diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -34,6 +34,7 @@ namespace LLVM { class LLVMDialect; class LLVMType; +class LLVMPointerType; } // namespace LLVM /// Callback to convert function argument types. It converts a MemRef function @@ -281,8 +282,8 @@ void setConstantStride(OpBuilder &builder, Location loc, unsigned pos, uint64_t stride); - /// Returns the (LLVM) type this descriptor points to. - LLVM::LLVMType getElementType(); + /// Returns the (LLVM) pointer type this descriptor contains. + LLVM::LLVMPointerType getElementPtrType(); /// Builds IR populating a MemRef descriptor structure from a list of /// individual values composing that descriptor, in the following order: diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -642,9 +642,11 @@ createIndexAttrConstant(builder, loc, indexType, stride)); } -LLVM::LLVMType MemRefDescriptor::getElementType() { - return value.getType().cast().getStructElementType( - kAlignedPtrPosInMemRefDescriptor); +LLVM::LLVMPointerType MemRefDescriptor::getElementPtrType() { + return value.getType() + .cast() + .getStructElementType(kAlignedPtrPosInMemRefDescriptor) + .cast(); } /// Creates a MemRef descriptor structure from a list of individual values @@ -894,7 +896,7 @@ Value ConvertToLLVMPattern::getDataPtr( Location loc, MemRefType type, Value memRefDesc, ValueRange indices, ConversionPatternRewriter &rewriter) const { - LLVM::LLVMType ptrType = MemRefDescriptor(memRefDesc).getElementType(); + LLVM::LLVMType ptrType = MemRefDescriptor(memRefDesc).getElementPtrType(); int64_t offset; SmallVector strides; auto successStrides = getStridesAndOffset(type, strides, offset); diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -198,7 +198,7 @@ Value base; if (failed(getBase(rewriter, loc, memref, memRefType, base))) return failure(); - auto pType = MemRefDescriptor(memref).getElementType(); + auto pType = MemRefDescriptor(memref).getElementPtrType(); ptr = rewriter.create(loc, pType, base); return success(); } @@ -225,7 +225,7 @@ Value base; if (failed(getBase(rewriter, loc, memref, memRefType, base))) return failure(); - auto pType = MemRefDescriptor(memref).getElementType(); + auto pType = MemRefDescriptor(memref).getElementPtrType(); auto ptrsType = LLVM::LLVMType::getVectorTy(pType, vType.getDimSize(0)); ptrs = rewriter.create(loc, ptrsType, base, indices); return success(); @@ -1151,7 +1151,7 @@ // Create descriptor. auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); - Type llvmTargetElementTy = desc.getElementType(); + Type llvmTargetElementTy = desc.getElementPtrType(); // Set allocated ptr. Value allocated = sourceMemRef.allocatedPtr(rewriter, loc); allocated =