re-work backend parameters again to work on 3.9 and 3.11
This commit is contained in:
parent
877246ac28
commit
86059a5d85
1 changed files with 32 additions and 9 deletions
|
|
@ -11,7 +11,17 @@ import pathlib
|
|||
import pkgutil
|
||||
import subprocess
|
||||
from dataclasses import MISSING, Field, dataclass, fields
|
||||
from typing import Any, AsyncGenerator, Callable, Optional, Union, cast
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Callable,
|
||||
Optional,
|
||||
Union,
|
||||
cast,
|
||||
get_origin,
|
||||
get_args,
|
||||
)
|
||||
import sys
|
||||
|
||||
import click
|
||||
import platformdirs
|
||||
|
|
@ -21,8 +31,11 @@ from typing_extensions import Protocol
|
|||
from . import backends, commands
|
||||
from .session import Message, Session, System, session_from_file
|
||||
|
||||
# 3.9 compatible version of `from types import UnionType`
|
||||
UnionType = type(Union[int, list])
|
||||
UnionType: type
|
||||
if sys.version_info >= (3, 10):
|
||||
from types import UnionType
|
||||
else:
|
||||
UnionType = type(Union[int, float])
|
||||
|
||||
conversations_path = platformdirs.user_state_path("chap") / "conversations"
|
||||
conversations_path.mkdir(parents=True, exist_ok=True)
|
||||
|
|
@ -69,10 +82,22 @@ def new_session_path(opt_path: Optional[pathlib.Path] = None) -> pathlib.Path:
|
|||
)
|
||||
|
||||
|
||||
def convert_str_to_field(ctx: click.Context, field: Field[Any], value: str) -> Any:
|
||||
def get_field_type(field: Field[Any]) -> Any:
|
||||
field_type = field.type
|
||||
if isinstance(field_type, UnionType):
|
||||
field_type = field_type.__args__[0]
|
||||
if isinstance(field_type, str):
|
||||
raise RuntimeError(
|
||||
"parameters dataclass may not use 'from __future__ import annotations"
|
||||
)
|
||||
origin = get_origin(field_type)
|
||||
if origin in (Union, UnionType):
|
||||
for arg in get_args(field_type):
|
||||
if arg is not None:
|
||||
return arg
|
||||
return field_type
|
||||
|
||||
|
||||
def convert_str_to_field(ctx: click.Context, field: Field[Any], value: str) -> Any:
|
||||
field_type = get_field_type(field)
|
||||
try:
|
||||
if field_type is bool:
|
||||
tv = click.types.BoolParamType().convert(value, None, ctx)
|
||||
|
|
@ -181,9 +206,7 @@ def format_backend_help(api: Backend, formatter: click.HelpFormatter) -> None:
|
|||
if doc:
|
||||
doc += " "
|
||||
doc += f"(Default: {default!r})"
|
||||
f_type = f.type
|
||||
if isinstance(f_type, UnionType):
|
||||
f_type = f_type.__args__[0]
|
||||
f_type = get_field_type(f)
|
||||
typename = f_type.__name__
|
||||
rows.append((f"-B {name}:{typename.upper()}", doc))
|
||||
formatter.write_dl(rows)
|
||||
|
|
|
|||
Loading…
Reference in a new issue