diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -778,7 +778,7 @@ def MemRef_ReinterpretCastOp: BaseOpWithOffsetSizesAndStrides { let summary = "memref reinterpret cast operation"; let description = [{ diff --git a/mlir/test/Transforms/normalize-memrefs-ops.mlir b/mlir/test/Transforms/normalize-memrefs-ops.mlir --- a/mlir/test/Transforms/normalize-memrefs-ops.mlir +++ b/mlir/test/Transforms/normalize-memrefs-ops.mlir @@ -112,3 +112,22 @@ // Test with an arbitrary op that references the function symbol. "test.op_funcref"() {func = @test_norm_mix} : () -> () + + +// ----- + +#map_1d_tile = affine_map<(d0) -> (d0 floordiv 32, d0 mod 32)> + +// Test with memref.reinterpret_cast + +// CHECK-LABEL: test_norm_reinterpret_cast +// CHECK-SAME: (%[[ARG0:.*]]: memref<1x32xf32>) -> memref<3x1x1xf32> { +func @test_norm_reinterpret_cast(%arg0 : memref<3xf32, #map_1d_tile>) -> (memref<3x1x1xf32>) { + %0 = memref.alloc() : memref<3xf32> + "test.op_norm"(%arg0, %0) : (memref<3xf32, #map_1d_tile>, memref<3xf32>) -> () + %1 = memref.reinterpret_cast %0 to offset: [0], sizes: [3, 1, 1], strides: [1, 1, 1] : memref<3xf32> to memref<3x1x1xf32> + // CHECK: %[[v0:.*]] = memref.alloc() : memref<3xf32> + // CHECK: "test.op_norm"(%[[ARG0]], %[[v0]]) : (memref<1x32xf32>, memref<3xf32>) -> () + // CHECK: memref.reinterpret_cast %[[v0]] to offset: [0], sizes: [3, 1, 1], strides: [1, 1, 1] : memref<3xf32> to memref<3x1x1xf32> + return %1 : memref<3x1x1xf32> +}