#!/usr/bin/env python3 """Test get_pos_emb_args directly""" import sys import os sys.path.insert(0, '/home/ladmin/code/airllm/airllm/air_llm') # Inject BetterTransformer import importlib.util class DummyBetterTransformer: @staticmethod def transform(model): return model if "optimum.bettertransformer" not in sys.modules: spec = importlib.util.spec_from_loader("optimum.bettertransformer", None) dummy_module = importlib.util.module_from_spec(spec) dummy_module.BetterTransformer = DummyBetterTransformer sys.modules["optimum.bettertransformer"] = dummy_module from airllm import AutoModel print("Loading model...") model = AutoModel.from_pretrained("/home/ladmin/.local/models/llama3.2-3b-instruct") print("Model loaded") print("\nTesting get_pos_emb_args...") result = model.get_pos_emb_args(0, 128) print(f"Result type: {type(result)}") print(f"Result keys: {result.keys() if isinstance(result, dict) else 'not a dict'}") if isinstance(result, dict) and "position_embeddings" in result: pos_emb = result["position_embeddings"] print(f"position_embeddings type: {type(pos_emb)}") if isinstance(pos_emb, tuple) and len(pos_emb) == 2: cos, sin = pos_emb print(f"✓ cos shape: {cos.shape}, sin shape: {sin.shape}") print("✓ SUCCESS: position_embeddings created correctly") else: print(f"✗ position_embeddings is not a 2-tuple: {pos_emb}") else: print(f"✗ position_embeddings not in result: {result}")