diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -5752,22 +5752,26 @@ expandedVecType.getElementType() != distributedVecType.getElementType()) return op->emitOpError( "expected distributed vectors to have same rank and element type."); - bool foundDistributedDim = false; + + SmallVector scales(expandedVecType.getRank(), 1); for (int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) { - if (expandedVecType.getDimSize(i) == distributedVecType.getDimSize(i)) - continue; - if (expandedVecType.getDimSize(i) == - distributedVecType.getDimSize(i) * warpSize) { - if (foundDistributedDim) - return op->emitOpError() - << "expected only one dimension to be distributed from " - << expandedVecType << " to " << distributedVecType; - foundDistributedDim = true; + int64_t eDim = expandedVecType.getDimSize(i); + int64_t dDim = distributedVecType.getDimSize(i); + if (eDim == dDim) continue; - } - return op->emitOpError() << "incompatible distribution dimensions from " - << expandedVecType << " to " << distributedVecType; + if (eDim % dDim != 0) + return op->emitOpError() + << "expected expanded vector dimension #" << i << " (" << eDim + << ") to be a multipler of the distributed vector dimension (" + << dDim << ")"; + scales[i] = eDim / dDim; } + if (std::accumulate(scales.begin(), scales.end(), 1, + std::multiplies()) != warpSize) + return op->emitOpError() + << "incompatible distribution dimensions from " << expandedVecType + << " to " << distributedVecType << " with warp size = " << warpSize; + return success(); } diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -16,6 +16,7 @@ #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/SetVector.h" +#include #include using namespace mlir; @@ -45,8 +46,6 @@ } auto map = AffineMap::get(sequentialType.getRank(), 0, perm, distributedType.getContext()); - assert(map.getNumResults() <= 1 && - "only support distribution along one dimension for now."); return map; } @@ -702,6 +701,48 @@ } }; +/// Delinearize the given `laneId` into multiple dimensions, where each +/// dimension's size is determined by `originalShape` and `distributedShape` +/// together. This function expects the total numbers of threads needed for +/// distribution is equal to `warpSize`. Returns true and updates +/// `delinearizedIds` if so. +bool delinearizeLaneId(OpBuilder &builder, Location loc, + ArrayRef originalShape, + ArrayRef distributedShape, int64_t warpSize, + Value laneId, SmallVectorImpl &delinearizedIds) { + SmallVector sizes; + for (auto [large, small] : llvm::zip_equal(originalShape, distributedShape)) { + if (large % small != 0) + return false; + sizes.push_back(large / small); + } + if (std::accumulate(sizes.begin(), sizes.end(), 1, + std::multiplies()) != warpSize) + return false; + + AffineExpr s0, s1; + bindSymbols(builder.getContext(), s0, s1); + + int64_t usedThreads = 1; + + Value zero = builder.create(loc, 0); + delinearizedIds.assign(sizes.size(), zero); + + for (int i = sizes.size() - 1; i >= 0; --i) { + if ((usedThreads *= sizes[i]) == warpSize) { + // We've used up all available threads. Don't need to perform modulo + // anymore. And we can stop the calculation for further dimensions. + delinearizedIds[i] = laneId; + break; + } + delinearizedIds[i] = + affine::makeComposedAffineApply(builder, loc, s0 % sizes[i], {laneId}); + laneId = affine::makeComposedAffineApply( + builder, loc, s0.floorDiv(usedThreads), {laneId}); + } + return true; +} + /// Sink out transfer_read op feeding into a warp op yield. /// ``` /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) { @@ -743,6 +784,16 @@ AffineMap indexMap = map.compose(read.getPermutationMap()); OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPointAfter(warpOp); + + // Try to delinearize the lane ID to match the rank expected for + // distribution. + SmallVector delinearizedIds; + if (!delinearizeLaneId(rewriter, read.getLoc(), sequentialType.getShape(), + distributedType.getShape(), warpOp.getWarpSize(), + warpOp.getLaneid(), delinearizedIds)) + return rewriter.notifyMatchFailure( + read, "cannot delinearize lane ID for distribution"); + for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) { AffineExpr d0, d1; bindDims(read.getContext(), d0, d1); @@ -751,11 +802,10 @@ continue; unsigned indexPos = indexExpr.getPosition(); unsigned vectorPos = std::get<1>(it).cast().getPosition(); - int64_t scale = - cast(distributedVal.getType()).getDimSize(vectorPos); + int64_t scale = distributedType.getDimSize(vectorPos); indices[indexPos] = affine::makeComposedAffineApply( rewriter, read.getLoc(), d0 + scale * d1, - {indices[indexPos], warpOp.getLaneid()}); + {indices[indexPos], delinearizedIds[vectorPos]}); } auto newRead = rewriter.create( read.getLoc(), distributedVal.getType(), read.getSource(), indices, @@ -918,6 +968,48 @@ } }; +/// Pattern to move shape cast out of the warp op. shape cast is basically a +/// no-op for warp distribution; we need to handle the shape though. +struct WarpOpShapeCast : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, + PatternRewriter &rewriter) const override { + OpOperand *operand = getWarpResult( + warpOp, [](Operation *op) { return isa(op); }); + if (!operand) + return failure(); + auto oldCastOp = operand->get().getDefiningOp(); + + unsigned int operandNumber = operand->getOperandNumber(); + auto castDistributedType = + cast(warpOp->getResultTypes()[operandNumber]); + VectorType castOriginalType = oldCastOp.getSourceVectorType(); + VectorType castResultType = castDistributedType; + + // We expect the distributed type to have a smaller rank than the original + // type. Prepend with size-one dimensions to make them the same. + unsigned castDistributedRank = castDistributedType.getRank(); + unsigned castOriginalRank = castOriginalType.getRank(); + if (castDistributedRank < castOriginalRank) { + SmallVector shape(castOriginalRank - castDistributedRank, 1); + llvm::append_range(shape, castDistributedType.getShape()); + castDistributedType = + VectorType::get(shape, castDistributedType.getElementType()); + } + + SmallVector newRetIndices; + WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, {oldCastOp.getSource()}, {castDistributedType}, + newRetIndices); + rewriter.setInsertionPointAfter(newWarpOp); + Value newCast = rewriter.create( + oldCastOp.getLoc(), castResultType, + newWarpOp->getResult(newRetIndices[0])); + rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newCast); + return success(); + } +}; + /// Pattern to move out vector.extract of single element vector. Those don't /// need to be distributed and can just be propagated outside of the region. struct WarpOpExtract : public OpRewritePattern { @@ -1557,9 +1649,9 @@ RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn, const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit) { patterns.add( - patterns.getContext(), benefit); + WarpOpBroadcast, WarpOpShapeCast, WarpOpExtract, + WarpOpForwardOperand, WarpOpConstant, WarpOpInsertElement, + WarpOpInsert>(patterns.getContext(), benefit); patterns.add(patterns.getContext(), warpShuffleFromIdxFn, benefit); patterns.add(patterns.getContext(), distributionMapFn, diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1593,7 +1593,7 @@ // ----- func.func @warp_2_distributed_dims(%laneid: index) { - // expected-error@+1 {{'vector.warp_execute_on_lane_0' op expected only one dimension to be distributed from 'vector<128x128xi32>' to 'vector<4x4xi32>'}} + // expected-error@+1 {{incompatible distribution dimensions from 'vector<128x128xi32>' to 'vector<4x4xi32>' with warp size = 32}} %2 = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<4x4xi32>) { %0 = arith.constant dense<2>: vector<128x128xi32> vector.yield %0 : vector<128x128xi32> @@ -1603,6 +1603,17 @@ // ----- +func.func @warp_2_distributed_dims(%laneid: index) { + // expected-error@+1 {{expected expanded vector dimension #1 (8) to be a multipler of the distributed vector dimension (3)}} + %2 = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<1x3xi32>) { + %0 = arith.constant dense<2>: vector<4x8xi32> + vector.yield %0 : vector<4x8xi32> + } + return +} + +// ----- + func.func @warp_mismatch_rank(%laneid: index) { // expected-error@+1 {{'vector.warp_execute_on_lane_0' op expected distributed vectors to have same rank and element type.}} %2 = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<4x4xi32>) { diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -849,6 +849,17 @@ return } +// CHECK-LABEL: func.func @warp_execute_on_lane_0_2d +func.func @warp_execute_on_lane_0_2d(%laneid: index) { + // CHECK: vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1x4xi32>) + %2 = vector.warp_execute_on_lane_0(%laneid)[32] -> (vector<1x4xi32>) { + %0 = arith.constant dense<2>: vector<4x32xi32> + // CHECK: vector.yield %{{.+}} : vector<4x32xi32> + vector.yield %0 : vector<4x32xi32> + } + return +} + // CHECK-LABEL: func @warp_operand_result( func.func @warp_operand_result(%laneid: index, %v0 : vector<4xi32>) -> (vector<4xi32>) { // CHECK-NEXT: %{{.*}} = vector.warp_execute_on_lane_0(%{{.*}})[32] args(%{{.*}} : vector<4xi32>) -> (vector<4xi32>) { diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir --- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir +++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir @@ -827,6 +827,50 @@ // ----- +func.func @warp_propagate_read_3d(%laneid: index, %src: memref<32x4x32xf32>) -> vector<1x1x4xf32> { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f32 + %r = vector.warp_execute_on_lane_0(%laneid)[1024] -> (vector<1x1x4xf32>) { + %2 = vector.transfer_read %src[%c0, %c0, %c0], %cst : memref<32x4x32xf32>, vector<32x4x32xf32> + vector.yield %2 : vector<32x4x32xf32> + } + return %r : vector<1x1x4xf32> +} + +// CHECK-PROP-DAG: #[[$ID0MAP:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 8) * 32)> +// CHECK-PROP-DAG: #[[$ID1MAP:.+]] = affine_map<()[s0] -> ((s0 floordiv 8) mod 4)> +// CHECK-PROP-DAG: #[[$ID2MAP:.+]] = affine_map<()[s0] -> ((s0 floordiv 8) floordiv 32)> +// CHECK-PROP-LABEL: func.func @warp_propagate_read_3d +// CHECK-PROP-SAME: (%[[LANE:.+]]: index, %[[SRC:.+]]: memref<32x4x32xf32>) +// CHECK-PROP-DAG: %[[ID0:.+]] = affine.apply #[[$ID0MAP]]()[%[[LANE]]] +// CHECK-PROP-DAG: %[[ID1:.+]] = affine.apply #[[$ID1MAP]]()[%[[LANE]]] +// CHECK-PROP-DAG: %[[ID2:.+]] = affine.apply #[[$ID2MAP]]()[%[[LANE]]] +// CHECK-PROP: %[[READ:.+]] = vector.transfer_read %[[SRC]][%[[ID2]], %[[ID1]], %[[ID0]]], %{{.+}} : memref<32x4x32xf32>, vector<1x1x4xf32> +// CHECK-PROP: return %[[READ]] : vector<1x1x4xf32> + +// ----- + +func.func @warp_propagate_read_broadcast(%laneid: index, %src: memref<32x1xf32>) -> vector<1x4xf32> { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f32 + %r = vector.warp_execute_on_lane_0(%laneid)[512] -> (vector<1x4xf32>) { + %2 = vector.transfer_read %src[%c0, %c0], %cst {in_bounds = [true, true], permutation_map = affine_map<(d0, d1) -> (d0, 0)>} : memref<32x1xf32>, vector<32x64xf32> + vector.yield %2 : vector<32x64xf32> + } + return %r : vector<1x4xf32> +} + +// CHECK-PROP-DAG: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 16)> +// CHECK-PROP-DAG: #[[$READMAP:.+]] = affine_map<(d0, d1) -> (d0, 0)> +// CHECK-PROP-LABEL: func.func @warp_propagate_read_broadcast +// CHECK-PROP-SAME: (%[[LANE:.+]]: index, %[[SRC:.+]]: memref<32x1xf32>) +// CHECK-PROP: %[[C0:.+]] = arith.constant 0 : index +// CHECK-PROP: %[[ID:.+]] = affine.apply #[[$MAP]]()[%[[LANE]]] +// CHECK-PROP: %[[READ:.+]] = vector.transfer_read %[[SRC]][%[[ID]], %[[C0]]], %{{.+}} {in_bounds = [true, true], permutation_map = #[[$READMAP]]} : memref<32x1xf32>, vector<1x4xf32> +// CHECK-PROP: return %[[READ]] : vector<1x4xf32> + +// ----- + // CHECK-PROP: func @dont_duplicate_read func.func @dont_duplicate_read( %laneid: index, %src: memref<1024xf32>) -> vector<1xf32> { @@ -1173,3 +1217,22 @@ vector.print %r : vector<1x2xf32> return } + +// ----- + +func.func @warp_propagate_shape_cast(%laneid: index, %src: memref<32x4x32xf32>) -> vector<4xf32> { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f32 + %r = vector.warp_execute_on_lane_0(%laneid)[1024] -> (vector<4xf32>) { + %2 = vector.transfer_read %src[%c0, %c0, %c0], %cst : memref<32x4x32xf32>, vector<32x4x32xf32> + %3 = vector.shape_cast %2 : vector<32x4x32xf32> to vector<4096xf32> + vector.yield %3 : vector<4096xf32> + } + return %r : vector<4xf32> +} + +// CHECK-PROP-LABEL: func.func @warp_propagate_shape_cast +// CHECK-PROP: %[[READ:.+]] = vector.transfer_read {{.+}} : memref<32x4x32xf32>, vector<1x1x4xf32> +// CHECK-PROP: %[[CAST:.+]] = vector.shape_cast %[[READ]] : vector<1x1x4xf32> to vector<4xf32> +// CHECK-PROP: return %[[CAST]] : vector<4xf32> +