Clean up providers: keep only Ollama, AirLLM, vLLM, and DeepSeek

- Remove Qwen/DashScope provider and all Qwen-specific code
- Remove gateway providers (OpenRouter, AiHubMix)
- Remove cloud providers (Anthropic, OpenAI, Gemini, Zhipu, Moonshot, MiniMax, Groq)
- Update default model from Platypus to llama3.2
- Remove Platypus references throughout codebase
- Add AirLLM provider support with local model path support
- Update setup scripts to only show Llama models
- Clean up provider registry and config schema
This commit is contained in:
Tanya 2026-02-17 14:20:47 -05:00
parent dd63337a83
commit f1e95626f8
9 changed files with 1057 additions and 216 deletions

242
airllm_ollama_wrapper.py Normal file
View File

@ -0,0 +1,242 @@
#!/usr/bin/env python3
"""
AirLLM Ollama-Compatible Wrapper
This wrapper provides an Ollama-like interface for AirLLM,
making it easy to replace Ollama in existing projects.
"""
import torch
from typing import List, Dict, Optional, Union
# Try to import airllm, handle BetterTransformer import error gracefully
try:
from airllm import AutoModel
AIRLLM_AVAILABLE = True
except ImportError as e:
if "optimum.bettertransformer" in str(e) or "BetterTransformer" in str(e):
# Try to work around BetterTransformer import issue
import sys
import importlib.util
# Create a dummy BetterTransformer module to allow airllm to import
class DummyBetterTransformer:
@staticmethod
def transform(model):
return model
# Inject dummy module before importing airllm
spec = importlib.util.spec_from_loader("optimum.bettertransformer", None)
dummy_module = importlib.util.module_from_spec(spec)
dummy_module.BetterTransformer = DummyBetterTransformer
sys.modules["optimum.bettertransformer"] = dummy_module
try:
from airllm import AutoModel
AIRLLM_AVAILABLE = True
except ImportError:
AIRLLM_AVAILABLE = False
AutoModel = None
else:
AIRLLM_AVAILABLE = False
AutoModel = None
class AirLLMOllamaWrapper:
"""
A wrapper that provides an Ollama-like API for AirLLM.
Usage:
# Instead of: ollama.generate(model="llama2", prompt="Hello")
# Use: airllm_wrapper.generate(model="llama2", prompt="Hello")
"""
def __init__(self, model_name: str, compression: Optional[str] = None, **kwargs):
"""
Initialize AirLLM model.
Args:
model_name: Hugging Face model name or path (e.g., "meta-llama/Llama-3.2-3B-Instruct")
compression: Optional compression ('4bit' or '8bit') for 3x speed improvement
**kwargs: Additional arguments for AutoModel.from_pretrained()
"""
if not AIRLLM_AVAILABLE or AutoModel is None:
raise ImportError(
"AirLLM is not available. Please install it with: pip install airllm bitsandbytes\n"
"If you see a BetterTransformer error, you may need to install: pip install optimum[bettertransformer]"
)
print(f"Loading AirLLM model: {model_name}")
self.model = AutoModel.from_pretrained(
model_name,
compression=compression,
**kwargs
)
self.model_name = model_name
print("Model loaded successfully!")
def generate(
self,
prompt: str,
model: Optional[str] = None, # Ignored, kept for API compatibility
max_tokens: int = 50,
temperature: float = 0.7,
top_p: float = 0.9,
stream: bool = False,
**kwargs
) -> Union[str, Dict]:
"""
Generate text from a prompt (Ollama-compatible interface).
Args:
prompt: Input text prompt
model: Ignored (kept for compatibility)
max_tokens: Maximum number of tokens to generate
temperature: Sampling temperature (0.0 to 1.0)
top_p: Nucleus sampling parameter
stream: If True, return streaming response (not yet implemented)
**kwargs: Additional generation parameters
Returns:
Generated text string or dict with response
"""
# Tokenize input
input_tokens = self.model.tokenizer(
[prompt],
return_tensors="pt",
return_attention_mask=False,
truncation=True,
max_length=512, # Adjust as needed
padding=False
)
# Move to GPU if available
device = 'cuda' if torch.cuda.is_available() else 'cpu'
input_ids = input_tokens['input_ids'].to(device)
# Prepare generation parameters
gen_kwargs = {
'max_new_tokens': max_tokens,
'use_cache': True,
'return_dict_in_generate': True,
'temperature': temperature,
'top_p': top_p,
**kwargs
}
# Generate
with torch.inference_mode():
generation_output = self.model.generate(input_ids, **gen_kwargs)
# Decode output
output = self.model.tokenizer.decode(generation_output.sequences[0])
# Remove the input prompt from output (if present)
if output.startswith(prompt):
output = output[len(prompt):].strip()
if stream:
# For streaming, return a generator (simplified version)
return {"response": output}
else:
return output
def chat(
self,
messages: List[Dict[str, str]],
model: Optional[str] = None,
max_tokens: int = 50,
temperature: float = 0.7,
**kwargs
) -> str:
"""
Chat interface (Ollama-compatible).
Args:
messages: List of message dicts with 'role' and 'content' keys
model: Ignored (kept for compatibility)
max_tokens: Maximum tokens to generate
temperature: Sampling temperature
**kwargs: Additional parameters
Returns:
Generated response string
"""
# Format messages into a prompt
prompt = self._format_messages(messages)
return self.generate(
prompt=prompt,
max_tokens=max_tokens,
temperature=temperature,
**kwargs
)
def _format_messages(self, messages: List[Dict[str, str]]) -> str:
"""Format chat messages into a single prompt."""
formatted = []
for msg in messages:
role = msg.get('role', 'user')
content = msg.get('content', '')
if role == 'system':
formatted.append(f"System: {content}")
elif role == 'user':
formatted.append(f"User: {content}")
elif role == 'assistant':
formatted.append(f"Assistant: {content}")
return "\n".join(formatted) + "\nAssistant:"
def embeddings(self, prompt: str) -> List[float]:
"""
Get embeddings for a prompt (simplified - returns token embeddings).
Note: This is a simplified version. For full embeddings,
you may need to access model internals.
"""
tokens = self.model.tokenizer(
[prompt],
return_tensors="pt",
truncation=True,
max_length=512,
padding=False
)
# This is a placeholder - actual embeddings would require model forward pass
return tokens['input_ids'].tolist()[0]
# Convenience function for easy migration
def create_ollama_client(model_name: str, compression: Optional[str] = None, **kwargs):
"""
Create an Ollama-compatible client using AirLLM.
Usage:
client = create_ollama_client("meta-llama/Llama-3.2-3B-Instruct")
response = client.generate("Hello, how are you?")
"""
return AirLLMOllamaWrapper(model_name, compression=compression, **kwargs)
# Example usage
if __name__ == "__main__":
# Example 1: Basic generation
print("Example 1: Basic Generation")
print("=" * 60)
# Initialize (this will take time on first run)
# client = create_ollama_client("garage-bAInd/Platypus2-70B-instruct")
# Generate
# response = client.generate("What is the capital of France?")
# print(f"Response: {response}")
print("\nExample 2: Chat Interface")
print("=" * 60)
# Chat example
# messages = [
# {"role": "user", "content": "Hello! How are you?"}
# ]
# response = client.chat(messages)
# print(f"Response: {response}")
print("\nUncomment the code above to test!")

