- Remove Qwen/DashScope provider and all Qwen-specific code - Remove gateway providers (OpenRouter, AiHubMix) - Remove cloud providers (Anthropic, OpenAI, Gemini, Zhipu, Moonshot, MiniMax, Groq) - Update default model from Platypus to llama3.2 - Remove Platypus references throughout codebase - Add AirLLM provider support with local model path support - Update setup scripts to only show Llama models - Clean up provider registry and config schema
182 lines
7.3 KiB
Python
182 lines
7.3 KiB
Python
"""AirLLM provider implementation for direct local model inference."""
|
|
|
|
import json
|
|
import asyncio
|
|
from typing import Any
|
|
from pathlib import Path
|
|
|
|
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
|
|
|
# Import the wrapper - handle import errors gracefully
|
|
try:
|
|
from nanobot.providers.airllm_wrapper import AirLLMOllamaWrapper, create_ollama_client
|
|
AIRLLM_WRAPPER_AVAILABLE = True
|
|
_import_error = None
|
|
except ImportError as e:
|
|
AIRLLM_WRAPPER_AVAILABLE = False
|
|
AirLLMOllamaWrapper = None
|
|
create_ollama_client = None
|
|
_import_error = str(e)
|
|
|
|
|
|
class AirLLMProvider(LLMProvider):
|
|
"""
|
|
LLM provider using AirLLM for direct local model inference.
|
|
|
|
This provider loads models directly into memory and runs inference locally,
|
|
bypassing HTTP API calls. It's optimized for GPU-limited environments.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
api_key: str | None = None, # Repurposed: can be HF token or model name
|
|
api_base: str | None = None, # Repurposed: compression setting ('4bit' or '8bit')
|
|
default_model: str = "meta-llama/Llama-3.2-3B-Instruct",
|
|
compression: str | None = None, # '4bit' or '8bit' for speed improvement
|
|
model_path: str | None = None, # Override default model
|
|
hf_token: str | None = None, # Hugging Face token for gated models
|
|
):
|
|
super().__init__(api_key, api_base)
|
|
self.default_model = model_path or default_model
|
|
# If api_base is set and looks like compression, use it
|
|
if api_base and api_base in ('4bit', '8bit'):
|
|
self.compression = api_base
|
|
else:
|
|
self.compression = compression
|
|
# If api_key is provided and doesn't look like a model path, treat as HF token
|
|
if api_key and '/' not in api_key and len(api_key) > 20:
|
|
self.hf_token = api_key
|
|
else:
|
|
self.hf_token = hf_token
|
|
# If api_key looks like a model path, use it as the model
|
|
if api_key and '/' in api_key:
|
|
self.default_model = api_key
|
|
self._client: AirLLMOllamaWrapper | None = None
|
|
self._model_loaded = False
|
|
|
|
def _ensure_client(self) -> AirLLMOllamaWrapper:
|
|
"""Lazy-load the AirLLM client."""
|
|
if not AIRLLM_WRAPPER_AVAILABLE:
|
|
error_msg = (
|
|
"AirLLM wrapper is not available. Please ensure airllm_ollama_wrapper.py "
|
|
"is in the Python path and AirLLM is installed."
|
|
)
|
|
if '_import_error' in globals():
|
|
error_msg += f"\nImport error: {_import_error}"
|
|
raise ImportError(error_msg)
|
|
|
|
if self._client is None or not self._model_loaded:
|
|
print(f"Initializing AirLLM with model: {self.default_model}")
|
|
if self.compression:
|
|
print(f"Using compression: {self.compression}")
|
|
if self.hf_token:
|
|
print("Using Hugging Face token for authentication")
|
|
|
|
# Prepare kwargs for model loading
|
|
kwargs = {}
|
|
if self.hf_token:
|
|
kwargs['hf_token'] = self.hf_token
|
|
|
|
self._client = create_ollama_client(
|
|
self.default_model,
|
|
compression=self.compression,
|
|
**kwargs
|
|
)
|
|
self._model_loaded = True
|
|
print("AirLLM model loaded and ready!")
|
|
|
|
return self._client
|
|
|
|
async def chat(
|
|
self,
|
|
messages: list[dict[str, Any]],
|
|
tools: list[dict[str, Any]] | None = None,
|
|
model: str | None = None,
|
|
max_tokens: int = 4096,
|
|
temperature: float = 0.7,
|
|
) -> LLMResponse:
|
|
"""
|
|
Send a chat completion request using AirLLM.
|
|
|
|
Args:
|
|
messages: List of message dicts with 'role' and 'content'.
|
|
tools: Optional list of tool definitions (Note: tool calling support may be limited).
|
|
model: Model identifier (ignored if different from initialized model).
|
|
max_tokens: Maximum tokens in response.
|
|
temperature: Sampling temperature.
|
|
|
|
Returns:
|
|
LLMResponse with content and/or tool calls.
|
|
"""
|
|
# If a different model is requested, we'd need to reload (expensive)
|
|
# For now, we'll use the initialized model
|
|
if model and model != self.default_model:
|
|
print(f"Warning: Model {model} requested but {self.default_model} is loaded. Using loaded model.")
|
|
|
|
client = self._ensure_client()
|
|
|
|
# Format tools into the prompt if provided (basic tool support)
|
|
# Note: Full tool calling requires model support and proper formatting
|
|
if tools:
|
|
# Add tool definitions to the system message or last user message
|
|
tools_text = "\n".join([
|
|
f"- {tool.get('function', {}).get('name', 'unknown')}: {tool.get('function', {}).get('description', '')}"
|
|
for tool in tools
|
|
])
|
|
# Append to messages (simplified - full implementation would format properly)
|
|
if messages and messages[-1].get('role') == 'user':
|
|
messages[-1]['content'] += f"\n\nAvailable tools:\n{tools_text}"
|
|
|
|
# Run the synchronous client in an executor to avoid blocking
|
|
loop = asyncio.get_event_loop()
|
|
response_text = await loop.run_in_executor(
|
|
None,
|
|
lambda: client.chat(
|
|
messages=messages,
|
|
max_tokens=max_tokens,
|
|
temperature=temperature,
|
|
)
|
|
)
|
|
|
|
# Parse tool calls from response if present
|
|
# This is a simplified parser - you may need to adjust based on model output format
|
|
tool_calls = []
|
|
content = response_text
|
|
|
|
# Try to extract JSON tool calls from the response
|
|
# Some models return tool calls as JSON in the content
|
|
if "tool_calls" in response_text.lower() or "function" in response_text.lower():
|
|
try:
|
|
# Look for JSON blocks in the response
|
|
import re
|
|
json_pattern = r'\{[^{}]*"function"[^{}]*\}'
|
|
matches = re.findall(json_pattern, response_text, re.DOTALL)
|
|
for match in matches:
|
|
try:
|
|
tool_data = json.loads(match)
|
|
if "function" in tool_data:
|
|
func = tool_data["function"]
|
|
tool_calls.append(ToolCallRequest(
|
|
id=tool_data.get("id", f"call_{len(tool_calls)}"),
|
|
name=func.get("name", "unknown"),
|
|
arguments=func.get("arguments", {}),
|
|
))
|
|
# Remove the tool call from content
|
|
content = content.replace(match, "").strip()
|
|
except json.JSONDecodeError:
|
|
pass
|
|
except Exception:
|
|
pass # If parsing fails, just return the content as-is
|
|
|
|
return LLMResponse(
|
|
content=content,
|
|
tool_calls=tool_calls if tool_calls else [],
|
|
finish_reason="stop",
|
|
usage={}, # AirLLM doesn't provide usage stats in the wrapper
|
|
)
|
|
|
|
def get_default_model(self) -> str:
|
|
"""Get the default model."""
|
|
return self.default_model
|
|
|