Index: mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h =================================================================== --- mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h +++ mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h @@ -40,7 +40,7 @@ const WarpExecuteOnLane0LoweringOptions &options, PatternBenefit benefit = 1); -using DistributionMapFn = std::function; +using DistributionMapFn = std::function; /// Distribute transfer_write ops based on the affine map returned by /// `distributionMapFn`. @@ -67,9 +67,12 @@ /// region. void moveScalarUniformCode(WarpExecuteOnLane0Op op); -/// Collect patterns to propagate warp distribution. +/// Collect patterns to propagate warp distribution. `distributionMapFn` is used +/// to decide how a value should be distributed when this cannot be inferred +/// from its uses. void populatePropagateWarpVectorDistributionPatterns( - RewritePatternSet &pattern, PatternBenefit benefit = 1); + RewritePatternSet &pattern, const DistributionMapFn &distributionMapFn, + PatternBenefit benefit = 1); /// Lambda signature to compute a reduction of a distributed value for the given /// reduction kind and size. Index: mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp =================================================================== --- mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -421,6 +421,25 @@ return newWriteOp; } +/// Return the distributed vector type based on the original type and the +/// distribution map. +static VectorType getDistributedType(VectorType originalType, AffineMap map, + int64_t warpSize) { + if (map.getNumResults() != 1) + return VectorType(); + SmallVector targetShape(originalType.getShape().begin(), + originalType.getShape().end()); + for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { + unsigned position = map.getDimPosition(i); + if (targetShape[position] % warpSize != 0) + return VectorType(); + targetShape[position] = targetShape[position] / warpSize; + } + VectorType targetType = + VectorType::get(targetShape, originalType.getElementType()); + return targetType; +} + /// Distribute transfer_write ops based on the affine map returned by /// `distributionMapFn`. /// Example: @@ -456,29 +475,19 @@ if (writtenVectorType.getRank() == 0) return failure(); - // 2. Compute the distribution map. - AffineMap map = distributionMapFn(writeOp); - if (map.getNumResults() != 1) - return writeOp->emitError("multi-dim distribution not implemented yet"); - - // 3. Compute the targetType using the distribution map. - SmallVector targetShape(writtenVectorType.getShape().begin(), - writtenVectorType.getShape().end()); - for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { - unsigned position = map.getDimPosition(i); - if (targetShape[position] % warpOp.getWarpSize() != 0) - return failure(); - targetShape[position] = targetShape[position] / warpOp.getWarpSize(); - } + // 2. Compute the distributed type. + AffineMap map = distributionMapFn(writeOp.getVector()); VectorType targetType = - VectorType::get(targetShape, writtenVectorType.getElementType()); + getDistributedType(writtenVectorType, map, warpOp.getWarpSize()); + if (!targetType) + return failure(); - // 4. clone the write into a new WarpExecuteOnLane0Op to separate it from + // 3. clone the write into a new WarpExecuteOnLane0Op to separate it from // the rest. vector::TransferWriteOp newWriteOp = cloneWriteOp(rewriter, warpOp, writeOp, targetType); - // 5. Reindex the write using the distribution map. + // 4. Reindex the write using the distribution map. auto newWarpOp = newWriteOp.getVector().getDefiningOp(); rewriter.setInsertionPoint(newWriteOp); @@ -494,7 +503,8 @@ continue; unsigned indexPos = indexExpr.getPosition(); unsigned vectorPos = std::get<1>(it).cast().getPosition(); - auto scale = rewriter.getAffineConstantExpr(targetShape[vectorPos]); + auto scale = + rewriter.getAffineConstantExpr(targetType.getDimSize(vectorPos)); indices[indexPos] = makeComposedAffineApply(rewriter, loc, d0 + scale * d1, {indices[indexPos], newWarpOp.getLaneid()}); @@ -956,6 +966,10 @@ /// } /// ``` struct WarpOpScfForOp : public OpRewritePattern { + + WarpOpScfForOp(MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1) + : OpRewritePattern(ctx, b), + distributionMapFn(std::move(fn)) {} using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { @@ -966,6 +980,38 @@ auto forOp = dyn_cast_or_null(lastNode); if (!forOp) return failure(); + // Collect Values that come from the warp op but are outside the forOp. + // Those Value needs to be returned by the original warpOp and passed to the + // new op. + llvm::SmallSetVector escapingValues; + SmallVector inputTypes; + SmallVector distTypes; + forOp.getBodyRegion().walk([&](Operation *op) { + for (Value operand : op->getOperands()) { + Operation *parent = operand.getParentRegion()->getParentOp(); + if (forOp->isAncestor(parent)) + continue; + if (warpOp->isAncestor(parent)) { + if (!escapingValues.insert(operand)) + continue; + Type distType = operand.getType(); + if (auto vecType = distType.cast()) { + AffineMap map = distributionMapFn(operand); + distType = getDistributedType(vecType, map, warpOp.getWarpSize()); + } + inputTypes.push_back(operand.getType()); + distTypes.push_back(distType); + } + } + }); + + SmallVector newRetIndices; + WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, escapingValues.getArrayRef(), distTypes, + newRetIndices); + yield = cast( + newWarpOp.getBodyRegion().getBlocks().begin()->getTerminator()); + SmallVector newOperands; SmallVector resultIdx; // Collect all the outputs coming from the forOp. @@ -973,28 +1019,42 @@ if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) continue; auto forResult = yieldOperand.get().cast(); - newOperands.push_back(warpOp.getResult(yieldOperand.getOperandNumber())); + newOperands.push_back( + newWarpOp.getResult(yieldOperand.getOperandNumber())); yieldOperand.set(forOp.getIterOperands()[forResult.getResultNumber()]); resultIdx.push_back(yieldOperand.getOperandNumber()); } + OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPointAfter(warpOp); + rewriter.setInsertionPointAfter(newWarpOp); + // Create a new for op outside the region with a WarpExecuteOnLane0Op region // inside. auto newForOp = rewriter.create( forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep(), newOperands); rewriter.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin()); + + SmallVector warpInput(newForOp.getRegionIterArgs().begin(), + newForOp.getRegionIterArgs().end()); + SmallVector warpInputType(forOp.getResultTypes().begin(), + forOp.getResultTypes().end()); + llvm::SmallDenseMap argIndexMapping; + for (auto [i, retIdx] : llvm::enumerate(newRetIndices)) { + warpInput.push_back(newWarpOp.getResult(retIdx)); + argIndexMapping[escapingValues[i]] = warpInputType.size(); + warpInputType.push_back(inputTypes[i]); + } auto innerWarp = rewriter.create( - warpOp.getLoc(), newForOp.getResultTypes(), warpOp.getLaneid(), - warpOp.getWarpSize(), newForOp.getRegionIterArgs(), - forOp.getResultTypes()); + newWarpOp.getLoc(), newForOp.getResultTypes(), newWarpOp.getLaneid(), + newWarpOp.getWarpSize(), warpInput, warpInputType); SmallVector argMapping; argMapping.push_back(newForOp.getInductionVar()); for (Value args : innerWarp.getBody()->getArguments()) { argMapping.push_back(args); } + argMapping.resize(forOp.getBody()->getNumArguments()); SmallVector yieldOperands; for (Value operand : forOp.getBody()->getTerminator()->getOperands()) yieldOperands.push_back(operand); @@ -1008,12 +1068,23 @@ rewriter.eraseOp(forOp); // Replace the warpOp result coming from the original ForOp. for (const auto &res : llvm::enumerate(resultIdx)) { - warpOp.getResult(res.value()) + newWarpOp.getResult(res.value()) .replaceAllUsesWith(newForOp.getResult(res.index())); - newForOp->setOperand(res.index() + 3, warpOp.getResult(res.value())); + newForOp->setOperand(res.index() + 3, newWarpOp.getResult(res.value())); } + newForOp.walk([&](Operation *op) { + for (OpOperand &operand : op->getOpOperands()) { + auto it = argIndexMapping.find(operand.get()); + if (it == argIndexMapping.end()) + continue; + operand.set(innerWarp.getBodyRegion().getArgument(it->second)); + } + }); return success(); } + +private: + DistributionMapFn distributionMapFn; }; /// A pattern that extracts vector.reduction ops from a WarpExecuteOnLane0Op. @@ -1119,11 +1190,14 @@ } void mlir::vector::populatePropagateWarpVectorDistributionPatterns( - RewritePatternSet &patterns, PatternBenefit benefit) { + RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn, + PatternBenefit benefit) { patterns.add( - patterns.getContext(), benefit); + WarpOpForwardOperand, WarpOpConstant>(patterns.getContext(), + benefit); + patterns.add(patterns.getContext(), distributionMapFn, + benefit); } void mlir::vector::populateDistributeReduction( Index: mlir/test/Dialect/Vector/vector-warp-distribute.mlir =================================================================== --- mlir/test/Dialect/Vector/vector-warp-distribute.mlir +++ mlir/test/Dialect/Vector/vector-warp-distribute.mlir @@ -349,6 +349,40 @@ // ----- +// CHECK-PROP-LABEL: func @warp_scf_for_use_from_above( +// CHECK-PROP: %[[INI:.*]]:2 = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<4xf32>, vector<4xf32>) { +// CHECK-PROP: %[[INI1:.*]] = "some_def"() : () -> vector<128xf32> +// CHECK-PROP: %[[USE:.*]] = "some_def_above"() : () -> vector<128xf32> +// CHECK-PROP: vector.yield %[[INI1]], %[[USE]] : vector<128xf32>, vector<128xf32> +// CHECK-PROP: } +// CHECK-PROP: %[[F:.*]] = scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[FARG:.*]] = %[[INI]]#0) -> (vector<4xf32>) { +// CHECK-PROP: %[[W:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] args(%[[FARG]], %[[INI]]#1 : vector<4xf32>, vector<4xf32>) -> (vector<4xf32>) { +// CHECK-PROP: ^bb0(%[[ARG0:.*]]: vector<128xf32>, %[[ARG1:.*]]: vector<128xf32>): +// CHECK-PROP: %[[ACC:.*]] = "some_def"(%[[ARG0]], %[[ARG1]]) : (vector<128xf32>, vector<128xf32>) -> vector<128xf32> +// CHECK-PROP: vector.yield %[[ACC]] : vector<128xf32> +// CHECK-PROP: } +// CHECK-PROP: scf.yield %[[W]] : vector<4xf32> +// CHECK-PROP: } +// CHECK-PROP: "some_use"(%[[F]]) : (vector<4xf32>) -> () +func.func @warp_scf_for_use_from_above(%arg0: index) { + %c128 = arith.constant 128 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %0 = vector.warp_execute_on_lane_0(%arg0)[32] -> (vector<4xf32>) { + %ini = "some_def"() : () -> (vector<128xf32>) + %use_from_above = "some_def_above"() : () -> (vector<128xf32>) + %3 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %ini) -> (vector<128xf32>) { + %acc = "some_def"(%arg4, %use_from_above) : (vector<128xf32>, vector<128xf32>) -> (vector<128xf32>) + scf.yield %acc : vector<128xf32> + } + vector.yield %3 : vector<128xf32> + } + "some_use"(%0) : (vector<4xf32>) -> () + return +} + +// ----- + // CHECK-PROP-LABEL: func @warp_scf_for_swap( // CHECK-PROP: %[[INI:.*]]:2 = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<4xf32>, vector<4xf32>) { // CHECK-PROP: %[[INI1:.*]] = "some_def"() : () -> vector<128xf32> Index: mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp =================================================================== --- mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -746,24 +746,26 @@ } }); MLIRContext *ctx = &getContext(); + auto distributionFn = [](Value val) { + // Create a map (d0, d1) -> (d1) to distribute along the inner + // dimension. Once we support n-d distribution we can add more + // complex cases. + VectorType vecType = val.getType().dyn_cast(); + int64_t vecRank = vecType ? vecType.getRank() : 0; + OpBuilder builder(val.getContext()); + if (vecRank == 0) + return AffineMap::get(val.getContext()); + return AffineMap::get(vecRank, 0, builder.getAffineDimExpr(vecRank - 1)); + }; if (distributeTransferWriteOps) { - auto distributionFn = [](vector::TransferWriteOp writeOp) { - // Create a map (d0, d1) -> (d1) to distribute along the inner - // dimension. Once we support n-d distribution we can add more - // complex cases. - int64_t vecRank = writeOp.getVectorType().getRank(); - OpBuilder builder(writeOp.getContext()); - auto map = - AffineMap::get(vecRank, 0, builder.getAffineDimExpr(vecRank - 1)); - return map; - }; RewritePatternSet patterns(ctx); populateDistributeTransferWriteOpPatterns(patterns, distributionFn); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } if (propagateDistribution) { RewritePatternSet patterns(ctx); - vector::populatePropagateWarpVectorDistributionPatterns(patterns); + vector::populatePropagateWarpVectorDistributionPatterns(patterns, + distributionFn); vector::populateDistributeReduction(patterns, warpReduction); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); }