diff --git a/src/chap/backends/huggingface.py b/src/chap/backends/huggingface.py new file mode 100644 index 0000000..043c3db --- /dev/null +++ b/src/chap/backends/huggingface.py @@ -0,0 +1,117 @@ +# SPDX-FileCopyrightText: 2023 Jeff Epler +# +# SPDX-License-Identifier: MIT + +import asyncio +import json +from dataclasses import dataclass + +import httpx + +from ..key import get_key +from ..session import Assistant, User + + +class HuggingFace: + @dataclass + class Parameters: + url: str = "https://api-inference.huggingface.co" + model: str = "mistralai/Mistral-7B-Instruct-v0.1" + max_new_tokens: int = 250 + start_prompt: str = """[INST] <>\n""" + after_system: str = "\n<>\n\n" + after_user: str = """ [/INST] """ + after_assistant: str = """ [INST] """ + stop_token_id = 2 + + def __init__(self): + self.parameters = self.Parameters() + + system_message = """\ +A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, and knows its own limits. +""" + + def make_full_query(self, messages, max_query_size): + del messages[1:-max_query_size] + result = [self.parameters.start_prompt] + for m in messages: + content = (m.content or "").strip() + if not content: + continue + result.append(content) + if m.role == "system": + result.append(self.parameters.after_system) + elif m.role == "assistant": + result.append(self.parameters.after_assistant) + elif m.role == "user": + result.append(self.parameters.after_user) + full_query = "".join(result) + return full_query + + async def chained_query(self, inputs, timeout): + async with httpx.AsyncClient(timeout=timeout) as client: + while inputs: + params = { + "inputs": inputs, + "stream": True, + } + inputs = None + async with client.stream( + "POST", + f"{self.parameters.url}/models/{self.parameters.model}", + json=params, + headers={ + "Authorization": f"Bearer {self.get_key()}", + }, + ) as response: + if response.status_code == 200: + async for line in response.aiter_lines(): + if line.startswith("data:"): + data = line.removeprefix("data:").strip() + j = json.loads(data) + token = j.get("token", {}) + inputs = j.get("generated_text", inputs) + if token.get("id") == self.parameters.stop_token_id: + return + yield token.get("text", "") + else: + yield f"\nFailed with {response=!r}" + return + + async def aask( + self, session, query, *, max_query_size=5, timeout=180 + ): # pylint: disable=unused-argument,too-many-locals,too-many-branches + new_content = [] + inputs = self.make_full_query(session.session + [User(query)], max_query_size) + try: + async for content in self.chained_query(inputs, timeout=timeout): + if not new_content: + content = content.lstrip() + if content: + if not new_content: + content = content.lstrip() + if content: + new_content.append(content) + yield content + + except httpx.HTTPError as e: + content = f"\nException: {e!r}" + new_content.append(content) + yield content + + session.session.extend([User(query), Assistant("".join(new_content))]) + + def ask(self, session, query, *, max_query_size=5, timeout=60): + asyncio.run( + self.aask(session, query, max_query_size=max_query_size, timeout=timeout) + ) + return session.session[-1].message + + @classmethod + def get_key(cls): + return get_key("huggingface_api_token") + + +def factory(): + """Uses the huggingface text-generation-interface web API""" + return HuggingFace() diff --git a/src/chap/backends/llama_cpp.py b/src/chap/backends/llama_cpp.py index 4854b43..53a0c9f 100644 --- a/src/chap/backends/llama_cpp.py +++ b/src/chap/backends/llama_cpp.py @@ -17,6 +17,11 @@ class LlamaCpp: url: str = "http://localhost:8080/completion" """The URL of a llama.cpp server's completion endpoint.""" + start_prompt: str = """[INST] <>\n""" + after_system: str = "\n<>\n\n" + after_user: str = """ [/INST] """ + after_assistant: str = """ [INST] """ + def __init__(self): self.parameters = self.Parameters() @@ -26,29 +31,30 @@ A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, a def make_full_query(self, messages, max_query_size): del messages[1:-max_query_size] - rows = [] + result = [self.parameters.start_prompt] for m in messages: content = (m.content or "").strip() if not content: continue + result.append(content) if m.role == "system": - rows.append(f"ASSISTANT'S RULE: {content}\n") + result.append(self.parameters.after_system) elif m.role == "assistant": - rows.append(f"ASSISTANT: {content}\n") + result.append(self.parameters.after_assistant) elif m.role == "user": - rows.append(f"USER: {content}") - rows.append("ASSISTANT: ") - full_query = ("\n".join(rows)).rstrip() + result.append(self.parameters.after_user) + full_query = "".join(result) return full_query async def aask( - self, session, query, *, max_query_size=5, timeout=60 + self, session, query, *, max_query_size=5, timeout=180 ): # pylint: disable=unused-argument,too-many-locals,too-many-branches params = { "prompt": self.make_full_query( session.session + [User(query)], max_query_size ), "stream": True, + "stop": ["", "", "[INST]"], } new_content = [] try: diff --git a/src/chap/commands/tui.css b/src/chap/commands/tui.css index b20924f..f6fc55c 100644 --- a/src/chap/commands/tui.css +++ b/src/chap/commands/tui.css @@ -4,6 +4,15 @@ * SPDX-License-Identifier: MIT */ +.role_user.history_exclude, .role_assistant.history_exclude { + color: $text-disabled; + border-left: dashed $primary; +} +.role_assistant.history_exclude:focus-within { + color: $text-disabled; + border-left: dashed $secondary; +} + .role_system { text-style: italic; color: $text-muted; diff --git a/src/chap/commands/tui.py b/src/chap/commands/tui.py index 33f71d4..35ea6bd 100644 --- a/src/chap/commands/tui.py +++ b/src/chap/commands/tui.py @@ -29,6 +29,7 @@ class Markdown( Binding("ctrl+y", "yank", "Yank text", show=True), Binding("ctrl+r", "resubmit", "resubmit", show=True), Binding("ctrl+x", "delete", "delete to end", show=True), + Binding("ctrl+q", "toggle_history", "history toggle", show=True), ] @@ -43,7 +44,7 @@ def markdown_for_step(step): class Tui(App): CSS_PATH = "tui.css" BINDINGS = [ - Binding("ctrl+q", "app.quit", "Quit", show=True, priority=True), + Binding("ctrl+c", "app.quit", "Quit", show=True, priority=True), ] def __init__(self, api=None, session=None): @@ -82,6 +83,12 @@ class Tui(App): tokens = [] update = asyncio.Queue(1) + # Construct a fake session with only select items + session = Session() + for si, wi in zip(self.session.session, self.container.children): + if not wi.has_class("history_exclude"): + session.session.append(si) + async def render_fun(): while await update.get(): if tokens: @@ -90,7 +97,7 @@ class Tui(App): await asyncio.sleep(0.1) async def get_token_fun(): - async for token in self.api.aask(self.session, event.value): + async for token in self.api.aask(session, event.value): tokens.append(token) try: update.put_nowait(True) @@ -102,6 +109,7 @@ class Tui(App): await asyncio.gather(render_fun(), get_token_fun()) self.input.value = "" finally: + self.session.session.extend(session.session[-2:]) all_output = self.session.session[-1].content output.update(all_output) output._markdown = all_output # pylint: disable=protected-access @@ -118,6 +126,19 @@ class Tui(App): content = widget._markdown # pylint: disable=protected-access subprocess.run(["xsel", "-ib"], input=content.encode("utf-8"), check=False) + def action_toggle_history(self): + widget = self.focused + if not isinstance(widget, Markdown): + return + children = self.container.children + idx = children.index(widget) + while idx > 1 and not "role_user" in children[idx].classes: + idx -= 1 + widget = children[idx] + + children[idx].toggle_class("history_exclude") + children[idx + 1].toggle_class("history_exclude") + async def action_resubmit(self): await self.delete_or_resubmit(True) @@ -130,7 +151,7 @@ class Tui(App): return children = self.container.children idx = children.index(widget) - while idx > 1 and not "role_user" in children[idx].classes: + while idx > 1 and not children[idx].has_class("role_user"): idx -= 1 widget = children[idx] diff --git a/src/chap/core.py b/src/chap/core.py index 44c93b3..effd035 100644 --- a/src/chap/core.py +++ b/src/chap/core.py @@ -134,7 +134,7 @@ def format_backend_help(api, formatter): doc = get_attribute_docstring(type(api.parameters), f.name).docstring_below if doc: doc += " " - doc += f"(Default: {default})" + doc += f"(Default: {default!r})" rows.append((f"-B {name}:{f.type.__name__.upper()}", doc)) formatter.write_dl(rows)