Merge pull request #11 from jepler/backend-option-help

Add background option help; add chatgpt max-request-tokens
This commit is contained in:
Jeff Epler 2023-09-24 11:01:13 -05:00 committed by GitHub
commit 02de0b3163
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 148 additions and 19 deletions

View file

@ -32,5 +32,5 @@ repos:
rev: v2.17.0
hooks:
- id: pylint
additional_dependencies: [click,dataclasses_json,httpx,lorem-text,'textual>=0.18.0',websockets]
additional_dependencies: [click,dataclasses_json,httpx,lorem-text,simple-parsing,'textual>=0.18.0',tiktoken,websockets]
args: ['--source-roots', 'src']

View file

@ -26,7 +26,9 @@ dependencies = [
"httpx",
"lorem-text",
"platformdirs",
"simple_parsing",
"textual>=0.18.0",
"tiktoken",
"websockets",
]
classifiers = [

View file

@ -22,9 +22,13 @@ class Lorem:
@dataclass
class Parameters:
delay_mu: float = 0.035
"""Average delay between tokens"""
delay_sigma: float = 0.02
"""Standard deviation of token delay"""
paragraph_lo: int = 1
"""Minimum response paragraph count"""
paragraph_hi: int = 5
"""Maximum response paragraph count (inclusive)"""
def __init__(self):
self.parameters = self.Parameters()

View file

@ -2,28 +2,105 @@
#
# SPDX-License-Identifier: MIT
import functools
import json
from dataclasses import dataclass
import httpx
import tiktoken
from ..key import get_key
from ..session import Assistant, Session, User
@dataclass(frozen=True)
class EncodingMeta:
encoding: tiktoken.Encoding
tokens_per_message: int
tokens_per_name: int
@functools.lru_cache()
def encode(self, s):
return self.encoding.encode(s)
def num_tokens_for_message(self, message):
# n.b. chap doesn't use message.name yet
return len(self.encode(message.role)) + len(self.encode(message.content))
def num_tokens_for_messages(self, messages):
return sum(self.num_tokens_for_message(message) for message in messages) + 3
@classmethod
@functools.cache
def from_model(cls, model):
if model == "gpt-3.5-turbo":
# print("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.")
model = "gpt-3.5-turbo-0613"
if model == "gpt-4":
# print("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
model = "gpt-4-0613"
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
print("Warning: model not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
if model in {
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
"gpt-4-0314",
"gpt-4-32k-0314",
"gpt-4-0613",
"gpt-4-32k-0613",
}:
tokens_per_message = 3
tokens_per_name = 1
elif model == "gpt-3.5-turbo-0301":
tokens_per_message = (
4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
)
tokens_per_name = -1 # if there's a name, the role is omitted
else:
raise NotImplementedError(
f"""EncodingMeta is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens."""
)
return cls(encoding, tokens_per_message, tokens_per_name)
class ChatGPT:
@dataclass
class Parameters:
model: str = "gpt-3.5-turbo"
"""The model to use. The most common alternative value is 'gpt-4'."""
max_request_tokens: int = 1024
"""The approximate greatest number of tokens to send in a request. When the session is long, the system prompt and 1 or more of the most recent interaction steps are sent."""
def __init__(self):
self.parameters = self.Parameters()
system_message: str = "You're named Chap (they/them). Chap strives to answer truthfully and frankly. You have opinions, preferences, and emotions, you can say when you don't know, and you prefer to speak in plain language."
def ask(self, session, query, *, max_query_size=5, timeout=60):
full_prompt = Session(session.session + [User(query)])
del full_prompt.session[1:-max_query_size]
def make_full_prompt(self, all_history):
encoding = EncodingMeta.from_model(self.parameters.model)
result = [all_history[0]] # Assumed to be system prompt
left = self.parameters.max_request_tokens - encoding.num_tokens_for_messages(
result
)
parts = []
for message in reversed(all_history[1:]):
msglen = encoding.num_tokens_for_message(message)
if left >= msglen:
left -= msglen
parts.append(message)
else:
break
result.extend(reversed(parts))
return Session(result)
def ask(self, session, query, *, timeout=60):
full_prompt = self.make_full_prompt(session.session + [User(query)])
response = httpx.post(
"https://api.openai.com/v1/chat/completions",
json={
@ -51,10 +128,8 @@ class ChatGPT:
session.session.extend([User(query), Assistant(result)])
return result
async def aask(self, session, query, *, max_query_size=5, timeout=60):
full_prompt = Session(session.session + [User(query)])
del full_prompt.session[1:-max_query_size]
async def aask(self, session, query, *, timeout=60):
full_prompt = self.make_full_prompt(session.session + [User(query)])
new_content = []
try:
async with httpx.AsyncClient(timeout=timeout) as client:

View file

@ -8,7 +8,7 @@ import sys
import click
import rich
from ..core import uses_new_session
from ..core import command_uses_new_session
if sys.stdout.isatty():
bold = "\033[1m"
@ -78,8 +78,7 @@ def verbose_ask(api, session, q, **kw):
return result
@click.command
@uses_new_session
@command_uses_new_session
@click.argument("prompt", nargs=-1, required=True)
def main(obj, prompt):
"""Ask a question (command-line argument is passed as prompt)"""

View file

@ -6,14 +6,13 @@ import asyncio
import subprocess
import sys
import click
from markdown_it import MarkdownIt
from textual.app import App
from textual.binding import Binding
from textual.containers import Container, VerticalScroll
from textual.widgets import Footer, Input, Markdown
from ..core import get_api, uses_new_session
from ..core import command_uses_new_session, get_api
from ..session import Assistant, Session, User
@ -115,8 +114,7 @@ class Tui(App):
subprocess.run(["xsel", "-ib"], input=content.encode("utf-8"), check=False)
@click.command
@uses_new_session
@command_uses_new_session
def main(obj):
"""Start interactive terminal user interface session"""
api = obj.api

View file

@ -8,10 +8,11 @@ import importlib
import pathlib
import pkgutil
import subprocess
from dataclasses import dataclass, fields
from dataclasses import MISSING, dataclass, fields
import click
import platformdirs
from simple_parsing.docstring import get_attribute_docstring
from . import commands # pylint: disable=no-name-in-module
from .session import Session
@ -88,7 +89,40 @@ def set_system_message(ctx, param, value): # pylint: disable=unused-argument
def set_backend(ctx, param, value): # pylint: disable=unused-argument
ctx.obj.api = get_api(value)
try:
ctx.obj.api = get_api(value)
except ModuleNotFoundError as e:
raise click.BadParameter(str(e))
def format_backend_help(api, formatter):
with formatter.section(f"Backend options for {api.__class__.__name__}"):
rows = []
for f in fields(api.parameters):
name = f.name.replace("_", "-")
default = f.default if f.default_factory is MISSING else f.default_factory()
doc = get_attribute_docstring(type(api.parameters), f.name).docstring_below
if doc:
doc += " "
doc += f"(Default: {default})"
rows.append((f"-B {name}:{f.type.__name__.upper()}", doc))
formatter.write_dl(rows)
def backend_help(ctx, param, value): # pylint: disable=unused-argument
if ctx.resilient_parsing or not value:
return
api = ctx.obj.api or get_api()
if not hasattr(api, "parameters"):
click.utils.echo(f"{api.__class__.__name__} does not support parameters")
else:
formatter = ctx.make_formatter()
format_backend_help(api, formatter)
click.utils.echo(formatter.getvalue().rstrip("\n"))
ctx.exit()
def set_backend_option(ctx, param, opts): # pylint: disable=unused-argument
@ -97,7 +131,7 @@ def set_backend_option(ctx, param, opts): # pylint: disable=unused-argument
raise click.BadParameter(
f"{api.__class__.__name__} does not support parameters"
)
all_fields = dict((f.name, f) for f in fields(api.parameters))
all_fields = dict((f.name.replace("_", "-"), f) for f in fields(api.parameters))
def set_one_backend_option(kv):
name, value = kv
@ -137,7 +171,15 @@ def uses_existing_session(f):
return f
def uses_new_session(f):
class CommandWithBackendHelp(click.Command):
def format_options(self, ctx, formatter):
super().format_options(ctx, formatter)
api = ctx.obj.api or get_api()
if hasattr(api, "parameters"):
format_backend_help(api, formatter)
def command_uses_new_session(f):
f = click.option(
"--system-message",
"-S",
@ -155,6 +197,14 @@ def uses_new_session(f):
expose_value=False,
is_eager=True,
)(f)
f = click.option(
"--backend-help",
is_flag=True,
is_eager=True,
callback=backend_help,
expose_value=False,
help="Show information about backend options",
)(f)
f = click.option(
"--backend-option",
"-B",
@ -172,6 +222,7 @@ def uses_new_session(f):
callback=do_session_new,
expose_value=False,
)(f)
f = click.command(cls=CommandWithBackendHelp)(f)
return f