diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -75,9 +75,6 @@ /// work on primitive types, if possible. std::unique_ptr createLinalgDetensorizePass(); -/// Create a pass to tile a LinalgOp and fuse its producers. -std::unique_ptr> createLinalgTileAndFuseTensorOpsPass(); - //===----------------------------------------------------------------------===// /// Linalg strategy passes. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -221,20 +221,6 @@ }]; } -def LinalgTileAndFuseTensorOps - : FunctionPass<"linalg-tile-and-fuse-tensor-ops"> { - let summary = "Tile a LinalgOp and fuse its producers."; - let constructor = "mlir::createLinalgTileAndFuseTensorOpsPass()"; - let options = [ - ListOption<"tileSizes", "tile-sizes", "int64_t", "Tile sizes", - "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">, - ListOption<"tileInterchange", "tile-interchange", "int64_t", - "Tile loop interchange", - "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">, - ]; - let dependentDialects = ["linalg::LinalgDialect", "scf::SCFDialect"]; -} - def LinalgStrategyTileAndFusePass : FunctionPass<"linalg-strategy-tile-and-fuse-pass"> { let summary = "Configurable pass to apply pattern-based tiling and fusion."; diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -457,80 +457,3 @@ return tileLoopNest; } - -namespace { -struct LinalgTileAndFuseTensorOps - : public LinalgTileAndFuseTensorOpsBase { - - void notifyFailure(StringRef message) { - llvm::errs() << " - LinalgTileAndFuseTensorOps: " << message << "\n"; - signalPassFailure(); - } - - void runOnFunction() override { - FuncOp funcOp = getFunction(); - OpBuilder b(funcOp.getContext()); - - // Heuristic to find a good operation to tile and start fusion. Walk all - // operations and select the one with the maximal backward slice of fusion - // candidates. - LinalgOp rootOp = nullptr; - int64_t numFusionCandidates = -1; - funcOp.walk([&](LinalgOp linalgOp) { - SetVector backwardSlice; - getBackwardSlice(linalgOp, &backwardSlice); - int64_t backwardSliceSize = count_if( - backwardSlice, [](Operation *op) { return isa(op); }); - if (backwardSliceSize > numFusionCandidates) { - rootOp = linalgOp; - numFusionCandidates = backwardSliceSize; - } - }); - if (!rootOp) - return notifyFailure("expect to find a root operation"); - - // Check `tileSizes` contains a tile size for every `rootOp` loop dimension. - if (tileSizes.size() < rootOp.getNumLoops()) - return notifyFailure("expect #tile sizes >= #loops"); - - // Check `tileInterchange` contains no entries or as many as `tileSizes`. - if (!tileInterchange.empty() && - tileInterchange.size() != tileSizes.size()) { - return notifyFailure( - "expect the number of tile sizes and interchange dims to match"); - } - - // Copy the `tileSizes` and `tileInterchange` prefixes needed to tile - // `rootOp` or use the identity interchange if `tileInterchange` is empty. - SmallVector rootTileSizes( - tileSizes.begin(), tileSizes.begin() + rootOp.getNumLoops()); - SmallVector rootInterchange = - tileInterchange.empty() - ? llvm::to_vector<6>(llvm::seq(0, rootOp.getNumLoops())) - : SmallVector(tileInterchange.begin(), - tileInterchange.begin() + - rootOp.getNumLoops()); - - // Check `rootInterchange` is a permutation of the `rootOp` loop dimensions. - // It has to be a permutation since the tiling cannot tile the same loop - // dimension multiple times. - if (!isPermutation(rootInterchange)) - return notifyFailure( - "expect the tile interchange permutes the root loops"); - - // Tile `rootOp` and fuse its producers. - FailureOr tileLoopNest = - tileConsumerAndFuseProducers(b, rootOp, rootTileSizes, rootInterchange); - if (failed(tileLoopNest)) - return notifyFailure("tileConsumerAndFuseProducers failed unexpectedly"); - - // Replace all uses of the tiled loop operation. - rootOp->replaceAllUsesWith(tileLoopNest->getRootOpReplacementResults()); - } -}; -} // namespace - -std::unique_ptr> -mlir::createLinalgTileAndFuseTensorOpsPass() { - return std::make_unique(); -}