diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td @@ -999,6 +999,7 @@ }]; let hasFolder = 1; let hasVerifier = 1; + let hasCanonicalizer = 1; } def GPU_MemsetOp : GPU_Op<"memset", diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -1064,6 +1064,55 @@ printer << "]"; } +namespace { + +/// Erases a common case of copy ops where a destination value is used only by +/// the copy op, alloc and dealloc ops. +struct EraseTrivialCopyOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(MemcpyOp op, + PatternRewriter &rewriter) const override { + Value dest = op.dst(); + // If `dest` is a block argument, we cannot remove `op`. + if (dest.isa()) + return failure(); + auto isDeallocLikeOpActingOnVal = [](Operation *op, Value val) { + auto memOp = dyn_cast(op); + if (!memOp) + return false; + llvm::SmallVector, 4> + memOpEffects; + memOp.getEffects(memOpEffects); + return llvm::none_of(memOpEffects, [val](auto &effect) { + return effect.getValue() == val && + !isa(effect.getEffect()); + }); + }; + // We can erase `op` iff `dest` has no other use apart from its + // use by `op` and dealloc ops. + if (llvm::any_of(dest.getUsers(), [isDeallocLikeOpActingOnVal, op, + dest](Operation *user) { + return user != op && !isDeallocLikeOpActingOnVal(user, dest); + })) + return failure(); + + if (op.asyncDependencies().size() > 1 || + ((op.asyncDependencies().empty() && op.asyncToken()) || + (!op.asyncDependencies().empty() && !op.asyncToken()))) + return failure(); + rewriter.replaceOp(op, op.asyncDependencies()); + return success(); + } +}; + +} // end anonymous namespace + +void MemcpyOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + //===----------------------------------------------------------------------===// // GPU_SubgroupMmaLoadMatrixOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/GPU/canonicalize.mlir b/mlir/test/Dialect/GPU/canonicalize.mlir --- a/mlir/test/Dialect/GPU/canonicalize.mlir +++ b/mlir/test/Dialect/GPU/canonicalize.mlir @@ -28,6 +28,60 @@ // CHECK-NEXT: gpu.alloc async [%[[TOKEN1]]] () // CHECK-NEXT: return +// CHECK-LABEL: func @fold_memcpy_op +func @fold_memcpy_op(%arg0: i1) { + %cst = arith.constant 0.000000e+00 : f16 + %1 = memref.alloc() : memref<2xf16> + %2 = gpu.wait async + %memref, %asyncToken = gpu.alloc async [%2] () : memref<2xf16> + gpu.wait [%2] + affine.store %cst, %memref[0] : memref<2xf16> + %3 = gpu.wait async + %4 = gpu.memcpy async [%3] %1, %memref : memref<2xf16>, memref<2xf16> + gpu.wait [%3] + %5 = scf.if %arg0 -> (i1) { + memref.dealloc %1 : memref<2xf16> + scf.yield %arg0 : i1 + } else { + memref.dealloc %1 : memref<2xf16> + scf.yield %arg0 : i1 + } + return +} +// CHECK-NOT: gpu.memcpy + +// We cannot fold memcpy here as dest is a block argument. +// CHECK-LABEL: func @do_not_fold_memcpy_op1 +func @do_not_fold_memcpy_op1(%arg0: i1, %arg1: memref<2xf16>) { + %cst = arith.constant 0.000000e+00 : f16 + %2 = gpu.wait async + %memref, %asyncToken = gpu.alloc async [%2] () : memref<2xf16> + gpu.wait [%2] + affine.store %cst, %memref[0] : memref<2xf16> + %3 = gpu.wait async + %4 = gpu.memcpy async [%3] %arg1, %memref : memref<2xf16>, memref<2xf16> + gpu.wait [%3] + return +} +// CHECK: gpu.memcpy + +// We cannot fold gpu.memcpy as it is used by an op having read effect on dest. +// CHECK-LABEL: func @do_not_fold_memcpy_op2 +func @do_not_fold_memcpy_op2(%arg0: i1, %arg1: index) -> f16 { + %cst = arith.constant 0.000000e+00 : f16 + %1 = memref.alloc() : memref<2xf16> + %2 = gpu.wait async + %memref, %asyncToken = gpu.alloc async [%2] () : memref<2xf16> + gpu.wait [%2] + affine.store %cst, %memref[0] : memref<2xf16> + %3 = gpu.wait async + %4 = gpu.memcpy async [%3] %1, %memref : memref<2xf16>, memref<2xf16> + gpu.wait [%3] + %5 = memref.load %1[%arg1] : memref<2xf16> + return %5 : f16 +} +// CHECK: gpu.memcpy + // CHECK-LABEL: @memcpy_after_cast func @memcpy_after_cast(%arg0: memref<10xf32>, %arg1: memref<10xf32>) { // CHECK-NOT: memref.cast