Index: mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp =================================================================== --- mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -524,6 +524,44 @@ } }; +/// Sink out splat constant op feeding into a warp op yield. +/// ``` +/// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) { +/// ... +/// %cst = arith.constant dense<2.0> : vector<32xf32> +/// vector.yield %cst : vector<32xf32> +/// } +/// ``` +/// To +/// ``` +/// vector.warp_execute_on_lane_0(%arg0 { +/// ... +/// } +/// %0 = arith.constant dense<2.0> : vector<1xf32> +struct WarpOpConstant : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, + PatternRewriter &rewriter) const override { + OpOperand *yieldOperand = getWarpResult( + warpOp, [](Operation *op) { return isa(op); }); + if (!yieldOperand) + return failure(); + auto constantOp = yieldOperand->get().getDefiningOp(); + auto dense = constantOp.getValue().dyn_cast(); + if (!dense) + return failure(); + unsigned operandIndex = yieldOperand->getOperandNumber(); + Attribute scalarAttr = dense.getSplatValue(); + Attribute newAttr = DenseElementsAttr::get( + warpOp.getResult(operandIndex).getType(), scalarAttr); + Location loc = warpOp.getLoc(); + rewriter.setInsertionPointAfter(warpOp); + Value distConstant = rewriter.create(loc, newAttr); + warpOp.getResult(operandIndex).replaceAllUsesWith(distConstant); + return success(); + } +}; + /// Sink out transfer_read op feeding into a warp op yield. /// ``` /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) { @@ -868,8 +906,8 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns( RewritePatternSet &patterns) { patterns.add( - patterns.getContext()); + WarpOpBroadcast, WarpOpForwardOperand, WarpOpScfForOp, + WarpOpConstant>(patterns.getContext()); } void mlir::vector::populateDistributeReduction( Index: mlir/test/Dialect/Vector/vector-warp-distribute.mlir =================================================================== --- mlir/test/Dialect/Vector/vector-warp-distribute.mlir +++ mlir/test/Dialect/Vector/vector-warp-distribute.mlir @@ -562,3 +562,16 @@ } return %r#0, %r#1 : vector<1xf32>, vector<1xf32> } + +// ----- + +// CHECK-PROP-LABEL: func @warp_constant( +// CHECK-PROP: %[[C:.*]] = arith.constant dense<2.000000e+00> : vector<1xf32> +// CHECK-PROP: return %[[C]] : vector<1xf32> +func.func @warp_constant(%laneid: index) -> (vector<1xf32>) { + %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) { + %cst = arith.constant dense<2.0> : vector<32xf32> + vector.yield %cst : vector<32xf32> + } + return %r : vector<1xf32> +}