diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -883,6 +883,43 @@ } }; +struct WarpOpShapeCast : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, + PatternRewriter &rewriter) const override { + OpOperand *operand = getWarpResult( + warpOp, [](Operation *op) { return isa(op); }); + if (!operand) + return failure(); + unsigned int operandNumber = operand->getOperandNumber(); + auto destVecType = + cast(warpOp->getResultTypes()[operandNumber]); + auto shapeCast = operand->get().getDefiningOp(); + Location loc = shapeCast.getLoc(); + Value shapeCastSrc = shapeCast.getSource(); + auto shapeCastSrcType = cast(shapeCastSrc.getType()); + + // Only handle the trivial shape cast with a single element for now. + // TODO: Support more cases. + if (shapeCastSrcType.getNumElements() != 1 || + destVecType.getNumElements() != 1) + return failure(); + + // For the single element shape cast, the source is broadcasted to all + // lanes, and each lane casts the source into the target shape. This is + // always possible because it's 1-element-to-1-element casting. + SmallVector newRetIndices; + WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, {shapeCastSrc}, {shapeCastSrcType}, newRetIndices); + rewriter.setInsertionPointAfter(newWarpOp); + Value newShapeCast = rewriter.create( + loc, destVecType, newWarpOp->getResult(newRetIndices[0])); + rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), + newShapeCast); + return success(); + } +}; + struct WarpOpBroadcast : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, @@ -1559,10 +1596,11 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns( RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn, const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit) { - patterns.add( - patterns.getContext(), benefit); + patterns + .add( + patterns.getContext(), benefit); patterns.add(patterns.getContext(), warpShuffleFromIdxFn, benefit); patterns.add(patterns.getContext(), distributionMapFn, diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir --- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir +++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir @@ -1173,3 +1173,42 @@ vector.print %r : vector<1x2xf32> return } + +// ----- + +// CHECK-PROP-LABEL: func @distribute_single_element_shape_cast( +// CHECK-PROP: %[[r:.*]] = vector.warp_execute_on_lane_0{{.*}} -> (vector) +// CHECK-PROP: %[[some_def:.*]] = "some_def" +// CHECK-PROP: vector.yield %[[some_def]] : vector +// CHECK-PROP: %[[s:.*]] = vector.shape_cast %[[r]] : vector to vector<1x1xf32> +// CHECK-PROP: vector.print %[[s]] : vector<1x1xf32> +func.func @distribute_single_element_shape_cast(%laneid: index) { + %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<1x1xf32>) { + %0 = "some_def"() : () -> (vector) + %1 = vector.shape_cast %0 : vector to vector<1x1xf32> + vector.yield %1 : vector<1x1xf32> + } + vector.print %r : vector<1x1xf32> + return +} + +// ----- + +// TODO: Distribute non-trivial shape cast when possible. + +// CHECK-PROP-LABEL: func @dont_distribute_nontrivial_shape_cast( +// CHECK-PROP: %[[r:.*]] = vector.warp_execute_on_lane_0{{.*}} -> (vector<1x2xf32>) +// CHECK-PROP: %[[some_def:.*]] = "some_def" +// CHECK-PROP: %[[s:.*]] = vector.shape_cast %[[some_def]] : vector<2x32xf32> to vector<32x2xf32> +// CHECK-PROP: vector.yield %[[s]] : vector<32x2xf32> +// CHECK-PROP: vector.print %[[r]] : vector<1x2xf32> +func.func @dont_distribute_nontrivial_shape_cast(%laneid: index) { + %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<1x2xf32>) { + %0 = "some_def"() : () -> (vector<2x32xf32>) + %1 = vector.shape_cast %0 : vector<2x32xf32> to vector<32x2xf32> + vector.yield %1 : vector<32x2xf32> + } + vector.print %r : vector<1x2xf32> + return +} +