Skip to content

Commit

Permalink
Support measurements of subveqs for bracket (#2416)
Browse files Browse the repository at this point in the history
* Add missing passes for translation to OpenQasm2

Signed-off-by: Anna Gringauze <[email protected]>

* Fix failing tests

Signed-off-by: Anna Gringauze <[email protected]>

* Fix failing tests and format

Signed-off-by: Anna Gringauze <[email protected]>

* Address CR comments

Signed-off-by: Anna Gringauze <[email protected]>

* Address CR comments

Signed-off-by: Anna Gringauze <[email protected]>

* Fix failing tests

Signed-off-by: Anna Gringauze <[email protected]>

* Fix failing tests

Signed-off-by: Anna Gringauze <[email protected]>

* * Added Python tests for the 'braket' target

Signed-off-by: Pradnya Khalate <[email protected]>

* * Skip expand measurements pass
* Two more tests working

Signed-off-by: Pradnya Khalate <[email protected]>

* * Ignore classical operations in OpenQASM2.0 translation
* Additional tests

Co-authored-by: Eric Schweitz <[email protected]>

Signed-off-by: Pradnya Khalate <[email protected]>

* * Support U3 gate

Signed-off-by: Pradnya Khalate <[email protected]>

* * Test for other simulators
* Test for asynchronous sampling API
* Failing test for observe API
* More tests to cover all native gates, custom operations
* One more test - check kernel that takes arguments

Signed-off-by: Pradnya Khalate <[email protected]>

* * Remove the `combine-quantum-alloc` pass since multiple registers are
  not supported.
* Simplify test setup since mock server isn't being used.

Signed-off-by: Pradnya Khalate <[email protected]>

* * Clean-up test, use 'Amazon Braket' in messages

Signed-off-by: Pradnya Khalate <[email protected]>

* * Decomposition patterns for R1, CRz, Sdg

Co-authored-by: Bettina Heim <[email protected]>

Signed-off-by: Pradnya Khalate <[email protected]>

* Remove and add some passes to braket

Signed-off-by: Anna Gringauze <[email protected]>

* * Added Python tests for the 'braket' target

Signed-off-by: Pradnya Khalate <[email protected]>

* * Skip expand measurements pass
* Two more tests working

Signed-off-by: Pradnya Khalate <[email protected]>

* * Ignore classical operations in OpenQASM2.0 translation
* Additional tests

Co-authored-by: Eric Schweitz <[email protected]>

Signed-off-by: Pradnya Khalate <[email protected]>

* * Support U3 gate

Signed-off-by: Pradnya Khalate <[email protected]>

* * Test for other simulators
* Test for asynchronous sampling API
* Failing test for observe API
* More tests to cover all native gates, custom operations
* One more test - check kernel that takes arguments

Signed-off-by: Pradnya Khalate <[email protected]>

* * Remove the `combine-quantum-alloc` pass since multiple registers are
  not supported.
* Simplify test setup since mock server isn't being used.

Signed-off-by: Pradnya Khalate <[email protected]>

* * Clean-up test, use 'Amazon Braket' in messages

Signed-off-by: Pradnya Khalate <[email protected]>

* * Decomposition patterns for R1, CRz, Sdg

Co-authored-by: Bettina Heim <[email protected]>

Signed-off-by: Pradnya Khalate <[email protected]>

* * Addd `tdg` decomposition

Signed-off-by: Pradnya Khalate <[email protected]>

* * Control modifier fixes

Signed-off-by: Pradnya Khalate <[email protected]>

* * More tests

Signed-off-by: Pradnya Khalate <[email protected]>

* * Failing test for multiple measurement ops

Signed-off-by: Pradnya Khalate <[email protected]>

* Update lib/Optimizer/Transforms/DecompositionPatterns.cpp

Co-authored-by: Eric Schweitz <[email protected]>
Signed-off-by: Pradnya Khalate <[email protected]>

* Add combine-measurements pass

Signed-off-by: Anna Gringauze <[email protected]>

* * Remove the decomposition patterns for 'SAdjToSZ' and 'TAdjToR1' since
  the existing 'SToR1' and 'TToR1' are sufficient
* Disable 'R1toU3' pattern on all pipelines
* Clean -up comments

Signed-off-by: Pradnya Khalate <[email protected]>

* * Restore 'translateOperatorName' function, and the corresponding test

Signed-off-by: Pradnya Khalate <[email protected]>

* * Correct the comment about global pahse on R1ToU3
* Simpler command-line invocation in test

Signed-off-by: Pradnya Khalate <[email protected]>

* Made the tests work end to end

Signed-off-by: Anna Gringauze <[email protected]>

* Update translate tests

Signed-off-by: Anna Gringauze <[email protected]>

* support multimple qubit measurements

Signed-off-by: Anna Gringauze <[email protected]>

* Added tests

Signed-off-by: Anna Gringauze <[email protected]>

* Address CR comments

Signed-off-by: Anna Gringauze <[email protected]>

* Address CR comments and remove printing

Signed-off-by: Anna Gringauze <[email protected]>

* Address CR comments and remove printing

Signed-off-by: Anna Gringauze <[email protected]>

* Fix test failures and address CR comments

Signed-off-by: Anna Gringauze <[email protected]>

* Fix test failures

Signed-off-by: Anna Gringauze <[email protected]>

* Update combine-measurements after merging with main

Signed-off-by: Anna Gringauze <[email protected]>

* Address CR comments

Signed-off-by: Anna Gringauze <[email protected]>

* Format

Signed-off-by: Anna Gringauze <[email protected]>

* Address CR comments

* DCO Remediation Commit for Anna Gringauze <[email protected]>

I, Anna Gringauze <[email protected]>, hereby add my Signed-off-by to this commit: 3966ee3

Signed-off-by: Anna Gringauze <[email protected]>

---------

Signed-off-by: Anna Gringauze <[email protected]>
Signed-off-by: Pradnya Khalate <[email protected]>
Signed-off-by: Pradnya Khalate <[email protected]>
Co-authored-by: Pradnya Khalate <[email protected]>
Co-authored-by: Pradnya Khalate <[email protected]>
Co-authored-by: Eric Schweitz <[email protected]>
  • Loading branch information
4 people authored Nov 27, 2024
1 parent a4660c9 commit 1dea79e
Show file tree
Hide file tree
Showing 19 changed files with 674 additions and 37 deletions.
26 changes: 26 additions & 0 deletions include/cudaq/Optimizer/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,32 @@ def CheckKernelCalls : Pass<"check-kernel-calls", "mlir::func::FuncOp"> {
}];
}

