diff --git a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp --- a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp +++ b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp @@ -111,10 +111,11 @@ // Create a function declaration for printf, the signature is: // * `i32 (i8*, ...)` - auto llvmI32Ty = LLVM::LLVMType::getInt32Ty(context); - auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(context); - auto llvmFnType = LLVM::LLVMType::getFunctionTy(llvmI32Ty, llvmI8PtrTy, - /*isVarArg=*/true); + auto llvmI32Ty = LLVM::LLVMIntegerType::get(context, 32); + auto llvmI8PtrTy = + LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(context, 8)); + auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmI8PtrTy, + /*isVarArg=*/true); // Insert the printf function into the body of the parent module. PatternRewriter::InsertionGuard insertGuard(rewriter); @@ -133,8 +134,8 @@ if (!(global = module.lookupSymbol(name))) { OpBuilder::InsertionGuard insertGuard(builder); builder.setInsertionPointToStart(module.getBody()); - auto type = LLVM::LLVMType::getArrayTy( - LLVM::LLVMType::getInt8Ty(builder.getContext()), value.size()); + auto type = LLVM::LLVMArrayType::get( + LLVM::LLVMIntegerType::get(builder.getContext(), 8), value.size()); global = builder.create(loc, type, /*isConstant=*/true, LLVM::Linkage::Internal, name, builder.getStringAttr(value)); @@ -143,11 +144,13 @@ // Get the pointer to the first character in the global string. Value globalPtr = builder.create(loc, global); Value cst0 = builder.create( - loc, LLVM::LLVMType::getInt64Ty(builder.getContext()), + loc, LLVM::LLVMIntegerType::get(builder.getContext(), 64), builder.getIntegerAttr(builder.getIndexType(), 0)); return builder.create( - loc, LLVM::LLVMType::getInt8PtrTy(builder.getContext()), globalPtr, - ArrayRef({cst0, cst0})); + loc, + LLVM::LLVMPointerType::get( + LLVM::LLVMIntegerType::get(builder.getContext(), 8)), + globalPtr, ArrayRef({cst0, cst0})); } }; } // end anonymous namespace diff --git a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp --- a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp +++ b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp @@ -111,10 +111,11 @@ // Create a function declaration for printf, the signature is: // * `i32 (i8*, ...)` - auto llvmI32Ty = LLVM::LLVMType::getInt32Ty(context); - auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(context); - auto llvmFnType = LLVM::LLVMType::getFunctionTy(llvmI32Ty, llvmI8PtrTy, - /*isVarArg=*/true); + auto llvmI32Ty = LLVM::LLVMIntegerType::get(context, 32); + auto llvmI8PtrTy = + LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(context, 8)); + auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmI8PtrTy, + /*isVarArg=*/true); // Insert the printf function into the body of the parent module. PatternRewriter::InsertionGuard insertGuard(rewriter); @@ -133,8 +134,8 @@ if (!(global = module.lookupSymbol(name))) { OpBuilder::InsertionGuard insertGuard(builder); builder.setInsertionPointToStart(module.getBody()); - auto type = LLVM::LLVMType::getArrayTy( - LLVM::LLVMType::getInt8Ty(builder.getContext()), value.size()); + auto type = LLVM::LLVMArrayType::get( + LLVM::LLVMIntegerType::get(builder.getContext(), 8), value.size()); global = builder.create(loc, type, /*isConstant=*/true, LLVM::Linkage::Internal, name, builder.getStringAttr(value)); @@ -143,11 +144,13 @@ // Get the pointer to the first character in the global string. Value globalPtr = builder.create(loc, global); Value cst0 = builder.create( - loc, LLVM::LLVMType::getInt64Ty(builder.getContext()), + loc, LLVM::LLVMIntegerType::get(builder.getContext(), 64), builder.getIntegerAttr(builder.getIndexType(), 0)); return builder.create( - loc, LLVM::LLVMType::getInt8PtrTy(builder.getContext()), globalPtr, - ArrayRef({cst0, cst0})); + loc, + LLVM::LLVMPointerType::get( + LLVM::LLVMIntegerType::get(builder.getContext(), 8)), + globalPtr, ArrayRef({cst0, cst0})); } }; } // end anonymous namespace diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -150,7 +150,7 @@ let builders = [ OpBuilderDAG<(ins "ICmpPredicate":$predicate, "Value":$lhs, "Value":$rhs), [{ - build($_builder, $_state, LLVMType::getInt1Ty(lhs.getType().getContext()), + build($_builder, $_state, LLVMIntegerType::get(lhs.getType().getContext(), 1), $_builder.getI64IntegerAttr(static_cast(predicate)), lhs, rhs); }]>]; let parser = [{ return parseCmpOp(parser, result); }]; @@ -198,7 +198,7 @@ let builders = [ OpBuilderDAG<(ins "FCmpPredicate":$predicate, "Value":$lhs, "Value":$rhs), [{ - build($_builder, $_state, LLVMType::getInt1Ty(lhs.getType().getContext()), + build($_builder, $_state, LLVMIntegerType::get(lhs.getType().getContext(), 1), $_builder.getI64IntegerAttr(static_cast(predicate)), lhs, rhs); }]>]; let parser = [{ return parseCmpOp(parser, result); }]; diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h @@ -53,9 +53,7 @@ /// /// The LLVM dialect in MLIR fully reflects the LLVM IR type system, prodiving a /// separate MLIR type for each LLVM IR type. All types are represented as -/// separate subclasses and are compatible with the isa/cast infrastructure. For -/// convenience, the base class provides most of the APIs available on -/// llvm::Type in addition to MLIR-compatible APIs. +/// separate subclasses and are compatible with the isa/cast infrastructure. /// /// The LLVM dialect type system is closed: parametric types can only refer to /// other LLVM dialect types. This is consistent with LLVM IR and enables a more @@ -64,6 +62,11 @@ /// Similarly to other MLIR types, LLVM dialect types are owned by the MLIR /// context, have an immutable identifier (for most types except identified /// structs, the entire type is the identifier) and are thread-safe. +/// +/// This class is a thin common base class for different types available in the +/// LLVM dialect. It intentionally does not provide the API similar to +/// llvm::Type to avoid confusion and highlight potentially expensive operations +/// (e.g., type creation in MLIR takes a lock, so it's better to cache types). class LLVMType : public Type { public: /// Inherit base constructors. @@ -79,98 +82,6 @@ static bool classof(Type type); LLVMDialect &getDialect(); - - /// Utilities used to generate floating point types. - static LLVMType getDoubleTy(MLIRContext *context); - static LLVMType getFloatTy(MLIRContext *context); - static LLVMType getBFloatTy(MLIRContext *context); - static LLVMType getHalfTy(MLIRContext *context); - static LLVMType getFP128Ty(MLIRContext *context); - static LLVMType getX86_FP80Ty(MLIRContext *context); - - /// Utilities used to generate integer types. - static LLVMType getIntNTy(MLIRContext *context, unsigned numBits); - static LLVMType getInt1Ty(MLIRContext *context) { - return getIntNTy(context, /*numBits=*/1); - } - static LLVMType getInt8Ty(MLIRContext *context) { - return getIntNTy(context, /*numBits=*/8); - } - static LLVMType getInt8PtrTy(MLIRContext *context); - static LLVMType getInt16Ty(MLIRContext *context) { - return getIntNTy(context, /*numBits=*/16); - } - static LLVMType getInt32Ty(MLIRContext *context) { - return getIntNTy(context, /*numBits=*/32); - } - static LLVMType getInt64Ty(MLIRContext *context) { - return getIntNTy(context, /*numBits=*/64); - } - - /// Utilities used to generate other miscellaneous types. - static LLVMType getArrayTy(LLVMType elementType, uint64_t numElements); - static LLVMType getFunctionTy(LLVMType result, ArrayRef params, - bool isVarArg); - static LLVMType getFunctionTy(LLVMType result, bool isVarArg) { - return getFunctionTy(result, llvm::None, isVarArg); - } - static LLVMType getStructTy(MLIRContext *context, ArrayRef elements, - bool isPacked = false); - static LLVMType getStructTy(MLIRContext *context, bool isPacked = false) { - return getStructTy(context, llvm::None, isPacked); - } - template - static typename std::enable_if::value, - LLVMType>::type - getStructTy(LLVMType elt1, Args... elts) { - SmallVector fields({elt1, elts...}); - return getStructTy(elt1.getContext(), fields); - } - static LLVMType getVectorTy(LLVMType elementType, unsigned numElements); - - /// Void type utilities. - static LLVMType getVoidTy(MLIRContext *context); - - // Creation and setting of LLVM's identified struct types - static LLVMType createStructTy(MLIRContext *context, - ArrayRef elements, - Optional name, - bool isPacked = false); - - static LLVMType createStructTy(MLIRContext *context, - Optional name) { - return createStructTy(context, llvm::None, name); - } - - static LLVMType createStructTy(ArrayRef elements, - Optional name, - bool isPacked = false) { - assert(!elements.empty() && - "This method may not be invoked with an empty list"); - LLVMType ele0 = elements.front(); - return createStructTy(ele0.getContext(), elements, name, isPacked); - } - - template - static typename std::enable_if_t::value, - LLVMType> - createStructTy(StringRef name, LLVMType elt1, Args... elts) { - SmallVector fields({elt1, elts...}); - Optional opt_name(name); - return createStructTy(elt1.getContext(), fields, opt_name); - } - - static LLVMType setStructTyBody(LLVMType structType, - ArrayRef elements, - bool isPacked = false); - - template - static typename std::enable_if_t::value, - LLVMType> - setStructTyBody(LLVMType structType, LLVMType elt1, Args... elts) { - SmallVector fields({elt1, elts...}); - return setStructTyBody(structType, fields); - } }; //===----------------------------------------------------------------------===// @@ -386,6 +297,14 @@ static LLVMStructType getIdentified(MLIRContext *context, StringRef name); static LLVMStructType getIdentifiedChecked(Location loc, StringRef name); + /// Gets a new identified struct with the given body. The body _cannot_ be + /// changed later. If a struct with the given name already exists, renames + /// the struct by appending a `.` followed by a number to the name. Renaming + /// happens even if the existing struct has the same body. + static LLVMStructType getNewIdentified(MLIRContext *context, StringRef name, + ArrayRef elements, + bool isPacked = false); + /// Gets or creates a literal struct with the given body in the provided /// context. static LLVMStructType getLiteral(MLIRContext *context, diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp --- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp +++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp @@ -52,7 +52,7 @@ // Async Runtime API function types. struct AsyncAPI { static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) { - auto ref = LLVM::LLVMType::getInt8PtrTy(ctx); + auto ref = LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(ctx, 8)); auto count = IntegerType::get(ctx, 32); return FunctionType::get(ctx, {ref, count}, {}); } @@ -78,7 +78,7 @@ } static FunctionType executeFunctionType(MLIRContext *ctx) { - auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx); + auto hdl = LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(ctx, 8)); auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); return FunctionType::get(ctx, {hdl, resume}, {}); } @@ -90,22 +90,22 @@ } static FunctionType awaitAndExecuteFunctionType(MLIRContext *ctx) { - auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx); + auto hdl = LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(ctx, 8)); auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); return FunctionType::get(ctx, {TokenType::get(ctx), hdl, resume}, {}); } static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) { - auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx); + auto hdl = LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(ctx, 8)); auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); return FunctionType::get(ctx, {GroupType::get(ctx), hdl, resume}, {}); } // Auxiliary coroutine resume intrinsic wrapper. static LLVM::LLVMType resumeFunctionType(MLIRContext *ctx) { - auto voidTy = LLVM::LLVMType::getVoidTy(ctx); - auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(ctx); - return LLVM::LLVMType::getFunctionTy(voidTy, {i8Ptr}, false); + auto voidTy = LLVM::LLVMVoidType::get(ctx); + auto i8Ptr = LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(ctx, 8)); + return LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}, false); } }; } // namespace @@ -154,7 +154,7 @@ ArrayRef params) { if (module.lookupSymbol(name)) return; - LLVM::LLVMType type = LLVM::LLVMType::getFunctionTy(ret, params, false); + LLVM::LLVMType type = LLVM::LLVMFunctionType::get(ret, params); builder.create(module.getLoc(), name, type); } @@ -166,13 +166,13 @@ OpBuilder builder(module.getBody()->getTerminator()); auto token = LLVMTokenType::get(ctx); - auto voidTy = LLVMType::getVoidTy(ctx); + auto voidTy = LLVMVoidType::get(ctx); - auto i8 = LLVMType::getInt8Ty(ctx); - auto i1 = LLVMType::getInt1Ty(ctx); - auto i32 = LLVMType::getInt32Ty(ctx); - auto i64 = LLVMType::getInt64Ty(ctx); - auto i8Ptr = LLVMType::getInt8PtrTy(ctx); + auto i8 = LLVMIntegerType::get(ctx, 8); + auto i1 = LLVMIntegerType::get(ctx, 1); + auto i32 = LLVMIntegerType::get(ctx, 32); + auto i64 = LLVMIntegerType::get(ctx, 64); + auto i8Ptr = LLVMPointerType::get(i8); addLLVMFuncDecl(module, builder, kCoroId, token, {i32, i8Ptr, i8Ptr, i8Ptr}); addLLVMFuncDecl(module, builder, kCoroSizeI64, i64, {}); @@ -198,9 +198,9 @@ MLIRContext *ctx = module.getContext(); OpBuilder builder(module.getBody()->getTerminator()); - auto voidTy = LLVMType::getVoidTy(ctx); - auto i64 = LLVMType::getInt64Ty(ctx); - auto i8Ptr = LLVMType::getInt8PtrTy(ctx); + auto voidTy = LLVMVoidType::get(ctx); + auto i64 = LLVMIntegerType::get(ctx, 64); + auto i8Ptr = LLVMPointerType::get(LLVMIntegerType::get(ctx, 8)); addLLVMFuncDecl(module, builder, kMalloc, i8Ptr, {i64}); addLLVMFuncDecl(module, builder, kFree, voidTy, {i8Ptr}); @@ -224,11 +224,11 @@ if (module.lookupSymbol(kResume)) return; - auto voidTy = LLVM::LLVMType::getVoidTy(ctx); - auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(ctx); + auto voidTy = LLVM::LLVMVoidType::get(ctx); + auto i8Ptr = LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(ctx, 8)); auto resumeOp = moduleBuilder.create( - loc, kResume, LLVM::LLVMType::getFunctionTy(voidTy, {i8Ptr}, false)); + loc, kResume, LLVM::LLVMFunctionType::get(voidTy, {i8Ptr})); resumeOp.setPrivate(); auto *block = resumeOp.addEntryBlock(); @@ -294,10 +294,10 @@ MLIRContext *ctx = func.getContext(); auto token = LLVM::LLVMTokenType::get(ctx); - auto i1 = LLVM::LLVMType::getInt1Ty(ctx); - auto i32 = LLVM::LLVMType::getInt32Ty(ctx); - auto i64 = LLVM::LLVMType::getInt64Ty(ctx); - auto i8Ptr = LLVM::LLVMType::getInt8PtrTy(ctx); + auto i1 = LLVM::LLVMIntegerType::get(ctx, 1); + auto i32 = LLVM::LLVMIntegerType::get(ctx, 32); + auto i64 = LLVM::LLVMIntegerType::get(ctx, 64); + auto i8Ptr = LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(ctx, 8)); Block *entryBlock = func.addEntryBlock(); Location loc = func.getBody().getLoc(); @@ -420,8 +420,8 @@ OpBuilder &builder) { Location loc = op->getLoc(); MLIRContext *ctx = op->getContext(); - auto i1 = LLVM::LLVMType::getInt1Ty(ctx); - auto i8 = LLVM::LLVMType::getInt8Ty(ctx); + auto i1 = LLVM::LLVMIntegerType::get(ctx, 1); + auto i8 = LLVM::LLVMIntegerType::get(ctx, 8); // Add a coroutine suspension in place of original `op` in the split block. OpBuilder::InsertionGuard guard(builder); @@ -570,7 +570,7 @@ MLIRContext *ctx = type.getContext(); // Convert async tokens and groups to opaque pointers. if (type.isa()) - return LLVM::LLVMType::getInt8PtrTy(ctx); + return LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(ctx, 8)); return type; } }; diff --git a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp --- a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp +++ b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp @@ -55,8 +55,7 @@ FunctionCallBuilder(StringRef functionName, LLVM::LLVMType returnType, ArrayRef argumentTypes) : functionName(functionName), - functionType(LLVM::LLVMFunctionType::get(returnType, argumentTypes, - /*isVarArg=*/false)) {} + functionType(LLVM::LLVMFunctionType::get(returnType, argumentTypes)) {} LLVM::CallOp create(Location loc, OpBuilder &builder, ArrayRef arguments) const; @@ -74,14 +73,15 @@ protected: MLIRContext *context = &this->getTypeConverter()->getContext(); - LLVM::LLVMType llvmVoidType = LLVM::LLVMType::getVoidTy(context); - LLVM::LLVMType llvmPointerType = LLVM::LLVMType::getInt8PtrTy(context); + LLVM::LLVMType llvmVoidType = LLVM::LLVMVoidType::get(context); + LLVM::LLVMType llvmPointerType = + LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(context, 8)); LLVM::LLVMType llvmPointerPointerType = LLVM::LLVMPointerType::get(llvmPointerType); - LLVM::LLVMType llvmInt8Type = LLVM::LLVMType::getInt8Ty(context); - LLVM::LLVMType llvmInt32Type = LLVM::LLVMType::getInt32Ty(context); - LLVM::LLVMType llvmInt64Type = LLVM::LLVMType::getInt64Ty(context); - LLVM::LLVMType llvmIntPtrType = LLVM::LLVMType::getIntNTy( + LLVM::LLVMType llvmInt8Type = LLVM::LLVMIntegerType::get(context, 8); + LLVM::LLVMType llvmInt32Type = LLVM::LLVMIntegerType::get(context, 32); + LLVM::LLVMType llvmInt64Type = LLVM::LLVMIntegerType::get(context, 64); + LLVM::LLVMType llvmIntPtrType = LLVM::LLVMIntegerType::get( context, this->getTypeConverter()->getPointerBitwidth(0)); FunctionCallBuilder moduleLoadCallBuilder = { @@ -495,7 +495,8 @@ argumentTypes.reserve(numArguments); for (auto argument : arguments) argumentTypes.push_back(argument.getType().cast()); - auto structType = LLVM::LLVMType::createStructTy(argumentTypes, StringRef()); + auto structType = LLVM::LLVMStructType::getNewIdentified(context, StringRef(), + argumentTypes); auto one = builder.create(loc, llvmInt32Type, builder.getI32IntegerAttr(1)); auto structPtr = builder.create( @@ -652,10 +653,10 @@ void mlir::populateGpuToLLVMConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns, StringRef gpuBinaryAnnotation) { - converter.addConversion( - [context = &converter.getContext()](gpu::AsyncTokenType type) -> Type { - return LLVM::LLVMType::getInt8PtrTy(context); - }); + converter.addConversion([context = &converter.getContext()]( + gpu::AsyncTokenType type) -> Type { + return LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(context, 8)); + }); patterns.insertconvertType(type.getElementType()) .template cast(); - auto arrayType = LLVM::LLVMType::getArrayTy(elementType, numElements); + auto arrayType = LLVM::LLVMArrayType::get(elementType, numElements); std::string name = std::string( llvm::formatv("__wg_{0}_{1}", gpuFuncOp.getName(), en.index())); auto globalOp = rewriter.create( @@ -85,7 +85,7 @@ // Rewrite workgroup memory attributions to addresses of global buffers. rewriter.setInsertionPointToStart(&gpuFuncOp.front()); unsigned numProperArguments = gpuFuncOp.getNumArguments(); - auto i32Type = LLVM::LLVMType::getInt32Ty(rewriter.getContext()); + auto i32Type = LLVM::LLVMIntegerType::get(rewriter.getContext(), 32); Value zero = nullptr; if (!workgroupBuffers.empty()) @@ -114,7 +114,7 @@ // Rewrite private memory attributions to alloca'ed buffers. unsigned numWorkgroupAttributions = gpuFuncOp.getNumWorkgroupAttributions(); - auto int64Ty = LLVM::LLVMType::getInt64Ty(rewriter.getContext()); + auto int64Ty = LLVM::LLVMIntegerType::get(rewriter.getContext(), 64); for (auto en : llvm::enumerate(gpuFuncOp.getPrivateAttributions())) { Value attribution = en.value(); auto type = attribution.getType().cast(); diff --git a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h --- a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h +++ b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h @@ -48,13 +48,16 @@ Value newOp; switch (dimensionToIndex(op)) { case X: - newOp = rewriter.create(loc, LLVM::LLVMType::getInt32Ty(context)); + newOp = + rewriter.create(loc, LLVM::LLVMIntegerType::get(context, 32)); break; case Y: - newOp = rewriter.create(loc, LLVM::LLVMType::getInt32Ty(context)); + newOp = + rewriter.create(loc, LLVM::LLVMIntegerType::get(context, 32)); break; case Z: - newOp = rewriter.create(loc, LLVM::LLVMType::getInt32Ty(context)); + newOp = + rewriter.create(loc, LLVM::LLVMIntegerType::get(context, 32)); break; default: return failure(); @@ -62,10 +65,10 @@ if (indexBitwidth > 32) { newOp = rewriter.create( - loc, LLVM::LLVMType::getIntNTy(context, indexBitwidth), newOp); + loc, LLVM::LLVMIntegerType::get(context, indexBitwidth), newOp); } else if (indexBitwidth < 32) { newOp = rewriter.create( - loc, LLVM::LLVMType::getIntNTy(context, indexBitwidth), newOp); + loc, LLVM::LLVMIntegerType::get(context, indexBitwidth), newOp); } rewriter.replaceOp(op, {newOp}); diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h --- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h +++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h @@ -85,7 +85,7 @@ return operand; return rewriter.create( - operand.getLoc(), LLVM::LLVMType::getFloatTy(rewriter.getContext()), + operand.getLoc(), LLVM::LLVMFloatType::get(rewriter.getContext()), operand); } @@ -96,8 +96,7 @@ for (Value operand : operands) { operandTypes.push_back(operand.getType().cast()); } - return LLVMType::getFunctionTy(resultType, operandTypes, - /*isVarArg=*/false); + return LLVM::LLVMFunctionType::get(resultType, operandTypes); } StringRef getFunctionName(LLVM::LLVMType type) const { diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -57,10 +57,10 @@ gpu::ShuffleOpAdaptor adaptor(operands); auto valueTy = adaptor.value().getType().cast(); - auto int32Type = LLVM::LLVMType::getInt32Ty(rewriter.getContext()); - auto predTy = LLVM::LLVMType::getInt1Ty(rewriter.getContext()); - auto resultTy = - LLVM::LLVMType::getStructTy(rewriter.getContext(), {valueTy, predTy}); + auto int32Type = LLVM::LLVMIntegerType::get(rewriter.getContext(), 32); + auto predTy = LLVM::LLVMIntegerType::get(rewriter.getContext(), 1); + auto resultTy = LLVM::LLVMStructType::getLiteral(rewriter.getContext(), + {valueTy, predTy}); Value one = rewriter.create( loc, int32Type, rewriter.getI32IntegerAttr(1)); diff --git a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp --- a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp +++ b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp @@ -57,11 +57,12 @@ VulkanLaunchFuncToVulkanCallsPass> { private: void initializeCachedTypes() { - llvmFloatType = LLVM::LLVMType::getFloatTy(&getContext()); - llvmVoidType = LLVM::LLVMType::getVoidTy(&getContext()); - llvmPointerType = LLVM::LLVMType::getInt8PtrTy(&getContext()); - llvmInt32Type = LLVM::LLVMType::getInt32Ty(&getContext()); - llvmInt64Type = LLVM::LLVMType::getInt64Ty(&getContext()); + llvmFloatType = LLVM::LLVMFloatType::get(&getContext()); + llvmVoidType = LLVM::LLVMVoidType::get(&getContext()); + llvmPointerType = LLVM::LLVMPointerType::get( + LLVM::LLVMIntegerType::get(&getContext(), 8)); + llvmInt32Type = LLVM::LLVMIntegerType::get(&getContext(), 32); + llvmInt64Type = LLVM::LLVMIntegerType::get(&getContext(), 64); } LLVM::LLVMType getMemRefType(uint32_t rank, LLVM::LLVMType elemenType) { @@ -77,12 +78,12 @@ // }; auto llvmPtrToElementType = LLVM::LLVMPointerType::get(elemenType); auto llvmArrayRankElementSizeType = - LLVM::LLVMType::getArrayTy(getInt64Type(), rank); + LLVM::LLVMArrayType::get(getInt64Type(), rank); // Create a type // `!llvm<"{ `element-type`*, `element-type`*, i64, // [`rank` x i64], [`rank` x i64]}">`. - return LLVM::LLVMType::getStructTy( + return LLVM::LLVMStructType::getLiteral( &getContext(), {llvmPtrToElementType, llvmPtrToElementType, getInt64Type(), llvmArrayRankElementSizeType, llvmArrayRankElementSizeType}); @@ -242,7 +243,7 @@ // int16_t and bitcast the descriptor. if (type.isa()) { auto memRefTy = - getMemRefType(rank, LLVM::LLVMType::getInt16Ty(&getContext())); + getMemRefType(rank, LLVM::LLVMIntegerType::get(&getContext(), 16)); ptrToMemRefDescriptor = builder.create( loc, LLVM::LLVMPointerType::get(memRefTy), ptrToMemRefDescriptor); } @@ -296,47 +297,45 @@ if (!module.lookupSymbol(kSetEntryPoint)) { builder.create( loc, kSetEntryPoint, - LLVM::LLVMType::getFunctionTy(getVoidType(), - {getPointerType(), getPointerType()}, - /*isVarArg=*/false)); + LLVM::LLVMFunctionType::get(getVoidType(), + {getPointerType(), getPointerType()})); } if (!module.lookupSymbol(kSetNumWorkGroups)) { builder.create( loc, kSetNumWorkGroups, - LLVM::LLVMType::getFunctionTy( - getVoidType(), - {getPointerType(), getInt64Type(), getInt64Type(), getInt64Type()}, - /*isVarArg=*/false)); + LLVM::LLVMFunctionType::get(getVoidType(), + {getPointerType(), getInt64Type(), + getInt64Type(), getInt64Type()})); } if (!module.lookupSymbol(kSetBinaryShader)) { builder.create( loc, kSetBinaryShader, - LLVM::LLVMType::getFunctionTy( - getVoidType(), {getPointerType(), getPointerType(), getInt32Type()}, - /*isVarArg=*/false)); + LLVM::LLVMFunctionType::get( + getVoidType(), + {getPointerType(), getPointerType(), getInt32Type()})); } if (!module.lookupSymbol(kRunOnVulkan)) { builder.create( loc, kRunOnVulkan, - LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()}, - /*isVarArg=*/false)); + LLVM::LLVMFunctionType::get(getVoidType(), {getPointerType()})); } for (unsigned i = 1; i <= 3; i++) { - for (LLVM::LLVMType type : {LLVM::LLVMType::getFloatTy(&getContext()), - LLVM::LLVMType::getInt32Ty(&getContext()), - LLVM::LLVMType::getInt16Ty(&getContext()), - LLVM::LLVMType::getInt8Ty(&getContext()), - LLVM::LLVMType::getHalfTy(&getContext())}) { + for (auto type : std::initializer_list{ + LLVM::LLVMFloatType::get(&getContext()), + LLVM::LLVMIntegerType::get(&getContext(), 32), + LLVM::LLVMIntegerType::get(&getContext(), 16), + LLVM::LLVMIntegerType::get(&getContext(), 8), + LLVM::LLVMHalfType::get(&getContext())}) { std::string fnName = "bindMemRef" + std::to_string(i) + "D" + std::string(stringifyType(type)); if (type.isa()) - type = LLVM::LLVMType::getInt16Ty(&getContext()); + type = LLVM::LLVMIntegerType::get(&getContext(), 16); if (!module.lookupSymbol(fnName)) { - auto fnType = LLVM::LLVMType::getFunctionTy( + auto fnType = LLVM::LLVMFunctionType::get( getVoidType(), {getPointerType(), getInt32Type(), getInt32Type(), LLVM::LLVMPointerType::get(getMemRefType(i, type))}, @@ -348,16 +347,13 @@ if (!module.lookupSymbol(kInitVulkan)) { builder.create( - loc, kInitVulkan, - LLVM::LLVMType::getFunctionTy(getPointerType(), {}, - /*isVarArg=*/false)); + loc, kInitVulkan, LLVM::LLVMFunctionType::get(getPointerType(), {})); } if (!module.lookupSymbol(kDeinitVulkan)) { builder.create( loc, kDeinitVulkan, - LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()}, - /*isVarArg=*/false)); + LLVM::LLVMFunctionType::get(getVoidType(), {getPointerType()})); } } diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp --- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp +++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp @@ -86,7 +86,7 @@ auto *context = t.getContext(); auto int64Ty = converter.convertType(IntegerType::get(context, 64)) .cast(); - return LLVMType::getStructTy(int64Ty, int64Ty, int64Ty); + return LLVMStructType::getLiteral(context, {int64Ty, int64Ty, int64Ty}); } namespace { diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp @@ -60,7 +60,7 @@ static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder) { MLIRContext *context = builder.getContext(); - auto llvmI1Type = LLVM::LLVMType::getInt1Ty(context); + auto llvmI1Type = LLVM::LLVMIntegerType::get(context, 1); Value isVolatile = builder.create( loc, llvmI1Type, builder.getBoolAttr(false)); builder.create(loc, dst, src, size, isVolatile); @@ -183,9 +183,8 @@ rewriter.setInsertionPointToStart(module.getBody()); kernelFunc = rewriter.create( rewriter.getUnknownLoc(), newKernelFuncName, - LLVM::LLVMType::getFunctionTy(LLVM::LLVMType::getVoidTy(context), - ArrayRef(), - /*isVarArg=*/false)); + LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(context), + ArrayRef())); rewriter.setInsertionPoint(launchOp); } diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp @@ -195,8 +195,8 @@ llvm::map_range(type.getElementTypes(), [&](Type elementType) { return converter.convertType(elementType).cast(); })); - return LLVM::LLVMType::getStructTy(type.getContext(), elementsVector, - /*isPacked=*/false); + return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector, + /*isPacked=*/false); } /// Converts SPIR-V struct with no offset to packed LLVM struct. @@ -206,15 +206,15 @@ llvm::map_range(type.getElementTypes(), [&](Type elementType) { return converter.convertType(elementType).cast(); })); - return LLVM::LLVMType::getStructTy(type.getContext(), elementsVector, - /*isPacked=*/true); + return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector, + /*isPacked=*/true); } /// Creates LLVM dialect constant with the given value. static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter, unsigned value) { return rewriter.create( - loc, LLVM::LLVMType::getInt32Ty(rewriter.getContext()), + loc, LLVM::LLVMIntegerType::get(rewriter.getContext(), 32), rewriter.getIntegerAttr(rewriter.getI32Type(), value)); } @@ -258,7 +258,7 @@ auto llvmElementType = converter.convertType(elementType).cast(); unsigned numElements = type.getNumElements(); - return LLVM::LLVMType::getArrayTy(llvmElementType, numElements); + return LLVM::LLVMArrayType::get(llvmElementType, numElements); } /// Converts SPIR-V pointer type to LLVM pointer. Pointer's storage class is not @@ -279,7 +279,7 @@ return llvm::None; auto elementType = converter.convertType(type.getElementType()).cast(); - return LLVM::LLVMType::getArrayTy(elementType, 0); + return LLVM::LLVMArrayType::get(elementType, 0); } /// Converts SPIR-V struct to LLVM struct. There is no support of structs with @@ -666,15 +666,15 @@ // int32_t executionMode; // int32_t values[]; // optional values // }; - auto llvmI32Type = LLVM::LLVMType::getInt32Ty(context); + auto llvmI32Type = LLVM::LLVMIntegerType::get(context, 32); SmallVector fields; fields.push_back(llvmI32Type); ArrayAttr values = op.values(); if (!values.empty()) { - auto arrayType = LLVM::LLVMType::getArrayTy(llvmI32Type, values.size()); + auto arrayType = LLVM::LLVMArrayType::get(llvmI32Type, values.size()); fields.push_back(arrayType); } - auto structType = LLVM::LLVMType::getStructTy(context, fields); + auto structType = LLVM::LLVMStructType::getLiteral(context, fields); // Create `llvm.mlir.global` with initializer region containing one block. auto global = rewriter.create( diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -171,7 +171,7 @@ } LLVM::LLVMType LLVMTypeConverter::getIndexType() { - return LLVM::LLVMType::getIntNTy(&getContext(), getIndexTypeBitwidth()); + return LLVM::LLVMIntegerType::get(&getContext(), getIndexTypeBitwidth()); } unsigned LLVMTypeConverter::getPointerBitwidth(unsigned addressSpace) { @@ -183,18 +183,18 @@ } Type LLVMTypeConverter::convertIntegerType(IntegerType type) { - return LLVM::LLVMType::getIntNTy(&getContext(), type.getWidth()); + return LLVM::LLVMIntegerType::get(&getContext(), type.getWidth()); } Type LLVMTypeConverter::convertFloatType(FloatType type) { if (type.isa()) - return LLVM::LLVMType::getFloatTy(&getContext()); + return LLVM::LLVMFloatType::get(&getContext()); if (type.isa()) - return LLVM::LLVMType::getDoubleTy(&getContext()); + return LLVM::LLVMDoubleType::get(&getContext()); if (type.isa()) - return LLVM::LLVMType::getHalfTy(&getContext()); + return LLVM::LLVMHalfType::get(&getContext()); if (type.isa()) - return LLVM::LLVMType::getBFloatTy(&getContext()); + return LLVM::LLVMBFloatType::get(&getContext()); llvm_unreachable("non-float type in convertFloatType"); } @@ -206,7 +206,8 @@ static constexpr unsigned kImaginaryPosInComplexNumberStruct = 1; Type LLVMTypeConverter::convertComplexType(ComplexType type) { auto elementType = convertType(type.getElementType()).cast(); - return LLVM::LLVMType::getStructTy(&getContext(), {elementType, elementType}); + return LLVM::LLVMStructType::getLiteral(&getContext(), + {elementType, elementType}); } // Except for signatures, MLIR function types are converted into LLVM @@ -249,11 +250,11 @@ // a struct. LLVM::LLVMType resultType = funcTy.getNumResults() == 0 - ? LLVM::LLVMType::getVoidTy(&getContext()) + ? LLVM::LLVMVoidType::get(&getContext()) : unwrap(packFunctionResults(funcTy.getResults())); if (!resultType) return {}; - return LLVM::LLVMType::getFunctionTy(resultType, argTypes, isVariadic); + return LLVM::LLVMFunctionType::get(resultType, argTypes, isVariadic); } /// Converts the function type to a C-compatible format, in particular using @@ -273,12 +274,12 @@ LLVM::LLVMType resultType = type.getNumResults() == 0 - ? LLVM::LLVMType::getVoidTy(&getContext()) + ? LLVM::LLVMVoidType::get(&getContext()) : unwrap(packFunctionResults(type.getResults())); if (!resultType) return {}; - return LLVM::LLVMType::getFunctionTy(resultType, inputs, false); + return LLVM::LLVMFunctionType::get(resultType, inputs); } static constexpr unsigned kAllocatedPtrPosInMemRefDescriptor = 0; @@ -335,7 +336,7 @@ if (unpackAggregates) results.insert(results.end(), 2 * rank, indexTy); else - results.insert(results.end(), 2, LLVM::LLVMType::getArrayTy(indexTy, rank)); + results.insert(results.end(), 2, LLVM::LLVMArrayType::get(indexTy, rank)); return results; } @@ -346,7 +347,7 @@ // unpack the `sizes` and `strides` arrays. SmallVector types = getMemRefDescriptorFields(type, /*unpackAggregates=*/false); - return LLVM::LLVMType::getStructTy(&getContext(), types); + return LLVM::LLVMStructType::getLiteral(&getContext(), types); } static constexpr unsigned kRankInUnrankedMemRefDescriptor = 0; @@ -361,12 +362,13 @@ /// be unranked. SmallVector LLVMTypeConverter::getUnrankedMemRefDescriptorFields() { - return {getIndexType(), LLVM::LLVMType::getInt8PtrTy(&getContext())}; + return {getIndexType(), LLVM::LLVMPointerType::get( + LLVM::LLVMIntegerType::get(&getContext(), 8))}; } Type LLVMTypeConverter::convertUnrankedMemRefType(UnrankedMemRefType type) { - return LLVM::LLVMType::getStructTy(&getContext(), - getUnrankedMemRefDescriptorFields()); + return LLVM::LLVMStructType::getLiteral(&getContext(), + getUnrankedMemRefDescriptorFields()); } /// Convert a memref type to a bare pointer to the memref element type. @@ -407,11 +409,11 @@ auto elementType = unwrap(convertType(type.getElementType())); if (!elementType) return {}; - auto vectorType = - LLVM::LLVMType::getVectorTy(elementType, type.getShape().back()); + LLVM::LLVMType vectorType = + LLVM::LLVMFixedVectorType::get(elementType, type.getShape().back()); auto shape = type.getShape(); for (int i = shape.size() - 2; i >= 0; --i) - vectorType = LLVM::LLVMType::getArrayTy(vectorType, shape[i]); + vectorType = LLVM::LLVMArrayType::get(vectorType, shape[i]); return vectorType; } @@ -620,7 +622,7 @@ int64_t rank) { auto indexTy = indexType.cast(); auto indexPtrTy = LLVM::LLVMPointerType::get(indexTy); - auto arrayTy = LLVM::LLVMType::getArrayTy(indexTy, rank); + auto arrayTy = LLVM::LLVMArrayType::get(indexTy, rank); auto arrayPtrTy = LLVM::LLVMPointerType::get(arrayTy); // Copy size values to stack-allocated memory. @@ -949,8 +951,9 @@ Value memRefDescPtr, LLVM::LLVMPointerType elemPtrPtrType) { LLVM::LLVMType elemPtrTy = elemPtrPtrType.getElementType(); LLVM::LLVMType indexTy = typeConverter.getIndexType(); - LLVM::LLVMType structPtrTy = LLVM::LLVMPointerType::get( - LLVM::LLVMType::getStructTy(elemPtrTy, elemPtrTy, indexTy, indexTy)); + LLVM::LLVMType structPtrTy = + LLVM::LLVMPointerType::get(LLVM::LLVMStructType::getLiteral( + indexTy.getContext(), {elemPtrTy, elemPtrTy, indexTy, indexTy})); Value structPtr = builder.create(loc, structPtrTy, memRefDescPtr); @@ -1031,17 +1034,18 @@ LLVM::LLVMType ConvertToLLVMPattern::getIntPtrType(unsigned addressSpace) const { - return LLVM::LLVMType::getIntNTy( + return LLVM::LLVMIntegerType::get( &getTypeConverter()->getContext(), getTypeConverter()->getPointerBitwidth(addressSpace)); } LLVM::LLVMType ConvertToLLVMPattern::getVoidType() const { - return LLVM::LLVMType::getVoidTy(&getTypeConverter()->getContext()); + return LLVM::LLVMVoidType::get(&getTypeConverter()->getContext()); } LLVM::LLVMType ConvertToLLVMPattern::getVoidPtrType() const { - return LLVM::LLVMType::getInt8PtrTy(&getTypeConverter()->getContext()); + return LLVM::LLVMPointerType::get( + LLVM::LLVMIntegerType::get(&getTypeConverter()->getContext(), 8)); } Value ConvertToLLVMPattern::createIndexConstant( @@ -1729,8 +1733,7 @@ if (!abortFunc) { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(module.getBody()); - auto abortFuncTy = - LLVM::LLVMType::getFunctionTy(getVoidType(), {}, /*isVarArg=*/false); + auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {}); abortFunc = rewriter.create(rewriter.getUnknownLoc(), "abort", abortFuncTy); } @@ -1955,8 +1958,7 @@ for (Value param : params) paramTypes.push_back(param.getType().cast()); auto allocFuncType = - LLVM::LLVMType::getFunctionTy(getVoidPtrType(), paramTypes, - /*isVarArg=*/false); + LLVM::LLVMFunctionType::get(getVoidPtrType(), paramTypes); OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(module.getBody()); allocFuncOp = rewriter.create(rewriter.getUnknownLoc(), @@ -2208,9 +2210,10 @@ // Get frequently used types. MLIRContext *context = builder.getContext(); - auto voidType = LLVM::LLVMType::getVoidTy(context); - auto voidPtrType = LLVM::LLVMType::getInt8PtrTy(context); - auto i1Type = LLVM::LLVMType::getInt1Ty(context); + auto voidType = LLVM::LLVMVoidType::get(context); + LLVM::LLVMType voidPtrType = + LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(context, 8)); + auto i1Type = LLVM::LLVMIntegerType::get(context, 1); LLVM::LLVMType indexType = typeConverter.getIndexType(); // Find the malloc and free, or declare them if necessary. @@ -2221,8 +2224,8 @@ builder.setInsertionPointToStart(module.getBody()); mallocFunc = builder.create( builder.getUnknownLoc(), "malloc", - LLVM::LLVMType::getFunctionTy( - voidPtrType, llvm::makeArrayRef(indexType), /*isVarArg=*/false)); + LLVM::LLVMFunctionType::get(voidPtrType, llvm::makeArrayRef(indexType), + /*isVarArg=*/false)); } auto freeFunc = module.lookupSymbol("free"); if (!freeFunc && !toDynamic) { @@ -2230,8 +2233,8 @@ builder.setInsertionPointToStart(module.getBody()); freeFunc = builder.create( builder.getUnknownLoc(), "free", - LLVM::LLVMType::getFunctionTy(voidType, llvm::makeArrayRef(voidPtrType), - /*isVarArg=*/false)); + LLVM::LLVMFunctionType::get(voidType, llvm::makeArrayRef(voidPtrType), + /*isVarArg=*/false)); } // Initialize shared constants. @@ -2377,8 +2380,7 @@ op->getParentOfType().getBody()); freeFunc = rewriter.create( rewriter.getUnknownLoc(), "free", - LLVM::LLVMType::getFunctionTy(getVoidType(), getVoidPtrType(), - /*isVarArg=*/false)); + LLVM::LLVMFunctionType::get(getVoidType(), getVoidPtrType())); } MemRefDescriptor memref(transformed.memref()); @@ -2405,7 +2407,7 @@ LLVM::LLVMType arrayTy = elementType; // Shape has the outermost dim at index 0, so need to walk it backwards for (int64_t dim : llvm::reverse(type.getShape())) - arrayTy = LLVM::LLVMType::getArrayTy(arrayTy, dim); + arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim); return arrayTy; } @@ -2860,7 +2862,7 @@ Value zeroIndex = createIndexConstant(rewriter, loc, 0); Value pred = rewriter.create( - loc, LLVM::LLVMType::getInt1Ty(rewriter.getContext()), + loc, LLVM::LLVMIntegerType::get(rewriter.getContext(), 1), LLVM::ICmpPredicate::sge, indexArg, zeroIndex); Block *bodyBlock = @@ -3894,8 +3896,9 @@ // Append the cmpxchg op to the end of the loop block. auto successOrdering = LLVM::AtomicOrdering::acq_rel; auto failureOrdering = LLVM::AtomicOrdering::monotonic; - auto boolType = LLVM::LLVMType::getInt1Ty(rewriter.getContext()); - auto pairType = LLVM::LLVMType::getStructTy(valueType, boolType); + auto boolType = LLVM::LLVMIntegerType::get(rewriter.getContext(), 1); + auto pairType = LLVM::LLVMStructType::getLiteral(rewriter.getContext(), + {valueType, boolType}); auto cmpxchg = rewriter.create( loc, pairType, dataPtr, loopArgument, result, successOrdering, failureOrdering); @@ -4072,13 +4075,13 @@ resultTypes.push_back(converted); } - return LLVM::LLVMType::getStructTy(&getContext(), resultTypes); + return LLVM::LLVMStructType::getLiteral(&getContext(), resultTypes); } Value LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, Value operand, OpBuilder &builder) { auto *context = builder.getContext(); - auto int64Ty = LLVM::LLVMType::getInt64Ty(builder.getContext()); + auto int64Ty = LLVM::LLVMIntegerType::get(builder.getContext(), 64); auto indexType = IndexType::get(context); // Alloca with proper alignment. We do not expect optimizations of this // alloca op and so we omit allocating at the entry block. diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -211,7 +211,7 @@ if (failed(getBase(rewriter, loc, memref, memRefType, base))) return failure(); auto pType = MemRefDescriptor(memref).getElementPtrType(); - auto ptrsType = LLVM::LLVMType::getVectorTy(pType, vType.getDimSize(0)); + auto ptrsType = LLVM::LLVMFixedVectorType::get(pType, vType.getDimSize(0)); ptrs = rewriter.create(loc, ptrsType, base, indices); return success(); } @@ -742,7 +742,7 @@ // Remaining extraction of element from 1-D LLVM vector auto position = positionAttrs.back().cast(); - auto i64Type = LLVM::LLVMType::getInt64Ty(rewriter.getContext()); + auto i64Type = LLVM::LLVMIntegerType::get(rewriter.getContext(), 64); auto constant = rewriter.create(loc, i64Type, position); extracted = rewriter.create(loc, extracted, constant); @@ -850,7 +850,7 @@ } // Insertion of an element into a 1-D LLVM vector. - auto i64Type = LLVM::LLVMType::getInt64Ty(rewriter.getContext()); + auto i64Type = LLVM::LLVMIntegerType::get(rewriter.getContext(), 64); auto constant = rewriter.create(loc, i64Type, position); Value inserted = rewriter.create( loc, typeConverter->convertType(oneDVectorType), extracted, @@ -1117,7 +1117,7 @@ })) return failure(); - auto int64Ty = LLVM::LLVMType::getInt64Ty(rewriter.getContext()); + auto int64Ty = LLVM::LLVMIntegerType::get(rewriter.getContext(), 64); // Create descriptor. auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); @@ -1354,11 +1354,11 @@ switch (conversion) { case PrintConversion::ZeroExt64: value = rewriter.create( - loc, value, LLVM::LLVMType::getInt64Ty(rewriter.getContext())); + loc, value, LLVM::LLVMIntegerType::get(rewriter.getContext(), 64)); break; case PrintConversion::SignExt64: value = rewriter.create( - loc, value, LLVM::LLVMType::getInt64Ty(rewriter.getContext())); + loc, value, LLVM::LLVMIntegerType::get(rewriter.getContext(), 64)); break; case PrintConversion::None: break; @@ -1402,27 +1402,25 @@ OpBuilder moduleBuilder(module.getBodyRegion()); return moduleBuilder.create( op->getLoc(), name, - LLVM::LLVMType::getFunctionTy( - LLVM::LLVMType::getVoidTy(op->getContext()), params, - /*isVarArg=*/false)); + LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(op->getContext()), + params)); } // Helpers for method names. Operation *getPrintI64(Operation *op) const { return getPrint(op, "printI64", - LLVM::LLVMType::getInt64Ty(op->getContext())); + LLVM::LLVMIntegerType::get(op->getContext(), 64)); } Operation *getPrintU64(Operation *op) const { return getPrint(op, "printU64", - LLVM::LLVMType::getInt64Ty(op->getContext())); + LLVM::LLVMIntegerType::get(op->getContext(), 64)); } Operation *getPrintFloat(Operation *op) const { - return getPrint(op, "printF32", - LLVM::LLVMType::getFloatTy(op->getContext())); + return getPrint(op, "printF32", LLVM::LLVMFloatType::get(op->getContext())); } Operation *getPrintDouble(Operation *op) const { return getPrint(op, "printF64", - LLVM::LLVMType::getDoubleTy(op->getContext())); + LLVM::LLVMDoubleType::get(op->getContext())); } Operation *getPrintOpen(Operation *op) const { return getPrint(op, "printOpen", {}); diff --git a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp --- a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp +++ b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp @@ -119,7 +119,7 @@ Type i64Ty = rewriter.getIntegerType(64); Value i64x2Ty = rewriter.create( loc, - LLVM::LLVMType::getVectorTy( + LLVM::LLVMFixedVectorType::get( toLLVMTy(i64Ty).template cast(), 2), constConfig); Value dataPtrAsI64 = rewriter.create( @@ -127,7 +127,7 @@ Value zero = this->createIndexConstant(rewriter, loc, 0); Value dwordConfig = rewriter.create( loc, - LLVM::LLVMType::getVectorTy( + LLVM::LLVMFixedVectorType::get( toLLVMTy(i64Ty).template cast(), 2), i64x2Ty, dataPtrAsI64, zero); dwordConfig = diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -101,12 +101,13 @@ // The result type is either i1 or a vector type if the inputs are // vectors. - auto resultType = LLVMType::getInt1Ty(builder.getContext()); + LLVMType resultType = LLVMIntegerType::get(builder.getContext(), 1); auto argType = type.dyn_cast(); if (!argType) return parser.emitError(trailingTypeLoc, "expected LLVM IR dialect type"); if (auto vecArgType = argType.dyn_cast()) - resultType = LLVMType::getVectorTy(resultType, vecArgType.getNumElements()); + resultType = + LLVMFixedVectorType::get(resultType, vecArgType.getNumElements()); assert(!argType.isa() && "unhandled scalable vector"); @@ -547,7 +548,7 @@ LLVM::LLVMType llvmResultType; if (funcType.getNumResults() == 0) { - llvmResultType = LLVM::LLVMType::getVoidTy(builder.getContext()); + llvmResultType = LLVM::LLVMVoidType::get(builder.getContext()); } else { llvmResultType = funcType.getResult(0).dyn_cast(); if (!llvmResultType) @@ -565,8 +566,7 @@ "expected LLVM types as inputs"); } - auto llvmFuncType = LLVM::LLVMType::getFunctionTy(llvmResultType, argTypes, - /*isVarArg=*/false); + auto llvmFuncType = LLVM::LLVMFunctionType::get(llvmResultType, argTypes); auto wrappedFuncType = LLVM::LLVMPointerType::get(llvmFuncType); auto funcArguments = llvm::makeArrayRef(operands).drop_front(); @@ -827,7 +827,7 @@ Builder &builder = parser.getBuilder(); LLVM::LLVMType llvmResultType; if (funcType.getNumResults() == 0) { - llvmResultType = LLVM::LLVMType::getVoidTy(builder.getContext()); + llvmResultType = LLVM::LLVMVoidType::get(builder.getContext()); } else { llvmResultType = funcType.getResult(0).dyn_cast(); if (!llvmResultType) @@ -844,8 +844,7 @@ "expected LLVM types as inputs"); argTypes.push_back(argType); } - auto llvmFuncType = LLVM::LLVMType::getFunctionTy(llvmResultType, argTypes, - /*isVarArg=*/false); + auto llvmFuncType = LLVM::LLVMFunctionType::get(llvmResultType, argTypes); auto wrappedFuncType = LLVM::LLVMPointerType::get(llvmFuncType); auto funcArguments = @@ -1477,8 +1476,8 @@ if (types.empty()) { if (auto strAttr = value.dyn_cast_or_null()) { MLIRContext *context = parser.getBuilder().getContext(); - auto arrayType = LLVM::LLVMType::getArrayTy( - LLVM::LLVMType::getInt8Ty(context), strAttr.getValue().size()); + auto arrayType = LLVM::LLVMArrayType::get( + LLVM::LLVMIntegerType::get(context, 8), strAttr.getValue().size()); types.push_back(arrayType); } else { return parser.emitError(parser.getNameLoc(), @@ -1539,7 +1538,7 @@ ArrayRef attrs) { auto containerType = v1.getType().cast(); auto vType = - LLVMType::getVectorTy(containerType.getElementType(), mask.size()); + LLVMFixedVectorType::get(containerType.getElementType(), mask.size()); build(b, result, vType, v1, v2, mask); result.addAttributes(attrs); } @@ -1574,7 +1573,7 @@ return parser.emitError( loc, "expected LLVM IR dialect vector type for operand #1"); auto vType = - LLVMType::getVectorTy(containerType.getElementType(), maskAttr.size()); + LLVMFixedVectorType::get(containerType.getElementType(), maskAttr.size()); result.addTypes(vType); return success(); } @@ -1646,15 +1645,15 @@ } // No output is denoted as "void" in LLVM type system. - LLVMType llvmOutput = outputs.empty() ? LLVMType::getVoidTy(b.getContext()) + LLVMType llvmOutput = outputs.empty() ? LLVMVoidType::get(b.getContext()) : outputs.front().dyn_cast(); if (!llvmOutput) { parser.emitError(loc, "failed to construct function type: expected LLVM " "type for function results"); return {}; } - return LLVMType::getFunctionTy(llvmOutput, llvmInputs, - variadicFlag.isVariadic()); + return LLVMFunctionType::get(llvmOutput, llvmInputs, + variadicFlag.isVariadic()); } // Parses an LLVM function. @@ -1970,8 +1969,9 @@ parser.resolveOperand(val, type, result.operands)) return failure(); - auto boolType = LLVMType::getInt1Ty(builder.getContext()); - auto resultType = LLVMType::getStructTy(type, boolType); + auto boolType = LLVMIntegerType::get(builder.getContext(), 1); + auto resultType = + LLVMStructType::getLiteral(builder.getContext(), {type, boolType}); result.addTypes(resultType); return success(); @@ -2159,8 +2159,8 @@ // Create the global at the entry of the module. OpBuilder moduleBuilder(module.getBodyRegion()); MLIRContext *ctx = builder.getContext(); - auto type = - LLVM::LLVMType::getArrayTy(LLVM::LLVMType::getInt8Ty(ctx), value.size()); + auto type = LLVM::LLVMArrayType::get(LLVM::LLVMIntegerType::get(ctx, 8), + value.size()); auto global = moduleBuilder.create( loc, type, /*isConstant=*/true, linkage, name, builder.getStringAttr(value)); @@ -2168,10 +2168,11 @@ // Get the pointer to the first character in the global string. Value globalPtr = builder.create(loc, global); Value cst0 = builder.create( - loc, LLVM::LLVMType::getInt64Ty(ctx), + loc, LLVM::LLVMIntegerType::get(ctx, 64), builder.getIntegerAttr(builder.getIndexType(), 0)); - return builder.create(loc, LLVM::LLVMType::getInt8PtrTy(ctx), - globalPtr, ValueRange{cst0, cst0}); + return builder.create( + loc, LLVM::LLVMPointerType::get(LLVMIntegerType::get(ctx, 8)), globalPtr, + ValueRange{cst0, cst0}); } bool mlir::LLVM::satisfiesLLVMModule(Operation *op) { diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -36,106 +36,9 @@ return static_cast(Type::getDialect()); } -//----------------------------------------------------------------------------// -// Utilities used to generate floating point types. - -LLVMType LLVMType::getDoubleTy(MLIRContext *context) { - return LLVMDoubleType::get(context); -} - -LLVMType LLVMType::getFloatTy(MLIRContext *context) { - return LLVMFloatType::get(context); -} - -LLVMType LLVMType::getBFloatTy(MLIRContext *context) { - return LLVMBFloatType::get(context); -} - -LLVMType LLVMType::getHalfTy(MLIRContext *context) { - return LLVMHalfType::get(context); -} - -LLVMType LLVMType::getFP128Ty(MLIRContext *context) { - return LLVMFP128Type::get(context); -} - -LLVMType LLVMType::getX86_FP80Ty(MLIRContext *context) { - return LLVMX86FP80Type::get(context); -} - -//----------------------------------------------------------------------------// -// Utilities used to generate integer types. - -LLVMType LLVMType::getIntNTy(MLIRContext *context, unsigned numBits) { - return LLVMIntegerType::get(context, numBits); -} - -LLVMType LLVMType::getInt8PtrTy(MLIRContext *context) { - return LLVMPointerType::get(LLVMIntegerType::get(context, 8)); -} - -//----------------------------------------------------------------------------// -// Utilities used to generate other miscellaneous types. - -LLVMType LLVMType::getArrayTy(LLVMType elementType, uint64_t numElements) { - return LLVMArrayType::get(elementType, numElements); -} - -LLVMType LLVMType::getFunctionTy(LLVMType result, ArrayRef params, - bool isVarArg) { - return LLVMFunctionType::get(result, params, isVarArg); -} - -LLVMType LLVMType::getStructTy(MLIRContext *context, - ArrayRef elements, bool isPacked) { - return LLVMStructType::getLiteral(context, elements, isPacked); -} - -LLVMType LLVMType::getVectorTy(LLVMType elementType, unsigned numElements) { - return LLVMFixedVectorType::get(elementType, numElements); -} - -//----------------------------------------------------------------------------// -// Void type utilities. - -LLVMType LLVMType::getVoidTy(MLIRContext *context) { - return LLVMVoidType::get(context); -} - -//----------------------------------------------------------------------------// -// Creation and setting of LLVM's identified struct types - -LLVMType LLVMType::createStructTy(MLIRContext *context, - ArrayRef elements, - Optional name, bool isPacked) { - assert(name.hasValue() && - "identified structs with no identifier not supported"); - StringRef stringNameBase = name.getValueOr(""); - std::string stringName = stringNameBase.str(); - unsigned counter = 0; - do { - auto type = LLVMStructType::getIdentified(context, stringName); - if (type.isInitialized() || failed(type.setBody(elements, isPacked))) { - counter += 1; - stringName = - (Twine(stringNameBase) + "." + std::to_string(counter)).str(); - continue; - } - return type; - } while (true); -} - -LLVMType LLVMType::setStructTyBody(LLVMType structType, - ArrayRef elements, bool isPacked) { - LogicalResult couldSet = - structType.cast().setBody(elements, isPacked); - assert(succeeded(couldSet) && "failed to set the body"); - (void)couldSet; - return structType; -} - //===----------------------------------------------------------------------===// // Array type. +//===----------------------------------------------------------------------===// bool LLVMArrayType::isValidElementType(LLVMType type) { return !type.isa(); @@ -222,6 +126,7 @@ //===----------------------------------------------------------------------===// // Integer type. +//===----------------------------------------------------------------------===// LLVMIntegerType LLVMIntegerType::get(MLIRContext *ctx, unsigned bitwidth) { return Base::get(ctx, bitwidth); @@ -243,6 +148,7 @@ //===----------------------------------------------------------------------===// // Pointer type. +//===----------------------------------------------------------------------===// bool LLVMPointerType::isValidElementType(LLVMType type) { return !type.isa elements, + bool isPacked) { + std::string stringName = name.str(); + unsigned counter = 0; + do { + auto type = LLVMStructType::getIdentified(context, stringName); + if (type.isInitialized() || failed(type.setBody(elements, isPacked))) { + counter += 1; + stringName = (Twine(name) + "." + std::to_string(counter)).str(); + continue; + } + return type; + } while (true); +} + LLVMStructType LLVMStructType::getLiteral(MLIRContext *context, ArrayRef types, bool isPacked) { @@ -346,6 +270,7 @@ //===----------------------------------------------------------------------===// // Vector types. +//===----------------------------------------------------------------------===// bool LLVMVectorType::isValidElementType(LLVMType type) { return type.isa() || diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -63,7 +63,8 @@ break; } - auto int32Ty = LLVM::LLVMType::getInt32Ty(parser.getBuilder().getContext()); + auto int32Ty = + LLVM::LLVMIntegerType::get(parser.getBuilder().getContext(), 32); return parser.resolveOperands(ops, {int32Ty, type, int32Ty, int32Ty}, parser.getNameLoc(), result.operands); } @@ -72,8 +73,8 @@ static ParseResult parseNVVMVoteBallotOp(OpAsmParser &parser, OperationState &result) { MLIRContext *context = parser.getBuilder().getContext(); - auto int32Ty = LLVM::LLVMType::getInt32Ty(context); - auto int1Ty = LLVM::LLVMType::getInt1Ty(context); + auto int32Ty = LLVM::LLVMIntegerType::get(context, 32); + auto int1Ty = LLVM::LLVMIntegerType::get(context, 1); SmallVector ops; Type type; @@ -87,12 +88,12 @@ static LogicalResult verify(MmaOp op) { MLIRContext *context = op.getContext(); - auto f16Ty = LLVM::LLVMType::getHalfTy(context); - auto f16x2Ty = LLVM::LLVMType::getVectorTy(f16Ty, 2); - auto f32Ty = LLVM::LLVMType::getFloatTy(context); - auto f16x2x4StructTy = LLVM::LLVMType::getStructTy( + auto f16Ty = LLVM::LLVMHalfType::get(context); + auto f16x2Ty = LLVM::LLVMFixedVectorType::get(f16Ty, 2); + auto f32Ty = LLVM::LLVMFloatType::get(context); + auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral( context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty}); - auto f32x8StructTy = LLVM::LLVMType::getStructTy( + auto f32x8StructTy = LLVM::LLVMStructType::getLiteral( context, {f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty}); SmallVector operand_types(op.getOperandTypes().begin(), diff --git a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp @@ -46,9 +46,9 @@ return failure(); MLIRContext *context = parser.getBuilder().getContext(); - auto int32Ty = LLVM::LLVMType::getInt32Ty(context); - auto int1Ty = LLVM::LLVMType::getInt1Ty(context); - auto i32x4Ty = LLVM::LLVMType::getVectorTy(int32Ty, 4); + auto int32Ty = LLVM::LLVMIntegerType::get(context, 32); + auto int1Ty = LLVM::LLVMIntegerType::get(context, 1); + auto i32x4Ty = LLVM::LLVMFixedVectorType::get(int32Ty, 4); return parser.resolveOperands(ops, {i32x4Ty, int32Ty, int32Ty, int1Ty, int1Ty}, parser.getNameLoc(), result.operands); @@ -65,9 +65,9 @@ return failure(); MLIRContext *context = parser.getBuilder().getContext(); - auto int32Ty = LLVM::LLVMType::getInt32Ty(context); - auto int1Ty = LLVM::LLVMType::getInt1Ty(context); - auto i32x4Ty = LLVM::LLVMType::getVectorTy(int32Ty, 4); + auto int32Ty = LLVM::LLVMIntegerType::get(context, 32); + auto int1Ty = LLVM::LLVMIntegerType::get(context, 1); + auto i32x4Ty = LLVM::LLVMFixedVectorType::get(int32Ty, 4); if (parser.resolveOperands(ops, {type, i32x4Ty, int32Ty, int32Ty, int1Ty, int1Ty}, diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -642,13 +642,12 @@ LLVM::LLVMType resultType; if (inlineAsmOp.getNumResults() == 0) { - resultType = LLVM::LLVMType::getVoidTy(mlirModule->getContext()); + resultType = LLVM::LLVMVoidType::get(mlirModule->getContext()); } else { assert(inlineAsmOp.getNumResults() == 1); resultType = inlineAsmOp.getResultTypes()[0].cast(); } - auto ft = LLVM::LLVMType::getFunctionTy(resultType, operandTypes, - /*isVarArg=*/false); + auto ft = LLVM::LLVMFunctionType::get(resultType, operandTypes); llvm::InlineAsm *inlineAsmInst = inlineAsmOp.asm_dialect().hasValue() ? llvm::InlineAsm::get( diff --git a/mlir/test/lib/Transforms/TestConvertCallOp.cpp b/mlir/test/lib/Transforms/TestConvertCallOp.cpp --- a/mlir/test/lib/Transforms/TestConvertCallOp.cpp +++ b/mlir/test/lib/Transforms/TestConvertCallOp.cpp @@ -45,7 +45,8 @@ // Populate type conversions. LLVMTypeConverter type_converter(m.getContext()); type_converter.addConversion([&](test::TestType type) { - return LLVM::LLVMType::getInt8PtrTy(m.getContext()); + return LLVM::LLVMPointerType::get( + LLVM::LLVMIntegerType::get(m.getContext(), 8)); }); // Populate patterns.