Reorganize session, ditching dataclasses_json

this needs more testing (or more typing) and it breaks the plugins
I just released. oof.
This commit is contained in:
Jeff Epler 2023-11-08 19:41:51 -06:00
parent 29cc2edfb3
commit 396ef3164b
No known key found for this signature in database
GPG key ID: D5BF15AB975AB4DE
15 changed files with 112 additions and 95 deletions

View file

@ -37,5 +37,5 @@ repos:
rev: v2.17.0
hooks:
- id: pylint
additional_dependencies: [click,dataclasses_json,httpx,lorem-text,simple-parsing,'textual>=0.18.0',tiktoken,websockets]
additional_dependencies: [click,httpx,lorem-text,simple-parsing,'textual>=0.18.0',tiktoken,websockets]
args: ['--source-roots', 'src']

View file

@ -3,7 +3,6 @@
# SPDX-License-Identifier: Unlicense
click
dataclasses_json
httpx
lorem-text
platformdirs

View file

@ -9,7 +9,7 @@ from dataclasses import dataclass
import httpx
from ..key import get_key
from ..session import Assistant, User
from ..session import Assistant, Role, User
class HuggingFace:
@ -39,11 +39,11 @@ A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, a
if not content:
continue
result.append(content)
if m.role == "system":
if m.role == Role.SYSTEM:
result.append(self.parameters.after_system)
elif m.role == "assistant":
elif m.role == Role.ASSISTANT:
result.append(self.parameters.after_assistant)
elif m.role == "user":
elif m.role == Role.USER:
result.append(self.parameters.after_user)
full_query = "".join(result)
return full_query
@ -82,7 +82,7 @@ A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, a
self, session, query, *, max_query_size=5, timeout=180
): # pylint: disable=unused-argument,too-many-locals,too-many-branches
new_content = []
inputs = self.make_full_query(session.session + [User(query)], max_query_size)
inputs = self.make_full_query(session + [User(query)], max_query_size)
try:
async for content in self.chained_query(inputs, timeout=timeout):
if not new_content:
@ -99,13 +99,13 @@ A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, a
new_content.append(content)
yield content
session.session.extend([User(query), Assistant("".join(new_content))])
session.extend([User(query), Assistant("".join(new_content))])
def ask(self, session, query, *, max_query_size=5, timeout=60):
asyncio.run(
self.aask(session, query, max_query_size=max_query_size, timeout=timeout)
)
return session.session[-1].message
return session[-1].content
@classmethod
def get_key(cls):

View file

@ -8,7 +8,7 @@ from dataclasses import dataclass
import httpx
from ..session import Assistant, User
from ..session import Assistant, Role, User
class LlamaCpp:
@ -37,11 +37,11 @@ A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, a
if not content:
continue
result.append(content)
if m.role == "system":
if m.role == Role.SYSTEM:
result.append(self.parameters.after_system)
elif m.role == "assistant":
elif m.role == Role.ASSISTANT:
result.append(self.parameters.after_assistant)
elif m.role == "user":
elif m.role == Role.USER:
result.append(self.parameters.after_user)
full_query = "".join(result)
return full_query
@ -50,9 +50,7 @@ A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, a
self, session, query, *, max_query_size=5, timeout=180
): # pylint: disable=unused-argument,too-many-locals,too-many-branches
params = {
"prompt": self.make_full_query(
session.session + [User(query)], max_query_size
),
"prompt": self.make_full_query(session + [User(query)], max_query_size),
"stream": True,
"stop": ["</s>", "<s>", "[INST]"],
}
@ -87,13 +85,13 @@ A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, a
new_content.append(content)
yield content
session.session.extend([User(query), Assistant("".join(new_content))])
session.extend([User(query), Assistant("".join(new_content))])
def ask(self, session, query, *, max_query_size=5, timeout=60):
asyncio.run(
self.aask(session, query, max_query_size=max_query_size, timeout=timeout)
)
return session.session[-1].message
return session[-1].content
def factory():

View file

@ -51,7 +51,7 @@ class Lorem:
new_content = lorem.paragraphs(
random.randint(self.parameters.paragraph_lo, self.parameters.paragraph_hi)
).replace("\n", "\n\n")
session.session.extend([User(query), Assistant("".join(new_content))])
session.extend([User(query), Assistant("".join(new_content))])
return new_content

