Merge pull request #40 from jepler/improve-key-handling
This commit is contained in:
commit
993b17845c
5 changed files with 61 additions and 33 deletions
|
|
@ -9,16 +9,17 @@ from typing import AsyncGenerator, Any
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from ..core import AutoAskMixin, Backend
|
from ..core import AutoAskMixin, Backend
|
||||||
from ..key import get_key
|
from ..key import UsesKeyMixin
|
||||||
from ..session import Assistant, Role, Session, User
|
from ..session import Assistant, Role, Session, User
|
||||||
|
|
||||||
|
|
||||||
class Anthropic(AutoAskMixin):
|
class Anthropic(AutoAskMixin, UsesKeyMixin):
|
||||||
@dataclass
|
@dataclass
|
||||||
class Parameters:
|
class Parameters:
|
||||||
url: str = "https://api.anthropic.com"
|
url: str = "https://api.anthropic.com"
|
||||||
model: str = "claude-3-5-sonnet-20240620"
|
model: str = "claude-3-5-sonnet-20240620"
|
||||||
max_new_tokens: int = 1000
|
max_new_tokens: int = 1000
|
||||||
|
api_key_name = "anthropic_api_key"
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
@ -88,10 +89,6 @@ Answer each question accurately and thoroughly.
|
||||||
|
|
||||||
session.extend([User(query), Assistant("".join(new_content))])
|
session.extend([User(query), Assistant("".join(new_content))])
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_key(cls) -> str:
|
|
||||||
return get_key("anthropic_api_key")
|
|
||||||
|
|
||||||
|
|
||||||
def factory() -> Backend:
|
def factory() -> Backend:
|
||||||
"""Uses the anthropic text-generation-interface web API"""
|
"""Uses the anthropic text-generation-interface web API"""
|
||||||
|
|
|
||||||
|
|
@ -9,11 +9,11 @@ from typing import Any, AsyncGenerator
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from ..core import AutoAskMixin, Backend
|
from ..core import AutoAskMixin, Backend
|
||||||
from ..key import get_key
|
from ..key import UsesKeyMixin
|
||||||
from ..session import Assistant, Role, Session, User
|
from ..session import Assistant, Role, Session, User
|
||||||
|
|
||||||
|
|
||||||
class HuggingFace(AutoAskMixin):
|
class HuggingFace(AutoAskMixin, UsesKeyMixin):
|
||||||
@dataclass
|
@dataclass
|
||||||
class Parameters:
|
class Parameters:
|
||||||
url: str = "https://api-inference.huggingface.co"
|
url: str = "https://api-inference.huggingface.co"
|
||||||
|
|
@ -24,6 +24,7 @@ class HuggingFace(AutoAskMixin):
|
||||||
after_user: str = """ [/INST] """
|
after_user: str = """ [/INST] """
|
||||||
after_assistant: str = """ </s><s>[INST] """
|
after_assistant: str = """ </s><s>[INST] """
|
||||||
stop_token_id = 2
|
stop_token_id = 2
|
||||||
|
api_key_name = "huggingface_api_token"
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
@ -110,10 +111,6 @@ A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, a
|
||||||
|
|
||||||
session.extend([User(query), Assistant("".join(new_content))])
|
session.extend([User(query), Assistant("".join(new_content))])
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_key(cls) -> str:
|
|
||||||
return get_key("huggingface_api_token")
|
|
||||||
|
|
||||||
|
|
||||||
def factory() -> Backend:
|
def factory() -> Backend:
|
||||||
"""Uses the huggingface text-generation-interface web API"""
|
"""Uses the huggingface text-generation-interface web API"""
|
||||||
|
|
|
||||||
|
|
@ -9,16 +9,17 @@ from typing import AsyncGenerator, Any
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from ..core import AutoAskMixin
|
from ..core import AutoAskMixin
|
||||||
from ..key import get_key
|
from ..key import UsesKeyMixin
|
||||||
from ..session import Assistant, Session, User
|
from ..session import Assistant, Session, User
|
||||||
|
|
||||||
|
|
||||||
class Mistral(AutoAskMixin):
|
class Mistral(AutoAskMixin, UsesKeyMixin):
|
||||||
@dataclass
|
@dataclass
|
||||||
class Parameters:
|
class Parameters:
|
||||||
url: str = "https://api.mistral.ai"
|
url: str = "https://api.mistral.ai"
|
||||||
model: str = "open-mistral-7b"
|
model: str = "open-mistral-7b"
|
||||||
max_new_tokens: int = 1000
|
max_new_tokens: int = 1000
|
||||||
|
api_key_name = "mistral_api_key"
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
@ -91,9 +92,5 @@ Answer each question accurately and thoroughly.
|
||||||
|
|
||||||
session.extend([User(query), Assistant("".join(new_content))])
|
session.extend([User(query), Assistant("".join(new_content))])
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_key(cls) -> str:
|
|
||||||
return get_key("mistral_api_key")
|
|
||||||
|
|
||||||
|
|
||||||
factory = Mistral
|
factory = Mistral
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ import httpx
|
||||||
import tiktoken
|
import tiktoken
|
||||||
|
|
||||||
from ..core import Backend
|
from ..core import Backend
|
||||||
from ..key import get_key
|
from ..key import UsesKeyMixin
|
||||||
from ..session import Assistant, Message, Session, User, session_to_list
|
from ..session import Assistant, Message, Session, User, session_to_list
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -63,7 +63,7 @@ class EncodingMeta:
|
||||||
return cls(encoding, tokens_per_message, tokens_per_name, tokens_overhead)
|
return cls(encoding, tokens_per_message, tokens_per_name, tokens_overhead)
|
||||||
|
|
||||||
|
|
||||||
class ChatGPT:
|
class ChatGPT(UsesKeyMixin):
|
||||||
@dataclass
|
@dataclass
|
||||||
class Parameters:
|
class Parameters:
|
||||||
model: str = "gpt-4o-mini"
|
model: str = "gpt-4o-mini"
|
||||||
|
|
@ -81,6 +81,11 @@ class ChatGPT:
|
||||||
top_p: float | None = None
|
top_p: float | None = None
|
||||||
"""The model temperature for sampling"""
|
"""The model temperature for sampling"""
|
||||||
|
|
||||||
|
api_key_name: str = "openai_api_key"
|
||||||
|
"""The OpenAI API key"""
|
||||||
|
|
||||||
|
parameters: Parameters
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.parameters = self.Parameters()
|
self.parameters = self.Parameters()
|
||||||
|
|
||||||
|
|
@ -171,10 +176,6 @@ class ChatGPT:
|
||||||
|
|
||||||
session.extend([User(query), Assistant("".join(new_content))])
|
session.extend([User(query), Assistant("".join(new_content))])
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_key(cls) -> str:
|
|
||||||
return get_key("openai_api_key")
|
|
||||||
|
|
||||||
|
|
||||||
def factory() -> Backend:
|
def factory() -> Backend:
|
||||||
"""Uses the OpenAI chat completion API"""
|
"""Uses the OpenAI chat completion API"""
|
||||||
|
|
|
||||||
|
|
@ -2,25 +2,61 @@
|
||||||
#
|
#
|
||||||
# SPDX-License-Identifier: MIT
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
import json
|
||||||
|
import subprocess
|
||||||
|
from typing import Protocol
|
||||||
import functools
|
import functools
|
||||||
|
|
||||||
import platformdirs
|
import platformdirs
|
||||||
|
|
||||||
|
|
||||||
|
class APIKeyProtocol(Protocol):
|
||||||
|
@property
|
||||||
|
def api_key_name(self) -> str:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class HasKeyProtocol(Protocol):
|
||||||
|
@property
|
||||||
|
def parameters(self) -> APIKeyProtocol:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class UsesKeyMixin:
|
||||||
|
def get_key(self: HasKeyProtocol) -> str:
|
||||||
|
return get_key(self.parameters.api_key_name)
|
||||||
|
|
||||||
|
|
||||||
class NoKeyAvailable(Exception):
|
class NoKeyAvailable(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
_key_path_base = platformdirs.user_config_path("chap")
|
_key_path_base = platformdirs.user_config_path("chap")
|
||||||
|
|
||||||
|
USE_PASSWORD_STORE = _key_path_base / "USE_PASSWORD_STORE"
|
||||||
|
|
||||||
@functools.cache
|
if USE_PASSWORD_STORE.exists():
|
||||||
def get_key(name: str, what: str = "openai api key") -> str:
|
content = USE_PASSWORD_STORE.read_text(encoding="utf-8")
|
||||||
key_path = _key_path_base / name
|
if content.strip():
|
||||||
if not key_path.exists():
|
cfg = json.loads(content)
|
||||||
raise NoKeyAvailable(
|
pass_command: list[str] = cfg.get("PASS_COMMAND", ["pass", "show"])
|
||||||
f"Place your {what} in {key_path} and run the program again"
|
pass_prefix: str = cfg.get("PASS_PREFIX", "chap/")
|
||||||
)
|
|
||||||
|
|
||||||
with open(key_path, encoding="utf-8") as f:
|
@functools.cache
|
||||||
return f.read().strip()
|
def get_key(name: str, what: str = "api key") -> str:
|
||||||
|
key_path = f"{pass_prefix}{name}"
|
||||||
|
command = pass_command + [key_path]
|
||||||
|
return subprocess.check_output(command, encoding="utf-8").split("\n")[0]
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
@functools.cache
|
||||||
|
def get_key(name: str, what: str = "api key") -> str:
|
||||||
|
key_path = _key_path_base / name
|
||||||
|
if not key_path.exists():
|
||||||
|
raise NoKeyAvailable(
|
||||||
|
f"Place your {what} in {key_path} and run the program again"
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(key_path, encoding="utf-8") as f:
|
||||||
|
return f.read().strip()
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue