diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -10,9 +10,9 @@ #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" -#include "mlir/Dialect/Vector/Utils/VectorUtils.h" -#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/AffineExpr.h" #include "mlir/Transforms/SideEffectUtils.h" #include "llvm/ADT/SetVector.h" #include @@ -20,122 +20,115 @@ using namespace mlir; using namespace mlir::vector; -static LogicalResult -rewriteWarpOpToScfFor(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, - const WarpExecuteOnLane0LoweringOptions &options) { - assert(warpOp.getBodyRegion().hasOneBlock() && - "expected WarpOp with single block"); - Block *warpOpBody = &warpOp.getBodyRegion().front(); - Location loc = warpOp.getLoc(); +/// TODO: add an analysis step that determines which vector dimension should be +/// used for distribution. +static llvm::Optional +getDistributedVectorDim(VectorType distributedVectorType) { + return (distributedVectorType) + ? llvm::Optional(distributedVectorType.getRank() - 1) + : llvm::None; +} - // Passed all checks. Start rewriting. - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(warpOp); +static llvm::Optional +getDistributedSize(VectorType distributedVectorType) { + auto dim = getDistributedVectorDim(distributedVectorType); + return (dim) ? llvm::Optional(distributedVectorType.getDimSize(*dim)) + : llvm::None; +} - // Create scf.if op. - Value c0 = rewriter.create(loc, 0); - Value isLane0 = rewriter.create(loc, arith::CmpIPredicate::eq, - warpOp.getLaneid(), c0); - auto ifOp = rewriter.create(loc, isLane0, - /*withElseRegion=*/false); - rewriter.eraseOp(ifOp.thenBlock()->getTerminator()); - - // Store vectors that are defined outside of warpOp into the scratch pad - // buffer. - SmallVector bbArgReplacements; - for (const auto &it : llvm::enumerate(warpOp.getArgs())) { - Value val = it.value(); - Value bbArg = warpOpBody->getArgument(it.index()); - - rewriter.setInsertionPoint(ifOp); - Value buffer = - options.warpAllocationFn(loc, rewriter, warpOp, bbArg.getType()); - - // Store arg vector into buffer. - rewriter.setInsertionPoint(ifOp); - auto vectorType = val.getType().cast(); - int64_t storeSize = vectorType.getShape()[0]; - Value storeOffset = rewriter.create( - loc, warpOp.getLaneid(), - rewriter.create(loc, storeSize)); - rewriter.create(loc, val, buffer, storeOffset); - - // Load bbArg vector from buffer. - rewriter.setInsertionPointToStart(ifOp.thenBlock()); - auto bbArgType = bbArg.getType().cast(); - Value loadOp = rewriter.create(loc, bbArgType, buffer, c0); - bbArgReplacements.push_back(loadOp); +namespace { + +/// Helper struct to create the load / store operations that permit transit +/// through the parallel / sequential and the sequential / parallel boundaries +/// when performing `rewriteWarpOpToScfFor`. +/// +/// All this assumes the vector distribution occurs along the most minor +/// distributed vector dimension. +/// TODO: which is expected to be a multiple of the warp size ? +/// TODO: add an analysis step that determines which vector dimension should +/// be used for distribution. +struct DistributedLoadStoreHelper { + DistributedLoadStoreHelper(Value sequentialVal, Value distributedVal, + Value laneId, Value zero) + : sequentialVal(sequentialVal), distributedVal(distributedVal), + laneId(laneId), zero(zero) { + sequentialType = sequentialVal.getType(); + distributedType = distributedVal.getType(); + sequentialVectorType = sequentialType.dyn_cast(); + distributedVectorType = distributedType.dyn_cast(); } - // Insert sync after all the stores and before all the loads. - if (!warpOp.getArgs().empty()) { - rewriter.setInsertionPoint(ifOp); - options.warpSyncronizationFn(loc, rewriter, warpOp); + Value buildDistributedOffset(RewriterBase &b, Location loc) { + auto maybeDistributedSize = getDistributedSize(distributedVectorType); + assert(maybeDistributedSize && + "at this point, a distributed size must be determined"); + AffineExpr tid = getAffineSymbolExpr(0, b.getContext()); + return b.createOrFold(loc, tid * (*maybeDistributedSize), + ArrayRef{laneId}); } - // Move body of warpOp to ifOp. - rewriter.mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements); - - // Rewrite terminator and compute replacements of WarpOp results. - SmallVector replacements; - auto yieldOp = cast(ifOp.thenBlock()->getTerminator()); - Location yieldLoc = yieldOp.getLoc(); - for (const auto &it : llvm::enumerate(yieldOp.operands())) { - Value val = it.value(); - Type resultType = warpOp->getResultTypes()[it.index()]; - rewriter.setInsertionPoint(ifOp); - Value buffer = - options.warpAllocationFn(loc, rewriter, warpOp, val.getType()); - - // Store yielded value into buffer. - rewriter.setInsertionPoint(yieldOp); - if (val.getType().isa()) - rewriter.create(yieldLoc, val, buffer, c0); - else - rewriter.create(yieldLoc, val, buffer, c0); - - // Load value from buffer (after warpOp). - rewriter.setInsertionPointAfter(ifOp); - if (resultType == val.getType()) { - // Result type and yielded value type are the same. This is a broadcast. - // E.g.: - // %r = vector.warp_execute_on_lane_0(...) -> (f32) { - // vector.yield %cst : f32 - // } - // Both types are f32. The constant %cst is broadcasted to all lanes. - // This is described in more detail in the documentation of the op. - Value loadOp = rewriter.create(loc, buffer, c0); - replacements.push_back(loadOp); - } else { - auto loadedVectorType = resultType.cast(); - int64_t loadSize = loadedVectorType.getShape()[0]; - - // loadOffset = laneid * loadSize - Value loadOffset = rewriter.create( - loc, warpOp.getLaneid(), - rewriter.create(loc, loadSize)); - Value loadOp = rewriter.create(loc, loadedVectorType, - buffer, loadOffset); - replacements.push_back(loadOp); + Operation *buildStore(RewriterBase &b, Location loc, Value val, + Value buffer) { + assert((val == distributedVal || val == sequentialVal) && + "Must store either the preregistered distributed or the " + "preregistered sequential value."); + // Vector case must use vector::TransferWriteOp which will later lower to + // vector.store of memref.store depending on further lowerings. + if (val.getType().isa()) { + SmallVector indices(sequentialVectorType.getRank(), zero); + auto maybeDistributedDim = getDistributedVectorDim(distributedVectorType); + assert(maybeDistributedDim && "must be able to deduce distributed dim"); + if (val == distributedVal) + indices[*maybeDistributedDim] = + (val == distributedVal) ? buildDistributedOffset(b, loc) : zero; + SmallVector inBounds(indices.size(), true); + return b.create( + loc, val, buffer, indices, + ArrayRef(inBounds.begin(), inBounds.end())); } + // Scalar case can directly use memref.store. + return b.create(loc, val, buffer, zero); } - // Insert sync after all the stores and before all the loads. - if (!yieldOp.operands().empty()) { - rewriter.setInsertionPointAfter(ifOp); - options.warpSyncronizationFn(loc, rewriter, warpOp); + Value buildLoad(RewriterBase &b, Location loc, Type type, Value buffer, + bool broadcastMode = false) { + // When broadcastMode is true, this is a broadcast. + // E.g.: + // %r = vector.warp_execute_on_lane_0(...) -> (f32) { + // vector.yield %cst : f32 + // } + // Both types are f32. The constant %cst is broadcasted to all lanes. + // This is described in more detail in the documentation of the op. + if (broadcastMode) + return b.create(loc, buffer, zero); + + // Other cases must be vector atm. + // Vector case must use vector::TransferReadOp which will later lower to + // vector.read of memref.read depending on further lowerings. + assert(type.isa() && "must be a vector type"); + assert((type == distributedVectorType || type == sequentialVectorType) && + "Must store either the preregistered distributed or the " + "preregistered sequential type."); + auto maybeDistributedDim = getDistributedVectorDim(distributedVectorType); + assert(maybeDistributedDim && "must be able to deduce distributed dim"); + SmallVector indices(sequentialVectorType.getRank(), zero); + if (type == distributedVectorType) { + indices[*maybeDistributedDim] = buildDistributedOffset(b, loc); + } else { + indices[*maybeDistributedDim] = zero; + } + SmallVector inBounds(indices.size(), true); + return b.create( + loc, type.cast(), buffer, indices, + ArrayRef(inBounds.begin(), inBounds.end())); } - // Delete terminator and add empty scf.yield. - rewriter.eraseOp(yieldOp); - rewriter.setInsertionPointToEnd(ifOp.thenBlock()); - rewriter.create(yieldLoc); - - // Compute replacements for WarpOp results. - rewriter.replaceOp(warpOp, replacements); + Value sequentialVal, distributedVal, laneId, zero; + Type sequentialType, distributedType; + VectorType sequentialVectorType, distributedVectorType; +}; - return success(); -} +} // namespace /// Helper to create a new WarpExecuteOnLane0Op with different signature. static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns( @@ -261,6 +254,37 @@ namespace { +/// Rewrite a WarpExecuteOnLane0Op into a predicated scf.if op where the single +/// thread `laneId` executes the entirety of the computation. +/// +/// After the transformation: +/// - the IR within the scf.if op can be thought of as executing sequentially +/// (from the point of view of threads along `laneId`). +/// - the IR outside of the scf.if op can be thought of as executing in +/// parallel (from the point of view of threads along `laneId`). +/// +/// Values that need to transit through the parallel / sequential and the +/// sequential / parallel boundaries do so via reads and writes to a temporary +/// memory location. +/// +/// The transformation proceeds in multiple steps: +/// 1. Create the scf.if op. +/// 2. Insert appropriate (alloc, write)-pairs before the scf.if and reads +/// within the scf.if to transit the values captured from above. +/// 3. Synchronize before the scf.if to ensure all writes inserted in 2. are +/// consistent within the scf.if. +/// 4. Move the body of the WarpExecuteOnLane0Op inside the scf.if. +/// 5. Insert appropriate writes within scf.if and reads after the scf.if to +/// transit the values returned by the op. +/// 6. Synchronize after the scf.if to ensure all writes inserted in 5. are +/// consistent after the scf.if. +/// 7. Perform late cleanups. +/// +/// All this assumes the vector distribution occurs along the most minor +/// distributed vector dimension. +/// TODO: which is expected to be a multiple of the warp size ? +/// TODO: add an analysis step that determines which vector dimension should be +/// used for distribution. struct WarpOpToScfForPattern : public OpRewritePattern { WarpOpToScfForPattern(MLIRContext *context, const WarpExecuteOnLane0LoweringOptions &options, @@ -270,7 +294,106 @@ LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { - return rewriteWarpOpToScfFor(rewriter, warpOp, options); + assert(warpOp.getBodyRegion().hasOneBlock() && + "expected WarpOp with single block"); + Block *warpOpBody = &warpOp.getBodyRegion().front(); + Location loc = warpOp.getLoc(); + + // Passed all checks. Start rewriting. + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(warpOp); + + // Step 1: Create scf.if op. + Value c0 = rewriter.create(loc, 0); + Value isLane0 = rewriter.create( + loc, arith::CmpIPredicate::eq, warpOp.getLaneid(), c0); + auto ifOp = rewriter.create(loc, isLane0, + /*withElseRegion=*/false); + rewriter.eraseOp(ifOp.thenBlock()->getTerminator()); + + // Step 2: insert appropriate (alloc, write)-pairs before the scf.if and + // reads within the scf.if to transit the values captured from above. + SmallVector bbArgReplacements; + for (const auto &it : llvm::enumerate(warpOp.getArgs())) { + Value sequentialVal = warpOpBody->getArgument(it.index()); + Value distributedVal = it.value(); + DistributedLoadStoreHelper helper(sequentialVal, distributedVal, + warpOp.getLaneid(), c0); + + // Create buffer before the ifOp. + rewriter.setInsertionPoint(ifOp); + Value buffer = options.warpAllocationFn(loc, rewriter, warpOp, + sequentialVal.getType()); + // Store distributed vector into buffer, before the ifOp. + helper.buildStore(rewriter, loc, distributedVal, buffer); + // Load sequential vector from buffer, inside the ifOp. + rewriter.setInsertionPointToStart(ifOp.thenBlock()); + bbArgReplacements.push_back( + helper.buildLoad(rewriter, loc, sequentialVal.getType(), buffer)); + } + + // Step 3. Insert sync after all the stores and before all the loads. + if (!warpOp.getArgs().empty()) { + rewriter.setInsertionPoint(ifOp); + options.warpSyncronizationFn(loc, rewriter, warpOp); + } + + // Step 4. Move body of warpOp to ifOp. + rewriter.mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements); + + // Step 5. Insert appropriate writes within scf.if and reads after the + // scf.if to transit the values returned by the op. + // TODO: at this point, we can reuse the shared memory from previous + // buffers. + SmallVector replacements; + auto yieldOp = cast(ifOp.thenBlock()->getTerminator()); + Location yieldLoc = yieldOp.getLoc(); + for (const auto &it : llvm::enumerate(yieldOp.operands())) { + Value sequentialVal = it.value(); + Value distributedVal = warpOp->getResult(it.index()); + DistributedLoadStoreHelper helper(sequentialVal, distributedVal, + warpOp.getLaneid(), c0); + + // Create buffer before the ifOp. + rewriter.setInsertionPoint(ifOp); + Value buffer = options.warpAllocationFn(loc, rewriter, warpOp, + sequentialVal.getType()); + + // Store yielded value into buffer, inside the ifOp, before the + // terminator. + rewriter.setInsertionPoint(yieldOp); + helper.buildStore(rewriter, loc, sequentialVal, buffer); + + // Load distributed value from buffer, after the warpOp. + rewriter.setInsertionPointAfter(ifOp); + bool broadcastMode = + (sequentialVal.getType() == distributedVal.getType()); + // Result type and yielded value type are the same. This is a broadcast. + // E.g.: + // %r = vector.warp_execute_on_lane_0(...) -> (f32) { + // vector.yield %cst : f32 + // } + // Both types are f32. The constant %cst is broadcasted to all lanes. + // This is described in more detail in the documentation of the op. + replacements.push_back(helper.buildLoad( + rewriter, loc, distributedVal.getType(), buffer, broadcastMode)); + } + + // Step 6. Insert sync after all the stores and before all the loads. + if (!yieldOp.operands().empty()) { + rewriter.setInsertionPointAfter(ifOp); + options.warpSyncronizationFn(loc, rewriter, warpOp); + } + + // Step 7. Delete terminator and add empty scf.yield. + rewriter.eraseOp(yieldOp); + rewriter.setInsertionPointToEnd(ifOp.thenBlock()); + rewriter.create(yieldLoc); + + // Compute replacements for WarpOp results. + rewriter.replaceOp(warpOp, replacements); + + return success(); } private: diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir --- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir +++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir @@ -4,7 +4,9 @@ // RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -test-vector-warp-distribute=propagate-distribution -canonicalize | FileCheck --check-prefixes=CHECK-PROP %s // RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -test-vector-warp-distribute="hoist-uniform distribute-transfer-write propagate-distribution" -canonicalize | FileCheck --check-prefixes=CHECK-DIST-AND-PROP %s - +// CHECK-SCF-IF-DAG: #[[$TIMES2:.*]] = affine_map<()[s0] -> (s0 * 2)> +// CHECK-SCF-IF-DAG: #[[$TIMES4:.*]] = affine_map<()[s0] -> (s0 * 4)> +// CHECK-SCF-IF-DAG: #[[$TIMES8:.*]] = affine_map<()[s0] -> (s0 * 8)> // CHECK-SCF-IF-DAG: memref.global "private" @__shared_32xf32 : memref<32xf32, 3> // CHECK-SCF-IF-DAG: memref.global "private" @__shared_64xf32 : memref<64xf32, 3> // CHECK-SCF-IF-DAG: memref.global "private" @__shared_128xf32 : memref<128xf32, 3> @@ -16,17 +18,14 @@ func.func @rewrite_warp_op_to_scf_if(%laneid: index, %v0: vector<4xf32>, %v1: vector<8xf32>) { // CHECK-SCF-IF-DAG: %[[c0:.*]] = arith.constant 0 : index -// CHECK-SCF-IF-DAG: %[[c2:.*]] = arith.constant 2 : index -// CHECK-SCF-IF-DAG: %[[c4:.*]] = arith.constant 4 : index -// CHECK-SCF-IF-DAG: %[[c8:.*]] = arith.constant 8 : index // CHECK-SCF-IF: %[[is_lane_0:.*]] = arith.cmpi eq, %[[laneid]], %[[c0]] // CHECK-SCF-IF: %[[buffer_v0:.*]] = memref.get_global @__shared_128xf32 -// CHECK-SCF-IF: %[[s0:.*]] = arith.muli %[[laneid]], %[[c4]] -// CHECK-SCF-IF: vector.store %[[v0]], %[[buffer_v0]][%[[s0]]] +// CHECK-SCF-IF: %[[s0:.*]] = affine.apply #[[$TIMES4]]()[%[[laneid]]] +// CHECK-SCF-IF: vector.transfer_write %[[v0]], %[[buffer_v0]][%[[s0]]] // CHECK-SCF-IF: %[[buffer_v1:.*]] = memref.get_global @__shared_256xf32 -// CHECK-SCF-IF: %[[s1:.*]] = arith.muli %[[laneid]], %[[c8]] -// CHECK-SCF-IF: vector.store %[[v1]], %[[buffer_v1]][%[[s1]]] +// CHECK-SCF-IF: %[[s1:.*]] = affine.apply #[[$TIMES8]]()[%[[laneid]]] +// CHECK-SCF-IF: vector.transfer_write %[[v1]], %[[buffer_v1]][%[[s1]]] // CHECK-SCF-IF-DAG: gpu.barrier // CHECK-SCF-IF-DAG: %[[buffer_def_0:.*]] = memref.get_global @__shared_32xf32 @@ -36,21 +35,21 @@ %r:2 = vector.warp_execute_on_lane_0(%laneid)[32] args(%v0, %v1 : vector<4xf32>, vector<8xf32>) -> (vector<1xf32>, vector<2xf32>) { ^bb0(%arg0: vector<128xf32>, %arg1: vector<256xf32>): -// CHECK-SCF-IF: %[[arg1:.*]] = vector.load %[[buffer_v1]][%[[c0]]] : memref<256xf32, 3>, vector<256xf32> -// CHECK-SCF-IF: %[[arg0:.*]] = vector.load %[[buffer_v0]][%[[c0]]] : memref<128xf32, 3>, vector<128xf32> +// CHECK-SCF-IF: %[[arg1:.*]] = vector.transfer_read %[[buffer_v1]][%[[c0]]], %{{.*}} {in_bounds = [true]} : memref<256xf32, 3>, vector<256xf32> +// CHECK-SCF-IF: %[[arg0:.*]] = vector.transfer_read %[[buffer_v0]][%[[c0]]], %{{.*}} {in_bounds = [true]} : memref<128xf32, 3>, vector<128xf32> // CHECK-SCF-IF: %[[def_0:.*]] = "some_def"(%[[arg0]]) : (vector<128xf32>) -> vector<32xf32> // CHECK-SCF-IF: %[[def_1:.*]] = "some_def"(%[[arg1]]) : (vector<256xf32>) -> vector<64xf32> %2 = "some_def"(%arg0) : (vector<128xf32>) -> vector<32xf32> %3 = "some_def"(%arg1) : (vector<256xf32>) -> vector<64xf32> -// CHECK-SCF-IF: vector.store %[[def_0]], %[[buffer_def_0]][%[[c0]]] -// CHECK-SCF-IF: vector.store %[[def_1]], %[[buffer_def_1]][%[[c0]]] +// CHECK-SCF-IF: vector.transfer_write %[[def_0]], %[[buffer_def_0]][%[[c0]]] +// CHECK-SCF-IF: vector.transfer_write %[[def_1]], %[[buffer_def_1]][%[[c0]]] vector.yield %2, %3 : vector<32xf32>, vector<64xf32> } // CHECK-SCF-IF: } // CHECK-SCF-IF: gpu.barrier -// CHECK-SCF-IF: %[[o1:.*]] = arith.muli %[[laneid]], %[[c2]] -// CHECK-SCF-IF: %[[r1:.*]] = vector.load %[[buffer_def_1]][%[[o1]]] : memref<64xf32, 3>, vector<2xf32> -// CHECK-SCF-IF: %[[r0:.*]] = vector.load %[[buffer_def_0]][%[[laneid]]] : memref<32xf32, 3>, vector<1xf32> +// CHECK-SCF-IF: %[[o1:.*]] = affine.apply #[[$TIMES2]]()[%[[laneid]]] +// CHECK-SCF-IF: %[[r1:.*]] = vector.transfer_read %[[buffer_def_1]][%[[o1]]], %{{.*}} {in_bounds = [true]} : memref<64xf32, 3>, vector<2xf32> +// CHECK-SCF-IF: %[[r0:.*]] = vector.transfer_read %[[buffer_def_0]][%[[laneid]]], %{{.*}} {in_bounds = [true]} : memref<32xf32, 3>, vector<1xf32> // CHECK-SCF-IF: "some_use"(%[[r0]]) : (vector<1xf32>) -> () // CHECK-SCF-IF: "some_use"(%[[r1]]) : (vector<2xf32>) -> () "some_use"(%r#0) : (vector<1xf32>) -> () @@ -631,3 +630,23 @@ } return %r : f32 } + +// ----- + +// CHECK-PROP: func @lane_dependent_warp_propagate_read +// CHECK-PROP-SAME: %[[ID:.*]]: index +func.func @lane_dependent_warp_propagate_read( + %laneid: index, %src: memref<1x1024xf32>, %dest: memref<1x1024xf32>) { + // CHECK-PROP-DAG: %[[C0:.*]] = arith.constant 0 : index + // CHECK-PROP-NOT: lane_dependent_warp_propagate_read + // CHECK-PROP-DAG: %[[R0:.*]] = vector.transfer_read %arg1[%[[C0]], %[[ID]]], %{{.*}} : memref<1x1024xf32>, vector<1x1xf32> + // CHECK-PROP: vector.transfer_write %[[R0]], {{.*}} : vector<1x1xf32>, memref<1x1024xf32> + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f32 + %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<1x1xf32>) { + %2 = vector.transfer_read %src[%c0, %c0], %cst : memref<1x1024xf32>, vector<1x32xf32> + vector.yield %2 : vector<1x32xf32> + } + vector.transfer_write %r, %dest[%c0, %laneid] : vector<1x1xf32>, memref<1x1024xf32> + return +}