diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h @@ -67,11 +67,17 @@ /// region. void moveScalarUniformCode(WarpExecuteOnLane0Op op); +/// Lambda signature to compute a warp shuffle of a given value of a given lane +/// within a given warp size. +using WarpShuffleFromIdxFn = + std::function; + /// Collect patterns to propagate warp distribution. `distributionMapFn` is used /// to decide how a value should be distributed when this cannot be inferred /// from its uses. void populatePropagateWarpVectorDistributionPatterns( RewritePatternSet &pattern, const DistributionMapFn &distributionMapFn, + const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit = 1); /// Lambda signature to compute a reduction of a distributed value for the given diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -915,7 +915,10 @@ /// Pattern to move out vector.extractelement of 0-D tensors. Those don't /// need to be distributed and can just be propagated outside of the region. struct WarpOpExtractElement : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + WarpOpExtractElement(MLIRContext *ctx, WarpShuffleFromIdxFn fn, + PatternBenefit b = 1) + : OpRewritePattern(ctx, b), + warpShuffleFromIdxFn(std::move(fn)) {} LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { OpOperand *operand = getWarpResult(warpOp, [](Operation *op) { @@ -925,19 +928,60 @@ return failure(); unsigned int operandNumber = operand->getOperandNumber(); auto extractOp = operand->get().getDefiningOp(); - if (extractOp.getVectorType().getRank() != 0) - return failure(); + VectorType extractSrcType = extractOp.getVectorType(); + bool is0dExtract = extractSrcType.getRank() == 0; + Type elType = extractSrcType.getElementType(); + VectorType distributedVecType; + if (!is0dExtract) { + assert(extractSrcType.getRank() == 1 && + "expected that extractelement src rank is 0 or 1"); + int64_t elementsPerLane = + extractSrcType.getShape()[0] / warpOp.getWarpSize(); + distributedVecType = VectorType::get({elementsPerLane}, elType); + } else { + distributedVecType = extractSrcType; + } + + // Yield source vector from warp op. Location loc = extractOp.getLoc(); SmallVector newRetIndices; WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( - rewriter, warpOp, {extractOp.getVector()}, {extractOp.getVectorType()}, + rewriter, warpOp, {extractOp.getVector()}, {distributedVecType}, newRetIndices); rewriter.setInsertionPointAfter(newWarpOp); - Value newExtract = rewriter.create( - loc, newWarpOp->getResult(newRetIndices[0])); - newWarpOp->getResult(operandNumber).replaceAllUsesWith(newExtract); + Value distributedVec = newWarpOp->getResult(newRetIndices[0]); + + // 0d extract: The new warp op broadcasts the source vector to all lanes. + // All lanes extract the scalar. + if (is0dExtract) { + Value newExtract = + rewriter.create(loc, distributedVec); + newWarpOp->getResult(operandNumber).replaceAllUsesWith(newExtract); + return success(); + } + + // 1d extract: Distribute the source vector. One lane extracts and shuffles + // the value to all other lanes. + int64_t elementsPerLane = distributedVecType.getShape()[0]; + AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext()); + // tid of extracting thread: pos / elementsPerLane + Value broadcastFromTid = rewriter.create( + loc, sym0.ceilDiv(elementsPerLane), extractOp.getPosition()); + // Extract at position: pos % elementsPerLane + Value pos = rewriter.create(loc, sym0 % elementsPerLane, + extractOp.getPosition()); + Value extracted = + rewriter.create(loc, distributedVec, pos); + + // Shuffle the extracted value to all lanes. + Value shuffled = warpShuffleFromIdxFn( + loc, rewriter, extracted, broadcastFromTid, newWarpOp.getWarpSize()); + newWarpOp->getResult(operandNumber).replaceAllUsesWith(shuffled); return success(); } + +private: + WarpShuffleFromIdxFn warpShuffleFromIdxFn; }; /// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if @@ -1194,11 +1238,12 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns( RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn, - PatternBenefit benefit) { + const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit) { patterns.add(patterns.getContext(), - benefit); + WarpOpBroadcast, WarpOpExtract, WarpOpForwardOperand, + WarpOpConstant>(patterns.getContext(), benefit); + patterns.add(patterns.getContext(), + warpShuffleFromIdxFn, benefit); patterns.add(patterns.getContext(), distributionMapFn, benefit); } diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir --- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir +++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir @@ -666,14 +666,14 @@ // ----- -// CHECK-PROP-LABEL: func.func @vector_extractelement_simple( +// CHECK-PROP-LABEL: func.func @vector_extractelement_0d( // CHECK-PROP: %[[R:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector) { // CHECK-PROP: %[[V:.*]] = "some_def"() : () -> vector // CHECK-PROP: vector.yield %[[V]] : vector // CHECK-PROP: } // CHECK-PROP: %[[E:.*]] = vector.extractelement %[[R]][] : vector // CHECK-PROP: return %[[E]] : f32 -func.func @vector_extractelement_simple(%laneid: index) -> (f32) { +func.func @vector_extractelement_0d(%laneid: index) -> (f32) { %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (f32) { %0 = "some_def"() : () -> (vector) %1 = vector.extractelement %0[] : vector @@ -684,6 +684,32 @@ // ----- +// CHECK-PROP: #[[$map:.*]] = affine_map<()[s0] -> (s0 ceildiv 3)> +// CHECK-PROP: #[[$map1:.*]] = affine_map<()[s0] -> (s0 mod 3)> +// CHECK-PROP-LABEL: func.func @vector_extractelement_1d( +// CHECK-PROP-SAME: %[[LANEID:.*]]: index, %[[POS:.*]]: index +// CHECK-PROP-DAG: %[[C32:.*]] = arith.constant 32 : i32 +// CHECK-PROP: %[[W:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<3xf32>) { +// CHECK-PROP: %[[V:.*]] = "some_def" +// CHECK-PROP: vector.yield %[[V]] : vector<96xf32> +// CHECK-PROP: } +// CHECK-PROP: %[[FROM_LANE:.*]] = affine.apply #[[$map]]()[%[[POS]]] +// CHECK-PROP: %[[DISTR_POS:.*]] = affine.apply #[[$map1]]()[%[[POS]]] +// CHECK-PROP: %[[EXTRACTED:.*]] = vector.extractelement %[[W]][%[[DISTR_POS]] : index] : vector<3xf32> +// CHECK-PROP: %[[FROM_LANE_I32:.*]] = arith.index_cast %[[FROM_LANE]] : index to i32 +// CHECK-PROP: %[[SHUFFLED:.*]], %{{.*}} = gpu.shuffle idx %[[EXTRACTED]], %[[FROM_LANE_I32]], %[[C32]] : f32 +// CHECK-PROP: return %[[SHUFFLED]] +func.func @vector_extractelement_1d(%laneid: index, %pos: index) -> (f32) { + %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (f32) { + %0 = "some_def"() : () -> (vector<96xf32>) + %1 = vector.extractelement %0[%pos : index] : vector<96xf32> + vector.yield %1 : f32 + } + return %r : f32 +} + +// ----- + // CHECK-PROP: func @lane_dependent_warp_propagate_read // CHECK-PROP-SAME: %[[ID:.*]]: index func.func @lane_dependent_warp_propagate_read( diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -759,6 +759,21 @@ return AffineMap::get(val.getContext()); return AffineMap::get(vecRank, 0, builder.getAffineDimExpr(vecRank - 1)); }; + auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, + Value srcIdx, int64_t warpSz) { + assert((val.getType().isF32() || val.getType().isInteger(32)) && + "unsupported shuffle type"); + Type i32Type = builder.getIntegerType(32); + Value srcIdxI32 = + builder.create(loc, i32Type, srcIdx); + Value warpSzI32 = builder.create( + loc, builder.getIntegerAttr(i32Type, warpSz)); + Value result = builder + .create(loc, val, srcIdxI32, warpSzI32, + gpu::ShuffleMode::IDX) + .getResult(0); + return result; + }; if (distributeTransferWriteOps) { RewritePatternSet patterns(ctx); populateDistributeTransferWriteOpPatterns(patterns, distributionFn); @@ -766,8 +781,8 @@ } if (propagateDistribution) { RewritePatternSet patterns(ctx); - vector::populatePropagateWarpVectorDistributionPatterns(patterns, - distributionFn); + vector::populatePropagateWarpVectorDistributionPatterns( + patterns, distributionFn, shuffleFn); vector::populateDistributeReduction(patterns, warpReduction); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); }