diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td @@ -887,6 +887,8 @@ let assemblyFormat = [{ custom($op) $value (`uniform` $uniform^)? $body attr-dict `:` functional-type(operands, results) }]; + + let hasFolder = 1; let hasRegionVerifier = 1; } @@ -913,6 +915,8 @@ let assemblyFormat = [{ custom($op) $value (`uniform` $uniform^)? attr-dict `:` functional-type(operands, results) }]; + + let hasFolder = 1; let hasVerifier = 1; } diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -431,6 +431,27 @@ return success(); } +static bool canMakeGroupOpUniform(Operation *op) { + auto launchOp = dyn_cast(op->getParentOp()); + if (!launchOp) + return false; + + Region &body = launchOp.getBody(); + assert(!body.empty() && "Invalid region"); + + // Only convert ops in gpu::launch entry block for now. + return op->getBlock() == &body.front(); +} + +OpFoldResult gpu::AllReduceOp::fold(FoldAdaptor /*adaptor*/) { + if (!getUniform() && canMakeGroupOpUniform(*this)) { + setUniform(true); + return getResult(); + } + + return nullptr; +} + // TODO: Support optional custom attributes (without dialect prefix). static ParseResult parseAllReduceOperation(AsmParser &parser, AllReduceOperationAttr &attr) { @@ -464,6 +485,15 @@ return success(); } +OpFoldResult gpu::SubgroupReduceOp::fold(FoldAdaptor /*adaptor*/) { + if (!getUniform() && canMakeGroupOpUniform(*this)) { + setUniform(true); + return getResult(); + } + + return nullptr; +} + //===----------------------------------------------------------------------===// // AsyncOpInterface //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/GPU/canonicalize.mlir b/mlir/test/Dialect/GPU/canonicalize.mlir --- a/mlir/test/Dialect/GPU/canonicalize.mlir +++ b/mlir/test/Dialect/GPU/canonicalize.mlir @@ -170,8 +170,8 @@ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK: gpu.launch blocks(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %[[C1]], %{{.*}} = %[[C1]], %{{.*}} = %[[C1]]) threads(%[[TIDX:.*]], %{{.*}}, %{{.*}}) in (%{{.*}} = %c32, %{{.*}} = %[[C1]], %{{.*}} = %[[C1]]) { -// CHECK-NEXT: arith.divui %[[TIDX]], %c32 : index -// CHECK-NEXT: arith.muli %{{.*}}, %c2 : index +// CHECK-NEXT: arith.divui %[[TIDX]], %c32 : index +// CHECK-NEXT: arith.muli %{{.*}}, %c2 : index // CHECK-NEXT: memref.load %memref[%{{.*}}, %[[C0]], %[[C0]]] : memref<2x16x16xf32> // CHECK-NEXT: arith.addi %{{.*}}, %[[C1]] : index // CHECK-NEXT: memref.load %memref[%{{.*}}, %[[C0]], %[[C0]]] : memref<2x16x16xf32> @@ -179,3 +179,41 @@ // CHECK-NEXT: memref.store %{{.*}}, %memref[%{{.*}}, %[[C0]], %[[C0]]] : memref<2x16x16xf32> // CHECK-NEXT: gpu.terminator // CHECK-NEXT: } + +// ----- + +// CHECK-LABEL: func @make_reduce_uniform +// CHECK: gpu.launch blocks +// CHECK: %[[V1:.*]] = "test.test2"() : () -> i32 +// CHECK: %[[V2:.*]] = gpu.all_reduce add %[[V1]] uniform { +// CHECK: "test.test3"(%[[V2]]) : (i32) -> () +func.func @make_reduce_uniform() { + %0:6 = "test.test1"() : () -> (index, index, index, index, index, index) + gpu.launch blocks(%arg0, %arg1, %arg2) in (%arg6 = %0#0, %arg7 = %0#1, %arg8 = %0#2) + threads(%arg3, %arg4, %arg5) in (%arg9 = %0#3, %arg10 = %0#4, %arg11 = %0#5) { + %1 = "test.test2"() : () -> i32 + %2 = gpu.all_reduce add %1 {} : (i32) -> (i32) + "test.test3"(%2) : (i32) -> () + gpu.terminator + } + return +} + +// ----- + +// CHECK-LABEL: func @make_subgroup_reduce_uniform +// CHECK: gpu.launch blocks +// CHECK: %[[V1:.*]] = "test.test2"() : () -> i32 +// CHECK: %[[V2:.*]] = gpu.subgroup_reduce add %[[V1]] uniform +// CHECK: "test.test3"(%[[V2]]) : (i32) -> () +func.func @make_subgroup_reduce_uniform() { + %0:6 = "test.test1"() : () -> (index, index, index, index, index, index) + gpu.launch blocks(%arg0, %arg1, %arg2) in (%arg6 = %0#0, %arg7 = %0#1, %arg8 = %0#2) + threads(%arg3, %arg4, %arg5) in (%arg9 = %0#3, %arg10 = %0#4, %arg11 = %0#5) { + %1 = "test.test2"() : () -> i32 + %2 = gpu.subgroup_reduce add %1 : (i32) -> (i32) + "test.test3"(%2) : (i32) -> () + gpu.terminator + } + return +}