diff --git a/mlir/include/mlir/Dialect/Affine/Utils.h b/mlir/include/mlir/Dialect/Affine/Utils.h --- a/mlir/include/mlir/Dialect/Affine/Utils.h +++ b/mlir/include/mlir/Dialect/Affine/Utils.h @@ -249,8 +249,7 @@ /// transformed to an identity map with a new shape being computed for the /// normalized memref type and returns it. The old memref type is simplify /// returned if the normalization failed. -MemRefType normalizeMemRefType(MemRefType memrefType, - unsigned numSymbolicOperands); +MemRefType normalizeMemRefType(MemRefType memrefType); /// Given an operation, inserts one or more single result affine apply /// operations, results of which are exclusively used by this operation. diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp --- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp @@ -1720,8 +1720,7 @@ // Fetch a new memref type after normalizing the old memref to have an // identity map layout. - MemRefType newMemRefType = - normalizeMemRefType(memrefType, allocOp->getSymbolOperands().size()); + MemRefType newMemRefType = normalizeMemRefType(memrefType); if (newMemRefType == memrefType) // Either memrefType already had an identity map or the map couldn't be // transformed to an identity map. @@ -1772,8 +1771,7 @@ return success(); } -MemRefType mlir::affine::normalizeMemRefType(MemRefType memrefType, - unsigned numSymbolicOperands) { +MemRefType mlir::affine::normalizeMemRefType(MemRefType memrefType) { unsigned rank = memrefType.getRank(); if (rank == 0) return memrefType; @@ -1784,6 +1782,7 @@ return memrefType; } AffineMap layoutMap = memrefType.getLayout().getAffineMap(); + unsigned numSymbolicOperands = layoutMap.getNumSymbols(); // We don't do any checks for one-to-one'ness; we assume that it is // one-to-one. diff --git a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp @@ -367,8 +367,7 @@ } // Fetch a new memref type after normalizing the old memref to have an // identity map layout. - MemRefType newMemRefType = normalizeMemRefType(memrefType, - /*numSymbolicOperands=*/0); + MemRefType newMemRefType = normalizeMemRefType(memrefType); if (newMemRefType == memrefType || funcOp.isExternal()) { // Either memrefType already had an identity map or the map couldn't be // transformed to an identity map. @@ -475,8 +474,7 @@ } // Computing a new memref type after normalizing the old memref to have an // identity map layout. - MemRefType newMemRefType = normalizeMemRefType(memrefType, - /*numSymbolicOperands=*/0); + MemRefType newMemRefType = normalizeMemRefType(memrefType); resultTypes.push_back(newMemRefType); } @@ -513,9 +511,9 @@ resultTypes.push_back(resultType); continue; } + // Fetch a new memref type after normalizing the old memref. - MemRefType newMemRefType = normalizeMemRefType(memrefType, - /*numSymbolicOperands=*/0); + MemRefType newMemRefType = normalizeMemRefType(memrefType); if (newMemRefType == memrefType) { // Either memrefType already had an identity map or the map couldn't // be transformed to an identity map. diff --git a/mlir/test/Transforms/normalize-memrefs.mlir b/mlir/test/Transforms/normalize-memrefs.mlir --- a/mlir/test/Transforms/normalize-memrefs.mlir +++ b/mlir/test/Transforms/normalize-memrefs.mlir @@ -352,3 +352,14 @@ %0 = memref.alloc() : memref<2x3xf32, #neg> return %0 : memref<2x3xf32, #neg> } + +// CHECK-LABEL: func @memref_with_strided_offset +func.func @memref_with_strided_offset(%arg0: tensor<128x512xf32>, %arg1: index, %arg2: index) -> tensor<16x512xf32> { + %c0 = arith.constant 0 : index + %0 = bufferization.to_memref %arg0 : memref<128x512xf32, strided<[?, ?], offset: ?>> + %subview = memref.subview %0[%arg2, 0] [%arg1, 512] [1, 1] : memref<128x512xf32, strided<[?, ?], offset: ?>> to memref> + // CHECK: %{{.*}} = memref.cast %{{.*}} : memref> to memref<16x512xf32, strided<[?, ?], offset: ?>> + %cast = memref.cast %subview : memref> to memref<16x512xf32, strided<[?, ?], offset: ?>> + %1 = bufferization.to_tensor %cast : memref<16x512xf32, strided<[?, ?], offset: ?>> + return %1 : tensor<16x512xf32> +}