From f10fc96b5a734480054d8ba539d45ccc4d291658 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Tue, 12 Nov 2024 15:38:33 -0600 Subject: [PATCH] Fix tag_inames to apply multiple tags, type it --- loopy/transform/iname.py | 91 +++++++++++++++++++--------------------- 1 file changed, 44 insertions(+), 47 deletions(-) diff --git a/loopy/transform/iname.py b/loopy/transform/iname.py index 97257745c..1fe886422 100644 --- a/loopy/transform/iname.py +++ b/loopy/transform/iname.py @@ -21,10 +21,14 @@ """ +from collections.abc import Iterable, Mapping, Sequence from typing import Any, FrozenSet, Optional +from typing_extensions import TypeAlias + import islpy as isl from islpy import dim_type +from pytools.tag import Tag from loopy.diagnostic import LoopyError from loopy.kernel import LoopKernel @@ -675,9 +679,18 @@ def untag_inames(kernel, iname_to_untag, tag_type): # {{{ tag inames +_Tags_ish: TypeAlias = Tag | Sequence[Tag] | str | Sequence[str] + + @for_each_kernel -def tag_inames(kernel, iname_to_tag, force=False, - ignore_nonexistent=False): +def tag_inames( + kernel: LoopKernel, + iname_to_tag: (Mapping[str, _Tags_ish] + | Sequence[tuple[str, _Tags_ish]] + | str), + force: bool = False, + ignore_nonexistent: bool = False + ) -> LoopKernel: """Tag an iname :arg iname_to_tag: a list of tuples ``(iname, new_tag)``. *new_tag* is given @@ -697,74 +710,67 @@ def tag_inames(kernel, iname_to_tag, force=False, """ if isinstance(iname_to_tag, str): - def parse_kv(s): + def parse_kv(s: str) -> tuple[str, str]: colon_index = s.find(":") if colon_index == -1: raise ValueError("tag decl '%s' has no colon" % s) return (s[:colon_index].strip(), s[colon_index+1:].strip()) - iname_to_tag = [ + iname_to_tags_seq = [ parse_kv(s) for s in iname_to_tag.split(",") if s.strip()] + elif isinstance(iname_to_tag, Mapping): + iname_to_tags_seq = list(iname_to_tag.items()) + else: + iname_to_tags_seq = iname_to_tag if not iname_to_tag: return kernel - # convert dict to list of tuples - if isinstance(iname_to_tag, dict): - iname_to_tag = list(iname_to_tag.items()) - # flatten iterables of tags for each iname - try: - from collections.abc import Iterable - except ImportError: - from collections import Iterable # pylint:disable=no-name-in-module - - unpack_iname_to_tag = [] - for iname, tags in iname_to_tag: + unpack_iname_to_tag: list[tuple[str, Tag | str]] = [] + for iname, tags in iname_to_tags_seq: if isinstance(tags, Iterable) and not isinstance(tags, str): for tag in tags: unpack_iname_to_tag.append((iname, tag)) else: unpack_iname_to_tag.append((iname, tags)) - iname_to_tag = unpack_iname_to_tag from loopy.kernel.data import parse_tag as inner_parse_tag - def parse_tag(tag): + def parse_tag(tag: Tag | str) -> Iterable[Tag]: if isinstance(tag, str): if tag.startswith("like."): - tags = kernel.iname_tags(tag[5:]) - if len(tags) == 0: - return None - if len(tags) == 1: - return tags[0] - else: - raise LoopyError("cannot use like for multiple tags (for now)") + return kernel.iname_tags(tag[5:]) elif tag == "unused.g": return find_unused_axis_tag(kernel, "g") elif tag == "unused.l": return find_unused_axis_tag(kernel, "l") - return inner_parse_tag(tag) - - iname_to_tag = [(iname, parse_tag(tag)) for iname, tag in iname_to_tag] + result = inner_parse_tag(tag) + if result is None: + return [] + else: + return [result] - # {{{ globbing + iname_to_parsed_tag = [ + (iname, subtag) + for iname, tag in unpack_iname_to_tag + for subtag in parse_tag(tag) + ] + knl_inames = dict(kernel.inames) all_inames = kernel.all_inames() from loopy.match import re_from_glob - new_iname_to_tag = {} - for iname, new_tag in iname_to_tag: + + for iname, new_tag in iname_to_parsed_tag: if "*" in iname or "?" in iname: match_re = re_from_glob(iname) - for sub_iname in all_inames: - if match_re.match(sub_iname): - new_iname_to_tag[sub_iname] = new_tag - + inames = [sub_iname for sub_iname in all_inames + if match_re.match(sub_iname)] else: if iname not in all_inames: if ignore_nonexistent: @@ -772,22 +778,13 @@ def parse_tag(tag): else: raise LoopyError("iname '%s' does not exist" % iname) - new_iname_to_tag[iname] = new_tag - - iname_to_tag = new_iname_to_tag - del new_iname_to_tag + inames = [iname] - # }}} - - knl_inames = kernel.inames.copy() - for name, new_tag in iname_to_tag.items(): - if not new_tag: + if new_tag is None: continue - if name not in kernel.all_inames(): - raise ValueError("cannot tag '%s'--not known" % name) - - knl_inames[name] = knl_inames[name].tagged(new_tag) + for sub_iname in inames: + knl_inames[sub_iname] = knl_inames[sub_iname].tagged(new_tag) return kernel.copy(inames=knl_inames)