diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td @@ -108,7 +108,8 @@ def AffineForOp : Affine_Op<"for", [AutomaticAllocationScope, ImplicitAffineTerminator, RecursiveSideEffects, - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods]> { let summary = "for operation"; let description = [{ Syntax: diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h @@ -27,7 +27,7 @@ } // namespace vector namespace memref { - +class AllocOp; //===----------------------------------------------------------------------===// // Patterns //===----------------------------------------------------------------------===// @@ -51,6 +51,33 @@ /// terms of shapes of its input operands. void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns); +/// Transformation to do multi-buffering/array expansion to remove dependencies +/// on the temporary allocation between consecutive loop iterations. +/// It return success if the allocation was multi-buffered and returns failure() +/// otherwise. +/// Example: +/// ``` +/// %0 = memref.alloc() : memref<4x128xf32> +/// scf.for %iv = %c1 to %c1024 step %c3 { +/// memref.copy %1, %0 : memref<4x128xf32> to memref<4x128xf32> +/// "some_use"(%0) : (memref<4x128xf32>) -> () +/// } +/// ``` +/// into: +/// ``` +/// %0 = memref.alloc() : memref<5x4x128xf32> +/// scf.for %iv = %c1 to %c1024 step %c3 { +/// %s = arith.subi %iv, %c1 : index +/// %d = arith.divsi %s, %c3 : index +/// %i = arith.remsi %d, %c5 : index +/// %sv = memref.subview %0[%i, 0, 0] [1, 4, 128] [1, 1, 1] : +/// memref<5x4x128xf32> to memref<4x128xf32, #map0> +/// memref.copy %1, %sv : memref<4x128xf32> to memref<4x128xf32, #map0> +/// "some_use"(%sv) : (memref<4x128xf32, $map0>) -> () +/// } +/// ``` +LogicalResult multiBuffer(memref::AllocOp allocOp, unsigned multiplier); + //===----------------------------------------------------------------------===// // Passes //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td @@ -110,7 +110,8 @@ } def ForOp : SCF_Op<"for", - [AutomaticAllocationScope, DeclareOpInterfaceMethods, + [AutomaticAllocationScope, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, SingleBlockImplicitTerminator<"scf::YieldOp">, RecursiveSideEffects]> { diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td --- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td +++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td @@ -44,6 +44,42 @@ }], "::mlir::LogicalResult", "moveOutOfLoop", (ins "::mlir::ArrayRef<::mlir::Operation *>":$ops) >, + InterfaceMethod<[{ + If there is a single induction variable return it, otherwise return + llvm::None. + }], + /*retTy=*/"::mlir::Optional<::mlir::Value>", + /*methodName=*/"getSingleInductionVar", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return llvm::None; + }] + >, + InterfaceMethod<[{ + Return the single lower bound value or attribute if it exist, otherwise + return llvm::None. + }], + /*retTy=*/"::mlir::Optional<::mlir::OpFoldResult>", + /*methodName=*/"getSingleLowerBound", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return llvm::None; + }] + >, + InterfaceMethod<[{ + Return the single step value or attribute if it exist, otherwise + return llvm::None. + }], + /*retTy=*/"::mlir::Optional<::mlir::OpFoldResult>", + /*methodName=*/"getSingleStep", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return llvm::None; + }] + >, ]; } diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -1820,6 +1820,22 @@ return !region().isAncestor(value.getParentRegion()); } +Optional AffineForOp::getSingleInductionVar() { + return getInductionVar(); +} + +Optional AffineForOp::getSingleLowerBound() { + if (!hasConstantLowerBound()) + return llvm::None; + OpBuilder b(getContext()); + return OpFoldResult(b.getI64IntegerAttr(getConstantLowerBound())); +} + +Optional AffineForOp::getSingleStep() { + OpBuilder b(getContext()); + return OpFoldResult(b.getI64IntegerAttr(getStep())); +} + LogicalResult AffineForOp::moveOutOfLoop(ArrayRef ops) { for (auto *op : ops) op->moveBefore(*this); diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt @@ -2,6 +2,7 @@ ComposeSubView.cpp ExpandOps.cpp FoldSubViewOps.cpp + MultiBuffer.cpp NormalizeMemRefs.cpp ResolveShapedTypeResultDims.cpp @@ -16,6 +17,7 @@ MLIRAffineUtils MLIRArithmetic MLIRInferTypeOpInterface + MLIRLoopLikeInterface MLIRMemRef MLIRPass MLIRStandard diff --git a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp @@ -0,0 +1,146 @@ +//===----------- MultiBuffering.cpp ---------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements multi buffering transformation. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/IR/Dominance.h" +#include "mlir/Interfaces/LoopLikeInterface.h" + +using namespace mlir; + +/// Return true if the op fully overwrite the given `buffer` value. +static bool overrideBuffer(Operation *op, Value buffer) { + auto copyOp = dyn_cast(op); + if (!copyOp) + return false; + return copyOp.target() == buffer; +} + +/// Replace the uses of `oldOp` with the given `val` and for subview uses +/// propagate the type change. Changing the memref type may require propagating +/// it through subview ops so we cannot just do a replaceAllUse but need to +/// propagate the type change and erase old subview ops. +static void replaceUsesAndPropagateType(Operation *oldOp, Value val, + OpBuilder &builder) { + SmallVector opToDelete; + SmallVector operandsToReplace; + for (OpOperand &use : oldOp->getUses()) { + auto subviewUse = dyn_cast(use.getOwner()); + if (!subviewUse) { + // Save the operand to and replace outside the loop to not invalidate the + // iterator. + operandsToReplace.push_back(&use); + continue; + } + builder.setInsertionPoint(subviewUse); + Type newType = memref::SubViewOp::inferRankReducedResultType( + subviewUse.getType().getRank(), val.getType().cast(), + extractFromI64ArrayAttr(subviewUse.static_offsets()), + extractFromI64ArrayAttr(subviewUse.static_sizes()), + extractFromI64ArrayAttr(subviewUse.static_strides())); + Value newSubview = builder.create( + subviewUse->getLoc(), newType.cast(), val, + subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(), + subviewUse.getMixedStrides()); + replaceUsesAndPropagateType(subviewUse, newSubview, builder); + opToDelete.push_back(use.getOwner()); + } + for (OpOperand *operand : operandsToReplace) + operand->set(val); + // Clean up old subview ops. + for (Operation *op : opToDelete) + op->erase(); +} + +/// Helper to convert get a value from an OpFoldResult or create it at the +/// builder insert point. +static Value getOrCreateValue(OpFoldResult res, OpBuilder &builder, + Location loc) { + Value value = res.dyn_cast(); + if (value) + return value; + return builder.create( + loc, res.dyn_cast().cast().getInt()); +} + +// Transformation to do multi-buffering/array expansion to remove dependencies +// on the temporary allocation between consecutive loop iterations. +// Returns success if the transformation happened and failure otherwise. +// This is not a pattern as it requires propagating the new memref type to its +// uses and requires updating subview ops. +LogicalResult mlir::memref::multiBuffer(memref::AllocOp allocOp, + unsigned multiplier) { + DominanceInfo dom(allocOp->getParentOp()); + LoopLikeOpInterface candidateLoop; + for (Operation *user : allocOp->getUsers()) { + auto parentLoop = user->getParentOfType(); + if (!parentLoop) + return failure(); + /// Make sure there is no loop carried dependency on the allocation. + if (!overrideBuffer(user, allocOp.getResult())) + continue; + // If this user doesn't dominate all the other users keep looking. + if (llvm::any_of(allocOp->getUsers(), [&](Operation *otherUser) { + return !dom.dominates(user, otherUser); + })) + continue; + candidateLoop = parentLoop; + break; + } + if (!candidateLoop) + return failure(); + llvm::Optional inductionVar = candidateLoop.getSingleInductionVar(); + llvm::Optional lowerBound = candidateLoop.getSingleLowerBound(); + llvm::Optional singleStep = candidateLoop.getSingleStep(); + if (!inductionVar || !lowerBound || !singleStep) + return failure(); + OpBuilder builder(candidateLoop); + Value stepValue = + getOrCreateValue(*singleStep, builder, candidateLoop->getLoc()); + Value lowerBoundValue = + getOrCreateValue(*lowerBound, builder, candidateLoop->getLoc()); + SmallVector newShape(1, multiplier); + ArrayRef oldShape = allocOp.getType().getShape(); + newShape.append(oldShape.begin(), oldShape.end()); + auto newMemref = MemRefType::get(newShape, allocOp.getType().getElementType(), + MemRefLayoutAttrInterface(), + allocOp.getType().getMemorySpace()); + builder.setInsertionPoint(allocOp); + Location loc = allocOp->getLoc(); + auto newAlloc = builder.create(loc, newMemref); + builder.setInsertionPoint(&candidateLoop.getLoopBody().front(), + candidateLoop.getLoopBody().front().begin()); + AffineExpr induc = getAffineDimExpr(0, allocOp.getContext()); + AffineExpr init = getAffineDimExpr(1, allocOp.getContext()); + AffineExpr step = getAffineDimExpr(2, allocOp.getContext()); + AffineExpr expr = ((induc - init).floorDiv(step)) % multiplier; + auto map = AffineMap::get(3, 0, expr); + std::array operands = {*inductionVar, lowerBoundValue, stepValue}; + Value bufferIndex = builder.create(loc, map, operands); + SmallVector offsets, sizes, strides; + offsets.push_back(bufferIndex); + offsets.append(oldShape.size(), builder.getIndexAttr(0)); + strides.assign(oldShape.size() + 1, builder.getIndexAttr(1)); + sizes.push_back(builder.getIndexAttr(1)); + for (int64_t size : oldShape) + sizes.push_back(builder.getIndexAttr(size)); + auto dstMemref = + memref::SubViewOp::inferRankReducedResultType( + allocOp.getType().getRank(), newMemref, offsets, sizes, strides) + .cast(); + Value subview = builder.create(loc, dstMemref, newAlloc, + offsets, sizes, strides); + replaceUsesAndPropagateType(allocOp, subview, builder); + allocOp.erase(); + return success(); +} diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -320,6 +320,16 @@ return RegionBranchOpInterface::verifyTypes(*this); } +Optional ForOp::getSingleInductionVar() { return getInductionVar(); } + +Optional ForOp::getSingleLowerBound() { + return OpFoldResult(getLowerBound()); +} + +Optional ForOp::getSingleStep() { + return OpFoldResult(getStep()); +} + /// Prints the initialization list in the form of /// (%inner = %outer, %inner2 = %outer2, <...>) /// where 'inner' values are assumed to be region arguments and 'outer' values diff --git a/mlir/test/Dialect/MemRef/multibuffer.mlir b/mlir/test/Dialect/MemRef/multibuffer.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/MemRef/multibuffer.mlir @@ -0,0 +1,110 @@ +// RUN: mlir-opt %s -allow-unregistered-dialect -test-multi-buffering=multiplier=5 -cse -split-input-file | FileCheck %s + +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1)[s0] -> (d0 * 128 + s0 + d1)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (((d0 - d1) floordiv d2) mod 5)> + +// CHECK-LABEL: func @multi_buffer +func @multi_buffer(%a: memref<1024x1024xf32>) { +// CHECK-DAG: %[[A:.*]] = memref.alloc() : memref<5x4x128xf32> +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index + %0 = memref.alloc() : memref<4x128xf32> + %c1024 = arith.constant 1024 : index + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index +// CHECK: scf.for %[[IV:.*]] = %[[C1]] + scf.for %arg2 = %c1 to %c1024 step %c3 { +// CHECK: %[[I:.*]] = affine.apply #[[$MAP1]](%[[IV]], %[[C1]], %[[C3]]) +// CHECK: %[[SV:.*]] = memref.subview %[[A]][%[[I]], 0, 0] [1, 4, 128] [1, 1, 1] : memref<5x4x128xf32> to memref<4x128xf32, #[[$MAP0]]> + %1 = memref.subview %a[%arg2, 0] [4, 128] [1, 1] : + memref<1024x1024xf32> to memref<4x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>> +// CHECK: memref.copy %{{.*}}, %[[SV]] : memref<4x128xf32, #{{.*}}> to memref<4x128xf32, #[[$MAP0]]> + memref.copy %1, %0 : memref<4x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>> to memref<4x128xf32> +// CHECK: "some_use"(%[[SV]]) : (memref<4x128xf32, #[[$MAP0]]>) -> () + "some_use"(%0) : (memref<4x128xf32>) -> () +// CHECK: "some_use"(%[[SV]]) : (memref<4x128xf32, #[[$MAP0]]>) -> () + "some_use"(%0) : (memref<4x128xf32>) -> () + } + return +} + +// ----- + +// CHECK-LABEL: func @multi_buffer_affine +func @multi_buffer_affine(%a: memref<1024x1024xf32>) { +// CHECK-DAG: %[[A:.*]] = memref.alloc() : memref<5x4x128xf32> +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index + %0 = memref.alloc() : memref<4x128xf32> + %c1024 = arith.constant 1024 : index + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index +// CHECK: affine.for %[[IV:.*]] = 1 + affine.for %arg2 = 1 to 1024 step 3 { +// CHECK: %[[I:.*]] = affine.apply #[[$MAP1]](%[[IV]], %[[C1]], %[[C3]]) +// CHECK: %[[SV:.*]] = memref.subview %[[A]][%[[I]], 0, 0] [1, 4, 128] [1, 1, 1] : memref<5x4x128xf32> to memref<4x128xf32, #[[$MAP0]]> + %1 = memref.subview %a[%arg2, 0] [4, 128] [1, 1] : + memref<1024x1024xf32> to memref<4x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>> +// CHECK: memref.copy %{{.*}}, %[[SV]] : memref<4x128xf32, #{{.*}}> to memref<4x128xf32, #[[$MAP0]]> + memref.copy %1, %0 : memref<4x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>> to memref<4x128xf32> +// CHECK: "some_use"(%[[SV]]) : (memref<4x128xf32, #[[$MAP0]]>) -> () + "some_use"(%0) : (memref<4x128xf32>) -> () +// CHECK: "some_use"(%[[SV]]) : (memref<4x128xf32, #[[$MAP0]]>) -> () + "some_use"(%0) : (memref<4x128xf32>) -> () + } + return +} + +// ----- + +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1)[s0] -> (d0 * 128 + s0 + d1)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (((d0 - d1) floordiv d2) mod 5)> + +// CHECK-LABEL: func @multi_buffer_subview_use +func @multi_buffer_subview_use(%a: memref<1024x1024xf32>) { +// CHECK-DAG: %[[A:.*]] = memref.alloc() : memref<5x4x128xf32> +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index + %0 = memref.alloc() : memref<4x128xf32> + %c1024 = arith.constant 1024 : index + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index +// CHECK: scf.for %[[IV:.*]] = %[[C1]] + scf.for %arg2 = %c1 to %c1024 step %c3 { +// CHECK: %[[I:.*]] = affine.apply #[[$MAP1]](%[[IV]], %[[C1]], %[[C3]]) +// CHECK: %[[SV:.*]] = memref.subview %[[A]][%[[I]], 0, 0] [1, 4, 128] [1, 1, 1] : memref<5x4x128xf32> to memref<4x128xf32, #[[$MAP0]]> + %1 = memref.subview %a[%arg2, 0] [4, 128] [1, 1] : + memref<1024x1024xf32> to memref<4x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>> +// CHECK: memref.copy %{{.*}}, %[[SV]] : memref<4x128xf32, #{{.*}}> to memref<4x128xf32, #[[$MAP0]]> + memref.copy %1, %0 : memref<4x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>> to memref<4x128xf32> +// CHECK: %[[SV1:.*]] = memref.subview %[[SV]][0, 1] [4, 127] [1, 1] : memref<4x128xf32, #[[$MAP0]]> to memref<4x127xf32, #[[$MAP0]]> + %s = memref.subview %0[0, 1] [4, 127] [1, 1] : + memref<4x128xf32> to memref<4x127xf32, affine_map<(d0, d1) -> (d0 * 128 + d1 + 1)>> +// CHECK: "some_use"(%[[SV1]]) : (memref<4x127xf32, #[[$MAP0]]>) -> () + "some_use"(%s) : (memref<4x127xf32, affine_map<(d0, d1) -> (d0 * 128 + d1 + 1)>>) -> () +// CHECK: "some_use"(%[[SV]]) : (memref<4x128xf32, #[[$MAP0]]>) -> () + "some_use"(%0) : (memref<4x128xf32>) -> () + } + return +} + +// ----- + +// CHECK-LABEL: func @multi_buffer_negative +func @multi_buffer_negative(%a: memref<1024x1024xf32>) { +// CHECK-NOT: %{{.*}} = memref.alloc() : memref<5x4x128xf32> +// CHECK: %{{.*}} = memref.alloc() : memref<4x128xf32> + %0 = memref.alloc() : memref<4x128xf32> + %c1024 = arith.constant 1024 : index + %c0 = arith.constant 0 : index + %c3 = arith.constant 3 : index + scf.for %arg2 = %c0 to %c1024 step %c3 { + "blocking_use"(%0) : (memref<4x128xf32>) -> () + %1 = memref.subview %a[%arg2, 0] [4, 128] [1, 1] : + memref<1024x1024xf32> to memref<4x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>> + memref.copy %1, %0 : memref<4x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>> to memref<4x128xf32> + "some_use"(%0) : (memref<4x128xf32>) -> () + } + return +} + diff --git a/mlir/test/lib/Dialect/MemRef/CMakeLists.txt b/mlir/test/lib/Dialect/MemRef/CMakeLists.txt --- a/mlir/test/lib/Dialect/MemRef/CMakeLists.txt +++ b/mlir/test/lib/Dialect/MemRef/CMakeLists.txt @@ -1,11 +1,13 @@ # Exclude tests from libMLIR.so add_mlir_library(MLIRMemRefTestPasses TestComposeSubView.cpp + TestMultiBuffer.cpp EXCLUDE_FROM_LIBMLIR LINK_LIBS PUBLIC MLIRPass + MLIRMemRef MLIRMemRefTransforms MLIRTestDialect ) diff --git a/mlir/test/lib/Dialect/MemRef/TestMultiBuffer.cpp b/mlir/test/lib/Dialect/MemRef/TestMultiBuffer.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/MemRef/TestMultiBuffer.cpp @@ -0,0 +1,53 @@ +//===- TestComposeSubView.cpp - Test composed subviews --------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" + +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace { +struct TestMultiBufferingPass + : public PassWrapper> { + TestMultiBufferingPass() = default; + TestMultiBufferingPass(const TestMultiBufferingPass &pass) + : PassWrapper(pass) {} + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + StringRef getArgument() const final { return "test-multi-buffering"; } + StringRef getDescription() const final { + return "Test multi buffering transformation"; + } + void runOnOperation() override; + Option multiplier{ + *this, "multiplier", + llvm::cl::desc( + "Decide how many versions of the buffer should be created,"), + llvm::cl::init(2)}; +}; + +void TestMultiBufferingPass::runOnOperation() { + SmallVector allocs; + getOperation().walk( + [&allocs](memref::AllocOp alloc) { allocs.push_back(alloc); }); + for (memref::AllocOp alloc : allocs) + (void)multiBuffer(alloc, multiplier); +} +} // namespace + +namespace mlir { +namespace test { +void registerTestMultiBuffering() { + PassRegistration(); +} +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -74,6 +74,7 @@ void registerTestDynamicPipelinePass(); void registerTestExpandTanhPass(); void registerTestComposeSubView(); +void registerTestMultiBuffering(); void registerTestGpuParallelLoopMappingPass(); void registerTestIRVisitorsPass(); void registerTestGenericIRVisitorsPass(); @@ -164,6 +165,7 @@ mlir::test::registerTestDynamicPipelinePass(); mlir::test::registerTestExpandTanhPass(); mlir::test::registerTestComposeSubView(); + mlir::test::registerTestMultiBuffering(); mlir::test::registerTestGpuParallelLoopMappingPass(); mlir::test::registerTestIRVisitorsPass(); mlir::test::registerTestGenericIRVisitorsPass(); diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -8037,6 +8037,7 @@ ":ArithmeticTransforms", ":IR", ":InferTypeOpInterface", + ":LoopLikeInterface", ":MemRefDialect", ":MemRefPassIncGen", ":Pass", diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel @@ -438,6 +438,7 @@ deps = [ ":TestDialect", "//mlir:Affine", + "//mlir:MemRefDialect", "//mlir:MemRefTransforms", "//mlir:Pass", "//mlir:Transforms",