diff --git a/src/chap/backends/anthropic.py b/src/chap/backends/anthropic.py index f3d8f41..ad4a8f9 100644 --- a/src/chap/backends/anthropic.py +++ b/src/chap/backends/anthropic.py @@ -9,16 +9,17 @@ from typing import AsyncGenerator, Any import httpx from ..core import AutoAskMixin, Backend -from ..key import get_key +from ..key import UsesKeyMixin from ..session import Assistant, Role, Session, User -class Anthropic(AutoAskMixin): +class Anthropic(AutoAskMixin, UsesKeyMixin): @dataclass class Parameters: url: str = "https://api.anthropic.com" model: str = "claude-3-5-sonnet-20240620" max_new_tokens: int = 1000 + api_key_name = "anthropic_api_key" def __init__(self) -> None: super().__init__() @@ -88,10 +89,6 @@ Answer each question accurately and thoroughly. session.extend([User(query), Assistant("".join(new_content))]) - @classmethod - def get_key(cls) -> str: - return get_key("anthropic_api_key") - def factory() -> Backend: """Uses the anthropic text-generation-interface web API""" diff --git a/src/chap/backends/huggingface.py b/src/chap/backends/huggingface.py index cdd2629..ff0a3ef 100644 --- a/src/chap/backends/huggingface.py +++ b/src/chap/backends/huggingface.py @@ -9,11 +9,11 @@ from typing import Any, AsyncGenerator import httpx from ..core import AutoAskMixin, Backend -from ..key import get_key +from ..key import UsesKeyMixin from ..session import Assistant, Role, Session, User -class HuggingFace(AutoAskMixin): +class HuggingFace(AutoAskMixin, UsesKeyMixin): @dataclass class Parameters: url: str = "https://api-inference.huggingface.co" @@ -24,6 +24,7 @@ class HuggingFace(AutoAskMixin): after_user: str = """ [/INST] """ after_assistant: str = """ [INST] """ stop_token_id = 2 + api_key_name = "huggingface_api_token" def __init__(self) -> None: super().__init__() @@ -110,10 +111,6 @@ A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, a session.extend([User(query), Assistant("".join(new_content))]) - @classmethod - def get_key(cls) -> str: - return get_key("huggingface_api_token") - def factory() -> Backend: """Uses the huggingface text-generation-interface web API""" diff --git a/src/chap/backends/mistral.py b/src/chap/backends/mistral.py index c254632..93cdcaf 100644 --- a/src/chap/backends/mistral.py +++ b/src/chap/backends/mistral.py @@ -9,16 +9,17 @@ from typing import AsyncGenerator, Any import httpx from ..core import AutoAskMixin -from ..key import get_key +from ..key import UsesKeyMixin from ..session import Assistant, Session, User -class Mistral(AutoAskMixin): +class Mistral(AutoAskMixin, UsesKeyMixin): @dataclass class Parameters: url: str = "https://api.mistral.ai" model: str = "open-mistral-7b" max_new_tokens: int = 1000 + api_key_name = "mistral_api_key" def __init__(self) -> None: super().__init__() @@ -91,9 +92,5 @@ Answer each question accurately and thoroughly. session.extend([User(query), Assistant("".join(new_content))]) - @classmethod - def get_key(cls) -> str: - return get_key("mistral_api_key") - factory = Mistral diff --git a/src/chap/backends/openai_chatgpt.py b/src/chap/backends/openai_chatgpt.py index cd7dc16..2caf2ae 100644 --- a/src/chap/backends/openai_chatgpt.py +++ b/src/chap/backends/openai_chatgpt.py @@ -12,7 +12,7 @@ import httpx import tiktoken from ..core import Backend -from ..key import get_key +from ..key import UsesKeyMixin from ..session import Assistant, Message, Session, User, session_to_list @@ -63,7 +63,7 @@ class EncodingMeta: return cls(encoding, tokens_per_message, tokens_per_name, tokens_overhead) -class ChatGPT: +class ChatGPT(UsesKeyMixin): @dataclass class Parameters: model: str = "gpt-4o-mini" @@ -81,6 +81,11 @@ class ChatGPT: top_p: float | None = None """The model temperature for sampling""" + api_key_name: str = "openai_api_key" + """The OpenAI API key""" + + parameters: Parameters + def __init__(self) -> None: self.parameters = self.Parameters() @@ -171,10 +176,6 @@ class ChatGPT: session.extend([User(query), Assistant("".join(new_content))]) - @classmethod - def get_key(cls) -> str: - return get_key("openai_api_key") - def factory() -> Backend: """Uses the OpenAI chat completion API""" diff --git a/src/chap/key.py b/src/chap/key.py index e292905..edb8381 100644 --- a/src/chap/key.py +++ b/src/chap/key.py @@ -2,25 +2,61 @@ # # SPDX-License-Identifier: MIT +import json +import subprocess +from typing import Protocol import functools import platformdirs +class APIKeyProtocol(Protocol): + @property + def api_key_name(self) -> str: + ... + + +class HasKeyProtocol(Protocol): + @property + def parameters(self) -> APIKeyProtocol: + ... + + +class UsesKeyMixin: + def get_key(self: HasKeyProtocol) -> str: + return get_key(self.parameters.api_key_name) + + class NoKeyAvailable(Exception): pass _key_path_base = platformdirs.user_config_path("chap") +USE_PASSWORD_STORE = _key_path_base / "USE_PASSWORD_STORE" -@functools.cache -def get_key(name: str, what: str = "openai api key") -> str: - key_path = _key_path_base / name - if not key_path.exists(): - raise NoKeyAvailable( - f"Place your {what} in {key_path} and run the program again" - ) +if USE_PASSWORD_STORE.exists(): + content = USE_PASSWORD_STORE.read_text(encoding="utf-8") + if content.strip(): + cfg = json.loads(content) + pass_command: list[str] = cfg.get("PASS_COMMAND", ["pass", "show"]) + pass_prefix: str = cfg.get("PASS_PREFIX", "chap/") - with open(key_path, encoding="utf-8") as f: - return f.read().strip() + @functools.cache + def get_key(name: str, what: str = "api key") -> str: + key_path = f"{pass_prefix}{name}" + command = pass_command + [key_path] + return subprocess.check_output(command, encoding="utf-8").split("\n")[0] + +else: + + @functools.cache + def get_key(name: str, what: str = "api key") -> str: + key_path = _key_path_base / name + if not key_path.exists(): + raise NoKeyAvailable( + f"Place your {what} in {key_path} and run the program again" + ) + + with open(key_path, encoding="utf-8") as f: + return f.read().strip()