"""OpenAI-compatible API client (for Ollama / vLLM / TGI / OpenAI-style servers). This lets LLM Council talk to any OpenAI-compatible server (local Ollama, remote Ollama, vLLM, TGI, etc.). """ from __future__ import annotations import asyncio import os from typing import Any, Dict, List, Optional import httpx from .config import ( OPENAI_COMPAT_BASE_URL, OPENAI_COMPAT_RETRIES, OPENAI_COMPAT_RETRY_BACKOFF_SECONDS, OPENAI_COMPAT_TIMEOUT_SECONDS, OPENAI_COMPAT_CONNECT_TIMEOUT_SECONDS, OPENAI_COMPAT_WRITE_TIMEOUT_SECONDS, OPENAI_COMPAT_POOL_TIMEOUT_SECONDS, DEBUG, ) def _resolve_chat_completions_url(base_url: str) -> str: """ Accepts either: - http://host:8000 -> http://host:8000/v1/chat/completions - http://host:8000/v1 -> http://host:8000/v1/chat/completions - http://host:8000/v1/ -> http://host:8000/v1/chat/completions """ base = base_url.rstrip("/") if base.endswith("/v1"): return f"{base}/chat/completions" if "/v1/" in f"{base}/": # Already has /v1 somewhere; assume caller gave full root including /v1 return f"{base}/chat/completions" return f"{base}/v1/chat/completions" def _resolve_models_url(base_url: str) -> str: base = base_url.rstrip("/") if base.endswith("/v1"): return f"{base}/models" if "/v1/" in f"{base}/": return f"{base}/models" return f"{base}/v1/models" def _resolve_ollama_tags_url(base_url: str) -> str: """Resolve Ollama's native /api/tags endpoint URL.""" base = base_url.rstrip("/") return f"{base}/api/tags" def _should_retry(status_code: int) -> bool: return status_code in {408, 409, 425, 429, 500, 502, 503, 504} async def query_model( model: str, messages: List[Dict[str, str]], *, base_url: Optional[str] = None, api_key: Optional[str] = None, max_tokens: int = 2048, timeout: Optional[float] = None, client: Optional[httpx.AsyncClient] = None, ) -> Optional[Dict[str, Any]]: """Query a model via an OpenAI-compatible chat completions endpoint.""" resolved_base_url = base_url or OPENAI_COMPAT_BASE_URL if not resolved_base_url: print("Error querying OpenAI-compatible provider: OPENAI_COMPAT_BASE_URL not set") return None resolved_api_key = api_key if api_key is not None else os.getenv("OPENAI_COMPAT_API_KEY") resolved_timeout = OPENAI_COMPAT_TIMEOUT_SECONDS if timeout is None else timeout retries = OPENAI_COMPAT_RETRIES backoff = OPENAI_COMPAT_RETRY_BACKOFF_SECONDS url = _resolve_chat_completions_url(resolved_base_url) headers = {"Content-Type": "application/json"} if resolved_api_key: headers["Authorization"] = f"Bearer {resolved_api_key}" payload: Dict[str, Any] = { "model": model, "messages": messages, "max_tokens": max_tokens, } if DEBUG: print(f"[DEBUG] Querying model '{model}' at {url} (timeout={resolved_timeout}s, max_tokens={max_tokens})") close_client = False try: if client is None: # Use explicit Timeout object to ensure read timeout is set correctly # For LLM requests, we need a long read timeout since generation can take time timeout_config = httpx.Timeout( connect=OPENAI_COMPAT_CONNECT_TIMEOUT_SECONDS, read=resolved_timeout, # Read timeout: use the configured timeout write=OPENAI_COMPAT_WRITE_TIMEOUT_SECONDS, pool=OPENAI_COMPAT_POOL_TIMEOUT_SECONDS ) client = httpx.AsyncClient(timeout=timeout_config) close_client = True attempt = 0 while True: if DEBUG: print(f"[DEBUG] Attempt {attempt + 1}/{retries + 1}: POST {url}") resp = await client.post(url, headers=headers, json=payload) if resp.status_code != 200: # Preserve server-provided error text for debugging. try: err_json = resp.json() err_msg = err_json.get("error", {}).get("message", resp.text) except Exception: err_msg = resp.text if attempt < retries and _should_retry(resp.status_code): await asyncio.sleep(backoff * (2**attempt)) attempt += 1 continue print(f"Error querying model {model} (HTTP {resp.status_code}): {err_msg}") return None data = resp.json() msg = data["choices"][0]["message"] if DEBUG: print(f"[DEBUG] Model '{model}' responded successfully") return { "content": msg.get("content"), "reasoning_details": msg.get("reasoning_details"), } except httpx.TimeoutException as e: print(f"[ERROR] Model '{model}' timeout after {resolved_timeout}s at {url}") print( f"[ERROR] This can mean the model is loading / slow, OR that the server/port is unreachable.\n" f"[ERROR] Check connectivity: curl {resolved_base_url}/api/tags" ) return None except httpx.ConnectError as e: print(f"[ERROR] Cannot connect to {url}: {e}") print(f"[ERROR] Is Ollama running? Check: curl {resolved_base_url}/api/tags") return None except Exception as e: print(f"[ERROR] Unexpected error querying model '{model}' at {url}: {type(e).__name__}: {e}") import traceback traceback.print_exc() return None finally: if close_client and client is not None: await client.aclose() async def list_models( *, base_url: Optional[str] = None, api_key: Optional[str] = None, timeout: Optional[float] = None, client: Optional[httpx.AsyncClient] = None, ) -> Optional[List[str]]: """Return model IDs from an OpenAI-compatible server (/v1/models).""" resolved_base_url = base_url or OPENAI_COMPAT_BASE_URL if not resolved_base_url: return None resolved_api_key = api_key if api_key is not None else os.getenv("OPENAI_COMPAT_API_KEY") resolved_timeout = OPENAI_COMPAT_TIMEOUT_SECONDS if timeout is None else timeout retries = OPENAI_COMPAT_RETRIES backoff = OPENAI_COMPAT_RETRY_BACKOFF_SECONDS # Try OpenAI-compatible endpoint first url = _resolve_models_url(resolved_base_url) headers = {"Content-Type": "application/json"} if resolved_api_key: headers["Authorization"] = f"Bearer {resolved_api_key}" close_client = False try: if client is None: # Use explicit Timeout object for list_models (faster operation) timeout_config = httpx.Timeout( connect=OPENAI_COMPAT_CONNECT_TIMEOUT_SECONDS, read=resolved_timeout, write=OPENAI_COMPAT_WRITE_TIMEOUT_SECONDS, pool=OPENAI_COMPAT_POOL_TIMEOUT_SECONDS ) client = httpx.AsyncClient(timeout=timeout_config) close_client = True attempt = 0 while True: resp = await client.get(url, headers=headers) if resp.status_code == 200: data = resp.json() # Try OpenAI-compatible format first items = data.get("data", []) if items: ids: List[str] = [] for it in items: mid = it.get("id") if mid: ids.append(mid) return ids # Fallback: check if it's already in Ollama format items = data.get("models", []) if items: ids: List[str] = [] for it in items: mid = it.get("name") or it.get("model") if mid: ids.append(mid) return ids return [] # If /v1/models fails, try Ollama's native /api/tags endpoint if resp.status_code == 404 and attempt == 0: ollama_url = _resolve_ollama_tags_url(resolved_base_url) if DEBUG: print(f"[DEBUG] /v1/models not found, trying Ollama native API: {ollama_url}") resp = await client.get(ollama_url, headers=headers) if resp.status_code == 200: data = resp.json() items = data.get("models", []) if items: ids: List[str] = [] for it in items: mid = it.get("name") or it.get("model") if mid: ids.append(mid) return ids if attempt < retries and _should_retry(resp.status_code): await asyncio.sleep(backoff * (2**attempt)) attempt += 1 continue return None except Exception as e: if DEBUG: msg = str(e) if str(e) else "(no message)" print(f"[DEBUG] Error listing models: {type(e).__name__}: {msg}") return None finally: if close_client and client is not None: await client.aclose()