"""Unified LLM client. This module routes LLM requests to OpenAI-compatible servers (Ollama, vLLM, TGI, etc.). The base URL is determined by: - If USE_LOCAL_OLLAMA=true: uses http://localhost:11434 - Else if OPENAI_COMPAT_BASE_URL is set: uses that URL - Else: raises an error (base URL must be configured) """ from __future__ import annotations import os from typing import Any, Dict, List, Optional from .config import MAX_TOKENS, OPENAI_COMPAT_BASE_URL, LLM_TIMEOUT_SECONDS, DEBUG def _get_provider_name() -> str: """Returns the provider name (always 'openai_compat' now).""" return "openai_compat" def _get_max_concurrency() -> int: """ Maximum number of in-flight model requests when calling query_models_parallel. - If LLM_MAX_CONCURRENCY is unset/empty/invalid: unlimited (0) - If set to 1: strictly sequential - If set to N>1: at most N in flight """ raw = (os.getenv("LLM_MAX_CONCURRENCY") or "").strip() if not raw: return 0 try: v = int(raw) except ValueError: return 0 return max(0, v) def get_provider_info() -> Dict[str, Any]: """Get information about the configured provider.""" from .config import OPENAI_COMPAT_BASE_URL return { "provider": "openai_compat", "base_url": OPENAI_COMPAT_BASE_URL } async def list_models() -> Optional[List[str]]: """List available models from the OpenAI-compatible server.""" from .openai_compat import list_models as _list return await _list() async def query_model( model: str, messages: List[Dict[str, str]], timeout: Optional[float] = None, max_tokens_override: Optional[int] = None, ) -> Optional[Dict[str, Any]]: """Query a model via OpenAI-compatible API.""" from .openai_compat import query_model as _query max_tokens = max_tokens_override if max_tokens_override is not None else MAX_TOKENS resolved_timeout = timeout if timeout is not None else LLM_TIMEOUT_SECONDS return await _query( model, messages, max_tokens=max_tokens, timeout=resolved_timeout, ) async def query_models_parallel( models: List[str], messages: List[Dict[str, str]], timeout: Optional[float] = None, max_tokens_override: Optional[int] = None, ) -> Dict[str, Optional[Dict[str, Any]]]: import asyncio resolved_timeout = timeout if timeout is not None else LLM_TIMEOUT_SECONDS limit = _get_max_concurrency() # If limit is 1, run completely sequentially (one at a time, wait for each to finish) if limit == 1: results = {} for model in models: if DEBUG: print(f"[DEBUG] Running model '{model}' sequentially (concurrency=1)") results[model] = await query_model( model, messages, timeout=resolved_timeout, max_tokens_override=max_tokens_override, ) return results # If limit <= 0 or >= len(models), run all in parallel (no limit) if limit <= 0 or limit >= len(models): tasks = [ query_model( model, messages, timeout=resolved_timeout, max_tokens_override=max_tokens_override, ) for model in models ] responses = await asyncio.gather(*tasks) return {model: response for model, response in zip(models, responses)} # Otherwise, use semaphore to limit concurrency (2, 3, etc.) sem = asyncio.Semaphore(limit) async def _run_one(model: str) -> Optional[Dict[str, Any]]: async with sem: return await query_model( model, messages, timeout=resolved_timeout, max_tokens_override=max_tokens_override, ) tasks = [_run_one(model) for model in models] responses = await asyncio.gather(*tasks) return {model: response for model, response in zip(models, responses)}