View File

@ -265,10 +265,60 @@ This file stores important information that should persist across sessions.
def _make_provider(config):
"""Create LiteLLMProvider from config. Exits if no API key found."""
from nanobot.providers.litellm_provider import LiteLLMProvider
"""Create LLM provider from config. Supports LiteLLMProvider and AirLLMProvider."""
provider_name = config.get_provider_name()
p = config.get_provider()
model = config.agents.defaults.model
# Check if AirLLM provider is requested
if provider_name == "airllm":
try:
from nanobot.providers.airllm_provider import AirLLMProvider
# AirLLM doesn't need API key, but we can use model path from config
# Check if model is specified in the airllm provider config
airllm_config = getattr(config.providers, "airllm", None)
model_path = None
compression = None
# Try to get model from airllm config's api_key field (repurposed as model path)
# or from the default model
if airllm_config and airllm_config.api_key:
# Check if api_key looks like a model path (contains '/') or is an HF token
if '/' in airllm_config.api_key:
model_path = airllm_config.api_key
hf_token = None
else:
# Treat as HF token, use model from defaults
model_path = model
hf_token = airllm_config.api_key
else:
model_path = model
hf_token = None
# Check for compression setting in extra_headers or api_base
if airllm_config:
if airllm_config.api_base:
compression = airllm_config.api_base # Repurpose api_base as compression
elif airllm_config.extra_headers and "compression" in airllm_config.extra_headers:
compression = airllm_config.extra_headers["compression"]
# Check for HF token in extra_headers
if not hf_token and airllm_config.extra_headers and "hf_token" in airllm_config.extra_headers:
hf_token = airllm_config.extra_headers["hf_token"]
return AirLLMProvider(
api_key=airllm_config.api_key if airllm_config else None,
api_base=compression if compression else None,
default_model=model_path,
compression=compression,
hf_token=hf_token,
)
except ImportError as e:
console.print(f"[red]Error: AirLLM provider not available: {e}[/red]")
console.print("Please ensure airllm_ollama_wrapper.py is in the Python path.")
raise typer.Exit(1)
# Default to LiteLLMProvider
from nanobot.providers.litellm_provider import LiteLLMProvider
if not (p and p.api_key) and not model.startswith("bedrock/"):
console.print("[red]Error: No API key configured.[/red]")
console.print("Set one in ~/.nanobot/config.json under providers section")
@ -278,7 +328,7 @@ def _make_provider(config):
api_base=config.get_api_base(),
default_model=model,
extra_headers=p.extra_headers if p else None,
provider_name=config.get_provider_name(),
provider_name=provider_name,
)

View File

@ -177,18 +177,10 @@ class ProviderConfig(BaseModel):
class ProvidersConfig(BaseModel):
"""Configuration for LLM providers."""
anthropic: ProviderConfig = Field(default_factory=ProviderConfig)
openai: ProviderConfig = Field(default_factory=ProviderConfig)
openrouter: ProviderConfig = Field(default_factory=ProviderConfig)
deepseek: ProviderConfig = Field(default_factory=ProviderConfig)
groq: ProviderConfig = Field(default_factory=ProviderConfig)
zhipu: ProviderConfig = Field(default_factory=ProviderConfig)
dashscope: ProviderConfig = Field(default_factory=ProviderConfig) # 阿里云通义千问
vllm: ProviderConfig = Field(default_factory=ProviderConfig)
gemini: ProviderConfig = Field(default_factory=ProviderConfig)
moonshot: ProviderConfig = Field(default_factory=ProviderConfig)
minimax: ProviderConfig = Field(default_factory=ProviderConfig)
aihubmix: ProviderConfig = Field(default_factory=ProviderConfig) # AiHubMix API gateway
ollama: ProviderConfig = Field(default_factory=ProviderConfig)
airllm: ProviderConfig = Field(default_factory=ProviderConfig)
class GatewayConfig(BaseModel):
@ -241,14 +233,37 @@ class Config(BaseSettings):
# Match by keyword (order follows PROVIDERS registry)
for spec in PROVIDERS:
p = getattr(self.providers, spec.name, None)
if p and any(kw in model_lower for kw in spec.keywords) and p.api_key:
return p, spec.name
if p and any(kw in model_lower for kw in spec.keywords):
# For local providers (Ollama, AirLLM), allow empty api_key or "dummy"
# For other providers, require api_key
if spec.is_local:
# Local providers can work with empty/dummy api_key
if p.api_key or p.api_base or spec.name == "airllm":
return p, spec.name
elif p.api_key:
return p, spec.name
# Check local providers by api_base detection (for explicit config)
for spec in PROVIDERS:
if spec.is_local:
p = getattr(self.providers, spec.name, None)
if p:
# Check if api_base matches the provider's detection pattern
if spec.detect_by_base_keyword and p.api_base and spec.detect_by_base_keyword in p.api_base:
return p, spec.name
# AirLLM is detected by provider name being "airllm"
if spec.name == "airllm" and p.api_key: # api_key can be model path
return p, spec.name
# Fallback: gateways first, then others (follows registry order)
for spec in PROVIDERS:
p = getattr(self.providers, spec.name, None)
if p and p.api_key:
return p, spec.name
if p:
# For local providers, allow empty/dummy api_key
if spec.is_local and (p.api_key or p.api_base):
return p, spec.name
elif p.api_key:
return p, spec.name
return None, None
def get_provider(self, model: str | None = None) -> ProviderConfig | None:

