diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1926,6 +1926,12 @@ return getResult(); } + // reinterpret_cast(x) w/o offset/shape/stride changes -> x + if (!ShapedType::isDynamicShape(getType().getShape()) && + src.getType() == getType() && getStaticOffsets().front() == 0) { + return src; + } + return nullptr; } diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -719,6 +719,16 @@ // ----- +// CHECK-LABEL: func @reinterpret_noop +// CHECK-SAME: (%[[ARG:.*]]: memref<2x3x4xf32>) +// CHECK-NEXT: return %[[ARG]] +func.func @reinterpret_noop(%arg : memref<2x3x4xf32>) -> memref<2x3x4xf32> { + %0 = memref.reinterpret_cast %arg to offset: [0], sizes: [2, 3, 4], strides: [12, 4, 1] : memref<2x3x4xf32> to memref<2x3x4xf32> + return %0 : memref<2x3x4xf32> +} + +// ----- + // CHECK-LABEL: func @reinterpret_of_reinterpret // CHECK-SAME: (%[[ARG:.*]]: memref, %[[SIZE1:.*]]: index, %[[SIZE2:.*]]: index) // CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [%[[SIZE2]]], strides: [1]