diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorDistribution.h @@ -67,11 +67,17 @@ /// region. void moveScalarUniformCode(WarpExecuteOnLane0Op op); +/// Lambda signature to compute a warp shuffle of a given value of a given lane +/// within a given warp size. +using WarpShuffleFromIdxFn = + std::function; + /// Collect patterns to propagate warp distribution. `distributionMapFn` is used /// to decide how a value should be distributed when this cannot be inferred /// from its uses. void populatePropagateWarpVectorDistributionPatterns( RewritePatternSet &pattern, const DistributionMapFn &distributionMapFn, + const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit = 1); /// Lambda signature to compute a reduction of a distributed value for the given diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -915,7 +915,10 @@ /// Pattern to move out vector.extractelement of 0-D tensors. Those don't /// need to be distributed and can just be propagated outside of the region. struct WarpOpExtractElement : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + WarpOpExtractElement(MLIRContext *ctx, WarpShuffleFromIdxFn fn, + PatternBenefit b = 1) + : OpRewritePattern(ctx, b), + warpShuffleFromIdxFn(std::move(fn)) {} LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { OpOperand *operand = getWarpResult(warpOp, [](Operation *op) { @@ -925,19 +928,75 @@ return failure(); unsigned int operandNumber = operand->getOperandNumber(); auto extractOp = operand->get().getDefiningOp(); - if (extractOp.getVectorType().getRank() != 0) - return failure(); + VectorType extractSrcType = extractOp.getVectorType(); + bool is0dExtract = extractSrcType.getRank() == 0; + Type elType = extractSrcType.getElementType(); + VectorType distributedVecType; + if (!is0dExtract) { + assert(extractSrcType.getRank() == 1 && + "expected that extractelement src rank is 0 or 1"); + int64_t elementsPerLane = + extractSrcType.getShape()[0] / warpOp.getWarpSize(); + distributedVecType = VectorType::get({elementsPerLane}, elType); + } else { + distributedVecType = extractSrcType; + } + + // Yield source vector from warp op. Location loc = extractOp.getLoc(); SmallVector newRetIndices; WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( - rewriter, warpOp, {extractOp.getVector()}, {extractOp.getVectorType()}, + rewriter, warpOp, {extractOp.getVector()}, {distributedVecType}, newRetIndices); rewriter.setInsertionPointAfter(newWarpOp); - Value newExtract = rewriter.create( - loc, newWarpOp->getResult(newRetIndices[0])); - newWarpOp->getResult(operandNumber).replaceAllUsesWith(newExtract); + Value distributedVec = newWarpOp->getResult(newRetIndices[0]); + + // 0d extract: The new warp op broadcasts the source vector to all lanes. + // All lanes extract the scalar. + if (is0dExtract) { + Value newExtract = + rewriter.create(loc, distributedVec); + newWarpOp->getResult(operandNumber).replaceAllUsesWith(newExtract); + return success(); + } + + // 1d extract: Distribute the source vector. One lane extracts and shuffles + // the value to all other lanes. + int64_t elementsPerLane = distributedVecType.getShape()[0]; + AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext()); + AffineExpr sym1 = getAffineSymbolExpr(1, rewriter.getContext()); + // tid of extracting thread: pos / elementsPerLane + Value broadcastFromTid = rewriter.create( + loc, sym0.ceilDiv(elementsPerLane), extractOp.getPosition()); + Value isBroadcastSrc = rewriter.create( + loc, arith::CmpIPredicate::eq, broadcastFromTid, newWarpOp.getLaneid()); + auto thenBuilder = [&](OpBuilder &b, Location loc) { + // This is the extracing lane. + Value pos = rewriter.create( + loc, sym0 - sym1 * elementsPerLane, + ArrayRef{extractOp.getPosition(), newWarpOp.getLaneid()}); + Value extracted = + rewriter.create(loc, distributedVec, pos); + rewriter.create(loc, extracted); + }; + auto elseBuilder = [&](OpBuilder &b, Location loc) { + // This lane does nothing. + Value zero = b.create(loc, elType, + rewriter.getZeroAttr(elType)); + rewriter.create(loc, zero); + }; + scf::IfOp extracted = rewriter.create( + loc, elType, isBroadcastSrc, thenBuilder, elseBuilder); + // Shuffle the extracted value to all lanes. + Value shuffled = warpShuffleFromIdxFn( + loc, rewriter, extracted.getResult(0), broadcastFromTid, + rewriter.create(loc, newWarpOp.getWarpSize())); + newWarpOp->getResult(operandNumber).replaceAllUsesWith(shuffled); return success(); } + +private: + WarpShuffleFromIdxFn warpShuffleFromIdxFn; }; /// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if @@ -1195,11 +1254,12 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns( RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn, - PatternBenefit benefit) { + const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit) { patterns.add(patterns.getContext(), - benefit); + WarpOpBroadcast, WarpOpExtract, WarpOpForwardOperand, + WarpOpConstant>(patterns.getContext(), benefit); + patterns.add(patterns.getContext(), + warpShuffleFromIdxFn, benefit); patterns.add(patterns.getContext(), distributionMapFn, benefit); } diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir --- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir +++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir @@ -666,14 +666,14 @@ // ----- -// CHECK-PROP-LABEL: func.func @vector_extractelement_simple( +// CHECK-PROP-LABEL: func.func @vector_extractelement_0d( // CHECK-PROP: %[[R:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector) { // CHECK-PROP: %[[V:.*]] = "some_def"() : () -> vector // CHECK-PROP: vector.yield %[[V]] : vector // CHECK-PROP: } // CHECK-PROP: %[[E:.*]] = vector.extractelement %[[R]][] : vector // CHECK-PROP: return %[[E]] : f32 -func.func @vector_extractelement_simple(%laneid: index) -> (f32) { +func.func @vector_extractelement_0d(%laneid: index) -> (f32) { %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (f32) { %0 = "some_def"() : () -> (vector) %1 = vector.extractelement %0[] : vector @@ -684,6 +684,39 @@ // ----- +// CHECK-PROP: #[[$map:.*]] = affine_map<()[s0] -> (s0 ceildiv 3)> +// CHECK-PROP: #[[$map1:.*]] = affine_map<()[s0, s1] -> (s0 - s1 * 3)> +// CHECK-PROP-LABEL: func.func @vector_extractelement_1d( +// CHECK-PROP-SAME: %[[LANEID:.*]]: index, %[[POS:.*]]: index +// CHECK-PROP-DAG: %[[ZERO:.*]] = arith.constant 0.0 +// CHECK-PROP-DAG: %[[C32:.*]] = arith.constant 32 : i32 +// CHECK-PROP: %[[W:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<3xf32>) { +// CHECK-PROP: %[[V:.*]] = "some_def" +// CHECK-PROP: vector.yield %[[V]] : vector<96xf32> +// CHECK-PROP: } +// CHECK-PROP: %[[FROM_IDX:.*]] = affine.apply #[[$map]]()[%[[POS]]] +// CHECK-PROP: %[[IS_SHUFFLE_SRC:.*]] = arith.cmpi eq, %[[FROM_IDX]], %[[LANEID]] +// CHECK-PROP: %[[EXTRACTED:.*]] = scf.if %[[IS_SHUFFLE_SRC]] -> (f32) { +// CHECK-PROP: %[[DISTR_POS:.*]] = affine.apply #[[$map1]]()[%[[POS]], %[[LANEID]]] +// CHECK-PROP: %[[EXTR:.*]] = vector.extractelement %[[W]][%[[DISTR_POS]] : index] : vector<3xf32> +// CHECK-PROP: scf.yield %[[EXTR]] : f32 +// CHECK-PROP: } else { +// CHECK-PROP: scf.yield %[[ZERO]] +// CHECK-PROP: } +// CHECK-PROP: %[[FROM_IDX_I32:.*]] = arith.index_cast %[[FROM_IDX]] : index to i32 +// CHECK-PROP: %[[SHUFFLED:.*]], %{{.*}} = gpu.shuffle idx %[[EXTRACTED]], %[[FROM_IDX_I32]], %[[C32]] : f32 +// CHECK-PROP: return %[[SHUFFLED]] +func.func @vector_extractelement_1d(%laneid: index, %pos: index) -> (f32) { + %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (f32) { + %0 = "some_def"() : () -> (vector<96xf32>) + %1 = vector.extractelement %0[%pos : index] : vector<96xf32> + vector.yield %1 : f32 + } + return %r : f32 +} + +// ----- + // CHECK-PROP: func @lane_dependent_warp_propagate_read // CHECK-PROP-SAME: %[[ID:.*]]: index func.func @lane_dependent_warp_propagate_read( diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -757,6 +757,21 @@ return AffineMap::get(val.getContext()); return AffineMap::get(vecRank, 0, builder.getAffineDimExpr(vecRank - 1)); }; + auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, + Value srcIdx, Value warpSz) { + assert((val.getType().isF32() || val.getType().isInteger(32)) && + "unsupported shuffle type"); + Type i32Type = builder.getIntegerType(32); + Value srcIdxI32 = + builder.create(loc, i32Type, srcIdx); + Value warpSzI32 = + builder.create(loc, i32Type, warpSz); + Value result = builder + .create(loc, val, srcIdxI32, warpSzI32, + gpu::ShuffleMode::IDX) + .getResult(0); + return result; + }; if (distributeTransferWriteOps) { RewritePatternSet patterns(ctx); populateDistributeTransferWriteOpPatterns(patterns, distributionFn); @@ -764,8 +779,8 @@ } if (propagateDistribution) { RewritePatternSet patterns(ctx); - vector::populatePropagateWarpVectorDistributionPatterns(patterns, - distributionFn); + vector::populatePropagateWarpVectorDistributionPatterns( + patterns, distributionFn, shuffleFn); vector::populateDistributeReduction(patterns, warpReduction); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); }