Skip to content

Commit 6629cd1

Browse files
authored
CG-10967: Create Custom Langgraph Nodes + Retry Policy (#788)
- No more `create_react_agent` from `langgraph.prebuilt` - In `graph.py` we now have our own nodes for langgraph - Retry policy with the following parameters `30s` for interval and back off factor `0.5`. so it will be 30s first wait, then `15s`, then `7.5s`, etc.
1 parent 7de5af7 commit 6629cd1

File tree

11 files changed

+283
-73
lines changed

11 files changed

+283
-73
lines changed

codegen-examples/examples/langchain_agent/run.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from langgraph.checkpoint.memory import MemorySaver
2222
from langgraph.graph.graph import CompiledGraph
23-
from langgraph.prebuilt import create_react_agent
23+
from codegen.extensions.langchain.graph import create_react_agent
2424
from langchain_core.messages import SystemMessage
2525

2626

@@ -70,7 +70,7 @@ def create_codebase_agent(
7070

7171
memory = MemorySaver() if memory else None
7272

73-
return create_react_agent(model=llm, tools=tools, prompt=system_message, checkpointer=memory, debug=debug)
73+
return create_react_agent(model=llm, tools=tools, system_message=system_message, checkpointer=memory, debug=debug)
7474

7575

7676
if __name__ == "__main__":

codegen-examples/examples/swebench_agent_run/local_run.ipynb

Lines changed: 11 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,14 @@
77
"outputs": [],
88
"source": [
99
"%load_ext autoreload\n",
10-
"%autoreload 2"
10+
"%autoreload 2\n",
11+
"\n",
12+
"from dotenv import load_dotenv # type: ignore\n",
13+
"\n",
14+
"load_dotenv()\n",
15+
"\n",
16+
"from codegen.extensions.swebench.utils import SWEBenchDataset, get_swe_bench_examples # noqa: E402\n",
17+
"from run_eval import run_eval # noqa: E402"
1118
]
1219
},
1320
{
@@ -16,9 +23,7 @@
1623
"metadata": {},
1724
"outputs": [],
1825
"source": [
19-
"from codegen.sdk.core.codebase import Codebase\n",
20-
"from codegen.extensions.swebench.utils import SWEBenchDataset, get_swe_bench_examples\n",
21-
"from run_eval import run_eval"
26+
"examples = get_swe_bench_examples(dataset=SWEBenchDataset.LITE, split=\"test\", offset=0, length=10)"
2227
]
2328
},
2429
{
@@ -27,43 +32,8 @@
2732
"metadata": {},
2833
"outputs": [],
2934
"source": [
30-
"examples = get_swe_bench_examples(dataset=SWEBenchDataset.LITE, split=\"test\", offset=0, length=1)"
35+
"await run_eval(use_existing_preds=None, dataset=\"lite\", length=20, repo=\"django/django\", num_workers=10, model=\"claude-3-7-sonnet-latest\")"
3136
]
32-
},
33-
{
34-
"cell_type": "code",
35-
"execution_count": null,
36-
"metadata": {},
37-
"outputs": [],
38-
"source": [
39-
"codebase = Codebase.from_repo(examples[0].repo, commit=examples[0].base_commit, tmp_dir=f\"/tmp/{examples[0].instance_id}\")\n",
40-
"# this will allow us to reuse the codebase for multiple examples\n",
41-
"codebases = {examples[0].instance_id: codebase}"
42-
]
43-
},
44-
{
45-
"cell_type": "code",
46-
"execution_count": null,
47-
"metadata": {},
48-
"outputs": [],
49-
"source": [
50-
"await run_eval(use_existing_preds=None, dataset=\"lite\", length=None, instance_id=examples[0].instance_id, local=True, codebases=codebases)\n",
51-
"codebases[examples[0].instance_id].reset()"
52-
]
53-
},
54-
{
55-
"cell_type": "code",
56-
"execution_count": null,
57-
"metadata": {},
58-
"outputs": [],
59-
"source": []
60-
},
61-
{
62-
"cell_type": "code",
63-
"execution_count": null,
64-
"metadata": {},
65-
"outputs": [],
66-
"source": []
6737
}
6838
],
6939
"metadata": {
@@ -82,7 +52,7 @@
8252
"name": "python",
8353
"nbconvert_exporter": "python",
8454
"pygments_lexer": "ipython3",
85-
"version": "3.13.1"
55+
"version": "3.13.0"
8656
}
8757
},
8858
"nbformat": 4,

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ dependencies = [
1010
"tiktoken<1.0.0,>=0.5.1",
1111
"tabulate>=0.9.0,<1.0.0",
1212
"codeowners<1.0.0,>=0.6.0",
13+
"anthropic",
1314
"dataclasses-json<1.0.0,>=0.6.4",
1415
"dicttoxml<2.0.0,>=1.7.16",
1516
"xmltodict<1.0.0,>=0.13.0",

src/codegen/agents/chat_agent.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
from typing import TYPE_CHECKING, Optional
2+
from uuid import uuid4
3+
4+
from langchain.tools import BaseTool
5+
from langchain_core.messages import AIMessage
6+
7+
from codegen.extensions.langchain.agent import create_chat_agent
8+
9+
if TYPE_CHECKING:
10+
from codegen import Codebase
11+
12+
13+
class ChatAgent:
14+
"""Agent for interacting with a codebase."""
15+
16+
def __init__(self, codebase: "Codebase", model_provider: str = "anthropic", model_name: str = "claude-3-5-sonnet-latest", memory: bool = True, tools: Optional[list[BaseTool]] = None, **kwargs):
17+
"""Initialize a CodeAgent.
18+
19+
Args:
20+
codebase: The codebase to operate on
21+
model_provider: The model provider to use ("anthropic" or "openai")
22+
model_name: Name of the model to use
23+
memory: Whether to let LLM keep track of the conversation history
24+
tools: Additional tools to use
25+
**kwargs: Additional LLM configuration options. Supported options:
26+
- temperature: Temperature parameter (0-1)
27+
- top_p: Top-p sampling parameter (0-1)
28+
- top_k: Top-k sampling parameter (>= 1)
29+
- max_tokens: Maximum number of tokens to generate
30+
"""
31+
self.codebase = codebase
32+
self.agent = create_chat_agent(self.codebase, model_provider=model_provider, model_name=model_name, memory=memory, additional_tools=tools, **kwargs)
33+
34+
def run(self, prompt: str, thread_id: Optional[str] = None) -> str:
35+
"""Run the agent with a prompt.
36+
37+
Args:
38+
prompt: The prompt to run
39+
thread_id: Optional thread ID for message history. If None, a new thread is created.
40+
41+
Returns:
42+
The agent's response
43+
"""
44+
if thread_id is None:
45+
thread_id = str(uuid4())
46+
47+
input = {"query": prompt}
48+
stream = self.agent.stream(input, config={"configurable": {"thread_id": thread_id}}, stream_mode="values")
49+
50+
for s in stream:
51+
message = s["messages"][-1]
52+
if isinstance(message, tuple):
53+
print(message)
54+
else:
55+
if isinstance(message, AIMessage) and isinstance(message.content, list) and "text" in message.content[0]:
56+
AIMessage(message.content[0]["text"]).pretty_print()
57+
else:
58+
message.pretty_print()
59+
60+
return s["final_answer"]
61+
62+
def chat(self, prompt: str, thread_id: Optional[str] = None) -> tuple[str, str]:
63+
"""Chat with the agent, maintaining conversation history.
64+
65+
Args:
66+
prompt: The user message
67+
thread_id: Optional thread ID for message history. If None, a new thread is created.
68+
69+
Returns:
70+
A tuple of (response_content, thread_id) to allow continued conversation
71+
"""
72+
if thread_id is None:
73+
thread_id = str(uuid4())
74+
print(f"Starting new chat thread: {thread_id}")
75+
else:
76+
print(f"Continuing chat thread: {thread_id}")
77+
78+
response = self.run(prompt, thread_id=thread_id)
79+
return response, thread_id
80+
81+
def get_chat_history(self, thread_id: str) -> list:
82+
"""Retrieve the chat history for a specific thread.
83+
84+
Args:
85+
thread_id: The thread ID to retrieve history for
86+
87+
Returns:
88+
List of messages in the conversation history
89+
"""
90+
# Access the agent's memory to get conversation history
91+
if hasattr(self.agent, "get_state"):
92+
state = self.agent.get_state({"configurable": {"thread_id": thread_id}})
93+
if state and "messages" in state:
94+
return state["messages"]
95+
return []

src/codegen/agents/code_agent.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from uuid import uuid4
44

55
from langchain.tools import BaseTool
6-
from langchain_core.messages import AIMessage
6+
from langchain_core.messages import AIMessage, HumanMessage
77
from langchain_core.runnables.config import RunnableConfig
88
from langsmith import Client
99

@@ -94,8 +94,17 @@ def run(self, prompt: str, thread_id: Optional[str] = None) -> str:
9494

9595
# this message has a reducer which appends the current message to the existing history
9696
# see more https://langchain-ai.github.io/langgraph/concepts/low_level/#reducers
97-
input = {"messages": [("user", prompt)]}
98-
tags, metadata = self.get_tags_metadata()
97+
input = {"query": prompt}
98+
metadata = {"project": self.project_name}
99+
tags = []
100+
# Add SWEBench run ID and instance ID to the metadata and tags for filtering
101+
if self.run_id is not None:
102+
metadata["swebench_run_id"] = self.run_id
103+
tags.append(self.run_id)
104+
105+
if self.instance_id is not None:
106+
metadata["swebench_instance_id"] = self.instance_id
107+
tags.append(self.instance_id)
99108

100109
config = RunnableConfig(configurable={"thread_id": thread_id}, tags=tags, metadata=metadata, recursion_limit=100)
101110
# we stream the steps instead of invoke because it allows us to access intermediate nodes
@@ -105,7 +114,11 @@ def run(self, prompt: str, thread_id: Optional[str] = None) -> str:
105114
run_ids = []
106115

107116
for s in stream:
108-
message = s["messages"][-1]
117+
if len(s["messages"]) == 0:
118+
message = HumanMessage(content=prompt)
119+
else:
120+
message = s["messages"][-1]
121+
109122
if isinstance(message, tuple):
110123
print(message)
111124
else:
@@ -119,7 +132,7 @@ def run(self, prompt: str, thread_id: Optional[str] = None) -> str:
119132
run_ids.append(message.additional_kwargs["run_id"])
120133

121134
# Get the last message content
122-
result = s["messages"][-1].content
135+
result = s["final_answer"]
123136

124137
# Try to find run IDs in the LangSmith client's recent runs
125138
try:

src/codegen/extensions/langchain/agent.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,10 @@
66
from langchain_core.messages import SystemMessage
77
from langgraph.checkpoint.memory import MemorySaver
88
from langgraph.graph.graph import CompiledGraph
9-
from langgraph.prebuilt import create_react_agent
109

11-
from .llm import LLM
12-
from .prompts import REASONER_SYSTEM_MESSAGE
13-
from .tools import (
10+
from codegen.extensions.langchain.llm import LLM
11+
from codegen.extensions.langchain.prompts import REASONER_SYSTEM_MESSAGE
12+
from codegen.extensions.langchain.tools import (
1413
CreateFileTool,
1514
DeleteFileTool,
1615
ListDirectoryTool,
@@ -25,6 +24,8 @@
2524
ViewFileTool,
2625
)
2726

27+
from .graph import create_react_agent
28+
2829
if TYPE_CHECKING:
2930
from codegen import Codebase
3031

@@ -88,7 +89,7 @@ def create_codebase_agent(
8889

8990
memory = MemorySaver() if memory else None
9091

91-
return create_react_agent(model=llm, tools=tools, prompt=system_message, checkpointer=memory, debug=debug)
92+
return create_react_agent(model=llm, tools=tools, system_message=system_message, checkpointer=memory, debug=debug)
9293

9394

9495
def create_chat_agent(
@@ -137,7 +138,7 @@ def create_chat_agent(
137138

138139
memory = MemorySaver() if memory else None
139140

140-
return create_react_agent(model=llm, tools=tools, prompt=system_message, checkpointer=memory, debug=debug)
141+
return create_react_agent(model=llm, tools=tools, system_message=system_message, checkpointer=memory, debug=debug)
141142

142143

143144
def create_codebase_inspector_agent(
@@ -174,7 +175,7 @@ def create_codebase_inspector_agent(
174175
]
175176

176177
memory = MemorySaver() if memory else None
177-
return create_react_agent(model=llm, tools=tools, prompt=system_message, checkpointer=memory, debug=debug)
178+
return create_react_agent(model=llm, tools=tools, system_message=system_message, checkpointer=memory, debug=debug)
178179

179180

180181
def create_agent_with_tools(
@@ -208,4 +209,4 @@ def create_agent_with_tools(
208209

209210
memory = MemorySaver() if memory else None
210211

211-
return create_react_agent(model=llm, tools=tools, prompt=system_message, checkpointer=memory, debug=debug)
212+
return create_react_agent(model=llm, tools=tools, system_message=system_message, checkpointer=memory, debug=debug)

0 commit comments

Comments
 (0)