diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -269,10 +269,10 @@ /// Returns the loop ops generated from tiling. ArrayRef getLoopOps() { return tileLoopOps; } -private: /// Returns true if the tile loop nest has no tile loops. bool isEmpty(); +private: /// Returns true if the tile loop nest invariants are satisfied: /// - The `rootOp` has been tiled at least once. /// - The number of tile loop operations and dimensions match. diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -458,5 +458,9 @@ return failure(); fuseProducersGreedily(tileLoopNest.getRootOp().getInputOperands()); + // Exit if the tile loop nest is empty due to tile sizes zero. + if (tileLoopNest.isEmpty()) + return failure(); + return tileLoopNest; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -592,10 +592,6 @@ SmallVector rootTileSizes(options.tileSizes.begin(), options.tileSizes.begin() + rootOp.getNumLoops()); - if (llvm::all_of(rootTileSizes, [](int64_t ts) { return ts == 0; })) { - return rewriter.notifyMatchFailure( - op, "all tile sizes are zero, nothing to do"); - } SmallVector rootInterchange = options.tileInterchange.empty() ? llvm::to_vector<6>(llvm::seq(0, rootOp.getNumLoops())) @@ -623,7 +619,7 @@ // Apply the filter if specified. for (LinalgOp linalgOp : tileLoopNest->getAllTiledAndFusedOps()) filter.replaceLinalgTransformationFilter(rewriter, linalgOp); - return failure(); + return success(); } /// Linalg generic interchange pattern.