diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -3158,6 +3158,9 @@ def SPV_OC_OpGroupNonUniformElect : I32EnumAttrCase<"OpGroupNonUniformElect", 333>; def SPV_OC_OpGroupNonUniformBallot : I32EnumAttrCase<"OpGroupNonUniformBallot", 339>; def SPV_OC_OpGroupNonUniformIAdd : I32EnumAttrCase<"OpGroupNonUniformIAdd", 349>; +def SPV_OC_OpGroupNonUniformFAdd : I32EnumAttrCase<"OpGroupNonUniformFAdd", 350>; +def SPV_OC_OpGroupNonUniformIMul : I32EnumAttrCase<"OpGroupNonUniformIMul", 351>; +def SPV_OC_OpGroupNonUniformFMul : I32EnumAttrCase<"OpGroupNonUniformFMul", 352>; def SPV_OC_OpSubgroupBallotKHR : I32EnumAttrCase<"OpSubgroupBallotKHR", 4421>; def SPV_OpcodeAttr : @@ -3205,7 +3208,9 @@ SPV_OC_OpBranch, SPV_OC_OpBranchConditional, SPV_OC_OpReturn, SPV_OC_OpReturnValue, SPV_OC_OpUnreachable, SPV_OC_OpModuleProcessed, SPV_OC_OpGroupNonUniformElect, SPV_OC_OpGroupNonUniformBallot, - SPV_OC_OpGroupNonUniformIAdd, SPV_OC_OpSubgroupBallotKHR + SPV_OC_OpGroupNonUniformIAdd, SPV_OC_OpGroupNonUniformFAdd, + SPV_OC_OpGroupNonUniformIMul, SPV_OC_OpGroupNonUniformFMul, + SPV_OC_OpSubgroupBallotKHR ]>; // End opcode section. Generated from SPIR-V spec; DO NOT MODIFY! diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td @@ -14,6 +14,35 @@ #ifndef SPIRV_NON_UNIFORM_OPS #define SPIRV_NON_UNIFORM_OPS +class SPV_GroupNonUniformArithmeticOp traits = []> : SPV_Op { + + let availability = [ + MinVersion, + MaxVersion, + Extension<[]>, + Capability<[SPV_C_GroupNonUniformArithmetic, + SPV_C_GroupNonUniformClustered, + SPV_C_GroupNonUniformPartitionedNV]> + ]; + + let arguments = (ins + SPV_ScopeAttr:$execution_scope, + SPV_GroupOperationAttr:$group_operation, + SPV_ScalarOrVectorOf:$value, + SPV_Optional:$cluster_size + ); + + let results = (outs + SPV_ScalarOrVectorOf:$result + ); + + let parser = [{ return parseGroupNonUniformArithmeticOp(parser, result); }]; + let printer = [{ printGroupNonUniformArithmeticOp(getOperation(), p); }]; + let verifier = [{ return ::verifyGroupNonUniformArithmeticOp(getOperation()); }]; + +} + // ----- def SPV_GroupNonUniformBallotOp : SPV_Op<"GroupNonUniformBallot", []> { @@ -120,7 +149,110 @@ // ----- -def SPV_GroupNonUniformIAddOp : SPV_Op<"GroupNonUniformIAdd", []> { +def SPV_GroupNonUniformFAddOp : + SPV_GroupNonUniformArithmeticOp<"GroupNonUniformFAdd", SPV_Float, []> { + let summary = [{ + A floating point add group operation of all Value operands contributed + by active invocations in the group. + }]; + + let description = [{ + Result Type must be a scalar or vector of floating-point type. + + Execution must be Workgroup or Subgroup Scope. + + The identity I for Operation is 0. If Operation is ClusteredReduce, + ClusterSize must be specified. + + The type of Value must be the same as Result Type. The method used to + perform the group operation on the contributed Value(s) from active + invocations is implementation defined. + + ClusterSize is the size of cluster to use. ClusterSize must be a scalar + of integer type, whose Signedness operand is 0. ClusterSize must come + from a constant instruction. ClusterSize must be at least 1, and must be + a power of 2. If ClusterSize is greater than the declared SubGroupSize, + executing this instruction results in undefined behavior. + + ### Custom assembly form + + ``` + scope ::= `"Workgroup"` | `"Subgroup"` + operation ::= `"Reduce"` | `"InclusiveScan"` | `"ExclusiveScan"` | ... + float-scalar-vector-type ::= float-type | + `vector<` integer-literal `x` float-type `>` + non-uniform-fadd-op ::= ssa-id `=` `spv.GroupNonUniformFAdd` scope operation + ssa-use ( `cluster_size` `(` ssa_use `)` )? + `:` float-scalar-vector-type + ``` + + For example: + + ``` + %four = spv.constant 4 : i32 + %scalar = ... : f32 + %vector = ... : vector<4xf32> + %0 = spv.GroupNonUniformFAdd "Workgroup" "Reduce" %scalar : f32 + %1 = spv.GroupNonUniformFAdd "Subgroup" "ClusteredReduce" %vector cluster_size(%four) : vector<4xf32> + ``` + }]; +} + +// ----- + +def SPV_GroupNonUniformFMulOp : + SPV_GroupNonUniformArithmeticOp<"GroupNonUniformFMul", SPV_Float, []> { + let summary = [{ + A floating point multiply group operation of all Value operands + contributed by active invocations in the group. + }]; + + let description = [{ + Result Type must be a scalar or vector of floating-point type. + + Execution must be Workgroup or Subgroup Scope. + + The identity I for Operation is 1. If Operation is ClusteredReduce, + ClusterSize must be specified. + + The type of Value must be the same as Result Type. The method used to + perform the group operation on the contributed Value(s) from active + invocations is implementation defined. + + ClusterSize is the size of cluster to use. ClusterSize must be a scalar + of integer type, whose Signedness operand is 0. ClusterSize must come + from a constant instruction. ClusterSize must be at least 1, and must be + a power of 2. If ClusterSize is greater than the declared SubGroupSize, + executing this instruction results in undefined behavior. + + ### Custom assembly form + + ``` + scope ::= `"Workgroup"` | `"Subgroup"` + operation ::= `"Reduce"` | `"InclusiveScan"` | `"ExclusiveScan"` | ... + float-scalar-vector-type ::= float-type | + `vector<` integer-literal `x` float-type `>` + non-uniform-fmul-op ::= ssa-id `=` `spv.GroupNonUniformFMul` scope operation + ssa-use ( `cluster_size` `(` ssa_use `)` )? + `:` float-scalar-vector-type + ``` + + For example: + + ``` + %four = spv.constant 4 : i32 + %scalar = ... : f32 + %vector = ... : vector<4xf32> + %0 = spv.GroupNonUniformFMul "Workgroup" "Reduce" %scalar : f32 + %1 = spv.GroupNonUniformFMul "Subgroup" "ClusteredReduce" %vector cluster_size(%four) : vector<4xf32> + ``` + }]; +} + +// ----- + +def SPV_GroupNonUniformIAddOp : + SPV_GroupNonUniformArithmeticOp<"GroupNonUniformIAdd", SPV_Integer, []> { let summary = [{ An integer add group operation of all Value operands contributed active by invocations in the group. @@ -164,24 +296,55 @@ %1 = spv.GroupNonUniformIAdd "Subgroup" "ClusteredReduce" %vector cluster_size(%four) : vector<4xi32> ``` }]; +} - let availability = [ - MinVersion, - MaxVersion, - Extension<[]>, - Capability<[SPV_C_GroupNonUniformArithmetic, SPV_C_GroupNonUniformClustered, SPV_C_GroupNonUniformPartitionedNV]> - ]; +// ----- - let arguments = (ins - SPV_ScopeAttr:$execution_scope, - SPV_GroupOperationAttr:$group_operation, - SPV_ScalarOrVectorOf:$value, - SPV_Optional:$cluster_size - ); +def SPV_GroupNonUniformIMulOp : + SPV_GroupNonUniformArithmeticOp<"GroupNonUniformIMul", SPV_Integer, []> { + let summary = [{ + An integer multiply group operation of all Value operands contributed by + active invocations in the group. + }]; - let results = (outs - SPV_ScalarOrVectorOf:$result - ); + let description = [{ + Result Type must be a scalar or vector of integer type. + + Execution must be Workgroup or Subgroup Scope. + + The identity I for Operation is 1. If Operation is ClusteredReduce, + ClusterSize must be specified. + + The type of Value must be the same as Result Type. + + ClusterSize is the size of cluster to use. ClusterSize must be a scalar + of integer type, whose Signedness operand is 0. ClusterSize must come + from a constant instruction. ClusterSize must be at least 1, and must be + a power of 2. If ClusterSize is greater than the declared SubGroupSize, + executing this instruction results in undefined behavior. + + ### Custom assembly form + + ``` + scope ::= `"Workgroup"` | `"Subgroup"` + operation ::= `"Reduce"` | `"InclusiveScan"` | `"ExclusiveScan"` | ... + integer-scalar-vector-type ::= integer-type | + `vector<` integer-literal `x` integer-type `>` + non-uniform-imul-op ::= ssa-id `=` `spv.GroupNonUniformIMul` scope operation + ssa-use ( `cluster_size` `(` ssa_use `)` )? + `:` integer-scalar-vector-type + ``` + + For example: + + ``` + %four = spv.constant 4 : i32 + %scalar = ... : i32 + %vector = ... : vector<4xi32> + %0 = spv.GroupNonUniformIMul "Workgroup" "Reduce" %scalar : i32 + %1 = spv.GroupNonUniformIMul "Subgroup" "ClusteredReduce" %vector cluster_size(%four) : vector<4xi32> + ``` + }]; } // ----- diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -588,6 +588,88 @@ return success(); } +static ParseResult parseGroupNonUniformArithmeticOp(OpAsmParser &parser, + OperationState &state) { + spirv::Scope executionScope; + spirv::GroupOperation groupOperation; + OpAsmParser::OperandType valueInfo; + if (parseEnumAttribute(executionScope, parser, state, + kExecutionScopeAttrName) || + parseEnumAttribute(groupOperation, parser, state, + kGroupOperationAttrName) || + parser.parseOperand(valueInfo)) + return failure(); + + Optional clusterSizeInfo; + if (succeeded(parser.parseOptionalKeyword(kClusterSize))) { + clusterSizeInfo = OpAsmParser::OperandType(); + if (parser.parseLParen() || parser.parseOperand(*clusterSizeInfo) || + parser.parseRParen()) + return failure(); + } + + Type resultType; + if (parser.parseColonType(resultType)) + return failure(); + + if (parser.resolveOperand(valueInfo, resultType, state.operands)) + return failure(); + + if (clusterSizeInfo.hasValue()) { + Type i32Type = parser.getBuilder().getIntegerType(32); + if (parser.resolveOperand(*clusterSizeInfo, i32Type, state.operands)) + return failure(); + } + + return parser.addTypeToList(resultType, state.types); +} + +static void printGroupNonUniformArithmeticOp(Operation *groupOp, + OpAsmPrinter &printer) { + printer << groupOp->getName() << " \"" + << stringifyScope(static_cast( + groupOp->getAttrOfType(kExecutionScopeAttrName) + .getInt())) + << "\" \"" + << stringifyGroupOperation(static_cast( + groupOp->getAttrOfType(kGroupOperationAttrName) + .getInt())) + << "\" " << groupOp->getOperand(0); + + if (groupOp->getNumOperands() > 1) + printer << " " << kClusterSize << '(' << groupOp->getOperand(1) << ')'; + printer << " : " << groupOp->getResult(0).getType(); +} + +static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp) { + spirv::Scope scope = static_cast( + groupOp->getAttrOfType(kExecutionScopeAttrName).getInt()); + if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) + return groupOp->emitOpError( + "execution scope must be 'Workgroup' or 'Subgroup'"); + + spirv::GroupOperation operation = static_cast( + groupOp->getAttrOfType(kGroupOperationAttrName).getInt()); + if (operation == spirv::GroupOperation::ClusteredReduce && + groupOp->getNumOperands() == 1) + return groupOp->emitOpError("cluster size operand must be provided for " + "'ClusteredReduce' group operation"); + if (groupOp->getNumOperands() > 1) { + Operation *sizeOp = groupOp->getOperand(1).getDefiningOp(); + int32_t clusterSize = 0; + + // TODO(antiagainst): support specialization constant here. + if (failed(extractValueFromConstOp(sizeOp, clusterSize))) + return groupOp->emitOpError( + "cluster size operand must come from a constant op"); + + if (!llvm::isPowerOf2_32(clusterSize)) + return groupOp->emitOpError( + "cluster size operand must be a power of two"); + } + return success(); +} + // Parses an op that has no inputs and no outputs. static ParseResult parseNoIOOp(OpAsmParser &parser, OperationState &state) { if (parser.parseOptionalAttrDict(state.attributes)) @@ -1939,83 +2021,7 @@ return success(); } -//===----------------------------------------------------------------------===// -// spv.GroupNonUniformIAddOp -//===----------------------------------------------------------------------===// -static ParseResult parseGroupNonUniformIAddOp(OpAsmParser &parser, - OperationState &state) { - spirv::Scope executionScope; - spirv::GroupOperation groupOperation; - OpAsmParser::OperandType valueInfo; - if (parseEnumAttribute(executionScope, parser, state, - kExecutionScopeAttrName) || - parseEnumAttribute(groupOperation, parser, state, - kGroupOperationAttrName) || - parser.parseOperand(valueInfo)) - return failure(); - - Optional clusterSizeInfo; - if (succeeded(parser.parseOptionalKeyword(kClusterSize))) { - clusterSizeInfo = OpAsmParser::OperandType(); - if (parser.parseLParen() || parser.parseOperand(*clusterSizeInfo) || - parser.parseRParen()) - return failure(); - } - - Type resultType; - if (parser.parseColonType(resultType)) - return failure(); - - if (parser.resolveOperand(valueInfo, resultType, state.operands)) - return failure(); - - if (clusterSizeInfo.hasValue()) { - Type i32Type = parser.getBuilder().getIntegerType(32); - if (parser.resolveOperand(*clusterSizeInfo, i32Type, state.operands)) - return failure(); - } - - return parser.addTypeToList(resultType, state.types); -} - -static void print(spirv::GroupNonUniformIAddOp groupOp, OpAsmPrinter &printer) { - printer << spirv::GroupNonUniformIAddOp::getOperationName() << " \"" - << stringifyScope(groupOp.execution_scope()) << "\" \"" - << stringifyGroupOperation(groupOp.group_operation()) << "\" " - << groupOp.value(); - if (!groupOp.cluster_size().empty()) - printer << " " << kClusterSize << '(' << groupOp.cluster_size() << ')'; - printer << " : " << groupOp.getType(); -} - -static LogicalResult verify(spirv::GroupNonUniformIAddOp groupOp) { - spirv::Scope scope = groupOp.execution_scope(); - if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) - return groupOp.emitOpError( - "execution scope must be 'Workgroup' or 'Subgroup'"); - - spirv::GroupOperation operation = groupOp.group_operation(); - if (operation == spirv::GroupOperation::ClusteredReduce && - groupOp.cluster_size().empty()) - return groupOp.emitOpError("cluster size operand must be provided for " - "'ClusteredReduce' group operation"); - - if (!groupOp.cluster_size().empty()) { - Operation *sizeOp = (*groupOp.cluster_size().begin()).getDefiningOp(); - int32_t clusterSize = 0; - - // TODO(antiagainst): support specialization constant here. - if (failed(extractValueFromConstOp(sizeOp, clusterSize))) - return groupOp.emitOpError( - "cluster size operand must come from a constant op"); - - if (!llvm::isPowerOf2_32(clusterSize)) - return groupOp.emitOpError("cluster size operand must be a power of two"); - } - - return success(); -} //===----------------------------------------------------------------------===// // spv.IAdd diff --git a/mlir/test/Dialect/SPIRV/Serialization/non-uniform-ops.mlir b/mlir/test/Dialect/SPIRV/Serialization/non-uniform-ops.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/non-uniform-ops.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/non-uniform-ops.mlir @@ -15,6 +15,20 @@ spv.ReturnValue %0: i1 } + // CHECK-LABEL: @group_non_uniform_fadd_reduce + func @group_non_uniform_fadd_reduce(%val: f32) -> f32 { + // CHECK: %{{.+}} = spv.GroupNonUniformFAdd "Workgroup" "Reduce" %{{.+}} : f32 + %0 = spv.GroupNonUniformFAdd "Workgroup" "Reduce" %val : f32 + spv.ReturnValue %0: f32 + } + + // CHECK-LABEL: @group_non_uniform_fmul_reduce + func @group_non_uniform_fmul_reduce(%val: f32) -> f32 { + // CHECK: %{{.+}} = spv.GroupNonUniformFMul "Workgroup" "Reduce" %{{.+}} : f32 + %0 = spv.GroupNonUniformFMul "Workgroup" "Reduce" %val : f32 + spv.ReturnValue %0: f32 + } + // CHECK-LABEL: @group_non_uniform_iadd_reduce func @group_non_uniform_iadd_reduce(%val: i32) -> i32 { // CHECK: %{{.+}} = spv.GroupNonUniformIAdd "Workgroup" "Reduce" %{{.+}} : i32 @@ -29,4 +43,12 @@ %0 = spv.GroupNonUniformIAdd "Workgroup" "ClusteredReduce" %val cluster_size(%four) : vector<2xi32> spv.ReturnValue %0: vector<2xi32> } + + // CHECK-LABEL: @group_non_uniform_imul_reduce + func @group_non_uniform_imul_reduce(%val: i32) -> i32 { + // CHECK: %{{.+}} = spv.GroupNonUniformIMul "Workgroup" "Reduce" %{{.+}} : i32 + %0 = spv.GroupNonUniformIMul "Workgroup" "Reduce" %val : i32 + spv.ReturnValue %0: i32 + } + } diff --git a/mlir/test/Dialect/SPIRV/non-uniform-ops.mlir b/mlir/test/Dialect/SPIRV/non-uniform-ops.mlir --- a/mlir/test/Dialect/SPIRV/non-uniform-ops.mlir +++ b/mlir/test/Dialect/SPIRV/non-uniform-ops.mlir @@ -42,6 +42,46 @@ // ----- //===----------------------------------------------------------------------===// +// spv.GroupNonUniformFAdd +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @group_non_uniform_fadd_reduce +func @group_non_uniform_fadd_reduce(%val: f32) -> f32 { + // CHECK: %{{.+}} = spv.GroupNonUniformFAdd "Workgroup" "Reduce" %{{.+}} : f32 + %0 = spv.GroupNonUniformFAdd "Workgroup" "Reduce" %val : f32 + return %0: f32 +} + +// CHECK-LABEL: @group_non_uniform_fadd_clustered_reduce +func @group_non_uniform_fadd_clustered_reduce(%val: vector<2xf32>) -> vector<2xf32> { + %four = spv.constant 4 : i32 + // CHECK: %{{.+}} = spv.GroupNonUniformFAdd "Workgroup" "ClusteredReduce" %{{.+}} cluster_size(%{{.+}}) : vector<2xf32> + %0 = spv.GroupNonUniformFAdd "Workgroup" "ClusteredReduce" %val cluster_size(%four) : vector<2xf32> + return %0: vector<2xf32> +} + +//===----------------------------------------------------------------------===// +// spv.GroupNonUniformFMul +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @group_non_uniform_fmul_reduce +func @group_non_uniform_fmul_reduce(%val: f32) -> f32 { + // CHECK: %{{.+}} = spv.GroupNonUniformFMul "Workgroup" "Reduce" %{{.+}} : f32 + %0 = spv.GroupNonUniformFMul "Workgroup" "Reduce" %val : f32 + return %0: f32 +} + +// CHECK-LABEL: @group_non_uniform_fmul_clustered_reduce +func @group_non_uniform_fmul_clustered_reduce(%val: vector<2xf32>) -> vector<2xf32> { + %four = spv.constant 4 : i32 + // CHECK: %{{.+}} = spv.GroupNonUniformFMul "Workgroup" "ClusteredReduce" %{{.+}} cluster_size(%{{.+}}) : vector<2xf32> + %0 = spv.GroupNonUniformFMul "Workgroup" "ClusteredReduce" %val cluster_size(%four) : vector<2xf32> + return %0: vector<2xf32> +} + +// ----- + +//===----------------------------------------------------------------------===// // spv.GroupNonUniformIAdd //===----------------------------------------------------------------------===// @@ -92,3 +132,24 @@ %0 = spv.GroupNonUniformIAdd "Workgroup" "ClusteredReduce" %val cluster_size(%five) : vector<2xi32> return %0: vector<2xi32> } + +// ----- + +//===----------------------------------------------------------------------===// +// spv.GroupNonUniformIMul +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @group_non_uniform_imul_reduce +func @group_non_uniform_imul_reduce(%val: i32) -> i32 { + // CHECK: %{{.+}} = spv.GroupNonUniformIMul "Workgroup" "Reduce" %{{.+}} : i32 + %0 = spv.GroupNonUniformIMul "Workgroup" "Reduce" %val : i32 + return %0: i32 +} + +// CHECK-LABEL: @group_non_uniform_imul_clustered_reduce +func @group_non_uniform_imul_clustered_reduce(%val: vector<2xi32>) -> vector<2xi32> { + %four = spv.constant 4 : i32 + // CHECK: %{{.+}} = spv.GroupNonUniformIMul "Workgroup" "ClusteredReduce" %{{.+}} cluster_size(%{{.+}}) : vector<2xi32> + %0 = spv.GroupNonUniformIMul "Workgroup" "ClusteredReduce" %val cluster_size(%four) : vector<2xi32> + return %0: vector<2xi32> +}