Skip to content

Commit

Permalink
[core] Add checks to the CC dialect types.
Browse files Browse the repository at this point in the history
We do not allow CC dialect types to "contain" quantum types.

Fix some issues in the C++ bridge that were resulting in these
invalid hybrid types being created.

Signed-off-by: Eric Schweitz <[email protected]>
  • Loading branch information
schweitzpgi committed Mar 5, 2025
1 parent d77d6de commit 26042f2
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 6 deletions.
4 changes: 4 additions & 0 deletions include/cudaq/Optimizer/Dialect/CC/CCTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def cc_StructType : CCType<"Struct", "struct",
);

let hasCustomAssemblyFormat = 1;
let genVerifyDecl = 1;

let builders = [
TypeBuilder<(ins CArg<"llvm::StringRef">:$name,
Expand Down Expand Up @@ -154,6 +155,7 @@ def cc_ArrayType : CCType<"Array", "array"> {
);

let hasCustomAssemblyFormat = 1;
let genVerifyDecl = 1;

let extraClassDeclaration = [{
using SizeType = std::int64_t;
Expand Down Expand Up @@ -249,6 +251,8 @@ def cc_StdVectorType : CCType<"Stdvec", "stdvec", [],

let assemblyFormat = "`<` qualified($elementType) `>`";

let genVerifyDecl = 1;

let builders = [
TypeBuilderWithInferredContext<(ins "mlir::Type":$elementType), [{
return Base::get(elementType.getContext(), elementType);
Expand Down
9 changes: 6 additions & 3 deletions lib/Frontend/nvqpp/ASTBridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -430,14 +430,17 @@ namespace cudaq::details {

bool QuakeBridgeVisitor::generateFunctionDeclaration(
StringRef funcName, const clang::FunctionDecl *x) {
auto loc = toLocation(x);
allowUnknownRecordType = true;
if (!TraverseType(x->getType()))
emitFatalError(loc, "failed to generate type for kernel function");
if (!TraverseType(x->getType())) {
reportClangError(x, mangler, "failed to generate type for kernel function");
typeStack.clear();
return false;
}
allowUnknownRecordType = false;
if (!doSyntaxChecks(x))
return false;
auto funcTy = cast<FunctionType>(popType());
auto loc = toLocation(x);
[[maybe_unused]] auto fnPair = getOrAddFunc(loc, funcName, funcTy);
assert(fnPair.first && "expected FuncOp to be created");
if (!isa<clang::CXXMethodDecl>(x) || x->isStatic())
Expand Down
10 changes: 9 additions & 1 deletion lib/Frontend/nvqpp/ConvertDecl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,15 @@ bool QuakeBridgeVisitor::interceptRecordDecl(clang::RecordDecl *x) {
// Traverse template argument 0 to get the vector's element type.
if (!cts || !TraverseType(cts->getTemplateArgs()[0].getAsType()))
return false;
return pushType(cc::StdvecType::get(ctx, popType()));
auto ty = popType();
if (quake::isQuantumType(ty)) {
if (ty == quake::RefType::get(ctx))
return pushType(quake::VeqType::getUnsized(ctx));
cudaq::emitFatalError(toLocation(x->getSourceRange()),
"std::vector element type is not supported");
return false;
}
return pushType(cc::StdvecType::get(ctx, ty));
}
// std::vector<bool> => cc.stdvec<i1>
if (name.equals("_Bit_reference") || name.equals("__bit_reference")) {
Expand Down
24 changes: 22 additions & 2 deletions lib/Frontend/nvqpp/ConvertType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,10 +242,21 @@ bool QuakeBridgeVisitor::VisitRecordDecl(clang::RecordDecl *x) {
SmallVector<Type> fieldTys =
lastTypes(std::distance(x->field_begin(), x->field_end()));
auto [width, alignInBytes] = getWidthAndAlignment(x);

// This is a struq if it is not empty and all members are quantum references.
bool isStruq = !fieldTys.empty();
for (auto ty : fieldTys)
bool quantumMembers = false;
for (auto ty : fieldTys) {
if (quake::isQuantumType(ty))
quantumMembers = true;
if (!quake::isQuantumReferenceType(ty))
isStruq = false;
}
if (quantumMembers && !isStruq) {
reportClangError(x, mangler,
"hybrid quantum-classical struct types are not allowed");
return false;
}

auto ty = [&]() -> Type {
if (isStruq)
Expand Down Expand Up @@ -458,7 +469,16 @@ bool QuakeBridgeVisitor::VisitRValueReferenceType(

bool QuakeBridgeVisitor::VisitConstantArrayType(clang::ConstantArrayType *t) {
auto size = t->getSize().getZExtValue();
return pushType(cc::ArrayType::get(builder.getContext(), popType(), size));
auto ty = popType();
if (quake::isQuantumType(ty)) {
auto *ctx = builder.getContext();
if (ty == quake::RefType::get(ctx))
return pushType(quake::VeqType::getUnsized(ctx));
emitFatalError(builder.getUnknownLoc(),
"array element type is not supported");
return false;
}
return pushType(cc::ArrayType::get(builder.getContext(), ty, size));
}

bool QuakeBridgeVisitor::pushType(Type t) {
Expand Down
29 changes: 29 additions & 0 deletions lib/Optimizer/Dialect/CC/CCTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "cudaq/Optimizer/Dialect/CC/CCTypes.h"
#include "cudaq/Optimizer/Dialect/CC/CCDialect.h"
#include "cudaq/Optimizer/Dialect/Quake/QuakeTypes.h"
#include "llvm/ADT/TypeSwitch.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
Expand Down Expand Up @@ -103,6 +104,16 @@ cc::StructType::getPreferredAlignment(const DataLayout &dataLayout,
return getAlignment();
}

LogicalResult
cc::StructType::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
mlir::StringAttr, llvm::ArrayRef<mlir::Type> members,
bool, bool, unsigned long, unsigned int) {
for (auto ty : members)
if (quake::isQuantumType(ty))
return emitError() << "cc.struct may not contain quake types: " << ty;
return success();
}

//===----------------------------------------------------------------------===//
// ArrayType
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -139,6 +150,24 @@ void cc::ArrayType::print(AsmPrinter &printer) const {
printer << '>';
}

LogicalResult
cc::ArrayType::verify(function_ref<InFlightDiagnostic()> emitError, Type eleTy,
long) {
if (quake::isQuantumType(eleTy))
return emitError() << "cc.array may not have a quake element type: "
<< eleTy;
return success();
}

LogicalResult
cc::StdvecType::verify(function_ref<InFlightDiagnostic()> emitError,
Type eleTy) {
if (quake::isQuantumType(eleTy))
return emitError() << "cc.stdvec may not have a quake element type: "
<< eleTy;
return success();
}

} // namespace cudaq

//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions test/AST-error/struct_quantum_and_classical.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ struct test {

__qpu__ void hello(cudaq::qubit &q) { h(q); }

// expected-error@+1 {{failed to generate type for kernel function}}
__qpu__ void kernel(test t) {
h(t.q);
hello(t.q[0]);
Expand Down

0 comments on commit 26042f2

Please sign in to comment.