diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td @@ -523,6 +523,7 @@ let parser = [{ return parseLaunchOp(parser, result); }]; let printer = [{ printLaunchOp(p, *this); }]; let verifier = [{ return ::verify(*this); }]; + let hasCanonicalizer = 1; } def GPU_ReturnOp : GPU_Op<"return", [HasParent<"GPUFuncOp">, NoSideEffect, diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -21,6 +21,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/FunctionImplementation.h" +#include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" @@ -530,6 +531,49 @@ parser.parseOptionalAttrDict(result.attributes)); } +/// Simplify the gpu.launch when the range of the thread and block IDs is +/// trivially known to be one. +struct FoldLaunchArguments : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(LaunchOp op, + PatternRewriter &rewriter) const override { + auto isTriviallyOne = [](Value size) { + IntegerAttr cst; + return matchPattern(size, m_Constant(&cst)) && cst.getInt() == 1; + }; + + // If the range implies a single value for `id`, replace `id`'s uses by + // zero. + Value zero; + bool simplified = false; + auto constPropIdUses = [&](Value id, Value size) { + if (!isTriviallyOne(size)) + return; + if (!simplified) { + // Create a zero value the first time. + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&op.body().front()); + zero = rewriter.create(op.getLoc(), /*value=*/0); + } + id.replaceAllUsesWith(zero); + simplified = true; + }; + constPropIdUses(op.getBlockIds().x, op.gridSizeX()); + constPropIdUses(op.getBlockIds().y, op.gridSizeY()); + constPropIdUses(op.getBlockIds().z, op.gridSizeZ()); + constPropIdUses(op.getThreadIds().x, op.blockSizeX()); + constPropIdUses(op.getThreadIds().y, op.blockSizeY()); + constPropIdUses(op.getThreadIds().z, op.blockSizeZ()); + + return simplified ? success() : failure(); + } +}; + +void LaunchOp::getCanonicalizationPatterns(RewritePatternSet &rewrites, + MLIRContext *context) { + rewrites.add(context); +} + //===----------------------------------------------------------------------===// // LaunchFuncOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/GPU/canonicalize.mlir b/mlir/test/Dialect/GPU/canonicalize.mlir --- a/mlir/test/Dialect/GPU/canonicalize.mlir +++ b/mlir/test/Dialect/GPU/canonicalize.mlir @@ -31,3 +31,59 @@ %1 = memref.dim %0, %c0 : memref return %1 : index } + +// ----- + +// CHECK-LABEL: func @simplify_gpu_launch +func @simplify_gpu_launch() attributes {llvm.emit_c_interface} { + %cst = constant 0.000000e+00 : f32 + %c1 = constant 1 : index + %c32 = constant 32 : index + %c16 = constant 16 : index + %c2 = constant 2 : index + %c0 = constant 0 : index + %0 = memref.alloc() : memref<2x16x16xf32> + scf.for %arg0 = %c0 to %c2 step %c1 { + scf.for %arg1 = %c0 to %c16 step %c1 { + scf.for %arg2 = %c0 to %c16 step %c1 { + memref.store %cst, %0[%arg0, %arg1, %arg2] : memref<2x16x16xf32> + } + } + } + %1 = gpu.wait async + %memref, %asyncToken = gpu.alloc async [%1] () : memref<2x16x16xf32> + %2 = gpu.memcpy async [%1] %memref, %0 : memref<2x16x16xf32>, memref<2x16x16xf32> + gpu.wait [%1] + gpu.launch blocks(%arg0, %arg1, %arg2) in (%arg6 = %c1, %arg7 = %c1, %arg8 = %c1) + threads(%arg3, %arg4, %arg5) in (%arg9 = %c32, %arg10 = %c1, %arg11 = %c1) { + %3 = muli %arg5, %c32 : index + %4 = muli %arg4, %c32 : index + %5 = addi %3, %4 : index + %6 = addi %5, %arg3 : index + %7 = divi_unsigned %6, %c32 : index + %8 = muli %arg0, %c16 : index + %9 = muli %arg1, %c2 : index + %10 = muli %7, %c2 : index + %11 = addi %9, %10 : index + %12 = memref.load %memref[%11, %c0, %8] : memref<2x16x16xf32> + %13 = addi %11, %c1 : index + %14 = memref.load %memref[%13, %c0, %8] : memref<2x16x16xf32> + memref.store %12, %memref[%11, %c0, %8] : memref<2x16x16xf32> + memref.store %14, %memref[%13, %c0, %8] : memref<2x16x16xf32> + gpu.terminator + } + return +} + +// CHECK-DAG: %[[C1:.*]] = constant 1 : index +// CHECK-DAG: %[[C0:.*]] = constant 0 : index +// CHECK: gpu.launch blocks(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %[[C1]], %{{.*}} = %[[C1]], %{{.*}} = %[[C1]]) threads(%[[TIDX:.*]], %{{.*}}, %{{.*}}) in (%{{.*}} = %c32, %{{.*}} = %[[C1]], %{{.*}} = %[[C1]]) { +// CHECK-NEXT: divi_unsigned %[[TIDX]], %c32 : index +// CHECK-NEXT: muli %{{.*}}, %c2 : index +// CHECK-NEXT: memref.load %memref[%{{.*}}, %[[C0]], %[[C0]]] : memref<2x16x16xf32> +// CHECK-NEXT: addi %{{.*}}, %[[C1]] : index +// CHECK-NEXT: memref.load %memref[%{{.*}}, %[[C0]], %[[C0]]] : memref<2x16x16xf32> +// CHECK-NEXT: memref.store %{{.*}}, %memref[%{{.*}}, %[[C0]], %[[C0]]] : memref<2x16x16xf32> +// CHECK-NEXT: memref.store %{{.*}}, %memref[%{{.*}}, %[[C0]], %[[C0]]] : memref<2x16x16xf32> +// CHECK-NEXT: gpu.terminator +// CHECK-NEXT: }