diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -2702,6 +2702,10 @@ bool isDefinedOutsideOfRegion(Value value) { return !getRegion().isAncestor(value.getParentRegion()); } + // Helper method to replace all uses of the laneId operand by the constant + // 0 inside the region. This is a necessary prerequisite to perform any + // kind of hoisting of IR that is inside the region. + LogicalResult replaceAllUsesOfLaneWithin(RewriterBase &b); }]; } diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -5088,6 +5088,30 @@ verifyDistributedType(lhs, rhs, getWarpSize(), getOperation())); } +// Helper method to replace all uses of the laneId operand by the constant +// 0 inside the region. This is a necessary prerequisite to perform any kind of +// hoisting of IR that is inside the region. +// Return success if any replacement occurred, failure otherwise. +LogicalResult +WarpExecuteOnLane0Op::replaceAllUsesOfLaneWithin(RewriterBase &b) { + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(*this); + Value zero = b.create(getLoc(), 0); + b.setInsertionPointToStart(&getWarpRegion().front()); + Value laneId = getLaneid(); + // TODO: This could be lifted to e.g. Operation::replaceNestedUsesOfWith. + bool applied = false; + for (Operation *user : laneId.getUsers()) { + if (!getOperation()->isProperAncestor(user)) + continue; + b.startRootUpdate(user); + user->replaceUsesOfWith(laneId, zero); + b.finalizeRootUpdate(user); + applied = true; + } + return success(applied); +} + Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value v2) { Type t1 = getElementTypeOrSelf(v1.getType()); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -10,16 +10,110 @@ #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" +#include "mlir/IR/AffineExpr.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/Transforms/SideEffectUtils.h" #include "llvm/ADT/SetVector.h" +#include "llvm/Support/ErrorHandling.h" #include using namespace mlir; using namespace mlir::vector; +struct DistributedLoadStoreHelper { + DistributedLoadStoreHelper(Value sequentialVal, Value distributedVal, + Value laneId, Value zero) + : sequentialVal(sequentialVal), distributedVal(distributedVal), + laneId(laneId), zero(zero) { + sequentialType = sequentialVal.getType(); + distributedType = distributedVal.getType(); + sequentialVectorType = sequentialType.dyn_cast(); + distributedVectorType = distributedType.dyn_cast(); + if (distributedVectorType) + distributedTrailingSize = distributedVectorType.getShape().back(); + } + + Value buildDistributedOffset(RewriterBase &b, Location loc) { + AffineExpr tid = getAffineSymbolExpr(0, b.getContext()); + return b.createOrFold(loc, tid * distributedTrailingSize, + ArrayRef{laneId}); + } + + Operation *buildStore(RewriterBase &b, Location loc, Value val, + Value buffer) { + assert((val == distributedVal || val == sequentialVal) && + "Must store either the preregistered distributed or the " + "preregistered sequential value."); + if (val.getType().isa()) { + SmallVector indices(sequentialVectorType.getRank() - 1, zero); + if (val == distributedVal) + indices.push_back(buildDistributedOffset(b, loc)); + else + indices.push_back(zero); + SmallVector inBounds(indices.size(), true); + return b.create( + loc, val, buffer, indices, + ArrayRef(inBounds.begin(), inBounds.end())); + } + return b.create(loc, val, buffer, zero); + } + + Value buildLoad(RewriterBase &b, Location loc, Type type, Value buffer, + bool broadcastMode = false) { + if (broadcastMode) + return b.create(loc, buffer, zero); + + assert(type.isa() && "must be a vector type!"); + SmallVector indices(sequentialVectorType.getRank() - 1, zero); + assert((type == sequentialVectorType || type == distributedVectorType) && + "Must store either the preregistered distributed or the " + "preregistered sequential type."); + if (type == distributedVectorType) + indices.push_back(buildDistributedOffset(b, loc)); + else + indices.push_back(zero); + return b.create(loc, type.cast(), + buffer, indices); + } + + Value sequentialVal, distributedVal, laneId, zero; + Type sequentialType, distributedType; + VectorType sequentialVectorType, distributedVectorType; + int64_t distributedTrailingSize; +}; + +/// Rewrite a WarpExecuteOnLane0Op into a predicated scf.if op where the single +/// thread `laneId` executes the entirety of the computation. +/// +/// After the transformation: +/// - the IR within the scf.if op can be thought of as executing sequentially +/// (from the point of view of threads along `laneId`). +/// - the IR outside of the scf.if op can be thought of as executing in +/// parallel (from the point of view of threads along `laneId`). +/// +/// Values that need to transit through the parallel / sequential and the +/// sequantial / parallel boundaries do so via reads and writes to a temporary +/// memory location. +/// +/// The transformation proceeds in multiple steps: +/// 1. Create the scf.if op. +/// 2. Insert appropriate (alloc, write)-pairs before the scf.if and reads +/// within the scf.if to transit the values captured from above. +/// 3. Synchronize before the scf.if to ensure all writes inserted in 2. are +/// consistent within the scf.if. +/// 4. Move the body of the WarpExecuteOnLane0Op inside the scf.if. +/// 5. Insert appropriate writes within scf.if and reads after the scf.if to +/// transit the values returned by the op. +/// 6. Synchronize after the scf.if to ensure all writes inserted in 5. are +/// consistent after the scf.if. +/// 7. Perform late cleanups. +/// +/// All this assumes the vector distribution occurs along the most minor +/// dimension. +/// TODO: which is expected to be a multiple of the warp size ? static LogicalResult rewriteWarpOpToScfFor(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, const WarpExecuteOnLane0LoweringOptions &options) { @@ -32,7 +126,7 @@ OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(warpOp); - // Create scf.if op. + // Step 1: Create scf.if op. Value c0 = rewriter.create(loc, 0); Value isLane0 = rewriter.create(loc, arith::CmpIPredicate::eq, warpOp.getLaneid(), c0); @@ -40,93 +134,89 @@ /*withElseRegion=*/false); rewriter.eraseOp(ifOp.thenBlock()->getTerminator()); - // Store vectors that are defined outside of warpOp into the scratch pad - // buffer. + // Step 2: insert appropriate (alloc, write)-pairs before the scf.if and reads + // within the scf.if to transit the values captured from above. SmallVector bbArgReplacements; for (const auto &it : llvm::enumerate(warpOp.getArgs())) { - Value val = it.value(); - Value bbArg = warpOpBody->getArgument(it.index()); - - rewriter.setInsertionPoint(ifOp); - Value buffer = - options.warpAllocationFn(loc, rewriter, warpOp, bbArg.getType()); + Value sequentialVal = warpOpBody->getArgument(it.index()); + Value distributedVal = it.value(); + DistributedLoadStoreHelper helper(sequentialVal, distributedVal, + warpOp.getLaneid(), c0); - // Store arg vector into buffer. + // Create buffer before the ifOp. rewriter.setInsertionPoint(ifOp); - auto vectorType = val.getType().cast(); - int64_t storeSize = vectorType.getShape()[0]; - Value storeOffset = rewriter.create( - loc, warpOp.getLaneid(), - rewriter.create(loc, storeSize)); - rewriter.create(loc, val, buffer, storeOffset); - - // Load bbArg vector from buffer. + Value buffer = options.warpAllocationFn(loc, rewriter, warpOp, + sequentialVal.getType()); + // Store distributed vector into buffer, before the ifOp. + helper.buildStore(rewriter, loc, distributedVal, buffer); + // Load sequential vector from buffer, inside the ifOp. rewriter.setInsertionPointToStart(ifOp.thenBlock()); - auto bbArgType = bbArg.getType().cast(); - Value loadOp = rewriter.create(loc, bbArgType, buffer, c0); - bbArgReplacements.push_back(loadOp); + bbArgReplacements.push_back( + helper.buildLoad(rewriter, loc, sequentialVal.getType(), buffer)); } - // Insert sync after all the stores and before all the loads. + // Step 3. Insert sync after all the stores and before all the loads. if (!warpOp.getArgs().empty()) { rewriter.setInsertionPoint(ifOp); options.warpSyncronizationFn(loc, rewriter, warpOp); } - // Move body of warpOp to ifOp. + // Step 4. Move body of warpOp to ifOp. rewriter.mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements); - // Rewrite terminator and compute replacements of WarpOp results. + // Step 5. Insert appropriate writes within scf.if and reads after the scf.if + // to transit the values returned by the op. + // TODO: at this point, we can reuse the shared memory from previous buffers. SmallVector replacements; auto yieldOp = cast(ifOp.thenBlock()->getTerminator()); Location yieldLoc = yieldOp.getLoc(); for (const auto &it : llvm::enumerate(yieldOp.operands())) { - Value val = it.value(); - Type resultType = warpOp->getResultTypes()[it.index()]; + Value sequentialVal = it.value(); + Value distributedVal = warpOp->getResult(it.index()); + // Value val = it.value(); + // Type resultType = warpOp->getResultTypes()[it.index()]; + + DistributedLoadStoreHelper helper(sequentialVal, distributedVal, + warpOp.getLaneid(), c0); + + // Create buffer before the ifOp. rewriter.setInsertionPoint(ifOp); - Value buffer = - options.warpAllocationFn(loc, rewriter, warpOp, val.getType()); + Value buffer = options.warpAllocationFn(loc, rewriter, warpOp, + sequentialVal.getType()); - // Store yielded value into buffer. + // Store yielded value into buffer, inside the ifOp, before the terminator. rewriter.setInsertionPoint(yieldOp); - if (val.getType().isa()) - rewriter.create(yieldLoc, val, buffer, c0); - else - rewriter.create(yieldLoc, val, buffer, c0); - - // Load value from buffer (after warpOp). + // VectorType vectorType = val.getType().dyn_cast(); + // if (vectorType) { + // SmallVector indices(vectorType.getRank(), c0); + // rewriter.create(yieldLoc, val, buffer, + // indices); + // } else { + // rewriter.create(yieldLoc, val, buffer, c0); + // } + helper.buildStore(rewriter, loc, sequentialVal, buffer); + + // Load distributed value from buffer, after the warpOp. rewriter.setInsertionPointAfter(ifOp); - if (resultType == val.getType()) { - // Result type and yielded value type are the same. This is a broadcast. - // E.g.: - // %r = vector.warp_execute_on_lane_0(...) -> (f32) { - // vector.yield %cst : f32 - // } - // Both types are f32. The constant %cst is broadcasted to all lanes. - // This is described in more detail in the documentation of the op. - Value loadOp = rewriter.create(loc, buffer, c0); - replacements.push_back(loadOp); - } else { - auto loadedVectorType = resultType.cast(); - int64_t loadSize = loadedVectorType.getShape()[0]; - - // loadOffset = laneid * loadSize - Value loadOffset = rewriter.create( - loc, warpOp.getLaneid(), - rewriter.create(loc, loadSize)); - Value loadOp = rewriter.create(loc, loadedVectorType, - buffer, loadOffset); - replacements.push_back(loadOp); - } + bool broadcastMode = (sequentialVal.getType() == distributedVal.getType()); + // Result type and yielded value type are the same. This is a broadcast. + // E.g.: + // %r = vector.warp_execute_on_lane_0(...) -> (f32) { + // vector.yield %cst : f32 + // } + // Both types are f32. The constant %cst is broadcasted to all lanes. + // This is described in more detail in the documentation of the op. + replacements.push_back(helper.buildLoad( + rewriter, loc, distributedVal.getType(), buffer, broadcastMode)); } - // Insert sync after all the stores and before all the loads. + // Step 6. Insert sync after all the stores and before all the loads. if (!yieldOp.operands().empty()) { rewriter.setInsertionPointAfter(ifOp); options.warpSyncronizationFn(loc, rewriter, warpOp); } - // Delete terminator and add empty scf.yield. + // Step 7. Delete terminator and add empty scf.yield. rewriter.eraseOp(yieldOp); rewriter.setInsertionPointToEnd(ifOp.thenBlock()); rewriter.create(yieldLoc); @@ -261,6 +351,20 @@ namespace { +/// Pattern to make the body of WarpExecuteOnLane0Op independent of the laneId. +/// It is crucial to apply this pattern with higher benefit before any hoisting +/// or ditribution pattern is applied. +struct ReplaceAllUsesOfLaneWithinWarpExecuteOnLane0 + : public OpRewritePattern { + ReplaceAllUsesOfLaneWithinWarpExecuteOnLane0(MLIRContext *context) + : OpRewritePattern(context, /*benefit=*/10) {} + + LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, + PatternRewriter &rewriter) const override { + return warpOp.replaceAllUsesOfLaneWithin(rewriter); + } +}; + struct WarpOpToScfForPattern : public OpRewritePattern { WarpOpToScfForPattern(MLIRContext *context, const WarpExecuteOnLane0LoweringOptions &options, @@ -931,24 +1035,32 @@ void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern( RewritePatternSet &patterns, const WarpExecuteOnLane0LoweringOptions &options) { + patterns.add( + patterns.getContext()); patterns.add(patterns.getContext(), options); } void mlir::vector::populateDistributeTransferWriteOpPatterns( RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn) { + patterns.add( + patterns.getContext()); patterns.add(patterns.getContext(), distributionMapFn); } void mlir::vector::populatePropagateWarpVectorDistributionPatterns( RewritePatternSet &patterns) { - patterns.add(patterns.getContext()); + patterns + .add( + patterns.getContext()); } void mlir::vector::populateDistributeReduction( RewritePatternSet &patterns, DistributedReductionFn distributedReductionFn) { + patterns.add( + patterns.getContext()); patterns.add(patterns.getContext(), distributedReductionFn); } diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir --- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir +++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir @@ -631,3 +631,21 @@ } return %r : f32 } + +// ----- + +// CHECK-PROP: func @lane_dependent_warp_propagate_read +func.func @lane_dependent_warp_propagate_read(%laneid: index, %src: memref<1x1024xf32>, %dest: memref<1x1024xf32>) { + // CHECK-PROP-DAG: %[[C0]] = arith.constant 0 : index + // CHECK-PROP-NOT: lane_dependent_warp_propagate_read + // CHECK-PROP-DAG: %[[R0:.*]] = vector.transfer_read %arg1[%[[C0]], %[[ID]]], %{{.*}} : memref<1x1024xf32>, vector<1x1xf32> + // CHECK-PROP: vector.transfer_write %[[R0]], {{.*}} : vector<1x1xf32>, memref<1x1024xf32> + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f32 + %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<1x1xf32>) { + %2 = vector.transfer_read %src[%laneid, %c0], %cst : memref<1x1024xf32>, vector<1x32xf32> + vector.yield %2 : vector<1x32xf32> + } + vector.transfer_write %r, %dest[%c0, %laneid] : vector<1x1xf32>, memref<1x1024xf32> + return +}