diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -1133,6 +1133,131 @@ } }; +struct WarpOpInsert : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, + PatternRewriter &rewriter) const override { + OpOperand *operand = getWarpResult( + warpOp, [](Operation *op) { return isa(op); }); + if (!operand) + return failure(); + unsigned int operandNumber = operand->getOperandNumber(); + auto insertOp = operand->get().getDefiningOp(); + Location loc = insertOp.getLoc(); + + // "vector.insert %v, %v[] : ..." can be canonicalized to %v. + if (insertOp.getPosition().empty()) + return failure(); + + // Rewrite vector.insert with 1d dest to vector.insertelement. + if (insertOp.getDestVectorType().getRank() == 1) { + assert(insertOp.getPosition().size() == 1 && "expected 1 index"); + int64_t pos = insertOp.getPosition()[0].cast().getInt(); + rewriter.setInsertionPoint(insertOp); + rewriter.replaceOpWithNewOp( + insertOp, insertOp.getSource(), insertOp.getDest(), + rewriter.create(loc, pos)); + return success(); + } + + if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) { + // There is no distribution, this is a broadcast. Simply move the insert + // out of the warp op. + SmallVector newRetIndices; + WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, {insertOp.getSource(), insertOp.getDest()}, + {insertOp.getSourceType(), insertOp.getDestVectorType()}, + newRetIndices); + rewriter.setInsertionPointAfter(newWarpOp); + Value distributedSrc = newWarpOp->getResult(newRetIndices[0]); + Value distributedDest = newWarpOp->getResult(newRetIndices[1]); + Value newResult = rewriter.create( + loc, distributedSrc, distributedDest, insertOp.getPosition()); + newWarpOp->getResult(operandNumber).replaceAllUsesWith(newResult); + return success(); + } + + // Find the distributed dimension. There should be exactly one. + auto distrDestType = + warpOp.getResult(operandNumber).getType().cast(); + auto yieldedType = operand->get().getType().cast(); + int64_t distrDestDim = -1; + for (int64_t i = 0; i < yieldedType.getRank(); ++i) { + if (distrDestType.getDimSize(i) != yieldedType.getDimSize(i)) { + // Keep this assert here in case WarpExecuteOnLane0Op gets extended to + // support distributing multiple dimensions in the future. + assert(distrDestDim == -1 && "found multiple distributed dims"); + distrDestDim = i; + } + } + assert(distrDestDim != -1 && "could not find distributed dimension"); + + // Compute the distributed source vector type. + VectorType srcVecType = insertOp.getSourceType().cast(); + SmallVector distrSrcShape(srcVecType.getShape().begin(), + srcVecType.getShape().end()); + // E.g.: vector.insert %s, %d [2] : vector<96xf32> into vector<128x96xf32> + // Case 1: distrDestDim = 1 (dim of size 96). In that case, each lane will + // insert a smaller vector<3xf32>. + // Case 2: distrDestDim = 0 (dim of size 128) => distrSrcDim = -1. In that + // case, one lane will insert the source vector<96xf32>. The other + // lanes will not do anything. + int64_t distrSrcDim = distrDestDim - insertOp.getPosition().size(); + if (distrSrcDim >= 0) + distrSrcShape[distrSrcDim] = distrDestType.getDimSize(distrDestDim); + auto distrSrcType = + VectorType::get(distrSrcShape, distrDestType.getElementType()); + + // Yield source and dest vectors from warp op. + SmallVector newRetIndices; + WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, {insertOp.getSource(), insertOp.getDest()}, + {distrSrcType, distrDestType}, newRetIndices); + rewriter.setInsertionPointAfter(newWarpOp); + Value distributedSrc = newWarpOp->getResult(newRetIndices[0]); + Value distributedDest = newWarpOp->getResult(newRetIndices[1]); + + // Insert into the distributed vector. + Value newResult; + if (distrSrcDim >= 0) { + // Every lane inserts a small piece. + newResult = rewriter.create( + loc, distributedSrc, distributedDest, insertOp.getPosition()); + } else { + // One lane inserts the entire source vector. + int64_t elementsPerLane = distrDestType.getDimSize(distrDestDim); + SmallVector newPos = llvm::to_vector( + llvm::map_range(insertOp.getPosition(), [](Attribute attr) { + return attr.cast().getInt(); + })); + // tid of inserting lane: pos / elementsPerLane + Value insertingLane = rewriter.create( + loc, newPos[distrDestDim] / elementsPerLane); + Value isInsertingLane = rewriter.create( + loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane); + // Insert position: pos % elementsPerLane + newPos[distrDestDim] %= elementsPerLane; + auto insertingBuilder = [&](OpBuilder &builder, Location loc) { + Value newInsert = builder.create( + loc, distributedSrc, distributedDest, newPos); + builder.create(loc, newInsert); + }; + auto nonInsertingBuilder = [&](OpBuilder &builder, Location loc) { + builder.create(loc, distributedDest); + }; + newResult = rewriter + .create(loc, distrDestType, isInsertingLane, + /*thenBuilder=*/insertingBuilder, + /*elseBuilder=*/nonInsertingBuilder) + .getResult(0); + } + + newWarpOp->getResult(operandNumber).replaceAllUsesWith(newResult); + return success(); + } +}; + /// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if /// the scf.ForOp is the last operation in the region so that it doesn't change /// the order of execution. This creates a new scf.for region after the @@ -1390,8 +1515,8 @@ const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit) { patterns.add(patterns.getContext(), - benefit); + WarpOpConstant, WarpOpInsertElement, WarpOpInsert>( + patterns.getContext(), benefit); patterns.add(patterns.getContext(), warpShuffleFromIdxFn, benefit); patterns.add(patterns.getContext(), distributionMapFn, diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir --- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir +++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir @@ -994,3 +994,98 @@ } return %r : vector } + +// ----- + +// CHECK-PROP-LABEL: func @vector_insert_1d( +// CHECK-PROP-SAME: %[[LANEID:.*]]: index +// CHECK-PROP-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-PROP-DAG: %[[C26:.*]] = arith.constant 26 : index +// CHECK-PROP: %[[W:.*]]:2 = vector.warp_execute_on_lane_0{{.*}} -> (vector<3xf32>, f32) +// CHECK-PROP: %[[VEC:.*]] = "some_def" +// CHECK-PROP: %[[VAL:.*]] = "another_def" +// CHECK-PROP: vector.yield %[[VEC]], %[[VAL]] +// CHECK-PROP: %[[SHOULD_INSERT:.*]] = arith.cmpi eq, %[[LANEID]], %[[C26]] +// CHECK-PROP: %[[R:.*]] = scf.if %[[SHOULD_INSERT]] -> (vector<3xf32>) { +// CHECK-PROP: %[[INSERT:.*]] = vector.insertelement %[[W]]#1, %[[W]]#0[%[[C1]] : index] +// CHECK-PROP: scf.yield %[[INSERT]] +// CHECK-PROP: } else { +// CHECK-PROP: scf.yield %[[W]]#0 +// CHECK-PROP: } +// CHECK-PROP: return %[[R]] +func.func @vector_insert_1d(%laneid: index) -> (vector<3xf32>) { + %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<3xf32>) { + %0 = "some_def"() : () -> (vector<96xf32>) + %f = "another_def"() : () -> (f32) + %1 = vector.insert %f, %0[76] : f32 into vector<96xf32> + vector.yield %1 : vector<96xf32> + } + return %r : vector<3xf32> +} + +// ----- + +// CHECK-PROP-LABEL: func @vector_insert_2d_distr_src( +// CHECK-PROP-SAME: %[[LANEID:.*]]: index +// CHECK-PROP: %[[W:.*]]:2 = vector.warp_execute_on_lane_0{{.*}} -> (vector<3xf32>, vector<4x3xf32>) +// CHECK-PROP: %[[VEC:.*]] = "some_def" +// CHECK-PROP: %[[VAL:.*]] = "another_def" +// CHECK-PROP: vector.yield %[[VAL]], %[[VEC]] +// CHECK-PROP: %[[INSERT:.*]] = vector.insert %[[W]]#0, %[[W]]#1 [2] : vector<3xf32> into vector<4x3xf32> +// CHECK-PROP: return %[[INSERT]] +func.func @vector_insert_2d_distr_src(%laneid: index) -> (vector<4x3xf32>) { + %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<4x3xf32>) { + %0 = "some_def"() : () -> (vector<4x96xf32>) + %s = "another_def"() : () -> (vector<96xf32>) + %1 = vector.insert %s, %0[2] : vector<96xf32> into vector<4x96xf32> + vector.yield %1 : vector<4x96xf32> + } + return %r : vector<4x3xf32> +} + +// ----- + +// CHECK-PROP-LABEL: func @vector_insert_2d_distr_pos( +// CHECK-PROP-SAME: %[[LANEID:.*]]: index +// CHECK-PROP: %[[C19:.*]] = arith.constant 19 : index +// CHECK-PROP: %[[W:.*]]:2 = vector.warp_execute_on_lane_0{{.*}} -> (vector<96xf32>, vector<4x96xf32>) +// CHECK-PROP: %[[VEC:.*]] = "some_def" +// CHECK-PROP: %[[VAL:.*]] = "another_def" +// CHECK-PROP: vector.yield %[[VAL]], %[[VEC]] +// CHECK-PROP: %[[SHOULD_INSERT:.*]] = arith.cmpi eq, %[[LANEID]], %[[C19]] +// CHECK-PROP: %[[R:.*]] = scf.if %[[SHOULD_INSERT]] -> (vector<4x96xf32>) { +// CHECK-PROP: %[[INSERT:.*]] = vector.insert %[[W]]#0, %[[W]]#1 [3] : vector<96xf32> into vector<4x96xf32> +// CHECK-PROP: scf.yield %[[INSERT]] +// CHECK-PROP: } else { +// CHECK-PROP: scf.yield %[[W]]#1 +// CHECK-PROP: } +// CHECK-PROP: return %[[R]] +func.func @vector_insert_2d_distr_pos(%laneid: index) -> (vector<4x96xf32>) { + %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<4x96xf32>) { + %0 = "some_def"() : () -> (vector<128x96xf32>) + %s = "another_def"() : () -> (vector<96xf32>) + %1 = vector.insert %s, %0[79] : vector<96xf32> into vector<128x96xf32> + vector.yield %1 : vector<128x96xf32> + } + return %r : vector<4x96xf32> +} + +// ----- + +// CHECK-PROP-LABEL: func @vector_insert_2d_broadcast( +// CHECK-PROP-SAME: %[[LANEID:.*]]: index +// CHECK-PROP: %[[W:.*]]:2 = vector.warp_execute_on_lane_0{{.*}} -> (vector<96xf32>, vector<4x96xf32>) +// CHECK-PROP: %[[VEC:.*]] = "some_def" +// CHECK-PROP: %[[VAL:.*]] = "another_def" +// CHECK-PROP: vector.yield %[[VAL]], %[[VEC]] +// CHECK-PROP: %[[INSERT:.*]] = vector.insert %[[W]]#0, %[[W]]#1 [2] : vector<96xf32> into vector<4x96xf32> +// CHECK-PROP: return %[[INSERT]] +func.func @vector_insert_2d_broadcast(%laneid: index) -> (vector<4x96xf32>) { + %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<4x96xf32>) { + %0 = "some_def"() : () -> (vector<4x96xf32>) + %s = "another_def"() : () -> (vector<96xf32>) + %1 = vector.insert %s, %0[2] : vector<96xf32> into vector<4x96xf32> + vector.yield %1 : vector<4x96xf32> + } + return %r : vector<4x96xf32> +}