- Remove Qwen/DashScope provider and all Qwen-specific code - Remove gateway providers (OpenRouter, AiHubMix) - Remove cloud providers (Anthropic, OpenAI, Gemini, Zhipu, Moonshot, MiniMax, Groq) - Update default model from Platypus to llama3.2 - Remove Platypus references throughout codebase - Add AirLLM provider support with local model path support - Update setup scripts to only show Llama models - Clean up provider registry and config schema
243 lines
7.9 KiB
Python
243 lines
7.9 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
AirLLM Ollama-Compatible Wrapper
|
|
|
|
This wrapper provides an Ollama-like interface for AirLLM,
|
|
making it easy to replace Ollama in existing projects.
|
|
"""
|
|
|
|
import torch
|
|
from typing import List, Dict, Optional, Union
|
|
|
|
# Try to import airllm, handle BetterTransformer import error gracefully
|
|
try:
|
|
from airllm import AutoModel
|
|
AIRLLM_AVAILABLE = True
|
|
except ImportError as e:
|
|
if "optimum.bettertransformer" in str(e) or "BetterTransformer" in str(e):
|
|
# Try to work around BetterTransformer import issue
|
|
import sys
|
|
import importlib.util
|
|
|
|
# Create a dummy BetterTransformer module to allow airllm to import
|
|
class DummyBetterTransformer:
|
|
@staticmethod
|
|
def transform(model):
|
|
return model
|
|
|
|
# Inject dummy module before importing airllm
|
|
spec = importlib.util.spec_from_loader("optimum.bettertransformer", None)
|
|
dummy_module = importlib.util.module_from_spec(spec)
|
|
dummy_module.BetterTransformer = DummyBetterTransformer
|
|
sys.modules["optimum.bettertransformer"] = dummy_module
|
|
|
|
try:
|
|
from airllm import AutoModel
|
|
AIRLLM_AVAILABLE = True
|
|
except ImportError:
|
|
AIRLLM_AVAILABLE = False
|
|
AutoModel = None
|
|
else:
|
|
AIRLLM_AVAILABLE = False
|
|
AutoModel = None
|
|
|
|
|
|
class AirLLMOllamaWrapper:
|
|
"""
|
|
A wrapper that provides an Ollama-like API for AirLLM.
|
|
|
|
Usage:
|
|
# Instead of: ollama.generate(model="llama2", prompt="Hello")
|
|
# Use: airllm_wrapper.generate(model="llama2", prompt="Hello")
|
|
"""
|
|
|
|
def __init__(self, model_name: str, compression: Optional[str] = None, **kwargs):
|
|
"""
|
|
Initialize AirLLM model.
|
|
|
|
Args:
|
|
model_name: Hugging Face model name or path (e.g., "meta-llama/Llama-3.2-3B-Instruct")
|
|
compression: Optional compression ('4bit' or '8bit') for 3x speed improvement
|
|
**kwargs: Additional arguments for AutoModel.from_pretrained()
|
|
"""
|
|
if not AIRLLM_AVAILABLE or AutoModel is None:
|
|
raise ImportError(
|
|
"AirLLM is not available. Please install it with: pip install airllm bitsandbytes\n"
|
|
"If you see a BetterTransformer error, you may need to install: pip install optimum[bettertransformer]"
|
|
)
|
|
|
|
print(f"Loading AirLLM model: {model_name}")
|
|
self.model = AutoModel.from_pretrained(
|
|
model_name,
|
|
compression=compression,
|
|
**kwargs
|
|
)
|
|
self.model_name = model_name
|
|
print("Model loaded successfully!")
|
|
|
|
def generate(
|
|
self,
|
|
prompt: str,
|
|
model: Optional[str] = None, # Ignored, kept for API compatibility
|
|
max_tokens: int = 50,
|
|
temperature: float = 0.7,
|
|
top_p: float = 0.9,
|
|
stream: bool = False,
|
|
**kwargs
|
|
) -> Union[str, Dict]:
|
|
"""
|
|
Generate text from a prompt (Ollama-compatible interface).
|
|
|
|
Args:
|
|
prompt: Input text prompt
|
|
model: Ignored (kept for compatibility)
|
|
max_tokens: Maximum number of tokens to generate
|
|
temperature: Sampling temperature (0.0 to 1.0)
|
|
top_p: Nucleus sampling parameter
|
|
stream: If True, return streaming response (not yet implemented)
|
|
**kwargs: Additional generation parameters
|
|
|
|
Returns:
|
|
Generated text string or dict with response
|
|
"""
|
|
# Tokenize input
|
|
input_tokens = self.model.tokenizer(
|
|
[prompt],
|
|
return_tensors="pt",
|
|
return_attention_mask=False,
|
|
truncation=True,
|
|
max_length=512, # Adjust as needed
|
|
padding=False
|
|
)
|
|
|
|
# Move to GPU if available
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
input_ids = input_tokens['input_ids'].to(device)
|
|
|
|
# Prepare generation parameters
|
|
gen_kwargs = {
|
|
'max_new_tokens': max_tokens,
|
|
'use_cache': True,
|
|
'return_dict_in_generate': True,
|
|
'temperature': temperature,
|
|
'top_p': top_p,
|
|
**kwargs
|
|
}
|
|
|
|
# Generate
|
|
with torch.inference_mode():
|
|
generation_output = self.model.generate(input_ids, **gen_kwargs)
|
|
|
|
# Decode output
|
|
output = self.model.tokenizer.decode(generation_output.sequences[0])
|
|
|
|
# Remove the input prompt from output (if present)
|
|
if output.startswith(prompt):
|
|
output = output[len(prompt):].strip()
|
|
|
|
if stream:
|
|
# For streaming, return a generator (simplified version)
|
|
return {"response": output}
|
|
else:
|
|
return output
|
|
|
|
def chat(
|
|
self,
|
|
messages: List[Dict[str, str]],
|
|
model: Optional[str] = None,
|
|
max_tokens: int = 50,
|
|
temperature: float = 0.7,
|
|
**kwargs
|
|
) -> str:
|
|
"""
|
|
Chat interface (Ollama-compatible).
|
|
|
|
Args:
|
|
messages: List of message dicts with 'role' and 'content' keys
|
|
model: Ignored (kept for compatibility)
|
|
max_tokens: Maximum tokens to generate
|
|
temperature: Sampling temperature
|
|
**kwargs: Additional parameters
|
|
|
|
Returns:
|
|
Generated response string
|
|
"""
|
|
# Format messages into a prompt
|
|
prompt = self._format_messages(messages)
|
|
return self.generate(
|
|
prompt=prompt,
|
|
max_tokens=max_tokens,
|
|
temperature=temperature,
|
|
**kwargs
|
|
)
|
|
|
|
def _format_messages(self, messages: List[Dict[str, str]]) -> str:
|
|
"""Format chat messages into a single prompt."""
|
|
formatted = []
|
|
for msg in messages:
|
|
role = msg.get('role', 'user')
|
|
content = msg.get('content', '')
|
|
if role == 'system':
|
|
formatted.append(f"System: {content}")
|
|
elif role == 'user':
|
|
formatted.append(f"User: {content}")
|
|
elif role == 'assistant':
|
|
formatted.append(f"Assistant: {content}")
|
|
return "\n".join(formatted) + "\nAssistant:"
|
|
|
|
def embeddings(self, prompt: str) -> List[float]:
|
|
"""
|
|
Get embeddings for a prompt (simplified - returns token embeddings).
|
|
|
|
Note: This is a simplified version. For full embeddings,
|
|
you may need to access model internals.
|
|
"""
|
|
tokens = self.model.tokenizer(
|
|
[prompt],
|
|
return_tensors="pt",
|
|
truncation=True,
|
|
max_length=512,
|
|
padding=False
|
|
)
|
|
# This is a placeholder - actual embeddings would require model forward pass
|
|
return tokens['input_ids'].tolist()[0]
|
|
|
|
|
|
# Convenience function for easy migration
|
|
def create_ollama_client(model_name: str, compression: Optional[str] = None, **kwargs):
|
|
"""
|
|
Create an Ollama-compatible client using AirLLM.
|
|
|
|
Usage:
|
|
client = create_ollama_client("meta-llama/Llama-3.2-3B-Instruct")
|
|
response = client.generate("Hello, how are you?")
|
|
"""
|
|
return AirLLMOllamaWrapper(model_name, compression=compression, **kwargs)
|
|
|
|
|
|
# Example usage
|
|
if __name__ == "__main__":
|
|
# Example 1: Basic generation
|
|
print("Example 1: Basic Generation")
|
|
print("=" * 60)
|
|
|
|
# Initialize (this will take time on first run)
|
|
# client = create_ollama_client("garage-bAInd/Platypus2-70B-instruct")
|
|
|
|
# Generate
|
|
# response = client.generate("What is the capital of France?")
|
|
# print(f"Response: {response}")
|
|
|
|
print("\nExample 2: Chat Interface")
|
|
print("=" * 60)
|
|
|
|
# Chat example
|
|
# messages = [
|
|
# {"role": "user", "content": "Hello! How are you?"}
|
|
# ]
|
|
# response = client.chat(messages)
|
|
# print(f"Response: {response}")
|
|
|
|
print("\nUncomment the code above to test!")
|
|
|