diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -76,13 +76,6 @@ //===----------------------------------------------------------------------===// /// Linalg strategy passes. //===----------------------------------------------------------------------===// -/// Create a LinalgStrategyTileAndFusePass. -std::unique_ptr> -createLinalgStrategyTileAndFusePass( - StringRef opName = "", const linalg::LinalgTilingAndFusionOptions &opt = {}, - const linalg::LinalgTransformationFilter &filter = - linalg::LinalgTransformationFilter()); - /// Create a LinalgStrategyTilePass. std::unique_ptr> createLinalgStrategyTilePass( StringRef opName = "", diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -162,18 +162,6 @@ ]; } -def LinalgStrategyTileAndFusePass - : Pass<"linalg-strategy-tile-and-fuse-pass", "func::FuncOp"> { - let summary = "Configurable pass to apply pattern-based tiling and fusion."; - let constructor = "mlir::createLinalgStrategyTileAndFusePass()"; - let options = [ - Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"", - "Which func op is the anchor to latch on.">, - Option<"anchorOpName", "anchor-op", "std::string", /*default=*/"", - "Which linalg op within the func is the anchor to latch on.">, - ]; -} - def LinalgStrategyTilePass : Pass<"linalg-strategy-tile-pass", "func::FuncOp"> { let summary = "Configurable pass to apply pattern-based linalg tiling."; diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h @@ -30,23 +30,6 @@ LinalgTransformationFilter::FilterFunction filter = nullptr; }; -/// Represent one application of LinalgStrategyTileAndFusePass. -struct TileAndFuse : public Transformation { - TileAndFuse(StringRef name, linalg::LinalgTilingAndFusionOptions options, - LinalgTransformationFilter::FilterFunction f = nullptr) - : Transformation(std::move(f)), opName(name), - options(std::move(options)) {} - - void addToPassPipeline(OpPassManager &pm, - LinalgTransformationFilter m) const override { - pm.addPass(createLinalgStrategyTileAndFusePass(opName, options, m)); - } - -private: - std::string opName; - linalg::LinalgTilingAndFusionOptions options; -}; - /// Represent one application of LinalgStrategyTilePass. struct Tile : public Transformation { Tile(StringRef name, linalg::LinalgTilingOptions options, @@ -66,22 +49,6 @@ /// Codegen strategy controls how a Linalg op is progressively lowered. struct CodegenStrategy { - /// Append a pattern to tile the Op `opName` and fuse its producers with - /// tiling and fusion `options`. - CodegenStrategy & - tileAndFuse(StringRef opName, const LinalgTilingAndFusionOptions &options, - const LinalgTransformationFilter::FilterFunction &f = nullptr) { - transformationSequence.emplace_back( - std::make_unique(opName, options, f)); - return *this; - } - /// Conditionally append a pattern to tile the Op `opName` and fuse its - /// producers with tiling and fusion `options`. - CodegenStrategy & - tileAndFuseIf(bool b, StringRef opName, LinalgTilingAndFusionOptions options, - LinalgTransformationFilter::FilterFunction f = nullptr) { - return b ? tileAndFuse(opName, std::move(options), std::move(f)) : *this; - } /// Append a pattern to add a level of tiling for Op `opName` with tiling /// `options`. CodegenStrategy & diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -787,42 +787,6 @@ } }; -/// -/// Linalg tile and fuse tensor ops pattern. -/// -/// Apply tiling and fusion as a pattern. -/// `filter` controls LinalgTransformMarker matching and update when specified. -/// See `tileConsumerAndFuseProducers` for more details. -struct LinalgTileAndFuseTensorOpsPattern : public RewritePattern { - // Entry point to match any LinalgOp. - LinalgTileAndFuseTensorOpsPattern( - MLIRContext *context, LinalgTilingAndFusionOptions options, - LinalgTransformationFilter f = LinalgTransformationFilter(), - PatternBenefit benefit = 1); - // Entry point to match a specific LinalgOp. - LinalgTileAndFuseTensorOpsPattern( - StringRef opName, MLIRContext *context, - LinalgTilingAndFusionOptions options, - LinalgTransformationFilter f = LinalgTransformationFilter(), - PatternBenefit benefit = 1); - - /// `matchAndRewrite` implementation that returns the significant transformed - /// pieces of IR. - FailureOr - returningMatchAndRewrite(Operation *op, PatternRewriter &rewriter) const; - - LogicalResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override { - return returningMatchAndRewrite(op, rewriter); - } - -private: - /// LinalgTransformMarker handles special attribute manipulations. - LinalgTransformationFilter filter; - /// Tile sizes and interchange used to tile the root operation. - LinalgTilingAndFusionOptions options; -}; - /// /// Linalg generalization pattern. /// diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -445,14 +445,6 @@ DenseMap> tiledRootAndFusedOpsLoops; }; -/// Tiles `consumerOp` and fuses its dependencies if possible. Uses the -/// `tileSizes`, `tileInterchange`, and `tileDistribution` parameters to control -/// the tiling. -FailureOr tileConsumerAndFuseProducers( - OpBuilder &b, LinalgOp consumerOp, ArrayRef tileSizes, - ArrayRef tileInterchange, - const Optional &tileDistribution); - //===----------------------------------------------------------------------===// // Generic op region utilities //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h --- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h @@ -53,8 +53,8 @@ SCFTilingOptions &setTileSizes(ArrayRef ts); /// The interchange vector to reorder the tiled loops. - SmallVector interchangeVector = {}; - SCFTilingOptions &setInterchange(ArrayRef interchange) { + SmallVector interchangeVector = {}; + SCFTilingOptions &setInterchange(ArrayRef interchange) { interchangeVector = llvm::to_vector(interchange); return *this; } diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/PDL/IR/PDLTypes.h" +#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Interfaces/TilingInterface.h" @@ -99,38 +100,36 @@ results.assign(1, nullptr); return emitDefaultSilenceableFailure(target); } - //===----------------------------------------------------------------------===// // FuseOp //===----------------------------------------------------------------------===// /// Apply a tiling transformation to all payload ops and store both the /// tiled operation as well as the created tile loops. -static LogicalResult -applyTilingToAll(Operation *transformOp, ArrayRef payloadOps, - unsigned numLoops, - transform::TransformResults &transformResults, - function_ref(LinalgOp)> applyFn) { +static LogicalResult applyTilingToAll( + Operation *transformOp, ArrayRef payloadOps, unsigned numLoops, + transform::TransformResults &transformResults, + function_ref(TilingInterface)> + applyFn) { SmallVector tiledLinalgOps; SmallVector> loopOps(numLoops); for (unsigned int i = 0; i < numLoops; ++i) loopOps[i].reserve(payloadOps.size()); for (Operation *target : payloadOps) { - auto linalgOp = dyn_cast(target); - if (!linalgOp) - return transformOp->emitError("only LinalgOps are supported"); + auto tilingInterfaceOp = dyn_cast(target); + if (!tilingInterfaceOp) + return transformOp->emitError("only TilingInterface ops are supported"); - FailureOr tiled = applyFn(linalgOp); + FailureOr tiled = applyFn(tilingInterfaceOp); if (failed(tiled)) return failure(); - tiledLinalgOps.push_back(tiled->op); - if (tiled->loops.size() != numLoops) - // Not enough loops were generated. This usually means that the input size - // was smaller than the tiling size. - // TODO: LinalgTilingPattern should return failure(). - return failure(); + tiledLinalgOps.push_back(tiled->tiledAndFusedOps.front()); + assert(tiled->loops.size() == numLoops && + "Mismatched number of loops, tile and fuse transform should have " + "failed"); + for (unsigned int i = 0; i < numLoops; ++i) loopOps[i].push_back(tiled->loops[i]); } @@ -138,6 +137,7 @@ transformResults.set(transformOp->getOpResult(0), tiledLinalgOps); for (unsigned int i = 0; i < numLoops; ++i) transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]); + return success(); } @@ -172,27 +172,23 @@ DiagnosedSilenceableFailure transform::FuseOp::apply(mlir::transform::TransformResults &transformResults, mlir::transform::TransformState &state) { - LinalgTilingAndFusionOptions fusionOptions; - fusionOptions.tileSizes = extractFromI64ArrayAttr(getTileSizes()); - fusionOptions.tileInterchange = extractFromI64ArrayAttr(getTileInterchange()); + SmallVector tileSizes = extractFromI64ArrayAttr(getTileSizes()); + SmallVector tileInterchange = + extractFromI64ArrayAttr(getTileInterchange()); + scf::SCFTilingOptions tilingOptions; + tilingOptions.interchangeVector = tileInterchange; + tilingOptions = tilingOptions.setTileSizes(tileSizes); + scf::SCFTileAndFuseOptions tileAndFuseOptions; + tileAndFuseOptions.tilingOptions = tilingOptions; LogicalResult result = applyTilingToAll( getOperation(), state.getPayloadOps(getTarget()), - fusionOptions.tileSizes.size() - llvm::count(fusionOptions.tileSizes, 0), - transformResults, [&](LinalgOp linalgOp) -> FailureOr { - LinalgTileAndFuseTensorOpsPattern pattern(getContext(), fusionOptions); + tileSizes.size() - llvm::count(tileSizes, 0), transformResults, + [&](TilingInterface tilingInterfaceOp) + -> FailureOr { SimpleRewriter rewriter(getContext()); - rewriter.setInsertionPoint(linalgOp); - FailureOr tileLoopNest = - pattern.returningMatchAndRewrite(linalgOp, rewriter); - if (failed(tileLoopNest)) - return failure(); - - TiledLinalgOp tiledLinalgOp; - tiledLinalgOp.op = tileLoopNest->getRootOp(); - tiledLinalgOp.loops = {tileLoopNest->getLoopOps().begin(), - tileLoopNest->getLoopOps().end()}; - return tiledLinalgOp; + return tileConsumerAndFuseProducerGreedilyUsingSCFForOp( + rewriter, tilingInterfaceOp, tileAndFuseOptions); }); return DiagnosedSilenceableFailure(result); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -414,68 +414,3 @@ } return result; } - -//===----------------------------------------------------------------------===// -// Tile and fuse entry-points. -//===----------------------------------------------------------------------===// - -FailureOr mlir::linalg::tileConsumerAndFuseProducers( - OpBuilder &b, LinalgOp consumerOp, ArrayRef tileSizes, - ArrayRef tileInterchange, - const Optional &tileDistribution) { - assert(tileSizes.size() == tileInterchange.size() && - "expect the number of tile sizes and interchange dims to match"); - assert(isPermutation(tileInterchange) && - "expect tile interchange is a permutation"); - - // Create an empty tile loop nest. - TileLoopNest tileLoopNest(consumerOp); - - // Search the number of outer parallel loops to separate them from possible - // inner reduction dimensions. - SmallVector iterTypes = consumerOp.getIteratorTypesArray(); - applyPermutationToVector(iterTypes, tileInterchange); - auto *it = find_if_not(iterTypes, isParallelIterator); - int64_t split = std::distance(iterTypes.begin(), it); - - // Helper to fuse the producers greedily using a queue of fusion candidates. - auto fuseProducersGreedily = [&](ArrayRef operands) { - SmallVector candidates(operands.begin(), operands.end()); - while (!candidates.empty()) { - FailureOr fusedProducer = - tileLoopNest.fuseProducer(b, candidates.pop_back_val()); - if (failed(fusedProducer)) - continue; - candidates.append(fusedProducer->getInputAndOutputOperands()); - } - }; - - // Perform tiling and fusion in two steps. We need to respect the loop - // interchange here; filter parellel dimensions based on their order *after* - // permutation but pass in the original configuration *before* permuation, - // given the tiling and interchange happen together. - SmallVector outerTileSizes(tileSizes.size(), 0); - SmallVector innerTileSizes(tileSizes.size(), 0); - for (int64_t i : tileInterchange.take_front(split)) - outerTileSizes[i] = tileSizes[i]; - for (int64_t i : tileInterchange.drop_front(split)) - innerTileSizes[i] = tileSizes[i]; - - // Tile the outer parallel loops and fuse the output operands. - if (failed(tileLoopNest.tileRootOp(b, outerTileSizes, tileInterchange, - tileDistribution))) - return failure(); - fuseProducersGreedily(tileLoopNest.getRootOp().getOutputOperands()); - - // Tile the remaining loops and fuse the input operands. - if (failed(tileLoopNest.tileRootOp(b, innerTileSizes, tileInterchange, - tileDistribution))) - return failure(); - fuseProducersGreedily(tileLoopNest.getRootOp().getInputOperands()); - - // Exit if the tile loop nest is empty since all tile sizes are zero. - if (tileLoopNest.isEmpty()) - return failure(); - - return tileLoopNest; -} diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp @@ -51,44 +51,6 @@ namespace { -/// Configurable pass to apply pattern-based tiling and fusion. -struct LinalgStrategyTileAndFusePass - : public impl::LinalgStrategyTileAndFusePassBase< - LinalgStrategyTileAndFusePass> { - - LinalgStrategyTileAndFusePass() = default; - - LinalgStrategyTileAndFusePass(StringRef opName, - LinalgTilingAndFusionOptions opt, - LinalgTransformationFilter filt) - : options(std::move(opt)), filter(std::move(filt)) { - this->anchorOpName.setValue(opName.str()); - } - - void runOnOperation() override { - auto funcOp = getOperation(); - if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) - return; - - RewritePatternSet tilingAndFusionPattern(funcOp.getContext()); - if (!anchorOpName.empty()) { - tilingAndFusionPattern.add( - anchorOpName, funcOp.getContext(), options, filter); - } else { - tilingAndFusionPattern.add( - funcOp.getContext(), options, filter); - } - // Search the root operation using bottom up traversal. - GreedyRewriteConfig config; - config.useTopDownTraversal = false; - (void)applyPatternsAndFoldGreedily( - funcOp, std::move(tilingAndFusionPattern), config); - } - - LinalgTilingAndFusionOptions options; - LinalgTransformationFilter filter; -}; - /// Configurable pass to apply pattern-based linalg tiling. struct LinalgStrategyTilePass : public impl::LinalgStrategyTilePassBase { @@ -139,15 +101,6 @@ }; } // namespace -/// Create a LinalgStrategyTileAndFusePass. -std::unique_ptr> -mlir::createLinalgStrategyTileAndFusePass( - StringRef opName, const LinalgTilingAndFusionOptions &options, - const LinalgTransformationFilter &filter) { - return std::make_unique(opName, options, - filter); -} - /// Create a LinalgStrategyTilePass. std::unique_ptr> mlir::createLinalgStrategyTilePass(StringRef opName, diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -447,82 +447,6 @@ return paddedOp; } -/// Linalg tile and fuse tensor ops pattern. -mlir::linalg::LinalgTileAndFuseTensorOpsPattern:: - LinalgTileAndFuseTensorOpsPattern(MLIRContext *context, - LinalgTilingAndFusionOptions options, - LinalgTransformationFilter f, - PatternBenefit benefit) - : RewritePattern(MatchAnyOpTypeTag(), benefit, context), - filter(std::move(f)), options(std::move(options)) {} - -mlir::linalg::LinalgTileAndFuseTensorOpsPattern:: - LinalgTileAndFuseTensorOpsPattern(StringRef opName, MLIRContext *context, - LinalgTilingAndFusionOptions options, - LinalgTransformationFilter f, - PatternBenefit benefit) - : RewritePattern(opName, benefit, context), filter(std::move(f)), - options(std::move(options)) {} - -FailureOr -mlir::linalg::LinalgTileAndFuseTensorOpsPattern::returningMatchAndRewrite( - Operation *op, PatternRewriter &rewriter) const { - LinalgOp rootOp = dyn_cast(op); - if (!rootOp) - return failure(); - if (failed(filter.checkAndNotify(rewriter, op))) - return failure(); - - // Check `tileSizes` contains a tile size for every `rootOp` loop dimension. - if (options.tileSizes.size() < rootOp.getNumLoops()) - return rewriter.notifyMatchFailure(op, "expect #tile sizes >= #loops"); - - // Check `tileInterchange` contains no entries or as many as `tileSizes`. - if (!options.tileInterchange.empty() && - options.tileInterchange.size() != options.tileSizes.size()) - return rewriter.notifyMatchFailure( - op, "expect the number of tile sizes and interchange dims to match"); - - // Copy the `tileSizes` and `tileInterchange` prefixes needed for `rootOp`. - SmallVector rootTileSizes(options.tileSizes.begin(), - options.tileSizes.begin() + - rootOp.getNumLoops()); - SmallVector rootInterchange = - options.tileInterchange.empty() - ? llvm::to_vector<6>(llvm::seq(0, rootOp.getNumLoops())) - : SmallVector(options.tileInterchange.begin(), - options.tileInterchange.begin() + - rootOp.getNumLoops()); - - // Check `rootTileSizes` contains non-zero tile sizes. - if (llvm::count(rootTileSizes, 0) == static_cast(rootTileSizes.size())) - return rewriter.notifyMatchFailure( - op, "expect at least one non-zero tile size"); - - // Check `rootInterchange` is a permutation of the `rootOp` loop dimensions. - // It has to be a permutation since the tiling cannot tile the same loop - // dimension multiple times. - if (!isPermutation(rootInterchange)) - return rewriter.notifyMatchFailure( - op, "expect the tile interchange permutes the root loops"); - - // Tile `rootOp` and fuse its producers. - FailureOr tileLoopNest = - tileConsumerAndFuseProducers(rewriter, rootOp, rootTileSizes, - rootInterchange, options.tileDistribution); - if (failed(tileLoopNest)) - return rewriter.notifyMatchFailure( - op, "tileConsumerAndFuseProducers failed unexpectedly"); - - // Replace all uses of the tiled loop operation. - rootOp->replaceAllUsesWith(tileLoopNest->getRootOpReplacementResults()); - - // Apply the filter if specified. - for (LinalgOp linalgOp : tileLoopNest->getAllTiledAndFusedOps()) - filter.replaceLinalgTransformationFilter(rewriter, linalgOp); - return tileLoopNest; -} - /// Linalg generalization pattern. mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern( MLIRContext *context, LinalgTransformationFilter f, PatternBenefit benefit) diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -45,12 +45,12 @@ /// Helper method to adjust the interchange vector to match the iteration /// domain. -static SmallVector -fillInterchangeVector(ArrayRef interchangeVector, +static SmallVector +fillInterchangeVector(ArrayRef interchangeVector, size_t iterationDomainSize) { - SmallVector filledVector = llvm::to_vector(interchangeVector); + SmallVector filledVector = llvm::to_vector(interchangeVector); if (filledVector.size() < iterationDomainSize) { - auto range = llvm::seq(filledVector.size(), iterationDomainSize); + auto range = llvm::seq(filledVector.size(), iterationDomainSize); filledVector.append(range.begin(), range.end()); } if (filledVector.size() > iterationDomainSize) @@ -61,23 +61,23 @@ /// Helper method to apply permutation to a vector template static SmallVector applyPermutationToVector(const SmallVector &vector, - ArrayRef interchange) { + ArrayRef interchange) { assert(interchange.size() == vector.size()); return llvm::to_vector( - llvm::map_range(interchange, [&](unsigned val) { return vector[val]; })); + llvm::map_range(interchange, [&](int64_t val) { return vector[val]; })); } /// Helper method to apply to invert a permutation. -static SmallVector -invertPermutationVector(ArrayRef interchange) { - SmallVector inversion(interchange.size()); +static SmallVector +invertPermutationVector(ArrayRef interchange) { + SmallVector inversion(interchange.size()); for (const auto &pos : llvm::enumerate(interchange)) { inversion[pos.value()] = pos.index(); } return inversion; } /// Method to check if an interchange vector is a permutation. -static bool isPermutation(ArrayRef interchange) { - llvm::SmallDenseSet seenVals; +static bool isPermutation(ArrayRef interchange) { + llvm::SmallDenseSet seenVals; for (auto val : interchange) { if (seenVals.count(val)) return false; @@ -298,7 +298,7 @@ { // If there is an interchange specified, permute the iteration domain and // the tile sizes. - SmallVector interchangeVector; + SmallVector interchangeVector; if (!options.interchangeVector.empty()) { interchangeVector = fillInterchangeVector(options.interchangeVector, iterationDomain.size()); @@ -365,7 +365,7 @@ // 5. Yield all the results of the tiled operation. The surrounding loop // nest is modified to insert a destructive update pattern to yield // from the loop nest values to replace the untiled op with. - unsigned numResults = op->getNumResults(); + int64_t numResults = op->getNumResults(); SmallVector> resultOffsetsList(numResults), resultSizesList(numResults); for (auto result : llvm::enumerate(op->getResults())) { @@ -443,7 +443,7 @@ // 1. First tile the consumer. scf::SCFTileAndFuseResult tileAndFuseResult; - llvm::SmallDenseMap yieldedValueToResultNumber; + llvm::SmallDenseMap yieldedValueToResultNumber; { FailureOr tilingResult = tileUsingSCFForOp(rewriter, consumer, options.tilingOptions); @@ -566,7 +566,7 @@ *destinationIterArg.value()); } if (iterArgNumber) { - unsigned resultNumber = fusableProducer.getResultNumber(); + int64_t resultNumber = fusableProducer.getResultNumber(); if (auto producerOp = dyn_cast(fusableProducer.getOwner())) { SmallVector destination = diff --git a/mlir/test/Dialect/Linalg/tile-fuse-and-distribute.mlir b/mlir/test/Dialect/Linalg/tile-fuse-and-distribute.mlir deleted file mode 100644 --- a/mlir/test/Dialect/Linalg/tile-fuse-and-distribute.mlir +++ /dev/null @@ -1,55 +0,0 @@ -// RUN: mlir-opt %s -test-linalg-transform-patterns=test-tile-fuse-and-distribute-options -split-input-file | FileCheck %s - -// CHECK: #[[MULMAP:.+]] = affine_map<()[s0, s1] -> (s0 * s1)> -// CHECK: #[[ADDMAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1)> -// CHECK: func @fill_matmul_tensors( -// CHECK-SAME: %[[TA:[0-9a-z]+]]: tensor -// CHECK-SAME: %[[TB:[0-9a-z]+]]: tensor) -> tensor { -func.func @fill_matmul_tensors( - %arg0: tensor, %arg1: tensor) - -> tensor { -// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[BIDY:.*]] = gpu.block_id y -// CHECK-DAG: %[[NBLOCKSY:.*]] = gpu.grid_dim y -// CHECK-DAG: %[[BIDX:.*]] = gpu.block_id x -// CHECK-DAG: %[[NBLOCKSX:.*]] = gpu.grid_dim x -// CHECK-DAG: %[[INIT:.+]] = tensor.empty -// CHECK: %[[MUL:.+]] = affine.apply #[[MULMAP]]()[%[[BIDY]], %[[C8]]] -// CHECK: %[[LBY:.+]] = affine.apply #[[ADDMAP]]()[%[[MUL]], %[[C0]]] -// CHECK: %[[STEPY:.+]] = affine.apply #[[MULMAP]]()[%[[NBLOCKSY]], %[[C8]]] -// CHECK: %[[TD0:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC0:.*]] = %[[INIT]]) -> (tensor) { -// CHECK: %[[MUL:.+]] = affine.apply #[[MULMAP]]()[%[[BIDX]], %[[C8]]] -// CHECK: %[[LBX:.+]] = affine.apply #[[ADDMAP]]()[%[[MUL]], %[[C0]]] -// CHECK: %[[STEPX:.+]] = affine.apply #[[MULMAP]]()[%[[NBLOCKSX]], %[[C8]]] -// CHECK: %[[TD1:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC1:.*]] = %[[TC0]]) -> (tensor) { -// CHECK: %[[OUTSLICEA:.+]] = tensor.extract_slice %{{.*}}[%{{.*}}, 0] [%{{.*}}, %{{.*}}] [1, 1] : tensor to tensor -// CHECK: %[[OUTSLICEB:.+]] = tensor.extract_slice %{{.*}}[0, %{{.*}}] [%{{.*}}, %{{.*}}] [1, 1] : tensor to tensor -// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[TC1]] -// CHECK: %[[FILL:.+]] = linalg.fill ins(%{{.+}}{{.*}}outs(%[[SLICE]] -// CHECK: %[[sTD2:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC2:.*]] = %[[FILL]]) -> (tensor) { -// CHECK: %[[sTA:.*]] = tensor.extract_slice %[[OUTSLICEA]][{{.*}}] : tensor to tensor -// CHECK: %[[sTB:.*]] = tensor.extract_slice %[[OUTSLICEB]][{{.*}}] : tensor to tensor -// CHECK: %[[sTC:.*]] = tensor.extract_slice %[[TC2]][{{.*}}] : tensor to tensor -// CHECK: %[[sTD:.*]] = linalg.matmul ins(%[[sTA]], %[[sTB]] : tensor, tensor) -// CHECK-SAME: outs(%[[sTC]] : tensor) -> tensor -// CHECK: %[[TD:.*]] = tensor.insert_slice %[[sTD]] into %[[TC2]][{{.*}}] : tensor into tensor -// CHECK: scf.yield %[[TD]] : tensor -// CHECK: %[[TD2:.*]] = tensor.insert_slice %[[sTD2]] into %[[TC1]][{{.*}}] : tensor into tensor -// CHECK: scf.yield %[[TD2]] : tensor -// CHECK: scf.yield %[[TD1]] : tensor - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %cst = arith.constant 0.0 : f32 - %0 = tensor.dim %arg0, %c0 : tensor - %1 = tensor.dim %arg1, %c1 : tensor - %2 = tensor.empty(%0, %1) : tensor - %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor) -> tensor - %4 = linalg.matmul {__internal_linalg_transform__ = "tensors_fuse_distribute1"} - ins(%arg0, %arg1: tensor, tensor) - outs(%3: tensor) - -> tensor - -// CHECK: return %[[TD0]] : tensor - return %4 : tensor -} diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir --- a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir @@ -61,7 +61,7 @@ %five = arith.constant 5.0 : f32 %init = tensor.empty() : tensor<12x25xf32> -// CHECK: %[[INIT:.+]] = tensor.empty() +// CHECK-DAG: %[[INIT:.+]] = tensor.empty() // CHECK-DAG: %[[C5:.+]] = arith.constant 5 : index // CHECK-DAG: %[[C7:.+]] = arith.constant 7 : index // CHECK: scf.for %[[IV0:.+]] = %{{.+}} to %{{.+}} step %[[C5]] iter_args(%[[FOR_ARG0:.+]] = %[[INIT]]) @@ -69,6 +69,9 @@ // CHECK: %[[OUT_SLICE0:.+]] = tensor.extract_slice %[[INPUT]][%[[IV0]], 0, %[[IV1]]] // CHECK: %[[OUT_SLICE1:.+]] = tensor.extract_slice %[[FOR_ARG1]][%[[IV0]], %[[IV1]]] // CHECK: %[[FILL:.+]] = linalg.fill {{.+}} outs(%[[OUT_SLICE1]] : tensor) +// +// Extra 4 constant is introduced, discard it. +// CHECK: arith.constant 4 : index // CHECK: %[[C4:.+]] = arith.constant 4 : index // CHECK: scf.for %[[IV2:.+]] = %{{.+}} to %{{.+}} step %[[C4]] iter_args(%[[FOR_ARG2:.+]] = %[[FILL]]) // CHECK: %[[IN_SLICE:.+]] = tensor.extract_slice %[[OUT_SLICE0]] @@ -92,6 +95,7 @@ transform.sequence %arg0 failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 - %1, %loops:3 = transform.structured.fuse %0 {tile_sizes = [5, 4, 7], tile_interchange = [0, 2, 1]} + %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [5, 0, 7], tile_interchange = [0, 2, 1]} + %2, %loops_2 = transform.structured.tile %1 [0, 4] } } diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -65,10 +65,6 @@ *this, "test-tile-and-distribute-options", llvm::cl::desc("Test tile and distribute options"), llvm::cl::init(false)}; - Option testTileFuseAndDistributionOptions{ - *this, "test-tile-fuse-and-distribute-options", - llvm::cl::desc("Test tile, fuse and distribute options"), - llvm::cl::init(false)}; Option testVectorTransferForwardingPatterns{ *this, "test-vector-transfer-forwarding-patterns", llvm::cl::desc( @@ -415,27 +411,6 @@ } } -static void fillTileFuseAndDistributePatterns(MLIRContext *context, - RewritePatternSet &patterns) { - LinalgLoopDistributionOptions cyclicNprocsEqNiters; - SmallVector distributionMethod = { - DistributionMethod::Cyclic, DistributionMethod::Cyclic}; - cyclicNprocsEqNiters.procInfo = - [distributionMethod](OpBuilder &b, Location loc, - ArrayRef parallelLoopRanges) { - return getGpuProcIds( - b, loc, parallelLoopRanges, distributionMethod); - }; - patterns.add( - MatmulOp::getOperationName(), context, - LinalgTilingAndFusionOptions() - .setTileSizes({8, 8, 4}) - .setDistributionOptions(cyclicNprocsEqNiters), - LinalgTransformationFilter( - StringAttr::get(context, "tensors_fuse_distribute1"), - StringAttr::get(context, "tensors_after_fuse_distribute1"))); -} - static void applyVectorTransferForwardingPatterns(func::FuncOp funcOp) { RewritePatternSet forwardPattern(funcOp.getContext()); forwardPattern.add(funcOp.getContext()); @@ -552,12 +527,6 @@ (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); return; } - if (testTileFuseAndDistributionOptions) { - RewritePatternSet patterns(&getContext()); - fillTileFuseAndDistributePatterns(&getContext(), patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); - return; - } if (testPatterns) return applyPatterns(getOperation()); if (testVectorTransferForwardingPatterns) diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp @@ -199,7 +199,7 @@ RewritePatternSet &patterns, StringRef filterName, ArrayRef tileSizes, - ArrayRef interchange = {}) { + ArrayRef interchange = {}) { scf::SCFTilingOptions tilingOptions; tilingOptions.setTileSizes(tileSizes).setInterchange(interchange); linalg::LinalgTransformationFilter filter( @@ -211,7 +211,7 @@ RewritePatternSet &patterns, StringRef filterName, ArrayRef tileSizes, - ArrayRef interchange = {}) { + ArrayRef interchange = {}) { scf::SCFTileAndFuseOptions tileAndFuseOptions; tileAndFuseOptions.tilingOptions.setTileSizes(tileSizes).setInterchange( interchange);