Index: mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h =================================================================== --- mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h +++ mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h @@ -69,9 +69,15 @@ void populatePropagateWarpVectorDistributionPatterns( RewritePatternSet &pattern); -/// Collect patterns to distribute vector reduction ops using GPU warp shuffle -/// ops. -void populateReductionToGPUWarpShufflePatterns(RewritePatternSet &pattern); +/// Lambda signature to compute a reduction of a distributed value for the given +/// reduction kind and size. +using DistributedReductionFn = + std::function; + +/// Collect patterns to distribute vector reduction ops using given lamdba to +/// distribute reduction op. +void populateDistributeReduction(RewritePatternSet &pattern, + DistributedReductionFn distributedReductionFn); } // namespace vector } // namespace mlir Index: mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp =================================================================== --- mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -8,7 +8,6 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" -#include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" @@ -725,11 +724,10 @@ }; /// A pattern that extracts vector.reduction ops from a WarpExecuteOnLane0Op. -/// The vector is reduced in parallel. Currently limited to vector<32x...> -/// values. Every lane reduces two scalars, 5 times in a row. -/// E.g.: +/// The vector is reduced in parallel. Currently limited to vector size matching +/// the warpOp size. E.g.: /// ``` -/// %r = vector_ext.warp_execute_on_lane_0(%laneid) -> (f32) { +/// %r = vector_ext.warp_execute_on_lane_0(%laneid)[32] -> (f32) { /// %0 = "some_def"() : () -> (vector<32xf32>) /// %1 = vector.reduction "add", %0 : vector<32xf32> into f32 /// vector_ext.yield %1 : f32 @@ -737,22 +735,19 @@ /// ``` /// is lowered to: /// ``` -/// %0 = vector_ext.warp_execute_on_lane_0(%laneid) -> (vector<1xf32>) { +/// %0 = vector_ext.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) { /// %1 = "some_def"() : () -> (vector<32xf32>) /// vector_ext.yield %1 : vector<32xf32> /// } /// %a = vector.extract %0[0] : vector<1xf32> -/// %r0, %s0 = gpu.shuffle xor %e, %c1, %c32 : f32 -/// %a0 = arith.addf %a, %r0 : f32 -/// %r1, %s1 = gpu.shuffle xor %a0, %c2, %c32 : f32 -/// %a1 = arith.addf %a0, %r1 : f32 -/// ... -/// %r4, %s4 = gpu.shuffle xor %a3, %c16, %c32 : f32 -/// %r = arith.addf %a3, %r4 : f32 +/// %r = ("warp.reduction %a") /// ``` -struct ReductionToGPUWarpShuffle - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct WarpOpReduction : public OpRewritePattern { + WarpOpReduction(MLIRContext *context, + DistributedReductionFn distributedReductionFn, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + distributedReductionFn(distributedReductionFn) {} LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { @@ -793,22 +788,15 @@ // Every lane has one scalar value. These should be reduced. Value laneValVec = newWarpOp.getResult(numResults); Value laneVal = rewriter.create(yieldLoc, laneValVec, 0); - - // Parallel reduction using butterfly shuffles. - for (uint64_t i = 1; i < newWarpOp.getWarpSize(); i <<= 1) { - Value shuffled = - rewriter - .create(reductionOp.getLoc(), laneVal, i, - /*width=*/newWarpOp.getWarpSize(), - /*mode=*/gpu::ShuffleMode::XOR) - .result(); - laneVal = makeArithReduction(rewriter, reductionOp.getLoc(), - reductionOp.getKind(), laneVal, shuffled); - } - + laneVal = + distributedReductionFn(reductionOp.getLoc(), rewriter, laneVal, + reductionOp.getKind(), newWarpOp.getWarpSize()); newWarpOp.getResult(operandIndex).replaceAllUsesWith(laneVal); return success(); } + +private: + DistributedReductionFn distributedReductionFn; }; } // namespace @@ -831,9 +819,10 @@ patterns.getContext()); } -void mlir::vector::populateReductionToGPUWarpShufflePatterns( - RewritePatternSet &patterns) { - patterns.add(patterns.getContext()); +void mlir::vector::populateDistributeReduction( + RewritePatternSet &patterns, + DistributedReductionFn distributedReductionFn) { + patterns.add(patterns.getContext(), distributedReductionFn); } void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) { Index: mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp =================================================================== --- mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -804,6 +804,21 @@ return builder.create(loc, memrefType, symbolName); } +static Value warpReduction(Location loc, OpBuilder &builder, Value input, + CombiningKind kind, uint32_t size) { + Value laneVal = input; + // Parallel reduction using butterfly shuffles. + for (uint64_t i = 1; i < size; i <<= 1) { + Value shuffled = builder + .create(loc, laneVal, i, + /*width=*/size, + /*mode=*/gpu::ShuffleMode::XOR) + .result(); + laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled); + } + return laneVal; +} + struct TestVectorDistribution : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorDistribution) @@ -869,7 +884,7 @@ if (propagateDistribution) { RewritePatternSet patterns(ctx); vector::populatePropagateWarpVectorDistributionPatterns(patterns); - vector::populateReductionToGPUWarpShufflePatterns(patterns); + vector::populateDistributeReduction(patterns, warpReduction); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } WarpExecuteOnLane0LoweringOptions options; Index: utils/bazel/llvm-project-overlay/mlir/BUILD.bazel =================================================================== --- utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -3133,7 +3133,6 @@ ":BufferizationTransforms", ":DialectUtils", ":FuncDialect", - ":GPUDialect", ":IR", ":LinalgDialect", ":MemRefDialect",