Fix transformers 4.39.3 compatibility issues with AirLLM
- Fix RoPE scaling compatibility: automatically convert unsupported 'llama3' type to 'linear' for local models - Patch LlamaSdpaAttention to filter out position_embeddings argument that AirLLM passes but transformers 4.39.3 doesn't accept - Add better error handling with specific guidance for compatibility issues - Fix config file modification for local models with unsupported rope_scaling types - Improve error messages to help diagnose transformers version compatibility issues These fixes allow nanobot to work with transformers 4.39.3 and AirLLM.
This commit is contained in:
parent
f1faee54b6
commit
7961bf1360
@ -2,6 +2,7 @@
|
||||
|
||||
import json
|
||||
import asyncio
|
||||
import sys
|
||||
from typing import Any
|
||||
from pathlib import Path
|
||||
|
||||
@ -129,14 +130,20 @@ class AirLLMProvider(LLMProvider):
|
||||
|
||||
# Run the synchronous client in an executor to avoid blocking
|
||||
loop = asyncio.get_event_loop()
|
||||
response_text = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: client.chat(
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
try:
|
||||
response_text = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: client.chat(
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
)
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
error_msg = f"AirLLM generation failed: {e}\n{traceback.format_exc()}"
|
||||
print(error_msg, file=sys.stderr)
|
||||
raise RuntimeError(f"AirLLM provider error: {e}") from e
|
||||
|
||||
# Parse tool calls from response if present
|
||||
# This is a simplified parser - you may need to adjust based on model output format
|
||||
|
||||
@ -9,37 +9,100 @@ making it easy to replace Ollama in existing projects.
|
||||
import torch
|
||||
from typing import List, Dict, Optional, Union
|
||||
|
||||
# Try to import airllm, handle BetterTransformer import error gracefully
|
||||
# Try to import airllm, preferring the local checkout if available
|
||||
import sys
|
||||
import os
|
||||
import importlib.util
|
||||
|
||||
# Inject dummy BetterTransformer BEFORE importing airllm (local code needs it)
|
||||
class DummyBetterTransformer:
|
||||
@staticmethod
|
||||
def transform(model):
|
||||
return model
|
||||
|
||||
if "optimum.bettertransformer" not in sys.modules:
|
||||
spec = importlib.util.spec_from_loader("optimum.bettertransformer", None)
|
||||
dummy_module = importlib.util.module_from_spec(spec)
|
||||
dummy_module.BetterTransformer = DummyBetterTransformer
|
||||
sys.modules["optimum.bettertransformer"] = dummy_module
|
||||
|
||||
# Fix RoPE scaling compatibility: patch transformers to handle "llama3" type
|
||||
def _patch_rope_scaling():
|
||||
"""Patch transformers LlamaConfig to handle unsupported 'llama3' RoPE scaling type."""
|
||||
try:
|
||||
from transformers import LlamaConfig
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig as OriginalLlamaConfig
|
||||
|
||||
# Store original __init__ if not already patched
|
||||
if not hasattr(OriginalLlamaConfig, '_rope_scaling_patched'):
|
||||
original_init = OriginalLlamaConfig.__init__
|
||||
|
||||
def patched_init(self, *args, **kwargs):
|
||||
# Call original init
|
||||
original_init(self, *args, **kwargs)
|
||||
|
||||
# Fix rope_scaling if it's "llama3" (unsupported in some transformers versions)
|
||||
if hasattr(self, 'rope_scaling') and self.rope_scaling is not None:
|
||||
# Check if it's a dict or object
|
||||
if isinstance(self.rope_scaling, dict):
|
||||
if self.rope_scaling.get('type') == 'llama3':
|
||||
print("Warning: Converting unsupported RoPE scaling 'llama3' to 'linear'")
|
||||
self.rope_scaling['type'] = 'linear'
|
||||
if 'factor' not in self.rope_scaling:
|
||||
self.rope_scaling['factor'] = 1.0
|
||||
elif hasattr(self.rope_scaling, 'type'):
|
||||
if getattr(self.rope_scaling, 'type', None) == 'llama3':
|
||||
print("Warning: Converting unsupported RoPE scaling 'llama3' to 'linear'")
|
||||
# Convert to dict format
|
||||
factor = getattr(self.rope_scaling, 'factor', 1.0)
|
||||
self.rope_scaling = {'type': 'linear', 'factor': factor}
|
||||
|
||||
OriginalLlamaConfig.__init__ = patched_init
|
||||
OriginalLlamaConfig._rope_scaling_patched = True
|
||||
except Exception as e:
|
||||
# If patching fails, we'll handle it in the error handler
|
||||
print(f"Warning: Could not patch RoPE scaling: {e}", file=sys.stderr)
|
||||
|
||||
def _patch_attention_position_embeddings():
|
||||
"""Patch LlamaSdpaAttention to accept and ignore position_embeddings argument for AirLLM compatibility."""
|
||||
try:
|
||||
from transformers.models.llama import modeling_llama
|
||||
import functools
|
||||
|
||||
# Check if LlamaSdpaAttention exists and hasn't been patched
|
||||
if hasattr(modeling_llama, 'LlamaSdpaAttention'):
|
||||
LlamaSdpaAttention = modeling_llama.LlamaSdpaAttention
|
||||
if not hasattr(LlamaSdpaAttention, '_position_embeddings_patched'):
|
||||
original_forward = LlamaSdpaAttention.forward
|
||||
|
||||
@functools.wraps(original_forward)
|
||||
def patched_forward(self, *args, **kwargs):
|
||||
# Remove position_embeddings if present (AirLLM compatibility)
|
||||
kwargs.pop('position_embeddings', None)
|
||||
# Call original forward
|
||||
return original_forward(self, *args, **kwargs)
|
||||
|
||||
LlamaSdpaAttention.forward = patched_forward
|
||||
LlamaSdpaAttention._position_embeddings_patched = True
|
||||
except Exception as e:
|
||||
# If patching fails, we'll handle it in the error handler
|
||||
print(f"Warning: Could not patch attention position_embeddings: {e}", file=sys.stderr)
|
||||
|
||||
# Apply the patches before importing airllm
|
||||
_patch_rope_scaling()
|
||||
_patch_attention_position_embeddings()
|
||||
|
||||
LOCAL_AIRLLM_PATH = "/home/ladmin/code/airllm/airllm/air_llm"
|
||||
if os.path.exists(LOCAL_AIRLLM_PATH) and LOCAL_AIRLLM_PATH not in sys.path:
|
||||
sys.path.insert(0, LOCAL_AIRLLM_PATH)
|
||||
|
||||
try:
|
||||
from airllm import AutoModel
|
||||
AIRLLM_AVAILABLE = True
|
||||
except ImportError as e:
|
||||
if "optimum.bettertransformer" in str(e) or "BetterTransformer" in str(e):
|
||||
# Try to work around BetterTransformer import issue
|
||||
import sys
|
||||
import importlib.util
|
||||
|
||||
# Create a dummy BetterTransformer module to allow airllm to import
|
||||
class DummyBetterTransformer:
|
||||
@staticmethod
|
||||
def transform(model):
|
||||
return model
|
||||
|
||||
# Inject dummy module before importing airllm
|
||||
spec = importlib.util.spec_from_loader("optimum.bettertransformer", None)
|
||||
dummy_module = importlib.util.module_from_spec(spec)
|
||||
dummy_module.BetterTransformer = DummyBetterTransformer
|
||||
sys.modules["optimum.bettertransformer"] = dummy_module
|
||||
|
||||
try:
|
||||
from airllm import AutoModel
|
||||
AIRLLM_AVAILABLE = True
|
||||
except ImportError:
|
||||
AIRLLM_AVAILABLE = False
|
||||
AutoModel = None
|
||||
else:
|
||||
AIRLLM_AVAILABLE = False
|
||||
AutoModel = None
|
||||
AIRLLM_AVAILABLE = False
|
||||
AutoModel = None
|
||||
print(f"Warning: Failed to import AirLLM: {e}", file=sys.stderr)
|
||||
|
||||
|
||||
class AirLLMOllamaWrapper:
|
||||
@ -67,24 +130,131 @@ class AirLLMOllamaWrapper:
|
||||
)
|
||||
|
||||
print(f"Loading AirLLM model: {model_name}")
|
||||
|
||||
# Fix RoPE scaling compatibility issue: transformers 4.39.3 doesn't support "llama3" type
|
||||
# Modify config file if it's a local path and has unsupported rope_scaling
|
||||
model_path = model_name
|
||||
if os.path.exists(model_name) or model_name.startswith('/') or model_name.startswith('~'):
|
||||
if model_name.startswith('~'):
|
||||
model_path = os.path.expanduser(model_name)
|
||||
else:
|
||||
model_path = os.path.abspath(model_name)
|
||||
|
||||
config_json_path = os.path.join(model_path, "config.json")
|
||||
if os.path.exists(config_json_path):
|
||||
try:
|
||||
import json
|
||||
with open(config_json_path, 'r') as f:
|
||||
config_data = json.load(f)
|
||||
|
||||
# Check and fix rope_scaling
|
||||
if 'rope_scaling' in config_data and config_data['rope_scaling'] is not None:
|
||||
rope_scaling = config_data['rope_scaling']
|
||||
if isinstance(rope_scaling, dict) and rope_scaling.get('type') == 'llama3':
|
||||
print("Warning: Fixing unsupported RoPE scaling type 'llama3' -> 'linear'")
|
||||
# Backup original config
|
||||
backup_path = config_json_path + ".backup"
|
||||
if not os.path.exists(backup_path):
|
||||
import shutil
|
||||
shutil.copy2(config_json_path, backup_path)
|
||||
|
||||
# Fix the rope_scaling type
|
||||
config_data['rope_scaling']['type'] = 'linear'
|
||||
if 'factor' not in config_data['rope_scaling']:
|
||||
config_data['rope_scaling']['factor'] = 1.0
|
||||
|
||||
# Save fixed config
|
||||
with open(config_json_path, 'w') as f:
|
||||
json.dump(config_data, f, indent=2)
|
||||
print(f"Fixed config saved to {config_json_path}")
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not fix config file: {e}", file=sys.stderr)
|
||||
|
||||
# Determine max_seq_len before loading model
|
||||
# AirLLM needs this at initialization time
|
||||
max_seq_len = 2048 # Default for Llama models
|
||||
|
||||
# Check if this is a Llama model to determine appropriate max length
|
||||
# We need to load config first to check model type
|
||||
try:
|
||||
from transformers import AutoConfig
|
||||
config = AutoConfig.from_pretrained(model_name, **{k: v for k, v in kwargs.items() if k in ['token', 'trust_remote_code']})
|
||||
model_type = getattr(config, 'model_type', '').lower()
|
||||
is_llama = 'llama' in model_type or 'llama' in model_name.lower()
|
||||
|
||||
# Also fix rope_scaling in the loaded config object if needed
|
||||
if is_llama and hasattr(config, 'rope_scaling') and config.rope_scaling is not None:
|
||||
if isinstance(config.rope_scaling, dict) and config.rope_scaling.get('type') == 'llama3':
|
||||
print("Warning: Converting RoPE scaling 'llama3' to 'linear' in config object")
|
||||
config.rope_scaling['type'] = 'linear'
|
||||
if 'factor' not in config.rope_scaling:
|
||||
config.rope_scaling['factor'] = 1.0
|
||||
elif hasattr(config.rope_scaling, 'type') and getattr(config.rope_scaling, 'type', None) == 'llama3':
|
||||
# Convert object to dict
|
||||
factor = getattr(config.rope_scaling, 'factor', 1.0)
|
||||
config.rope_scaling = {'type': 'linear', 'factor': factor}
|
||||
|
||||
if is_llama:
|
||||
config_max = getattr(config, 'max_position_embeddings', None)
|
||||
if config_max and config_max > 0:
|
||||
max_seq_len = min(config_max, 2048)
|
||||
else:
|
||||
max_seq_len = 2048
|
||||
else:
|
||||
config_max = getattr(config, 'max_position_embeddings', None)
|
||||
if config_max and config_max > 0 and config_max <= 2048:
|
||||
max_seq_len = config_max
|
||||
else:
|
||||
max_seq_len = 512
|
||||
except Exception:
|
||||
# Fallback to defaults if config loading fails
|
||||
pass
|
||||
|
||||
# AutoModel.from_pretrained() accepts:
|
||||
# - Hugging Face model IDs (e.g., "meta-llama/Llama-3.1-8B-Instruct")
|
||||
# - Local paths (e.g., "/path/to/local/model")
|
||||
# - Can use local_dir parameter for local models
|
||||
self.model = AutoModel.from_pretrained(
|
||||
model_name,
|
||||
compression=compression,
|
||||
**kwargs
|
||||
)
|
||||
try:
|
||||
self.model = AutoModel.from_pretrained(
|
||||
model_name,
|
||||
compression=compression,
|
||||
max_seq_len=max_seq_len, # Pass max_seq_len to AirLLM
|
||||
**kwargs
|
||||
)
|
||||
except ValueError as e:
|
||||
# Handle specific RoPE scaling errors
|
||||
if "Unknown RoPE scaling type" in str(e) or "rope_scaling" in str(e).lower():
|
||||
import traceback
|
||||
error_msg = (
|
||||
f"RoPE scaling compatibility error: {e}\n"
|
||||
"The model config uses a RoPE scaling type not supported by your transformers version.\n"
|
||||
"If this is a local model, the config file should have been fixed automatically.\n"
|
||||
"If the error persists, try:\n"
|
||||
"1. For local models: Check that config.json has rope_scaling.type='linear' instead of 'llama3'\n"
|
||||
"2. Upgrade transformers: pip install --upgrade transformers\n"
|
||||
"3. Or downgrade to a compatible version: pip install 'transformers==4.37.0'\n"
|
||||
f"\nFull traceback:\n{traceback.format_exc()}"
|
||||
)
|
||||
raise RuntimeError(error_msg) from e
|
||||
raise
|
||||
except Exception as e:
|
||||
import traceback
|
||||
error_msg = (
|
||||
f"Failed to load AirLLM model '{model_name}': {e}\n"
|
||||
f"Error type: {type(e).__name__}\n"
|
||||
"This is often a transformers version compatibility issue.\n"
|
||||
"Try one of these solutions:\n"
|
||||
"1. Install an older transformers version: pip install 'transformers==4.37.0'\n"
|
||||
"2. Or try: pip install 'transformers==4.38.2'\n"
|
||||
"3. If using transformers 4.39.3, try downgrading: pip install 'transformers==4.37.0'\n"
|
||||
"4. Check AirLLM compatibility with your transformers version\n"
|
||||
f"\nFull traceback:\n{traceback.format_exc()}"
|
||||
)
|
||||
raise RuntimeError(error_msg) from e
|
||||
self.model_name = model_name
|
||||
|
||||
# Get the model's maximum sequence length for AirLLM
|
||||
# IMPORTANT: AirLLM processes sequences in chunks, and each chunk must fit
|
||||
# within the model's position embedding limits.
|
||||
# Even if the model config says it supports longer sequences via rope scaling,
|
||||
# AirLLM's chunking mechanism requires the base size.
|
||||
|
||||
self.max_length = 2048 # Default for Llama models
|
||||
# Store max_length for tokenization
|
||||
self.max_length = max_seq_len
|
||||
|
||||
# Check if this is a Llama model to determine appropriate max length
|
||||
is_llama = False
|
||||
@ -178,7 +348,7 @@ class AirLLMOllamaWrapper:
|
||||
|
||||
gen_kwargs = {
|
||||
'max_new_tokens': max_gen_tokens,
|
||||
'use_cache': True,
|
||||
'use_cache': False, # Disable cache to avoid DynamicCache compatibility issues
|
||||
'return_dict_in_generate': True,
|
||||
'temperature': temperature,
|
||||
'top_p': top_p,
|
||||
@ -190,8 +360,22 @@ class AirLLMOllamaWrapper:
|
||||
gen_kwargs['attention_mask'] = attention_mask
|
||||
|
||||
# Generate
|
||||
with torch.inference_mode():
|
||||
generation_output = self.model.generate(input_ids, **gen_kwargs)
|
||||
try:
|
||||
with torch.inference_mode():
|
||||
generation_output = self.model.generate(input_ids, **gen_kwargs)
|
||||
except (TypeError, RuntimeError) as e:
|
||||
if "position_embeddings" in str(e) or "cannot unpack" in str(e):
|
||||
error_msg = (
|
||||
f"AirLLM compatibility error with transformers: {e}\n"
|
||||
"This is a known issue with AirLLM and transformers version compatibility.\n"
|
||||
"Try one of these solutions:\n"
|
||||
"1. Install transformers 4.37.0: pip install 'transformers==4.37.0'\n"
|
||||
"2. Or try transformers 4.38.2: pip install 'transformers==4.38.2'\n"
|
||||
"3. If you're using 4.39.3, it may have compatibility issues - try downgrading\n"
|
||||
"4. Or use Ollama instead: nanobot agent -m 'Hello' (with Ollama provider)"
|
||||
)
|
||||
raise RuntimeError(error_msg) from e
|
||||
raise
|
||||
|
||||
# Decode output - get only the newly generated tokens
|
||||
if hasattr(generation_output, 'sequences'):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user