Skip to content

Commit 998b08d

Browse files
authored
feat: BashCommandTool (#510)
1 parent ac677f3 commit 998b08d

File tree

3 files changed

+306
-1
lines changed

3 files changed

+306
-1
lines changed

src/codegen/extensions/langchain/tools.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
"""Langchain tools for workspace operations."""
22

33
import json
4-
from typing import ClassVar, Literal, Optional
4+
from typing import Callable, ClassVar, Literal, Optional
55

66
from langchain.tools import BaseTool
77
from pydantic import BaseModel, Field
88

99
from codegen import Codebase
1010
from codegen.extensions.linear.linear_client import LinearClient
11+
from codegen.extensions.tools.bash import run_bash_command
1112
from codegen.extensions.tools.linear.linear import (
1213
linear_comment_on_issue_tool,
1314
linear_create_issue_tool,
@@ -16,6 +17,7 @@
1617
linear_get_teams_tool,
1718
linear_search_issues_tool,
1819
)
20+
from codegen.extensions.tools.link_annotation import add_links_to_message
1921

2022
from ..tools import (
2123
commit,
@@ -354,6 +356,30 @@ def _run(self, query: str, k: int = 5, preview_length: int = 200) -> str:
354356
return json.dumps(result, indent=2)
355357

356358

359+
########################################################################################################################
360+
# BASH
361+
########################################################################################################################
362+
363+
364+
class RunBashCommandInput(BaseModel):
365+
"""Input for running a bash command."""
366+
367+
command: str = Field(..., description="The command to run")
368+
is_background: bool = Field(default=False, description="Whether to run the command in the background")
369+
370+
371+
class RunBashCommandTool(BaseTool):
372+
"""Tool for running bash commands."""
373+
374+
name: ClassVar[str] = "run_bash_command"
375+
description: ClassVar[str] = "Run a bash command and return its output"
376+
args_schema: ClassVar[type[BaseModel]] = RunBashCommandInput
377+
378+
def _run(self, command: str, is_background: bool = False) -> str:
379+
result = run_bash_command(command, is_background)
380+
return json.dumps(result, indent=2)
381+
382+
357383
########################################################################################################################
358384
# GITHUB
359385
########################################################################################################################
@@ -607,6 +633,43 @@ def _run(self) -> str:
607633
return json.dumps(result, indent=2)
608634

609635

636+
########################################################################################################################
637+
# SLACK
638+
########################################################################################################################
639+
640+
641+
class SlackSendMessageInput(BaseModel):
642+
"""Input for sending a message to Slack."""
643+
644+
content: str = Field(..., description="Message to send to Slack")
645+
646+
647+
class SlackSendMessageTool(BaseTool):
648+
"""Tool for sending a message to Slack."""
649+
650+
name: ClassVar[str] = "send_slack_message"
651+
description: ClassVar[str] = (
652+
"Send a message via Slack."
653+
"Write symbol names (classes, functions, etc.) or full filepaths in single backticks and they will be auto-linked to the code."
654+
"Use Slack-style markdown for other links."
655+
)
656+
args_schema: ClassVar[type[BaseModel]] = SlackSendMessageInput
657+
say: Callable[[str], None] = Field(exclude=True)
658+
codebase: Codebase = Field(exclude=True)
659+
660+
def __init__(self, codebase: Codebase, say: Callable[[str], None]) -> None:
661+
super().__init__(say=say, codebase=codebase)
662+
self.say = say
663+
self.codebase = codebase
664+
665+
def _run(self, content: str) -> str:
666+
print("> Adding links to message")
667+
content_formatted = add_links_to_message(content, self.codebase)
668+
print("> Sending message to Slack")
669+
self.say(content_formatted)
670+
return "✅ Message sent successfully"
671+
672+
610673
########################################################################################################################
611674
# EXPORT
612675
########################################################################################################################
@@ -631,6 +694,7 @@ def get_workspace_tools(codebase: Codebase) -> list["BaseTool"]:
631694
MoveSymbolTool(codebase),
632695
RenameFileTool(codebase),
633696
RevealSymbolTool(codebase),
697+
RunBashCommandTool(), # Note: This tool doesn't need the codebase
634698
SearchTool(codebase),
635699
SemanticEditTool(codebase),
636700
SemanticSearchTool(codebase),

src/codegen/extensions/tools/bash.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
"""Tools for running bash commands."""
2+
3+
import re
4+
import shlex
5+
import subprocess
6+
from typing import Any
7+
8+
# Whitelist of allowed commands and their flags
9+
ALLOWED_COMMANDS = {
10+
"ls": {"-l", "-a", "-h", "-t", "-r", "--color"},
11+
"cat": {"-n", "--number"},
12+
"head": {"-n"},
13+
"tail": {"-n", "-f"},
14+
"grep": {"-i", "-r", "-n", "-l", "-v", "--color"},
15+
"find": {"-name", "-type", "-size", "-mtime"},
16+
"pwd": set(),
17+
"echo": set(), # echo is safe with any args
18+
"ps": {"-ef", "-aux"},
19+
"df": {"-h"},
20+
"du": {"-h", "-s"},
21+
"wc": {"-l", "-w", "-c"},
22+
}
23+
24+
25+
def validate_command(command: str) -> tuple[bool, str]:
26+
"""Validate if a command is safe to execute.
27+
28+
Args:
29+
command: The command to validate
30+
31+
Returns:
32+
Tuple of (is_valid, error_message)
33+
"""
34+
try:
35+
# Check for dangerous patterns first, before splitting
36+
dangerous_patterns = [
37+
(r"[|;&`$]", "shell operators (|, ;, &, `, $)"),
38+
(r"rm\s", "remove command"),
39+
(r">\s", "output redirection"),
40+
(r">>\s", "append redirection"),
41+
(r"<\s", "input redirection"),
42+
(r"\.\.", "parent directory traversal"),
43+
(r"sudo\s", "sudo command"),
44+
(r"chmod\s", "chmod command"),
45+
(r"chown\s", "chown command"),
46+
(r"mv\s", "move command"),
47+
(r"cp\s", "copy command"),
48+
]
49+
50+
for pattern, description in dangerous_patterns:
51+
if re.search(pattern, command):
52+
return False, f"Command contains dangerous pattern: {description}"
53+
54+
# Split command into tokens while preserving quoted strings
55+
tokens = shlex.split(command)
56+
if not tokens:
57+
return False, "Empty command"
58+
59+
# Get base command (first token)
60+
base_cmd = tokens[0]
61+
62+
# Check if base command is in whitelist
63+
if base_cmd not in ALLOWED_COMMANDS:
64+
return False, f"Command '{base_cmd}' is not allowed. Allowed commands: {', '.join(sorted(ALLOWED_COMMANDS.keys()))}"
65+
66+
# Extract and split combined flags (e.g., -la -> -l -a)
67+
flags = set()
68+
for token in tokens[1:]:
69+
if token.startswith("-"):
70+
if token.startswith("--"):
71+
# Handle long options (e.g., --color)
72+
flags.add(token)
73+
else:
74+
# Handle combined short options (e.g., -la)
75+
# Skip the first "-" and add each character as a flag
76+
for char in token[1:]:
77+
flags.add(f"-{char}")
78+
79+
allowed_flags = ALLOWED_COMMANDS[base_cmd]
80+
81+
# For commands with no flag restrictions (like echo), skip flag validation
82+
if allowed_flags:
83+
invalid_flags = flags - allowed_flags
84+
if invalid_flags:
85+
return False, f"Flags {invalid_flags} are not allowed for command '{base_cmd}'. Allowed flags: {allowed_flags}"
86+
87+
return True, ""
88+
89+
except Exception as e:
90+
return False, f"Failed to validate command: {e!s}"
91+
92+
93+
def run_bash_command(command: str, is_background: bool = False) -> dict[str, Any]:
94+
"""Run a bash command and return its output.
95+
96+
Args:
97+
command: The command to run
98+
is_background: Whether to run the command in the background
99+
100+
Returns:
101+
Dictionary containing the command output or error
102+
"""
103+
# First validate the command
104+
is_valid, error_message = validate_command(command)
105+
if not is_valid:
106+
return {
107+
"status": "error",
108+
"error": f"Invalid command: {error_message}",
109+
}
110+
111+
try:
112+
if is_background:
113+
# For background processes, we use Popen and return immediately
114+
process = subprocess.Popen(
115+
command,
116+
shell=True,
117+
stdout=subprocess.PIPE,
118+
stderr=subprocess.PIPE,
119+
text=True,
120+
)
121+
return {
122+
"status": "success",
123+
"message": f"Command '{command}' started in background with PID {process.pid}",
124+
}
125+
126+
# For foreground processes, we wait for completion
127+
result = subprocess.run(
128+
command,
129+
shell=True,
130+
capture_output=True,
131+
text=True,
132+
check=True, # This will raise CalledProcessError if command fails
133+
)
134+
135+
return {
136+
"status": "success",
137+
"stdout": result.stdout,
138+
"stderr": result.stderr,
139+
}
140+
except subprocess.CalledProcessError as e:
141+
return {
142+
"status": "error",
143+
"error": f"Command failed with exit code {e.returncode}",
144+
"stdout": e.stdout,
145+
"stderr": e.stderr,
146+
}
147+
except Exception as e:
148+
return {
149+
"status": "error",
150+
"error": f"Failed to run command: {e!s}",
151+
}
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
"""Tests for bash command tools."""
2+
3+
import time
4+
5+
from codegen.extensions.tools.bash import run_bash_command
6+
7+
8+
def test_run_bash_command() -> None:
9+
"""Test running a bash command."""
10+
# Test a simple echo command
11+
result = run_bash_command("echo 'Hello, World!'")
12+
assert result["status"] == "success"
13+
assert "Hello, World!" in result["stdout"]
14+
assert result["stderr"] == ""
15+
16+
# Test ls with combined flags
17+
result = run_bash_command("ls -la")
18+
assert result["status"] == "success"
19+
20+
# Test ls with separate flags
21+
result = run_bash_command("ls -l -a")
22+
assert result["status"] == "success"
23+
24+
# Test ls with long option
25+
result = run_bash_command("ls --color")
26+
assert result["status"] == "success"
27+
28+
# Test grep with allowed flags
29+
result = run_bash_command("grep -n test *.py")
30+
assert result["status"] == "success"
31+
32+
33+
def test_command_validation() -> None:
34+
"""Test command validation."""
35+
# Test disallowed command
36+
result = run_bash_command("rm -rf /")
37+
assert result["status"] == "error"
38+
assert "dangerous pattern: remove command" in result["error"]
39+
40+
# Test command with disallowed flags
41+
result = run_bash_command("ls --invalid-flag")
42+
assert result["status"] == "error"
43+
assert "Flags" in result["error"]
44+
assert "not allowed" in result["error"]
45+
46+
# Test command with invalid combined flags
47+
result = run_bash_command("ls -laz") # -z is not allowed
48+
assert result["status"] == "error"
49+
assert "Flags {'-z'} are not allowed" in result["error"]
50+
51+
# Test dangerous patterns
52+
dangerous_commands = [
53+
"ls | grep test", # Pipe
54+
"ls; rm file", # Command chaining
55+
"ls > output.txt", # Redirection
56+
"sudo ls", # Sudo
57+
"ls ../parent", # Parent directory
58+
"mv file1 file2", # Move
59+
"cp file1 file2", # Copy
60+
"chmod +x file", # Change permissions
61+
]
62+
63+
expected_patterns = [
64+
"shell operators", # For pipe
65+
"shell operators", # For command chaining
66+
"output redirection", # For redirection
67+
"sudo command", # For sudo
68+
"parent directory traversal", # For parent directory
69+
"move command", # For move
70+
"copy command", # For copy
71+
"chmod command", # For chmod
72+
]
73+
74+
for cmd, pattern in zip(dangerous_commands, expected_patterns):
75+
result = run_bash_command(cmd)
76+
assert result["status"] == "error", f"Command should be blocked: {cmd}"
77+
assert f"dangerous pattern: {pattern}" in result["error"], f"Expected '{pattern}' in error for command: {cmd}"
78+
79+
80+
def test_background_command() -> None:
81+
"""Test background command execution."""
82+
# Test a safe background command
83+
result = run_bash_command("tail -f /dev/null", is_background=True)
84+
assert result["status"] == "success"
85+
assert "started in background with PID" in result["message"]
86+
87+
# Clean up by finding and killing the background process
88+
pid = int(result["message"].split()[-1])
89+
run_bash_command(f"ps -p {pid} || true") # Check if process exists
90+
time.sleep(1) # Give process time to start/stop

0 commit comments

Comments
 (0)