Merge pull request #30 from jepler/multiline2
Switch chap tui to a multiline text field
This commit is contained in:
commit
9f6ace394a
21 changed files with 205 additions and 144 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -8,3 +8,4 @@ __pycache__
|
||||||
/dist
|
/dist
|
||||||
/src/chap/__version__.py
|
/src/chap/__version__.py
|
||||||
/venv
|
/venv
|
||||||
|
/keys.log
|
||||||
|
|
|
||||||
|
|
@ -6,12 +6,8 @@ default_language_version:
|
||||||
python: python3
|
python: python3
|
||||||
|
|
||||||
repos:
|
repos:
|
||||||
- repo: https://github.com/psf/black
|
|
||||||
rev: 23.1.0
|
|
||||||
hooks:
|
|
||||||
- id: black
|
|
||||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
rev: v4.4.0
|
rev: v4.5.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: check-yaml
|
- id: check-yaml
|
||||||
- id: end-of-file-fixer
|
- id: end-of-file-fixer
|
||||||
|
|
@ -19,23 +15,20 @@ repos:
|
||||||
- id: trailing-whitespace
|
- id: trailing-whitespace
|
||||||
exclude: tests
|
exclude: tests
|
||||||
- repo: https://github.com/codespell-project/codespell
|
- repo: https://github.com/codespell-project/codespell
|
||||||
rev: v2.2.4
|
rev: v2.2.6
|
||||||
hooks:
|
hooks:
|
||||||
- id: codespell
|
- id: codespell
|
||||||
args: [-w]
|
args: [-w]
|
||||||
- repo: https://github.com/fsfe/reuse-tool
|
- repo: https://github.com/fsfe/reuse-tool
|
||||||
rev: v1.1.2
|
rev: v2.1.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: reuse
|
- id: reuse
|
||||||
- repo: https://github.com/pycqa/isort
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
rev: 5.12.0
|
# Ruff version.
|
||||||
|
rev: v0.1.6
|
||||||
hooks:
|
hooks:
|
||||||
- id: isort
|
# Run the linter.
|
||||||
name: isort (python)
|
- id: ruff
|
||||||
args: ['--profile', 'black']
|
args: [ --fix, --preview ]
|
||||||
- repo: https://github.com/pycqa/pylint
|
# Run the formatter.
|
||||||
rev: v2.17.0
|
- id: ruff-format
|
||||||
hooks:
|
|
||||||
- id: pylint
|
|
||||||
additional_dependencies: [click,httpx,lorem-text,simple-parsing,'textual>=0.18.0',tiktoken,websockets]
|
|
||||||
args: ['--source-roots', 'src']
|
|
||||||
|
|
|
||||||
12
.pylintrc
12
.pylintrc
|
|
@ -1,12 +0,0 @@
|
||||||
# SPDX-FileCopyrightText: 2023 Jeff Epler <jepler@gmail.com>
|
|
||||||
#
|
|
||||||
# SPDX-License-Identifier: MIT
|
|
||||||
|
|
||||||
[MESSAGES CONTROL]
|
|
||||||
disable=
|
|
||||||
duplicate-code,
|
|
||||||
invalid-name,
|
|
||||||
line-too-long,
|
|
||||||
missing-class-docstring,
|
|
||||||
missing-function-docstring,
|
|
||||||
missing-module-docstring,
|
|
||||||
27
README.md
27
README.md
|
|
@ -9,11 +9,11 @@ SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
# chap - A Python interface to chatgpt and other LLMs, including a terminal user interface (tui)
|
# chap - A Python interface to chatgpt and other LLMs, including a terminal user interface (tui)
|
||||||
|
|
||||||

|

