diff --git a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td --- a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td +++ b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td @@ -28,6 +28,9 @@ iterations. This transform expands the size of an allocation by a given multiplicative factor and fixes up any users of the multibuffered allocation. + If skip analysis is not set the transformation will only apply + if it can prove that there is no data being carried across loop + iterations. #### Return modes @@ -37,7 +40,8 @@ let arguments = (ins Transform_MemRefAllocOp:$target, - ConfinedAttr:$factor); + ConfinedAttr:$factor, + UnitAttr:$skip_analysis); let results = (outs PDL_Operation:$transformed); diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h @@ -78,7 +78,10 @@ /// on the temporary allocation between consecutive loop iterations. /// It returns the new allocation if the original allocation was multi-buffered /// and returns failure() otherwise. -/// Example: +/// When `skipOverrideAnalysis`, the pass will apply the transformation +/// without checking thwt the buffer is overrided at the beginning of each +/// iteration. This implies that user knows that there is no data carried across +/// loop iterations. Example: /// ``` /// %0 = memref.alloc() : memref<4x128xf32> /// scf.for %iv = %c1 to %c1024 step %c3 { @@ -100,7 +103,8 @@ /// } /// ``` FailureOr multiBuffer(memref::AllocOp allocOp, - unsigned multiplier); + unsigned multiplier, + bool skipOverrideAnalysis = false); //===----------------------------------------------------------------------===// // Passes diff --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp --- a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp +++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp @@ -41,7 +41,8 @@ if (!canApplyMultiBuffer) continue; - auto newBuffer = memref::multiBuffer(target, getFactor()); + auto newBuffer = + memref::multiBuffer(target, getFactor(), getSkipAnalysis()); if (failed(newBuffer)) return emitSilenceableFailure(target->getLoc()) << "op failed to multibuffer"; diff --git a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp @@ -82,8 +82,9 @@ // Returns success if the transformation happened and failure otherwise. // This is not a pattern as it requires propagating the new memref type to its // uses and requires updating subview ops. -FailureOr mlir::memref::multiBuffer(memref::AllocOp allocOp, - unsigned multiplier) { +FailureOr +mlir::memref::multiBuffer(memref::AllocOp allocOp, unsigned multiplier, + bool skipOverrideAnalysis) { LLVM_DEBUG(DBGS() << "Try multibuffer: " << allocOp << "\n"); DominanceInfo dom(allocOp->getParentOp()); LoopLikeOpInterface candidateLoop; @@ -93,17 +94,29 @@ LLVM_DEBUG(DBGS() << "Skip user: no parent loop\n"); return failure(); } - /// Make sure there is no loop-carried dependency on the allocation. - if (!overrideBuffer(user, allocOp.getResult())) { - LLVM_DEBUG(DBGS() << "Skip user: found loop-carried dependence\n"); - continue; - } - // If this user doesn't dominate all the other users keep looking. - if (llvm::any_of(allocOp->getUsers(), [&](Operation *otherUser) { - return !dom.dominates(user, otherUser); - })) { - LLVM_DEBUG(DBGS() << "Skip user: does not dominate all other users\n"); - continue; + if (!skipOverrideAnalysis) { + /// Make sure there is no loop-carried dependency on the allocation. + if (!overrideBuffer(user, allocOp.getResult())) { + LLVM_DEBUG(DBGS() << "Skip user: found loop-carried dependence\n"); + continue; + } + // If this user doesn't dominate all the other users keep looking. + if (llvm::any_of(allocOp->getUsers(), [&](Operation *otherUser) { + return !dom.dominates(user, otherUser); + })) { + LLVM_DEBUG(DBGS() << "Skip user: does not dominate all other users\n"); + continue; + } + } else { + if (llvm::any_of(allocOp->getUsers(), [&](Operation *otherUser) { + return !isa(otherUser) && + !parentLoop->isProperAncestor(otherUser); + })) { + LLVM_DEBUG( + DBGS() + << "Skip user: not all other users are in the parent loop\n"); + continue; + } } candidateLoop = parentLoop; break; diff --git a/mlir/test/Dialect/MemRef/transform-ops.mlir b/mlir/test/Dialect/MemRef/transform-ops.mlir --- a/mlir/test/Dialect/MemRef/transform-ops.mlir +++ b/mlir/test/Dialect/MemRef/transform-ops.mlir @@ -185,3 +185,37 @@ // Verify that the returned handle is usable. transform.test_print_remark_at_operand %1, "transformed" : !pdl.operation } + +// ----- + + +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0) -> ((d0 floordiv 4) mod 2)> + +// CHECK-LABEL: func @multi_buffer +func.func @multi_buffer_no_analysis(%in: memref<16xf32>) { + // CHECK: %[[A:.*]] = memref.alloc() : memref<2x4xf32> + // expected-remark @below {{transformed}} + %tmp = memref.alloc() : memref<4xf32> + + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[C4:.*]] = arith.constant 4 : index + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c16 = arith.constant 16 : index + + // CHECK: scf.for %[[IV:.*]] = %[[C0]] + scf.for %i0 = %c0 to %c16 step %c4 { + // CHECK: %[[I:.*]] = affine.apply #[[$MAP0]](%[[IV]]) + // CHECK: %[[SV:.*]] = memref.subview %[[A]][%[[I]], 0] [1, 4] [1, 1] : memref<2x4xf32> to memref<4xf32, strided<[1], offset: ?>> + "some_write_read"(%tmp) : (memref<4xf32>) ->() + } + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["memref.alloc"]} in %arg1 : (!pdl.operation) -> !transform.op<"memref.alloc"> + %1 = transform.memref.multibuffer %0 {factor = 2 : i64, skip_analysis} : (!transform.op<"memref.alloc">) -> !pdl.operation + // Verify that the returned handle is usable. + transform.test_print_remark_at_operand %1, "transformed" : !pdl.operation +}