Update to newest API & add typing
This commit is contained in:
parent
8ef439f815
commit
593c7cee51
4 changed files with 54 additions and 19 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -6,3 +6,4 @@ __pycache__
|
|||
*.egg-info/
|
||||
/build
|
||||
/dist
|
||||
/venv
|
||||
|
|
|
|||
15
Makefile
Normal file
15
Makefile
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
# SPDX-FileCopyrightText: 2023 Jeff Epler <jepler@gmail.com>
|
||||
#
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
.PHONY: mypy
|
||||
mypy: venv/bin/mypy
|
||||
venv/bin/mypy --strict -p chap
|
||||
|
||||
venv/bin/mypy:
|
||||
python -mvenv venv
|
||||
venv/bin/pip install chap mypy
|
||||
|
||||
.PHONY: clean
|
||||
clean:
|
||||
rm -rf venv
|
||||
6
mypy.ini
Normal file
6
mypy.ini
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
# SPDX-FileCopyrightText: 2023 Jeff Epler <jepler@gmail.com>
|
||||
#
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
[mypy]
|
||||
mypy_path = src
|
||||
|
|
@ -3,10 +3,10 @@
|
|||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import os
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, AsyncGenerator, Callable, Any
|
||||
|
||||
import click
|
||||
|
||||
|
|
@ -14,10 +14,20 @@ from ..session import ( # pylint: disable=relative-beyond-top-level
|
|||
Assistant,
|
||||
Session,
|
||||
User,
|
||||
session_from_file
|
||||
)
|
||||
from ..core import Backend
|
||||
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
if TYPE_CHECKING:
|
||||
F = TypeVar('F', bound=Callable[..., Any])
|
||||
def cached(f: F) -> F: return f
|
||||
else:
|
||||
from functools import lru_cache
|
||||
def cached(f):
|
||||
return lru_cache()(f)
|
||||
|
||||
def ipartition(s, sep=" "):
|
||||
def ipartition(s: str, sep:str=" ") -> Iterable[tuple[str, str]]:
|
||||
rest = s
|
||||
while rest:
|
||||
first, opt_sep, rest = rest.partition(sep)
|
||||
|
|
@ -27,43 +37,46 @@ def ipartition(s, sep=" "):
|
|||
class Replay:
|
||||
@dataclass
|
||||
class Parameters:
|
||||
session: str = None
|
||||
session: str | None = None
|
||||
"""Complete path to an existing session file"""
|
||||
delay_mu: float = 0.035
|
||||
"""Average delay between tokens"""
|
||||
delay_sigma: float = 0.02
|
||||
"""Standard deviation of token delay"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.parameters = self.Parameters()
|
||||
|
||||
@property
|
||||
@functools.lru_cache()
|
||||
def _session(self):
|
||||
@cached
|
||||
def _session(self) -> Session:
|
||||
if self.parameters.session is None:
|
||||
raise click.BadParameter(
|
||||
"Must specify -B session:/full/path/to/existing_session.json"
|
||||
)
|
||||
session_file = os.path.expanduser(self.parameters.session)
|
||||
with open(session_file, "r", encoding="utf-8") as f:
|
||||
return Session.from_json(f.read())
|
||||
return session_from_file(session_file)
|
||||
|
||||
@property
|
||||
@functools.lru_cache()
|
||||
def _assistant_responses(self):
|
||||
@cached
|
||||
def _assistant_responses(self) -> Session:
|
||||
return [
|
||||
message for message in self._session.session if message.role == "assistant"
|
||||
message for message in self._session if message.role == "assistant"
|
||||
]
|
||||
|
||||
@property
|
||||
def system_message(self):
|
||||
def system_message(self) -> str:
|
||||
num_assistant_responses = len(self._assistant_responses)
|
||||
return (
|
||||
f"Replay of {self.parameters.session} with {num_assistant_responses} responses. "
|
||||
f"The original session system message was:\n\n{self._session.session[0].content}"
|
||||
f"The original session system message was:\n\n{self._session[0].content}"
|
||||
)
|
||||
|
||||
async def aask(self, session, query):
|
||||
@system_message.setter
|
||||
def system_message(self, value: str) -> None:
|
||||
raise AttributeError("Read-only attribute 'system_message'")
|
||||
|
||||
async def aask(self, session: Session, query: str) -> AsyncGenerator[str, None]:
|
||||
data = self.ask(session, query)
|
||||
for word, opt_sep in ipartition(data):
|
||||
yield word + opt_sep
|
||||
|
|
@ -72,20 +85,20 @@ class Replay:
|
|||
)
|
||||
|
||||
def ask(
|
||||
self, session, query, *, max_query_size=5, timeout=60
|
||||
): # pylint: disable=unused-argument
|
||||
self, session: Session, query: str
|
||||
) -> str: # pylint: disable=unused-argument
|
||||
if self._assistant_responses:
|
||||
idx = sum(
|
||||
1 for message in session.session if message.role == "assistant"
|
||||
1 for message in session if message.role == "assistant"
|
||||
) % len(self._assistant_responses)
|
||||
|
||||
new_content = self._assistant_responses[idx].content
|
||||
else:
|
||||
new_content = "(No assistant responses in session)"
|
||||
session.session.extend([User(query), Assistant(new_content)])
|
||||
session.extend([User(query), Assistant(new_content)])
|
||||
return new_content
|
||||
|
||||
|
||||
def factory():
|
||||
def factory() -> Backend:
|
||||
"""Replay an existing session file. Useful for testing."""
|
||||
return Replay()
|
||||
|
|
|
|||
Loading…
Reference in a new issue