diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 65291c8..2b082b7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -32,5 +32,5 @@ repos: rev: v2.17.0 hooks: - id: pylint - additional_dependencies: [click,dataclasses_json,httpx,lorem-text,'textual>=0.18.0',websockets] + additional_dependencies: [click,dataclasses_json,httpx,lorem-text,simple-parsing,'textual>=0.18.0',tiktoken,websockets] args: ['--source-roots', 'src'] diff --git a/pyproject.toml b/pyproject.toml index 1460525..116c81a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,9 @@ dependencies = [ "httpx", "lorem-text", "platformdirs", + "simple_parsing", "textual>=0.18.0", + "tiktoken", "websockets", ] classifiers = [ diff --git a/src/chap/backends/lorem.py b/src/chap/backends/lorem.py index dd94281..563a3af 100644 --- a/src/chap/backends/lorem.py +++ b/src/chap/backends/lorem.py @@ -22,9 +22,13 @@ class Lorem: @dataclass class Parameters: delay_mu: float = 0.035 + """Average delay between tokens""" delay_sigma: float = 0.02 + """Standard deviation of token delay""" paragraph_lo: int = 1 + """Minimum response paragraph count""" paragraph_hi: int = 5 + """Maximum response paragraph count (inclusive)""" def __init__(self): self.parameters = self.Parameters() diff --git a/src/chap/backends/openai_chatgpt.py b/src/chap/backends/openai_chatgpt.py index d4346fd..b555764 100644 --- a/src/chap/backends/openai_chatgpt.py +++ b/src/chap/backends/openai_chatgpt.py @@ -2,28 +2,105 @@ # # SPDX-License-Identifier: MIT +import functools import json from dataclasses import dataclass import httpx +import tiktoken from ..key import get_key from ..session import Assistant, Session, User +@dataclass(frozen=True) +class EncodingMeta: + encoding: tiktoken.Encoding + tokens_per_message: int + tokens_per_name: int + + @functools.lru_cache() + def encode(self, s): + return self.encoding.encode(s) + + def num_tokens_for_message(self, message): + # n.b. chap doesn't use message.name yet + return len(self.encode(message.role)) + len(self.encode(message.content)) + + def num_tokens_for_messages(self, messages): + return sum(self.num_tokens_for_message(message) for message in messages) + 3 + + @classmethod + @functools.cache + def from_model(cls, model): + if model == "gpt-3.5-turbo": + # print("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.") + model = "gpt-3.5-turbo-0613" + if model == "gpt-4": + # print("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.") + model = "gpt-4-0613" + + try: + encoding = tiktoken.encoding_for_model(model) + except KeyError: + print("Warning: model not found. Using cl100k_base encoding.") + encoding = tiktoken.get_encoding("cl100k_base") + + if model in { + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-16k-0613", + "gpt-4-0314", + "gpt-4-32k-0314", + "gpt-4-0613", + "gpt-4-32k-0613", + }: + tokens_per_message = 3 + tokens_per_name = 1 + elif model == "gpt-3.5-turbo-0301": + tokens_per_message = ( + 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n + ) + tokens_per_name = -1 # if there's a name, the role is omitted + else: + raise NotImplementedError( + f"""EncodingMeta is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens.""" + ) + return cls(encoding, tokens_per_message, tokens_per_name) + + class ChatGPT: @dataclass class Parameters: model: str = "gpt-3.5-turbo" + """The model to use. The most common alternative value is 'gpt-4'.""" + + max_request_tokens: int = 1024 + """The approximate greatest number of tokens to send in a request. When the session is long, the system prompt and 1 or more of the most recent interaction steps are sent.""" def __init__(self): self.parameters = self.Parameters() system_message: str = "You're named Chap (they/them). Chap strives to answer truthfully and frankly. You have opinions, preferences, and emotions, you can say when you don't know, and you prefer to speak in plain language." - def ask(self, session, query, *, max_query_size=5, timeout=60): - full_prompt = Session(session.session + [User(query)]) - del full_prompt.session[1:-max_query_size] + def make_full_prompt(self, all_history): + encoding = EncodingMeta.from_model(self.parameters.model) + result = [all_history[0]] # Assumed to be system prompt + left = self.parameters.max_request_tokens - encoding.num_tokens_for_messages( + result + ) + parts = [] + for message in reversed(all_history[1:]): + msglen = encoding.num_tokens_for_message(message) + if left >= msglen: + left -= msglen + parts.append(message) + else: + break + result.extend(reversed(parts)) + return Session(result) + + def ask(self, session, query, *, timeout=60): + full_prompt = self.make_full_prompt(session.session + [User(query)]) response = httpx.post( "https://api.openai.com/v1/chat/completions", json={ @@ -51,10 +128,8 @@ class ChatGPT: session.session.extend([User(query), Assistant(result)]) return result - async def aask(self, session, query, *, max_query_size=5, timeout=60): - full_prompt = Session(session.session + [User(query)]) - del full_prompt.session[1:-max_query_size] - + async def aask(self, session, query, *, timeout=60): + full_prompt = self.make_full_prompt(session.session + [User(query)]) new_content = [] try: async with httpx.AsyncClient(timeout=timeout) as client: diff --git a/src/chap/commands/ask.py b/src/chap/commands/ask.py index dbaf029..e426160 100644 --- a/src/chap/commands/ask.py +++ b/src/chap/commands/ask.py @@ -8,7 +8,7 @@ import sys import click import rich -from ..core import uses_new_session +from ..core import command_uses_new_session if sys.stdout.isatty(): bold = "\033[1m" @@ -78,8 +78,7 @@ def verbose_ask(api, session, q, **kw): return result -@click.command -@uses_new_session +@command_uses_new_session @click.argument("prompt", nargs=-1, required=True) def main(obj, prompt): """Ask a question (command-line argument is passed as prompt)""" diff --git a/src/chap/commands/tui.py b/src/chap/commands/tui.py index f9d0a2e..d38e045 100644 --- a/src/chap/commands/tui.py +++ b/src/chap/commands/tui.py @@ -6,14 +6,13 @@ import asyncio import subprocess import sys -import click from markdown_it import MarkdownIt from textual.app import App from textual.binding import Binding from textual.containers import Container, VerticalScroll from textual.widgets import Footer, Input, Markdown -from ..core import get_api, uses_new_session +from ..core import command_uses_new_session, get_api from ..session import Assistant, Session, User @@ -115,8 +114,7 @@ class Tui(App): subprocess.run(["xsel", "-ib"], input=content.encode("utf-8"), check=False) -@click.command -@uses_new_session +@command_uses_new_session def main(obj): """Start interactive terminal user interface session""" api = obj.api diff --git a/src/chap/core.py b/src/chap/core.py index 3cc425d..cc4b6bd 100644 --- a/src/chap/core.py +++ b/src/chap/core.py @@ -8,10 +8,11 @@ import importlib import pathlib import pkgutil import subprocess -from dataclasses import dataclass, fields +from dataclasses import MISSING, dataclass, fields import click import platformdirs +from simple_parsing.docstring import get_attribute_docstring from . import commands # pylint: disable=no-name-in-module from .session import Session @@ -88,7 +89,40 @@ def set_system_message(ctx, param, value): # pylint: disable=unused-argument def set_backend(ctx, param, value): # pylint: disable=unused-argument - ctx.obj.api = get_api(value) + try: + ctx.obj.api = get_api(value) + except ModuleNotFoundError as e: + raise click.BadParameter(str(e)) + + +def format_backend_help(api, formatter): + with formatter.section(f"Backend options for {api.__class__.__name__}"): + rows = [] + for f in fields(api.parameters): + name = f.name.replace("_", "-") + default = f.default if f.default_factory is MISSING else f.default_factory() + doc = get_attribute_docstring(type(api.parameters), f.name).docstring_below + if doc: + doc += " " + doc += f"(Default: {default})" + rows.append((f"-B {name}:{f.type.__name__.upper()}", doc)) + formatter.write_dl(rows) + + +def backend_help(ctx, param, value): # pylint: disable=unused-argument + if ctx.resilient_parsing or not value: + return + + api = ctx.obj.api or get_api() + + if not hasattr(api, "parameters"): + click.utils.echo(f"{api.__class__.__name__} does not support parameters") + else: + formatter = ctx.make_formatter() + format_backend_help(api, formatter) + click.utils.echo(formatter.getvalue().rstrip("\n")) + + ctx.exit() def set_backend_option(ctx, param, opts): # pylint: disable=unused-argument @@ -97,7 +131,7 @@ def set_backend_option(ctx, param, opts): # pylint: disable=unused-argument raise click.BadParameter( f"{api.__class__.__name__} does not support parameters" ) - all_fields = dict((f.name, f) for f in fields(api.parameters)) + all_fields = dict((f.name.replace("_", "-"), f) for f in fields(api.parameters)) def set_one_backend_option(kv): name, value = kv @@ -137,7 +171,15 @@ def uses_existing_session(f): return f -def uses_new_session(f): +class CommandWithBackendHelp(click.Command): + def format_options(self, ctx, formatter): + super().format_options(ctx, formatter) + api = ctx.obj.api or get_api() + if hasattr(api, "parameters"): + format_backend_help(api, formatter) + + +def command_uses_new_session(f): f = click.option( "--system-message", "-S", @@ -155,6 +197,14 @@ def uses_new_session(f): expose_value=False, is_eager=True, )(f) + f = click.option( + "--backend-help", + is_flag=True, + is_eager=True, + callback=backend_help, + expose_value=False, + help="Show information about backend options", + )(f) f = click.option( "--backend-option", "-B", @@ -172,6 +222,7 @@ def uses_new_session(f): callback=do_session_new, expose_value=False, )(f) + f = click.command(cls=CommandWithBackendHelp)(f) return f