From 7ef18c4e8a1b2aaa9df7e7e37ff2ffa1361d1621 Mon Sep 17 00:00:00 2001 From: Kiplangatkorir Date: Mon, 2 Feb 2026 20:39:08 +0300 Subject: [PATCH] Validate tool params and add tests --- nanobot/agent/tools/base.py | 73 +++++++++++++++++++++++++++ nanobot/agent/tools/registry.py | 5 +- tests/test_tool_validation.py | 88 +++++++++++++++++++++++++++++++++ 3 files changed, 165 insertions(+), 1 deletion(-) create mode 100644 tests/test_tool_validation.py diff --git a/nanobot/agent/tools/base.py b/nanobot/agent/tools/base.py index 6fcfec6..355150f 100644 --- a/nanobot/agent/tools/base.py +++ b/nanobot/agent/tools/base.py @@ -42,6 +42,79 @@ class Tool(ABC): String result of the tool execution. """ pass + + def validate_params(self, params: dict[str, Any]) -> list[str]: + """ + Lightweight JSON schema validation for tool parameters. + + Returns a list of error strings (empty if valid). + Unknown params are ignored. + """ + schema = self.parameters or {} + if schema.get("type") != "object": + return [] + + return self._validate_schema(params, schema, path="") + + def _validate_schema(self, value: Any, schema: dict[str, Any], path: str) -> list[str]: + errors: list[str] = [] + expected_type = schema.get("type") + + type_map = { + "string": str, + "integer": int, + "number": (int, float), + "boolean": bool, + "array": list, + "object": dict, + } + + def label(p: str) -> str: + return p or "parameter" + + if expected_type in type_map and not isinstance(value, type_map[expected_type]): + errors.append(f"{label(path)} should be {expected_type}") + return errors + + if "enum" in schema and value not in schema["enum"]: + errors.append(f"{label(path)} must be one of {schema['enum']}") + + if expected_type in ("integer", "number"): + if "minimum" in schema and value < schema["minimum"]: + errors.append(f"{label(path)} must be >= {schema['minimum']}") + if "maximum" in schema and value > schema["maximum"]: + errors.append(f"{label(path)} must be <= {schema['maximum']}") + + if expected_type == "string": + if "minLength" in schema and len(value) < schema["minLength"]: + errors.append(f"{label(path)} must be at least {schema['minLength']} chars") + if "maxLength" in schema and len(value) > schema["maxLength"]: + errors.append(f"{label(path)} must be at most {schema['maxLength']} chars") + + if expected_type == "object": + properties = schema.get("properties", {}) or {} + required = set(schema.get("required", []) or []) + + for key in required: + if key not in value: + p = f"{path}.{key}" if path else key + errors.append(f"missing required {p}") + + for key, item in value.items(): + prop_schema = properties.get(key) + if not prop_schema: + continue # ignore unknown fields + p = f"{path}.{key}" if path else key + errors.extend(self._validate_schema(item, prop_schema, p)) + + if expected_type == "array": + items_schema = schema.get("items") + if items_schema: + for idx, item in enumerate(value): + p = f"{path}[{idx}]" if path else f"[{idx}]" + errors.extend(self._validate_schema(item, items_schema, p)) + + return errors def to_schema(self) -> dict[str, Any]: """Convert tool to OpenAI function schema format.""" diff --git a/nanobot/agent/tools/registry.py b/nanobot/agent/tools/registry.py index 1e8f56d..d9b33ff 100644 --- a/nanobot/agent/tools/registry.py +++ b/nanobot/agent/tools/registry.py @@ -52,8 +52,11 @@ class ToolRegistry: tool = self._tools.get(name) if not tool: return f"Error: Tool '{name}' not found" - + try: + errors = tool.validate_params(params) + if errors: + return f"Error: Invalid parameters for tool '{name}': " + "; ".join(errors) return await tool.execute(**params) except Exception as e: return f"Error executing {name}: {str(e)}" diff --git a/tests/test_tool_validation.py b/tests/test_tool_validation.py new file mode 100644 index 0000000..f11c667 --- /dev/null +++ b/tests/test_tool_validation.py @@ -0,0 +1,88 @@ +from typing import Any + +from nanobot.agent.tools.base import Tool +from nanobot.agent.tools.registry import ToolRegistry + + +class SampleTool(Tool): + @property + def name(self) -> str: + return "sample" + + @property + def description(self) -> str: + return "sample tool" + + @property + def parameters(self) -> dict[str, Any]: + return { + "type": "object", + "properties": { + "query": {"type": "string", "minLength": 2}, + "count": {"type": "integer", "minimum": 1, "maximum": 10}, + "mode": {"type": "string", "enum": ["fast", "full"]}, + "meta": { + "type": "object", + "properties": { + "tag": {"type": "string"}, + "flags": { + "type": "array", + "items": {"type": "string"}, + }, + }, + "required": ["tag"], + }, + }, + "required": ["query", "count"], + } + + async def execute(self, **kwargs: Any) -> str: + return "ok" + + +def test_validate_params_missing_required() -> None: + tool = SampleTool() + errors = tool.validate_params({"query": "hi"}) + assert "missing required count" in "; ".join(errors) + + +def test_validate_params_type_and_range() -> None: + tool = SampleTool() + errors = tool.validate_params({"query": "hi", "count": 0}) + assert any("count must be >= 1" in e for e in errors) + + errors = tool.validate_params({"query": "hi", "count": "2"}) + assert any("count should be integer" in e for e in errors) + + +def test_validate_params_enum_and_min_length() -> None: + tool = SampleTool() + errors = tool.validate_params({"query": "h", "count": 2, "mode": "slow"}) + assert any("query must be at least 2 chars" in e for e in errors) + assert any("mode must be one of" in e for e in errors) + + +def test_validate_params_nested_object_and_array() -> None: + tool = SampleTool() + errors = tool.validate_params( + { + "query": "hi", + "count": 2, + "meta": {"flags": [1, "ok"]}, + } + ) + assert any("missing required meta.tag" in e for e in errors) + assert any("meta.flags[0] should be string" in e for e in errors) + + +def test_validate_params_ignores_unknown_fields() -> None: + tool = SampleTool() + errors = tool.validate_params({"query": "hi", "count": 2, "extra": "x"}) + assert errors == [] + + +async def test_registry_returns_validation_error() -> None: + reg = ToolRegistry() + reg.register(SampleTool()) + result = await reg.execute("sample", {"query": "hi"}) + assert "Invalid parameters" in result