Add UsesKeyMixin & make all key names settable as parameters

This commit is contained in:
Jeff Epler 2024-10-23 11:42:31 -05:00
parent 06482245b7
commit d38a98ad90
5 changed files with 34 additions and 24 deletions

View file

@ -9,16 +9,17 @@ from typing import AsyncGenerator, Any
import httpx
from ..core import AutoAskMixin, Backend
from ..key import get_key
from ..key import UsesKeyMixin
from ..session import Assistant, Role, Session, User
class Anthropic(AutoAskMixin):
class Anthropic(AutoAskMixin, UsesKeyMixin):
@dataclass
class Parameters:
url: str = "https://api.anthropic.com"
model: str = "claude-3-5-sonnet-20240620"
max_new_tokens: int = 1000
api_key_name = "anthropic_api_key"
def __init__(self) -> None:
super().__init__()
@ -88,10 +89,6 @@ Answer each question accurately and thoroughly.
session.extend([User(query), Assistant("".join(new_content))])
@classmethod
def get_key(cls) -> str:
return get_key("anthropic_api_key")
def factory() -> Backend:
"""Uses the anthropic text-generation-interface web API"""

View file

@ -9,11 +9,11 @@ from typing import Any, AsyncGenerator
import httpx
from ..core import AutoAskMixin, Backend
from ..key import get_key
from ..key import UsesKeyMixin
from ..session import Assistant, Role, Session, User
class HuggingFace(AutoAskMixin):
class HuggingFace(AutoAskMixin, UsesKeyMixin):
@dataclass
class Parameters:
url: str = "https://api-inference.huggingface.co"
@ -24,6 +24,7 @@ class HuggingFace(AutoAskMixin):
after_user: str = """ [/INST] """
after_assistant: str = """ </s><s>[INST] """
stop_token_id = 2
api_key_name = "huggingface_api_token"
def __init__(self) -> None:
super().__init__()
@ -110,10 +111,6 @@ A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, a
session.extend([User(query), Assistant("".join(new_content))])
@classmethod
def get_key(cls) -> str:
return get_key("huggingface_api_token")
def factory() -> Backend:
"""Uses the huggingface text-generation-interface web API"""

View file

@ -9,16 +9,17 @@ from typing import AsyncGenerator, Any
import httpx
from ..core import AutoAskMixin
from ..key import get_key
from ..key import UsesKeyMixin
from ..session import Assistant, Session, User
class Mistral(AutoAskMixin):
class Mistral(AutoAskMixin, UsesKeyMixin):
@dataclass
class Parameters:
url: str = "https://api.mistral.ai"
model: str = "open-mistral-7b"
max_new_tokens: int = 1000
api_key_name = "mistral_api_key"
def __init__(self) -> None:
super().__init__()
@ -91,9 +92,5 @@ Answer each question accurately and thoroughly.
session.extend([User(query), Assistant("".join(new_content))])
@classmethod
def get_key(cls) -> str:
return get_key("mistral_api_key")
factory = Mistral

View file

@ -12,7 +12,7 @@ import httpx
import tiktoken
from ..core import Backend
from ..key import get_key
from ..key import UsesKeyMixin
from ..session import Assistant, Message, Session, User, session_to_list
@ -63,7 +63,7 @@ class EncodingMeta:
return cls(encoding, tokens_per_message, tokens_per_name, tokens_overhead)
class ChatGPT:
class ChatGPT(UsesKeyMixin):
@dataclass
class Parameters:
model: str = "gpt-4o-mini"
@ -81,6 +81,11 @@ class ChatGPT:
top_p: float | None = None
"""The model temperature for sampling"""
api_key_name: str = "openai_api_key"
"""The OpenAI API key"""
parameters: Parameters
def __init__(self) -> None:
self.parameters = self.Parameters()
@ -171,10 +176,6 @@ class ChatGPT:
session.extend([User(query), Assistant("".join(new_content))])
@classmethod
def get_key(cls) -> str:
return get_key("openai_api_key")
def factory() -> Backend:
"""Uses the OpenAI chat completion API"""

View file

@ -2,11 +2,29 @@
#
# SPDX-License-Identifier: MIT
from typing import Protocol
import functools
import platformdirs
class APIKeyProtocol(Protocol):
@property
def api_key_name(self) -> str:
...
class HasKeyProtocol(Protocol):
@property
def parameters(self) -> APIKeyProtocol:
...
class UsesKeyMixin:
def get_key(self: HasKeyProtocol) -> str:
return get_key(self.parameters.api_key_name)
class NoKeyAvailable(Exception):
pass