Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add virtual arrays #1277

Draft
wants to merge 11 commits into
base: master
Choose a base branch
from
76 changes: 62 additions & 14 deletions src/coffea/nanoevents/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
)
from coffea.nanoevents.schemas import BaseSchema, NanoAODSchema
from coffea.nanoevents.util import key_to_tuple, quote, tuple_to_key, unquote
from coffea.util import _remove_not_interpretable
from coffea.util import _remove_not_interpretable, deprecate

_offsets_label = quote(",!offsets")

Expand Down Expand Up @@ -206,6 +206,9 @@ def __call__(self, form):
return awkward.forms.form.from_dict(self.schemaclass(lform, self.version).form)


allowed_modes = frozenset(["eager", "virtual", "dask"])


class NanoEventsFactory:
"""
A factory class to build NanoEvents objects.
Expand All @@ -214,8 +217,10 @@ class NanoEventsFactory:
the constructor args are properly set.
"""

def __init__(self, schema, mapping, partition_key, cache=None, is_dask=False):
self._is_dask = is_dask
def __init__(self, schema, mapping, partition_key, cache=None, mode="eager"):
if mode not in allowed_modes:
raise ValueError(f"Invalid mode {mode}, valid modes are {allowed_modes}")
self._mode = mode
self._schema = schema
self._mapping = mapping
self._partition_key = partition_key
Expand Down Expand Up @@ -252,7 +257,8 @@ def from_root(
access_log=None,
iteritems_options={},
use_ak_forth=True,
delayed=True,
delayed=True, # deprecated
mode=None, # mode takes precedence over delayed
known_base_form=None,
decompression_executor=None,
interpretation_executor=None,
Expand Down Expand Up @@ -291,6 +297,8 @@ def from_root(
Toggle using awkward_forth to interpret branches in root file.
delayed:
Nanoevents will use dask as a backend to construct a delayed task graph representing your analysis.
mode:
Nanoevents will use "eager", "virtual", or "dask" as a backend. 'mode' will take precedence over 'delayed'.
known_base_form:
If the base form of the input file is known ahead of time we can skip opening a single file and parsing metadata.
decompression_executor (None or Executor with a ``submit`` method):
Expand All @@ -313,7 +321,20 @@ def from_root(
"""
)

if delayed and steps_per_file is not uproot._util.unset:
if mode is None:
deprecate(
RuntimeError(
"The 'delayed' argument is deprecated, please use 'mode' instead. "
"If you are using 'delayed=True' to construct a dask graph, please use 'mode=dask'"
),
"<unknown>",
)
mode = "dask" if delayed else "virtual"

if mode not in allowed_modes:
raise ValueError(f"Invalid mode {mode}, valid modes are {allowed_modes}")

if mode == "dask" and steps_per_file is not uproot._util.unset:
warnings.warn(
f"""You have set steps_per_file to {steps_per_file}, this should only be used for a
small number of inputs (e.g. for early-stage/exploratory analysis) since it does not
Expand All @@ -326,7 +347,7 @@ def from_root(
)

if (
delayed
mode == "dask"
and not isinstance(schemaclass, FunctionType)
and schemaclass.__dask_capable__
):
Expand Down Expand Up @@ -355,12 +376,14 @@ def from_root(
**uproot_options,
)

return cls(map_schema, opener, None, cache=None, is_dask=True)
elif delayed and not schemaclass.__dask_capable__:
return cls(map_schema, opener, None, cache=None, mode="dask")
elif mode == "dask" and not schemaclass.__dask_capable__:
warnings.warn(
f"{schemaclass} is not dask capable despite requesting delayed mode, generating non-dask nanoevents",
RuntimeWarning,
)
# fall through to virtual mode
mode = "virtual"

if isinstance(file, uproot.reading.ReadOnlyDirectory):
tree = file[treepath]
Expand Down Expand Up @@ -390,6 +413,7 @@ def from_root(
cache={},
access_log=access_log,
use_ak_forth=use_ak_forth,
virtual=mode == "virtual",
)
mapping.preload_column_source(partition_key[0], partition_key[1], tree)

Expand All @@ -405,6 +429,7 @@ def from_root(
persistent_cache,
schemaclass,
metadata,
mode=mode,
)

@classmethod
Expand All @@ -421,7 +446,8 @@ def from_parquet(
parquet_options={},
skyhook_options={},
access_log=None,
delayed=True,
delayed=True, # deprecated
mode=None, # mode takes precedence over delayed
):
"""Quickly build NanoEvents from a parquet file

Expand Down Expand Up @@ -452,6 +478,8 @@ def from_parquet(
Pass a list instance to record which branches were lazily accessed by this instance
delayed:
Nanoevents will use dask as a backend to construct a delayed task graph representing your analysis.
mode:
Nanoevents will use "eager", "virtual", or "dask" as a backend. 'mode' will take precedence over 'delayed'.

Returns
-------
Expand All @@ -471,8 +499,21 @@ def from_parquet(
io.IOBase,
)

if mode is None:
deprecate(
RuntimeError(
"The 'delayed' argument is deprecated, please use 'mode' instead. "
"If you are using 'delayed=True' to construct a dask graph, please use 'mode=dask'"
),
"<unknown>",
)
mode = "dask" if delayed else "virtual"

if mode not in allowed_modes:
raise ValueError(f"Invalid mode {mode}, valid modes are {allowed_modes}")

if (
delayed
mode == "dask"
and not isinstance(schemaclass, FunctionType)
and schemaclass.__dask_capable__
):
Expand All @@ -490,8 +531,8 @@ def from_parquet(
)
else:
raise TypeError("Invalid file type (%s)" % (str(type(file))))
return cls(map_schema, opener, None, cache=None, is_dask=True)
elif delayed and not schemaclass.__dask_capable__:
return cls(map_schema, opener, None, cache=None, mode="dask")
elif mode == "dask" and not schemaclass.__dask_capable__:
warnings.warn(
f"{schemaclass} is not dask capable despite allowing dask, generating non-dask nanoevents"
)
Expand Down Expand Up @@ -528,6 +569,7 @@ def from_parquet(
entry_start,
entry_stop,
access_log=access_log,
virtual=mode == "virtual",
)

format_ = "parquet"
Expand Down Expand Up @@ -558,6 +600,7 @@ def from_parquet(
persistent_cache,
schemaclass,
metadata,
mode,
)

@classmethod
Expand Down Expand Up @@ -640,6 +683,7 @@ def from_preloaded(
persistent_cache,
schemaclass,
metadata,
mode="eager",
)

@classmethod
Expand All @@ -652,6 +696,7 @@ def _from_mapping(
persistent_cache,
schemaclass,
metadata,
mode,
):
"""Quickly build NanoEvents from a root file

Expand All @@ -674,6 +719,8 @@ def _from_mapping(
A schema class deriving from `BaseSchema` and implementing the desired view of the file
metadata : dict
Arbitrary metadata to add to the `base.NanoEvents` object
mode:
Nanoevents will use "eager", "virtual", or "dask" as a backend.

"""
if persistent_cache is not None:
Expand All @@ -690,7 +737,7 @@ def _from_mapping(
mapping,
tuple_to_key(partition_key),
cache=runtime_cache,
is_dask=False,
mode=mode,
)

def __len__(self):
Expand All @@ -711,7 +758,7 @@ def events(self):
If the factory is not running in delayed mode, this is an awkward
array of the events.
"""
if self._is_dask:
if self._mode == "dask":
events = self._mapping(form_mapping=self._schema)
report = None
if isinstance(events, tuple):
Expand All @@ -730,6 +777,7 @@ def events(self):
buffer_key=partial(_key_formatter, self._partition_key),
behavior=self._schema.behavior(),
attrs={"@events_factory": self},
allow_noncanonical_form=True,
)
self._events = weakref.ref(events)

Expand Down
120 changes: 67 additions & 53 deletions src/coffea/nanoevents/mapping/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from abc import abstractmethod
from collections.abc import Mapping
from functools import partial

import numpy
from cachetools import LRUCache
Expand All @@ -21,14 +22,22 @@ class BaseSourceMapping(Mapping):
_debug = False

def __init__(
self, fileopener, start, stop, cache=None, access_log=None, use_ak_forth=False
self,
fileopener,
start,
stop,
cache=None,
access_log=None,
use_ak_forth=False,
virtual=False,
):
self._fileopener = fileopener
self._cache = cache
self._access_log = access_log
self._start = start
self._stop = stop
self._use_ak_forth = use_ak_forth
self._virtual = virtual
self.setup()

def setup(self):
Expand Down Expand Up @@ -75,61 +84,66 @@ def interpret_key(cls, key):
return uuid, treepath, start, stop, nodes

def __getitem__(self, key):
uuid, treepath, start, stop, nodes = self.interpret_key(key)
if self._debug:
print("Getting (", key, ") :", uuid, treepath, start, stop, nodes)
stack = []
skip = False
for node in nodes:
if skip:
skip = False
continue
elif node == "!skip":
skip = True
continue
elif node.startswith("!load"):
handle_name = stack.pop()
if self._access_log is not None:
self._access_log.append(handle_name)
allow_missing = node == "!loadallowmissing"
handle = self.get_column_handle(
self._column_source(uuid, treepath), handle_name, allow_missing
)
stack.append(
self.extract_column(
handle,
start,
stop,
allow_missing,
use_ak_forth=self._use_ak_forth,
def _getitem(key):
uuid, treepath, start, stop, nodes = self.interpret_key(key)
if self._debug:
print("Getting (", key, ") :", uuid, treepath, start, stop, nodes)
stack = []
skip = False
for node in nodes:
if skip:
skip = False
continue
elif node == "!skip":
skip = True
continue
elif node.startswith("!load"):
handle_name = stack.pop()
if self._access_log is not None:
self._access_log.append(handle_name)
allow_missing = node == "!loadallowmissing"
handle = self.get_column_handle(
self._column_source(uuid, treepath), handle_name, allow_missing
)
)
elif node.startswith("!"):
tname = node[1:]
if not hasattr(transforms, tname):
raise RuntimeError(
f"Syntax error in form_key: no transform named {tname}"
stack.append(
self.extract_column(
handle,
start,
stop,
allow_missing,
use_ak_forth=self._use_ak_forth,
)
)
getattr(transforms, tname)(stack)
elif node.startswith("!"):
tname = node[1:]
if not hasattr(transforms, tname):
raise RuntimeError(
f"Syntax error in form_key: no transform named {tname}"
)
getattr(transforms, tname)(stack)
else:
stack.append(node)
if len(stack) != 1:
raise RuntimeError(f"Syntax error in form key {nodes}")
out = stack.pop()
import awkward

if isinstance(out, awkward.contents.Content):
out = awkward.to_numpy(out)
else:
stack.append(node)
if len(stack) != 1:
raise RuntimeError(f"Syntax error in form key {nodes}")
out = stack.pop()
import awkward

if isinstance(out, awkward.contents.Content):
out = awkward.to_numpy(out)
else:
try:
out = numpy.array(out)
except ValueError:
if self._debug:
print(out)
raise RuntimeError(
f"Left with non-bare array after evaluating form key {nodes}"
)
return out
try:
out = numpy.array(out)
except ValueError:
if self._debug:
print(out)
raise RuntimeError(
f"Left with non-bare array after evaluating form key {nodes}"
)
return out

if self._virtual:
return partial(_getitem, key)
return _getitem(key)

@abstractmethod
def __len__(self):
Expand Down
Loading
Loading