diff --git a/mlir/docs/Dialects/Linalg/_index.md b/mlir/docs/Dialects/Linalg/_index.md --- a/mlir/docs/Dialects/Linalg/_index.md +++ b/mlir/docs/Dialects/Linalg/_index.md @@ -103,10 +103,10 @@ #identity = affine_map<(d0) -> (d0)> func.func @example(%A: memref, - %B: memref, offset: 1, strides: [2]>) { + %B: memref, strided<[2], offset: 1>>) { linalg.generic #attrs ins(%A: memref) - outs(%B: memref, offset: 1, strides: [2]>) { + outs(%B: memref, strided<[2], offset: 1>>) { ^bb0(%a: f32, %b: vector<4xf32>): %c = "some_compute"(%a, %b): (f32, vector<4xf32>) -> (vector<4xf32>) linalg.yield %c: vector<4xf32> @@ -186,10 +186,10 @@ iterator_types = ["parallel", "parallel"] } -func.func @example(%A: memref<8x?xf32, offset: 0, strides: [2, 2]>, +func.func @example(%A: memref<8x?xf32, strided<[2, 2], offset: 0>>, %B: memref>) { linalg.generic #attrs - ins(%A: memref<8x?xf32, offset: 0, strides: [2, 2]>) + ins(%A: memref<8x?xf32, strided<[2, 2], offset: 0>>) outs(%B: memref>) { ^bb0(%a: f32, %b: vector<4xf32>): %c = "some_compute"(%a, %b): (f32, vector<4xf32>) -> (vector<4xf32>) diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -343,12 +343,12 @@ // The same holds true for offsets and strides. // Assert that the input dynamic shape matches the destination static stride. - %4 = memref.cast %1 : memref<12x4xf32, offset:?, strides: [?, ?]> to - memref<12x4xf32, offset:5, strides: [4, 1]> + %4 = memref.cast %1 : memref<12x4xf32, strided<[?, ?], offset: ?>> to + memref<12x4xf32, strided<[4, 1], offset: 5>> // Erase static offset and stride information, replacing it with // dynamic information. - %5 = memref.cast %1 : memref<12x4xf32, offset:5, strides: [4, 1]> to - memref<12x4xf32, offset:?, strides: [?, ?]> + %5 = memref.cast %1 : memref<12x4xf32, strided<[4, 1], offset: 5>> to + memref<12x4xf32, strided<[?, ?], offset: ?>> ``` b. Either or both memref types are unranked with the same element type, and @@ -1067,13 +1067,13 @@ offset: [0], sizes: [%size0, 10], strides: [1, %stride1] - : memref to memref + : memref to memref> memref.reinterpret_cast %unranked to offset: [%offset], sizes: [%size0, %size1], strides: [%stride0, %stride1] - : memref<*xf32> to memref + : memref<*xf32> to memref> ``` }]; @@ -1639,7 +1639,7 @@ // After rank reducing: // (d0, d1) -> (4 * d0 + d1 + 210) %3 = memref.subview %2[3, 4, 2][1, 6, 3][1, 1, 1] : - memref<8x16x4xf32> to memref<6x3xf32, offset: 210, strides: [4, 1]> + memref<8x16x4xf32> to memref<6x3xf32, strided<[4, 1], offset: 210>> ``` }]; diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h --- a/mlir/include/mlir/IR/BuiltinAttributes.h +++ b/mlir/include/mlir/IR/BuiltinAttributes.h @@ -1074,6 +1074,30 @@ // Attribute Utilities //===----------------------------------------------------------------------===// +namespace mlir { + +/// Given a list of strides (in which MemRefType::getDynamicStrideOrOffset() +/// represents a dynamic value), return the single result AffineMap which +/// represents the linearized strided layout map. Dimensions correspond to the +/// offset followed by the strides in order. Symbols are inserted for each +/// dynamic dimension in order. A stride cannot take value `0`. +/// +/// Examples: +/// ========= +/// +/// 1. For offset: 0 strides: ?, ?, 1 return +/// (i, j, k)[M, N]->(M * i + N * j + k) +/// +/// 2. For offset: 3 strides: 32, ?, 16 return +/// (i, j, k)[M]->(3 + 32 * i + M * j + 16 * k) +/// +/// 3. For offset: ? strides: ?, ?, ? return +/// (i, j, k)[off, M, N, P]->(off + M * i + N * j + P * k) +AffineMap makeStridedLinearLayoutMap(ArrayRef strides, int64_t offset, + MLIRContext *context); + +} // namespace mlir + namespace llvm { template <> diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td --- a/mlir/include/mlir/IR/BuiltinAttributes.td +++ b/mlir/include/mlir/IR/BuiltinAttributes.td @@ -966,6 +966,59 @@ let skipDefaultBuilders = 1; } +//===----------------------------------------------------------------------===// +// StridedLayoutAttr +//===----------------------------------------------------------------------===// + +def StridedLayoutAttr : Builtin_Attr<"StridedLayout", + [DeclareAttrInterfaceMethods]> { + let summary = "An Attribute representing a strided layout of a shaped type"; + let description = [{ + Syntax: + + ``` + strided-layout-attribute ::= `strided` `<` `[` stride-list `]` + (`,` `offset` `:` dimension)? `>` + stride-list ::= /*empty*/ + | dimension (`,` dimension)* + dimension ::= decimal-literal | `?` + ``` + + A strided layout attribute captures layout information of the memref type in + the canonical form. Specifically, it contains a list of _strides_ along each + dimension. A stride is the number of elements in the linear storage one must + step over to reflect an increment in the given dimension. For example, a + `MxN` row-major contiguous shaped type would have the strides `[N, 1]`. The + layout attribute also contains the _offset_ from the base pointer of the + shaped type to the first effectively accessed element, expressed in terms of + the number of contiguously stored elements. + + Both the strides and the offset may be _dynamic_, i.e. their value may not + be known at compile time. This is expressed as a `?` in the assembly syntax + and as `ShapedType::kDynamicStrideOrOffset` in the code. + + See [Dialects/Builtin.md#memreftype](MemRef type) for more information. + }]; + + let mnemonic = "strided"; + let parameters = (ins + "int64_t":$offset, + ArrayRefParameter< + "int64_t", + "array of strides (64-bit integer)" + >:$strides + ); + let hasCustomAssemblyFormat = 1; + let genVerifyDecl = 1; + + let extraClassDeclaration = [{ + /// Print the attribute to the given output stream. + void print(raw_ostream &os) const; + }]; +} + + //===----------------------------------------------------------------------===// // StringAttr //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -436,26 +436,6 @@ SmallVectorImpl &strides, AffineExpr &offset); -/// Given a list of strides (in which MemRefType::getDynamicStrideOrOffset() -/// represents a dynamic value), return the single result AffineMap which -/// represents the linearized strided layout map. Dimensions correspond to the -/// offset followed by the strides in order. Symbols are inserted for each -/// dynamic dimension in order. A stride cannot take value `0`. -/// -/// Examples: -/// ========= -/// -/// 1. For offset: 0 strides: ?, ?, 1 return -/// (i, j, k)[M, N]->(M * i + N * j + k) -/// -/// 2. For offset: 3 strides: 32, ?, 16 return -/// (i, j, k)[M]->(3 + 32 * i + M * j + 16 * k) -/// -/// 3. For offset: ? strides: ?, ?, ? return -/// (i, j, k)[off, M, N, P]->(off + M * i + N * j + P * k) -AffineMap makeStridedLinearLayoutMap(ArrayRef strides, int64_t offset, - MLIRContext *context); - /// Return a version of `t` with identity layout if it can be determined /// statically that the layout is the canonical contiguous strided layout. /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -280,10 +280,7 @@ ``` memref-type ::= `memref` `<` dimension-list-ranked type (`,` layout-specification)? (`,` memory-space)? `>` - - stride-list ::= `[` (dimension (`,` dimension)*)? `]` - strided-layout ::= `offset:` dimension `,` `strides: ` stride-list - layout-specification ::= semi-affine-map | strided-layout | attribute-value + layout-specification ::= attribute-value memory-space ::= attribute-value ``` @@ -310,6 +307,84 @@ - another memref type; - any other type implementing `MemRefElementTypeInterface`. + ##### Layout + + A memref may optionally have a layout that indicates the how the indices are + transformed from the multi-dimensional form into a linear address. The + layout must avoid internal aliasing, i.e., two distinct tuples of + _in-bounds_ indices must be pointing to different elements in memory. The + layout is an attribute that implements `MemRefLayoutAttrInterface`. The + bulitin dialect offers two kinds of layouts: strided and affine map, each + of which is available as an attribute. Other attributes may be used to + represent the layout as long as they can be converted to a + [semi-affine map](Affine.md/#semi-affine-maps) and implement the required + interface. Users of memref are expected to fallback to the affine + representation when handling unknown memref layouts. Multi-dimensional + affine forms are interpreted in _row-major_ fasion. + + In absence of an explicit layout, a memref is considered to have a + multi-dimensional identity affine map layout. Identity layout maps do not + contribute to the MemRef type identification and are discarded on + construction. That is, a type with an explicit identity map is + `memref(i,j)>` is strictly the same as the one without a + layout, `memref`. + + ##### Affine Map Layout + + The layout may be represented directly as an affine map from the index space + to the storage space. For example, the following figure shows an index map + which maps a 2-dimensional index from a 2x2 index space to a 3x3 index + space, using symbols `S0` and `S1` as offsets. + + ![Index Map Example](/includes/img/index-map.svg) + + Semi-affine maps are sufficiently flexible to represent a wide varity of + dense storage layouts, incloding row- and column-major and tiled: + + ```mlir + // MxN matrix stored in row major layout in memory: + #layout_map_row_major = (i, j) -> (i, j) + + // MxN matrix stored in column major layout in memory: + #layout_map_col_major = (i, j) -> (j, i) + + // MxN matrix stored in a 2-d blocked/tiled layout with 64x64 tiles. + #layout_tiled = (i, j) -> (i floordiv 64, j floordiv 64, i mod 64, j mod 64) + ``` + + ##### Strided Layout + + Memref layout can be expressed using strides to encode the distance, in + number of elements, in (linear) memory between successive entries along a + particular dimension. For example, a row-major strided layout for + `memref<2x3x4xf32>` is `strided<[12, 4, 1]>`, where the last dimension is + contiguous as indicated by the unit stride and the remaining strides are + products of the sizes of faster-variying dimensions. Strided layout can also + express non-contiguity, e.g., `memref<2x3, strided<[6, 2]>>` only accesses + even elements of the dense consecutive storage along the innermost + dimension. + + The strided layout supports an optional _offset_ that indicates the + distance, in the number of elements, between the beginning of the memref + and the first accessed element. When omitted, the offset is considered to + be zero. That is, `memref<2, strided<[2], offset: 0>>` and + `memref<2, strided<[2]>` are strictly the same type. + + Both offsets and strides may be _dynamic_, that is, unknown at compile time. + This is represented by using a question mark (`?`) instead of the value in + the textual form of the IR. + + The strided layout converts into the following canonical one-dimensional + affine form through explicit linearization: + + ```mlir + affine_map<(d0, ... dN)[offset, stride0, ... strideN] -> + (offset + d0 * stride0 + ... dN * strideN)> + ``` + + Therefore, it is never subject to the implicit row-major layout + interpretation. + ##### Codegen of Unranked Memref Using unranked memref in codegen besides the case mentioned above is highly @@ -425,104 +500,6 @@ %o = ... %A = alloc (%n)[%o] : <16x?xf32, #imapS> ``` - - ##### Index Space - - A memref dimension list defines an index space within which the memref can - be indexed to access data. - - ##### Index - - Data is accessed through a memref type using a multidimensional index into - the multidimensional index space defined by the memref's dimension list. - - Examples - - ```mlir - // Allocates a memref with 2D index space: - // { (i, j) : 0 <= i < 16, 0 <= j < 32 } - %A = alloc() : memref<16x32xf32, #imapA> - - // Loads data from memref '%A' using a 2D index: (%i, %j) - %v = load %A[%i, %j] : memref<16x32xf32, #imapA> - ``` - - ##### Index Map - - An index map is a one-to-one - [semi-affine map](Affine.md/#semi-affine-maps) that transforms a - multidimensional index from one index space to another. For example, the - following figure shows an index map which maps a 2-dimensional index from a - 2x2 index space to a 3x3 index space, using symbols `S0` and `S1` as - offsets. - - ![Index Map Example](/includes/img/index-map.svg) - - The number of domain dimensions and range dimensions of an index map can be - different, but must match the number of dimensions of the input and output - index spaces on which the map operates. The index space is always - non-negative and integral. In addition, an index map must specify the size - of each of its range dimensions onto which it maps. Index map symbols must - be listed in order with symbols for dynamic dimension sizes first, followed - by other required symbols. - - ##### Layout Map - - A layout map is a [semi-affine map](Affine.md/#semi-affine-maps) - which encodes logical to physical index space mapping, by mapping input - dimensions to their ordering from most-major (slowest varying) to most-minor - (fastest varying). Therefore, an identity layout map corresponds to a - row-major layout. Identity layout maps do not contribute to the MemRef type - identification and are discarded on construction. That is, a type with an - explicit identity map is `memref(i,j)>` is strictly the - same as the one without layout maps, `memref`. - - Layout map examples: - - ```mlir - // MxN matrix stored in row major layout in memory: - #layout_map_row_major = (i, j) -> (i, j) - - // MxN matrix stored in column major layout in memory: - #layout_map_col_major = (i, j) -> (j, i) - - // MxN matrix stored in a 2-d blocked/tiled layout with 64x64 tiles. - #layout_tiled = (i, j) -> (i floordiv 64, j floordiv 64, i mod 64, j mod 64) - ``` - - ##### Strided MemRef - - A memref may specify a strided layout as part of its type. A stride - specification is a list of integer values that are either static or `?` - (dynamic case). - Strides encode the distance, in number of elements, in (linear) memory - between successive entries along a particular dimension. A stride - specification is syntactic sugar for an equivalent strided memref - representation with a *single* semi-affine map. - - For example, `memref<42x16xf32, offset: 33, strides: [1, 64]>` specifies a - non-contiguous memory region of `42` by `16` `f32` elements such that: - - 1. the minimal size of the enclosing memory region must be - `33 + 42 * 1 + 16 * 64 = 1066` elements; - 2. the address calculation for accessing element `(i, j)` computes - `33 + i + 64 * j` - 3. the distance between two consecutive elements along the inner dimension - is `1` element and the distance between two consecutive elements along - the outer dimension is `64` elements. - - This corresponds to a column major view of the memory region and is - internally represented as the type - `memref<42x16xf32, (i, j) -> (33 + i + 64 * j)>`. - - The specification of strides must not alias: given an n-D strided memref, - indices `(i1, ..., in)` and `(j1, ..., jn)` may not refer to the same memory - address unless `i1 == j1, ..., in == jn`. - - Strided memrefs represent a view abstraction over preallocated data. They - are constructed with special ops, yet to be introduced. Strided memrefs are - a special subclass of memrefs with generic semi-affine map and correspond to - a normalized memref descriptor when lowering to LLVM. }]; let parameters = (ins ArrayRefParameter<"int64_t">:$shape, diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp --- a/mlir/lib/AsmParser/AttributeParser.cpp +++ b/mlir/lib/AsmParser/AttributeParser.cpp @@ -15,6 +15,7 @@ #include "AsmParserImpl.h" #include "mlir/AsmParser/AsmParserState.h" #include "mlir/IR/AffineMap.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" @@ -42,6 +43,8 @@ /// (tensor-type | vector-type) /// | `sparse` `<` attribute-value `,` attribute-value `>` /// `:` (tensor-type | vector-type) +/// | `strided` `<` `[` comma-separated-int-or-question `]` +/// (`,` `offset` `:` integer-literal)? `>` /// | extended-attribute /// Attribute Parser::parseAttribute(Type type) { @@ -147,6 +150,10 @@ case Token::kw_sparse: return parseSparseElementsAttr(type); + // Parse a strided layout attribute. + case Token::kw_strided: + return parseStridedLayoutAttr(); + // Parse a string attribute. case Token::string: { auto val = getToken().getStringValue(); @@ -1072,3 +1079,74 @@ // Build the sparse elements attribute by the indices and values. return getChecked(loc, type, indices, values); } + +Attribute Parser::parseStridedLayoutAttr() { + // Callback for error emissing at the keyword token location. + llvm::SMLoc loc = getToken().getLoc(); + auto errorEmitter = [&] { return emitError(loc); }; + + consumeToken(Token::kw_strided); + if (failed(parseToken(Token::less, "expected '<' after 'strided'")) || + failed(parseToken(Token::l_square, "expected '['"))) + return nullptr; + + // Parses either an integer token or a question mark token. Reports an error + // and returns None if the current token is neither. The integer token must + // fit into int64_t limits. + auto parseStrideOrOffset = [&]() -> Optional { + if (consumeIf(Token::question)) + return ShapedType::kDynamicStrideOrOffset; + + SMLoc loc = getToken().getLoc(); + auto emitWrongTokenError = [&] { + emitError(loc, "expected a non-negative 64-bit signed integer or '?'"); + return llvm::None; + }; + + if (getToken().is(Token::integer)) { + Optional value = getToken().getUInt64IntegerValue(); + if (!value || *value > std::numeric_limits::max()) + return emitWrongTokenError(); + consumeToken(); + return static_cast(*value); + } + + return emitWrongTokenError(); + }; + + // Parse strides. + SmallVector strides; + if (!getToken().is(Token::r_square)) { + do { + Optional stride = parseStrideOrOffset(); + if (!stride) + return nullptr; + strides.push_back(*stride); + } while (consumeIf(Token::comma)); + } + + if (failed(parseToken(Token::r_square, "expected ']'"))) + return nullptr; + + // Fast path in absence of offset. + if (consumeIf(Token::greater)) { + if (failed(StridedLayoutAttr::verify(errorEmitter, + /*offset=*/0, strides))) + return nullptr; + return StridedLayoutAttr::get(getContext(), /*offset=*/0, strides); + } + + if (failed(parseToken(Token::comma, "expected ','")) || + failed(parseToken(Token::kw_offset, "expected 'offset' after comma")) || + failed(parseToken(Token::colon, "expected ':' after 'offset'"))) + return nullptr; + + Optional offset = parseStrideOrOffset(); + if (!offset || failed(parseToken(Token::greater, "expected '>'"))) + return nullptr; + + if (failed(StridedLayoutAttr::verify(errorEmitter, *offset, strides))) + return nullptr; + return StridedLayoutAttr::get(getContext(), *offset, strides); + // return getChecked(loc,getContext(), *offset, strides); +} diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h --- a/mlir/lib/AsmParser/Parser.h +++ b/mlir/lib/AsmParser/Parser.h @@ -217,14 +217,6 @@ ParseResult parseIntegerInDimensionList(int64_t &value); ParseResult parseXInDimensionList(); - /// Parse strided layout specification. - ParseResult parseStridedLayout(int64_t &offset, - SmallVectorImpl &strides); - - // Parse a brace-delimiter list of comma-separated integers with `?` as an - // unknown marker. - ParseResult parseStrideList(SmallVectorImpl &dimensions); - //===--------------------------------------------------------------------===// // Attribute Parsing //===--------------------------------------------------------------------===// @@ -279,6 +271,9 @@ /// Parse a sparse elements attribute. Attribute parseSparseElementsAttr(Type attrType); + /// Parse a strided layout attribute. + Attribute parseStridedLayoutAttr(); + //===--------------------------------------------------------------------===// // Location Parsing //===--------------------------------------------------------------------===// diff --git a/mlir/lib/AsmParser/TokenKinds.def b/mlir/lib/AsmParser/TokenKinds.def --- a/mlir/lib/AsmParser/TokenKinds.def +++ b/mlir/lib/AsmParser/TokenKinds.def @@ -109,7 +109,7 @@ TOK_KEYWORD(size) TOK_KEYWORD(sparse) TOK_KEYWORD(step) -TOK_KEYWORD(strides) +TOK_KEYWORD(strided) TOK_KEYWORD(symbol) TOK_KEYWORD(tensor) TOK_KEYWORD(to) diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp --- a/mlir/lib/AsmParser/TypeParser.cpp +++ b/mlir/lib/AsmParser/TypeParser.cpp @@ -146,35 +146,6 @@ return builder.getFunctionType(arguments, results); } -/// Parse the offset and strides from a strided layout specification. -/// -/// strided-layout ::= `offset:` dimension `,` `strides: ` stride-list -/// -ParseResult Parser::parseStridedLayout(int64_t &offset, - SmallVectorImpl &strides) { - // Parse offset. - consumeToken(Token::kw_offset); - if (parseToken(Token::colon, "expected colon after `offset` keyword")) - return failure(); - - auto maybeOffset = getToken().getUnsignedIntegerValue(); - bool question = getToken().is(Token::question); - if (!maybeOffset && !question) - return emitWrongTokenError("invalid offset"); - offset = maybeOffset ? static_cast(*maybeOffset) - : MemRefType::getDynamicStrideOrOffset(); - consumeToken(); - - // Parse stride list. - if (parseToken(Token::comma, "expected comma after offset value") || - parseToken(Token::kw_strides, - "expected `strides` keyword after offset specification") || - parseToken(Token::colon, "expected colon after `strides` keyword") || - parseStrideList(strides)) - return failure(); - return success(); -} - /// Parse a memref type. /// /// memref-type ::= ranked-memref-type | unranked-memref-type @@ -225,29 +196,18 @@ Attribute memorySpace; auto parseElt = [&]() -> ParseResult { - // Check for AffineMap as offset/strides. - if (getToken().is(Token::kw_offset)) { - int64_t offset; - SmallVector strides; - if (failed(parseStridedLayout(offset, strides))) - return failure(); - // Construct strided affine map. - AffineMap map = makeStridedLinearLayoutMap(strides, offset, getContext()); - layout = AffineMapAttr::get(map); - } else { - // Either it is MemRefLayoutAttrInterface or memory space attribute. - Attribute attr = parseAttribute(); - if (!attr) - return failure(); + // Either it is MemRefLayoutAttrInterface or memory space attribute. + Attribute attr = parseAttribute(); + if (!attr) + return failure(); - if (attr.isa()) { - layout = attr.cast(); - } else if (memorySpace) { - return emitError("multiple memory spaces specified in memref type"); - } else { - memorySpace = attr; - return success(); - } + if (attr.isa()) { + layout = attr.cast(); + } else if (memorySpace) { + return emitError("multiple memory spaces specified in memref type"); + } else { + memorySpace = attr; + return success(); } if (isUnranked) @@ -617,34 +577,3 @@ return success(); } - -// Parse a comma-separated list of dimensions, possibly empty: -// stride-list ::= `[` (dimension (`,` dimension)*)? `]` -ParseResult Parser::parseStrideList(SmallVectorImpl &dimensions) { - return parseCommaSeparatedList( - Delimiter::Square, - [&]() -> ParseResult { - if (consumeIf(Token::question)) { - dimensions.push_back(MemRefType::getDynamicStrideOrOffset()); - } else { - // This must be an integer value. - int64_t val; - if (getToken().getSpelling().getAsInteger(10, val)) - return emitError("invalid integer value: ") - << getToken().getSpelling(); - // Make sure it is not the one value for `?`. - if (ShapedType::isDynamic(val)) - return emitError("invalid integer value: ") - << getToken().getSpelling() - << ", use `?` to specify a dynamic dimension"; - - if (val == 0) - return emitError("invalid memref stride"); - - dimensions.push_back(val); - consumeToken(Token::integer); - } - return success(); - }, - " in stride list"); -} diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -2577,13 +2577,13 @@ /// ``` /// %0 = memref.cast %V : memref<16x16xf32> to memref /// %1 = memref.subview %0[0, 0][3, 4][1, 1] : -/// memref to memref<3x4xf32, offset:?, strides:[?, 1]> +/// memref to memref<3x4xf32, strided<[?, 1], offset: ?>> /// ``` /// is rewritten into: /// ``` /// %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]> -/// %1 = memref.cast %0: memref<3x4xf32, offset:0, strides:[16, 1]> to -/// memref<3x4xf32, offset:?, strides:[?, 1]> +/// %1 = memref.cast %0: memref<3x4xf32, strided<[16, 1], offset: 0>> to +/// memref<3x4xf32, strided<[?, 1], offset: ?>> /// ``` class SubViewOpMemRefCastFolder final : public OpRewritePattern { public: diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1858,6 +1858,8 @@ } os << '>'; } + } else if (auto stridedLayoutAttr = attr.dyn_cast()) { + stridedLayoutAttr.print(os); } else if (auto denseArrayAttr = attr.dyn_cast()) { typeElision = AttrTypeElision::Must; os << "array<" << denseArrayAttr.getType().getElementType(); diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -236,6 +236,84 @@ return getWithSorted(getContext(), vec); } +//===----------------------------------------------------------------------===// +// StridedLayoutAttr +//===----------------------------------------------------------------------===// + +/// Parses an integer or a question mark. Sets `value` to the integer value or +/// to ShapedType::kDynamicStrideOrOffset if the question mark was parsed. +static ParseResult parseIntOrQuestion(AsmParser &parser, int64_t &value) { + if (succeeded(parser.parseOptionalQuestion())) { + value = ShapedType::kDynamicStrideOrOffset; + return success(); + } + + return parser.parseInteger(value); +} + +/// Parses a strided layout attribute. +Attribute StridedLayoutAttr::parse(AsmParser &parser, Type odsType) { + StridedLayoutAttr attr; + if (failed(parser.parseAttribute(attr))) + return nullptr; + return attr; +} + +/// Prints a strided layout attribute. +void StridedLayoutAttr::print(AsmPrinter &printer) const { + print(printer.getStream()); +} + +/// Prints a strided layout attribute. +void StridedLayoutAttr::print(llvm::raw_ostream &os) const { + auto printIntOrQuestion = [&](int64_t value) { + if (value == ShapedType::kDynamicStrideOrOffset) + os << "?"; + else + os << value; + }; + + os << "strided<["; + llvm::interleaveComma(getStrides(), os, printIntOrQuestion); + os << "]"; + + if (getOffset() != 0) { + os << ", offset: "; + printIntOrQuestion(getOffset()); + } + os << ">"; +} + +/// Returns the strided layout as an affine map. +AffineMap StridedLayoutAttr::getAffineMap() const { + return makeStridedLinearLayoutMap(getStrides(), getOffset(), getContext()); +} + +/// Checks that the type-agnostic strided layout invariants are satisfied. +LogicalResult +StridedLayoutAttr::verify(function_ref emitError, + int64_t offset, ArrayRef strides) { + if (offset < 0 && offset != ShapedType::kDynamicStrideOrOffset) + return emitError() << "offset must be non-negative or dynamic"; + + if (llvm::any_of(strides, [&](int64_t stride) { + return stride <= 0 && stride != ShapedType::kDynamicStrideOrOffset; + })) { + return emitError() << "strides must be positive or dynamic"; + } + return success(); +} + +/// Checks that the type-specific strided layout invariants are satisfied. +LogicalResult StridedLayoutAttr::verifyLayout( + ArrayRef shape, + function_ref emitError) const { + if (shape.size() != getStrides().size()) + return emitError() << "expected the number of strides to match the rank"; + + return success(); +} + //===----------------------------------------------------------------------===// // StringAttr //===----------------------------------------------------------------------===// @@ -1783,3 +1861,43 @@ ArrayRef replTypes) const { return get(replTypes[0]); } + +//===----------------------------------------------------------------------===// +// Attribute Utilities +//===----------------------------------------------------------------------===// + +AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef strides, + int64_t offset, + MLIRContext *context) { + AffineExpr expr; + unsigned nSymbols = 0; + + // AffineExpr for offset. + // Static case. + if (offset != MemRefType::getDynamicStrideOrOffset()) { + auto cst = getAffineConstantExpr(offset, context); + expr = cst; + } else { + // Dynamic case, new symbol for the offset. + auto sym = getAffineSymbolExpr(nSymbols++, context); + expr = sym; + } + + // AffineExpr for strides. + for (const auto &en : llvm::enumerate(strides)) { + auto dim = en.index(); + auto stride = en.value(); + assert(stride != 0 && "Invalid stride specification"); + auto d = getAffineDimExpr(dim, context); + AffineExpr mult; + // Static case. + if (stride != MemRefType::getDynamicStrideOrOffset()) + mult = getAffineConstantExpr(stride, context); + else + // Dynamic case, new symbol for each new stride. + mult = getAffineSymbolExpr(nSymbols++, context); + expr = expr + d * mult; + } + + return AffineMap::get(strides.size(), nSymbols, expr); +} diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -895,42 +895,6 @@ // Type Utilities //===----------------------------------------------------------------------===// -AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef strides, - int64_t offset, - MLIRContext *context) { - AffineExpr expr; - unsigned nSymbols = 0; - - // AffineExpr for offset. - // Static case. - if (offset != MemRefType::getDynamicStrideOrOffset()) { - auto cst = getAffineConstantExpr(offset, context); - expr = cst; - } else { - // Dynamic case, new symbol for the offset. - auto sym = getAffineSymbolExpr(nSymbols++, context); - expr = sym; - } - - // AffineExpr for strides. - for (const auto &en : llvm::enumerate(strides)) { - auto dim = en.index(); - auto stride = en.value(); - assert(stride != 0 && "Invalid stride specification"); - auto d = getAffineDimExpr(dim, context); - AffineExpr mult; - // Static case. - if (stride != MemRefType::getDynamicStrideOrOffset()) - mult = getAffineConstantExpr(stride, context); - else - // Dynamic case, new symbol for each new stride. - mult = getAffineSymbolExpr(nSymbols++, context); - expr = expr + d * mult; - } - - return AffineMap::get(strides.size(), nSymbols, expr); -} - /// Return a version of `t` with identity layout if it can be determined /// statically that the layout is the canonical contiguous strided layout. /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of diff --git a/mlir/test/Conversion/FuncToLLVM/func-memref-return.mlir b/mlir/test/Conversion/FuncToLLVM/func-memref-return.mlir --- a/mlir/test/Conversion/FuncToLLVM/func-memref-return.mlir +++ b/mlir/test/Conversion/FuncToLLVM/func-memref-return.mlir @@ -40,7 +40,7 @@ // CHECK-SAME: -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // BAREPTR-LABEL: func @check_static_return_with_offset // BAREPTR-SAME: (%[[arg:.*]]: !llvm.ptr) -> !llvm.ptr { -func.func @check_static_return_with_offset(%static : memref<32x18xf32, offset:7, strides:[22,1]>) -> memref<32x18xf32, offset:7, strides:[22,1]> { +func.func @check_static_return_with_offset(%static : memref<32x18xf32, strided<[22,1], offset: 7>>) -> memref<32x18xf32, strided<[22,1], offset: 7>> { // CHECK: llvm.return %{{.*}} : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // BAREPTR: %[[udf:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> @@ -58,7 +58,7 @@ // BAREPTR-NEXT: %[[ins4:.*]] = llvm.insertvalue %[[val4]], %[[ins3]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // BAREPTR-NEXT: %[[base1:.*]] = llvm.extractvalue %[[ins4]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // BAREPTR-NEXT: llvm.return %[[base1]] : !llvm.ptr - return %static : memref<32x18xf32, offset:7, strides:[22,1]> + return %static : memref<32x18xf32, strided<[22,1], offset: 7>> } // ----- diff --git a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir --- a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir +++ b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir @@ -597,22 +597,22 @@ // CHECK-SAME: !spv.array<256 x f32, stride=4> [0])>, StorageBuffer> // CHECK-SAME: !spv.array<64 x f32, stride=4> [0])>, StorageBuffer> // CHECK-SAME: !spv.array<88 x f32, stride=4> [0])>, StorageBuffer> - %arg0: memref<16x4xf32, offset: 0, strides: [4, 1], #spv.storage_class>, // tightly packed; row major - %arg1: memref<16x4xf32, offset: 8, strides: [4, 1], #spv.storage_class>, // offset 8 - %arg2: memref<16x4xf32, offset: 0, strides: [16, 1], #spv.storage_class>, // pad 12 after each row - %arg3: memref<16x4xf32, offset: 0, strides: [1, 16], #spv.storage_class>, // tightly packed; col major - %arg4: memref<16x4xf32, offset: 0, strides: [1, 22], #spv.storage_class>, // pad 4 after each col + %arg0: memref<16x4xf32, strided<[4, 1], offset: 0>, #spv.storage_class>, // tightly packed; row major + %arg1: memref<16x4xf32, strided<[4, 1], offset: 8>, #spv.storage_class>, // offset 8 + %arg2: memref<16x4xf32, strided<[16, 1], offset: 0>, #spv.storage_class>, // pad 12 after each row + %arg3: memref<16x4xf32, strided<[1, 16], offset: 0>, #spv.storage_class>, // tightly packed; col major + %arg4: memref<16x4xf32, strided<[1, 22], offset: 0>, #spv.storage_class>, // pad 4 after each col // CHECK-SAME: !spv.array<64 x f16, stride=2> [0])>, StorageBuffer> // CHECK-SAME: !spv.array<72 x f16, stride=2> [0])>, StorageBuffer> // CHECK-SAME: !spv.array<256 x f16, stride=2> [0])>, StorageBuffer> // CHECK-SAME: !spv.array<64 x f16, stride=2> [0])>, StorageBuffer> // CHECK-SAME: !spv.array<88 x f16, stride=2> [0])>, StorageBuffer> - %arg5: memref<16x4xf16, offset: 0, strides: [4, 1], #spv.storage_class>, - %arg6: memref<16x4xf16, offset: 8, strides: [4, 1], #spv.storage_class>, - %arg7: memref<16x4xf16, offset: 0, strides: [16, 1], #spv.storage_class>, - %arg8: memref<16x4xf16, offset: 0, strides: [1, 16], #spv.storage_class>, - %arg9: memref<16x4xf16, offset: 0, strides: [1, 22], #spv.storage_class> + %arg5: memref<16x4xf16, strided<[4, 1], offset: 0>, #spv.storage_class>, + %arg6: memref<16x4xf16, strided<[4, 1], offset: 8>, #spv.storage_class>, + %arg7: memref<16x4xf16, strided<[16, 1], offset: 0>, #spv.storage_class>, + %arg8: memref<16x4xf16, strided<[1, 16], offset: 0>, #spv.storage_class>, + %arg9: memref<16x4xf16, strided<[1, 22], offset: 0>, #spv.storage_class> ) { return } } // end module diff --git a/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir b/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir --- a/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir @@ -436,7 +436,7 @@ %output = memref.reinterpret_cast %input to offset: [%offset], sizes: [%size_0, %size_1], strides: [%stride_0, %stride_1] - : memref<*xf32> to memref + : memref<*xf32> to memref> return } // CHECK-SAME: ([[OFFSETarg:%[a-z,0-9]+]]: index, diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir --- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir @@ -122,7 +122,7 @@ // CHECK32: %[[ARG0f:[a-zA-Z0-9]*]]: index, // CHECK32: %[[ARG1f:[a-zA-Z0-9]*]]: index, // CHECK32: %[[ARG2f:.*]]: index) -func.func @subview(%0 : memref<64x4xf32, offset: 0, strides: [4, 1]>, %arg0 : index, %arg1 : index, %arg2 : index) { +func.func @subview(%0 : memref<64x4xf32, strided<[4, 1], offset: 0>>, %arg0 : index, %arg1 : index, %arg2 : index) { // CHECK-DAG: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]] // CHECK-DAG: %[[ARG0:.*]] = builtin.unrealized_conversion_cast %[[ARG0f]] // CHECK-DAG: %[[ARG1:.*]] = builtin.unrealized_conversion_cast %[[ARG1f]] @@ -170,8 +170,8 @@ // CHECK32: %[[DESC5:.*]] = llvm.insertvalue %[[ARG0]], %[[DESC4]][3, 0] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> %1 = memref.subview %0[%arg0, %arg1][%arg0, %arg1][%arg0, %arg1] : - memref<64x4xf32, offset: 0, strides: [4, 1]> - to memref + memref<64x4xf32, strided<[4, 1], offset: 0>> + to memref> return } @@ -187,7 +187,7 @@ // CHECK32: %[[ARG0f:[a-zA-Z0-9]*]]: index, // CHECK32: %[[ARG1f:[a-zA-Z0-9]*]]: index, // CHECK32: %[[ARG2f:.*]]: index) -func.func @subview_non_zero_addrspace(%0 : memref<64x4xf32, offset: 0, strides: [4, 1], 3>, %arg0 : index, %arg1 : index, %arg2 : index) { +func.func @subview_non_zero_addrspace(%0 : memref<64x4xf32, strided<[4, 1], offset: 0>, 3>, %arg0 : index, %arg1 : index, %arg2 : index) { // CHECK-DAG: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]] // CHECK-DAG: %[[ARG0:.*]] = builtin.unrealized_conversion_cast %[[ARG0f]] // CHECK-DAG: %[[ARG1:.*]] = builtin.unrealized_conversion_cast %[[ARG1f]] @@ -234,8 +234,8 @@ // CHECK32: %[[DESC5:.*]] = llvm.insertvalue %[[ARG0]], %[[DESC4]][3, 0] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> %1 = memref.subview %0[%arg0, %arg1][%arg0, %arg1][%arg0, %arg1] : - memref<64x4xf32, offset: 0, strides: [4, 1], 3> - to memref + memref<64x4xf32, strided<[4, 1], offset: 0>, 3> + to memref, 3> return } @@ -251,7 +251,7 @@ // CHECK32-SAME: %[[ARG0f:[a-zA-Z0-9]*]]: index // CHECK32-SAME: %[[ARG1f:[a-zA-Z0-9]*]]: index // CHECK32-SAME: %[[ARG2f:[a-zA-Z0-9]*]]: index -func.func @subview_const_size(%0 : memref<64x4xf32, offset: 0, strides: [4, 1]>, %arg0 : index, %arg1 : index, %arg2 : index) { +func.func @subview_const_size(%0 : memref<64x4xf32, strided<[4, 1], offset: 0>>, %arg0 : index, %arg1 : index, %arg2 : index) { // CHECK-DAG: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]] // CHECK-DAG: %[[ARG0:.*]] = builtin.unrealized_conversion_cast %[[ARG0f]] // CHECK-DAG: %[[ARG1:.*]] = builtin.unrealized_conversion_cast %[[ARG1f]] @@ -302,8 +302,8 @@ // CHECK32: %[[DESC5:.*]] = llvm.insertvalue %[[CST4]], %[[DESC4]][3, 0] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> // CHECK32: llvm.insertvalue %[[DESCSTRIDE0]], %[[DESC5]][4, 0] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> %1 = memref.subview %0[%arg0, %arg1][4, 2][%arg0, %arg1] : - memref<64x4xf32, offset: 0, strides: [4, 1]> - to memref<4x2xf32, offset: ?, strides: [?, ?]> + memref<64x4xf32, strided<[4, 1], offset: 0>> + to memref<4x2xf32, strided<[?, ?], offset: ?>> return } @@ -319,7 +319,7 @@ // CHECK32-SAME: %[[ARG0f:[a-zA-Z0-9]*]]: index // CHECK32-SAME: %[[ARG1f:[a-zA-Z0-9]*]]: index // CHECK32-SAME: %[[ARG2f:[a-zA-Z0-9]*]]: index -func.func @subview_const_stride(%0 : memref<64x4xf32, offset: 0, strides: [4, 1]>, %arg0 : index, %arg1 : index, %arg2 : index) { +func.func @subview_const_stride(%0 : memref<64x4xf32, strided<[4, 1], offset: 0>>, %arg0 : index, %arg1 : index, %arg2 : index) { // CHECK-DAG: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]] // CHECK-DAG: %[[ARG0:.*]] = builtin.unrealized_conversion_cast %[[ARG0f]] // CHECK-DAG: %[[ARG1:.*]] = builtin.unrealized_conversion_cast %[[ARG1f]] @@ -366,8 +366,8 @@ // CHECK32: %[[DESC5:.*]] = llvm.insertvalue %[[ARG0]], %[[DESC4]][3, 0] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> // CHECK32: llvm.insertvalue %[[CST4]], %[[DESC5]][4, 0] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> %1 = memref.subview %0[%arg0, %arg1][%arg0, %arg1][1, 2] : - memref<64x4xf32, offset: 0, strides: [4, 1]> - to memref + memref<64x4xf32, strided<[4, 1], offset: 0>> + to memref> return } @@ -375,7 +375,7 @@ // CHECK-LABEL: func @subview_const_stride_and_offset( // CHECK32-LABEL: func @subview_const_stride_and_offset( -func.func @subview_const_stride_and_offset(%0 : memref<64x4xf32, offset: 0, strides: [4, 1]>) { +func.func @subview_const_stride_and_offset(%0 : memref<64x4xf32, strided<[4, 1], offset: 0>>) { // The last "insertvalue" that populates the memref descriptor from the function arguments. // CHECK: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast // CHECK32: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast @@ -398,8 +398,8 @@ // CHECK32: %[[DESC5:.*]] = llvm.insertvalue %[[CST62]], %[[DESC4]][3, 0] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> // CHECK32: llvm.insertvalue %[[CST4]], %[[DESC5]][4, 0] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> %1 = memref.subview %0[0, 8][62, 3][1, 1] : - memref<64x4xf32, offset: 0, strides: [4, 1]> - to memref<62x3xf32, offset: 8, strides: [4, 1]> + memref<64x4xf32, strided<[4, 1], offset: 0>> + to memref<62x3xf32, strided<[4, 1], offset: 8>> return } @@ -415,7 +415,7 @@ // CHECK32: %[[ARG0f:[a-zA-Z0-9]*]]: index, // CHECK32: %[[ARG1f:[a-zA-Z0-9]*]]: index, // CHECK32: %[[ARG2f:.*]]: index) -func.func @subview_mixed_static_dynamic(%0 : memref<64x4xf32, offset: 0, strides: [4, 1]>, %arg0 : index, %arg1 : index, %arg2 : index) { +func.func @subview_mixed_static_dynamic(%0 : memref<64x4xf32, strided<[4, 1], offset: 0>>, %arg0 : index, %arg1 : index, %arg2 : index) { // CHECK32-DAG: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]] // CHECK32-DAG: %[[ARG1:.*]] = builtin.unrealized_conversion_cast %[[ARG1f]] // CHECK32-DAG: %[[ARG2:.*]] = builtin.unrealized_conversion_cast %[[ARG2f]] @@ -444,8 +444,8 @@ // CHECK32: %[[DESC5:.*]] = llvm.insertvalue %[[CST62]], %[[DESC4]][3, 0] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> // CHECK32: llvm.insertvalue %[[DESCSTRIDE0]], %[[DESC5]][4, 0] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> %1 = memref.subview %0[%arg1, 8][62, %arg2][%arg0, 1] : - memref<64x4xf32, offset: 0, strides: [4, 1]> - to memref<62x?xf32, offset: ?, strides: [?, 1]> + memref<64x4xf32, strided<[4, 1], offset: 0>> + to memref<62x?xf32, strided<[?, 1], offset: ?>> return } @@ -471,7 +471,7 @@ // CHECK: %[[C3_3:.*]] = llvm.mlir.constant(3 : i64) : i64 // CHECK: llvm.insertvalue %[[C3_2]], %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: llvm.insertvalue %[[C3_3]], %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - %2 = memref.subview %0[2, 0][3, 3][1, 1]: memref<5x3xf32> to memref<3x3xf32, offset: 6, strides: [3, 1]> + %2 = memref.subview %0[2, 0][3, 3][1, 1]: memref<5x3xf32> to memref<3x3xf32, strided<[3, 1], offset: 6>> return } @@ -509,7 +509,7 @@ // CHECK: llvm.insertvalue %[[MUL]], %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> %c0 = arith.constant 1 : index %d0 = memref.dim %0, %c0 : memref<5x?xf32> - %1 = memref.subview %0[2, 0][3, %d0][1, 1]: memref<5x?xf32> to memref<3x?xf32, offset: ?, strides: [?, 1]> + %1 = memref.subview %0[2, 0][3, %d0][1, 1]: memref<5x?xf32> to memref<3x?xf32, strided<[?, 1], offset: ?>> return } @@ -533,7 +533,7 @@ // CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64 // CHECK: llvm.insertvalue %[[C3]], %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // CHECK: llvm.insertvalue %[[C1]], %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - %1 = memref.subview %0[1, 0][1, 3][1, 1]: memref<5x3xf32> to memref<3xf32, offset: 3, strides: [1]> + %1 = memref.subview %0[1, 0][1, 3][1, 1]: memref<5x3xf32> to memref<3xf32, strided<[1], offset: 3>> return } @@ -608,8 +608,8 @@ // CHECK: llvm.insertvalue {{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> // CHECK: llvm.extractvalue {{.*}}[3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> // CHECK: llvm.insertvalue {{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -func.func @transpose(%arg0: memref) { - %0 = memref.transpose %arg0 (i, j, k) -> (k, i, j) : memref to memref (d2 * s1 + s0 + d0 * s2 + d1)>> +func.func @transpose(%arg0: memref>) { + %0 = memref.transpose %arg0 (i, j, k) -> (k, i, j) : memref> to memref (d2 * s1 + s0 + d0 * s2 + d1)>> return } diff --git a/mlir/test/Dialect/Affine/memref-stride-calculation.mlir b/mlir/test/Dialect/Affine/memref-stride-calculation.mlir --- a/mlir/test/Dialect/Affine/memref-stride-calculation.mlir +++ b/mlir/test/Dialect/Affine/memref-stride-calculation.mlir @@ -17,7 +17,7 @@ %11 = memref.alloc() : memref<3x4x5xf32, affine_map<(i, j, k)->(i, j, k)>> // CHECK: MemRefType offset: 0 strides: 20, 5, 1 - %b11 = memref.alloc() : memref<3x4x5xf32, offset: 0, strides: [20, 5, 1]> + %b11 = memref.alloc() : memref<3x4x5xf32, strided<[20, 5, 1], offset: 0>> // CHECK: MemRefType offset: 0 strides: 20, 5, 1 %12 = memref.alloc(%0) : memref<3x4x?xf32, affine_map<(i, j, k)->(i, j, k)>> // CHECK: MemRefType offset: 0 strides: ?, ?, 1 @@ -34,19 +34,19 @@ // CHECK: MemRefType offset: 1 strides: 32, 16, ? %22 = memref.alloc()[%0] : memref<3x4x5xf32, affine_map<(i, j, k)[M]->(32 * i + M * j + 16 * k + 3)>> // CHECK: MemRefType offset: 3 strides: 32, ?, 16 - %b22 = memref.alloc(%0)[%0, %0] : memref<3x4x?xf32, offset: 0, strides: [?, ?, 1]> + %b22 = memref.alloc(%0)[%0, %0] : memref<3x4x?xf32, strided<[?, ?, 1], offset: 0>> // CHECK: MemRefType offset: 0 strides: ?, ?, 1 %23 = memref.alloc(%0)[%0] : memref<3x?x5xf32, affine_map<(i, j, k)[M]->(M * i + 32 * j + 16 * k + 7)>> // CHECK: MemRefType offset: 7 strides: ?, 32, 16 - %b23 = memref.alloc(%0)[%0] : memref<3x?x5xf32, offset: 0, strides: [?, 5, 1]> + %b23 = memref.alloc(%0)[%0] : memref<3x?x5xf32, strided<[?, 5, 1], offset: 0>> // CHECK: MemRefType offset: 0 strides: ?, 5, 1 %24 = memref.alloc(%0)[%0] : memref<3x?x5xf32, affine_map<(i, j, k)[M]->(M * i + 32 * j + 16 * k + M)>> // CHECK: MemRefType offset: ? strides: ?, 32, 16 - %b24 = memref.alloc(%0)[%0, %0] : memref<3x?x5xf32, offset: ?, strides: [?, 32, 16]> + %b24 = memref.alloc(%0)[%0, %0] : memref<3x?x5xf32, strided<[?, 32, 16], offset: ?>> // CHECK: MemRefType offset: ? strides: ?, 32, 16 %25 = memref.alloc(%0, %0)[%0, %0] : memref(M * i + N * j + k + 1)>> // CHECK: MemRefType offset: 1 strides: ?, ?, 1 - %b25 = memref.alloc(%0, %0)[%0, %0] : memref + %b25 = memref.alloc(%0, %0)[%0, %0] : memref> // CHECK: MemRefType offset: 1 strides: ?, ?, 1 %26 = memref.alloc(%0)[] : memref(i)>> // CHECK: MemRefType offset: 0 strides: 1 diff --git a/mlir/test/Dialect/Affine/ops.mlir b/mlir/test/Dialect/Affine/ops.mlir --- a/mlir/test/Dialect/Affine/ops.mlir +++ b/mlir/test/Dialect/Affine/ops.mlir @@ -103,8 +103,8 @@ affine.for %arg4 = 0 to %13 step 264 { %18 = memref.dim %0, %c0 : memref %20 = memref.subview %0[%c0, %c0][%18,%arg4][%c1,%c1] : memref - to memref - %24 = memref.dim %20, %c0 : memref + to memref> + %24 = memref.dim %20, %c0 : memref> affine.for %arg5 = 0 to %24 step 768 { "foo"() : () -> () } diff --git a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation.mlir b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation.mlir @@ -706,12 +706,12 @@ // CHECK-LABEL: func @subview func.func @subview(%arg0 : index, %arg1 : index, %arg2 : memref) { - %0 = memref.alloc() : memref<64x4xf32, offset: 0, strides: [4, 1]> + %0 = memref.alloc() : memref<64x4xf32, strided<[4, 1], offset: 0>> %1 = memref.subview %0[%arg0, %arg1][%arg0, %arg1][%arg0, %arg1] : - memref<64x4xf32, offset: 0, strides: [4, 1]> - to memref + memref<64x4xf32, strided<[4, 1], offset: 0>> + to memref> test.copy(%1, %arg2) : - (memref, memref) + (memref>, memref) return } diff --git a/mlir/test/Dialect/Bufferization/canonicalize.mlir b/mlir/test/Dialect/Bufferization/canonicalize.mlir --- a/mlir/test/Dialect/Bufferization/canonicalize.mlir +++ b/mlir/test/Dialect/Bufferization/canonicalize.mlir @@ -46,52 +46,46 @@ // ----- -// CHECK-DAG: #[[$OFF_3:[a-z0-9]+]] = affine_map<(d0) -> (d0 + 3)> -// CHECK-DAG: #[[$OFF_UNK:[a-z0-9]+]] = affine_map<(d0)[s0] -> (d0 + s0)> - // If the memrefs are definitely cast-compatible, canonicalize to // cast. // CHECK-LABEL: func @canonicalize_buffer_cast_of_tensor_load( -// CHECK-SAME: %[[M:.*]]: memref) -// CHECK-SAME: -> memref { +// CHECK-SAME: %[[M:.*]]: memref>) +// CHECK-SAME: -> memref> { // CHECK-NOT: bufferization.to_tensor // CHECK-NOT: bufferization.to_memref // CHECK: %[[R:.*]] = memref.cast %[[M]] -// CHECK-SAME: memref to memref +// CHECK-SAME: memref> to memref> // CHECK: return %[[R]] func.func @canonicalize_buffer_cast_of_tensor_load( - %arg0: memref) - -> memref + %arg0: memref>) + -> memref> { - %0 = bufferization.to_tensor %arg0 : memref - %1 = bufferization.to_memref %0 : memref - return %1 : memref + %0 = bufferization.to_tensor %arg0 : memref> + %1 = bufferization.to_memref %0 : memref> + return %1 : memref> } // ----- -// CHECK-DAG: #[[$OFF_UNK:[a-z0-9]+]] = affine_map<(d0)[s0] -> (d0 + s0)> -// CHECK-DAG: #[[$OFF_3:[a-z0-9]+]] = affine_map<(d0) -> (d0 + 3)> - // If the memrefs are potentially cast-compatible, canonicalize to // copy. // CHECK-LABEL: func @canonicalize_buffer_cast_of_tensor_load_to_copy( func.func @canonicalize_buffer_cast_of_tensor_load_to_copy( - %arg0: memref) - -> memref { - %0 = bufferization.to_tensor %arg0 : memref - %1 = bufferization.to_memref %0 : memref - return %1 : memref + %arg0: memref>) + -> memref> { + %0 = bufferization.to_tensor %arg0 : memref> + %1 = bufferization.to_memref %0 : memref> + return %1 : memref> } -// CHECK-SAME: %[[M:.*]]: memref) -// CHECK-SAME: -> memref { +// CHECK-SAME: %[[M:.*]]: memref>) +// CHECK-SAME: -> memref> { // CHECK-NOT: bufferization.to_tensor // CHECK-NOT: bufferization.to_memref // CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[DIM:.*]] = memref.dim %[[M]], %[[C0]] : memref -// CHECK: %[[ALLOC:.*]] = memref.alloc(%[[DIM]]) : memref +// CHECK: %[[DIM:.*]] = memref.dim %[[M]], %[[C0]] : memref> +// CHECK: %[[ALLOC:.*]] = memref.alloc(%[[DIM]]) : memref> // CHECK: memref.copy %[[M]], %[[ALLOC]] -// CHECK-SAME: memref to memref +// CHECK-SAME: memref> to memref> // CHECK: return %[[ALLOC]] // ----- diff --git a/mlir/test/Dialect/Builtin/types.mlir b/mlir/test/Dialect/Builtin/types.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Builtin/types.mlir @@ -0,0 +1,18 @@ +// RUN: mlir-opt %s | mlir-opt | FileCheck %s + +// CHECK: memref> +func.func private @f1() -> memref> +// CHECK: memref> +func.func private @f2() -> memref> +// CHECK: memref> +func.func private @f3() -> memref> +// CHECK: memref> +func.func private @f4() -> memref> +// CHECK: memref> +func.func private @f5() -> memref> +// CHECK: memref> +func.func private @f6() -> memref> +// CHECK: memref> +func.func private @f7() -> memref> +// CHECK: memref> +func.func private @f8() -> memref> diff --git a/mlir/test/Dialect/Linalg/fusion-2-level.mlir b/mlir/test/Dialect/Linalg/fusion-2-level.mlir --- a/mlir/test/Dialect/Linalg/fusion-2-level.mlir +++ b/mlir/test/Dialect/Linalg/fusion-2-level.mlir @@ -1,6 +1,6 @@ // RUN: mlir-opt %s -test-linalg-greedy-fusion | FileCheck %s -func.func @f1(%A: memref, %B: memref, %C: memref, %D: memref, %E: memref) -> memref { +func.func @f1(%A: memref>, %B: memref>, %C: memref>, %D: memref>, %E: memref>) -> memref> { %c1 = arith.constant 1 : index %c0 = arith.constant 0 : index %c4 = arith.constant 4 : index @@ -9,35 +9,35 @@ %c40 = arith.constant 40 : index %c30 = arith.constant 30 : index %c20 = arith.constant 20 : index - %0 = memref.dim %C, %c0 : memref - %1 = memref.dim %C, %c1 : memref - %2 = memref.dim %D, %c1 : memref - linalg.matmul ins(%A, %B: memref, memref) - outs(%C: memref) + %0 = memref.dim %C, %c0 : memref> + %1 = memref.dim %C, %c1 : memref> + %2 = memref.dim %D, %c1 : memref> + linalg.matmul ins(%A, %B: memref>, memref>) + outs(%C: memref>) scf.for %arg5 = %c0 to %0 step %c20 { scf.for %arg6 = %c0 to %2 step %c30 { scf.for %arg7 = %c0 to %1 step %c40 { - %5 = memref.subview %C[%arg5, %arg7][%c20, %c40][%c1, %c1] : memref to memref - %7 = memref.subview %D[%arg7, %arg6][%c40, %c30][%c1, %c1]: memref to memref - %8 = memref.subview %E[%arg5, %arg6][%c20, %c40][%c1, %c1] : memref to memref - %9 = memref.dim %5, %c0 : memref - %10 = memref.dim %5, %c1 : memref - %11 = memref.dim %7, %c1 : memref + %5 = memref.subview %C[%arg5, %arg7][%c20, %c40][%c1, %c1] : memref> to memref> + %7 = memref.subview %D[%arg7, %arg6][%c40, %c30][%c1, %c1]: memref> to memref> + %8 = memref.subview %E[%arg5, %arg6][%c20, %c40][%c1, %c1] : memref> to memref> + %9 = memref.dim %5, %c0 : memref> + %10 = memref.dim %5, %c1 : memref> + %11 = memref.dim %7, %c1 : memref> scf.for %arg8 = %c0 to %9 step %c2 { scf.for %arg9 = %c0 to %11 step %c3 { scf.for %arg10 = %c0 to %10 step %c4 { - %14 = memref.subview %5[%arg8, %arg10][%c2, %c4][%c1, %c1] : memref to memref - %16 = memref.subview %7[%arg10, %arg9][%c4, %c3][%c1, %c1]: memref to memref - %17 = memref.subview %8[%arg8, %arg9][%c2, %c3][%c1, %c1] : memref to memref - linalg.matmul ins(%14, %16: memref, memref) - outs(%17: memref) + %14 = memref.subview %5[%arg8, %arg10][%c2, %c4][%c1, %c1] : memref> to memref> + %16 = memref.subview %7[%arg10, %arg9][%c4, %c3][%c1, %c1]: memref> to memref> + %17 = memref.subview %8[%arg8, %arg9][%c2, %c3][%c1, %c1] : memref> to memref> + linalg.matmul ins(%14, %16: memref>, memref>) + outs(%17: memref>) } } } } } } - return %E : memref + return %E : memref> } // CHECK-LABEL: func @f1 // CHECK: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}}) diff --git a/mlir/test/Dialect/Linalg/fusion.mlir b/mlir/test/Dialect/Linalg/fusion.mlir --- a/mlir/test/Dialect/Linalg/fusion.mlir +++ b/mlir/test/Dialect/Linalg/fusion.mlir @@ -1,41 +1,41 @@ // RUN: mlir-opt %s -test-linalg-greedy-fusion -split-input-file | FileCheck %s -func.func @f1(%A: memref, - %B: memref, - %C: memref, - %D: memref, - %E: memref - ) -> memref { +func.func @f1(%A: memref>, + %B: memref>, + %C: memref>, + %D: memref>, + %E: memref> + ) -> memref> { %c0 = arith.constant 0 : index %c4 = arith.constant 4 : index %c3 = arith.constant 3 : index %c2 = arith.constant 2 : index %c1 = arith.constant 1 : index - %0 = memref.dim %A, %c0 : memref - %1 = memref.dim %A, %c1 : memref - %2 = memref.dim %B, %c1 : memref - linalg.matmul ins(%A, %B : memref, - memref) - outs(%C : memref) + %0 = memref.dim %A, %c0 : memref> + %1 = memref.dim %A, %c1 : memref> + %2 = memref.dim %B, %c1 : memref> + linalg.matmul ins(%A, %B : memref>, + memref>) + outs(%C : memref>) scf.for %arg5 = %c0 to %0 step %c2 { scf.for %arg6 = %c0 to %2 step %c3 { scf.for %arg7 = %c0 to %1 step %c4 { %5 = memref.subview %A[%arg5, %arg7][%c2, %c4][%c1, %c1] : - memref to - memref + memref> to + memref> %7 = memref.subview %B[%arg7, %arg6][%c4, %c3][%c1, %c1] : - memref to - memref + memref> to + memref> %8 = memref.subview %C[%arg5, %arg6][%c2, %c3][%c1, %c1] : - memref to - memref - linalg.matmul ins(%5, %7 : memref, - memref) - outs(%8: memref) + memref> to + memref> + linalg.matmul ins(%5, %7 : memref>, + memref>) + outs(%8: memref>) } } } - return %E : memref + return %E : memref> } // CHECK-LABEL: func @f1 // CHECK: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}}) @@ -47,49 +47,48 @@ // ----- -// CHECK-DAG: #[[$strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s0 + d1 * s1)> -func.func @f2(%A: memref, - %B: memref, - %C: memref, - %D: memref, - %E: memref - ) -> memref { +func.func @f2(%A: memref>, + %B: memref>, + %C: memref>, + %D: memref>, + %E: memref> + ) -> memref> { %c1 = arith.constant 1 : index %c0 = arith.constant 0 : index %c4 = arith.constant 4 : index %c3 = arith.constant 3 : index %c2 = arith.constant 2 : index - linalg.matmul ins(%A, %B : memref, - memref) - outs(%C: memref) - %0 = memref.dim %C, %c0 : memref - %1 = memref.dim %C, %c1 : memref - %2 = memref.dim %D, %c1 : memref + linalg.matmul ins(%A, %B : memref>, + memref>) + outs(%C: memref>) + %0 = memref.dim %C, %c0 : memref> + %1 = memref.dim %C, %c1 : memref> + %2 = memref.dim %D, %c1 : memref> scf.for %arg5 = %c0 to %0 step %c2 { scf.for %arg6 = %c0 to %2 step %c3 { scf.for %arg7 = %c0 to %1 step %c4 { %5 = memref.subview %C[%arg5, %arg7][%c2, %c4][%c1, %c1] : - memref to - memref + memref> to + memref> %7 = memref.subview %D[%arg7, %arg6][%c4, %c3][%c1, %c1] : - memref to - memref + memref> to + memref> %8 = memref.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] : - memref to - memref - linalg.matmul ins(%5, %7 : memref, - memref) - outs(%8 : memref) + memref> to + memref> + linalg.matmul ins(%5, %7 : memref>, + memref>) + outs(%8 : memref>) } } } - return %E : memref + return %E : memref> } // CHECK-LABEL: func @f2 // CHECK: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}}) -// CHECK-DAG: %[[C_0:.*]] = memref.dim %[[C]], %c0{{[_0-9]*}} : memref -// CHECK-DAG: %[[C_1:.*]] = memref.dim %[[C]], %c1{{[_0-9]*}} : memref -// CHECK-DAG: %[[D_1:.*]] = memref.dim %[[D]], %c1{{[_0-9]*}} : memref +// CHECK-DAG: %[[C_0:.*]] = memref.dim %[[C]], %c0{{[_0-9]*}} : memref> +// CHECK-DAG: %[[C_1:.*]] = memref.dim %[[C]], %c1{{[_0-9]*}} : memref> +// CHECK-DAG: %[[D_1:.*]] = memref.dim %[[D]], %c1{{[_0-9]*}} : memref> // CHECK: scf.for %{{.*}} = %{{.*}} to %[[C_0]] step %{{.*}} { // CHECK: scf.for %{{.*}} = %{{.*}} to %[[D_1]] step %{{.*}} { // CHECK: scf.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} { @@ -98,52 +97,50 @@ // ----- -// CHECK-DAG: #[[$strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s0 + d1 * s1)> - -func.func @f3(%A: memref, - %B: memref, - %C: memref, - %D: memref, - %E: memref - ) -> memref { +func.func @f3(%A: memref>, + %B: memref>, + %C: memref>, + %D: memref>, + %E: memref> + ) -> memref> { %c1 = arith.constant 1 : index %c0 = arith.constant 0 : index %c4 = arith.constant 4 : index %c3 = arith.constant 3 : index %c2 = arith.constant 2 : index - linalg.matmul ins(%A, %B : memref, - memref) - outs(%C : memref) - %0 = memref.dim %D, %c0 : memref - %1 = memref.dim %D, %c1 : memref - %2 = memref.dim %C, %c1 : memref + linalg.matmul ins(%A, %B : memref>, + memref>) + outs(%C : memref>) + %0 = memref.dim %D, %c0 : memref> + %1 = memref.dim %D, %c1 : memref> + %2 = memref.dim %C, %c1 : memref> scf.for %arg5 = %c0 to %0 step %c2 { scf.for %arg6 = %c0 to %2 step %c3 { scf.for %arg7 = %c0 to %1 step %c4 { %5 = memref.subview %D[%arg5, %arg7][%c2, %c4][%c1, %c1] : - memref to - memref + memref> to + memref> %7 = memref.subview %C[%arg7, %arg6][%c4, %c3][%c1, %c1] : - memref to - memref + memref> to + memref> %8 = memref.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] : - memref to - memref - linalg.matmul ins(%5, %7 : memref, - memref) - outs(%8 : memref) + memref> to + memref> + linalg.matmul ins(%5, %7 : memref>, + memref>) + outs(%8 : memref>) } } } - return %E : memref + return %E : memref> } // CHECK-LABEL: func @f3 // CHECK: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}}) // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[D_0:.*]] = memref.dim %[[D]], %[[C0]] : memref -// CHECK: %[[D_1:.*]] = memref.dim %[[D]], %[[C1]] : memref -// CHECK: %[[C_1:.*]] = memref.dim %[[C]], %[[C1]] : memref +// CHECK: %[[D_0:.*]] = memref.dim %[[D]], %[[C0]] : memref> +// CHECK: %[[D_1:.*]] = memref.dim %[[D]], %[[C1]] : memref> +// CHECK: %[[C_1:.*]] = memref.dim %[[C]], %[[C1]] : memref> // CHECK: scf.for %{{.*}} = %{{.*}} to %[[D_0]] step %{{.*}} { // CHECK: scf.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} { // CHECK: scf.for %{{.*}} = %{{.*}} to %[[D_1]] step %{{.*}} { @@ -152,55 +149,53 @@ // ----- -// CHECK-DAG: #[[$strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s0 + d1 * s1)> - -func.func @f4(%A: memref, - %B: memref, - %C: memref, - %D: memref, - %E: memref - ) -> memref { +func.func @f4(%A: memref>, + %B: memref>, + %C: memref>, + %D: memref>, + %E: memref> + ) -> memref> { %c1 = arith.constant 1 : index %c0 = arith.constant 0 : index %c4 = arith.constant 4 : index %c3 = arith.constant 3 : index %c2 = arith.constant 2 : index - linalg.matmul ins(%A, %B : memref, - memref) - outs(%C : memref) - linalg.matmul ins(%A, %B : memref, - memref) - outs(%D : memref) - %0 = memref.dim %C, %c0 : memref - %1 = memref.dim %C, %c1 : memref - %2 = memref.dim %D, %c1 : memref + linalg.matmul ins(%A, %B : memref>, + memref>) + outs(%C : memref>) + linalg.matmul ins(%A, %B : memref>, + memref>) + outs(%D : memref>) + %0 = memref.dim %C, %c0 : memref> + %1 = memref.dim %C, %c1 : memref> + %2 = memref.dim %D, %c1 : memref> scf.for %arg5 = %c0 to %0 step %c2 { scf.for %arg6 = %c0 to %2 step %c3 { scf.for %arg7 = %c0 to %1 step %c4 { %5 = memref.subview %C[%arg5, %arg7][%c2, %c4][%c1, %c1] : - memref to - memref + memref> to + memref> %7 = memref.subview %D[%arg7, %arg6][%c4, %c3][%c1, %c1] : - memref to - memref + memref> to + memref> %8 = memref.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] : - memref to - memref - linalg.matmul ins(%5, %7 : memref, - memref) - outs(%8 : memref) + memref> to + memref> + linalg.matmul ins(%5, %7 : memref>, + memref>) + outs(%8 : memref>) } } } - return %E : memref + return %E : memref> } // CHECK-LABEL: func @f4 // CHECK: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}}) // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[C_0:.*]] = memref.dim %[[C]], %[[C0:.*]] : memref -// CHECK: %[[C_1:.*]] = memref.dim %[[C]], %[[C1:.*]] : memref -// CHECK: %[[D_1:.*]] = memref.dim %[[D]], %[[C1:.*]] : memref +// CHECK: %[[C_0:.*]] = memref.dim %[[C]], %[[C0:.*]] : memref> +// CHECK: %[[C_1:.*]] = memref.dim %[[C]], %[[C1:.*]] : memref> +// CHECK: %[[D_1:.*]] = memref.dim %[[D]], %[[C1:.*]] : memref> // CHECK: scf.for %{{.*}} = %{{.*}} to %[[C_0]] step %{{.*}} { // CHECK: scf.for %{{.*}} = %{{.*}} to %[[D_1]] step %{{.*}} { // CHECK: scf.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} { @@ -211,46 +206,45 @@ // ----- -// CHECK-DAG: #[[$strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s0 + d1 * s1)> -func.func @f5(%A: memref, - %B: memref, - %C: memref, - %D: memref, - %E: memref - ) -> memref { +func.func @f5(%A: memref>, + %B: memref>, + %C: memref>, + %D: memref>, + %E: memref> + ) -> memref> { %c1 = arith.constant 1 : index %c0 = arith.constant 0 : index %c4 = arith.constant 4 : index %c3 = arith.constant 3 : index %c2 = arith.constant 2 : index - %0 = memref.dim %B, %c1 : memref - %1 = memref.dim %D, %c0 : memref - %2 = memref.dim %D, %c1 : memref - linalg.matmul ins(%A, %B : memref, - memref) - outs(%C : memref) - linalg.matmul ins(%C, %B : memref, - memref) - outs(%D : memref) + %0 = memref.dim %B, %c1 : memref> + %1 = memref.dim %D, %c0 : memref> + %2 = memref.dim %D, %c1 : memref> + linalg.matmul ins(%A, %B : memref>, + memref>) + outs(%C : memref>) + linalg.matmul ins(%C, %B : memref>, + memref>) + outs(%D : memref>) scf.for %arg5 = %c0 to %1 step %c2 { scf.for %arg6 = %c0 to %0 step %c3 { scf.for %arg7 = %c0 to %2 step %c4 { %5 = memref.subview %D[%arg5, %arg7][%c2, %c4][%c1, %c1] : - memref to - memref + memref> to + memref> %7 = memref.subview %B[%arg7, %arg6][%c4, %c3][%c1, %c1] : - memref to - memref + memref> to + memref> %8 = memref.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] : - memref to - memref - linalg.matmul ins(%5, %7 : memref, - memref) - outs(%8 : memref) + memref> to + memref> + linalg.matmul ins(%5, %7 : memref>, + memref>) + outs(%8 : memref>) } } } - return %E : memref + return %E : memref> } // CHECK-DAG: #[[BOUND_2_MAP:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 2)> @@ -260,11 +254,11 @@ // CHECK-SAME: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}}) // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[A_0:.*]] = memref.dim %[[A]], %[[C0]] : memref -// CHECK-DAG: %[[B_1:.*]] = memref.dim %[[B]], %[[C1]] : memref -// CHECK-DAG: %[[C_0:.*]] = memref.dim %[[C]], %[[C0]] : memref -// CHECK-DAG: %[[D_0:.*]] = memref.dim %[[D]], %[[C0]] : memref -// CHECK-DAG: %[[D_1:.*]] = memref.dim %[[D]], %[[C1]] : memref +// CHECK-DAG: %[[A_0:.*]] = memref.dim %[[A]], %[[C0]] : memref> +// CHECK-DAG: %[[B_1:.*]] = memref.dim %[[B]], %[[C1]] : memref> +// CHECK-DAG: %[[C_0:.*]] = memref.dim %[[C]], %[[C0]] : memref> +// CHECK-DAG: %[[D_0:.*]] = memref.dim %[[D]], %[[C0]] : memref> +// CHECK-DAG: %[[D_1:.*]] = memref.dim %[[D]], %[[C1]] : memref> // CHECK-DAG: %[[B_00:.*]] = memref.subview %[[B]][0, 0]{{.*}} // CHECK: scf.for %[[I:.*]] = %{{.*}} to %[[D_0]] step %{{.*}} { // CHECK: %[[BOUND_2_C0:.+]] = affine.min #[[BOUND_2_MAP]](%[[I]])[%[[C_0]]] @@ -290,48 +284,48 @@ #map1 = affine_map<(d0) -> (d0 + 4)> #map2 = affine_map<(d0) -> (d0 + 3)> -func.func @f6(%A: memref, - %B: memref, - %C: memref, - %D: memref, - %E: memref - ) -> memref { +func.func @f6(%A: memref>, + %B: memref>, + %C: memref>, + %D: memref>, + %E: memref> + ) -> memref> { %c1 = arith.constant 1 : index %c0 = arith.constant 0 : index %c4 = arith.constant 4 : index %c3 = arith.constant 3 : index %c2 = arith.constant 2 : index - %0 = memref.dim %C, %c1 : memref - linalg.matmul ins(%A, %B : memref, - memref) - outs(%C : memref) - linalg.matmul ins(%A, %C : memref, - memref) - outs(%E : memref) - %1 = memref.dim %C, %c0 : memref - %2 = memref.dim %D, %c1 : memref + %0 = memref.dim %C, %c1 : memref> + linalg.matmul ins(%A, %B : memref>, + memref>) + outs(%C : memref>) + linalg.matmul ins(%A, %C : memref>, + memref>) + outs(%E : memref>) + %1 = memref.dim %C, %c0 : memref> + %2 = memref.dim %D, %c1 : memref> scf.for %arg5 = %c0 to %1 step %c2 { scf.for %arg6 = %c0 to %2 step %c3 { scf.for %arg7 = %c0 to %0 step %c4 { %3 = affine.apply #map0(%arg5) %4 = affine.apply #map1(%arg7) %5 = memref.subview %C[%arg5, %arg7][%c2, %c4][%c1, %c1] : - memref to - memref + memref> to + memref> %6 = affine.apply #map2(%arg6) %7 = memref.subview %D[%arg7, %arg6][%c4, %c3][%c1, %c1] : - memref to - memref + memref> to + memref> %8 = memref.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] : - memref to - memref - linalg.matmul ins(%5, %7 : memref, - memref) - outs(%8 : memref) + memref> to + memref> + linalg.matmul ins(%5, %7 : memref>, + memref>) + outs(%8 : memref>) } } } - return %E : memref + return %E : memref> } // CHECK-LABEL: func @f6 // CHECK: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}}) @@ -345,43 +339,43 @@ // ----- -func.func @f7(%A: memref, - %B: memref, - %C: memref, - %D: memref, - %E: memref - ) -> memref { +func.func @f7(%A: memref>, + %B: memref>, + %C: memref>, + %D: memref>, + %E: memref> + ) -> memref> { %c1 = arith.constant 1 : index %c0 = arith.constant 0 : index %c4 = arith.constant 4 : index %c3 = arith.constant 3 : index %c2 = arith.constant 2 : index - %0 = memref.dim %A, %c0 : memref - %1 = memref.dim %A, %c1 : memref - %2 = memref.dim %C, %c1 : memref - %3 = memref.dim %C, %c0 : memref - %4 = memref.dim %D, %c1 : memref - linalg.matmul ins(%A, %C : memref, - memref) - outs(%E : memref) - linalg.matmul ins(%A, %B : memref, - memref) - outs(%C : memref) + %0 = memref.dim %A, %c0 : memref> + %1 = memref.dim %A, %c1 : memref> + %2 = memref.dim %C, %c1 : memref> + %3 = memref.dim %C, %c0 : memref> + %4 = memref.dim %D, %c1 : memref> + linalg.matmul ins(%A, %C : memref>, + memref>) + outs(%E : memref>) + linalg.matmul ins(%A, %B : memref>, + memref>) + outs(%C : memref>) scf.for %arg5 = %c0 to %0 step %c2 { scf.for %arg6 = %c0 to %2 step %c3 { scf.for %arg7 = %c0 to %1 step %c4 { %7 = memref.subview %A[%arg5, %arg7][%c2, %c4][%c1, %c1] : - memref to - memref + memref> to + memref> %9 = memref.subview %C[%arg7, %arg6][%c4, %c3][%c1, %c1] : - memref to - memref + memref> to + memref> %10 = memref.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] : - memref to - memref - linalg.matmul ins(%7, %9 : memref, - memref) - outs(%10 : memref) + memref> to + memref> + linalg.matmul ins(%7, %9 : memref>, + memref>) + outs(%10 : memref>) } } } @@ -389,31 +383,31 @@ scf.for %arg6 = %c0 to %4 step %c3 { scf.for %arg7 = %c0 to %2 step %c4 { %7 = memref.subview %C[%arg5, %arg7][%c2, %c4][%c1, %c1] : - memref to - memref + memref> to + memref> %9 = memref.subview %D[%arg7, %arg6][%c4, %c3][%c1, %c1] : - memref to - memref + memref> to + memref> %10 = memref.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] : - memref to - memref - linalg.matmul ins(%7, %9 : memref, - memref) - outs(%10 : memref) + memref> to + memref> + linalg.matmul ins(%7, %9 : memref>, + memref>) + outs(%10 : memref>) } } } - return %E : memref + return %E : memref> } // CHECK-LABEL: func @f7 // CHECK: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}}) // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[A_0:.*]] = memref.dim %[[A]], %[[C0:.*]] : memref -// CHECK: %[[A_1:.*]] = memref.dim %[[A]], %[[C1:.*]] : memref -// CHECK: %[[C_1:.*]] = memref.dim %[[C]], %[[C1:.*]] : memref -// CHECK: %[[C_0:.*]] = memref.dim %[[C]], %[[C0:.*]] : memref -// CHECK: %[[D_1:.*]] = memref.dim %[[D]], %[[C1:.*]] : memref +// CHECK: %[[A_0:.*]] = memref.dim %[[A]], %[[C0:.*]] : memref> +// CHECK: %[[A_1:.*]] = memref.dim %[[A]], %[[C1:.*]] : memref> +// CHECK: %[[C_1:.*]] = memref.dim %[[C]], %[[C1:.*]] : memref> +// CHECK: %[[C_0:.*]] = memref.dim %[[C]], %[[C0:.*]] : memref> +// CHECK: %[[D_1:.*]] = memref.dim %[[D]], %[[C1:.*]] : memref> // CHECK: linalg.matmul ins(%[[A]], %[[C]]{{.*}} outs(%[[E]] // CHECK: scf.for %{{.*}} = %{{.*}} to %[[A_0]] step %{{.*}} { // CHECK: scf.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} { @@ -432,48 +426,48 @@ #map1 = affine_map<(d0) -> (d0 + 4)> #map2 = affine_map<(d0) -> (d0 + 3)> -func.func @f8(%A: memref, - %B: memref, - %C: memref, - %D: memref, - %E: memref - ) -> memref { +func.func @f8(%A: memref>, + %B: memref>, + %C: memref>, + %D: memref>, + %E: memref> + ) -> memref> { %c1 = arith.constant 1 : index %c0 = arith.constant 0 : index %c4 = arith.constant 4 : index %c3 = arith.constant 3 : index %c2 = arith.constant 2 : index - %0 = memref.dim %A, %c0 : memref - %1 = memref.dim %A, %c1 : memref - linalg.matmul ins(%A, %C : memref, - memref) - outs(%D : memref) - linalg.matmul ins(%A, %B : memref, - memref) - outs(%C : memref) - %2 = memref.dim %D, %c1 : memref + %0 = memref.dim %A, %c0 : memref> + %1 = memref.dim %A, %c1 : memref> + linalg.matmul ins(%A, %C : memref>, + memref>) + outs(%D : memref>) + linalg.matmul ins(%A, %B : memref>, + memref>) + outs(%C : memref>) + %2 = memref.dim %D, %c1 : memref> scf.for %arg5 = %c0 to %0 step %c2 { scf.for %arg6 = %c0 to %2 step %c3 { scf.for %arg7 = %c0 to %1 step %c4 { %3 = affine.apply #map0(%arg5) %4 = affine.apply #map1(%arg7) %5 = memref.subview %A[%arg5, %arg7][%c2, %c4][%c1, %c1] : - memref to - memref + memref> to + memref> %6 = affine.apply #map2(%arg6) %7 = memref.subview %D[%arg7, %arg6][%c4, %c3][%c1, %c1] : - memref to - memref + memref> to + memref> %8 = memref.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] : - memref to - memref - linalg.matmul ins(%5, %7 : memref, - memref) - outs(%8 : memref) + memref> to + memref> + linalg.matmul ins(%5, %7 : memref>, + memref>) + outs(%8 : memref>) } } } - return %E : memref + return %E : memref> } // CHECK-LABEL: func @f8 // CHECK: (%[[A:.*]]: memref{{.*}}, %[[B:.*]]: memref{{.*}}, %[[C:.*]]: memref{{.*}}, %[[D:.*]]: memref{{.*}}, %[[E:.*]]: memref{{.*}}) @@ -492,39 +486,39 @@ indexing_maps = [#id_2d, #id_2d, #id_2d], iterator_types = ["parallel", "parallel"] } -func.func @pointwise(%A: memref, - %B: memref, - %C: memref, - %D: memref) { +func.func @pointwise(%A: memref>, + %B: memref>, + %C: memref>, + %D: memref>) { %c1 = arith.constant 1 : index %c0 = arith.constant 0 : index %c3 = arith.constant 3 : index %c2 = arith.constant 2 : index linalg.generic #pointwise_2d_trait - ins(%A, %A: memref, - memref) - outs(%B : memref) { + ins(%A, %A: memref>, + memref>) + outs(%B : memref>) { ^bb0(%E: f32, %arg5: f32, %arg6: f32): %2 = arith.addf %E, %arg5 : f32 linalg.yield %2 : f32 } - %0 = memref.dim %B, %c0 : memref - %1 = memref.dim %B, %c1 : memref + %0 = memref.dim %B, %c0 : memref> + %1 = memref.dim %B, %c1 : memref> scf.for %arg4 = %c0 to %0 step %c2 { scf.for %arg5 = %c0 to %1 step %c3 { %4 = memref.subview %B[%arg4, %arg5][%c2, %c3][%c1, %c1] : - memref to - memref + memref> to + memref> %5 = memref.subview %C[%arg4, %arg5][%c2, %c3][%c1, %c1] : - memref to - memref + memref> to + memref> %6 = memref.subview %D[%arg4, %arg5][%c2, %c3][%c1, %c1] : - memref to - memref + memref> to + memref> linalg.generic #pointwise_2d_trait - ins(%4, %5: memref, - memref) - outs(%6 : memref) { + ins(%4, %5: memref>, + memref>) + outs(%6 : memref>) { ^bb0(%arg6: f32, %arg7: f32, %arg8: f32): %7 = arith.mulf %arg6, %arg7 : f32 linalg.yield %7 : f32 @@ -572,17 +566,17 @@ scf.for %arg5 = %c0 to %1 step %c3 { %4 = memref.subview %B[%arg4, %arg5][%c2, %c3][%c1, %c1] : memref to - memref + memref> %5 = memref.subview %C[%arg4, %arg5][%c2, %c3][%c1, %c1] : memref to - memref + memref> %6 = memref.subview %D[%arg4, %arg5][%c2, %c3][%c1, %c1] : memref to - memref + memref> linalg.generic #pointwise_2d_trait - ins(%4, %5: memref, - memref) - outs(%6 : memref) { + ins(%4, %5: memref>, + memref>) + outs(%6 : memref>) { ^bb0(%arg6: f32, %arg7: f32, %arg8: f32): %7 = arith.mulf %arg6, %arg7 : f32 linalg.yield %7 : f32 @@ -719,29 +713,29 @@ %c3 = arith.constant 3 : index %c4 = arith.constant 4 : index - %A = memref.alloca(%dim, %dim)[%s0, %s1] : memref - %B = memref.alloca(%dim, %dim)[%s0, %s1] : memref - %C = memref.alloc(%dim, %dim)[%s0, %s1] : memref + %A = memref.alloca(%dim, %dim)[%s0, %s1] : memref> + %B = memref.alloca(%dim, %dim)[%s0, %s1] : memref> + %C = memref.alloc(%dim, %dim)[%s0, %s1] : memref> - linalg.matmul ins(%A, %B : memref, - memref) - outs(%C : memref) + linalg.matmul ins(%A, %B : memref>, + memref>) + outs(%C : memref>) scf.for %i = %c0 to %dim step %c2 { scf.for %j = %c0 to %dim step %c3 { scf.for %k = %c0 to %dim step %c4 { %0 = memref.subview %A[%i, %k][%c2, %c4][%c1, %c1] : - memref to - memref + memref> to + memref> %1 = memref.subview %B[%k, %j][%c4, %c3][%c1, %c1] : - memref to - memref + memref> to + memref> %2 = memref.subview %C[%i, %j][%c2, %c3][%c1, %c1] : - memref to - memref - linalg.matmul ins(%0, %1 : memref, - memref) - outs(%2 : memref) + memref> to + memref> + linalg.matmul ins(%0, %1 : memref>, + memref>) + outs(%2 : memref>) } } } diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir --- a/mlir/test/Dialect/Linalg/loops.mlir +++ b/mlir/test/Dialect/Linalg/loops.mlir @@ -4,15 +4,9 @@ // Test that we can lower all the way to LLVM without crashing, don't check results here. // RUN: mlir-opt %s -convert-linalg-to-loops -convert-linalg-to-llvm -o=/dev/null 2>&1 -// CHECK-DAG: #[[$strided1D:.*]] = affine_map<(d0)[s0] -> (d0 + s0)> -// CHECK-DAG: #[[$strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> -// CHECK-DAG: #[[$strided3D:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)> -// CHECK-DAG: #[[$stride1Dilation1:.*]] = affine_map<(d0, d1) -> (d0 + d1)> +// CHECK: #[[$stride1Dilation1:.*]] = affine_map<(d0, d1) -> (d0 + d1)> -// CHECKPARALLEL-DAG: #[[$strided1D:.*]] = affine_map<(d0)[s0] -> (d0 + s0)> -// CHECKPARALLEL-DAG: #[[$strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> -// CHECKPARALLEL-DAG: #[[$strided3D:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)> -// CHECKPARALLEL-DAG: #[[$stride1Dilation1:.*]] = affine_map<(d0, d1) -> (d0 + d1)> +// CHECKPARALLEL: #[[$stride1Dilation1:.*]] = affine_map<(d0, d1) -> (d0 + d1)> func.func @matmul(%arg0: memref, %M: index, %N: index, %K: index) { %c0 = arith.constant 0 : index @@ -163,47 +157,47 @@ // CHECK-NEXT: store %[[res]], {{.*}} : memref -func.func @dot_view(%arg0: memref, %arg1: memref, %arg2: memref) { - linalg.dot ins(%arg0, %arg1 : memref, - memref) +func.func @dot_view(%arg0: memref>, %arg1: memref>, %arg2: memref) { + linalg.dot ins(%arg0, %arg1 : memref>, + memref>) outs(%arg2: memref) return } // CHECK-LABEL: func @dot_view( -// CHECK: %{{.*}}: memref, %{{.*}}: memref, %{{.*}}: memref) { -// CHECK: %[[K:.*]] = memref.dim %arg0, %c0 : memref +// CHECK: %{{.*}}: memref>, %{{.*}}: memref>, %{{.*}}: memref) { +// CHECK: %[[K:.*]] = memref.dim %arg0, %c0 : memref> // CHECK: scf.for {{.*}} to %[[K]] -// CHECK-DAG: %[[a:.*]] = memref.load %arg0[%{{.*}}] : memref -// CHECK-DAG: %[[b:.*]] = memref.load %{{.*}}[%{{.*}}] : memref +// CHECK-DAG: %[[a:.*]] = memref.load %arg0[%{{.*}}] : memref> +// CHECK-DAG: %[[b:.*]] = memref.load %{{.*}}[%{{.*}}] : memref> // CHECK-DAG: %[[inc:.*]] = arith.mulf %[[a]], %[[b]] : f32 // CHECK-DAG: %[[c:.*]] = memref.load %{{.*}}[] : memref // CHECK-DAG: %[[res:.*]] = arith.addf %[[c]], %[[inc]] : f32 // CHECK: store %[[res]], %{{.*}}[] : memref // CHECKPARALLEL-LABEL: func @dot_view( -// CHECKPARALLEL: %{{.*}}: memref, %{{.*}}: memref, %{{.*}}: memref) { -// CHECKPARALLEL: %[[K:.*]] = memref.dim %arg0, %c0 : memref +// CHECKPARALLEL: %{{.*}}: memref>, %{{.*}}: memref>, %{{.*}}: memref) { +// CHECKPARALLEL: %[[K:.*]] = memref.dim %arg0, %c0 : memref> // CHECKPARALLEL: scf.for {{.*}} to %[[K]] -// CHECKPARALLEL-DAG: %[[a:.*]] = memref.load %arg0[%{{.*}}] : memref -// CHECKPARALLEL-DAG: %[[b:.*]] = memref.load %{{.*}}[%{{.*}}] : memref +// CHECKPARALLEL-DAG: %[[a:.*]] = memref.load %arg0[%{{.*}}] : memref> +// CHECKPARALLEL-DAG: %[[b:.*]] = memref.load %{{.*}}[%{{.*}}] : memref> // CHECKPARALLEL-DAG: %[[inc:.*]] = arith.mulf %[[a]], %[[b]] : f32 // CHECKPARALLEL-DAG: %[[c:.*]] = memref.load %{{.*}}[] : memref // CHECKPARALLEL-DAG: %[[res:.*]] = arith.addf %[[c]], %[[inc]] : f32 // CHECKPARALLEL: store %[[res]], %{{.*}}[] : memref -func.func @fill_view(%arg0: memref, %arg1: f32) { - linalg.fill ins(%arg1 : f32) outs(%arg0 : memref) +func.func @fill_view(%arg0: memref>, %arg1: f32) { + linalg.fill ins(%arg1 : f32) outs(%arg0 : memref>) return } // CHECK-LABEL: func @fill_view( -// CHECK: %{{.*}}: memref, %{{.*}}: f32) { +// CHECK: %{{.*}}: memref>, %{{.*}}: f32) { // CHECK: scf.for {{.*}} to %{{.*}} -// CHECK: store %{{.*}}, %{{.*}}[%{{.*}}] : memref +// CHECK: store %{{.*}}, %{{.*}}[%{{.*}}] : memref> // CHECKPARALLEL-LABEL: func @fill_view( -// CHECKPARALLEL: %{{.*}}: memref, %{{.*}}: f32) { +// CHECKPARALLEL: %{{.*}}: memref>, %{{.*}}: f32) { // CHECKPARALLEL: scf.parallel (%{{.*}}) = (%{{.*}}) to (%{{.*}}) step (%{{.*}}) { -// CHECKPARALLEL: store %{{.*}}, %{{.*}}[%{{.*}}] : memref +// CHECKPARALLEL: store %{{.*}}, %{{.*}}[%{{.*}}] : memref> func.func @fill_view0(%arg0: memref, %arg1: f32) { linalg.fill ins(%arg1 : f32) outs(%arg0 : memref) @@ -215,44 +209,44 @@ // CHECKPARALLEL-LABEL: func @fill_view0(%{{.*}}: memref, %{{.*}}: f32) { // CHECKPARALLEL: store %{{.*}}, %{{.*}}[] : memref -func.func @fill_view3(%arg0: memref, %arg1: f32) { - linalg.fill ins(%arg1 : f32) outs(%arg0 : memref) +func.func @fill_view3(%arg0: memref>, %arg1: f32) { + linalg.fill ins(%arg1 : f32) outs(%arg0 : memref>) return } // CHECK-LABEL: func @fill_view3( -// CHECK: %{{.*}}: memref, %{{.*}}: f32) { +// CHECK: %{{.*}}: memref>, %{{.*}}: f32) { // CHECK: scf.for {{.*}} to %{{.*}} // CHECK: scf.for {{.*}} to %{{.*}} // CHECK: scf.for {{.*}} to %{{.*}} -// CHECK: store %{{.*}}, {{.*}} : memref +// CHECK: store %{{.*}}, {{.*}} : memref> // CHECKPARALLEL-LABEL: func @fill_view3( -// CHECKPARALLEL: %{{.*}}: memref, %{{.*}}: f32) { +// CHECKPARALLEL: %{{.*}}: memref>, %{{.*}}: f32) { // CHECKPARALLEL: scf.parallel (%{{.*}}, %{{.*}}, %{{.*}}) = (%{{.*}}, %{{.*}}, %{{.*}}) to (%{{.*}}, %{{.*}}, %{{.*}}) step (%{{.*}}, %{{.*}}, %{{.*}}) { -// CHECKPARALLEL: store %{{.*}}, {{.*}} : memref +// CHECKPARALLEL: store %{{.*}}, {{.*}} : memref> -func.func @copy_view(%arg0: memref, %arg1: memref) { +func.func @copy_view(%arg0: memref>, %arg1: memref>) { linalg.generic { iterator_types = ["parallel"], indexing_maps = [ affine_map<(i) -> (i)>, affine_map<(i) -> (i)>] } - ins(%arg0: memref) - outs(%arg1: memref) { + ins(%arg0: memref>) + outs(%arg1: memref>) { ^bb0(%a: f32, %b: f32): linalg.yield %a : f32 } return } // CHECK-LABEL: func @copy_view( -// CHECK: %{{.*}}: memref, %{{.*}}: memref) { +// CHECK: %{{.*}}: memref>, %{{.*}}: memref>) { // CHECK: scf.for {{.*}} to %{{.*}} -// CHECK: %[[L:.*]] = memref.load %{{.*}}[%{{.*}}] : memref -// CHECK: store %[[L]], %{{.*}}[%{{.*}}] : memref +// CHECK: %[[L:.*]] = memref.load %{{.*}}[%{{.*}}] : memref> +// CHECK: store %[[L]], %{{.*}}[%{{.*}}] : memref> // CHECKPARALLEL-LABEL: func @copy_view( -// CHECKPARALLEL: %{{.*}}: memref, %{{.*}}: memref) { +// CHECKPARALLEL: %{{.*}}: memref>, %{{.*}}: memref>) { // CHECKPARALLEL: scf.parallel (%{{.*}}) = (%{{.*}}) to (%{{.*}}) step (%{{.*}}) { -// CHECKPARALLEL: %[[L:.*]] = memref.load %{{.*}}[%{{.*}}] : memref -// CHECKPARALLEL: store %[[L]], %{{.*}}[%{{.*}}] : memref +// CHECKPARALLEL: %[[L:.*]] = memref.load %{{.*}}[%{{.*}}] : memref> +// CHECKPARALLEL: store %[[L]], %{{.*}}[%{{.*}}] : memref> #accesses = [ affine_map<(i, j, k) -> (i, j)>, @@ -267,11 +261,11 @@ library_call = "some_external_function_name_2", doc = "B(i,j,k), C(i,k,j) = foo(A(i, j), B(i,j,k), C(i,k,j))" } -func.func @generic_region(%arg0: memref, %arg1: memref, %arg2: memref) { +func.func @generic_region(%arg0: memref>, %arg1: memref>, %arg2: memref>) { linalg.generic #trait2 - ins(%arg0: memref) - outs(%arg1, %arg2 : memref, - memref) { + ins(%arg0: memref>) + outs(%arg1, %arg2 : memref>, + memref>) { ^bb0(%a: f32, %b: f32, %c: f32): %d = arith.mulf %a, %b : f32 %e = arith.addf %c, %d : f32 @@ -283,23 +277,23 @@ // CHECK: scf.for %[[i:.*]] = {{.*}} // CHECK: scf.for %[[j:.*]] = {{.*}} // CHECK: scf.for %[[k:.*]] = {{.*}} -// CHECK: %[[a:.*]] = memref.load %{{.*}}[%[[i]], %[[j]]] : memref -// CHECK: %[[b:.*]] = memref.load %{{.*}}[%[[i]], %[[j]], %[[k]]] : memref -// CHECK: %[[c:.*]] = memref.load %{{.*}}[%[[i]], %[[k]], %[[j]]] : memref +// CHECK: %[[a:.*]] = memref.load %{{.*}}[%[[i]], %[[j]]] : memref> +// CHECK: %[[b:.*]] = memref.load %{{.*}}[%[[i]], %[[j]], %[[k]]] : memref> +// CHECK: %[[c:.*]] = memref.load %{{.*}}[%[[i]], %[[k]], %[[j]]] : memref> // CHECK: %[[d:.*]] = arith.mulf %[[a]], %[[b]] : f32 // CHECK: %[[e:.*]] = arith.addf %[[c]], %[[d]] : f32 -// CHECK: store %[[d]], %{{.*}}[%[[i]], %[[j]], %[[k]]] : memref -// CHECK: store %[[e]], %{{.*}}[%[[i]], %[[k]], %[[j]]] : memref +// CHECK: store %[[d]], %{{.*}}[%[[i]], %[[j]], %[[k]]] : memref> +// CHECK: store %[[e]], %{{.*}}[%[[i]], %[[k]], %[[j]]] : memref> // CHECKPARALLEL-LABEL: @generic_region // CHECKPARALLEL: scf.parallel (%[[i:[a-zA-Z0-9_]*]], %[[j:[a-zA-Z0-9_]*]], %[[k:[a-zA-Z0-9_]*]]) -// CHECKPARALLEL: %[[a:.*]] = memref.load %{{.*}}[%[[i]], %[[j]]] : memref -// CHECKPARALLEL: %[[b:.*]] = memref.load %{{.*}}[%[[i]], %[[j]], %[[k]]] : memref -// CHECKPARALLEL: %[[c:.*]] = memref.load %{{.*}}[%[[i]], %[[k]], %[[j]]] : memref +// CHECKPARALLEL: %[[a:.*]] = memref.load %{{.*}}[%[[i]], %[[j]]] : memref> +// CHECKPARALLEL: %[[b:.*]] = memref.load %{{.*}}[%[[i]], %[[j]], %[[k]]] : memref> +// CHECKPARALLEL: %[[c:.*]] = memref.load %{{.*}}[%[[i]], %[[k]], %[[j]]] : memref> // CHECKPARALLEL: %[[d:.*]] = arith.mulf %[[a]], %[[b]] : f32 // CHECKPARALLEL: %[[e:.*]] = arith.addf %[[c]], %[[d]] : f32 -// CHECKPARALLEL: store %[[d]], %{{.*}}[%[[i]], %[[j]], %[[k]]] : memref -// CHECKPARALLEL: store %[[e]], %{{.*}}[%[[i]], %[[k]], %[[j]]] : memref +// CHECKPARALLEL: store %[[d]], %{{.*}}[%[[i]], %[[j]], %[[k]]] : memref> +// CHECKPARALLEL: store %[[e]], %{{.*}}[%[[i]], %[[k]], %[[j]]] : memref> #trait4 = { args_in = 1, @@ -310,13 +304,13 @@ doc = "B(i,j,k), C(i,k,j) = foo(A(i, j) * B(i,j,k), i * j * k + C(i,k,j))" } func.func @generic_index_region( - %arg0: memref, - %arg1: memref, - %arg2: memref) { + %arg0: memref>, + %arg1: memref>, + %arg2: memref>) { linalg.generic #trait4 - ins(%arg0 : memref) - outs(%arg1, %arg2 : memref, - memref) { + ins(%arg0 : memref>) + outs(%arg1, %arg2 : memref>, + memref>) { ^bb0(%a: f32, %b: f32, %c: f32): %i = linalg.index 0 : index %j = linalg.index 1 : index @@ -855,14 +849,14 @@ %arg0 : memref, %arg1 : memref, %arg2 : index, %arg3 : index, %arg4 : index) { %0 = memref.subview %arg0[%arg2] [%arg3] [1] - : memref to memref + : memref to memref> %1 = memref.subview %arg1[0, %arg4] [1, %arg3] [1, 1] - : memref to memref + : memref to memref> linalg.generic { iterator_types = ["parallel"], indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>]} - ins(%0: memref) - outs(%1: memref) { + ins(%0: memref>) + outs(%1: memref>) { ^bb0(%a: i32, %b: i32): linalg.yield %a : i32 } diff --git a/mlir/test/Dialect/Linalg/promote.mlir b/mlir/test/Dialect/Linalg/promote.mlir --- a/mlir/test/Dialect/Linalg/promote.mlir +++ b/mlir/test/Dialect/Linalg/promote.mlir @@ -21,13 +21,13 @@ scf.for %arg4 = %c0 to %6 step %c2 { scf.for %arg5 = %c0 to %8 step %c3 { scf.for %arg6 = %c0 to %7 step %c4 { - %11 = memref.subview %3[%arg4, %arg6][%c2, %c4][1, 1] : memref to memref - %14 = memref.subview %4[%arg6, %arg5][%c4, %c3][1, 1] : memref to memref - %17 = memref.subview %5[%arg4, %arg5][%c2, %c3][1, 1] : memref to memref + %11 = memref.subview %3[%arg4, %arg6][%c2, %c4][1, 1] : memref to memref> + %14 = memref.subview %4[%arg6, %arg5][%c4, %c3][1, 1] : memref to memref> + %17 = memref.subview %5[%arg4, %arg5][%c2, %c3][1, 1] : memref to memref> linalg.matmul - ins(%11, %14: memref, - memref) - outs(%17: memref) + ins(%11, %14: memref>, + memref>) + outs(%17: memref>) } } } @@ -54,15 +54,15 @@ // CHECK: %[[fullC:.*]] = memref.view %[[tmpC]][{{.*}}][{{.*}}] : memref<24xi8> to memref // CHECK: %[[partialC:.*]] = memref.subview %[[fullC]]{{.*}} : memref to memref -// CHECK: memref.copy %[[vA]], %[[partialA]] : memref to memref -// CHECK: memref.copy %[[vB]], %[[partialB]] : memref to memref -// CHECK: memref.copy %[[vC]], %[[partialC]] : memref to memref +// CHECK: memref.copy %[[vA]], %[[partialA]] : memref> to memref +// CHECK: memref.copy %[[vB]], %[[partialB]] : memref> to memref +// CHECK: memref.copy %[[vC]], %[[partialC]] : memref> to memref // // CHECK: linalg.matmul ins(%[[partialA]], %[[partialB]]{{.*}} outs(%[[partialC]] // // CHECK: memref.copy %[[partialC]], %[[vC]] : // CHECK: memref to -// CHECK: memref +// CHECK: memref> // // CHECK-NOT: memref.dealloc %[[tmpA]] : memref<32xi8> // CHECK-NOT: memref.dealloc %[[tmpB]] : memref<48xi8> @@ -94,13 +94,13 @@ scf.for %arg4 = %c0 to %6 step %c2 { scf.for %arg5 = %c0 to %8 step %c3 { scf.for %arg6 = %c0 to %7 step %c4 { - %11 = memref.subview %3[%arg4, %arg6][%c2, %c4][1, 1] : memref to memref - %14 = memref.subview %4[%arg6, %arg5][%c4, %c3][1, 1] : memref to memref - %17 = memref.subview %5[%arg4, %arg5][%c2, %c3][1, 1] : memref to memref + %11 = memref.subview %3[%arg4, %arg6][%c2, %c4][1, 1] : memref to memref> + %14 = memref.subview %4[%arg6, %arg5][%c4, %c3][1, 1] : memref to memref> + %17 = memref.subview %5[%arg4, %arg5][%c2, %c3][1, 1] : memref to memref> linalg.matmul - ins(%11, %14: memref, - memref) - outs(%17: memref) + ins(%11, %14: memref>, + memref>) + outs(%17: memref>) } } } @@ -127,15 +127,15 @@ // CHECK: %[[fullC_f64:.*]] = memref.view %[[tmpC_f64]][{{.*}}][{{.*}}] : memref<48xi8> to memref // CHECK: %[[partialC_f64:.*]] = memref.subview %[[fullC_f64]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref to memref -// CHECK: memref.copy %[[vA_f64]], %[[partialA_f64]] : memref to memref -// CHECK: memref.copy %[[vB_f64]], %[[partialB_f64]] : memref to memref -// CHECK: memref.copy %[[vC_f64]], %[[partialC_f64]] : memref to memref +// CHECK: memref.copy %[[vA_f64]], %[[partialA_f64]] : memref> to memref +// CHECK: memref.copy %[[vB_f64]], %[[partialB_f64]] : memref> to memref +// CHECK: memref.copy %[[vC_f64]], %[[partialC_f64]] : memref> to memref // // CHECK: linalg.matmul ins(%[[partialA_f64]], %[[partialB_f64]]{{.*}} outs(%[[partialC_f64]] // // CHECK: memref.copy %[[partialC_f64]], %[[vC_f64]] : // CHECK: memref to -// CHECK: memref +// CHECK: memref> // // CHECK: memref.dealloc %[[tmpA_f64]] : memref<64xi8> // CHECK: memref.dealloc %[[tmpB_f64]] : memref<96xi8> diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -6,10 +6,7 @@ // Test that we can lower all the way to LLVM without crashing, don't check results here. // DISABLED: mlir-opt %s --convert-linalg-to-llvm -o=/dev/null 2>&1 -// CHECK-DAG: #[[$strided1D:.*]] = affine_map<(d0)[s0] -> (d0 + s0)> -// CHECK-DAG: #[[$strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> -// CHECK-DAG: #[[$strided3D:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)> -// CHECK-DAG: #[[$strided3DT:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2 * s1 + s0 + d1 * s2 + d0)> +// CHECK: #[[$strided3DT:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2 * s1 + s0 + d1 * s2 + d0)> func.func @views(%arg0: index) { %c0 = arith.constant 0 : index @@ -31,65 +28,65 @@ // ----- -func.func @ops(%arg0: memref, - %arg1: memref, - %arg2: memref, +func.func @ops(%arg0: memref>, + %arg1: memref>, + %arg2: memref>, %arg3: memref) { - linalg.matmul ins(%arg0, %arg0 : memref, - memref) - outs(%arg0 : memref) - linalg.matvec ins(%arg0, %arg1: memref, - memref) - outs(%arg2: memref) - linalg.dot ins(%arg1, %arg2: memref, - memref) + linalg.matmul ins(%arg0, %arg0 : memref>, + memref>) + outs(%arg0 : memref>) + linalg.matvec ins(%arg0, %arg1: memref>, + memref>) + outs(%arg2: memref>) + linalg.dot ins(%arg1, %arg2: memref>, + memref>) outs(%arg3: memref) return } // CHECK-LABEL: func @ops(% // CHECK: linalg.matmul -// CHECK-SAME: ins(%{{.*}}, %{{.*}} : memref, -// CHECK-SAME: memref) -// CHECK-SAME: outs(%{{.*}} : memref) +// CHECK-SAME: ins(%{{.*}}, %{{.*}} : memref>, +// CHECK-SAME: memref>) +// CHECK-SAME: outs(%{{.*}} : memref>) // CHECK: linalg.matvec -// CHECK-SAME: ins(%{{.*}}, %{{.*}}: memref, -// CHECK-SAME: memref) -// CHECK-SAME: outs(%{{.*}}: memref) +// CHECK-SAME: ins(%{{.*}}, %{{.*}}: memref>, +// CHECK-SAME: memref>) +// CHECK-SAME: outs(%{{.*}}: memref>) // CHECK: linalg.dot -// CHECK-SAME: ins(%{{.*}}, %{{.*}}: memref, -// CHECK-SAME: memref) +// CHECK-SAME: ins(%{{.*}}, %{{.*}}: memref>, +// CHECK-SAME: memref>) // CHECK-SAME: outs(%{{.*}}: memref) // ----- -func.func @fill_view(%arg0: memref, %arg1: f32) { - linalg.fill ins(%arg1 : f32) outs(%arg0 : memref) +func.func @fill_view(%arg0: memref>, %arg1: f32) { + linalg.fill ins(%arg1 : f32) outs(%arg0 : memref>) return } // CHECK-LABEL: func @fill_view( -// CHECK: %{{.*}}: memref, %{{.*}}: f32) { -// CHECK: linalg.fill ins(%{{.*}} : f32) outs(%{{.*}} : memref) +// CHECK: %{{.*}}: memref>, %{{.*}}: f32) { +// CHECK: linalg.fill ins(%{{.*}} : f32) outs(%{{.*}} : memref>) // ----- -func.func @transpose(%arg0: memref) { - %0 = memref.transpose %arg0 (i, j, k) -> (k, j, i) : memref to memref (d2 * s1 + s0 + d1 * s2 + d0)>> +func.func @transpose(%arg0: memref>) { + %0 = memref.transpose %arg0 (i, j, k) -> (k, j, i) : memref> to memref (d2 * s1 + s0 + d1 * s2 + d0)>> return } // CHECK-LABEL: func @transpose // CHECK: memref.transpose %{{.*}} ([[i:.*]], [[j:.*]], [[k:.*]]) -> ([[k]], [[j]], [[i]]) : -// CHECK-SAME: memref to memref +// CHECK-SAME: memref> to memref // ----- -func.func @fill_view3(%arg0: memref, %arg1: f32) { - linalg.fill ins(%arg1 : f32) outs(%arg0 : memref) +func.func @fill_view3(%arg0: memref>, %arg1: f32) { + linalg.fill ins(%arg1 : f32) outs(%arg0 : memref>) return } // CHECK-LABEL: func @fill_view3( -// CHECK: %{{.*}}: memref, %{{.*}}: f32) { -// CHECK: linalg.fill ins(%{{.*}} : f32) outs(%{{.*}} : memref) +// CHECK: %{{.*}}: memref>, %{{.*}}: f32) { +// CHECK: linalg.fill ins(%{{.*}} : f32) outs(%{{.*}} : memref>) // ----- @@ -105,12 +102,12 @@ library_call = "some_external_function_name_1" } -func.func @generic(%arg0: memref, offset: ?, strides: [?, 1]>, - %arg1: memref) { +func.func @generic(%arg0: memref, strided<[?, 1], offset: ?>>, + %arg1: memref>) { %cst = arith.constant 0.0 : f32 linalg.generic #trait_0 - ins(%arg0, %cst : memref, offset: ?, strides: [?, 1]>, f32) - outs(%arg1 : memref) + ins(%arg0, %cst : memref, strided<[?, 1], offset: ?>>, f32) + outs(%arg1 : memref>) attrs = {foo = 1} { ^bb(%0: vector<3x4xi4>, %1: f32, %2: f32) : linalg.yield %1 : f32 @@ -122,16 +119,16 @@ // CHECK-SAME: indexing_maps = [#{{[0-9a-z]*}}, #{{[0-9a-z]*}}, #{{[0-9a-z]*}}], // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"], // CHECK-SAME: library_call = "some_external_function_name_1"} -// CHECK-SAME: ins({{.*}}, {{.*}} : memref, #[[$strided2D]]>, f32) -// CHECK-SAME: outs({{.*}} : memref) +// CHECK-SAME: ins({{.*}}, {{.*}} : memref, strided<[?, 1], offset: ?>>, f32) +// CHECK-SAME: outs({{.*}} : memref>) // CHECK-SAME: {foo = 1 : i64} func.func @generic_with_tensor_input(%arg0: tensor>, - %arg1: memref) { + %arg1: memref>) { %cst = arith.constant 0.0 : f32 linalg.generic #trait_0 ins(%arg0, %cst : tensor>, f32) - outs(%arg1 : memref) + outs(%arg1 : memref>) attrs = {foo = 1} { ^bb(%0: vector<3x4xi4>, %1: f32, %2: f32) : linalg.yield %1 : f32 @@ -143,7 +140,7 @@ // CHECK-SAME: indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], // CHECK-SAME: library_call = "some_external_function_name_1"} // CHECK-SAME: ins({{.*}}, {{.*}} : tensor>, f32) -// CHECK-SAME: outs({{.*}} : memref) +// CHECK-SAME: outs({{.*}} : memref>) // CHECK-SAME: {foo = 1 : i64} // ----- @@ -272,11 +269,11 @@ library_call = "some_external_function_name_2" } -func.func @generic_region(%arg0: memref, offset: ?, strides: [?, 1]>, - %arg1: memref) { +func.func @generic_region(%arg0: memref, strided<[?, 1], offset: ?>>, + %arg1: memref>) { linalg.generic #trait_3 - ins(%arg0 : memref, offset: ?, strides: [?, 1]>) - outs(%arg1 : memref) + ins(%arg0 : memref, strided<[?, 1], offset: ?>>) + outs(%arg1 : memref>) attrs = {foo = 1} { ^bb(%a: vector<3x4xi4>, %b: f32) : %0 = linalg.index 0 : index @@ -291,8 +288,8 @@ // CHECK-SAME: indexing_maps = [#{{[0-9a-z]*}}, #{{[0-9a-z]*}}], // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"], // CHECK-SAME: library_call = "some_external_function_name_2" -// CHECK-SAME: ins({{.*}} : memref, #[[$strided2D]]>) -// CHECK-SAME: outs({{.*}} : memref) +// CHECK-SAME: ins({{.*}} : memref, strided<[?, 1], offset: ?>>) +// CHECK-SAME: outs({{.*}} : memref>) // CHECK-SAME: attrs = {foo = 1 : i64} { // CHECK: ^{{.*}}(%{{.*}}: vector<3x4xi4>, %{{.*}}: f32): // CHECK: %{{.*}} = linalg.index 0 : index diff --git a/mlir/test/Dialect/Linalg/standard.mlir b/mlir/test/Dialect/Linalg/standard.mlir --- a/mlir/test/Dialect/Linalg/standard.mlir +++ b/mlir/test/Dialect/Linalg/standard.mlir @@ -1,25 +1,24 @@ // RUN: mlir-opt %s -convert-linalg-to-std | FileCheck %s -// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0)[s0] -> (d0 + s0)> // CHECK-DAG: #[[$map6:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> // CHECK-DAG: #[[$map7:.*]] = affine_map<()[s0] -> (s0)> -func.func @dot(%arg0: memref, - %arg1: memref, +func.func @dot(%arg0: memref>, + %arg1: memref>, %arg2: memref) { - linalg.dot ins(%arg0, %arg1: memref, - memref) + linalg.dot ins(%arg0, %arg1: memref>, + memref>) outs(%arg2: memref) return } // CHECK-LABEL: func @dot( -// CHECK-SAME: %[[arg0:[a-zA-z0-9]*]]: memref, -// CHECK-SAME: %[[arg1:[a-zA-z0-9]*]]: memref, +// CHECK-SAME: %[[arg0:[a-zA-z0-9]*]]: memref>, +// CHECK-SAME: %[[arg1:[a-zA-z0-9]*]]: memref>, // CHECK-SAME: %[[arg2:[a-zA-z0-9]*]]: memref) { // CHECK: %[[o0:.*]] = memref.cast %[[arg0]] : -// CHECK-SAME: memref to memref +// CHECK-SAME: memref> to memref // CHECK: %[[o1:.*]] = memref.cast %[[arg1]] : -// CHECK-SAME: memref to memref +// CHECK-SAME: memref> to memref // CHECK: %[[o2:.*]] = memref.cast %[[arg2]] : // CHECK-SAME: memref to memref // CHECK: call @linalg_dot_viewsxf32_viewsxf32_viewf32( diff --git a/mlir/test/Dialect/Linalg/tile-parallel.mlir b/mlir/test/Dialect/Linalg/tile-parallel.mlir --- a/mlir/test/Dialect/Linalg/tile-parallel.mlir +++ b/mlir/test/Dialect/Linalg/tile-parallel.mlir @@ -11,13 +11,13 @@ iterator_types = ["parallel", "parallel"] } -func.func @sum(%lhs: memref, - %rhs: memref, - %sum: memref) { +func.func @sum(%lhs: memref>, + %rhs: memref>, + %sum: memref>) { linalg.generic #pointwise_2d_trait - ins(%lhs, %rhs: memref, - memref) - outs(%sum : memref) { + ins(%lhs, %rhs: memref>, + memref>) + outs(%sum : memref>) { ^bb0(%lhs_in: f32, %rhs_in: f32, %sum_out: f32): %result = arith.addf %lhs_in, %rhs_in : f32 linalg.yield %result : f32 @@ -25,7 +25,7 @@ return } // TILE-2-LABEL: func @sum( -// TILE-2-SAME: [[LHS:%.*]]: {{.*}}, [[RHS:%.*]]: {{.*}}, [[SUM:%.*]]: {{.*}}) { +// TILE-2-SAME: [[LHS:%.*]]: memref{{.*}}, [[RHS:%.*]]: memref{{.*}}, [[SUM:%.*]]: memref{{.*}}) { // TILE-2-DAG: [[C0:%.*]] = arith.constant 0 : index // TILE-2-DAG: [[C2:%.*]] = arith.constant 2 : index // TILE-2: [[LHS_ROWS:%.*]] = memref.dim [[LHS]], %c0 @@ -37,7 +37,7 @@ // TILE-2: linalg.generic {{.*}} ins([[LHS_SUBVIEW]], [[RHS_SUBVIEW]]{{.*}} outs([[SUM_SUBVIEW]] // TILE-02-LABEL: func @sum( -// TILE-02-SAME: [[LHS:%.*]]: {{.*}}, [[RHS:%.*]]: {{.*}}, [[SUM:%.*]]: {{.*}}) { +// TILE-02-SAME: [[LHS:%.*]]: memref{{.*}}, [[RHS:%.*]]: memref{{.*}}, [[SUM:%.*]]: memref{{.*}}) { // TILE-02-DAG: [[C0:%.*]] = arith.constant 0 : index // TILE-02-DAG: [[C2:%.*]] = arith.constant 2 : index // TILE-02: [[LHS_COLS:%.*]] = memref.dim [[LHS]], %c1 @@ -49,12 +49,12 @@ // TILE-02: linalg.generic {{.*}} ins([[LHS_SUBVIEW]], [[RHS_SUBVIEW]]{{.*}} outs([[SUM_SUBVIEW]] // TILE-002-LABEL: func @sum( -// TILE-002-SAME: [[LHS:%.*]]: {{.*}}, [[RHS:%.*]]: {{.*}}, [[SUM:%.*]]: {{.*}}) { +// TILE-002-SAME: [[LHS:%.*]]: memref{{.*}}, [[RHS:%.*]]: memref{{.*}}, [[SUM:%.*]]: memref{{.*}}) { // TILE-002-NO: scf.parallel // TILE-002: linalg.generic {{.*}} ins([[LHS]], [[RHS]]{{.*}} outs([[SUM]] // TILE-234-LABEL: func @sum( -// TILE-234-SAME: [[LHS:%.*]]: {{.*}}, [[RHS:%.*]]: {{.*}}, [[SUM:%.*]]: {{.*}}) { +// TILE-234-SAME: [[LHS:%.*]]: memref{{.*}}, [[RHS:%.*]]: memref{{.*}}, [[SUM:%.*]]: memref{{.*}}) { // TILE-234-DAG: [[C0:%.*]] = arith.constant 0 : index // TILE-234-DAG: [[C2:%.*]] = arith.constant 2 : index // TILE-234-DAG: [[C3:%.*]] = arith.constant 3 : index diff --git a/mlir/test/Dialect/Linalg/tile.mlir b/mlir/test/Dialect/Linalg/tile.mlir --- a/mlir/test/Dialect/Linalg/tile.mlir +++ b/mlir/test/Dialect/Linalg/tile.mlir @@ -5,7 +5,6 @@ // TILE-2-DAG: #[[$strided1D:.*]] = affine_map<(d0)[s0] -> (d0 + s0)> // TILE-02-DAG: #[[$strided1D:.*]] = affine_map<(d0)[s0] -> (d0 + s0)> -// TILE-002-DAG: #[[$strided1D:.*]] = affine_map<(d0)[s0] -> (d0 + s0)> // TILE-234-DAG: #[[$strided1D:.*]] = affine_map<(d0)[s0] -> (d0 + s0)> // TILE-2-DAG: #[[$strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> @@ -24,52 +23,52 @@ // TILE-02-DAG: #[[$stride_99_1_layout_map:.*]] = affine_map<(d0, d1)[s0] -> (d0 * 99 + s0 + d1)> // TILE-234-DAG: #[[$stride_99_1_layout_map:.*]] = affine_map<(d0, d1)[s0] -> (d0 * 99 + s0 + d1)> -func.func @matmul(%arg0: memref, - %arg1: memref, - %arg2: memref) { +func.func @matmul(%arg0: memref>, + %arg1: memref>, + %arg2: memref>) { linalg.matmul - ins(%arg0, %arg1: memref, - memref) - outs(%arg2: memref) + ins(%arg0, %arg1: memref>, + memref>) + outs(%arg2: memref>) return } // TILE-2-LABEL: func @matmul( // TILE-2-DAG: %[[C0:.*]] = arith.constant 0 : index // TILE-2-DAG: %[[C2:.*]] = arith.constant 2 : index -// TILE-2: %[[M:.*]] = memref.dim %{{.*}}, %c0 : memref +// TILE-2: %[[M:.*]] = memref.dim %{{.*}}, %c0 : memref> // TILE-2: scf.for %[[I:.*]] = %{{.*}}{{.*}} to %[[M]] step %{{.*}} { // TILE-2: %[[szM:.*]] = affine.min #[[$bound_map]](%[[I]])[%[[M]]] -// TILE-2: %[[K:.*]] = memref.dim %{{.*}}, %c1 : memref +// TILE-2: %[[K:.*]] = memref.dim %{{.*}}, %c1 : memref> // TILE-2: %[[szK:.*]] = affine.min #[[$bound_map]](%[[I]])[%[[M]]] -// TILE-2: %[[N:.*]] = memref.dim %{{.*}}, %c1 : memref -// TILE-2: %[[sAi:.*]] = memref.subview %{{.*}}[%[[I]], 0] [%[[szM]], %[[K]]] [1, 1] : memref to memref -// TILE-2: %[[sCi:.*]] = memref.subview %{{.*}}[%[[I]], 0] [%[[szK]], %[[N]]] [1, 1] : memref to memref +// TILE-2: %[[N:.*]] = memref.dim %{{.*}}, %c1 : memref> +// TILE-2: %[[sAi:.*]] = memref.subview %{{.*}}[%[[I]], 0] [%[[szM]], %[[K]]] [1, 1] : memref> to memref +// TILE-2: %[[sCi:.*]] = memref.subview %{{.*}}[%[[I]], 0] [%[[szK]], %[[N]]] [1, 1] : memref> to memref // TILE-2: linalg.matmul ins(%[[sAi]]{{.*}} outs(%[[sCi]] // TILE-02-LABEL: func @matmul( // TILE-02-DAG: %[[C0:.*]] = arith.constant 0 : index // TILE-02-DAG: %[[C2:.*]] = arith.constant 2 : index -// TILE-02: %[[N:.*]] = memref.dim %arg1, %c1 : memref +// TILE-02: %[[N:.*]] = memref.dim %arg1, %c1 : memref> // TILE-02: scf.for %[[J:.*]] = %{{.*}} to %[[N]] step %{{.*}} { -// TILE-02: %[[K:.*]] = memref.dim %{{.*}}, %c0 : memref +// TILE-02: %[[K:.*]] = memref.dim %{{.*}}, %c0 : memref> // TILE-02: %[[szN:.*]] = affine.min #[[$bound_map]](%[[J]])[%[[N]]] -// TILE-02: %[[M:.*]] = memref.dim %{{.*}}, %c0 : memref +// TILE-02: %[[M:.*]] = memref.dim %{{.*}}, %c0 : memref> // TILE-02: %[[szK:.*]] = affine.min #[[$bound_map]](%[[J]])[%[[N]]] -// TILE-02: %[[sBj:.*]] = memref.subview %{{.*}}[0, %[[J]]] [%[[K]], %[[szN]]] [1, 1] : memref to memref -// TILE-02: %[[sCj:.*]] = memref.subview %{{.*}}[0, %[[J]]] [%[[M]], %[[szK]]] [1, 1] : memref to memref +// TILE-02: %[[sBj:.*]] = memref.subview %{{.*}}[0, %[[J]]] [%[[K]], %[[szN]]] [1, 1] : memref> to memref +// TILE-02: %[[sCj:.*]] = memref.subview %{{.*}}[0, %[[J]]] [%[[M]], %[[szK]]] [1, 1] : memref> to memref // TILE-02: linalg.matmul ins(%{{.*}}, %[[sBj]]{{.*}} outs(%[[sCj]] // TILE-002-LABEL: func @matmul( // TILE-002-DAG: %[[C0:.*]] = arith.constant 0 : index // TILE-002-DAG: %[[C2:.*]] = arith.constant 2 : index -// TILE-002: %[[ubK:.*]] = memref.dim %{{.*}}, %c1 : memref +// TILE-002: %[[ubK:.*]] = memref.dim %{{.*}}, %c1 : memref> // TILE-002: scf.for %[[K:.*]] = %{{.*}}{{.*}} to %[[ubK]] step %{{.*}} { -// TILE-002: %[[M:.*]] = memref.dim %{{.*}}, %c0 : memref +// TILE-002: %[[M:.*]] = memref.dim %{{.*}}, %c0 : memref> // TILE-002: %[[szK:.*]] = affine.min #[[$bound_map]](%[[K]])[%[[ubK]]] // TILE-002: %[[szK_1:.*]] = affine.min #[[$bound_map]](%[[K]])[%[[ubK]]] -// TILE-002: %[[N:.*]] = memref.dim %{{.*}}, %c1 : memref -// TILE-002: %[[sAj:.*]] = memref.subview %{{.*}}[0, %[[K]]] [%[[M]], %[[szK]]] [1, 1] : memref to memref -// TILE-002: %[[sBj:.*]] = memref.subview %{{.*}}[%[[K]], 0] [%[[szK_1]], %[[N]]] [1, 1] : memref to memref +// TILE-002: %[[N:.*]] = memref.dim %{{.*}}, %c1 : memref> +// TILE-002: %[[sAj:.*]] = memref.subview %{{.*}}[0, %[[K]]] [%[[M]], %[[szK]]] [1, 1] : memref> to memref +// TILE-002: %[[sBj:.*]] = memref.subview %{{.*}}[%[[K]], 0] [%[[szK_1]], %[[N]]] [1, 1] : memref> to memref // TILE-002: linalg.matmul ins(%[[sAj]], %[[sBj]]{{.*}} outs(%{{.*}} // TILE-234-LABEL: func @matmul( @@ -77,9 +76,9 @@ // TILE-234-DAG: %[[C2:.*]] = arith.constant 2 : index // TILE-234-DAG: %[[C3:.*]] = arith.constant 3 : index // TILE-234-DAG: %[[C4:.*]] = arith.constant 4 : index -// TILE-234: %[[ubM:.*]] = memref.dim %{{.*}}, %c0 : memref -// TILE-234: %[[ubK:.*]] = memref.dim %{{.*}}, %c1 : memref -// TILE-234: %[[ubN:.*]] = memref.dim %{{.*}}, %c1 : memref +// TILE-234: %[[ubM:.*]] = memref.dim %{{.*}}, %c0 : memref> +// TILE-234: %[[ubK:.*]] = memref.dim %{{.*}}, %c1 : memref> +// TILE-234: %[[ubN:.*]] = memref.dim %{{.*}}, %c1 : memref> // TILE-234: scf.for %[[I:.*]] = %{{.*}}{{.*}} to %[[ubM]] step %{{.*}} { // TILE-234: scf.for %[[J:.*]] = %{{.*}}{{.*}} to %[[ubN]] step %{{.*}} { // TILE-234: scf.for %[[K:.*]] = %{{.*}}{{.*}} to %[[ubK]] step %{{.*}} { @@ -89,9 +88,9 @@ // TILE-234: %[[szN:.*]] = affine.min #[[$bound_map_3]](%[[J]])[%[[ubN]]] // TILE-234: %[[szM_1:.*]] = affine.min #[[$bound_map_2]](%[[I]])[%[[ubM]]] // TILE-234: %[[szN_1:.*]] = affine.min #[[$bound_map_3]](%[[J]])[%[[ubN]]] -// TILE-234: %[[sAik:.*]] = memref.subview %{{.*}}[%[[I]], %[[K]]] [%[[szM]], %[[szK]]] [1, 1] : memref to memref -// TILE-234: %[[sBkj:.*]] = memref.subview %{{.*}}[%[[K]], %[[J]]] [%[[szK_1]], %[[szN]]] [1, 1] : memref to memref -// TILE-234: %[[sCij:.*]] = memref.subview %{{.*}}[%[[I]], %[[J]]] [%[[szM_1]], %[[szN_1]]] [1, 1] : memref to memref +// TILE-234: %[[sAik:.*]] = memref.subview %{{.*}}[%[[I]], %[[K]]] [%[[szM]], %[[szK]]] [1, 1] : memref> to memref +// TILE-234: %[[sBkj:.*]] = memref.subview %{{.*}}[%[[K]], %[[J]]] [%[[szK_1]], %[[szN]]] [1, 1] : memref> to memref +// TILE-234: %[[sCij:.*]] = memref.subview %{{.*}}[%[[I]], %[[J]]] [%[[szM_1]], %[[szN_1]]] [1, 1] : memref> to memref // // TILE-234: linalg.matmul ins(%[[sAik]], %[[sBkj]]{{.*}} outs(%[[sCij]] @@ -99,13 +98,13 @@ // the "min" in subview size computation. This test uses buffer sizes divisible // by respective tile sizes (M=10 divisble by 2, N=12 divisible by 2 and 3, // K=16 divisble by 2 and 4). -func.func @matmul_static(%arg0: memref<10x16xf32, offset: ?, strides: [?, 1]>, - %arg1: memref<16x12xf32, offset: ?, strides: [?, 1]>, - %arg2: memref<10x12xf32, offset: ?, strides: [?, 1]>) { +func.func @matmul_static(%arg0: memref<10x16xf32, strided<[?, 1], offset: ?>>, + %arg1: memref<16x12xf32, strided<[?, 1], offset: ?>>, + %arg2: memref<10x12xf32, strided<[?, 1], offset: ?>>) { linalg.matmul - ins(%arg0, %arg1: memref<10x16xf32, offset: ?, strides: [?, 1]>, - memref<16x12xf32, offset: ?, strides: [?, 1]>) - outs(%arg2: memref<10x12xf32, offset: ?, strides: [?, 1]>) + ins(%arg0, %arg1: memref<10x16xf32, strided<[?, 1], offset: ?>>, + memref<16x12xf32, strided<[?, 1], offset: ?>>) + outs(%arg2: memref<10x12xf32, strided<[?, 1], offset: ?>>) return } // TILE-2-LABEL: func @matmul_static( @@ -116,8 +115,8 @@ // TILE-2-DAG: %[[C2:.*]] = arith.constant 2 : index // TILE-2-DAG: %[[M:.*]] = arith.constant 10 : index // TILE-2: scf.for %[[I:.*]] = %{{.*}} to %[[M]] step %{{.*}} { -// TILE-2: %[[sAi:.*]] = memref.subview %{{.*}}[%[[I]], 0] [2, 16] [1, 1] : memref<10x16xf32, #[[$strided2D]]> to memref<2x16xf32, #[[$strided2D]]> -// TILE-2: %[[sCi:.*]] = memref.subview %{{.*}}[%[[I]], 0] [2, 12] [1, 1] : memref<10x12xf32, #[[$strided2D]]> to memref<2x12xf32, #[[$strided2D]]> +// TILE-2: %[[sAi:.*]] = memref.subview %{{.*}}[%[[I]], 0] [2, 16] [1, 1] : memref<10x16xf32, strided<[?, 1], offset: ?>> to memref<2x16xf32, #[[$strided2D]]> +// TILE-2: %[[sCi:.*]] = memref.subview %{{.*}}[%[[I]], 0] [2, 12] [1, 1] : memref<10x12xf32, strided<[?, 1], offset: ?>> to memref<2x12xf32, #[[$strided2D]]> // TILE-2: linalg.matmul ins(%[[sAi]], %{{.*}}{{.*}} outs(%[[sCi]] // TILE-02-LABEL: func @matmul_static( @@ -125,8 +124,8 @@ // TILE-02-DAG: %[[C2:.*]] = arith.constant 2 : index // TILE-02-DAG: %[[N:.*]] = arith.constant 12 : index // TILE-02: scf.for %[[J:.*]] = %{{.*}} to %[[N]] step %{{.*}} { -// TILE-02: %[[sBj:.*]] = memref.subview %{{.*}}[0, %[[J]]] [16, 2] [1, 1] : memref<16x12xf32, #[[$strided2D]]> to memref<16x2xf32, #[[$strided2D]]> -// TILE-02: %[[sCj:.*]] = memref.subview %{{.*}}[0, %[[J]]] [10, 2] [1, 1] : memref<10x12xf32, #[[$strided2D]]> to memref<10x2xf32, #[[$strided2D]]> +// TILE-02: %[[sBj:.*]] = memref.subview %{{.*}}[0, %[[J]]] [16, 2] [1, 1] : memref<16x12xf32, strided<[?, 1], offset: ?>> to memref<16x2xf32, #[[$strided2D]]> +// TILE-02: %[[sCj:.*]] = memref.subview %{{.*}}[0, %[[J]]] [10, 2] [1, 1] : memref<10x12xf32, strided<[?, 1], offset: ?>> to memref<10x2xf32, #[[$strided2D]]> // TILE-02: linalg.matmul ins(%{{.*}}, %[[sBj]]{{.*}} outs(%[[sCj]] // TILE-002-LABEL: func @matmul_static( @@ -134,8 +133,8 @@ // TILE-002-DAG: %[[C2:.*]] = arith.constant 2 : index // TILE-002-DAG: %[[C16:.*]] = arith.constant 16 : index // TILE-002: scf.for %[[K:.*]] = %{{.*}}{{.*}} to %[[C16]] step %{{.*}} { -// TILE-002: %[[sAj:.*]] = memref.subview %{{.*}}[0, %[[K]]] [10, 2] [1, 1] : memref<10x16xf32, #[[$strided2D]]> to memref<10x2xf32, #[[$strided2D]]> -// TILE-002: %[[sBj:.*]] = memref.subview %{{.*}}[%[[K]], 0] [2, 12] [1, 1] : memref<16x12xf32, #[[$strided2D]]> to memref<2x12xf32, #[[$strided2D]]> +// TILE-002: %[[sAj:.*]] = memref.subview %{{.*}}[0, %[[K]]] [10, 2] [1, 1] : memref<10x16xf32, strided<[?, 1], offset: ?>> to memref<10x2xf32, #[[$strided2D]]> +// TILE-002: %[[sBj:.*]] = memref.subview %{{.*}}[%[[K]], 0] [2, 12] [1, 1] : memref<16x12xf32, strided<[?, 1], offset: ?>> to memref<2x12xf32, #[[$strided2D]]> // TILE-002: linalg.matmul ins(%[[sAj]], %[[sBj]]{{.*}} outs(%{{.*}} // TILE-234-LABEL: func @matmul_static( @@ -149,17 +148,17 @@ // TILE-234: scf.for %[[I:.*]] = %{{.*}}{{.*}} to %[[C10]] step %{{.*}} { // TILE-234: scf.for %[[J:.*]] = %{{.*}}{{.*}} to %[[C12]] step %{{.*}} { // TILE-234: scf.for %[[K:.*]] = %{{.*}}{{.*}} to %[[C16]] step %{{.*}} { -// TILE-234: %[[sAik:.*]] = memref.subview %{{.*}}[%[[I]], %[[K]]] [2, 4] [1, 1] : memref<10x16xf32, #[[$strided2D]]> to memref<2x4xf32, #[[$strided2D]]> -// TILE-234: %[[sBkj:.*]] = memref.subview %{{.*}}[%[[K]], %[[J]]] [4, 3] [1, 1] : memref<16x12xf32, #[[$strided2D]]> to memref<4x3xf32, #[[$strided2D]]> -// TILE-234: %[[sCij:.*]] = memref.subview %{{.*}}[%[[I]], %[[J]]] [2, 3] [1, 1] : memref<10x12xf32, #[[$strided2D]]> to memref<2x3xf32, #[[$strided2D]]> +// TILE-234: %[[sAik:.*]] = memref.subview %{{.*}}[%[[I]], %[[K]]] [2, 4] [1, 1] : memref<10x16xf32, strided<[?, 1], offset: ?>> to memref<2x4xf32, #[[$strided2D]]> +// TILE-234: %[[sBkj:.*]] = memref.subview %{{.*}}[%[[K]], %[[J]]] [4, 3] [1, 1] : memref<16x12xf32, strided<[?, 1], offset: ?>> to memref<4x3xf32, #[[$strided2D]]> +// TILE-234: %[[sCij:.*]] = memref.subview %{{.*}}[%[[I]], %[[J]]] [2, 3] [1, 1] : memref<10x12xf32, strided<[?, 1], offset: ?>> to memref<2x3xf32, #[[$strided2D]]> // // TILE-234: linalg.matmul ins(%[[sAik]], %[[sBkj]]{{.*}} outs(%[[sCij]] -func.func @matvec(%arg0: memref, %arg1: memref, %arg2: memref) { +func.func @matvec(%arg0: memref>, %arg1: memref>, %arg2: memref>) { linalg.matvec - ins(%arg0, %arg1: memref, - memref) - outs(%arg2: memref) + ins(%arg0, %arg1: memref>, + memref>) + outs(%arg2: memref>) return } // TILE-2-LABEL: func @matvec( @@ -168,13 +167,13 @@ // TILE-2-SAME: %[[ARG2:[0-9a-zA-Z]*]]: memref // TILE-2-DAG: %[[C0:.*]] = arith.constant 0 : index // TILE-2-DAG: %[[C2:.*]] = arith.constant 2 : index -// TILE-2: %[[M:.*]] = memref.dim %{{.*}}, %c0 : memref +// TILE-2: %[[M:.*]] = memref.dim %{{.*}}, %c0 : memref> // TILE-2: scf.for %[[I:.*]] = %{{.*}}{{.*}} to %[[M]] step %{{.*}} { // TILE-2: %[[szM:.*]] = affine.min #[[$bound_map]](%[[I]])[%[[M]]] -// TILE-2: %[[N:.*]] = memref.dim %{{.*}}, %c1 : memref +// TILE-2: %[[N:.*]] = memref.dim %{{.*}}, %c1 : memref> // TILE-2: %[[szN:.*]] = affine.min #[[$bound_map]](%[[I]])[%[[M]]] -// TILE-2: %[[sAi:.*]] = memref.subview %{{.*}}[%[[I]], 0] [%[[szM]], %[[N]]] [1, 1] : memref to memref -// TILE-2: %[[sCi:.*]] = memref.subview %{{.*}}[%[[I]]] [%[[szN]]] [1] : memref to memref +// TILE-2: %[[sAi:.*]] = memref.subview %{{.*}}[%[[I]], 0] [%[[szM]], %[[N]]] [1, 1] : memref> to memref +// TILE-2: %[[sCi:.*]] = memref.subview %{{.*}}[%[[I]]] [%[[szN]]] [1] : memref> to memref // TILE-2: linalg.matvec ins(%[[sAi]], %{{.*}} outs(%[[sCi]] // TILE-02-LABEL: func @matvec( @@ -183,13 +182,13 @@ // TILE-02-SAME: %[[ARG2:[0-9a-zA-Z]*]]: memref // TILE-02-DAG: %[[C0:.*]] = arith.constant 0 : index // TILE-02-DAG: %[[C2:.*]] = arith.constant 2 : index -// TILE-02: %[[K:.*]] = memref.dim %{{.*}}, %c1 : memref +// TILE-02: %[[K:.*]] = memref.dim %{{.*}}, %c1 : memref> // TILE-02: scf.for %[[J:.*]] = %{{.*}}{{.*}} to %[[K]] step %{{.*}} { -// TILE-02: %[[M:.*]] = memref.dim %{{.*}}, %c0 : memref +// TILE-02: %[[M:.*]] = memref.dim %{{.*}}, %c0 : memref> // TILE-02: %[[szN:.*]] = affine.min #[[$bound_map]](%[[J]])[%[[K]]] // TILE-02: %[[szN_1:.*]] = affine.min #[[$bound_map]](%[[J]])[%[[K]]] -// TILE-02: %[[sAj:.*]] = memref.subview %{{.*}}[0, %[[J]]] [%[[M]], %[[szN]]] [1, 1] : memref to memref -// TILE-02: %[[sBj:.*]] = memref.subview %{{.*}}[%[[J]]] [%[[szN_1]]] [1] : memref to memref +// TILE-02: %[[sAj:.*]] = memref.subview %{{.*}}[0, %[[J]]] [%[[M]], %[[szN]]] [1, 1] : memref> to memref +// TILE-02: %[[sBj:.*]] = memref.subview %{{.*}}[%[[J]]] [%[[szN_1]]] [1] : memref> to memref // TILE-02: linalg.matvec ins(%[[sAj]], %[[sBj]]{{.*}} outs(%{{.*}} // TILE-002-LABEL: func @matvec( @@ -205,35 +204,35 @@ // TILE-234-DAG: %[[C0:.*]] = arith.constant 0 : index // TILE-234-DAG: %[[C2:.*]] = arith.constant 2 : index // TILE-234-DAG: %[[C3:.*]] = arith.constant 3 : index -// TILE-234: %[[M:.*]] = memref.dim %{{.*}}, %c0 : memref -// TILE-234: %[[K:.*]] = memref.dim %{{.*}}, %c1 : memref +// TILE-234: %[[M:.*]] = memref.dim %{{.*}}, %c0 : memref> +// TILE-234: %[[K:.*]] = memref.dim %{{.*}}, %c1 : memref> // TILE-234: scf.for %[[I:.*]] = %{{.*}}{{.*}} to %[[M]] step %{{.*}} { // TILE-234: scf.for %[[J:.*]] = %{{.*}}{{.*}} to %[[K]] step %{{.*}} { // TILE-234: %[[szM:.*]] = affine.min #[[$bound_map_2]](%[[I]])[%[[M]]] // TILE-234: %[[szN:.*]] = affine.min #[[$bound_map_3]](%[[J]])[%[[K]]] // TILE-234: %[[szN_1:.*]] = affine.min #[[$bound_map_3]](%[[J]])[%[[K]]] // TILE-234: %[[szM_1:.*]] = affine.min #[[$bound_map_2]](%[[I]])[%[[M]]] -// TILE-234: %[[sAij:.*]] = memref.subview %{{.*}}[%[[I]], %[[J]]] [%[[szM]], %[[szN]]] [1, 1] : memref to memref -// TILE-234: %[[sBj:.*]] = memref.subview %{{.*}}[%[[J]]] [%[[szN_1]]] [1] : memref to memref -// TILE-234: %[[sCi:.*]] = memref.subview %{{.*}}[%[[I]]] [%[[szM_1]]] [1] : memref to memref +// TILE-234: %[[sAij:.*]] = memref.subview %{{.*}}[%[[I]], %[[J]]] [%[[szM]], %[[szN]]] [1, 1] : memref> to memref +// TILE-234: %[[sBj:.*]] = memref.subview %{{.*}}[%[[J]]] [%[[szN_1]]] [1] : memref> to memref +// TILE-234: %[[sCi:.*]] = memref.subview %{{.*}}[%[[I]]] [%[[szM_1]]] [1] : memref> to memref // // TILE-234: linalg.matvec ins(%[[sAij]], %[[sBj]]{{.*}} outs(%[[sCi]] -func.func @dot(%arg0: memref, %arg1: memref, %arg2: memref) { +func.func @dot(%arg0: memref>, %arg1: memref>, %arg2: memref) { linalg.dot - ins(%arg0, %arg1: memref, memref) + ins(%arg0, %arg1: memref>, memref>) outs(%arg2: memref) return } // TILE-2-LABEL: func @dot( // TILE-2-DAG: %[[C0:.*]] = arith.constant 0 : index // TILE-2-DAG: %[[C2:.*]] = arith.constant 2 : index -// TILE-2: %[[M:.*]] = memref.dim %{{.*}}, %c0 : memref +// TILE-2: %[[M:.*]] = memref.dim %{{.*}}, %c0 : memref> // TILE-2: scf.for %[[I:.*]] = %{{.*}}{{.*}} to %[[M]] step %{{.*}} { // TILE-2: %[[szM:.*]] = affine.min #[[$bound_map]](%[[I]])[%[[M]]] // TILE-2: %[[szM_1:.*]] = affine.min #[[$bound_map]](%[[I]])[%[[M]]] -// TILE-2: %[[sAi:.*]] = memref.subview %{{.*}}[%[[I]]] [%[[szM]]] [1] : memref to memref -// TILE-2: %[[sBi:.*]] = memref.subview %{{.*}}[%[[I]]] [%[[szM_1]]] [1] : memref to memref +// TILE-2: %[[sAi:.*]] = memref.subview %{{.*}}[%[[I]]] [%[[szM]]] [1] : memref> to memref +// TILE-2: %[[sBi:.*]] = memref.subview %{{.*}}[%[[I]]] [%[[szM_1]]] [1] : memref> to memref // TILE-2: linalg.dot ins(%[[sAi]], %[[sBi]]{{.*}} outs( // TILE-02-LABEL: func @dot( @@ -245,12 +244,12 @@ // TILE-234-LABEL: func @dot( // TILE-234-DAG: %[[C0:.*]] = arith.constant 0 : index // TILE-234-DAG: %[[C2:.*]] = arith.constant 2 : index -// TILE-234: %[[ubK:.*]] = memref.dim %{{.*}}, %c0 : memref +// TILE-234: %[[ubK:.*]] = memref.dim %{{.*}}, %c0 : memref> // TILE-234: scf.for %[[I:.*]] = %{{.*}} to %[[ubK]] step %{{.*}} { // TILE-234: %[[szM:.*]] = affine.min #[[$bound_map_2]](%[[I]])[%[[ubK]]] // TILE-234: %[[szM_1:.*]] = affine.min #[[$bound_map_2]](%[[I]])[%[[ubK]]] -// TILE-234: %[[sAi:.*]] = memref.subview %{{.*}}[%[[I]]] [%[[szM]]] [1] : memref to memref -// TILE-234: %[[sBi:.*]] = memref.subview %{{.*}}[%[[I]]] [%[[szM_1]]] [1] : memref to memref +// TILE-234: %[[sAi:.*]] = memref.subview %{{.*}}[%[[I]]] [%[[szM]]] [1] : memref> to memref +// TILE-234: %[[sBi:.*]] = memref.subview %{{.*}}[%[[I]]] [%[[szM_1]]] [1] : memref> to memref // TILE-234: linalg.dot ins(%[[sAi]], %[[sBi]]{{.*}} outs( func.func @fill_static(%arg0: memref<127x99xf32>, %arg1: f32) { @@ -281,8 +280,8 @@ // TILE-234: linalg.fill{{.*}} : memref -func.func @fill(%arg0: memref, %arg1: f32) { - linalg.fill ins(%arg1 : f32) outs(%arg0 : memref) +func.func @fill(%arg0: memref>, %arg1: f32) { + linalg.fill ins(%arg1 : f32) outs(%arg0 : memref>) return } // TILE-2-LABEL: func @fill @@ -313,11 +312,11 @@ iterator_types = ["parallel", "parallel"] } -func.func @pointwise(%arg0: memref, %arg1: memref, - %arg2: memref) { +func.func @pointwise(%arg0: memref>, %arg1: memref>, + %arg2: memref>) { linalg.generic #pointwise_2d_trait - ins(%arg0, %arg1 : memref, memref) - outs(%arg2 : memref) { + ins(%arg0, %arg1 : memref>, memref>) + outs(%arg2 : memref>) { ^bb0(%arg4: f32, %arg5: f32, %arg6: f32): %4 = arith.addf %arg4, %arg5 : f32 linalg.yield %4 : f32 diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir --- a/mlir/test/Dialect/Linalg/transform-patterns.mlir +++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir @@ -7,12 +7,12 @@ // CHECK-DAG: #[[$nm:.*]] = affine_map<(d0, d1, d2) -> (d1, d0)> // CHECK-DAG: #[[$km:.*]] = affine_map<(d0, d1, d2) -> (d2, d0)> -func.func @dot(%x: memref, - %y: memref, +func.func @dot(%x: memref>, + %y: memref>, %v: memref) { linalg.dot { __internal_linalg_transform__ = "MEM" } - ins(%x, %y: memref, - memref) + ins(%x, %y: memref>, + memref>) outs(%v: memref) return @@ -30,13 +30,13 @@ // CHECK: arith.addf // CHECK: store -func.func @matvec(%A: memref, - %x: memref, - %y: memref) { +func.func @matvec(%A: memref>, + %x: memref>, + %y: memref>) { linalg.matvec - ins(%A, %x: memref, - memref) - outs(%y: memref) + ins(%A, %x: memref>, + memref>) + outs(%y: memref>) return } // CHECK-LABEL: func @matvec @@ -49,13 +49,13 @@ // CHECK: ins({{.*}}: memref, memref) // CHECK: outs({{.*}}: memref) -func.func @matmul(%A: memref, - %B: memref, - %C: memref) { +func.func @matmul(%A: memref>, + %B: memref>, + %C: memref>) { linalg.matmul { __internal_linalg_transform__ = "MEM" } - ins(%A, %B: memref, - memref) - outs(%C: memref) + ins(%A, %B: memref>, + memref>) + outs(%C: memref>) return } // CHECK-LABEL: func @matmul @@ -100,13 +100,13 @@ library_call = "linalg_matmul", iterator_types = ["parallel", "parallel", "reduction"] } -func.func @permute_generic(%A: memref, - %B: memref, - %C: memref) { +func.func @permute_generic(%A: memref>, + %B: memref>, + %C: memref>) { linalg.generic #generic_matmul_trait - ins(%A, %B : memref, - memref) - outs(%C : memref) { + ins(%A, %B : memref>, + memref>) + outs(%C : memref>) { ^bb(%a: f32, %b: f32, %c: f32): %d = arith.mulf %a, %b: f32 %e = arith.addf %c, %d: f32 @@ -127,17 +127,17 @@ // CHECK-SAME: indexing_maps = [#[[$kn]], #[[$nm]], #[[$km]]], // CHECK-SAME: iterator_types = ["parallel", "reduction", "parallel"], // CHECK-SAME: library_call = "linalg_matmul"} -// CHECK: memref, -// CHECK-SAME: memref -// CHECK-SAME: memref +// CHECK: memref>, +// CHECK-SAME: memref> +// CHECK-SAME: memref> -func.func @matvec_perm(%A: memref, - %x: memref, - %y: memref) { +func.func @matvec_perm(%A: memref>, + %x: memref>, + %y: memref>) { linalg.matvec {__internal_linalg_transform__ = "__with_perm__"} - ins(%A, %x: memref, - memref) - outs(%y: memref) + ins(%A, %x: memref>, + memref>) + outs(%y: memref>) return } // CHECK-LABEL: func @matvec_perm @@ -150,13 +150,13 @@ // CHECK: ins({{.*}}: memref, memref) // CHECK: outs({{.*}}: memref) -func.func @matmul_perm(%A: memref, - %B: memref, - %C: memref) { +func.func @matmul_perm(%A: memref>, + %B: memref>, + %C: memref>) { linalg.matmul {__internal_linalg_transform__ = "__with_perm__"} - ins(%A, %B: memref, - memref) - outs(%C : memref) + ins(%A, %B: memref>, + memref>) + outs(%C : memref>) return } // CHECK-LABEL: func @matmul_perm diff --git a/mlir/test/Dialect/Linalg/transform-promotion.mlir b/mlir/test/Dialect/Linalg/transform-promotion.mlir --- a/mlir/test/Dialect/Linalg/transform-promotion.mlir +++ b/mlir/test/Dialect/Linalg/transform-promotion.mlir @@ -2,31 +2,29 @@ // Map corresponding to a 2D memory access where the stride along the last dim is known to be 1. // CHECK-DAG: #[[$STRIDED_2D_u_1:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> -// Map corresponding to a 2D memory access where the stride along all dims are unknown. -// CHECK-DAG: #[[$STRIDED_2D:.*]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)> -func.func @promote_subview_matmul(%arg0: memref, - %arg1: memref, - %arg2: memref) { +func.func @promote_subview_matmul(%arg0: memref>, + %arg1: memref>, + %arg2: memref>) { %c2000 = arith.constant 2000 : index %c3000 = arith.constant 3000 : index %c4000 = arith.constant 4000 : index %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index - %0 = memref.dim %arg0, %c0 : memref - %1 = memref.dim %arg0, %c1 : memref - %2 = memref.dim %arg1, %c1 : memref + %0 = memref.dim %arg0, %c0 : memref> + %1 = memref.dim %arg0, %c1 : memref> + %2 = memref.dim %arg1, %c1 : memref> scf.for %arg3 = %c0 to %0 step %c2000 { scf.for %arg4 = %c0 to %2 step %c3000 { scf.for %arg5 = %c0 to %1 step %c4000 { %3 = memref.subview %arg0[%arg3, %arg5][%c2000, %c4000][%c1, %c1] : - memref to memref + memref> to memref> %4 = memref.subview %arg1[%arg5, %arg4][%c4000, %c3000][%c1, %c1] : - memref to memref + memref> to memref> %5 = memref.subview %arg2[%arg3, %arg4][%c2000, %c3000][%c1, %c1] : - memref to memref - linalg.matmul ins(%3, %4: memref, - memref) - outs(%5: memref) + memref> to memref> + linalg.matmul ins(%3, %4: memref>, + memref>) + outs(%5: memref>) } } } @@ -40,9 +38,9 @@ // CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c2000]] { // CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c3000]] { // CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c4000]] { -// CHECK: %[[s0:.*]] = memref.subview {{.*}}: memref to memref -// CHECK: %[[s1:.*]] = memref.subview {{.*}}: memref to memref -// CHECK: %[[s2:.*]] = memref.subview {{.*}}: memref to memref +// CHECK: %[[s0:.*]] = memref.subview {{.*}}: memref to memref +// CHECK: %[[s1:.*]] = memref.subview {{.*}}: memref to memref +// CHECK: %[[s2:.*]] = memref.subview {{.*}}: memref to memref // CHECK: %[[a0:.*]] = memref.alloc() : memref<32000000xi8> // CHECK: %[[v0:.*]] = memref.view %[[a0]]{{.*}} : memref<32000000xi8> to memref // CHECK: %[[l0:.*]] = memref.subview %[[v0]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] @@ -55,9 +53,9 @@ // CHECK: %[[v2:.*]] = memref.view %[[a2]]{{.*}} : memref<24000000xi8> to memref // CHECK: %[[l2:.*]] = memref.subview %[[v2]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] // CHECK-SAME: memref to memref -// CHECK: memref.copy %[[s0]], %[[l0]] : memref to memref -// CHECK: memref.copy %[[s1]], %[[l1]] : memref to memref -// CHECK: memref.copy %[[s2]], %[[l2]] : memref to memref +// CHECK: memref.copy %[[s0]], %[[l0]] : memref to memref +// CHECK: memref.copy %[[s1]], %[[l1]] : memref to memref +// CHECK: memref.copy %[[s2]], %[[l2]] : memref to memref // CHECK: linalg.matmul // CHECK-SAME: ins(%[[v0]], %[[v1]] : memref, memref) // CHECK-SAME: outs(%[[v2]] : memref) @@ -73,30 +71,30 @@ // ----- -func.func @promote_first_subview_matmul(%arg0: memref, - %arg1: memref, - %arg2: memref) { +func.func @promote_first_subview_matmul(%arg0: memref>, + %arg1: memref>, + %arg2: memref>) { %c2000 = arith.constant 2000 : index %c3000 = arith.constant 3000 : index %c4000 = arith.constant 4000 : index %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index - %0 = memref.dim %arg0, %c0 : memref - %1 = memref.dim %arg0, %c1 : memref - %2 = memref.dim %arg1, %c1 : memref + %0 = memref.dim %arg0, %c0 : memref> + %1 = memref.dim %arg0, %c1 : memref> + %2 = memref.dim %arg1, %c1 : memref> scf.for %arg3 = %c0 to %0 step %c2000 { scf.for %arg4 = %c0 to %2 step %c3000 { scf.for %arg5 = %c0 to %1 step %c4000 { %3 = memref.subview %arg0[%arg3, %arg5][%c2000, %c4000][%c1, %c1] : - memref to memref + memref> to memref> %4 = memref.subview %arg1[%arg5, %arg4][%c4000, %c3000][%c1, %c1] : - memref to memref + memref> to memref> %5 = memref.subview %arg2[%arg3, %arg4][%c2000, %c3000][%c1, %c1] : - memref to memref + memref> to memref> linalg.matmul {__internal_linalg_transform__ = "_promote_first_view_"} - ins(%3, %4: memref, - memref) - outs(%5: memref) + ins(%3, %4: memref>, + memref>) + outs(%5: memref>) } } } @@ -110,20 +108,20 @@ // CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c2000]] { // CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c3000]] { // CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c4000]] { -// CHECK: %[[s0:.*]] = memref.subview {{.*}}: memref to memref -// CHECK: %[[s1:.*]] = memref.subview {{.*}}: memref to memref -// CHECK: %[[s2:.*]] = memref.subview {{.*}}: memref to memref +// CHECK: %[[s0:.*]] = memref.subview {{.*}}: memref to memref +// CHECK: %[[s1:.*]] = memref.subview {{.*}}: memref to memref +// CHECK: %[[s2:.*]] = memref.subview {{.*}}: memref to memref // CHECK: %[[a0:.*]] = memref.alloc() : memref<32000000xi8> // CHECK: %[[v0:.*]] = memref.view %[[a0]]{{.*}} : memref<32000000xi8> to memref // CHECK: %[[l0:.*]] = memref.subview %[[v0]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref to memref // CHECK-NOT: memref.alloc // CHECK-NOT: memref.view // CHECK-NOT: memref.subview -// CHECK: memref.copy %[[s0]], %[[l0]] : memref to memref +// CHECK: memref.copy %[[s0]], %[[l0]] : memref to memref // CHECK-NOT: memref.copy // CHECK: linalg.matmul -// CHECK-SAME: ins(%[[v0]], %[[s1]] : memref, memref) -// CHECK-SAME: outs(%[[s2]] : memref) +// CHECK-SAME: ins(%[[v0]], %[[s1]] : memref, memref>) +// CHECK-SAME: outs(%[[s2]] : memref>) transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): @@ -136,26 +134,26 @@ // ----- -func.func @aligned_promote_fill(%arg0: memref) { +func.func @aligned_promote_fill(%arg0: memref>) { %c2000 = arith.constant 2000 : index %c4000 = arith.constant 4000 : index %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %cf = arith.constant 1.0 : f32 %3 = memref.subview %arg0[%c0, %c0][%c2000, %c4000][%c1, %c1] : - memref to memref + memref> to memref> linalg.fill - ins(%cf : f32) outs(%3 : memref) + ins(%cf : f32) outs(%3 : memref>) return } // CHECK-LABEL: func @aligned_promote_fill // CHECK: %[[cf:.*]] = arith.constant 1.{{.*}} : f32 -// CHECK: %[[s0:.*]] = memref.subview {{.*}}: memref to memref +// CHECK: %[[s0:.*]] = memref.subview {{.*}}: memref to memref // CHECK: %[[a0:.*]] = memref.alloc() {alignment = 32 : i64} : memref<32000000xi8> // CHECK: %[[v0:.*]] = memref.view %[[a0]]{{.*}} : memref<32000000xi8> to memref // CHECK: %[[l0:.*]] = memref.subview %[[v0]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref to memref // CHECK: linalg.fill ins({{.*}} : f32) outs(%[[v0]] : memref) -// CHECK: memref.copy %[[s0]], %[[l0]] : memref to memref +// CHECK: memref.copy %[[s0]], %[[l0]] : memref to memref // CHECK: linalg.fill ins(%[[cf]] : f32) outs(%[[v0]] : memref) transform.with_pdl_patterns { @@ -169,7 +167,7 @@ // ----- -func.func @aligned_promote_fill_complex(%arg0: memref, offset: ?, strides: [?, 1]>) { +func.func @aligned_promote_fill_complex(%arg0: memref, strided<[?, 1], offset: ?>>) { %c2000 = arith.constant 2000 : index %c4000 = arith.constant 4000 : index %c0 = arith.constant 0 : index @@ -177,19 +175,19 @@ %cf = arith.constant 1.0 : f32 %cc = complex.create %cf, %cf : complex %3 = memref.subview %arg0[%c0, %c0][%c2000, %c4000][%c1, %c1] : - memref, offset: ?, strides: [?, 1]> to memref, offset: ?, strides: [?, ?]> + memref, strided<[?, 1], offset: ?>> to memref, strided<[?, ?], offset: ?>> linalg.fill ins(%cc : complex) - outs(%3 : memref, offset: ?, strides: [?, ?]>) + outs(%3 : memref, strided<[?, ?], offset: ?>>) return } // CHECK-LABEL: func @aligned_promote_fill_complex // CHECK: %[[cc:.*]] = complex.create {{.*}} : complex -// CHECK: %[[s0:.*]] = memref.subview {{.*}}: memref, #map{{.*}}> to memref, #map{{.*}}> +// CHECK: %[[s0:.*]] = memref.subview {{.*}}: memref, strided{{.*}}> to memref, strided{{.*}}> // CHECK: %[[a0:.*]] = memref.alloc() {alignment = 32 : i64} : memref<64000000xi8> // CHECK: %[[v0:.*]] = memref.view %[[a0]]{{.*}} : memref<64000000xi8> to memref> // CHECK: %[[l0:.*]] = memref.subview %[[v0]][0, 0] [%{{.*}}, %{{.*}}] [1, 1] : memref> to memref, #[[$STRIDED_2D_u_1]]> // CHECK: linalg.fill ins({{.*}} : complex) outs(%[[v0]] : memref>) -// CHECK: memref.copy %[[s0]], %[[l0]] : memref, #map{{.*}}> to memref, #map{{.*}}> +// CHECK: memref.copy %[[s0]], %[[l0]] : memref, strided{{.*}}> to memref, #map{{.*}}> // CHECK: linalg.fill ins(%[[cc]] : complex) outs(%[[v0]] : memref>) transform.with_pdl_patterns { diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -88,59 +88,61 @@ // ----- +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1)[s0] -> (d0 * 384 + s0 + d1)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0)[s0] -> (d0 + s0)> + func.func @multiple_reducing_dims(%arg0 : memref<1x384x384xf32>, - %arg1 : index, %arg2 : index, %arg3 : index) -> memref + %arg1 : index, %arg2 : index, %arg3 : index) -> memref> { %c1 = arith.constant 1 : index - %0 = memref.subview %arg0[0, %arg1, %arg2] [1, %c1, %arg3] [1, 1, 1] : memref<1x384x384xf32> to memref - %1 = memref.subview %0[0, 0] [1, %arg3] [1, 1] : memref to memref - return %1 : memref + %0 = memref.subview %arg0[0, %arg1, %arg2] [1, %c1, %arg3] [1, 1, 1] : memref<1x384x384xf32> to memref> + %1 = memref.subview %0[0, 0] [1, %arg3] [1, 1] : memref> to memref> + return %1 : memref> } -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (d0 + s0)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0] -> (d0 * 384 + s0 + d1)> // CHECK: func @multiple_reducing_dims // CHECK: %[[REDUCED1:.+]] = memref.subview %{{.+}}[0, %{{.+}}, %{{.+}}] [1, 1, %{{.+}}] [1, 1, 1] -// CHECK-SAME: : memref<1x384x384xf32> to memref<1x?xf32, #[[MAP1]]> +// CHECK-SAME: : memref<1x384x384xf32> to memref<1x?xf32, #[[MAP0]]> // CHECK: %[[REDUCED2:.+]] = memref.subview %[[REDUCED1]][0, 0] [1, %{{.+}}] [1, 1] -// CHECK-SAME: : memref<1x?xf32, #[[MAP1]]> to memref +// CHECK-SAME: : memref<1x?xf32, #[[MAP0]]> to memref // ----- +// CHECK-DAG: #[[MAP0]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> +// CHECK-DAG: #[[MAP1]] = affine_map<(d0)[s0] -> (d0 + s0)> + func.func @multiple_reducing_dims_dynamic(%arg0 : memref, - %arg1 : index, %arg2 : index, %arg3 : index) -> memref + %arg1 : index, %arg2 : index, %arg3 : index) -> memref> { %c1 = arith.constant 1 : index - %0 = memref.subview %arg0[0, %arg1, %arg2] [1, %c1, %arg3] [1, 1, 1] : memref to memref - %1 = memref.subview %0[0, 0] [1, %arg3] [1, 1] : memref to memref - return %1 : memref + %0 = memref.subview %arg0[0, %arg1, %arg2] [1, %c1, %arg3] [1, 1, 1] : memref to memref> + %1 = memref.subview %0[0, 0] [1, %arg3] [1, 1] : memref> to memref> + return %1 : memref> } -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (d0 + s0)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> // CHECK: func @multiple_reducing_dims_dynamic // CHECK: %[[REDUCED1:.+]] = memref.subview %{{.+}}[0, %{{.+}}, %{{.+}}] [1, 1, %{{.+}}] [1, 1, 1] -// CHECK-SAME: : memref to memref<1x?xf32, #[[MAP1]]> +// CHECK-SAME: : memref to memref<1x?xf32, #[[MAP0]]> // CHECK: %[[REDUCED2:.+]] = memref.subview %[[REDUCED1]][0, 0] [1, %{{.+}}] [1, 1] -// CHECK-SAME: : memref<1x?xf32, #[[MAP1]]> to memref +// CHECK-SAME: : memref<1x?xf32, #[[MAP0]]> to memref // ----- -func.func @multiple_reducing_dims_all_dynamic(%arg0 : memref, - %arg1 : index, %arg2 : index, %arg3 : index) -> memref +// CHECK-DAG: #[[MAP0]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)> +// CHECK-DAG: #[[MAP1]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> + +func.func @multiple_reducing_dims_all_dynamic(%arg0 : memref>, + %arg1 : index, %arg2 : index, %arg3 : index) -> memref> { %c1 = arith.constant 1 : index %0 = memref.subview %arg0[0, %arg1, %arg2] [1, %c1, %arg3] [1, 1, 1] - : memref to memref - %1 = memref.subview %0[0, 0] [1, %arg3] [1, 1] : memref to memref - return %1 : memref + : memref> to memref> + %1 = memref.subview %0[0, 0] [1, %arg3] [1, 1] : memref> to memref> + return %1 : memref> } -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)> -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)> // CHECK: func @multiple_reducing_dims_all_dynamic // CHECK: %[[REDUCED1:.+]] = memref.subview %{{.+}}[0, %{{.+}}, %{{.+}}] [1, 1, %{{.+}}] [1, 1, 1] -// CHECK-SAME: : memref to memref<1x?xf32, #[[MAP1]]> +// CHECK-SAME: : memref> to memref<1x?xf32, #[[MAP0]]> // CHECK: %[[REDUCED2:.+]] = memref.subview %[[REDUCED1]][0, 0] [1, %{{.+}}] [1, 1] -// CHECK-SAME: : memref<1x?xf32, #[[MAP1]]> to memref +// CHECK-SAME: : memref<1x?xf32, #[[MAP0]]> to memref // ----- @@ -330,15 +332,15 @@ // ----- func.func @do_not_compose_collapse_of_expand_non_identity_layout( - %arg0: memref) - -> memref { + %arg0: memref>) + -> memref> { %1 = memref.expand_shape %arg0 [[0, 1], [2]] : - memref into - memref + memref> into + memref> %2 = memref.collapse_shape %1 [[0, 1, 2]] : - memref into - memref - return %2 : memref + memref> into + memref> + return %2 : memref> } // CHECK-LABEL: func @do_not_compose_collapse_of_expand_non_identity_layout // CHECK: expand @@ -747,10 +749,10 @@ // CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>) // CHECK: %[[CAST:.*]] = memref.cast %[[ARG]] : memref<8x2xf32> to memref) -> memref { +func.func @reinterpret_of_extract_strided_metadata_w_type_mistach(%arg0 : memref<8x2xf32>) -> memref> { %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref, index, index, index, index, index - %m2 = memref.reinterpret_cast %base to offset: [%offset], sizes: [%sizes#0, %sizes#1], strides: [%strides#0, %strides#1] : memref to memref - return %m2 : memref + %m2 = memref.reinterpret_cast %base to offset: [%offset], sizes: [%sizes#0, %sizes#1], strides: [%strides#0, %strides#1] : memref to memref> + return %m2 : memref> } // ----- @@ -775,10 +777,10 @@ // CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] // CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[OFFSET]]], sizes: [4, 2, 2], strides: [1, 1, %[[STRIDES]]#1] // CHECK: return %[[RES]] -func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : memref<8x2xf32>) -> memref { +func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : memref<8x2xf32>) -> memref> { %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref, index, index, index, index, index - %m2 = memref.reinterpret_cast %base to offset: [%offset], sizes: [4, 2, 2], strides: [1, 1, %strides#1] : memref to memref - return %m2 : memref + %m2 = memref.reinterpret_cast %base to offset: [%offset], sizes: [4, 2, 2], strides: [1, 1, %strides#1] : memref to memref> + return %m2 : memref> } // ----- @@ -789,20 +791,20 @@ // CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] // CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [1], sizes: [%[[SIZES]]#0, %[[SIZES]]#1], strides: [%[[STRIDES]]#0, %[[STRIDES]]#1] // CHECK: return %[[RES]] -func.func @reinterpret_of_extract_strided_metadata_w_different_offset(%arg0 : memref<8x2xf32>) -> memref { +func.func @reinterpret_of_extract_strided_metadata_w_different_offset(%arg0 : memref<8x2xf32>) -> memref> { %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref, index, index, index, index, index - %m2 = memref.reinterpret_cast %base to offset: [1], sizes: [%sizes#0, %sizes#1], strides: [%strides#0, %strides#1] : memref to memref - return %m2 : memref + %m2 = memref.reinterpret_cast %base to offset: [1], sizes: [%sizes#0, %sizes#1], strides: [%strides#0, %strides#1] : memref to memref> + return %m2 : memref> } // ----- func.func @canonicalize_rank_reduced_subview(%arg0 : memref<8x?xf32>, - %arg1 : index) -> memref { + %arg1 : index) -> memref> { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index - %0 = memref.subview %arg0[%c0, %c0] [1, %arg1] [%c1, %c1] : memref<8x?xf32> to memref - return %0 : memref + %0 = memref.subview %arg0[%c0, %c0] [1, %arg1] [%c1, %c1] : memref<8x?xf32> to memref> + return %0 : memref> } // CHECK-DAG: #[[MAP:.+]] = affine_map<(d0)[s0] -> (d0 + s0)> // CHECK: func @canonicalize_rank_reduced_subview diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir --- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir +++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir @@ -1,8 +1,8 @@ // RUN: mlir-opt -fold-memref-alias-ops -split-input-file %s -o - | FileCheck %s func.func @fold_static_stride_subview_with_load(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index) -> f32 { - %0 = memref.subview %arg0[%arg1, %arg2][4, 4][2, 3] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [64, 3]> - %1 = memref.load %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [64, 3]> + %0 = memref.subview %arg0[%arg1, %arg2][4, 4][2, 3] : memref<12x32xf32> to memref<4x4xf32, strided<[64, 3], offset: ?>> + %1 = memref.load %0[%arg3, %arg4] : memref<4x4xf32, strided<[64, 3], offset: ?>> return %1 : f32 } // CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (d0 * 2 + s0)> @@ -21,8 +21,8 @@ func.func @fold_dynamic_stride_subview_with_load(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index) -> f32 { %0 = memref.subview %arg0[%arg1, %arg2][4, 4][%arg5, %arg6] : - memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [?, ?]> - %1 = memref.load %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [?, ?]> + memref<12x32xf32> to memref<4x4xf32, strided<[?, ?], offset: ?>> + %1 = memref.load %0[%arg3, %arg4] : memref<4x4xf32, strided<[?, ?], offset: ?>> return %1 : f32 } // CHECK-DAG: #[[MAP:.+]] = affine_map<(d0)[s0, s1] -> (d0 * s0 + s1)> @@ -42,8 +42,8 @@ func.func @fold_static_stride_subview_with_store(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : f32) { %0 = memref.subview %arg0[%arg1, %arg2][4, 4][2, 3] : - memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [64, 3]> - memref.store %arg5, %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [64, 3]> + memref<12x32xf32> to memref<4x4xf32, strided<[64, 3], offset: ?>> + memref.store %arg5, %0[%arg3, %arg4] : memref<4x4xf32, strided<[64, 3], offset: ?>> return } // CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (d0 * 2 + s0)> @@ -62,8 +62,8 @@ func.func @fold_dynamic_stride_subview_with_store(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index, %arg7 : f32) { %0 = memref.subview %arg0[%arg1, %arg2][4, 4][%arg5, %arg6] : - memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [?, ?]> - memref.store %arg7, %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [?, ?]> + memref<12x32xf32> to memref<4x4xf32, strided<[?, ?], offset: ?>> + memref.store %arg7, %0[%arg3, %arg4] : memref<4x4xf32, strided<[?, ?], offset: ?>> return } // CHECK-DAG: #[[MAP:.+]] = affine_map<(d0)[s0, s1] -> (d0 * s0 + s1)> @@ -83,8 +83,8 @@ func.func @fold_subview_with_transfer_read(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index) -> vector<4xf32> { %f1 = arith.constant 1.0 : f32 - %0 = memref.subview %arg0[%arg1, %arg2][4, 4][%arg5, %arg6] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [?, ?]> - %1 = vector.transfer_read %0[%arg3, %arg4], %f1 {in_bounds = [true]} : memref<4x4xf32, offset:?, strides: [?, ?]>, vector<4xf32> + %0 = memref.subview %arg0[%arg1, %arg2][4, 4][%arg5, %arg6] : memref<12x32xf32> to memref<4x4xf32, strided<[?, ?], offset: ?>> + %1 = vector.transfer_read %0[%arg3, %arg4], %f1 {in_bounds = [true]} : memref<4x4xf32, strided<[?, ?], offset: ?>>, vector<4xf32> return %1 : vector<4xf32> } // CHECK-DAG: #[[MAP:.+]] = affine_map<(d0)[s0, s1] -> (d0 * s0 + s1)> @@ -104,8 +104,8 @@ func.func @fold_static_stride_subview_with_transfer_write(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5: index, %arg6 : index, %arg7 : vector<4xf32>) { %0 = memref.subview %arg0[%arg1, %arg2][4, 4][%arg5, %arg6] : - memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [?, ?]> - vector.transfer_write %arg7, %0[%arg3, %arg4] {in_bounds = [true]} : vector<4xf32>, memref<4x4xf32, offset:?, strides: [?, ?]> + memref<12x32xf32> to memref<4x4xf32, strided<[?, ?], offset: ?>> + vector.transfer_write %arg7, %0[%arg3, %arg4] {in_bounds = [true]} : vector<4xf32>, memref<4x4xf32, strided<[?, ?], offset: ?>> return } // CHECK-DAG: #[[MAP:.+]] = affine_map<(d0)[s0, s1] -> (d0 * s0 + s1)> @@ -129,8 +129,8 @@ %arg7 : index, %arg8 : index, %arg9 : index, %arg10: index, %arg11 : index, %arg12 : index, %arg13 : index, %arg14: index, %arg15 : index, %arg16 : index) -> f32 { - %0 = memref.subview %arg0[%arg1, %arg2, %arg3, %arg4, %arg5, %arg6][4, 1, 1, 4, 1, 1][%arg7, %arg8, %arg9, %arg10, %arg11, %arg12] : memref to memref<4x1x4x1xf32, offset:?, strides: [?, ?, ?, ?]> - %1 = memref.load %0[%arg13, %arg14, %arg15, %arg16] : memref<4x1x4x1xf32, offset:?, strides: [?, ?, ?, ?]> + %0 = memref.subview %arg0[%arg1, %arg2, %arg3, %arg4, %arg5, %arg6][4, 1, 1, 4, 1, 1][%arg7, %arg8, %arg9, %arg10, %arg11, %arg12] : memref to memref<4x1x4x1xf32, strided<[?, ?, ?, ?], offset: ?>> + %1 = memref.load %0[%arg13, %arg14, %arg15, %arg16] : memref<4x1x4x1xf32, strided<[?, ?, ?, ?], offset: ?>> return %1 : f32 } // CHECK-DAG: #[[MAP:.+]] = affine_map<(d0)[s0, s1] -> (d0 * s0 + s1)> @@ -164,21 +164,20 @@ // ----- func.func @fold_vector_transfer_read_with_rank_reduced_subview( - %arg0 : memref, + %arg0 : memref>, %arg1: index, %arg2 : index, %arg3 : index, %arg4: index, %arg5 : index, %arg6 : index) -> vector<4xf32> { %cst = arith.constant 0.0 : f32 %0 = memref.subview %arg0[0, %arg1, %arg2] [1, %arg3, %arg4] [1, 1, 1] - : memref to - memref + : memref> to + memref> %1 = vector.transfer_read %0[%arg5, %arg6], %cst {in_bounds = [true]} - : memref, vector<4xf32> + : memref>, vector<4xf32> return %1 : vector<4xf32> } -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)> // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0] -> (d0 + s0)> // CHECK: func @fold_vector_transfer_read_with_rank_reduced_subview -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref> // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index // CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index @@ -193,21 +192,20 @@ // ----- func.func @fold_vector_transfer_write_with_rank_reduced_subview( - %arg0 : memref, + %arg0 : memref>, %arg1 : vector<4xf32>, %arg2: index, %arg3 : index, %arg4 : index, %arg5: index, %arg6 : index, %arg7 : index) { %cst = arith.constant 0.0 : f32 %0 = memref.subview %arg0[0, %arg2, %arg3] [1, %arg4, %arg5] [1, 1, 1] - : memref to - memref + : memref> to + memref> vector.transfer_write %arg1, %0[%arg6, %arg7] {in_bounds = [true]} - : vector<4xf32>, memref + : vector<4xf32>, memref> return } -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)> // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0] -> (d0 + s0)> // CHECK: func @fold_vector_transfer_write_with_rank_reduced_subview -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref> // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: vector<4xf32> // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index // CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index @@ -223,22 +221,21 @@ // ----- func.func @fold_vector_transfer_write_with_inner_rank_reduced_subview( - %arg0 : memref, + %arg0 : memref>, %arg1 : vector<4xf32>, %arg2: index, %arg3 : index, %arg4 : index, %arg5: index, %arg6 : index, %arg7 : index) { %cst = arith.constant 0.0 : f32 %0 = memref.subview %arg0[%arg2, %arg3, 0] [%arg4, %arg5, 1] [1, 1, 1] - : memref to - memref + : memref> to + memref> vector.transfer_write %arg1, %0[%arg6, %arg7] {in_bounds = [true]} - : vector<4xf32>, memref + : vector<4xf32>, memref> return } -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)> // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0] -> (d0 + s0)> // CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d1)> // CHECK: func @fold_vector_transfer_write_with_inner_rank_reduced_subview -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref> // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: vector<4xf32> // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index // CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index @@ -260,12 +257,12 @@ // CHECK-LABEL: func @fold_static_stride_subview_with_affine_load_store func.func @fold_static_stride_subview_with_affine_load_store(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index) -> f32 { - %0 = memref.subview %arg0[%arg1, %arg2][4, 4][2, 3] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [64, 3]> - %1 = affine.load %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [64, 3]> + %0 = memref.subview %arg0[%arg1, %arg2][4, 4][2, 3] : memref<12x32xf32> to memref<4x4xf32, strided<[64, 3], offset: ?>> + %1 = affine.load %0[%arg3, %arg4] : memref<4x4xf32, strided<[64, 3], offset: ?>> // CHECK-NEXT: affine.apply // CHECK-NEXT: affine.apply // CHECK-NEXT: affine.load - affine.store %1, %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [64, 3]> + affine.store %1, %0[%arg3, %arg4] : memref<4x4xf32, strided<[64, 3], offset: ?>> // CHECK-NEXT: affine.apply // CHECK-NEXT: affine.apply // CHECK-NEXT: affine.store diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir --- a/mlir/test/Dialect/MemRef/invalid.mlir +++ b/mlir/test/Dialect/MemRef/invalid.mlir @@ -152,7 +152,7 @@ // expected-error @+1 {{expected 1 offset values}} %out = memref.reinterpret_cast %in to offset: [0, 0], sizes: [10, 10], strides: [10, 1] - : memref to memref<10x10xf32, offset: 0, strides: [10, 1]> + : memref to memref<10x10xf32, strided<[10, 1], offset: 0>> return } @@ -162,7 +162,7 @@ // expected-error @+1 {{different element types specified}} %out = memref.reinterpret_cast %in to offset: [0], sizes: [10], strides: [1] - : memref<*xf32> to memref<10xi32, offset: 0, strides: [1]> + : memref<*xf32> to memref<10xi32, strided<[1], offset: 0>> return } @@ -172,7 +172,7 @@ // expected-error @+1 {{different memory spaces specified}} %out = memref.reinterpret_cast %in to offset: [0], sizes: [10], strides: [1] - : memref<*xf32> to memref<10xi32, offset: 0, strides: [1], 2> + : memref<*xf32> to memref<10xi32, strided<[1], offset: 0>, 2> return } @@ -182,7 +182,7 @@ // expected-error @+1 {{expected result type with offset = 2 instead of 1}} %out = memref.reinterpret_cast %in to offset: [1], sizes: [10], strides: [1] - : memref to memref<10xf32, offset: 2, strides: [1]> + : memref to memref<10xf32, strided<[1], offset: 2>> return } @@ -192,7 +192,7 @@ // expected-error @+1 {{expected result type with size = 10 instead of 1 in dim = 0}} %out = memref.reinterpret_cast %in to offset: [0], sizes: [10], strides: [1] - : memref<*xf32> to memref<1xf32, offset: 0, strides: [1]> + : memref<*xf32> to memref<1xf32, strided<[1], offset: 0>> return } @@ -202,7 +202,7 @@ // expected-error @+1 {{expected result type with stride = 2 instead of 1 in dim = 0}} %out = memref.reinterpret_cast %in to offset: [2], sizes: [10], strides: [2] - : memref to memref<10xf32, offset: 2, strides: [1]> + : memref to memref<10xf32, strided<[1], offset: 2>> return } @@ -272,11 +272,11 @@ // ----- func.func @memref_reshape_src_affine_map_is_not_identity( - %buf: memref<4x4xf32, offset: 0, strides: [3, 2]>, + %buf: memref<4x4xf32, strided<[3, 2], offset: 0>>, %shape: memref<1xi32>) { // expected-error @+1 {{source memref type should have identity affine map}} memref.reshape %buf(%shape) - : (memref<4x4xf32, offset: 0, strides: [3, 2]>, memref<1xi32>) + : (memref<4x4xf32, strided<[3, 2], offset: 0>>, memref<1xi32>) -> memref<8xf32> } @@ -286,7 +286,7 @@ %buf: memref<4x4xf32>, %shape: memref<1xi32>) { // expected-error @+1 {{result memref type should have identity affine map}} memref.reshape %buf(%shape) - : (memref<4x4xf32>, memref<1xi32>) -> memref<8xf32, offset: 0, strides: [2]> + : (memref<4x4xf32>, memref<1xi32>) -> memref<8xf32, strided<[2], offset: 0>> } // ----- @@ -423,11 +423,11 @@ // ----- func.func @expand_shape_invalid_result_layout( - %arg0: memref<30x20xf32, offset : 100, strides : [4000, 2]>) { + %arg0: memref<30x20xf32, strided<[4000, 2], offset: 100>>) { // expected-error @+1 {{expected expanded type to be 'memref<2x15x20xf32, affine_map<(d0, d1, d2) -> (d0 * 60000 + d1 * 4000 + d2 * 2 + 100)>>' but found 'memref<2x15x20xf32, affine_map<(d0, d1, d2) -> (d0 * 5000 + d1 * 4000 + d2 * 2 + 100)>>'}} %0 = memref.expand_shape %arg0 [[0, 1], [2]] : - memref<30x20xf32, offset : 100, strides : [4000, 2]> - into memref<2x15x20xf32, offset : 100, strides : [5000, 4000, 2]> + memref<30x20xf32, strided<[4000, 2], offset: 100>> + into memref<2x15x20xf32, strided<[5000, 4000, 2], offset: 100>> } // ----- @@ -435,7 +435,7 @@ func.func @collapse_shape_mismatch_indices_num(%arg0: memref) { // expected-error @+1 {{invalid number of reassociation groups: found 1, expected 2}} %0 = memref.collapse_shape %arg0 [[0, 1]] : - memref into memref + memref into memref> } // ----- @@ -443,17 +443,17 @@ func.func @collapse_shape_invalid_reassociation(%arg0: memref) { // expected-error @+1 {{reassociation indices must be contiguous}} %0 = memref.collapse_shape %arg0 [[0, 1], [1, 2]] : - memref into memref + memref into memref> } // ----- func.func @collapse_shape_reshaping_non_contiguous( - %arg0: memref<3x4x5xf32, offset: 0, strides: [270, 50, 10]>) { + %arg0: memref<3x4x5xf32, strided<[270, 50, 10], offset: 0>>) { // expected-error @+1 {{invalid source layout map or collapsing non-contiguous dims}} %0 = memref.collapse_shape %arg0 [[0, 1], [2]] : - memref<3x4x5xf32, offset: 0, strides: [270, 50, 10]> - into memref<12x5xf32, offset: 0, strides: [50, 1]> + memref<3x4x5xf32, strided<[270, 50, 10], offset: 0>> + into memref<12x5xf32, strided<[50, 1], offset: 0>> return } @@ -628,10 +628,10 @@ // ----- func.func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) { - %0 = memref.alloc() : memref<8x16x4xf32, offset: 0, strides: [64, 4, 1], 2> + %0 = memref.alloc() : memref<8x16x4xf32, strided<[64, 4, 1], offset: 0>, 2> // expected-error@+1 {{different memory spaces}} %1 = memref.subview %0[0, 0, 0][%arg2, %arg2, %arg2][1, 1, 1] - : memref<8x16x4xf32, offset: 0, strides: [64, 4, 1], 2> to + : memref<8x16x4xf32, strided<[64, 4, 1], offset: 0>, 2> to memref<8x?x4xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * s0 + d1 * 4 + d2)>> return } @@ -643,7 +643,7 @@ // expected-error@+1 {{is not strided}} %1 = memref.subview %0[0, 0, 0][%arg2, %arg2, %arg2][1, 1, 1] : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 + d1, d1 + d2, d2)>> to - memref<8x?x4xf32, offset: 0, strides: [?, 4, 1]> + memref<8x?x4xf32, strided<[?, 4, 1], offset: 0>> return } @@ -654,7 +654,7 @@ // expected-error@+1 {{expected 3 offset values}} %1 = memref.subview %0[%arg0, %arg1, 0, 0][%arg2, 0, 0, 0][1, 1, 1, 1] : memref<8x16x4xf32> to - memref<8x?x4xf32, offset: 0, strides:[?, ?, 4]> + memref<8x?x4xf32, strided<[?, ?, 4], offset: 0>> return } @@ -746,17 +746,17 @@ // ----- -func.func @invalid_memref_cast(%arg0 : memref<12x4x16xf32, offset:0, strides:[64, 16, 1]>) { - // expected-error@+1{{operand type 'memref<12x4x16xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 16 + d2)>>' and result type 'memref<12x4x16xf32, affine_map<(d0, d1, d2) -> (d0 * 128 + d1 * 32 + d2 * 2)>>' are cast incompatible}} - %0 = memref.cast %arg0 : memref<12x4x16xf32, offset:0, strides:[64, 16, 1]> to memref<12x4x16xf32, offset:0, strides:[128, 32, 2]> +func.func @invalid_memref_cast(%arg0 : memref<12x4x16xf32, strided<[64, 16, 1], offset: 0>>) { + // expected-error@+1{{operand type 'memref<12x4x16xf32, strided<[64, 16, 1]>>' and result type 'memref<12x4x16xf32, strided<[128, 32, 2]>>' are cast incompatible}} + %0 = memref.cast %arg0 : memref<12x4x16xf32, strided<[64, 16, 1], offset: 0>> to memref<12x4x16xf32, strided<[128, 32, 2], offset: 0>> return } // ----- -func.func @invalid_memref_cast(%arg0 : memref<12x4x16xf32, offset:0, strides:[64, 16, 1]>) { - // expected-error@+1{{operand type 'memref<12x4x16xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 16 + d2)>>' and result type 'memref<12x4x16xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 16 + d2 + 16)>>' are cast incompatible}} - %0 = memref.cast %arg0 : memref<12x4x16xf32, offset:0, strides:[64, 16, 1]> to memref<12x4x16xf32, offset:16, strides:[64, 16, 1]> +func.func @invalid_memref_cast(%arg0 : memref<12x4x16xf32, strided<[64, 16, 1], offset: 0>>) { + // expected-error@+1{{operand type 'memref<12x4x16xf32, strided<[64, 16, 1]>>' and result type 'memref<12x4x16xf32, strided<[64, 16, 1], offset: 16>>' are cast incompatible}} + %0 = memref.cast %arg0 : memref<12x4x16xf32, strided<[64, 16, 1], offset: 0>> to memref<12x4x16xf32, strided<[64, 16, 1], offset: 16>> return } diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir --- a/mlir/test/Dialect/MemRef/ops.mlir +++ b/mlir/test/Dialect/MemRef/ops.mlir @@ -1,39 +1,33 @@ // RUN: mlir-opt %s | mlir-opt | FileCheck %s // RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s -// CHECK-DAG: #[[$strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> -// CHECK-DAG: #[[$strided3D:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)> -// CHECK-DAG: #[[$strided2DOFF0:.*]] = affine_map<(d0, d1)[s0] -> (d0 * s0 + d1)> -// CHECK-DAG: #[[$strided3DOFF0:.*]] = affine_map<(d0, d1, d2)[s0, s1] -> (d0 * s0 + d1 * s1 + d2)> -// CHECK-DAG: #[[$strided2D42:.*]] = affine_map<(d0, d1) -> (d0 * 42 + d1)> - // CHECK-LABEL: func @memref_reinterpret_cast func.func @memref_reinterpret_cast(%in: memref) - -> memref<10x?xf32, offset: ?, strides: [?, 1]> { + -> memref<10x?xf32, strided<[?, 1], offset: ?>> { %c0 = arith.constant 0 : index %c10 = arith.constant 10 : index %out = memref.reinterpret_cast %in to offset: [%c0], sizes: [10, %c10], strides: [%c10, 1] - : memref to memref<10x?xf32, offset: ?, strides: [?, 1]> - return %out : memref<10x?xf32, offset: ?, strides: [?, 1]> + : memref to memref<10x?xf32, strided<[?, 1], offset: ?>> + return %out : memref<10x?xf32, strided<[?, 1], offset: ?>> } // CHECK-LABEL: func @memref_reinterpret_cast_static_to_dynamic_sizes func.func @memref_reinterpret_cast_static_to_dynamic_sizes(%in: memref) - -> memref<10x?xf32, offset: ?, strides: [?, 1]> { + -> memref<10x?xf32, strided<[?, 1], offset: ?>> { %out = memref.reinterpret_cast %in to offset: [1], sizes: [10, 10], strides: [1, 1] - : memref to memref<10x?xf32, offset: ?, strides: [?, 1]> - return %out : memref<10x?xf32, offset: ?, strides: [?, 1]> + : memref to memref<10x?xf32, strided<[?, 1], offset: ?>> + return %out : memref<10x?xf32, strided<[?, 1], offset: ?>> } // CHECK-LABEL: func @memref_reinterpret_cast_dynamic_offset func.func @memref_reinterpret_cast_dynamic_offset(%in: memref, %offset: index) - -> memref<10x?xf32, offset: ?, strides: [?, 1]> { + -> memref<10x?xf32, strided<[?, 1], offset: ?>> { %out = memref.reinterpret_cast %in to offset: [%offset], sizes: [10, 10], strides: [1, 1] - : memref to memref<10x?xf32, offset: ?, strides: [?, 1]> - return %out : memref<10x?xf32, offset: ?, strides: [?, 1]> + : memref to memref<10x?xf32, strided<[?, 1], offset: ?>> + return %out : memref<10x?xf32, strided<[?, 1], offset: ?>> } // CHECK-LABEL: func @memref_reshape( @@ -109,10 +103,10 @@ %arg0: memref<3x4x5xf32>, %arg1: tensor<3x4x5xf32>, %arg2: tensor<3x?x5xf32>, - %arg3: memref<30x20xf32, offset : 100, strides : [4000, 2]>, + %arg3: memref<30x20xf32, strided<[4000, 2], offset: 100>>, %arg4: memref<1x5xf32, affine_map<(d0, d1)[s0] -> (d0 * 5 + s0 + d1)>>, %arg5: memref, - %arg6: memref<3x4x5xf32, offset: 0, strides: [240, 60, 10]>, + %arg6: memref<3x4x5xf32, strided<[240, 60, 10], offset: 0>>, %arg7: memref<1x2049xi64, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>>) { // Reshapes that collapse and expand back a contiguous buffer. // CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]] @@ -153,13 +147,13 @@ // Reshapes with a custom layout map. // CHECK: memref.expand_shape {{.*}} {{\[}}[0], [1, 2]] %l0 = memref.expand_shape %arg3 [[0], [1, 2]] : - memref<30x20xf32, offset : 100, strides : [4000, 2]> - into memref<30x4x5xf32, offset : 100, strides : [4000, 10, 2]> + memref<30x20xf32, strided<[4000, 2], offset: 100>> + into memref<30x4x5xf32, strided<[4000, 10, 2], offset: 100>> // CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2]] %l1 = memref.expand_shape %arg3 [[0, 1], [2]] : - memref<30x20xf32, offset : 100, strides : [4000, 2]> - into memref<2x15x20xf32, offset : 100, strides : [60000, 4000, 2]> + memref<30x20xf32, strided<[4000, 2], offset: 100>> + into memref<2x15x20xf32, strided<[60000, 4000, 2], offset: 100>> // CHECK: memref.expand_shape {{.*}} {{\[}}[0], [1, 2]] %r4 = memref.expand_shape %arg4 [[0], [1, 2]] : @@ -169,8 +163,8 @@ // Note: Only the collapsed two shapes are contiguous in the follow test case. // CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]] %r6 = memref.collapse_shape %arg6 [[0, 1], [2]] : - memref<3x4x5xf32, offset: 0, strides: [240, 60, 10]> into - memref<12x5xf32, offset: 0, strides: [60, 10]> + memref<3x4x5xf32, strided<[240, 60, 10], offset: 0>> into + memref<12x5xf32, strided<[60, 10], offset: 0>> // CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1]] %r7 = memref.collapse_shape %arg7 [[0, 1]] : @@ -209,9 +203,9 @@ // CHECK-LABEL: func @expand_collapse_shape_dynamic func.func @expand_collapse_shape_dynamic(%arg0: memref, - %arg1: memref, - %arg2: memref, - %arg3: memref) { + %arg1: memref>, + %arg2: memref>, + %arg3: memref>) { // CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]] // CHECK-SAME: memref into memref %0 = memref.collapse_shape %arg0 [[0, 1], [2]] : @@ -223,39 +217,39 @@ memref into memref // CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]] -// CHECK-SAME: memref into memref +// CHECK-SAME: memref> into memref> %1 = memref.collapse_shape %arg1 [[0, 1], [2]] : - memref into - memref + memref> into + memref> // CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2]] -// CHECK-SAME: memref into memref +// CHECK-SAME: memref> into memref> %r1 = memref.expand_shape %1 [[0, 1], [2]] : - memref into - memref + memref> into + memref> // CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]] -// CHECK-SAME: memref into memref +// CHECK-SAME: memref> into memref> %2 = memref.collapse_shape %arg2 [[0, 1], [2]] : - memref into - memref + memref> into + memref> // CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2]] -// CHECK-SAME: memref into memref +// CHECK-SAME: memref> into memref> %r2 = memref.expand_shape %2 [[0, 1], [2]] : - memref into - memref + memref> into + memref> // CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1]] -// CHECK-SAME: memref into memref +// CHECK-SAME: memref> into memref %3 = memref.collapse_shape %arg3 [[0, 1]] : - memref into - memref + memref> into + memref // CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1]] // CHECK-SAME: memref into memref %r3 = memref.expand_shape %3 [[0, 1]] : - memref into memref + memref into memref return } @@ -283,22 +277,22 @@ // CHECK-LABEL: func @expand_collapse_shape_transposed_layout func.func @expand_collapse_shape_transposed_layout( - %m0: memref, - %m1: memref<4x5x6xf32, offset : 0, strides : [1, ?, 1000]>) { + %m0: memref>, + %m1: memref<4x5x6xf32, strided<[1, ?, 1000], offset: 0>>) { %r0 = memref.expand_shape %m0 [[0], [1, 2]] : - memref into - memref + memref> into + memref> %rr0 = memref.collapse_shape %r0 [[0], [1, 2]] : - memref into - memref + memref> into + memref> %r1 = memref.expand_shape %m1 [[0, 1], [2], [3, 4]] : - memref<4x5x6xf32, offset : 0, strides : [1, ?, 1000]> into - memref<2x2x5x2x3xf32, offset : 0, strides : [2, 1, ?, 3000, 1000]> + memref<4x5x6xf32, strided<[1, ?, 1000], offset: 0>> into + memref<2x2x5x2x3xf32, strided<[2, 1, ?, 3000, 1000], offset: 0>> %rr1 = memref.collapse_shape %r1 [[0, 1], [2], [3, 4]] : - memref<2x2x5x2x3xf32, offset : 0, strides : [2, 1, ?, 3000, 1000]> into - memref<4x5x6xf32, offset : 0, strides : [1, ?, 1000]> + memref<2x2x5x2x3xf32, strided<[2, 1, ?, 3000, 1000], offset: 0>> into + memref<4x5x6xf32, strided<[1, ?, 1000], offset: 0>> return } @@ -340,7 +334,7 @@ // ----- func.func @extract_strided_metadata(%memref : memref<10x?xf32>) - -> memref { + -> memref> { %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %memref : memref<10x?xf32> -> memref, index, index, index, index, index @@ -349,7 +343,7 @@ offset: [%offset], sizes: [%sizes#0, %sizes#1], strides: [%strides#0, %strides#1] - : memref to memref + : memref to memref> - return %m2: memref + return %m2: memref> } diff --git a/mlir/test/Dialect/MemRef/subview.mlir b/mlir/test/Dialect/MemRef/subview.mlir --- a/mlir/test/Dialect/MemRef/subview.mlir +++ b/mlir/test/Dialect/MemRef/subview.mlir @@ -1,22 +1,8 @@ // RUN: mlir-opt %s | mlir-opt | FileCheck %s // RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s -// CHECK-DAG: #[[$BASE_MAP0:map[0-9]+]] = affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)> -// CHECK-DAG: #[[$BASE_MAP3:map[0-9]+]] = affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)> - // CHECK-DAG: #[[$BASE_MAP1:map[0-9]+]] = affine_map<(d0)[s0] -> (d0 + s0)> // CHECK-DAG: #[[$SUBVIEW_MAP1:map[0-9]+]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> - -// CHECK-DAG: #[[$BASE_MAP2:map[0-9]+]] = affine_map<(d0, d1) -> (d0 * 22 + d1)> -// CHECK-DAG: #[[$SUBVIEW_MAP2:map[0-9]+]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)> -// CHECK-DAG: #[[$SUBVIEW_MAP3:map[0-9]+]] = affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2 + 8)> -// CHECK-DAG: #[[$SUBVIEW_MAP4:map[0-9]+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> -// CHECK-DAG: #[[$SUBVIEW_MAP5:map[0-9]+]] = affine_map<(d0, d1)[s0] -> (d0 * 8 + s0 + d1 * 2)> -// CHECK-DAG: #[[$SUBVIEW_MAP6:map[0-9]+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0 * 36 + d1 * 36 + d2 * 4 + d3 * 4 + d4)> -// CHECK-DAG: #[[$SUBVIEW_MAP7:map[0-9]+]] = affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3 * s4 + d4 * s5 + d5 * s6)> -// CHECK-DAG: #[[$SUBVIEW_MAP8:map[0-9]+]] = affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3 * s4)> -// CHECK-DAG: #[[$SUBVIEW_MAP9:map[0-9]+]] = affine_map<(d0, d1) -> (d0 * 3 + d1 + 6)> -// CHECK-DAG: #[[$SUBVIEW_MAP10:map[0-9]+]] = affine_map<(d0) -> (d0 + 3)> // CHECK-DAG: #[[$SUBVIEW_MAP11:map[0-9]+]] = affine_map<() -> (4)> // CHECK-DAG: #[[$SUBVIEW_MAP12:map[0-9]+]] = affine_map<()[s0] -> (s0)> @@ -25,13 +11,13 @@ %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index - %0 = memref.alloc() : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>> + %0 = memref.alloc() : memref<8x16x4xf32, strided<[64, 4, 1], offset: 0>> // CHECK: subview %0[%c0, %c0, %c0] [%arg0, %arg1, %arg2] [%c1, %c1, %c1] : - // CHECK-SAME: memref<8x16x4xf32, #[[$BASE_MAP0]]> - // CHECK-SAME: to memref + // CHECK-SAME: memref<8x16x4xf32, strided<[64, 4, 1]>> + // CHECK-SAME: to memref> %1 = memref.subview %0[%c0, %c0, %c0][%arg0, %arg1, %arg2][%c1, %c1, %c1] - : memref<8x16x4xf32, offset:0, strides: [64, 4, 1]> to - memref + : memref<8x16x4xf32, strided<[64, 4, 1], offset: 0>> to + memref> %2 = memref.alloc()[%arg2] : memref<64xf32, affine_map<(d0)[s0] -> (d0 + s0)>> // CHECK: memref.subview %2[%c1] [%arg0] [%c1] : @@ -41,58 +27,58 @@ : memref<64xf32, affine_map<(d0)[s0] -> (d0 + s0)>> to memref (d0 * s1 + s0)>> - %4 = memref.alloc() : memref<64x22xf32, affine_map<(d0, d1) -> (d0 * 22 + d1)>> + %4 = memref.alloc() : memref<64x22xf32, strided<[22, 1]>> // CHECK: memref.subview %4[%c0, %c1] [%arg0, %arg1] [%c1, %c0] : - // CHECK-SAME: memref<64x22xf32, #[[$BASE_MAP2]]> - // CHECK-SAME: to memref + // CHECK-SAME: memref<64x22xf32, strided<[22, 1]>> + // CHECK-SAME: to memref> %5 = memref.subview %4[%c0, %c1][%arg0, %arg1][%c1, %c0] - : memref<64x22xf32, offset:0, strides: [22, 1]> to - memref + : memref<64x22xf32, strided<[22, 1], offset: 0>> to + memref> // CHECK: memref.subview %0[0, 2, 0] [4, 4, 4] [1, 1, 1] : - // CHECK-SAME: memref<8x16x4xf32, #[[$BASE_MAP0]]> - // CHECK-SAME: to memref<4x4x4xf32, #[[$SUBVIEW_MAP3]]> + // CHECK-SAME: memref<8x16x4xf32, strided<[64, 4, 1]>> + // CHECK-SAME: to memref<4x4x4xf32, strided<[64, 4, 1], offset: 8>> %6 = memref.subview %0[0, 2, 0][4, 4, 4][1, 1, 1] - : memref<8x16x4xf32, offset:0, strides: [64, 4, 1]> to - memref<4x4x4xf32, offset:8, strides: [64, 4, 1]> + : memref<8x16x4xf32, strided<[64, 4, 1], offset: 0>> to + memref<4x4x4xf32, strided<[64, 4, 1], offset: 8>> %7 = memref.alloc(%arg1, %arg2) : memref // CHECK: memref.subview {{%.*}}[0, 0] [4, 4] [1, 1] : // CHECK-SAME: memref - // CHECK-SAME: to memref<4x4xf32, #[[$SUBVIEW_MAP4]]> + // CHECK-SAME: to memref<4x4xf32, strided<[?, 1], offset: ?>> %8 = memref.subview %7[0, 0][4, 4][1, 1] - : memref to memref<4x4xf32, offset: ?, strides:[?, 1]> + : memref to memref<4x4xf32, strided<[?, 1], offset: ?>> %9 = memref.alloc() : memref<16x4xf32> // CHECK: memref.subview {{%.*}}[{{%.*}}, {{%.*}}] [4, 4] [{{%.*}}, {{%.*}}] : // CHECK-SAME: memref<16x4xf32> - // CHECK-SAME: to memref<4x4xf32, #[[$SUBVIEW_MAP2]] + // CHECK-SAME: to memref<4x4xf32, strided<[?, ?], offset: ?>> %10 = memref.subview %9[%arg1, %arg1][4, 4][%arg2, %arg2] - : memref<16x4xf32> to memref<4x4xf32, offset: ?, strides:[?, ?]> + : memref<16x4xf32> to memref<4x4xf32, strided<[?, ?], offset: ?>> // CHECK: memref.subview {{%.*}}[{{%.*}}, {{%.*}}] [4, 4] [2, 2] : // CHECK-SAME: memref<16x4xf32> - // CHECK-SAME: to memref<4x4xf32, #[[$SUBVIEW_MAP5]] + // CHECK-SAME: to memref<4x4xf32, strided<[8, 2], offset: ?>> %11 = memref.subview %9[%arg1, %arg2][4, 4][2, 2] - : memref<16x4xf32> to memref<4x4xf32, offset: ?, strides:[8, 2]> + : memref<16x4xf32> to memref<4x4xf32, strided<[8, 2], offset: ?>> - %12 = memref.alloc() : memref<1x9x1x4x1xf32, affine_map<(d0, d1, d2, d3, d4) -> (36 * d0 + 36 * d1 + 4 * d2 + 4 * d3 + d4)>> + %12 = memref.alloc() : memref<1x9x1x4x1xf32, strided<[36, 36, 4, 4, 1]>> // CHECK: memref.subview %12[%arg1, %arg1, %arg1, %arg1, %arg1] // CHECK-SAME: [1, 9, 1, 4, 1] [%arg2, %arg2, %arg2, %arg2, %arg2] : - // CHECK-SAME: memref<1x9x1x4x1xf32, #[[$SUBVIEW_MAP6]]> to memref<9x4xf32, #[[$SUBVIEW_MAP2]]> - %13 = memref.subview %12[%arg1, %arg1, %arg1, %arg1, %arg1][1, 9, 1, 4, 1][%arg2, %arg2, %arg2, %arg2, %arg2] : memref<1x9x1x4x1xf32, offset: 0, strides: [36, 36, 4, 4, 1]> to memref<9x4xf32, offset: ?, strides: [?, ?]> + // CHECK-SAME: memref<1x9x1x4x1xf32, strided<[36, 36, 4, 4, 1]>> to memref<9x4xf32, strided<[?, ?], offset: ?>> + %13 = memref.subview %12[%arg1, %arg1, %arg1, %arg1, %arg1][1, 9, 1, 4, 1][%arg2, %arg2, %arg2, %arg2, %arg2] : memref<1x9x1x4x1xf32, strided<[36, 36, 4, 4, 1], offset: 0>> to memref<9x4xf32, strided<[?, ?], offset: ?>> // CHECK: memref.subview %12[%arg1, %arg1, %arg1, %arg1, %arg1] // CHECK-SAME: [1, 9, 1, 4, 1] [%arg2, %arg2, %arg2, %arg2, %arg2] : - // CHECK-SAME: memref<1x9x1x4x1xf32, #[[$SUBVIEW_MAP6]]> to memref<1x9x4xf32, #[[$BASE_MAP3]]> - %14 = memref.subview %12[%arg1, %arg1, %arg1, %arg1, %arg1][1, 9, 1, 4, 1][%arg2, %arg2, %arg2, %arg2, %arg2] : memref<1x9x1x4x1xf32, offset: 0, strides: [36, 36, 4, 4, 1]> to memref<1x9x4xf32, offset: ?, strides: [?, ?, ?]> + // CHECK-SAME: memref<1x9x1x4x1xf32, strided<[36, 36, 4, 4, 1]>> to memref<1x9x4xf32, strided<[?, ?, ?], offset: ?>> + %14 = memref.subview %12[%arg1, %arg1, %arg1, %arg1, %arg1][1, 9, 1, 4, 1][%arg2, %arg2, %arg2, %arg2, %arg2] : memref<1x9x1x4x1xf32, strided<[36, 36, 4, 4, 1], offset: 0>> to memref<1x9x4xf32, strided<[?, ?, ?], offset: ?>> - %15 = memref.alloc(%arg1, %arg2)[%c0, %c1, %arg1, %arg0, %arg0, %arg2, %arg2] : memref<1x?x5x1x?x1xf32, affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6] -> (s0 + s1 * d0 + s2 * d1 + s3 * d2 + s4 * d3 + s5 * d4 + s6 * d5)>> + %15 = memref.alloc(%arg1, %arg2)[%c0, %c1, %arg1, %arg0, %arg0, %arg2, %arg2] : memref<1x?x5x1x?x1xf32, strided<[?, ?, ?, ?, ?, ?], offset: ?>> // CHECK: memref.subview %15[0, 0, 0, 0, 0, 0] [1, %arg1, 5, 1, %arg2, 1] [1, 1, 1, 1, 1, 1] : - // CHECK-SAME: memref<1x?x5x1x?x1xf32, #[[$SUBVIEW_MAP7]]> to memref - %16 = memref.subview %15[0, 0, 0, 0, 0, 0][1, %arg1, 5, 1, %arg2, 1][1, 1, 1, 1, 1, 1] : memref<1x?x5x1x?x1xf32, offset: ?, strides: [?, ?, ?, ?, ?, ?]> to memref + // CHECK-SAME: memref<1x?x5x1x?x1xf32, strided<[?, ?, ?, ?, ?, ?], offset: ?>> to memref> + %16 = memref.subview %15[0, 0, 0, 0, 0, 0][1, %arg1, 5, 1, %arg2, 1][1, 1, 1, 1, 1, 1] : memref<1x?x5x1x?x1xf32, strided<[?, ?, ?, ?, ?, ?], offset: ?>> to memref> // CHECK: memref.subview %15[%arg1, %arg1, %arg1, %arg1, %arg1, %arg1] [1, %arg1, 5, 1, %arg2, 1] [1, 1, 1, 1, 1, 1] : - // CHECK-SAME: memref<1x?x5x1x?x1xf32, #[[$SUBVIEW_MAP7]]> to memref - %17 = memref.subview %15[%arg1, %arg1, %arg1, %arg1, %arg1, %arg1][1, %arg1, 5, 1, %arg2, 1][1, 1, 1, 1, 1, 1] : memref<1x?x5x1x?x1xf32, offset: ?, strides: [?, ?, ?, ?, ?, ?]> to memref + // CHECK-SAME: memref<1x?x5x1x?x1xf32, strided<[?, ?, ?, ?, ?, ?], offset: ?>> to memref> + %17 = memref.subview %15[%arg1, %arg1, %arg1, %arg1, %arg1, %arg1][1, %arg1, 5, 1, %arg2, 1][1, 1, 1, 1, 1, 1] : memref<1x?x5x1x?x1xf32, strided<[?, ?, ?, ?, ?, ?], offset: ?>> to memref> %18 = memref.alloc() : memref<1x8xf32> // CHECK: memref.subview %18[0, 0] [1, 8] [1, 1] : memref<1x8xf32> to memref<8xf32> @@ -102,19 +88,19 @@ // CHECK: memref.subview %20[0, 0, 0] [1, 16, 4] [1, 1, 1] : memref<8x16x4xf32> to memref<16x4xf32> %21 = memref.subview %20[0, 0, 0][1, 16, 4][1, 1, 1] : memref<8x16x4xf32> to memref<16x4xf32> - %22 = memref.subview %20[3, 4, 2][1, 6, 3][1, 1, 1] : memref<8x16x4xf32> to memref<6x3xf32, offset: 210, strides: [4, 1]> + %22 = memref.subview %20[3, 4, 2][1, 6, 3][1, 1, 1] : memref<8x16x4xf32> to memref<6x3xf32, strided<[4, 1], offset: 210>> %23 = memref.alloc() : memref %78 = memref.subview %23[] [] [] : memref to memref /// Subview with only leading operands. %24 = memref.alloc() : memref<5x3xf32> - // CHECK: memref.subview %{{.*}}[2, 0] [3, 3] [1, 1] : memref<5x3xf32> to memref<3x3xf32, #[[$SUBVIEW_MAP9]]> - %25 = memref.subview %24[2, 0][3, 3][1, 1]: memref<5x3xf32> to memref<3x3xf32, offset: 6, strides: [3, 1]> + // CHECK: memref.subview %{{.*}}[2, 0] [3, 3] [1, 1] : memref<5x3xf32> to memref<3x3xf32, strided<[3, 1], offset: 6>> + %25 = memref.subview %24[2, 0][3, 3][1, 1]: memref<5x3xf32> to memref<3x3xf32, strided<[3, 1], offset: 6>> /// Rank-reducing subview with only leading operands. - // CHECK: memref.subview %{{.*}}[1, 0] [1, 3] [1, 1] : memref<5x3xf32> to memref<3xf32, #[[$SUBVIEW_MAP10]]> - %26 = memref.subview %24[1, 0][1, 3][1, 1]: memref<5x3xf32> to memref<3xf32, offset: 3, strides: [1]> + // CHECK: memref.subview %{{.*}}[1, 0] [1, 3] [1, 1] : memref<5x3xf32> to memref<3xf32, strided<[1], offset: 3>> + %26 = memref.subview %24[1, 0][1, 3][1, 1]: memref<5x3xf32> to memref<3xf32, strided<[1], offset: 3>> // Corner-case of 0-D rank-reducing subview with an offset. // CHECK: memref.subview %{{.*}}[1, 1] [1, 1] [1, 1] : memref<5x3xf32> to memref diff --git a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir --- a/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir +++ b/mlir/test/Dialect/SCF/parallel-loop-fusion.mlir @@ -298,8 +298,8 @@ } scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { %A = memref.subview %buffer[%c0, %c0][%c2, %c2][%c1, %c1] - : memref<2x2xf32> to memref - %A_elem = memref.load %A[%i, %j] : memref + : memref<2x2xf32> to memref> + %A_elem = memref.load %A[%i, %j] : memref> scf.yield } return diff --git a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir --- a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir @@ -1,11 +1,11 @@ // RUN: mlir-opt %s -test-vector-transfer-drop-unit-dims-patterns -split-input-file | FileCheck %s func.func @transfer_read_rank_reducing( - %arg : memref<1x1x3x2xi8, offset:?, strides:[6, 6, 2, 1]>) -> vector<3x2xi8> { + %arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>) -> vector<3x2xi8> { %c0 = arith.constant 0 : index %cst = arith.constant 0 : i8 %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst : - memref<1x1x3x2xi8, offset:?, strides:[6, 6, 2, 1]>, vector<3x2xi8> + memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, vector<3x2xi8> return %v : vector<3x2xi8> } @@ -17,10 +17,10 @@ // ----- -func.func @transfer_write_rank_reducing(%arg : memref<1x1x3x2xi8, offset:?, strides:[6, 6, 2, 1]>, %vec : vector<3x2xi8>) { +func.func @transfer_write_rank_reducing(%arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, %vec : vector<3x2xi8>) { %c0 = arith.constant 0 : index vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] : - vector<3x2xi8>, memref<1x1x3x2xi8, offset:?, strides:[6, 6, 2, 1]> + vector<3x2xi8>, memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>> return } diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir --- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir @@ -1,11 +1,11 @@ // RUN: mlir-opt %s -test-vector-transfer-flatten-patterns -split-input-file | FileCheck %s func.func @transfer_read_flattenable_with_offset( - %arg : memref<5x4x3x2xi8, offset:?, strides:[24, 6, 2, 1]>) -> vector<5x4x3x2xi8> { + %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> { %c0 = arith.constant 0 : index %cst = arith.constant 0 : i8 %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst : - memref<5x4x3x2xi8, offset:?, strides:[24, 6, 2, 1]>, vector<5x4x3x2xi8> + memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<5x4x3x2xi8> return %v : vector<5x4x3x2xi8> } @@ -19,10 +19,10 @@ // ----- func.func @transfer_write_flattenable_with_offset( - %arg : memref<5x4x3x2xi8, offset:?, strides:[24, 6, 2, 1]>, %vec : vector<5x4x3x2xi8>) { + %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, %vec : vector<5x4x3x2xi8>) { %c0 = arith.constant 0 : index vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] : - vector<5x4x3x2xi8>, memref<5x4x3x2xi8, offset:?, strides:[24, 6, 2, 1]> + vector<5x4x3x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> return } diff --git a/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir b/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir --- a/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir @@ -106,7 +106,7 @@ // LINALG-SAME: %[[i:[a-zA-Z0-9]*]]: index // LINALG-SAME: %[[j:[a-zA-Z0-9]*]]: index func.func @split_vector_transfer_read_strided_2d( - %A: memref<7x8xf32, offset:?, strides:[?, 1]>, + %A: memref<7x8xf32, strided<[?, 1], offset: ?>>, %i: index, %j: index) -> vector<4x8xf32> { %c0 = arith.constant 0 : index %f0 = arith.constant 0.0 : f32 @@ -127,13 +127,13 @@ // CHECK: %[[ifres:.*]]:3 = scf.if %[[cond]] -> (memref, index, index) { // inBounds but not cast-compatible: yield a memref_casted form of %A // CHECK: %[[casted:.*]] = memref.cast %arg0 : - // CHECK-SAME: memref<7x8xf32, #[[$map_2d_stride_1]]> to memref + // CHECK-SAME: memref<7x8xf32, strided<[?, 1], offset: ?>> to memref // CHECK: scf.yield %[[casted]], %[[i]], %[[j]] : // CHECK-SAME: memref, index, index // CHECK: } else { // slow path, fill tmp alloc and yield a memref_casted version of it // CHECK: %[[slow:.*]] = vector.transfer_read %[[A]][%[[i]], %[[j]]], %cst : - // CHECK-SAME: memref<7x8xf32, #[[$map_2d_stride_1]]>, vector<4x8xf32> + // CHECK-SAME: memref<7x8xf32, strided<[?, 1], offset: ?>>, vector<4x8xf32> // CHECK: %[[cast_alloc:.*]] = vector.type_cast %[[alloc]] : // CHECK-SAME: memref<4x8xf32> to memref> // CHECK: store %[[slow]], %[[cast_alloc]][] : @@ -163,7 +163,7 @@ // LINALG: %[[ifres:.*]]:3 = scf.if %[[cond]] -> (memref, index, index) { // inBounds but not cast-compatible: yield a memref_casted form of %A // LINALG: %[[casted:.*]] = memref.cast %arg0 : - // LINALG-SAME: memref<7x8xf32, #[[$map_2d_stride_1]]> to memref + // LINALG-SAME: memref<7x8xf32, strided<[?, 1], offset: ?>> to memref // LINALG: scf.yield %[[casted]], %[[i]], %[[j]] : // LINALG-SAME: memref, index, index // LINALG: } else { @@ -172,7 +172,7 @@ // LINALG: %[[sv0:.*]] = affine.min #[[$bounds_map_4]](%[[c7]], %[[i]], %[[c4]]) // LINALG: %[[sv1:.*]] = affine.min #[[$bounds_map_8]](%[[c8]], %[[j]], %[[c8]]) // LINALG: %[[sv:.*]] = memref.subview %[[A]][%[[i]], %[[j]]] [%[[sv0]], %[[sv1]]] [1, 1] - // LINALG-SAME: memref<7x8xf32, #[[$map_2d_stride_1]]> to memref + // LINALG-SAME: memref<7x8xf32, strided<[?, 1], offset: ?>> to memref // LINALG: %[[alloc_view:.*]] = memref.subview %[[alloc]][0, 0] [%[[sv0]], %[[sv1]]] [1, 1] // LINALG: memref.copy %[[sv]], %[[alloc_view]] : memref to memref // LINALG: %[[yielded:.*]] = memref.cast %[[alloc]] : @@ -183,7 +183,7 @@ // LINALG: %[[res:.*]] = vector.transfer_read {{.*}} {in_bounds = [true, true]} : // LINALG-SAME: memref, vector<4x8xf32> %1 = vector.transfer_read %A[%i, %j], %f0 : - memref<7x8xf32, offset:?, strides:[?, 1]>, vector<4x8xf32> + memref<7x8xf32, strided<[?, 1], offset: ?>>, vector<4x8xf32> // CHECK: return %[[res]] : vector<4x8xf32> return %1 : vector<4x8xf32> @@ -288,10 +288,10 @@ // ----- func.func @split_vector_transfer_write_strided_2d( - %V: vector<4x8xf32>, %A: memref<7x8xf32, offset:?, strides:[?, 1]>, + %V: vector<4x8xf32>, %A: memref<7x8xf32, strided<[?, 1], offset: ?>>, %i: index, %j: index) { vector.transfer_write %V, %A[%i, %j] : - vector<4x8xf32>, memref<7x8xf32, offset:?, strides:[?, 1]> + vector<4x8xf32>, memref<7x8xf32, strided<[?, 1], offset: ?>> return } @@ -300,7 +300,7 @@ // CHECK-DAG: #[[MAP2:.*]] = affine_map<()[s0] -> (s0 + 8)> // CHECK: func @split_vector_transfer_write_strided_2d( // CHECK-SAME: %[[VEC:.*]]: vector<4x8xf32>, -// CHECK-SAME: %[[DEST:.*]]: memref<7x8xf32, #[[MAP0]]>, +// CHECK-SAME: %[[DEST:.*]]: memref<7x8xf32, strided<[?, 1], offset: ?>>, // CHECK-SAME: %[[I:.*]]: index, // CHECK-SAME: %[[J:.*]]: index) { // CHECK-DAG: %[[C7:.*]] = arith.constant 7 : index @@ -316,7 +316,7 @@ // CHECK: %[[IN_BOUND_DEST:.*]]:3 = scf.if %[[IN_BOUNDS]] // CHECK-SAME: -> (memref, index, index) { // CHECK: %[[VAL_15:.*]] = memref.cast %[[DEST]] -// CHECK-SAME: : memref<7x8xf32, #[[MAP0]]> to memref +// CHECK-SAME: : memref<7x8xf32, strided<[?, 1], offset: ?>> to memref // CHECK: scf.yield %[[VAL_15]], %[[I]], %[[J]] // CHECK-SAME: : memref, index, index // CHECK: } else { @@ -336,7 +336,7 @@ // CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_19]][] // CHECK-SAME: : memref> // CHECK: vector.transfer_write %[[VAL_20]], %[[DEST]][%[[I]], %[[J]]] -// CHECK-SAME: : vector<4x8xf32>, memref<7x8xf32, #[[MAP0]]> +// CHECK-SAME: : vector<4x8xf32>, memref<7x8xf32, strided<[?, 1], offset: ?>> // CHECK: } // CHECK: return // CHECK: } @@ -349,7 +349,7 @@ // LINALG-DAG: #[[MAP5:.*]] = affine_map<(d0, d1)[s0] -> (d0 * 8 + s0 + d1)> // LINALG: func @split_vector_transfer_write_strided_2d( // LINALG-SAME: %[[VEC:.*]]: vector<4x8xf32>, -// LINALG-SAME: %[[DEST:.*]]: memref<7x8xf32, #[[MAP0]]>, +// LINALG-SAME: %[[DEST:.*]]: memref<7x8xf32, strided<[?, 1], offset: ?>>, // LINALG-SAME: %[[I:.*]]: index, // LINALG-SAME: %[[J:.*]]: index) { // LINALG-DAG: %[[C0:.*]] = arith.constant 0 : index @@ -366,7 +366,7 @@ // LINALG: %[[IN_BOUND_DEST:.*]]:3 = scf.if %[[IN_BOUNDS]] // LINALG-SAME: -> (memref, index, index) { // LINALG: %[[VAL_16:.*]] = memref.cast %[[DEST]] -// LINALG-SAME: : memref<7x8xf32, #[[MAP0]]> to memref +// LINALG-SAME: : memref<7x8xf32, strided<[?, 1], offset: ?>> to memref // LINALG: scf.yield %[[VAL_16]], %[[I]], %[[J]] // LINALG-SAME: : memref, index, index // LINALG: } else { diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -8,9 +8,6 @@ // CHECK: #map1 = affine_map<()[s0] -> (s0 + 1)> -// CHECK-DAG: #[[$BASE_MAP0:map[0-9]+]] = affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)> -// CHECK-DAG: #[[$BASE_MAP3:map[0-9]+]] = affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)> - // CHECK-LABEL: func @func_with_ops // CHECK-SAME: %[[ARG:.*]]: f32 func.func @func_with_ops(f32) { @@ -236,18 +233,18 @@ } // CHECK-LABEL: func @memref_cast(%arg0 -func.func @memref_cast(%arg0: memref<4xf32>, %arg1 : memref, %arg2 : memref<64x16x4xf32, offset: 0, strides: [64, 4, 1]>) { +func.func @memref_cast(%arg0: memref<4xf32>, %arg1 : memref, %arg2 : memref<64x16x4xf32, strided<[64, 4, 1], offset: 0>>) { // CHECK: %0 = memref.cast %arg0 : memref<4xf32> to memref %0 = memref.cast %arg0 : memref<4xf32> to memref // CHECK: %1 = memref.cast %arg1 : memref to memref<4xf32> %1 = memref.cast %arg1 : memref to memref<4xf32> - // CHECK: {{%.*}} = memref.cast %arg2 : memref<64x16x4xf32, #[[$BASE_MAP0]]> to memref<64x16x4xf32, #[[$BASE_MAP3]]> - %2 = memref.cast %arg2 : memref<64x16x4xf32, offset: 0, strides: [64, 4, 1]> to memref<64x16x4xf32, offset: ?, strides: [?, ?, ?]> + // CHECK: {{%.*}} = memref.cast %arg2 : memref<64x16x4xf32, strided<[64, 4, 1]>> to memref<64x16x4xf32, strided<[?, ?, ?], offset: ?>> + %2 = memref.cast %arg2 : memref<64x16x4xf32, strided<[64, 4, 1], offset: 0>> to memref<64x16x4xf32, strided<[?, ?, ?], offset: ?>> - // CHECK: {{%.*}} = memref.cast {{%.*}} : memref<64x16x4xf32, #[[$BASE_MAP3]]> to memref<64x16x4xf32, #[[$BASE_MAP0]]> - %3 = memref.cast %2 : memref<64x16x4xf32, offset: ?, strides: [?, ?, ?]> to memref<64x16x4xf32, offset: 0, strides: [64, 4, 1]> + // CHECK: {{%.*}} = memref.cast {{%.*}} : memref<64x16x4xf32, strided<[?, ?, ?], offset: ?>> to memref<64x16x4xf32, strided<[64, 4, 1]>> + %3 = memref.cast %2 : memref<64x16x4xf32, strided<[?, ?, ?], offset: ?>> to memref<64x16x4xf32, strided<[64, 4, 1], offset: 0>> // CHECK: memref.cast %{{.*}} : memref<4xf32> to memref<*xf32> %4 = memref.cast %1 : memref<4xf32> to memref<*xf32> diff --git a/mlir/test/IR/invalid-builtin-types.mlir b/mlir/test/IR/invalid-builtin-types.mlir --- a/mlir/test/IR/invalid-builtin-types.mlir +++ b/mlir/test/IR/invalid-builtin-types.mlir @@ -64,36 +64,58 @@ // ----- -func.func @memref_space_after_strides(memref<42x42xi8, 0, offset: ?, strides: [?, ?]>) // expected-error {{expected memory space to be last in memref type}} +// expected-error @below {{expected '<' after 'strided'}} +func.func private @memref_unfinished_strided() -> memref // ----- -func.func @memref_stride_missing_colon(memref<42x42xi8, offset ?, strides: [?, ?]>) // expected-error {{expected colon after `offset` keyword}} +// expected-error @below {{expected '['}} +func.func private @memref_unfinished_strided() -> memref> // ----- -func.func @memref_stride_invalid_offset(memref<42x42xi8, offset: [], strides: [?, ?]>) // expected-error {{invalid offset}} +// expected-error @below {{expected a non-negative 64-bit signed integer or '?'}} +func.func private @memref_unfinished_stride_list() -> memref> // ----- -func.func @memref_stride_missing_strides(memref<42x42xi8, offset: 0 [?, ?]>) // expected-error {{expected comma after offset value}} +// expected-error @below {{expected 'offset' after comma}} +func.func private @memref_missing_offset() -> memref> // ----- -func.func @memref_stride_missing_strides(memref<42x42xi8, offset: 0, [?, ?]>) // expected-error {{expected `strides` keyword after offset specification}} +// expected-error @below {{expected ':' after 'offset'}} +func.func private @memref_missing_offset_colon() -> memref> // ----- -func.func @memref_stride_missing_colon_2(memref<42x42xi8, offset: 0, strides [?, ?]>) // expected-error {{expected colon after `strides` keyword}} +// expected-error @below {{expected a non-negative 64-bit signed integer or '?'}} +func.func private @memref_missing_offset_value() -> memref> // ----- -// expected-error @+1 {{expected '['}} -func.func @memref_stride_invalid_strides(memref<42x42xi8, offset: 0, strides: ()>) +// expected-error @below {{expected '>'}} +func.func private @memref_incorrect_strided_ending() -> memref // ----- -func.func @memref_zero_stride(memref<42x42xi8, offset: ?, strides: [0, ?]>) // expected-error {{invalid memref stride}} +// expected-error @below {{strides must be positive or dynamic}} +func.func private @memref_zero_stride() -> memref> + +// ----- + +// expected-error @below {{expected a non-negative 64-bit signed integer or '?'}} +func.func private @memref_negative_stride() -> memref> + +// ----- + +// expected-error @below {{expected a non-negative 64-bit signed integer or '?'}} +func.func private @memref_negative_offset() -> memref> + +// ----- + +// expected-error @below {{expected the number of strides to match the rank}} +func.func private @memref_strided_rank_mismatch() -> memref> // ----- diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/matmul-vs-matvec.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/matmul-vs-matvec.mlir --- a/mlir/test/Integration/Dialect/Linalg/CPU/matmul-vs-matvec.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/matmul-vs-matvec.mlir @@ -28,10 +28,10 @@ %C = memref.alloc(%m, %n) : memref linalg.fill ins(%f0 : f32) outs(%C : memref) scf.for %i = %c0 to %n step %c1 { - %b = memref.subview %B[0, %i][%x, 1][1, 1] : memref to memref - %c = memref.subview %C[0, %i][%m, 1][1, 1] : memref to memref - linalg.matvec ins(%A, %b: memref, memref) - outs(%c: memref) + %b = memref.subview %B[0, %i][%x, 1][1, 1] : memref to memref> + %c = memref.subview %C[0, %i][%m, 1][1, 1] : memref to memref> + linalg.matvec ins(%A, %b: memref, memref>) + outs(%c: memref>) } return %C : memref } diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/rank-reducing-subview.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/rank-reducing-subview.mlir --- a/mlir/test/Integration/Dialect/Linalg/CPU/rank-reducing-subview.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/rank-reducing-subview.mlir @@ -18,13 +18,13 @@ memref.store %f1, %A[%c0, %c1] : memref memref.store %f2, %A[%c1, %c0] : memref memref.store %f3, %A[%c1, %c1] : memref - %B = memref.subview %A[%c1, 0][1, %c2][1, 1] : memref to memref - %C = memref.subview %A[0, %c1][%c2, 1][1, 1] : memref to memref + %B = memref.subview %A[%c1, 0][1, %c2][1, 1] : memref to memref> + %C = memref.subview %A[0, %c1][%c2, 1][1, 1] : memref to memref> %A_ = memref.cast %A : memref to memref<*xf32> call @printMemrefF32(%A_) : (memref<*xf32>) -> () - %B_ = memref.cast %B : memref to memref<*xf32> + %B_ = memref.cast %B : memref> to memref<*xf32> call @printMemrefF32(%B_) : (memref<*xf32>) -> () - %C_ = memref.cast %C : memref to memref<*xf32> + %C_ = memref.cast %C : memref> to memref<*xf32> call @printMemrefF32(%C_) : (memref<*xf32>) -> () memref.dealloc %A : memref return diff --git a/mlir/test/Integration/Dialect/Standard/CPU/test_subview.mlir b/mlir/test/Integration/Dialect/Standard/CPU/test_subview.mlir --- a/mlir/test/Integration/Dialect/Standard/CPU/test_subview.mlir +++ b/mlir/test/Integration/Dialect/Standard/CPU/test_subview.mlir @@ -13,8 +13,8 @@ %0 = memref.get_global @__constant_5x3xf32 : memref<5x3xf32> /// Subview with only leading operands. - %1 = memref.subview %0[2, 0][3, 3][1, 1]: memref<5x3xf32> to memref<3x3xf32, offset: 6, strides: [3, 1]> - %unranked = memref.cast %1 : memref<3x3xf32, offset: 6, strides: [3, 1]> to memref<*xf32> + %1 = memref.subview %0[2, 0][3, 3][1, 1]: memref<5x3xf32> to memref<3x3xf32, strided<[3, 1], offset: 6>> + %unranked = memref.cast %1 : memref<3x3xf32, strided<[3, 1], offset: 6>> to memref<*xf32> call @printMemrefF32(%unranked) : (memref<*xf32>) -> () // CHECK: Unranked Memref base@ = {{0x[-9a-f]*}} @@ -26,8 +26,8 @@ // CHECK-SAME: ] /// Regular subview. - %2 = memref.subview %0[0, 2][5, 1][1, 1]: memref<5x3xf32> to memref<5x1xf32, offset: 2, strides: [3, 1]> - %unranked2 = memref.cast %2 : memref<5x1xf32, offset: 2, strides: [3, 1]> to memref<*xf32> + %2 = memref.subview %0[0, 2][5, 1][1, 1]: memref<5x3xf32> to memref<5x1xf32, strided<[3, 1], offset: 2>> + %unranked2 = memref.cast %2 : memref<5x1xf32, strided<[3, 1], offset: 2>> to memref<*xf32> call @printMemrefF32(%unranked2) : (memref<*xf32>) -> () // CHECK: Unranked Memref base@ = {{0x[-9a-f]*}} @@ -41,8 +41,8 @@ // CHECK-SAME: ] /// Rank-reducing subview. - %3 = memref.subview %0[0, 2][5, 1][1, 1]: memref<5x3xf32> to memref<5xf32, offset: 2, strides: [3]> - %unranked3 = memref.cast %3 : memref<5xf32, offset: 2, strides: [3]> to memref<*xf32> + %3 = memref.subview %0[0, 2][5, 1][1, 1]: memref<5x3xf32> to memref<5xf32, strided<[3], offset: 2>> + %unranked3 = memref.cast %3 : memref<5xf32, strided<[3], offset: 2>> to memref<*xf32> call @printMemrefF32(%unranked3) : (memref<*xf32>) -> () // CHECK: Unranked Memref base@ = {{0x[-9a-f]*}} @@ -50,8 +50,8 @@ // CHECK-NEXT: [2, 5, 8, 11, 14] /// Rank-reducing subview with only leading operands. - %4 = memref.subview %0[1, 0][1, 3][1, 1]: memref<5x3xf32> to memref<3xf32, offset: 3, strides: [1]> - %unranked4 = memref.cast %4 : memref<3xf32, offset: 3, strides: [1]> to memref<*xf32> + %4 = memref.subview %0[1, 0][1, 3][1, 1]: memref<5x3xf32> to memref<3xf32, strided<[1], offset: 3>> + %unranked4 = memref.cast %4 : memref<3xf32, strided<[1], offset: 3>> to memref<*xf32> call @printMemrefF32(%unranked4) : (memref<*xf32>) -> () // CHECK: Unranked Memref base@ = {{0x[-9a-f]*}} // CHECK-SAME: rank = 1 offset = 3 sizes = [3] strides = [1] data = diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir @@ -70,9 +70,9 @@ %c6 = arith.constant 6 : index %fm42 = arith.constant -42.0: f32 %1 = memref.reinterpret_cast %A to offset: [%c6], sizes: [%c1, %c2], strides: [%c6, %c1] - : memref to memref + : memref to memref> %2 = vector.transfer_read %1[%c2, %c1], %fm42 {in_bounds=[true]} - : memref, vector<4xf32> + : memref>, vector<4xf32> vector.print %2 : vector<4xf32> return } diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -764,86 +764,86 @@ %c15 = arith.constant 15 : index // CHECK: %[[ALLOC0:.*]] = memref.alloc() - %0 = memref.alloc() : memref<8x16x4xf32, offset : 0, strides : [64, 4, 1]> + %0 = memref.alloc() : memref<8x16x4xf32, strided<[64, 4, 1], offset: 0>> // Test: subview with constant base memref and constant operands is folded. // Note that the subview uses the base memrefs layout map because it used // zero offset and unit stride arguments. // CHECK: memref.subview %[[ALLOC0]][0, 0, 0] [7, 11, 2] [1, 1, 1] : - // CHECK-SAME: memref<8x16x4xf32, #[[$BASE_MAP0]]> + // CHECK-SAME: memref<8x16x4xf32, strided<[64, 4, 1]>> // CHECK-SAME: to memref<7x11x2xf32, #[[$BASE_MAP0]]> %1 = memref.subview %0[%c0, %c0, %c0] [%c7, %c11, %c2] [%c1, %c1, %c1] - : memref<8x16x4xf32, offset : 0, strides : [64, 4, 1]> to - memref - %v0 = memref.load %1[%c0, %c0, %c0] : memref + : memref<8x16x4xf32, strided<[64, 4, 1], offset: 0>> to + memref> + %v0 = memref.load %1[%c0, %c0, %c0] : memref> // Test: subview with one dynamic operand can also be folded. // CHECK: memref.subview %[[ALLOC0]][0, %[[ARG0]], 0] [7, 11, 15] [1, 1, 1] : - // CHECK-SAME: memref<8x16x4xf32, #[[$BASE_MAP0]]> + // CHECK-SAME: memref<8x16x4xf32, strided<[64, 4, 1]>> // CHECK-SAME: to memref<7x11x15xf32, #[[$SUBVIEW_MAP0]]> %2 = memref.subview %0[%c0, %arg0, %c0] [%c7, %c11, %c15] [%c1, %c1, %c1] - : memref<8x16x4xf32, offset : 0, strides : [64, 4, 1]> to - memref - memref.store %v0, %2[%c0, %c0, %c0] : memref + : memref<8x16x4xf32, strided<[64, 4, 1], offset: 0>> to + memref> + memref.store %v0, %2[%c0, %c0, %c0] : memref> // CHECK: %[[ALLOC1:.*]] = memref.alloc(%[[ARG0]]) - %3 = memref.alloc(%arg0) : memref + %3 = memref.alloc(%arg0) : memref> // Test: subview with constant operands but dynamic base memref is folded as long as the strides and offset of the base memref are static. // CHECK: memref.subview %[[ALLOC1]][0, 0, 0] [7, 11, 15] [1, 1, 1] : - // CHECK-SAME: memref + // CHECK-SAME: memref> // CHECK-SAME: to memref<7x11x15xf32, #[[$BASE_MAP0]]> %4 = memref.subview %3[%c0, %c0, %c0] [%c7, %c11, %c15] [%c1, %c1, %c1] - : memref to - memref - memref.store %v0, %4[%c0, %c0, %c0] : memref + : memref> to + memref> + memref.store %v0, %4[%c0, %c0, %c0] : memref> // Test: subview offset operands are folded correctly w.r.t. base strides. // CHECK: memref.subview %[[ALLOC0]][1, 2, 7] [7, 11, 2] [1, 1, 1] : - // CHECK-SAME: memref<8x16x4xf32, #[[$BASE_MAP0]]> to + // CHECK-SAME: memref<8x16x4xf32, strided<[64, 4, 1]>> to // CHECK-SAME: memref<7x11x2xf32, #[[$SUBVIEW_MAP1]]> %5 = memref.subview %0[%c1, %c2, %c7] [%c7, %c11, %c2] [%c1, %c1, %c1] - : memref<8x16x4xf32, offset : 0, strides : [64, 4, 1]> to - memref - memref.store %v0, %5[%c0, %c0, %c0] : memref + : memref<8x16x4xf32, strided<[64, 4, 1], offset: 0>> to + memref> + memref.store %v0, %5[%c0, %c0, %c0] : memref> // Test: subview stride operands are folded correctly w.r.t. base strides. // CHECK: memref.subview %[[ALLOC0]][0, 0, 0] [7, 11, 2] [2, 7, 11] : - // CHECK-SAME: memref<8x16x4xf32, #[[$BASE_MAP0]]> + // CHECK-SAME: memref<8x16x4xf32, strided<[64, 4, 1]>> // CHECK-SAME: to memref<7x11x2xf32, #[[$SUBVIEW_MAP2]]> %6 = memref.subview %0[%c0, %c0, %c0] [%c7, %c11, %c2] [%c2, %c7, %c11] - : memref<8x16x4xf32, offset : 0, strides : [64, 4, 1]> to - memref - memref.store %v0, %6[%c0, %c0, %c0] : memref + : memref<8x16x4xf32, strided<[64, 4, 1], offset: 0>> to + memref> + memref.store %v0, %6[%c0, %c0, %c0] : memref> // Test: subview shape are folded, but offsets and strides are not even if base memref is static // CHECK: memref.subview %[[ALLOC0]][%[[ARG0]], %[[ARG0]], %[[ARG0]]] [7, 11, 2] [%[[ARG1]], %[[ARG1]], %[[ARG1]]] : - // CHECK-SAME: memref<8x16x4xf32, #[[$BASE_MAP0]]> to + // CHECK-SAME: memref<8x16x4xf32, strided<[64, 4, 1]>> to // CHECK-SAME: memref<7x11x2xf32, #[[$SUBVIEW_MAP3]]> %10 = memref.subview %0[%arg0, %arg0, %arg0] [%c7, %c11, %c2] [%arg1, %arg1, %arg1] : - memref<8x16x4xf32, offset:0, strides:[64, 4, 1]> to - memref + memref<8x16x4xf32, strided<[64, 4, 1], offset: 0>> to + memref> memref.store %v0, %10[%arg1, %arg1, %arg1] : - memref + memref> // Test: subview strides are folded, but offsets and shape are not even if base memref is static // CHECK: memref.subview %[[ALLOC0]][%[[ARG0]], %[[ARG0]], %[[ARG0]]] [%[[ARG1]], %[[ARG1]], %[[ARG1]]] [2, 7, 11] : - // CHECK-SAME: memref<8x16x4xf32, #[[$BASE_MAP0]]> to + // CHECK-SAME: memref<8x16x4xf32, strided<[64, 4, 1]>> to // CHECK-SAME: memref to - memref + memref<8x16x4xf32, strided<[64, 4, 1], offset: 0>> to + memref> memref.store %v0, %11[%arg0, %arg0, %arg0] : - memref + memref> // Test: subview offsets are folded, but strides and shape are not even if base memref is static // CHECK: memref.subview %[[ALLOC0]][1, 2, 7] [%[[ARG1]], %[[ARG1]], %[[ARG1]]] [%[[ARG0]], %[[ARG0]], %[[ARG0]]] : - // CHECK-SAME: memref<8x16x4xf32, #[[$BASE_MAP0]]> to + // CHECK-SAME: memref<8x16x4xf32, strided<[64, 4, 1]>> to // CHECK-SAME: memref to - memref + memref<8x16x4xf32, strided<[64, 4, 1], offset: 0>> to + memref> memref.store %v0, %13[%arg1, %arg1, %arg1] : - memref + memref> // CHECK: %[[ALLOC2:.*]] = memref.alloc(%[[ARG0]], %[[ARG0]], %[[ARG1]]) %14 = memref.alloc(%arg0, %arg0, %arg1) : memref @@ -853,8 +853,8 @@ // CHECK-SAME: memref<7x11x2xf32, #[[$SUBVIEW_MAP3]]> %15 = memref.subview %14[%arg0, %arg0, %arg0] [%c7, %c11, %c2] [%arg1, %arg1, %arg1] : memref to - memref - memref.store %v0, %15[%arg1, %arg1, %arg1] : memref + memref> + memref.store %v0, %15[%arg1, %arg1, %arg1] : memref> // TEST: subview strides are folded, in the type only the most minor stride is folded. // CHECK: memref.subview %[[ALLOC2]][%[[ARG0]], %[[ARG0]], %[[ARG0]]] [%[[ARG1]], %[[ARG1]], %[[ARG1]]] [2, 2, 2] : @@ -862,8 +862,8 @@ // CHECK-SAME: memref to - memref - memref.store %v0, %16[%arg0, %arg0, %arg0] : memref + memref> + memref.store %v0, %16[%arg0, %arg0, %arg0] : memref> // TEST: subview offsets are folded but the type offset remains dynamic, when the base memref is not static // CHECK: memref.subview %[[ALLOC2]][1, 1, 1] [%[[ARG0]], %[[ARG0]], %[[ARG0]]] [%[[ARG1]], %[[ARG1]], %[[ARG1]]] : @@ -871,8 +871,8 @@ // CHECK-SAME: memref to - memref - memref.store %v0, %17[%arg0, %arg0, %arg0] : memref + memref> + memref.store %v0, %17[%arg0, %arg0, %arg0] : memref> // CHECK: %[[ALLOC3:.*]] = memref.alloc() : memref<12x4xf32> %18 = memref.alloc() : memref<12x4xf32> @@ -884,8 +884,8 @@ // CHECK-SAME: memref<2x4xf32, #[[$SUBVIEW_MAP7]]> %19 = memref.subview %18[%arg1, %arg1] [%c2, %c4] [1, 1] : memref<12x4xf32> to - memref - memref.store %v0, %19[%arg1, %arg1] : memref + memref> + memref.store %v0, %19[%arg1, %arg1] : memref> // TEST: subview strides and sizes are maintained when offsets are folded // CHECK: memref.subview %[[ALLOC3]][2, 4] [12, 4] [1, 1] : @@ -893,12 +893,12 @@ // CHECK-SAME: memref<12x4xf32, #[[$SUBVIEW_MAP8]]> %20 = memref.subview %18[%c2, %c4] [12, 4] [1, 1] : memref<12x4xf32> to - memref<12x4xf32, offset: ?, strides:[4, 1]> - memref.store %v0, %20[%arg1, %arg1] : memref<12x4xf32, offset: ?, strides:[4, 1]> + memref<12x4xf32, strided<[4, 1], offset: ?>> + memref.store %v0, %20[%arg1, %arg1] : memref<12x4xf32, strided<[4, 1], offset: ?>> // Test: dim on subview is rewritten to size operand. - %7 = memref.dim %4, %c0 : memref - %8 = memref.dim %4, %c1 : memref + %7 = memref.dim %4, %c0 : memref> + %8 = memref.dim %4, %c1 : memref> // CHECK: return %[[C7]], %[[C11]] return %7, %8 : index, index @@ -1049,28 +1049,28 @@ // ----- // CHECK-LABEL: func @memref_cast_folding_subview -func.func @memref_cast_folding_subview(%arg0: memref<4x5xf32>, %i: index) -> (memref) { +func.func @memref_cast_folding_subview(%arg0: memref<4x5xf32>, %i: index) -> (memref>) { %0 = memref.cast %arg0 : memref<4x5xf32> to memref // CHECK-NEXT: memref.subview %{{.*}}: memref<4x5xf32> - %1 = memref.subview %0[%i, %i][%i, %i][%i, %i]: memref to memref + %1 = memref.subview %0[%i, %i][%i, %i][%i, %i]: memref to memref> + // CHECK-NEXT: memref.cast // CHECK-NEXT: return %{{.*}} - return %1: memref + return %1: memref> } // ----- // CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1) -> (d0 * 16 + d1)> -// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> // CHECK-LABEL: func @memref_cast_folding_subview_static( func.func @memref_cast_folding_subview_static(%V: memref<16x16xf32>, %a: index, %b: index) - -> memref<3x4xf32, offset:?, strides:[?, 1]> + -> memref<3x4xf32, strided<[?, 1], offset: ?>> { %0 = memref.cast %V : memref<16x16xf32> to memref - %1 = memref.subview %0[0, 0][3, 4][1, 1] : memref to memref<3x4xf32, offset:?, strides:[?, 1]> + %1 = memref.subview %0[0, 0][3, 4][1, 1] : memref to memref<3x4xf32, strided<[?, 1], offset: ?>> // CHECK: memref.subview{{.*}}: memref<16x16xf32> to memref<3x4xf32, #[[$map0]]> - return %1: memref<3x4xf32, offset:?, strides:[?, 1]> + return %1: memref<3x4xf32, strided<[?, 1], offset: ?>> } // ----- diff --git a/mlir/test/mlir-cpu-runner/copy.mlir b/mlir/test/mlir-cpu-runner/copy.mlir --- a/mlir/test/mlir-cpu-runner/copy.mlir +++ b/mlir/test/mlir-cpu-runner/copy.mlir @@ -37,8 +37,8 @@ %copy_two = memref.alloc() : memref<3x2xf32> %copy_two_casted = memref.reinterpret_cast %copy_two to offset: [0], sizes: [2, 3], strides: [1, 2] - : memref<3x2xf32> to memref<2x3xf32, offset: 0, strides: [1, 2]> - memref.copy %input, %copy_two_casted : memref<2x3xf32> to memref<2x3xf32, offset: 0, strides: [1, 2]> + : memref<3x2xf32> to memref<2x3xf32, strided<[1, 2], offset: 0>> + memref.copy %input, %copy_two_casted : memref<2x3xf32> to memref<2x3xf32, strided<[1, 2], offset: 0>> %unranked_copy_two = memref.cast %copy_two : memref<3x2xf32> to memref<*xf32> call @printMemrefF32(%unranked_copy_two) : (memref<*xf32>) -> () // CHECK: rank = 2 offset = 0 sizes = [3, 2] strides = [2, 1] @@ -52,10 +52,10 @@ memref.copy %input_empty, %copy_empty : memref<3x0x1xf32> to memref<3x0x1xf32> %input_empty_casted = memref.reinterpret_cast %input_empty to offset: [0], sizes: [0, 3, 1], strides: [3, 1, 1] - : memref<3x0x1xf32> to memref<0x3x1xf32, offset: 0, strides: [3, 1, 1]> + : memref<3x0x1xf32> to memref<0x3x1xf32, strided<[3, 1, 1], offset: 0>> %copy_empty_casted = memref.alloc() : memref<0x3x1xf32> // Copying a casted empty shape should do nothing (and should not crash). - memref.copy %input_empty_casted, %copy_empty_casted : memref<0x3x1xf32, offset: 0, strides: [3, 1, 1]> to memref<0x3x1xf32> + memref.copy %input_empty_casted, %copy_empty_casted : memref<0x3x1xf32, strided<[3, 1, 1], offset: 0>> to memref<0x3x1xf32> %scalar = memref.alloc() : memref memref.store %c42, %scalar[] : memref diff --git a/mlir/test/mlir-cpu-runner/memref-reinterpret-cast.mlir b/mlir/test/mlir-cpu-runner/memref-reinterpret-cast.mlir --- a/mlir/test/mlir-cpu-runner/memref-reinterpret-cast.mlir +++ b/mlir/test/mlir-cpu-runner/memref-reinterpret-cast.mlir @@ -59,10 +59,10 @@ %c6 = arith.constant 6 : index %output = memref.reinterpret_cast %input to offset: [%c0], sizes: [%c1, %c6], strides: [%c6, %c1] - : memref<2x3xf32> to memref + : memref<2x3xf32> to memref> %unranked_output = memref.cast %output - : memref to memref<*xf32> + : memref> to memref<*xf32> call @printMemrefF32(%unranked_output) : (memref<*xf32>) -> () // CHECK: rank = 2 offset = 0 sizes = [1, 6] strides = [6, 1] data = // CHECK-NEXT: [0, 1, 2, 3, 4, 5] @@ -95,10 +95,10 @@ %c6 = arith.constant 6 : index %output = memref.reinterpret_cast %unranked_input to offset: [%c0], sizes: [%c1, %c6], strides: [%c6, %c1] - : memref<*xf32> to memref + : memref<*xf32> to memref> %unranked_output = memref.cast %output - : memref to memref<*xf32> + : memref> to memref<*xf32> call @printMemrefF32(%unranked_output) : (memref<*xf32>) -> () // CHECK: rank = 2 offset = 0 sizes = [1, 6] strides = [6, 1] data = // CHECK-NEXT: [0, 1, 2, 3, 4, 5] diff --git a/mlir/test/python/dialects/memref.py b/mlir/test/python/dialects/memref.py --- a/mlir/test/python/dialects/memref.py +++ b/mlir/test/python/dialects/memref.py @@ -24,7 +24,7 @@ %3 = arith.constant 3 : index %4 = arith.constant 4 : index %5 = arith.constant 5 : index - memref.subview %arg0[%0, %1][%2, %3][%4, %5] : memref to memref + memref.subview %arg0[%0, %1][%2, %3][%4, %5] : memref to memref> return } """, ctx)