diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -262,6 +262,28 @@ const WarpExecuteOnLane0LoweringOptions &options; }; +/// Clone `writeOp` assumed to be nested under `warpOp` into a new warp execute +/// op with the proper return type. +/// The new write op is updated to write the result of the new warp execute op. +/// The old `writeOp` is deleted. +static vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter, + WarpExecuteOnLane0Op warpOp, + vector::TransferWriteOp writeOp, + VectorType targetType) { + assert(writeOp->getParentOp() == warpOp && + "write must be nested immediately under warp"); + OpBuilder::InsertionGuard g(rewriter); + WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, ValueRange{{writeOp.getVector()}}, + TypeRange{targetType}); + rewriter.setInsertionPointAfter(newWarpOp); + auto newWriteOp = + cast(rewriter.clone(*writeOp.getOperation())); + rewriter.eraseOp(writeOp); + newWriteOp.getVectorMutable().assign(newWarpOp.getResults().back()); + return newWriteOp; +} + /// Distribute transfer_write ops based on the affine map returned by /// `distributionMapFn`. /// Example: @@ -290,11 +312,21 @@ LogicalResult tryDistributeOp(RewriterBase &rewriter, vector::TransferWriteOp writeOp, WarpExecuteOnLane0Op warpOp) const { + VectorType writtenVectorType = writeOp.getVectorType(); + + // 1. If the write is 0-D, we just clone it into a new WarpExecuteOnLane0Op + // to separate it from the rest. + if (writtenVectorType.getRank() == 0) + return failure(); + + // 2. Compute the distribution map. AffineMap map = distributionMapFn(writeOp); - SmallVector targetShape(writeOp.getVectorType().getShape().begin(), - writeOp.getVectorType().getShape().end()); - assert(map.getNumResults() == 1 && - "multi-dim distribution not implemented yet"); + if (map.getNumResults() != 1) + return writeOp->emitError("multi-dim distribution not implemented yet"); + + // 3. Compute the targetType using the distribution map. + SmallVector targetShape(writtenVectorType.getShape().begin(), + writtenVectorType.getShape().end()); for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { unsigned position = map.getDimPosition(i); if (targetShape[position] % warpOp.getWarpSize() != 0) @@ -302,20 +334,16 @@ targetShape[position] = targetShape[position] / warpOp.getWarpSize(); } VectorType targetType = - VectorType::get(targetShape, writeOp.getVectorType().getElementType()); - - SmallVector yieldValues = {writeOp.getVector()}; - SmallVector retTypes = {targetType}; - WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( - rewriter, warpOp, yieldValues, retTypes); - rewriter.setInsertionPointAfter(newWarpOp); + VectorType::get(targetShape, writtenVectorType.getElementType()); - // Move op outside of region: Insert clone at the insertion point and delete - // the old op. - auto newWriteOp = - cast(rewriter.clone(*writeOp.getOperation())); - rewriter.eraseOp(writeOp); + // 4. clone the write into a new WarpExecuteOnLane0Op to separate it from + // the rest. + vector::TransferWriteOp newWriteOp = + cloneWriteOp(rewriter, warpOp, writeOp, targetType); + // 5. Reindex the write using the distribution map. + auto newWarpOp = + newWriteOp.getVector().getDefiningOp(); rewriter.setInsertionPoint(newWriteOp); AffineMap indexMap = map.compose(newWriteOp.getPermutationMap()); Location loc = newWriteOp.getLoc(); @@ -329,13 +357,11 @@ continue; unsigned indexPos = indexExpr.getPosition(); unsigned vectorPos = std::get<1>(it).cast().getPosition(); - auto scale = - getAffineConstantExpr(targetShape[vectorPos], newWarpOp.getContext()); + auto scale = rewriter.getAffineConstantExpr(targetShape[vectorPos]); indices[indexPos] = makeComposedAffineApply(rewriter, loc, d0 + scale * d1, {indices[indexPos], newWarpOp.getLaneid()}); } - newWriteOp.getVectorMutable().assign(newWarpOp.getResults().back()); newWriteOp.getIndicesMutable().assign(indices); return success(); @@ -634,7 +660,6 @@ Value broadcasted = rewriter.create( loc, destVecType, newWarpOp->getResults().back()); newWarpOp->getResult(operandNumber).replaceAllUsesWith(broadcasted); - return success(); } }; diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir --- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir +++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir @@ -491,3 +491,23 @@ } return %r : f32 } + +// ----- + +func.func @vector_reduction(%laneid: index, %m0: memref<4x2x32xf32>, %m1: memref) { + %c0 = arith.constant 0: index + %f0 = arith.constant 0.0: f32 + // CHECK-D: %[[R:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector) { + // CHECK-D: vector.warp_execute_on_lane_0(%{{.*}})[32] { + // CHECK-D: vector.transfer_write %[[R]], %{{.*}}[] : vector, memref + vector.warp_execute_on_lane_0(%laneid)[32] { + %0 = vector.transfer_read %m0[%c0, %c0, %c0], %f0 {in_bounds = [true]} : memref<4x2x32xf32>, vector<32xf32> + %1 = vector.transfer_read %m1[], %f0 : memref, vector + %2 = vector.extractelement %1[] : vector + %3 = vector.reduction , %0 : vector<32xf32> into f32 + %4 = arith.addf %3, %2 : f32 + %5 = vector.broadcast %4 : f32 to vector + vector.transfer_write %5, %m1[] : vector, memref + } + return +}