Fully type annotate mypy
with two (commented) exceptions
This commit is contained in:
parent
eac422ff17
commit
fc69800594
17 changed files with 345 additions and 201 deletions
3
.github/workflows/test.yml
vendored
3
.github/workflows/test.yml
vendored
|
|
@ -46,6 +46,9 @@ jobs:
|
|||
- name: install deps
|
||||
run: pip install -r requirements-dev.txt
|
||||
|
||||
- name: check types with mypy
|
||||
run: make
|
||||
|
||||
- name: Build release
|
||||
run: python -mbuild
|
||||
|
||||
|
|
|
|||
15
Makefile
Normal file
15
Makefile
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
# SPDX-FileCopyrightText: 2023 Jeff Epler <jepler@gmail.com>
|
||||
#
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
.PHONY: mypy
|
||||
mypy: venv/bin/mypy
|
||||
venv/bin/mypy --strict -p chap
|
||||
|
||||
venv/bin/mypy:
|
||||
python -mvenv venv
|
||||
venv/bin/pip install -r requirements.txt mypy
|
||||
|
||||
.PHONY: clean
|
||||
clean:
|
||||
rm -rf venv
|
||||
6
mypy.ini
Normal file
6
mypy.ini
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
# SPDX-FileCopyrightText: 2023 Jeff Epler <jepler@gmail.com>
|
||||
#
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
[mypy]
|
||||
mypy_path = src
|
||||
|
|
@ -2,17 +2,18 @@
|
|||
#
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
import httpx
|
||||
|
||||
from ..core import AutoAskMixin, Backend
|
||||
from ..key import get_key
|
||||
from ..session import Assistant, Role, User
|
||||
from ..session import Assistant, Role, Session, User
|
||||
|
||||
|
||||
class HuggingFace:
|
||||
class HuggingFace(AutoAskMixin):
|
||||
@dataclass
|
||||
class Parameters:
|
||||
url: str = "https://api-inference.huggingface.co"
|
||||
|
|
@ -24,14 +25,15 @@ class HuggingFace:
|
|||
after_assistant: str = """ </s><s>[INST] """
|
||||
stop_token_id = 2
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.parameters = self.Parameters()
|
||||
|
||||
system_message = """\
|
||||
A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, and knows its own limits.
|
||||
"""
|
||||
|
||||
def make_full_query(self, messages, max_query_size):
|
||||
def make_full_query(self, messages: Session, max_query_size: int) -> str:
|
||||
del messages[1:-max_query_size]
|
||||
result = [self.parameters.start_prompt]
|
||||
for m in messages:
|
||||
|
|
@ -48,7 +50,9 @@ A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, a
|
|||
full_query = "".join(result)
|
||||
return full_query
|
||||
|
||||
async def chained_query(self, inputs, timeout):
|
||||
async def chained_query(
|
||||
self, inputs: Any, timeout: float
|
||||
) -> AsyncGenerator[str, None]:
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
while inputs:
|
||||
params = {
|
||||
|
|
@ -79,9 +83,16 @@ A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, a
|
|||
return
|
||||
|
||||
async def aask(
|
||||
self, session, query, *, max_query_size=5, timeout=180
|
||||
): # pylint: disable=unused-argument,too-many-locals,too-many-branches
|
||||
new_content = []
|
||||
self,
|
||||
session: Session,
|
||||
query: str,
|
||||
*,
|
||||
max_query_size: int = 5,
|
||||
timeout: float = 180,
|
||||
) -> AsyncGenerator[
|
||||
str, None
|
||||
]: # pylint: disable=unused-argument,too-many-locals,too-many-branches
|
||||
new_content: list[str] = []
|
||||
inputs = self.make_full_query(session + [User(query)], max_query_size)
|
||||
try:
|
||||
async for content in self.chained_query(inputs, timeout=timeout):
|
||||
|
|
@ -101,17 +112,11 @@ A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, a
|
|||
|
||||
session.extend([User(query), Assistant("".join(new_content))])
|
||||
|
||||
def ask(self, session, query, *, max_query_size=5, timeout=60):
|
||||
asyncio.run(
|
||||
self.aask(session, query, max_query_size=max_query_size, timeout=timeout)
|
||||
)
|
||||
return session[-1].content
|
||||
|
||||
@classmethod
|
||||
def get_key(cls):
|
||||
def get_key(cls) -> str:
|
||||
return get_key("huggingface_api_token")
|
||||
|
||||
|
||||
def factory():
|
||||
def factory() -> Backend:
|
||||
"""Uses the huggingface text-generation-interface web API"""
|
||||
return HuggingFace()
|
||||
|
|
|
|||
|
|
@ -2,16 +2,17 @@
|
|||
#
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import AsyncGenerator
|
||||
|
||||
import httpx
|
||||
|
||||
from ..session import Assistant, Role, User
|
||||
from ..core import AutoAskMixin, Backend
|
||||
from ..session import Assistant, Role, Session, User
|
||||
|
||||
|
||||
class LlamaCpp:
|
||||
class LlamaCpp(AutoAskMixin):
|
||||
@dataclass
|
||||
class Parameters:
|
||||
url: str = "http://localhost:8080/completion"
|
||||
|
|
@ -22,14 +23,15 @@ class LlamaCpp:
|
|||
after_user: str = """ [/INST] """
|
||||
after_assistant: str = """ </s><s>[INST] """
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.parameters = self.Parameters()
|
||||
|
||||
system_message = """\
|
||||
A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, and knows its own limits.
|
||||
"""
|
||||
|
||||
def make_full_query(self, messages, max_query_size):
|
||||
def make_full_query(self, messages: Session, max_query_size: int) -> str:
|
||||
del messages[1:-max_query_size]
|
||||
result = [self.parameters.start_prompt]
|
||||
for m in messages:
|
||||
|
|
@ -47,14 +49,21 @@ A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, a
|
|||
return full_query
|
||||
|
||||
async def aask(
|
||||
self, session, query, *, max_query_size=5, timeout=180
|
||||
): # pylint: disable=unused-argument,too-many-locals,too-many-branches
|
||||
self,
|
||||
session: Session,
|
||||
query: str,
|
||||
*,
|
||||
max_query_size: int = 5,
|
||||
timeout: float = 180,
|
||||
) -> AsyncGenerator[
|
||||
str, None
|
||||
]: # pylint: disable=unused-argument,too-many-locals,too-many-branches
|
||||
params = {
|
||||
"prompt": self.make_full_query(session + [User(query)], max_query_size),
|
||||
"stream": True,
|
||||
"stop": ["</s>", "<s>", "[INST]"],
|
||||
}
|
||||
new_content = []
|
||||
new_content: list[str] = []
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
async with client.stream(
|
||||
|
|
@ -87,13 +96,7 @@ A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, a
|
|||
|
||||
session.extend([User(query), Assistant("".join(new_content))])
|
||||
|
||||
def ask(self, session, query, *, max_query_size=5, timeout=60):
|
||||
asyncio.run(
|
||||
self.aask(session, query, max_query_size=max_query_size, timeout=timeout)
|
||||
)
|
||||
return session[-1].content
|
||||
|
||||
|
||||
def factory():
|
||||
def factory() -> Backend:
|
||||
"""Uses the llama.cpp completion web API"""
|
||||
return LlamaCpp()
|
||||
|
|
|
|||
|
|
@ -5,13 +5,16 @@
|
|||
import asyncio
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import AsyncGenerator, Iterable, cast
|
||||
|
||||
from lorem_text import lorem
|
||||
# lorem is not type annotated
|
||||
from lorem_text import lorem # type: ignore
|
||||
|
||||
from ..session import Assistant, User
|
||||
from ..core import Backend
|
||||
from ..session import Assistant, Session, User
|
||||
|
||||
|
||||
def ipartition(s, sep=" "):
|
||||
def ipartition(s: str, sep: str = " ") -> Iterable[tuple[str, str]]:
|
||||
rest = s
|
||||
while rest:
|
||||
first, opt_sep, rest = rest.partition(sep)
|
||||
|
|
@ -30,15 +33,19 @@ class Lorem:
|
|||
paragraph_hi: int = 5
|
||||
"""Maximum response paragraph count (inclusive)"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.parameters = self.Parameters()
|
||||
|
||||
system_message = (
|
||||
"(It doesn't matter what you ask, this backend will respond with lorem)"
|
||||
)
|
||||
|
||||
async def aask(self, session, query, *, max_query_size=5, timeout=60):
|
||||
data = self.ask(session, query, max_query_size=max_query_size, timeout=timeout)
|
||||
async def aask(
|
||||
self,
|
||||
session: Session,
|
||||
query: str,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
data = self.ask(session, query)[-1]
|
||||
for word, opt_sep in ipartition(data):
|
||||
yield word + opt_sep
|
||||
await asyncio.sleep(
|
||||
|
|
@ -46,15 +53,22 @@ class Lorem:
|
|||
)
|
||||
|
||||
def ask(
|
||||
self, session, query, *, max_query_size=5, timeout=60
|
||||
): # pylint: disable=unused-argument
|
||||
new_content = lorem.paragraphs(
|
||||
random.randint(self.parameters.paragraph_lo, self.parameters.paragraph_hi)
|
||||
self,
|
||||
session: Session,
|
||||
query: str,
|
||||
) -> str: # pylint: disable=unused-argument
|
||||
new_content = cast(
|
||||
str,
|
||||
lorem.paragraphs(
|
||||
random.randint(
|
||||
self.parameters.paragraph_lo, self.parameters.paragraph_hi
|
||||
)
|
||||
),
|
||||
).replace("\n", "\n\n")
|
||||
session.extend([User(query), Assistant("".join(new_content))])
|
||||
return new_content
|
||||
return session[-1].content
|
||||
|
||||
|
||||
def factory():
|
||||
def factory() -> Backend:
|
||||
"""That just prints 'lorem' text. Useful for testing."""
|
||||
return Lorem()
|
||||
|
|
|
|||
|
|
@ -5,12 +5,14 @@
|
|||
import functools
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import AsyncGenerator, cast
|
||||
|
||||
import httpx
|
||||
import tiktoken
|
||||
|
||||
from ..core import Backend
|
||||
from ..key import get_key
|
||||
from ..session import Assistant, User
|
||||
from ..session import Assistant, Message, Session, User, session_to_list
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
|
@ -20,19 +22,19 @@ class EncodingMeta:
|
|||
tokens_per_name: int
|
||||
|
||||
@functools.lru_cache()
|
||||
def encode(self, s):
|
||||
def encode(self, s: str) -> list[int]:
|
||||
return self.encoding.encode(s)
|
||||
|
||||
def num_tokens_for_message(self, message):
|
||||
def num_tokens_for_message(self, message: Message) -> int:
|
||||
# n.b. chap doesn't use message.name yet
|
||||
return len(self.encode(message.role)) + len(self.encode(message.content))
|
||||
|
||||
def num_tokens_for_messages(self, messages):
|
||||
def num_tokens_for_messages(self, messages: Session) -> int:
|
||||
return sum(self.num_tokens_for_message(message) for message in messages) + 3
|
||||
|
||||
@classmethod
|
||||
@functools.cache
|
||||
def from_model(cls, model):
|
||||
def from_model(cls, model: str) -> "EncodingMeta":
|
||||
if model == "gpt-3.5-turbo":
|
||||
# print("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.")
|
||||
model = "gpt-3.5-turbo-0613"
|
||||
|
|
@ -77,12 +79,12 @@ class ChatGPT:
|
|||
max_request_tokens: int = 1024
|
||||
"""The approximate greatest number of tokens to send in a request. When the session is long, the system prompt and 1 or more of the most recent interaction steps are sent."""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.parameters = self.Parameters()
|
||||
|
||||
system_message: str = "You're named Chap (they/them). Chap strives to answer truthfully and frankly. You have opinions, preferences, and emotions, you can say when you don't know, and you prefer to speak in plain language."
|
||||
|
||||
def make_full_prompt(self, all_history):
|
||||
def make_full_prompt(self, all_history: Session) -> Session:
|
||||
encoding = EncodingMeta.from_model(self.parameters.model)
|
||||
result = [all_history[0]] # Assumed to be system prompt
|
||||
left = self.parameters.max_request_tokens - encoding.num_tokens_for_messages(
|
||||
|
|
@ -99,15 +101,13 @@ class ChatGPT:
|
|||
result.extend(reversed(parts))
|
||||
return result
|
||||
|
||||
def ask(self, session, query, *, timeout=60):
|
||||
def ask(self, session: Session, query: str, *, timeout: float = 60) -> str:
|
||||
full_prompt = self.make_full_prompt(session + [User(query)])
|
||||
response = httpx.post(
|
||||
"https://api.openai.com/v1/chat/completions",
|
||||
json={
|
||||
"model": self.parameters.model,
|
||||
"messages": full_prompt.to_dict()[ # pylint: disable=no-member
|
||||
"session"
|
||||
],
|
||||
"messages": session_to_list(full_prompt),
|
||||
}, # pylint: disable=no-member
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.get_key()}",
|
||||
|
|
@ -115,20 +115,20 @@ class ChatGPT:
|
|||
timeout=timeout,
|
||||
)
|
||||
if response.status_code != 200:
|
||||
print("Failure", response.status_code, response.text)
|
||||
return None
|
||||
return f"Failure {response.text} ({response.status_code})"
|
||||
|
||||
try:
|
||||
j = response.json()
|
||||
result = j["choices"][0]["message"]["content"]
|
||||
result = cast(str, j["choices"][0]["message"]["content"])
|
||||
except (KeyError, IndexError, json.decoder.JSONDecodeError):
|
||||
print("Failure", response.status_code, response.text)
|
||||
return None
|
||||
return f"Failure {response.text} ({response.status_code})"
|
||||
|
||||
session.extend([User(query), Assistant(result)])
|
||||
return result
|
||||
|
||||
async def aask(self, session, query, *, timeout=60):
|
||||
async def aask(
|
||||
self, session: Session, query: str, *, timeout: float = 60
|
||||
) -> AsyncGenerator[str, None]:
|
||||
full_prompt = self.make_full_prompt(session + [User(query)])
|
||||
new_content = []
|
||||
try:
|
||||
|
|
@ -140,9 +140,7 @@ class ChatGPT:
|
|||
json={
|
||||
"model": self.parameters.model,
|
||||
"stream": True,
|
||||
"messages": full_prompt.to_dict()[ # pylint: disable=no-member
|
||||
"session"
|
||||
], # pylint: disable=no-member
|
||||
"messages": session_to_list(full_prompt),
|
||||
},
|
||||
) as response:
|
||||
if response.status_code == 200:
|
||||
|
|
@ -170,10 +168,10 @@ class ChatGPT:
|
|||
session.extend([User(query), Assistant("".join(new_content))])
|
||||
|
||||
@classmethod
|
||||
def get_key(cls):
|
||||
def get_key(cls) -> str:
|
||||
return get_key("openai_api_key")
|
||||
|
||||
|
||||
def factory():
|
||||
def factory() -> Backend:
|
||||
"""Uses the OpenAI chat completion API"""
|
||||
return ChatGPT()
|
||||
|
|
|
|||
|
|
@ -2,19 +2,25 @@
|
|||
#
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import AsyncGenerator
|
||||
|
||||
import websockets
|
||||
|
||||
from ..key import get_key
|
||||
from ..session import Assistant, Role, User
|
||||
from ..core import AutoAskMixin, Backend
|
||||
from ..session import Assistant, Role, Session, User
|
||||
|
||||
|
||||
class Textgen:
|
||||
def __init__(self):
|
||||
self.server = get_key("textgen_url", "textgen server URL")
|
||||
class Textgen(AutoAskMixin):
|
||||
@dataclass
|
||||
class Parameters:
|
||||
server_hostname: str = "localhost"
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.parameters = self.Parameters()
|
||||
|
||||
system_message = """\
|
||||
A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, and knows its own limits.
|
||||
|
|
@ -23,9 +29,14 @@ USER: Hello, AI.
|
|||
|
||||
AI: Hello! How can I assist you today?"""
|
||||
|
||||
async def aask(
|
||||
self, session, query, *, max_query_size=5, timeout=60
|
||||
): # pylint: disable=unused-argument,too-many-locals,too-many-branches
|
||||
async def aask( # pylint: disable=unused-argument,too-many-locals,too-many-branches
|
||||
self,
|
||||
session: Session,
|
||||
query: str,
|
||||
*,
|
||||
max_query_size: int = 5,
|
||||
timeout: float = 60,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
params = {
|
||||
"max_new_tokens": 200,
|
||||
"do_sample": True,
|
||||
|
|
@ -55,7 +66,7 @@ AI: Hello! How can I assist you today?"""
|
|||
)
|
||||
try:
|
||||
async with websockets.connect( # pylint: disable=no-member
|
||||
f"ws://{self.server}:7860/queue/join"
|
||||
f"ws://{self.parameters.server_hostname}:7860/queue/join"
|
||||
) as websocket:
|
||||
while content := json.loads(await websocket.recv()):
|
||||
if content["msg"] == "send_hash":
|
||||
|
|
@ -124,13 +135,7 @@ AI: Hello! How can I assist you today?"""
|
|||
all_response = new_data[len(full_query) :]
|
||||
session.extend([User(query), Assistant(all_response)])
|
||||
|
||||
def ask(self, session, query, *, max_query_size=5, timeout=60):
|
||||
asyncio.run(
|
||||
self.aask(session, query, max_query_size=max_query_size, timeout=timeout)
|
||||
)
|
||||
return session[-1].content
|
||||
|
||||
|
||||
def factory():
|
||||
def factory() -> Backend:
|
||||
"""Uses the textgen completion API"""
|
||||
return Textgen()
|
||||
|
|
|
|||
|
|
@ -4,43 +4,52 @@
|
|||
|
||||
import asyncio
|
||||
import sys
|
||||
from typing import Iterable, Protocol
|
||||
|
||||
import click
|
||||
import rich
|
||||
|
||||
from ..core import command_uses_new_session
|
||||
from ..session import session_to_file
|
||||
from ..core import Backend, Obj, command_uses_new_session
|
||||
from ..session import Session, session_to_file
|
||||
|
||||
bold = "\033[1m"
|
||||
nobold = "\033[m"
|
||||
|
||||
|
||||
def ipartition(s, sep):
|
||||
def ipartition(s: str, sep: str) -> Iterable[tuple[str, str]]:
|
||||
rest = s
|
||||
while rest:
|
||||
first, opt_sep, rest = rest.partition(sep)
|
||||
yield (first, opt_sep)
|
||||
|
||||
|
||||
class Printable(Protocol):
|
||||
def raw(self, s: str) -> None:
|
||||
...
|
||||
|
||||
def add(self, s: str) -> None:
|
||||
...
|
||||
|
||||
|
||||
class DumbPrinter:
|
||||
def raw(self, s):
|
||||
def raw(self, s: str) -> None:
|
||||
pass
|
||||
|
||||
def add(self, s):
|
||||
def add(self, s: str) -> None:
|
||||
print(s, end="")
|
||||
|
||||
|
||||
class WrappingPrinter:
|
||||
def __init__(self, width=None):
|
||||
def __init__(self, width: int | None = None) -> None:
|
||||
self._width = width or rich.get_console().width
|
||||
self._column = 0
|
||||
self._line = ""
|
||||
self._sp = ""
|
||||
|
||||
def raw(self, s):
|
||||
def raw(self, s: str) -> None:
|
||||
print(s, end="")
|
||||
|
||||
def add(self, s):
|
||||
def add(self, s: str) -> None:
|
||||
for line, opt_nl in ipartition(s, "\n"):
|
||||
for word, opt_sp in ipartition(line, " "):
|
||||
newlen = len(self._line) + len(self._sp) + len(word)
|
||||
|
|
@ -64,15 +73,16 @@ class WrappingPrinter:
|
|||
self._sp = ""
|
||||
|
||||
|
||||
def verbose_ask(api, session, q, print_prompt, **kw):
|
||||
def verbose_ask(api: Backend, session: Session, q: str, print_prompt: bool) -> str:
|
||||
printer: Printable
|
||||
if sys.stdout.isatty():
|
||||
printer = WrappingPrinter()
|
||||
else:
|
||||
printer = DumbPrinter()
|
||||
tokens = []
|
||||
tokens: list[str] = []
|
||||
|
||||
async def work():
|
||||
async for token in api.aask(session, q, **kw):
|
||||
async def work() -> None:
|
||||
async for token in api.aask(session, q):
|
||||
printer.add(token)
|
||||
|
||||
if print_prompt:
|
||||
|
|
@ -91,11 +101,16 @@ def verbose_ask(api, session, q, print_prompt, **kw):
|
|||
@command_uses_new_session
|
||||
@click.option("--print-prompt/--no-print-prompt", default=True)
|
||||
@click.argument("prompt", nargs=-1, required=True)
|
||||
def main(obj, prompt, print_prompt):
|
||||
def main(obj: Obj, prompt: str, print_prompt: bool) -> None:
|
||||
"""Ask a question (command-line argument is passed as prompt)"""
|
||||
session = obj.session
|
||||
assert session is not None
|
||||
|
||||
session_filename = obj.session_filename
|
||||
assert session_filename is not None
|
||||
|
||||
api = obj.api
|
||||
assert api is not None
|
||||
|
||||
# symlink_session_filename(session_filename)
|
||||
|
||||
|
|
|
|||
|
|
@ -4,15 +4,17 @@
|
|||
|
||||
import click
|
||||
|
||||
from ..core import command_uses_existing_session
|
||||
from ..core import Obj, command_uses_existing_session
|
||||
from ..session import Role
|
||||
|
||||
|
||||
@command_uses_existing_session
|
||||
@click.option("--no-system", is_flag=True)
|
||||
def main(obj, no_system):
|
||||
def main(obj: Obj, no_system: bool) -> None:
|
||||
"""Print session in plaintext"""
|
||||
session = obj.session
|
||||
if not session:
|
||||
return
|
||||
|
||||
first = True
|
||||
for row in session:
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ from .render import to_markdown
|
|||
|
||||
|
||||
def list_files_matching_rx(
|
||||
rx: re.Pattern, conversations_path: Optional[pathlib.Path] = None
|
||||
rx: re.Pattern[str], conversations_path: Optional[pathlib.Path] = None
|
||||
) -> Iterable[Tuple[pathlib.Path, Message]]:
|
||||
for conversation in (conversations_path or default_conversations_path).glob(
|
||||
"*.json"
|
||||
|
|
@ -39,7 +39,9 @@ def list_files_matching_rx(
|
|||
@click.option("--files-with-matches", "-l", is_flag=True)
|
||||
@click.option("--fixed-strings", "--literal", "-F", is_flag=True)
|
||||
@click.argument("pattern", nargs=1, required=True)
|
||||
def main(ignore_case, files_with_matches, fixed_strings, pattern):
|
||||
def main(
|
||||
ignore_case: bool, files_with_matches: bool, fixed_strings: bool, pattern: str
|
||||
) -> None:
|
||||
"""Search sessions for pattern"""
|
||||
console = rich.get_console()
|
||||
if fixed_strings:
|
||||
|
|
@ -47,7 +49,7 @@ def main(ignore_case, files_with_matches, fixed_strings, pattern):
|
|||
|
||||
rx = re.compile(pattern, re.I if ignore_case else 0)
|
||||
last_file = None
|
||||
for f, m in list_files_matching_rx(rx, ignore_case):
|
||||
for f, m in list_files_matching_rx(rx, None):
|
||||
if f != last_file:
|
||||
if files_with_matches:
|
||||
print(f)
|
||||
|
|
|
|||
|
|
@ -6,17 +6,20 @@ from __future__ import annotations
|
|||
|
||||
import json
|
||||
import pathlib
|
||||
from typing import Any, Iterator, TextIO
|
||||
|
||||
import click
|
||||
import rich
|
||||
|
||||
from ..core import conversations_path, new_session_path
|
||||
from ..session import Message, Role, new_session, session_to_file
|
||||
from ..session import Message, Role, Session, new_session, session_to_file
|
||||
|
||||
console = rich.get_console()
|
||||
|
||||
|
||||
def iter_sessions(name, content, session_in, node_id):
|
||||
def iter_sessions(
|
||||
name: str, content: Any, session_in: Session, node_id: str
|
||||
) -> Iterator[tuple[str, Session]]:
|
||||
node = content["mapping"][node_id]
|
||||
session = session_in[:]
|
||||
|
||||
|
|
@ -37,7 +40,7 @@ def iter_sessions(name, content, session_in, node_id):
|
|||
yield node_id, session
|
||||
|
||||
|
||||
def do_import(output_directory, f):
|
||||
def do_import(output_directory: pathlib.Path, f: TextIO) -> None:
|
||||
stem = pathlib.Path(f.name).stem
|
||||
content = json.load(f)
|
||||
session = new_session()
|
||||
|
|
@ -66,7 +69,7 @@ def do_import(output_directory, f):
|
|||
@click.argument(
|
||||
"files", nargs=-1, required=True, type=click.File("r", encoding="utf-8")
|
||||
)
|
||||
def main(output_directory, files):
|
||||
def main(output_directory: pathlib.Path, files: list[TextIO]) -> None:
|
||||
"""Import files from the ChatGPT webui
|
||||
|
||||
This understands the format produced by
|
||||
|
|
|
|||
|
|
@ -7,11 +7,11 @@ import rich
|
|||
from markdown_it import MarkdownIt
|
||||
from rich.markdown import Markdown
|
||||
|
||||
from ..core import command_uses_existing_session
|
||||
from ..session import Role
|
||||
from ..core import Obj, command_uses_existing_session
|
||||
from ..session import Message, Role
|
||||
|
||||
|
||||
def to_markdown(message):
|
||||
def to_markdown(message: Message) -> Markdown:
|
||||
role = message.role
|
||||
if role == Role.USER:
|
||||
style = "bold"
|
||||
|
|
@ -28,9 +28,10 @@ def to_markdown(message):
|
|||
|
||||
@command_uses_existing_session
|
||||
@click.option("--no-system", is_flag=True)
|
||||
def main(obj, no_system):
|
||||
def main(obj: Obj, no_system: bool) -> None:
|
||||
"""Print session with formatting"""
|
||||
session = obj.session
|
||||
assert session is not None
|
||||
|
||||
console = rich.get_console()
|
||||
first = True
|
||||
|
|
|
|||
|
|
@ -5,19 +5,20 @@
|
|||
import asyncio
|
||||
import subprocess
|
||||
import sys
|
||||
from typing import cast
|
||||
|
||||
from markdown_it import MarkdownIt
|
||||
from textual import work
|
||||
from textual.app import App
|
||||
from textual.app import App, ComposeResult
|
||||
from textual.binding import Binding
|
||||
from textual.containers import Container, Horizontal, VerticalScroll
|
||||
from textual.widgets import Button, Footer, Input, LoadingIndicator, Markdown
|
||||
|
||||
from ..core import command_uses_new_session, get_api, new_session_path
|
||||
from ..session import Assistant, User, new_session, session_to_file
|
||||
from ..core import Backend, Obj, command_uses_new_session, get_api, new_session_path
|
||||
from ..session import Assistant, Message, Session, User, new_session, session_to_file
|
||||
|
||||
|
||||
def parser_factory():
|
||||
def parser_factory() -> MarkdownIt:
|
||||
parser = MarkdownIt()
|
||||
parser.options["html"] = False
|
||||
return parser
|
||||
|
|
@ -34,7 +35,7 @@ class ChapMarkdown(
|
|||
]
|
||||
|
||||
|
||||
def markdown_for_step(step):
|
||||
def markdown_for_step(step: Message) -> ChapMarkdown:
|
||||
return ChapMarkdown(
|
||||
step.content.strip() or "…",
|
||||
classes="role_" + step.role,
|
||||
|
|
@ -48,13 +49,15 @@ class CancelButton(Button):
|
|||
]
|
||||
|
||||
|
||||
class Tui(App):
|
||||
class Tui(App[None]):
|
||||
CSS_PATH = "tui.css"
|
||||
BINDINGS = [
|
||||
Binding("ctrl+c", "quit", "Quit", show=True, priority=True),
|
||||
]
|
||||
|
||||
def __init__(self, api=None, session=None):
|
||||
def __init__(
|
||||
self, api: Backend | None = None, session: Session | None = None
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.api = api or get_api("lorem")
|
||||
self.session = (
|
||||
|
|
@ -62,26 +65,26 @@ class Tui(App):
|
|||
)
|
||||
|
||||
@property
|
||||
def spinner(self):
|
||||
def spinner(self) -> LoadingIndicator:
|
||||
return self.query_one(LoadingIndicator)
|
||||
|
||||
@property
|
||||
def wait(self):
|
||||
return self.query_one("#wait")
|
||||
def wait(self) -> VerticalScroll:
|
||||
return cast(VerticalScroll, self.query_one("#wait"))
|
||||
|
||||
@property
|
||||
def input(self):
|
||||
def input(self) -> Input:
|
||||
return self.query_one(Input)
|
||||
|
||||
@property
|
||||
def cancel_button(self):
|
||||
def cancel_button(self) -> CancelButton:
|
||||
return self.query_one(CancelButton)
|
||||
|
||||
@property
|
||||
def container(self):
|
||||
return self.query_one("#content")
|
||||
def container(self) -> VerticalScroll:
|
||||
return cast(VerticalScroll, self.query_one("#content"))
|
||||
|
||||
def compose(self):
|
||||
def compose(self) -> ComposeResult:
|
||||
yield Footer()
|
||||
yield VerticalScroll(
|
||||
*[markdown_for_step(step) for step in self.session],
|
||||
|
|
@ -100,11 +103,11 @@ class Tui(App):
|
|||
self.container.scroll_end(animate=False)
|
||||
self.input.focus()
|
||||
|
||||
async def on_input_submitted(self, event) -> None:
|
||||
async def on_input_submitted(self, event: Input.Submitted) -> None:
|
||||
self.get_completion(event.value)
|
||||
|
||||
@work(exclusive=True)
|
||||
async def get_completion(self, query):
|
||||
async def get_completion(self, query: str) -> None:
|
||||
self.scroll_end()
|
||||
|
||||
self.input.styles.display = "none"
|
||||
|
|
@ -117,8 +120,8 @@ class Tui(App):
|
|||
await self.container.mount_all(
|
||||
[markdown_for_step(User(query)), output], before="#pad"
|
||||
)
|
||||
tokens = []
|
||||
update = asyncio.Queue(1)
|
||||
tokens: list[str] = []
|
||||
update: asyncio.Queue[bool] = asyncio.Queue(1)
|
||||
|
||||
for markdown in self.container.children:
|
||||
markdown.disabled = True
|
||||
|
|
@ -137,14 +140,14 @@ class Tui(App):
|
|||
]
|
||||
)
|
||||
|
||||
async def render_fun():
|
||||
async def render_fun() -> None:
|
||||
while await update.get():
|
||||
if tokens:
|
||||
output.update("".join(tokens).strip())
|
||||
self.container.scroll_end()
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
async def get_token_fun():
|
||||
async def get_token_fun() -> None:
|
||||
async for token in self.api.aask(session, query):
|
||||
tokens.append(token)
|
||||
message.content += token
|
||||
|
|
@ -174,16 +177,16 @@ class Tui(App):
|
|||
self.cancel_button.disabled = True
|
||||
self.input.focus()
|
||||
|
||||
def scroll_end(self):
|
||||
def scroll_end(self) -> None:
|
||||
self.call_after_refresh(self.container.scroll_end)
|
||||
|
||||
def action_yank(self):
|
||||
def action_yank(self) -> None:
|
||||
widget = self.focused
|
||||
if isinstance(widget, ChapMarkdown):
|
||||
content = widget._markdown # pylint: disable=protected-access
|
||||
content = widget._markdown or "" # pylint: disable=protected-access
|
||||
subprocess.run(["xsel", "-ib"], input=content.encode("utf-8"), check=False)
|
||||
|
||||
def action_toggle_history(self):
|
||||
def action_toggle_history(self) -> None:
|
||||
widget = self.focused
|
||||
if not isinstance(widget, ChapMarkdown):
|
||||
return
|
||||
|
|
@ -199,23 +202,25 @@ class Tui(App):
|
|||
for m in children[idx : idx + 2]:
|
||||
m.toggle_class("history_exclude")
|
||||
|
||||
async def action_stop_generating(self):
|
||||
async def action_stop_generating(self) -> None:
|
||||
self.workers.cancel_all()
|
||||
|
||||
async def on_button_pressed(self, event): # pylint: disable=unused-argument
|
||||
async def on_button_pressed( # pylint: disable=unused-argument
|
||||
self, event: Button.Pressed
|
||||
) -> None:
|
||||
self.workers.cancel_all()
|
||||
|
||||
async def action_quit(self):
|
||||
async def action_quit(self) -> None:
|
||||
self.workers.cancel_all()
|
||||
self.exit()
|
||||
|
||||
async def action_resubmit(self):
|
||||
async def action_resubmit(self) -> None:
|
||||
await self.redraft_or_resubmit(True)
|
||||
|
||||
async def action_redraft(self):
|
||||
async def action_redraft(self) -> None:
|
||||
await self.redraft_or_resubmit(False)
|
||||
|
||||
async def redraft_or_resubmit(self, resubmit):
|
||||
async def redraft_or_resubmit(self, resubmit: bool) -> None:
|
||||
widget = self.focused
|
||||
if not isinstance(widget, ChapMarkdown):
|
||||
return
|
||||
|
|
@ -244,11 +249,14 @@ class Tui(App):
|
|||
|
||||
|
||||
@command_uses_new_session
|
||||
def main(obj):
|
||||
def main(obj: Obj) -> None:
|
||||
"""Start interactive terminal user interface session"""
|
||||
api = obj.api
|
||||
assert api is not None
|
||||
session = obj.session
|
||||
assert session is not None
|
||||
session_filename = obj.session_filename
|
||||
assert session_filename is not None
|
||||
|
||||
tui = Tui(api, session)
|
||||
tui.run()
|
||||
|
|
|
|||
147
src/chap/core.py
147
src/chap/core.py
|
|
@ -3,6 +3,7 @@
|
|||
# SPDX-License-Identifier: MIT
|
||||
# pylint: disable=import-outside-toplevel
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import importlib
|
||||
import os
|
||||
|
|
@ -10,33 +11,62 @@ import pathlib
|
|||
import pkgutil
|
||||
import subprocess
|
||||
from dataclasses import MISSING, dataclass, fields
|
||||
from typing import Any, AsyncGenerator, Callable, cast
|
||||
|
||||
import click
|
||||
import platformdirs
|
||||
from simple_parsing.docstring import get_attribute_docstring
|
||||
from typing_extensions import Protocol
|
||||
|
||||
from . import backends, commands # pylint: disable=no-name-in-module
|
||||
from .session import Message, System, session_from_file
|
||||
from .session import Message, Session, System, session_from_file
|
||||
|
||||
conversations_path = platformdirs.user_state_path("chap") / "conversations"
|
||||
conversations_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def last_session_path():
|
||||
class ABackend(Protocol): # pylint: disable=too-few-public-methods
|
||||
def aask(self, session: Session, query: str) -> AsyncGenerator[str, None]:
|
||||
"""Make a query, updating the session with the query and response, returning the query token by token"""
|
||||
|
||||
|
||||
class Backend(ABackend, Protocol):
|
||||
parameters: Any
|
||||
system_message: str
|
||||
|
||||
def ask(self, session: Session, query: str) -> str:
|
||||
"""Make a query, updating the session with the query and response, returning the query"""
|
||||
|
||||
|
||||
class AutoAskMixin: # pylint: disable=too-few-public-methods
|
||||
"""Mixin class for backends implementing aask"""
|
||||
|
||||
def ask(self, session: Session, query: str) -> str:
|
||||
tokens: list[str] = []
|
||||
|
||||
async def inner() -> None:
|
||||
# https://github.com/pylint-dev/pylint/issues/5761
|
||||
async for token in self.aask(session, query): # type: ignore
|
||||
tokens.append(token)
|
||||
|
||||
asyncio.run(inner())
|
||||
return "".join(tokens)
|
||||
|
||||
|
||||
def last_session_path() -> pathlib.Path | None:
|
||||
result = max(
|
||||
conversations_path.glob("*.json"), key=lambda p: p.stat().st_mtime, default=None
|
||||
)
|
||||
print(result)
|
||||
return result
|
||||
|
||||
|
||||
def new_session_path(opt_path=None):
|
||||
def new_session_path(opt_path: pathlib.Path | None = None) -> pathlib.Path:
|
||||
return opt_path or conversations_path / (
|
||||
datetime.datetime.now().isoformat().replace(":", "_") + ".json"
|
||||
)
|
||||
|
||||
|
||||
def configure_api_from_environment(api_name, api):
|
||||
def configure_api_from_environment(api_name: str, api: Backend) -> None:
|
||||
if not hasattr(api, "parameters"):
|
||||
return
|
||||
|
||||
|
|
@ -54,44 +84,46 @@ def configure_api_from_environment(api_name, api):
|
|||
setattr(api.parameters, field.name, tv)
|
||||
|
||||
|
||||
def get_api(name="openai_chatgpt"):
|
||||
def get_api(name: str = "openai_chatgpt") -> Backend:
|
||||
name = name.replace("-", "_")
|
||||
result = importlib.import_module(f"{__package__}.backends.{name}").factory()
|
||||
result = cast(
|
||||
Backend, importlib.import_module(f"{__package__}.backends.{name}").factory()
|
||||
)
|
||||
configure_api_from_environment(name, result)
|
||||
return result
|
||||
|
||||
|
||||
def ask(*args, **kw):
|
||||
return get_api().ask(*args, **kw)
|
||||
|
||||
|
||||
def aask(*args, **kw):
|
||||
return get_api().aask(*args, **kw)
|
||||
|
||||
|
||||
def do_session_continue(ctx, param, value):
|
||||
def do_session_continue(
|
||||
ctx: click.Context, param: click.Parameter, value: pathlib.Path | None
|
||||
) -> None:
|
||||
if value is None:
|
||||
return
|
||||
if ctx.obj.session is not None:
|
||||
raise click.BadParameter(
|
||||
param, "--continue-session, --last and --new-session are mutually exclusive"
|
||||
"--continue-session, --last and --new-session are mutually exclusive",
|
||||
param=param,
|
||||
)
|
||||
ctx.obj.session = session_from_file(value)
|
||||
ctx.obj.session_filename = value
|
||||
|
||||
|
||||
def do_session_last(ctx, param, value): # pylint: disable=unused-argument
|
||||
def do_session_last(
|
||||
ctx: click.Context, param: click.Parameter, value: bool
|
||||
) -> None: # pylint: disable=unused-argument
|
||||
if not value:
|
||||
return
|
||||
do_session_continue(ctx, param, last_session_path())
|
||||
|
||||
|
||||
def do_session_new(ctx, param, value):
|
||||
def do_session_new(
|
||||
ctx: click.Context, param: click.Parameter, value: pathlib.Path
|
||||
) -> None:
|
||||
if ctx.obj.session is not None:
|
||||
if value is None:
|
||||
return
|
||||
raise click.BadOptionUsage(
|
||||
param, "--continue-session, --last and --new-session are mutually exclusive"
|
||||
raise click.BadParameter(
|
||||
"--continue-session, --last and --new-session are mutually exclusive",
|
||||
param=param,
|
||||
)
|
||||
session_filename = new_session_path(value)
|
||||
system_message = ctx.obj.system_message or ctx.obj.api.system_message
|
||||
|
|
@ -99,20 +131,24 @@ def do_session_new(ctx, param, value):
|
|||
ctx.obj.session_filename = session_filename
|
||||
|
||||
|
||||
def colonstr(arg):
|
||||
def colonstr(arg: str) -> tuple[str, str]:
|
||||
if ":" not in arg:
|
||||
raise click.BadParameter("must be of the form 'name:value'")
|
||||
return arg.split(":", 1)
|
||||
return cast(tuple[str, str], tuple(arg.split(":", 1)))
|
||||
|
||||
|
||||
def set_system_message(ctx, param, value): # pylint: disable=unused-argument
|
||||
def set_system_message( # pylint: disable=unused-argument
|
||||
ctx: click.Context, param: click.Parameter, value: str
|
||||
) -> None:
|
||||
if value and value.startswith("@"):
|
||||
with open(value[1:], "r", encoding="utf-8") as f:
|
||||
value = f.read().rstrip()
|
||||
ctx.obj.system_message = value
|
||||
|
||||
|
||||
def set_backend(ctx, param, value): # pylint: disable=unused-argument
|
||||
def set_backend( # pylint: disable=unused-argument
|
||||
ctx: click.Context, param: click.Parameter, value: str
|
||||
) -> None:
|
||||
if value == "list":
|
||||
formatter = ctx.make_formatter()
|
||||
format_backend_list(formatter)
|
||||
|
|
@ -125,7 +161,7 @@ def set_backend(ctx, param, value): # pylint: disable=unused-argument
|
|||
raise click.BadParameter(str(e))
|
||||
|
||||
|
||||
def format_backend_help(api, formatter):
|
||||
def format_backend_help(api: Backend, formatter: click.HelpFormatter) -> None:
|
||||
with formatter.section(f"Backend options for {api.__class__.__name__}"):
|
||||
rows = []
|
||||
for f in fields(api.parameters):
|
||||
|
|
@ -135,11 +171,14 @@ def format_backend_help(api, formatter):
|
|||
if doc:
|
||||
doc += " "
|
||||
doc += f"(Default: {default!r})"
|
||||
rows.append((f"-B {name}:{f.type.__name__.upper()}", doc))
|
||||
typename = f.type.__name__
|
||||
rows.append((f"-B {name}:{typename.upper()}", doc))
|
||||
formatter.write_dl(rows)
|
||||
|
||||
|
||||
def set_backend_option(ctx, param, opts): # pylint: disable=unused-argument
|
||||
def set_backend_option( # pylint: disable=unused-argument
|
||||
ctx: click.Context, param: click.Parameter, opts: list[tuple[str, str]]
|
||||
) -> None:
|
||||
api = ctx.obj.api
|
||||
if not hasattr(api, "parameters"):
|
||||
raise click.BadParameter(
|
||||
|
|
@ -147,7 +186,7 @@ def set_backend_option(ctx, param, opts): # pylint: disable=unused-argument
|
|||
)
|
||||
all_fields = dict((f.name.replace("_", "-"), f) for f in fields(api.parameters))
|
||||
|
||||
def set_one_backend_option(kv):
|
||||
def set_one_backend_option(kv: tuple[str, str]) -> None:
|
||||
name, value = kv
|
||||
field = all_fields.get(name)
|
||||
if field is None:
|
||||
|
|
@ -164,7 +203,7 @@ def set_backend_option(ctx, param, opts): # pylint: disable=unused-argument
|
|||
set_one_backend_option(kv)
|
||||
|
||||
|
||||
def format_backend_list(formatter):
|
||||
def format_backend_list(formatter: click.HelpFormatter) -> None:
|
||||
all_backends = []
|
||||
for pi in pkgutil.walk_packages(backends.__path__):
|
||||
name = pi.name
|
||||
|
|
@ -186,7 +225,7 @@ def format_backend_list(formatter):
|
|||
formatter.write_dl(rows)
|
||||
|
||||
|
||||
def uses_session(f):
|
||||
def uses_session(f: click.decorators.FC) -> Callable[[], None]:
|
||||
f = click.option(
|
||||
"--continue-session",
|
||||
"-s",
|
||||
|
|
@ -198,18 +237,15 @@ def uses_session(f):
|
|||
f = click.option(
|
||||
"--last", is_flag=True, callback=do_session_last, expose_value=False
|
||||
)(f)
|
||||
f = click.pass_obj(f)
|
||||
return f
|
||||
return click.pass_obj(f)
|
||||
|
||||
|
||||
def command_uses_existing_session(f):
|
||||
f = uses_session(f)
|
||||
f = click.command()(f)
|
||||
return f
|
||||
def command_uses_existing_session(f: click.decorators.FC) -> click.Command:
|
||||
return click.command()(uses_session(f))
|
||||
|
||||
|
||||
def command_uses_new_session(f):
|
||||
f = uses_session(f)
|
||||
def command_uses_new_session(f_in: click.decorators.FC) -> click.Command:
|
||||
f = uses_session(f_in)
|
||||
f = click.option(
|
||||
"--new-session",
|
||||
"-n",
|
||||
|
|
@ -218,11 +254,12 @@ def command_uses_new_session(f):
|
|||
callback=do_session_new,
|
||||
expose_value=False,
|
||||
)(f)
|
||||
f = click.command()(f)
|
||||
return f
|
||||
return click.command()(f)
|
||||
|
||||
|
||||
def version_callback(ctx, param, value) -> None: # pylint: disable=unused-argument
|
||||
def version_callback( # pylint: disable=unused-argument
|
||||
ctx: click.Context, param: click.Parameter, value: None
|
||||
) -> None:
|
||||
if not value or ctx.resilient_parsing:
|
||||
return
|
||||
|
||||
|
|
@ -247,17 +284,24 @@ def version_callback(ctx, param, value) -> None: # pylint: disable=unused-argum
|
|||
|
||||
@dataclass
|
||||
class Obj:
|
||||
api: object = None
|
||||
system_message: object = None
|
||||
api: Backend | None = None
|
||||
system_message: str | None = None
|
||||
session: list[Message] | None = None
|
||||
session_filename: pathlib.Path | None = None
|
||||
|
||||
|
||||
class MyCLI(click.MultiCommand):
|
||||
def make_context(self, info_name, args, parent=None, **extra):
|
||||
def make_context(
|
||||
self,
|
||||
info_name: str | None,
|
||||
args: list[str],
|
||||
parent: click.Context | None = None,
|
||||
**extra: Any,
|
||||
) -> click.Context:
|
||||
result = super().make_context(info_name, args, parent, obj=Obj(), **extra)
|
||||
return result
|
||||
|
||||
def list_commands(self, ctx):
|
||||
def list_commands(self, ctx: click.Context) -> list[str]:
|
||||
rv = []
|
||||
for pi in pkgutil.walk_packages(commands.__path__):
|
||||
name = pi.name
|
||||
|
|
@ -266,13 +310,18 @@ class MyCLI(click.MultiCommand):
|
|||
rv.sort()
|
||||
return rv
|
||||
|
||||
def get_command(self, ctx, cmd_name):
|
||||
def get_command(self, ctx: click.Context, cmd_name: str) -> click.Command:
|
||||
try:
|
||||
return importlib.import_module("." + cmd_name, commands.__name__).main
|
||||
return cast(
|
||||
click.Command,
|
||||
importlib.import_module("." + cmd_name, commands.__name__).main,
|
||||
)
|
||||
except ModuleNotFoundError as exc:
|
||||
raise click.UsageError(f"Invalid subcommand {cmd_name!r}", ctx) from exc
|
||||
|
||||
def format_options(self, ctx, formatter):
|
||||
def format_options(
|
||||
self, ctx: click.Context, formatter: click.HelpFormatter
|
||||
) -> None:
|
||||
super().format_options(ctx, formatter)
|
||||
api = ctx.obj.api or get_api()
|
||||
if hasattr(api, "parameters"):
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ _key_path_base = platformdirs.user_config_path("chap")
|
|||
|
||||
|
||||
@functools.cache
|
||||
def get_key(name, what="openai api key"):
|
||||
def get_key(name: str, what: str = "openai api key") -> str:
|
||||
key_path = _key_path_base / name
|
||||
if not key_path.exists():
|
||||
raise NoKeyAvailable(
|
||||
|
|
|
|||
|
|
@ -2,8 +2,14 @@
|
|||
#
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import pathlib
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import cast
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
# not an enum.Enum because these objects are not json-serializable, sigh
|
||||
|
|
@ -21,40 +27,49 @@ class Message:
|
|||
content: str
|
||||
|
||||
|
||||
def Assistant(content):
|
||||
MessageDict = TypedDict("MessageDict", {"role": str, "content": int})
|
||||
Session = list[Message]
|
||||
SessionDicts = list[MessageDict]
|
||||
|
||||
|
||||
def Assistant(content: str) -> Message:
|
||||
return Message(Role.ASSISTANT, content)
|
||||
|
||||
|
||||
def System(content):
|
||||
def System(content: str) -> Message:
|
||||
return Message(Role.SYSTEM, content)
|
||||
|
||||
|
||||
def User(content):
|
||||
def User(content: str) -> Message:
|
||||
return Message(Role.USER, content)
|
||||
|
||||
|
||||
def new_session(
|
||||
system_message: str = "You're named Chap (they/them). Chap strives to answer truthfully and frankly. You have opinions, preferences, and emotions, you can say when you don't know, and you prefer to speak in plain language.",
|
||||
):
|
||||
) -> Session:
|
||||
return [System(system_message)]
|
||||
|
||||
|
||||
def session_to_json(session):
|
||||
return json.dumps([asdict(message) for message in session])
|
||||
def session_to_json(session: Session) -> str:
|
||||
return json.dumps(session_to_list(session))
|
||||
|
||||
|
||||
def session_from_json(data):
|
||||
def session_to_list(session: Session) -> SessionDicts:
|
||||
return [cast(MessageDict, asdict(message)) for message in session]
|
||||
|
||||
|
||||
def session_from_json(data: str) -> Session:
|
||||
j = json.loads(data)
|
||||
if isinstance(j, dict):
|
||||
j = j["session"]
|
||||
return [Message(**mapping) for mapping in j]
|
||||
|
||||
|
||||
def session_from_file(path):
|
||||
def session_from_file(path: pathlib.Path | str) -> Session:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
return session_from_json(f.read())
|
||||
|
||||
|
||||
def session_to_file(session, path):
|
||||
def session_to_file(session: Session, path: pathlib.Path | str) -> None:
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
return f.write(session_to_json(session))
|
||||
f.write(session_to_json(session))
|
||||
|
|
|
|||
Loading…
Reference in a new issue