Merge pull request #30 from jepler/multiline2

Switch chap tui to a multiline text field
This commit is contained in:
Jeff Epler 2023-12-12 17:05:05 -06:00 committed by GitHub
commit 9f6ace394a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
21 changed files with 205 additions and 144 deletions

1
.gitignore vendored
View file

@ -8,3 +8,4 @@ __pycache__
/dist /dist
/src/chap/__version__.py /src/chap/__version__.py
/venv /venv
/keys.log

View file

@ -6,12 +6,8 @@ default_language_version:
python: python3 python: python3
repos: repos:
- repo: https://github.com/psf/black
rev: 23.1.0
hooks:
- id: black
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0 rev: v4.5.0
hooks: hooks:
- id: check-yaml - id: check-yaml
- id: end-of-file-fixer - id: end-of-file-fixer
@ -19,23 +15,20 @@ repos:
- id: trailing-whitespace - id: trailing-whitespace
exclude: tests exclude: tests
- repo: https://github.com/codespell-project/codespell - repo: https://github.com/codespell-project/codespell
rev: v2.2.4 rev: v2.2.6
hooks: hooks:
- id: codespell - id: codespell
args: [-w] args: [-w]
- repo: https://github.com/fsfe/reuse-tool - repo: https://github.com/fsfe/reuse-tool
rev: v1.1.2 rev: v2.1.0
hooks: hooks:
- id: reuse - id: reuse
- repo: https://github.com/pycqa/isort - repo: https://github.com/astral-sh/ruff-pre-commit
rev: 5.12.0 # Ruff version.
rev: v0.1.6
hooks: hooks:
- id: isort # Run the linter.
name: isort (python) - id: ruff
args: ['--profile', 'black'] args: [ --fix, --preview ]
- repo: https://github.com/pycqa/pylint # Run the formatter.
rev: v2.17.0 - id: ruff-format
hooks:
- id: pylint
additional_dependencies: [click,httpx,lorem-text,simple-parsing,'textual>=0.18.0',tiktoken,websockets]
args: ['--source-roots', 'src']

View file

@ -1,12 +0,0 @@
# SPDX-FileCopyrightText: 2023 Jeff Epler <jepler@gmail.com>
#
# SPDX-License-Identifier: MIT
[MESSAGES CONTROL]
disable=
duplicate-code,
invalid-name,
line-too-long,
missing-class-docstring,
missing-function-docstring,
missing-module-docstring,

View file

