- Add Groq provider config for voice transcription support - Pass Groq API key to Telegram channel for voice transcription - Increase Ollama timeout settings (10min read timeout for slow GPU responses) - Improve timeout handling in custom provider
131 lines
6.1 KiB
Python
131 lines
6.1 KiB
Python
"""Direct OpenAI-compatible provider — bypasses LiteLLM."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import Any
|
|
|
|
import json_repair
|
|
from openai import AsyncOpenAI
|
|
|
|
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
|
|
|
|
|
class CustomProvider(LLMProvider):
|
|
|
|
def __init__(self, api_key: str = "no-key", api_base: str = "http://localhost:8000/v1", default_model: str = "default"):
|
|
super().__init__(api_key, api_base)
|
|
self.default_model = default_model
|
|
# Set longer timeout for Ollama (especially with GPU, first load can be slow)
|
|
from openai import Timeout
|
|
# Set separate timeouts: connect, read, write, pool
|
|
# Ollama can be slow, especially on first request
|
|
self._client = AsyncOpenAI(
|
|
api_key=api_key,
|
|
base_url=api_base,
|
|
timeout=Timeout(
|
|
connect=60.0, # Connection timeout
|
|
read=600.0, # Read timeout (10 min for slow Ollama responses)
|
|
write=60.0, # Write timeout
|
|
pool=60.0 # Pool timeout
|
|
)
|
|
)
|
|
|
|
async def chat(self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None,
|
|
model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7) -> LLMResponse:
|
|
kwargs: dict[str, Any] = {"model": model or self.default_model, "messages": messages,
|
|
"max_tokens": max(1, max_tokens), "temperature": temperature}
|
|
if tools:
|
|
kwargs.update(tools=tools, tool_choice="auto")
|
|
try:
|
|
import asyncio
|
|
# Add explicit timeout wrapper (longer for Ollama)
|
|
return self._parse(await asyncio.wait_for(
|
|
self._client.chat.completions.create(**kwargs),
|
|
timeout=310.0 # Slightly longer than client timeout (300s)
|
|
))
|
|
except asyncio.TimeoutError:
|
|
return LLMResponse(content="Error: Request timed out after 310 seconds", finish_reason="error")
|
|
except Exception as e:
|
|
return LLMResponse(content=f"Error: {e}", finish_reason="error")
|
|
|
|
def _parse(self, response: Any) -> LLMResponse:
|
|
choice = response.choices[0]
|
|
msg = choice.message
|
|
|
|
# First, try to get structured tool calls
|
|
tool_calls = [
|
|
ToolCallRequest(id=tc.id, name=tc.function.name,
|
|
arguments=json_repair.loads(tc.function.arguments) if isinstance(tc.function.arguments, str) else tc.function.arguments)
|
|
for tc in (msg.tool_calls or [])
|
|
]
|
|
|
|
# If no structured tool calls, try to parse from content (Ollama sometimes returns JSON in content)
|
|
# Only parse if content looks like it contains a tool call JSON (to avoid false positives)
|
|
content = msg.content or ""
|
|
if not tool_calls and content and '"name"' in content and '"parameters"' in content:
|
|
import re
|
|
# Look for JSON tool call patterns: {"name": "exec", "parameters": {...}}
|
|
# Find complete JSON objects by matching braces
|
|
pattern = r'\{\s*"name"\s*:\s*"(\w+)"'
|
|
start_pos = 0
|
|
max_iterations = 5 # Safety limit
|
|
iteration = 0
|
|
while iteration < max_iterations:
|
|
iteration += 1
|
|
match = re.search(pattern, content[start_pos:])
|
|
if not match:
|
|
break
|
|
|
|
json_start = start_pos + match.start()
|
|
name = match.group(1)
|
|
|
|
# Find the matching closing brace by counting braces
|
|
brace_count = 0
|
|
json_end = json_start
|
|
found_end = False
|
|
for i, char in enumerate(content[json_start:], json_start):
|
|
if char == '{':
|
|
brace_count += 1
|
|
elif char == '}':
|
|
brace_count -= 1
|
|
if brace_count == 0:
|
|
json_end = i + 1
|
|
found_end = True
|
|
break
|
|
|
|
if found_end:
|
|
# Try to parse the complete JSON object
|
|
try:
|
|
json_str = content[json_start:json_end]
|
|
tool_obj = json_repair.loads(json_str)
|
|
# Only accept if it has both name and parameters, and name is a valid tool name
|
|
valid_tools = ["exec", "read_file", "write_file", "list_dir", "web_search"]
|
|
if (isinstance(tool_obj, dict) and
|
|
"name" in tool_obj and
|
|
"parameters" in tool_obj and
|
|
isinstance(tool_obj["name"], str) and
|
|
tool_obj["name"] in valid_tools):
|
|
tool_calls.append(ToolCallRequest(
|
|
id=f"call_{len(tool_calls)}",
|
|
name=tool_obj["name"],
|
|
arguments=tool_obj["parameters"] if isinstance(tool_obj["parameters"], dict) else {"raw": str(tool_obj["parameters"])}
|
|
))
|
|
# Remove the tool call from content
|
|
content = content[:json_start] + content[json_end:].strip()
|
|
start_pos = json_start # Stay at same position since we removed text
|
|
continue
|
|
except Exception:
|
|
pass # If parsing fails, skip this match
|
|
|
|
start_pos = json_start + 1 # Move past this match
|
|
|
|
u = response.usage
|
|
return LLMResponse(
|
|
content=content, tool_calls=tool_calls, finish_reason=choice.finish_reason or "stop",
|
|
usage={"prompt_tokens": u.prompt_tokens, "completion_tokens": u.completion_tokens, "total_tokens": u.total_tokens} if u else {},
|
|
reasoning_content=getattr(msg, "reasoning_content", None),
|
|
)
|
|
|
|
def get_default_model(self) -> str:
|
|
return self.default_model
|