View file

@ -10,7 +10,7 @@ import httpx
import tiktoken
from ..key import get_key
from ..session import Assistant, Session, User
from ..session import Assistant, User
@dataclass(frozen=True)
@ -97,10 +97,10 @@ class ChatGPT:
else:
break
result.extend(reversed(parts))
return Session(result)
return result
def ask(self, session, query, *, timeout=60):
full_prompt = self.make_full_prompt(session.session + [User(query)])
full_prompt = self.make_full_prompt(session + [User(query)])
response = httpx.post(
"https://api.openai.com/v1/chat/completions",
json={
@ -125,11 +125,11 @@ class ChatGPT:
print("Failure", response.status_code, response.text)
return None
session.session.extend([User(query), Assistant(result)])
session.extend([User(query), Assistant(result)])
return result
async def aask(self, session, query, *, timeout=60):
full_prompt = self.make_full_prompt(session.session + [User(query)])
full_prompt = self.make_full_prompt(session + [User(query)])
new_content = []
try:
async with httpx.AsyncClient(timeout=timeout) as client:
@ -167,7 +167,7 @@ class ChatGPT:
new_content.append(content)
yield content
session.session.extend([User(query), Assistant("".join(new_content))])
session.extend([User(query), Assistant("".join(new_content))])
@classmethod
def get_key(cls):

View file

@ -9,7 +9,7 @@ import uuid
import websockets
from ..key import get_key
from ..session import Assistant, Session, User
from ..session import Assistant, Role, User
class Textgen:
@ -44,15 +44,13 @@ AI: Hello! How can I assist you today?"""
session_hash = str(uuid.uuid4())
role_map = {
"user": "USER: ",
"assistant": "AI: ",
Role.USER: "USER: ",
Role.ASSISTANT: "AI: ",
}
full_prompt = Session(session.session + [User(query)])
del full_prompt.session[1:-max_query_size]
full_prompt = session + [User(query)]
del full_prompt[1:-max_query_size]
new_data = old_data = full_query = (
"\n".join(
f"{role_map.get(q.role,'')}{q.content}\n" for q in full_prompt.session
)
"\n".join(f"{role_map.get(q.role,'')}{q.content}\n" for q in full_prompt)
+ f"\n{role_map.get('assistant')}"
)
try:
@ -124,13 +122,13 @@ AI: Hello! How can I assist you today?"""
yield content
all_response = new_data[len(full_query) :]
session.session.extend([User(query), Assistant(all_response)])
session.extend([User(query), Assistant(all_response)])
def ask(self, session, query, *, max_query_size=5, timeout=60):
asyncio.run(
self.aask(session, query, max_query_size=max_query_size, timeout=timeout)
)
return session.session[-1].message
return session[-1].content
def factory():

View file

@ -9,6 +9,7 @@ import click
import rich
from ..core import command_uses_new_session
from ..session import session_to_file
bold = "\033[1m"
nobold = "\033[m"
@ -102,8 +103,7 @@ def main(obj, prompt, print_prompt):
print(f"Saving session to {session_filename}", file=sys.stderr)
if response is not None:
with open(session_filename, "w", encoding="utf-8") as f:
f.write(session.to_json())
session_to_file(session, session_filename)
if __name__ == "__main__":

View file

@ -5,6 +5,7 @@
import click
from ..core import command_uses_existing_session
from ..session import Role
@command_uses_existing_session
@ -14,13 +15,13 @@ def main(obj, no_system):
session = obj.session
first = True
for row in session.session:
for row in session:
if not first:
print()
first = False
if row.role == "user":
if row.role == Role.USER:
decoration = "**"
elif row.role == "system":
elif row.role == Role.SYSTEM:
if no_system:
continue
decoration = "_"

View file

@ -6,13 +6,14 @@ from __future__ import annotations
import pathlib
import re
import sys
from typing import Iterable, Optional, Tuple
import click
import rich
from ..core import conversations_path as default_conversations_path
from ..session import Message, Session
from ..session import Message, session_from_file
from .render import to_markdown
@ -22,11 +23,15 @@ def list_files_matching_rx(
for conversation in (conversations_path or default_conversations_path).glob(
"*.json"
):
with open(conversation, "r", encoding="utf-8") as f:
session = Session.from_json(f.read()) # pylint: disable=no-member
for message in session.session:
if isinstance(message.content, str) and rx.search(message.content):
yield conversation, message
try:
session = session_from_file(conversation) # pylint: disable=no-member
except Exception as e: # pylint: disable=broad-exception-caught
print(f"Failed to read {conversation}: {e}", file=sys.stderr)
continue
for message in session:
if isinstance(message.content, str) and rx.search(message.content):
yield conversation, message
@click.command

View file

@ -11,27 +11,27 @@ import click
import rich
from ..core import conversations_path, new_session_path
from ..session import Message, Session
from ..session import Message, Role, new_session, session_to_file
console = rich.get_console()
def iter_sessions(name, content, session_in, node_id):
node = content["mapping"][node_id]
session = Session(session_in.session[:])
session = session_in[:]
if "message" in node:
role = node["message"]["author"]["role"]
text_content = "".join(node["message"]["content"]["parts"])
session.session.append(Message(role=role, content=text_content))
session.append(Message(role=role, content=text_content))
if children := node.get("children"):
for c in children:
yield from iter_sessions(name, content, session, c)
else:
title = content.get("title") or "Untitled"
session.session[0] = Message(
"system",
session[0] = Message(
Role.SYSTEM,
f"# {title}\nChatGPT session imported from {name}, branch {node_id}.\n\n",
)
yield node_id, session
@ -40,7 +40,7 @@ def iter_sessions(name, content, session_in, node_id):
def do_import(output_directory, f):
stem = pathlib.Path(f.name).stem
content = json.load(f)
session = Session.new_session()
session = new_session()
default_branch = content["current_node"]
console.print(f"Importing [bold]{f.name}[nobold]")
@ -52,8 +52,7 @@ def do_import(output_directory, f):
session_filename = new_session_path(
output_directory / (f"{stem}_{branch}.json")
)
with open(session_filename, "w", encoding="utf-8") as f_out:
f_out.write(session.to_json()) # pylint: disable=no-member
session_to_file(session, session_filename)
console.print(f" -> {session_filename}")

View file

@ -8,13 +8,14 @@ from markdown_it import MarkdownIt
from rich.markdown import Markdown
from ..core import command_uses_existing_session
from ..session import Role
def to_markdown(message):
role = message.role
if role == "user":
if role == Role.USER:
style = "bold"
elif role == "system":
elif role == Role.SYSTEM:
style = "italic"
else:
style = "none"
@ -33,11 +34,11 @@ def main(obj, no_system):
console = rich.get_console()
first = True
for row in session.session:
for row in session:
if not first:
console.print()
first = False
if no_system and row.role == "system":
if no_system and row.role == Role.SYSTEM:
continue
console.print(to_markdown(row))

View file

@ -14,7 +14,7 @@ from textual.containers import Container, Horizontal, VerticalScroll
from textual.widgets import Button, Footer, Input, LoadingIndicator, Markdown
from ..core import command_uses_new_session, get_api, new_session_path
from ..session import Assistant, Session, User
from ..session import Assistant, User, new_session, session_to_file
def parser_factory():
@ -57,7 +57,9 @@ class Tui(App):
def __init__(self, api=None, session=None):
super().__init__()
self.api = api or get_api("lorem")
self.session = session or Session.new_session(self.api.system_message)
self.session = (
new_session(self.api.system_message) if session is None else session
)
@property
def spinner(self):
@ -82,7 +84,7 @@ class Tui(App):
def compose(self):
yield Footer()
yield VerticalScroll(
*[markdown_for_step(step) for step in self.session.session],
*[markdown_for_step(step) for step in self.session],
# The pad container helps reduce flickering when rendering fresh
# content and scrolling. (it's not clear why this makes a
# difference and it'd be nice to be rid of the workaround)
@ -122,13 +124,13 @@ class Tui(App):
markdown.disabled = True
# Construct a fake session with only select items
session = Session()
for si, wi in zip(self.session.session, self.container.children):
session = []
for si, wi in zip(self.session, self.container.children):
if not wi.has_class("history_exclude"):
session.session.append(si)
session.append(si)
message = Assistant("")
self.session.session.extend(
self.session.extend(
[
User(query),
message,
@ -158,7 +160,7 @@ class Tui(App):
await asyncio.gather(render_fun(), get_token_fun())
finally:
self.input.value = ""
all_output = self.session.session[-1].content
all_output = self.session[-1].content
output.update(all_output)
output._markdown = all_output # pylint: disable=protected-access
self.container.scroll_end()
@ -227,13 +229,12 @@ class Tui(App):
widget = children[idx]
# Save a copy of the discussion before this deletion
with open(new_session_path(), "w", encoding="utf-8") as f:
f.write(self.session.to_json())
session_to_file(self.session, new_session_path())
query = self.session.session[idx].content
query = self.session[idx].content
self.input.value = query
del self.session.session[idx:]
del self.session[idx:]
for child in self.container.children[idx:-1]:
await child.remove()
@ -257,8 +258,7 @@ def main(obj):
print(f"Saving session to {session_filename}", file=sys.stderr)
with open(session_filename, "w", encoding="utf-8") as f:
f.write(session.to_json())
session_to_file(session, session_filename)
if __name__ == "__main__":

View file

@ -16,7 +16,7 @@ import platformdirs
from simple_parsing.docstring import get_attribute_docstring
from . import backends, commands # pylint: disable=no-name-in-module
from .session import Session
from .session import Message, System, session_from_file
conversations_path = platformdirs.user_state_path("chap") / "conversations"
conversations_path.mkdir(parents=True, exist_ok=True)
@ -76,8 +76,7 @@ def do_session_continue(ctx, param, value):
raise click.BadParameter(
param, "--continue-session, --last and --new-session are mutually exclusive"
)
with open(value, "r", encoding="utf-8") as f:
ctx.obj.session = Session.from_json(f.read()) # pylint: disable=no-member
ctx.obj.session = session_from_file(value)
ctx.obj.session_filename = value
@ -96,7 +95,7 @@ def do_session_new(ctx, param, value):
)
session_filename = new_session_path(value)
system_message = ctx.obj.system_message or ctx.obj.api.system_message
ctx.obj.session = Session.new_session(system_message)
ctx.obj.session = [System(system_message)]
ctx.obj.session_filename = session_filename
@ -250,7 +249,7 @@ def version_callback(ctx, param, value) -> None: # pylint: disable=unused-argum
class Obj:
api: object = None
system_message: object = None
session: Session | None = None
session: list[Message] | None = None
class MyCLI(click.MultiCommand):

View file

@ -2,12 +2,17 @@
#
# SPDX-License-Identifier: MIT
from dataclasses import dataclass, field
from dataclasses_json import dataclass_json
import json
from dataclasses import asdict, dataclass
# not an enum.Enum because these objects are not json-serializable, sigh
class Role: # pylint: disable=too-few-public-methods
ASSISTANT = "assistant"
SYSTEM = "system"
USER = "user"
@dataclass_json
@dataclass
class Message:
"""Represents one Message within a chap Session"""
@ -17,27 +22,39 @@ class Message:
def Assistant(content):
return Message("assistant", content)
return Message(Role.ASSISTANT, content)
def System(content):
return Message("system", content)
return Message(Role.SYSTEM, content)
def User(content):
return Message("user", content)
return Message(Role.USER, content)
@dataclass_json
@dataclass
class Session:
"""Represents a series of Messages"""
def new_session(
system_message: str = "You're named Chap (they/them). Chap strives to answer truthfully and frankly. You have opinions, preferences, and emotions, you can say when you don't know, and you prefer to speak in plain language.",
):
return [System(system_message)]
session: list[Message] = field(default_factory=list)
@classmethod
def new_session(
cls,
system_message: str = "You're named Chap (they/them). Chap strives to answer truthfully and frankly. You have opinions, preferences, and emotions, you can say when you don't know, and you prefer to speak in plain language.",
):
return Session([System(system_message)])
def session_to_json(session):
return json.dumps([asdict(message) for message in session])
def session_from_json(data):
j = json.loads(data)
if isinstance(j, dict):
j = j["session"]
return [Message(**mapping) for mapping in j]
def session_from_file(path):
with open(path, "r", encoding="utf-8") as f:
return session_from_json(f.read())
def session_to_file(session, path):
with open(path, "w", encoding="utf-8") as f:
return f.write(session_to_json(session))