- Merged latest 166 commits from origin/main - Resolved conflicts in .gitignore, commands.py, schema.py, providers/__init__.py, and registry.py - Kept both local providers (Ollama, AirLLM) and new providers from main - Preserved transformers 4.39.3 compatibility fixes - Combined error handling improvements with new features
62 lines
1.8 KiB
Python
62 lines
1.8 KiB
Python
#!/usr/bin/env python3
|
|
"""Direct test of AirLLM fix with Llama3.2"""
|
|
import sys
|
|
import os
|
|
|
|
# Add paths
|
|
sys.path.insert(0, '/home/ladmin/code/nanobot/nanobot')
|
|
sys.path.insert(0, '/home/ladmin/code/airllm/airllm/air_llm')
|
|
|
|
# Inject BetterTransformer before importing
|
|
import importlib.util
|
|
class DummyBetterTransformer:
|
|
@staticmethod
|
|
def transform(model):
|
|
return model
|
|
if "optimum.bettertransformer" not in sys.modules:
|
|
spec = importlib.util.spec_from_loader("optimum.bettertransformer", None)
|
|
dummy_module = importlib.util.module_from_spec(spec)
|
|
dummy_module.BetterTransformer = DummyBetterTransformer
|
|
sys.modules["optimum.bettertransformer"] = dummy_module
|
|
|
|
print("=" * 60)
|
|
print("TESTING AIRLLM FIX WITH LLAMA3.2")
|
|
print("=" * 60)
|
|
|
|
try:
|
|
from airllm import AutoModel
|
|
print("✓ AirLLM imported")
|
|
|
|
print("\nLoading model...")
|
|
model = AutoModel.from_pretrained("/home/ladmin/.local/models/llama3.2-3b-instruct")
|
|
print("✓ Model loaded")
|
|
|
|
print("\nTesting generation...")
|
|
prompt = "Hello, what is 2+5?"
|
|
print(f"Prompt: {prompt}")
|
|
|
|
# Tokenize
|
|
input_ids = model.tokenizer(prompt, return_tensors="pt")['input_ids'].to('cuda' if os.environ.get('CUDA_VISIBLE_DEVICES') else 'cpu')
|
|
|
|
# Generate
|
|
print("Generating (this may take a minute)...")
|
|
output = model.generate(input_ids, max_new_tokens=20, temperature=0.7)
|
|
|
|
# Decode
|
|
response = model.tokenizer.decode(output[0][input_ids.shape[1]:], skip_special_tokens=True)
|
|
print(f"\n{'='*60}")
|
|
print("SUCCESS! Response:")
|
|
print(f"{'='*60}")
|
|
print(response)
|
|
print(f"{'='*60}")
|
|
|
|
except Exception as e:
|
|
print(f"\n{'='*60}")
|
|
print("ERROR")
|
|
print(f"{'='*60}")
|
|
import traceback
|
|
print(f"{e}")
|
|
traceback.print_exc()
|
|
sys.exit(1)
|
|
|