diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -1168,10 +1168,14 @@ ConversionPatternRewriter &rewriter) const override { MemRefType dstType = reshapeOp.getResultType(); MemRefType srcType = reshapeOp.getSrcType(); - if (!srcType.getLayout().isIdentity() || - !dstType.getLayout().isIdentity()) { - return rewriter.notifyMatchFailure(reshapeOp, - "only empty layout map is supported"); + + // The condition on the layouts can be ignored when all shapes are static. + if (!srcType.hasStaticShape() || !dstType.hasStaticShape()) { + if (!srcType.getLayout().isIdentity() || + !dstType.getLayout().isIdentity()) { + return rewriter.notifyMatchFailure( + reshapeOp, "only empty layout map is supported"); + } } int64_t offset; diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir --- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir @@ -883,3 +883,12 @@ // CHECK: llvm.atomicrmw _and %{{.*}}, %{{.*}} acq_rel return } + +// ----- + +// CHECK-LABEL: func @collapse_static_shape_with_non_identity_layout +func @collapse_static_shape_with_non_identity_layout(%arg: memref<1x1x8x8xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 64 + s0 + d1 * 64 + d2 * 8 + d3)>>) -> memref<64xf32, affine_map<(d0)[s0] -> (d0 + s0)>> { +// CHECK-NOT: memref.collapse_shape + %1 = memref.collapse_shape %arg [[0, 1, 2, 3]] : memref<1x1x8x8xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 64 + s0 + d1 * 64 + d2 * 8 + d3)>> into memref<64xf32, affine_map<(d0)[s0] -> (d0 + s0)>> + return %1 : memref<64xf32, affine_map<(d0)[s0] -> (d0 + s0)>> +}