diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -2278,11 +2278,11 @@ auto targetStructType = typeConverter.convertType(memRefCastOp.getType()); auto loc = op->getLoc(); - // MemRefCastOp reduce to bitcast in the ranked MemRef case. - if (srcType.isa() && dstType.isa()) { - rewriter.replaceOpWithNewOp(op, targetStructType, - transformed.source()); - } else if (srcType.isa() && dstType.isa()) { + // For ranked/ranked case, just keep the original descriptor. + if (srcType.isa() && dstType.isa()) + return rewriter.replaceOp(op, {transformed.source()}); + + if (srcType.isa() && dstType.isa()) { // Casting ranked to unranked memref type // Set the rank in the destination from the memref type // Allocate space on the stack and copy the src memref descriptor diff --git a/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir @@ -324,49 +324,49 @@ // CHECK-LABEL: func @memref_cast_static_to_dynamic func @memref_cast_static_to_dynamic(%static : memref<10x42xf32>) { -// CHECK: llvm.bitcast %{{.*}} : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NOT: llvm.bitcast %0 = memref_cast %static : memref<10x42xf32> to memref return } // CHECK-LABEL: func @memref_cast_static_to_mixed func @memref_cast_static_to_mixed(%static : memref<10x42xf32>) { -// CHECK: llvm.bitcast %{{.*}} : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NOT: llvm.bitcast %0 = memref_cast %static : memref<10x42xf32> to memref return } // CHECK-LABEL: func @memref_cast_dynamic_to_static func @memref_cast_dynamic_to_static(%dynamic : memref) { -// CHECK: llvm.bitcast %{{.*}} : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NOT: llvm.bitcast %0 = memref_cast %dynamic : memref to memref<10x12xf32> return } // CHECK-LABEL: func @memref_cast_dynamic_to_mixed func @memref_cast_dynamic_to_mixed(%dynamic : memref) { -// CHECK: llvm.bitcast %{{.*}} : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NOT: llvm.bitcast %0 = memref_cast %dynamic : memref to memref return } // CHECK-LABEL: func @memref_cast_mixed_to_dynamic func @memref_cast_mixed_to_dynamic(%mixed : memref<42x?xf32>) { -// CHECK: llvm.bitcast %{{.*}} : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NOT: llvm.bitcast %0 = memref_cast %mixed : memref<42x?xf32> to memref return } // CHECK-LABEL: func @memref_cast_mixed_to_static func @memref_cast_mixed_to_static(%mixed : memref<42x?xf32>) { -// CHECK: llvm.bitcast %{{.*}} : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NOT: llvm.bitcast %0 = memref_cast %mixed : memref<42x?xf32> to memref<42x1xf32> return } // CHECK-LABEL: func @memref_cast_mixed_to_mixed func @memref_cast_mixed_to_mixed(%mixed : memref<42x?xf32>) { -// CHECK: llvm.bitcast %{{.*}} : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NOT: llvm.bitcast %0 = memref_cast %mixed : memref<42x?xf32> to memref return }