Merge pull request #30 from jepler/multiline2
Switch chap tui to a multiline text field
This commit is contained in:
commit
9f6ace394a
21 changed files with 205 additions and 144 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -8,3 +8,4 @@ __pycache__
|
|||
/dist
|
||||
/src/chap/__version__.py
|
||||
/venv
|
||||
/keys.log
|
||||
|
|
|
|||
|
|
@ -6,12 +6,8 @@ default_language_version:
|
|||
python: python3
|
||||
|
||||
repos:
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 23.1.0
|
||||
hooks:
|
||||
- id: black
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.4.0
|
||||
rev: v4.5.0
|
||||
hooks:
|
||||
- id: check-yaml
|
||||
- id: end-of-file-fixer
|
||||
|
|
@ -19,23 +15,20 @@ repos:
|
|||
- id: trailing-whitespace
|
||||
exclude: tests
|
||||
- repo: https://github.com/codespell-project/codespell
|
||||
rev: v2.2.4
|
||||
rev: v2.2.6
|
||||
hooks:
|
||||
- id: codespell
|
||||
args: [-w]
|
||||
- repo: https://github.com/fsfe/reuse-tool
|
||||
rev: v1.1.2
|
||||
rev: v2.1.0
|
||||
hooks:
|
||||
- id: reuse
|
||||
- repo: https://github.com/pycqa/isort
|
||||
rev: 5.12.0
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
# Ruff version.
|
||||
rev: v0.1.6
|
||||
hooks:
|
||||
- id: isort
|
||||
name: isort (python)
|
||||
args: ['--profile', 'black']
|
||||
- repo: https://github.com/pycqa/pylint
|
||||
rev: v2.17.0
|
||||
hooks:
|
||||
- id: pylint
|
||||
additional_dependencies: [click,httpx,lorem-text,simple-parsing,'textual>=0.18.0',tiktoken,websockets]
|
||||
args: ['--source-roots', 'src']
|
||||
# Run the linter.
|
||||
- id: ruff
|
||||
args: [ --fix, --preview ]
|
||||
# Run the formatter.
|
||||
- id: ruff-format
|
||||
|
|
|
|||
12
.pylintrc
12
.pylintrc
|
|
@ -1,12 +0,0 @@
|
|||
# SPDX-FileCopyrightText: 2023 Jeff Epler <jepler@gmail.com>
|
||||
#
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
[MESSAGES CONTROL]
|
||||
disable=
|
||||
duplicate-code,
|
||||
invalid-name,
|
||||
line-too-long,
|
||||
missing-class-docstring,
|
||||
missing-function-docstring,
|
||||
missing-module-docstring,
|
||||
27
README.md
27
README.md
|
|
@ -9,11 +9,11 @@ SPDX-License-Identifier: MIT
|
|||
|
||||
# chap - A Python interface to chatgpt and other LLMs, including a terminal user interface (tui)
|
||||
|
||||

|
||||

