diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index 4a96b84..bfe6e89 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -40,14 +40,17 @@ class AgentLoop: workspace: Path, model: str | None = None, max_iterations: int = 20, - brave_api_key: str | None = None + brave_api_key: str | None = None, + exec_config: "ExecToolConfig | None" = None, ): + from nanobot.config.schema import ExecToolConfig self.bus = bus self.provider = provider self.workspace = workspace self.model = model or provider.get_default_model() self.max_iterations = max_iterations self.brave_api_key = brave_api_key + self.exec_config = exec_config or ExecToolConfig() self.context = ContextBuilder(workspace) self.sessions = SessionManager(workspace) @@ -58,6 +61,7 @@ class AgentLoop: bus=bus, model=self.model, brave_api_key=brave_api_key, + exec_config=self.exec_config, ) self._running = False @@ -72,7 +76,11 @@ class AgentLoop: self.tools.register(ListDirTool()) # Shell tool - self.tools.register(ExecTool(working_dir=str(self.workspace))) + self.tools.register(ExecTool( + working_dir=str(self.workspace), + timeout=self.exec_config.timeout, + restrict_to_workspace=self.exec_config.restrict_to_workspace, + )) # Web tools self.tools.register(WebSearchTool(api_key=self.brave_api_key)) diff --git a/nanobot/agent/subagent.py b/nanobot/agent/subagent.py index d3b320c..05ffbb8 100644 --- a/nanobot/agent/subagent.py +++ b/nanobot/agent/subagent.py @@ -33,12 +33,15 @@ class SubagentManager: bus: MessageBus, model: str | None = None, brave_api_key: str | None = None, + exec_config: "ExecToolConfig | None" = None, ): + from nanobot.config.schema import ExecToolConfig self.provider = provider self.workspace = workspace self.bus = bus self.model = model or provider.get_default_model() self.brave_api_key = brave_api_key + self.exec_config = exec_config or ExecToolConfig() self._running_tasks: dict[str, asyncio.Task[None]] = {} async def spawn( @@ -96,7 +99,11 @@ class SubagentManager: tools.register(ReadFileTool()) tools.register(WriteFileTool()) tools.register(ListDirTool()) - tools.register(ExecTool(working_dir=str(self.workspace))) + tools.register(ExecTool( + working_dir=str(self.workspace), + timeout=self.exec_config.timeout, + restrict_to_workspace=self.exec_config.restrict_to_workspace, + )) tools.register(WebSearchTool(api_key=self.brave_api_key)) tools.register(WebFetchTool()) diff --git a/nanobot/agent/tools/base.py b/nanobot/agent/tools/base.py index 6fcfec6..cbaadbd 100644 --- a/nanobot/agent/tools/base.py +++ b/nanobot/agent/tools/base.py @@ -12,6 +12,15 @@ class Tool(ABC): the environment, such as reading files, executing commands, etc. """ + _TYPE_MAP = { + "string": str, + "integer": int, + "number": (int, float), + "boolean": bool, + "array": list, + "object": dict, + } + @property @abstractmethod def name(self) -> str: @@ -42,6 +51,65 @@ class Tool(ABC): String result of the tool execution. """ pass + + def validate_params(self, params: dict[str, Any]) -> list[str]: + """ + Lightweight JSON schema validation for tool parameters. + + Returns a list of error strings (empty if valid). + Unknown params are ignored. + """ + schema = self.parameters or {} + + # Default to an object schema if type is missing, and fail fast on unsupported top-level types. + if "type" not in schema: + schema = {"type": "object", **schema} + elif schema.get("type") != "object": + raise ValueError( + f"Tool parameter schemas must have top-level type 'object'; got {schema.get('type')!r}" + ) + + return self._validate_schema(params, schema, path="") + + def _validate_schema(self, value: Any, schema: dict[str, Any], path: str) -> list[str]: + errors: list[str] = [] + expected_type = schema.get("type") + label = path or "parameter" + + if expected_type in self._TYPE_MAP and not isinstance(value, self._TYPE_MAP[expected_type]): + return [f"{label} should be {expected_type}"] + + if "enum" in schema and value not in schema["enum"]: + errors.append(f"{label} must be one of {schema['enum']}") + + if expected_type in ("integer", "number"): + if "minimum" in schema and value < schema["minimum"]: + errors.append(f"{label} must be >= {schema['minimum']}") + if "maximum" in schema and value > schema["maximum"]: + errors.append(f"{label} must be <= {schema['maximum']}") + + if expected_type == "string": + if "minLength" in schema and len(value) < schema["minLength"]: + errors.append(f"{label} must be at least {schema['minLength']} chars") + if "maxLength" in schema and len(value) > schema["maxLength"]: + errors.append(f"{label} must be at most {schema['maxLength']} chars") + + if expected_type == "object": + properties = schema.get("properties", {}) + for key in schema.get("required", []): + if key not in value: + errors.append(f"missing required {path}.{key}" if path else f"missing required {key}") + for key, item in value.items(): + if key in properties: + errors.extend(self._validate_schema(item, properties[key], f"{path}.{key}" if path else key)) + + if expected_type == "array": + items_schema = schema.get("items") + if items_schema: + for idx, item in enumerate(value): + errors.extend(self._validate_schema(item, items_schema, f"{path}[{idx}]" if path else f"[{idx}]")) + + return errors def to_schema(self) -> dict[str, Any]: """Convert tool to OpenAI function schema format.""" diff --git a/nanobot/agent/tools/registry.py b/nanobot/agent/tools/registry.py index 1e8f56d..d9b33ff 100644 --- a/nanobot/agent/tools/registry.py +++ b/nanobot/agent/tools/registry.py @@ -52,8 +52,11 @@ class ToolRegistry: tool = self._tools.get(name) if not tool: return f"Error: Tool '{name}' not found" - + try: + errors = tool.validate_params(params) + if errors: + return f"Error: Invalid parameters for tool '{name}': " + "; ".join(errors) return await tool.execute(**params) except Exception as e: return f"Error executing {name}: {str(e)}" diff --git a/nanobot/agent/tools/shell.py b/nanobot/agent/tools/shell.py index bf7f064..143d187 100644 --- a/nanobot/agent/tools/shell.py +++ b/nanobot/agent/tools/shell.py @@ -2,6 +2,8 @@ import asyncio import os +import re +from pathlib import Path from typing import Any from nanobot.agent.tools.base import Tool @@ -10,9 +12,28 @@ from nanobot.agent.tools.base import Tool class ExecTool(Tool): """Tool to execute shell commands.""" - def __init__(self, timeout: int = 60, working_dir: str | None = None): + def __init__( + self, + timeout: int = 60, + working_dir: str | None = None, + deny_patterns: list[str] | None = None, + allow_patterns: list[str] | None = None, + restrict_to_workspace: bool = False, + ): self.timeout = timeout self.working_dir = working_dir + self.deny_patterns = deny_patterns or [ + r"\brm\s+-[rf]{1,2}\b", # rm -r, rm -rf, rm -fr + r"\bdel\s+/[fq]\b", # del /f, del /q + r"\brmdir\s+/s\b", # rmdir /s + r"\b(format|mkfs|diskpart)\b", # disk operations + r"\bdd\s+if=", # dd + r">\s*/dev/sd", # write to disk + r"\b(shutdown|reboot|poweroff)\b", # system power + r":\(\)\s*\{.*\};\s*:", # fork bomb + ] + self.allow_patterns = allow_patterns or [] + self.restrict_to_workspace = restrict_to_workspace @property def name(self) -> str: @@ -41,6 +62,9 @@ class ExecTool(Tool): async def execute(self, command: str, working_dir: str | None = None, **kwargs: Any) -> str: cwd = working_dir or self.working_dir or os.getcwd() + guard_error = self._guard_command(command, cwd) + if guard_error: + return guard_error try: process = await asyncio.create_subprocess_shell( @@ -83,3 +107,35 @@ class ExecTool(Tool): except Exception as e: return f"Error executing command: {str(e)}" + + def _guard_command(self, command: str, cwd: str) -> str | None: + """Best-effort safety guard for potentially destructive commands.""" + cmd = command.strip() + lower = cmd.lower() + + for pattern in self.deny_patterns: + if re.search(pattern, lower): + return "Error: Command blocked by safety guard (dangerous pattern detected)" + + if self.allow_patterns: + if not any(re.search(p, lower) for p in self.allow_patterns): + return "Error: Command blocked by safety guard (not in allowlist)" + + if self.restrict_to_workspace: + if "..\\" in cmd or "../" in cmd: + return "Error: Command blocked by safety guard (path traversal detected)" + + cwd_path = Path(cwd).resolve() + + win_paths = re.findall(r"[A-Za-z]:\\[^\\\"']+", cmd) + posix_paths = re.findall(r"/[^\s\"']+", cmd) + + for raw in win_paths + posix_paths: + try: + p = Path(raw).resolve() + except Exception: + continue + if cwd_path not in p.parents and p != cwd_path: + return "Error: Command blocked by safety guard (path outside working dir)" + + return None diff --git a/nanobot/cli/commands.py b/nanobot/cli/commands.py index 5ecc31b..6b95667 100644 --- a/nanobot/cli/commands.py +++ b/nanobot/cli/commands.py @@ -202,7 +202,8 @@ def gateway( workspace=config.workspace_path, model=config.agents.defaults.model, max_iterations=config.agents.defaults.max_tool_iterations, - brave_api_key=config.tools.web.search.api_key or None + brave_api_key=config.tools.web.search.api_key or None, + exec_config=config.tools.exec, ) # Create cron service @@ -309,7 +310,8 @@ def agent( bus=bus, provider=provider, workspace=config.workspace_path, - brave_api_key=config.tools.web.search.api_key or None + brave_api_key=config.tools.web.search.api_key or None, + exec_config=config.tools.exec, ) if message: diff --git a/nanobot/config/schema.py b/nanobot/config/schema.py index 71e3361..4c34834 100644 --- a/nanobot/config/schema.py +++ b/nanobot/config/schema.py @@ -73,9 +73,16 @@ class WebToolsConfig(BaseModel): search: WebSearchConfig = Field(default_factory=WebSearchConfig) +class ExecToolConfig(BaseModel): + """Shell exec tool configuration.""" + timeout: int = 60 + restrict_to_workspace: bool = False # If true, block commands accessing paths outside workspace + + class ToolsConfig(BaseModel): """Tools configuration.""" web: WebToolsConfig = Field(default_factory=WebToolsConfig) + exec: ExecToolConfig = Field(default_factory=ExecToolConfig) class Config(BaseSettings): diff --git a/tests/test_tool_validation.py b/tests/test_tool_validation.py new file mode 100644 index 0000000..f11c667 --- /dev/null +++ b/tests/test_tool_validation.py @@ -0,0 +1,88 @@ +from typing import Any + +from nanobot.agent.tools.base import Tool +from nanobot.agent.tools.registry import ToolRegistry + + +class SampleTool(Tool): + @property + def name(self) -> str: + return "sample" + + @property + def description(self) -> str: + return "sample tool" + + @property + def parameters(self) -> dict[str, Any]: + return { + "type": "object", + "properties": { + "query": {"type": "string", "minLength": 2}, + "count": {"type": "integer", "minimum": 1, "maximum": 10}, + "mode": {"type": "string", "enum": ["fast", "full"]}, + "meta": { + "type": "object", + "properties": { + "tag": {"type": "string"}, + "flags": { + "type": "array", + "items": {"type": "string"}, + }, + }, + "required": ["tag"], + }, + }, + "required": ["query", "count"], + } + + async def execute(self, **kwargs: Any) -> str: + return "ok" + + +def test_validate_params_missing_required() -> None: + tool = SampleTool() + errors = tool.validate_params({"query": "hi"}) + assert "missing required count" in "; ".join(errors) + + +def test_validate_params_type_and_range() -> None: + tool = SampleTool() + errors = tool.validate_params({"query": "hi", "count": 0}) + assert any("count must be >= 1" in e for e in errors) + + errors = tool.validate_params({"query": "hi", "count": "2"}) + assert any("count should be integer" in e for e in errors) + + +def test_validate_params_enum_and_min_length() -> None: + tool = SampleTool() + errors = tool.validate_params({"query": "h", "count": 2, "mode": "slow"}) + assert any("query must be at least 2 chars" in e for e in errors) + assert any("mode must be one of" in e for e in errors) + + +def test_validate_params_nested_object_and_array() -> None: + tool = SampleTool() + errors = tool.validate_params( + { + "query": "hi", + "count": 2, + "meta": {"flags": [1, "ok"]}, + } + ) + assert any("missing required meta.tag" in e for e in errors) + assert any("meta.flags[0] should be string" in e for e in errors) + + +def test_validate_params_ignores_unknown_fields() -> None: + tool = SampleTool() + errors = tool.validate_params({"query": "hi", "count": 2, "extra": "x"}) + assert errors == [] + + +async def test_registry_returns_validation_error() -> None: + reg = ToolRegistry() + reg.register(SampleTool()) + result = await reg.execute("sample", {"query": "hi"}) + assert "Invalid parameters" in result