diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -869,6 +869,10 @@ SmallVector yieldValues = {reductionOp.getVector()}; SmallVector retTypes = { VectorType::get({numElements}, reductionOp.getType())}; + if (reductionOp.getAcc()) { + yieldValues.push_back(reductionOp.getAcc()); + retTypes.push_back(reductionOp.getAcc().getType()); + } SmallVector newRetIndices; WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( rewriter, warpOp, yieldValues, retTypes, newRetIndices); @@ -882,6 +886,11 @@ Value fullReduce = distributedReductionFn(reductionOp.getLoc(), rewriter, perLaneReduction, reductionOp.getKind(), newWarpOp.getWarpSize()); + if (reductionOp.getAcc()) { + fullReduce = vector::makeArithReduction( + rewriter, reductionOp.getLoc(), reductionOp.getKind(), fullReduce, + newWarpOp.getResult(newRetIndices[1])); + } newWarpOp.getResult(operandIndex).replaceAllUsesWith(fullReduce); return success(); } diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir --- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir +++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir @@ -548,6 +548,42 @@ // ----- +// CHECK-PROP-LABEL: func @vector_reduction_acc( +// CHECK-PROP-SAME: %[[laneid:.*]]: index) +// CHECK-PROP-DAG: %[[c1:.*]] = arith.constant 1 : i32 +// CHECK-PROP-DAG: %[[c2:.*]] = arith.constant 2 : i32 +// CHECK-PROP-DAG: %[[c4:.*]] = arith.constant 4 : i32 +// CHECK-PROP-DAG: %[[c8:.*]] = arith.constant 8 : i32 +// CHECK-PROP-DAG: %[[c16:.*]] = arith.constant 16 : i32 +// CHECK-PROP-DAG: %[[c32:.*]] = arith.constant 32 : i32 +// CHECK-PROP: %[[warp_op:.*]]:2 = vector.warp_execute_on_lane_0(%[[laneid]])[32] -> (vector<2xf32>, f32) { +// CHECK-PROP: vector.yield %{{.*}}, %{{.*}} : vector<64xf32>, f32 +// CHECK-PROP: } +// CHECK-PROP: %[[a:.*]] = vector.reduction , %[[warp_op]]#0 : vector<2xf32> into f32 +// CHECK-PROP: %[[r0:.*]], %{{.*}} = gpu.shuffle xor %[[a]], %[[c1]], %[[c32]] +// CHECK-PROP: %[[a0:.*]] = arith.addf %[[a]], %[[r0]] +// CHECK-PROP: %[[r1:.*]], %{{.*}} = gpu.shuffle xor %[[a0]], %[[c2]], %[[c32]] +// CHECK-PROP: %[[a1:.*]] = arith.addf %[[a0]], %[[r1]] +// CHECK-PROP: %[[r2:.*]], %{{.*}} = gpu.shuffle xor %[[a1]], %[[c4]], %[[c32]] +// CHECK-PROP: %[[a2:.*]] = arith.addf %[[a1]], %[[r2]] +// CHECK-PROP: %[[r3:.*]], %{{.*}} = gpu.shuffle xor %[[a2]], %[[c8]], %[[c32]] +// CHECK-PROP: %[[a3:.*]] = arith.addf %[[a2]], %[[r3]] +// CHECK-PROP: %[[r4:.*]], %{{.*}} = gpu.shuffle xor %[[a3]], %[[c16]], %[[c32]] +// CHECK-PROP: %[[a4:.*]] = arith.addf %[[a3]], %[[r4]] +// CHECK-PROP: %[[a5:.*]] = arith.addf %[[a4]], %[[warp_op]]#1 +// CHECK-PROP: return %[[a5]] : f32 +func.func @vector_reduction_acc(%laneid: index) -> (f32) { + %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (f32) { + %0 = "some_def"() : () -> (vector<64xf32>) + %1 = "some_def"() : () -> (f32) + %2 = vector.reduction , %0, %1 : vector<64xf32> into f32 + vector.yield %2 : f32 + } + return %r : f32 +} + +// ----- + // CHECK-PROP-LABEL: func @warp_duplicate_yield( func.func @warp_duplicate_yield(%laneid: index) -> (vector<1xf32>, vector<1xf32>) { // CHECK-PROP: %{{.*}}:2 = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1xf32>, vector<1xf32>)