diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -269,10 +269,10 @@ /// Returns the loop ops generated from tiling. ArrayRef getLoopOps() { return tileLoopOps; } -private: /// Returns true if the tile loop nest has no tile loops. bool isEmpty(); +private: /// Returns true if the tile loop nest invariants are satisfied: /// - The `rootOp` has been tiled at least once. /// - The number of tile loop operations and dimensions match. diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -458,5 +458,9 @@ return failure(); fuseProducersGreedily(tileLoopNest.getRootOp().getInputOperands()); + // Exit if the tile loop nest is empty since all tile sizes are zero. + if (tileLoopNest.isEmpty()) + return failure(); + return tileLoopNest; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -592,10 +592,6 @@ SmallVector rootTileSizes(options.tileSizes.begin(), options.tileSizes.begin() + rootOp.getNumLoops()); - if (llvm::all_of(rootTileSizes, [](int64_t ts) { return ts == 0; })) { - return rewriter.notifyMatchFailure( - op, "all tile sizes are zero, nothing to do"); - } SmallVector rootInterchange = options.tileInterchange.empty() ? llvm::to_vector<6>(llvm::seq(0, rootOp.getNumLoops())) @@ -603,6 +599,11 @@ options.tileInterchange.begin() + rootOp.getNumLoops()); + // Check `rootTileSizes` contains non-zero tile sizes. + if (llvm::count(rootTileSizes, 0) == static_cast(rootTileSizes.size())) + return rewriter.notifyMatchFailure( + op, "expect at least one non-zero tile size"); + // Check `rootInterchange` is a permutation of the `rootOp` loop dimensions. // It has to be a permutation since the tiling cannot tile the same loop // dimension multiple times. @@ -623,7 +624,7 @@ // Apply the filter if specified. for (LinalgOp linalgOp : tileLoopNest->getAllTiledAndFusedOps()) filter.replaceLinalgTransformationFilter(rewriter, linalgOp); - return failure(); + return success(); } /// Linalg generic interchange pattern. diff --git a/mlir/test/Dialect/Linalg/tile-and-fuse-no-fuse.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-no-fuse.mlir --- a/mlir/test/Dialect/Linalg/tile-and-fuse-no-fuse.mlir +++ b/mlir/test/Dialect/Linalg/tile-and-fuse-no-fuse.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.generic fuse tile-sizes=0,0 run-enable-pass=false" -cse -split-input-file | FileCheck %s +// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.matmul fuse tile-sizes=0,0,0 run-enable-pass=false" -split-input-file | FileCheck %s builtin.func @no_fuse_gemm(%arg0 : tensor, %arg1 : tensor) -> tensor { %c0 = arith.constant 0 : index