diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -2098,6 +2098,50 @@ } }; +/// Materialize the MemRef descriptor represented by the results of +/// ExtractStridedMetadataOp. +class ExtractStridedMetadataOpLowering + : public ConvertOpToLLVMPattern { +public: + using ConvertOpToLLVMPattern< + memref::ExtractStridedMetadataOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + if (!LLVM::isCompatibleType(adaptor.getOperands().front().getType())) + return failure(); + + // Create the descriptor. + MemRefDescriptor sourceMemRef(adaptor.getOperands().front()); + Location loc = extractStridedMetadataOp.getLoc(); + Value source = extractStridedMetadataOp.getSource(); + + auto sourceMemRefType = source.getType().cast(); + int64_t rank = sourceMemRefType.getRank(); + SmallVector results; + results.reserve(2 + rank * 2); + + // Base buffer. + results.push_back(sourceMemRef.allocatedPtr(rewriter, loc)); + + // Offset. + results.push_back(sourceMemRef.offset(rewriter, loc)); + + // Sizes. + for (unsigned i = 0; i < rank; ++i) + results.push_back(sourceMemRef.size(rewriter, loc, i)); + // Strides. + for (unsigned i = 0; i < rank; ++i) + results.push_back(sourceMemRef.stride(rewriter, loc, i)); + + rewriter.replaceOp(extractStridedMetadataOp, results); + return success(); + } +}; + } // namespace void mlir::populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter, @@ -2110,6 +2154,7 @@ AssumeAlignmentOpLowering, ConvertExtractAlignedPointerAsIndex, DimOpLowering, + ExtractStridedMetadataOpLowering, GenericAtomicRMWOpLowering, GlobalMemrefOpLowering, GetGlobalMemrefOpLowering, diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir --- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir @@ -1162,3 +1162,26 @@ // CHECK: return %[[R:.*]] : index return %0: index } + +// ----- + +// CHECK-LABEL: func @extract_strided_metadata( +// CHECK-SAME: %[[ARG:.*]]: memref +// CHECK: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[OFFSET:.*]] = llvm.extractvalue %[[MEM_DESC]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[SIZE0:.*]] = llvm.extractvalue %[[MEM_DESC]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[SIZE1:.*]] = llvm.extractvalue %[[MEM_DESC]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[STRIDE1:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +func.func @extract_strided_metadata( + %ref: memref>) { + + %base, %offset, %sizes:2, %strides:2 = + memref.extract_strided_metadata %ref : memref> + -> memref, index, + index, index, + index, index + + return +}