diff --git a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp --- a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp +++ b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp @@ -56,19 +56,15 @@ auto memRefType = (*op->operand_type_begin()).cast(); auto memRefShape = memRefType.getShape(); auto loc = op->getLoc(); - auto *llvmDialect = - op->getContext()->getRegisteredDialect(); - assert(llvmDialect && "expected llvm dialect to be registered"); ModuleOp parentModule = op->getParentOfType(); // Get a symbol reference to the printf function, inserting it if necessary. - auto printfRef = getOrInsertPrintf(rewriter, parentModule, llvmDialect); + auto printfRef = getOrInsertPrintf(rewriter, parentModule); Value formatSpecifierCst = getOrCreateGlobalString( - loc, rewriter, "frmt_spec", StringRef("%f \0", 4), parentModule, - llvmDialect); + loc, rewriter, "frmt_spec", StringRef("%f \0", 4), parentModule); Value newLineCst = getOrCreateGlobalString( - loc, rewriter, "nl", StringRef("\n\0", 2), parentModule, llvmDialect); + loc, rewriter, "nl", StringRef("\n\0", 2), parentModule); // Create a loop for each of the dimensions within the shape. SmallVector loopIvs; @@ -108,16 +104,15 @@ /// Return a symbol reference to the printf function, inserting it into the /// module if necessary. static FlatSymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter, - ModuleOp module, - LLVM::LLVMDialect *llvmDialect) { + ModuleOp module) { auto *context = module.getContext(); if (module.lookupSymbol("printf")) return SymbolRefAttr::get("printf", context); // Create a function declaration for printf, the signature is: // * `i32 (i8*, ...)` - auto llvmI32Ty = LLVM::LLVMType::getInt32Ty(llvmDialect); - auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect); + auto llvmI32Ty = LLVM::LLVMType::getInt32Ty(context); + auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(context); auto llvmFnType = LLVM::LLVMType::getFunctionTy(llvmI32Ty, llvmI8PtrTy, /*isVarArg=*/true); @@ -132,15 +127,14 @@ /// name, creating the string if necessary. static Value getOrCreateGlobalString(Location loc, OpBuilder &builder, StringRef name, StringRef value, - ModuleOp module, - LLVM::LLVMDialect *llvmDialect) { + ModuleOp module) { // Create the global at the entry of the module. LLVM::GlobalOp global; if (!(global = module.lookupSymbol(name))) { OpBuilder::InsertionGuard insertGuard(builder); builder.setInsertionPointToStart(module.getBody()); auto type = LLVM::LLVMType::getArrayTy( - LLVM::LLVMType::getInt8Ty(llvmDialect), value.size()); + LLVM::LLVMType::getInt8Ty(builder.getContext()), value.size()); global = builder.create(loc, type, /*isConstant=*/true, LLVM::Linkage::Internal, name, builder.getStringAttr(value)); @@ -149,10 +143,10 @@ // Get the pointer to the first character in the global string. Value globalPtr = builder.create(loc, global); Value cst0 = builder.create( - loc, LLVM::LLVMType::getInt64Ty(llvmDialect), + loc, LLVM::LLVMType::getInt64Ty(builder.getContext()), builder.getIntegerAttr(builder.getIndexType(), 0)); return builder.create( - loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), globalPtr, + loc, LLVM::LLVMType::getInt8PtrTy(builder.getContext()), globalPtr, ArrayRef({cst0, cst0})); } }; diff --git a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp --- a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp +++ b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp @@ -56,19 +56,15 @@ auto memRefType = (*op->operand_type_begin()).cast(); auto memRefShape = memRefType.getShape(); auto loc = op->getLoc(); - auto *llvmDialect = - op->getContext()->getRegisteredDialect(); - assert(llvmDialect && "expected llvm dialect to be registered"); ModuleOp parentModule = op->getParentOfType(); // Get a symbol reference to the printf function, inserting it if necessary. - auto printfRef = getOrInsertPrintf(rewriter, parentModule, llvmDialect); + auto printfRef = getOrInsertPrintf(rewriter, parentModule); Value formatSpecifierCst = getOrCreateGlobalString( - loc, rewriter, "frmt_spec", StringRef("%f \0", 4), parentModule, - llvmDialect); + loc, rewriter, "frmt_spec", StringRef("%f \0", 4), parentModule); Value newLineCst = getOrCreateGlobalString( - loc, rewriter, "nl", StringRef("\n\0", 2), parentModule, llvmDialect); + loc, rewriter, "nl", StringRef("\n\0", 2), parentModule); // Create a loop for each of the dimensions within the shape. SmallVector loopIvs; @@ -108,16 +104,15 @@ /// Return a symbol reference to the printf function, inserting it into the /// module if necessary. static FlatSymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter, - ModuleOp module, - LLVM::LLVMDialect *llvmDialect) { + ModuleOp module) { auto *context = module.getContext(); if (module.lookupSymbol("printf")) return SymbolRefAttr::get("printf", context); // Create a function declaration for printf, the signature is: // * `i32 (i8*, ...)` - auto llvmI32Ty = LLVM::LLVMType::getInt32Ty(llvmDialect); - auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect); + auto llvmI32Ty = LLVM::LLVMType::getInt32Ty(context); + auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(context); auto llvmFnType = LLVM::LLVMType::getFunctionTy(llvmI32Ty, llvmI8PtrTy, /*isVarArg=*/true); @@ -132,15 +127,14 @@ /// name, creating the string if necessary. static Value getOrCreateGlobalString(Location loc, OpBuilder &builder, StringRef name, StringRef value, - ModuleOp module, - LLVM::LLVMDialect *llvmDialect) { + ModuleOp module) { // Create the global at the entry of the module. LLVM::GlobalOp global; if (!(global = module.lookupSymbol(name))) { OpBuilder::InsertionGuard insertGuard(builder); builder.setInsertionPointToStart(module.getBody()); auto type = LLVM::LLVMType::getArrayTy( - LLVM::LLVMType::getInt8Ty(llvmDialect), value.size()); + LLVM::LLVMType::getInt8Ty(builder.getContext()), value.size()); global = builder.create(loc, type, /*isConstant=*/true, LLVM::Linkage::Internal, name, builder.getStringAttr(value)); @@ -149,10 +143,10 @@ // Get the pointer to the first character in the global string. Value globalPtr = builder.create(loc, global); Value cst0 = builder.create( - loc, LLVM::LLVMType::getInt64Ty(llvmDialect), + loc, LLVM::LLVMType::getInt64Ty(builder.getContext()), builder.getIntegerAttr(builder.getIndexType(), 0)); return builder.create( - loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), globalPtr, + loc, LLVM::LLVMType::getInt8PtrTy(builder.getContext()), globalPtr, ArrayRef({cst0, cst0})); } }; diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h @@ -59,8 +59,7 @@ /// global and use it to compute the address of the first character in the /// string (operations inserted at the builder insertion point). Value createGlobalString(Location loc, OpBuilder &builder, StringRef name, - StringRef value, LLVM::Linkage linkage, - LLVM::LLVMDialect *llvmDialect); + StringRef value, LLVM::Linkage linkage); /// LLVM requires some operations to be inside of a Module operation. This /// function confirms that the Operation has the desired properties. diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -58,8 +58,7 @@ "$_self.cast<::mlir::LLVM::LLVMType>().isIntegerTy(" # width # ")">]>, "LLVM dialect " # width # "-bit integer">, BuildableType< - "::mlir::LLVM::LLVMType::getIntNTy(" - "$_builder.getContext()->getRegisteredDialect()," + "::mlir::LLVM::LLVMType::getIntNTy($_builder.getContext()," # width # ")">; def LLVMI1 : LLVMI<1>; diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -151,8 +151,7 @@ let builders = [OpBuilder< "OpBuilder &b, OperationState &result, ICmpPredicate predicate, Value lhs, " "Value rhs", [{ - LLVMDialect *dialect = &lhs.getType().cast().getDialect(); - build(b, result, LLVMType::getInt1Ty(dialect), + build(b, result, LLVMType::getInt1Ty(lhs.getType().getContext()), b.getI64IntegerAttr(static_cast(predicate)), lhs, rhs); }]>]; let parser = [{ return parseCmpOp(parser, result); }]; @@ -198,8 +197,7 @@ let builders = [OpBuilder< "OpBuilder &b, OperationState &result, FCmpPredicate predicate, Value lhs, " "Value rhs", [{ - LLVMDialect *dialect = &lhs.getType().cast().getDialect(); - build(b, result, LLVMType::getInt1Ty(dialect), + build(b, result, LLVMType::getInt1Ty(lhs.getType().getContext()), b.getI64IntegerAttr(static_cast(predicate)), lhs, rhs); }]>]; let parser = [{ return parseCmpOp(parser, result); }]; diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h @@ -152,32 +152,32 @@ bool isStructTy(); /// Utilities used to generate floating point types. - static LLVMType getDoubleTy(LLVMDialect *dialect); - static LLVMType getFloatTy(LLVMDialect *dialect); - static LLVMType getBFloatTy(LLVMDialect *dialect); - static LLVMType getHalfTy(LLVMDialect *dialect); - static LLVMType getFP128Ty(LLVMDialect *dialect); - static LLVMType getX86_FP80Ty(LLVMDialect *dialect); + static LLVMType getDoubleTy(MLIRContext *context); + static LLVMType getFloatTy(MLIRContext *context); + static LLVMType getBFloatTy(MLIRContext *context); + static LLVMType getHalfTy(MLIRContext *context); + static LLVMType getFP128Ty(MLIRContext *context); + static LLVMType getX86_FP80Ty(MLIRContext *context); /// Utilities used to generate integer types. - static LLVMType getIntNTy(LLVMDialect *dialect, unsigned numBits); - static LLVMType getInt1Ty(LLVMDialect *dialect) { - return getIntNTy(dialect, /*numBits=*/1); + static LLVMType getIntNTy(MLIRContext *context, unsigned numBits); + static LLVMType getInt1Ty(MLIRContext *context) { + return getIntNTy(context, /*numBits=*/1); } - static LLVMType getInt8Ty(LLVMDialect *dialect) { - return getIntNTy(dialect, /*numBits=*/8); + static LLVMType getInt8Ty(MLIRContext *context) { + return getIntNTy(context, /*numBits=*/8); } - static LLVMType getInt8PtrTy(LLVMDialect *dialect) { - return getInt8Ty(dialect).getPointerTo(); + static LLVMType getInt8PtrTy(MLIRContext *context) { + return getInt8Ty(context).getPointerTo(); } - static LLVMType getInt16Ty(LLVMDialect *dialect) { - return getIntNTy(dialect, /*numBits=*/16); + static LLVMType getInt16Ty(MLIRContext *context) { + return getIntNTy(context, /*numBits=*/16); } - static LLVMType getInt32Ty(LLVMDialect *dialect) { - return getIntNTy(dialect, /*numBits=*/32); + static LLVMType getInt32Ty(MLIRContext *context) { + return getIntNTy(context, /*numBits=*/32); } - static LLVMType getInt64Ty(LLVMDialect *dialect) { - return getIntNTy(dialect, /*numBits=*/64); + static LLVMType getInt64Ty(MLIRContext *context) { + return getIntNTy(context, /*numBits=*/64); } /// Utilities used to generate other miscellaneous types. @@ -187,33 +187,33 @@ static LLVMType getFunctionTy(LLVMType result, bool isVarArg) { return getFunctionTy(result, llvm::None, isVarArg); } - static LLVMType getStructTy(LLVMDialect *dialect, ArrayRef elements, + static LLVMType getStructTy(MLIRContext *context, ArrayRef elements, bool isPacked = false); - static LLVMType getStructTy(LLVMDialect *dialect, bool isPacked = false) { - return getStructTy(dialect, llvm::None, isPacked); + static LLVMType getStructTy(MLIRContext *context, bool isPacked = false) { + return getStructTy(context, llvm::None, isPacked); } template static typename std::enable_if::value, LLVMType>::type getStructTy(LLVMType elt1, Args... elts) { SmallVector fields({elt1, elts...}); - return getStructTy(&elt1.getDialect(), fields); + return getStructTy(elt1.getContext(), fields); } static LLVMType getVectorTy(LLVMType elementType, unsigned numElements); /// Void type utilities. - static LLVMType getVoidTy(LLVMDialect *dialect); + static LLVMType getVoidTy(MLIRContext *context); bool isVoidTy(); // Creation and setting of LLVM's identified struct types - static LLVMType createStructTy(LLVMDialect *dialect, + static LLVMType createStructTy(MLIRContext *context, ArrayRef elements, Optional name, bool isPacked = false); - static LLVMType createStructTy(LLVMDialect *dialect, + static LLVMType createStructTy(MLIRContext *context, Optional name) { - return createStructTy(dialect, llvm::None, name); + return createStructTy(context, llvm::None, name); } static LLVMType createStructTy(ArrayRef elements, @@ -222,7 +222,7 @@ assert(!elements.empty() && "This method may not be invoked with an empty list"); LLVMType ele0 = elements.front(); - return createStructTy(&ele0.getDialect(), elements, name, isPacked); + return createStructTy(ele0.getContext(), elements, name, isPacked); } template @@ -231,7 +231,7 @@ createStructTy(StringRef name, LLVMType elt1, Args... elts) { SmallVector fields({elt1, elts...}); Optional opt_name(name); - return createStructTy(&elt1.getDialect(), fields, opt_name); + return createStructTy(elt1.getContext(), fields, opt_name); } static LLVMType setStructTyBody(LLVMType structType, diff --git a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp --- a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp +++ b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp @@ -67,14 +67,14 @@ LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; } void initializeCachedTypes() { - llvmVoidType = LLVM::LLVMType::getVoidTy(llvmDialect); - llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect); + llvmVoidType = LLVM::LLVMType::getVoidTy(&getContext()); + llvmPointerType = LLVM::LLVMType::getInt8PtrTy(&getContext()); llvmPointerPointerType = llvmPointerType.getPointerTo(); - llvmInt8Type = LLVM::LLVMType::getInt8Ty(llvmDialect); - llvmInt32Type = LLVM::LLVMType::getInt32Ty(llvmDialect); - llvmInt64Type = LLVM::LLVMType::getInt64Ty(llvmDialect); + llvmInt8Type = LLVM::LLVMType::getInt8Ty(&getContext()); + llvmInt32Type = LLVM::LLVMType::getInt32Ty(&getContext()); + llvmInt64Type = LLVM::LLVMType::getInt64Ty(&getContext()); llvmIntPtrType = LLVM::LLVMType::getIntNTy( - llvmDialect, llvmDialect->getDataLayout().getPointerSizeInBits()); + &getContext(), llvmDialect->getDataLayout().getPointerSizeInBits()); } LLVM::LLVMType getVoidType() { return llvmVoidType; } @@ -91,7 +91,7 @@ LLVM::LLVMType getIntPtrType() { return LLVM::LLVMType::getIntNTy( - getLLVMDialect(), + &getContext(), getLLVMDialect()->getDataLayout().getPointerSizeInBits()); } @@ -340,7 +340,7 @@ std::string(llvm::formatv("{0}_{1}_kernel_name", moduleName, name)); return LLVM::createGlobalString( loc, builder, globalName, StringRef(kernelName.data(), kernelName.size()), - LLVM::Linkage::Internal, llvmDialect); + LLVM::Linkage::Internal); } // Emits LLVM IR to launch a kernel function. Expects the module that contains @@ -378,9 +378,9 @@ SmallString<128> nameBuffer(kernelModule.getName()); nameBuffer.append(kGpuBinaryStorageSuffix); - Value data = LLVM::createGlobalString( - loc, builder, nameBuffer.str(), binaryAttr.getValue(), - LLVM::Linkage::Internal, getLLVMDialect()); + Value data = + LLVM::createGlobalString(loc, builder, nameBuffer.str(), + binaryAttr.getValue(), LLVM::Linkage::Internal); // Emit the load module call to load the module data. Error checking is done // in the called helper function. diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h @@ -89,7 +89,7 @@ // Rewrite workgroup memory attributions to addresses of global buffers. rewriter.setInsertionPointToStart(&gpuFuncOp.front()); unsigned numProperArguments = gpuFuncOp.getNumArguments(); - auto i32Type = LLVM::LLVMType::getInt32Ty(typeConverter.getDialect()); + auto i32Type = LLVM::LLVMType::getInt32Ty(rewriter.getContext()); Value zero = nullptr; if (!workgroupBuffers.empty()) @@ -117,7 +117,7 @@ // Rewrite private memory attributions to alloca'ed buffers. unsigned numWorkgroupAttributions = gpuFuncOp.getNumWorkgroupAttributions(); - auto int64Ty = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect()); + auto int64Ty = LLVM::LLVMType::getInt64Ty(rewriter.getContext()); for (auto en : llvm::enumerate(gpuFuncOp.getPrivateAttributions())) { Value attribution = en.value(); auto type = attribution.getType().cast(); diff --git a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h --- a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h +++ b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h @@ -46,17 +46,17 @@ matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); - auto dialect = typeConverter.getDialect(); + MLIRContext *context = rewriter.getContext(); Value newOp; switch (dimensionToIndex(cast(op))) { case X: - newOp = rewriter.create(loc, LLVM::LLVMType::getInt32Ty(dialect)); + newOp = rewriter.create(loc, LLVM::LLVMType::getInt32Ty(context)); break; case Y: - newOp = rewriter.create(loc, LLVM::LLVMType::getInt32Ty(dialect)); + newOp = rewriter.create(loc, LLVM::LLVMType::getInt32Ty(context)); break; case Z: - newOp = rewriter.create(loc, LLVM::LLVMType::getInt32Ty(dialect)); + newOp = rewriter.create(loc, LLVM::LLVMType::getInt32Ty(context)); break; default: return failure(); @@ -64,10 +64,10 @@ if (indexBitwidth > 32) { newOp = rewriter.create( - loc, LLVM::LLVMType::getIntNTy(dialect, indexBitwidth), newOp); + loc, LLVM::LLVMType::getIntNTy(context, indexBitwidth), newOp); } else if (indexBitwidth < 32) { newOp = rewriter.create( - loc, LLVM::LLVMType::getIntNTy(dialect, indexBitwidth), newOp); + loc, LLVM::LLVMType::getIntNTy(context, indexBitwidth), newOp); } rewriter.replaceOp(op, {newOp}); diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h --- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h +++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h @@ -85,7 +85,7 @@ return operand; return rewriter.create( - operand.getLoc(), LLVM::LLVMType::getFloatTy(&type.getDialect()), + operand.getLoc(), LLVM::LLVMType::getFloatTy(rewriter.getContext()), operand); } diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -57,11 +57,11 @@ Location loc = op->getLoc(); gpu::ShuffleOpAdaptor adaptor(operands); - auto dialect = typeConverter.getDialect(); auto valueTy = adaptor.value().getType().cast(); - auto int32Type = LLVM::LLVMType::getInt32Ty(dialect); - auto predTy = LLVM::LLVMType::getInt1Ty(dialect); - auto resultTy = LLVM::LLVMType::getStructTy(dialect, {valueTy, predTy}); + auto int32Type = LLVM::LLVMType::getInt32Ty(rewriter.getContext()); + auto predTy = LLVM::LLVMType::getInt1Ty(rewriter.getContext()); + auto resultTy = + LLVM::LLVMType::getStructTy(rewriter.getContext(), {valueTy, predTy}); Value one = rewriter.create( loc, int32Type, rewriter.getI32IntegerAttr(1)); diff --git a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp --- a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp +++ b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp @@ -57,15 +57,12 @@ : public ConvertVulkanLaunchFuncToVulkanCallsBase< VulkanLaunchFuncToVulkanCallsPass> { private: - LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; } - void initializeCachedTypes() { - llvmDialect = getContext().getRegisteredDialect(); - llvmFloatType = LLVM::LLVMType::getFloatTy(llvmDialect); - llvmVoidType = LLVM::LLVMType::getVoidTy(llvmDialect); - llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect); - llvmInt32Type = LLVM::LLVMType::getInt32Ty(llvmDialect); - llvmInt64Type = LLVM::LLVMType::getInt64Ty(llvmDialect); + llvmFloatType = LLVM::LLVMType::getFloatTy(&getContext()); + llvmVoidType = LLVM::LLVMType::getVoidTy(&getContext()); + llvmPointerType = LLVM::LLVMType::getInt8PtrTy(&getContext()); + llvmInt32Type = LLVM::LLVMType::getInt32Ty(&getContext()); + llvmInt64Type = LLVM::LLVMType::getInt64Ty(&getContext()); } LLVM::LLVMType getMemRefType(uint32_t rank, LLVM::LLVMType elemenType) { @@ -87,7 +84,7 @@ // `!llvm<"{ `element-type`*, `element-type`*, i64, // [`rank` x i64], [`rank` x i64]}">`. return LLVM::LLVMType::getStructTy( - llvmDialect, + &getContext(), {llvmPtrToElementType, llvmPtrToElementType, getInt64Type(), llvmArrayRankElementSizeType, llvmArrayRankElementSizeType}); } @@ -153,7 +150,6 @@ void runOnOperation() override; private: - LLVM::LLVMDialect *llvmDialect; LLVM::LLVMType llvmFloatType; LLVM::LLVMType llvmVoidType; LLVM::LLVMType llvmPointerType; @@ -245,7 +241,7 @@ // int16_t and bitcast the descriptor. if (type.isHalfTy()) { auto memRefTy = - getMemRefType(rank, LLVM::LLVMType::getInt16Ty(llvmDialect)); + getMemRefType(rank, LLVM::LLVMType::getInt16Ty(&getContext())); ptrToMemRefDescriptor = builder.create( loc, memRefTy.getPointerTo(), ptrToMemRefDescriptor); } @@ -324,15 +320,15 @@ } for (unsigned i = 1; i <= 3; i++) { - for (LLVM::LLVMType type : {LLVM::LLVMType::getFloatTy(llvmDialect), - LLVM::LLVMType::getInt32Ty(llvmDialect), - LLVM::LLVMType::getInt16Ty(llvmDialect), - LLVM::LLVMType::getInt8Ty(llvmDialect), - LLVM::LLVMType::getHalfTy(llvmDialect)}) { + for (LLVM::LLVMType type : {LLVM::LLVMType::getFloatTy(&getContext()), + LLVM::LLVMType::getInt32Ty(&getContext()), + LLVM::LLVMType::getInt16Ty(&getContext()), + LLVM::LLVMType::getInt8Ty(&getContext()), + LLVM::LLVMType::getHalfTy(&getContext())}) { std::string fnName = "bindMemRef" + std::to_string(i) + "D" + std::string(stringifyType(type)); if (type.isHalfTy()) - type = getMemRefType(i, LLVM::LLVMType::getInt16Ty(llvmDialect)); + type = getMemRefType(i, LLVM::LLVMType::getInt16Ty(&getContext())); if (!module.lookupSymbol(fnName)) { auto fnType = LLVM::LLVMType::getFunctionTy( getVoidType(), @@ -368,8 +364,7 @@ std::string entryPointGlobalName = (name + "_spv_entry_point_name").str(); return LLVM::createGlobalString(loc, builder, entryPointGlobalName, - shaderName, LLVM::Linkage::Internal, - getLLVMDialect()); + shaderName, LLVM::Linkage::Internal); } void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall( @@ -388,7 +383,7 @@ // that data to runtime call. Value ptrToSPIRVBinary = LLVM::createGlobalString( loc, builder, kSPIRVBinary, spirvAttributes.first.getValue(), - LLVM::Linkage::Internal, getLLVMDialect()); + LLVM::Linkage::Internal); // Create LLVM constant for the size of SPIR-V binary shader. Value binarySize = builder.create( diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp @@ -186,15 +186,15 @@ llvm::map_range(type.getElementTypes(), [&](Type elementType) { return converter.convertType(elementType).cast(); })); - return LLVM::LLVMType::getStructTy(converter.getDialect(), elementsVector, + return LLVM::LLVMType::getStructTy(type.getContext(), elementsVector, /*isPacked=*/true); } /// Creates LLVM dialect constant with the given value. static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter, - LLVMTypeConverter &converter, unsigned value) { + unsigned value) { return rewriter.create( - loc, LLVM::LLVMType::getInt32Ty(converter.getDialect()), + loc, LLVM::LLVMType::getInt32Ty(rewriter.getContext()), rewriter.getIntegerAttr(rewriter.getI32Type(), value)); } @@ -1002,7 +1002,7 @@ return failure(); Location loc = varOp.getLoc(); - Value size = createI32ConstantOf(loc, rewriter, typeConverter, 1); + Value size = createI32ConstantOf(loc, rewriter, 1); if (!init) { rewriter.replaceOpWithNewOp(varOp, dstType, size); return success(); diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -199,7 +199,7 @@ } LLVM::LLVMType LLVMTypeConverter::getIndexType() { - return LLVM::LLVMType::getIntNTy(llvmDialect, getIndexTypeBitwidth()); + return LLVM::LLVMType::getIntNTy(&getContext(), getIndexTypeBitwidth()); } unsigned LLVMTypeConverter::getPointerBitwidth(unsigned addressSpace) { @@ -211,19 +211,19 @@ } Type LLVMTypeConverter::convertIntegerType(IntegerType type) { - return LLVM::LLVMType::getIntNTy(llvmDialect, type.getWidth()); + return LLVM::LLVMType::getIntNTy(&getContext(), type.getWidth()); } Type LLVMTypeConverter::convertFloatType(FloatType type) { switch (type.getKind()) { case mlir::StandardTypes::F32: - return LLVM::LLVMType::getFloatTy(llvmDialect); + return LLVM::LLVMType::getFloatTy(&getContext()); case mlir::StandardTypes::F64: - return LLVM::LLVMType::getDoubleTy(llvmDialect); + return LLVM::LLVMType::getDoubleTy(&getContext()); case mlir::StandardTypes::F16: - return LLVM::LLVMType::getHalfTy(llvmDialect); + return LLVM::LLVMType::getHalfTy(&getContext()); case mlir::StandardTypes::BF16: { - return LLVM::LLVMType::getBFloatTy(llvmDialect); + return LLVM::LLVMType::getBFloatTy(&getContext()); } default: llvm_unreachable("non-float type in convertFloatType"); @@ -238,7 +238,7 @@ static constexpr unsigned kImaginaryPosInComplexNumberStruct = 1; Type LLVMTypeConverter::convertComplexType(ComplexType type) { auto elementType = convertType(type.getElementType()).cast(); - return LLVM::LLVMType::getStructTy(llvmDialect, {elementType, elementType}); + return LLVM::LLVMType::getStructTy(&getContext(), {elementType, elementType}); } // Except for signatures, MLIR function types are converted into LLVM @@ -274,7 +274,7 @@ /// In signatures, unranked MemRef descriptors are expanded into a pair "rank, /// pointer to descriptor". SmallVector LLVMTypeConverter::convertUnrankedMemRefSignature() { - return {getIndexType(), LLVM::LLVMType::getInt8PtrTy(llvmDialect)}; + return {getIndexType(), LLVM::LLVMType::getInt8PtrTy(&getContext())}; } // Function types are converted to LLVM Function types by recursively converting @@ -307,7 +307,7 @@ // a struct. LLVM::LLVMType resultType = type.getNumResults() == 0 - ? LLVM::LLVMType::getVoidTy(llvmDialect) + ? LLVM::LLVMType::getVoidTy(&getContext()) : unwrap(packFunctionResults(type.getResults())); if (!resultType) return {}; @@ -331,7 +331,7 @@ LLVM::LLVMType resultType = type.getNumResults() == 0 - ? LLVM::LLVMType::getVoidTy(llvmDialect) + ? LLVM::LLVMType::getVoidTy(&getContext()) : unwrap(packFunctionResults(type.getResults())); if (!resultType) return {}; @@ -400,7 +400,7 @@ Type LLVMTypeConverter::convertUnrankedMemRefType(UnrankedMemRefType type) { auto rankTy = getIndexType(); - auto ptrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect); + auto ptrTy = LLVM::LLVMType::getInt8PtrTy(&getContext()); return LLVM::LLVMType::getStructTy(rankTy, ptrTy); } @@ -853,11 +853,11 @@ } LLVM::LLVMType ConvertToLLVMPattern::getVoidType() const { - return LLVM::LLVMType::getVoidTy(&getDialect()); + return LLVM::LLVMType::getVoidTy(&typeConverter.getContext()); } LLVM::LLVMType ConvertToLLVMPattern::getVoidPtrType() const { - return LLVM::LLVMType::getInt8PtrTy(&getDialect()); + return LLVM::LLVMType::getInt8PtrTy(&typeConverter.getContext()); } Value ConvertToLLVMPattern::createIndexConstant( @@ -2025,9 +2025,10 @@ unrankedMemrefs, sizes); // Get frequently used types. - auto voidType = LLVM::LLVMType::getVoidTy(typeConverter.getDialect()); - auto voidPtrType = LLVM::LLVMType::getInt8PtrTy(typeConverter.getDialect()); - auto i1Type = LLVM::LLVMType::getInt1Ty(typeConverter.getDialect()); + MLIRContext *context = builder.getContext(); + auto voidType = LLVM::LLVMType::getVoidTy(context); + auto voidPtrType = LLVM::LLVMType::getInt8PtrTy(context); + auto i1Type = LLVM::LLVMType::getInt1Ty(context); LLVM::LLVMType indexType = typeConverter.getIndexType(); // Find the malloc and free, or declare them if necessary. @@ -3168,7 +3169,7 @@ // Append the cmpxchg op to the end of the loop block. auto successOrdering = LLVM::AtomicOrdering::acq_rel; auto failureOrdering = LLVM::AtomicOrdering::monotonic; - auto boolType = LLVM::LLVMType::getInt1Ty(&getDialect()); + auto boolType = LLVM::LLVMType::getInt1Ty(rewriter.getContext()); auto pairType = LLVM::LLVMType::getStructTy(valueType, boolType); auto cmpxchg = rewriter.create( loc, pairType, dataPtr, loopArgument, result, successOrdering, @@ -3330,13 +3331,13 @@ resultTypes.push_back(converted); } - return LLVM::LLVMType::getStructTy(llvmDialect, resultTypes); + return LLVM::LLVMType::getStructTy(&getContext(), resultTypes); } Value LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, Value operand, OpBuilder &builder) { auto *context = builder.getContext(); - auto int64Ty = LLVM::LLVMType::getInt64Ty(getDialect()); + auto int64Ty = LLVM::LLVMType::getInt64Ty(builder.getContext()); auto indexType = IndexType::get(context); // Alloca with proper alignment. We do not expect optimizations of this // alloca op and so we omit allocating at the entry block. diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -715,7 +715,7 @@ // Remaining extraction of element from 1-D LLVM vector auto position = positionAttrs.back().cast(); - auto i64Type = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect()); + auto i64Type = LLVM::LLVMType::getInt64Ty(rewriter.getContext()); auto constant = rewriter.create(loc, i64Type, position); extracted = rewriter.create(loc, extracted, constant); @@ -832,7 +832,7 @@ } // Insertion of an element into a 1-D LLVM vector. - auto i64Type = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect()); + auto i64Type = LLVM::LLVMType::getInt64Ty(rewriter.getContext()); auto constant = rewriter.create(loc, i64Type, position); Value inserted = rewriter.create( loc, typeConverter.convertType(oneDVectorType), extracted, @@ -1074,7 +1074,7 @@ if (failed(successStrides) || !isContiguous) return failure(); - auto int64Ty = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect()); + auto int64Ty = LLVM::LLVMType::getInt64Ty(rewriter.getContext()); // Create descriptor. auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); @@ -1263,11 +1263,10 @@ int64_t rank) const { Location loc = op->getLoc(); if (rank == 0) { - if (value.getType() == - LLVM::LLVMType::getInt1Ty(typeConverter.getDialect())) { + if (value.getType() == LLVM::LLVMType::getInt1Ty(rewriter.getContext())) { // Convert i1 (bool) to i32 so we can use the print_i32 method. // This avoids the need for a print_i1 method with an unclear ABI. - auto i32Type = LLVM::LLVMType::getInt32Ty(typeConverter.getDialect()); + auto i32Type = LLVM::LLVMType::getInt32Ty(rewriter.getContext()); auto trueVal = rewriter.create( loc, i32Type, rewriter.getI32IntegerAttr(1)); auto falseVal = rewriter.create( @@ -1303,8 +1302,8 @@ } // Helper for printer method declaration (first hit) and lookup. - static Operation *getPrint(Operation *op, LLVM::LLVMDialect *dialect, - StringRef name, ArrayRef params) { + static Operation *getPrint(Operation *op, StringRef name, + ArrayRef params) { auto module = op->getParentOfType(); auto func = module.lookupSymbol(name); if (func) @@ -1312,42 +1311,39 @@ OpBuilder moduleBuilder(module.getBodyRegion()); return moduleBuilder.create( op->getLoc(), name, - LLVM::LLVMType::getFunctionTy(LLVM::LLVMType::getVoidTy(dialect), - params, /*isVarArg=*/false)); + LLVM::LLVMType::getFunctionTy( + LLVM::LLVMType::getVoidTy(op->getContext()), params, + /*isVarArg=*/false)); } // Helpers for method names. Operation *getPrintI32(Operation *op) const { - LLVM::LLVMDialect *dialect = typeConverter.getDialect(); - return getPrint(op, dialect, "print_i32", - LLVM::LLVMType::getInt32Ty(dialect)); + return getPrint(op, "print_i32", + LLVM::LLVMType::getInt32Ty(op->getContext())); } Operation *getPrintI64(Operation *op) const { - LLVM::LLVMDialect *dialect = typeConverter.getDialect(); - return getPrint(op, dialect, "print_i64", - LLVM::LLVMType::getInt64Ty(dialect)); + return getPrint(op, "print_i64", + LLVM::LLVMType::getInt64Ty(op->getContext())); } Operation *getPrintFloat(Operation *op) const { - LLVM::LLVMDialect *dialect = typeConverter.getDialect(); - return getPrint(op, dialect, "print_f32", - LLVM::LLVMType::getFloatTy(dialect)); + return getPrint(op, "print_f32", + LLVM::LLVMType::getFloatTy(op->getContext())); } Operation *getPrintDouble(Operation *op) const { - LLVM::LLVMDialect *dialect = typeConverter.getDialect(); - return getPrint(op, dialect, "print_f64", - LLVM::LLVMType::getDoubleTy(dialect)); + return getPrint(op, "print_f64", + LLVM::LLVMType::getDoubleTy(op->getContext())); } Operation *getPrintOpen(Operation *op) const { - return getPrint(op, typeConverter.getDialect(), "print_open", {}); + return getPrint(op, "print_open", {}); } Operation *getPrintClose(Operation *op) const { - return getPrint(op, typeConverter.getDialect(), "print_close", {}); + return getPrint(op, "print_close", {}); } Operation *getPrintComma(Operation *op) const { - return getPrint(op, typeConverter.getDialect(), "print_comma", {}); + return getPrint(op, "print_comma", {}); } Operation *getPrintNewline(Operation *op) const { - return getPrint(op, typeConverter.getDialect(), "print_newline", {}); + return getPrint(op, "print_newline", {}); } }; diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -101,8 +101,7 @@ // The result type is either i1 or a vector type if the inputs are // vectors. - auto *dialect = builder.getContext()->getRegisteredDialect(); - auto resultType = LLVMType::getInt1Ty(dialect); + auto resultType = LLVMType::getInt1Ty(builder.getContext()); auto argType = type.dyn_cast(); if (!argType) return parser.emitError(trailingTypeLoc, "expected LLVM IR dialect type"); @@ -393,11 +392,9 @@ return parser.emitError(trailingTypeLoc, "expected function with 0 or 1 result"); - auto *llvmDialect = - builder.getContext()->getRegisteredDialect(); LLVM::LLVMType llvmResultType; if (funcType.getNumResults() == 0) { - llvmResultType = LLVM::LLVMType::getVoidTy(llvmDialect); + llvmResultType = LLVM::LLVMType::getVoidTy(builder.getContext()); } else { llvmResultType = funcType.getResult(0).dyn_cast(); if (!llvmResultType) @@ -601,11 +598,9 @@ "expected function with 0 or 1 result"); Builder &builder = parser.getBuilder(); - auto *llvmDialect = - builder.getContext()->getRegisteredDialect(); LLVM::LLVMType llvmResultType; if (funcType.getNumResults() == 0) { - llvmResultType = LLVM::LLVMType::getVoidTy(llvmDialect); + llvmResultType = LLVM::LLVMType::getVoidTy(builder.getContext()); } else { llvmResultType = funcType.getResult(0).dyn_cast(); if (!llvmResultType) @@ -1101,9 +1096,8 @@ if (types.empty()) { if (auto strAttr = value.dyn_cast_or_null()) { MLIRContext *context = parser.getBuilder().getContext(); - auto *dialect = context->getRegisteredDialect(); auto arrayType = LLVM::LLVMType::getArrayTy( - LLVM::LLVMType::getInt8Ty(dialect), strAttr.getValue().size()); + LLVM::LLVMType::getInt8Ty(context), strAttr.getValue().size()); types.push_back(arrayType); } else { return parser.emitError(parser.getNameLoc(), @@ -1265,14 +1259,8 @@ llvmInputs.push_back(llvmTy); } - // Get the dialect from the input type, if any exist. Look it up in the - // context otherwise. - LLVMDialect *dialect = - llvmInputs.empty() ? b.getContext()->getRegisteredDialect() - : &llvmInputs.front().getDialect(); - // No output is denoted as "void" in LLVM type system. - LLVMType llvmOutput = outputs.empty() ? LLVMType::getVoidTy(dialect) + LLVMType llvmOutput = outputs.empty() ? LLVMType::getVoidTy(b.getContext()) : outputs.front().dyn_cast(); if (!llvmOutput) { parser.emitError(loc, "failed to construct function type: expected LLVM " @@ -1605,8 +1593,7 @@ parser.resolveOperand(val, type, result.operands)) return failure(); - auto *dialect = builder.getContext()->getRegisteredDialect(); - auto boolType = LLVMType::getInt1Ty(dialect); + auto boolType = LLVMType::getInt1Ty(builder.getContext()); auto resultType = LLVMType::getStructTy(type, boolType); result.addTypes(resultType); @@ -1777,8 +1764,7 @@ Value mlir::LLVM::createGlobalString(Location loc, OpBuilder &builder, StringRef name, StringRef value, - LLVM::Linkage linkage, - LLVM::LLVMDialect *llvmDialect) { + LLVM::Linkage linkage) { assert(builder.getInsertionBlock() && builder.getInsertionBlock()->getParentOp() && "expected builder to point to a block constrained in an op"); @@ -1788,8 +1774,9 @@ // Create the global at the entry of the module. OpBuilder moduleBuilder(module.getBodyRegion()); - auto type = LLVM::LLVMType::getArrayTy(LLVM::LLVMType::getInt8Ty(llvmDialect), - value.size()); + MLIRContext *ctx = builder.getContext(); + auto type = + LLVM::LLVMType::getArrayTy(LLVM::LLVMType::getInt8Ty(ctx), value.size()); auto global = moduleBuilder.create( loc, type, /*isConstant=*/true, linkage, name, builder.getStringAttr(value)); @@ -1797,10 +1784,9 @@ // Get the pointer to the first character in the global string. Value globalPtr = builder.create(loc, global); Value cst0 = builder.create( - loc, LLVM::LLVMType::getInt64Ty(llvmDialect), + loc, LLVM::LLVMType::getInt64Ty(ctx), builder.getIntegerAttr(builder.getIndexType(), 0)); - return builder.create(loc, - LLVM::LLVMType::getInt8PtrTy(llvmDialect), + return builder.create(loc, LLVM::LLVMType::getInt8PtrTy(ctx), globalPtr, ArrayRef({cst0, cst0})); } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -127,35 +127,35 @@ //----------------------------------------------------------------------------// // Utilities used to generate floating point types. -LLVMType LLVMType::getDoubleTy(LLVMDialect *dialect) { - return LLVMDoubleType::get(dialect->getContext()); +LLVMType LLVMType::getDoubleTy(MLIRContext *context) { + return LLVMDoubleType::get(context); } -LLVMType LLVMType::getFloatTy(LLVMDialect *dialect) { - return LLVMFloatType::get(dialect->getContext()); +LLVMType LLVMType::getFloatTy(MLIRContext *context) { + return LLVMFloatType::get(context); } -LLVMType LLVMType::getBFloatTy(LLVMDialect *dialect) { - return LLVMBFloatType::get(dialect->getContext()); +LLVMType LLVMType::getBFloatTy(MLIRContext *context) { + return LLVMBFloatType::get(context); } -LLVMType LLVMType::getHalfTy(LLVMDialect *dialect) { - return LLVMHalfType::get(dialect->getContext()); +LLVMType LLVMType::getHalfTy(MLIRContext *context) { + return LLVMHalfType::get(context); } -LLVMType LLVMType::getFP128Ty(LLVMDialect *dialect) { - return LLVMFP128Type::get(dialect->getContext()); +LLVMType LLVMType::getFP128Ty(MLIRContext *context) { + return LLVMFP128Type::get(context); } -LLVMType LLVMType::getX86_FP80Ty(LLVMDialect *dialect) { - return LLVMX86FP80Type::get(dialect->getContext()); +LLVMType LLVMType::getX86_FP80Ty(MLIRContext *context) { + return LLVMX86FP80Type::get(context); } //----------------------------------------------------------------------------// // Utilities used to generate integer types. -LLVMType LLVMType::getIntNTy(LLVMDialect *dialect, unsigned numBits) { - return LLVMIntegerType::get(dialect->getContext(), numBits); +LLVMType LLVMType::getIntNTy(MLIRContext *context, unsigned numBits) { + return LLVMIntegerType::get(context, numBits); } //----------------------------------------------------------------------------// @@ -170,9 +170,9 @@ return LLVMFunctionType::get(result, params, isVarArg); } -LLVMType LLVMType::getStructTy(LLVMDialect *dialect, +LLVMType LLVMType::getStructTy(MLIRContext *context, ArrayRef elements, bool isPacked) { - return LLVMStructType::getLiteral(dialect->getContext(), elements, isPacked); + return LLVMStructType::getLiteral(context, elements, isPacked); } LLVMType LLVMType::getVectorTy(LLVMType elementType, unsigned numElements) { @@ -182,8 +182,8 @@ //----------------------------------------------------------------------------// // Void type utilities. -LLVMType LLVMType::getVoidTy(LLVMDialect *dialect) { - return LLVMVoidType::get(dialect->getContext()); +LLVMType LLVMType::getVoidTy(MLIRContext *context) { + return LLVMVoidType::get(context); } bool LLVMType::isVoidTy() { return isa(); } @@ -191,7 +191,7 @@ //----------------------------------------------------------------------------// // Creation and setting of LLVM's identified struct types -LLVMType LLVMType::createStructTy(LLVMDialect *dialect, +LLVMType LLVMType::createStructTy(MLIRContext *context, ArrayRef elements, Optional name, bool isPacked) { assert(name.hasValue() && @@ -200,8 +200,7 @@ std::string stringName = stringNameBase.str(); unsigned counter = 0; do { - auto type = - LLVMStructType::getIdentified(dialect->getContext(), stringName); + auto type = LLVMStructType::getIdentified(context, stringName); if (type.isInitialized() || failed(type.setBody(elements, isPacked))) { counter += 1; stringName = diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -41,12 +41,6 @@ p << " : " << op->getResultTypes(); } -static LLVM::LLVMDialect *getLlvmDialect(OpAsmParser &parser) { - return parser.getBuilder() - .getContext() - ->getRegisteredDialect(); -} - // ::= // `llvm.nvvm.shfl.sync.bfly %dst, %val, %offset, %clamp_and_mask` // ({return_value_and_is_valid})? : result_type @@ -69,7 +63,7 @@ break; } - auto int32Ty = LLVM::LLVMType::getInt32Ty(getLlvmDialect(parser)); + auto int32Ty = LLVM::LLVMType::getInt32Ty(parser.getBuilder().getContext()); return parser.resolveOperands(ops, {int32Ty, type, int32Ty, int32Ty}, parser.getNameLoc(), result.operands); } @@ -77,9 +71,9 @@ // ::= `llvm.nvvm.vote.ballot.sync %mask, %pred` : result_type static ParseResult parseNVVMVoteBallotOp(OpAsmParser &parser, OperationState &result) { - auto llvmDialect = getLlvmDialect(parser); - auto int32Ty = LLVM::LLVMType::getInt32Ty(llvmDialect); - auto int1Ty = LLVM::LLVMType::getInt1Ty(llvmDialect); + MLIRContext *context = parser.getBuilder().getContext(); + auto int32Ty = LLVM::LLVMType::getInt32Ty(context); + auto int1Ty = LLVM::LLVMType::getInt1Ty(context); SmallVector ops; Type type; @@ -92,14 +86,14 @@ } static LogicalResult verify(MmaOp op) { - auto dialect = op.getContext()->getRegisteredDialect(); - auto f16Ty = LLVM::LLVMType::getHalfTy(dialect); + MLIRContext *context = op.getContext(); + auto f16Ty = LLVM::LLVMType::getHalfTy(context); auto f16x2Ty = LLVM::LLVMType::getVectorTy(f16Ty, 2); - auto f32Ty = LLVM::LLVMType::getFloatTy(dialect); + auto f32Ty = LLVM::LLVMType::getFloatTy(context); auto f16x2x4StructTy = LLVM::LLVMType::getStructTy( - dialect, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty}); + context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty}); auto f32x8StructTy = LLVM::LLVMType::getStructTy( - dialect, {f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty}); + context, {f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty}); SmallVector operand_types(op.getOperandTypes().begin(), op.getOperandTypes().end()); diff --git a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp @@ -34,12 +34,6 @@ // Parsing for ROCDL ops //===----------------------------------------------------------------------===// -static LLVM::LLVMDialect *getLlvmDialect(OpAsmParser &parser) { - return parser.getBuilder() - .getContext() - ->getRegisteredDialect(); -} - // ::= // `llvm.amdgcn.buffer.load.* %rsrc, %vindex, %offset, %glc, %slc : // result_type` @@ -51,8 +45,9 @@ parser.addTypeToList(type, result.types)) return failure(); - auto int32Ty = LLVM::LLVMType::getInt32Ty(getLlvmDialect(parser)); - auto int1Ty = LLVM::LLVMType::getInt1Ty(getLlvmDialect(parser)); + MLIRContext *context = parser.getBuilder().getContext(); + auto int32Ty = LLVM::LLVMType::getInt32Ty(context); + auto int1Ty = LLVM::LLVMType::getInt1Ty(context); auto i32x4Ty = LLVM::LLVMType::getVectorTy(int32Ty, 4); return parser.resolveOperands(ops, {i32x4Ty, int32Ty, int32Ty, int1Ty, int1Ty}, @@ -69,8 +64,9 @@ if (parser.parseOperandList(ops, 6) || parser.parseColonType(type)) return failure(); - auto int32Ty = LLVM::LLVMType::getInt32Ty(getLlvmDialect(parser)); - auto int1Ty = LLVM::LLVMType::getInt1Ty(getLlvmDialect(parser)); + MLIRContext *context = parser.getBuilder().getContext(); + auto int32Ty = LLVM::LLVMType::getInt32Ty(context); + auto int1Ty = LLVM::LLVMType::getInt1Ty(context); auto i32x4Ty = LLVM::LLVMType::getVectorTy(int32Ty, 4); if (parser.resolveOperands(ops,