Index: mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp =================================================================== --- mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -20,20 +20,33 @@ using namespace mlir; using namespace mlir::vector; -/// TODO: add an analysis step that determines which vector dimension should be -/// used for distribution. -static llvm::Optional -getDistributedVectorDim(VectorType distributedVectorType) { - return (distributedVectorType) - ? llvm::Optional(distributedVectorType.getRank() - 1) - : llvm::None; -} - -static llvm::Optional -getDistributedSize(VectorType distributedVectorType) { - auto dim = getDistributedVectorDim(distributedVectorType); - return (dim) ? llvm::Optional(distributedVectorType.getDimSize(*dim)) - : llvm::None; +/// Currently the distribution map is implicit based on the vector shape. In the +/// future it will be part of the op. +/// Example: +/// ``` +/// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1x16x2xf32>) { +/// ... +/// vector.yield %3 : vector<32x16x64xf32> +/// } +/// ``` +/// Would have an implicit map of: +/// `(d0, d1, d2) -> (d0, d2)` +static AffineMap calculateImplicitMap(VectorType sequentialType, + VectorType distributedType) { + SmallVector perm; + perm.reserve(1); + // Check which dimensions of the sequential type are different than the + // dimensions of the distributed type to know the distributed dimensions. Then + // associate each distributed dimension to an ID in order. + for (unsigned i = 0, e = sequentialType.getRank(); i < e; i++) { + if (sequentialType.getDimSize(i) != distributedType.getDimSize(i)) + perm.push_back(getAffineDimExpr(i, distributedType.getContext())); + } + auto map = AffineMap::get(sequentialType.getRank(), 0, perm, + distributedType.getContext()); + assert(map.getNumResults() <= 1 && + "only support distribution along one dimension for now."); + return map; } namespace { @@ -42,28 +55,23 @@ /// through the parallel / sequential and the sequential / parallel boundaries /// when performing `rewriteWarpOpToScfFor`. /// -/// All this assumes the vector distribution occurs along the most minor -/// distributed vector dimension. -/// TODO: which is expected to be a multiple of the warp size ? -/// TODO: add an analysis step that determines which vector dimension should -/// be used for distribution. +/// The vector distribution dimension is inferred from the vector types. struct DistributedLoadStoreHelper { DistributedLoadStoreHelper(Value sequentialVal, Value distributedVal, Value laneId, Value zero) : sequentialVal(sequentialVal), distributedVal(distributedVal), laneId(laneId), zero(zero) { - sequentialType = sequentialVal.getType(); - distributedType = distributedVal.getType(); - sequentialVectorType = sequentialType.dyn_cast(); - distributedVectorType = distributedType.dyn_cast(); + sequentialVectorType = sequentialVal.getType().dyn_cast(); + distributedVectorType = distributedVal.getType().dyn_cast(); + if (sequentialVectorType && distributedVectorType) + distributionMap = + calculateImplicitMap(sequentialVectorType, distributedVectorType); } - Value buildDistributedOffset(RewriterBase &b, Location loc) { - auto maybeDistributedSize = getDistributedSize(distributedVectorType); - assert(maybeDistributedSize && - "at this point, a distributed size must be determined"); + Value buildDistributedOffset(RewriterBase &b, Location loc, int64_t index) { + int64_t distributedSize = distributedVectorType.getDimSize(index); AffineExpr tid = getAffineSymbolExpr(0, b.getContext()); - return b.createOrFold(loc, tid * (*maybeDistributedSize), + return b.createOrFold(loc, tid * distributedSize, ArrayRef{laneId}); } @@ -79,27 +87,24 @@ assert((val == distributedVal || val == sequentialVal) && "Must store either the preregistered distributed or the " "preregistered sequential value."); + // Scalar case can directly use memref.store. + if (!val.getType().isa()) + return b.create(loc, val, buffer, zero); + // Vector case must use vector::TransferWriteOp which will later lower to // vector.store of memref.store depending on further lowerings. - if (val.getType().isa()) { - int64_t rank = sequentialVectorType.getRank(); - if (rank == 0) { - return b.create(loc, val, buffer, ValueRange{}, - ArrayRef{}); + int64_t rank = sequentialVectorType.getRank(); + SmallVector indices(rank, zero); + if (val == distributedVal) { + for (auto dimExpr : distributionMap.getResults()) { + int64_t index = dimExpr.cast().getPosition(); + indices[index] = buildDistributedOffset(b, loc, index); } - SmallVector indices(rank, zero); - auto maybeDistributedDim = getDistributedVectorDim(distributedVectorType); - assert(maybeDistributedDim && "must be able to deduce distributed dim"); - if (val == distributedVal) - indices[*maybeDistributedDim] = - (val == distributedVal) ? buildDistributedOffset(b, loc) : zero; - SmallVector inBounds(indices.size(), true); - return b.create( - loc, val, buffer, indices, - ArrayRef(inBounds.begin(), inBounds.end())); } - // Scalar case can directly use memref.store. - return b.create(loc, val, buffer, zero); + SmallVector inBounds(indices.size(), true); + return b.create( + loc, val, buffer, indices, + ArrayRef(inBounds.begin(), inBounds.end())); } /// Create a load during the process of distributing the @@ -122,36 +127,24 @@ /// // Both types are f32. The constant %cst is broadcasted to all lanes. /// ``` /// This behavior described in more detail in the documentation of the op. - Value buildLoad(RewriterBase &b, Location loc, Type type, Value buffer, - bool broadcastMode = false) { - if (broadcastMode) { - // Broadcast mode may occur for either scalar or vector operands. - auto vectorType = type.dyn_cast(); - auto shape = buffer.getType().cast(); - if (vectorType) { - SmallVector inBounds(shape.getRank(), true); - return b.create( - loc, vectorType, buffer, - /*indices=*/SmallVector(shape.getRank(), zero), - ArrayRef(inBounds.begin(), inBounds.end())); - } + Value buildLoad(RewriterBase &b, Location loc, Type type, Value buffer) { + + // Scalar case can directly use memref.store. + if (!type.isa()) return b.create(loc, buffer, zero); - } // Other cases must be vector atm. // Vector case must use vector::TransferReadOp which will later lower to // vector.read of memref.read depending on further lowerings. - assert(type.isa() && "must be a vector type"); assert((type == distributedVectorType || type == sequentialVectorType) && "Must store either the preregistered distributed or the " "preregistered sequential type."); - auto maybeDistributedDim = getDistributedVectorDim(distributedVectorType); - assert(maybeDistributedDim && "must be able to deduce distributed dim"); SmallVector indices(sequentialVectorType.getRank(), zero); if (type == distributedVectorType) { - indices[*maybeDistributedDim] = buildDistributedOffset(b, loc); - } else { - indices[*maybeDistributedDim] = zero; + for (auto dimExpr : distributionMap.getResults()) { + int64_t index = dimExpr.cast().getPosition(); + indices[index] = buildDistributedOffset(b, loc, index); + } } SmallVector inBounds(indices.size(), true); return b.create( @@ -160,8 +153,8 @@ } Value sequentialVal, distributedVal, laneId, zero; - Type sequentialType, distributedType; VectorType sequentialVectorType, distributedVectorType; + AffineMap distributionMap; }; } // namespace @@ -262,32 +255,6 @@ return rewriter.create(res); } -/// Currently the distribution map is implicit based on the vector shape. In the -/// future it will be part of the op. -/// Example: -/// ``` -/// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1x16x2xf32>) { -/// ... -/// vector.yield %3 : vector<32x16x64xf32> -/// } -/// ``` -/// Would have an implicit map of: -/// `(d0, d1, d2) -> (d0, d2)` -static AffineMap calculateImplicitMap(Value yield, Value ret) { - auto srcType = yield.getType().cast(); - auto dstType = ret.getType().cast(); - SmallVector perm; - // Check which dimensions of the yield value are different than the dimensions - // of the result to know the distributed dimensions. Then associate each - // distributed dimension to an ID in order. - for (unsigned i = 0, e = srcType.getRank(); i < e; i++) { - if (srcType.getDimSize(i) != dstType.getDimSize(i)) - perm.push_back(getAffineDimExpr(i, yield.getContext())); - } - auto map = AffineMap::get(srcType.getRank(), 0, perm, yield.getContext()); - return map; -} - namespace { /// Rewrite a WarpExecuteOnLane0Op into a predicated scf.if op where the single @@ -318,13 +285,10 @@ /// /// All this assumes the vector distribution occurs along the most minor /// distributed vector dimension. -/// TODO: which is expected to be a multiple of the warp size ? -/// TODO: add an analysis step that determines which vector dimension should be -/// used for distribution. -struct WarpOpToScfForPattern : public OpRewritePattern { - WarpOpToScfForPattern(MLIRContext *context, - const WarpExecuteOnLane0LoweringOptions &options, - PatternBenefit benefit = 1) +struct WarpOpToScfIfPattern : public OpRewritePattern { + WarpOpToScfIfPattern(MLIRContext *context, + const WarpExecuteOnLane0LoweringOptions &options, + PatternBenefit benefit = 1) : OpRewritePattern(context, benefit), options(options) {} @@ -364,10 +328,8 @@ helper.buildStore(rewriter, loc, distributedVal, buffer); // Load sequential vector from buffer, inside the ifOp. rewriter.setInsertionPointToStart(ifOp.thenBlock()); - bool broadcastMode = - (sequentialVal.getType() == distributedVal.getType()); - bbArgReplacements.push_back(helper.buildLoad( - rewriter, loc, sequentialVal.getType(), buffer, broadcastMode)); + bbArgReplacements.push_back( + helper.buildLoad(rewriter, loc, sequentialVal.getType(), buffer)); } // Step 3. Insert sync after all the stores and before all the loads. @@ -404,8 +366,6 @@ // Load distributed value from buffer, after the warpOp. rewriter.setInsertionPointAfter(ifOp); - bool broadcastMode = - (sequentialVal.getType() == distributedVal.getType()); // Result type and yielded value type are the same. This is a broadcast. // E.g.: // %r = vector.warp_execute_on_lane_0(...) -> (f32) { @@ -413,8 +373,8 @@ // } // Both types are f32. The constant %cst is broadcasted to all lanes. // This is described in more detail in the documentation of the op. - replacements.push_back(helper.buildLoad( - rewriter, loc, distributedVal.getType(), buffer, broadcastMode)); + replacements.push_back( + helper.buildLoad(rewriter, loc, distributedVal.getType(), buffer)); } // Step 6. Insert sync after all the stores and before all the loads. @@ -758,7 +718,9 @@ SmallVector indices(read.getIndices().begin(), read.getIndices().end()); - AffineMap map = calculateImplicitMap(read.getResult(), distributedVal); + auto sequentialType = read.getResult().getType().cast(); + auto distributedType = distributedVal.getType().cast(); + AffineMap map = calculateImplicitMap(sequentialType, distributedType); AffineMap indexMap = map.compose(read.getPermutationMap()); OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPointAfter(warpOp); @@ -1118,7 +1080,7 @@ void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern( RewritePatternSet &patterns, const WarpExecuteOnLane0LoweringOptions &options, PatternBenefit benefit) { - patterns.add(patterns.getContext(), options, benefit); + patterns.add(patterns.getContext(), options, benefit); } void mlir::vector::populateDistributeTransferWriteOpPatterns( Index: mlir/test/Dialect/Vector/vector-warp-distribute.mlir =================================================================== --- mlir/test/Dialect/Vector/vector-warp-distribute.mlir +++ mlir/test/Dialect/Vector/vector-warp-distribute.mlir @@ -734,3 +734,45 @@ // CHECK-SCF-IF: return %[[RS0]], %[[RV0]], %[[RV1]], %[[RV2]] : f32, vector, vector<1xf32>, vector<1x1xf32> return %r#0, %r#1, %r#2, %r#3 : f32, vector, vector<1xf32>, vector<1x1xf32> } + +// ----- + +// CHECK-SCF-IF-DAG: #[[$TIMES2:.*]] = affine_map<()[s0] -> (s0 * 2)> + +// CHECK-SCF-IF: func @warp_execute_nd_distribute +// CHECK-SCF-IF-SAME: (%[[LANEID:.*]]: index +func.func @warp_execute_nd_distribute(%laneid: index, %v0: vector<1x64x1xf32>, %v1: vector<1x2x128xf32>) + -> (vector<1x64x1xf32>, vector<1x2x128xf32>) { + // CHECK-SCF-IF-DAG: %[[C0:.*]] = arith.constant 0 : index + + // CHECK-SCF-IF: vector.transfer_write %{{.*}}, %{{.*}}[%[[LANEID]], %c0, %c0] {in_bounds = [true, true, true]} : vector<1x64x1xf32>, memref<32x64x1xf32, 3> + // CHECK-SCF-IF: %[[RID:.*]] = affine.apply #[[$TIMES2]]()[%[[LANEID]]] + // CHECK-SCF-IF: vector.transfer_write %{{.*}}, %{{.*}}[%[[C0]], %[[RID]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x2x128xf32>, memref<1x64x128xf32, 3> + // CHECK-SCF-IF: gpu.barrier + + // CHECK-SCF-IF: scf.if{{.*}}{ + %r:2 = vector.warp_execute_on_lane_0(%laneid)[32] + args(%v0, %v1 : vector<1x64x1xf32>, vector<1x2x128xf32>) -> (vector<1x64x1xf32>, vector<1x2x128xf32>) { + ^bb0(%arg0: vector<32x64x1xf32>, %arg1: vector<1x64x128xf32>): + + // CHECK-SCF-IF-DAG: %[[SR0:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]]], %{{.*}} {in_bounds = [true, true, true]} : memref<32x64x1xf32, 3>, vector<32x64x1xf32> + // CHECK-SCF-IF-DAG: %[[SR1:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]]], %{{.*}} {in_bounds = [true, true, true]} : memref<1x64x128xf32, 3>, vector<1x64x128xf32> + // CHECK-SCF-IF: %[[W0:.*]] = "some_def_0"(%[[SR0]]) : (vector<32x64x1xf32>) -> vector<32x64x1xf32> + // CHECK-SCF-IF: %[[W1:.*]] = "some_def_1"(%[[SR1]]) : (vector<1x64x128xf32>) -> vector<1x64x128xf32> + // CHECK-SCF-IF-DAG: vector.transfer_write %[[W0]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<32x64x1xf32>, memref<32x64x1xf32, 3> + // CHECK-SCF-IF-DAG: vector.transfer_write %[[W1]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x64x128xf32>, memref<1x64x128xf32, 3> + + %r0 = "some_def_0"(%arg0) : (vector<32x64x1xf32>) -> vector<32x64x1xf32> + %r1 = "some_def_1"(%arg1) : (vector<1x64x128xf32>) -> vector<1x64x128xf32> + + // CHECK-SCF-IF-NOT: vector.yield + vector.yield %r0, %r1 : vector<32x64x1xf32>, vector<1x64x128xf32> + } + + // CHECK-SCF-IF: gpu.barrier + // CHECK-SCF-IF: %[[WID:.*]] = affine.apply #[[$TIMES2]]()[%[[LANEID]]] + // CHECK-SCF-IF-DAG: %[[R0:.*]] = vector.transfer_read %{{.*}}[%[[LANEID]], %[[C0]], %[[C0]]], %cst {in_bounds = [true, true, true]} : memref<32x64x1xf32, 3>, vector<1x64x1xf32> + // CHECK-SCF-IF-DAG: %[[R1:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[WID]], %[[C0]]], %cst {in_bounds = [true, true, true]} : memref<1x64x128xf32, 3>, vector<1x2x128xf32> + // CHECK-SCF-IF: return %[[R0]], %[[R1]] : vector<1x64x1xf32>, vector<1x2x128xf32> + return %r#0, %r#1 : vector<1x64x1xf32>, vector<1x2x128xf32> +}