Update to newest API & add typing

This commit is contained in:
Jeff Epler 2023-11-09 11:09:18 -06:00
parent 8ef439f815
commit 593c7cee51
No known key found for this signature in database
GPG key ID: D5BF15AB975AB4DE
4 changed files with 54 additions and 19 deletions

1
.gitignore vendored
View file

@ -6,3 +6,4 @@ __pycache__
*.egg-info/
/build
/dist
/venv

15
Makefile Normal file
View file

@ -0,0 +1,15 @@
# SPDX-FileCopyrightText: 2023 Jeff Epler <jepler@gmail.com>
#
# SPDX-License-Identifier: MIT
.PHONY: mypy
mypy: venv/bin/mypy
venv/bin/mypy --strict -p chap
venv/bin/mypy:
python -mvenv venv
venv/bin/pip install chap mypy
.PHONY: clean
clean:
rm -rf venv

6
mypy.ini Normal file
View file

@ -0,0 +1,6 @@
# SPDX-FileCopyrightText: 2023 Jeff Epler <jepler@gmail.com>
#
# SPDX-License-Identifier: MIT
[mypy]
mypy_path = src

View file

@ -3,10 +3,10 @@
# SPDX-License-Identifier: MIT
import asyncio
import functools
import os
import random
from dataclasses import dataclass
from typing import Iterable, AsyncGenerator, Callable, Any
import click
@ -14,10 +14,20 @@ from ..session import ( # pylint: disable=relative-beyond-top-level
Assistant,
Session,
User,
session_from_file
)
from ..core import Backend
from typing import TYPE_CHECKING, TypeVar
if TYPE_CHECKING:
F = TypeVar('F', bound=Callable[..., Any])
def cached(f: F) -> F: return f
else:
from functools import lru_cache
def cached(f):
return lru_cache()(f)
def ipartition(s, sep=" "):
def ipartition(s: str, sep:str=" ") -> Iterable[tuple[str, str]]:
rest = s
while rest:
first, opt_sep, rest = rest.partition(sep)
@ -27,43 +37,46 @@ def ipartition(s, sep=" "):
class Replay:
@dataclass
class Parameters:
session: str = None
session: str | None = None
"""Complete path to an existing session file"""
delay_mu: float = 0.035
"""Average delay between tokens"""
delay_sigma: float = 0.02
"""Standard deviation of token delay"""
def __init__(self):
def __init__(self) -> None:
self.parameters = self.Parameters()
@property
@functools.lru_cache()
def _session(self):
@cached
def _session(self) -> Session:
if self.parameters.session is None:
raise click.BadParameter(
"Must specify -B session:/full/path/to/existing_session.json"
)
session_file = os.path.expanduser(self.parameters.session)
with open(session_file, "r", encoding="utf-8") as f:
return Session.from_json(f.read())
return session_from_file(session_file)
@property
@functools.lru_cache()
def _assistant_responses(self):
@cached
def _assistant_responses(self) -> Session:
return [
message for message in self._session.session if message.role == "assistant"
message for message in self._session if message.role == "assistant"
]
@property
def system_message(self):
def system_message(self) -> str:
num_assistant_responses = len(self._assistant_responses)
return (
f"Replay of {self.parameters.session} with {num_assistant_responses} responses. "
f"The original session system message was:\n\n{self._session.session[0].content}"
f"The original session system message was:\n\n{self._session[0].content}"
)
async def aask(self, session, query):
@system_message.setter
def system_message(self, value: str) -> None:
raise AttributeError("Read-only attribute 'system_message'")
async def aask(self, session: Session, query: str) -> AsyncGenerator[str, None]:
data = self.ask(session, query)
for word, opt_sep in ipartition(data):
yield word + opt_sep
@ -72,20 +85,20 @@ class Replay:
)
def ask(
self, session, query, *, max_query_size=5, timeout=60
): # pylint: disable=unused-argument
self, session: Session, query: str
) -> str: # pylint: disable=unused-argument
if self._assistant_responses:
idx = sum(
1 for message in session.session if message.role == "assistant"
1 for message in session if message.role == "assistant"
) % len(self._assistant_responses)
new_content = self._assistant_responses[idx].content
else:
new_content = "(No assistant responses in session)"
session.session.extend([User(query), Assistant(new_content)])
session.extend([User(query), Assistant(new_content)])
return new_content
def factory():
def factory() -> Backend:
"""Replay an existing session file. Useful for testing."""
return Replay()