import asyncio
import sys
from typing import Optional
from contextlib import AsyncExitStack
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
from anthropic import Anthropic
from dotenv import load_dotenv

# 環境変数を読み込む
load_dotenv()  

class MCPClient:
    def __init__(self):
        """クライアントの初期化"""

        self.session: Optional[ClientSession] = None
        self.exit_stack = AsyncExitStack()
        self.anthropic = Anthropic()


    async def connect_to_server(self, server_script_path: str):
        """MCPサーバに接続"""

        # 引数のチェック
        is_python = server_script_path.endswith('.py')
        is_js = server_script_path.endswith('.js')
        if not (is_python or is_js):
            raise ValueError("サーバスクリプトは.pyまたは.jsである必要があります。")

        # サーバパラメータの準備
        command = "python" if is_python else "node"
        server_params = StdioServerParameters(
            command=command,
            args=[server_script_path],
            env=None
        )

        # MCPサーバに接続
        stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
        self.stdio, self.write = stdio_transport
        self.session = await self.exit_stack.enter_async_context(ClientSession(self.stdio, self.write))
        await self.session.initialize()

        # 利用可能ツールの一覧表示
        response = await self.session.list_tools()
        tools = response.tools
        print("\nMCPサーバに接続\n利用可能ツール:", [tool.name for tool in tools])


    async def process_query(self, query: str) -> str:
        """クエリの処理"""

        # メッセージリストの準備
        messages = [
            {
                "role": "user",
                "content": query
            }
        ]

        # 利用可能ツールの準備
        response = await self.session.list_tools()
        available_tools = [{
            "name": tool.name,
            "description": tool.description,
            "input_schema": tool.inputSchema
        } for tool in response.tools]

        # LLMの呼び出し
        response = self.anthropic.messages.create(
            model="claude-3-5-sonnet-20241022",
            max_tokens=1000,
            messages=messages,
            tools=available_tools
        )

        # 応答の処理
        final_text = []
        assistant_message_content = []
        for content in response.content:
            # テキスト
            if content.type == 'text':
                final_text.append(content.text)
                assistant_message_content.append(content)
            
            # ツール
            elif content.type == 'tool_use':
                tool_name = content.name
                tool_args = content.input

                # ツールの呼び出し
                result = await self.session.call_tool(tool_name, tool_args)
                final_text.append(f"ツール呼び出し: ({result})")
                assistant_message_content.append(content)
                messages.append({
                    "role": "assistant",
                    "content": assistant_message_content
                })
                messages.append({
                    "role": "user",
                    "content": [
                        {
                            "type": "tool_result",
                            "tool_use_id": content.id,
                            "content": result.content
                        }
                    ]
                })

                # LLMの呼び出し
                response = self.anthropic.messages.create(
                    model="claude-3-5-sonnet-20241022",
                    max_tokens=1000,
                    messages=messages,
                    tools=available_tools
                )
                final_text.append(response.content[0].text)
        return "\n".join(final_text)


    async def chat_loop(self):
        """MCPクライアントのチャットループ"""

        print("\nチャットループ開始")
        while True:
            try:
                query = input("\n> ").strip()
                response = await self.process_query(query)  # クエリの実行
                print("\n" + response)
            except Exception as e:
                print(f"\nError: {str(e)}")


    async def cleanup(self):
        """MCPクライアントのクリーンアップ"""
        await self.exit_stack.aclose()


async def main():
    """メイン"""

    # 引数が2個未満
    if len(sys.argv) < 2:
        print("Usage: python client.py <path_to_server_script>")
        sys.exit(1)

    # MCPクライアントの準備
    client = MCPClient()
    try:
        await client.connect_to_server(sys.argv[1])  # MCPサーバとの接続
        await client.chat_loop()  # MCPクライアントのチャットループ
    finally:
        await client.cleanup()  # MCPクライアントのクリーンアップ

if __name__ == "__main__":
    """メイン"""
    asyncio.run(main())