diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td @@ -717,6 +717,30 @@ let hasRegionVerifier = 1; } +def GPU_SubgroupReduceOp : GPU_Op<"subgroup_reduce", + [SameOperandsAndResultType]>, + Arguments<(ins AnyType:$value, + GPU_AllReduceOperationAttr:$op)>, + Results<(outs AnyType)> { + let summary = "Reduce values among subgroup."; + let description = [{ + The `subgroup_reduce` op reduces the value of every work item across a + subgroup. The result is equal for all work items of a subgroup. + + Example: + + ```mlir + %1 = gpu.subgroup_reduce add %0 : (f32) -> (f32) + ``` + + Either none or all work items of a subgroup need to execute this op + in convergence. + }]; + let assemblyFormat = [{ custom($op) $value attr-dict + `:` functional-type(operands, results) }]; + let hasVerifier = 1; +} + def GPU_ShuffleOpXor : I32EnumAttrCase<"XOR", 0, "xor">; def GPU_ShuffleOpDown : I32EnumAttrCase<"DOWN", 1, "down">; def GPU_ShuffleOpUp : I32EnumAttrCase<"UP", 2, "up">; diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -309,6 +309,17 @@ // AllReduceOp //===----------------------------------------------------------------------===// +static bool verifyReduceOpAndType(gpu::AllReduceOperation opName, + Type resType) { + if ((opName == gpu::AllReduceOperation::AND || + opName == gpu::AllReduceOperation::OR || + opName == gpu::AllReduceOperation::XOR) && + !resType.isa()) + return false; + + return true; +} + LogicalResult gpu::AllReduceOp::verifyRegions() { if (getBody().empty() != getOp().has_value()) return emitError("expected either an op attribute or a non-empty body"); @@ -333,10 +344,7 @@ return emitError("expected gpu.yield op in region"); } else { gpu::AllReduceOperation opName = *getOp(); - if ((opName == gpu::AllReduceOperation::AND || - opName == gpu::AllReduceOperation::OR || - opName == gpu::AllReduceOperation::XOR) && - !getType().isa()) { + if (!verifyReduceOpAndType(opName, getType())) { return emitError() << '`' << gpu::stringifyAllReduceOperation(opName) << "` accumulator is only compatible with Integer type"; @@ -364,6 +372,19 @@ attr.print(printer); } +//===----------------------------------------------------------------------===// +// SubgroupReduceOp +//===----------------------------------------------------------------------===// + +LogicalResult gpu::SubgroupReduceOp::verify() { + gpu::AllReduceOperation opName = getOp(); + if (!verifyReduceOpAndType(opName, getType())) { + return emitError() << '`' << gpu::stringifyAllReduceOperation(opName) + << "` accumulator is only compatible with Integer type"; + } + return success(); +} + //===----------------------------------------------------------------------===// // AsyncOpInterface //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir --- a/mlir/test/Dialect/GPU/invalid.mlir +++ b/mlir/test/Dialect/GPU/invalid.mlir @@ -245,6 +245,14 @@ // ----- +func.func @subgroup_reduce_invalid_op_type(%arg0 : f32) { + // expected-error@+1 {{`and` accumulator is only compatible with Integer type}} + %res = gpu.subgroup_reduce and %arg0 : (f32) -> (f32) + return +} + +// ----- + func.func @reduce_incorrect_region_arguments(%arg0 : f32) { // expected-error@+1 {{expected two region arguments}} %res = gpu.all_reduce %arg0 { diff --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir --- a/mlir/test/Dialect/GPU/ops.mlir +++ b/mlir/test/Dialect/GPU/ops.mlir @@ -85,6 +85,9 @@ %one = arith.constant 1.0 : f32 %sum = gpu.all_reduce add %one {} : (f32) -> (f32) + // CHECK: %{{.*}} = gpu.subgroup_reduce add %{{.*}} : (f32) -> f32 + %sum_subgroup = gpu.subgroup_reduce add %one : (f32) -> f32 + %width = arith.constant 7 : i32 %offset = arith.constant 3 : i32 // CHECK: gpu.shuffle xor %{{.*}}, %{{.*}}, %{{.*}} : f32