|
||||||
|
|
||||||
## System requirements
|
## System requirements
|
||||||
|
|
||||||
Chap is developed on Linux with Python 3.11. Due to use of the `X | Y` style of type hints, it is known to not work on Python 3.9 and older. The target minimum Python version is 3.11 (debian stable).
|
Chap is primarily developed on Linux with Python 3.11. Moderate effort will be made to support versions back to Python 3.9 (Debian oldstable).
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
|
|
@ -82,7 +82,28 @@ Put your OpenAI API key in the platform configuration directory for chap, e.g.,
|
||||||
* `chap grep needle`
|
* `chap grep needle`
|
||||||
|
|
||||||
## Interactive terminal usage
|
## Interactive terminal usage
|
||||||
* `chap tui`
|
The interactive terminal mode is accessed via `chap tui`.
|
||||||
|
|
||||||
|
There are a variety of keyboard shortcuts to be aware of:
|
||||||
|
* tab/shift-tab to move between the entry field and the conversation, or between conversation items
|
||||||
|
* While in the text box, F9 or (if supported by your terminal) alt+enter to submit multiline text
|
||||||
|
* while on a conversation item:
|
||||||
|
* ctrl+x to re-draft the message. This
|
||||||
|
* saves a copy of the session in an auto-named file in the conversations folder
|
||||||
|
* removes the conversation from this message to the end
|
||||||
|
* puts the user's message in the text box to edit
|
||||||
|
* ctrl+x to re-submit the message. This
|
||||||
|
* saves a copy of the session in an auto-named file in the conversations folder
|
||||||
|
* removes the conversation from this message to the end
|
||||||
|
* puts the user's message in the text box
|
||||||
|
* and submits it immediately
|
||||||
|
* ctrl+y to yank the message. This places the response part of the current
|
||||||
|
interaction in the operating system clipboard to be pasted (e..g, with
|
||||||
|
ctrl+v or command+v in other software)
|
||||||
|
* ctrl+q to toggle whether this message may be included in the contextual history for a future query.
|
||||||
|
The exact way history is submitted is determined by the back-end, often by
|
||||||
|
counting messages or tokens, but the ctrl+q toggle ensures this message (both the user
|
||||||
|
and assistant message parts) are not considered.
|
||||||
|
|
||||||
## Sessions & Command-line Parameters
|
## Sessions & Command-line Parameters
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,6 @@ src_path = project_root / "src"
|
||||||
sys.path.insert(0, str(src_path))
|
sys.path.insert(0, str(src_path))
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# pylint: disable=import-error,no-name-in-module
|
|
||||||
from chap.core import main
|
from chap.core import main
|
||||||
|
|
||||||
main()
|
main()
|
||||||
|
|
|
||||||
|
|
@ -20,13 +20,14 @@ name="chap"
|
||||||
authors = [{name = "Jeff Epler", email = "jepler@gmail.com"}]
|
authors = [{name = "Jeff Epler", email = "jepler@gmail.com"}]
|
||||||
description = "Interact with the OpenAI ChatGPT API (and other text generators)"
|
description = "Interact with the OpenAI ChatGPT API (and other text generators)"
|
||||||
dynamic = ["readme","version","dependencies"]
|
dynamic = ["readme","version","dependencies"]
|
||||||
requires-python = ">=3.10"
|
requires-python = ">=3.9"
|
||||||
keywords = ["llm", "tui", "chatgpt"]
|
keywords = ["llm", "tui", "chatgpt"]
|
||||||
classifiers = [
|
classifiers = [
|
||||||
"Development Status :: 3 - Alpha",
|
"Development Status :: 3 - Alpha",
|
||||||
"License :: OSI Approved :: MIT License",
|
"License :: OSI Approved :: MIT License",
|
||||||
"Operating System :: OS Independent",
|
"Operating System :: OS Independent",
|
||||||
"Programming Language :: Python :: 3",
|
"Programming Language :: Python :: 3",
|
||||||
|
"Programming Language :: Python :: 3.9",
|
||||||
"Programming Language :: Python :: 3.10",
|
"Programming Language :: Python :: 3.10",
|
||||||
"Programming Language :: Python :: 3.11",
|
"Programming Language :: Python :: 3.11",
|
||||||
"Programming Language :: Python :: 3.12",
|
"Programming Language :: Python :: 3.12",
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,8 @@ click
|
||||||
httpx
|
httpx
|
||||||
lorem-text
|
lorem-text
|
||||||
platformdirs
|
platformdirs
|
||||||
|
pyperclip
|
||||||
simple_parsing
|
simple_parsing
|
||||||
textual>=0.18.0
|
textual[syntax]
|
||||||
tiktoken
|
tiktoken
|
||||||
websockets
|
websockets
|
||||||
|
|
|
||||||
|
|
@ -89,9 +89,7 @@ A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, a
|
||||||
*,
|
*,
|
||||||
max_query_size: int = 5,
|
max_query_size: int = 5,
|
||||||
timeout: float = 180,
|
timeout: float = 180,
|
||||||
) -> AsyncGenerator[
|
) -> AsyncGenerator[str, None]:
|
||||||
str, None
|
|
||||||
]: # pylint: disable=unused-argument,too-many-locals,too-many-branches
|
|
||||||
new_content: list[str] = []
|
new_content: list[str] = []
|
||||||
inputs = self.make_full_query(session + [User(query)], max_query_size)
|
inputs = self.make_full_query(session + [User(query)], max_query_size)
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -55,9 +55,7 @@ A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, a
|
||||||
*,
|
*,
|
||||||
max_query_size: int = 5,
|
max_query_size: int = 5,
|
||||||
timeout: float = 180,
|
timeout: float = 180,
|
||||||
) -> AsyncGenerator[
|
) -> AsyncGenerator[str, None]:
|
||||||
str, None
|
|
||||||
]: # pylint: disable=unused-argument,too-many-locals,too-many-branches
|
|
||||||
params = {
|
params = {
|
||||||
"prompt": self.make_full_query(session + [User(query)], max_query_size),
|
"prompt": self.make_full_query(session + [User(query)], max_query_size),
|
||||||
"stream": True,
|
"stream": True,
|
||||||
|
|
|
||||||
|
|
@ -45,7 +45,7 @@ class Lorem:
|
||||||
session: Session,
|
session: Session,
|
||||||
query: str,
|
query: str,
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
data = self.ask(session, query)[-1]
|
data = self.ask(session, query)
|
||||||
for word, opt_sep in ipartition(data):
|
for word, opt_sep in ipartition(data):
|
||||||
yield word + opt_sep
|
yield word + opt_sep
|
||||||
await asyncio.sleep(
|
await asyncio.sleep(
|
||||||
|
|
@ -56,7 +56,7 @@ class Lorem:
|
||||||
self,
|
self,
|
||||||
session: Session,
|
session: Session,
|
||||||
query: str,
|
query: str,
|
||||||
) -> str: # pylint: disable=unused-argument
|
) -> str:
|
||||||
new_content = cast(
|
new_content = cast(
|
||||||
str,
|
str,
|
||||||
lorem.paragraphs(
|
lorem.paragraphs(
|
||||||
|
|
|
||||||
|
|
@ -101,7 +101,7 @@ class ChatGPT:
|
||||||
json={
|
json={
|
||||||
"model": self.parameters.model,
|
"model": self.parameters.model,
|
||||||
"messages": session_to_list(full_prompt),
|
"messages": session_to_list(full_prompt),
|
||||||
}, # pylint: disable=no-member
|
},
|
||||||
headers={
|
headers={
|
||||||
"Authorization": f"Bearer {self.get_key()}",
|
"Authorization": f"Bearer {self.get_key()}",
|
||||||
},
|
},
|
||||||
|
|
|
||||||
|
|
@ -29,7 +29,7 @@ USER: Hello, AI.
|
||||||
|
|
||||||
AI: Hello! How can I assist you today?"""
|
AI: Hello! How can I assist you today?"""
|
||||||
|
|
||||||
async def aask( # pylint: disable=unused-argument,too-many-locals,too-many-branches
|
async def aask(
|
||||||
self,
|
self,
|
||||||
session: Session,
|
session: Session,
|
||||||
query: str,
|
query: str,
|
||||||
|
|
@ -60,12 +60,11 @@ AI: Hello! How can I assist you today?"""
|
||||||
}
|
}
|
||||||
full_prompt = session + [User(query)]
|
full_prompt = session + [User(query)]
|
||||||
del full_prompt[1:-max_query_size]
|
del full_prompt[1:-max_query_size]
|
||||||
new_data = old_data = full_query = (
|
new_data = old_data = full_query = "\n".join(
|
||||||
"\n".join(f"{role_map.get(q.role,'')}{q.content}\n" for q in full_prompt)
|
f"{role_map.get(q.role,'')}{q.content}\n" for q in full_prompt
|
||||||
+ f"\n{role_map.get('assistant')}"
|
) + f"\n{role_map.get('assistant')}"
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
async with websockets.connect( # pylint: disable=no-member
|
async with websockets.connect(
|
||||||
f"ws://{self.parameters.server_hostname}:7860/queue/join"
|
f"ws://{self.parameters.server_hostname}:7860/queue/join"
|
||||||
) as websocket:
|
) as websocket:
|
||||||
while content := json.loads(await websocket.recv()):
|
while content := json.loads(await websocket.recv()):
|
||||||
|
|
@ -127,7 +126,7 @@ AI: Hello! How can I assist you today?"""
|
||||||
# stop generation by closing the websocket here
|
# stop generation by closing the websocket here
|
||||||
if content["msg"] == "process_completed":
|
if content["msg"] == "process_completed":
|
||||||
break
|
break
|
||||||
except Exception as e: # pylint: disable=broad-exception-caught
|
except Exception as e:
|
||||||
content = f"\nException: {e!r}"
|
content = f"\nException: {e!r}"
|
||||||
new_data += content
|
new_data += content
|
||||||
yield content
|
yield content
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import sys
|
import sys
|
||||||
from typing import Iterable, Protocol
|
from typing import Iterable, Optional, Protocol
|
||||||
|
|
||||||
import click
|
import click
|
||||||
import rich
|
import rich
|
||||||
|
|
@ -40,7 +40,7 @@ class DumbPrinter:
|
||||||
|
|
||||||
|
|
||||||
class WrappingPrinter:
|
class WrappingPrinter:
|
||||||
def __init__(self, width: int | None = None) -> None:
|
def __init__(self, width: Optional[int] = None) -> None:
|
||||||
self._width = width or rich.get_console().width
|
self._width = width or rich.get_console().width
|
||||||
self._column = 0
|
self._column = 0
|
||||||
self._line = ""
|
self._line = ""
|
||||||
|
|
@ -122,4 +122,4 @@ def main(obj: Obj, prompt: str, print_prompt: bool) -> None:
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main() # pylint: disable=no-value-for-parameter
|
main()
|
||||||
|
|
|
||||||
|
|
@ -38,4 +38,4 @@ def main(obj: Obj, no_system: bool) -> None:
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main() # pylint: disable=no-value-for-parameter
|
main()
|
||||||
|
|
|
||||||
|
|
@ -24,8 +24,8 @@ def list_files_matching_rx(
|
||||||
"*.json"
|
"*.json"
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
session = session_from_file(conversation) # pylint: disable=no-member
|
session = session_from_file(conversation)
|
||||||
except Exception as e: # pylint: disable=broad-exception-caught
|
except Exception as e:
|
||||||
print(f"Failed to read {conversation}: {e}", file=sys.stderr)
|
print(f"Failed to read {conversation}: {e}", file=sys.stderr)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
@ -67,4 +67,4 @@ def main(
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main() # pylint: disable=no-value-for-parameter
|
main()
|
||||||
|
|
|
||||||
|
|
@ -82,4 +82,4 @@ def main(output_directory: pathlib.Path, files: list[TextIO]) -> None:
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main() # pylint: disable=no-value-for-parameter
|
main()
|
||||||
|
|
|
||||||
|
|
@ -45,4 +45,4 @@ def main(obj: Obj, no_system: bool) -> None:
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main() # pylint: disable=no-value-for-parameter
|
main()
|
||||||
|
|
|
||||||
|
|
@ -54,3 +54,5 @@ Input {
|
||||||
Markdown {
|
Markdown {
|
||||||
margin: 0 1 0 0;
|
margin: 0 1 0 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SubmittableTextArea { height: 3 }
|
||||||
|
|
|
||||||
|
|
@ -3,30 +3,53 @@
|
||||||
# SPDX-License-Identifier: MIT
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import subprocess
|
|
||||||
import sys
|
import sys
|
||||||
from typing import cast
|
from typing import Any, Optional, cast, TYPE_CHECKING
|
||||||
|
|
||||||
|
import click
|
||||||
from markdown_it import MarkdownIt
|
from markdown_it import MarkdownIt
|
||||||
from textual import work
|
from textual import work
|
||||||
|
from textual._ansi_sequences import ANSI_SEQUENCES_KEYS
|
||||||
from textual.app import App, ComposeResult
|
from textual.app import App, ComposeResult
|
||||||
from textual.binding import Binding
|
from textual.binding import Binding
|
||||||
from textual.containers import Container, Horizontal, VerticalScroll
|
from textual.containers import Container, Horizontal, VerticalScroll
|
||||||
from textual.widgets import Button, Footer, Input, LoadingIndicator, Markdown
|
from textual.keys import Keys
|
||||||
|
from textual.widgets import Button, Footer, LoadingIndicator, Markdown, TextArea
|
||||||
|
|
||||||
from ..core import Backend, Obj, command_uses_new_session, get_api, new_session_path
|
from ..core import Backend, Obj, command_uses_new_session, get_api, new_session_path
|
||||||
from ..session import Assistant, Message, Session, User, new_session, session_to_file
|
from ..session import Assistant, Message, Session, User, new_session, session_to_file
|
||||||
|
|
||||||
|
|
||||||
|
# workaround for pyperclip being un-typed
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
|
||||||
|
def pyperclip_copy(data: str) -> None:
|
||||||
|
...
|
||||||
|
else:
|
||||||
|
from pyperclip import copy as pyperclip_copy
|
||||||
|
|
||||||
|
|
||||||
|
# Monkeypatch alt+enter as meaning "F9", WFM
|
||||||
|
# ignore typing here because ANSI_SEQUENCES_KEYS is a Mapping[] which is read-only as
|
||||||
|
# far as mypy is concerned.
|
||||||
|
ANSI_SEQUENCES_KEYS["\x1b\r"] = (Keys.F9,) # type: ignore
|
||||||
|
ANSI_SEQUENCES_KEYS["\x1b\n"] = (Keys.F9,) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
class SubmittableTextArea(TextArea):
|
||||||
|
BINDINGS = [
|
||||||
|
Binding("f9", "submit", "Submit", show=True),
|
||||||
|
Binding("tab", "focus_next", show=False, priority=True), # no inserting tabs
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def parser_factory() -> MarkdownIt:
|
def parser_factory() -> MarkdownIt:
|
||||||
parser = MarkdownIt()
|
parser = MarkdownIt()
|
||||||
parser.options["html"] = False
|
parser.options["html"] = False
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
class ChapMarkdown(
|
class ChapMarkdown(Markdown, can_focus=True, can_focus_children=False):
|
||||||
Markdown, can_focus=True, can_focus_children=False
|
|
||||||
): # pylint: disable=function-redefined
|
|
||||||
BINDINGS = [
|
BINDINGS = [
|
||||||
Binding("ctrl+y", "yank", "Yank text", show=True),
|
Binding("ctrl+y", "yank", "Yank text", show=True),
|
||||||
Binding("ctrl+r", "resubmit", "resubmit", show=True),
|
Binding("ctrl+r", "resubmit", "resubmit", show=True),
|
||||||
|
|
@ -56,10 +79,10 @@ class Tui(App[None]):
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, api: Backend | None = None, session: Session | None = None
|
self, api: Optional[Backend] = None, session: Optional[Session] = None
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.api = api or get_api("lorem")
|
self.api = api or get_api(click.Context(click.Command("chap tui")), "lorem")
|
||||||
self.session = (
|
self.session = (
|
||||||
new_session(self.api.system_message) if session is None else session
|
new_session(self.api.system_message) if session is None else session
|
||||||
)
|
)
|
||||||
|
|
@ -73,8 +96,8 @@ class Tui(App[None]):
|
||||||
return cast(VerticalScroll, self.query_one("#wait"))
|
return cast(VerticalScroll, self.query_one("#wait"))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input(self) -> Input:
|
def input(self) -> SubmittableTextArea:
|
||||||
return self.query_one(Input)
|
return self.query_one(SubmittableTextArea)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def cancel_button(self) -> CancelButton:
|
def cancel_button(self) -> CancelButton:
|
||||||
|
|
@ -94,7 +117,9 @@ class Tui(App[None]):
|
||||||
Container(id="pad"),
|
Container(id="pad"),
|
||||||
id="content",
|
id="content",
|
||||||
)
|
)
|
||||||
yield Input(placeholder="Prompt")
|
s = SubmittableTextArea(language="markdown")
|
||||||
|
s.show_line_numbers = False
|
||||||
|
yield s
|
||||||
with Horizontal(id="wait"):
|
with Horizontal(id="wait"):
|
||||||
yield LoadingIndicator()
|
yield LoadingIndicator()
|
||||||
yield CancelButton(label="❌ Stop Generation", id="cancel", disabled=True)
|
yield CancelButton(label="❌ Stop Generation", id="cancel", disabled=True)
|
||||||
|
|
@ -103,8 +128,8 @@ class Tui(App[None]):
|
||||||
self.container.scroll_end(animate=False)
|
self.container.scroll_end(animate=False)
|
||||||
self.input.focus()
|
self.input.focus()
|
||||||
|
|
||||||
async def on_input_submitted(self, event: Input.Submitted) -> None:
|
async def action_submit(self) -> None:
|
||||||
self.get_completion(event.value)
|
self.get_completion(self.input.text)
|
||||||
|
|
||||||
@work(exclusive=True)
|
@work(exclusive=True)
|
||||||
async def get_completion(self, query: str) -> None:
|
async def get_completion(self, query: str) -> None:
|
||||||
|
|
@ -162,10 +187,10 @@ class Tui(App[None]):
|
||||||
try:
|
try:
|
||||||
await asyncio.gather(render_fun(), get_token_fun())
|
await asyncio.gather(render_fun(), get_token_fun())
|
||||||
finally:
|
finally:
|
||||||
self.input.value = ""
|
self.input.clear()
|
||||||
all_output = self.session[-1].content
|
all_output = self.session[-1].content
|
||||||
output.update(all_output)
|
output.update(all_output)
|
||||||
output._markdown = all_output # pylint: disable=protected-access
|
output._markdown = all_output
|
||||||
self.container.scroll_end()
|
self.container.scroll_end()
|
||||||
|
|
||||||
for markdown in self.container.children:
|
for markdown in self.container.children:
|
||||||
|
|
@ -183,8 +208,8 @@ class Tui(App[None]):
|
||||||
def action_yank(self) -> None:
|
def action_yank(self) -> None:
|
||||||
widget = self.focused
|
widget = self.focused
|
||||||
if isinstance(widget, ChapMarkdown):
|
if isinstance(widget, ChapMarkdown):
|
||||||
content = widget._markdown or "" # pylint: disable=protected-access
|
content = widget._markdown or ""
|
||||||
subprocess.run(["xsel", "-ib"], input=content.encode("utf-8"), check=False)
|
pyperclip_copy(content)
|
||||||
|
|
||||||
def action_toggle_history(self) -> None:
|
def action_toggle_history(self) -> None:
|
||||||
widget = self.focused
|
widget = self.focused
|
||||||
|
|
@ -204,9 +229,7 @@ class Tui(App[None]):
|
||||||
async def action_stop_generating(self) -> None:
|
async def action_stop_generating(self) -> None:
|
||||||
self.workers.cancel_all()
|
self.workers.cancel_all()
|
||||||
|
|
||||||
async def on_button_pressed( # pylint: disable=unused-argument
|
async def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||||
self, event: Button.Pressed
|
|
||||||
) -> None:
|
|
||||||
self.workers.cancel_all()
|
self.workers.cancel_all()
|
||||||
|
|
||||||
async def action_quit(self) -> None:
|
async def action_quit(self) -> None:
|
||||||
|
|
@ -235,19 +258,31 @@ class Tui(App[None]):
|
||||||
session_to_file(self.session, new_session_path())
|
session_to_file(self.session, new_session_path())
|
||||||
|
|
||||||
query = self.session[idx].content
|
query = self.session[idx].content
|
||||||
self.input.value = query
|
self.input.load_text(query)
|
||||||
|
|
||||||
del self.session[idx:]
|
del self.session[idx:]
|
||||||
for child in self.container.children[idx:-1]:
|
for child in self.container.children[idx:-1]:
|
||||||
await child.remove()
|
await child.remove()
|
||||||
|
|
||||||
self.input.focus()
|
self.input.focus()
|
||||||
|
self.on_text_area_changed()
|
||||||
if resubmit:
|
if resubmit:
|
||||||
await self.input.action_submit()
|
await self.action_submit()
|
||||||
|
|
||||||
|
def on_text_area_changed(self, event: Any = None) -> None:
|
||||||
|
height = self.input.document.get_size(self.input.indent_width)[1]
|
||||||
|
max_height = max(3, self.size.height - 6)
|
||||||
|
if height >= max_height:
|
||||||
|
self.input.styles.height = max_height
|
||||||
|
elif height <= 3:
|
||||||
|
self.input.styles.height = 3
|
||||||
|
else:
|
||||||
|
self.input.styles.height = height
|
||||||
|
|
||||||
|
|
||||||
@command_uses_new_session
|
@command_uses_new_session
|
||||||
def main(obj: Obj) -> None:
|
@click.option("--replace-system-prompt/--no-replace-system-prompt", default=False)
|
||||||
|
def main(obj: Obj, replace_system_prompt: bool) -> None:
|
||||||
"""Start interactive terminal user interface session"""
|
"""Start interactive terminal user interface session"""
|
||||||
api = obj.api
|
api = obj.api
|
||||||
assert api is not None
|
assert api is not None
|
||||||
|
|
@ -256,6 +291,9 @@ def main(obj: Obj) -> None:
|
||||||
session_filename = obj.session_filename
|
session_filename = obj.session_filename
|
||||||
assert session_filename is not None
|
assert session_filename is not None
|
||||||
|
|
||||||
|
if replace_system_prompt:
|
||||||
|
session[0].content = obj.system_message or api.system_message
|
||||||
|
|
||||||
tui = Tui(api, session)
|
tui = Tui(api, session)
|
||||||
tui.run()
|
tui.run()
|
||||||
|
|
||||||
|
|
@ -268,4 +306,4 @@ def main(obj: Obj) -> None:
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main() # pylint: disable=no-value-for-parameter
|
main()
|
||||||
|
|
|
||||||
132
src/chap/core.py
132
src/chap/core.py
|
|
@ -1,7 +1,7 @@
|
||||||
# SPDX-FileCopyrightText: 2023 Jeff Epler <jepler@gmail.com>
|
# SPDX-FileCopyrightText: 2023 Jeff Epler <jepler@gmail.com>
|
||||||
#
|
#
|
||||||
# SPDX-License-Identifier: MIT
|
# SPDX-License-Identifier: MIT
|
||||||
# pylint: disable=import-outside-toplevel
|
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import datetime
|
import datetime
|
||||||
|
|
@ -10,23 +10,38 @@ import os
|
||||||
import pathlib
|
import pathlib
|
||||||
import pkgutil
|
import pkgutil
|
||||||
import subprocess
|
import subprocess
|
||||||
from dataclasses import MISSING, dataclass, fields
|
from dataclasses import MISSING, Field, dataclass, fields
|
||||||
from types import UnionType
|
from typing import (
|
||||||
from typing import Any, AsyncGenerator, Callable, cast
|
Any,
|
||||||
|
AsyncGenerator,
|
||||||
|
Callable,
|
||||||
|
Optional,
|
||||||
|
Union,
|
||||||
|
cast,
|
||||||
|
get_origin,
|
||||||
|
get_args,
|
||||||
|
)
|
||||||
|
import sys
|
||||||
|
|
||||||
import click
|
import click
|
||||||
import platformdirs
|
import platformdirs
|
||||||
from simple_parsing.docstring import get_attribute_docstring
|
from simple_parsing.docstring import get_attribute_docstring
|
||||||
from typing_extensions import Protocol
|
from typing_extensions import Protocol
|
||||||
|
|
||||||
from . import backends, commands # pylint: disable=no-name-in-module
|
from . import backends, commands
|
||||||
from .session import Message, Session, System, session_from_file
|
from .session import Message, Session, System, session_from_file
|
||||||
|
|
||||||
|
UnionType: type
|
||||||
|
if sys.version_info >= (3, 10):
|
||||||
|
from types import UnionType
|
||||||
|
else:
|
||||||
|
UnionType = type(Union[int, float])
|
||||||
|
|
||||||
conversations_path = platformdirs.user_state_path("chap") / "conversations"
|
conversations_path = platformdirs.user_state_path("chap") / "conversations"
|
||||||
conversations_path.mkdir(parents=True, exist_ok=True)
|
conversations_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
class ABackend(Protocol): # pylint: disable=too-few-public-methods
|
class ABackend(Protocol):
|
||||||
def aask(self, session: Session, query: str) -> AsyncGenerator[str, None]:
|
def aask(self, session: Session, query: str) -> AsyncGenerator[str, None]:
|
||||||
"""Make a query, updating the session with the query and response, returning the query token by token"""
|
"""Make a query, updating the session with the query and response, returning the query token by token"""
|
||||||
|
|
||||||
|
|
@ -39,7 +54,7 @@ class Backend(ABackend, Protocol):
|
||||||
"""Make a query, updating the session with the query and response, returning the query"""
|
"""Make a query, updating the session with the query and response, returning the query"""
|
||||||
|
|
||||||
|
|
||||||
class AutoAskMixin: # pylint: disable=too-few-public-methods
|
class AutoAskMixin:
|
||||||
"""Mixin class for backends implementing aask"""
|
"""Mixin class for backends implementing aask"""
|
||||||
|
|
||||||
def ask(self, session: Session, query: str) -> str:
|
def ask(self, session: Session, query: str) -> str:
|
||||||
|
|
@ -54,20 +69,50 @@ class AutoAskMixin: # pylint: disable=too-few-public-methods
|
||||||
return "".join(tokens)
|
return "".join(tokens)
|
||||||
|
|
||||||
|
|
||||||
def last_session_path() -> pathlib.Path | None:
|
def last_session_path() -> Optional[pathlib.Path]:
|
||||||
result = max(
|
result = max(
|
||||||
conversations_path.glob("*.json"), key=lambda p: p.stat().st_mtime, default=None
|
conversations_path.glob("*.json"), key=lambda p: p.stat().st_mtime, default=None
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def new_session_path(opt_path: pathlib.Path | None = None) -> pathlib.Path:
|
def new_session_path(opt_path: Optional[pathlib.Path] = None) -> pathlib.Path:
|
||||||
return opt_path or conversations_path / (
|
return opt_path or conversations_path / (
|
||||||
datetime.datetime.now().isoformat().replace(":", "_") + ".json"
|
datetime.datetime.now().isoformat().replace(":", "_") + ".json"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def configure_api_from_environment(api_name: str, api: Backend) -> None:
|
def get_field_type(field: Field[Any]) -> Any:
|
||||||
|
field_type = field.type
|
||||||
|
if isinstance(field_type, str):
|
||||||
|
raise RuntimeError(
|
||||||
|
"parameters dataclass may not use 'from __future__ import annotations"
|
||||||
|
)
|
||||||
|
origin = get_origin(field_type)
|
||||||
|
if origin in (Union, UnionType):
|
||||||
|
for arg in get_args(field_type):
|
||||||
|
if arg is not None:
|
||||||
|
return arg
|
||||||
|
return field_type
|
||||||
|
|
||||||
|
|
||||||
|
def convert_str_to_field(ctx: click.Context, field: Field[Any], value: str) -> Any:
|
||||||
|
field_type = get_field_type(field)
|
||||||
|
try:
|
||||||
|
if field_type is bool:
|
||||||
|
tv = click.types.BoolParamType().convert(value, None, ctx)
|
||||||
|
else:
|
||||||
|
tv = field_type(value)
|
||||||
|
return tv
|
||||||
|
except ValueError as e:
|
||||||
|
raise click.BadParameter(
|
||||||
|
f"Invalid value for {field.name} with value {value}: {e}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
def configure_api_from_environment(
|
||||||
|
ctx: click.Context, api_name: str, api: Backend
|
||||||
|
) -> None:
|
||||||
if not hasattr(api, "parameters"):
|
if not hasattr(api, "parameters"):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
@ -76,26 +121,21 @@ def configure_api_from_environment(api_name: str, api: Backend) -> None:
|
||||||
value = os.environ.get(envvar)
|
value = os.environ.get(envvar)
|
||||||
if value is None:
|
if value is None:
|
||||||
continue
|
continue
|
||||||
try:
|
tv = convert_str_to_field(ctx, field, value)
|
||||||
tv = field.type(value)
|
|
||||||
except ValueError as e:
|
|
||||||
raise click.BadParameter(
|
|
||||||
f"Invalid value for {field.name} with value {value}: {e}"
|
|
||||||
) from e
|
|
||||||
setattr(api.parameters, field.name, tv)
|
setattr(api.parameters, field.name, tv)
|
||||||
|
|
||||||
|
|
||||||
def get_api(name: str = "openai_chatgpt") -> Backend:
|
def get_api(ctx: click.Context, name: str = "openai_chatgpt") -> Backend:
|
||||||
name = name.replace("-", "_")
|
name = name.replace("-", "_")
|
||||||
result = cast(
|
backend = cast(
|
||||||
Backend, importlib.import_module(f"{__package__}.backends.{name}").factory()
|
Backend, importlib.import_module(f"{__package__}.backends.{name}").factory()
|
||||||
)
|
)
|
||||||
configure_api_from_environment(name, result)
|
configure_api_from_environment(ctx, name, backend)
|
||||||
return result
|
return backend
|
||||||
|
|
||||||
|
|
||||||
def do_session_continue(
|
def do_session_continue(
|
||||||
ctx: click.Context, param: click.Parameter, value: pathlib.Path | None
|
ctx: click.Context, param: click.Parameter, value: Optional[pathlib.Path]
|
||||||
) -> None:
|
) -> None:
|
||||||
if value is None:
|
if value is None:
|
||||||
return
|
return
|
||||||
|
|
@ -108,9 +148,7 @@ def do_session_continue(
|
||||||
ctx.obj.session_filename = value
|
ctx.obj.session_filename = value
|
||||||
|
|
||||||
|
|
||||||
def do_session_last(
|
def do_session_last(ctx: click.Context, param: click.Parameter, value: bool) -> None:
|
||||||
ctx: click.Context, param: click.Parameter, value: bool
|
|
||||||
) -> None: # pylint: disable=unused-argument
|
|
||||||
if not value:
|
if not value:
|
||||||
return
|
return
|
||||||
do_session_continue(ctx, param, last_session_path())
|
do_session_continue(ctx, param, last_session_path())
|
||||||
|
|
@ -138,18 +176,14 @@ def colonstr(arg: str) -> tuple[str, str]:
|
||||||
return cast(tuple[str, str], tuple(arg.split(":", 1)))
|
return cast(tuple[str, str], tuple(arg.split(":", 1)))
|
||||||
|
|
||||||
|
|
||||||
def set_system_message( # pylint: disable=unused-argument
|
def set_system_message(ctx: click.Context, param: click.Parameter, value: str) -> None:
|
||||||
ctx: click.Context, param: click.Parameter, value: str
|
|
||||||
) -> None:
|
|
||||||
if value and value.startswith("@"):
|
if value and value.startswith("@"):
|
||||||
with open(value[1:], "r", encoding="utf-8") as f:
|
with open(value[1:], "r", encoding="utf-8") as f:
|
||||||
value = f.read().rstrip()
|
value = f.read().rstrip()
|
||||||
ctx.obj.system_message = value
|
ctx.obj.system_message = value
|
||||||
|
|
||||||
|
|
||||||
def set_backend( # pylint: disable=unused-argument
|
def set_backend(ctx: click.Context, param: click.Parameter, value: str) -> None:
|
||||||
ctx: click.Context, param: click.Parameter, value: str
|
|
||||||
) -> None:
|
|
||||||
if value == "list":
|
if value == "list":
|
||||||
formatter = ctx.make_formatter()
|
formatter = ctx.make_formatter()
|
||||||
format_backend_list(formatter)
|
format_backend_list(formatter)
|
||||||
|
|
@ -157,7 +191,7 @@ def set_backend( # pylint: disable=unused-argument
|
||||||
ctx.exit()
|
ctx.exit()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ctx.obj.api = get_api(value)
|
ctx.obj.api = get_api(ctx, value)
|
||||||
except ModuleNotFoundError as e:
|
except ModuleNotFoundError as e:
|
||||||
raise click.BadParameter(str(e))
|
raise click.BadParameter(str(e))
|
||||||
|
|
||||||
|
|
@ -172,15 +206,13 @@ def format_backend_help(api: Backend, formatter: click.HelpFormatter) -> None:
|
||||||
if doc:
|
if doc:
|
||||||
doc += " "
|
doc += " "
|
||||||
doc += f"(Default: {default!r})"
|
doc += f"(Default: {default!r})"
|
||||||
f_type = f.type
|
f_type = get_field_type(f)
|
||||||
if isinstance(f_type, UnionType):
|
|
||||||
f_type = f_type.__args__[0]
|
|
||||||
typename = f_type.__name__
|
typename = f_type.__name__
|
||||||
rows.append((f"-B {name}:{typename.upper()}", doc))
|
rows.append((f"-B {name}:{typename.upper()}", doc))
|
||||||
formatter.write_dl(rows)
|
formatter.write_dl(rows)
|
||||||
|
|
||||||
|
|
||||||
def set_backend_option( # pylint: disable=unused-argument
|
def set_backend_option(
|
||||||
ctx: click.Context, param: click.Parameter, opts: list[tuple[str, str]]
|
ctx: click.Context, param: click.Parameter, opts: list[tuple[str, str]]
|
||||||
) -> None:
|
) -> None:
|
||||||
api = ctx.obj.api
|
api = ctx.obj.api
|
||||||
|
|
@ -195,16 +227,8 @@ def set_backend_option( # pylint: disable=unused-argument
|
||||||
field = all_fields.get(name)
|
field = all_fields.get(name)
|
||||||
if field is None:
|
if field is None:
|
||||||
raise click.BadParameter(f"Invalid parameter {name}")
|
raise click.BadParameter(f"Invalid parameter {name}")
|
||||||
f_type = field.type
|
tv = convert_str_to_field(ctx, field, value)
|
||||||
if isinstance(f_type, UnionType):
|
setattr(api.parameters, field.name, tv)
|
||||||
f_type = f_type.__args__[0]
|
|
||||||
try:
|
|
||||||
tv = f_type(value)
|
|
||||||
except ValueError as e:
|
|
||||||
raise click.BadParameter(
|
|
||||||
f"Invalid value for {name} with value {value}: {e}"
|
|
||||||
) from e
|
|
||||||
setattr(api.parameters, name, tv)
|
|
||||||
|
|
||||||
for kv in opts:
|
for kv in opts:
|
||||||
set_one_backend_option(kv)
|
set_one_backend_option(kv)
|
||||||
|
|
@ -264,9 +288,7 @@ def command_uses_new_session(f_in: click.decorators.FC) -> click.Command:
|
||||||
return click.command()(f)
|
return click.command()(f)
|
||||||
|
|
||||||
|
|
||||||
def version_callback( # pylint: disable=unused-argument
|
def version_callback(ctx: click.Context, param: click.Parameter, value: None) -> None:
|
||||||
ctx: click.Context, param: click.Parameter, value: None
|
|
||||||
) -> None:
|
|
||||||
if not value or ctx.resilient_parsing:
|
if not value or ctx.resilient_parsing:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
@ -293,18 +315,18 @@ def version_callback( # pylint: disable=unused-argument
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Obj:
|
class Obj:
|
||||||
api: Backend | None = None
|
api: Optional[Backend] = None
|
||||||
system_message: str | None = None
|
system_message: Optional[str] = None
|
||||||
session: list[Message] | None = None
|
session: Optional[list[Message]] = None
|
||||||
session_filename: pathlib.Path | None = None
|
session_filename: Optional[pathlib.Path] = None
|
||||||
|
|
||||||
|
|
||||||
class MyCLI(click.MultiCommand):
|
class MyCLI(click.MultiCommand):
|
||||||
def make_context(
|
def make_context(
|
||||||
self,
|
self,
|
||||||
info_name: str | None,
|
info_name: Optional[str],
|
||||||
args: list[str],
|
args: list[str],
|
||||||
parent: click.Context | None = None,
|
parent: Optional[click.Context] = None,
|
||||||
**extra: Any,
|
**extra: Any,
|
||||||
) -> click.Context:
|
) -> click.Context:
|
||||||
result = super().make_context(info_name, args, parent, obj=Obj(), **extra)
|
result = super().make_context(info_name, args, parent, obj=Obj(), **extra)
|
||||||
|
|
@ -332,7 +354,7 @@ class MyCLI(click.MultiCommand):
|
||||||
self, ctx: click.Context, formatter: click.HelpFormatter
|
self, ctx: click.Context, formatter: click.HelpFormatter
|
||||||
) -> None:
|
) -> None:
|
||||||
super().format_options(ctx, formatter)
|
super().format_options(ctx, formatter)
|
||||||
api = ctx.obj.api or get_api()
|
api = ctx.obj.api or get_api(ctx)
|
||||||
if hasattr(api, "parameters"):
|
if hasattr(api, "parameters"):
|
||||||
format_backend_help(api, formatter)
|
format_backend_help(api, formatter)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,13 +7,13 @@ from __future__ import annotations
|
||||||
import json
|
import json
|
||||||
import pathlib
|
import pathlib
|
||||||
from dataclasses import asdict, dataclass
|
from dataclasses import asdict, dataclass
|
||||||
from typing import cast
|
from typing import Union, cast
|
||||||
|
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
|
|
||||||
# not an enum.Enum because these objects are not json-serializable, sigh
|
# not an enum.Enum because these objects are not json-serializable, sigh
|
||||||
class Role: # pylint: disable=too-few-public-methods
|
class Role:
|
||||||
ASSISTANT = "assistant"
|
ASSISTANT = "assistant"
|
||||||
SYSTEM = "system"
|
SYSTEM = "system"
|
||||||
USER = "user"
|
USER = "user"
|
||||||
|
|
@ -65,11 +65,11 @@ def session_from_json(data: str) -> Session:
|
||||||
return [Message(**mapping) for mapping in j]
|
return [Message(**mapping) for mapping in j]
|
||||||
|
|
||||||
|
|
||||||
def session_from_file(path: pathlib.Path | str) -> Session:
|
def session_from_file(path: Union[pathlib.Path, str]) -> Session:
|
||||||
with open(path, "r", encoding="utf-8") as f:
|
with open(path, "r", encoding="utf-8") as f:
|
||||||
return session_from_json(f.read())
|
return session_from_json(f.read())
|
||||||
|
|
||||||
|
|
||||||
def session_to_file(session: Session, path: pathlib.Path | str) -> None:
|
def session_to_file(session: Session, path: Union[pathlib.Path, str]) -> None:
|
||||||
with open(path, "w", encoding="utf-8") as f:
|
with open(path, "w", encoding="utf-8") as f:
|
||||||
f.write(session_to_json(session))
|
f.write(session_to_json(session))
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue