diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -937,6 +937,11 @@ return FuseTensorReshapeOpAsProducer::fuse( reshapeOpProducer, genericOpConsumer, consumerIdx, rewriter, folder); + } else if (auto indexedGenericOpConsumer = + dyn_cast(consumer)) { + return FuseTensorReshapeOpAsProducer::fuse( + reshapeOpProducer, indexedGenericOpConsumer, consumerIdx, rewriter, + folder); } } else if (auto constantOpProducer = dyn_cast(producer)) { if (auto genericOpConsumer = dyn_cast(consumer)) { @@ -954,6 +959,11 @@ if (genericOpProducer.hasTensorSemantics()) return FuseTensorReshapeOpAsConsumer::fuse( genericOpProducer, reshapeOp, consumerIdx, rewriter, folder); + } else if (auto indexedGenericOpProducer = + dyn_cast(producer)) { + if (indexedGenericOpProducer.hasTensorSemantics()) + return FuseTensorReshapeOpAsConsumer::fuse( + indexedGenericOpProducer, reshapeOp, consumerIdx, rewriter, folder); } return nullptr; } diff --git a/mlir/test/Dialect/Linalg/fusion-tensor.mlir b/mlir/test/Dialect/Linalg/fusion-tensor.mlir --- a/mlir/test/Dialect/Linalg/fusion-tensor.mlir +++ b/mlir/test/Dialect/Linalg/fusion-tensor.mlir @@ -421,3 +421,68 @@ // CHECK: %[[VAL4:.+]] = subi %[[VAL3]], %[[SUB_OPERAND2]] : i32 // CHECK: linalg.yield %[[VAL4]] : i32 // CHECK-NOT: linalg.indexed_generic + +// ----- + +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> + +#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +func @indexed_generic_op_reshape_producer_fusion(%arg0 : tensor) + -> tensor { + %0 = linalg.tensor_reshape %arg0 [affine_map<(i, j, k, l) -> (i)>, + affine_map<(i, j, k, l) -> (j, k)>, + affine_map<(i, j, k, l) -> (l)>] : + tensor into tensor + %1 = linalg.indexed_generic { + args_in = 1 : i64, + args_out = 1 : i64, + indexing_maps = [#map0, #map0], + iterator_types = ["parallel", "parallel", "parallel", "parallel"] } %0 { + ^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: i32): // no predecessors + %2 = index_cast %arg2 : index to i32 + %3 = addi %arg6, %2 : i32 + linalg.yield %3 : i32 + }: tensor -> tensor + return %1 : tensor +} + +// CHECK-LABEL: func @indexed_generic_op_reshape_producer_fusion +// CHECK-NOT: linalg.tensor_reshape +// CHECK: linalg.indexed_generic +// CHECK-SAME: args_in = 1 +// CHECK-SAME: args_out = 1 +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] +// CHECK-NOT: linalg.tensor_reshape + +// ----- + +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)> + +#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +func @indexed_generic_op_reshape_consumer_fusion(%arg0 : tensor) + -> tensor { + %0 = linalg.indexed_generic { + args_in = 1 : i64, + args_out = 1 : i64, + indexing_maps = [#map0, #map0], + iterator_types = ["parallel", "parallel", "parallel", "parallel"] } %arg0 { + ^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: i32): // no predecessors + %2 = index_cast %arg2 : index to i32 + %3 = addi %arg6, %2 : i32 + linalg.yield %3 : i32 + }: tensor -> tensor + %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>, + affine_map<(i, j, k, l) -> (j, k, l)>] : + tensor into tensor + return %1 : tensor +} + +// CHECK-LABEL: func @indexed_generic_op_reshape_consumer_fusion +// CHECK-NOT: linalg.tensor_reshape +// CHECK: linalg.indexed_generic +// CHECK-SAME: args_in = 1 +// CHECK-SAME: args_out = 1 +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] +// CHECK-NOT: linalg.tensor_reshape