Fix parsing of bool-type backend options

this made it necessary to plumb 'ctx' through some additional
places, all to make it easy to call BoolParamType().convert().
This commit is contained in:
Jeff Epler 2023-11-16 20:26:44 -06:00
parent 5daf7e7d2a
commit 4a36cce0e0
No known key found for this signature in database
GPG key ID: D5BF15AB975AB4DE
2 changed files with 29 additions and 23 deletions

View file

@ -7,6 +7,7 @@ import subprocess
import sys import sys
from typing import Any, Optional, cast from typing import Any, Optional, cast
import click
from markdown_it import MarkdownIt from markdown_it import MarkdownIt
from textual import work from textual import work
from textual.app import App, ComposeResult from textual.app import App, ComposeResult
@ -66,7 +67,7 @@ class Tui(App[None]):
self, api: Optional[Backend] = None, session: Optional[Session] = None self, api: Optional[Backend] = None, session: Optional[Session] = None
) -> None: ) -> None:
super().__init__() super().__init__()
self.api = api or get_api("lorem") self.api = api or get_api(click.Context(click.Command("chap tui")), "lorem")
self.session = ( self.session = (
new_session(self.api.system_message) if session is None else session new_session(self.api.system_message) if session is None else session
) )

View file

@ -10,7 +10,7 @@ import os
import pathlib import pathlib
import pkgutil import pkgutil
import subprocess import subprocess
from dataclasses import MISSING, dataclass, fields from dataclasses import MISSING, Field, dataclass, fields
from typing import Any, AsyncGenerator, Callable, Optional, Union, cast from typing import Any, AsyncGenerator, Callable, Optional, Union, cast
import click import click
@ -69,7 +69,25 @@ def new_session_path(opt_path: Optional[pathlib.Path] = None) -> pathlib.Path:
) )
def configure_api_from_environment(api_name: str, api: Backend) -> None: def convert_str_to_field(ctx: click.Context, field: Field[Any], value: str) -> Any:
field_type = field.type
if isinstance(field_type, UnionType):
field_type = field_type.__args__[0]
try:
if field_type is bool:
tv = click.types.BoolParamType().convert(value, None, ctx)
else:
tv = field_type(value)
return tv
except ValueError as e:
raise click.BadParameter(
f"Invalid value for {field.name} with value {value}: {e}"
) from e
def configure_api_from_environment(
ctx: click.Context, api_name: str, api: Backend
) -> None:
if not hasattr(api, "parameters"): if not hasattr(api, "parameters"):
return return
@ -78,21 +96,16 @@ def configure_api_from_environment(api_name: str, api: Backend) -> None:
value = os.environ.get(envvar) value = os.environ.get(envvar)
if value is None: if value is None:
continue continue
try: tv = convert_str_to_field(ctx, field, value)
tv = field.type(value)
except ValueError as e:
raise click.BadParameter(
f"Invalid value for {field.name} with value {value}: {e}"
) from e
setattr(api.parameters, field.name, tv) setattr(api.parameters, field.name, tv)
def get_api(name: str = "openai_chatgpt") -> Backend: def get_api(ctx: click.Context, name: str = "openai_chatgpt") -> Backend:
name = name.replace("-", "_") name = name.replace("-", "_")
backend = cast( backend = cast(
Backend, importlib.import_module(f"{__package__}.backends.{name}").factory() Backend, importlib.import_module(f"{__package__}.backends.{name}").factory()
) )
configure_api_from_environment(name, backend) configure_api_from_environment(ctx, name, backend)
return backend return backend
@ -159,7 +172,7 @@ def set_backend( # pylint: disable=unused-argument
ctx.exit() ctx.exit()
try: try:
ctx.obj.api = get_api(value) ctx.obj.api = get_api(ctx, value)
except ModuleNotFoundError as e: except ModuleNotFoundError as e:
raise click.BadParameter(str(e)) raise click.BadParameter(str(e))
@ -197,16 +210,8 @@ def set_backend_option( # pylint: disable=unused-argument
field = all_fields.get(name) field = all_fields.get(name)
if field is None: if field is None:
raise click.BadParameter(f"Invalid parameter {name}") raise click.BadParameter(f"Invalid parameter {name}")
f_type = field.type tv = convert_str_to_field(ctx, field, value)
if isinstance(f_type, UnionType): setattr(api.parameters, field.name, tv)
f_type = f_type.__args__[0]
try:
tv = f_type(value)
except ValueError as e:
raise click.BadParameter(
f"Invalid value for {name} with value {value}: {e}"
) from e
setattr(api.parameters, name, tv)
for kv in opts: for kv in opts:
set_one_backend_option(kv) set_one_backend_option(kv)
@ -334,7 +339,7 @@ class MyCLI(click.MultiCommand):
self, ctx: click.Context, formatter: click.HelpFormatter self, ctx: click.Context, formatter: click.HelpFormatter
) -> None: ) -> None:
super().format_options(ctx, formatter) super().format_options(ctx, formatter)
api = ctx.obj.api or get_api() api = ctx.obj.api or get_api(ctx)
if hasattr(api, "parameters"): if hasattr(api, "parameters"):
format_backend_help(api, formatter) format_backend_help(api, formatter)