feat: add parameter validation and safety guard for exec tool

This commit is contained in:
Re-bin 2026-02-04 03:45:26 +00:00
parent e508f73f54
commit a20d887f9e
6 changed files with 64 additions and 58 deletions

View File

@ -40,14 +40,17 @@ class AgentLoop:
workspace: Path, workspace: Path,
model: str | None = None, model: str | None = None,
max_iterations: int = 20, max_iterations: int = 20,
brave_api_key: str | None = None brave_api_key: str | None = None,
exec_config: "ExecToolConfig | None" = None,
): ):
from nanobot.config.schema import ExecToolConfig
self.bus = bus self.bus = bus
self.provider = provider self.provider = provider
self.workspace = workspace self.workspace = workspace
self.model = model or provider.get_default_model() self.model = model or provider.get_default_model()
self.max_iterations = max_iterations self.max_iterations = max_iterations
self.brave_api_key = brave_api_key self.brave_api_key = brave_api_key
self.exec_config = exec_config or ExecToolConfig()
self.context = ContextBuilder(workspace) self.context = ContextBuilder(workspace)
self.sessions = SessionManager(workspace) self.sessions = SessionManager(workspace)
@ -58,6 +61,7 @@ class AgentLoop:
bus=bus, bus=bus,
model=self.model, model=self.model,
brave_api_key=brave_api_key, brave_api_key=brave_api_key,
exec_config=self.exec_config,
) )
self._running = False self._running = False
@ -72,7 +76,11 @@ class AgentLoop:
self.tools.register(ListDirTool()) self.tools.register(ListDirTool())
# Shell tool # Shell tool
self.tools.register(ExecTool(working_dir=str(self.workspace))) self.tools.register(ExecTool(
working_dir=str(self.workspace),
timeout=self.exec_config.timeout,
restrict_to_workspace=self.exec_config.restrict_to_workspace,
))
# Web tools # Web tools
self.tools.register(WebSearchTool(api_key=self.brave_api_key)) self.tools.register(WebSearchTool(api_key=self.brave_api_key))

View File

