re-work backend parameters again to work on 3.9 and 3.11

This commit is contained in:
Jeff Epler 2023-12-12 16:58:54 -06:00
parent 877246ac28
commit 86059a5d85
No known key found for this signature in database
GPG key ID: D5BF15AB975AB4DE

View file

@ -11,7 +11,17 @@ import pathlib
import pkgutil
import subprocess
from dataclasses import MISSING, Field, dataclass, fields
from typing import Any, AsyncGenerator, Callable, Optional, Union, cast
from typing import (
Any,
AsyncGenerator,
Callable,
Optional,
Union,
cast,
get_origin,
get_args,
)
import sys
import click
import platformdirs
@ -21,8 +31,11 @@ from typing_extensions import Protocol
from . import backends, commands
from .session import Message, Session, System, session_from_file
# 3.9 compatible version of `from types import UnionType`
UnionType = type(Union[int, list])
UnionType: type
if sys.version_info >= (3, 10):
from types import UnionType
else:
UnionType = type(Union[int, float])
conversations_path = platformdirs.user_state_path("chap") / "conversations"
conversations_path.mkdir(parents=True, exist_ok=True)
@ -69,10 +82,22 @@ def new_session_path(opt_path: Optional[pathlib.Path] = None) -> pathlib.Path:
)
def convert_str_to_field(ctx: click.Context, field: Field[Any], value: str) -> Any:
def get_field_type(field: Field[Any]) -> Any:
field_type = field.type
if isinstance(field_type, UnionType):
field_type = field_type.__args__[0]
if isinstance(field_type, str):
raise RuntimeError(
"parameters dataclass may not use 'from __future__ import annotations"
)
origin = get_origin(field_type)
if origin in (Union, UnionType):
for arg in get_args(field_type):
if arg is not None:
return arg
return field_type
def convert_str_to_field(ctx: click.Context, field: Field[Any], value: str) -> Any:
field_type = get_field_type(field)
try:
if field_type is bool:
tv = click.types.BoolParamType().convert(value, None, ctx)
@ -181,9 +206,7 @@ def format_backend_help(api: Backend, formatter: click.HelpFormatter) -> None:
if doc:
doc += " "
doc += f"(Default: {default!r})"
f_type = f.type
if isinstance(f_type, UnionType):
f_type = f_type.__args__[0]
f_type = get_field_type(f)
typename = f_type.__name__
rows.append((f"-B {name}:{typename.upper()}", doc))
formatter.write_dl(rows)