diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td @@ -888,6 +888,8 @@ let assemblyFormat = [{ custom(type($asyncToken), $asyncDependencies) attr-dict }]; + + let hasCanonicalizer = 1; } def GPU_AllocOp : GPU_Op<"alloc", [ diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -1184,6 +1184,78 @@ return foldMemRefCast(*this); } +//===----------------------------------------------------------------------===// +// GPU_WaitOp +//===----------------------------------------------------------------------===// + +namespace { + +/// Remove gpu.wait op use of gpu.wait op def without async dependencies. +/// %t = gpu.wait async [] // No async dependencies. +/// ... gpu.wait ... [%t, ...] // %t can be removed. +struct EraseRedundantGpuWaitOpPairs : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(WaitOp op, + PatternRewriter &rewriter) const final { + auto predicate = [](Value value) { + auto wait_op = value.getDefiningOp(); + return wait_op && wait_op->getNumOperands() == 0; + }; + if (llvm::none_of(op.asyncDependencies(), predicate)) + return failure(); + SmallVector validOperands; + for (Value operand : op->getOperands()) { + if (predicate(operand)) + continue; + validOperands.push_back(operand); + } + op->setOperands(validOperands); + return success(); + } +}; + +/// Simplify trivial gpu.wait ops for the following patterns. +/// 1. %t = gpu.wait async ... ops, where %t has no uses (regardless of async +/// dependencies). +/// 2. %t1 = gpu.wait async [%t0], in this case, we can replace uses of %t1 with +/// %t0. +/// 3. gpu.wait [] ops, i.e gpu.wait ops that neither have any async +/// dependencies nor return any token. +struct SimplifyGpuWaitOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(WaitOp op, + PatternRewriter &rewriter) const final { + // Erase gpu.wait ops that neither have any async dependencies nor return + // any async token. + if (op.asyncDependencies().empty() && !op.asyncToken()) { + rewriter.eraseOp(op); + return success(); + } + // Replace uses of %t1 = gpu.wait async [%t0] ops with %t0 and erase the op. + if (llvm::hasSingleElement(op.asyncDependencies()) && op.asyncToken()) { + rewriter.replaceOp(op, op.asyncDependencies()); + return success(); + } + // Erase %t = gpu.wait async ... ops, where %t has no uses. + if (op.asyncToken() && op.asyncToken().use_empty()) { + rewriter.eraseOp(op); + return success(); + } + return failure(); + } +}; + +} // end anonymous namespace + +void WaitOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + //===----------------------------------------------------------------------===// // GPU_AllocOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/GPU/canonicalize.mlir b/mlir/test/Dialect/GPU/canonicalize.mlir --- a/mlir/test/Dialect/GPU/canonicalize.mlir +++ b/mlir/test/Dialect/GPU/canonicalize.mlir @@ -1,5 +1,33 @@ // RUN: mlir-opt %s -canonicalize --split-input-file -allow-unregistered-dialect | FileCheck %s +// Fold all the gpu.wait ops as they are redundant. +// CHECK-LABEL: func @fold_wait_op_test1 +func @fold_wait_op_test1() { + %1 = gpu.wait async + gpu.wait [] + %3 = gpu.wait async + gpu.wait [%3] + return +} +// CHECK-NOT: gpu.wait + +// Replace uses of gpu.wait op with its async dependency. +// CHECK-LABEL: func @fold_wait_op_test2 +func @fold_wait_op_test2(%arg0: i1) -> (memref<5xf16>, memref<5xf16>) { + %0 = gpu.wait async + %memref, %asyncToken = gpu.alloc async [%0] () : memref<5xf16> + gpu.wait [%0] + %1 = gpu.wait async [%0] + %memref_0, %asyncToken_0 = gpu.alloc async [%1] () : memref<5xf16> + gpu.wait [%1] + return %memref, %memref_0 : memref<5xf16>, memref<5xf16> +} +// CHECK-NEXT: %[[TOKEN0:.*]] = gpu.wait async +// CHECK-NEXT: gpu.alloc async [%[[TOKEN0]]] () +// CHECK-NEXT: %[[TOKEN1:.*]] = gpu.wait async +// CHECK-NEXT: gpu.alloc async [%[[TOKEN1]]] () +// CHECK-NEXT: return + // CHECK-LABEL: @memcpy_after_cast func @memcpy_after_cast(%arg0: memref<10xf32>, %arg1: memref<10xf32>) { // CHECK-NOT: memref.cast