diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -2070,6 +2070,32 @@ } }; +/// Convert memref.extract_aligned_pointer_as_index to the pointer +/// returned by an AllocLikeOp. +class ConvertExtractAlignedPointerAsIndex + : public ConvertOpToLLVMPattern { +public: + using ConvertOpToLLVMPattern< + memref::ExtractAlignedPointerAsIndexOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!extractOp.getSource().isa() && + !isa( + extractOp.getSource().getDefiningOp())) + return rewriter.notifyMatchFailure( + extractOp, "not an AllocLikeOp or a block argument"); + + MemRefDescriptor desc(adaptor.getSource()); + rewriter.replaceOpWithNewOp( + extractOp, getTypeConverter()->getIndexType(), + desc.alignedPtr(rewriter, extractOp->getLoc())); + return success(); + } +}; + } // namespace void mlir::populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter, @@ -2080,6 +2106,7 @@ AllocaScopeOpLowering, AtomicRMWOpLowering, AssumeAlignmentOpLowering, + ConvertExtractAlignedPointerAsIndex, DimOpLowering, GenericAtomicRMWOpLowering, GlobalMemrefOpLowering, diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir --- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir @@ -1143,3 +1143,28 @@ // CHECK: llvm.call @memrefCopy([[SIZE]], [[ALLOCA2]], [[ALLOCA3]]) : (i64, !llvm.ptr)>>, !llvm.ptr)>>) -> () return } + +// ----- + +// CHECK-LABEL: func @extract_aligned_pointer_as_index +func.func @extract_aligned_pointer_as_index(%m: memref) -> (index, index, index) { + %0 = memref.alloc() : memref<2xi8> + %1 = memref.extract_aligned_pointer_as_index %0: memref<2xi8> -> index + // CHECK: %[[E_0:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: %[[I64_0:.*]] = llvm.ptrtoint %[[E_0]] : !llvm.ptr to i64 + // CHECK: %[[R0:.*]] = builtin.unrealized_conversion_cast %[[I64_0]] : i64 to index + + %2 = memref.alloca() : memref<4xi16> + %3 = memref.extract_aligned_pointer_as_index %2: memref<4xi16> -> index + // CHECK: %[[E_1:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: %[[I64_1:.*]] = llvm.ptrtoint %[[E_1]] : !llvm.ptr to i64 + // CHECK: %[[R1:.*]] = builtin.unrealized_conversion_cast %[[I64_1]] : i64 to index + + %4 = memref.extract_aligned_pointer_as_index %m: memref -> index + // CHECK: %[[E_2:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: %[[I64_2:.*]] = llvm.ptrtoint %[[E_2]] : !llvm.ptr to i64 + // CHECK: %[[R2:.*]] = builtin.unrealized_conversion_cast %[[I64_2]] : i64 to index + + // CHECK: return %[[R0:.*]], %[[R1:.*]], %[[R2:.*]] : index, index, index + return %1, %3, %4: index, index, index +}