Merge pull request #40 from jepler/improve-key-handling
This commit is contained in:
commit
993b17845c
5 changed files with 61 additions and 33 deletions
|
|
@ -9,16 +9,17 @@ from typing import AsyncGenerator, Any
|
|||
import httpx
|
||||
|
||||
from ..core import AutoAskMixin, Backend
|
||||
from ..key import get_key
|
||||
from ..key import UsesKeyMixin
|
||||
from ..session import Assistant, Role, Session, User
|
||||
|
||||
|
||||
class Anthropic(AutoAskMixin):
|
||||
class Anthropic(AutoAskMixin, UsesKeyMixin):
|
||||
@dataclass
|
||||
class Parameters:
|
||||
url: str = "https://api.anthropic.com"
|
||||
model: str = "claude-3-5-sonnet-20240620"
|
||||
max_new_tokens: int = 1000
|
||||
api_key_name = "anthropic_api_key"
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
|
@ -88,10 +89,6 @@ Answer each question accurately and thoroughly.
|
|||
|
||||
session.extend([User(query), Assistant("".join(new_content))])
|
||||
|
||||
@classmethod
|
||||
def get_key(cls) -> str:
|
||||
return get_key("anthropic_api_key")
|
||||
|
||||
|
||||
def factory() -> Backend:
|
||||
"""Uses the anthropic text-generation-interface web API"""
|
||||
|
|
|
|||
|
|
@ -9,11 +9,11 @@ from typing import Any, AsyncGenerator
|
|||
import httpx
|
||||
|
||||
from ..core import AutoAskMixin, Backend
|
||||
from ..key import get_key
|
||||
from ..key import UsesKeyMixin
|
||||
from ..session import Assistant, Role, Session, User
|
||||
|
||||
|
||||
class HuggingFace(AutoAskMixin):
|
||||
class HuggingFace(AutoAskMixin, UsesKeyMixin):
|
||||
@dataclass
|
||||
class Parameters:
|
||||
url: str = "https://api-inference.huggingface.co"
|
||||
|
|
@ -24,6 +24,7 @@ class HuggingFace(AutoAskMixin):
|
|||
after_user: str = """ [/INST] """
|
||||
after_assistant: str = """ </s><s>[INST] """
|
||||
stop_token_id = 2
|
||||
api_key_name = "huggingface_api_token"
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
|
@ -110,10 +111,6 @@ A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, a
|
|||
|
||||
session.extend([User(query), Assistant("".join(new_content))])
|
||||
|
||||
@classmethod
|
||||
def get_key(cls) -> str:
|
||||
return get_key("huggingface_api_token")
|
||||
|
||||
|
||||
def factory() -> Backend:
|
||||
"""Uses the huggingface text-generation-interface web API"""
|
||||
|
|
|
|||
|
|
@ -9,16 +9,17 @@ from typing import AsyncGenerator, Any
|
|||
import httpx
|
||||
|
||||
from ..core import AutoAskMixin
|
||||
from ..key import get_key
|
||||
from ..key import UsesKeyMixin
|
||||
from ..session import Assistant, Session, User
|
||||
|
||||
|
||||
class Mistral(AutoAskMixin):
|
||||
class Mistral(AutoAskMixin, UsesKeyMixin):
|
||||
@dataclass
|
||||
class Parameters:
|
||||
url: str = "https://api.mistral.ai"
|
||||
model: str = "open-mistral-7b"
|
||||
max_new_tokens: int = 1000
|
||||
api_key_name = "mistral_api_key"
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
|
@ -91,9 +92,5 @@ Answer each question accurately and thoroughly.
|
|||
|
||||
session.extend([User(query), Assistant("".join(new_content))])
|
||||
|
||||
@classmethod
|
||||
def get_key(cls) -> str:
|
||||
return get_key("mistral_api_key")
|
||||
|
||||
|
||||
factory = Mistral
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ import httpx
|
|||
import tiktoken
|
||||
|
||||
from ..core import Backend
|
||||
from ..key import get_key
|
||||
from ..key import UsesKeyMixin
|
||||
from ..session import Assistant, Message, Session, User, session_to_list
|
||||
|
||||
|
||||
|
|
@ -63,7 +63,7 @@ class EncodingMeta:
|
|||
return cls(encoding, tokens_per_message, tokens_per_name, tokens_overhead)
|
||||
|
||||
|
||||
class ChatGPT:
|
||||
class ChatGPT(UsesKeyMixin):
|
||||
@dataclass
|
||||
class Parameters:
|
||||
model: str = "gpt-4o-mini"
|
||||
|
|
@ -81,6 +81,11 @@ class ChatGPT:
|
|||
top_p: float | None = None
|
||||
"""The model temperature for sampling"""
|
||||
|
||||
api_key_name: str = "openai_api_key"
|
||||
"""The OpenAI API key"""
|
||||
|
||||
parameters: Parameters
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.parameters = self.Parameters()
|
||||
|
||||
|
|
@ -171,10 +176,6 @@ class ChatGPT:
|
|||
|
||||
session.extend([User(query), Assistant("".join(new_content))])
|
||||
|
||||
@classmethod
|
||||
def get_key(cls) -> str:
|
||||
return get_key("openai_api_key")
|
||||
|
||||
|
||||
def factory() -> Backend:
|
||||
"""Uses the OpenAI chat completion API"""
|
||||
|
|
|
|||
|
|
@ -2,25 +2,61 @@
|
|||
#
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import json
|
||||
import subprocess
|
||||
from typing import Protocol
|
||||
import functools
|
||||
|
||||
import platformdirs
|
||||
|
||||
|
||||
class APIKeyProtocol(Protocol):
|
||||
@property
|
||||
def api_key_name(self) -> str:
|
||||
...
|
||||
|
||||
|
||||
class HasKeyProtocol(Protocol):
|
||||
@property
|
||||
def parameters(self) -> APIKeyProtocol:
|
||||
...
|
||||
|
||||
|
||||
class UsesKeyMixin:
|
||||
def get_key(self: HasKeyProtocol) -> str:
|
||||
return get_key(self.parameters.api_key_name)
|
||||
|
||||
|
||||
class NoKeyAvailable(Exception):
|
||||
pass
|
||||
|
||||
|
||||
_key_path_base = platformdirs.user_config_path("chap")
|
||||
|
||||
USE_PASSWORD_STORE = _key_path_base / "USE_PASSWORD_STORE"
|
||||
|
||||
@functools.cache
|
||||
def get_key(name: str, what: str = "openai api key") -> str:
|
||||
key_path = _key_path_base / name
|
||||
if not key_path.exists():
|
||||
raise NoKeyAvailable(
|
||||
f"Place your {what} in {key_path} and run the program again"
|
||||
)
|
||||
if USE_PASSWORD_STORE.exists():
|
||||
content = USE_PASSWORD_STORE.read_text(encoding="utf-8")
|
||||
if content.strip():
|
||||
cfg = json.loads(content)
|
||||
pass_command: list[str] = cfg.get("PASS_COMMAND", ["pass", "show"])
|
||||
pass_prefix: str = cfg.get("PASS_PREFIX", "chap/")
|
||||
|
||||
with open(key_path, encoding="utf-8") as f:
|
||||
return f.read().strip()
|
||||
@functools.cache
|
||||
def get_key(name: str, what: str = "api key") -> str:
|
||||
key_path = f"{pass_prefix}{name}"
|
||||
command = pass_command + [key_path]
|
||||
return subprocess.check_output(command, encoding="utf-8").split("\n")[0]
|
||||
|
||||
else:
|
||||
|
||||
@functools.cache
|
||||
def get_key(name: str, what: str = "api key") -> str:
|
||||
key_path = _key_path_base / name
|
||||
if not key_path.exists():
|
||||
raise NoKeyAvailable(
|
||||
f"Place your {what} in {key_path} and run the program again"
|
||||
)
|
||||
|
||||
with open(key_path, encoding="utf-8") as f:
|
||||
return f.read().strip()
|
||||
|
|
|
|||
Loading…
Reference in a new issue