diff --git a/Lib/test/test_clinic.py b/Lib/test/test_clinic.py index 580d54e0eb094d..dc88d388fd57dc 100644 --- a/Lib/test/test_clinic.py +++ b/Lib/test/test_clinic.py @@ -1277,12 +1277,8 @@ def test_base_invalid_syntax(self): os.stat invalid syntax: int = 42 """ - err = dedent(r""" - Function 'stat' has an invalid parameter declaration: - \s+'invalid syntax: int = 42' - """).strip() - with self.assertRaisesRegex(ClinicError, err): - self.parse_function(block) + err = "Function 'stat' has an invalid parameter declaration: 'invalid syntax: int = 42'" + self.expect_failure(block, err, lineno=2) def test_param_default_invalid_syntax(self): block = """ @@ -1290,7 +1286,7 @@ def test_param_default_invalid_syntax(self): os.stat x: int = invalid syntax """ - err = r"Syntax error: 'x = invalid syntax\n'" + err = "Function 'stat' has an invalid parameter declaration:" self.expect_failure(block, err, lineno=2) def test_cloning_nonexistent_function_correctly_fails(self): @@ -2510,7 +2506,7 @@ def test_cannot_specify_pydefault_without_default(self): self.expect_failure(block, err, lineno=1) def test_vararg_cannot_take_default_value(self): - err = "Vararg can't take a default value!" + err = "Function 'fn' has an invalid parameter declaration:" block = """ fn *args: tuple = None diff --git a/Tools/clinic/libclinic/dsl_parser.py b/Tools/clinic/libclinic/dsl_parser.py index 282ff64cd33089..eca41531f7c8e9 100644 --- a/Tools/clinic/libclinic/dsl_parser.py +++ b/Tools/clinic/libclinic/dsl_parser.py @@ -877,43 +877,16 @@ def parse_parameter(self, line: str) -> None: # handle "as" for parameters too c_name = None - name, have_as_token, trailing = line.partition(' as ') - if have_as_token: - name = name.strip() - if ' ' not in name: - fields = trailing.strip().split(' ') - if not fields: - fail("Invalid 'as' clause!") - c_name = fields[0] - if c_name.endswith(':'): - name += ':' - c_name = c_name[:-1] - fields[0] = name - line = ' '.join(fields) - - default: str | None - base, equals, default = line.rpartition('=') - if not equals: - base = default - default = None - - module = None + m = re.match(r'(?:\* *)?\w+( +as +(\w+))', line) + if m: + c_name = m[2] + line = line[:m.start(1)] + line[m.end(1):] + try: - ast_input = f"def x({base}): pass" + ast_input = f"def x({line}\n): pass" module = ast.parse(ast_input) except SyntaxError: - try: - # the last = was probably inside a function call, like - # c: int(accept={str}) - # so assume there was no actual default value. - default = None - ast_input = f"def x({line}): pass" - module = ast.parse(ast_input) - except SyntaxError: - pass - if not module: - fail(f"Function {self.function.name!r} has an invalid parameter declaration:\n\t", - repr(line)) + fail(f"Function {self.function.name!r} has an invalid parameter declaration: {line!r}") function = module.body[0] assert isinstance(function, ast.FunctionDef) @@ -922,9 +895,6 @@ def parse_parameter(self, line: str) -> None: if len(function_args.args) > 1: fail(f"Function {self.function.name!r} has an " f"invalid parameter declaration (comma?): {line!r}") - if function_args.defaults or function_args.kw_defaults: - fail(f"Function {self.function.name!r} has an " - f"invalid parameter declaration (default value?): {line!r}") if function_args.kwarg: fail(f"Function {self.function.name!r} has an " f"invalid parameter declaration (**kwargs?): {line!r}") @@ -944,7 +914,7 @@ def parse_parameter(self, line: str) -> None: name = 'varpos_' + name value: object - if not default: + if not function_args.defaults: if is_vararg: value = NULL else: @@ -955,17 +925,13 @@ def parse_parameter(self, line: str) -> None: if 'py_default' in kwargs: fail("You can't specify py_default without specifying a default value!") else: - if is_vararg: - fail("Vararg can't take a default value!") + expr = function_args.defaults[0] + default = ast_input[expr.col_offset: expr.end_col_offset].strip() if self.parameter_state is ParamState.REQUIRED: self.parameter_state = ParamState.OPTIONAL - default = default.strip() bad = False - ast_input = f"x = {default}" try: - module = ast.parse(ast_input) - if 'c_default' not in kwargs: # we can only represent very simple data values in C. # detect whether default is okay, via a denylist @@ -992,13 +958,14 @@ def bad_node(self, node: ast.AST) -> None: visit_Starred = bad_node denylist = DetectBadNodes() - denylist.visit(module) + denylist.visit(expr) bad = denylist.bad else: # if they specify a c_default, we can be more lenient about the default value. # but at least make an attempt at ensuring it's a valid expression. + code = compile(ast.Expression(expr), '', 'eval') try: - value = eval(default) + value = eval(code) except NameError: pass # probably a named constant except Exception as e: @@ -1010,9 +977,6 @@ def bad_node(self, node: ast.AST) -> None: if bad: fail(f"Unsupported expression as default value: {default!r}") - assignment = module.body[0] - assert isinstance(assignment, ast.Assign) - expr = assignment.value # mild hack: explicitly support NULL as a default value c_default: str | None if isinstance(expr, ast.Name) and expr.id == 'NULL': @@ -1064,8 +1028,6 @@ def bad_node(self, node: ast.AST) -> None: else: c_default = py_default - except SyntaxError as e: - fail(f"Syntax error: {e.text!r}") except (ValueError, AttributeError): value = unknown c_default = kwargs.get("c_default")