def CombineMeasurements :
Pass<"combine-measurements", "mlir::func::FuncOp"> {
let summary = "Extends mesurements on subveqs adds output names";
let description = [{
Replace a pattern such as:
```
func.func @kernel() attributes {"cudaq-entrypoint"} {
%1 = ... : !quake.veq<4>
%2 = quake.subveq %1, %c2, %c3 : (!quake.veq<4>, i32, i32) ->
!quake.veq<2>
%measOut = quake.mz %2 : (!quake.veq<2>) -> !cc.stdvec<!quake.measure>
}
```
with:
```
func.func @kernel() attributes {"cudaq-entrypoint", ["output_names",
"[[[0,[1,\22q0\22]],[1,[2,\22q1\22]]]]"]} {
%1 = ... : !quake.veq<4>
%measOut = quake.mz %1 : (!quake.veq<4>) -> !cc.stdvec<!quake.measure>
}
```
}];
let dependentDialects = ["cudaq::cc::CCDialect", "quake::QuakeDialect"];
}


def CombineQuantumAllocations :
Pass<"combine-quantum-alloc", "mlir::func::FuncOp"> {
let summary = "Combines quake alloca operations.";
Expand Down
18 changes: 3 additions & 15 deletions lib/Optimizer/CodeGen/Pipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,23 +42,11 @@ void cudaq::opt::commonPipelineConvertToQIR(
}

void cudaq::opt::addPipelineTranslateToOpenQASM(PassManager &pm) {
pm.addPass(createCanonicalizerPass());
pm.addPass(createCSEPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(createCSEPass());
pm.addNestedPass<func::FuncOp>(createClassicalMemToReg());
pm.addPass(createLoopUnroll());
pm.addPass(createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(createLiftArrayAlloc());
pm.addPass(createGlobalizeArrayValues());
pm.addPass(createStatePreparation());
pm.addNestedPass<func::FuncOp>(createGetConcreteMatrix());
pm.addPass(createUnitarySynthesis());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
pm.addPass(createSymbolDCEPass());
pm.addPass(createCanonicalizerPass());
pm.addPass(createCSEPass());
pm.addNestedPass<func::FuncOp>(createMultiControlDecompositionPass());
pm.addPass(createDecompositionPass(
{.enabledPatterns = {"CCZToCX", "RxAdjToRx", "RyAdjToRy", "RzAdjToRz"}}));
pm.addPass(createCanonicalizerPass());
}

void cudaq::opt::addPipelineTranslateToIQMJson(PassManager &pm) {
Expand Down
1 change: 1 addition & 0 deletions lib/Optimizer/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ add_cudaq_library(OptTransforms
ApplyOpSpecialization.cpp
ArgumentSynthesis.cpp
BasisConversion.cpp
CombineMeasurements.cpp
CombineQuantumAlloc.cpp
ConstPropComplex.cpp
Decomposition.cpp
Expand Down
263 changes: 263 additions & 0 deletions lib/Optimizer/Transforms/CombineMeasurements.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@
/*******************************************************************************
* Copyright (c) 2022 - 2024 NVIDIA Corporation & Affiliates. *
* All rights reserved. *
* *
* This source code and the accompanying materials are made available under *
* the terms of the Apache License 2.0 which accompanies this distribution. *
******************************************************************************/

#include "PassDetails.h"
#include "cudaq/Optimizer/Builder/Factory.h"
#include "cudaq/Optimizer/CodeGen/QIRAttributeNames.h"
#include "cudaq/Optimizer/Dialect/CC/CCOps.h"
#include "cudaq/Optimizer/Dialect/Quake/QuakeOps.h"
#include "cudaq/Optimizer/Dialect/Quake/QuakeTypes.h"
#include "cudaq/Optimizer/Transforms/Passes.h"
#include "nlohmann/json.hpp"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"

namespace cudaq::opt {
#define GEN_PASS_DEF_COMBINEMEASUREMENTS
#include "cudaq/Optimizer/Transforms/Passes.h.inc"
} // namespace cudaq::opt

#define DEBUG_TYPE "combine-measurements"

using namespace mlir;

namespace {

// After combine-quantum-alloc, we have one top allocation per function.
// The following type is used to store qubit mapping from result qubit
// index to the actual qubit index and register name.
// map[result] --> [qb,regName]
// Note: register name is currently not used in `OpenQasm2` backends,
// so we supply a bogus name.
using OutputNamesType =
std::map<std::size_t, std::pair<std::size_t, std::string>>;

struct Analysis {
Analysis() = default;
Analysis(const Analysis &) = delete;
Analysis(Analysis &&) = delete;
Analysis &operator=(const Analysis &) = delete;

mlir::DenseMap<mlir::Value, std::size_t> measurements;
OutputNamesType resultQubitVals;
quake::MzOp lastMeasurement;

bool empty() const { return measurements.empty(); }

LogicalResult analyze(func::FuncOp func) {
quake::AllocaOp qalloc;
std::size_t currentOffset = 0;

for (auto &block : func.getRegion()) {
for (auto &op : block) {
if (auto alloc = dyn_cast_or_null<quake::AllocaOp>(&op)) {
if (qalloc)
return op.emitError("Multiple qalloc statements found");

qalloc = alloc;
} else if (auto measure = dyn_cast_or_null<quake::MzOp>(&op)) {
if (!measure.use_empty()) {
measure.emitWarning("Measurements with uses are not supported");
return success();
}

auto veqOp = measure.getOperand(0);
auto ty = veqOp.getType();

std::size_t size = 0;
if (auto veqTy = dyn_cast<quake::RefType>(ty))
size = 1;
else if (auto veqTy = dyn_cast<quake::VeqType>(ty)) {
size = veqTy.getSize();
if (size == 0)
return op.emitError("Unknown measurement size");
}

measurements[measure.getMeasOut()] = currentOffset;
lastMeasurement = measure;
currentOffset += size;
}
}
}

return success();
}
};

class ExtendQubitMeasurePattern : public OpRewritePattern<quake::MzOp> {
public:
using OpRewritePattern::OpRewritePattern;

explicit ExtendQubitMeasurePattern(MLIRContext *ctx, Analysis &analysis)
: OpRewritePattern(ctx), analysis(analysis) {}

// Replace a pattern such as:
// ```
// %0 = ...: !quake.veq<2>
// %1 = quake.extract_ref %0[0] : (!quake.veq<2>) -> !quake.ref
// %measOut = quake.mz %1 : (!quake.ref) -> !quake.measure
// ```
// with:
// ```
// %1 = ... : !quake.veq<4>
// %measOut = quake.mz %1 : (!quake.veq<4>) -> !cc.stdvec<!quake.measure>
// ```
// And collect output names information: `"[[[0,[1,"q0"]],[1,[2,"q1"]]]]"`
LogicalResult matchAndRewrite(quake::MzOp measure,
PatternRewriter &rewriter) const override {

auto veqOp = measure.getOperand(0);
if (auto extract = veqOp.getDefiningOp<quake::ExtractRefOp>()) {
auto veq = extract.getVeq();
std::size_t idx;

if (extract.hasConstantIndex())
idx = extract.getConstantIndex();
else if (auto cst =
extract.getIndex().getDefiningOp<arith::ConstantIntOp>())
idx = static_cast<std::size_t>(cst.value());
else
return extract.emitError("Non-constant index in ExtractRef");

auto offset = analysis.measurements[measure.getMeasOut()];
analysis.resultQubitVals[offset] =
std::make_pair(idx, std::to_string(idx));

auto resultType = cudaq::cc::StdvecType::get(measure.getType(0));
if (measure == analysis.lastMeasurement)
rewriter.replaceOpWithNewOp<quake::MzOp>(measure, TypeRange{resultType},
ValueRange{veq},
measure.getRegisterNameAttr());
else if (measure.use_empty())
rewriter.eraseOp(measure);
}

return failure();
}

private:
Analysis &analysis;
};

class ExtendVeqMeasurePattern : public OpRewritePattern<quake::MzOp> {
public:
using OpRewritePattern::OpRewritePattern;

explicit ExtendVeqMeasurePattern(MLIRContext *ctx, Analysis &analysis)
: OpRewritePattern(ctx), analysis(analysis) {}

// Replace a pattern such as:
// ```
// %1 = ... : !quake.veq<4>
// %2 = quake.subveq %1, %c1, %c2 : (!quake.veq<4>, i32, i32) ->
// !quake.veq<2>
// %measOut = quake.mz %2 : (!quake.veq<2>) -> !cc.stdvec<!quake.measure>
// ```
// with:
// ```
// %1 = ... : !quake.veq<4>
// %measOut = quake.mz %1 : (!quake.veq<4>) -> !cc.stdvec<!quake.measure>
// ```
// And collect output names information: `"[[[0,[1,"q0"]],[1,[2,"q1"]]]]"`
LogicalResult matchAndRewrite(quake::MzOp measure,
PatternRewriter &rewriter) const override {
auto veqOp = measure.getOperand(0);
if (auto subveq = veqOp.getDefiningOp<quake::SubVeqOp>()) {
std::size_t low;
if (subveq.hasConstantLowerBound())
low = subveq.getConstantLowerBound();
else {
auto value = cudaq::opt::factory::getIntIfConstant(subveq.getLower());
if (!value.has_value())
return subveq.emitError("Non-constant lower index in subveq");
low = static_cast<std::size_t>(value.value());
}

std::size_t high;
if (subveq.hasConstantUpperBound())
high = subveq.getConstantUpperBound();
else {
auto value = cudaq::opt::factory::getIntIfConstant(subveq.getUpper());
if (!value.has_value())
return subveq.emitError("Non-constant upper index in subveq");
high = static_cast<std::size_t>(value.value());
}

for (std::size_t i = low; i <= high; i++) {
auto start = analysis.measurements[measure.getMeasOut()];
auto offset = i - low + start;
analysis.resultQubitVals[offset] = std::make_pair(i, std::to_string(i));
}

if (measure == analysis.lastMeasurement)
rewriter.replaceOpWithNewOp<quake::MzOp>(
measure, measure.getResultTypes(), ValueRange{subveq.getVeq()},
measure.getRegisterNameAttr());
else if (measure.use_empty())
rewriter.eraseOp(measure);

return success();
}

return failure();
}

private:
Analysis &analysis;
};

class CombineMeasurementsPass
: public cudaq::opt::impl::CombineMeasurementsBase<
CombineMeasurementsPass> {
public:
using CombineMeasurementsBase::CombineMeasurementsBase;

void runOnOperation() override {
auto *ctx = &getContext();
func::FuncOp func = getOperation();
OpBuilder builder(func);

LLVM_DEBUG(llvm::dbgs() << "Function before combining measurements:\n"
<< func << "\n\n");

// Analyze the function to find all qubit mappings.
Analysis analysis;
if (failed(analysis.analyze(func))) {
func.emitOpError("Combining measurements failed");
signalPassFailure();
}

if (analysis.empty())
return;

// Extend measurement into one last full measurement.
RewritePatternSet patterns(ctx);
patterns.insert<ExtendQubitMeasurePattern, ExtendVeqMeasurePattern>(
ctx, analysis);
if (failed(applyPatternsAndFoldGreedily(func.getOperation(),
std::move(patterns)))) {
func.emitOpError("Combining measurements failed");
signalPassFailure();
}

// Add output names mapping attribute.
if (!analysis.resultQubitVals.empty()) {
nlohmann::json resultQubitJSON{analysis.resultQubitVals};
func->setAttr(cudaq::opt::QIROutputNamesAttrName,
builder.getStringAttr(resultQubitJSON.dump()));
}

LLVM_DEBUG(llvm::dbgs() << "Function after combining measurements:\n"
<< func << "\n\n");
}
};
} // namespace
40 changes: 38 additions & 2 deletions python/runtime/cudaq/platform/py_alt_launch_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -561,8 +561,8 @@ MlirModule synthesizeKernel(const std::string &name, MlirModule module,
pm.addNestedPass<func::FuncOp>(cudaq::opt::createExpandMeasurementsPass());
pm.addNestedPass<func::FuncOp>(cudaq::opt::createClassicalMemToReg());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(cudaq::opt::createLoopNormalize());
pm.addNestedPass<func::FuncOp>(cudaq::opt::createLoopUnroll());
pm.addPass(cudaq::opt::createLoopNormalize());
pm.addPass(cudaq::opt::createLoopUnroll());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
DefaultTimingManager tm;
tm.setEnabled(cudaq::isTimingTagEnabled(cudaq::TIMING_JIT_PASSES));
Expand Down Expand Up @@ -634,10 +634,46 @@ std::string getASM(const std::string &name, MlirModule module,
auto cloned = unwrap(module).clone();
auto context = cloned.getContext();

// Get additional debug values
auto disableMLIRthreading = getEnvBool("CUDAQ_MLIR_DISABLE_THREADING", false);
auto enablePrintMLIREachPass =
getEnvBool("CUDAQ_MLIR_PRINT_EACH_PASS", false);

PassManager pm(context);
pm.addPass(cudaq::opt::createLambdaLiftingPass());
// Run most of the passes from hardware pipelines.
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(createCSEPass());
pm.addNestedPass<func::FuncOp>(cudaq::opt::createClassicalMemToReg());
pm.addPass(cudaq::opt::createLoopNormalize());
pm.addPass(cudaq::opt::createLoopUnroll());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(cudaq::opt::createLiftArrayAlloc());
pm.addPass(cudaq::opt::createGlobalizeArrayValues());
pm.addPass(cudaq::opt::createStatePreparation());
pm.addNestedPass<func::FuncOp>(cudaq::opt::createGetConcreteMatrix());
pm.addPass(cudaq::opt::createUnitarySynthesis());
pm.addPass(cudaq::opt::createApplyOpSpecializationPass());
cudaq::opt::addAggressiveEarlyInlining(pm);
pm.addPass(createSymbolDCEPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(createCSEPass());
pm.addNestedPass<func::FuncOp>(
cudaq::opt::createMultiControlDecompositionPass());
pm.addPass(cudaq::opt::createDecompositionPass(
{.enabledPatterns = {"SToR1", "TToR1", "R1ToU3", "U3ToRotations",
"CHToCX", "CCZToCX", "CRzToCX", "CRyToCX", "CRxToCX",
"CR1ToCX", "CCZToCX", "RxAdjToRx", "RyAdjToRy",
"RzAdjToRz"}}));
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(cudaq::opt::createExpandControlVeqs());
pm.addNestedPass<func::FuncOp>(cudaq::opt::createCombineQuantumAllocations());
cudaq::opt::addPipelineTranslateToOpenQASM(pm);

if (disableMLIRthreading || enablePrintMLIREachPass)
context->disableMultithreading();
if (enablePrintMLIREachPass)
pm.enableIRPrinting();
if (failed(pm.run(cloned)))
throw std::runtime_error("getASM: code generation failed.");
std::free(rawArgs);
Expand Down
Loading

0 comments on commit 1dea79e

Please sign in to comment.