diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h --- a/mlir/include/mlir/Dialect/Vector/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h @@ -56,22 +56,48 @@ }; /// Enum to control the lowering of `vector.transpose` operations. enum class VectorTransposeLowering { - // Lower transpose into element-wise extract and inserts. + /// Lower transpose into element-wise extract and inserts. EltWise = 0, /// Lower 2-D transpose to `vector.flat_transpose`, maps 1-1 to LLVM matrix /// intrinsics. Flat = 1, }; +/// Enum to control the splitting of `vector.transfer` operations into masked +/// and unmasked variants. +enum class VectorTransferSplit { + /// Do not split vector transfer operations. + None = 0, + /// Split using masked + unmasked vector.transfer operations. + VectorTransfer = 1, + /// Split using a unmasked vector.transfer + linalg.fill + linalg.copy + /// operations. + LinalgCopy = 2, + /// Do not split vector transfer operation but instead mark it as "unmasked". + ForceUnmasked = 3 +}; /// Structure to control the behavior of vector transform patterns. struct VectorTransformsOptions { + /// Option to control the lowering of vector.contract. VectorContractLowering vectorContractLowering = VectorContractLowering::Dot; - VectorTransposeLowering vectorTransposeLowering = - VectorTransposeLowering::EltWise; VectorTransformsOptions & setVectorTransformsOptions(VectorContractLowering opt) { vectorContractLowering = opt; return *this; } + /// Option to control the lowering of vector.transpose. + VectorTransposeLowering vectorTransposeLowering = + VectorTransposeLowering::EltWise; + VectorTransformsOptions & + setVectorTransposeLowering(VectorTransposeLowering opt) { + vectorTransposeLowering = opt; + return *this; + } + /// Option to control the splitting of vector transfers. + VectorTransferSplit vectorTransferSplit = VectorTransferSplit::None; + VectorTransformsOptions &setVectorTransferSplit(VectorTransferSplit opt) { + vectorTransferSplit = opt; + return *this; + } }; /// Collect a set of transformation patterns that are related to contracting diff --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h --- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h +++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h @@ -109,13 +109,13 @@ FilterConstraintType filter; }; -/// Split a vector.transfer operation into an unmasked fastpath vector.transfer -/// and a slowpath masked vector.transfer. If `ifOp` is not null and the result -/// is `success, the `ifOp` points to the newly created conditional upon -/// function return. To accomodate for the fact that the original -/// vector.transfer indexing may be arbitrary and the slow path indexes @[0...0] -/// in the temporary buffer, the scf.if op returns a view and values of type -/// index. At this time, only vector.transfer_read is implemented. +/// Split a vector.transfer operation into an unmasked fastpath and a slowpath. +/// If `ifOp` is not null and the result is `success, the `ifOp` points to the +/// newly created conditional upon function return. +/// To accomodate for the fact that the original vector.transfer indexing may be +/// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the +/// scf.if op returns a view and values of type index. +/// At this time, only vector.transfer_read case is implemented. /// /// Example (a 2-D vector.transfer_read): /// ``` @@ -124,17 +124,17 @@ /// is transformed into: /// ``` /// %1:3 = scf.if (%inBounds) { -/// scf.yield %0 : memref, index, index -/// } else { -/// %2 = vector.transfer_read %0[...], %pad : memref, vector<...> -/// %3 = vector.type_cast %extra_alloc : memref<...> to -/// memref> store %2, %3[] : memref> %4 = -/// memref_cast %extra_alloc: memref to memref scf.yield %4 : -/// memref, index, index +/// // fastpath, direct cast +/// memref_cast %A: memref to compatibleMemRefType +/// scf.yield %view : compatibleMemRefType, index, index +/// } else { +/// // slowpath, masked vector.transfer or linalg.copy. +/// memref_cast %alloc: memref to compatibleMemRefType +/// scf.yield %4 : compatibleMemRefType, index, index // } /// %0 = vector.transfer_read %1#0[%1#1, %1#2] {masked = [false ... false]} /// ``` -/// where `extra_alloc` is a top of the function alloca'ed buffer of one vector. +/// where `alloc` is a top of the function alloca'ed buffer of one vector. /// /// Preconditions: /// 1. `xferOp.permutation_map()` must be a minor identity map @@ -143,9 +143,10 @@ /// rank-reducing subviews. LogicalResult splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp); -LogicalResult splitFullAndPartialTransfer(OpBuilder &b, - VectorTransferOpInterface xferOp, - scf::IfOp *ifOp = nullptr); +LogicalResult splitFullAndPartialTransfer( + OpBuilder &b, VectorTransferOpInterface xferOp, + VectorTransformsOptions options = VectorTransformsOptions(), + scf::IfOp *ifOp = nullptr); /// Apply `splitFullAndPartialTransfer` selectively via a pattern. This pattern /// may take an extra filter to perform selection at a finer granularity. @@ -155,16 +156,19 @@ explicit VectorTransferFullPartialRewriter( MLIRContext *context, + VectorTransformsOptions options = VectorTransformsOptions(), FilterConstraintType filter = [](VectorTransferOpInterface op) { return success(); }, PatternBenefit benefit = 1) - : RewritePattern(benefit, MatchAnyOpTypeTag()), filter(filter) {} + : RewritePattern(benefit, MatchAnyOpTypeTag()), options(options), + filter(filter) {} /// Performs the rewrite. LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override; private: + VectorTransformsOptions options; FilterConstraintType filter; }; diff --git a/mlir/lib/Dialect/Vector/CMakeLists.txt b/mlir/lib/Dialect/Vector/CMakeLists.txt --- a/mlir/lib/Dialect/Vector/CMakeLists.txt +++ b/mlir/lib/Dialect/Vector/CMakeLists.txt @@ -16,6 +16,7 @@ MLIRIR MLIRStandardOps MLIRAffineOps + MLIRLinalgOps MLIRSCF MLIRLoopAnalysis MLIRSideEffectInterfaces diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/Affine/EDSC/Intrinsics.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h" #include "mlir/Dialect/SCF/EDSC/Intrinsics.h" #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" @@ -2056,7 +2057,16 @@ return success(); } -MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) { +/// Given two MemRefTypes `aT` and `bT`, return a MemRefType to which both can +/// be cast. If the MemRefTypes don't have the same rank or are not strided, +/// return null; otherwise: +/// 1. if `aT` and `bT` are cast-compatible, return `aT`. +/// 2. else return a new MemRefType obtained by iterating over the shape and +/// strides and: +/// a. keeping the ones that are static and equal across `aT` and `bT`. +/// b. using a dynamic shape and/or stride for the dimeniosns that don't +/// agree. +static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) { if (MemRefCastOp::areCastCompatible(aT, bT)) return aT; if (aT.getRank() != bT.getRank()) @@ -2086,13 +2096,154 @@ makeStridedLinearLayoutMap(resStrides, resOffset, aT.getContext())); } -/// Split a vector.transfer operation into an unmasked fastpath vector.transfer -/// and a slowpath masked vector.transfer. If `ifOp` is not null and the result -/// is `success, the `ifOp` points to the newly created conditional upon -/// function return. To accomodate for the fact that the original -/// vector.transfer indexing may be arbitrary and the slow path indexes @[0...0] -/// in the temporary buffer, the scf.if op returns a view and values of type -/// index. At this time, only vector.transfer_read is implemented. +/// Operates under a scoped context to build the intersection between the +/// view `xferOp.memref()` @ `xferOp.indices()` and the view `alloc`. +// TODO: view intersection/union/differences should be a proper std op. +static Value createScopedSubViewIntersection(VectorTransferOpInterface xferOp, + Value alloc) { + using namespace edsc::intrinsics; + int64_t memrefRank = xferOp.getMemRefType().getRank(); + // TODO: relax this precondition, will require rank-reducing subviews. + assert(memrefRank == alloc.getType().cast().getRank() && + "Expected memref rank to match the alloc rank"); + Value one = std_constant_index(1); + ValueRange leadingIndices = + xferOp.indices().take_front(xferOp.getLeadingMemRefRank()); + SmallVector sizes; + sizes.append(leadingIndices.begin(), leadingIndices.end()); + xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) { + using MapList = ArrayRef>; + Value dimMemRef = std_dim(xferOp.memref(), indicesIdx); + Value dimAlloc = std_dim(alloc, resultIdx); + Value index = xferOp.indices()[indicesIdx]; + AffineExpr i, j, k; + bindDims(xferOp.getContext(), i, j, k); + SmallVector maps = + AffineMap::inferFromExprList(MapList{{i - j, k}}); + // affine_min(%dimMemRef - %index, %dimAlloc) + Value affineMin = affine_min(index.getType(), maps[0], + ValueRange{dimMemRef, index, dimAlloc}); + sizes.push_back(affineMin); + }); + return std_sub_view(xferOp.memref(), xferOp.indices(), sizes, + SmallVector(memrefRank, one)); +} + +/// Given an `xferOp` for which: +/// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed. +/// 2. a memref of single vector `alloc` has been allocated. +/// Produce IR resembling: +/// ``` +/// %1:3 = scf.if (%inBounds) { +/// memref_cast %A: memref to compatibleMemRefType +/// scf.yield %view, ... : compatibleMemRefType, index, index +/// } else { +/// %2 = linalg.fill(%alloc, %pad) +/// %3 = subview %view [...][...][...] +/// linalg.copy(%3, %alloc) +/// memref_cast %alloc: memref to compatibleMemRefType +/// scf.yield %4, ... : compatibleMemRefType, index, index +/// } +/// ``` +/// Return the produced scf::IfOp. +static scf::IfOp createScopedFullPartialLinalgCopy( + vector::TransferReadOp xferOp, TypeRange returnTypes, Value inBoundsCond, + MemRefType compatibleMemRefType, Value alloc) { + using namespace edsc; + using namespace edsc::intrinsics; + scf::IfOp fullPartialIfOp; + Value zero = std_constant_index(0); + Value memref = xferOp.memref(); + conditionBuilder( + returnTypes, inBoundsCond, + [&]() -> scf::ValueVector { + Value res = memref; + if (compatibleMemRefType != xferOp.getMemRefType()) + res = std_memref_cast(memref, compatibleMemRefType); + scf::ValueVector viewAndIndices{res}; + viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(), + xferOp.indices().end()); + return viewAndIndices; + }, + [&]() -> scf::ValueVector { + linalg_fill(alloc, xferOp.padding()); + // Take partial subview of memref which guarantees no dimension + // overflows. + Value memRefSubView = createScopedSubViewIntersection( + cast(xferOp.getOperation()), alloc); + linalg_copy(memRefSubView, alloc); + Value casted = std_memref_cast(alloc, compatibleMemRefType); + scf::ValueVector viewAndIndices{casted}; + viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(), + zero); + return viewAndIndices; + }, + &fullPartialIfOp); + return fullPartialIfOp; +} + +/// Given an `xferOp` for which: +/// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed. +/// 2. a memref of single vector `alloc` has been allocated. +/// Produce IR resembling: +/// ``` +/// %1:3 = scf.if (%inBounds) { +/// memref_cast %A: memref to compatibleMemRefType +/// scf.yield %view, ... : compatibleMemRefType, index, index +/// } else { +/// %2 = vector.transfer_read %view[...], %pad : memref, vector<...> +/// %3 = vector.type_cast %extra_alloc : +/// memref<...> to memref> +/// store %2, %3[] : memref> +/// %4 = memref_cast %alloc: memref to compatibleMemRefType +/// scf.yield %4, ... : compatibleMemRefType, index, index +/// } +/// ``` +/// Return the produced scf::IfOp. +static scf::IfOp createScopedFullPartialVectorTransferRead( + vector::TransferReadOp xferOp, TypeRange returnTypes, Value inBoundsCond, + MemRefType compatibleMemRefType, Value alloc) { + using namespace edsc; + using namespace edsc::intrinsics; + scf::IfOp fullPartialIfOp; + Value zero = std_constant_index(0); + Value memref = xferOp.memref(); + conditionBuilder( + returnTypes, inBoundsCond, + [&]() -> scf::ValueVector { + Value res = memref; + if (compatibleMemRefType != xferOp.getMemRefType()) + res = std_memref_cast(memref, compatibleMemRefType); + scf::ValueVector viewAndIndices{res}; + viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(), + xferOp.indices().end()); + return viewAndIndices; + }, + [&]() -> scf::ValueVector { + Operation *newXfer = + ScopedContext::getBuilderRef().clone(*xferOp.getOperation()); + Value vector = cast(newXfer).vector(); + std_store(vector, vector_type_cast( + MemRefType::get({}, vector.getType()), alloc)); + + Value casted = std_memref_cast(alloc, compatibleMemRefType); + scf::ValueVector viewAndIndices{casted}; + viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(), + zero); + + return viewAndIndices; + }, + &fullPartialIfOp); + return fullPartialIfOp; +} + +/// Split a vector.transfer operation into an unmasked fastpath and a slowpath. +/// If `ifOp` is not null and the result is `success, the `ifOp` points to the +/// newly created conditional upon function return. +/// To accomodate for the fact that the original vector.transfer indexing may be +/// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the +/// scf.if op returns a view and values of type index. +/// At this time, only vector.transfer_read case is implemented. /// /// Example (a 2-D vector.transfer_read): /// ``` @@ -2101,17 +2252,17 @@ /// is transformed into: /// ``` /// %1:3 = scf.if (%inBounds) { -/// scf.yield %0 : memref, index, index -/// } else { -/// %2 = vector.transfer_read %0[...], %pad : memref, vector<...> -/// %3 = vector.type_cast %extra_alloc : memref<...> to -/// memref> store %2, %3[] : memref> %4 = -/// memref_cast %extra_alloc: memref to memref scf.yield %4 : -/// memref, index, index +/// // fastpath, direct cast +/// memref_cast %A: memref to compatibleMemRefType +/// scf.yield %view : compatibleMemRefType, index, index +/// } else { +/// // slowpath, masked vector.transfer or linalg.copy. +/// memref_cast %alloc: memref to compatibleMemRefType +/// scf.yield %4 : compatibleMemRefType, index, index // } /// %0 = vector.transfer_read %1#0[%1#1, %1#2] {masked = [false ... false]} /// ``` -/// where `extra_alloc` is a top of the function alloca'ed buffer of one vector. +/// where `alloc` is a top of the function alloca'ed buffer of one vector. /// /// Preconditions: /// 1. `xferOp.permutation_map()` must be a minor identity map @@ -2119,10 +2270,21 @@ /// must be equal. This will be relaxed in the future but requires /// rank-reducing subviews. LogicalResult mlir::vector::splitFullAndPartialTransfer( - OpBuilder &b, VectorTransferOpInterface xferOp, scf::IfOp *ifOp) { + OpBuilder &b, VectorTransferOpInterface xferOp, + VectorTransformsOptions options, scf::IfOp *ifOp) { using namespace edsc; using namespace edsc::intrinsics; + if (options.vectorTransferSplit == VectorTransferSplit::None) + return failure(); + + SmallVector bools(xferOp.getTransferRank(), false); + auto unmaskedAttr = b.getBoolArrayAttr(bools); + if (options.vectorTransferSplit == VectorTransferSplit::ForceUnmasked) { + xferOp.setAttr(vector::TransferReadOp::getMaskedAttrName(), unmaskedAttr); + return success(); + } + assert(succeeded(splitFullAndPartialTransferPrecondition(xferOp)) && "Expected splitFullAndPartialTransferPrecondition to hold"); auto xferReadOp = dyn_cast(xferOp.getOperation()); @@ -2154,45 +2316,21 @@ b.getI64IntegerAttr(32)); } - Value memref = xferOp.memref(); - SmallVector bools(xferOp.getTransferRank(), false); - auto unmaskedAttr = b.getBoolArrayAttr(bools); - MemRefType compatibleMemRefType = getCastCompatibleMemRefType( xferOp.getMemRefType(), alloc.getType().cast()); // Read case: full fill + partial copy -> unmasked vector.xfer_read. - Value zero = std_constant_index(0); SmallVector returnTypes(1 + xferOp.getTransferRank(), b.getIndexType()); returnTypes[0] = compatibleMemRefType; - scf::IfOp fullPartialIfOp; - conditionBuilder( - returnTypes, inBoundsCond, - [&]() -> scf::ValueVector { - Value res = memref; - if (compatibleMemRefType != xferOp.getMemRefType()) - res = std_memref_cast(memref, compatibleMemRefType); - scf::ValueVector viewAndIndices{res}; - viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(), - xferOp.indices().end()); - return viewAndIndices; - }, - [&]() -> scf::ValueVector { - Operation *newXfer = - ScopedContext::getBuilderRef().clone(*xferOp.getOperation()); - Value vector = cast(newXfer).vector(); - std_store(vector, vector_type_cast( - MemRefType::get({}, vector.getType()), alloc)); - - Value casted = std_memref_cast(alloc, compatibleMemRefType); - scf::ValueVector viewAndIndices{casted}; - viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(), - zero); - - return viewAndIndices; - }, - &fullPartialIfOp); + scf::IfOp fullPartialIfOp = + options.vectorTransferSplit == VectorTransferSplit::VectorTransfer + ? createScopedFullPartialVectorTransferRead( + xferReadOp, returnTypes, inBoundsCond, compatibleMemRefType, + alloc) + : createScopedFullPartialLinalgCopy(xferReadOp, returnTypes, + inBoundsCond, + compatibleMemRefType, alloc); if (ifOp) *ifOp = fullPartialIfOp; @@ -2211,7 +2349,7 @@ failed(filter(xferOp))) return failure(); rewriter.startRootUpdate(xferOp); - if (succeeded(splitFullAndPartialTransfer(rewriter, xferOp))) { + if (succeeded(splitFullAndPartialTransfer(rewriter, xferOp, options))) { rewriter.finalizeRootUpdate(xferOp); return success(); } diff --git a/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir b/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir --- a/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir @@ -1,13 +1,26 @@ // RUN: mlir-opt %s -test-vector-transfer-full-partial-split | FileCheck %s +// RUN: mlir-opt %s -test-vector-transfer-full-partial-split=use-linalg-copy | FileCheck %s --check-prefix=LINALG // CHECK-DAG: #[[$map_p4:.*]] = affine_map<()[s0] -> (s0 + 4)> // CHECK-DAG: #[[$map_p8:.*]] = affine_map<()[s0] -> (s0 + 8)> // CHECK-DAG: #[[$map_2d_stride_1:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> +// LINALG-DAG: #[[$map_p4:.*]] = affine_map<()[s0] -> (s0 + 4)> +// LINALG-DAG: #[[$map_p8:.*]] = affine_map<()[s0] -> (s0 + 8)> +// LINALG-DAG: #[[$map_2d_stride_1:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> +// LINALG-DAG: #[[$map_2d_dynamic:.*]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)> +// LINALG-DAG: #[[$bounds_map_4:.*]] = affine_map<(d0, d1, d2) -> (d0 - d1, 4)> +// LINALG-DAG: #[[$bounds_map_8:.*]] = affine_map<(d0, d1, d2) -> (d0 - d1, 8)> + // CHECK-LABEL: split_vector_transfer_read_2d( // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref // CHECK-SAME: %[[i:[a-zA-Z0-9]*]]: index // CHECK-SAME: %[[j:[a-zA-Z0-9]*]]: index + +// LINALG-LABEL: split_vector_transfer_read_2d( +// LINALG-SAME: %[[A:[a-zA-Z0-9]*]]: memref +// LINALG-SAME: %[[i:[a-zA-Z0-9]*]]: index +// LINALG-SAME: %[[j:[a-zA-Z0-9]*]]: index func @split_vector_transfer_read_2d(%A: memref, %i: index, %j: index) -> vector<4x8xf32> { %c0 = constant 0 : index %f0 = constant 0.0 : f32 @@ -43,9 +56,45 @@ // CHECK: } // CHECK: %[[res:.*]] = vector.transfer_read %[[ifres]]#0[%[[ifres]]#1, %[[ifres]]#2], %[[cst]] // CHECK_SAME: {masked = [false, false]} : memref, vector<4x8xf32> + + // LINALG-DAG: %[[c0:.*]] = constant 0 : index + // LINALG-DAG: %[[c1:.*]] = constant 1 : index + // LINALG-DAG: %[[c4:.*]] = constant 4 : index + // LINALG-DAG: %[[c8:.*]] = constant 8 : index + // LINALG-DAG: %[[cst:.*]] = constant 0.000000e+00 : f32 + // alloca for boundary full tile + // LINALG: %[[alloc:.*]] = alloca() {alignment = 32 : i64} : memref<4x8xf32> + // %i + 4 <= dim(%A, 0) + // LINALG: %[[idx0:.*]] = affine.apply #[[$map_p4]]()[%[[i]]] + // LINALG: %[[d0:.*]] = dim %[[A]], %[[c0]] : memref + // LINALG: %[[cmp0:.*]] = cmpi "sle", %[[idx0]], %[[d0]] : index + // %j + 8 <= dim(%A, 1) + // LINALG: %[[idx1:.*]] = affine.apply #[[$map_p8]]()[%[[j]]] + // LINALG: %[[cmp1:.*]] = cmpi "sle", %[[idx1]], %[[c8]] : index + // are both conds true + // LINALG: %[[cond:.*]] = and %[[cmp0]], %[[cmp1]] : i1 + // LINALG: %[[ifres:.*]]:3 = scf.if %[[cond]] -> (memref, index, index) { + // inBounds, just yield %A + // LINALG: scf.yield %[[A]], %[[i]], %[[j]] : memref, index, index + // LINALG: } else { + // slow path, fill tmp alloc and yield a memref_casted version of it + // LINALG: linalg.fill(%[[alloc]], %[[cst]]) : memref<4x8xf32>, f32 + // LINALG: %[[d0:.*]] = dim %[[A]], %[[c0]] : memref + // LINALG: %[[sv0:.*]] = affine.min #[[$bounds_map_4]](%[[d0]], %[[i]], %[[c4]]) + // LINALG: %[[sv1:.*]] = affine.min #[[$bounds_map_8]](%[[c8]], %[[j]], %[[c8]]) + // LINALG: %[[sv:.*]] = subview %[[A]][%[[i]], %[[j]]] [%[[sv0]], %[[sv1]]] [%[[c1]], %[[c1]]] + // LINALG-SAME: memref to memref + // LINALG: linalg.copy(%[[sv]], %[[alloc]]) : memref, memref<4x8xf32> + // LINALG: %[[yielded:.*]] = memref_cast %[[alloc]] : + // LINALG-SAME: memref<4x8xf32> to memref + // LINALG: scf.yield %[[yielded]], %[[c0]], %[[c0]] : + // LINALG-SAME: memref, index, index + // LINALG: } + // LINALG: %[[res:.*]] = vector.transfer_read %[[ifres]]#0[%[[ifres]]#1, %[[ifres]]#2], %[[cst]] + // LINALG_SAME: {masked = [false, false]} : memref, vector<4x8xf32> %1 = vector.transfer_read %A[%i, %j], %f0 : memref, vector<4x8xf32> - // CHECK: return %[[res]] : vector<4x8xf32> + // LINALG: return %[[res]] : vector<4x8xf32> return %1: vector<4x8xf32> } @@ -53,6 +102,11 @@ // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref // CHECK-SAME: %[[i:[a-zA-Z0-9]*]]: index // CHECK-SAME: %[[j:[a-zA-Z0-9]*]]: index + +// LINALG-LABEL: split_vector_transfer_read_strided_2d( +// LINALG-SAME: %[[A:[a-zA-Z0-9]*]]: memref +// LINALG-SAME: %[[i:[a-zA-Z0-9]*]]: index +// LINALG-SAME: %[[j:[a-zA-Z0-9]*]]: index func @split_vector_transfer_read_strided_2d( %A: memref<7x8xf32, offset:?, strides:[?, 1]>, %i: index, %j: index) -> vector<4x8xf32> { @@ -94,6 +148,44 @@ // CHECK: } // CHECK: %[[res:.*]] = vector.transfer_read {{.*}} {masked = [false, false]} : // CHECK-SAME: memref, vector<4x8xf32> + + // LINALG-DAG: %[[c0:.*]] = constant 0 : index + // LINALG-DAG: %[[c1:.*]] = constant 1 : index + // LINALG-DAG: %[[c4:.*]] = constant 4 : index + // LINALG-DAG: %[[c7:.*]] = constant 7 : index + // LINALG-DAG: %[[c8:.*]] = constant 8 : index + // LINALG-DAG: %[[cst:.*]] = constant 0.000000e+00 : f32 + // alloca for boundary full tile + // LINALG: %[[alloc:.*]] = alloca() {alignment = 32 : i64} : memref<4x8xf32> + // %i + 4 <= dim(%A, 0) + // LINALG: %[[idx0:.*]] = affine.apply #[[$map_p4]]()[%[[i]]] + // LINALG: %[[cmp0:.*]] = cmpi "sle", %[[idx0]], %[[c7]] : index + // %j + 8 <= dim(%A, 1) + // LINALG: %[[idx1:.*]] = affine.apply #[[$map_p8]]()[%[[j]]] + // LINALG: %[[cmp1:.*]] = cmpi "sle", %[[idx1]], %[[c8]] : index + // are both conds true + // LINALG: %[[cond:.*]] = and %[[cmp0]], %[[cmp1]] : i1 + // LINALG: %[[ifres:.*]]:3 = scf.if %[[cond]] -> (memref, index, index) { + // inBounds but not cast-compatible: yield a memref_casted form of %A + // LINALG: %[[casted:.*]] = memref_cast %arg0 : + // LINALG-SAME: memref<7x8xf32, #[[$map_2d_stride_1]]> to memref + // LINALG: scf.yield %[[casted]], %[[i]], %[[j]] : + // LINALG-SAME: memref, index, index + // LINALG: } else { + // slow path, fill tmp alloc and yield a memref_casted version of it + // LINALG: linalg.fill(%[[alloc]], %[[cst]]) : memref<4x8xf32>, f32 + // LINALG: %[[sv0:.*]] = affine.min #[[$bounds_map_4]](%[[c7]], %[[i]], %[[c4]]) + // LINALG: %[[sv1:.*]] = affine.min #[[$bounds_map_8]](%[[c8]], %[[j]], %[[c8]]) + // LINALG: %[[sv:.*]] = subview %[[A]][%[[i]], %[[j]]] [%[[sv0]], %[[sv1]]] [%[[c1]], %[[c1]]] + // LINALG-SAME: memref<7x8xf32, #[[$map_2d_stride_1]]> to memref + // LINALG: linalg.copy(%[[sv]], %[[alloc]]) : memref, memref<4x8xf32> + // LINALG: %[[yielded:.*]] = memref_cast %[[alloc]] : + // LINALG-SAME: memref<4x8xf32> to memref + // LINALG: scf.yield %[[yielded]], %[[c0]], %[[c0]] : + // LINALG-SAME: memref, index, index + // LINALG: } + // LINALG: %[[res:.*]] = vector.transfer_read {{.*}} {masked = [false, false]} : + // LINALG-SAME: memref, vector<4x8xf32> %1 = vector.transfer_read %A[%i, %j], %f0 : memref<7x8xf32, offset:?, strides:[?, 1]>, vector<4x8xf32> diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp --- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp +++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp @@ -125,10 +125,23 @@ struct TestVectorTransferFullPartialSplitPatterns : public PassWrapper { + TestVectorTransferFullPartialSplitPatterns() = default; + TestVectorTransferFullPartialSplitPatterns( + const TestVectorTransferFullPartialSplitPatterns &pass) {} + Option useLinalgOps{ + *this, "use-linalg-copy", + llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + " + "linalg.copy operations."), + llvm::cl::init(false)}; void runOnFunction() override { MLIRContext *ctx = &getContext(); OwningRewritePatternList patterns; - patterns.insert(ctx); + VectorTransformsOptions options; + if (useLinalgOps) + options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy); + else + options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer); + patterns.insert(ctx, options); applyPatternsAndFoldGreedily(getFunction(), patterns); } };