diff --git a/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h b/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h --- a/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h +++ b/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h @@ -9,13 +9,152 @@ #ifndef MLIR_CONVERSION_VECTORTOSCF_VECTORTOSCF_H_ #define MLIR_CONVERSION_VECTORTOSCF_VECTORTOSCF_H_ +#include "mlir/IR/PatternMatch.h" + namespace mlir { class MLIRContext; class OwningRewritePatternList; -/// Collect a set of patterns to convert from the Vector dialect to loops + std. -void populateVectorToSCFConversionPatterns(OwningRewritePatternList &patterns, - MLIRContext *context); +/// Control whether unrolling is used when lowering vector transfer ops to SCF. +/// +/// Case 1: +/// ======= +/// When `unroll` is false, a temporary buffer is created through which +/// individual 1-D vector are staged. this is consistent with the lack of an +/// LLVM instruction to dynamically index into an aggregate (see the Vector +/// dialect lowering to LLVM deep dive). +/// An instruction such as: +/// ``` +/// vector.transfer_write %vec, %A[%base, %base] : +/// vector<17x15xf32>, memref +/// ``` +/// Lowers to pseudo-IR resembling: +/// ``` +/// %0 = alloc() : memref<17xvector<15xf32>> +/// %1 = vector.type_cast %0 : +/// memref<17xvector<15xf32>> to memref> +/// store %vec, %1[] : memref> +/// %dim = dim %A, 0 : memref +/// affine.for %I = 0 to 17 { +/// %add = affine.apply %I + %base +/// %cmp = cmpi "slt", %add, %dim : index +/// scf.if %cmp { +/// %vec_1d = load %0[%I] : memref<17xvector<15xf32>> +/// vector.transfer_write %vec_1d, %A[%add, %base] : +/// vector<15xf32>, memref +/// ``` +/// +/// Case 2: +/// ======= +/// When `unroll` is true, the temporary buffer is skipped and static indices +/// into aggregates can be used (see the Vector dialect lowering to LLVM deep +/// dive). +/// An instruction such as: +/// ``` +/// vector.transfer_write %vec, %A[%base, %base] : +/// vector<17x15xf32>, memref +/// ``` +/// Lowers to pseudo-IR resembling: +struct VectorTransferToSCFOptions { + /// The tile sizes by which to tile. + bool unroll = false; + VectorTransferToSCFOptions &setUnroll(bool u) { + unroll = u; + return *this; + } +}; + +/// Implements lowering of TransferReadOp and TransferWriteOp to a +/// proper abstraction for the hardware. +/// +/// There are multiple cases. +/// +/// Case A: Permutation Map does not permute or broadcast. +/// ====================================================== +/// +/// Progressive lowering occurs to 1-D vector transfer ops according to the +/// description in `VectorTransferToSCFOptions`. +/// +/// Case B: Permutation Map permutes and/or broadcast. +/// ====================================================== +/// +/// This path will be progressively deprecated and folded into the case above by +/// using vector broadcast and transpose operations. +/// +/// This path only emits a simple loop nest that performs clipped pointwise +/// copies from a remote to a locally allocated memory. +/// +/// Consider the case: +/// +/// ```mlir +/// // Read the slice `%A[%i0, %i1:%i1+256, %i2:%i2+32]` into +/// // vector<32x256xf32> and pad with %f0 to handle the boundary case: +/// %f0 = constant 0.0f : f32 +/// scf.for %i0 = 0 to %0 { +/// scf.for %i1 = 0 to %1 step %c256 { +/// scf.for %i2 = 0 to %2 step %c32 { +/// %v = vector.transfer_read %A[%i0, %i1, %i2], %f0 +/// {permutation_map: (d0, d1, d2) -> (d2, d1)} : +/// memref, vector<32x256xf32> +/// }}} +/// ``` +/// +/// The rewriters construct loop and indices that access MemRef A in a pattern +/// resembling the following (while guaranteeing an always full-tile +/// abstraction): +/// +/// ```mlir +/// scf.for %d2 = 0 to %c256 { +/// scf.for %d1 = 0 to %c32 { +/// %s = %A[%i0, %i1 + %d1, %i2 + %d2] : f32 +/// %tmp[%d2, %d1] = %s +/// } +/// } +/// ``` +/// +/// In the current state, only a clipping transfer is implemented by `clip`, +/// which creates individual indexing expressions of the form: +/// +/// ```mlir-dsc +/// auto condMax = i + ii < N; +/// auto max = std_select(condMax, i + ii, N - one) +/// auto cond = i + ii < zero; +/// std_select(cond, zero, max); +/// ``` +/// +/// In the future, clipping should not be the only way and instead we should +/// load vectors + mask them. Similarly on the write side, load/mask/store for +/// implementing RMW behavior. +/// +/// Lowers TransferOp into a combination of: +/// 1. local memory allocation; +/// 2. perfect loop nest over: +/// a. scalar load/stores from local buffers (viewed as a scalar memref); +/// a. scalar store/load to original memref (with clipping). +/// 3. vector_load/store +/// 4. local memory deallocation. +/// Minor variations occur depending on whether a TransferReadOp or +/// a TransferWriteOp is rewritten. +template +struct VectorTransferRewriter : public RewritePattern { + explicit VectorTransferRewriter(VectorTransferToSCFOptions options, + MLIRContext *context); + + /// Used for staging the transfer in a local buffer. + MemRefType tmpMemRefType(TransferOpTy transfer) const; + + /// Performs the rewrite. + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override; + + /// See description of `VectorTransferToSCFOptions`. + VectorTransferToSCFOptions options; +}; + +/// Collect a set of patterns to convert from the Vector dialect to SCF + std. +void populateVectorToSCFConversionPatterns( + OwningRewritePatternList &patterns, MLIRContext *context, + const VectorTransferToSCFOptions &options = VectorTransferToSCFOptions()); } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -373,7 +373,10 @@ }]; let builders = [OpBuilder< "OpBuilder &builder, OperationState &result, Value source," - "ArrayRef">]; + "ArrayRef position">, + // Convenience builder which assumes the values are constant indices. + OpBuilder<"OpBuilder &builder, OperationState &result, Value source," + "ValueRange position">]; let extraClassDeclaration = [{ static StringRef getPositionAttrName() { return "position"; } VectorType getVectorType() { @@ -536,7 +539,11 @@ let builders = [OpBuilder< "OpBuilder &builder, OperationState &result, Value source, " # - "Value dest, ArrayRef">]; + "Value dest, ArrayRef position">, + OpBuilder< + // Convenience builder which assumes all values are constant indices. + "OpBuilder &builder, OperationState &result, Value source, " # + "Value dest, ValueRange position">]; let extraClassDeclaration = [{ static StringRef getPositionAttrName() { return "position"; } Type getSourceType() { return source().getType(); } diff --git a/mlir/include/mlir/Dialect/Vector/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/VectorUtils.h --- a/mlir/include/mlir/Dialect/Vector/VectorUtils.h +++ b/mlir/include/mlir/Dialect/Vector/VectorUtils.h @@ -1,4 +1,4 @@ -//===- VectorUtils.h - VectorOps Utilities ------------------*- C++ -*-=======// +//===- VectorUtils.h - Vector Utilities -------------------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -26,18 +26,28 @@ class Value; class VectorType; +/// Return the number of elements of basis, `0` if empty. +int64_t computeMaxLinearIndex(ArrayRef basis); + +/// Given a shape with sizes greater than 0 along all dimensions, +/// return the distance, in number of elements, between a slice in a dimension +/// and the next slice in the same dimension. +/// e.g. shape[3, 4, 5] -> linearization_basis[20, 5, 1] +SmallVector computeStrides(ArrayRef shape); + /// Given the shape and sizes of a vector, returns the corresponding /// strides for each dimension. +/// TODO: needs better doc of how it is used. SmallVector computeStrides(ArrayRef shape, ArrayRef sizes); /// Computes and returns the linearized index of 'offsets' w.r.t. 'basis'. int64_t linearize(ArrayRef offsets, ArrayRef basis); -/// Given the slice strides together with a linear index in the dimension +/// Given the strides together with a linear index in the dimension /// space, returns the vector-space offsets in each dimension for a /// de-linearized index. -SmallVector delinearize(ArrayRef sliceStrides, +SmallVector delinearize(ArrayRef strides, int64_t linearIndex); /// Given the target sizes of a vector, together with vector-space offsets, diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -19,6 +19,7 @@ #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" #include "mlir/Dialect/Vector/EDSC/Intrinsics.h" #include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/Dialect/Vector/VectorUtils.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" @@ -73,8 +74,9 @@ /// template class NDTransferOpHelper { public: - NDTransferOpHelper(PatternRewriter &rewriter, ConcreteOp xferOp) - : rewriter(rewriter), loc(xferOp.getLoc()), + NDTransferOpHelper(PatternRewriter &rewriter, ConcreteOp xferOp, + const VectorTransferToSCFOptions &options) + : rewriter(rewriter), options(options), loc(xferOp.getLoc()), scope(std::make_unique(rewriter, loc)), xferOp(xferOp), op(xferOp.getOperation()) { vectorType = xferOp.getVectorType(); @@ -102,19 +104,17 @@ template void emitLoops(Lambda loopBodyBuilder); /// Operate within the body of `emitLoops` to: - /// 1. Compute the indexings `majorIvs + majorOffsets`. - /// 2. Compute a boolean that determines whether the first `majorIvs.rank()` + /// 1. Compute the indexings `majorIvs + majorOffsets` and save them in + /// `majorIvsPlusOffsets`. + /// 2. Return a boolean that determines whether the first `majorIvs.rank()` /// dimensions `majorIvs + majorOffsets` are all within `memrefBounds`. - /// 3. Create an IfOp conditioned on the boolean in step 2. - /// 4. Call a `thenBlockBuilder` and an `elseBlockBuilder` to append - /// operations to the IfOp blocks as appropriate. - template - void emitInBounds(ValueRange majorIvs, ValueRange majorOffsets, - MemRefBoundsCapture &memrefBounds, - LambdaThen thenBlockBuilder, LambdaElse elseBlockBuilder); + Value emitInBoundsCondition(ValueRange majorIvs, ValueRange majorOffsets, + MemRefBoundsCapture &memrefBounds, + SmallVectorImpl &majorIvsPlusOffsets); /// Common state to lower vector transfer ops. PatternRewriter &rewriter; + const VectorTransferToSCFOptions &options; Location loc; std::unique_ptr scope; ConcreteOp xferOp; @@ -136,27 +136,43 @@ void NDTransferOpHelper::emitLoops(Lambda loopBodyBuilder) { /// Loop nest operates on the major dimensions MemRefBoundsCapture memrefBoundsCapture(xferOp.memref()); - VectorBoundsCapture vectorBoundsCapture(majorVectorType); - auto majorLbs = vectorBoundsCapture.getLbs(); - auto majorUbs = vectorBoundsCapture.getUbs(); - auto majorSteps = vectorBoundsCapture.getSteps(); - SmallVector majorIvs(vectorBoundsCapture.rank()); - AffineLoopNestBuilder(majorIvs, majorLbs, majorUbs, majorSteps)([&] { + + if (options.unroll) { + auto shape = majorVectorType.getShape(); + auto strides = computeStrides(shape); + unsigned numUnrolledInstances = computeMaxLinearIndex(shape); ValueRange indices(xferOp.indices()); - loopBodyBuilder(majorIvs, indices.take_front(leadingRank), - indices.drop_front(leadingRank).take_front(majorRank), - indices.take_back(minorRank), memrefBoundsCapture); - }); + for (unsigned idx = 0; idx < numUnrolledInstances; ++idx) { + auto offsets = delinearize(strides, idx); + auto offsetValues = + llvm::to_vector<4>(llvm::map_range(offsets, [](int64_t off) -> Value { + return std_constant_index(off); + })); + loopBodyBuilder(offsetValues, indices.take_front(leadingRank), + indices.drop_front(leadingRank).take_front(majorRank), + indices.take_back(minorRank), memrefBoundsCapture); + } + } else { + VectorBoundsCapture vectorBoundsCapture(majorVectorType); + auto majorLbs = vectorBoundsCapture.getLbs(); + auto majorUbs = vectorBoundsCapture.getUbs(); + auto majorSteps = vectorBoundsCapture.getSteps(); + SmallVector majorIvs(vectorBoundsCapture.rank()); + AffineLoopNestBuilder(majorIvs, majorLbs, majorUbs, majorSteps)([&] { + ValueRange indices(xferOp.indices()); + loopBodyBuilder(majorIvs, indices.take_front(leadingRank), + indices.drop_front(leadingRank).take_front(majorRank), + indices.take_back(minorRank), memrefBoundsCapture); + }); + } } template -template -void NDTransferOpHelper::emitInBounds( +Value NDTransferOpHelper::emitInBoundsCondition( ValueRange majorIvs, ValueRange majorOffsets, - MemRefBoundsCapture &memrefBounds, LambdaThen thenBlockBuilder, - LambdaElse elseBlockBuilder) { - Value inBounds; - SmallVector majorIvsPlusOffsets; + MemRefBoundsCapture &memrefBounds, + SmallVectorImpl &majorIvsPlusOffsets) { + Value inBoundsCondition; majorIvsPlusOffsets.reserve(majorIvs.size()); unsigned idx = 0; for (auto it : llvm::zip(majorIvs, majorOffsets, memrefBounds.getUbs())) { @@ -164,41 +180,32 @@ using namespace mlir::edsc::op; majorIvsPlusOffsets.push_back(iv + off); if (xferOp.isMaskedDim(leadingRank + idx)) { - Value inBounds2 = majorIvsPlusOffsets.back() < ub; - inBounds = (inBounds) ? (inBounds && inBounds2) : inBounds2; + Value inBounds = majorIvsPlusOffsets.back() < ub; + inBoundsCondition = + (inBoundsCondition) ? (inBoundsCondition && inBounds) : inBounds; } ++idx; } - - if (inBounds) { - auto ifOp = ScopedContext::getBuilderRef().create( - ScopedContext::getLocation(), TypeRange{}, inBounds, - /*withElseRegion=*/std::is_same()); - BlockBuilder(&ifOp.thenRegion().front(), - Append())([&] { thenBlockBuilder(majorIvsPlusOffsets); }); - if (std::is_same()) - BlockBuilder(&ifOp.elseRegion().front(), - Append())([&] { elseBlockBuilder(majorIvsPlusOffsets); }); - } else { - // Just build the body of the then block right here. - thenBlockBuilder(majorIvsPlusOffsets); - } + return inBoundsCondition; } template <> LogicalResult NDTransferOpHelper::doReplace() { - Value alloc = std_alloc(memRefMinorVectorType); + Value alloc, result; + if (options.unroll) + result = std_splat(vectorType, xferOp.padding()); + else + alloc = std_alloc(memRefMinorVectorType); emitLoops([&](ValueRange majorIvs, ValueRange leadingOffsets, ValueRange majorOffsets, ValueRange minorOffsets, MemRefBoundsCapture &memrefBounds) { - // If in-bounds, index into memref and lower to 1-D transfer read. - auto thenBlockBuilder = [&](ValueRange majorIvsPlusOffsets) { + /// Lambda to load 1-D vector in the current loop ivs + offset context. + auto load1DVector = [&](ValueRange majorIvsPlusOffsets) -> Value { SmallVector indexing; indexing.reserve(leadingRank + majorRank + minorRank); indexing.append(leadingOffsets.begin(), leadingOffsets.end()); indexing.append(majorIvsPlusOffsets.begin(), majorIvsPlusOffsets.end()); indexing.append(minorOffsets.begin(), minorOffsets.end()); - Value memref = xferOp.memref(); auto map = TransferReadOp::getTransferMinorIdentityMap( xferOp.getMemRefType(), minorVectorType); @@ -207,45 +214,100 @@ OpBuilder &b = ScopedContext::getBuilderRef(); masked = b.getBoolArrayAttr({true}); } - auto loaded1D = vector_transfer_read(minorVectorType, memref, indexing, - AffineMapAttr::get(map), - xferOp.padding(), masked); - // Store the 1-D vector. - std_store(loaded1D, alloc, majorIvs); - }; - // If out-of-bounds, just store a splatted vector. - auto elseBlockBuilder = [&](ValueRange majorIvsPlusOffsets) { - auto vector = std_splat(minorVectorType, xferOp.padding()); - std_store(vector, alloc, majorIvs); + return vector_transfer_read(minorVectorType, memref, indexing, + AffineMapAttr::get(map), xferOp.padding(), + masked); }; - emitInBounds(majorIvs, majorOffsets, memrefBounds, thenBlockBuilder, - elseBlockBuilder); + + // 1. Compute the inbBoundsCondition in the current loops ivs + offset + // context. + SmallVector majorIvsPlusOffsets; + Value inBoundsCondition = emitInBoundsCondition( + majorIvs, majorOffsets, memrefBounds, majorIvsPlusOffsets); + + if (inBoundsCondition) { + // 2. If the condition is not null, we need an IfOp, which may yield + // if `options.unroll` is true. + SmallVector resultType; + if (options.unroll) + resultType.push_back(vectorType); + auto ifOp = ScopedContext::getBuilderRef().create( + ScopedContext::getLocation(), resultType, inBoundsCondition, + /*withElseRegion=*/true); + + // 3.a. If in-bounds, progressively lower to a 1-D transfer read. + BlockBuilder(&ifOp.thenRegion().front(), Append())([&] { + Value loaded1D = load1DVector(majorIvsPlusOffsets); + // 3.a.i. If `options.unroll` is true, insert the 1-D vector in the + // aggregate. We must yield and merge with the `else` branch. + if (options.unroll) { + loop_yield(Value{vector_insert(loaded1D, result, majorIvs)}); + return; + } + // 3.a.ii. Otherwise, just go through the temporary `alloc`. + std_store(loaded1D, alloc, majorIvs); + }); + + // 3.b. If not in-bounds, splat a 1-D vector. + BlockBuilder(&ifOp.elseRegion().front(), Append())([&] { + auto vector = std_splat(minorVectorType, xferOp.padding()); + // 3.a.i. If `options.unroll` is true, insert the 1-D vector in the + // aggregate. We must yield and merge with the `then` branch. + if (options.unroll) { + loop_yield(Value{vector_insert(vector, result, majorIvs)}); + return; + } + // 3.b.ii. Otherwise, just go through the temporary `alloc`. + std_store(vector, alloc, majorIvs); + }); + if (!resultType.empty()) + result = *ifOp.results().begin(); + } else { + // 4. Guaranteed in-bounds, progressively lower to a 1-D transfer read. + Value loaded1D = load1DVector(majorIvsPlusOffsets); + // 5.a. If `options.unroll` is true, insert the 1-D vector in the + // aggregate. + if (options.unroll) + result = vector_insert(loaded1D, result, majorIvs); + // 5.b. Otherwise, just go through the temporary `alloc`. + else + std_store(loaded1D, alloc, majorIvs); + } }); - Value loaded = - std_load(vector_type_cast(MemRefType::get({}, vectorType), alloc)); - rewriter.replaceOp(op, loaded); + assert((!options.unroll ^ result) && "Expected resulting Value iff unroll"); + if (!result) + result = std_load(vector_type_cast(MemRefType::get({}, vectorType), alloc)); + rewriter.replaceOp(op, result); return success(); } template <> LogicalResult NDTransferOpHelper::doReplace() { - Value alloc = std_alloc(memRefMinorVectorType); - - std_store(xferOp.vector(), - vector_type_cast(MemRefType::get({}, vectorType), alloc)); + Value alloc; + if (!options.unroll) { + alloc = std_alloc(memRefMinorVectorType); + std_store(xferOp.vector(), + vector_type_cast(MemRefType::get({}, vectorType), alloc)); + } emitLoops([&](ValueRange majorIvs, ValueRange leadingOffsets, ValueRange majorOffsets, ValueRange minorOffsets, MemRefBoundsCapture &memrefBounds) { - auto thenBlockBuilder = [&](ValueRange majorIvsPlusOffsets) { + // Lower to 1-D vector_transfer_write and let recursion handle it. + auto emitTransferWrite = [&](ValueRange majorIvsPlusOffsets) { SmallVector indexing; indexing.reserve(leadingRank + majorRank + minorRank); indexing.append(leadingOffsets.begin(), leadingOffsets.end()); indexing.append(majorIvsPlusOffsets.begin(), majorIvsPlusOffsets.end()); indexing.append(minorOffsets.begin(), minorOffsets.end()); - // Lower to 1-D vector_transfer_write and let recursion handle it. - Value loaded1D = std_load(alloc, majorIvs); + Value result; + // If `options.unroll` is true, extract the 1-D vector from the + // aggregate. + if (options.unroll) + result = vector_extract(xferOp.vector(), majorIvs); + else + result = std_load(alloc, majorIvs); auto map = TransferWriteOp::getTransferMinorIdentityMap( xferOp.getMemRefType(), minorVectorType); ArrayAttr masked; @@ -253,13 +315,28 @@ OpBuilder &b = ScopedContext::getBuilderRef(); masked = b.getBoolArrayAttr({true}); } - vector_transfer_write(loaded1D, xferOp.memref(), indexing, + vector_transfer_write(result, xferOp.memref(), indexing, AffineMapAttr::get(map), masked); }; - // Don't write anything when out of bounds. - auto elseBlockBuilder = [&](ValueRange majorIvsPlusOffsets) {}; - emitInBounds(majorIvs, majorOffsets, memrefBounds, thenBlockBuilder, - elseBlockBuilder); + + // 1. Compute the inbBoundsCondition in the current loops ivs + offset + // context. + SmallVector majorIvsPlusOffsets; + Value inBoundsCondition = emitInBoundsCondition( + majorIvs, majorOffsets, memrefBounds, majorIvsPlusOffsets); + + if (inBoundsCondition) { + // 2.a. If the condition is not null, we need an IfOp, to write + // conditionally. Progressively lower to a 1-D transfer write. + auto ifOp = ScopedContext::getBuilderRef().create( + ScopedContext::getLocation(), TypeRange{}, inBoundsCondition, + /*withElseRegion=*/false); + BlockBuilder(&ifOp.thenRegion().front(), + Append())([&] { emitTransferWrite(majorIvsPlusOffsets); }); + } else { + // 2.b. Guaranteed in-bounds. Progressively lower to a 1-D transfer write. + emitTransferWrite(majorIvsPlusOffsets); + } }); rewriter.eraseOp(op); @@ -351,81 +428,20 @@ return clippedScalarAccessExprs; } -namespace { - -/// Implements lowering of TransferReadOp and TransferWriteOp to a -/// proper abstraction for the hardware. -/// -/// For now, we only emit a simple loop nest that performs clipped pointwise -/// copies from a remote to a locally allocated memory. -/// -/// Consider the case: -/// -/// ```mlir -/// // Read the slice `%A[%i0, %i1:%i1+256, %i2:%i2+32]` into -/// // vector<32x256xf32> and pad with %f0 to handle the boundary case: -/// %f0 = constant 0.0f : f32 -/// scf.for %i0 = 0 to %0 { -/// scf.for %i1 = 0 to %1 step %c256 { -/// scf.for %i2 = 0 to %2 step %c32 { -/// %v = vector.transfer_read %A[%i0, %i1, %i2], %f0 -/// {permutation_map: (d0, d1, d2) -> (d2, d1)} : -/// memref, vector<32x256xf32> -/// }}} -/// ``` -/// -/// The rewriters construct loop and indices that access MemRef A in a pattern -/// resembling the following (while guaranteeing an always full-tile -/// abstraction): -/// -/// ```mlir -/// scf.for %d2 = 0 to %c256 { -/// scf.for %d1 = 0 to %c32 { -/// %s = %A[%i0, %i1 + %d1, %i2 + %d2] : f32 -/// %tmp[%d2, %d1] = %s -/// } -/// } -/// ``` -/// -/// In the current state, only a clipping transfer is implemented by `clip`, -/// which creates individual indexing expressions of the form: -/// -/// ```mlir-dsc -/// auto condMax = i + ii < N; -/// auto max = std_select(condMax, i + ii, N - one) -/// auto cond = i + ii < zero; -/// std_select(cond, zero, max); -/// ``` -/// -/// In the future, clipping should not be the only way and instead we should -/// load vectors + mask them. Similarly on the write side, load/mask/store for -/// implementing RMW behavior. -/// -/// Lowers TransferOp into a combination of: -/// 1. local memory allocation; -/// 2. perfect loop nest over: -/// a. scalar load/stores from local buffers (viewed as a scalar memref); -/// a. scalar store/load to original memref (with clipping). -/// 3. vector_load/store -/// 4. local memory deallocation. -/// Minor variations occur depending on whether a TransferReadOp or -/// a TransferWriteOp is rewritten. template -struct VectorTransferRewriter : public RewritePattern { - explicit VectorTransferRewriter(MLIRContext *context) - : RewritePattern(TransferOpTy::getOperationName(), 1, context) {} - - /// Used for staging the transfer in a local scalar buffer. - MemRefType tmpMemRefType(TransferOpTy transfer) const { - auto vectorType = transfer.getVectorType(); - return MemRefType::get(vectorType.getShape(), vectorType.getElementType(), - {}, 0); - } +VectorTransferRewriter::VectorTransferRewriter( + VectorTransferToSCFOptions options, MLIRContext *context) + : RewritePattern(TransferOpTy::getOperationName(), 1, context), + options(options) {} - /// Performs the rewrite. - LogicalResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override; -}; +/// Used for staging the transfer in a local buffer. +template +MemRefType VectorTransferRewriter::tmpMemRefType( + TransferOpTy transfer) const { + auto vectorType = transfer.getVectorType(); + return MemRefType::get(vectorType.getShape(), vectorType.getElementType(), {}, + 0); +} /// Lowers TransferReadOp into a combination of: /// 1. local memory allocation; @@ -479,7 +495,8 @@ if (AffineMap::isMinorIdentity(transfer.permutation_map())) { // If > 1D, emit a bunch of loops around 1-D vector transfers. if (transfer.getVectorType().getRank() > 1) - return NDTransferOpHelper(rewriter, transfer).doReplace(); + return NDTransferOpHelper(rewriter, transfer, options) + .doReplace(); // If 1-D this is now handled by the target-specific lowering. if (transfer.getVectorType().getRank() == 1) return failure(); @@ -551,7 +568,7 @@ if (AffineMap::isMinorIdentity(transfer.permutation_map())) { // If > 1D, emit a bunch of loops around 1-D vector transfers. if (transfer.getVectorType().getRank() > 1) - return NDTransferOpHelper(rewriter, transfer) + return NDTransferOpHelper(rewriter, transfer, options) .doReplace(); // If 1-D this is now handled by the target-specific lowering. if (transfer.getVectorType().getRank() == 1) @@ -596,10 +613,10 @@ return success(); } -} // namespace - void mlir::populateVectorToSCFConversionPatterns( - OwningRewritePatternList &patterns, MLIRContext *context) { + OwningRewritePatternList &patterns, MLIRContext *context, + const VectorTransferToSCFOptions &options) { patterns.insert, - VectorTransferRewriter>(context); + VectorTransferRewriter>(options, + context); } diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -470,6 +470,16 @@ result.addAttribute(getPositionAttrName(), positionAttr); } +// Convenience builder which assumes the values are constant indices. +void vector::ExtractOp::build(OpBuilder &builder, OperationState &result, + Value source, ValueRange position) { + SmallVector positionConstants = + llvm::to_vector<4>(llvm::map_range(position, [](Value pos) { + return pos.getDefiningOp().getValue(); + })); + build(builder, result, source, positionConstants); +} + static void print(OpAsmPrinter &p, vector::ExtractOp op) { p << op.getOperationName() << " " << op.vector() << op.position(); p.printOptionalAttrDict(op.getAttrs(), {"position"}); @@ -739,6 +749,16 @@ result.addAttribute(getPositionAttrName(), positionAttr); } +// Convenience builder which assumes the values are constant indices. +void InsertOp::build(OpBuilder &builder, OperationState &result, Value source, + Value dest, ValueRange position) { + SmallVector positionConstants = + llvm::to_vector<4>(llvm::map_range(position, [](Value pos) { + return pos.getDefiningOp().getValue(); + })); + build(builder, result, source, dest, positionConstants); +} + static LogicalResult verify(InsertOp op) { auto positionAttr = op.position().getValue(); if (positionAttr.empty()) diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -39,33 +39,6 @@ using namespace mlir; using llvm::dbgs; -/// Given a shape with sizes greater than 0 along all dimensions, -/// returns the distance, in number of elements, between a slice in a dimension -/// and the next slice in the same dimension. -/// e.g. shape[3, 4, 5] -> linearization_basis[20, 5, 1] -static SmallVector computeStrides(ArrayRef shape) { - if (shape.empty()) - return {}; - SmallVector tmp; - tmp.reserve(shape.size()); - int64_t running = 1; - for (auto size : llvm::reverse(shape)) { - assert(size > 0 && "size must be nonnegative"); - tmp.push_back(running); - running *= size; - } - return SmallVector(tmp.rbegin(), tmp.rend()); -} - -static int64_t computeMaxLinearIndex(ArrayRef basis) { - if (basis.empty()) - return 0; - int64_t res = 1; - for (auto b : basis) - res *= b; - return res; -} - // Clones `op` into a new operations that takes `operands` and returns // `resultTypes`. static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc, diff --git a/mlir/lib/Dialect/Vector/VectorUtils.cpp b/mlir/lib/Dialect/Vector/VectorUtils.cpp --- a/mlir/lib/Dialect/Vector/VectorUtils.cpp +++ b/mlir/lib/Dialect/Vector/VectorUtils.cpp @@ -28,6 +28,34 @@ using namespace mlir; +/// Return the number of elements of basis, `0` if empty. +int64_t mlir::computeMaxLinearIndex(ArrayRef basis) { + if (basis.empty()) + return 0; + int64_t res = 1; + for (auto b : basis) + res *= b; + return res; +} + +/// Given a shape with sizes greater than 0 along all dimensions, +/// return the distance, in number of elements, between a slice in a dimension +/// and the next slice in the same dimension. +/// e.g. shape[3, 4, 5] -> linearization_basis[20, 5, 1] +SmallVector mlir::computeStrides(ArrayRef shape) { + if (shape.empty()) + return {}; + SmallVector tmp; + tmp.reserve(shape.size()); + int64_t running = 1; + for (auto size : llvm::reverse(shape)) { + assert(size > 0 && "size must be nonnegative"); + tmp.push_back(running); + running *= size; + } + return SmallVector(tmp.rbegin(), tmp.rend()); +} + SmallVector mlir::computeStrides(ArrayRef shape, ArrayRef sizes) { int64_t rank = shape.size(); diff --git a/mlir/test/Conversion/VectorToLoops/vector-to-loops.mlir b/mlir/test/Conversion/VectorToLoops/vector-to-loops.mlir --- a/mlir/test/Conversion/VectorToLoops/vector-to-loops.mlir +++ b/mlir/test/Conversion/VectorToLoops/vector-to-loops.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt %s -test-convert-vector-to-scf -split-input-file | FileCheck %s +// RUN: mlir-opt %s -test-convert-vector-to-scf=full-unroll=true -split-input-file | FileCheck %s --check-prefix=FULL-UNROLL // CHECK-LABEL: func @materialize_read_1d() { func @materialize_read_1d() { @@ -213,32 +214,74 @@ // CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (d0 + s0)> // CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d1)> +// FULL-UNROLL-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d1)> +// FULL-UNROLL-DAG: #[[MAP1:.*]] = affine_map<()[s0] -> (s0 + 1)> +// FULL-UNROLL-DAG: #[[MAP2:.*]] = affine_map<()[s0] -> (s0 + 2)> + + // CHECK-LABEL: transfer_read_progressive( // CHECK-SAME: %[[A:[a-zA-Z0-9]+]]: memref, // CHECK-SAME: %[[base:[a-zA-Z0-9]+]]: index -func @transfer_read_progressive(%A : memref, %base: index) -> vector<17x15xf32> { +// FULL-UNROLL-LABEL: transfer_read_progressive( +// FULL-UNROLL-SAME: %[[A:[a-zA-Z0-9]+]]: memref, +// FULL-UNROLL-SAME: %[[base:[a-zA-Z0-9]+]]: index +func @transfer_read_progressive(%A : memref, %base: index) -> vector<3x15xf32> { // CHECK: %[[cst:.*]] = constant 7.000000e+00 : f32 %f7 = constant 7.0: f32 // CHECK-DAG: %[[splat:.*]] = constant dense<7.000000e+00> : vector<15xf32> - // CHECK-DAG: %[[alloc:.*]] = alloc() : memref<17xvector<15xf32>> + // CHECK-DAG: %[[alloc:.*]] = alloc() : memref<3xvector<15xf32>> // CHECK-DAG: %[[dim:.*]] = dim %[[A]], 0 : memref - // CHECK: affine.for %[[I:.*]] = 0 to 17 { + // CHECK: affine.for %[[I:.*]] = 0 to 3 { // CHECK: %[[add:.*]] = affine.apply #[[MAP0]](%[[I]])[%[[base]]] // CHECK: %[[cond1:.*]] = cmpi "slt", %[[add]], %[[dim]] : index // CHECK: scf.if %[[cond1]] { // CHECK: %[[vec_1d:.*]] = vector.transfer_read %[[A]][%[[add]], %[[base]]], %[[cst]] : memref, vector<15xf32> - // CHECK: store %[[vec_1d]], %[[alloc]][%[[I]]] : memref<17xvector<15xf32>> + // CHECK: store %[[vec_1d]], %[[alloc]][%[[I]]] : memref<3xvector<15xf32>> // CHECK: } else { - // CHECK: store %[[splat]], %[[alloc]][%[[I]]] : memref<17xvector<15xf32>> + // CHECK: store %[[splat]], %[[alloc]][%[[I]]] : memref<3xvector<15xf32>> // CHECK: } - // CHECK: %[[vmemref:.*]] = vector.type_cast %[[alloc]] : memref<17xvector<15xf32>> to memref> - // CHECK: %[[cst:.*]] = load %[[vmemref]][] : memref> - %f = vector.transfer_read %A[%base, %base], %f7 - {permutation_map = affine_map<(d0, d1) -> (d0, d1)>} : - memref, vector<17x15xf32> + // CHECK: %[[vmemref:.*]] = vector.type_cast %[[alloc]] : memref<3xvector<15xf32>> to memref> + // CHECK: %[[cst:.*]] = load %[[vmemref]][] : memref> + + // FULL-UNROLL: %[[pad:.*]] = constant 7.000000e+00 : f32 + // FULL-UNROLL: %[[VEC0:.*]] = constant dense<7.000000e+00> : vector<3x15xf32> + // FULL-UNROLL: %[[SPLAT:.*]] = constant dense<7.000000e+00> : vector<15xf32> + // FULL-UNROLL: %[[DIM:.*]] = dim %[[A]], 0 : memref + // FULL-UNROLL: cmpi "slt", %[[base]], %[[DIM]] : index + // FULL-UNROLL: %[[VEC1:.*]] = scf.if %{{.*}} -> (vector<3x15xf32>) { + // FULL-UNROLL: vector.transfer_read %[[A]][%[[base]], %[[base]]], %[[pad]] : memref, vector<15xf32> + // FULL-UNROLL: vector.insert %{{.*}}, %[[VEC0]] [0] : vector<15xf32> into vector<3x15xf32> + // FULL-UNROLL: scf.yield %{{.*}} : vector<3x15xf32> + // FULL-UNROLL: } else { + // FULL-UNROLL: vector.insert %{{.*}}, %[[VEC0]] [0] : vector<15xf32> into vector<3x15xf32> + // FULL-UNROLL: scf.yield %{{.*}} : vector<3x15xf32> + // FULL-UNROLL: } + // FULL-UNROLL: affine.apply #[[MAP1]]()[%[[base]]] + // FULL-UNROLL: cmpi "slt", %{{.*}}, %[[DIM]] : index + // FULL-UNROLL: %[[VEC2:.*]] = scf.if %{{.*}} -> (vector<3x15xf32>) { + // FULL-UNROLL: vector.transfer_read %[[A]][%{{.*}}, %[[base]]], %[[pad]] : memref, vector<15xf32> + // FULL-UNROLL: vector.insert %{{.*}}, %[[VEC1]] [1] : vector<15xf32> into vector<3x15xf32> + // FULL-UNROLL: scf.yield %{{.*}} : vector<3x15xf32> + // FULL-UNROLL: } else { + // FULL-UNROLL: vector.insert %{{.*}}, %[[VEC1]] [1] : vector<15xf32> into vector<3x15xf32> + // FULL-UNROLL: scf.yield %{{.*}} : vector<3x15xf32> + // FULL-UNROLL: } + // FULL-UNROLL: affine.apply #[[MAP2]]()[%[[base]]] + // FULL-UNROLL: cmpi "slt", %{{.*}}, %[[DIM]] : index + // FULL-UNROLL: %[[VEC3:.*]] = scf.if %{{.*}} -> (vector<3x15xf32>) { + // FULL-UNROLL: vector.transfer_read %[[A]][%{{.*}}, %[[base]]], %[[pad]] : memref, vector<15xf32> + // FULL-UNROLL: vector.insert %{{.*}}, %[[VEC2]] [2] : vector<15xf32> into vector<3x15xf32> + // FULL-UNROLL: scf.yield %{{.*}} : vector<3x15xf32> + // FULL-UNROLL: } else { + // FULL-UNROLL: vector.insert %{{.*}}, %[[VEC2]] [2] : vector<15xf32> into vector<3x15xf32> + // FULL-UNROLL: scf.yield %{{.*}} : vector<3x15xf32> + // FULL-UNROLL: } - return %f: vector<17x15xf32> + %f = vector.transfer_read %A[%base, %base], %f7 : + memref, vector<3x15xf32> + + return %f: vector<3x15xf32> } // ----- @@ -246,25 +289,52 @@ // CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (d0 + s0)> // CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d1)> +// FULL-UNROLL-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d1)> +// FULL-UNROLL-DAG: #[[MAP1:.*]] = affine_map<()[s0] -> (s0 + 1)> +// FULL-UNROLL-DAG: #[[MAP2:.*]] = affine_map<()[s0] -> (s0 + 2)> + // CHECK-LABEL: transfer_write_progressive( // CHECK-SAME: %[[A:[a-zA-Z0-9]+]]: memref, // CHECK-SAME: %[[base:[a-zA-Z0-9]+]]: index, -// CHECK-SAME: %[[vec:[a-zA-Z0-9]+]]: vector<17x15xf32> -func @transfer_write_progressive(%A : memref, %base: index, %vec: vector<17x15xf32>) { - // CHECK: %[[alloc:.*]] = alloc() : memref<17xvector<15xf32>> - // CHECK: %[[vmemref:.*]] = vector.type_cast %[[alloc]] : memref<17xvector<15xf32>> to memref> - // CHECK: store %[[vec]], %[[vmemref]][] : memref> +// CHECK-SAME: %[[vec:[a-zA-Z0-9]+]]: vector<3x15xf32> +// FULL-UNROLL-LABEL: transfer_write_progressive( +// FULL-UNROLL-SAME: %[[A:[a-zA-Z0-9]+]]: memref, +// FULL-UNROLL-SAME: %[[base:[a-zA-Z0-9]+]]: index, +// FULL-UNROLL-SAME: %[[vec:[a-zA-Z0-9]+]]: vector<3x15xf32> +func @transfer_write_progressive(%A : memref, %base: index, %vec: vector<3x15xf32>) { + // CHECK: %[[alloc:.*]] = alloc() : memref<3xvector<15xf32>> + // CHECK: %[[vmemref:.*]] = vector.type_cast %[[alloc]] : memref<3xvector<15xf32>> to memref> + // CHECK: store %[[vec]], %[[vmemref]][] : memref> // CHECK: %[[dim:.*]] = dim %[[A]], 0 : memref - // CHECK: affine.for %[[I:.*]] = 0 to 17 { + // CHECK: affine.for %[[I:.*]] = 0 to 3 { // CHECK: %[[add:.*]] = affine.apply #[[MAP0]](%[[I]])[%[[base]]] // CHECK: %[[cmp:.*]] = cmpi "slt", %[[add]], %[[dim]] : index // CHECK: scf.if %[[cmp]] { - // CHECK: %[[vec_1d:.*]] = load %0[%[[I]]] : memref<17xvector<15xf32>> + // CHECK: %[[vec_1d:.*]] = load %0[%[[I]]] : memref<3xvector<15xf32>> // CHECK: vector.transfer_write %[[vec_1d]], %[[A]][%[[add]], %[[base]]] : vector<15xf32>, memref // CHECK: } - vector.transfer_write %vec, %A[%base, %base] - {permutation_map = affine_map<(d0, d1) -> (d0, d1)>} : - vector<17x15xf32>, memref + + // FULL-UNROLL: %[[DIM:.*]] = dim %[[A]], 0 : memref + // FULL-UNROLL: %[[CMP0:.*]] = cmpi "slt", %[[base]], %[[DIM]] : index + // FULL-UNROLL: scf.if %[[CMP0]] { + // FULL-UNROLL: %[[V0:.*]] = vector.extract %[[vec]][0] : vector<3x15xf32> + // FULL-UNROLL: vector.transfer_write %[[V0]], %[[A]][%[[base]], %[[base]]] : vector<15xf32>, memref + // FULL-UNROLL: } + // FULL-UNROLL: %[[I1:.*]] = affine.apply #[[MAP1]]()[%[[base]]] + // FULL-UNROLL: %[[CMP1:.*]] = cmpi "slt", %[[I1]], %[[DIM]] : index + // FULL-UNROLL: scf.if %[[CMP1]] { + // FULL-UNROLL: %[[V1:.*]] = vector.extract %[[vec]][1] : vector<3x15xf32> + // FULL-UNROLL: vector.transfer_write %[[V1]], %[[A]][%[[I1]], %[[base]]] : vector<15xf32>, memref + // FULL-UNROLL: } + // FULL-UNROLL: %[[I2:.*]] = affine.apply #[[MAP2]]()[%[[base]]] + // FULL-UNROLL: %[[CMP2:.*]] = cmpi "slt", %[[I2]], %[[DIM]] : index + // FULL-UNROLL: scf.if %[[CMP2]] { + // FULL-UNROLL: %[[V2:.*]] = vector.extract %[[vec]][2] : vector<3x15xf32> + // FULL-UNROLL: vector.transfer_write %[[V2]], %[[A]][%[[I2]], %[[base]]] : vector<15xf32>, memref + // FULL-UNROLL: } + + vector.transfer_write %vec, %A[%base, %base] : + vector<3x15xf32>, memref return } @@ -273,20 +343,37 @@ // CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (d0 + s0)> // CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d1)> +// FULL-UNROLL-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d1)> +// FULL-UNROLL-DAG: #[[MAP1:.*]] = affine_map<()[s0] -> (s0 + 1)> +// FULL-UNROLL-DAG: #[[MAP2:.*]] = affine_map<()[s0] -> (s0 + 2)> + // CHECK-LABEL: transfer_write_progressive_not_masked( // CHECK-SAME: %[[A:[a-zA-Z0-9]+]]: memref, // CHECK-SAME: %[[base:[a-zA-Z0-9]+]]: index, -// CHECK-SAME: %[[vec:[a-zA-Z0-9]+]]: vector<17x15xf32> -func @transfer_write_progressive_not_masked(%A : memref, %base: index, %vec: vector<17x15xf32>) { +// CHECK-SAME: %[[vec:[a-zA-Z0-9]+]]: vector<3x15xf32> +// FULL-UNROLL-LABEL: transfer_write_progressive_not_masked( +// FULL-UNROLL-SAME: %[[A:[a-zA-Z0-9]+]]: memref, +// FULL-UNROLL-SAME: %[[base:[a-zA-Z0-9]+]]: index, +// FULL-UNROLL-SAME: %[[vec:[a-zA-Z0-9]+]]: vector<3x15xf32> +func @transfer_write_progressive_not_masked(%A : memref, %base: index, %vec: vector<3x15xf32>) { // CHECK-NOT: scf.if - // CHECK-NEXT: %[[alloc:.*]] = alloc() : memref<17xvector<15xf32>> - // CHECK-NEXT: %[[vmemref:.*]] = vector.type_cast %[[alloc]] : memref<17xvector<15xf32>> to memref> - // CHECK-NEXT: store %[[vec]], %[[vmemref]][] : memref> - // CHECK-NEXT: affine.for %[[I:.*]] = 0 to 17 { + // CHECK-NEXT: %[[alloc:.*]] = alloc() : memref<3xvector<15xf32>> + // CHECK-NEXT: %[[vmemref:.*]] = vector.type_cast %[[alloc]] : memref<3xvector<15xf32>> to memref> + // CHECK-NEXT: store %[[vec]], %[[vmemref]][] : memref> + // CHECK-NEXT: affine.for %[[I:.*]] = 0 to 3 { // CHECK-NEXT: %[[add:.*]] = affine.apply #[[MAP0]](%[[I]])[%[[base]]] - // CHECK-NEXT: %[[vec_1d:.*]] = load %0[%[[I]]] : memref<17xvector<15xf32>> + // CHECK-NEXT: %[[vec_1d:.*]] = load %0[%[[I]]] : memref<3xvector<15xf32>> // CHECK-NEXT: vector.transfer_write %[[vec_1d]], %[[A]][%[[add]], %[[base]]] : vector<15xf32>, memref + + // FULL-UNROLL: %[[VEC0:.*]] = vector.extract %[[vec]][0] : vector<3x15xf32> + // FULL-UNROLL: vector.transfer_write %[[VEC0]], %[[A]][%[[base]], %[[base]]] : vector<15xf32>, memref + // FULL-UNROLL: %[[I1:.*]] = affine.apply #[[MAP1]]()[%[[base]]] + // FULL-UNROLL: %[[VEC1:.*]] = vector.extract %[[vec]][1] : vector<3x15xf32> + // FULL-UNROLL: vector.transfer_write %2, %[[A]][%[[I1]], %[[base]]] : vector<15xf32>, memref + // FULL-UNROLL: %[[I2:.*]] = affine.apply #[[MAP2]]()[%[[base]]] + // FULL-UNROLL: %[[VEC2:.*]] = vector.extract %[[vec]][2] : vector<3x15xf32> + // FULL-UNROLL: vector.transfer_write %[[VEC2:.*]], %[[A]][%[[I2]], %[[base]]] : vector<15xf32>, memref vector.transfer_write %vec, %A[%base, %base] {masked = [false, false]} : - vector<17x15xf32>, memref + vector<3x15xf32>, memref return } diff --git a/mlir/test/lib/Transforms/TestVectorToSCFConversion.cpp b/mlir/test/lib/Transforms/TestVectorToSCFConversion.cpp --- a/mlir/test/lib/Transforms/TestVectorToSCFConversion.cpp +++ b/mlir/test/lib/Transforms/TestVectorToSCFConversion.cpp @@ -19,10 +19,20 @@ struct TestVectorToSCFPass : public PassWrapper { + TestVectorToSCFPass() = default; + TestVectorToSCFPass(const TestVectorToSCFPass &pass) {} + + Option fullUnroll{ + *this, "full-unroll", + llvm::cl::desc( + "Perform full unrolling when converting vector transfers to SCF"), + llvm::cl::init(false)}; + void runOnFunction() override { OwningRewritePatternList patterns; auto *context = &getContext(); - populateVectorToSCFConversionPatterns(patterns, context); + populateVectorToSCFConversionPatterns( + patterns, context, VectorTransferToSCFOptions().setUnroll(fullUnroll)); applyPatternsAndFoldGreedily(getFunction(), patterns); } };