Compare commits

..

No commits in common. "main" and "document-mypy" have entirely different histories.

28 changed files with 383 additions and 740 deletions

48
.github/workflows/codeql.yml vendored Normal file
View file

@ -0,0 +1,48 @@
# SPDX-FileCopyrightText: 2022 Jeff Epler
#
# SPDX-License-Identifier: CC0-1.0
name: "CodeQL"
on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]
schedule:
- cron: "53 3 * * 5"
jobs:
analyze:
name: Analyze
runs-on: ubuntu-latest
permissions:
actions: read
contents: read
security-events: write
strategy:
fail-fast: false
matrix:
language: [ python ]
steps:
- name: Checkout
uses: actions/checkout@v3
- name: Install Dependencies (python)
run: pip3 install -r requirements-dev.txt
- name: Initialize CodeQL
uses: github/codeql-action/init@v2
with:
languages: ${{ matrix.language }}
queries: +security-and-quality
- name: Autobuild
uses: github/codeql-action/autobuild@v2
- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@v2
with:
category: "/language:${{ matrix.language }}"

View file

@ -18,10 +18,10 @@ jobs:
GITHUB_CONTEXT: ${{ toJson(github) }} GITHUB_CONTEXT: ${{ toJson(github) }}
run: echo "$GITHUB_CONTEXT" run: echo "$GITHUB_CONTEXT"
- uses: actions/checkout@v4 - uses: actions/checkout@v3
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v5 uses: actions/setup-python@v4
with: with:
python-version: 3.11 python-version: 3.11

View file

