diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000..241a938 --- /dev/null +++ b/.editorconfig @@ -0,0 +1,15 @@ +root = true + +[*] +charset = utf-8 +indent_style = space +indent_size = 4 +insert_final_newline = true +end_of_line = lf + +[*.{yml,yaml}] +indent_size = 2 + + +[*.json] +indent_size = 2 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 583206f..7250ca2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,7 +5,7 @@ repos: - id: trailing-whitespace - id: end-of-file-fixer - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: v0.4.5 + rev: v0.11.4 hooks: - id: ruff args: [ --fix, --exit-non-zero-on-fix ] diff --git a/ARCHITECTURE.rst b/ARCHITECTURE.rst index 37daf23..1ad7454 100644 --- a/ARCHITECTURE.rst +++ b/ARCHITECTURE.rst @@ -103,6 +103,35 @@ FTL functions. Other related level classes for the user are provided in ``fluent_compiler.resource`` and ``fluent_compiler.escapers``. +AST Types +~~~~~~~~~ + +As we are translating from one language to another, it is easy to get confused +about types of Abstract Syntax Tree objects in the different languages, +especially as sometimes we use identical class names. Here is a quick overview: + +- FTL (Fluent Translation List) has its own AST types. In the ``compiler.py`` + module, these are imported as ``fl_ast``. So, for example, + ``fl_ast.VariableReference`` is an AST node representing a variable reference + in a Fluent document. + +- Python AST. This is the end product we generate. It is imported directly into + the ``ast_compat.py`` module. From there it is imported into the + ``codegen.py`` module as ``py_ast``. + +- Codegen AST. We have our own layer of classes for Python code generation which + are used by the ``compiler.py`` module, which represent a simplified Python + AST with conveniences for easier construction, and eventually emit Python AST. + The base classes used here are ``CodeGenAst`` and ``CodeGenAstList``. + + This module is imported into ``compiler.py`` as ``codegen``. + +So, for example, in ``compiler.py`` you find both ``fl_ast.VariableReference`` +and ``codegen.VariableReference``. In ``codegen.py`` you find both ``If`` (the +``If`` AST node defined in the ``codgen.py`` module) and ``py_ast.If`` (the +Python AST node from Python stdlib). If you get lost remember which module/layer +you are in. + Tests ~~~~~ diff --git a/CHANGELOG.rst b/CHANGELOG.rst index f5b140c..57f4d71 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -6,6 +6,14 @@ fluent_compiler 1.2 (unreleased) * Dropped Python 3.7 support * Compiler performance improvements - thanks `@leamingrad `_. +* Switched from `attrs `_ to stdlib + `dataclasses `_, and added + lots of type signatures and cleaned up internally. + * The documented API is still exactly the same as it was. However, if you were + depending on implementation details, like the fact that ``CompiledFtl`` was + an attrs dataclass when it is now a stdlib dataclass, there may be some + small backwards incompatibilities. In addition if you are running a type + checker, you may notice some differences as more things will be checked. fluent_compiler 1.1 (2024-04-02) -------------------------------- diff --git a/docs/escaping.rst b/docs/escaping.rst index 8e9dd48..c649812 100644 --- a/docs/escaping.rst +++ b/docs/escaping.rst @@ -24,10 +24,11 @@ passed to the ``FluentBundle`` constructor or to ``compile_messages``. An ``escaper`` is an object that defines the following set of attributes. The object could be a module, or a simple namespace object you could create using -``types.SimpleNamespace``, or an instance of a class with appropriate -methods defined. The attributes are: +``types.SimpleNamespace`` or the provided +:class:`fluent_compiler.escapers.Escaper` dataclass, or an instance of a class +with appropriate methods defined. The attributes are: -- ``name`` - a simple text value that is used in error messages. +- ``name: str`` - a simple text value that is used in error messages. - ``select(**hints)`` @@ -35,7 +36,7 @@ methods defined. The attributes are: given message (or message attribute). It is passed a number of hints as keyword arguments, currently only the following: - - ``message_id`` - a string that is the name of the message or term. For terms + - ``message_id: str`` - a string that is the name of the message or term. For terms it is a string with a leading dash - e.g. ``-brand-name``. For message attributes, it is a string in the form ``messsage-name.attribute-name`` @@ -48,10 +49,10 @@ methods defined. The attributes are: ``select`` callable of each escaper in the list of escapers is tried in turn, and the first to return ``True`` is used. -- ``output_type`` - the type of values that are returned by ``escape``, +- ``output_type: type`` - the type of values that are returned by ``escape``, ``mark_escape``, and ``join``, and therefore by the whole message. -- ``escape(text_to_be_escaped)`` +- ``escape(text_to_be_escaped: str)`` A callable that will escape the passed in text. It must return a value that is an instance of ``output_type`` (or a subclass). @@ -64,12 +65,12 @@ methods defined. The attributes are: A callable that marks the passed in text as markup i.e. already escaped. It must return a value that is an instance of ``output_type`` (or a subclass). -- ``join(parts)`` +- ``join(parts: Iterable)`` A callable that accepts an iterable of components, each of type ``output_type``, and combines them into a larger value of the same type. -- ``use_isolating`` +- ``use_isolating: bool | None`` A boolean that determines whether the normal bidi isolating characters should be inserted. If it is ``None`` the value from the ``FluentBundle`` will be @@ -77,11 +78,11 @@ methods defined. The attributes are: The escaping functions need to obey some rules: -- escape must be idempotent: +- ``escape`` must be idempotent: ``escape(escape(text)) == escape(text)`` -- escape must be a no-op on the output of ``mark_escaped``: +- ``escape`` must be a no-op on the output of ``mark_escaped``: ``escape(mark_escaped(text)) == mark_escaped(text)`` @@ -101,13 +102,13 @@ This example is for .. code-block:: python - from fluent_compiler.utils import SimpleNamespace + from fluent_compiler.escapers import Escaper from markupsafe import Markup, escape empty_markup = Markup('') - html_escaper = SimpleNamespace( - select=lambda message_id=None, **hints: message_id.endswith('-html'), + html_escaper = Escaper( + select=lambda message_id, **hints: message_id.endswith('-html'), output_type=Markup, mark_escaped=Markup, escape=escape, diff --git a/pyproject.toml b/pyproject.toml index 3abb2e8..99e6f24 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,10 +35,11 @@ classifiers = [ dynamic = ["version"] dependencies = [ - "attrs>=19.3.0", "babel>=2.12.0", + "backports-strenum>=1.2.4 ; python_full_version < '3.11'", "fluent-syntax>=0.14", "pytz>=2025.2", + "typing-extensions>=4.13.0 ; python_full_version < '3.10'", ] [project.urls] @@ -75,7 +76,8 @@ target-version = ['py39'] [tool.ruff] line-length = 120 -target-version = 'py37' +target-version = 'py38' + [tool.ruff.lint] ignore = ["E501","E731"] @@ -92,11 +94,14 @@ known-first-party = ["fluent_compiler"] dev = [ "ast-decompiler>=0.8", "beautifulsoup4>=4.7.1", + "fluent-runtime>=0.4.0", "hypothesis>=4.9.0", + "ipython>=8.12.3", "markdown>=3.0.1", "markupsafe>=1.1.1", "pre-commit>=3.5.0", "pytest>=7.4.4", - "tox-uv>=1.13.1", + "pytest-benchmark>=4.0.0", "tox>=4.25.0", + "tox-uv>=1.13.1", ] diff --git a/requirements-test.txt b/requirements-test.txt deleted file mode 100644 index 5120e18..0000000 --- a/requirements-test.txt +++ /dev/null @@ -1,7 +0,0 @@ -ast_decompiler>=0.4 -beautifulsoup4>=4.7.1 -hypothesis>=4.9.0 -Markdown>=3.0.1 -MarkupSafe>=1.1.1 -pytest -six diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index d74172c..0000000 --- a/requirements.txt +++ /dev/null @@ -1,9 +0,0 @@ -# This file is used in test runs, to pin versions of runtime dependencies -# (those in setup.cfg), to help test stability, because some tests -# depend on exact behaviour (e.g. of Babel which bundles specific versions -# of i18n data) and are hard to write concisely otherwise. - -fluent.syntax>=0.14 -attrs==19.3.0 -babel==2.9.1 -pytz==2018.9 diff --git a/src/fluent_compiler/ast_compat.py b/src/fluent_compiler/ast_compat.py index 0ead310..300dbf6 100644 --- a/src/fluent_compiler/ast_compat.py +++ b/src/fluent_compiler/ast_compat.py @@ -25,6 +25,7 @@ def NewAst(...): """ import ast import sys +from typing import TypedDict, TypeVar # This is a very limited subset of Python AST: # - only the things needed by codegen.py @@ -61,21 +62,36 @@ def NewAst(...): arg = ast.arg keyword = ast.keyword walk = ast.walk +Constant = ast.Constant +AST = ast.AST +stmt = ast.stmt +expr = ast.expr -if sys.version_info >= (3, 8): - Constant = ast.Constant +# `compile` builtin needs these attributes on AST nodes. +# It's hard to get something sensible we can put for line/col numbers so we put arbitrary values. + + +class DefaultAstArgs(TypedDict): + lineno: int + col_offset: int + + +DEFAULT_AST_ARGS: DefaultAstArgs = {"lineno": 1, "col_offset": 1} +# Some AST types have different requirements: +DEFAULT_AST_ARGS_MODULE = dict() +DEFAULT_AST_ARGS_ADD = dict() +DEFAULT_AST_ARGS_ARGUMENTS = dict() + +T = TypeVar("T") + + +if sys.version_info < (3, 9): + # Old versions need an `Index` object here: + def subscript_slice_object(value: T) -> T: + return ast.Index(value, **DEFAULT_AST_ARGS) + else: - # For Python 3.7, in terms of runtime behaviour we could also use - # Constant for Str/Num, but this seems to trigger bugs when decompiling with - # ast_decompiler, which is needed by tests. So we use the more normal - # ast that Python 3.7 use for this code. - def Constant(arg, **kwargs): - if isinstance(arg, str): - return ast.Str(arg, **kwargs) - elif isinstance(arg, (int, float)): - return ast.Num(arg, **kwargs) - elif arg is None: - return ast.NameConstant(arg, **kwargs) - else: - raise NotImplementedError(f"Constant not implemented for args of type {type(arg)}") + # New versions need nothing. + def subscript_slice_object(value: T) -> T: + return value diff --git a/src/fluent_compiler/bundle.py b/src/fluent_compiler/bundle.py index 55f7dc0..dd6a889 100644 --- a/src/fluent_compiler/bundle.py +++ b/src/fluent_compiler/bundle.py @@ -1,4 +1,10 @@ -from .compiler import compile_messages +from __future__ import annotations + +from typing import Any, Callable, Sequence + +from fluent_compiler.escapers import IsEscaper + +from .compiler import CompilationErrorItem, compile_messages from .resource import FtlResource from .utils import ATTRIBUTE_SEPARATOR, TERM_SIGIL @@ -16,7 +22,14 @@ class FluentBundle: """ - def __init__(self, locale, resources, functions=None, use_isolating=True, escapers=None): + def __init__( + self, + locale: str, + resources: Sequence[FtlResource], + functions: dict[str, Callable] | None = None, + use_isolating: bool = True, + escapers: Sequence[IsEscaper] | None = None, + ): self.locale = locale compiled_ftl = compile_messages( locale, @@ -29,7 +42,14 @@ def __init__(self, locale, resources, functions=None, use_isolating=True, escape self._compilation_errors = compiled_ftl.errors @classmethod - def from_string(cls, locale, text, functions=None, use_isolating=True, escapers=None): + def from_string( + cls, + locale: str, + text: str, + functions: dict[str, Callable] | None = None, + use_isolating: bool = True, + escapers: Sequence[IsEscaper] | None = None, + ) -> FluentBundle: return cls( locale, [FtlResource.from_string(text)], @@ -48,14 +68,14 @@ def from_files(cls, locale, filenames, functions=None, use_isolating=True, escap escapers=escapers, ) - def has_message(self, message_id): + def has_message(self, message_id: str) -> bool: if message_id.startswith(TERM_SIGIL) or ATTRIBUTE_SEPARATOR in message_id: return False return message_id in self._compiled_messages - def format(self, message_id, args=None): + def format(self, message_id: str, args: Any | None = None) -> Any: errors = [] return self._compiled_messages[message_id](args, errors), errors - def check_messages(self): + def check_messages(self) -> list[CompilationErrorItem]: return self._compilation_errors diff --git a/src/fluent_compiler/codegen.py b/src/fluent_compiler/codegen.py index 4c2e9b8..164d147 100644 --- a/src/fluent_compiler/codegen.py +++ b/src/fluent_compiler/codegen.py @@ -1,12 +1,19 @@ """ Utilities for doing Python code generation """ +from __future__ import annotations +import decimal import keyword import platform import re +from abc import ABC, abstractmethod +from typing import Callable, Iterable, Protocol, Sequence, Union, runtime_checkable -from . import ast_compat as ast +from . import ast_compat as py_ast +from .ast_compat import DEFAULT_AST_ARGS, DEFAULT_AST_ARGS_ADD, DEFAULT_AST_ARGS_ARGUMENTS, DEFAULT_AST_ARGS_MODULE +from .compat import TypeAlias +from .source import FtlSource from .utils import allowable_keyword_arg_name, allowable_name # This module provides simple utilities for building up Python source code. It @@ -47,11 +54,17 @@ PROPERTY_TYPE = "PROPERTY_TYPE" PROPERTY_RETURN_TYPE = "PROPERTY_RETURN_TYPE" -UNKNOWN_TYPE = object -SENSITIVE_FUNCTIONS = [ +# UNKNOWN_TYPE is just an alias for `object` for clarity. +UNKNOWN_TYPE: type = object +# It is important for our usage of it that UNKNOWN_TYPE is a `type`, +# and the most general `type`. +assert isinstance(UNKNOWN_TYPE, type) + + +SENSITIVE_FUNCTIONS = { # builtin functions that we should never be calling from our code # generation. This is a defense-in-depth mechansim to stop our code - # generation become a code exectution vulnerability, we also have + # generation becoming a code execution vulnerability. We also have # higher level code that ensures we are not generating calls # to arbitrary Python functions. # This is not a comprehensive list of functions we are not using, but @@ -72,52 +85,44 @@ "object", "reload", "type", -] +} -class PythonAst: +class CodeGenAst(ABC): """ Base class representing a simplified Python AST (not the real one). Generates real `ast.*` nodes via `as_ast()` method. """ - def as_ast(self): + @abstractmethod + def as_ast(self) -> py_ast.AST: raise NotImplementedError(f"{self.__class__!r}.as_ast()") - @property - def child_elements(self): - raise NotImplementedError(f"{self.__class__!r}.child_elements") + child_elements: list[str] = NotImplemented -class PythonAstList: +class CodeGenAstList(ABC): """ - Alternative base class to PythonAst when we have code that wants to return a + Alternative base class to CodeGenAst when we have code that wants to return a list of AST objects. """ - def as_ast_list(self): + @abstractmethod + def as_ast_list(self, allow_empty: bool = True) -> list[py_ast.stmt]: raise NotImplementedError(f"{self.__class__!r}.as_ast_list()") - @property - def child_elements(self): - raise NotImplementedError(f"child_elements needs to be created on {type(self)}") + child_elements: list[str] = NotImplemented -# `compile` builtin needs these attributes on AST nodes. -# It's hard to get something sensible we can put for line/col numbers so we put arbitrary values. -DEFAULT_AST_ARGS = dict(lineno=1, col_offset=1) -# Some AST types have different requirements: -DEFAULT_AST_ARGS_MODULE = dict() -DEFAULT_AST_ARGS_ADD = dict() -DEFAULT_AST_ARGS_ARGUMENTS = dict() +CodeGenAstType: TypeAlias = Union[CodeGenAst, CodeGenAstList] class Scope: - def __init__(self, parent_scope=None): + def __init__(self, parent_scope: Scope | None = None): self.parent_scope = parent_scope self.names = set() self._function_arg_reserved_names = set() - self._properties = {} + self._properties: dict[str, dict[str, object]] = {} self._assignments = {} def is_name_in_use(self, name: str) -> bool: @@ -141,7 +146,13 @@ def is_name_reserved_function_arg(self, name: str) -> bool: def is_name_reserved(self, name: str) -> bool: return self.is_name_in_use(name) or self.is_name_reserved_function_arg(name) - def reserve_name(self, requested, function_arg=False, is_builtin=False, properties=None): + def reserve_name( + self, + requested: str, + function_arg: bool = False, + is_builtin: bool = False, + properties: dict[str, object] | None = None, + ): """ Reserve a name as being in use in a scope. @@ -150,7 +161,7 @@ def reserve_name(self, requested, function_arg=False, is_builtin=False, properti (e.g. the type associated with a name) """ - def _add(final): + def _add(final: str): self.names.add(final) self._properties[final] = properties or {} return final @@ -185,7 +196,7 @@ def _is_name_allowed(name: str) -> bool: return _add(attempt) - def reserve_function_arg_name(self, name): + def reserve_function_arg_name(self, name: str): """ Reserve a name for *later* use as a function argument. This does not result in that name being considered 'in use' in the current scope, but will @@ -198,29 +209,33 @@ def reserve_function_arg_name(self, name): raise AssertionError(f"Can't reserve '{name}' as function arg name as it is already reserved") self._function_arg_reserved_names.add(name) - def get_name_properties(self, name): + def get_name_properties(self, name) -> dict[str, object]: """ Gets a dictionary of properties for the name. Raises exception if the name is not reserved in this scope or parent """ if name in self._properties: return self._properties[name] + if self.parent_scope is None: + raise LookupError(f"{name} not found in properties") return self.parent_scope.get_name_properties(name) - def set_name_properties(self, name, props): + def set_name_properties(self, name: str, props: dict[str, object]): """ Sets a dictionary of properties for the name. Raises exception if the name is not reserved in this scope or parent. """ scope = self while True: + if scope is None: + raise LookupError(f"{name} not found in properties") if name in scope._properties: scope._properties[name].update(props) break else: scope = scope.parent_scope - def find_names_by_property(self, prop_name, prop_val): + def find_names_by_property(self, prop_name: str, prop_val: object) -> list[str]: """ Retrieve all names that match the supplied property name and value """ @@ -231,13 +246,13 @@ def find_names_by_property(self, prop_name, prop_val): if k == prop_name and v == prop_val ] - def has_assignment(self, name): + def has_assignment(self, name: str) -> bool: return name in self._assignments - def register_assignment(self, name): + def register_assignment(self, name: str) -> None: self._assignments[name] = None - def variable(self, name): + def variable(self, name: str) -> VariableReference: # Convenience utility for returning a VariableReference return VariableReference(name, self) @@ -246,7 +261,7 @@ def variable(self, name): _IDENTIFIER_START_RE = re.compile("^[a-zA-Z_]") -def cleanup_name(name): +def cleanup_name(name: str) -> str: """ Convert name to a allowable identifier """ @@ -257,39 +272,50 @@ def cleanup_name(name): return name -class Statement: +class Statement(CodeGenAst): pass -class _Assignment(Statement, PythonAst): +@runtime_checkable +class SupportsNameAssignment(Protocol): + def has_assignment_for_name(self, name: str) -> bool: + ... + + +class _Assignment(Statement): child_elements = ["value"] - def __init__(self, name, value): + def __init__(self, name: str, value: Expression): self.name = name self.value = value def as_ast(self): if not allowable_name(self.name): raise AssertionError(f"Expected {self.name} to be a valid Python identifier") - return ast.Assign( - targets=[ast.Name(id=self.name, ctx=ast.Store(), **DEFAULT_AST_ARGS)], + return py_ast.Assign( + targets=[py_ast.Name(id=self.name, ctx=py_ast.Store(), **DEFAULT_AST_ARGS)], value=self.value.as_ast(), **DEFAULT_AST_ARGS, ) + def has_assignment_for_name(self, name: str) -> bool: + return self.name == name + -class Block(PythonAstList): +class Block(CodeGenAstList): child_elements = ["statements"] - def __init__(self, scope, parent_block=None): + def __init__(self, scope: Scope, parent_block: Block | None = None): self.scope = scope - self.statements = [] + # We all `Expression` here for things like MethodCall which + # are bare expressions that are still useful for side effects + self.statements: list[Block | Statement | Expression] = [] self.parent_block = parent_block - def as_ast_list(self, allow_empty=True): + def as_ast_list(self, allow_empty: bool = True) -> list[py_ast.stmt]: retval = [] for s in self.statements: - if hasattr(s, "as_ast_list"): + if isinstance(s, CodeGenAstList): retval.extend(s.as_ast_list(allow_empty=True)) else: if isinstance(s, Statement): @@ -297,13 +323,13 @@ def as_ast_list(self, allow_empty=True): else: # Things like bare function/method calls need to be wrapped # in `Expr` to match the way Python parses. - retval.append(ast.Expr(s.as_ast(), **DEFAULT_AST_ARGS)) + retval.append(py_ast.Expr(s.as_ast(), **DEFAULT_AST_ARGS)) if len(retval) == 0 and not allow_empty: - return [ast.Pass(**DEFAULT_AST_ARGS)] + return [py_ast.Pass(**DEFAULT_AST_ARGS)] return retval - def add_statement(self, statement): + def add_statement(self, statement: Statement | Block | Expression) -> None: self.statements.append(statement) if isinstance(statement, Block): if statement.parent_block is None: @@ -315,7 +341,7 @@ def add_statement(self, statement): ) # Safe alternatives to Block.statements being manipulated directly: - def add_assignment(self, name, value, allow_multiple=False): + def add_assignment(self, name: str, value: Expression, allow_multiple: bool = False): """ Adds an assigment of the form: @@ -332,37 +358,34 @@ def add_assignment(self, name, value, allow_multiple=False): self.add_statement(_Assignment(name, value)) - def add_function(self, func_name, func): + def add_function(self, func_name: str, func: Function) -> None: assert func.func_name == func_name self.add_statement(func) - def add_return(self, value): + def add_return(self, value: Expression) -> None: self.add_statement(Return(value)) - def has_assignment_for_name(self, name): + def has_assignment_for_name(self, name: str) -> bool: for s in self.statements: - if isinstance(s, _Assignment) and s.name == name: + if isinstance(s, SupportsNameAssignment) and s.has_assignment_for_name(name): return True - elif hasattr(s, "has_assignment_for_name"): - if s.has_assignment_for_name(name): - return True if self.parent_block is not None: return self.parent_block.has_assignment_for_name(name) return False -class Module(Block, PythonAst): +class Module(Block, CodeGenAst): def __init__(self): scope = Scope(parent_scope=None) Block.__init__(self, scope) - def as_ast(self): - return ast.Module(body=self.as_ast_list(), type_ignores=[], **DEFAULT_AST_ARGS_MODULE) + def as_ast(self) -> py_ast.Module: + return py_ast.Module(body=self.as_ast_list(), type_ignores=[], **DEFAULT_AST_ARGS_MODULE) - def as_multiple_module_ast(self): - retval = [] + def as_multiple_module_ast(self) -> Iterable[py_ast.Module]: + retval: list[py_ast.Module] = [] for item in self.as_ast_list(): - mod = ast.Module(body=[item], type_ignores=[], **DEFAULT_AST_ARGS_MODULE) + mod = py_ast.Module(body=[item], type_ignores=[], **DEFAULT_AST_ARGS_MODULE) if hasattr(item, "filename"): # For use by compile_messages mod.filename = item.filename @@ -370,10 +393,16 @@ def as_multiple_module_ast(self): return retval -class Function(Scope, Statement, PythonAst): +class Function(Scope, Statement): child_elements = ["body"] - def __init__(self, name, args=None, parent_scope=None, source=None): + def __init__( + self, + name: str, + args: Sequence[str] | None = None, + parent_scope: Scope | None = None, + source: FtlSource | None = None, + ): super().__init__(parent_scope=parent_scope) self.body = Block(self) self.func_name = name @@ -386,18 +415,18 @@ def __init__(self, name, args=None, parent_scope=None, source=None): self.args = args self.source = source - def as_ast(self): + def as_ast(self) -> py_ast.stmt: if not allowable_name(self.func_name): raise AssertionError(f"Expected '{self.func_name}' to be a valid Python identifier") for arg in self.args: if not allowable_name(arg): raise AssertionError(f"Expected '{arg}' to be a valid Python identifier") - func_def = ast.FunctionDef( + func_def = py_ast.FunctionDef( name=self.func_name, - args=ast.arguments( + args=py_ast.arguments( posonlyargs=[], - args=([ast.arg(arg=arg_name, annotation=None, **DEFAULT_AST_ARGS) for arg_name in self.args]), + args=([py_ast.arg(arg=arg_name, annotation=None, **DEFAULT_AST_ARGS) for arg_name in self.args]), vararg=None, kwonlyargs=[], kw_defaults=[], @@ -411,45 +440,45 @@ def as_ast(self): returns=None, # ast_decompiler compat **DEFAULT_AST_ARGS, ) - if self.source is not None and self.source.filename is not None: - func_def.filename = self.source.filename # See Module.as_multiple_module_ast + if (source := self.source) is not None and source.filename is not None: + func_def.filename = source.filename # See Module.as_multiple_module_ast # It's hard to get good line numbers for all AST objects, but # if we put the FTL line number of the main message on all nodes # this gets us a lot of the benefit for a smallish cost - def add_lineno(node): - node.lineno = self.source.row + def add_lineno(node: py_ast.AST): + node.lineno = source.row traverse(func_def, add_lineno) return func_def - def add_return(self, value): + def add_return(self, value: Expression): self.body.add_return(value) -class Return(Statement, PythonAst): +class Return(Statement): child_elements = ["value"] - def __init__(self, value): + def __init__(self, value: Expression): self.value = value def as_ast(self): - return ast.Return(self.value.as_ast(), **DEFAULT_AST_ARGS) + return py_ast.Return(self.value.as_ast(), **DEFAULT_AST_ARGS) def __repr__(self): return f"Return({repr(self.value)}" -class If(Statement, PythonAst): +class If(Statement): child_elements = ["if_blocks", "conditions", "else_block"] - def __init__(self, parent_scope, parent_block=None): + def __init__(self, parent_scope: Scope, parent_block: Block | None = None): # We model a "compound if statement" as a list of if blocks # (if/elif/elif etc), each with their own condition, with a final else # block. Note this is quite different from Python's AST for the same # thing, so conversion to AST is more complex because of this. - self.if_blocks = [] - self.conditions = [] + self.if_blocks: list[Block] = [] + self.conditions: list[Expression] = [] self._parent_block = parent_block self.else_block = Block(parent_scope, parent_block=self._parent_block) self._parent_scope = parent_scope @@ -460,7 +489,7 @@ def add_if(self, condition): self.conditions.append(condition) return new_if - def finalize(self): + def finalize(self) -> Block | Statement: if not self.if_blocks: # Unusual case of no conditions, only default case, but it # simplifies other code to be able to handle this uniformly. We can @@ -468,10 +497,10 @@ def finalize(self): return self.else_block return self - def as_ast(self): + def as_ast(self) -> py_ast.If: if len(self.if_blocks) == 0: raise AssertionError("Should have called `finalize` on If") - if_ast = ast.If(test=None, orelse=[], **DEFAULT_AST_ARGS) + if_ast = empty_If() current_if = if_ast previous_if = None for condition, if_block in zip(self.conditions, self.if_blocks): @@ -481,34 +510,35 @@ def as_ast(self): previous_if.orelse.append(current_if) previous_if = current_if - current_if = ast.If(test=None, orelse=[], **DEFAULT_AST_ARGS) + current_if = empty_If() if self.else_block.statements: + assert previous_if is not None previous_if.orelse = self.else_block.as_ast_list() return if_ast -class Try(Statement, PythonAst): +class Try(Statement): child_elements = ["catch_exceptions", "try_block", "except_block", "else_block"] - def __init__(self, catch_exceptions, parent_scope): + def __init__(self, catch_exceptions: Sequence[Expression], parent_scope: Scope): self.catch_exceptions = catch_exceptions self.try_block = Block(parent_scope) self.except_block = Block(parent_scope) self.else_block = Block(parent_scope) - def as_ast(self): - return ast.Try( + def as_ast(self) -> py_ast.Try: + return py_ast.Try( body=self.try_block.as_ast_list(allow_empty=False), handlers=[ - ast.ExceptHandler( + py_ast.ExceptHandler( type=( self.catch_exceptions[0].as_ast() if len(self.catch_exceptions) == 1 - else ast.Tuple( + else py_ast.Tuple( elts=[e.as_ast() for e in self.catch_exceptions], - ctx=ast.Load(), + ctx=py_ast.Load(), **DEFAULT_AST_ARGS, ) ), @@ -522,7 +552,7 @@ def as_ast(self): **DEFAULT_AST_ARGS, ) - def has_assignment_for_name(self, name): + def has_assignment_for_name(self, name: str) -> bool: if ( self.try_block.has_assignment_for_name(name) or self.else_block.has_assignment_for_name(name) ) and self.except_block.has_assignment_for_name(name): @@ -530,10 +560,14 @@ def has_assignment_for_name(self, name): return False -class Expression(PythonAst): +class Expression(CodeGenAst): # type represents the Python type this expression will produce, # if we know it (UNKNOWN_TYPE otherwise). - type = UNKNOWN_TYPE + type: type = UNKNOWN_TYPE + + @abstractmethod + def as_ast(self) -> py_ast.expr: + raise NotImplementedError() class String(Expression): @@ -541,11 +575,11 @@ class String(Expression): type = str - def __init__(self, string_value): + def __init__(self, string_value: str): self.string_value = string_value - def as_ast(self): - return ast.Constant( + def as_ast(self) -> py_ast.expr: + return py_ast.Constant( self.string_value, kind=None, # 3.8, indicates no prefix, needed only for tests **DEFAULT_AST_ARGS, @@ -554,19 +588,19 @@ def as_ast(self): def __repr__(self): return f"String({repr(self.string_value)})" - def __eq__(self, other): + def __eq__(self, other: object): return isinstance(other, String) and other.string_value == self.string_value class Number(Expression): child_elements = [] - def __init__(self, number): + def __init__(self, number: int | float | decimal.Decimal): self.number = number self.type = type(number) - def as_ast(self): - return ast.Constant(self.number, **DEFAULT_AST_ARGS) + def as_ast(self) -> py_ast.expr: + return py_ast.Constant(self.number, **DEFAULT_AST_ARGS) def __repr__(self): return f"Number({repr(self.number)})" @@ -579,22 +613,21 @@ def __init__(self, items): self.items = items self.type = list - def as_ast(self): - return ast.List(elts=[i.as_ast() for i in self.items], ctx=ast.Load(), **DEFAULT_AST_ARGS) + def as_ast(self) -> py_ast.expr: + return py_ast.List(elts=[i.as_ast() for i in self.items], ctx=py_ast.Load(), **DEFAULT_AST_ARGS) class Dict(Expression): child_elements = ["pairs"] - def __init__(self, pairs): - # pairs is a list of key-value pairs (PythonAst object, PythonAst object) + def __init__(self, pairs: Sequence[tuple[Expression, Expression]]): self.pairs = pairs self.type = dict - def as_ast(self): - return ast.Dict( - keys=[k.as_ast() for k, v in self.pairs], - values=[v.as_ast() for k, v in self.pairs], + def as_ast(self) -> py_ast.expr: + return py_ast.Dict( + keys=[k.as_ast() for k, _ in self.pairs], + values=[v.as_ast() for _, v in self.pairs], **DEFAULT_AST_ARGS, ) @@ -604,16 +637,19 @@ class StringJoinBase(Expression): type = str - def __init__(self, parts): + def __init__(self, parts: Sequence[Expression]): self.parts = parts def __repr__(self): return f"{self.__class__.__name__}([{', '.join(repr(p) for p in self.parts)}])" @classmethod - def build(cls, parts): + def build(cls: type[StringJoinBase], parts: Sequence[Expression]) -> StringJoinBase | Expression: + """ + Build a string join operation, but return a simpler expression if possible. + """ # Merge adjacent String objects. - new_parts = [] + new_parts: list[Expression] = [] for part in parts: if len(new_parts) > 0 and isinstance(new_parts[-1], String) and isinstance(part, String): new_parts[-1] = String(new_parts[-1].string_value + part.string_value) @@ -630,33 +666,33 @@ def build(cls, parts): class FStringJoin(StringJoinBase): - def as_ast(self): + def as_ast(self) -> py_ast.expr: # f-strings - values = [] + values: list[py_ast.expr] = [] for part in self.parts: if isinstance(part, String): values.append(part.as_ast()) else: values.append( - ast.FormattedValue( + py_ast.FormattedValue( value=part.as_ast(), conversion=-1, format_spec=None, **DEFAULT_AST_ARGS, ) ) - return ast.JoinedStr(values=values, **DEFAULT_AST_ARGS) + return py_ast.JoinedStr(values=values, **DEFAULT_AST_ARGS) class ConcatJoin(StringJoinBase): - def as_ast(self): + def as_ast(self) -> py_ast.expr: # Concatenate with + left = self.parts[0].as_ast() for part in self.parts[1:]: right = part.as_ast() - left = ast.BinOp( + left = py_ast.BinOp( left=left, - op=ast.Add(**DEFAULT_AST_ARGS_ADD), + op=py_ast.Add(**DEFAULT_AST_ARGS_ADD), right=right, **DEFAULT_AST_ARGS, ) @@ -677,19 +713,21 @@ def as_ast(self): class VariableReference(Expression): child_elements = [] - def __init__(self, name, scope): + def __init__(self, name: str, scope: Scope): if not scope.is_name_in_use(name): raise AssertionError(f"Cannot refer to undefined variable '{name}'") self.name = name - self.type = scope.get_name_properties(name).get(PROPERTY_TYPE, UNKNOWN_TYPE) + looked_up_type = scope.get_name_properties(name).get(PROPERTY_TYPE, UNKNOWN_TYPE) + assert isinstance(looked_up_type, type) + self.type = looked_up_type - def as_ast(self): + def as_ast(self) -> py_ast.expr: if not allowable_name(self.name, allow_builtin=True): raise AssertionError(f"Expected {self.name} to be a valid Python identifier") - return ast.Name(id=self.name, ctx=ast.Load(), **DEFAULT_AST_ARGS) + return py_ast.Name(id=self.name, ctx=py_ast.Load(), **DEFAULT_AST_ARGS) def __eq__(self, other): - return type(other) == type(self) and other.name == self.name + return type(other) is type(self) and other.name == self.name def __repr__(self): return f"VariableReference({repr(self.name)})" @@ -698,7 +736,14 @@ def __repr__(self): class FunctionCall(Expression): child_elements = ["args", "kwargs"] - def __init__(self, function_name, args, kwargs, scope, expr_type=UNKNOWN_TYPE): + def __init__( + self, + function_name: str, + args: Sequence[Expression], + kwargs: dict[str, Expression], + scope: Scope, + expr_type: type = UNKNOWN_TYPE, + ): if not scope.is_name_in_use(function_name): raise AssertionError(f"Cannot call unknown function '{function_name}'") self.function_name = function_name @@ -706,10 +751,12 @@ def __init__(self, function_name, args, kwargs, scope, expr_type=UNKNOWN_TYPE): self.kwargs = kwargs if expr_type is UNKNOWN_TYPE: # Try to find out automatically - expr_type = scope.get_name_properties(function_name).get(PROPERTY_RETURN_TYPE, expr_type) + looked_up_return_type = scope.get_name_properties(function_name).get(PROPERTY_RETURN_TYPE, expr_type) + assert isinstance(looked_up_return_type, type) + expr_type = looked_up_return_type self.type = expr_type - def as_ast(self): + def as_ast(self) -> py_ast.expr: if not allowable_name(self.function_name, allow_builtin=True): raise AssertionError(f"Expected {self.function_name} to be a valid Python identifier or builtin") @@ -735,15 +782,15 @@ def as_ast(self): # decompiles to something more recognisably correct, we pretend this # is necessary). kwarg_pairs = list(sorted(self.kwargs.items())) - kwarg_names, kwarg_values = [k for k, v in kwarg_pairs], [v for k, v in kwarg_pairs] - return ast.Call( - func=ast.Name(id=self.function_name, ctx=ast.Load(), **DEFAULT_AST_ARGS), + kwarg_names, kwarg_values = [k for k, _ in kwarg_pairs], [v for _, v in kwarg_pairs] + return py_ast.Call( + func=py_ast.Name(id=self.function_name, ctx=py_ast.Load(), **DEFAULT_AST_ARGS), args=[arg.as_ast() for arg in self.args], keywords=[ - ast.keyword( + py_ast.keyword( arg=None, - value=ast.Dict( - keys=[ast.Constant(k, kind=None, **DEFAULT_AST_ARGS) for k in kwarg_names], + value=py_ast.Dict( + keys=[py_ast.Constant(k, kind=None, **DEFAULT_AST_ARGS) for k in kwarg_names], values=[v.as_ast() for v in kwarg_values], **DEFAULT_AST_ARGS, ), @@ -754,11 +801,12 @@ def as_ast(self): ) # Normal `my_function(foo=bar)` syntax - return ast.Call( - func=ast.Name(id=self.function_name, ctx=ast.Load(), **DEFAULT_AST_ARGS), + return py_ast.Call( + func=py_ast.Name(id=self.function_name, ctx=py_ast.Load(), **DEFAULT_AST_ARGS), args=[arg.as_ast() for arg in self.args], keywords=[ - ast.keyword(arg=name, value=value.as_ast(), **DEFAULT_AST_ARGS) for name, value in self.kwargs.items() + py_ast.keyword(arg=name, value=value.as_ast(), **DEFAULT_AST_ARGS) + for name, value in self.kwargs.items() ], **DEFAULT_AST_ARGS, ) @@ -770,21 +818,21 @@ def __repr__(self): class MethodCall(Expression): child_elements = ["obj", "args"] - def __init__(self, obj, method_name, args, expr_type=UNKNOWN_TYPE): + def __init__(self, obj: Expression, method_name: str, args: Sequence[Expression], expr_type: type = UNKNOWN_TYPE): # We can't check method_name because we don't know the type of obj yet. self.obj = obj self.method_name = method_name self.args = args self.type = expr_type - def as_ast(self): + def as_ast(self) -> py_ast.expr: if not allowable_name(self.method_name, for_method=True): raise AssertionError(f"Expected {self.method_name} to be a valid Python identifier") - return ast.Call( - func=ast.Attribute( + return py_ast.Call( + func=py_ast.Attribute( value=self.obj.as_ast(), attr=self.method_name, - ctx=ast.Load(), + ctx=py_ast.Load(), **DEFAULT_AST_ARGS, ), args=[arg.as_ast() for arg in self.args], @@ -799,16 +847,16 @@ def __repr__(self): class DictLookup(Expression): child_elements = ["lookup_obj", "lookup_arg"] - def __init__(self, lookup_obj, lookup_arg, expr_type=UNKNOWN_TYPE): + def __init__(self, lookup_obj: Expression, lookup_arg: Expression, expr_type: type = UNKNOWN_TYPE): self.lookup_obj = lookup_obj self.lookup_arg = lookup_arg self.type = expr_type - def as_ast(self): - return ast.Subscript( + def as_ast(self) -> py_ast.expr: + return py_ast.Subscript( value=self.lookup_obj.as_ast(), - slice=ast.Index(value=self.lookup_arg.as_ast(), **DEFAULT_AST_ARGS), - ctx=ast.Load(), + slice=py_ast.subscript_slice_object(self.lookup_arg.as_ast()), + ctx=py_ast.Load(), **DEFAULT_AST_ARGS, ) @@ -819,14 +867,14 @@ def as_ast(self): class NoneExpr(Expression): type = type(None) - def as_ast(self): - return ast.Constant(value=None, **DEFAULT_AST_ARGS) + def as_ast(self) -> py_ast.expr: + return py_ast.Constant(value=None, **DEFAULT_AST_ARGS) class BinaryOperator(Expression): child_elements = ["left", "right"] - def __init__(self, left, right): + def __init__(self, left: Expression, right: Expression): self.left = left self.right = right @@ -834,11 +882,11 @@ def __init__(self, left, right): class Equals(BinaryOperator): type = bool - def as_ast(self): - return ast.Compare( + def as_ast(self) -> py_ast.expr: + return py_ast.Compare( left=self.left.as_ast(), comparators=[self.right.as_ast()], - ops=[ast.Eq()], + ops=[py_ast.Eq()], **DEFAULT_AST_ARGS, ) @@ -847,8 +895,8 @@ class BoolOp(BinaryOperator): type = bool op = NotImplemented - def as_ast(self): - return ast.BoolOp( + def as_ast(self) -> py_ast.expr: + return py_ast.BoolOp( op=self.op(), values=[self.left.as_ast(), self.right.as_ast()], **DEFAULT_AST_ARGS, @@ -856,18 +904,18 @@ def as_ast(self): class Or(BoolOp): - op = ast.Or + op = py_ast.Or -def traverse(ast_node, func): +def traverse(ast_node: py_ast.AST, func: Callable[[py_ast.AST], None]): """ Apply 'func' to ast_node (which is `ast.*` object) """ - for node in ast.walk(ast_node): + for node in py_ast.walk(ast_node): func(node) -def simplify(codegen_ast, simplifier): +def simplify(codegen_ast: CodeGenAst, simplifier: Callable[[CodeGenAst, list[bool]], CodeGenAst]): changes = [True] # Wrap `simplifier` (which takes additional `changes` arg) @@ -881,11 +929,14 @@ def rewriter(node): return codegen_ast -def rewriting_traverse(node, func): +def rewriting_traverse( + node: CodeGenAstType | list | tuple | dict, + func: Callable[[CodeGenAstType], CodeGenAstType], +): """ - Apply 'func' to node and all sub PythonAst nodes + Apply 'func' to node and all sub CodeGenAst nodes """ - if isinstance(node, (PythonAst, PythonAstList)): + if isinstance(node, (CodeGenAst, CodeGenAstList)): new_node = func(node) if new_node is not node: morph_into(node, new_node) @@ -900,9 +951,17 @@ def rewriting_traverse(node, func): rewriting_traverse(v, func) -def morph_into(item, new_item): +def morph_into(item: object, new_item: object) -> None: # This naughty little function allows us to make `item` behave like # `new_item` in every way, except it maintains the identity of `item`, so # that we don't have to rewrite a tree of objects with new objects. item.__class__ = new_item.__class__ item.__dict__ = new_item.__dict__ + + +def empty_If() -> py_ast.If: + """ + Create an empty If ast node. The `test` attribute + must be added later. + """ + return py_ast.If(test=None, orelse=[], **DEFAULT_AST_ARGS) # type: ignore[reportArgumentType] diff --git a/src/fluent_compiler/compat.py b/src/fluent_compiler/compat.py new file mode 100644 index 0000000..966d01d --- /dev/null +++ b/src/fluent_compiler/compat.py @@ -0,0 +1,11 @@ +try: + from enum import StrEnum +except ImportError: + from backports.strenum import StrEnum # type: ignore[reportMissingImports] + +try: + from typing import TypeAlias +except ImportError: + from typing_extensions import TypeAlias # type: ignore[reportMissingImports] + +__all__ = ["StrEnum", "TypeAlias"] diff --git a/src/fluent_compiler/compiler.py b/src/fluent_compiler/compiler.py index 47bd169..c93870f 100644 --- a/src/fluent_compiler/compiler.py +++ b/src/fluent_compiler/compiler.py @@ -1,35 +1,27 @@ # The heart of the FTL -> Python compiler. See the architecture docs in # ARCHITECTURE.rst for the big picture, and comments on compile_expr below. +from __future__ import annotations import builtins import contextlib -from collections import OrderedDict +import dataclasses +import decimal +from dataclasses import dataclass, field from functools import singledispatch +from typing import Any, Callable, ContextManager, Generator, Iterable, Mapping, Sequence, Tuple, Union -import attr import babel +import babel.plural from fluent.syntax import FluentParser -from fluent.syntax.ast import ( - Attribute, - BaseNode, - FunctionReference, - Identifier, - Junk, - Message, - MessageReference, - NumberLiteral, - Pattern, - Placeable, - SelectExpression, - StringLiteral, - Term, - TermReference, - TextElement, - VariableReference, -) +from fluent.syntax import ast as fl_ast +from typing_extensions import TypeGuard + +from fluent_compiler.source import FtlSource +from . import ast_compat as py_ast from . import codegen, runtime from .builtins import BUILTINS +from .compat import TypeAlias from .errors import ( FluentCyclicReferenceError, FluentDuplicateMessageId, @@ -37,11 +29,22 @@ FluentJunkFound, FluentReferenceError, ) -from .escapers import EscaperJoin, RegisteredEscaper, escaper_for_message, escapers_compatible, identity, null_escaper +from .escapers import ( + EscaperJoin, + IsEscaper, + NullEscaper, + RegisteredEscaper, + escaper_for_message, + escapers_compatible, + identity, + null_escaper, +) +from .resource import FtlResource from .types import FluentDateType, FluentNone, FluentNumber, FluentType from .utils import ( ATTRIBUTE_SEPARATOR, TERM_SIGIL, + FunctionArgSpec, args_match, ast_to_id, attribute_ast_to_id, @@ -57,7 +60,7 @@ BUILTIN_NUMBER = "NUMBER" BUILTIN_DATETIME = "DATETIME" -BUILTIN_RETURN_TYPES = { +BUILTIN_RETURN_TYPES: dict[str, type] = { BUILTIN_NUMBER: FluentNumber, BUILTIN_DATETIME: FluentDateType, } @@ -80,90 +83,101 @@ PROPERTY_EXTERNAL_ARG = "PROPERTY_EXTERNAL_ARG" -@attr.s +@dataclass class CurrentEnvironment: + # TODO make fields not optional, and the whole of `CurrentEnvironment` optional instead # The parts of CompilerEnvironment that we want to mutate (and restore) # temporarily for some parts of a call chain. - message_id = attr.ib(default=None) - ftl_resource = attr.ib(default=None) - term_args = attr.ib(default=None) - in_select_expression = attr.ib(default=False) - escaper = attr.ib(default=null_escaper) + message_id: str + ftl_resource: FtlResource + term_args: dict | None = None + in_select_expression: bool = False + escaper: RegisteredEscaper | NullEscaper = field(default_factory=lambda: null_escaper) + + +NumberType: TypeAlias = Union[float, decimal.Decimal] + +PluralFormFunc: TypeAlias = Callable[[NumberType], Union[str, None]] +MessageFunc: TypeAlias = Callable[[Union[dict, None], list], str] -@attr.s +CompilationErrorItem: TypeAlias = Tuple[Union[str, None], Exception] + + +@dataclass class CompilerEnvironment: - locale = attr.ib() - plural_form_function = attr.ib() - use_isolating = attr.ib() - message_mapping = attr.ib(factory=dict) - errors = attr.ib(factory=list) - escapers = attr.ib(default=None) - functions = attr.ib(factory=dict) - function_renames = attr.ib(factory=dict) - functions_arg_spec = attr.ib(factory=dict) - message_ids_to_ast = attr.ib(factory=dict) - term_ids_to_ast = attr.ib(factory=dict) - current = attr.ib(factory=CurrentEnvironment) - - def add_current_message_error(self, error): - self.errors.append((self.current.message_id, error)) - - def escaper_for_message(self, message_id=None): + locale: babel.Locale + plural_form_function: PluralFormFunc + use_isolating: bool + current: CurrentEnvironment + message_mapping: dict[str, str] = field(default_factory=dict) + errors: list[CompilationErrorItem] = field(default_factory=list) + escapers: Sequence[RegisteredEscaper] = field(default_factory=list) + functions: Mapping[str, Callable] = field(default_factory=dict) + function_renames: dict[str, str] = field(default_factory=dict) + functions_arg_spec: dict[str, FunctionArgSpec] = field(default_factory=dict) + message_ids_to_ast: dict[str, fl_ast.Message | fl_ast.Attribute] = field(default_factory=dict) + term_ids_to_ast: dict[str, fl_ast.Term | fl_ast.Attribute] = field(default_factory=dict) + + def add_current_message_error(self, error: Exception): + message_id = self.current.message_id if self.current else None + self.errors.append((message_id, error)) + + def escaper_for_message(self, message_id: str | None = None) -> RegisteredEscaper | NullEscaper: return escaper_for_message(self.escapers, message_id=message_id) @contextlib.contextmanager - def modified(self, **replacements): + def modified(self, **replacements) -> Generator[CompilerEnvironment, None, None]: """ Context manager that modifies the 'current' attribute of the environment, restoring the old data at the end. """ # CurrentEnvironment only has immutable args at the moment, so the - # shallow copy returned by attr.evolve is fine. + # shallow copy returned by dataclasses.replace is fine. old_current = self.current - self.current = attr.evolve(old_current, **replacements) + if old_current is None: + self.current = CurrentEnvironment(**replacements) + else: + self.current = dataclasses.replace(old_current, **replacements) yield self self.current = old_current - def modified_for_term_reference(self, term_args=None): + def modified_for_term_reference( + self, + term_args: dict[str, codegen.CodeGenAst] | None = None, + ) -> ContextManager[CompilerEnvironment]: return self.modified(term_args=term_args if term_args is not None else {}) - def should_use_isolating(self): + def should_use_isolating(self) -> bool: if self.current.escaper.use_isolating is None: return self.use_isolating return self.current.escaper.use_isolating -class FtlSource: - """ - Object used to specify the origin of a chunk of FTL - """ - - def __init__(self, ast_node, ftl_resource): - self.ast_node = ast_node - self.ftl_resource = ftl_resource - self.filename = self.ftl_resource.filename - self.row, self.column = span_to_position(ast_node.span, ftl_resource.text) - - -@attr.s +@dataclass class CompiledFtl: # A dictionary of message IDs to Python functions. This is the primary # output that is needed to execute the FTL - the functions simply need to be # called with a dictionary of external arguments, and a list to which # runtime errors will be added. - message_functions = attr.ib(factory=dict) + message_functions: dict[str, MessageFunc] = field(default_factory=dict) # A list of parsing and compilation errors, where each item is # (message_id or None, exception object) - errors = attr.ib(factory=list) + errors: list[CompilationErrorItem] = field(default_factory=list) # Compiled output as Python AST. - module_ast = attr.ib(default=None) + module_ast: py_ast.Module | None = None - locale = attr.ib(default=None) + locale: str | None = None -def compile_messages(locale, resources, use_isolating=True, functions=None, escapers=None): +def compile_messages( + locale: str, + resources: Sequence[FtlResource], + use_isolating: bool = True, + functions: dict[str, Callable] | None = None, + escapers: Sequence[IsEscaper] | None = None, +) -> CompiledFtl: """ Compile a list of FtlResource to a Python module, and returns a CompiledFtl objects @@ -211,14 +225,16 @@ def compile_messages(locale, resources, use_isolating=True, functions=None, esca ) -def _parse_resources(ftl_resources): +def _parse_resources( + ftl_resources: Sequence[FtlResource], +) -> tuple[Mapping[str, TermOrMessage], list[CompilationErrorItem]]: parsing_issues = [] - output_dict = OrderedDict() + output_dict = dict() for ftl_resource in ftl_resources: parser = FluentParser() resource = parser.parse(ftl_resource.text) for item in resource.body: - if isinstance(item, (Message, Term)): + if isinstance(item, (fl_ast.Message, fl_ast.Term)): full_id = ast_to_id(item) if full_id in output_dict: parsing_issues.append( @@ -233,7 +249,7 @@ def _parse_resources(ftl_resources): for attribute in item.attributes: attribute.ftl_resource = ftl_resource output_dict[full_id] = item - elif isinstance(item, Junk): + elif isinstance(item, fl_ast.Junk): parsing_issues.append( ( None, @@ -256,7 +272,16 @@ def _parse_resources(ftl_resources): return output_dict, parsing_issues -def messages_to_module(messages, locale, use_isolating=True, functions=None, escapers=None): +TermOrMessage: TypeAlias = Union[fl_ast.Message, fl_ast.Term] + + +def messages_to_module( + messages: Mapping[str, TermOrMessage], + locale: babel.Locale, + use_isolating: bool = True, + functions: Mapping[str, Callable] | None = None, + escapers: Sequence[IsEscaper] | None = None, +) -> tuple: """ Compile a set of {id: Message/Term objects} to a Python module, returning a tuple: (codegen.Module object, dictionary mapping message IDs to Python functions, @@ -265,13 +290,13 @@ def messages_to_module(messages, locale, use_isolating=True, functions=None, esc if functions is None: functions = {} - message_ids_to_ast = OrderedDict(get_message_function_ast(messages)) - term_ids_to_ast = OrderedDict(get_term_ast(messages)) + message_ids_to_ast = dict(get_message_function_ast(messages)) + term_ids_to_ast = dict(get_term_ast(messages)) # Plural form function plural_form_for_number_main = babel.plural.to_python(locale.plural_form) - def plural_form_for_number(number): + def plural_form_for_number(number: NumberType) -> str | None: try: return plural_form_for_number_main(number) except TypeError: @@ -280,6 +305,8 @@ def plural_form_for_number(number): return None function_arg_errors = [] + # TODO - to avoid issues with incomplete CompilerEnvironment/CurrentEnvironment, + # maybe don't create this object until we can set it up fully compiler_env = CompilerEnvironment( locale=locale, plural_form_function=plural_form_for_number, @@ -290,6 +317,8 @@ def plural_form_for_number(number): }, message_ids_to_ast=message_ids_to_ast, term_ids_to_ast=term_ids_to_ast, + # We will fix this up before we use it: + current=None, # type: ignore[reportArgumentType] ) for err in function_arg_errors: compiler_env.add_current_message_error(err) @@ -305,15 +334,15 @@ def plural_form_for_number(number): module_globals[LOCALE_NAME] = locale # Return types of known functions. - known_return_types = {} + known_return_types: dict[str, type] = {} known_return_types.update(BUILTIN_RETURN_TYPES) known_return_types.update(runtime.RETURN_TYPES) module_globals[PLURAL_FORM_FOR_NUMBER_NAME] = plural_form_for_number known_return_types[PLURAL_FORM_FOR_NUMBER_NAME] = str - def get_name_properties(name): - properties = {} + def get_name_properties(name: str) -> dict[str, object]: + properties: dict[str, object] = {} if name in known_return_types: properties[codegen.PROPERTY_RETURN_TYPE] = known_return_types[name] return properties @@ -325,15 +354,14 @@ def get_name_properties(name): assert name == k, f"Expected {name}=={k}" # Reserve names for escapers - if compiler_env.escapers is not None: - for escaper in compiler_env.escapers: - for name, func, properties in escaper.get_reserved_names_with_properties(): - assigned_name = module.scope.reserve_name(name, properties=properties) - # We've chosen the names to not clash with anything that - # we've already set up. - assert assigned_name == name - assert assigned_name not in module_globals - module_globals[assigned_name] = func + for escaper in compiler_env.escapers: + for name, func, properties in escaper.get_reserved_names_with_properties(): + assigned_name = module.scope.reserve_name(name, properties=properties) + # We've chosen the names to not clash with anything that + # we've already set up. + assert assigned_name == name + assert assigned_name not in module_globals + module_globals[assigned_name] = func # Reserve names for function arguments, so that we always # know the name of these arguments without needing to do @@ -375,9 +403,11 @@ def get_name_properties(name): return (module, compiler_env.message_mapping, module_globals, compiler_env.errors) -def get_message_function_ast(message_dict): +def get_message_function_ast( + message_dict: Mapping[str, TermOrMessage] +) -> Iterable[tuple[str, fl_ast.Attribute | fl_ast.Message]]: for msg_id, msg in message_dict.items(): - if isinstance(msg, Term): + if isinstance(msg, fl_ast.Term): continue if msg.value is not None: # has a body yield (msg_id, msg) @@ -385,10 +415,10 @@ def get_message_function_ast(message_dict): yield (attribute_ast_to_id(attribute, msg), attribute) -def get_term_ast(message_dict): +def get_term_ast(message_dict: Mapping[str, TermOrMessage]) -> Iterable[tuple[str, fl_ast.Attribute | fl_ast.Term]]: for term_id, term in message_dict.items(): - if isinstance(term, Message): - pass + if isinstance(term, fl_ast.Message): + continue if term.value is not None: # has a body yield (term_id, term) @@ -396,7 +426,7 @@ def get_term_ast(message_dict): yield (attribute_ast_to_id(attribute, term), attribute) -def suggested_function_name_for_msg_id(msg_id): +def suggested_function_name_for_msg_id(msg_id: str) -> str: # Scope.reserve_name does further sanitising of name, which we don't need to # worry about. It also ensures we don't get dupes. So the fact that this # method will produce occasional collisions is not an issue - here we are @@ -406,7 +436,13 @@ def suggested_function_name_for_msg_id(msg_id): return msg_id.replace(ATTRIBUTE_SEPARATOR, "__").replace("-", "_") -def compile_message(msg, msg_id, function_name, module, compiler_env): +def compile_message( + msg: fl_ast.Attribute | fl_ast.Message, + msg_id: str, + function_name: str, + module: codegen.Module, + compiler_env: CompilerEnvironment, +) -> codegen.Function: msg_func = codegen.Function( parent_scope=module.scope, name=function_name, @@ -423,12 +459,17 @@ def compile_message(msg, msg_id, function_name, module, compiler_env): ) else: return_expression = compile_expr(msg, function_block, compiler_env) + assert isinstance(return_expression, codegen.Expression), f"Expected Expression, got {return_expression}" # > return $return_expression msg_func.add_return(return_expression) return msg_func -def traverse_ast(node, func, exclude_attributes=None): +def traverse_ast( + node: fl_ast.BaseNode, + func: Callable[[object], None], + exclude_attributes: list[tuple[type[TermOrMessage], str]] | None = None, +) -> None: """ Postorder-traverse this node and apply `func` to all child nodes. @@ -438,7 +479,7 @@ def traverse_ast(node, func, exclude_attributes=None): def visit(value): """Call `func` on `value` and its descendants.""" - if isinstance(value, BaseNode): + if isinstance(value, fl_ast.BaseNode): return traverse_ast(value, func, exclude_attributes=exclude_attributes) if isinstance(value, list): return func(list(map(visit, value))) @@ -454,7 +495,7 @@ def visit(value): return func(node) -def contains_reference_cycle(msg, compiler_env): +def contains_reference_cycle(msg: fl_ast.Attribute | fl_ast.Message, compiler_env: CompilerEnvironment) -> bool: """ Returns True if the message 'msg' contains a cyclic reference, in the context of the other messages provided in compiler_env @@ -491,11 +532,11 @@ def contains_reference_cycle(msg, compiler_env): # references. exclude_attributes = [ # Message and Term attributes have already been loaded into the message_ids_to_ast dict, - (Message, "attributes"), - (Term, "attributes"), + (fl_ast.Message, "attributes"), + (fl_ast.Term, "attributes"), # for speed - (Message, "comment"), - (Term, "comment"), + (fl_ast.Message, "comment"), + (fl_ast.Term, "comment"), ] # We need to keep track of visited nodes. If we use just a single set for @@ -514,7 +555,7 @@ def contains_reference_cycle(msg, compiler_env): checks = [] def checker(node): - if isinstance(node, BaseNode): + if isinstance(node, fl_ast.BaseNode): node_id = id(node) if node_id in visited_node_stack[-1]: checks.append(True) @@ -527,7 +568,7 @@ def checker(node): # different nodes (messages via a runtime function call, terms via # inlining), including the fallback strategies that are used. sub_node = None - if isinstance(node, (MessageReference, TermReference)): + if isinstance(node, (fl_ast.MessageReference, fl_ast.TermReference)): ref_id = reference_to_id(node) if ref_id in message_ids_to_ast: sub_node = message_ids_to_ast[ref_id] @@ -559,7 +600,7 @@ def checker(node): # # The `compile_expr_XXXX functions` form the heart of handling all FTL syntax. # They convert FTL AST nodes (as created by fluent.syntax parser) -# into Python expressions (in the form of our `codegen.PythonAst` objects). +# into Python expressions (in the form of our `codegen.CodeGenAst` objects). # # The first `compile_expr` function is decorated with `@singledispatch`, # so we can then dispatch to other functions based on the type of the first @@ -567,7 +608,7 @@ def checker(node): # `if isinstance(ast, XXX): handle_XXX(...)`, or other similar visitor patterns. # # The basic structure is that each `compile_expr` returns a single -# codegen.PythonAst object that corresponds to the passed in FTL AST (the first +# codegen.CodeGenAst object that corresponds to the passed in FTL AST (the first # argument). That is, the overall strategy is to compile each FTL AST object to # a single Python expression. # @@ -634,7 +675,7 @@ def checker(node): # -> compile_expr_message_reference # # -# Note that some of the codegen.PythonAst objects can simplify themselves as +# Note that some of the codegen.CodeGenAst objects can simplify themselves as # they are being built or finalised, and further transformations (i.e. # simplifications and optimizations) are done after we've built up a complete # Python AST for the function. So the easy one-to-one correspondence above will @@ -649,7 +690,9 @@ def checker(node): @singledispatch -def compile_expr(element, block, compiler_env): +def compile_expr( + element: fl_ast.BaseNode, block: codegen.Block, compiler_env: CompilerEnvironment +) -> codegen.CodeGenAst: """ Compiles a Fluent expression into a Python one, return an object of type codegen.Expression. @@ -661,30 +704,36 @@ def compile_expr(element, block, compiler_env): raise NotImplementedError(f"Cannot handle object of type {type(element).__name__}") -@compile_expr.register(Message) -def compile_expr_message(message, block, compiler_env): +@compile_expr.register(fl_ast.Message) +def compile_expr_message( + message: fl_ast.Message, block: codegen.Block, compiler_env: CompilerEnvironment +) -> codegen.CodeGenAst: return compile_expr(message.value, block, compiler_env) -@compile_expr.register(Term) +@compile_expr.register(fl_ast.Term) def compile_expr_term(term, block, compiler_env): return compile_expr(term.value, block, compiler_env) -@compile_expr.register(Attribute) -def compile_expr_attribute(attribute, block, compiler_env): +@compile_expr.register(fl_ast.Attribute) +def compile_expr_attribute( + attribute: fl_ast.Attribute, block: codegen.Block, compiler_env: CompilerEnvironment +) -> codegen.CodeGenAst: return compile_expr(attribute.value, block, compiler_env) -@compile_expr.register(Pattern) -def compile_expr_pattern(pattern, block, compiler_env): +@compile_expr.register(fl_ast.Pattern) +def compile_expr_pattern( + pattern: fl_ast.Pattern, block: codegen.Block, compiler_env: CompilerEnvironment +) -> codegen.CodeGenAst: parts = [] subelements = pattern.elements use_isolating = compiler_env.should_use_isolating() and len(subelements) > 1 for element in pattern.elements: - wrap_this_with_isolating = use_isolating and not isinstance(element, TextElement) + wrap_this_with_isolating = use_isolating and not isinstance(element, fl_ast.TextElement) if wrap_this_with_isolating: parts.append(wrap_with_escaper(codegen.String(FSI), block, compiler_env)) parts.append(compile_expr(element, block, compiler_env)) @@ -692,43 +741,60 @@ def compile_expr_pattern(pattern, block, compiler_env): parts.append(wrap_with_escaper(codegen.String(PDI), block, compiler_env)) # > f'$[p for p in parts]' - return EscaperJoin.build( + return EscaperJoin.build_with_escaper( [finalize_expr_as_output_type(p, block, compiler_env) for p in parts], compiler_env.current.escaper, block.scope, ) -@compile_expr.register(TextElement) -def compile_expr_text(text, block, compiler_env): +@compile_expr.register(fl_ast.TextElement) +def compile_expr_text( + text: fl_ast.TextElement, block: codegen.Block, compiler_env: CompilerEnvironment +) -> codegen.CodeGenAst: return wrap_with_mark_escaped(codegen.String(text.value), block, compiler_env) -@compile_expr.register(StringLiteral) -def compile_expr_string_expression(expr, block, compiler_env): +@compile_expr.register(fl_ast.StringLiteral) +def compile_expr_string_expression( + expr: fl_ast.StringLiteral, block: codegen.Block, compiler_env: CompilerEnvironment +) -> codegen.CodeGenAst: return codegen.String(expr.parse()["value"]) -@compile_expr.register(NumberLiteral) -def compile_expr_number_expression(expr, block, compiler_env): +@compile_expr.register(fl_ast.NumberLiteral) +def compile_expr_number_expression( + expr: fl_ast.NumberLiteral, block: codegen.Block, compiler_env: CompilerEnvironment +) -> codegen.FunctionCall: number_expr = codegen.Number(numeric_to_native(expr.value)) # > NUMBER($number_expr) return codegen.FunctionCall(BUILTIN_NUMBER, [number_expr], {}, block.scope) -@compile_expr.register(Placeable) -def compile_expr_placeable(placeable, block, compiler_env): +@compile_expr.register(fl_ast.Placeable) +def compile_expr_placeable( + placeable: fl_ast.Placeable, block: codegen.Block, compiler_env: CompilerEnvironment +) -> codegen.CodeGenAst: return compile_expr(placeable.expression, block, compiler_env) -@compile_expr.register(MessageReference) -def compile_expr_message_reference(reference, block, compiler_env): +@compile_expr.register(fl_ast.MessageReference) +def compile_expr_message_reference( + reference: fl_ast.MessageReference, block: codegen.Block, compiler_env: CompilerEnvironment +) -> codegen.Expression: return handle_message_reference(reference, block, compiler_env) -def compile_term(term, block, compiler_env, new_escaper, term_args=None): +def compile_term( + term: fl_ast.Term | fl_ast.Attribute, + block: codegen.Block, + compiler_env: CompilerEnvironment, + new_escaper: NullEscaper | RegisteredEscaper, + term_args: dict[str, codegen.CodeGenAst] | None = None, +) -> codegen.CodeGenAst: current_escaper = compiler_env.current.escaper if not escapers_compatible(current_escaper, new_escaper): + # TODO bug here when attribute is passed term_id = ast_to_id(term) error = TypeError( f"Escaper {new_escaper.name} for term {term_id} cannot be used from calling context with {current_escaper.name} escaper" @@ -742,11 +808,14 @@ def compile_term(term, block, compiler_env, new_escaper, term_args=None): return compile_expr(term.value, block, compiler_env) -@compile_expr.register(TermReference) -def compile_expr_term_reference(reference, block, compiler_env): - term, new_escaper, err_obj = lookup_term_reference(reference, block, compiler_env) - if term is None: - return err_obj +@compile_expr.register(fl_ast.TermReference) +def compile_expr_term_reference( + reference: fl_ast.TermReference, block: codegen.Block, compiler_env: CompilerEnvironment +) -> codegen.CodeGenAst: + looked_up = lookup_term_reference(reference, block, compiler_env) + if isinstance(looked_up, codegen.CodeGenAst): + return looked_up + term, new_escaper = looked_up if reference.arguments: args = [compile_expr(arg, block, compiler_env) for arg in reference.arguments.positional] kwargs = { @@ -765,10 +834,13 @@ def compile_expr_term_reference(reference, block, compiler_env): return compile_term(term, block, compiler_env, new_escaper, term_args=kwargs) -@compile_expr.register(SelectExpression) -def compile_expr_select_expression(select_expr, block, compiler_env): +@compile_expr.register(fl_ast.SelectExpression) +def compile_expr_select_expression( + select_expr: fl_ast.SelectExpression, block: codegen.Block, compiler_env: CompilerEnvironment +) -> codegen.CodeGenAst: with compiler_env.modified(in_select_expression=True): key_value = compile_expr(select_expr.selector, block, compiler_env) + assert isinstance(key_value, codegen.Expression) static_retval = resolve_select_expression_statically(select_expr, key_value, block, compiler_env) if static_retval is not None: return static_retval @@ -824,6 +896,7 @@ def compile_expr_select_expression(select_expr, block, compiler_env): condition = condition1 cur_block = if_statement.add_if(condition) assigned_value = compile_expr(variant.value, cur_block, compiler_env) + assert isinstance(assigned_value, codegen.Expression) cur_block.add_assignment(return_tmp_name, assigned_value, allow_multiple=not first) first = False assigned_types.append(assigned_value.type) @@ -837,14 +910,18 @@ def compile_expr_select_expression(select_expr, block, compiler_env): return block.scope.variable(return_tmp_name) -@compile_expr.register(Identifier) -def compile_expr_variant_name(name, block, compiler_env): +@compile_expr.register(fl_ast.Identifier) +def compile_expr_variant_name( + name: fl_ast.Identifier, block: codegen.Block, compiler_env: CompilerEnvironment +) -> codegen.String: # TODO - handle numeric literals here? return codegen.String(name.name) -@compile_expr.register(VariableReference) -def compile_expr_variable_reference(argument, block, compiler_env): +@compile_expr.register(fl_ast.VariableReference) +def compile_expr_variable_reference( + argument: fl_ast.VariableReference, block: codegen.Block, compiler_env: CompilerEnvironment +) -> codegen.CodeGenAst: name = argument.id.name if compiler_env.current.term_args is not None: # We are in a term, all args are passed explicitly, not inherited from @@ -876,7 +953,7 @@ def compile_expr_variable_reference(argument, block, compiler_env): # or # > $tmp_name = handle_argument($tmp_name, "$name", locale, errors) escaper = compiler_env.current.escaper - if escaper is null_escaper: + if isinstance(escaper, NullEscaper): handle_argument_func_call = codegen.FunctionCall( "handle_argument", [ @@ -901,14 +978,14 @@ def compile_expr_variable_reference(argument, block, compiler_env): {}, block.scope, ) + if block.scope.has_assignment(arg_tmp_name): # already assigned to this, can re-use + block.add_assignment(arg_handled_tmp_name, handle_argument_func_call) + return block.scope.variable(arg_handled_tmp_name) - if block.scope.has_assignment(arg_tmp_name): # already assigned to this, can re-use - if not wrap_with_handle_argument: + else: + if block.scope.has_assignment(arg_tmp_name): # already assigned to this, can re-use return block.scope.variable(arg_tmp_name) - block.add_assignment(arg_handled_tmp_name, handle_argument_func_call) - return block.scope.variable(arg_handled_tmp_name) - # Add try/except/else to lookup variable. try_except = codegen.Try( [ @@ -950,8 +1027,10 @@ def compile_expr_variable_reference(argument, block, compiler_env): return block.scope.variable(arg_handled_tmp_name) -@compile_expr.register(FunctionReference) -def compile_expr_function_reference(expr, block, compiler_env): +@compile_expr.register(fl_ast.FunctionReference) +def compile_expr_function_reference( + expr: fl_ast.FunctionReference, block: codegen.Block, compiler_env: CompilerEnvironment +) -> codegen.CodeGenAst: args = [compile_expr(arg, block, compiler_env) for arg in expr.arguments.positional] kwargs = {kwarg.name.name: compile_expr(kwarg.value, block, compiler_env) for kwarg in expr.arguments.named} @@ -994,11 +1073,11 @@ def compile_expr_function_reference(expr, block, compiler_env): # Compiler utilities and common code: -def add_msg_error_with_expr(block, exception_expr): +def add_msg_error_with_expr(block: codegen.Block, exception_expr: codegen.CodeGenAst): block.add_statement(codegen.MethodCall(block.scope.variable(ERRORS_NAME), "append", [exception_expr])) -def add_static_msg_error(block, exception): +def add_static_msg_error(block: codegen.Block, exception: Exception) -> None: """ Given a block and an exception object, inspect the object and add the code to the scope needed to create and add that exception to the returned errors @@ -1016,7 +1095,7 @@ def add_static_msg_error(block, exception): ) -def do_message_call(msg_id, block, compiler_env): +def do_message_call(msg_id: str, block: codegen.Block, compiler_env: CompilerEnvironment) -> codegen.Expression: current_escaper = compiler_env.current.escaper new_escaper = compiler_env.escaper_for_message(msg_id) if not escapers_compatible(current_escaper, new_escaper): @@ -1042,7 +1121,11 @@ def do_message_call(msg_id, block, compiler_env): return wrap_with_escaper(func_call, block, compiler_env) -def finalize_expr_as_output_type(codegen_ast, block, compiler_env): +def finalize_expr_as_output_type( + codegen_ast: codegen.Expression, + block: codegen.Block, + compiler_env: CompilerEnvironment, +) -> codegen.Expression: """ Wrap an outputted Python expression with code to ensure that it will return a string (or the correct output type for the escaper) @@ -1064,7 +1147,7 @@ def finalize_expr_as_output_type(codegen_ast, block, compiler_env): block, compiler_env, ) - if escaper is null_escaper: + if isinstance(escaper, NullEscaper): # > handle_output($python_expr, locale, errors) return codegen.FunctionCall( "handle_output", @@ -1094,29 +1177,35 @@ def finalize_expr_as_output_type(codegen_ast, block, compiler_env): ) -def is_cldr_plural_form_key(key_expr): - return isinstance(key_expr, Identifier) and key_expr.name in CLDR_PLURAL_FORMS +def is_cldr_plural_form_key(key_expr: fl_ast.BaseNode) -> bool: + return isinstance(key_expr, fl_ast.Identifier) and key_expr.name in CLDR_PLURAL_FORMS def is_NUMBER_call_expr(expr): """ Returns True if the object is a FTL ast.FunctionReference representing a call to NUMBER """ - return isinstance(expr, FunctionReference) and expr.id.name == "NUMBER" + return isinstance(expr, fl_ast.FunctionReference) and expr.id.name == "NUMBER" -def lookup_term_reference(ref, block, compiler_env): +def lookup_term_reference( + ref: fl_ast.TermReference, block: codegen.Block, compiler_env: CompilerEnvironment +) -> tuple[fl_ast.Term | fl_ast.Attribute, RegisteredEscaper | NullEscaper] | codegen.CodeGenAst: + """ + Looks up term reference, and returns either: + - a tuple containing Term/Attribute and the escaper needed, + - OR a CodeGenAst object representing an error if not found. + """ # This could be turned into 'handle_term_reference', (similar to # 'handle_message_reference' below) once VariantList and VariantExpression # go away. term_id = reference_to_id(ref) if term_id in compiler_env.term_ids_to_ast: + term_ast = compiler_env.term_ids_to_ast[term_id] return ( - compiler_env.term_ids_to_ast[term_id], + term_ast, compiler_env.escaper_for_message(term_id), - None, ) - return compiler_env.term_ids_to_ast[term_id], None # Fallback to parent if ref.attribute: parent_id = reference_to_id(ref, ignore_attributes=True) @@ -1127,12 +1216,13 @@ def lookup_term_reference(ref, block, compiler_env): return ( compiler_env.term_ids_to_ast[parent_id], compiler_env.escaper_for_message(parent_id), - None, ) - return None, None, unknown_reference(term_id, block, ref, compiler_env) + return unknown_reference(term_id, block, ref, compiler_env) -def handle_message_reference(ref, block, compiler_env): +def handle_message_reference( + ref: fl_ast.MessageReference, block: codegen.Block, compiler_env: CompilerEnvironment +) -> codegen.Expression: msg_id = reference_to_id(ref) if msg_id in compiler_env.message_ids_to_ast: return do_message_call(msg_id, block, compiler_env) @@ -1147,14 +1237,14 @@ def handle_message_reference(ref, block, compiler_env): return unknown_reference(msg_id, block, ref, compiler_env) -def make_fluent_none(name, scope): +def make_fluent_none(name: str | None, scope: codegen.Scope) -> codegen.ObjectCreation: # > FluentNone(name) # OR # > FluentNone() return codegen.ObjectCreation("FluentNone", [codegen.String(name)] if name else [], {}, scope) -def numeric_to_native(val): +def numeric_to_native(val: str) -> float | int: """ Given a numeric string (as defined by fluent spec), return an int or float @@ -1166,7 +1256,7 @@ def numeric_to_native(val): return int(val) -def reserve_and_assign_name(block, suggested_name, value): +def reserve_and_assign_name(block: codegen.Block, suggested_name: str, value: codegen.CodeGenAst) -> str: """ Reserves a name for the value in the scope block and adds assignment if necessary, returning the name reserved. @@ -1182,9 +1272,14 @@ def reserve_and_assign_name(block, suggested_name, value): return name -def resolve_select_expression_statically(select_expr, key_ast, block, compiler_env): +def resolve_select_expression_statically( + select_expr: fl_ast.SelectExpression, + key_ast: codegen.Expression, + block: codegen.Block, + compiler_env: CompilerEnvironment, +) -> codegen.CodeGenAst | None: """ - Resolve a select expression statically, given a codegen.PythonAst object + Resolve a select expression statically, given a codegen.CodeGenAst object `key_ast` representing the key value, or return None if not possible. """ key_is_fluent_none = is_fluent_none(key_ast) @@ -1211,15 +1306,17 @@ def resolve_select_expression_statically(select_expr, key_ast, block, compiler_e found = variant break if key_is_string: - if isinstance(variant.key, Identifier) and key_ast.string_value == variant.key.name: + if isinstance(variant.key, fl_ast.Identifier) and key_ast.string_value == variant.key.name: found = variant break elif key_is_number: - if isinstance(variant.key, NumberLiteral) and key_number_value == numeric_to_native(variant.key.value): + if isinstance(variant.key, fl_ast.NumberLiteral) and key_number_value == numeric_to_native( + variant.key.value + ): found = variant break elif ( - isinstance(variant.key, Identifier) + isinstance(variant.key, fl_ast.Identifier) and compiler_env.plural_form_function(key_number_value) == variant.key.name ): found = variant @@ -1230,19 +1327,26 @@ def resolve_select_expression_statically(select_expr, key_ast, block, compiler_e return compile_expr(found.value, block, compiler_env) -def unknown_reference(name, block, ast_node, compiler_env): +def unknown_reference( + name: str, + block: codegen.Block, + ast_node: fl_ast.MessageReference | fl_ast.TermReference, + compiler_env: CompilerEnvironment, +) -> codegen.Expression: error = unknown_reference_error_obj(name, ast_node, compiler_env) add_static_msg_error(block, error) compiler_env.add_current_message_error(error) return make_fluent_none(name, block.scope) -def display_ast_location(ast_node, compiler_env): +def display_ast_location(ast_node: fl_ast.SyntaxNode, compiler_env: CompilerEnvironment) -> str: ftl_resource = compiler_env.current.ftl_resource return display_location(ftl_resource.filename, span_to_position(ast_node.span, ftl_resource.text)) -def unknown_reference_error_obj(ref_id, source_ast_node, compiler_env): +def unknown_reference_error_obj( + ref_id: str, source_ast_node: fl_ast.MessageReference | fl_ast.TermReference, compiler_env: CompilerEnvironment +) -> FluentReferenceError: location = display_ast_location(source_ast_node, compiler_env) if ATTRIBUTE_SEPARATOR in ref_id: return FluentReferenceError(f"{location}: Unknown attribute: {ref_id}") @@ -1251,18 +1355,22 @@ def unknown_reference_error_obj(ref_id, source_ast_node, compiler_env): return FluentReferenceError(f"{location}: Unknown message: {ref_id}") -def wrap_with_escaper(codegen_ast, block, compiler_env): +def wrap_with_escaper( + codegen_ast: codegen.Expression, block: codegen.Block, compiler_env: CompilerEnvironment +) -> codegen.Expression: escaper = compiler_env.current.escaper - if escaper is null_escaper or escaper.escape is identity: + if isinstance(escaper, NullEscaper) or escaper.escape is identity: return codegen_ast if escaper.output_type is codegen_ast.type: return codegen_ast return codegen.FunctionCall(escaper.escape_name(), [codegen_ast], {}, block.scope) -def wrap_with_mark_escaped(codegen_ast, block, compiler_env): +def wrap_with_mark_escaped( + codegen_ast: codegen.Expression, block: codegen.Block, compiler_env: CompilerEnvironment +) -> codegen.Expression: escaper = compiler_env.current.escaper - if escaper is null_escaper or escaper.mark_escaped is identity: + if isinstance(escaper, NullEscaper) or escaper.mark_escaped is identity: return codegen_ast if escaper.output_type is codegen_ast.type: return codegen_ast @@ -1272,11 +1380,11 @@ def wrap_with_mark_escaped(codegen_ast, block, compiler_env): # AST checking and simplification -def is_DATETIME_function_call(codegen_ast): +def is_DATETIME_function_call(codegen_ast: Any) -> TypeGuard[codegen.FunctionCall]: return isinstance(codegen_ast, codegen.FunctionCall) and codegen_ast.function_name == BUILTIN_DATETIME -def is_fluent_none(codegen_ast): +def is_fluent_none(codegen_ast: codegen.Expression) -> TypeGuard[codegen.ObjectCreation]: return ( isinstance(codegen_ast, codegen.ObjectCreation) and codegen_ast.function_name == "FluentNone" @@ -1284,15 +1392,15 @@ def is_fluent_none(codegen_ast): ) -def is_NUMBER_function_call(codegen_ast): +def is_NUMBER_function_call(codegen_ast: Any) -> TypeGuard[codegen.FunctionCall]: return isinstance(codegen_ast, codegen.FunctionCall) and codegen_ast.function_name == BUILTIN_NUMBER class Simplifier: - def __init__(self, compiler_env): + def __init__(self, compiler_env: CompilerEnvironment): self.compiler_env = compiler_env - def __call__(self, codegen_ast, changes): + def __call__(self, codegen_ast: Any, changes: list[Any | bool]) -> codegen.CodeGenAst: # Simplifications we can do on the AST tree. We append to # changes if we made a change, and either mutate codegen_ast or # return a new/different object. diff --git a/src/fluent_compiler/errors.py b/src/fluent_compiler/errors.py index 82082b9..7bc4452 100644 --- a/src/fluent_compiler/errors.py +++ b/src/fluent_compiler/errors.py @@ -1,8 +1,8 @@ class FluentError(ValueError): # This equality method exists to make exact tests for exceptions much # simpler to write, at least for our own errors. - def __eq__(self, other): - return (other.__class__ == self.__class__) and other.args == self.args + def __eq__(self, other: object) -> bool: + return type(other) is type(self) and other.args == self.args class FluentFormatError(FluentError): diff --git a/src/fluent_compiler/escapers.py b/src/fluent_compiler/escapers.py index bb5adeb..849f14d 100644 --- a/src/fluent_compiler/escapers.py +++ b/src/fluent_compiler/escapers.py @@ -1,5 +1,15 @@ -from types import SimpleNamespace +from __future__ import annotations +from typing import TYPE_CHECKING, Callable, Final, Generic, Sequence, TypeVar + +from attr import dataclass +from typing_extensions import Protocol, runtime_checkable + +if TYPE_CHECKING: + from .codegen import Expression + from .compiler import CompilerEnvironment + +from . import ast_compat as py_ast from . import codegen @@ -13,25 +23,91 @@ def identity(value): # Default string join function and sentinel value -default_join = "".join +def default_join(items: Sequence[str]) -> str: + return "".join(items) -def select_always(message_id=None, **kwargs): +def select_always(message_id: str | None = None, **kwargs: object) -> bool: return True -null_escaper = SimpleNamespace( - select=select_always, - output_type=str, - escape=identity, - mark_escaped=identity, - join=default_join, - name="null_escaper", +T = TypeVar("T") + + +@runtime_checkable +class IsEscaper(Protocol[T]): + output_type: Final[type[T]] + name: Final[str] + use_isolating: Final[bool | None] + + def select(self, message_id: str, **kwargs: object) -> bool: + ... + + def escape(self, unescaped: str, /) -> T: + ... + + def mark_escaped(self, value: str, /) -> T: + ... + + def join(self, parts: Sequence[T], /) -> T: + ... + + +@dataclass(frozen=True) +class Escaper(Generic[T]): + select: Callable[..., bool] + output_type: type[T] + escape: Callable[[str], T] + mark_escaped: Callable[[str], T] + join: Callable[[Sequence[T]], T] + name: str + use_isolating: bool | None + + +class NullEscaper: + # select = select_always + # output_type = str + # escape = identity + # mark_escaped = identity + # join = default_join + def __init__(self) -> None: + self.name = "null_escaper" + self.use_isolating = None + self.output_type = str + + def select(self, message_id: str, **kwargs: object) -> bool: + return True + + def escape(self, unescaped: str) -> str: + return unescaped + + def mark_escaped(self, value: str, /) -> str: + return value + + def join(self, parts: Sequence[str], /) -> str: + return "".join(parts) + + +null_escaper = NullEscaper() + +# Some tests for the types above: +_1: IsEscaper[str] = NullEscaper() + + +_2: IsEscaper[str] = Escaper( + name="x", use_isolating=None, + select=lambda **kwargs: True, + output_type=str, + escape=lambda unescaped, /: unescaped, + mark_escaped=lambda value, /: value, + join="".join, ) -def escapers_compatible(outer_escaper, inner_escaper): +def escapers_compatible( + outer_escaper: NullEscaper | RegisteredEscaper, inner_escaper: NullEscaper | RegisteredEscaper +) -> bool: # Messages with no escaper defined can always be used from other messages, # because the outer message will do the escaping, and the inner message will # always return a simple string which must be handle by all escapers. @@ -43,11 +119,12 @@ def escapers_compatible(outer_escaper, inner_escaper): return outer_escaper.name == inner_escaper.name -def escaper_for_message(escapers, message_id): - if escapers is not None: - for escaper in escapers: - if escaper.select(message_id=message_id): - return escaper +def escaper_for_message( + escapers: Sequence[RegisteredEscaper], message_id: str | None +) -> RegisteredEscaper | NullEscaper: + for escaper in escapers: + if escaper.select(message_id=message_id): + return escaper return null_escaper @@ -58,7 +135,7 @@ class RegisteredEscaper: functions are called in the compiler environment. """ - def __init__(self, escaper, compiler_env): + def __init__(self, escaper: IsEscaper, compiler_env: CompilerEnvironment): self._escaper = escaper self._compiler_env = compiler_env @@ -66,30 +143,32 @@ def __repr__(self): return f"" @property - def select(self): + def select(self) -> Callable: return self._escaper.select @property - def output_type(self): + def output_type(self) -> type: return self._escaper.output_type @property - def escape(self): + def escape(self) -> Callable: return self._escaper.escape @property - def mark_escaped(self): + def mark_escaped(self) -> Callable: return self._escaper.mark_escaped @property - def join(self): + def join(self) -> Callable: return self._escaper.join @property - def name(self): + def name(self) -> str: return self._escaper.name - def get_reserved_names_with_properties(self): + def get_reserved_names_with_properties( + self, + ) -> list[tuple[str, object, dict[str, object]]]: # escaper.output_type, escaper.mark_escaped, escaper.escape, escaper.join return [ (self.output_type_name(), self._escaper.output_type, {}), @@ -110,35 +189,35 @@ def get_reserved_names_with_properties(self): ), ] - def _prefix(self): + def _prefix(self) -> str: idx = self._compiler_env.escapers.index(self) return f"escaper_{idx}_" - def output_type_name(self): + def output_type_name(self) -> str: return f"{self._prefix()}_output_type" - def mark_escaped_name(self): + def mark_escaped_name(self) -> str: return f"{self._prefix()}_mark_escaped" - def escape_name(self): + def escape_name(self) -> str: return f"{self._prefix()}_escape" - def join_name(self): + def join_name(self) -> str: return f"{self._prefix()}_join" @property - def use_isolating(self): + def use_isolating(self) -> bool | None: return getattr(self._escaper, "use_isolating", None) -class EscaperJoin(codegen.StringJoin): - def __init__(self, parts, escaper, scope): +class EscaperJoin(codegen.StringJoinBase): + def __init__(self, parts: Sequence[Expression], escaper: RegisteredEscaper, scope: codegen.Scope): super().__init__(parts) self.type = escaper.output_type self.escaper = escaper self.scope = scope - def as_ast(self): + def as_ast(self) -> py_ast.expr: if self.escaper.join is default_join: return super().as_ast() else: @@ -151,8 +230,10 @@ def as_ast(self): ).as_ast() @classmethod - def build(cls, parts, escaper, scope): - if escaper.name == null_escaper.name: + def build_with_escaper( + cls, parts: Sequence[Expression], escaper: RegisteredEscaper | NullEscaper, scope: codegen.Scope + ) -> codegen.CodeGenAst: + if isinstance(escaper, NullEscaper): return codegen.StringJoin.build(parts) new_parts = [] @@ -161,17 +242,18 @@ def build(cls, parts, escaper, scope): if len(new_parts) > 0: last_part = new_parts[-1] # Merge string literals wrapped in mark_escaped calls - if all( - ( - isinstance(p, codegen.FunctionCall) - and p.function_name == escaper.mark_escaped_name() - and isinstance(p.args[0], codegen.String) - ) - for p in [last_part, part] + if ( + isinstance(last_part, codegen.FunctionCall) + and last_part.function_name == escaper.mark_escaped_name() + and (isinstance(last_part_args0 := last_part.args[0], codegen.String)) + ) and ( + isinstance(part, codegen.FunctionCall) + and part.function_name == escaper.mark_escaped_name() + and isinstance(part_args0 := part.args[0], codegen.String) ): new_parts[-1] = codegen.FunctionCall( last_part.function_name, - [codegen.String(last_part.args[0].string_value + part.args[0].string_value)], + [codegen.String(last_part_args0.string_value + part_args0.string_value)], {}, scope, ) diff --git a/src/fluent_compiler/resource.py b/src/fluent_compiler/resource.py index ddcc335..137fb85 100644 --- a/src/fluent_compiler/resource.py +++ b/src/fluent_compiler/resource.py @@ -1,20 +1,22 @@ -import attr +from __future__ import annotations +from dataclasses import dataclass -@attr.s + +@dataclass class FtlResource: """ Represents an (unparsed) FTL file (contents and optional filename) """ - text = attr.ib() - filename = attr.ib(default=None) + text: str + filename: str | None = None @classmethod - def from_string(cls, text): + def from_string(cls, text: str) -> FtlResource: return cls(text) @classmethod - def from_file(cls, filename, encoding="utf-8"): + def from_file(cls, filename: str, encoding="utf-8"): with open(filename, "rb") as f: return cls(text=f.read().decode(encoding), filename=filename) diff --git a/src/fluent_compiler/runtime.py b/src/fluent_compiler/runtime.py index 87c76d8..faf1147 100644 --- a/src/fluent_compiler/runtime.py +++ b/src/fluent_compiler/runtime.py @@ -1,7 +1,11 @@ # Runtime functions for compiled messages +from __future__ import annotations from datetime import date, datetime from decimal import Decimal +from typing import Callable + +from babel.core import Locale from .errors import FluentCyclicReferenceError, FluentFormatError, FluentReferenceError from .types import FluentNone, FluentType, fluent_date, fluent_number @@ -18,7 +22,7 @@ ] -RETURN_TYPES = { +RETURN_TYPES: dict[str, type] = { "handle_argument": object, "handle_output": str, "FluentReferenceError": FluentReferenceError, @@ -27,7 +31,13 @@ } -def handle_argument_with_escaper(arg, name, output_type, locale, errors): +def handle_argument_with_escaper( + arg: object, + name: str, + output_type: type, + locale: Locale, + errors: list[Exception], +) -> object: # This needs to be synced with resolver.handle_variable_reference if isinstance(arg, output_type): return arg @@ -41,7 +51,7 @@ def handle_argument_with_escaper(arg, name, output_type, locale, errors): return name -def handle_argument(arg, name, locale, errors): +def handle_argument(arg: object, name: str, locale: Locale, errors: list[Exception]) -> object: # handle_argument_with_escaper specialized to null escaper # This needs to be synced with resolver.handle_variable_reference if isinstance(arg, str): @@ -54,7 +64,13 @@ def handle_argument(arg, name, locale, errors): return name -def handle_output_with_escaper(val, output_type, escaper_escape, locale, errors): +def handle_output_with_escaper( + val: object, + output_type: type, + escaper_escape: Callable, + locale: Locale, + errors: list[Exception], +) -> object: if isinstance(val, output_type): return val elif isinstance(val, str): @@ -67,7 +83,7 @@ def handle_output_with_escaper(val, output_type, escaper_escape, locale, errors) raise TypeError(f"Cannot handle object {val} of type {type(val).__name__}") -def handle_output(val, locale, errors): +def handle_output(val: object, locale: Locale, errors: list[Exception]) -> str: # handle_output_with_escaper specialized to null_escaper if isinstance(val, str): return val diff --git a/src/fluent_compiler/source.py b/src/fluent_compiler/source.py new file mode 100644 index 0000000..d2d721d --- /dev/null +++ b/src/fluent_compiler/source.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +from fluent.syntax import ast as fl_ast + +from .resource import FtlResource +from .utils import ( + span_to_position, +) + + +class FtlSource: + """ + Object used to specify the origin of a chunk of FTL + """ + + def __init__(self, ast_node: fl_ast.Attribute | fl_ast.Message, ftl_resource: FtlResource): + self.ast_node = ast_node + self.ftl_resource = ftl_resource + self.filename = self.ftl_resource.filename + assert ast_node.span is not None + self.row, self.column = span_to_position(ast_node.span, ftl_resource.text) diff --git a/src/fluent_compiler/types.py b/src/fluent_compiler/types.py index 5a49461..f88f385 100644 --- a/src/fluent_compiler/types.py +++ b/src/fluent_compiler/types.py @@ -1,45 +1,61 @@ +from __future__ import annotations + import warnings -from datetime import date, datetime +from dataclasses import dataclass +from datetime import date, datetime, tzinfo from decimal import Decimal +from typing import Any, Literal, get_args, overload -import attr import pytz +from babel.core import Locale from babel.dates import format_date, format_time, get_datetime_format, get_timezone from babel.numbers import NumberPattern, get_currency_name, get_currency_unit_pattern, parse_pattern -FORMAT_STYLE_DECIMAL = "decimal" -FORMAT_STYLE_CURRENCY = "currency" -FORMAT_STYLE_PERCENT = "percent" -FORMAT_STYLE_OPTIONS = { - FORMAT_STYLE_DECIMAL, - FORMAT_STYLE_CURRENCY, - FORMAT_STYLE_PERCENT, -} - -CURRENCY_DISPLAY_SYMBOL = "symbol" -CURRENCY_DISPLAY_CODE = "code" -CURRENCY_DISPLAY_NAME = "name" -CURRENCY_DISPLAY_OPTIONS = { - CURRENCY_DISPLAY_SYMBOL, - CURRENCY_DISPLAY_CODE, - CURRENCY_DISPLAY_NAME, -} - -DATE_STYLE_OPTIONS = { +from .compat import StrEnum, TypeAlias + + +class FormatStyle(StrEnum): + DECIMAL = "decimal" + CURRENCY = "currency" + PERCENT = "percent" + + +class CurrencyDisplay(StrEnum): + SYMBOL = "symbol" + CODE = "code" + NAME = "name" + + +DateStyle: TypeAlias = Literal[ "full", "long", "medium", "short", None, -} +] -TIME_STYLE_OPTIONS = { +TimeStyle: TypeAlias = Literal[ "full", "long", "medium", "short", None, -} +] + + +# Backwards compat constants in case anyone is importing these +FORMAT_STYLE_DECIMAL = FormatStyle.DECIMAL +FORMAT_STYLE_CURRENCY = FormatStyle.CURRENCY +FORMAT_STYLE_PERCENT = FormatStyle.PERCENT +FORMAT_STYLE_OPTIONS: set[FormatStyle] = set(FormatStyle) + +CURRENCY_DISPLAY_SYMBOL = CurrencyDisplay.SYMBOL +CURRENCY_DISPLAY_CODE = CurrencyDisplay.CODE +CURRENCY_DISPLAY_NAME = CurrencyDisplay.NAME +CURRENCY_DISPLAY_OPTIONS: set[CurrencyDisplay] = set(CurrencyDisplay) + +DATE_STYLE_OPTIONS: set[str | None] = set(get_args(DateStyle)) +TIME_STYLE_OPTIONS: set[str | None] = set(get_args(TimeStyle)) class FluentType: @@ -48,20 +64,20 @@ def format(self, locale): class FluentNone(FluentType): - def __init__(self, name=None): + def __init__(self, name: str | None = None): self.name = name - def __eq__(self, other): + def __eq__(self, other: object) -> bool: return isinstance(other, FluentNone) and self.name == other.name - def format(self, locale): + def format(self, locale: Locale) -> str: return self.name or "???" def __repr__(self): return f"" -@attr.s +@dataclass class NumberFormatOptions: # We follow the Intl.NumberFormat parameter names here, # rather than using underscores as per PEP8, so that @@ -70,31 +86,29 @@ class NumberFormatOptions: # Keyword args available to FTL authors must be synced to fluent_number.ftl_arg_spec below # See https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/NumberFormat - style = attr.ib( - default=FORMAT_STYLE_DECIMAL, - validator=attr.validators.in_(FORMAT_STYLE_OPTIONS), - ) - currency = attr.ib(default=None) - currencyDisplay = attr.ib( - default=CURRENCY_DISPLAY_SYMBOL, - validator=attr.validators.in_(CURRENCY_DISPLAY_OPTIONS), - ) - useGrouping = attr.ib(default=True) - minimumIntegerDigits = attr.ib(default=None) - minimumFractionDigits = attr.ib(default=None) - maximumFractionDigits = attr.ib(default=None) - minimumSignificantDigits = attr.ib(default=None) - maximumSignificantDigits = attr.ib(default=None) + style: FormatStyle = FormatStyle.DECIMAL + currency: str | None = None + currencyDisplay: CurrencyDisplay = CurrencyDisplay.SYMBOL + useGrouping: bool = True + minimumIntegerDigits: int | None = None + minimumFractionDigits: int | None = None + maximumFractionDigits: int | None = None + minimumSignificantDigits: int | None = None + maximumSignificantDigits: int | None = None + + def __post_init__(self): + self.currencyDisplay = CurrencyDisplay(self.currencyDisplay) + self.style = FormatStyle(self.style) class FluentNumber(FluentType): default_number_format_options = NumberFormatOptions() - def __new__(cls, value, **kwargs): + def __new__(cls: type[FluentNumber], value: Decimal | float | FluentFloat | int, **kwargs) -> FluentNumber: self = super().__new__(cls, value) return self._init(value, kwargs) - def _init(self, value, kwargs): + def _init(self, value: Decimal | float | FluentNumber | int, kwargs: dict[str, int | str | bool]) -> FluentNumber: self.options = merge_options( NumberFormatOptions, getattr(value, "options", self.default_number_format_options), @@ -106,7 +120,7 @@ def _init(self, value, kwargs): return self - def format(self, locale): + def format(self, locale: Locale) -> str: if self.options.style == FORMAT_STYLE_DECIMAL: base_pattern = locale.decimal_formats.get(None) pattern = self._apply_options(base_pattern) @@ -116,14 +130,14 @@ def format(self, locale): pattern = self._apply_options(base_pattern) return pattern.apply(self, locale) elif self.options.style == FORMAT_STYLE_CURRENCY: - if self.options.currencyDisplay == "name": + if self.options.currencyDisplay == CurrencyDisplay.NAME: return self._format_currency_long_name(locale) else: base_pattern = locale.currency_formats["standard"] pattern = self._apply_options(base_pattern) return pattern.apply(self, locale, currency=self.options.currency, currency_digits=False) - def _apply_options(self, pattern): + def _apply_options(self, pattern: NumberPattern) -> NumberPattern: # We are essentially trying to copy the # https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/NumberFormat # API using Babel number formatting routines, which is slightly awkward @@ -173,7 +187,7 @@ def replacer(s): return pattern - def _format_currency_long_name(self, locale): + def _format_currency_long_name(self, locale: Locale) -> str: # This reproduces some of bable.numbers._format_currency_long_name # Step 3. unit_pattern = get_currency_unit_pattern(self.options.currency, count=self, locale=locale) @@ -193,7 +207,29 @@ def _format_currency_long_name(self, locale): return unit_pattern.format(number_part, display_name) -def merge_options(options_class, base, kwargs): +@overload +def merge_options( + options_class: type[DateFormatOptions], + base: DateFormatOptions | None, + kwargs: dict[str, int | str | bool], +) -> DateFormatOptions: + ... + + +@overload +def merge_options( + options_class: type[NumberFormatOptions], + base: NumberFormatOptions | None, + kwargs: dict[str, int | str | bool], +) -> NumberFormatOptions: + ... + + +def merge_options( + options_class: type[DateFormatOptions] | type[NumberFormatOptions], + base: DateFormatOptions | NumberFormatOptions | None, + kwargs: dict[str, int | str | bool], +) -> DateFormatOptions | NumberFormatOptions: """ Given an 'options_class', an optional 'base' object to copy from, and some keyword arguments, create a new options instance @@ -242,7 +278,7 @@ class FluentDecimal(FluentNumber, Decimal): pass -def fluent_number(number, **kwargs): +def fluent_number(number: Any, **kwargs) -> FluentNone | FluentNumber: if isinstance(number, FluentNumber) and not kwargs: return number if isinstance(number, int): @@ -279,7 +315,7 @@ def fluent_number(number, **kwargs): _UNGROUPED_PATTERN = parse_pattern("#0") -def clone_pattern(pattern): +def clone_pattern(pattern: NumberPattern) -> NumberPattern: return NumberPattern( pattern.pattern, pattern.prefix, @@ -292,31 +328,37 @@ def clone_pattern(pattern): ) -@attr.s +@dataclass class DateFormatOptions: # Parameters. # See https://projectfluent.org/fluent/guide/functions.html#datetime # Developer only - timeZone = attr.ib(default=None) + timeZone: tzinfo | None = None # Other # Keyword args available to FTL authors must be synced to fluent_date.ftl_arg_spec below - hour12 = attr.ib(default=None) - weekday = attr.ib(default=None) - era = attr.ib(default=None) - year = attr.ib(default=None) - month = attr.ib(default=None) - day = attr.ib(default=None) - hour = attr.ib(default=None) - minute = attr.ib(default=None) - second = attr.ib(default=None) - timeZoneName = attr.ib(default=None) + hour12: bool | None = None + weekday: str | None = None + era: str | None = None + year: str | None = None + month: str | None = None + day: str | None = None + hour: str | None = None + minute: str | None = None + second: str | None = None + timeZoneName: str | None = None # See https://github.com/tc39/proposal-ecma402-datetime-style - dateStyle = attr.ib(default=None, validator=attr.validators.in_(DATE_STYLE_OPTIONS)) - timeStyle = attr.ib(default=None, validator=attr.validators.in_(TIME_STYLE_OPTIONS)) + dateStyle: DateStyle | None = None + timeStyle: TimeStyle | None = None + + def __post_init__(self): + if self.dateStyle not in DATE_STYLE_OPTIONS: + raise ValueError(f"{self.dateStyle} is not a valid option for dateStyle, choose from: {DATE_STYLE_OPTIONS}") + if self.timeStyle not in TIME_STYLE_OPTIONS: + raise ValueError(f"{self.timeStyle} is not a valid option for dateStyle, choose from: {TIME_STYLE_OPTIONS}") _SUPPORTED_DATETIME_OPTIONS = ["dateStyle", "timeStyle", "timeZone"] @@ -327,7 +369,7 @@ class FluentDateType(FluentType): # some Python implementation (e.g. PyPy) implement some methods. # So we leave those alone, and implement another `_init_options` # which is called from other constructors. - def _init_options(self, dt_obj, kwargs): + def _init_options(self, dt_obj: date | datetime | FluentDate, kwargs: dict[str, str | bool]): if "timeStyle" in kwargs and not isinstance(self, datetime): raise TypeError("timeStyle option can only be specified for datetime instances, not date instance") @@ -336,9 +378,9 @@ def _init_options(self, dt_obj, kwargs): if k not in _SUPPORTED_DATETIME_OPTIONS: warnings.warn(f"FluentDateType option {k} is not yet supported") - def format(self, locale): + def format(self, locale: Locale) -> str: if isinstance(self, datetime): - selftz = _ensure_datetime_tzinfo(self, tzinfo=self.options.timeZone) + selftz = _ensure_datetime_tzinfo(self, tzinfo_obj=self.options.timeZone) else: selftz = self @@ -364,23 +406,23 @@ def format(self, locale): ) -def _ensure_datetime_tzinfo(dt, tzinfo=None): +def _ensure_datetime_tzinfo(dt: FluentDateTime, tzinfo_obj: tzinfo | None = None) -> FluentDateTime: """ Ensure the datetime passed has an attached tzinfo. """ # Adapted from babel's function. if dt.tzinfo is None: dt = dt.replace(tzinfo=pytz.UTC) - if tzinfo is not None: - dt = dt.astimezone(get_timezone(tzinfo)) - if hasattr(tzinfo, "normalize"): # pytz - dt = tzinfo.normalize(datetime) + if tzinfo_obj is not None: + dt = dt.astimezone(get_timezone(tzinfo_obj)) + if hasattr(tzinfo_obj, "normalize"): # pytz + dt = tzinfo_obj.normalize(datetime) return dt class FluentDate(FluentDateType, date): @classmethod - def from_date(cls, dt_obj, **kwargs): + def from_date(cls, dt_obj: date | FluentDate, **kwargs) -> FluentDate: obj = cls(dt_obj.year, dt_obj.month, dt_obj.day) obj._init_options(dt_obj, kwargs) return obj @@ -388,7 +430,7 @@ def from_date(cls, dt_obj, **kwargs): class FluentDateTime(FluentDateType, datetime): @classmethod - def from_date_time(cls, dt_obj, **kwargs): + def from_date_time(cls, dt_obj: datetime, **kwargs) -> FluentDateTime: obj = cls( dt_obj.year, dt_obj.month, @@ -403,7 +445,7 @@ def from_date_time(cls, dt_obj, **kwargs): return obj -def fluent_date(dt, **kwargs): +def fluent_date(dt: date | datetime | FluentDateType, **kwargs) -> FluentDateType: if isinstance(dt, FluentDateType) and not kwargs: return dt if isinstance(dt, datetime): diff --git a/src/fluent_compiler/utils.py b/src/fluent_compiler/utils.py index b6f0db5..5c3abc9 100644 --- a/src/fluent_compiler/utils.py +++ b/src/fluent_compiler/utils.py @@ -1,21 +1,28 @@ +from __future__ import annotations + import builtins import inspect import keyword import re +from typing import TYPE_CHECKING, Any, Callable, List, Tuple, Union -from fluent.syntax.ast import Term, TermReference +from fluent.syntax.ast import Attribute, Message, MessageReference, Span, Term, TermReference +from .compat import TypeAlias from .errors import FluentFormatError TERM_SIGIL = "-" ATTRIBUTE_SEPARATOR = "." +if TYPE_CHECKING: + from . import codegen + -class Any: +class AnyArgType: pass -Any = Any() +AnyArg = AnyArgType() # From spec: @@ -25,13 +32,13 @@ class Any: NAMED_ARG_RE = re.compile(r"^[a-zA-Z][a-zA-Z0-9_-]*$") -def allowable_keyword_arg_name(name): +def allowable_keyword_arg_name(name: str) -> re.Match | None: # We limit to what Fluent allows for NamedArgument - Python allows anything # if you use **kwarg call and receiving syntax. return NAMED_ARG_RE.match(name) -def ast_to_id(ast): +def ast_to_id(ast: Message | Term) -> str: """ Returns a string reference for a Term or Message """ @@ -40,14 +47,14 @@ def ast_to_id(ast): return ast.id.name -def attribute_ast_to_id(attribute, parent_ast): +def attribute_ast_to_id(attribute: Attribute, parent_ast: Message | Term) -> str: """ Returns a string reference for an Attribute, given Attribute and parent Term or Message """ return "".join([ast_to_id(parent_ast), ATTRIBUTE_SEPARATOR, attribute.id.name]) -def allowable_name(ident, for_method=False, allow_builtin=False): +def allowable_name(ident: str, for_method: bool = False, allow_builtin: bool = False) -> bool: if keyword.iskeyword(ident): return False @@ -61,7 +68,10 @@ def allowable_name(ident, for_method=False, allow_builtin=False): return True -def inspect_function_args(function, name, errors): +FunctionArgSpec: TypeAlias = Tuple[Union[int, AnyArgType], Union[List[str], AnyArgType]] + + +def inspect_function_args(function: Callable, name: str, errors: list[Any]) -> FunctionArgSpec: """ For a Python function, returns a 2 tuple containing: (number of positional args or Any, @@ -76,7 +86,7 @@ def inspect_function_args(function, name, errors): parameters = list(sig.parameters.values()) positional = ( - Any + AnyArg if any(p.kind == inspect.Parameter.VAR_POSITIONAL for p in parameters) else len( list( @@ -88,14 +98,19 @@ def inspect_function_args(function, name, errors): ) keywords = ( - Any + AnyArg if any(p.kind == inspect.Parameter.VAR_KEYWORD for p in parameters) else [p.name for p in parameters if p.default != inspect.Parameter.empty] ) return sanitize_function_args((positional, keywords), name, errors) -def args_match(function_name, args, kwargs, arg_spec): +def args_match( + function_name: str, + args: list[codegen.CodeGenAst], + kwargs: dict[str, codegen.CodeGenAst], + arg_spec: FunctionArgSpec, +) -> Any: """ Checks the passed in args/kwargs against the function arg_spec and returns data for calling the function correctly. @@ -115,13 +130,13 @@ def args_match(function_name, args, kwargs, arg_spec): positional_arg_count, allowed_kwargs = arg_spec match = True for kwarg_name, kwarg_val in kwargs.items(): - if (allowed_kwargs is Any and allowable_keyword_arg_name(kwarg_name)) or ( - allowed_kwargs is not Any and kwarg_name in allowed_kwargs + if (allowed_kwargs is AnyArg and allowable_keyword_arg_name(kwarg_name)) or ( + allowed_kwargs is not AnyArg and kwarg_name in allowed_kwargs ): sanitized_kwargs[kwarg_name] = kwarg_val else: errors.append(TypeError(f"{function_name}() got an unexpected keyword argument '{kwarg_name}'")) - if positional_arg_count is Any: + if positional_arg_count is AnyArg: sanitized_args = args else: sanitized_args = tuple(args[0:positional_arg_count]) @@ -143,7 +158,7 @@ def args_match(function_name, args, kwargs, arg_spec): return (match, sanitized_args, sanitized_kwargs, errors) -def reference_to_id(ref, ignore_attributes=False): +def reference_to_id(ref: TermReference | MessageReference, ignore_attributes: bool = False) -> str: """ Returns a string reference for a MessageReference or TermReference AST node. @@ -164,13 +179,13 @@ def reference_to_id(ref, ignore_attributes=False): return start -def sanitize_function_args(arg_spec, name, errors): +def sanitize_function_args(arg_spec: Any, name: str, errors: list[Any]) -> FunctionArgSpec: """ Check function arg spec is legitimate, returning a cleaned up version, and adding any errors to errors list. """ positional_args, keyword_args = arg_spec - if keyword_args is Any: + if keyword_args is AnyArg: cleaned_kwargs = keyword_args else: cleaned_kwargs = [] @@ -182,7 +197,7 @@ def sanitize_function_args(arg_spec, name, errors): return (positional_args, cleaned_kwargs) -def span_to_position(span, source_text): +def span_to_position(span: Span, source_text: str) -> tuple[int, int]: start = span.start relevant = source_text[0:start] row = relevant.count("\n") + 1 @@ -190,6 +205,6 @@ def span_to_position(span, source_text): return row, col -def display_location(filename, position): +def display_location(filename: str | None, position: tuple[int, int]) -> str: row, col = position return f"{filename if filename else ''}:{row}:{col}" diff --git a/tests/format/test_escapers.py b/tests/format/test_escapers.py index 840dc99..a23265b 100644 --- a/tests/format/test_escapers.py +++ b/tests/format/test_escapers.py @@ -1,38 +1,42 @@ import operator -import unittest from functools import reduce +from typing import Sequence +import pytest from bs4 import BeautifulSoup from markdown import markdown from markupsafe import Markup, escape from fluent_compiler.bundle import FluentBundle +from fluent_compiler.escapers import IsEscaper from ..utils import dedent_ftl +def assertTypeAndValueEqual(val1, val2): + assert val1 == val2 + assert type(val1) is type(val2) + + # An escaper for MarkupSafe with instrumentation so we can check behaviour class HtmlEscaper: name = "HtmlEscaper" output_type = Markup use_isolating = False - def __init__(self, test_case): - self.test_case = test_case - - def select(self, message_id=None, **hints): + def select(self, message_id: str, **kwargs: object): return message_id.endswith("-html") - def mark_escaped(self, escaped): - self.test_case.assertEqual(type(escaped), str) + def mark_escaped(self, escaped: str) -> Markup: + assert type(escaped) is str return Markup(escaped) - def escape(self, unescaped): + def escape(self, unescaped: str) -> Markup: return escape(unescaped) - def join(self, parts): + def join(self, parts: Sequence[Markup]) -> Markup: for p in parts: - self.test_case.assertEqual(type(p), Markup) + assert type(p) is Markup return Markup("").join(parts) @@ -79,15 +83,13 @@ def __init__(self, text): class MarkdownEscaper: name = "MarkdownEscaper" output_type = Markdown + use_isolating = False - def __init__(self, test_case): - self.test_case = test_case - - def select(self, message_id=None, **hints): + def select(self, message_id: str, **kwargs: object): return message_id.endswith("-md") def mark_escaped(self, escaped): - self.test_case.assertEqual(type(escaped), str) + assert type(escaped) is str return LiteralMarkdown(escaped) def escape(self, unescaped): @@ -98,371 +100,408 @@ def escape(self, unescaped): def join(self, parts): for p in parts: - self.test_case.assertTrue(isinstance(p, Markdown)) + assert isinstance(p, Markdown) return reduce(operator.add, parts, empty_markdown) -class TestHtmlEscaping(unittest.TestCase): - def setUp(self): - escaper = HtmlEscaper(self) +@pytest.fixture(scope="session") +def html_escaping_bundle() -> FluentBundle: + escaper: IsEscaper[Markup] = HtmlEscaper() - # A function that outputs '> ' that needs to be escaped. Part of the - # point of this is to ensure that escaping is being done at the correct - # point - it is no good to escape string input when it enters, it has to - # be done at the end of the formatting process. - def QUOTE(arg): - return "\n" + "\n".join(f"> {line}" for line in arg.split("\n")) + # A function that outputs '> ' that needs to be escaped. Part of the + # point of this is to ensure that escaping is being done at the correct + # point - it is no good to escape string input when it enters, it has to + # be done at the end of the formatting process. + def QUOTE(arg): + return "\n" + "\n".join(f"> {line}" for line in arg.split("\n")) - self.bundle = FluentBundle.from_string( - "en-US", - dedent_ftl( - """ - not-html-message = x < y + return FluentBundle.from_string( + "en-US", + dedent_ftl( + """ + not-html-message = x < y - simple-html = This is great. + simple-html = This is great. - argument-html = This thing is called { $arg }. + argument-html = This thing is called { $arg }. - -term-html = Jack & Jill + -term-html = Jack & Jill - -term-plain = Jack & Jill + -term-plain = Jack & Jill - references-html-term-html = { -term-html } are great! + references-html-term-html = { -term-html } are great! - references-plain-term-html = { -term-plain } are great! + references-plain-term-html = { -term-plain } are great! - references-html-term-plain = { -term-html } are great! + references-html-term-plain = { -term-html } are great! - attribute-argument-html = A link to { $place } + attribute-argument-html = A link to { $place } - compound-message-html = A message about { $arg }. { argument-html } + compound-message-html = A message about { $arg }. { argument-html } - function-html = You said: { QUOTE($text) } + function-html = You said: { QUOTE($text) } - parent-plain = Some stuff - .attr-html = Some HTML stuff - .attr-plain = This & That + parent-plain = Some stuff + .attr-html = Some HTML stuff + .attr-plain = This & That - references-html-message-plain = Plain. { simple-html } + references-html-message-plain = Plain. { simple-html } - references-html-message-attr-plain = Plain. { parent-plain.attr-html } + references-html-message-attr-plain = Plain. { parent-plain.attr-html } - references-html-message-attr-html = HTML. { parent-plain.attr-html } + references-html-message-attr-html = HTML. { parent-plain.attr-html } - references-plain-message-attr-html = HTML. { parent-plain.attr-plain } + references-plain-message-attr-html = HTML. { parent-plain.attr-plain } - -brand-plain = { $variant -> - [short] A&B - *[long] A & B - } + -brand-plain = { $variant -> + [short] A&B + *[long] A & B + } - -brand-html = { $variant -> - [superscript] CoolBrand2 - *[normal] CoolBrand2 - } + -brand-html = { $variant -> + [superscript] CoolBrand2 + *[normal] CoolBrand2 + } - references-html-variant-plain = { -brand-html(variant: "superscript") } is cool + references-html-variant-plain = { -brand-html(variant: "superscript") } is cool - references-html-variant-html = { -brand-html(variant: "superscript") } is cool + references-html-variant-html = { -brand-html(variant: "superscript") } is cool - references-plain-variant-plain = { -brand-plain(variant: "short") } is awesome + references-plain-variant-plain = { -brand-plain(variant: "short") } is awesome - references-plain-variant-html = { -brand-plain(variant: "short") } is awesome - """ - ), - use_isolating=True, - functions={"QUOTE": QUOTE}, - escapers=[escaper], - ) + references-plain-variant-html = { -brand-plain(variant: "short") } is awesome + """ + ), + use_isolating=True, + functions={"QUOTE": QUOTE}, + escapers=[escaper], + ) + + +def test_html_select_false(html_escaping_bundle: FluentBundle): + val, errs = html_escaping_bundle.format("not-html-message") + assertTypeAndValueEqual(val, "x < y") - def assertTypeAndValueEqual(self, val1, val2): - self.assertEqual(val1, val2) - self.assertEqual(type(val1), type(val2)) - - def test_select_false(self): - val, errs = self.bundle.format("not-html-message") - self.assertTypeAndValueEqual(val, "x < y") - - def test_simple(self): - val, errs = self.bundle.format("simple-html") - self.assertTypeAndValueEqual(val, Markup("This is great.")) - self.assertEqual(errs, []) - - def test_argument_is_escaped(self): - val, errs = self.bundle.format("argument-html", {"arg": "Jack & Jill"}) - self.assertTypeAndValueEqual(val, Markup("This thing is called Jack & Jill.")) - self.assertEqual(errs, []) - - def test_argument_already_escaped(self): - val, errs = self.bundle.format("argument-html", {"arg": Markup("Jack")}) - self.assertTypeAndValueEqual(val, Markup("This thing is called Jack.")) - self.assertEqual(errs, []) - - def test_included_html_term(self): - val, errs = self.bundle.format("references-html-term-html") - self.assertTypeAndValueEqual(val, Markup("Jack & Jill are great!")) - self.assertEqual(errs, []) - - def test_included_plain_term(self): - val, errs = self.bundle.format("references-plain-term-html") - self.assertTypeAndValueEqual(val, Markup("Jack & Jill are great!")) - self.assertEqual(errs, []) - - def test_included_html_term_from_plain(self): - val, errs = self.bundle.format("references-html-term-plain") - self.assertTypeAndValueEqual(val, "\u2068-term-html\u2069 are great!") - self.assertEqual(type(errs[0]), TypeError) - - def test_compound_message(self): - val, errs = self.bundle.format("compound-message-html", {"arg": "Jack & Jill"}) - self.assertTypeAndValueEqual( - val, - Markup("A message about Jack & Jill. " "This thing is called Jack & Jill."), - ) - self.assertEqual(errs, []) - def test_function(self): - val, errs = self.bundle.format("function-html", {"text": "Jack & Jill"}) - self.assertTypeAndValueEqual(val, Markup("You said: \n> Jack & Jill")) - self.assertEqual(errs, []) +def test_html_simple(html_escaping_bundle: FluentBundle): + val, errs = html_escaping_bundle.format("simple-html") + assertTypeAndValueEqual(val, Markup("This is great.")) + assert errs == [] - def test_plain_parent(self): - val, errs = self.bundle.format("parent-plain") - self.assertTypeAndValueEqual(val, "Some stuff") - self.assertEqual(errs, []) - def test_html_attribute(self): - val, errs = self.bundle.format("parent-plain.attr-html") - self.assertTypeAndValueEqual(val, Markup("Some HTML stuff")) - self.assertEqual(errs, []) +def test_html_argument_is_escaped(html_escaping_bundle: FluentBundle): + val, errs = html_escaping_bundle.format("argument-html", {"arg": "Jack & Jill"}) + assertTypeAndValueEqual(val, Markup("This thing is called Jack & Jill.")) + assert errs == [] - def test_html_message_reference_from_plain(self): - val, errs = self.bundle.format("references-html-message-plain") - self.assertTypeAndValueEqual(val, "Plain. \u2068simple-html\u2069") - self.assertEqual(len(errs), 1) - self.assertEqual(type(errs[0]), TypeError) - # Message attr references - def test_html_message_attr_reference_from_plain(self): - val, errs = self.bundle.format("references-html-message-attr-plain") - self.assertTypeAndValueEqual(val, "Plain. \u2068parent-plain.attr-html\u2069") - self.assertEqual(len(errs), 1) - self.assertEqual(type(errs[0]), TypeError) +def test_html_argument_already_escaped(html_escaping_bundle: FluentBundle): + val, errs = html_escaping_bundle.format("argument-html", {"arg": Markup("Jack")}) + assertTypeAndValueEqual(val, Markup("This thing is called Jack.")) + assert errs == [] - def test_html_message_attr_reference_from_html(self): - val, errs = self.bundle.format("references-html-message-attr-html") - self.assertTypeAndValueEqual(val, Markup("HTML. Some HTML stuff")) - self.assertEqual(errs, []) - def test_plain_message_attr_reference_from_html(self): - val, errs = self.bundle.format("references-plain-message-attr-html") - self.assertTypeAndValueEqual(val, Markup("HTML. This & That")) - self.assertEqual(errs, []) +def test_included_html_term(html_escaping_bundle: FluentBundle): + val, errs = html_escaping_bundle.format("references-html-term-html") + assertTypeAndValueEqual(val, Markup("Jack & Jill are great!")) + assert errs == [] - # Term variant references - def test_html_variant_from_plain(self): - val, errs = self.bundle.format("references-html-variant-plain") - self.assertTypeAndValueEqual(val, "\u2068-brand-html\u2069 is cool") - self.assertEqual(len(errs), 1) - self.assertEqual(type(errs[0]), TypeError) - def test_html_variant_from_html(self): - val, errs = self.bundle.format("references-html-variant-html") - self.assertTypeAndValueEqual(val, Markup("CoolBrand2 is cool")) - self.assertEqual(errs, []) +def test_included_plain_term(html_escaping_bundle: FluentBundle): + val, errs = html_escaping_bundle.format("references-plain-term-html") + assertTypeAndValueEqual(val, Markup("Jack & Jill are great!")) + assert errs == [] - def test_plain_variant_from_plain(self): - val, errs = self.bundle.format("references-plain-variant-plain") - self.assertTypeAndValueEqual(val, "\u2068A&B\u2069 is awesome") - self.assertEqual(errs, []) - def test_plain_variant_from_html(self): - val, errs = self.bundle.format("references-plain-variant-html") - self.assertTypeAndValueEqual(val, Markup("A&B is awesome")) - self.assertEqual(errs, []) +def test_included_html_term_from_plain(html_escaping_bundle: FluentBundle): + val, errs = html_escaping_bundle.format("references-html-term-plain") + assertTypeAndValueEqual(val, "\u2068-term-html\u2069 are great!") + assert type(errs[0]) is TypeError + assert ( + errs[0].args[0] + == "Escaper HtmlEscaper for term -term-html cannot be used from calling context with null_escaper escaper" + ) - def test_use_isolating(self): - val, errs = self.bundle.format("attribute-argument-html", {"url": "http://example.com", "place": "A Place"}) - self.assertTypeAndValueEqual(val, Markup('A link to A Place')) +def test_html_compound_message(html_escaping_bundle: FluentBundle): + val, errs = html_escaping_bundle.format("compound-message-html", {"arg": "Jack & Jill"}) + assertTypeAndValueEqual( + val, + Markup("A message about Jack & Jill. " "This thing is called Jack & Jill."), + ) + assert errs == [] -class TestMarkdownEscaping(unittest.TestCase): - maxDiff = None - def setUp(self): - escaper = MarkdownEscaper(self) +def test_html_function(html_escaping_bundle: FluentBundle): + val, errs = html_escaping_bundle.format("function-html", {"text": "Jack & Jill"}) + assertTypeAndValueEqual(val, Markup("You said: \n> Jack & Jill")) + assert errs == [] - # This QUOTE function outputs Markdown that should not be removed. - def QUOTE(arg): - return Markdown("\n" + "\n".join(f"> {line}" for line in arg.split("\n"))) - self.bundle = FluentBundle.from_string( - "en-US", - dedent_ftl( - """ - not-md-message = **some text** +def test_html_plain_parent(html_escaping_bundle: FluentBundle): + val, errs = html_escaping_bundle.format("parent-plain") + assertTypeAndValueEqual(val, "Some stuff") + assert errs == [] - simple-md = This is **great** - argument-md = This **thing** is called { $arg }. +def test_html_attribute(html_escaping_bundle: FluentBundle): + val, errs = html_escaping_bundle.format("parent-plain.attr-html") + assertTypeAndValueEqual(val, Markup("Some HTML stuff")) + assert errs == [] - -term-md = **Jack** & __Jill__ - -term-plain = **Jack & Jill** +def test_html_message_reference_from_plain(html_escaping_bundle: FluentBundle): + val, errs = html_escaping_bundle.format("references-html-message-plain") + assertTypeAndValueEqual(val, "Plain. \u2068simple-html\u2069") + assert len(errs) == 1 + assert type(errs[0]) is TypeError - term-md-ref-md = { -term-md } are **great!** - term-plain-ref-md = { -term-plain } are **great!** +# Message attr references +def test_html_message_attr_reference_from_plain(html_escaping_bundle: FluentBundle): + val, errs = html_escaping_bundle.format("references-html-message-attr-plain") + assertTypeAndValueEqual(val, "Plain. \u2068parent-plain.attr-html\u2069") + assert len(errs) == 1 + assert type(errs[0]) is TypeError + assert ( + errs[0].args[0] + == "Escaper HtmlEscaper for message parent-plain.attr-html cannot be used from calling context with null_escaper escaper" + ) - embedded-argument-md = A [link to { $place }]({ $url }) - compound-message-md = A message about { $arg }. { argument-md } +def test_html_message_attr_reference_from_html(html_escaping_bundle: FluentBundle): + val, errs = html_escaping_bundle.format("references-html-message-attr-html") + assertTypeAndValueEqual(val, Markup("HTML. Some HTML stuff")) + assert errs == [] - function-md = You said: { QUOTE($text) } - parent-plain = Some stuff - .attr-md = Some **Markdown** stuff - .attr-plain = This and **That** +def test_plain_message_attr_reference_from_html(html_escaping_bundle: FluentBundle): + val, errs = html_escaping_bundle.format("references-plain-message-attr-html") + assertTypeAndValueEqual(val, Markup("HTML. This & That")) + assert errs == [] - references-md-message-plain = Plain. { simple-md } - references-md-attr-plain = Plain. { parent-plain.attr-md } +# Term variant references +def test_html_variant_from_plain(html_escaping_bundle: FluentBundle): + val, errs = html_escaping_bundle.format("references-html-variant-plain") + assertTypeAndValueEqual(val, "\u2068-brand-html\u2069 is cool") + assert len(errs) == 1 + assert type(errs[0]) is TypeError - references-md-attr-md = **Markdown**. { parent-plain.attr-md } - references-plain-attr-md = **Markdown**. { parent-plain.attr-plain } +def test_html_variant_from_html(html_escaping_bundle: FluentBundle): + val, errs = html_escaping_bundle.format("references-html-variant-html") + assertTypeAndValueEqual(val, Markup("CoolBrand2 is cool")) + assert errs == [] - -brand-plain = { $variant -> - [short] *A&B* - *[long] *A & B* - } - -brand-md = { $variant -> - [bolded] CoolBrand **2** - *[normal] CoolBrand2 - } +def test_html_plain_variant_from_plain(html_escaping_bundle: FluentBundle): + val, errs = html_escaping_bundle.format("references-plain-variant-plain") + assertTypeAndValueEqual(val, "\u2068A&B\u2069 is awesome") + assert errs == [] - references-md-variant-plain = { -brand-md(variant: "bolded") } is cool - references-md-variant-md = { -brand-md(variant: "bolded") } is cool +def test_plain_variant_from_html(html_escaping_bundle: FluentBundle): + val, errs = html_escaping_bundle.format("references-plain-variant-html") + assertTypeAndValueEqual(val, Markup("A&B is awesome")) + assert errs == [] - references-plain-variant-plain = { -brand-plain(variant: "short") } is awesome - references-plain-variant-md = { -brand-plain(variant: "short") } is awesome +def test_use_isolating(html_escaping_bundle: FluentBundle): + val, errs = html_escaping_bundle.format( + "attribute-argument-html", {"url": "http://example.com", "place": "A Place"} + ) + assertTypeAndValueEqual(val, Markup('A link to A Place')) + + +@pytest.fixture(scope="session") +def markdown_escaping_bundle() -> FluentBundle: + escaper: IsEscaper[Markdown] = MarkdownEscaper() + + # This QUOTE function outputs Markdown that should not be removed. + def QUOTE(arg): + return Markdown("\n" + "\n".join(f"> {line}" for line in arg.split("\n"))) + + return FluentBundle.from_string( + "en-US", + dedent_ftl( """ - ), - use_isolating=False, - functions={"QUOTE": QUOTE}, - escapers=[escaper], - ) + not-md-message = **some text** + + simple-md = This is **great** + + argument-md = This **thing** is called { $arg }. + + -term-md = **Jack** & __Jill__ + + -term-plain = **Jack & Jill** + + term-md-ref-md = { -term-md } are **great!** + + term-plain-ref-md = { -term-plain } are **great!** + + embedded-argument-md = A [link to { $place }]({ $url }) + + compound-message-md = A message about { $arg }. { argument-md } + + function-md = You said: { QUOTE($text) } + + parent-plain = Some stuff + .attr-md = Some **Markdown** stuff + .attr-plain = This and **That** + + references-md-message-plain = Plain. { simple-md } + + references-md-attr-plain = Plain. { parent-plain.attr-md } + + references-md-attr-md = **Markdown**. { parent-plain.attr-md } + + references-plain-attr-md = **Markdown**. { parent-plain.attr-plain } + + -brand-plain = { $variant -> + [short] *A&B* + *[long] *A & B* + } + + -brand-md = { $variant -> + [bolded] CoolBrand **2** + *[normal] CoolBrand2 + } + + references-md-variant-plain = { -brand-md(variant: "bolded") } is cool + + references-md-variant-md = { -brand-md(variant: "bolded") } is cool + + references-plain-variant-plain = { -brand-plain(variant: "short") } is awesome + + references-plain-variant-md = { -brand-plain(variant: "short") } is awesome + """ + ), + use_isolating=False, + functions={"QUOTE": QUOTE}, + escapers=[escaper], + ) - def test_strip_markdown(self): - self.assertEqual( - StrippedMarkdown("**Some bolded** and __italic__ text"), - Markdown("Some bolded and italic text"), - ) - self.assertEqual( - StrippedMarkdown( - """ + +def test_strip_markdown(): + assert StrippedMarkdown("**Some bolded** and __italic__ text") == Markdown("Some bolded and italic text") + assert ( + StrippedMarkdown( + """ > A quotation > about something """ - ), - Markdown("\nA quotation\nabout something\n"), ) + == Markdown("\nA quotation\nabout something\n") + ) - def test_select_false(self): - val, errs = self.bundle.format("not-md-message") - self.assertEqual(val, "**some text**") - - def test_simple(self): - val, errs = self.bundle.format("simple-md") - self.assertEqual(val, Markdown("This is **great**")) - self.assertEqual(errs, []) - - def test_argument_is_escaped(self): - val, errs = self.bundle.format("argument-md", {"arg": "**Jack**"}) - self.assertEqual(val, Markdown("This **thing** is called Jack.")) - self.assertEqual(errs, []) - - def test_argument_already_escaped(self): - val, errs = self.bundle.format("argument-md", {"arg": Markdown("**Jack**")}) - self.assertEqual(val, Markdown("This **thing** is called **Jack**.")) - self.assertEqual(errs, []) - - def test_included_md(self): - val, errs = self.bundle.format("term-md-ref-md") - self.assertEqual(val, Markdown("**Jack** & __Jill__ are **great!**")) - self.assertEqual(errs, []) - - def test_included_plain(self): - val, errs = self.bundle.format("term-plain-ref-md") - self.assertEqual(val, Markdown("Jack & Jill are **great!**")) - self.assertEqual(errs, []) - - def test_compound_message(self): - val, errs = self.bundle.format("compound-message-md", {"arg": "**Jack & Jill**"}) - self.assertEqual( - val, - Markdown("A message about Jack & Jill. " "This **thing** is called Jack & Jill."), - ) - self.assertEqual(errs, []) - - def test_function(self): - val, errs = self.bundle.format("function-md", {"text": "Jack & Jill"}) - self.assertEqual(val, Markdown("You said: \n> Jack & Jill")) - self.assertEqual(errs, []) - - def test_plain_parent(self): - val, errs = self.bundle.format("parent-plain") - self.assertEqual(val, "Some stuff") - self.assertEqual(errs, []) - - def test_md_attribute(self): - val, errs = self.bundle.format("parent-plain.attr-md") - self.assertEqual(val, Markdown("Some **Markdown** stuff")) - self.assertEqual(errs, []) - - def test_md_message_reference_from_plain(self): - val, errs = self.bundle.format("references-md-message-plain") - self.assertEqual(val, "Plain. simple-md") - self.assertEqual(len(errs), 1) - self.assertEqual(type(errs[0]), TypeError) - - def test_md_attr_reference_from_plain(self): - val, errs = self.bundle.format("references-md-attr-plain") - self.assertEqual(val, "Plain. parent-plain.attr-md") - self.assertEqual(len(errs), 1) - self.assertEqual(type(errs[0]), TypeError) - - def test_md_reference_from_md(self): - val, errs = self.bundle.format("references-md-attr-md") - self.assertEqual(val, Markdown("**Markdown**. Some **Markdown** stuff")) - self.assertEqual(errs, []) - - def test_plain_reference_from_md(self): - val, errs = self.bundle.format("references-plain-attr-md") - self.assertEqual(val, Markdown("**Markdown**. This and That")) - self.assertEqual(errs, []) - - def test_md_variant_from_plain(self): - val, errs = self.bundle.format("references-md-variant-plain") - self.assertEqual(val, "-brand-md is cool") - self.assertEqual(len(errs), 1) - self.assertEqual(type(errs[0]), TypeError) - - def test_md_variant_from_md(self): - val, errs = self.bundle.format("references-md-variant-md") - self.assertEqual(val, Markdown("CoolBrand **2** is cool")) - self.assertEqual(errs, []) - - def test_plain_variant_from_plain(self): - val, errs = self.bundle.format("references-plain-variant-plain") - self.assertEqual(val, "*A&B* is awesome") - self.assertEqual(errs, []) - - def test_plain_variant_from_md(self): - val, errs = self.bundle.format("references-plain-variant-md") - self.assertEqual(val, Markdown("A&B is awesome")) - self.assertEqual(errs, []) + +def test_md_select_false(markdown_escaping_bundle: FluentBundle): + val, errs = markdown_escaping_bundle.format("not-md-message") + assert val == "**some text**" + + +def test_md_simple(markdown_escaping_bundle: FluentBundle): + val, errs = markdown_escaping_bundle.format("simple-md") + assert val == Markdown("This is **great**") + assert errs == [] + + +def test_md_argument_is_escaped(markdown_escaping_bundle: FluentBundle): + val, errs = markdown_escaping_bundle.format("argument-md", {"arg": "**Jack**"}) + assert val == Markdown("This **thing** is called Jack.") + assert errs == [] + + +def test_md_argument_already_escaped(markdown_escaping_bundle: FluentBundle): + val, errs = markdown_escaping_bundle.format("argument-md", {"arg": Markdown("**Jack**")}) + assert val == Markdown("This **thing** is called **Jack**.") + assert errs == [] + + +def test_included_md(markdown_escaping_bundle: FluentBundle): + val, errs = markdown_escaping_bundle.format("term-md-ref-md") + assert val == Markdown("**Jack** & __Jill__ are **great!**") + assert errs == [] + + +def test_included_plain(markdown_escaping_bundle: FluentBundle): + val, errs = markdown_escaping_bundle.format("term-plain-ref-md") + assert val == Markdown("Jack & Jill are **great!**") + assert errs == [] + + +def test_md_compound_message(markdown_escaping_bundle: FluentBundle): + val, errs = markdown_escaping_bundle.format("compound-message-md", {"arg": "**Jack & Jill**"}) + assert val == Markdown("A message about Jack & Jill. " "This **thing** is called Jack & Jill.") + assert errs == [] + + +def test_md_function(markdown_escaping_bundle: FluentBundle): + val, errs = markdown_escaping_bundle.format("function-md", {"text": "Jack & Jill"}) + assert val == Markdown("You said: \n> Jack & Jill") + assert errs == [] + + +def test_md_plain_parent(markdown_escaping_bundle: FluentBundle): + val, errs = markdown_escaping_bundle.format("parent-plain") + assert val == "Some stuff" + assert errs == [] + + +def test_md_attribute(markdown_escaping_bundle: FluentBundle): + val, errs = markdown_escaping_bundle.format("parent-plain.attr-md") + assert val == Markdown("Some **Markdown** stuff") + assert errs == [] + + +def test_md_message_reference_from_plain(markdown_escaping_bundle: FluentBundle): + val, errs = markdown_escaping_bundle.format("references-md-message-plain") + assert val == "Plain. simple-md" + assert len(errs) == 1 + assert type(errs[0]) is TypeError + + +def test_md_attr_reference_from_plain(markdown_escaping_bundle: FluentBundle): + val, errs = markdown_escaping_bundle.format("references-md-attr-plain") + assert val == "Plain. parent-plain.attr-md" + assert len(errs) == 1 + assert type(errs[0]) is TypeError + + +def test_md_reference_from_md(markdown_escaping_bundle: FluentBundle): + val, errs = markdown_escaping_bundle.format("references-md-attr-md") + assert val == Markdown("**Markdown**. Some **Markdown** stuff") + assert errs == [] + + +def test_plain_reference_from_md(markdown_escaping_bundle: FluentBundle): + val, errs = markdown_escaping_bundle.format("references-plain-attr-md") + assert val == Markdown("**Markdown**. This and That") + assert errs == [] + + +def test_md_variant_from_plain(markdown_escaping_bundle: FluentBundle): + val, errs = markdown_escaping_bundle.format("references-md-variant-plain") + assert val == "-brand-md is cool" + assert len(errs) == 1 + assert type(errs[0]) is TypeError + + +def test_md_variant_from_md(markdown_escaping_bundle: FluentBundle): + val, errs = markdown_escaping_bundle.format("references-md-variant-md") + assert val == Markdown("CoolBrand **2** is cool") + assert errs == [] + + +def test_md_plain_variant_from_plain(markdown_escaping_bundle: FluentBundle): + val, errs = markdown_escaping_bundle.format("references-plain-variant-plain") + assert val == "*A&B* is awesome" + assert errs == [] + + +def test_plain_variant_from_md(markdown_escaping_bundle: FluentBundle): + val, errs = markdown_escaping_bundle.format("references-plain-variant-md") + assert val == Markdown("A&B is awesome") + assert errs == [] diff --git a/tests/test_utils.py b/tests/test_utils.py index 494f111..93fa13d 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,7 +1,7 @@ import unittest from fluent_compiler.errors import FluentFormatError -from fluent_compiler.utils import Any, inspect_function_args +from fluent_compiler.utils import AnyArg, inspect_function_args class TestInspectFunctionArgs(unittest.TestCase): @@ -11,16 +11,16 @@ def test_inspect_function_args_positional(self): self.assertEqual(inspect_function_args(lambda x, y: None, "name", []), (2, [])) def test_inspect_function_args_var_positional(self): - self.assertEqual(inspect_function_args(lambda *args: None, "name", []), (Any, [])) + self.assertEqual(inspect_function_args(lambda *args: None, "name", []), (AnyArg, [])) def test_inspect_function_args_keywords(self): self.assertEqual(inspect_function_args(lambda x, y=1, z=2: None, "name", []), (1, ["y", "z"])) def test_inspect_function_args_var_keywords(self): - self.assertEqual(inspect_function_args(lambda x, **kwargs: None, "name", []), (1, Any)) + self.assertEqual(inspect_function_args(lambda x, **kwargs: None, "name", []), (1, AnyArg)) def test_inspect_function_args_var_positional_plus_keywords(self): - self.assertEqual(inspect_function_args(lambda x, y=1, *args: None, "name", []), (Any, ["y"])) + self.assertEqual(inspect_function_args(lambda x, y=1, *args: None, "name", []), (AnyArg, ["y"])) def test_inspect_function_args_bad_keyword_args(self): def foo(): diff --git a/tools/benchmarks/README.md b/tools/benchmarks/README.md index f0c5920..6db24e2 100644 --- a/tools/benchmarks/README.md +++ b/tools/benchmarks/README.md @@ -1,22 +1,19 @@ -To run the benchmarks, do the following from this directory: - - $ pip install -r requirements.txt - -Then, run any of the benchmarks you want as scripts: +To run the benchmarks, after installing the project in the normal way +for development, run any of the benchmarks you want as scripts: $ ./runtime.py $ ./compiler.py You can also run them using py.test with extra args: - $ py.test --benchmark-warmup=on runtime.py -k interpolation + $ pytest --benchmark-warmup=on runtime.py -k interpolation The “plural form” tests are the cases where GNU gettext performs most favourably, partly because it uses a much simpler (and incorrect) function for deciding plural forms, while we use the more complex ones from CLDR. You can exclude those by doing: - $ py.test --benchmark-warmup=on runtime.py -k 'not plural' + $ pytest --benchmark-warmup=on runtime.py -k 'not plural' To profile the benchmark suite, we recommend py-spy as a good tool. Install py-spy: https://github.com/benfred/py-spy diff --git a/tools/benchmarks/generate_ftl_file.py b/tools/benchmarks/generate_ftl_file.py index 1eb03af..82c5f4a 100644 --- a/tools/benchmarks/generate_ftl_file.py +++ b/tools/benchmarks/generate_ftl_file.py @@ -1,12 +1,12 @@ import argparse import random +from dataclasses import dataclass -import attrs from fluent.syntax import serialize from fluent.syntax.ast import Comment, Identifier, Message, Pattern, Placeable, Resource, TextElement, VariableReference -@attrs.frozen +@dataclass(frozen=True) class ItemRatios: """ Represent the ratios of different items inside the generated ftl file @@ -16,7 +16,7 @@ class ItemRatios: comment: int -@attrs.frozen +@dataclass(frozen=True) class ElementCountRatios: """ Represent the ratios of the different count of elements within a pattern @@ -28,7 +28,7 @@ class ElementCountRatios: four: int -@attrs.frozen +@dataclass(frozen=True) class Config: filename: str num_items: int diff --git a/tools/benchmarks/requirements.txt b/tools/benchmarks/requirements.txt deleted file mode 100644 index ec19b3f..0000000 --- a/tools/benchmarks/requirements.txt +++ /dev/null @@ -1,4 +0,0 @@ -pytest -pytest-benchmark -fluent.runtime==0.3 -fluent.syntax==0.17