Merge pull request #40 from jepler/improve-key-handling

This commit is contained in:
Jeff Epler 2024-10-23 11:47:20 -05:00 committed by GitHub
commit 993b17845c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 61 additions and 33 deletions

View file

@ -9,16 +9,17 @@ from typing import AsyncGenerator, Any
import httpx import httpx
from ..core import AutoAskMixin, Backend from ..core import AutoAskMixin, Backend
from ..key import get_key from ..key import UsesKeyMixin
from ..session import Assistant, Role, Session, User from ..session import Assistant, Role, Session, User
class Anthropic(AutoAskMixin): class Anthropic(AutoAskMixin, UsesKeyMixin):
@dataclass @dataclass
class Parameters: class Parameters:
url: str = "https://api.anthropic.com" url: str = "https://api.anthropic.com"
model: str = "claude-3-5-sonnet-20240620" model: str = "claude-3-5-sonnet-20240620"
max_new_tokens: int = 1000 max_new_tokens: int = 1000
api_key_name = "anthropic_api_key"
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
@ -88,10 +89,6 @@ Answer each question accurately and thoroughly.
session.extend([User(query), Assistant("".join(new_content))]) session.extend([User(query), Assistant("".join(new_content))])
@classmethod
def get_key(cls) -> str:
return get_key("anthropic_api_key")
def factory() -> Backend: def factory() -> Backend:
"""Uses the anthropic text-generation-interface web API""" """Uses the anthropic text-generation-interface web API"""

View file

@ -9,11 +9,11 @@ from typing import Any, AsyncGenerator
import httpx import httpx
from ..core import AutoAskMixin, Backend from ..core import AutoAskMixin, Backend
from ..key import get_key from ..key import UsesKeyMixin
from ..session import Assistant, Role, Session, User from ..session import Assistant, Role, Session, User
class HuggingFace(AutoAskMixin): class HuggingFace(AutoAskMixin, UsesKeyMixin):
@dataclass @dataclass
class Parameters: class Parameters:
url: str = "https://api-inference.huggingface.co" url: str = "https://api-inference.huggingface.co"
@ -24,6 +24,7 @@ class HuggingFace(AutoAskMixin):
after_user: str = """ [/INST] """ after_user: str = """ [/INST] """
after_assistant: str = """ </s><s>[INST] """ after_assistant: str = """ </s><s>[INST] """
stop_token_id = 2 stop_token_id = 2
api_key_name = "huggingface_api_token"
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
@ -110,10 +111,6 @@ A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, a
session.extend([User(query), Assistant("".join(new_content))]) session.extend([User(query), Assistant("".join(new_content))])
@classmethod
def get_key(cls) -> str:
return get_key("huggingface_api_token")
def factory() -> Backend: def factory() -> Backend:
"""Uses the huggingface text-generation-interface web API""" """Uses the huggingface text-generation-interface web API"""

View file

@ -9,16 +9,17 @@ from typing import AsyncGenerator, Any
import httpx import httpx
from ..core import AutoAskMixin from ..core import AutoAskMixin
from ..key import get_key from ..key import UsesKeyMixin
from ..session import Assistant, Session, User from ..session import Assistant, Session, User
class Mistral(AutoAskMixin): class Mistral(AutoAskMixin, UsesKeyMixin):
@dataclass @dataclass
class Parameters: class Parameters:
url: str = "https://api.mistral.ai" url: str = "https://api.mistral.ai"
model: str = "open-mistral-7b" model: str = "open-mistral-7b"
max_new_tokens: int = 1000 max_new_tokens: int = 1000
api_key_name = "mistral_api_key"
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
@ -91,9 +92,5 @@ Answer each question accurately and thoroughly.
session.extend([User(query), Assistant("".join(new_content))]) session.extend([User(query), Assistant("".join(new_content))])
@classmethod
def get_key(cls) -> str:
return get_key("mistral_api_key")
factory = Mistral factory = Mistral

View file

@ -12,7 +12,7 @@ import httpx
import tiktoken import tiktoken
from ..core import Backend from ..core import Backend
from ..key import get_key from ..key import UsesKeyMixin
from ..session import Assistant, Message, Session, User, session_to_list from ..session import Assistant, Message, Session, User, session_to_list
@ -63,7 +63,7 @@ class EncodingMeta:
return cls(encoding, tokens_per_message, tokens_per_name, tokens_overhead) return cls(encoding, tokens_per_message, tokens_per_name, tokens_overhead)
class ChatGPT: class ChatGPT(UsesKeyMixin):
@dataclass @dataclass
class Parameters: class Parameters:
model: str = "gpt-4o-mini" model: str = "gpt-4o-mini"
@ -81,6 +81,11 @@ class ChatGPT:
top_p: float | None = None top_p: float | None = None
"""The model temperature for sampling""" """The model temperature for sampling"""
api_key_name: str = "openai_api_key"
"""The OpenAI API key"""
parameters: Parameters
def __init__(self) -> None: def __init__(self) -> None:
self.parameters = self.Parameters() self.parameters = self.Parameters()
@ -171,10 +176,6 @@ class ChatGPT:
session.extend([User(query), Assistant("".join(new_content))]) session.extend([User(query), Assistant("".join(new_content))])
@classmethod
def get_key(cls) -> str:
return get_key("openai_api_key")
def factory() -> Backend: def factory() -> Backend:
"""Uses the OpenAI chat completion API""" """Uses the OpenAI chat completion API"""

View file

@ -2,25 +2,61 @@
# #
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
import json
import subprocess
from typing import Protocol
import functools import functools
import platformdirs import platformdirs
class APIKeyProtocol(Protocol):
@property
def api_key_name(self) -> str:
...
class HasKeyProtocol(Protocol):
@property
def parameters(self) -> APIKeyProtocol:
...
class UsesKeyMixin:
def get_key(self: HasKeyProtocol) -> str:
return get_key(self.parameters.api_key_name)
class NoKeyAvailable(Exception): class NoKeyAvailable(Exception):
pass pass
_key_path_base = platformdirs.user_config_path("chap") _key_path_base = platformdirs.user_config_path("chap")
USE_PASSWORD_STORE = _key_path_base / "USE_PASSWORD_STORE"
@functools.cache if USE_PASSWORD_STORE.exists():
def get_key(name: str, what: str = "openai api key") -> str: content = USE_PASSWORD_STORE.read_text(encoding="utf-8")
key_path = _key_path_base / name if content.strip():
if not key_path.exists(): cfg = json.loads(content)
raise NoKeyAvailable( pass_command: list[str] = cfg.get("PASS_COMMAND", ["pass", "show"])
f"Place your {what} in {key_path} and run the program again" pass_prefix: str = cfg.get("PASS_PREFIX", "chap/")
)
with open(key_path, encoding="utf-8") as f: @functools.cache
return f.read().strip() def get_key(name: str, what: str = "api key") -> str:
key_path = f"{pass_prefix}{name}"
command = pass_command + [key_path]
return subprocess.check_output(command, encoding="utf-8").split("\n")[0]
else:
@functools.cache
def get_key(name: str, what: str = "api key") -> str:
key_path = _key_path_base / name
if not key_path.exists():
raise NoKeyAvailable(
f"Place your {what} in {key_path} and run the program again"
)
with open(key_path, encoding="utf-8") as f:
return f.read().strip()