diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -3231,6 +3231,7 @@ def SPV_OC_OpReturn : I32EnumAttrCase<"OpReturn", 253>; def SPV_OC_OpReturnValue : I32EnumAttrCase<"OpReturnValue", 254>; def SPV_OC_OpUnreachable : I32EnumAttrCase<"OpUnreachable", 255>; +def SPV_OC_OpGroupBroadcast : I32EnumAttrCase<"OpGroupBroadcast", 263>; def SPV_OC_OpNoLine : I32EnumAttrCase<"OpNoLine", 317>; def SPV_OC_OpModuleProcessed : I32EnumAttrCase<"OpModuleProcessed", 330>; def SPV_OC_OpGroupNonUniformElect : I32EnumAttrCase<"OpGroupNonUniformElect", 333>; @@ -3297,8 +3298,8 @@ SPV_OC_OpAtomicUMax, SPV_OC_OpAtomicAnd, SPV_OC_OpAtomicOr, SPV_OC_OpAtomicXor, SPV_OC_OpPhi, SPV_OC_OpLoopMerge, SPV_OC_OpSelectionMerge, SPV_OC_OpLabel, SPV_OC_OpBranch, SPV_OC_OpBranchConditional, SPV_OC_OpReturn, - SPV_OC_OpReturnValue, SPV_OC_OpUnreachable, SPV_OC_OpNoLine, - SPV_OC_OpModuleProcessed, SPV_OC_OpGroupNonUniformElect, + SPV_OC_OpReturnValue, SPV_OC_OpUnreachable, SPV_OC_OpGroupBroadcast, + SPV_OC_OpNoLine, SPV_OC_OpModuleProcessed, SPV_OC_OpGroupNonUniformElect, SPV_OC_OpGroupNonUniformBallot, SPV_OC_OpGroupNonUniformIAdd, SPV_OC_OpGroupNonUniformFAdd, SPV_OC_OpGroupNonUniformIMul, SPV_OC_OpGroupNonUniformFMul, SPV_OC_OpGroupNonUniformSMin, diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td @@ -17,6 +17,82 @@ // ----- +def SPV_GroupBroadcastOp : SPV_Op<"GroupBroadcast", + [NoSideEffect, AllTypesMatch<["value", "result"]>]> { + let summary = [{ + Return the Value of the invocation identified by the local id LocalId to + all invocations in the group. + }]; + + let description = [{ + All invocations of this module within Execution must reach this point of + execution. + + Behavior is undefined if this instruction is used in control flow that + is non-uniform within Execution. + + Result Type must be a scalar or vector of floating-point type, integer + type, or Boolean type. + + Execution must be Workgroup or Subgroup Scope. + + The type of Value must be the same as Result Type. + + LocalId must be an integer datatype. It can be a scalar, or a vector + with 2 components or a vector with 3 components. LocalId must be the + same for all invocations in the group. + + + + ``` + scope ::= `"Workgroup"` | `"Subgroup"` + integer-float-scalar-vector-type ::= integer-type | float-type | + `vector<` integer-literal `x` integer-type `>` | + `vector<` integer-literal `x` float-type `>` + localid-type ::= integer-type | + `vector<` integer-literal `x` integer-type `>` + group-broadcast-op ::= ssa-id `=` `spv.GroupBroadcast` scope ssa_use, + ssa_use `:` integer-float-scalar-vector-type `,` localid-type + ```mlir + + #### Example: + + ``` + %scalar_value = ... : f32 + %vector_value = ... : vector<4xf32> + %scalar_localid = ... : i32 + %vector_localid = ... : vector<3xi32> + %0 = spv.GroupBroadcast "Subgroup" %scalar_value, %scalar_localid : f32, i32 + %1 = spv.GroupBroadcast "Workgroup" %vector_value, %vector_localid : + vector<4xf32>, vector<3xi32> + ``` + }]; + + let availability = [ + MinVersion, + MaxVersion, + Extension<[]>, + Capability<[SPV_C_Groups]> + ]; + + let arguments = (ins + SPV_ScopeAttr:$execution_scope, + SPV_Type:$value, + SPV_ScalarOrVectorOf:$localid + ); + + let results = (outs + SPV_Type:$result + ); + + let assemblyFormat = [{ + $execution_scope operands attr-dict `:` type($value) `,` type($localid) + }]; + +} + +// ----- + def SPV_SubgroupBallotKHROp : SPV_Op<"SubgroupBallotKHR", []> { let summary = "See extension SPV_KHR_shader_ballot"; diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -1993,6 +1993,25 @@ return success(); } +//===----------------------------------------------------------------------===// +// spv.GroupBroadcast +//===----------------------------------------------------------------------===// + +static LogicalResult verify(spirv::GroupBroadcastOp broadcastOp) { + spirv::Scope scope = broadcastOp.execution_scope(); + if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) + return broadcastOp.emitOpError( + "execution scope must be 'Workgroup' or 'Subgroup'"); + + if (auto localIdTy = broadcastOp.localid().getType().dyn_cast()) + if (!(localIdTy.getNumElements() == 2 || localIdTy.getNumElements() == 3)) + return broadcastOp.emitOpError("localid is a vector and can be with only " + " 2 or 3 components, actual number is ") + << localIdTy.getNumElements(); + + return success(); +} + //===----------------------------------------------------------------------===// // spv.GroupNonUniformBallotOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/Serialization/group-ops.mlir b/mlir/test/Dialect/SPIRV/Serialization/group-ops.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/group-ops.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/group-ops.mlir @@ -7,4 +7,16 @@ %0 = spv.SubgroupBallotKHR %predicate: vector<4xi32> spv.ReturnValue %0: vector<4xi32> } + // CHECK-LABEL: @group_broadcast_1 + spv.func @group_broadcast_1(%value: f32, %localid: i32 ) -> f32 "None" { + // CHECK: spv.GroupBroadcast "Workgroup" %{{.*}}, %{{.*}} : f32, i32 + %0 = spv.GroupBroadcast "Workgroup" %value, %localid : f32, i32 + spv.ReturnValue %0: f32 + } + // CHECK-LABEL: @group_broadcast_2 + spv.func @group_broadcast_2(%value: f32, %localid: vector<3xi32> ) -> f32 "None" { + // CHECK: spv.GroupBroadcast "Workgroup" %{{.*}}, %{{.*}} : f32, vector<3xi32> + %0 = spv.GroupBroadcast "Workgroup" %value, %localid : f32, vector<3xi32> + spv.ReturnValue %0: f32 + } } diff --git a/mlir/test/Dialect/SPIRV/group-ops.mlir b/mlir/test/Dialect/SPIRV/group-ops.mlir --- a/mlir/test/Dialect/SPIRV/group-ops.mlir +++ b/mlir/test/Dialect/SPIRV/group-ops.mlir @@ -9,3 +9,55 @@ %0 = spv.SubgroupBallotKHR %predicate: vector<4xi32> return %0: vector<4xi32> } + +// ----- + +//===----------------------------------------------------------------------===// +// spv.GroupBroadcast +//===----------------------------------------------------------------------===// + +func @group_broadcast_scalar(%value: f32, %localid: i32 ) -> f32 { + // CHECK: spv.GroupBroadcast "Workgroup" %{{.*}}, %{{.*}} : f32, i32 + %0 = spv.GroupBroadcast "Workgroup" %value, %localid : f32, i32 + return %0: f32 +} + +// ----- + +func @group_broadcast_scalar_vector(%value: f32, %localid: vector<3xi32> ) -> f32 { + // CHECK: spv.GroupBroadcast "Workgroup" %{{.*}}, %{{.*}} : f32, vector<3xi32> + %0 = spv.GroupBroadcast "Workgroup" %value, %localid : f32, vector<3xi32> + return %0: f32 +} + +// ----- + +func @group_broadcast_vector(%value: vector<4xf32>, %localid: vector<3xi32> ) -> vector<4xf32> { + // CHECK: spv.GroupBroadcast "Subgroup" %{{.*}}, %{{.*}} : vector<4xf32>, vector<3xi32> + %0 = spv.GroupBroadcast "Subgroup" %value, %localid : vector<4xf32>, vector<3xi32> + return %0: vector<4xf32> +} + +// ----- + +func @group_broadcast_negative_scope(%value: f32, %localid: vector<3xi32> ) -> f32 { + // expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}} + %0 = spv.GroupBroadcast "Device" %value, %localid : f32, vector<3xi32> + return %0: f32 +} + +// ----- + +func @group_broadcast_negative_locid_dtype(%value: f32, %localid: vector<3xf32> ) -> f32 { + // expected-error @+1 {{operand #1 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values}} + %0 = spv.GroupBroadcast "Subgroup" %value, %localid : f32, vector<3xf32> + return %0: f32 +} + +// ----- + +func @group_broadcast_negative_locid_vec4(%value: f32, %localid: vector<4xi32> ) -> f32 { + // expected-error @+1 {{localid is a vector and can be with only 2 or 3 components, actual number is 4}} + %0 = spv.GroupBroadcast "Subgroup" %value, %localid : f32, vector<4xi32> + return %0: f32 +}