Fix pre-commit errors

This commit is contained in:
Jeff Epler 2023-11-09 12:02:41 -06:00
parent e4d3f14f84
commit 86fa38b504
No known key found for this signature in database
GPG key ID: D5BF15AB975AB4DE

View file

@ -6,28 +6,32 @@ import asyncio
import os
import random
from dataclasses import dataclass
from typing import Iterable, AsyncGenerator, Callable, Any
from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Iterable, TypeVar
import click
from ..core import Backend # pylint: disable=relative-beyond-top-level
from ..session import ( # pylint: disable=relative-beyond-top-level
Assistant,
Session,
User,
session_from_file
session_from_file,
)
from ..core import Backend
from typing import TYPE_CHECKING, TypeVar
if TYPE_CHECKING:
F = TypeVar('F', bound=Callable[..., Any])
def cached(f: F) -> F: return f
F = TypeVar("F", bound=Callable[..., Any])
def cached(f: F) -> F:
return f
else:
from functools import lru_cache
def cached(f):
return lru_cache()(f)
def ipartition(s: str, sep:str=" ") -> Iterable[tuple[str, str]]:
def ipartition(s: str, sep: str = " ") -> Iterable[tuple[str, str]]:
rest = s
while rest:
first, opt_sep, rest = rest.partition(sep)
@ -60,9 +64,7 @@ class Replay:
@property
@cached
def _assistant_responses(self) -> Session:
return [
message for message in self._session if message.role == "assistant"
]
return [message for message in self._session if message.role == "assistant"]
@property
def system_message(self) -> str:
@ -88,9 +90,9 @@ class Replay:
self, session: Session, query: str
) -> str: # pylint: disable=unused-argument
if self._assistant_responses:
idx = sum(
1 for message in session if message.role == "assistant"
) % len(self._assistant_responses)
idx = sum(1 for message in session if message.role == "assistant") % len(
self._assistant_responses
)
new_content = self._assistant_responses[idx].content
else: