diff --git a/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h b/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h --- a/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h +++ b/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h @@ -17,6 +17,7 @@ namespace intrinsics { using linalg_fill = OperationBuilder; +using linalg_reshape = OperationBuilder; using linalg_yield = OperationBuilder; } // namespace intrinsics diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -58,6 +58,58 @@ let verifier = ?; } +def Linalg_ReshapeOp : Linalg_Op<"reshape", [NoSideEffect]>, + Arguments<(ins AnyStridedMemRef:$view, AffineMapArrayAttr:$reassociation)>, + Results<(outs AnyStridedMemRef)> { + let summary = "linalg.reshape produces a new view into the operand view"; + let description = [{ + The `linalg.reshape` op produces a new view whose sizes are a reassociation + of the original `view`. Depending on whether or not the reassociated + MemRefType is contiguous, the resulting memref may require explicit alloc + and copies. + + A reassociation is defined as a continous grouping of dimensions and is + represented with an affine map array attribute. In the future, non-continous + groupings may be allowed (i.e. permutations, reindexings etc). + + For now, it is assumed that either: + 1. a reassociation produces and consumes contiguous MemRefType or, + 2. the reshape op will be folded into its consumers (by changing the shape + of the computations). + All other cases are undefined behavior and a reshape op may not lower to + LLVM if it cannot be proven statically that it does not require alloc+copy. + + A reshape may either collapse or expand dimensions, depending on the + relationship between source and target memref ranks. The verification rule + is that the reassociation maps are applied to the memref with the larger + rank to obtain the memref with the smaller rank. In the case of a dimension + expansion, the reassociation maps can be interpreted as inverse maps. + + Examples: + + ```mlir + // Dimension collapse (i, j) -> i' and k -> k' + %1 = linalg.reshape %0 [(i, j, k) -> (i, j), (i, j, k) -> (k)] : + memref into memref + ``` + + ```mlir + // Dimension expansion i -> (i', j') and (k) -> (k') + %1 = linalg.reshape %0 [(i, j, k) -> (i, j), (i, j, k) -> (k)] : + memref into memref + ``` + }]; + + let builders = [OpBuilder< + "Builder *b, OperationState &result, Value view, " + "ArrayAttr reassociation, ArrayRef attrs = {}">]; + + let extraClassDeclaration = [{ + static StringRef getReassociationAttrName() { return "reassociation"; } + MemRefType getViewType() { return view().getType().cast(); } + }]; +} + def Linalg_SliceOp : Linalg_Op<"slice", [NoSideEffect]>, Arguments<(ins AnyStridedMemRef:$view, Variadic>:$indexings)>, diff --git a/mlir/include/mlir/IR/AffineExpr.h b/mlir/include/mlir/IR/AffineExpr.h --- a/mlir/include/mlir/IR/AffineExpr.h +++ b/mlir/include/mlir/IR/AffineExpr.h @@ -87,6 +87,7 @@ template bool isa() const; template U dyn_cast() const; + template U dyn_cast_or_null() const; template U cast() const; MLIRContext *getContext() const; @@ -226,25 +227,23 @@ raw_ostream &operator<<(raw_ostream &os, AffineExpr &expr); template bool AffineExpr::isa() const { - if (std::is_same::value) { + if (std::is_same::value) return getKind() <= AffineExprKind::LAST_AFFINE_BINARY_OP; - } - if (std::is_same::value) { + if (std::is_same::value) return getKind() == AffineExprKind::DimId; - } - if (std::is_same::value) { + if (std::is_same::value) return getKind() == AffineExprKind::SymbolId; - } - if (std::is_same::value) { + if (std::is_same::value) return getKind() == AffineExprKind::Constant; - } } template U AffineExpr::dyn_cast() const { - if (isa()) { + if (isa()) return U(expr); - } return U(nullptr); } +template U AffineExpr::dyn_cast_or_null() const { + return (!*this || !isa()) ? U(nullptr) : U(expr); +} template U AffineExpr::cast() const { assert(isa()); return U(expr); diff --git a/mlir/include/mlir/IR/StandardTypes.h b/mlir/include/mlir/IR/StandardTypes.h --- a/mlir/include/mlir/IR/StandardTypes.h +++ b/mlir/include/mlir/IR/StandardTypes.h @@ -16,6 +16,7 @@ } // namespace llvm namespace mlir { +class AffineExpr; class AffineMap; class FloatType; class IndexType; @@ -245,6 +246,9 @@ /// Whether the given dimension size indicates a dynamic dimension. static constexpr bool isDynamic(int64_t dSize) { return dSize < 0; } + static constexpr bool isDynamicStrideOrOffset(int64_t dStrideOrOffset) { + return dStrideOrOffset == kDynamicStrideOrOffset; + } }; /// Vector types represent multi-dimensional SIMD vectors, and have a fixed @@ -548,6 +552,9 @@ LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl &strides, int64_t &offset); +LogicalResult getStridesAndOffset(MemRefType t, + SmallVectorImpl &strides, + AffineExpr &offset); /// Given a list of strides (in which MemRefType::getDynamicStrideOrOffset() /// represents a dynamic value), return the single result AffineMap which @@ -569,6 +576,13 @@ AffineMap makeStridedLinearLayoutMap(ArrayRef strides, int64_t offset, MLIRContext *context); +/// Return a version of `t` with identity layout if it can be determined +/// statically that the layout is the canonical contiguous strided layout. +/// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of +/// `t` with simplifed layout. +MemRefType canonicalizeStridedLayout(MemRefType t); + +/// Return true if the layout for `t` is compatible with strided semantics. bool isStrided(MemRefType t); } // end namespace mlir diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// // -// This file implements a the Linalg operations. +// This file implements the Linalg operations. // //===----------------------------------------------------------------------===// @@ -23,6 +23,7 @@ #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/StandardTypes.h" +#include "mlir/Support/Functional.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/STLExtras.h" #include "mlir/Transforms/FoldUtils.h" @@ -332,6 +333,206 @@ parser.addTypeToList(type, result.types)); } +//===----------------------------------------------------------------------===// +// ReshapeOp +//===----------------------------------------------------------------------===// + +/// Return true if the reassociation specification is valid, false otherwise. +/// When false, the `invalidIndex` integer pointer is optionally filled with the +/// index of the offending reassociation map. +static bool isReassociationValid(ArrayRef reassociation, + int *invalidIndex = nullptr) { + if (reassociation.empty()) + return true; + unsigned nDims = reassociation[0].getNumDims(); + unsigned nextExpectedDim = 0; + for (auto it : llvm::enumerate(reassociation)) { + auto m = it.value(); + if (m.getNumDims() != nDims || m.getNumSymbols() != 0) { + if (invalidIndex) + *invalidIndex = it.index(); + return false; + } + for (auto e : m.getResults()) { + auto d = e.dyn_cast(); + if (!d || d.getPosition() != nextExpectedDim++) { + if (invalidIndex) + *invalidIndex = it.index(); + return false; + } + } + } + if (nextExpectedDim != nDims) { + if (invalidIndex) + *invalidIndex = reassociation.size() - 1; + return false; + } + return true; +} + +/// Detect whether memref dims [dim, dim + extent) can be reshaped without +/// copies. +static bool isReshapableDimBand(unsigned dim, unsigned extent, + ArrayRef sizes, + ArrayRef strides) { + assert(sizes.size() == strides.size() && "mismatched ranks"); + // off by 1 indexing to avoid out of bounds + // V + for (auto idx = dim, e = dim + extent; idx + 1 < e; ++idx) { + // Only bands of static shapes are reshapable. This is due to the fact that + // there is no relation between dynamic sizes and dynamic strides: we do not + // have enough information to know whether a "-1" size corresponds to the + // proper symbol in the AffineExpr of a stride. + if (ShapedType::isDynamic(sizes[dim + 1])) + return false; + // TODO(ntv) Refine this by passing the proper nDims and nSymbols so we can + // simplify on the fly and catch more reshapable cases. + if (strides[idx] != strides[idx + 1] * sizes[idx + 1]) + return false; + } + return true; +} + +/// Compute the MemRefType obtained by applying the `reassociation` (which is +/// expected to be valid) to `type`. +/// If `type` is Contiguous MemRefType, this always produce a contiguous +/// MemRefType. +static MemRefType +computeReshapeCollapsedType(MemRefType type, + ArrayRef reassociation) { + auto sizes = type.getShape(); + AffineExpr offset; + SmallVector strides; + auto status = getStridesAndOffset(type, strides, offset); + (void)status; + assert(succeeded(status) && "expected strided memref"); + + SmallVector newSizes; + newSizes.reserve(reassociation.size()); + SmallVector newStrides; + newStrides.reserve(reassociation.size()); + + // Use the fact that reassociation is valid to simplify the logic: only use + // each map's rank. + assert(isReassociationValid(reassociation) && "invalid reassociation"); + unsigned currentDim = 0; + for (AffineMap m : reassociation) { + unsigned dim = m.getNumResults(); + int64_t size = 1; + AffineExpr stride = strides[currentDim + dim - 1]; + if (!isReshapableDimBand(currentDim, dim, sizes, strides)) { + size = ShapedType::kDynamicSize; + stride = AffineExpr(); + } else { + for (unsigned d = 0; d < dim; ++d) + size *= sizes[currentDim + d]; + } + newSizes.push_back(size); + newStrides.push_back(stride); + currentDim += dim; + } + + // Early-exit: if `type` is contiguous, the result must be contiguous. + if (canonicalizeStridedLayout(type).getAffineMaps().empty()) + return MemRefType::get(newSizes, type.getElementType(), {}); + + // Convert back to int64_t because we don't have enough information to create + // new strided layouts from AffineExpr only. This corresponds to a case where + // copies may be necessary. + int64_t intOffset = ShapedType::kDynamicStrideOrOffset; + if (auto o = offset.dyn_cast()) + intOffset = o.getValue(); + SmallVector intStrides; + intStrides.reserve(strides.size()); + for (auto stride : newStrides) { + if (auto cst = stride.dyn_cast_or_null()) + intStrides.push_back(cst.getValue()); + else + intStrides.push_back(ShapedType::kDynamicStrideOrOffset); + } + auto layout = + makeStridedLinearLayoutMap(intStrides, intOffset, type.getContext()); + return canonicalizeStridedLayout( + MemRefType::get(newSizes, type.getElementType(), {layout})); +} + +/// Helper functions assert Attribute of the proper type in attr and returns the +/// corresponding vector. +/// TODO(rridle,ntv) this should be evolved into a generic +/// `getRangeOfType(ArrayAttr attrs)` that does not copy. +static SmallVector getAffineMaps(ArrayAttr attrs) { + return functional::map( + [](Attribute a) { return a.cast().getValue(); }, attrs); +} + +void mlir::linalg::ReshapeOp::build(Builder *b, OperationState &result, + Value view, ArrayAttr reassociation, + ArrayRef attrs) { + auto maps = getAffineMaps(reassociation); + auto memRefType = view.getType().cast(); + auto resultType = computeReshapeCollapsedType(memRefType, maps); + build(b, result, resultType, view, attrs); + result.addAttribute(ReshapeOp::getReassociationAttrName(), reassociation); +} + +static void print(OpAsmPrinter &p, ReshapeOp op) { + p << op.getOperationName() << " " << op.view() << " " << op.reassociation(); + p.printOptionalAttrDict(op.getAttrs(), + {ReshapeOp::getReassociationAttrName()}); + p << " : " << op.getViewType() << " into " << op.getResult().getType(); +} + +static ParseResult parseReshapeOp(OpAsmParser &parser, OperationState &result) { + OpAsmParser::OperandType view; + ArrayAttr reassociation; + MemRefType type, resultType; + return failure(parser.parseOperand(view) || + parser.parseAttribute(reassociation, + ReshapeOp::getReassociationAttrName(), + result.attributes) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(type) || + parser.parseKeywordType("into", resultType) || + parser.resolveOperand(view, type, result.operands) || + parser.addTypeToList(resultType, result.types)); +} + +static LogicalResult verify(ReshapeOp op) { + MemRefType expandedType = op.getViewType(); + MemRefType collapsedType = op.getResult().getType().cast(); + unsigned expandedRank = expandedType.getRank(); + unsigned collapsedRank = collapsedType.getRank(); + bool isCollapse = expandedRank > collapsedRank; + if (!isCollapse) { + std::swap(expandedRank, collapsedRank); + std::swap(expandedType, collapsedType); + } + if (expandedRank == 0 || collapsedRank == 0) + return op.emitOpError("expected non-zero memref ranks"); + if (expandedRank == collapsedRank) + return op.emitOpError("expected to collapse or expand dims"); + + if (collapsedRank != op.reassociation().size()) + return op.emitOpError("expected rank of the collapsed view(") + << collapsedRank << ") to be the number of reassociation maps(" + << op.reassociation().size() << ")"; + auto maps = getAffineMaps(op.reassociation()); + for (auto it : llvm::enumerate(maps)) + if (it.value().getNumDims() != expandedRank) + return op.emitOpError("expected reassociation map #") + << it.index() << " of same rank as expanded memref(" + << expandedRank << "), but got " << it.value().getNumDims(); + int invalidIdx = 0; + if (!isReassociationValid(maps, &invalidIdx)) + return op.emitOpError("expected reassociation map #") + << invalidIdx << " to be valid and contiguous"; + MemRefType expectedType = computeReshapeCollapsedType(expandedType, maps); + if (collapsedType != expectedType) + return op.emitOpError("expected collapsed type to be ") + << expectedType << ", but got " << collapsedType; + return success(); +} + //===----------------------------------------------------------------------===// // SliceOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp --- a/mlir/lib/IR/StandardTypes.cpp +++ b/mlir/lib/IR/StandardTypes.cpp @@ -520,9 +520,9 @@ llvm_unreachable("unexpected binary operation"); } -static LogicalResult getStridesAndOffset(MemRefType t, - SmallVectorImpl &strides, - AffineExpr &offset) { +LogicalResult mlir::getStridesAndOffset(MemRefType t, + SmallVectorImpl &strides, + AffineExpr &offset) { auto affineMaps = t.getAffineMaps(); // For now strides are only computed on a single affine map with a single // result (i.e. the closed subset of linearization maps that are compatible @@ -699,6 +699,38 @@ return AffineMap::get(strides.size(), nSymbols, expr); } +/// Return a version of `t` with identity layout if it can be determined +/// statically that the layout is the canonical contiguous strided layout. +/// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of +/// `t` with simplifed layout. +/// If `t` has multiple layout maps or a multi-result layout, just return `t`. +MemRefType mlir::canonicalizeStridedLayout(MemRefType t) { + auto affineMaps = t.getAffineMaps(); + // Already in canonical form. + if (affineMaps.empty()) + return t; + + // Can't reduce to canonical identity form, return in canonical form. + if (affineMaps.size() > 1 || affineMaps[0].getNumResults() > 1) + return t; + + // If the canonical strided layout for the sizes of `t` is equal to the + // simplified layout of `t` we can just return an empty layout. Otherwise, + // just simplify the existing layout. + AffineExpr expr = + makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext()); + auto m = affineMaps[0]; + auto simplifiedLayoutExpr = + simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols()); + if (expr != simplifiedLayoutExpr) + return MemRefType::get(t.getShape(), t.getElementType(), + {AffineMap::get(m.getNumDims(), m.getNumSymbols(), + {simplifiedLayoutExpr})}); + + return MemRefType::get(t.getShape(), t.getElementType(), {}); +} + +/// Return true if the layout for `t` is compatible with strided semantics. bool mlir::isStrided(MemRefType t) { int64_t offset; SmallVector stridesAndOffset; diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -482,3 +482,49 @@ // expected-error @+1 {{expected valid keyword}} !invalid_type = type !linalg<"?"> + +// ----- + +func @reshape(%arg0: memref) { + // expected-error @+1 {{expected non-zero memref ranks}} + %0 = linalg.reshape %arg0 [()->(0)] : memref into memref +} + +// ----- + +func @reshape(%arg0: memref) { + // expected-error @+1 {{expected to collapse or expand dims}} + %0 = linalg.reshape %arg0 [(i)->(i)] : memref into memref +} + +// ----- + +func @reshape(%arg0: memref) { + // expected-error @+1 {{expected rank of the collapsed view(2) to be the number of reassociation maps(1)}} + %0 = linalg.reshape %arg0 [(i, j, k) -> (i, j)] : + memref into memref +} + +// ----- + +func @reshape(%arg0: memref) { + // expected-error @+1 {{expected reassociation map #0 of same rank as expanded memref(3), but got 1}} + %0 = linalg.reshape %arg0 [(i) -> (i), (i, j, k) -> (k)] : + memref into memref +} + +// ----- + +func @reshape(%arg0: memref) { + // expected-error @+1 {{expected reassociation map #1 to be valid and contiguous}} + %0 = linalg.reshape %arg0 [(i, j, k) -> (i, j), (i, j, k) -> (k, j)] : + memref into memref +} + +// ----- + +func @reshape(%arg0: memref) { + // expected-error @+1 {{expected collapsed type to be 'memref', but got 'memref (d0 * s0 + d1)>'}} + %0 = linalg.reshape %arg0 [(i, j, k) -> (i, j), (i, j, k) -> (k)] : + memref into memref (d0 * s0 + d1)> +} diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -7,12 +7,23 @@ // CHECK-DAG: #[[strided1D:.*]] = (d0)[s0] -> (d0 + s0) // CHECK-DAG: #[[strided2D:.*]] = (d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1) +// CHECK-DAG: #[[strided2DOFF0:.*]] = (d0, d1)[s0] -> (d0 * s0 + d1) // CHECK-DAG: #[[strided3D:.*]] = (d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2) +// CHECK-DAG: #[[strided3DOFF0:.*]] = (d0, d1, d2)[s0, s1] -> (d0 * s0 + d1 * s1 + d2) // CHECK-DAG: #[[strided6D:.*]] = (d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3 * s4 + d4 * s5 + d5) // CHECK-DAG: #[[map0:.*]] = (d0, d1, d2) -> (d0, d2, d1) // CHECK-DAG: #[[map1:.*]] = (d0, d1, d2) -> (d2, d1, d0) +// CHECK-DAG: #[[reshapeD01:.*]] = (d0, d1, d2) -> (d0, d1) +// CHECK-DAG: #[[reshapeD2:.*]] = (d0, d1, d2) -> (d2) +// CHECK-DAG: #[[reshapeD0:.*]] = (d0, d1, d2) -> (d0) +// CHECK-DAG: #[[reshapeD12:.*]] = (d0, d1, d2) -> (d1, d2) +// CHECK-DAG: #[[reshapeD012:.*]] = (d0, d1, d2) -> (d0, d1, d2) +// CHECK-DAG: #[[reshape5D01:.*]] = (d0, d1, d2, d3, d4) -> (d0, d1) +// CHECK-DAG: #[[reshape5D2:.*]] = (d0, d1, d2, d3, d4) -> (d2) +// CHECK-DAG: #[[reshape5D34:.*]] = (d0, d1, d2, d3, d4) -> (d3, d4) + func @range(%arg0: index, %arg1: index, %arg2: index) { %0 = linalg.range %arg0:%arg1:%arg2 : !linalg.range return @@ -181,7 +192,6 @@ // CHECK: ^{{.*}}(%{{.*}}: vector<3x4xi4>, %{{.*}}: f32): // no predecessors // CHECK: linalg.yield %{{.*}} : f32 // CHECK: } {foo = 1 : i64}: memref, #[[strided2D]]>, memref - func @indexed_generic(%arg0: memref, offset: ?, strides: [?, 1]>, %arg1: memref) { linalg.indexed_generic #trait2 %arg0, %arg1 { @@ -195,3 +205,91 @@ // CHECK: ^{{.*}}(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: vector<3x4xi4>, %{{.*}}: f32): // CHECK: linalg.yield %{{.*}} : f32 // CHECK: } {foo = 1 : i64}: memref, #[[strided2D]]>, memref + +func @reshape_static(%arg0: memref<3x4x5xf32>) { + // Reshapes that collapse and expand back a contiguous tensor. + %0 = linalg.reshape %arg0 [(i, j, k) -> (i, j), + (i, j, k) -> (k)] : + memref<3x4x5xf32> into memref<12x5xf32> + %r0 = linalg.reshape %0 [(i, j, k) -> (i, j), + (i, j, k) -> (k)] : + memref<12x5xf32> into memref<3x4x5xf32> + %1 = linalg.reshape %arg0 [(i, j, k) -> (i), + (i, j, k) -> (j, k)] : + memref<3x4x5xf32> into memref<3x20xf32> + %r1 = linalg.reshape %1 [(i, j, k) -> (i), + (i, j, k) -> (j, k)] : + memref<3x20xf32> into memref<3x4x5xf32> + %2 = linalg.reshape %arg0 [(i, j, k) -> (i, j, k)] : + memref<3x4x5xf32> into memref<60xf32> + %r2 = linalg.reshape %2 [(i, j, k) -> (i, j, k)] : + memref<60xf32> into memref<3x4x5xf32> + // Reshapes that expand and collapse back a contiguous tensor with some 1's. + %3 = linalg.reshape %arg0 [(i, j, k, l, m) -> (i, j), + (i, j, k, l, m) -> (k), + (i, j, k, l, m) -> (l, m)] : + memref<3x4x5xf32> into memref<1x3x4x1x5xf32> + %r3 = linalg.reshape %3 [(i, j, k, l, m) -> (i, j), + (i, j, k, l, m) -> (k), + (i, j, k, l, m) -> (l, m)] : + memref<1x3x4x1x5xf32> into memref<3x4x5xf32> + return +} +// CHECK-LABEL: func @reshape_static +// CHECK: linalg.reshape {{.*}} [#[[reshapeD01]], #[[reshapeD2]]] +// CHECK-SAME: memref<3x4x5xf32> into memref<12x5xf32> +// CHECK: linalg.reshape {{.*}} [#[[reshapeD01]], #[[reshapeD2]]] +// CHECK-SAME: memref<12x5xf32> into memref<3x4x5xf32> +// CHECK: linalg.reshape {{.*}} [#[[reshapeD0]], #[[reshapeD12]]] +// CHECK-SAME: memref<3x4x5xf32> into memref<3x20xf32> +// CHECK: linalg.reshape {{.*}} [#[[reshapeD0]], #[[reshapeD12]]] +// CHECK-SAME: memref<3x20xf32> into memref<3x4x5xf32> +// CHECK: linalg.reshape {{.*}} [#[[reshapeD012]]] +// CHECK-SAME: memref<3x4x5xf32> into memref<60xf32> +// CHECK: linalg.reshape {{.*}} [#[[reshapeD012]]] +// CHECK-SAME: memref<60xf32> into memref<3x4x5xf32> +// CHECK: linalg.reshape {{.*}} [#[[reshape5D01]], #[[reshape5D2]], #[[reshape5D34]]] +// CHECK-SAME: memref<3x4x5xf32> into memref<1x3x4x1x5xf32> +// CHECK: linalg.reshape {{.*}} [#[[reshape5D01]], #[[reshape5D2]], #[[reshape5D34]]] +// CHECK-SAME: memref<1x3x4x1x5xf32> into memref<3x4x5xf32> + +func @reshape_dynamic(%arg0: memref, + %arg1: memref, + %arg2: memref) { + %0 = linalg.reshape %arg0 [(i, j, k) -> (i, j), + (i, j, k) -> (k)] : + memref into memref + %r0 = linalg.reshape %0 [(i, j, k) -> (i, j), + (i, j, k) -> (k)] : + memref into memref + %1 = linalg.reshape %arg1 [(i, j, k) -> (i, j), + (i, j, k) -> (k)] : + memref into + memref + %r1 = linalg.reshape %1 [(i, j, k) -> (i, j), + (i, j, k) -> (k)] : + memref into + memref + %2 = linalg.reshape %arg2 [(i, j, k) -> (i, j), + (i, j, k) -> (k)] : + memref into + memref + %r2 = linalg.reshape %2 [(i, j, k) -> (i, j), + (i, j, k) -> (k)] : + memref into + memref + return +} +// CHECK-LABEL: func @reshape +// CHECK: linalg.reshape {{.*}} [#[[reshapeD01]], #[[reshapeD2]]] +// CHECK-SAME: memref into memref +// CHECK: linalg.reshape {{.*}} [#[[reshapeD01]], #[[reshapeD2]]] +// CHECK-SAME: memref into memref +// CHECK: linalg.reshape {{.*}} [#[[reshapeD01]], #[[reshapeD2]]] +// CHECK-SAME: memref into memref +// CHECK: linalg.reshape {{.*}} [#[[reshapeD01]], #[[reshapeD2]]] +// CHECK-SAME: memref into memref +// CHECK: linalg.reshape {{.*}} [#[[reshapeD01]], #[[reshapeD2]]] +// CHECK-SAME: memref into memref +// CHECK: linalg.reshape {{.*}} [#[[reshapeD01]], #[[reshapeD2]]] +// CHECK-SAME: memref into memref