diff --git a/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h b/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h --- a/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h @@ -43,6 +43,10 @@ static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, MemRefType type, Value memory); + static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, + LLVMTypeConverter &typeConverter, + MemRefType type, Value memory, + Value alignedMemory); /// Builds IR extracting the allocated pointer from the descriptor. Value allocatedPtr(OpBuilder &builder, Location loc); diff --git a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp --- a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp +++ b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp @@ -43,6 +43,12 @@ MemRefDescriptor::fromStaticShape(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, MemRefType type, Value memory) { + return fromStaticShape(builder, loc, typeConverter, type, memory, memory); +} + +MemRefDescriptor MemRefDescriptor::fromStaticShape( + OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, + MemRefType type, Value memory, Value alignedMemory) { assert(type.hasStaticShape() && "unexpected dynamic shape"); // Extract all strides and offsets and verify they are static. @@ -61,7 +67,7 @@ auto descr = MemRefDescriptor::undef(builder, loc, convertedType); descr.setAllocatedPtr(builder, loc, memory); - descr.setAlignedPtr(builder, loc, memory); + descr.setAlignedPtr(builder, loc, alignedMemory); descr.setConstantOffset(builder, loc, offset); // Fill in sizes and strides diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -2115,7 +2115,7 @@ return failure(); // Create the descriptor. - MemRefDescriptor sourceMemRef(adaptor.getOperands().front()); + MemRefDescriptor sourceMemRef(adaptor.getSource()); Location loc = extractStridedMetadataOp.getLoc(); Value source = extractStridedMetadataOp.getSource(); @@ -2125,7 +2125,13 @@ results.reserve(2 + rank * 2); // Base buffer. - results.push_back(sourceMemRef.allocatedPtr(rewriter, loc)); + Value baseBuffer = sourceMemRef.allocatedPtr(rewriter, loc); + Value alignedBuffer = sourceMemRef.alignedPtr(rewriter, loc); + MemRefDescriptor dstMemRef = MemRefDescriptor::fromStaticShape( + rewriter, loc, *getTypeConverter(), + extractStridedMetadataOp.getBaseBuffer().getType().cast(), + baseBuffer, alignedBuffer); + results.push_back((Value)dstMemRef); // Offset. results.push_back(sourceMemRef.offset(rewriter, loc)); diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir --- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir @@ -1169,6 +1169,12 @@ // CHECK-SAME: %[[ARG:.*]]: memref // CHECK: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64)> +// CHECK: %[[DESC0:.*]] = llvm.insertvalue %[[BASE]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64)> +// CHECK: %[[DESC1:.*]] = llvm.insertvalue %[[ALIGNED_BASE]], %[[DESC0]][1] : !llvm.struct<(ptr, ptr, i64)> +// CHECK: %[[OFF0:.*]] = llvm.mlir.constant(0 : index) : i64 +// CHECK: %[[BASE_BUFFER_DESC:.*]] = llvm.insertvalue %[[OFF0]], %[[DESC1]][2] : !llvm.struct<(ptr, ptr, i64)> // CHECK: %[[OFFSET:.*]] = llvm.extractvalue %[[MEM_DESC]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[SIZE0:.*]] = llvm.extractvalue %[[MEM_DESC]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[SIZE1:.*]] = llvm.extractvalue %[[MEM_DESC]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>