diff --git a/nanobot/auth/__init__.py b/nanobot/auth/__init__.py index e74e1c2..c74d992 100644 --- a/nanobot/auth/__init__.py +++ b/nanobot/auth/__init__.py @@ -1,6 +1,6 @@ -"""鉴权相关模块。""" +"""Authentication modules.""" -from nanobot.auth.codex_oauth import ( +from nanobot.auth.codex import ( ensure_codex_token_available, get_codex_token, login_codex_oauth_interactive, diff --git a/nanobot/auth/codex/__init__.py b/nanobot/auth/codex/__init__.py new file mode 100644 index 0000000..707cd4d --- /dev/null +++ b/nanobot/auth/codex/__init__.py @@ -0,0 +1,15 @@ +"""Codex OAuth module.""" + +from nanobot.auth.codex.flow import ( + ensure_codex_token_available, + get_codex_token, + login_codex_oauth_interactive, +) +from nanobot.auth.codex.models import CodexToken + +__all__ = [ + "CodexToken", + "ensure_codex_token_available", + "get_codex_token", + "login_codex_oauth_interactive", +] diff --git a/nanobot/auth/codex/constants.py b/nanobot/auth/codex/constants.py new file mode 100644 index 0000000..bbe676a --- /dev/null +++ b/nanobot/auth/codex/constants.py @@ -0,0 +1,25 @@ +"""Codex OAuth constants.""" + +CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann" +AUTHORIZE_URL = "https://auth.openai.com/oauth/authorize" +TOKEN_URL = "https://auth.openai.com/oauth/token" +REDIRECT_URI = "http://localhost:1455/auth/callback" +SCOPE = "openid profile email offline_access" +JWT_CLAIM_PATH = "https://api.openai.com/auth" + +DEFAULT_ORIGINATOR = "nanobot" +TOKEN_FILENAME = "codex.json" +MANUAL_PROMPT_DELAY_SEC = 3 +SUCCESS_HTML = ( + "" + "" + "
" + "" + "" + "Authentication successful. Return to your terminal to continue.
" + "" + "" +) diff --git a/nanobot/auth/codex/flow.py b/nanobot/auth/codex/flow.py new file mode 100644 index 0000000..0966327 --- /dev/null +++ b/nanobot/auth/codex/flow.py @@ -0,0 +1,312 @@ +"""Codex OAuth login and token management.""" + +from __future__ import annotations + +import asyncio +import sys +import threading +import time +import urllib.parse +import webbrowser +from typing import Any, Callable + +import httpx + +from nanobot.auth.codex.constants import ( + AUTHORIZE_URL, + CLIENT_ID, + DEFAULT_ORIGINATOR, + MANUAL_PROMPT_DELAY_SEC, + REDIRECT_URI, + SCOPE, + TOKEN_URL, +) +from nanobot.auth.codex.models import CodexToken +from nanobot.auth.codex.pkce import ( + _create_state, + _decode_account_id, + _generate_pkce, + _parse_authorization_input, + _parse_token_payload, +) +from nanobot.auth.codex.server import _start_local_server +from nanobot.auth.codex.storage import ( + _FileLock, + _get_token_path, + _load_token_file, + _save_token_file, + _try_import_codex_cli_token, +) + + +def _exchange_code_for_token(code: str, verifier: str) -> CodexToken: + data = { + "grant_type": "authorization_code", + "client_id": CLIENT_ID, + "code": code, + "code_verifier": verifier, + "redirect_uri": REDIRECT_URI, + } + with httpx.Client(timeout=30.0) as client: + response = client.post(TOKEN_URL, data=data, headers={"Content-Type": "application/x-www-form-urlencoded"}) + if response.status_code != 200: + raise RuntimeError(f"Token exchange failed: {response.status_code} {response.text}") + + payload = response.json() + access, refresh, expires_in = _parse_token_payload(payload, "Token response missing fields") + print("Received access token:", access) + account_id = _decode_account_id(access) + return CodexToken( + access=access, + refresh=refresh, + expires=int(time.time() * 1000 + expires_in * 1000), + account_id=account_id, + ) + + +async def _exchange_code_for_token_async(code: str, verifier: str) -> CodexToken: + data = { + "grant_type": "authorization_code", + "client_id": CLIENT_ID, + "code": code, + "code_verifier": verifier, + "redirect_uri": REDIRECT_URI, + } + async with httpx.AsyncClient(timeout=30.0) as client: + response = await client.post( + TOKEN_URL, + data=data, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + if response.status_code != 200: + raise RuntimeError(f"Token exchange failed: {response.status_code} {response.text}") + + payload = response.json() + access, refresh, expires_in = _parse_token_payload(payload, "Token response missing fields") + + account_id = _decode_account_id(access) + return CodexToken( + access=access, + refresh=refresh, + expires=int(time.time() * 1000 + expires_in * 1000), + account_id=account_id, + ) + + +def _refresh_token(refresh_token: str) -> CodexToken: + data = { + "grant_type": "refresh_token", + "refresh_token": refresh_token, + "client_id": CLIENT_ID, + } + with httpx.Client(timeout=30.0) as client: + response = client.post(TOKEN_URL, data=data, headers={"Content-Type": "application/x-www-form-urlencoded"}) + if response.status_code != 200: + raise RuntimeError(f"Token refresh failed: {response.status_code} {response.text}") + + payload = response.json() + access, refresh, expires_in = _parse_token_payload(payload, "Token refresh response missing fields") + + account_id = _decode_account_id(access) + return CodexToken( + access=access, + refresh=refresh, + expires=int(time.time() * 1000 + expires_in * 1000), + account_id=account_id, + ) + + +def get_codex_token() -> CodexToken: + """Get an available token (refresh if needed).""" + token = _load_token_file() or _try_import_codex_cli_token() + if not token: + raise RuntimeError("Codex OAuth credentials not found. Please run the login command.") + + # Refresh 60 seconds early. + now_ms = int(time.time() * 1000) + if token.expires - now_ms > 60 * 1000: + return token + + lock_path = _get_token_path().with_suffix(".lock") + with _FileLock(lock_path): + # Re-read to avoid stale token if another process refreshed it. + token = _load_token_file() or token + now_ms = int(time.time() * 1000) + if token.expires - now_ms > 60 * 1000: + return token + try: + refreshed = _refresh_token(token.refresh) + _save_token_file(refreshed) + return refreshed + except Exception: + # If refresh fails, re-read the file to avoid false negatives. + latest = _load_token_file() + if latest and latest.expires - now_ms > 0: + return latest + raise + + +def ensure_codex_token_available() -> None: + """Ensure a valid token is available; raise if not.""" + _ = get_codex_token() + + +async def _read_stdin_line() -> str: + loop = asyncio.get_running_loop() + if hasattr(loop, "add_reader") and sys.stdin: + future: asyncio.Future[str] = loop.create_future() + + def _on_readable() -> None: + line = sys.stdin.readline() + if not future.done(): + future.set_result(line) + + try: + loop.add_reader(sys.stdin, _on_readable) + except Exception: + return await loop.run_in_executor(None, sys.stdin.readline) + + try: + return await future + finally: + try: + loop.remove_reader(sys.stdin) + except Exception: + pass + + return await loop.run_in_executor(None, sys.stdin.readline) + + +async def _await_manual_input( + on_manual_code_input: Callable[[str], None], +) -> str: + await asyncio.sleep(MANUAL_PROMPT_DELAY_SEC) + on_manual_code_input("Paste the authorization code (or full redirect URL), or wait for the browser callback:") + return await _read_stdin_line() + + +def login_codex_oauth_interactive( + on_auth: Callable[[str], None] | None = None, + on_prompt: Callable[[str], str] | None = None, + on_status: Callable[[str], None] | None = None, + on_progress: Callable[[str], None] | None = None, + on_manual_code_input: Callable[[str], None] = None, + originator: str = DEFAULT_ORIGINATOR, +) -> CodexToken: + """Interactive login flow.""" + + async def _login_async() -> CodexToken: + verifier, challenge = _generate_pkce() + state = _create_state() + + params = { + "response_type": "code", + "client_id": CLIENT_ID, + "redirect_uri": REDIRECT_URI, + "scope": SCOPE, + "code_challenge": challenge, + "code_challenge_method": "S256", + "state": state, + "id_token_add_organizations": "true", + "codex_cli_simplified_flow": "true", + "originator": originator, + } + url = f"{AUTHORIZE_URL}?{urllib.parse.urlencode(params)}" + + loop = asyncio.get_running_loop() + code_future: asyncio.Future[str] = loop.create_future() + + def _notify(code_value: str) -> None: + if code_future.done(): + return + loop.call_soon_threadsafe(code_future.set_result, code_value) + + server, server_error = _start_local_server(state, on_code=_notify) + if on_auth: + on_auth(url) + else: + webbrowser.open(url) + + if not server and server_error and on_status: + on_status( + f"Local callback server could not start ({server_error}). " + "You will need to paste the callback URL or authorization code." + ) + + code: str | None = None + try: + if server: + if on_progress and not on_manual_code_input: + on_progress("Waiting for browser callback...") + + tasks: list[asyncio.Task[Any]] = [] + callback_task = asyncio.create_task(asyncio.wait_for(code_future, timeout=120)) + tasks.append(callback_task) + manual_task = asyncio.create_task(_await_manual_input(on_manual_code_input)) + tasks.append(manual_task) + + done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) + for task in pending: + task.cancel() + + for task in done: + try: + result = task.result() + except asyncio.TimeoutError: + result = None + if not result: + continue + if task is manual_task: + parsed_code, parsed_state = _parse_authorization_input(result) + if parsed_state and parsed_state != state: + raise RuntimeError("State validation failed.") + code = parsed_code + else: + code = result + if code: + break + + if not code: + prompt = "Please paste the callback URL or authorization code:" + if on_prompt: + raw = await loop.run_in_executor(None, on_prompt, prompt) + else: + raw = await loop.run_in_executor(None, input, prompt) + parsed_code, parsed_state = _parse_authorization_input(raw) + if parsed_state and parsed_state != state: + raise RuntimeError("State validation failed.") + code = parsed_code + + if not code: + raise RuntimeError("Authorization code not found.") + + if on_progress: + on_progress("Exchanging authorization code for tokens...") + token = await _exchange_code_for_token_async(code, verifier) + _save_token_file(token) + return token + finally: + if server: + server.shutdown() + server.server_close() + + try: + asyncio.get_running_loop() + except RuntimeError: + return asyncio.run(_login_async()) + + result: list[CodexToken] = [] + error: list[Exception] = [] + + def _runner() -> None: + try: + result.append(asyncio.run(_login_async())) + except Exception as exc: + error.append(exc) + + thread = threading.Thread(target=_runner) + thread.start() + thread.join() + if error: + raise error[0] + return result[0] diff --git a/nanobot/auth/codex/models.py b/nanobot/auth/codex/models.py new file mode 100644 index 0000000..e3a5f55 --- /dev/null +++ b/nanobot/auth/codex/models.py @@ -0,0 +1,15 @@ +"""Codex OAuth data models.""" + +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass +class CodexToken: + """Codex OAuth token data structure.""" + + access: str + refresh: str + expires: int + account_id: str diff --git a/nanobot/auth/codex/pkce.py b/nanobot/auth/codex/pkce.py new file mode 100644 index 0000000..b682386 --- /dev/null +++ b/nanobot/auth/codex/pkce.py @@ -0,0 +1,77 @@ +"""PKCE and authorization helpers.""" + +from __future__ import annotations + +import base64 +import hashlib +import json +import os +import urllib.parse +from typing import Any + +from nanobot.auth.codex.constants import JWT_CLAIM_PATH + + +def _base64url(data: bytes) -> str: + return base64.urlsafe_b64encode(data).rstrip(b"=").decode("utf-8") + + +def _decode_base64url(data: str) -> bytes: + padding = "=" * (-len(data) % 4) + return base64.urlsafe_b64decode(data + padding) + + +def _generate_pkce() -> tuple[str, str]: + verifier = _base64url(os.urandom(32)) + challenge = _base64url(hashlib.sha256(verifier.encode("utf-8")).digest()) + return verifier, challenge + + +def _create_state() -> str: + return _base64url(os.urandom(16)) + + +def _parse_authorization_input(raw: str) -> tuple[str | None, str | None]: + value = raw.strip() + if not value: + return None, None + try: + url = urllib.parse.urlparse(value) + qs = urllib.parse.parse_qs(url.query) + code = qs.get("code", [None])[0] + state = qs.get("state", [None])[0] + if code: + return code, state + except Exception: + pass + + if "#" in value: + parts = value.split("#", 1) + return parts[0] or None, parts[1] or None + + if "code=" in value: + qs = urllib.parse.parse_qs(value) + return qs.get("code", [None])[0], qs.get("state", [None])[0] + + return value, None + + +def _decode_account_id(access_token: str) -> str: + parts = access_token.split(".") + if len(parts) != 3: + raise ValueError("Invalid JWT token") + payload = json.loads(_decode_base64url(parts[1]).decode("utf-8")) + auth = payload.get(JWT_CLAIM_PATH) or {} + account_id = auth.get("chatgpt_account_id") + if not account_id: + raise ValueError("Failed to extract account_id from token") + return str(account_id) + + +def _parse_token_payload(payload: dict[str, Any], missing_message: str) -> tuple[str, str, int]: + access = payload.get("access_token") + refresh = payload.get("refresh_token") + expires_in = payload.get("expires_in") + if not access or not refresh or not isinstance(expires_in, int): + raise RuntimeError(missing_message) + return access, refresh, expires_in diff --git a/nanobot/auth/codex/server.py b/nanobot/auth/codex/server.py new file mode 100644 index 0000000..f31db19 --- /dev/null +++ b/nanobot/auth/codex/server.py @@ -0,0 +1,115 @@ +"""Local OAuth callback server.""" + +from __future__ import annotations + +import socket +import threading +import urllib.parse +from http.server import BaseHTTPRequestHandler, HTTPServer +from typing import Any, Callable + +from nanobot.auth.codex.constants import SUCCESS_HTML + + +class _OAuthHandler(BaseHTTPRequestHandler): + """Local callback HTTP handler.""" + + server_version = "NanobotOAuth/1.0" + protocol_version = "HTTP/1.1" + + def do_GET(self) -> None: # noqa: N802 + try: + url = urllib.parse.urlparse(self.path) + if url.path != "/auth/callback": + self.send_response(404) + self.end_headers() + self.wfile.write(b"Not found") + return + + qs = urllib.parse.parse_qs(url.query) + code = qs.get("code", [None])[0] + state = qs.get("state", [None])[0] + + if state != self.server.expected_state: + self.send_response(400) + self.end_headers() + self.wfile.write(b"State mismatch") + return + + if not code: + self.send_response(400) + self.end_headers() + self.wfile.write(b"Missing code") + return + + self.server.code = code + try: + if getattr(self.server, "on_code", None): + self.server.on_code(code) + except Exception: + pass + body = SUCCESS_HTML.encode("utf-8") + self.send_response(200) + self.send_header("Content-Type", "text/html; charset=utf-8") + self.send_header("Content-Length", str(len(body))) + self.send_header("Connection", "close") + self.end_headers() + self.wfile.write(body) + try: + self.wfile.flush() + except Exception: + pass + self.close_connection = True + except Exception: + self.send_response(500) + self.end_headers() + self.wfile.write(b"Internal error") + + def log_message(self, format: str, *args: Any) -> None: # noqa: A003 + # Suppress default logs to avoid noisy output. + return + + +class _OAuthServer(HTTPServer): + """OAuth callback server with state.""" + + def __init__( + self, + server_address: tuple[str, int], + expected_state: str, + on_code: Callable[[str], None] | None = None, + ): + super().__init__(server_address, _OAuthHandler) + self.expected_state = expected_state + self.code: str | None = None + self.on_code = on_code + + +def _start_local_server( + state: str, + on_code: Callable[[str], None] | None = None, +) -> tuple[_OAuthServer | None, str | None]: + """Start a local OAuth callback server on the first available localhost address.""" + try: + addrinfos = socket.getaddrinfo("localhost", 1455, type=socket.SOCK_STREAM) + except OSError as exc: + return None, f"Failed to resolve localhost: {exc}" + + last_error: OSError | None = None + for family, _socktype, _proto, _canonname, sockaddr in addrinfos: + try: + # Support IPv4/IPv6 to avoid missing callbacks when localhost resolves to ::1. + class _AddrOAuthServer(_OAuthServer): + address_family = family + + server = _AddrOAuthServer(sockaddr, state, on_code=on_code) + thread = threading.Thread(target=server.serve_forever, daemon=True) + thread.start() + return server, None + except OSError as exc: + last_error = exc + continue + + if last_error: + return None, f"Local callback server failed to start: {last_error}" + return None, "Local callback server failed to start: unknown error" diff --git a/nanobot/auth/codex/storage.py b/nanobot/auth/codex/storage.py new file mode 100644 index 0000000..31e5e3d --- /dev/null +++ b/nanobot/auth/codex/storage.py @@ -0,0 +1,118 @@ +"""Token storage helpers.""" + +from __future__ import annotations + +import json +import os +import time +from pathlib import Path + +from nanobot.auth.codex.constants import TOKEN_FILENAME +from nanobot.auth.codex.models import CodexToken +from nanobot.utils.helpers import ensure_dir, get_data_path + + +def _get_token_path() -> Path: + auth_dir = ensure_dir(get_data_path() / "auth") + return auth_dir / TOKEN_FILENAME + + +def _load_token_file() -> CodexToken | None: + path = _get_token_path() + if not path.exists(): + return None + try: + data = json.loads(path.read_text(encoding="utf-8")) + return CodexToken( + access=data["access"], + refresh=data["refresh"], + expires=int(data["expires"]), + account_id=data["account_id"], + ) + except Exception: + return None + + +def _save_token_file(token: CodexToken) -> None: + path = _get_token_path() + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text( + json.dumps( + { + "access": token.access, + "refresh": token.refresh, + "expires": token.expires, + "account_id": token.account_id, + }, + ensure_ascii=True, + indent=2, + ), + encoding="utf-8", + ) + try: + os.chmod(path, 0o600) + except Exception: + # Ignore permission setting failures. + pass + + +def _try_import_codex_cli_token() -> CodexToken | None: + codex_path = Path.home() / ".codex" / "auth.json" + if not codex_path.exists(): + return None + try: + data = json.loads(codex_path.read_text(encoding="utf-8")) + tokens = data.get("tokens") or {} + access = tokens.get("access_token") + refresh = tokens.get("refresh_token") + account_id = tokens.get("account_id") + if not access or not refresh or not account_id: + return None + try: + mtime = codex_path.stat().st_mtime + expires = int(mtime * 1000 + 60 * 60 * 1000) + except Exception: + expires = int(time.time() * 1000 + 60 * 60 * 1000) + token = CodexToken( + access=str(access), + refresh=str(refresh), + expires=expires, + account_id=str(account_id), + ) + _save_token_file(token) + return token + except Exception: + return None + + +class _FileLock: + """Simple file lock to reduce concurrent refreshes.""" + + def __init__(self, path: Path): + self._path = path + self._fp = None + + def __enter__(self) -> "_FileLock": + self._path.parent.mkdir(parents=True, exist_ok=True) + self._fp = open(self._path, "a+") + try: + import fcntl + + fcntl.flock(self._fp.fileno(), fcntl.LOCK_EX) + except Exception: + # Non-POSIX or failed lock: continue without locking. + pass + return self + + def __exit__(self, exc_type, exc, tb) -> None: + try: + import fcntl + + fcntl.flock(self._fp.fileno(), fcntl.LOCK_UN) + except Exception: + pass + try: + if self._fp: + self._fp.close() + except Exception: + pass diff --git a/nanobot/auth/codex_oauth.py b/nanobot/auth/codex_oauth.py deleted file mode 100644 index 0784267..0000000 --- a/nanobot/auth/codex_oauth.py +++ /dev/null @@ -1,607 +0,0 @@ -"""OpenAI Codex OAuth implementation.""" - -from __future__ import annotations - -import asyncio -import base64 -import hashlib -import json -import os -import socket -import sys -import threading -import time -import urllib.parse -import webbrowser -from dataclasses import dataclass -from http.server import BaseHTTPRequestHandler, HTTPServer -from pathlib import Path -from typing import Any, Callable - -import httpx - -from nanobot.utils.helpers import ensure_dir, get_data_path - -# Fixed parameters (sourced from the official Codex CLI OAuth client). -CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann" -AUTHORIZE_URL = "https://auth.openai.com/oauth/authorize" -TOKEN_URL = "https://auth.openai.com/oauth/token" -REDIRECT_URI = "http://localhost:1455/auth/callback" -SCOPE = "openid profile email offline_access" -JWT_CLAIM_PATH = "https://api.openai.com/auth" - -DEFAULT_ORIGINATOR = "nanobot" -TOKEN_FILENAME = "codex.json" -MANUAL_PROMPT_DELAY_SEC = 3 -SUCCESS_HTML = ( - "" - "" - "" - "" - "" - "Authentication successful. Return to your terminal to continue.
" - "" - "" -) - - -@dataclass -class CodexToken: - """Codex OAuth token data structure.""" - access: str - refresh: str - expires: int - account_id: str - - -def _base64url(data: bytes) -> str: - return base64.urlsafe_b64encode(data).rstrip(b"=").decode("utf-8") - - -def _decode_base64url(data: str) -> bytes: - padding = "=" * (-len(data) % 4) - return base64.urlsafe_b64decode(data + padding) - - -def _generate_pkce() -> tuple[str, str]: - verifier = _base64url(os.urandom(32)) - challenge = _base64url(hashlib.sha256(verifier.encode("utf-8")).digest()) - return verifier, challenge - - -def _create_state() -> str: - return _base64url(os.urandom(16)) - - -def _get_token_path() -> Path: - auth_dir = ensure_dir(get_data_path() / "auth") - return auth_dir / TOKEN_FILENAME - - -def _parse_authorization_input(raw: str) -> tuple[str | None, str | None]: - value = raw.strip() - if not value: - return None, None - try: - url = urllib.parse.urlparse(value) - qs = urllib.parse.parse_qs(url.query) - code = qs.get("code", [None])[0] - state = qs.get("state", [None])[0] - if code: - return code, state - except Exception: - pass - - if "#" in value: - parts = value.split("#", 1) - return parts[0] or None, parts[1] or None - - if "code=" in value: - qs = urllib.parse.parse_qs(value) - return qs.get("code", [None])[0], qs.get("state", [None])[0] - - return value, None - - -def _decode_account_id(access_token: str) -> str: - parts = access_token.split(".") - if len(parts) != 3: - raise ValueError("Invalid JWT token") - payload = json.loads(_decode_base64url(parts[1]).decode("utf-8")) - auth = payload.get(JWT_CLAIM_PATH) or {} - account_id = auth.get("chatgpt_account_id") - if not account_id: - raise ValueError("Failed to extract account_id from token") - return str(account_id) - - -class _OAuthHandler(BaseHTTPRequestHandler): - """Local callback HTTP handler.""" - - server_version = "NanobotOAuth/1.0" - protocol_version = "HTTP/1.1" - - def do_GET(self) -> None: # noqa: N802 - try: - url = urllib.parse.urlparse(self.path) - if url.path != "/auth/callback": - self.send_response(404) - self.end_headers() - self.wfile.write(b"Not found") - return - - qs = urllib.parse.parse_qs(url.query) - code = qs.get("code", [None])[0] - state = qs.get("state", [None])[0] - - if state != self.server.expected_state: - self.send_response(400) - self.end_headers() - self.wfile.write(b"State mismatch") - return - - if not code: - self.send_response(400) - self.end_headers() - self.wfile.write(b"Missing code") - return - - self.server.code = code - try: - if getattr(self.server, "on_code", None): - self.server.on_code(code) - except Exception: - pass - body = SUCCESS_HTML.encode("utf-8") - self.send_response(200) - self.send_header("Content-Type", "text/html; charset=utf-8") - self.send_header("Content-Length", str(len(body))) - self.send_header("Connection", "close") - self.end_headers() - self.wfile.write(body) - try: - self.wfile.flush() - except Exception: - pass - self.close_connection = True - except Exception: - self.send_response(500) - self.end_headers() - self.wfile.write(b"Internal error") - - def log_message(self, format: str, *args: Any) -> None: # noqa: A003 - # Suppress default logs to avoid noisy output. - return - - -class _OAuthServer(HTTPServer): - """OAuth callback server with state.""" - - def __init__( - self, - server_address: tuple[str, int], - expected_state: str, - on_code: Callable[[str], None] | None = None, - ): - super().__init__(server_address, _OAuthHandler) - self.expected_state = expected_state - self.code: str | None = None - self.on_code = on_code - - -def _start_local_server( - state: str, - on_code: Callable[[str], None] | None = None, -) -> tuple[_OAuthServer | None, str | None]: - """Start a local OAuth callback server on the first available localhost address.""" - try: - addrinfos = socket.getaddrinfo("localhost", 1455, type=socket.SOCK_STREAM) - except OSError as exc: - return None, f"Failed to resolve localhost: {exc}" - - last_error: OSError | None = None - for family, _socktype, _proto, _canonname, sockaddr in addrinfos: - try: - # 兼容 IPv4/IPv6 监听,避免 localhost 解析到 ::1 时收不到回调 - class _AddrOAuthServer(_OAuthServer): - address_family = family - - server = _AddrOAuthServer(sockaddr, state, on_code=on_code) - thread = threading.Thread(target=server.serve_forever, daemon=True) - thread.start() - return server, None - except OSError as exc: - last_error = exc - continue - - if last_error: - return None, f"Local callback server failed to start: {last_error}" - return None, "Local callback server failed to start: unknown error" - - -def _exchange_code_for_token(code: str, verifier: str) -> CodexToken: - data = { - "grant_type": "authorization_code", - "client_id": CLIENT_ID, - "code": code, - "code_verifier": verifier, - "redirect_uri": REDIRECT_URI, - } - with httpx.Client(timeout=30.0) as client: - response = client.post(TOKEN_URL, data=data, headers={"Content-Type": "application/x-www-form-urlencoded"}) - if response.status_code != 200: - raise RuntimeError(f"Token exchange failed: {response.status_code} {response.text}") - - payload = response.json() - access = payload.get("access_token") - refresh = payload.get("refresh_token") - expires_in = payload.get("expires_in") - if not access or not refresh or not isinstance(expires_in, int): - raise RuntimeError("Token response missing fields") - print("Received access token:", access) - account_id = _decode_account_id(access) - return CodexToken( - access=access, - refresh=refresh, - expires=int(time.time() * 1000 + expires_in * 1000), - account_id=account_id, - ) - - -async def _exchange_code_for_token_async(code: str, verifier: str) -> CodexToken: - data = { - "grant_type": "authorization_code", - "client_id": CLIENT_ID, - "code": code, - "code_verifier": verifier, - "redirect_uri": REDIRECT_URI, - } - async with httpx.AsyncClient(timeout=30.0) as client: - response = await client.post( - TOKEN_URL, - data=data, - headers={"Content-Type": "application/x-www-form-urlencoded"}, - ) - if response.status_code != 200: - raise RuntimeError(f"Token exchange failed: {response.status_code} {response.text}") - - payload = response.json() - access = payload.get("access_token") - refresh = payload.get("refresh_token") - expires_in = payload.get("expires_in") - if not access or not refresh or not isinstance(expires_in, int): - raise RuntimeError("Token response missing fields") - - account_id = _decode_account_id(access) - return CodexToken( - access=access, - refresh=refresh, - expires=int(time.time() * 1000 + expires_in * 1000), - account_id=account_id, - ) - - -def _refresh_token(refresh_token: str) -> CodexToken: - data = { - "grant_type": "refresh_token", - "refresh_token": refresh_token, - "client_id": CLIENT_ID, - } - with httpx.Client(timeout=30.0) as client: - response = client.post(TOKEN_URL, data=data, headers={"Content-Type": "application/x-www-form-urlencoded"}) - if response.status_code != 200: - raise RuntimeError(f"Token refresh failed: {response.status_code} {response.text}") - - payload = response.json() - access = payload.get("access_token") - refresh = payload.get("refresh_token") - expires_in = payload.get("expires_in") - if not access or not refresh or not isinstance(expires_in, int): - raise RuntimeError("Token refresh response missing fields") - - account_id = _decode_account_id(access) - return CodexToken( - access=access, - refresh=refresh, - expires=int(time.time() * 1000 + expires_in * 1000), - account_id=account_id, - ) - - -def _load_token_file() -> CodexToken | None: - path = _get_token_path() - if not path.exists(): - return None - try: - data = json.loads(path.read_text(encoding="utf-8")) - return CodexToken( - access=data["access"], - refresh=data["refresh"], - expires=int(data["expires"]), - account_id=data["account_id"], - ) - except Exception: - return None - - -def _save_token_file(token: CodexToken) -> None: - path = _get_token_path() - path.parent.mkdir(parents=True, exist_ok=True) - path.write_text( - json.dumps( - { - "access": token.access, - "refresh": token.refresh, - "expires": token.expires, - "account_id": token.account_id, - }, - ensure_ascii=True, - indent=2, - ), - encoding="utf-8", - ) - try: - os.chmod(path, 0o600) - except Exception: - # Ignore permission setting failures. - pass - - -def _try_import_codex_cli_token() -> CodexToken | None: - codex_path = Path.home() / ".codex" / "auth.json" - if not codex_path.exists(): - return None - try: - data = json.loads(codex_path.read_text(encoding="utf-8")) - tokens = data.get("tokens") or {} - access = tokens.get("access_token") - refresh = tokens.get("refresh_token") - account_id = tokens.get("account_id") - if not access or not refresh or not account_id: - return None - try: - mtime = codex_path.stat().st_mtime - expires = int(mtime * 1000 + 60 * 60 * 1000) - except Exception: - expires = int(time.time() * 1000 + 60 * 60 * 1000) - token = CodexToken( - access=str(access), - refresh=str(refresh), - expires=expires, - account_id=str(account_id), - ) - _save_token_file(token) - return token - except Exception: - return None - - -class _FileLock: - """Simple file lock to reduce concurrent refreshes.""" - - def __init__(self, path: Path): - self._path = path - self._fp = None - - def __enter__(self) -> "_FileLock": - self._path.parent.mkdir(parents=True, exist_ok=True) - self._fp = open(self._path, "a+") - try: - import fcntl - - fcntl.flock(self._fp.fileno(), fcntl.LOCK_EX) - except Exception: - # Non-POSIX or failed lock: continue without locking. - pass - return self - - def __exit__(self, exc_type, exc, tb) -> None: - try: - import fcntl - - fcntl.flock(self._fp.fileno(), fcntl.LOCK_UN) - except Exception: - pass - try: - if self._fp: - self._fp.close() - except Exception: - pass - - -def get_codex_token() -> CodexToken: - """Get an available token (refresh if needed).""" - token = _load_token_file() or _try_import_codex_cli_token() - if not token: - raise RuntimeError("Codex OAuth credentials not found. Please run the login command.") - - # Refresh 60 seconds early. - now_ms = int(time.time() * 1000) - if token.expires - now_ms > 60 * 1000: - return token - - lock_path = _get_token_path().with_suffix(".lock") - with _FileLock(lock_path): - # Re-read to avoid stale token if another process refreshed it. - token = _load_token_file() or token - now_ms = int(time.time() * 1000) - if token.expires - now_ms > 60 * 1000: - return token - try: - refreshed = _refresh_token(token.refresh) - _save_token_file(refreshed) - return refreshed - except Exception: - # If refresh fails, re-read the file to avoid false negatives. - latest = _load_token_file() - if latest and latest.expires - now_ms > 0: - return latest - raise - - -def ensure_codex_token_available() -> None: - """Ensure a valid token is available; raise if not.""" - _ = get_codex_token() - - -async def _read_stdin_line() -> str: - loop = asyncio.get_running_loop() - if hasattr(loop, "add_reader") and sys.stdin: - future: asyncio.Future[str] = loop.create_future() - - def _on_readable() -> None: - line = sys.stdin.readline() - if not future.done(): - future.set_result(line) - - try: - loop.add_reader(sys.stdin, _on_readable) - except Exception: - return await loop.run_in_executor(None, sys.stdin.readline) - - try: - return await future - finally: - try: - loop.remove_reader(sys.stdin) - except Exception: - pass - - return await loop.run_in_executor(None, sys.stdin.readline) - - -async def _await_manual_input( - on_manual_code_input: Callable[[str], None], -) -> str: - await asyncio.sleep(MANUAL_PROMPT_DELAY_SEC) - on_manual_code_input("Paste the authorization code (or full redirect URL), or wait for the browser callback:") - return await _read_stdin_line() - - -def login_codex_oauth_interactive( - on_auth: Callable[[str], None] | None = None, - on_prompt: Callable[[str], str] | None = None, - on_status: Callable[[str], None] | None = None, - on_progress: Callable[[str], None] | None = None, - on_manual_code_input: Callable[[str], None] = None, - originator: str = DEFAULT_ORIGINATOR, -) -> CodexToken: - """Interactive login flow.""" - async def _login_async() -> CodexToken: - verifier, challenge = _generate_pkce() - state = _create_state() - - params = { - "response_type": "code", - "client_id": CLIENT_ID, - "redirect_uri": REDIRECT_URI, - "scope": SCOPE, - "code_challenge": challenge, - "code_challenge_method": "S256", - "state": state, - "id_token_add_organizations": "true", - "codex_cli_simplified_flow": "true", - "originator": originator, - } - url = f"{AUTHORIZE_URL}?{urllib.parse.urlencode(params)}" - - loop = asyncio.get_running_loop() - code_future: asyncio.Future[str] = loop.create_future() - - def _notify(code_value: str) -> None: - if code_future.done(): - return - loop.call_soon_threadsafe(code_future.set_result, code_value) - - server, server_error = _start_local_server(state, on_code=_notify) - if on_auth: - on_auth(url) - else: - webbrowser.open(url) - - if not server and server_error and on_status: - on_status( - f"Local callback server could not start ({server_error}). " - "You will need to paste the callback URL or authorization code." - ) - - code: str | None = None - try: - if server: - if on_progress and not on_manual_code_input: - on_progress("Waiting for browser callback...") - - tasks: list[asyncio.Task[Any]] = [] - callback_task = asyncio.create_task(asyncio.wait_for(code_future, timeout=120)) - tasks.append(callback_task) - manual_task = asyncio.create_task(_await_manual_input(on_manual_code_input)) - tasks.append(manual_task) - - done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) - for task in pending: - task.cancel() - - for task in done: - try: - result = task.result() - except asyncio.TimeoutError: - result = None - if not result: - continue - if task is manual_task: - parsed_code, parsed_state = _parse_authorization_input(result) - if parsed_state and parsed_state != state: - raise RuntimeError("State validation failed.") - code = parsed_code - else: - code = result - if code: - break - - if not code: - prompt = "Please paste the callback URL or authorization code:" - if on_prompt: - raw = await loop.run_in_executor(None, on_prompt, prompt) - else: - raw = await loop.run_in_executor(None, input, prompt) - parsed_code, parsed_state = _parse_authorization_input(raw) - if parsed_state and parsed_state != state: - raise RuntimeError("State validation failed.") - code = parsed_code - - if not code: - raise RuntimeError("Authorization code not found.") - - if on_progress: - on_progress("Exchanging authorization code for tokens...") - token = await _exchange_code_for_token_async(code, verifier) - _save_token_file(token) - return token - finally: - if server: - server.shutdown() - server.server_close() - - try: - asyncio.get_running_loop() - except RuntimeError: - return asyncio.run(_login_async()) - - result: list[CodexToken] = [] - error: list[Exception] = [] - - def _runner() -> None: - try: - result.append(asyncio.run(_login_async())) - except Exception as exc: - error.append(exc) - - thread = threading.Thread(target=_runner) - thread.start() - thread.join() - if error: - raise error[0] - return result[0] diff --git a/nanobot/cli/commands.py b/nanobot/cli/commands.py index 213f8c5..93be424 100644 --- a/nanobot/cli/commands.py +++ b/nanobot/cli/commands.py @@ -82,7 +82,7 @@ def login( console.print(f"[red]Unsupported provider: {provider}[/red]") raise typer.Exit(1) - from nanobot.auth.codex_oauth import login_codex_oauth_interactive + from nanobot.auth.codex import login_codex_oauth_interactive def on_auth(url: str) -> None: console.print("[cyan]A browser window will open for login. If it doesn't, open this URL manually:[/cyan]") @@ -205,7 +205,7 @@ def gateway( from nanobot.bus.queue import MessageBus from nanobot.providers.litellm_provider import LiteLLMProvider from nanobot.providers.openai_codex_provider import OpenAICodexProvider - from nanobot.auth.codex_oauth import ensure_codex_token_available + from nanobot.auth.codex import ensure_codex_token_available from nanobot.agent.loop import AgentLoop from nanobot.channels.manager import ChannelManager from nanobot.cron.service import CronService @@ -341,7 +341,7 @@ def agent( from nanobot.bus.queue import MessageBus from nanobot.providers.litellm_provider import LiteLLMProvider from nanobot.providers.openai_codex_provider import OpenAICodexProvider - from nanobot.auth.codex_oauth import ensure_codex_token_available + from nanobot.auth.codex import ensure_codex_token_available from nanobot.agent.loop import AgentLoop config = load_config() diff --git a/nanobot/providers/openai_codex_provider.py b/nanobot/providers/openai_codex_provider.py index 2081180..ec0383c 100644 --- a/nanobot/providers/openai_codex_provider.py +++ b/nanobot/providers/openai_codex_provider.py @@ -1,4 +1,4 @@ -"""OpenAI Codex Responses Provider。""" +"""OpenAI Codex Responses Provider.""" from __future__ import annotations @@ -9,7 +9,7 @@ from typing import Any, AsyncGenerator import httpx -from nanobot.auth.codex_oauth import get_codex_token +from nanobot.auth.codex import get_codex_token from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest DEFAULT_CODEX_BASE_URL = "https://chatgpt.com/backend-api" @@ -17,7 +17,7 @@ DEFAULT_ORIGINATOR = "nanobot" class OpenAICodexProvider(LLMProvider): - """使用 Codex OAuth 调用 Responses 接口。""" + """Use Codex OAuth to call the Responses API.""" def __init__(self, default_model: str = "openai-codex/gpt-5.1-codex"): super().__init__(api_key=None, api_base=None) @@ -56,37 +56,18 @@ class OpenAICodexProvider(LLMProvider): url = _resolve_codex_url(DEFAULT_CODEX_BASE_URL) try: - async with httpx.AsyncClient(timeout=60.0) as client: - try: - async with client.stream("POST", url, headers=headers, json=body) as response: - if response.status_code != 200: - text = await response.aread() - raise RuntimeError( - _friendly_error(response.status_code, text.decode("utf-8", "ignore")) - ) - content, tool_calls, finish_reason = await _consume_sse(response) - return LLMResponse( - content=content, - tool_calls=tool_calls, - finish_reason=finish_reason, - ) - except Exception as e: - # 证书校验失败时降级关闭校验(存在安全风险) - if "CERTIFICATE_VERIFY_FAILED" not in str(e): - raise - async with httpx.AsyncClient(timeout=60.0, verify=False) as insecure_client: - async with insecure_client.stream("POST", url, headers=headers, json=body) as response: - if response.status_code != 200: - text = await response.aread() - raise RuntimeError( - _friendly_error(response.status_code, text.decode("utf-8", "ignore")) - ) - content, tool_calls, finish_reason = await _consume_sse(response) - return LLMResponse( - content=content, - tool_calls=tool_calls, - finish_reason=finish_reason, - ) + try: + content, tool_calls, finish_reason = await _request_codex(url, headers, body, verify=True) + except Exception as e: + # Certificate verification failed, downgrade to disable verification (security risk) + if "CERTIFICATE_VERIFY_FAILED" not in str(e): + raise + content, tool_calls, finish_reason = await _request_codex(url, headers, body, verify=False) + return LLMResponse( + content=content, + tool_calls=tool_calls, + finish_reason=finish_reason, + ) except Exception as e: return LLMResponse( content=f"Error calling Codex: {str(e)}", @@ -124,17 +105,31 @@ def _build_headers(account_id: str, token: str) -> dict[str, str]: } +async def _request_codex( + url: str, + headers: dict[str, str], + body: dict[str, Any], + verify: bool, +) -> tuple[str, list[ToolCallRequest], str]: + async with httpx.AsyncClient(timeout=60.0, verify=verify) as client: + async with client.stream("POST", url, headers=headers, json=body) as response: + if response.status_code != 200: + text = await response.aread() + raise RuntimeError(_friendly_error(response.status_code, text.decode("utf-8", "ignore"))) + return await _consume_sse(response) + + def _convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]: - # nanobot 工具定义已是 OpenAI function schema + # Nanobot tool definitions already use the OpenAI function schema. converted: list[dict[str, Any]] = [] for tool in tools: name = tool.get("name") if not isinstance(name, str) or not name: - # 忽略无效工具,避免被 Codex 拒绝 + # Skip invalid tools to avoid Codex rejection. continue params = tool.get("parameters") or {} if not isinstance(params, dict): - # 参数必须是 JSON Schema 对象 + # Parameters must be a JSON Schema object. params = {} converted.append( { @@ -164,7 +159,7 @@ def _convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[st continue if role == "assistant": - # 先处理文本 + # Handle text first. if isinstance(content, str) and content: input_items.append( { @@ -175,7 +170,7 @@ def _convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[st "id": f"msg_{idx}", } ) - # 再处理工具调用 + # Then handle tool calls. for tool_call in msg.get("tool_calls", []) or []: fn = tool_call.get("function") or {} call_id = tool_call.get("id") or f"call_{idx}" @@ -329,5 +324,5 @@ def _map_finish_reason(status: str | None) -> str: def _friendly_error(status_code: int, raw: str) -> str: if status_code == 429: - return "ChatGPT 使用额度已达上限或触发限流,请稍后再试。" + return "ChatGPT usage quota exceeded or rate limit triggered. Please try again later." return f"HTTP {status_code}: {raw}"