diff --git a/src/chap/backends/textgen.py b/src/chap/backends/textgen.py index 949c67f..f857680 100644 --- a/src/chap/backends/textgen.py +++ b/src/chap/backends/textgen.py @@ -61,8 +61,8 @@ AI: Hello! How can I assist you today?""" full_prompt = session + [User(query)] del full_prompt[1:-max_query_size] new_data = old_data = full_query = "\n".join( - f"{role_map.get(q.role,'')}{q.content}\n" for q in full_prompt - ) + f"\n{role_map.get('assistant')}" + f"{role_map.get(q.role,)}{q.content}\n" for q in full_prompt + ) + f"\n{role_map.get(Role.ASSISTANT)}" try: async with websockets.connect( f"ws://{self.parameters.server_hostname}:7860/queue/join" diff --git a/src/chap/session.py b/src/chap/session.py index 5868296..e680718 100644 --- a/src/chap/session.py +++ b/src/chap/session.py @@ -4,6 +4,7 @@ from __future__ import annotations +import enum import json import pathlib from dataclasses import asdict, dataclass @@ -12,8 +13,7 @@ from typing import Union, cast from typing_extensions import TypedDict -# not an enum.Enum because these objects are not json-serializable, sigh -class Role: +class Role(str, enum.Enum): ASSISTANT = "assistant" SYSTEM = "system" USER = "user" @@ -23,7 +23,7 @@ class Role: class Message: """Represents one Message within a chap Session""" - role: str + role: Role content: str