Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a few state-related cc ops #2354

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
85 changes: 85 additions & 0 deletions include/cudaq/Optimizer/Dialect/Quake/QuakeOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1397,4 +1397,89 @@ def CustomUnitarySymbolOp :
}];
}

//===----------------------------------------------------------------------===//
// Quantum states
//===----------------------------------------------------------------------===//

def quake_CreateStateOp : QuakeOp<"create_state", [Pure] > {
let summary = "Create state from data";
let description = [{
This operation takes a pointer to state data and creates a quantum state.
The operation can be optimized away in DeleteStates pass, or replaced
by an intrinsic runtime call on simulators.

```mlir
%0 = quake.create_state %data %len: !cc.ptr<!cc.state>
```
}];

let arguments = (ins
cc_PointerType:$data,
AnySignlessInteger:$length
);
let results = (outs cc_PointerType:$result);
let assemblyFormat = [{
$data `,` $length `:` functional-type(operands, results) attr-dict
}];
}

def QuakeOp_DeleteStateOp : QuakeOp<"delete_state", [] > {
let summary = "Delete quantum state";
let description = [{
This operation takes a pointer to the state and deletes the state object.
The operation can be created in in DeleteStates pass, and replaced later
by an intrinsic runtime call on simulators.

```mlir
quake.delete_state %state : (!cc.ptr<!cc.state>) -> ()
```
}];

let arguments = (ins cc_PointerType:$state);
let results = (outs);
let assemblyFormat = [{
$state `:` functional-type(operands, results) attr-dict
}];
}

def quake_GetNumberOfQubitsOp : QuakeOp<"get_number_of_qubits", [Pure] > {
let summary = "Get number of qubits from a quantum state";
let description = [{
This operation takes a state pointer argument and returns a number of
qubits in the state. The operation can be optimized away in some passes
line ReplaceStateByKernel or DeleteStates, or replaced by an intrinsic
runtime call on simulators.

```mlir
%0 = quake.get_number_of_qubits %state : (!cc.ptr<!cc.state>) -> i64
```
}];

let arguments = (ins cc_PointerType:$state);
let results = (outs AnySignlessInteger:$result);
let assemblyFormat = [{
$state `:` functional-type(operands, results) attr-dict
}];
}

def QuakeOp_GetStateOp : QuakeOp<"get_state", [Pure] > {
let summary = "Get state from kernel with the provided name.";
let description = [{
This operation is created by argument synthesis of state pointer arguments
for quantum devices. It takes a kernel name as ASCIIZ string literal value
and returns the kernel's quantum state. The operation is replaced by a call
to the kernel with the provided name in ReplaceStateByKernel pass.

```mlir
%0 = quake.get_state "callee" : !cc.ptr<!cc.state>
```
}];

let arguments = (ins StrAttr:$calleeName);
let results = (outs cc_PointerType:$result);
let assemblyFormat = [{
$calleeName `:` qualified(type(results)) attr-dict
}];
}

