KeiStory

반응형

ai 가 만들어 준 이미지;;

LangChain으로 도구 호출 승인 시스템 구현하기

 

이전 포스팅에서 LangChain의 도구(Tool) 기능과 스트리밍 처리 방법에 대해 알아보았습니다.

2026.01.20 - [코딩/Python_AI] - ChatOpenAI 와 FastAPI 를 이용해 Streaming 채팅 구현하기

2026.01.23 - [코딩/Python_AI] - ChatOpenAI 와 FastAPI 챗 서버에 Tool Calling 기능 추가하기

이번에는 실제 운영 환경에서 사용할 수 있는 도구 호출 승인 시스템 구현 방법을 알아봅니다.

 

왜 도구 호출 승인이 필요한가?

AI 에이전트가 자동으로 도구를 실행하는 것은 편리하지만, 때로는 위험할 수 있습니다:

  • 데이터베이스 삭제 작업
  • 외부 API 호출 (비용 발생)
  • 파일 시스템 변경
  • 이메일 발송 등

이러한 작업들은 실행 전에 사용자의 명시적인 승인을 받는 것이 안전합니다.

 

시스템 아키텍처

 

이번에 구현할 시스템은 다음과 같은 흐름으로 동작합니다.

  1. 사용자가 메시지를 전송
  2. AI가 응답을 생성하는 과정에서 도구 호출이 필요하다고 판단
  3. 서버는 도구를 즉시 실행하지 않고 승인 대기 상태로 중단
  4. 클라이언트에 “도구 실행 승인 요청” 이벤트 전달
  5. 사용자가 승인 또는 거부 선택
  6. 승인 시 도구 실행 → 결과를 기반으로 AI 최종 응답 생성

이 구조의 핵심 포인트는 다음 두 가지입니다.

  • SSE 스트리밍을 유지한 상태에서 승인 플로우를 분기
  • 도구 실행 전 상태(message + tool call 정보)를 서버에 임시 저장

 

Server

주요 구성 요소

pendingDictionary
승인 대기 중인 세션 정보를 임시 저장하는 메모리 저장소

/v1/chat/completion
일반 채팅 + 도구 호출 감지용 엔드포인트

/v1/chat/approve
도구 실행 승인/거부 처리 엔드포인트

generateStream()
최초 AI 응답 스트리밍 및 승인 요청 트리거

generateApprovalStream()
승인 이후 도구 실행 + 최종 응답 스트리밍

server.py

# uv add python-dotenv langchain langchain-openai fastapi uvicorn
import json
import uvicorn
import anyio
import uuid

from dotenv                  import load_dotenv
from pydantic                import BaseModel
from typing                  import List
from typing                  import Optional
from typing                  import Dict
from typing                  import Any
from pydantic                import Field
from langchain_core.tools    import tool
from langchain_core.messages import HumanMessage
from langchain_core.messages import AIMessage
from langchain_core.messages import SystemMessage
from langchain_core.messages import ToolMessage
from langchain_openai        import ChatOpenAI
from fastapi                 import FastAPI
from fastapi                 import HTTPException
from fastapi.responses       import StreamingResponse

load_dotenv()

class Message(BaseModel):
    role    : str
    content : str

class ChatRequest(BaseModel):
    messages    : List[Message]
    model       : str             = "gpt-4o-mini"
    temperature : Optional[float] = Field(default = 0.7, ge = 0, le = 2)
    max_tokens  : Optional[int  ] = None

class ApprovalRequest(BaseModel):
    sessionid : str
    approved   : bool

@tool
def getStringLength(text: str) -> int:
    """입력된 문자열의 길이를 반환한다."""
    print(f"\n>>> [도구 실행] getStringLength 호출됨! 입력값: {text}")
    return len(text)

toolList       = [getStringLength]
toolDictionary = {tool.name : tool for tool in toolList}

pendingDictionary : Dict[str, Dict[str, Any]] = {}

def getLangchainMessageList(messageList : List[Message]):
    targetMessageList = []
    for message in messageList:
        if message.role == "system":
            targetMessageList.append(SystemMessage(content = message.content))
        elif message.role == "user":
            targetMessageList.append(HumanMessage(content = message.content))
        elif message.role == "assistant":
            targetMessageList.append(AIMessage(content = message.content))
        else:
            raise ValueError(f"Unknown role : {message.role}")
    return targetMessageList

