From 57f6654662dcf0188ae4d9976e166d59addb019a Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 13 Nov 2024 14:02:54 -0600 Subject: [PATCH] schedule: update types --- loopy/schedule/__init__.py | 85 ++++++++++++++++------------------- loopy/schedule/tree.py | 23 +++++----- loopy/transform/precompute.py | 5 ++- 3 files changed, 53 insertions(+), 60 deletions(-) diff --git a/loopy/schedule/__init__.py b/loopy/schedule/__init__.py index 1364be850..a9121de8c 100644 --- a/loopy/schedule/__init__.py +++ b/loopy/schedule/__init__.py @@ -27,20 +27,11 @@ import logging import sys +from collections.abc import Hashable, Iterator, Mapping, Sequence, Set from dataclasses import dataclass, replace from typing import ( TYPE_CHECKING, - AbstractSet, Any, - Dict, - FrozenSet, - Hashable, - Iterator, - Mapping, - Optional, - Sequence, - Set, - Tuple, TypeVar, ) @@ -155,7 +146,7 @@ class Barrier(ScheduleItem): def gather_schedule_block( schedule: Sequence[ScheduleItem], start_idx: int - ) -> Tuple[Sequence[ScheduleItem], int]: + ) -> tuple[Sequence[ScheduleItem], int]: assert isinstance(schedule[start_idx], BeginBlockItem) level = 0 @@ -176,7 +167,7 @@ def gather_schedule_block( def generate_sub_sched_items( schedule: Sequence[ScheduleItem], start_idx: int - ) -> Iterator[Tuple[int, ScheduleItem]]: + ) -> Iterator[tuple[int, ScheduleItem]]: if not isinstance(schedule[start_idx], BeginBlockItem): yield start_idx, schedule[start_idx] @@ -203,7 +194,7 @@ def generate_sub_sched_items( def get_insn_ids_for_block_at( schedule: Sequence[ScheduleItem], start_idx: int - ) -> FrozenSet[str]: + ) -> frozenset[str]: return frozenset( sub_sched_item.insn_id for i, sub_sched_item in generate_sub_sched_items( @@ -212,7 +203,7 @@ def get_insn_ids_for_block_at( def find_used_inames_within( - kernel: LoopKernel, sched_index: int) -> AbstractSet[str]: + kernel: LoopKernel, sched_index: int) -> set[str]: assert kernel.linearization is not None sched_item = kernel.linearization[sched_index] @@ -234,7 +225,7 @@ def find_used_inames_within( return result -def find_loop_nest_with_map(kernel: LoopKernel) -> Mapping[str, AbstractSet[str]]: +def find_loop_nest_with_map(kernel: LoopKernel) -> Mapping[str, set[str]]: """Returns a dictionary mapping inames to other inames that are always nested with them. """ @@ -257,11 +248,11 @@ def find_loop_nest_with_map(kernel: LoopKernel) -> Mapping[str, AbstractSet[str] return result -def find_loop_nest_around_map(kernel: LoopKernel) -> Mapping[str, AbstractSet[str]]: +def find_loop_nest_around_map(kernel: LoopKernel) -> Mapping[str, set[str]]: """Returns a dictionary mapping inames to other inames that are always nested around them. """ - result: Dict[str, Set[str]] = {} + result: dict[str, set[str]] = {} all_inames = kernel.all_inames() @@ -299,14 +290,14 @@ def find_loop_nest_around_map(kernel: LoopKernel) -> Mapping[str, AbstractSet[st def find_loop_insn_dep_map( kernel: LoopKernel, - loop_nest_with_map: Mapping[str, AbstractSet[str]], - loop_nest_around_map: Mapping[str, AbstractSet[str]] - ) -> Mapping[str, AbstractSet[str]]: + loop_nest_with_map: Mapping[str, Set[str]], + loop_nest_around_map: Mapping[str, Set[str]] + ) -> Mapping[str, set[str]]: """Returns a dictionary mapping inames to other instruction ids that need to be scheduled before the iname should be eligible for scheduling. """ - result: Dict[str, Set[str]] = {} + result: dict[str, set[str]] = {} from loopy.kernel.data import ConcurrentTag, IlpBaseTag for insn in kernel.instructions: @@ -372,7 +363,7 @@ def find_loop_insn_dep_map( def group_insn_counts(kernel: LoopKernel) -> Mapping[str, int]: - result: Dict[str, int] = {} + result: dict[str, int] = {} for insn in kernel.instructions: for grp in insn.groups: @@ -382,7 +373,7 @@ def group_insn_counts(kernel: LoopKernel) -> Mapping[str, int]: def gen_dependencies_except( - kernel: LoopKernel, insn_id: str, except_insn_ids: AbstractSet[str] + kernel: LoopKernel, insn_id: str, except_insn_ids: Set[str] ) -> Iterator[str]: insn = kernel.id_to_insn[insn_id] for dep_id in insn.depends_on: @@ -396,9 +387,9 @@ def gen_dependencies_except( def get_priority_tiers( - wanted: AbstractSet[int], - priorities: AbstractSet[Sequence[int]] - ) -> Iterator[AbstractSet[int]]: + wanted: Set[int], + priorities: Set[Sequence[int]] + ) -> Iterator[set[int]]: # Get highest priority tier candidates: These are the first inames # of all the given priority constraints candidates = set() @@ -677,24 +668,24 @@ class SchedulerState: order with instruction priorities as tie breaker. """ kernel: LoopKernel - loop_nest_around_map: Mapping[str, AbstractSet[str]] - loop_insn_dep_map: Mapping[str, AbstractSet[str]] + loop_nest_around_map: Mapping[str, set[str]] + loop_insn_dep_map: Mapping[str, set[str]] - breakable_inames: AbstractSet[str] - ilp_inames: AbstractSet[str] - vec_inames: AbstractSet[str] - concurrent_inames: AbstractSet[str] + breakable_inames: set[str] + ilp_inames: set[str] + vec_inames: set[str] + concurrent_inames: set[str] - insn_ids_to_try: Optional[AbstractSet[str]] + insn_ids_to_try: set[str] | None active_inames: Sequence[str] - entered_inames: FrozenSet[str] - enclosing_subkernel_inames: Tuple[str, ...] + entered_inames: frozenset[str] + enclosing_subkernel_inames: tuple[str, ...] schedule: Sequence[ScheduleItem] - scheduled_insn_ids: AbstractSet[str] - unscheduled_insn_ids: AbstractSet[str] + scheduled_insn_ids: frozenset[str] + unscheduled_insn_ids: set[str] preschedule: Sequence[ScheduleItem] - prescheduled_insn_ids: AbstractSet[str] - prescheduled_inames: AbstractSet[str] + prescheduled_insn_ids: set[str] + prescheduled_inames: set[str] may_schedule_global_barriers: bool within_subkernel: bool group_insn_counts: Mapping[str, int] @@ -702,7 +693,7 @@ class SchedulerState: insns_in_topologically_sorted_order: Sequence[InstructionBase] @property - def last_entered_loop(self) -> Optional[str]: + def last_entered_loop(self) -> str | None: if self.active_inames: return self.active_inames[-1] else: @@ -718,7 +709,7 @@ def get_insns_in_topologically_sorted_order( kernel: LoopKernel) -> Sequence[InstructionBase]: from pytools.graph import compute_topological_order - rev_dep_map: Dict[str, Set[str]] = { + rev_dep_map: dict[str, set[str]] = { not_none(insn.id): set() for insn in kernel.instructions} for insn in kernel.instructions: for dep in insn.depends_on: @@ -733,7 +724,7 @@ def get_insns_in_topologically_sorted_order( # Instead of returning these features as a key, we assign an id to # each set of features to avoid comparing them which can be expensive. insn_id_to_feature_id = {} - insn_features: Dict[Hashable, int] = {} + insn_features: dict[Hashable, int] = {} for insn in kernel.instructions: feature = (insn.within_inames, insn.groups, insn.conflicts_with_groups) if feature not in insn_features: @@ -890,7 +881,7 @@ def _get_outermost_diverging_inames( tree: LoopTree, within1: InameStrSet, within2: InameStrSet - ) -> Tuple[InameStr, InameStr]: + ) -> tuple[InameStr, InameStr]: """ For loop nestings *within1* and *within2*, returns the first inames at which the loops nests diverge in the loop nesting tree *tree*. @@ -2180,7 +2171,7 @@ def __init__(self, kernel): def generate_loop_schedules( kernel: LoopKernel, callables_table: CallablesTable, - debug_args: Optional[Dict[str, Any]] = None) -> Iterator[LoopKernel]: + debug_args: Mapping[str, Any] | None = None) -> Iterator[LoopKernel]: """ .. warning:: @@ -2236,7 +2227,7 @@ def _postprocess_schedule(kernel, callables_table, gen_sched): def _generate_loop_schedules_inner( kernel: LoopKernel, callables_table: CallablesTable, - debug_args: Optional[Dict[str, Any]]) -> Iterator[LoopKernel]: + debug_args: Mapping[str, Any] | None) -> Iterator[LoopKernel]: if debug_args is None: debug_args = {} @@ -2337,7 +2328,7 @@ def _generate_loop_schedules_inner( get_insns_in_topologically_sorted_order(kernel)), ) - schedule_gen_kwargs: Dict[str, Any] = {} + schedule_gen_kwargs: dict[str, Any] = {} def print_longest_dead_end(): if debug.interactive: @@ -2402,7 +2393,7 @@ def print_longest_dead_end(): schedule_cache: WriteOncePersistentDict[ - Tuple[LoopKernel, CallablesTable], + tuple[LoopKernel, CallablesTable], LoopKernel ] = WriteOncePersistentDict( "loopy-schedule-cache-v4-"+DATA_MODEL_VERSION, diff --git a/loopy/schedule/tree.py b/loopy/schedule/tree.py index 253ff5f84..e98724f83 100644 --- a/loopy/schedule/tree.py +++ b/loopy/schedule/tree.py @@ -34,9 +34,10 @@ THE SOFTWARE. """ +from collections.abc import Hashable, Iterator, Sequence from dataclasses import dataclass from functools import cached_property -from typing import Generic, Hashable, Iterator, List, Optional, Sequence, Tuple, TypeVar +from typing import Generic, TypeVar from immutables import Map @@ -70,11 +71,11 @@ class Tree(Generic[NodeT]): this allocates a new stack frame for each iteration of the operation. """ - _parent_to_children: Map[NodeT, Tuple[NodeT, ...]] - _child_to_parent: Map[NodeT, Optional[NodeT]] + _parent_to_children: Map[NodeT, tuple[NodeT, ...]] + _child_to_parent: Map[NodeT, NodeT | None] @staticmethod - def from_root(root: NodeT) -> "Tree[NodeT]": + def from_root(root: NodeT) -> Tree[NodeT]: return Tree(Map({root: ()}), Map({root: None})) @@ -89,7 +90,7 @@ def root(self) -> NodeT: return guess @memoize_method - def ancestors(self, node: NodeT) -> Tuple[NodeT, ...]: + def ancestors(self, node: NodeT) -> tuple[NodeT, ...]: """ Returns a :class:`tuple` of nodes that are ancestors of *node*. """ @@ -104,7 +105,7 @@ def ancestors(self, node: NodeT) -> Tuple[NodeT, ...]: return (parent,) + self.ancestors(parent) - def parent(self, node: NodeT) -> Optional[NodeT]: + def parent(self, node: NodeT) -> NodeT | None: """ Returns the parent of *node*. """ @@ -112,7 +113,7 @@ def parent(self, node: NodeT) -> Optional[NodeT]: return self._child_to_parent[node] - def children(self, node: NodeT) -> Tuple[NodeT, ...]: + def children(self, node: NodeT) -> tuple[NodeT, ...]: """ Returns the children of *node*. """ @@ -150,7 +151,7 @@ def __contains__(self, node: NodeT) -> bool: """Return *True* if *node* is a node in the tree.""" return node in self._child_to_parent - def add_node(self, node: NodeT, parent: NodeT) -> "Tree[NodeT]": + def add_node(self, node: NodeT, parent: NodeT) -> Tree[NodeT]: """ Returns a :class:`Tree` with added node *node* having a parent *parent*. @@ -165,7 +166,7 @@ def add_node(self, node: NodeT, parent: NodeT) -> "Tree[NodeT]": .set(node, ())), self._child_to_parent.set(node, parent)) - def replace_node(self, node: NodeT, new_node: NodeT) -> "Tree[NodeT]": + def replace_node(self, node: NodeT, new_node: NodeT) -> Tree[NodeT]: """ Returns a copy of *self* with *node* replaced with *new_node*. """ @@ -207,7 +208,7 @@ def replace_node(self, node: NodeT, new_node: NodeT) -> "Tree[NodeT]": return Tree(parent_to_children_mut.finish(), child_to_parent_mut.finish()) - def move_node(self, node: NodeT, new_parent: Optional[NodeT]) -> "Tree[NodeT]": + def move_node(self, node: NodeT, new_parent: NodeT | None) -> Tree[NodeT]: """ Returns a copy of *self* with node *node* as a child of *new_parent*. """ @@ -262,7 +263,7 @@ def __str__(self) -> str: ├── D └── E """ - def rec(node: NodeT) -> List[str]: + def rec(node: NodeT) -> list[str]: children_result = [rec(c) for c in self.children(node)] def post_process_non_last_child(children: Sequence[str]) -> list[str]: diff --git a/loopy/transform/precompute.py b/loopy/transform/precompute.py index c2cd0a5ca..b0fbb5468 100644 --- a/loopy/transform/precompute.py +++ b/loopy/transform/precompute.py @@ -155,7 +155,8 @@ def storage_axis_exprs(storage_axis_sources, args) -> Sequence[ExpressionT]: # {{{ gather rule invocations class RuleInvocationGatherer(RuleAwareIdentityMapper): - def __init__(self, rule_mapping_context, kernel, subst_name, subst_tag, within): + def __init__(self, rule_mapping_context, kernel, subst_name, subst_tag, within) \ + -> None: super().__init__(rule_mapping_context) from loopy.symbolic import SubstitutionRuleExpander @@ -167,7 +168,7 @@ def __init__(self, rule_mapping_context, kernel, subst_name, subst_tag, within): self.subst_tag = subst_tag self.within = within - self.access_descriptors: List[RuleAccessDescriptor] = [] + self.access_descriptors: list[RuleAccessDescriptor] = [] def map_substitution(self, name, tag, arguments, expn_state): process_me = name == self.subst_name