Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate SPV_NV_cooperative_vector #5972

Merged
merged 1 commit into from
Jan 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DEPS
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ vars = {

're2_revision': '6dcd83d60f7944926bfd308cc13979fc53dd69ca',

'spirv_headers_revision': '2b2e05e088841c63c0b6fd4c9fb380d8688738d3',
'spirv_headers_revision': '767e901c986e9755a17e7939b3046fc2911a4bbd',
}

deps = {
Expand Down
3 changes: 3 additions & 0 deletions include/spirv-tools/libspirv.h
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,9 @@ typedef enum spv_operand_type_t {
SPV_OPERAND_TYPE_MATRIX_MULTIPLY_ACCUMULATE_OPERANDS,
SPV_OPERAND_TYPE_OPTIONAL_MATRIX_MULTIPLY_ACCUMULATE_OPERANDS,

SPV_OPERAND_TYPE_COOPERATIVE_VECTOR_MATRIX_LAYOUT,
SPV_OPERAND_TYPE_COMPONENT_TYPE,

// This is a sentinel value, and does not represent an operand type.
// It should come last.
SPV_OPERAND_TYPE_NUM_OPERAND_TYPES,
Expand Down
2 changes: 2 additions & 0 deletions source/opcode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ int32_t spvOpcodeIsComposite(const spv::Op opcode) {
case spv::Op::OpTypeRuntimeArray:
case spv::Op::OpTypeCooperativeMatrixNV:
case spv::Op::OpTypeCooperativeMatrixKHR:
case spv::Op::OpTypeCooperativeVectorNV:
return true;
default:
return false;
Expand Down Expand Up @@ -381,6 +382,7 @@ int32_t spvOpcodeGeneratesType(spv::Op op) {
case spv::Op::OpTypeAccelerationStructureNV:
case spv::Op::OpTypeCooperativeMatrixNV:
case spv::Op::OpTypeCooperativeMatrixKHR:
case spv::Op::OpTypeCooperativeVectorNV:
// case spv::Op::OpTypeAccelerationStructureKHR: covered by
// spv::Op::OpTypeAccelerationStructureNV
case spv::Op::OpTypeRayQueryKHR:
Expand Down
6 changes: 6 additions & 0 deletions source/operand.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,10 @@ const char* spvOperandTypeStr(spv_operand_type_t type) {
return "quantization mode";
case SPV_OPERAND_TYPE_OVERFLOW_MODES:
return "overflow mode";
case SPV_OPERAND_TYPE_COOPERATIVE_VECTOR_MATRIX_LAYOUT:
return "cooperative vector matrix layout";
case SPV_OPERAND_TYPE_COMPONENT_TYPE:
return "component type";

case SPV_OPERAND_TYPE_NONE:
return "NONE";
Expand Down Expand Up @@ -399,6 +403,8 @@ bool spvOperandIsConcrete(spv_operand_type_t type) {
case SPV_OPERAND_TYPE_NAMED_MAXIMUM_NUMBER_OF_REGISTERS:
case SPV_OPERAND_TYPE_FPENCODING:
case SPV_OPERAND_TYPE_TENSOR_CLAMP_MODE:
case SPV_OPERAND_TYPE_COOPERATIVE_VECTOR_MATRIX_LAYOUT:
case SPV_OPERAND_TYPE_COMPONENT_TYPE:
return true;
default:
break;
Expand Down
5 changes: 5 additions & 0 deletions source/opt/eliminate_dead_members_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ void EliminateDeadMembersPass::MarkMembersAsLiveForExtract(
case spv::Op::OpTypeMatrix:
case spv::Op::OpTypeCooperativeMatrixNV:
case spv::Op::OpTypeCooperativeMatrixKHR:
case spv::Op::OpTypeCooperativeVectorNV:
type_id = type_inst->GetSingleWordInOperand(0);
break;
default:
Expand Down Expand Up @@ -255,6 +256,7 @@ void EliminateDeadMembersPass::MarkMembersAsLiveForAccessChain(
case spv::Op::OpTypeMatrix:
case spv::Op::OpTypeCooperativeMatrixNV:
case spv::Op::OpTypeCooperativeMatrixKHR:
case spv::Op::OpTypeCooperativeVectorNV:
type_id = type_inst->GetSingleWordInOperand(0);
break;
default:
Expand Down Expand Up @@ -516,6 +518,7 @@ bool EliminateDeadMembersPass::UpdateAccessChain(Instruction* inst) {
case spv::Op::OpTypeMatrix:
case spv::Op::OpTypeCooperativeMatrixNV:
case spv::Op::OpTypeCooperativeMatrixKHR:
case spv::Op::OpTypeCooperativeVectorNV:
new_operands.emplace_back(inst->GetInOperand(i));
type_id = type_inst->GetSingleWordInOperand(0);
break;
Expand Down Expand Up @@ -591,6 +594,7 @@ bool EliminateDeadMembersPass::UpdateCompsiteExtract(Instruction* inst) {
case spv::Op::OpTypeMatrix:
case spv::Op::OpTypeCooperativeMatrixNV:
case spv::Op::OpTypeCooperativeMatrixKHR:
case spv::Op::OpTypeCooperativeVectorNV:
type_id = type_inst->GetSingleWordInOperand(0);
break;
default:
Expand Down Expand Up @@ -654,6 +658,7 @@ bool EliminateDeadMembersPass::UpdateCompositeInsert(Instruction* inst) {
case spv::Op::OpTypeMatrix:
case spv::Op::OpTypeCooperativeMatrixNV:
case spv::Op::OpTypeCooperativeMatrixKHR:
case spv::Op::OpTypeCooperativeVectorNV:
type_id = type_inst->GetSingleWordInOperand(0);
break;
default:
Expand Down
5 changes: 4 additions & 1 deletion source/opt/folding_rules.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,10 @@ int32_t ImageOperandsMaskInOperandIndex(Instruction* inst) {

// Returns the element width of |type|.
uint32_t ElementWidth(const analysis::Type* type) {
if (const analysis::Vector* vec_type = type->AsVector()) {
if (const analysis::CooperativeVectorNV* coopvec_type =
type->AsCooperativeVectorNV()) {
return ElementWidth(coopvec_type->component_type());
} else if (const analysis::Vector* vec_type = type->AsVector()) {
return ElementWidth(vec_type->element_type());
} else if (const analysis::Float* float_type = type->AsFloat()) {
return float_type->width();
Expand Down
26 changes: 26 additions & 0 deletions source/opt/type_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,20 @@ uint32_t TypeManager::GetTypeInstruction(const Type* type) {
0, id, operands);
break;
}
case Type::kCooperativeVectorNV: {
auto coop_vec = type->AsCooperativeVectorNV();
uint32_t const component_type =
GetTypeInstruction(coop_vec->component_type());
if (component_type == 0) {
return 0;
}
typeInst = MakeUnique<Instruction>(
context(), spv::Op::OpTypeCooperativeVectorNV, 0, id,
std::initializer_list<Operand>{
{SPV_OPERAND_TYPE_ID, {component_type}},
{SPV_OPERAND_TYPE_ID, {coop_vec->components()}}});
break;
}
default:
assert(false && "Unexpected type");
break;
Expand Down Expand Up @@ -721,6 +735,14 @@ Type* TypeManager::RebuildType(uint32_t type_id, const Type& type) {
tv_type->dim_id(), tv_type->has_dimensions_id(), tv_type->perm());
break;
}
case Type::kCooperativeVectorNV: {
const CooperativeVectorNV* cv_type = type.AsCooperativeVectorNV();
const Type* component_type = cv_type->component_type();
rebuilt_ty = MakeUnique<CooperativeVectorNV>(
RebuildType(GetId(component_type), *component_type),
cv_type->components());
break;
}
default:
assert(false && "Unhandled type");
return nullptr;
Expand Down Expand Up @@ -970,6 +992,10 @@ Type* TypeManager::RecordIfTypeDefinition(const Instruction& inst) {
inst.GetSingleWordInOperand(1), inst.GetSingleWordInOperand(2),
inst.GetSingleWordInOperand(3), inst.GetSingleWordInOperand(4));
break;
case spv::Op::OpTypeCooperativeVectorNV:
type = new CooperativeVectorNV(GetType(inst.GetSingleWordInOperand(0)),
inst.GetSingleWordInOperand(1));
break;
case spv::Op::OpTypeRayQueryKHR:
type = new RayQueryKHR();
break;
Expand Down
32 changes: 32 additions & 0 deletions source/opt/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ std::unique_ptr<Type> Type::Clone() const {
DeclareKindCase(AccelerationStructureNV);
DeclareKindCase(CooperativeMatrixNV);
DeclareKindCase(CooperativeMatrixKHR);
DeclareKindCase(CooperativeVectorNV);
DeclareKindCase(RayQueryKHR);
DeclareKindCase(HitObjectNV);
#undef DeclareKindCase
Expand Down Expand Up @@ -181,6 +182,7 @@ bool Type::operator==(const Type& other) const {
DeclareKindCase(AccelerationStructureNV);
DeclareKindCase(CooperativeMatrixNV);
DeclareKindCase(CooperativeMatrixKHR);
DeclareKindCase(CooperativeVectorNV);
DeclareKindCase(RayQueryKHR);
DeclareKindCase(HitObjectNV);
DeclareKindCase(TensorLayoutNV);
Expand Down Expand Up @@ -240,6 +242,7 @@ size_t Type::ComputeHashValue(size_t hash, SeenTypes* seen) const {
DeclareKindCase(AccelerationStructureNV);
DeclareKindCase(CooperativeMatrixNV);
DeclareKindCase(CooperativeMatrixKHR);
DeclareKindCase(CooperativeVectorNV);
DeclareKindCase(RayQueryKHR);
DeclareKindCase(HitObjectNV);
DeclareKindCase(TensorLayoutNV);
Expand Down Expand Up @@ -835,6 +838,35 @@ bool TensorViewNV::IsSameImpl(const Type* that, IsSameCache*) const {
has_dimensions_id_ == tv->has_dimensions_id_ && perm_ == tv->perm_;
}

CooperativeVectorNV::CooperativeVectorNV(const Type* type,
const uint32_t components)
: Type(kCooperativeVectorNV),
component_type_(type),
components_(components) {
assert(type != nullptr);
assert(components != 0);
}

std::string CooperativeVectorNV::str() const {
std::ostringstream oss;
oss << "<" << component_type_->str() << ", " << components_ << ">";
return oss.str();
}

size_t CooperativeVectorNV::ComputeExtraStateHash(size_t hash,
SeenTypes* seen) const {
hash = hash_combine(hash, components_);
return component_type_->ComputeHashValue(hash, seen);
}

bool CooperativeVectorNV::IsSameImpl(const Type* that,
IsSameCache* seen) const {
const CooperativeVectorNV* mt = that->AsCooperativeVectorNV();
if (!mt) return false;
return component_type_->IsSameImpl(mt->component_type_, seen) &&
components_ == mt->components_ && HasSameDecorations(that);
}

} // namespace analysis
} // namespace opt
} // namespace spvtools
27 changes: 27 additions & 0 deletions source/opt/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class NamedBarrier;
class AccelerationStructureNV;
class CooperativeMatrixNV;
class CooperativeMatrixKHR;
class CooperativeVectorNV;
class RayQueryKHR;
class HitObjectNV;
class TensorLayoutNV;
Expand Down Expand Up @@ -108,6 +109,7 @@ class Type {
kAccelerationStructureNV,
kCooperativeMatrixNV,
kCooperativeMatrixKHR,
kCooperativeVectorNV,
kRayQueryKHR,
kHitObjectNV,
kTensorLayoutNV,
Expand Down Expand Up @@ -213,6 +215,7 @@ class Type {
DeclareCastMethod(AccelerationStructureNV)
DeclareCastMethod(CooperativeMatrixNV)
DeclareCastMethod(CooperativeMatrixKHR)
DeclareCastMethod(CooperativeVectorNV)
DeclareCastMethod(RayQueryKHR)
DeclareCastMethod(HitObjectNV)
DeclareCastMethod(TensorLayoutNV)
Expand Down Expand Up @@ -742,6 +745,30 @@ class TensorViewNV : public Type {
std::vector<uint32_t> perm_;
};

class CooperativeVectorNV : public Type {
public:
CooperativeVectorNV(const Type* type, const uint32_t components);
CooperativeVectorNV(const CooperativeVectorNV&) = default;

std::string str() const override;

CooperativeVectorNV* AsCooperativeVectorNV() override { return this; }
const CooperativeVectorNV* AsCooperativeVectorNV() const override {
return this;
}

size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override;

const Type* component_type() const { return component_type_; }
uint32_t components() const { return components_; }

private:
bool IsSameImpl(const Type* that, IsSameCache*) const override;

const Type* component_type_;
const uint32_t components_;
};

#define DefineParameterlessType(type, name) \
class type : public Type { \
public: \
Expand Down
61 changes: 54 additions & 7 deletions source/val/validate_arithmetics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,34 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) {
bool supportsCoopMat =
(opcode != spv::Op::OpFMul && opcode != spv::Op::OpFRem &&
opcode != spv::Op::OpFMod);
bool supportsCoopVec =
(opcode != spv::Op::OpFRem && opcode != spv::Op::OpFMod);
if (!_.IsFloatScalarType(result_type) &&
!_.IsFloatVectorType(result_type) &&
!(supportsCoopMat && _.IsFloatCooperativeMatrixType(result_type)) &&
!(opcode == spv::Op::OpFMul &&
_.IsCooperativeMatrixKHRType(result_type) &&
_.IsFloatCooperativeMatrixType(result_type)))
_.IsFloatCooperativeMatrixType(result_type)) &&
!(supportsCoopVec && _.IsFloatCooperativeVectorNVType(result_type)))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected floating scalar or vector type as Result Type: "
<< spvOpcodeString(opcode);

for (size_t operand_index = 2; operand_index < inst->operands().size();
++operand_index) {
if (supportsCoopMat && _.IsCooperativeMatrixKHRType(result_type)) {
if (supportsCoopVec && _.IsCooperativeVectorNVType(result_type)) {
const uint32_t type_id = _.GetOperandTypeId(inst, operand_index);
if (!_.IsCooperativeVectorNVType(type_id)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected arithmetic operands to be of Result Type: "
<< spvOpcodeString(opcode) << " operand index "
<< operand_index;
}
spv_result_t ret =
_.CooperativeVectorDimensionsMatch(inst, type_id, result_type);
if (ret != SPV_SUCCESS) return ret;
} else if (supportsCoopMat &&
_.IsCooperativeMatrixKHRType(result_type)) {
const uint32_t type_id = _.GetOperandTypeId(inst, operand_index);
if (!_.IsCooperativeMatrixKHRType(type_id) ||
!_.IsFloatCooperativeMatrixType(type_id)) {
Expand All @@ -76,17 +91,32 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) {
case spv::Op::OpUDiv:
case spv::Op::OpUMod: {
bool supportsCoopMat = (opcode == spv::Op::OpUDiv);
bool supportsCoopVec = (opcode == spv::Op::OpUDiv);
if (!_.IsUnsignedIntScalarType(result_type) &&
!_.IsUnsignedIntVectorType(result_type) &&
!(supportsCoopMat &&
_.IsUnsignedIntCooperativeMatrixType(result_type)))
_.IsUnsignedIntCooperativeMatrixType(result_type)) &&
!(supportsCoopVec &&
_.IsUnsignedIntCooperativeVectorNVType(result_type)))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected unsigned int scalar or vector type as Result Type: "
<< spvOpcodeString(opcode);

for (size_t operand_index = 2; operand_index < inst->operands().size();
++operand_index) {
if (supportsCoopMat && _.IsCooperativeMatrixKHRType(result_type)) {
if (supportsCoopVec && _.IsCooperativeVectorNVType(result_type)) {
const uint32_t type_id = _.GetOperandTypeId(inst, operand_index);
if (!_.IsCooperativeVectorNVType(type_id)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected arithmetic operands to be of Result Type: "
<< spvOpcodeString(opcode) << " operand index "
<< operand_index;
}
spv_result_t ret =
_.CooperativeVectorDimensionsMatch(inst, type_id, result_type);
if (ret != SPV_SUCCESS) return ret;
} else if (supportsCoopMat &&
_.IsCooperativeMatrixKHRType(result_type)) {
const uint32_t type_id = _.GetOperandTypeId(inst, operand_index);
if (!_.IsCooperativeMatrixKHRType(type_id) ||
!_.IsUnsignedIntCooperativeMatrixType(type_id)) {
Expand Down Expand Up @@ -117,11 +147,14 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) {
bool supportsCoopMat =
(opcode != spv::Op::OpIMul && opcode != spv::Op::OpSRem &&
opcode != spv::Op::OpSMod);
bool supportsCoopVec =
(opcode != spv::Op::OpSRem && opcode != spv::Op::OpSMod);
if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type) &&
!(supportsCoopMat && _.IsIntCooperativeMatrixType(result_type)) &&
!(opcode == spv::Op::OpIMul &&
_.IsCooperativeMatrixKHRType(result_type) &&
_.IsIntCooperativeMatrixType(result_type)))
_.IsIntCooperativeMatrixType(result_type)) &&
!(supportsCoopVec && _.IsIntCooperativeVectorNVType(result_type)))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected int scalar or vector type as Result Type: "
<< spvOpcodeString(opcode);
Expand All @@ -133,6 +166,18 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) {
++operand_index) {
const uint32_t type_id = _.GetOperandTypeId(inst, operand_index);

if (supportsCoopVec && _.IsCooperativeVectorNVType(result_type)) {
if (!_.IsCooperativeVectorNVType(type_id)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected arithmetic operands to be of Result Type: "
<< spvOpcodeString(opcode) << " operand index "
<< operand_index;
}
spv_result_t ret =
_.CooperativeVectorDimensionsMatch(inst, type_id, result_type);
if (ret != SPV_SUCCESS) return ret;
}

if (supportsCoopMat && _.IsCooperativeMatrixKHRType(result_type)) {
if (!_.IsCooperativeMatrixKHRType(type_id) ||
!_.IsIntCooperativeMatrixType(type_id)) {
Expand All @@ -151,7 +196,8 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) {
!(supportsCoopMat && _.IsIntCooperativeMatrixType(result_type)) &&
!(opcode == spv::Op::OpIMul &&
_.IsCooperativeMatrixKHRType(result_type) &&
_.IsIntCooperativeMatrixType(result_type))))
_.IsIntCooperativeMatrixType(result_type)) &&
!(supportsCoopVec && _.IsIntCooperativeVectorNVType(result_type))))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected int scalar or vector type as operand: "
<< spvOpcodeString(opcode) << " operand index "
Expand Down Expand Up @@ -210,7 +256,8 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) {
}

case spv::Op::OpVectorTimesScalar: {
if (!_.IsFloatVectorType(result_type))
if (!_.IsFloatVectorType(result_type) &&
!_.IsFloatCooperativeVectorNVType(result_type))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected float vector type as Result Type: "
<< spvOpcodeString(opcode);
Expand Down
Loading