diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -278,6 +278,10 @@ LLVM_DEBUG(dbgs() << "\n***Consider producer:\t" << *dependence.dependentOpView.op << "\n"); auto producer = cast(dependence.dependentOpView.op); + if (isa(dependence.dependentOpView.op)) { + LLVM_DEBUG(dbgs() << "Not fusing indexed_generic producer"); + continue; + } // Check that the dependence is indeed on the input `consumerIdx` view. auto consumedView = dependence.indexingView; diff --git a/mlir/test/Dialect/Linalg/fusion.mlir b/mlir/test/Dialect/Linalg/fusion.mlir --- a/mlir/test/Dialect/Linalg/fusion.mlir +++ b/mlir/test/Dialect/Linalg/fusion.mlir @@ -672,6 +672,49 @@ // ----- +// +// We should not be fusing indexed_generic into a generic yet. +// https://bugs.llvm.org/show_bug.cgi?id=44875 +// + +#map0 = affine_map<(d0)[s0,s1] -> (d0 * s1 + s0)> +#pointwise_map = affine_map<(d0) -> (d0)> +#pointwise_1d_trait = { + args_in = 1, + args_out = 1, + indexing_maps = [#pointwise_map, #pointwise_map], + iterator_types = ["parallel"] +} + +func @nofuse_indexed_generic(%A: memref, %B: memref, %C: memref) { + linalg.indexed_generic #pointwise_1d_trait %A, %B { + ^bb0(%i: index, %a: f32, %b: f32): + linalg.yield %a : f32 + }: memref, memref + + %c0 = constant 0 : index + %c1 = constant 1 : index + %c10 = constant 10 : index + %dB = dim %B, 0 : memref + loop.for %i = %c0 to %dB step %c10 { + %subB = subview %B[%i][%c10][%c1] : memref to memref + %subC = subview %C[%i][%c10][%c1] : memref to memref + linalg.generic #pointwise_1d_trait %subB, %subC { + ^bb0(%b: f32, %c: f32): + linalg.yield %b : f32 + }: memref, memref + } + return +} +// CHECK-LABEL: func @nofuse_indexed_generic +// CHECK-NOT: loop.for +// CHECK: linalg.indexed_generic +// CHECK: loop.for +// CHECK-NOT: linalg.indexed_generic +// CHECK: linalg.generic + +// ----- + #map0 = affine_map<(d0, d1) -> (d0)> #map1 = affine_map<(d0, d1) -> (d0, d1)> #map2 = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>