diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -1103,6 +1103,12 @@ linalgOp.indexing_maps().getAsValueRange()); fusedIndexMaps.erase(std::next(fusedIndexMaps.begin(), operand.index())); + // Check if the operation shapes to loops map is computable. + if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) { + return rewriter.notifyMatchFailure( + linalgOp, "fused op loop bound computation failed"); + } + // The operands list is same as the linalgOp with the argument for // constant index dropped. SmallVector fusedOperands(linalgOp.getInputs()); diff --git a/mlir/test/Dialect/Linalg/fusion-tensor.mlir b/mlir/test/Dialect/Linalg/fusion-tensor.mlir --- a/mlir/test/Dialect/Linalg/fusion-tensor.mlir +++ b/mlir/test/Dialect/Linalg/fusion-tensor.mlir @@ -678,3 +678,26 @@ } -> tensor<1x8xindex> return %1 : tensor<1x8xindex> } + +// ----- + +// CHECK-LABEL: func @no_fuse_constant_with_reduction +func @no_fuse_constant_with_reduction() -> tensor<3xf32> +{ + // CHECK: %[[CONST:.+]] = constant {{.+}} : tensor<3x2xf32> + // CHECK: %[[RESULT:.+]] = linalg.generic + // CHECK-SAME: ins(%[[CONST]] : tensor<3x2xf32>) + // CHECK: return %[[RESULT]] + %three = constant dense<3.0> : tensor<3x2xf32> + %init = linalg.init_tensor [3] : tensor<3xf32> + %result = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"]} + ins(%three : tensor<3x2xf32>) outs(%init : tensor<3xf32>) { + ^bb0(%arg0 : f32, %arg1 : f32): + %0 = addf %arg0, %arg1 : f32 + linalg.yield %0 : f32 + } -> tensor<3xf32> + return %result : tensor<3xf32> +}