Reorganize session, ditching dataclasses_json
this needs more testing (or more typing) and it breaks the plugins I just released. oof.
This commit is contained in:
parent
29cc2edfb3
commit
396ef3164b
15 changed files with 112 additions and 95 deletions
|
|
@ -37,5 +37,5 @@ repos:
|
|||
rev: v2.17.0
|
||||
hooks:
|
||||
- id: pylint
|
||||
additional_dependencies: [click,dataclasses_json,httpx,lorem-text,simple-parsing,'textual>=0.18.0',tiktoken,websockets]
|
||||
additional_dependencies: [click,httpx,lorem-text,simple-parsing,'textual>=0.18.0',tiktoken,websockets]
|
||||
args: ['--source-roots', 'src']
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@
|
|||
# SPDX-License-Identifier: Unlicense
|
||||
|
||||
click
|
||||
dataclasses_json
|
||||
httpx
|
||||
lorem-text
|
||||
platformdirs
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from dataclasses import dataclass
|
|||
import httpx
|
||||
|
||||
from ..key import get_key
|
||||
from ..session import Assistant, User
|
||||
from ..session import Assistant, Role, User
|
||||
|
||||
|
||||
class HuggingFace:
|
||||
|
|
@ -39,11 +39,11 @@ A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, a
|
|||
if not content:
|
||||
continue
|
||||
result.append(content)
|
||||
if m.role == "system":
|
||||
if m.role == Role.SYSTEM:
|
||||
result.append(self.parameters.after_system)
|
||||
elif m.role == "assistant":
|
||||
elif m.role == Role.ASSISTANT:
|
||||
result.append(self.parameters.after_assistant)
|
||||
elif m.role == "user":
|
||||
elif m.role == Role.USER:
|
||||
result.append(self.parameters.after_user)
|
||||
full_query = "".join(result)
|
||||
return full_query
|
||||
|
|
@ -82,7 +82,7 @@ A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, a
|
|||
self, session, query, *, max_query_size=5, timeout=180
|
||||
): # pylint: disable=unused-argument,too-many-locals,too-many-branches
|
||||
new_content = []
|
||||
inputs = self.make_full_query(session.session + [User(query)], max_query_size)
|
||||
inputs = self.make_full_query(session + [User(query)], max_query_size)
|
||||
try:
|
||||
async for content in self.chained_query(inputs, timeout=timeout):
|
||||
if not new_content:
|
||||
|
|
@ -99,13 +99,13 @@ A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, a
|
|||
new_content.append(content)
|
||||
yield content
|
||||
|
||||
session.session.extend([User(query), Assistant("".join(new_content))])
|
||||
session.extend([User(query), Assistant("".join(new_content))])
|
||||
|
||||
def ask(self, session, query, *, max_query_size=5, timeout=60):
|
||||
asyncio.run(
|
||||
self.aask(session, query, max_query_size=max_query_size, timeout=timeout)
|
||||
)
|
||||
return session.session[-1].message
|
||||
return session[-1].content
|
||||
|
||||
@classmethod
|
||||
def get_key(cls):
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from dataclasses import dataclass
|
|||
|
||||
import httpx
|
||||
|
||||
from ..session import Assistant, User
|
||||
from ..session import Assistant, Role, User
|
||||
|
||||
|
||||
class LlamaCpp:
|
||||
|
|
@ -37,11 +37,11 @@ A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, a
|
|||
if not content:
|
||||
continue
|
||||
result.append(content)
|
||||
if m.role == "system":
|
||||
if m.role == Role.SYSTEM:
|
||||
result.append(self.parameters.after_system)
|
||||
elif m.role == "assistant":
|
||||
elif m.role == Role.ASSISTANT:
|
||||
result.append(self.parameters.after_assistant)
|
||||
elif m.role == "user":
|
||||
elif m.role == Role.USER:
|
||||
result.append(self.parameters.after_user)
|
||||
full_query = "".join(result)
|
||||
return full_query
|
||||
|
|
@ -50,9 +50,7 @@ A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, a
|
|||
self, session, query, *, max_query_size=5, timeout=180
|
||||
): # pylint: disable=unused-argument,too-many-locals,too-many-branches
|
||||
params = {
|
||||
"prompt": self.make_full_query(
|
||||
session.session + [User(query)], max_query_size
|
||||
),
|
||||
"prompt": self.make_full_query(session + [User(query)], max_query_size),
|
||||
"stream": True,
|
||||
"stop": ["</s>", "<s>", "[INST]"],
|
||||
}
|
||||
|
|
@ -87,13 +85,13 @@ A dialog, where USER interacts with AI. AI is helpful, kind, obedient, honest, a
|
|||
new_content.append(content)
|
||||
yield content
|
||||
|
||||
session.session.extend([User(query), Assistant("".join(new_content))])
|
||||
session.extend([User(query), Assistant("".join(new_content))])
|
||||
|
||||
def ask(self, session, query, *, max_query_size=5, timeout=60):
|
||||
asyncio.run(
|
||||
self.aask(session, query, max_query_size=max_query_size, timeout=timeout)
|
||||
)
|
||||
return session.session[-1].message
|
||||
return session[-1].content
|
||||
|
||||
|
||||
def factory():
|
||||
|
|
|
|||
|
|
@ -51,7 +51,7 @@ class Lorem:
|
|||
new_content = lorem.paragraphs(
|
||||
random.randint(self.parameters.paragraph_lo, self.parameters.paragraph_hi)
|
||||
).replace("\n", "\n\n")
|
||||
session.session.extend([User(query), Assistant("".join(new_content))])
|
||||
session.extend([User(query), Assistant("".join(new_content))])
|
||||
return new_content
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ import httpx
|
|||
import tiktoken
|
||||
|
||||
from ..key import get_key
|
||||
from ..session import Assistant, Session, User
|
||||
from ..session import Assistant, User
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
|
@ -97,10 +97,10 @@ class ChatGPT:
|
|||
else:
|
||||
break
|
||||
result.extend(reversed(parts))
|
||||
return Session(result)
|
||||
return result
|
||||
|
||||
def ask(self, session, query, *, timeout=60):
|
||||
full_prompt = self.make_full_prompt(session.session + [User(query)])
|
||||
full_prompt = self.make_full_prompt(session + [User(query)])
|
||||
response = httpx.post(
|
||||
"https://api.openai.com/v1/chat/completions",
|
||||
json={
|
||||
|
|
@ -125,11 +125,11 @@ class ChatGPT:
|
|||
print("Failure", response.status_code, response.text)
|
||||
return None
|
||||
|
||||
session.session.extend([User(query), Assistant(result)])
|
||||
session.extend([User(query), Assistant(result)])
|
||||
return result
|
||||
|
||||
async def aask(self, session, query, *, timeout=60):
|
||||
full_prompt = self.make_full_prompt(session.session + [User(query)])
|
||||
full_prompt = self.make_full_prompt(session + [User(query)])
|
||||
new_content = []
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
|
|
@ -167,7 +167,7 @@ class ChatGPT:
|
|||
new_content.append(content)
|
||||
yield content
|
||||
|
||||
session.session.extend([User(query), Assistant("".join(new_content))])
|
||||
session.extend([User(query), Assistant("".join(new_content))])
|
||||
|
||||
@classmethod
|
||||
def get_key(cls):
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import uuid
|
|||
import websockets
|
||||
|
||||
from ..key import get_key
|
||||
from ..session import Assistant, Session, User
|
||||
from ..session import Assistant, Role, User
|
||||
|
||||
|
||||
class Textgen:
|
||||
|
|
@ -44,15 +44,13 @@ AI: Hello! How can I assist you today?"""
|
|||
session_hash = str(uuid.uuid4())
|
||||
|
||||
role_map = {
|
||||
"user": "USER: ",
|
||||
"assistant": "AI: ",
|
||||
Role.USER: "USER: ",
|
||||
Role.ASSISTANT: "AI: ",
|
||||
}
|
||||
full_prompt = Session(session.session + [User(query)])
|
||||
del full_prompt.session[1:-max_query_size]
|
||||
full_prompt = session + [User(query)]
|
||||
del full_prompt[1:-max_query_size]
|
||||
new_data = old_data = full_query = (
|
||||
"\n".join(
|
||||
f"{role_map.get(q.role,'')}{q.content}\n" for q in full_prompt.session
|
||||
)
|
||||
"\n".join(f"{role_map.get(q.role,'')}{q.content}\n" for q in full_prompt)
|
||||
+ f"\n{role_map.get('assistant')}"
|
||||
)
|
||||
try:
|
||||
|
|
@ -124,13 +122,13 @@ AI: Hello! How can I assist you today?"""
|
|||
yield content
|
||||
|
||||
all_response = new_data[len(full_query) :]
|
||||
session.session.extend([User(query), Assistant(all_response)])
|
||||
session.extend([User(query), Assistant(all_response)])
|
||||
|
||||
def ask(self, session, query, *, max_query_size=5, timeout=60):
|
||||
asyncio.run(
|
||||
self.aask(session, query, max_query_size=max_query_size, timeout=timeout)
|
||||
)
|
||||
return session.session[-1].message
|
||||
return session[-1].content
|
||||
|
||||
|
||||
def factory():
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import click
|
|||
import rich
|
||||
|
||||
from ..core import command_uses_new_session
|
||||
from ..session import session_to_file
|
||||
|
||||
bold = "\033[1m"
|
||||
nobold = "\033[m"
|
||||
|
|
@ -102,8 +103,7 @@ def main(obj, prompt, print_prompt):
|
|||
|
||||
print(f"Saving session to {session_filename}", file=sys.stderr)
|
||||
if response is not None:
|
||||
with open(session_filename, "w", encoding="utf-8") as f:
|
||||
f.write(session.to_json())
|
||||
session_to_file(session, session_filename)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
import click
|
||||
|
||||
from ..core import command_uses_existing_session
|
||||
from ..session import Role
|
||||
|
||||
|
||||
@command_uses_existing_session
|
||||
|
|
@ -14,13 +15,13 @@ def main(obj, no_system):
|
|||
session = obj.session
|
||||
|
||||
first = True
|
||||
for row in session.session:
|
||||
for row in session:
|
||||
if not first:
|
||||
print()
|
||||
first = False
|
||||
if row.role == "user":
|
||||
if row.role == Role.USER:
|
||||
decoration = "**"
|
||||
elif row.role == "system":
|
||||
elif row.role == Role.SYSTEM:
|
||||
if no_system:
|
||||
continue
|
||||
decoration = "_"
|
||||
|
|
|
|||
|
|
@ -6,13 +6,14 @@ from __future__ import annotations
|
|||
|
||||
import pathlib
|
||||
import re
|
||||
import sys
|
||||
from typing import Iterable, Optional, Tuple
|
||||
|
||||
import click
|
||||
import rich
|
||||
|
||||
from ..core import conversations_path as default_conversations_path
|
||||
from ..session import Message, Session
|
||||
from ..session import Message, session_from_file
|
||||
from .render import to_markdown
|
||||
|
||||
|
||||
|
|
@ -22,11 +23,15 @@ def list_files_matching_rx(
|
|||
for conversation in (conversations_path or default_conversations_path).glob(
|
||||
"*.json"
|
||||
):
|
||||
with open(conversation, "r", encoding="utf-8") as f:
|
||||
session = Session.from_json(f.read()) # pylint: disable=no-member
|
||||
for message in session.session:
|
||||
if isinstance(message.content, str) and rx.search(message.content):
|
||||
yield conversation, message
|
||||
try:
|
||||
session = session_from_file(conversation) # pylint: disable=no-member
|
||||
except Exception as e: # pylint: disable=broad-exception-caught
|
||||
print(f"Failed to read {conversation}: {e}", file=sys.stderr)
|
||||
continue
|
||||
|
||||
for message in session:
|
||||
if isinstance(message.content, str) and rx.search(message.content):
|
||||
yield conversation, message
|
||||
|
||||
|
||||
@click.command
|
||||
|
|
|
|||
|
|
@ -11,27 +11,27 @@ import click
|
|||
import rich
|
||||
|
||||
from ..core import conversations_path, new_session_path
|
||||
from ..session import Message, Session
|
||||
from ..session import Message, Role, new_session, session_to_file
|
||||
|
||||
console = rich.get_console()
|
||||
|
||||
|
||||
def iter_sessions(name, content, session_in, node_id):
|
||||
node = content["mapping"][node_id]
|
||||
session = Session(session_in.session[:])
|
||||
session = session_in[:]
|
||||
|
||||
if "message" in node:
|
||||
role = node["message"]["author"]["role"]
|
||||
text_content = "".join(node["message"]["content"]["parts"])
|
||||
session.session.append(Message(role=role, content=text_content))
|
||||
session.append(Message(role=role, content=text_content))
|
||||
|
||||
if children := node.get("children"):
|
||||
for c in children:
|
||||
yield from iter_sessions(name, content, session, c)
|
||||
else:
|
||||
title = content.get("title") or "Untitled"
|
||||
session.session[0] = Message(
|
||||
"system",
|
||||
session[0] = Message(
|
||||
Role.SYSTEM,
|
||||
f"# {title}\nChatGPT session imported from {name}, branch {node_id}.\n\n",
|
||||
)
|
||||
yield node_id, session
|
||||
|
|
@ -40,7 +40,7 @@ def iter_sessions(name, content, session_in, node_id):
|
|||
def do_import(output_directory, f):
|
||||
stem = pathlib.Path(f.name).stem
|
||||
content = json.load(f)
|
||||
session = Session.new_session()
|
||||
session = new_session()
|
||||
|
||||
default_branch = content["current_node"]
|
||||
console.print(f"Importing [bold]{f.name}[nobold]")
|
||||
|
|
@ -52,8 +52,7 @@ def do_import(output_directory, f):
|
|||
session_filename = new_session_path(
|
||||
output_directory / (f"{stem}_{branch}.json")
|
||||
)
|
||||
with open(session_filename, "w", encoding="utf-8") as f_out:
|
||||
f_out.write(session.to_json()) # pylint: disable=no-member
|
||||
session_to_file(session, session_filename)
|
||||
console.print(f" -> {session_filename}")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -8,13 +8,14 @@ from markdown_it import MarkdownIt
|
|||
from rich.markdown import Markdown
|
||||
|
||||
from ..core import command_uses_existing_session
|
||||
from ..session import Role
|
||||
|
||||
|
||||
def to_markdown(message):
|
||||
role = message.role
|
||||
if role == "user":
|
||||
if role == Role.USER:
|
||||
style = "bold"
|
||||
elif role == "system":
|
||||
elif role == Role.SYSTEM:
|
||||
style = "italic"
|
||||
else:
|
||||
style = "none"
|
||||
|
|
@ -33,11 +34,11 @@ def main(obj, no_system):
|
|||
|
||||
console = rich.get_console()
|
||||
first = True
|
||||
for row in session.session:
|
||||
for row in session:
|
||||
if not first:
|
||||
console.print()
|
||||
first = False
|
||||
if no_system and row.role == "system":
|
||||
if no_system and row.role == Role.SYSTEM:
|
||||
continue
|
||||
console.print(to_markdown(row))
|
||||
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from textual.containers import Container, Horizontal, VerticalScroll
|
|||
from textual.widgets import Button, Footer, Input, LoadingIndicator, Markdown
|
||||
|
||||
from ..core import command_uses_new_session, get_api, new_session_path
|
||||
from ..session import Assistant, Session, User
|
||||
from ..session import Assistant, User, new_session, session_to_file
|
||||
|
||||
|
||||
def parser_factory():
|
||||
|
|
@ -57,7 +57,9 @@ class Tui(App):
|
|||
def __init__(self, api=None, session=None):
|
||||
super().__init__()
|
||||
self.api = api or get_api("lorem")
|
||||
self.session = session or Session.new_session(self.api.system_message)
|
||||
self.session = (
|
||||
new_session(self.api.system_message) if session is None else session
|
||||
)
|
||||
|
||||
@property
|
||||
def spinner(self):
|
||||
|
|
@ -82,7 +84,7 @@ class Tui(App):
|
|||
def compose(self):
|
||||
yield Footer()
|
||||
yield VerticalScroll(
|
||||
*[markdown_for_step(step) for step in self.session.session],
|
||||
*[markdown_for_step(step) for step in self.session],
|
||||
# The pad container helps reduce flickering when rendering fresh
|
||||
# content and scrolling. (it's not clear why this makes a
|
||||
# difference and it'd be nice to be rid of the workaround)
|
||||
|
|
@ -122,13 +124,13 @@ class Tui(App):
|
|||
markdown.disabled = True
|
||||
|
||||
# Construct a fake session with only select items
|
||||
session = Session()
|
||||
for si, wi in zip(self.session.session, self.container.children):
|
||||
session = []
|
||||
for si, wi in zip(self.session, self.container.children):
|
||||
if not wi.has_class("history_exclude"):
|
||||
session.session.append(si)
|
||||
session.append(si)
|
||||
|
||||
message = Assistant("")
|
||||
self.session.session.extend(
|
||||
self.session.extend(
|
||||
[
|
||||
User(query),
|
||||
message,
|
||||
|
|
@ -158,7 +160,7 @@ class Tui(App):
|
|||
await asyncio.gather(render_fun(), get_token_fun())
|
||||
finally:
|
||||
self.input.value = ""
|
||||
all_output = self.session.session[-1].content
|
||||
all_output = self.session[-1].content
|
||||
output.update(all_output)
|
||||
output._markdown = all_output # pylint: disable=protected-access
|
||||
self.container.scroll_end()
|
||||
|
|
@ -227,13 +229,12 @@ class Tui(App):
|
|||
widget = children[idx]
|
||||
|
||||
# Save a copy of the discussion before this deletion
|
||||
with open(new_session_path(), "w", encoding="utf-8") as f:
|
||||
f.write(self.session.to_json())
|
||||
session_to_file(self.session, new_session_path())
|
||||
|
||||
query = self.session.session[idx].content
|
||||
query = self.session[idx].content
|
||||
self.input.value = query
|
||||
|
||||
del self.session.session[idx:]
|
||||
del self.session[idx:]
|
||||
for child in self.container.children[idx:-1]:
|
||||
await child.remove()
|
||||
|
||||
|
|
@ -257,8 +258,7 @@ def main(obj):
|
|||
|
||||
print(f"Saving session to {session_filename}", file=sys.stderr)
|
||||
|
||||
with open(session_filename, "w", encoding="utf-8") as f:
|
||||
f.write(session.to_json())
|
||||
session_to_file(session, session_filename)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ import platformdirs
|
|||
from simple_parsing.docstring import get_attribute_docstring
|
||||
|
||||
from . import backends, commands # pylint: disable=no-name-in-module
|
||||
from .session import Session
|
||||
from .session import Message, System, session_from_file
|
||||
|
||||
conversations_path = platformdirs.user_state_path("chap") / "conversations"
|
||||
conversations_path.mkdir(parents=True, exist_ok=True)
|
||||
|
|
@ -76,8 +76,7 @@ def do_session_continue(ctx, param, value):
|
|||
raise click.BadParameter(
|
||||
param, "--continue-session, --last and --new-session are mutually exclusive"
|
||||
)
|
||||
with open(value, "r", encoding="utf-8") as f:
|
||||
ctx.obj.session = Session.from_json(f.read()) # pylint: disable=no-member
|
||||
ctx.obj.session = session_from_file(value)
|
||||
ctx.obj.session_filename = value
|
||||
|
||||
|
||||
|
|
@ -96,7 +95,7 @@ def do_session_new(ctx, param, value):
|
|||
)
|
||||
session_filename = new_session_path(value)
|
||||
system_message = ctx.obj.system_message or ctx.obj.api.system_message
|
||||
ctx.obj.session = Session.new_session(system_message)
|
||||
ctx.obj.session = [System(system_message)]
|
||||
ctx.obj.session_filename = session_filename
|
||||
|
||||
|
||||
|
|
@ -250,7 +249,7 @@ def version_callback(ctx, param, value) -> None: # pylint: disable=unused-argum
|
|||
class Obj:
|
||||
api: object = None
|
||||
system_message: object = None
|
||||
session: Session | None = None
|
||||
session: list[Message] | None = None
|
||||
|
||||
|
||||
class MyCLI(click.MultiCommand):
|
||||
|
|
|
|||
|
|
@ -2,12 +2,17 @@
|
|||
#
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from dataclasses_json import dataclass_json
|
||||
import json
|
||||
from dataclasses import asdict, dataclass
|
||||
|
||||
|
||||
# not an enum.Enum because these objects are not json-serializable, sigh
|
||||
class Role: # pylint: disable=too-few-public-methods
|
||||
ASSISTANT = "assistant"
|
||||
SYSTEM = "system"
|
||||
USER = "user"
|
||||
|
||||
|
||||
@dataclass_json
|
||||
@dataclass
|
||||
class Message:
|
||||
"""Represents one Message within a chap Session"""
|
||||
|
|
@ -17,27 +22,39 @@ class Message:
|
|||
|
||||
|
||||
def Assistant(content):
|
||||
return Message("assistant", content)
|
||||
return Message(Role.ASSISTANT, content)
|
||||
|
||||
|
||||
def System(content):
|
||||
return Message("system", content)
|
||||
return Message(Role.SYSTEM, content)
|
||||
|
||||
|
||||
def User(content):
|
||||
return Message("user", content)
|
||||
return Message(Role.USER, content)
|
||||
|
||||
|
||||
@dataclass_json
|
||||
@dataclass
|
||||
class Session:
|
||||
"""Represents a series of Messages"""
|
||||
def new_session(
|
||||
system_message: str = "You're named Chap (they/them). Chap strives to answer truthfully and frankly. You have opinions, preferences, and emotions, you can say when you don't know, and you prefer to speak in plain language.",
|
||||
):
|
||||
return [System(system_message)]
|
||||
|
||||
session: list[Message] = field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
def new_session(
|
||||
cls,
|
||||
system_message: str = "You're named Chap (they/them). Chap strives to answer truthfully and frankly. You have opinions, preferences, and emotions, you can say when you don't know, and you prefer to speak in plain language.",
|
||||
):
|
||||
return Session([System(system_message)])
|
||||
def session_to_json(session):
|
||||
return json.dumps([asdict(message) for message in session])
|
||||
|
||||
|
||||
def session_from_json(data):
|
||||
j = json.loads(data)
|
||||
if isinstance(j, dict):
|
||||
j = j["session"]
|
||||
return [Message(**mapping) for mapping in j]
|
||||
|
||||
|
||||
def session_from_file(path):
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
return session_from_json(f.read())
|
||||
|
||||
|
||||
def session_to_file(session, path):
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
return f.write(session_to_json(session))
|
||||
|
|
|
|||
Loading…
Reference in a new issue