diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -756,13 +756,14 @@ // 4 . Check if the operation is a reduction. SmallVector> reductionOperands; for (Value operand : op->getOperands()) { - auto arg = operand.dyn_cast(); - if (!arg || arg.getArgNumber() < linalgOp.getNumDpsInputs()) + auto blockArg = operand.dyn_cast(); + if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() || + blockArg.getArgNumber() < linalgOp.getNumDpsInputs()) continue; SmallVector reductionOps; Value reduceValue = matchReduction( linalgOp.getRegionOutputArgs(), - arg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps); + blockArg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps); if (!reduceValue) continue; reductionOperands.push_back(std::make_pair(reduceValue, operand));