diff --git a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Arith/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Arith/Utils/Utils.h @@ -76,18 +76,18 @@ Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr); -/// Create a cast from an index-like value (index or integer) to another -/// index-like value. If the value type and the target type are the same, it -/// returns the original value. -Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, - Type targetType, Value value); - /// Similar to the other overload, but converts multiple OpFoldResults into /// Values. SmallVector getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, ArrayRef valueOrAttrVec); +/// Create a cast from an index-like value (index or integer) to another +/// index-like value. If the value type and the target type are the same, it +/// returns the original value. +Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, + Type targetType, Value value); + /// Converts a scalar value `operand` to type `toType`. If the value doesn't /// convert, a warning will be issued and the operand is returned as is (which /// will presumably yield a verification issue downstream). diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h @@ -19,6 +19,7 @@ class AffineDialect; class ModuleOp; +class RewriterBase; namespace arith { class WideIntEmulationConverter; @@ -102,6 +103,11 @@ /// "some_use"(%sv) : (memref<4x128xf32, strided<...>) -> () /// } /// ``` +FailureOr multiBuffer(RewriterBase &rewriter, + memref::AllocOp allocOp, + unsigned multiplier, + bool skipOverrideAnalysis = false); +/// Call into `multiBuffer` with locally constructed IRRewriter. FailureOr multiBuffer(memref::AllocOp allocOp, unsigned multiplier, bool skipOverrideAnalysis = false); diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCF.h b/mlir/include/mlir/Dialect/SCF/IR/SCF.h --- a/mlir/include/mlir/Dialect/SCF/IR/SCF.h +++ b/mlir/include/mlir/Dialect/SCF/IR/SCF.h @@ -13,6 +13,7 @@ #ifndef MLIR_DIALECT_SCF_SCF_H #define MLIR_DIALECT_SCF_SCF_H +#include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -573,17 +573,17 @@ /// Get lower bounds as values. SmallVector getLowerBound(OpBuilder &b) { - return getAsValues(b, getLoc(), getMixedLowerBound()); + return getValueOrCreateConstantIndexOp(b, getLoc(), getMixedLowerBound()); } /// Get upper bounds as values. SmallVector getUpperBound(OpBuilder &b) { - return getAsValues(b, getLoc(), getMixedUpperBound()); + return getValueOrCreateConstantIndexOp(b, getLoc(), getMixedUpperBound()); } /// Get steps as values. SmallVector getStep(OpBuilder &b) { - return getAsValues(b, getLoc(), getMixedStep()); + return getValueOrCreateConstantIndexOp(b, getLoc(), getMixedStep()); } int64_t getRank() { return getStaticLowerBound().size(); } diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h --- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h @@ -85,12 +85,20 @@ bool isEqualConstantIntOrValueArray(ArrayRef ofrs1, ArrayRef ofrs2); -/// Helper function to convert a vector of `OpFoldResult`s into a vector of -/// `Value`s. For each `OpFoldResult` in `valueOrAttrVec` return the fold -/// result if it casts to a `Value` or create an index-type constant if it -/// casts to `IntegerAttr`. No other attribute types are supported. -SmallVector getAsValues(OpBuilder &b, Location loc, - ArrayRef valueOrAttrVec); +// To convert an OpFoldResult to a Value of index type, see: +// mlir/include/mlir/Dialect/Arith/Utils/Utils.h +// TODO: find a better common landing place. +// +// Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, +// OpFoldResult ofr); + +// To convert an OpFoldResult to a Value of index type, see: +// mlir/include/mlir/Dialect/Arith/Utils/Utils.h +// TODO: find a better common landing place. +// +// SmallVector +// getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, +// ArrayRef valueOrAttrVec); /// Return a vector of OpFoldResults with the same size a staticValues, but /// all elements for which ShapedType::isDynamic is true, will be replaced by diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -674,7 +674,7 @@ return !isConstantIntValue(ofr, 0); })); SmallVector materializedNonZeroNumThreads = - getAsValues(b, loc, nonZeroNumThreads); + getValueOrCreateConstantIndexOp(b, loc, nonZeroNumThreads); // 2. Create the ForallOp with an empty region. scf::ForallOp forallOp = b.create( diff --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp --- a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp +++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp @@ -15,9 +15,13 @@ #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Interfaces/LoopLikeInterface.h" +#include "llvm/Support/Debug.h" using namespace mlir; +#define DEBUG_TYPE "memref-transforms" +#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") + //===----------------------------------------------------------------------===// // MemRefMultiBufferOp //===----------------------------------------------------------------------===// @@ -27,25 +31,36 @@ transform::TransformState &state) { SmallVector results; ArrayRef payloadOps = state.getPayloadOps(getTarget()); + IRRewriter rewriter(getContext()); for (auto *op : payloadOps) { bool canApplyMultiBuffer = true; auto target = cast(op); + LLVM_DEBUG(DBGS() << "Start multibuffer transform op: " << target << "\n";); // Skip allocations not used in a loop. for (Operation *user : target->getUsers()) { + if (isa(user)) + continue; auto loop = user->getParentOfType(); if (!loop) { + LLVM_DEBUG(DBGS() << "--allocation not used in a loop\n"; + DBGS() << "----due to user: " << *user;); canApplyMultiBuffer = false; break; } } - if (!canApplyMultiBuffer) + if (!canApplyMultiBuffer) { + LLVM_DEBUG(DBGS() << "--cannot apply multibuffering -> Skip\n";); continue; + } auto newBuffer = - memref::multiBuffer(target, getFactor(), getSkipAnalysis()); - if (failed(newBuffer)) + memref::multiBuffer(rewriter, target, getFactor(), getSkipAnalysis()); + + if (failed(newBuffer)) { + LLVM_DEBUG(DBGS() << "--op failed to multibuffer\n";); return emitSilenceableFailure(target->getLoc()) << "op failed to multibuffer"; + } results.push_back(*newBuffer); } diff --git a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp @@ -11,10 +11,16 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Dominance.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/ValueRange.h" #include "mlir/Interfaces/LoopLikeInterface.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/Support/Debug.h" using namespace mlir; @@ -35,46 +41,52 @@ /// propagate the type change. Changing the memref type may require propagating /// it through subview ops so we cannot just do a replaceAllUse but need to /// propagate the type change and erase old subview ops. -static void replaceUsesAndPropagateType(Operation *oldOp, Value val, - OpBuilder &builder) { - SmallVector opToDelete; +static void replaceUsesAndPropagateType(RewriterBase &rewriter, + Operation *oldOp, Value val) { + SmallVector opsToDelete; SmallVector operandsToReplace; + + // Save the operand to replace / delete later (avoid iterator invalidation). + // TODO: can we use an early_inc iterator? for (OpOperand &use : oldOp->getUses()) { + // Non-subview ops will be replaced by `val`. auto subviewUse = dyn_cast(use.getOwner()); if (!subviewUse) { - // Save the operand to and replace outside the loop to not invalidate the - // iterator. operandsToReplace.push_back(&use); continue; } - builder.setInsertionPoint(subviewUse); + + // `subview(old_op)` is replaced by a new `subview(val)`. + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(subviewUse); Type newType = memref::SubViewOp::inferRankReducedResultType( subviewUse.getType().getShape(), val.getType().cast(), subviewUse.getStaticOffsets(), subviewUse.getStaticSizes(), subviewUse.getStaticStrides()); - Value newSubview = builder.create( + Value newSubview = rewriter.create( subviewUse->getLoc(), newType.cast(), val, subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(), subviewUse.getMixedStrides()); - replaceUsesAndPropagateType(subviewUse, newSubview, builder); - opToDelete.push_back(use.getOwner()); + + // Ouch recursion ... is this really necessary? + replaceUsesAndPropagateType(rewriter, subviewUse, newSubview); + + opsToDelete.push_back(use.getOwner()); } - for (OpOperand *operand : operandsToReplace) + + // Perform late replacement. + // TODO: can we use an early_inc iterator? + for (OpOperand *operand : operandsToReplace) { + Operation *op = operand->getOwner(); + rewriter.startRootUpdate(op); operand->set(val); - // Clean up old subview ops. - for (Operation *op : opToDelete) - op->erase(); -} + rewriter.finalizeRootUpdate(op); + } -/// Helper to convert get a value from an OpFoldResult or create it at the -/// builder insert point. -static Value getOrCreateValue(OpFoldResult res, OpBuilder &builder, - Location loc) { - Value value = res.dyn_cast(); - if (value) - return value; - return builder.create( - loc, res.dyn_cast().cast().getInt()); + // Perform late op erasure. + // TODO: can we use an early_inc iterator? + for (Operation *op : opsToDelete) + rewriter.eraseOp(op); } // Transformation to do multi-buffering/array expansion to remove dependencies @@ -83,28 +95,37 @@ // This is not a pattern as it requires propagating the new memref type to its // uses and requires updating subview ops. FailureOr -mlir::memref::multiBuffer(memref::AllocOp allocOp, unsigned multiplier, +mlir::memref::multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp, + unsigned multiBufferingFactor, bool skipOverrideAnalysis) { - LLVM_DEBUG(DBGS() << "Try multibuffer: " << allocOp << "\n"); + LLVM_DEBUG(DBGS() << "Start multibuffering: " << allocOp << "\n"); DominanceInfo dom(allocOp->getParentOp()); LoopLikeOpInterface candidateLoop; for (Operation *user : allocOp->getUsers()) { auto parentLoop = user->getParentOfType(); if (!parentLoop) { - LLVM_DEBUG(DBGS() << "Skip user: no parent loop\n"); + if (isa(user)) { + // Allow dealloc outside of any loop. + // TODO: The whole precondition function here is very brittle and will + // need to rethought an isolated into a cleaner analysis. + continue; + } + LLVM_DEBUG(DBGS() << "--no parent loop -> fail\n"); + LLVM_DEBUG(DBGS() << "----due to user: " << *user << "\n"); return failure(); } if (!skipOverrideAnalysis) { /// Make sure there is no loop-carried dependency on the allocation. if (!overrideBuffer(user, allocOp.getResult())) { - LLVM_DEBUG(DBGS() << "Skip user: found loop-carried dependence\n"); + LLVM_DEBUG(DBGS() << "--Skip user: found loop-carried dependence\n"); continue; } // If this user doesn't dominate all the other users keep looking. if (llvm::any_of(allocOp->getUsers(), [&](Operation *otherUser) { return !dom.dominates(user, otherUser); })) { - LLVM_DEBUG(DBGS() << "Skip user: does not dominate all other users\n"); + LLVM_DEBUG( + DBGS() << "--Skip user: does not dominate all other users\n"); continue; } } else { @@ -114,17 +135,19 @@ })) { LLVM_DEBUG( DBGS() - << "Skip user: not all other users are in the parent loop\n"); + << "--Skip user: not all other users are in the parent loop\n"); continue; } } candidateLoop = parentLoop; break; } + if (!candidateLoop) { LLVM_DEBUG(DBGS() << "Skip alloc: no candidate loop\n"); return failure(); } + std::optional inductionVar = candidateLoop.getSingleInductionVar(); std::optional lowerBound = candidateLoop.getSingleLowerBound(); std::optional singleStep = candidateLoop.getSingleStep(); @@ -138,51 +161,89 @@ return failure(); } - OpBuilder builder(candidateLoop); - SmallVector newShape(1, multiplier); - ArrayRef oldShape = allocOp.getType().getShape(); - newShape.append(oldShape.begin(), oldShape.end()); - auto newMemref = MemRefType::get(newShape, allocOp.getType().getElementType(), - MemRefLayoutAttrInterface(), - allocOp.getType().getMemorySpace()); - builder.setInsertionPoint(allocOp); + LLVM_DEBUG(DBGS() << "Start multibuffering loop: " << candidateLoop << "\n"); + + // 1. Construct the multi-buffered memref type. + ArrayRef originalShape = allocOp.getType().getShape(); + SmallVector multiBufferedShape{multiBufferingFactor}; + llvm::append_range(multiBufferedShape, originalShape); + LLVM_DEBUG(DBGS() << "--original type: " << allocOp.getType() << "\n"); + MemRefType mbMemRefType = MemRefType::Builder(allocOp.getType()) + .setShape(multiBufferedShape) + .setLayout(MemRefLayoutAttrInterface()); + LLVM_DEBUG(DBGS() << "--multi-buffered type: " << mbMemRefType << "\n"); + + // 2. Create the multi-buffered alloc. Location loc = allocOp->getLoc(); - auto newAlloc = builder.create(loc, newMemref, ValueRange{}, - allocOp->getAttrs()); - builder.setInsertionPoint(&candidateLoop.getLoopBody().front(), - candidateLoop.getLoopBody().front().begin()); - - SmallVector operands = {*inductionVar}; - AffineExpr induc = getAffineDimExpr(0, allocOp.getContext()); - unsigned dimCount = 1; - auto getAffineExpr = [&](OpFoldResult e) -> AffineExpr { - if (std::optional constValue = getConstantIntValue(e)) { - return getAffineConstantExpr(*constValue, allocOp.getContext()); - } - auto value = getOrCreateValue(e, builder, candidateLoop->getLoc()); - operands.push_back(value); - return getAffineDimExpr(dimCount++, allocOp.getContext()); - }; - auto init = getAffineExpr(*lowerBound); - auto step = getAffineExpr(*singleStep); - - AffineExpr expr = ((induc - init).floorDiv(step)) % multiplier; - auto map = AffineMap::get(dimCount, 0, expr); - Value bufferIndex = builder.create(loc, map, operands); - SmallVector offsets, sizes, strides; - offsets.push_back(bufferIndex); - offsets.append(oldShape.size(), builder.getIndexAttr(0)); - strides.assign(oldShape.size() + 1, builder.getIndexAttr(1)); - sizes.push_back(builder.getIndexAttr(1)); - for (int64_t size : oldShape) - sizes.push_back(builder.getIndexAttr(size)); - auto dstMemref = - memref::SubViewOp::inferRankReducedResultType( - allocOp.getType().getShape(), newMemref, offsets, sizes, strides) - .cast(); - Value subview = builder.create(loc, dstMemref, newAlloc, - offsets, sizes, strides); - replaceUsesAndPropagateType(allocOp, subview, builder); - allocOp.erase(); - return newAlloc; + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(allocOp); + auto mbAlloc = rewriter.create( + loc, mbMemRefType, ValueRange{}, allocOp->getAttrs()); + LLVM_DEBUG(DBGS() << "--multi-buffered alloc: " << mbAlloc << "\n"); + + // 3. Within the loop, build the modular leading index (i.e. each loop + // iteration %iv accesses slice ((%iv - %lb) / %step) % %mb_factor). + rewriter.setInsertionPointToStart(&candidateLoop.getLoopBody().front()); + Value ivVal = *inductionVar; + Value lbVal = getValueOrCreateConstantIndexOp(rewriter, loc, *lowerBound); + Value stepVal = getValueOrCreateConstantIndexOp(rewriter, loc, *singleStep); + AffineExpr iv, lb, step; + bindDims(rewriter.getContext(), iv, lb, step); + Value bufferIndex = makeComposedAffineApply( + rewriter, loc, ((iv - lb).floorDiv(step)) % multiBufferingFactor, + {ivVal, lbVal, stepVal}); + LLVM_DEBUG(DBGS() << "--multi-buffered indexing: " << bufferIndex << "\n"); + + // 4. Build the subview accessing the particular slice, taking modular + // rotation into account. + int64_t mbMemRefTypeRank = mbMemRefType.getRank(); + IntegerAttr zero = rewriter.getIndexAttr(0); + IntegerAttr one = rewriter.getIndexAttr(1); + SmallVector offsets(mbMemRefTypeRank, zero); + SmallVector sizes(mbMemRefTypeRank, one); + SmallVector strides(mbMemRefTypeRank, one); + // Offset is [bufferIndex, 0 ... 0 ]. + offsets.front() = bufferIndex; + // Sizes is [1, original_size_0 ... original_size_n ]. + for (int64_t i = 0, e = originalShape.size(); i != e; ++i) + sizes[1 + i] = rewriter.getIndexAttr(originalShape[i]); + // Strides is [1, 1 ... 1 ]. + auto dstMemref = memref::SubViewOp::inferRankReducedResultType( + originalShape, mbMemRefType, offsets, sizes, strides) + .cast(); + Value subview = rewriter.create(loc, dstMemref, mbAlloc, + offsets, sizes, strides); + LLVM_DEBUG(DBGS() << "--multi-buffered slice: " << subview << "\n"); + + // 5. Due to the recursive nature of replaceUsesAndPropagateType , we need to + // handle dealloc uses separately.. + for (OpOperand &use : llvm::make_early_inc_range(allocOp->getUses())) { + auto deallocOp = dyn_cast(use.getOwner()); + if (!deallocOp) + continue; + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(deallocOp); + auto newDeallocOp = + rewriter.create(deallocOp->getLoc(), mbAlloc); + (void)newDeallocOp; + LLVM_DEBUG(DBGS() << "----Created dealloc: " << newDeallocOp << "\n"); + rewriter.eraseOp(deallocOp); + } + + // 6. RAUW with the particular slice, taking modular rotation into account. + replaceUsesAndPropagateType(rewriter, allocOp, subview); + + // 7. Finally, erase the old allocOp. + rewriter.eraseOp(allocOp); + + return mbAlloc; +} + +FailureOr +mlir::memref::multiBuffer(memref::AllocOp allocOp, + unsigned multiBufferingFactor, + bool skipOverrideAnalysis) { + IRRewriter rewriter(allocOp->getContext()); + return multiBuffer(rewriter, allocOp, multiBufferingFactor, + skipOverrideAnalysis); } diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -922,9 +922,10 @@ continue; // If the dest type of the cast does not preserve static information in // the source type. - if (!tensor::preservesStaticInformation(incomingCast.getDest().getType(), - incomingCast.getSource().getType())) - continue; + if (!tensor::preservesStaticInformation( + incomingCast.getDest().getType(), + incomingCast.getSource().getType())) + continue; if (!std::get<1>(it).hasOneUse()) continue; diff --git a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Interfaces/InferTypeOpInterface.h" @@ -146,7 +147,8 @@ auto resultShape = getReshapeOutputShapeFromInputShape( b, loc, reshapeOp.getSrc(), reshapeOp.getResultType().getShape(), reshapeOp.getReassociationMaps()); - reifiedReturnShapes.push_back(getAsValues(b, loc, resultShape)); + reifiedReturnShapes.push_back( + getValueOrCreateConstantIndexOp(b, loc, resultShape)); return success(); } }; diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp --- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp +++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp @@ -146,18 +146,6 @@ return true; } -/// Helper function to convert a vector of `OpFoldResult`s into a vector of -/// `Value`s. For each `OpFoldResult` in `valueOrAttrVec` return the fold result -/// if it casts to a `Value` or create an index-type constant if it casts to -/// `IntegerAttr`. No other attribute types are supported. -SmallVector getAsValues(OpBuilder &b, Location loc, - ArrayRef valueOrAttrVec) { - return llvm::to_vector<4>( - llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value { - return getValueOrCreateConstantIndexOp(b, loc, value); - })); -} - /// Return a vector of OpFoldResults with the same size a staticValues, but all /// elements for which ShapedType::isDynamic is true, will be replaced by /// dynamicValues. diff --git a/mlir/test/Dialect/MemRef/transform-ops.mlir b/mlir/test/Dialect/MemRef/transform-ops.mlir --- a/mlir/test/Dialect/MemRef/transform-ops.mlir +++ b/mlir/test/Dialect/MemRef/transform-ops.mlir @@ -219,3 +219,40 @@ // Verify that the returned handle is usable. transform.test_print_remark_at_operand %1, "transformed" : !pdl.operation } + +// ----- + +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0) -> ((d0 floordiv 4) mod 2)> + +// CHECK-LABEL: func @multi_buffer_dealloc +func.func @multi_buffer_dealloc(%in: memref<16xf32>) { + // CHECK: %[[A:.*]] = memref.alloc() : memref<2x4xf32> + // expected-remark @below {{transformed}} + %tmp = memref.alloc() : memref<4xf32> + + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[C4:.*]] = arith.constant 4 : index + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c16 = arith.constant 16 : index + + // CHECK: scf.for %[[IV:.*]] = %[[C0]] + scf.for %i0 = %c0 to %c16 step %c4 { + // CHECK: %[[I:.*]] = affine.apply #[[$MAP0]](%[[IV]]) + // CHECK: %[[SV:.*]] = memref.subview %[[A]][%[[I]], 0] [1, 4] [1, 1] : memref<2x4xf32> to memref<4xf32, strided<[1], offset: ?>> + "some_write_read"(%tmp) : (memref<4xf32>) ->() + } + + // CHECK-NOT: memref.dealloc {{.*}} : memref<4xf32> + // CHECK: memref.dealloc %[[A]] : memref<2x4xf32> + memref.dealloc %tmp : memref<4xf32> + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["memref.alloc"]} in %arg1 : (!pdl.operation) -> !transform.op<"memref.alloc"> + %1 = transform.memref.multibuffer %0 {factor = 2 : i64, skip_analysis} : (!transform.op<"memref.alloc">) -> !pdl.operation + // Verify that the returned handle is usable. + transform.test_print_remark_at_operand %1, "transformed" : !pdl.operation +}