diff --git a/mlir/include/mlir/Conversion/VectorToSCF/ProgressiveVectorToSCF.h b/mlir/include/mlir/Conversion/VectorToSCF/ProgressiveVectorToSCF.h --- a/mlir/include/mlir/Conversion/VectorToSCF/ProgressiveVectorToSCF.h +++ b/mlir/include/mlir/Conversion/VectorToSCF/ProgressiveVectorToSCF.h @@ -47,12 +47,24 @@ /// When applying the pattern a second time, the existing alloca() operation /// is reused and only a second vector.type_cast is added. +struct ProgressiveVectorTransferToSCFOptions { + bool unroll = false; + ProgressiveVectorTransferToSCFOptions &setUnroll(bool u) { + unroll = u; + return *this; + } +}; + /// Collect a set of patterns to convert from the Vector dialect to SCF + std. void populateProgressiveVectorToSCFConversionPatterns( - RewritePatternSet &patterns); + RewritePatternSet &patterns, + const ProgressiveVectorTransferToSCFOptions &options = + ProgressiveVectorTransferToSCFOptions()); /// Create a pass to convert a subset of vector ops to SCF. -std::unique_ptr createProgressiveConvertVectorToSCFPass(); +std::unique_ptr createProgressiveConvertVectorToSCFPass( + const ProgressiveVectorTransferToSCFOptions &options = + ProgressiveVectorTransferToSCFOptions()); } // namespace mlir diff --git a/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp --- a/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp @@ -266,6 +266,14 @@ return ArrayAttr::get(builder.getContext(), attr.getValue().drop_front()); } +/// Add the pass label to a vector transfer op if its rank is not the target +/// rank. +template +static void maybeApplyPassLabel(OpBuilder &builder, OpTy newXferOp) { + if (newXferOp.getVectorType().getRank() > kTargetRank) + newXferOp->setAttr(kPassLabel, builder.getUnitAttr()); +} + /// Given a transfer op, find the memref from which the mask is loaded. This /// is similar to Strategy::getBuffer. template @@ -356,8 +364,8 @@ AffineMapAttr::get(unpackedPermutationMap(xferOp, builder)), xferOp.padding(), Value(), inBoundsAttr).value; - if (vecType.getRank() > kTargetRank) - newXfer.getDefiningOp()->setAttr(kPassLabel, builder.getUnitAttr()); + maybeApplyPassLabel(builder, + dyn_cast(newXfer.getDefiningOp())); memref_store(newXfer, buffer, storeIndices); return newXfer.getDefiningOp(); @@ -428,15 +436,13 @@ getXferIndices(xferOp, iv, xferIndices); auto vec = memref_load(buffer, loadIndices); - auto vecType = vec.value.getType().dyn_cast(); auto inBoundsAttr = dropFirstElem(builder, xferOp.in_boundsAttr()); auto newXfer = vector_transfer_write( Type(), vec, xferOp.source(), xferIndices, AffineMapAttr::get(unpackedPermutationMap(xferOp, builder)), Value(), inBoundsAttr); - if (vecType.getRank() > kTargetRank) - newXfer.op->setAttr(kPassLabel, builder.getUnitAttr()); + maybeApplyPassLabel(builder, newXfer.op); return newXfer; } @@ -668,6 +674,256 @@ } }; +/// If the original transfer op has a mask, compute the mask of the new transfer +/// op (for the current iteration `i`) and assign it. +template +static void maybeAssignMask(OpBuilder &builder, OpTy xferOp, OpTy newXferOp, + int64_t i) { + // If old transfer op has a mask: Set mask on new transfer op. + if (xferOp.mask()) { + if (isOutermostDimBroadcast(xferOp)) { + newXferOp.maskMutable().assign(xferOp.mask()); + } else if (xferOp.getMaskType()->getRank() > 1) { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPoint(newXferOp); // Insert load before newXfer. + + llvm::SmallVector indices({i}); + auto newMask = vector_extract(xferOp.mask(), indices).value; + newXferOp.maskMutable().assign(newMask); + } + // Else: If the mask of the old transfer op is 1D and the unpacked dim is + // not a broadcast, no mask is needed on the new transfer op. + } +} + +/// Progressive lowering of vector TransferReadOp with unrolling: Unpack one +/// dimension. This is similar to TransferOpConversion, but no +/// memref buffer is allocated and the SCF loop is fully unrolled. +/// +/// ``` +/// E.g.: +/// ``` +/// %vec = vector.transfer_read %A[%a, %b, %c], %padding +/// : memref, vector<5x4xf32> +/// ``` +/// is rewritten to IR such as (simplified): +/// ``` +/// %v_init = splat %padding : vector<5x4xf32> +/// %tmp0 = vector.transfer_read %A[%a, %b, %c], %padding +/// : memref, vector<4xf32> +/// %v0 = vector.insert %tmp0, %v_init[0] : vector<4xf32> into vector<5x4xf32> +/// %tmp1 = vector.transfer_read %A[%a, %b + 1, %c], %padding +/// : memref, vector<4xf32> +/// %v1 = vector.insert %tmp1, %v0[1] : vector<4xf32> into vector<5x4xf32> +/// ... +/// %tmp4 = vector.transfer_read %A[%a, %b + 4, %c], %padding +/// : memref, vector<4xf32> +/// %vec = vector.insert %tmp1, %v3[4] : vector<4xf32> into vector<5x4xf32> +/// ``` +/// +/// Note: A pass label is attached to new TransferReadOps, so that subsequent +/// applications of this pattern do not create an additional %v_init vector. +struct UnrollTransferReadConversion : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + /// Find the result vector %v_init or create a new vector if this the first + /// application of the pattern. + Value getResultVector(TransferReadOp xferOp, + PatternRewriter &rewriter) const { + if (xferOp->hasAttr(kPassLabel)) { + return getInsertOp(xferOp).dest(); + } + return std_splat(xferOp.getVectorType(), xferOp.padding()).value; + } + + /// Assuming that this not the first application of the pattern, return the + /// vector.insert op in which the result of this transfer op is used. + vector::InsertOp getInsertOp(TransferReadOp xferOp) const { + Operation *xferOpUser = *xferOp->getUsers().begin(); + return dyn_cast(xferOpUser); + } + + /// Assuming that this not the first application of the pattern, return the + /// indices of the vector.insert op in which the result of this transfer op + /// is used. + void getInsertionIndices(TransferReadOp xferOp, + SmallVector &indices) const { + if (xferOp->hasAttr(kPassLabel)) { + llvm::for_each(getInsertOp(xferOp).position(), [&](Attribute attr) { + indices.push_back(attr.dyn_cast().getInt()); + }); + } + } + + /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds + /// accesses, and broadcasts and transposes in permutation maps. + LogicalResult matchAndRewrite(TransferReadOp xferOp, + PatternRewriter &rewriter) const override { + if (xferOp.getVectorType().getRank() <= kTargetRank) + return failure(); + + ScopedContext scope(rewriter, xferOp.getLoc()); + auto vec = getResultVector(xferOp, rewriter); + auto vecType = vec.getType().dyn_cast(); + auto xferVecType = xferOp.getVectorType(); + auto newXferVecType = VectorType::get(xferVecType.getShape().drop_front(), + xferVecType.getElementType()); + int64_t dimSize = xferVecType.getShape()[0]; + + // Generate fully unrolled loop of transfer ops. + for (int64_t i = 0; i < dimSize; ++i) { + Value iv = std_constant_index(i); + + vec = generateInBoundsCheck( + xferOp, iv, rewriter, unpackedDim(xferOp), TypeRange(vecType), + /*inBoundsCase=*/ + [&](OpBuilder &b, Location loc) { + ScopedContext scope(b, loc); + + // Indices for the new transfer op. + SmallVector xferIndices; + getXferIndices(xferOp, iv, xferIndices); + + // Indices for the new vector.insert op. + SmallVector insertionIndices; + getInsertionIndices(xferOp, insertionIndices); + insertionIndices.push_back(i); + + auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr()); + auto newXferOpVal = + vector_transfer_read( + newXferVecType, xferOp.source(), xferIndices, + AffineMapAttr::get(unpackedPermutationMap(xferOp, b)), + xferOp.padding(), Value(), inBoundsAttr) + .value; + auto newXferOp = + dyn_cast(newXferOpVal.getDefiningOp()); + + maybeAssignMask(b, xferOp, newXferOp, i); + maybeApplyPassLabel(b, newXferOp); + + return vector_insert(newXferOp, vec, insertionIndices).value; + }, + /*outOfBoundsCase=*/ + [&](OpBuilder &b, Location loc) { + // Loop through original (unmodified) vector. + return vec; + }); + } + + if (xferOp->hasAttr(kPassLabel)) { + rewriter.replaceOp(getInsertOp(xferOp), vec); + rewriter.eraseOp(xferOp); + } else { + rewriter.replaceOp(xferOp, vec); + } + + return success(); + } +}; + +/// Progressive lowering of vector TransferWriteOp with unrolling: Unpack one +/// dimension. This is similar to TransferOpConversion, but no +/// memref buffer is allocated and the SCF loop is fully unrolled. +/// +/// ``` +/// E.g.: +/// ``` +/// vector.transfer_write %vec, %A[%a, %b, %c] +/// : vector<5x4xf32>, memref +/// ``` +/// is rewritten to IR such as (simplified): +/// ``` +/// %v0 = vector.extract %vec[0] : vector<5x4xf32> +/// vector.transfer_write %v0, %A[%a, %b, %c] : vector<4xf32>, memref<...> +/// %v1 = vector.extract %vec[1] : vector<5x4xf32> +/// vector.transfer_write %v1, %A[%a, %b + 1, %c] : vector<4xf32>, memref<...> +/// ... +/// %v4 = vector.extract %vec[4] : vector<5x4xf32> +/// vector.transfer_write %v4, %A[%a, %b + 4, %c] : vector<4xf32>, memref<...> +/// ``` +/// +/// Note: A pass label is attached to new TransferWriteOps, so that subsequent +/// applications of this pattern can read the indices of previously generated +/// vector.extract ops. +struct UnrollTransferWriteConversion + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + /// If this is not the first application of the pattern, find the original + /// vector %vec that is written by this transfer op. Otherwise, return the + /// vector of this transfer op. + Value getDataVector(TransferWriteOp xferOp) const { + if (xferOp->hasAttr(kPassLabel)) + return getExtractOp(xferOp).vector(); + return xferOp.vector(); + } + + /// Assuming that this is not the first application of the pattern, find the + /// vector.extract op whose result is written by this transfer op. + vector::ExtractOp getExtractOp(TransferWriteOp xferOp) const { + return dyn_cast(xferOp.vector().getDefiningOp()); + } + + void getExtractionIndices(TransferWriteOp xferOp, + SmallVector &indices) const { + if (xferOp->hasAttr(kPassLabel)) { + llvm::for_each(getExtractOp(xferOp).position(), [&](Attribute attr) { + indices.push_back(attr.dyn_cast().getInt()); + }); + } + } + + /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds + /// accesses, and broadcasts and transposes in permutation maps. + LogicalResult matchAndRewrite(TransferWriteOp xferOp, + PatternRewriter &rewriter) const override { + if (xferOp.getVectorType().getRank() <= kTargetRank) + return failure(); + + ScopedContext scope(rewriter, xferOp.getLoc()); + auto vec = getDataVector(xferOp); + auto xferVecType = xferOp.getVectorType(); + int64_t dimSize = xferVecType.getShape()[0]; + + // Generate fully unrolled loop of transfer ops. + for (int64_t i = 0; i < dimSize; ++i) { + Value iv = std_constant_index(i); + + generateInBoundsCheck( + xferOp, iv, rewriter, unpackedDim(xferOp), + /*inBoundsCase=*/[&](OpBuilder &b, Location loc) { + ScopedContext scope(b, loc); + + // Indices for the new transfer op. + SmallVector xferIndices; + getXferIndices(xferOp, iv, xferIndices); + + // Indices for the new vector.extract op. + SmallVector extractionIndices; + getExtractionIndices(xferOp, extractionIndices); + extractionIndices.push_back(i); + + auto extracted = vector_extract(vec, extractionIndices).value; + auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr()); + + auto newXferOp = + vector_transfer_write( + Type(), extracted, xferOp.source(), xferIndices, + AffineMapAttr::get(unpackedPermutationMap(xferOp, b)), + Value(), inBoundsAttr) + .op; + + maybeAssignMask(b, xferOp, newXferOp, i); + maybeApplyPassLabel(b, newXferOp); + }); + } + + rewriter.eraseOp(xferOp); + return success(); + } +}; + /// Compute the indices into the memref for the LoadOp/StoreOp generated as /// part of TransferOp1dConversion. Return the memref dimension on which /// the transfer is operating. A return value of None indicates a broadcast. @@ -824,11 +1080,16 @@ namespace mlir { void populateProgressiveVectorToSCFConversionPatterns( - RewritePatternSet &patterns) { - patterns.add, - TransferOpConversion>(patterns.getContext()); + RewritePatternSet &patterns, + const ProgressiveVectorTransferToSCFOptions &options) { + if (options.unroll) { + patterns.add( + patterns.getContext()); + } else { + patterns.add, + TransferOpConversion>(patterns.getContext()); + } if (kTargetRank == 1) { patterns.add, @@ -839,16 +1100,22 @@ struct ConvertProgressiveVectorToSCFPass : public ConvertVectorToSCFBase { + ConvertProgressiveVectorToSCFPass( + const ProgressiveVectorTransferToSCFOptions &opt) + : options(opt) {} + void runOnFunction() override { RewritePatternSet patterns(getFunction().getContext()); - populateProgressiveVectorToSCFConversionPatterns(patterns); + populateProgressiveVectorToSCFConversionPatterns(patterns, options); (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); } + + ProgressiveVectorTransferToSCFOptions options; }; } // namespace mlir -std::unique_ptr -mlir::createProgressiveConvertVectorToSCFPass() { - return std::make_unique(); +std::unique_ptr mlir::createProgressiveConvertVectorToSCFPass( + const ProgressiveVectorTransferToSCFOptions &options) { + return std::make_unique(options); } diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir @@ -1,5 +1,10 @@ // RUN: mlir-opt %s -test-progressive-convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \ -// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ +// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +// RUN: mlir-opt %s -test-unrolled-progressive-convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \ +// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ // RUN: FileCheck %s diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir @@ -3,6 +3,11 @@ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ // RUN: FileCheck %s +// RUN: mlir-opt %s -test-unrolled-progressive-convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \ +// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + memref.global "private" @gv : memref<3x4xf32> = dense<[[0. , 1. , 2. , 3. ], [10., 11., 12., 13.], [20., 21., 22., 23.]]> diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir @@ -3,6 +3,11 @@ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ // RUN: FileCheck %s +// RUN: mlir-opt %s -test-unrolled-progressive-convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \ +// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + func @transfer_read_3d(%A : memref, %o: index, %a: index, %b: index, %c: index) { %fm42 = constant -42.0: f32 diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp --- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp +++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp @@ -390,16 +390,20 @@ } }; +template struct TestProgressiveVectorToSCFLoweringPatterns - : public PassWrapper, FunctionPass> { void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } void runOnFunction() override { - RewritePatternSet patterns(&getContext()); - populateProgressiveVectorToSCFConversionPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); + RewritePatternSet patterns(&this->getContext()); + ProgressiveVectorTransferToSCFOptions options; + options.unroll = Unroll; + populateProgressiveVectorToSCFConversionPatterns(patterns, options); + (void)applyPatternsAndFoldGreedily(this->getFunction(), + std::move(patterns)); } }; @@ -450,9 +454,18 @@ "test-vector-transfer-lowering-patterns", "Test conversion patterns to lower transfer ops to other vector ops"); - PassRegistration transferOpToSCF( - "test-progressive-convert-vector-to-scf", - "Test conversion patterns to progressively lower transfer ops to SCF"); + PassRegistration> + transferOpToSCF("test-progressive-convert-vector-to-scf", + "Test conversion patterns to progressively lower " + "transfer ops to SCF"); + + PassRegistration> + transferOpToSCFUnrolled( + "test-unrolled-progressive-convert-vector-to-scf", + "Test conversion patterns to progressively lower transfer ops to SCF" + "(unrolled variant)"); PassRegistration multiDimReductionOpLoweringPass(