Fix parsing of bool-type backend options

this made it necessary to plumb 'ctx' through some additional
places, all to make it easy to call BoolParamType().convert().
This commit is contained in:
Jeff Epler 2023-11-16 20:26:44 -06:00
parent 5daf7e7d2a
commit 4a36cce0e0
No known key found for this signature in database
GPG key ID: D5BF15AB975AB4DE
2 changed files with 29 additions and 23 deletions

View file

@ -7,6 +7,7 @@ import subprocess
import sys
from typing import Any, Optional, cast
import click
from markdown_it import MarkdownIt
from textual import work
from textual.app import App, ComposeResult
@ -66,7 +67,7 @@ class Tui(App[None]):
self, api: Optional[Backend] = None, session: Optional[Session] = None
) -> None:
super().__init__()
self.api = api or get_api("lorem")
self.api = api or get_api(click.Context(click.Command("chap tui")), "lorem")
self.session = (
new_session(self.api.system_message) if session is None else session
)

View file

@ -10,7 +10,7 @@ import os
import pathlib
import pkgutil
import subprocess
from dataclasses import MISSING, dataclass, fields
from dataclasses import MISSING, Field, dataclass, fields
from typing import Any, AsyncGenerator, Callable, Optional, Union, cast
import click
@ -69,7 +69,25 @@ def new_session_path(opt_path: Optional[pathlib.Path] = None) -> pathlib.Path:
)
def configure_api_from_environment(api_name: str, api: Backend) -> None:
def convert_str_to_field(ctx: click.Context, field: Field[Any], value: str) -> Any:
field_type = field.type
if isinstance(field_type, UnionType):
field_type = field_type.__args__[0]
try:
if field_type is bool:
tv = click.types.BoolParamType().convert(value, None, ctx)
else:
tv = field_type(value)
return tv
except ValueError as e:
raise click.BadParameter(
f"Invalid value for {field.name} with value {value}: {e}"
) from e
def configure_api_from_environment(
ctx: click.Context, api_name: str, api: Backend
) -> None:
if not hasattr(api, "parameters"):
return
@ -78,21 +96,16 @@ def configure_api_from_environment(api_name: str, api: Backend) -> None:
value = os.environ.get(envvar)
if value is None:
continue
try:
tv = field.type(value)
except ValueError as e:
raise click.BadParameter(
f"Invalid value for {field.name} with value {value}: {e}"
) from e
tv = convert_str_to_field(ctx, field, value)
setattr(api.parameters, field.name, tv)
def get_api(name: str = "openai_chatgpt") -> Backend:
def get_api(ctx: click.Context, name: str = "openai_chatgpt") -> Backend:
name = name.replace("-", "_")
backend = cast(
Backend, importlib.import_module(f"{__package__}.backends.{name}").factory()
)
configure_api_from_environment(name, backend)
configure_api_from_environment(ctx, name, backend)
return backend
@ -159,7 +172,7 @@ def set_backend( # pylint: disable=unused-argument
ctx.exit()
try:
ctx.obj.api = get_api(value)
ctx.obj.api = get_api(ctx, value)
except ModuleNotFoundError as e:
raise click.BadParameter(str(e))
@ -197,16 +210,8 @@ def set_backend_option( # pylint: disable=unused-argument
field = all_fields.get(name)
if field is None:
raise click.BadParameter(f"Invalid parameter {name}")
f_type = field.type
if isinstance(f_type, UnionType):
f_type = f_type.__args__[0]
try:
tv = f_type(value)
except ValueError as e:
raise click.BadParameter(
f"Invalid value for {name} with value {value}: {e}"
) from e
setattr(api.parameters, name, tv)
tv = convert_str_to_field(ctx, field, value)
setattr(api.parameters, field.name, tv)
for kv in opts:
set_one_backend_option(kv)
@ -334,7 +339,7 @@ class MyCLI(click.MultiCommand):
self, ctx: click.Context, formatter: click.HelpFormatter
) -> None:
super().format_options(ctx, formatter)
api = ctx.obj.api or get_api()
api = ctx.obj.api or get_api(ctx)
if hasattr(api, "parameters"):
format_backend_help(api, formatter)