diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -122,6 +122,8 @@ combinerOp) .Case( [&](auto op) { return CombiningKind::ADD; }) + .Case([&](auto op) { return CombiningKind::ADD; }) + .Case([&](auto op) { return CombiningKind::ADD; }) .Case([&](auto op) { return CombiningKind::AND; }) .Case([&](auto op) { return CombiningKind::MAXSI; }) .Case([&](auto op) { return CombiningKind::MAXF; }) diff --git a/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir b/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir --- a/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir @@ -153,3 +153,33 @@ %2 = transform.structured.vectorize %0 } } + +// ----- + +// CHECK-LABEL: @vectorize_subred +func.func @vectorize_subred(%arg0: tensor<32xf32>, %out: tensor) -> tensor { + // CHECK: %[[vA:.+]] = vector.transfer_read + // CHECK: %[[vB:.+]] = vector.transfer_read + // CHECK: %[[vC:.+]] = vector.extractelement + // CHECK: %[[vR:.+]] = vector.multi_reduction , %[[vA]], %[[vC]] + %0 = linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>], + iterator_types = ["reduction"]} + ins(%arg0 : tensor<32xf32>) + outs(%out : tensor) { + ^bb0(%arg7: f32, %arg9: f32): + %1 = arith.subf %arg7, %arg9 : f32 + linalg.yield %1 : f32 + } -> tensor + return %0 : tensor +} + +transform.with_pdl_patterns { + ^bb0(%IN_MAT1: !pdl.operation): + transform.sequence %IN_MAT1 failures(propagate) { + ^bb1(%IN_MAT2: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %IN_MAT2 + %1 = get_closest_isolated_parent %0 + transform.structured.vectorize %1 + } +}