|
||||
|
||||
## System requirements
|
||||
|
||||
Chap is developed on Linux with Python 3.11. Due to use of the `X | Y` style of type hints, it is known to not work on Python 3.9 and older. The target minimum Python version is 3.11 (debian stable).
|
||||
Chap is primarily developed on Linux with Python 3.11. Moderate effort will be made to support versions back to Python 3.9 (Debian oldstable).
|
||||
|
||||
## Installation
|
||||
|
||||
|
|
@ -82,7 +82,28 @@ Put your OpenAI API key in the platform configuration directory for chap, e.g.,
|
|||
* `chap grep needle`
|
||||
|
||||
## Interactive terminal usage
|
||||
* `chap tui`
|
||||
The interactive terminal mode is accessed via `chap tui`.
|
||||
|
||||
There are a variety of keyboard shortcuts to be aware of:
|
||||
* tab/shift-tab to move between the entry field and the conversation, or between conversation items
|
||||
* While in the text box, F9 or (if supported by your terminal) alt+enter to submit multiline text
|
||||
* while on a conversation item:
|
||||
* ctrl+x to re-draft the message. This
|
||||
* saves a copy of the session in an auto-named file in the conversations folder
|
||||
* removes the conversation from this message to the end
|
||||
* puts the user's message in the text box to edit
|
||||
* ctrl+x to re-submit the message. This
|
||||
* saves a copy of the session in an auto-named file in the conversations folder
|
||||
* removes the conversation from this message to the end
|
||||
* puts the user's message in the text box
|
||||
* and submits it immediately
|
||||
* ctrl+y to yank the message. This places the response part of the current
|
||||
interaction in the operating system clipboard to be pasted (e..g, with
|
||||
ctrl+v or command+v in other software)
|
||||
* ctrl+q to toggle whether this message may be included in the contextual history for a future query.
|
||||
The exact way history is submitted is determined by the back-end, often by
|
||||
counting messages or tokens, but the ctrl+q toggle ensures this message (both the user
|
||||
and assistant message parts) are not considered.
|
||||
|
||||
## Sessions & Command-line Parameters
|
||||
|
||||
|
|
|
|||
|
|
@ -14,7 +14,6 @@ src_path = project_root / "src"
|
|||
sys.path.insert(0, str(src_path))
|
||||
|
||||
if __name__ == "__main__":
|
||||
# pylint: disable=import-error,no-name-in-module
|
||||
from chap.core import main
|
||||
|
||||
main()
|
||||
|
|
|
|||
|
|
@ -20,13 +20,14 @@ name="chap"
|
|||
authors = [{name = "Jeff Epler", email = "jepler@gmail.com"}]
|
||||
description = "Interact with the OpenAI ChatGPT API (and other text generators)"
|
||||
dynamic = ["readme","version","dependencies"]
|
||||
requires-python = ">=3.10"
|
||||
requires-python = ">=3.9"
|
||||
keywords = ["llm", "tui", "chatgpt"]
|
||||
classifiers = [
|
||||
"Development Status :: 3 - Alpha",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Operating System :: OS Independent",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
|
|
|
|||
|
|
@ -6,7 +6,8 @@ click
|
|||
httpx
|
||||
lorem-text
|
||||
platformdirs
|
||||
pyperclip
|
||||
simple_parsing
|
||||
textual>=0.18.0
|
||||
textual[syntax]
|
||||
tiktoken
|
||||
websockets
|
||||
|
|
|
|||
|
|
@ -89,9 +89,7 @@ A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, a
|
|||
*,
|
||||
max_query_size: int = 5,
|
||||
timeout: float = 180,
|
||||
) -> AsyncGenerator[
|
||||
str, None
|
||||
]: # pylint: disable=unused-argument,too-many-locals,too-many-branches
|
||||
) -> AsyncGenerator[str, None]:
|
||||
new_content: list[str] = []
|
||||
inputs = self.make_full_query(session + [User(query)], max_query_size)
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -55,9 +55,7 @@ A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, a
|
|||
*,
|
||||
max_query_size: int = 5,
|
||||
timeout: float = 180,
|
||||
) -> AsyncGenerator[
|
||||
str, None
|
||||
]: # pylint: disable=unused-argument,too-many-locals,too-many-branches
|
||||
) -> AsyncGenerator[str, None]:
|
||||
params = {
|
||||
"prompt": self.make_full_query(session + [User(query)], max_query_size),
|
||||
"stream": True,
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ class Lorem:
|
|||
session: Session,
|
||||
query: str,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
data = self.ask(session, query)[-1]
|
||||
data = self.ask(session, query)
|
||||
for word, opt_sep in ipartition(data):
|
||||
yield word + opt_sep
|
||||
await asyncio.sleep(
|
||||
|
|
@ -56,7 +56,7 @@ class Lorem:
|
|||
self,
|
||||
session: Session,
|
||||
query: str,
|
||||
) -> str: # pylint: disable=unused-argument
|
||||
) -> str:
|
||||
new_content = cast(
|
||||
str,
|
||||
lorem.paragraphs(
|
||||
|
|
|
|||
|
|
@ -101,7 +101,7 @@ class ChatGPT:
|
|||
json={
|
||||
"model": self.parameters.model,
|
||||
"messages": session_to_list(full_prompt),
|
||||
}, # pylint: disable=no-member
|
||||
},
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.get_key()}",
|
||||
},
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ USER: Hello, AI.
|
|||
|
||||
AI: Hello! How can I assist you today?"""
|
||||
|
||||
async def aask( # pylint: disable=unused-argument,too-many-locals,too-many-branches
|
||||
async def aask(
|
||||
self,
|
||||
session: Session,
|
||||
query: str,
|
||||
|
|
@ -60,12 +60,11 @@ AI: Hello! How can I assist you today?"""
|
|||
}
|
||||
full_prompt = session + [User(query)]
|
||||
del full_prompt[1:-max_query_size]
|
||||
new_data = old_data = full_query = (
|
||||
"\n".join(f"{role_map.get(q.role,'')}{q.content}\n" for q in full_prompt)
|
||||
+ f"\n{role_map.get('assistant')}"
|
||||
)
|
||||
new_data = old_data = full_query = "\n".join(
|
||||
f"{role_map.get(q.role,'')}{q.content}\n" for q in full_prompt
|
||||
) + f"\n{role_map.get('assistant')}"
|
||||
try:
|
||||
async with websockets.connect( # pylint: disable=no-member
|
||||
async with websockets.connect(
|
||||
f"ws://{self.parameters.server_hostname}:7860/queue/join"
|
||||
) as websocket:
|
||||
while content := json.loads(await websocket.recv()):
|
||||
|
|
@ -127,7 +126,7 @@ AI: Hello! How can I assist you today?"""
|
|||
# stop generation by closing the websocket here
|
||||
if content["msg"] == "process_completed":
|
||||
break
|
||||
except Exception as e: # pylint: disable=broad-exception-caught
|
||||
except Exception as e:
|
||||
content = f"\nException: {e!r}"
|
||||
new_data += content
|
||||
yield content
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
|
||||
import asyncio
|
||||
import sys
|
||||
from typing import Iterable, Protocol
|
||||
from typing import Iterable, Optional, Protocol
|
||||
|
||||
import click
|
||||
import rich
|
||||
|
|
@ -40,7 +40,7 @@ class DumbPrinter:
|
|||
|
||||
|
||||
class WrappingPrinter:
|
||||
def __init__(self, width: int | None = None) -> None:
|
||||
def __init__(self, width: Optional[int] = None) -> None:
|
||||
self._width = width or rich.get_console().width
|
||||
self._column = 0
|
||||
self._line = ""
|
||||
|
|
@ -122,4 +122,4 @@ def main(obj: Obj, prompt: str, print_prompt: bool) -> None:
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main() # pylint: disable=no-value-for-parameter
|
||||
main()
|
||||
|
|
|
|||
|
|
@ -38,4 +38,4 @@ def main(obj: Obj, no_system: bool) -> None:
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main() # pylint: disable=no-value-for-parameter
|
||||
main()
|
||||
|
|
|
|||
|
|
@ -24,8 +24,8 @@ def list_files_matching_rx(
|
|||
"*.json"
|
||||
):
|
||||
try:
|
||||
session = session_from_file(conversation) # pylint: disable=no-member
|
||||
except Exception as e: # pylint: disable=broad-exception-caught
|
||||
session = session_from_file(conversation)
|
||||
except Exception as e:
|
||||
print(f"Failed to read {conversation}: {e}", file=sys.stderr)
|
||||
continue
|
||||
|
||||
|
|
@ -67,4 +67,4 @@ def main(
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main() # pylint: disable=no-value-for-parameter
|
||||
main()
|
||||
|
|
|
|||
|
|
@ -82,4 +82,4 @@ def main(output_directory: pathlib.Path, files: list[TextIO]) -> None:
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main() # pylint: disable=no-value-for-parameter
|
||||
main()
|
||||
|
|
|
|||
|
|
@ -45,4 +45,4 @@ def main(obj: Obj, no_system: bool) -> None:
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main() # pylint: disable=no-value-for-parameter
|
||||
main()
|
||||
|
|
|
|||
|
|
@ -54,3 +54,5 @@ Input {
|
|||
Markdown {
|
||||
margin: 0 1 0 0;
|
||||
}
|
||||
|
||||
SubmittableTextArea { height: 3 }
|
||||
|
|
|
|||
|
|
@ -3,30 +3,53 @@
|
|||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import asyncio
|
||||
import subprocess
|
||||
import sys
|
||||
from typing import cast
|
||||
from typing import Any, Optional, cast, TYPE_CHECKING
|
||||
|
||||
import click
|
||||
from markdown_it import MarkdownIt
|
||||
from textual import work
|
||||
from textual._ansi_sequences import ANSI_SEQUENCES_KEYS
|
||||
from textual.app import App, ComposeResult
|
||||
from textual.binding import Binding
|
||||
from textual.containers import Container, Horizontal, VerticalScroll
|
||||
from textual.widgets import Button, Footer, Input, LoadingIndicator, Markdown
|
||||
from textual.keys import Keys
|
||||
from textual.widgets import Button, Footer, LoadingIndicator, Markdown, TextArea
|
||||
|
||||
from ..core import Backend, Obj, command_uses_new_session, get_api, new_session_path
|
||||
from ..session import Assistant, Message, Session, User, new_session, session_to_file
|
||||
|
||||
|
||||
# workaround for pyperclip being un-typed
|
||||
if TYPE_CHECKING:
|
||||
|
||||
def pyperclip_copy(data: str) -> None:
|
||||
...
|
||||
else:
|
||||
from pyperclip import copy as pyperclip_copy
|
||||
|
||||
|
||||
# Monkeypatch alt+enter as meaning "F9", WFM
|
||||
# ignore typing here because ANSI_SEQUENCES_KEYS is a Mapping[] which is read-only as
|
||||
# far as mypy is concerned.
|
||||
ANSI_SEQUENCES_KEYS["\x1b\r"] = (Keys.F9,) # type: ignore
|
||||
ANSI_SEQUENCES_KEYS["\x1b\n"] = (Keys.F9,) # type: ignore
|
||||
|
||||
|
||||
class SubmittableTextArea(TextArea):
|
||||
BINDINGS = [
|
||||
Binding("f9", "submit", "Submit", show=True),
|
||||
Binding("tab", "focus_next", show=False, priority=True), # no inserting tabs
|
||||
]
|
||||
|
||||
|
||||
def parser_factory() -> MarkdownIt:
|
||||
parser = MarkdownIt()
|
||||
parser.options["html"] = False
|
||||
return parser
|
||||
|
||||
|
||||
class ChapMarkdown(
|
||||
Markdown, can_focus=True, can_focus_children=False
|
||||
): # pylint: disable=function-redefined
|
||||
class ChapMarkdown(Markdown, can_focus=True, can_focus_children=False):
|
||||
BINDINGS = [
|
||||
Binding("ctrl+y", "yank", "Yank text", show=True),
|
||||
Binding("ctrl+r", "resubmit", "resubmit", show=True),
|
||||
|
|
@ -56,10 +79,10 @@ class Tui(App[None]):
|
|||
]
|
||||
|
||||
def __init__(
|
||||
self, api: Backend | None = None, session: Session | None = None
|
||||
self, api: Optional[Backend] = None, session: Optional[Session] = None
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.api = api or get_api("lorem")
|
||||
self.api = api or get_api(click.Context(click.Command("chap tui")), "lorem")
|
||||
self.session = (
|
||||
new_session(self.api.system_message) if session is None else session
|
||||
)
|
||||
|
|
@ -73,8 +96,8 @@ class Tui(App[None]):
|
|||
return cast(VerticalScroll, self.query_one("#wait"))
|
||||
|
||||
@property
|
||||
def input(self) -> Input:
|
||||
return self.query_one(Input)
|
||||
def input(self) -> SubmittableTextArea:
|
||||
return self.query_one(SubmittableTextArea)
|
||||
|
||||
@property
|
||||
def cancel_button(self) -> CancelButton:
|
||||
|
|
@ -94,7 +117,9 @@ class Tui(App[None]):
|
|||
Container(id="pad"),
|
||||
id="content",
|
||||
)
|
||||
yield Input(placeholder="Prompt")
|
||||
s = SubmittableTextArea(language="markdown")
|
||||
s.show_line_numbers = False
|
||||
yield s
|
||||
with Horizontal(id="wait"):
|
||||
yield LoadingIndicator()
|
||||
yield CancelButton(label="❌ Stop Generation", id="cancel", disabled=True)
|
||||
|
|
@ -103,8 +128,8 @@ class Tui(App[None]):
|
|||
self.container.scroll_end(animate=False)
|
||||
self.input.focus()
|
||||
|
||||
async def on_input_submitted(self, event: Input.Submitted) -> None:
|
||||
self.get_completion(event.value)
|
||||
async def action_submit(self) -> None:
|
||||
self.get_completion(self.input.text)
|
||||
|
||||
@work(exclusive=True)
|
||||
async def get_completion(self, query: str) -> None:
|
||||
|
|
@ -162,10 +187,10 @@ class Tui(App[None]):
|
|||
try:
|
||||
await asyncio.gather(render_fun(), get_token_fun())
|
||||
finally:
|
||||
self.input.value = ""
|
||||
self.input.clear()
|
||||
all_output = self.session[-1].content
|
||||
output.update(all_output)
|
||||
output._markdown = all_output # pylint: disable=protected-access
|
||||
output._markdown = all_output
|
||||
self.container.scroll_end()
|
||||
|
||||
for markdown in self.container.children:
|
||||
|
|
@ -183,8 +208,8 @@ class Tui(App[None]):
|
|||
def action_yank(self) -> None:
|
||||
widget = self.focused
|
||||
if isinstance(widget, ChapMarkdown):
|
||||
content = widget._markdown or "" # pylint: disable=protected-access
|
||||
subprocess.run(["xsel", "-ib"], input=content.encode("utf-8"), check=False)
|
||||
content = widget._markdown or ""
|
||||
pyperclip_copy(content)
|
||||
|
||||
def action_toggle_history(self) -> None:
|
||||
widget = self.focused
|
||||
|
|
@ -204,9 +229,7 @@ class Tui(App[None]):
|
|||
async def action_stop_generating(self) -> None:
|
||||
self.workers.cancel_all()
|
||||
|
||||
async def on_button_pressed( # pylint: disable=unused-argument
|
||||
self, event: Button.Pressed
|
||||
) -> None:
|
||||
async def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||
self.workers.cancel_all()
|
||||
|
||||
async def action_quit(self) -> None:
|
||||
|
|
@ -235,19 +258,31 @@ class Tui(App[None]):
|
|||
session_to_file(self.session, new_session_path())
|
||||
|
||||
query = self.session[idx].content
|
||||
self.input.value = query
|
||||
self.input.load_text(query)
|
||||
|
||||
del self.session[idx:]
|
||||
for child in self.container.children[idx:-1]:
|
||||
await child.remove()
|
||||
|
||||
self.input.focus()
|
||||
self.on_text_area_changed()
|
||||
if resubmit:
|
||||
await self.input.action_submit()
|
||||
await self.action_submit()
|
||||
|
||||
def on_text_area_changed(self, event: Any = None) -> None:
|
||||
height = self.input.document.get_size(self.input.indent_width)[1]
|
||||
max_height = max(3, self.size.height - 6)
|
||||
if height >= max_height:
|
||||
self.input.styles.height = max_height
|
||||
elif height <= 3:
|
||||
self.input.styles.height = 3
|
||||
else:
|
||||
self.input.styles.height = height
|
||||
|
||||
|
||||
@command_uses_new_session
|
||||
def main(obj: Obj) -> None:
|
||||
@click.option("--replace-system-prompt/--no-replace-system-prompt", default=False)
|
||||
def main(obj: Obj, replace_system_prompt: bool) -> None:
|
||||
"""Start interactive terminal user interface session"""
|
||||
api = obj.api
|
||||
assert api is not None
|
||||
|
|
@ -256,6 +291,9 @@ def main(obj: Obj) -> None:
|
|||
session_filename = obj.session_filename
|
||||
assert session_filename is not None
|
||||
|
||||
if replace_system_prompt:
|
||||
session[0].content = obj.system_message or api.system_message
|
||||
|
||||
tui = Tui(api, session)
|
||||
tui.run()
|
||||
|
||||
|
|
@ -268,4 +306,4 @@ def main(obj: Obj) -> None:
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main() # pylint: disable=no-value-for-parameter
|
||||
main()
|
||||
|
|
|
|||
132
src/chap/core.py
132
src/chap/core.py
|
|
@ -1,7 +1,7 @@
|
|||
# SPDX-FileCopyrightText: 2023 Jeff Epler <jepler@gmail.com>
|
||||
#
|
||||
# SPDX-License-Identifier: MIT
|
||||
# pylint: disable=import-outside-toplevel
|
||||
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
|
|
@ -10,23 +10,38 @@ import os
|
|||
import pathlib
|
||||
import pkgutil
|
||||
import subprocess
|
||||
from dataclasses import MISSING, dataclass, fields
|
||||
from types import UnionType
|
||||
from typing import Any, AsyncGenerator, Callable, cast
|
||||
from dataclasses import MISSING, Field, dataclass, fields
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Callable,
|
||||
Optional,
|
||||
Union,
|
||||
cast,
|
||||
get_origin,
|
||||
get_args,
|
||||
)
|
||||
import sys
|
||||
|
||||
import click
|
||||
import platformdirs
|
||||
from simple_parsing.docstring import get_attribute_docstring
|
||||
from typing_extensions import Protocol
|
||||
|
||||
from . import backends, commands # pylint: disable=no-name-in-module
|
||||
from . import backends, commands
|
||||
from .session import Message, Session, System, session_from_file
|
||||
|
||||
UnionType: type
|
||||
if sys.version_info >= (3, 10):
|
||||
from types import UnionType
|
||||
else:
|
||||
UnionType = type(Union[int, float])
|
||||
|
||||
conversations_path = platformdirs.user_state_path("chap") / "conversations"
|
||||
conversations_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
class ABackend(Protocol): # pylint: disable=too-few-public-methods
|
||||
class ABackend(Protocol):
|
||||
def aask(self, session: Session, query: str) -> AsyncGenerator[str, None]:
|
||||
"""Make a query, updating the session with the query and response, returning the query token by token"""
|
||||
|
||||
|
|
@ -39,7 +54,7 @@ class Backend(ABackend, Protocol):
|
|||
"""Make a query, updating the session with the query and response, returning the query"""
|
||||
|
||||
|
||||
class AutoAskMixin: # pylint: disable=too-few-public-methods
|
||||
class AutoAskMixin:
|
||||
"""Mixin class for backends implementing aask"""
|
||||
|
||||
def ask(self, session: Session, query: str) -> str:
|
||||
|
|
@ -54,20 +69,50 @@ class AutoAskMixin: # pylint: disable=too-few-public-methods
|
|||
return "".join(tokens)
|
||||
|
||||
|
||||
def last_session_path() -> pathlib.Path | None:
|
||||
def last_session_path() -> Optional[pathlib.Path]:
|
||||
result = max(
|
||||
conversations_path.glob("*.json"), key=lambda p: p.stat().st_mtime, default=None
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def new_session_path(opt_path: pathlib.Path | None = None) -> pathlib.Path:
|
||||
def new_session_path(opt_path: Optional[pathlib.Path] = None) -> pathlib.Path:
|
||||
return opt_path or conversations_path / (
|
||||
datetime.datetime.now().isoformat().replace(":", "_") + ".json"
|
||||
)
|
||||
|
||||
|
||||
def configure_api_from_environment(api_name: str, api: Backend) -> None:
|
||||
def get_field_type(field: Field[Any]) -> Any:
|
||||
field_type = field.type
|
||||
if isinstance(field_type, str):
|
||||
raise RuntimeError(
|
||||
"parameters dataclass may not use 'from __future__ import annotations"
|
||||
)
|
||||
origin = get_origin(field_type)
|
||||
if origin in (Union, UnionType):
|
||||
for arg in get_args(field_type):
|
||||
if arg is not None:
|
||||
return arg
|
||||
return field_type
|
||||
|
||||
|
||||
def convert_str_to_field(ctx: click.Context, field: Field[Any], value: str) -> Any:
|
||||
field_type = get_field_type(field)
|
||||
try:
|
||||
if field_type is bool:
|
||||
tv = click.types.BoolParamType().convert(value, None, ctx)
|
||||
else:
|
||||
tv = field_type(value)
|
||||
return tv
|
||||
except ValueError as e:
|
||||
raise click.BadParameter(
|
||||
f"Invalid value for {field.name} with value {value}: {e}"
|
||||
) from e
|
||||
|
||||
|
||||
def configure_api_from_environment(
|
||||
ctx: click.Context, api_name: str, api: Backend
|
||||
) -> None:
|
||||
if not hasattr(api, "parameters"):
|
||||
return
|
||||
|
||||
|
|
@ -76,26 +121,21 @@ def configure_api_from_environment(api_name: str, api: Backend) -> None:
|
|||
value = os.environ.get(envvar)
|
||||
if value is None:
|
||||
continue
|
||||
try:
|
||||
tv = field.type(value)
|
||||
except ValueError as e:
|
||||
raise click.BadParameter(
|
||||
f"Invalid value for {field.name} with value {value}: {e}"
|
||||
) from e
|
||||
tv = convert_str_to_field(ctx, field, value)
|
||||
setattr(api.parameters, field.name, tv)
|
||||
|
||||
|
||||
def get_api(name: str = "openai_chatgpt") -> Backend:
|
||||
def get_api(ctx: click.Context, name: str = "openai_chatgpt") -> Backend:
|
||||
name = name.replace("-", "_")
|
||||
result = cast(
|
||||
backend = cast(
|
||||
Backend, importlib.import_module(f"{__package__}.backends.{name}").factory()
|
||||
)
|
||||
configure_api_from_environment(name, result)
|
||||
return result
|
||||
configure_api_from_environment(ctx, name, backend)
|
||||
return backend
|
||||
|
||||
|
||||
def do_session_continue(
|
||||
ctx: click.Context, param: click.Parameter, value: pathlib.Path | None
|
||||
ctx: click.Context, param: click.Parameter, value: Optional[pathlib.Path]
|
||||
) -> None:
|
||||
if value is None:
|
||||
return
|
||||
|
|
@ -108,9 +148,7 @@ def do_session_continue(
|
|||
ctx.obj.session_filename = value
|
||||
|
||||
|
||||
def do_session_last(
|
||||
ctx: click.Context, param: click.Parameter, value: bool
|
||||
) -> None: # pylint: disable=unused-argument
|
||||
def do_session_last(ctx: click.Context, param: click.Parameter, value: bool) -> None:
|
||||
if not value:
|
||||
return
|
||||
do_session_continue(ctx, param, last_session_path())
|
||||
|
|
@ -138,18 +176,14 @@ def colonstr(arg: str) -> tuple[str, str]:
|
|||
return cast(tuple[str, str], tuple(arg.split(":", 1)))
|
||||
|
||||
|
||||
def set_system_message( # pylint: disable=unused-argument
|
||||
ctx: click.Context, param: click.Parameter, value: str
|
||||
) -> None:
|
||||
def set_system_message(ctx: click.Context, param: click.Parameter, value: str) -> None:
|
||||
if value and value.startswith("@"):
|
||||
with open(value[1:], "r", encoding="utf-8") as f:
|
||||
value = f.read().rstrip()
|
||||
ctx.obj.system_message = value
|
||||
|
||||
|
||||
def set_backend( # pylint: disable=unused-argument
|
||||
ctx: click.Context, param: click.Parameter, value: str
|
||||
) -> None:
|
||||
def set_backend(ctx: click.Context, param: click.Parameter, value: str) -> None:
|
||||
if value == "list":
|
||||
formatter = ctx.make_formatter()
|
||||
format_backend_list(formatter)
|
||||
|
|
@ -157,7 +191,7 @@ def set_backend( # pylint: disable=unused-argument
|
|||
ctx.exit()
|
||||
|
||||
try:
|
||||
ctx.obj.api = get_api(value)
|
||||
ctx.obj.api = get_api(ctx, value)
|
||||
except ModuleNotFoundError as e:
|
||||
raise click.BadParameter(str(e))
|
||||
|
||||
|
|
@ -172,15 +206,13 @@ def format_backend_help(api: Backend, formatter: click.HelpFormatter) -> None:
|
|||
if doc:
|
||||
doc += " "
|
||||
doc += f"(Default: {default!r})"
|
||||
f_type = f.type
|
||||
if isinstance(f_type, UnionType):
|
||||
f_type = f_type.__args__[0]
|
||||
f_type = get_field_type(f)
|
||||
typename = f_type.__name__
|
||||
rows.append((f"-B {name}:{typename.upper()}", doc))
|
||||
formatter.write_dl(rows)
|
||||
|
||||
|
||||
def set_backend_option( # pylint: disable=unused-argument
|
||||
def set_backend_option(
|
||||
ctx: click.Context, param: click.Parameter, opts: list[tuple[str, str]]
|
||||
) -> None:
|
||||
api = ctx.obj.api
|
||||
|
|
@ -195,16 +227,8 @@ def set_backend_option( # pylint: disable=unused-argument
|
|||
field = all_fields.get(name)
|
||||
if field is None:
|
||||
raise click.BadParameter(f"Invalid parameter {name}")
|
||||
f_type = field.type
|
||||
if isinstance(f_type, UnionType):
|
||||
f_type = f_type.__args__[0]
|
||||
try:
|
||||
tv = f_type(value)
|
||||
except ValueError as e:
|
||||
raise click.BadParameter(
|
||||
f"Invalid value for {name} with value {value}: {e}"
|
||||
) from e
|
||||
setattr(api.parameters, name, tv)
|
||||
tv = convert_str_to_field(ctx, field, value)
|
||||
setattr(api.parameters, field.name, tv)
|
||||
|
||||
for kv in opts:
|
||||
set_one_backend_option(kv)
|
||||
|
|
@ -264,9 +288,7 @@ def command_uses_new_session(f_in: click.decorators.FC) -> click.Command:
|
|||
return click.command()(f)
|
||||
|
||||
|
||||
def version_callback( # pylint: disable=unused-argument
|
||||
ctx: click.Context, param: click.Parameter, value: None
|
||||
) -> None:
|
||||
def version_callback(ctx: click.Context, param: click.Parameter, value: None) -> None:
|
||||
if not value or ctx.resilient_parsing:
|
||||
return
|
||||
|
||||
|
|
@ -293,18 +315,18 @@ def version_callback( # pylint: disable=unused-argument
|
|||
|
||||
@dataclass
|
||||
class Obj:
|
||||
api: Backend | None = None
|
||||
system_message: str | None = None
|
||||
session: list[Message] | None = None
|
||||
session_filename: pathlib.Path | None = None
|
||||
api: Optional[Backend] = None
|
||||
system_message: Optional[str] = None
|
||||
session: Optional[list[Message]] = None
|
||||
session_filename: Optional[pathlib.Path] = None
|
||||
|
||||
|
||||
class MyCLI(click.MultiCommand):
|
||||
def make_context(
|
||||
self,
|
||||
info_name: str | None,
|
||||
info_name: Optional[str],
|
||||
args: list[str],
|
||||
parent: click.Context | None = None,
|
||||
parent: Optional[click.Context] = None,
|
||||
**extra: Any,
|
||||
) -> click.Context:
|
||||
result = super().make_context(info_name, args, parent, obj=Obj(), **extra)
|
||||
|
|
@ -332,7 +354,7 @@ class MyCLI(click.MultiCommand):
|
|||
self, ctx: click.Context, formatter: click.HelpFormatter
|
||||
) -> None:
|
||||
super().format_options(ctx, formatter)
|
||||
api = ctx.obj.api or get_api()
|
||||
api = ctx.obj.api or get_api(ctx)
|
||||
if hasattr(api, "parameters"):
|
||||
format_backend_help(api, formatter)
|
||||
|
||||
|
|
|
|||
|
|
@ -7,13 +7,13 @@ from __future__ import annotations
|
|||
import json
|
||||
import pathlib
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import cast
|
||||
from typing import Union, cast
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
# not an enum.Enum because these objects are not json-serializable, sigh
|
||||
class Role: # pylint: disable=too-few-public-methods
|
||||
class Role:
|
||||
ASSISTANT = "assistant"
|
||||
SYSTEM = "system"
|
||||
USER = "user"
|
||||
|
|
@ -65,11 +65,11 @@ def session_from_json(data: str) -> Session:
|
|||
return [Message(**mapping) for mapping in j]
|
||||
|
||||
|
||||
def session_from_file(path: pathlib.Path | str) -> Session:
|
||||
def session_from_file(path: Union[pathlib.Path, str]) -> Session:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
return session_from_json(f.read())
|
||||
|
||||
|
||||
def session_to_file(session: Session, path: pathlib.Path | str) -> None:
|
||||
def session_to_file(session: Session, path: Union[pathlib.Path, str]) -> None:
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
f.write(session_to_json(session))
|
||||
|
|
|
|||
Loading…
Reference in a new issue