From f1e95626f83821ca6717ef3db505fa01ac9114e4 Mon Sep 17 00:00:00 2001 From: Tanya Date: Tue, 17 Feb 2026 14:20:47 -0500 Subject: [PATCH] Clean up providers: keep only Ollama, AirLLM, vLLM, and DeepSeek - Remove Qwen/DashScope provider and all Qwen-specific code - Remove gateway providers (OpenRouter, AiHubMix) - Remove cloud providers (Anthropic, OpenAI, Gemini, Zhipu, Moonshot, MiniMax, Groq) - Update default model from Platypus to llama3.2 - Remove Platypus references throughout codebase - Add AirLLM provider support with local model path support - Update setup scripts to only show Llama models - Clean up provider registry and config schema --- airllm_ollama_wrapper.py | 242 +++++++++++++++++++ nanobot/cli/commands.py | 56 ++++- nanobot/config/schema.py | 43 ++-- nanobot/providers/__init__.py | 6 +- nanobot/providers/airllm_provider.py | 181 ++++++++++++++ nanobot/providers/airllm_wrapper.py | 327 ++++++++++++++++++++++++++ nanobot/providers/litellm_provider.py | 6 + nanobot/providers/registry.py | 237 +++---------------- setup_llama_airllm.py | 175 ++++++++++++++ 9 files changed, 1057 insertions(+), 216 deletions(-) create mode 100644 airllm_ollama_wrapper.py create mode 100644 nanobot/providers/airllm_provider.py create mode 100644 nanobot/providers/airllm_wrapper.py create mode 100644 setup_llama_airllm.py diff --git a/airllm_ollama_wrapper.py b/airllm_ollama_wrapper.py new file mode 100644 index 0000000..cdd186a --- /dev/null +++ b/airllm_ollama_wrapper.py @@ -0,0 +1,242 @@ +#!/usr/bin/env python3 +""" +AirLLM Ollama-Compatible Wrapper + +This wrapper provides an Ollama-like interface for AirLLM, +making it easy to replace Ollama in existing projects. +""" + +import torch +from typing import List, Dict, Optional, Union + +# Try to import airllm, handle BetterTransformer import error gracefully +try: + from airllm import AutoModel + AIRLLM_AVAILABLE = True +except ImportError as e: + if "optimum.bettertransformer" in str(e) or "BetterTransformer" in str(e): + # Try to work around BetterTransformer import issue + import sys + import importlib.util + + # Create a dummy BetterTransformer module to allow airllm to import + class DummyBetterTransformer: + @staticmethod + def transform(model): + return model + + # Inject dummy module before importing airllm + spec = importlib.util.spec_from_loader("optimum.bettertransformer", None) + dummy_module = importlib.util.module_from_spec(spec) + dummy_module.BetterTransformer = DummyBetterTransformer + sys.modules["optimum.bettertransformer"] = dummy_module + + try: + from airllm import AutoModel + AIRLLM_AVAILABLE = True + except ImportError: + AIRLLM_AVAILABLE = False + AutoModel = None + else: + AIRLLM_AVAILABLE = False + AutoModel = None + + +class AirLLMOllamaWrapper: + """ + A wrapper that provides an Ollama-like API for AirLLM. + + Usage: + # Instead of: ollama.generate(model="llama2", prompt="Hello") + # Use: airllm_wrapper.generate(model="llama2", prompt="Hello") + """ + + def __init__(self, model_name: str, compression: Optional[str] = None, **kwargs): + """ + Initialize AirLLM model. + + Args: + model_name: Hugging Face model name or path (e.g., "meta-llama/Llama-3.2-3B-Instruct") + compression: Optional compression ('4bit' or '8bit') for 3x speed improvement + **kwargs: Additional arguments for AutoModel.from_pretrained() + """ + if not AIRLLM_AVAILABLE or AutoModel is None: + raise ImportError( + "AirLLM is not available. Please install it with: pip install airllm bitsandbytes\n" + "If you see a BetterTransformer error, you may need to install: pip install optimum[bettertransformer]" + ) + + print(f"Loading AirLLM model: {model_name}") + self.model = AutoModel.from_pretrained( + model_name, + compression=compression, + **kwargs + ) + self.model_name = model_name + print("Model loaded successfully!") + + def generate( + self, + prompt: str, + model: Optional[str] = None, # Ignored, kept for API compatibility + max_tokens: int = 50, + temperature: float = 0.7, + top_p: float = 0.9, + stream: bool = False, + **kwargs + ) -> Union[str, Dict]: + """ + Generate text from a prompt (Ollama-compatible interface). + + Args: + prompt: Input text prompt + model: Ignored (kept for compatibility) + max_tokens: Maximum number of tokens to generate + temperature: Sampling temperature (0.0 to 1.0) + top_p: Nucleus sampling parameter + stream: If True, return streaming response (not yet implemented) + **kwargs: Additional generation parameters + + Returns: + Generated text string or dict with response + """ + # Tokenize input + input_tokens = self.model.tokenizer( + [prompt], + return_tensors="pt", + return_attention_mask=False, + truncation=True, + max_length=512, # Adjust as needed + padding=False + ) + + # Move to GPU if available + device = 'cuda' if torch.cuda.is_available() else 'cpu' + input_ids = input_tokens['input_ids'].to(device) + + # Prepare generation parameters + gen_kwargs = { + 'max_new_tokens': max_tokens, + 'use_cache': True, + 'return_dict_in_generate': True, + 'temperature': temperature, + 'top_p': top_p, + **kwargs + } + + # Generate + with torch.inference_mode(): + generation_output = self.model.generate(input_ids, **gen_kwargs) + + # Decode output + output = self.model.tokenizer.decode(generation_output.sequences[0]) + + # Remove the input prompt from output (if present) + if output.startswith(prompt): + output = output[len(prompt):].strip() + + if stream: + # For streaming, return a generator (simplified version) + return {"response": output} + else: + return output + + def chat( + self, + messages: List[Dict[str, str]], + model: Optional[str] = None, + max_tokens: int = 50, + temperature: float = 0.7, + **kwargs + ) -> str: + """ + Chat interface (Ollama-compatible). + + Args: + messages: List of message dicts with 'role' and 'content' keys + model: Ignored (kept for compatibility) + max_tokens: Maximum tokens to generate + temperature: Sampling temperature + **kwargs: Additional parameters + + Returns: + Generated response string + """ + # Format messages into a prompt + prompt = self._format_messages(messages) + return self.generate( + prompt=prompt, + max_tokens=max_tokens, + temperature=temperature, + **kwargs + ) + + def _format_messages(self, messages: List[Dict[str, str]]) -> str: + """Format chat messages into a single prompt.""" + formatted = [] + for msg in messages: + role = msg.get('role', 'user') + content = msg.get('content', '') + if role == 'system': + formatted.append(f"System: {content}") + elif role == 'user': + formatted.append(f"User: {content}") + elif role == 'assistant': + formatted.append(f"Assistant: {content}") + return "\n".join(formatted) + "\nAssistant:" + + def embeddings(self, prompt: str) -> List[float]: + """ + Get embeddings for a prompt (simplified - returns token embeddings). + + Note: This is a simplified version. For full embeddings, + you may need to access model internals. + """ + tokens = self.model.tokenizer( + [prompt], + return_tensors="pt", + truncation=True, + max_length=512, + padding=False + ) + # This is a placeholder - actual embeddings would require model forward pass + return tokens['input_ids'].tolist()[0] + + +# Convenience function for easy migration +def create_ollama_client(model_name: str, compression: Optional[str] = None, **kwargs): + """ + Create an Ollama-compatible client using AirLLM. + + Usage: + client = create_ollama_client("meta-llama/Llama-3.2-3B-Instruct") + response = client.generate("Hello, how are you?") + """ + return AirLLMOllamaWrapper(model_name, compression=compression, **kwargs) + + +# Example usage +if __name__ == "__main__": + # Example 1: Basic generation + print("Example 1: Basic Generation") + print("=" * 60) + + # Initialize (this will take time on first run) + # client = create_ollama_client("garage-bAInd/Platypus2-70B-instruct") + + # Generate + # response = client.generate("What is the capital of France?") + # print(f"Response: {response}") + + print("\nExample 2: Chat Interface") + print("=" * 60) + + # Chat example + # messages = [ + # {"role": "user", "content": "Hello! How are you?"} + # ] + # response = client.chat(messages) + # print(f"Response: {response}") + + print("\nUncomment the code above to test!") + diff --git a/nanobot/cli/commands.py b/nanobot/cli/commands.py index aa99d55..ee98536 100644 --- a/nanobot/cli/commands.py +++ b/nanobot/cli/commands.py @@ -265,10 +265,60 @@ This file stores important information that should persist across sessions. def _make_provider(config): - """Create LiteLLMProvider from config. Exits if no API key found.""" - from nanobot.providers.litellm_provider import LiteLLMProvider + """Create LLM provider from config. Supports LiteLLMProvider and AirLLMProvider.""" + provider_name = config.get_provider_name() p = config.get_provider() model = config.agents.defaults.model + + # Check if AirLLM provider is requested + if provider_name == "airllm": + try: + from nanobot.providers.airllm_provider import AirLLMProvider + # AirLLM doesn't need API key, but we can use model path from config + # Check if model is specified in the airllm provider config + airllm_config = getattr(config.providers, "airllm", None) + model_path = None + compression = None + + # Try to get model from airllm config's api_key field (repurposed as model path) + # or from the default model + if airllm_config and airllm_config.api_key: + # Check if api_key looks like a model path (contains '/') or is an HF token + if '/' in airllm_config.api_key: + model_path = airllm_config.api_key + hf_token = None + else: + # Treat as HF token, use model from defaults + model_path = model + hf_token = airllm_config.api_key + else: + model_path = model + hf_token = None + + # Check for compression setting in extra_headers or api_base + if airllm_config: + if airllm_config.api_base: + compression = airllm_config.api_base # Repurpose api_base as compression + elif airllm_config.extra_headers and "compression" in airllm_config.extra_headers: + compression = airllm_config.extra_headers["compression"] + # Check for HF token in extra_headers + if not hf_token and airllm_config.extra_headers and "hf_token" in airllm_config.extra_headers: + hf_token = airllm_config.extra_headers["hf_token"] + + return AirLLMProvider( + api_key=airllm_config.api_key if airllm_config else None, + api_base=compression if compression else None, + default_model=model_path, + compression=compression, + hf_token=hf_token, + ) + except ImportError as e: + console.print(f"[red]Error: AirLLM provider not available: {e}[/red]") + console.print("Please ensure airllm_ollama_wrapper.py is in the Python path.") + raise typer.Exit(1) + + # Default to LiteLLMProvider + from nanobot.providers.litellm_provider import LiteLLMProvider if not (p and p.api_key) and not model.startswith("bedrock/"): console.print("[red]Error: No API key configured.[/red]") console.print("Set one in ~/.nanobot/config.json under providers section") @@ -278,7 +328,7 @@ def _make_provider(config): api_base=config.get_api_base(), default_model=model, extra_headers=p.extra_headers if p else None, - provider_name=config.get_provider_name(), + provider_name=provider_name, ) diff --git a/nanobot/config/schema.py b/nanobot/config/schema.py index 19feba4..797b331 100644 --- a/nanobot/config/schema.py +++ b/nanobot/config/schema.py @@ -177,18 +177,10 @@ class ProviderConfig(BaseModel): class ProvidersConfig(BaseModel): """Configuration for LLM providers.""" - anthropic: ProviderConfig = Field(default_factory=ProviderConfig) - openai: ProviderConfig = Field(default_factory=ProviderConfig) - openrouter: ProviderConfig = Field(default_factory=ProviderConfig) deepseek: ProviderConfig = Field(default_factory=ProviderConfig) - groq: ProviderConfig = Field(default_factory=ProviderConfig) - zhipu: ProviderConfig = Field(default_factory=ProviderConfig) - dashscope: ProviderConfig = Field(default_factory=ProviderConfig) # 阿里云通义千问 vllm: ProviderConfig = Field(default_factory=ProviderConfig) - gemini: ProviderConfig = Field(default_factory=ProviderConfig) - moonshot: ProviderConfig = Field(default_factory=ProviderConfig) - minimax: ProviderConfig = Field(default_factory=ProviderConfig) - aihubmix: ProviderConfig = Field(default_factory=ProviderConfig) # AiHubMix API gateway + ollama: ProviderConfig = Field(default_factory=ProviderConfig) + airllm: ProviderConfig = Field(default_factory=ProviderConfig) class GatewayConfig(BaseModel): @@ -241,14 +233,37 @@ class Config(BaseSettings): # Match by keyword (order follows PROVIDERS registry) for spec in PROVIDERS: p = getattr(self.providers, spec.name, None) - if p and any(kw in model_lower for kw in spec.keywords) and p.api_key: - return p, spec.name + if p and any(kw in model_lower for kw in spec.keywords): + # For local providers (Ollama, AirLLM), allow empty api_key or "dummy" + # For other providers, require api_key + if spec.is_local: + # Local providers can work with empty/dummy api_key + if p.api_key or p.api_base or spec.name == "airllm": + return p, spec.name + elif p.api_key: + return p, spec.name + + # Check local providers by api_base detection (for explicit config) + for spec in PROVIDERS: + if spec.is_local: + p = getattr(self.providers, spec.name, None) + if p: + # Check if api_base matches the provider's detection pattern + if spec.detect_by_base_keyword and p.api_base and spec.detect_by_base_keyword in p.api_base: + return p, spec.name + # AirLLM is detected by provider name being "airllm" + if spec.name == "airllm" and p.api_key: # api_key can be model path + return p, spec.name # Fallback: gateways first, then others (follows registry order) for spec in PROVIDERS: p = getattr(self.providers, spec.name, None) - if p and p.api_key: - return p, spec.name + if p: + # For local providers, allow empty/dummy api_key + if spec.is_local and (p.api_key or p.api_base): + return p, spec.name + elif p.api_key: + return p, spec.name return None, None def get_provider(self, model: str | None = None) -> ProviderConfig | None: diff --git a/nanobot/providers/__init__.py b/nanobot/providers/__init__.py index ceff8fa..d2bdb18 100644 --- a/nanobot/providers/__init__.py +++ b/nanobot/providers/__init__.py @@ -3,4 +3,8 @@ from nanobot.providers.base import LLMProvider, LLMResponse from nanobot.providers.litellm_provider import LiteLLMProvider -__all__ = ["LLMProvider", "LLMResponse", "LiteLLMProvider"] +try: + from nanobot.providers.airllm_provider import AirLLMProvider + __all__ = ["LLMProvider", "LLMResponse", "LiteLLMProvider", "AirLLMProvider"] +except ImportError: + __all__ = ["LLMProvider", "LLMResponse", "LiteLLMProvider"] diff --git a/nanobot/providers/airllm_provider.py b/nanobot/providers/airllm_provider.py new file mode 100644 index 0000000..e57aa6b --- /dev/null +++ b/nanobot/providers/airllm_provider.py @@ -0,0 +1,181 @@ +"""AirLLM provider implementation for direct local model inference.""" + +import json +import asyncio +from typing import Any +from pathlib import Path + +from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest + +# Import the wrapper - handle import errors gracefully +try: + from nanobot.providers.airllm_wrapper import AirLLMOllamaWrapper, create_ollama_client + AIRLLM_WRAPPER_AVAILABLE = True + _import_error = None +except ImportError as e: + AIRLLM_WRAPPER_AVAILABLE = False + AirLLMOllamaWrapper = None + create_ollama_client = None + _import_error = str(e) + + +class AirLLMProvider(LLMProvider): + """ + LLM provider using AirLLM for direct local model inference. + + This provider loads models directly into memory and runs inference locally, + bypassing HTTP API calls. It's optimized for GPU-limited environments. + """ + + def __init__( + self, + api_key: str | None = None, # Repurposed: can be HF token or model name + api_base: str | None = None, # Repurposed: compression setting ('4bit' or '8bit') + default_model: str = "meta-llama/Llama-3.2-3B-Instruct", + compression: str | None = None, # '4bit' or '8bit' for speed improvement + model_path: str | None = None, # Override default model + hf_token: str | None = None, # Hugging Face token for gated models + ): + super().__init__(api_key, api_base) + self.default_model = model_path or default_model + # If api_base is set and looks like compression, use it + if api_base and api_base in ('4bit', '8bit'): + self.compression = api_base + else: + self.compression = compression + # If api_key is provided and doesn't look like a model path, treat as HF token + if api_key and '/' not in api_key and len(api_key) > 20: + self.hf_token = api_key + else: + self.hf_token = hf_token + # If api_key looks like a model path, use it as the model + if api_key and '/' in api_key: + self.default_model = api_key + self._client: AirLLMOllamaWrapper | None = None + self._model_loaded = False + + def _ensure_client(self) -> AirLLMOllamaWrapper: + """Lazy-load the AirLLM client.""" + if not AIRLLM_WRAPPER_AVAILABLE: + error_msg = ( + "AirLLM wrapper is not available. Please ensure airllm_ollama_wrapper.py " + "is in the Python path and AirLLM is installed." + ) + if '_import_error' in globals(): + error_msg += f"\nImport error: {_import_error}" + raise ImportError(error_msg) + + if self._client is None or not self._model_loaded: + print(f"Initializing AirLLM with model: {self.default_model}") + if self.compression: + print(f"Using compression: {self.compression}") + if self.hf_token: + print("Using Hugging Face token for authentication") + + # Prepare kwargs for model loading + kwargs = {} + if self.hf_token: + kwargs['hf_token'] = self.hf_token + + self._client = create_ollama_client( + self.default_model, + compression=self.compression, + **kwargs + ) + self._model_loaded = True + print("AirLLM model loaded and ready!") + + return self._client + + async def chat( + self, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None = None, + model: str | None = None, + max_tokens: int = 4096, + temperature: float = 0.7, + ) -> LLMResponse: + """ + Send a chat completion request using AirLLM. + + Args: + messages: List of message dicts with 'role' and 'content'. + tools: Optional list of tool definitions (Note: tool calling support may be limited). + model: Model identifier (ignored if different from initialized model). + max_tokens: Maximum tokens in response. + temperature: Sampling temperature. + + Returns: + LLMResponse with content and/or tool calls. + """ + # If a different model is requested, we'd need to reload (expensive) + # For now, we'll use the initialized model + if model and model != self.default_model: + print(f"Warning: Model {model} requested but {self.default_model} is loaded. Using loaded model.") + + client = self._ensure_client() + + # Format tools into the prompt if provided (basic tool support) + # Note: Full tool calling requires model support and proper formatting + if tools: + # Add tool definitions to the system message or last user message + tools_text = "\n".join([ + f"- {tool.get('function', {}).get('name', 'unknown')}: {tool.get('function', {}).get('description', '')}" + for tool in tools + ]) + # Append to messages (simplified - full implementation would format properly) + if messages and messages[-1].get('role') == 'user': + messages[-1]['content'] += f"\n\nAvailable tools:\n{tools_text}" + + # Run the synchronous client in an executor to avoid blocking + loop = asyncio.get_event_loop() + response_text = await loop.run_in_executor( + None, + lambda: client.chat( + messages=messages, + max_tokens=max_tokens, + temperature=temperature, + ) + ) + + # Parse tool calls from response if present + # This is a simplified parser - you may need to adjust based on model output format + tool_calls = [] + content = response_text + + # Try to extract JSON tool calls from the response + # Some models return tool calls as JSON in the content + if "tool_calls" in response_text.lower() or "function" in response_text.lower(): + try: + # Look for JSON blocks in the response + import re + json_pattern = r'\{[^{}]*"function"[^{}]*\}' + matches = re.findall(json_pattern, response_text, re.DOTALL) + for match in matches: + try: + tool_data = json.loads(match) + if "function" in tool_data: + func = tool_data["function"] + tool_calls.append(ToolCallRequest( + id=tool_data.get("id", f"call_{len(tool_calls)}"), + name=func.get("name", "unknown"), + arguments=func.get("arguments", {}), + )) + # Remove the tool call from content + content = content.replace(match, "").strip() + except json.JSONDecodeError: + pass + except Exception: + pass # If parsing fails, just return the content as-is + + return LLMResponse( + content=content, + tool_calls=tool_calls if tool_calls else [], + finish_reason="stop", + usage={}, # AirLLM doesn't provide usage stats in the wrapper + ) + + def get_default_model(self) -> str: + """Get the default model.""" + return self.default_model + diff --git a/nanobot/providers/airllm_wrapper.py b/nanobot/providers/airllm_wrapper.py new file mode 100644 index 0000000..e2e2723 --- /dev/null +++ b/nanobot/providers/airllm_wrapper.py @@ -0,0 +1,327 @@ +#!/usr/bin/env python3 +""" +AirLLM Ollama-Compatible Wrapper + +This wrapper provides an Ollama-like interface for AirLLM, +making it easy to replace Ollama in existing projects. +""" + +import torch +from typing import List, Dict, Optional, Union + +# Try to import airllm, handle BetterTransformer import error gracefully +try: + from airllm import AutoModel + AIRLLM_AVAILABLE = True +except ImportError as e: + if "optimum.bettertransformer" in str(e) or "BetterTransformer" in str(e): + # Try to work around BetterTransformer import issue + import sys + import importlib.util + + # Create a dummy BetterTransformer module to allow airllm to import + class DummyBetterTransformer: + @staticmethod + def transform(model): + return model + + # Inject dummy module before importing airllm + spec = importlib.util.spec_from_loader("optimum.bettertransformer", None) + dummy_module = importlib.util.module_from_spec(spec) + dummy_module.BetterTransformer = DummyBetterTransformer + sys.modules["optimum.bettertransformer"] = dummy_module + + try: + from airllm import AutoModel + AIRLLM_AVAILABLE = True + except ImportError: + AIRLLM_AVAILABLE = False + AutoModel = None + else: + AIRLLM_AVAILABLE = False + AutoModel = None + + +class AirLLMOllamaWrapper: + """ + A wrapper that provides an Ollama-like API for AirLLM. + + Usage: + # Instead of: ollama.generate(model="llama2", prompt="Hello") + # Use: airllm_wrapper.generate(model="llama2", prompt="Hello") + """ + + def __init__(self, model_name: str, compression: Optional[str] = None, **kwargs): + """ + Initialize AirLLM model. + + Args: + model_name: Hugging Face model name or path (e.g., "meta-llama/Llama-3.2-3B-Instruct") + compression: Optional compression ('4bit' or '8bit') for 3x speed improvement + **kwargs: Additional arguments for AutoModel.from_pretrained() + """ + if not AIRLLM_AVAILABLE or AutoModel is None: + raise ImportError( + "AirLLM is not available. Please install it with: pip install airllm bitsandbytes\n" + "If you see a BetterTransformer error, you may need to install: pip install optimum[bettertransformer]" + ) + + print(f"Loading AirLLM model: {model_name}") + # AutoModel.from_pretrained() accepts: + # - Hugging Face model IDs (e.g., "meta-llama/Llama-3.1-8B-Instruct") + # - Local paths (e.g., "/path/to/local/model") + # - Can use local_dir parameter for local models + self.model = AutoModel.from_pretrained( + model_name, + compression=compression, + **kwargs + ) + self.model_name = model_name + + # Get the model's maximum sequence length for AirLLM + # IMPORTANT: AirLLM processes sequences in chunks, and each chunk must fit + # within the model's position embedding limits. + # Even if the model config says it supports longer sequences via rope scaling, + # AirLLM's chunking mechanism requires the base size. + + self.max_length = 2048 # Default for Llama models + + # Check if this is a Llama model to determine appropriate max length + is_llama = False + if hasattr(self.model, 'config'): + model_type = getattr(self.model.config, 'model_type', '').lower() + is_llama = 'llama' in model_type or 'llama' in self.model_name.lower() + + if is_llama: + # Llama models: typically support 2048-4096 tokens + # AirLLM works well with Llama, so we can use larger chunks + if hasattr(self.model, 'config'): + config_max = getattr(self.model.config, 'max_position_embeddings', None) + if config_max and config_max > 0: + # Use config value, but cap at 2048 for AirLLM safety + self.max_length = min(config_max, 2048) + else: + self.max_length = 2048 # Safe default for Llama + else: + # For other models (e.g., DeepSeek), use conservative default + if hasattr(self.model, 'config'): + config_max = getattr(self.model.config, 'max_position_embeddings', None) + if config_max and config_max > 0 and config_max <= 2048: + self.max_length = config_max + else: + self.max_length = 512 # Very conservative + + print(f"Using sequence length limit: {self.max_length} (AirLLM chunk size)") + + print("Model loaded successfully!") + + def generate( + self, + prompt: str, + model: Optional[str] = None, # Ignored, kept for API compatibility + max_tokens: int = 50, + temperature: float = 0.7, + top_p: float = 0.9, + stream: bool = False, + **kwargs + ) -> Union[str, Dict]: + """ + Generate text from a prompt (Ollama-compatible interface). + + Args: + prompt: Input text prompt + model: Ignored (kept for compatibility) + max_tokens: Maximum number of tokens to generate + temperature: Sampling temperature (0.0 to 1.0) + top_p: Nucleus sampling parameter + stream: If True, return streaming response (not yet implemented) + **kwargs: Additional generation parameters + + Returns: + Generated text string or dict with response + """ + # Tokenize input with attention mask + # AirLLM processes sequences in chunks, but each chunk must fit within the model's + # position embedding limits. We need to ensure we don't exceed the chunk size. + # Use the model's max_length to ensure compatibility with position embeddings + input_tokens = self.model.tokenizer( + prompt, + return_tensors="pt", + return_attention_mask=True, + truncation=True, + max_length=self.max_length, # Respect model's position embedding limit + padding=False + ) + + # Move to GPU if available + device = 'cuda' if torch.cuda.is_available() else 'cpu' + input_ids = input_tokens['input_ids'].to(device) + attention_mask = input_tokens.get('attention_mask', None) + if attention_mask is not None: + attention_mask = attention_mask.to(device) + + # Ensure we don't exceed max_length (manual truncation as safety check) + seq_length = input_ids.shape[1] + if seq_length > self.max_length: + print(f"Warning: Sequence length ({seq_length}) exceeds limit ({self.max_length}), truncating...") + input_ids = input_ids[:, :self.max_length] + if attention_mask is not None: + attention_mask = attention_mask[:, :self.max_length] + seq_length = self.max_length + + if seq_length >= self.max_length: + print(f"Note: Using sequence of {seq_length} tokens (at limit: {self.max_length})") + + # Prepare generation parameters + # For Llama models, we can use more tokens + max_gen_tokens = min(max_tokens, 512) + + gen_kwargs = { + 'max_new_tokens': max_gen_tokens, + 'use_cache': True, + 'return_dict_in_generate': True, + 'temperature': temperature, + 'top_p': top_p, + **kwargs + } + + # Add attention mask if available + if attention_mask is not None: + gen_kwargs['attention_mask'] = attention_mask + + # Generate + with torch.inference_mode(): + generation_output = self.model.generate(input_ids, **gen_kwargs) + + # Decode output - get only the newly generated tokens + if hasattr(generation_output, 'sequences'): + # Extract only the new tokens (after input length) + input_length = input_ids.shape[1] + generated_ids = generation_output.sequences[0, input_length:] + output = self.model.tokenizer.decode(generated_ids, skip_special_tokens=True) + else: + # Fallback for older output formats + output = self.model.tokenizer.decode(generation_output.sequences[0], skip_special_tokens=True) + # Remove the input prompt from output if present + if output.startswith(prompt): + output = output[len(prompt):].strip() + + if stream: + # For streaming, return a generator (simplified version) + return {"response": output} + else: + return output + + def chat( + self, + messages: List[Dict[str, str]], + model: Optional[str] = None, + max_tokens: int = 50, + temperature: float = 0.7, + **kwargs + ) -> str: + """ + Chat interface (Ollama-compatible). + + Args: + messages: List of message dicts with 'role' and 'content' keys + model: Ignored (kept for compatibility) + max_tokens: Maximum tokens to generate + temperature: Sampling temperature + **kwargs: Additional parameters + + Returns: + Generated response string + """ + # Try to use the model's chat template if available (for Llama, etc.) + if hasattr(self.model.tokenizer, 'apply_chat_template') and self.model.tokenizer.chat_template: + try: + # Use the model's native chat template + prompt = self.model.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True + ) + except Exception: + # Fallback to simple formatting if chat template fails + prompt = self._format_messages(messages) + else: + # Fallback to simple formatting + prompt = self._format_messages(messages) + + return self.generate( + prompt=prompt, + max_tokens=max_tokens, + temperature=temperature, + **kwargs + ) + + def _format_messages(self, messages: List[Dict[str, str]]) -> str: + """Format chat messages into a single prompt (fallback method).""" + formatted = [] + for msg in messages: + role = msg.get('role', 'user') + content = msg.get('content', '') + if role == 'system': + formatted.append(f"System: {content}") + elif role == 'user': + formatted.append(f"User: {content}") + elif role == 'assistant': + formatted.append(f"Assistant: {content}") + return "\n".join(formatted) + "\nAssistant:" + + def embeddings(self, prompt: str) -> List[float]: + """ + Get embeddings for a prompt (simplified - returns token embeddings). + + Note: This is a simplified version. For full embeddings, + you may need to access model internals. + """ + tokens = self.model.tokenizer( + [prompt], + return_tensors="pt", + truncation=True, + max_length=512, + padding=False + ) + # This is a placeholder - actual embeddings would require model forward pass + return tokens['input_ids'].tolist()[0] + + +# Convenience function for easy migration +def create_ollama_client(model_name: str, compression: Optional[str] = None, **kwargs): + """ + Create an Ollama-compatible client using AirLLM. + + Usage: + client = create_ollama_client("meta-llama/Llama-3.2-3B-Instruct") + response = client.generate("Hello, how are you?") + """ + return AirLLMOllamaWrapper(model_name, compression=compression, **kwargs) + + +# Example usage +if __name__ == "__main__": + # Example 1: Basic generation + print("Example 1: Basic Generation") + print("=" * 60) + + # Initialize (this will take time on first run) + # client = create_ollama_client("meta-llama/Llama-3.2-3B-Instruct") + + # Generate + # response = client.generate("What is the capital of France?") + # print(f"Response: {response}") + + print("\nExample 2: Chat Interface") + print("=" * 60) + + # Chat example + # messages = [ + # {"role": "user", "content": "Hello! How are you?"} + # ] + # response = client.chat(messages) + # print(f"Response: {response}") + + print("\nUncomment the code above to test!") + diff --git a/nanobot/providers/litellm_provider.py b/nanobot/providers/litellm_provider.py index 7865139..82798e4 100644 --- a/nanobot/providers/litellm_provider.py +++ b/nanobot/providers/litellm_provider.py @@ -127,6 +127,7 @@ class LiteLLMProvider(LLMProvider): "messages": messages, "max_tokens": max_tokens, "temperature": temperature, + "stream": False, # Explicitly disable streaming to avoid hangs with some providers } # Apply model-specific overrides (e.g. kimi-k2.5 temperature) @@ -148,6 +149,11 @@ class LiteLLMProvider(LLMProvider): kwargs["tools"] = tools kwargs["tool_choice"] = "auto" + # Add timeout to prevent hangs (especially with local servers) + # Ollama can be slow with complex prompts, so use a longer timeout + # Increased to 400s for larger models like mistral-nemo + kwargs["timeout"] = 400.0 + try: response = await acompletion(**kwargs) return self._parse_response(response) diff --git a/nanobot/providers/registry.py b/nanobot/providers/registry.py index fdd036e..8990f2d 100644 --- a/nanobot/providers/registry.py +++ b/nanobot/providers/registry.py @@ -6,7 +6,7 @@ Adding a new provider: 2. Add a field to ProvidersConfig in config/schema.py. Done. Env vars, prefixing, config matching, status display all derive from here. -Order matters — it controls match priority and fallback. Gateways first. +Order matters — it controls match priority and fallback. Every entry writes out all fields so you can copy-paste as a template. """ @@ -62,86 +62,10 @@ class ProviderSpec: PROVIDERS: tuple[ProviderSpec, ...] = ( - # === Gateways (detected by api_key / api_base, not model name) ========= - # Gateways can route any model, so they win in fallback. - - # OpenRouter: global gateway, keys start with "sk-or-" - ProviderSpec( - name="openrouter", - keywords=("openrouter",), - env_key="OPENROUTER_API_KEY", - display_name="OpenRouter", - litellm_prefix="openrouter", # claude-3 → openrouter/claude-3 - skip_prefixes=(), - env_extras=(), - is_gateway=True, - is_local=False, - detect_by_key_prefix="sk-or-", - detect_by_base_keyword="openrouter", - default_api_base="https://openrouter.ai/api/v1", - strip_model_prefix=False, - model_overrides=(), - ), - - # AiHubMix: global gateway, OpenAI-compatible interface. - # strip_model_prefix=True: it doesn't understand "anthropic/claude-3", - # so we strip to bare "claude-3" then re-prefix as "openai/claude-3". - ProviderSpec( - name="aihubmix", - keywords=("aihubmix",), - env_key="OPENAI_API_KEY", # OpenAI-compatible - display_name="AiHubMix", - litellm_prefix="openai", # → openai/{model} - skip_prefixes=(), - env_extras=(), - is_gateway=True, - is_local=False, - detect_by_key_prefix="", - detect_by_base_keyword="aihubmix", - default_api_base="https://aihubmix.com/v1", - strip_model_prefix=True, # anthropic/claude-3 → claude-3 → openai/claude-3 - model_overrides=(), - ), - # === Standard providers (matched by model-name keywords) =============== - # Anthropic: LiteLLM recognizes "claude-*" natively, no prefix needed. - ProviderSpec( - name="anthropic", - keywords=("anthropic", "claude"), - env_key="ANTHROPIC_API_KEY", - display_name="Anthropic", - litellm_prefix="", - skip_prefixes=(), - env_extras=(), - is_gateway=False, - is_local=False, - detect_by_key_prefix="", - detect_by_base_keyword="", - default_api_base="", - strip_model_prefix=False, - model_overrides=(), - ), - - # OpenAI: LiteLLM recognizes "gpt-*" natively, no prefix needed. - ProviderSpec( - name="openai", - keywords=("openai", "gpt"), - env_key="OPENAI_API_KEY", - display_name="OpenAI", - litellm_prefix="", - skip_prefixes=(), - env_extras=(), - is_gateway=False, - is_local=False, - detect_by_key_prefix="", - detect_by_base_keyword="", - default_api_base="", - strip_model_prefix=False, - model_overrides=(), - ), - # DeepSeek: needs "deepseek/" prefix for LiteLLM routing. + # Can be used with local models or API. ProviderSpec( name="deepseek", keywords=("deepseek",), @@ -159,107 +83,6 @@ PROVIDERS: tuple[ProviderSpec, ...] = ( model_overrides=(), ), - # Gemini: needs "gemini/" prefix for LiteLLM. - ProviderSpec( - name="gemini", - keywords=("gemini",), - env_key="GEMINI_API_KEY", - display_name="Gemini", - litellm_prefix="gemini", # gemini-pro → gemini/gemini-pro - skip_prefixes=("gemini/",), # avoid double-prefix - env_extras=(), - is_gateway=False, - is_local=False, - detect_by_key_prefix="", - detect_by_base_keyword="", - default_api_base="", - strip_model_prefix=False, - model_overrides=(), - ), - - # Zhipu: LiteLLM uses "zai/" prefix. - # Also mirrors key to ZHIPUAI_API_KEY (some LiteLLM paths check that). - # skip_prefixes: don't add "zai/" when already routed via gateway. - ProviderSpec( - name="zhipu", - keywords=("zhipu", "glm", "zai"), - env_key="ZAI_API_KEY", - display_name="Zhipu AI", - litellm_prefix="zai", # glm-4 → zai/glm-4 - skip_prefixes=("zhipu/", "zai/", "openrouter/", "hosted_vllm/"), - env_extras=( - ("ZHIPUAI_API_KEY", "{api_key}"), - ), - is_gateway=False, - is_local=False, - detect_by_key_prefix="", - detect_by_base_keyword="", - default_api_base="", - strip_model_prefix=False, - model_overrides=(), - ), - - # DashScope: Qwen models, needs "dashscope/" prefix. - ProviderSpec( - name="dashscope", - keywords=("qwen", "dashscope"), - env_key="DASHSCOPE_API_KEY", - display_name="DashScope", - litellm_prefix="dashscope", # qwen-max → dashscope/qwen-max - skip_prefixes=("dashscope/", "openrouter/"), - env_extras=(), - is_gateway=False, - is_local=False, - detect_by_key_prefix="", - detect_by_base_keyword="", - default_api_base="", - strip_model_prefix=False, - model_overrides=(), - ), - - # Moonshot: Kimi models, needs "moonshot/" prefix. - # LiteLLM requires MOONSHOT_API_BASE env var to find the endpoint. - # Kimi K2.5 API enforces temperature >= 1.0. - ProviderSpec( - name="moonshot", - keywords=("moonshot", "kimi"), - env_key="MOONSHOT_API_KEY", - display_name="Moonshot", - litellm_prefix="moonshot", # kimi-k2.5 → moonshot/kimi-k2.5 - skip_prefixes=("moonshot/", "openrouter/"), - env_extras=( - ("MOONSHOT_API_BASE", "{api_base}"), - ), - is_gateway=False, - is_local=False, - detect_by_key_prefix="", - detect_by_base_keyword="", - default_api_base="https://api.moonshot.ai/v1", # intl; use api.moonshot.cn for China - strip_model_prefix=False, - model_overrides=( - ("kimi-k2.5", {"temperature": 1.0}), - ), - ), - - # MiniMax: needs "minimax/" prefix for LiteLLM routing. - # Uses OpenAI-compatible API at api.minimax.io/v1. - ProviderSpec( - name="minimax", - keywords=("minimax",), - env_key="MINIMAX_API_KEY", - display_name="MiniMax", - litellm_prefix="minimax", # MiniMax-M2.1 → minimax/MiniMax-M2.1 - skip_prefixes=("minimax/", "openrouter/"), - env_extras=(), - is_gateway=False, - is_local=False, - detect_by_key_prefix="", - detect_by_base_keyword="", - default_api_base="https://api.minimax.io/v1", - strip_model_prefix=False, - model_overrides=(), - ), - # === Local deployment (matched by config key, NOT by api_base) ========= # vLLM / any OpenAI-compatible local server. @@ -281,23 +104,44 @@ PROVIDERS: tuple[ProviderSpec, ...] = ( model_overrides=(), ), - # === Auxiliary (not a primary LLM provider) ============================ - - # Groq: mainly used for Whisper voice transcription, also usable for LLM. - # Needs "groq/" prefix for LiteLLM routing. Placed last — it rarely wins fallback. + # Ollama: local OpenAI-compatible server. + # Use OpenAI-compatible endpoint, not native Ollama API. + # Detected when config key is "ollama" or api_base contains "11434" or "ollama". ProviderSpec( - name="groq", - keywords=("groq",), - env_key="GROQ_API_KEY", - display_name="Groq", - litellm_prefix="groq", # llama3-8b-8192 → groq/llama3-8b-8192 - skip_prefixes=("groq/",), # avoid double-prefix + name="ollama", + keywords=("ollama", "llama"), # Match both "ollama" and "llama" model names + env_key="OPENAI_API_KEY", # Use OpenAI-compatible API + display_name="Ollama", + litellm_prefix="", # No prefix - use as OpenAI-compatible + skip_prefixes=(), + env_extras=( + ("OPENAI_API_BASE", "{api_base}"), # Set OpenAI API base to Ollama endpoint + ), + is_gateway=False, + is_local=True, + detect_by_key_prefix="", + detect_by_base_keyword="11434", # Detect by default Ollama port + default_api_base="http://localhost:11434/v1", + strip_model_prefix=False, + model_overrides=(), + ), + + # AirLLM: direct local model inference (no HTTP server). + # Loads models directly into memory for GPU-optimized inference. + # Detected when config key is "airllm". + ProviderSpec( + name="airllm", + keywords=("airllm",), + env_key="", # No API key needed (local) + display_name="AirLLM", + litellm_prefix="", # Not used with LiteLLM + skip_prefixes=(), env_extras=(), is_gateway=False, - is_local=False, + is_local=True, detect_by_key_prefix="", detect_by_base_keyword="", - default_api_base="", + default_api_base="", # Not used (direct Python calls) strip_model_prefix=False, model_overrides=(), ), @@ -325,12 +169,11 @@ def find_gateway( api_key: str | None = None, api_base: str | None = None, ) -> ProviderSpec | None: - """Detect gateway/local provider. + """Detect local provider. Priority: - 1. provider_name — if it maps to a gateway/local spec, use it directly. - 2. api_key prefix — e.g. "sk-or-" → OpenRouter. - 3. api_base keyword — e.g. "aihubmix" in URL → AiHubMix. + 1. provider_name — if it maps to a local spec, use it directly. + 2. api_base keyword — e.g. "11434" in URL → Ollama. A standard provider with a custom api_base (e.g. DeepSeek behind a proxy) will NOT be mistaken for vLLM — the old fallback is gone. @@ -341,10 +184,8 @@ def find_gateway( if spec and (spec.is_gateway or spec.is_local): return spec - # 2. Auto-detect by api_key prefix / api_base keyword + # 2. Auto-detect by api_base keyword for spec in PROVIDERS: - if spec.detect_by_key_prefix and api_key and api_key.startswith(spec.detect_by_key_prefix): - return spec if spec.detect_by_base_keyword and api_base and spec.detect_by_base_keyword in api_base: return spec diff --git a/setup_llama_airllm.py b/setup_llama_airllm.py new file mode 100644 index 0000000..424f5b6 --- /dev/null +++ b/setup_llama_airllm.py @@ -0,0 +1,175 @@ +#!/usr/bin/env python3 +""" +Setup script to configure nanobot to use Llama models with AirLLM. +This script will: +1. Check/create the config file +2. Set up Llama model configuration +3. Guide you through getting a Hugging Face token if needed +""" + +import json +import os +from pathlib import Path + +CONFIG_PATH = Path.home() / ".nanobot" / "config.json" + +def get_hf_token_instructions(): + """Print instructions for getting a Hugging Face token.""" + print("\n" + "="*70) + print("GETTING A HUGGING FACE TOKEN") + print("="*70) + print("\nTo use Llama models (which are gated), you need a Hugging Face token:") + print("\n1. Go to: https://huggingface.co/settings/tokens") + print("2. Click 'New token'") + print("3. Give it a name (e.g., 'nanobot')") + print("4. Select 'Read' permission") + print("5. Click 'Generate token'") + print("6. Copy the token (starts with 'hf_...')") + print("\nThen accept the Llama model license:") + print("1. Go to: https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct") + print("2. Click 'Agree and access repository'") + print("3. Accept the license terms") + print("\n" + "="*70 + "\n") + +def load_existing_config(): + """Load existing config or return default.""" + if CONFIG_PATH.exists(): + try: + with open(CONFIG_PATH) as f: + return json.load(f) + except Exception as e: + print(f"Warning: Could not read existing config: {e}") + return {} + return {} + +def create_llama_config(): + """Create or update config for Llama with AirLLM.""" + config = load_existing_config() + + # Ensure providers section exists + if "providers" not in config: + config["providers"] = {} + + # Ensure agents section exists + if "agents" not in config: + config["agents"] = {} + if "defaults" not in config["agents"]: + config["agents"]["defaults"] = {} + + # Choose Llama model + print("\n" + "="*70) + print("CHOOSE LLAMA MODEL") + print("="*70) + print("\nAvailable models:") + print(" 1. Llama-3.2-3B-Instruct (Recommended - fast, minimal memory)") + print(" 2. Llama-3.1-8B-Instruct (Good balance of performance and speed)") + print(" 3. Custom (enter model path)") + + choice = input("\nChoose model (1-3, default: 1): ").strip() or "1" + + model_map = { + "1": "meta-llama/Llama-3.2-3B-Instruct", + "2": "meta-llama/Llama-3.1-8B-Instruct", + } + + if choice == "3": + model_path = input("Enter model path (e.g., meta-llama/Llama-3.2-3B-Instruct): ").strip() + if not model_path: + model_path = "meta-llama/Llama-3.2-3B-Instruct" + print(f"Using default: {model_path}") + else: + model_path = model_map.get(choice, "meta-llama/Llama-3.2-3B-Instruct") + + # Set up AirLLM provider with Llama model + # Note: apiKey can be used as model path, or we can put model in defaults + config["providers"]["airllm"] = { + "apiKey": "", # Will be set to model path + "apiBase": None, + "extraHeaders": {} + } + + # Set default model + config["agents"]["defaults"]["model"] = model_path + + # Ask for Hugging Face token + print("\n" + "="*70) + print("HUGGING FACE TOKEN SETUP") + print("="*70) + print("\nDo you have a Hugging Face token? (Required for Llama models)") + print("If not, we'll show you how to get one.\n") + + has_token = input("Do you have a Hugging Face token? (y/n): ").strip().lower() + + if has_token == 'y': + hf_token = input("\nEnter your Hugging Face token (starts with 'hf_'): ").strip() + if hf_token and hf_token.startswith('hf_'): + # Store token in extraHeaders + config["providers"]["airllm"]["extraHeaders"]["hf_token"] = hf_token + # Also set apiKey to model path (AirLLM uses apiKey as model path if it contains '/') + config["providers"]["airllm"]["apiKey"] = config["agents"]["defaults"]["model"] + print("\n✓ Token configured!") + else: + print("⚠ Warning: Token doesn't look valid (should start with 'hf_')") + print("You can add it later by editing the config file.") + # Still set model path in apiKey + config["providers"]["airllm"]["apiKey"] = config["agents"]["defaults"]["model"] + else: + get_hf_token_instructions() + print("\nYou can add your token later by:") + print(f"1. Editing: {CONFIG_PATH}") + print("2. Adding your token to: providers.airllm.extraHeaders.hf_token") + print("\nOr run this script again after getting your token.") + + return config + +def save_config(config): + """Save config to file.""" + CONFIG_PATH.parent.mkdir(parents=True, exist_ok=True) + with open(CONFIG_PATH, 'w') as f: + json.dump(config, f, indent=2) + + # Set secure permissions + os.chmod(CONFIG_PATH, 0o600) + print(f"\n✓ Configuration saved to: {CONFIG_PATH}") + print(f"✓ File permissions set to 600 (read/write for owner only)") + +def main(): + """Main setup function.""" + print("\n" + "="*70) + print("NANOBOT LLAMA + AIRLLM SETUP") + print("="*70) + print("\nThis script will configure nanobot to use Llama models with AirLLM.\n") + + if CONFIG_PATH.exists(): + print(f"Found existing config at: {CONFIG_PATH}") + backup = input("\nCreate backup? (y/n): ").strip().lower() + if backup == 'y': + backup_path = CONFIG_PATH.with_suffix('.json.backup') + import shutil + shutil.copy(CONFIG_PATH, backup_path) + print(f"✓ Backup created: {backup_path}") + else: + print(f"Creating new config at: {CONFIG_PATH}") + + config = create_llama_config() + save_config(config) + + print("\n" + "="*70) + print("SETUP COMPLETE!") + print("="*70) + print("\nConfiguration:") + print(f" Model: {config['agents']['defaults']['model']}") + print(f" Provider: airllm") + if config["providers"]["airllm"].get("extraHeaders", {}).get("hf_token"): + print(f" HF Token: {'*' * 20} (configured)") + else: + print(f" HF Token: Not configured (add it to use gated models)") + + print("\nNext steps:") + print(" 1. If you need a Hugging Face token, follow the instructions above") + print(" 2. Test it: nanobot agent -m 'Hello, what is 2+5?'") + print("\n" + "="*70 + "\n") + +if __name__ == "__main__": + main() +