diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -118,8 +118,7 @@ unsigned getPointerBitwidth(unsigned addressSpace = 0); protected: - /// LLVM IR module used to parse/create types. - llvm::Module *module; + /// Pointer to the LLVM dialect. LLVM::LLVMDialect *llvmDialect; private: @@ -400,9 +399,6 @@ /// Returns the LLVM IR context. llvm::LLVMContext &getContext() const; - /// Returns the LLVM IR module associated with the LLVM dialect. - llvm::Module &getModule() const; - /// Gets the MLIR type wrapping the LLVM integer type whose bit width is /// defined by the used type converter. LLVM::LLVMType getIndexType() const; @@ -437,8 +433,8 @@ ConversionPatternRewriter &rewriter) const; Value getDataPtr(Location loc, MemRefType type, Value memRefDesc, - ValueRange indices, ConversionPatternRewriter &rewriter, - llvm::Module &module) const; + ValueRange indices, + ConversionPatternRewriter &rewriter) const; /// Returns the type of a pointer to an element of the memref. Type getElementPtrType(MemRefType type) const; diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -25,6 +25,7 @@ llvm::LLVMContext &getLLVMContext(); llvm::Module &getLLVMModule(); llvm::sys::SmartMutex &getLLVMContextMutex(); + const llvm::DataLayout &getDataLayout(); private: friend LLVMType; diff --git a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp --- a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp +++ b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp @@ -66,12 +66,7 @@ private: LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; } - llvm::LLVMContext &getLLVMContext() { - return getLLVMDialect()->getLLVMContext(); - } - void initializeCachedTypes() { - const llvm::Module &module = llvmDialect->getLLVMModule(); llvmVoidType = LLVM::LLVMType::getVoidTy(llvmDialect); llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect); llvmPointerPointerType = llvmPointerType.getPointerTo(); @@ -79,7 +74,7 @@ llvmInt32Type = LLVM::LLVMType::getInt32Ty(llvmDialect); llvmInt64Type = LLVM::LLVMType::getInt64Ty(llvmDialect); llvmIntPtrType = LLVM::LLVMType::getIntNTy( - llvmDialect, module.getDataLayout().getPointerSizeInBits()); + llvmDialect, llvmDialect->getDataLayout().getPointerSizeInBits()); } LLVM::LLVMType getVoidType() { return llvmVoidType; } @@ -95,9 +90,9 @@ LLVM::LLVMType getInt64Type() { return llvmInt64Type; } LLVM::LLVMType getIntPtrType() { - const llvm::Module &module = getLLVMDialect()->getLLVMModule(); return LLVM::LLVMType::getIntNTy( - getLLVMDialect(), module.getDataLayout().getPointerSizeInBits()); + getLLVMDialect(), + getLLVMDialect()->getDataLayout().getPointerSizeInBits()); } // Allocate a void pointer on the stack. diff --git a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp --- a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp +++ b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp @@ -59,10 +59,6 @@ private: LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; } - llvm::LLVMContext &getLLVMContext() { - return getLLVMDialect()->getLLVMContext(); - } - void initializeCachedTypes() { llvmDialect = getContext().getRegisteredDialect(); llvmFloatType = LLVM::LLVMType::getFloatTy(llvmDialect); diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -128,10 +128,9 @@ : llvmDialect(ctx->getRegisteredDialect()), options(options) { assert(llvmDialect && "LLVM IR dialect is not registered"); - module = &llvmDialect->getLLVMModule(); if (options.indexBitwidth == kDeriveIndexBitwidthFromDataLayout) this->options.indexBitwidth = - module->getDataLayout().getPointerSizeInBits(); + llvmDialect->getDataLayout().getPointerSizeInBits(); // Register conversions for the standard types. addConversion([&](ComplexType type) { return convertComplexType(type); }); @@ -196,7 +195,7 @@ /// Get the LLVM context. llvm::LLVMContext &LLVMTypeConverter::getLLVMContext() { - return module->getContext(); + return llvmDialect->getLLVMContext(); } LLVM::LLVMType LLVMTypeConverter::getIndexType() { @@ -204,7 +203,7 @@ } unsigned LLVMTypeConverter::getPointerBitwidth(unsigned addressSpace) { - return module->getDataLayout().getPointerSizeInBits(addressSpace); + return llvmDialect->getDataLayout().getPointerSizeInBits(addressSpace); } Type LLVMTypeConverter::convertIndexType(IndexType type) { @@ -849,10 +848,6 @@ return typeConverter.getLLVMContext(); } -llvm::Module &ConvertToLLVMPattern::getModule() const { - return getDialect().getLLVMModule(); -} - LLVM::LLVMType ConvertToLLVMPattern::getIndexType() const { return typeConverter.getIndexType(); } @@ -910,10 +905,9 @@ return rewriter.create(loc, elementTypePtr, base, offsetValue); } -Value ConvertToLLVMPattern::getDataPtr(Location loc, MemRefType type, - Value memRefDesc, ValueRange indices, - ConversionPatternRewriter &rewriter, - llvm::Module &module) const { +Value ConvertToLLVMPattern::getDataPtr( + Location loc, MemRefType type, Value memRefDesc, ValueRange indices, + ConversionPatternRewriter &rewriter) const { LLVM::LLVMType ptrType = MemRefDescriptor(memRefDesc).getElementType(); int64_t offset; SmallVector strides; @@ -2451,7 +2445,7 @@ auto type = loadOp.getMemRefType(); Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(), - transformed.indices(), rewriter, getModule()); + transformed.indices(), rewriter); rewriter.replaceOpWithNewOp(op, dataPtr); return success(); } @@ -2469,7 +2463,7 @@ StoreOp::Adaptor transformed(operands); Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(), - transformed.indices(), rewriter, getModule()); + transformed.indices(), rewriter); rewriter.replaceOpWithNewOp(op, transformed.value(), dataPtr); return success(); @@ -2489,7 +2483,7 @@ auto type = prefetchOp.getMemRefType(); Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(), - transformed.indices(), rewriter, getModule()); + transformed.indices(), rewriter); // Replace with llvm.prefetch. auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32)); @@ -3086,7 +3080,7 @@ auto resultType = adaptor.value().getType(); auto memRefType = atomicOp.getMemRefType(); auto dataPtr = getDataPtr(op->getLoc(), memRefType, adaptor.memref(), - adaptor.indices(), rewriter, getModule()); + adaptor.indices(), rewriter); rewriter.replaceOpWithNewOp( op, resultType, *maybeKind, dataPtr, adaptor.value(), LLVM::AtomicOrdering::acq_rel); @@ -3152,7 +3146,7 @@ rewriter.setInsertionPointToEnd(initBlock); auto memRefType = atomicOp.memref().getType().cast(); auto dataPtr = getDataPtr(loc, memRefType, adaptor.memref(), - adaptor.indices(), rewriter, getModule()); + adaptor.indices(), rewriter); Value init = rewriter.create(loc, dataPtr); rewriter.create(loc, init, loopBlock); diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -131,7 +131,7 @@ LLVM::LLVMDialect *dialect = typeConverter.getDialect(); align = LLVM::TypeToLLVMIRTranslator(dialect->getLLVMContext()) .getPreferredAlignment(elementTy.cast(), - dialect->getLLVMModule().getDataLayout()); + dialect->getDataLayout()); return success(); } @@ -1152,7 +1152,7 @@ // address space 0. // TODO: support alignment when possible. Value dataPtr = getDataPtr(loc, memRefType, adaptor.memref(), - adaptor.indices(), rewriter, getModule()); + adaptor.indices(), rewriter); auto vecTy = toLLVMTy(xferOp.getVectorType()).template cast(); Value vectorDataPtr; diff --git a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp --- a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp +++ b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp @@ -103,7 +103,7 @@ // indices, so no need to calculat offset size in bytes again in // the MUBUF instruction. Value dataPtr = getDataPtr(loc, memRefType, adaptor.memref(), - adaptor.indices(), rewriter, getModule()); + adaptor.indices(), rewriter); // 1. Create and fill a <4 x i32> dwordConfig with: // 1st two elements holding the address of dataPtr. diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -1741,6 +1741,9 @@ llvm::sys::SmartMutex &LLVMDialect::getLLVMContextMutex() { return impl->mutex; } +const llvm::DataLayout &LLVMDialect::getDataLayout() { + return impl->module.getDataLayout(); +} /// Parse a type registered to this dialect. Type LLVMDialect::parseType(DialectAsmParser &parser) const {