Skip to content

feat: BashCommandTool #510

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 65 additions & 1 deletion src/codegen/extensions/langchain/tools.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
"""Langchain tools for workspace operations."""

import json
from typing import ClassVar, Literal, Optional
from typing import Callable, ClassVar, Literal, Optional

from langchain.tools import BaseTool
from pydantic import BaseModel, Field

from codegen import Codebase
from codegen.extensions.linear.linear_client import LinearClient
from codegen.extensions.tools.bash import run_bash_command
from codegen.extensions.tools.linear.linear import (
linear_comment_on_issue_tool,
linear_create_issue_tool,
Expand All @@ -16,6 +17,7 @@
linear_get_teams_tool,
linear_search_issues_tool,
)
from codegen.extensions.tools.link_annotation import add_links_to_message

from ..tools import (
commit,
Expand Down Expand Up @@ -46,9 +48,9 @@
class ViewFileTool(BaseTool):
"""Tool for viewing file contents and metadata."""

name: ClassVar[str] = "view_file"

Check failure on line 51 in src/codegen/extensions/langchain/tools.py

View workflow job for this annotation

GitHub Actions / mypy

error: Cannot override instance variable (previously declared on base class "BaseTool") with class variable [misc]
description: ClassVar[str] = "View the contents and metadata of a file in the codebase"

Check failure on line 52 in src/codegen/extensions/langchain/tools.py

View workflow job for this annotation

GitHub Actions / mypy

error: Cannot override instance variable (previously declared on base class "BaseTool") with class variable [misc]
args_schema: ClassVar[type[BaseModel]] = ViewFileInput

Check failure on line 53 in src/codegen/extensions/langchain/tools.py

View workflow job for this annotation

GitHub Actions / mypy

error: Cannot override instance variable (previously declared on base class "BaseTool") with class variable [misc]
codebase: Codebase = Field(exclude=True)

def __init__(self, codebase: Codebase) -> None:
Expand All @@ -69,9 +71,9 @@
class ListDirectoryTool(BaseTool):
"""Tool for listing directory contents."""

name: ClassVar[str] = "list_directory"

Check failure on line 74 in src/codegen/extensions/langchain/tools.py

View workflow job for this annotation

GitHub Actions / mypy

error: Cannot override instance variable (previously declared on base class "BaseTool") with class variable [misc]
description: ClassVar[str] = "List contents of a directory in the codebase"

Check failure on line 75 in src/codegen/extensions/langchain/tools.py

View workflow job for this annotation

GitHub Actions / mypy

error: Cannot override instance variable (previously declared on base class "BaseTool") with class variable [misc]
args_schema: ClassVar[type[BaseModel]] = ListDirectoryInput

Check failure on line 76 in src/codegen/extensions/langchain/tools.py

View workflow job for this annotation

GitHub Actions / mypy

error: Cannot override instance variable (previously declared on base class "BaseTool") with class variable [misc]
codebase: Codebase = Field(exclude=True)

def __init__(self, codebase: Codebase) -> None:
Expand All @@ -92,9 +94,9 @@
class SearchTool(BaseTool):
"""Tool for searching the codebase."""

name: ClassVar[str] = "search"

Check failure on line 97 in src/codegen/extensions/langchain/tools.py

View workflow job for this annotation

GitHub Actions / mypy

error: Cannot override instance variable (previously declared on base class "BaseTool") with class variable [misc]
description: ClassVar[str] = "Search the codebase using text search"

Check failure on line 98 in src/codegen/extensions/langchain/tools.py

View workflow job for this annotation

GitHub Actions / mypy

error: Cannot override instance variable (previously declared on base class "BaseTool") with class variable [misc]
args_schema: ClassVar[type[BaseModel]] = SearchInput

Check failure on line 99 in src/codegen/extensions/langchain/tools.py

View workflow job for this annotation

GitHub Actions / mypy

error: Cannot override instance variable (previously declared on base class "BaseTool") with class variable [misc]
codebase: Codebase = Field(exclude=True)

def __init__(self, codebase: Codebase) -> None:
Expand All @@ -115,7 +117,7 @@
class EditFileTool(BaseTool):
"""Tool for editing files."""

name: ClassVar[str] = "edit_file"

Check failure on line 120 in src/codegen/extensions/langchain/tools.py

View workflow job for this annotation

GitHub Actions / mypy

error: Cannot override instance variable (previously declared on base class "BaseTool") with class variable [misc]
description: ClassVar[str] = "Edit a file by replacing its entire content"
args_schema: ClassVar[type[BaseModel]] = EditFileInput
codebase: Codebase = Field(exclude=True)
Expand Down Expand Up @@ -354,6 +356,30 @@
return json.dumps(result, indent=2)


########################################################################################################################
# BASH
########################################################################################################################


class RunBashCommandInput(BaseModel):
"""Input for running a bash command."""

command: str = Field(..., description="The command to run")
is_background: bool = Field(default=False, description="Whether to run the command in the background")


class RunBashCommandTool(BaseTool):
"""Tool for running bash commands."""

name: ClassVar[str] = "run_bash_command"
description: ClassVar[str] = "Run a bash command and return its output"
args_schema: ClassVar[type[BaseModel]] = RunBashCommandInput

def _run(self, command: str, is_background: bool = False) -> str:
result = run_bash_command(command, is_background)
return json.dumps(result, indent=2)


########################################################################################################################
# GITHUB
########################################################################################################################
Expand Down Expand Up @@ -607,6 +633,43 @@
return json.dumps(result, indent=2)


########################################################################################################################
# SLACK
########################################################################################################################


class SlackSendMessageInput(BaseModel):
"""Input for sending a message to Slack."""

content: str = Field(..., description="Message to send to Slack")


class SlackSendMessageTool(BaseTool):
"""Tool for sending a message to Slack."""

name: ClassVar[str] = "send_slack_message"
description: ClassVar[str] = (
"Send a message via Slack."
"Write symbol names (classes, functions, etc.) or full filepaths in single backticks and they will be auto-linked to the code."
"Use Slack-style markdown for other links."
)
args_schema: ClassVar[type[BaseModel]] = SlackSendMessageInput
say: Callable[[str], None] = Field(exclude=True)
codebase: Codebase = Field(exclude=True)

def __init__(self, codebase: Codebase, say: Callable[[str], None]) -> None:
super().__init__(say=say, codebase=codebase)
self.say = say
self.codebase = codebase

def _run(self, content: str) -> str:
print("> Adding links to message")
content_formatted = add_links_to_message(content, self.codebase)
print("> Sending message to Slack")
self.say(content_formatted)
return "✅ Message sent successfully"


########################################################################################################################
# EXPORT
########################################################################################################################
Expand All @@ -631,6 +694,7 @@
MoveSymbolTool(codebase),
RenameFileTool(codebase),
RevealSymbolTool(codebase),
RunBashCommandTool(), # Note: This tool doesn't need the codebase
SearchTool(codebase),
SemanticEditTool(codebase),
SemanticSearchTool(codebase),
Expand Down
151 changes: 151 additions & 0 deletions src/codegen/extensions/tools/bash.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
"""Tools for running bash commands."""

import re
import shlex
import subprocess
from typing import Any

# Whitelist of allowed commands and their flags
ALLOWED_COMMANDS = {
"ls": {"-l", "-a", "-h", "-t", "-r", "--color"},
"cat": {"-n", "--number"},
"head": {"-n"},
"tail": {"-n", "-f"},
"grep": {"-i", "-r", "-n", "-l", "-v", "--color"},
"find": {"-name", "-type", "-size", "-mtime"},
"pwd": set(),
"echo": set(), # echo is safe with any args
"ps": {"-ef", "-aux"},
"df": {"-h"},
"du": {"-h", "-s"},
"wc": {"-l", "-w", "-c"},
}


def validate_command(command: str) -> tuple[bool, str]:
"""Validate if a command is safe to execute.

Args:
command: The command to validate

Returns:
Tuple of (is_valid, error_message)
"""
try:
# Check for dangerous patterns first, before splitting
dangerous_patterns = [
(r"[|;&`$]", "shell operators (|, ;, &, `, $)"),
(r"rm\s", "remove command"),
(r">\s", "output redirection"),
(r">>\s", "append redirection"),
(r"<\s", "input redirection"),
(r"\.\.", "parent directory traversal"),
(r"sudo\s", "sudo command"),
(r"chmod\s", "chmod command"),
(r"chown\s", "chown command"),
(r"mv\s", "move command"),
(r"cp\s", "copy command"),
]

for pattern, description in dangerous_patterns:
if re.search(pattern, command):
return False, f"Command contains dangerous pattern: {description}"

# Split command into tokens while preserving quoted strings
tokens = shlex.split(command)
if not tokens:
return False, "Empty command"

# Get base command (first token)
base_cmd = tokens[0]

# Check if base command is in whitelist
if base_cmd not in ALLOWED_COMMANDS:
return False, f"Command '{base_cmd}' is not allowed. Allowed commands: {', '.join(sorted(ALLOWED_COMMANDS.keys()))}"

# Extract and split combined flags (e.g., -la -> -l -a)
flags = set()
for token in tokens[1:]:
if token.startswith("-"):
if token.startswith("--"):
# Handle long options (e.g., --color)
flags.add(token)
else:
# Handle combined short options (e.g., -la)
# Skip the first "-" and add each character as a flag
for char in token[1:]:
flags.add(f"-{char}")

allowed_flags = ALLOWED_COMMANDS[base_cmd]

# For commands with no flag restrictions (like echo), skip flag validation
if allowed_flags:
invalid_flags = flags - allowed_flags
if invalid_flags:
return False, f"Flags {invalid_flags} are not allowed for command '{base_cmd}'. Allowed flags: {allowed_flags}"

return True, ""

except Exception as e:
return False, f"Failed to validate command: {e!s}"


def run_bash_command(command: str, is_background: bool = False) -> dict[str, Any]:
"""Run a bash command and return its output.

Args:
command: The command to run
is_background: Whether to run the command in the background

Returns:
Dictionary containing the command output or error
"""
# First validate the command
is_valid, error_message = validate_command(command)
if not is_valid:
return {
"status": "error",
"error": f"Invalid command: {error_message}",
}

try:
if is_background:
# For background processes, we use Popen and return immediately
process = subprocess.Popen(
command,
shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
return {
"status": "success",
"message": f"Command '{command}' started in background with PID {process.pid}",
}

# For foreground processes, we wait for completion
result = subprocess.run(
command,
shell=True,
capture_output=True,
text=True,
check=True, # This will raise CalledProcessError if command fails
)

return {
"status": "success",
"stdout": result.stdout,
"stderr": result.stderr,
}
except subprocess.CalledProcessError as e:
return {
"status": "error",
"error": f"Command failed with exit code {e.returncode}",
"stdout": e.stdout,
"stderr": e.stderr,
}
except Exception as e:
return {
"status": "error",
"error": f"Failed to run command: {e!s}",
}
90 changes: 90 additions & 0 deletions tests/integration/extension/test_bash.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""Tests for bash command tools."""

import time

from codegen.extensions.tools.bash import run_bash_command


def test_run_bash_command() -> None:
"""Test running a bash command."""
# Test a simple echo command
result = run_bash_command("echo 'Hello, World!'")
assert result["status"] == "success"
assert "Hello, World!" in result["stdout"]
assert result["stderr"] == ""

# Test ls with combined flags
result = run_bash_command("ls -la")
assert result["status"] == "success"

# Test ls with separate flags
result = run_bash_command("ls -l -a")
assert result["status"] == "success"

# Test ls with long option
result = run_bash_command("ls --color")
assert result["status"] == "success"

# Test grep with allowed flags
result = run_bash_command("grep -n test *.py")
assert result["status"] == "success"


def test_command_validation() -> None:
"""Test command validation."""
# Test disallowed command
result = run_bash_command("rm -rf /")
assert result["status"] == "error"
assert "dangerous pattern: remove command" in result["error"]

# Test command with disallowed flags
result = run_bash_command("ls --invalid-flag")
assert result["status"] == "error"
assert "Flags" in result["error"]
assert "not allowed" in result["error"]

# Test command with invalid combined flags
result = run_bash_command("ls -laz") # -z is not allowed
assert result["status"] == "error"
assert "Flags {'-z'} are not allowed" in result["error"]

# Test dangerous patterns
dangerous_commands = [
"ls | grep test", # Pipe
"ls; rm file", # Command chaining
"ls > output.txt", # Redirection
"sudo ls", # Sudo
"ls ../parent", # Parent directory
"mv file1 file2", # Move
"cp file1 file2", # Copy
"chmod +x file", # Change permissions
]

expected_patterns = [
"shell operators", # For pipe
"shell operators", # For command chaining
"output redirection", # For redirection
"sudo command", # For sudo
"parent directory traversal", # For parent directory
"move command", # For move
"copy command", # For copy
"chmod command", # For chmod
]

for cmd, pattern in zip(dangerous_commands, expected_patterns):
result = run_bash_command(cmd)
assert result["status"] == "error", f"Command should be blocked: {cmd}"
assert f"dangerous pattern: {pattern}" in result["error"], f"Expected '{pattern}' in error for command: {cmd}"


def test_background_command() -> None:
"""Test background command execution."""
# Test a safe background command
result = run_bash_command("tail -f /dev/null", is_background=True)
assert result["status"] == "success"
assert "started in background with PID" in result["message"]

# Clean up by finding and killing the background process
pid = int(result["message"].split()[-1])
run_bash_command(f"ps -p {pid} || true") # Check if process exists
time.sleep(1) # Give process time to start/stop