Compare commits
No commits in common. "main" and "document-mypy" have entirely different histories.
main
...
document-m
28 changed files with 383 additions and 740 deletions
48
.github/workflows/codeql.yml
vendored
Normal file
48
.github/workflows/codeql.yml
vendored
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
# SPDX-FileCopyrightText: 2022 Jeff Epler
|
||||
#
|
||||
# SPDX-License-Identifier: CC0-1.0
|
||||
|
||||
name: "CodeQL"
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ "main" ]
|
||||
pull_request:
|
||||
branches: [ "main" ]
|
||||
schedule:
|
||||
- cron: "53 3 * * 5"
|
||||
|
||||
jobs:
|
||||
analyze:
|
||||
name: Analyze
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
actions: read
|
||||
contents: read
|
||||
security-events: write
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
language: [ python ]
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Install Dependencies (python)
|
||||
run: pip3 install -r requirements-dev.txt
|
||||
|
||||
- name: Initialize CodeQL
|
||||
uses: github/codeql-action/init@v2
|
||||
with:
|
||||
languages: ${{ matrix.language }}
|
||||
queries: +security-and-quality
|
||||
|
||||
- name: Autobuild
|
||||
uses: github/codeql-action/autobuild@v2
|
||||
|
||||
- name: Perform CodeQL Analysis
|
||||
uses: github/codeql-action/analyze@v2
|
||||
with:
|
||||
category: "/language:${{ matrix.language }}"
|
||||
4
.github/workflows/release.yml
vendored
4
.github/workflows/release.yml
vendored
|
|
@ -18,10 +18,10 @@ jobs:
|
|||
GITHUB_CONTEXT: ${{ toJson(github) }}
|
||||
run: echo "$GITHUB_CONTEXT"
|
||||
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v3
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: 3.11
|
||||
|
||||
|
|
|
|||
12
.github/workflows/test.yml
vendored
12
.github/workflows/test.yml
vendored
|
|
@ -17,10 +17,10 @@ jobs:
|
|||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v3
|
||||
|
||||
- name: pre-commit
|
||||
uses: pre-commit/action@v3.0.1
|
||||
uses: pre-commit/action@v3.0.0
|
||||
|
||||
- name: Make patch
|
||||
if: failure()
|
||||
|
|
@ -28,7 +28,7 @@ jobs:
|
|||
|
||||
- name: Upload patch
|
||||
if: failure()
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: patch
|
||||
path: ~/pre-commit.patch
|
||||
|
|
@ -36,10 +36,10 @@ jobs:
|
|||
test-release:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v3
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: 3.11
|
||||
|
||||
|
|
@ -53,7 +53,7 @@ jobs:
|
|||
run: python -mbuild
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: dist
|
||||
path: dist/*
|
||||
|
|
|
|||
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -8,4 +8,3 @@ __pycache__
|
|||
/dist
|
||||
/src/chap/__version__.py
|
||||
/venv
|
||||
/keys.log
|
||||
|
|
|
|||
|
|
@ -6,8 +6,12 @@ default_language_version:
|
|||
python: python3
|
||||
|
||||
repos:
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 23.1.0
|
||||
hooks:
|
||||
- id: black
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.5.0
|
||||
rev: v4.4.0
|
||||
hooks:
|
||||
- id: check-yaml
|
||||
- id: end-of-file-fixer
|
||||
|
|
@ -15,20 +19,23 @@ repos:
|
|||
- id: trailing-whitespace
|
||||
exclude: tests
|
||||
- repo: https://github.com/codespell-project/codespell
|
||||
rev: v2.2.6
|
||||
rev: v2.2.4
|
||||
hooks:
|
||||
- id: codespell
|
||||
args: [-w]
|
||||
- repo: https://github.com/fsfe/reuse-tool
|
||||
rev: v2.1.0
|
||||
rev: v1.1.2
|
||||
hooks:
|
||||
- id: reuse
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
# Ruff version.
|
||||
rev: v0.1.6
|
||||
- repo: https://github.com/pycqa/isort
|
||||
rev: 5.12.0
|
||||
hooks:
|
||||
# Run the linter.
|
||||
- id: ruff
|
||||
args: [ --fix, --preview ]
|
||||
# Run the formatter.
|
||||
- id: ruff-format
|
||||
- id: isort
|
||||
name: isort (python)
|
||||
args: ['--profile', 'black']
|
||||
- repo: https://github.com/pycqa/pylint
|
||||
rev: v2.17.0
|
||||
hooks:
|
||||
- id: pylint
|
||||
additional_dependencies: [click,httpx,lorem-text,simple-parsing,'textual>=0.18.0',tiktoken,websockets]
|
||||
args: ['--source-roots', 'src']
|
||||
|
|
|
|||
12
.pylintrc
Normal file
12
.pylintrc
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
# SPDX-FileCopyrightText: 2023 Jeff Epler <jepler@gmail.com>
|
||||
#
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
[MESSAGES CONTROL]
|
||||
disable=
|
||||
duplicate-code,
|
||||
invalid-name,
|
||||
line-too-long,
|
||||
missing-class-docstring,
|
||||
missing-function-docstring,
|
||||
missing-module-docstring,
|
||||
121
LICENSES/CC0-1.0.txt
Normal file
121
LICENSES/CC0-1.0.txt
Normal file
|
|
@ -0,0 +1,121 @@
|
|||
Creative Commons Legal Code
|
||||
|
||||
CC0 1.0 Universal
|
||||
|
||||
CREATIVE COMMONS CORPORATION IS NOT A LAW FIRM AND DOES NOT PROVIDE
|
||||
LEGAL SERVICES. DISTRIBUTION OF THIS DOCUMENT DOES NOT CREATE AN
|
||||
ATTORNEY-CLIENT RELATIONSHIP. CREATIVE COMMONS PROVIDES THIS
|
||||
INFORMATION ON AN "AS-IS" BASIS. CREATIVE COMMONS MAKES NO WARRANTIES
|
||||
REGARDING THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS
|
||||
PROVIDED HEREUNDER, AND DISCLAIMS LIABILITY FOR DAMAGES RESULTING FROM
|
||||
THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS PROVIDED
|
||||
HEREUNDER.
|
||||
|
||||
Statement of Purpose
|
||||
|
||||
The laws of most jurisdictions throughout the world automatically confer
|
||||
exclusive Copyright and Related Rights (defined below) upon the creator
|
||||
and subsequent owner(s) (each and all, an "owner") of an original work of
|
||||
authorship and/or a database (each, a "Work").
|
||||
|
||||
Certain owners wish to permanently relinquish those rights to a Work for
|
||||
the purpose of contributing to a commons of creative, cultural and
|
||||
scientific works ("Commons") that the public can reliably and without fear
|
||||
of later claims of infringement build upon, modify, incorporate in other
|
||||
works, reuse and redistribute as freely as possible in any form whatsoever
|
||||
and for any purposes, including without limitation commercial purposes.
|
||||
These owners may contribute to the Commons to promote the ideal of a free
|
||||
culture and the further production of creative, cultural and scientific
|
||||
works, or to gain reputation or greater distribution for their Work in
|
||||
part through the use and efforts of others.
|
||||
|
||||
For these and/or other purposes and motivations, and without any
|
||||
expectation of additional consideration or compensation, the person
|
||||
associating CC0 with a Work (the "Affirmer"), to the extent that he or she
|
||||
is an owner of Copyright and Related Rights in the Work, voluntarily
|
||||
elects to apply CC0 to the Work and publicly distribute the Work under its
|
||||
terms, with knowledge of his or her Copyright and Related Rights in the
|
||||
Work and the meaning and intended legal effect of CC0 on those rights.
|
||||
|
||||
1. Copyright and Related Rights. A Work made available under CC0 may be
|
||||
protected by copyright and related or neighboring rights ("Copyright and
|
||||
Related Rights"). Copyright and Related Rights include, but are not
|
||||
limited to, the following:
|
||||
|
||||
i. the right to reproduce, adapt, distribute, perform, display,
|
||||
communicate, and translate a Work;
|
||||
ii. moral rights retained by the original author(s) and/or performer(s);
|
||||
iii. publicity and privacy rights pertaining to a person's image or
|
||||
likeness depicted in a Work;
|
||||
iv. rights protecting against unfair competition in regards to a Work,
|
||||
subject to the limitations in paragraph 4(a), below;
|
||||
v. rights protecting the extraction, dissemination, use and reuse of data
|
||||
in a Work;
|
||||
vi. database rights (such as those arising under Directive 96/9/EC of the
|
||||
European Parliament and of the Council of 11 March 1996 on the legal
|
||||
protection of databases, and under any national implementation
|
||||
thereof, including any amended or successor version of such
|
||||
directive); and
|
||||
vii. other similar, equivalent or corresponding rights throughout the
|
||||
world based on applicable law or treaty, and any national
|
||||
implementations thereof.
|
||||
|
||||
2. Waiver. To the greatest extent permitted by, but not in contravention
|
||||
of, applicable law, Affirmer hereby overtly, fully, permanently,
|
||||
irrevocably and unconditionally waives, abandons, and surrenders all of
|
||||
Affirmer's Copyright and Related Rights and associated claims and causes
|
||||
of action, whether now known or unknown (including existing as well as
|
||||
future claims and causes of action), in the Work (i) in all territories
|
||||
worldwide, (ii) for the maximum duration provided by applicable law or
|
||||
treaty (including future time extensions), (iii) in any current or future
|
||||
medium and for any number of copies, and (iv) for any purpose whatsoever,
|
||||
including without limitation commercial, advertising or promotional
|
||||
purposes (the "Waiver"). Affirmer makes the Waiver for the benefit of each
|
||||
member of the public at large and to the detriment of Affirmer's heirs and
|
||||
successors, fully intending that such Waiver shall not be subject to
|
||||
revocation, rescission, cancellation, termination, or any other legal or
|
||||
equitable action to disrupt the quiet enjoyment of the Work by the public
|
||||
as contemplated by Affirmer's express Statement of Purpose.
|
||||
|
||||
3. Public License Fallback. Should any part of the Waiver for any reason
|
||||
be judged legally invalid or ineffective under applicable law, then the
|
||||
Waiver shall be preserved to the maximum extent permitted taking into
|
||||
account Affirmer's express Statement of Purpose. In addition, to the
|
||||
extent the Waiver is so judged Affirmer hereby grants to each affected
|
||||
person a royalty-free, non transferable, non sublicensable, non exclusive,
|
||||
irrevocable and unconditional license to exercise Affirmer's Copyright and
|
||||
Related Rights in the Work (i) in all territories worldwide, (ii) for the
|
||||
maximum duration provided by applicable law or treaty (including future
|
||||
time extensions), (iii) in any current or future medium and for any number
|
||||
of copies, and (iv) for any purpose whatsoever, including without
|
||||
limitation commercial, advertising or promotional purposes (the
|
||||
"License"). The License shall be deemed effective as of the date CC0 was
|
||||
applied by Affirmer to the Work. Should any part of the License for any
|
||||
reason be judged legally invalid or ineffective under applicable law, such
|
||||
partial invalidity or ineffectiveness shall not invalidate the remainder
|
||||
of the License, and in such case Affirmer hereby affirms that he or she
|
||||
will not (i) exercise any of his or her remaining Copyright and Related
|
||||
Rights in the Work or (ii) assert any associated claims and causes of
|
||||
action with respect to the Work, in either case contrary to Affirmer's
|
||||
express Statement of Purpose.
|
||||
|
||||
4. Limitations and Disclaimers.
|
||||
|
||||
a. No trademark or patent rights held by Affirmer are waived, abandoned,
|
||||
surrendered, licensed or otherwise affected by this document.
|
||||
b. Affirmer offers the Work as-is and makes no representations or
|
||||
warranties of any kind concerning the Work, express, implied,
|
||||
statutory or otherwise, including without limitation warranties of
|
||||
title, merchantability, fitness for a particular purpose, non
|
||||
infringement, or the absence of latent or other defects, accuracy, or
|
||||
the present or absence of errors, whether or not discoverable, all to
|
||||
the greatest extent permissible under applicable law.
|
||||
c. Affirmer disclaims responsibility for clearing rights of other persons
|
||||
that may apply to the Work or any use thereof, including without
|
||||
limitation any person's Copyright and Related Rights in the Work.
|
||||
Further, Affirmer disclaims responsibility for obtaining any necessary
|
||||
consents, permissions or other rights required for any use of the
|
||||
Work.
|
||||
d. Affirmer understands and acknowledges that Creative Commons is not a
|
||||
party to this document and has no duty or obligation with respect to
|
||||
this CC0 or use of the Work.
|
||||
65
README.md
65
README.md
|
|
@ -9,11 +9,11 @@ SPDX-License-Identifier: MIT
|
|||
|
||||
# chap - A Python interface to chatgpt and other LLMs, including a terminal user interface (tui)
|
||||
|
||||

|
||||

|
||||
|
||||
## System requirements
|
||||
|
||||
Chap is primarily developed on Linux with Python 3.11. Moderate effort will be made to support versions back to Python 3.9 (Debian oldstable).
|
||||
Chap is developed on Linux with Python 3.11. Due to use of the `X | Y` style of type hints, it is known to not work on Python 3.9 and older. The target minimum Python version is 3.11 (debian stable).
|
||||
|
||||
## Installation
|
||||
|
||||
|
|
@ -81,54 +81,8 @@ Put your OpenAI API key in the platform configuration directory for chap, e.g.,
|
|||
|
||||
* `chap grep needle`
|
||||
|
||||
## `@FILE` arguments
|
||||
|
||||
It's useful to set a bunch of related arguments together, for instance to fully
|
||||
configure a back-end. This functionality is implemented via `@FILE` arguments.
|
||||
|
||||
Before any other command-line argument parsing is performed, `@FILE` arguments are expanded:
|
||||
|
||||
* An `@FILE` argument is searched relative to the current directory
|
||||
* An `@:FILE` argument is searched relative to the configuration directory (e.g., $HOME/.config/chap/presets)
|
||||
* If an argument starts with a literal `@`, double it: `@@`
|
||||
* `@.` stops processing any further `@FILE` arguments and leaves them unchanged.
|
||||
The contents of an `@FILE` are parsed according to `shlex.split(comments=True)`.
|
||||
Comments are supported.
|
||||
A typical content might look like this:
|
||||
```
|
||||
# cfg/gpt-4o: Use more expensive gpt 4o and custom prompt
|
||||
--backend openai-chatgpt
|
||||
-B model:gpt-4o
|
||||
-s :my-custom-system-message.txt
|
||||
```
|
||||
and you might use it with
|
||||
```
|
||||
chap @:cfg/gpt-4o ask what version of gpt is this
|
||||
```
|
||||
|
||||
## Interactive terminal usage
|
||||
The interactive terminal mode is accessed via `chap tui`.
|
||||
|
||||
There are a variety of keyboard shortcuts to be aware of:
|
||||
* tab/shift-tab to move between the entry field and the conversation, or between conversation items
|
||||
* While in the text box, F9 or (if supported by your terminal) alt+enter to submit multiline text
|
||||
* while on a conversation item:
|
||||
* ctrl+x to re-draft the message. This
|
||||
* saves a copy of the session in an auto-named file in the conversations folder
|
||||
* removes the conversation from this message to the end
|
||||
* puts the user's message in the text box to edit
|
||||
* ctrl+x to re-submit the message. This
|
||||
* saves a copy of the session in an auto-named file in the conversations folder
|
||||
* removes the conversation from this message to the end
|
||||
* puts the user's message in the text box
|
||||
* and submits it immediately
|
||||
* ctrl+y to yank the message. This places the response part of the current
|
||||
interaction in the operating system clipboard to be pasted (e..g, with
|
||||
ctrl+v or command+v in other software)
|
||||
* ctrl+q to toggle whether this message may be included in the contextual history for a future query.
|
||||
The exact way history is submitted is determined by the back-end, often by
|
||||
counting messages or tokens, but the ctrl+q toggle ensures this message (both the user
|
||||
and assistant message parts) are not considered.
|
||||
* `chap tui`
|
||||
|
||||
## Sessions & Command-line Parameters
|
||||
|
||||
|
|
@ -144,26 +98,21 @@ an existing session with `-s`. Or, you can continue the last session with
|
|||
You can set the "system message" with the `-S` flag.
|
||||
|
||||
You can select the text generating backend with the `-b` flag:
|
||||
* openai-chatgpt: the default, paid API, best quality results. Also works with compatible API implementations including llama-cpp when the correct backend URL is specified.
|
||||
* openai-chatgpt: the default, paid API, best quality results
|
||||
* llama-cpp: Works with [llama.cpp's http server](https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md) and can run locally with various models,
|
||||
though it is [optimized for models that use the llama2-style prompting](https://huggingface.co/blog/llama2#how-to-prompt-llama-2). Consider using llama.cpp's OpenAI compatible API with the openai-chatgpt backend instead, in which case the server can apply the chat template.
|
||||
though it is [optimized for models that use the llama2-style prompting](https://huggingface.co/blog/llama2#how-to-prompt-llama-2).
|
||||
Set the server URL with `-B url:...`.
|
||||
* textgen: Works with https://github.com/oobabooga/text-generation-webui and can run locally with various models.
|
||||
Needs the server URL in *$configuration_directory/textgen\_url*.
|
||||
* mistral: Works with the [mistral paid API](https://docs.mistral.ai/).
|
||||
* anthropic: Works with the [anthropic paid API](https://docs.anthropic.com/en/home).
|
||||
* huggingface: Works with the [huggingface API](https://huggingface.co/docs/api-inference/index), which includes a free tier.
|
||||
* lorem: local non-AI lorem generator for testing
|
||||
|
||||
Backends have settings such as URLs and where API keys are stored. use `chap --backend
|
||||
<BACKEND> --help` to list settings for a particular backend.
|
||||
|
||||
## Environment variables
|
||||
|
||||
The backend can be set with the `CHAP_BACKEND` environment variable.
|
||||
|
||||
Backend settings can be set with `CHAP_<backend_name>_<parameter_name>`, with `backend_name` and `parameter_name` all in caps.
|
||||
|
||||
For instance, `CHAP_LLAMA_CPP_URL=http://server.local:8080/completion` changes the default server URL for the llama-cpp backend.
|
||||
For instance, `CHAP_LLAMA_CPP_URL=http://server.local:8080/completion` changes the default server URL for the llama-cpp back-end.
|
||||
|
||||
## Importing from ChatGPT
|
||||
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ src_path = project_root / "src"
|
|||
sys.path.insert(0, str(src_path))
|
||||
|
||||
if __name__ == "__main__":
|
||||
# pylint: disable=import-error,no-name-in-module
|
||||
from chap.core import main
|
||||
|
||||
main()
|
||||
|
|
|
|||
|
|
@ -20,19 +20,15 @@ name="chap"
|
|||
authors = [{name = "Jeff Epler", email = "jepler@gmail.com"}]
|
||||
description = "Interact with the OpenAI ChatGPT API (and other text generators)"
|
||||
dynamic = ["readme","version","dependencies"]
|
||||
requires-python = ">=3.9"
|
||||
keywords = ["llm", "tui", "chatgpt"]
|
||||
classifiers = [
|
||||
"Development Status :: 3 - Alpha",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Operating System :: OS Independent",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Programming Language :: Python :: Implementation :: CPython",
|
||||
"Programming Language :: Python :: Implementation :: PyPy",
|
||||
"Programming Language :: Python :: Implementation :: CPython",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Operating System :: OS Independent",
|
||||
]
|
||||
[project.urls]
|
||||
homepage = "https://github.com/jepler/chap"
|
||||
|
|
|
|||
|
|
@ -6,8 +6,7 @@ click
|
|||
httpx
|
||||
lorem-text
|
||||
platformdirs
|
||||
pyperclip
|
||||
simple_parsing
|
||||
textual[syntax] >= 4
|
||||
textual>=0.18.0
|
||||
tiktoken
|
||||
websockets
|
||||
|
|
|
|||
|
|
@ -1,95 +0,0 @@
|
|||
# SPDX-FileCopyrightText: 2024 Jeff Epler <jepler@gmail.com>
|
||||
#
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import AsyncGenerator, Any
|
||||
|
||||
import httpx
|
||||
|
||||
from ..core import AutoAskMixin, Backend
|
||||
from ..key import UsesKeyMixin
|
||||
from ..session import Assistant, Role, Session, User
|
||||
|
||||
|
||||
class Anthropic(AutoAskMixin, UsesKeyMixin):
|
||||
@dataclass
|
||||
class Parameters:
|
||||
url: str = "https://api.anthropic.com"
|
||||
model: str = "claude-3-5-sonnet-20240620"
|
||||
max_new_tokens: int = 1000
|
||||
api_key_name = "anthropic_api_key"
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.parameters = self.Parameters()
|
||||
|
||||
system_message = """\
|
||||
Answer each question accurately and thoroughly.
|
||||
"""
|
||||
|
||||
def make_full_query(self, messages: Session, max_query_size: int) -> dict[str, Any]:
|
||||
system = [m.content for m in messages if m.role == Role.SYSTEM]
|
||||
messages = [m for m in messages if m.role != Role.SYSTEM and m.content]
|
||||
del messages[:-max_query_size]
|
||||
result = dict(
|
||||
model=self.parameters.model,
|
||||
max_tokens=self.parameters.max_new_tokens,
|
||||
messages=[dict(role=str(m.role), content=m.content) for m in messages],
|
||||
stream=True,
|
||||
)
|
||||
if system and system[0]:
|
||||
result["system"] = system[0]
|
||||
return result
|
||||
|
||||
async def aask(
|
||||
self,
|
||||
session: Session,
|
||||
query: str,
|
||||
*,
|
||||
max_query_size: int = 5,
|
||||
timeout: float = 180,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
new_content: list[str] = []
|
||||
params = self.make_full_query(session + [User(query)], max_query_size)
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
async with client.stream(
|
||||
"POST",
|
||||
f"{self.parameters.url}/v1/messages",
|
||||
json=params,
|
||||
headers={
|
||||
"x-api-key": self.get_key(),
|
||||
"content-type": "application/json",
|
||||
"anthropic-version": "2023-06-01",
|
||||
"anthropic-beta": "messages-2023-12-15",
|
||||
},
|
||||
) as response:
|
||||
if response.status_code == 200:
|
||||
async for line in response.aiter_lines():
|
||||
if line.startswith("data:"):
|
||||
data = line.removeprefix("data:").strip()
|
||||
j = json.loads(data)
|
||||
content = j.get("delta", {}).get("text", "")
|
||||
if content:
|
||||
new_content.append(content)
|
||||
yield content
|
||||
else:
|
||||
content = f"\nFailed with {response=!r}"
|
||||
new_content.append(content)
|
||||
yield content
|
||||
async for line in response.aiter_lines():
|
||||
new_content.append(line)
|
||||
yield line
|
||||
except httpx.HTTPError as e:
|
||||
content = f"\nException: {e!r}"
|
||||
new_content.append(content)
|
||||
yield content
|
||||
|
||||
session.extend([User(query), Assistant("".join(new_content))])
|
||||
|
||||
|
||||
def factory() -> Backend:
|
||||
"""Uses the anthropic text-generation-interface web API"""
|
||||
return Anthropic()
|
||||
|
|
@ -9,11 +9,11 @@ from typing import Any, AsyncGenerator
|
|||
import httpx
|
||||
|
||||
from ..core import AutoAskMixin, Backend
|
||||
from ..key import UsesKeyMixin
|
||||
from ..key import get_key
|
||||
from ..session import Assistant, Role, Session, User
|
||||
|
||||
|
||||
class HuggingFace(AutoAskMixin, UsesKeyMixin):
|
||||
class HuggingFace(AutoAskMixin):
|
||||
@dataclass
|
||||
class Parameters:
|
||||
url: str = "https://api-inference.huggingface.co"
|
||||
|
|
@ -24,7 +24,6 @@ class HuggingFace(AutoAskMixin, UsesKeyMixin):
|
|||
after_user: str = """ [/INST] """
|
||||
after_assistant: str = """ </s><s>[INST] """
|
||||
stop_token_id = 2
|
||||
api_key_name = "huggingface_api_token"
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
|
@ -90,7 +89,9 @@ A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, a
|
|||
*,
|
||||
max_query_size: int = 5,
|
||||
timeout: float = 180,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
) -> AsyncGenerator[
|
||||
str, None
|
||||
]: # pylint: disable=unused-argument,too-many-locals,too-many-branches
|
||||
new_content: list[str] = []
|
||||
inputs = self.make_full_query(session + [User(query)], max_query_size)
|
||||
try:
|
||||
|
|
@ -111,6 +112,10 @@ A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, a
|
|||
|
||||
session.extend([User(query), Assistant("".join(new_content))])
|
||||
|
||||
@classmethod
|
||||
def get_key(cls) -> str:
|
||||
return get_key("huggingface_api_token")
|
||||
|
||||
|
||||
def factory() -> Backend:
|
||||
"""Uses the huggingface text-generation-interface web API"""
|
||||
|
|
|
|||
|
|
@ -18,16 +18,10 @@ class LlamaCpp(AutoAskMixin):
|
|||
url: str = "http://localhost:8080/completion"
|
||||
"""The URL of a llama.cpp server's completion endpoint."""
|
||||
|
||||
start_prompt: str = "<|begin_of_text|>"
|
||||
system_format: str = (
|
||||
"<|start_header_id|>system<|end_header_id|>\n\n{}<|eot_id|>"
|
||||
)
|
||||
user_format: str = "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
|
||||
assistant_format: str = (
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n{}<|eot_id|>"
|
||||
)
|
||||
end_prompt: str = "<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
stop: str | None = None
|
||||
start_prompt: str = """<s>[INST] <<SYS>>\n"""
|
||||
after_system: str = "\n<</SYS>>\n\n"
|
||||
after_user: str = """ [/INST] """
|
||||
after_assistant: str = """ </s><s>[INST] """
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
|
@ -40,16 +34,17 @@ A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, a
|
|||
def make_full_query(self, messages: Session, max_query_size: int) -> str:
|
||||
del messages[1:-max_query_size]
|
||||
result = [self.parameters.start_prompt]
|
||||
formats = {
|
||||
Role.SYSTEM: self.parameters.system_format,
|
||||
Role.USER: self.parameters.user_format,
|
||||
Role.ASSISTANT: self.parameters.assistant_format,
|
||||
}
|
||||
for m in messages:
|
||||
content = (m.content or "").strip()
|
||||
if not content:
|
||||
continue
|
||||
result.append(formats[m.role].format(content))
|
||||
result.append(content)
|
||||
if m.role == Role.SYSTEM:
|
||||
result.append(self.parameters.after_system)
|
||||
elif m.role == Role.ASSISTANT:
|
||||
result.append(self.parameters.after_assistant)
|
||||
elif m.role == Role.USER:
|
||||
result.append(self.parameters.after_user)
|
||||
full_query = "".join(result)
|
||||
return full_query
|
||||
|
||||
|
|
@ -60,11 +55,13 @@ A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, a
|
|||
*,
|
||||
max_query_size: int = 5,
|
||||
timeout: float = 180,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
) -> AsyncGenerator[
|
||||
str, None
|
||||
]: # pylint: disable=unused-argument,too-many-locals,too-many-branches
|
||||
params = {
|
||||
"prompt": self.make_full_query(session + [User(query)], max_query_size),
|
||||
"stream": True,
|
||||
"stop": ["</s>", "<s>", "[INST]", "<|eot_id|>"],
|
||||
"stop": ["</s>", "<s>", "[INST]"],
|
||||
}
|
||||
new_content: list[str] = []
|
||||
try:
|
||||
|
|
@ -101,10 +98,5 @@ A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, a
|
|||
|
||||
|
||||
def factory() -> Backend:
|
||||
"""Uses the llama.cpp completion web API
|
||||
|
||||
Note: Consider using the openai-chatgpt backend with a custom URL instead.
|
||||
The llama.cpp server will automatically apply common chat templates with the
|
||||
openai-chatgpt backend, while chat templates must be manually configured client side
|
||||
with this backend."""
|
||||
"""Uses the llama.cpp completion web API"""
|
||||
return LlamaCpp()
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ class Lorem:
|
|||
session: Session,
|
||||
query: str,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
data = self.ask(session, query)
|
||||
data = self.ask(session, query)[-1]
|
||||
for word, opt_sep in ipartition(data):
|
||||
yield word + opt_sep
|
||||
await asyncio.sleep(
|
||||
|
|
@ -56,7 +56,7 @@ class Lorem:
|
|||
self,
|
||||
session: Session,
|
||||
query: str,
|
||||
) -> str:
|
||||
) -> str: # pylint: disable=unused-argument
|
||||
new_content = cast(
|
||||
str,
|
||||
lorem.paragraphs(
|
||||
|
|
|
|||
|
|
@ -1,96 +0,0 @@
|
|||
# SPDX-FileCopyrightText: 2024 Jeff Epler <jepler@gmail.com>
|
||||
#
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import AsyncGenerator, Any
|
||||
|
||||
import httpx
|
||||
|
||||
from ..core import AutoAskMixin
|
||||
from ..key import UsesKeyMixin
|
||||
from ..session import Assistant, Session, User
|
||||
|
||||
|
||||
class Mistral(AutoAskMixin, UsesKeyMixin):
|
||||
@dataclass
|
||||
class Parameters:
|
||||
url: str = "https://api.mistral.ai"
|
||||
model: str = "open-mistral-7b"
|
||||
max_new_tokens: int = 1000
|
||||
api_key_name = "mistral_api_key"
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.parameters = self.Parameters()
|
||||
|
||||
system_message = """\
|
||||
Answer each question accurately and thoroughly.
|
||||
"""
|
||||
|
||||
def make_full_query(self, messages: Session, max_query_size: int) -> dict[str, Any]:
|
||||
messages = [m for m in messages if m.content]
|
||||
del messages[1:-max_query_size]
|
||||
result = dict(
|
||||
model=self.parameters.model,
|
||||
max_tokens=self.parameters.max_new_tokens,
|
||||
messages=[dict(role=str(m.role), content=m.content) for m in messages],
|
||||
stream=True,
|
||||
)
|
||||
return result
|
||||
|
||||
async def aask(
|
||||
self,
|
||||
session: Session,
|
||||
query: str,
|
||||
*,
|
||||
max_query_size: int = 5,
|
||||
timeout: float = 180,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
new_content: list[str] = []
|
||||
params = self.make_full_query(session + [User(query)], max_query_size)
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
async with client.stream(
|
||||
"POST",
|
||||
f"{self.parameters.url}/v1/chat/completions",
|
||||
json=params,
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.get_key()}",
|
||||
"content-type": "application/json",
|
||||
"accept": "application/json",
|
||||
"model": "application/json",
|
||||
},
|
||||
) as response:
|
||||
if response.status_code == 200:
|
||||
async for line in response.aiter_lines():
|
||||
if line.startswith("data:"):
|
||||
data = line.removeprefix("data:").strip()
|
||||
if data == "[DONE]":
|
||||
break
|
||||
j = json.loads(data)
|
||||
content = (
|
||||
j.get("choices", [{}])[0]
|
||||
.get("delta", {})
|
||||
.get("content", "")
|
||||
)
|
||||
if content:
|
||||
new_content.append(content)
|
||||
yield content
|
||||
else:
|
||||
content = f"\nFailed with {response=!r}"
|
||||
new_content.append(content)
|
||||
yield content
|
||||
async for line in response.aiter_lines():
|
||||
new_content.append(line)
|
||||
yield line
|
||||
except httpx.HTTPError as e:
|
||||
content = f"\nException: {e!r}"
|
||||
new_content.append(content)
|
||||
yield content
|
||||
|
||||
session.extend([User(query), Assistant("".join(new_content))])
|
||||
|
||||
|
||||
factory = Mistral
|
||||
|
|
@ -12,7 +12,7 @@ import httpx
|
|||
import tiktoken
|
||||
|
||||
from ..core import Backend
|
||||
from ..key import UsesKeyMixin
|
||||
from ..key import get_key
|
||||
from ..session import Assistant, Message, Session, User, session_to_list
|
||||
|
||||
|
||||
|
|
@ -63,29 +63,15 @@ class EncodingMeta:
|
|||
return cls(encoding, tokens_per_message, tokens_per_name, tokens_overhead)
|
||||
|
||||
|
||||
class ChatGPT(UsesKeyMixin):
|
||||
class ChatGPT:
|
||||
@dataclass
|
||||
class Parameters:
|
||||
model: str = "gpt-4o-mini"
|
||||
"""The model to use. The most common alternative value is 'gpt-4o'."""
|
||||
model: str = "gpt-3.5-turbo"
|
||||
"""The model to use. The most common alternative value is 'gpt-4'."""
|
||||
|
||||
max_request_tokens: int = 1024
|
||||
"""The approximate greatest number of tokens to send in a request. When the session is long, the system prompt and 1 or more of the most recent interaction steps are sent."""
|
||||
|
||||
url: str = "https://api.openai.com/v1/chat/completions"
|
||||
"""The URL of a chatgpt-compatible server's completion endpoint. Notably, llama.cpp's server is compatible with this backend, and can automatically apply common chat templates too."""
|
||||
|
||||
temperature: float | None = None
|
||||
"""The model temperature for sampling"""
|
||||
|
||||
top_p: float | None = None
|
||||
"""The model temperature for sampling"""
|
||||
|
||||
api_key_name: str = "openai_api_key"
|
||||
"""The OpenAI API key"""
|
||||
|
||||
parameters: Parameters
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.parameters = self.Parameters()
|
||||
|
||||
|
|
@ -111,11 +97,11 @@ class ChatGPT(UsesKeyMixin):
|
|||
def ask(self, session: Session, query: str, *, timeout: float = 60) -> str:
|
||||
full_prompt = self.make_full_prompt(session + [User(query)])
|
||||
response = httpx.post(
|
||||
self.parameters.url,
|
||||
"https://api.openai.com/v1/chat/completions",
|
||||
json={
|
||||
"model": self.parameters.model,
|
||||
"messages": session_to_list(full_prompt),
|
||||
},
|
||||
}, # pylint: disable=no-member
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.get_key()}",
|
||||
},
|
||||
|
|
@ -142,12 +128,10 @@ class ChatGPT(UsesKeyMixin):
|
|||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
async with client.stream(
|
||||
"POST",
|
||||
self.parameters.url,
|
||||
"https://api.openai.com/v1/chat/completions",
|
||||
headers={"authorization": f"Bearer {self.get_key()}"},
|
||||
json={
|
||||
"model": self.parameters.model,
|
||||
"temperature": self.parameters.temperature,
|
||||
"top_p": self.parameters.top_p,
|
||||
"stream": True,
|
||||
"messages": session_to_list(full_prompt),
|
||||
},
|
||||
|
|
@ -176,6 +160,10 @@ class ChatGPT(UsesKeyMixin):
|
|||
|
||||
session.extend([User(query), Assistant("".join(new_content))])
|
||||
|
||||
@classmethod
|
||||
def get_key(cls) -> str:
|
||||
return get_key("openai_api_key")
|
||||
|
||||
|
||||
def factory() -> Backend:
|
||||
"""Uses the OpenAI chat completion API"""
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ USER: Hello, AI.
|
|||
|
||||
AI: Hello! How can I assist you today?"""
|
||||
|
||||
async def aask(
|
||||
async def aask( # pylint: disable=unused-argument,too-many-locals,too-many-branches
|
||||
self,
|
||||
session: Session,
|
||||
query: str,
|
||||
|
|
@ -60,11 +60,12 @@ AI: Hello! How can I assist you today?"""
|
|||
}
|
||||
full_prompt = session + [User(query)]
|
||||
del full_prompt[1:-max_query_size]
|
||||
new_data = old_data = full_query = "\n".join(
|
||||
f"{role_map.get(q.role,'')}{q.content}\n" for q in full_prompt
|
||||
) + f"\n{role_map.get('assistant')}"
|
||||
new_data = old_data = full_query = (
|
||||
"\n".join(f"{role_map.get(q.role,'')}{q.content}\n" for q in full_prompt)
|
||||
+ f"\n{role_map.get('assistant')}"
|
||||
)
|
||||
try:
|
||||
async with websockets.connect(
|
||||
async with websockets.connect( # pylint: disable=no-member
|
||||
f"ws://{self.parameters.server_hostname}:7860/queue/join"
|
||||
) as websocket:
|
||||
while content := json.loads(await websocket.recv()):
|
||||
|
|
@ -126,7 +127,7 @@ AI: Hello! How can I assist you today?"""
|
|||
# stop generation by closing the websocket here
|
||||
if content["msg"] == "process_completed":
|
||||
break
|
||||
except Exception as e:
|
||||
except Exception as e: # pylint: disable=broad-exception-caught
|
||||
content = f"\nException: {e!r}"
|
||||
new_data += content
|
||||
yield content
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
|
||||
import asyncio
|
||||
import sys
|
||||
from typing import Iterable, Optional, Protocol
|
||||
from typing import Iterable, Protocol
|
||||
|
||||
import click
|
||||
import rich
|
||||
|
|
@ -40,7 +40,7 @@ class DumbPrinter:
|
|||
|
||||
|
||||
class WrappingPrinter:
|
||||
def __init__(self, width: Optional[int] = None) -> None:
|
||||
def __init__(self, width: int | None = None) -> None:
|
||||
self._width = width or rich.get_console().width
|
||||
self._column = 0
|
||||
self._line = ""
|
||||
|
|
@ -100,9 +100,8 @@ def verbose_ask(api: Backend, session: Session, q: str, print_prompt: bool) -> s
|
|||
|
||||
@command_uses_new_session
|
||||
@click.option("--print-prompt/--no-print-prompt", default=True)
|
||||
@click.option("--stdin/--no-stdin", "use_stdin", default=False)
|
||||
@click.argument("prompt", nargs=-1)
|
||||
def main(obj: Obj, prompt: list[str], use_stdin: bool, print_prompt: bool) -> None:
|
||||
@click.argument("prompt", nargs=-1, required=True)
|
||||
def main(obj: Obj, prompt: str, print_prompt: bool) -> None:
|
||||
"""Ask a question (command-line argument is passed as prompt)"""
|
||||
session = obj.session
|
||||
assert session is not None
|
||||
|
|
@ -113,16 +112,9 @@ def main(obj: Obj, prompt: list[str], use_stdin: bool, print_prompt: bool) -> No
|
|||
api = obj.api
|
||||
assert api is not None
|
||||
|
||||
if use_stdin:
|
||||
if prompt:
|
||||
raise click.UsageError("Can't use 'prompt' together with --stdin")
|
||||
|
||||
joined_prompt = sys.stdin.read()
|
||||
else:
|
||||
joined_prompt = " ".join(prompt)
|
||||
# symlink_session_filename(session_filename)
|
||||
|
||||
response = verbose_ask(api, session, joined_prompt, print_prompt=print_prompt)
|
||||
response = verbose_ask(api, session, " ".join(prompt), print_prompt=print_prompt)
|
||||
|
||||
print(f"Saving session to {session_filename}", file=sys.stderr)
|
||||
if response is not None:
|
||||
|
|
@ -130,4 +122,4 @@ def main(obj: Obj, prompt: list[str], use_stdin: bool, print_prompt: bool) -> No
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main() # pylint: disable=no-value-for-parameter
|
||||
|
|
|
|||
|
|
@ -38,4 +38,4 @@ def main(obj: Obj, no_system: bool) -> None:
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main() # pylint: disable=no-value-for-parameter
|
||||
|
|
|
|||
|
|
@ -24,8 +24,8 @@ def list_files_matching_rx(
|
|||
"*.json"
|
||||
):
|
||||
try:
|
||||
session = session_from_file(conversation)
|
||||
except Exception as e:
|
||||
session = session_from_file(conversation) # pylint: disable=no-member
|
||||
except Exception as e: # pylint: disable=broad-exception-caught
|
||||
print(f"Failed to read {conversation}: {e}", file=sys.stderr)
|
||||
continue
|
||||
|
||||
|
|
@ -67,4 +67,4 @@ def main(
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main() # pylint: disable=no-value-for-parameter
|
||||
|
|
|
|||
|
|
@ -82,4 +82,4 @@ def main(output_directory: pathlib.Path, files: list[TextIO]) -> None:
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main() # pylint: disable=no-value-for-parameter
|
||||
|
|
|
|||
|
|
@ -45,4 +45,4 @@ def main(obj: Obj, no_system: bool) -> None:
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main() # pylint: disable=no-value-for-parameter
|
||||
|
|
|
|||
|
|
@ -36,9 +36,6 @@ Markdown:focus-within {
|
|||
Markdown {
|
||||
border-left: heavy transparent;
|
||||
}
|
||||
MarkdownFence {
|
||||
max-height: 9999;
|
||||
}
|
||||
Footer {
|
||||
dock: top;
|
||||
}
|
||||
|
|
@ -57,5 +54,3 @@ Input {
|
|||
Markdown {
|
||||
margin: 0 1 0 0;
|
||||
}
|
||||
|
||||
SubmittableTextArea { height: auto; min-height: 5; margin: 0; border: none; border-left: heavy $primary }
|
||||
|
|
|
|||
|
|
@ -3,58 +3,35 @@
|
|||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import asyncio
|
||||
import subprocess
|
||||
import sys
|
||||
from typing import Any, Optional, cast, TYPE_CHECKING
|
||||
from typing import cast
|
||||
|
||||
import click
|
||||
from markdown_it import MarkdownIt
|
||||
from textual import work
|
||||
from textual._ansi_sequences import ANSI_SEQUENCES_KEYS
|
||||
from textual.app import App, ComposeResult
|
||||
from textual.binding import Binding
|
||||
from textual.containers import Container, Horizontal, VerticalScroll
|
||||
from textual.keys import Keys
|
||||
from textual.widgets import Button, Footer, LoadingIndicator, Markdown, TextArea
|
||||
from textual.widgets import Button, Footer, Input, LoadingIndicator, Markdown
|
||||
|
||||
from ..core import Backend, Obj, command_uses_new_session, get_api, new_session_path
|
||||
from ..session import Assistant, Message, Session, User, new_session, session_to_file
|
||||
|
||||
|
||||
# workaround for pyperclip being un-typed
|
||||
if TYPE_CHECKING:
|
||||
|
||||
def pyperclip_copy(data: str) -> None:
|
||||
...
|
||||
else:
|
||||
from pyperclip import copy as pyperclip_copy
|
||||
|
||||
|
||||
# Monkeypatch alt+enter as meaning "F9", WFM
|
||||
# ignore typing here because ANSI_SEQUENCES_KEYS is a Mapping[] which is read-only as
|
||||
# far as mypy is concerned.
|
||||
ANSI_SEQUENCES_KEYS["\x1b\r"] = (Keys.F9,) # type: ignore
|
||||
ANSI_SEQUENCES_KEYS["\x1b\n"] = (Keys.F9,) # type: ignore
|
||||
|
||||
|
||||
class SubmittableTextArea(TextArea):
|
||||
BINDINGS = [
|
||||
Binding("f9", "app.submit", "Submit", show=True),
|
||||
Binding("tab", "focus_next", show=False, priority=True), # no inserting tabs
|
||||
]
|
||||
|
||||
|
||||
def parser_factory() -> MarkdownIt:
|
||||
parser = MarkdownIt()
|
||||
parser.options["html"] = False
|
||||
return parser
|
||||
|
||||
|
||||
class ChapMarkdown(Markdown, can_focus=True, can_focus_children=False):
|
||||
class ChapMarkdown(
|
||||
Markdown, can_focus=True, can_focus_children=False
|
||||
): # pylint: disable=function-redefined
|
||||
BINDINGS = [
|
||||
Binding("ctrl+c", "app.yank", "Copy text", show=True),
|
||||
Binding("ctrl+r", "app.resubmit", "resubmit", show=True),
|
||||
Binding("ctrl+x", "app.redraft", "redraft", show=True),
|
||||
Binding("ctrl+q", "app.toggle_history", "history toggle", show=True),
|
||||
Binding("ctrl+y", "yank", "Yank text", show=True),
|
||||
Binding("ctrl+r", "resubmit", "resubmit", show=True),
|
||||
Binding("ctrl+x", "redraft", "redraft", show=True),
|
||||
Binding("ctrl+q", "toggle_history", "history toggle", show=True),
|
||||
]
|
||||
|
||||
|
||||
|
|
@ -75,14 +52,14 @@ class CancelButton(Button):
|
|||
class Tui(App[None]):
|
||||
CSS_PATH = "tui.css"
|
||||
BINDINGS = [
|
||||
Binding("ctrl+q", "quit", "Quit", show=True, priority=True),
|
||||
Binding("ctrl+c", "quit", "Quit", show=True, priority=True),
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self, api: Optional[Backend] = None, session: Optional[Session] = None
|
||||
self, api: Backend | None = None, session: Session | None = None
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.api = api or get_api(click.Context(click.Command("chap tui")), "lorem")
|
||||
self.api = api or get_api("lorem")
|
||||
self.session = (
|
||||
new_session(self.api.system_message) if session is None else session
|
||||
)
|
||||
|
|
@ -96,8 +73,8 @@ class Tui(App[None]):
|
|||
return cast(VerticalScroll, self.query_one("#wait"))
|
||||
|
||||
@property
|
||||
def input(self) -> SubmittableTextArea:
|
||||
return self.query_one(SubmittableTextArea)
|
||||
def input(self) -> Input:
|
||||
return self.query_one(Input)
|
||||
|
||||
@property
|
||||
def cancel_button(self) -> CancelButton:
|
||||
|
|
@ -117,9 +94,7 @@ class Tui(App[None]):
|
|||
Container(id="pad"),
|
||||
id="content",
|
||||
)
|
||||
s = SubmittableTextArea(language="markdown")
|
||||
s.show_line_numbers = False
|
||||
yield s
|
||||
yield Input(placeholder="Prompt")
|
||||
with Horizontal(id="wait"):
|
||||
yield LoadingIndicator()
|
||||
yield CancelButton(label="❌ Stop Generation", id="cancel", disabled=True)
|
||||
|
|
@ -128,8 +103,8 @@ class Tui(App[None]):
|
|||
self.container.scroll_end(animate=False)
|
||||
self.input.focus()
|
||||
|
||||
async def action_submit(self) -> None:
|
||||
self.get_completion(self.input.text)
|
||||
async def on_input_submitted(self, event: Input.Submitted) -> None:
|
||||
self.get_completion(event.value)
|
||||
|
||||
@work(exclusive=True)
|
||||
async def get_completion(self, query: str) -> None:
|
||||
|
|
@ -145,6 +120,7 @@ class Tui(App[None]):
|
|||
await self.container.mount_all(
|
||||
[markdown_for_step(User(query)), output], before="#pad"
|
||||
)
|
||||
tokens: list[str] = []
|
||||
update: asyncio.Queue[bool] = asyncio.Queue(1)
|
||||
|
||||
for markdown in self.container.children:
|
||||
|
|
@ -165,22 +141,15 @@ class Tui(App[None]):
|
|||
)
|
||||
|
||||
async def render_fun() -> None:
|
||||
old_len = 0
|
||||
while await update.get():
|
||||
content = message.content
|
||||
new_len = len(content)
|
||||
new_content = content[old_len:new_len]
|
||||
if new_content:
|
||||
if old_len:
|
||||
await output.append(new_content)
|
||||
else:
|
||||
output.update(content)
|
||||
self.container.scroll_end()
|
||||
old_len = new_len
|
||||
await asyncio.sleep(0.01)
|
||||
if tokens:
|
||||
output.update("".join(tokens).strip())
|
||||
self.container.scroll_end()
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
async def get_token_fun() -> None:
|
||||
async for token in self.api.aask(session, query):
|
||||
tokens.append(token)
|
||||
message.content += token
|
||||
try:
|
||||
update.put_nowait(True)
|
||||
|
|
@ -193,10 +162,10 @@ class Tui(App[None]):
|
|||
try:
|
||||
await asyncio.gather(render_fun(), get_token_fun())
|
||||
finally:
|
||||
self.input.clear()
|
||||
self.input.value = ""
|
||||
all_output = self.session[-1].content
|
||||
output.update(all_output)
|
||||
output._markdown = all_output
|
||||
output._markdown = all_output # pylint: disable=protected-access
|
||||
self.container.scroll_end()
|
||||
|
||||
for markdown in self.container.children:
|
||||
|
|
@ -214,8 +183,8 @@ class Tui(App[None]):
|
|||
def action_yank(self) -> None:
|
||||
widget = self.focused
|
||||
if isinstance(widget, ChapMarkdown):
|
||||
content = widget._markdown or ""
|
||||
pyperclip_copy(content)
|
||||
content = widget._markdown or "" # pylint: disable=protected-access
|
||||
subprocess.run(["xsel", "-ib"], input=content.encode("utf-8"), check=False)
|
||||
|
||||
def action_toggle_history(self) -> None:
|
||||
widget = self.focused
|
||||
|
|
@ -235,7 +204,9 @@ class Tui(App[None]):
|
|||
async def action_stop_generating(self) -> None:
|
||||
self.workers.cancel_all()
|
||||
|
||||
async def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||
async def on_button_pressed( # pylint: disable=unused-argument
|
||||
self, event: Button.Pressed
|
||||
) -> None:
|
||||
self.workers.cancel_all()
|
||||
|
||||
async def action_quit(self) -> None:
|
||||
|
|
@ -264,31 +235,19 @@ class Tui(App[None]):
|
|||
session_to_file(self.session, new_session_path())
|
||||
|
||||
query = self.session[idx].content
|
||||
self.input.load_text(query)
|
||||
self.input.value = query
|
||||
|
||||
del self.session[idx:]
|
||||
for child in self.container.children[idx:-1]:
|
||||
await child.remove()
|
||||
|
||||
self.input.focus()
|
||||
self.on_text_area_changed()
|
||||
if resubmit:
|
||||
await self.action_submit()
|
||||
|
||||
def on_text_area_changed(self, event: Any = None) -> None:
|
||||
height = self.input.document.get_size(self.input.indent_width)[1]
|
||||
max_height = max(3, self.size.height - 6)
|
||||
if height >= max_height:
|
||||
self.input.styles.height = max_height
|
||||
elif height <= 3:
|
||||
self.input.styles.height = 3
|
||||
else:
|
||||
self.input.styles.height = height
|
||||
await self.input.action_submit()
|
||||
|
||||
|
||||
@command_uses_new_session
|
||||
@click.option("--replace-system-prompt/--no-replace-system-prompt", default=False)
|
||||
def main(obj: Obj, replace_system_prompt: bool) -> None:
|
||||
def main(obj: Obj) -> None:
|
||||
"""Start interactive terminal user interface session"""
|
||||
api = obj.api
|
||||
assert api is not None
|
||||
|
|
@ -297,11 +256,6 @@ def main(obj: Obj, replace_system_prompt: bool) -> None:
|
|||
session_filename = obj.session_filename
|
||||
assert session_filename is not None
|
||||
|
||||
if replace_system_prompt:
|
||||
session[0].content = (
|
||||
api.system_message if obj.system_message is None else obj.system_message
|
||||
)
|
||||
|
||||
tui = Tui(api, session)
|
||||
tui.run()
|
||||
|
||||
|
|
@ -314,4 +268,4 @@ def main(obj: Obj, replace_system_prompt: bool) -> None:
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main() # pylint: disable=no-value-for-parameter
|
||||
|
|
|
|||
305
src/chap/core.py
305
src/chap/core.py
|
|
@ -1,53 +1,32 @@
|
|||
# SPDX-FileCopyrightText: 2023 Jeff Epler <jepler@gmail.com>
|
||||
#
|
||||
# SPDX-License-Identifier: MIT
|
||||
# pylint: disable=import-outside-toplevel
|
||||
|
||||
|
||||
from collections.abc import Sequence
|
||||
import asyncio
|
||||
import datetime
|
||||
import io
|
||||
import importlib
|
||||
import os
|
||||
import pathlib
|
||||
import pkgutil
|
||||
import subprocess
|
||||
import shlex
|
||||
import textwrap
|
||||
from dataclasses import MISSING, Field, dataclass, fields
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Callable,
|
||||
Optional,
|
||||
Union,
|
||||
IO,
|
||||
cast,
|
||||
get_origin,
|
||||
get_args,
|
||||
)
|
||||
import sys
|
||||
from dataclasses import MISSING, dataclass, fields
|
||||
from types import UnionType
|
||||
from typing import Any, AsyncGenerator, Callable, cast
|
||||
|
||||
import click
|
||||
import platformdirs
|
||||
from simple_parsing.docstring import get_attribute_docstring
|
||||
from typing_extensions import Protocol
|
||||
from . import backends, commands
|
||||
from .session import Message, Session, System, session_from_file
|
||||
|
||||
UnionType: type
|
||||
if sys.version_info >= (3, 10):
|
||||
from types import UnionType
|
||||
else:
|
||||
UnionType = type(Union[int, float])
|
||||
from . import backends, commands # pylint: disable=no-name-in-module
|
||||
from .session import Message, Session, System, session_from_file
|
||||
|
||||
conversations_path = platformdirs.user_state_path("chap") / "conversations"
|
||||
conversations_path.mkdir(parents=True, exist_ok=True)
|
||||
configuration_path = platformdirs.user_config_path("chap")
|
||||
preset_path = configuration_path / "preset"
|
||||
|
||||
|
||||
class ABackend(Protocol):
|
||||
class ABackend(Protocol): # pylint: disable=too-few-public-methods
|
||||
def aask(self, session: Session, query: str) -> AsyncGenerator[str, None]:
|
||||
"""Make a query, updating the session with the query and response, returning the query token by token"""
|
||||
|
||||
|
|
@ -60,7 +39,7 @@ class Backend(ABackend, Protocol):
|
|||
"""Make a query, updating the session with the query and response, returning the query"""
|
||||
|
||||
|
||||
class AutoAskMixin:
|
||||
class AutoAskMixin: # pylint: disable=too-few-public-methods
|
||||
"""Mixin class for backends implementing aask"""
|
||||
|
||||
def ask(self, session: Session, query: str) -> str:
|
||||
|
|
@ -75,50 +54,20 @@ class AutoAskMixin:
|
|||
return "".join(tokens)
|
||||
|
||||
|
||||
def last_session_path() -> Optional[pathlib.Path]:
|
||||
def last_session_path() -> pathlib.Path | None:
|
||||
result = max(
|
||||
conversations_path.glob("*.json"), key=lambda p: p.stat().st_mtime, default=None
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def new_session_path(opt_path: Optional[pathlib.Path] = None) -> pathlib.Path:
|
||||
def new_session_path(opt_path: pathlib.Path | None = None) -> pathlib.Path:
|
||||
return opt_path or conversations_path / (
|
||||
datetime.datetime.now().isoformat().replace(":", "_") + ".json"
|
||||
)
|
||||
|
||||
|
||||
def get_field_type(field: Field[Any]) -> Any:
|
||||
field_type = field.type
|
||||
if isinstance(field_type, str):
|
||||
raise RuntimeError(
|
||||
"parameters dataclass may not use 'from __future__ import annotations"
|
||||
)
|
||||
origin = get_origin(field_type)
|
||||
if origin in (Union, UnionType):
|
||||
for arg in get_args(field_type):
|
||||
if arg is not None:
|
||||
return arg
|
||||
return field_type
|
||||
|
||||
|
||||
def convert_str_to_field(ctx: click.Context, field: Field[Any], value: str) -> Any:
|
||||
field_type = get_field_type(field)
|
||||
try:
|
||||
if field_type is bool:
|
||||
tv = click.types.BoolParamType().convert(value, None, ctx)
|
||||
else:
|
||||
tv = field_type(value)
|
||||
return tv
|
||||
except ValueError as e:
|
||||
raise click.BadParameter(
|
||||
f"Invalid value for {field.name} with value {value}: {e}"
|
||||
) from e
|
||||
|
||||
|
||||
def configure_api_from_environment(
|
||||
ctx: click.Context, api_name: str, api: Backend
|
||||
) -> None:
|
||||
def configure_api_from_environment(api_name: str, api: Backend) -> None:
|
||||
if not hasattr(api, "parameters"):
|
||||
return
|
||||
|
||||
|
|
@ -127,25 +76,26 @@ def configure_api_from_environment(
|
|||
value = os.environ.get(envvar)
|
||||
if value is None:
|
||||
continue
|
||||
tv = convert_str_to_field(ctx, field, value)
|
||||
try:
|
||||
tv = field.type(value)
|
||||
except ValueError as e:
|
||||
raise click.BadParameter(
|
||||
f"Invalid value for {field.name} with value {value}: {e}"
|
||||
) from e
|
||||
setattr(api.parameters, field.name, tv)
|
||||
|
||||
|
||||
def get_api(ctx: click.Context | None = None, name: str | None = None) -> Backend:
|
||||
if ctx is None:
|
||||
ctx = click.Context(click.Command("chap"))
|
||||
if name is None:
|
||||
name = os.environ.get("CHAP_BACKEND", "openai_chatgpt")
|
||||
def get_api(name: str = "openai_chatgpt") -> Backend:
|
||||
name = name.replace("-", "_")
|
||||
backend = cast(
|
||||
result = cast(
|
||||
Backend, importlib.import_module(f"{__package__}.backends.{name}").factory()
|
||||
)
|
||||
configure_api_from_environment(ctx, name, backend)
|
||||
return backend
|
||||
configure_api_from_environment(name, result)
|
||||
return result
|
||||
|
||||
|
||||
def do_session_continue(
|
||||
ctx: click.Context, param: click.Parameter, value: Optional[pathlib.Path]
|
||||
ctx: click.Context, param: click.Parameter, value: pathlib.Path | None
|
||||
) -> None:
|
||||
if value is None:
|
||||
return
|
||||
|
|
@ -158,7 +108,9 @@ def do_session_continue(
|
|||
ctx.obj.session_filename = value
|
||||
|
||||
|
||||
def do_session_last(ctx: click.Context, param: click.Parameter, value: bool) -> None:
|
||||
def do_session_last(
|
||||
ctx: click.Context, param: click.Parameter, value: bool
|
||||
) -> None: # pylint: disable=unused-argument
|
||||
if not value:
|
||||
return
|
||||
do_session_continue(ctx, param, last_session_path())
|
||||
|
|
@ -186,22 +138,18 @@ def colonstr(arg: str) -> tuple[str, str]:
|
|||
return cast(tuple[str, str], tuple(arg.split(":", 1)))
|
||||
|
||||
|
||||
def set_system_message(ctx: click.Context, param: click.Parameter, value: str) -> None:
|
||||
if value is None:
|
||||
return
|
||||
def set_system_message( # pylint: disable=unused-argument
|
||||
ctx: click.Context, param: click.Parameter, value: str
|
||||
) -> None:
|
||||
if value and value.startswith("@"):
|
||||
with open(value[1:], "r", encoding="utf-8") as f:
|
||||
value = f.read().rstrip()
|
||||
ctx.obj.system_message = value
|
||||
|
||||
|
||||
def set_system_message_from_file(
|
||||
ctx: click.Context, param: click.Parameter, value: io.TextIOWrapper
|
||||
def set_backend( # pylint: disable=unused-argument
|
||||
ctx: click.Context, param: click.Parameter, value: str
|
||||
) -> None:
|
||||
if value is None:
|
||||
return
|
||||
content = value.read().strip()
|
||||
ctx.obj.system_message = content
|
||||
|
||||
|
||||
def set_backend(ctx: click.Context, param: click.Parameter, value: str) -> None:
|
||||
if value == "list":
|
||||
formatter = ctx.make_formatter()
|
||||
format_backend_list(formatter)
|
||||
|
|
@ -209,7 +157,7 @@ def set_backend(ctx: click.Context, param: click.Parameter, value: str) -> None:
|
|||
ctx.exit()
|
||||
|
||||
try:
|
||||
ctx.obj.api = get_api(ctx, value)
|
||||
ctx.obj.api = get_api(value)
|
||||
except ModuleNotFoundError as e:
|
||||
raise click.BadParameter(str(e))
|
||||
|
||||
|
|
@ -224,13 +172,15 @@ def format_backend_help(api: Backend, formatter: click.HelpFormatter) -> None:
|
|||
if doc:
|
||||
doc += " "
|
||||
doc += f"(Default: {default!r})"
|
||||
f_type = get_field_type(f)
|
||||
f_type = f.type
|
||||
if isinstance(f_type, UnionType):
|
||||
f_type = f_type.__args__[0]
|
||||
typename = f_type.__name__
|
||||
rows.append((f"-B {name}:{typename.upper()}", doc))
|
||||
formatter.write_dl(rows)
|
||||
|
||||
|
||||
def set_backend_option(
|
||||
def set_backend_option( # pylint: disable=unused-argument
|
||||
ctx: click.Context, param: click.Parameter, opts: list[tuple[str, str]]
|
||||
) -> None:
|
||||
api = ctx.obj.api
|
||||
|
|
@ -245,8 +195,16 @@ def set_backend_option(
|
|||
field = all_fields.get(name)
|
||||
if field is None:
|
||||
raise click.BadParameter(f"Invalid parameter {name}")
|
||||
tv = convert_str_to_field(ctx, field, value)
|
||||
setattr(api.parameters, field.name, tv)
|
||||
f_type = field.type
|
||||
if isinstance(f_type, UnionType):
|
||||
f_type = f_type.__args__[0]
|
||||
try:
|
||||
tv = f_type(value)
|
||||
except ValueError as e:
|
||||
raise click.BadParameter(
|
||||
f"Invalid value for {name} with value {value}: {e}"
|
||||
) from e
|
||||
setattr(api.parameters, name, tv)
|
||||
|
||||
for kv in opts:
|
||||
set_one_backend_option(kv)
|
||||
|
|
@ -306,7 +264,9 @@ def command_uses_new_session(f_in: click.decorators.FC) -> click.Command:
|
|||
return click.command()(f)
|
||||
|
||||
|
||||
def version_callback(ctx: click.Context, param: click.Parameter, value: None) -> None:
|
||||
def version_callback( # pylint: disable=unused-argument
|
||||
ctx: click.Context, param: click.Parameter, value: None
|
||||
) -> None:
|
||||
if not value or ctx.resilient_parsing:
|
||||
return
|
||||
|
||||
|
|
@ -333,54 +293,18 @@ def version_callback(ctx: click.Context, param: click.Parameter, value: None) ->
|
|||
|
||||
@dataclass
|
||||
class Obj:
|
||||
api: Optional[Backend] = None
|
||||
system_message: Optional[str] = None
|
||||
session: Optional[list[Message]] = None
|
||||
session_filename: Optional[pathlib.Path] = None
|
||||
api: Backend | None = None
|
||||
system_message: str | None = None
|
||||
session: list[Message] | None = None
|
||||
session_filename: pathlib.Path | None = None
|
||||
|
||||
|
||||
def maybe_add_txt_extension(fn: pathlib.Path) -> pathlib.Path:
|
||||
if not fn.exists():
|
||||
fn1 = pathlib.Path(str(fn) + ".txt")
|
||||
if fn1.exists():
|
||||
fn = fn1
|
||||
return fn
|
||||
|
||||
|
||||
def expand_splats(args: list[str]) -> list[str]:
|
||||
result = []
|
||||
saw_at_dot = False
|
||||
for a in args:
|
||||
if a == "@.":
|
||||
saw_at_dot = True
|
||||
continue
|
||||
if saw_at_dot:
|
||||
result.append(a)
|
||||
continue
|
||||
if a.startswith("@@"): ## double @ to escape an argument that starts with @
|
||||
result.append(a[1:])
|
||||
continue
|
||||
if not a.startswith("@"):
|
||||
result.append(a)
|
||||
continue
|
||||
if a.startswith("@:"):
|
||||
fn: pathlib.Path = preset_path / a[2:]
|
||||
else:
|
||||
fn = pathlib.Path(a[1:])
|
||||
fn = maybe_add_txt_extension(fn)
|
||||
with open(fn, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
parts = shlex.split(content)
|
||||
result.extend(expand_splats(parts))
|
||||
return result
|
||||
|
||||
|
||||
class MyCLI(click.Group):
|
||||
class MyCLI(click.MultiCommand):
|
||||
def make_context(
|
||||
self,
|
||||
info_name: Optional[str],
|
||||
info_name: str | None,
|
||||
args: list[str],
|
||||
parent: Optional[click.Context] = None,
|
||||
parent: click.Context | None = None,
|
||||
**extra: Any,
|
||||
) -> click.Context:
|
||||
result = super().make_context(info_name, args, parent, obj=Obj(), **extra)
|
||||
|
|
@ -404,117 +328,14 @@ class MyCLI(click.Group):
|
|||
except ModuleNotFoundError as exc:
|
||||
raise click.UsageError(f"Invalid subcommand {cmd_name!r}", ctx) from exc
|
||||
|
||||
def gather_preset_info(self) -> list[tuple[str, str]]:
|
||||
result = []
|
||||
for p in preset_path.glob("*"):
|
||||
if p.is_file():
|
||||
with p.open() as f:
|
||||
first_line = f.readline()
|
||||
if first_line.startswith("#"):
|
||||
help_str = first_line[1:].strip()
|
||||
else:
|
||||
help_str = "(A comment on the first line would be shown here)"
|
||||
result.append((f"@:{p.name}", help_str))
|
||||
return result
|
||||
|
||||
def format_splat_options(
|
||||
self, ctx: click.Context, formatter: click.HelpFormatter
|
||||
) -> None:
|
||||
with formatter.section("Splats"):
|
||||
formatter.write_text(
|
||||
"Before any other command-line argument parsing is performed, @FILE arguments are expanded:"
|
||||
)
|
||||
formatter.write_paragraph()
|
||||
formatter.indent()
|
||||
formatter.write_dl(
|
||||
[
|
||||
("@FILE", "Argument is searched relative to the current directory"),
|
||||
(
|
||||
"@:FILE",
|
||||
"Argument is searched relative to the configuration directory (e.g., $HOME/.config/chap/preset)",
|
||||
),
|
||||
("@@…", "If an argument starts with a literal '@', double it"),
|
||||
(
|
||||
"@.",
|
||||
"Stops processing any further `@FILE` arguments and leaves them unchanged.",
|
||||
),
|
||||
]
|
||||
)
|
||||
formatter.dedent()
|
||||
formatter.write_paragraph()
|
||||
formatter.write_text(
|
||||
textwrap.dedent(
|
||||
"""\
|
||||
The contents of an `@FILE` are parsed by `shlex.split(comments=True)`.
|
||||
Comments are supported. If the filename ends in .txt,
|
||||
the extension may be omitted."""
|
||||
)
|
||||
)
|
||||
formatter.write_paragraph()
|
||||
if preset_info := self.gather_preset_info():
|
||||
formatter.write_text("Presets found:")
|
||||
formatter.write_paragraph()
|
||||
formatter.indent()
|
||||
formatter.write_dl(preset_info)
|
||||
formatter.dedent()
|
||||
|
||||
def format_options(
|
||||
self, ctx: click.Context, formatter: click.HelpFormatter
|
||||
) -> None:
|
||||
self.format_splat_options(ctx, formatter)
|
||||
super().format_options(ctx, formatter)
|
||||
api = ctx.obj.api or get_api(ctx)
|
||||
api = ctx.obj.api or get_api()
|
||||
if hasattr(api, "parameters"):
|
||||
format_backend_help(api, formatter)
|
||||
|
||||
def main(
|
||||
self,
|
||||
args: Sequence[str] | None = None,
|
||||
prog_name: str | None = None,
|
||||
complete_var: str | None = None,
|
||||
standalone_mode: bool = True,
|
||||
windows_expand_args: bool = True,
|
||||
**extra: Any,
|
||||
) -> Any:
|
||||
if args is None:
|
||||
args = sys.argv[1:]
|
||||
if os.name == "nt" and windows_expand_args:
|
||||
args = click.utils._expand_args(args)
|
||||
else:
|
||||
args = list(args)
|
||||
|
||||
args = expand_splats(args)
|
||||
|
||||
return super().main(
|
||||
args,
|
||||
prog_name=prog_name,
|
||||
complete_var=complete_var,
|
||||
standalone_mode=standalone_mode,
|
||||
windows_expand_args=windows_expand_args,
|
||||
**extra,
|
||||
)
|
||||
|
||||
|
||||
class ConfigRelativeFile(click.File):
|
||||
def __init__(self, mode: str, where: str) -> None:
|
||||
super().__init__(mode)
|
||||
self.where = where
|
||||
|
||||
def convert(
|
||||
self,
|
||||
value: str | os.PathLike[str] | IO[Any],
|
||||
param: click.Parameter | None,
|
||||
ctx: click.Context | None,
|
||||
) -> Any:
|
||||
if isinstance(value, str):
|
||||
if value.startswith(":"):
|
||||
value = configuration_path / self.where / value[1:]
|
||||
else:
|
||||
value = pathlib.Path(value)
|
||||
if isinstance(value, pathlib.Path):
|
||||
value = maybe_add_txt_extension(value)
|
||||
return super().convert(value, param, ctx)
|
||||
|
||||
|
||||
main = MyCLI(
|
||||
help="Commandline interface to ChatGPT",
|
||||
|
|
@ -526,14 +347,6 @@ main = MyCLI(
|
|||
help="Show the version and exit",
|
||||
callback=version_callback,
|
||||
),
|
||||
click.Option(
|
||||
("--system-message-file", "-s"),
|
||||
type=ConfigRelativeFile("r", where="prompt"),
|
||||
default=None,
|
||||
callback=set_system_message_from_file,
|
||||
expose_value=False,
|
||||
help=f"Set the system message from a file. If the filename starts with `:` it is relative to the {configuration_path}/prompt. If the filename ends in .txt, the extension may be omitted.",
|
||||
),
|
||||
click.Option(
|
||||
("--system-message", "-S"),
|
||||
type=str,
|
||||
|
|
|
|||
|
|
@ -2,63 +2,25 @@
|
|||
#
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import json
|
||||
import subprocess
|
||||
from typing import Protocol
|
||||
import functools
|
||||
|
||||
import platformdirs
|
||||
|
||||
|
||||
class APIKeyProtocol(Protocol):
|
||||
@property
|
||||
def api_key_name(self) -> str:
|
||||
...
|
||||
|
||||
|
||||
class HasKeyProtocol(Protocol):
|
||||
@property
|
||||
def parameters(self) -> APIKeyProtocol:
|
||||
...
|
||||
|
||||
|
||||
class UsesKeyMixin:
|
||||
def get_key(self: HasKeyProtocol) -> str:
|
||||
return get_key(self.parameters.api_key_name)
|
||||
|
||||
|
||||
class NoKeyAvailable(Exception):
|
||||
pass
|
||||
|
||||
|
||||
_key_path_base = platformdirs.user_config_path("chap")
|
||||
|
||||
USE_PASSWORD_STORE = _key_path_base / "USE_PASSWORD_STORE"
|
||||
|
||||
if USE_PASSWORD_STORE.exists():
|
||||
content = USE_PASSWORD_STORE.read_text(encoding="utf-8")
|
||||
if content.strip():
|
||||
cfg = json.loads(content)
|
||||
pass_command: list[str] = cfg.get("PASS_COMMAND", ["pass", "show"])
|
||||
pass_prefix: str = cfg.get("PASS_PREFIX", "chap/")
|
||||
@functools.cache
|
||||
def get_key(name: str, what: str = "openai api key") -> str:
|
||||
key_path = _key_path_base / name
|
||||
if not key_path.exists():
|
||||
raise NoKeyAvailable(
|
||||
f"Place your {what} in {key_path} and run the program again"
|
||||
)
|
||||
|
||||
@functools.cache
|
||||
def get_key(name: str, what: str = "api key") -> str:
|
||||
if name == "-":
|
||||
return "-"
|
||||
key_path = f"{pass_prefix}{name}"
|
||||
command = pass_command + [key_path]
|
||||
return subprocess.check_output(command, encoding="utf-8").split("\n")[0]
|
||||
|
||||
else:
|
||||
|
||||
@functools.cache
|
||||
def get_key(name: str, what: str = "api key") -> str:
|
||||
key_path = _key_path_base / name
|
||||
if not key_path.exists():
|
||||
raise NoKeyAvailable(
|
||||
f"Place your {what} in {key_path} and run the program again"
|
||||
)
|
||||
|
||||
with open(key_path, encoding="utf-8") as f:
|
||||
return f.read().strip()
|
||||
with open(key_path, encoding="utf-8") as f:
|
||||
return f.read().strip()
|
||||
|
|
|
|||
|
|
@ -7,13 +7,13 @@ from __future__ import annotations
|
|||
import json
|
||||
import pathlib
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Union, cast
|
||||
from typing import cast
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
# not an enum.Enum because these objects are not json-serializable, sigh
|
||||
class Role:
|
||||
class Role: # pylint: disable=too-few-public-methods
|
||||
ASSISTANT = "assistant"
|
||||
SYSTEM = "system"
|
||||
USER = "user"
|
||||
|
|
@ -65,11 +65,11 @@ def session_from_json(data: str) -> Session:
|
|||
return [Message(**mapping) for mapping in j]
|
||||
|
||||
|
||||
def session_from_file(path: Union[pathlib.Path, str]) -> Session:
|
||||
def session_from_file(path: pathlib.Path | str) -> Session:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
return session_from_json(f.read())
|
||||
|
||||
|
||||
def session_to_file(session: Session, path: Union[pathlib.Path, str]) -> None:
|
||||
def session_to_file(session: Session, path: pathlib.Path | str) -> None:
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
f.write(session_to_json(session))
|
||||
|
|
|
|||
Loading…
Reference in a new issue