async def generateStream(chatRequest : ChatRequest):
    try:
        chatOpenAI = (ChatOpenAI(
            model       = chatRequest.model,
            temperature = chatRequest.temperature,
            max_tokens  = chatRequest.max_tokens,
            streaming   = True
        ).bind_tools(toolList))

        messageList = getLangchainMessageList(chatRequest.messages)
        
        while True:
            fullAIMessageChunk = None
            async for aiMessageChunk in chatOpenAI.astream(messageList):
                if fullAIMessageChunk is None:
                    fullAIMessageChunk = aiMessageChunk
                else:
                    fullAIMessageChunk += aiMessageChunk
                if aiMessageChunk.content:
                    yield f"data: {json.dumps({'type' : 'content', 'content' : aiMessageChunk.content}, ensure_ascii = False)}\n\n"
            
            # 도구 호출이 있는 경우
            if fullAIMessageChunk and fullAIMessageChunk.tool_calls:
                # 세션 ID 생성
                sessionId = str(uuid.uuid4())
                
                # 도구 호출 정보 수집
                toolCallsInfo = []
                for toolCall in fullAIMessageChunk.tool_calls:
                    toolCallsInfo.append({
                        "id"        : toolCall["id"],
                        "name"      : toolCall["name"],
                        "arguments" : toolCall["args"]
                    })
                
                # 세션 정보 저장
                pendingDictionary[sessionId] = {
                    "messageList" : messageList,
                    "aiMessage"   : fullAIMessageChunk,
                    "toolCalls"   : toolCallsInfo
                }
                
                # 승인 요청 메시지 전송
                yield f"data: {json.dumps({'type' : 'approvalRequired', 'sessionId' : sessionId, 'toolCalls' : toolCallsInfo}, ensure_ascii = False)}\n\n"
                yield "data: [DONE]\n\n"
                break
            else:
                # 도구 호출이 없으면 정상 종료
                yield "data: [DONE]\n\n"
                break
                
    except Exception as exception:
        print(f"STREAMING ERROR : {str(exception)}")
        yield f"data: {json.dumps({'type' : 'error', 'error' : str(exception)}, ensure_ascii = False)}\n\n"

async def generateApprovalStream(sessionId : str, approved : bool):
    try:
        if sessionId not in pendingDictionary:
            yield f"data: {json.dumps({'type' : 'error', 'error' : 'Invalid session ID'}, ensure_ascii = False)}\n\n"
            return
        
        sessionData = pendingDictionary[sessionId]
        
        if not approved:
            yield f"data: {json.dumps({'type' : 'content', 'content' : '도구 실행이 사용자에 의해 취소되었습니다.'}, ensure_ascii = False)}\n\n"
            yield "data: [DONE]\n\n"
            del pendingDictionary[sessionId]
            return
        
        # 승인된 경우 - 도구 실행
        messageList     = sessionData["messageList"]
        fullAIMessage   = sessionData["aiMessage"]
        toolCallsInfo   = sessionData["toolCalls"]
        
        messageList.append(fullAIMessage)
        
        # 각 도구 실행
        for toolCallInfo in toolCallsInfo:
            toolName      = toolCallInfo["name"]
            toolArguments = toolCallInfo["arguments"]
            toolCallId    = toolCallInfo["id"]
            
            yield f"data: {json.dumps({'type' : 'toolExecuting', 'tool' : toolName, 'arguments' : toolArguments}, ensure_ascii = False)}\n\n"
            
            toolFunction = toolDictionary.get(toolName)
            if not toolFunction:
                result = f"Error : Unknown tool '{toolName}'"
            else:
                try:
                    result = await anyio.to_thread.run_sync(toolFunction.invoke, toolArguments)
                except Exception as exception:
                    result = f"Error executing {toolName} : {str(exception)}"
            
            toolMessage = ToolMessage(content = str(result), tool_call_id = toolCallId)
            messageList.append(toolMessage)
            
            yield f"data: {json.dumps({'type' : 'toolResult', 'tool' : toolName, 'result' : result}, ensure_ascii = False)}\n\n"
        
        # 도구 실행 결과를 바탕으로 AI 응답 재생성
        chatOpenAI = (ChatOpenAI(
            model       = "gpt-4o-mini",
            temperature = 0.7,
            streaming   = True
        ).bind_tools(toolList))
        
        newline_content = '\n'
        yield f"data: {json.dumps({'type' : 'content', 'content' : newline_content}, ensure_ascii = False)}\n\n"
        
        async for aiMessageChunk in chatOpenAI.astream(messageList):
            if aiMessageChunk.content:
                yield f"data: {json.dumps({'type' : 'content', 'content' : aiMessageChunk.content}, ensure_ascii = False)}\n\n"
        
        yield "data: [DONE]\n\n"
        
        # 세션 정리
        del pendingDictionary[sessionId]
        
    except Exception as exception:
        print(f"APPROVAL STREAMING ERROR : {str(exception)}")
        yield f"data: {json.dumps({'type' : 'error', 'error' : str(exception)}, ensure_ascii = False)}\n\n"