#endif // CUDAQ_OPTIMIZER_DIALECT_QUAKE_OPS
5 changes: 2 additions & 3 deletions include/cudaq/Optimizer/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -778,9 +778,8 @@ def DeleteStates : Pass<"delete-states", "mlir::ModuleOp"> {
func.func @foo() attributes {"cudaq-entrypoint", "cudaq-kernel", no_this} {
%c8_i64 = arith.constant 8 : i64
%0 = cc.address_of @foo.rodata_synth_0 : !cc.ptr<!cc.array<complex<f32> x 8>>
%3 = cc.cast %0 : (!cc.ptr<!cc.array<complex<f32> x 8>>) -> !cc.ptr<i8>
%4 = call @__nvqpp_cudaq_state_createFromData_fp32(%3, %c8_i64) : (!cc.ptr<i8>, i64) -> !cc.ptr<!cc.state>
%5 = call @__nvqpp_cudaq_state_numberOfQubits(%4) : (!cc.ptr<!cc.state>) -> i64
%4 = cc.create_state %3, %c8_i64 : (!cc.ptr<!cc.array<complex<f32> x 8>>, i64) -> !cc.ptr<!cc.state>
%5 = cc.get_number_of_qubits %4 : (!cc.ptr<!cc.state>) -> i64
%6 = quake.alloca !quake.veq<?>[%5 : i64]
%7 = quake.init_state %6, %4 : (!quake.veq<?>, !cc.ptr<!cc.state>) -> !quake.veq<?>

Expand Down
13 changes: 3 additions & 10 deletions lib/Frontend/nvqpp/ConvertExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2697,19 +2697,12 @@ bool QuakeBridgeVisitor::VisitCXXConstructExpr(clang::CXXConstructExpr *x) {
initials = load.getPtrvalue();
}
if (isStateType(initials.getType())) {
IRBuilder irBuilder(builder.getContext());
auto mod =
builder.getBlock()->getParentOp()->getParentOfType<ModuleOp>();
auto result =
irBuilder.loadIntrinsic(mod, getNumQubitsFromCudaqState);
assert(succeeded(result) && "loading intrinsic should never fail");
Value state = initials;
auto i64Ty = builder.getI64Type();
auto numQubits = builder.create<func::CallOp>(
loc, i64Ty, getNumQubitsFromCudaqState, ValueRange{state});
auto numQubits =
builder.create<quake::GetNumberOfQubitsOp>(loc, i64Ty, state);
auto veqTy = quake::VeqType::getUnsized(ctx);
Value alloc = builder.create<quake::AllocaOp>(loc, veqTy,
numQubits.getResult(0));
Value alloc = builder.create<quake::AllocaOp>(loc, veqTy, numQubits);
return pushValue(builder.create<quake::InitializeStateOp>(
loc, veqTy, alloc, state));
}
Expand Down
89 changes: 88 additions & 1 deletion lib/Optimizer/CodeGen/QuakeToCodegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@

#include "QuakeToCodegen.h"
#include "CodeGenOps.h"
#include "cudaq/Optimizer/Builder/Intrinsics.h"
#include "cudaq/Optimizer/CodeGen/Passes.h"
#include "cudaq/Optimizer/CodeGen/QIRFunctionNames.h"
#include "cudaq/Optimizer/Dialect/CC/CCOps.h"
#include "cudaq/Optimizer/Dialect/Quake/QuakeOps.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
Expand Down Expand Up @@ -62,10 +65,94 @@ class ExpandComplexCast : public OpRewritePattern<cudaq::cc::CastOp> {
return success();
}
};

class CreateStateOpPattern : public OpRewritePattern<quake::CreateStateOp> {
public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(quake::CreateStateOp createStateOp,
PatternRewriter &rewriter) const override {
auto module = createStateOp->getParentOfType<ModuleOp>();
auto loc = createStateOp.getLoc();
auto ctx = createStateOp.getContext();
auto buffer = createStateOp.getOperand(0);
auto size = createStateOp.getOperand(1);

auto bufferTy = buffer.getType();
auto ptrTy = cast<cudaq::cc::PointerType>(bufferTy);
auto arrTy = cast<cudaq::cc::ArrayType>(ptrTy.getElementType());
auto eleTy = arrTy.getElementType();
auto is64Bit = isa<Float64Type>(eleTy);

if (auto cTy = dyn_cast<ComplexType>(eleTy))
is64Bit = isa<Float64Type>(cTy.getElementType());

auto createStateFunc = is64Bit ? cudaq::createCudaqStateFromDataFP64
: cudaq::createCudaqStateFromDataFP32;
cudaq::IRBuilder irBuilder(ctx);
auto result = irBuilder.loadIntrinsic(module, createStateFunc);
assert(succeeded(result) && "loading intrinsic should never fail");

auto stateTy = cudaq::cc::StateType::get(ctx);
auto statePtrTy = cudaq::cc::PointerType::get(stateTy);
auto i8PtrTy = cudaq::cc::PointerType::get(rewriter.getI8Type());
auto cast = rewriter.create<cudaq::cc::CastOp>(loc, i8PtrTy, buffer);

rewriter.replaceOpWithNewOp<func::CallOp>(
createStateOp, statePtrTy, createStateFunc, ValueRange{cast, size});
return success();
}
};

class DeleteStateOpPattern : public OpRewritePattern<quake::DeleteStateOp> {
public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(quake::DeleteStateOp deleteStateOp,
PatternRewriter &rewriter) const override {
auto module = deleteStateOp->getParentOfType<ModuleOp>();
auto ctx = deleteStateOp.getContext();
auto state = deleteStateOp.getOperand();

cudaq::IRBuilder irBuilder(ctx);
auto result = irBuilder.loadIntrinsic(module, cudaq::deleteCudaqState);
assert(succeeded(result) && "loading intrinsic should never fail");

rewriter.replaceOpWithNewOp<func::CallOp>(deleteStateOp, std::nullopt,
cudaq::deleteCudaqState,
mlir::ValueRange{state});
return success();
}
};

class GetNumberOfQubitsOpPattern
: public OpRewritePattern<quake::GetNumberOfQubitsOp> {
public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(quake::GetNumberOfQubitsOp getNumQubitsOp,
PatternRewriter &rewriter) const override {
auto module = getNumQubitsOp->getParentOfType<ModuleOp>();
auto ctx = getNumQubitsOp.getContext();
auto state = getNumQubitsOp.getOperand();

cudaq::IRBuilder irBuilder(ctx);
auto result =
irBuilder.loadIntrinsic(module, cudaq::getNumQubitsFromCudaqState);
assert(succeeded(result) && "loading intrinsic should never fail");

rewriter.replaceOpWithNewOp<func::CallOp>(
getNumQubitsOp, rewriter.getI64Type(),
cudaq::getNumQubitsFromCudaqState, state);
return success();
}
};

} // namespace

void cudaq::codegen::populateQuakeToCodegenPatterns(
mlir::RewritePatternSet &patterns) {
auto *ctx = patterns.getContext();
patterns.insert<CodeGenRAIIPattern, ExpandComplexCast>(ctx);
patterns.insert<CodeGenRAIIPattern, ExpandComplexCast, CreateStateOpPattern,
DeleteStateOpPattern, GetNumberOfQubitsOpPattern>(ctx);
}
Loading
Loading