Fully type annotate mypy

with two (commented) exceptions
This commit is contained in:
Jeff Epler 2023-11-09 09:48:22 -06:00
parent eac422ff17
commit fc69800594
No known key found for this signature in database
GPG key ID: D5BF15AB975AB4DE
17 changed files with 345 additions and 201 deletions

View file

@ -46,6 +46,9 @@ jobs:
- name: install deps
run: pip install -r requirements-dev.txt
- name: check types with mypy
run: make
- name: Build release
run: python -mbuild

15
Makefile Normal file
View file

@ -0,0 +1,15 @@
# SPDX-FileCopyrightText: 2023 Jeff Epler <jepler@gmail.com>
#
# SPDX-License-Identifier: MIT
.PHONY: mypy
mypy: venv/bin/mypy
venv/bin/mypy --strict -p chap
venv/bin/mypy:
python -mvenv venv
venv/bin/pip install -r requirements.txt mypy
.PHONY: clean
clean:
rm -rf venv

6
mypy.ini Normal file
View file

@ -0,0 +1,6 @@
# SPDX-FileCopyrightText: 2023 Jeff Epler <jepler@gmail.com>
#
# SPDX-License-Identifier: MIT
[mypy]
mypy_path = src

View file

@ -2,17 +2,18 @@
#
# SPDX-License-Identifier: MIT
import asyncio
import json
from dataclasses import dataclass
from typing import Any, AsyncGenerator
import httpx
from ..core import AutoAskMixin, Backend
from ..key import get_key
from ..session import Assistant, Role, User
from ..session import Assistant, Role, Session, User
class HuggingFace:
class HuggingFace(AutoAskMixin):
@dataclass
class Parameters:
url: str = "https://api-inference.huggingface.co"
@ -24,14 +25,15 @@ class HuggingFace:
after_assistant: str = """ </s><s>[INST] """
stop_token_id = 2
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.parameters = self.Parameters()
system_message = """\
A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, and knows its own limits.
"""
def make_full_query(self, messages, max_query_size):
def make_full_query(self, messages: Session, max_query_size: int) -> str:
del messages[1:-max_query_size]
result = [self.parameters.start_prompt]
for m in messages:
@ -48,7 +50,9 @@ A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, a
full_query = "".join(result)
return full_query
async def chained_query(self, inputs, timeout):
async def chained_query(
self, inputs: Any, timeout: float
) -> AsyncGenerator[str, None]:
async with httpx.AsyncClient(timeout=timeout) as client:
while inputs:
params = {
@ -79,9 +83,16 @@ A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, a
return
async def aask(
self, session, query, *, max_query_size=5, timeout=180
): # pylint: disable=unused-argument,too-many-locals,too-many-branches
new_content = []
self,
session: Session,
query: str,
*,
max_query_size: int = 5,
timeout: float = 180,
) -> AsyncGenerator[
str, None
]: # pylint: disable=unused-argument,too-many-locals,too-many-branches
new_content: list[str] = []
inputs = self.make_full_query(session + [User(query)], max_query_size)
try:
async for content in self.chained_query(inputs, timeout=timeout):
@ -101,17 +112,11 @@ A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, a
session.extend([User(query), Assistant("".join(new_content))])
def ask(self, session, query, *, max_query_size=5, timeout=60):
asyncio.run(
self.aask(session, query, max_query_size=max_query_size, timeout=timeout)
)
return session[-1].content
@classmethod
def get_key(cls):
def get_key(cls) -> str:
return get_key("huggingface_api_token")
def factory():
def factory() -> Backend:
"""Uses the huggingface text-generation-interface web API"""
return HuggingFace()

View file

