Index: mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td =================================================================== --- mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td +++ mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td @@ -25,6 +25,9 @@ iterations. This transform expands the size of an allocation by a given multiplicative factor and fixes up any users of the multibuffered allocation. + If skip verification is not set the transformation will only apply + if it can prove that there is no data being carried across loop + iterations. #### Return modes @@ -34,7 +37,8 @@ let arguments = (ins PDL_Operation:$target, - ConfinedAttr:$factor); + ConfinedAttr:$factor, + UnitAttr:$skip_verification); let results = (outs PDL_Operation:$transformed); Index: mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h =================================================================== --- mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h +++ mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h @@ -78,7 +78,10 @@ /// on the temporary allocation between consecutive loop iterations. /// It returns the new allocation if the original allocation was multi-buffered /// and returns failure() otherwise. -/// Example: +/// When `skipOverrideVerification`, the pass will apply the transformation +/// without checking thwt the buffer is overrided at the beginning of each +/// iteration. This implies that user knows that there is no data carried across +/// loop iterations. Example: /// ``` /// %0 = memref.alloc() : memref<4x128xf32> /// scf.for %iv = %c1 to %c1024 step %c3 { @@ -100,7 +103,8 @@ /// } /// ``` FailureOr multiBuffer(memref::AllocOp allocOp, - unsigned multiplier); + unsigned multiplier, + bool skipOverrideVerification = false); //===----------------------------------------------------------------------===// // Passes Index: mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp =================================================================== --- mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp +++ mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp @@ -24,7 +24,8 @@ DiagnosedSilenceableFailure transform::MemRefMultiBufferOp::applyToOne( memref::AllocOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { - auto newBuffer = memref::multiBuffer(target, getFactor()); + auto newBuffer = + memref::multiBuffer(target, getFactor(), getSkipVerification()); if (failed(newBuffer)) { Diagnostic diag(target->getLoc(), DiagnosticSeverity::Note); diag << "op failed to multibuffer"; Index: mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp =================================================================== --- mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp +++ mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp @@ -77,22 +77,32 @@ // Returns success if the transformation happened and failure otherwise. // This is not a pattern as it requires propagating the new memref type to its // uses and requires updating subview ops. -FailureOr mlir::memref::multiBuffer(memref::AllocOp allocOp, - unsigned multiplier) { +FailureOr +mlir::memref::multiBuffer(memref::AllocOp allocOp, unsigned multiplier, + bool skipOverrideVerification) { DominanceInfo dom(allocOp->getParentOp()); LoopLikeOpInterface candidateLoop; for (Operation *user : allocOp->getUsers()) { auto parentLoop = user->getParentOfType(); if (!parentLoop) return failure(); - /// Make sure there is no loop carried dependency on the allocation. - if (!overrideBuffer(user, allocOp.getResult())) - continue; - // If this user doesn't dominate all the other users keep looking. - if (llvm::any_of(allocOp->getUsers(), [&](Operation *otherUser) { - return !dom.dominates(user, otherUser); - })) - continue; + if (!skipOverrideVerification) { + /// Make sure there is no loop carried dependency on the allocation. + if (!overrideBuffer(user, allocOp.getResult())) + continue; + // If this user doesn't dominate all the other users keep looking. + if (llvm::any_of(allocOp->getUsers(), [&](Operation *otherUser) { + return !isa(otherUser) && + !dom.dominates(user, otherUser); + })) + continue; + } else { + if (llvm::any_of(allocOp->getUsers(), [&](Operation *otherUser) { + return !isa(otherUser) && + !parentLoop->isProperAncestor(otherUser); + })) + continue; + } candidateLoop = parentLoop; break; } Index: mlir/test/Dialect/MemRef/transform-ops.mlir =================================================================== --- mlir/test/Dialect/MemRef/transform-ops.mlir +++ mlir/test/Dialect/MemRef/transform-ops.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -test-transform-dialect-interpreter -verify-diagnostics -allow-unregistered-dialect | FileCheck %s +// RUN: mlir-opt %s -test-transform-dialect-interpreter -verify-diagnostics -allow-unregistered-dialect -split-input-file | FileCheck %s // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0) -> ((d0 floordiv 4) mod 2)> // CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0)[s0] -> (d0 + s0)> @@ -35,3 +35,37 @@ // Verify that the returned handle is usable. transform.test_print_remark_at_operand %1, "transformed" : !pdl.operation } + +// ----- + + +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0) -> ((d0 floordiv 4) mod 2)> + +// CHECK-LABEL: func @multi_buffer +func.func @multi_buffer_no_verification(%in: memref<16xf32>) { + // CHECK: %[[A:.*]] = memref.alloc() : memref<2x4xf32> + // expected-remark @below {{transformed}} + %tmp = memref.alloc() : memref<4xf32> + + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[C4:.*]] = arith.constant 4 : index + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c16 = arith.constant 16 : index + + // CHECK: scf.for %[[IV:.*]] = %[[C0]] + scf.for %i0 = %c0 to %c16 step %c4 { + // CHECK: %[[I:.*]] = affine.apply #[[$MAP0]](%[[IV]]) + // CHECK: %[[SV:.*]] = memref.subview %[[A]][%[[I]], 0] [1, 4] [1, 1] : memref<2x4xf32> to memref<4xf32, strided<[1], offset: ?>> + "some_write_read"(%tmp) : (memref<4xf32>) ->() + } + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["memref.alloc"]} in %arg1 : (!pdl.operation) -> !pdl.operation + %1 = transform.memref.multibuffer %0 {factor = 2 : i64, skip_verification} + // Verify that the returned handle is usable. + transform.test_print_remark_at_operand %1, "transformed" : !pdl.operation +}