Compare commits
10 Commits
c8831a1e1e
...
9c858699f3
| Author | SHA1 | Date | |
|---|---|---|---|
| 9c858699f3 | |||
| 7961bf1360 | |||
| f1faee54b6 | |||
| 2f8205150f | |||
| 216c9f5039 | |||
| f1e95626f8 | |||
|
|
dd63337a83 | ||
|
|
cdc37e2f5e | ||
|
|
554ba81473 | ||
|
|
cbab72ab72 |
1
.gitignore
vendored
1
.gitignore
vendored
@ -14,6 +14,7 @@ docs/
|
||||
*.pywz
|
||||
*.pyzz
|
||||
.venv/
|
||||
vllm-env/
|
||||
__pycache__/
|
||||
poetry.lock
|
||||
.pytest_cache/
|
||||
|
||||
11
README.md
11
README.md
@ -573,6 +573,17 @@ nanobot gateway
|
||||
|
||||
</details>
|
||||
|
||||
## 🌐 Agent Social Network
|
||||
|
||||
🐈 nanobot is capable of linking to the agent social network (agent community). **Just send one message and your nanobot joins automatically!**
|
||||
|
||||
| Platform | How to Join (send this message to your bot) |
|
||||
|----------|-------------|
|
||||
| [**Moltbook**](https://www.moltbook.com/) | `Read https://moltbook.com/skill.md and follow the instructions to join Moltbook` |
|
||||
| [**ClawdChat**](https://clawdchat.ai/) | `Read https://clawdchat.ai/skill.md and follow the instructions to join ClawdChat` |
|
||||
|
||||
Simply send the command above to your nanobot (via CLI or any chat channel), and it will handle the rest.
|
||||
|
||||
## ⚙️ Configuration
|
||||
|
||||
Config file: `~/.nanobot/config.json`
|
||||
|
||||
239
SETUP.md
Normal file
239
SETUP.md
Normal file
@ -0,0 +1,239 @@
|
||||
# Nanobot Setup Guide
|
||||
|
||||
This guide will help you set up nanobot on a fresh system, pulling from the repository and configuring it to use Ollama and AirLLM with Llama models.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- Python 3.10 or higher
|
||||
- Git
|
||||
- (Optional) CUDA-capable GPU for AirLLM (recommended for better performance)
|
||||
|
||||
## Step 1: Clone the Repository
|
||||
|
||||
```bash
|
||||
git clone <repository-url>
|
||||
cd nanobot
|
||||
```
|
||||
|
||||
If you're using a specific branch (e.g., the cleanup branch):
|
||||
```bash
|
||||
git checkout feature/cleanup-providers-llama-only
|
||||
```
|
||||
|
||||
## Step 2: Create Virtual Environment
|
||||
|
||||
```bash
|
||||
python3 -m venv venv
|
||||
source venv/bin/activate # On Windows: venv\Scripts\activate
|
||||
```
|
||||
|
||||
## Step 3: Install Dependencies
|
||||
|
||||
```bash
|
||||
pip install --upgrade pip
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
If you plan to use AirLLM, also install:
|
||||
```bash
|
||||
pip install airllm bitsandbytes
|
||||
```
|
||||
|
||||
## Step 4: Choose Your Provider Setup
|
||||
|
||||
You have two main options:
|
||||
|
||||
### Option A: Use Ollama (Easiest, No Tokens Needed)
|
||||
|
||||
1. **Install Ollama** (if not already installed):
|
||||
```bash
|
||||
# Linux/Mac
|
||||
curl -fsSL https://ollama.ai/install.sh | sh
|
||||
|
||||
# Or download from: https://ollama.ai
|
||||
```
|
||||
|
||||
2. **Pull a Llama model**:
|
||||
```bash
|
||||
ollama pull llama3.2:latest
|
||||
```
|
||||
|
||||
3. **Configure nanobot**:
|
||||
```bash
|
||||
mkdir -p ~/.nanobot
|
||||
cat > ~/.nanobot/config.json << 'EOF'
|
||||
{
|
||||
"providers": {
|
||||
"ollama": {
|
||||
"apiKey": "dummy",
|
||||
"apiBase": "http://localhost:11434/v1"
|
||||
}
|
||||
},
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"model": "llama3.2:latest"
|
||||
}
|
||||
}
|
||||
}
|
||||
EOF
|
||||
chmod 600 ~/.nanobot/config.json
|
||||
```
|
||||
|
||||
### Option B: Use AirLLM (Direct Local Inference, No HTTP Server)
|
||||
|
||||
1. **Get Hugging Face Token** (one-time, for downloading gated models):
|
||||
- Go to: https://huggingface.co/settings/tokens
|
||||
- Create a new token with "Read" permission
|
||||
- Copy the token (starts with `hf_`)
|
||||
|
||||
2. **Accept Llama License**:
|
||||
- Go to: https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct
|
||||
- Click "Agree and access repository"
|
||||
- Accept the license terms
|
||||
|
||||
3. **Download Llama Model** (one-time):
|
||||
```bash
|
||||
# Install huggingface_hub if needed
|
||||
pip install huggingface_hub
|
||||
|
||||
# Download model to local directory
|
||||
huggingface-cli download meta-llama/Llama-3.2-3B-Instruct \
|
||||
--local-dir ~/.local/models/llama3.2-3b-instruct \
|
||||
--token YOUR_HF_TOKEN_HERE
|
||||
```
|
||||
|
||||
4. **Configure nanobot**:
|
||||
```bash
|
||||
mkdir -p ~/.nanobot
|
||||
cat > ~/.nanobot/config.json << 'EOF'
|
||||
{
|
||||
"providers": {
|
||||
"airllm": {
|
||||
"apiKey": "/home/YOUR_USERNAME/.local/models/llama3.2-3b-instruct",
|
||||
"apiBase": null,
|
||||
"extraHeaders": {}
|
||||
}
|
||||
},
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"model": "/home/YOUR_USERNAME/.local/models/llama3.2-3b-instruct"
|
||||
}
|
||||
}
|
||||
}
|
||||
EOF
|
||||
chmod 600 ~/.nanobot/config.json
|
||||
```
|
||||
|
||||
**Important**: Replace `YOUR_USERNAME` with your actual username, or use `~/.local/models/llama3.2-3b-instruct` (the `~` will be expanded).
|
||||
|
||||
## Step 5: Test the Setup
|
||||
|
||||
```bash
|
||||
nanobot agent -m "Hello, what is 2+5?"
|
||||
```
|
||||
|
||||
You should see a response from the model. If you get errors, see the Troubleshooting section below.
|
||||
|
||||
## Step 6: (Optional) Use Setup Script
|
||||
|
||||
Instead of manual configuration, you can use the provided setup script:
|
||||
|
||||
```bash
|
||||
python3 setup_llama_airllm.py
|
||||
```
|
||||
|
||||
This script will:
|
||||
- Guide you through model selection
|
||||
- Help you configure the Hugging Face token
|
||||
- Set up the config file automatically
|
||||
|
||||
## Configuration File Location
|
||||
|
||||
- **Path**: `~/.nanobot/config.json`
|
||||
- **Permissions**: Should be `600` (read/write for owner only)
|
||||
- **Backup**: Always backup before editing!
|
||||
|
||||
## Available Providers
|
||||
|
||||
After setup, nanobot supports:
|
||||
|
||||
- **Ollama**: Local OpenAI-compatible server (no tokens needed)
|
||||
- **AirLLM**: Direct local model inference (no HTTP server, no tokens after download)
|
||||
- **vLLM**: Local OpenAI-compatible server (for advanced users)
|
||||
- **DeepSeek**: API or local models (for future use)
|
||||
|
||||
## Recommended Models
|
||||
|
||||
### For Ollama:
|
||||
- `llama3.2:latest` - Fast, minimal memory (recommended)
|
||||
- `llama3.1:8b` - Good balance
|
||||
- `llama3.1:70b` - Best quality (needs more GPU)
|
||||
|
||||
### For AirLLM:
|
||||
- `meta-llama/Llama-3.2-3B-Instruct` - Fast, minimal memory (recommended)
|
||||
- `meta-llama/Llama-3.1-8B-Instruct` - Good balance
|
||||
- Local path: `~/.local/models/llama3.2-3b-instruct` (after download)
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### "Model not found" error (AirLLM)
|
||||
- Make sure you've accepted the Llama license on Hugging Face
|
||||
- Verify your HF token has read permissions
|
||||
- Check that the model path in config is correct
|
||||
- Ensure the model files are downloaded (check `~/.local/models/llama3.2-3b-instruct/`)
|
||||
|
||||
### "Connection refused" error (Ollama)
|
||||
- Make sure Ollama is running: `ollama serve`
|
||||
- Check that Ollama is listening on port 11434: `curl http://localhost:11434/api/tags`
|
||||
- Verify the model is pulled: `ollama list`
|
||||
|
||||
### "Out of memory" error (AirLLM)
|
||||
- Try a smaller model (Llama-3.2-3B-Instruct instead of 8B)
|
||||
- Use compression: set `apiBase` to `"4bit"` or `"8bit"` in the airllm config
|
||||
- Close other GPU-intensive applications
|
||||
|
||||
### "No API key configured" error
|
||||
- For Ollama: Use `"dummy"` as apiKey (it's not actually used)
|
||||
- For AirLLM: No API key needed for local paths, but you need the model files downloaded
|
||||
|
||||
### Import errors
|
||||
- Make sure virtual environment is activated
|
||||
- Reinstall dependencies: `pip install -e .`
|
||||
- For AirLLM: `pip install airllm bitsandbytes`
|
||||
|
||||
## Using Local Model Paths (No Tokens After Download)
|
||||
|
||||
Once you've downloaded a model locally with AirLLM, you can use it forever without any tokens:
|
||||
|
||||
```json
|
||||
{
|
||||
"providers": {
|
||||
"airllm": {
|
||||
"apiKey": "/path/to/your/local/model"
|
||||
}
|
||||
},
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"model": "/path/to/your/local/model"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
The model path should point to a directory containing:
|
||||
- `config.json`
|
||||
- `tokenizer.json` (or `tokenizer_config.json`)
|
||||
- Model weights (`model.safetensors` or `pytorch_model.bin`)
|
||||
|
||||
## Next Steps
|
||||
|
||||
- Read the main README.md for usage examples
|
||||
- Check `nanobot --help` for available commands
|
||||
- Explore the workspace features: `nanobot workspace create myproject`
|
||||
|
||||
## Getting Help
|
||||
|
||||
- Check the repository issues
|
||||
- Review the code comments
|
||||
- Test with a simple query first: `nanobot agent -m "Hello"`
|
||||
|
||||
242
airllm_ollama_wrapper.py
Normal file
242
airllm_ollama_wrapper.py
Normal file
@ -0,0 +1,242 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
AirLLM Ollama-Compatible Wrapper
|
||||
|
||||
This wrapper provides an Ollama-like interface for AirLLM,
|
||||
making it easy to replace Ollama in existing projects.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from typing import List, Dict, Optional, Union
|
||||
|
||||
# Try to import airllm, handle BetterTransformer import error gracefully
|
||||
try:
|
||||
from airllm import AutoModel
|
||||
AIRLLM_AVAILABLE = True
|
||||
except ImportError as e:
|
||||
if "optimum.bettertransformer" in str(e) or "BetterTransformer" in str(e):
|
||||
# Try to work around BetterTransformer import issue
|
||||
import sys
|
||||
import importlib.util
|
||||
|
||||
# Create a dummy BetterTransformer module to allow airllm to import
|
||||
class DummyBetterTransformer:
|
||||
@staticmethod
|
||||
def transform(model):
|
||||
return model
|
||||
|
||||
# Inject dummy module before importing airllm
|
||||
spec = importlib.util.spec_from_loader("optimum.bettertransformer", None)
|
||||
dummy_module = importlib.util.module_from_spec(spec)
|
||||
dummy_module.BetterTransformer = DummyBetterTransformer
|
||||
sys.modules["optimum.bettertransformer"] = dummy_module
|
||||
|
||||
try:
|
||||
from airllm import AutoModel
|
||||
AIRLLM_AVAILABLE = True
|
||||
except ImportError:
|
||||
AIRLLM_AVAILABLE = False
|
||||
AutoModel = None
|
||||
else:
|
||||
AIRLLM_AVAILABLE = False
|
||||
AutoModel = None
|
||||
|
||||
|
||||
class AirLLMOllamaWrapper:
|
||||
"""
|
||||
A wrapper that provides an Ollama-like API for AirLLM.
|
||||
|
||||
Usage:
|
||||
# Instead of: ollama.generate(model="llama2", prompt="Hello")
|
||||
# Use: airllm_wrapper.generate(model="llama2", prompt="Hello")
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str, compression: Optional[str] = None, **kwargs):
|
||||
"""
|
||||
Initialize AirLLM model.
|
||||
|
||||
Args:
|
||||
model_name: Hugging Face model name or path (e.g., "meta-llama/Llama-3.2-3B-Instruct")
|
||||
compression: Optional compression ('4bit' or '8bit') for 3x speed improvement
|
||||
**kwargs: Additional arguments for AutoModel.from_pretrained()
|
||||
"""
|
||||
if not AIRLLM_AVAILABLE or AutoModel is None:
|
||||
raise ImportError(
|
||||
"AirLLM is not available. Please install it with: pip install airllm bitsandbytes\n"
|
||||
"If you see a BetterTransformer error, you may need to install: pip install optimum[bettertransformer]"
|
||||
)
|
||||
|
||||
print(f"Loading AirLLM model: {model_name}")
|
||||
self.model = AutoModel.from_pretrained(
|
||||
model_name,
|
||||
compression=compression,
|
||||
**kwargs
|
||||
)
|
||||
self.model_name = model_name
|
||||
print("Model loaded successfully!")
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
model: Optional[str] = None, # Ignored, kept for API compatibility
|
||||
max_tokens: int = 50,
|
||||
temperature: float = 0.7,
|
||||
top_p: float = 0.9,
|
||||
stream: bool = False,
|
||||
**kwargs
|
||||
) -> Union[str, Dict]:
|
||||
"""
|
||||
Generate text from a prompt (Ollama-compatible interface).
|
||||
|
||||
Args:
|
||||
prompt: Input text prompt
|
||||
model: Ignored (kept for compatibility)
|
||||
max_tokens: Maximum number of tokens to generate
|
||||
temperature: Sampling temperature (0.0 to 1.0)
|
||||
top_p: Nucleus sampling parameter
|
||||
stream: If True, return streaming response (not yet implemented)
|
||||
**kwargs: Additional generation parameters
|
||||
|
||||
Returns:
|
||||
Generated text string or dict with response
|
||||
"""
|
||||
# Tokenize input
|
||||
input_tokens = self.model.tokenizer(
|
||||
[prompt],
|
||||
return_tensors="pt",
|
||||
return_attention_mask=False,
|
||||
truncation=True,
|
||||
max_length=512, # Adjust as needed
|
||||
padding=False
|
||||
)
|
||||
|
||||
# Move to GPU if available
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
input_ids = input_tokens['input_ids'].to(device)
|
||||
|
||||
# Prepare generation parameters
|
||||
gen_kwargs = {
|
||||
'max_new_tokens': max_tokens,
|
||||
'use_cache': True,
|
||||
'return_dict_in_generate': True,
|
||||
'temperature': temperature,
|
||||
'top_p': top_p,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
# Generate
|
||||
with torch.inference_mode():
|
||||
generation_output = self.model.generate(input_ids, **gen_kwargs)
|
||||
|
||||
# Decode output
|
||||
output = self.model.tokenizer.decode(generation_output.sequences[0])
|
||||
|
||||
# Remove the input prompt from output (if present)
|
||||
if output.startswith(prompt):
|
||||
output = output[len(prompt):].strip()
|
||||
|
||||
if stream:
|
||||
# For streaming, return a generator (simplified version)
|
||||
return {"response": output}
|
||||
else:
|
||||
return output
|
||||
|
||||
def chat(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
model: Optional[str] = None,
|
||||
max_tokens: int = 50,
|
||||
temperature: float = 0.7,
|
||||
**kwargs
|
||||
) -> str:
|
||||
"""
|
||||
Chat interface (Ollama-compatible).
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'role' and 'content' keys
|
||||
model: Ignored (kept for compatibility)
|
||||
max_tokens: Maximum tokens to generate
|
||||
temperature: Sampling temperature
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
Generated response string
|
||||
"""
|
||||
# Format messages into a prompt
|
||||
prompt = self._format_messages(messages)
|
||||
return self.generate(
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def _format_messages(self, messages: List[Dict[str, str]]) -> str:
|
||||
"""Format chat messages into a single prompt."""
|
||||
formatted = []
|
||||
for msg in messages:
|
||||
role = msg.get('role', 'user')
|
||||
content = msg.get('content', '')
|
||||
if role == 'system':
|
||||
formatted.append(f"System: {content}")
|
||||
elif role == 'user':
|
||||
formatted.append(f"User: {content}")
|
||||
elif role == 'assistant':
|
||||
formatted.append(f"Assistant: {content}")
|
||||
return "\n".join(formatted) + "\nAssistant:"
|
||||
|
||||
def embeddings(self, prompt: str) -> List[float]:
|
||||
"""
|
||||
Get embeddings for a prompt (simplified - returns token embeddings).
|
||||
|
||||
Note: This is a simplified version. For full embeddings,
|
||||
you may need to access model internals.
|
||||
"""
|
||||
tokens = self.model.tokenizer(
|
||||
[prompt],
|
||||
return_tensors="pt",
|
||||
truncation=True,
|
||||
max_length=512,
|
||||
padding=False
|
||||
)
|
||||
# This is a placeholder - actual embeddings would require model forward pass
|
||||
return tokens['input_ids'].tolist()[0]
|
||||
|
||||
|
||||
# Convenience function for easy migration
|
||||
def create_ollama_client(model_name: str, compression: Optional[str] = None, **kwargs):
|
||||
"""
|
||||
Create an Ollama-compatible client using AirLLM.
|
||||
|
||||
Usage:
|
||||
client = create_ollama_client("meta-llama/Llama-3.2-3B-Instruct")
|
||||
response = client.generate("Hello, how are you?")
|
||||
"""
|
||||
return AirLLMOllamaWrapper(model_name, compression=compression, **kwargs)
|
||||
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
# Example 1: Basic generation
|
||||
print("Example 1: Basic Generation")
|
||||
print("=" * 60)
|
||||
|
||||
# Initialize (this will take time on first run)
|
||||
# client = create_ollama_client("garage-bAInd/Platypus2-70B-instruct")
|
||||
|
||||
# Generate
|
||||
# response = client.generate("What is the capital of France?")
|
||||
# print(f"Response: {response}")
|
||||
|
||||
print("\nExample 2: Chat Interface")
|
||||
print("=" * 60)
|
||||
|
||||
# Chat example
|
||||
# messages = [
|
||||
# {"role": "user", "content": "Hello! How are you?"}
|
||||
# ]
|
||||
# response = client.chat(messages)
|
||||
# print(f"Response: {response}")
|
||||
|
||||
print("\nUncomment the code above to test!")
|
||||
|
||||
@ -44,7 +44,7 @@ def _validate_url(url: str) -> tuple[bool, str]:
|
||||
|
||||
|
||||
class WebSearchTool(Tool):
|
||||
"""Search the web using Brave Search API."""
|
||||
"""Search the web using DuckDuckGo (free, no API key required)."""
|
||||
|
||||
name = "web_search"
|
||||
description = "Search the web. Returns titles, URLs, and snippets."
|
||||
@ -58,13 +58,20 @@ class WebSearchTool(Tool):
|
||||
}
|
||||
|
||||
def __init__(self, api_key: str | None = None, max_results: int = 5):
|
||||
# Keep api_key parameter for backward compatibility, but use DuckDuckGo if not provided
|
||||
self.api_key = api_key or os.environ.get("BRAVE_API_KEY", "")
|
||||
self.max_results = max_results
|
||||
self.use_brave = bool(self.api_key)
|
||||
|
||||
async def execute(self, query: str, count: int | None = None, **kwargs: Any) -> str:
|
||||
if not self.api_key:
|
||||
return "Error: BRAVE_API_KEY not configured"
|
||||
# Try Brave API if key is available, otherwise use DuckDuckGo
|
||||
if self.use_brave:
|
||||
return await self._brave_search(query, count)
|
||||
else:
|
||||
return await self._duckduckgo_search(query, count)
|
||||
|
||||
async def _brave_search(self, query: str, count: int | None = None) -> str:
|
||||
"""Search using Brave API (requires API key)."""
|
||||
try:
|
||||
n = min(max(count or self.max_results, 1), 10)
|
||||
async with httpx.AsyncClient() as client:
|
||||
@ -89,6 +96,79 @@ class WebSearchTool(Tool):
|
||||
except Exception as e:
|
||||
return f"Error: {e}"
|
||||
|
||||
async def _duckduckgo_search(self, query: str, count: int | None = None) -> str:
|
||||
"""Search using DuckDuckGo (free, no API key)."""
|
||||
try:
|
||||
n = min(max(count or self.max_results, 1), 10)
|
||||
|
||||
# Try using duckduckgo_search library if available
|
||||
try:
|
||||
from duckduckgo_search import DDGS
|
||||
with DDGS() as ddgs:
|
||||
results = []
|
||||
for r in ddgs.text(query, max_results=n):
|
||||
results.append({
|
||||
"title": r.get("title", ""),
|
||||
"url": r.get("href", ""),
|
||||
"description": r.get("body", "")
|
||||
})
|
||||
|
||||
if not results:
|
||||
return f"No results found for: {query}"
|
||||
|
||||
lines = [f"Results for: {query}\n"]
|
||||
for i, item in enumerate(results, 1):
|
||||
lines.append(f"{i}. {item['title']}\n {item['url']}")
|
||||
if item['description']:
|
||||
lines.append(f" {item['description']}")
|
||||
return "\n".join(lines)
|
||||
except ImportError:
|
||||
# Fallback: use DuckDuckGo instant answer API (simpler, but limited)
|
||||
async with httpx.AsyncClient(
|
||||
follow_redirects=True,
|
||||
timeout=15.0
|
||||
) as client:
|
||||
# Use DuckDuckGo instant answer API (no key needed)
|
||||
url = "https://api.duckduckgo.com/"
|
||||
r = await client.get(
|
||||
url,
|
||||
params={"q": query, "format": "json", "no_html": "1", "skip_disambig": "1"},
|
||||
headers={"User-Agent": USER_AGENT},
|
||||
)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
|
||||
results = []
|
||||
# Get RelatedTopics (search results)
|
||||
if "RelatedTopics" in data:
|
||||
for topic in data["RelatedTopics"][:n]:
|
||||
if "Text" in topic and "FirstURL" in topic:
|
||||
results.append({
|
||||
"title": topic.get("Text", "").split(" - ")[0] if " - " in topic.get("Text", "") else topic.get("Text", "")[:50],
|
||||
"url": topic.get("FirstURL", ""),
|
||||
"description": topic.get("Text", "")
|
||||
})
|
||||
|
||||
# Also check AbstractText for direct answer
|
||||
if "AbstractText" in data and data["AbstractText"]:
|
||||
results.insert(0, {
|
||||
"title": data.get("Heading", query),
|
||||
"url": data.get("AbstractURL", ""),
|
||||
"description": data.get("AbstractText", "")
|
||||
})
|
||||
|
||||
if not results:
|
||||
return f"No results found for: {query}. Try installing 'duckduckgo-search' package for better results: pip install duckduckgo-search"
|
||||
|
||||
lines = [f"Results for: {query}\n"]
|
||||
for i, item in enumerate(results[:n], 1):
|
||||
lines.append(f"{i}. {item['title']}\n {item['url']}")
|
||||
if item['description']:
|
||||
lines.append(f" {item['description']}")
|
||||
return "\n".join(lines)
|
||||
except Exception as e:
|
||||
return f"Error searching: {e}. Try installing 'duckduckgo-search' package: pip install duckduckgo-search"
|
||||
|
||||
|
||||
class WebFetchTool(Tool):
|
||||
"""Fetch and extract content from a URL using Readability."""
|
||||
|
||||
@ -265,10 +265,60 @@ This file stores important information that should persist across sessions.
|
||||
|
||||
|
||||
def _make_provider(config):
|
||||
"""Create LiteLLMProvider from config. Exits if no API key found."""
|
||||
from nanobot.providers.litellm_provider import LiteLLMProvider
|
||||
"""Create LLM provider from config. Supports LiteLLMProvider and AirLLMProvider."""
|
||||
provider_name = config.get_provider_name()
|
||||
p = config.get_provider()
|
||||
model = config.agents.defaults.model
|
||||
|
||||
# Check if AirLLM provider is requested
|
||||
if provider_name == "airllm":
|
||||
try:
|
||||
from nanobot.providers.airllm_provider import AirLLMProvider
|
||||
# AirLLM doesn't need API key, but we can use model path from config
|
||||
# Check if model is specified in the airllm provider config
|
||||
airllm_config = getattr(config.providers, "airllm", None)
|
||||
model_path = None
|
||||
compression = None
|
||||
|
||||
# Try to get model from airllm config's api_key field (repurposed as model path)
|
||||
# or from the default model
|
||||
if airllm_config and airllm_config.api_key:
|
||||
# Check if api_key looks like a model path (contains '/') or is an HF token
|
||||
if '/' in airllm_config.api_key:
|
||||
model_path = airllm_config.api_key
|
||||
hf_token = None
|
||||
else:
|
||||
# Treat as HF token, use model from defaults
|
||||
model_path = model
|
||||
hf_token = airllm_config.api_key
|
||||
else:
|
||||
model_path = model
|
||||
hf_token = None
|
||||
|
||||
# Check for compression setting in extra_headers or api_base
|
||||
if airllm_config:
|
||||
if airllm_config.api_base:
|
||||
compression = airllm_config.api_base # Repurpose api_base as compression
|
||||
elif airllm_config.extra_headers and "compression" in airllm_config.extra_headers:
|
||||
compression = airllm_config.extra_headers["compression"]
|
||||
# Check for HF token in extra_headers
|
||||
if not hf_token and airllm_config.extra_headers and "hf_token" in airllm_config.extra_headers:
|
||||
hf_token = airllm_config.extra_headers["hf_token"]
|
||||
|
||||
return AirLLMProvider(
|
||||
api_key=airllm_config.api_key if airllm_config else None,
|
||||
api_base=compression if compression else None,
|
||||
default_model=model_path,
|
||||
compression=compression,
|
||||
hf_token=hf_token,
|
||||
)
|
||||
except ImportError as e:
|
||||
console.print(f"[red]Error: AirLLM provider not available: {e}[/red]")
|
||||
console.print("Please ensure airllm_ollama_wrapper.py is in the Python path.")
|
||||
raise typer.Exit(1)
|
||||
|
||||
# Default to LiteLLMProvider
|
||||
from nanobot.providers.litellm_provider import LiteLLMProvider
|
||||
if not (p and p.api_key) and not model.startswith("bedrock/"):
|
||||
console.print("[red]Error: No API key configured.[/red]")
|
||||
console.print("Set one in ~/.nanobot/config.json under providers section")
|
||||
@ -278,7 +328,7 @@ def _make_provider(config):
|
||||
api_base=config.get_api_base(),
|
||||
default_model=model,
|
||||
extra_headers=p.extra_headers if p else None,
|
||||
provider_name=config.get_provider_name(),
|
||||
provider_name=provider_name,
|
||||
)
|
||||
|
||||
|
||||
@ -444,9 +494,16 @@ def agent(
|
||||
if message:
|
||||
# Single message mode
|
||||
async def run_once():
|
||||
try:
|
||||
with _thinking_ctx():
|
||||
response = await agent_loop.process_direct(message, session_id)
|
||||
_print_agent_response(response, render_markdown=markdown)
|
||||
# response is a string (content) from process_direct
|
||||
_print_agent_response(response or "", render_markdown=markdown)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
console.print(f"[red]Error: {e}[/red]")
|
||||
console.print(f"[dim]{traceback.format_exc()}[/dim]")
|
||||
raise
|
||||
|
||||
asyncio.run(run_once())
|
||||
else:
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
"""Configuration schema using Pydantic."""
|
||||
|
||||
from pathlib import Path
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
@ -177,18 +177,10 @@ class ProviderConfig(BaseModel):
|
||||
|
||||
class ProvidersConfig(BaseModel):
|
||||
"""Configuration for LLM providers."""
|
||||
anthropic: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||
openai: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||
openrouter: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||
deepseek: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||
groq: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||
zhipu: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||
dashscope: ProviderConfig = Field(default_factory=ProviderConfig) # 阿里云通义千问
|
||||
vllm: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||
gemini: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||
moonshot: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||
minimax: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||
aihubmix: ProviderConfig = Field(default_factory=ProviderConfig) # AiHubMix API gateway
|
||||
ollama: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||
airllm: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||
|
||||
|
||||
class GatewayConfig(BaseModel):
|
||||
@ -241,13 +233,36 @@ class Config(BaseSettings):
|
||||
# Match by keyword (order follows PROVIDERS registry)
|
||||
for spec in PROVIDERS:
|
||||
p = getattr(self.providers, spec.name, None)
|
||||
if p and any(kw in model_lower for kw in spec.keywords) and p.api_key:
|
||||
if p and any(kw in model_lower for kw in spec.keywords):
|
||||
# For local providers (Ollama, AirLLM), allow empty api_key or "dummy"
|
||||
# For other providers, require api_key
|
||||
if spec.is_local:
|
||||
# Local providers can work with empty/dummy api_key
|
||||
if p.api_key or p.api_base or spec.name == "airllm":
|
||||
return p, spec.name
|
||||
elif p.api_key:
|
||||
return p, spec.name
|
||||
|
||||
# Check local providers by api_base detection (for explicit config)
|
||||
for spec in PROVIDERS:
|
||||
if spec.is_local:
|
||||
p = getattr(self.providers, spec.name, None)
|
||||
if p:
|
||||
# Check if api_base matches the provider's detection pattern
|
||||
if spec.detect_by_base_keyword and p.api_base and spec.detect_by_base_keyword in p.api_base:
|
||||
return p, spec.name
|
||||
# AirLLM is detected by provider name being "airllm"
|
||||
if spec.name == "airllm" and p.api_key: # api_key can be model path
|
||||
return p, spec.name
|
||||
|
||||
# Fallback: gateways first, then others (follows registry order)
|
||||
for spec in PROVIDERS:
|
||||
p = getattr(self.providers, spec.name, None)
|
||||
if p and p.api_key:
|
||||
if p:
|
||||
# For local providers, allow empty/dummy api_key
|
||||
if spec.is_local and (p.api_key or p.api_base):
|
||||
return p, spec.name
|
||||
elif p.api_key:
|
||||
return p, spec.name
|
||||
return None, None
|
||||
|
||||
@ -281,6 +296,7 @@ class Config(BaseSettings):
|
||||
return spec.default_api_base
|
||||
return None
|
||||
|
||||
class Config:
|
||||
env_prefix = "NANOBOT_"
|
||||
env_nested_delimiter = "__"
|
||||
model_config = ConfigDict(
|
||||
env_prefix="NANOBOT_",
|
||||
env_nested_delimiter="__"
|
||||
)
|
||||
|
||||
@ -3,4 +3,8 @@
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse
|
||||
from nanobot.providers.litellm_provider import LiteLLMProvider
|
||||
|
||||
__all__ = ["LLMProvider", "LLMResponse", "LiteLLMProvider"]
|
||||
try:
|
||||
from nanobot.providers.airllm_provider import AirLLMProvider
|
||||
__all__ = ["LLMProvider", "LLMResponse", "LiteLLMProvider", "AirLLMProvider"]
|
||||
except ImportError:
|
||||
__all__ = ["LLMProvider", "LLMResponse", "LiteLLMProvider"]
|
||||
|
||||
188
nanobot/providers/airllm_provider.py
Normal file
188
nanobot/providers/airllm_provider.py
Normal file
@ -0,0 +1,188 @@
|
||||
"""AirLLM provider implementation for direct local model inference."""
|
||||
|
||||
import json
|
||||
import asyncio
|
||||
import sys
|
||||
from typing import Any
|
||||
from pathlib import Path
|
||||
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
|
||||
# Import the wrapper - handle import errors gracefully
|
||||
try:
|
||||
from nanobot.providers.airllm_wrapper import AirLLMOllamaWrapper, create_ollama_client
|
||||
AIRLLM_WRAPPER_AVAILABLE = True
|
||||
_import_error = None
|
||||
except ImportError as e:
|
||||
AIRLLM_WRAPPER_AVAILABLE = False
|
||||
AirLLMOllamaWrapper = None
|
||||
create_ollama_client = None
|
||||
_import_error = str(e)
|
||||
|
||||
|
||||
class AirLLMProvider(LLMProvider):
|
||||
"""
|
||||
LLM provider using AirLLM for direct local model inference.
|
||||
|
||||
This provider loads models directly into memory and runs inference locally,
|
||||
bypassing HTTP API calls. It's optimized for GPU-limited environments.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None, # Repurposed: can be HF token or model name
|
||||
api_base: str | None = None, # Repurposed: compression setting ('4bit' or '8bit')
|
||||
default_model: str = "meta-llama/Llama-3.2-3B-Instruct",
|
||||
compression: str | None = None, # '4bit' or '8bit' for speed improvement
|
||||
model_path: str | None = None, # Override default model
|
||||
hf_token: str | None = None, # Hugging Face token for gated models
|
||||
):
|
||||
super().__init__(api_key, api_base)
|
||||
self.default_model = model_path or default_model
|
||||
# If api_base is set and looks like compression, use it
|
||||
if api_base and api_base in ('4bit', '8bit'):
|
||||
self.compression = api_base
|
||||
else:
|
||||
self.compression = compression
|
||||
# If api_key is provided and doesn't look like a model path, treat as HF token
|
||||
if api_key and '/' not in api_key and len(api_key) > 20:
|
||||
self.hf_token = api_key
|
||||
else:
|
||||
self.hf_token = hf_token
|
||||
# If api_key looks like a model path, use it as the model
|
||||
if api_key and '/' in api_key:
|
||||
self.default_model = api_key
|
||||
self._client: AirLLMOllamaWrapper | None = None
|
||||
self._model_loaded = False
|
||||
|
||||
def _ensure_client(self) -> AirLLMOllamaWrapper:
|
||||
"""Lazy-load the AirLLM client."""
|
||||
if not AIRLLM_WRAPPER_AVAILABLE:
|
||||
error_msg = (
|
||||
"AirLLM wrapper is not available. Please ensure airllm_ollama_wrapper.py "
|
||||
"is in the Python path and AirLLM is installed."
|
||||
)
|
||||
if '_import_error' in globals():
|
||||
error_msg += f"\nImport error: {_import_error}"
|
||||
raise ImportError(error_msg)
|
||||
|
||||
if self._client is None or not self._model_loaded:
|
||||
print(f"Initializing AirLLM with model: {self.default_model}")
|
||||
if self.compression:
|
||||
print(f"Using compression: {self.compression}")
|
||||
if self.hf_token:
|
||||
print("Using Hugging Face token for authentication")
|
||||
|
||||
# Prepare kwargs for model loading
|
||||
kwargs = {}
|
||||
if self.hf_token:
|
||||
kwargs['hf_token'] = self.hf_token
|
||||
|
||||
self._client = create_ollama_client(
|
||||
self.default_model,
|
||||
compression=self.compression,
|
||||
**kwargs
|
||||
)
|
||||
self._model_loaded = True
|
||||
print("AirLLM model loaded and ready!")
|
||||
|
||||
return self._client
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
Send a chat completion request using AirLLM.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'role' and 'content'.
|
||||
tools: Optional list of tool definitions (Note: tool calling support may be limited).
|
||||
model: Model identifier (ignored if different from initialized model).
|
||||
max_tokens: Maximum tokens in response.
|
||||
temperature: Sampling temperature.
|
||||
|
||||
Returns:
|
||||
LLMResponse with content and/or tool calls.
|
||||
"""
|
||||
# If a different model is requested, we'd need to reload (expensive)
|
||||
# For now, we'll use the initialized model
|
||||
if model and model != self.default_model:
|
||||
print(f"Warning: Model {model} requested but {self.default_model} is loaded. Using loaded model.")
|
||||
|
||||
client = self._ensure_client()
|
||||
|
||||
# Format tools into the prompt if provided (basic tool support)
|
||||
# Note: Full tool calling requires model support and proper formatting
|
||||
if tools:
|
||||
# Add tool definitions to the system message or last user message
|
||||
tools_text = "\n".join([
|
||||
f"- {tool.get('function', {}).get('name', 'unknown')}: {tool.get('function', {}).get('description', '')}"
|
||||
for tool in tools
|
||||
])
|
||||
# Append to messages (simplified - full implementation would format properly)
|
||||
if messages and messages[-1].get('role') == 'user':
|
||||
messages[-1]['content'] += f"\n\nAvailable tools:\n{tools_text}"
|
||||
|
||||
# Run the synchronous client in an executor to avoid blocking
|
||||
loop = asyncio.get_event_loop()
|
||||
try:
|
||||
response_text = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: client.chat(
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
error_msg = f"AirLLM generation failed: {e}\n{traceback.format_exc()}"
|
||||
print(error_msg, file=sys.stderr)
|
||||
raise RuntimeError(f"AirLLM provider error: {e}") from e
|
||||
|
||||
# Parse tool calls from response if present
|
||||
# This is a simplified parser - you may need to adjust based on model output format
|
||||
tool_calls = []
|
||||
content = response_text
|
||||
|
||||
# Try to extract JSON tool calls from the response
|
||||
# Some models return tool calls as JSON in the content
|
||||
if "tool_calls" in response_text.lower() or "function" in response_text.lower():
|
||||
try:
|
||||
# Look for JSON blocks in the response
|
||||
import re
|
||||
json_pattern = r'\{[^{}]*"function"[^{}]*\}'
|
||||
matches = re.findall(json_pattern, response_text, re.DOTALL)
|
||||
for match in matches:
|
||||
try:
|
||||
tool_data = json.loads(match)
|
||||
if "function" in tool_data:
|
||||
func = tool_data["function"]
|
||||
tool_calls.append(ToolCallRequest(
|
||||
id=tool_data.get("id", f"call_{len(tool_calls)}"),
|
||||
name=func.get("name", "unknown"),
|
||||
arguments=func.get("arguments", {}),
|
||||
))
|
||||
# Remove the tool call from content
|
||||
content = content.replace(match, "").strip()
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
except Exception:
|
||||
pass # If parsing fails, just return the content as-is
|
||||
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
tool_calls=tool_calls if tool_calls else [],
|
||||
finish_reason="stop",
|
||||
usage={}, # AirLLM doesn't provide usage stats in the wrapper
|
||||
)
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
"""Get the default model."""
|
||||
return self.default_model
|
||||
|
||||
511
nanobot/providers/airllm_wrapper.py
Normal file
511
nanobot/providers/airllm_wrapper.py
Normal file
@ -0,0 +1,511 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
AirLLM Ollama-Compatible Wrapper
|
||||
|
||||
This wrapper provides an Ollama-like interface for AirLLM,
|
||||
making it easy to replace Ollama in existing projects.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from typing import List, Dict, Optional, Union
|
||||
|
||||
# Try to import airllm, preferring the local checkout if available
|
||||
import sys
|
||||
import os
|
||||
import importlib.util
|
||||
|
||||
# Inject dummy BetterTransformer BEFORE importing airllm (local code needs it)
|
||||
class DummyBetterTransformer:
|
||||
@staticmethod
|
||||
def transform(model):
|
||||
return model
|
||||
|
||||
if "optimum.bettertransformer" not in sys.modules:
|
||||
spec = importlib.util.spec_from_loader("optimum.bettertransformer", None)
|
||||
dummy_module = importlib.util.module_from_spec(spec)
|
||||
dummy_module.BetterTransformer = DummyBetterTransformer
|
||||
sys.modules["optimum.bettertransformer"] = dummy_module
|
||||
|
||||
# Fix RoPE scaling compatibility: patch transformers to handle "llama3" type
|
||||
def _patch_rope_scaling():
|
||||
"""Patch transformers LlamaConfig to handle unsupported 'llama3' RoPE scaling type."""
|
||||
try:
|
||||
from transformers import LlamaConfig
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig as OriginalLlamaConfig
|
||||
|
||||
# Store original __init__ if not already patched
|
||||
if not hasattr(OriginalLlamaConfig, '_rope_scaling_patched'):
|
||||
original_init = OriginalLlamaConfig.__init__
|
||||
|
||||
def patched_init(self, *args, **kwargs):
|
||||
# Call original init
|
||||
original_init(self, *args, **kwargs)
|
||||
|
||||
# Fix rope_scaling if it's "llama3" (unsupported in some transformers versions)
|
||||
if hasattr(self, 'rope_scaling') and self.rope_scaling is not None:
|
||||
# Check if it's a dict or object
|
||||
if isinstance(self.rope_scaling, dict):
|
||||
if self.rope_scaling.get('type') == 'llama3':
|
||||
print("Warning: Converting unsupported RoPE scaling 'llama3' to 'linear'")
|
||||
self.rope_scaling['type'] = 'linear'
|
||||
if 'factor' not in self.rope_scaling:
|
||||
self.rope_scaling['factor'] = 1.0
|
||||
elif hasattr(self.rope_scaling, 'type'):
|
||||
if getattr(self.rope_scaling, 'type', None) == 'llama3':
|
||||
print("Warning: Converting unsupported RoPE scaling 'llama3' to 'linear'")
|
||||
# Convert to dict format
|
||||
factor = getattr(self.rope_scaling, 'factor', 1.0)
|
||||
self.rope_scaling = {'type': 'linear', 'factor': factor}
|
||||
|
||||
OriginalLlamaConfig.__init__ = patched_init
|
||||
OriginalLlamaConfig._rope_scaling_patched = True
|
||||
except Exception as e:
|
||||
# If patching fails, we'll handle it in the error handler
|
||||
print(f"Warning: Could not patch RoPE scaling: {e}", file=sys.stderr)
|
||||
|
||||
def _patch_attention_position_embeddings():
|
||||
"""Patch LlamaSdpaAttention to accept and ignore position_embeddings argument for AirLLM compatibility."""
|
||||
try:
|
||||
from transformers.models.llama import modeling_llama
|
||||
import functools
|
||||
|
||||
# Check if LlamaSdpaAttention exists and hasn't been patched
|
||||
if hasattr(modeling_llama, 'LlamaSdpaAttention'):
|
||||
LlamaSdpaAttention = modeling_llama.LlamaSdpaAttention
|
||||
if not hasattr(LlamaSdpaAttention, '_position_embeddings_patched'):
|
||||
original_forward = LlamaSdpaAttention.forward
|
||||
|
||||
@functools.wraps(original_forward)
|
||||
def patched_forward(self, *args, **kwargs):
|
||||
# Remove position_embeddings if present (AirLLM compatibility)
|
||||
kwargs.pop('position_embeddings', None)
|
||||
# Call original forward
|
||||
return original_forward(self, *args, **kwargs)
|
||||
|
||||
LlamaSdpaAttention.forward = patched_forward
|
||||
LlamaSdpaAttention._position_embeddings_patched = True
|
||||
except Exception as e:
|
||||
# If patching fails, we'll handle it in the error handler
|
||||
print(f"Warning: Could not patch attention position_embeddings: {e}", file=sys.stderr)
|
||||
|
||||
# Apply the patches before importing airllm
|
||||
_patch_rope_scaling()
|
||||
_patch_attention_position_embeddings()
|
||||
|
||||
LOCAL_AIRLLM_PATH = "/home/ladmin/code/airllm/airllm/air_llm"
|
||||
if os.path.exists(LOCAL_AIRLLM_PATH) and LOCAL_AIRLLM_PATH not in sys.path:
|
||||
sys.path.insert(0, LOCAL_AIRLLM_PATH)
|
||||
|
||||
try:
|
||||
from airllm import AutoModel
|
||||
AIRLLM_AVAILABLE = True
|
||||
except ImportError as e:
|
||||
AIRLLM_AVAILABLE = False
|
||||
AutoModel = None
|
||||
print(f"Warning: Failed to import AirLLM: {e}", file=sys.stderr)
|
||||
|
||||
|
||||
class AirLLMOllamaWrapper:
|
||||
"""
|
||||
A wrapper that provides an Ollama-like API for AirLLM.
|
||||
|
||||
Usage:
|
||||
# Instead of: ollama.generate(model="llama2", prompt="Hello")
|
||||
# Use: airllm_wrapper.generate(model="llama2", prompt="Hello")
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str, compression: Optional[str] = None, **kwargs):
|
||||
"""
|
||||
Initialize AirLLM model.
|
||||
|
||||
Args:
|
||||
model_name: Hugging Face model name or path (e.g., "meta-llama/Llama-3.2-3B-Instruct")
|
||||
compression: Optional compression ('4bit' or '8bit') for 3x speed improvement
|
||||
**kwargs: Additional arguments for AutoModel.from_pretrained()
|
||||
"""
|
||||
if not AIRLLM_AVAILABLE or AutoModel is None:
|
||||
raise ImportError(
|
||||
"AirLLM is not available. Please install it with: pip install airllm bitsandbytes\n"
|
||||
"If you see a BetterTransformer error, you may need to install: pip install optimum[bettertransformer]"
|
||||
)
|
||||
|
||||
print(f"Loading AirLLM model: {model_name}")
|
||||
|
||||
# Fix RoPE scaling compatibility issue: transformers 4.39.3 doesn't support "llama3" type
|
||||
# Modify config file if it's a local path and has unsupported rope_scaling
|
||||
model_path = model_name
|
||||
if os.path.exists(model_name) or model_name.startswith('/') or model_name.startswith('~'):
|
||||
if model_name.startswith('~'):
|
||||
model_path = os.path.expanduser(model_name)
|
||||
else:
|
||||
model_path = os.path.abspath(model_name)
|
||||
|
||||
config_json_path = os.path.join(model_path, "config.json")
|
||||
if os.path.exists(config_json_path):
|
||||
try:
|
||||
import json
|
||||
with open(config_json_path, 'r') as f:
|
||||
config_data = json.load(f)
|
||||
|
||||
# Check and fix rope_scaling
|
||||
if 'rope_scaling' in config_data and config_data['rope_scaling'] is not None:
|
||||
rope_scaling = config_data['rope_scaling']
|
||||
if isinstance(rope_scaling, dict) and rope_scaling.get('type') == 'llama3':
|
||||
print("Warning: Fixing unsupported RoPE scaling type 'llama3' -> 'linear'")
|
||||
# Backup original config
|
||||
backup_path = config_json_path + ".backup"
|
||||
if not os.path.exists(backup_path):
|
||||
import shutil
|
||||
shutil.copy2(config_json_path, backup_path)
|
||||
|
||||
# Fix the rope_scaling type
|
||||
config_data['rope_scaling']['type'] = 'linear'
|
||||
if 'factor' not in config_data['rope_scaling']:
|
||||
config_data['rope_scaling']['factor'] = 1.0
|
||||
|
||||
# Save fixed config
|
||||
with open(config_json_path, 'w') as f:
|
||||
json.dump(config_data, f, indent=2)
|
||||
print(f"Fixed config saved to {config_json_path}")
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not fix config file: {e}", file=sys.stderr)
|
||||
|
||||
# Determine max_seq_len before loading model
|
||||
# AirLLM needs this at initialization time
|
||||
max_seq_len = 2048 # Default for Llama models
|
||||
|
||||
# Check if this is a Llama model to determine appropriate max length
|
||||
# We need to load config first to check model type
|
||||
try:
|
||||
from transformers import AutoConfig
|
||||
config = AutoConfig.from_pretrained(model_name, **{k: v for k, v in kwargs.items() if k in ['token', 'trust_remote_code']})
|
||||
model_type = getattr(config, 'model_type', '').lower()
|
||||
is_llama = 'llama' in model_type or 'llama' in model_name.lower()
|
||||
|
||||
# Also fix rope_scaling in the loaded config object if needed
|
||||
if is_llama and hasattr(config, 'rope_scaling') and config.rope_scaling is not None:
|
||||
if isinstance(config.rope_scaling, dict) and config.rope_scaling.get('type') == 'llama3':
|
||||
print("Warning: Converting RoPE scaling 'llama3' to 'linear' in config object")
|
||||
config.rope_scaling['type'] = 'linear'
|
||||
if 'factor' not in config.rope_scaling:
|
||||
config.rope_scaling['factor'] = 1.0
|
||||
elif hasattr(config.rope_scaling, 'type') and getattr(config.rope_scaling, 'type', None) == 'llama3':
|
||||
# Convert object to dict
|
||||
factor = getattr(config.rope_scaling, 'factor', 1.0)
|
||||
config.rope_scaling = {'type': 'linear', 'factor': factor}
|
||||
|
||||
if is_llama:
|
||||
config_max = getattr(config, 'max_position_embeddings', None)
|
||||
if config_max and config_max > 0:
|
||||
max_seq_len = min(config_max, 2048)
|
||||
else:
|
||||
max_seq_len = 2048
|
||||
else:
|
||||
config_max = getattr(config, 'max_position_embeddings', None)
|
||||
if config_max and config_max > 0 and config_max <= 2048:
|
||||
max_seq_len = config_max
|
||||
else:
|
||||
max_seq_len = 512
|
||||
except Exception:
|
||||
# Fallback to defaults if config loading fails
|
||||
pass
|
||||
|
||||
# AutoModel.from_pretrained() accepts:
|
||||
# - Hugging Face model IDs (e.g., "meta-llama/Llama-3.1-8B-Instruct")
|
||||
# - Local paths (e.g., "/path/to/local/model")
|
||||
# - Can use local_dir parameter for local models
|
||||
try:
|
||||
self.model = AutoModel.from_pretrained(
|
||||
model_name,
|
||||
compression=compression,
|
||||
max_seq_len=max_seq_len, # Pass max_seq_len to AirLLM
|
||||
**kwargs
|
||||
)
|
||||
except ValueError as e:
|
||||
# Handle specific RoPE scaling errors
|
||||
if "Unknown RoPE scaling type" in str(e) or "rope_scaling" in str(e).lower():
|
||||
import traceback
|
||||
error_msg = (
|
||||
f"RoPE scaling compatibility error: {e}\n"
|
||||
"The model config uses a RoPE scaling type not supported by your transformers version.\n"
|
||||
"If this is a local model, the config file should have been fixed automatically.\n"
|
||||
"If the error persists, try:\n"
|
||||
"1. For local models: Check that config.json has rope_scaling.type='linear' instead of 'llama3'\n"
|
||||
"2. Upgrade transformers: pip install --upgrade transformers\n"
|
||||
"3. Or downgrade to a compatible version: pip install 'transformers==4.37.0'\n"
|
||||
f"\nFull traceback:\n{traceback.format_exc()}"
|
||||
)
|
||||
raise RuntimeError(error_msg) from e
|
||||
raise
|
||||
except Exception as e:
|
||||
import traceback
|
||||
error_msg = (
|
||||
f"Failed to load AirLLM model '{model_name}': {e}\n"
|
||||
f"Error type: {type(e).__name__}\n"
|
||||
"This is often a transformers version compatibility issue.\n"
|
||||
"Try one of these solutions:\n"
|
||||
"1. Install an older transformers version: pip install 'transformers==4.37.0'\n"
|
||||
"2. Or try: pip install 'transformers==4.38.2'\n"
|
||||
"3. If using transformers 4.39.3, try downgrading: pip install 'transformers==4.37.0'\n"
|
||||
"4. Check AirLLM compatibility with your transformers version\n"
|
||||
f"\nFull traceback:\n{traceback.format_exc()}"
|
||||
)
|
||||
raise RuntimeError(error_msg) from e
|
||||
self.model_name = model_name
|
||||
|
||||
# Store max_length for tokenization
|
||||
self.max_length = max_seq_len
|
||||
|
||||
# Check if this is a Llama model to determine appropriate max length
|
||||
is_llama = False
|
||||
if hasattr(self.model, 'config'):
|
||||
model_type = getattr(self.model.config, 'model_type', '').lower()
|
||||
is_llama = 'llama' in model_type or 'llama' in self.model_name.lower()
|
||||
|
||||
if is_llama:
|
||||
# Llama models: typically support 2048-4096 tokens
|
||||
# AirLLM works well with Llama, so we can use larger chunks
|
||||
if hasattr(self.model, 'config'):
|
||||
config_max = getattr(self.model.config, 'max_position_embeddings', None)
|
||||
if config_max and config_max > 0:
|
||||
# Use config value, but cap at 2048 for AirLLM safety
|
||||
self.max_length = min(config_max, 2048)
|
||||
else:
|
||||
self.max_length = 2048 # Safe default for Llama
|
||||
else:
|
||||
# For other models (e.g., DeepSeek), use conservative default
|
||||
if hasattr(self.model, 'config'):
|
||||
config_max = getattr(self.model.config, 'max_position_embeddings', None)
|
||||
if config_max and config_max > 0 and config_max <= 2048:
|
||||
self.max_length = config_max
|
||||
else:
|
||||
self.max_length = 512 # Very conservative
|
||||
|
||||
print(f"Using sequence length limit: {self.max_length} (AirLLM chunk size)")
|
||||
|
||||
print("Model loaded successfully!")
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
model: Optional[str] = None, # Ignored, kept for API compatibility
|
||||
max_tokens: int = 50,
|
||||
temperature: float = 0.7,
|
||||
top_p: float = 0.9,
|
||||
stream: bool = False,
|
||||
**kwargs
|
||||
) -> Union[str, Dict]:
|
||||
"""
|
||||
Generate text from a prompt (Ollama-compatible interface).
|
||||
|
||||
Args:
|
||||
prompt: Input text prompt
|
||||
model: Ignored (kept for compatibility)
|
||||
max_tokens: Maximum number of tokens to generate
|
||||
temperature: Sampling temperature (0.0 to 1.0)
|
||||
top_p: Nucleus sampling parameter
|
||||
stream: If True, return streaming response (not yet implemented)
|
||||
**kwargs: Additional generation parameters
|
||||
|
||||
Returns:
|
||||
Generated text string or dict with response
|
||||
"""
|
||||
# Tokenize input with attention mask
|
||||
# AirLLM processes sequences in chunks, but each chunk must fit within the model's
|
||||
# position embedding limits. We need to ensure we don't exceed the chunk size.
|
||||
# Use the model's max_length to ensure compatibility with position embeddings
|
||||
input_tokens = self.model.tokenizer(
|
||||
prompt,
|
||||
return_tensors="pt",
|
||||
return_attention_mask=True,
|
||||
truncation=True,
|
||||
max_length=self.max_length, # Respect model's position embedding limit
|
||||
padding=False
|
||||
)
|
||||
|
||||
# Move to GPU if available
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
input_ids = input_tokens['input_ids'].to(device)
|
||||
attention_mask = input_tokens.get('attention_mask', None)
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(device)
|
||||
|
||||
# Ensure we don't exceed max_length (manual truncation as safety check)
|
||||
seq_length = input_ids.shape[1]
|
||||
if seq_length > self.max_length:
|
||||
print(f"Warning: Sequence length ({seq_length}) exceeds limit ({self.max_length}), truncating...")
|
||||
input_ids = input_ids[:, :self.max_length]
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, :self.max_length]
|
||||
seq_length = self.max_length
|
||||
|
||||
if seq_length >= self.max_length:
|
||||
print(f"Note: Using sequence of {seq_length} tokens (at limit: {self.max_length})")
|
||||
|
||||
# Prepare generation parameters
|
||||
# For Llama models, we can use more tokens
|
||||
max_gen_tokens = min(max_tokens, 512)
|
||||
|
||||
gen_kwargs = {
|
||||
'max_new_tokens': max_gen_tokens,
|
||||
'use_cache': False, # Disable cache to avoid DynamicCache compatibility issues
|
||||
'return_dict_in_generate': True,
|
||||
'temperature': temperature,
|
||||
'top_p': top_p,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
# Add attention mask if available
|
||||
if attention_mask is not None:
|
||||
gen_kwargs['attention_mask'] = attention_mask
|
||||
|
||||
# Generate
|
||||
try:
|
||||
with torch.inference_mode():
|
||||
generation_output = self.model.generate(input_ids, **gen_kwargs)
|
||||
except (TypeError, RuntimeError) as e:
|
||||
if "position_embeddings" in str(e) or "cannot unpack" in str(e):
|
||||
error_msg = (
|
||||
f"AirLLM compatibility error with transformers: {e}\n"
|
||||
"This is a known issue with AirLLM and transformers version compatibility.\n"
|
||||
"Try one of these solutions:\n"
|
||||
"1. Install transformers 4.37.0: pip install 'transformers==4.37.0'\n"
|
||||
"2. Or try transformers 4.38.2: pip install 'transformers==4.38.2'\n"
|
||||
"3. If you're using 4.39.3, it may have compatibility issues - try downgrading\n"
|
||||
"4. Or use Ollama instead: nanobot agent -m 'Hello' (with Ollama provider)"
|
||||
)
|
||||
raise RuntimeError(error_msg) from e
|
||||
raise
|
||||
|
||||
# Decode output - get only the newly generated tokens
|
||||
if hasattr(generation_output, 'sequences'):
|
||||
# Extract only the new tokens (after input length)
|
||||
input_length = input_ids.shape[1]
|
||||
generated_ids = generation_output.sequences[0, input_length:]
|
||||
output = self.model.tokenizer.decode(generated_ids, skip_special_tokens=True)
|
||||
else:
|
||||
# Fallback for older output formats
|
||||
output = self.model.tokenizer.decode(generation_output.sequences[0], skip_special_tokens=True)
|
||||
# Remove the input prompt from output if present
|
||||
if output.startswith(prompt):
|
||||
output = output[len(prompt):].strip()
|
||||
|
||||
if stream:
|
||||
# For streaming, return a generator (simplified version)
|
||||
return {"response": output}
|
||||
else:
|
||||
return output
|
||||
|
||||
def chat(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
model: Optional[str] = None,
|
||||
max_tokens: int = 50,
|
||||
temperature: float = 0.7,
|
||||
**kwargs
|
||||
) -> str:
|
||||
"""
|
||||
Chat interface (Ollama-compatible).
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'role' and 'content' keys
|
||||
model: Ignored (kept for compatibility)
|
||||
max_tokens: Maximum tokens to generate
|
||||
temperature: Sampling temperature
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
Generated response string
|
||||
"""
|
||||
# Try to use the model's chat template if available (for Llama, etc.)
|
||||
if hasattr(self.model.tokenizer, 'apply_chat_template') and self.model.tokenizer.chat_template:
|
||||
try:
|
||||
# Use the model's native chat template
|
||||
prompt = self.model.tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True
|
||||
)
|
||||
except Exception:
|
||||
# Fallback to simple formatting if chat template fails
|
||||
prompt = self._format_messages(messages)
|
||||
else:
|
||||
# Fallback to simple formatting
|
||||
prompt = self._format_messages(messages)
|
||||
|
||||
return self.generate(
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def _format_messages(self, messages: List[Dict[str, str]]) -> str:
|
||||
"""Format chat messages into a single prompt (fallback method)."""
|
||||
formatted = []
|
||||
for msg in messages:
|
||||
role = msg.get('role', 'user')
|
||||
content = msg.get('content', '')
|
||||
if role == 'system':
|
||||
formatted.append(f"System: {content}")
|
||||
elif role == 'user':
|
||||
formatted.append(f"User: {content}")
|
||||
elif role == 'assistant':
|
||||
formatted.append(f"Assistant: {content}")
|
||||
return "\n".join(formatted) + "\nAssistant:"
|
||||
|
||||
def embeddings(self, prompt: str) -> List[float]:
|
||||
"""
|
||||
Get embeddings for a prompt (simplified - returns token embeddings).
|
||||
|
||||
Note: This is a simplified version. For full embeddings,
|
||||
you may need to access model internals.
|
||||
"""
|
||||
tokens = self.model.tokenizer(
|
||||
[prompt],
|
||||
return_tensors="pt",
|
||||
truncation=True,
|
||||
max_length=512,
|
||||
padding=False
|
||||
)
|
||||
# This is a placeholder - actual embeddings would require model forward pass
|
||||
return tokens['input_ids'].tolist()[0]
|
||||
|
||||
|
||||
# Convenience function for easy migration
|
||||
def create_ollama_client(model_name: str, compression: Optional[str] = None, **kwargs):
|
||||
"""
|
||||
Create an Ollama-compatible client using AirLLM.
|
||||
|
||||
Usage:
|
||||
client = create_ollama_client("meta-llama/Llama-3.2-3B-Instruct")
|
||||
response = client.generate("Hello, how are you?")
|
||||
"""
|
||||
return AirLLMOllamaWrapper(model_name, compression=compression, **kwargs)
|
||||
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
# Example 1: Basic generation
|
||||
print("Example 1: Basic Generation")
|
||||
print("=" * 60)
|
||||
|
||||
# Initialize (this will take time on first run)
|
||||
# client = create_ollama_client("meta-llama/Llama-3.2-3B-Instruct")
|
||||
|
||||
# Generate
|
||||
# response = client.generate("What is the capital of France?")
|
||||
# print(f"Response: {response}")
|
||||
|
||||
print("\nExample 2: Chat Interface")
|
||||
print("=" * 60)
|
||||
|
||||
# Chat example
|
||||
# messages = [
|
||||
# {"role": "user", "content": "Hello! How are you?"}
|
||||
# ]
|
||||
# response = client.chat(messages)
|
||||
# print(f"Response: {response}")
|
||||
|
||||
print("\nUncomment the code above to test!")
|
||||
|
||||
@ -127,6 +127,7 @@ class LiteLLMProvider(LLMProvider):
|
||||
"messages": messages,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"stream": False, # Explicitly disable streaming to avoid hangs with some providers
|
||||
}
|
||||
|
||||
# Apply model-specific overrides (e.g. kimi-k2.5 temperature)
|
||||
@ -148,6 +149,11 @@ class LiteLLMProvider(LLMProvider):
|
||||
kwargs["tools"] = tools
|
||||
kwargs["tool_choice"] = "auto"
|
||||
|
||||
# Add timeout to prevent hangs (especially with local servers)
|
||||
# Ollama can be slow with complex prompts, so use a longer timeout
|
||||
# Increased to 400s for larger models like mistral-nemo
|
||||
kwargs["timeout"] = 400.0
|
||||
|
||||
try:
|
||||
response = await acompletion(**kwargs)
|
||||
return self._parse_response(response)
|
||||
|
||||
@ -6,7 +6,7 @@ Adding a new provider:
|
||||
2. Add a field to ProvidersConfig in config/schema.py.
|
||||
Done. Env vars, prefixing, config matching, status display all derive from here.
|
||||
|
||||
Order matters — it controls match priority and fallback. Gateways first.
|
||||
Order matters — it controls match priority and fallback.
|
||||
Every entry writes out all fields so you can copy-paste as a template.
|
||||
"""
|
||||
|
||||
@ -62,86 +62,10 @@ class ProviderSpec:
|
||||
|
||||
PROVIDERS: tuple[ProviderSpec, ...] = (
|
||||
|
||||
# === Gateways (detected by api_key / api_base, not model name) =========
|
||||
# Gateways can route any model, so they win in fallback.
|
||||
|
||||
# OpenRouter: global gateway, keys start with "sk-or-"
|
||||
ProviderSpec(
|
||||
name="openrouter",
|
||||
keywords=("openrouter",),
|
||||
env_key="OPENROUTER_API_KEY",
|
||||
display_name="OpenRouter",
|
||||
litellm_prefix="openrouter", # claude-3 → openrouter/claude-3
|
||||
skip_prefixes=(),
|
||||
env_extras=(),
|
||||
is_gateway=True,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="sk-or-",
|
||||
detect_by_base_keyword="openrouter",
|
||||
default_api_base="https://openrouter.ai/api/v1",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
),
|
||||
|
||||
# AiHubMix: global gateway, OpenAI-compatible interface.
|
||||
# strip_model_prefix=True: it doesn't understand "anthropic/claude-3",
|
||||
# so we strip to bare "claude-3" then re-prefix as "openai/claude-3".
|
||||
ProviderSpec(
|
||||
name="aihubmix",
|
||||
keywords=("aihubmix",),
|
||||
env_key="OPENAI_API_KEY", # OpenAI-compatible
|
||||
display_name="AiHubMix",
|
||||
litellm_prefix="openai", # → openai/{model}
|
||||
skip_prefixes=(),
|
||||
env_extras=(),
|
||||
is_gateway=True,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="aihubmix",
|
||||
default_api_base="https://aihubmix.com/v1",
|
||||
strip_model_prefix=True, # anthropic/claude-3 → claude-3 → openai/claude-3
|
||||
model_overrides=(),
|
||||
),
|
||||
|
||||
# === Standard providers (matched by model-name keywords) ===============
|
||||
|
||||
# Anthropic: LiteLLM recognizes "claude-*" natively, no prefix needed.
|
||||
ProviderSpec(
|
||||
name="anthropic",
|
||||
keywords=("anthropic", "claude"),
|
||||
env_key="ANTHROPIC_API_KEY",
|
||||
display_name="Anthropic",
|
||||
litellm_prefix="",
|
||||
skip_prefixes=(),
|
||||
env_extras=(),
|
||||
is_gateway=False,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="",
|
||||
default_api_base="",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
),
|
||||
|
||||
# OpenAI: LiteLLM recognizes "gpt-*" natively, no prefix needed.
|
||||
ProviderSpec(
|
||||
name="openai",
|
||||
keywords=("openai", "gpt"),
|
||||
env_key="OPENAI_API_KEY",
|
||||
display_name="OpenAI",
|
||||
litellm_prefix="",
|
||||
skip_prefixes=(),
|
||||
env_extras=(),
|
||||
is_gateway=False,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="",
|
||||
default_api_base="",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
),
|
||||
|
||||
# DeepSeek: needs "deepseek/" prefix for LiteLLM routing.
|
||||
# Can be used with local models or API.
|
||||
ProviderSpec(
|
||||
name="deepseek",
|
||||
keywords=("deepseek",),
|
||||
@ -159,107 +83,6 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
||||
model_overrides=(),
|
||||
),
|
||||
|
||||
# Gemini: needs "gemini/" prefix for LiteLLM.
|
||||
ProviderSpec(
|
||||
name="gemini",
|
||||
keywords=("gemini",),
|
||||
env_key="GEMINI_API_KEY",
|
||||
display_name="Gemini",
|
||||
litellm_prefix="gemini", # gemini-pro → gemini/gemini-pro
|
||||
skip_prefixes=("gemini/",), # avoid double-prefix
|
||||
env_extras=(),
|
||||
is_gateway=False,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="",
|
||||
default_api_base="",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
),
|
||||
|
||||
# Zhipu: LiteLLM uses "zai/" prefix.
|
||||
# Also mirrors key to ZHIPUAI_API_KEY (some LiteLLM paths check that).
|
||||
# skip_prefixes: don't add "zai/" when already routed via gateway.
|
||||
ProviderSpec(
|
||||
name="zhipu",
|
||||
keywords=("zhipu", "glm", "zai"),
|
||||
env_key="ZAI_API_KEY",
|
||||
display_name="Zhipu AI",
|
||||
litellm_prefix="zai", # glm-4 → zai/glm-4
|
||||
skip_prefixes=("zhipu/", "zai/", "openrouter/", "hosted_vllm/"),
|
||||
env_extras=(
|
||||
("ZHIPUAI_API_KEY", "{api_key}"),
|
||||
),
|
||||
is_gateway=False,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="",
|
||||
default_api_base="",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
),
|
||||
|
||||
# DashScope: Qwen models, needs "dashscope/" prefix.
|
||||
ProviderSpec(
|
||||
name="dashscope",
|
||||
keywords=("qwen", "dashscope"),
|
||||
env_key="DASHSCOPE_API_KEY",
|
||||
display_name="DashScope",
|
||||
litellm_prefix="dashscope", # qwen-max → dashscope/qwen-max
|
||||
skip_prefixes=("dashscope/", "openrouter/"),
|
||||
env_extras=(),
|
||||
is_gateway=False,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="",
|
||||
default_api_base="",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
),
|
||||
|
||||
# Moonshot: Kimi models, needs "moonshot/" prefix.
|
||||
# LiteLLM requires MOONSHOT_API_BASE env var to find the endpoint.
|
||||
# Kimi K2.5 API enforces temperature >= 1.0.
|
||||
ProviderSpec(
|
||||
name="moonshot",
|
||||
keywords=("moonshot", "kimi"),
|
||||
env_key="MOONSHOT_API_KEY",
|
||||
display_name="Moonshot",
|
||||
litellm_prefix="moonshot", # kimi-k2.5 → moonshot/kimi-k2.5
|
||||
skip_prefixes=("moonshot/", "openrouter/"),
|
||||
env_extras=(
|
||||
("MOONSHOT_API_BASE", "{api_base}"),
|
||||
),
|
||||
is_gateway=False,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="",
|
||||
default_api_base="https://api.moonshot.ai/v1", # intl; use api.moonshot.cn for China
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(
|
||||
("kimi-k2.5", {"temperature": 1.0}),
|
||||
),
|
||||
),
|
||||
|
||||
# MiniMax: needs "minimax/" prefix for LiteLLM routing.
|
||||
# Uses OpenAI-compatible API at api.minimax.io/v1.
|
||||
ProviderSpec(
|
||||
name="minimax",
|
||||
keywords=("minimax",),
|
||||
env_key="MINIMAX_API_KEY",
|
||||
display_name="MiniMax",
|
||||
litellm_prefix="minimax", # MiniMax-M2.1 → minimax/MiniMax-M2.1
|
||||
skip_prefixes=("minimax/", "openrouter/"),
|
||||
env_extras=(),
|
||||
is_gateway=False,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="",
|
||||
default_api_base="https://api.minimax.io/v1",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
),
|
||||
|
||||
# === Local deployment (matched by config key, NOT by api_base) =========
|
||||
|
||||
# vLLM / any OpenAI-compatible local server.
|
||||
@ -281,23 +104,44 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
||||
model_overrides=(),
|
||||
),
|
||||
|
||||
# === Auxiliary (not a primary LLM provider) ============================
|
||||
|
||||
# Groq: mainly used for Whisper voice transcription, also usable for LLM.
|
||||
# Needs "groq/" prefix for LiteLLM routing. Placed last — it rarely wins fallback.
|
||||
# Ollama: local OpenAI-compatible server.
|
||||
# Use OpenAI-compatible endpoint, not native Ollama API.
|
||||
# Detected when config key is "ollama" or api_base contains "11434" or "ollama".
|
||||
ProviderSpec(
|
||||
name="groq",
|
||||
keywords=("groq",),
|
||||
env_key="GROQ_API_KEY",
|
||||
display_name="Groq",
|
||||
litellm_prefix="groq", # llama3-8b-8192 → groq/llama3-8b-8192
|
||||
skip_prefixes=("groq/",), # avoid double-prefix
|
||||
name="ollama",
|
||||
keywords=("ollama", "llama"), # Match both "ollama" and "llama" model names
|
||||
env_key="OPENAI_API_KEY", # Use OpenAI-compatible API
|
||||
display_name="Ollama",
|
||||
litellm_prefix="", # No prefix - use as OpenAI-compatible
|
||||
skip_prefixes=(),
|
||||
env_extras=(
|
||||
("OPENAI_API_BASE", "{api_base}"), # Set OpenAI API base to Ollama endpoint
|
||||
),
|
||||
is_gateway=False,
|
||||
is_local=True,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="11434", # Detect by default Ollama port
|
||||
default_api_base="http://localhost:11434/v1",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
),
|
||||
|
||||
# AirLLM: direct local model inference (no HTTP server).
|
||||
# Loads models directly into memory for GPU-optimized inference.
|
||||
# Detected when config key is "airllm".
|
||||
ProviderSpec(
|
||||
name="airllm",
|
||||
keywords=("airllm",),
|
||||
env_key="", # No API key needed (local)
|
||||
display_name="AirLLM",
|
||||
litellm_prefix="", # Not used with LiteLLM
|
||||
skip_prefixes=(),
|
||||
env_extras=(),
|
||||
is_gateway=False,
|
||||
is_local=False,
|
||||
is_local=True,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="",
|
||||
default_api_base="",
|
||||
default_api_base="", # Not used (direct Python calls)
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
),
|
||||
@ -325,12 +169,11 @@ def find_gateway(
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
) -> ProviderSpec | None:
|
||||
"""Detect gateway/local provider.
|
||||
"""Detect local provider.
|
||||
|
||||
Priority:
|
||||
1. provider_name — if it maps to a gateway/local spec, use it directly.
|
||||
2. api_key prefix — e.g. "sk-or-" → OpenRouter.
|
||||
3. api_base keyword — e.g. "aihubmix" in URL → AiHubMix.
|
||||
1. provider_name — if it maps to a local spec, use it directly.
|
||||
2. api_base keyword — e.g. "11434" in URL → Ollama.
|
||||
|
||||
A standard provider with a custom api_base (e.g. DeepSeek behind a proxy)
|
||||
will NOT be mistaken for vLLM — the old fallback is gone.
|
||||
@ -341,10 +184,8 @@ def find_gateway(
|
||||
if spec and (spec.is_gateway or spec.is_local):
|
||||
return spec
|
||||
|
||||
# 2. Auto-detect by api_key prefix / api_base keyword
|
||||
# 2. Auto-detect by api_base keyword
|
||||
for spec in PROVIDERS:
|
||||
if spec.detect_by_key_prefix and api_key and api_key.startswith(spec.detect_by_key_prefix):
|
||||
return spec
|
||||
if spec.detect_by_base_keyword and api_base and spec.detect_by_base_keyword in api_base:
|
||||
return spec
|
||||
|
||||
|
||||
397
setup.sh
Normal file
397
setup.sh
Normal file
@ -0,0 +1,397 @@
|
||||
#!/bin/bash
|
||||
# Nanobot Setup Script
|
||||
# Automates installation and configuration of nanobot with Ollama/AirLLM
|
||||
|
||||
set -e # Exit on error
|
||||
|
||||
# Colors for output
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
BLUE='\033[0;34m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
# Configuration
|
||||
VENV_DIR="venv"
|
||||
CONFIG_DIR="$HOME/.nanobot"
|
||||
CONFIG_FILE="$CONFIG_DIR/config.json"
|
||||
MODEL_DIR="$HOME/.local/models/llama3.2-3b-instruct"
|
||||
MODEL_NAME="meta-llama/Llama-3.2-3B-Instruct"
|
||||
|
||||
# Functions
|
||||
print_header() {
|
||||
echo -e "\n${BLUE}========================================${NC}"
|
||||
echo -e "${BLUE}$1${NC}"
|
||||
echo -e "${BLUE}========================================${NC}\n"
|
||||
}
|
||||
|
||||
print_success() {
|
||||
echo -e "${GREEN}✓ $1${NC}"
|
||||
}
|
||||
|
||||
print_warning() {
|
||||
echo -e "${YELLOW}⚠ $1${NC}"
|
||||
}
|
||||
|
||||
print_error() {
|
||||
echo -e "${RED}✗ $1${NC}"
|
||||
}
|
||||
|
||||
print_info() {
|
||||
echo -e "${BLUE}ℹ $1${NC}"
|
||||
}
|
||||
|
||||
# Check if command exists
|
||||
command_exists() {
|
||||
command -v "$1" >/dev/null 2>&1
|
||||
}
|
||||
|
||||
# Check prerequisites
|
||||
check_prerequisites() {
|
||||
print_header "Checking Prerequisites"
|
||||
|
||||
local missing=0
|
||||
|
||||
if ! command_exists python3; then
|
||||
print_error "Python 3 is not installed"
|
||||
missing=1
|
||||
else
|
||||
PYTHON_VERSION=$(python3 --version 2>&1 | awk '{print $2}')
|
||||
print_success "Python $PYTHON_VERSION found"
|
||||
|
||||
# Check Python version (need 3.10+)
|
||||
PYTHON_MAJOR=$(echo $PYTHON_VERSION | cut -d. -f1)
|
||||
PYTHON_MINOR=$(echo $PYTHON_VERSION | cut -d. -f2)
|
||||
if [ "$PYTHON_MAJOR" -lt 3 ] || ([ "$PYTHON_MAJOR" -eq 3 ] && [ "$PYTHON_MINOR" -lt 10 ]); then
|
||||
print_error "Python 3.10+ required, found $PYTHON_VERSION"
|
||||
missing=1
|
||||
fi
|
||||
fi
|
||||
|
||||
if ! command_exists git; then
|
||||
print_warning "Git is not installed (optional, but recommended)"
|
||||
else
|
||||
print_success "Git found"
|
||||
fi
|
||||
|
||||
if ! command_exists pip3 && ! python3 -m pip --version >/dev/null 2>&1; then
|
||||
print_error "pip is not installed"
|
||||
missing=1
|
||||
else
|
||||
print_success "pip found"
|
||||
fi
|
||||
|
||||
if [ $missing -eq 1 ]; then
|
||||
print_error "Missing required prerequisites. Please install them first."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
print_success "All prerequisites met"
|
||||
}
|
||||
|
||||
# Create virtual environment
|
||||
setup_venv() {
|
||||
print_header "Setting Up Virtual Environment"
|
||||
|
||||
if [ -d "$VENV_DIR" ]; then
|
||||
print_warning "Virtual environment already exists at $VENV_DIR"
|
||||
read -p "Recreate it? (y/n): " -n 1 -r
|
||||
echo
|
||||
if [[ $REPLY =~ ^[Yy]$ ]]; then
|
||||
rm -rf "$VENV_DIR"
|
||||
print_info "Removed existing virtual environment"
|
||||
else
|
||||
print_info "Using existing virtual environment"
|
||||
return
|
||||
fi
|
||||
fi
|
||||
|
||||
print_info "Creating virtual environment..."
|
||||
python3 -m venv "$VENV_DIR"
|
||||
print_success "Virtual environment created"
|
||||
|
||||
print_info "Activating virtual environment..."
|
||||
source "$VENV_DIR/bin/activate"
|
||||
print_success "Virtual environment activated"
|
||||
|
||||
print_info "Upgrading pip..."
|
||||
pip install --upgrade pip --quiet
|
||||
print_success "pip upgraded"
|
||||
}
|
||||
|
||||
# Install dependencies
|
||||
install_dependencies() {
|
||||
print_header "Installing Dependencies"
|
||||
|
||||
if [ -z "$VIRTUAL_ENV" ]; then
|
||||
source "$VENV_DIR/bin/activate"
|
||||
fi
|
||||
|
||||
print_info "Installing nanobot and dependencies..."
|
||||
pip install -e . --quiet
|
||||
print_success "Nanobot installed"
|
||||
|
||||
# Check if AirLLM should be installed
|
||||
read -p "Do you want to use AirLLM? (y/n): " -n 1 -r
|
||||
echo
|
||||
if [[ $REPLY =~ ^[Yy]$ ]]; then
|
||||
print_info "Installing AirLLM..."
|
||||
pip install airllm bitsandbytes --quiet || {
|
||||
print_warning "AirLLM installation had issues, but continuing..."
|
||||
print_info "You can install it later with: pip install airllm bitsandbytes"
|
||||
}
|
||||
print_success "AirLLM installed (or attempted)"
|
||||
USE_AIRLLM=true
|
||||
else
|
||||
USE_AIRLLM=false
|
||||
fi
|
||||
}
|
||||
|
||||
# Check for Ollama
|
||||
check_ollama() {
|
||||
if command_exists ollama; then
|
||||
print_success "Ollama is installed"
|
||||
if ollama list >/dev/null 2>&1; then
|
||||
print_success "Ollama is running"
|
||||
return 0
|
||||
else
|
||||
print_warning "Ollama is installed but not running"
|
||||
return 1
|
||||
fi
|
||||
else
|
||||
print_warning "Ollama is not installed"
|
||||
return 1
|
||||
fi
|
||||
}
|
||||
|
||||
# Setup Ollama configuration
|
||||
setup_ollama() {
|
||||
print_header "Setting Up Ollama"
|
||||
|
||||
if ! check_ollama; then
|
||||
print_info "Ollama is not installed or not running"
|
||||
read -p "Do you want to install Ollama? (y/n): " -n 1 -r
|
||||
echo
|
||||
if [[ $REPLY =~ ^[Yy]$ ]]; then
|
||||
print_info "Installing Ollama..."
|
||||
curl -fsSL https://ollama.ai/install.sh | sh || {
|
||||
print_error "Failed to install Ollama automatically"
|
||||
print_info "Please install manually from: https://ollama.ai"
|
||||
return 1
|
||||
}
|
||||
print_success "Ollama installed"
|
||||
else
|
||||
return 1
|
||||
fi
|
||||
fi
|
||||
|
||||
# Check if llama3.2 is available
|
||||
if ollama list | grep -q "llama3.2"; then
|
||||
print_success "llama3.2 model found"
|
||||
else
|
||||
print_info "Downloading llama3.2 model (this may take a while)..."
|
||||
ollama pull llama3.2:latest || {
|
||||
print_error "Failed to pull llama3.2 model"
|
||||
return 1
|
||||
}
|
||||
print_success "llama3.2 model downloaded"
|
||||
fi
|
||||
|
||||
# Create config
|
||||
mkdir -p "$CONFIG_DIR"
|
||||
cat > "$CONFIG_FILE" << EOF
|
||||
{
|
||||
"providers": {
|
||||
"ollama": {
|
||||
"apiKey": "dummy",
|
||||
"apiBase": "http://localhost:11434/v1"
|
||||
}
|
||||
},
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"model": "llama3.2:latest"
|
||||
}
|
||||
}
|
||||
}
|
||||
EOF
|
||||
chmod 600 "$CONFIG_FILE"
|
||||
print_success "Ollama configuration created at $CONFIG_FILE"
|
||||
return 0
|
||||
}
|
||||
|
||||
# Setup AirLLM configuration
|
||||
setup_airllm() {
|
||||
print_header "Setting Up AirLLM"
|
||||
|
||||
# Check if model already exists
|
||||
if [ -d "$MODEL_DIR" ] && [ -f "$MODEL_DIR/config.json" ]; then
|
||||
print_success "Model already exists at $MODEL_DIR"
|
||||
else
|
||||
print_info "Model needs to be downloaded"
|
||||
print_info "You'll need a Hugging Face token to download gated models"
|
||||
echo
|
||||
print_info "Steps:"
|
||||
echo " 1. Get token: https://huggingface.co/settings/tokens"
|
||||
echo " 2. Accept license: https://huggingface.co/$MODEL_NAME"
|
||||
echo
|
||||
read -p "Do you have a Hugging Face token? (y/n): " -n 1 -r
|
||||
echo
|
||||
if [[ ! $REPLY =~ ^[Yy]$ ]]; then
|
||||
print_warning "Skipping model download. You can download it later."
|
||||
print_info "To download later, run:"
|
||||
echo " huggingface-cli download $MODEL_NAME --local-dir $MODEL_DIR --token YOUR_TOKEN"
|
||||
return 1
|
||||
fi
|
||||
|
||||
read -p "Enter your Hugging Face token: " -s HF_TOKEN
|
||||
echo
|
||||
|
||||
if [ -z "$HF_TOKEN" ]; then
|
||||
print_error "Token is required"
|
||||
return 1
|
||||
fi
|
||||
|
||||
# Install huggingface_hub if needed
|
||||
if [ -z "$VIRTUAL_ENV" ]; then
|
||||
source "$VENV_DIR/bin/activate"
|
||||
fi
|
||||
pip install huggingface_hub --quiet
|
||||
|
||||
print_info "Downloading model (this may take a while, ~2GB)..."
|
||||
mkdir -p "$MODEL_DIR"
|
||||
huggingface-cli download "$MODEL_NAME" \
|
||||
--local-dir "$MODEL_DIR" \
|
||||
--token "$HF_TOKEN" \
|
||||
--local-dir-use-symlinks False || {
|
||||
print_error "Failed to download model"
|
||||
print_info "Make sure you've accepted the license at: https://huggingface.co/$MODEL_NAME"
|
||||
return 1
|
||||
}
|
||||
print_success "Model downloaded to $MODEL_DIR"
|
||||
fi
|
||||
|
||||
# Create config
|
||||
mkdir -p "$CONFIG_DIR"
|
||||
cat > "$CONFIG_FILE" << EOF
|
||||
{
|
||||
"providers": {
|
||||
"airllm": {
|
||||
"apiKey": "$MODEL_DIR",
|
||||
"apiBase": null,
|
||||
"extraHeaders": {}
|
||||
}
|
||||
},
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"model": "$MODEL_DIR"
|
||||
}
|
||||
}
|
||||
}
|
||||
EOF
|
||||
chmod 600 "$CONFIG_FILE"
|
||||
print_success "AirLLM configuration created at $CONFIG_FILE"
|
||||
return 0
|
||||
}
|
||||
|
||||
# Test installation
|
||||
test_installation() {
|
||||
print_header "Testing Installation"
|
||||
|
||||
if [ -z "$VIRTUAL_ENV" ]; then
|
||||
source "$VENV_DIR/bin/activate"
|
||||
fi
|
||||
|
||||
print_info "Testing nanobot installation..."
|
||||
if nanobot --help >/dev/null 2>&1; then
|
||||
print_success "Nanobot is installed and working"
|
||||
else
|
||||
print_error "Nanobot test failed"
|
||||
return 1
|
||||
fi
|
||||
|
||||
print_info "Testing with a simple query..."
|
||||
if nanobot agent -m "Hello, what is 2+5?" >/dev/null 2>&1; then
|
||||
print_success "Test query successful!"
|
||||
else
|
||||
print_warning "Test query had issues (this might be normal if model is still loading)"
|
||||
print_info "Try running manually: nanobot agent -m 'Hello'"
|
||||
fi
|
||||
}
|
||||
|
||||
# Main setup flow
|
||||
main() {
|
||||
print_header "Nanobot Setup Script"
|
||||
print_info "This script will set up nanobot with Ollama or AirLLM"
|
||||
echo
|
||||
|
||||
# Check prerequisites
|
||||
check_prerequisites
|
||||
|
||||
# Setup virtual environment
|
||||
setup_venv
|
||||
|
||||
# Install dependencies
|
||||
install_dependencies
|
||||
|
||||
# Choose provider
|
||||
echo
|
||||
print_header "Choose Provider"
|
||||
echo "1. Ollama (easiest, no tokens needed)"
|
||||
echo "2. AirLLM (direct local inference, no HTTP server)"
|
||||
echo "3. Both (configure both, use either)"
|
||||
echo
|
||||
read -p "Choose option (1-3): " -n 1 -r
|
||||
echo
|
||||
|
||||
PROVIDER_SETUP=false
|
||||
|
||||
case $REPLY in
|
||||
1)
|
||||
if setup_ollama; then
|
||||
PROVIDER_SETUP=true
|
||||
fi
|
||||
;;
|
||||
2)
|
||||
if setup_airllm; then
|
||||
PROVIDER_SETUP=true
|
||||
fi
|
||||
;;
|
||||
3)
|
||||
if setup_ollama || setup_airllm; then
|
||||
PROVIDER_SETUP=true
|
||||
fi
|
||||
;;
|
||||
*)
|
||||
print_warning "Invalid choice, skipping provider setup"
|
||||
;;
|
||||
esac
|
||||
|
||||
if [ "$PROVIDER_SETUP" = false ]; then
|
||||
print_warning "Provider setup incomplete. You can configure manually later."
|
||||
print_info "Config file location: $CONFIG_FILE"
|
||||
fi
|
||||
|
||||
# Test installation
|
||||
test_installation
|
||||
|
||||
# Final instructions
|
||||
echo
|
||||
print_header "Setup Complete!"
|
||||
echo
|
||||
print_success "Nanobot is ready to use!"
|
||||
echo
|
||||
print_info "To activate the virtual environment:"
|
||||
echo " source $VENV_DIR/bin/activate"
|
||||
echo
|
||||
print_info "To use nanobot:"
|
||||
echo " nanobot agent -m 'Your message here'"
|
||||
echo
|
||||
print_info "Configuration file: $CONFIG_FILE"
|
||||
echo
|
||||
print_info "For more information, see SETUP.md"
|
||||
echo
|
||||
}
|
||||
|
||||
# Run main function
|
||||
main
|
||||
|
||||
175
setup_llama_airllm.py
Normal file
175
setup_llama_airllm.py
Normal file
@ -0,0 +1,175 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Setup script to configure nanobot to use Llama models with AirLLM.
|
||||
This script will:
|
||||
1. Check/create the config file
|
||||
2. Set up Llama model configuration
|
||||
3. Guide you through getting a Hugging Face token if needed
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
CONFIG_PATH = Path.home() / ".nanobot" / "config.json"
|
||||
|
||||
def get_hf_token_instructions():
|
||||
"""Print instructions for getting a Hugging Face token."""
|
||||
print("\n" + "="*70)
|
||||
print("GETTING A HUGGING FACE TOKEN")
|
||||
print("="*70)
|
||||
print("\nTo use Llama models (which are gated), you need a Hugging Face token:")
|
||||
print("\n1. Go to: https://huggingface.co/settings/tokens")
|
||||
print("2. Click 'New token'")
|
||||
print("3. Give it a name (e.g., 'nanobot')")
|
||||
print("4. Select 'Read' permission")
|
||||
print("5. Click 'Generate token'")
|
||||
print("6. Copy the token (starts with 'hf_...')")
|
||||
print("\nThen accept the Llama model license:")
|
||||
print("1. Go to: https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct")
|
||||
print("2. Click 'Agree and access repository'")
|
||||
print("3. Accept the license terms")
|
||||
print("\n" + "="*70 + "\n")
|
||||
|
||||
def load_existing_config():
|
||||
"""Load existing config or return default."""
|
||||
if CONFIG_PATH.exists():
|
||||
try:
|
||||
with open(CONFIG_PATH) as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not read existing config: {e}")
|
||||
return {}
|
||||
return {}
|
||||
|
||||
def create_llama_config():
|
||||
"""Create or update config for Llama with AirLLM."""
|
||||
config = load_existing_config()
|
||||
|
||||
# Ensure providers section exists
|
||||
if "providers" not in config:
|
||||
config["providers"] = {}
|
||||
|
||||
# Ensure agents section exists
|
||||
if "agents" not in config:
|
||||
config["agents"] = {}
|
||||
if "defaults" not in config["agents"]:
|
||||
config["agents"]["defaults"] = {}
|
||||
|
||||
# Choose Llama model
|
||||
print("\n" + "="*70)
|
||||
print("CHOOSE LLAMA MODEL")
|
||||
print("="*70)
|
||||
print("\nAvailable models:")
|
||||
print(" 1. Llama-3.2-3B-Instruct (Recommended - fast, minimal memory)")
|
||||
print(" 2. Llama-3.1-8B-Instruct (Good balance of performance and speed)")
|
||||
print(" 3. Custom (enter model path)")
|
||||
|
||||
choice = input("\nChoose model (1-3, default: 1): ").strip() or "1"
|
||||
|
||||
model_map = {
|
||||
"1": "meta-llama/Llama-3.2-3B-Instruct",
|
||||
"2": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
}
|
||||
|
||||
if choice == "3":
|
||||
model_path = input("Enter model path (e.g., meta-llama/Llama-3.2-3B-Instruct): ").strip()
|
||||
if not model_path:
|
||||
model_path = "meta-llama/Llama-3.2-3B-Instruct"
|
||||
print(f"Using default: {model_path}")
|
||||
else:
|
||||
model_path = model_map.get(choice, "meta-llama/Llama-3.2-3B-Instruct")
|
||||
|
||||
# Set up AirLLM provider with Llama model
|
||||
# Note: apiKey can be used as model path, or we can put model in defaults
|
||||
config["providers"]["airllm"] = {
|
||||
"apiKey": "", # Will be set to model path
|
||||
"apiBase": None,
|
||||
"extraHeaders": {}
|
||||
}
|
||||
|
||||
# Set default model
|
||||
config["agents"]["defaults"]["model"] = model_path
|
||||
|
||||
# Ask for Hugging Face token
|
||||
print("\n" + "="*70)
|
||||
print("HUGGING FACE TOKEN SETUP")
|
||||
print("="*70)
|
||||
print("\nDo you have a Hugging Face token? (Required for Llama models)")
|
||||
print("If not, we'll show you how to get one.\n")
|
||||
|
||||
has_token = input("Do you have a Hugging Face token? (y/n): ").strip().lower()
|
||||
|
||||
if has_token == 'y':
|
||||
hf_token = input("\nEnter your Hugging Face token (starts with 'hf_'): ").strip()
|
||||
if hf_token and hf_token.startswith('hf_'):
|
||||
# Store token in extraHeaders
|
||||
config["providers"]["airllm"]["extraHeaders"]["hf_token"] = hf_token
|
||||
# Also set apiKey to model path (AirLLM uses apiKey as model path if it contains '/')
|
||||
config["providers"]["airllm"]["apiKey"] = config["agents"]["defaults"]["model"]
|
||||
print("\n✓ Token configured!")
|
||||
else:
|
||||
print("⚠ Warning: Token doesn't look valid (should start with 'hf_')")
|
||||
print("You can add it later by editing the config file.")
|
||||
# Still set model path in apiKey
|
||||
config["providers"]["airllm"]["apiKey"] = config["agents"]["defaults"]["model"]
|
||||
else:
|
||||
get_hf_token_instructions()
|
||||
print("\nYou can add your token later by:")
|
||||
print(f"1. Editing: {CONFIG_PATH}")
|
||||
print("2. Adding your token to: providers.airllm.extraHeaders.hf_token")
|
||||
print("\nOr run this script again after getting your token.")
|
||||
|
||||
return config
|
||||
|
||||
def save_config(config):
|
||||
"""Save config to file."""
|
||||
CONFIG_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(CONFIG_PATH, 'w') as f:
|
||||
json.dump(config, f, indent=2)
|
||||
|
||||
# Set secure permissions
|
||||
os.chmod(CONFIG_PATH, 0o600)
|
||||
print(f"\n✓ Configuration saved to: {CONFIG_PATH}")
|
||||
print(f"✓ File permissions set to 600 (read/write for owner only)")
|
||||
|
||||
def main():
|
||||
"""Main setup function."""
|
||||
print("\n" + "="*70)
|
||||
print("NANOBOT LLAMA + AIRLLM SETUP")
|
||||
print("="*70)
|
||||
print("\nThis script will configure nanobot to use Llama models with AirLLM.\n")
|
||||
|
||||
if CONFIG_PATH.exists():
|
||||
print(f"Found existing config at: {CONFIG_PATH}")
|
||||
backup = input("\nCreate backup? (y/n): ").strip().lower()
|
||||
if backup == 'y':
|
||||
backup_path = CONFIG_PATH.with_suffix('.json.backup')
|
||||
import shutil
|
||||
shutil.copy(CONFIG_PATH, backup_path)
|
||||
print(f"✓ Backup created: {backup_path}")
|
||||
else:
|
||||
print(f"Creating new config at: {CONFIG_PATH}")
|
||||
|
||||
config = create_llama_config()
|
||||
save_config(config)
|
||||
|
||||
print("\n" + "="*70)
|
||||
print("SETUP COMPLETE!")
|
||||
print("="*70)
|
||||
print("\nConfiguration:")
|
||||
print(f" Model: {config['agents']['defaults']['model']}")
|
||||
print(f" Provider: airllm")
|
||||
if config["providers"]["airllm"].get("extraHeaders", {}).get("hf_token"):
|
||||
print(f" HF Token: {'*' * 20} (configured)")
|
||||
else:
|
||||
print(f" HF Token: Not configured (add it to use gated models)")
|
||||
|
||||
print("\nNext steps:")
|
||||
print(" 1. If you need a Hugging Face token, follow the instructions above")
|
||||
print(" 2. Test it: nanobot agent -m 'Hello, what is 2+5?'")
|
||||
print("\n" + "="*70 + "\n")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user