diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -1033,8 +1033,13 @@ Value broadcastFromTid = rewriter.create( loc, sym0.ceilDiv(elementsPerLane), extractOp.getPosition()); // Extract at position: pos % elementsPerLane - Value pos = rewriter.create(loc, sym0 % elementsPerLane, - extractOp.getPosition()); + Value pos = + elementsPerLane == 1 + ? rewriter.create(loc, 0).getResult() + : rewriter + .create(loc, sym0 % elementsPerLane, + extractOp.getPosition()) + .getResult(); Value extracted = rewriter.create(loc, distributedVec, pos); @@ -1049,6 +1054,85 @@ WarpShuffleFromIdxFn warpShuffleFromIdxFn; }; +struct WarpOpInsertElement : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, + PatternRewriter &rewriter) const override { + OpOperand *operand = getWarpResult( + warpOp, [](Operation *op) { return isa(op); }); + if (!operand) + return failure(); + unsigned int operandNumber = operand->getOperandNumber(); + auto insertOp = operand->get().getDefiningOp(); + VectorType vecType = insertOp.getDestVectorType(); + VectorType distrType = + warpOp.getResult(operandNumber).getType().cast(); + bool hasPos = static_cast(insertOp.getPosition()); + + // Yield destination vector, source scalar and position from warp op. + SmallVector additionalResults{insertOp.getDest(), + insertOp.getSource()}; + SmallVector additionalResultTypes{distrType, + insertOp.getSource().getType()}; + if (hasPos) { + additionalResults.push_back(insertOp.getPosition()); + additionalResultTypes.push_back(insertOp.getPosition().getType()); + } + Location loc = insertOp.getLoc(); + SmallVector newRetIndices; + WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, additionalResults, additionalResultTypes, + newRetIndices); + rewriter.setInsertionPointAfter(newWarpOp); + Value distributedVec = newWarpOp->getResult(newRetIndices[0]); + Value newSource = newWarpOp->getResult(newRetIndices[1]); + Value newPos = hasPos ? newWarpOp->getResult(newRetIndices[2]) : Value(); + rewriter.setInsertionPointAfter(newWarpOp); + + if (vecType == distrType) { + // Broadcast: Simply move the vector.inserelement op out. + Value newInsert = rewriter.create( + loc, newSource, distributedVec, newPos); + newWarpOp->getResult(operandNumber).replaceAllUsesWith(newInsert); + return success(); + } + + // This is a distribution. Only one lane should insert. + int64_t elementsPerLane = distrType.getShape()[0]; + AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext()); + // tid of extracting thread: pos / elementsPerLane + Value insertingLane = rewriter.create( + loc, sym0.ceilDiv(elementsPerLane), newPos); + // Insert position: pos % elementsPerLane + Value pos = + elementsPerLane == 1 + ? rewriter.create(loc, 0).getResult() + : rewriter + .create(loc, sym0 % elementsPerLane, newPos) + .getResult(); + Value isInsertingLane = rewriter.create( + loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane); + Value newResult = + rewriter + .create( + loc, distrType, isInsertingLane, + /*thenBuilder=*/ + [&](OpBuilder &builder, Location loc) { + Value newInsert = builder.create( + loc, newSource, distributedVec, pos); + builder.create(loc, newInsert); + }, + /*elseBuilder=*/ + [&](OpBuilder &builder, Location loc) { + builder.create(loc, distributedVec); + }) + .getResult(0); + newWarpOp->getResult(operandNumber).replaceAllUsesWith(newResult); + return success(); + } +}; + /// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if /// the scf.ForOp is the last operation in the region so that it doesn't change /// the order of execution. This creates a new scf.for region after the @@ -1303,7 +1387,8 @@ const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit) { patterns.add(patterns.getContext(), benefit); + WarpOpConstant, WarpOpInsertElement>(patterns.getContext(), + benefit); patterns.add(patterns.getContext(), warpShuffleFromIdxFn, benefit); patterns.add(patterns.getContext(), distributionMapFn, diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir --- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir +++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir @@ -930,3 +930,67 @@ // CHECK-SCF-IF: return %[[R0]], %[[R1]] : vector<1x64x1xf32>, vector<1x2x128xf32> return %r#0, %r#1 : vector<1x64x1xf32>, vector<1x2x128xf32> } + +// ----- + +// CHECK-PROP: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 ceildiv 3)> +// CHECK-PROP: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 3)> +// CHECK-PROP-LABEL: func @vector_insertelement_1d( +// CHECK-PROP-SAME: %[[LANEID:.*]]: index, %[[POS:.*]]: index +// CHECK-PROP: %[[W:.*]]:2 = vector.warp_execute_on_lane_0{{.*}} -> (vector<3xf32>, f32) +// CHECK-PROP: %[[INSERTING_LANE:.*]] = affine.apply #[[$MAP]]()[%[[POS]]] +// CHECK-PROP: %[[INSERTING_POS:.*]] = affine.apply #[[$MAP1]]()[%[[POS]]] +// CHECK-PROP: %[[SHOULD_INSERT:.*]] = arith.cmpi eq, %[[LANEID]], %[[INSERTING_LANE]] : index +// CHECK-PROP: %[[R:.*]] = scf.if %[[SHOULD_INSERT]] -> (vector<3xf32>) { +// CHECK-PROP: %[[INSERT:.*]] = vector.insertelement %[[W]]#1, %[[W]]#0[%[[INSERTING_POS]] : index] +// CHECK-PROP: scf.yield %[[INSERT]] +// CHECK-PROP: } else { +// CHECK-PROP: scf.yield %[[W]]#0 +// CHECK-PROP: } +// CHECK-PROP: return %[[R]] +func.func @vector_insertelement_1d(%laneid: index, %pos: index) -> (vector<3xf32>) { + %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<3xf32>) { + %0 = "some_def"() : () -> (vector<96xf32>) + %f = "another_def"() : () -> (f32) + %1 = vector.insertelement %f, %0[%pos : index] : vector<96xf32> + vector.yield %1 : vector<96xf32> + } + return %r : vector<3xf32> +} + +// ----- + +// CHECK-PROP-LABEL: func @vector_insertelement_1d_broadcast( +// CHECK-PROP-SAME: %[[LANEID:.*]]: index, %[[POS:.*]]: index +// CHECK-PROP: %[[W:.*]]:2 = vector.warp_execute_on_lane_0{{.*}} -> (vector<96xf32>, f32) +// CHECK-PROP: %[[VEC:.*]] = "some_def" +// CHECK-PROP: %[[VAL:.*]] = "another_def" +// CHECK-PROP: vector.yield %[[VEC]], %[[VAL]] +// CHECK-PROP: vector.insertelement %[[W]]#1, %[[W]]#0[%[[POS]] : index] : vector<96xf32> +func.func @vector_insertelement_1d_broadcast(%laneid: index, %pos: index) -> (vector<96xf32>) { + %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<96xf32>) { + %0 = "some_def"() : () -> (vector<96xf32>) + %f = "another_def"() : () -> (f32) + %1 = vector.insertelement %f, %0[%pos : index] : vector<96xf32> + vector.yield %1 : vector<96xf32> + } + return %r : vector<96xf32> +} + +// ----- + +// CHECK-PROP-LABEL: func @vector_insertelement_0d( +// CHECK-PROP: %[[W:.*]]:2 = vector.warp_execute_on_lane_0{{.*}} -> (vector, f32) +// CHECK-PROP: %[[VEC:.*]] = "some_def" +// CHECK-PROP: %[[VAL:.*]] = "another_def" +// CHECK-PROP: vector.yield %[[VEC]], %[[VAL]] +// CHECK-PROP: vector.insertelement %[[W]]#1, %[[W]]#0[] : vector +func.func @vector_insertelement_0d(%laneid: index) -> (vector) { + %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector) { + %0 = "some_def"() : () -> (vector) + %f = "another_def"() : () -> (f32) + %1 = vector.insertelement %f, %0[] : vector + vector.yield %1 : vector + } + return %r : vector +}