nanobot/nanobot/providers/airllm_provider.py
Tanya f1e95626f8 Clean up providers: keep only Ollama, AirLLM, vLLM, and DeepSeek
- Remove Qwen/DashScope provider and all Qwen-specific code
- Remove gateway providers (OpenRouter, AiHubMix)
- Remove cloud providers (Anthropic, OpenAI, Gemini, Zhipu, Moonshot, MiniMax, Groq)
- Update default model from Platypus to llama3.2
- Remove Platypus references throughout codebase
- Add AirLLM provider support with local model path support
- Update setup scripts to only show Llama models
- Clean up provider registry and config schema
2026-02-17 14:20:47 -05:00

182 lines
7.3 KiB
Python

"""AirLLM provider implementation for direct local model inference."""
import json
import asyncio
from typing import Any
from pathlib import Path
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
# Import the wrapper - handle import errors gracefully
try:
from nanobot.providers.airllm_wrapper import AirLLMOllamaWrapper, create_ollama_client
AIRLLM_WRAPPER_AVAILABLE = True
_import_error = None
except ImportError as e:
AIRLLM_WRAPPER_AVAILABLE = False
AirLLMOllamaWrapper = None
create_ollama_client = None
_import_error = str(e)
class AirLLMProvider(LLMProvider):
"""
LLM provider using AirLLM for direct local model inference.
This provider loads models directly into memory and runs inference locally,
bypassing HTTP API calls. It's optimized for GPU-limited environments.
"""
def __init__(
self,
api_key: str | None = None, # Repurposed: can be HF token or model name
api_base: str | None = None, # Repurposed: compression setting ('4bit' or '8bit')
default_model: str = "meta-llama/Llama-3.2-3B-Instruct",
compression: str | None = None, # '4bit' or '8bit' for speed improvement
model_path: str | None = None, # Override default model
hf_token: str | None = None, # Hugging Face token for gated models
):
super().__init__(api_key, api_base)
self.default_model = model_path or default_model
# If api_base is set and looks like compression, use it
if api_base and api_base in ('4bit', '8bit'):
self.compression = api_base
else:
self.compression = compression
# If api_key is provided and doesn't look like a model path, treat as HF token
if api_key and '/' not in api_key and len(api_key) > 20:
self.hf_token = api_key
else:
self.hf_token = hf_token
# If api_key looks like a model path, use it as the model
if api_key and '/' in api_key:
self.default_model = api_key
self._client: AirLLMOllamaWrapper | None = None
self._model_loaded = False
def _ensure_client(self) -> AirLLMOllamaWrapper:
"""Lazy-load the AirLLM client."""
if not AIRLLM_WRAPPER_AVAILABLE:
error_msg = (
"AirLLM wrapper is not available. Please ensure airllm_ollama_wrapper.py "
"is in the Python path and AirLLM is installed."
)
if '_import_error' in globals():
error_msg += f"\nImport error: {_import_error}"
raise ImportError(error_msg)
if self._client is None or not self._model_loaded:
print(f"Initializing AirLLM with model: {self.default_model}")
if self.compression:
print(f"Using compression: {self.compression}")
if self.hf_token:
print("Using Hugging Face token for authentication")
# Prepare kwargs for model loading
kwargs = {}
if self.hf_token:
kwargs['hf_token'] = self.hf_token
self._client = create_ollama_client(
self.default_model,
compression=self.compression,
**kwargs
)
self._model_loaded = True
print("AirLLM model loaded and ready!")
return self._client
async def chat(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
model: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
) -> LLMResponse:
"""
Send a chat completion request using AirLLM.
Args:
messages: List of message dicts with 'role' and 'content'.
tools: Optional list of tool definitions (Note: tool calling support may be limited).
model: Model identifier (ignored if different from initialized model).
max_tokens: Maximum tokens in response.
temperature: Sampling temperature.
Returns:
LLMResponse with content and/or tool calls.
"""
# If a different model is requested, we'd need to reload (expensive)
# For now, we'll use the initialized model
if model and model != self.default_model:
print(f"Warning: Model {model} requested but {self.default_model} is loaded. Using loaded model.")
client = self._ensure_client()
# Format tools into the prompt if provided (basic tool support)
# Note: Full tool calling requires model support and proper formatting
if tools:
# Add tool definitions to the system message or last user message
tools_text = "\n".join([
f"- {tool.get('function', {}).get('name', 'unknown')}: {tool.get('function', {}).get('description', '')}"
for tool in tools
])
# Append to messages (simplified - full implementation would format properly)
if messages and messages[-1].get('role') == 'user':
messages[-1]['content'] += f"\n\nAvailable tools:\n{tools_text}"
# Run the synchronous client in an executor to avoid blocking
loop = asyncio.get_event_loop()
response_text = await loop.run_in_executor(
None,
lambda: client.chat(
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
)
)
# Parse tool calls from response if present
# This is a simplified parser - you may need to adjust based on model output format
tool_calls = []
content = response_text
# Try to extract JSON tool calls from the response
# Some models return tool calls as JSON in the content
if "tool_calls" in response_text.lower() or "function" in response_text.lower():
try:
# Look for JSON blocks in the response
import re
json_pattern = r'\{[^{}]*"function"[^{}]*\}'
matches = re.findall(json_pattern, response_text, re.DOTALL)
for match in matches:
try:
tool_data = json.loads(match)
if "function" in tool_data:
func = tool_data["function"]
tool_calls.append(ToolCallRequest(
id=tool_data.get("id", f"call_{len(tool_calls)}"),
name=func.get("name", "unknown"),
arguments=func.get("arguments", {}),
))
# Remove the tool call from content
content = content.replace(match, "").strip()
except json.JSONDecodeError:
pass
except Exception:
pass # If parsing fails, just return the content as-is
return LLMResponse(
content=content,
tool_calls=tool_calls if tool_calls else [],
finish_reason="stop",
usage={}, # AirLLM doesn't provide usage stats in the wrapper
)
def get_default_model(self) -> str:
"""Get the default model."""
return self.default_model