refactor: simplify parameter validation logic

This commit is contained in:
Re-bin 2026-02-04 03:50:39 +00:00
parent 579cbfc8fe
commit 9a0f8fcc73

View File

@ -53,62 +53,41 @@ class Tool(ABC):
pass
def validate_params(self, params: dict[str, Any]) -> list[str]:
"""
Lightweight JSON schema validation for tool parameters.
Returns a list of error strings (empty if valid).
Unknown params are ignored.
"""
"""Validate tool parameters against JSON schema. Returns error list (empty if valid)."""
schema = self.parameters or {}
if schema.get("type", "object") != "object":
raise ValueError(f"Schema must be object type, got {schema.get('type')!r}")
return self._validate(params, {**schema, "type": "object"}, "")
# Default to an object schema if type is missing, and fail fast on unsupported top-level types.
if "type" not in schema:
schema = {"type": "object", **schema}
elif schema.get("type") != "object":
raise ValueError(
f"Tool parameter schemas must have top-level type 'object'; got {schema.get('type')!r}"
)
return self._validate_schema(params, schema, path="")
def _validate_schema(self, value: Any, schema: dict[str, Any], path: str) -> list[str]:
errors: list[str] = []
expected_type = schema.get("type")
label = path or "parameter"
if expected_type in self._TYPE_MAP and not isinstance(value, self._TYPE_MAP[expected_type]):
return [f"{label} should be {expected_type}"]
if "enum" in schema and value not in schema["enum"]:
def _validate(self, val: Any, schema: dict[str, Any], path: str) -> list[str]:
t, label = schema.get("type"), path or "parameter"
if t in self._TYPE_MAP and not isinstance(val, self._TYPE_MAP[t]):
return [f"{label} should be {t}"]
errors = []
if "enum" in schema and val not in schema["enum"]:
errors.append(f"{label} must be one of {schema['enum']}")
if expected_type in ("integer", "number"):
if "minimum" in schema and value < schema["minimum"]:
if t in ("integer", "number"):
if "minimum" in schema and val < schema["minimum"]:
errors.append(f"{label} must be >= {schema['minimum']}")
if "maximum" in schema and value > schema["maximum"]:
if "maximum" in schema and val > schema["maximum"]:
errors.append(f"{label} must be <= {schema['maximum']}")
if expected_type == "string":
if "minLength" in schema and len(value) < schema["minLength"]:
if t == "string":
if "minLength" in schema and len(val) < schema["minLength"]:
errors.append(f"{label} must be at least {schema['minLength']} chars")
if "maxLength" in schema and len(value) > schema["maxLength"]:
if "maxLength" in schema and len(val) > schema["maxLength"]:
errors.append(f"{label} must be at most {schema['maxLength']} chars")
if expected_type == "object":
properties = schema.get("properties", {})
for key in schema.get("required", []):
if key not in value:
errors.append(f"missing required {path}.{key}" if path else f"missing required {key}")
for key, item in value.items():
if key in properties:
errors.extend(self._validate_schema(item, properties[key], f"{path}.{key}" if path else key))
if expected_type == "array":
items_schema = schema.get("items")
if items_schema:
for idx, item in enumerate(value):
errors.extend(self._validate_schema(item, items_schema, f"{path}[{idx}]" if path else f"[{idx}]"))
if t == "object":
props = schema.get("properties", {})
for k in schema.get("required", []):
if k not in val:
errors.append(f"missing required {path + '.' + k if path else k}")
for k, v in val.items():
if k in props:
errors.extend(self._validate(v, props[k], path + '.' + k if path else k))
if t == "array" and "items" in schema:
for i, item in enumerate(val):
errors.extend(self._validate(item, schema["items"], f"{path}[{i}]" if path else f"[{i}]"))
return errors
def to_schema(self) -> dict[str, Any]: