Fix parsing of bool-type backend options
this made it necessary to plumb 'ctx' through some additional places, all to make it easy to call BoolParamType().convert().
This commit is contained in:
parent
5daf7e7d2a
commit
4a36cce0e0
2 changed files with 29 additions and 23 deletions
|
|
@ -7,6 +7,7 @@ import subprocess
|
||||||
import sys
|
import sys
|
||||||
from typing import Any, Optional, cast
|
from typing import Any, Optional, cast
|
||||||
|
|
||||||
|
import click
|
||||||
from markdown_it import MarkdownIt
|
from markdown_it import MarkdownIt
|
||||||
from textual import work
|
from textual import work
|
||||||
from textual.app import App, ComposeResult
|
from textual.app import App, ComposeResult
|
||||||
|
|
@ -66,7 +67,7 @@ class Tui(App[None]):
|
||||||
self, api: Optional[Backend] = None, session: Optional[Session] = None
|
self, api: Optional[Backend] = None, session: Optional[Session] = None
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.api = api or get_api("lorem")
|
self.api = api or get_api(click.Context(click.Command("chap tui")), "lorem")
|
||||||
self.session = (
|
self.session = (
|
||||||
new_session(self.api.system_message) if session is None else session
|
new_session(self.api.system_message) if session is None else session
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ import os
|
||||||
import pathlib
|
import pathlib
|
||||||
import pkgutil
|
import pkgutil
|
||||||
import subprocess
|
import subprocess
|
||||||
from dataclasses import MISSING, dataclass, fields
|
from dataclasses import MISSING, Field, dataclass, fields
|
||||||
from typing import Any, AsyncGenerator, Callable, Optional, Union, cast
|
from typing import Any, AsyncGenerator, Callable, Optional, Union, cast
|
||||||
|
|
||||||
import click
|
import click
|
||||||
|
|
@ -69,7 +69,25 @@ def new_session_path(opt_path: Optional[pathlib.Path] = None) -> pathlib.Path:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def configure_api_from_environment(api_name: str, api: Backend) -> None:
|
def convert_str_to_field(ctx: click.Context, field: Field[Any], value: str) -> Any:
|
||||||
|
field_type = field.type
|
||||||
|
if isinstance(field_type, UnionType):
|
||||||
|
field_type = field_type.__args__[0]
|
||||||
|
try:
|
||||||
|
if field_type is bool:
|
||||||
|
tv = click.types.BoolParamType().convert(value, None, ctx)
|
||||||
|
else:
|
||||||
|
tv = field_type(value)
|
||||||
|
return tv
|
||||||
|
except ValueError as e:
|
||||||
|
raise click.BadParameter(
|
||||||
|
f"Invalid value for {field.name} with value {value}: {e}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
def configure_api_from_environment(
|
||||||
|
ctx: click.Context, api_name: str, api: Backend
|
||||||
|
) -> None:
|
||||||
if not hasattr(api, "parameters"):
|
if not hasattr(api, "parameters"):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
@ -78,21 +96,16 @@ def configure_api_from_environment(api_name: str, api: Backend) -> None:
|
||||||
value = os.environ.get(envvar)
|
value = os.environ.get(envvar)
|
||||||
if value is None:
|
if value is None:
|
||||||
continue
|
continue
|
||||||
try:
|
tv = convert_str_to_field(ctx, field, value)
|
||||||
tv = field.type(value)
|
|
||||||
except ValueError as e:
|
|
||||||
raise click.BadParameter(
|
|
||||||
f"Invalid value for {field.name} with value {value}: {e}"
|
|
||||||
) from e
|
|
||||||
setattr(api.parameters, field.name, tv)
|
setattr(api.parameters, field.name, tv)
|
||||||
|
|
||||||
|
|
||||||
def get_api(name: str = "openai_chatgpt") -> Backend:
|
def get_api(ctx: click.Context, name: str = "openai_chatgpt") -> Backend:
|
||||||
name = name.replace("-", "_")
|
name = name.replace("-", "_")
|
||||||
backend = cast(
|
backend = cast(
|
||||||
Backend, importlib.import_module(f"{__package__}.backends.{name}").factory()
|
Backend, importlib.import_module(f"{__package__}.backends.{name}").factory()
|
||||||
)
|
)
|
||||||
configure_api_from_environment(name, backend)
|
configure_api_from_environment(ctx, name, backend)
|
||||||
return backend
|
return backend
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -159,7 +172,7 @@ def set_backend( # pylint: disable=unused-argument
|
||||||
ctx.exit()
|
ctx.exit()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ctx.obj.api = get_api(value)
|
ctx.obj.api = get_api(ctx, value)
|
||||||
except ModuleNotFoundError as e:
|
except ModuleNotFoundError as e:
|
||||||
raise click.BadParameter(str(e))
|
raise click.BadParameter(str(e))
|
||||||
|
|
||||||
|
|
@ -197,16 +210,8 @@ def set_backend_option( # pylint: disable=unused-argument
|
||||||
field = all_fields.get(name)
|
field = all_fields.get(name)
|
||||||
if field is None:
|
if field is None:
|
||||||
raise click.BadParameter(f"Invalid parameter {name}")
|
raise click.BadParameter(f"Invalid parameter {name}")
|
||||||
f_type = field.type
|
tv = convert_str_to_field(ctx, field, value)
|
||||||
if isinstance(f_type, UnionType):
|
setattr(api.parameters, field.name, tv)
|
||||||
f_type = f_type.__args__[0]
|
|
||||||
try:
|
|
||||||
tv = f_type(value)
|
|
||||||
except ValueError as e:
|
|
||||||
raise click.BadParameter(
|
|
||||||
f"Invalid value for {name} with value {value}: {e}"
|
|
||||||
) from e
|
|
||||||
setattr(api.parameters, name, tv)
|
|
||||||
|
|
||||||
for kv in opts:
|
for kv in opts:
|
||||||
set_one_backend_option(kv)
|
set_one_backend_option(kv)
|
||||||
|
|
@ -334,7 +339,7 @@ class MyCLI(click.MultiCommand):
|
||||||
self, ctx: click.Context, formatter: click.HelpFormatter
|
self, ctx: click.Context, formatter: click.HelpFormatter
|
||||||
) -> None:
|
) -> None:
|
||||||
super().format_options(ctx, formatter)
|
super().format_options(ctx, formatter)
|
||||||
api = ctx.obj.api or get_api()
|
api = ctx.obj.api or get_api(ctx)
|
||||||
if hasattr(api, "parameters"):
|
if hasattr(api, "parameters"):
|
||||||
format_backend_help(api, formatter)
|
format_backend_help(api, formatter)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue