Add UsesKeyMixin & make all key names settable as parameters
This commit is contained in:
parent
06482245b7
commit
d38a98ad90
5 changed files with 34 additions and 24 deletions
|
|
@ -9,16 +9,17 @@ from typing import AsyncGenerator, Any
|
|||
import httpx
|
||||
|
||||
from ..core import AutoAskMixin, Backend
|
||||
from ..key import get_key
|
||||
from ..key import UsesKeyMixin
|
||||
from ..session import Assistant, Role, Session, User
|
||||
|
||||
|
||||
class Anthropic(AutoAskMixin):
|
||||
class Anthropic(AutoAskMixin, UsesKeyMixin):
|
||||
@dataclass
|
||||
class Parameters:
|
||||
url: str = "https://api.anthropic.com"
|
||||
model: str = "claude-3-5-sonnet-20240620"
|
||||
max_new_tokens: int = 1000
|
||||
api_key_name = "anthropic_api_key"
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
|
@ -88,10 +89,6 @@ Answer each question accurately and thoroughly.
|
|||
|
||||
session.extend([User(query), Assistant("".join(new_content))])
|
||||
|
||||
@classmethod
|
||||
def get_key(cls) -> str:
|
||||
return get_key("anthropic_api_key")
|
||||
|
||||
|
||||
def factory() -> Backend:
|
||||
"""Uses the anthropic text-generation-interface web API"""
|
||||
|
|
|
|||
|
|
@ -9,11 +9,11 @@ from typing import Any, AsyncGenerator
|
|||
import httpx
|
||||
|
||||
from ..core import AutoAskMixin, Backend
|
||||
from ..key import get_key
|
||||
from ..key import UsesKeyMixin
|
||||
from ..session import Assistant, Role, Session, User
|
||||
|
||||
|
||||
class HuggingFace(AutoAskMixin):
|
||||
class HuggingFace(AutoAskMixin, UsesKeyMixin):
|
||||
@dataclass
|
||||
class Parameters:
|
||||
url: str = "https://api-inference.huggingface.co"
|
||||
|
|
@ -24,6 +24,7 @@ class HuggingFace(AutoAskMixin):
|
|||
after_user: str = """ [/INST] """
|
||||
after_assistant: str = """ </s><s>[INST] """
|
||||
stop_token_id = 2
|
||||
api_key_name = "huggingface_api_token"
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
|
@ -110,10 +111,6 @@ A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, a
|
|||
|
||||
session.extend([User(query), Assistant("".join(new_content))])
|
||||
|
||||
@classmethod
|
||||
def get_key(cls) -> str:
|
||||
return get_key("huggingface_api_token")
|
||||
|
||||
|
||||
def factory() -> Backend:
|
||||
"""Uses the huggingface text-generation-interface web API"""
|
||||
|
|
|
|||
|
|
@ -9,16 +9,17 @@ from typing import AsyncGenerator, Any
|
|||
import httpx
|
||||
|
||||
from ..core import AutoAskMixin
|
||||
from ..key import get_key
|
||||
from ..key import UsesKeyMixin
|
||||
from ..session import Assistant, Session, User
|
||||
|
||||
|
||||
class Mistral(AutoAskMixin):
|
||||
class Mistral(AutoAskMixin, UsesKeyMixin):
|
||||
@dataclass
|
||||
class Parameters:
|
||||
url: str = "https://api.mistral.ai"
|
||||
model: str = "open-mistral-7b"
|
||||
max_new_tokens: int = 1000
|
||||
api_key_name = "mistral_api_key"
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
|
@ -91,9 +92,5 @@ Answer each question accurately and thoroughly.
|
|||
|
||||
session.extend([User(query), Assistant("".join(new_content))])
|
||||
|
||||
@classmethod
|
||||
def get_key(cls) -> str:
|
||||
return get_key("mistral_api_key")
|
||||
|
||||
|
||||
factory = Mistral
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ import httpx
|
|||
import tiktoken
|
||||
|
||||
from ..core import Backend
|
||||
from ..key import get_key
|
||||
from ..key import UsesKeyMixin
|
||||
from ..session import Assistant, Message, Session, User, session_to_list
|
||||
|
||||
|
||||
|
|
@ -63,7 +63,7 @@ class EncodingMeta:
|
|||
return cls(encoding, tokens_per_message, tokens_per_name, tokens_overhead)
|
||||
|
||||
|
||||
class ChatGPT:
|
||||
class ChatGPT(UsesKeyMixin):
|
||||
@dataclass
|
||||
class Parameters:
|
||||
model: str = "gpt-4o-mini"
|
||||
|
|
@ -81,6 +81,11 @@ class ChatGPT:
|
|||
top_p: float | None = None
|
||||
"""The model temperature for sampling"""
|
||||
|
||||
api_key_name: str = "openai_api_key"
|
||||
"""The OpenAI API key"""
|
||||
|
||||
parameters: Parameters
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.parameters = self.Parameters()
|
||||
|
||||
|
|
@ -171,10 +176,6 @@ class ChatGPT:
|
|||
|
||||
session.extend([User(query), Assistant("".join(new_content))])
|
||||
|
||||
@classmethod
|
||||
def get_key(cls) -> str:
|
||||
return get_key("openai_api_key")
|
||||
|
||||
|
||||
def factory() -> Backend:
|
||||
"""Uses the OpenAI chat completion API"""
|
||||
|
|
|
|||
|
|
@ -2,11 +2,29 @@
|
|||
#
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from typing import Protocol
|
||||
import functools
|
||||
|
||||
import platformdirs
|
||||
|
||||
|
||||
class APIKeyProtocol(Protocol):
|
||||
@property
|
||||
def api_key_name(self) -> str:
|
||||
...
|
||||
|
||||
|
||||
class HasKeyProtocol(Protocol):
|
||||
@property
|
||||
def parameters(self) -> APIKeyProtocol:
|
||||
...
|
||||
|
||||
|
||||
class UsesKeyMixin:
|
||||
def get_key(self: HasKeyProtocol) -> str:
|
||||
return get_key(self.parameters.api_key_name)
|
||||
|
||||
|
||||
class NoKeyAvailable(Exception):
|
||||
pass
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue