From 86fa38b504b38c31bb5e21df50e1182986a34b11 Mon Sep 17 00:00:00 2001 From: Jeff Epler Date: Thu, 9 Nov 2023 12:02:41 -0600 Subject: [PATCH] Fix pre-commit errors --- src/chap/backends/replay.py | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/src/chap/backends/replay.py b/src/chap/backends/replay.py index ea55f03..682262f 100644 --- a/src/chap/backends/replay.py +++ b/src/chap/backends/replay.py @@ -6,28 +6,32 @@ import asyncio import os import random from dataclasses import dataclass -from typing import Iterable, AsyncGenerator, Callable, Any +from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Iterable, TypeVar import click +from ..core import Backend # pylint: disable=relative-beyond-top-level from ..session import ( # pylint: disable=relative-beyond-top-level Assistant, Session, User, - session_from_file + session_from_file, ) -from ..core import Backend -from typing import TYPE_CHECKING, TypeVar if TYPE_CHECKING: - F = TypeVar('F', bound=Callable[..., Any]) - def cached(f: F) -> F: return f + F = TypeVar("F", bound=Callable[..., Any]) + + def cached(f: F) -> F: + return f + else: from functools import lru_cache + def cached(f): return lru_cache()(f) -def ipartition(s: str, sep:str=" ") -> Iterable[tuple[str, str]]: + +def ipartition(s: str, sep: str = " ") -> Iterable[tuple[str, str]]: rest = s while rest: first, opt_sep, rest = rest.partition(sep) @@ -60,9 +64,7 @@ class Replay: @property @cached def _assistant_responses(self) -> Session: - return [ - message for message in self._session if message.role == "assistant" - ] + return [message for message in self._session if message.role == "assistant"] @property def system_message(self) -> str: @@ -85,12 +87,12 @@ class Replay: ) def ask( - self, session: Session, query: str + self, session: Session, query: str ) -> str: # pylint: disable=unused-argument if self._assistant_responses: - idx = sum( - 1 for message in session if message.role == "assistant" - ) % len(self._assistant_responses) + idx = sum(1 for message in session if message.role == "assistant") % len( + self._assistant_responses + ) new_content = self._assistant_responses[idx].content else: