diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -752,13 +752,14 @@ // 4 . Check if the operation is a reduction. SmallVector> reductionOperands; for (Value operand : op->getOperands()) { - auto arg = operand.dyn_cast(); - if (!arg || arg.getArgNumber() < linalgOp.getNumDpsInputs()) + auto blockArg = operand.dyn_cast(); + if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() || + blockArg.getArgNumber() < linalgOp.getNumDpsInputs()) continue; SmallVector reductionOps; Value reduceValue = matchReduction( linalgOp.getRegionOutputArgs(), - arg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps); + blockArg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps); if (!reduceValue) continue; reductionOperands.push_back(std::make_pair(reduceValue, operand)); diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -1853,3 +1853,43 @@ %0 = transform.structured.match ops{["func.func"]} in %arg0 %1 = transform.structured.vectorize %0 } + +// ----- + +// Regression test: %13 was incorrectly detected as a reduction and +// vectorization failed. + +func.func @wrong_reduction_detection(%input: tensor<120x64xf32>) -> tensor<120x64xf32> { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c64 = arith.constant 64 : index + %cst_6 = arith.constant 4.000000e+00 : f32 + %1 = scf.for %arg0 = %c0 to %c64 step %c4 iter_args(%arg1 = %input) -> (tensor<120x64xf32>) { + %extracted_slice = tensor.extract_slice %arg1[%c0, %arg0] [1, 4] [1, 1] : tensor<120x64xf32> to tensor<1x4xf32> + %10 = linalg.fill {__internal_linalg_transform__ = "1"} ins(%cst_6 : f32) outs(%extracted_slice : tensor<1x4xf32>) -> tensor<1x4xf32> + %11 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} outs(%10 : tensor<1x4xf32>) { + ^bb0(%out: f32): + %12 = linalg.index 0 : index + %13 = arith.addi %arg0, %12 : index + %18 = arith.index_cast %13 : index to i32 + %20 = arith.uitofp %18 : i32 to f32 + %67 = arith.mulf %out, %20 : f32 + linalg.yield %67 : f32 + } -> tensor<1x4xf32> + %inserted_slice = tensor.insert_slice %11 into %arg1[%c0, %arg0] [1, 4] [1, 1] : tensor<1x4xf32> into tensor<120x64xf32> + scf.yield %inserted_slice : tensor<120x64xf32> + } + return %1 : tensor<120x64xf32> +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation + %2 = transform.structured.vectorize %1 +} + +// CHECK-LABEL: @wrong_reduction_detection +// CHECK: vector.broadcast +// CHECK: vector.transfer_write +