diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -994,7 +994,7 @@ LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { // Only apply to elementwise linalg on tensor. - if (!genericOp.hasTensorSemantics() || + if (!genericOp.hasTensorSemantics() || genericOp.hasIndexSemantics() || genericOp.getNumParallelLoops() != genericOp.getNumLoops()) return failure(); // Only support identity output maps. It could be extended to permuations if diff --git a/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir b/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir --- a/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir +++ b/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir @@ -124,3 +124,30 @@ // CHECK-SAME: outs(%{{.+}} : tensor<6x5xf32>) // CHECK: tensor.expand_shape %[[OP]] // CHECK-SAME: tensor<6x5xf32> into tensor<2x3x5xf32> + +// ----- + +func @generic_op_index_semantics(%A: tensor, %B: tensor<16xi64>, %init: tensor) -> tensor { + %0 = tensor.expand_shape %A [[0, 1], [2]] + : tensor into tensor + %2 = linalg.generic {indexing_maps = [ + affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d2)>, + affine_map<(d0, d1, d2) -> (d0, d1, d2)>], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%0, %B : tensor, tensor<16xi64>) + outs(%init : tensor) { + ^bb0(%arg1: i64, %arg2: i64, %arg3: i64): // no predecessors + %index = linalg.index 0 : index + %1 = arith.index_cast %index : index to i64 + %add = arith.addi %arg1, %1 : i64 + %s = arith.subi %add, %arg2 : i64 + linalg.yield %s : i64 + } -> tensor + return %2 : tensor +} +// CHECK: func @generic_op_index_semantics +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK: %[[RESHAPE:.+]] = tensor.expand_shape %[[ARG0]] +// CHECK: %[[RESULT:.+]] = linalg.generic +// CHECK-SAME: ins(%[[RESHAPE]] +// CHECK: return %[[RESULT]]