From f14db3d203343a49b535e7c04102bdd1d5f3deeb Mon Sep 17 00:00:00 2001 From: boschmitt <7152025+boschmitt@users.noreply.github.com> Date: Thu, 31 Oct 2024 16:10:16 +0100 Subject: [PATCH] Fix logic to find real module path With python device kernel interoperability, users can write quantum kernels in C++ and bind them to python. In such cases, the common pattern is to have a C++ module that gets imported into a python module. For example, if we have a python package named `foo` to which we add C++ extensions using pybind11. The common pattern is to end up with a with a module named `_cppfoo` (or whaterver). Then, we import all of its symbols to `foo`: foo/__init__.py: from ._cppfoo import * Now, if `_cppfoo` contains a binded device kernel named `bar`, then users are able to access it using `foo.bar(...)`. This, however, is not the real path of `bar`, the real path is `foo._cppfoo.bar(..)`. Currently, binded device kernels get registered with their real path name, and thus when the python AST bridge parse another kernel that uses `foo.bar(...)`, it needs to figure it out if that is its real path or not. This commit attemps to improve the robustness of discovering this real path because as-is it fails on some simple cases. This is how it works: In Python, many objects have a module attribute, which indicates the module in which the object was defined. This should be the case for functions. Thus the idea here is to walk the provide path until we reach the function object and ask it for its `__module__`. Signed-off-by: boschmitt <7152025+boschmitt@users.noreply.github.com> --- python/cudaq/kernel/ast_bridge.py | 98 +++++++++++++++------------- python/tests/interop/qlib.py | 10 +++ python/tests/interop/test_interop.py | 23 ++++++- 3 files changed, 84 insertions(+), 47 deletions(-) create mode 100644 python/tests/interop/qlib.py diff --git a/python/cudaq/kernel/ast_bridge.py b/python/cudaq/kernel/ast_bridge.py index 6f04773c3f..9d13d6350f 100644 --- a/python/cudaq/kernel/ast_bridge.py +++ b/python/cudaq/kernel/ast_bridge.py @@ -9,6 +9,7 @@ import hashlib import graphlib import sys, os +from types import ModuleType from typing import Callable from collections import deque import numpy as np @@ -1296,60 +1297,66 @@ def visit_Call(self, node): """ global globalRegisteredOperations - if self.verbose: - print("[Visit Call] {}".format( - ast.unparse(node) if hasattr(ast, 'unparse') else node)) + #if self.verbose: + print("[Visit Call] {}".format( + ast.unparse(node) if hasattr(ast, 'unparse') else node)) self.currentNode = node - # do not walk the FunctionDef decorator_list arguments if isinstance(node.func, ast.Attribute): - if hasattr( - node.func.value, 'id' - ) and node.func.value.id == 'cudaq' and node.func.attr == 'kernel': - return + # When `node.func` is an attribute, then we have the case where the + # call has the following form: `..<...>.`. + value = node.func.value + + # First, we walk all the components until we reach a name. + components = [node.func.attr] + while isinstance(value, ast.Attribute): + components.append(value.attr) + value = value.value + components.append(value.id) + components = components[::-1] - # If we have a `func = ast.Attribute``, then it could be that - # we have a previously defined kernel function call with manually specified module names - # e.g. `cudaq.lib.test.hello.fermionic_swap``. In this case, we assume - # FindDepKernels has found something like this, loaded it, and now we just - # want to get the function name and call it. + # Check whether this is our known decorator `@cudaq.kernel`. If it + # is then we gracefully ignore it. + if components[0] == 'cudaq' and components[1] == 'kernel': + return - # First let's check for registered C++ kernels - cppDevModNames = [] - value = node.func.value - if isinstance(value, ast.Name) and value.id != 'cudaq': - cppDevModNames = [node.func.attr, value.id] - else: - while isinstance(value, ast.Attribute): - cppDevModNames.append(value.attr) - value = value.value - if isinstance(value, ast.Name): - cppDevModNames.append(value.id) - break - - devKey = '.'.join(cppDevModNames[::-1]) - - def get_full_module_path(partial_path): - parts = partial_path.split('.') - for module_name, module in sys.modules.items(): - if module_name.endswith(parts[0]): - try: - obj = module - for part in parts[1:]: - obj = getattr(obj, part) - return f"{module_name}.{'.'.join(parts[1:])}" - except AttributeError: - continue - return partial_path - - devKey = get_full_module_path(devKey) - if cudaq_runtime.isRegisteredDeviceModule(devKey): + # Get full module path. + # + # Note: Here we skip anything that starts with `cudaq.` because not + # all constructs are backed by an python object. See issue # + modPath = "" + if components[0] != 'cudaq': + if components[0] in sys.modules: + module = sys.modules[components[0]] + obj = module + for attribute in components[1:]: + obj = getattr(obj, attribute) + if hasattr(obj, + '__module__') and obj.__module__ != obj.__name__: + modPath = obj.__module__ + else: + modPath = obj.__name__ + else: + import inspect + current_frame = inspect.currentframe() + mod = None + while current_frame is not None: + local_vars = current_frame.f_locals + if components[0] in local_vars: + mod = local_vars[components[0]] + break + current_frame = current_frame.f_back + + if isinstance(mod, ModuleType): + modPath = mod.__name__ + + if cudaq_runtime.isRegisteredDeviceModule(modPath): maybeKernelName = cudaq_runtime.checkRegisteredCppDeviceKernel( - self.module, devKey + '.' + node.func.attr) + self.module, modPath + '.' + node.func.attr) if maybeKernelName == None: maybeKernelName = cudaq_runtime.checkRegisteredCppDeviceKernel( - self.module, devKey) + self.module, modPath) if maybeKernelName != None: otherKernel = SymbolTable( self.module.operation)[maybeKernelName] @@ -1368,7 +1375,6 @@ def get_full_module_path(partial_path): func.CallOp(otherKernel, values) return - # Start by seeing if we have mod1.mod2.mod3... moduleNames = [] value = node.func.value while isinstance(value, ast.Attribute): diff --git a/python/tests/interop/qlib.py b/python/tests/interop/qlib.py new file mode 100644 index 0000000000..cea8b1861a --- /dev/null +++ b/python/tests/interop/qlib.py @@ -0,0 +1,10 @@ +# ============================================================================ # +# Copyright (c) 2022 - 2024 NVIDIA Corporation & Affiliates. # +# All rights reserved. # +# # +# This source code and the accompanying materials are made available under # +# the terms of the Apache License 2.0 which accompanies this distribution. # +# ============================================================================ # + +from cudaq_test_cpp_algo import * + diff --git a/python/tests/interop/test_interop.py b/python/tests/interop/test_interop.py index 78d46e576e..0d52d8e4e7 100644 --- a/python/tests/interop/test_interop.py +++ b/python/tests/interop/test_interop.py @@ -222,6 +222,27 @@ def callUCCSD(): callUCCSD() +def test_cpp_kernel_from_python_3(): + + import qlib + + # Sanity checks + print(qlib.qstd.qft) + print(qlib.qstd.another) + + @cudaq.kernel + def callQftAndAnother(): + q = cudaq.qvector(4) + qlib.qstd.qft(q) + h(q) + qlib.qstd.another(q, 2) + + callQftAndAnother() + + counts = cudaq.sample(callQftAndAnother) + counts.dump() + assert len(counts) == 1 and '0010' in counts + def test_capture(): @cudaq.kernel def takesCapture(s : int): @@ -232,4 +253,4 @@ def takesCapture(s : int): @cudaq.kernel(verbose=True) def entry(): takesCapture(spin) - entry.compile() \ No newline at end of file + entry.compile()