diff --git a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp @@ -513,9 +513,12 @@ resultTypes.push_back(resultType); continue; } + + AffineMap layoutMap = memrefType.getLayout().getAffineMap(); // Fetch a new memref type after normalizing the old memref. - MemRefType newMemRefType = normalizeMemRefType(memrefType, - /*numSymbolicOperands=*/0); + MemRefType newMemRefType = + normalizeMemRefType(memrefType, + /*numSymbolicOperands=*/layoutMap.getNumSymbols()); if (newMemRefType == memrefType) { // Either memrefType already had an identity map or the map couldn't // be transformed to an identity map. diff --git a/mlir/test/Transforms/normalize-memrefs.mlir b/mlir/test/Transforms/normalize-memrefs.mlir --- a/mlir/test/Transforms/normalize-memrefs.mlir +++ b/mlir/test/Transforms/normalize-memrefs.mlir @@ -352,3 +352,14 @@ %0 = memref.alloc() : memref<2x3xf32, #neg> return %0 : memref<2x3xf32, #neg> } + +// CHECK-LABEL: func @memref_with_strided_offset +func.func @memref_with_strided_offset(%arg0: tensor<128x512xf32>, %arg1: index, %arg2: index) -> tensor<16x512xf32> { + %c0 = arith.constant 0 : index + %0 = bufferization.to_memref %arg0 : memref<128x512xf32, strided<[?, ?], offset: ?>> + %subview = memref.subview %0[%arg2, 0] [%arg1, 512] [1, 1] : memref<128x512xf32, strided<[?, ?], offset: ?>> to memref> + // CHECK: %{{.*}} = memref.cast %{{.*}} : memref> to memref<16x512xf32, strided<[?, ?], offset: ?>> + %cast = memref.cast %subview : memref> to memref<16x512xf32, strided<[?, ?], offset: ?>> + %1 = bufferization.to_tensor %cast : memref<16x512xf32, strided<[?, ?], offset: ?>> + return %1 : tensor<16x512xf32> +}