diff --git a/src/chap/backends/llama_cpp.py b/src/chap/backends/llama_cpp.py
index 381b825..99ddece 100644
--- a/src/chap/backends/llama_cpp.py
+++ b/src/chap/backends/llama_cpp.py
@@ -18,10 +18,10 @@ class LlamaCpp(AutoAskMixin):
url: str = "http://localhost:8080/completion"
"""The URL of a llama.cpp server's completion endpoint."""
- start_prompt: str = """[INST] <>\n"""
- after_system: str = "\n<>\n\n"
- after_user: str = """ [/INST] """
- after_assistant: str = """ [INST] """
+ start_prompt: str = ""
+ system_format: str = "<>{}<>"
+ user_format: str = " [INST] {} [/INST]"
+ assistant_format: str = " {}"
def __init__(self) -> None:
super().__init__()
@@ -34,18 +34,18 @@ A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, a
def make_full_query(self, messages: Session, max_query_size: int) -> str:
del messages[1:-max_query_size]
result = [self.parameters.start_prompt]
+ formats = {
+ Role.SYSTEM: self.parameters.system_format,
+ Role.USER: self.parameters.user_format,
+ Role.ASSISTANT: self.parameters.assistant_format,
+ }
for m in messages:
content = (m.content or "").strip()
if not content:
continue
- result.append(content)
- if m.role == Role.SYSTEM:
- result.append(self.parameters.after_system)
- elif m.role == Role.ASSISTANT:
- result.append(self.parameters.after_assistant)
- elif m.role == Role.USER:
- result.append(self.parameters.after_user)
+ result.append(formats[m.role].format(content))
full_query = "".join(result)
+ print("fq", full_query)
return full_query
async def aask(
diff --git a/src/chap/backends/mistral.py b/src/chap/backends/mistral.py
new file mode 100644
index 0000000..c254632
--- /dev/null
+++ b/src/chap/backends/mistral.py
@@ -0,0 +1,99 @@
+# SPDX-FileCopyrightText: 2024 Jeff Epler
+#
+# SPDX-License-Identifier: MIT
+
+import json
+from dataclasses import dataclass
+from typing import AsyncGenerator, Any
+
+import httpx
+
+from ..core import AutoAskMixin
+from ..key import get_key
+from ..session import Assistant, Session, User
+
+
+class Mistral(AutoAskMixin):
+ @dataclass
+ class Parameters:
+ url: str = "https://api.mistral.ai"
+ model: str = "open-mistral-7b"
+ max_new_tokens: int = 1000
+
+ def __init__(self) -> None:
+ super().__init__()
+ self.parameters = self.Parameters()
+
+ system_message = """\
+Answer each question accurately and thoroughly.
+"""
+
+ def make_full_query(self, messages: Session, max_query_size: int) -> dict[str, Any]:
+ messages = [m for m in messages if m.content]
+ del messages[1:-max_query_size]
+ result = dict(
+ model=self.parameters.model,
+ max_tokens=self.parameters.max_new_tokens,
+ messages=[dict(role=str(m.role), content=m.content) for m in messages],
+ stream=True,
+ )
+ return result
+
+ async def aask(
+ self,
+ session: Session,
+ query: str,
+ *,
+ max_query_size: int = 5,
+ timeout: float = 180,
+ ) -> AsyncGenerator[str, None]:
+ new_content: list[str] = []
+ params = self.make_full_query(session + [User(query)], max_query_size)
+ try:
+ async with httpx.AsyncClient(timeout=timeout) as client:
+ async with client.stream(
+ "POST",
+ f"{self.parameters.url}/v1/chat/completions",
+ json=params,
+ headers={
+ "Authorization": f"Bearer {self.get_key()}",
+ "content-type": "application/json",
+ "accept": "application/json",
+ "model": "application/json",
+ },
+ ) as response:
+ if response.status_code == 200:
+ async for line in response.aiter_lines():
+ if line.startswith("data:"):
+ data = line.removeprefix("data:").strip()
+ if data == "[DONE]":
+ break
+ j = json.loads(data)
+ content = (
+ j.get("choices", [{}])[0]
+ .get("delta", {})
+ .get("content", "")
+ )
+ if content:
+ new_content.append(content)
+ yield content
+ else:
+ content = f"\nFailed with {response=!r}"
+ new_content.append(content)
+ yield content
+ async for line in response.aiter_lines():
+ new_content.append(line)
+ yield line
+ except httpx.HTTPError as e:
+ content = f"\nException: {e!r}"
+ new_content.append(content)
+ yield content
+
+ session.extend([User(query), Assistant("".join(new_content))])
+
+ @classmethod
+ def get_key(cls) -> str:
+ return get_key("mistral_api_key")
+
+
+factory = Mistral