View File

@ -3,4 +3,8 @@
from nanobot.providers.base import LLMProvider, LLMResponse
from nanobot.providers.litellm_provider import LiteLLMProvider
__all__ = ["LLMProvider", "LLMResponse", "LiteLLMProvider"]
try:
from nanobot.providers.airllm_provider import AirLLMProvider
__all__ = ["LLMProvider", "LLMResponse", "LiteLLMProvider", "AirLLMProvider"]
except ImportError:
__all__ = ["LLMProvider", "LLMResponse", "LiteLLMProvider"]

View File

@ -0,0 +1,181 @@
"""AirLLM provider implementation for direct local model inference."""
import json
import asyncio
from typing import Any
from pathlib import Path
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
# Import the wrapper - handle import errors gracefully
try:
from nanobot.providers.airllm_wrapper import AirLLMOllamaWrapper, create_ollama_client
AIRLLM_WRAPPER_AVAILABLE = True
_import_error = None
except ImportError as e:
AIRLLM_WRAPPER_AVAILABLE = False
AirLLMOllamaWrapper = None
create_ollama_client = None
_import_error = str(e)
class AirLLMProvider(LLMProvider):
"""
LLM provider using AirLLM for direct local model inference.
This provider loads models directly into memory and runs inference locally,
bypassing HTTP API calls. It's optimized for GPU-limited environments.
"""
def __init__(
self,
api_key: str | None = None, # Repurposed: can be HF token or model name
api_base: str | None = None, # Repurposed: compression setting ('4bit' or '8bit')
default_model: str = "meta-llama/Llama-3.2-3B-Instruct",
compression: str | None = None, # '4bit' or '8bit' for speed improvement
model_path: str | None = None, # Override default model
hf_token: str | None = None, # Hugging Face token for gated models
):
super().__init__(api_key, api_base)
self.default_model = model_path or default_model
# If api_base is set and looks like compression, use it
if api_base and api_base in ('4bit', '8bit'):
self.compression = api_base
else:
self.compression = compression
# If api_key is provided and doesn't look like a model path, treat as HF token
if api_key and '/' not in api_key and len(api_key) > 20:
self.hf_token = api_key
else:
self.hf_token = hf_token
# If api_key looks like a model path, use it as the model
if api_key and '/' in api_key:
self.default_model = api_key
self._client: AirLLMOllamaWrapper | None = None
self._model_loaded = False
def _ensure_client(self) -> AirLLMOllamaWrapper:
"""Lazy-load the AirLLM client."""
if not AIRLLM_WRAPPER_AVAILABLE:
error_msg = (
"AirLLM wrapper is not available. Please ensure airllm_ollama_wrapper.py "
"is in the Python path and AirLLM is installed."
)
if '_import_error' in globals():
error_msg += f"\nImport error: {_import_error}"
raise ImportError(error_msg)
if self._client is None or not self._model_loaded:
print(f"Initializing AirLLM with model: {self.default_model}")
if self.compression:
print(f"Using compression: {self.compression}")
if self.hf_token:
print("Using Hugging Face token for authentication")
# Prepare kwargs for model loading
kwargs = {}
if self.hf_token:
kwargs['hf_token'] = self.hf_token
self._client = create_ollama_client(
self.default_model,
compression=self.compression,
**kwargs
)
self._model_loaded = True
print("AirLLM model loaded and ready!")
return self._client
async def chat(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
model: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
) -> LLMResponse:
"""
Send a chat completion request using AirLLM.
Args:
messages: List of message dicts with 'role' and 'content'.
tools: Optional list of tool definitions (Note: tool calling support may be limited).
model: Model identifier (ignored if different from initialized model).
max_tokens: Maximum tokens in response.
temperature: Sampling temperature.
Returns:
LLMResponse with content and/or tool calls.
"""
# If a different model is requested, we'd need to reload (expensive)
# For now, we'll use the initialized model
if model and model != self.default_model:
print(f"Warning: Model {model} requested but {self.default_model} is loaded. Using loaded model.")
client = self._ensure_client()
# Format tools into the prompt if provided (basic tool support)
# Note: Full tool calling requires model support and proper formatting
if tools:
# Add tool definitions to the system message or last user message
tools_text = "\n".join([
f"- {tool.get('function', {}).get('name', 'unknown')}: {tool.get('function', {}).get('description', '')}"
for tool in tools
])
# Append to messages (simplified - full implementation would format properly)
if messages and messages[-1].get('role') == 'user':
messages[-1]['content'] += f"\n\nAvailable tools:\n{tools_text}"
# Run the synchronous client in an executor to avoid blocking
loop = asyncio.get_event_loop()
response_text = await loop.run_in_executor(
None,
lambda: client.chat(
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
)
)
# Parse tool calls from response if present
# This is a simplified parser - you may need to adjust based on model output format
tool_calls = []
content = response_text
# Try to extract JSON tool calls from the response
# Some models return tool calls as JSON in the content
if "tool_calls" in response_text.lower() or "function" in response_text.lower():
try:
# Look for JSON blocks in the response
import re
json_pattern = r'\{[^{}]*"function"[^{}]*\}'
matches = re.findall(json_pattern, response_text, re.DOTALL)
for match in matches:
try:
tool_data = json.loads(match)
if "function" in tool_data:
func = tool_data["function"]
tool_calls.append(ToolCallRequest(
id=tool_data.get("id", f"call_{len(tool_calls)}"),
name=func.get("name", "unknown"),
arguments=func.get("arguments", {}),
))
# Remove the tool call from content
content = content.replace(match, "").strip()
except json.JSONDecodeError:
pass
except Exception:
pass # If parsing fails, just return the content as-is
return LLMResponse(
content=content,
tool_calls=tool_calls if tool_calls else [],
finish_reason="stop",
usage={}, # AirLLM doesn't provide usage stats in the wrapper
)
def get_default_model(self) -> str:
"""Get the default model."""
return self.default_model

