diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -592,6 +592,10 @@ SmallVector rootTileSizes(options.tileSizes.begin(), options.tileSizes.begin() + rootOp.getNumLoops()); + if (llvm::all_of(rootTileSizes, [](int64_t ts) { return ts == 0; })) { + return rewriter.notifyMatchFailure( + op, "all tile sizes are zero, nothing to do"); + } SmallVector rootInterchange = options.tileInterchange.empty() ? llvm::to_vector<6>(llvm::seq(0, rootOp.getNumLoops())) diff --git a/mlir/test/Dialect/Linalg/tile_and_fuse_no_fuse.mlir b/mlir/test/Dialect/Linalg/tile_and_fuse_no_fuse.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/tile_and_fuse_no_fuse.mlir @@ -0,0 +1,28 @@ +// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.generic fuse tile-sizes=0,0 run-enable-pass=false" -cse -split-input-file | FileCheck %s + +builtin.func @gemm_bias_add(%arg0 : tensor, %arg1 : tensor, + %arg2 : tensor, %arg3 : tensor) -> tensor { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = linalg.matmul ins(%arg0, %arg1 : tensor, tensor) + outs(%arg2 : tensor) -> tensor + %d0 = tensor.dim %arg0, %c0 : tensor + %d1 = tensor.dim %arg1, %c1 : tensor + %init = linalg.init_tensor [%d0, %d1] : tensor + %result = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d1)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%0, %arg3 : tensor, tensor) + outs(%init : tensor) { + ^bb0(%arg4: f32, %arg5 : f32, %arg6 : f32): + %1 = arith.addf %arg4, %arg5 : f32 + linalg.yield %1 : f32 + } -> tensor + return %result : tensor +} +// CHECK-LABEL: @gemm_bias_add( +// CHECK-NOT: scf.for +// CHECK: linalg.matmul +// CHECK: linalg.generic