Skip to content

Commit 23d0e89

Browse files
authored
Add langchain tool support (#8292)
* add langchain tool support * rename * update dependencies
1 parent e2b5fe9 commit 23d0e89

File tree

9 files changed

+350
-52
lines changed

9 files changed

+350
-52
lines changed

docs/docs/api/primitives/Tool.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
- __call__
99
- acall
1010
- from_mcp_tool
11+
- from_langchain
1112
show_source: true
1213
show_root_heading: true
1314
heading_level: 2

docs/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
git+https://github.com/stanfordnlp/dspy.git
1+
-e ..
22
mkdocs-material
33
mkdocs-jupyter
44
mkdocs-material[imaging]

dspy/predict/react.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def __init__(self, signature, tools: list[Callable], max_iters=5):
1616
"""
1717
`tools` is either a list of functions, callable classes, or `dspy.Tool` instances.
1818
"""
19-
19+
super().__init__()
2020
self.signature = signature = ensure_signature(signature)
2121
self.max_iters = max_iters
2222

dspy/primitives/tool.py

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import asyncio
22
import inspect
3-
from typing import TYPE_CHECKING, Any, Callable, Optional, get_origin, get_type_hints
3+
from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Type, get_origin, get_type_hints
44

55
from jsonschema import ValidationError, validate
66
from pydantic import BaseModel, TypeAdapter, create_model
@@ -9,7 +9,9 @@
99

1010
if TYPE_CHECKING:
1111
import mcp
12+
from langchain.tools import BaseTool
1213

14+
_TYPE_MAPPING = {"string": str, "integer": int, "number": float, "boolean": bool, "array": list, "object": dict}
1315

1416
class Tool:
1517
"""Tool class.
@@ -94,7 +96,7 @@ def _parse_function(self, func: Callable, arg_desc: Optional[dict[str, str]] = N
9496
origin = get_origin(v) or v
9597
if isinstance(origin, type) and issubclass(origin, BaseModel):
9698
# Get json schema, and replace $ref with the actual schema
97-
v_json_schema = resolve_json_schema_reference(v.model_json_schema())
99+
v_json_schema = _resolve_json_schema_reference(v.model_json_schema())
98100
args[k] = v_json_schema
99101
else:
100102
args[k] = TypeAdapter(v).json_schema()
@@ -172,6 +174,21 @@ def from_mcp_tool(cls, session: "mcp.client.session.ClientSession", tool: "mcp.t
172174

173175
return convert_mcp_tool(session, tool)
174176

177+
@classmethod
178+
def from_langchain(cls, tool: "BaseTool") -> "Tool":
179+
"""
180+
Build a DSPy tool from a LangChain tool.
181+
182+
Args:
183+
tool: The LangChain tool to convert.
184+
185+
Returns:
186+
A Tool object.
187+
"""
188+
from dspy.utils.langchain_tool import convert_langchain_tool
189+
190+
return convert_langchain_tool(tool)
191+
175192
def __repr__(self):
176193
return f"Tool(name={self.name}, desc={self.desc}, args={self.args})"
177194

@@ -181,7 +198,7 @@ def __str__(self):
181198
return f"{self.name}{desc} {arg_desc}"
182199

183200

184-
def resolve_json_schema_reference(schema: dict) -> dict:
201+
def _resolve_json_schema_reference(schema: dict) -> dict:
185202
"""Recursively resolve json model schema, expanding all references."""
186203

187204
# If there are no definitions to resolve, return the main schema
@@ -205,3 +222,35 @@ def resolve_refs(obj: Any) -> Any:
205222
# Remove the $defs key as it's no longer needed
206223
resolved_schema.pop("$defs", None)
207224
return resolved_schema
225+
226+
227+
def convert_input_schema_to_tool_args(
228+
schema: dict[str, Any],
229+
) -> Tuple[dict[str, Any], dict[str, Type], dict[str, str]]:
230+
"""Convert an input json schema to tool arguments compatible with DSPy Tool.
231+
232+
Args:
233+
schema: An input json schema describing the tool's input parameters
234+
235+
Returns:
236+
A tuple of (args, arg_types, arg_desc) for DSPy Tool definition.
237+
"""
238+
args, arg_types, arg_desc = {}, {}, {}
239+
properties = schema.get("properties", None)
240+
if properties is None:
241+
return args, arg_types, arg_desc
242+
243+
required = schema.get("required", [])
244+
245+
defs = schema.get("$defs", {})
246+
247+
for name, prop in properties.items():
248+
if len(defs) > 0:
249+
prop = _resolve_json_schema_reference({"$defs": defs, **prop})
250+
args[name] = prop
251+
arg_types[name] = _TYPE_MAPPING.get(prop.get("type"), Any)
252+
arg_desc[name] = prop.get("description", "No description provided.")
253+
if name in required:
254+
arg_desc[name] += " (Required)"
255+
256+
return args, arg_types, arg_desc

dspy/utils/langchain_tool.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from typing import TYPE_CHECKING, Any
2+
from dspy.primitives.tool import Tool, convert_input_schema_to_tool_args
3+
4+
if TYPE_CHECKING:
5+
from langchain.tools import BaseTool
6+
7+
8+
9+
def convert_langchain_tool(tool: "BaseTool") -> Tool:
10+
"""Build a DSPy tool from a LangChain tool.
11+
12+
This function converts a LangChain tool (either created with @tool decorator
13+
or by subclassing BaseTool) into a DSPy Tool.
14+
15+
Args:
16+
tool: The LangChain tool to convert.
17+
18+
Returns:
19+
A DSPy Tool object.
20+
"""
21+
async def func(**kwargs):
22+
try:
23+
result = await tool.ainvoke(kwargs)
24+
return result
25+
except Exception as e:
26+
raise RuntimeError(f"Failed to call LangChain tool {tool.name}: {str(e)}")
27+
28+
# Get args_schema from the tool
29+
# https://python.langchain.com/api_reference/core/tools/langchain_core.tools.base.BaseTool.html#langchain_core.tools.base.BaseTool.args_schema
30+
args_schema = tool.args_schema
31+
args, _, arg_desc = convert_input_schema_to_tool_args(args_schema.model_json_schema())
32+
33+
# The args_schema of Langchain tool is a pydantic model, so we can get the type hints from the model fields
34+
arg_types = {
35+
key: field.annotation if field.annotation is not None else Any
36+
for key, field in args_schema.model_fields.items()
37+
}
38+
39+
return Tool(
40+
func=func,
41+
name=tool.name,
42+
desc=tool.description,
43+
args=args,
44+
arg_types=arg_types,
45+
arg_desc=arg_desc
46+
)

dspy/utils/mcp.py

Lines changed: 3 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,10 @@
1-
from typing import TYPE_CHECKING, Any, Tuple, Type, Union
1+
from typing import TYPE_CHECKING, Any, Union
22

3-
from dspy.primitives.tool import Tool, resolve_json_schema_reference
3+
from dspy.primitives.tool import Tool, convert_input_schema_to_tool_args
44

55
if TYPE_CHECKING:
66
import mcp
77

8-
TYPE_MAPPING = {"string": str, "integer": int, "number": float, "boolean": bool, "array": list, "object": dict}
9-
10-
11-
def _convert_input_schema_to_tool_args(
12-
schema: dict[str, Any],
13-
) -> Tuple[dict[str, Any], dict[str, Type], dict[str, str]]:
14-
"""Convert an input schema to tool arguments compatible with DSPy Tool.
15-
16-
Args:
17-
schema: An input schema describing the tool's input parameters
18-
19-
Returns:
20-
A tuple of (args, arg_types, arg_desc) for DSPy Tool definition.
21-
"""
22-
args, arg_types, arg_desc = {}, {}, {}
23-
properties = schema.get("properties", None)
24-
if properties is None:
25-
return args, arg_types, arg_desc
26-
27-
required = schema.get("required", [])
28-
29-
defs = schema.get("$defs", {})
30-
31-
for name, prop in properties.items():
32-
if len(defs) > 0:
33-
prop = resolve_json_schema_reference({"$defs": defs, **prop})
34-
args[name] = prop
35-
# MCP tools are validated through jsonschema using args, so arg_types are not strictly required.
36-
arg_types[name] = TYPE_MAPPING.get(prop.get("type"), Any)
37-
arg_desc[name] = prop.get("description", "No description provided.")
38-
if name in required:
39-
arg_desc[name] += " (Required)"
40-
41-
return args, arg_types, arg_desc
42-
438

449
def _convert_mcp_tool_result(call_tool_result: "mcp.types.CallToolResult") -> Union[str, list[Any]]:
4510
from mcp.types import TextContent
@@ -72,7 +37,7 @@ def convert_mcp_tool(session: "mcp.client.session.ClientSession", tool: "mcp.typ
7237
Returns:
7338
A dspy Tool object.
7439
"""
75-
args, arg_types, arg_desc = _convert_input_schema_to_tool_args(tool.inputSchema)
40+
args, arg_types, arg_desc = convert_input_schema_to_tool_args(tool.inputSchema)
7641

7742
# Convert the MCP tool and Session to a single async method
7843
async def func(*args, **kwargs):

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ anthropic = ["anthropic>=0.18.0,<1.0.0"]
5252
weaviate = ["weaviate-client~=4.5.4"]
5353
aws = ["boto3~=1.34.78"]
5454
mcp = ["mcp; python_version >= '3.10'"]
55+
langchain = ["langchain_core"]
5556
dev = [
5657
"pytest>=6.2.5",
5758
"pytest-mock>=3.12.0",
@@ -66,6 +67,7 @@ dev = [
6667
]
6768
test_extras = [
6869
"mcp; python_version >= '3.10'",
70+
"langchain_core",
6971
]
7072

7173
[tool.setuptools.packages.find]

tests/utils/test_langchain_tool.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import importlib
2+
import pytest
3+
4+
if importlib.util.find_spec("langchain_core") is None:
5+
pytest.skip(reason="langchain_core is not installed", allow_module_level=True)
6+
7+
from pydantic import BaseModel
8+
9+
from dspy.utils.langchain_tool import convert_langchain_tool
10+
11+
12+
@pytest.mark.asyncio
13+
@pytest.mark.extra
14+
async def test_convert_custom_simple_tool():
15+
from langchain_core.tools import tool
16+
17+
@tool
18+
def add(a: int, b: int) -> int:
19+
"""Add two numbers."""
20+
return a + b
21+
22+
tool = convert_langchain_tool(add)
23+
assert tool.name == "add"
24+
assert tool.desc == "Add two numbers."
25+
assert tool.args == {"a": {"title": "A", "type": "integer"}, "b": {"title": "B", "type": "integer"}}
26+
assert tool.arg_types == {"a": int, "b": int}
27+
assert tool.arg_desc == {"a": "No description provided. (Required)", "b": "No description provided. (Required)"}
28+
assert await tool.acall(a=1, b=2) == 3
29+
30+
31+
@pytest.mark.asyncio
32+
@pytest.mark.extra
33+
async def test_convert_custom_tool_with_custom_class():
34+
from langchain_core.tools import tool
35+
36+
class Profile(BaseModel):
37+
name: str
38+
age: int
39+
40+
@tool
41+
def get_age(profile: Profile) -> int:
42+
"""Get the age of the profile."""
43+
return profile.age
44+
45+
tool = convert_langchain_tool(get_age)
46+
assert tool.name == "get_age"
47+
assert tool.desc == "Get the age of the profile."
48+
assert tool.args == {"profile": {"title": "Profile", "type": "object", "properties": {"name": {"title": "Name", "type": "string"}, "age": {"title": "Age", "type": "integer"}}, "required": ["name", "age"]}}
49+
assert tool.arg_types == {"profile": Profile}
50+
assert tool.arg_desc == {"profile": "No description provided. (Required)"}
51+
assert await tool.acall(profile=Profile(name="John", age=20)) == 20

0 commit comments

Comments
 (0)