Merge PR #30: Harden exec tool with safety guard
This commit is contained in:
commit
579cbfc8fe
@ -40,14 +40,17 @@ class AgentLoop:
|
|||||||
workspace: Path,
|
workspace: Path,
|
||||||
model: str | None = None,
|
model: str | None = None,
|
||||||
max_iterations: int = 20,
|
max_iterations: int = 20,
|
||||||
brave_api_key: str | None = None
|
brave_api_key: str | None = None,
|
||||||
|
exec_config: "ExecToolConfig | None" = None,
|
||||||
):
|
):
|
||||||
|
from nanobot.config.schema import ExecToolConfig
|
||||||
self.bus = bus
|
self.bus = bus
|
||||||
self.provider = provider
|
self.provider = provider
|
||||||
self.workspace = workspace
|
self.workspace = workspace
|
||||||
self.model = model or provider.get_default_model()
|
self.model = model or provider.get_default_model()
|
||||||
self.max_iterations = max_iterations
|
self.max_iterations = max_iterations
|
||||||
self.brave_api_key = brave_api_key
|
self.brave_api_key = brave_api_key
|
||||||
|
self.exec_config = exec_config or ExecToolConfig()
|
||||||
|
|
||||||
self.context = ContextBuilder(workspace)
|
self.context = ContextBuilder(workspace)
|
||||||
self.sessions = SessionManager(workspace)
|
self.sessions = SessionManager(workspace)
|
||||||
@ -58,6 +61,7 @@ class AgentLoop:
|
|||||||
bus=bus,
|
bus=bus,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
brave_api_key=brave_api_key,
|
brave_api_key=brave_api_key,
|
||||||
|
exec_config=self.exec_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._running = False
|
self._running = False
|
||||||
@ -72,7 +76,11 @@ class AgentLoop:
|
|||||||
self.tools.register(ListDirTool())
|
self.tools.register(ListDirTool())
|
||||||
|
|
||||||
# Shell tool
|
# Shell tool
|
||||||
self.tools.register(ExecTool(working_dir=str(self.workspace)))
|
self.tools.register(ExecTool(
|
||||||
|
working_dir=str(self.workspace),
|
||||||
|
timeout=self.exec_config.timeout,
|
||||||
|
restrict_to_workspace=self.exec_config.restrict_to_workspace,
|
||||||
|
))
|
||||||
|
|
||||||
# Web tools
|
# Web tools
|
||||||
self.tools.register(WebSearchTool(api_key=self.brave_api_key))
|
self.tools.register(WebSearchTool(api_key=self.brave_api_key))
|
||||||
|
|||||||
@ -33,12 +33,15 @@ class SubagentManager:
|
|||||||
bus: MessageBus,
|
bus: MessageBus,
|
||||||
model: str | None = None,
|
model: str | None = None,
|
||||||
brave_api_key: str | None = None,
|
brave_api_key: str | None = None,
|
||||||
|
exec_config: "ExecToolConfig | None" = None,
|
||||||
):
|
):
|
||||||
|
from nanobot.config.schema import ExecToolConfig
|
||||||
self.provider = provider
|
self.provider = provider
|
||||||
self.workspace = workspace
|
self.workspace = workspace
|
||||||
self.bus = bus
|
self.bus = bus
|
||||||
self.model = model or provider.get_default_model()
|
self.model = model or provider.get_default_model()
|
||||||
self.brave_api_key = brave_api_key
|
self.brave_api_key = brave_api_key
|
||||||
|
self.exec_config = exec_config or ExecToolConfig()
|
||||||
self._running_tasks: dict[str, asyncio.Task[None]] = {}
|
self._running_tasks: dict[str, asyncio.Task[None]] = {}
|
||||||
|
|
||||||
async def spawn(
|
async def spawn(
|
||||||
@ -96,7 +99,11 @@ class SubagentManager:
|
|||||||
tools.register(ReadFileTool())
|
tools.register(ReadFileTool())
|
||||||
tools.register(WriteFileTool())
|
tools.register(WriteFileTool())
|
||||||
tools.register(ListDirTool())
|
tools.register(ListDirTool())
|
||||||
tools.register(ExecTool(working_dir=str(self.workspace)))
|
tools.register(ExecTool(
|
||||||
|
working_dir=str(self.workspace),
|
||||||
|
timeout=self.exec_config.timeout,
|
||||||
|
restrict_to_workspace=self.exec_config.restrict_to_workspace,
|
||||||
|
))
|
||||||
tools.register(WebSearchTool(api_key=self.brave_api_key))
|
tools.register(WebSearchTool(api_key=self.brave_api_key))
|
||||||
tools.register(WebFetchTool())
|
tools.register(WebFetchTool())
|
||||||
|
|
||||||
|
|||||||
@ -12,6 +12,15 @@ class Tool(ABC):
|
|||||||
the environment, such as reading files, executing commands, etc.
|
the environment, such as reading files, executing commands, etc.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_TYPE_MAP = {
|
||||||
|
"string": str,
|
||||||
|
"integer": int,
|
||||||
|
"number": (int, float),
|
||||||
|
"boolean": bool,
|
||||||
|
"array": list,
|
||||||
|
"object": dict,
|
||||||
|
}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
@ -42,6 +51,65 @@ class Tool(ABC):
|
|||||||
String result of the tool execution.
|
String result of the tool execution.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def validate_params(self, params: dict[str, Any]) -> list[str]:
|
||||||
|
"""
|
||||||
|
Lightweight JSON schema validation for tool parameters.
|
||||||
|
|
||||||
|
Returns a list of error strings (empty if valid).
|
||||||
|
Unknown params are ignored.
|
||||||
|
"""
|
||||||
|
schema = self.parameters or {}
|
||||||
|
|
||||||
|
# Default to an object schema if type is missing, and fail fast on unsupported top-level types.
|
||||||
|
if "type" not in schema:
|
||||||
|
schema = {"type": "object", **schema}
|
||||||
|
elif schema.get("type") != "object":
|
||||||
|
raise ValueError(
|
||||||
|
f"Tool parameter schemas must have top-level type 'object'; got {schema.get('type')!r}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._validate_schema(params, schema, path="")
|
||||||
|
|
||||||
|
def _validate_schema(self, value: Any, schema: dict[str, Any], path: str) -> list[str]:
|
||||||
|
errors: list[str] = []
|
||||||
|
expected_type = schema.get("type")
|
||||||
|
label = path or "parameter"
|
||||||
|
|
||||||
|
if expected_type in self._TYPE_MAP and not isinstance(value, self._TYPE_MAP[expected_type]):
|
||||||
|
return [f"{label} should be {expected_type}"]
|
||||||
|
|
||||||
|
if "enum" in schema and value not in schema["enum"]:
|
||||||
|
errors.append(f"{label} must be one of {schema['enum']}")
|
||||||
|
|
||||||
|
if expected_type in ("integer", "number"):
|
||||||
|
if "minimum" in schema and value < schema["minimum"]:
|
||||||
|
errors.append(f"{label} must be >= {schema['minimum']}")
|
||||||
|
if "maximum" in schema and value > schema["maximum"]:
|
||||||
|
errors.append(f"{label} must be <= {schema['maximum']}")
|
||||||
|
|
||||||
|
if expected_type == "string":
|
||||||
|
if "minLength" in schema and len(value) < schema["minLength"]:
|
||||||
|
errors.append(f"{label} must be at least {schema['minLength']} chars")
|
||||||
|
if "maxLength" in schema and len(value) > schema["maxLength"]:
|
||||||
|
errors.append(f"{label} must be at most {schema['maxLength']} chars")
|
||||||
|
|
||||||
|
if expected_type == "object":
|
||||||
|
properties = schema.get("properties", {})
|
||||||
|
for key in schema.get("required", []):
|
||||||
|
if key not in value:
|
||||||
|
errors.append(f"missing required {path}.{key}" if path else f"missing required {key}")
|
||||||
|
for key, item in value.items():
|
||||||
|
if key in properties:
|
||||||
|
errors.extend(self._validate_schema(item, properties[key], f"{path}.{key}" if path else key))
|
||||||
|
|
||||||
|
if expected_type == "array":
|
||||||
|
items_schema = schema.get("items")
|
||||||
|
if items_schema:
|
||||||
|
for idx, item in enumerate(value):
|
||||||
|
errors.extend(self._validate_schema(item, items_schema, f"{path}[{idx}]" if path else f"[{idx}]"))
|
||||||
|
|
||||||
|
return errors
|
||||||
|
|
||||||
def to_schema(self) -> dict[str, Any]:
|
def to_schema(self) -> dict[str, Any]:
|
||||||
"""Convert tool to OpenAI function schema format."""
|
"""Convert tool to OpenAI function schema format."""
|
||||||
|
|||||||
@ -52,8 +52,11 @@ class ToolRegistry:
|
|||||||
tool = self._tools.get(name)
|
tool = self._tools.get(name)
|
||||||
if not tool:
|
if not tool:
|
||||||
return f"Error: Tool '{name}' not found"
|
return f"Error: Tool '{name}' not found"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
errors = tool.validate_params(params)
|
||||||
|
if errors:
|
||||||
|
return f"Error: Invalid parameters for tool '{name}': " + "; ".join(errors)
|
||||||
return await tool.execute(**params)
|
return await tool.execute(**params)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return f"Error executing {name}: {str(e)}"
|
return f"Error executing {name}: {str(e)}"
|
||||||
|
|||||||
@ -2,6 +2,8 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from nanobot.agent.tools.base import Tool
|
from nanobot.agent.tools.base import Tool
|
||||||
@ -10,9 +12,28 @@ from nanobot.agent.tools.base import Tool
|
|||||||
class ExecTool(Tool):
|
class ExecTool(Tool):
|
||||||
"""Tool to execute shell commands."""
|
"""Tool to execute shell commands."""
|
||||||
|
|
||||||
def __init__(self, timeout: int = 60, working_dir: str | None = None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
timeout: int = 60,
|
||||||
|
working_dir: str | None = None,
|
||||||
|
deny_patterns: list[str] | None = None,
|
||||||
|
allow_patterns: list[str] | None = None,
|
||||||
|
restrict_to_workspace: bool = False,
|
||||||
|
):
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
self.working_dir = working_dir
|
self.working_dir = working_dir
|
||||||
|
self.deny_patterns = deny_patterns or [
|
||||||
|
r"\brm\s+-[rf]{1,2}\b", # rm -r, rm -rf, rm -fr
|
||||||
|
r"\bdel\s+/[fq]\b", # del /f, del /q
|
||||||
|
r"\brmdir\s+/s\b", # rmdir /s
|
||||||
|
r"\b(format|mkfs|diskpart)\b", # disk operations
|
||||||
|
r"\bdd\s+if=", # dd
|
||||||
|
r">\s*/dev/sd", # write to disk
|
||||||
|
r"\b(shutdown|reboot|poweroff)\b", # system power
|
||||||
|
r":\(\)\s*\{.*\};\s*:", # fork bomb
|
||||||
|
]
|
||||||
|
self.allow_patterns = allow_patterns or []
|
||||||
|
self.restrict_to_workspace = restrict_to_workspace
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
@ -41,6 +62,9 @@ class ExecTool(Tool):
|
|||||||
|
|
||||||
async def execute(self, command: str, working_dir: str | None = None, **kwargs: Any) -> str:
|
async def execute(self, command: str, working_dir: str | None = None, **kwargs: Any) -> str:
|
||||||
cwd = working_dir or self.working_dir or os.getcwd()
|
cwd = working_dir or self.working_dir or os.getcwd()
|
||||||
|
guard_error = self._guard_command(command, cwd)
|
||||||
|
if guard_error:
|
||||||
|
return guard_error
|
||||||
|
|
||||||
try:
|
try:
|
||||||
process = await asyncio.create_subprocess_shell(
|
process = await asyncio.create_subprocess_shell(
|
||||||
@ -83,3 +107,35 @@ class ExecTool(Tool):
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return f"Error executing command: {str(e)}"
|
return f"Error executing command: {str(e)}"
|
||||||
|
|
||||||
|
def _guard_command(self, command: str, cwd: str) -> str | None:
|
||||||
|
"""Best-effort safety guard for potentially destructive commands."""
|
||||||
|
cmd = command.strip()
|
||||||
|
lower = cmd.lower()
|
||||||
|
|
||||||
|
for pattern in self.deny_patterns:
|
||||||
|
if re.search(pattern, lower):
|
||||||
|
return "Error: Command blocked by safety guard (dangerous pattern detected)"
|
||||||
|
|
||||||
|
if self.allow_patterns:
|
||||||
|
if not any(re.search(p, lower) for p in self.allow_patterns):
|
||||||
|
return "Error: Command blocked by safety guard (not in allowlist)"
|
||||||
|
|
||||||
|
if self.restrict_to_workspace:
|
||||||
|
if "..\\" in cmd or "../" in cmd:
|
||||||
|
return "Error: Command blocked by safety guard (path traversal detected)"
|
||||||
|
|
||||||
|
cwd_path = Path(cwd).resolve()
|
||||||
|
|
||||||
|
win_paths = re.findall(r"[A-Za-z]:\\[^\\\"']+", cmd)
|
||||||
|
posix_paths = re.findall(r"/[^\s\"']+", cmd)
|
||||||
|
|
||||||
|
for raw in win_paths + posix_paths:
|
||||||
|
try:
|
||||||
|
p = Path(raw).resolve()
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
if cwd_path not in p.parents and p != cwd_path:
|
||||||
|
return "Error: Command blocked by safety guard (path outside working dir)"
|
||||||
|
|
||||||
|
return None
|
||||||
|
|||||||
@ -202,7 +202,8 @@ def gateway(
|
|||||||
workspace=config.workspace_path,
|
workspace=config.workspace_path,
|
||||||
model=config.agents.defaults.model,
|
model=config.agents.defaults.model,
|
||||||
max_iterations=config.agents.defaults.max_tool_iterations,
|
max_iterations=config.agents.defaults.max_tool_iterations,
|
||||||
brave_api_key=config.tools.web.search.api_key or None
|
brave_api_key=config.tools.web.search.api_key or None,
|
||||||
|
exec_config=config.tools.exec,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create cron service
|
# Create cron service
|
||||||
@ -309,7 +310,8 @@ def agent(
|
|||||||
bus=bus,
|
bus=bus,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
workspace=config.workspace_path,
|
workspace=config.workspace_path,
|
||||||
brave_api_key=config.tools.web.search.api_key or None
|
brave_api_key=config.tools.web.search.api_key or None,
|
||||||
|
exec_config=config.tools.exec,
|
||||||
)
|
)
|
||||||
|
|
||||||
if message:
|
if message:
|
||||||
|
|||||||
@ -73,9 +73,16 @@ class WebToolsConfig(BaseModel):
|
|||||||
search: WebSearchConfig = Field(default_factory=WebSearchConfig)
|
search: WebSearchConfig = Field(default_factory=WebSearchConfig)
|
||||||
|
|
||||||
|
|
||||||
|
class ExecToolConfig(BaseModel):
|
||||||
|
"""Shell exec tool configuration."""
|
||||||
|
timeout: int = 60
|
||||||
|
restrict_to_workspace: bool = False # If true, block commands accessing paths outside workspace
|
||||||
|
|
||||||
|
|
||||||
class ToolsConfig(BaseModel):
|
class ToolsConfig(BaseModel):
|
||||||
"""Tools configuration."""
|
"""Tools configuration."""
|
||||||
web: WebToolsConfig = Field(default_factory=WebToolsConfig)
|
web: WebToolsConfig = Field(default_factory=WebToolsConfig)
|
||||||
|
exec: ExecToolConfig = Field(default_factory=ExecToolConfig)
|
||||||
|
|
||||||
|
|
||||||
class Config(BaseSettings):
|
class Config(BaseSettings):
|
||||||
|
|||||||
88
tests/test_tool_validation.py
Normal file
88
tests/test_tool_validation.py
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from nanobot.agent.tools.base import Tool
|
||||||
|
from nanobot.agent.tools.registry import ToolRegistry
|
||||||
|
|
||||||
|
|
||||||
|
class SampleTool(Tool):
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "sample"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return "sample tool"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parameters(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {"type": "string", "minLength": 2},
|
||||||
|
"count": {"type": "integer", "minimum": 1, "maximum": 10},
|
||||||
|
"mode": {"type": "string", "enum": ["fast", "full"]},
|
||||||
|
"meta": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"tag": {"type": "string"},
|
||||||
|
"flags": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["tag"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["query", "count"],
|
||||||
|
}
|
||||||
|
|
||||||
|
async def execute(self, **kwargs: Any) -> str:
|
||||||
|
return "ok"
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_params_missing_required() -> None:
|
||||||
|
tool = SampleTool()
|
||||||
|
errors = tool.validate_params({"query": "hi"})
|
||||||
|
assert "missing required count" in "; ".join(errors)
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_params_type_and_range() -> None:
|
||||||
|
tool = SampleTool()
|
||||||
|
errors = tool.validate_params({"query": "hi", "count": 0})
|
||||||
|
assert any("count must be >= 1" in e for e in errors)
|
||||||
|
|
||||||
|
errors = tool.validate_params({"query": "hi", "count": "2"})
|
||||||
|
assert any("count should be integer" in e for e in errors)
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_params_enum_and_min_length() -> None:
|
||||||
|
tool = SampleTool()
|
||||||
|
errors = tool.validate_params({"query": "h", "count": 2, "mode": "slow"})
|
||||||
|
assert any("query must be at least 2 chars" in e for e in errors)
|
||||||
|
assert any("mode must be one of" in e for e in errors)
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_params_nested_object_and_array() -> None:
|
||||||
|
tool = SampleTool()
|
||||||
|
errors = tool.validate_params(
|
||||||
|
{
|
||||||
|
"query": "hi",
|
||||||
|
"count": 2,
|
||||||
|
"meta": {"flags": [1, "ok"]},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert any("missing required meta.tag" in e for e in errors)
|
||||||
|
assert any("meta.flags[0] should be string" in e for e in errors)
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_params_ignores_unknown_fields() -> None:
|
||||||
|
tool = SampleTool()
|
||||||
|
errors = tool.validate_params({"query": "hi", "count": 2, "extra": "x"})
|
||||||
|
assert errors == []
|
||||||
|
|
||||||
|
|
||||||
|
async def test_registry_returns_validation_error() -> None:
|
||||||
|
reg = ToolRegistry()
|
||||||
|
reg.register(SampleTool())
|
||||||
|
result = await reg.execute("sample", {"query": "hi"})
|
||||||
|
assert "Invalid parameters" in result
|
||||||
Loading…
x
Reference in New Issue
Block a user