diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -26,14 +26,15 @@ #include "llvm/ADT/SmallSet.h" namespace mlir { + +class FrozenRewritePatternSet; +class RewriterBase; + namespace bufferization { class BufferizeTypeConverter; } // namespace bufferization -class FrozenRewritePatternSet; - namespace linalg { - struct LinalgElementwiseFusionOptions; struct LinalgFusionOptions; struct LinalgTilingOptions; @@ -364,7 +365,7 @@ LinalgPromotionOptions options); /// Return success if the operation can be vectorized. -LogicalResult vectorizeLinalgOpPrecondition(LinalgOp linalgOp); +LogicalResult vectorizeLinalgOpPrecondition(RewriterBase &b, LinalgOp linalgOp); //===----------------------------------------------------------------------===// // Transformations exposed as rewrite patterns. diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h --- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h @@ -27,7 +27,7 @@ namespace mlir { -class OpBuilder; +class RewriterBase; /// Tests whether the given maps describe a row major matmul. The test is /// permutation-invariant. Note that this only checks the affine maps from an @@ -79,7 +79,7 @@ Red() : IteratorType(IteratorTypeT::reduction) {} }; - StructuredGenerator(OpBuilder &builder, StructuredOpInterface op) + StructuredGenerator(RewriterBase &builder, StructuredOpInterface op) : builder(builder), ctx(op.getContext()), loc(op.getLoc()), iterators(op.getIteratorTypesArray()), maps(op.getIndexingMapsArray()), op(op) {} @@ -100,7 +100,7 @@ } protected: - OpBuilder &builder; + RewriterBase &builder; MLIRContext *ctx; Location loc; SmallVector iterators; diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -1174,13 +1174,14 @@ OptionalAttr:$in_bounds)>, Results<(outs AnyVectorOfAnyRank:$vector)> { - let summary = "Reads a supervector from memory into an SSA vector value."; + let summary = "Reads an n-D from a tensor or memref into an SSA vector value."; let description = [{ - The `vector.transfer_read` op performs a read from a slice within a + The `vector.transfer_read` op performs a read from a slice of a [MemRef](../LangRef.md#memref-type) or a Ranked [Tensor](../LangRef.md#tensor-type) supplied as its first operand - into a [vector](../LangRef.md#vector-type) of the same base elemental type. + into an n-D [vector](../LangRef.md#vector-type) of the same base elemental + type. A memref/tensor operand with vector element type, must have its vector element type match a suffix (shape and element type) of the vector (e.g. @@ -1190,11 +1191,11 @@ supplied as the operands `[1 .. 1 + rank(memref/tensor))` that defines the starting point of the transfer (e.g. `%A[%i0, %i1, %i2]`). - The permutation_map [attribute](../LangRef.md#attributes) is an + The permutation_map [attribute](../LangRef.md#attributes) is an optional [affine-map](Affine.md#affine-maps) which specifies the transposition on the slice to match the vector shape. The permutation map may be implicit and omitted from parsing and printing if it is the canonical minor identity map - (i.e. if it does not permute or broadcast any dimension). + (i.e. if it does not permute any dimension). The size of the slice is specified by the size of the vector, given as the return type. @@ -1213,78 +1214,50 @@ run out-of-bounds as indices increase. Broadcast dimensions must always be in-bounds. If specified, the `in_bounds` array length has to be equal to the vector rank. In absence of the attribute, accesses along all dimensions - (except for broadcasts) may run out-of-bounds. A `vector.transfer_read` can - be lowered to a simple load if all dimensions are specified to be within - bounds and no `mask` was specified. + may run out-of-bounds. + A `vector.transfer_read` can be lowered to a simple load if all dimensions + are specified to be within bounds and no `mask` was specified. + + This operation is called 'read' by opposition to 'load' because the n-D + vector granularity is generally not representable with a single hardware + register. - This operation is called 'read' by opposition to 'load' because the - super-vector granularity is generally not representable with a single - hardware register. A `vector.transfer_read` is thus a mid-level abstraction - that supports super-vectorization with non-effecting padding for full-tile - only operations. + It is the responsibility of `vector.read`'s implementation to ensure the + memory loads are valid. Different lowerings may be pertinent depending on + the hardware support. More precisely, let's dive deeper into the permutation_map for the following - MLIR: + MLIR snippets: ```mlir vector.transfer_read %A[%expr1, %expr2, %expr3, %expr4] - { permutation_map : (d0,d1,d2,d3) -> (d2,0,d0) } : + { permutation_map : (d0,d1,d2,d3) -> (d2, d0) } : memref, vector<3x4x5xf32> ``` - This operation always reads a slice starting at `%A[%expr1, %expr2, %expr3, - %expr4]`. The size of the slice is 3 along d2 and 5 along d0, so the slice - is: `%A[%expr1 : %expr1 + 5, %expr2, %expr3:%expr3 + 3, %expr4]` + This operation always reads a slice starting at + `%A[%expr1, %expr2, %expr3, %expr4]`. + The size of the slice is 3 along d2 and 5 along d0, so the slice is: + `%A[%expr1 : %expr1 + 5, %expr2, %expr3:%expr3 + 3, %expr4]` - That slice needs to be read into a `vector<3x4x5xf32>`. Since the - permutation map is not full rank, there must be a broadcast along vector - dimension `1`. + That slice needs to be read into a `vector<3x5xf32>`. A notional lowering of vector.transfer_read could generate code resembling: ```mlir // %expr1, %expr2, %expr3, %expr4 defined before this point - %tmp = alloc() : vector<3x4x5xf32> - %view_in_tmp = "element_type_cast"(%tmp) : memref<1xvector<3x4x5xf32>> + %tmp = alloc() : memref<3x5xf32> + %view_in_tmp = "element_type_cast"(%tmp) : memref> for %i = 0 to 3 { - affine.for %j = 0 to 4 { - affine.for %k = 0 to 5 { - %a = load %A[%expr1 + %k, %expr2, %expr3 + %i, %expr4] : - memref - store %tmp[%i, %j, %k] : vector<3x4x5xf32> - }}} - %c0 = arith.constant 0 : index - %vec = load %view_in_tmp[%c0] : vector<3x4x5xf32> - ``` - - On a GPU one could then map `i`, `j`, `k` to blocks and threads. Notice that - the temporary storage footprint is `3 * 5` values but `3 * 4 * 5` values are - actually transferred between `%A` and `%tmp`. - - Alternatively, if a notional vector broadcast operation were available, the - lowered code would resemble: - - ```mlir - // %expr1, %expr2, %expr3, %expr4 defined before this point - %tmp = alloc() : vector<3x4x5xf32> - %view_in_tmp = "element_type_cast"(%tmp) : memref<1xvector<3x4x5xf32>> - for %i = 0 to 3 { - affine.for %k = 0 to 5 { - %a = load %A[%expr1 + %k, %expr2, %expr3 + %i, %expr4] : + affine.for %j = 0 to 5 { + %a = load %A[%expr1 + %j, %expr2, %expr3 + %i, %expr4] : memref - store %tmp[%i, 0, %k] : vector<3x4x5xf32> + store %a, %tmp[%i, %j] : memref<3x5xf32> }} %c0 = arith.constant 0 : index - %tmpvec = load %view_in_tmp[%c0] : vector<3x4x5xf32> - %vec = broadcast %tmpvec, 1 : vector<3x4x5xf32> + %vec = load %view_in_tmp[] : vector<3x5xf32> ``` - where `broadcast` broadcasts from element 0 to all others along the - specified dimension. This time, the temporary storage footprint is `3 * 5` - values which is the same amount of data as the `3 * 5` values transferred. - An additional `1` broadcast is required. On a GPU this broadcast could be - implemented using a warp-shuffle if loop `j` were mapped to `threadIdx.x`. - Syntax ``` operation ::= ssa-id `=` `vector.transfer_read` ssa-use-list @@ -1305,17 +1278,6 @@ memref, vector<32x256xf32> }}} - // Read the slice `%A[%i0, %i1]` (i.e. the element `%A[%i0, %i1]`) into - // vector<128xf32>. The underlying implementation will require a 1-D vector - // broadcast: - for %i0 = 0 to %0 { - affine.for %i1 = 0 to %1 { - %3 = vector.transfer_read %A[%i0, %i1] - {permutation_map: (d0, d1) -> (0)} : - memref, vector<128xf32> - } - } - // Read from a memref with vector element type. %4 = vector.transfer_read %arg1[%c3, %c3], %vf0 {permutation_map = (d0, d1)->(d0, d1)} @@ -1325,11 +1287,6 @@ %4 = vector.transfer_read %arg1[%c3, %c3], %vf0 {permutation_map = (d0, d1)->(d0, d1)} : tensor>, vector<1x1x4x3xf32> - - // Special encoding for 0-d transfer with 0-d tensor/memref, vector shape - // {1} and permutation_map () -> (0). - %0 = vector.transfer_read %arg0[], %f0 {permutation_map = affine_map<()->(0)>} : - tensor, vector<1xf32> ``` }]; @@ -1388,10 +1345,10 @@ OptionalAttr:$in_bounds)>, Results<(outs Optional:$result)> { - let summary = "The vector.transfer_write op writes a supervector to memory."; + let summary = "The vector.transfer_write op writes an n-D vector to a tensor or memref."; let description = [{ - The `vector.transfer_write` op performs a write from a + The `vector.transfer_write` op performs a write from an n-D [vector](../LangRef.md#vector-type), supplied as its first operand, into a slice within a [MemRef](../LangRef.md#memref-type) or a Ranked [Tensor](../LangRef.md#tensor-type) of the same base elemental type, @@ -1399,8 +1356,11 @@ A vector memref/tensor operand must have its vector element type match a suffix (shape and element type) of the vector (e.g. memref<3x2x6x4x3xf32>, - vector<1x1x4x3xf32>). If the operand is a tensor, the operation returns a - new tensor of the same type. + vector<1x1x4x3xf32>). + + If the operand is a tensor, the operation returns a new tensor of the same + type with the corresponding elements updated to the values contained in the + vector. The slice is further defined by a full-rank index within the MemRef/Tensor, supplied as the operands `[2 .. 2 + rank(memref/tensor))` that defines the @@ -1410,8 +1370,7 @@ [affine-map](Affine.md#affine-maps) which specifies the transposition on the slice to match the vector shape. The permutation map may be implicit and omitted from parsing and printing if it is the canonical minor identity map - (i.e. if it does not permute any dimension). In contrast to `transfer_read`, - write ops cannot have broadcast dimensions. + (i.e. if it does not permute any dimension). The size of the slice is specified by the size of the vector. @@ -1424,18 +1383,20 @@ While the starting point of the transfer has to be in-bounds, accesses may run out-of-bounds as indices increase. If specified, the `in_bounds` array length has to be equal to the vector rank. In absence of the attribute, - accesses along all dimensions may run out-of-bounds. A - `vector.transfer_write` can be lowered to a simple store if all dimensions + accesses along all dimensions may run out-of-bounds. + A `vector.transfer_write` can be lowered to a simple store if all dimensions are specified to be within bounds and no `mask` was specified. - This operation is called 'write' by opposition to 'store' because the - super-vector granularity is generally not representable with a single - hardware register. A `vector.transfer_write` is thus a - mid-level abstraction that supports super-vectorization with non-effecting - padding for full-tile-only code. It is the responsibility of - `vector.transfer_write`'s implementation to ensure the memory writes are - valid. Different lowerings may be pertinent depending on the hardware - support. + This operation is called 'write' by opposition to 'store' because the n-D + vector granularity is generally not representable with a single hardware + register. + + It is the responsibility of `vector.transfer_write`'s implementation to + ensure the memory stores are valid. Different lowerings may be pertinent + depending on the hardware support. + + More precisely, let's dive deeper into the permutation_map for the following + MLIR snippets: Example: @@ -1461,11 +1422,6 @@ %5 = vector.transfer_write %4, %arg1[%c3, %c3] {permutation_map = (d0, d1)->(d0, d1)} : vector<1x1x4x3xf32>, tensor> - - // Special encoding for 0-d transfer with 0-d tensor/memref, vector shape - // {1} and permutation_map () -> (0). - %1 = vector.transfer_write %0, %arg0[] {permutation_map = affine_map<()->(0)>} : - vector<1xf32>, tensor ``` }]; diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h --- a/mlir/include/mlir/IR/AffineMap.h +++ b/mlir/include/mlir/IR/AffineMap.h @@ -16,6 +16,7 @@ #include "mlir/IR/AffineExpr.h" #include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMapInfo.h" #include "llvm/ADT/SmallBitVector.h" @@ -117,6 +118,11 @@ bool isMinorIdentityWithBroadcasting( SmallVectorImpl *broadcastedDims = nullptr) const; + /// Return a permutation vector encoding the permutation of the map's results. + /// This is computed once the unused dims and symbols are compressed away. + /// Return failure if the compressed map is not exactly a permutation. + FailureOr> getDimPermutationVector() const; + /// Return true if this affine map can be converted to a minor identity with /// broadcast by doing a permute. Return a permutation (there may be /// several) to apply to get to a minor identity with broadcasts. @@ -125,13 +131,14 @@ /// perm = [1, 0] and broadcast d2 /// * (d0, d1, d2) -> (d0, 0) cannot be mapped to a minor identity by /// permutation + broadcast - /// * (d0, d1, d2, d3) -> (0, d1, d3) maps to minor identity (d1, 0 = d2, d3) - /// with perm = [1, 0, 2] and broadcast d2 + /// * (d0, d1, d2, d3) -> (0, d1, d3) maps to minor identity (d1, 0 = d2, + /// d3) with perm = [1, 0, 2] and broadcast d2 /// * (d0, d1) -> (d1, 0, 0, d0) maps to minor identity (d0, d1) with extra /// leading broadcat dimensions. The map returned would be (0, 0, d0, d1) /// with perm = [3, 0, 1, 2] bool isPermutationOfMinorIdentityWithBroadcasting( - SmallVectorImpl &permutedDims) const; + SmallVectorImpl &permutedDims, + bool allowBroadcast = false) const; /// Returns true if this affine map is an empty map, i.e., () -> (). bool isEmpty() const; @@ -142,12 +149,12 @@ /// Returns true if this affine map has only constant results. bool isConstant() const; - /// Returns the constant result of this map. This methods asserts that the map - /// has a single constant result. + /// Returns the constant result of this map. This methods asserts that the + /// map has a single constant result. int64_t getSingleConstantResult() const; - /// Returns the constant results of this map. This method asserts that the map - /// has all constant results. + /// Returns the constant results of this map. This method asserts that the + /// map has all constant results. SmallVector getConstantResults() const; // Prints affine map to 'os'. @@ -177,7 +184,8 @@ }); } - /// Return true if any affine expression involves AffineSymbolExpr `position`. + /// Return true if any affine expression involves AffineSymbolExpr + /// `position`. bool isFunctionOfSymbol(unsigned position) const { return llvm::any_of(getResults(), [&](AffineExpr e) { return e.isFunctionOfSymbol(position); @@ -189,24 +197,24 @@ void walkExprs(llvm::function_ref callback) const; /// This method substitutes any uses of dimensions and symbols (e.g. - /// dim#0 with dimReplacements[0]) in subexpressions and returns the modified - /// expression mapping. Because this can be used to eliminate dims and - /// symbols, the client needs to specify the number of dims and symbols in - /// the result. The returned map always has the same number of results. + /// dim#0 with dimReplacements[0]) in subexpressions and returns the + /// modified expression mapping. Because this can be used to eliminate dims + /// and symbols, the client needs to specify the number of dims and symbols + /// in the result. The returned map always has the same number of results. AffineMap replaceDimsAndSymbols(ArrayRef dimReplacements, ArrayRef symReplacements, unsigned numResultDims, unsigned numResultSyms) const; - /// Sparse replace method. Apply AffineExpr::replace(`expr`, `replacement`) to - /// each of the results and return a new AffineMap with the new results and - /// with the specified number of dims and symbols. + /// Sparse replace method. Apply AffineExpr::replace(`expr`, `replacement`) + /// to each of the results and return a new AffineMap with the new results + /// and with the specified number of dims and symbols. AffineMap replace(AffineExpr expr, AffineExpr replacement, unsigned numResultDims, unsigned numResultSyms) const; /// Sparse replace method. Apply AffineExpr::replace(`map`) to each of the - /// results and return a new AffineMap with the new results and with inferred - /// number of dims and symbols. + /// results and return a new AffineMap with the new results and with + /// inferred number of dims and symbols. AffineMap replace(const DenseMap &map) const; /// Sparse replace method. Apply AffineExpr::replace(`map`) to each of the @@ -271,11 +279,11 @@ SmallVectorImpl &results) const; /// Propagates the constant operands into this affine map. Operands are - /// allowed to be null, at which point they are treated as non-constant. This - /// does not change the number of symbols and dimensions. Returns a new map, - /// which may be equal to the old map if no folding happened. If `results` is - /// provided and if all expressions in the map were folded to constants, - /// `results` will contain the values of these constants. + /// allowed to be null, at which point they are treated as non-constant. + /// This does not change the number of symbols and dimensions. Returns a new + /// map, which may be equal to the old map if no folding happened. If + /// `results` is provided and if all expressions in the map were folded to + /// constants, `results` will contain the values of these constants. AffineMap partialConstantFold(ArrayRef operandConstants, SmallVectorImpl *results = nullptr) const; @@ -300,8 +308,8 @@ /// returns the resulting values. `this` must be symbol-less. SmallVector compose(ArrayRef values) const; - /// Returns true if the AffineMap represents a subset (i.e. a projection) of a - /// symbol-less permutation map. `allowZeroInResults` allows projected + /// Returns true if the AffineMap represents a subset (i.e. a projection) of + /// a symbol-less permutation map. `allowZeroInResults` allows projected /// permutation maps with constant zero result expressions. /// TODO: Remove `allowZeroInResults` when constant zero result expressions /// are broadly supported. @@ -313,7 +321,8 @@ /// Returns the map consisting of the `resultPos` subset. AffineMap getSubMap(ArrayRef resultPos) const; - /// Returns the map consisting of `length` expressions starting from `start`. + /// Returns the map consisting of `length` expressions starting from + /// `start`. AffineMap getSliceMap(unsigned start, unsigned length) const; /// Returns the map consisting of the most major `numResults` results. @@ -328,8 +337,9 @@ /// Get the largest known divisor of all map expressions. /// For eg: for (d0, d1) -> (8*d0 + 4, 4*d1 + 2), the result is 2. - /// In the case of maps with no expressions or all zero constant expressions, - /// the largest known divisor is trivially the max uint64_t value. + /// In the case of maps with no expressions or all zero constant + /// expressions, the largest known divisor is trivially the max uint64_t + /// value. uint64_t getLargestKnownDivisorOfMapExprs(); friend ::llvm::hash_code hash_value(AffineMap arg); diff --git a/mlir/include/mlir/Interfaces/VectorInterfaces.td b/mlir/include/mlir/Interfaces/VectorInterfaces.td --- a/mlir/include/mlir/Interfaces/VectorInterfaces.td +++ b/mlir/include/mlir/Interfaces/VectorInterfaces.td @@ -76,10 +76,9 @@ /*args=*/(ins "unsigned":$dim), /*methodBody=*/"", /*defaultImplementation=*/[{ - return $_op.isBroadcastDim(dim) - || ($_op.getInBounds() - && $_op.getInBounds()->template cast<::mlir::ArrayAttr>()[dim] - .template cast<::mlir::BoolAttr>().getValue()); + return $_op.getInBounds() && + $_op.getInBounds()->template cast<::mlir::ArrayAttr>()[dim] + .template cast<::mlir::BoolAttr>().getValue(); }] >, InterfaceMethod< @@ -114,33 +113,6 @@ /*methodBody=*/"return $_op.getPermutationMap();" /*defaultImplementation=*/ >, - InterfaceMethod< - /*desc=*/[{ Returns true if the specified dimension is a broadcast. }], - /*retTy=*/"bool", - /*methodName=*/"isBroadcastDim", - /*args=*/(ins "unsigned":$idx), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - auto expr = $_op.getPermutationMap().getResult(idx); - return expr.template isa<::mlir::AffineConstantExpr>() && - expr.template dyn_cast<::mlir::AffineConstantExpr>().getValue() == 0; - }] - >, - InterfaceMethod< - /*desc=*/[{ Returns true if at least one of the dimensions in the - permutation map is a broadcast.}], - /*retTy=*/"bool", - /*methodName=*/"hasBroadcastDim", - /*args=*/(ins), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - for (unsigned i = 0, rank = getTransferRank(); i < rank; ++i) { - if ($_op.isBroadcastDim(i)) - return true; - } - return false; - }] - >, InterfaceMethod< /*desc=*/"Return the `in_bounds` boolean ArrayAttr.", /*retTy=*/"::llvm::Optional<::mlir::ArrayAttr>", @@ -265,9 +237,6 @@ $_op.getVectorType().getShape())) { AffineExpr dim = std::get<0>(vecDims); int64_t size = std::get<1>(vecDims); - // Skip broadcast. - if (dim.isa()) - continue; dimSizes[dim.cast().getPosition()] = size; } return dimSizes; diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -20,6 +20,7 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -52,18 +53,12 @@ /// Given a vector transfer op, calculate which dimension of the `source` /// memref should be unpacked in the next application of TransferOpConversion. -/// A return value of None indicates a broadcast. template -static Optional unpackedDim(OpTy xferOp) { +static int64_t unpackedDim(OpTy xferOp) { // TODO: support 0-d corner case. assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer"); auto map = xferOp.getPermutationMap(); - if (auto expr = map.getResult(0).template dyn_cast()) { - return expr.getPosition(); - } - assert(xferOp.isBroadcastDim(0) && - "Expected AffineDimExpr or AffineConstantExpr"); - return None; + return map.getResult(0).template cast().getPosition(); } /// Compute the permutation map for the new (N-1)-D vector transfer op. This @@ -89,19 +84,15 @@ SmallVector &indices) { typename OpTy::Adaptor adaptor(xferOp); // Corresponding memref dim of the vector dim that is unpacked. - auto dim = unpackedDim(xferOp); + int64_t dim = unpackedDim(xferOp); auto prevIndices = adaptor.getIndices(); indices.append(prevIndices.begin(), prevIndices.end()); Location loc = xferOp.getLoc(); - bool isBroadcast = !dim.has_value(); - if (!isBroadcast) { - AffineExpr d0, d1; - bindDims(xferOp.getContext(), d0, d1); - Value offset = adaptor.getIndices()[dim.value()]; - indices[dim.value()] = - makeComposedAffineApply(b, loc, d0 + d1, {offset, iv}); - } + AffineExpr d0, d1; + bindDims(xferOp.getContext(), d0, d1); + Value offset = adaptor.getIndices()[dim]; + indices[dim] = makeComposedAffineApply(b, loc, d0 + d1, {offset, iv}); } static void maybeYieldValue(OpBuilder &b, Location loc, bool hasRetVal, @@ -119,16 +110,12 @@ /// * xferOp does not have a mask. /// * xferOp's mask is not 1D. (In case of (N>1)-D, a subvector of the mask is /// computed and attached to the new transfer op in the pattern.) -/// * The to-be-unpacked dim of xferOp is a broadcast. template static Value generateMaskCheck(OpBuilder &b, OpTy xferOp, Value iv) { if (!xferOp.getMask()) return Value(); if (xferOp.getMaskType().getRank() != 1) return Value(); - if (xferOp.isBroadcastDim(0)) - return Value(); - Location loc = xferOp.getLoc(); return b.create(loc, xferOp.getMask(), iv); } @@ -159,23 +146,21 @@ /// `resultTypes`. template static Value generateInBoundsCheck( - OpBuilder &b, OpTy xferOp, Value iv, Optional dim, - TypeRange resultTypes, + OpBuilder &b, OpTy xferOp, Value iv, int64_t dim, TypeRange resultTypes, function_ref inBoundsCase, function_ref outOfBoundsCase = nullptr) { bool hasRetVal = !resultTypes.empty(); Value cond; // Condition to be built... // Condition check 1: Access in-bounds? - bool isBroadcast = !dim; // No in-bounds check for broadcasts. Location loc = xferOp.getLoc(); ImplicitLocOpBuilder lb(xferOp.getLoc(), b); - if (!xferOp.isDimInBounds(0) && !isBroadcast) { + if (!xferOp.isDimInBounds(0)) { Value memrefDim = - vector::createOrFoldDimOp(b, loc, xferOp.getSource(), *dim); + vector::createOrFoldDimOp(b, loc, xferOp.getSource(), dim); AffineExpr d0, d1; bindDims(xferOp.getContext(), d0, d1); - Value base = xferOp.getIndices()[*dim]; + Value base = xferOp.getIndices()[dim]; Value memrefIdx = makeComposedAffineApply(b, loc, d0 + d1, {base, iv}); cond = lb.create(arith::CmpIPredicate::sgt, memrefDim, memrefIdx); @@ -217,7 +202,7 @@ /// a return value. Consequently, this function does not have a return value. template static void generateInBoundsCheck( - OpBuilder &b, OpTy xferOp, Value iv, Optional dim, + OpBuilder &b, OpTy xferOp, Value iv, int64_t dim, function_ref inBoundsCase, function_ref outOfBoundsCase = nullptr) { generateInBoundsCheck( @@ -695,7 +680,7 @@ ImplicitLocOpBuilder locB(xferOp.getLoc(), rewriter); auto dataBuffer = Strategy::getBuffer(xferOp); auto dataBufferType = dataBuffer.getType().template dyn_cast(); - auto castedDataType = unpackOneDim(dataBufferType); + MemRefType castedDataType = unpackOneDim(dataBufferType); auto castedDataBuffer = locB.create(castedDataType, dataBuffer); @@ -705,15 +690,12 @@ auto maskBuffer = getMaskBuffer(xferOp); auto maskBufferType = maskBuffer.getType().template dyn_cast(); - if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) { + if (xferOp.getMaskType().getRank() == 1) { // Do not unpack a dimension of the mask, if: - // * To-be-unpacked transfer op dimension is a broadcast. // * Mask is 1D, i.e., the mask cannot be further unpacked. - // (That means that all remaining dimensions of the transfer op must - // be broadcasted.) castedMaskBuffer = maskBuffer; } else { - auto castedMaskType = unpackOneDim(maskBufferType); + MemRefType castedMaskType = unpackOneDim(maskBufferType); castedMaskBuffer = locB.create(castedMaskType, maskBuffer); } @@ -743,22 +725,16 @@ OpTy newXfer = Strategy::rewriteOp( b, this->options, xferOp, castedDataBuffer, iv, loopState); - // If old transfer op has a mask: Set mask on new transfer op. - // Special case: If the mask of the old transfer op is 1D and - // the - // unpacked dim is not a broadcast, no mask is - // needed on the new transfer op. - if (xferOp.getMask() && (xferOp.isBroadcastDim(0) || - xferOp.getMaskType().getRank() > 1)) { + // If old transfer op has a mask, set mask on new transfer op. + // Special case: If the mask of the old transfer op is 1D, no + // mask is needed on the new transfer op. + if (xferOp.getMask() && xferOp.getMaskType().getRank() > 1) { OpBuilder::InsertionGuard guard(b); b.setInsertionPoint(newXfer); // Insert load before newXfer. SmallVector loadIndices; Strategy::getBufferIndices(xferOp, loadIndices); - // In case of broadcast: Use same indices to load from memref - // as before. - if (!xferOp.isBroadcastDim(0)) - loadIndices.push_back(iv); + loadIndices.push_back(iv); auto mask = b.create(loc, castedMaskBuffer, loadIndices); @@ -795,13 +771,6 @@ if (!xferOp.getMask()) return; - if (xferOp.isBroadcastDim(0)) { - // To-be-unpacked dimension is a broadcast, which does not have a - // corresponding mask dimension. Mask attribute remains unchanged. - newXferOp.getMaskMutable().assign(xferOp.getMask()); - return; - } - if (xferOp.getMaskType().getRank() > 1) { // Unpack one dimension of the mask. OpBuilder::InsertionGuard guard(b); @@ -813,8 +782,8 @@ newXferOp.getMaskMutable().assign(newMask); } - // If we end up here: The mask of the old transfer op is 1D and the unpacked - // dim is not a broadcast, so no mask is needed on the new transfer op. + // If we end up here, then the mask of the old transfer op is 1D, so no mask + // is needed on the new transfer op. // `generateInBoundsCheck` will have evaluated the mask already. } @@ -889,8 +858,8 @@ } } - /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds - /// accesses, and broadcasts and transposes in permutation maps. + /// Rewrite the op by unpacking one dimension. + /// This handles masks, out-of-bounds accesses and permutation map transposes. LogicalResult matchAndRewrite(TransferReadOp xferOp, PatternRewriter &rewriter) const override { if (xferOp.getVectorType().getRank() <= options.targetRank) @@ -1016,8 +985,8 @@ } } - /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds - /// accesses, and broadcasts and transposes in permutation maps. + /// Rewrite the op by unpacking one dimension. + /// This handles masks, out-of-bounds accesses and permutation map transposes. LogicalResult matchAndRewrite(TransferWriteOp xferOp, PatternRewriter &rewriter) const override { if (xferOp.getVectorType().getRank() <= options.targetRank) @@ -1089,12 +1058,11 @@ namespace lowering_1_d { /// Compute the indices into the memref for the LoadOp/StoreOp generated as -/// part of TransferOp1dConversion. Return the memref dimension on which -/// the transfer is operating. A return value of None indicates a broadcast. +/// part of TransferOp1dConversion. +/// Return the memref dimension on which the transfer is operating. template -static Optional -get1dMemrefIndices(OpBuilder &b, OpTy xferOp, Value iv, - SmallVector &memrefIndices) { +static int64_t get1dMemrefIndices(OpBuilder &b, OpTy xferOp, Value iv, + SmallVector &memrefIndices) { auto indices = xferOp.getIndices(); auto map = xferOp.getPermutationMap(); assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer"); @@ -1102,19 +1070,14 @@ memrefIndices.append(indices.begin(), indices.end()); assert(map.getNumResults() == 1 && "Expected 1 permutation map result for 1D transfer"); - if (auto expr = map.getResult(0).template dyn_cast()) { - Location loc = xferOp.getLoc(); - auto dim = expr.getPosition(); - AffineExpr d0, d1; - bindDims(xferOp.getContext(), d0, d1); - Value offset = memrefIndices[dim]; - memrefIndices[dim] = makeComposedAffineApply(b, loc, d0 + d1, {offset, iv}); - return dim; - } - - assert(xferOp.isBroadcastDim(0) && - "Expected AffineDimExpr or AffineConstantExpr"); - return None; + auto expr = map.getResult(0).template cast(); + Location loc = xferOp.getLoc(); + auto dim = expr.getPosition(); + AffineExpr d0, d1; + bindDims(xferOp.getContext(), d0, d1); + Value offset = memrefIndices[dim]; + memrefIndices[dim] = makeComposedAffineApply(b, loc, d0 + d1, {offset, iv}); + return dim; } /// Codegen strategy for TransferOp1dConversion, depending on the @@ -1188,12 +1151,11 @@ return succeeded(successStrides) && (strides.empty() || strides.back() == 1); } -/// Lower a 1D vector transfer op to SCF using scalar loads/stores. This is -/// necessary in cases where a 1D vector transfer op cannot be lowered into -/// vector load/stores due to non-unit strides or broadcasts: +/// Lower a 1D vector transfer op to SCF using scalar loads/stores. +/// This is necessary in cases where a 1D vector transfer op cannot be lowered +/// into vector load/stores due to non-unit strides. /// /// * Transfer dimension is not the last memref dimension -/// * Transfer dimension is a broadcast (i.e., scalar load + broadcast) /// * Memref has a layout map with non-unit stride on the last dimension /// /// This pattern generates IR as follows: diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -1072,8 +1072,7 @@ } LogicalResult SplitOp::verify() { - if ((static_cast(getStaticSplitPoint()) != - ShapedType::kDynamic) ^ + if ((static_cast(getStaticSplitPoint()) != ShapedType::kDynamic) ^ (getDynamicSplitPoint() == nullptr)) { return emitOpError() << "expects either a dynamic or a static split " "point to be provided"; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -24,18 +24,21 @@ #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" #include "mlir/IR/AffineExpr.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/Sequence.h" +#include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" -#include "llvm/Support/raw_ostream.h" -#include using namespace mlir; using namespace mlir::linalg; @@ -46,7 +49,7 @@ #define LDBG(X) LLVM_DEBUG(DBGS() << X) /// Try to vectorize `convOp` as a convolution. -static FailureOr vectorizeConvolution(OpBuilder &b, +static FailureOr vectorizeConvolution(RewriterBase &b, LinalgOp convOp); /// Return the unique instance of OpType in `block` if it is indeed unique. @@ -68,6 +71,7 @@ /// Given an indexing `map` coming from a LinalgOp indexing, restricted to a /// projectedPermutation, compress the unused dimensions to serve as a /// permutation_map for a vector transfer operation. +/// Set the `unused` bit vector to those dims that are not used. /// For example, given a linalg op such as: /// /// ``` @@ -79,13 +83,25 @@ /// outs(%1 : tensor<5x6xf32>) /// ``` /// -/// the iteration domain size of the linalg op is 3x5x4x6x2. The first affine -/// map is reindexed to `affine_map<(d0, d1, d2) -> (d2, d0, d1)>`, the second -/// affine map is reindexed to `affine_map<(d0, d1) -> (d0, d1)>`. -static AffineMap reindexIndexingMap(AffineMap map) { - assert(map.isProjectedPermutation(/*allowZeroInResults=*/true) && - "expected projected permutation"); - auto res = compressUnusedDims(map); +/// the iteration domain size of the linalg op is 3x5x4x6x2. +/// When the first map is passed: +/// - the map is reindexed to `affine_map<(d0, d1, d2) -> (d2, d0, d1)>`. +/// - the bit vector is set to (1, 3) +/// When the second map is passed: +/// - the map is reindexed to `affine_map<(d0, d1) -> (d0, d1)>`. +/// - the bit vector is set to (0, 2, 4) +static AffineMap compress(AffineMap map, llvm::SmallBitVector &unusedDims) { + assert(map.isProjectedPermutation() && "expected projected permutation"); + + unusedDims = getUnusedDimsBitVector(map); + AffineMap res = compressDims(map, unusedDims); + uint64_t backingStore; + LLVM_DEBUG({ + llvm::interleaveComma(unusedDims.getData(backingStore), + DBGS() << "compress: " << map << " to " << res + << " by dropping: "); + llvm::dbgs() << "\n"; + }); assert(res.getNumDims() == res.getNumResults() && "expected reindexed map with same number of dims and results"); return res; @@ -161,26 +177,101 @@ return combinerOps[0]; } -/// Broadcast `value` to a vector of `shape` if possible. Return value -/// otherwise. -static Value broadcastIfNeeded(OpBuilder &b, Value value, - ArrayRef shape) { - // If no shape to broadcast to, just return `value`. - if (shape.empty()) +/// Broadcast `value` to a vector of `targetShape`. +/// Since vector.broadcast only allows expanding leading dimensions, this may +/// additionally insert a vector.transpose to make the broadcast possble. +/// To this end, an optional `bcastDims` is passed to capture which dimensions +/// in the `targetShape` come from a broadcast. +/// When present, this is used to derive the extra transpose. +static FailureOr +broadcastIfNeeded(RewriterBase &b, Value value, ArrayRef targetShape, + llvm::SmallBitVector *bcastDims = nullptr) { + // Case 1. If no targetShape to broadcast to, just return `value`. + if (targetShape.empty()) return value; - VectorType targetVectorType = - VectorType::get(shape, getElementTypeOrSelf(value)); - if (vector::isBroadcastableTo(value.getType(), targetVectorType) != + + Location loc = value.getLoc(); + Type elementType = getElementTypeOrSelf(value.getType()); + VectorType sourceVectorType = value.getType().dyn_cast(); + VectorType targetVectorType = VectorType::get(targetShape, elementType); + // Case 2. If scalar -> targetShape broadcast, just do it. + if (!sourceVectorType) { + return b.createOrFold(loc, targetVectorType, value); + } + + // Case 3. We can directly broadcast to the target shape. + if (vector::isBroadcastableTo(value.getType(), targetVectorType) == vector::BroadcastableToResult::Success) - return value; - Location loc = b.getInsertionPoint()->getLoc(); - return b.createOrFold(loc, targetVectorType, value); + return b.createOrFold(loc, targetVectorType, value); + + LDBG(value << " not broadcastable to: " << targetVectorType << "\n"); + + // Case 4. vector -> targetShape broadcast requires a transpose because + // vector.broadcast only allows creating leading dims. + assert(bcastDims && "must specify which dims of the target shape come from " + "broadcast to lift ambiguities"); + SmallVector broadcastShape, permutation(targetShape.size(), -1); + broadcastShape.reserve(targetShape.size()); + int64_t sourceShapeDim = bcastDims->count(); + for (int64_t i = 0, e = targetShape.size(); i < e; ++i) { + LLVM_DEBUG({ + llvm::interleaveComma(permutation, DBGS() << "permutation: "); + llvm::dbgs() << "\n"; + LDBG("i = " << i << "\n"); + LDBG("sourceShapeDim = " << sourceShapeDim << "\n"); + }); + if (bcastDims->test(i)) { + // For each dim in the target shape, if it comes from a broadcast bring + // it to the leading part of the targetShape. + broadcastShape.push_back(targetShape[i]); + // It will need to be permuted back from broadcastShape.size() - 1 into + // position `i`. + // permutation[broadcastShape.size() - 1] = i; + permutation[i] = broadcastShape.size() - 1; + } else { + // Otherwise, the dim comes from the source shape and needs to be + // permuted into position `i`. + // permutation[sourceShapeDim++] = i; + permutation[i] = sourceShapeDim++; + } + LLVM_DEBUG({ + llvm::interleaveComma(permutation, DBGS() << "permutation: "); + llvm::dbgs() << "\n"; + }); + } + llvm::append_range(broadcastShape, sourceVectorType.getShape()); + + VectorType broadcastType = VectorType::get(broadcastShape, elementType); + assert(vector::isBroadcastableTo(value.getType(), broadcastType) == + vector::BroadcastableToResult::Success && + "must be broadcastable"); + if (vector::isBroadcastableTo(value.getType(), broadcastType) != + vector::BroadcastableToResult::Success) { + // llvm_unreachable("broadcast must succeed"); + return b.notifyMatchFailure(value.getLoc(), [&](Diagnostic &diag) { + diag << value << " not broadcastable to: " << broadcastType; + if (bcastDims) { + uint64_t backingStore; + llvm::interleaveComma(bcastDims->getData(backingStore), + diag << " given bcastDims: "); + } + llvm::interleaveComma(permutation, diag << " and permutation: "); + }); + } + Value bcast = b.createOrFold(loc, broadcastType, value); + Value res = b.createOrFold(loc, bcast, permutation); + if (res.getType() != targetVectorType) { + res.dump(); + targetVectorType.dump(); + llvm_unreachable("unexpected transpose type"); + } + return res; } -/// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This -/// assumes that `reductionOp` has two operands and one of them is the reduction -/// initial value. -static Operation *buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, +/// Create MultiDimReductionOp to compute the reduction for `reductionOp`. +/// This assumes that `reductionOp` has two operands and one of them is the +/// reduction initial value. +static Operation *buildMultiDimReduce(RewriterBase &b, Operation *reduceOp, Value valueToReduce, Value acc, const SmallVector &reductionMask) { auto maybeKind = getCombinerOpKind(reduceOp); @@ -194,65 +285,68 @@ llvm::map_range(linalgOp.getIteratorTypesArray(), isReductionIterator)); } -/// Build a vector.transfer_write of `value` into `outputOperand` at indices set -/// to all `0`; where `outputOperand` is an output operand of the LinalgOp -/// currently being vectorized. If `dest` has null rank, build an memref.store. -/// Return the produced value or null if no value is produced. -static Value buildVectorWrite(OpBuilder &b, Value value, - OpOperand *outputOperand) { - Operation *write; +/// Build a vector.transfer_write of `value` into `outputOperand` at indices +/// set to all `0`; where `outputOperand` is an output operand of the LinalgOp +/// currently being vectorized. If `dest` has null rank, build an +/// memref.store. Return the produced value or null if no value is produced. +static FailureOr +buildVectorWrite(RewriterBase &b, Value value, OpOperand *outputOperand) { Location loc = value.getLoc(); auto linalgOp = cast(outputOperand->getOwner()); ArrayRef shape = linalgOp.getShape(outputOperand); auto vectorType = VectorType::get( shape, getElementTypeOrSelf(outputOperand->get().getType())); - if (vectorType.getRank() > 0) { - // 0-d case is still special: do not invert the reindexing map. - AffineMap map = - reindexIndexingMap(linalgOp.getMatchingIndexingMap(outputOperand)); - SmallVector transposeShape = - applyPermutationMap(inversePermutation(map), vectorType.getShape()); - assert(!transposeShape.empty() && "unexpected empty transpose shape"); - vectorType = VectorType::get(transposeShape, vectorType.getElementType()); - SmallVector indices(linalgOp.getRank(outputOperand), - b.create(loc, 0)); - value = broadcastIfNeeded(b, value, vectorType.getShape()); - write = b.create(loc, value, outputOperand->get(), - indices, map); - } else { + + // 0-d case is still special: no need invert the reindexing map, we just + // need to broadcast and exit early. + if (vectorType.getRank() == 0) { if (!value.getType().isa()) value = b.create(loc, vectorType, value); assert(value.getType() == vectorType && "incorrect type"); - write = b.create(loc, value, outputOperand->get(), - ValueRange{}); + return b.create(loc, value, outputOperand->get(), + ValueRange{}); + } + // >=1-D case. + llvm::SmallBitVector unusedDims; + AffineMap map = + compress(linalgOp.getMatchingIndexingMap(outputOperand), unusedDims); + SmallVector transposeShape = + applyPermutationMap(inversePermutation(map), vectorType.getShape()); + assert(!transposeShape.empty() && "unexpected empty transpose shape"); + vectorType = VectorType::get(transposeShape, vectorType.getElementType()); + SmallVector indices(linalgOp.getRank(outputOperand), + b.create(loc, 0)); + auto maybeBroadcasted = broadcastIfNeeded(b, value, vectorType.getShape()); + if (failed(maybeBroadcasted)) { + return b.notifyMatchFailure(loc, [&](Diagnostic &diag) { + diag << value << " could not be broadcast to: " << vectorType; + }); } - LDBG("vectorized op: " << *write); - if (!write->getResults().empty()) - return write->getResult(0); - return Value(); + return b.create(loc, *maybeBroadcasted, + outputOperand->get(), indices, map); } -// Custom vectorization precondition function type. This is intented to be used -// with CustomVectorizationHook. Returns success if the correpsonding custom -// hook can vectorize the op. +// Custom vectorization precondition function type. This is intented to be +// used with CustomVectorizationHook. Returns success if the correpsonding +// custom hook can vectorize the op. using CustomVectorizationPrecondition = - std::function; + std::function; // Custom vectorization function type. Produce a vector form of Operation* -// assuming all its vectorized operands are already in the BlockAndValueMapping. -// Return nullptr if the Operation cannot be vectorized. -using CustomVectorizationHook = std::function( Operation *, const BlockAndValueMapping &)>; /// Helper function to vectorize the terminator of a `linalgOp`. New result /// vector values are appended to `newResults`. Return -/// VectorizationStatus::NoReplace to signal the vectorization algorithm that it -/// should not try to map produced operations and instead return the results -/// using the `newResults` vector making them available to the +/// VectorizationStatus::NoReplace to signal the vectorization algorithm that +/// it should not try to map produced operations and instead return the +/// results using the `newResults` vector making them available to the /// vectorization algorithm for RAUW. This function is meant to be used as a /// CustomVectorizationHook. -static VectorizationResult -vectorizeLinalgYield(OpBuilder &b, Operation *op, +static FailureOr +vectorizeLinalgYield(RewriterBase &b, Operation *op, const BlockAndValueMapping &bvm, LinalgOp linalgOp, SmallVectorImpl &newResults) { auto yieldOp = dyn_cast(op); @@ -262,11 +356,19 @@ // TODO: Scan for an opportunity for reuse. // TODO: use a map. Value vectorValue = bvm.lookup(outputs.value()); - Value newResult = buildVectorWrite( + auto maybeVectorWrite = buildVectorWrite( b, vectorValue, linalgOp.getDpsInitOperand(outputs.index())); - if (newResult) - newResults.push_back(newResult); + if (failed(maybeVectorWrite)) + return b.notifyMatchFailure(op, "failed to vectorize"); + vector::TransferWriteOp vectorWrite = *maybeVectorWrite; + LDBG("vector.transfer_write: " << maybeVectorWrite << "\n"); + if (vectorWrite->getNumResults() > 0) + newResults.push_back(vectorWrite->getResult(0)); } + LLVM_DEBUG({ + llvm::interleaveComma(newResults, DBGS() << "results: "); + llvm::dbgs() << "\n"; + }); return VectorizationResult{VectorizationStatus::NoReplace, nullptr}; } @@ -274,8 +376,8 @@ /// VectorizationStatus::NewOp to signal the vectorization algorithm that it /// should map the produced operations. This function is meant to be used as a /// CustomVectorizationHook. -static VectorizationResult vectorizeLinalgIndex(OpBuilder &b, Operation *op, - LinalgOp linalgOp) { +static FailureOr +vectorizeLinalgIndex(RewriterBase &b, Operation *op, LinalgOp linalgOp) { IndexOp indexOp = dyn_cast(op); if (!indexOp) return VectorizationResult{VectorizationStatus::Failure, nullptr}; @@ -288,8 +390,8 @@ auto constantOp = b.create(loc, b.getIndexVectorAttr(constantSeq)); // Return the one-dimensional index vector if it lives in the trailing - // dimension of the iteration space since the vectorization algorithm in this - // case can handle the broadcast. + // dimension of the iteration space since the vectorization algorithm in + // this case can handle the broadcast. if (indexOp.getDim() == targetShape.size() - 1) return VectorizationResult{VectorizationStatus::NewOp, constantOp}; // Otherwise permute the targetShape to move the index dimension last, @@ -308,24 +410,19 @@ /// Helper function to check if the tensor.extract can be vectorized by the /// custom hook vectorizeTensorExtract. -static LogicalResult tensorExtractVectorizationPrecondition(Operation *op) { +static LogicalResult tensorExtractVectorizationPrecondition(RewriterBase &b, + Operation *op) { tensor::ExtractOp extractOp = dyn_cast(op); + // Avoid reporting trivial information. if (!extractOp) return failure(); - - // Currently only supports extraction with an 1-D index. if (extractOp.getIndices().size() != 1) - return failure(); - - if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType())) - return failure(); - - if (llvm::any_of(extractOp->getResultTypes(), [](Type type) { - return !VectorType::isValidElementType(type); - })) { - return failure(); - } - + return b.notifyMatchFailure(op, "only 1 index supported atm"); + auto validElementType = VectorType::isValidElementType; + if (!validElementType(extractOp.getIndices()[0].getType())) + return b.notifyMatchFailure(op, "not a valid operand vector element type"); + if (!llvm::all_of(extractOp->getResultTypes(), validElementType)) + return b.notifyMatchFailure(op, "not a valid result vector element type"); return success(); } @@ -333,8 +430,8 @@ /// VectorizationStatus::NewOp to signal the vectorization algorithm that it /// should map the produced operations. This function is meant to be used as a /// CustomVectorizationHook. -static VectorizationResult -vectorizeTensorExtract(OpBuilder &b, Operation *op, LinalgOp linalgOp, +static FailureOr +vectorizeTensorExtract(RewriterBase &b, Operation *op, LinalgOp linalgOp, const BlockAndValueMapping &bvm) { tensor::ExtractOp extractOp = dyn_cast(op); if (!extractOp) @@ -369,10 +466,11 @@ return VectorizationResult{VectorizationStatus::NewOp, gatherOp}; } -/// Emit reduction operations if the shapes of the value to reduce is different -/// that the result shape. -static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op, - Value reduceValue, Value initialValue, +/// Emit reduction operations if the shapes of the value to reduce is +/// different that the result shape. +static Operation *reduceIfNeeded(RewriterBase &b, LinalgOp linalgOp, + Operation *op, Value reduceValue, + Value initialValue, const BlockAndValueMapping &bvm) { Value reduceVec = bvm.lookup(reduceValue); Value outputVec = bvm.lookup(initialValue); @@ -387,8 +485,8 @@ return buildMultiDimReduce(b, op, reduceVec, outputVec, reductionMask); } -/// Generic vectorization for a single operation `op`, given already vectorized -/// operands carried by `bvm`. Vectorization occurs as follows: +/// Generic vectorization for a single operation `op`, given already +/// vectorized operands carried by `bvm`. Vectorization occurs as follows: /// 1. Try to apply any of the `customVectorizationHooks` and return its /// result on success. /// 2. Clone any constant in the current scope without vectorization: each @@ -396,34 +494,34 @@ /// constant needs to be broadcast to. /// 3. Fail on any remaining non `ElementwiseMappable` op. It is the purpose /// of the `customVectorizationHooks` to cover such cases. -/// 4. Clone `op` in vector form to a vector of shape prescribed by the first -/// operand of maximal rank. Other operands have smaller rank and are +/// 4. Clone `op` in vector form to a vector of shape prescribed by the +/// first operand of maximal rank. Other operands have smaller rank and are /// broadcast accordingly. It is assumed this broadcast is always legal, /// otherwise, it means one of the `customVectorizationHooks` is incorrect. /// /// This function assumes all operands of `op` have been vectorized and are in -/// the `bvm` mapping. As a consequence, this function is meant to be called on -/// a topologically-sorted list of ops. -/// This function does not update `bvm` but returns a VectorizationStatus that -/// instructs the caller what `bvm` update needs to occur. -static VectorizationResult -vectorizeOneOp(OpBuilder &b, LinalgOp linalgOp, Operation *op, +/// the `bvm` mapping. As a consequence, this function is meant to be called +/// on a topologically-sorted list of ops. This function does not update `bvm` +/// but returns a VectorizationStatus that instructs the caller what `bvm` +/// update needs to occur. +static FailureOr +vectorizeOneOp(RewriterBase &b, LinalgOp linalgOp, Operation *op, const BlockAndValueMapping &bvm, ArrayRef customVectorizationHooks) { - LDBG("vectorize op " << *op); - - // 1. Try to apply any CustomVectorizationHook. + // 1. Try to apply any CustomVectorizationHook, if succeeds exit early. if (!customVectorizationHooks.empty()) { for (auto &customFunc : customVectorizationHooks) { - VectorizationResult result = customFunc(op, bvm); - if (result.status == VectorizationStatus::Failure) + auto maybeResult = customFunc(op, bvm); + if (failed(maybeResult)) + return b.notifyMatchFailure(op, "custom vectorization hook failed"); + if (maybeResult->status == VectorizationStatus::Failure) continue; - return result; + return *maybeResult; } } - // 2. Constant ops don't get vectorized but rather broadcasted at their users. - // Clone so that the constant is not confined to the linalgOp block . + // 2. Constant ops don't get vectorized but rather broadcasted at their + // users. Clone so that the constant is not confined to the linalgOp block . if (isa(op)) return VectorizationResult{VectorizationStatus::NewOp, b.clone(*op)}; @@ -463,18 +561,28 @@ firstMaxRankedShape.assign(vt.getShape().begin(), vt.getShape().end()); } // b. broadcast each op if needed. - auto vectorizedOperands = llvm::map_range(op->getOperands(), [&](Value v) { - return firstMaxRankedShape.empty() - ? bvm.lookup(v) - : broadcastIfNeeded(b, bvm.lookup(v), firstMaxRankedShape); - }); - // c. for elementwise, the result is the vector with the firstMaxRankedShape + auto maybeVectorizedOperands = + llvm::map_range(op->getOperands(), [&](Value v) { + return firstMaxRankedShape.empty() + ? bvm.lookup(v) + : broadcastIfNeeded(b, bvm.lookup(v), firstMaxRankedShape); + }); + // c. for elementwise, the result is the vector with the + // firstMaxRankedShape auto returnTypes = llvm::map_range(op->getResultTypes(), [&](Type t) { return firstMaxRankedShape.empty() ? t : VectorType::get(firstMaxRankedShape, t); }); + SmallVector vectorizedOperands; + vectorizedOperands.reserve(8); + for (auto maybeOperand : maybeVectorizedOperands) { + if (failed(maybeOperand)) + return b.notifyMatchFailure(op, "operands failed to vectorize"); + vectorizedOperands.push_back(*maybeOperand); + } + // Build and return the new op. return VectorizationResult{ VectorizationStatus::NewOp, @@ -488,8 +596,8 @@ /// 1. Verify the `linalgOp` has one non-empty region. /// 2. Values defined above the region are mapped to themselves and will be /// broadcasted on a per-need basis by their consumers. -/// 3. Each region argument is vectorized into a vector.transfer_read (or 0-d -/// load). +/// 3. Each region argument is vectorized into a vector.transfer_read (or +/// 0-d load). /// TODO: Reuse opportunities for RAR dependencies. /// 4a. Register CustomVectorizationHook for YieldOp to capture the results. /// 4b. Register CustomVectorizationHook for IndexOp to access the iteration @@ -499,33 +607,34 @@ /// When `broadcastToMaximalCommonShape` is set to true, eager broadcasting is /// performed to the maximal common vector size implied by the `linalgOp` /// iteration space. This eager broadcasting is introduced in the -/// permutation_map of the vector.transfer_read operations. The eager -/// broadcasting makes it trivial to detrmine where broadcast, transposes and -/// reductions should occur, without any bookkeeping. The tradeoff is that, in -/// the absence of good canonicalizations, the amount of work increases. -/// This is not deemed a problem as we expect canonicalizations and foldings to -/// aggressively clean up the useless work. +/// permutation_map of the vector.transfer_read operations. +/// The eager broadcasting makes it trivial to determine where broadcast, +/// transposes and reductions should occur, without any bookkeeping. +/// The tradeoff is that, in the absence of good canonicalizations, the amount +/// of work can increase (a lot). +/// This is not deemed a problem as we expect canonicalizations and foldings +/// to aggressively clean up the useless work. static LogicalResult -vectorizeAsLinalgGeneric(OpBuilder &b, LinalgOp linalgOp, +vectorizeAsLinalgGeneric(RewriterBase &b, LinalgOp linalgOp, SmallVectorImpl &newResults) { Block *block = linalgOp.getBlock(); - // 2. Values defined above the region can only be broadcast for now. Make them - // map to themselves. + // 2. Values defined above the region can only be broadcast for now. Make + // them map to themselves. BlockAndValueMapping bvm; SetVector valuesSet; mlir::getUsedValuesDefinedAbove(linalgOp->getRegion(0), valuesSet); bvm.map(valuesSet.getArrayRef(), valuesSet.getArrayRef()); if (linalgOp.getNumDpsInits() == 0) - return failure(); + return b.notifyMatchFailure(linalgOp, "no DPS inits"); // TODO: the common vector shape is equal to the static loop sizes only when // all indexing maps are projected permutations. For convs and stencils the // logic will need to evolve. SmallVector commonVectorShape = linalgOp.computeStaticLoopSizes(); - // 3. Turn all BBArgs into vector.transfer_read / load. + // 3. Turn all BBArgs into vector.transfer_read. Location loc = linalgOp.getLoc(); Value zero = b.create(loc, 0); for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) { @@ -534,35 +643,91 @@ bvm.map(bbarg, opOperand->get()); continue; } - VectorType readType; - AffineMap map; - // TODO: can we keep this simplification? - // if (linalgOp.getShape(&opOperand).empty()) { - // readType = VectorType::get({}, bbarg.getType()); - // } else { - if (opOperand->getOperandNumber() < linalgOp.getNumDpsInputs()) { - map = inverseAndBroadcastProjectedPermutation( - linalgOp.getMatchingIndexingMap(opOperand)); - readType = VectorType::get(commonVectorShape, - getElementTypeOrSelf(opOperand->get())); - } else { - map = inversePermutation( - reindexIndexingMap(linalgOp.getMatchingIndexingMap(opOperand))); - readType = VectorType::get(map.compose(linalgOp.getShape(opOperand)), - getElementTypeOrSelf(opOperand->get())); - } - // } - auto shape = linalgOp.getShape(opOperand); - SmallVector indices(shape.size(), zero); + /// Consider the following linalg op as an illustration: + /// + /// ``` + /// %0 = linalg.generic { + /// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, + /// d2)>, indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d1, + /// d3)> + /// } + /// ins(%0 : tensor<2x3x4xf32>) + /// outs(%1 : tensor<5x6xf32>) + /// ``` + /// + /// the iteration domain size of the linalg op is 3x5x4x6x2. + /// The common normalized vector size on which all operations will occur + /// is `vector<3x5x4x6x2xf32>`. + /// + /// Illustration with map `affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, + /// d2)>` follows. + + /// 3.1. Get the map from the iteration domain to the indexed data: + /// `affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, d2)>` + AffineMap iterationToIndexedDataMap = + linalgOp.getMatchingIndexingMap(opOperand); + assert(iterationToIndexedDataMap.isProjectedPermutation( + /*allowZeroInResults=*/true) && + "expected projected permutation"); + + // 3.1.b. Linalg rank-reducing corner case: drop all constants, known to be + // zeros. + SmallVector constantPositions; + constantPositions.reserve(iterationToIndexedDataMap.getNumResults()); + for (int64_t i = 0, e = iterationToIndexedDataMap.getNumResults(); i < e; + ++i) + if (iterationToIndexedDataMap.getResult(i).isa()) + constantPositions.push_back(i); + iterationToIndexedDataMap = + iterationToIndexedDataMap.dropResults(constantPositions); + SmallVector operandShape{linalgOp.getShape(opOperand)}; + for (int64_t pos : llvm::reverse(constantPositions)) + operandShape.erase(operandShape.begin() + pos); + + /// 3.2. Compress out the unused dimensions (e.g. d1, d3), these will just + /// be broadcasts to go from the data read to the common normalized + /// vector: + /// `affine_map<(d0, d1, d2) -> (d2, d0, d1)>` + llvm::SmallBitVector bcastDims; + AffineMap permutationMap = compress(iterationToIndexedDataMap, bcastDims); + assert(permutationMap.isPermutation() && "expected a permutation"); + + /// 3.3. Invert the permutation to obtain the readMap for the + /// transfer_read: + /// `affine_map<(d0, d1, d2) -> (d1, d2, d0)>` + /// The data will then be in a form + AffineMap readMap = inversePermutation(permutationMap); + auto readShape = readMap.compose(operandShape); + Type elementType = getElementTypeOrSelf(opOperand->get()); + VectorType readType = VectorType::get(readShape, elementType); + + /// 3.4 Compute the target type: + /// - for an output operand we just use the read type. + /// - for an input operand, we may need to broadcast to the common + /// vector shape. + VectorType targetType = readType; + if (linalgOp.isDpsInput(opOperand)) + targetType = VectorType::get(commonVectorShape, elementType); + + // 3.4.1. Linalg rank-reducing corner case: reinject indexing for zeros. + for (int64_t pos : constantPositions) + readMap = readMap.shiftDims(1, /*offset=*/pos); + + SmallVector indices(readMap.getNumDims(), zero); Value readValue = b.create( - loc, readType, opOperand->get(), indices, map); + loc, readType, opOperand->get(), indices, readMap); + FailureOr maybeReadValue = + broadcastIfNeeded(b, readValue, targetType.getShape(), &bcastDims); + if (failed(maybeReadValue)) + return b.notifyMatchFailure(linalgOp, "operand could not be broadcast"); + readValue = *maybeReadValue; + // Not all ops support 0-d vectors, extract the scalar for now. // TODO: remove this. if (readValue.getType().cast().getRank() == 0) readValue = b.create(loc, readValue); - LDBG("new vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue); bvm.map(bbarg, readValue); bvm.map(opOperand->get(), readValue); } @@ -571,7 +736,7 @@ // 4a. Register CustomVectorizationHook for yieldOp. CustomVectorizationHook vectorizeYield = [&](Operation *op, - const BlockAndValueMapping &bvm) -> VectorizationResult { + const BlockAndValueMapping &bvm) -> FailureOr { return vectorizeLinalgYield(b, op, bvm, linalgOp, newResults); }; hooks.push_back(vectorizeYield); @@ -579,7 +744,7 @@ // 4b. Register CustomVectorizationHook for indexOp. CustomVectorizationHook vectorizeIndex = [&](Operation *op, - const BlockAndValueMapping &bvm) -> VectorizationResult { + const BlockAndValueMapping &bvm) -> FailureOr { return vectorizeLinalgIndex(b, op, linalgOp); }; hooks.push_back(vectorizeIndex); @@ -587,22 +752,20 @@ // 4c. Register CustomVectorizationHook for extractOp. CustomVectorizationHook vectorizeExtract = [&](Operation *op, - const BlockAndValueMapping &bvm) -> VectorizationResult { + const BlockAndValueMapping &bvm) -> FailureOr { return vectorizeTensorExtract(b, op, linalgOp, bvm); }; hooks.push_back(vectorizeExtract); // 5. Iteratively call `vectorizeOneOp` to each op in the slice. for (Operation &op : block->getOperations()) { - VectorizationResult result = vectorizeOneOp(b, linalgOp, &op, bvm, hooks); - if (result.status == VectorizationStatus::Failure) { - LDBG("failed to vectorize: " << op); - return failure(); - } - if (result.status == VectorizationStatus::NewOp) { - LDBG("new vector op: " << *result.newOp;); - bvm.map(op.getResults(), result.newOp->getResults()); - } + FailureOr maybeResult = + vectorizeOneOp(b, linalgOp, &op, bvm, hooks); + if (failed(maybeResult) || + maybeResult->status == VectorizationStatus::Failure) + return b.notifyMatchFailure(&op, "failed to vectorize"); + if (maybeResult->status == VectorizationStatus::NewOp) + bvm.map(op.getResults(), maybeResult->newOp->getResults()); } return success(); @@ -610,27 +773,23 @@ // TODO: probably need some extra checks for reduction followed by consumer // ops that may not commute (e.g. linear reduction + non-linear instructions). -static LogicalResult reductionPreconditions(LinalgOp op) { - if (llvm::none_of(op.getIteratorTypesArray(), isReductionIterator)) { - LDBG("reduction precondition failed: no reduction iterator"); - return failure(); - } +static LogicalResult reductionPreconditions(RewriterBase &b, LinalgOp op) { + if (llvm::none_of(op.getIteratorTypesArray(), isReductionIterator)) + return b.notifyMatchFailure(op, "no reduction iterator"); for (OpOperand *opOperand : op.getDpsInitOperands()) { AffineMap indexingMap = op.getMatchingIndexingMap(opOperand); if (indexingMap.isPermutation()) continue; Operation *reduceOp = matchLinalgReduction(opOperand); - if (!reduceOp || !getCombinerOpKind(reduceOp)) { - LDBG("reduction precondition failed: reduction detection failed"); - return failure(); - } + if (!reduceOp || !getCombinerOpKind(reduceOp)) + return b.notifyMatchFailure(op, "reduction combiner detection failed"); } return success(); } static LogicalResult vectorizeStaticLinalgOpPrecondition( - linalg::LinalgOp op, + RewriterBase &b, linalg::LinalgOp op, ArrayRef customPreconditions) { // All types in the body should be a supported element type for VectorType. @@ -639,91 +798,72 @@ if (llvm::any_of( customPreconditions, [&](const CustomVectorizationPrecondition &customPrecondition) { - return succeeded(customPrecondition(&innerOp)); + return succeeded(customPrecondition(b, &innerOp)); })) { continue; } - if (llvm::any_of(innerOp.getOperandTypes(), [](Type type) { - return !VectorType::isValidElementType(type); - })) { - return failure(); - } - if (llvm::any_of(innerOp.getResultTypes(), [](Type type) { - return !VectorType::isValidElementType(type); - })) { - return failure(); - } + auto validElementType = VectorType::isValidElementType; + if (!llvm::all_of(innerOp.getOperandTypes(), validElementType)) + return b.notifyMatchFailure(op, "invalid operand vector element type"); + if (!llvm::all_of(innerOp.getResultTypes(), validElementType)) + return b.notifyMatchFailure(op, "invalid result vector element type"); } if (isElementwise(op)) return success(); - // TODO: isaConvolutionOpInterface that can also infer from generic features. - // But we will still need stride/dilation attributes that will be annoying to - // reverse-engineer... + // TODO: isaConvolutionOpInterface that can also infer from generic + // features. But we will still need stride/dilation attributes that will be + // annoying to reverse-engineer... if (isa(op.getOperation())) return success(); // TODO: the common vector shape is equal to the static loop sizes only when // all indexing maps are projected permutations. For convs and stencils the // logic will need to evolve. - if (!allIndexingsAreProjectedPermutation(op)) { - LDBG("precondition failed: not projected permutations"); - return failure(); - } - if (failed(reductionPreconditions(op))) { - LDBG("precondition failed: reduction preconditions"); - return failure(); - } + if (!allIndexingsAreProjectedPermutation(op)) + return b.notifyMatchFailure(op, "indexing is not a projected permutation"); + if (failed(reductionPreconditions(b, op))) + return b.notifyMatchFailure(op, "reduction preconditions failed"); return success(); } -LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(LinalgOp linalgOp) { +LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(RewriterBase &b, + LinalgOp linalgOp) { // All types must be static shape to go to vector. - if (linalgOp.hasDynamicShape()) { - LDBG("precondition failed: dynamic shape"); - return failure(); - } + if (linalgOp.hasDynamicShape()) + return b.notifyMatchFailure(linalgOp, "op has dynamic shape"); SmallVector customPreconditions; // Register CustomVectorizationPrecondition for extractOp. customPreconditions.push_back(tensorExtractVectorizationPrecondition); - return vectorizeStaticLinalgOpPrecondition(linalgOp, customPreconditions); + return vectorizeStaticLinalgOpPrecondition(b, linalgOp, customPreconditions); } -LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, - LinalgOp linalgOp) { - if (failed(vectorizeLinalgOpPrecondition(linalgOp))) - return failure(); +LogicalResult mlir::linalg::vectorize(RewriterBase &b, LinalgOp op) { + if (failed(vectorizeLinalgOpPrecondition(b, op))) + return b.notifyMatchFailure(op, "linalg preconditions failed"); SmallVector results; // TODO: isaConvolutionOpInterface that can also infer from generic // features. Will require stride/dilation attributes inference. - FailureOr convOr = vectorizeConvolution(rewriter, linalgOp); - if (succeeded(convOr)) { + FailureOr convOr = vectorizeConvolution(b, op); + if (succeeded(convOr)) llvm::append_range(results, (*convOr)->getResults()); - } else { - if (failed(vectorizeLinalgOpPrecondition(linalgOp))) - return failure(); - LDBG("Vectorize generic by broadcasting to a common shape: " << linalgOp); - if (failed(vectorizeAsLinalgGeneric(rewriter, linalgOp, results))) - return failure(); - } + else if (failed(vectorizeAsLinalgGeneric(b, op, results))) + return b.notifyMatchFailure(op, "failed to vectorize like linalg.generic"); - if (!results.empty()) - rewriter.replaceOp(linalgOp, results); - else - rewriter.eraseOp(linalgOp); + b.replaceOp(op, results); return success(); } -LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter, +LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &b, memref::CopyOp copyOp) { auto srcType = copyOp.getSource().getType().cast(); auto dstType = copyOp.getTarget().getType().cast(); if (!srcType.hasStaticShape() || !dstType.hasStaticShape()) - return failure(); + return b.notifyMatchFailure(copyOp, "not a static shape"); auto readType = VectorType::get(srcType.getShape(), getElementTypeOrSelf(srcType)); @@ -731,20 +871,20 @@ VectorType::get(dstType.getShape(), getElementTypeOrSelf(dstType)); Location loc = copyOp->getLoc(); - Value zero = rewriter.create(loc, 0); + Value zero = b.create(loc, 0); SmallVector indices(srcType.getRank(), zero); - Value readValue = rewriter.create( + Value readValue = b.create( loc, readType, copyOp.getSource(), indices, - rewriter.getMultiDimIdentityMap(srcType.getRank())); + b.getMultiDimIdentityMap(srcType.getRank())); if (readValue.getType().cast().getRank() == 0) { - readValue = rewriter.create(loc, readValue); - readValue = rewriter.create(loc, writeType, readValue); + readValue = b.create(loc, readValue); + readValue = b.create(loc, writeType, readValue); } - Operation *writeValue = rewriter.create( + Operation *writeValue = b.create( loc, readValue, copyOp.getTarget(), indices, - rewriter.getMultiDimIdentityMap(srcType.getRank())); - rewriter.replaceOp(copyOp, writeValue->getResults()); + b.getMultiDimIdentityMap(srcType.getRank())); + b.replaceOp(copyOp, writeValue->getResults()); return success(); } @@ -760,7 +900,7 @@ /// Given an ArrayRef of OpFoldResults, return a vector of Values. /// IntegerAttrs are converted to ConstantIndexOps. Other attribute types are /// not supported. -static SmallVector ofrToIndexValues(OpBuilder &builder, Location loc, +static SmallVector ofrToIndexValues(RewriterBase &builder, Location loc, ArrayRef ofrs) { SmallVector result; for (auto o : ofrs) { @@ -785,8 +925,8 @@ /// Vectorize the copying of a tensor::PadOp's source. This is possible if /// each dimension size is statically know in the source type or the result /// type (or both). - static LogicalResult tryVectorizeCopy(PatternRewriter &rewriter, - tensor::PadOp padOp, Value dest) { + static LogicalResult tryVectorizeCopy(PatternRewriter &b, tensor::PadOp padOp, + Value dest) { auto sourceType = padOp.getSourceType(); auto resultType = padOp.getResultType(); @@ -796,11 +936,11 @@ auto padValue = padOp.getConstantPaddingValue(); if (!padValue) { if (!sourceType.hasStaticShape()) - return failure(); + return b.notifyMatchFailure(padOp, "source does not have static shape"); // Create dummy padding value. auto elemType = sourceType.getElementType(); - padValue = rewriter.create( - padOp.getLoc(), elemType, rewriter.getZeroAttr(elemType)); + padValue = b.create(padOp.getLoc(), elemType, + b.getZeroAttr(elemType)); } SmallVector vecShape; @@ -828,16 +968,15 @@ } else { // Neither source nor result dim of padOp is static. Cannot vectorize // the copy. - return failure(); + return b.notifyMatchFailure(padOp, "not static"); } } auto vecType = VectorType::get(vecShape, sourceType.getElementType()); // Generate TransferReadOp. SmallVector readIndices( - vecType.getRank(), - rewriter.create(padOp.getLoc(), 0)); - auto read = rewriter.create( + vecType.getRank(), b.create(padOp.getLoc(), 0)); + auto read = b.create( padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue, ArrayRef{readInBounds}); @@ -850,8 +989,8 @@ // Generate TransferWriteOp. auto writeIndices = - ofrToIndexValues(rewriter, padOp.getLoc(), padOp.getMixedLowPad()); - rewriter.replaceOpWithNewOp( + ofrToIndexValues(b, padOp.getLoc(), padOp.getMixedLowPad()); + b.replaceOpWithNewOp( padOp, read, dest, writeIndices, ArrayRef{writeInBounds}); return success(); @@ -865,18 +1004,18 @@ using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tensor::PadOp padOp, - PatternRewriter &rewriter) const final { + PatternRewriter &b) const final { bool changed = false; // Insert users in vector, because some users may be replaced/removed. for (auto *user : llvm::to_vector<4>(padOp->getUsers())) if (auto op = dyn_cast(user)) - changed |= rewriteUser(rewriter, padOp, op).succeeded(); + changed |= rewriteUser(b, padOp, op).succeeded(); return success(changed); } protected: - virtual LogicalResult rewriteUser(PatternRewriter &rewriter, - tensor::PadOp padOp, OpTy op) const = 0; + virtual LogicalResult rewriteUser(PatternRewriter &b, tensor::PadOp padOp, + OpTy op) const = 0; }; /// Rewrite use of tensor::PadOp result in TransferReadOp. E.g.: @@ -903,23 +1042,25 @@ using VectorizePadOpUserPattern< vector::TransferReadOp>::VectorizePadOpUserPattern; - LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, + LogicalResult rewriteUser(PatternRewriter &b, tensor::PadOp padOp, vector::TransferReadOp xferOp) const override { // Low padding must be static 0. if (!padOp.hasZeroLowPad()) - return failure(); + return b.notifyMatchFailure(padOp, "non-zero low pad"); // Pad value must be a constant. auto padValue = padOp.getConstantPaddingValue(); if (!padValue) - return failure(); + return b.notifyMatchFailure(padOp, "non-constant padding value"); // Padding value of existing `xferOp` is unused. - if (xferOp.hasOutOfBoundsDim() || xferOp.getMask()) - return failure(); + if (xferOp.hasOutOfBoundsDim()) + return b.notifyMatchFailure(padOp, "unsupported out-of-bounds dim"); + if (xferOp.getMask()) + return b.notifyMatchFailure(padOp, "unsupported mask"); - rewriter.updateRootInPlace(xferOp, [&]() { + b.updateRootInPlace(xferOp, [&]() { SmallVector inBounds(xferOp.getVectorType().getRank(), false); xferOp->setAttr(xferOp.getInBoundsAttrName(), - rewriter.getBoolArrayAttr(inBounds)); + b.getBoolArrayAttr(inBounds)); xferOp.getSourceMutable().assign(padOp.getSource()); xferOp.getPaddingMutable().assign(padValue); }); @@ -965,41 +1106,41 @@ using VectorizePadOpUserPattern< vector::TransferWriteOp>::VectorizePadOpUserPattern; - LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, + LogicalResult rewriteUser(PatternRewriter &b, tensor::PadOp padOp, vector::TransferWriteOp xferOp) const override { // TODO: support 0-d corner case. if (xferOp.getTransferRank() == 0) - return failure(); + return b.notifyMatchFailure(padOp, "unsupported 0-d"); // Low padding must be static 0. if (!padOp.hasZeroLowPad()) - return failure(); + return b.notifyMatchFailure(padOp, "non-zero low pad"); // Pad value must be a constant. auto padValue = padOp.getConstantPaddingValue(); if (!padValue) - return failure(); - // TransferWriteOp result must be directly consumed by an ExtractSliceOp. + return b.notifyMatchFailure(padOp, "non-constant padding value"); + // TransferWriteOp result must be consumed by a single ExtractSliceOp. if (!xferOp->hasOneUse()) - return failure(); + return b.notifyMatchFailure(padOp, "not a single use"); auto trimPadding = dyn_cast(*xferOp->user_begin()); if (!trimPadding) - return failure(); + return b.notifyMatchFailure(padOp, "not an extract_slice use"); // Only static zero offsets supported when trimming padding. if (!trimPadding.hasZeroOffset()) - return failure(); + return b.notifyMatchFailure(padOp, "non-zero offset"); // trimPadding must remove the amount of padding that was added earlier. if (!hasSameTensorSize(padOp.getSource(), trimPadding)) - return failure(); + return b.notifyMatchFailure(padOp, "different tensor sizes"); // Insert the new TransferWriteOp at position of the old TransferWriteOp. - rewriter.setInsertionPoint(xferOp); + b.setInsertionPoint(xferOp); SmallVector inBounds(xferOp.getVectorType().getRank(), false); - auto newXferOp = rewriter.replaceOpWithNewOp( + auto newXferOp = b.replaceOpWithNewOp( xferOp, padOp.getSource().getType(), xferOp.getVector(), padOp.getSource(), xferOp.getIndices(), xferOp.getPermutationMapAttr(), - xferOp.getMask(), rewriter.getBoolArrayAttr(inBounds)); - rewriter.replaceOp(trimPadding, newXferOp->getResult(0)); + xferOp.getMask(), b.getBoolArrayAttr(inBounds)); + b.replaceOp(trimPadding, newXferOp->getResult(0)); return success(); } @@ -1119,24 +1260,24 @@ using VectorizePadOpUserPattern< tensor::InsertSliceOp>::VectorizePadOpUserPattern; - LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, + LogicalResult rewriteUser(PatternRewriter &b, tensor::PadOp padOp, tensor::InsertSliceOp insertOp) const override { // Low padding must be static 0. if (!padOp.hasZeroLowPad()) - return failure(); + return b.notifyMatchFailure(padOp, "non-zero low pad"); // Only unit stride supported. if (!insertOp.hasUnitStride()) - return failure(); + return b.notifyMatchFailure(padOp, "non-unit stride"); // Pad value must be a constant. auto padValue = padOp.getConstantPaddingValue(); if (!padValue) - return failure(); + return b.notifyMatchFailure(padOp, "non-constant padding value"); // Dynamic shapes not supported. if (!padOp.getResult().getType().cast().hasStaticShape()) - return failure(); + return b.notifyMatchFailure(padOp, "non-static shape"); // Pad result not used as destination. if (insertOp.getDest() == padOp.getResult()) - return failure(); + return b.notifyMatchFailure(padOp, "pad result is not destination"); auto vecType = VectorType::get(padOp.getType().getShape(), padOp.getType().getElementType()); @@ -1151,26 +1292,26 @@ llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](auto it) { return getConstantIntValue(std::get<0>(it)) == std::get<1>(it); })) - return failure(); + return b.notifyMatchFailure(padOp, "sizes don't match"); // Insert the TransferReadOp and TransferWriteOp at the position of the // InsertSliceOp. - rewriter.setInsertionPoint(insertOp); + b.setInsertionPoint(insertOp); // Generate TransferReadOp: Read entire source tensor and add high // padding. SmallVector readIndices( - vecRank, rewriter.create(padOp.getLoc(), 0)); - auto read = rewriter.create( + vecRank, b.create(padOp.getLoc(), 0)); + auto read = b.create( padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue); // Generate TransferWriteOp: Write to InsertSliceOp's dest tensor at // specified offsets. Write is fully in-bounds because a InsertSliceOp's // source must fit into the destination at the specified offsets. auto writeIndices = - ofrToIndexValues(rewriter, padOp.getLoc(), insertOp.getMixedOffsets()); + ofrToIndexValues(b, padOp.getLoc(), insertOp.getMixedOffsets()); SmallVector inBounds(vecRank, true); - rewriter.replaceOpWithNewOp( + b.replaceOpWithNewOp( insertOp, read, insertOp.getDest(), writeIndices, ArrayRef{inBounds}); @@ -1200,8 +1341,6 @@ ValueRange values) { if (firstOp->getBlock() != secondOp->getBlock() || !firstOp->isBeforeInBlock(secondOp)) { - LDBG("interleavedUses precondition failed, firstOp: " - << *firstOp << ", second op: " << *secondOp); return true; } for (auto v : values) { @@ -1213,8 +1352,6 @@ if (owner->getBlock() == firstOp->getBlock() && (owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner))) continue; - LDBG(" found interleaved op " << *owner << ", firstOp: " << *firstOp - << ", second op: " << *secondOp); return true; } } @@ -1237,27 +1374,25 @@ /// TODO: use interfaces, side-effects and aliasing analysis as appropriate, /// when available. -LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite( - vector::TransferReadOp xferOp, PatternRewriter &rewriter) const { +LogicalResult +LinalgCopyVTRForwardingPattern::matchAndRewrite(vector::TransferReadOp xferOp, + PatternRewriter &b) const { // TODO: support mask. if (xferOp.getMask()) - return failure(); + return b.notifyMatchFailure(xferOp, "unsupported mask"); // Transfer into `view`. Value viewOrAlloc = xferOp.getSource(); if (!viewOrAlloc.getDefiningOp() && !viewOrAlloc.getDefiningOp()) - return failure(); - - LDBG(viewOrAlloc); + return b.notifyMatchFailure(xferOp, "source not a view or alloc"); // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`. memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc); if (!subViewOp) - return failure(); + return b.notifyMatchFailure(xferOp, "no subview found"); Value subView = subViewOp.getResult(); - LDBG("with subView " << subView); // Find the copy into `subView` without interleaved uses. memref::CopyOp copyOp; @@ -1266,7 +1401,6 @@ assert(newCopyOp.getTarget().getType().isa()); if (newCopyOp.getTarget() != subView) continue; - LDBG("copy candidate " << *newCopyOp); if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView})) continue; copyOp = newCopyOp; @@ -1274,8 +1408,7 @@ } } if (!copyOp) - return failure(); - LDBG("with copy " << *copyOp); + return b.notifyMatchFailure(xferOp, "no copy found"); // Find the fill into `viewOrAlloc` without interleaved uses before the // copy. @@ -1285,7 +1418,6 @@ assert(newFillOp.output().getType().isa()); if (newFillOp.output() != viewOrAlloc) continue; - LDBG("fill candidate " << *newFillOp); if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView})) continue; maybeFillOp = newFillOp; @@ -1294,9 +1426,7 @@ } // Ensure padding matches. if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value()) - return failure(); - if (maybeFillOp) - LDBG("with maybeFillOp " << *maybeFillOp); + return b.notifyMatchFailure(xferOp, "padding value does not match fill"); // `in` is the subview that memref.copy reads. Replace it. Value in = copyOp.getSource(); @@ -1305,38 +1435,39 @@ // The `masked` attribute is only valid on this padded buffer. // When forwarding to vector.transfer_read, the attribute must be reset // conservatively. - Value res = rewriter.create( + Value res = b.create( xferOp.getLoc(), xferOp.getVectorType(), in, xferOp.getIndices(), xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(), // in_bounds is explicitly reset /*inBoundsAttr=*/ArrayAttr()); if (maybeFillOp) - rewriter.eraseOp(maybeFillOp); - rewriter.eraseOp(copyOp); - rewriter.replaceOp(xferOp, res); + b.eraseOp(maybeFillOp); + b.eraseOp(copyOp); + b.replaceOp(xferOp, res); return success(); } /// TODO: use interfaces, side-effects and aliasing analysis as appropriate, /// when available. -LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite( - vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const { +LogicalResult +LinalgCopyVTWForwardingPattern::matchAndRewrite(vector::TransferWriteOp xferOp, + PatternRewriter &b) const { // TODO: support mask. if (xferOp.getMask()) - return failure(); + return b.notifyMatchFailure(xferOp, "unsupported mask"); // Transfer into `viewOrAlloc`. Value viewOrAlloc = xferOp.getSource(); if (!viewOrAlloc.getDefiningOp() && !viewOrAlloc.getDefiningOp()) - return failure(); + return b.notifyMatchFailure(xferOp, "source not a view or alloc"); // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`. memref::SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc); if (!subViewOp) - return failure(); + return b.notifyMatchFailure(xferOp, "no subview found"); Value subView = subViewOp.getResult(); // Find the copy from `subView` without interleaved uses. @@ -1352,7 +1483,7 @@ } } if (!copyOp) - return failure(); + return b.notifyMatchFailure(xferOp, "no copy found"); // `out` is the subview copied into that we replace. assert(copyOp.getTarget().getType().isa()); @@ -1363,14 +1494,14 @@ // The `masked` attribute is only valid on this padded buffer. // When forwarding to vector.transfer_write, the attribute must be reset // conservatively. - rewriter.create( + b.create( xferOp.getLoc(), xferOp.getVector(), out, xferOp.getIndices(), xferOp.getPermutationMapAttr(), xferOp.getMask(), // in_bounds is explicitly reset /*inBoundsAttr=*/ArrayAttr()); - rewriter.eraseOp(copyOp); - rewriter.eraseOp(xferOp); + b.eraseOp(copyOp); + b.eraseOp(xferOp); return success(); } @@ -1422,7 +1553,7 @@ /// kw is unrolled, w is unrolled iff dilationW > 1. struct Conv1DGenerator : public StructuredGenerator { - Conv1DGenerator(OpBuilder &builder, LinalgOp linalgOp, int strideW, + Conv1DGenerator(RewriterBase &builder, LinalgOp linalgOp, int strideW, int dilationW) : StructuredGenerator(builder, linalgOp), strideW(strideW), dilationW(dilationW) { @@ -1488,7 +1619,7 @@ /// > 1. FailureOr conv(Conv1DOpOrder conv1DOpOrder) { if (!valid) - return failure(); + return builder.notifyMatchFailure(op, "unvectorizable 1-D conv"); int64_t nSize, wSize, cSize, kwSize, fSize; SmallVector lhsShape, rhsShape, resShape; @@ -1549,9 +1680,9 @@ Value res = builder.create( loc, resType, resShaped, ValueRange{zero, zero, zero}); - // The base vectorization case is input: {n,w,c}, weight: {kw,c,f}, output: - // {n,w,f}. To reuse the base pattern vectorization case, we do pre - // transpose on input, weight, and output. + // The base vectorization case is input: {n,w,c}, weight: {kw,c,f}, + // output: {n,w,f}. To reuse the base pattern vectorization case, we do + // pre transpose on input, weight, and output. switch (conv1DOpOrder) { case Conv1DOpOrder::Nwc: // Base case, so no transposes necessary. @@ -1647,7 +1778,7 @@ } // Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f} - Value conv1dSliceAsContraction(OpBuilder &b, Location loc, Value lhs, + Value conv1dSliceAsContraction(RewriterBase &b, Location loc, Value lhs, Value rhs, Value res) { vector::IteratorType par = vector::IteratorType::parallel; vector::IteratorType red = vector::IteratorType::reduction; @@ -1670,7 +1801,7 @@ /// > 1. FailureOr depthwiseConv() { if (!valid) - return failure(); + return builder.notifyMatchFailure(op, "unvectorizable depthwise conv"); int64_t nSize, wSize, cSize, kwSize; // kernel{kw, c} @@ -1753,10 +1884,8 @@ } // Its possible we failed to create the Fma - for (auto v : resVals) { - if (!v) - return failure(); - } + if (!llvm::all_of(resVals, [](Value v) { return v; })) + return builder.notifyMatchFailure(op, "failed to create FMA"); // Write back res slice: {n, wSizeStep, c} @ [0, w, 0]. // This does not depend on kw. @@ -1778,7 +1907,7 @@ } // Take a value of element type T and widen to the destination type. - Value promote(OpBuilder &b, Location loc, Value val, Type ty) { + Value promote(RewriterBase &b, Location loc, Value val, Type ty) { if (val.getType() == ty) return val; @@ -1796,7 +1925,7 @@ } /// Lower lhs{n, w, c} * rhs{c} -> res{n, w, c} to MulAcc - Value depthwiseConv1dSliceAsMulAcc(OpBuilder &b, Location loc, Value lhs, + Value depthwiseConv1dSliceAsMulAcc(RewriterBase &b, Location loc, Value lhs, Value rhs, Value res) { auto rhsTy = rhs.getType().cast(); auto resTy = res.getType().cast(); @@ -1823,15 +1952,18 @@ FailureOr generateNwcConv() { AffineExpr n, w, f, kw, c; bindDims(ctx, n, w, f, kw, c); - if (!iters({Par(), Par(), Par(), Red(), Red()})) - return failure(); + if (!iters({Par(), Par(), Par(), Red(), Red()})) { + return builder.notifyMatchFailure( + op, "failed to match conv::Nwc 3-par 2-red"); + } // No transposition needed. if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c}, /*rhsIndex*/ {kw, c, f}, /*resIndex*/ {n, w, f}})) return conv(Conv1DOpOrder::Nwc); - return failure(); + + return builder.notifyMatchFailure(op, "not a conv::Nwc layout"); } /// Entry point that transposes into the common form: @@ -1839,15 +1971,17 @@ FailureOr generateNcwConv() { AffineExpr n, w, f, kw, c; bindDims(ctx, n, f, w, c, kw); - if (!iters({Par(), Par(), Par(), Red(), Red()})) - return failure(); + if (!iters({Par(), Par(), Par(), Red(), Red()})) { + return builder.notifyMatchFailure( + op, "failed to match conv::Ncw 3-par 2-red"); + } if (layout({/*lhsIndex*/ {n, c, strideW * w + dilationW * kw}, /*rhsIndex*/ {f, c, kw}, /*resIndex*/ {n, f, w}})) return conv(Conv1DOpOrder::Ncw); - return failure(); + return builder.notifyMatchFailure(op, "not a conv::Ncw layout"); } /// Entry point that transposes into the common form: @@ -1855,15 +1989,18 @@ FailureOr generateDilatedConv() { AffineExpr n, w, c, kw; bindDims(ctx, n, w, c, kw); - if (!iters({Par(), Par(), Par(), Red()})) - return failure(); + if (!iters({Par(), Par(), Par(), Red()})) { + return builder.notifyMatchFailure( + op, "failed to match depthwise::Nwc conv 3-par 1-red"); + } // No transposition needed. if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c}, /*rhsIndex*/ {kw, c}, /*resIndex*/ {n, w, c}})) return depthwiseConv(); - return failure(); + + return builder.notifyMatchFailure(op, "not a depthwise::Nwc layout"); } private: @@ -1876,11 +2013,12 @@ /// Helper function to vectorize a LinalgOp with convolution semantics. // TODO: extend the generic vectorization to support windows and drop this. -static FailureOr vectorizeConvolution(OpBuilder &b, LinalgOp op) { +static FailureOr vectorizeConvolution(RewriterBase &b, + LinalgOp op) { // The ConvolutionOpInterface gives us guarantees of existence for - // strides/dilations. However, we do not need to rely on those, we can simply - // use them if present, otherwise use the default and let the generic conv. - // matcher in the ConvGenerator succeed or fail. + // strides/dilations. However, we do not need to rely on those, we can + // simply use them if present, otherwise use the default and let the generic + // conv. matcher in the ConvGenerator succeed or fail. auto strides = op->getAttrOfType("strides"); auto dilations = op->getAttrOfType("dilations"); auto stride = strides ? *strides.getValues().begin() : 1; @@ -1899,17 +2037,18 @@ using OpInterfaceRewritePattern::OpInterfaceRewritePattern; LogicalResult matchAndRewrite(LinalgOp op, - PatternRewriter &rewriter) const override { - FailureOr resultOrFail = vectorizeConvolution(rewriter, op); + PatternRewriter &b) const override { + FailureOr resultOrFail = vectorizeConvolution(b, op); + // Nothing meaningful to report here, children have already reported. if (failed(resultOrFail)) return failure(); Operation *newOp = *resultOrFail; if (newOp->getNumResults() == 0) { - rewriter.eraseOp(op.getOperation()); + b.eraseOp(op.getOperation()); return success(); } assert(newOp->getNumResults() == 1 && "expected single result"); - rewriter.replaceOp(op.getOperation(), newOp->getResult(0)); + b.replaceOp(op.getOperation(), newOp->getResult(0)); return success(); } }; diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -30,6 +30,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Support/LLVM.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringSet.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/ADT/bit.h" @@ -2847,18 +2848,9 @@ SmallVector seen(permutationMap.getNumInputs(), false); for (auto expr : permutationMap.getResults()) { auto dim = expr.dyn_cast(); - auto zero = expr.dyn_cast(); - if (zero) { - if (zero.getValue() != 0) { - return emitOpError( - "requires a projected permutation_map (at most one dim or the zero " - "constant can appear in each result)"); - } - continue; - } if (!dim) { - return emitOpError("requires a projected permutation_map (at most one " - "dim or the zero constant can appear in each result)"); + return emitOpError("requires a projected permutation_map (exactly one " + "dim can appear in each result)"); } if (seen[dim.getPosition()]) { return emitOpError( @@ -2948,10 +2940,6 @@ "as permutation_map results: ") << AffineMapAttr::get(permutationMap) << " vs inBounds of size: " << inBounds.size(); - for (unsigned int i = 0; i < permutationMap.getNumResults(); ++i) - if (permutationMap.getResult(i).isa() && - !inBounds.getValue()[i].cast().getValue()) - return op->emitOpError("requires broadcast dimensions to be in-bounds"); } return success(); @@ -3119,8 +3107,7 @@ } // Currently out-of-bounds, check whether we can statically determine it is // inBounds. - auto dimExpr = permutationMap.getResult(i).dyn_cast(); - assert(dimExpr && "Broadcast dims must be in-bounds"); + auto dimExpr = permutationMap.getResult(i).cast(); auto inBounds = isInBounds(op, /*resultIdx=*/i, /*indicesIdx=*/dimExpr.getPosition()); newInBounds.push_back(inBounds); @@ -3287,11 +3274,10 @@ } }; -/// Store to load forwarding for transfer operations with permuation maps. +/// Store to load forwarding for transfer operations with permutation maps. /// Even if the permutation maps are different we can still propagate the store /// into the load if the size of the dimensions read and written match. Then we -/// can replace the transfer_read + transfer_write by vector.broadcast and -/// vector.transpose. +/// can replace the transfer_read + transfer_write by vector.transpose. /// Example: /// ``` /// %w0 = vector.transfer_write %v0, %arg0[%c0, %c0, %c0] @@ -3300,67 +3286,59 @@ /// vector<4x1xf32>, tensor<4x4x4xf32> /// %r = vector.transfer_read %w0[%c0, %c0, %c0], %cf0 /// {in_bounds = [true, true, true, true], -/// permutation_map = affine_map<(d0, d1, d2) -> (d1, 0, d2, 0)>} : -/// tensor<4x4x4xf32>, vector<1x100x4x5xf32> +/// permutation_map = affine_map<(d0, d1, d2) -> (d1, d2)>} : +/// tensor<4x4x4xf32>, vector<1x4xf32> /// ``` /// To: /// ``` -/// %0 = vector.broadcast %arg1 : vector<4x1xf32> to vector<100x5x4x1xf32> -/// %r = vector.transpose %0, [3, 0, 2, 1] : -/// vector<100x5x4x1xf32> to vector<1x100x4x5xf32> +/// %r = vector.transpose %0, [1, 0] : +/// vector<4x1xf32> to vector<1x4xf32> /// ``` -struct TransferReadAfterWriteToBroadcast +struct TransferReadAfterWriteToTranspose : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(TransferReadOp readOp, PatternRewriter &rewriter) const override { if (readOp.hasOutOfBoundsDim() || - !readOp.getShapedType().isa()) - return failure(); + !readOp.getShapedType().isa()) { + return rewriter.notifyMatchFailure( + readOp, "not a read into ranked tensor with all dims inbounds"); + } auto defWrite = readOp.getSource().getDefiningOp(); if (!defWrite) - return failure(); + return rewriter.notifyMatchFailure(readOp, "not defined by a write"); SmallVector readDims = readOp.getTransferChunkAccessed(); - Value vec; - if (readOp.getIndices() == defWrite.getIndices() && - readOp.getMask() == defWrite.getMask()) { - SmallVector writeDims = defWrite.getTransferChunkAccessed(); - // TODO: If the writeDim is a superset of the read dims we could do an - // extract_strided_slice. - if (writeDims == readDims) - vec = defWrite.getVector(); - } + if (readOp.getIndices() != defWrite.getIndices()) + return rewriter.notifyMatchFailure(readOp, "not defined by a write"); + if (readOp.getMask() != defWrite.getMask()) + return rewriter.notifyMatchFailure(readOp, "non-matching masks"); + // TODO: If the writeDim is a superset of the read dims we could do an + // extract_strided_slice. + SmallVector writeDims = defWrite.getTransferChunkAccessed(); + if (writeDims != readDims) + return rewriter.notifyMatchFailure( + readOp, "non-matching masks read/write transfer chunks accessed"); // TODO: loop through the chain of transfer_write if we can prove that they // don't overlap with the transfer_read. This requires improving // `isDisjointTransferIndices` helper. - if (!vec) - return failure(); - SmallVector permutation; + Value vec = defWrite.getVector(); AffineMap readMap = compressUnusedDims(readOp.getPermutationMap()); AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap()); AffineMap map = readMap.compose(writeMap); if (map.getNumResults() == 0) - return failure(); - // Calculate the permuation to apply to go from the vector stored to the + return rewriter.notifyMatchFailure(readOp, + "0-result composed permutation map"); + // Calculate the permutation to apply to go from the vector stored to the // vector read. - if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) - return failure(); - - Location loc = readOp.getLoc(); - // Calculate the broadcast shape by applying the reverse permuation to the - // final shape we want. - ArrayRef destShape = readOp.getVectorType().getShape(); - SmallVector broadcastShape(destShape.size()); - for (const auto &pos : llvm::enumerate(permutation)) - broadcastShape[pos.value()] = destShape[pos.index()]; - VectorType broadcastedType = VectorType::get( - broadcastShape, defWrite.getVectorType().getElementType()); - vec = rewriter.create(loc, broadcastedType, vec); - SmallVector transposePerm(permutation.begin(), permutation.end()); + auto maybePermutation = map.getDimPermutationVector(); + if (failed(maybePermutation)) { + return rewriter.notifyMatchFailure( + readOp, "no permutation map exists between write and read map"); + } rewriter.replaceOpWithNewOp(readOp, vec, - transposePerm); + *maybePermutation); return success(); } }; @@ -3369,7 +3347,7 @@ void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results - .add( + .add( context); } @@ -3493,11 +3471,6 @@ if (llvm::size(getIndices()) != shapedType.getRank()) return emitOpError("requires ") << shapedType.getRank() << " indices"; - // We do not allow broadcast dimensions on TransferWriteOps for the moment, - // as the semantics is unclear. This can be revisited later if necessary. - if (hasBroadcastDim()) - return emitOpError("should not have broadcast dimensions"); - if (failed(verifyTransferOp(cast(getOperation()), shapedType, vectorType, maskType, permutationMap, getInBounds() ? *getInBounds() : ArrayAttr()))) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -1522,7 +1522,7 @@ /// This unrolls outer-products along the reduction dimension. struct UnrolledOuterProductGenerator : public StructuredGenerator { - UnrolledOuterProductGenerator(OpBuilder &builder, vector::ContractionOp op) + UnrolledOuterProductGenerator(RewriterBase &builder, vector::ContractionOp op) : StructuredGenerator( builder, op), kind(op.getKind()), lhs(op.getLhs()), rhs(op.getRhs()), diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp --- a/mlir/lib/IR/AffineMap.cpp +++ b/mlir/lib/IR/AffineMap.cpp @@ -12,6 +12,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Support/MathExtras.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/StringRef.h" @@ -142,9 +143,32 @@ return true; } +/// Return a permutation vector encoding the permutation of the map's results. +/// This is computed once the unused dims and symbols are compressed away. +/// Return failure if the compressed map is not exactly a permutation. +FailureOr> AffineMap::getDimPermutationVector() const { + AffineMap compressedMap = compressUnusedSymbols(compressUnusedDims(*this)); + SmallVector res(compressedMap.getNumResults(), -1); + if (!compressedMap.isPermutation()) + return failure(); + int64_t idx = 0; + for (AffineExpr expr : compressedMap.getResults()) { + auto dimExpr = expr.dyn_cast(); + if (!dimExpr) + return failure(); + res[idx++] = dimExpr.getPosition(); + } + // Should be guaranteed by isPermutation. + assert(llvm::all_of(res, [](int64_t v) { return v >= 0; }) && + "expected nonnegative"); + return res; +} + /// Return true if this affine map can be converted to a minor identity with -/// broadcast by doing a permute. Return a permutation (there may be -/// several) to apply to get to a minor identity with broadcasts. +/// by a permutation. If `allowBroadcast` is specified, additionally treat `0` +/// as broadcasts. +/// Return a permutation (not guarateed unique) to apply to get to a minor +/// identity. /// Ex: /// * (d0, d1, d2) -> (0, d1) maps to minor identity (d1, 0 = d2) with /// perm = [1, 0] and broadcast d2 @@ -156,7 +180,7 @@ /// leading broadcat dimensions. The map returned would be (0, 0, d0, d1) with /// perm = [3, 0, 1, 2] bool AffineMap::isPermutationOfMinorIdentityWithBroadcasting( - SmallVectorImpl &permutedDims) const { + SmallVectorImpl &permutedDims, bool allowBroadcast) const { unsigned projectionStart = getNumResults() < getNumInputs() ? getNumInputs() - getNumResults() : 0; permutedDims.clear(); @@ -175,7 +199,7 @@ // Each result may be either a constant 0 (broadcast dimension) or a // dimension. if (auto constExpr = expr.dyn_cast()) { - if (constExpr.getValue() != 0) + if (!allowBroadcast || constExpr.getValue() != 0) return false; broadcastDims.push_back(resIdx); } else if (auto dimExpr = expr.dyn_cast()) { @@ -189,7 +213,7 @@ return false; } } - // Find a permuation for the broadcast dimension. Since they are broadcasted + // Find a permutation for the broadcast dimension. Since they are broadcasted // any valid permutation is acceptable. We just permute the dim into a slot // without an existing dimension. unsigned pos = 0; diff --git a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir --- a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir +++ b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir @@ -4,7 +4,6 @@ #map1 = affine_map<(d0, d1, d2) -> (d0, d2)> #map2 = affine_map<(d0, d1, d2) -> (d1, d2)> #map3 = affine_map<(d0, d1, d2) -> (d0, d1)> -#map4 = affine_map<(d0) -> (d0, 0)> // CHECK-LABEL: func @matmul // CHECK-DAG: %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp"> @@ -107,66 +106,3 @@ vector.transfer_write %E, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16> return } - -// CHECK-LABEL: func @matmul_fused_broadcast -// CHECK-DAG: %[[CST_0:.+]] = arith.constant 0.000000e+00 : f16 -// CHECK-DAG: %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp"> -// CHECK-DAG: %[[B:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp"> -// CHECK-DAG: %[[C0:.+]] = gpu.subgroup_mma_constant_matrix %[[CST_0]] : !gpu.mma_matrix<16x16xf16, "COp"> -// CHECK: %[[D:.+]] = gpu.subgroup_mma_compute %[[A]], %[[B]], %[[C0]] : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp"> -// CHECK: %[[E:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] {leadDimension = 0 : index} : memref<16x16x16x16xf16> -> !gpu.mma_matrix<16x16xf16, "COp"> -// CHECK: %[[F:.+]] = gpu.subgroup_mma_elementwise divf %[[D]], %[[E]] : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp"> -// CHECK: gpu.subgroup_mma_store_matrix %[[F]], %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16> -func.func @matmul_fused_broadcast(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>, - %arg2: memref<16x16xf16>, %arg3: memref<16x16x16x16xf16>) { - %cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf16> - %c0 = arith.constant 0 : index - %cst = arith.constant 0.000000e+00 : f16 - %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16> - %B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16> - %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %cst_0 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16> - %E = vector.transfer_read %arg3[%c0, %c0, %c0, %c0], %cst - {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3)->(0, d3)>} - : memref<16x16x16x16xf16>, vector<16x16xf16> - %F = arith.divf %D, %E : vector<16x16xf16> - vector.transfer_write %F, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16> - return -} - -// CHECK-LABEL: func @matmul_3Dmemref -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] {leadDimension = 16 : index} : memref<2x16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp"> -// CHECK-DAG: %[[B:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%[[C0]]] {leadDimension = 0 : index} : memref<16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp"> -// CHECK-DAG: %[[C:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] {leadDimension = 16 : index} : memref<2x16x16xf16> -> !gpu.mma_matrix<16x16xf16, "COp"> -// CHECK: %[[D:.+]] = gpu.subgroup_mma_compute %[[A]], %[[B]], %[[C]] : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp"> -// CHECK: gpu.subgroup_mma_store_matrix %[[D]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<2x16x16xf16> -func.func @matmul_3Dmemref(%arg0: memref<2x16x16xf16>, %arg1: memref<16xf16>, %arg2: memref<2x16x16xf16>) { - %cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf16> - %c0 = arith.constant 0 : index - %cst = arith.constant 0.000000e+00 : f16 - %A = vector.transfer_read %arg0[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16>, vector<16x16xf16> - %B = vector.transfer_read %arg1[%c0], %cst {permutation_map = #map4, in_bounds = [true, true]} : memref<16xf16>, vector<16x16xf16> - %C = vector.transfer_read %arg2[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16>, vector<16x16xf16> - %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %C : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16> - vector.transfer_write %D, %arg2[%c0, %c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<2x16x16xf16> - return -} - -// CHECK-LABEL: func @matmul_memref_strided -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] {leadDimension = 32 : index} : memref<2x16x16xf16, #{{.*}}> -> !gpu.mma_matrix<16x16xf16, "AOp"> -// CHECK-DAG: %[[B:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%[[C0]]] {leadDimension = 0 : index} : memref<16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp"> -// CHECK-DAG: %[[C:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] {leadDimension = 16 : index} : memref<2x16x16xf16> -> !gpu.mma_matrix<16x16xf16, "COp"> -// CHECK: %[[D:.+]] = gpu.subgroup_mma_compute %[[A]], %[[B]], %[[C]] : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp"> -// CHECK: gpu.subgroup_mma_store_matrix %[[D]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<2x16x16xf16> -func.func @matmul_memref_strided(%arg0: memref<2x16x16xf16, affine_map<(d0, d1, d2) -> (d0 * 512 + d1 * 32 + d2)>>, %arg1: memref<16xf16>, %arg2: memref<2x16x16xf16>) { - %cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf16> - %c0 = arith.constant 0 : index - %cst = arith.constant 0.000000e+00 : f16 - %A = vector.transfer_read %arg0[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16, affine_map<(d0, d1, d2) -> (d0 * 512 + d1 * 32 + d2)>>, vector<16x16xf16> - %B = vector.transfer_read %arg1[%c0], %cst {permutation_map = #map4, in_bounds = [true, true]} : memref<16xf16>, vector<16x16xf16> - %C = vector.transfer_read %arg2[%c0, %c0, %c0], %cst {in_bounds = [true, true]} : memref<2x16x16xf16>, vector<16x16xf16> - %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %C : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16> - vector.transfer_write %D, %arg2[%c0, %c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<2x16x16xf16> - return -} diff --git a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir --- a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir +++ b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir @@ -80,73 +80,6 @@ // CHECK: #[[$ADD:map.*]] = affine_map<(d0, d1) -> (d0 + d1)> -// CHECK-LABEL: func @materialize_read(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index) { -func.func @materialize_read(%M: index, %N: index, %O: index, %P: index) { - %f0 = arith.constant 0.0: f32 - // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index - // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index - // CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index - // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index - // CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index - // CHECK: %{{.*}} = memref.alloc(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : memref - // CHECK-NEXT: affine.for %[[I0:.*]] = 0 to %{{.*}} step 3 { - // CHECK-NEXT: affine.for %[[I1:.*]] = 0 to %{{.*}} { - // CHECK-NEXT: affine.for %[[I2:.*]] = 0 to %{{.*}} { - // CHECK-NEXT: affine.for %[[I3:.*]] = 0 to %{{.*}} step 5 { - // CHECK: %[[ALLOC:.*]] = memref.alloca() : memref> - // CHECK: scf.for %[[I4:.*]] = %[[C0]] to %[[C5]] step %[[C1]] { - // CHECK: scf.if - // CHECK: %[[L3:.*]] = affine.apply #[[$ADD]](%[[I3]], %[[I4]]) - // CHECK: scf.for %[[I5:.*]] = %[[C0]] to %[[C4]] step %[[C1]] { - // CHECK: %[[VEC:.*]] = scf.for %[[I6:.*]] = %[[C0]] to %[[C3]] step %[[C1]] {{.*}} -> (vector<3xf32>) { - // CHECK: %[[L0:.*]] = affine.apply #[[$ADD]](%[[I0]], %[[I6]]) - // CHECK: scf.if {{.*}} -> (vector<3xf32>) { - // CHECK-NEXT: %[[SCAL:.*]] = memref.load %{{.*}}[%[[L0]], %[[I1]], %[[I2]], %[[L3]]] : memref - // CHECK-NEXT: %[[RVEC:.*]] = vector.insertelement %[[SCAL]], %{{.*}}[%[[I6]] : index] : vector<3xf32> - // CHECK-NEXT: scf.yield - // CHECK-NEXT: } else { - // CHECK-NEXT: scf.yield - // CHECK-NEXT: } - // CHECK-NEXT: scf.yield - // CHECK-NEXT: } - // CHECK-NEXT: memref.store %[[VEC]], {{.*}} : memref<5x4xvector<3xf32>> - // CHECK-NEXT: } - // CHECK-NEXT: } else { - // CHECK-NEXT: memref.store {{.*}} : memref<5xvector<4x3xf32>> - // CHECK-NEXT: } - // CHECK-NEXT: } - // CHECK-NEXT: %[[LD:.*]] = memref.load %[[ALLOC]][] : memref> - // CHECK-NEXT: "dummy_use"(%[[LD]]) : (vector<5x4x3xf32>) -> () - // CHECK-NEXT: } - // CHECK-NEXT: } - // CHECK-NEXT: } - // CHECK-NEXT: } - // CHECK-NEXT: return - // CHECK-NEXT:} - - // Check that I0 + I4 (of size 3) read from first index load(L0, ...) and write into last index store(..., I4) - // Check that I3 + I6 (of size 5) read from last index load(..., L3) and write into first index store(I6, ...) - // Other dimensions are just accessed with I1, I2 resp. - %A = memref.alloc (%M, %N, %O, %P) : memref - affine.for %i0 = 0 to %M step 3 { - affine.for %i1 = 0 to %N { - affine.for %i2 = 0 to %O { - affine.for %i3 = 0 to %P step 5 { - %f = vector.transfer_read %A[%i0, %i1, %i2, %i3], %f0 {permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, 0, d0)>} : memref, vector<5x4x3xf32> - // Add a dummy use to prevent dead code elimination from removing - // transfer read ops. - "dummy_use"(%f) : (vector<5x4x3xf32>) -> () - } - } - } - } - return -} - -// ----- - -// CHECK: #[[$ADD:map.*]] = affine_map<(d0, d1) -> (d0 + d1)> - // CHECK-LABEL:func @materialize_write(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index) { func.func @materialize_write(%M: index, %N: index, %O: index, %P: index) { // CHECK-DAG: %{{.*}} = arith.constant dense<1.000000e+00> : vector<5x4x3xf32> diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -91,8 +91,11 @@ // CHECK-LABEL: func @vectorization_test func.func @vectorization_test(%A: memref<8x16xf32>, %B: memref<16x32xf32>, %C: memref<8x32xf32>) { - // CHECK: vector.transfer_read %{{.*}} : memref<8x16xf32>, vector<8x32x16xf32> - // CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<8x32x16xf32> + // CHECK: vector.transfer_read %{{.*}} : memref<8x16xf32>, vector<8x16xf32> + // CHECK: vector.broadcast {{.*}} : vector<8x16xf32> to vector<32x8x16xf32> + // CHECK: vector.transpose {{.*}}, [1, 0, 2] : vector<32x8x16xf32> to vector<8x32x16xf32> + // CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<32x16xf32> + // CHECK: vector.broadcast {{.*}} : vector<32x16xf32> to vector<8x32x16xf32> // CHECK: %[[ACC:.*]] = vector.transfer_read %{{.*}} : memref<8x32xf32>, vector<8x32xf32> // CHECK: %[[MUL:.*]] = arith.mulf %{{.*}}, %{{.*}} : vector<8x32x16xf32> // CHECK: %[[R:.*]] = vector.multi_reduction , %[[MUL]], %[[ACC]] [2] : vector<8x32x16xf32> to vector<8x32xf32> @@ -131,8 +134,11 @@ // CHECK-LABEL: func @generic_output_transpose func.func @generic_output_transpose(%A: memref<8x16xf32>, %B: memref<16x32xf32>, %C: memref<32x8xf32>) { - // CHECK: vector.transfer_read %{{.*}} : memref<8x16xf32>, vector<8x32x16xf32> - // CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<8x32x16xf32> + // CHECK: vector.transfer_read %{{.*}} : memref<8x16xf32>, vector<8x16xf32> + // CHECK: vector.broadcast {{.*}} : vector<8x16xf32> to vector<32x8x16xf32> + // CHECK: vector.transpose {{.*}}, [1, 0, 2] : vector<32x8x16xf32> to vector<8x32x16xf32> + // CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<32x16xf32> + // CHECK: vector.broadcast {{.*}} : vector<32x16xf32> to vector<8x32x16xf32> // CHECK: %[[ACC:.*]] = vector.transfer_read %{{.*}} : memref<32x8xf32>, vector<8x32xf32> // CHECK: %[[MUL:.*]] = arith.mulf %{{.*}}, %{{.*}} : vector<8x32x16xf32> // CHECK: %[[R:.*]] = vector.multi_reduction , %[[MUL]], %[[ACC]] [2] : vector<8x32x16xf32> to vector<8x32xf32> @@ -198,8 +204,11 @@ // CHECK-LABEL: func @vectorization_test_integer func.func @vectorization_test_integer(%A: memref<8x16xi32>, %B: memref<16x32xi32>, %C: memref<8x32xi32>) { - // CHECK: vector.transfer_read %{{.*}} : memref<8x16xi32>, vector<8x32x16xi32> - // CHECK: vector.transfer_read %{{.*}} : memref<16x32xi32>, vector<8x32x16xi32> + // CHECK: vector.transfer_read %{{.*}} : memref<8x16xi32>, vector<8x16xi32> + // CHECK: vector.broadcast {{.*}} : vector<8x16xi32> to vector<32x8x16xi32> + // CHECK: vector.transpose {{.*}}, [1, 0, 2] : vector<32x8x16xi32> to vector<8x32x16xi32> + // CHECK: vector.transfer_read %{{.*}} : memref<16x32xi32>, vector<32x16xi32> + // CHECK: vector.broadcast {{.*}} : vector<32x16xi32> to vector<8x32x16xi32> // CHECK: %[[ACC:.*]] = vector.transfer_read %{{.*}} : memref<8x32xi32>, vector<8x32xi32> // CHECK: %[[MUL:.*]] = arith.muli %{{.*}}, %{{.*}} : vector<8x32x16xi32> // CHECK: vector.multi_reduction , %[[MUL]], %[[ACC]] [2] : vector<8x32x16xi32> to vector<8x32xi32> @@ -455,7 +464,8 @@ memref<4x256xf32>, memref<4x256xf32>) { ^bb0(%arg3 : f32, %arg4 : f32, %arg5: f32, %arg6: f32, %arg7: f32, %arg8: f32, // CHECK: %[[V2:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32> - // CHECK: %[[V0:.*]] = vector.transfer_read %[[ARG2]][%[[C0]]], {{.*}} : memref<256xf32>, vector<4x256xf32> + // CHECK: %[[V0tmp:.*]] = vector.transfer_read %[[ARG2]][%[[C0]]], {{.*}} : memref<256xf32>, vector<256xf32> + // CHECK: %[[V0:.*]] = vector.broadcast %[[V0tmp]] : vector<256xf32> to vector<4x256xf32> // CHECK: %[[V3:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32> // CHECK: %[[V1:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x256xf32>, vector<4x256xf32> %arg9 : f32, %arg10 : f32, %arg11 : f32, %arg12 : f32, %arg13 : f32, @@ -542,7 +552,8 @@ // CHECK-DAG: %[[CST1:.*]] = arith.constant dense<1.000000e+00> : vector<4x256xf32> // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[V2:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32> - // CHECK: %[[V0:.*]] = vector.transfer_read %[[ARG2]][%[[C0]]], {{.*}} : tensor<256xf32>, vector<4x256xf32> + // CHECK: %[[V0tmp:.*]] = vector.transfer_read %[[ARG2]][%[[C0]]], {{.*}} : tensor<256xf32>, vector<256xf32> + // CHECK: %[[V0:.*]] = vector.broadcast %[[V0tmp]] : vector<256xf32> to vector<4x256xf32> // CHECK: %[[V3:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32> // CHECK: %[[V1:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x256xf32>, vector<4x256xf32> // CHECK: %[[ADD:.*]] = arith.addf %[[V0]], %[[V1]] : vector<4x256xf32> @@ -596,17 +607,19 @@ // ----- -// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, 0, 0, d1)> -// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (d0, 0, 0, 0)> -// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0) -> (0, 0, d0, 0)> -// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1) -> (d1, 0, d0, 0)> +// CHECK-DAG: #[[$MAP:.*]] = affine_map<(d0, d1) -> (d1, d0)> + // CHECK: func @generic_vectorize_broadcast_transpose // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[CF:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[V0:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %[[CF]] {in_bounds = [true, true, true, true], permutation_map = #[[$MAP0]]} : memref<4x4xf32>, vector<4x4x4x4xf32> -// CHECK: %[[V1:.*]] = vector.transfer_read %{{.*}}[%[[C0]]], %[[CF]] {in_bounds = [true, true, true, true], permutation_map = #[[$MAP1]]} : memref<4xf32>, vector<4x4x4x4xf32> -// CHECK: %[[V2:.*]] = vector.transfer_read %{{.*}}[%[[C0]]], %[[CF]] {in_bounds = [true, true, true, true], permutation_map = #[[$MAP2]]} : memref<4xf32>, vector<4x4x4x4xf32> -// CHECK: %[[V3:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %[[CF]] {in_bounds = [true, true, true, true], permutation_map = #[[$MAP3]]} : memref<4x4xf32>, vector<4x4x4x4xf32> +// CHECK: %[[V0tmp:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %[[CF]] {in_bounds = [true, true]} : memref<4x4xf32>, vector<4x4xf32> +// CHECK: %[[V0:.*]] = vector.broadcast %[[V0tmp]] : vector<4x4xf32> to vector<4x4x4x4xf32> +// CHECK: %[[V1tmp:.*]] = vector.transfer_read %{{.*}}[%[[C0]]], %[[CF]] {in_bounds = [true]} : memref<4xf32>, vector<4xf32> +// CHECK: %[[V1:.*]] = vector.broadcast %[[V1tmp]] : vector<4xf32> to vector<4x4x4x4xf32> +// CHECK: %[[V2tmp:.*]] = vector.transfer_read %{{.*}}[%[[C0]]], %[[CF]] {in_bounds = [true]} : memref<4xf32>, vector<4xf32> +// CHECK: %[[V2:.*]] = vector.broadcast %[[V2tmp]] : vector<4xf32> to vector<4x4x4x4xf32> +// CHECK: %[[V3tmp:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %[[CF]] {in_bounds = [true, true], permutation_map = #[[$MAP]]} : memref<4x4xf32>, vector<4x4xf32> +// CHECK: %[[V3:.*]] = vector.broadcast %[[V3tmp]] : vector<4x4xf32> to vector<4x4x4x4xf32> // CHECK: %[[SUB:.*]] = arith.subf %[[V0]], %[[V1]] : vector<4x4x4x4xf32> // CHECK: %[[ADD0:.*]] = arith.addf %[[V2]], %[[SUB]] : vector<4x4x4x4xf32> // CHECK: %[[ADD1:.*]] = arith.addf %[[V3]], %[[ADD0]] : vector<4x4x4x4xf32> @@ -651,13 +664,16 @@ iterator_types = ["parallel", "parallel", "parallel", "parallel"] } -// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d1, d0, 0, 0)> -// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (0, d1, 0, d0)> -// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3, d0)> +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d1, d0)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3, d0)> // CHECK: func @vectorization_transpose -// CHECK: vector.transfer_read {{.*}}{in_bounds = [true, true, true, true], permutation_map = #[[MAP0]]} : memref<14x7xf32>, vector<7x14x8x16xf32> -// CHECK: vector.transfer_read {{.*}}{in_bounds = [true, true, true, true], permutation_map = #[[MAP1]]} : memref<16x14xf32>, vector<7x14x8x16xf32> -// CHECK: vector.transfer_read {{.*}}{in_bounds = [true, true, true, true], permutation_map = #[[MAP2]]} : memref<16x14x7x8xf32>, vector<7x14x8x16xf32> +// CHECK: vector.transfer_read {{.*}}{in_bounds = [true, true], permutation_map = #[[MAP0]]} : memref<14x7xf32>, vector<7x14xf32> +// CHECK: vector.broadcast {{.*}} : vector<7x14xf32> to vector<8x16x7x14xf32> +// CHECK: vector.transpose %{{.*}}, [2, 3, 0, 1] : vector<8x16x7x14xf32> to vector<7x14x8x16xf32> +// CHECK: vector.transfer_read {{.*}}{in_bounds = [true, true], permutation_map = #[[MAP0]]} : memref<16x14xf32>, vector<14x16xf32> +// CHECK: vector.broadcast {{.*}} : vector<14x16xf32> to vector<7x8x14x16xf32> +// CHECK: vector.transpose %{{.*}}, [0, 2, 1, 3] : vector<7x8x14x16xf32> to vector<7x14x8x16xf32> +// CHECK: vector.transfer_read {{.*}}{in_bounds = [true, true, true, true], permutation_map = #[[MAP1]]} : memref<16x14x7x8xf32>, vector<7x14x8x16xf32> // CHECK: arith.addf {{.*}} : vector<7x14x8x16xf32> // CHECK: arith.addf {{.*}} : vector<7x14x8x16xf32> // CHECK: vector.transfer_write {{.*}} : vector<7x14x8x16xf32>, memref<7x14x8x16xf32> @@ -689,10 +705,13 @@ func.func @matmul_tensors( %arg0: tensor<8x4xf32>, %arg1: tensor<4x12xf32>, %arg2: tensor<8x12xf32>) -> tensor<8x12xf32> { - // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index - // CHECK-DAG: %[[V0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x4xf32>, vector<8x12x4xf32> - // CHECK-DAG: %[[V1:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x12xf32>, vector<8x12x4xf32> - // CHECK-DAG: %[[V2:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x12xf32>, vector<8x12xf32> + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[V0tmp:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x4xf32>, vector<8x4xf32> + // CHECK: %[[V0tmp2:.*]] = vector.broadcast %[[V0tmp]] : vector<8x4xf32> to vector<12x8x4xf32> + // CHECK: %[[V0:.*]] = vector.transpose %[[V0tmp2]], [1, 0, 2] : vector<12x8x4xf32> to vector<8x12x4xf32> + // CHECK: %[[V1tmp:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x12xf32>, vector<12x4xf32> + // CHECK: %[[V1:.*]] = vector.broadcast %[[V1tmp]] : vector<12x4xf32> to vector<8x12x4xf32> + // CHECK: %[[V2:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x12xf32>, vector<8x12xf32> // // linalg matmul lowers gets expanded to a 3D reduction, canonicalization later // convert it to a 2D contract. @@ -1045,22 +1064,23 @@ // ----- -// CHECK-DAG: #[[$M1:.*]] = affine_map<(d0, d1) -> (d1, d0, 0, 0)> -// CHECK-DAG: #[[$M2:.*]] = affine_map<(d0, d1) -> (0, 0, d1, d0)> -// CHECK-DAG: #[[$M3:.*]] = affine_map<(d0, d1) -> (d1, d0)> +// CHECK-DAG: #[[$MAP:.*]] = affine_map<(d0, d1) -> (d1, d0)> // CHECK-LABEL: func @sum_exp_2 func.func @sum_exp_2(%input: tensor<3x2xf32>, %input_2: tensor<5x4xf32>, %output: tensor<5x2xf32>) -> tensor<5x2xf32> { - // CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true, true, true], permutation_map = #[[$M1]]} : tensor<3x2xf32>, vector<2x3x4x5xf32> - // CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true, true, true], permutation_map = #[[$M2]]} : tensor<5x4xf32>, vector<2x3x4x5xf32> - // CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true], permutation_map = #[[$M3]]} : tensor<5x2xf32>, vector<2x5xf32> + // CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true], permutation_map = #[[$MAP]]} : tensor<3x2xf32>, vector<2x3xf32> + // CHECK: vector.broadcast %{{.*}} : vector<2x3xf32> to vector<4x5x2x3xf32> + // CHECK: vector.transpose %{{.*}}, [2, 3, 0, 1] : vector<4x5x2x3xf32> to vector<2x3x4x5xf32> + // CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true], permutation_map = #[[$MAP]]} : tensor<5x4xf32>, vector<4x5xf32> + // CHECK: vector.broadcast %{{.*}} : vector<4x5xf32> to vector<2x3x4x5xf32> + // CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true], permutation_map = #[[$MAP]]} : tensor<5x2xf32>, vector<2x5xf32> // CHECK: math.exp {{.*}} : vector<2x3x4x5xf32> // CHECK: math.exp {{.*}} : vector<2x3x4x5xf32> // CHECK: addf {{.*}} : vector<2x3x4x5xf32> // CHECK: vector.multi_reduction , {{.*}}, %{{.*}} [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32> - // CHECK: vector.transfer_write {{.*}} {in_bounds = [true, true], permutation_map = #[[$M3]]} : vector<2x5xf32>, tensor<5x2xf32> + // CHECK: vector.transfer_write {{.*}} {in_bounds = [true, true], permutation_map = #[[$MAP]]} : vector<2x5xf32>, tensor<5x2xf32> // CHECK: return {{.*}} : tensor<5x2xf32> %0 = linalg.generic { indexing_maps = [ @@ -1271,12 +1291,13 @@ // ----- -// CHECK-DAG: #[[$M5:.*]] = affine_map<(d0, d1) -> (d0, 0)> +// CHECK-DAG: #[[$MAP:.*]] = affine_map<(d0, d1) -> (d0)> // CHECK-LABEL: func @explicit_broadcast( func.func @explicit_broadcast(%arg0: tensor<4x4xf32>, %arg1: tensor<4x1xf32>) -> tensor<4x4xf32> { // CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true]} : tensor<4x4xf32>, vector<4x4xf32> - // CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true], permutation_map = #[[$M5]]} : tensor<4x1xf32>, vector<4x4xf32> + // CHECK: vector.transfer_read {{.*}} {in_bounds = [true], permutation_map = #[[$MAP]]} : tensor<4x1xf32>, vector<4xf32> + // CHECK: vector.broadcast %{{.*}} : vector<4xf32> to vector<4x4xf32> // CHECK: subf {{.*}} : vector<4x4xf32> // CHECK: vector.transfer_write {{.*}} {in_bounds = [true, true]} : vector<4x4xf32>, tensor<4x4xf32> %c0 = arith.constant 0.0 : f32 @@ -1305,12 +1326,13 @@ // ----- -// CHECK-DAG: #[[$M6:.*]] = affine_map<(d0, d1) -> (d0, 0)> +// CHECK-DAG: #[[$MAP:.*]] = affine_map<(d0, d1) -> (d0)> // CHECK-LABEL: func @fused_broadcast_red_2d func.func @fused_broadcast_red_2d(%arg0: tensor<4x4xf32>, %arg1: tensor<4x1xf32>) -> tensor<4xf32> { // CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true]} : tensor<4x4xf32>, vector<4x4xf32> - // CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true], permutation_map = #[[$M6]]} : tensor<4x1xf32>, vector<4x4xf32> + // CHECK: vector.transfer_read {{.*}} {in_bounds = [true], permutation_map = #[[$MAP]]} : tensor<4x1xf32>, vector<4xf32> + // CHECK: vector.broadcast %{{.*}} : vector<4xf32> to vector<4x4xf32> // CHECK: subf {{.*}} : vector<4x4xf32> // CHECK: math.exp {{.*}} : vector<4x4xf32> // CHECK: vector.multi_reduction , {{.*}}, {{.*}} : vector<4x4xf32> to vector<4xf32> @@ -1444,7 +1466,9 @@ // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<2x4x8xf32> // CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: tensor<2x4xf32> // CHECK-DAG: %[[V0:.+]] = vector.transfer_read %[[ARG0]] -// CHECK-DAG: %[[V1:.+]] = vector.transfer_read %[[ARG1]] +// CHECK-DAG: %[[V1tmp0:.+]] = vector.transfer_read %[[ARG1]] +// CHECK-DAG: %[[V1tmp1:.+]] = vector.broadcast %[[V1tmp0]] +// CHECK-DAG: %[[V1:.+]] = vector.transpose %[[V1tmp1]] // CHECK-DAG: %[[V2:.+]] = vector.transfer_read %[[ARG3]] // CHECK-DAG: %[[MUL:.+]] = arith.mulf %[[V0]], %[[V1]] // CHECK-DAG: %[[ADD:.+]] = vector.multi_reduction , %[[MUL]], %[[V2]] diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1058,41 +1058,23 @@ // ----- -// CHECK-LABEL: func @store_to_load_tensor_broadcast -// CHECK-SAME: (%[[ARG:.*]]: tensor<4x4xf32>, %[[V0:.*]]: vector<4x2xf32>) -// CHECK: %[[B:.*]] = vector.broadcast %[[V0]] : vector<4x2xf32> to vector<6x4x2xf32> -// CHECK: %[[T:.*]] = vector.transpose %[[B]], [1, 2, 0] : vector<6x4x2xf32> to vector<4x2x6xf32> -// CHECK: return %[[T]] : vector<4x2x6xf32> -func.func @store_to_load_tensor_broadcast(%arg0 : tensor<4x4xf32>, - %v0 : vector<4x2xf32>) -> vector<4x2x6xf32> { - %c0 = arith.constant 0 : index - %cf0 = arith.constant 0.0 : f32 - %w0 = vector.transfer_write %v0, %arg0[%c0, %c0] {in_bounds = [true, true]} : - vector<4x2xf32>, tensor<4x4xf32> - %0 = vector.transfer_read %w0[%c0, %c0], %cf0 {in_bounds = [true, true, true], - permutation_map = affine_map<(d0, d1) -> (d0, d1, 0)>} : - tensor<4x4xf32>, vector<4x2x6xf32> - return %0 : vector<4x2x6xf32> -} - -// ----- - -// CHECK-LABEL: func @store_to_load_tensor_perm_broadcast +// CHECK-LABEL: func @store_to_load_tensor_perm // CHECK-SAME: (%[[ARG:.*]]: tensor<4x4x4xf32>, %[[V0:.*]]: vector<4x1xf32>) -// CHECK: %[[B:.*]] = vector.broadcast %[[V0]] : vector<4x1xf32> to vector<100x5x4x1xf32> -// CHECK: %[[T:.*]] = vector.transpose %[[B]], [3, 0, 2, 1] : vector<100x5x4x1xf32> to vector<1x100x4x5xf32> -// CHECK: return %[[T]] : vector<1x100x4x5xf32> +// CHECK-NEXT: %[[T:.*]] = vector.transpose %[[V0]], [1, 0] : vector<4x1xf32> to vector<1x4xf32> +// CHECK: return %[[T]] : vector<1x4xf32> func.func @store_to_load_tensor_perm_broadcast(%arg0 : tensor<4x4x4xf32>, - %v0 : vector<4x1xf32>) -> vector<1x100x4x5xf32> { + %v0 : vector<4x1xf32>) -> vector<1x4xf32> { %c0 = arith.constant 0 : index %cf0 = arith.constant 0.0 : f32 - %w0 = vector.transfer_write %v0, %arg0[%c0, %c0, %c0] {in_bounds = [true, true], - permutation_map = affine_map<(d0, d1, d2) -> (d2, d1)>} : + %w0 = vector.transfer_write %v0, %arg0[%c0, %c0, %c0] + {in_bounds = [true, true], + permutation_map = affine_map<(d0, d1, d2) -> (d2, d1)>} : vector<4x1xf32>, tensor<4x4x4xf32> - %0 = vector.transfer_read %w0[%c0, %c0, %c0], %cf0 {in_bounds = [true, true, true, true], - permutation_map = affine_map<(d0, d1, d2) -> (d1, 0, d2, 0)>} : - tensor<4x4x4xf32>, vector<1x100x4x5xf32> - return %0 : vector<1x100x4x5xf32> + %0 = vector.transfer_read %w0[%c0, %c0, %c0], %cf0 + {in_bounds = [true, true], + permutation_map = affine_map<(d0, d1, d2) -> (d1, d2)>} : + tensor<4x4x4xf32>, vector<1x4xf32> + return %0 : vector<1x4xf32> } // ----- diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -354,7 +354,7 @@ func.func @test_vector.transfer_read(%arg0: memref) { %c3 = arith.constant 3 : index %cst = arith.constant 3.0 : f32 - // expected-error@+1 {{requires a projected permutation_map (at most one dim or the zero constant can appear in each result)}} + // expected-error@+1 {{requires a projected permutation_map (exactly one dim can appear in each result)}} %0 = vector.transfer_read %arg0[%c3, %c3], %cst {permutation_map = affine_map<(d0, d1)->(d0 + d1)>} : memref, vector<128xf32> } @@ -363,7 +363,7 @@ func.func @test_vector.transfer_read(%arg0: memref) { %c3 = arith.constant 3 : index %cst = arith.constant 3.0 : f32 - // expected-error@+1 {{requires a projected permutation_map (at most one dim or the zero constant can appear in each result)}} + // expected-error@+1 {{requires a projected permutation_map (exactly one dim can appear in each result)}} %0 = vector.transfer_read %arg0[%c3, %c3], %cst {permutation_map = affine_map<(d0, d1)->(d0 + 1)>} : memref, vector<128xf32> } @@ -420,16 +420,6 @@ // ----- -func.func @test_vector.transfer_read(%arg0: memref>) { - %c3 = arith.constant 3 : index - %f0 = arith.constant 0.0 : f32 - %vf0 = vector.splat %f0 : vector<2x3xf32> - // expected-error@+1 {{requires broadcast dimensions to be in-bounds}} - %0 = vector.transfer_read %arg0[%c3, %c3], %vf0 {in_bounds = [false, true], permutation_map = affine_map<(d0, d1)->(0, d1)>} : memref>, vector<1x1x2x3xf32> -} - -// ----- - func.func @test_vector.transfer_read(%arg0: memref>) { %c3 = arith.constant 3 : index %f0 = arith.constant 0.0 : f32 @@ -509,7 +499,7 @@ func.func @test_vector.transfer_write(%arg0: memref) { %c3 = arith.constant 3 : index %cst = arith.constant dense<3.0> : vector<128 x f32> - // expected-error@+1 {{requires a projected permutation_map (at most one dim or the zero constant can appear in each result)}} + // expected-error@+1 {{requires a projected permutation_map (exactly one dim can appear in each result)}} vector.transfer_write %cst, %arg0[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d0 + d1)>} : vector<128xf32>, memref } @@ -518,7 +508,7 @@ func.func @test_vector.transfer_write(%arg0: memref) { %c3 = arith.constant 3 : index %cst = arith.constant dense<3.0> : vector<128 x f32> - // expected-error@+1 {{requires a projected permutation_map (at most one dim or the zero constant can appear in each result)}} + // expected-error@+1 {{requires a projected permutation_map (exactly one dim can appear in each result)}} vector.transfer_write %cst, %arg0[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d0 + 1)>} : vector<128xf32>, memref } @@ -536,7 +526,7 @@ func.func @test_vector.transfer_write(%arg0: memref, %arg1: vector<7xf32>) { %c3 = arith.constant 3 : index %cst = arith.constant 3.0 : f32 - // expected-error@+1 {{should not have broadcast dimensions}} + // expected-error@+1 {{requires a projected permutation_map (exactly one dim can appear in each result)}} vector.transfer_write %arg1, %arg0[%c3] {permutation_map = affine_map<(d0) -> (0)>} : vector<7xf32>, memref diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -69,8 +69,6 @@ %7 = vector.transfer_read %arg3[%c3, %c3], %vi0 : memref>, vector<5x48xi8> // CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}}, %{{.*}} : memref, vector<5xf32> %8 = vector.transfer_read %arg0[%c3, %c3], %f0, %m : memref, vector<5xf32> - // CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]], %[[C3]]], %{{.*}}, %{{.*}} : memref, vector<5x4x8xf32> - %9 = vector.transfer_read %arg4[%c3, %c3, %c3], %f0, %m2 {permutation_map = affine_map<(d0, d1, d2)->(d1, d0, 0)>} : memref, vector<5x4x8xf32> // CHECK: vector.transfer_write vector.transfer_write %0, %arg0[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d0)>} : vector<128xf32>, memref diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir --- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir @@ -25,16 +25,15 @@ // CHECK-LABEL: func @vector_transfer_ops_0d_tensor( // CHECK-SAME: %[[SOURCE:.*]]: tensor -func.func @vector_transfer_ops_0d_tensor(%M: tensor) -> vector<1xf32> { +func.func @vector_transfer_ops_0d_tensor(%M: tensor) -> vector { %f0 = arith.constant 0.0 : f32 -// CHECK-NEXT: %[[S:.*]] = tensor.extract %[[SOURCE]][] : tensor -// CHECK-NEXT: %[[V:.*]] = vector.broadcast %[[S]] : f32 to vector<1xf32> - %0 = vector.transfer_read %M[], %f0 {permutation_map = affine_map<()->(0)>} : - tensor, vector<1xf32> +// Does not need to lower frther, 0-D is naturally supported. +// CHECK: %[[V:.*]] = vector.transfer_read %[[SOURCE]][]{{.*}} : tensor, vector + %0 = vector.transfer_read %M[], %f0 : tensor, vector // CHECK-NEXT: return %[[V]] - return %0: vector<1xf32> + return %0: vector } // ----- @@ -191,25 +190,6 @@ // ----- -// Lowering of transfer_read with broadcasting is supported (note that a `load` -// is generated instead of a `vector.load`). -// CHECK-LABEL: func @transfer_broadcasting( -// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>, -// CHECK-SAME: %[[IDX:.*]]: index) -> vector<4xf32> { -// CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32> -// CHECK-NEXT: %[[RES:.*]] = vector.broadcast %[[LOAD]] : f32 to vector<4xf32> -// CHECK-NEXT: return %[[RES]] : vector<4xf32> -// CHECK-NEXT: } - -#broadcast = affine_map<(d0, d1) -> (0)> -func.func @transfer_broadcasting(%mem : memref<8x8xf32>, %i : index) -> vector<4xf32> { - %cf0 = arith.constant 0.0 : f32 - %res = vector.transfer_read %mem[%i, %i], %cf0 {in_bounds = [true], permutation_map = #broadcast} : memref<8x8xf32>, vector<4xf32> - return %res : vector<4xf32> -} - -// ----- - // CHECK-LABEL: func @transfer_scalar( // CHECK-SAME: %[[MEM:.*]]: memref, // CHECK-SAME: %[[IDX:.*]]: index) -> vector<1xf32> { @@ -225,102 +205,20 @@ // ----- -// An example with two broadcasted dimensions. -// CHECK-LABEL: func @transfer_broadcasting_2D( -// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>, -// CHECK-SAME: %[[IDX:.*]]: index) -> vector<4x4xf32> { -// CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32> -// CHECK-NEXT: %[[RES:.*]] = vector.broadcast %[[LOAD]] : f32 to vector<4x4xf32> -// CHECK-NEXT: return %[[RES]] : vector<4x4xf32> -// CHECK-NEXT: } - -#broadcast = affine_map<(d0, d1) -> (0, 0)> -func.func @transfer_broadcasting_2D(%mem : memref<8x8xf32>, %i : index) -> vector<4x4xf32> { - %cf0 = arith.constant 0.0 : f32 - %res = vector.transfer_read %mem[%i, %i], %cf0 {in_bounds = [true, true], permutation_map = #broadcast} : memref<8x8xf32>, vector<4x4xf32> - return %res : vector<4x4xf32> -} - -// ----- - -// More complex broadcasting case (here a `vector.load` is generated). -// CHECK-LABEL: func @transfer_broadcasting_complex( -// CHECK-SAME: %[[MEM:.*]]: memref<10x20x30x8x8xf32>, -// CHECK-SAME: %[[IDX:.*]]: index) -> vector<3x2x4x5xf32> { -// CHECK-NEXT: %[[LOAD:.*]] = vector.load %[[MEM]][%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]] : memref<10x20x30x8x8xf32>, vector<3x1x1x5xf32> -// CHECK-NEXT: %[[RES:.*]] = vector.broadcast %[[LOAD]] : vector<3x1x1x5xf32> to vector<3x2x4x5xf32> -// CHECK-NEXT: return %[[RES]] : vector<3x2x4x5xf32> -// CHECK-NEXT: } - -#broadcast = affine_map<(d0, d1, d2, d3, d4) -> (d1, 0, 0, d4)> -func.func @transfer_broadcasting_complex(%mem : memref<10x20x30x8x8xf32>, %i : index) -> vector<3x2x4x5xf32> { - %cf0 = arith.constant 0.0 : f32 - %res = vector.transfer_read %mem[%i, %i, %i, %i, %i], %cf0 {in_bounds = [true, true, true, true], permutation_map = #broadcast} : memref<10x20x30x8x8xf32>, vector<3x2x4x5xf32> - return %res : vector<3x2x4x5xf32> -} - -// ----- - -#map0 = affine_map<(d0, d1, d2, d3) -> (d1, d0, 0, 0)> -#map1 = affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d0)> -#map2 = affine_map<(d0, d1, d2, d3) -> (d3, d1, 0, 0)> -#map3 = affine_map<(d0, d1) -> (d1, d0, 0, 0)> -#map4 = affine_map<(d0, d1) -> (0, d1, 0, d0)> -#map5 = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3, d0)> -#map6 = affine_map<(d0, d1) -> (0)> - -// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, 0, 0)> -// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, 0, d3)> +#map = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3, d0)> // CHECK-LABEL: func @transfer_read_permutations -func.func @transfer_read_permutations(%arg0 : memref, %arg1 : memref, %m: i1) - -> (vector<7x14x8x16xf32>, vector<7x14x8x16xf32>, vector<7x14x8x16xf32>, - vector<7x14x8x16xf32>, vector<7x14x8x16xf32>, vector<7x14x8x16xf32>, vector<8xf32>) { +func.func @transfer_read_permutations(%arg0 : memref, %arg1 : memref, %m: i1) -> vector<7x14x8x16xf32> { // CHECK-DAG: %[[CF0:.*]] = arith.constant 0.000000e+00 : f32 // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index %cst = arith.constant 0.000000e+00 : f32 %c0 = arith.constant 0 : index -// CHECK: %[[MASK0:.*]] = vector.splat %{{.*}} : vector<14x7xi1> - %mask0 = vector.splat %m : vector<7x14xi1> - %0 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst, %mask0 {in_bounds = [true, false, true, true], permutation_map = #map0} : memref, vector<7x14x8x16xf32> -// CHECK: vector.transfer_read {{.*}} %[[MASK0]] {in_bounds = [false, true, true, true], permutation_map = #[[$MAP0]]} : memref, vector<14x7x8x16xf32> -// CHECK: vector.transpose %{{.*}}, [1, 0, 2, 3] : vector<14x7x8x16xf32> to vector<7x14x8x16xf32> - -// CHECK: %[[MASK1:.*]] = vector.splat %{{.*}} : vector<16x14xi1> - %mask1 = vector.splat %m : vector<14x16xi1> - %1 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst, %mask1 {permutation_map = #map1} : memref, vector<7x14x8x16xf32> -// CHECK: vector.transfer_read {{.*}} %[[MASK1]] {permutation_map = #[[$MAP0]]} : memref, vector<16x14x7x8xf32> -// CHECK: vector.transpose %{{.*}}, [2, 1, 3, 0] : vector<16x14x7x8xf32> to vector<7x14x8x16xf32> - -// CHECK: %[[MASK3:.*]] = vector.splat %{{.*}} : vector<14x7xi1> - %mask2 = vector.splat %m : vector<7x14xi1> - %2 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst, %mask2 {in_bounds = [true, false, true, true], permutation_map = #map2} : memref, vector<7x14x8x16xf32> -// CHECK: vector.transfer_read {{.*}} %[[MASK3]] {in_bounds = [false, true, true], permutation_map = #[[$MAP1]]} : memref, vector<14x16x7xf32> -// CHECK: vector.broadcast %{{.*}} : vector<14x16x7xf32> to vector<8x14x16x7xf32> -// CHECK: vector.transpose %{{.*}}, [3, 1, 0, 2] : vector<8x14x16x7xf32> to vector<7x14x8x16xf32> - - %3 = vector.transfer_read %arg0[%c0, %c0], %cst {permutation_map = #map3} : memref, vector<7x14x8x16xf32> -// CHECK: vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %[[CF0]] : memref, vector<14x7xf32> -// CHECK: vector.broadcast %{{.*}} : vector<14x7xf32> to vector<8x16x14x7xf32> -// CHECK: vector.transpose %{{.*}}, [3, 2, 0, 1] : vector<8x16x14x7xf32> to vector<7x14x8x16xf32> - - %4 = vector.transfer_read %arg0[%c0, %c0], %cst {permutation_map = #map4} : memref, vector<7x14x8x16xf32> -// CHECK: vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %[[CF0]] : memref, vector<16x14xf32> -// CHECK: vector.broadcast %{{.*}} : vector<16x14xf32> to vector<7x8x16x14xf32> -// CHECK: vector.transpose %{{.*}}, [0, 3, 1, 2] : vector<7x8x16x14xf32> to vector<7x14x8x16xf32> - - %5 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst {permutation_map = #map5} : memref, vector<7x14x8x16xf32> + %0 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst {permutation_map = #map} : memref, vector<7x14x8x16xf32> // CHECK: vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[CF0]] : memref, vector<16x14x7x8xf32> // CHECK: vector.transpose %{{.*}}, [2, 1, 3, 0] : vector<16x14x7x8xf32> to vector<7x14x8x16xf32> - %6 = vector.transfer_read %arg0[%c0, %c0], %cst {permutation_map = #map6} : memref, vector<8xf32> -// CHECK: memref.load %{{.*}}[%[[C0]], %[[C0]]] : memref -// CHECK: vector.broadcast %{{.*}} : f32 to vector<8xf32> - - return %0, %1, %2, %3, %4, %5, %6 : vector<7x14x8x16xf32>, vector<7x14x8x16xf32>, - vector<7x14x8x16xf32>, vector<7x14x8x16xf32>, vector<7x14x8x16xf32>, - vector<7x14x8x16xf32>, vector<8xf32> + return %0 : vector<7x14x8x16xf32> } // ----- diff --git a/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir b/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir --- a/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir @@ -179,59 +179,6 @@ // ----- -// CHECK-LABEL: func @transfer_read_unroll_broadcast -// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[VTR0:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32> -// CHECK-NEXT: %[[VEC0:.*]] = vector.insert_strided_slice %[[VTR0]], %{{.*}} {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32> -// CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C2]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32> -// CHECK-NEXT: %[[VEC1:.*]] = vector.insert_strided_slice %[[VTR1]], %[[VEC0]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32> -// CHECK-NEXT: %[[VTR2:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32> -// CHECK-NEXT: %[[VEC2:.*]] = vector.insert_strided_slice %[[VTR2]], %[[VEC1]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32> -// CHECK-NEXT: %[[VTR3:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C2]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32> -// CHECK-NEXT: %[[VEC3:.*]] = vector.insert_strided_slice %[[VTR3]], %[[VEC2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32> -// CHECK-NEXT: %[[VTR4:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32> -// CHECK-NEXT: %[[VEC4:.*]] = vector.insert_strided_slice %[[VTR4]], %[[VEC3]] {offsets = [4, 0], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32> -// CHECK-NEXT: %[[VTR5:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C2]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32> -// CHECK-NEXT: %[[VEC5:.*]] = vector.insert_strided_slice %[[VTR5]], %[[VEC4]] {offsets = [4, 2], strides = [1, 1]} : vector<2x2xf32> into vector<6x4xf32> -// CHECK-NEXT: return %[[VEC5]] : vector<6x4xf32> -#map0 = affine_map<(d0, d1) -> (0, d1)> -func.func @transfer_read_unroll_broadcast(%arg0 : memref<6x4xf32>) -> vector<6x4xf32> { - %c0 = arith.constant 0 : index - %cf0 = arith.constant 0.0 : f32 - %0 = vector.transfer_read %arg0[%c0, %c0], %cf0 {permutation_map = #map0} : memref<6x4xf32>, vector<6x4xf32> - return %0 : vector<6x4xf32> -} - -// ----- - -// CHECK-LABEL: func @transfer_read_unroll_broadcast_permuation -// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index -// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[VTR0:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32> -// CHECK-NEXT: %[[VEC0:.*]] = vector.insert_strided_slice %[[VTR0]], %{{.*}} {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x6xf32> -// CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32> -// CHECK-NEXT: %[[VEC1:.*]] = vector.insert_strided_slice %[[VTR1]], %[[VEC0]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x6xf32> -// CHECK-NEXT: %[[VTR2:.*]] = vector.transfer_read {{.*}}[%[[C4]], %[[C0]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32> -// CHECK-NEXT: %[[VEC2:.*]] = vector.insert_strided_slice %[[VTR2]], %[[VEC1]] {offsets = [0, 4], strides = [1, 1]} : vector<2x2xf32> into vector<4x6xf32> -// CHECK-NEXT: %[[VTR3:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32> -// CHECK-NEXT: %[[VEC3:.*]] = vector.insert_strided_slice %[[VTR3]], %[[VEC2]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x6xf32> -// CHECK-NEXT: %[[VTR4:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32> -// CHECK-NEXT: %[[VEC4:.*]] = vector.insert_strided_slice %[[VTR4]], %[[VEC3]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x6xf32> -// CHECK-NEXT: %[[VTR5:.*]] = vector.transfer_read {{.*}}[%[[C4]], %[[C0]]], %{{.*}} : memref<6x4xf32>, vector<2x2xf32> -// CHECK-NEXT: %[[VEC5:.*]] = vector.insert_strided_slice %[[VTR5]], %[[VEC4]] {offsets = [2, 4], strides = [1, 1]} : vector<2x2xf32> into vector<4x6xf32> -// CHECK-NEXT: return %[[VEC5]] : vector<4x6xf32> -#map0 = affine_map<(d0, d1) -> (0, d0)> -func.func @transfer_read_unroll_broadcast_permuation(%arg0 : memref<6x4xf32>) -> vector<4x6xf32> { - %c0 = arith.constant 0 : index - %cf0 = arith.constant 0.0 : f32 - %0 = vector.transfer_read %arg0[%c0, %c0], %cf0 {permutation_map = #map0} : memref<6x4xf32>, vector<4x6xf32> - return %0 : vector<4x6xf32> -} - -// ----- - // CHECK-LABEL: func @transfer_read_unroll_different_rank // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir @@ -86,17 +86,6 @@ return } -// Broadcast. -func.func @transfer_read_1d_broadcast( - %A : memref, %base1 : index, %base2 : index) { - %fm42 = arith.constant -42.0: f32 - %f = vector.transfer_read %A[%base1, %base2], %fm42 - {permutation_map = affine_map<(d0, d1) -> (0)>} - : memref, vector<9xf32> - vector.print %f: vector<9xf32> - return -} - // Non-contiguous, strided load. func.func @transfer_read_1d_in_bounds( %A : memref, %base1 : index, %base2 : index) { @@ -190,34 +179,28 @@ call @transfer_read_1d(%A, %c0, %c2) : (memref, index, index) -> () // CHECK: ( 2, 12, 22, -1, -1, -42, -42, -42, -42 ) - // 6. Read a scalar from a 2D memref and broadcast the value to a 1D vector. - // Generates a loop with vector.insertelement. - call @transfer_read_1d_broadcast(%A, %c1, %c2) - : (memref, index, index) -> () - // CHECK: ( 12, 12, 12, 12, 12, 12, 12, 12, 12 ) - - // 7. Read from 2D memref on first dimension. Accesses are in-bounds, so no + // 6. Read from 2D memref on first dimension. Accesses are in-bounds, so no // if-check is generated inside the generated loop. call @transfer_read_1d_in_bounds(%A, %c1, %c2) : (memref, index, index) -> () // CHECK: ( 12, 22, -1 ) - // 8. Optional mask attribute is specified and, in addition, there may be + // 7. Optional mask attribute is specified and, in addition, there may be // out-of-bounds accesses. call @transfer_read_1d_mask(%A, %c1, %c2) : (memref, index, index) -> () // CHECK: ( 12, -42, -1, -42, -42, -42, -42, -42, -42 ) - // 9. Same as 8, but accesses are in-bounds. + // 8. Same as 7, but accesses are in-bounds. call @transfer_read_1d_mask_in_bounds(%A, %c1, %c2) : (memref, index, index) -> () // CHECK: ( 12, -42, -1 ) - // 10. Write to 2D memref on first dimension with a mask. + // 9. Write to 2D memref on first dimension with a mask. call @transfer_write_1d_mask(%A, %c1, %c0) : (memref, index, index) -> () - // 11. (Same as 1. To check if 10 works correctly.) + // 10. (Same as 1. To check if 9 works correctly.) call @transfer_read_1d(%A, %c0, %c0) : (memref, index, index) -> () // CHECK: ( 0, -2, 20, -2, 40, -42, -42, -42, -42 ) diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir @@ -62,30 +62,6 @@ return } -// Vector load with mask + broadcast. -func.func @transfer_read_2d_mask_broadcast( - %A : memref, %base1: index, %base2: index) { - %fm42 = arith.constant -42.0: f32 - %mask = arith.constant dense<[1, 0, 1, 0, 1, 1, 1, 0, 1]> : vector<9xi1> - %f = vector.transfer_read %A[%base1, %base2], %fm42, %mask - {permutation_map = affine_map<(d0, d1) -> (0, d1)>} : - memref, vector<4x9xf32> - vector.print %f: vector<4x9xf32> - return -} - -// Transpose + vector load with mask + broadcast. -func.func @transfer_read_2d_mask_transpose_broadcast_last_dim( - %A : memref, %base1: index, %base2: index) { - %fm42 = arith.constant -42.0: f32 - %mask = arith.constant dense<[1, 0, 1, 1]> : vector<4xi1> - %f = vector.transfer_read %A[%base1, %base2], %fm42, %mask - {permutation_map = affine_map<(d0, d1) -> (d1, 0)>} : - memref, vector<4x9xf32> - vector.print %f: vector<4x9xf32> - return -} - // Load + transpose. func.func @transfer_read_2d_transposed( %A : memref, %base1: index, %base2: index) { @@ -97,17 +73,6 @@ return } -// Load 1D + broadcast to 2D. -func.func @transfer_read_2d_broadcast( - %A : memref, %base1: index, %base2: index) { - %fm42 = arith.constant -42.0: f32 - %f = vector.transfer_read %A[%base1, %base2], %fm42 - {permutation_map = affine_map<(d0, d1) -> (d1, 0)>} : - memref, vector<4x9xf32> - vector.print %f: vector<4x9xf32> - return -} - // Vector store. func.func @transfer_write_2d(%A : memref, %base1: index, %base2: index) { %fn1 = arith.constant -1.0 : f32 @@ -158,36 +123,17 @@ : (memref, index, index) -> () // CHECK: ( ( 0, -42, 20, -42 ), ( -42, -42, 21, -42 ), ( 2, 12, 22, -42 ), ( -42, 13, 23, -42 ), ( -42, -42, -42, -42 ), ( -42, -42, -42, -42 ), ( -42, -42, -42, -42 ), ( -42, -42, -42, -42 ), ( -42, -42, -42, -42 ) ) - // 5. Read 1D vector from 2D memref at specified location and broadcast the - // result to 2D. - call @transfer_read_2d_broadcast(%A, %c1, %c2) - : (memref, index, index) -> () - // CHECK: ( ( 12, 12, 12, 12, 12, 12, 12, 12, 12 ), ( 13, 13, 13, 13, 13, 13, 13, 13, 13 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) ) - - // 6. Read 1D vector from 2D memref at specified location with mask and - // broadcast the result to 2D. - call @transfer_read_2d_mask_broadcast(%A, %c2, %c1) - : (memref, index, index) -> () - // CHECK: ( ( 21, -42, 23, -42, -42, -42, -42, -42, -42 ), ( 21, -42, 23, -42, -42, -42, -42, -42, -42 ), ( 21, -42, 23, -42, -42, -42, -42, -42, -42 ), ( 21, -42, 23, -42, -42, -42, -42, -42, -42 ) ) - - // 7. Read 1D vector from 2D memref (second dimension) at specified location - // with mask and broadcast the result to 2D. In this test case, mask - // elements must be evaluated before lowering to an (N>1)-D transfer. - call @transfer_read_2d_mask_transpose_broadcast_last_dim(%A, %c0, %c1) - : (memref, index, index) -> () - // CHECK: ( ( 1, 1, 1, 1, 1, 1, 1, 1, 1 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ), ( 3, 3, 3, 3, 3, 3, 3, 3, 3 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) ) - - // 8. Write 2D vector into 2D memref at specified location. + // 5. Write 2D vector into 2D memref at specified location. call @transfer_write_2d(%A, %c1, %c2) : (memref, index, index) -> () - // 9. Read memref to verify step 8. + // 6. Read memref to verify step 5. call @transfer_read_2d(%A, %c0, %c0) : (memref, index, index) -> () // CHECK: ( ( 0, 1, 2, 3, -42, -42, -42, -42, -42 ), ( 10, 11, -1, -1, -42, -42, -42, -42, -42 ), ( 20, 21, 22, 23, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) ) - // 10. Write 2D vector into 2D memref at specified location with mask. + // 7. Write 2D vector into 2D memref at specified location with mask. call @transfer_write_2d_mask(%A, %c0, %c2) : (memref, index, index) -> () - // 11. Read memref to verify step 10. + // 8. Read memref to verify step 7. call @transfer_read_2d(%A, %c0, %c0) : (memref, index, index) -> () // CHECK: ( ( 0, 1, -2, 3, -42, -42, -42, -42, -42 ), ( 10, 11, -1, -1, -42, -42, -42, -42, -42 ), ( 20, 21, 22, 23, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) ) diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir @@ -38,27 +38,6 @@ return } -func.func @transfer_read_3d_broadcast(%A : memref, - %o: index, %a: index, %b: index, %c: index) { - %fm42 = arith.constant -42.0: f32 - %f = vector.transfer_read %A[%o, %a, %b, %c], %fm42 - {permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, 0, d3)>} - : memref, vector<2x5x3xf32> - vector.print %f: vector<2x5x3xf32> - return -} - -func.func @transfer_read_3d_mask_broadcast( - %A : memref, %o: index, %a: index, %b: index, %c: index) { - %fm42 = arith.constant -42.0: f32 - %mask = arith.constant dense<[0, 1]> : vector<2xi1> - %f = vector.transfer_read %A[%o, %a, %b, %c], %fm42, %mask - {permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, 0, 0)>} - : memref, vector<2x5x3xf32> - vector.print %f: vector<2x5x3xf32> - return -} - func.func @transfer_read_3d_transposed(%A : memref, %o: index, %a: index, %b: index, %c: index) { %fm42 = arith.constant -42.0: f32 @@ -134,16 +113,6 @@ : (memref, index, index, index, index) -> () // CHECK: ( ( ( 0, 20, 40 ), ( 0, 20, 40 ), ( 0, 20, 40 ), ( 0, 20, 40 ), ( 0, 20, 40 ) ), ( ( 0, 30, 60 ), ( 0, 30, 60 ), ( 0, 30, 60 ), ( 0, 30, 60 ), ( 0, 30, 60 ) ), ( ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ) ) ) - // 6. Read 1D vector from 4D memref and broadcast vector to 3D. - call @transfer_read_3d_broadcast(%A, %c0, %c0, %c0, %c0) - : (memref, index, index, index, index) -> () - // CHECK: ( ( ( 0, 0, -42 ), ( 0, 0, -42 ), ( 0, 0, -42 ), ( 0, 0, -42 ), ( 0, 0, -42 ) ), ( ( 20, 30, -42 ), ( 20, 30, -42 ), ( 20, 30, -42 ), ( 20, 30, -42 ), ( 20, 30, -42 ) ) ) - - // 7. Read 1D vector from 4D memref with mask and broadcast vector to 3D. - call @transfer_read_3d_mask_broadcast(%A, %c0, %c0, %c0, %c0) - : (memref, index, index, index, index) -> () - // CHECK: ( ( ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ) ), ( ( 20, 20, 20 ), ( 20, 20, 20 ), ( 20, 20, 20 ), ( 20, 20, 20 ), ( 20, 20, 20 ) ) ) - memref.dealloc %A : memref return }