diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -895,6 +895,34 @@ } }; +/// Pattern to move out vector.extractelement of 0-D tensors. Those don't +/// need to be distributed and can just be propagated outside of the region. +struct WarpOpExtractElement : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, + PatternRewriter &rewriter) const override { + OpOperand *operand = getWarpResult(warpOp, [](Operation *op) { + return isa(op); + }); + if (!operand) + return failure(); + unsigned int operandNumber = operand->getOperandNumber(); + auto extractOp = operand->get().getDefiningOp(); + if (extractOp.getVectorType().getRank() != 0) + return failure(); + Location loc = extractOp.getLoc(); + SmallVector newRetIndices; + WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, {extractOp.getVector()}, {extractOp.getVectorType()}, + newRetIndices); + rewriter.setInsertionPointAfter(newWarpOp); + Value newExtract = rewriter.create( + loc, newWarpOp->getResult(newRetIndices[0])); + newWarpOp->getResult(operandNumber).replaceAllUsesWith(newExtract); + return success(); + } +}; + /// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if /// the scf.ForOp is the last operation in the region so that it doesn't change /// the order of execution. This creates a new scf.for region after the @@ -1093,8 +1121,9 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns( RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add(patterns.getContext(), benefit); + WarpOpBroadcast, WarpOpExtract, WarpOpExtractElement, + WarpOpForwardOperand, WarpOpScfForOp, WarpOpConstant>( + patterns.getContext(), benefit); } void mlir::vector::populateDistributeReduction( diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir --- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir +++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir @@ -632,6 +632,24 @@ // ----- +// CHECK-PROP-LABEL: func.func @vector_extractelement_simple( +// CHECK-PROP: %[[R:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector) { +// CHECK-PROP: %[[V:.*]] = "some_def"() : () -> vector +// CHECK-PROP: vector.yield %[[V]] : vector +// CHECK-PROP: } +// CHECK-PROP: %[[E:.*]] = vector.extractelement %[[R]][] : vector +// CHECK-PROP: return %[[E]] : f32 +func.func @vector_extractelement_simple(%laneid: index) -> (f32) { + %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (f32) { + %0 = "some_def"() : () -> (vector) + %1 = vector.extractelement %0[] : vector + vector.yield %1 : f32 + } + return %r : f32 +} + +// ----- + // CHECK-PROP: func @lane_dependent_warp_propagate_read // CHECK-PROP-SAME: %[[ID:.*]]: index func.func @lane_dependent_warp_propagate_read(