Merge pull request #25 from jepler/gpt-4-turbo
Add gpt-4-1106-preview (gpt-4-turbo) to model list
This commit is contained in:
commit
03e0adfbb1
5 changed files with 31 additions and 29 deletions
2
Makefile
2
Makefile
|
|
@ -4,7 +4,7 @@
|
|||
|
||||
.PHONY: mypy
|
||||
mypy: venv/bin/mypy
|
||||
venv/bin/mypy --strict -p chap
|
||||
venv/bin/mypy --strict --no-warn-unused-ignores -p chap
|
||||
|
||||
venv/bin/mypy:
|
||||
python -mvenv venv
|
||||
|
|
|
|||
|
|
@ -42,3 +42,5 @@ write_to = "src/chap/__version__.py"
|
|||
[tool.setuptools.dynamic]
|
||||
readme = {file = ["README.md"], content-type="text/markdown"}
|
||||
dependencies = {file = "requirements.txt"}
|
||||
[tool.setuptools.package-data]
|
||||
"pkgname" = ["py.typed"]
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
|
||||
import functools
|
||||
import json
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import AsyncGenerator, cast
|
||||
|
||||
|
|
@ -20,6 +21,7 @@ class EncodingMeta:
|
|||
encoding: tiktoken.Encoding
|
||||
tokens_per_message: int
|
||||
tokens_per_name: int
|
||||
tokens_overhead: int
|
||||
|
||||
@functools.lru_cache()
|
||||
def encode(self, s: str) -> list[int]:
|
||||
|
|
@ -27,47 +29,38 @@ class EncodingMeta:
|
|||
|
||||
def num_tokens_for_message(self, message: Message) -> int:
|
||||
# n.b. chap doesn't use message.name yet
|
||||
return len(self.encode(message.role)) + len(self.encode(message.content))
|
||||
return (
|
||||
len(self.encode(message.role))
|
||||
+ len(self.encode(message.content))
|
||||
+ self.tokens_per_message
|
||||
)
|
||||
|
||||
def num_tokens_for_messages(self, messages: Session) -> int:
|
||||
return sum(self.num_tokens_for_message(message) for message in messages) + 3
|
||||
return (
|
||||
sum(self.num_tokens_for_message(message) for message in messages)
|
||||
+ self.tokens_overhead
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@functools.cache
|
||||
def from_model(cls, model: str) -> "EncodingMeta":
|
||||
if model == "gpt-3.5-turbo":
|
||||
# print("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.")
|
||||
model = "gpt-3.5-turbo-0613"
|
||||
if model == "gpt-4":
|
||||
# print("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
|
||||
model = "gpt-4-0613"
|
||||
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
except KeyError:
|
||||
print("Warning: model not found. Using cl100k_base encoding.")
|
||||
warnings.warn("Warning: model not found. Using cl100k_base encoding.")
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
if model in {
|
||||
"gpt-3.5-turbo-0613",
|
||||
"gpt-3.5-turbo-16k-0613",
|
||||
"gpt-4-0314",
|
||||
"gpt-4-32k-0314",
|
||||
"gpt-4-0613",
|
||||
"gpt-4-32k-0613",
|
||||
}:
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
elif model == "gpt-3.5-turbo-0301":
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
tokens_overhead = 3
|
||||
|
||||
if model == "gpt-3.5-turbo-0301":
|
||||
tokens_per_message = (
|
||||
4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
|
||||
)
|
||||
tokens_per_name = -1 # if there's a name, the role is omitted
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"""EncodingMeta is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens."""
|
||||
)
|
||||
return cls(encoding, tokens_per_message, tokens_per_name)
|
||||
|
||||
return cls(encoding, tokens_per_message, tokens_per_name, tokens_overhead)
|
||||
|
||||
|
||||
class ChatGPT:
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ import pathlib
|
|||
import pkgutil
|
||||
import subprocess
|
||||
from dataclasses import MISSING, dataclass, fields
|
||||
from types import UnionType
|
||||
from typing import Any, AsyncGenerator, Callable, cast
|
||||
|
||||
import click
|
||||
|
|
@ -171,7 +172,10 @@ def format_backend_help(api: Backend, formatter: click.HelpFormatter) -> None:
|
|||
if doc:
|
||||
doc += " "
|
||||
doc += f"(Default: {default!r})"
|
||||
typename = f.type.__name__
|
||||
f_type = f.type
|
||||
if isinstance(f_type, UnionType):
|
||||
f_type = f_type.__args__[0]
|
||||
typename = f_type.__name__
|
||||
rows.append((f"-B {name}:{typename.upper()}", doc))
|
||||
formatter.write_dl(rows)
|
||||
|
||||
|
|
@ -191,8 +195,11 @@ def set_backend_option( # pylint: disable=unused-argument
|
|||
field = all_fields.get(name)
|
||||
if field is None:
|
||||
raise click.BadParameter(f"Invalid parameter {name}")
|
||||
f_type = field.type
|
||||
if isinstance(f_type, UnionType):
|
||||
f_type = f_type.__args__[0]
|
||||
try:
|
||||
tv = field.type(value)
|
||||
tv = f_type(value)
|
||||
except ValueError as e:
|
||||
raise click.BadParameter(
|
||||
f"Invalid value for {name} with value {value}: {e}"
|
||||
|
|
|
|||
0
src/chap/py.typed
Normal file
0
src/chap/py.typed
Normal file
Loading…
Reference in a new issue