diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -10,16 +10,113 @@ #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" +#include "mlir/IR/AffineExpr.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/Transforms/SideEffectUtils.h" #include "llvm/ADT/SetVector.h" +#include "llvm/Support/ErrorHandling.h" #include using namespace mlir; using namespace mlir::vector; +struct DistributedLoadStoreHelper { + DistributedLoadStoreHelper(Value sequentialVal, Value distributedVal, + Value laneId, Value zero) + : sequentialVal(sequentialVal), distributedVal(distributedVal), + laneId(laneId), zero(zero) { + sequentialType = sequentialVal.getType(); + distributedType = distributedVal.getType(); + sequentialVectorType = sequentialType.dyn_cast(); + distributedVectorType = distributedType.dyn_cast(); + if (distributedVectorType) + distributedTrailingSize = distributedVectorType.getShape().back(); + } + + Value buildDistributedOffset(RewriterBase &b, Location loc) { + AffineExpr tid = getAffineSymbolExpr(0, b.getContext()); + return b.createOrFold(loc, tid * distributedTrailingSize, + ArrayRef{laneId}); + } + + Operation *buildStore(RewriterBase &b, Location loc, Value val, + Value buffer) { + assert((val == distributedVal || val == sequentialVal) && + "Must store either the preregistered distributed or the " + "preregistered sequential value."); + if (val.getType().isa()) { + SmallVector indices(sequentialVectorType.getRank() - 1, zero); + if (val == distributedVal) + indices.push_back(buildDistributedOffset(b, loc)); + else + indices.push_back(zero); + SmallVector inBounds(indices.size(), true); + return b.create( + loc, val, buffer, indices, + ArrayRef(inBounds.begin(), inBounds.end())); + } + return b.create(loc, val, buffer, zero); + } + + /// When broadcastMode + Value buildLoad(RewriterBase &b, Location loc, Type type, Value buffer, + bool broadcastMode = false) { + if (broadcastMode) + return b.create(loc, buffer, zero); + + assert(type.isa() && "must be a vector type!"); + SmallVector indices(sequentialVectorType.getRank() - 1, zero); + assert((type == sequentialVectorType || type == distributedVectorType) && + "Must store either the preregistered distributed or the " + "preregistered sequential type."); + if (type == distributedVectorType) + indices.push_back(buildDistributedOffset(b, loc)); + else + indices.push_back(zero); + SmallVector inBounds(indices.size(), true); + return b.create( + loc, type.cast(), buffer, indices, + ArrayRef(inBounds.begin(), inBounds.end())); + } + + Value sequentialVal, distributedVal, laneId, zero; + Type sequentialType, distributedType; + VectorType sequentialVectorType, distributedVectorType; + int64_t distributedTrailingSize; +}; + +/// Rewrite a WarpExecuteOnLane0Op into a predicated scf.if op where the single +/// thread `laneId` executes the entirety of the computation. +/// +/// After the transformation: +/// - the IR within the scf.if op can be thought of as executing sequentially +/// (from the point of view of threads along `laneId`). +/// - the IR outside of the scf.if op can be thought of as executing in +/// parallel (from the point of view of threads along `laneId`). +/// +/// Values that need to transit through the parallel / sequential and the +/// sequantial / parallel boundaries do so via reads and writes to a temporary +/// memory location. +/// +/// The transformation proceeds in multiple steps: +/// 1. Create the scf.if op. +/// 2. Insert appropriate (alloc, write)-pairs before the scf.if and reads +/// within the scf.if to transit the values captured from above. +/// 3. Synchronize before the scf.if to ensure all writes inserted in 2. are +/// consistent within the scf.if. +/// 4. Move the body of the WarpExecuteOnLane0Op inside the scf.if. +/// 5. Insert appropriate writes within scf.if and reads after the scf.if to +/// transit the values returned by the op. +/// 6. Synchronize after the scf.if to ensure all writes inserted in 5. are +/// consistent after the scf.if. +/// 7. Perform late cleanups. +/// +/// All this assumes the vector distribution occurs along the most minor +/// dimension. +/// TODO: which is expected to be a multiple of the warp size ? static LogicalResult rewriteWarpOpToScfFor(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, const WarpExecuteOnLane0LoweringOptions &options) { @@ -32,7 +129,7 @@ OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(warpOp); - // Create scf.if op. + // Step 1: Create scf.if op. Value c0 = rewriter.create(loc, 0); Value isLane0 = rewriter.create(loc, arith::CmpIPredicate::eq, warpOp.getLaneid(), c0); @@ -40,93 +137,78 @@ /*withElseRegion=*/false); rewriter.eraseOp(ifOp.thenBlock()->getTerminator()); - // Store vectors that are defined outside of warpOp into the scratch pad - // buffer. + // Step 2: insert appropriate (alloc, write)-pairs before the scf.if and reads + // within the scf.if to transit the values captured from above. SmallVector bbArgReplacements; for (const auto &it : llvm::enumerate(warpOp.getArgs())) { - Value val = it.value(); - Value bbArg = warpOpBody->getArgument(it.index()); - - rewriter.setInsertionPoint(ifOp); - Value buffer = - options.warpAllocationFn(loc, rewriter, warpOp, bbArg.getType()); + Value sequentialVal = warpOpBody->getArgument(it.index()); + Value distributedVal = it.value(); + DistributedLoadStoreHelper helper(sequentialVal, distributedVal, + warpOp.getLaneid(), c0); - // Store arg vector into buffer. + // Create buffer before the ifOp. rewriter.setInsertionPoint(ifOp); - auto vectorType = val.getType().cast(); - int64_t storeSize = vectorType.getShape()[0]; - Value storeOffset = rewriter.create( - loc, warpOp.getLaneid(), - rewriter.create(loc, storeSize)); - rewriter.create(loc, val, buffer, storeOffset); - - // Load bbArg vector from buffer. + Value buffer = options.warpAllocationFn(loc, rewriter, warpOp, + sequentialVal.getType()); + // Store distributed vector into buffer, before the ifOp. + helper.buildStore(rewriter, loc, distributedVal, buffer); + // Load sequential vector from buffer, inside the ifOp. rewriter.setInsertionPointToStart(ifOp.thenBlock()); - auto bbArgType = bbArg.getType().cast(); - Value loadOp = rewriter.create(loc, bbArgType, buffer, c0); - bbArgReplacements.push_back(loadOp); + bbArgReplacements.push_back( + helper.buildLoad(rewriter, loc, sequentialVal.getType(), buffer)); } - // Insert sync after all the stores and before all the loads. + // Step 3. Insert sync after all the stores and before all the loads. if (!warpOp.getArgs().empty()) { rewriter.setInsertionPoint(ifOp); options.warpSyncronizationFn(loc, rewriter, warpOp); } - // Move body of warpOp to ifOp. + // Step 4. Move body of warpOp to ifOp. rewriter.mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements); - // Rewrite terminator and compute replacements of WarpOp results. + // Step 5. Insert appropriate writes within scf.if and reads after the scf.if + // to transit the values returned by the op. + // TODO: at this point, we can reuse the shared memory from previous buffers. SmallVector replacements; auto yieldOp = cast(ifOp.thenBlock()->getTerminator()); Location yieldLoc = yieldOp.getLoc(); for (const auto &it : llvm::enumerate(yieldOp.operands())) { - Value val = it.value(); - Type resultType = warpOp->getResultTypes()[it.index()]; + Value sequentialVal = it.value(); + Value distributedVal = warpOp->getResult(it.index()); + DistributedLoadStoreHelper helper(sequentialVal, distributedVal, + warpOp.getLaneid(), c0); + + // Create buffer before the ifOp. rewriter.setInsertionPoint(ifOp); - Value buffer = - options.warpAllocationFn(loc, rewriter, warpOp, val.getType()); + Value buffer = options.warpAllocationFn(loc, rewriter, warpOp, + sequentialVal.getType()); - // Store yielded value into buffer. + // Store yielded value into buffer, inside the ifOp, before the terminator. rewriter.setInsertionPoint(yieldOp); - if (val.getType().isa()) - rewriter.create(yieldLoc, val, buffer, c0); - else - rewriter.create(yieldLoc, val, buffer, c0); + helper.buildStore(rewriter, loc, sequentialVal, buffer); - // Load value from buffer (after warpOp). + // Load distributed value from buffer, after the warpOp. rewriter.setInsertionPointAfter(ifOp); - if (resultType == val.getType()) { - // Result type and yielded value type are the same. This is a broadcast. - // E.g.: - // %r = vector.warp_execute_on_lane_0(...) -> (f32) { - // vector.yield %cst : f32 - // } - // Both types are f32. The constant %cst is broadcasted to all lanes. - // This is described in more detail in the documentation of the op. - Value loadOp = rewriter.create(loc, buffer, c0); - replacements.push_back(loadOp); - } else { - auto loadedVectorType = resultType.cast(); - int64_t loadSize = loadedVectorType.getShape()[0]; - - // loadOffset = laneid * loadSize - Value loadOffset = rewriter.create( - loc, warpOp.getLaneid(), - rewriter.create(loc, loadSize)); - Value loadOp = rewriter.create(loc, loadedVectorType, - buffer, loadOffset); - replacements.push_back(loadOp); - } + bool broadcastMode = (sequentialVal.getType() == distributedVal.getType()); + // Result type and yielded value type are the same. This is a broadcast. + // E.g.: + // %r = vector.warp_execute_on_lane_0(...) -> (f32) { + // vector.yield %cst : f32 + // } + // Both types are f32. The constant %cst is broadcasted to all lanes. + // This is described in more detail in the documentation of the op. + replacements.push_back(helper.buildLoad( + rewriter, loc, distributedVal.getType(), buffer, broadcastMode)); } - // Insert sync after all the stores and before all the loads. + // Step 6. Insert sync after all the stores and before all the loads. if (!yieldOp.operands().empty()) { rewriter.setInsertionPointAfter(ifOp); options.warpSyncronizationFn(loc, rewriter, warpOp); } - // Delete terminator and add empty scf.yield. + // Step 7. Delete terminator and add empty scf.yield. rewriter.eraseOp(yieldOp); rewriter.setInsertionPointToEnd(ifOp.thenBlock()); rewriter.create(yieldLoc); diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir --- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir +++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir @@ -4,7 +4,9 @@ // RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -test-vector-warp-distribute=propagate-distribution -canonicalize | FileCheck --check-prefixes=CHECK-PROP %s // RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -test-vector-warp-distribute="hoist-uniform distribute-transfer-write propagate-distribution" -canonicalize | FileCheck --check-prefixes=CHECK-DIST-AND-PROP %s - +// CHECK-SCF-IF-DAG: #[[$TIMES2:.*]] = affine_map<()[s0] -> (s0 * 2)> +// CHECK-SCF-IF-DAG: #[[$TIMES4:.*]] = affine_map<()[s0] -> (s0 * 4)> +// CHECK-SCF-IF-DAG: #[[$TIMES8:.*]] = affine_map<()[s0] -> (s0 * 8)> // CHECK-SCF-IF-DAG: memref.global "private" @__shared_32xf32 : memref<32xf32, 3> // CHECK-SCF-IF-DAG: memref.global "private" @__shared_64xf32 : memref<64xf32, 3> // CHECK-SCF-IF-DAG: memref.global "private" @__shared_128xf32 : memref<128xf32, 3> @@ -16,17 +18,14 @@ func.func @rewrite_warp_op_to_scf_if(%laneid: index, %v0: vector<4xf32>, %v1: vector<8xf32>) { // CHECK-SCF-IF-DAG: %[[c0:.*]] = arith.constant 0 : index -// CHECK-SCF-IF-DAG: %[[c2:.*]] = arith.constant 2 : index -// CHECK-SCF-IF-DAG: %[[c4:.*]] = arith.constant 4 : index -// CHECK-SCF-IF-DAG: %[[c8:.*]] = arith.constant 8 : index // CHECK-SCF-IF: %[[is_lane_0:.*]] = arith.cmpi eq, %[[laneid]], %[[c0]] // CHECK-SCF-IF: %[[buffer_v0:.*]] = memref.get_global @__shared_128xf32 -// CHECK-SCF-IF: %[[s0:.*]] = arith.muli %[[laneid]], %[[c4]] -// CHECK-SCF-IF: vector.store %[[v0]], %[[buffer_v0]][%[[s0]]] +// CHECK-SCF-IF: %[[s0:.*]] = affine.apply #[[$TIMES4]]()[%[[laneid]]] +// CHECK-SCF-IF: vector.transfer_write %[[v0]], %[[buffer_v0]][%[[s0]]] // CHECK-SCF-IF: %[[buffer_v1:.*]] = memref.get_global @__shared_256xf32 -// CHECK-SCF-IF: %[[s1:.*]] = arith.muli %[[laneid]], %[[c8]] -// CHECK-SCF-IF: vector.store %[[v1]], %[[buffer_v1]][%[[s1]]] +// CHECK-SCF-IF: %[[s1:.*]] = affine.apply #[[$TIMES8]]()[%[[laneid]]] +// CHECK-SCF-IF: vector.transfer_write %[[v1]], %[[buffer_v1]][%[[s1]]] // CHECK-SCF-IF-DAG: gpu.barrier // CHECK-SCF-IF-DAG: %[[buffer_def_0:.*]] = memref.get_global @__shared_32xf32 @@ -36,21 +35,21 @@ %r:2 = vector.warp_execute_on_lane_0(%laneid)[32] args(%v0, %v1 : vector<4xf32>, vector<8xf32>) -> (vector<1xf32>, vector<2xf32>) { ^bb0(%arg0: vector<128xf32>, %arg1: vector<256xf32>): -// CHECK-SCF-IF: %[[arg1:.*]] = vector.load %[[buffer_v1]][%[[c0]]] : memref<256xf32, 3>, vector<256xf32> -// CHECK-SCF-IF: %[[arg0:.*]] = vector.load %[[buffer_v0]][%[[c0]]] : memref<128xf32, 3>, vector<128xf32> +// CHECK-SCF-IF: %[[arg1:.*]] = vector.transfer_read %[[buffer_v1]][%[[c0]]], %{{.*}} {in_bounds = [true]} : memref<256xf32, 3>, vector<256xf32> +// CHECK-SCF-IF: %[[arg0:.*]] = vector.transfer_read %[[buffer_v0]][%[[c0]]], %{{.*}} {in_bounds = [true]} : memref<128xf32, 3>, vector<128xf32> // CHECK-SCF-IF: %[[def_0:.*]] = "some_def"(%[[arg0]]) : (vector<128xf32>) -> vector<32xf32> // CHECK-SCF-IF: %[[def_1:.*]] = "some_def"(%[[arg1]]) : (vector<256xf32>) -> vector<64xf32> %2 = "some_def"(%arg0) : (vector<128xf32>) -> vector<32xf32> %3 = "some_def"(%arg1) : (vector<256xf32>) -> vector<64xf32> -// CHECK-SCF-IF: vector.store %[[def_0]], %[[buffer_def_0]][%[[c0]]] -// CHECK-SCF-IF: vector.store %[[def_1]], %[[buffer_def_1]][%[[c0]]] +// CHECK-SCF-IF: vector.transfer_write %[[def_0]], %[[buffer_def_0]][%[[c0]]] +// CHECK-SCF-IF: vector.transfer_write %[[def_1]], %[[buffer_def_1]][%[[c0]]] vector.yield %2, %3 : vector<32xf32>, vector<64xf32> } // CHECK-SCF-IF: } // CHECK-SCF-IF: gpu.barrier -// CHECK-SCF-IF: %[[o1:.*]] = arith.muli %[[laneid]], %[[c2]] -// CHECK-SCF-IF: %[[r1:.*]] = vector.load %[[buffer_def_1]][%[[o1]]] : memref<64xf32, 3>, vector<2xf32> -// CHECK-SCF-IF: %[[r0:.*]] = vector.load %[[buffer_def_0]][%[[laneid]]] : memref<32xf32, 3>, vector<1xf32> +// CHECK-SCF-IF: %[[o1:.*]] = affine.apply #[[$TIMES2]]()[%[[laneid]]] +// CHECK-SCF-IF: %[[r1:.*]] = vector.transfer_read %[[buffer_def_1]][%[[o1]]], %{{.*}} {in_bounds = [true]} : memref<64xf32, 3>, vector<2xf32> +// CHECK-SCF-IF: %[[r0:.*]] = vector.transfer_read %[[buffer_def_0]][%[[laneid]]], %{{.*}} {in_bounds = [true]} : memref<32xf32, 3>, vector<1xf32> // CHECK-SCF-IF: "some_use"(%[[r0]]) : (vector<1xf32>) -> () // CHECK-SCF-IF: "some_use"(%[[r1]]) : (vector<2xf32>) -> () "some_use"(%r#0) : (vector<1xf32>) -> () @@ -631,3 +630,23 @@ } return %r : f32 } + +// ----- + +// CHECK-PROP: func @lane_dependent_warp_propagate_read +// CHECK-PROP-SAME: %[[ID:.*]]: index +func.func @lane_dependent_warp_propagate_read( + %laneid: index, %src: memref<1x1024xf32>, %dest: memref<1x1024xf32>) { + // CHECK-PROP-DAG: %[[C0:.*]] = arith.constant 0 : index + // CHECK-PROP-NOT: lane_dependent_warp_propagate_read + // CHECK-PROP-DAG: %[[R0:.*]] = vector.transfer_read %arg1[%[[C0]], %[[ID]]], %{{.*}} : memref<1x1024xf32>, vector<1x1xf32> + // CHECK-PROP: vector.transfer_write %[[R0]], {{.*}} : vector<1x1xf32>, memref<1x1024xf32> + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f32 + %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<1x1xf32>) { + %2 = vector.transfer_read %src[%c0, %c0], %cst : memref<1x1024xf32>, vector<1x32xf32> + vector.yield %2 : vector<1x32xf32> + } + vector.transfer_write %r, %dest[%c0, %laneid] : vector<1x1xf32>, memref<1x1024xf32> + return +}