diff --git a/src/chap/backends/llama_cpp.py b/src/chap/backends/llama_cpp.py index 381b825..99ddece 100644 --- a/src/chap/backends/llama_cpp.py +++ b/src/chap/backends/llama_cpp.py @@ -18,10 +18,10 @@ class LlamaCpp(AutoAskMixin): url: str = "http://localhost:8080/completion" """The URL of a llama.cpp server's completion endpoint.""" - start_prompt: str = """[INST] <>\n""" - after_system: str = "\n<>\n\n" - after_user: str = """ [/INST] """ - after_assistant: str = """ [INST] """ + start_prompt: str = "" + system_format: str = "<>{}<>" + user_format: str = " [INST] {} [/INST]" + assistant_format: str = " {}" def __init__(self) -> None: super().__init__() @@ -34,18 +34,18 @@ A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, a def make_full_query(self, messages: Session, max_query_size: int) -> str: del messages[1:-max_query_size] result = [self.parameters.start_prompt] + formats = { + Role.SYSTEM: self.parameters.system_format, + Role.USER: self.parameters.user_format, + Role.ASSISTANT: self.parameters.assistant_format, + } for m in messages: content = (m.content or "").strip() if not content: continue - result.append(content) - if m.role == Role.SYSTEM: - result.append(self.parameters.after_system) - elif m.role == Role.ASSISTANT: - result.append(self.parameters.after_assistant) - elif m.role == Role.USER: - result.append(self.parameters.after_user) + result.append(formats[m.role].format(content)) full_query = "".join(result) + print("fq", full_query) return full_query async def aask( diff --git a/src/chap/backends/mistral.py b/src/chap/backends/mistral.py new file mode 100644 index 0000000..c254632 --- /dev/null +++ b/src/chap/backends/mistral.py @@ -0,0 +1,99 @@ +# SPDX-FileCopyrightText: 2024 Jeff Epler +# +# SPDX-License-Identifier: MIT + +import json +from dataclasses import dataclass +from typing import AsyncGenerator, Any + +import httpx + +from ..core import AutoAskMixin +from ..key import get_key +from ..session import Assistant, Session, User + + +class Mistral(AutoAskMixin): + @dataclass + class Parameters: + url: str = "https://api.mistral.ai" + model: str = "open-mistral-7b" + max_new_tokens: int = 1000 + + def __init__(self) -> None: + super().__init__() + self.parameters = self.Parameters() + + system_message = """\ +Answer each question accurately and thoroughly. +""" + + def make_full_query(self, messages: Session, max_query_size: int) -> dict[str, Any]: + messages = [m for m in messages if m.content] + del messages[1:-max_query_size] + result = dict( + model=self.parameters.model, + max_tokens=self.parameters.max_new_tokens, + messages=[dict(role=str(m.role), content=m.content) for m in messages], + stream=True, + ) + return result + + async def aask( + self, + session: Session, + query: str, + *, + max_query_size: int = 5, + timeout: float = 180, + ) -> AsyncGenerator[str, None]: + new_content: list[str] = [] + params = self.make_full_query(session + [User(query)], max_query_size) + try: + async with httpx.AsyncClient(timeout=timeout) as client: + async with client.stream( + "POST", + f"{self.parameters.url}/v1/chat/completions", + json=params, + headers={ + "Authorization": f"Bearer {self.get_key()}", + "content-type": "application/json", + "accept": "application/json", + "model": "application/json", + }, + ) as response: + if response.status_code == 200: + async for line in response.aiter_lines(): + if line.startswith("data:"): + data = line.removeprefix("data:").strip() + if data == "[DONE]": + break + j = json.loads(data) + content = ( + j.get("choices", [{}])[0] + .get("delta", {}) + .get("content", "") + ) + if content: + new_content.append(content) + yield content + else: + content = f"\nFailed with {response=!r}" + new_content.append(content) + yield content + async for line in response.aiter_lines(): + new_content.append(line) + yield line + except httpx.HTTPError as e: + content = f"\nException: {e!r}" + new_content.append(content) + yield content + + session.extend([User(query), Assistant("".join(new_content))]) + + @classmethod + def get_key(cls) -> str: + return get_key("mistral_api_key") + + +factory = Mistral