Skip to content

Commit

Permalink
Revert "Revert commit 5d910b6 (#909)" (#910)
Browse files Browse the repository at this point in the history
**Context:** The so-called offending commit was not the actual cause of
CPL rc/rc/rc not passing. It was a problem with the Enzyme cache

**Description of the Change:** Revert the commit that reverts the commit

**Benefits:** The bug fixed by this commit will be fixed.

**Possible Drawbacks:** Not so clean commit history
  • Loading branch information
rauletorresc authored Jul 8, 2024
1 parent 2007d9d commit 0b8213d
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 2 deletions.
1 change: 1 addition & 0 deletions doc/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
[(#822)](https://github.com/PennyLaneAI/catalyst/pull/822)
[(#834)](https://github.com/PennyLaneAI/catalyst/pull/834)
[(#882)](https://github.com/PennyLaneAI/catalyst/pull/882)
[(#907)](https://github.com/PennyLaneAI/catalyst/pull/907)

- When using callbacks that do not return any values, such as `catalyst.debug.callback` and
`catalyst.debug.print`, these functions are marked as 'inactive' and do not contribute to or
Expand Down
6 changes: 4 additions & 2 deletions frontend/catalyst/api_extensions/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,10 +434,12 @@ class MemrefCallable(FlatCallable):

CACHE = {}

def __new__(cls, func, results_aval, *_args, **_kwargs):
def __new__(cls, func, results_aval, *args, **kwargs):
# Hash-cons: https://en.wikipedia.org/wiki/Hash_consing
absargs, abskwargs = tree_map(shaped_abstractify, (args, kwargs))
flat_params, _ = tree_flatten((absargs, abskwargs))
flat_results_aval, _ = tree_flatten(results_aval)
cache_key = (func, *flat_results_aval)
cache_key = (func, *flat_params, *flat_results_aval)
if cls.CACHE.get(cache_key):
return cls.CACHE.get(cache_key)

Expand Down
60 changes: 60 additions & 0 deletions frontend/test/lit/test_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright 2024 Xanadu Quantum Technologies Inc.

Check notice on line 1 in frontend/test/lit/test_callback.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/test/lit/test_callback.py#L1

Missing module docstring (missing-module-docstring)

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# RUN: %PYTHON %s | FileCheck %s

import pennylane as qml
from catalyst import pure_callback


def i(x):

Check notice on line 21 in frontend/test/lit/test_callback.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/test/lit/test_callback.py#L21

Missing function or method docstring (missing-function-docstring)
return x


# CHECK-LABEL: module @one_callback_cached
@qml.qjit
# CHECK-NOT: catalyst.callback @callback
# CHECK-LABEL: func.func public @jit_one_callback_cached
def one_callback_cached(x: float):
"""Single callback is created, but called twice"""
c = pure_callback(i, float)
return c(x), c(x)


# CHECK-LABEL: catalyst.callback @callback
# CHECK-NOT: catalyst.callback @callback
print(one_callback_cached.mlir)


@pure_callback
def always_return_float(x) -> float:

Check notice on line 41 in frontend/test/lit/test_callback.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/test/lit/test_callback.py#L41

Missing function or method docstring (missing-function-docstring)
if x == 0.0:
return x
else:
return x + 0.0


# CHECK-LABEL: module @test2
@qml.qjit
# CHECK-NOT: catalyst.callback @callback
# CHECK-LABEL: func.func public @jit_test2
def test2():

Check notice on line 52 in frontend/test/lit/test_callback.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/test/lit/test_callback.py#L52

Missing function or method docstring (missing-function-docstring)
return always_return_float(0.0), always_return_float(1)


# CHECK-LABEL: catalyst.callback @callback
# CHECK-LABEL: catalyst.callback @callback
# CHECK-NOT: catalyst.callback @callback

print(test2.mlir)

0 comments on commit 0b8213d

Please sign in to comment.