@ -17,10 +17,10 @@ jobs:
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v3
- name: pre-commit - name: pre-commit
uses: pre-commit/action@v3.0.1 uses: pre-commit/action@v3.0.0
- name: Make patch - name: Make patch
if: failure() if: failure()
@ -28,7 +28,7 @@ jobs:
- name: Upload patch - name: Upload patch
if: failure() if: failure()
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v3
with: with:
name: patch name: patch
path: ~/pre-commit.patch path: ~/pre-commit.patch
@ -36,10 +36,10 @@ jobs:
test-release: test-release:
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v3
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v5 uses: actions/setup-python@v4
with: with:
python-version: 3.11 python-version: 3.11
@ -53,7 +53,7 @@ jobs:
run: python -mbuild run: python -mbuild
- name: Upload artifacts - name: Upload artifacts
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v3
with: with:
name: dist name: dist
path: dist/* path: dist/*

1
.gitignore vendored
View file

@ -8,4 +8,3 @@ __pycache__
/dist /dist
/src/chap/__version__.py /src/chap/__version__.py
/venv /venv
/keys.log

View file

@ -6,8 +6,12 @@ default_language_version:
python: python3 python: python3
repos: repos:
- repo: https://github.com/psf/black
rev: 23.1.0
hooks:
- id: black
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0 rev: v4.4.0
hooks: hooks:
- id: check-yaml - id: check-yaml
- id: end-of-file-fixer - id: end-of-file-fixer
@ -15,20 +19,23 @@ repos:
- id: trailing-whitespace - id: trailing-whitespace
exclude: tests exclude: tests
- repo: https://github.com/codespell-project/codespell - repo: https://github.com/codespell-project/codespell
rev: v2.2.6 rev: v2.2.4
hooks: hooks:
- id: codespell - id: codespell
args: [-w] args: [-w]
- repo: https://github.com/fsfe/reuse-tool - repo: https://github.com/fsfe/reuse-tool
rev: v2.1.0 rev: v1.1.2
hooks: hooks:
- id: reuse - id: reuse
- repo: https://github.com/astral-sh/ruff-pre-commit - repo: https://github.com/pycqa/isort
# Ruff version. rev: 5.12.0
rev: v0.1.6
hooks: hooks:
# Run the linter. - id: isort
- id: ruff name: isort (python)
args: [ --fix, --preview ] args: ['--profile', 'black']
# Run the formatter. - repo: https://github.com/pycqa/pylint
- id: ruff-format rev: v2.17.0
hooks:
- id: pylint
additional_dependencies: [click,httpx,lorem-text,simple-parsing,'textual>=0.18.0',tiktoken,websockets]
args: ['--source-roots', 'src']

12
.pylintrc Normal file
View file

@ -0,0 +1,12 @@
# SPDX-FileCopyrightText: 2023 Jeff Epler <jepler@gmail.com>
#
# SPDX-License-Identifier: MIT
[MESSAGES CONTROL]
disable=
duplicate-code,
invalid-name,
line-too-long,
missing-class-docstring,
missing-function-docstring,
missing-module-docstring,

121
LICENSES/CC0-1.0.txt Normal file
View file

@ -0,0 +1,121 @@
Creative Commons Legal Code
CC0 1.0 Universal
CREATIVE COMMONS CORPORATION IS NOT A LAW FIRM AND DOES NOT PROVIDE
LEGAL SERVICES. DISTRIBUTION OF THIS DOCUMENT DOES NOT CREATE AN
ATTORNEY-CLIENT RELATIONSHIP. CREATIVE COMMONS PROVIDES THIS
INFORMATION ON AN "AS-IS" BASIS. CREATIVE COMMONS MAKES NO WARRANTIES
REGARDING THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS
PROVIDED HEREUNDER, AND DISCLAIMS LIABILITY FOR DAMAGES RESULTING FROM
THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS PROVIDED
HEREUNDER.
Statement of Purpose
The laws of most jurisdictions throughout the world automatically confer
exclusive Copyright and Related Rights (defined below) upon the creator
and subsequent owner(s) (each and all, an "owner") of an original work of
authorship and/or a database (each, a "Work").
Certain owners wish to permanently relinquish those rights to a Work for
the purpose of contributing to a commons of creative, cultural and
scientific works ("Commons") that the public can reliably and without fear
of later claims of infringement build upon, modify, incorporate in other
works, reuse and redistribute as freely as possible in any form whatsoever
and for any purposes, including without limitation commercial purposes.
These owners may contribute to the Commons to promote the ideal of a free
culture and the further production of creative, cultural and scientific
works, or to gain reputation or greater distribution for their Work in
part through the use and efforts of others.
For these and/or other purposes and motivations, and without any
expectation of additional consideration or compensation, the person
associating CC0 with a Work (the "Affirmer"), to the extent that he or she
is an owner of Copyright and Related Rights in the Work, voluntarily
elects to apply CC0 to the Work and publicly distribute the Work under its
terms, with knowledge of his or her Copyright and Related Rights in the
Work and the meaning and intended legal effect of CC0 on those rights.
1. Copyright and Related Rights. A Work made available under CC0 may be
protected by copyright and related or neighboring rights ("Copyright and
Related Rights"). Copyright and Related Rights include, but are not
limited to, the following:
i. the right to reproduce, adapt, distribute, perform, display,
communicate, and translate a Work;
ii. moral rights retained by the original author(s) and/or performer(s);
iii. publicity and privacy rights pertaining to a person's image or
likeness depicted in a Work;
iv. rights protecting against unfair competition in regards to a Work,
subject to the limitations in paragraph 4(a), below;
v. rights protecting the extraction, dissemination, use and reuse of data
in a Work;
vi. database rights (such as those arising under Directive 96/9/EC of the
European Parliament and of the Council of 11 March 1996 on the legal
protection of databases, and under any national implementation
thereof, including any amended or successor version of such
directive); and
vii. other similar, equivalent or corresponding rights throughout the
world based on applicable law or treaty, and any national
implementations thereof.
2. Waiver. To the greatest extent permitted by, but not in contravention
of, applicable law, Affirmer hereby overtly, fully, permanently,
irrevocably and unconditionally waives, abandons, and surrenders all of
Affirmer's Copyright and Related Rights and associated claims and causes
of action, whether now known or unknown (including existing as well as
future claims and causes of action), in the Work (i) in all territories
worldwide, (ii) for the maximum duration provided by applicable law or
treaty (including future time extensions), (iii) in any current or future
medium and for any number of copies, and (iv) for any purpose whatsoever,
including without limitation commercial, advertising or promotional
purposes (the "Waiver"). Affirmer makes the Waiver for the benefit of each
member of the public at large and to the detriment of Affirmer's heirs and
successors, fully intending that such Waiver shall not be subject to
revocation, rescission, cancellation, termination, or any other legal or
equitable action to disrupt the quiet enjoyment of the Work by the public
as contemplated by Affirmer's express Statement of Purpose.
3. Public License Fallback. Should any part of the Waiver for any reason
be judged legally invalid or ineffective under applicable law, then the
Waiver shall be preserved to the maximum extent permitted taking into
account Affirmer's express Statement of Purpose. In addition, to the
extent the Waiver is so judged Affirmer hereby grants to each affected
person a royalty-free, non transferable, non sublicensable, non exclusive,
irrevocable and unconditional license to exercise Affirmer's Copyright and
Related Rights in the Work (i) in all territories worldwide, (ii) for the
maximum duration provided by applicable law or treaty (including future
time extensions), (iii) in any current or future medium and for any number
of copies, and (iv) for any purpose whatsoever, including without
limitation commercial, advertising or promotional purposes (the
"License"). The License shall be deemed effective as of the date CC0 was
applied by Affirmer to the Work. Should any part of the License for any
reason be judged legally invalid or ineffective under applicable law, such
partial invalidity or ineffectiveness shall not invalidate the remainder
of the License, and in such case Affirmer hereby affirms that he or she
will not (i) exercise any of his or her remaining Copyright and Related
Rights in the Work or (ii) assert any associated claims and causes of
action with respect to the Work, in either case contrary to Affirmer's
express Statement of Purpose.
4. Limitations and Disclaimers.
a. No trademark or patent rights held by Affirmer are waived, abandoned,
surrendered, licensed or otherwise affected by this document.
b. Affirmer offers the Work as-is and makes no representations or
warranties of any kind concerning the Work, express, implied,
statutory or otherwise, including without limitation warranties of
title, merchantability, fitness for a particular purpose, non
infringement, or the absence of latent or other defects, accuracy, or
the present or absence of errors, whether or not discoverable, all to
the greatest extent permissible under applicable law.
c. Affirmer disclaims responsibility for clearing rights of other persons
that may apply to the Work or any use thereof, including without
limitation any person's Copyright and Related Rights in the Work.
Further, Affirmer disclaims responsibility for obtaining any necessary
consents, permissions or other rights required for any use of the
Work.
d. Affirmer understands and acknowledges that Creative Commons is not a
party to this document and has no duty or obligation with respect to
this CC0 or use of the Work.

View file

@ -9,11 +9,11 @@ SPDX-License-Identifier: MIT
# chap - A Python interface to chatgpt and other LLMs, including a terminal user interface (tui) # chap - A Python interface to chatgpt and other LLMs, including a terminal user interface (tui)
![Chap screencast](https://raw.githubusercontent.com/jepler/chap/main/chap.gif) ![Chap screencast](https://github.com/jepler/chap/blob/main/chap.gif)
## System requirements ## System requirements
Chap is primarily developed on Linux with Python 3.11. Moderate effort will be made to support versions back to Python 3.9 (Debian oldstable). Chap is developed on Linux with Python 3.11. Due to use of the `X | Y` style of type hints, it is known to not work on Python 3.9 and older. The target minimum Python version is 3.11 (debian stable).
## Installation ## Installation
@ -81,54 +81,8 @@ Put your OpenAI API key in the platform configuration directory for chap, e.g.,
* `chap grep needle` * `chap grep needle`
## `@FILE` arguments
It's useful to set a bunch of related arguments together, for instance to fully
configure a back-end. This functionality is implemented via `@FILE` arguments.
Before any other command-line argument parsing is performed, `@FILE` arguments are expanded:
* An `@FILE` argument is searched relative to the current directory
* An `@:FILE` argument is searched relative to the configuration directory (e.g., $HOME/.config/chap/presets)
* If an argument starts with a literal `@`, double it: `@@`
* `@.` stops processing any further `@FILE` arguments and leaves them unchanged.
The contents of an `@FILE` are parsed according to `shlex.split(comments=True)`.
Comments are supported.
A typical content might look like this:
```
# cfg/gpt-4o: Use more expensive gpt 4o and custom prompt
--backend openai-chatgpt
-B model:gpt-4o
-s :my-custom-system-message.txt
```
and you might use it with
```
chap @:cfg/gpt-4o ask what version of gpt is this
```
## Interactive terminal usage ## Interactive terminal usage
The interactive terminal mode is accessed via `chap tui`. * `chap tui`
There are a variety of keyboard shortcuts to be aware of:
* tab/shift-tab to move between the entry field and the conversation, or between conversation items
* While in the text box, F9 or (if supported by your terminal) alt+enter to submit multiline text
* while on a conversation item:
* ctrl+x to re-draft the message. This
* saves a copy of the session in an auto-named file in the conversations folder
* removes the conversation from this message to the end
* puts the user's message in the text box to edit
* ctrl+x to re-submit the message. This
* saves a copy of the session in an auto-named file in the conversations folder
* removes the conversation from this message to the end
* puts the user's message in the text box
* and submits it immediately
* ctrl+y to yank the message. This places the response part of the current
interaction in the operating system clipboard to be pasted (e..g, with
ctrl+v or command+v in other software)
* ctrl+q to toggle whether this message may be included in the contextual history for a future query.
The exact way history is submitted is determined by the back-end, often by
counting messages or tokens, but the ctrl+q toggle ensures this message (both the user
and assistant message parts) are not considered.
## Sessions & Command-line Parameters ## Sessions & Command-line Parameters
@ -144,26 +98,21 @@ an existing session with `-s`. Or, you can continue the last session with
You can set the "system message" with the `-S` flag. You can set the "system message" with the `-S` flag.
You can select the text generating backend with the `-b` flag: You can select the text generating backend with the `-b` flag:
* openai-chatgpt: the default, paid API, best quality results. Also works with compatible API implementations including llama-cpp when the correct backend URL is specified. * openai-chatgpt: the default, paid API, best quality results
* llama-cpp: Works with [llama.cpp's http server](https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md) and can run locally with various models, * llama-cpp: Works with [llama.cpp's http server](https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md) and can run locally with various models,
though it is [optimized for models that use the llama2-style prompting](https://huggingface.co/blog/llama2#how-to-prompt-llama-2). Consider using llama.cpp's OpenAI compatible API with the openai-chatgpt backend instead, in which case the server can apply the chat template. though it is [optimized for models that use the llama2-style prompting](https://huggingface.co/blog/llama2#how-to-prompt-llama-2).
Set the server URL with `-B url:...`.
* textgen: Works with https://github.com/oobabooga/text-generation-webui and can run locally with various models. * textgen: Works with https://github.com/oobabooga/text-generation-webui and can run locally with various models.
Needs the server URL in *$configuration_directory/textgen\_url*. Needs the server URL in *$configuration_directory/textgen\_url*.
* mistral: Works with the [mistral paid API](https://docs.mistral.ai/).
* anthropic: Works with the [anthropic paid API](https://docs.anthropic.com/en/home).
* huggingface: Works with the [huggingface API](https://huggingface.co/docs/api-inference/index), which includes a free tier.
* lorem: local non-AI lorem generator for testing * lorem: local non-AI lorem generator for testing
Backends have settings such as URLs and where API keys are stored. use `chap --backend
<BACKEND> --help` to list settings for a particular backend.
## Environment variables ## Environment variables
The backend can be set with the `CHAP_BACKEND` environment variable. The backend can be set with the `CHAP_BACKEND` environment variable.
Backend settings can be set with `CHAP_<backend_name>_<parameter_name>`, with `backend_name` and `parameter_name` all in caps. Backend settings can be set with `CHAP_<backend_name>_<parameter_name>`, with `backend_name` and `parameter_name` all in caps.
For instance, `CHAP_LLAMA_CPP_URL=http://server.local:8080/completion` changes the default server URL for the llama-cpp backend. For instance, `CHAP_LLAMA_CPP_URL=http://server.local:8080/completion` changes the default server URL for the llama-cpp back-end.
## Importing from ChatGPT ## Importing from ChatGPT

View file

@ -14,6 +14,7 @@ src_path = project_root / "src"
sys.path.insert(0, str(src_path)) sys.path.insert(0, str(src_path))
if __name__ == "__main__": if __name__ == "__main__":
# pylint: disable=import-error,no-name-in-module
from chap.core import main from chap.core import main
main() main()

View file

@ -20,19 +20,15 @@ name="chap"
authors = [{name = "Jeff Epler", email = "jepler@gmail.com"}] authors = [{name = "Jeff Epler", email = "jepler@gmail.com"}]
description = "Interact with the OpenAI ChatGPT API (and other text generators)" description = "Interact with the OpenAI ChatGPT API (and other text generators)"
dynamic = ["readme","version","dependencies"] dynamic = ["readme","version","dependencies"]
requires-python = ">=3.9"
keywords = ["llm", "tui", "chatgpt"]
classifiers = [ classifiers = [
"Development Status :: 3 - Alpha",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3", "Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.12",
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy", "Programming Language :: Python :: Implementation :: PyPy",
"Programming Language :: Python :: Implementation :: CPython",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
] ]
[project.urls] [project.urls]
homepage = "https://github.com/jepler/chap" homepage = "https://github.com/jepler/chap"

View file

@ -6,8 +6,7 @@ click
httpx httpx
lorem-text lorem-text
platformdirs platformdirs
pyperclip
simple_parsing simple_parsing
textual[syntax] >= 4 textual>=0.18.0
tiktoken tiktoken
websockets websockets

View file

@ -1,95 +0,0 @@
# SPDX-FileCopyrightText: 2024 Jeff Epler <jepler@gmail.com>
#
# SPDX-License-Identifier: MIT
import json
from dataclasses import dataclass
from typing import AsyncGenerator, Any
import httpx
from ..core import AutoAskMixin, Backend
from ..key import UsesKeyMixin
from ..session import Assistant, Role, Session, User
class Anthropic(AutoAskMixin, UsesKeyMixin):
@dataclass
class Parameters:
url: str = "https://api.anthropic.com"
model: str = "claude-3-5-sonnet-20240620"
max_new_tokens: int = 1000
api_key_name = "anthropic_api_key"
def __init__(self) -> None:
super().__init__()
self.parameters = self.Parameters()
system_message = """\
Answer each question accurately and thoroughly.
"""
def make_full_query(self, messages: Session, max_query_size: int) -> dict[str, Any]:
system = [m.content for m in messages if m.role == Role.SYSTEM]
messages = [m for m in messages if m.role != Role.SYSTEM and m.content]
del messages[:-max_query_size]
result = dict(
model=self.parameters.model,
max_tokens=self.parameters.max_new_tokens,
messages=[dict(role=str(m.role), content=m.content) for m in messages],
stream=True,
)
if system and system[0]:
result["system"] = system[0]
return result
async def aask(
self,
session: Session,
query: str,
*,
max_query_size: int = 5,
timeout: float = 180,
) -> AsyncGenerator[str, None]:
new_content: list[str] = []
params = self.make_full_query(session + [User(query)], max_query_size)
try:
async with httpx.AsyncClient(timeout=timeout) as client:
async with client.stream(
"POST",
f"{self.parameters.url}/v1/messages",
json=params,
headers={
"x-api-key": self.get_key(),
"content-type": "application/json",
"anthropic-version": "2023-06-01",
"anthropic-beta": "messages-2023-12-15",
},
) as response:
if response.status_code == 200:
async for line in response.aiter_lines():
if line.startswith("data:"):
data = line.removeprefix("data:").strip()
j = json.loads(data)
content = j.get("delta", {}).get("text", "")
if content:
new_content.append(content)
yield content
else:
content = f"\nFailed with {response=!r}"
new_content.append(content)
yield content
async for line in response.aiter_lines():
new_content.append(line)
yield line
except httpx.HTTPError as e:
content = f"\nException: {e!r}"
new_content.append(content)
yield content
session.extend([User(query), Assistant("".join(new_content))])
def factory() -> Backend:
"""Uses the anthropic text-generation-interface web API"""
return Anthropic()

View file

@ -9,11 +9,11 @@ from typing import Any, AsyncGenerator
import httpx import httpx
from ..core import AutoAskMixin, Backend from ..core import AutoAskMixin, Backend
from ..key import UsesKeyMixin from ..key import get_key
from ..session import Assistant, Role, Session, User from ..session import Assistant, Role, Session, User
class HuggingFace(AutoAskMixin, UsesKeyMixin): class HuggingFace(AutoAskMixin):
@dataclass @dataclass
class Parameters: class Parameters:
url: str = "https://api-inference.huggingface.co" url: str = "https://api-inference.huggingface.co"
@ -24,7 +24,6 @@ class HuggingFace(AutoAskMixin, UsesKeyMixin):
after_user: str = """ [/INST] """ after_user: str = """ [/INST] """
after_assistant: str = """ </s><s>[INST] """ after_assistant: str = """ </s><s>[INST] """
stop_token_id = 2 stop_token_id = 2
api_key_name = "huggingface_api_token"
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
@ -90,7 +89,9 @@ A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, a
*, *,
max_query_size: int = 5, max_query_size: int = 5,
timeout: float = 180, timeout: float = 180,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[
str, None
]: # pylint: disable=unused-argument,too-many-locals,too-many-branches
new_content: list[str] = [] new_content: list[str] = []
inputs = self.make_full_query(session + [User(query)], max_query_size) inputs = self.make_full_query(session + [User(query)], max_query_size)
try: try:
@ -111,6 +112,10 @@ A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, a
session.extend([User(query), Assistant("".join(new_content))]) session.extend([User(query), Assistant("".join(new_content))])
@classmethod
def get_key(cls) -> str:
return get_key("huggingface_api_token")
def factory() -> Backend: def factory() -> Backend:
"""Uses the huggingface text-generation-interface web API""" """Uses the huggingface text-generation-interface web API"""

View file

@ -18,16 +18,10 @@ class LlamaCpp(AutoAskMixin):
url: str = "http://localhost:8080/completion" url: str = "http://localhost:8080/completion"
"""The URL of a llama.cpp server's completion endpoint.""" """The URL of a llama.cpp server's completion endpoint."""
start_prompt: str = "<|begin_of_text|>" start_prompt: str = """<s>[INST] <<SYS>>\n"""
system_format: str = ( after_system: str = "\n<</SYS>>\n\n"
"<|start_header_id|>system<|end_header_id|>\n\n{}<|eot_id|>" after_user: str = """ [/INST] """
) after_assistant: str = """ </s><s>[INST] """
user_format: str = "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
assistant_format: str = (
"<|start_header_id|>assistant<|end_header_id|>\n\n{}<|eot_id|>"
)
end_prompt: str = "<|start_header_id|>assistant<|end_header_id|>\n\n"
stop: str | None = None
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
@ -40,16 +34,17 @@ A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, a
def make_full_query(self, messages: Session, max_query_size: int) -> str: def make_full_query(self, messages: Session, max_query_size: int) -> str:
del messages[1:-max_query_size] del messages[1:-max_query_size]
result = [self.parameters.start_prompt] result = [self.parameters.start_prompt]
formats = {
Role.SYSTEM: self.parameters.system_format,
Role.USER: self.parameters.user_format,
Role.ASSISTANT: self.parameters.assistant_format,
}
for m in messages: for m in messages:
content = (m.content or "").strip() content = (m.content or "").strip()
if not content: if not content:
continue continue
result.append(formats[m.role].format(content)) result.append(content)
if m.role == Role.SYSTEM:
result.append(self.parameters.after_system)
elif m.role == Role.ASSISTANT:
result.append(self.parameters.after_assistant)
elif m.role == Role.USER:
result.append(self.parameters.after_user)
full_query = "".join(result) full_query = "".join(result)
return full_query return full_query
@ -60,11 +55,13 @@ A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, a
*, *,
max_query_size: int = 5, max_query_size: int = 5,
timeout: float = 180, timeout: float = 180,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[
str, None
]: # pylint: disable=unused-argument,too-many-locals,too-many-branches
params = { params = {
"prompt": self.make_full_query(session + [User(query)], max_query_size), "prompt": self.make_full_query(session + [User(query)], max_query_size),
"stream": True, "stream": True,
"stop": ["</s>", "<s>", "[INST]", "<|eot_id|>"], "stop": ["</s>", "<s>", "[INST]"],
} }
new_content: list[str] = [] new_content: list[str] = []
try: try:
@ -101,10 +98,5 @@ A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, a
def factory() -> Backend: def factory() -> Backend:
"""Uses the llama.cpp completion web API """Uses the llama.cpp completion web API"""
Note: Consider using the openai-chatgpt backend with a custom URL instead.
The llama.cpp server will automatically apply common chat templates with the
openai-chatgpt backend, while chat templates must be manually configured client side
with this backend."""
return LlamaCpp() return LlamaCpp()

View file

@ -45,7 +45,7 @@ class Lorem:
session: Session, session: Session,
query: str, query: str,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
data = self.ask(session, query) data = self.ask(session, query)[-1]
for word, opt_sep in ipartition(data): for word, opt_sep in ipartition(data):
yield word + opt_sep yield word + opt_sep
await asyncio.sleep( await asyncio.sleep(
@ -56,7 +56,7 @@ class Lorem:
self, self,
session: Session, session: Session,
query: str, query: str,
) -> str: ) -> str: # pylint: disable=unused-argument
new_content = cast( new_content = cast(
str, str,
lorem.paragraphs( lorem.paragraphs(

View file

@ -1,96 +0,0 @@
# SPDX-FileCopyrightText: 2024 Jeff Epler <jepler@gmail.com>
#
# SPDX-License-Identifier: MIT
import json
from dataclasses import dataclass
from typing import AsyncGenerator, Any
import httpx
from ..core import AutoAskMixin
from ..key import UsesKeyMixin
from ..session import Assistant, Session, User
class Mistral(AutoAskMixin, UsesKeyMixin):
@dataclass
class Parameters:
url: str = "https://api.mistral.ai"
model: str = "open-mistral-7b"
max_new_tokens: int = 1000
api_key_name = "mistral_api_key"
def __init__(self) -> None:
super().__init__()
self.parameters = self.Parameters()
system_message = """\
Answer each question accurately and thoroughly.
"""
def make_full_query(self, messages: Session, max_query_size: int) -> dict[str, Any]:
messages = [m for m in messages if m.content]
del messages[1:-max_query_size]
result = dict(
model=self.parameters.model,
max_tokens=self.parameters.max_new_tokens,
messages=[dict(role=str(m.role), content=m.content) for m in messages],
stream=True,
)
return result
async def aask(
self,
session: Session,
query: str,
*,
max_query_size: int = 5,
timeout: float = 180,
) -> AsyncGenerator[str, None]:
new_content: list[str] = []
params = self.make_full_query(session + [User(query)], max_query_size)
try:
async with httpx.AsyncClient(timeout=timeout) as client:
async with client.stream(
"POST",
f"{self.parameters.url}/v1/chat/completions",
json=params,
headers={
"Authorization": f"Bearer {self.get_key()}",
"content-type": "application/json",
"accept": "application/json",
"model": "application/json",
},
) as response:
if response.status_code == 200:
async for line in response.aiter_lines():
if line.startswith("data:"):
data = line.removeprefix("data:").strip()
if data == "[DONE]":
break
j = json.loads(data)
content = (
j.get("choices", [{}])[0]
.get("delta", {})
.get("content", "")
)
if content:
new_content.append(content)
yield content
else:
content = f"\nFailed with {response=!r}"
new_content.append(content)
yield content
async for line in response.aiter_lines():
new_content.append(line)
yield line
except httpx.HTTPError as e:
content = f"\nException: {e!r}"
new_content.append(content)
yield content
session.extend([User(query), Assistant("".join(new_content))])
factory = Mistral

View file

@ -12,7 +12,7 @@ import httpx
import tiktoken import tiktoken
from ..core import Backend from ..core import Backend
from ..key import UsesKeyMixin from ..key import get_key
from ..session import Assistant, Message, Session, User, session_to_list from ..session import Assistant, Message, Session, User, session_to_list
@ -63,29 +63,15 @@ class EncodingMeta:
return cls(encoding, tokens_per_message, tokens_per_name, tokens_overhead) return cls(encoding, tokens_per_message, tokens_per_name, tokens_overhead)
class ChatGPT(UsesKeyMixin): class ChatGPT:
@dataclass @dataclass
class Parameters: class Parameters:
model: str = "gpt-4o-mini" model: str = "gpt-3.5-turbo"
"""The model to use. The most common alternative value is 'gpt-4o'.""" """The model to use. The most common alternative value is 'gpt-4'."""
max_request_tokens: int = 1024 max_request_tokens: int = 1024
"""The approximate greatest number of tokens to send in a request. When the session is long, the system prompt and 1 or more of the most recent interaction steps are sent.""" """The approximate greatest number of tokens to send in a request. When the session is long, the system prompt and 1 or more of the most recent interaction steps are sent."""
url: str = "https://api.openai.com/v1/chat/completions"
"""The URL of a chatgpt-compatible server's completion endpoint. Notably, llama.cpp's server is compatible with this backend, and can automatically apply common chat templates too."""
temperature: float | None = None
"""The model temperature for sampling"""
top_p: float | None = None
"""The model temperature for sampling"""
api_key_name: str = "openai_api_key"
"""The OpenAI API key"""
parameters: Parameters
def __init__(self) -> None: def __init__(self) -> None:
self.parameters = self.Parameters() self.parameters = self.Parameters()
@ -111,11 +97,11 @@ class ChatGPT(UsesKeyMixin):
def ask(self, session: Session, query: str, *, timeout: float = 60) -> str: def ask(self, session: Session, query: str, *, timeout: float = 60) -> str:
full_prompt = self.make_full_prompt(session + [User(query)]) full_prompt = self.make_full_prompt(session + [User(query)])
response = httpx.post( response = httpx.post(
self.parameters.url, "https://api.openai.com/v1/chat/completions",
json={ json={
"model": self.parameters.model, "model": self.parameters.model,
"messages": session_to_list(full_prompt), "messages": session_to_list(full_prompt),
}, }, # pylint: disable=no-member
headers={ headers={
"Authorization": f"Bearer {self.get_key()}", "Authorization": f"Bearer {self.get_key()}",
}, },
@ -142,12 +128,10 @@ class ChatGPT(UsesKeyMixin):
async with httpx.AsyncClient(timeout=timeout) as client: async with httpx.AsyncClient(timeout=timeout) as client:
async with client.stream( async with client.stream(
"POST", "POST",
self.parameters.url, "https://api.openai.com/v1/chat/completions",
headers={"authorization": f"Bearer {self.get_key()}"}, headers={"authorization": f"Bearer {self.get_key()}"},
json={ json={
"model": self.parameters.model, "model": self.parameters.model,
"temperature": self.parameters.temperature,
"top_p": self.parameters.top_p,
"stream": True, "stream": True,
"messages": session_to_list(full_prompt), "messages": session_to_list(full_prompt),
}, },
@ -176,6 +160,10 @@ class ChatGPT(UsesKeyMixin):
session.extend([User(query), Assistant("".join(new_content))]) session.extend([User(query), Assistant("".join(new_content))])
@classmethod
def get_key(cls) -> str:
return get_key("openai_api_key")
def factory() -> Backend: def factory() -> Backend:
"""Uses the OpenAI chat completion API""" """Uses the OpenAI chat completion API"""

View file

@ -29,7 +29,7 @@ USER: Hello, AI.
AI: Hello! How can I assist you today?""" AI: Hello! How can I assist you today?"""
async def aask( async def aask( # pylint: disable=unused-argument,too-many-locals,too-many-branches
self, self,
session: Session, session: Session,
query: str, query: str,
@ -60,11 +60,12 @@ AI: Hello! How can I assist you today?"""
} }
full_prompt = session + [User(query)] full_prompt = session + [User(query)]
del full_prompt[1:-max_query_size] del full_prompt[1:-max_query_size]
new_data = old_data = full_query = "\n".join( new_data = old_data = full_query = (
f"{role_map.get(q.role,'')}{q.content}\n" for q in full_prompt "\n".join(f"{role_map.get(q.role,'')}{q.content}\n" for q in full_prompt)
) + f"\n{role_map.get('assistant')}" + f"\n{role_map.get('assistant')}"
)
try: try:
async with websockets.connect( async with websockets.connect( # pylint: disable=no-member
f"ws://{self.parameters.server_hostname}:7860/queue/join" f"ws://{self.parameters.server_hostname}:7860/queue/join"
) as websocket: ) as websocket:
while content := json.loads(await websocket.recv()): while content := json.loads(await websocket.recv()):
@ -126,7 +127,7 @@ AI: Hello! How can I assist you today?"""
# stop generation by closing the websocket here # stop generation by closing the websocket here
if content["msg"] == "process_completed": if content["msg"] == "process_completed":
break break
except Exception as e: except Exception as e: # pylint: disable=broad-exception-caught
content = f"\nException: {e!r}" content = f"\nException: {e!r}"
new_data += content new_data += content
yield content yield content

View file

@ -4,7 +4,7 @@
import asyncio import asyncio
import sys import sys
from typing import Iterable, Optional, Protocol from typing import Iterable, Protocol
import click import click
import rich import rich
@ -40,7 +40,7 @@ class DumbPrinter:
class WrappingPrinter: class WrappingPrinter:
def __init__(self, width: Optional[int] = None) -> None: def __init__(self, width: int | None = None) -> None:
self._width = width or rich.get_console().width self._width = width or rich.get_console().width
self._column = 0 self._column = 0
self._line = "" self._line = ""
@ -100,9 +100,8 @@ def verbose_ask(api: Backend, session: Session, q: str, print_prompt: bool) -> s
@command_uses_new_session @command_uses_new_session
@click.option("--print-prompt/--no-print-prompt", default=True) @click.option("--print-prompt/--no-print-prompt", default=True)
@click.option("--stdin/--no-stdin", "use_stdin", default=False) @click.argument("prompt", nargs=-1, required=True)
@click.argument("prompt", nargs=-1) def main(obj: Obj, prompt: str, print_prompt: bool) -> None:
def main(obj: Obj, prompt: list[str], use_stdin: bool, print_prompt: bool) -> None:
"""Ask a question (command-line argument is passed as prompt)""" """Ask a question (command-line argument is passed as prompt)"""
session = obj.session session = obj.session
assert session is not None assert session is not None
@ -113,16 +112,9 @@ def main(obj: Obj, prompt: list[str], use_stdin: bool, print_prompt: bool) -> No
api = obj.api api = obj.api
assert api is not None assert api is not None
if use_stdin:
if prompt:
raise click.UsageError("Can't use 'prompt' together with --stdin")
joined_prompt = sys.stdin.read()
else:
joined_prompt = " ".join(prompt)
# symlink_session_filename(session_filename) # symlink_session_filename(session_filename)
response = verbose_ask(api, session, joined_prompt, print_prompt=print_prompt) response = verbose_ask(api, session, " ".join(prompt), print_prompt=print_prompt)
print(f"Saving session to {session_filename}", file=sys.stderr) print(f"Saving session to {session_filename}", file=sys.stderr)
if response is not None: if response is not None:
@ -130,4 +122,4 @@ def main(obj: Obj, prompt: list[str], use_stdin: bool, print_prompt: bool) -> No
if __name__ == "__main__": if __name__ == "__main__":
main() main() # pylint: disable=no-value-for-parameter

View file

@ -38,4 +38,4 @@ def main(obj: Obj, no_system: bool) -> None:
if __name__ == "__main__": if __name__ == "__main__":
main() main() # pylint: disable=no-value-for-parameter

View file

@ -24,8 +24,8 @@ def list_files_matching_rx(
"*.json" "*.json"
): ):
try: try:
session = session_from_file(conversation) session = session_from_file(conversation) # pylint: disable=no-member
except Exception as e: except Exception as e: # pylint: disable=broad-exception-caught
print(f"Failed to read {conversation}: {e}", file=sys.stderr) print(f"Failed to read {conversation}: {e}", file=sys.stderr)
continue continue
@ -67,4 +67,4 @@ def main(
if __name__ == "__main__": if __name__ == "__main__":
main() main() # pylint: disable=no-value-for-parameter

View file

@ -82,4 +82,4 @@ def main(output_directory: pathlib.Path, files: list[TextIO]) -> None:
if __name__ == "__main__": if __name__ == "__main__":
main() main() # pylint: disable=no-value-for-parameter

View file

@ -45,4 +45,4 @@ def main(obj: Obj, no_system: bool) -> None:
if __name__ == "__main__": if __name__ == "__main__":
main() main() # pylint: disable=no-value-for-parameter

View file

@ -36,9 +36,6 @@ Markdown:focus-within {
Markdown { Markdown {
border-left: heavy transparent; border-left: heavy transparent;
} }
MarkdownFence {
max-height: 9999;
}
Footer { Footer {
dock: top; dock: top;
} }
@ -57,5 +54,3 @@ Input {
Markdown { Markdown {
margin: 0 1 0 0; margin: 0 1 0 0;
} }
SubmittableTextArea { height: auto; min-height: 5; margin: 0; border: none; border-left: heavy $primary }

View file

@ -3,58 +3,35 @@
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
import asyncio import asyncio
import subprocess
import sys import sys
from typing import Any, Optional, cast, TYPE_CHECKING from typing import cast
import click
from markdown_it import MarkdownIt from markdown_it import MarkdownIt
from textual import work from textual import work
from textual._ansi_sequences import ANSI_SEQUENCES_KEYS
from textual.app import App, ComposeResult from textual.app import App, ComposeResult
from textual.binding import Binding from textual.binding import Binding
from textual.containers import Container, Horizontal, VerticalScroll from textual.containers import Container, Horizontal, VerticalScroll
from textual.keys import Keys from textual.widgets import Button, Footer, Input, LoadingIndicator, Markdown
from textual.widgets import Button, Footer, LoadingIndicator, Markdown, TextArea
from ..core import Backend, Obj, command_uses_new_session, get_api, new_session_path from ..core import Backend, Obj, command_uses_new_session, get_api, new_session_path
from ..session import Assistant, Message, Session, User, new_session, session_to_file from ..session import Assistant, Message, Session, User, new_session, session_to_file
# workaround for pyperclip being un-typed
if TYPE_CHECKING:
def pyperclip_copy(data: str) -> None:
...
else:
from pyperclip import copy as pyperclip_copy
# Monkeypatch alt+enter as meaning "F9", WFM
# ignore typing here because ANSI_SEQUENCES_KEYS is a Mapping[] which is read-only as
# far as mypy is concerned.
ANSI_SEQUENCES_KEYS["\x1b\r"] = (Keys.F9,) # type: ignore
ANSI_SEQUENCES_KEYS["\x1b\n"] = (Keys.F9,) # type: ignore
class SubmittableTextArea(TextArea):
BINDINGS = [
Binding("f9", "app.submit", "Submit", show=True),
Binding("tab", "focus_next", show=False, priority=True), # no inserting tabs
]
def parser_factory() -> MarkdownIt: def parser_factory() -> MarkdownIt:
parser = MarkdownIt() parser = MarkdownIt()
parser.options["html"] = False parser.options["html"] = False
return parser return parser
class ChapMarkdown(Markdown, can_focus=True, can_focus_children=False): class ChapMarkdown(
Markdown, can_focus=True, can_focus_children=False
): # pylint: disable=function-redefined
BINDINGS = [ BINDINGS = [
Binding("ctrl+c", "app.yank", "Copy text", show=True), Binding("ctrl+y", "yank", "Yank text", show=True),
Binding("ctrl+r", "app.resubmit", "resubmit", show=True), Binding("ctrl+r", "resubmit", "resubmit", show=True),
Binding("ctrl+x", "app.redraft", "redraft", show=True), Binding("ctrl+x", "redraft", "redraft", show=True),
Binding("ctrl+q", "app.toggle_history", "history toggle", show=True), Binding("ctrl+q", "toggle_history", "history toggle", show=True),
] ]
@ -75,14 +52,14 @@ class CancelButton(Button):
class Tui(App[None]): class Tui(App[None]):
CSS_PATH = "tui.css" CSS_PATH = "tui.css"
BINDINGS = [ BINDINGS = [
Binding("ctrl+q", "quit", "Quit", show=True, priority=True), Binding("ctrl+c", "quit", "Quit", show=True, priority=True),
] ]
def __init__( def __init__(
self, api: Optional[Backend] = None, session: Optional[Session] = None self, api: Backend | None = None, session: Session | None = None
) -> None: ) -> None:
super().__init__() super().__init__()
self.api = api or get_api(click.Context(click.Command("chap tui")), "lorem") self.api = api or get_api("lorem")
self.session = ( self.session = (
new_session(self.api.system_message) if session is None else session new_session(self.api.system_message) if session is None else session
) )
@ -96,8 +73,8 @@ class Tui(App[None]):
return cast(VerticalScroll, self.query_one("#wait")) return cast(VerticalScroll, self.query_one("#wait"))
@property @property
def input(self) -> SubmittableTextArea: def input(self) -> Input:
return self.query_one(SubmittableTextArea) return self.query_one(Input)
@property @property
def cancel_button(self) -> CancelButton: def cancel_button(self) -> CancelButton:
@ -117,9 +94,7 @@ class Tui(App[None]):
Container(id="pad"), Container(id="pad"),
id="content", id="content",
) )
s = SubmittableTextArea(language="markdown") yield Input(placeholder="Prompt")
s.show_line_numbers = False
yield s
with Horizontal(id="wait"): with Horizontal(id="wait"):
yield LoadingIndicator() yield LoadingIndicator()
yield CancelButton(label="❌ Stop Generation", id="cancel", disabled=True) yield CancelButton(label="❌ Stop Generation", id="cancel", disabled=True)
@ -128,8 +103,8 @@ class Tui(App[None]):
self.container.scroll_end(animate=False) self.container.scroll_end(animate=False)
self.input.focus() self.input.focus()
async def action_submit(self) -> None: async def on_input_submitted(self, event: Input.Submitted) -> None:
self.get_completion(self.input.text) self.get_completion(event.value)
@work(exclusive=True) @work(exclusive=True)
async def get_completion(self, query: str) -> None: async def get_completion(self, query: str) -> None:
@ -145,6 +120,7 @@ class Tui(App[None]):
await self.container.mount_all( await self.container.mount_all(
[markdown_for_step(User(query)), output], before="#pad" [markdown_for_step(User(query)), output], before="#pad"
) )
tokens: list[str] = []
update: asyncio.Queue[bool] = asyncio.Queue(1) update: asyncio.Queue[bool] = asyncio.Queue(1)
for markdown in self.container.children: for markdown in self.container.children:
@ -165,22 +141,15 @@ class Tui(App[None]):
) )
async def render_fun() -> None: async def render_fun() -> None:
old_len = 0
while await update.get(): while await update.get():
content = message.content if tokens:
new_len = len(content) output.update("".join(tokens).strip())
new_content = content[old_len:new_len]
if new_content:
if old_len:
await output.append(new_content)
else:
output.update(content)
self.container.scroll_end() self.container.scroll_end()
old_len = new_len await asyncio.sleep(0.1)
await asyncio.sleep(0.01)
async def get_token_fun() -> None: async def get_token_fun() -> None:
async for token in self.api.aask(session, query): async for token in self.api.aask(session, query):
tokens.append(token)
message.content += token message.content += token
try: try:
update.put_nowait(True) update.put_nowait(True)
@ -193,10 +162,10 @@ class Tui(App[None]):
try: try:
await asyncio.gather(render_fun(), get_token_fun()) await asyncio.gather(render_fun(), get_token_fun())
finally: finally:
self.input.clear() self.input.value = ""
all_output = self.session[-1].content all_output = self.session[-1].content
output.update(all_output) output.update(all_output)
output._markdown = all_output output._markdown = all_output # pylint: disable=protected-access
self.container.scroll_end() self.container.scroll_end()
for markdown in self.container.children: for markdown in self.container.children:
@ -214,8 +183,8 @@ class Tui(App[None]):
def action_yank(self) -> None: def action_yank(self) -> None:
widget = self.focused widget = self.focused
if isinstance(widget, ChapMarkdown): if isinstance(widget, ChapMarkdown):
content = widget._markdown or "" content = widget._markdown or "" # pylint: disable=protected-access
pyperclip_copy(content) subprocess.run(["xsel", "-ib"], input=content.encode("utf-8"), check=False)
def action_toggle_history(self) -> None: def action_toggle_history(self) -> None:
widget = self.focused widget = self.focused
@ -235,7 +204,9 @@ class Tui(App[None]):
async def action_stop_generating(self) -> None: async def action_stop_generating(self) -> None:
self.workers.cancel_all() self.workers.cancel_all()
async def on_button_pressed(self, event: Button.Pressed) -> None: async def on_button_pressed( # pylint: disable=unused-argument
self, event: Button.Pressed
) -> None:
self.workers.cancel_all() self.workers.cancel_all()
async def action_quit(self) -> None: async def action_quit(self) -> None:
@ -264,31 +235,19 @@ class Tui(App[None]):
session_to_file(self.session, new_session_path()) session_to_file(self.session, new_session_path())
query = self.session[idx].content query = self.session[idx].content
self.input.load_text(query) self.input.value = query
del self.session[idx:] del self.session[idx:]
for child in self.container.children[idx:-1]: for child in self.container.children[idx:-1]:
await child.remove() await child.remove()
self.input.focus() self.input.focus()
self.on_text_area_changed()
if resubmit: if resubmit:
await self.action_submit() await self.input.action_submit()
def on_text_area_changed(self, event: Any = None) -> None:
height = self.input.document.get_size(self.input.indent_width)[1]
max_height = max(3, self.size.height - 6)
if height >= max_height:
self.input.styles.height = max_height
elif height <= 3:
self.input.styles.height = 3
else:
self.input.styles.height = height
@command_uses_new_session @command_uses_new_session
@click.option("--replace-system-prompt/--no-replace-system-prompt", default=False) def main(obj: Obj) -> None:
def main(obj: Obj, replace_system_prompt: bool) -> None:
"""Start interactive terminal user interface session""" """Start interactive terminal user interface session"""
api = obj.api api = obj.api
assert api is not None assert api is not None
@ -297,11 +256,6 @@ def main(obj: Obj, replace_system_prompt: bool) -> None:
session_filename = obj.session_filename session_filename = obj.session_filename
assert session_filename is not None assert session_filename is not None
if replace_system_prompt:
session[0].content = (
api.system_message if obj.system_message is None else obj.system_message
)
tui = Tui(api, session) tui = Tui(api, session)
tui.run() tui.run()
@ -314,4 +268,4 @@ def main(obj: Obj, replace_system_prompt: bool) -> None:
if __name__ == "__main__": if __name__ == "__main__":
main() main() # pylint: disable=no-value-for-parameter

View file

@ -1,53 +1,32 @@
# SPDX-FileCopyrightText: 2023 Jeff Epler <jepler@gmail.com> # SPDX-FileCopyrightText: 2023 Jeff Epler <jepler@gmail.com>
# #
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
# pylint: disable=import-outside-toplevel
from collections.abc import Sequence
import asyncio import asyncio
import datetime import datetime
import io
import importlib import importlib
import os import os
import pathlib import pathlib
import pkgutil import pkgutil
import subprocess import subprocess
import shlex from dataclasses import MISSING, dataclass, fields
import textwrap from types import UnionType
from dataclasses import MISSING, Field, dataclass, fields from typing import Any, AsyncGenerator, Callable, cast
from typing import (
Any,
AsyncGenerator,
Callable,
Optional,
Union,
IO,
cast,
get_origin,
get_args,
)
import sys
import click import click
import platformdirs import platformdirs
from simple_parsing.docstring import get_attribute_docstring from simple_parsing.docstring import get_attribute_docstring
from typing_extensions import Protocol from typing_extensions import Protocol
from . import backends, commands
from .session import Message, Session, System, session_from_file
UnionType: type from . import backends, commands # pylint: disable=no-name-in-module
if sys.version_info >= (3, 10): from .session import Message, Session, System, session_from_file
from types import UnionType
else:
UnionType = type(Union[int, float])
conversations_path = platformdirs.user_state_path("chap") / "conversations" conversations_path = platformdirs.user_state_path("chap") / "conversations"
conversations_path.mkdir(parents=True, exist_ok=True) conversations_path.mkdir(parents=True, exist_ok=True)
configuration_path = platformdirs.user_config_path("chap")
preset_path = configuration_path / "preset"
class ABackend(Protocol): class ABackend(Protocol): # pylint: disable=too-few-public-methods
def aask(self, session: Session, query: str) -> AsyncGenerator[str, None]: def aask(self, session: Session, query: str) -> AsyncGenerator[str, None]:
"""Make a query, updating the session with the query and response, returning the query token by token""" """Make a query, updating the session with the query and response, returning the query token by token"""
@ -60,7 +39,7 @@ class Backend(ABackend, Protocol):
"""Make a query, updating the session with the query and response, returning the query""" """Make a query, updating the session with the query and response, returning the query"""
class AutoAskMixin: class AutoAskMixin: # pylint: disable=too-few-public-methods
"""Mixin class for backends implementing aask""" """Mixin class for backends implementing aask"""
def ask(self, session: Session, query: str) -> str: def ask(self, session: Session, query: str) -> str:
@ -75,50 +54,20 @@ class AutoAskMixin:
return "".join(tokens) return "".join(tokens)
def last_session_path() -> Optional[pathlib.Path]: def last_session_path() -> pathlib.Path | None:
result = max( result = max(
conversations_path.glob("*.json"), key=lambda p: p.stat().st_mtime, default=None conversations_path.glob("*.json"), key=lambda p: p.stat().st_mtime, default=None
) )
return result return result
def new_session_path(opt_path: Optional[pathlib.Path] = None) -> pathlib.Path: def new_session_path(opt_path: pathlib.Path | None = None) -> pathlib.Path:
return opt_path or conversations_path / ( return opt_path or conversations_path / (
datetime.datetime.now().isoformat().replace(":", "_") + ".json" datetime.datetime.now().isoformat().replace(":", "_") + ".json"
) )
def get_field_type(field: Field[Any]) -> Any: def configure_api_from_environment(api_name: str, api: Backend) -> None:
field_type = field.type
if isinstance(field_type, str):
raise RuntimeError(
"parameters dataclass may not use 'from __future__ import annotations"
)
origin = get_origin(field_type)
if origin in (Union, UnionType):
for arg in get_args(field_type):
if arg is not None:
return arg
return field_type
def convert_str_to_field(ctx: click.Context, field: Field[Any], value: str) -> Any:
field_type = get_field_type(field)
try:
if field_type is bool:
tv = click.types.BoolParamType().convert(value, None, ctx)
else:
tv = field_type(value)
return tv
except ValueError as e:
raise click.BadParameter(
f"Invalid value for {field.name} with value {value}: {e}"
) from e
def configure_api_from_environment(
ctx: click.Context, api_name: str, api: Backend
) -> None:
if not hasattr(api, "parameters"): if not hasattr(api, "parameters"):
return return
@ -127,25 +76,26 @@ def configure_api_from_environment(
value = os.environ.get(envvar) value = os.environ.get(envvar)
if value is None: if value is None:
continue continue
tv = convert_str_to_field(ctx, field, value) try:
tv = field.type(value)
except ValueError as e:
raise click.BadParameter(
f"Invalid value for {field.name} with value {value}: {e}"
) from e
setattr(api.parameters, field.name, tv) setattr(api.parameters, field.name, tv)
def get_api(ctx: click.Context | None = None, name: str | None = None) -> Backend: def get_api(name: str = "openai_chatgpt") -> Backend:
if ctx is None:
ctx = click.Context(click.Command("chap"))
if name is None:
name = os.environ.get("CHAP_BACKEND", "openai_chatgpt")
name = name.replace("-", "_") name = name.replace("-", "_")
backend = cast( result = cast(
Backend, importlib.import_module(f"{__package__}.backends.{name}").factory() Backend, importlib.import_module(f"{__package__}.backends.{name}").factory()
) )
configure_api_from_environment(ctx, name, backend) configure_api_from_environment(name, result)
return backend return result
def do_session_continue( def do_session_continue(
ctx: click.Context, param: click.Parameter, value: Optional[pathlib.Path] ctx: click.Context, param: click.Parameter, value: pathlib.Path | None
) -> None: ) -> None:
if value is None: if value is None:
return return
@ -158,7 +108,9 @@ def do_session_continue(
ctx.obj.session_filename = value ctx.obj.session_filename = value
def do_session_last(ctx: click.Context, param: click.Parameter, value: bool) -> None: def do_session_last(
ctx: click.Context, param: click.Parameter, value: bool
) -> None: # pylint: disable=unused-argument
if not value: if not value:
return return
do_session_continue(ctx, param, last_session_path()) do_session_continue(ctx, param, last_session_path())
@ -186,22 +138,18 @@ def colonstr(arg: str) -> tuple[str, str]:
return cast(tuple[str, str], tuple(arg.split(":", 1))) return cast(tuple[str, str], tuple(arg.split(":", 1)))
def set_system_message(ctx: click.Context, param: click.Parameter, value: str) -> None: def set_system_message( # pylint: disable=unused-argument
if value is None: ctx: click.Context, param: click.Parameter, value: str
return ) -> None:
if value and value.startswith("@"):
with open(value[1:], "r", encoding="utf-8") as f:
value = f.read().rstrip()
ctx.obj.system_message = value ctx.obj.system_message = value
def set_system_message_from_file( def set_backend( # pylint: disable=unused-argument
ctx: click.Context, param: click.Parameter, value: io.TextIOWrapper ctx: click.Context, param: click.Parameter, value: str
) -> None: ) -> None:
if value is None:
return
content = value.read().strip()
ctx.obj.system_message = content
def set_backend(ctx: click.Context, param: click.Parameter, value: str) -> None:
if value == "list": if value == "list":
formatter = ctx.make_formatter() formatter = ctx.make_formatter()
format_backend_list(formatter) format_backend_list(formatter)
@ -209,7 +157,7 @@ def set_backend(ctx: click.Context, param: click.Parameter, value: str) -> None:
ctx.exit() ctx.exit()
try: try:
ctx.obj.api = get_api(ctx, value) ctx.obj.api = get_api(value)
except ModuleNotFoundError as e: except ModuleNotFoundError as e:
raise click.BadParameter(str(e)) raise click.BadParameter(str(e))
@ -224,13 +172,15 @@ def format_backend_help(api: Backend, formatter: click.HelpFormatter) -> None:
if doc: if doc:
doc += " " doc += " "
doc += f"(Default: {default!r})" doc += f"(Default: {default!r})"
f_type = get_field_type(f) f_type = f.type
if isinstance(f_type, UnionType):
f_type = f_type.__args__[0]
typename = f_type.__name__ typename = f_type.__name__
rows.append((f"-B {name}:{typename.upper()}", doc)) rows.append((f"-B {name}:{typename.upper()}", doc))
formatter.write_dl(rows) formatter.write_dl(rows)
def set_backend_option( def set_backend_option( # pylint: disable=unused-argument
ctx: click.Context, param: click.Parameter, opts: list[tuple[str, str]] ctx: click.Context, param: click.Parameter, opts: list[tuple[str, str]]
) -> None: ) -> None:
api = ctx.obj.api api = ctx.obj.api
@ -245,8 +195,16 @@ def set_backend_option(
field = all_fields.get(name) field = all_fields.get(name)
if field is None: if field is None:
raise click.BadParameter(f"Invalid parameter {name}") raise click.BadParameter(f"Invalid parameter {name}")
tv = convert_str_to_field(ctx, field, value) f_type = field.type
setattr(api.parameters, field.name, tv) if isinstance(f_type, UnionType):
f_type = f_type.__args__[0]
try:
tv = f_type(value)
except ValueError as e:
raise click.BadParameter(
f"Invalid value for {name} with value {value}: {e}"
) from e
setattr(api.parameters, name, tv)
for kv in opts: for kv in opts:
set_one_backend_option(kv) set_one_backend_option(kv)
@ -306,7 +264,9 @@ def command_uses_new_session(f_in: click.decorators.FC) -> click.Command:
return click.command()(f) return click.command()(f)
def version_callback(ctx: click.Context, param: click.Parameter, value: None) -> None: def version_callback( # pylint: disable=unused-argument
ctx: click.Context, param: click.Parameter, value: None
) -> None:
if not value or ctx.resilient_parsing: if not value or ctx.resilient_parsing:
return return
@ -333,54 +293,18 @@ def version_callback(ctx: click.Context, param: click.Parameter, value: None) ->
@dataclass @dataclass
class Obj: class Obj:
api: Optional[Backend] = None api: Backend | None = None
system_message: Optional[str] = None system_message: str | None = None
session: Optional[list[Message]] = None session: list[Message] | None = None
session_filename: Optional[pathlib.Path] = None session_filename: pathlib.Path | None = None
def maybe_add_txt_extension(fn: pathlib.Path) -> pathlib.Path: class MyCLI(click.MultiCommand):
if not fn.exists():
fn1 = pathlib.Path(str(fn) + ".txt")
if fn1.exists():
fn = fn1
return fn
def expand_splats(args: list[str]) -> list[str]:
result = []
saw_at_dot = False
for a in args:
if a == "@.":
saw_at_dot = True
continue
if saw_at_dot:
result.append(a)
continue
if a.startswith("@@"): ## double @ to escape an argument that starts with @
result.append(a[1:])
continue
if not a.startswith("@"):
result.append(a)
continue
if a.startswith("@:"):
fn: pathlib.Path = preset_path / a[2:]
else:
fn = pathlib.Path(a[1:])
fn = maybe_add_txt_extension(fn)
with open(fn, "r", encoding="utf-8") as f:
content = f.read()
parts = shlex.split(content)
result.extend(expand_splats(parts))
return result
class MyCLI(click.Group):
def make_context( def make_context(
self, self,
info_name: Optional[str], info_name: str | None,
args: list[str], args: list[str],
parent: Optional[click.Context] = None, parent: click.Context | None = None,
**extra: Any, **extra: Any,
) -> click.Context: ) -> click.Context:
result = super().make_context(info_name, args, parent, obj=Obj(), **extra) result = super().make_context(info_name, args, parent, obj=Obj(), **extra)
@ -404,117 +328,14 @@ class MyCLI(click.Group):
except ModuleNotFoundError as exc: except ModuleNotFoundError as exc:
raise click.UsageError(f"Invalid subcommand {cmd_name!r}", ctx) from exc raise click.UsageError(f"Invalid subcommand {cmd_name!r}", ctx) from exc
def gather_preset_info(self) -> list[tuple[str, str]]:
result = []
for p in preset_path.glob("*"):
if p.is_file():
with p.open() as f:
first_line = f.readline()
if first_line.startswith("#"):
help_str = first_line[1:].strip()
else:
help_str = "(A comment on the first line would be shown here)"
result.append((f"@:{p.name}", help_str))
return result
def format_splat_options(
self, ctx: click.Context, formatter: click.HelpFormatter
) -> None:
with formatter.section("Splats"):
formatter.write_text(
"Before any other command-line argument parsing is performed, @FILE arguments are expanded:"
)
formatter.write_paragraph()
formatter.indent()
formatter.write_dl(
[
("@FILE", "Argument is searched relative to the current directory"),
(
"@:FILE",
"Argument is searched relative to the configuration directory (e.g., $HOME/.config/chap/preset)",
),
("@@…", "If an argument starts with a literal '@', double it"),
(
"@.",
"Stops processing any further `@FILE` arguments and leaves them unchanged.",
),
]
)
formatter.dedent()
formatter.write_paragraph()
formatter.write_text(
textwrap.dedent(
"""\
The contents of an `@FILE` are parsed by `shlex.split(comments=True)`.
Comments are supported. If the filename ends in .txt,
the extension may be omitted."""
)
)
formatter.write_paragraph()
if preset_info := self.gather_preset_info():
formatter.write_text("Presets found:")
formatter.write_paragraph()
formatter.indent()
formatter.write_dl(preset_info)
formatter.dedent()
def format_options( def format_options(
self, ctx: click.Context, formatter: click.HelpFormatter self, ctx: click.Context, formatter: click.HelpFormatter
) -> None: ) -> None:
self.format_splat_options(ctx, formatter)
super().format_options(ctx, formatter) super().format_options(ctx, formatter)
api = ctx.obj.api or get_api(ctx) api = ctx.obj.api or get_api()
if hasattr(api, "parameters"): if hasattr(api, "parameters"):
format_backend_help(api, formatter) format_backend_help(api, formatter)
def main(
self,
args: Sequence[str] | None = None,
prog_name: str | None = None,
complete_var: str | None = None,
standalone_mode: bool = True,
windows_expand_args: bool = True,
**extra: Any,
) -> Any:
if args is None:
args = sys.argv[1:]
if os.name == "nt" and windows_expand_args:
args = click.utils._expand_args(args)
else:
args = list(args)
args = expand_splats(args)
return super().main(
args,
prog_name=prog_name,
complete_var=complete_var,
standalone_mode=standalone_mode,
windows_expand_args=windows_expand_args,
**extra,
)
class ConfigRelativeFile(click.File):
def __init__(self, mode: str, where: str) -> None:
super().__init__(mode)
self.where = where
def convert(
self,
value: str | os.PathLike[str] | IO[Any],
param: click.Parameter | None,
ctx: click.Context | None,
) -> Any:
if isinstance(value, str):
if value.startswith(":"):
value = configuration_path / self.where / value[1:]
else:
value = pathlib.Path(value)
if isinstance(value, pathlib.Path):
value = maybe_add_txt_extension(value)
return super().convert(value, param, ctx)
main = MyCLI( main = MyCLI(
help="Commandline interface to ChatGPT", help="Commandline interface to ChatGPT",
@ -526,14 +347,6 @@ main = MyCLI(
help="Show the version and exit", help="Show the version and exit",
callback=version_callback, callback=version_callback,
), ),
click.Option(
("--system-message-file", "-s"),
type=ConfigRelativeFile("r", where="prompt"),
default=None,
callback=set_system_message_from_file,
expose_value=False,
help=f"Set the system message from a file. If the filename starts with `:` it is relative to the {configuration_path}/prompt. If the filename ends in .txt, the extension may be omitted.",
),
click.Option( click.Option(
("--system-message", "-S"), ("--system-message", "-S"),
type=str, type=str,

View file

@ -2,58 +2,20 @@
# #
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
import json
import subprocess
from typing import Protocol
import functools import functools
import platformdirs import platformdirs
class APIKeyProtocol(Protocol):
@property
def api_key_name(self) -> str:
...
class HasKeyProtocol(Protocol):
@property
def parameters(self) -> APIKeyProtocol:
...
class UsesKeyMixin:
def get_key(self: HasKeyProtocol) -> str:
return get_key(self.parameters.api_key_name)
class NoKeyAvailable(Exception): class NoKeyAvailable(Exception):
pass pass
_key_path_base = platformdirs.user_config_path("chap") _key_path_base = platformdirs.user_config_path("chap")
USE_PASSWORD_STORE = _key_path_base / "USE_PASSWORD_STORE"
if USE_PASSWORD_STORE.exists(): @functools.cache
content = USE_PASSWORD_STORE.read_text(encoding="utf-8") def get_key(name: str, what: str = "openai api key") -> str:
if content.strip():
cfg = json.loads(content)
pass_command: list[str] = cfg.get("PASS_COMMAND", ["pass", "show"])
pass_prefix: str = cfg.get("PASS_PREFIX", "chap/")
@functools.cache
def get_key(name: str, what: str = "api key") -> str:
if name == "-":
return "-"
key_path = f"{pass_prefix}{name}"
command = pass_command + [key_path]
return subprocess.check_output(command, encoding="utf-8").split("\n")[0]
else:
@functools.cache
def get_key(name: str, what: str = "api key") -> str:
key_path = _key_path_base / name key_path = _key_path_base / name
if not key_path.exists(): if not key_path.exists():
raise NoKeyAvailable( raise NoKeyAvailable(

View file

@ -7,13 +7,13 @@ from __future__ import annotations
import json import json
import pathlib import pathlib
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from typing import Union, cast from typing import cast
from typing_extensions import TypedDict from typing_extensions import TypedDict
# not an enum.Enum because these objects are not json-serializable, sigh # not an enum.Enum because these objects are not json-serializable, sigh
class Role: class Role: # pylint: disable=too-few-public-methods
ASSISTANT = "assistant" ASSISTANT = "assistant"
SYSTEM = "system" SYSTEM = "system"
USER = "user" USER = "user"
@ -65,11 +65,11 @@ def session_from_json(data: str) -> Session:
return [Message(**mapping) for mapping in j] return [Message(**mapping) for mapping in j]
def session_from_file(path: Union[pathlib.Path, str]) -> Session: def session_from_file(path: pathlib.Path | str) -> Session:
with open(path, "r", encoding="utf-8") as f: with open(path, "r", encoding="utf-8") as f:
return session_from_json(f.read()) return session_from_json(f.read())
def session_to_file(session: Session, path: Union[pathlib.Path, str]) -> None: def session_to_file(session: Session, path: pathlib.Path | str) -> None:
with open(path, "w", encoding="utf-8") as f: with open(path, "w", encoding="utf-8") as f:
f.write(session_to_json(session)) f.write(session_to_json(session))