View File

@ -0,0 +1,327 @@
#!/usr/bin/env python3
"""
AirLLM Ollama-Compatible Wrapper
This wrapper provides an Ollama-like interface for AirLLM,
making it easy to replace Ollama in existing projects.
"""
import torch
from typing import List, Dict, Optional, Union
# Try to import airllm, handle BetterTransformer import error gracefully
try:
from airllm import AutoModel
AIRLLM_AVAILABLE = True
except ImportError as e:
if "optimum.bettertransformer" in str(e) or "BetterTransformer" in str(e):
# Try to work around BetterTransformer import issue
import sys
import importlib.util
# Create a dummy BetterTransformer module to allow airllm to import
class DummyBetterTransformer:
@staticmethod
def transform(model):
return model
# Inject dummy module before importing airllm
spec = importlib.util.spec_from_loader("optimum.bettertransformer", None)
dummy_module = importlib.util.module_from_spec(spec)
dummy_module.BetterTransformer = DummyBetterTransformer
sys.modules["optimum.bettertransformer"] = dummy_module
try:
from airllm import AutoModel
AIRLLM_AVAILABLE = True
except ImportError:
AIRLLM_AVAILABLE = False
AutoModel = None
else:
AIRLLM_AVAILABLE = False
AutoModel = None
class AirLLMOllamaWrapper:
"""
A wrapper that provides an Ollama-like API for AirLLM.
Usage:
# Instead of: ollama.generate(model="llama2", prompt="Hello")
# Use: airllm_wrapper.generate(model="llama2", prompt="Hello")
"""
def __init__(self, model_name: str, compression: Optional[str] = None, **kwargs):
"""
Initialize AirLLM model.
Args:
model_name: Hugging Face model name or path (e.g., "meta-llama/Llama-3.2-3B-Instruct")
compression: Optional compression ('4bit' or '8bit') for 3x speed improvement
**kwargs: Additional arguments for AutoModel.from_pretrained()
"""
if not AIRLLM_AVAILABLE or AutoModel is None:
raise ImportError(
"AirLLM is not available. Please install it with: pip install airllm bitsandbytes\n"
"If you see a BetterTransformer error, you may need to install: pip install optimum[bettertransformer]"
)
print(f"Loading AirLLM model: {model_name}")
# AutoModel.from_pretrained() accepts:
# - Hugging Face model IDs (e.g., "meta-llama/Llama-3.1-8B-Instruct")
# - Local paths (e.g., "/path/to/local/model")
# - Can use local_dir parameter for local models
self.model = AutoModel.from_pretrained(
model_name,
compression=compression,
**kwargs
)
self.model_name = model_name
# Get the model's maximum sequence length for AirLLM
# IMPORTANT: AirLLM processes sequences in chunks, and each chunk must fit
# within the model's position embedding limits.
# Even if the model config says it supports longer sequences via rope scaling,
# AirLLM's chunking mechanism requires the base size.
self.max_length = 2048 # Default for Llama models
# Check if this is a Llama model to determine appropriate max length
is_llama = False
if hasattr(self.model, 'config'):
model_type = getattr(self.model.config, 'model_type', '').lower()
is_llama = 'llama' in model_type or 'llama' in self.model_name.lower()
if is_llama:
# Llama models: typically support 2048-4096 tokens
# AirLLM works well with Llama, so we can use larger chunks
if hasattr(self.model, 'config'):
config_max = getattr(self.model.config, 'max_position_embeddings', None)
if config_max and config_max > 0:
# Use config value, but cap at 2048 for AirLLM safety
self.max_length = min(config_max, 2048)
else:
self.max_length = 2048 # Safe default for Llama
else:
# For other models (e.g., DeepSeek), use conservative default
if hasattr(self.model, 'config'):
config_max = getattr(self.model.config, 'max_position_embeddings', None)
if config_max and config_max > 0 and config_max <= 2048:
self.max_length = config_max
else:
self.max_length = 512 # Very conservative
print(f"Using sequence length limit: {self.max_length} (AirLLM chunk size)")
print("Model loaded successfully!")
def generate(
self,
prompt: str,
model: Optional[str] = None, # Ignored, kept for API compatibility
max_tokens: int = 50,
temperature: float = 0.7,
top_p: float = 0.9,
stream: bool = False,
**kwargs
) -> Union[str, Dict]:
"""
Generate text from a prompt (Ollama-compatible interface).
Args:
prompt: Input text prompt
model: Ignored (kept for compatibility)
max_tokens: Maximum number of tokens to generate
temperature: Sampling temperature (0.0 to 1.0)
top_p: Nucleus sampling parameter
stream: If True, return streaming response (not yet implemented)
**kwargs: Additional generation parameters
Returns:
Generated text string or dict with response
"""
# Tokenize input with attention mask
# AirLLM processes sequences in chunks, but each chunk must fit within the model's
# position embedding limits. We need to ensure we don't exceed the chunk size.
# Use the model's max_length to ensure compatibility with position embeddings
input_tokens = self.model.tokenizer(
prompt,
return_tensors="pt",
return_attention_mask=True,
truncation=True,
max_length=self.max_length, # Respect model's position embedding limit
padding=False
)
# Move to GPU if available
device = 'cuda' if torch.cuda.is_available() else 'cpu'
input_ids = input_tokens['input_ids'].to(device)
attention_mask = input_tokens.get('attention_mask', None)
if attention_mask is not None:
attention_mask = attention_mask.to(device)
# Ensure we don't exceed max_length (manual truncation as safety check)
seq_length = input_ids.shape[1]
if seq_length > self.max_length:
print(f"Warning: Sequence length ({seq_length}) exceeds limit ({self.max_length}), truncating...")
input_ids = input_ids[:, :self.max_length]
if attention_mask is not None:
attention_mask = attention_mask[:, :self.max_length]
seq_length = self.max_length
if seq_length >= self.max_length:
print(f"Note: Using sequence of {seq_length} tokens (at limit: {self.max_length})")
# Prepare generation parameters
# For Llama models, we can use more tokens
max_gen_tokens = min(max_tokens, 512)
gen_kwargs = {
'max_new_tokens': max_gen_tokens,
'use_cache': True,
'return_dict_in_generate': True,
'temperature': temperature,
'top_p': top_p,
**kwargs
}
# Add attention mask if available
if attention_mask is not None:
gen_kwargs['attention_mask'] = attention_mask
# Generate
with torch.inference_mode():
generation_output = self.model.generate(input_ids, **gen_kwargs)
# Decode output - get only the newly generated tokens
if hasattr(generation_output, 'sequences'):
# Extract only the new tokens (after input length)
input_length = input_ids.shape[1]
generated_ids = generation_output.sequences[0, input_length:]
output = self.model.tokenizer.decode(generated_ids, skip_special_tokens=True)
else:
# Fallback for older output formats
output = self.model.tokenizer.decode(generation_output.sequences[0], skip_special_tokens=True)
# Remove the input prompt from output if present
if output.startswith(prompt):
output = output[len(prompt):].strip()
if stream:
# For streaming, return a generator (simplified version)
return {"response": output}
else:
return output
def chat(
self,
messages: List[Dict[str, str]],
model: Optional[str] = None,
max_tokens: int = 50,
temperature: float = 0.7,
**kwargs
) -> str:
"""
Chat interface (Ollama-compatible).
Args:
messages: List of message dicts with 'role' and 'content' keys
model: Ignored (kept for compatibility)
max_tokens: Maximum tokens to generate
temperature: Sampling temperature
**kwargs: Additional parameters
Returns:
Generated response string
"""
# Try to use the model's chat template if available (for Llama, etc.)
if hasattr(self.model.tokenizer, 'apply_chat_template') and self.model.tokenizer.chat_template:
try:
# Use the model's native chat template
prompt = self.model.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
except Exception:
# Fallback to simple formatting if chat template fails
prompt = self._format_messages(messages)
else:
# Fallback to simple formatting
prompt = self._format_messages(messages)
return self.generate(
prompt=prompt,
max_tokens=max_tokens,
temperature=temperature,
**kwargs
)
def _format_messages(self, messages: List[Dict[str, str]]) -> str:
"""Format chat messages into a single prompt (fallback method)."""
formatted = []
for msg in messages:
role = msg.get('role', 'user')
content = msg.get('content', '')
if role == 'system':
formatted.append(f"System: {content}")
elif role == 'user':
formatted.append(f"User: {content}")
elif role == 'assistant':
formatted.append(f"Assistant: {content}")
return "\n".join(formatted) + "\nAssistant:"
def embeddings(self, prompt: str) -> List[float]:
"""
Get embeddings for a prompt (simplified - returns token embeddings).
Note: This is a simplified version. For full embeddings,
you may need to access model internals.
"""
tokens = self.model.tokenizer(
[prompt],
return_tensors="pt",
truncation=True,
max_length=512,
padding=False
)
# This is a placeholder - actual embeddings would require model forward pass
return tokens['input_ids'].tolist()[0]
# Convenience function for easy migration
def create_ollama_client(model_name: str, compression: Optional[str] = None, **kwargs):
"""
Create an Ollama-compatible client using AirLLM.
Usage:
client = create_ollama_client("meta-llama/Llama-3.2-3B-Instruct")
response = client.generate("Hello, how are you?")
"""
return AirLLMOllamaWrapper(model_name, compression=compression, **kwargs)
# Example usage
if __name__ == "__main__":
# Example 1: Basic generation
print("Example 1: Basic Generation")
print("=" * 60)
# Initialize (this will take time on first run)
# client = create_ollama_client("meta-llama/Llama-3.2-3B-Instruct")
# Generate
# response = client.generate("What is the capital of France?")
# print(f"Response: {response}")
print("\nExample 2: Chat Interface")
print("=" * 60)
# Chat example
# messages = [
# {"role": "user", "content": "Hello! How are you?"}
# ]
# response = client.chat(messages)
# print(f"Response: {response}")
print("\nUncomment the code above to test!")