fastAPI = FastAPI()

@fastAPI.post("/v1/chat/completion")
async def processChatCompletion(chatRequest : ChatRequest):
    if not chatRequest.messages:
        raise HTTPException(status_code = 400, detail = "Messages cannot be empty")
    return StreamingResponse(generateStream(chatRequest), media_type = "text/event-stream", headers= {"Cache-Control" : "no-cache", "Connection" : "keep-alive"})

@fastAPI.post("/v1/chat/approve")
async def processApproval(approvalRequest : ApprovalRequest):
    return StreamingResponse(
        generateApprovalStream(approvalRequest.sessionid, approvalRequest.approved),
        media_type = "text/event-stream",
        headers = {"Cache-Control" : "no-cache", "Connection" : "keep-alive"}
    )

@fastAPI.get("/health")
async def processHealth():
    return {"status" : "healthy"}

if __name__ == "__main__":
    uvicorn.run(fastAPI, host = "0.0.0.0", port = 8000)

Client

클라이언트의 역할

- SSE 스트림을 실시간으로 수신
- 일반 텍스트 응답은 그대로 출력
- approvalRequired 이벤트 수신 시
    도구 이름과 인자를 사용자에게 출력
    승인 여부를 입력받아 서버에 전달
- 승인 이후 도구 실행 상태 및 결과를 실시간 표시

승인 UX 흐름

  1. 도구 호출 감지
  2. 콘솔에 도구 정보 출력
  3. yes / no 입력 요청
  4. 승인 결과를 /v1/chat/approve로 전송
  5. 도구 실행 로그 + 최종 AI 응답 스트리밍

client.py

# uv add httpx
import httpx
import json
import asyncio

from datetime import datetime
from typing   import List
from typing   import Dict
from typing   import Optional

