nanobot/airllm_ollama_wrapper.py
Tanya f1e95626f8 Clean up providers: keep only Ollama, AirLLM, vLLM, and DeepSeek
- Remove Qwen/DashScope provider and all Qwen-specific code
- Remove gateway providers (OpenRouter, AiHubMix)
- Remove cloud providers (Anthropic, OpenAI, Gemini, Zhipu, Moonshot, MiniMax, Groq)
- Update default model from Platypus to llama3.2
- Remove Platypus references throughout codebase
- Add AirLLM provider support with local model path support
- Update setup scripts to only show Llama models
- Clean up provider registry and config schema
2026-02-17 14:20:47 -05:00

243 lines
7.9 KiB
Python

#!/usr/bin/env python3
"""
AirLLM Ollama-Compatible Wrapper
This wrapper provides an Ollama-like interface for AirLLM,
making it easy to replace Ollama in existing projects.
"""
import torch
from typing import List, Dict, Optional, Union
# Try to import airllm, handle BetterTransformer import error gracefully
try:
from airllm import AutoModel
AIRLLM_AVAILABLE = True
except ImportError as e:
if "optimum.bettertransformer" in str(e) or "BetterTransformer" in str(e):
# Try to work around BetterTransformer import issue
import sys
import importlib.util
# Create a dummy BetterTransformer module to allow airllm to import
class DummyBetterTransformer:
@staticmethod
def transform(model):
return model
# Inject dummy module before importing airllm
spec = importlib.util.spec_from_loader("optimum.bettertransformer", None)
dummy_module = importlib.util.module_from_spec(spec)
dummy_module.BetterTransformer = DummyBetterTransformer
sys.modules["optimum.bettertransformer"] = dummy_module
try:
from airllm import AutoModel
AIRLLM_AVAILABLE = True
except ImportError:
AIRLLM_AVAILABLE = False
AutoModel = None
else:
AIRLLM_AVAILABLE = False
AutoModel = None
class AirLLMOllamaWrapper:
"""
A wrapper that provides an Ollama-like API for AirLLM.
Usage:
# Instead of: ollama.generate(model="llama2", prompt="Hello")
# Use: airllm_wrapper.generate(model="llama2", prompt="Hello")
"""
def __init__(self, model_name: str, compression: Optional[str] = None, **kwargs):
"""
Initialize AirLLM model.
Args:
model_name: Hugging Face model name or path (e.g., "meta-llama/Llama-3.2-3B-Instruct")
compression: Optional compression ('4bit' or '8bit') for 3x speed improvement
**kwargs: Additional arguments for AutoModel.from_pretrained()
"""
if not AIRLLM_AVAILABLE or AutoModel is None:
raise ImportError(
"AirLLM is not available. Please install it with: pip install airllm bitsandbytes\n"
"If you see a BetterTransformer error, you may need to install: pip install optimum[bettertransformer]"
)
print(f"Loading AirLLM model: {model_name}")
self.model = AutoModel.from_pretrained(
model_name,
compression=compression,
**kwargs
)
self.model_name = model_name
print("Model loaded successfully!")
def generate(
self,
prompt: str,
model: Optional[str] = None, # Ignored, kept for API compatibility
max_tokens: int = 50,
temperature: float = 0.7,
top_p: float = 0.9,
stream: bool = False,
**kwargs
) -> Union[str, Dict]:
"""
Generate text from a prompt (Ollama-compatible interface).
Args:
prompt: Input text prompt
model: Ignored (kept for compatibility)
max_tokens: Maximum number of tokens to generate
temperature: Sampling temperature (0.0 to 1.0)
top_p: Nucleus sampling parameter
stream: If True, return streaming response (not yet implemented)
**kwargs: Additional generation parameters
Returns:
Generated text string or dict with response
"""
# Tokenize input
input_tokens = self.model.tokenizer(
[prompt],
return_tensors="pt",
return_attention_mask=False,
truncation=True,
max_length=512, # Adjust as needed
padding=False
)
# Move to GPU if available
device = 'cuda' if torch.cuda.is_available() else 'cpu'
input_ids = input_tokens['input_ids'].to(device)
# Prepare generation parameters
gen_kwargs = {
'max_new_tokens': max_tokens,
'use_cache': True,
'return_dict_in_generate': True,
'temperature': temperature,
'top_p': top_p,
**kwargs
}
# Generate
with torch.inference_mode():
generation_output = self.model.generate(input_ids, **gen_kwargs)
# Decode output
output = self.model.tokenizer.decode(generation_output.sequences[0])
# Remove the input prompt from output (if present)
if output.startswith(prompt):
output = output[len(prompt):].strip()
if stream:
# For streaming, return a generator (simplified version)
return {"response": output}
else:
return output
def chat(
self,
messages: List[Dict[str, str]],
model: Optional[str] = None,
max_tokens: int = 50,
temperature: float = 0.7,
**kwargs
) -> str:
"""
Chat interface (Ollama-compatible).
Args:
messages: List of message dicts with 'role' and 'content' keys
model: Ignored (kept for compatibility)
max_tokens: Maximum tokens to generate
temperature: Sampling temperature
**kwargs: Additional parameters
Returns:
Generated response string
"""
# Format messages into a prompt
prompt = self._format_messages(messages)
return self.generate(
prompt=prompt,
max_tokens=max_tokens,
temperature=temperature,
**kwargs
)
def _format_messages(self, messages: List[Dict[str, str]]) -> str:
"""Format chat messages into a single prompt."""
formatted = []
for msg in messages:
role = msg.get('role', 'user')
content = msg.get('content', '')
if role == 'system':
formatted.append(f"System: {content}")
elif role == 'user':
formatted.append(f"User: {content}")
elif role == 'assistant':
formatted.append(f"Assistant: {content}")
return "\n".join(formatted) + "\nAssistant:"
def embeddings(self, prompt: str) -> List[float]:
"""
Get embeddings for a prompt (simplified - returns token embeddings).
Note: This is a simplified version. For full embeddings,
you may need to access model internals.
"""
tokens = self.model.tokenizer(
[prompt],
return_tensors="pt",
truncation=True,
max_length=512,
padding=False
)
# This is a placeholder - actual embeddings would require model forward pass
return tokens['input_ids'].tolist()[0]
# Convenience function for easy migration
def create_ollama_client(model_name: str, compression: Optional[str] = None, **kwargs):
"""
Create an Ollama-compatible client using AirLLM.
Usage:
client = create_ollama_client("meta-llama/Llama-3.2-3B-Instruct")
response = client.generate("Hello, how are you?")
"""
return AirLLMOllamaWrapper(model_name, compression=compression, **kwargs)
# Example usage
if __name__ == "__main__":
# Example 1: Basic generation
print("Example 1: Basic Generation")
print("=" * 60)
# Initialize (this will take time on first run)
# client = create_ollama_client("garage-bAInd/Platypus2-70B-instruct")
# Generate
# response = client.generate("What is the capital of France?")
# print(f"Response: {response}")
print("\nExample 2: Chat Interface")
print("=" * 60)
# Chat example
# messages = [
# {"role": "user", "content": "Hello! How are you?"}
# ]
# response = client.chat(messages)
# print(f"Response: {response}")
print("\nUncomment the code above to test!")