diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td --- a/flang/include/flang/Optimizer/Dialect/FIROps.td +++ b/flang/include/flang/Optimizer/Dialect/FIROps.td @@ -2176,7 +2176,7 @@ p.printOptionalAttrDict(getAttrs(), {"fn_type", "method"}); auto resTy{getResultTypes()}; llvm::SmallVector argTy(getOperandTypes()); - p << " : " << mlir::FunctionType::get(argTy, resTy, getContext()); + p << " : " << mlir::FunctionType::get(getContext(), argTy, resTy); }]; let extraClassDeclaration = [{ diff --git a/flang/lib/Lower/ConvertType.cpp b/flang/lib/Lower/ConvertType.cpp --- a/flang/lib/Lower/ConvertType.cpp +++ b/flang/lib/Lower/ConvertType.cpp @@ -49,7 +49,7 @@ if constexpr (TC == Fortran::common::TypeCategory::Integer) { auto bits{Fortran::evaluate::Type::Scalar::bits}; - return mlir::IntegerType::get(bits, context); + return mlir::IntegerType::get(context, bits); } else if constexpr (TC == Fortran::common::TypeCategory::Logical || TC == Fortran::common::TypeCategory::Character || TC == Fortran::common::TypeCategory::Complex) { @@ -278,7 +278,7 @@ // some sequence of `n` bytes mlir::Type gen(const Fortran::evaluate::StaticDataObject::Pointer &ptr) { - mlir::Type byteTy{mlir::IntegerType::get(8, context)}; + mlir::Type byteTy{mlir::IntegerType::get(context, 8)}; return fir::SequenceType::get(trivialShape(ptr->itemBytes()), byteTy); } diff --git a/flang/lib/Lower/IntrinsicCall.cpp b/flang/lib/Lower/IntrinsicCall.cpp --- a/flang/lib/Lower/IntrinsicCall.cpp +++ b/flang/lib/Lower/IntrinsicCall.cpp @@ -298,26 +298,26 @@ static mlir::FunctionType genF32F32FuncType(mlir::MLIRContext *context) { auto t = mlir::FloatType::getF32(context); - return mlir::FunctionType::get({t}, {t}, context); + return mlir::FunctionType::get(context, {t}, {t}); } static mlir::FunctionType genF64F64FuncType(mlir::MLIRContext *context) { auto t = mlir::FloatType::getF64(context); - return mlir::FunctionType::get({t}, {t}, context); + return mlir::FunctionType::get(context, {t}, {t}); } template static mlir::FunctionType genIntF64FuncType(mlir::MLIRContext *context) { auto t = mlir::FloatType::getF64(context); - auto r = mlir::IntegerType::get(Bits, context); - return mlir::FunctionType::get({t}, {r}, context); + auto r = mlir::IntegerType::get(context, Bits); + return mlir::FunctionType::get(context, {t}, {r}); } template static mlir::FunctionType genIntF32FuncType(mlir::MLIRContext *context) { auto t = mlir::FloatType::getF32(context); - auto r = mlir::IntegerType::get(Bits, context); - return mlir::FunctionType::get({t}, {r}, context); + auto r = mlir::IntegerType::get(context, Bits); + return mlir::FunctionType::get(context, {t}, {r}); } // TODO : Fill-up this table with more intrinsic. @@ -585,8 +585,8 @@ llvm::SmallVector argumentTypes; for (auto &arg : arguments) argumentTypes.push_back(arg.getType()); - return mlir::FunctionType::get(argumentTypes, resultType, - builder.getModule().getContext()); + return mlir::FunctionType::get(builder.getModule().getContext(), + argumentTypes, resultType); } /// fir::ExtendedValue to mlir::Value translation layer @@ -1144,7 +1144,7 @@ llvm::ArrayRef args) { assert(args.size() == 3); - auto i1Type = mlir::IntegerType::get(1, builder.getContext()); + auto i1Type = mlir::IntegerType::get(builder.getContext(), 1); auto mask = builder.createConvert(loc, i1Type, args[2]); return builder.create(loc, mask, args[0], args[1]); } diff --git a/flang/lib/Lower/RTBuilder.h b/flang/lib/Lower/RTBuilder.h --- a/flang/lib/Lower/RTBuilder.h +++ b/flang/lib/Lower/RTBuilder.h @@ -48,7 +48,7 @@ template <> constexpr TypeBuilderFunc getModel() { return [](mlir::MLIRContext *context) -> mlir::Type { - return mlir::IntegerType::get(8 * sizeof(int), context); + return mlir::IntegerType::get(context, 8 * sizeof(int)); }; } template <> @@ -61,14 +61,14 @@ template <> constexpr TypeBuilderFunc getModel() { return [](mlir::MLIRContext *context) -> mlir::Type { - return mlir::IntegerType::get(8 * sizeof(Fortran::runtime::io::Iostat), - context); + return mlir::IntegerType::get(context, + 8 * sizeof(Fortran::runtime::io::Iostat)); }; } template <> constexpr TypeBuilderFunc getModel() { return [](mlir::MLIRContext *context) -> mlir::Type { - return fir::ReferenceType::get(mlir::IntegerType::get(8, context)); + return fir::ReferenceType::get(mlir::IntegerType::get(context, 8)); }; } template <> @@ -78,26 +78,26 @@ template <> constexpr TypeBuilderFunc getModel() { return [](mlir::MLIRContext *context) -> mlir::Type { - return fir::ReferenceType::get(mlir::IntegerType::get(16, context)); + return fir::ReferenceType::get(mlir::IntegerType::get(context, 16)); }; } template <> constexpr TypeBuilderFunc getModel() { return [](mlir::MLIRContext *context) -> mlir::Type { - return fir::ReferenceType::get(mlir::IntegerType::get(32, context)); + return fir::ReferenceType::get(mlir::IntegerType::get(context, 32)); }; } template <> constexpr TypeBuilderFunc getModel() { return [](mlir::MLIRContext *context) -> mlir::Type { return fir::ReferenceType::get( - fir::PointerType::get(mlir::IntegerType::get(8, context))); + fir::PointerType::get(mlir::IntegerType::get(context, 8))); }; } template <> constexpr TypeBuilderFunc getModel() { return [](mlir::MLIRContext *context) -> mlir::Type { - return mlir::IntegerType::get(64, context); + return mlir::IntegerType::get(context, 64); }; } template <> @@ -110,7 +110,7 @@ template <> constexpr TypeBuilderFunc getModel() { return [](mlir::MLIRContext *context) -> mlir::Type { - return mlir::IntegerType::get(8 * sizeof(std::size_t), context); + return mlir::IntegerType::get(context, 8 * sizeof(std::size_t)); }; } template <> @@ -146,7 +146,7 @@ template <> constexpr TypeBuilderFunc getModel() { return [](mlir::MLIRContext *context) -> mlir::Type { - return mlir::IntegerType::get(1, context); + return mlir::IntegerType::get(context, 1); }; } template <> @@ -190,7 +190,7 @@ llvm::SmallVector argTys; for (auto f : args) argTys.push_back(f(ctxt)); - return mlir::FunctionType::get(argTys, {retTy}, ctxt); + return mlir::FunctionType::get(ctxt, argTys, {retTy}); }; } }; diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -151,7 +151,7 @@ // note: triple, but 4 is nearest power of 2 llvm::SmallVector triple{ getResult(0).getType(), getResult(1).getType(), getResult(2).getType()}; - return mlir::TupleType::get(triple, getContext()); + return mlir::TupleType::get(getContext(), triple); } //===----------------------------------------------------------------------===// @@ -171,7 +171,7 @@ auto resultTypes{op.getResultTypes()}; llvm::SmallVector argTypes( llvm::drop_begin(op.getOperandTypes(), isDirect ? 0 : 1)); - p << " : " << FunctionType::get(argTypes, resultTypes, op.getContext()); + p << " : " << FunctionType::get(op.getContext(), argTypes, resultTypes); } static mlir::ParseResult parseCallOp(mlir::OpAsmParser &parser, @@ -1565,4 +1565,3 @@ #define GET_OP_CLASSES #include "flang/Optimizer/Dialect/FIROps.cpp.inc" - diff --git a/mlir/include/mlir/Dialect/AVX512/AVX512.td b/mlir/include/mlir/Dialect/AVX512/AVX512.td --- a/mlir/include/mlir/Dialect/AVX512/AVX512.td +++ b/mlir/include/mlir/Dialect/AVX512/AVX512.td @@ -35,8 +35,8 @@ AllTypesMatch<["src", "a", "dst"]>, TypesMatchWith<"imm has the same number of bits as elements in dst", "dst", "imm", - "IntegerType::get(($_self.cast().getShape()[0])," - " $_self.getContext())">]> { + "IntegerType::get($_self.getContext(), " + "($_self.cast().getShape()[0]))">]> { let summary = "Masked roundscale op"; let description = [{ The mask.rndscale op is an AVX512 specific op that can lower to the proper @@ -67,8 +67,8 @@ AllTypesMatch<["src", "a", "b", "dst"]>, TypesMatchWith<"k has the same number of bits as elements in dst", "dst", "k", - "IntegerType::get(($_self.cast().getShape()[0])," - " $_self.getContext())">]> { + "IntegerType::get($_self.getContext(), " + "($_self.cast().getShape()[0]))">]> { let summary = "ScaleF op"; let description = [{ The `mask.scalef` op is an AVX512 specific op that can lower to the proper diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td @@ -911,7 +911,7 @@ auto attr = (*this)->getAttr("operand_segment_sizes") .cast(); unsigned i = 0; - auto newAttr = attr.mapValues(IntegerType::get(32, getContext()), + auto newAttr = attr.mapValues(IntegerType::get(getContext(), 32), [&](const APInt &v) { return (i++ == idx) ? APInt(32, val) : v; }); getOperation()->setAttr("operand_segment_sizes", newAttr); } diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -63,7 +63,7 @@ /// Get or create a ComplexType with the provided element type. This emits /// and error at the specified location and returns null if the element type /// isn't supported. - static ComplexType getChecked(Type elementType, Location location); + static ComplexType getChecked(Location location, Type elementType); /// Verify the construction of an integer type. static LogicalResult verifyConstructionInvariants(Location loc, @@ -93,27 +93,27 @@ /// The created IntegerType is signless (i.e., no signedness semantics). /// Assume the width is within the allowed range and assert on failures. Use /// getChecked to handle failures gracefully. - static IntegerType get(unsigned width, MLIRContext *context); + static IntegerType get(MLIRContext *context, unsigned width); /// Get or create a new IntegerType of the given width within the context. /// The created IntegerType has signedness semantics as indicated via /// `signedness`. Assume the width is within the allowed range and assert on /// failures. Use getChecked to handle failures gracefully. - static IntegerType get(unsigned width, SignednessSemantics signedness, - MLIRContext *context); + static IntegerType get(MLIRContext *context, unsigned width, + SignednessSemantics signedness); /// Get or create a new IntegerType of the given width within the context, /// defined at the given, potentially unknown, location. The created /// IntegerType is signless (i.e., no signedness semantics). If the width is /// outside the allowed range, emit errors and return a null type. - static IntegerType getChecked(unsigned width, Location location); + static IntegerType getChecked(Location location, unsigned width); /// Get or create a new IntegerType of the given width within the context, /// defined at the given, potentially unknown, location. The created /// IntegerType has signedness semantics as indicated via `signedness`. If the /// width is outside the allowed range, emit errors and return a null type. - static IntegerType getChecked(unsigned width, SignednessSemantics signedness, - Location location); + static IntegerType getChecked(Location location, unsigned width, + SignednessSemantics signedness); /// Verify the construction of an integer type. static LogicalResult @@ -180,8 +180,8 @@ public: using Base::Base; - static FunctionType get(TypeRange inputs, TypeRange results, - MLIRContext *context); + static FunctionType get(MLIRContext *context, TypeRange inputs, + TypeRange results); /// Input types. unsigned getNumInputs() const; @@ -211,14 +211,14 @@ using Base::Base; /// Get or create a new OpaqueType with the provided dialect and string data. - static OpaqueType get(Identifier dialect, StringRef typeData, - MLIRContext *context); + static OpaqueType get(MLIRContext *context, Identifier dialect, + StringRef typeData); /// Get or create a new OpaqueType with the provided dialect and string data. /// If the given identifier is not a valid namespace for a dialect, then a /// null type is returned. - static OpaqueType getChecked(Identifier dialect, StringRef typeData, - MLIRContext *context, Location location); + static OpaqueType getChecked(Location location, Identifier dialect, + StringRef typeData); /// Returns the dialect namespace of the opaque type. Identifier getDialectNamespace() const; @@ -335,8 +335,8 @@ /// declared at the given, potentially unknown, location. If the VectorType /// defined by the arguments would be ill-formed, emit errors and return /// nullptr-wrapping type. - static VectorType getChecked(ArrayRef shape, Type elementType, - Location location); + static VectorType getChecked(Location location, ArrayRef shape, + Type elementType); /// Verify the construction of a vector type. static LogicalResult verifyConstructionInvariants(Location loc, @@ -394,8 +394,8 @@ /// type declared at the given, potentially unknown, location. If the /// RankedTensorType defined by the arguments would be ill-formed, emit errors /// and return a nullptr-wrapping type. - static RankedTensorType getChecked(ArrayRef shape, Type elementType, - Location location); + static RankedTensorType getChecked(Location location, ArrayRef shape, + Type elementType); /// Verify the construction of a ranked tensor type. static LogicalResult verifyConstructionInvariants(Location loc, @@ -424,7 +424,7 @@ /// type declared at the given, potentially unknown, location. If the /// UnrankedTensorType defined by the arguments would be ill-formed, emit /// errors and return a nullptr-wrapping type. - static UnrankedTensorType getChecked(Type elementType, Location location); + static UnrankedTensorType getChecked(Location location, Type elementType); /// Verify the construction of a unranked tensor type. static LogicalResult verifyConstructionInvariants(Location loc, @@ -527,9 +527,10 @@ /// UnknownLoc. If the MemRefType defined by the arguments would be /// ill-formed, emits errors (to the handler registered with the context or to /// the error stream) and returns nullptr. - static MemRefType getChecked(ArrayRef shape, Type elementType, + static MemRefType getChecked(Location location, ArrayRef shape, + Type elementType, ArrayRef affineMapComposition, - unsigned memorySpace, Location location); + unsigned memorySpace); ArrayRef getShape() const; @@ -573,8 +574,8 @@ /// type and memory space declared at the given, potentially unknown, /// location. If the UnrankedMemRefType defined by the arguments would be /// ill-formed, emit errors and return a nullptr-wrapping type. - static UnrankedMemRefType getChecked(Type elementType, unsigned memorySpace, - Location location); + static UnrankedMemRefType getChecked(Location location, Type elementType, + unsigned memorySpace); /// Verify the construction of a unranked memref type. static LogicalResult verifyConstructionInvariants(Location loc, @@ -600,7 +601,7 @@ /// Get or create a new TupleType with the provided element types. Assumes the /// arguments define a well-formed type. - static TupleType get(TypeRange elementTypes, MLIRContext *context); + static TupleType get(MLIRContext *context, TypeRange elementTypes); /// Get or create an empty tuple type. static TupleType get(MLIRContext *context); diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -475,8 +475,9 @@ class OpaqueType : Type, description>, - BuildableType<"::mlir::OpaqueType::get($_builder.getIdentifier(\"" - # dialect # "\"), \"" # name # "\", $_builder.getContext())">; + BuildableType<"::mlir::OpaqueType::get($_builder.getContext(), " + "$_builder.getIdentifier(\"" # dialect # "\"), \"" + # name # "\")">; // Function Type diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -26,15 +26,15 @@ } MlirType mlirIntegerTypeGet(MlirContext ctx, unsigned bitwidth) { - return wrap(IntegerType::get(bitwidth, unwrap(ctx))); + return wrap(IntegerType::get(unwrap(ctx), bitwidth)); } MlirType mlirIntegerTypeSignedGet(MlirContext ctx, unsigned bitwidth) { - return wrap(IntegerType::get(bitwidth, IntegerType::Signed, unwrap(ctx))); + return wrap(IntegerType::get(unwrap(ctx), bitwidth, IntegerType::Signed)); } MlirType mlirIntegerTypeUnsignedGet(MlirContext ctx, unsigned bitwidth) { - return wrap(IntegerType::get(bitwidth, IntegerType::Unsigned, unwrap(ctx))); + return wrap(IntegerType::get(unwrap(ctx), bitwidth, IntegerType::Unsigned)); } unsigned mlirIntegerTypeGetWidth(MlirType type) { @@ -172,8 +172,8 @@ MlirType mlirVectorTypeGetChecked(intptr_t rank, const int64_t *shape, MlirType elementType, MlirLocation loc) { return wrap(VectorType::getChecked( - llvm::makeArrayRef(shape, static_cast(rank)), unwrap(elementType), - unwrap(loc))); + unwrap(loc), llvm::makeArrayRef(shape, static_cast(rank)), + unwrap(elementType))); } //===----------------------------------------------------------------------===// @@ -201,8 +201,8 @@ MlirType elementType, MlirLocation loc) { return wrap(RankedTensorType::getChecked( - llvm::makeArrayRef(shape, static_cast(rank)), unwrap(elementType), - unwrap(loc))); + unwrap(loc), llvm::makeArrayRef(shape, static_cast(rank)), + unwrap(elementType))); } MlirType mlirUnrankedTensorTypeGet(MlirType elementType) { @@ -211,7 +211,7 @@ MlirType mlirUnrankedTensorTypeGetChecked(MlirType elementType, MlirLocation loc) { - return wrap(UnrankedTensorType::getChecked(unwrap(elementType), unwrap(loc))); + return wrap(UnrankedTensorType::getChecked(unwrap(loc), unwrap(elementType))); } //===----------------------------------------------------------------------===// @@ -244,8 +244,8 @@ unsigned memorySpace, MlirLocation loc) { return wrap(MemRefType::getChecked( - llvm::makeArrayRef(shape, static_cast(rank)), unwrap(elementType), - llvm::None, memorySpace, unwrap(loc))); + unwrap(loc), llvm::makeArrayRef(shape, static_cast(rank)), + unwrap(elementType), llvm::None, memorySpace)); } intptr_t mlirMemRefTypeGetNumAffineMaps(MlirType type) { @@ -272,8 +272,8 @@ MlirType mlirUnrankedMemRefTypeGetChecked(MlirType elementType, unsigned memorySpace, MlirLocation loc) { - return wrap(UnrankedMemRefType::getChecked(unwrap(elementType), memorySpace, - unwrap(loc))); + return wrap(UnrankedMemRefType::getChecked(unwrap(loc), unwrap(elementType), + memorySpace)); } unsigned mlirUnrankedMemrefGetMemorySpace(MlirType type) { @@ -290,7 +290,7 @@ MlirType const *elements) { SmallVector types; ArrayRef typeRef = unwrapList(numElements, elements, types); - return wrap(TupleType::get(typeRef, unwrap(ctx))); + return wrap(TupleType::get(unwrap(ctx), typeRef)); } intptr_t mlirTupleTypeGetNumTypes(MlirType type) { @@ -316,7 +316,7 @@ SmallVector resultsList; (void)unwrapList(numInputs, inputs, inputsList); (void)unwrapList(numResults, results, resultsList); - return wrap(FunctionType::get(inputsList, resultsList, unwrap(ctx))); + return wrap(FunctionType::get(unwrap(ctx), inputsList, resultsList)); } intptr_t mlirFunctionTypeGetNumInputs(MlirType type) { diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp --- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp +++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp @@ -53,52 +53,52 @@ struct AsyncAPI { static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) { auto ref = LLVM::LLVMType::getInt8PtrTy(ctx); - auto count = IntegerType::get(32, ctx); - return FunctionType::get({ref, count}, {}, ctx); + auto count = IntegerType::get(ctx, 32); + return FunctionType::get(ctx, {ref, count}, {}); } static FunctionType createTokenFunctionType(MLIRContext *ctx) { - return FunctionType::get({}, {TokenType::get(ctx)}, ctx); + return FunctionType::get(ctx, {}, {TokenType::get(ctx)}); } static FunctionType createGroupFunctionType(MLIRContext *ctx) { - return FunctionType::get({}, {GroupType::get(ctx)}, ctx); + return FunctionType::get(ctx, {}, {GroupType::get(ctx)}); } static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) { - return FunctionType::get({TokenType::get(ctx)}, {}, ctx); + return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); } static FunctionType awaitTokenFunctionType(MLIRContext *ctx) { - return FunctionType::get({TokenType::get(ctx)}, {}, ctx); + return FunctionType::get(ctx, {TokenType::get(ctx)}, {}); } static FunctionType awaitGroupFunctionType(MLIRContext *ctx) { - return FunctionType::get({GroupType::get(ctx)}, {}, ctx); + return FunctionType::get(ctx, {GroupType::get(ctx)}, {}); } static FunctionType executeFunctionType(MLIRContext *ctx) { auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx); auto resume = resumeFunctionType(ctx).getPointerTo(); - return FunctionType::get({hdl, resume}, {}, ctx); + return FunctionType::get(ctx, {hdl, resume}, {}); } static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) { - auto i64 = IntegerType::get(64, ctx); - return FunctionType::get({TokenType::get(ctx), GroupType::get(ctx)}, {i64}, - ctx); + auto i64 = IntegerType::get(ctx, 64); + return FunctionType::get(ctx, {TokenType::get(ctx), GroupType::get(ctx)}, + {i64}); } static FunctionType awaitAndExecuteFunctionType(MLIRContext *ctx) { auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx); auto resume = resumeFunctionType(ctx).getPointerTo(); - return FunctionType::get({TokenType::get(ctx), hdl, resume}, {}, ctx); + return FunctionType::get(ctx, {TokenType::get(ctx), hdl, resume}, {}); } static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) { auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx); auto resume = resumeFunctionType(ctx).getPointerTo(); - return FunctionType::get({GroupType::get(ctx), hdl, resume}, {}, ctx); + return FunctionType::get(ctx, {GroupType::get(ctx), hdl, resume}, {}); } // Auxiliary coroutine resume intrinsic wrapper. @@ -690,7 +690,7 @@ if (!addToGroup.operand().getType().isa()) return failure(); - auto i64 = IntegerType::get(64, op->getContext()); + auto i64 = IntegerType::get(op->getContext(), 64); rewriter.replaceOpWithNewOp(op, kAddTokenToGroup, i64, operands); return success(); } diff --git a/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp --- a/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp +++ b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp @@ -122,7 +122,7 @@ } // Declare vulkan launch function. - auto funcType = FunctionType::get(vulkanLaunchTypes, {}, loc->getContext()); + auto funcType = builder.getFunctionType(vulkanLaunchTypes, {}); builder.create(loc, kVulkanLaunch, funcType).setPrivate(); return success(); diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp --- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp +++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp @@ -84,7 +84,7 @@ /// }; static Type convertRangeType(RangeType t, LLVMTypeConverter &converter) { auto *context = t.getContext(); - auto int64Ty = converter.convertType(IntegerType::get(64, context)) + auto int64Ty = converter.convertType(IntegerType::get(context, 64)) .cast(); return LLVMType::getStructTy(int64Ty, int64Ty, int64Ty); } diff --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp --- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp +++ b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp @@ -65,7 +65,7 @@ assert(op->getNumResults() == 0 && "Library call for linalg operation can be generated only for ops that " "have void return types"); - auto libFnType = FunctionType::get(inputTypes, {}, rewriter.getContext()); + auto libFnType = rewriter.getFunctionType(inputTypes, {}); OpBuilder::InsertionGuard guard(rewriter); // Insert before module terminator. diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp @@ -407,8 +407,7 @@ // cover all possible corner cases. if (isSignedIntegerOrVector(srcType) || isUnsignedIntegerOrVector(srcType)) { - auto *context = rewriter.getContext(); - auto signlessType = IntegerType::get(getBitWidth(srcType), context); + auto signlessType = rewriter.getIntegerType(getBitWidth(srcType)); if (srcType.isa()) { auto dstElementsAttr = constOp.value().cast(); diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -584,7 +584,7 @@ std::swap(ivsStorage.back(), ivsStorage[coalescedIdx]); ArrayRef ivs(ivsStorage); - Value pos = std_index_cast(IntegerType::get(32, ctx), ivs.back()); + Value pos = std_index_cast(IntegerType::get(ctx, 32), ivs.back()); Value inVector = local(ivs.drop_back()); auto loadValue = [&](ArrayRef indices) { Value vector = vector_insert_element(remote(indices), inVector, pos); @@ -671,7 +671,7 @@ ArrayRef ivs(ivsStorage); Value pos = - std_index_cast(IntegerType::get(32, op->getContext()), ivs.back()); + std_index_cast(IntegerType::get(op->getContext(), 32), ivs.back()); auto storeValue = [&](ArrayRef indices) { Value scalar = vector_extract_element(local(ivs.drop_back()), pos); remote(indices) = scalar; diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp --- a/mlir/lib/Dialect/Async/IR/Async.cpp +++ b/mlir/lib/Dialect/Async/IR/Async.cpp @@ -152,7 +152,7 @@ int32_t numDependencies = dependencies.size(); int32_t numOperands = operands.size(); auto operandSegmentSizes = DenseIntElementsAttr::get( - VectorType::get({2}, IntegerType::get(32, result.getContext())), + VectorType::get({2}, builder.getIntegerType(32)), {numDependencies, numOperands}); result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes); diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncRefCounting.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncRefCounting.cpp --- a/mlir/lib/Dialect/Async/Transforms/AsyncRefCounting.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncRefCounting.cpp @@ -118,7 +118,7 @@ builder.setInsertionPointToStart(value.getParentBlock()); Location loc = value.getLoc(); - auto i32 = IntegerType::get(32, ctx); + auto i32 = IntegerType::get(ctx, 32); // Drop the reference count immediately if the value has no uses. if (value.getUses().empty()) { diff --git a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp --- a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp @@ -31,7 +31,7 @@ : funcOp(funcOp_), reduceOp(reduceOp_), rewriter(rewriter_), loc(reduceOp.getLoc()), valueType(reduceOp.value().getType()), indexType(IndexType::get(reduceOp.getContext())), - int32Type(IntegerType::get(/*width=*/32, reduceOp.getContext())) {} + int32Type(IntegerType::get(reduceOp.getContext(), /*width=*/32)) {} /// Creates an all_reduce across the workgroup. /// diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp --- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp @@ -155,7 +155,7 @@ kernelOperandTypes.push_back(operand.getType()); } FunctionType type = - FunctionType::get(kernelOperandTypes, {}, launchOp.getContext()); + FunctionType::get(launchOp.getContext(), kernelOperandTypes, {}); auto outlinedFunc = builder.create(loc, kernelFnName, type); outlinedFunc->setAttr(gpu::GPUDialect::getKernelFuncAttrName(), builder.getUnitAttr()); diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -120,8 +120,8 @@ static void printAllocaOp(OpAsmPrinter &p, AllocaOp &op) { auto elemTy = op.getType().cast().getPointerElementTy(); - auto funcTy = FunctionType::get({op.arraySize().getType()}, {op.getType()}, - op.getContext()); + auto funcTy = FunctionType::get(op.getContext(), {op.arraySize().getType()}, + {op.getType()}); p << op.getOperationName() << ' ' << op.arraySize() << " x " << elemTy; if (op.alignment().hasValue() && *op.alignment() != 0) @@ -781,7 +781,7 @@ // Reconstruct the function MLIR function type from operand and result types. p << " : " - << FunctionType::get(args.getTypes(), op.getResultTypes(), op.getContext()); + << FunctionType::get(op.getContext(), args.getTypes(), op.getResultTypes()); } // ::= `llvm.call` (function-id | ssa-use) `(` ssa-use-list `)` diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -76,25 +76,25 @@ IntegerAttr alignment_attr; if (alignment.hasValue()) alignment_attr = - IntegerAttr::get(IntegerType::get(64, ctx), alignment.getValue()); + IntegerAttr::get(IntegerType::get(ctx, 64), alignment.getValue()); if (!dynamicBuffers) if (auto cst = size.getDefiningOp()) return options.useAlloca ? std_alloca(MemRefType::get(width * cst.getValue(), - IntegerType::get(8, ctx)), + IntegerType::get(ctx, 8)), ValueRange{}, alignment_attr) .value : std_alloc(MemRefType::get(width * cst.getValue(), - IntegerType::get(8, ctx)), + IntegerType::get(ctx, 8)), ValueRange{}, alignment_attr) .value; Value mul = folded_std_muli(folder, folded_std_constant_index(folder, width), size); return options.useAlloca - ? std_alloca(MemRefType::get(-1, IntegerType::get(8, ctx)), mul, + ? std_alloca(MemRefType::get(-1, IntegerType::get(ctx, 8)), mul, alignment_attr) .value - : std_alloc(MemRefType::get(-1, IntegerType::get(8, ctx)), mul, + : std_alloc(MemRefType::get(-1, IntegerType::get(ctx, 8)), mul, alignment_attr) .value; } diff --git a/mlir/lib/Dialect/Quant/Utils/FakeQuantSupport.cpp b/mlir/lib/Dialect/Quant/Utils/FakeQuantSupport.cpp --- a/mlir/lib/Dialect/Quant/Utils/FakeQuantSupport.cpp +++ b/mlir/lib/Dialect/Quant/Utils/FakeQuantSupport.cpp @@ -18,7 +18,7 @@ int64_t &qmax) { // Hard-coded type mapping from TFLite. if (numBits <= 8) { - storageType = IntegerType::get(8, ctx); + storageType = IntegerType::get(ctx, 8); if (isSigned) { qmin = -128; qmax = 127; @@ -27,7 +27,7 @@ qmax = 255; } } else if (numBits <= 16) { - storageType = IntegerType::get(16, ctx); + storageType = IntegerType::get(ctx, 16); if (isSigned) { qmin = -32768; qmax = 32767; @@ -36,7 +36,7 @@ qmax = 65535; } } else if (numBits <= 32) { - storageType = IntegerType::get(32, ctx); + storageType = IntegerType::get(ctx, 32); if (isSigned) { qmin = std::numeric_limits::min(); qmax = std::numeric_limits::max(); diff --git a/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp b/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp --- a/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp +++ b/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp @@ -79,7 +79,7 @@ int64_t chunkSize = std::accumulate(std::next(shape.begin(), quantizationDim + 1), shape.end(), 1, std::multiplies()); - Type newElementType = IntegerType::get(storageBitWidth, attr.getContext()); + Type newElementType = IntegerType::get(attr.getContext(), storageBitWidth); return attr.mapValues(newElementType, [&](const APFloat &old) { int chunkIndex = (flattenIndex++) / chunkSize; return converters[chunkIndex % dimSize].quantizeFloatToInt(old); diff --git a/mlir/lib/Dialect/SCF/Transforms/Utils.cpp b/mlir/lib/Dialect/SCF/Transforms/Utils.cpp --- a/mlir/lib/Dialect/SCF/Transforms/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/Utils.cpp @@ -96,7 +96,7 @@ ValueRange values(captures.getArrayRef()); FunctionType type = - FunctionType::get(values.getTypes(), ifOp.getResultTypes(), ctx); + FunctionType::get(ctx, values.getTypes(), ifOp.getResultTypes()); auto outlinedFunc = b.create(loc, funcName, type); b.setInsertionPointToStart(outlinedFunc.addEntryBlock()); BlockAndValueMapping bvm; diff --git a/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp b/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp --- a/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp @@ -123,7 +123,7 @@ assert(localSize.size() == 3); return spirv::EntryPointABIAttr::get( DenseElementsAttr::get( - VectorType::get(3, IntegerType::get(32, context)), localSize) + VectorType::get(3, IntegerType::get(context, 32)), localSize) .cast(), context); } diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -93,7 +93,7 @@ // instructions. The Vulkan spec requires the builtins like // GlobalInvocationID, etc. to be 32-bit (unsigned) integers which should be // SExtended to 64-bit for index computations. - return IntegerType::get(32, context); + return IntegerType::get(context, 32); } /// Mapping between SPIR-V storage classes to memref memory spaces. @@ -260,8 +260,8 @@ auto intType = type.cast(); LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n"); - return IntegerType::get(/*width=*/32, intType.getSignedness(), - targetEnv.getContext()); + return IntegerType::get(targetEnv.getContext(), /*width=*/32, + intType.getSignedness()); } /// Converts a vector `type` to a suitable type under the given `targetEnv`. diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -714,7 +714,7 @@ } FunctionType CallOp::getCalleeType() { - return FunctionType::get(getOperandTypes(), getResultTypes(), getContext()); + return FunctionType::get(getContext(), getOperandTypes(), getResultTypes()); } //===----------------------------------------------------------------------===// @@ -753,7 +753,7 @@ // Return the type of the same shape (scalar, vector or tensor) containing i1. static Type getI1SameShape(Type type) { - auto i1Type = IntegerType::get(1, type.getContext()); + auto i1Type = IntegerType::get(type.getContext(), 1); if (auto tensorType = type.dyn_cast()) return RankedTensorType::get(tensorType.getShape(), i1Type); if (type.isa()) @@ -914,7 +914,7 @@ return {}; auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); - return IntegerAttr::get(IntegerType::get(1, getContext()), APInt(1, val)); + return IntegerAttr::get(IntegerType::get(getContext(), 1), APInt(1, val)); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -1426,7 +1426,7 @@ static ArrayAttr makeI64ArrayAttr(ArrayRef values, MLIRContext *context) { auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute { - return IntegerAttr::get(IntegerType::get(64, context), APInt(64, v)); + return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v)); }); return ArrayAttr::get(llvm::to_vector<8>(attrs), context); } @@ -2767,7 +2767,7 @@ parser.parseOptionalAttrDict(result.attributes) || parser.parseColonTypeList(types) || parser.resolveOperands(operandInfos, types, loc, result.operands) || - parser.addTypeToList(TupleType::get(types, ctx), result.types)); + parser.addTypeToList(TupleType::get(ctx, types), result.types)); } static void print(OpAsmPrinter &p, TupleOp op) { diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -215,7 +215,7 @@ // Create Vector type and add to 'vectorTypes[i]'. vectorTypes[i] = VectorType::get(sliceSizes, vectorType.getElementType()); } - return TupleType::get(vectorTypes, builder.getContext()); + return builder.getTupleType(vectorTypes); } // UnrolledVectorState aggregates per-operand/result vector state required for diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -52,27 +52,27 @@ IndexType Builder::getIndexType() { return IndexType::get(context); } -IntegerType Builder::getI1Type() { return IntegerType::get(1, context); } +IntegerType Builder::getI1Type() { return IntegerType::get(context, 1); } -IntegerType Builder::getI32Type() { return IntegerType::get(32, context); } +IntegerType Builder::getI32Type() { return IntegerType::get(context, 32); } -IntegerType Builder::getI64Type() { return IntegerType::get(64, context); } +IntegerType Builder::getI64Type() { return IntegerType::get(context, 64); } IntegerType Builder::getIntegerType(unsigned width) { - return IntegerType::get(width, context); + return IntegerType::get(context, width); } IntegerType Builder::getIntegerType(unsigned width, bool isSigned) { return IntegerType::get( - width, isSigned ? IntegerType::Signed : IntegerType::Unsigned, context); + context, width, isSigned ? IntegerType::Signed : IntegerType::Unsigned); } FunctionType Builder::getFunctionType(TypeRange inputs, TypeRange results) { - return FunctionType::get(inputs, results, context); + return FunctionType::get(context, inputs, results); } TupleType Builder::getTupleType(TypeRange elementTypes) { - return TupleType::get(elementTypes, context); + return TupleType::get(context, elementTypes); } NoneType Builder::getNoneType() { return NoneType::get(context); } diff --git a/mlir/lib/IR/BuiltinDialect.cpp b/mlir/lib/IR/BuiltinDialect.cpp --- a/mlir/lib/IR/BuiltinDialect.cpp +++ b/mlir/lib/IR/BuiltinDialect.cpp @@ -179,7 +179,7 @@ for (unsigned i = 0, e = getNumArguments(); i != e; ++i) if (!mapper.contains(getArgument(i))) inputTypes.push_back(newType.getInput(i)); - newType = FunctionType::get(inputTypes, newType.getResults(), getContext()); + newType = FunctionType::get(getContext(), inputTypes, newType.getResults()); } // Create the new function. diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -35,7 +35,7 @@ return Base::get(elementType.getContext(), elementType); } -ComplexType ComplexType::getChecked(Type elementType, Location location) { +ComplexType ComplexType::getChecked(Location location, Type elementType) { return Base::getChecked(location, elementType); } @@ -76,7 +76,7 @@ IntegerType IntegerType::scaleElementBitwidth(unsigned scale) { if (!scale) return IntegerType(); - return IntegerType::get(scale * getWidth(), getSignedness(), getContext()); + return IntegerType::get(getContext(), scale * getWidth(), getSignedness()); } //===----------------------------------------------------------------------===// @@ -126,8 +126,8 @@ // FunctionType //===----------------------------------------------------------------------===// -FunctionType FunctionType::get(TypeRange inputs, TypeRange results, - MLIRContext *context) { +FunctionType FunctionType::get(MLIRContext *context, TypeRange inputs, + TypeRange results) { return Base::get(context, inputs, results); } @@ -182,20 +182,20 @@ newResultTypes = newResultTypesBuffer; } - return get(newInputTypes, newResultTypes, getContext()); + return get(getContext(), newInputTypes, newResultTypes); } //===----------------------------------------------------------------------===// // OpaqueType //===----------------------------------------------------------------------===// -OpaqueType OpaqueType::get(Identifier dialect, StringRef typeData, - MLIRContext *context) { +OpaqueType OpaqueType::get(MLIRContext *context, Identifier dialect, + StringRef typeData) { return Base::get(context, dialect, typeData); } -OpaqueType OpaqueType::getChecked(Identifier dialect, StringRef typeData, - MLIRContext *context, Location location) { +OpaqueType OpaqueType::getChecked(Location location, Identifier dialect, + StringRef typeData) { return Base::getChecked(location, dialect, typeData); } @@ -313,8 +313,8 @@ return Base::get(elementType.getContext(), shape, elementType); } -VectorType VectorType::getChecked(ArrayRef shape, Type elementType, - Location location) { +VectorType VectorType::getChecked(Location location, ArrayRef shape, + Type elementType) { return Base::getChecked(location, shape, elementType); } @@ -379,9 +379,9 @@ return Base::get(elementType.getContext(), shape, elementType); } -RankedTensorType RankedTensorType::getChecked(ArrayRef shape, - Type elementType, - Location location) { +RankedTensorType RankedTensorType::getChecked(Location location, + ArrayRef shape, + Type elementType) { return Base::getChecked(location, shape, elementType); } @@ -406,8 +406,8 @@ return Base::get(elementType.getContext(), elementType); } -UnrankedTensorType UnrankedTensorType::getChecked(Type elementType, - Location location) { +UnrankedTensorType UnrankedTensorType::getChecked(Location location, + Type elementType) { return Base::getChecked(location, elementType); } @@ -448,9 +448,10 @@ /// UnknownLoc. If the MemRefType defined by the arguments would be /// ill-formed, emits errors (to the handler registered with the context or to /// the error stream) and returns nullptr. -MemRefType MemRefType::getChecked(ArrayRef shape, Type elementType, +MemRefType MemRefType::getChecked(Location location, ArrayRef shape, + Type elementType, ArrayRef affineMapComposition, - unsigned memorySpace, Location location) { + unsigned memorySpace) { return getImpl(shape, elementType, affineMapComposition, memorySpace, location); } @@ -524,9 +525,9 @@ return Base::get(elementType.getContext(), elementType, memorySpace); } -UnrankedMemRefType UnrankedMemRefType::getChecked(Type elementType, - unsigned memorySpace, - Location location) { +UnrankedMemRefType UnrankedMemRefType::getChecked(Location location, + Type elementType, + unsigned memorySpace) { return Base::getChecked(location, elementType, memorySpace); } @@ -694,12 +695,12 @@ /// Get or create a new TupleType with the provided element types. Assumes the /// arguments define a well-formed type. -TupleType TupleType::get(TypeRange elementTypes, MLIRContext *context) { +TupleType TupleType::get(MLIRContext *context, TypeRange elementTypes) { return Base::get(context, elementTypes); } /// Get or create an empty tuple type. -TupleType TupleType::get(MLIRContext *context) { return get({}, context); } +TupleType TupleType::get(MLIRContext *context) { return get(context, {}); } /// Return the elements types for this tuple. ArrayRef TupleType::getTypes() const { return getImpl()->getTypes(); } diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp --- a/mlir/lib/IR/Dialect.cpp +++ b/mlir/lib/IR/Dialect.cpp @@ -82,7 +82,7 @@ // If this dialect allows unknown types, then represent this with OpaqueType. if (allowsUnknownTypes()) { auto ns = Identifier::get(getNamespace(), getContext()); - return OpaqueType::get(ns, parser.getFullSymbolSpec(), getContext()); + return OpaqueType::get(getContext(), ns, parser.getFullSymbolSpec()); } parser.emitError(parser.getNameLoc()) diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -772,25 +772,23 @@ } } -IntegerType IntegerType::get(unsigned width, MLIRContext *context) { - return get(width, IntegerType::Signless, context); +IntegerType IntegerType::get(MLIRContext *context, unsigned width) { + return get(context, width, IntegerType::Signless); } -IntegerType IntegerType::get(unsigned width, - IntegerType::SignednessSemantics signedness, - MLIRContext *context) { +IntegerType IntegerType::get(MLIRContext *context, unsigned width, + IntegerType::SignednessSemantics signedness) { if (auto cached = getCachedIntegerType(width, signedness, context)) return cached; return Base::get(context, width, signedness); } -IntegerType IntegerType::getChecked(unsigned width, Location location) { - return getChecked(width, IntegerType::Signless, location); +IntegerType IntegerType::getChecked(Location location, unsigned width) { + return getChecked(location, width, IntegerType::Signless); } -IntegerType IntegerType::getChecked(unsigned width, - SignednessSemantics signedness, - Location location) { +IntegerType IntegerType::getChecked(Location location, unsigned width, + SignednessSemantics signedness) { if (auto cached = getCachedIntegerType(width, signedness, location->getContext())) return cached; diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -178,7 +178,7 @@ if (hasSingleResult) resultType = resultTypes.front(); else - resultType = TupleType::get(resultTypes, location->getContext()); + resultType = TupleType::get(location->getContext(), resultTypes); } } diff --git a/mlir/lib/IR/Value.cpp b/mlir/lib/IR/Value.cpp --- a/mlir/lib/IR/Value.cpp +++ b/mlir/lib/IR/Value.cpp @@ -63,7 +63,7 @@ return; auto newTypes = llvm::to_vector<4>(curTypes); newTypes[resultNo] = newType; - owner->resultType = TupleType::get(newTypes, newType.getContext()); + owner->resultType = TupleType::get(newType.getContext(), newTypes); } /// If this value is the result of an Operation, return the operation that diff --git a/mlir/lib/Parser/DialectSymbolParser.cpp b/mlir/lib/Parser/DialectSymbolParser.cpp --- a/mlir/lib/Parser/DialectSymbolParser.cpp +++ b/mlir/lib/Parser/DialectSymbolParser.cpp @@ -563,8 +563,8 @@ // Otherwise, form a new opaque type. return OpaqueType::getChecked( - Identifier::get(dialectName, state.context), symbolData, - state.context, getEncodedSourceLocation(loc)); + getEncodedSourceLocation(loc), + Identifier::get(dialectName, state.context), symbolData); }); } diff --git a/mlir/lib/Parser/TypeParser.cpp b/mlir/lib/Parser/TypeParser.cpp --- a/mlir/lib/Parser/TypeParser.cpp +++ b/mlir/lib/Parser/TypeParser.cpp @@ -338,7 +338,7 @@ signSemantics = *signedness ? IntegerType::Signed : IntegerType::Unsigned; consumeToken(Token::inttype); - return IntegerType::get(width.getValue(), signSemantics, getContext()); + return IntegerType::get(getContext(), width.getValue(), signSemantics); } // float-type @@ -432,7 +432,7 @@ parseToken(Token::greater, "expected '>' in tuple type")) return nullptr; - return TupleType::get(types, getContext()); + return TupleType::get(getContext(), types); } /// Parse a vector type. diff --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp --- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp @@ -236,7 +236,7 @@ Attribute Importer::getConstantAsAttr(llvm::Constant *value) { if (auto *ci = dyn_cast(value)) return b.getIntegerAttr( - IntegerType::get(ci->getType()->getBitWidth(), context), + IntegerType::get(context, ci->getType()->getBitWidth()), ci->getValue()); if (auto *c = dyn_cast(value)) if (c->isString()) diff --git a/mlir/lib/Target/SPIRV/Deserialization.cpp b/mlir/lib/Target/SPIRV/Deserialization.cpp --- a/mlir/lib/Target/SPIRV/Deserialization.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization.cpp @@ -1182,7 +1182,7 @@ // signless semantics for such cases. auto sign = operands[2] == 1 ? IntegerType::SignednessSemantics::Signed : IntegerType::SignednessSemantics::Signless; - typeMap[operands[0]] = IntegerType::get(operands[1], sign, context); + typeMap[operands[0]] = IntegerType::get(context, operands[1], sign); } break; case spirv::Opcode::OpTypeFloat: { if (operands.size() != 2) @@ -1345,7 +1345,7 @@ if (!isVoidType(returnType)) { returnTypes = llvm::makeArrayRef(returnType); } - typeMap[operands[0]] = FunctionType::get(argTypes, returnTypes, context); + typeMap[operands[0]] = FunctionType::get(context, argTypes, returnTypes); return success(); } diff --git a/mlir/lib/Target/SPIRV/Serialization.cpp b/mlir/lib/Target/SPIRV/Serialization.cpp --- a/mlir/lib/Target/SPIRV/Serialization.cpp +++ b/mlir/lib/Target/SPIRV/Serialization.cpp @@ -1267,7 +1267,7 @@ } typeEnum = spirv::Opcode::OpTypeCooperativeMatrixNV; auto getConstantOp = [&](uint32_t id) { - auto attr = IntegerAttr::get(IntegerType::get(32, type.getContext()), id); + auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id); return prepareConstantInt(loc, attr); }; operands.push_back(elementTypeID); diff --git a/mlir/lib/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Transforms/BufferResultsToOutParams.cpp --- a/mlir/lib/Transforms/BufferResultsToOutParams.cpp +++ b/mlir/lib/Transforms/BufferResultsToOutParams.cpp @@ -35,8 +35,8 @@ // Add the new arguments to the function type. auto newArgTypes = llvm::to_vector<6>( llvm::concat(functionType.getInputs(), erasedResultTypes)); - auto newFunctionType = FunctionType::get( - newArgTypes, functionType.getResults(), func.getContext()); + auto newFunctionType = FunctionType::get(func.getContext(), newArgTypes, + functionType.getResults()); func.setType(newFunctionType); // Transfer the result attributes to arg attributes. diff --git a/mlir/lib/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Transforms/NormalizeMemRefs.cpp --- a/mlir/lib/Transforms/NormalizeMemRefs.cpp +++ b/mlir/lib/Transforms/NormalizeMemRefs.cpp @@ -230,9 +230,8 @@ // We create a new function type and modify the function signature with this // new type. - newFuncType = FunctionType::get(/*inputs=*/argTypes, - /*results=*/resultTypes, - /*context=*/&getContext()); + newFuncType = FunctionType::get(&getContext(), /*inputs=*/argTypes, + /*results=*/resultTypes); } // Since we update the function signature, it might affect the result types at @@ -463,9 +462,9 @@ continue; } - FunctionType newFuncType = FunctionType::get(/*inputs=*/inputTypes, - /*results=*/resultTypes, - /*context=*/&getContext()); + FunctionType newFuncType = + FunctionType::get(&getContext(), /*inputs=*/inputTypes, + /*results=*/resultTypes); // Setting the new function signature for this external function. funcOp.setType(newFuncType); } diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -2522,8 +2522,8 @@ // Update the function signature in-place. rewriter.updateRootInPlace(funcOp, [&] { - funcOp.setType(FunctionType::get(result.getConvertedTypes(), newResults, - funcOp.getContext())); + funcOp.setType(FunctionType::get(funcOp.getContext(), + result.getConvertedTypes(), newResults)); }); return success(); } diff --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp --- a/mlir/test/EDSC/builder-api-test.cpp +++ b/mlir/test/EDSC/builder-api-test.cpp @@ -56,7 +56,7 @@ ArrayRef args = {}) { auto &ctx = globalContext(); auto function = FuncOp::create(UnknownLoc::get(&ctx), name, - FunctionType::get(args, results, &ctx)); + FunctionType::get(&ctx, args, results)); function.addEntryBlock(); return function; } @@ -277,7 +277,7 @@ TEST_FUNC(builder_cond_branch) { auto f = makeFunction("builder_cond_branch", {}, - {IntegerType::get(1, &globalContext())}); + {IntegerType::get(&globalContext(), 1)}); OpBuilder builder(f.getBody()); ScopedContext scope(builder, f.getLoc()); @@ -390,8 +390,8 @@ TEST_FUNC(zero_and_std_sign_extendi_op_i1_to_i8) { using namespace edsc::op; - auto i1Type = IntegerType::get(1, &globalContext()); - auto i8Type = IntegerType::get(8, &globalContext()); + auto i1Type = IntegerType::get(&globalContext(), 1); + auto i8Type = IntegerType::get(&globalContext(), 8); auto memrefType = MemRefType::get({}, i1Type, {}, 0); auto f = makeFunction("zero_and_std_sign_extendi_op", {}, {memrefType, memrefType}); @@ -414,7 +414,7 @@ } TEST_FUNC(operator_or) { - auto i1Type = IntegerType::get(/*width=*/1, &globalContext()); + auto i1Type = IntegerType::get(&globalContext(), /*width=*/1); auto f = makeFunction("operator_or", {}, {i1Type, i1Type}); OpBuilder builder(f.getBody()); @@ -435,7 +435,7 @@ } TEST_FUNC(operator_and) { - auto i1Type = IntegerType::get(/*width=*/1, &globalContext()); + auto i1Type = IntegerType::get(&globalContext(), /*width=*/1); auto f = makeFunction("operator_and", {}, {i1Type, i1Type}); OpBuilder builder(f.getBody()); @@ -536,7 +536,7 @@ TEST_FUNC(select_op_i32) { using namespace edsc::op; - auto i32Type = IntegerType::get(32, &globalContext()); + auto i32Type = IntegerType::get(&globalContext(), 32); auto memrefType = MemRefType::get( {ShapedType::kDynamicSize, ShapedType::kDynamicSize}, i32Type, {}, 0); auto f = makeFunction("select_op", {}, {memrefType}); diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -653,7 +653,7 @@ } int64_t dim = sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize; - auto type = IntegerType::get(17, context); + auto type = IntegerType::get(context, 17); inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type)); return success(); } diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -509,7 +509,7 @@ // Convert I42 to I43. if (t.isInteger(42)) { - results.push_back(IntegerType::get(43, t.getContext())); + results.push_back(IntegerType::get(t.getContext(), 43)); return success(); } diff --git a/mlir/test/lib/Transforms/TestDecomposeCallGraphTypes.cpp b/mlir/test/lib/Transforms/TestDecomposeCallGraphTypes.cpp --- a/mlir/test/lib/Transforms/TestDecomposeCallGraphTypes.cpp +++ b/mlir/test/lib/Transforms/TestDecomposeCallGraphTypes.cpp @@ -69,9 +69,7 @@ Location loc) -> Optional { if (inputs.size() == 1) return llvm::None; - TypeRange TypeRange = inputs.getTypes(); - SmallVector types(TypeRange.begin(), TypeRange.end()); - TupleType tuple = TupleType::get(types, builder.getContext()); + TupleType tuple = builder.getTupleType(inputs.getTypes()); Value value = builder.create(loc, tuple, inputs); return value; }); diff --git a/mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp b/mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp --- a/mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp +++ b/mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp @@ -59,7 +59,7 @@ } else { tensorType = RankedTensorType::get(shape, eleType); } - auto indicesType = RankedTensorType::get({1, 2}, IntegerType::get(64, ctx)); + auto indicesType = RankedTensorType::get({1, 2}, IntegerType::get(ctx, 64)); auto indices = DenseIntElementsAttr::get(indicesType, {APInt(64, 0), APInt(64, 0)}); auto valuesType = RankedTensorType::get({1}, eleType); @@ -77,7 +77,7 @@ TEST(QuantizationUtilsTest, convertFloatAttrUniform) { MLIRContext ctx; ctx.getOrLoadDialect(); - IntegerType convertedType = IntegerType::get(8, &ctx); + IntegerType convertedType = IntegerType::get(&ctx, 8); auto quantizedType = getTestQuantizedType(convertedType, &ctx); TestUniformQuantizedValueConverter converter(quantizedType); @@ -95,7 +95,7 @@ TEST(QuantizationUtilsTest, convertRankedDenseAttrUniform) { MLIRContext ctx; ctx.getOrLoadDialect(); - IntegerType convertedType = IntegerType::get(8, &ctx); + IntegerType convertedType = IntegerType::get(&ctx, 8); auto quantizedType = getTestQuantizedType(convertedType, &ctx); TestUniformQuantizedValueConverter converter(quantizedType); auto realValue = getTestElementsAttr>( @@ -120,7 +120,7 @@ TEST(QuantizationUtilsTest, convertRankedSplatAttrUniform) { MLIRContext ctx; ctx.getOrLoadDialect(); - IntegerType convertedType = IntegerType::get(8, &ctx); + IntegerType convertedType = IntegerType::get(&ctx, 8); auto quantizedType = getTestQuantizedType(convertedType, &ctx); TestUniformQuantizedValueConverter converter(quantizedType); auto realValue = getTestElementsAttr( @@ -145,7 +145,7 @@ TEST(QuantizationUtilsTest, convertRankedSparseAttrUniform) { MLIRContext ctx; ctx.getOrLoadDialect(); - IntegerType convertedType = IntegerType::get(8, &ctx); + IntegerType convertedType = IntegerType::get(&ctx, 8); auto quantizedType = getTestQuantizedType(convertedType, &ctx); TestUniformQuantizedValueConverter converter(quantizedType); auto realValue = getTestSparseElementsAttr(&ctx, {1, 2}); diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp --- a/mlir/unittests/IR/AttributeTest.cpp +++ b/mlir/unittests/IR/AttributeTest.cpp @@ -33,7 +33,7 @@ namespace { TEST(DenseSplatTest, BoolSplat) { MLIRContext context; - IntegerType boolTy = IntegerType::get(1, &context); + IntegerType boolTy = IntegerType::get(&context, 1); RankedTensorType shape = RankedTensorType::get({2, 2}, boolTy); // Check that splat is automatically detected for boolean values. @@ -58,7 +58,7 @@ constexpr int64_t boolCount = 56; MLIRContext context; - IntegerType boolTy = IntegerType::get(1, &context); + IntegerType boolTy = IntegerType::get(&context, 1); RankedTensorType shape = RankedTensorType::get({boolCount}, boolTy); // Check that splat is automatically detected for boolean values. @@ -81,7 +81,7 @@ TEST(DenseSplatTest, BoolNonSplat) { MLIRContext context; - IntegerType boolTy = IntegerType::get(1, &context); + IntegerType boolTy = IntegerType::get(&context, 1); RankedTensorType shape = RankedTensorType::get({6}, boolTy); // Check that we properly handle non-splat values. @@ -94,7 +94,7 @@ // Test detecting a splat with an odd(non 8-bit) integer bitwidth. MLIRContext context; constexpr size_t intWidth = 19; - IntegerType intTy = IntegerType::get(intWidth, &context); + IntegerType intTy = IntegerType::get(&context, intWidth); APInt value(intWidth, 10); testSplat(intTy, value); @@ -102,7 +102,7 @@ TEST(DenseSplatTest, Int32Splat) { MLIRContext context; - IntegerType intTy = IntegerType::get(32, &context); + IntegerType intTy = IntegerType::get(&context, 32); int value = 64; testSplat(intTy, value); @@ -110,7 +110,7 @@ TEST(DenseSplatTest, IntAttrSplat) { MLIRContext context; - IntegerType intTy = IntegerType::get(85, &context); + IntegerType intTy = IntegerType::get(&context, 85); Attribute value = IntegerAttr::get(intTy, 109); testSplat(intTy, value); @@ -151,7 +151,7 @@ TEST(DenseSplatTest, StringSplat) { MLIRContext context; Type stringType = - OpaqueType::get(Identifier::get("test", &context), "string", &context); + OpaqueType::get(&context, Identifier::get("test", &context), "string"); StringRef value = "test-string"; testSplat(stringType, value); } @@ -159,7 +159,7 @@ TEST(DenseSplatTest, StringAttrSplat) { MLIRContext context; Type stringType = - OpaqueType::get(Identifier::get("test", &context), "string", &context); + OpaqueType::get(&context, Identifier::get("test", &context), "string"); Attribute stringAttr = StringAttr::get("test-string", stringType); testSplat(stringType, stringAttr); } @@ -173,7 +173,7 @@ TEST(DenseComplexTest, ComplexIntSplat) { MLIRContext context; - ComplexType complexType = ComplexType::get(IntegerType::get(64, &context)); + ComplexType complexType = ComplexType::get(IntegerType::get(&context, 64)); std::complex value(10, 15); testSplat(complexType, value); } @@ -187,7 +187,7 @@ TEST(DenseComplexTest, ComplexAPIntSplat) { MLIRContext context; - ComplexType complexType = ComplexType::get(IntegerType::get(64, &context)); + ComplexType complexType = ComplexType::get(IntegerType::get(&context, 64)); std::complex value(APInt(64, 10), APInt(64, 15)); testSplat(complexType, value); } diff --git a/mlir/unittests/TableGen/StructsGenTest.cpp b/mlir/unittests/TableGen/StructsGenTest.cpp --- a/mlir/unittests/TableGen/StructsGenTest.cpp +++ b/mlir/unittests/TableGen/StructsGenTest.cpp @@ -25,7 +25,7 @@ /// Helper that returns an example test::TestStruct for testing its /// implementation. static test::TestStruct getTestStruct(mlir::MLIRContext *context) { - auto integerType = mlir::IntegerType::get(32, context); + auto integerType = mlir::IntegerType::get(context, 32); auto integerAttr = mlir::IntegerAttr::get(integerType, 127); auto floatType = mlir::FloatType::getF32(context); @@ -105,7 +105,7 @@ expectedValues.begin(), expectedValues.end() - 1); // Add a copy of the last attribute with the wrong type. - auto i64Type = mlir::IntegerType::get(64, &context); + auto i64Type = mlir::IntegerType::get(&context, 64); auto elementsType = mlir::RankedTensorType::get({3}, i64Type); auto elementsAttr = mlir::DenseIntElementsAttr::get(elementsType, ArrayRef{1, 2, 3});