class ChatClient:
    def __init__(self, serverURL : str = "http://localhost:8000"):
        self.serverURL                          = serverURL
        self.messageList : List[Dict[str, str]] = []
        self.timeout                            = httpx.Timeout(60.0, connect = 5.0)

    def addMessage(self, role : str, content : str) -> None:
        self.messageList.append(
            {
                "role"      : role,
                "content"   : content,
                "timestamp" : datetime.now().isoformat()
            }
        )

    def getMessageList(self) -> List[Dict[str, str]]:
        return [{"role" : message["role"], "content" : message["content"]} for message in self.messageList]

    def clearMessageList(self) -> None:
        self.messageList.clear()
        print("MESSAGE LIST CLEARED")

    async def checkServerHealth(self) -> bool:
        try:
            async with httpx.AsyncClient(timeout = self.timeout) as asyncClient:
                response = await asyncClient.get(f"{self.serverURL}/health")
                return response.status_code == 200
        except Exception as exception:
            print(f"HEALTH CHECK FAILED : {exception}")
            return False

    async def astream(self, userInput : str, model : str = "gpt-4o-mini", temperature : float = 0.7) -> Optional[str]:
        self.addMessage("user", userInput)
        requestDictionary = {
            "messages"    : self.getMessageList(),
            "model"       : model,
            "temperature" : temperature
        }
        
        async with httpx.AsyncClient(timeout = self.timeout) as asyncClient:
            try:
                async with asyncClient.stream("POST", f"{self.serverURL}/v1/chat/completion", json = requestDictionary) as response:
                    if response.status_code != 200:
                        print(f"HTTP ERROR : {response.status_code}")
                        return None
                    
                    fullResponse = ""
                    print("ASSISTANT : ", end = "", flush = True)
                    
                    async for line in response.aiter_lines():
                        if not line.startswith("data: "):
                            continue
                        
                        dataLine = line[6:]
                        if dataLine == "[DONE]":
                            break
                        
                        try:
                            dataDictionary = json.loads(dataLine)
                            dataType = dataDictionary.get("type")
                            
                            if dataType == "content":
                                content = dataDictionary["content"]
                                print(content, end = "", flush = True)
                                fullResponse += content
                            
                            elif dataType == "approvalRequired":
                                # 도구 승인 요청 처리
                                print()
                                sessionId = dataDictionary.get("sessionId")
                                toolCalls = dataDictionary.get("toolCalls")
                                
                                if sessionId and toolCalls:
                                    approvalResult = await self.requestApproval(sessionId, toolCalls)
                                    return approvalResult
                            
                            elif dataType == "error":
                                print()
                                print(f"ERROR : {dataDictionary.get('error')}")
                                return None
                        
                        except json.JSONDecodeError:
                            continue
                    
                    print()
                    if fullResponse:
                        self.addMessage("assistant", fullResponse)
                    return fullResponse
            
            except httpx.TimeoutException:
                print("REQUEST TIMEOUT")
            except httpx.ConnectError:
                print("CANNOT CONNECT TO SERVER")
            except Exception:
                print("UNEXPECTED ERROR", exc_info = True)
        
        return None

    def showMessageList(self) -> None:
        print()
        print("-" * 50)
        print("CURRENT MESSAGE LIST")
        print("-" * 50)
        if not self.messageList:
            print("(EMPTY)")
        else:
            for i, message in enumerate(self.messageList, 1):
                content = message["content"]
                if len(content) > 100:
                    content = content[:100] + "..."
                print(f"{i}. [{message['role'].upper()}] {content}")
        print("-" * 50)
        print()

    async def requestApproval(self, sessionId : str, toolCalls : List[Dict]) -> Optional[str]:
        """도구 호출 승인 처리"""
        print()
        print("=" * 60)
        print("도구 실행 승인 요청")
        print("=" * 60)
        
        for i, toolCall in enumerate(toolCalls, 1):
            print(f" 도구 이름: {toolCall['name']}]")
            print(f" 인자: {json.dumps(toolCall['arguments'], ensure_ascii=False, indent=2)}")
        
        print("=" * 60)
        
        while True:
            approval = input("도구 실행을 승인하시겠습니까? (yes(y)/no(n)): ").strip().lower()
            if approval in ("yes", "y"):
                approved = True
                break
            elif approval in ("no", "n"):
                approved = False
                break
            else:
                print("'yes(y)' 또는 'no(n)'를 입력해주세요.")
        
        print("=" * 60)
        
        # 승인 요청 전송
        return await self.sendRequestApproval(sessionId, approved)


    async def sendRequestApproval(self, sessionId : str, approved : bool) -> Optional[str]:
        """승인/거부 결정을 서버에 전송하고 결과 스트림 처리"""
        requestDictionary = {
            "sessionid" : sessionId,
            "approved"  : approved
        }
        
        async with httpx.AsyncClient(timeout = self.timeout) as asyncClient:
            try:
                async with asyncClient.stream("POST", f"{self.serverURL}/v1/chat/approve", json = requestDictionary) as response:
                    if response.status_code != 200:
                        print(f"HTTP ERROR : {response.status_code}")
                        return None
                    
                    fullResponse = ""
                    print("ASSISTANT : ", end = "", flush = True)
                    print()

                    async for line in response.aiter_lines():
                        if not line.startswith("data: "):
                            continue
                        
                        dataLine = line[6:]
                        if dataLine == "[DONE]":
                            break
                        
                        try:
                            dataDictionary = json.loads(dataLine)
                            dataType = dataDictionary.get("type")
                            
                            if dataType == "toolExecuting":
                                toolName = dataDictionary.get("tool")
                                arguments = dataDictionary.get("arguments")
                                print(f"[도구 실행 중: {toolName}({json.dumps(arguments, ensure_ascii=False)})]", flush = True)
                           
                            elif dataType == "toolResult":
                                toolName = dataDictionary.get("tool")
                                result = dataDictionary.get("result")
                                print(f"[도구 호출 결과: {toolName} → {result}]", flush = True)
                                
                            elif dataType == "content":
                                content = dataDictionary["content"]
                                print(content, end = "", flush = True)
                                fullResponse += content
                            
                            elif dataType == "error":
                                print(f"ERROR : {dataDictionary.get('error')}")
                                return None
                        
                        except json.JSONDecodeError:
                            continue
                    
                    print()
                    if fullResponse:
                        self.addMessage("assistant", fullResponse)
                    return fullResponse
            
            except httpx.TimeoutException:
                print("REQUEST TIMEOUT")
            except httpx.ConnectError:
                print("CANNOT CONNECT TO SERVER")
            except Exception:
                print("UNEXPECTED ERROR", exc_info = True)
        
        return None

async def main():
    chatClient = ChatClient()
    print("COMMANDS : quit | clear | show | health")
    print()
    if not await chatClient.checkServerHealth():
        print("SERVER HEALTH CHECK FAILED")
        print()
    
    while True:
        try:
            userInput = input("YOU : ").strip()
            if userInput in ("quit", "exit"):
                break
            elif userInput == "clear":
                chatClient.clearMessageList()
            elif userInput == "show":
                chatClient.showMessageList()
            elif userInput == "health":
                ok = await chatClient.checkServerHealth()
                print("HEALTHY" if ok else "UNHEALTHY")
            elif userInput:
                await chatClient.astream(userInput)
        except KeyboardInterrupt:
            break

if __name__ == "__main__":
    asyncio.run(main())

 

728x90

공유하기

facebook twitter kakaoTalk kakaostory naver band