Fix transformers 4.39.3 compatibility issues with AirLLM

- Fix RoPE scaling compatibility: automatically convert unsupported 'llama3' type to 'linear' for local models
- Patch LlamaSdpaAttention to filter out position_embeddings argument that AirLLM passes but transformers 4.39.3 doesn't accept
- Add better error handling with specific guidance for compatibility issues
- Fix config file modification for local models with unsupported rope_scaling types
- Improve error messages to help diagnose transformers version compatibility issues

These fixes allow nanobot to work with transformers 4.39.3 and AirLLM.
This commit is contained in:
Tanya 2026-02-18 12:39:29 -05:00
parent f1faee54b6
commit 7961bf1360
2 changed files with 240 additions and 49 deletions

View File

@ -2,6 +2,7 @@
import json
import asyncio
import sys
from typing import Any
from pathlib import Path
@ -129,14 +130,20 @@ class AirLLMProvider(LLMProvider):
# Run the synchronous client in an executor to avoid blocking
loop = asyncio.get_event_loop()
response_text = await loop.run_in_executor(
None,
lambda: client.chat(
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
try:
response_text = await loop.run_in_executor(
None,
lambda: client.chat(
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
)
)
)
except Exception as e:
import traceback
error_msg = f"AirLLM generation failed: {e}\n{traceback.format_exc()}"
print(error_msg, file=sys.stderr)
raise RuntimeError(f"AirLLM provider error: {e}") from e
# Parse tool calls from response if present
# This is a simplified parser - you may need to adjust based on model output format

View File

@ -9,37 +9,100 @@ making it easy to replace Ollama in existing projects.
import torch
from typing import List, Dict, Optional, Union
# Try to import airllm, handle BetterTransformer import error gracefully
# Try to import airllm, preferring the local checkout if available
import sys
import os
import importlib.util
# Inject dummy BetterTransformer BEFORE importing airllm (local code needs it)
class DummyBetterTransformer:
@staticmethod
def transform(model):
return model
if "optimum.bettertransformer" not in sys.modules:
spec = importlib.util.spec_from_loader("optimum.bettertransformer", None)
dummy_module = importlib.util.module_from_spec(spec)
dummy_module.BetterTransformer = DummyBetterTransformer
sys.modules["optimum.bettertransformer"] = dummy_module
# Fix RoPE scaling compatibility: patch transformers to handle "llama3" type
def _patch_rope_scaling():
"""Patch transformers LlamaConfig to handle unsupported 'llama3' RoPE scaling type."""
try:
from transformers import LlamaConfig
from transformers.models.llama.configuration_llama import LlamaConfig as OriginalLlamaConfig
# Store original __init__ if not already patched
if not hasattr(OriginalLlamaConfig, '_rope_scaling_patched'):
original_init = OriginalLlamaConfig.__init__
def patched_init(self, *args, **kwargs):
# Call original init
original_init(self, *args, **kwargs)
# Fix rope_scaling if it's "llama3" (unsupported in some transformers versions)
if hasattr(self, 'rope_scaling') and self.rope_scaling is not None:
# Check if it's a dict or object
if isinstance(self.rope_scaling, dict):
if self.rope_scaling.get('type') == 'llama3':
print("Warning: Converting unsupported RoPE scaling 'llama3' to 'linear'")
self.rope_scaling['type'] = 'linear'
if 'factor' not in self.rope_scaling:
self.rope_scaling['factor'] = 1.0
elif hasattr(self.rope_scaling, 'type'):
if getattr(self.rope_scaling, 'type', None) == 'llama3':
print("Warning: Converting unsupported RoPE scaling 'llama3' to 'linear'")
# Convert to dict format
factor = getattr(self.rope_scaling, 'factor', 1.0)
self.rope_scaling = {'type': 'linear', 'factor': factor}
OriginalLlamaConfig.__init__ = patched_init
OriginalLlamaConfig._rope_scaling_patched = True
except Exception as e:
# If patching fails, we'll handle it in the error handler
print(f"Warning: Could not patch RoPE scaling: {e}", file=sys.stderr)
def _patch_attention_position_embeddings():
"""Patch LlamaSdpaAttention to accept and ignore position_embeddings argument for AirLLM compatibility."""
try:
from transformers.models.llama import modeling_llama
import functools
# Check if LlamaSdpaAttention exists and hasn't been patched
if hasattr(modeling_llama, 'LlamaSdpaAttention'):
LlamaSdpaAttention = modeling_llama.LlamaSdpaAttention
if not hasattr(LlamaSdpaAttention, '_position_embeddings_patched'):
original_forward = LlamaSdpaAttention.forward
@functools.wraps(original_forward)
def patched_forward(self, *args, **kwargs):
# Remove position_embeddings if present (AirLLM compatibility)
kwargs.pop('position_embeddings', None)
# Call original forward
return original_forward(self, *args, **kwargs)
LlamaSdpaAttention.forward = patched_forward
LlamaSdpaAttention._position_embeddings_patched = True
except Exception as e:
# If patching fails, we'll handle it in the error handler
print(f"Warning: Could not patch attention position_embeddings: {e}", file=sys.stderr)
# Apply the patches before importing airllm
_patch_rope_scaling()
_patch_attention_position_embeddings()
LOCAL_AIRLLM_PATH = "/home/ladmin/code/airllm/airllm/air_llm"
if os.path.exists(LOCAL_AIRLLM_PATH) and LOCAL_AIRLLM_PATH not in sys.path:
sys.path.insert(0, LOCAL_AIRLLM_PATH)
try:
from airllm import AutoModel
AIRLLM_AVAILABLE = True
except ImportError as e:
if "optimum.bettertransformer" in str(e) or "BetterTransformer" in str(e):
# Try to work around BetterTransformer import issue
import sys
import importlib.util
# Create a dummy BetterTransformer module to allow airllm to import
class DummyBetterTransformer:
@staticmethod
def transform(model):
return model
# Inject dummy module before importing airllm
spec = importlib.util.spec_from_loader("optimum.bettertransformer", None)
dummy_module = importlib.util.module_from_spec(spec)
dummy_module.BetterTransformer = DummyBetterTransformer
sys.modules["optimum.bettertransformer"] = dummy_module
try:
from airllm import AutoModel
AIRLLM_AVAILABLE = True
except ImportError:
AIRLLM_AVAILABLE = False
AutoModel = None
else:
AIRLLM_AVAILABLE = False
AutoModel = None
AIRLLM_AVAILABLE = False
AutoModel = None
print(f"Warning: Failed to import AirLLM: {e}", file=sys.stderr)
class AirLLMOllamaWrapper:
@ -67,24 +130,131 @@ class AirLLMOllamaWrapper:
)
print(f"Loading AirLLM model: {model_name}")
# Fix RoPE scaling compatibility issue: transformers 4.39.3 doesn't support "llama3" type
# Modify config file if it's a local path and has unsupported rope_scaling
model_path = model_name
if os.path.exists(model_name) or model_name.startswith('/') or model_name.startswith('~'):
if model_name.startswith('~'):
model_path = os.path.expanduser(model_name)
else:
model_path = os.path.abspath(model_name)
config_json_path = os.path.join(model_path, "config.json")
if os.path.exists(config_json_path):
try:
import json
with open(config_json_path, 'r') as f:
config_data = json.load(f)
# Check and fix rope_scaling
if 'rope_scaling' in config_data and config_data['rope_scaling'] is not None:
rope_scaling = config_data['rope_scaling']
if isinstance(rope_scaling, dict) and rope_scaling.get('type') == 'llama3':
print("Warning: Fixing unsupported RoPE scaling type 'llama3' -> 'linear'")
# Backup original config
backup_path = config_json_path + ".backup"
if not os.path.exists(backup_path):
import shutil
shutil.copy2(config_json_path, backup_path)
# Fix the rope_scaling type
config_data['rope_scaling']['type'] = 'linear'
if 'factor' not in config_data['rope_scaling']:
config_data['rope_scaling']['factor'] = 1.0
# Save fixed config
with open(config_json_path, 'w') as f:
json.dump(config_data, f, indent=2)
print(f"Fixed config saved to {config_json_path}")
except Exception as e:
print(f"Warning: Could not fix config file: {e}", file=sys.stderr)
# Determine max_seq_len before loading model
# AirLLM needs this at initialization time
max_seq_len = 2048 # Default for Llama models
# Check if this is a Llama model to determine appropriate max length
# We need to load config first to check model type
try:
from transformers import AutoConfig
config = AutoConfig.from_pretrained(model_name, **{k: v for k, v in kwargs.items() if k in ['token', 'trust_remote_code']})
model_type = getattr(config, 'model_type', '').lower()
is_llama = 'llama' in model_type or 'llama' in model_name.lower()
# Also fix rope_scaling in the loaded config object if needed
if is_llama and hasattr(config, 'rope_scaling') and config.rope_scaling is not None:
if isinstance(config.rope_scaling, dict) and config.rope_scaling.get('type') == 'llama3':
print("Warning: Converting RoPE scaling 'llama3' to 'linear' in config object")
config.rope_scaling['type'] = 'linear'
if 'factor' not in config.rope_scaling:
config.rope_scaling['factor'] = 1.0
elif hasattr(config.rope_scaling, 'type') and getattr(config.rope_scaling, 'type', None) == 'llama3':
# Convert object to dict
factor = getattr(config.rope_scaling, 'factor', 1.0)
config.rope_scaling = {'type': 'linear', 'factor': factor}
if is_llama:
config_max = getattr(config, 'max_position_embeddings', None)
if config_max and config_max > 0:
max_seq_len = min(config_max, 2048)
else:
max_seq_len = 2048
else:
config_max = getattr(config, 'max_position_embeddings', None)
if config_max and config_max > 0 and config_max <= 2048:
max_seq_len = config_max
else:
max_seq_len = 512
except Exception:
# Fallback to defaults if config loading fails
pass
# AutoModel.from_pretrained() accepts:
# - Hugging Face model IDs (e.g., "meta-llama/Llama-3.1-8B-Instruct")
# - Local paths (e.g., "/path/to/local/model")
# - Can use local_dir parameter for local models
self.model = AutoModel.from_pretrained(
model_name,
compression=compression,
**kwargs
)
try:
self.model = AutoModel.from_pretrained(
model_name,
compression=compression,
max_seq_len=max_seq_len, # Pass max_seq_len to AirLLM
**kwargs
)
except ValueError as e:
# Handle specific RoPE scaling errors
if "Unknown RoPE scaling type" in str(e) or "rope_scaling" in str(e).lower():
import traceback
error_msg = (
f"RoPE scaling compatibility error: {e}\n"
"The model config uses a RoPE scaling type not supported by your transformers version.\n"
"If this is a local model, the config file should have been fixed automatically.\n"
"If the error persists, try:\n"
"1. For local models: Check that config.json has rope_scaling.type='linear' instead of 'llama3'\n"
"2. Upgrade transformers: pip install --upgrade transformers\n"
"3. Or downgrade to a compatible version: pip install 'transformers==4.37.0'\n"
f"\nFull traceback:\n{traceback.format_exc()}"
)
raise RuntimeError(error_msg) from e
raise
except Exception as e:
import traceback
error_msg = (
f"Failed to load AirLLM model '{model_name}': {e}\n"
f"Error type: {type(e).__name__}\n"
"This is often a transformers version compatibility issue.\n"
"Try one of these solutions:\n"
"1. Install an older transformers version: pip install 'transformers==4.37.0'\n"
"2. Or try: pip install 'transformers==4.38.2'\n"
"3. If using transformers 4.39.3, try downgrading: pip install 'transformers==4.37.0'\n"
"4. Check AirLLM compatibility with your transformers version\n"
f"\nFull traceback:\n{traceback.format_exc()}"
)
raise RuntimeError(error_msg) from e
self.model_name = model_name
# Get the model's maximum sequence length for AirLLM
# IMPORTANT: AirLLM processes sequences in chunks, and each chunk must fit
# within the model's position embedding limits.
# Even if the model config says it supports longer sequences via rope scaling,
# AirLLM's chunking mechanism requires the base size.
self.max_length = 2048 # Default for Llama models
# Store max_length for tokenization
self.max_length = max_seq_len
# Check if this is a Llama model to determine appropriate max length
is_llama = False
@ -178,7 +348,7 @@ class AirLLMOllamaWrapper:
gen_kwargs = {
'max_new_tokens': max_gen_tokens,
'use_cache': True,
'use_cache': False, # Disable cache to avoid DynamicCache compatibility issues
'return_dict_in_generate': True,
'temperature': temperature,
'top_p': top_p,
@ -190,8 +360,22 @@ class AirLLMOllamaWrapper:
gen_kwargs['attention_mask'] = attention_mask
# Generate
with torch.inference_mode():
generation_output = self.model.generate(input_ids, **gen_kwargs)
try:
with torch.inference_mode():
generation_output = self.model.generate(input_ids, **gen_kwargs)
except (TypeError, RuntimeError) as e:
if "position_embeddings" in str(e) or "cannot unpack" in str(e):
error_msg = (
f"AirLLM compatibility error with transformers: {e}\n"
"This is a known issue with AirLLM and transformers version compatibility.\n"
"Try one of these solutions:\n"
"1. Install transformers 4.37.0: pip install 'transformers==4.37.0'\n"
"2. Or try transformers 4.38.2: pip install 'transformers==4.38.2'\n"
"3. If you're using 4.39.3, it may have compatibility issues - try downgrading\n"
"4. Or use Ollama instead: nanobot agent -m 'Hello' (with Ollama provider)"
)
raise RuntimeError(error_msg) from e
raise
# Decode output - get only the newly generated tokens
if hasattr(generation_output, 'sequences'):