diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -2353,6 +2353,53 @@ SPV_FC_None, SPV_FC_Inline, SPV_FC_DontInline, SPV_FC_Pure, SPV_FC_Const ]>; +def SPV_GO_Reduce : I32EnumAttrCase<"Reduce", 0> { + list availability = [ + Capability<[SPV_C_GroupNonUniformArithmetic, SPV_C_GroupNonUniformBallot, SPV_C_Kernel]> + ]; +} +def SPV_GO_InclusiveScan : I32EnumAttrCase<"InclusiveScan", 1> { + list availability = [ + Capability<[SPV_C_GroupNonUniformArithmetic, SPV_C_GroupNonUniformBallot, SPV_C_Kernel]> + ]; +} +def SPV_GO_ExclusiveScan : I32EnumAttrCase<"ExclusiveScan", 2> { + list availability = [ + Capability<[SPV_C_GroupNonUniformArithmetic, SPV_C_GroupNonUniformBallot, SPV_C_Kernel]> + ]; +} +def SPV_GO_ClusteredReduce : I32EnumAttrCase<"ClusteredReduce", 3> { + list availability = [ + MinVersion, + Capability<[SPV_C_GroupNonUniformClustered]> + ]; +} +def SPV_GO_PartitionedReduceNV : I32EnumAttrCase<"PartitionedReduceNV", 6> { + list availability = [ + Extension<[SPV_NV_shader_subgroup_partitioned]>, + Capability<[SPV_C_GroupNonUniformPartitionedNV]> + ]; +} +def SPV_GO_PartitionedInclusiveScanNV : I32EnumAttrCase<"PartitionedInclusiveScanNV", 7> { + list availability = [ + Extension<[SPV_NV_shader_subgroup_partitioned]>, + Capability<[SPV_C_GroupNonUniformPartitionedNV]> + ]; +} +def SPV_GO_PartitionedExclusiveScanNV : I32EnumAttrCase<"PartitionedExclusiveScanNV", 8> { + list availability = [ + Extension<[SPV_NV_shader_subgroup_partitioned]>, + Capability<[SPV_C_GroupNonUniformPartitionedNV]> + ]; +} + +def SPV_GroupOperationAttr : + SPV_I32EnumAttr<"GroupOperation", "valid SPIR-V GroupOperation", [ + SPV_GO_Reduce, SPV_GO_InclusiveScan, SPV_GO_ExclusiveScan, + SPV_GO_ClusteredReduce, SPV_GO_PartitionedReduceNV, + SPV_GO_PartitionedInclusiveScanNV, SPV_GO_PartitionedExclusiveScanNV + ]>; + def SPV_IF_Unknown : I32EnumAttrCase<"Unknown", 0>; def SPV_IF_Rgba32f : I32EnumAttrCase<"Rgba32f", 1> { list availability = [ @@ -3108,7 +3155,9 @@ def SPV_OC_OpReturnValue : I32EnumAttrCase<"OpReturnValue", 254>; def SPV_OC_OpUnreachable : I32EnumAttrCase<"OpUnreachable", 255>; def SPV_OC_OpModuleProcessed : I32EnumAttrCase<"OpModuleProcessed", 330>; +def SPV_OC_OpGroupNonUniformElect : I32EnumAttrCase<"OpGroupNonUniformElect", 333>; def SPV_OC_OpGroupNonUniformBallot : I32EnumAttrCase<"OpGroupNonUniformBallot", 339>; +def SPV_OC_OpGroupNonUniformIAdd : I32EnumAttrCase<"OpGroupNonUniformIAdd", 349>; def SPV_OC_OpSubgroupBallotKHR : I32EnumAttrCase<"OpSubgroupBallotKHR", 4421>; def SPV_OpcodeAttr : @@ -3155,7 +3204,8 @@ SPV_OC_OpPhi, SPV_OC_OpLoopMerge, SPV_OC_OpSelectionMerge, SPV_OC_OpLabel, SPV_OC_OpBranch, SPV_OC_OpBranchConditional, SPV_OC_OpReturn, SPV_OC_OpReturnValue, SPV_OC_OpUnreachable, SPV_OC_OpModuleProcessed, - SPV_OC_OpGroupNonUniformBallot, SPV_OC_OpSubgroupBallotKHR + SPV_OC_OpGroupNonUniformElect, SPV_OC_OpGroupNonUniformBallot, + SPV_OC_OpGroupNonUniformIAdd, SPV_OC_OpSubgroupBallotKHR ]>; // End opcode section. Generated from SPIR-V spec; DO NOT MODIFY! diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td @@ -72,5 +72,118 @@ // ----- -#endif // SPIRV_NON_UNIFORM_OPS +def SPV_GroupNonUniformElectOp : SPV_Op<"GroupNonUniformElect", []> { + let summary = [{ + Result is true only in the active invocation with the lowest id in the + group, otherwise result is false. + }]; + + let description = [{ + Result Type must be a Boolean type. + + Execution must be Workgroup or Subgroup Scope. + + ### Custom assembly form + + ``` + scope ::= `"Workgroup"` | `"Subgroup"` + non-uniform-elect-op ::= ssa-id `=` `spv.GroupNonUniformElect` scope + `:` `i1` + ``` + + For example: + + ``` + %0 = spv.GroupNonUniformElect : i1 + ``` + }]; + let availability = [ + MinVersion, + MaxVersion, + Extension<[]>, + Capability<[SPV_C_GroupNonUniform]> + ]; + + let arguments = (ins + SPV_ScopeAttr:$execution_scope + ); + + let results = (outs + SPV_Bool:$result + ); + + let builders = [ + OpBuilder<[{Builder *builder, OperationState &state, spirv::Scope}]> + ]; +} + +// ----- + +def SPV_GroupNonUniformIAddOp : SPV_Op<"GroupNonUniformIAdd", []> { + let summary = [{ + An integer add group operation of all Value operands contributed active + by invocations in the group. + }]; + + let description = [{ + Result Type must be a scalar or vector of integer type. + + Execution must be Workgroup or Subgroup Scope. + + The identity I for Operation is 0. If Operation is ClusteredReduce, + ClusterSize must be specified. + + The type of Value must be the same as Result Type. + + ClusterSize is the size of cluster to use. ClusterSize must be a scalar + of integer type, whose Signedness operand is 0. ClusterSize must come + from a constant instruction. ClusterSize must be at least 1, and must be + a power of 2. If ClusterSize is greater than the declared SubGroupSize, + executing this instruction results in undefined behavior. + + ### Custom assembly form + + ``` + scope ::= `"Workgroup"` | `"Subgroup"` + operation ::= `"Reduce"` | `"InclusiveScan"` | `"ExclusiveScan"` | ... + integer-scalar-vector-type ::= integer-type | + `vector<` integer-literal `x` integer-type `>` + non-uniform-iadd-op ::= ssa-id `=` `spv.GroupNonUniformIAdd` scope operation + ssa-use ( `cluster_size` `(` ssa_use `)` )? + `:` integer-scalar-vector-type + ``` + + For example: + + ``` + %four = spv.constant 4 : i32 + %scalar = ... : i32 + %vector = ... : vector<4xi32> + %0 = spv.GroupNonUniformIAdd "Workgroup" "Reduce" %scalar : i32 + %1 = spv.GroupNonUniformIAdd "Subgroup" "ClusteredReduce" %vector cluster_size(%four) : vector<4xi32> + ``` + }]; + + let availability = [ + MinVersion, + MaxVersion, + Extension<[]>, + Capability<[SPV_C_GroupNonUniformArithmetic, SPV_C_GroupNonUniformClustered, SPV_C_GroupNonUniformPartitionedNV]> + ]; + + let arguments = (ins + SPV_ScopeAttr:$execution_scope, + SPV_GroupOperationAttr:$group_operation, + SPV_ScalarOrVectorOf:$value, + SPV_Optional:$cluster_size + ); + + let results = (outs + SPV_ScalarOrVectorOf:$result + ); +} + +// ----- + +#endif // SPIRV_NON_UNIFORM_OPS diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -32,10 +32,12 @@ static constexpr const char kAlignmentAttrName[] = "alignment"; static constexpr const char kBranchWeightAttrName[] = "branch_weights"; static constexpr const char kCallee[] = "callee"; +static constexpr const char kClusterSize[] = "cluster_size"; static constexpr const char kDefaultValueAttrName[] = "default_value"; static constexpr const char kExecutionScopeAttrName[] = "execution_scope"; static constexpr const char kEqualSemanticsAttrName[] = "equal_semantics"; static constexpr const char kFnNameAttrName[] = "fn"; +static constexpr const char kGroupOperationAttrName[] = "group_operation"; static constexpr const char kIndicesAttrName[] = "indices"; static constexpr const char kInitializerAttrName[] = "initializer"; static constexpr const char kInterfaceAttrName[] = "interface"; @@ -53,9 +55,8 @@ // Common utility functions //===----------------------------------------------------------------------===// -static LogicalResult extractValueFromConstOp(Operation *op, - int32_t &indexValue) { - auto constOp = dyn_cast(op); +static LogicalResult extractValueFromConstOp(Operation *op, int32_t &value) { + auto constOp = dyn_cast_or_null(op); if (!constOp) { return failure(); } @@ -64,7 +65,7 @@ if (!integerValueAttr) { return failure(); } - indexValue = integerValueAttr.getInt(); + value = integerValueAttr.getInt(); return success(); } @@ -1889,6 +1890,122 @@ } //===----------------------------------------------------------------------===// +// spv.GroupNonUniformElectOp +//===----------------------------------------------------------------------===// + +void spirv::GroupNonUniformElectOp::build(Builder *builder, + OperationState &state, + spirv::Scope scope) { + build(builder, state, builder->getI1Type(), scope); +} + +static ParseResult parseGroupNonUniformElectOp(OpAsmParser &parser, + OperationState &state) { + spirv::Scope executionScope; + Type resultType; + if (parseEnumAttribute(executionScope, parser, state, + kExecutionScopeAttrName) || + parser.parseColonType(resultType)) + return failure(); + + return parser.addTypeToList(resultType, state.types); +} + +static void print(spirv::GroupNonUniformElectOp groupOp, + OpAsmPrinter &printer) { + printer << spirv::GroupNonUniformElectOp::getOperationName() << " \"" + << stringifyScope(groupOp.execution_scope()) + << "\" : " << groupOp.getType(); +} + +static LogicalResult verify(spirv::GroupNonUniformElectOp groupOp) { + spirv::Scope scope = groupOp.execution_scope(); + if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) + return groupOp.emitOpError( + "execution scope must be 'Workgroup' or 'Subgroup'"); + + return success(); +} + +//===----------------------------------------------------------------------===// +// spv.GroupNonUniformIAddOp +//===----------------------------------------------------------------------===// + +static ParseResult parseGroupNonUniformIAddOp(OpAsmParser &parser, + OperationState &state) { + spirv::Scope executionScope; + spirv::GroupOperation groupOperation; + OpAsmParser::OperandType valueInfo; + if (parseEnumAttribute(executionScope, parser, state, + kExecutionScopeAttrName) || + parseEnumAttribute(groupOperation, parser, state, + kGroupOperationAttrName) || + parser.parseOperand(valueInfo)) + return failure(); + + Optional clusterSizeInfo; + if (succeeded(parser.parseOptionalKeyword(kClusterSize))) { + clusterSizeInfo = OpAsmParser::OperandType(); + if (parser.parseLParen() || parser.parseOperand(*clusterSizeInfo) || + parser.parseRParen()) + return failure(); + } + + Type resultType; + if (parser.parseColonType(resultType)) + return failure(); + + if (parser.resolveOperand(valueInfo, resultType, state.operands)) + return failure(); + + if (clusterSizeInfo.hasValue()) { + Type i32Type = parser.getBuilder().getIntegerType(32); + if (parser.resolveOperand(*clusterSizeInfo, i32Type, state.operands)) + return failure(); + } + + return parser.addTypeToList(resultType, state.types); +} + +static void print(spirv::GroupNonUniformIAddOp groupOp, OpAsmPrinter &printer) { + printer << spirv::GroupNonUniformIAddOp::getOperationName() << " \"" + << stringifyScope(groupOp.execution_scope()) << "\" \"" + << stringifyGroupOperation(groupOp.group_operation()) << "\" " + << groupOp.value(); + if (!groupOp.cluster_size().empty()) + printer << " " << kClusterSize << '(' << groupOp.cluster_size() << ')'; + printer << " : " << groupOp.getType(); +} + +static LogicalResult verify(spirv::GroupNonUniformIAddOp groupOp) { + spirv::Scope scope = groupOp.execution_scope(); + if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) + return groupOp.emitOpError( + "execution scope must be 'Workgroup' or 'Subgroup'"); + + spirv::GroupOperation operation = groupOp.group_operation(); + if (operation == spirv::GroupOperation::ClusteredReduce && + groupOp.cluster_size().empty()) + return groupOp.emitOpError("cluster size operand must be provided for " + "'ClusteredReduce' group operation"); + + if (!groupOp.cluster_size().empty()) { + Operation *sizeOp = (*groupOp.cluster_size().begin()).getDefiningOp(); + int32_t clusterSize = 0; + + // TODO(antiagainst): support specialization constant here. + if (failed(extractValueFromConstOp(sizeOp, clusterSize))) + return groupOp.emitOpError( + "cluster size operand must come from a constant op"); + + if (!llvm::isPowerOf2_32(clusterSize)) + return groupOp.emitOpError("cluster size operand must be a power of two"); + } + + return success(); +} + +//===----------------------------------------------------------------------===// // spv.IAdd //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/Serialization/non-uniform-ops.mlir b/mlir/test/Dialect/SPIRV/Serialization/non-uniform-ops.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/non-uniform-ops.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/non-uniform-ops.mlir @@ -7,4 +7,26 @@ %0 = spv.GroupNonUniformBallot "Workgroup" %predicate : vector<4xi32> spv.ReturnValue %0: vector<4xi32> } + + // CHECK-LABEL: @group_non_uniform_elect + func @group_non_uniform_elect() -> i1 { + // CHECK: %{{.+}} = spv.GroupNonUniformElect "Workgroup" : i1 + %0 = spv.GroupNonUniformElect "Workgroup" : i1 + spv.ReturnValue %0: i1 + } + + // CHECK-LABEL: @group_non_uniform_iadd_reduce + func @group_non_uniform_iadd_reduce(%val: i32) -> i32 { + // CHECK: %{{.+}} = spv.GroupNonUniformIAdd "Workgroup" "Reduce" %{{.+}} : i32 + %0 = spv.GroupNonUniformIAdd "Workgroup" "Reduce" %val : i32 + spv.ReturnValue %0: i32 + } + + // CHECK-LABEL: @group_non_uniform_iadd_clustered_reduce + func @group_non_uniform_iadd_clustered_reduce(%val: vector<2xi32>) -> vector<2xi32> { + %four = spv.constant 4 : i32 + // CHECK: %{{.+}} = spv.GroupNonUniformIAdd "Workgroup" "ClusteredReduce" %{{.+}} cluster_size(%{{.+}}) : vector<2xi32> + %0 = spv.GroupNonUniformIAdd "Workgroup" "ClusteredReduce" %val cluster_size(%four) : vector<2xi32> + spv.ReturnValue %0: vector<2xi32> + } } diff --git a/mlir/test/Dialect/SPIRV/non-uniform-ops.mlir b/mlir/test/Dialect/SPIRV/non-uniform-ops.mlir --- a/mlir/test/Dialect/SPIRV/non-uniform-ops.mlir +++ b/mlir/test/Dialect/SPIRV/non-uniform-ops.mlir @@ -4,7 +4,7 @@ // spv.GroupNonUniformBallot //===----------------------------------------------------------------------===// -func @subgroup_ballot(%predicate: i1) -> vector<4xi32> { +func @group_non_uniform_ballot(%predicate: i1) -> vector<4xi32> { // CHECK: %{{.*}} = spv.GroupNonUniformBallot "Workgroup" %{{.*}}: vector<4xi32> %0 = spv.GroupNonUniformBallot "Workgroup" %predicate : vector<4xi32> return %0: vector<4xi32> @@ -12,8 +12,83 @@ // ----- -func @subgroup_ballot(%predicate: i1) -> vector<4xi32> { +func @group_non_uniform_ballot(%predicate: i1) -> vector<4xi32> { // expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}} %0 = spv.GroupNonUniformBallot "Device" %predicate : vector<4xi32> return %0: vector<4xi32> } + +// ----- + +//===----------------------------------------------------------------------===// +// spv.GroupNonUniformElect +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @group_non_uniform_elect +func @group_non_uniform_elect() -> i1 { + // CHECK: %{{.+}} = spv.GroupNonUniformElect "Workgroup" : i1 + %0 = spv.GroupNonUniformElect "Workgroup" : i1 + return %0: i1 +} + +// ----- + +func @group_non_uniform_elect() -> i1 { + // expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}} + %0 = spv.GroupNonUniformElect "CrossDevice" : i1 + return %0: i1 +} + +// ----- + +//===----------------------------------------------------------------------===// +// spv.GroupNonUniformIAdd +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @group_non_uniform_iadd_reduce +func @group_non_uniform_iadd_reduce(%val: i32) -> i32 { + // CHECK: %{{.+}} = spv.GroupNonUniformIAdd "Workgroup" "Reduce" %{{.+}} : i32 + %0 = spv.GroupNonUniformIAdd "Workgroup" "Reduce" %val : i32 + return %0: i32 +} + +// CHECK-LABEL: @group_non_uniform_iadd_clustered_reduce +func @group_non_uniform_iadd_clustered_reduce(%val: vector<2xi32>) -> vector<2xi32> { + %four = spv.constant 4 : i32 + // CHECK: %{{.+}} = spv.GroupNonUniformIAdd "Workgroup" "ClusteredReduce" %{{.+}} cluster_size(%{{.+}}) : vector<2xi32> + %0 = spv.GroupNonUniformIAdd "Workgroup" "ClusteredReduce" %val cluster_size(%four) : vector<2xi32> + return %0: vector<2xi32> +} + +// ----- + +func @group_non_uniform_iadd_reduce(%val: i32) -> i32 { + // expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}} + %0 = spv.GroupNonUniformIAdd "Device" "Reduce" %val : i32 + return %0: i32 +} + +// ----- + +func @group_non_uniform_iadd_clustered_reduce(%val: vector<2xi32>) -> vector<2xi32> { + // expected-error @+1 {{cluster size operand must be provided for 'ClusteredReduce' group operation}} + %0 = spv.GroupNonUniformIAdd "Workgroup" "ClusteredReduce" %val : vector<2xi32> + return %0: vector<2xi32> +} + +// ----- + +func @group_non_uniform_iadd_clustered_reduce(%val: vector<2xi32>, %size: i32) -> vector<2xi32> { + // expected-error @+1 {{cluster size operand must come from a constant op}} + %0 = spv.GroupNonUniformIAdd "Workgroup" "ClusteredReduce" %val cluster_size(%size) : vector<2xi32> + return %0: vector<2xi32> +} + +// ----- + +func @group_non_uniform_iadd_clustered_reduce(%val: vector<2xi32>) -> vector<2xi32> { + %five = spv.constant 5 : i32 + // expected-error @+1 {{cluster size operand must be a power of two}} + %0 = spv.GroupNonUniformIAdd "Workgroup" "ClusteredReduce" %val cluster_size(%five) : vector<2xi32> + return %0: vector<2xi32> +}