Merge pull request #11 from jepler/backend-option-help
Add background option help; add chatgpt max-request-tokens
This commit is contained in:
commit
02de0b3163
7 changed files with 148 additions and 19 deletions
|
|
@ -32,5 +32,5 @@ repos:
|
|||
rev: v2.17.0
|
||||
hooks:
|
||||
- id: pylint
|
||||
additional_dependencies: [click,dataclasses_json,httpx,lorem-text,'textual>=0.18.0',websockets]
|
||||
additional_dependencies: [click,dataclasses_json,httpx,lorem-text,simple-parsing,'textual>=0.18.0',tiktoken,websockets]
|
||||
args: ['--source-roots', 'src']
|
||||
|
|
|
|||
|
|
@ -26,7 +26,9 @@ dependencies = [
|
|||
"httpx",
|
||||
"lorem-text",
|
||||
"platformdirs",
|
||||
"simple_parsing",
|
||||
"textual>=0.18.0",
|
||||
"tiktoken",
|
||||
"websockets",
|
||||
]
|
||||
classifiers = [
|
||||
|
|
|
|||
|
|
@ -22,9 +22,13 @@ class Lorem:
|
|||
@dataclass
|
||||
class Parameters:
|
||||
delay_mu: float = 0.035
|
||||
"""Average delay between tokens"""
|
||||
delay_sigma: float = 0.02
|
||||
"""Standard deviation of token delay"""
|
||||
paragraph_lo: int = 1
|
||||
"""Minimum response paragraph count"""
|
||||
paragraph_hi: int = 5
|
||||
"""Maximum response paragraph count (inclusive)"""
|
||||
|
||||
def __init__(self):
|
||||
self.parameters = self.Parameters()
|
||||
|
|
|
|||
|
|
@ -2,28 +2,105 @@
|
|||
#
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import functools
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
|
||||
import httpx
|
||||
import tiktoken
|
||||
|
||||
from ..key import get_key
|
||||
from ..session import Assistant, Session, User
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class EncodingMeta:
|
||||
encoding: tiktoken.Encoding
|
||||
tokens_per_message: int
|
||||
tokens_per_name: int
|
||||
|
||||
@functools.lru_cache()
|
||||
def encode(self, s):
|
||||
return self.encoding.encode(s)
|
||||
|
||||
def num_tokens_for_message(self, message):
|
||||
# n.b. chap doesn't use message.name yet
|
||||
return len(self.encode(message.role)) + len(self.encode(message.content))
|
||||
|
||||
def num_tokens_for_messages(self, messages):
|
||||
return sum(self.num_tokens_for_message(message) for message in messages) + 3
|
||||
|
||||
@classmethod
|
||||
@functools.cache
|
||||
def from_model(cls, model):
|
||||
if model == "gpt-3.5-turbo":
|
||||
# print("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.")
|
||||
model = "gpt-3.5-turbo-0613"
|
||||
if model == "gpt-4":
|
||||
# print("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
|
||||
model = "gpt-4-0613"
|
||||
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
except KeyError:
|
||||
print("Warning: model not found. Using cl100k_base encoding.")
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
if model in {
|
||||
"gpt-3.5-turbo-0613",
|
||||
"gpt-3.5-turbo-16k-0613",
|
||||
"gpt-4-0314",
|
||||
"gpt-4-32k-0314",
|
||||
"gpt-4-0613",
|
||||
"gpt-4-32k-0613",
|
||||
}:
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
elif model == "gpt-3.5-turbo-0301":
|
||||
tokens_per_message = (
|
||||
4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
|
||||
)
|
||||
tokens_per_name = -1 # if there's a name, the role is omitted
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"""EncodingMeta is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens."""
|
||||
)
|
||||
return cls(encoding, tokens_per_message, tokens_per_name)
|
||||
|
||||
|
||||
class ChatGPT:
|
||||
@dataclass
|
||||
class Parameters:
|
||||
model: str = "gpt-3.5-turbo"
|
||||
"""The model to use. The most common alternative value is 'gpt-4'."""
|
||||
|
||||
max_request_tokens: int = 1024
|
||||
"""The approximate greatest number of tokens to send in a request. When the session is long, the system prompt and 1 or more of the most recent interaction steps are sent."""
|
||||
|
||||
def __init__(self):
|
||||
self.parameters = self.Parameters()
|
||||
|
||||
system_message: str = "You're named Chap (they/them). Chap strives to answer truthfully and frankly. You have opinions, preferences, and emotions, you can say when you don't know, and you prefer to speak in plain language."
|
||||
|
||||
def ask(self, session, query, *, max_query_size=5, timeout=60):
|
||||
full_prompt = Session(session.session + [User(query)])
|
||||
del full_prompt.session[1:-max_query_size]
|
||||
def make_full_prompt(self, all_history):
|
||||
encoding = EncodingMeta.from_model(self.parameters.model)
|
||||
result = [all_history[0]] # Assumed to be system prompt
|
||||
left = self.parameters.max_request_tokens - encoding.num_tokens_for_messages(
|
||||
result
|
||||
)
|
||||
parts = []
|
||||
for message in reversed(all_history[1:]):
|
||||
msglen = encoding.num_tokens_for_message(message)
|
||||
if left >= msglen:
|
||||
left -= msglen
|
||||
parts.append(message)
|
||||
else:
|
||||
break
|
||||
result.extend(reversed(parts))
|
||||
return Session(result)
|
||||
|
||||
def ask(self, session, query, *, timeout=60):
|
||||
full_prompt = self.make_full_prompt(session.session + [User(query)])
|
||||
response = httpx.post(
|
||||
"https://api.openai.com/v1/chat/completions",
|
||||
json={
|
||||
|
|
@ -51,10 +128,8 @@ class ChatGPT:
|
|||
session.session.extend([User(query), Assistant(result)])
|
||||
return result
|
||||
|
||||
async def aask(self, session, query, *, max_query_size=5, timeout=60):
|
||||
full_prompt = Session(session.session + [User(query)])
|
||||
del full_prompt.session[1:-max_query_size]
|
||||
|
||||
async def aask(self, session, query, *, timeout=60):
|
||||
full_prompt = self.make_full_prompt(session.session + [User(query)])
|
||||
new_content = []
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import sys
|
|||
import click
|
||||
import rich
|
||||
|
||||
from ..core import uses_new_session
|
||||
from ..core import command_uses_new_session
|
||||
|
||||
if sys.stdout.isatty():
|
||||
bold = "\033[1m"
|
||||
|
|
@ -78,8 +78,7 @@ def verbose_ask(api, session, q, **kw):
|
|||
return result
|
||||
|
||||
|
||||
@click.command
|
||||
@uses_new_session
|
||||
@command_uses_new_session
|
||||
@click.argument("prompt", nargs=-1, required=True)
|
||||
def main(obj, prompt):
|
||||
"""Ask a question (command-line argument is passed as prompt)"""
|
||||
|
|
|
|||
|
|
@ -6,14 +6,13 @@ import asyncio
|
|||
import subprocess
|
||||
import sys
|
||||
|
||||
import click
|
||||
from markdown_it import MarkdownIt
|
||||
from textual.app import App
|
||||
from textual.binding import Binding
|
||||
from textual.containers import Container, VerticalScroll
|
||||
from textual.widgets import Footer, Input, Markdown
|
||||
|
||||
from ..core import get_api, uses_new_session
|
||||
from ..core import command_uses_new_session, get_api
|
||||
from ..session import Assistant, Session, User
|
||||
|
||||
|
||||
|
|
@ -115,8 +114,7 @@ class Tui(App):
|
|||
subprocess.run(["xsel", "-ib"], input=content.encode("utf-8"), check=False)
|
||||
|
||||
|
||||
@click.command
|
||||
@uses_new_session
|
||||
@command_uses_new_session
|
||||
def main(obj):
|
||||
"""Start interactive terminal user interface session"""
|
||||
api = obj.api
|
||||
|
|
|
|||
|
|
@ -8,10 +8,11 @@ import importlib
|
|||
import pathlib
|
||||
import pkgutil
|
||||
import subprocess
|
||||
from dataclasses import dataclass, fields
|
||||
from dataclasses import MISSING, dataclass, fields
|
||||
|
||||
import click
|
||||
import platformdirs
|
||||
from simple_parsing.docstring import get_attribute_docstring
|
||||
|
||||
from . import commands # pylint: disable=no-name-in-module
|
||||
from .session import Session
|
||||
|
|
@ -88,7 +89,40 @@ def set_system_message(ctx, param, value): # pylint: disable=unused-argument
|
|||
|
||||
|
||||
def set_backend(ctx, param, value): # pylint: disable=unused-argument
|
||||
ctx.obj.api = get_api(value)
|
||||
try:
|
||||
ctx.obj.api = get_api(value)
|
||||
except ModuleNotFoundError as e:
|
||||
raise click.BadParameter(str(e))
|
||||
|
||||
|
||||
def format_backend_help(api, formatter):
|
||||
with formatter.section(f"Backend options for {api.__class__.__name__}"):
|
||||
rows = []
|
||||
for f in fields(api.parameters):
|
||||
name = f.name.replace("_", "-")
|
||||
default = f.default if f.default_factory is MISSING else f.default_factory()
|
||||
doc = get_attribute_docstring(type(api.parameters), f.name).docstring_below
|
||||
if doc:
|
||||
doc += " "
|
||||
doc += f"(Default: {default})"
|
||||
rows.append((f"-B {name}:{f.type.__name__.upper()}", doc))
|
||||
formatter.write_dl(rows)
|
||||
|
||||
|
||||
def backend_help(ctx, param, value): # pylint: disable=unused-argument
|
||||
if ctx.resilient_parsing or not value:
|
||||
return
|
||||
|
||||
api = ctx.obj.api or get_api()
|
||||
|
||||
if not hasattr(api, "parameters"):
|
||||
click.utils.echo(f"{api.__class__.__name__} does not support parameters")
|
||||
else:
|
||||
formatter = ctx.make_formatter()
|
||||
format_backend_help(api, formatter)
|
||||
click.utils.echo(formatter.getvalue().rstrip("\n"))
|
||||
|
||||
ctx.exit()
|
||||
|
||||
|
||||
def set_backend_option(ctx, param, opts): # pylint: disable=unused-argument
|
||||
|
|
@ -97,7 +131,7 @@ def set_backend_option(ctx, param, opts): # pylint: disable=unused-argument
|
|||
raise click.BadParameter(
|
||||
f"{api.__class__.__name__} does not support parameters"
|
||||
)
|
||||
all_fields = dict((f.name, f) for f in fields(api.parameters))
|
||||
all_fields = dict((f.name.replace("_", "-"), f) for f in fields(api.parameters))
|
||||
|
||||
def set_one_backend_option(kv):
|
||||
name, value = kv
|
||||
|
|
@ -137,7 +171,15 @@ def uses_existing_session(f):
|
|||
return f
|
||||
|
||||
|
||||
def uses_new_session(f):
|
||||
class CommandWithBackendHelp(click.Command):
|
||||
def format_options(self, ctx, formatter):
|
||||
super().format_options(ctx, formatter)
|
||||
api = ctx.obj.api or get_api()
|
||||
if hasattr(api, "parameters"):
|
||||
format_backend_help(api, formatter)
|
||||
|
||||
|
||||
def command_uses_new_session(f):
|
||||
f = click.option(
|
||||
"--system-message",
|
||||
"-S",
|
||||
|
|
@ -155,6 +197,14 @@ def uses_new_session(f):
|
|||
expose_value=False,
|
||||
is_eager=True,
|
||||
)(f)
|
||||
f = click.option(
|
||||
"--backend-help",
|
||||
is_flag=True,
|
||||
is_eager=True,
|
||||
callback=backend_help,
|
||||
expose_value=False,
|
||||
help="Show information about backend options",
|
||||
)(f)
|
||||
f = click.option(
|
||||
"--backend-option",
|
||||
"-B",
|
||||
|
|
@ -172,6 +222,7 @@ def uses_new_session(f):
|
|||
callback=do_session_new,
|
||||
expose_value=False,
|
||||
)(f)
|
||||
f = click.command(cls=CommandWithBackendHelp)(f)
|
||||
return f
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue