Fix transformers 4.39.3 compatibility issues with AirLLM
- Fix RoPE scaling compatibility: automatically convert unsupported 'llama3' type to 'linear' for local models - Patch LlamaSdpaAttention to filter out position_embeddings argument that AirLLM passes but transformers 4.39.3 doesn't accept - Add better error handling with specific guidance for compatibility issues - Fix config file modification for local models with unsupported rope_scaling types - Improve error messages to help diagnose transformers version compatibility issues These fixes allow nanobot to work with transformers 4.39.3 and AirLLM.
This commit is contained in:
parent
f1faee54b6
commit
7961bf1360
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import sys
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@ -129,14 +130,20 @@ class AirLLMProvider(LLMProvider):
|
|||||||
|
|
||||||
# Run the synchronous client in an executor to avoid blocking
|
# Run the synchronous client in an executor to avoid blocking
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
response_text = await loop.run_in_executor(
|
try:
|
||||||
None,
|
response_text = await loop.run_in_executor(
|
||||||
lambda: client.chat(
|
None,
|
||||||
messages=messages,
|
lambda: client.chat(
|
||||||
max_tokens=max_tokens,
|
messages=messages,
|
||||||
temperature=temperature,
|
max_tokens=max_tokens,
|
||||||
|
temperature=temperature,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
error_msg = f"AirLLM generation failed: {e}\n{traceback.format_exc()}"
|
||||||
|
print(error_msg, file=sys.stderr)
|
||||||
|
raise RuntimeError(f"AirLLM provider error: {e}") from e
|
||||||
|
|
||||||
# Parse tool calls from response if present
|
# Parse tool calls from response if present
|
||||||
# This is a simplified parser - you may need to adjust based on model output format
|
# This is a simplified parser - you may need to adjust based on model output format
|
||||||
|
|||||||
@ -9,37 +9,100 @@ making it easy to replace Ollama in existing projects.
|
|||||||
import torch
|
import torch
|
||||||
from typing import List, Dict, Optional, Union
|
from typing import List, Dict, Optional, Union
|
||||||
|
|
||||||
# Try to import airllm, handle BetterTransformer import error gracefully
|
# Try to import airllm, preferring the local checkout if available
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import importlib.util
|
||||||
|
|
||||||
|
# Inject dummy BetterTransformer BEFORE importing airllm (local code needs it)
|
||||||
|
class DummyBetterTransformer:
|
||||||
|
@staticmethod
|
||||||
|
def transform(model):
|
||||||
|
return model
|
||||||
|
|
||||||
|
if "optimum.bettertransformer" not in sys.modules:
|
||||||
|
spec = importlib.util.spec_from_loader("optimum.bettertransformer", None)
|
||||||
|
dummy_module = importlib.util.module_from_spec(spec)
|
||||||
|
dummy_module.BetterTransformer = DummyBetterTransformer
|
||||||
|
sys.modules["optimum.bettertransformer"] = dummy_module
|
||||||
|
|
||||||
|
# Fix RoPE scaling compatibility: patch transformers to handle "llama3" type
|
||||||
|
def _patch_rope_scaling():
|
||||||
|
"""Patch transformers LlamaConfig to handle unsupported 'llama3' RoPE scaling type."""
|
||||||
|
try:
|
||||||
|
from transformers import LlamaConfig
|
||||||
|
from transformers.models.llama.configuration_llama import LlamaConfig as OriginalLlamaConfig
|
||||||
|
|
||||||
|
# Store original __init__ if not already patched
|
||||||
|
if not hasattr(OriginalLlamaConfig, '_rope_scaling_patched'):
|
||||||
|
original_init = OriginalLlamaConfig.__init__
|
||||||
|
|
||||||
|
def patched_init(self, *args, **kwargs):
|
||||||
|
# Call original init
|
||||||
|
original_init(self, *args, **kwargs)
|
||||||
|
|
||||||
|
# Fix rope_scaling if it's "llama3" (unsupported in some transformers versions)
|
||||||
|
if hasattr(self, 'rope_scaling') and self.rope_scaling is not None:
|
||||||
|
# Check if it's a dict or object
|
||||||
|
if isinstance(self.rope_scaling, dict):
|
||||||
|
if self.rope_scaling.get('type') == 'llama3':
|
||||||
|
print("Warning: Converting unsupported RoPE scaling 'llama3' to 'linear'")
|
||||||
|
self.rope_scaling['type'] = 'linear'
|
||||||
|
if 'factor' not in self.rope_scaling:
|
||||||
|
self.rope_scaling['factor'] = 1.0
|
||||||
|
elif hasattr(self.rope_scaling, 'type'):
|
||||||
|
if getattr(self.rope_scaling, 'type', None) == 'llama3':
|
||||||
|
print("Warning: Converting unsupported RoPE scaling 'llama3' to 'linear'")
|
||||||
|
# Convert to dict format
|
||||||
|
factor = getattr(self.rope_scaling, 'factor', 1.0)
|
||||||
|
self.rope_scaling = {'type': 'linear', 'factor': factor}
|
||||||
|
|
||||||
|
OriginalLlamaConfig.__init__ = patched_init
|
||||||
|
OriginalLlamaConfig._rope_scaling_patched = True
|
||||||
|
except Exception as e:
|
||||||
|
# If patching fails, we'll handle it in the error handler
|
||||||
|
print(f"Warning: Could not patch RoPE scaling: {e}", file=sys.stderr)
|
||||||
|
|
||||||
|
def _patch_attention_position_embeddings():
|
||||||
|
"""Patch LlamaSdpaAttention to accept and ignore position_embeddings argument for AirLLM compatibility."""
|
||||||
|
try:
|
||||||
|
from transformers.models.llama import modeling_llama
|
||||||
|
import functools
|
||||||
|
|
||||||
|
# Check if LlamaSdpaAttention exists and hasn't been patched
|
||||||
|
if hasattr(modeling_llama, 'LlamaSdpaAttention'):
|
||||||
|
LlamaSdpaAttention = modeling_llama.LlamaSdpaAttention
|
||||||
|
if not hasattr(LlamaSdpaAttention, '_position_embeddings_patched'):
|
||||||
|
original_forward = LlamaSdpaAttention.forward
|
||||||
|
|
||||||
|
@functools.wraps(original_forward)
|
||||||
|
def patched_forward(self, *args, **kwargs):
|
||||||
|
# Remove position_embeddings if present (AirLLM compatibility)
|
||||||
|
kwargs.pop('position_embeddings', None)
|
||||||
|
# Call original forward
|
||||||
|
return original_forward(self, *args, **kwargs)
|
||||||
|
|
||||||
|
LlamaSdpaAttention.forward = patched_forward
|
||||||
|
LlamaSdpaAttention._position_embeddings_patched = True
|
||||||
|
except Exception as e:
|
||||||
|
# If patching fails, we'll handle it in the error handler
|
||||||
|
print(f"Warning: Could not patch attention position_embeddings: {e}", file=sys.stderr)
|
||||||
|
|
||||||
|
# Apply the patches before importing airllm
|
||||||
|
_patch_rope_scaling()
|
||||||
|
_patch_attention_position_embeddings()
|
||||||
|
|
||||||
|
LOCAL_AIRLLM_PATH = "/home/ladmin/code/airllm/airllm/air_llm"
|
||||||
|
if os.path.exists(LOCAL_AIRLLM_PATH) and LOCAL_AIRLLM_PATH not in sys.path:
|
||||||
|
sys.path.insert(0, LOCAL_AIRLLM_PATH)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from airllm import AutoModel
|
from airllm import AutoModel
|
||||||
AIRLLM_AVAILABLE = True
|
AIRLLM_AVAILABLE = True
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
if "optimum.bettertransformer" in str(e) or "BetterTransformer" in str(e):
|
AIRLLM_AVAILABLE = False
|
||||||
# Try to work around BetterTransformer import issue
|
AutoModel = None
|
||||||
import sys
|
print(f"Warning: Failed to import AirLLM: {e}", file=sys.stderr)
|
||||||
import importlib.util
|
|
||||||
|
|
||||||
# Create a dummy BetterTransformer module to allow airllm to import
|
|
||||||
class DummyBetterTransformer:
|
|
||||||
@staticmethod
|
|
||||||
def transform(model):
|
|
||||||
return model
|
|
||||||
|
|
||||||
# Inject dummy module before importing airllm
|
|
||||||
spec = importlib.util.spec_from_loader("optimum.bettertransformer", None)
|
|
||||||
dummy_module = importlib.util.module_from_spec(spec)
|
|
||||||
dummy_module.BetterTransformer = DummyBetterTransformer
|
|
||||||
sys.modules["optimum.bettertransformer"] = dummy_module
|
|
||||||
|
|
||||||
try:
|
|
||||||
from airllm import AutoModel
|
|
||||||
AIRLLM_AVAILABLE = True
|
|
||||||
except ImportError:
|
|
||||||
AIRLLM_AVAILABLE = False
|
|
||||||
AutoModel = None
|
|
||||||
else:
|
|
||||||
AIRLLM_AVAILABLE = False
|
|
||||||
AutoModel = None
|
|
||||||
|
|
||||||
|
|
||||||
class AirLLMOllamaWrapper:
|
class AirLLMOllamaWrapper:
|
||||||
@ -67,24 +130,131 @@ class AirLLMOllamaWrapper:
|
|||||||
)
|
)
|
||||||
|
|
||||||
print(f"Loading AirLLM model: {model_name}")
|
print(f"Loading AirLLM model: {model_name}")
|
||||||
|
|
||||||
|
# Fix RoPE scaling compatibility issue: transformers 4.39.3 doesn't support "llama3" type
|
||||||
|
# Modify config file if it's a local path and has unsupported rope_scaling
|
||||||
|
model_path = model_name
|
||||||
|
if os.path.exists(model_name) or model_name.startswith('/') or model_name.startswith('~'):
|
||||||
|
if model_name.startswith('~'):
|
||||||
|
model_path = os.path.expanduser(model_name)
|
||||||
|
else:
|
||||||
|
model_path = os.path.abspath(model_name)
|
||||||
|
|
||||||
|
config_json_path = os.path.join(model_path, "config.json")
|
||||||
|
if os.path.exists(config_json_path):
|
||||||
|
try:
|
||||||
|
import json
|
||||||
|
with open(config_json_path, 'r') as f:
|
||||||
|
config_data = json.load(f)
|
||||||
|
|
||||||
|
# Check and fix rope_scaling
|
||||||
|
if 'rope_scaling' in config_data and config_data['rope_scaling'] is not None:
|
||||||
|
rope_scaling = config_data['rope_scaling']
|
||||||
|
if isinstance(rope_scaling, dict) and rope_scaling.get('type') == 'llama3':
|
||||||
|
print("Warning: Fixing unsupported RoPE scaling type 'llama3' -> 'linear'")
|
||||||
|
# Backup original config
|
||||||
|
backup_path = config_json_path + ".backup"
|
||||||
|
if not os.path.exists(backup_path):
|
||||||
|
import shutil
|
||||||
|
shutil.copy2(config_json_path, backup_path)
|
||||||
|
|
||||||
|
# Fix the rope_scaling type
|
||||||
|
config_data['rope_scaling']['type'] = 'linear'
|
||||||
|
if 'factor' not in config_data['rope_scaling']:
|
||||||
|
config_data['rope_scaling']['factor'] = 1.0
|
||||||
|
|
||||||
|
# Save fixed config
|
||||||
|
with open(config_json_path, 'w') as f:
|
||||||
|
json.dump(config_data, f, indent=2)
|
||||||
|
print(f"Fixed config saved to {config_json_path}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Warning: Could not fix config file: {e}", file=sys.stderr)
|
||||||
|
|
||||||
|
# Determine max_seq_len before loading model
|
||||||
|
# AirLLM needs this at initialization time
|
||||||
|
max_seq_len = 2048 # Default for Llama models
|
||||||
|
|
||||||
|
# Check if this is a Llama model to determine appropriate max length
|
||||||
|
# We need to load config first to check model type
|
||||||
|
try:
|
||||||
|
from transformers import AutoConfig
|
||||||
|
config = AutoConfig.from_pretrained(model_name, **{k: v for k, v in kwargs.items() if k in ['token', 'trust_remote_code']})
|
||||||
|
model_type = getattr(config, 'model_type', '').lower()
|
||||||
|
is_llama = 'llama' in model_type or 'llama' in model_name.lower()
|
||||||
|
|
||||||
|
# Also fix rope_scaling in the loaded config object if needed
|
||||||
|
if is_llama and hasattr(config, 'rope_scaling') and config.rope_scaling is not None:
|
||||||
|
if isinstance(config.rope_scaling, dict) and config.rope_scaling.get('type') == 'llama3':
|
||||||
|
print("Warning: Converting RoPE scaling 'llama3' to 'linear' in config object")
|
||||||
|
config.rope_scaling['type'] = 'linear'
|
||||||
|
if 'factor' not in config.rope_scaling:
|
||||||
|
config.rope_scaling['factor'] = 1.0
|
||||||
|
elif hasattr(config.rope_scaling, 'type') and getattr(config.rope_scaling, 'type', None) == 'llama3':
|
||||||
|
# Convert object to dict
|
||||||
|
factor = getattr(config.rope_scaling, 'factor', 1.0)
|
||||||
|
config.rope_scaling = {'type': 'linear', 'factor': factor}
|
||||||
|
|
||||||
|
if is_llama:
|
||||||
|
config_max = getattr(config, 'max_position_embeddings', None)
|
||||||
|
if config_max and config_max > 0:
|
||||||
|
max_seq_len = min(config_max, 2048)
|
||||||
|
else:
|
||||||
|
max_seq_len = 2048
|
||||||
|
else:
|
||||||
|
config_max = getattr(config, 'max_position_embeddings', None)
|
||||||
|
if config_max and config_max > 0 and config_max <= 2048:
|
||||||
|
max_seq_len = config_max
|
||||||
|
else:
|
||||||
|
max_seq_len = 512
|
||||||
|
except Exception:
|
||||||
|
# Fallback to defaults if config loading fails
|
||||||
|
pass
|
||||||
|
|
||||||
# AutoModel.from_pretrained() accepts:
|
# AutoModel.from_pretrained() accepts:
|
||||||
# - Hugging Face model IDs (e.g., "meta-llama/Llama-3.1-8B-Instruct")
|
# - Hugging Face model IDs (e.g., "meta-llama/Llama-3.1-8B-Instruct")
|
||||||
# - Local paths (e.g., "/path/to/local/model")
|
# - Local paths (e.g., "/path/to/local/model")
|
||||||
# - Can use local_dir parameter for local models
|
# - Can use local_dir parameter for local models
|
||||||
self.model = AutoModel.from_pretrained(
|
try:
|
||||||
model_name,
|
self.model = AutoModel.from_pretrained(
|
||||||
compression=compression,
|
model_name,
|
||||||
**kwargs
|
compression=compression,
|
||||||
)
|
max_seq_len=max_seq_len, # Pass max_seq_len to AirLLM
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
except ValueError as e:
|
||||||
|
# Handle specific RoPE scaling errors
|
||||||
|
if "Unknown RoPE scaling type" in str(e) or "rope_scaling" in str(e).lower():
|
||||||
|
import traceback
|
||||||
|
error_msg = (
|
||||||
|
f"RoPE scaling compatibility error: {e}\n"
|
||||||
|
"The model config uses a RoPE scaling type not supported by your transformers version.\n"
|
||||||
|
"If this is a local model, the config file should have been fixed automatically.\n"
|
||||||
|
"If the error persists, try:\n"
|
||||||
|
"1. For local models: Check that config.json has rope_scaling.type='linear' instead of 'llama3'\n"
|
||||||
|
"2. Upgrade transformers: pip install --upgrade transformers\n"
|
||||||
|
"3. Or downgrade to a compatible version: pip install 'transformers==4.37.0'\n"
|
||||||
|
f"\nFull traceback:\n{traceback.format_exc()}"
|
||||||
|
)
|
||||||
|
raise RuntimeError(error_msg) from e
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
error_msg = (
|
||||||
|
f"Failed to load AirLLM model '{model_name}': {e}\n"
|
||||||
|
f"Error type: {type(e).__name__}\n"
|
||||||
|
"This is often a transformers version compatibility issue.\n"
|
||||||
|
"Try one of these solutions:\n"
|
||||||
|
"1. Install an older transformers version: pip install 'transformers==4.37.0'\n"
|
||||||
|
"2. Or try: pip install 'transformers==4.38.2'\n"
|
||||||
|
"3. If using transformers 4.39.3, try downgrading: pip install 'transformers==4.37.0'\n"
|
||||||
|
"4. Check AirLLM compatibility with your transformers version\n"
|
||||||
|
f"\nFull traceback:\n{traceback.format_exc()}"
|
||||||
|
)
|
||||||
|
raise RuntimeError(error_msg) from e
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
|
||||||
# Get the model's maximum sequence length for AirLLM
|
# Store max_length for tokenization
|
||||||
# IMPORTANT: AirLLM processes sequences in chunks, and each chunk must fit
|
self.max_length = max_seq_len
|
||||||
# within the model's position embedding limits.
|
|
||||||
# Even if the model config says it supports longer sequences via rope scaling,
|
|
||||||
# AirLLM's chunking mechanism requires the base size.
|
|
||||||
|
|
||||||
self.max_length = 2048 # Default for Llama models
|
|
||||||
|
|
||||||
# Check if this is a Llama model to determine appropriate max length
|
# Check if this is a Llama model to determine appropriate max length
|
||||||
is_llama = False
|
is_llama = False
|
||||||
@ -178,7 +348,7 @@ class AirLLMOllamaWrapper:
|
|||||||
|
|
||||||
gen_kwargs = {
|
gen_kwargs = {
|
||||||
'max_new_tokens': max_gen_tokens,
|
'max_new_tokens': max_gen_tokens,
|
||||||
'use_cache': True,
|
'use_cache': False, # Disable cache to avoid DynamicCache compatibility issues
|
||||||
'return_dict_in_generate': True,
|
'return_dict_in_generate': True,
|
||||||
'temperature': temperature,
|
'temperature': temperature,
|
||||||
'top_p': top_p,
|
'top_p': top_p,
|
||||||
@ -190,8 +360,22 @@ class AirLLMOllamaWrapper:
|
|||||||
gen_kwargs['attention_mask'] = attention_mask
|
gen_kwargs['attention_mask'] = attention_mask
|
||||||
|
|
||||||
# Generate
|
# Generate
|
||||||
with torch.inference_mode():
|
try:
|
||||||
generation_output = self.model.generate(input_ids, **gen_kwargs)
|
with torch.inference_mode():
|
||||||
|
generation_output = self.model.generate(input_ids, **gen_kwargs)
|
||||||
|
except (TypeError, RuntimeError) as e:
|
||||||
|
if "position_embeddings" in str(e) or "cannot unpack" in str(e):
|
||||||
|
error_msg = (
|
||||||
|
f"AirLLM compatibility error with transformers: {e}\n"
|
||||||
|
"This is a known issue with AirLLM and transformers version compatibility.\n"
|
||||||
|
"Try one of these solutions:\n"
|
||||||
|
"1. Install transformers 4.37.0: pip install 'transformers==4.37.0'\n"
|
||||||
|
"2. Or try transformers 4.38.2: pip install 'transformers==4.38.2'\n"
|
||||||
|
"3. If you're using 4.39.3, it may have compatibility issues - try downgrading\n"
|
||||||
|
"4. Or use Ollama instead: nanobot agent -m 'Hello' (with Ollama provider)"
|
||||||
|
)
|
||||||
|
raise RuntimeError(error_msg) from e
|
||||||
|
raise
|
||||||
|
|
||||||
# Decode output - get only the newly generated tokens
|
# Decode output - get only the newly generated tokens
|
||||||
if hasattr(generation_output, 'sequences'):
|
if hasattr(generation_output, 'sequences'):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user