Clean up providers: keep only Ollama, AirLLM, vLLM, and DeepSeek
- Remove Qwen/DashScope provider and all Qwen-specific code - Remove gateway providers (OpenRouter, AiHubMix) - Remove cloud providers (Anthropic, OpenAI, Gemini, Zhipu, Moonshot, MiniMax, Groq) - Update default model from Platypus to llama3.2 - Remove Platypus references throughout codebase - Add AirLLM provider support with local model path support - Update setup scripts to only show Llama models - Clean up provider registry and config schema
This commit is contained in:
parent
dd63337a83
commit
f1e95626f8
242
airllm_ollama_wrapper.py
Normal file
242
airllm_ollama_wrapper.py
Normal file
@ -0,0 +1,242 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
AirLLM Ollama-Compatible Wrapper
|
||||
|
||||
This wrapper provides an Ollama-like interface for AirLLM,
|
||||
making it easy to replace Ollama in existing projects.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from typing import List, Dict, Optional, Union
|
||||
|
||||
# Try to import airllm, handle BetterTransformer import error gracefully
|
||||
try:
|
||||
from airllm import AutoModel
|
||||
AIRLLM_AVAILABLE = True
|
||||
except ImportError as e:
|
||||
if "optimum.bettertransformer" in str(e) or "BetterTransformer" in str(e):
|
||||
# Try to work around BetterTransformer import issue
|
||||
import sys
|
||||
import importlib.util
|
||||
|
||||
# Create a dummy BetterTransformer module to allow airllm to import
|
||||
class DummyBetterTransformer:
|
||||
@staticmethod
|
||||
def transform(model):
|
||||
return model
|
||||
|
||||
# Inject dummy module before importing airllm
|
||||
spec = importlib.util.spec_from_loader("optimum.bettertransformer", None)
|
||||
dummy_module = importlib.util.module_from_spec(spec)
|
||||
dummy_module.BetterTransformer = DummyBetterTransformer
|
||||
sys.modules["optimum.bettertransformer"] = dummy_module
|
||||
|
||||
try:
|
||||
from airllm import AutoModel
|
||||
AIRLLM_AVAILABLE = True
|
||||
except ImportError:
|
||||
AIRLLM_AVAILABLE = False
|
||||
AutoModel = None
|
||||
else:
|
||||
AIRLLM_AVAILABLE = False
|
||||
AutoModel = None
|
||||
|
||||
|
||||
class AirLLMOllamaWrapper:
|
||||
"""
|
||||
A wrapper that provides an Ollama-like API for AirLLM.
|
||||
|
||||
Usage:
|
||||
# Instead of: ollama.generate(model="llama2", prompt="Hello")
|
||||
# Use: airllm_wrapper.generate(model="llama2", prompt="Hello")
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str, compression: Optional[str] = None, **kwargs):
|
||||
"""
|
||||
Initialize AirLLM model.
|
||||
|
||||
Args:
|
||||
model_name: Hugging Face model name or path (e.g., "meta-llama/Llama-3.2-3B-Instruct")
|
||||
compression: Optional compression ('4bit' or '8bit') for 3x speed improvement
|
||||
**kwargs: Additional arguments for AutoModel.from_pretrained()
|
||||
"""
|
||||
if not AIRLLM_AVAILABLE or AutoModel is None:
|
||||
raise ImportError(
|
||||
"AirLLM is not available. Please install it with: pip install airllm bitsandbytes\n"
|
||||
"If you see a BetterTransformer error, you may need to install: pip install optimum[bettertransformer]"
|
||||
)
|
||||
|
||||
print(f"Loading AirLLM model: {model_name}")
|
||||
self.model = AutoModel.from_pretrained(
|
||||
model_name,
|
||||
compression=compression,
|
||||
**kwargs
|
||||
)
|
||||
self.model_name = model_name
|
||||
print("Model loaded successfully!")
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
model: Optional[str] = None, # Ignored, kept for API compatibility
|
||||
max_tokens: int = 50,
|
||||
temperature: float = 0.7,
|
||||
top_p: float = 0.9,
|
||||
stream: bool = False,
|
||||
**kwargs
|
||||
) -> Union[str, Dict]:
|
||||
"""
|
||||
Generate text from a prompt (Ollama-compatible interface).
|
||||
|
||||
Args:
|
||||
prompt: Input text prompt
|
||||
model: Ignored (kept for compatibility)
|
||||
max_tokens: Maximum number of tokens to generate
|
||||
temperature: Sampling temperature (0.0 to 1.0)
|
||||
top_p: Nucleus sampling parameter
|
||||
stream: If True, return streaming response (not yet implemented)
|
||||
**kwargs: Additional generation parameters
|
||||
|
||||
Returns:
|
||||
Generated text string or dict with response
|
||||
"""
|
||||
# Tokenize input
|
||||
input_tokens = self.model.tokenizer(
|
||||
[prompt],
|
||||
return_tensors="pt",
|
||||
return_attention_mask=False,
|
||||
truncation=True,
|
||||
max_length=512, # Adjust as needed
|
||||
padding=False
|
||||
)
|
||||
|
||||
# Move to GPU if available
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
input_ids = input_tokens['input_ids'].to(device)
|
||||
|
||||
# Prepare generation parameters
|
||||
gen_kwargs = {
|
||||
'max_new_tokens': max_tokens,
|
||||
'use_cache': True,
|
||||
'return_dict_in_generate': True,
|
||||
'temperature': temperature,
|
||||
'top_p': top_p,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
# Generate
|
||||
with torch.inference_mode():
|
||||
generation_output = self.model.generate(input_ids, **gen_kwargs)
|
||||
|
||||
# Decode output
|
||||
output = self.model.tokenizer.decode(generation_output.sequences[0])
|
||||
|
||||
# Remove the input prompt from output (if present)
|
||||
if output.startswith(prompt):
|
||||
output = output[len(prompt):].strip()
|
||||
|
||||
if stream:
|
||||
# For streaming, return a generator (simplified version)
|
||||
return {"response": output}
|
||||
else:
|
||||
return output
|
||||
|
||||
def chat(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
model: Optional[str] = None,
|
||||
max_tokens: int = 50,
|
||||
temperature: float = 0.7,
|
||||
**kwargs
|
||||
) -> str:
|
||||
"""
|
||||
Chat interface (Ollama-compatible).
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'role' and 'content' keys
|
||||
model: Ignored (kept for compatibility)
|
||||
max_tokens: Maximum tokens to generate
|
||||
temperature: Sampling temperature
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
Generated response string
|
||||
"""
|
||||
# Format messages into a prompt
|
||||
prompt = self._format_messages(messages)
|
||||
return self.generate(
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def _format_messages(self, messages: List[Dict[str, str]]) -> str:
|
||||
"""Format chat messages into a single prompt."""
|
||||
formatted = []
|
||||
for msg in messages:
|
||||
role = msg.get('role', 'user')
|
||||
content = msg.get('content', '')
|
||||
if role == 'system':
|
||||
formatted.append(f"System: {content}")
|
||||
elif role == 'user':
|
||||
formatted.append(f"User: {content}")
|
||||
elif role == 'assistant':
|
||||
formatted.append(f"Assistant: {content}")
|
||||
return "\n".join(formatted) + "\nAssistant:"
|
||||
|
||||
def embeddings(self, prompt: str) -> List[float]:
|
||||
"""
|
||||
Get embeddings for a prompt (simplified - returns token embeddings).
|
||||
|
||||
Note: This is a simplified version. For full embeddings,
|
||||
you may need to access model internals.
|
||||
"""
|
||||
tokens = self.model.tokenizer(
|
||||
[prompt],
|
||||
return_tensors="pt",
|
||||
truncation=True,
|
||||
max_length=512,
|
||||
padding=False
|
||||
)
|
||||
# This is a placeholder - actual embeddings would require model forward pass
|
||||
return tokens['input_ids'].tolist()[0]
|
||||
|
||||
|
||||
# Convenience function for easy migration
|
||||
def create_ollama_client(model_name: str, compression: Optional[str] = None, **kwargs):
|
||||
"""
|
||||
Create an Ollama-compatible client using AirLLM.
|
||||
|
||||
Usage:
|
||||
client = create_ollama_client("meta-llama/Llama-3.2-3B-Instruct")
|
||||
response = client.generate("Hello, how are you?")
|
||||
"""
|
||||
return AirLLMOllamaWrapper(model_name, compression=compression, **kwargs)
|
||||
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
# Example 1: Basic generation
|
||||
print("Example 1: Basic Generation")
|
||||
print("=" * 60)
|
||||
|
||||
# Initialize (this will take time on first run)
|
||||
# client = create_ollama_client("garage-bAInd/Platypus2-70B-instruct")
|
||||
|
||||
# Generate
|
||||
# response = client.generate("What is the capital of France?")
|
||||
# print(f"Response: {response}")
|
||||
|
||||
print("\nExample 2: Chat Interface")
|
||||
print("=" * 60)
|
||||
|
||||
# Chat example
|
||||
# messages = [
|
||||
# {"role": "user", "content": "Hello! How are you?"}
|
||||
# ]
|
||||
# response = client.chat(messages)
|
||||
# print(f"Response: {response}")
|
||||
|
||||
print("\nUncomment the code above to test!")
|
||||
|
||||
@ -265,10 +265,60 @@ This file stores important information that should persist across sessions.
|
||||
|
||||
|
||||
def _make_provider(config):
|
||||
"""Create LiteLLMProvider from config. Exits if no API key found."""
|
||||
from nanobot.providers.litellm_provider import LiteLLMProvider
|
||||
"""Create LLM provider from config. Supports LiteLLMProvider and AirLLMProvider."""
|
||||
provider_name = config.get_provider_name()
|
||||
p = config.get_provider()
|
||||
model = config.agents.defaults.model
|
||||
|
||||
# Check if AirLLM provider is requested
|
||||
if provider_name == "airllm":
|
||||
try:
|
||||
from nanobot.providers.airllm_provider import AirLLMProvider
|
||||
# AirLLM doesn't need API key, but we can use model path from config
|
||||
# Check if model is specified in the airllm provider config
|
||||
airllm_config = getattr(config.providers, "airllm", None)
|
||||
model_path = None
|
||||
compression = None
|
||||
|
||||
# Try to get model from airllm config's api_key field (repurposed as model path)
|
||||
# or from the default model
|
||||
if airllm_config and airllm_config.api_key:
|
||||
# Check if api_key looks like a model path (contains '/') or is an HF token
|
||||
if '/' in airllm_config.api_key:
|
||||
model_path = airllm_config.api_key
|
||||
hf_token = None
|
||||
else:
|
||||
# Treat as HF token, use model from defaults
|
||||
model_path = model
|
||||
hf_token = airllm_config.api_key
|
||||
else:
|
||||
model_path = model
|
||||
hf_token = None
|
||||
|
||||
# Check for compression setting in extra_headers or api_base
|
||||
if airllm_config:
|
||||
if airllm_config.api_base:
|
||||
compression = airllm_config.api_base # Repurpose api_base as compression
|
||||
elif airllm_config.extra_headers and "compression" in airllm_config.extra_headers:
|
||||
compression = airllm_config.extra_headers["compression"]
|
||||
# Check for HF token in extra_headers
|
||||
if not hf_token and airllm_config.extra_headers and "hf_token" in airllm_config.extra_headers:
|
||||
hf_token = airllm_config.extra_headers["hf_token"]
|
||||
|
||||
return AirLLMProvider(
|
||||
api_key=airllm_config.api_key if airllm_config else None,
|
||||
api_base=compression if compression else None,
|
||||
default_model=model_path,
|
||||
compression=compression,
|
||||
hf_token=hf_token,
|
||||
)
|
||||
except ImportError as e:
|
||||
console.print(f"[red]Error: AirLLM provider not available: {e}[/red]")
|
||||
console.print("Please ensure airllm_ollama_wrapper.py is in the Python path.")
|
||||
raise typer.Exit(1)
|
||||
|
||||
# Default to LiteLLMProvider
|
||||
from nanobot.providers.litellm_provider import LiteLLMProvider
|
||||
if not (p and p.api_key) and not model.startswith("bedrock/"):
|
||||
console.print("[red]Error: No API key configured.[/red]")
|
||||
console.print("Set one in ~/.nanobot/config.json under providers section")
|
||||
@ -278,7 +328,7 @@ def _make_provider(config):
|
||||
api_base=config.get_api_base(),
|
||||
default_model=model,
|
||||
extra_headers=p.extra_headers if p else None,
|
||||
provider_name=config.get_provider_name(),
|
||||
provider_name=provider_name,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -177,18 +177,10 @@ class ProviderConfig(BaseModel):
|
||||
|
||||
class ProvidersConfig(BaseModel):
|
||||
"""Configuration for LLM providers."""
|
||||
anthropic: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||
openai: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||
openrouter: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||
deepseek: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||
groq: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||
zhipu: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||
dashscope: ProviderConfig = Field(default_factory=ProviderConfig) # 阿里云通义千问
|
||||
vllm: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||
gemini: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||
moonshot: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||
minimax: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||
aihubmix: ProviderConfig = Field(default_factory=ProviderConfig) # AiHubMix API gateway
|
||||
ollama: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||
airllm: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||
|
||||
|
||||
class GatewayConfig(BaseModel):
|
||||
@ -241,14 +233,37 @@ class Config(BaseSettings):
|
||||
# Match by keyword (order follows PROVIDERS registry)
|
||||
for spec in PROVIDERS:
|
||||
p = getattr(self.providers, spec.name, None)
|
||||
if p and any(kw in model_lower for kw in spec.keywords) and p.api_key:
|
||||
return p, spec.name
|
||||
if p and any(kw in model_lower for kw in spec.keywords):
|
||||
# For local providers (Ollama, AirLLM), allow empty api_key or "dummy"
|
||||
# For other providers, require api_key
|
||||
if spec.is_local:
|
||||
# Local providers can work with empty/dummy api_key
|
||||
if p.api_key or p.api_base or spec.name == "airllm":
|
||||
return p, spec.name
|
||||
elif p.api_key:
|
||||
return p, spec.name
|
||||
|
||||
# Check local providers by api_base detection (for explicit config)
|
||||
for spec in PROVIDERS:
|
||||
if spec.is_local:
|
||||
p = getattr(self.providers, spec.name, None)
|
||||
if p:
|
||||
# Check if api_base matches the provider's detection pattern
|
||||
if spec.detect_by_base_keyword and p.api_base and spec.detect_by_base_keyword in p.api_base:
|
||||
return p, spec.name
|
||||
# AirLLM is detected by provider name being "airllm"
|
||||
if spec.name == "airllm" and p.api_key: # api_key can be model path
|
||||
return p, spec.name
|
||||
|
||||
# Fallback: gateways first, then others (follows registry order)
|
||||
for spec in PROVIDERS:
|
||||
p = getattr(self.providers, spec.name, None)
|
||||
if p and p.api_key:
|
||||
return p, spec.name
|
||||
if p:
|
||||
# For local providers, allow empty/dummy api_key
|
||||
if spec.is_local and (p.api_key or p.api_base):
|
||||
return p, spec.name
|
||||
elif p.api_key:
|
||||
return p, spec.name
|
||||
return None, None
|
||||
|
||||
def get_provider(self, model: str | None = None) -> ProviderConfig | None:
|
||||
|
||||
@ -3,4 +3,8 @@
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse
|
||||
from nanobot.providers.litellm_provider import LiteLLMProvider
|
||||
|
||||
__all__ = ["LLMProvider", "LLMResponse", "LiteLLMProvider"]
|
||||
try:
|
||||
from nanobot.providers.airllm_provider import AirLLMProvider
|
||||
__all__ = ["LLMProvider", "LLMResponse", "LiteLLMProvider", "AirLLMProvider"]
|
||||
except ImportError:
|
||||
__all__ = ["LLMProvider", "LLMResponse", "LiteLLMProvider"]
|
||||
|
||||
181
nanobot/providers/airllm_provider.py
Normal file
181
nanobot/providers/airllm_provider.py
Normal file
@ -0,0 +1,181 @@
|
||||
"""AirLLM provider implementation for direct local model inference."""
|
||||
|
||||
import json
|
||||
import asyncio
|
||||
from typing import Any
|
||||
from pathlib import Path
|
||||
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
|
||||
# Import the wrapper - handle import errors gracefully
|
||||
try:
|
||||
from nanobot.providers.airllm_wrapper import AirLLMOllamaWrapper, create_ollama_client
|
||||
AIRLLM_WRAPPER_AVAILABLE = True
|
||||
_import_error = None
|
||||
except ImportError as e:
|
||||
AIRLLM_WRAPPER_AVAILABLE = False
|
||||
AirLLMOllamaWrapper = None
|
||||
create_ollama_client = None
|
||||
_import_error = str(e)
|
||||
|
||||
|
||||
class AirLLMProvider(LLMProvider):
|
||||
"""
|
||||
LLM provider using AirLLM for direct local model inference.
|
||||
|
||||
This provider loads models directly into memory and runs inference locally,
|
||||
bypassing HTTP API calls. It's optimized for GPU-limited environments.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None, # Repurposed: can be HF token or model name
|
||||
api_base: str | None = None, # Repurposed: compression setting ('4bit' or '8bit')
|
||||
default_model: str = "meta-llama/Llama-3.2-3B-Instruct",
|
||||
compression: str | None = None, # '4bit' or '8bit' for speed improvement
|
||||
model_path: str | None = None, # Override default model
|
||||
hf_token: str | None = None, # Hugging Face token for gated models
|
||||
):
|
||||
super().__init__(api_key, api_base)
|
||||
self.default_model = model_path or default_model
|
||||
# If api_base is set and looks like compression, use it
|
||||
if api_base and api_base in ('4bit', '8bit'):
|
||||
self.compression = api_base
|
||||
else:
|
||||
self.compression = compression
|
||||
# If api_key is provided and doesn't look like a model path, treat as HF token
|
||||
if api_key and '/' not in api_key and len(api_key) > 20:
|
||||
self.hf_token = api_key
|
||||
else:
|
||||
self.hf_token = hf_token
|
||||
# If api_key looks like a model path, use it as the model
|
||||
if api_key and '/' in api_key:
|
||||
self.default_model = api_key
|
||||
self._client: AirLLMOllamaWrapper | None = None
|
||||
self._model_loaded = False
|
||||
|
||||
def _ensure_client(self) -> AirLLMOllamaWrapper:
|
||||
"""Lazy-load the AirLLM client."""
|
||||
if not AIRLLM_WRAPPER_AVAILABLE:
|
||||
error_msg = (
|
||||
"AirLLM wrapper is not available. Please ensure airllm_ollama_wrapper.py "
|
||||
"is in the Python path and AirLLM is installed."
|
||||
)
|
||||
if '_import_error' in globals():
|
||||
error_msg += f"\nImport error: {_import_error}"
|
||||
raise ImportError(error_msg)
|
||||
|
||||
if self._client is None or not self._model_loaded:
|
||||
print(f"Initializing AirLLM with model: {self.default_model}")
|
||||
if self.compression:
|
||||
print(f"Using compression: {self.compression}")
|
||||
if self.hf_token:
|
||||
print("Using Hugging Face token for authentication")
|
||||
|
||||
# Prepare kwargs for model loading
|
||||
kwargs = {}
|
||||
if self.hf_token:
|
||||
kwargs['hf_token'] = self.hf_token
|
||||
|
||||
self._client = create_ollama_client(
|
||||
self.default_model,
|
||||
compression=self.compression,
|
||||
**kwargs
|
||||
)
|
||||
self._model_loaded = True
|
||||
print("AirLLM model loaded and ready!")
|
||||
|
||||
return self._client
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
Send a chat completion request using AirLLM.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'role' and 'content'.
|
||||
tools: Optional list of tool definitions (Note: tool calling support may be limited).
|
||||
model: Model identifier (ignored if different from initialized model).
|
||||
max_tokens: Maximum tokens in response.
|
||||
temperature: Sampling temperature.
|
||||
|
||||
Returns:
|
||||
LLMResponse with content and/or tool calls.
|
||||
"""
|
||||
# If a different model is requested, we'd need to reload (expensive)
|
||||
# For now, we'll use the initialized model
|
||||
if model and model != self.default_model:
|
||||
print(f"Warning: Model {model} requested but {self.default_model} is loaded. Using loaded model.")
|
||||
|
||||
client = self._ensure_client()
|
||||
|
||||
# Format tools into the prompt if provided (basic tool support)
|
||||
# Note: Full tool calling requires model support and proper formatting
|
||||
if tools:
|
||||
# Add tool definitions to the system message or last user message
|
||||
tools_text = "\n".join([
|
||||
f"- {tool.get('function', {}).get('name', 'unknown')}: {tool.get('function', {}).get('description', '')}"
|
||||
for tool in tools
|
||||
])
|
||||
# Append to messages (simplified - full implementation would format properly)
|
||||
if messages and messages[-1].get('role') == 'user':
|
||||
messages[-1]['content'] += f"\n\nAvailable tools:\n{tools_text}"
|
||||
|
||||
# Run the synchronous client in an executor to avoid blocking
|
||||
loop = asyncio.get_event_loop()
|
||||
response_text = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: client.chat(
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
)
|
||||
)
|
||||
|
||||
# Parse tool calls from response if present
|
||||
# This is a simplified parser - you may need to adjust based on model output format
|
||||
tool_calls = []
|
||||
content = response_text
|
||||
|
||||
# Try to extract JSON tool calls from the response
|
||||
# Some models return tool calls as JSON in the content
|
||||
if "tool_calls" in response_text.lower() or "function" in response_text.lower():
|
||||
try:
|
||||
# Look for JSON blocks in the response
|
||||
import re
|
||||
json_pattern = r'\{[^{}]*"function"[^{}]*\}'
|
||||
matches = re.findall(json_pattern, response_text, re.DOTALL)
|
||||
for match in matches:
|
||||
try:
|
||||
tool_data = json.loads(match)
|
||||
if "function" in tool_data:
|
||||
func = tool_data["function"]
|
||||
tool_calls.append(ToolCallRequest(
|
||||
id=tool_data.get("id", f"call_{len(tool_calls)}"),
|
||||
name=func.get("name", "unknown"),
|
||||
arguments=func.get("arguments", {}),
|
||||
))
|
||||
# Remove the tool call from content
|
||||
content = content.replace(match, "").strip()
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
except Exception:
|
||||
pass # If parsing fails, just return the content as-is
|
||||
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
tool_calls=tool_calls if tool_calls else [],
|
||||
finish_reason="stop",
|
||||
usage={}, # AirLLM doesn't provide usage stats in the wrapper
|
||||
)
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
"""Get the default model."""
|
||||
return self.default_model
|
||||
|
||||
327
nanobot/providers/airllm_wrapper.py
Normal file
327
nanobot/providers/airllm_wrapper.py
Normal file
@ -0,0 +1,327 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
AirLLM Ollama-Compatible Wrapper
|
||||
|
||||
This wrapper provides an Ollama-like interface for AirLLM,
|
||||
making it easy to replace Ollama in existing projects.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from typing import List, Dict, Optional, Union
|
||||
|
||||
# Try to import airllm, handle BetterTransformer import error gracefully
|
||||
try:
|
||||
from airllm import AutoModel
|
||||
AIRLLM_AVAILABLE = True
|
||||
except ImportError as e:
|
||||
if "optimum.bettertransformer" in str(e) or "BetterTransformer" in str(e):
|
||||
# Try to work around BetterTransformer import issue
|
||||
import sys
|
||||
import importlib.util
|
||||
|
||||
# Create a dummy BetterTransformer module to allow airllm to import
|
||||
class DummyBetterTransformer:
|
||||
@staticmethod
|
||||
def transform(model):
|
||||
return model
|
||||
|
||||
# Inject dummy module before importing airllm
|
||||
spec = importlib.util.spec_from_loader("optimum.bettertransformer", None)
|
||||
dummy_module = importlib.util.module_from_spec(spec)
|
||||
dummy_module.BetterTransformer = DummyBetterTransformer
|
||||
sys.modules["optimum.bettertransformer"] = dummy_module
|
||||
|
||||
try:
|
||||
from airllm import AutoModel
|
||||
AIRLLM_AVAILABLE = True
|
||||
except ImportError:
|
||||
AIRLLM_AVAILABLE = False
|
||||
AutoModel = None
|
||||
else:
|
||||
AIRLLM_AVAILABLE = False
|
||||
AutoModel = None
|
||||
|
||||
|
||||
class AirLLMOllamaWrapper:
|
||||
"""
|
||||
A wrapper that provides an Ollama-like API for AirLLM.
|
||||
|
||||
Usage:
|
||||
# Instead of: ollama.generate(model="llama2", prompt="Hello")
|
||||
# Use: airllm_wrapper.generate(model="llama2", prompt="Hello")
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str, compression: Optional[str] = None, **kwargs):
|
||||
"""
|
||||
Initialize AirLLM model.
|
||||
|
||||
Args:
|
||||
model_name: Hugging Face model name or path (e.g., "meta-llama/Llama-3.2-3B-Instruct")
|
||||
compression: Optional compression ('4bit' or '8bit') for 3x speed improvement
|
||||
**kwargs: Additional arguments for AutoModel.from_pretrained()
|
||||
"""
|
||||
if not AIRLLM_AVAILABLE or AutoModel is None:
|
||||
raise ImportError(
|
||||
"AirLLM is not available. Please install it with: pip install airllm bitsandbytes\n"
|
||||
"If you see a BetterTransformer error, you may need to install: pip install optimum[bettertransformer]"
|
||||
)
|
||||
|
||||
print(f"Loading AirLLM model: {model_name}")
|
||||
# AutoModel.from_pretrained() accepts:
|
||||
# - Hugging Face model IDs (e.g., "meta-llama/Llama-3.1-8B-Instruct")
|
||||
# - Local paths (e.g., "/path/to/local/model")
|
||||
# - Can use local_dir parameter for local models
|
||||
self.model = AutoModel.from_pretrained(
|
||||
model_name,
|
||||
compression=compression,
|
||||
**kwargs
|
||||
)
|
||||
self.model_name = model_name
|
||||
|
||||
# Get the model's maximum sequence length for AirLLM
|
||||
# IMPORTANT: AirLLM processes sequences in chunks, and each chunk must fit
|
||||
# within the model's position embedding limits.
|
||||
# Even if the model config says it supports longer sequences via rope scaling,
|
||||
# AirLLM's chunking mechanism requires the base size.
|
||||
|
||||
self.max_length = 2048 # Default for Llama models
|
||||
|
||||
# Check if this is a Llama model to determine appropriate max length
|
||||
is_llama = False
|
||||
if hasattr(self.model, 'config'):
|
||||
model_type = getattr(self.model.config, 'model_type', '').lower()
|
||||
is_llama = 'llama' in model_type or 'llama' in self.model_name.lower()
|
||||
|
||||
if is_llama:
|
||||
# Llama models: typically support 2048-4096 tokens
|
||||
# AirLLM works well with Llama, so we can use larger chunks
|
||||
if hasattr(self.model, 'config'):
|
||||
config_max = getattr(self.model.config, 'max_position_embeddings', None)
|
||||
if config_max and config_max > 0:
|
||||
# Use config value, but cap at 2048 for AirLLM safety
|
||||
self.max_length = min(config_max, 2048)
|
||||
else:
|
||||
self.max_length = 2048 # Safe default for Llama
|
||||
else:
|
||||
# For other models (e.g., DeepSeek), use conservative default
|
||||
if hasattr(self.model, 'config'):
|
||||
config_max = getattr(self.model.config, 'max_position_embeddings', None)
|
||||
if config_max and config_max > 0 and config_max <= 2048:
|
||||
self.max_length = config_max
|
||||
else:
|
||||
self.max_length = 512 # Very conservative
|
||||
|
||||
print(f"Using sequence length limit: {self.max_length} (AirLLM chunk size)")
|
||||
|
||||
print("Model loaded successfully!")
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
model: Optional[str] = None, # Ignored, kept for API compatibility
|
||||
max_tokens: int = 50,
|
||||
temperature: float = 0.7,
|
||||
top_p: float = 0.9,
|
||||
stream: bool = False,
|
||||
**kwargs
|
||||
) -> Union[str, Dict]:
|
||||
"""
|
||||
Generate text from a prompt (Ollama-compatible interface).
|
||||
|
||||
Args:
|
||||
prompt: Input text prompt
|
||||
model: Ignored (kept for compatibility)
|
||||
max_tokens: Maximum number of tokens to generate
|
||||
temperature: Sampling temperature (0.0 to 1.0)
|
||||
top_p: Nucleus sampling parameter
|
||||
stream: If True, return streaming response (not yet implemented)
|
||||
**kwargs: Additional generation parameters
|
||||
|
||||
Returns:
|
||||
Generated text string or dict with response
|
||||
"""
|
||||
# Tokenize input with attention mask
|
||||
# AirLLM processes sequences in chunks, but each chunk must fit within the model's
|
||||
# position embedding limits. We need to ensure we don't exceed the chunk size.
|
||||
# Use the model's max_length to ensure compatibility with position embeddings
|
||||
input_tokens = self.model.tokenizer(
|
||||
prompt,
|
||||
return_tensors="pt",
|
||||
return_attention_mask=True,
|
||||
truncation=True,
|
||||
max_length=self.max_length, # Respect model's position embedding limit
|
||||
padding=False
|
||||
)
|
||||
|
||||
# Move to GPU if available
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
input_ids = input_tokens['input_ids'].to(device)
|
||||
attention_mask = input_tokens.get('attention_mask', None)
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(device)
|
||||
|
||||
# Ensure we don't exceed max_length (manual truncation as safety check)
|
||||
seq_length = input_ids.shape[1]
|
||||
if seq_length > self.max_length:
|
||||
print(f"Warning: Sequence length ({seq_length}) exceeds limit ({self.max_length}), truncating...")
|
||||
input_ids = input_ids[:, :self.max_length]
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, :self.max_length]
|
||||
seq_length = self.max_length
|
||||
|
||||
if seq_length >= self.max_length:
|
||||
print(f"Note: Using sequence of {seq_length} tokens (at limit: {self.max_length})")
|
||||
|
||||
# Prepare generation parameters
|
||||
# For Llama models, we can use more tokens
|
||||
max_gen_tokens = min(max_tokens, 512)
|
||||
|
||||
gen_kwargs = {
|
||||
'max_new_tokens': max_gen_tokens,
|
||||
'use_cache': True,
|
||||
'return_dict_in_generate': True,
|
||||
'temperature': temperature,
|
||||
'top_p': top_p,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
# Add attention mask if available
|
||||
if attention_mask is not None:
|
||||
gen_kwargs['attention_mask'] = attention_mask
|
||||
|
||||
# Generate
|
||||
with torch.inference_mode():
|
||||
generation_output = self.model.generate(input_ids, **gen_kwargs)
|
||||
|
||||
# Decode output - get only the newly generated tokens
|
||||
if hasattr(generation_output, 'sequences'):
|
||||
# Extract only the new tokens (after input length)
|
||||
input_length = input_ids.shape[1]
|
||||
generated_ids = generation_output.sequences[0, input_length:]
|
||||
output = self.model.tokenizer.decode(generated_ids, skip_special_tokens=True)
|
||||
else:
|
||||
# Fallback for older output formats
|
||||
output = self.model.tokenizer.decode(generation_output.sequences[0], skip_special_tokens=True)
|
||||
# Remove the input prompt from output if present
|
||||
if output.startswith(prompt):
|
||||
output = output[len(prompt):].strip()
|
||||
|
||||
if stream:
|
||||
# For streaming, return a generator (simplified version)
|
||||
return {"response": output}
|
||||
else:
|
||||
return output
|
||||
|
||||
def chat(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
model: Optional[str] = None,
|
||||
max_tokens: int = 50,
|
||||
temperature: float = 0.7,
|
||||
**kwargs
|
||||
) -> str:
|
||||
"""
|
||||
Chat interface (Ollama-compatible).
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'role' and 'content' keys
|
||||
model: Ignored (kept for compatibility)
|
||||
max_tokens: Maximum tokens to generate
|
||||
temperature: Sampling temperature
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
Generated response string
|
||||
"""
|
||||
# Try to use the model's chat template if available (for Llama, etc.)
|
||||
if hasattr(self.model.tokenizer, 'apply_chat_template') and self.model.tokenizer.chat_template:
|
||||
try:
|
||||
# Use the model's native chat template
|
||||
prompt = self.model.tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True
|
||||
)
|
||||
except Exception:
|
||||
# Fallback to simple formatting if chat template fails
|
||||
prompt = self._format_messages(messages)
|
||||
else:
|
||||
# Fallback to simple formatting
|
||||
prompt = self._format_messages(messages)
|
||||
|
||||
return self.generate(
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def _format_messages(self, messages: List[Dict[str, str]]) -> str:
|
||||
"""Format chat messages into a single prompt (fallback method)."""
|
||||
formatted = []
|
||||
for msg in messages:
|
||||
role = msg.get('role', 'user')
|
||||
content = msg.get('content', '')
|
||||
if role == 'system':
|
||||
formatted.append(f"System: {content}")
|
||||
elif role == 'user':
|
||||
formatted.append(f"User: {content}")
|
||||
elif role == 'assistant':
|
||||
formatted.append(f"Assistant: {content}")
|
||||
return "\n".join(formatted) + "\nAssistant:"
|
||||
|
||||
def embeddings(self, prompt: str) -> List[float]:
|
||||
"""
|
||||
Get embeddings for a prompt (simplified - returns token embeddings).
|
||||
|
||||
Note: This is a simplified version. For full embeddings,
|
||||
you may need to access model internals.
|
||||
"""
|
||||
tokens = self.model.tokenizer(
|
||||
[prompt],
|
||||
return_tensors="pt",
|
||||
truncation=True,
|
||||
max_length=512,
|
||||
padding=False
|
||||
)
|
||||
# This is a placeholder - actual embeddings would require model forward pass
|
||||
return tokens['input_ids'].tolist()[0]
|
||||
|
||||
|
||||
# Convenience function for easy migration
|
||||
def create_ollama_client(model_name: str, compression: Optional[str] = None, **kwargs):
|
||||
"""
|
||||
Create an Ollama-compatible client using AirLLM.
|
||||
|
||||
Usage:
|
||||
client = create_ollama_client("meta-llama/Llama-3.2-3B-Instruct")
|
||||
response = client.generate("Hello, how are you?")
|
||||
"""
|
||||
return AirLLMOllamaWrapper(model_name, compression=compression, **kwargs)
|
||||
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
# Example 1: Basic generation
|
||||
print("Example 1: Basic Generation")
|
||||
print("=" * 60)
|
||||
|
||||
# Initialize (this will take time on first run)
|
||||
# client = create_ollama_client("meta-llama/Llama-3.2-3B-Instruct")
|
||||
|
||||
# Generate
|
||||
# response = client.generate("What is the capital of France?")
|
||||
# print(f"Response: {response}")
|
||||
|
||||
print("\nExample 2: Chat Interface")
|
||||
print("=" * 60)
|
||||
|
||||
# Chat example
|
||||
# messages = [
|
||||
# {"role": "user", "content": "Hello! How are you?"}
|
||||
# ]
|
||||
# response = client.chat(messages)
|
||||
# print(f"Response: {response}")
|
||||
|
||||
print("\nUncomment the code above to test!")
|
||||
|
||||
@ -127,6 +127,7 @@ class LiteLLMProvider(LLMProvider):
|
||||
"messages": messages,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"stream": False, # Explicitly disable streaming to avoid hangs with some providers
|
||||
}
|
||||
|
||||
# Apply model-specific overrides (e.g. kimi-k2.5 temperature)
|
||||
@ -148,6 +149,11 @@ class LiteLLMProvider(LLMProvider):
|
||||
kwargs["tools"] = tools
|
||||
kwargs["tool_choice"] = "auto"
|
||||
|
||||
# Add timeout to prevent hangs (especially with local servers)
|
||||
# Ollama can be slow with complex prompts, so use a longer timeout
|
||||
# Increased to 400s for larger models like mistral-nemo
|
||||
kwargs["timeout"] = 400.0
|
||||
|
||||
try:
|
||||
response = await acompletion(**kwargs)
|
||||
return self._parse_response(response)
|
||||
|
||||
@ -6,7 +6,7 @@ Adding a new provider:
|
||||
2. Add a field to ProvidersConfig in config/schema.py.
|
||||
Done. Env vars, prefixing, config matching, status display all derive from here.
|
||||
|
||||
Order matters — it controls match priority and fallback. Gateways first.
|
||||
Order matters — it controls match priority and fallback.
|
||||
Every entry writes out all fields so you can copy-paste as a template.
|
||||
"""
|
||||
|
||||
@ -62,86 +62,10 @@ class ProviderSpec:
|
||||
|
||||
PROVIDERS: tuple[ProviderSpec, ...] = (
|
||||
|
||||
# === Gateways (detected by api_key / api_base, not model name) =========
|
||||
# Gateways can route any model, so they win in fallback.
|
||||
|
||||
# OpenRouter: global gateway, keys start with "sk-or-"
|
||||
ProviderSpec(
|
||||
name="openrouter",
|
||||
keywords=("openrouter",),
|
||||
env_key="OPENROUTER_API_KEY",
|
||||
display_name="OpenRouter",
|
||||
litellm_prefix="openrouter", # claude-3 → openrouter/claude-3
|
||||
skip_prefixes=(),
|
||||
env_extras=(),
|
||||
is_gateway=True,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="sk-or-",
|
||||
detect_by_base_keyword="openrouter",
|
||||
default_api_base="https://openrouter.ai/api/v1",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
),
|
||||
|
||||
# AiHubMix: global gateway, OpenAI-compatible interface.
|
||||
# strip_model_prefix=True: it doesn't understand "anthropic/claude-3",
|
||||
# so we strip to bare "claude-3" then re-prefix as "openai/claude-3".
|
||||
ProviderSpec(
|
||||
name="aihubmix",
|
||||
keywords=("aihubmix",),
|
||||
env_key="OPENAI_API_KEY", # OpenAI-compatible
|
||||
display_name="AiHubMix",
|
||||
litellm_prefix="openai", # → openai/{model}
|
||||
skip_prefixes=(),
|
||||
env_extras=(),
|
||||
is_gateway=True,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="aihubmix",
|
||||
default_api_base="https://aihubmix.com/v1",
|
||||
strip_model_prefix=True, # anthropic/claude-3 → claude-3 → openai/claude-3
|
||||
model_overrides=(),
|
||||
),
|
||||
|
||||
# === Standard providers (matched by model-name keywords) ===============
|
||||
|
||||
# Anthropic: LiteLLM recognizes "claude-*" natively, no prefix needed.
|
||||
ProviderSpec(
|
||||
name="anthropic",
|
||||
keywords=("anthropic", "claude"),
|
||||
env_key="ANTHROPIC_API_KEY",
|
||||
display_name="Anthropic",
|
||||
litellm_prefix="",
|
||||
skip_prefixes=(),
|
||||
env_extras=(),
|
||||
is_gateway=False,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="",
|
||||
default_api_base="",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
),
|
||||
|
||||
# OpenAI: LiteLLM recognizes "gpt-*" natively, no prefix needed.
|
||||
ProviderSpec(
|
||||
name="openai",
|
||||
keywords=("openai", "gpt"),
|
||||
env_key="OPENAI_API_KEY",
|
||||
display_name="OpenAI",
|
||||
litellm_prefix="",
|
||||
skip_prefixes=(),
|
||||
env_extras=(),
|
||||
is_gateway=False,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="",
|
||||
default_api_base="",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
),
|
||||
|
||||
# DeepSeek: needs "deepseek/" prefix for LiteLLM routing.
|
||||
# Can be used with local models or API.
|
||||
ProviderSpec(
|
||||
name="deepseek",
|
||||
keywords=("deepseek",),
|
||||
@ -159,107 +83,6 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
||||
model_overrides=(),
|
||||
),
|
||||
|
||||
# Gemini: needs "gemini/" prefix for LiteLLM.
|
||||
ProviderSpec(
|
||||
name="gemini",
|
||||
keywords=("gemini",),
|
||||
env_key="GEMINI_API_KEY",
|
||||
display_name="Gemini",
|
||||
litellm_prefix="gemini", # gemini-pro → gemini/gemini-pro
|
||||
skip_prefixes=("gemini/",), # avoid double-prefix
|
||||
env_extras=(),
|
||||
is_gateway=False,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="",
|
||||
default_api_base="",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
),
|
||||
|
||||
# Zhipu: LiteLLM uses "zai/" prefix.
|
||||
# Also mirrors key to ZHIPUAI_API_KEY (some LiteLLM paths check that).
|
||||
# skip_prefixes: don't add "zai/" when already routed via gateway.
|
||||
ProviderSpec(
|
||||
name="zhipu",
|
||||
keywords=("zhipu", "glm", "zai"),
|
||||
env_key="ZAI_API_KEY",
|
||||
display_name="Zhipu AI",
|
||||
litellm_prefix="zai", # glm-4 → zai/glm-4
|
||||
skip_prefixes=("zhipu/", "zai/", "openrouter/", "hosted_vllm/"),
|
||||
env_extras=(
|
||||
("ZHIPUAI_API_KEY", "{api_key}"),
|
||||
),
|
||||
is_gateway=False,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="",
|
||||
default_api_base="",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
),
|
||||
|
||||
# DashScope: Qwen models, needs "dashscope/" prefix.
|
||||
ProviderSpec(
|
||||
name="dashscope",
|
||||
keywords=("qwen", "dashscope"),
|
||||
env_key="DASHSCOPE_API_KEY",
|
||||
display_name="DashScope",
|
||||
litellm_prefix="dashscope", # qwen-max → dashscope/qwen-max
|
||||
skip_prefixes=("dashscope/", "openrouter/"),
|
||||
env_extras=(),
|
||||
is_gateway=False,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="",
|
||||
default_api_base="",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
),
|
||||
|
||||
# Moonshot: Kimi models, needs "moonshot/" prefix.
|
||||
# LiteLLM requires MOONSHOT_API_BASE env var to find the endpoint.
|
||||
# Kimi K2.5 API enforces temperature >= 1.0.
|
||||
ProviderSpec(
|
||||
name="moonshot",
|
||||
keywords=("moonshot", "kimi"),
|
||||
env_key="MOONSHOT_API_KEY",
|
||||
display_name="Moonshot",
|
||||
litellm_prefix="moonshot", # kimi-k2.5 → moonshot/kimi-k2.5
|
||||
skip_prefixes=("moonshot/", "openrouter/"),
|
||||
env_extras=(
|
||||
("MOONSHOT_API_BASE", "{api_base}"),
|
||||
),
|
||||
is_gateway=False,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="",
|
||||
default_api_base="https://api.moonshot.ai/v1", # intl; use api.moonshot.cn for China
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(
|
||||
("kimi-k2.5", {"temperature": 1.0}),
|
||||
),
|
||||
),
|
||||
|
||||
# MiniMax: needs "minimax/" prefix for LiteLLM routing.
|
||||
# Uses OpenAI-compatible API at api.minimax.io/v1.
|
||||
ProviderSpec(
|
||||
name="minimax",
|
||||
keywords=("minimax",),
|
||||
env_key="MINIMAX_API_KEY",
|
||||
display_name="MiniMax",
|
||||
litellm_prefix="minimax", # MiniMax-M2.1 → minimax/MiniMax-M2.1
|
||||
skip_prefixes=("minimax/", "openrouter/"),
|
||||
env_extras=(),
|
||||
is_gateway=False,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="",
|
||||
default_api_base="https://api.minimax.io/v1",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
),
|
||||
|
||||
# === Local deployment (matched by config key, NOT by api_base) =========
|
||||
|
||||
# vLLM / any OpenAI-compatible local server.
|
||||
@ -281,23 +104,44 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
||||
model_overrides=(),
|
||||
),
|
||||
|
||||
# === Auxiliary (not a primary LLM provider) ============================
|
||||
|
||||
# Groq: mainly used for Whisper voice transcription, also usable for LLM.
|
||||
# Needs "groq/" prefix for LiteLLM routing. Placed last — it rarely wins fallback.
|
||||
# Ollama: local OpenAI-compatible server.
|
||||
# Use OpenAI-compatible endpoint, not native Ollama API.
|
||||
# Detected when config key is "ollama" or api_base contains "11434" or "ollama".
|
||||
ProviderSpec(
|
||||
name="groq",
|
||||
keywords=("groq",),
|
||||
env_key="GROQ_API_KEY",
|
||||
display_name="Groq",
|
||||
litellm_prefix="groq", # llama3-8b-8192 → groq/llama3-8b-8192
|
||||
skip_prefixes=("groq/",), # avoid double-prefix
|
||||
name="ollama",
|
||||
keywords=("ollama", "llama"), # Match both "ollama" and "llama" model names
|
||||
env_key="OPENAI_API_KEY", # Use OpenAI-compatible API
|
||||
display_name="Ollama",
|
||||
litellm_prefix="", # No prefix - use as OpenAI-compatible
|
||||
skip_prefixes=(),
|
||||
env_extras=(
|
||||
("OPENAI_API_BASE", "{api_base}"), # Set OpenAI API base to Ollama endpoint
|
||||
),
|
||||
is_gateway=False,
|
||||
is_local=True,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="11434", # Detect by default Ollama port
|
||||
default_api_base="http://localhost:11434/v1",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
),
|
||||
|
||||
# AirLLM: direct local model inference (no HTTP server).
|
||||
# Loads models directly into memory for GPU-optimized inference.
|
||||
# Detected when config key is "airllm".
|
||||
ProviderSpec(
|
||||
name="airllm",
|
||||
keywords=("airllm",),
|
||||
env_key="", # No API key needed (local)
|
||||
display_name="AirLLM",
|
||||
litellm_prefix="", # Not used with LiteLLM
|
||||
skip_prefixes=(),
|
||||
env_extras=(),
|
||||
is_gateway=False,
|
||||
is_local=False,
|
||||
is_local=True,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="",
|
||||
default_api_base="",
|
||||
default_api_base="", # Not used (direct Python calls)
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
),
|
||||
@ -325,12 +169,11 @@ def find_gateway(
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
) -> ProviderSpec | None:
|
||||
"""Detect gateway/local provider.
|
||||
"""Detect local provider.
|
||||
|
||||
Priority:
|
||||
1. provider_name — if it maps to a gateway/local spec, use it directly.
|
||||
2. api_key prefix — e.g. "sk-or-" → OpenRouter.
|
||||
3. api_base keyword — e.g. "aihubmix" in URL → AiHubMix.
|
||||
1. provider_name — if it maps to a local spec, use it directly.
|
||||
2. api_base keyword — e.g. "11434" in URL → Ollama.
|
||||
|
||||
A standard provider with a custom api_base (e.g. DeepSeek behind a proxy)
|
||||
will NOT be mistaken for vLLM — the old fallback is gone.
|
||||
@ -341,10 +184,8 @@ def find_gateway(
|
||||
if spec and (spec.is_gateway or spec.is_local):
|
||||
return spec
|
||||
|
||||
# 2. Auto-detect by api_key prefix / api_base keyword
|
||||
# 2. Auto-detect by api_base keyword
|
||||
for spec in PROVIDERS:
|
||||
if spec.detect_by_key_prefix and api_key and api_key.startswith(spec.detect_by_key_prefix):
|
||||
return spec
|
||||
if spec.detect_by_base_keyword and api_base and spec.detect_by_base_keyword in api_base:
|
||||
return spec
|
||||
|
||||
|
||||
175
setup_llama_airllm.py
Normal file
175
setup_llama_airllm.py
Normal file
@ -0,0 +1,175 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Setup script to configure nanobot to use Llama models with AirLLM.
|
||||
This script will:
|
||||
1. Check/create the config file
|
||||
2. Set up Llama model configuration
|
||||
3. Guide you through getting a Hugging Face token if needed
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
CONFIG_PATH = Path.home() / ".nanobot" / "config.json"
|
||||
|
||||
def get_hf_token_instructions():
|
||||
"""Print instructions for getting a Hugging Face token."""
|
||||
print("\n" + "="*70)
|
||||
print("GETTING A HUGGING FACE TOKEN")
|
||||
print("="*70)
|
||||
print("\nTo use Llama models (which are gated), you need a Hugging Face token:")
|
||||
print("\n1. Go to: https://huggingface.co/settings/tokens")
|
||||
print("2. Click 'New token'")
|
||||
print("3. Give it a name (e.g., 'nanobot')")
|
||||
print("4. Select 'Read' permission")
|
||||
print("5. Click 'Generate token'")
|
||||
print("6. Copy the token (starts with 'hf_...')")
|
||||
print("\nThen accept the Llama model license:")
|
||||
print("1. Go to: https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct")
|
||||
print("2. Click 'Agree and access repository'")
|
||||
print("3. Accept the license terms")
|
||||
print("\n" + "="*70 + "\n")
|
||||
|
||||
def load_existing_config():
|
||||
"""Load existing config or return default."""
|
||||
if CONFIG_PATH.exists():
|
||||
try:
|
||||
with open(CONFIG_PATH) as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not read existing config: {e}")
|
||||
return {}
|
||||
return {}
|
||||
|
||||
def create_llama_config():
|
||||
"""Create or update config for Llama with AirLLM."""
|
||||
config = load_existing_config()
|
||||
|
||||
# Ensure providers section exists
|
||||
if "providers" not in config:
|
||||
config["providers"] = {}
|
||||
|
||||
# Ensure agents section exists
|
||||
if "agents" not in config:
|
||||
config["agents"] = {}
|
||||
if "defaults" not in config["agents"]:
|
||||
config["agents"]["defaults"] = {}
|
||||
|
||||
# Choose Llama model
|
||||
print("\n" + "="*70)
|
||||
print("CHOOSE LLAMA MODEL")
|
||||
print("="*70)
|
||||
print("\nAvailable models:")
|
||||
print(" 1. Llama-3.2-3B-Instruct (Recommended - fast, minimal memory)")
|
||||
print(" 2. Llama-3.1-8B-Instruct (Good balance of performance and speed)")
|
||||
print(" 3. Custom (enter model path)")
|
||||
|
||||
choice = input("\nChoose model (1-3, default: 1): ").strip() or "1"
|
||||
|
||||
model_map = {
|
||||
"1": "meta-llama/Llama-3.2-3B-Instruct",
|
||||
"2": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
}
|
||||
|
||||
if choice == "3":
|
||||
model_path = input("Enter model path (e.g., meta-llama/Llama-3.2-3B-Instruct): ").strip()
|
||||
if not model_path:
|
||||
model_path = "meta-llama/Llama-3.2-3B-Instruct"
|
||||
print(f"Using default: {model_path}")
|
||||
else:
|
||||
model_path = model_map.get(choice, "meta-llama/Llama-3.2-3B-Instruct")
|
||||
|
||||
# Set up AirLLM provider with Llama model
|
||||
# Note: apiKey can be used as model path, or we can put model in defaults
|
||||
config["providers"]["airllm"] = {
|
||||
"apiKey": "", # Will be set to model path
|
||||
"apiBase": None,
|
||||
"extraHeaders": {}
|
||||
}
|
||||
|
||||
# Set default model
|
||||
config["agents"]["defaults"]["model"] = model_path
|
||||
|
||||
# Ask for Hugging Face token
|
||||
print("\n" + "="*70)
|
||||
print("HUGGING FACE TOKEN SETUP")
|
||||
print("="*70)
|
||||
print("\nDo you have a Hugging Face token? (Required for Llama models)")
|
||||
print("If not, we'll show you how to get one.\n")
|
||||
|
||||
has_token = input("Do you have a Hugging Face token? (y/n): ").strip().lower()
|
||||
|
||||
if has_token == 'y':
|
||||
hf_token = input("\nEnter your Hugging Face token (starts with 'hf_'): ").strip()
|
||||
if hf_token and hf_token.startswith('hf_'):
|
||||
# Store token in extraHeaders
|
||||
config["providers"]["airllm"]["extraHeaders"]["hf_token"] = hf_token
|
||||
# Also set apiKey to model path (AirLLM uses apiKey as model path if it contains '/')
|
||||
config["providers"]["airllm"]["apiKey"] = config["agents"]["defaults"]["model"]
|
||||
print("\n✓ Token configured!")
|
||||
else:
|
||||
print("⚠ Warning: Token doesn't look valid (should start with 'hf_')")
|
||||
print("You can add it later by editing the config file.")
|
||||
# Still set model path in apiKey
|
||||
config["providers"]["airllm"]["apiKey"] = config["agents"]["defaults"]["model"]
|
||||
else:
|
||||
get_hf_token_instructions()
|
||||
print("\nYou can add your token later by:")
|
||||
print(f"1. Editing: {CONFIG_PATH}")
|
||||
print("2. Adding your token to: providers.airllm.extraHeaders.hf_token")
|
||||
print("\nOr run this script again after getting your token.")
|
||||
|
||||
return config
|
||||
|
||||
def save_config(config):
|
||||
"""Save config to file."""
|
||||
CONFIG_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(CONFIG_PATH, 'w') as f:
|
||||
json.dump(config, f, indent=2)
|
||||
|
||||
# Set secure permissions
|
||||
os.chmod(CONFIG_PATH, 0o600)
|
||||
print(f"\n✓ Configuration saved to: {CONFIG_PATH}")
|
||||
print(f"✓ File permissions set to 600 (read/write for owner only)")
|
||||
|
||||
def main():
|
||||
"""Main setup function."""
|
||||
print("\n" + "="*70)
|
||||
print("NANOBOT LLAMA + AIRLLM SETUP")
|
||||
print("="*70)
|
||||
print("\nThis script will configure nanobot to use Llama models with AirLLM.\n")
|
||||
|
||||
if CONFIG_PATH.exists():
|
||||
print(f"Found existing config at: {CONFIG_PATH}")
|
||||
backup = input("\nCreate backup? (y/n): ").strip().lower()
|
||||
if backup == 'y':
|
||||
backup_path = CONFIG_PATH.with_suffix('.json.backup')
|
||||
import shutil
|
||||
shutil.copy(CONFIG_PATH, backup_path)
|
||||
print(f"✓ Backup created: {backup_path}")
|
||||
else:
|
||||
print(f"Creating new config at: {CONFIG_PATH}")
|
||||
|
||||
config = create_llama_config()
|
||||
save_config(config)
|
||||
|
||||
print("\n" + "="*70)
|
||||
print("SETUP COMPLETE!")
|
||||
print("="*70)
|
||||
print("\nConfiguration:")
|
||||
print(f" Model: {config['agents']['defaults']['model']}")
|
||||
print(f" Provider: airllm")
|
||||
if config["providers"]["airllm"].get("extraHeaders", {}).get("hf_token"):
|
||||
print(f" HF Token: {'*' * 20} (configured)")
|
||||
else:
|
||||
print(f" HF Token: Not configured (add it to use gated models)")
|
||||
|
||||
print("\nNext steps:")
|
||||
print(" 1. If you need a Hugging Face token, follow the instructions above")
|
||||
print(" 2. Test it: nanobot agent -m 'Hello, what is 2+5?'")
|
||||
print("\n" + "="*70 + "\n")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user