diff --git a/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h b/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h --- a/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h +++ b/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h @@ -78,6 +78,8 @@ public: GenericLoopNestRangeBuilder(MutableArrayRef ivs, ArrayRef ranges); + GenericLoopNestRangeBuilder(MutableArrayRef ivs, + ArrayRef ranges); void operator()(std::function fun = nullptr) { (*builder)(fun); } private: diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -15,6 +15,7 @@ namespace mlir { namespace linalg { +struct LinalgTilingOptions; //===----------------------------------------------------------------------===// // Transformations exposed as function calls. @@ -34,10 +35,6 @@ /// An empty vector is interpreted as the identity permutation and the /// transformation returns early. /// -/// When non-null, the optional pointer `folder` is used to call into the -/// `createAndFold` builder method. If `folder` is null, the regular `create` -/// method is called. -/// /// Returns a struct containing the tiled loops in the specified order /// and the cloned op if successful, llvm::None otherwise. /// @@ -46,26 +43,7 @@ /// integers, in the range 0..`tileSizes.size()` without duplications /// (i.e. `[1,1,2]` is an invalid permutation). Optional tileLinalgOp(OpBuilder &b, LinalgOp op, - ArrayRef tileSizes, - ArrayRef interchangeVector = {}, - OperationFolder *folder = nullptr); -Optional -tileLinalgOpToParallelLoops(OpBuilder &b, LinalgOp op, - ArrayRef tileSizes, - ArrayRef interchangeVector = {}, - OperationFolder *folder = nullptr); - -/// Performs standalone tiling of a single LinalgOp by constant `tileSizes`. -/// See `tileLinalgOp(... ArrayRef tileSizes,)` for more details -Optional tileLinalgOp(OpBuilder &b, LinalgOp op, - ArrayRef tileSizes, - ArrayRef interchangeVector = {}, - OperationFolder *folder = nullptr); -Optional -tileLinalgOpToParallelLoops(OpBuilder &b, LinalgOp op, - ArrayRef tileSizes, - ArrayRef interchangeVector = {}, - OperationFolder *folder = nullptr); + const LinalgTilingOptions &options); /// Interchanges the `iterator_types` and `iterator_maps` dimensions of `op`. /// This is an in-place transformation controlled by `interchangeVector`. @@ -203,15 +181,34 @@ enum class LinalgTilingLoopType { Loops = 0, AffineLoops = 1, - ParallelLoops = 2 + ParallelLoops = 2, }; +using TileSizeComputationFunction = + std::function(OpBuilder &, Operation *)>; struct LinalgTilingOptions { - /// The tile sizes by which to tile. - SmallVector tileSizes{}; - LinalgTilingOptions &setTileSizes(ArrayRef ts) { - tileSizes.assign(ts.begin(), ts.end()); + /// Computation function that returns the tile sizes for each operation. + /// Delayed construction of constant tile sizes should occur to interoperate + /// with folding. + TileSizeComputationFunction tileSizeComputationFunction = nullptr; + LinalgTilingOptions & + setTileSizeComputationFunction(TileSizeComputationFunction &fun) { + tileSizeComputationFunction = fun; + return *this; + } + /// Set the `tileSizeComputationFunction` to return the values `ts`. The + /// values must not fold away when tiling. Otherwise, use a more robust + /// `tileSizeComputationFunction`. + LinalgTilingOptions &setTileSizes(ValueRange ts) { + tileSizeComputationFunction = [&](OpBuilder &, Operation *) { + return SmallVector(ts.begin(), ts.end()); + }; return *this; } + /// Convenience function to set the `tileSizeComputationFunction` to a + /// function that computes tile sizes at the point they are needed. Allows + /// proper interaction with folding. + LinalgTilingOptions &setTileSizes(ArrayRef ts); + /// The interchange vector to reorder the tiled loops. SmallVector interchangeVector{}; LinalgTilingOptions &setInterchange(ArrayRef interchange) { @@ -226,6 +223,12 @@ } }; +/// Canonicalization patterns relevant to apply after tiling patterns. These are +/// applied automatically by the tiling pass but need to be applied manually +/// when tiling is called programmatically. +OwningRewritePatternList +getLinalgTilingCanonicalizationPatterns(MLIRContext *ctx); + struct LinalgBaseTilingPattern : public RewritePattern { LinalgBaseTilingPattern(StringRef opName, MLIRContext *context, LinalgTilingOptions options, diff --git a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp --- a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp +++ b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp @@ -78,6 +78,60 @@ namespace mlir { namespace edsc { +static void unpackRanges(ArrayRef rangeOps, SmallVectorImpl &lbs, + SmallVectorImpl &ubs, + SmallVectorImpl &steps) { + for (Value range : rangeOps) { + assert(range.getType() && "expected linalg.range type"); + assert(range.getDefiningOp() && "need operations to extract range parts"); + RangeOp rangeOp = cast(range.getDefiningOp()); + lbs.emplace_back(rangeOp.min()); + ubs.emplace_back(rangeOp.max()); + steps.emplace_back(rangeOp.step()); + } +} + +static void unpackRanges(ArrayRef ranges, + SmallVectorImpl &lbs, + SmallVectorImpl &ubs, + SmallVectorImpl &steps) { + for (SubViewOp::Range range : ranges) { + lbs.emplace_back(range.offset); + ubs.emplace_back(range.size); + steps.emplace_back(range.stride); + } +} + +template <> +GenericLoopNestRangeBuilder::GenericLoopNestRangeBuilder( + MutableArrayRef ivs, ArrayRef ranges) { + builder = std::make_unique(ivs, ranges); +} + +template <> +GenericLoopNestRangeBuilder::GenericLoopNestRangeBuilder( + MutableArrayRef ivs, ArrayRef ranges) { + SmallVector lbs, ubs, steps; + unpackRanges(ranges, lbs, ubs, steps); + SmallVector constantSteps; + constantSteps.reserve(steps.size()); + for (Value v : steps) { + auto op = v.getDefiningOp(); + assert(op && "Affine loops require constant steps"); + constantSteps.push_back(op.getValue()); + } + builder = + std::make_unique(ivs, lbs, ubs, constantSteps); +} + +template <> +GenericLoopNestRangeBuilder::GenericLoopNestRangeBuilder( + MutableArrayRef ivs, ArrayRef ranges) { + SmallVector lbs, ubs, steps; + unpackRanges(ranges, lbs, ubs, steps); + builder = std::make_unique(ivs, lbs, ubs, steps); +} + template <> GenericLoopNestRangeBuilder::GenericLoopNestRangeBuilder( MutableArrayRef ivs, ArrayRef ranges) { @@ -87,32 +141,24 @@ template <> GenericLoopNestRangeBuilder::GenericLoopNestRangeBuilder( MutableArrayRef ivs, ArrayRef ranges) { - SmallVector lbs; - SmallVector ubs; - SmallVector steps; - for (Value range : ranges) { - assert(range.getType() && "expected linalg.range type"); - assert(range.getDefiningOp() && "need operations to extract range parts"); - RangeOp rangeOp = cast(range.getDefiningOp()); - lbs.emplace_back(rangeOp.min()); - ubs.emplace_back(rangeOp.max()); - steps.emplace_back(rangeOp.step()); + SmallVector lbs, ubs, steps; + unpackRanges(ranges, lbs, ubs, steps); + SmallVector constantSteps; + constantSteps.reserve(steps.size()); + for (Value v : steps) { + auto op = v.getDefiningOp(); + assert(op && "Affine loops require constant steps"); + constantSteps.push_back(op.getValue()); } - builder = std::make_unique(ivs, lbs, ubs, steps); + builder = + std::make_unique(ivs, lbs, ubs, constantSteps); } template <> GenericLoopNestRangeBuilder::GenericLoopNestRangeBuilder( MutableArrayRef ivs, ArrayRef ranges) { SmallVector lbs, ubs, steps; - for (Value range : ranges) { - assert(range.getType() && "expected linalg.range type"); - assert(range.getDefiningOp() && "need operations to extract range parts"); - RangeOp rangeOp = cast(range.getDefiningOp()); - lbs.emplace_back(rangeOp.min()); - ubs.emplace_back(rangeOp.max()); - steps.emplace_back(rangeOp.step()); - } + unpackRanges(ranges, lbs, ubs, steps); builder = std::make_unique(ivs, lbs, ubs, steps); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -290,6 +290,34 @@ return true; } +static bool isSameSubView(Value a, Value b) { + if (a == b) + return true; + auto sva = a.getDefiningOp(); + auto svb = b.getDefiningOp(); + if (!sva || !svb) + return false; + if (!isSameSubView(sva.getViewSource(), svb.getViewSource())) + return false; + if (sva.getType() != svb.getType()) + return false; + if (sva.getRank() != svb.getRank()) + return false; + if (sva.getNumOperands() != svb.getNumOperands()) + return false; + if (sva.static_offsets() != svb.static_offsets()) + return false; + if (sva.static_sizes() != svb.static_sizes()) + return false; + if (sva.static_strides() != svb.static_strides()) + return false; + /// Skip the "viewSource" operand. + for (unsigned idx = 1, e = sva.getNumOperands(); idx != e; ++idx) + if (sva.getOperand(idx) != svb.getOperand(idx)) + return false; + return true; +} + static Optional fuseProducerOfDep(OpBuilder &b, LinalgOp consumer, unsigned consumerIdx, const LinalgDependenceGraph &graph, OperationFolder *folder, @@ -305,7 +333,7 @@ // Check that the dependence is indeed on the input `consumerIdx` view. auto consumedView = dependence.indexingView; - if (consumer.getBuffer(consumerIdx) != consumedView) + if (!isSameSubView(consumer.getBuffer(consumerIdx), consumedView)) continue; // Consumer consumes this view, `isStructurallyFusableProducer` also checks diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -38,8 +38,9 @@ #define DEBUG_TYPE "linalg-tiling" static bool isZero(Value v) { - return isa_and_nonnull(v.getDefiningOp()) && - cast(v.getDefiningOp()).getValue() == 0; + if (auto cst = v.getDefiningOp()) + return cst.getValue() == 0; + return false; } using LoopIndexToRangeIndexMap = DenseMap; @@ -55,11 +56,11 @@ // indices of newly created loops. static std::tuple, LoopIndexToRangeIndexMap> makeTiledLoopRanges(OpBuilder &b, Location loc, AffineMap map, - ArrayRef allViewSizes, ArrayRef allTileSizes, - OperationFolder *folder) { + ArrayRef allViewSizes, + ArrayRef allTileSizes) { assert(allTileSizes.size() == map.getNumResults()); // Apply `map` to get view sizes in loop order. - auto viewSizes = applyMapToValues(b, loc, map, allViewSizes, folder); + auto viewSizes = applyMapToValues(b, loc, map, allViewSizes); SmallVector tileSizes(allTileSizes.begin(), allTileSizes.end()); // Traverse the tile sizes, which are in loop order, erase zeros everywhere. @@ -76,10 +77,9 @@ // Create a new range with the applied tile sizes. SmallVector res; - for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx) { - res.push_back(SubViewOp::Range{folded_std_constant_index(folder, 0), - viewSizes[idx], tileSizes[idx]}); - } + for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx) + res.push_back(SubViewOp::Range{std_constant_index(0), viewSizes[idx], + tileSizes[idx]}); return std::make_tuple(res, loopIndexToRangeIndex); } @@ -222,10 +222,11 @@ return false; } -static SmallVector -makeTiledViews(OpBuilder &b, Location loc, LinalgOp linalgOp, - ArrayRef ivs, ArrayRef tileSizes, - ArrayRef viewSizes, OperationFolder *folder) { +static SmallVector makeTiledViews(OpBuilder &b, Location loc, + LinalgOp linalgOp, + ArrayRef ivs, + ArrayRef tileSizes, + ArrayRef viewSizes) { assert(linalgOp.hasBufferSemantics() && "expected linalg op with buffer semantics"); assert(ivs.size() == static_cast(llvm::count_if( @@ -240,8 +241,7 @@ SmallVector lbs, subViewSizes; for (unsigned idx = 0, idxIvs = 0, e = tileSizes.size(); idx < e; ++idx) { bool isTiled = !isZero(tileSizes[idx]); - lbs.push_back(isTiled ? ivs[idxIvs++] - : (Value)folded_std_constant_index(folder, 0)); + lbs.push_back(isTiled ? ivs[idxIvs++] : (Value)std_constant_index(0)); subViewSizes.push_back(isTiled ? tileSizes[idx] : viewSizes[idx]); } @@ -270,18 +270,18 @@ strides.reserve(rank); for (unsigned r = 0; r < rank; ++r) { if (!isTiled(map.getSubMap({r}), tileSizes)) { - offsets.push_back(folded_std_constant_index(folder, 0)); + offsets.push_back(std_constant_index(0)); sizes.push_back(std_dim(view, r)); - strides.push_back(folded_std_constant_index(folder, 1)); + strides.push_back(std_constant_index(1)); continue; } // Tiling creates a new slice at the proper index, the slice step is 1 // (i.e. the slice view does not subsample, stepping occurs in the loop). auto m = map.getSubMap({r}); - auto offset = applyMapToValues(b, loc, m, lbs, folder).front(); + auto offset = applyMapToValues(b, loc, m, lbs).front(); offsets.push_back(offset); - auto size = applyMapToValues(b, loc, m, subViewSizes, folder).front(); + auto size = applyMapToValues(b, loc, m, subViewSizes).front(); // The size of the subview should be trimmed to avoid out-of-bounds // accesses, unless we statically know the subview size divides the view @@ -297,64 +297,64 @@ getAffineDimExpr(/*position=*/1, b.getContext()) - getAffineDimExpr(/*position=*/2, b.getContext())}, b.getContext()); - auto d = folded_std_dim(folder, view, r); - size = folded_affine_min(folder, b.getIndexType(), minMap, - ValueRange{size, d, offset}); + auto d = std_dim(view, r); + size = + affine_min(b.getIndexType(), minMap, ValueRange{size, d, offset}); } sizes.push_back(size); - strides.push_back(folded_std_constant_index(folder, 1)); + strides.push_back(std_constant_index(1)); } res.push_back(b.create(loc, view, offsets, sizes, strides)); } - // Traverse the mins/maxes and erase those that don't have uses left. - // This is a special type of folding that we only apply when `folder` is - // defined. - if (folder) - for (auto v : llvm::concat(lbs, subViewSizes)) - if (v.use_empty()) - v.getDefiningOp()->erase(); - return res; } template Optional static tileLinalgOpImpl( - OpBuilder &b, LinalgOp op, ArrayRef tileSizes, - ArrayRef interchangeVector, OperationFolder *folder) { + OpBuilder &b, LinalgOp op, const LinalgTilingOptions &options) { + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(op); + ScopedContext scope(b, op.getLoc()); + assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); // 1. Enforce the convention that "tiling by zero" skips tiling a particular // dimension. This convention is significantly simpler to handle instead of // adjusting affine maps to account for missing dimensions. - assert(op.getNumParallelLoops() + op.getNumReductionLoops() + - op.getNumWindowLoops() == - tileSizes.size() && - "expected matching number of tile sizes and loops"); + auto nLoops = op.getNumLoops(); + SmallVector tileSizeVector = + options.tileSizeComputationFunction(b, op); + if (tileSizeVector.size() < nLoops) { + auto zero = std_constant_index(0); + tileSizeVector.append(nLoops - tileSizeVector.size(), zero); + } + + ArrayRef tileSizes = tileSizeVector; + // Initial tile sizes may be too big, only take the first nLoops. + tileSizes = tileSizes.take_front(nLoops); + + if (llvm::all_of(tileSizes, isZero)) + return llvm::None; if (auto convOp = dyn_cast(op.getOperation())) { // For conv op only support tiling along batch dimension (which is the first // loop). - if (convOp.padding() && - !llvm::all_of(tileSizes.drop_front(), - [](Value val) { return isZero(val); })) + if (convOp.padding() && !llvm::all_of(tileSizes.drop_front(), isZero)) return llvm::None; } // If interchangeVector is empty, use the identity. Build the permutation map // otherwise. - auto invPermutationMap = AffineMap::getMultiDimIdentityMap( - tileSizes.size(), ScopedContext::getContext()); - if (!interchangeVector.empty()) + auto invPermutationMap = + AffineMap::getMultiDimIdentityMap(tileSizes.size(), b.getContext()); + if (!options.interchangeVector.empty()) invPermutationMap = inversePermutation(AffineMap::getPermutationMap( - interchangeVector, ScopedContext::getContext())); + options.interchangeVector, b.getContext())); if (!invPermutationMap) return llvm::None; - OpBuilder::InsertionGuard g(b); - b.setInsertionPoint(op); - ScopedContext scope(b, op.getLoc()); // 2. Build the tiled loop ranges. auto viewSizes = getViewSizes(b, op); // The flattened loopToOperandRangesMaps is expected to be an invertible @@ -368,22 +368,15 @@ SmallVector loopRanges; LoopIndexToRangeIndexMap loopIndexToRangeIndex; - std::tie(loopRanges, loopIndexToRangeIndex) = - makeTiledLoopRanges(b, scope.getLocation(), viewSizesToLoopsMap, - viewSizes, tileSizes, folder); - if (!interchangeVector.empty()) - applyPermutationToVector(loopRanges, interchangeVector); + std::tie(loopRanges, loopIndexToRangeIndex) = makeTiledLoopRanges( + b, scope.getLocation(), viewSizesToLoopsMap, viewSizes, tileSizes); + if (!options.interchangeVector.empty()) + applyPermutationToVector(loopRanges, options.interchangeVector); // 3. Create the tiled loops. LinalgOp res = op; SmallVector ivs(loopRanges.size()); - // Convert SubViewOp::Range to linalg_range. - SmallVector linalgRanges; - for (auto &range : loopRanges) { - linalgRanges.push_back( - linalg_range(range.offset, range.size, range.stride)); - } - GenericLoopNestRangeBuilder(ivs, linalgRanges)([&] { + GenericLoopNestRangeBuilder(ivs, loopRanges)([&] { auto &b = ScopedContext::getBuilderRef(); auto loc = ScopedContext::getLocation(); SmallVector ivValues(ivs.begin(), ivs.end()); @@ -393,11 +386,10 @@ // assuming that loopRanges have previously been permuted by // (i,j,k)->(k,i,j) So this permutation should be the inversePermutation of // that one: (d0,d1,d2)->(d2,d0,d1) - if (!interchangeVector.empty()) - ivValues = applyMapToValues(b, loc, invPermutationMap, ivValues, folder); + if (!options.interchangeVector.empty()) + ivValues = applyMapToValues(b, loc, invPermutationMap, ivValues); - auto views = - makeTiledViews(b, loc, op, ivValues, tileSizes, viewSizes, folder); + auto views = makeTiledViews(b, loc, op, ivValues, tileSizes, viewSizes); auto operands = getAssumedNonViewOperands(op); views.append(operands.begin(), operands.end()); res = op.clone(b, loc, views); @@ -413,102 +405,106 @@ loops.push_back(iv.cast().getOwner()->getParentOp()); assert(loops.back() && "no owner found for induction variable!"); } - return TiledLinalgOp{res, loops}; } -template -static Optional -tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ArrayRef tileSizes, - ArrayRef interchangeVector, - OperationFolder *folder) { - assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); - if (tileSizes.empty()) - return llvm::None; - - // The following uses the convention that "tiling by zero" skips tiling a - // particular dimension. This convention is significantly simpler to handle - // instead of adjusting affine maps to account for missing dimensions. - auto nLoops = op.getNumParallelLoops() + op.getNumReductionLoops() + - op.getNumWindowLoops(); - tileSizes = tileSizes.take_front(nLoops); - // If only 0 tilings are left, then return. - if (llvm::all_of(tileSizes, [](int64_t v) { return v == 0; })) - return llvm::None; - - if (auto convOp = dyn_cast(op.getOperation())) { - // For conv op only support tiling along batch dimension (which is the first - // loop). - if (convOp.padding() && !llvm::all_of(tileSizes.drop_front(), - [](int64_t val) { return val == 0; })) - return llvm::None; - } +Optional +mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op, + const LinalgTilingOptions &options) { + if (options.loopType == LinalgTilingLoopType::Loops) + return tileLinalgOpImpl(b, op, options); + if (options.loopType == LinalgTilingLoopType::ParallelLoops) + return tileLinalgOpImpl(b, op, options); + // TODO: Impl tiling to affine loops when it makes sense. + return llvm::None; +} - // Create a builder for tile size constants. - OpBuilder::InsertionGuard g(b); - b.setInsertionPoint(op); - ScopedContext scope(b, op.getLoc()); +namespace { +/// Helper classes for type list expansion. +template +class CanonicalizationPatternList; + +template <> +class CanonicalizationPatternList<> { +public: + static void insert(OwningRewritePatternList &patterns, MLIRContext *ctx) {} +}; - // Materialize concrete tile size values to pass the generic tiling function. - SmallVector tileSizeValues; - tileSizeValues.reserve(tileSizes.size()); - for (auto ts : tileSizes) - tileSizeValues.push_back(folded_std_constant_index(folder, ts)); - // Pad tile sizes with zero values to enforce our convention. - if (tileSizeValues.size() < nLoops) { - for (unsigned i = tileSizeValues.size(); i < nLoops; ++i) - tileSizeValues.push_back(folded_std_constant_index(folder, 0)); +template +class CanonicalizationPatternList { +public: + static void insert(OwningRewritePatternList &patterns, MLIRContext *ctx) { + OpTy::getCanonicalizationPatterns(patterns, ctx); + CanonicalizationPatternList::insert(patterns, ctx); } +}; - return tileLinalgOpImpl(b, op, tileSizeValues, interchangeVector, - folder); -} +/// Helper classes for type list expansion. +template +class RewritePatternList; -Optional -mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op, ArrayRef tileSizes, - ArrayRef interchangeVector, - OperationFolder *folder) { - return tileLinalgOpImpl(b, op, tileSizes, interchangeVector, - folder); -} +template <> +class RewritePatternList<> { +public: + static void insert(OwningRewritePatternList &patterns, + const LinalgTilingOptions &options, MLIRContext *ctx) {} +}; -Optional mlir::linalg::tileLinalgOpToParallelLoops( - OpBuilder &b, LinalgOp op, ArrayRef tileSizes, - ArrayRef interchangeVector, OperationFolder *folder) { - return tileLinalgOpImpl(b, op, tileSizes, interchangeVector, - folder); -} +template +class RewritePatternList { +public: + static void insert(OwningRewritePatternList &patterns, + const LinalgTilingOptions &options, MLIRContext *ctx) { + patterns.insert>(ctx, options, + LinalgMarker({}, "tiled")); + RewritePatternList::insert(patterns, options, ctx); + } +}; +} // namespace -Optional mlir::linalg::tileLinalgOp( - OpBuilder &b, LinalgOp op, ArrayRef tileSizes, - ArrayRef interchangeVector, OperationFolder *folder) { - return tileLinalgOpImpl(b, op, tileSizes, interchangeVector, - folder); +OwningRewritePatternList +mlir::linalg::getLinalgTilingCanonicalizationPatterns(MLIRContext *ctx) { + OwningRewritePatternList patterns; + AffineApplyOp::getCanonicalizationPatterns(patterns, ctx); + AffineForOp::getCanonicalizationPatterns(patterns, ctx); + AffineMinOp::getCanonicalizationPatterns(patterns, ctx); + AffineMaxOp::getCanonicalizationPatterns(patterns, ctx); + scf::ForOp::getCanonicalizationPatterns(patterns, ctx); + scf::ParallelOp::getCanonicalizationPatterns(patterns, ctx); + ConstantIndexOp::getCanonicalizationPatterns(patterns, ctx); + SubViewOp::getCanonicalizationPatterns(patterns, ctx); + ViewOp::getCanonicalizationPatterns(patterns, ctx); + CanonicalizationPatternList< +#define GET_OP_LIST +#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" + >::insert(patterns, ctx); + return patterns; } -Optional mlir::linalg::tileLinalgOpToParallelLoops( - OpBuilder &b, LinalgOp op, ArrayRef tileSizes, - ArrayRef interchangeVector, OperationFolder *folder) { - return tileLinalgOpImpl(b, op, tileSizes, interchangeVector, - folder); +/// Populate the given list with patterns that apply Linalg tiling. +static void insertTilingPatterns(OwningRewritePatternList &patterns, + const LinalgTilingOptions &options, + MLIRContext *ctx) { + RewritePatternList< +#define GET_OP_LIST +#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" + >::insert(patterns, options, ctx); } -template -static void tileLinalgOps(FuncOp f, ArrayRef tileSizes) { - OpBuilder b(f); - OperationFolder folder(f.getContext()); - f.walk([tileSizes, &b, &folder](LinalgOp op) { - if (!op.hasBufferSemantics()) - return; - auto opLoopsPair = tileLinalgOpImpl( - b, op, tileSizes, /*interchangeVector=*/{}, &folder); - // If tiling occurred successfully, erase old op. - if (opLoopsPair) - op.erase(); - }); - f.walk([](LinalgOp op) { - if (isOpTriviallyDead(op)) - op.erase(); +static void applyTilingToLoopPatterns(LinalgTilingLoopType loopType, + FuncOp funcOp, + ArrayRef tileSizes) { + auto options = + LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(loopType); + MLIRContext *ctx = funcOp.getContext(); + OwningRewritePatternList patterns; + insertTilingPatterns(patterns, options, ctx); + applyPatternsAndFoldGreedily(funcOp, patterns); + applyPatternsAndFoldGreedily(funcOp, + getLinalgTilingCanonicalizationPatterns(ctx)); + // Drop the marker. + funcOp.walk([](LinalgOp op) { + op.removeAttr(LinalgTransforms::kLinalgTransformMarker); }); } @@ -518,7 +514,8 @@ LinalgTilingPass(ArrayRef sizes) { tileSizes = sizes; } void runOnFunction() override { - tileLinalgOps(getFunction(), tileSizes); + applyTilingToLoopPatterns(LinalgTilingLoopType::Loops, getFunction(), + tileSizes); } }; @@ -530,7 +527,8 @@ } void runOnFunction() override { - tileLinalgOps(getFunction(), tileSizes); + applyTilingToLoopPatterns(LinalgTilingLoopType::ParallelLoops, + getFunction(), tileSizes); } }; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -98,6 +98,21 @@ rewriter.getContext())); } +LinalgTilingOptions & +mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef ts) { + SmallVector tileSizes(ts.begin(), ts.end()); + tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart( + &op->getParentOfType().getBody().front()); + return llvm::to_vector<4>(llvm::map_range(tileSizes, [&](int64_t s) { + Value v = b.create(op->getLoc(), s); + return v; + })); + }; + return *this; +}; + /// Linalg base tiling pattern. mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( StringRef opName, MLIRContext *context, LinalgTilingOptions options, @@ -112,14 +127,7 @@ return failure(); if (failed(marker.checkAndNotify(rewriter, linalgOp))) return failure(); - Optional res; - if (options.loopType == LinalgTilingLoopType::Loops) - res = tileLinalgOp(rewriter, linalgOp, options.tileSizes, - options.interchangeVector); - else if (options.loopType == LinalgTilingLoopType::ParallelLoops) - res = tileLinalgOpToParallelLoops(rewriter, linalgOp, options.tileSizes, - options.interchangeVector); - // TODO: Impl tiling to affine loops when it makes sense. + Optional res = tileLinalgOp(rewriter, linalgOp, options); if (!res) return failure(); diff --git a/mlir/test/Dialect/Linalg/tile.mlir b/mlir/test/Dialect/Linalg/tile.mlir --- a/mlir/test/Dialect/Linalg/tile.mlir +++ b/mlir/test/Dialect/Linalg/tile.mlir @@ -1,7 +1,7 @@ -// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2" | FileCheck %s -check-prefix=TILE-2 -// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=0,2" | FileCheck %s -check-prefix=TILE-02 -// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=0,0,2" | FileCheck %s -check-prefix=TILE-002 -// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2,3,4" | FileCheck %s -check-prefix=TILE-234 +// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2" -mlir-disable-threading=true | FileCheck %s -check-prefix=TILE-2 +// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=0,2" -mlir-disable-threading=true | FileCheck %s -check-prefix=TILE-02 +// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=0,0,2" -mlir-disable-threading=true | FileCheck %s -check-prefix=TILE-002 +// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2,3,4" -mlir-disable-threading=true | FileCheck %s -check-prefix=TILE-234 // TILE-2-DAG: #[[strided1D:.*]] = affine_map<(d0)[s0] -> (d0 + s0)> // TILE-02-DAG: #[[strided1D:.*]] = affine_map<(d0)[s0] -> (d0 + s0)> @@ -13,26 +13,16 @@ // TILE-002-DAG: #[[strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> // TILE-234-DAG: #[[strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> -// TILE-2-DAG: #[[bound_map:.*]] = affine_map<(d0, d1, d2) -> (2, d1 - d2)> -// TILE-02-DAG: #[[bound_map:.*]] = affine_map<(d0, d1, d2) -> (2, d1 - d2)> -// TILE-002-DAG: #[[bound_map:.*]] = affine_map<(d0, d1, d2) -> (2, d1 - d2)> -// TILE-234-DAG: #[[bound_map_2:.*]] = affine_map<(d0, d1, d2) -> (2, d1 - d2)> -// TILE-234-DAG: #[[bound_map_3:.*]] = affine_map<(d0, d1, d2) -> (3, d1 - d2)> -// TILE-234-DAG: #[[bound_map_4:.*]] = affine_map<(d0, d1, d2) -> (4, d1 - d2)> +// TILE-2-DAG: #[[bound_map:.*]] = affine_map<(d0)[s0] -> (2, -d0 + s0)> +// TILE-02-DAG: #[[bound_map:.*]] = affine_map<(d0)[s0] -> (2, -d0 + s0)> +// TILE-002-DAG: #[[bound_map:.*]] = affine_map<(d0)[s0] -> (2, -d0 + s0)> +// TILE-234-DAG: #[[bound_map_2:.*]] = affine_map<(d0)[s0] -> (2, -d0 + s0)> +// TILE-234-DAG: #[[bound_map_3:.*]] = affine_map<(d0)[s0] -> (3, -d0 + s0)> +// TILE-234-DAG: #[[bound_map_4:.*]] = affine_map<(d0)[s0] -> (4, -d0 + s0)> -// TILE-2-DAG: #[[strided1D_dynamic:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> -// TILE-02-DAG: #[[strided1D_dynamic:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> -// T_ILE-002-DAG: #[[strided1D_dynamic:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> -// TILE-234-DAG: #[[strided1D_dynamic:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> - -// TILE-2-DAG: #[[strided2D_dynamic:.*]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)> -// TILE-02-DAG: #[[strided2D_dynamic:.*]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)> -// TILE-002-DAG: #[[strided2D_dynamic:.*]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)> -// TILE-234-DAG: #[[strided2D_dynamic:.*]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)> - -// REACTIVATE_ME_TILE-2-DAG: #[[stride_99_1_layout_map:.*]] = affine_map<(d0, d1)[s0] -> (d0 * 99 + s0 + d1)> -// REACTIVATE_ME_TILE-02-DAG: #[[stride_99_1_layout_map:.*]] = affine_map<(d0, d1)[s0] -> (d0 * 99 + s0 + d1)> -// REACTIVATE_ME_TILE-234-DAG: #[[stride_99_1_layout_map:.*]] = affine_map<(d0, d1)[s0] -> (d0 * 99 + s0 + d1)> +// TILE-2-DAG: #[[stride_99_1_layout_map:.*]] = affine_map<(d0, d1)[s0] -> (d0 * 99 + s0 + d1)> +// TILE-02-DAG: #[[stride_99_1_layout_map:.*]] = affine_map<(d0, d1)[s0] -> (d0 * 99 + s0 + d1)> +// TILE-234-DAG: #[[stride_99_1_layout_map:.*]] = affine_map<(d0, d1)[s0] -> (d0 * 99 + s0 + d1)> func @matmul(%arg0: memref, %arg1: memref, %arg2: memref) { linalg.matmul(%arg0, %arg1, %arg2) : memref, memref, memref @@ -40,55 +30,51 @@ } // TILE-2-LABEL: func @matmul( // TILE-2-DAG: %[[C0:.*]] = constant 0 : index -// TILE-2-DAG: %[[C1:.*]] = constant 1 : index // TILE-2-DAG: %[[C2:.*]] = constant 2 : index // TILE-2: %[[M:.*]] = dim %{{.*}}, 0 : memref // TILE-2: scf.for %[[I:.*]] = %{{.*}}{{.*}} to %[[M]] step %{{.*}} { // TILE-2: %[[localM:.*]] = dim %{{.*}}, 0 -// TILE-2: %[[szM:.*]] = affine.min #[[bound_map]](%[[C2]], %[[localM]], %[[I]]) +// TILE-2: %[[szM:.*]] = affine.min #[[bound_map]](%[[I]])[%[[localM]]] // TILE-2: %[[K:.*]] = dim %{{.*}}, 1 : memref -// TILE-2: %[[sAi:.*]] = subview %{{.*}}[%[[I]], %[[C0]]] [%[[szM]], %[[K]]] [%[[C1]], %[[C1]]] : memref to memref +// TILE-2: %[[sAi:.*]] = subview %{{.*}}[%[[I]], 0] [%[[szM]], %[[K]]] [1, 1] : memref to memref // TILE-2: %[[localK:.*]] = dim %{{.*}}, 0 -// TILE-2: %[[szK:.*]] = affine.min #[[bound_map]](%[[C2]], %[[localK]], %[[I]]) +// TILE-2: %[[szK:.*]] = affine.min #[[bound_map]](%[[I]])[%[[localK]]] // TILE-2: %[[N:.*]] = dim %{{.*}}, 1 : memref -// TILE-2: %[[sCi:.*]] = subview %{{.*}}[%[[I]], %[[C0]]] [%[[szK]], %[[N]]] [%[[C1]], %[[C1]]] : memref to memref -// TILE-2: linalg.matmul(%[[sAi]], %{{.*}}, %[[sCi]]) : memref, memref, memref +// TILE-2: %[[sCi:.*]] = subview %{{.*}}[%[[I]], 0] [%[[szK]], %[[N]]] [1, 1] : memref to memref +// TILE-2: linalg.matmul(%[[sAi]], %{{.*}}, %[[sCi]]) : memref, memref, memref // TILE-02-LABEL: func @matmul( // TILE-02-DAG: %[[C0:.*]] = constant 0 : index -// TILE-02-DAG: %[[C1:.*]] = constant 1 : index // TILE-02-DAG: %[[C2:.*]] = constant 2 : index // TILE-02: %[[N:.*]] = dim %arg1, 1 : memref // TILE-02: scf.for %[[J:.*]] = %{{.*}} to %[[N]] step %{{.*}} { // TILE-02: %[[K:.*]] = dim %{{.*}}, 0 : memref // TILE-02: %[[localN:.*]] = dim %{{.*}}, 1 -// TILE-02: %[[szN:.*]] = affine.min #[[bound_map]](%[[C2]], %[[localN]], %[[J]]) -// TILE-02: %[[sBj:.*]] = subview %{{.*}}[%[[C0]], %[[J]]] [%[[K]], %[[szN]]] [%[[C1]], %[[C1]]] : memref to memref +// TILE-02: %[[szN:.*]] = affine.min #[[bound_map]](%[[J]])[%[[localN]]] +// TILE-02: %[[sBj:.*]] = subview %{{.*}}[0, %[[J]]] [%[[K]], %[[szN]]] [1, 1] : memref to memref // TILE-02: %[[M:.*]] = dim %{{.*}}, 0 : memref // TILE-02: %[[localK:.*]] = dim %{{.*}}, 1 -// TILE-02: %[[szK:.*]] = affine.min #[[bound_map]](%[[C2]], %[[localK]], %[[J]]) -// TILE-02: %[[sCj:.*]] = subview %{{.*}}[%[[C0]], %[[J]]] [%[[M]], %[[szK]]] [%[[C1]], %[[C1]]] : memref to memref -// TILE-02: linalg.matmul(%{{.*}}, %[[sBj]], %[[sCj]]) : memref, memref, memref +// TILE-02: %[[szK:.*]] = affine.min #[[bound_map]](%[[J]])[%[[localK]]] +// TILE-02: %[[sCj:.*]] = subview %{{.*}}[0, %[[J]]] [%[[M]], %[[szK]]] [1, 1] : memref to memref +// TILE-02: linalg.matmul(%{{.*}}, %[[sBj]], %[[sCj]]) : memref, memref, memref // TILE-002-LABEL: func @matmul( // TILE-002-DAG: %[[C0:.*]] = constant 0 : index -// TILE-002-DAG: %[[C1:.*]] = constant 1 : index // TILE-002-DAG: %[[C2:.*]] = constant 2 : index // TILE-002: %[[ubK:.*]] = dim %{{.*}}, 1 : memref // TILE-002: scf.for %[[K:.*]] = %{{.*}}{{.*}} to %[[ubK]] step %{{.*}} { // TILE-002: %[[M:.*]] = dim %{{.*}}, 0 : memref // TILE-002: %[[localK:.*]] = dim %{{.*}}, 1 -// TILE-002: %[[szK:.*]] = affine.min #[[bound_map]](%[[C2]], %[[localK]], %[[K]]) -// TILE-002: %[[sAj:.*]] = subview %{{.*}}[%[[C0]], %[[K]]] [%[[M]], %[[szK]]] [%[[C1]], %[[C1]]] : memref to memref +// TILE-002: %[[szK:.*]] = affine.min #[[bound_map]](%[[K]])[%[[localK]]] +// TILE-002: %[[sAj:.*]] = subview %{{.*}}[0, %[[K]]] [%[[M]], %[[szK]]] [1, 1] : memref to memref // TILE-002: %[[localK:.*]] = dim %{{.*}}, 0 -// TILE-002: %[[szK:.*]] = affine.min #[[bound_map]](%[[C2]], %[[localK]], %[[K]]) +// TILE-002: %[[szK:.*]] = affine.min #[[bound_map]](%[[K]])[%[[localK]]] // TILE-002: %[[N:.*]] = dim %{{.*}}, 1 : memref -// TILE-002: %[[sBj:.*]] = subview %{{.*}}[%[[K]], %[[C0]]] [%[[szK]], %[[N]]] [%[[C1]], %[[C1]]] : memref to memref -// TILE-002: linalg.matmul(%[[sAj]], %[[sBj]], %{{.*}}) : memref, memref, memref +// TILE-002: %[[sBj:.*]] = subview %{{.*}}[%[[K]], 0] [%[[szK]], %[[N]]] [1, 1] : memref to memref +// TILE-002: linalg.matmul(%[[sAj]], %[[sBj]], %{{.*}}) : memref, memref, memref // TILE-234-LABEL: func @matmul( // TILE-234-DAG: %[[C0:.*]] = constant 0 : index -// TILE-234-DAG: %[[C1:.*]] = constant 1 : index // TILE-234-DAG: %[[C2:.*]] = constant 2 : index // TILE-234-DAG: %[[C3:.*]] = constant 3 : index // TILE-234-DAG: %[[C4:.*]] = constant 4 : index @@ -99,22 +85,22 @@ // TILE-234: scf.for %[[J:.*]] = %{{.*}}{{.*}} to %[[ubN]] step %{{.*}} { // TILE-234: scf.for %[[K:.*]] = %{{.*}}{{.*}} to %[[ubK]] step %{{.*}} { // TILE-234: %[[localM:.*]] = dim %{{.*}}, 0 -// TILE-234: %[[szM:.*]] = affine.min #[[bound_map_2]](%[[C2]], %[[localM]], %[[I]]) +// TILE-234: %[[szM:.*]] = affine.min #[[bound_map_2]](%[[I]])[%[[localM]]] // TILE-234: %[[localK:.*]] = dim %{{.*}}, 1 -// TILE-234: %[[szK:.*]] = affine.min #[[bound_map_4]](%[[C4]], %[[localK]], %[[K]]) -// TILE-234: %[[sAik:.*]] = subview %{{.*}}[%[[I]], %[[K]]] [%[[szM]], %[[szK]]] [%[[C1]], %[[C1]]] : memref to memref +// TILE-234: %[[szK:.*]] = affine.min #[[bound_map_4]](%[[K]])[%[[localK]]] +// TILE-234: %[[sAik:.*]] = subview %{{.*}}[%[[I]], %[[K]]] [%[[szM]], %[[szK]]] [1, 1] : memref to memref // TILE-234: %[[localK:.*]] = dim %{{.*}}, 0 -// TILE-234: %[[szK:.*]] = affine.min #[[bound_map_4]](%[[C4]], %[[localK]], %[[K]]) +// TILE-234: %[[szK:.*]] = affine.min #[[bound_map_4]](%[[K]])[%[[localK]]] // TILE-234: %[[localN:.*]] = dim %{{.*}}, 1 -// TILE-234: %[[szN:.*]] = affine.min #[[bound_map_3]](%[[C3]], %[[localN]], %[[J]]) -// TILE-234: %[[sBkj:.*]] = subview %{{.*}}[%[[K]], %[[J]]] [%[[szK]], %[[szN]]] [%[[C1]], %[[C1]]] : memref to memref +// TILE-234: %[[szN:.*]] = affine.min #[[bound_map_3]](%[[J]])[%[[localN]]] +// TILE-234: %[[sBkj:.*]] = subview %{{.*}}[%[[K]], %[[J]]] [%[[szK]], %[[szN]]] [1, 1] : memref to memref // TILE-234: %[[localM:.*]] = dim %{{.*}}, 0 -// TILE-234: %[[szM:.*]] = affine.min #[[bound_map_2]](%[[C2]], %[[localM]], %[[I]]) +// TILE-234: %[[szM:.*]] = affine.min #[[bound_map_2]](%[[I]])[%[[localM]]] // TILE-234: %[[localN:.*]] = dim %{{.*}}, 1 -// TILE-234: %[[szN:.*]] = affine.min #[[bound_map_3]](%[[C3]], %[[localN]], %[[J]]) -// TILE-234: %[[sCij:.*]] = subview %{{.*}}[%[[I]], %[[J]]] [%[[szM]], %[[szN]]] [%[[C1]], %[[C1]]] : memref to memref +// TILE-234: %[[szN:.*]] = affine.min #[[bound_map_3]](%[[J]])[%[[localN]]] +// TILE-234: %[[sCij:.*]] = subview %{{.*}}[%[[I]], %[[J]]] [%[[szM]], %[[szN]]] [1, 1] : memref to memref // -// TILE-234: linalg.matmul(%[[sAik]], %[[sBkj]], %[[sCij]]) : memref, memref, memref +// TILE-234: linalg.matmul(%[[sAik]], %[[sBkj]], %[[sCij]]) : memref, memref, memref // When the buffer shapes are known at compile time, it is possible to avoid // the "min" in subview size computation. This test uses buffer sizes divisible @@ -125,106 +111,107 @@ return } // TILE-2-LABEL: func @matmul_static( +// TILE-2-SAME: %[[ARG0:[0-9a-zA-Z]*]]: memref +// TILE-2-SAME: %[[ARG1:[0-9a-zA-Z]*]]: memref +// TILE-2-SAME: %[[ARG2:[0-9a-zA-Z]*]]: memref // TILE-2-DAG: %[[C0:.*]] = constant 0 : index -// TILE-2-DAG: %[[C1:.*]] = constant 1 : index // TILE-2-DAG: %[[C2:.*]] = constant 2 : index -// TILE-2: %[[M:.*]] = dim %{{.*}}, 0 : memref<10x16xf32, #[[strided2D]]> -// TILE-2: scf.for %[[I:.*]] = %{{.*}}{{.*}} to %[[M]] step %{{.*}} { -// TILE-2: %[[K:.*]] = dim %{{.*}}, 1 : memref<10x16xf32, #[[strided2D]]> -// TILE-2: %[[sAi:.*]] = subview %{{.*}}[%[[I]], %[[C0]]] [%[[C2]], %[[K]]] [%[[C1]], %[[C1]]] : memref<10x16xf32, #[[strided2D]]> to memref -// TILE-2: %[[N:.*]] = dim %{{.*}}, 1 : memref<10x12xf32, #[[strided2D]]> -// TILE-2: %[[sCi:.*]] = subview %{{.*}}[%[[I]], %[[C0]]] [%[[C2]], %[[N]]] [%[[C1]], %[[C1]]] : memref<10x12xf32, #[[strided2D]]> to memref +// TILE-2-DAG: %[[M:.*]] = constant 10 : index +// TILE-2: scf.for %[[I:.*]] = %{{.*}} to %[[M]] step %{{.*}} { +// TILE-2: %[[MIN2:.*]] = affine.min #map2(%[[I]]) +// TILE-2: %[[sAi:.*]] = subview %{{.*}}[%[[I]], 0] [%[[MIN2]], 16] [1, 1] : memref<10x16xf32, #[[strided2D]]> to memref +// TILE-2: %[[MIN22:.*]] = affine.min #map2(%[[I]]) +// TILE-2: %[[sCi:.*]] = subview %{{.*}}[%[[I]], 0] [%[[MIN22]], 12] [1, 1] : memref<10x12xf32, #[[strided2D]]> to memref // TILE-2: linalg.matmul(%[[sAi]], %{{.*}}, %[[sCi]]) // TILE-02-LABEL: func @matmul_static( // TILE-02-DAG: %[[C0:.*]] = constant 0 : index -// TILE-02-DAG: %[[C1:.*]] = constant 1 : index // TILE-02-DAG: %[[C2:.*]] = constant 2 : index -// TILE-02: %[[N:.*]] = dim %arg1, 1 : memref<16x12xf32, #[[strided2D]]> +// TILE-02-DAG: %[[N:.*]] = constant 12 : index // TILE-02: scf.for %[[J:.*]] = %{{.*}} to %[[N]] step %{{.*}} { -// TILE-02: %[[K:.*]] = dim %{{.*}}, 0 : memref<16x12xf32, #[[strided2D]]> -// TILE-02-NOT: affine.min -// TILE-02: %[[sBj:.*]] = subview %{{.*}}[%[[C0]], %[[J]]] [%[[K]], %[[C2]]] [%[[C1]], %[[C1]]] : memref<16x12xf32, #[[strided2D]]> to memref -// TILE-02: %[[M:.*]] = dim %{{.*}}, 0 : memref<10x12xf32, #[[strided2D]]> -// TILE-02-NOT: affine.min -// TILE-02: %[[sCj:.*]] = subview %{{.*}}[%[[C0]], %[[J]]] [%[[M]], %[[C2]]] [%[[C1]], %[[C1]]] : memref<10x12xf32, #[[strided2D]]> to memref -// TILE-02: linalg.matmul(%{{.*}}, %[[sBj]], %[[sCj]]) : memref<10x16xf32, #[[strided2D]]>, memref, memref +// TILE-02: %[[MIN2:.*]] = affine.min #map2(%[[J]]) +// TILE-02: %[[sBj:.*]] = subview %{{.*}}[0, %[[J]]] [16, %[[MIN2]]] [1, 1] : memref<16x12xf32, #[[strided2D]]> to memref<16x?xf32, #[[strided2D]]> +// TILE-02: %[[MIN22:.*]] = affine.min #map2(%[[J]]) +// TILE-02: %[[sCj:.*]] = subview %{{.*}}[0, %[[J]]] [10, %[[MIN22]]] [1, 1] : memref<10x12xf32, #[[strided2D]]> to memref<10x?xf32, #[[strided2D]]> +// TILE-02: linalg.matmul(%{{.*}}, %[[sBj]], %[[sCj]]) : memref<10x16xf32, #[[strided2D]]>, memref<16x?xf32, #[[strided2D]]>, memref<10x?xf32, #[[strided2D]]> // TILE-002-LABEL: func @matmul_static( // TILE-002-DAG: %[[C0:.*]] = constant 0 : index -// TILE-002-DAG: %[[C1:.*]] = constant 1 : index // TILE-002-DAG: %[[C2:.*]] = constant 2 : index -// TILE-002: %[[ubK:.*]] = dim %{{.*}}, 1 : memref<10x16xf32, #[[strided2D]]> -// TILE-002: scf.for %[[K:.*]] = %{{.*}}{{.*}} to %[[ubK]] step %{{.*}} { -// TILE-002: %[[M:.*]] = dim %{{.*}}, 0 : memref<10x16xf32, #[[strided2D]]> -// TILE-002-NOT: affine.min -// TILE-002: %[[sAj:.*]] = subview %{{.*}}[%[[C0]], %[[K]]] [%[[M]], %[[C2]]] [%[[C1]], %[[C1]]] : memref<10x16xf32, #[[strided2D]]> to memref -// TILE-002: %[[N:.*]] = dim %{{.*}}, 1 : memref<16x12xf32, #[[strided2D]]> -// TILE-002-NOT: affine.min -// TILE-002: %[[sBj:.*]] = subview %{{.*}}[%[[K]], %[[C0]]] [%[[C2]], %[[N]]] [%[[C1]], %[[C1]]] : memref<16x12xf32, #[[strided2D]]> to memref -// TILE-002: linalg.matmul(%[[sAj]], %[[sBj]], %{{.*}}) : memref, memref, memref<10x12xf32, #[[strided2D]]> +// TILE-002-DAG: %[[C16:.*]] = constant 16 : index +// TILE-002: scf.for %[[K:.*]] = %{{.*}}{{.*}} to %[[C16]] step %{{.*}} { +// TILE-002: %[[MIN2:.*]] = affine.min #map2(%[[K]]) +// TILE-002: %[[sAj:.*]] = subview %{{.*}}[0, %[[K]]] [10, %[[MIN2]]] [1, 1] : memref<10x16xf32, #[[strided2D]]> to memref<10x?xf32, #[[strided2D]]> +// TILE-002: %[[MIN22:.*]] = affine.min #map2(%[[K]]) +// TILE-002: %[[sBj:.*]] = subview %{{.*}}[%[[K]], 0] [%[[MIN22]], 12] [1, 1] : memref<16x12xf32, #[[strided2D]]> to memref +// TILE-002: linalg.matmul(%[[sAj]], %[[sBj]], %{{.*}}) : memref<10x?xf32, #[[strided2D]]>, memref, memref<10x12xf32, #[[strided2D]]> // TILE-234-LABEL: func @matmul_static( // TILE-234-DAG: %[[C0:.*]] = constant 0 : index -// TILE-234-DAG: %[[C1:.*]] = constant 1 : index // TILE-234-DAG: %[[C2:.*]] = constant 2 : index // TILE-234-DAG: %[[C3:.*]] = constant 3 : index // TILE-234-DAG: %[[C4:.*]] = constant 4 : index -// TILE-234: %[[ubM:.*]] = dim %{{.*}}, 0 : memref<10x16xf32, #[[strided2D]]> -// TILE-234: %[[ubK:.*]] = dim %{{.*}}, 1 : memref<10x16xf32, #[[strided2D]]> -// TILE-234: %[[ubN:.*]] = dim %{{.*}}, 1 : memref<16x12xf32, #[[strided2D]]> -// TILE-234: scf.for %[[I:.*]] = %{{.*}}{{.*}} to %[[ubM]] step %{{.*}} { -// TILE-234: scf.for %[[J:.*]] = %{{.*}}{{.*}} to %[[ubN]] step %{{.*}} { -// TILE-234: scf.for %[[K:.*]] = %{{.*}}{{.*}} to %[[ubK]] step %{{.*}} { -// TILE-234-NOT: affine.min -// TILE-234: %[[sAik:.*]] = subview %{{.*}}[%[[I]], %[[K]]] [%[[C2]], %[[C4]]] [%[[C1]], %[[C1]]] : memref<10x16xf32, #[[strided2D]]> to memref -// TILE-234-NOT: affine.min -// TILE-234: %[[sBkj:.*]] = subview %{{.*}}[%[[K]], %[[J]]] [%[[C4]], %[[C3]]] [%[[C1]], %[[C1]]] : memref<16x12xf32, #[[strided2D]]> to memref -// TILE-234-NOT: affine.min -// TILE-234: %[[sCij:.*]] = subview %{{.*}}[%[[I]], %[[J]]] [%[[C2]], %[[C3]]] [%[[C1]], %[[C1]]] : memref<10x12xf32, #[[strided2D]]> to memref +// TILE-234-DAG: %[[C10:.*]] = constant 10 : index +// TILE-234-DAG: %[[C16:.*]] = constant 16 : index +// TILE-234-DAG: %[[C12:.*]] = constant 12 : index +// TILE-234: scf.for %[[I:.*]] = %{{.*}}{{.*}} to %[[C10]] step %{{.*}} { +// TILE-234: scf.for %[[J:.*]] = %{{.*}}{{.*}} to %[[C12]] step %{{.*}} { +// TILE-234: scf.for %[[K:.*]] = %{{.*}}{{.*}} to %[[C16]] step %{{.*}} { +// TILE-234: %[[sAik:.*]] = subview %{{.*}}[%[[I]], %[[K]]] [%{{.*}}, %{{.*}}] [1, 1] : memref<10x16xf32, #[[strided2D]]> to memref +// TILE-234: %[[sBkj:.*]] = subview %{{.*}}[%[[K]], %[[J]]] [%{{.*}}, %{{.*}}] [1, 1] : memref<16x12xf32, #[[strided2D]]> to memref +// TILE-234: %[[sCij:.*]] = subview %{{.*}}[%[[I]], %[[J]]] [%{{.*}}, %{{.*}}] [1, 1] : memref<10x12xf32, #[[strided2D]]> to memref // -// TILE-234: linalg.matmul(%[[sAik]], %[[sBkj]], %[[sCij]]) : memref, memref, memref +// TILE-234: linalg.matmul(%[[sAik]], %[[sBkj]], %[[sCij]]) : memref, memref, memref func @matvec(%arg0: memref, %arg1: memref, %arg2: memref) { linalg.matvec(%arg0, %arg1, %arg2) : memref, memref, memref return } // TILE-2-LABEL: func @matvec( +// TILE-2-SAME: %[[ARG0:[0-9a-zA-Z]*]]: memref +// TILE-2-SAME: %[[ARG1:[0-9a-zA-Z]*]]: memref +// TILE-2-SAME: %[[ARG2:[0-9a-zA-Z]*]]: memref // TILE-2-DAG: %[[C0:.*]] = constant 0 : index -// TILE-2-DAG: %[[C1:.*]] = constant 1 : index // TILE-2-DAG: %[[C2:.*]] = constant 2 : index // TILE-2: %[[M:.*]] = dim %{{.*}}, 0 : memref // TILE-2: scf.for %[[I:.*]] = %{{.*}}{{.*}} to %[[M]] step %{{.*}} { -// TILE-2: %[[localM:.*]] = dim %{{.*}}, 0 -// TILE-2: %[[szM:.*]] = affine.min #[[bound_map]](%[[C2]], %[[localM]], %[[I]]) +// TILE-2: %[[localM:.*]] = dim %[[ARG0]], 0 +// TILE-2: %[[szM:.*]] = affine.min #[[bound_map]](%[[I]])[%[[localM]]] // TILE-2: %[[N:.*]] = dim %{{.*}}, 1 : memref -// TILE-2: %[[sAi:.*]] = subview %{{.*}}[%[[I]], %[[C0]]] [%[[szM]], %[[N]]] [%[[C1]], %[[C1]]] : memref to memref +// TILE-2: %[[sAi:.*]] = subview %{{.*}}[%[[I]], 0] [%[[szM]], %[[N]]] [1, 1] : memref to memref // TILE-2: %[[localN:.*]] = dim %{{.*}}, 0 -// TILE-2: %[[szN:.*]] = affine.min #[[bound_map]](%[[C2]], %[[localN]], %[[I]]) -// TILE-2: %[[sCi:.*]] = subview %{{.*}}[%[[I]]] [%[[szN]]] [%[[C1]]] : memref to memref -// TILE-2: linalg.matvec(%[[sAi]], %{{.*}}, %[[sCi]]) : memref, memref, memref +// TILE-2: %[[szN:.*]] = affine.min #[[bound_map]](%[[I]])[%[[localN]]] +// TILE-2: %[[sCi:.*]] = subview %{{.*}}[%[[I]]] [%[[szN]]] [1] : memref to memref +// TILE-2: linalg.matvec(%[[sAi]], %{{.*}}, %[[sCi]]) : memref, memref, memref // TILE-02-LABEL: func @matvec( +// TILE-02-SAME: %[[ARG0:[0-9a-zA-Z]*]]: memref +// TILE-02-SAME: %[[ARG1:[0-9a-zA-Z]*]]: memref +// TILE-02-SAME: %[[ARG2:[0-9a-zA-Z]*]]: memref // TILE-02-DAG: %[[C0:.*]] = constant 0 : index -// TILE-02-DAG: %[[C1:.*]] = constant 1 : index // TILE-02-DAG: %[[C2:.*]] = constant 2 : index // TILE-02: %[[K:.*]] = dim %{{.*}}, 1 : memref // TILE-02: scf.for %[[J]] = %{{.*}}{{.*}} to %[[K]] step %{{.*}} { // TILE-02: %[[M:.*]] = dim %{{.*}}, 0 : memref // TILE-02: %[[localN:.*]] = dim %{{.*}}, 1 -// TILE-02: %[[szN:.*]] = affine.min #[[bound_map]](%[[C2]], %[[localN]], %[[J]]) -// TILE-02: %[[sAj:.*]] = subview %{{.*}}[%[[C0]], %[[J]]] [%[[M]], %[[szN]]] [%[[C1]], %[[C1]]] : memref to memref +// TILE-02: %[[szN:.*]] = affine.min #[[bound_map]](%[[J]])[%[[localN]]] +// TILE-02: %[[sAj:.*]] = subview %{{.*}}[0, %[[J]]] [%[[M]], %[[szN]]] [1, 1] : memref to memref // TILE-02: %[[localN:.*]] = dim %{{.*}}, 0 -// TILE-02: %[[szN:.*]] = affine.min #[[bound_map]](%[[C2]], %[[localN]], %[[J]]) -// TILE-02: %[[sBj:.*]] = subview %{{.*}}[%[[J]]] [%[[szN]]] [%[[C1]]] : memref to memref -// TILE-02: linalg.matvec(%[[sAj]], %[[sBj]], %{{.*}}) : memref, memref, memref +// TILE-02: %[[szN:.*]] = affine.min #[[bound_map]](%[[J]])[%[[localN]]] +// TILE-02: %[[sBj:.*]] = subview %{{.*}}[%[[J]]] [%[[szN]]] [1] : memref to memref +// TILE-02: linalg.matvec(%[[sAj]], %[[sBj]], %{{.*}}) : memref, memref, memref // TILE-002-LABEL: func @matvec( +// TILE-002-SAME: %[[ARG0:[0-9a-zA-Z]*]]: memref +// TILE-002-SAME: %[[ARG1:[0-9a-zA-Z]*]]: memref +// TILE-002-SAME: %[[ARG2:[0-9a-zA-Z]*]]: memref // TILE-002-NOT: scf.for // TILE-234-LABEL: func @matvec( +// TILE-234-SAME: %[[ARG0:[0-9a-zA-Z]*]]: memref +// TILE-234-SAME: %[[ARG1:[0-9a-zA-Z]*]]: memref +// TILE-234-SAME: %[[ARG2:[0-9a-zA-Z]*]]: memref // TILE-234-DAG: %[[C0:.*]] = constant 0 : index -// TILE-234-DAG: %[[C1:.*]] = constant 1 : index // TILE-234-DAG: %[[C2:.*]] = constant 2 : index // TILE-234-DAG: %[[C3:.*]] = constant 3 : index // TILE-234: %[[M:.*]] = dim %{{.*}}, 0 : memref @@ -232,18 +219,18 @@ // TILE-234: scf.for %[[I:.*]] = %{{.*}}{{.*}} to %[[M]] step %{{.*}} { // TILE-234: scf.for %[[J:.*]] = %{{.*}}{{.*}} to %[[K]] step %{{.*}} { // TILE-234: %[[localM:.*]] = dim %{{.*}}, 0 -// TILE-234: %[[szM:.*]] = affine.min #[[bound_map_2]](%[[C2]], %[[localM]], %[[I]]) +// TILE-234: %[[szM:.*]] = affine.min #[[bound_map_2]](%[[I]])[%[[localM]]] // TILE-234: %[[localN:.*]] = dim %{{.*}}, 1 -// TILE-234: %[[szN:.*]] = affine.min #[[bound_map_3]](%[[C3]], %[[localN]], %[[J]]) -// TILE-234: %[[sAij:.*]] = subview %{{.*}}[%[[I]], %[[J]]] [%[[szM]], %[[szN]]] [%[[C1]], %[[C1]]] : memref to memref +// TILE-234: %[[szN:.*]] = affine.min #[[bound_map_3]](%[[J]])[%[[localN]]] +// TILE-234: %[[sAij:.*]] = subview %{{.*}}[%[[I]], %[[J]]] [%[[szM]], %[[szN]]] [1, 1] : memref to memref // TILE-234: %[[localN:.*]] = dim %{{.*}}, 0 -// TILE-234: %[[szN:.*]] = affine.min #[[bound_map_3]](%[[C3]], %[[localN]], %[[J]]) -// TILE-234: %[[sBj:.*]] = subview %{{.*}}[%[[J]]] [%[[szN]]] [%[[C1]]] : memref to memref +// TILE-234: %[[szN:.*]] = affine.min #[[bound_map_3]](%[[J]])[%[[localN]]] +// TILE-234: %[[sBj:.*]] = subview %{{.*}}[%[[J]]] [%[[szN]]] [1] : memref to memref // TILE-234: %[[localM:.*]] = dim %{{.*}}, 0 -// TILE-234: %[[szM:.*]] = affine.min #[[bound_map_2]](%[[C2]], %[[localM]], %[[I]]) -// TILE-234: %[[sCi:.*]] = subview %{{.*}}[%[[I]]] [%[[szM]]] [%[[C1]]] : memref to memref +// TILE-234: %[[szM:.*]] = affine.min #[[bound_map_2]](%[[I]])[%[[localM]]] +// TILE-234: %[[sCi:.*]] = subview %{{.*}}[%[[I]]] [%[[szM]]] [1] : memref to memref // -// TILE-234: linalg.matvec(%[[sAij]], %[[sBj]], %[[sCi]]) : memref, memref, memref +// TILE-234: linalg.matvec(%[[sAij]], %[[sBj]], %[[sCi]]) : memref, memref, memref func @dot(%arg0: memref, %arg1: memref, %arg2: memref) { linalg.dot(%arg0, %arg1, %arg2) : memref, memref, memref @@ -251,17 +238,16 @@ } // TILE-2-LABEL: func @dot( // TILE-2-DAG: %[[C0:.*]] = constant 0 : index -// TILE-2-DAG: %[[C1:.*]] = constant 1 : index // TILE-2-DAG: %[[C2:.*]] = constant 2 : index // TILE-2: %[[M:.*]] = dim %{{.*}}, 0 : memref // TILE-2: scf.for %[[I:.*]] = %{{.*}}{{.*}} to %[[M]] step %{{.*}} { // TILE-2: %[[localM:.*]] = dim %{{.*}}, 0 -// TILE-2: %[[szM:.*]] = affine.min #[[bound_map]](%[[C2]], %[[localM]], %[[I]]) -// TILE-2: %[[sAi:.*]] = subview %{{.*}}[%[[I]]] [%[[szM]]] [%[[C1]]] : memref to memref +// TILE-2: %[[szM:.*]] = affine.min #[[bound_map]](%[[I]])[%[[localM]]] +// TILE-2: %[[sAi:.*]] = subview %{{.*}}[%[[I]]] [%[[szM]]] [1] : memref to memref // TILE-2: %[[localM:.*]] = dim %{{.*}}, 0 -// TILE-2: %[[szM:.*]] = affine.min #[[bound_map]](%[[C2]], %[[localM]], %[[I]]) -// TILE-2: %[[sBi:.*]] = subview %{{.*}}[%[[I]]] [%[[szM]]] [%[[C1]]] : memref to memref -// TILE-2: linalg.dot(%[[sAi]], %[[sBi]], {{.*}}) : memref, memref, memref +// TILE-2: %[[szM:.*]] = affine.min #[[bound_map]](%[[I]])[%[[localM]]] +// TILE-2: %[[sBi:.*]] = subview %{{.*}}[%[[I]]] [%[[szM]]] [1] : memref to memref +// TILE-2: linalg.dot(%[[sAi]], %[[sBi]], {{.*}}) : memref, memref, memref // TILE-02-LABEL: func @dot( // TILE-02-NOT: scf.for @@ -271,17 +257,16 @@ // TILE-234-LABEL: func @dot( // TILE-234-DAG: %[[C0:.*]] = constant 0 : index -// TILE-234-DAG: %[[C1:.*]] = constant 1 : index // TILE-234-DAG: %[[C2:.*]] = constant 2 : index // TILE-234: %[[ubK:.*]] = dim %{{.*}}, 0 : memref // TILE-234: scf.for %[[I:.*]] = %{{.*}} to %[[ubK]] step %{{.*}} { // TILE-234: %[[localM:.*]] = dim %{{.*}}, 0 -// TILE-234: %[[szM:.*]] = affine.min #[[bound_map_2]](%[[C2]], %[[localM]], %[[I]]) -// TILE-234: %[[sAi:.*]] = subview %{{.*}}[%[[I]]] [%[[szM]]] [%[[C1]]] : memref to memref +// TILE-234: %[[szM:.*]] = affine.min #[[bound_map_2]](%[[I]])[%[[localM]]] +// TILE-234: %[[sAi:.*]] = subview %{{.*}}[%[[I]]] [%[[szM]]] [1] : memref to memref // TILE-234: %[[localM:.*]] = dim %{{.*}}, 0 -// TILE-234: %[[szM:.*]] = affine.min #[[bound_map_2]](%[[C2]], %[[localM]], %[[I]]) -// TILE-234: %[[sBi:.*]] = subview %{{.*}}[%[[I]]] [%[[szM]]] [%[[C1]]] : memref to memref -// TILE-234: linalg.dot(%[[sAi]], %[[sBi]], %{{.*}}) : memref, memref, memref +// TILE-234: %[[szM:.*]] = affine.min #[[bound_map_2]](%[[I]])[%[[localM]]] +// TILE-234: %[[sBi:.*]] = subview %{{.*}}[%[[I]]] [%[[szM]]] [1] : memref to memref +// TILE-234: linalg.dot(%[[sAi]], %[[sBi]], %{{.*}}) : memref, memref, memref func @fill_static(%arg0: memref<127x99xf32>, %arg1: f32) { linalg.fill(%arg0, %arg1) : memref<127x99xf32>, f32 @@ -291,13 +276,13 @@ // TILE-2: for // TILE-2-NOT: for // TILE-2: subview{{.*}} : memref<127x99xf32> -// TILE-2: linalg.fill{{.*}} : memref, f32 +// TILE-2: linalg.fill{{.*}} : memref, f32 // TILE-02-LABEL: func @fill_static // TILE-02: for // TILE-02-NOT: for // TILE-02: subview{{.*}} : memref<127x99xf32> -// TILE-02: linalg.fill{{.*}} : memref, f32 +// TILE-02: linalg.fill{{.*}} : memref<127x?xf32, #[[stride_99_1_layout_map]]>, f32 // TILE-002-LABEL: func @fill_static // TILE-002-NOT: for @@ -308,7 +293,7 @@ // TILE-234: for // TILE-234-NOT: for // TILE-234: subview{{.*}} : memref<127x99xf32> -// TILE-234: linalg.fill{{.*}} : memref, f32 +// TILE-234: linalg.fill{{.*}} : memref, f32 func @fill(%arg0: memref, %arg1: f32) { diff --git a/mlir/test/Dialect/Linalg/tile_conv.mlir b/mlir/test/Dialect/Linalg/tile_conv.mlir --- a/mlir/test/Dialect/Linalg/tile_conv.mlir +++ b/mlir/test/Dialect/Linalg/tile_conv.mlir @@ -1,10 +1,9 @@ // RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2,3,0,0,4" | FileCheck %s -check-prefix=TILE-23004 // TILE-23004-DAG: #[[D0x30pS0x10:.*]] = affine_map<(d0) -> (d0 * 30)> -// TILE-23004-DAG: #[[S0x10p90:.*]] = affine_map<()[s0] -> (s0 * 10 + 90)> +// TILE-23004-DAG: #[[S0x10p90D0x30pS1:.*]] = affine_map<(d0)[s0, s1] -> (s0 * 10 + 90, d0 * -30 + s1)> // TILE-23004-DAG: #[[strided4D:.*]] = affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3)> -// TILE-23004-DAG: #[[strided4D_dynamic:.*]] = affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3 * s4)> -// TILE-23004-DAG: #[[bound_map_4:.*]] = affine_map<(d0, d1, d2) -> (4, d1 - d2)> +// TILE-23004-DAG: #[[bound_map_4:.*]] = affine_map<(d0)[s0] -> (4, -d0 + s0)> func @conv(%arg0: memref, %arg1: memref, %arg2: memref) { linalg.conv(%arg0, %arg1, %arg2) {dilations = [10, 20], strides = [30, 40]} : memref, memref, memref @@ -13,7 +12,6 @@ // TILE-23004-LABEL: func @conv( // TILE-23004: %{{.*}}: memref, %{{.*}}: memref, %{{.*}}: memref) { // TILE-23004-DAG: %[[C0:.*]] = constant 0 : index -// TILE-23004-DAG: %[[C1:.*]] = constant 1 : index // TILE-23004-DAG: %[[C2:.*]] = constant 2 : index // TILE-23004-DAG: %[[C3:.*]] = constant 3 : index // TILE-23004-DAG: %[[C4:.*]] = constant 4 : index @@ -27,19 +25,20 @@ // TILE-23004: %[[Z0:.*]] = dim %{{.*}}, 0 : memref // TILE-23004: %[[Z1:.*]] = dim %{{.*}}, 1 : memref // TILE-23004: %[[Z2:.*]] = dim %{{.*}}, 2 : memref -// TILE-23004: %[[szK:.*]] = affine.min #[[bound_map_4]](%[[C4]], %[[Z2]], %[[ivK]]) +// TILE-23004: %[[szK:.*]] = affine.min #[[bound_map_4]](%[[ivK]])[%[[Z2]]] // TILE-23004: %[[K:.*]] = dim %{{.*}}, 3 : memref -// TILE-23004: %[[FilterView:.*]] = subview %{{.*}}[%[[C0]], %[[C0]], %[[ivK]], %[[C0]]] [%[[Z0]], %[[Z1]], %[[szK]], %[[K]]] [%[[C1]], %[[C1]], %[[C1]], %[[C1]]] : memref to memref +// TILE-23004: %[[FilterView:.*]] = subview %{{.*}}[0, 0, %[[ivK]], 0] [%[[Z0]], %[[Z1]], %[[szK]], %[[K]]] [1, 1, 1, 1] : memref to memref // // TILE-23004: %[[J1:.*]] = affine.apply #[[D0x30pS0x10]](%[[ivJ]]) -// T__ILE-23004: %[[I1pStep:.*]] = affine.apply #[[S0x10p90]]()[%[[I1]]] +// TILE-23004: %[[PaddedInput0b:.*]] = dim %{{.*}}, 1 : memref +// TILE-23004: %[[I1pStep:.*]] = affine.min #[[S0x10p90D0x30pS1]](%[[ivJ]])[%[[PaddedInput0]], %[[PaddedInput0b]]] // TILE-23004: %[[SZ2:.*]] = dim %{{.*}}, 2 : memref // TILE-23004: %[[dim3:.*]] = dim %{{.*}}, 3 -// TILE-23004: %[[sz3:.*]] = affine.min #[[bound_map_4]](%[[C4]], %[[dim3]], %[[ivK]] -// TILE-23004: %[[InputView:.*]] = subview %{{.*}}[%[[ivI]], %[[J1]], %[[C0]], %[[ivK]]] [%{{.*}}, %{{.*}}, %[[SZ2]], %[[sz3]]] [%[[C1]], %[[C1]], %[[C1]], %[[C1]]] : memref to memref +// TILE-23004: %[[sz3:.*]] = affine.min #[[bound_map_4]](%[[ivK]])[%[[dim3]]] +// TILE-23004: %[[InputView:.*]] = subview %{{.*}}[%[[ivI]], %[[J1]], 0, %[[ivK]]] [%{{.*}}, %{{.*}}, %[[SZ2]], %[[sz3]]] [1, 1, 1, 1] : memref to memref // // TILE-23004: %[[X0:.*]] = dim %{{.*}}, 2 : memref // TILE-23004: %[[X1:.*]] = dim %{{.*}}, 3 : memref -// TILE-23004: %[[OutputView:.*]] = subview %{{.*}}[%[[ivI]], %[[ivJ]], %[[C0]], %[[C0]]] [%{{.*}}, %{{.*}}, %[[X0]], %[[X1]]] [%[[C1]], %[[C1]], %[[C1]], %[[C1]]] : memref to memref +// TILE-23004: %[[OutputView:.*]] = subview %{{.*}}[%[[ivI]], %[[ivJ]], 0, 0] [%{{.*}}, %{{.*}}, %[[X0]], %[[X1]]] [1, 1, 1, 1] : memref to memref // -// TILE-23004: linalg.conv(%[[FilterView]], %[[InputView]], %[[OutputView]]) {dilations = [10, 20], strides = [30, 40]} : memref, memref, memref +// TILE-23004: linalg.conv(%[[FilterView]], %[[InputView]], %[[OutputView]]) {dilations = [10, 20], strides = [30, 40]} : memref, memref, memref diff --git a/mlir/test/Dialect/Linalg/tile_conv_padding.mlir b/mlir/test/Dialect/Linalg/tile_conv_padding.mlir --- a/mlir/test/Dialect/Linalg/tile_conv_padding.mlir +++ b/mlir/test/Dialect/Linalg/tile_conv_padding.mlir @@ -1,10 +1,9 @@ -// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2,3,0,0,4" | FileCheck %s -check-prefix=TILE-23004 -// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2" | FileCheck %s -check-prefix=TILE-20000 +// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2,3,0,0,4" | FileCheck %s -check-prefix=TILE-23004 --dump-input-on-failure +// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2" | FileCheck %s -check-prefix=TILE-20000 --dump-input-on-failure // TILE-23004-DAG: #[[strided4D:.*]] = affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3)> // TILE-20000-DAG: #[[strided4D:.*]] = affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3)> -// TILE-20000-DAG: #[[minmap:.*]] = affine_map<(d0, d1, d2) -> (2, d1 - d2)> -// TILE-20000-DAG: #[[subviewstride:.*]] = affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3 * s4)> +// TILE-20000-DAG: #[[minmap:.*]] = affine_map<(d0)[s0] -> (2, -d0 + s0)> func @conv_padding(%arg0: memref, %arg1: memref, %arg2: memref) { linalg.conv(%arg0, %arg1, %arg2) {dilations = [10, 20], padding = dense<[[1, 1], [0, 1]]> : tensor<2x2xi64>, strides = [30, 40]} : memref, memref, memref @@ -21,20 +20,19 @@ // TILE-20000-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref // TILE-20000-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref) // TILE-20000-DAG: %[[C0:.*]] = constant 0 : index -// TILE-20000-DAG: %[[C1:.*]] = constant 1 : index // TILE-20000-DAG: %[[C2:.*]] = constant 2 : index // TILE-20000: %[[B:.*]] = dim %[[ARG1]], 0 // TILE-20000: scf.for %[[ivI:.*]] = %[[C0]] to %[[B]] step %[[C2]] { // TILE-20000: %[[DIM10:.*]] = dim %[[ARG1]], 0 -// TILE-20000: %[[EXTENT:.*]] = affine.min #[[minmap]](%[[C2]], %[[DIM10]], %[[ivI]]) +// TILE-20000: %[[EXTENT:.*]] = affine.min #[[minmap]](%[[ivI]])[%[[DIM10]]] // TILE-20000: %[[DIM11:.*]] = dim %[[ARG1]], 1 // TILE-20000: %[[DIM12:.*]] = dim %[[ARG1]], 2 // TILE-20000: %[[DIM13:.*]] = dim %[[ARG1]], 3 -// TILE-20000: %[[SUBVIEW1:.*]] = subview %[[ARG1]][%[[ivI]], %[[C0]], %[[C0]], %[[C0]]] [%[[EXTENT]], %[[DIM11]], %[[DIM12]], %[[DIM13]]] +// TILE-20000: %[[SUBVIEW1:.*]] = subview %[[ARG1]][%[[ivI]], 0, 0, 0] [%[[EXTENT]], %[[DIM11]], %[[DIM12]], %[[DIM13]]] // TILE-20000: %[[DIM20:.*]] = dim %[[ARG2]], 0 -// TILE-20000: %[[EXTENT:.*]] = affine.min #[[minmap]](%[[C2]], %[[DIM20]], %[[ivI]]) +// TILE-20000: %[[EXTENT:.*]] = affine.min #[[minmap]](%[[ivI]])[%[[DIM20]]] // TILE-20000: %[[DIM21:.*]] = dim %[[ARG2]], 1 // TILE-20000: %[[DIM22:.*]] = dim %[[ARG2]], 2 // TILE-20000: %[[DIM23:.*]] = dim %[[ARG2]], 3 -// TILE-20000: %[[SUBVIEW2:.*]] = subview %[[ARG2]][%[[ivI]], %[[C0]], %[[C0]], %[[C0]]] [%[[EXTENT]], %[[DIM21]], %[[DIM22]], %[[DIM23]]] +// TILE-20000: %[[SUBVIEW2:.*]] = subview %[[ARG2]][%[[ivI]], 0, 0, 0] [%[[EXTENT]], %[[DIM21]], %[[DIM22]], %[[DIM23]]] // TILE-20000: linalg.conv(%[[ARG0]], %[[SUBVIEW1]], %[[SUBVIEW2]]) diff --git a/mlir/test/Dialect/Linalg/tile_indexed_generic.mlir b/mlir/test/Dialect/Linalg/tile_indexed_generic.mlir --- a/mlir/test/Dialect/Linalg/tile_indexed_generic.mlir +++ b/mlir/test/Dialect/Linalg/tile_indexed_generic.mlir @@ -65,8 +65,8 @@ return } // TILE-10n25-LABEL: func @indexed_generic_matrix -// TILE-10n25: %[[C25:.*]] = constant 25 : index -// TILE-10n25: %[[C10:.*]] = constant 10 : index +// TILE-10n25-DAG: %[[C25:.*]] = constant 25 : index +// TILE-10n25-DAG: %[[C10:.*]] = constant 10 : index // TILE-10n25: scf.for %[[K:.*]] = {{.*}} step %[[C10]] // TILE-10n25: scf.for %[[L:.*]] = {{.*}} step %[[C25]] // TILE-10n25: linalg.indexed_generic diff --git a/mlir/test/Dialect/Linalg/tile_parallel.mlir b/mlir/test/Dialect/Linalg/tile_parallel.mlir --- a/mlir/test/Dialect/Linalg/tile_parallel.mlir +++ b/mlir/test/Dialect/Linalg/tile_parallel.mlir @@ -26,7 +26,6 @@ // TILE-2-LABEL: func @sum( // TILE-2-SAME: [[LHS:%.*]]: {{.*}}, [[RHS:%.*]]: {{.*}}, [[SUM:%.*]]: {{.*}}) { // TILE-2-DAG: [[C0:%.*]] = constant 0 : index -// TILE-2-DAG: [[C1:%.*]] = constant 1 : index // TILE-2-DAG: [[C2:%.*]] = constant 2 : index // TILE-2: [[LHS_ROWS:%.*]] = dim [[LHS]], 0 // TILE-2: scf.parallel ([[I:%.*]]) = ([[C0]]) to ([[LHS_ROWS]]) step ([[C2]]) { @@ -39,7 +38,6 @@ // TILE-02-LABEL: func @sum( // TILE-02-SAME: [[LHS:%.*]]: {{.*}}, [[RHS:%.*]]: {{.*}}, [[SUM:%.*]]: {{.*}}) { // TILE-02-DAG: [[C0:%.*]] = constant 0 : index -// TILE-02-DAG: [[C1:%.*]] = constant 1 : index // TILE-02-DAG: [[C2:%.*]] = constant 2 : index // TILE-02: [[LHS_COLS:%.*]] = dim [[LHS]], 1 // TILE-02: scf.parallel ([[I:%.*]]) = ([[C0]]) to ([[LHS_COLS]]) step ([[C2]]) { @@ -57,7 +55,6 @@ // TILE-234-LABEL: func @sum( // TILE-234-SAME: [[LHS:%.*]]: {{.*}}, [[RHS:%.*]]: {{.*}}, [[SUM:%.*]]: {{.*}}) { // TILE-234-DAG: [[C0:%.*]] = constant 0 : index -// TILE-234-DAG: [[C1:%.*]] = constant 1 : index // TILE-234-DAG: [[C2:%.*]] = constant 2 : index // TILE-234-DAG: [[C3:%.*]] = constant 3 : index // TILE-234: [[LHS_ROWS:%.*]] = dim [[LHS]], 0 diff --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp --- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp @@ -170,8 +170,9 @@ } void fillL1TilingAndMatmulToVectorPatterns( - MLIRContext *context, StringRef startMarker, + FuncOp funcOp, StringRef startMarker, SmallVectorImpl &patternsVector) { + MLIRContext *context = funcOp.getContext(); patternsVector.emplace_back(LinalgTilingPattern( context, LinalgTilingOptions().setTileSizes({8, 12, 16}).setInterchange({1, 0, 2}), @@ -195,7 +196,7 @@ } else { SmallVector stage1Patterns; if (testMatmulToVectorPatterns1dTiling) { - fillL1TilingAndMatmulToVectorPatterns(&getContext(), "START", + fillL1TilingAndMatmulToVectorPatterns(getFunction(), "START", stage1Patterns); } else if (testMatmulToVectorPatterns2dTiling) { stage1Patterns.emplace_back( @@ -204,7 +205,7 @@ .setTileSizes({768, 264, 768}) .setInterchange({1, 2, 0}), LinalgMarker({"START"}, "L2"))); - fillL1TilingAndMatmulToVectorPatterns(&getContext(), "L2", + fillL1TilingAndMatmulToVectorPatterns(getFunction(), "L2", stage1Patterns); } OwningRewritePatternList stage2Patterns =