Compare commits
No commits in common. "main" and "document-mypy" have entirely different histories.
main
...
document-m
28 changed files with 383 additions and 740 deletions
48
.github/workflows/codeql.yml
vendored
Normal file
48
.github/workflows/codeql.yml
vendored
Normal file
|
|
@ -0,0 +1,48 @@
|
||||||
|
# SPDX-FileCopyrightText: 2022 Jeff Epler
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: CC0-1.0
|
||||||
|
|
||||||
|
name: "CodeQL"
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ "main" ]
|
||||||
|
pull_request:
|
||||||
|
branches: [ "main" ]
|
||||||
|
schedule:
|
||||||
|
- cron: "53 3 * * 5"
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
analyze:
|
||||||
|
name: Analyze
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
permissions:
|
||||||
|
actions: read
|
||||||
|
contents: read
|
||||||
|
security-events: write
|
||||||
|
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
language: [ python ]
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v3
|
||||||
|
|
||||||
|
- name: Install Dependencies (python)
|
||||||
|
run: pip3 install -r requirements-dev.txt
|
||||||
|
|
||||||
|
- name: Initialize CodeQL
|
||||||
|
uses: github/codeql-action/init@v2
|
||||||
|
with:
|
||||||
|
languages: ${{ matrix.language }}
|
||||||
|
queries: +security-and-quality
|
||||||
|
|
||||||
|
- name: Autobuild
|
||||||
|
uses: github/codeql-action/autobuild@v2
|
||||||
|
|
||||||
|
- name: Perform CodeQL Analysis
|
||||||
|
uses: github/codeql-action/analyze@v2
|
||||||
|
with:
|
||||||
|
category: "/language:${{ matrix.language }}"
|
||||||
4
.github/workflows/release.yml
vendored
4
.github/workflows/release.yml
vendored
|
|
@ -18,10 +18,10 @@ jobs:
|
||||||
GITHUB_CONTEXT: ${{ toJson(github) }}
|
GITHUB_CONTEXT: ${{ toJson(github) }}
|
||||||
run: echo "$GITHUB_CONTEXT"
|
run: echo "$GITHUB_CONTEXT"
|
||||||
|
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v3
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v4
|
||||||
with:
|
with:
|
||||||
python-version: 3.11
|
python-version: 3.11
|
||||||
|
|
||||||
|
|
|
||||||
12
.github/workflows/test.yml
vendored
12
.github/workflows/test.yml
vendored
|
|
@ -17,10 +17,10 @@ jobs:
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
|
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v3
|
||||||
|
|
||||||
- name: pre-commit
|
- name: pre-commit
|
||||||
uses: pre-commit/action@v3.0.1
|
uses: pre-commit/action@v3.0.0
|
||||||
|
|
||||||
- name: Make patch
|
- name: Make patch
|
||||||
if: failure()
|
if: failure()
|
||||||
|
|
@ -28,7 +28,7 @@ jobs:
|
||||||
|
|
||||||
- name: Upload patch
|
- name: Upload patch
|
||||||
if: failure()
|
if: failure()
|
||||||
uses: actions/upload-artifact@v4
|
uses: actions/upload-artifact@v3
|
||||||
with:
|
with:
|
||||||
name: patch
|
name: patch
|
||||||
path: ~/pre-commit.patch
|
path: ~/pre-commit.patch
|
||||||
|
|
@ -36,10 +36,10 @@ jobs:
|
||||||
test-release:
|
test-release:
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v3
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v4
|
||||||
with:
|
with:
|
||||||
python-version: 3.11
|
python-version: 3.11
|
||||||
|
|
||||||
|
|
@ -53,7 +53,7 @@ jobs:
|
||||||
run: python -mbuild
|
run: python -mbuild
|
||||||
|
|
||||||
- name: Upload artifacts
|
- name: Upload artifacts
|
||||||
uses: actions/upload-artifact@v4
|
uses: actions/upload-artifact@v3
|
||||||
with:
|
with:
|
||||||
name: dist
|
name: dist
|
||||||
path: dist/*
|
path: dist/*
|
||||||
|
|
|
||||||
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -8,4 +8,3 @@ __pycache__
|
||||||
/dist
|
/dist
|
||||||
/src/chap/__version__.py
|
/src/chap/__version__.py
|
||||||
/venv
|
/venv
|
||||||
/keys.log
|
|
||||||
|
|
|
||||||
|
|
@ -6,8 +6,12 @@ default_language_version:
|
||||||
python: python3
|
python: python3
|
||||||
|
|
||||||
repos:
|
repos:
|
||||||
|
- repo: https://github.com/psf/black
|
||||||
|
rev: 23.1.0
|
||||||
|
hooks:
|
||||||
|
- id: black
|
||||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
rev: v4.5.0
|
rev: v4.4.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: check-yaml
|
- id: check-yaml
|
||||||
- id: end-of-file-fixer
|
- id: end-of-file-fixer
|
||||||
|
|
@ -15,20 +19,23 @@ repos:
|
||||||
- id: trailing-whitespace
|
- id: trailing-whitespace
|
||||||
exclude: tests
|
exclude: tests
|
||||||
- repo: https://github.com/codespell-project/codespell
|
- repo: https://github.com/codespell-project/codespell
|
||||||
rev: v2.2.6
|
rev: v2.2.4
|
||||||
hooks:
|
hooks:
|
||||||
- id: codespell
|
- id: codespell
|
||||||
args: [-w]
|
args: [-w]
|
||||||
- repo: https://github.com/fsfe/reuse-tool
|
- repo: https://github.com/fsfe/reuse-tool
|
||||||
rev: v2.1.0
|
rev: v1.1.2
|
||||||
hooks:
|
hooks:
|
||||||
- id: reuse
|
- id: reuse
|
||||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
- repo: https://github.com/pycqa/isort
|
||||||
# Ruff version.
|
rev: 5.12.0
|
||||||
rev: v0.1.6
|
|
||||||
hooks:
|
hooks:
|
||||||
# Run the linter.
|
- id: isort
|
||||||
- id: ruff
|
name: isort (python)
|
||||||
args: [ --fix, --preview ]
|
args: ['--profile', 'black']
|
||||||
# Run the formatter.
|
- repo: https://github.com/pycqa/pylint
|
||||||
- id: ruff-format
|
rev: v2.17.0
|
||||||
|
hooks:
|
||||||
|
- id: pylint
|
||||||
|
additional_dependencies: [click,httpx,lorem-text,simple-parsing,'textual>=0.18.0',tiktoken,websockets]
|
||||||
|
args: ['--source-roots', 'src']
|
||||||
|
|
|
||||||
12
.pylintrc
Normal file
12
.pylintrc
Normal file
|
|
@ -0,0 +1,12 @@
|
||||||
|
# SPDX-FileCopyrightText: 2023 Jeff Epler <jepler@gmail.com>
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
[MESSAGES CONTROL]
|
||||||
|
disable=
|
||||||
|
duplicate-code,
|
||||||
|
invalid-name,
|
||||||
|
line-too-long,
|
||||||
|
missing-class-docstring,
|
||||||
|
missing-function-docstring,
|
||||||
|
missing-module-docstring,
|
||||||
121
LICENSES/CC0-1.0.txt
Normal file
121
LICENSES/CC0-1.0.txt
Normal file
|
|
@ -0,0 +1,121 @@
|
||||||
|
Creative Commons Legal Code
|
||||||
|
|
||||||
|
CC0 1.0 Universal
|
||||||
|
|
||||||
|
CREATIVE COMMONS CORPORATION IS NOT A LAW FIRM AND DOES NOT PROVIDE
|
||||||
|
LEGAL SERVICES. DISTRIBUTION OF THIS DOCUMENT DOES NOT CREATE AN
|
||||||
|
ATTORNEY-CLIENT RELATIONSHIP. CREATIVE COMMONS PROVIDES THIS
|
||||||
|
INFORMATION ON AN "AS-IS" BASIS. CREATIVE COMMONS MAKES NO WARRANTIES
|
||||||
|
REGARDING THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS
|
||||||
|
PROVIDED HEREUNDER, AND DISCLAIMS LIABILITY FOR DAMAGES RESULTING FROM
|
||||||
|
THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS PROVIDED
|
||||||
|
HEREUNDER.
|
||||||
|
|
||||||
|
Statement of Purpose
|
||||||
|
|
||||||
|
The laws of most jurisdictions throughout the world automatically confer
|
||||||
|
exclusive Copyright and Related Rights (defined below) upon the creator
|
||||||
|
and subsequent owner(s) (each and all, an "owner") of an original work of
|
||||||
|
authorship and/or a database (each, a "Work").
|
||||||
|
|
||||||
|
Certain owners wish to permanently relinquish those rights to a Work for
|
||||||
|
the purpose of contributing to a commons of creative, cultural and
|
||||||
|
scientific works ("Commons") that the public can reliably and without fear
|
||||||
|
of later claims of infringement build upon, modify, incorporate in other
|
||||||
|
works, reuse and redistribute as freely as possible in any form whatsoever
|
||||||
|
and for any purposes, including without limitation commercial purposes.
|
||||||
|
These owners may contribute to the Commons to promote the ideal of a free
|
||||||
|
culture and the further production of creative, cultural and scientific
|
||||||
|
works, or to gain reputation or greater distribution for their Work in
|
||||||
|
part through the use and efforts of others.
|
||||||
|
|
||||||
|
For these and/or other purposes and motivations, and without any
|
||||||
|
expectation of additional consideration or compensation, the person
|
||||||
|
associating CC0 with a Work (the "Affirmer"), to the extent that he or she
|
||||||
|
is an owner of Copyright and Related Rights in the Work, voluntarily
|
||||||
|
elects to apply CC0 to the Work and publicly distribute the Work under its
|
||||||
|
terms, with knowledge of his or her Copyright and Related Rights in the
|
||||||
|
Work and the meaning and intended legal effect of CC0 on those rights.
|
||||||
|
|
||||||
|
1. Copyright and Related Rights. A Work made available under CC0 may be
|
||||||
|
protected by copyright and related or neighboring rights ("Copyright and
|
||||||
|
Related Rights"). Copyright and Related Rights include, but are not
|
||||||
|
limited to, the following:
|
||||||
|
|
||||||
|
i. the right to reproduce, adapt, distribute, perform, display,
|
||||||
|
communicate, and translate a Work;
|
||||||
|
ii. moral rights retained by the original author(s) and/or performer(s);
|
||||||
|
iii. publicity and privacy rights pertaining to a person's image or
|
||||||
|
likeness depicted in a Work;
|
||||||
|
iv. rights protecting against unfair competition in regards to a Work,
|
||||||
|
subject to the limitations in paragraph 4(a), below;
|
||||||
|
v. rights protecting the extraction, dissemination, use and reuse of data
|
||||||
|
in a Work;
|
||||||
|
vi. database rights (such as those arising under Directive 96/9/EC of the
|
||||||
|
European Parliament and of the Council of 11 March 1996 on the legal
|
||||||
|
protection of databases, and under any national implementation
|
||||||
|
thereof, including any amended or successor version of such
|
||||||
|
directive); and
|
||||||
|
vii. other similar, equivalent or corresponding rights throughout the
|
||||||
|
world based on applicable law or treaty, and any national
|
||||||
|
implementations thereof.
|
||||||
|
|
||||||
|
2. Waiver. To the greatest extent permitted by, but not in contravention
|
||||||
|
of, applicable law, Affirmer hereby overtly, fully, permanently,
|
||||||
|
irrevocably and unconditionally waives, abandons, and surrenders all of
|
||||||
|
Affirmer's Copyright and Related Rights and associated claims and causes
|
||||||
|
of action, whether now known or unknown (including existing as well as
|
||||||
|
future claims and causes of action), in the Work (i) in all territories
|
||||||
|
worldwide, (ii) for the maximum duration provided by applicable law or
|
||||||
|
treaty (including future time extensions), (iii) in any current or future
|
||||||
|
medium and for any number of copies, and (iv) for any purpose whatsoever,
|
||||||
|
including without limitation commercial, advertising or promotional
|
||||||
|
purposes (the "Waiver"). Affirmer makes the Waiver for the benefit of each
|
||||||
|
member of the public at large and to the detriment of Affirmer's heirs and
|
||||||
|
successors, fully intending that such Waiver shall not be subject to
|
||||||
|
revocation, rescission, cancellation, termination, or any other legal or
|
||||||
|
equitable action to disrupt the quiet enjoyment of the Work by the public
|
||||||
|
as contemplated by Affirmer's express Statement of Purpose.
|
||||||
|
|
||||||
|
3. Public License Fallback. Should any part of the Waiver for any reason
|
||||||
|
be judged legally invalid or ineffective under applicable law, then the
|
||||||
|
Waiver shall be preserved to the maximum extent permitted taking into
|
||||||
|
account Affirmer's express Statement of Purpose. In addition, to the
|
||||||
|
extent the Waiver is so judged Affirmer hereby grants to each affected
|
||||||
|
person a royalty-free, non transferable, non sublicensable, non exclusive,
|
||||||
|
irrevocable and unconditional license to exercise Affirmer's Copyright and
|
||||||
|
Related Rights in the Work (i) in all territories worldwide, (ii) for the
|
||||||
|
maximum duration provided by applicable law or treaty (including future
|
||||||
|
time extensions), (iii) in any current or future medium and for any number
|
||||||
|
of copies, and (iv) for any purpose whatsoever, including without
|
||||||
|
limitation commercial, advertising or promotional purposes (the
|
||||||
|
"License"). The License shall be deemed effective as of the date CC0 was
|
||||||
|
applied by Affirmer to the Work. Should any part of the License for any
|
||||||
|
reason be judged legally invalid or ineffective under applicable law, such
|
||||||
|
partial invalidity or ineffectiveness shall not invalidate the remainder
|
||||||
|
of the License, and in such case Affirmer hereby affirms that he or she
|
||||||
|
will not (i) exercise any of his or her remaining Copyright and Related
|
||||||
|
Rights in the Work or (ii) assert any associated claims and causes of
|
||||||
|
action with respect to the Work, in either case contrary to Affirmer's
|
||||||
|
express Statement of Purpose.
|
||||||
|
|
||||||
|
4. Limitations and Disclaimers.
|
||||||
|
|
||||||
|
a. No trademark or patent rights held by Affirmer are waived, abandoned,
|
||||||
|
surrendered, licensed or otherwise affected by this document.
|
||||||
|
b. Affirmer offers the Work as-is and makes no representations or
|
||||||
|
warranties of any kind concerning the Work, express, implied,
|
||||||
|
statutory or otherwise, including without limitation warranties of
|
||||||
|
title, merchantability, fitness for a particular purpose, non
|
||||||
|
infringement, or the absence of latent or other defects, accuracy, or
|
||||||
|
the present or absence of errors, whether or not discoverable, all to
|
||||||
|
the greatest extent permissible under applicable law.
|
||||||
|
c. Affirmer disclaims responsibility for clearing rights of other persons
|
||||||
|
that may apply to the Work or any use thereof, including without
|
||||||
|
limitation any person's Copyright and Related Rights in the Work.
|
||||||
|
Further, Affirmer disclaims responsibility for obtaining any necessary
|
||||||
|
consents, permissions or other rights required for any use of the
|
||||||
|
Work.
|
||||||
|
d. Affirmer understands and acknowledges that Creative Commons is not a
|
||||||
|
party to this document and has no duty or obligation with respect to
|
||||||
|
this CC0 or use of the Work.
|
||||||
65
README.md
65
README.md
|
|
@ -9,11 +9,11 @@ SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
# chap - A Python interface to chatgpt and other LLMs, including a terminal user interface (tui)
|
# chap - A Python interface to chatgpt and other LLMs, including a terminal user interface (tui)
|
||||||
|
|
||||||

|

|
||||||
|
|
||||||
## System requirements
|
## System requirements
|
||||||
|
|
||||||
Chap is primarily developed on Linux with Python 3.11. Moderate effort will be made to support versions back to Python 3.9 (Debian oldstable).
|
Chap is developed on Linux with Python 3.11. Due to use of the `X | Y` style of type hints, it is known to not work on Python 3.9 and older. The target minimum Python version is 3.11 (debian stable).
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
|
|
@ -81,54 +81,8 @@ Put your OpenAI API key in the platform configuration directory for chap, e.g.,
|
||||||
|
|
||||||
* `chap grep needle`
|
* `chap grep needle`
|
||||||
|
|
||||||
## `@FILE` arguments
|
|
||||||
|
|
||||||
It's useful to set a bunch of related arguments together, for instance to fully
|
|
||||||
configure a back-end. This functionality is implemented via `@FILE` arguments.
|
|
||||||
|
|
||||||
Before any other command-line argument parsing is performed, `@FILE` arguments are expanded:
|
|
||||||
|
|
||||||
* An `@FILE` argument is searched relative to the current directory
|
|
||||||
* An `@:FILE` argument is searched relative to the configuration directory (e.g., $HOME/.config/chap/presets)
|
|
||||||
* If an argument starts with a literal `@`, double it: `@@`
|
|
||||||
* `@.` stops processing any further `@FILE` arguments and leaves them unchanged.
|
|
||||||
The contents of an `@FILE` are parsed according to `shlex.split(comments=True)`.
|
|
||||||
Comments are supported.
|
|
||||||
A typical content might look like this:
|
|
||||||
```
|
|
||||||
# cfg/gpt-4o: Use more expensive gpt 4o and custom prompt
|
|
||||||
--backend openai-chatgpt
|
|
||||||
-B model:gpt-4o
|
|
||||||
-s :my-custom-system-message.txt
|
|
||||||
```
|
|
||||||
and you might use it with
|
|
||||||
```
|
|
||||||
chap @:cfg/gpt-4o ask what version of gpt is this
|
|
||||||
```
|
|
||||||
|
|
||||||
## Interactive terminal usage
|
## Interactive terminal usage
|
||||||
The interactive terminal mode is accessed via `chap tui`.
|
* `chap tui`
|
||||||
|
|
||||||
There are a variety of keyboard shortcuts to be aware of:
|
|
||||||
* tab/shift-tab to move between the entry field and the conversation, or between conversation items
|
|
||||||
* While in the text box, F9 or (if supported by your terminal) alt+enter to submit multiline text
|
|
||||||
* while on a conversation item:
|
|
||||||
* ctrl+x to re-draft the message. This
|
|
||||||
* saves a copy of the session in an auto-named file in the conversations folder
|
|
||||||
* removes the conversation from this message to the end
|
|
||||||
* puts the user's message in the text box to edit
|
|
||||||
* ctrl+x to re-submit the message. This
|
|
||||||
* saves a copy of the session in an auto-named file in the conversations folder
|
|
||||||
* removes the conversation from this message to the end
|
|
||||||
* puts the user's message in the text box
|
|
||||||
* and submits it immediately
|
|
||||||
* ctrl+y to yank the message. This places the response part of the current
|
|
||||||
interaction in the operating system clipboard to be pasted (e..g, with
|
|
||||||
ctrl+v or command+v in other software)
|
|
||||||
* ctrl+q to toggle whether this message may be included in the contextual history for a future query.
|
|
||||||
The exact way history is submitted is determined by the back-end, often by
|
|
||||||
counting messages or tokens, but the ctrl+q toggle ensures this message (both the user
|
|
||||||
and assistant message parts) are not considered.
|
|
||||||
|
|
||||||
## Sessions & Command-line Parameters
|
## Sessions & Command-line Parameters
|
||||||
|
|
||||||
|
|
@ -144,26 +98,21 @@ an existing session with `-s`. Or, you can continue the last session with
|
||||||
You can set the "system message" with the `-S` flag.
|
You can set the "system message" with the `-S` flag.
|
||||||
|
|
||||||
You can select the text generating backend with the `-b` flag:
|
You can select the text generating backend with the `-b` flag:
|
||||||
* openai-chatgpt: the default, paid API, best quality results. Also works with compatible API implementations including llama-cpp when the correct backend URL is specified.
|
* openai-chatgpt: the default, paid API, best quality results
|
||||||
* llama-cpp: Works with [llama.cpp's http server](https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md) and can run locally with various models,
|
* llama-cpp: Works with [llama.cpp's http server](https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md) and can run locally with various models,
|
||||||
though it is [optimized for models that use the llama2-style prompting](https://huggingface.co/blog/llama2#how-to-prompt-llama-2). Consider using llama.cpp's OpenAI compatible API with the openai-chatgpt backend instead, in which case the server can apply the chat template.
|
though it is [optimized for models that use the llama2-style prompting](https://huggingface.co/blog/llama2#how-to-prompt-llama-2).
|
||||||
|
Set the server URL with `-B url:...`.
|
||||||
* textgen: Works with https://github.com/oobabooga/text-generation-webui and can run locally with various models.
|
* textgen: Works with https://github.com/oobabooga/text-generation-webui and can run locally with various models.
|
||||||
Needs the server URL in *$configuration_directory/textgen\_url*.
|
Needs the server URL in *$configuration_directory/textgen\_url*.
|
||||||
* mistral: Works with the [mistral paid API](https://docs.mistral.ai/).
|
|
||||||
* anthropic: Works with the [anthropic paid API](https://docs.anthropic.com/en/home).
|
|
||||||
* huggingface: Works with the [huggingface API](https://huggingface.co/docs/api-inference/index), which includes a free tier.
|
|
||||||
* lorem: local non-AI lorem generator for testing
|
* lorem: local non-AI lorem generator for testing
|
||||||
|
|
||||||
Backends have settings such as URLs and where API keys are stored. use `chap --backend
|
|
||||||
<BACKEND> --help` to list settings for a particular backend.
|
|
||||||
|
|
||||||
## Environment variables
|
## Environment variables
|
||||||
|
|
||||||
The backend can be set with the `CHAP_BACKEND` environment variable.
|
The backend can be set with the `CHAP_BACKEND` environment variable.
|
||||||
|
|
||||||
Backend settings can be set with `CHAP_<backend_name>_<parameter_name>`, with `backend_name` and `parameter_name` all in caps.
|
Backend settings can be set with `CHAP_<backend_name>_<parameter_name>`, with `backend_name` and `parameter_name` all in caps.
|
||||||
|
|
||||||
For instance, `CHAP_LLAMA_CPP_URL=http://server.local:8080/completion` changes the default server URL for the llama-cpp backend.
|
For instance, `CHAP_LLAMA_CPP_URL=http://server.local:8080/completion` changes the default server URL for the llama-cpp back-end.
|
||||||
|
|
||||||
## Importing from ChatGPT
|
## Importing from ChatGPT
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,7 @@ src_path = project_root / "src"
|
||||||
sys.path.insert(0, str(src_path))
|
sys.path.insert(0, str(src_path))
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
# pylint: disable=import-error,no-name-in-module
|
||||||
from chap.core import main
|
from chap.core import main
|
||||||
|
|
||||||
main()
|
main()
|
||||||
|
|
|
||||||
|
|
@ -20,19 +20,15 @@ name="chap"
|
||||||
authors = [{name = "Jeff Epler", email = "jepler@gmail.com"}]
|
authors = [{name = "Jeff Epler", email = "jepler@gmail.com"}]
|
||||||
description = "Interact with the OpenAI ChatGPT API (and other text generators)"
|
description = "Interact with the OpenAI ChatGPT API (and other text generators)"
|
||||||
dynamic = ["readme","version","dependencies"]
|
dynamic = ["readme","version","dependencies"]
|
||||||
requires-python = ">=3.9"
|
|
||||||
keywords = ["llm", "tui", "chatgpt"]
|
|
||||||
classifiers = [
|
classifiers = [
|
||||||
"Development Status :: 3 - Alpha",
|
|
||||||
"License :: OSI Approved :: MIT License",
|
|
||||||
"Operating System :: OS Independent",
|
|
||||||
"Programming Language :: Python :: 3",
|
"Programming Language :: Python :: 3",
|
||||||
"Programming Language :: Python :: 3.9",
|
|
||||||
"Programming Language :: Python :: 3.10",
|
"Programming Language :: Python :: 3.10",
|
||||||
"Programming Language :: Python :: 3.11",
|
"Programming Language :: Python :: 3.11",
|
||||||
"Programming Language :: Python :: 3.12",
|
"Programming Language :: Python :: 3.12",
|
||||||
"Programming Language :: Python :: Implementation :: CPython",
|
|
||||||
"Programming Language :: Python :: Implementation :: PyPy",
|
"Programming Language :: Python :: Implementation :: PyPy",
|
||||||
|
"Programming Language :: Python :: Implementation :: CPython",
|
||||||
|
"License :: OSI Approved :: MIT License",
|
||||||
|
"Operating System :: OS Independent",
|
||||||
]
|
]
|
||||||
[project.urls]
|
[project.urls]
|
||||||
homepage = "https://github.com/jepler/chap"
|
homepage = "https://github.com/jepler/chap"
|
||||||
|
|
|
||||||
|
|
@ -6,8 +6,7 @@ click
|
||||||
httpx
|
httpx
|
||||||
lorem-text
|
lorem-text
|
||||||
platformdirs
|
platformdirs
|
||||||
pyperclip
|
|
||||||
simple_parsing
|
simple_parsing
|
||||||
textual[syntax] >= 4
|
textual>=0.18.0
|
||||||
tiktoken
|
tiktoken
|
||||||
websockets
|
websockets
|
||||||
|
|
|
||||||
|
|
@ -1,95 +0,0 @@
|
||||||
# SPDX-FileCopyrightText: 2024 Jeff Epler <jepler@gmail.com>
|
|
||||||
#
|
|
||||||
# SPDX-License-Identifier: MIT
|
|
||||||
|
|
||||||
import json
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import AsyncGenerator, Any
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
from ..core import AutoAskMixin, Backend
|
|
||||||
from ..key import UsesKeyMixin
|
|
||||||
from ..session import Assistant, Role, Session, User
|
|
||||||
|
|
||||||
|
|
||||||
class Anthropic(AutoAskMixin, UsesKeyMixin):
|
|
||||||
@dataclass
|
|
||||||
class Parameters:
|
|
||||||
url: str = "https://api.anthropic.com"
|
|
||||||
model: str = "claude-3-5-sonnet-20240620"
|
|
||||||
max_new_tokens: int = 1000
|
|
||||||
api_key_name = "anthropic_api_key"
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.parameters = self.Parameters()
|
|
||||||
|
|
||||||
system_message = """\
|
|
||||||
Answer each question accurately and thoroughly.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def make_full_query(self, messages: Session, max_query_size: int) -> dict[str, Any]:
|
|
||||||
system = [m.content for m in messages if m.role == Role.SYSTEM]
|
|
||||||
messages = [m for m in messages if m.role != Role.SYSTEM and m.content]
|
|
||||||
del messages[:-max_query_size]
|
|
||||||
result = dict(
|
|
||||||
model=self.parameters.model,
|
|
||||||
max_tokens=self.parameters.max_new_tokens,
|
|
||||||
messages=[dict(role=str(m.role), content=m.content) for m in messages],
|
|
||||||
stream=True,
|
|
||||||
)
|
|
||||||
if system and system[0]:
|
|
||||||
result["system"] = system[0]
|
|
||||||
return result
|
|
||||||
|
|
||||||
async def aask(
|
|
||||||
self,
|
|
||||||
session: Session,
|
|
||||||
query: str,
|
|
||||||
*,
|
|
||||||
max_query_size: int = 5,
|
|
||||||
timeout: float = 180,
|
|
||||||
) -> AsyncGenerator[str, None]:
|
|
||||||
new_content: list[str] = []
|
|
||||||
params = self.make_full_query(session + [User(query)], max_query_size)
|
|
||||||
try:
|
|
||||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
|
||||||
async with client.stream(
|
|
||||||
"POST",
|
|
||||||
f"{self.parameters.url}/v1/messages",
|
|
||||||
json=params,
|
|
||||||
headers={
|
|
||||||
"x-api-key": self.get_key(),
|
|
||||||
"content-type": "application/json",
|
|
||||||
"anthropic-version": "2023-06-01",
|
|
||||||
"anthropic-beta": "messages-2023-12-15",
|
|
||||||
},
|
|
||||||
) as response:
|
|
||||||
if response.status_code == 200:
|
|
||||||
async for line in response.aiter_lines():
|
|
||||||
if line.startswith("data:"):
|
|
||||||
data = line.removeprefix("data:").strip()
|
|
||||||
j = json.loads(data)
|
|
||||||
content = j.get("delta", {}).get("text", "")
|
|
||||||
if content:
|
|
||||||
new_content.append(content)
|
|
||||||
yield content
|
|
||||||
else:
|
|
||||||
content = f"\nFailed with {response=!r}"
|
|
||||||
new_content.append(content)
|
|
||||||
yield content
|
|
||||||
async for line in response.aiter_lines():
|
|
||||||
new_content.append(line)
|
|
||||||
yield line
|
|
||||||
except httpx.HTTPError as e:
|
|
||||||
content = f"\nException: {e!r}"
|
|
||||||
new_content.append(content)
|
|
||||||
yield content
|
|
||||||
|
|
||||||
session.extend([User(query), Assistant("".join(new_content))])
|
|
||||||
|
|
||||||
|
|
||||||
def factory() -> Backend:
|
|
||||||
"""Uses the anthropic text-generation-interface web API"""
|
|
||||||
return Anthropic()
|
|
||||||
|
|
@ -9,11 +9,11 @@ from typing import Any, AsyncGenerator
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from ..core import AutoAskMixin, Backend
|
from ..core import AutoAskMixin, Backend
|
||||||
from ..key import UsesKeyMixin
|
from ..key import get_key
|
||||||
from ..session import Assistant, Role, Session, User
|
from ..session import Assistant, Role, Session, User
|
||||||
|
|
||||||
|
|
||||||
class HuggingFace(AutoAskMixin, UsesKeyMixin):
|
class HuggingFace(AutoAskMixin):
|
||||||
@dataclass
|
@dataclass
|
||||||
class Parameters:
|
class Parameters:
|
||||||
url: str = "https://api-inference.huggingface.co"
|
url: str = "https://api-inference.huggingface.co"
|
||||||
|
|
@ -24,7 +24,6 @@ class HuggingFace(AutoAskMixin, UsesKeyMixin):
|
||||||
after_user: str = """ [/INST] """
|
after_user: str = """ [/INST] """
|
||||||
after_assistant: str = """ </s><s>[INST] """
|
after_assistant: str = """ </s><s>[INST] """
|
||||||
stop_token_id = 2
|
stop_token_id = 2
|
||||||
api_key_name = "huggingface_api_token"
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
@ -90,7 +89,9 @@ A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, a
|
||||||
*,
|
*,
|
||||||
max_query_size: int = 5,
|
max_query_size: int = 5,
|
||||||
timeout: float = 180,
|
timeout: float = 180,
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[
|
||||||
|
str, None
|
||||||
|
]: # pylint: disable=unused-argument,too-many-locals,too-many-branches
|
||||||
new_content: list[str] = []
|
new_content: list[str] = []
|
||||||
inputs = self.make_full_query(session + [User(query)], max_query_size)
|
inputs = self.make_full_query(session + [User(query)], max_query_size)
|
||||||
try:
|
try:
|
||||||
|
|
@ -111,6 +112,10 @@ A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, a
|
||||||
|
|
||||||
session.extend([User(query), Assistant("".join(new_content))])
|
session.extend([User(query), Assistant("".join(new_content))])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_key(cls) -> str:
|
||||||
|
return get_key("huggingface_api_token")
|
||||||
|
|
||||||
|
|
||||||
def factory() -> Backend:
|
def factory() -> Backend:
|
||||||
"""Uses the huggingface text-generation-interface web API"""
|
"""Uses the huggingface text-generation-interface web API"""
|
||||||
|
|
|
||||||
|
|
@ -18,16 +18,10 @@ class LlamaCpp(AutoAskMixin):
|
||||||
url: str = "http://localhost:8080/completion"
|
url: str = "http://localhost:8080/completion"
|
||||||
"""The URL of a llama.cpp server's completion endpoint."""
|
"""The URL of a llama.cpp server's completion endpoint."""
|
||||||
|
|
||||||
start_prompt: str = "<|begin_of_text|>"
|
start_prompt: str = """<s>[INST] <<SYS>>\n"""
|
||||||
system_format: str = (
|
after_system: str = "\n<</SYS>>\n\n"
|
||||||
"<|start_header_id|>system<|end_header_id|>\n\n{}<|eot_id|>"
|
after_user: str = """ [/INST] """
|
||||||
)
|
after_assistant: str = """ </s><s>[INST] """
|
||||||
user_format: str = "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
|
|
||||||
assistant_format: str = (
|
|
||||||
"<|start_header_id|>assistant<|end_header_id|>\n\n{}<|eot_id|>"
|
|
||||||
)
|
|
||||||
end_prompt: str = "<|start_header_id|>assistant<|end_header_id|>\n\n"
|
|
||||||
stop: str | None = None
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
@ -40,16 +34,17 @@ A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, a
|
||||||
def make_full_query(self, messages: Session, max_query_size: int) -> str:
|
def make_full_query(self, messages: Session, max_query_size: int) -> str:
|
||||||
del messages[1:-max_query_size]
|
del messages[1:-max_query_size]
|
||||||
result = [self.parameters.start_prompt]
|
result = [self.parameters.start_prompt]
|
||||||
formats = {
|
|
||||||
Role.SYSTEM: self.parameters.system_format,
|
|
||||||
Role.USER: self.parameters.user_format,
|
|
||||||
Role.ASSISTANT: self.parameters.assistant_format,
|
|
||||||
}
|
|
||||||
for m in messages:
|
for m in messages:
|
||||||
content = (m.content or "").strip()
|
content = (m.content or "").strip()
|
||||||
if not content:
|
if not content:
|
||||||
continue
|
continue
|
||||||
result.append(formats[m.role].format(content))
|
result.append(content)
|
||||||
|
if m.role == Role.SYSTEM:
|
||||||
|
result.append(self.parameters.after_system)
|
||||||
|
elif m.role == Role.ASSISTANT:
|
||||||
|
result.append(self.parameters.after_assistant)
|
||||||
|
elif m.role == Role.USER:
|
||||||
|
result.append(self.parameters.after_user)
|
||||||
full_query = "".join(result)
|
full_query = "".join(result)
|
||||||
return full_query
|
return full_query
|
||||||
|
|
||||||
|
|
@ -60,11 +55,13 @@ A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, a
|
||||||
*,
|
*,
|
||||||
max_query_size: int = 5,
|
max_query_size: int = 5,
|
||||||
timeout: float = 180,
|
timeout: float = 180,
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[
|
||||||
|
str, None
|
||||||
|
]: # pylint: disable=unused-argument,too-many-locals,too-many-branches
|
||||||
params = {
|
params = {
|
||||||
"prompt": self.make_full_query(session + [User(query)], max_query_size),
|
"prompt": self.make_full_query(session + [User(query)], max_query_size),
|
||||||
"stream": True,
|
"stream": True,
|
||||||
"stop": ["</s>", "<s>", "[INST]", "<|eot_id|>"],
|
"stop": ["</s>", "<s>", "[INST]"],
|
||||||
}
|
}
|
||||||
new_content: list[str] = []
|
new_content: list[str] = []
|
||||||
try:
|
try:
|
||||||
|
|
@ -101,10 +98,5 @@ A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, a
|
||||||
|
|
||||||
|
|
||||||
def factory() -> Backend:
|
def factory() -> Backend:
|
||||||
"""Uses the llama.cpp completion web API
|
"""Uses the llama.cpp completion web API"""
|
||||||
|
|
||||||
Note: Consider using the openai-chatgpt backend with a custom URL instead.
|
|
||||||
The llama.cpp server will automatically apply common chat templates with the
|
|
||||||
openai-chatgpt backend, while chat templates must be manually configured client side
|
|
||||||
with this backend."""
|
|
||||||
return LlamaCpp()
|
return LlamaCpp()
|
||||||
|
|
|
||||||
|
|
@ -45,7 +45,7 @@ class Lorem:
|
||||||
session: Session,
|
session: Session,
|
||||||
query: str,
|
query: str,
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
data = self.ask(session, query)
|
data = self.ask(session, query)[-1]
|
||||||
for word, opt_sep in ipartition(data):
|
for word, opt_sep in ipartition(data):
|
||||||
yield word + opt_sep
|
yield word + opt_sep
|
||||||
await asyncio.sleep(
|
await asyncio.sleep(
|
||||||
|
|
@ -56,7 +56,7 @@ class Lorem:
|
||||||
self,
|
self,
|
||||||
session: Session,
|
session: Session,
|
||||||
query: str,
|
query: str,
|
||||||
) -> str:
|
) -> str: # pylint: disable=unused-argument
|
||||||
new_content = cast(
|
new_content = cast(
|
||||||
str,
|
str,
|
||||||
lorem.paragraphs(
|
lorem.paragraphs(
|
||||||
|
|
|
||||||
|
|
@ -1,96 +0,0 @@
|
||||||
# SPDX-FileCopyrightText: 2024 Jeff Epler <jepler@gmail.com>
|
|
||||||
#
|
|
||||||
# SPDX-License-Identifier: MIT
|
|
||||||
|
|
||||||
import json
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import AsyncGenerator, Any
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
from ..core import AutoAskMixin
|
|
||||||
from ..key import UsesKeyMixin
|
|
||||||
from ..session import Assistant, Session, User
|
|
||||||
|
|
||||||
|
|
||||||
class Mistral(AutoAskMixin, UsesKeyMixin):
|
|
||||||
@dataclass
|
|
||||||
class Parameters:
|
|
||||||
url: str = "https://api.mistral.ai"
|
|
||||||
model: str = "open-mistral-7b"
|
|
||||||
max_new_tokens: int = 1000
|
|
||||||
api_key_name = "mistral_api_key"
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.parameters = self.Parameters()
|
|
||||||
|
|
||||||
system_message = """\
|
|
||||||
Answer each question accurately and thoroughly.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def make_full_query(self, messages: Session, max_query_size: int) -> dict[str, Any]:
|
|
||||||
messages = [m for m in messages if m.content]
|
|
||||||
del messages[1:-max_query_size]
|
|
||||||
result = dict(
|
|
||||||
model=self.parameters.model,
|
|
||||||
max_tokens=self.parameters.max_new_tokens,
|
|
||||||
messages=[dict(role=str(m.role), content=m.content) for m in messages],
|
|
||||||
stream=True,
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
|
|
||||||
async def aask(
|
|
||||||
self,
|
|
||||||
session: Session,
|
|
||||||
query: str,
|
|
||||||
*,
|
|
||||||
max_query_size: int = 5,
|
|
||||||
timeout: float = 180,
|
|
||||||
) -> AsyncGenerator[str, None]:
|
|
||||||
new_content: list[str] = []
|
|
||||||
params = self.make_full_query(session + [User(query)], max_query_size)
|
|
||||||
try:
|
|
||||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
|
||||||
async with client.stream(
|
|
||||||
"POST",
|
|
||||||
f"{self.parameters.url}/v1/chat/completions",
|
|
||||||
json=params,
|
|
||||||
headers={
|
|
||||||
"Authorization": f"Bearer {self.get_key()}",
|
|
||||||
"content-type": "application/json",
|
|
||||||
"accept": "application/json",
|
|
||||||
"model": "application/json",
|
|
||||||
},
|
|
||||||
) as response:
|
|
||||||
if response.status_code == 200:
|
|
||||||
async for line in response.aiter_lines():
|
|
||||||
if line.startswith("data:"):
|
|
||||||
data = line.removeprefix("data:").strip()
|
|
||||||
if data == "[DONE]":
|
|
||||||
break
|
|
||||||
j = json.loads(data)
|
|
||||||
content = (
|
|
||||||
j.get("choices", [{}])[0]
|
|
||||||
.get("delta", {})
|
|
||||||
.get("content", "")
|
|
||||||
)
|
|
||||||
if content:
|
|
||||||
new_content.append(content)
|
|
||||||
yield content
|
|
||||||
else:
|
|
||||||
content = f"\nFailed with {response=!r}"
|
|
||||||
new_content.append(content)
|
|
||||||
yield content
|
|
||||||
async for line in response.aiter_lines():
|
|
||||||
new_content.append(line)
|
|
||||||
yield line
|
|
||||||
except httpx.HTTPError as e:
|
|
||||||
content = f"\nException: {e!r}"
|
|
||||||
new_content.append(content)
|
|
||||||
yield content
|
|
||||||
|
|
||||||
session.extend([User(query), Assistant("".join(new_content))])
|
|
||||||
|
|
||||||
|
|
||||||
factory = Mistral
|
|
||||||
|
|
@ -12,7 +12,7 @@ import httpx
|
||||||
import tiktoken
|
import tiktoken
|
||||||
|
|
||||||
from ..core import Backend
|
from ..core import Backend
|
||||||
from ..key import UsesKeyMixin
|
from ..key import get_key
|
||||||
from ..session import Assistant, Message, Session, User, session_to_list
|
from ..session import Assistant, Message, Session, User, session_to_list
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -63,29 +63,15 @@ class EncodingMeta:
|
||||||
return cls(encoding, tokens_per_message, tokens_per_name, tokens_overhead)
|
return cls(encoding, tokens_per_message, tokens_per_name, tokens_overhead)
|
||||||
|
|
||||||
|
|
||||||
class ChatGPT(UsesKeyMixin):
|
class ChatGPT:
|
||||||
@dataclass
|
@dataclass
|
||||||
class Parameters:
|
class Parameters:
|
||||||
model: str = "gpt-4o-mini"
|
model: str = "gpt-3.5-turbo"
|
||||||
"""The model to use. The most common alternative value is 'gpt-4o'."""
|
"""The model to use. The most common alternative value is 'gpt-4'."""
|
||||||
|
|
||||||
max_request_tokens: int = 1024
|
max_request_tokens: int = 1024
|
||||||
"""The approximate greatest number of tokens to send in a request. When the session is long, the system prompt and 1 or more of the most recent interaction steps are sent."""
|
"""The approximate greatest number of tokens to send in a request. When the session is long, the system prompt and 1 or more of the most recent interaction steps are sent."""
|
||||||
|
|
||||||
url: str = "https://api.openai.com/v1/chat/completions"
|
|
||||||
"""The URL of a chatgpt-compatible server's completion endpoint. Notably, llama.cpp's server is compatible with this backend, and can automatically apply common chat templates too."""
|
|
||||||
|
|
||||||
temperature: float | None = None
|
|
||||||
"""The model temperature for sampling"""
|
|
||||||
|
|
||||||
top_p: float | None = None
|
|
||||||
"""The model temperature for sampling"""
|
|
||||||
|
|
||||||
api_key_name: str = "openai_api_key"
|
|
||||||
"""The OpenAI API key"""
|
|
||||||
|
|
||||||
parameters: Parameters
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.parameters = self.Parameters()
|
self.parameters = self.Parameters()
|
||||||
|
|
||||||
|
|
@ -111,11 +97,11 @@ class ChatGPT(UsesKeyMixin):
|
||||||
def ask(self, session: Session, query: str, *, timeout: float = 60) -> str:
|
def ask(self, session: Session, query: str, *, timeout: float = 60) -> str:
|
||||||
full_prompt = self.make_full_prompt(session + [User(query)])
|
full_prompt = self.make_full_prompt(session + [User(query)])
|
||||||
response = httpx.post(
|
response = httpx.post(
|
||||||
self.parameters.url,
|
"https://api.openai.com/v1/chat/completions",
|
||||||
json={
|
json={
|
||||||
"model": self.parameters.model,
|
"model": self.parameters.model,
|
||||||
"messages": session_to_list(full_prompt),
|
"messages": session_to_list(full_prompt),
|
||||||
},
|
}, # pylint: disable=no-member
|
||||||
headers={
|
headers={
|
||||||
"Authorization": f"Bearer {self.get_key()}",
|
"Authorization": f"Bearer {self.get_key()}",
|
||||||
},
|
},
|
||||||
|
|
@ -142,12 +128,10 @@ class ChatGPT(UsesKeyMixin):
|
||||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||||
async with client.stream(
|
async with client.stream(
|
||||||
"POST",
|
"POST",
|
||||||
self.parameters.url,
|
"https://api.openai.com/v1/chat/completions",
|
||||||
headers={"authorization": f"Bearer {self.get_key()}"},
|
headers={"authorization": f"Bearer {self.get_key()}"},
|
||||||
json={
|
json={
|
||||||
"model": self.parameters.model,
|
"model": self.parameters.model,
|
||||||
"temperature": self.parameters.temperature,
|
|
||||||
"top_p": self.parameters.top_p,
|
|
||||||
"stream": True,
|
"stream": True,
|
||||||
"messages": session_to_list(full_prompt),
|
"messages": session_to_list(full_prompt),
|
||||||
},
|
},
|
||||||
|
|
@ -176,6 +160,10 @@ class ChatGPT(UsesKeyMixin):
|
||||||
|
|
||||||
session.extend([User(query), Assistant("".join(new_content))])
|
session.extend([User(query), Assistant("".join(new_content))])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_key(cls) -> str:
|
||||||
|
return get_key("openai_api_key")
|
||||||
|
|
||||||
|
|
||||||
def factory() -> Backend:
|
def factory() -> Backend:
|
||||||
"""Uses the OpenAI chat completion API"""
|
"""Uses the OpenAI chat completion API"""
|
||||||
|
|
|
||||||
|
|
@ -29,7 +29,7 @@ USER: Hello, AI.
|
||||||
|
|
||||||
AI: Hello! How can I assist you today?"""
|
AI: Hello! How can I assist you today?"""
|
||||||
|
|
||||||
async def aask(
|
async def aask( # pylint: disable=unused-argument,too-many-locals,too-many-branches
|
||||||
self,
|
self,
|
||||||
session: Session,
|
session: Session,
|
||||||
query: str,
|
query: str,
|
||||||
|
|
@ -60,11 +60,12 @@ AI: Hello! How can I assist you today?"""
|
||||||
}
|
}
|
||||||
full_prompt = session + [User(query)]
|
full_prompt = session + [User(query)]
|
||||||
del full_prompt[1:-max_query_size]
|
del full_prompt[1:-max_query_size]
|
||||||
new_data = old_data = full_query = "\n".join(
|
new_data = old_data = full_query = (
|
||||||
f"{role_map.get(q.role,'')}{q.content}\n" for q in full_prompt
|
"\n".join(f"{role_map.get(q.role,'')}{q.content}\n" for q in full_prompt)
|
||||||
) + f"\n{role_map.get('assistant')}"
|
+ f"\n{role_map.get('assistant')}"
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
async with websockets.connect(
|
async with websockets.connect( # pylint: disable=no-member
|
||||||
f"ws://{self.parameters.server_hostname}:7860/queue/join"
|
f"ws://{self.parameters.server_hostname}:7860/queue/join"
|
||||||
) as websocket:
|
) as websocket:
|
||||||
while content := json.loads(await websocket.recv()):
|
while content := json.loads(await websocket.recv()):
|
||||||
|
|
@ -126,7 +127,7 @@ AI: Hello! How can I assist you today?"""
|
||||||
# stop generation by closing the websocket here
|
# stop generation by closing the websocket here
|
||||||
if content["msg"] == "process_completed":
|
if content["msg"] == "process_completed":
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e: # pylint: disable=broad-exception-caught
|
||||||
content = f"\nException: {e!r}"
|
content = f"\nException: {e!r}"
|
||||||
new_data += content
|
new_data += content
|
||||||
yield content
|
yield content
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import sys
|
import sys
|
||||||
from typing import Iterable, Optional, Protocol
|
from typing import Iterable, Protocol
|
||||||
|
|
||||||
import click
|
import click
|
||||||
import rich
|
import rich
|
||||||
|
|
@ -40,7 +40,7 @@ class DumbPrinter:
|
||||||
|
|
||||||
|
|
||||||
class WrappingPrinter:
|
class WrappingPrinter:
|
||||||
def __init__(self, width: Optional[int] = None) -> None:
|
def __init__(self, width: int | None = None) -> None:
|
||||||
self._width = width or rich.get_console().width
|
self._width = width or rich.get_console().width
|
||||||
self._column = 0
|
self._column = 0
|
||||||
self._line = ""
|
self._line = ""
|
||||||
|
|
@ -100,9 +100,8 @@ def verbose_ask(api: Backend, session: Session, q: str, print_prompt: bool) -> s
|
||||||
|
|
||||||
@command_uses_new_session
|
@command_uses_new_session
|
||||||
@click.option("--print-prompt/--no-print-prompt", default=True)
|
@click.option("--print-prompt/--no-print-prompt", default=True)
|
||||||
@click.option("--stdin/--no-stdin", "use_stdin", default=False)
|
@click.argument("prompt", nargs=-1, required=True)
|
||||||
@click.argument("prompt", nargs=-1)
|
def main(obj: Obj, prompt: str, print_prompt: bool) -> None:
|
||||||
def main(obj: Obj, prompt: list[str], use_stdin: bool, print_prompt: bool) -> None:
|
|
||||||
"""Ask a question (command-line argument is passed as prompt)"""
|
"""Ask a question (command-line argument is passed as prompt)"""
|
||||||
session = obj.session
|
session = obj.session
|
||||||
assert session is not None
|
assert session is not None
|
||||||
|
|
@ -113,16 +112,9 @@ def main(obj: Obj, prompt: list[str], use_stdin: bool, print_prompt: bool) -> No
|
||||||
api = obj.api
|
api = obj.api
|
||||||
assert api is not None
|
assert api is not None
|
||||||
|
|
||||||
if use_stdin:
|
|
||||||
if prompt:
|
|
||||||
raise click.UsageError("Can't use 'prompt' together with --stdin")
|
|
||||||
|
|
||||||
joined_prompt = sys.stdin.read()
|
|
||||||
else:
|
|
||||||
joined_prompt = " ".join(prompt)
|
|
||||||
# symlink_session_filename(session_filename)
|
# symlink_session_filename(session_filename)
|
||||||
|
|
||||||
response = verbose_ask(api, session, joined_prompt, print_prompt=print_prompt)
|
response = verbose_ask(api, session, " ".join(prompt), print_prompt=print_prompt)
|
||||||
|
|
||||||
print(f"Saving session to {session_filename}", file=sys.stderr)
|
print(f"Saving session to {session_filename}", file=sys.stderr)
|
||||||
if response is not None:
|
if response is not None:
|
||||||
|
|
@ -130,4 +122,4 @@ def main(obj: Obj, prompt: list[str], use_stdin: bool, print_prompt: bool) -> No
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main() # pylint: disable=no-value-for-parameter
|
||||||
|
|
|
||||||
|
|
@ -38,4 +38,4 @@ def main(obj: Obj, no_system: bool) -> None:
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main() # pylint: disable=no-value-for-parameter
|
||||||
|
|
|
||||||
|
|
@ -24,8 +24,8 @@ def list_files_matching_rx(
|
||||||
"*.json"
|
"*.json"
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
session = session_from_file(conversation)
|
session = session_from_file(conversation) # pylint: disable=no-member
|
||||||
except Exception as e:
|
except Exception as e: # pylint: disable=broad-exception-caught
|
||||||
print(f"Failed to read {conversation}: {e}", file=sys.stderr)
|
print(f"Failed to read {conversation}: {e}", file=sys.stderr)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
@ -67,4 +67,4 @@ def main(
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main() # pylint: disable=no-value-for-parameter
|
||||||
|
|
|
||||||
|
|
@ -82,4 +82,4 @@ def main(output_directory: pathlib.Path, files: list[TextIO]) -> None:
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main() # pylint: disable=no-value-for-parameter
|
||||||
|
|
|
||||||
|
|
@ -45,4 +45,4 @@ def main(obj: Obj, no_system: bool) -> None:
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main() # pylint: disable=no-value-for-parameter
|
||||||
|
|
|
||||||
|
|
@ -36,9 +36,6 @@ Markdown:focus-within {
|
||||||
Markdown {
|
Markdown {
|
||||||
border-left: heavy transparent;
|
border-left: heavy transparent;
|
||||||
}
|
}
|
||||||
MarkdownFence {
|
|
||||||
max-height: 9999;
|
|
||||||
}
|
|
||||||
Footer {
|
Footer {
|
||||||
dock: top;
|
dock: top;
|
||||||
}
|
}
|
||||||
|
|
@ -57,5 +54,3 @@ Input {
|
||||||
Markdown {
|
Markdown {
|
||||||
margin: 0 1 0 0;
|
margin: 0 1 0 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
SubmittableTextArea { height: auto; min-height: 5; margin: 0; border: none; border-left: heavy $primary }
|
|
||||||
|
|
|
||||||
|
|
@ -3,58 +3,35 @@
|
||||||
# SPDX-License-Identifier: MIT
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
from typing import Any, Optional, cast, TYPE_CHECKING
|
from typing import cast
|
||||||
|
|
||||||
import click
|
|
||||||
from markdown_it import MarkdownIt
|
from markdown_it import MarkdownIt
|
||||||
from textual import work
|
from textual import work
|
||||||
from textual._ansi_sequences import ANSI_SEQUENCES_KEYS
|
|
||||||
from textual.app import App, ComposeResult
|
from textual.app import App, ComposeResult
|
||||||
from textual.binding import Binding
|
from textual.binding import Binding
|
||||||
from textual.containers import Container, Horizontal, VerticalScroll
|
from textual.containers import Container, Horizontal, VerticalScroll
|
||||||
from textual.keys import Keys
|
from textual.widgets import Button, Footer, Input, LoadingIndicator, Markdown
|
||||||
from textual.widgets import Button, Footer, LoadingIndicator, Markdown, TextArea
|
|
||||||
|
|
||||||
from ..core import Backend, Obj, command_uses_new_session, get_api, new_session_path
|
from ..core import Backend, Obj, command_uses_new_session, get_api, new_session_path
|
||||||
from ..session import Assistant, Message, Session, User, new_session, session_to_file
|
from ..session import Assistant, Message, Session, User, new_session, session_to_file
|
||||||
|
|
||||||
|
|
||||||
# workaround for pyperclip being un-typed
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
|
|
||||||
def pyperclip_copy(data: str) -> None:
|
|
||||||
...
|
|
||||||
else:
|
|
||||||
from pyperclip import copy as pyperclip_copy
|
|
||||||
|
|
||||||
|
|
||||||
# Monkeypatch alt+enter as meaning "F9", WFM
|
|
||||||
# ignore typing here because ANSI_SEQUENCES_KEYS is a Mapping[] which is read-only as
|
|
||||||
# far as mypy is concerned.
|
|
||||||
ANSI_SEQUENCES_KEYS["\x1b\r"] = (Keys.F9,) # type: ignore
|
|
||||||
ANSI_SEQUENCES_KEYS["\x1b\n"] = (Keys.F9,) # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
class SubmittableTextArea(TextArea):
|
|
||||||
BINDINGS = [
|
|
||||||
Binding("f9", "app.submit", "Submit", show=True),
|
|
||||||
Binding("tab", "focus_next", show=False, priority=True), # no inserting tabs
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def parser_factory() -> MarkdownIt:
|
def parser_factory() -> MarkdownIt:
|
||||||
parser = MarkdownIt()
|
parser = MarkdownIt()
|
||||||
parser.options["html"] = False
|
parser.options["html"] = False
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
class ChapMarkdown(Markdown, can_focus=True, can_focus_children=False):
|
class ChapMarkdown(
|
||||||
|
Markdown, can_focus=True, can_focus_children=False
|
||||||
|
): # pylint: disable=function-redefined
|
||||||
BINDINGS = [
|
BINDINGS = [
|
||||||
Binding("ctrl+c", "app.yank", "Copy text", show=True),
|
Binding("ctrl+y", "yank", "Yank text", show=True),
|
||||||
Binding("ctrl+r", "app.resubmit", "resubmit", show=True),
|
Binding("ctrl+r", "resubmit", "resubmit", show=True),
|
||||||
Binding("ctrl+x", "app.redraft", "redraft", show=True),
|
Binding("ctrl+x", "redraft", "redraft", show=True),
|
||||||
Binding("ctrl+q", "app.toggle_history", "history toggle", show=True),
|
Binding("ctrl+q", "toggle_history", "history toggle", show=True),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -75,14 +52,14 @@ class CancelButton(Button):
|
||||||
class Tui(App[None]):
|
class Tui(App[None]):
|
||||||
CSS_PATH = "tui.css"
|
CSS_PATH = "tui.css"
|
||||||
BINDINGS = [
|
BINDINGS = [
|
||||||
Binding("ctrl+q", "quit", "Quit", show=True, priority=True),
|
Binding("ctrl+c", "quit", "Quit", show=True, priority=True),
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, api: Optional[Backend] = None, session: Optional[Session] = None
|
self, api: Backend | None = None, session: Session | None = None
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.api = api or get_api(click.Context(click.Command("chap tui")), "lorem")
|
self.api = api or get_api("lorem")
|
||||||
self.session = (
|
self.session = (
|
||||||
new_session(self.api.system_message) if session is None else session
|
new_session(self.api.system_message) if session is None else session
|
||||||
)
|
)
|
||||||
|
|
@ -96,8 +73,8 @@ class Tui(App[None]):
|
||||||
return cast(VerticalScroll, self.query_one("#wait"))
|
return cast(VerticalScroll, self.query_one("#wait"))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input(self) -> SubmittableTextArea:
|
def input(self) -> Input:
|
||||||
return self.query_one(SubmittableTextArea)
|
return self.query_one(Input)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def cancel_button(self) -> CancelButton:
|
def cancel_button(self) -> CancelButton:
|
||||||
|
|
@ -117,9 +94,7 @@ class Tui(App[None]):
|
||||||
Container(id="pad"),
|
Container(id="pad"),
|
||||||
id="content",
|
id="content",
|
||||||
)
|
)
|
||||||
s = SubmittableTextArea(language="markdown")
|
yield Input(placeholder="Prompt")
|
||||||
s.show_line_numbers = False
|
|
||||||
yield s
|
|
||||||
with Horizontal(id="wait"):
|
with Horizontal(id="wait"):
|
||||||
yield LoadingIndicator()
|
yield LoadingIndicator()
|
||||||
yield CancelButton(label="❌ Stop Generation", id="cancel", disabled=True)
|
yield CancelButton(label="❌ Stop Generation", id="cancel", disabled=True)
|
||||||
|
|
@ -128,8 +103,8 @@ class Tui(App[None]):
|
||||||
self.container.scroll_end(animate=False)
|
self.container.scroll_end(animate=False)
|
||||||
self.input.focus()
|
self.input.focus()
|
||||||
|
|
||||||
async def action_submit(self) -> None:
|
async def on_input_submitted(self, event: Input.Submitted) -> None:
|
||||||
self.get_completion(self.input.text)
|
self.get_completion(event.value)
|
||||||
|
|
||||||
@work(exclusive=True)
|
@work(exclusive=True)
|
||||||
async def get_completion(self, query: str) -> None:
|
async def get_completion(self, query: str) -> None:
|
||||||
|
|
@ -145,6 +120,7 @@ class Tui(App[None]):
|
||||||
await self.container.mount_all(
|
await self.container.mount_all(
|
||||||
[markdown_for_step(User(query)), output], before="#pad"
|
[markdown_for_step(User(query)), output], before="#pad"
|
||||||
)
|
)
|
||||||
|
tokens: list[str] = []
|
||||||
update: asyncio.Queue[bool] = asyncio.Queue(1)
|
update: asyncio.Queue[bool] = asyncio.Queue(1)
|
||||||
|
|
||||||
for markdown in self.container.children:
|
for markdown in self.container.children:
|
||||||
|
|
@ -165,22 +141,15 @@ class Tui(App[None]):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def render_fun() -> None:
|
async def render_fun() -> None:
|
||||||
old_len = 0
|
|
||||||
while await update.get():
|
while await update.get():
|
||||||
content = message.content
|
if tokens:
|
||||||
new_len = len(content)
|
output.update("".join(tokens).strip())
|
||||||
new_content = content[old_len:new_len]
|
|
||||||
if new_content:
|
|
||||||
if old_len:
|
|
||||||
await output.append(new_content)
|
|
||||||
else:
|
|
||||||
output.update(content)
|
|
||||||
self.container.scroll_end()
|
self.container.scroll_end()
|
||||||
old_len = new_len
|
await asyncio.sleep(0.1)
|
||||||
await asyncio.sleep(0.01)
|
|
||||||
|
|
||||||
async def get_token_fun() -> None:
|
async def get_token_fun() -> None:
|
||||||
async for token in self.api.aask(session, query):
|
async for token in self.api.aask(session, query):
|
||||||
|
tokens.append(token)
|
||||||
message.content += token
|
message.content += token
|
||||||
try:
|
try:
|
||||||
update.put_nowait(True)
|
update.put_nowait(True)
|
||||||
|
|
@ -193,10 +162,10 @@ class Tui(App[None]):
|
||||||
try:
|
try:
|
||||||
await asyncio.gather(render_fun(), get_token_fun())
|
await asyncio.gather(render_fun(), get_token_fun())
|
||||||
finally:
|
finally:
|
||||||
self.input.clear()
|
self.input.value = ""
|
||||||
all_output = self.session[-1].content
|
all_output = self.session[-1].content
|
||||||
output.update(all_output)
|
output.update(all_output)
|
||||||
output._markdown = all_output
|
output._markdown = all_output # pylint: disable=protected-access
|
||||||
self.container.scroll_end()
|
self.container.scroll_end()
|
||||||
|
|
||||||
for markdown in self.container.children:
|
for markdown in self.container.children:
|
||||||
|
|
@ -214,8 +183,8 @@ class Tui(App[None]):
|
||||||
def action_yank(self) -> None:
|
def action_yank(self) -> None:
|
||||||
widget = self.focused
|
widget = self.focused
|
||||||
if isinstance(widget, ChapMarkdown):
|
if isinstance(widget, ChapMarkdown):
|
||||||
content = widget._markdown or ""
|
content = widget._markdown or "" # pylint: disable=protected-access
|
||||||
pyperclip_copy(content)
|
subprocess.run(["xsel", "-ib"], input=content.encode("utf-8"), check=False)
|
||||||
|
|
||||||
def action_toggle_history(self) -> None:
|
def action_toggle_history(self) -> None:
|
||||||
widget = self.focused
|
widget = self.focused
|
||||||
|
|
@ -235,7 +204,9 @@ class Tui(App[None]):
|
||||||
async def action_stop_generating(self) -> None:
|
async def action_stop_generating(self) -> None:
|
||||||
self.workers.cancel_all()
|
self.workers.cancel_all()
|
||||||
|
|
||||||
async def on_button_pressed(self, event: Button.Pressed) -> None:
|
async def on_button_pressed( # pylint: disable=unused-argument
|
||||||
|
self, event: Button.Pressed
|
||||||
|
) -> None:
|
||||||
self.workers.cancel_all()
|
self.workers.cancel_all()
|
||||||
|
|
||||||
async def action_quit(self) -> None:
|
async def action_quit(self) -> None:
|
||||||
|
|
@ -264,31 +235,19 @@ class Tui(App[None]):
|
||||||
session_to_file(self.session, new_session_path())
|
session_to_file(self.session, new_session_path())
|
||||||
|
|
||||||
query = self.session[idx].content
|
query = self.session[idx].content
|
||||||
self.input.load_text(query)
|
self.input.value = query
|
||||||
|
|
||||||
del self.session[idx:]
|
del self.session[idx:]
|
||||||
for child in self.container.children[idx:-1]:
|
for child in self.container.children[idx:-1]:
|
||||||
await child.remove()
|
await child.remove()
|
||||||
|
|
||||||
self.input.focus()
|
self.input.focus()
|
||||||
self.on_text_area_changed()
|
|
||||||
if resubmit:
|
if resubmit:
|
||||||
await self.action_submit()
|
await self.input.action_submit()
|
||||||
|
|
||||||
def on_text_area_changed(self, event: Any = None) -> None:
|
|
||||||
height = self.input.document.get_size(self.input.indent_width)[1]
|
|
||||||
max_height = max(3, self.size.height - 6)
|
|
||||||
if height >= max_height:
|
|
||||||
self.input.styles.height = max_height
|
|
||||||
elif height <= 3:
|
|
||||||
self.input.styles.height = 3
|
|
||||||
else:
|
|
||||||
self.input.styles.height = height
|
|
||||||
|
|
||||||
|
|
||||||
@command_uses_new_session
|
@command_uses_new_session
|
||||||
@click.option("--replace-system-prompt/--no-replace-system-prompt", default=False)
|
def main(obj: Obj) -> None:
|
||||||
def main(obj: Obj, replace_system_prompt: bool) -> None:
|
|
||||||
"""Start interactive terminal user interface session"""
|
"""Start interactive terminal user interface session"""
|
||||||
api = obj.api
|
api = obj.api
|
||||||
assert api is not None
|
assert api is not None
|
||||||
|
|
@ -297,11 +256,6 @@ def main(obj: Obj, replace_system_prompt: bool) -> None:
|
||||||
session_filename = obj.session_filename
|
session_filename = obj.session_filename
|
||||||
assert session_filename is not None
|
assert session_filename is not None
|
||||||
|
|
||||||
if replace_system_prompt:
|
|
||||||
session[0].content = (
|
|
||||||
api.system_message if obj.system_message is None else obj.system_message
|
|
||||||
)
|
|
||||||
|
|
||||||
tui = Tui(api, session)
|
tui = Tui(api, session)
|
||||||
tui.run()
|
tui.run()
|
||||||
|
|
||||||
|
|
@ -314,4 +268,4 @@ def main(obj: Obj, replace_system_prompt: bool) -> None:
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main() # pylint: disable=no-value-for-parameter
|
||||||
|
|
|
||||||
305
src/chap/core.py
305
src/chap/core.py
|
|
@ -1,53 +1,32 @@
|
||||||
# SPDX-FileCopyrightText: 2023 Jeff Epler <jepler@gmail.com>
|
# SPDX-FileCopyrightText: 2023 Jeff Epler <jepler@gmail.com>
|
||||||
#
|
#
|
||||||
# SPDX-License-Identifier: MIT
|
# SPDX-License-Identifier: MIT
|
||||||
|
# pylint: disable=import-outside-toplevel
|
||||||
|
|
||||||
|
|
||||||
from collections.abc import Sequence
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import datetime
|
import datetime
|
||||||
import io
|
|
||||||
import importlib
|
import importlib
|
||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
import pkgutil
|
import pkgutil
|
||||||
import subprocess
|
import subprocess
|
||||||
import shlex
|
from dataclasses import MISSING, dataclass, fields
|
||||||
import textwrap
|
from types import UnionType
|
||||||
from dataclasses import MISSING, Field, dataclass, fields
|
from typing import Any, AsyncGenerator, Callable, cast
|
||||||
from typing import (
|
|
||||||
Any,
|
|
||||||
AsyncGenerator,
|
|
||||||
Callable,
|
|
||||||
Optional,
|
|
||||||
Union,
|
|
||||||
IO,
|
|
||||||
cast,
|
|
||||||
get_origin,
|
|
||||||
get_args,
|
|
||||||
)
|
|
||||||
import sys
|
|
||||||
|
|
||||||
import click
|
import click
|
||||||
import platformdirs
|
import platformdirs
|
||||||
from simple_parsing.docstring import get_attribute_docstring
|
from simple_parsing.docstring import get_attribute_docstring
|
||||||
from typing_extensions import Protocol
|
from typing_extensions import Protocol
|
||||||
from . import backends, commands
|
|
||||||
from .session import Message, Session, System, session_from_file
|
|
||||||
|
|
||||||
UnionType: type
|
from . import backends, commands # pylint: disable=no-name-in-module
|
||||||
if sys.version_info >= (3, 10):
|
from .session import Message, Session, System, session_from_file
|
||||||
from types import UnionType
|
|
||||||
else:
|
|
||||||
UnionType = type(Union[int, float])
|
|
||||||
|
|
||||||
conversations_path = platformdirs.user_state_path("chap") / "conversations"
|
conversations_path = platformdirs.user_state_path("chap") / "conversations"
|
||||||
conversations_path.mkdir(parents=True, exist_ok=True)
|
conversations_path.mkdir(parents=True, exist_ok=True)
|
||||||
configuration_path = platformdirs.user_config_path("chap")
|
|
||||||
preset_path = configuration_path / "preset"
|
|
||||||
|
|
||||||
|
|
||||||
class ABackend(Protocol):
|
class ABackend(Protocol): # pylint: disable=too-few-public-methods
|
||||||
def aask(self, session: Session, query: str) -> AsyncGenerator[str, None]:
|
def aask(self, session: Session, query: str) -> AsyncGenerator[str, None]:
|
||||||
"""Make a query, updating the session with the query and response, returning the query token by token"""
|
"""Make a query, updating the session with the query and response, returning the query token by token"""
|
||||||
|
|
||||||
|
|
@ -60,7 +39,7 @@ class Backend(ABackend, Protocol):
|
||||||
"""Make a query, updating the session with the query and response, returning the query"""
|
"""Make a query, updating the session with the query and response, returning the query"""
|
||||||
|
|
||||||
|
|
||||||
class AutoAskMixin:
|
class AutoAskMixin: # pylint: disable=too-few-public-methods
|
||||||
"""Mixin class for backends implementing aask"""
|
"""Mixin class for backends implementing aask"""
|
||||||
|
|
||||||
def ask(self, session: Session, query: str) -> str:
|
def ask(self, session: Session, query: str) -> str:
|
||||||
|
|
@ -75,50 +54,20 @@ class AutoAskMixin:
|
||||||
return "".join(tokens)
|
return "".join(tokens)
|
||||||
|
|
||||||
|
|
||||||
def last_session_path() -> Optional[pathlib.Path]:
|
def last_session_path() -> pathlib.Path | None:
|
||||||
result = max(
|
result = max(
|
||||||
conversations_path.glob("*.json"), key=lambda p: p.stat().st_mtime, default=None
|
conversations_path.glob("*.json"), key=lambda p: p.stat().st_mtime, default=None
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def new_session_path(opt_path: Optional[pathlib.Path] = None) -> pathlib.Path:
|
def new_session_path(opt_path: pathlib.Path | None = None) -> pathlib.Path:
|
||||||
return opt_path or conversations_path / (
|
return opt_path or conversations_path / (
|
||||||
datetime.datetime.now().isoformat().replace(":", "_") + ".json"
|
datetime.datetime.now().isoformat().replace(":", "_") + ".json"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_field_type(field: Field[Any]) -> Any:
|
def configure_api_from_environment(api_name: str, api: Backend) -> None:
|
||||||
field_type = field.type
|
|
||||||
if isinstance(field_type, str):
|
|
||||||
raise RuntimeError(
|
|
||||||
"parameters dataclass may not use 'from __future__ import annotations"
|
|
||||||
)
|
|
||||||
origin = get_origin(field_type)
|
|
||||||
if origin in (Union, UnionType):
|
|
||||||
for arg in get_args(field_type):
|
|
||||||
if arg is not None:
|
|
||||||
return arg
|
|
||||||
return field_type
|
|
||||||
|
|
||||||
|
|
||||||
def convert_str_to_field(ctx: click.Context, field: Field[Any], value: str) -> Any:
|
|
||||||
field_type = get_field_type(field)
|
|
||||||
try:
|
|
||||||
if field_type is bool:
|
|
||||||
tv = click.types.BoolParamType().convert(value, None, ctx)
|
|
||||||
else:
|
|
||||||
tv = field_type(value)
|
|
||||||
return tv
|
|
||||||
except ValueError as e:
|
|
||||||
raise click.BadParameter(
|
|
||||||
f"Invalid value for {field.name} with value {value}: {e}"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
|
|
||||||
def configure_api_from_environment(
|
|
||||||
ctx: click.Context, api_name: str, api: Backend
|
|
||||||
) -> None:
|
|
||||||
if not hasattr(api, "parameters"):
|
if not hasattr(api, "parameters"):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
@ -127,25 +76,26 @@ def configure_api_from_environment(
|
||||||
value = os.environ.get(envvar)
|
value = os.environ.get(envvar)
|
||||||
if value is None:
|
if value is None:
|
||||||
continue
|
continue
|
||||||
tv = convert_str_to_field(ctx, field, value)
|
try:
|
||||||
|
tv = field.type(value)
|
||||||
|
except ValueError as e:
|
||||||
|
raise click.BadParameter(
|
||||||
|
f"Invalid value for {field.name} with value {value}: {e}"
|
||||||
|
) from e
|
||||||
setattr(api.parameters, field.name, tv)
|
setattr(api.parameters, field.name, tv)
|
||||||
|
|
||||||
|
|
||||||
def get_api(ctx: click.Context | None = None, name: str | None = None) -> Backend:
|
def get_api(name: str = "openai_chatgpt") -> Backend:
|
||||||
if ctx is None:
|
|
||||||
ctx = click.Context(click.Command("chap"))
|
|
||||||
if name is None:
|
|
||||||
name = os.environ.get("CHAP_BACKEND", "openai_chatgpt")
|
|
||||||
name = name.replace("-", "_")
|
name = name.replace("-", "_")
|
||||||
backend = cast(
|
result = cast(
|
||||||
Backend, importlib.import_module(f"{__package__}.backends.{name}").factory()
|
Backend, importlib.import_module(f"{__package__}.backends.{name}").factory()
|
||||||
)
|
)
|
||||||
configure_api_from_environment(ctx, name, backend)
|
configure_api_from_environment(name, result)
|
||||||
return backend
|
return result
|
||||||
|
|
||||||
|
|
||||||
def do_session_continue(
|
def do_session_continue(
|
||||||
ctx: click.Context, param: click.Parameter, value: Optional[pathlib.Path]
|
ctx: click.Context, param: click.Parameter, value: pathlib.Path | None
|
||||||
) -> None:
|
) -> None:
|
||||||
if value is None:
|
if value is None:
|
||||||
return
|
return
|
||||||
|
|
@ -158,7 +108,9 @@ def do_session_continue(
|
||||||
ctx.obj.session_filename = value
|
ctx.obj.session_filename = value
|
||||||
|
|
||||||
|
|
||||||
def do_session_last(ctx: click.Context, param: click.Parameter, value: bool) -> None:
|
def do_session_last(
|
||||||
|
ctx: click.Context, param: click.Parameter, value: bool
|
||||||
|
) -> None: # pylint: disable=unused-argument
|
||||||
if not value:
|
if not value:
|
||||||
return
|
return
|
||||||
do_session_continue(ctx, param, last_session_path())
|
do_session_continue(ctx, param, last_session_path())
|
||||||
|
|
@ -186,22 +138,18 @@ def colonstr(arg: str) -> tuple[str, str]:
|
||||||
return cast(tuple[str, str], tuple(arg.split(":", 1)))
|
return cast(tuple[str, str], tuple(arg.split(":", 1)))
|
||||||
|
|
||||||
|
|
||||||
def set_system_message(ctx: click.Context, param: click.Parameter, value: str) -> None:
|
def set_system_message( # pylint: disable=unused-argument
|
||||||
if value is None:
|
ctx: click.Context, param: click.Parameter, value: str
|
||||||
return
|
) -> None:
|
||||||
|
if value and value.startswith("@"):
|
||||||
|
with open(value[1:], "r", encoding="utf-8") as f:
|
||||||
|
value = f.read().rstrip()
|
||||||
ctx.obj.system_message = value
|
ctx.obj.system_message = value
|
||||||
|
|
||||||
|
|
||||||
def set_system_message_from_file(
|
def set_backend( # pylint: disable=unused-argument
|
||||||
ctx: click.Context, param: click.Parameter, value: io.TextIOWrapper
|
ctx: click.Context, param: click.Parameter, value: str
|
||||||
) -> None:
|
) -> None:
|
||||||
if value is None:
|
|
||||||
return
|
|
||||||
content = value.read().strip()
|
|
||||||
ctx.obj.system_message = content
|
|
||||||
|
|
||||||
|
|
||||||
def set_backend(ctx: click.Context, param: click.Parameter, value: str) -> None:
|
|
||||||
if value == "list":
|
if value == "list":
|
||||||
formatter = ctx.make_formatter()
|
formatter = ctx.make_formatter()
|
||||||
format_backend_list(formatter)
|
format_backend_list(formatter)
|
||||||
|
|
@ -209,7 +157,7 @@ def set_backend(ctx: click.Context, param: click.Parameter, value: str) -> None:
|
||||||
ctx.exit()
|
ctx.exit()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ctx.obj.api = get_api(ctx, value)
|
ctx.obj.api = get_api(value)
|
||||||
except ModuleNotFoundError as e:
|
except ModuleNotFoundError as e:
|
||||||
raise click.BadParameter(str(e))
|
raise click.BadParameter(str(e))
|
||||||
|
|
||||||
|
|
@ -224,13 +172,15 @@ def format_backend_help(api: Backend, formatter: click.HelpFormatter) -> None:
|
||||||
if doc:
|
if doc:
|
||||||
doc += " "
|
doc += " "
|
||||||
doc += f"(Default: {default!r})"
|
doc += f"(Default: {default!r})"
|
||||||
f_type = get_field_type(f)
|
f_type = f.type
|
||||||
|
if isinstance(f_type, UnionType):
|
||||||
|
f_type = f_type.__args__[0]
|
||||||
typename = f_type.__name__
|
typename = f_type.__name__
|
||||||
rows.append((f"-B {name}:{typename.upper()}", doc))
|
rows.append((f"-B {name}:{typename.upper()}", doc))
|
||||||
formatter.write_dl(rows)
|
formatter.write_dl(rows)
|
||||||
|
|
||||||
|
|
||||||
def set_backend_option(
|
def set_backend_option( # pylint: disable=unused-argument
|
||||||
ctx: click.Context, param: click.Parameter, opts: list[tuple[str, str]]
|
ctx: click.Context, param: click.Parameter, opts: list[tuple[str, str]]
|
||||||
) -> None:
|
) -> None:
|
||||||
api = ctx.obj.api
|
api = ctx.obj.api
|
||||||
|
|
@ -245,8 +195,16 @@ def set_backend_option(
|
||||||
field = all_fields.get(name)
|
field = all_fields.get(name)
|
||||||
if field is None:
|
if field is None:
|
||||||
raise click.BadParameter(f"Invalid parameter {name}")
|
raise click.BadParameter(f"Invalid parameter {name}")
|
||||||
tv = convert_str_to_field(ctx, field, value)
|
f_type = field.type
|
||||||
setattr(api.parameters, field.name, tv)
|
if isinstance(f_type, UnionType):
|
||||||
|
f_type = f_type.__args__[0]
|
||||||
|
try:
|
||||||
|
tv = f_type(value)
|
||||||
|
except ValueError as e:
|
||||||
|
raise click.BadParameter(
|
||||||
|
f"Invalid value for {name} with value {value}: {e}"
|
||||||
|
) from e
|
||||||
|
setattr(api.parameters, name, tv)
|
||||||
|
|
||||||
for kv in opts:
|
for kv in opts:
|
||||||
set_one_backend_option(kv)
|
set_one_backend_option(kv)
|
||||||
|
|
@ -306,7 +264,9 @@ def command_uses_new_session(f_in: click.decorators.FC) -> click.Command:
|
||||||
return click.command()(f)
|
return click.command()(f)
|
||||||
|
|
||||||
|
|
||||||
def version_callback(ctx: click.Context, param: click.Parameter, value: None) -> None:
|
def version_callback( # pylint: disable=unused-argument
|
||||||
|
ctx: click.Context, param: click.Parameter, value: None
|
||||||
|
) -> None:
|
||||||
if not value or ctx.resilient_parsing:
|
if not value or ctx.resilient_parsing:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
@ -333,54 +293,18 @@ def version_callback(ctx: click.Context, param: click.Parameter, value: None) ->
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Obj:
|
class Obj:
|
||||||
api: Optional[Backend] = None
|
api: Backend | None = None
|
||||||
system_message: Optional[str] = None
|
system_message: str | None = None
|
||||||
session: Optional[list[Message]] = None
|
session: list[Message] | None = None
|
||||||
session_filename: Optional[pathlib.Path] = None
|
session_filename: pathlib.Path | None = None
|
||||||
|
|
||||||
|
|
||||||
def maybe_add_txt_extension(fn: pathlib.Path) -> pathlib.Path:
|
class MyCLI(click.MultiCommand):
|
||||||
if not fn.exists():
|
|
||||||
fn1 = pathlib.Path(str(fn) + ".txt")
|
|
||||||
if fn1.exists():
|
|
||||||
fn = fn1
|
|
||||||
return fn
|
|
||||||
|
|
||||||
|
|
||||||
def expand_splats(args: list[str]) -> list[str]:
|
|
||||||
result = []
|
|
||||||
saw_at_dot = False
|
|
||||||
for a in args:
|
|
||||||
if a == "@.":
|
|
||||||
saw_at_dot = True
|
|
||||||
continue
|
|
||||||
if saw_at_dot:
|
|
||||||
result.append(a)
|
|
||||||
continue
|
|
||||||
if a.startswith("@@"): ## double @ to escape an argument that starts with @
|
|
||||||
result.append(a[1:])
|
|
||||||
continue
|
|
||||||
if not a.startswith("@"):
|
|
||||||
result.append(a)
|
|
||||||
continue
|
|
||||||
if a.startswith("@:"):
|
|
||||||
fn: pathlib.Path = preset_path / a[2:]
|
|
||||||
else:
|
|
||||||
fn = pathlib.Path(a[1:])
|
|
||||||
fn = maybe_add_txt_extension(fn)
|
|
||||||
with open(fn, "r", encoding="utf-8") as f:
|
|
||||||
content = f.read()
|
|
||||||
parts = shlex.split(content)
|
|
||||||
result.extend(expand_splats(parts))
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
class MyCLI(click.Group):
|
|
||||||
def make_context(
|
def make_context(
|
||||||
self,
|
self,
|
||||||
info_name: Optional[str],
|
info_name: str | None,
|
||||||
args: list[str],
|
args: list[str],
|
||||||
parent: Optional[click.Context] = None,
|
parent: click.Context | None = None,
|
||||||
**extra: Any,
|
**extra: Any,
|
||||||
) -> click.Context:
|
) -> click.Context:
|
||||||
result = super().make_context(info_name, args, parent, obj=Obj(), **extra)
|
result = super().make_context(info_name, args, parent, obj=Obj(), **extra)
|
||||||
|
|
@ -404,117 +328,14 @@ class MyCLI(click.Group):
|
||||||
except ModuleNotFoundError as exc:
|
except ModuleNotFoundError as exc:
|
||||||
raise click.UsageError(f"Invalid subcommand {cmd_name!r}", ctx) from exc
|
raise click.UsageError(f"Invalid subcommand {cmd_name!r}", ctx) from exc
|
||||||
|
|
||||||
def gather_preset_info(self) -> list[tuple[str, str]]:
|
|
||||||
result = []
|
|
||||||
for p in preset_path.glob("*"):
|
|
||||||
if p.is_file():
|
|
||||||
with p.open() as f:
|
|
||||||
first_line = f.readline()
|
|
||||||
if first_line.startswith("#"):
|
|
||||||
help_str = first_line[1:].strip()
|
|
||||||
else:
|
|
||||||
help_str = "(A comment on the first line would be shown here)"
|
|
||||||
result.append((f"@:{p.name}", help_str))
|
|
||||||
return result
|
|
||||||
|
|
||||||
def format_splat_options(
|
|
||||||
self, ctx: click.Context, formatter: click.HelpFormatter
|
|
||||||
) -> None:
|
|
||||||
with formatter.section("Splats"):
|
|
||||||
formatter.write_text(
|
|
||||||
"Before any other command-line argument parsing is performed, @FILE arguments are expanded:"
|
|
||||||
)
|
|
||||||
formatter.write_paragraph()
|
|
||||||
formatter.indent()
|
|
||||||
formatter.write_dl(
|
|
||||||
[
|
|
||||||
("@FILE", "Argument is searched relative to the current directory"),
|
|
||||||
(
|
|
||||||
"@:FILE",
|
|
||||||
"Argument is searched relative to the configuration directory (e.g., $HOME/.config/chap/preset)",
|
|
||||||
),
|
|
||||||
("@@…", "If an argument starts with a literal '@', double it"),
|
|
||||||
(
|
|
||||||
"@.",
|
|
||||||
"Stops processing any further `@FILE` arguments and leaves them unchanged.",
|
|
||||||
),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
formatter.dedent()
|
|
||||||
formatter.write_paragraph()
|
|
||||||
formatter.write_text(
|
|
||||||
textwrap.dedent(
|
|
||||||
"""\
|
|
||||||
The contents of an `@FILE` are parsed by `shlex.split(comments=True)`.
|
|
||||||
Comments are supported. If the filename ends in .txt,
|
|
||||||
the extension may be omitted."""
|
|
||||||
)
|
|
||||||
)
|
|
||||||
formatter.write_paragraph()
|
|
||||||
if preset_info := self.gather_preset_info():
|
|
||||||
formatter.write_text("Presets found:")
|
|
||||||
formatter.write_paragraph()
|
|
||||||
formatter.indent()
|
|
||||||
formatter.write_dl(preset_info)
|
|
||||||
formatter.dedent()
|
|
||||||
|
|
||||||
def format_options(
|
def format_options(
|
||||||
self, ctx: click.Context, formatter: click.HelpFormatter
|
self, ctx: click.Context, formatter: click.HelpFormatter
|
||||||
) -> None:
|
) -> None:
|
||||||
self.format_splat_options(ctx, formatter)
|
|
||||||
super().format_options(ctx, formatter)
|
super().format_options(ctx, formatter)
|
||||||
api = ctx.obj.api or get_api(ctx)
|
api = ctx.obj.api or get_api()
|
||||||
if hasattr(api, "parameters"):
|
if hasattr(api, "parameters"):
|
||||||
format_backend_help(api, formatter)
|
format_backend_help(api, formatter)
|
||||||
|
|
||||||
def main(
|
|
||||||
self,
|
|
||||||
args: Sequence[str] | None = None,
|
|
||||||
prog_name: str | None = None,
|
|
||||||
complete_var: str | None = None,
|
|
||||||
standalone_mode: bool = True,
|
|
||||||
windows_expand_args: bool = True,
|
|
||||||
**extra: Any,
|
|
||||||
) -> Any:
|
|
||||||
if args is None:
|
|
||||||
args = sys.argv[1:]
|
|
||||||
if os.name == "nt" and windows_expand_args:
|
|
||||||
args = click.utils._expand_args(args)
|
|
||||||
else:
|
|
||||||
args = list(args)
|
|
||||||
|
|
||||||
args = expand_splats(args)
|
|
||||||
|
|
||||||
return super().main(
|
|
||||||
args,
|
|
||||||
prog_name=prog_name,
|
|
||||||
complete_var=complete_var,
|
|
||||||
standalone_mode=standalone_mode,
|
|
||||||
windows_expand_args=windows_expand_args,
|
|
||||||
**extra,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ConfigRelativeFile(click.File):
|
|
||||||
def __init__(self, mode: str, where: str) -> None:
|
|
||||||
super().__init__(mode)
|
|
||||||
self.where = where
|
|
||||||
|
|
||||||
def convert(
|
|
||||||
self,
|
|
||||||
value: str | os.PathLike[str] | IO[Any],
|
|
||||||
param: click.Parameter | None,
|
|
||||||
ctx: click.Context | None,
|
|
||||||
) -> Any:
|
|
||||||
if isinstance(value, str):
|
|
||||||
if value.startswith(":"):
|
|
||||||
value = configuration_path / self.where / value[1:]
|
|
||||||
else:
|
|
||||||
value = pathlib.Path(value)
|
|
||||||
if isinstance(value, pathlib.Path):
|
|
||||||
value = maybe_add_txt_extension(value)
|
|
||||||
return super().convert(value, param, ctx)
|
|
||||||
|
|
||||||
|
|
||||||
main = MyCLI(
|
main = MyCLI(
|
||||||
help="Commandline interface to ChatGPT",
|
help="Commandline interface to ChatGPT",
|
||||||
|
|
@ -526,14 +347,6 @@ main = MyCLI(
|
||||||
help="Show the version and exit",
|
help="Show the version and exit",
|
||||||
callback=version_callback,
|
callback=version_callback,
|
||||||
),
|
),
|
||||||
click.Option(
|
|
||||||
("--system-message-file", "-s"),
|
|
||||||
type=ConfigRelativeFile("r", where="prompt"),
|
|
||||||
default=None,
|
|
||||||
callback=set_system_message_from_file,
|
|
||||||
expose_value=False,
|
|
||||||
help=f"Set the system message from a file. If the filename starts with `:` it is relative to the {configuration_path}/prompt. If the filename ends in .txt, the extension may be omitted.",
|
|
||||||
),
|
|
||||||
click.Option(
|
click.Option(
|
||||||
("--system-message", "-S"),
|
("--system-message", "-S"),
|
||||||
type=str,
|
type=str,
|
||||||
|
|
|
||||||
|
|
@ -2,58 +2,20 @@
|
||||||
#
|
#
|
||||||
# SPDX-License-Identifier: MIT
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
import json
|
|
||||||
import subprocess
|
|
||||||
from typing import Protocol
|
|
||||||
import functools
|
import functools
|
||||||
|
|
||||||
import platformdirs
|
import platformdirs
|
||||||
|
|
||||||
|
|
||||||
class APIKeyProtocol(Protocol):
|
|
||||||
@property
|
|
||||||
def api_key_name(self) -> str:
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class HasKeyProtocol(Protocol):
|
|
||||||
@property
|
|
||||||
def parameters(self) -> APIKeyProtocol:
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class UsesKeyMixin:
|
|
||||||
def get_key(self: HasKeyProtocol) -> str:
|
|
||||||
return get_key(self.parameters.api_key_name)
|
|
||||||
|
|
||||||
|
|
||||||
class NoKeyAvailable(Exception):
|
class NoKeyAvailable(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
_key_path_base = platformdirs.user_config_path("chap")
|
_key_path_base = platformdirs.user_config_path("chap")
|
||||||
|
|
||||||
USE_PASSWORD_STORE = _key_path_base / "USE_PASSWORD_STORE"
|
|
||||||
|
|
||||||
if USE_PASSWORD_STORE.exists():
|
@functools.cache
|
||||||
content = USE_PASSWORD_STORE.read_text(encoding="utf-8")
|
def get_key(name: str, what: str = "openai api key") -> str:
|
||||||
if content.strip():
|
|
||||||
cfg = json.loads(content)
|
|
||||||
pass_command: list[str] = cfg.get("PASS_COMMAND", ["pass", "show"])
|
|
||||||
pass_prefix: str = cfg.get("PASS_PREFIX", "chap/")
|
|
||||||
|
|
||||||
@functools.cache
|
|
||||||
def get_key(name: str, what: str = "api key") -> str:
|
|
||||||
if name == "-":
|
|
||||||
return "-"
|
|
||||||
key_path = f"{pass_prefix}{name}"
|
|
||||||
command = pass_command + [key_path]
|
|
||||||
return subprocess.check_output(command, encoding="utf-8").split("\n")[0]
|
|
||||||
|
|
||||||
else:
|
|
||||||
|
|
||||||
@functools.cache
|
|
||||||
def get_key(name: str, what: str = "api key") -> str:
|
|
||||||
key_path = _key_path_base / name
|
key_path = _key_path_base / name
|
||||||
if not key_path.exists():
|
if not key_path.exists():
|
||||||
raise NoKeyAvailable(
|
raise NoKeyAvailable(
|
||||||
|
|
|
||||||
|
|
@ -7,13 +7,13 @@ from __future__ import annotations
|
||||||
import json
|
import json
|
||||||
import pathlib
|
import pathlib
|
||||||
from dataclasses import asdict, dataclass
|
from dataclasses import asdict, dataclass
|
||||||
from typing import Union, cast
|
from typing import cast
|
||||||
|
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
|
|
||||||
# not an enum.Enum because these objects are not json-serializable, sigh
|
# not an enum.Enum because these objects are not json-serializable, sigh
|
||||||
class Role:
|
class Role: # pylint: disable=too-few-public-methods
|
||||||
ASSISTANT = "assistant"
|
ASSISTANT = "assistant"
|
||||||
SYSTEM = "system"
|
SYSTEM = "system"
|
||||||
USER = "user"
|
USER = "user"
|
||||||
|
|
@ -65,11 +65,11 @@ def session_from_json(data: str) -> Session:
|
||||||
return [Message(**mapping) for mapping in j]
|
return [Message(**mapping) for mapping in j]
|
||||||
|
|
||||||
|
|
||||||
def session_from_file(path: Union[pathlib.Path, str]) -> Session:
|
def session_from_file(path: pathlib.Path | str) -> Session:
|
||||||
with open(path, "r", encoding="utf-8") as f:
|
with open(path, "r", encoding="utf-8") as f:
|
||||||
return session_from_json(f.read())
|
return session_from_json(f.read())
|
||||||
|
|
||||||
|
|
||||||
def session_to_file(session: Session, path: Union[pathlib.Path, str]) -> None:
|
def session_to_file(session: Session, path: pathlib.Path | str) -> None:
|
||||||
with open(path, "w", encoding="utf-8") as f:
|
with open(path, "w", encoding="utf-8") as f:
|
||||||
f.write(session_to_json(session))
|
f.write(session_to_json(session))
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue