Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC Choice: expand Choice token normalization + remove str requirement #2796

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions src/click/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2150,11 +2150,11 @@ def human_readable_name(self) -> str:
"""
return self.name # type: ignore

def make_metavar(self) -> str:
def make_metavar(self, ctx: Context) -> str:
if self.metavar is not None:
return self.metavar

metavar = self.type.get_metavar(self)
metavar = self.type.get_metavar(param=self, ctx=ctx)

if metavar is None:
metavar = self.type.name.upper()
Expand Down Expand Up @@ -2696,7 +2696,7 @@ def _write_opts(opts: cabc.Sequence[str]) -> str:
any_prefix_is_slash = True

if not self.is_flag and not self.count:
rv += f" {self.make_metavar()}"
rv += f" {self.make_metavar(ctx=ctx)}"

return rv

Expand Down Expand Up @@ -2977,10 +2977,10 @@ def human_readable_name(self) -> str:
return self.metavar
return self.name.upper() # type: ignore

def make_metavar(self) -> str:
def make_metavar(self, ctx: Context) -> str:
if self.metavar is not None:
return self.metavar
var = self.type.get_metavar(self)
var = self.type.get_metavar(param=self, ctx=ctx)
if not var:
var = self.name.upper() # type: ignore
if not self.required:
Expand All @@ -3007,10 +3007,10 @@ def _parse_decls(
return name, [arg], []

def get_usage_pieces(self, ctx: Context) -> list[str]:
return [self.make_metavar()]
return [self.make_metavar(ctx)]

def get_error_hint(self, ctx: Context) -> str:
return f"'{self.make_metavar()}'"
return f"'{self.make_metavar(ctx)}'"

def add_to_parser(self, parser: _OptionParser, ctx: Context) -> None:
parser.add_argument(dest=self.name, nargs=self.nargs, obj=self)
Expand Down
4 changes: 3 additions & 1 deletion src/click/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,9 @@ def format_message(self) -> str:

msg = self.message
if self.param is not None:
msg_extra = self.param.type.get_missing_message(self.param)
msg_extra = self.param.type.get_missing_message(
param=self.param, ctx=self.ctx
)
if msg_extra:
if msg:
msg += f". {msg_extra}"
Expand Down
109 changes: 75 additions & 34 deletions src/click/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from .core import Parameter
from .shell_completion import CompletionItem

ParamTypeValue = t.TypeVar("ParamTypeValue")


class ParamType:
"""Represents the type of a parameter. Validates and converts values
Expand Down Expand Up @@ -86,10 +88,10 @@ def __call__(
if value is not None:
return self.convert(value, param, ctx)

def get_metavar(self, param: Parameter) -> str | None:
def get_metavar(self, param: Parameter, ctx: Context) -> str | None:
"""Returns the metavar default for this param if it provides one."""

def get_missing_message(self, param: Parameter) -> str | None:
def get_missing_message(self, param: Parameter, ctx: Context | None) -> str | None:
"""Optionally might return extra information about a missing
parameter.

Expand Down Expand Up @@ -227,7 +229,7 @@ def __repr__(self) -> str:
return "STRING"


class Choice(ParamType):
class Choice(ParamType, t.Generic[ParamTypeValue]):
"""The choice type allows a value to be checked against a fixed set
of supported values. All of these values have to be strings.

Expand All @@ -247,7 +249,7 @@ class Choice(ParamType):
name = "choice"

def __init__(
self, choices: cabc.Sequence[str], case_sensitive: bool = True
self, choices: cabc.Sequence[ParamTypeValue], case_sensitive: bool = True
) -> None:
self.choices = choices
self.case_sensitive = case_sensitive
Expand All @@ -258,14 +260,52 @@ def to_info_dict(self) -> dict[str, t.Any]:
info_dict["case_sensitive"] = self.case_sensitive
return info_dict

def get_metavar(self, param: Parameter) -> str:
def normalized_mapping(
self, ctx: Context | None = None
) -> cabc.Mapping[ParamTypeValue, str]:
"""
Returns mapping where keys are the original choices and the values are
the normalized values that are accepted via the command line.

.. versionadded:: 8.2.0
"""
return {
choice: self.normalize_choice(
choice=choice,
ctx=ctx,
)
for choice in self.choices
}

def normalize_choice(self, choice: ParamTypeValue, ctx: Context | None) -> str:
"""
Normalize a choice value.

By default use ``ctx.token_normalize_func`` and if not case sensitive,
convert to a lowecase/casefolded value.

.. versionadded:: 8.2.0
"""
normed_value = str(choice)

if ctx is not None and ctx.token_normalize_func is not None:
normed_value = ctx.token_normalize_func(normed_value)

if not self.case_sensitive:
normed_value = normed_value.casefold()

return normed_value

def get_metavar(self, param: Parameter, ctx: Context) -> str | None:
if param.param_type_name == "option" and not param.show_choices: # type: ignore
choice_metavars = [
convert_type(type(choice)).name.upper() for choice in self.choices
]
choices_str = "|".join([*dict.fromkeys(choice_metavars)])
else:
choices_str = "|".join([str(i) for i in self.choices])
choices_str = "|".join(
[str(i) for i in self.normalized_mapping(ctx=ctx).values()]
)

# Use curly braces to indicate a required argument.
if param.required and param.param_type_name == "argument":
Expand All @@ -274,46 +314,47 @@ def get_metavar(self, param: Parameter) -> str:
# Use square braces to indicate an option or optional argument.
return f"[{choices_str}]"

def get_missing_message(self, param: Parameter) -> str:
return _("Choose from:\n\t{choices}").format(choices=",\n\t".join(self.choices))
def get_missing_message(self, param: Parameter, ctx: Context | None) -> str:
"""
Message shown when no choice is passed.

.. versionchanged:: 8.2.0 Added ``ctx`` argument.
"""
return _("Choose from:\n\t{choices}").format(
choices=",\n\t".join(self.normalized_mapping(ctx=ctx).values())
)

def convert(
self, value: t.Any, param: Parameter | None, ctx: Context | None
) -> t.Any:
# Match through normalization and case sensitivity
# first do token_normalize_func, then lowercase
# preserve original `value` to produce an accurate message in
# `self.fail`
normed_value = value
normed_choices = {choice: choice for choice in self.choices}

if ctx is not None and ctx.token_normalize_func is not None:
normed_value = ctx.token_normalize_func(value)
normed_choices = {
ctx.token_normalize_func(normed_choice): original
for normed_choice, original in normed_choices.items()
}

if not self.case_sensitive:
normed_value = normed_value.casefold()
normed_choices = {
normed_choice.casefold(): original
for normed_choice, original in normed_choices.items()
}
normed_value = self.normalize_choice(choice=value, ctx=ctx)
normalized_mapping = self.normalized_mapping(ctx=ctx)
original_choice = next(
(
original
for original, normalized in normalized_mapping.items()
if normalized == normed_value
),
None,
)

if normed_value in normed_choices:
return normed_choices[normed_value]
if not original_choice:
self.fail(
self.get_invalid_choice_message(value=value, ctx=ctx),
param=param,
ctx=ctx,
)

self.fail(self.get_invalid_choice_message(value), param, ctx)
return original_choice

def get_invalid_choice_message(self, value: t.Any) -> str:
def get_invalid_choice_message(self, value: t.Any, ctx: Context | None) -> str:
"""Get the error message when the given choice is invalid.

:param value: The invalid value.

.. versionadded:: 8.2
"""
choices_str = ", ".join(map(repr, self.choices))
choices_str = ", ".join(map(repr, self.normalized_mapping(ctx=ctx).values()))
return ngettext(
"{value!r} is not {choice}.",
"{value!r} is not one of {choices}.",
Expand Down Expand Up @@ -382,7 +423,7 @@ def to_info_dict(self) -> dict[str, t.Any]:
info_dict["formats"] = self.formats
return info_dict

def get_metavar(self, param: Parameter) -> str:
def get_metavar(self, param: Parameter, ctx: Context) -> str | None:
return f"[{'|'.join(self.formats)}]"

def _try_to_convert_date(self, value: t.Any, format: str) -> datetime | None:
Expand Down
35 changes: 35 additions & 0 deletions tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,41 @@ def cli(method):
assert "--method [foo|bar|baz]" in result.output


def test_choice_option_normalization(runner):
@click.command()
@click.option(
"--method",
type=click.Choice(
["SCREAMING_SNAKE_CASE", "snake_case", "PascalCase", "kebab-case"],
case_sensitive=False,
),
)
def cli(method):
click.echo(method)

result = runner.invoke(cli, ["--method=snake_case"])
assert not result.exception, result.output
assert result.output == "snake_case\n"

# Even though it's case sensitive, the choice's original value is preserved
result = runner.invoke(cli, ["--method=pascalcase"])
assert not result.exception, result.output
assert result.output == "PascalCase\n"

result = runner.invoke(cli, ["--method=meh"])
assert result.exit_code == 2
assert (
"Invalid value for '--method': 'meh' is not one of "
"'screaming_snake_case', 'snake_case', 'pascalcase', 'kebab-case'."
) in result.output

result = runner.invoke(cli, ["--help"])
assert (
"--method [screaming_snake_case|snake_case|pascalcase|kebab-case]"
in result.output
)


def test_choice_argument(runner):
@click.command()
@click.argument("method", type=click.Choice(["foo", "bar", "baz"]))
Expand Down
2 changes: 1 addition & 1 deletion tests/test_formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def cmd(arg):

def test_formatting_custom_type_metavar(runner):
class MyType(click.ParamType):
def get_metavar(self, param):
def get_metavar(self, param: click.Parameter, ctx: click.Context):
return "MY_TYPE"

@click.command("foo")
Expand Down
2 changes: 1 addition & 1 deletion tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,5 +248,5 @@ def test_invalid_path_with_esc_sequence():

def test_choice_get_invalid_choice_message():
choice = click.Choice(["a", "b", "c"])
message = choice.get_invalid_choice_message("d")
message = choice.get_invalid_choice_message("d", ctx=None)
assert message == "'d' is not one of 'a', 'b', 'c'."