diff --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h --- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h +++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h @@ -16,6 +16,8 @@ class MLIRContext; class VectorTransferOpInterface; class RewritePatternSet; +class RewriterBase; + using OwningRewritePatternList = RewritePatternSet; namespace scf { @@ -61,7 +63,7 @@ /// must be equal. This will be relaxed in the future but requires /// rank-reducing subviews. LogicalResult splitFullAndPartialTransfer( - OpBuilder &b, VectorTransferOpInterface xferOp, + RewriterBase &b, VectorTransferOpInterface xferOp, VectorTransformsOptions options = VectorTransformsOptions(), scf::IfOp *ifOp = nullptr); diff --git a/mlir/lib/Dialect/Vector/CMakeLists.txt b/mlir/lib/Dialect/Vector/CMakeLists.txt --- a/mlir/lib/Dialect/Vector/CMakeLists.txt +++ b/mlir/lib/Dialect/Vector/CMakeLists.txt @@ -4,6 +4,7 @@ VectorMultiDimReductionTransforms.cpp VectorOps.cpp VectorTransferOpTransforms.cpp + VectorTransferSplitRewritePatterns.cpp VectorTransferPermutationMapRewritePatterns.cpp VectorTransforms.cpp VectorUnrollDistribute.cpp diff --git a/mlir/lib/Dialect/Vector/VectorTransferSplitRewritePatterns.cpp b/mlir/lib/Dialect/Vector/VectorTransferSplitRewritePatterns.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Vector/VectorTransferSplitRewritePatterns.cpp @@ -0,0 +1,625 @@ +//===- VectorTransferSplitRewritePatterns.cpp - Transfer Split Rewrites ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements target-independent patterns to rewrite a vector.transfer +// op into a fully in-bounds part and a partial part. +// +//===----------------------------------------------------------------------===// + +#include + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/Utils.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" + +#include "mlir/Dialect/Vector/VectorTransforms.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/VectorInterfaces.h" + +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +#define DEBUG_TYPE "vector-transfer-split" + +using namespace mlir; +using namespace mlir::vector; + +static Optional extractConstantIndex(Value v) { + if (auto cstOp = v.getDefiningOp()) + return cstOp.value(); + if (auto affineApplyOp = v.getDefiningOp()) + if (affineApplyOp.getAffineMap().isSingleConstant()) + return affineApplyOp.getAffineMap().getSingleConstantResult(); + return None; +} + +// Missing foldings of scf.if make it necessary to perform poor man's folding +// eagerly, especially in the case of unrolling. In the future, this should go +// away once scf.if folds properly. +static Value createFoldedSLE(RewriterBase &b, Value v, Value ub) { + auto maybeCstV = extractConstantIndex(v); + auto maybeCstUb = extractConstantIndex(ub); + if (maybeCstV && maybeCstUb && *maybeCstV < *maybeCstUb) + return Value(); + return b.create(v.getLoc(), arith::CmpIPredicate::sle, v, ub); +} + +/// Build the condition to ensure that a particular VectorTransferOpInterface +/// is in-bounds. +static Value createInBoundsCond(RewriterBase &b, + VectorTransferOpInterface xferOp) { + assert(xferOp.permutation_map().isMinorIdentity() && + "Expected minor identity map"); + Value inBoundsCond; + xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) { + // Zip over the resulting vector shape and memref indices. + // If the dimension is known to be in-bounds, it does not participate in + // the construction of `inBoundsCond`. + if (xferOp.isDimInBounds(resultIdx)) + return; + // Fold or create the check that `index + vector_size` <= `memref_size`. + Location loc = xferOp.getLoc(); + int64_t vectorSize = xferOp.getVectorType().getDimSize(resultIdx); + auto d0 = getAffineDimExpr(0, xferOp.getContext()); + auto vs = getAffineConstantExpr(vectorSize, xferOp.getContext()); + Value sum = + makeComposedAffineApply(b, loc, d0 + vs, xferOp.indices()[indicesIdx]); + Value cond = createFoldedSLE( + b, sum, vector::createOrFoldDimOp(b, loc, xferOp.source(), indicesIdx)); + if (!cond) + return; + // Conjunction over all dims for which we are in-bounds. + if (inBoundsCond) + inBoundsCond = b.create(loc, inBoundsCond, cond); + else + inBoundsCond = cond; + }); + return inBoundsCond; +} + +/// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds +/// masking) fastpath and a slowpath. +/// If `ifOp` is not null and the result is `success, the `ifOp` points to the +/// newly created conditional upon function return. +/// To accomodate for the fact that the original vector.transfer indexing may be +/// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the +/// scf.if op returns a view and values of type index. +/// At this time, only vector.transfer_read case is implemented. +/// +/// Example (a 2-D vector.transfer_read): +/// ``` +/// %1 = vector.transfer_read %0[...], %pad : memref, vector<...> +/// ``` +/// is transformed into: +/// ``` +/// %1:3 = scf.if (%inBounds) { +/// // fastpath, direct cast +/// memref.cast %A: memref to compatibleMemRefType +/// scf.yield %view : compatibleMemRefType, index, index +/// } else { +/// // slowpath, not in-bounds vector.transfer or linalg.copy. +/// memref.cast %alloc: memref to compatibleMemRefType +/// scf.yield %4 : compatibleMemRefType, index, index +// } +/// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]} +/// ``` +/// where `alloc` is a top of the function alloca'ed buffer of one vector. +/// +/// Preconditions: +/// 1. `xferOp.permutation_map()` must be a minor identity map +/// 2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()` +/// must be equal. This will be relaxed in the future but requires +/// rank-reducing subviews. +static LogicalResult +splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp) { + // TODO: support 0-d corner case. + if (xferOp.getTransferRank() == 0) + return failure(); + + // TODO: expand support to these 2 cases. + if (!xferOp.permutation_map().isMinorIdentity()) + return failure(); + // Must have some out-of-bounds dimension to be a candidate for splitting. + if (!xferOp.hasOutOfBoundsDim()) + return failure(); + // Don't split transfer operations directly under IfOp, this avoids applying + // the pattern recursively. + // TODO: improve the filtering condition to make it more applicable. + if (isa(xferOp->getParentOp())) + return failure(); + return success(); +} + +/// Given two MemRefTypes `aT` and `bT`, return a MemRefType to which both can +/// be cast. If the MemRefTypes don't have the same rank or are not strided, +/// return null; otherwise: +/// 1. if `aT` and `bT` are cast-compatible, return `aT`. +/// 2. else return a new MemRefType obtained by iterating over the shape and +/// strides and: +/// a. keeping the ones that are static and equal across `aT` and `bT`. +/// b. using a dynamic shape and/or stride for the dimensions that don't +/// agree. +static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) { + if (memref::CastOp::areCastCompatible(aT, bT)) + return aT; + if (aT.getRank() != bT.getRank()) + return MemRefType(); + int64_t aOffset, bOffset; + SmallVector aStrides, bStrides; + if (failed(getStridesAndOffset(aT, aStrides, aOffset)) || + failed(getStridesAndOffset(bT, bStrides, bOffset)) || + aStrides.size() != bStrides.size()) + return MemRefType(); + + ArrayRef aShape = aT.getShape(), bShape = bT.getShape(); + int64_t resOffset; + SmallVector resShape(aT.getRank(), 0), + resStrides(bT.getRank(), 0); + for (int64_t idx = 0, e = aT.getRank(); idx < e; ++idx) { + resShape[idx] = + (aShape[idx] == bShape[idx]) ? aShape[idx] : MemRefType::kDynamicSize; + resStrides[idx] = (aStrides[idx] == bStrides[idx]) + ? aStrides[idx] + : MemRefType::kDynamicStrideOrOffset; + } + resOffset = + (aOffset == bOffset) ? aOffset : MemRefType::kDynamicStrideOrOffset; + return MemRefType::get( + resShape, aT.getElementType(), + makeStridedLinearLayoutMap(resStrides, resOffset, aT.getContext())); +} + +/// Operates under a scoped context to build the intersection between the +/// view `xferOp.source()` @ `xferOp.indices()` and the view `alloc`. +// TODO: view intersection/union/differences should be a proper std op. +static std::pair +createSubViewIntersection(RewriterBase &b, VectorTransferOpInterface xferOp, + Value alloc) { + Location loc = xferOp.getLoc(); + int64_t memrefRank = xferOp.getShapedType().getRank(); + // TODO: relax this precondition, will require rank-reducing subviews. + assert(memrefRank == alloc.getType().cast().getRank() && + "Expected memref rank to match the alloc rank"); + ValueRange leadingIndices = + xferOp.indices().take_front(xferOp.getLeadingShapedRank()); + SmallVector sizes; + sizes.append(leadingIndices.begin(), leadingIndices.end()); + auto isaWrite = isa(xferOp); + xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) { + using MapList = ArrayRef>; + Value dimMemRef = vector::createOrFoldDimOp(b, xferOp.getLoc(), + xferOp.source(), indicesIdx); + Value dimAlloc = b.create(loc, alloc, resultIdx); + Value index = xferOp.indices()[indicesIdx]; + AffineExpr i, j, k; + bindDims(xferOp.getContext(), i, j, k); + SmallVector maps = + AffineMap::inferFromExprList(MapList{{i - j, k}}); + // affine_min(%dimMemRef - %index, %dimAlloc) + Value affineMin = b.create( + loc, index.getType(), maps[0], ValueRange{dimMemRef, index, dimAlloc}); + sizes.push_back(affineMin); + }); + + SmallVector srcIndices = llvm::to_vector<4>(llvm::map_range( + xferOp.indices(), [](Value idx) -> OpFoldResult { return idx; })); + SmallVector destIndices(memrefRank, b.getIndexAttr(0)); + SmallVector strides(memrefRank, b.getIndexAttr(1)); + auto copySrc = b.create( + loc, isaWrite ? alloc : xferOp.source(), srcIndices, sizes, strides); + auto copyDest = b.create( + loc, isaWrite ? xferOp.source() : alloc, destIndices, sizes, strides); + return std::make_pair(copySrc, copyDest); +} + +/// Given an `xferOp` for which: +/// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed. +/// 2. a memref of single vector `alloc` has been allocated. +/// Produce IR resembling: +/// ``` +/// %1:3 = scf.if (%inBounds) { +/// %view = memref.cast %A: memref to compatibleMemRefType +/// scf.yield %view, ... : compatibleMemRefType, index, index +/// } else { +/// %2 = linalg.fill(%pad, %alloc) +/// %3 = subview %view [...][...][...] +/// %4 = subview %alloc [0, 0] [...] [...] +/// linalg.copy(%3, %4) +/// %5 = memref.cast %alloc: memref to compatibleMemRefType +/// scf.yield %5, ... : compatibleMemRefType, index, index +/// } +/// ``` +/// Return the produced scf::IfOp. +static scf::IfOp +createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp, + TypeRange returnTypes, Value inBoundsCond, + MemRefType compatibleMemRefType, Value alloc) { + Location loc = xferOp.getLoc(); + Value zero = b.create(loc, 0); + Value memref = xferOp.source(); + return b.create( + loc, returnTypes, inBoundsCond, + [&](OpBuilder &b, Location loc) { + Value res = memref; + if (compatibleMemRefType != xferOp.getShapedType()) + res = b.create(loc, memref, compatibleMemRefType); + scf::ValueVector viewAndIndices{res}; + viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(), + xferOp.indices().end()); + b.create(loc, viewAndIndices); + }, + [&](OpBuilder &b, Location loc) { + b.create(loc, xferOp.padding(), alloc); + // Take partial subview of memref which guarantees no dimension + // overflows. + IRRewriter rewriter(b); + std::pair copyArgs = createSubViewIntersection( + rewriter, cast(xferOp.getOperation()), + alloc); + b.create(loc, copyArgs.first, copyArgs.second); + Value casted = + b.create(loc, alloc, compatibleMemRefType); + scf::ValueVector viewAndIndices{casted}; + viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(), + zero); + b.create(loc, viewAndIndices); + }); +} + +/// Given an `xferOp` for which: +/// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed. +/// 2. a memref of single vector `alloc` has been allocated. +/// Produce IR resembling: +/// ``` +/// %1:3 = scf.if (%inBounds) { +/// memref.cast %A: memref to compatibleMemRefType +/// scf.yield %view, ... : compatibleMemRefType, index, index +/// } else { +/// %2 = vector.transfer_read %view[...], %pad : memref, vector<...> +/// %3 = vector.type_cast %extra_alloc : +/// memref<...> to memref> +/// store %2, %3[] : memref> +/// %4 = memref.cast %alloc: memref to compatibleMemRefType +/// scf.yield %4, ... : compatibleMemRefType, index, index +/// } +/// ``` +/// Return the produced scf::IfOp. +static scf::IfOp createFullPartialVectorTransferRead( + RewriterBase &b, vector::TransferReadOp xferOp, TypeRange returnTypes, + Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc) { + Location loc = xferOp.getLoc(); + scf::IfOp fullPartialIfOp; + Value zero = b.create(loc, 0); + Value memref = xferOp.source(); + return b.create( + loc, returnTypes, inBoundsCond, + [&](OpBuilder &b, Location loc) { + Value res = memref; + if (compatibleMemRefType != xferOp.getShapedType()) + res = b.create(loc, memref, compatibleMemRefType); + scf::ValueVector viewAndIndices{res}; + viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(), + xferOp.indices().end()); + b.create(loc, viewAndIndices); + }, + [&](OpBuilder &b, Location loc) { + Operation *newXfer = b.clone(*xferOp.getOperation()); + Value vector = cast(newXfer).vector(); + b.create( + loc, vector, + b.create( + loc, MemRefType::get({}, vector.getType()), alloc)); + + Value casted = + b.create(loc, alloc, compatibleMemRefType); + scf::ValueVector viewAndIndices{casted}; + viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(), + zero); + b.create(loc, viewAndIndices); + }); +} + +/// Given an `xferOp` for which: +/// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed. +/// 2. a memref of single vector `alloc` has been allocated. +/// Produce IR resembling: +/// ``` +/// %1:3 = scf.if (%inBounds) { +/// memref.cast %A: memref to compatibleMemRefType +/// scf.yield %view, ... : compatibleMemRefType, index, index +/// } else { +/// %3 = vector.type_cast %extra_alloc : +/// memref<...> to memref> +/// %4 = memref.cast %alloc: memref to compatibleMemRefType +/// scf.yield %4, ... : compatibleMemRefType, index, index +/// } +/// ``` +static ValueRange +getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp, + TypeRange returnTypes, Value inBoundsCond, + MemRefType compatibleMemRefType, Value alloc) { + Location loc = xferOp.getLoc(); + Value zero = b.create(loc, 0); + Value memref = xferOp.source(); + return b + .create( + loc, returnTypes, inBoundsCond, + [&](OpBuilder &b, Location loc) { + Value res = memref; + if (compatibleMemRefType != xferOp.getShapedType()) + res = b.create(loc, memref, compatibleMemRefType); + scf::ValueVector viewAndIndices{res}; + viewAndIndices.insert(viewAndIndices.end(), + xferOp.indices().begin(), + xferOp.indices().end()); + b.create(loc, viewAndIndices); + }, + [&](OpBuilder &b, Location loc) { + Value casted = + b.create(loc, alloc, compatibleMemRefType); + scf::ValueVector viewAndIndices{casted}; + viewAndIndices.insert(viewAndIndices.end(), + xferOp.getTransferRank(), zero); + b.create(loc, viewAndIndices); + }) + ->getResults(); +} + +/// Given an `xferOp` for which: +/// 1. `inBoundsCond` has been computed. +/// 2. a memref of single vector `alloc` has been allocated. +/// 3. it originally wrote to %view +/// Produce IR resembling: +/// ``` +/// %notInBounds = arith.xori %inBounds, %true +/// scf.if (%notInBounds) { +/// %3 = subview %alloc [...][...][...] +/// %4 = subview %view [0, 0][...][...] +/// linalg.copy(%3, %4) +/// } +/// ``` +static void createFullPartialLinalgCopy(RewriterBase &b, + vector::TransferWriteOp xferOp, + Value inBoundsCond, Value alloc) { + Location loc = xferOp.getLoc(); + auto notInBounds = b.create( + loc, inBoundsCond, b.create(loc, true, 1)); + b.create(loc, notInBounds, [&](OpBuilder &b, Location loc) { + IRRewriter rewriter(b); + std::pair copyArgs = createSubViewIntersection( + rewriter, cast(xferOp.getOperation()), + alloc); + b.create(loc, copyArgs.first, copyArgs.second); + b.create(loc, ValueRange{}); + }); +} + +/// Given an `xferOp` for which: +/// 1. `inBoundsCond` has been computed. +/// 2. a memref of single vector `alloc` has been allocated. +/// 3. it originally wrote to %view +/// Produce IR resembling: +/// ``` +/// %notInBounds = arith.xori %inBounds, %true +/// scf.if (%notInBounds) { +/// %2 = load %alloc : memref> +/// vector.transfer_write %2, %view[...] : memref, vector<...> +/// } +/// ``` +static void createFullPartialVectorTransferWrite(RewriterBase &b, + vector::TransferWriteOp xferOp, + Value inBoundsCond, + Value alloc) { + Location loc = xferOp.getLoc(); + auto notInBounds = b.create( + loc, inBoundsCond, b.create(loc, true, 1)); + b.create(loc, notInBounds, [&](OpBuilder &b, Location loc) { + BlockAndValueMapping mapping; + Value load = b.create( + loc, b.create( + loc, MemRefType::get({}, xferOp.vector().getType()), alloc)); + mapping.map(xferOp.vector(), load); + b.clone(*xferOp.getOperation(), mapping); + b.create(loc, ValueRange{}); + }); +} + +/// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds +/// masking) fastpath and a slowpath. +/// +/// For vector.transfer_read: +/// If `ifOp` is not null and the result is `success, the `ifOp` points to the +/// newly created conditional upon function return. +/// To accomodate for the fact that the original vector.transfer indexing may be +/// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the +/// scf.if op returns a view and values of type index. +/// +/// Example (a 2-D vector.transfer_read): +/// ``` +/// %1 = vector.transfer_read %0[...], %pad : memref, vector<...> +/// ``` +/// is transformed into: +/// ``` +/// %1:3 = scf.if (%inBounds) { +/// // fastpath, direct cast +/// memref.cast %A: memref to compatibleMemRefType +/// scf.yield %view : compatibleMemRefType, index, index +/// } else { +/// // slowpath, not in-bounds vector.transfer or linalg.copy. +/// memref.cast %alloc: memref to compatibleMemRefType +/// scf.yield %4 : compatibleMemRefType, index, index +// } +/// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]} +/// ``` +/// where `alloc` is a top of the function alloca'ed buffer of one vector. +/// +/// For vector.transfer_write: +/// There are 2 conditional blocks. First a block to decide which memref and +/// indices to use for an unmasked, inbounds write. Then a conditional block to +/// further copy a partial buffer into the final result in the slow path case. +/// +/// Example (a 2-D vector.transfer_write): +/// ``` +/// vector.transfer_write %arg, %0[...], %pad : memref, vector<...> +/// ``` +/// is transformed into: +/// ``` +/// %1:3 = scf.if (%inBounds) { +/// memref.cast %A: memref to compatibleMemRefType +/// scf.yield %view : compatibleMemRefType, index, index +/// } else { +/// memref.cast %alloc: memref to compatibleMemRefType +/// scf.yield %4 : compatibleMemRefType, index, index +/// } +/// %0 = vector.transfer_write %arg, %1#0[%1#1, %1#2] {in_bounds = [true ... +/// true]} +/// scf.if (%notInBounds) { +/// // slowpath: not in-bounds vector.transfer or linalg.copy. +/// } +/// ``` +/// where `alloc` is a top of the function alloca'ed buffer of one vector. +/// +/// Preconditions: +/// 1. `xferOp.permutation_map()` must be a minor identity map +/// 2. the rank of the `xferOp.source()` and the rank of the `xferOp.vector()` +/// must be equal. This will be relaxed in the future but requires +/// rank-reducing subviews. +LogicalResult mlir::vector::splitFullAndPartialTransfer( + RewriterBase &b, VectorTransferOpInterface xferOp, + VectorTransformsOptions options, scf::IfOp *ifOp) { + if (options.vectorTransferSplit == VectorTransferSplit::None) + return failure(); + + SmallVector bools(xferOp.getTransferRank(), true); + auto inBoundsAttr = b.getBoolArrayAttr(bools); + if (options.vectorTransferSplit == VectorTransferSplit::ForceInBounds) { + xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr); + return success(); + } + + // Assert preconditions. Additionally, keep the variables in an inner scope to + // ensure they aren't used in the wrong scopes further down. + { + assert(succeeded(splitFullAndPartialTransferPrecondition(xferOp)) && + "Expected splitFullAndPartialTransferPrecondition to hold"); + + auto xferReadOp = dyn_cast(xferOp.getOperation()); + auto xferWriteOp = dyn_cast(xferOp.getOperation()); + + if (!(xferReadOp || xferWriteOp)) + return failure(); + if (xferWriteOp && xferWriteOp.mask()) + return failure(); + if (xferReadOp && xferReadOp.mask()) + return failure(); + } + + RewriterBase::InsertionGuard guard(b); + b.setInsertionPoint(xferOp); + Value inBoundsCond = createInBoundsCond( + b, cast(xferOp.getOperation())); + if (!inBoundsCond) + return failure(); + + // Top of the function `alloc` for transient storage. + Value alloc; + { + FuncOp funcOp = xferOp->getParentOfType(); + RewriterBase::InsertionGuard guard(b); + b.setInsertionPointToStart(&funcOp.getRegion().front()); + auto shape = xferOp.getVectorType().getShape(); + Type elementType = xferOp.getVectorType().getElementType(); + alloc = b.create(funcOp.getLoc(), + MemRefType::get(shape, elementType), + ValueRange{}, b.getI64IntegerAttr(32)); + } + + MemRefType compatibleMemRefType = + getCastCompatibleMemRefType(xferOp.getShapedType().cast(), + alloc.getType().cast()); + if (!compatibleMemRefType) + return failure(); + + SmallVector returnTypes(1 + xferOp.getTransferRank(), + b.getIndexType()); + returnTypes[0] = compatibleMemRefType; + + if (auto xferReadOp = + dyn_cast(xferOp.getOperation())) { + // Read case: full fill + partial copy -> in-bounds vector.xfer_read. + scf::IfOp fullPartialIfOp = + options.vectorTransferSplit == VectorTransferSplit::VectorTransfer + ? createFullPartialVectorTransferRead(b, xferReadOp, returnTypes, + inBoundsCond, + compatibleMemRefType, alloc) + : createFullPartialLinalgCopy(b, xferReadOp, returnTypes, + inBoundsCond, compatibleMemRefType, + alloc); + if (ifOp) + *ifOp = fullPartialIfOp; + + // Set existing read op to in-bounds, it always reads from a full buffer. + for (unsigned i = 0, e = returnTypes.size(); i != e; ++i) + xferReadOp.setOperand(i, fullPartialIfOp.getResult(i)); + + xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr); + + return success(); + } + + auto xferWriteOp = cast(xferOp.getOperation()); + + // Decide which location to write the entire vector to. + auto memrefAndIndices = getLocationToWriteFullVec( + b, xferWriteOp, returnTypes, inBoundsCond, compatibleMemRefType, alloc); + + // Do an in bounds write to either the output or the extra allocated buffer. + // The operation is cloned to prevent deleting information needed for the + // later IR creation. + BlockAndValueMapping mapping; + mapping.map(xferWriteOp.source(), memrefAndIndices.front()); + mapping.map(xferWriteOp.indices(), memrefAndIndices.drop_front()); + auto *clone = b.clone(*xferWriteOp, mapping); + clone->setAttr(xferWriteOp.getInBoundsAttrName(), inBoundsAttr); + + // Create a potential copy from the allocated buffer to the final output in + // the slow path case. + if (options.vectorTransferSplit == VectorTransferSplit::VectorTransfer) + createFullPartialVectorTransferWrite(b, xferWriteOp, inBoundsCond, alloc); + else + createFullPartialLinalgCopy(b, xferWriteOp, inBoundsCond, alloc); + + xferOp->erase(); + + return success(); +} + +LogicalResult mlir::vector::VectorTransferFullPartialRewriter::matchAndRewrite( + Operation *op, PatternRewriter &rewriter) const { + auto xferOp = dyn_cast(op); + if (!xferOp || failed(splitFullAndPartialTransferPrecondition(xferOp)) || + failed(filter(xferOp))) + return failure(); + rewriter.startRootUpdate(xferOp); + if (succeeded(splitFullAndPartialTransfer(rewriter, xferOp, options))) { + rewriter.finalizeRootUpdate(xferOp); + return success(); + } + rewriter.cancelRootUpdate(xferOp); + return failure(); +} diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -144,7 +144,6 @@ namespace { - /// ShapeCastOpFolder folds cancelling ShapeCastOps away. // // Example: @@ -1642,587 +1641,6 @@ } // namespace mlir -static Optional extractConstantIndex(Value v) { - if (auto cstOp = v.getDefiningOp()) - return cstOp.value(); - if (auto affineApplyOp = v.getDefiningOp()) - if (affineApplyOp.getAffineMap().isSingleConstant()) - return affineApplyOp.getAffineMap().getSingleConstantResult(); - return None; -} - -// Missing foldings of scf.if make it necessary to perform poor man's folding -// eagerly, especially in the case of unrolling. In the future, this should go -// away once scf.if folds properly. -static Value createFoldedSLE(OpBuilder &b, Value v, Value ub) { - auto maybeCstV = extractConstantIndex(v); - auto maybeCstUb = extractConstantIndex(ub); - if (maybeCstV && maybeCstUb && *maybeCstV < *maybeCstUb) - return Value(); - return b.create(v.getLoc(), arith::CmpIPredicate::sle, v, ub); -} - -// Operates under a scoped context to build the condition to ensure that a -// particular VectorTransferOpInterface is in-bounds. -static Value createInBoundsCond(OpBuilder &b, - VectorTransferOpInterface xferOp) { - assert(xferOp.permutation_map().isMinorIdentity() && - "Expected minor identity map"); - Value inBoundsCond; - xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) { - // Zip over the resulting vector shape and memref indices. - // If the dimension is known to be in-bounds, it does not participate in - // the construction of `inBoundsCond`. - if (xferOp.isDimInBounds(resultIdx)) - return; - // Fold or create the check that `index + vector_size` <= `memref_size`. - Location loc = xferOp.getLoc(); - ImplicitLocOpBuilder lb(loc, b); - int64_t vectorSize = xferOp.getVectorType().getDimSize(resultIdx); - auto d0 = getAffineDimExpr(0, xferOp.getContext()); - auto vs = getAffineConstantExpr(vectorSize, xferOp.getContext()); - Value sum = - makeComposedAffineApply(b, loc, d0 + vs, xferOp.indices()[indicesIdx]); - Value cond = createFoldedSLE( - b, sum, vector::createOrFoldDimOp(b, loc, xferOp.source(), indicesIdx)); - if (!cond) - return; - // Conjunction over all dims for which we are in-bounds. - if (inBoundsCond) - inBoundsCond = lb.create(inBoundsCond, cond); - else - inBoundsCond = cond; - }); - return inBoundsCond; -} -/// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds -/// masking) fastpath and a slowpath. -/// If `ifOp` is not null and the result is `success, the `ifOp` points to the -/// newly created conditional upon function return. -/// To accomodate for the fact that the original vector.transfer indexing may be -/// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the -/// scf.if op returns a view and values of type index. -/// At this time, only vector.transfer_read case is implemented. -/// -/// Example (a 2-D vector.transfer_read): -/// ``` -/// %1 = vector.transfer_read %0[...], %pad : memref, vector<...> -/// ``` -/// is transformed into: -/// ``` -/// %1:3 = scf.if (%inBounds) { -/// // fastpath, direct cast -/// memref.cast %A: memref to compatibleMemRefType -/// scf.yield %view : compatibleMemRefType, index, index -/// } else { -/// // slowpath, not in-bounds vector.transfer or linalg.copy. -/// memref.cast %alloc: memref to compatibleMemRefType -/// scf.yield %4 : compatibleMemRefType, index, index -// } -/// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]} -/// ``` -/// where `alloc` is a top of the function alloca'ed buffer of one vector. -/// -/// Preconditions: -/// 1. `xferOp.permutation_map()` must be a minor identity map -/// 2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()` -/// must be equal. This will be relaxed in the future but requires -/// rank-reducing subviews. -static LogicalResult -splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp) { - // TODO: support 0-d corner case. - if (xferOp.getTransferRank() == 0) - return failure(); - - // TODO: expand support to these 2 cases. - if (!xferOp.permutation_map().isMinorIdentity()) - return failure(); - // Must have some out-of-bounds dimension to be a candidate for splitting. - if (!xferOp.hasOutOfBoundsDim()) - return failure(); - // Don't split transfer operations directly under IfOp, this avoids applying - // the pattern recursively. - // TODO: improve the filtering condition to make it more applicable. - if (isa(xferOp->getParentOp())) - return failure(); - return success(); -} - -/// Given two MemRefTypes `aT` and `bT`, return a MemRefType to which both can -/// be cast. If the MemRefTypes don't have the same rank or are not strided, -/// return null; otherwise: -/// 1. if `aT` and `bT` are cast-compatible, return `aT`. -/// 2. else return a new MemRefType obtained by iterating over the shape and -/// strides and: -/// a. keeping the ones that are static and equal across `aT` and `bT`. -/// b. using a dynamic shape and/or stride for the dimensions that don't -/// agree. -static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) { - if (memref::CastOp::areCastCompatible(aT, bT)) - return aT; - if (aT.getRank() != bT.getRank()) - return MemRefType(); - int64_t aOffset, bOffset; - SmallVector aStrides, bStrides; - if (failed(getStridesAndOffset(aT, aStrides, aOffset)) || - failed(getStridesAndOffset(bT, bStrides, bOffset)) || - aStrides.size() != bStrides.size()) - return MemRefType(); - - ArrayRef aShape = aT.getShape(), bShape = bT.getShape(); - int64_t resOffset; - SmallVector resShape(aT.getRank(), 0), - resStrides(bT.getRank(), 0); - for (int64_t idx = 0, e = aT.getRank(); idx < e; ++idx) { - resShape[idx] = - (aShape[idx] == bShape[idx]) ? aShape[idx] : MemRefType::kDynamicSize; - resStrides[idx] = (aStrides[idx] == bStrides[idx]) - ? aStrides[idx] - : MemRefType::kDynamicStrideOrOffset; - } - resOffset = - (aOffset == bOffset) ? aOffset : MemRefType::kDynamicStrideOrOffset; - return MemRefType::get( - resShape, aT.getElementType(), - makeStridedLinearLayoutMap(resStrides, resOffset, aT.getContext())); -} - -/// Operates under a scoped context to build the intersection between the -/// view `xferOp.source()` @ `xferOp.indices()` and the view `alloc`. -// TODO: view intersection/union/differences should be a proper std op. -static std::pair -createSubViewIntersection(OpBuilder &b, VectorTransferOpInterface xferOp, - Value alloc) { - ImplicitLocOpBuilder lb(xferOp.getLoc(), b); - int64_t memrefRank = xferOp.getShapedType().getRank(); - // TODO: relax this precondition, will require rank-reducing subviews. - assert(memrefRank == alloc.getType().cast().getRank() && - "Expected memref rank to match the alloc rank"); - ValueRange leadingIndices = - xferOp.indices().take_front(xferOp.getLeadingShapedRank()); - SmallVector sizes; - sizes.append(leadingIndices.begin(), leadingIndices.end()); - auto isaWrite = isa(xferOp); - xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) { - using MapList = ArrayRef>; - Value dimMemRef = vector::createOrFoldDimOp(b, xferOp.getLoc(), - xferOp.source(), indicesIdx); - Value dimAlloc = lb.create(alloc, resultIdx); - Value index = xferOp.indices()[indicesIdx]; - AffineExpr i, j, k; - bindDims(xferOp.getContext(), i, j, k); - SmallVector maps = - AffineMap::inferFromExprList(MapList{{i - j, k}}); - // affine_min(%dimMemRef - %index, %dimAlloc) - Value affineMin = lb.create( - index.getType(), maps[0], ValueRange{dimMemRef, index, dimAlloc}); - sizes.push_back(affineMin); - }); - - SmallVector srcIndices = llvm::to_vector<4>(llvm::map_range( - xferOp.indices(), [](Value idx) -> OpFoldResult { return idx; })); - SmallVector destIndices(memrefRank, b.getIndexAttr(0)); - SmallVector strides(memrefRank, b.getIndexAttr(1)); - auto copySrc = lb.create( - isaWrite ? alloc : xferOp.source(), srcIndices, sizes, strides); - auto copyDest = lb.create( - isaWrite ? xferOp.source() : alloc, destIndices, sizes, strides); - return std::make_pair(copySrc, copyDest); -} - -/// Given an `xferOp` for which: -/// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed. -/// 2. a memref of single vector `alloc` has been allocated. -/// Produce IR resembling: -/// ``` -/// %1:3 = scf.if (%inBounds) { -/// %view = memref.cast %A: memref to compatibleMemRefType -/// scf.yield %view, ... : compatibleMemRefType, index, index -/// } else { -/// %2 = linalg.fill(%pad, %alloc) -/// %3 = subview %view [...][...][...] -/// %4 = subview %alloc [0, 0] [...] [...] -/// linalg.copy(%3, %4) -/// %5 = memref.cast %alloc: memref to compatibleMemRefType -/// scf.yield %5, ... : compatibleMemRefType, index, index -/// } -/// ``` -/// Return the produced scf::IfOp. -static scf::IfOp -createFullPartialLinalgCopy(OpBuilder &b, vector::TransferReadOp xferOp, - TypeRange returnTypes, Value inBoundsCond, - MemRefType compatibleMemRefType, Value alloc) { - Location loc = xferOp.getLoc(); - Value zero = b.create(loc, 0); - Value memref = xferOp.source(); - return b.create( - loc, returnTypes, inBoundsCond, - [&](OpBuilder &b, Location loc) { - Value res = memref; - if (compatibleMemRefType != xferOp.getShapedType()) - res = b.create(loc, memref, compatibleMemRefType); - scf::ValueVector viewAndIndices{res}; - viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(), - xferOp.indices().end()); - b.create(loc, viewAndIndices); - }, - [&](OpBuilder &b, Location loc) { - b.create(loc, xferOp.padding(), alloc); - // Take partial subview of memref which guarantees no dimension - // overflows. - std::pair copyArgs = createSubViewIntersection( - b, cast(xferOp.getOperation()), alloc); - b.create(loc, copyArgs.first, copyArgs.second); - Value casted = - b.create(loc, alloc, compatibleMemRefType); - scf::ValueVector viewAndIndices{casted}; - viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(), - zero); - b.create(loc, viewAndIndices); - }); -} - -/// Given an `xferOp` for which: -/// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed. -/// 2. a memref of single vector `alloc` has been allocated. -/// Produce IR resembling: -/// ``` -/// %1:3 = scf.if (%inBounds) { -/// memref.cast %A: memref to compatibleMemRefType -/// scf.yield %view, ... : compatibleMemRefType, index, index -/// } else { -/// %2 = vector.transfer_read %view[...], %pad : memref, vector<...> -/// %3 = vector.type_cast %extra_alloc : -/// memref<...> to memref> -/// store %2, %3[] : memref> -/// %4 = memref.cast %alloc: memref to compatibleMemRefType -/// scf.yield %4, ... : compatibleMemRefType, index, index -/// } -/// ``` -/// Return the produced scf::IfOp. -static scf::IfOp createFullPartialVectorTransferRead( - OpBuilder &b, vector::TransferReadOp xferOp, TypeRange returnTypes, - Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc) { - Location loc = xferOp.getLoc(); - scf::IfOp fullPartialIfOp; - Value zero = b.create(loc, 0); - Value memref = xferOp.source(); - return b.create( - loc, returnTypes, inBoundsCond, - [&](OpBuilder &b, Location loc) { - Value res = memref; - if (compatibleMemRefType != xferOp.getShapedType()) - res = b.create(loc, memref, compatibleMemRefType); - scf::ValueVector viewAndIndices{res}; - viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(), - xferOp.indices().end()); - b.create(loc, viewAndIndices); - }, - [&](OpBuilder &b, Location loc) { - Operation *newXfer = b.clone(*xferOp.getOperation()); - Value vector = cast(newXfer).vector(); - b.create( - loc, vector, - b.create( - loc, MemRefType::get({}, vector.getType()), alloc)); - - Value casted = - b.create(loc, alloc, compatibleMemRefType); - scf::ValueVector viewAndIndices{casted}; - viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(), - zero); - b.create(loc, viewAndIndices); - }); -} - -/// Given an `xferOp` for which: -/// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed. -/// 2. a memref of single vector `alloc` has been allocated. -/// Produce IR resembling: -/// ``` -/// %1:3 = scf.if (%inBounds) { -/// memref.cast %A: memref to compatibleMemRefType -/// scf.yield %view, ... : compatibleMemRefType, index, index -/// } else { -/// %3 = vector.type_cast %extra_alloc : -/// memref<...> to memref> -/// %4 = memref.cast %alloc: memref to compatibleMemRefType -/// scf.yield %4, ... : compatibleMemRefType, index, index -/// } -/// ``` -static ValueRange -getLocationToWriteFullVec(OpBuilder &b, vector::TransferWriteOp xferOp, - TypeRange returnTypes, Value inBoundsCond, - MemRefType compatibleMemRefType, Value alloc) { - Location loc = xferOp.getLoc(); - Value zero = b.create(loc, 0); - Value memref = xferOp.source(); - return b - .create( - loc, returnTypes, inBoundsCond, - [&](OpBuilder &b, Location loc) { - Value res = memref; - if (compatibleMemRefType != xferOp.getShapedType()) - res = b.create(loc, memref, compatibleMemRefType); - scf::ValueVector viewAndIndices{res}; - viewAndIndices.insert(viewAndIndices.end(), - xferOp.indices().begin(), - xferOp.indices().end()); - b.create(loc, viewAndIndices); - }, - [&](OpBuilder &b, Location loc) { - Value casted = - b.create(loc, alloc, compatibleMemRefType); - scf::ValueVector viewAndIndices{casted}; - viewAndIndices.insert(viewAndIndices.end(), - xferOp.getTransferRank(), zero); - b.create(loc, viewAndIndices); - }) - ->getResults(); -} - -/// Given an `xferOp` for which: -/// 1. `inBoundsCond` has been computed. -/// 2. a memref of single vector `alloc` has been allocated. -/// 3. it originally wrote to %view -/// Produce IR resembling: -/// ``` -/// %notInBounds = arith.xori %inBounds, %true -/// scf.if (%notInBounds) { -/// %3 = subview %alloc [...][...][...] -/// %4 = subview %view [0, 0][...][...] -/// linalg.copy(%3, %4) -/// } -/// ``` -static void createFullPartialLinalgCopy(OpBuilder &b, - vector::TransferWriteOp xferOp, - Value inBoundsCond, Value alloc) { - ImplicitLocOpBuilder lb(xferOp.getLoc(), b); - auto notInBounds = lb.create( - inBoundsCond, lb.create(true, 1)); - lb.create(notInBounds, [&](OpBuilder &b, Location loc) { - std::pair copyArgs = createSubViewIntersection( - b, cast(xferOp.getOperation()), alloc); - b.create(loc, copyArgs.first, copyArgs.second); - b.create(loc, ValueRange{}); - }); -} - -/// Given an `xferOp` for which: -/// 1. `inBoundsCond` has been computed. -/// 2. a memref of single vector `alloc` has been allocated. -/// 3. it originally wrote to %view -/// Produce IR resembling: -/// ``` -/// %notInBounds = arith.xori %inBounds, %true -/// scf.if (%notInBounds) { -/// %2 = load %alloc : memref> -/// vector.transfer_write %2, %view[...] : memref, vector<...> -/// } -/// ``` -static void createFullPartialVectorTransferWrite(OpBuilder &b, - vector::TransferWriteOp xferOp, - Value inBoundsCond, - Value alloc) { - ImplicitLocOpBuilder lb(xferOp.getLoc(), b); - auto notInBounds = lb.create( - inBoundsCond, lb.create(true, 1)); - lb.create(notInBounds, [&](OpBuilder &b, Location loc) { - BlockAndValueMapping mapping; - Value load = b.create( - loc, b.create( - loc, MemRefType::get({}, xferOp.vector().getType()), alloc)); - mapping.map(xferOp.vector(), load); - b.clone(*xferOp.getOperation(), mapping); - b.create(loc, ValueRange{}); - }); -} - -/// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds -/// masking) fastpath and a slowpath. -/// -/// For vector.transfer_read: -/// If `ifOp` is not null and the result is `success, the `ifOp` points to the -/// newly created conditional upon function return. -/// To accomodate for the fact that the original vector.transfer indexing may be -/// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the -/// scf.if op returns a view and values of type index. -/// -/// Example (a 2-D vector.transfer_read): -/// ``` -/// %1 = vector.transfer_read %0[...], %pad : memref, vector<...> -/// ``` -/// is transformed into: -/// ``` -/// %1:3 = scf.if (%inBounds) { -/// // fastpath, direct cast -/// memref.cast %A: memref to compatibleMemRefType -/// scf.yield %view : compatibleMemRefType, index, index -/// } else { -/// // slowpath, not in-bounds vector.transfer or linalg.copy. -/// memref.cast %alloc: memref to compatibleMemRefType -/// scf.yield %4 : compatibleMemRefType, index, index -// } -/// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]} -/// ``` -/// where `alloc` is a top of the function alloca'ed buffer of one vector. -/// -/// For vector.transfer_write: -/// There are 2 conditional blocks. First a block to decide which memref and -/// indices to use for an unmasked, inbounds write. Then a conditional block to -/// further copy a partial buffer into the final result in the slow path case. -/// -/// Example (a 2-D vector.transfer_write): -/// ``` -/// vector.transfer_write %arg, %0[...], %pad : memref, vector<...> -/// ``` -/// is transformed into: -/// ``` -/// %1:3 = scf.if (%inBounds) { -/// memref.cast %A: memref to compatibleMemRefType -/// scf.yield %view : compatibleMemRefType, index, index -/// } else { -/// memref.cast %alloc: memref to compatibleMemRefType -/// scf.yield %4 : compatibleMemRefType, index, index -/// } -/// %0 = vector.transfer_write %arg, %1#0[%1#1, %1#2] {in_bounds = [true ... -/// true]} -/// scf.if (%notInBounds) { -/// // slowpath: not in-bounds vector.transfer or linalg.copy. -/// } -/// ``` -/// where `alloc` is a top of the function alloca'ed buffer of one vector. -/// -/// Preconditions: -/// 1. `xferOp.permutation_map()` must be a minor identity map -/// 2. the rank of the `xferOp.source()` and the rank of the `xferOp.vector()` -/// must be equal. This will be relaxed in the future but requires -/// rank-reducing subviews. -LogicalResult mlir::vector::splitFullAndPartialTransfer( - OpBuilder &b, VectorTransferOpInterface xferOp, - VectorTransformsOptions options, scf::IfOp *ifOp) { - if (options.vectorTransferSplit == VectorTransferSplit::None) - return failure(); - - SmallVector bools(xferOp.getTransferRank(), true); - auto inBoundsAttr = b.getBoolArrayAttr(bools); - if (options.vectorTransferSplit == VectorTransferSplit::ForceInBounds) { - xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr); - return success(); - } - - // Assert preconditions. Additionally, keep the variables in an inner scope to - // ensure they aren't used in the wrong scopes further down. - { - assert(succeeded(splitFullAndPartialTransferPrecondition(xferOp)) && - "Expected splitFullAndPartialTransferPrecondition to hold"); - - auto xferReadOp = dyn_cast(xferOp.getOperation()); - auto xferWriteOp = dyn_cast(xferOp.getOperation()); - - if (!(xferReadOp || xferWriteOp)) - return failure(); - if (xferWriteOp && xferWriteOp.mask()) - return failure(); - if (xferReadOp && xferReadOp.mask()) - return failure(); - } - - OpBuilder::InsertionGuard guard(b); - b.setInsertionPoint(xferOp); - Value inBoundsCond = createInBoundsCond( - b, cast(xferOp.getOperation())); - if (!inBoundsCond) - return failure(); - - // Top of the function `alloc` for transient storage. - Value alloc; - { - FuncOp funcOp = xferOp->getParentOfType(); - OpBuilder::InsertionGuard guard(b); - b.setInsertionPointToStart(&funcOp.getRegion().front()); - auto shape = xferOp.getVectorType().getShape(); - Type elementType = xferOp.getVectorType().getElementType(); - alloc = b.create(funcOp.getLoc(), - MemRefType::get(shape, elementType), - ValueRange{}, b.getI64IntegerAttr(32)); - } - - MemRefType compatibleMemRefType = - getCastCompatibleMemRefType(xferOp.getShapedType().cast(), - alloc.getType().cast()); - if (!compatibleMemRefType) - return failure(); - - SmallVector returnTypes(1 + xferOp.getTransferRank(), - b.getIndexType()); - returnTypes[0] = compatibleMemRefType; - - if (auto xferReadOp = - dyn_cast(xferOp.getOperation())) { - // Read case: full fill + partial copy -> in-bounds vector.xfer_read. - scf::IfOp fullPartialIfOp = - options.vectorTransferSplit == VectorTransferSplit::VectorTransfer - ? createFullPartialVectorTransferRead(b, xferReadOp, returnTypes, - inBoundsCond, - compatibleMemRefType, alloc) - : createFullPartialLinalgCopy(b, xferReadOp, returnTypes, - inBoundsCond, compatibleMemRefType, - alloc); - if (ifOp) - *ifOp = fullPartialIfOp; - - // Set existing read op to in-bounds, it always reads from a full buffer. - for (unsigned i = 0, e = returnTypes.size(); i != e; ++i) - xferReadOp.setOperand(i, fullPartialIfOp.getResult(i)); - - xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr); - - return success(); - } - - auto xferWriteOp = cast(xferOp.getOperation()); - - // Decide which location to write the entire vector to. - auto memrefAndIndices = getLocationToWriteFullVec( - b, xferWriteOp, returnTypes, inBoundsCond, compatibleMemRefType, alloc); - - // Do an in bounds write to either the output or the extra allocated buffer. - // The operation is cloned to prevent deleting information needed for the - // later IR creation. - BlockAndValueMapping mapping; - mapping.map(xferWriteOp.source(), memrefAndIndices.front()); - mapping.map(xferWriteOp.indices(), memrefAndIndices.drop_front()); - auto *clone = b.clone(*xferWriteOp, mapping); - clone->setAttr(xferWriteOp.getInBoundsAttrName(), inBoundsAttr); - - // Create a potential copy from the allocated buffer to the final output in - // the slow path case. - if (options.vectorTransferSplit == VectorTransferSplit::VectorTransfer) - createFullPartialVectorTransferWrite(b, xferWriteOp, inBoundsCond, alloc); - else - createFullPartialLinalgCopy(b, xferWriteOp, inBoundsCond, alloc); - - xferOp->erase(); - - return success(); -} - -LogicalResult mlir::vector::VectorTransferFullPartialRewriter::matchAndRewrite( - Operation *op, PatternRewriter &rewriter) const { - auto xferOp = dyn_cast(op); - if (!xferOp || failed(splitFullAndPartialTransferPrecondition(xferOp)) || - failed(filter(xferOp))) - return failure(); - rewriter.startRootUpdate(xferOp); - if (succeeded(splitFullAndPartialTransfer(rewriter, xferOp, options))) { - rewriter.finalizeRootUpdate(xferOp); - return success(); - } - rewriter.cancelRootUpdate(xferOp); - return failure(); -} - Optional mlir::vector::distributPointwiseVectorOp( OpBuilder &builder, Operation *op, ArrayRef ids, ArrayRef multiplicity, const AffineMap &map) {