diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -1196,7 +1196,9 @@ code commonExtraClassDeclaration = [{ SmallVector getReassociationMaps(); + SmallVector getReassociationExprs(); + SmallVector getReassociationIndices() { SmallVector reassociationIndices; for (auto attr : reassociation()) @@ -1206,8 +1208,11 @@ }))); return reassociationIndices; }; + MemRefType getSrcType() { return src().getType().cast(); } + MemRefType getResultType() { return result().getType().cast(); } + Value getViewSource() { return src(); } }]; @@ -1224,36 +1229,37 @@ let summary = "operation to produce a memref with a higher rank."; let description = [{ The `memref.expand_shape` op produces a new view with a higher rank whose - sizes are a reassociation of the original `view`. Depending on whether or - not the reassociated MemRefType is contiguous, the resulting memref may - require explicit alloc and copies. + sizes are a reassociation of the original `view`. The operation is limited + to such reassociations, where a dimension is expanded into one or multiple + contiguous dimensions. Such reassociations never require additional allocs + or copies. - A reassociation is defined as a continuous grouping of dimensions and is - represented with an array of I64ArrayAttr attribute. - - For now, it is assumed that either: - 1. a reassociation produces and consumes contiguous MemRefType or, - 2. the reshape op will be folded into its consumers (by changing the shape - of the computations). - All other cases are undefined behavior and a reshape op may not lower to - LLVM if it cannot be proven statically that it does not require alloc+copy. - - The operand memref type when dimensions can be zero-ranked if the result - memref type is statically shaped with all dimensions being unit extent. In - such case the reassociation map is empty. - - The verification rule is that the reassociation maps are applied to the - result memref with the larger rank to obtain the operand memref with the - smaller rank. + A reassociation is defined as a grouping of dimensions and is represented + with an array of I64ArrayAttr attributes. Example: ```mlir - // Dimension expansion i -> (i', j') and (k) -> (k') - %1 = memref.expand_shape %0 [[0, 1], [2]] : - memref into memref + %r = memref.expand_shape %0 [[0, 1], [2]] + : memref into memref ``` + + At most one dimension of a reassociation group (e.g., [0, 1] above) may be + dynamic in the result type. Otherwise, the op would be ambiguous, as it + would not be clear how the source dimension is extended. + + If an op can be statically proven to be invalid (e.g, an expansion from + `memref<10xf32>` to `memref<2x6xf32>`), it is rejected by the verifier. If + it cannot statically be proven invalid (e.g., the full example above; it is + unclear whether the first source dimension is divisible by 5), the op is + accepted by the verifier. However, if the op is in fact invalid at runtime, + the behavior is undefined. + + The source memref can be zero-ranked. In that case, the reassociation + indices must be empty and the the result shape may only consist of unit + dimensions. }]; + let builders = [ // Builders using ReassociationIndices. OpBuilder<(ins "Type":$resultType, "Value":$src, @@ -1264,6 +1270,8 @@ $_state.addAttribute("reassociation", getReassociationIndicesAttribute($_builder, reassociation)); }]>, + + // Builder using ReassociationExprs. OpBuilder<(ins "Type":$resultType, "Value":$src, "ArrayRef":$reassociation, CArg<"ArrayRef", "{}">:$attrs), @@ -1271,8 +1279,14 @@ auto reassociationMaps = convertReassociationMapsToIndices($_builder, reassociation); build($_builder, $_state, resultType, src, reassociationMaps, attrs); - }]> + }]>, + + // Builder that infers the result layout map. The result shape must be + // specified. Otherwise, the op may be ambiguous. + OpBuilder<(ins "ArrayRef":$resultShape, "Value":$src, + "ArrayRef":$reassociation)> ]; + let extraClassDeclaration = commonExtraClassDeclaration; let hasVerifier = 1; } diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1558,9 +1558,28 @@ // Reassociative reshape ops //===----------------------------------------------------------------------===// +/// Helper function that computes a stride based on the size/stride of the +/// previous dimension. +/// +/// E.g., memref<20x10x5xf32, offset: 0, strides: [50, 5, 1]> +/// ^^ +/// compute this one +/// prevStride = 5, prevDimSize = 10 +/// nextStride = 5 * 10 = 50 +static int64_t computeNextStride(int64_t prevStride, int64_t prevDimSize) { + if (ShapedType::isDynamicStrideOrOffset(prevStride)) + return ShapedType::kDynamicStrideOrOffset; + + if (ShapedType::isDynamic(prevDimSize)) + return ShapedType::kDynamicStrideOrOffset; + + return prevStride * prevDimSize; +} + SmallVector CollapseShapeOp::getReassociationMaps() { return getSymbolLessAffineMaps(getReassociationExprs()); } + SmallVector CollapseShapeOp::getReassociationExprs() { return convertReassociationIndicesToExprs(getContext(), getReassociationIndices()); @@ -1569,6 +1588,7 @@ SmallVector ExpandShapeOp::getReassociationMaps() { return getSymbolLessAffineMaps(getReassociationExprs()); } + SmallVector ExpandShapeOp::getReassociationExprs() { return convertReassociationIndicesToExprs(getContext(), getReassociationIndices()); @@ -1702,8 +1722,157 @@ return success(); } +/// Compute the layout map after expanding a given source MemRef type with the +/// specified reassociation indices. +static FailureOr +computeExpandedLayoutMap(MemRefType srcType, ArrayRef resultShape, + ArrayRef reassociation) { + SmallVector srcStrides, resultStrides(resultShape.size(), 0); + int64_t srcOffset; + if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset))) + return failure(); + assert(srcStrides.size() == reassociation.size() && "invalid reassociation"); + + // Iterate over all reassociations group from the back. Example: + // strides = [1000, ?, 2] + // source shape = [20, 10, 5] + // result shape = [ 2, 10, 2, 5, 5] + // reassociation = [[0, 1], [2, 3], [4]] + for (const auto &it : llvm::reverse(llvm::zip(reassociation, srcStrides))) { + ReassociationIndices indices = std::get<0>(it); + int64_t srcGroupStride = std::get<1>(it); + + // The first result dimension (least significant one) in each reassociation + // group has the same stride as the corresponding source dimension. E.g.: + // reassociation = [[0, 1], [2, 3], [4]] + // | | | + // v v v + // 1000 ? 2 + resultStrides[indices.pop_back_val()] = srcGroupStride; + + // Compute the strides for the remaining dims in the reassociation group. + for (int64_t resultDim : llvm::reverse(indices)) { + // E.g.: + // reassociation = [[0, 1], [2, 3], [4]] + // | + // v + // 1000 * 10 = 10000 + // + // If the previous stride or the previous dimension was dynamic, then this + // stride will also be dynamic. + resultStrides[resultDim] = computeNextStride(resultStrides[resultDim + 1], + resultShape[resultDim + 1]); + } + } + + return makeStridedLinearLayoutMap(resultStrides, srcOffset, + srcType.getContext()); +} + +static FailureOr +computeExpandedType(MemRefType srcType, ArrayRef resultShape, + ArrayRef reassociation) { + if (srcType.getLayout().isIdentity()) { + // If the source is contiguous (i.e., no layout map specified), so is the + // result. + MemRefLayoutAttrInterface layout; + return MemRefType::get(resultShape, srcType.getElementType(), layout, + srcType.getMemorySpace()); + } + + // Source may not be contiguous. Compute the layout map. + FailureOr computedLayout = + computeExpandedLayoutMap(srcType, resultShape, reassociation); + if (failed(computedLayout)) + return failure(); + auto computedType = + MemRefType::get(resultShape, srcType.getElementType(), *computedLayout, + srcType.getMemorySpaceAsInt()); + return canonicalizeStridedLayout(computedType); +} + +void ExpandShapeOp::build(OpBuilder &builder, OperationState &result, + ArrayRef resultShape, Value src, + ArrayRef reassociation) { + // Only ranked memref source values are supported. + auto srcType = src.getType().cast(); + FailureOr resultType = + computeExpandedType(srcType, resultShape, reassociation); + // Failure of this assertion usually indicates a problem with the source + // type, e.g., could not get strides/offset. + assert(succeeded(resultType) && "could not compute layout"); + build(builder, result, *resultType, src, reassociation); +} + LogicalResult ExpandShapeOp::verify() { - return verifyReshapeOp(*this, getResultType(), getSrcType()); + MemRefType srcType = getSrcType(); + MemRefType resultType = getResultType(); + + // There must be one reassociation group per source dimension. + if (srcType.getRank() != getReassociationIndices().size()) + return emitOpError("invalid number of reassociation groups: found ") + << getReassociationIndices().size() << ", expected " + << srcType.getRank(); + + // The next expected dimension (while iterating over reassociation indices) + int64_t nextDim = 0; + for (const auto &it : llvm::enumerate(getReassociationIndices())) { + ReassociationIndices group = it.value(); + int64_t srcDim = it.index(); + + bool foundDynamic = false; + for (int64_t resultDim : group) { + if (resultDim != nextDim++) + return emitOpError("reassociation indices must be contiguous"); + + if (resultDim >= resultType.getRank()) + return emitOpError("reassociation index ") + << resultDim << " is out of bounds"; + + // There may be at most dynamic result dim in a reassociation group. + if (resultType.isDynamicDim(resultDim)) { + if (foundDynamic) + return emitOpError( + "at most one dimension in a reassociation group may be dynamic"); + foundDynamic = true; + } + } + + if (srcType.isDynamicDim(srcDim) != foundDynamic) + return emitOpError( + "result dim must be dynamic if and only if source dim is dynamic"); + + if (!foundDynamic) { + int64_t groupSize = 1; + for (int64_t resultDim : group) + groupSize *= resultType.getDimSize(resultDim); + if (groupSize != srcType.getDimSize(srcDim)) + return emitOpError("source dim must be a product of expanded dims"); + } + } + + if (srcType.getRank() == 0) { + // Rank 0: All result dimensions must be 1. + for (int64_t d : resultType.getShape()) + if (d != 1) + return emitOpError("rank 0 memrefs can only be extended with ones"); + } else if (nextDim != resultType.getRank()) { + // Rank >= 1: Number of dimensions among all reassociation groups must match + // the result memref rank. + return emitOpError("result type inconsistent with reassociation"); + } + + FailureOr expectedResultType = computeExpandedType( + srcType, resultType.getShape(), getReassociationIndices()); + if (failed(expectedResultType)) + return emitOpError("invalid source layout map"); + + auto canonicalizedResultType = canonicalizeStridedLayout(resultType); + if (*expectedResultType != canonicalizedResultType) + return emitOpError("expected expanded type to be ") + << *expectedResultType << " but found " << canonicalizedResultType; + + return success(); } void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir --- a/mlir/test/Dialect/MemRef/invalid.mlir +++ b/mlir/test/Dialect/MemRef/invalid.mlir @@ -393,8 +393,17 @@ // ----- func @expand_shape(%arg0: memref) { - // expected-error @+1 {{expected non-zero memref ranks}} + // expected-error @+1 {{invalid number of reassociation groups: found 1, expected 0}} %0 = memref.expand_shape %arg0 [[0]] : memref into memref + return +} + +// ----- + +func @expand_shape(%arg0: memref) { + // expected-error @+1 {{rank 0 memrefs can only be extended with ones}} + %0 = memref.expand_shape %arg0 [] : memref into memref<1x2xf32> + return } // ----- @@ -407,12 +416,22 @@ // ----- func @expand_shape_to_smaller_rank(%arg0: memref<1xf32>) { - // expected-error @+1 {{expected the type 'memref' to have higher rank than the type = 'memref<1xf32>'}} + // expected-error @+1 {{op reassociation index 0 is out of bounds}} %0 = memref.expand_shape %arg0 [[0]] : memref<1xf32> into memref } // ----- +func @expand_shape_invalid_result_layout( + %arg0: memref<30x20xf32, offset : 100, strides : [4000, 2]>) { + // expected-error @+1 {{expected expanded type to be 'memref<2x15x20xf32, affine_map<(d0, d1, d2) -> (d0 * 60000 + d1 * 4000 + d2 * 2 + 100)>>' but found 'memref<2x15x20xf32, affine_map<(d0, d1, d2) -> (d0 * 5000 + d1 * 4000 + d2 * 2 + 100)>>'}} + %0 = memref.expand_shape %arg0 [[0, 1], [2]] : + memref<30x20xf32, offset : 100, strides : [4000, 2]> + into memref<2x15x20xf32, offset : 100, strides : [5000, 4000, 2]> +} + +// ----- + func @collapse_shape(%arg0: memref) { // expected-error @+1 {{expected to collapse or expand dims}} %0 = memref.collapse_shape %arg0 [[0]] : memref into memref @@ -446,7 +465,7 @@ func @expand_shape_illegal_dynamic_memref (%arg0: memref) -> memref { - // expected-error @+1 {{invalid to have a single dimension (2) expanded into multiple dynamic dims (2,4)}} + // expected-error @+1 {{at most one dimension in a reassociation group may be dynamic}} %0 = memref.expand_shape %arg0 [[0], [1], [2, 3, 4]] : memref into memref return %0 : memref @@ -456,7 +475,7 @@ func @expand_shape_illegal_static_memref (%arg0: memref<2x3x20xf32>) -> memref<2x3x2x4x5xf32> { - // expected-error @+1 {{expected dimension 2 of collapsed type to be static value of 40}} + // expected-error @+1 {{source dim must be a product of expanded dims}} %0 = memref.expand_shape %arg0 [[0], [1], [2, 3, 4]] : memref<2x3x20xf32> into memref<2x3x2x4x5xf32> return %0 : memref<2x3x2x4x5xf32> @@ -476,7 +495,7 @@ func @expand_shape_illegal_mixed_memref(%arg0 : memref) -> memref { - // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 5}} + // expected-error @+1 {{result dim must be dynamic if and only if source dim is dynamic}} %0 = memref.expand_shape %arg0 [[0, 1], [2]] : memref into memref return %0 : memref @@ -486,7 +505,7 @@ func @expand_shape_illegal_mixed_memref_2(%arg0 : memref) -> memref { - // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 20}} + // expected-error @+1 {{result dim must be dynamic if and only if source dim is dynamic}} %0 = memref.expand_shape %arg0 [[0], [1, 2]] : memref into memref return %0 : memref @@ -494,6 +513,16 @@ // ----- +func @expand_shape_invalid_static_dim_size(%arg0 : memref) + -> memref { + // expected-error @+1 {{source dim must be a product of expanded dims}} + %0 = memref.expand_shape %arg0 [[0], [1, 2]] + : memref into memref + return %0 : memref +} + +// ----- + func @collapse_shape_illegal_mixed_memref(%arg0 : memref) -> memref { // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 5}} diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir --- a/mlir/test/Dialect/MemRef/ops.mlir +++ b/mlir/test/Dialect/MemRef/ops.mlir @@ -104,107 +104,147 @@ return } -func @expand_collapse_shape_static(%arg0: memref<3x4x5xf32>, - %arg1: tensor<3x4x5xf32>, - %arg2: tensor<3x?x5xf32>) { +// CHECK-LABEL: func @expand_collapse_shape_static +func @expand_collapse_shape_static( + %arg0: memref<3x4x5xf32>, + %arg1: tensor<3x4x5xf32>, + %arg2: tensor<3x?x5xf32>, + %arg3: memref<30x20xf32, offset : 100, strides : [4000, 2]>, + %arg4: memref<1x5xf32, affine_map<(d0, d1)[s0] -> (d0 * 5 + s0 + d1)>>, + %arg5: memref) { // Reshapes that collapse and expand back a contiguous buffer. +// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]] +// CHECK-SAME: memref<3x4x5xf32> into memref<12x5xf32> %0 = memref.collapse_shape %arg0 [[0, 1], [2]] : memref<3x4x5xf32> into memref<12x5xf32> + +// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2]] +// CHECK-SAME: memref<12x5xf32> into memref<3x4x5xf32> %r0 = memref.expand_shape %0 [[0, 1], [2]] : memref<12x5xf32> into memref<3x4x5xf32> + +// CHECK: memref.collapse_shape {{.*}} {{\[}}[0], [1, 2]] +// CHECK-SAME: memref<3x4x5xf32> into memref<3x20xf32> %1 = memref.collapse_shape %arg0 [[0], [1, 2]] : memref<3x4x5xf32> into memref<3x20xf32> + +// CHECK: memref.expand_shape {{.*}} {{\[}}[0], [1, 2]] +// CHECK-SAME: memref<3x20xf32> into memref<3x4x5xf32> %r1 = memref.expand_shape %1 [[0], [1, 2]] : memref<3x20xf32> into memref<3x4x5xf32> + +// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1, 2]] +// CHECK-SAME: memref<3x4x5xf32> into memref<60xf32> %2 = memref.collapse_shape %arg0 [[0, 1, 2]] : memref<3x4x5xf32> into memref<60xf32> + +// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1, 2]] +// CHECK-SAME: memref<60xf32> into memref<3x4x5xf32> %r2 = memref.expand_shape %2 [[0, 1, 2]] : - memref<60xf32> into memref<3x4x5xf32> + memref<60xf32> into memref<3x4x5xf32> + +// CHECK: memref.expand_shape {{.*}} [] +// CHECK-SAME: memref into memref<1x1xf32> + %r5 = memref.expand_shape %arg5 [] : + memref into memref<1x1xf32> + +// Reshapes with a custom layout map. +// CHECK: memref.expand_shape {{.*}} {{\[}}[0], [1, 2]] + %l0 = memref.expand_shape %arg3 [[0], [1, 2]] : + memref<30x20xf32, offset : 100, strides : [4000, 2]> + into memref<30x4x5xf32, offset : 100, strides : [4000, 10, 2]> + +// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2]] + %l1 = memref.expand_shape %arg3 [[0, 1], [2]] : + memref<30x20xf32, offset : 100, strides : [4000, 2]> + into memref<2x15x20xf32, offset : 100, strides : [60000, 4000, 2]> + +// CHECK: memref.expand_shape {{.*}} {{\[}}[0], [1, 2]] + %r4 = memref.expand_shape %arg4 [[0], [1, 2]] : + memref<1x5xf32, affine_map<(d0, d1)[s0] -> (d0 * 5 + s0 + d1)>> into + memref<1x1x5xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 5 + s0 + d2 + d1 * 5)>> + // Reshapes that expand and collapse back a contiguous buffer with some 1's. +// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]] +// CHECK-SAME: memref<3x4x5xf32> into memref<1x3x4x1x5xf32> %3 = memref.expand_shape %arg0 [[0, 1], [2], [3, 4]] : memref<3x4x5xf32> into memref<1x3x4x1x5xf32> + +// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]] +// CHECK-SAME: memref<1x3x4x1x5xf32> into memref<3x4x5xf32> %r3 = memref.collapse_shape %3 [[0, 1], [2], [3, 4]] : memref<1x3x4x1x5xf32> into memref<3x4x5xf32> + // Reshapes on tensors. +// CHECK: tensor.expand_shape {{.*}}: tensor<3x4x5xf32> into tensor<1x3x4x1x5xf32> %t0 = tensor.expand_shape %arg1 [[0, 1], [2], [3, 4]] : tensor<3x4x5xf32> into tensor<1x3x4x1x5xf32> + +// CHECK: tensor.collapse_shape {{.*}}: tensor<1x3x4x1x5xf32> into tensor<3x4x5xf32> %rt0 = tensor.collapse_shape %t0 [[0, 1], [2], [3, 4]] : tensor<1x3x4x1x5xf32> into tensor<3x4x5xf32> + +// CHECK: tensor.expand_shape {{.*}}: tensor<3x?x5xf32> into tensor<1x3x?x1x5xf32> %t1 = tensor.expand_shape %arg2 [[0, 1], [2], [3, 4]] : tensor<3x?x5xf32> into tensor<1x3x?x1x5xf32> + +// CHECK: tensor.collapse_shape {{.*}}: tensor<1x3x?x1x5xf32> into tensor<1x?x5xf32> %rt1 = tensor.collapse_shape %t1 [[0], [1, 2], [3, 4]] : tensor<1x3x?x1x5xf32> into tensor<1x?x5xf32> return } -// CHECK-LABEL: func @expand_collapse_shape_static -// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]] -// CHECK-SAME: memref<3x4x5xf32> into memref<12x5xf32> -// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2]] -// CHECK-SAME: memref<12x5xf32> into memref<3x4x5xf32> -// CHECK: memref.collapse_shape {{.*}} {{\[}}[0], [1, 2]] -// CHECK-SAME: memref<3x4x5xf32> into memref<3x20xf32> -// CHECK: memref.expand_shape {{.*}} {{\[}}[0], [1, 2]] -// CHECK-SAME: memref<3x20xf32> into memref<3x4x5xf32> -// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1, 2]] -// CHECK-SAME: memref<3x4x5xf32> into memref<60xf32> -// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1, 2]] -// CHECK-SAME: memref<60xf32> into memref<3x4x5xf32> -// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]] -// CHECK-SAME: memref<3x4x5xf32> into memref<1x3x4x1x5xf32> -// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]] -// CHECK-SAME: memref<1x3x4x1x5xf32> into memref<3x4x5xf32> -// -// CHECK: tensor.expand_shape {{.*}}: tensor<3x4x5xf32> into tensor<1x3x4x1x5xf32> -// CHECK: tensor.collapse_shape {{.*}}: tensor<1x3x4x1x5xf32> into tensor<3x4x5xf32> -// CHECK: tensor.expand_shape {{.*}}: tensor<3x?x5xf32> into tensor<1x3x?x1x5xf32> -// CHECK: tensor.collapse_shape {{.*}}: tensor<1x3x?x1x5xf32> into tensor<1x?x5xf32> - +// CHECK-LABEL: func @expand_collapse_shape_dynamic func @expand_collapse_shape_dynamic(%arg0: memref, %arg1: memref, %arg2: memref, %arg3: memref) { +// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]] +// CHECK-SAME: memref into memref %0 = memref.collapse_shape %arg0 [[0, 1], [2]] : memref into memref + +// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2]] +// CHECK-SAME: memref into memref %r0 = memref.expand_shape %0 [[0, 1], [2]] : memref into memref + +// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]] +// CHECK-SAME: memref into memref %1 = memref.collapse_shape %arg1 [[0, 1], [2]] : memref into memref + +// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2]] +// CHECK-SAME: memref into memref %r1 = memref.expand_shape %1 [[0, 1], [2]] : memref into memref + +// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]] +// CHECK-SAME: memref into memref %2 = memref.collapse_shape %arg2 [[0, 1], [2]] : memref into memref + +// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2]] +// CHECK-SAME: memref into memref %r2 = memref.expand_shape %2 [[0, 1], [2]] : memref into memref + +// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1]] +// CHECK-SAME: memref into memref %3 = memref.collapse_shape %arg3 [[0, 1]] : memref into memref + +// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1]] +// CHECK-SAME: memref into memref %r3 = memref.expand_shape %3 [[0, 1]] : - memref into - memref + memref into memref return } -// CHECK-LABEL: func @expand_collapse_shape_dynamic -// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]] -// CHECK-SAME: memref into memref -// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2]] -// CHECK-SAME: memref into memref -// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]] -// CHECK-SAME: memref into memref -// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2]] -// CHECK-SAME: memref into memref -// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]] -// CHECK-SAME: memref into memref -// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2]] -// CHECK-SAME: memref into memref -// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1]] -// CHECK-SAME: memref into memref -// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1]] -// CHECK-SAME: memref into memref func @expand_collapse_shape_zero_dim(%arg0 : memref<1x1xf32>, %arg1 : memref) -> (memref, memref<1x1xf32>) {