View File

@ -127,6 +127,7 @@ class LiteLLMProvider(LLMProvider):
"messages": messages,
"max_tokens": max_tokens,
"temperature": temperature,
"stream": False, # Explicitly disable streaming to avoid hangs with some providers
}
# Apply model-specific overrides (e.g. kimi-k2.5 temperature)
@ -148,6 +149,11 @@ class LiteLLMProvider(LLMProvider):
kwargs["tools"] = tools
kwargs["tool_choice"] = "auto"
# Add timeout to prevent hangs (especially with local servers)
# Ollama can be slow with complex prompts, so use a longer timeout
# Increased to 400s for larger models like mistral-nemo
kwargs["timeout"] = 400.0
try:
response = await acompletion(**kwargs)
return self._parse_response(response)

View File

@ -6,7 +6,7 @@ Adding a new provider:
2. Add a field to ProvidersConfig in config/schema.py.
Done. Env vars, prefixing, config matching, status display all derive from here.
Order matters it controls match priority and fallback. Gateways first.
Order matters it controls match priority and fallback.
Every entry writes out all fields so you can copy-paste as a template.
"""
@ -62,86 +62,10 @@ class ProviderSpec:
PROVIDERS: tuple[ProviderSpec, ...] = (
# === Gateways (detected by api_key / api_base, not model name) =========
# Gateways can route any model, so they win in fallback.
# OpenRouter: global gateway, keys start with "sk-or-"
ProviderSpec(
name="openrouter",
keywords=("openrouter",),
env_key="OPENROUTER_API_KEY",
display_name="OpenRouter",
litellm_prefix="openrouter", # claude-3 → openrouter/claude-3
skip_prefixes=(),
env_extras=(),
is_gateway=True,
is_local=False,
detect_by_key_prefix="sk-or-",
detect_by_base_keyword="openrouter",
default_api_base="https://openrouter.ai/api/v1",
strip_model_prefix=False,
model_overrides=(),
),
# AiHubMix: global gateway, OpenAI-compatible interface.
# strip_model_prefix=True: it doesn't understand "anthropic/claude-3",
# so we strip to bare "claude-3" then re-prefix as "openai/claude-3".
ProviderSpec(
name="aihubmix",
keywords=("aihubmix",),
env_key="OPENAI_API_KEY", # OpenAI-compatible
display_name="AiHubMix",
litellm_prefix="openai", # → openai/{model}
skip_prefixes=(),
env_extras=(),
is_gateway=True,
is_local=False,
detect_by_key_prefix="",
detect_by_base_keyword="aihubmix",
default_api_base="https://aihubmix.com/v1",
strip_model_prefix=True, # anthropic/claude-3 → claude-3 → openai/claude-3
model_overrides=(),
),
# === Standard providers (matched by model-name keywords) ===============
# Anthropic: LiteLLM recognizes "claude-*" natively, no prefix needed.
ProviderSpec(
name="anthropic",
keywords=("anthropic", "claude"),
env_key="ANTHROPIC_API_KEY",
display_name="Anthropic",
litellm_prefix="",
skip_prefixes=(),
env_extras=(),
is_gateway=False,
is_local=False,
detect_by_key_prefix="",
detect_by_base_keyword="",
default_api_base="",
strip_model_prefix=False,
model_overrides=(),
),
# OpenAI: LiteLLM recognizes "gpt-*" natively, no prefix needed.
ProviderSpec(
name="openai",
keywords=("openai", "gpt"),
env_key="OPENAI_API_KEY",
display_name="OpenAI",
litellm_prefix="",
skip_prefixes=(),
env_extras=(),
is_gateway=False,
is_local=False,
detect_by_key_prefix="",
detect_by_base_keyword="",
default_api_base="",
strip_model_prefix=False,
model_overrides=(),
),
# DeepSeek: needs "deepseek/" prefix for LiteLLM routing.
# Can be used with local models or API.
ProviderSpec(
name="deepseek",
keywords=("deepseek",),
@ -159,107 +83,6 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
model_overrides=(),
),
# Gemini: needs "gemini/" prefix for LiteLLM.
ProviderSpec(
name="gemini",
keywords=("gemini",),
env_key="GEMINI_API_KEY",
display_name="Gemini",
litellm_prefix="gemini", # gemini-pro → gemini/gemini-pro
skip_prefixes=("gemini/",), # avoid double-prefix
env_extras=(),
is_gateway=False,
is_local=False,
detect_by_key_prefix="",
detect_by_base_keyword="",
default_api_base="",
strip_model_prefix=False,
model_overrides=(),
),
# Zhipu: LiteLLM uses "zai/" prefix.
# Also mirrors key to ZHIPUAI_API_KEY (some LiteLLM paths check that).
# skip_prefixes: don't add "zai/" when already routed via gateway.
ProviderSpec(
name="zhipu",
keywords=("zhipu", "glm", "zai"),
env_key="ZAI_API_KEY",
display_name="Zhipu AI",
litellm_prefix="zai", # glm-4 → zai/glm-4
skip_prefixes=("zhipu/", "zai/", "openrouter/", "hosted_vllm/"),
env_extras=(
("ZHIPUAI_API_KEY", "{api_key}"),
),
is_gateway=False,
is_local=False,
detect_by_key_prefix="",
detect_by_base_keyword="",
default_api_base="",
strip_model_prefix=False,
model_overrides=(),
),
# DashScope: Qwen models, needs "dashscope/" prefix.
ProviderSpec(
name="dashscope",
keywords=("qwen", "dashscope"),
env_key="DASHSCOPE_API_KEY",
display_name="DashScope",
litellm_prefix="dashscope", # qwen-max → dashscope/qwen-max
skip_prefixes=("dashscope/", "openrouter/"),
env_extras=(),
is_gateway=False,
is_local=False,
detect_by_key_prefix="",
detect_by_base_keyword="",
default_api_base="",
strip_model_prefix=False,
model_overrides=(),
),
# Moonshot: Kimi models, needs "moonshot/" prefix.
# LiteLLM requires MOONSHOT_API_BASE env var to find the endpoint.
# Kimi K2.5 API enforces temperature >= 1.0.
ProviderSpec(
name="moonshot",
keywords=("moonshot", "kimi"),
env_key="MOONSHOT_API_KEY",
display_name="Moonshot",
litellm_prefix="moonshot", # kimi-k2.5 → moonshot/kimi-k2.5
skip_prefixes=("moonshot/", "openrouter/"),
env_extras=(
("MOONSHOT_API_BASE", "{api_base}"),
),
is_gateway=False,
is_local=False,
detect_by_key_prefix="",
detect_by_base_keyword="",
default_api_base="https://api.moonshot.ai/v1", # intl; use api.moonshot.cn for China
strip_model_prefix=False,
model_overrides=(
("kimi-k2.5", {"temperature": 1.0}),
),
),
# MiniMax: needs "minimax/" prefix for LiteLLM routing.
# Uses OpenAI-compatible API at api.minimax.io/v1.
ProviderSpec(
name="minimax",
keywords=("minimax",),
env_key="MINIMAX_API_KEY",
display_name="MiniMax",
litellm_prefix="minimax", # MiniMax-M2.1 → minimax/MiniMax-M2.1
skip_prefixes=("minimax/", "openrouter/"),
env_extras=(),
is_gateway=False,
is_local=False,
detect_by_key_prefix="",
detect_by_base_keyword="",
default_api_base="https://api.minimax.io/v1",
strip_model_prefix=False,
model_overrides=(),
),
# === Local deployment (matched by config key, NOT by api_base) =========
# vLLM / any OpenAI-compatible local server.
@ -281,23 +104,44 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
model_overrides=(),
),
# === Auxiliary (not a primary LLM provider) ============================
# Groq: mainly used for Whisper voice transcription, also usable for LLM.
# Needs "groq/" prefix for LiteLLM routing. Placed last — it rarely wins fallback.
# Ollama: local OpenAI-compatible server.
# Use OpenAI-compatible endpoint, not native Ollama API.
# Detected when config key is "ollama" or api_base contains "11434" or "ollama".
ProviderSpec(
name="groq",
keywords=("groq",),
env_key="GROQ_API_KEY",
display_name="Groq",
litellm_prefix="groq", # llama3-8b-8192 → groq/llama3-8b-8192
skip_prefixes=("groq/",), # avoid double-prefix
name="ollama",
keywords=("ollama", "llama"), # Match both "ollama" and "llama" model names
env_key="OPENAI_API_KEY", # Use OpenAI-compatible API
display_name="Ollama",
litellm_prefix="", # No prefix - use as OpenAI-compatible
skip_prefixes=(),
env_extras=(
("OPENAI_API_BASE", "{api_base}"), # Set OpenAI API base to Ollama endpoint
),
is_gateway=False,
is_local=True,
detect_by_key_prefix="",
detect_by_base_keyword="11434", # Detect by default Ollama port
default_api_base="http://localhost:11434/v1",
strip_model_prefix=False,
model_overrides=(),
),
# AirLLM: direct local model inference (no HTTP server).
# Loads models directly into memory for GPU-optimized inference.
# Detected when config key is "airllm".
ProviderSpec(
name="airllm",
keywords=("airllm",),
env_key="", # No API key needed (local)
display_name="AirLLM",
litellm_prefix="", # Not used with LiteLLM
skip_prefixes=(),
env_extras=(),
is_gateway=False,
is_local=False,
is_local=True,
detect_by_key_prefix="",
detect_by_base_keyword="",
default_api_base="",
default_api_base="", # Not used (direct Python calls)
strip_model_prefix=False,
model_overrides=(),
),
@ -325,12 +169,11 @@ def find_gateway(
api_key: str | None = None,
api_base: str | None = None,
) -> ProviderSpec | None:
"""Detect gateway/local provider.
"""Detect local provider.
Priority:
1. provider_name if it maps to a gateway/local spec, use it directly.
2. api_key prefix e.g. "sk-or-" OpenRouter.
3. api_base keyword e.g. "aihubmix" in URL AiHubMix.
1. provider_name if it maps to a local spec, use it directly.
2. api_base keyword e.g. "11434" in URL Ollama.
A standard provider with a custom api_base (e.g. DeepSeek behind a proxy)
will NOT be mistaken for vLLM the old fallback is gone.
@ -341,10 +184,8 @@ def find_gateway(
if spec and (spec.is_gateway or spec.is_local):
return spec
# 2. Auto-detect by api_key prefix / api_base keyword
# 2. Auto-detect by api_base keyword
for spec in PROVIDERS:
if spec.detect_by_key_prefix and api_key and api_key.startswith(spec.detect_by_key_prefix):
return spec
if spec.detect_by_base_keyword and api_base and spec.detect_by_base_keyword in api_base:
return spec

175
setup_llama_airllm.py Normal file
View File

@ -0,0 +1,175 @@
#!/usr/bin/env python3
"""
Setup script to configure nanobot to use Llama models with AirLLM.
This script will:
1. Check/create the config file
2. Set up Llama model configuration
3. Guide you through getting a Hugging Face token if needed
"""
import json
import os
from pathlib import Path
CONFIG_PATH = Path.home() / ".nanobot" / "config.json"
def get_hf_token_instructions():
"""Print instructions for getting a Hugging Face token."""
print("\n" + "="*70)
print("GETTING A HUGGING FACE TOKEN")
print("="*70)
print("\nTo use Llama models (which are gated), you need a Hugging Face token:")
print("\n1. Go to: https://huggingface.co/settings/tokens")
print("2. Click 'New token'")
print("3. Give it a name (e.g., 'nanobot')")
print("4. Select 'Read' permission")
print("5. Click 'Generate token'")
print("6. Copy the token (starts with 'hf_...')")
print("\nThen accept the Llama model license:")
print("1. Go to: https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct")
print("2. Click 'Agree and access repository'")
print("3. Accept the license terms")
print("\n" + "="*70 + "\n")
def load_existing_config():
"""Load existing config or return default."""
if CONFIG_PATH.exists():
try:
with open(CONFIG_PATH) as f:
return json.load(f)
except Exception as e:
print(f"Warning: Could not read existing config: {e}")
return {}
return {}
def create_llama_config():
"""Create or update config for Llama with AirLLM."""
config = load_existing_config()
# Ensure providers section exists
if "providers" not in config:
config["providers"] = {}
# Ensure agents section exists
if "agents" not in config:
config["agents"] = {}
if "defaults" not in config["agents"]:
config["agents"]["defaults"] = {}
# Choose Llama model
print("\n" + "="*70)
print("CHOOSE LLAMA MODEL")
print("="*70)
print("\nAvailable models:")
print(" 1. Llama-3.2-3B-Instruct (Recommended - fast, minimal memory)")
print(" 2. Llama-3.1-8B-Instruct (Good balance of performance and speed)")
print(" 3. Custom (enter model path)")
choice = input("\nChoose model (1-3, default: 1): ").strip() or "1"
model_map = {
"1": "meta-llama/Llama-3.2-3B-Instruct",
"2": "meta-llama/Llama-3.1-8B-Instruct",
}
if choice == "3":
model_path = input("Enter model path (e.g., meta-llama/Llama-3.2-3B-Instruct): ").strip()
if not model_path:
model_path = "meta-llama/Llama-3.2-3B-Instruct"
print(f"Using default: {model_path}")
else:
model_path = model_map.get(choice, "meta-llama/Llama-3.2-3B-Instruct")
# Set up AirLLM provider with Llama model
# Note: apiKey can be used as model path, or we can put model in defaults
config["providers"]["airllm"] = {
"apiKey": "", # Will be set to model path
"apiBase": None,
"extraHeaders": {}
}
# Set default model
config["agents"]["defaults"]["model"] = model_path
# Ask for Hugging Face token
print("\n" + "="*70)
print("HUGGING FACE TOKEN SETUP")
print("="*70)
print("\nDo you have a Hugging Face token? (Required for Llama models)")
print("If not, we'll show you how to get one.\n")
has_token = input("Do you have a Hugging Face token? (y/n): ").strip().lower()
if has_token == 'y':
hf_token = input("\nEnter your Hugging Face token (starts with 'hf_'): ").strip()
if hf_token and hf_token.startswith('hf_'):
# Store token in extraHeaders
config["providers"]["airllm"]["extraHeaders"]["hf_token"] = hf_token
# Also set apiKey to model path (AirLLM uses apiKey as model path if it contains '/')
config["providers"]["airllm"]["apiKey"] = config["agents"]["defaults"]["model"]
print("\n✓ Token configured!")
else:
print("⚠ Warning: Token doesn't look valid (should start with 'hf_')")
print("You can add it later by editing the config file.")
# Still set model path in apiKey
config["providers"]["airllm"]["apiKey"] = config["agents"]["defaults"]["model"]
else:
get_hf_token_instructions()
print("\nYou can add your token later by:")
print(f"1. Editing: {CONFIG_PATH}")
print("2. Adding your token to: providers.airllm.extraHeaders.hf_token")
print("\nOr run this script again after getting your token.")
return config
def save_config(config):
"""Save config to file."""
CONFIG_PATH.parent.mkdir(parents=True, exist_ok=True)
with open(CONFIG_PATH, 'w') as f:
json.dump(config, f, indent=2)
# Set secure permissions
os.chmod(CONFIG_PATH, 0o600)
print(f"\n✓ Configuration saved to: {CONFIG_PATH}")
print(f"✓ File permissions set to 600 (read/write for owner only)")
def main():
"""Main setup function."""
print("\n" + "="*70)
print("NANOBOT LLAMA + AIRLLM SETUP")
print("="*70)
print("\nThis script will configure nanobot to use Llama models with AirLLM.\n")
if CONFIG_PATH.exists():
print(f"Found existing config at: {CONFIG_PATH}")
backup = input("\nCreate backup? (y/n): ").strip().lower()
if backup == 'y':
backup_path = CONFIG_PATH.with_suffix('.json.backup')
import shutil
shutil.copy(CONFIG_PATH, backup_path)
print(f"✓ Backup created: {backup_path}")
else:
print(f"Creating new config at: {CONFIG_PATH}")
config = create_llama_config()
save_config(config)
print("\n" + "="*70)
print("SETUP COMPLETE!")
print("="*70)
print("\nConfiguration:")
print(f" Model: {config['agents']['defaults']['model']}")
print(f" Provider: airllm")
if config["providers"]["airllm"].get("extraHeaders", {}).get("hf_token"):
print(f" HF Token: {'*' * 20} (configured)")
else:
print(f" HF Token: Not configured (add it to use gated models)")
print("\nNext steps:")
print(" 1. If you need a Hugging Face token, follow the instructions above")
print(" 2. Test it: nanobot agent -m 'Hello, what is 2+5?'")
print("\n" + "="*70 + "\n")
if __name__ == "__main__":
main()