#!/usr/bin/env python3 """ AirLLM Ollama-Compatible Wrapper This wrapper provides an Ollama-like interface for AirLLM, making it easy to replace Ollama in existing projects. """ import torch from typing import List, Dict, Optional, Union # Try to import airllm, handle BetterTransformer import error gracefully try: from airllm import AutoModel AIRLLM_AVAILABLE = True except ImportError as e: if "optimum.bettertransformer" in str(e) or "BetterTransformer" in str(e): # Try to work around BetterTransformer import issue import sys import importlib.util # Create a dummy BetterTransformer module to allow airllm to import class DummyBetterTransformer: @staticmethod def transform(model): return model # Inject dummy module before importing airllm spec = importlib.util.spec_from_loader("optimum.bettertransformer", None) dummy_module = importlib.util.module_from_spec(spec) dummy_module.BetterTransformer = DummyBetterTransformer sys.modules["optimum.bettertransformer"] = dummy_module try: from airllm import AutoModel AIRLLM_AVAILABLE = True except ImportError: AIRLLM_AVAILABLE = False AutoModel = None else: AIRLLM_AVAILABLE = False AutoModel = None class AirLLMOllamaWrapper: """ A wrapper that provides an Ollama-like API for AirLLM. Usage: # Instead of: ollama.generate(model="llama2", prompt="Hello") # Use: airllm_wrapper.generate(model="llama2", prompt="Hello") """ def __init__(self, model_name: str, compression: Optional[str] = None, **kwargs): """ Initialize AirLLM model. Args: model_name: Hugging Face model name or path (e.g., "meta-llama/Llama-3.2-3B-Instruct") compression: Optional compression ('4bit' or '8bit') for 3x speed improvement **kwargs: Additional arguments for AutoModel.from_pretrained() """ if not AIRLLM_AVAILABLE or AutoModel is None: raise ImportError( "AirLLM is not available. Please install it with: pip install airllm bitsandbytes\n" "If you see a BetterTransformer error, you may need to install: pip install optimum[bettertransformer]" ) print(f"Loading AirLLM model: {model_name}") self.model = AutoModel.from_pretrained( model_name, compression=compression, **kwargs ) self.model_name = model_name print("Model loaded successfully!") def generate( self, prompt: str, model: Optional[str] = None, # Ignored, kept for API compatibility max_tokens: int = 50, temperature: float = 0.7, top_p: float = 0.9, stream: bool = False, **kwargs ) -> Union[str, Dict]: """ Generate text from a prompt (Ollama-compatible interface). Args: prompt: Input text prompt model: Ignored (kept for compatibility) max_tokens: Maximum number of tokens to generate temperature: Sampling temperature (0.0 to 1.0) top_p: Nucleus sampling parameter stream: If True, return streaming response (not yet implemented) **kwargs: Additional generation parameters Returns: Generated text string or dict with response """ # Tokenize input input_tokens = self.model.tokenizer( [prompt], return_tensors="pt", return_attention_mask=False, truncation=True, max_length=512, # Adjust as needed padding=False ) # Move to GPU if available device = 'cuda' if torch.cuda.is_available() else 'cpu' input_ids = input_tokens['input_ids'].to(device) # Prepare generation parameters gen_kwargs = { 'max_new_tokens': max_tokens, 'use_cache': True, 'return_dict_in_generate': True, 'temperature': temperature, 'top_p': top_p, **kwargs } # Generate with torch.inference_mode(): generation_output = self.model.generate(input_ids, **gen_kwargs) # Decode output output = self.model.tokenizer.decode(generation_output.sequences[0]) # Remove the input prompt from output (if present) if output.startswith(prompt): output = output[len(prompt):].strip() if stream: # For streaming, return a generator (simplified version) return {"response": output} else: return output def chat( self, messages: List[Dict[str, str]], model: Optional[str] = None, max_tokens: int = 50, temperature: float = 0.7, **kwargs ) -> str: """ Chat interface (Ollama-compatible). Args: messages: List of message dicts with 'role' and 'content' keys model: Ignored (kept for compatibility) max_tokens: Maximum tokens to generate temperature: Sampling temperature **kwargs: Additional parameters Returns: Generated response string """ # Format messages into a prompt prompt = self._format_messages(messages) return self.generate( prompt=prompt, max_tokens=max_tokens, temperature=temperature, **kwargs ) def _format_messages(self, messages: List[Dict[str, str]]) -> str: """Format chat messages into a single prompt.""" formatted = [] for msg in messages: role = msg.get('role', 'user') content = msg.get('content', '') if role == 'system': formatted.append(f"System: {content}") elif role == 'user': formatted.append(f"User: {content}") elif role == 'assistant': formatted.append(f"Assistant: {content}") return "\n".join(formatted) + "\nAssistant:" def embeddings(self, prompt: str) -> List[float]: """ Get embeddings for a prompt (simplified - returns token embeddings). Note: This is a simplified version. For full embeddings, you may need to access model internals. """ tokens = self.model.tokenizer( [prompt], return_tensors="pt", truncation=True, max_length=512, padding=False ) # This is a placeholder - actual embeddings would require model forward pass return tokens['input_ids'].tolist()[0] # Convenience function for easy migration def create_ollama_client(model_name: str, compression: Optional[str] = None, **kwargs): """ Create an Ollama-compatible client using AirLLM. Usage: client = create_ollama_client("meta-llama/Llama-3.2-3B-Instruct") response = client.generate("Hello, how are you?") """ return AirLLMOllamaWrapper(model_name, compression=compression, **kwargs) # Example usage if __name__ == "__main__": # Example 1: Basic generation print("Example 1: Basic Generation") print("=" * 60) # Initialize (this will take time on first run) # client = create_ollama_client("garage-bAInd/Platypus2-70B-instruct") # Generate # response = client.generate("What is the capital of France?") # print(f"Response: {response}") print("\nExample 2: Chat Interface") print("=" * 60) # Chat example # messages = [ # {"role": "user", "content": "Hello! How are you?"} # ] # response = client.chat(messages) # print(f"Response: {response}") print("\nUncomment the code above to test!")