Skip to content

Commit 9e6935d

Browse files
committed
调整为自动注册工具类实例,便于后续扩展。
1 parent 59f81ae commit 9e6935d

File tree

2 files changed

+60
-51
lines changed

2 files changed

+60
-51
lines changed

src/handles/base.py

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,64 @@
1-
from typing import Dict, Any, Sequence
1+
from typing import Dict, Any, Sequence, Type, ClassVar
22

33
from mcp.types import TextContent, Tool
44

55

6-
class BaseHandler:
6+
class ToolRegistry:
7+
"""工具注册表,用于管理所有工具实例"""
8+
_tools: ClassVar[Dict[str, 'BaseHandler']] = {}
9+
10+
@classmethod
11+
def register(cls, tool_class: Type['BaseHandler']) -> Type['BaseHandler']:
12+
"""注册工具类
13+
14+
Args:
15+
tool_class: 要注册的工具类
16+
17+
Returns:
18+
返回注册的工具类,方便作为装饰器使用
19+
"""
20+
tool = tool_class()
21+
cls._tools[tool.name] = tool
22+
return tool_class
23+
24+
@classmethod
25+
def get_tool(cls, name: str) -> 'BaseHandler':
26+
"""获取工具实例
27+
28+
Args:
29+
name: 工具名称
30+
31+
Returns:
32+
工具实例
33+
34+
Raises:
35+
ValueError: 当工具不存在时抛出
36+
"""
37+
if name not in cls._tools:
38+
raise ValueError(f"未知的工具: {name}")
39+
return cls._tools[name]
40+
41+
@classmethod
42+
def get_all_tools(cls) -> list[Tool]:
43+
"""获取所有工具的描述
44+
45+
Returns:
46+
所有工具的描述列表
47+
"""
48+
return [tool.get_tool_description() for tool in cls._tools.values()]
749

50+
51+
class BaseHandler:
52+
"""工具基类"""
853
name: str = ""
954
description: str = ""
1055

56+
def __init_subclass__(cls, **kwargs):
57+
"""子类初始化时自动注册到工具注册表"""
58+
super().__init_subclass__(**kwargs)
59+
if cls.name: # 只注册有名称的工具
60+
ToolRegistry.register(cls)
61+
1162
def get_tool_description(self) -> Tool:
1263
raise NotImplementedError
1364

src/server.py

Lines changed: 7 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -10,49 +10,18 @@
1010
from starlette.applications import Starlette
1111
from starlette.routing import Route, Mount
1212

13-
from handles import (
14-
ExecuteSQL,
15-
GetChineseInitials,
16-
GetTableIndex,
17-
GetTableLock,
18-
GetTableName,
19-
GetTableDesc
20-
)
21-
22-
# 初始化工具实例
23-
execute_sql = ExecuteSQL()
24-
get_chinese_initials = GetChineseInitials()
25-
get_table_index = GetTableIndex()
26-
get_table_desc = GetTableDesc()
27-
get_table_name = GetTableName()
28-
get_table_lock = GetTableLock()
29-
13+
from handles.base import ToolRegistry
3014

3115
# 初始化服务器
3216
app = Server("operateMysql")
3317

3418

3519
@app.list_tools()
3620
async def list_tools() -> list[Tool]:
37-
"""列出所有可用的MySQL操作工具
38-
39-
Returns:
40-
list[Tool]: 返回工具列表,包含:
41-
- execute_sql: 执行SQL语句
42-
- get_chinese_initials: 获取中文拼音首字母
43-
- get_table_index: 获取表索引信息
44-
- get_table_desc: 获取表结构描述
45-
- get_table_name: 获取表名
46-
- get_table_lock: 获取表锁信息
4721
"""
48-
return [
49-
execute_sql.get_tool_description(),
50-
get_chinese_initials.get_tool_description(),
51-
get_table_index.get_tool_description(),
52-
get_table_desc.get_tool_description(),
53-
get_table_name.get_tool_description(),
54-
get_table_lock.get_tool_description()
55-
]
22+
列出所有可用的MySQL操作工具
23+
"""
24+
return ToolRegistry.get_all_tools()
5625

5726
@app.call_tool()
5827
async def call_tool(name: str, arguments: dict) -> Sequence[TextContent]:
@@ -68,20 +37,9 @@ async def call_tool(name: str, arguments: dict) -> Sequence[TextContent]:
6837
Raises:
6938
ValueError: 当指定了未知的工具名称时抛出异常
7039
"""
71-
if name == execute_sql.name:
72-
return await execute_sql.run_tool(arguments)
73-
elif name == get_table_index.name:
74-
return await get_table_index.run_tool(arguments)
75-
elif name == get_table_name.name:
76-
return await get_table_name.run_tool(arguments)
77-
elif name == get_table_lock.name:
78-
return await get_table_lock.run_tool(arguments)
79-
elif name == get_table_desc.name:
80-
return await get_table_desc.run_tool(arguments)
81-
elif name == get_chinese_initials.name:
82-
return await get_chinese_initials.run_tool(arguments)
83-
84-
raise ValueError(f"未知的工具: {name}")
40+
tool = ToolRegistry.get_tool(name)
41+
42+
return await tool.run_tool(arguments)
8543

8644

8745
async def run_stdio():

0 commit comments

Comments
 (0)