diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -77,6 +77,10 @@ code commonExtraClassDeclaration = [{ static StringRef getReassociationAttrName() { return "reassociation"; } + SmallVector getReassociationMaps() { + return llvm::to_vector<4>(llvm::map_range(reassociation(), [ + ](Attribute a) { return a.cast().getValue(); })); + } }]; let assemblyFormat = [{ $src $reassociation attr-dict `:` type($src) `into` type(results) @@ -137,6 +141,7 @@ MemRefType getResultType() { return result().getType().cast(); } }]; let hasFolder = 1; + let hasCanonicalizer = 1; } def Linalg_TensorReshapeOp : Linalg_ReshapeLikeOp<"tensor_reshape">, @@ -187,11 +192,9 @@ RankedTensorType getResultType() { return result().getType().cast(); } - SmallVector getReassociationMaps() { - return llvm::to_vector<4>(llvm::map_range(reassociation(), - [](Attribute a) { return a.cast().getValue(); })); - } }]; + let hasFolder = 1; + let hasCanonicalizer = 1; } def Linalg_SliceOp : Linalg_Op<"slice", [ diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -246,6 +246,108 @@ // ReshapeOp //===----------------------------------------------------------------------===// +/// Collapse reassociation maps that are used in pair of reshape ops where one +/// is a producer and other is the consumer. Only valid to use this method when +/// both the producer and consumer are collapsing dimensions or both are +/// expanding dimensions. +/// +/// For example, +/// mapsProducer = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>, +/// affine_map<(d0, d1, d2, d3, d4) -> (d2)>, +/// affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>] +/// mapsConsumer = [affine_map<(d0, d1, d2) -> (d0, d1)>, +/// affine_map<(d0, d1, d2) -> (d2)>] +/// +/// is folded into +/// +/// result = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, +/// affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>] +static ArrayAttr collapseReassociationMaps(ArrayRef mapsProducer, + ArrayRef mapsConsumer, + MLIRContext *context) { + if (mapsProducer.size() == 0 || mapsConsumer.size() == 0 || + mapsProducer[0].getNumDims() < mapsConsumer[0].getNumDims() || + mapsProducer.size() != mapsConsumer[0].getNumDims()) + return nullptr; + unsigned numLhsDims = mapsProducer[0].getNumDims(); + unsigned currDim = 0; + SmallVector reassociations; + SmallVector reassociationMaps; + for (AffineMap rhs : mapsConsumer) { + for (AffineExpr rhsExpr : rhs.getResults()) { + AffineDimExpr dimExpr = rhsExpr.cast(); + for (int i = 0, e = mapsProducer[dimExpr.getPosition()].getNumResults(); + i != e; ++i) { + reassociations.push_back(getAffineDimExpr(currDim++, context)); + } + } + reassociationMaps.push_back(AffineMapAttr::get(AffineMap::get( + numLhsDims, /*numSymbols =*/0, reassociations, context))); + reassociations.clear(); + } + return ArrayAttr::get(reassociationMaps, context); +} + +namespace { +/// Pattern to collapse producer/consumer reshape ops that are both collapsing +/// dimensions or are both expanding dimensions. +template +struct CollapseReshapeOps : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp, + PatternRewriter &rewriter) const override { + auto srcReshapeOp = + dyn_cast_or_null(reshapeOp.src().getDefiningOp()); + if (!srcReshapeOp) + return failure(); + + auto areReshapeOpsFoldable = [](ShapedType largerType, + ShapedType intermediateType, + ShapedType smallerType) -> bool { + return largerType.getRank() > intermediateType.getRank() && + intermediateType.getRank() > smallerType.getRank() && + smallerType.getRank() > 0; + }; + // Check if producer and consumer are both expanding dims. + if (areReshapeOpsFoldable(reshapeOp.getResultType(), reshapeOp.getSrcType(), + srcReshapeOp.getSrcType())) { + rewriter.replaceOpWithNewOp( + reshapeOp, reshapeOp.getResultType(), srcReshapeOp.src(), + collapseReassociationMaps(reshapeOp.getReassociationMaps(), + srcReshapeOp.getReassociationMaps(), + rewriter.getContext())); + return success(); + } + // Check if producer and consumer are both collapsing dims. + else if (areReshapeOpsFoldable(srcReshapeOp.getSrcType(), + reshapeOp.getSrcType(), + reshapeOp.getResultType())) { + rewriter.replaceOpWithNewOp( + reshapeOp, reshapeOp.getResultType(), srcReshapeOp.src(), + collapseReassociationMaps(srcReshapeOp.getReassociationMaps(), + reshapeOp.getReassociationMaps(), + rewriter.getContext())); + return success(); + } + return failure(); + } +}; +} // namespace + +template +static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp) { + // Fold producer-consumer reshape ops that where the operand type of the + // producer is same as the return type of the consumer. This can only be + // verified if the shapes in question are static. + ReshapeOpTy reshapeSrcOp = + dyn_cast_or_null(reshapeOp.src().getDefiningOp()); + if (reshapeSrcOp && reshapeSrcOp.getSrcType().hasStaticShape() && + reshapeOp.getResultType().hasStaticShape() && + reshapeSrcOp.getSrcType() == reshapeOp.getResultType()) + return reshapeSrcOp.src(); + return nullptr; +}; + /// Return true if the reassociation specification is valid, false otherwise. /// When false, the `invalidIndex` integer pointer is optionally filled with the /// index of the offending reassociation map. @@ -482,6 +584,11 @@ return success(); } +void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert>(context); +} + //===----------------------------------------------------------------------===// // TensorReshapeOp //===----------------------------------------------------------------------===// @@ -551,6 +658,11 @@ return success(); } +void TensorReshapeOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert>(context); +} + //===----------------------------------------------------------------------===// // SliceOp //===----------------------------------------------------------------------===// @@ -1010,13 +1122,18 @@ OpFoldResult ReshapeOp::fold(ArrayRef) { if (succeeded(foldMemRefCast(*this))) return getResult(); - return {}; + return foldReshapeOp(*this); } OpFoldResult SliceOp::fold(ArrayRef) { if (succeeded(foldMemRefCast(*this))) return getResult(); return {}; } +OpFoldResult TensorReshapeOp::fold(ArrayRef) { + if (succeeded(foldMemRefCast(*this))) + return getResult(); + return foldReshapeOp(*this); +} OpFoldResult TransposeOp::fold(ArrayRef) { if (succeeded(foldMemRefCast(*this))) return getResult(); diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -canonicalize | FileCheck %s +// RUN: mlir-opt %s -canonicalize -split-input-file | FileCheck %s // CHECK-LABEL: func @memref_cast( func @memref_cast(%a: index, %b: index) -> memref { @@ -18,3 +18,157 @@ linalg.matmul(%3, %3, %3) : memref, memref, memref return %4: memref } + +// ----- + +func @collapsing_tensor_reshapes(%arg0 : tensor) -> tensor +{ + %0 = linalg.tensor_reshape %arg0 + [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>, + affine_map<(d0, d1, d2, d3, d4) -> (d2)>, + affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>] : + tensor into tensor + %1 = linalg.tensor_reshape %0 + [affine_map<(d0, d1, d2) -> (d0, d1)>, + affine_map<(d0, d1, d2) -> (d2)>] : + tensor into tensor + return %1 : tensor +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)> +// CHECK-LABEL: collapsing_tensor_reshapes +// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]], #[[MAP1]]] +// CHECK-NOT: linalg.tensor_reshape + +// ----- + +func @expanding_tensor_reshapes(%arg0 : tensor) -> tensor +{ + %0 = linalg.tensor_reshape %arg0 + [affine_map<(d0, d1, d2) -> (d0, d1)>, + affine_map<(d0, d1, d2) -> (d2)>] : + tensor into tensor + %1 = linalg.tensor_reshape %0 + [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>, + affine_map<(d0, d1, d2, d3, d4) -> (d2)>, + affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>] : + tensor into tensor + return %1 : tensor +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)> +// CHECK-LABEL: expanding_tensor_reshapes +// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]], #[[MAP1]]] +// CHECK-NOT: linalg.tensor_reshape + +// ----- + +func @collapsing_memref_reshapes(%arg0 : memref) -> memref +{ + %0 = linalg.reshape %arg0 + [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>, + affine_map<(d0, d1, d2, d3, d4) -> (d2)>, + affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>] : + memref into memref + %1 = linalg.reshape %0 + [affine_map<(d0, d1, d2) -> (d0, d1)>, + affine_map<(d0, d1, d2) -> (d2)>] : + memref into memref + return %1 : memref +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)> +// CHECK-LABEL: collapsing_memref_reshapes +// CHECK: linalg.reshape %{{.*}} [#[[MAP0]], #[[MAP1]]] +// CHECK-NOT: linalg.reshape + +// ----- + +func @expanding_memref_reshapes(%arg0 : memref) -> memref +{ + %0 = linalg.reshape %arg0 + [affine_map<(d0, d1, d2) -> (d0, d1)>, + affine_map<(d0, d1, d2) -> (d2)>] : + memref into memref + %1 = linalg.reshape %0 + [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>, + affine_map<(d0, d1, d2, d3, d4) -> (d2)>, + affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>] : + memref into memref + return %1 : memref +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)> +// CHECK-LABEL: expanding_memref_reshapes +// CHECK: linalg.reshape %{{.*}} [#[[MAP0]], #[[MAP1]]] +// CHECK-NOT: linalg.reshape + +// ----- + +func @fold_tensor_reshape(%arg0 : tensor<12x4xf32>) -> tensor<12x4xf32> +{ + %0 = linalg.tensor_reshape %arg0 + [affine_map<(d0, d1, d2) -> (d0, d1)>, + affine_map<(d0, d1, d2) -> (d2)>] : + tensor<12x4xf32> into tensor<3x4x4xf32> + %1 = linalg.tensor_reshape %0 + [affine_map<(d0, d1, d2) -> (d0, d1)>, + affine_map<(d0, d1, d2) -> (d2)>] : + tensor<3x4x4xf32> into tensor<12x4xf32> + return %1 : tensor<12x4xf32> +} +// CHECK-LABEL: @fold_tensor_reshape +// CHECK-NOT: linalg.tensor_reshape + +// ----- + +func @no_fold_tensor_reshape(%arg0 : tensor) -> tensor +{ + %0 = linalg.tensor_reshape %arg0 + [affine_map<(d0, d1, d2) -> (d0, d1)>, + affine_map<(d0, d1, d2) -> (d2)>] : + tensor into tensor + %1 = linalg.tensor_reshape %0 + [affine_map<(d0, d1, d2) -> (d0, d1)>, + affine_map<(d0, d1, d2) -> (d2)>] : + tensor into tensor + return %1 : tensor +} +// CHECK-LABEL: @no_fold_tensor_reshape +// CHECK: linalg.tensor_reshape +// CHECK: linalg.tensor_reshape + +// ----- + +func @fold_memref_reshape(%arg0 : memref<12x4xf32>) -> memref<12x4xf32> +{ + %0 = linalg.reshape %arg0 + [affine_map<(d0, d1, d2) -> (d0, d1)>, + affine_map<(d0, d1, d2) -> (d2)>] : + memref<12x4xf32> into memref<3x4x4xf32> + %1 = linalg.reshape %0 + [affine_map<(d0, d1, d2) -> (d0, d1)>, + affine_map<(d0, d1, d2) -> (d2)>] : + memref<3x4x4xf32> into memref<12x4xf32> + return %1 : memref<12x4xf32> +} +// CHECK-LABEL: @fold_memref_reshape +// CHECK-NOT: linalg.reshape + +// ----- + +func @no_fold_memref_reshape(%arg0 : memref) -> memref +{ + %0 = linalg.reshape %arg0 + [affine_map<(d0, d1, d2) -> (d0, d1)>, + affine_map<(d0, d1, d2) -> (d2)>] : + memref into memref + %1 = linalg.reshape %0 + [affine_map<(d0, d1, d2) -> (d0, d1)>, + affine_map<(d0, d1, d2) -> (d2)>] : + memref into memref + return %1 : memref +} +// CHECK-LABEL: @no_fold_memref_reshape +// CHECK: linalg.reshape +// CHECK: linalg.reshape diff --git a/mlir/test/Dialect/Linalg/llvm.mlir b/mlir/test/Dialect/Linalg/llvm.mlir --- a/mlir/test/Dialect/Linalg/llvm.mlir +++ b/mlir/test/Dialect/Linalg/llvm.mlir @@ -214,72 +214,76 @@ // CHECK-SAME: !llvm<"<4 x float>*">, !llvm<"<4 x float>*">, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64 // CHECK-SAME: !llvm<"[4 x <4 x float>]*">, !llvm<"[4 x <4 x float>]*">, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64 -func @reshape_static(%arg0: memref<3x4x5xf32>) { - // Reshapes that expand and collapse back a contiguous tensor with some 1's. +func @reshape_static_expand(%arg0: memref<3x4x5xf32>) -> memref<1x3x4x1x5xf32> { + // Reshapes that expand a contiguous tensor with some 1's. %0 = linalg.reshape %arg0 [affine_map<(i, j, k, l, m) -> (i, j)>, affine_map<(i, j, k, l, m) -> (k)>, affine_map<(i, j, k, l, m) -> (l, m)>] : memref<3x4x5xf32> into memref<1x3x4x1x5xf32> - %r0 = linalg.reshape %0 [affine_map<(i, j, k, l, m) -> (i, j)>, - affine_map<(i, j, k, l, m) -> (k)>, - affine_map<(i, j, k, l, m) -> (l, m)>] : + return %0 : memref<1x3x4x1x5xf32> +} +// CHECK-LABEL: func @reshape_static_expand +// CHECK: llvm.mlir.undef : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> +// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> +// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> +// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> +// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> +// CHECK: llvm.mlir.constant(3 : index) : !llvm.i64 +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> +// CHECK: llvm.mlir.constant(4 : index) : !llvm.i64 +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 2] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> +// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 3] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> +// CHECK: llvm.mlir.constant(5 : index) : !llvm.i64 +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 4] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> +// CHECK: llvm.mlir.constant(60 : index) : !llvm.i64 +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> +// CHECK: llvm.mlir.constant(20 : index) : !llvm.i64 +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> +// CHECK: llvm.mlir.constant(5 : index) : !llvm.i64 +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 2] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> +// CHECK: llvm.mlir.constant(5 : index) : !llvm.i64 +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 3] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> +// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 4] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> + +func @reshape_static_collapse(%arg0: memref<1x3x4x1x5xf32>) -> memref<3x4x5xf32> { + %0 = linalg.reshape %arg0 [affine_map<(i, j, k, l, m) -> (i, j)>, + affine_map<(i, j, k, l, m) -> (k)>, + affine_map<(i, j, k, l, m) -> (l, m)>] : memref<1x3x4x1x5xf32> into memref<3x4x5xf32> - return + return %0 : memref<3x4x5xf32> } -// CHECK-LABEL: func @reshape_static( -// CHECK: llvm.mlir.undef : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> -// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> -// CHECK: llvm.insertvalue {{.*}}[0] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> -// CHECK: llvm.extractvalue {{.*}}[1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> -// CHECK: llvm.insertvalue {{.*}}[1] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> -// CHECK: llvm.extractvalue {{.*}}[2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> -// CHECK: llvm.insertvalue {{.*}}[2] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> -// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64 -// CHECK: llvm.insertvalue {{.*}}[3, 0] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> -// CHECK: llvm.mlir.constant(3 : index) : !llvm.i64 -// CHECK: llvm.insertvalue {{.*}}[3, 1] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> -// CHECK: llvm.mlir.constant(4 : index) : !llvm.i64 -// CHECK: llvm.insertvalue {{.*}}[3, 2] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> -// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64 -// CHECK: llvm.insertvalue {{.*}}[3, 3] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> -// CHECK: llvm.mlir.constant(5 : index) : !llvm.i64 -// CHECK: llvm.insertvalue {{.*}}[3, 4] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> -// CHECK: llvm.mlir.constant(60 : index) : !llvm.i64 -// CHECK: llvm.insertvalue {{.*}}[4, 0] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> -// CHECK: llvm.mlir.constant(20 : index) : !llvm.i64 -// CHECK: llvm.insertvalue {{.*}}[4, 1] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> -// CHECK: llvm.mlir.constant(5 : index) : !llvm.i64 -// CHECK: llvm.insertvalue {{.*}}[4, 2] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> -// CHECK: llvm.mlir.constant(5 : index) : !llvm.i64 -// CHECK: llvm.insertvalue {{.*}}[4, 3] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> -// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64 -// CHECK: llvm.insertvalue {{.*}}[4, 4] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> -// CHECK: llvm.mlir.undef : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> -// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> -// CHECK: llvm.insertvalue {{.*}}[0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> -// CHECK: llvm.extractvalue {{.*}}[1] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> -// CHECK: llvm.insertvalue {{.*}}[1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> -// CHECK: llvm.extractvalue {{.*}}[2] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> -// CHECK: llvm.insertvalue {{.*}}[2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> -// CHECK: llvm.mlir.constant(3 : index) : !llvm.i64 -// CHECK: llvm.insertvalue {{.*}}[3, 0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> -// CHECK: llvm.mlir.constant(4 : index) : !llvm.i64 -// CHECK: llvm.insertvalue {{.*}}[3, 1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> -// CHECK: llvm.mlir.constant(5 : index) : !llvm.i64 -// CHECK: llvm.insertvalue {{.*}}[3, 2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> -// CHECK: llvm.mlir.constant(20 : index) : !llvm.i64 -// CHECK: llvm.insertvalue {{.*}}[4, 0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> -// CHECK: llvm.mlir.constant(5 : index) : !llvm.i64 -// CHECK: llvm.insertvalue {{.*}}[4, 1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> -// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64 -// CHECK: llvm.insertvalue {{.*}}[4, 2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> +// CHECK-LABEL: func @reshape_static_collapse +// CHECK: llvm.mlir.undef : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }"> +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.mlir.constant(3 : index) : !llvm.i64 +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.mlir.constant(4 : index) : !llvm.i64 +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.mlir.constant(5 : index) : !llvm.i64 +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.mlir.constant(20 : index) : !llvm.i64 +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.mlir.constant(5 : index) : !llvm.i64 +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> +// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }"> -func @reshape_zero_dim(%arg0 : memref<1x1xf32>) { +func @reshape_fold_zero_dim(%arg0 : memref<1x1xf32>) -> memref { %0 = linalg.reshape %arg0 [] : memref<1x1xf32> into memref - %1 = linalg.reshape %0 [] : memref into memref<1x1xf32> - return + return %0 : memref } -// CHECK-LABEL: func @reshape_zero_dim +// CHECK-LABEL: func @reshape_fold_zero_dim // CHECK: llvm.mlir.undef : !llvm<"{ float*, float*, i64 }"> // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, float*, i64 }"> @@ -287,6 +291,12 @@ // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"{ float*, float*, i64 }"> // CHECK: llvm.extractvalue %{{.*}}[2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm<"{ float*, float*, i64 }"> + +func @reshape_expand_zero_dim(%arg0 : memref) -> memref<1x1xf32> { + %0 = linalg.reshape %arg0 [] : memref into memref<1x1xf32> + return %0 : memref<1x1xf32> +} +// CHECK-LABEL: func @reshape_expand_zero_dim // CHECK: llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm<"{ float*, float*, i64 }"> // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">