@ -33,12 +33,15 @@ class SubagentManager:
bus: MessageBus, bus: MessageBus,
model: str | None = None, model: str | None = None,
brave_api_key: str | None = None, brave_api_key: str | None = None,
exec_config: "ExecToolConfig | None" = None,
): ):
from nanobot.config.schema import ExecToolConfig
self.provider = provider self.provider = provider
self.workspace = workspace self.workspace = workspace
self.bus = bus self.bus = bus
self.model = model or provider.get_default_model() self.model = model or provider.get_default_model()
self.brave_api_key = brave_api_key self.brave_api_key = brave_api_key
self.exec_config = exec_config or ExecToolConfig()
self._running_tasks: dict[str, asyncio.Task[None]] = {} self._running_tasks: dict[str, asyncio.Task[None]] = {}
async def spawn( async def spawn(
@ -96,7 +99,11 @@ class SubagentManager:
tools.register(ReadFileTool()) tools.register(ReadFileTool())
tools.register(WriteFileTool()) tools.register(WriteFileTool())
tools.register(ListDirTool()) tools.register(ListDirTool())
tools.register(ExecTool(working_dir=str(self.workspace))) tools.register(ExecTool(
working_dir=str(self.workspace),
timeout=self.exec_config.timeout,
restrict_to_workspace=self.exec_config.restrict_to_workspace,
))
tools.register(WebSearchTool(api_key=self.brave_api_key)) tools.register(WebSearchTool(api_key=self.brave_api_key))
tools.register(WebFetchTool()) tools.register(WebFetchTool())

View File

@ -12,6 +12,15 @@ class Tool(ABC):
the environment, such as reading files, executing commands, etc. the environment, such as reading files, executing commands, etc.
""" """
_TYPE_MAP = {
"string": str,
"integer": int,
"number": (int, float),
"boolean": bool,
"array": list,
"object": dict,
}
@property @property
@abstractmethod @abstractmethod
def name(self) -> str: def name(self) -> str:
@ -65,60 +74,40 @@ class Tool(ABC):
def _validate_schema(self, value: Any, schema: dict[str, Any], path: str) -> list[str]: def _validate_schema(self, value: Any, schema: dict[str, Any], path: str) -> list[str]:
errors: list[str] = [] errors: list[str] = []
expected_type = schema.get("type") expected_type = schema.get("type")
label = path or "parameter"
type_map = { if expected_type in self._TYPE_MAP and not isinstance(value, self._TYPE_MAP[expected_type]):
"string": str, return [f"{label} should be {expected_type}"]
"integer": int,
"number": (int, float),
"boolean": bool,
"array": list,
"object": dict,
}
def label(p: str) -> str:
return p or "parameter"
if expected_type in type_map and not isinstance(value, type_map[expected_type]):
errors.append(f"{label(path)} should be {expected_type}")
return errors
if "enum" in schema and value not in schema["enum"]: if "enum" in schema and value not in schema["enum"]:
errors.append(f"{label(path)} must be one of {schema['enum']}") errors.append(f"{label} must be one of {schema['enum']}")
if expected_type in ("integer", "number"): if expected_type in ("integer", "number"):
if "minimum" in schema and value < schema["minimum"]: if "minimum" in schema and value < schema["minimum"]:
errors.append(f"{label(path)} must be >= {schema['minimum']}") errors.append(f"{label} must be >= {schema['minimum']}")
if "maximum" in schema and value > schema["maximum"]: if "maximum" in schema and value > schema["maximum"]:
errors.append(f"{label(path)} must be <= {schema['maximum']}") errors.append(f"{label} must be <= {schema['maximum']}")
if expected_type == "string": if expected_type == "string":
if "minLength" in schema and len(value) < schema["minLength"]: if "minLength" in schema and len(value) < schema["minLength"]:
errors.append(f"{label(path)} must be at least {schema['minLength']} chars") errors.append(f"{label} must be at least {schema['minLength']} chars")
if "maxLength" in schema and len(value) > schema["maxLength"]: if "maxLength" in schema and len(value) > schema["maxLength"]:
errors.append(f"{label(path)} must be at most {schema['maxLength']} chars") errors.append(f"{label} must be at most {schema['maxLength']} chars")
if expected_type == "object": if expected_type == "object":
properties = schema.get("properties", {}) or {} properties = schema.get("properties", {})
required = set(schema.get("required", []) or []) for key in schema.get("required", []):
for key in required:
if key not in value: if key not in value:
p = f"{path}.{key}" if path else key errors.append(f"missing required {path}.{key}" if path else f"missing required {key}")
errors.append(f"missing required {p}")
for key, item in value.items(): for key, item in value.items():
prop_schema = properties.get(key) if key in properties:
if not prop_schema: errors.extend(self._validate_schema(item, properties[key], f"{path}.{key}" if path else key))
continue # ignore unknown fields
p = f"{path}.{key}" if path else key
errors.extend(self._validate_schema(item, prop_schema, p))
if expected_type == "array": if expected_type == "array":
items_schema = schema.get("items") items_schema = schema.get("items")
if items_schema: if items_schema:
for idx, item in enumerate(value): for idx, item in enumerate(value):
p = f"{path}[{idx}]" if path else f"[{idx}]" errors.extend(self._validate_schema(item, items_schema, f"{path}[{idx}]" if path else f"[{idx}]"))
errors.extend(self._validate_schema(item, items_schema, p))
return errors return errors

View File

@ -18,29 +18,22 @@ class ExecTool(Tool):
working_dir: str | None = None, working_dir: str | None = None,
deny_patterns: list[str] | None = None, deny_patterns: list[str] | None = None,
allow_patterns: list[str] | None = None, allow_patterns: list[str] | None = None,
restrict_to_working_dir: bool = False, restrict_to_workspace: bool = False,
): ):
self.timeout = timeout self.timeout = timeout
self.working_dir = working_dir self.working_dir = working_dir
self.deny_patterns = deny_patterns or [ self.deny_patterns = deny_patterns or [
r"\brm\s+-rf\b", r"\brm\s+-[rf]{1,2}\b", # rm -r, rm -rf, rm -fr
r"\brm\s+-fr\b", r"\bdel\s+/[fq]\b", # del /f, del /q
r"\brm\s+-r\b", r"\brmdir\s+/s\b", # rmdir /s
r"\bdel\s+/f\b", r"\b(format|mkfs|diskpart)\b", # disk operations
r"\bdel\s+/q\b", r"\bdd\s+if=", # dd
r"\brmdir\s+/s\b", r">\s*/dev/sd", # write to disk
r"\bformat\b", r"\b(shutdown|reboot|poweroff)\b", # system power
r"\bmkfs\b", r":\(\)\s*\{.*\};\s*:", # fork bomb
r"\bdd\s+if=",
r">\s*/dev/sd",
r"\bdiskpart\b",
r"\bshutdown\b",
r"\breboot\b",
r"\bpoweroff\b",
r":\(\)\s*\{\s*:\s*\|\s*:\s*&\s*\};\s*:",
] ]
self.allow_patterns = allow_patterns or [] self.allow_patterns = allow_patterns or []
self.restrict_to_working_dir = restrict_to_working_dir self.restrict_to_workspace = restrict_to_workspace
@property @property
def name(self) -> str: def name(self) -> str:
@ -128,14 +121,14 @@ class ExecTool(Tool):
if not any(re.search(p, lower) for p in self.allow_patterns): if not any(re.search(p, lower) for p in self.allow_patterns):
return "Error: Command blocked by safety guard (not in allowlist)" return "Error: Command blocked by safety guard (not in allowlist)"
if self.restrict_to_working_dir: if self.restrict_to_workspace:
if "..\\" in cmd or "../" in cmd: if "..\\" in cmd or "../" in cmd:
return "Error: Command blocked by safety guard (path traversal detected)" return "Error: Command blocked by safety guard (path traversal detected)"
cwd_path = Path(cwd).resolve() cwd_path = Path(cwd).resolve()
win_paths = re.findall(r"[A-Za-z]:\\[^\\\"']+", cmd) win_paths = re.findall(r"[A-Za-z]:\\[^\\\"']+", cmd)
posix_paths = re.findall(r"/[^\\s\"']+", cmd) posix_paths = re.findall(r"/[^\s\"']+", cmd)
for raw in win_paths + posix_paths: for raw in win_paths + posix_paths:
try: try:

