diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -228,8 +228,8 @@ isMemoryEffectFree(op) && op->getNumRegions() == 0; } -/// Return a value yielded by `warpOp` which statifies the filter lamdba -/// condition and is not dead. +/// Return a value yielded by `warpOp` with no other uses which statifies the +/// filter lamdba condition and is not dead. static OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp, const std::function &fn) { auto yield = cast( @@ -237,7 +237,7 @@ for (OpOperand &yieldOperand : yield->getOpOperands()) { Value yieldValues = yieldOperand.get(); Operation *definedOp = yieldValues.getDefiningOp(); - if (definedOp && fn(definedOp)) { + if (definedOp && definedOp->hasOneUse() && fn(definedOp)) { if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty()) return &yieldOperand; } diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir --- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir +++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir @@ -1109,3 +1109,22 @@ } return %r : vector<4x96xf32> } +// ----- + +// Verify that we don't duplicate the reduction. +// CHECK-PROP-LABEL: func @vector_reduction_no_duplicate( +// CHECK-PROP-SAME: %[[laneid:.*]]: index) +// CHECK-PROP: %[[warp_op:.*]] = vector.warp_execute_on_lane_0(%[[laneid]])[32] -> (f32) { +// CHECK-PROP: vector.reduction +// CHECK-PROP: vector.yield %{{.*}} : f32 +// CHECK-PROP: } +// CHECK-PROP-NEXT: return %{{.*}} : f32 +func.func @vector_reduction_no_duplicate(%laneid: index) -> (f32) { + %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (f32) { + %0 = "some_def"() : () -> (vector<32xf32>) + %1 = vector.reduction , %0 : vector<32xf32> into f32 + "some_blocking_use"(%1) : (f32) -> () + vector.yield %1 : f32 + } + return %r : f32 +} \ No newline at end of file