diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -719,6 +719,33 @@ } }; +/// Pattern to move out vector.extract of single element vector. Those don't +/// need to be distributed and can just be propagated outside of the region. +struct WarpOpExtract : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, + PatternRewriter &rewriter) const override { + OpOperand *operand = getWarpResult( + warpOp, [](Operation *op) { return isa(op); }); + if (!operand) + return failure(); + unsigned int operandNumber = operand->getOperandNumber(); + auto extractOp = operand->get().getDefiningOp(); + if (extractOp.getVectorType().getNumElements() != 1) + return failure(); + Location loc = extractOp.getLoc(); + SmallVector newRetIndices; + WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, {extractOp.getVector()}, {extractOp.getVectorType()}, + newRetIndices); + rewriter.setInsertionPointAfter(newWarpOp); + Value newExtract = rewriter.create( + loc, newWarpOp->getResult(newRetIndices[0]), extractOp.getPosition()); + newWarpOp->getResult(operandNumber).replaceAllUsesWith(newExtract); + return success(); + } +}; + /// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if /// the scf.ForOp is the last operation in the region so that it doesn't change /// the order of execution. This creates a new scf.for region after the @@ -915,8 +942,8 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns( RewritePatternSet &patterns) { patterns.add(patterns.getContext()); + WarpOpBroadcast, WarpOpExtract, WarpOpForwardOperand, + WarpOpScfForOp, WarpOpConstant>(patterns.getContext()); } void mlir::vector::populateDistributeReduction( diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir --- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir +++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir @@ -611,3 +611,21 @@ } return %r : vector<1xf32> } + +// ----- + +// CHECK-PROP-LABEL: func.func @vector_extract_simple( +// CHECK-PROP: %[[R:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1xf32>) { +// CHECK-PROP: %[[V:.*]] = "some_def"() : () -> vector<1xf32> +// CHECK-PROP: vector.yield %[[V]] : vector<1xf32> +// CHECK-PROP: } +// CHECK-PROP: %[[E:.*]] = vector.extract %[[R]][0] : vector<1xf32> +// CHECK-PROP: return %[[E]] : f32 +func.func @vector_extract_simple(%laneid: index) -> (f32) { + %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (f32) { + %0 = "some_def"() : () -> (vector<1xf32>) + %1 = vector.extract %0[0] : vector<1xf32> + vector.yield %1 : f32 + } + return %r : f32 +}