Validate tool params and add tests
This commit is contained in:
parent
30d6e4b4b6
commit
7ef18c4e8a
@ -43,6 +43,79 @@ class Tool(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def validate_params(self, params: dict[str, Any]) -> list[str]:
|
||||||
|
"""
|
||||||
|
Lightweight JSON schema validation for tool parameters.
|
||||||
|
|
||||||
|
Returns a list of error strings (empty if valid).
|
||||||
|
Unknown params are ignored.
|
||||||
|
"""
|
||||||
|
schema = self.parameters or {}
|
||||||
|
if schema.get("type") != "object":
|
||||||
|
return []
|
||||||
|
|
||||||
|
return self._validate_schema(params, schema, path="")
|
||||||
|
|
||||||
|
def _validate_schema(self, value: Any, schema: dict[str, Any], path: str) -> list[str]:
|
||||||
|
errors: list[str] = []
|
||||||
|
expected_type = schema.get("type")
|
||||||
|
|
||||||
|
type_map = {
|
||||||
|
"string": str,
|
||||||
|
"integer": int,
|
||||||
|
"number": (int, float),
|
||||||
|
"boolean": bool,
|
||||||
|
"array": list,
|
||||||
|
"object": dict,
|
||||||
|
}
|
||||||
|
|
||||||
|
def label(p: str) -> str:
|
||||||
|
return p or "parameter"
|
||||||
|
|
||||||
|
if expected_type in type_map and not isinstance(value, type_map[expected_type]):
|
||||||
|
errors.append(f"{label(path)} should be {expected_type}")
|
||||||
|
return errors
|
||||||
|
|
||||||
|
if "enum" in schema and value not in schema["enum"]:
|
||||||
|
errors.append(f"{label(path)} must be one of {schema['enum']}")
|
||||||
|
|
||||||
|
if expected_type in ("integer", "number"):
|
||||||
|
if "minimum" in schema and value < schema["minimum"]:
|
||||||
|
errors.append(f"{label(path)} must be >= {schema['minimum']}")
|
||||||
|
if "maximum" in schema and value > schema["maximum"]:
|
||||||
|
errors.append(f"{label(path)} must be <= {schema['maximum']}")
|
||||||
|
|
||||||
|
if expected_type == "string":
|
||||||
|
if "minLength" in schema and len(value) < schema["minLength"]:
|
||||||
|
errors.append(f"{label(path)} must be at least {schema['minLength']} chars")
|
||||||
|
if "maxLength" in schema and len(value) > schema["maxLength"]:
|
||||||
|
errors.append(f"{label(path)} must be at most {schema['maxLength']} chars")
|
||||||
|
|
||||||
|
if expected_type == "object":
|
||||||
|
properties = schema.get("properties", {}) or {}
|
||||||
|
required = set(schema.get("required", []) or [])
|
||||||
|
|
||||||
|
for key in required:
|
||||||
|
if key not in value:
|
||||||
|
p = f"{path}.{key}" if path else key
|
||||||
|
errors.append(f"missing required {p}")
|
||||||
|
|
||||||
|
for key, item in value.items():
|
||||||
|
prop_schema = properties.get(key)
|
||||||
|
if not prop_schema:
|
||||||
|
continue # ignore unknown fields
|
||||||
|
p = f"{path}.{key}" if path else key
|
||||||
|
errors.extend(self._validate_schema(item, prop_schema, p))
|
||||||
|
|
||||||
|
if expected_type == "array":
|
||||||
|
items_schema = schema.get("items")
|
||||||
|
if items_schema:
|
||||||
|
for idx, item in enumerate(value):
|
||||||
|
p = f"{path}[{idx}]" if path else f"[{idx}]"
|
||||||
|
errors.extend(self._validate_schema(item, items_schema, p))
|
||||||
|
|
||||||
|
return errors
|
||||||
|
|
||||||
def to_schema(self) -> dict[str, Any]:
|
def to_schema(self) -> dict[str, Any]:
|
||||||
"""Convert tool to OpenAI function schema format."""
|
"""Convert tool to OpenAI function schema format."""
|
||||||
return {
|
return {
|
||||||
|
|||||||
@ -54,6 +54,9 @@ class ToolRegistry:
|
|||||||
return f"Error: Tool '{name}' not found"
|
return f"Error: Tool '{name}' not found"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
errors = tool.validate_params(params)
|
||||||
|
if errors:
|
||||||
|
return f"Error: Invalid parameters for tool '{name}': " + "; ".join(errors)
|
||||||
return await tool.execute(**params)
|
return await tool.execute(**params)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return f"Error executing {name}: {str(e)}"
|
return f"Error executing {name}: {str(e)}"
|
||||||
|
|||||||
88
tests/test_tool_validation.py
Normal file
88
tests/test_tool_validation.py
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from nanobot.agent.tools.base import Tool
|
||||||
|
from nanobot.agent.tools.registry import ToolRegistry
|
||||||
|
|
||||||
|
|
||||||
|
class SampleTool(Tool):
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "sample"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return "sample tool"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parameters(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {"type": "string", "minLength": 2},
|
||||||
|
"count": {"type": "integer", "minimum": 1, "maximum": 10},
|
||||||
|
"mode": {"type": "string", "enum": ["fast", "full"]},
|
||||||
|
"meta": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"tag": {"type": "string"},
|
||||||
|
"flags": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["tag"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["query", "count"],
|
||||||
|
}
|
||||||
|
|
||||||
|
async def execute(self, **kwargs: Any) -> str:
|
||||||
|
return "ok"
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_params_missing_required() -> None:
|
||||||
|
tool = SampleTool()
|
||||||
|
errors = tool.validate_params({"query": "hi"})
|
||||||
|
assert "missing required count" in "; ".join(errors)
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_params_type_and_range() -> None:
|
||||||
|
tool = SampleTool()
|
||||||
|
errors = tool.validate_params({"query": "hi", "count": 0})
|
||||||
|
assert any("count must be >= 1" in e for e in errors)
|
||||||
|
|
||||||
|
errors = tool.validate_params({"query": "hi", "count": "2"})
|
||||||
|
assert any("count should be integer" in e for e in errors)
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_params_enum_and_min_length() -> None:
|
||||||
|
tool = SampleTool()
|
||||||
|
errors = tool.validate_params({"query": "h", "count": 2, "mode": "slow"})
|
||||||
|
assert any("query must be at least 2 chars" in e for e in errors)
|
||||||
|
assert any("mode must be one of" in e for e in errors)
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_params_nested_object_and_array() -> None:
|
||||||
|
tool = SampleTool()
|
||||||
|
errors = tool.validate_params(
|
||||||
|
{
|
||||||
|
"query": "hi",
|
||||||
|
"count": 2,
|
||||||
|
"meta": {"flags": [1, "ok"]},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert any("missing required meta.tag" in e for e in errors)
|
||||||
|
assert any("meta.flags[0] should be string" in e for e in errors)
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_params_ignores_unknown_fields() -> None:
|
||||||
|
tool = SampleTool()
|
||||||
|
errors = tool.validate_params({"query": "hi", "count": 2, "extra": "x"})
|
||||||
|
assert errors == []
|
||||||
|
|
||||||
|
|
||||||
|
async def test_registry_returns_validation_error() -> None:
|
||||||
|
reg = ToolRegistry()
|
||||||
|
reg.register(SampleTool())
|
||||||
|
result = await reg.execute("sample", {"query": "hi"})
|
||||||
|
assert "Invalid parameters" in result
|
||||||
Loading…
x
Reference in New Issue
Block a user