Fix parsing of bool-type backend options
this made it necessary to plumb 'ctx' through some additional places, all to make it easy to call BoolParamType().convert().
This commit is contained in:
parent
5daf7e7d2a
commit
4a36cce0e0
2 changed files with 29 additions and 23 deletions
|
|
@ -7,6 +7,7 @@ import subprocess
|
|||
import sys
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
import click
|
||||
from markdown_it import MarkdownIt
|
||||
from textual import work
|
||||
from textual.app import App, ComposeResult
|
||||
|
|
@ -66,7 +67,7 @@ class Tui(App[None]):
|
|||
self, api: Optional[Backend] = None, session: Optional[Session] = None
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.api = api or get_api("lorem")
|
||||
self.api = api or get_api(click.Context(click.Command("chap tui")), "lorem")
|
||||
self.session = (
|
||||
new_session(self.api.system_message) if session is None else session
|
||||
)
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ import os
|
|||
import pathlib
|
||||
import pkgutil
|
||||
import subprocess
|
||||
from dataclasses import MISSING, dataclass, fields
|
||||
from dataclasses import MISSING, Field, dataclass, fields
|
||||
from typing import Any, AsyncGenerator, Callable, Optional, Union, cast
|
||||
|
||||
import click
|
||||
|
|
@ -69,7 +69,25 @@ def new_session_path(opt_path: Optional[pathlib.Path] = None) -> pathlib.Path:
|
|||
)
|
||||
|
||||
|
||||
def configure_api_from_environment(api_name: str, api: Backend) -> None:
|
||||
def convert_str_to_field(ctx: click.Context, field: Field[Any], value: str) -> Any:
|
||||
field_type = field.type
|
||||
if isinstance(field_type, UnionType):
|
||||
field_type = field_type.__args__[0]
|
||||
try:
|
||||
if field_type is bool:
|
||||
tv = click.types.BoolParamType().convert(value, None, ctx)
|
||||
else:
|
||||
tv = field_type(value)
|
||||
return tv
|
||||
except ValueError as e:
|
||||
raise click.BadParameter(
|
||||
f"Invalid value for {field.name} with value {value}: {e}"
|
||||
) from e
|
||||
|
||||
|
||||
def configure_api_from_environment(
|
||||
ctx: click.Context, api_name: str, api: Backend
|
||||
) -> None:
|
||||
if not hasattr(api, "parameters"):
|
||||
return
|
||||
|
||||
|
|
@ -78,21 +96,16 @@ def configure_api_from_environment(api_name: str, api: Backend) -> None:
|
|||
value = os.environ.get(envvar)
|
||||
if value is None:
|
||||
continue
|
||||
try:
|
||||
tv = field.type(value)
|
||||
except ValueError as e:
|
||||
raise click.BadParameter(
|
||||
f"Invalid value for {field.name} with value {value}: {e}"
|
||||
) from e
|
||||
tv = convert_str_to_field(ctx, field, value)
|
||||
setattr(api.parameters, field.name, tv)
|
||||
|
||||
|
||||
def get_api(name: str = "openai_chatgpt") -> Backend:
|
||||
def get_api(ctx: click.Context, name: str = "openai_chatgpt") -> Backend:
|
||||
name = name.replace("-", "_")
|
||||
backend = cast(
|
||||
Backend, importlib.import_module(f"{__package__}.backends.{name}").factory()
|
||||
)
|
||||
configure_api_from_environment(name, backend)
|
||||
configure_api_from_environment(ctx, name, backend)
|
||||
return backend
|
||||
|
||||
|
||||
|
|
@ -159,7 +172,7 @@ def set_backend( # pylint: disable=unused-argument
|
|||
ctx.exit()
|
||||
|
||||
try:
|
||||
ctx.obj.api = get_api(value)
|
||||
ctx.obj.api = get_api(ctx, value)
|
||||
except ModuleNotFoundError as e:
|
||||
raise click.BadParameter(str(e))
|
||||
|
||||
|
|
@ -197,16 +210,8 @@ def set_backend_option( # pylint: disable=unused-argument
|
|||
field = all_fields.get(name)
|
||||
if field is None:
|
||||
raise click.BadParameter(f"Invalid parameter {name}")
|
||||
f_type = field.type
|
||||
if isinstance(f_type, UnionType):
|
||||
f_type = f_type.__args__[0]
|
||||
try:
|
||||
tv = f_type(value)
|
||||
except ValueError as e:
|
||||
raise click.BadParameter(
|
||||
f"Invalid value for {name} with value {value}: {e}"
|
||||
) from e
|
||||
setattr(api.parameters, name, tv)
|
||||
tv = convert_str_to_field(ctx, field, value)
|
||||
setattr(api.parameters, field.name, tv)
|
||||
|
||||
for kv in opts:
|
||||
set_one_backend_option(kv)
|
||||
|
|
@ -334,7 +339,7 @@ class MyCLI(click.MultiCommand):
|
|||
self, ctx: click.Context, formatter: click.HelpFormatter
|
||||
) -> None:
|
||||
super().format_options(ctx, formatter)
|
||||
api = ctx.obj.api or get_api()
|
||||
api = ctx.obj.api or get_api(ctx)
|
||||
if hasattr(api, "parameters"):
|
||||
format_backend_help(api, formatter)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue