diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -2070,6 +2070,25 @@ } }; +/// Unpack the pointer returned by a memref.extract_aligned_pointer_as_index. +class ConvertExtractAlignedPointerAsIndex + : public ConvertOpToLLVMPattern { +public: + using ConvertOpToLLVMPattern< + memref::ExtractAlignedPointerAsIndexOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + MemRefDescriptor desc(adaptor.getSource()); + rewriter.replaceOpWithNewOp( + extractOp, getTypeConverter()->getIndexType(), + desc.alignedPtr(rewriter, extractOp->getLoc())); + return success(); + } +}; + } // namespace void mlir::populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter, @@ -2080,6 +2099,7 @@ AllocaScopeOpLowering, AtomicRMWOpLowering, AssumeAlignmentOpLowering, + ConvertExtractAlignedPointerAsIndex, DimOpLowering, GenericAtomicRMWOpLowering, GlobalMemrefOpLowering, diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir --- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir @@ -1143,3 +1143,16 @@ // CHECK: llvm.call @memrefCopy([[SIZE]], [[ALLOCA2]], [[ALLOCA3]]) : (i64, !llvm.ptr)>>, !llvm.ptr)>>) -> () return } + +// ----- + +// CHECK-LABEL: func @extract_aligned_pointer_as_index +func.func @extract_aligned_pointer_as_index(%m: memref) -> index { + %0 = memref.extract_aligned_pointer_as_index %m: memref -> index + // CHECK: %[[E:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: %[[I64:.*]] = llvm.ptrtoint %[[E]] : !llvm.ptr to i64 + // CHECK: %[[R:.*]] = builtin.unrealized_conversion_cast %[[I64]] : i64 to index + + // CHECK: return %[[R:.*]] : index + return %0: index +}