@ -2,16 +2,17 @@
#
# SPDX-License-Identifier: MIT
import asyncio
import json
from dataclasses import dataclass
from typing import AsyncGenerator
import httpx
from ..session import Assistant, Role, User
from ..core import AutoAskMixin, Backend
from ..session import Assistant, Role, Session, User
class LlamaCpp:
class LlamaCpp(AutoAskMixin):
@dataclass
class Parameters:
url: str = "http://localhost:8080/completion"
@ -22,14 +23,15 @@ class LlamaCpp:
after_user: str = """ [/INST] """
after_assistant: str = """ </s><s>[INST] """
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.parameters = self.Parameters()
system_message = """\
A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, and knows its own limits.
"""
def make_full_query(self, messages, max_query_size):
def make_full_query(self, messages: Session, max_query_size: int) -> str:
del messages[1:-max_query_size]
result = [self.parameters.start_prompt]
for m in messages:
@ -47,14 +49,21 @@ A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, a
return full_query
async def aask(
self, session, query, *, max_query_size=5, timeout=180
): # pylint: disable=unused-argument,too-many-locals,too-many-branches
self,
session: Session,
query: str,
*,
max_query_size: int = 5,
timeout: float = 180,
) -> AsyncGenerator[
str, None
]: # pylint: disable=unused-argument,too-many-locals,too-many-branches
params = {
"prompt": self.make_full_query(session + [User(query)], max_query_size),
"stream": True,
"stop": ["</s>", "<s>", "[INST]"],
}
new_content = []
new_content: list[str] = []
try:
async with httpx.AsyncClient(timeout=timeout) as client:
async with client.stream(
@ -87,13 +96,7 @@ A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, a
session.extend([User(query), Assistant("".join(new_content))])
def ask(self, session, query, *, max_query_size=5, timeout=60):
asyncio.run(
self.aask(session, query, max_query_size=max_query_size, timeout=timeout)
)
return session[-1].content
def factory():
def factory() -> Backend:
"""Uses the llama.cpp completion web API"""
return LlamaCpp()

View file

@ -5,13 +5,16 @@
import asyncio
import random
from dataclasses import dataclass
from typing import AsyncGenerator, Iterable, cast
from lorem_text import lorem
# lorem is not type annotated
from lorem_text import lorem # type: ignore
from ..session import Assistant, User
from ..core import Backend
from ..session import Assistant, Session, User
def ipartition(s, sep=" "):
def ipartition(s: str, sep: str = " ") -> Iterable[tuple[str, str]]:
rest = s
while rest:
first, opt_sep, rest = rest.partition(sep)
@ -30,15 +33,19 @@ class Lorem:
paragraph_hi: int = 5
"""Maximum response paragraph count (inclusive)"""
def __init__(self):
def __init__(self) -> None:
self.parameters = self.Parameters()
system_message = (
"(It doesn't matter what you ask, this backend will respond with lorem)"
)
async def aask(self, session, query, *, max_query_size=5, timeout=60):
data = self.ask(session, query, max_query_size=max_query_size, timeout=timeout)
async def aask(
self,
session: Session,
query: str,
) -> AsyncGenerator[str, None]:
data = self.ask(session, query)[-1]
for word, opt_sep in ipartition(data):
yield word + opt_sep
await asyncio.sleep(
@ -46,15 +53,22 @@ class Lorem:
)
def ask(
self, session, query, *, max_query_size=5, timeout=60
): # pylint: disable=unused-argument
new_content = lorem.paragraphs(
random.randint(self.parameters.paragraph_lo, self.parameters.paragraph_hi)
self,
session: Session,
query: str,
) -> str: # pylint: disable=unused-argument
new_content = cast(
str,
lorem.paragraphs(
random.randint(
self.parameters.paragraph_lo, self.parameters.paragraph_hi
)
),
).replace("\n", "\n\n")
session.extend([User(query), Assistant("".join(new_content))])
return new_content
return session[-1].content
def factory():
def factory() -> Backend:
"""That just prints 'lorem' text. Useful for testing."""
return Lorem()

View file

@ -5,12 +5,14 @@
import functools
import json
from dataclasses import dataclass
from typing import AsyncGenerator, cast
import httpx
import tiktoken
from ..core import Backend
from ..key import get_key
from ..session import Assistant, User
from ..session import Assistant, Message, Session, User, session_to_list
@dataclass(frozen=True)
@ -20,19 +22,19 @@ class EncodingMeta:
tokens_per_name: int
@functools.lru_cache()
def encode(self, s):
def encode(self, s: str) -> list[int]:
return self.encoding.encode(s)
def num_tokens_for_message(self, message):
def num_tokens_for_message(self, message: Message) -> int:
# n.b. chap doesn't use message.name yet
return len(self.encode(message.role)) + len(self.encode(message.content))
def num_tokens_for_messages(self, messages):
def num_tokens_for_messages(self, messages: Session) -> int:
return sum(self.num_tokens_for_message(message) for message in messages) + 3
@classmethod
@functools.cache
def from_model(cls, model):
def from_model(cls, model: str) -> "EncodingMeta":
if model == "gpt-3.5-turbo":
# print("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.")
model = "gpt-3.5-turbo-0613"
@ -77,12 +79,12 @@ class ChatGPT:
max_request_tokens: int = 1024
"""The approximate greatest number of tokens to send in a request. When the session is long, the system prompt and 1 or more of the most recent interaction steps are sent."""
def __init__(self):
def __init__(self) -> None:
self.parameters = self.Parameters()
system_message: str = "You're named Chap (they/them). Chap strives to answer truthfully and frankly. You have opinions, preferences, and emotions, you can say when you don't know, and you prefer to speak in plain language."
def make_full_prompt(self, all_history):
def make_full_prompt(self, all_history: Session) -> Session:
encoding = EncodingMeta.from_model(self.parameters.model)
result = [all_history[0]] # Assumed to be system prompt
left = self.parameters.max_request_tokens - encoding.num_tokens_for_messages(
@ -99,15 +101,13 @@ class ChatGPT:
result.extend(reversed(parts))
return result
def ask(self, session, query, *, timeout=60):
def ask(self, session: Session, query: str, *, timeout: float = 60) -> str:
full_prompt = self.make_full_prompt(session + [User(query)])
response = httpx.post(
"https://api.openai.com/v1/chat/completions",
json={
"model": self.parameters.model,
"messages": full_prompt.to_dict()[ # pylint: disable=no-member
"session"
],
"messages": session_to_list(full_prompt),
}, # pylint: disable=no-member
headers={
"Authorization": f"Bearer {self.get_key()}",
@ -115,20 +115,20 @@ class ChatGPT:
timeout=timeout,
)
if response.status_code != 200:
print("Failure", response.status_code, response.text)
return None
return f"Failure {response.text} ({response.status_code})"
try:
j = response.json()
result = j["choices"][0]["message"]["content"]
result = cast(str, j["choices"][0]["message"]["content"])
except (KeyError, IndexError, json.decoder.JSONDecodeError):
print("Failure", response.status_code, response.text)
return None
return f"Failure {response.text} ({response.status_code})"
session.extend([User(query), Assistant(result)])
return result
async def aask(self, session, query, *, timeout=60):
async def aask(
self, session: Session, query: str, *, timeout: float = 60
) -> AsyncGenerator[str, None]:
full_prompt = self.make_full_prompt(session + [User(query)])
new_content = []
try:
@ -140,9 +140,7 @@ class ChatGPT:
json={
"model": self.parameters.model,
"stream": True,
"messages": full_prompt.to_dict()[ # pylint: disable=no-member
"session"
], # pylint: disable=no-member
"messages": session_to_list(full_prompt),
},
) as response:
if response.status_code == 200:
@ -170,10 +168,10 @@ class ChatGPT:
session.extend([User(query), Assistant("".join(new_content))])
@classmethod
def get_key(cls):
def get_key(cls) -> str:
return get_key("openai_api_key")
def factory():
def factory() -> Backend:
"""Uses the OpenAI chat completion API"""
return ChatGPT()

View file

@ -2,19 +2,25 @@
#
# SPDX-License-Identifier: MIT
import asyncio
import json
import uuid
from dataclasses import dataclass
from typing import AsyncGenerator
import websockets
from ..key import get_key
from ..session import Assistant, Role, User
from ..core import AutoAskMixin, Backend
from ..session import Assistant, Role, Session, User
class Textgen:
def __init__(self):
self.server = get_key("textgen_url", "textgen server URL")
class Textgen(AutoAskMixin):
@dataclass
class Parameters:
server_hostname: str = "localhost"
def __init__(self) -> None:
super().__init__()
self.parameters = self.Parameters()
system_message = """\
A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, and knows its own limits.
@ -23,9 +29,14 @@ USER: Hello, AI.
AI: Hello! How can I assist you today?"""
async def aask(
self, session, query, *, max_query_size=5, timeout=60
): # pylint: disable=unused-argument,too-many-locals,too-many-branches
async def aask( # pylint: disable=unused-argument,too-many-locals,too-many-branches
self,
session: Session,
query: str,
*,
max_query_size: int = 5,
timeout: float = 60,
) -> AsyncGenerator[str, None]:
params = {
"max_new_tokens": 200,
"do_sample": True,
@ -55,7 +66,7 @@ AI: Hello! How can I assist you today?"""
)
try:
async with websockets.connect( # pylint: disable=no-member
f"ws://{self.server}:7860/queue/join"
f"ws://{self.parameters.server_hostname}:7860/queue/join"
) as websocket:
while content := json.loads(await websocket.recv()):
if content["msg"] == "send_hash":
@ -124,13 +135,7 @@ AI: Hello! How can I assist you today?"""
all_response = new_data[len(full_query) :]
session.extend([User(query), Assistant(all_response)])
def ask(self, session, query, *, max_query_size=5, timeout=60):
asyncio.run(
self.aask(session, query, max_query_size=max_query_size, timeout=timeout)
)
return session[-1].content
def factory():
def factory() -> Backend:
"""Uses the textgen completion API"""
return Textgen()

View file

@ -4,43 +4,52 @@
import asyncio
import sys
from typing import Iterable, Protocol
import click
import rich
from ..core import command_uses_new_session
from ..session import session_to_file
from ..core import Backend, Obj, command_uses_new_session
from ..session import Session, session_to_file
bold = "\033[1m"
nobold = "\033[m"
def ipartition(s, sep):
def ipartition(s: str, sep: str) -> Iterable[tuple[str, str]]:
rest = s
while rest:
first, opt_sep, rest = rest.partition(sep)
yield (first, opt_sep)
class Printable(Protocol):
def raw(self, s: str) -> None:
...
def add(self, s: str) -> None:
...
class DumbPrinter:
def raw(self, s):
def raw(self, s: str) -> None:
pass
def add(self, s):
def add(self, s: str) -> None:
print(s, end="")
class WrappingPrinter:
def __init__(self, width=None):
def __init__(self, width: int | None = None) -> None:
self._width = width or rich.get_console().width
self._column = 0
self._line = ""
self._sp = ""
def raw(self, s):
def raw(self, s: str) -> None:
print(s, end="")
def add(self, s):
def add(self, s: str) -> None:
for line, opt_nl in ipartition(s, "\n"):
for word, opt_sp in ipartition(line, " "):
newlen = len(self._line) + len(self._sp) + len(word)
@ -64,15 +73,16 @@ class WrappingPrinter:
self._sp = ""
def verbose_ask(api, session, q, print_prompt, **kw):
def verbose_ask(api: Backend, session: Session, q: str, print_prompt: bool) -> str:
printer: Printable
if sys.stdout.isatty():
printer = WrappingPrinter()
else:
printer = DumbPrinter()
tokens = []
tokens: list[str] = []
async def work():
async for token in api.aask(session, q, **kw):
async def work() -> None:
async for token in api.aask(session, q):
printer.add(token)
if print_prompt:
@ -91,11 +101,16 @@ def verbose_ask(api, session, q, print_prompt, **kw):
@command_uses_new_session
@click.option("--print-prompt/--no-print-prompt", default=True)
@click.argument("prompt", nargs=-1, required=True)
def main(obj, prompt, print_prompt):
def main(obj: Obj, prompt: str, print_prompt: bool) -> None:
"""Ask a question (command-line argument is passed as prompt)"""
session = obj.session
assert session is not None
session_filename = obj.session_filename
assert session_filename is not None
api = obj.api
assert api is not None
# symlink_session_filename(session_filename)

View file

@ -4,15 +4,17 @@
import click
from ..core import command_uses_existing_session
from ..core import Obj, command_uses_existing_session
from ..session import Role
@command_uses_existing_session
@click.option("--no-system", is_flag=True)
def main(obj, no_system):
def main(obj: Obj, no_system: bool) -> None:
"""Print session in plaintext"""
session = obj.session
if not session:
return
first = True
for row in session:

View file

@ -18,7 +18,7 @@ from .render import to_markdown
def list_files_matching_rx(
rx: re.Pattern, conversations_path: Optional[pathlib.Path] = None
rx: re.Pattern[str], conversations_path: Optional[pathlib.Path] = None
) -> Iterable[Tuple[pathlib.Path, Message]]:
for conversation in (conversations_path or default_conversations_path).glob(
"*.json"
@ -39,7 +39,9 @@ def list_files_matching_rx(
@click.option("--files-with-matches", "-l", is_flag=True)
@click.option("--fixed-strings", "--literal", "-F", is_flag=True)
@click.argument("pattern", nargs=1, required=True)
def main(ignore_case, files_with_matches, fixed_strings, pattern):
def main(
ignore_case: bool, files_with_matches: bool, fixed_strings: bool, pattern: str
) -> None:
"""Search sessions for pattern"""
console = rich.get_console()
if fixed_strings:
@ -47,7 +49,7 @@ def main(ignore_case, files_with_matches, fixed_strings, pattern):
rx = re.compile(pattern, re.I if ignore_case else 0)
last_file = None
for f, m in list_files_matching_rx(rx, ignore_case):
for f, m in list_files_matching_rx(rx, None):
if f != last_file:
if files_with_matches:
print(f)

View file

@ -6,17 +6,20 @@ from __future__ import annotations
import json
import pathlib
from typing import Any, Iterator, TextIO
import click
import rich
from ..core import conversations_path, new_session_path
from ..session import Message, Role, new_session, session_to_file
from ..session import Message, Role, Session, new_session, session_to_file
console = rich.get_console()
def iter_sessions(name, content, session_in, node_id):
def iter_sessions(
name: str, content: Any, session_in: Session, node_id: str
) -> Iterator[tuple[str, Session]]:
node = content["mapping"][node_id]
session = session_in[:]
@ -37,7 +40,7 @@ def iter_sessions(name, content, session_in, node_id):
yield node_id, session
def do_import(output_directory, f):
def do_import(output_directory: pathlib.Path, f: TextIO) -> None:
stem = pathlib.Path(f.name).stem
content = json.load(f)
session = new_session()
@ -66,7 +69,7 @@ def do_import(output_directory, f):
@click.argument(
"files", nargs=-1, required=True, type=click.File("r", encoding="utf-8")
)
def main(output_directory, files):
def main(output_directory: pathlib.Path, files: list[TextIO]) -> None:
"""Import files from the ChatGPT webui
This understands the format produced by

View file

@ -7,11 +7,11 @@ import rich
from markdown_it import MarkdownIt
from rich.markdown import Markdown
from ..core import command_uses_existing_session
from ..session import Role
from ..core import Obj, command_uses_existing_session
from ..session import Message, Role
def to_markdown(message):
def to_markdown(message: Message) -> Markdown:
role = message.role
if role == Role.USER:
style = "bold"
@ -28,9 +28,10 @@ def to_markdown(message):
@command_uses_existing_session
@click.option("--no-system", is_flag=True)
def main(obj, no_system):
def main(obj: Obj, no_system: bool) -> None:
"""Print session with formatting"""
session = obj.session
assert session is not None
console = rich.get_console()
first = True

View file

@ -5,19 +5,20 @@
import asyncio
import subprocess
import sys
from typing import cast
from markdown_it import MarkdownIt
from textual import work
from textual.app import App
from textual.app import App, ComposeResult
from textual.binding import Binding
from textual.containers import Container, Horizontal, VerticalScroll
from textual.widgets import Button, Footer, Input, LoadingIndicator, Markdown
from ..core import command_uses_new_session, get_api, new_session_path
from ..session import Assistant, User, new_session, session_to_file
from ..core import Backend, Obj, command_uses_new_session, get_api, new_session_path
from ..session import Assistant, Message, Session, User, new_session, session_to_file
def parser_factory():
def parser_factory() -> MarkdownIt:
parser = MarkdownIt()
parser.options["html"] = False
return parser
@ -34,7 +35,7 @@ class ChapMarkdown(
]
def markdown_for_step(step):
def markdown_for_step(step: Message) -> ChapMarkdown:
return ChapMarkdown(
step.content.strip() or "",
classes="role_" + step.role,
@ -48,13 +49,15 @@ class CancelButton(Button):
]
class Tui(App):
class Tui(App[None]):
CSS_PATH = "tui.css"
BINDINGS = [
Binding("ctrl+c", "quit", "Quit", show=True, priority=True),
]
def __init__(self, api=None, session=None):
def __init__(
self, api: Backend | None = None, session: Session | None = None
) -> None:
super().__init__()
self.api = api or get_api("lorem")
self.session = (
@ -62,26 +65,26 @@ class Tui(App):
)
@property
def spinner(self):
def spinner(self) -> LoadingIndicator:
return self.query_one(LoadingIndicator)
@property
def wait(self):
return self.query_one("#wait")
def wait(self) -> VerticalScroll:
return cast(VerticalScroll, self.query_one("#wait"))
@property
def input(self):
def input(self) -> Input:
return self.query_one(Input)
@property
def cancel_button(self):
def cancel_button(self) -> CancelButton:
return self.query_one(CancelButton)
@property
def container(self):
return self.query_one("#content")
def container(self) -> VerticalScroll:
return cast(VerticalScroll, self.query_one("#content"))
def compose(self):
def compose(self) -> ComposeResult:
yield Footer()
yield VerticalScroll(
*[markdown_for_step(step) for step in self.session],
@ -100,11 +103,11 @@ class Tui(App):
self.container.scroll_end(animate=False)
self.input.focus()
async def on_input_submitted(self, event) -> None:
async def on_input_submitted(self, event: Input.Submitted) -> None:
self.get_completion(event.value)
@work(exclusive=True)
async def get_completion(self, query):
async def get_completion(self, query: str) -> None:
self.scroll_end()
self.input.styles.display = "none"
@ -117,8 +120,8 @@ class Tui(App):
await self.container.mount_all(
[markdown_for_step(User(query)), output], before="#pad"
)
tokens = []
update = asyncio.Queue(1)
tokens: list[str] = []
update: asyncio.Queue[bool] = asyncio.Queue(1)
for markdown in self.container.children:
markdown.disabled = True
@ -137,14 +140,14 @@ class Tui(App):
]
)
async def render_fun():
async def render_fun() -> None:
while await update.get():
if tokens:
output.update("".join(tokens).strip())
self.container.scroll_end()
await asyncio.sleep(0.1)
async def get_token_fun():
async def get_token_fun() -> None:
async for token in self.api.aask(session, query):
tokens.append(token)
message.content += token
@ -174,16 +177,16 @@ class Tui(App):
self.cancel_button.disabled = True
self.input.focus()
def scroll_end(self):
def scroll_end(self) -> None:
self.call_after_refresh(self.container.scroll_end)
def action_yank(self):
def action_yank(self) -> None:
widget = self.focused
if isinstance(widget, ChapMarkdown):
content = widget._markdown # pylint: disable=protected-access
content = widget._markdown or "" # pylint: disable=protected-access
subprocess.run(["xsel", "-ib"], input=content.encode("utf-8"), check=False)
def action_toggle_history(self):
def action_toggle_history(self) -> None:
widget = self.focused
if not isinstance(widget, ChapMarkdown):
return
@ -199,23 +202,25 @@ class Tui(App):
for m in children[idx : idx + 2]:
m.toggle_class("history_exclude")
async def action_stop_generating(self):
async def action_stop_generating(self) -> None:
self.workers.cancel_all()
async def on_button_pressed(self, event): # pylint: disable=unused-argument
async def on_button_pressed( # pylint: disable=unused-argument
self, event: Button.Pressed
) -> None:
self.workers.cancel_all()
async def action_quit(self):
async def action_quit(self) -> None:
self.workers.cancel_all()
self.exit()
async def action_resubmit(self):
async def action_resubmit(self) -> None:
await self.redraft_or_resubmit(True)
async def action_redraft(self):
async def action_redraft(self) -> None:
await self.redraft_or_resubmit(False)
async def redraft_or_resubmit(self, resubmit):
async def redraft_or_resubmit(self, resubmit: bool) -> None:
widget = self.focused
if not isinstance(widget, ChapMarkdown):
return
@ -244,11 +249,14 @@ class Tui(App):
@command_uses_new_session
def main(obj):
def main(obj: Obj) -> None:
"""Start interactive terminal user interface session"""
api = obj.api
assert api is not None
session = obj.session
assert session is not None
session_filename = obj.session_filename
assert session_filename is not None
tui = Tui(api, session)
tui.run()

View file

@ -3,6 +3,7 @@
# SPDX-License-Identifier: MIT
# pylint: disable=import-outside-toplevel
import asyncio
import datetime
import importlib
import os
@ -10,33 +11,62 @@ import pathlib
import pkgutil
import subprocess
from dataclasses import MISSING, dataclass, fields
from typing import Any, AsyncGenerator, Callable, cast
import click
import platformdirs
from simple_parsing.docstring import get_attribute_docstring
from typing_extensions import Protocol
from . import backends, commands # pylint: disable=no-name-in-module
from .session import Message, System, session_from_file
from .session import Message, Session, System, session_from_file
conversations_path = platformdirs.user_state_path("chap") / "conversations"
conversations_path.mkdir(parents=True, exist_ok=True)
def last_session_path():
class ABackend(Protocol): # pylint: disable=too-few-public-methods
def aask(self, session: Session, query: str) -> AsyncGenerator[str, None]:
"""Make a query, updating the session with the query and response, returning the query token by token"""
class Backend(ABackend, Protocol):
parameters: Any
system_message: str
def ask(self, session: Session, query: str) -> str:
"""Make a query, updating the session with the query and response, returning the query"""
class AutoAskMixin: # pylint: disable=too-few-public-methods
"""Mixin class for backends implementing aask"""
def ask(self, session: Session, query: str) -> str:
tokens: list[str] = []
async def inner() -> None:
# https://github.com/pylint-dev/pylint/issues/5761
async for token in self.aask(session, query): # type: ignore
tokens.append(token)
asyncio.run(inner())
return "".join(tokens)
def last_session_path() -> pathlib.Path | None:
result = max(
conversations_path.glob("*.json"), key=lambda p: p.stat().st_mtime, default=None
)
print(result)
return result
def new_session_path(opt_path=None):
def new_session_path(opt_path: pathlib.Path | None = None) -> pathlib.Path:
return opt_path or conversations_path / (
datetime.datetime.now().isoformat().replace(":", "_") + ".json"
)
def configure_api_from_environment(api_name, api):
def configure_api_from_environment(api_name: str, api: Backend) -> None:
if not hasattr(api, "parameters"):
return
@ -54,44 +84,46 @@ def configure_api_from_environment(api_name, api):
setattr(api.parameters, field.name, tv)
def get_api(name="openai_chatgpt"):
def get_api(name: str = "openai_chatgpt") -> Backend:
name = name.replace("-", "_")
result = importlib.import_module(f"{__package__}.backends.{name}").factory()
result = cast(
Backend, importlib.import_module(f"{__package__}.backends.{name}").factory()
)
configure_api_from_environment(name, result)
return result
def ask(*args, **kw):
return get_api().ask(*args, **kw)
def aask(*args, **kw):
return get_api().aask(*args, **kw)
def do_session_continue(ctx, param, value):
def do_session_continue(
ctx: click.Context, param: click.Parameter, value: pathlib.Path | None
) -> None:
if value is None:
return
if ctx.obj.session is not None:
raise click.BadParameter(
param, "--continue-session, --last and --new-session are mutually exclusive"
"--continue-session, --last and --new-session are mutually exclusive",
param=param,
)
ctx.obj.session = session_from_file(value)
ctx.obj.session_filename = value
def do_session_last(ctx, param, value): # pylint: disable=unused-argument
def do_session_last(
ctx: click.Context, param: click.Parameter, value: bool
) -> None: # pylint: disable=unused-argument
if not value:
return
do_session_continue(ctx, param, last_session_path())
def do_session_new(ctx, param, value):
def do_session_new(
ctx: click.Context, param: click.Parameter, value: pathlib.Path
) -> None:
if ctx.obj.session is not None:
if value is None:
return
raise click.BadOptionUsage(
param, "--continue-session, --last and --new-session are mutually exclusive"
raise click.BadParameter(
"--continue-session, --last and --new-session are mutually exclusive",
param=param,
)
session_filename = new_session_path(value)
system_message = ctx.obj.system_message or ctx.obj.api.system_message
@ -99,20 +131,24 @@ def do_session_new(ctx, param, value):
ctx.obj.session_filename = session_filename
def colonstr(arg):
def colonstr(arg: str) -> tuple[str, str]:
if ":" not in arg:
raise click.BadParameter("must be of the form 'name:value'")
return arg.split(":", 1)
return cast(tuple[str, str], tuple(arg.split(":", 1)))
def set_system_message(ctx, param, value): # pylint: disable=unused-argument
def set_system_message( # pylint: disable=unused-argument
ctx: click.Context, param: click.Parameter, value: str
) -> None:
if value and value.startswith("@"):
with open(value[1:], "r", encoding="utf-8") as f:
value = f.read().rstrip()
ctx.obj.system_message = value
def set_backend(ctx, param, value): # pylint: disable=unused-argument
def set_backend( # pylint: disable=unused-argument
ctx: click.Context, param: click.Parameter, value: str
) -> None:
if value == "list":
formatter = ctx.make_formatter()
format_backend_list(formatter)
@ -125,7 +161,7 @@ def set_backend(ctx, param, value): # pylint: disable=unused-argument
raise click.BadParameter(str(e))
def format_backend_help(api, formatter):
def format_backend_help(api: Backend, formatter: click.HelpFormatter) -> None:
with formatter.section(f"Backend options for {api.__class__.__name__}"):
rows = []
for f in fields(api.parameters):
@ -135,11 +171,14 @@ def format_backend_help(api, formatter):
if doc:
doc += " "
doc += f"(Default: {default!r})"
rows.append((f"-B {name}:{f.type.__name__.upper()}", doc))
typename = f.type.__name__
rows.append((f"-B {name}:{typename.upper()}", doc))
formatter.write_dl(rows)
def set_backend_option(ctx, param, opts): # pylint: disable=unused-argument
def set_backend_option( # pylint: disable=unused-argument
ctx: click.Context, param: click.Parameter, opts: list[tuple[str, str]]
) -> None:
api = ctx.obj.api
if not hasattr(api, "parameters"):
raise click.BadParameter(
@ -147,7 +186,7 @@ def set_backend_option(ctx, param, opts): # pylint: disable=unused-argument
)
all_fields = dict((f.name.replace("_", "-"), f) for f in fields(api.parameters))
def set_one_backend_option(kv):
def set_one_backend_option(kv: tuple[str, str]) -> None:
name, value = kv
field = all_fields.get(name)
if field is None:
@ -164,7 +203,7 @@ def set_backend_option(ctx, param, opts): # pylint: disable=unused-argument
set_one_backend_option(kv)
def format_backend_list(formatter):
def format_backend_list(formatter: click.HelpFormatter) -> None:
all_backends = []
for pi in pkgutil.walk_packages(backends.__path__):
name = pi.name
@ -186,7 +225,7 @@ def format_backend_list(formatter):
formatter.write_dl(rows)
def uses_session(f):
def uses_session(f: click.decorators.FC) -> Callable[[], None]:
f = click.option(
"--continue-session",
"-s",
@ -198,18 +237,15 @@ def uses_session(f):
f = click.option(
"--last", is_flag=True, callback=do_session_last, expose_value=False
)(f)
f = click.pass_obj(f)
return f
return click.pass_obj(f)
def command_uses_existing_session(f):
f = uses_session(f)
f = click.command()(f)
return f
def command_uses_existing_session(f: click.decorators.FC) -> click.Command:
return click.command()(uses_session(f))
def command_uses_new_session(f):
f = uses_session(f)
def command_uses_new_session(f_in: click.decorators.FC) -> click.Command:
f = uses_session(f_in)
f = click.option(
"--new-session",
"-n",
@ -218,11 +254,12 @@ def command_uses_new_session(f):
callback=do_session_new,
expose_value=False,
)(f)
f = click.command()(f)
return f
return click.command()(f)
def version_callback(ctx, param, value) -> None: # pylint: disable=unused-argument
def version_callback( # pylint: disable=unused-argument
ctx: click.Context, param: click.Parameter, value: None
) -> None:
if not value or ctx.resilient_parsing:
return
@ -247,17 +284,24 @@ def version_callback(ctx, param, value) -> None: # pylint: disable=unused-argum
@dataclass
class Obj:
api: object = None
system_message: object = None
api: Backend | None = None
system_message: str | None = None
session: list[Message] | None = None
session_filename: pathlib.Path | None = None
class MyCLI(click.MultiCommand):
def make_context(self, info_name, args, parent=None, **extra):
def make_context(
self,
info_name: str | None,
args: list[str],
parent: click.Context | None = None,
**extra: Any,
) -> click.Context:
result = super().make_context(info_name, args, parent, obj=Obj(), **extra)
return result
def list_commands(self, ctx):
def list_commands(self, ctx: click.Context) -> list[str]:
rv = []
for pi in pkgutil.walk_packages(commands.__path__):
name = pi.name
@ -266,13 +310,18 @@ class MyCLI(click.MultiCommand):
rv.sort()
return rv
def get_command(self, ctx, cmd_name):
def get_command(self, ctx: click.Context, cmd_name: str) -> click.Command:
try:
return importlib.import_module("." + cmd_name, commands.__name__).main
return cast(
click.Command,
importlib.import_module("." + cmd_name, commands.__name__).main,
)
except ModuleNotFoundError as exc:
raise click.UsageError(f"Invalid subcommand {cmd_name!r}", ctx) from exc
def format_options(self, ctx, formatter):
def format_options(
self, ctx: click.Context, formatter: click.HelpFormatter
) -> None:
super().format_options(ctx, formatter)
api = ctx.obj.api or get_api()
if hasattr(api, "parameters"):

View file

@ -15,7 +15,7 @@ _key_path_base = platformdirs.user_config_path("chap")
@functools.cache
def get_key(name, what="openai api key"):
def get_key(name: str, what: str = "openai api key") -> str:
key_path = _key_path_base / name
if not key_path.exists():
raise NoKeyAvailable(

View file

@ -2,8 +2,14 @@
#
# SPDX-License-Identifier: MIT
from __future__ import annotations
import json
import pathlib
from dataclasses import asdict, dataclass
from typing import cast
from typing_extensions import TypedDict
# not an enum.Enum because these objects are not json-serializable, sigh
@ -21,40 +27,49 @@ class Message:
content: str
def Assistant(content):
MessageDict = TypedDict("MessageDict", {"role": str, "content": int})
Session = list[Message]
SessionDicts = list[MessageDict]
def Assistant(content: str) -> Message:
return Message(Role.ASSISTANT, content)
def System(content):
def System(content: str) -> Message:
return Message(Role.SYSTEM, content)
def User(content):
def User(content: str) -> Message:
return Message(Role.USER, content)
def new_session(
system_message: str = "You're named Chap (they/them). Chap strives to answer truthfully and frankly. You have opinions, preferences, and emotions, you can say when you don't know, and you prefer to speak in plain language.",
):
) -> Session:
return [System(system_message)]
def session_to_json(session):
return json.dumps([asdict(message) for message in session])
def session_to_json(session: Session) -> str:
return json.dumps(session_to_list(session))
def session_from_json(data):
def session_to_list(session: Session) -> SessionDicts:
return [cast(MessageDict, asdict(message)) for message in session]
def session_from_json(data: str) -> Session:
j = json.loads(data)
if isinstance(j, dict):
j = j["session"]
return [Message(**mapping) for mapping in j]
def session_from_file(path):
def session_from_file(path: pathlib.Path | str) -> Session:
with open(path, "r", encoding="utf-8") as f:
return session_from_json(f.read())
def session_to_file(session, path):
def session_to_file(session: Session, path: pathlib.Path | str) -> None:
with open(path, "w", encoding="utf-8") as f:
return f.write(session_to_json(session))
f.write(session_to_json(session))