View File

@ -202,7 +202,8 @@ def gateway(
workspace=config.workspace_path, workspace=config.workspace_path,
model=config.agents.defaults.model, model=config.agents.defaults.model,
max_iterations=config.agents.defaults.max_tool_iterations, max_iterations=config.agents.defaults.max_tool_iterations,
brave_api_key=config.tools.web.search.api_key or None brave_api_key=config.tools.web.search.api_key or None,
exec_config=config.tools.exec,
) )
# Create cron service # Create cron service
@ -309,7 +310,8 @@ def agent(
bus=bus, bus=bus,
provider=provider, provider=provider,
workspace=config.workspace_path, workspace=config.workspace_path,
brave_api_key=config.tools.web.search.api_key or None brave_api_key=config.tools.web.search.api_key or None,
exec_config=config.tools.exec,
) )
if message: if message:

View File

@ -73,9 +73,16 @@ class WebToolsConfig(BaseModel):
search: WebSearchConfig = Field(default_factory=WebSearchConfig) search: WebSearchConfig = Field(default_factory=WebSearchConfig)
class ExecToolConfig(BaseModel):
"""Shell exec tool configuration."""
timeout: int = 60
restrict_to_workspace: bool = False # If true, block commands accessing paths outside workspace
class ToolsConfig(BaseModel): class ToolsConfig(BaseModel):
"""Tools configuration.""" """Tools configuration."""
web: WebToolsConfig = Field(default_factory=WebToolsConfig) web: WebToolsConfig = Field(default_factory=WebToolsConfig)
exec: ExecToolConfig = Field(default_factory=ExecToolConfig)
class Config(BaseSettings): class Config(BaseSettings):