Skip to content

Commit

Permalink
Validate SPV_NV_cooperative_vector
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffbolznv committed Jan 29, 2025
1 parent d99e54e commit f957854
Show file tree
Hide file tree
Showing 24 changed files with 1,810 additions and 95 deletions.
2 changes: 1 addition & 1 deletion DEPS
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ vars = {

're2_revision': '6dcd83d60f7944926bfd308cc13979fc53dd69ca',

'spirv_headers_revision': '2b2e05e088841c63c0b6fd4c9fb380d8688738d3',
'spirv_headers_revision': '767e901c986e9755a17e7939b3046fc2911a4bbd',
}

deps = {
Expand Down
3 changes: 3 additions & 0 deletions include/spirv-tools/libspirv.h
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,9 @@ typedef enum spv_operand_type_t {
SPV_OPERAND_TYPE_MATRIX_MULTIPLY_ACCUMULATE_OPERANDS,
SPV_OPERAND_TYPE_OPTIONAL_MATRIX_MULTIPLY_ACCUMULATE_OPERANDS,

SPV_OPERAND_TYPE_COOPERATIVE_VECTOR_MATRIX_LAYOUT,
SPV_OPERAND_TYPE_COMPONENT_TYPE,

// This is a sentinel value, and does not represent an operand type.
// It should come last.
SPV_OPERAND_TYPE_NUM_OPERAND_TYPES,
Expand Down
2 changes: 2 additions & 0 deletions source/opcode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ int32_t spvOpcodeIsComposite(const spv::Op opcode) {
case spv::Op::OpTypeRuntimeArray:
case spv::Op::OpTypeCooperativeMatrixNV:
case spv::Op::OpTypeCooperativeMatrixKHR:
case spv::Op::OpTypeCooperativeVectorNV:
return true;
default:
return false;
Expand Down Expand Up @@ -381,6 +382,7 @@ int32_t spvOpcodeGeneratesType(spv::Op op) {
case spv::Op::OpTypeAccelerationStructureNV:
case spv::Op::OpTypeCooperativeMatrixNV:
case spv::Op::OpTypeCooperativeMatrixKHR:
case spv::Op::OpTypeCooperativeVectorNV:
// case spv::Op::OpTypeAccelerationStructureKHR: covered by
// spv::Op::OpTypeAccelerationStructureNV
case spv::Op::OpTypeRayQueryKHR:
Expand Down
6 changes: 6 additions & 0 deletions source/operand.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,10 @@ const char* spvOperandTypeStr(spv_operand_type_t type) {
return "quantization mode";
case SPV_OPERAND_TYPE_OVERFLOW_MODES:
return "overflow mode";
case SPV_OPERAND_TYPE_COOPERATIVE_VECTOR_MATRIX_LAYOUT:
return "cooperative vector matrix layout";
case SPV_OPERAND_TYPE_COMPONENT_TYPE:
return "component type";

case SPV_OPERAND_TYPE_NONE:
return "NONE";
Expand Down Expand Up @@ -399,6 +403,8 @@ bool spvOperandIsConcrete(spv_operand_type_t type) {
case SPV_OPERAND_TYPE_NAMED_MAXIMUM_NUMBER_OF_REGISTERS:
case SPV_OPERAND_TYPE_FPENCODING:
case SPV_OPERAND_TYPE_TENSOR_CLAMP_MODE:
case SPV_OPERAND_TYPE_COOPERATIVE_VECTOR_MATRIX_LAYOUT:
case SPV_OPERAND_TYPE_COMPONENT_TYPE:
return true;
default:
break;
Expand Down
5 changes: 5 additions & 0 deletions source/opt/eliminate_dead_members_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ void EliminateDeadMembersPass::MarkMembersAsLiveForExtract(
case spv::Op::OpTypeMatrix:
case spv::Op::OpTypeCooperativeMatrixNV:
case spv::Op::OpTypeCooperativeMatrixKHR:
case spv::Op::OpTypeCooperativeVectorNV:
type_id = type_inst->GetSingleWordInOperand(0);
break;
default:
Expand Down Expand Up @@ -255,6 +256,7 @@ void EliminateDeadMembersPass::MarkMembersAsLiveForAccessChain(
case spv::Op::OpTypeMatrix:
case spv::Op::OpTypeCooperativeMatrixNV:
case spv::Op::OpTypeCooperativeMatrixKHR:
case spv::Op::OpTypeCooperativeVectorNV:
type_id = type_inst->GetSingleWordInOperand(0);
break;
default:
Expand Down Expand Up @@ -516,6 +518,7 @@ bool EliminateDeadMembersPass::UpdateAccessChain(Instruction* inst) {
case spv::Op::OpTypeMatrix:
case spv::Op::OpTypeCooperativeMatrixNV:
case spv::Op::OpTypeCooperativeMatrixKHR:
case spv::Op::OpTypeCooperativeVectorNV:
new_operands.emplace_back(inst->GetInOperand(i));
type_id = type_inst->GetSingleWordInOperand(0);
break;
Expand Down Expand Up @@ -591,6 +594,7 @@ bool EliminateDeadMembersPass::UpdateCompsiteExtract(Instruction* inst) {
case spv::Op::OpTypeMatrix:
case spv::Op::OpTypeCooperativeMatrixNV:
case spv::Op::OpTypeCooperativeMatrixKHR:
case spv::Op::OpTypeCooperativeVectorNV:
type_id = type_inst->GetSingleWordInOperand(0);
break;
default:
Expand Down Expand Up @@ -654,6 +658,7 @@ bool EliminateDeadMembersPass::UpdateCompositeInsert(Instruction* inst) {
case spv::Op::OpTypeMatrix:
case spv::Op::OpTypeCooperativeMatrixNV:
case spv::Op::OpTypeCooperativeMatrixKHR:
case spv::Op::OpTypeCooperativeVectorNV:
type_id = type_inst->GetSingleWordInOperand(0);
break;
default:
Expand Down
5 changes: 4 additions & 1 deletion source/opt/folding_rules.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,10 @@ int32_t ImageOperandsMaskInOperandIndex(Instruction* inst) {

// Returns the element width of |type|.
uint32_t ElementWidth(const analysis::Type* type) {
if (const analysis::Vector* vec_type = type->AsVector()) {
if (const analysis::CooperativeVectorNV* coopvec_type =
type->AsCooperativeVectorNV()) {
return ElementWidth(coopvec_type->component_type());
} else if (const analysis::Vector* vec_type = type->AsVector()) {
return ElementWidth(vec_type->element_type());
} else if (const analysis::Float* float_type = type->AsFloat()) {
return float_type->width();
Expand Down
26 changes: 26 additions & 0 deletions source/opt/type_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,20 @@ uint32_t TypeManager::GetTypeInstruction(const Type* type) {
0, id, operands);
break;
}
case Type::kCooperativeVectorNV: {
auto coop_vec = type->AsCooperativeVectorNV();
uint32_t const component_type =
GetTypeInstruction(coop_vec->component_type());
if (component_type == 0) {
return 0;
}
typeInst = MakeUnique<Instruction>(
context(), spv::Op::OpTypeCooperativeVectorNV, 0, id,
std::initializer_list<Operand>{
{SPV_OPERAND_TYPE_ID, {component_type}},
{SPV_OPERAND_TYPE_ID, {coop_vec->components()}}});
break;
}
default:
assert(false && "Unexpected type");
break;
Expand Down Expand Up @@ -721,6 +735,14 @@ Type* TypeManager::RebuildType(uint32_t type_id, const Type& type) {
tv_type->dim_id(), tv_type->has_dimensions_id(), tv_type->perm());
break;
}
case Type::kCooperativeVectorNV: {
const CooperativeVectorNV* cv_type = type.AsCooperativeVectorNV();
const Type* component_type = cv_type->component_type();
rebuilt_ty = MakeUnique<CooperativeVectorNV>(
RebuildType(GetId(component_type), *component_type),
cv_type->components());
break;
}
default:
assert(false && "Unhandled type");
return nullptr;
Expand Down Expand Up @@ -970,6 +992,10 @@ Type* TypeManager::RecordIfTypeDefinition(const Instruction& inst) {
inst.GetSingleWordInOperand(1), inst.GetSingleWordInOperand(2),
inst.GetSingleWordInOperand(3), inst.GetSingleWordInOperand(4));
break;
case spv::Op::OpTypeCooperativeVectorNV:
type = new CooperativeVectorNV(GetType(inst.GetSingleWordInOperand(0)),
inst.GetSingleWordInOperand(1));
break;
case spv::Op::OpTypeRayQueryKHR:
type = new RayQueryKHR();
break;
Expand Down
32 changes: 32 additions & 0 deletions source/opt/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ std::unique_ptr<Type> Type::Clone() const {
DeclareKindCase(AccelerationStructureNV);
DeclareKindCase(CooperativeMatrixNV);
DeclareKindCase(CooperativeMatrixKHR);
DeclareKindCase(CooperativeVectorNV);
DeclareKindCase(RayQueryKHR);
DeclareKindCase(HitObjectNV);
#undef DeclareKindCase
Expand Down Expand Up @@ -181,6 +182,7 @@ bool Type::operator==(const Type& other) const {
DeclareKindCase(AccelerationStructureNV);
DeclareKindCase(CooperativeMatrixNV);
DeclareKindCase(CooperativeMatrixKHR);
DeclareKindCase(CooperativeVectorNV);
DeclareKindCase(RayQueryKHR);
DeclareKindCase(HitObjectNV);
DeclareKindCase(TensorLayoutNV);
Expand Down Expand Up @@ -240,6 +242,7 @@ size_t Type::ComputeHashValue(size_t hash, SeenTypes* seen) const {
DeclareKindCase(AccelerationStructureNV);
DeclareKindCase(CooperativeMatrixNV);
DeclareKindCase(CooperativeMatrixKHR);
DeclareKindCase(CooperativeVectorNV);
DeclareKindCase(RayQueryKHR);
DeclareKindCase(HitObjectNV);
DeclareKindCase(TensorLayoutNV);
Expand Down Expand Up @@ -835,6 +838,35 @@ bool TensorViewNV::IsSameImpl(const Type* that, IsSameCache*) const {
has_dimensions_id_ == tv->has_dimensions_id_ && perm_ == tv->perm_;
}

CooperativeVectorNV::CooperativeVectorNV(const Type* type,
const uint32_t components)
: Type(kCooperativeVectorNV),
component_type_(type),
components_(components) {
assert(type != nullptr);
assert(components != 0);
}

std::string CooperativeVectorNV::str() const {
std::ostringstream oss;
oss << "<" << component_type_->str() << ", " << components_ << ">";
return oss.str();
}

size_t CooperativeVectorNV::ComputeExtraStateHash(size_t hash,
SeenTypes* seen) const {
hash = hash_combine(hash, components_);
return component_type_->ComputeHashValue(hash, seen);
}

bool CooperativeVectorNV::IsSameImpl(const Type* that,
IsSameCache* seen) const {
const CooperativeVectorNV* mt = that->AsCooperativeVectorNV();
if (!mt) return false;
return component_type_->IsSameImpl(mt->component_type_, seen) &&
components_ == mt->components_ && HasSameDecorations(that);
}

} // namespace analysis
} // namespace opt
} // namespace spvtools
27 changes: 27 additions & 0 deletions source/opt/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class NamedBarrier;
class AccelerationStructureNV;
class CooperativeMatrixNV;
class CooperativeMatrixKHR;
class CooperativeVectorNV;
class RayQueryKHR;
class HitObjectNV;
class TensorLayoutNV;
Expand Down Expand Up @@ -108,6 +109,7 @@ class Type {
kAccelerationStructureNV,
kCooperativeMatrixNV,
kCooperativeMatrixKHR,
kCooperativeVectorNV,
kRayQueryKHR,
kHitObjectNV,
kTensorLayoutNV,
Expand Down Expand Up @@ -213,6 +215,7 @@ class Type {
DeclareCastMethod(AccelerationStructureNV)
DeclareCastMethod(CooperativeMatrixNV)
DeclareCastMethod(CooperativeMatrixKHR)
DeclareCastMethod(CooperativeVectorNV)
DeclareCastMethod(RayQueryKHR)
DeclareCastMethod(HitObjectNV)
DeclareCastMethod(TensorLayoutNV)
Expand Down Expand Up @@ -742,6 +745,30 @@ class TensorViewNV : public Type {
std::vector<uint32_t> perm_;
};

class CooperativeVectorNV : public Type {
public:
CooperativeVectorNV(const Type* type, const uint32_t components);
CooperativeVectorNV(const CooperativeVectorNV&) = default;

std::string str() const override;

CooperativeVectorNV* AsCooperativeVectorNV() override { return this; }
const CooperativeVectorNV* AsCooperativeVectorNV() const override {
return this;
}

size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override;

const Type* component_type() const { return component_type_; }
uint32_t components() const { return components_; }

private:
bool IsSameImpl(const Type* that, IsSameCache*) const override;

const Type* component_type_;
const uint32_t components_;
};

#define DefineParameterlessType(type, name) \
class type : public Type { \
public: \
Expand Down
61 changes: 54 additions & 7 deletions source/val/validate_arithmetics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,34 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) {
bool supportsCoopMat =
(opcode != spv::Op::OpFMul && opcode != spv::Op::OpFRem &&
opcode != spv::Op::OpFMod);
bool supportsCoopVec =
(opcode != spv::Op::OpFRem && opcode != spv::Op::OpFMod);
if (!_.IsFloatScalarType(result_type) &&
!_.IsFloatVectorType(result_type) &&
!(supportsCoopMat && _.IsFloatCooperativeMatrixType(result_type)) &&
!(opcode == spv::Op::OpFMul &&
_.IsCooperativeMatrixKHRType(result_type) &&
_.IsFloatCooperativeMatrixType(result_type)))
_.IsFloatCooperativeMatrixType(result_type)) &&
!(supportsCoopVec && _.IsFloatCooperativeVectorNVType(result_type)))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected floating scalar or vector type as Result Type: "
<< spvOpcodeString(opcode);

for (size_t operand_index = 2; operand_index < inst->operands().size();
++operand_index) {
if (supportsCoopMat && _.IsCooperativeMatrixKHRType(result_type)) {
if (supportsCoopVec && _.IsCooperativeVectorNVType(result_type)) {
const uint32_t type_id = _.GetOperandTypeId(inst, operand_index);
if (!_.IsCooperativeVectorNVType(type_id)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected arithmetic operands to be of Result Type: "
<< spvOpcodeString(opcode) << " operand index "
<< operand_index;
}
spv_result_t ret =
_.CooperativeVectorDimensionsMatch(inst, type_id, result_type);
if (ret != SPV_SUCCESS) return ret;
} else if (supportsCoopMat &&
_.IsCooperativeMatrixKHRType(result_type)) {
const uint32_t type_id = _.GetOperandTypeId(inst, operand_index);
if (!_.IsCooperativeMatrixKHRType(type_id) ||
!_.IsFloatCooperativeMatrixType(type_id)) {
Expand All @@ -76,17 +91,32 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) {
case spv::Op::OpUDiv:
case spv::Op::OpUMod: {
bool supportsCoopMat = (opcode == spv::Op::OpUDiv);
bool supportsCoopVec = (opcode == spv::Op::OpUDiv);
if (!_.IsUnsignedIntScalarType(result_type) &&
!_.IsUnsignedIntVectorType(result_type) &&
!(supportsCoopMat &&
_.IsUnsignedIntCooperativeMatrixType(result_type)))
_.IsUnsignedIntCooperativeMatrixType(result_type)) &&
!(supportsCoopVec &&
_.IsUnsignedIntCooperativeVectorNVType(result_type)))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected unsigned int scalar or vector type as Result Type: "
<< spvOpcodeString(opcode);

for (size_t operand_index = 2; operand_index < inst->operands().size();
++operand_index) {
if (supportsCoopMat && _.IsCooperativeMatrixKHRType(result_type)) {
if (supportsCoopVec && _.IsCooperativeVectorNVType(result_type)) {
const uint32_t type_id = _.GetOperandTypeId(inst, operand_index);
if (!_.IsCooperativeVectorNVType(type_id)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected arithmetic operands to be of Result Type: "
<< spvOpcodeString(opcode) << " operand index "
<< operand_index;
}
spv_result_t ret =
_.CooperativeVectorDimensionsMatch(inst, type_id, result_type);
if (ret != SPV_SUCCESS) return ret;
} else if (supportsCoopMat &&
_.IsCooperativeMatrixKHRType(result_type)) {
const uint32_t type_id = _.GetOperandTypeId(inst, operand_index);
if (!_.IsCooperativeMatrixKHRType(type_id) ||
!_.IsUnsignedIntCooperativeMatrixType(type_id)) {
Expand Down Expand Up @@ -117,11 +147,14 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) {
bool supportsCoopMat =
(opcode != spv::Op::OpIMul && opcode != spv::Op::OpSRem &&
opcode != spv::Op::OpSMod);
bool supportsCoopVec =
(opcode != spv::Op::OpSRem && opcode != spv::Op::OpSMod);
if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type) &&
!(supportsCoopMat && _.IsIntCooperativeMatrixType(result_type)) &&
!(opcode == spv::Op::OpIMul &&
_.IsCooperativeMatrixKHRType(result_type) &&
_.IsIntCooperativeMatrixType(result_type)))
_.IsIntCooperativeMatrixType(result_type)) &&
!(supportsCoopVec && _.IsIntCooperativeVectorNVType(result_type)))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected int scalar or vector type as Result Type: "
<< spvOpcodeString(opcode);
Expand All @@ -133,6 +166,18 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) {
++operand_index) {
const uint32_t type_id = _.GetOperandTypeId(inst, operand_index);

if (supportsCoopVec && _.IsCooperativeVectorNVType(result_type)) {
if (!_.IsCooperativeVectorNVType(type_id)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected arithmetic operands to be of Result Type: "
<< spvOpcodeString(opcode) << " operand index "
<< operand_index;
}
spv_result_t ret =
_.CooperativeVectorDimensionsMatch(inst, type_id, result_type);
if (ret != SPV_SUCCESS) return ret;
}

if (supportsCoopMat && _.IsCooperativeMatrixKHRType(result_type)) {
if (!_.IsCooperativeMatrixKHRType(type_id) ||
!_.IsIntCooperativeMatrixType(type_id)) {
Expand All @@ -151,7 +196,8 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) {
!(supportsCoopMat && _.IsIntCooperativeMatrixType(result_type)) &&
!(opcode == spv::Op::OpIMul &&
_.IsCooperativeMatrixKHRType(result_type) &&
_.IsIntCooperativeMatrixType(result_type))))
_.IsIntCooperativeMatrixType(result_type)) &&
!(supportsCoopVec && _.IsIntCooperativeVectorNVType(result_type))))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected int scalar or vector type as operand: "
<< spvOpcodeString(opcode) << " operand index "
Expand Down Expand Up @@ -210,7 +256,8 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) {
}

case spv::Op::OpVectorTimesScalar: {
if (!_.IsFloatVectorType(result_type))
if (!_.IsFloatVectorType(result_type) &&
!_.IsFloatCooperativeVectorNVType(result_type))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected float vector type as Result Type: "
<< spvOpcodeString(opcode);
Expand Down
Loading

0 comments on commit f957854

Please sign in to comment.