nanobot/test_airllm_fix.py
Tanya e6b5ead3fd Merge origin/main into feature branch
- Merged latest 166 commits from origin/main
- Resolved conflicts in .gitignore, commands.py, schema.py, providers/__init__.py, and registry.py
- Kept both local providers (Ollama, AirLLM) and new providers from main
- Preserved transformers 4.39.3 compatibility fixes
- Combined error handling improvements with new features
2026-02-18 13:03:19 -05:00

62 lines
1.8 KiB
Python

#!/usr/bin/env python3
"""Direct test of AirLLM fix with Llama3.2"""
import sys
import os
# Add paths
sys.path.insert(0, '/home/ladmin/code/nanobot/nanobot')
sys.path.insert(0, '/home/ladmin/code/airllm/airllm/air_llm')
# Inject BetterTransformer before importing
import importlib.util
class DummyBetterTransformer:
@staticmethod
def transform(model):
return model
if "optimum.bettertransformer" not in sys.modules:
spec = importlib.util.spec_from_loader("optimum.bettertransformer", None)
dummy_module = importlib.util.module_from_spec(spec)
dummy_module.BetterTransformer = DummyBetterTransformer
sys.modules["optimum.bettertransformer"] = dummy_module
print("=" * 60)
print("TESTING AIRLLM FIX WITH LLAMA3.2")
print("=" * 60)
try:
from airllm import AutoModel
print("✓ AirLLM imported")
print("\nLoading model...")
model = AutoModel.from_pretrained("/home/ladmin/.local/models/llama3.2-3b-instruct")
print("✓ Model loaded")
print("\nTesting generation...")
prompt = "Hello, what is 2+5?"
print(f"Prompt: {prompt}")
# Tokenize
input_ids = model.tokenizer(prompt, return_tensors="pt")['input_ids'].to('cuda' if os.environ.get('CUDA_VISIBLE_DEVICES') else 'cpu')
# Generate
print("Generating (this may take a minute)...")
output = model.generate(input_ids, max_new_tokens=20, temperature=0.7)
# Decode
response = model.tokenizer.decode(output[0][input_ids.shape[1]:], skip_special_tokens=True)
print(f"\n{'='*60}")
print("SUCCESS! Response:")
print(f"{'='*60}")
print(response)
print(f"{'='*60}")
except Exception as e:
print(f"\n{'='*60}")
print("ERROR")
print(f"{'='*60}")
import traceback
print(f"{e}")
traceback.print_exc()
sys.exit(1)