diff --git a/.gitignore b/.gitignore index 0d69b13..a2578ad 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ __pycache__ *.egg-info/ /build /dist +/venv diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..1437bd3 --- /dev/null +++ b/Makefile @@ -0,0 +1,15 @@ +# SPDX-FileCopyrightText: 2023 Jeff Epler +# +# SPDX-License-Identifier: MIT + +.PHONY: mypy +mypy: venv/bin/mypy + venv/bin/mypy --strict -p chap + +venv/bin/mypy: + python -mvenv venv + venv/bin/pip install chap mypy + +.PHONY: clean +clean: + rm -rf venv diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000..4dbbc58 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,6 @@ +# SPDX-FileCopyrightText: 2023 Jeff Epler +# +# SPDX-License-Identifier: MIT + +[mypy] +mypy_path = src diff --git a/src/chap/backends/replay.py b/src/chap/backends/replay.py index 19e242b..ea55f03 100644 --- a/src/chap/backends/replay.py +++ b/src/chap/backends/replay.py @@ -3,10 +3,10 @@ # SPDX-License-Identifier: MIT import asyncio -import functools import os import random from dataclasses import dataclass +from typing import Iterable, AsyncGenerator, Callable, Any import click @@ -14,10 +14,20 @@ from ..session import ( # pylint: disable=relative-beyond-top-level Assistant, Session, User, + session_from_file ) +from ..core import Backend +from typing import TYPE_CHECKING, TypeVar +if TYPE_CHECKING: + F = TypeVar('F', bound=Callable[..., Any]) + def cached(f: F) -> F: return f +else: + from functools import lru_cache + def cached(f): + return lru_cache()(f) -def ipartition(s, sep=" "): +def ipartition(s: str, sep:str=" ") -> Iterable[tuple[str, str]]: rest = s while rest: first, opt_sep, rest = rest.partition(sep) @@ -27,43 +37,46 @@ def ipartition(s, sep=" "): class Replay: @dataclass class Parameters: - session: str = None + session: str | None = None """Complete path to an existing session file""" delay_mu: float = 0.035 """Average delay between tokens""" delay_sigma: float = 0.02 """Standard deviation of token delay""" - def __init__(self): + def __init__(self) -> None: self.parameters = self.Parameters() @property - @functools.lru_cache() - def _session(self): + @cached + def _session(self) -> Session: if self.parameters.session is None: raise click.BadParameter( "Must specify -B session:/full/path/to/existing_session.json" ) session_file = os.path.expanduser(self.parameters.session) - with open(session_file, "r", encoding="utf-8") as f: - return Session.from_json(f.read()) + return session_from_file(session_file) @property - @functools.lru_cache() - def _assistant_responses(self): + @cached + def _assistant_responses(self) -> Session: return [ - message for message in self._session.session if message.role == "assistant" + message for message in self._session if message.role == "assistant" ] @property - def system_message(self): + def system_message(self) -> str: num_assistant_responses = len(self._assistant_responses) return ( f"Replay of {self.parameters.session} with {num_assistant_responses} responses. " - f"The original session system message was:\n\n{self._session.session[0].content}" + f"The original session system message was:\n\n{self._session[0].content}" ) - async def aask(self, session, query): + @system_message.setter + def system_message(self, value: str) -> None: + raise AttributeError("Read-only attribute 'system_message'") + + async def aask(self, session: Session, query: str) -> AsyncGenerator[str, None]: data = self.ask(session, query) for word, opt_sep in ipartition(data): yield word + opt_sep @@ -72,20 +85,20 @@ class Replay: ) def ask( - self, session, query, *, max_query_size=5, timeout=60 - ): # pylint: disable=unused-argument + self, session: Session, query: str + ) -> str: # pylint: disable=unused-argument if self._assistant_responses: idx = sum( - 1 for message in session.session if message.role == "assistant" + 1 for message in session if message.role == "assistant" ) % len(self._assistant_responses) new_content = self._assistant_responses[idx].content else: new_content = "(No assistant responses in session)" - session.session.extend([User(query), Assistant(new_content)]) + session.extend([User(query), Assistant(new_content)]) return new_content -def factory(): +def factory() -> Backend: """Replay an existing session file. Useful for testing.""" return Replay()