diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 9f58a4f..660766a 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -46,6 +46,9 @@ jobs: - name: install deps run: pip install -r requirements-dev.txt + - name: check types with mypy + run: make + - name: Build release run: python -mbuild diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..94b117d --- /dev/null +++ b/Makefile @@ -0,0 +1,15 @@ +# SPDX-FileCopyrightText: 2023 Jeff Epler +# +# SPDX-License-Identifier: MIT + +.PHONY: mypy +mypy: venv/bin/mypy + venv/bin/mypy --strict -p chap + +venv/bin/mypy: + python -mvenv venv + venv/bin/pip install -r requirements.txt mypy + +.PHONY: clean +clean: + rm -rf venv diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000..4dbbc58 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,6 @@ +# SPDX-FileCopyrightText: 2023 Jeff Epler +# +# SPDX-License-Identifier: MIT + +[mypy] +mypy_path = src diff --git a/src/chap/backends/huggingface.py b/src/chap/backends/huggingface.py index 91859ff..60766b3 100644 --- a/src/chap/backends/huggingface.py +++ b/src/chap/backends/huggingface.py @@ -2,17 +2,18 @@ # # SPDX-License-Identifier: MIT -import asyncio import json from dataclasses import dataclass +from typing import Any, AsyncGenerator import httpx +from ..core import AutoAskMixin, Backend from ..key import get_key -from ..session import Assistant, Role, User +from ..session import Assistant, Role, Session, User -class HuggingFace: +class HuggingFace(AutoAskMixin): @dataclass class Parameters: url: str = "https://api-inference.huggingface.co" @@ -24,14 +25,15 @@ class HuggingFace: after_assistant: str = """ [INST] """ stop_token_id = 2 - def __init__(self): + def __init__(self) -> None: + super().__init__() self.parameters = self.Parameters() system_message = """\ A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, and knows its own limits. """ - def make_full_query(self, messages, max_query_size): + def make_full_query(self, messages: Session, max_query_size: int) -> str: del messages[1:-max_query_size] result = [self.parameters.start_prompt] for m in messages: @@ -48,7 +50,9 @@ A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, a full_query = "".join(result) return full_query - async def chained_query(self, inputs, timeout): + async def chained_query( + self, inputs: Any, timeout: float + ) -> AsyncGenerator[str, None]: async with httpx.AsyncClient(timeout=timeout) as client: while inputs: params = { @@ -79,9 +83,16 @@ A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, a return async def aask( - self, session, query, *, max_query_size=5, timeout=180 - ): # pylint: disable=unused-argument,too-many-locals,too-many-branches - new_content = [] + self, + session: Session, + query: str, + *, + max_query_size: int = 5, + timeout: float = 180, + ) -> AsyncGenerator[ + str, None + ]: # pylint: disable=unused-argument,too-many-locals,too-many-branches + new_content: list[str] = [] inputs = self.make_full_query(session + [User(query)], max_query_size) try: async for content in self.chained_query(inputs, timeout=timeout): @@ -101,17 +112,11 @@ A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, a session.extend([User(query), Assistant("".join(new_content))]) - def ask(self, session, query, *, max_query_size=5, timeout=60): - asyncio.run( - self.aask(session, query, max_query_size=max_query_size, timeout=timeout) - ) - return session[-1].content - @classmethod - def get_key(cls): + def get_key(cls) -> str: return get_key("huggingface_api_token") -def factory(): +def factory() -> Backend: """Uses the huggingface text-generation-interface web API""" return HuggingFace() diff --git a/src/chap/backends/llama_cpp.py b/src/chap/backends/llama_cpp.py index cf3a838..5695f8c 100644 --- a/src/chap/backends/llama_cpp.py +++ b/src/chap/backends/llama_cpp.py @@ -2,16 +2,17 @@ # # SPDX-License-Identifier: MIT -import asyncio import json from dataclasses import dataclass +from typing import AsyncGenerator import httpx -from ..session import Assistant, Role, User +from ..core import AutoAskMixin, Backend +from ..session import Assistant, Role, Session, User -class LlamaCpp: +class LlamaCpp(AutoAskMixin): @dataclass class Parameters: url: str = "http://localhost:8080/completion" @@ -22,14 +23,15 @@ class LlamaCpp: after_user: str = """ [/INST] """ after_assistant: str = """ [INST] """ - def __init__(self): + def __init__(self) -> None: + super().__init__() self.parameters = self.Parameters() system_message = """\ A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, and knows its own limits. """ - def make_full_query(self, messages, max_query_size): + def make_full_query(self, messages: Session, max_query_size: int) -> str: del messages[1:-max_query_size] result = [self.parameters.start_prompt] for m in messages: @@ -47,14 +49,21 @@ A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, a return full_query async def aask( - self, session, query, *, max_query_size=5, timeout=180 - ): # pylint: disable=unused-argument,too-many-locals,too-many-branches + self, + session: Session, + query: str, + *, + max_query_size: int = 5, + timeout: float = 180, + ) -> AsyncGenerator[ + str, None + ]: # pylint: disable=unused-argument,too-many-locals,too-many-branches params = { "prompt": self.make_full_query(session + [User(query)], max_query_size), "stream": True, "stop": ["", "", "[INST]"], } - new_content = [] + new_content: list[str] = [] try: async with httpx.AsyncClient(timeout=timeout) as client: async with client.stream( @@ -87,13 +96,7 @@ A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, a session.extend([User(query), Assistant("".join(new_content))]) - def ask(self, session, query, *, max_query_size=5, timeout=60): - asyncio.run( - self.aask(session, query, max_query_size=max_query_size, timeout=timeout) - ) - return session[-1].content - -def factory(): +def factory() -> Backend: """Uses the llama.cpp completion web API""" return LlamaCpp() diff --git a/src/chap/backends/lorem.py b/src/chap/backends/lorem.py index 4bcc845..80364a8 100644 --- a/src/chap/backends/lorem.py +++ b/src/chap/backends/lorem.py @@ -5,13 +5,16 @@ import asyncio import random from dataclasses import dataclass +from typing import AsyncGenerator, Iterable, cast -from lorem_text import lorem +# lorem is not type annotated +from lorem_text import lorem # type: ignore -from ..session import Assistant, User +from ..core import Backend +from ..session import Assistant, Session, User -def ipartition(s, sep=" "): +def ipartition(s: str, sep: str = " ") -> Iterable[tuple[str, str]]: rest = s while rest: first, opt_sep, rest = rest.partition(sep) @@ -30,15 +33,19 @@ class Lorem: paragraph_hi: int = 5 """Maximum response paragraph count (inclusive)""" - def __init__(self): + def __init__(self) -> None: self.parameters = self.Parameters() system_message = ( "(It doesn't matter what you ask, this backend will respond with lorem)" ) - async def aask(self, session, query, *, max_query_size=5, timeout=60): - data = self.ask(session, query, max_query_size=max_query_size, timeout=timeout) + async def aask( + self, + session: Session, + query: str, + ) -> AsyncGenerator[str, None]: + data = self.ask(session, query)[-1] for word, opt_sep in ipartition(data): yield word + opt_sep await asyncio.sleep( @@ -46,15 +53,22 @@ class Lorem: ) def ask( - self, session, query, *, max_query_size=5, timeout=60 - ): # pylint: disable=unused-argument - new_content = lorem.paragraphs( - random.randint(self.parameters.paragraph_lo, self.parameters.paragraph_hi) + self, + session: Session, + query: str, + ) -> str: # pylint: disable=unused-argument + new_content = cast( + str, + lorem.paragraphs( + random.randint( + self.parameters.paragraph_lo, self.parameters.paragraph_hi + ) + ), ).replace("\n", "\n\n") session.extend([User(query), Assistant("".join(new_content))]) - return new_content + return session[-1].content -def factory(): +def factory() -> Backend: """That just prints 'lorem' text. Useful for testing.""" return Lorem() diff --git a/src/chap/backends/openai_chatgpt.py b/src/chap/backends/openai_chatgpt.py index fbcbae6..a6d8fc7 100644 --- a/src/chap/backends/openai_chatgpt.py +++ b/src/chap/backends/openai_chatgpt.py @@ -5,12 +5,14 @@ import functools import json from dataclasses import dataclass +from typing import AsyncGenerator, cast import httpx import tiktoken +from ..core import Backend from ..key import get_key -from ..session import Assistant, User +from ..session import Assistant, Message, Session, User, session_to_list @dataclass(frozen=True) @@ -20,19 +22,19 @@ class EncodingMeta: tokens_per_name: int @functools.lru_cache() - def encode(self, s): + def encode(self, s: str) -> list[int]: return self.encoding.encode(s) - def num_tokens_for_message(self, message): + def num_tokens_for_message(self, message: Message) -> int: # n.b. chap doesn't use message.name yet return len(self.encode(message.role)) + len(self.encode(message.content)) - def num_tokens_for_messages(self, messages): + def num_tokens_for_messages(self, messages: Session) -> int: return sum(self.num_tokens_for_message(message) for message in messages) + 3 @classmethod @functools.cache - def from_model(cls, model): + def from_model(cls, model: str) -> "EncodingMeta": if model == "gpt-3.5-turbo": # print("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.") model = "gpt-3.5-turbo-0613" @@ -77,12 +79,12 @@ class ChatGPT: max_request_tokens: int = 1024 """The approximate greatest number of tokens to send in a request. When the session is long, the system prompt and 1 or more of the most recent interaction steps are sent.""" - def __init__(self): + def __init__(self) -> None: self.parameters = self.Parameters() system_message: str = "You're named Chap (they/them). Chap strives to answer truthfully and frankly. You have opinions, preferences, and emotions, you can say when you don't know, and you prefer to speak in plain language." - def make_full_prompt(self, all_history): + def make_full_prompt(self, all_history: Session) -> Session: encoding = EncodingMeta.from_model(self.parameters.model) result = [all_history[0]] # Assumed to be system prompt left = self.parameters.max_request_tokens - encoding.num_tokens_for_messages( @@ -99,15 +101,13 @@ class ChatGPT: result.extend(reversed(parts)) return result - def ask(self, session, query, *, timeout=60): + def ask(self, session: Session, query: str, *, timeout: float = 60) -> str: full_prompt = self.make_full_prompt(session + [User(query)]) response = httpx.post( "https://api.openai.com/v1/chat/completions", json={ "model": self.parameters.model, - "messages": full_prompt.to_dict()[ # pylint: disable=no-member - "session" - ], + "messages": session_to_list(full_prompt), }, # pylint: disable=no-member headers={ "Authorization": f"Bearer {self.get_key()}", @@ -115,20 +115,20 @@ class ChatGPT: timeout=timeout, ) if response.status_code != 200: - print("Failure", response.status_code, response.text) - return None + return f"Failure {response.text} ({response.status_code})" try: j = response.json() - result = j["choices"][0]["message"]["content"] + result = cast(str, j["choices"][0]["message"]["content"]) except (KeyError, IndexError, json.decoder.JSONDecodeError): - print("Failure", response.status_code, response.text) - return None + return f"Failure {response.text} ({response.status_code})" session.extend([User(query), Assistant(result)]) return result - async def aask(self, session, query, *, timeout=60): + async def aask( + self, session: Session, query: str, *, timeout: float = 60 + ) -> AsyncGenerator[str, None]: full_prompt = self.make_full_prompt(session + [User(query)]) new_content = [] try: @@ -140,9 +140,7 @@ class ChatGPT: json={ "model": self.parameters.model, "stream": True, - "messages": full_prompt.to_dict()[ # pylint: disable=no-member - "session" - ], # pylint: disable=no-member + "messages": session_to_list(full_prompt), }, ) as response: if response.status_code == 200: @@ -170,10 +168,10 @@ class ChatGPT: session.extend([User(query), Assistant("".join(new_content))]) @classmethod - def get_key(cls): + def get_key(cls) -> str: return get_key("openai_api_key") -def factory(): +def factory() -> Backend: """Uses the OpenAI chat completion API""" return ChatGPT() diff --git a/src/chap/backends/textgen.py b/src/chap/backends/textgen.py index 39eaa40..4999b1f 100644 --- a/src/chap/backends/textgen.py +++ b/src/chap/backends/textgen.py @@ -2,19 +2,25 @@ # # SPDX-License-Identifier: MIT -import asyncio import json import uuid +from dataclasses import dataclass +from typing import AsyncGenerator import websockets -from ..key import get_key -from ..session import Assistant, Role, User +from ..core import AutoAskMixin, Backend +from ..session import Assistant, Role, Session, User -class Textgen: - def __init__(self): - self.server = get_key("textgen_url", "textgen server URL") +class Textgen(AutoAskMixin): + @dataclass + class Parameters: + server_hostname: str = "localhost" + + def __init__(self) -> None: + super().__init__() + self.parameters = self.Parameters() system_message = """\ A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, and knows its own limits. @@ -23,9 +29,14 @@ USER: Hello, AI. AI: Hello! How can I assist you today?""" - async def aask( - self, session, query, *, max_query_size=5, timeout=60 - ): # pylint: disable=unused-argument,too-many-locals,too-many-branches + async def aask( # pylint: disable=unused-argument,too-many-locals,too-many-branches + self, + session: Session, + query: str, + *, + max_query_size: int = 5, + timeout: float = 60, + ) -> AsyncGenerator[str, None]: params = { "max_new_tokens": 200, "do_sample": True, @@ -55,7 +66,7 @@ AI: Hello! How can I assist you today?""" ) try: async with websockets.connect( # pylint: disable=no-member - f"ws://{self.server}:7860/queue/join" + f"ws://{self.parameters.server_hostname}:7860/queue/join" ) as websocket: while content := json.loads(await websocket.recv()): if content["msg"] == "send_hash": @@ -124,13 +135,7 @@ AI: Hello! How can I assist you today?""" all_response = new_data[len(full_query) :] session.extend([User(query), Assistant(all_response)]) - def ask(self, session, query, *, max_query_size=5, timeout=60): - asyncio.run( - self.aask(session, query, max_query_size=max_query_size, timeout=timeout) - ) - return session[-1].content - -def factory(): +def factory() -> Backend: """Uses the textgen completion API""" return Textgen() diff --git a/src/chap/commands/ask.py b/src/chap/commands/ask.py index 6a3fbcf..8c4f43e 100644 --- a/src/chap/commands/ask.py +++ b/src/chap/commands/ask.py @@ -4,43 +4,52 @@ import asyncio import sys +from typing import Iterable, Protocol import click import rich -from ..core import command_uses_new_session -from ..session import session_to_file +from ..core import Backend, Obj, command_uses_new_session +from ..session import Session, session_to_file bold = "\033[1m" nobold = "\033[m" -def ipartition(s, sep): +def ipartition(s: str, sep: str) -> Iterable[tuple[str, str]]: rest = s while rest: first, opt_sep, rest = rest.partition(sep) yield (first, opt_sep) +class Printable(Protocol): + def raw(self, s: str) -> None: + ... + + def add(self, s: str) -> None: + ... + + class DumbPrinter: - def raw(self, s): + def raw(self, s: str) -> None: pass - def add(self, s): + def add(self, s: str) -> None: print(s, end="") class WrappingPrinter: - def __init__(self, width=None): + def __init__(self, width: int | None = None) -> None: self._width = width or rich.get_console().width self._column = 0 self._line = "" self._sp = "" - def raw(self, s): + def raw(self, s: str) -> None: print(s, end="") - def add(self, s): + def add(self, s: str) -> None: for line, opt_nl in ipartition(s, "\n"): for word, opt_sp in ipartition(line, " "): newlen = len(self._line) + len(self._sp) + len(word) @@ -64,15 +73,16 @@ class WrappingPrinter: self._sp = "" -def verbose_ask(api, session, q, print_prompt, **kw): +def verbose_ask(api: Backend, session: Session, q: str, print_prompt: bool) -> str: + printer: Printable if sys.stdout.isatty(): printer = WrappingPrinter() else: printer = DumbPrinter() - tokens = [] + tokens: list[str] = [] - async def work(): - async for token in api.aask(session, q, **kw): + async def work() -> None: + async for token in api.aask(session, q): printer.add(token) if print_prompt: @@ -91,11 +101,16 @@ def verbose_ask(api, session, q, print_prompt, **kw): @command_uses_new_session @click.option("--print-prompt/--no-print-prompt", default=True) @click.argument("prompt", nargs=-1, required=True) -def main(obj, prompt, print_prompt): +def main(obj: Obj, prompt: str, print_prompt: bool) -> None: """Ask a question (command-line argument is passed as prompt)""" session = obj.session + assert session is not None + session_filename = obj.session_filename + assert session_filename is not None + api = obj.api + assert api is not None # symlink_session_filename(session_filename) diff --git a/src/chap/commands/cat.py b/src/chap/commands/cat.py index b11da1a..9eb5e2c 100644 --- a/src/chap/commands/cat.py +++ b/src/chap/commands/cat.py @@ -4,15 +4,17 @@ import click -from ..core import command_uses_existing_session +from ..core import Obj, command_uses_existing_session from ..session import Role @command_uses_existing_session @click.option("--no-system", is_flag=True) -def main(obj, no_system): +def main(obj: Obj, no_system: bool) -> None: """Print session in plaintext""" session = obj.session + if not session: + return first = True for row in session: diff --git a/src/chap/commands/grep.py b/src/chap/commands/grep.py index 0b81776..7afadc7 100644 --- a/src/chap/commands/grep.py +++ b/src/chap/commands/grep.py @@ -18,7 +18,7 @@ from .render import to_markdown def list_files_matching_rx( - rx: re.Pattern, conversations_path: Optional[pathlib.Path] = None + rx: re.Pattern[str], conversations_path: Optional[pathlib.Path] = None ) -> Iterable[Tuple[pathlib.Path, Message]]: for conversation in (conversations_path or default_conversations_path).glob( "*.json" @@ -39,7 +39,9 @@ def list_files_matching_rx( @click.option("--files-with-matches", "-l", is_flag=True) @click.option("--fixed-strings", "--literal", "-F", is_flag=True) @click.argument("pattern", nargs=1, required=True) -def main(ignore_case, files_with_matches, fixed_strings, pattern): +def main( + ignore_case: bool, files_with_matches: bool, fixed_strings: bool, pattern: str +) -> None: """Search sessions for pattern""" console = rich.get_console() if fixed_strings: @@ -47,7 +49,7 @@ def main(ignore_case, files_with_matches, fixed_strings, pattern): rx = re.compile(pattern, re.I if ignore_case else 0) last_file = None - for f, m in list_files_matching_rx(rx, ignore_case): + for f, m in list_files_matching_rx(rx, None): if f != last_file: if files_with_matches: print(f) diff --git a/src/chap/commands/import.py b/src/chap/commands/import.py index 7fe0e9a..c1a20a9 100644 --- a/src/chap/commands/import.py +++ b/src/chap/commands/import.py @@ -6,17 +6,20 @@ from __future__ import annotations import json import pathlib +from typing import Any, Iterator, TextIO import click import rich from ..core import conversations_path, new_session_path -from ..session import Message, Role, new_session, session_to_file +from ..session import Message, Role, Session, new_session, session_to_file console = rich.get_console() -def iter_sessions(name, content, session_in, node_id): +def iter_sessions( + name: str, content: Any, session_in: Session, node_id: str +) -> Iterator[tuple[str, Session]]: node = content["mapping"][node_id] session = session_in[:] @@ -37,7 +40,7 @@ def iter_sessions(name, content, session_in, node_id): yield node_id, session -def do_import(output_directory, f): +def do_import(output_directory: pathlib.Path, f: TextIO) -> None: stem = pathlib.Path(f.name).stem content = json.load(f) session = new_session() @@ -66,7 +69,7 @@ def do_import(output_directory, f): @click.argument( "files", nargs=-1, required=True, type=click.File("r", encoding="utf-8") ) -def main(output_directory, files): +def main(output_directory: pathlib.Path, files: list[TextIO]) -> None: """Import files from the ChatGPT webui This understands the format produced by diff --git a/src/chap/commands/render.py b/src/chap/commands/render.py index 5358416..bf814f5 100644 --- a/src/chap/commands/render.py +++ b/src/chap/commands/render.py @@ -7,11 +7,11 @@ import rich from markdown_it import MarkdownIt from rich.markdown import Markdown -from ..core import command_uses_existing_session -from ..session import Role +from ..core import Obj, command_uses_existing_session +from ..session import Message, Role -def to_markdown(message): +def to_markdown(message: Message) -> Markdown: role = message.role if role == Role.USER: style = "bold" @@ -28,9 +28,10 @@ def to_markdown(message): @command_uses_existing_session @click.option("--no-system", is_flag=True) -def main(obj, no_system): +def main(obj: Obj, no_system: bool) -> None: """Print session with formatting""" session = obj.session + assert session is not None console = rich.get_console() first = True diff --git a/src/chap/commands/tui.py b/src/chap/commands/tui.py index 95683e9..7863175 100644 --- a/src/chap/commands/tui.py +++ b/src/chap/commands/tui.py @@ -5,19 +5,20 @@ import asyncio import subprocess import sys +from typing import cast from markdown_it import MarkdownIt from textual import work -from textual.app import App +from textual.app import App, ComposeResult from textual.binding import Binding from textual.containers import Container, Horizontal, VerticalScroll from textual.widgets import Button, Footer, Input, LoadingIndicator, Markdown -from ..core import command_uses_new_session, get_api, new_session_path -from ..session import Assistant, User, new_session, session_to_file +from ..core import Backend, Obj, command_uses_new_session, get_api, new_session_path +from ..session import Assistant, Message, Session, User, new_session, session_to_file -def parser_factory(): +def parser_factory() -> MarkdownIt: parser = MarkdownIt() parser.options["html"] = False return parser @@ -34,7 +35,7 @@ class ChapMarkdown( ] -def markdown_for_step(step): +def markdown_for_step(step: Message) -> ChapMarkdown: return ChapMarkdown( step.content.strip() or "…", classes="role_" + step.role, @@ -48,13 +49,15 @@ class CancelButton(Button): ] -class Tui(App): +class Tui(App[None]): CSS_PATH = "tui.css" BINDINGS = [ Binding("ctrl+c", "quit", "Quit", show=True, priority=True), ] - def __init__(self, api=None, session=None): + def __init__( + self, api: Backend | None = None, session: Session | None = None + ) -> None: super().__init__() self.api = api or get_api("lorem") self.session = ( @@ -62,26 +65,26 @@ class Tui(App): ) @property - def spinner(self): + def spinner(self) -> LoadingIndicator: return self.query_one(LoadingIndicator) @property - def wait(self): - return self.query_one("#wait") + def wait(self) -> VerticalScroll: + return cast(VerticalScroll, self.query_one("#wait")) @property - def input(self): + def input(self) -> Input: return self.query_one(Input) @property - def cancel_button(self): + def cancel_button(self) -> CancelButton: return self.query_one(CancelButton) @property - def container(self): - return self.query_one("#content") + def container(self) -> VerticalScroll: + return cast(VerticalScroll, self.query_one("#content")) - def compose(self): + def compose(self) -> ComposeResult: yield Footer() yield VerticalScroll( *[markdown_for_step(step) for step in self.session], @@ -100,11 +103,11 @@ class Tui(App): self.container.scroll_end(animate=False) self.input.focus() - async def on_input_submitted(self, event) -> None: + async def on_input_submitted(self, event: Input.Submitted) -> None: self.get_completion(event.value) @work(exclusive=True) - async def get_completion(self, query): + async def get_completion(self, query: str) -> None: self.scroll_end() self.input.styles.display = "none" @@ -117,8 +120,8 @@ class Tui(App): await self.container.mount_all( [markdown_for_step(User(query)), output], before="#pad" ) - tokens = [] - update = asyncio.Queue(1) + tokens: list[str] = [] + update: asyncio.Queue[bool] = asyncio.Queue(1) for markdown in self.container.children: markdown.disabled = True @@ -137,14 +140,14 @@ class Tui(App): ] ) - async def render_fun(): + async def render_fun() -> None: while await update.get(): if tokens: output.update("".join(tokens).strip()) self.container.scroll_end() await asyncio.sleep(0.1) - async def get_token_fun(): + async def get_token_fun() -> None: async for token in self.api.aask(session, query): tokens.append(token) message.content += token @@ -174,16 +177,16 @@ class Tui(App): self.cancel_button.disabled = True self.input.focus() - def scroll_end(self): + def scroll_end(self) -> None: self.call_after_refresh(self.container.scroll_end) - def action_yank(self): + def action_yank(self) -> None: widget = self.focused if isinstance(widget, ChapMarkdown): - content = widget._markdown # pylint: disable=protected-access + content = widget._markdown or "" # pylint: disable=protected-access subprocess.run(["xsel", "-ib"], input=content.encode("utf-8"), check=False) - def action_toggle_history(self): + def action_toggle_history(self) -> None: widget = self.focused if not isinstance(widget, ChapMarkdown): return @@ -199,23 +202,25 @@ class Tui(App): for m in children[idx : idx + 2]: m.toggle_class("history_exclude") - async def action_stop_generating(self): + async def action_stop_generating(self) -> None: self.workers.cancel_all() - async def on_button_pressed(self, event): # pylint: disable=unused-argument + async def on_button_pressed( # pylint: disable=unused-argument + self, event: Button.Pressed + ) -> None: self.workers.cancel_all() - async def action_quit(self): + async def action_quit(self) -> None: self.workers.cancel_all() self.exit() - async def action_resubmit(self): + async def action_resubmit(self) -> None: await self.redraft_or_resubmit(True) - async def action_redraft(self): + async def action_redraft(self) -> None: await self.redraft_or_resubmit(False) - async def redraft_or_resubmit(self, resubmit): + async def redraft_or_resubmit(self, resubmit: bool) -> None: widget = self.focused if not isinstance(widget, ChapMarkdown): return @@ -244,11 +249,14 @@ class Tui(App): @command_uses_new_session -def main(obj): +def main(obj: Obj) -> None: """Start interactive terminal user interface session""" api = obj.api + assert api is not None session = obj.session + assert session is not None session_filename = obj.session_filename + assert session_filename is not None tui = Tui(api, session) tui.run() diff --git a/src/chap/core.py b/src/chap/core.py index cd1a705..ad6bc24 100644 --- a/src/chap/core.py +++ b/src/chap/core.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: MIT # pylint: disable=import-outside-toplevel +import asyncio import datetime import importlib import os @@ -10,33 +11,62 @@ import pathlib import pkgutil import subprocess from dataclasses import MISSING, dataclass, fields +from typing import Any, AsyncGenerator, Callable, cast import click import platformdirs from simple_parsing.docstring import get_attribute_docstring +from typing_extensions import Protocol from . import backends, commands # pylint: disable=no-name-in-module -from .session import Message, System, session_from_file +from .session import Message, Session, System, session_from_file conversations_path = platformdirs.user_state_path("chap") / "conversations" conversations_path.mkdir(parents=True, exist_ok=True) -def last_session_path(): +class ABackend(Protocol): # pylint: disable=too-few-public-methods + def aask(self, session: Session, query: str) -> AsyncGenerator[str, None]: + """Make a query, updating the session with the query and response, returning the query token by token""" + + +class Backend(ABackend, Protocol): + parameters: Any + system_message: str + + def ask(self, session: Session, query: str) -> str: + """Make a query, updating the session with the query and response, returning the query""" + + +class AutoAskMixin: # pylint: disable=too-few-public-methods + """Mixin class for backends implementing aask""" + + def ask(self, session: Session, query: str) -> str: + tokens: list[str] = [] + + async def inner() -> None: + # https://github.com/pylint-dev/pylint/issues/5761 + async for token in self.aask(session, query): # type: ignore + tokens.append(token) + + asyncio.run(inner()) + return "".join(tokens) + + +def last_session_path() -> pathlib.Path | None: result = max( conversations_path.glob("*.json"), key=lambda p: p.stat().st_mtime, default=None ) - print(result) return result -def new_session_path(opt_path=None): +def new_session_path(opt_path: pathlib.Path | None = None) -> pathlib.Path: return opt_path or conversations_path / ( datetime.datetime.now().isoformat().replace(":", "_") + ".json" ) -def configure_api_from_environment(api_name, api): +def configure_api_from_environment(api_name: str, api: Backend) -> None: if not hasattr(api, "parameters"): return @@ -54,44 +84,46 @@ def configure_api_from_environment(api_name, api): setattr(api.parameters, field.name, tv) -def get_api(name="openai_chatgpt"): +def get_api(name: str = "openai_chatgpt") -> Backend: name = name.replace("-", "_") - result = importlib.import_module(f"{__package__}.backends.{name}").factory() + result = cast( + Backend, importlib.import_module(f"{__package__}.backends.{name}").factory() + ) configure_api_from_environment(name, result) return result -def ask(*args, **kw): - return get_api().ask(*args, **kw) - - -def aask(*args, **kw): - return get_api().aask(*args, **kw) - - -def do_session_continue(ctx, param, value): +def do_session_continue( + ctx: click.Context, param: click.Parameter, value: pathlib.Path | None +) -> None: if value is None: return if ctx.obj.session is not None: raise click.BadParameter( - param, "--continue-session, --last and --new-session are mutually exclusive" + "--continue-session, --last and --new-session are mutually exclusive", + param=param, ) ctx.obj.session = session_from_file(value) ctx.obj.session_filename = value -def do_session_last(ctx, param, value): # pylint: disable=unused-argument +def do_session_last( + ctx: click.Context, param: click.Parameter, value: bool +) -> None: # pylint: disable=unused-argument if not value: return do_session_continue(ctx, param, last_session_path()) -def do_session_new(ctx, param, value): +def do_session_new( + ctx: click.Context, param: click.Parameter, value: pathlib.Path +) -> None: if ctx.obj.session is not None: if value is None: return - raise click.BadOptionUsage( - param, "--continue-session, --last and --new-session are mutually exclusive" + raise click.BadParameter( + "--continue-session, --last and --new-session are mutually exclusive", + param=param, ) session_filename = new_session_path(value) system_message = ctx.obj.system_message or ctx.obj.api.system_message @@ -99,20 +131,24 @@ def do_session_new(ctx, param, value): ctx.obj.session_filename = session_filename -def colonstr(arg): +def colonstr(arg: str) -> tuple[str, str]: if ":" not in arg: raise click.BadParameter("must be of the form 'name:value'") - return arg.split(":", 1) + return cast(tuple[str, str], tuple(arg.split(":", 1))) -def set_system_message(ctx, param, value): # pylint: disable=unused-argument +def set_system_message( # pylint: disable=unused-argument + ctx: click.Context, param: click.Parameter, value: str +) -> None: if value and value.startswith("@"): with open(value[1:], "r", encoding="utf-8") as f: value = f.read().rstrip() ctx.obj.system_message = value -def set_backend(ctx, param, value): # pylint: disable=unused-argument +def set_backend( # pylint: disable=unused-argument + ctx: click.Context, param: click.Parameter, value: str +) -> None: if value == "list": formatter = ctx.make_formatter() format_backend_list(formatter) @@ -125,7 +161,7 @@ def set_backend(ctx, param, value): # pylint: disable=unused-argument raise click.BadParameter(str(e)) -def format_backend_help(api, formatter): +def format_backend_help(api: Backend, formatter: click.HelpFormatter) -> None: with formatter.section(f"Backend options for {api.__class__.__name__}"): rows = [] for f in fields(api.parameters): @@ -135,11 +171,14 @@ def format_backend_help(api, formatter): if doc: doc += " " doc += f"(Default: {default!r})" - rows.append((f"-B {name}:{f.type.__name__.upper()}", doc)) + typename = f.type.__name__ + rows.append((f"-B {name}:{typename.upper()}", doc)) formatter.write_dl(rows) -def set_backend_option(ctx, param, opts): # pylint: disable=unused-argument +def set_backend_option( # pylint: disable=unused-argument + ctx: click.Context, param: click.Parameter, opts: list[tuple[str, str]] +) -> None: api = ctx.obj.api if not hasattr(api, "parameters"): raise click.BadParameter( @@ -147,7 +186,7 @@ def set_backend_option(ctx, param, opts): # pylint: disable=unused-argument ) all_fields = dict((f.name.replace("_", "-"), f) for f in fields(api.parameters)) - def set_one_backend_option(kv): + def set_one_backend_option(kv: tuple[str, str]) -> None: name, value = kv field = all_fields.get(name) if field is None: @@ -164,7 +203,7 @@ def set_backend_option(ctx, param, opts): # pylint: disable=unused-argument set_one_backend_option(kv) -def format_backend_list(formatter): +def format_backend_list(formatter: click.HelpFormatter) -> None: all_backends = [] for pi in pkgutil.walk_packages(backends.__path__): name = pi.name @@ -186,7 +225,7 @@ def format_backend_list(formatter): formatter.write_dl(rows) -def uses_session(f): +def uses_session(f: click.decorators.FC) -> Callable[[], None]: f = click.option( "--continue-session", "-s", @@ -198,18 +237,15 @@ def uses_session(f): f = click.option( "--last", is_flag=True, callback=do_session_last, expose_value=False )(f) - f = click.pass_obj(f) - return f + return click.pass_obj(f) -def command_uses_existing_session(f): - f = uses_session(f) - f = click.command()(f) - return f +def command_uses_existing_session(f: click.decorators.FC) -> click.Command: + return click.command()(uses_session(f)) -def command_uses_new_session(f): - f = uses_session(f) +def command_uses_new_session(f_in: click.decorators.FC) -> click.Command: + f = uses_session(f_in) f = click.option( "--new-session", "-n", @@ -218,11 +254,12 @@ def command_uses_new_session(f): callback=do_session_new, expose_value=False, )(f) - f = click.command()(f) - return f + return click.command()(f) -def version_callback(ctx, param, value) -> None: # pylint: disable=unused-argument +def version_callback( # pylint: disable=unused-argument + ctx: click.Context, param: click.Parameter, value: None +) -> None: if not value or ctx.resilient_parsing: return @@ -247,17 +284,24 @@ def version_callback(ctx, param, value) -> None: # pylint: disable=unused-argum @dataclass class Obj: - api: object = None - system_message: object = None + api: Backend | None = None + system_message: str | None = None session: list[Message] | None = None + session_filename: pathlib.Path | None = None class MyCLI(click.MultiCommand): - def make_context(self, info_name, args, parent=None, **extra): + def make_context( + self, + info_name: str | None, + args: list[str], + parent: click.Context | None = None, + **extra: Any, + ) -> click.Context: result = super().make_context(info_name, args, parent, obj=Obj(), **extra) return result - def list_commands(self, ctx): + def list_commands(self, ctx: click.Context) -> list[str]: rv = [] for pi in pkgutil.walk_packages(commands.__path__): name = pi.name @@ -266,13 +310,18 @@ class MyCLI(click.MultiCommand): rv.sort() return rv - def get_command(self, ctx, cmd_name): + def get_command(self, ctx: click.Context, cmd_name: str) -> click.Command: try: - return importlib.import_module("." + cmd_name, commands.__name__).main + return cast( + click.Command, + importlib.import_module("." + cmd_name, commands.__name__).main, + ) except ModuleNotFoundError as exc: raise click.UsageError(f"Invalid subcommand {cmd_name!r}", ctx) from exc - def format_options(self, ctx, formatter): + def format_options( + self, ctx: click.Context, formatter: click.HelpFormatter + ) -> None: super().format_options(ctx, formatter) api = ctx.obj.api or get_api() if hasattr(api, "parameters"): diff --git a/src/chap/key.py b/src/chap/key.py index 60cd14f..e292905 100644 --- a/src/chap/key.py +++ b/src/chap/key.py @@ -15,7 +15,7 @@ _key_path_base = platformdirs.user_config_path("chap") @functools.cache -def get_key(name, what="openai api key"): +def get_key(name: str, what: str = "openai api key") -> str: key_path = _key_path_base / name if not key_path.exists(): raise NoKeyAvailable( diff --git a/src/chap/session.py b/src/chap/session.py index 7489094..3982325 100644 --- a/src/chap/session.py +++ b/src/chap/session.py @@ -2,8 +2,14 @@ # # SPDX-License-Identifier: MIT +from __future__ import annotations + import json +import pathlib from dataclasses import asdict, dataclass +from typing import cast + +from typing_extensions import TypedDict # not an enum.Enum because these objects are not json-serializable, sigh @@ -21,40 +27,49 @@ class Message: content: str -def Assistant(content): +MessageDict = TypedDict("MessageDict", {"role": str, "content": int}) +Session = list[Message] +SessionDicts = list[MessageDict] + + +def Assistant(content: str) -> Message: return Message(Role.ASSISTANT, content) -def System(content): +def System(content: str) -> Message: return Message(Role.SYSTEM, content) -def User(content): +def User(content: str) -> Message: return Message(Role.USER, content) def new_session( system_message: str = "You're named Chap (they/them). Chap strives to answer truthfully and frankly. You have opinions, preferences, and emotions, you can say when you don't know, and you prefer to speak in plain language.", -): +) -> Session: return [System(system_message)] -def session_to_json(session): - return json.dumps([asdict(message) for message in session]) +def session_to_json(session: Session) -> str: + return json.dumps(session_to_list(session)) -def session_from_json(data): +def session_to_list(session: Session) -> SessionDicts: + return [cast(MessageDict, asdict(message)) for message in session] + + +def session_from_json(data: str) -> Session: j = json.loads(data) if isinstance(j, dict): j = j["session"] return [Message(**mapping) for mapping in j] -def session_from_file(path): +def session_from_file(path: pathlib.Path | str) -> Session: with open(path, "r", encoding="utf-8") as f: return session_from_json(f.read()) -def session_to_file(session, path): +def session_to_file(session: Session, path: pathlib.Path | str) -> None: with open(path, "w", encoding="utf-8") as f: - return f.write(session_to_json(session)) + f.write(session_to_json(session))