@ -9,11 +9,11 @@ SPDX-License-Identifier: MIT
# chap - A Python interface to chatgpt and other LLMs, including a terminal user interface (tui) # chap - A Python interface to chatgpt and other LLMs, including a terminal user interface (tui)
![Chap screencast](https://github.com/jepler/chap/blob/main/chap.gif) ![Chap screencast](https://raw.githubusercontent.com/jepler/chap/main/chap.gif)
## System requirements ## System requirements
Chap is developed on Linux with Python 3.11. Due to use of the `X | Y` style of type hints, it is known to not work on Python 3.9 and older. The target minimum Python version is 3.11 (debian stable). Chap is primarily developed on Linux with Python 3.11. Moderate effort will be made to support versions back to Python 3.9 (Debian oldstable).
## Installation ## Installation
@ -82,7 +82,28 @@ Put your OpenAI API key in the platform configuration directory for chap, e.g.,
* `chap grep needle` * `chap grep needle`
## Interactive terminal usage ## Interactive terminal usage
* `chap tui` The interactive terminal mode is accessed via `chap tui`.
There are a variety of keyboard shortcuts to be aware of:
* tab/shift-tab to move between the entry field and the conversation, or between conversation items
* While in the text box, F9 or (if supported by your terminal) alt+enter to submit multiline text
* while on a conversation item:
* ctrl+x to re-draft the message. This
* saves a copy of the session in an auto-named file in the conversations folder
* removes the conversation from this message to the end
* puts the user's message in the text box to edit
* ctrl+x to re-submit the message. This
* saves a copy of the session in an auto-named file in the conversations folder
* removes the conversation from this message to the end
* puts the user's message in the text box
* and submits it immediately
* ctrl+y to yank the message. This places the response part of the current
interaction in the operating system clipboard to be pasted (e..g, with
ctrl+v or command+v in other software)
* ctrl+q to toggle whether this message may be included in the contextual history for a future query.
The exact way history is submitted is determined by the back-end, often by
counting messages or tokens, but the ctrl+q toggle ensures this message (both the user
and assistant message parts) are not considered.
## Sessions & Command-line Parameters ## Sessions & Command-line Parameters

View file

@ -14,7 +14,6 @@ src_path = project_root / "src"
sys.path.insert(0, str(src_path)) sys.path.insert(0, str(src_path))
if __name__ == "__main__": if __name__ == "__main__":
# pylint: disable=import-error,no-name-in-module
from chap.core import main from chap.core import main
main() main()

View file

@ -20,13 +20,14 @@ name="chap"
authors = [{name = "Jeff Epler", email = "jepler@gmail.com"}] authors = [{name = "Jeff Epler", email = "jepler@gmail.com"}]
description = "Interact with the OpenAI ChatGPT API (and other text generators)" description = "Interact with the OpenAI ChatGPT API (and other text generators)"
dynamic = ["readme","version","dependencies"] dynamic = ["readme","version","dependencies"]
requires-python = ">=3.10" requires-python = ">=3.9"
keywords = ["llm", "tui", "chatgpt"] keywords = ["llm", "tui", "chatgpt"]
classifiers = [ classifiers = [
"Development Status :: 3 - Alpha", "Development Status :: 3 - Alpha",
"License :: OSI Approved :: MIT License", "License :: OSI Approved :: MIT License",
"Operating System :: OS Independent", "Operating System :: OS Independent",
"Programming Language :: Python :: 3", "Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.12",

View file

@ -6,7 +6,8 @@ click
httpx httpx
lorem-text lorem-text
platformdirs platformdirs
pyperclip
simple_parsing simple_parsing
textual>=0.18.0 textual[syntax]
tiktoken tiktoken
websockets websockets

View file

@ -89,9 +89,7 @@ A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, a
*, *,
max_query_size: int = 5, max_query_size: int = 5,
timeout: float = 180, timeout: float = 180,
) -> AsyncGenerator[ ) -> AsyncGenerator[str, None]:
str, None
]: # pylint: disable=unused-argument,too-many-locals,too-many-branches
new_content: list[str] = [] new_content: list[str] = []
inputs = self.make_full_query(session + [User(query)], max_query_size) inputs = self.make_full_query(session + [User(query)], max_query_size)
try: try:

View file

@ -55,9 +55,7 @@ A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, a
*, *,
max_query_size: int = 5, max_query_size: int = 5,
timeout: float = 180, timeout: float = 180,
) -> AsyncGenerator[ ) -> AsyncGenerator[str, None]:
str, None
]: # pylint: disable=unused-argument,too-many-locals,too-many-branches
params = { params = {
"prompt": self.make_full_query(session + [User(query)], max_query_size), "prompt": self.make_full_query(session + [User(query)], max_query_size),
"stream": True, "stream": True,

View file

@ -45,7 +45,7 @@ class Lorem:
session: Session, session: Session,
query: str, query: str,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
data = self.ask(session, query)[-1] data = self.ask(session, query)
for word, opt_sep in ipartition(data): for word, opt_sep in ipartition(data):
yield word + opt_sep yield word + opt_sep
await asyncio.sleep( await asyncio.sleep(
@ -56,7 +56,7 @@ class Lorem:
self, self,
session: Session, session: Session,
query: str, query: str,
) -> str: # pylint: disable=unused-argument ) -> str:
new_content = cast( new_content = cast(
str, str,
lorem.paragraphs( lorem.paragraphs(

View file

@ -101,7 +101,7 @@ class ChatGPT:
json={ json={
"model": self.parameters.model, "model": self.parameters.model,
"messages": session_to_list(full_prompt), "messages": session_to_list(full_prompt),
}, # pylint: disable=no-member },
headers={ headers={
"Authorization": f"Bearer {self.get_key()}", "Authorization": f"Bearer {self.get_key()}",
}, },

View file

@ -29,7 +29,7 @@ USER: Hello, AI.
AI: Hello! How can I assist you today?""" AI: Hello! How can I assist you today?"""
async def aask( # pylint: disable=unused-argument,too-many-locals,too-many-branches async def aask(
self, self,
session: Session, session: Session,
query: str, query: str,
@ -60,12 +60,11 @@ AI: Hello! How can I assist you today?"""
} }
full_prompt = session + [User(query)] full_prompt = session + [User(query)]
del full_prompt[1:-max_query_size] del full_prompt[1:-max_query_size]
new_data = old_data = full_query = ( new_data = old_data = full_query = "\n".join(
"\n".join(f"{role_map.get(q.role,'')}{q.content}\n" for q in full_prompt) f"{role_map.get(q.role,'')}{q.content}\n" for q in full_prompt
+ f"\n{role_map.get('assistant')}" ) + f"\n{role_map.get('assistant')}"
)
try: try:
async with websockets.connect( # pylint: disable=no-member async with websockets.connect(
f"ws://{self.parameters.server_hostname}:7860/queue/join" f"ws://{self.parameters.server_hostname}:7860/queue/join"
) as websocket: ) as websocket:
while content := json.loads(await websocket.recv()): while content := json.loads(await websocket.recv()):
@ -127,7 +126,7 @@ AI: Hello! How can I assist you today?"""
# stop generation by closing the websocket here # stop generation by closing the websocket here
if content["msg"] == "process_completed": if content["msg"] == "process_completed":
break break
except Exception as e: # pylint: disable=broad-exception-caught except Exception as e:
content = f"\nException: {e!r}" content = f"\nException: {e!r}"
new_data += content new_data += content
yield content yield content

View file

@ -4,7 +4,7 @@
import asyncio import asyncio
import sys import sys
from typing import Iterable, Protocol from typing import Iterable, Optional, Protocol
import click import click
import rich import rich
@ -40,7 +40,7 @@ class DumbPrinter:
class WrappingPrinter: class WrappingPrinter:
def __init__(self, width: int | None = None) -> None: def __init__(self, width: Optional[int] = None) -> None:
self._width = width or rich.get_console().width self._width = width or rich.get_console().width
self._column = 0 self._column = 0
self._line = "" self._line = ""
@ -122,4 +122,4 @@ def main(obj: Obj, prompt: str, print_prompt: bool) -> None:
if __name__ == "__main__": if __name__ == "__main__":
main() # pylint: disable=no-value-for-parameter main()

View file

@ -38,4 +38,4 @@ def main(obj: Obj, no_system: bool) -> None:
if __name__ == "__main__": if __name__ == "__main__":
main() # pylint: disable=no-value-for-parameter main()

View file

@ -24,8 +24,8 @@ def list_files_matching_rx(
"*.json" "*.json"
): ):
try: try:
session = session_from_file(conversation) # pylint: disable=no-member session = session_from_file(conversation)
except Exception as e: # pylint: disable=broad-exception-caught except Exception as e:
print(f"Failed to read {conversation}: {e}", file=sys.stderr) print(f"Failed to read {conversation}: {e}", file=sys.stderr)
continue continue
@ -67,4 +67,4 @@ def main(
if __name__ == "__main__": if __name__ == "__main__":
main() # pylint: disable=no-value-for-parameter main()

View file

@ -82,4 +82,4 @@ def main(output_directory: pathlib.Path, files: list[TextIO]) -> None:
if __name__ == "__main__": if __name__ == "__main__":
main() # pylint: disable=no-value-for-parameter main()

View file

@ -45,4 +45,4 @@ def main(obj: Obj, no_system: bool) -> None:
if __name__ == "__main__": if __name__ == "__main__":
main() # pylint: disable=no-value-for-parameter main()

View file

@ -54,3 +54,5 @@ Input {
Markdown { Markdown {
margin: 0 1 0 0; margin: 0 1 0 0;
} }
SubmittableTextArea { height: 3 }

View file

@ -3,30 +3,53 @@
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
import asyncio import asyncio
import subprocess
import sys import sys
from typing import cast from typing import Any, Optional, cast, TYPE_CHECKING
import click
from markdown_it import MarkdownIt from markdown_it import MarkdownIt
from textual import work from textual import work
from textual._ansi_sequences import ANSI_SEQUENCES_KEYS
from textual.app import App, ComposeResult from textual.app import App, ComposeResult
from textual.binding import Binding from textual.binding import Binding
from textual.containers import Container, Horizontal, VerticalScroll from textual.containers import Container, Horizontal, VerticalScroll
from textual.widgets import Button, Footer, Input, LoadingIndicator, Markdown from textual.keys import Keys
from textual.widgets import Button, Footer, LoadingIndicator, Markdown, TextArea
from ..core import Backend, Obj, command_uses_new_session, get_api, new_session_path from ..core import Backend, Obj, command_uses_new_session, get_api, new_session_path
from ..session import Assistant, Message, Session, User, new_session, session_to_file from ..session import Assistant, Message, Session, User, new_session, session_to_file
# workaround for pyperclip being un-typed
if TYPE_CHECKING:
def pyperclip_copy(data: str) -> None:
...
else:
from pyperclip import copy as pyperclip_copy
# Monkeypatch alt+enter as meaning "F9", WFM
# ignore typing here because ANSI_SEQUENCES_KEYS is a Mapping[] which is read-only as
# far as mypy is concerned.
ANSI_SEQUENCES_KEYS["\x1b\r"] = (Keys.F9,) # type: ignore
ANSI_SEQUENCES_KEYS["\x1b\n"] = (Keys.F9,) # type: ignore
class SubmittableTextArea(TextArea):
BINDINGS = [
Binding("f9", "submit", "Submit", show=True),
Binding("tab", "focus_next", show=False, priority=True), # no inserting tabs
]
def parser_factory() -> MarkdownIt: def parser_factory() -> MarkdownIt:
parser = MarkdownIt() parser = MarkdownIt()
parser.options["html"] = False parser.options["html"] = False
return parser return parser
class ChapMarkdown( class ChapMarkdown(Markdown, can_focus=True, can_focus_children=False):
Markdown, can_focus=True, can_focus_children=False
): # pylint: disable=function-redefined
BINDINGS = [ BINDINGS = [
Binding("ctrl+y", "yank", "Yank text", show=True), Binding("ctrl+y", "yank", "Yank text", show=True),
Binding("ctrl+r", "resubmit", "resubmit", show=True), Binding("ctrl+r", "resubmit", "resubmit", show=True),
@ -56,10 +79,10 @@ class Tui(App[None]):
] ]
def __init__( def __init__(
self, api: Backend | None = None, session: Session | None = None self, api: Optional[Backend] = None, session: Optional[Session] = None
) -> None: ) -> None:
super().__init__() super().__init__()
self.api = api or get_api("lorem") self.api = api or get_api(click.Context(click.Command("chap tui")), "lorem")
self.session = ( self.session = (
new_session(self.api.system_message) if session is None else session new_session(self.api.system_message) if session is None else session
) )
@ -73,8 +96,8 @@ class Tui(App[None]):
return cast(VerticalScroll, self.query_one("#wait")) return cast(VerticalScroll, self.query_one("#wait"))
@property @property
def input(self) -> Input: def input(self) -> SubmittableTextArea:
return self.query_one(Input) return self.query_one(SubmittableTextArea)
@property @property
def cancel_button(self) -> CancelButton: def cancel_button(self) -> CancelButton:
@ -94,7 +117,9 @@ class Tui(App[None]):
Container(id="pad"), Container(id="pad"),
id="content", id="content",
) )
yield Input(placeholder="Prompt") s = SubmittableTextArea(language="markdown")
s.show_line_numbers = False
yield s
with Horizontal(id="wait"): with Horizontal(id="wait"):
yield LoadingIndicator() yield LoadingIndicator()
yield CancelButton(label="❌ Stop Generation", id="cancel", disabled=True) yield CancelButton(label="❌ Stop Generation", id="cancel", disabled=True)
@ -103,8 +128,8 @@ class Tui(App[None]):
self.container.scroll_end(animate=False) self.container.scroll_end(animate=False)
self.input.focus() self.input.focus()
async def on_input_submitted(self, event: Input.Submitted) -> None: async def action_submit(self) -> None:
self.get_completion(event.value) self.get_completion(self.input.text)
@work(exclusive=True) @work(exclusive=True)
async def get_completion(self, query: str) -> None: async def get_completion(self, query: str) -> None:
@ -162,10 +187,10 @@ class Tui(App[None]):
try: try:
await asyncio.gather(render_fun(), get_token_fun()) await asyncio.gather(render_fun(), get_token_fun())
finally: finally:
self.input.value = "" self.input.clear()
all_output = self.session[-1].content all_output = self.session[-1].content
output.update(all_output) output.update(all_output)
output._markdown = all_output # pylint: disable=protected-access output._markdown = all_output
self.container.scroll_end() self.container.scroll_end()
for markdown in self.container.children: for markdown in self.container.children:
@ -183,8 +208,8 @@ class Tui(App[None]):
def action_yank(self) -> None: def action_yank(self) -> None:
widget = self.focused widget = self.focused
if isinstance(widget, ChapMarkdown): if isinstance(widget, ChapMarkdown):
content = widget._markdown or "" # pylint: disable=protected-access content = widget._markdown or ""
subprocess.run(["xsel", "-ib"], input=content.encode("utf-8"), check=False) pyperclip_copy(content)
def action_toggle_history(self) -> None: def action_toggle_history(self) -> None:
widget = self.focused widget = self.focused
@ -204,9 +229,7 @@ class Tui(App[None]):
async def action_stop_generating(self) -> None: async def action_stop_generating(self) -> None:
self.workers.cancel_all() self.workers.cancel_all()
async def on_button_pressed( # pylint: disable=unused-argument async def on_button_pressed(self, event: Button.Pressed) -> None:
self, event: Button.Pressed
) -> None:
self.workers.cancel_all() self.workers.cancel_all()
async def action_quit(self) -> None: async def action_quit(self) -> None:
@ -235,19 +258,31 @@ class Tui(App[None]):
session_to_file(self.session, new_session_path()) session_to_file(self.session, new_session_path())
query = self.session[idx].content query = self.session[idx].content
self.input.value = query self.input.load_text(query)
del self.session[idx:] del self.session[idx:]
for child in self.container.children[idx:-1]: for child in self.container.children[idx:-1]:
await child.remove() await child.remove()
self.input.focus() self.input.focus()
self.on_text_area_changed()
if resubmit: if resubmit:
await self.input.action_submit() await self.action_submit()
def on_text_area_changed(self, event: Any = None) -> None:
height = self.input.document.get_size(self.input.indent_width)[1]
max_height = max(3, self.size.height - 6)
if height >= max_height:
self.input.styles.height = max_height
elif height <= 3:
self.input.styles.height = 3
else:
self.input.styles.height = height
@command_uses_new_session @command_uses_new_session
def main(obj: Obj) -> None: @click.option("--replace-system-prompt/--no-replace-system-prompt", default=False)
def main(obj: Obj, replace_system_prompt: bool) -> None:
"""Start interactive terminal user interface session""" """Start interactive terminal user interface session"""
api = obj.api api = obj.api
assert api is not None assert api is not None
@ -256,6 +291,9 @@ def main(obj: Obj) -> None:
session_filename = obj.session_filename session_filename = obj.session_filename
assert session_filename is not None assert session_filename is not None
if replace_system_prompt:
session[0].content = obj.system_message or api.system_message
tui = Tui(api, session) tui = Tui(api, session)
tui.run() tui.run()
@ -268,4 +306,4 @@ def main(obj: Obj) -> None:
if __name__ == "__main__": if __name__ == "__main__":
main() # pylint: disable=no-value-for-parameter main()

View file

@ -1,7 +1,7 @@
# SPDX-FileCopyrightText: 2023 Jeff Epler <jepler@gmail.com> # SPDX-FileCopyrightText: 2023 Jeff Epler <jepler@gmail.com>
# #
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
# pylint: disable=import-outside-toplevel
import asyncio import asyncio
import datetime import datetime
@ -10,23 +10,38 @@ import os
import pathlib import pathlib
import pkgutil import pkgutil
import subprocess import subprocess
from dataclasses import MISSING, dataclass, fields from dataclasses import MISSING, Field, dataclass, fields
from types import UnionType from typing import (
from typing import Any, AsyncGenerator, Callable, cast Any,
AsyncGenerator,
Callable,
Optional,
Union,
cast,
get_origin,
get_args,
)
import sys
import click import click
import platformdirs import platformdirs
from simple_parsing.docstring import get_attribute_docstring from simple_parsing.docstring import get_attribute_docstring
from typing_extensions import Protocol from typing_extensions import Protocol
from . import backends, commands # pylint: disable=no-name-in-module from . import backends, commands
from .session import Message, Session, System, session_from_file from .session import Message, Session, System, session_from_file
UnionType: type
if sys.version_info >= (3, 10):
from types import UnionType
else:
UnionType = type(Union[int, float])
conversations_path = platformdirs.user_state_path("chap") / "conversations" conversations_path = platformdirs.user_state_path("chap") / "conversations"
conversations_path.mkdir(parents=True, exist_ok=True) conversations_path.mkdir(parents=True, exist_ok=True)
class ABackend(Protocol): # pylint: disable=too-few-public-methods class ABackend(Protocol):
def aask(self, session: Session, query: str) -> AsyncGenerator[str, None]: def aask(self, session: Session, query: str) -> AsyncGenerator[str, None]:
"""Make a query, updating the session with the query and response, returning the query token by token""" """Make a query, updating the session with the query and response, returning the query token by token"""
@ -39,7 +54,7 @@ class Backend(ABackend, Protocol):
"""Make a query, updating the session with the query and response, returning the query""" """Make a query, updating the session with the query and response, returning the query"""
class AutoAskMixin: # pylint: disable=too-few-public-methods class AutoAskMixin:
"""Mixin class for backends implementing aask""" """Mixin class for backends implementing aask"""
def ask(self, session: Session, query: str) -> str: def ask(self, session: Session, query: str) -> str:
@ -54,20 +69,50 @@ class AutoAskMixin: # pylint: disable=too-few-public-methods
return "".join(tokens) return "".join(tokens)
def last_session_path() -> pathlib.Path | None: def last_session_path() -> Optional[pathlib.Path]:
result = max( result = max(
conversations_path.glob("*.json"), key=lambda p: p.stat().st_mtime, default=None conversations_path.glob("*.json"), key=lambda p: p.stat().st_mtime, default=None
) )
return result return result
def new_session_path(opt_path: pathlib.Path | None = None) -> pathlib.Path: def new_session_path(opt_path: Optional[pathlib.Path] = None) -> pathlib.Path:
return opt_path or conversations_path / ( return opt_path or conversations_path / (
datetime.datetime.now().isoformat().replace(":", "_") + ".json" datetime.datetime.now().isoformat().replace(":", "_") + ".json"
) )
def configure_api_from_environment(api_name: str, api: Backend) -> None: def get_field_type(field: Field[Any]) -> Any:
field_type = field.type
if isinstance(field_type, str):
raise RuntimeError(
"parameters dataclass may not use 'from __future__ import annotations"
)
origin = get_origin(field_type)
if origin in (Union, UnionType):
for arg in get_args(field_type):
if arg is not None:
return arg
return field_type
def convert_str_to_field(ctx: click.Context, field: Field[Any], value: str) -> Any:
field_type = get_field_type(field)
try:
if field_type is bool:
tv = click.types.BoolParamType().convert(value, None, ctx)
else:
tv = field_type(value)
return tv
except ValueError as e:
raise click.BadParameter(
f"Invalid value for {field.name} with value {value}: {e}"
) from e
def configure_api_from_environment(
ctx: click.Context, api_name: str, api: Backend
) -> None:
if not hasattr(api, "parameters"): if not hasattr(api, "parameters"):
return return
@ -76,26 +121,21 @@ def configure_api_from_environment(api_name: str, api: Backend) -> None:
value = os.environ.get(envvar) value = os.environ.get(envvar)
if value is None: if value is None:
continue continue
try: tv = convert_str_to_field(ctx, field, value)
tv = field.type(value)
except ValueError as e:
raise click.BadParameter(
f"Invalid value for {field.name} with value {value}: {e}"
) from e
setattr(api.parameters, field.name, tv) setattr(api.parameters, field.name, tv)
def get_api(name: str = "openai_chatgpt") -> Backend: def get_api(ctx: click.Context, name: str = "openai_chatgpt") -> Backend:
name = name.replace("-", "_") name = name.replace("-", "_")
result = cast( backend = cast(
Backend, importlib.import_module(f"{__package__}.backends.{name}").factory() Backend, importlib.import_module(f"{__package__}.backends.{name}").factory()
) )
configure_api_from_environment(name, result) configure_api_from_environment(ctx, name, backend)
return result return backend
def do_session_continue( def do_session_continue(
ctx: click.Context, param: click.Parameter, value: pathlib.Path | None ctx: click.Context, param: click.Parameter, value: Optional[pathlib.Path]
) -> None: ) -> None:
if value is None: if value is None:
return return
@ -108,9 +148,7 @@ def do_session_continue(
ctx.obj.session_filename = value ctx.obj.session_filename = value
def do_session_last( def do_session_last(ctx: click.Context, param: click.Parameter, value: bool) -> None:
ctx: click.Context, param: click.Parameter, value: bool
) -> None: # pylint: disable=unused-argument
if not value: if not value:
return return
do_session_continue(ctx, param, last_session_path()) do_session_continue(ctx, param, last_session_path())
@ -138,18 +176,14 @@ def colonstr(arg: str) -> tuple[str, str]:
return cast(tuple[str, str], tuple(arg.split(":", 1))) return cast(tuple[str, str], tuple(arg.split(":", 1)))
def set_system_message( # pylint: disable=unused-argument def set_system_message(ctx: click.Context, param: click.Parameter, value: str) -> None:
ctx: click.Context, param: click.Parameter, value: str
) -> None:
if value and value.startswith("@"): if value and value.startswith("@"):
with open(value[1:], "r", encoding="utf-8") as f: with open(value[1:], "r", encoding="utf-8") as f:
value = f.read().rstrip() value = f.read().rstrip()
ctx.obj.system_message = value ctx.obj.system_message = value
def set_backend( # pylint: disable=unused-argument def set_backend(ctx: click.Context, param: click.Parameter, value: str) -> None:
ctx: click.Context, param: click.Parameter, value: str
) -> None:
if value == "list": if value == "list":
formatter = ctx.make_formatter() formatter = ctx.make_formatter()
format_backend_list(formatter) format_backend_list(formatter)
@ -157,7 +191,7 @@ def set_backend( # pylint: disable=unused-argument
ctx.exit() ctx.exit()
try: try:
ctx.obj.api = get_api(value) ctx.obj.api = get_api(ctx, value)
except ModuleNotFoundError as e: except ModuleNotFoundError as e:
raise click.BadParameter(str(e)) raise click.BadParameter(str(e))
@ -172,15 +206,13 @@ def format_backend_help(api: Backend, formatter: click.HelpFormatter) -> None:
if doc: if doc:
doc += " " doc += " "
doc += f"(Default: {default!r})" doc += f"(Default: {default!r})"
f_type = f.type f_type = get_field_type(f)
if isinstance(f_type, UnionType):
f_type = f_type.__args__[0]
typename = f_type.__name__ typename = f_type.__name__
rows.append((f"-B {name}:{typename.upper()}", doc)) rows.append((f"-B {name}:{typename.upper()}", doc))
formatter.write_dl(rows) formatter.write_dl(rows)
def set_backend_option( # pylint: disable=unused-argument def set_backend_option(
ctx: click.Context, param: click.Parameter, opts: list[tuple[str, str]] ctx: click.Context, param: click.Parameter, opts: list[tuple[str, str]]
) -> None: ) -> None:
api = ctx.obj.api api = ctx.obj.api
@ -195,16 +227,8 @@ def set_backend_option( # pylint: disable=unused-argument
field = all_fields.get(name) field = all_fields.get(name)
if field is None: if field is None:
raise click.BadParameter(f"Invalid parameter {name}") raise click.BadParameter(f"Invalid parameter {name}")
f_type = field.type tv = convert_str_to_field(ctx, field, value)
if isinstance(f_type, UnionType): setattr(api.parameters, field.name, tv)
f_type = f_type.__args__[0]
try:
tv = f_type(value)
except ValueError as e:
raise click.BadParameter(
f"Invalid value for {name} with value {value}: {e}"
) from e
setattr(api.parameters, name, tv)
for kv in opts: for kv in opts:
set_one_backend_option(kv) set_one_backend_option(kv)
@ -264,9 +288,7 @@ def command_uses_new_session(f_in: click.decorators.FC) -> click.Command:
return click.command()(f) return click.command()(f)
def version_callback( # pylint: disable=unused-argument def version_callback(ctx: click.Context, param: click.Parameter, value: None) -> None:
ctx: click.Context, param: click.Parameter, value: None
) -> None:
if not value or ctx.resilient_parsing: if not value or ctx.resilient_parsing:
return return
@ -293,18 +315,18 @@ def version_callback( # pylint: disable=unused-argument
@dataclass @dataclass
class Obj: class Obj:
api: Backend | None = None api: Optional[Backend] = None
system_message: str | None = None system_message: Optional[str] = None
session: list[Message] | None = None session: Optional[list[Message]] = None
session_filename: pathlib.Path | None = None session_filename: Optional[pathlib.Path] = None
class MyCLI(click.MultiCommand): class MyCLI(click.MultiCommand):
def make_context( def make_context(
self, self,
info_name: str | None, info_name: Optional[str],
args: list[str], args: list[str],
parent: click.Context | None = None, parent: Optional[click.Context] = None,
**extra: Any, **extra: Any,
) -> click.Context: ) -> click.Context:
result = super().make_context(info_name, args, parent, obj=Obj(), **extra) result = super().make_context(info_name, args, parent, obj=Obj(), **extra)
@ -332,7 +354,7 @@ class MyCLI(click.MultiCommand):
self, ctx: click.Context, formatter: click.HelpFormatter self, ctx: click.Context, formatter: click.HelpFormatter
) -> None: ) -> None:
super().format_options(ctx, formatter) super().format_options(ctx, formatter)
api = ctx.obj.api or get_api() api = ctx.obj.api or get_api(ctx)
if hasattr(api, "parameters"): if hasattr(api, "parameters"):
format_backend_help(api, formatter) format_backend_help(api, formatter)

View file

@ -7,13 +7,13 @@ from __future__ import annotations
import json import json
import pathlib import pathlib
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from typing import cast from typing import Union, cast
from typing_extensions import TypedDict from typing_extensions import TypedDict
# not an enum.Enum because these objects are not json-serializable, sigh # not an enum.Enum because these objects are not json-serializable, sigh
class Role: # pylint: disable=too-few-public-methods class Role:
ASSISTANT = "assistant" ASSISTANT = "assistant"
SYSTEM = "system" SYSTEM = "system"
USER = "user" USER = "user"
@ -65,11 +65,11 @@ def session_from_json(data: str) -> Session:
return [Message(**mapping) for mapping in j] return [Message(**mapping) for mapping in j]
def session_from_file(path: pathlib.Path | str) -> Session: def session_from_file(path: Union[pathlib.Path, str]) -> Session:
with open(path, "r", encoding="utf-8") as f: with open(path, "r", encoding="utf-8") as f:
return session_from_json(f.read()) return session_from_json(f.read())
def session_to_file(session: Session, path: pathlib.Path | str) -> None: def session_to_file(session: Session, path: Union[pathlib.Path, str]) -> None:
with open(path, "w", encoding="utf-8") as f: with open(path, "w", encoding="utf-8") as f:
f.write(session_to_json(session)) f.write(session_to_json(session))