Index: mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h =================================================================== --- mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h +++ mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h @@ -39,6 +39,31 @@ RewritePatternSet &patterns, const WarpExecuteOnLane0LoweringOptions &options); +using DistributionMapFn = std::function; + +/// Distribute transfer_write ops based on the affine map returned by +/// `distributionMapFn`. +/// Example: +/// ``` +/// %0 = vector.warp_execute_on_lane_0(%id){ +/// ... +/// vector.transfer_write %v, %A[%c0] : vector<32xf32>, memref<128xf32> +/// vector.yield +/// } +/// ``` +/// To +/// ``` +/// %r:3 = vector.warp_execute_on_lane_0(%id) -> (vector<1xf32>) { +/// ... +/// vector.yield %v : vector<32xf32> +/// } +/// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32> +void populateDistributeTransferWriteOpPatterns( + RewritePatternSet &patterns, DistributionMapFn distributionMapFn); + +/// Move scalar operations with no dependency on warp op outside of the region. +void moveScalarUniformCode(WarpExecuteOnLane0Op op); + } // namespace vector } // namespace mlir #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORDISTRIBUTION_H_ Index: mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp =================================================================== --- mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -6,10 +6,12 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" +#include "mlir/Transforms/SideEffectUtils.h" using namespace mlir; using namespace mlir::vector; @@ -93,8 +95,8 @@ if (resultType == val.getType()) { // Result type and yielded value type are the same. This is a broadcast. // E.g.: - // %r = vector_ext.warp_execute_on_lane_0(...) -> (f32) { - // vector_ext.yield %cst : f32 + // %r = vector.warp_execute_on_lane_0(...) -> (f32) { + // vector.yield %cst : f32 // } // Both types are f32. The constant %cst is broadcasted to all lanes. // This is described in more detail in the documentation of the op. @@ -131,6 +133,56 @@ return success(); } +/// Helper to create a new WarpExecuteOnLane0Op regions with different +/// signature. +static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns( + OpBuilder &b, WarpExecuteOnLane0Op warpOp, ValueRange newYieldedValues, + TypeRange newReturnTypes) { + // Create a new op before the existing one, with the extra operands. + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(warpOp); + auto newWarpOp = b.create( + warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(), + warpOp.getArgs(), warpOp.getBody()->getArgumentTypes()); + + Region &opBody = warpOp.getBodyRegion(); + Region &newOpBody = newWarpOp.getBodyRegion(); + newOpBody.takeBody(opBody); + auto yield = + cast(newOpBody.getBlocks().begin()->getTerminator()); + yield.operandsMutable().assign(newYieldedValues); + return newWarpOp; +} + +/// Helper to create a new WarpExecuteOnLane0Op region with extra outputs. +static WarpExecuteOnLane0Op +moveRegionToNewWarpOpAndAppendReturns(OpBuilder &b, WarpExecuteOnLane0Op warpOp, + ValueRange newYieldedValues, + TypeRange newReturnTypes) { + SmallVector types(warpOp.getResultTypes().begin(), + warpOp.getResultTypes().end()); + types.append(newReturnTypes.begin(), newReturnTypes.end()); + auto yield = cast( + warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); + SmallVector yieldValues(yield.getOperands().begin(), + yield.getOperands().end()); + yieldValues.append(newYieldedValues.begin(), newYieldedValues.end()); + WarpExecuteOnLane0Op newWarpOp = + moveRegionToNewWarpOpAndReplaceReturns(b, warpOp, yieldValues, types); + for (auto it : + llvm::zip(warpOp.getResults(), + newWarpOp.getResults().take_front(warpOp.getNumResults()))) + std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); + return newWarpOp; +} + +/// Helper to know if an op can be hoisted out of the region. +static bool canBeHoisted(Operation *op, + function_ref definedOutside) { + return llvm::all_of(op->getOperands(), definedOutside) && + isSideEffectFree(op) && op->getNumRegions() == 0; +} + namespace { struct WarpOpToScfForPattern : public OpRewritePattern { @@ -149,6 +201,140 @@ const WarpExecuteOnLane0LoweringOptions &options; }; +struct WarpOpTransferWrite : public OpRewritePattern { + WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn, + PatternBenefit b = 1) + : OpRewritePattern(ctx, b), + distributionMapFn(fn) {} + + /// Distribute the TransferWriteOp. Only 1D distributions and vector dims that + /// are multiples of the distribution ratio are supported at the moment. + LogicalResult tryDistributeOp(RewriterBase &rewriter, + vector::TransferWriteOp writeOp, + WarpExecuteOnLane0Op warpOp) const { + AffineMap map = distributionMapFn(writeOp); + SmallVector targetShape(writeOp.getVectorType().getShape().begin(), + writeOp.getVectorType().getShape().end()); + assert(map.getNumResults() == 1 && + "multi-dim distribution not implemented yet"); + for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { + unsigned position = map.getDimPosition(i); + if (targetShape[position] % warpOp.getWarpSize() != 0) + return failure(); + targetShape[position] = targetShape[position] / warpOp.getWarpSize(); + } + VectorType targetType = + VectorType::get(targetShape, writeOp.getVectorType().getElementType()); + + SmallVector yieldValues = {writeOp.getVector()}; + SmallVector retTypes = {targetType}; + WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, yieldValues, retTypes); + rewriter.setInsertionPointAfter(newWarpOp); + + // Move op outside of region: Insert clone at the insertion point and delete + // the old op. + auto newWriteOp = + cast(rewriter.clone(*writeOp.getOperation())); + rewriter.eraseOp(writeOp); + + rewriter.setInsertionPoint(newWriteOp); + AffineMap indexMap = map.compose(newWriteOp.getPermutationMap()); + Location loc = newWriteOp.getLoc(); + SmallVector indices(newWriteOp.getIndices().begin(), + newWriteOp.getIndices().end()); + for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) { + AffineExpr d0, d1; + bindDims(newWarpOp.getContext(), d0, d1); + auto indexExpr = std::get<0>(it).dyn_cast(); + if (!indexExpr) + continue; + unsigned indexPos = indexExpr.getPosition(); + unsigned vectorPos = std::get<1>(it).cast().getPosition(); + auto scale = + getAffineConstantExpr(targetShape[vectorPos], newWarpOp.getContext()); + indices[indexPos] = + makeComposedAffineApply(rewriter, loc, d0 + scale * d1, + {indices[indexPos], newWarpOp.getLaneid()}); + } + newWriteOp.getVectorMutable().assign(newWarpOp.getResults().back()); + newWriteOp.getIndicesMutable().assign(indices); + + return success(); + } + + /// Extract TransferWriteOps of vector<1x> into a separate warp op. + LogicalResult tryExtractOp(RewriterBase &rewriter, + vector::TransferWriteOp writeOp, + WarpExecuteOnLane0Op warpOp) const { + Location loc = writeOp.getLoc(); + VectorType vecType = writeOp.getVectorType(); + + // Only vector<1x> is supported at the moment. + if (vecType.getShape().size() != 1 || vecType.getShape()[0] != 1) + return failure(); + + // Do not process warp ops that contain only TransferWriteOps. + if (llvm::all_of(warpOp.getOps(), [](Operation &op) { + return isa(&op); + })) + return failure(); + + SmallVector yieldValues = {writeOp.getVector()}; + SmallVector retTypes = {vecType}; + WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, yieldValues, retTypes); + rewriter.setInsertionPointAfter(newWarpOp); + + // Create a second warp op that contains only writeOp. + auto secondWarpOp = rewriter.create( + loc, TypeRange(), newWarpOp.getLaneid(), newWarpOp.getWarpSize()); + Block &body = secondWarpOp.getBodyRegion().front(); + rewriter.setInsertionPointToStart(&body); + auto newWriteOp = + cast(rewriter.clone(*writeOp.getOperation())); + newWriteOp.getVectorMutable().assign( + newWarpOp.getResult(newWarpOp.getNumResults() - 1)); + rewriter.eraseOp(writeOp); + rewriter.create(newWarpOp.getLoc()); + return success(); + } + + LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, + PatternRewriter &rewriter) const override { + // Ops with mask not supported yet. + if (writeOp.getMask()) + return failure(); + + auto warpOp = dyn_cast(writeOp->getParentOp()); + if (!warpOp) + return failure(); + + // There must be no op with a side effect after writeOp. + Operation *nextOp = writeOp.getOperation(); + while ((nextOp = nextOp->getNextNode())) + if (!isSideEffectFree(nextOp)) + return failure(); + + if (!llvm::all_of(writeOp->getOperands(), [&](Value value) { + return writeOp.getVector() == value || + warpOp.isDefinedOutsideOfRegion(value); + })) + return failure(); + + if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp))) + return success(); + + if (succeeded(tryExtractOp(rewriter, writeOp, warpOp))) + return success(); + + return failure(); + } + +private: + DistributionMapFn distributionMapFn; +}; + } // namespace void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern( @@ -156,3 +342,36 @@ const WarpExecuteOnLane0LoweringOptions &options) { patterns.add(patterns.getContext(), options); } + +void mlir::vector::populateDistributeTransferWriteOpPatterns( + RewritePatternSet &patterns, DistributionMapFn distributionMapFn) { + patterns.add(patterns.getContext(), distributionMapFn); +} + +void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) { + Block *body = warpOp.getBody(); + + // Keep track of the ops we want to hoist. + llvm::SmallSetVector opsToMove; + + // Helper to check if a value is or will be defined outside of the region. + auto isDefinedOutsideOfBody = [&](Value value) { + auto *definingOp = value.getDefiningOp(); + return (definingOp && opsToMove.count(definingOp)) || + warpOp.isDefinedOutsideOfRegion(value); + }; + + // Do not use walk here, as we do not want to go into nested regions and hoist + // operations from there. + for (auto &op : body->without_terminator()) { + bool hasVectorResult = llvm::any_of(op.getResults(), [](Value result) { + return result.getType().isa(); + }); + if (!hasVectorResult && canBeHoisted(&op, isDefinedOutsideOfBody)) + opsToMove.insert(&op); + } + + // Move all the ops marked as uniform outside of the region. + for (Operation *op : opsToMove) + op->moveBefore(warpOp); +} Index: mlir/test/Dialect/Vector/vector-warp-distribute.mlir =================================================================== --- mlir/test/Dialect/Vector/vector-warp-distribute.mlir +++ mlir/test/Dialect/Vector/vector-warp-distribute.mlir @@ -1,4 +1,6 @@ // RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -test-vector-warp-distribute=rewrite-warp-ops-to-scf-if | FileCheck %s --check-prefix=CHECK-SCF-IF +// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -test-vector-warp-distribute="hoist-uniform" | FileCheck --check-prefixes=CHECK-HOIST %s +// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -test-vector-warp-distribute="hoist-uniform distribute-transfer-write" | FileCheck --check-prefixes=CHECK-D %s // CHECK-SCF-IF-DAG: memref.global "private" @__shared_32xf32 : memref<32xf32, 3> // CHECK-SCF-IF-DAG: memref.global "private" @__shared_64xf32 : memref<64xf32, 3> @@ -52,3 +54,76 @@ "some_use"(%r#1) : (vector<2xf32>) -> () return } + +// ----- + +// CHECK-D-DAG: #[[MAP1:.*]] = affine_map<()[s0] -> (s0 * 2 + 32)> + +// CHECK-ALL-LABEL: func @warp( +// CHECK-HOIST: memref.subview +// CHECK-HOIST: memref.subview +// CHECK-HOIST: memref.subview +// CHECK-HOIST: vector.warp_execute_on_lane_0 + +// CHECK-D: %[[R:.*]]:2 = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<2xf32>, vector<1xf32>) { +// CHECK-D: arith.addf {{.*}} : vector<32xf32> +// CHECK-D: arith.addf {{.*}} : vector<64xf32> +// CHECK-D: vector.yield %{{.*}}, %{{.*}} : vector<64xf32>, vector<32xf32> +// CHECK-D-DAG: vector.transfer_write %[[R]]#1, %{{.*}}[%{{.*}}] {in_bounds = [true]} : vector<1xf32>, memref<128xf32 +// CHECK-D-DAG: %[[ID1:.*]] = affine.apply #[[MAP1]]()[%{{.*}}] +// CHECK-D-DAG: vector.transfer_write %[[R]]#0, %2[%[[ID1]]] {in_bounds = [true]} : vector<2xf32>, memref<128xf32 + +// CHECK-ALL-NOT: vector.warp_execute_on_lane_0 +// CHECK-ALL: vector.transfer_read {{.*}} vector<1xf32> +// CHECK-ALL: vector.transfer_read {{.*}} vector<1xf32> +// CHECK-ALL: vector.transfer_read {{.*}} vector<2xf32> +// CHECK-ALL: vector.transfer_read {{.*}} vector<2xf32> +// CHECK-ALL: arith.addf {{.*}} : vector<1xf32> +// CHECK-ALL: arith.addf {{.*}} : vector<2xf32> +// CHECK-ALL: vector.transfer_write {{.*}} : vector<1xf32> +// CHECK-ALL: vector.transfer_write {{.*}} : vector<2xf32> + +#map0 = affine_map<(d0)[s0] -> (d0 + s0)> +func.func @warp(%laneid: index, %arg1: memref<1024xf32>, %arg2: memref<1024xf32>, + %arg3: memref<1024xf32>, %gid : index) { + vector.warp_execute_on_lane_0(%laneid)[32] { + %sa = memref.subview %arg1[%gid] [128] [1] : memref<1024xf32> to memref<128xf32, #map0> + %sb = memref.subview %arg2[%gid] [128] [1] : memref<1024xf32> to memref<128xf32, #map0> + %sc = memref.subview %arg3[%gid] [128] [1] : memref<1024xf32> to memref<128xf32, #map0> + %c0 = arith.constant 0 : index + %c32 = arith.constant 32 : index + %cst = arith.constant 0.000000e+00 : f32 + %2 = vector.transfer_read %sa[%c0], %cst : memref<128xf32, #map0>, vector<32xf32> + %3 = vector.transfer_read %sa[%c32], %cst : memref<128xf32, #map0>, vector<32xf32> + %4 = vector.transfer_read %sb[%c0], %cst : memref<128xf32, #map0>, vector<64xf32> + %5 = vector.transfer_read %sb[%c32], %cst : memref<128xf32, #map0>, vector<64xf32> + %6 = arith.addf %2, %3 : vector<32xf32> + %7 = arith.addf %4, %5 : vector<64xf32> + vector.transfer_write %6, %sc[%c0] : vector<32xf32>, memref<128xf32, #map0> + vector.transfer_write %7, %sc[%c32] : vector<64xf32>, memref<128xf32, #map0> + } + return +} + +// ----- + +// CHECK-D-LABEL: func @warp_extract( +// CHECK-D: %[[WARPOP:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1xf32>) +// CHECK-D: "test.dummy_op" +// CHECK-D: vector.yield %{{.*}} : vector<1xf32> +// CHECK-D: } +// CHECK-D: vector.warp_execute_on_lane_0(%{{.*}})[32] { +// CHECK-D: vector.transfer_write %[[WARPOP]], %{{.*}}[%{{.*}}] {{.*}} : vector<1xf32> +// CHECK-D: } + +#map2 = affine_map<(d0)[s0] -> (d0 + s0)> + +func.func @warp_extract(%laneid: index, %arg1: memref<1024xf32>, %gid : index) { + vector.warp_execute_on_lane_0(%laneid)[32] { + %sa = memref.subview %arg1[%gid] [128] [1] : memref<1024xf32> to memref<128xf32, #map2> + %c0 = arith.constant 0 : index + %v = "test.dummy_op"() : () -> (vector<1xf32>) + vector.transfer_write %v, %sa[%c0] : vector<1xf32>, memref<128xf32, #map2> + } + return +} \ No newline at end of file Index: mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp =================================================================== --- mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -801,7 +801,8 @@ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorDistribution) void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); + registry.insert(); } StringRef getArgument() const final { return "test-vector-warp-distribute"; } @@ -817,8 +818,43 @@ llvm::cl::desc("Lower vector.warp_execute_on_lane0 to scf.if op"), llvm::cl::init(false)}; + Option distributeTransferWriteOps{ + *this, "distribute-transfer-write", + llvm::cl::desc("Test distribution of transfer write"), + llvm::cl::init(false)}; + + Option hoistUniform{*this, "hoist-uniform", + llvm::cl::desc("Test hoist uniform"), + llvm::cl::init(false)}; + void runOnOperation() override { RewritePatternSet patterns(&getContext()); + + getOperation().walk([&](Operation *op) { + if (auto warpOp = dyn_cast(op)) { + if (hoistUniform) { + moveScalarUniformCode(warpOp); + } + WalkResult::interrupt(); + } + }); + MLIRContext *ctx = &getContext(); + if (distributeTransferWriteOps) { + auto distributionFn = [](vector::TransferWriteOp writeOp) { + // Create a map (d0, d1) -> (d1) to distribute along the inner + // dimension. Once we support n-d distribution we can add more + // complex cases. + int64_t vecRank = writeOp.getVectorType().getRank(); + OpBuilder builder(writeOp.getContext()); + auto map = + AffineMap::get(vecRank, 0, builder.getAffineDimExpr(vecRank - 1)); + return map; + }; + RewritePatternSet patterns(ctx); + populateDistributeTransferWriteOpPatterns(patterns, distributionFn); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } + WarpExecuteOnLane0LoweringOptions options; options.warpAllocationFn = allocateGlobalSharedMemory; options.warpSyncronizationFn = [](Location loc, OpBuilder &builder,