diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td --- a/flang/include/flang/Optimizer/Dialect/FIROps.td +++ b/flang/include/flang/Optimizer/Dialect/FIROps.td @@ -265,27 +265,27 @@ static constexpr llvm::StringRef inType() { return "in_type"; } static constexpr llvm::StringRef lenpName() { return "len_param_count"; } mlir::Type getAllocatedType(); - + bool hasLenParams() { return bool{(*this)->getAttr(lenpName())}; } - + unsigned numLenParams() { if (auto val = (*this)->getAttrOfType(lenpName())) return val.getInt(); return 0; } - + operand_range getLenParams() { return {operand_begin(), operand_begin() + numLenParams()}; } - + unsigned numShapeOperands() { return operand_end() - operand_begin() + numLenParams(); } - + operand_range getShapeOperands() { return {operand_begin() + numLenParams(), operand_end()}; } - + static mlir::Type getRefTy(mlir::Type ty); /// Get the input type of the allocation @@ -1129,7 +1129,7 @@ }]; let arguments = (ins AnyReferenceLike:$memref, AnyIntegerLike:$len); - + let results = (outs fir_BoxCharType); let assemblyFormat = [{ @@ -1561,7 +1561,7 @@ p.printFunctionalType((*this)->getOperandTypes(), (*this)->getResultTypes()); }]; - + let verifier = [{ auto refTy = ref().getType(); if (fir::isa_ref_type(refTy)) { @@ -1596,7 +1596,7 @@ CArg<"ArrayRef", "{}">:$attrs)>, OpBuilderDAG<(ins "Type":$type, "ValueRange":$operands, CArg<"ArrayRef", "{}">:$attrs)>]; - + let extraClassDeclaration = [{ static constexpr llvm::StringRef baseType() { return "base_type"; } mlir::Type getBaseType(); @@ -1684,7 +1684,7 @@ let printer = [{ p << getOperationName() << ' ' - << (*this)->getAttrOfType(fieldAttrName()).getValue() + << (*this)->getAttrOfType(fieldAttrName()).getValue() << ", " << (*this)->getAttr(typeAttrName()); if (getNumOperands()) { p << '('; @@ -2005,7 +2005,7 @@ CArg<"ValueRange", "llvm::None">:$iterArgs, CArg<"ArrayRef", "{}">:$attributes)> ]; - + let extraClassDeclaration = [{ mlir::Block *getBody() { return ®ion().front(); } mlir::Value getIterateVar() { return getBody()->getArgument(1); } @@ -2274,11 +2274,11 @@ }]; let arguments = (ins FirRealAttr:$constant); - + let results = (outs fir_RealType:$res); let assemblyFormat = "`(` $constant `)` attr-dict `:` type($res)"; - + let verifier = [{ if (!getType().isa()) return emitOpError("must be a !fir.real type"); @@ -2355,7 +2355,7 @@ }]; let results = (outs fir_ComplexType); - + let parser = [{ fir::RealAttr realp; fir::RealAttr imagp; @@ -2453,7 +2453,7 @@ def fir_AddrOfOp : fir_OneResultOp<"address_of", [NoSideEffect]> { let summary = "convert a symbol to an SSA value"; - + let description = [{ Convert a symbol (a function or global reference) to an SSA-value to be used in other Operations. @@ -2472,7 +2472,7 @@ def fir_ConvertOp : fir_OneResultOp<"convert", [NoSideEffect]> { let summary = "encapsulates all Fortran scalar type conversions"; - + let description = [{ Generalized type conversion. Convert the ssa value from type T to type U. Not all pairs of types have conversions. When types T and U are the same @@ -2703,7 +2703,7 @@ mlir::Type resultType() { return fir::AllocaOp::wrapResultType(getType()); } - + /// Return the initializer attribute if it exists, or a null attribute. Attribute getValueOrNull() { return initVal().getValueOr(Attribute()); } @@ -2726,9 +2726,9 @@ } mlir::FlatSymbolRefAttr getSymbol() { - return mlir::FlatSymbolRefAttr::get( + return mlir::FlatSymbolRefAttr::get(getContext(), (*this)->getAttrOfType( - mlir::SymbolTable::getSymbolAttrName()).getValue(), getContext()); + mlir::SymbolTable::getSymbolAttrName()).getValue()); } }]; } @@ -2770,7 +2770,7 @@ }]; let printer = [{ - p << getOperationName() << ' ' << (*this)->getAttr(lenParamAttrName()) + p << getOperationName() << ' ' << (*this)->getAttr(lenParamAttrName()) << ", " << (*this)->getAttr(intAttrName()); }]; diff --git a/flang/lib/Lower/FIRBuilder.cpp b/flang/lib/Lower/FIRBuilder.cpp --- a/flang/lib/Lower/FIRBuilder.cpp +++ b/flang/lib/Lower/FIRBuilder.cpp @@ -173,7 +173,7 @@ fir::StringLitOp Fortran::lower::FirOpBuilder::createStringLit( mlir::Location loc, mlir::Type eleTy, llvm::StringRef data) { - auto strAttr = mlir::StringAttr::get(data, getContext()); + auto strAttr = mlir::StringAttr::get(getContext(), data); auto valTag = mlir::Identifier::get(fir::StringLitOp::value(), getContext()); mlir::NamedAttribute dataAttr(valTag, strAttr); auto sizeTag = mlir::Identifier::get(fir::StringLitOp::size(), getContext()); diff --git a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp --- a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp +++ b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp @@ -107,7 +107,7 @@ ModuleOp module) { auto *context = module.getContext(); if (module.lookupSymbol("printf")) - return SymbolRefAttr::get("printf", context); + return SymbolRefAttr::get(context, "printf"); // Create a function declaration for printf, the signature is: // * `i32 (i8*, ...)` @@ -120,7 +120,7 @@ PatternRewriter::InsertionGuard insertGuard(rewriter); rewriter.setInsertionPointToStart(module.getBody()); rewriter.create(module.getLoc(), "printf", llvmFnType); - return SymbolRefAttr::get("printf", context); + return SymbolRefAttr::get(context, "printf"); } /// Return a value representing an access into a global string with the given diff --git a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp --- a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp +++ b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp @@ -107,7 +107,7 @@ ModuleOp module) { auto *context = module.getContext(); if (module.lookupSymbol("printf")) - return SymbolRefAttr::get("printf", context); + return SymbolRefAttr::get(context, "printf"); // Create a function declaration for printf, the signature is: // * `i32 (i8*, ...)` @@ -120,7 +120,7 @@ PatternRewriter::InsertionGuard insertGuard(rewriter); rewriter.setInsertionPointToStart(module.getBody()); rewriter.create(module.getLoc(), "printf", llvmFnType); - return SymbolRefAttr::get("printf", context); + return SymbolRefAttr::get(context, "printf"); } /// Return a value representing an access into a global string with the given diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h --- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h @@ -31,7 +31,7 @@ auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, context)); auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, context)); auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, context)); - auto maps = ArrayAttr::get({mapA, mapB, mapC}, context); + auto maps = ArrayAttr::get(context, {mapA, mapB, mapC}); return indexingMaps == maps; } @@ -42,7 +42,7 @@ auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, context)); auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, context)); auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {n, m}, context)); - auto maps = ArrayAttr::get({mapA, mapB, mapC}, context); + auto maps = ArrayAttr::get(context, {mapA, mapB, mapC}); return indexingMaps == maps; } diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h --- a/mlir/include/mlir/IR/BuiltinAttributes.h +++ b/mlir/include/mlir/IR/BuiltinAttributes.h @@ -69,7 +69,7 @@ using Base::Base; using ValueType = ArrayRef; - static ArrayAttr get(ArrayRef value, MLIRContext *context); + static ArrayAttr get(MLIRContext *context, ArrayRef value); ArrayRef getValue() const; Attribute operator[](unsigned idx) const; @@ -126,8 +126,8 @@ /// attributes. This method assumes that the provided list is unordered. If /// the caller can guarantee that the attributes are ordered by name, /// getWithSorted should be used instead. - static DictionaryAttr get(ArrayRef value, - MLIRContext *context); + static DictionaryAttr get(MLIRContext *context, + ArrayRef value); /// Construct a dictionary with an array of values that is known to already be /// sorted by name and uniqued. @@ -250,7 +250,7 @@ using Attribute::Attribute; using ValueType = bool; - static BoolAttr get(bool value, MLIRContext *context); + static BoolAttr get(MLIRContext *context, bool value); /// Enable conversion to IntegerAttr. This uses conversion vs. inheritance to /// avoid bringing in all of IntegerAttrs methods. @@ -292,8 +292,8 @@ using Base::Base; /// Get or create a new OpaqueAttr with the provided dialect and string data. - static OpaqueAttr get(Identifier dialect, StringRef attrData, Type type, - MLIRContext *context); + static OpaqueAttr get(MLIRContext *context, Identifier dialect, + StringRef attrData, Type type); /// Get or create a new OpaqueAttr with the provided dialect and string data. /// If the given identifier is not a valid namespace for a dialect, then a @@ -325,7 +325,7 @@ using ValueType = StringRef; /// Get an instance of a StringAttr with the given string. - static StringAttr get(StringRef bytes, MLIRContext *context); + static StringAttr get(MLIRContext *context, StringRef bytes); /// Get an instance of a StringAttr with the given string and Type. static StringAttr get(StringRef bytes, Type type); @@ -348,13 +348,12 @@ using Base::Base; /// Construct a symbol reference for the given value name. - static FlatSymbolRefAttr get(StringRef value, MLIRContext *ctx); + static FlatSymbolRefAttr get(MLIRContext *ctx, StringRef value); /// Construct a symbol reference for the given value name, and a set of nested /// references that are further resolve to a nested symbol. - static SymbolRefAttr get(StringRef value, - ArrayRef references, - MLIRContext *ctx); + static SymbolRefAttr get(MLIRContext *ctx, StringRef value, + ArrayRef references); /// Returns the name of the top level symbol reference, i.e. the root of the /// reference path. @@ -377,8 +376,8 @@ using ValueType = StringRef; /// Construct a symbol reference for the given value name. - static FlatSymbolRefAttr get(StringRef value, MLIRContext *ctx) { - return SymbolRefAttr::get(value, ctx); + static FlatSymbolRefAttr get(MLIRContext *ctx, StringRef value) { + return SymbolRefAttr::get(ctx, value); } /// Returns the name of the held symbol reference. diff --git a/mlir/include/mlir/IR/FunctionSupport.h b/mlir/include/mlir/IR/FunctionSupport.h --- a/mlir/include/mlir/IR/FunctionSupport.h +++ b/mlir/include/mlir/IR/FunctionSupport.h @@ -569,7 +569,7 @@ if (attributes.empty()) return (void)static_cast(this)->removeAttr(nameOut); Operation *op = this->getOperation(); - op->setAttr(nameOut, DictionaryAttr::get(attributes, op->getContext())); + op->setAttr(nameOut, DictionaryAttr::get(op->getContext(), attributes)); } template diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -315,7 +315,7 @@ attrs = newAttrs; } void setAttrs(ArrayRef newAttrs) { - setAttrs(DictionaryAttr::get(newAttrs, getContext())); + setAttrs(DictionaryAttr::get(getContext(), newAttrs)); } /// Return the specified attribute if present, null otherwise. diff --git a/mlir/include/mlir/IR/SymbolInterfaces.td b/mlir/include/mlir/IR/SymbolInterfaces.td --- a/mlir/include/mlir/IR/SymbolInterfaces.td +++ b/mlir/include/mlir/IR/SymbolInterfaces.td @@ -44,7 +44,7 @@ /*defaultImplementation=*/[{ this->getOperation()->setAttr( mlir::SymbolTable::getSymbolAttrName(), - StringAttr::get(name, this->getOperation()->getContext())); + StringAttr::get(this->getOperation()->getContext(), name)); }] >, InterfaceMethod<"Gets the visibility of this symbol.", diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -42,9 +42,9 @@ MlirAttribute mlirArrayAttrGet(MlirContext ctx, intptr_t numElements, MlirAttribute const *elements) { SmallVector attrs; - return wrap(ArrayAttr::get( - unwrapList(static_cast(numElements), elements, attrs), - unwrap(ctx))); + return wrap( + ArrayAttr::get(unwrap(ctx), unwrapList(static_cast(numElements), + elements, attrs))); } intptr_t mlirArrayAttrGetNumElements(MlirAttribute attr) { @@ -71,7 +71,7 @@ attributes.emplace_back( Identifier::get(unwrap(elements[i].name), unwrap(ctx)), unwrap(elements[i].attribute)); - return wrap(DictionaryAttr::get(attributes, unwrap(ctx))); + return wrap(DictionaryAttr::get(unwrap(ctx), attributes)); } intptr_t mlirDictionaryAttrGetNumElements(MlirAttribute attr) { @@ -137,7 +137,7 @@ } MlirAttribute mlirBoolAttrGet(MlirContext ctx, int value) { - return wrap(BoolAttr::get(value, unwrap(ctx))); + return wrap(BoolAttr::get(unwrap(ctx), value)); } bool mlirBoolAttrGetValue(MlirAttribute attr) { @@ -163,9 +163,9 @@ MlirAttribute mlirOpaqueAttrGet(MlirContext ctx, MlirStringRef dialectNamespace, intptr_t dataLength, const char *data, MlirType type) { - return wrap( - OpaqueAttr::get(Identifier::get(unwrap(dialectNamespace), unwrap(ctx)), - StringRef(data, dataLength), unwrap(type), unwrap(ctx))); + return wrap(OpaqueAttr::get( + unwrap(ctx), Identifier::get(unwrap(dialectNamespace), unwrap(ctx)), + StringRef(data, dataLength), unwrap(type))); } MlirStringRef mlirOpaqueAttrGetDialectNamespace(MlirAttribute attr) { @@ -185,7 +185,7 @@ } MlirAttribute mlirStringAttrGet(MlirContext ctx, MlirStringRef str) { - return wrap(StringAttr::get(unwrap(str), unwrap(ctx))); + return wrap(StringAttr::get(unwrap(ctx), unwrap(str))); } MlirAttribute mlirStringAttrTypedGet(MlirType type, MlirStringRef str) { @@ -211,7 +211,7 @@ refs.reserve(numReferences); for (intptr_t i = 0; i < numReferences; ++i) refs.push_back(unwrap(references[i]).cast()); - return wrap(SymbolRefAttr::get(unwrap(symbol), refs, unwrap(ctx))); + return wrap(SymbolRefAttr::get(unwrap(ctx), unwrap(symbol), refs)); } MlirStringRef mlirSymbolRefAttrGetRootReference(MlirAttribute attr) { @@ -241,7 +241,7 @@ } MlirAttribute mlirFlatSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol) { - return wrap(FlatSymbolRefAttr::get(unwrap(symbol), unwrap(ctx))); + return wrap(FlatSymbolRefAttr::get(unwrap(ctx), unwrap(symbol))); } MlirStringRef mlirFlatSymbolRefAttrGetValue(MlirAttribute attr) { diff --git a/mlir/lib/Conversion/GPUCommon/ConvertKernelFuncToBlob.cpp b/mlir/lib/Conversion/GPUCommon/ConvertKernelFuncToBlob.cpp --- a/mlir/lib/Conversion/GPUCommon/ConvertKernelFuncToBlob.cpp +++ b/mlir/lib/Conversion/GPUCommon/ConvertKernelFuncToBlob.cpp @@ -148,7 +148,7 @@ auto blob = convertModuleToBlob(llvmModule, loc, name); if (!blob) return {}; - return StringAttr::get({blob->data(), blob->size()}, loc->getContext()); + return StringAttr::get(loc->getContext(), {blob->data(), blob->size()}); } std::unique_ptr> diff --git a/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp --- a/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp +++ b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp @@ -177,12 +177,12 @@ // Set SPIR-V binary shader data as an attribute. vulkanLaunchCallOp->setAttr( kSPIRVBlobAttrName, - StringAttr::get({binary.data(), binary.size()}, loc->getContext())); + StringAttr::get(loc->getContext(), {binary.data(), binary.size()})); // Set entry point name as an attribute. vulkanLaunchCallOp->setAttr( kSPIRVEntryPointAttrName, - StringAttr::get(launchOp.getKernelName(), loc->getContext())); + StringAttr::get(loc->getContext(), launchOp.getKernelName())); launchOp.erase(); } diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp @@ -687,8 +687,8 @@ rewriter.create(loc, llvmI32Type, executionModeAttr); structValue = rewriter.create( loc, structType, structValue, executionMode, - ArrayAttr::get({rewriter.getIntegerAttr(rewriter.getI32Type(), 0)}, - context)); + ArrayAttr::get(context, + {rewriter.getIntegerAttr(rewriter.getI32Type(), 0)})); // Insert extra operands if they exist into execution mode info struct. for (unsigned i = 0, e = values.size(); i < e; ++i) { @@ -696,9 +696,9 @@ Value entry = rewriter.create(loc, llvmI32Type, attr); structValue = rewriter.create( loc, structType, structValue, entry, - ArrayAttr::get({rewriter.getIntegerAttr(rewriter.getI32Type(), 1), - rewriter.getIntegerAttr(rewriter.getI32Type(), i)}, - context)); + ArrayAttr::get(context, + {rewriter.getIntegerAttr(rewriter.getI32Type(), 1), + rewriter.getIntegerAttr(rewriter.getI32Type(), i)})); } rewriter.create(loc, ArrayRef({structValue})); rewriter.eraseOp(op); @@ -1297,17 +1297,17 @@ switch (funcOp.function_control()) { #define DISPATCH(functionControl, llvmAttr) \ case functionControl: \ - newFuncOp->setAttr("passthrough", ArrayAttr::get({llvmAttr}, context)); \ + newFuncOp->setAttr("passthrough", ArrayAttr::get(context, {llvmAttr})); \ break; DISPATCH(spirv::FunctionControl::Inline, - StringAttr::get("alwaysinline", context)); + StringAttr::get(context, "alwaysinline")); DISPATCH(spirv::FunctionControl::DontInline, - StringAttr::get("noinline", context)); + StringAttr::get(context, "noinline")); DISPATCH(spirv::FunctionControl::Pure, - StringAttr::get("readonly", context)); + StringAttr::get(context, "readonly")); DISPATCH(spirv::FunctionControl::Const, - StringAttr::get("readnone", context)); + StringAttr::get(context, "readnone")); #undef DISPATCH diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -4016,7 +4016,7 @@ if (failed(applyPartialConversion(m, target, std::move(patterns)))) signalPassFailure(); m->setAttr(LLVM::LLVMDialect::getDataLayoutAttrName(), - StringAttr::get(this->dataLayout, m.getContext())); + StringAttr::get(m.getContext(), this->dataLayout)); } }; } // end namespace diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -762,7 +762,7 @@ if (positionAttrs.size() > 1) { auto oneDVectorType = reducedVectorTypeBack(vectorType); auto nMinusOnePositionAttrs = - ArrayAttr::get(positionAttrs.drop_back(), context); + ArrayAttr::get(context, positionAttrs.drop_back()); extracted = rewriter.create( loc, typeConverter->convertType(oneDVectorType), extracted, nMinusOnePositionAttrs); @@ -871,7 +871,7 @@ if (positionAttrs.size() > 1) { oneDVectorType = reducedVectorTypeBack(destVectorType); auto nMinusOnePositionAttrs = - ArrayAttr::get(positionAttrs.drop_back(), context); + ArrayAttr::get(context, positionAttrs.drop_back()); extracted = rewriter.create( loc, typeConverter->convertType(oneDVectorType), extracted, nMinusOnePositionAttrs); @@ -887,7 +887,7 @@ // Potential insertion of resulting 1-D vector into array. if (positionAttrs.size() > 1) { auto nMinusOnePositionAttrs = - ArrayAttr::get(positionAttrs.drop_back(), context); + ArrayAttr::get(context, positionAttrs.drop_back()); inserted = rewriter.create(loc, llvmResultType, adaptor.dest(), inserted, nMinusOnePositionAttrs); diff --git a/mlir/lib/Dialect/GPU/Transforms/ParallelLoopMapper.cpp b/mlir/lib/Dialect/GPU/Transforms/ParallelLoopMapper.cpp --- a/mlir/lib/Dialect/GPU/Transforms/ParallelLoopMapper.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/ParallelLoopMapper.cpp @@ -53,7 +53,7 @@ } ArrayRef mappingAsAttrs(mapping.data(), mapping.size()); ploopOp->setAttr(getMappingAttrName(), - ArrayAttr::get(mappingAsAttrs, ploopOp.getContext())); + ArrayAttr::get(ploopOp.getContext(), mappingAsAttrs)); return success(); } } // namespace gpu diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -225,7 +225,7 @@ if (genericAttrNamesSet.count(attr.first.strref()) > 0) genericAttrs.push_back(attr); if (!genericAttrs.empty()) { - auto genericDictAttr = DictionaryAttr::get(genericAttrs, op.getContext()); + auto genericDictAttr = DictionaryAttr::get(op.getContext(), genericAttrs); p << genericDictAttr; } @@ -833,7 +833,7 @@ // Handle the corner case of the result being a rank 0 shaped type. Return an // emtpy ArrayAttr. if (mapsConsumer.empty() && !mapsProducer.empty()) - return ArrayAttr::get(ArrayRef(), context); + return ArrayAttr::get(context, ArrayRef()); if (mapsProducer.empty() || mapsConsumer.empty() || mapsProducer[0].getNumDims() < mapsConsumer[0].getNumDims() || mapsProducer.size() != mapsConsumer[0].getNumDims()) @@ -854,7 +854,7 @@ numLhsDims, /*numSymbols =*/0, reassociations, context))); reassociations.clear(); } - return ArrayAttr::get(reassociationMaps, context); + return ArrayAttr::get(context, reassociationMaps); } namespace { diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -137,11 +137,11 @@ // wrong, so abort. if (!inversePermutation(concatAffineMaps(newIndexingMaps))) return nullptr; - return ArrayAttr::get( - llvm::to_vector<4>(llvm::map_range( - newIndexingMaps, - [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); })), - context); + return ArrayAttr::get(context, + llvm::to_vector<4>(llvm::map_range( + newIndexingMaps, [](AffineMap map) -> Attribute { + return AffineMapAttr::get(map); + }))); } /// Modify the region of indexed generic op to drop arguments corresponding to @@ -220,7 +220,7 @@ rewriter.startRootUpdate(op); op.indexing_mapsAttr(newIndexingMapAttr); - op.iterator_typesAttr(ArrayAttr::get(newIteratorTypes, context)); + op.iterator_typesAttr(ArrayAttr::get(context, newIteratorTypes)); (void)replaceBlockArgForUnitDimLoops(op, unitDims, rewriter); rewriter.finalizeRootUpdate(op); return success(); @@ -282,7 +282,7 @@ RankedTensorType::get(newShape, type.getElementType()), AffineMap::get(indexMap.getNumDims(), indexMap.getNumSymbols(), newIndexExprs, context), - ArrayAttr::get(reassociationMaps, context)}; + ArrayAttr::get(context, reassociationMaps)}; return info; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp @@ -77,9 +77,9 @@ applyPermutationToVector(itTypesVector, interchangeVector); op->setAttr(getIndexingMapsAttrName(), - ArrayAttr::get(newIndexingMaps, context)); + ArrayAttr::get(context, newIndexingMaps)); op->setAttr(getIteratorTypesAttrName(), - ArrayAttr::get(itTypesVector, context)); + ArrayAttr::get(context, itTypesVector)); return op; } diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp @@ -98,7 +98,7 @@ }); for (auto &var : interfaceVarSet) { interfaceVars.push_back(SymbolRefAttr::get( - cast(var).sym_name(), funcOp.getContext())); + funcOp.getContext(), cast(var).sym_name())); } return success(); } diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -338,7 +338,7 @@ return a; } // If this is reached, all inputs were statically known passing. - return BoolAttr::get(true, getContext()); + return BoolAttr::get(getContext(), true); } static LogicalResult verify(AssumingAllOp op) { @@ -482,10 +482,10 @@ // Both operands are not needed if one is a scalar. if (operands[0] && operands[0].cast().getNumElements() == 0) - return BoolAttr::get(true, getContext()); + return BoolAttr::get(getContext(), true); if (operands[1] && operands[1].cast().getNumElements() == 0) - return BoolAttr::get(true, getContext()); + return BoolAttr::get(getContext(), true); if (operands[0] && operands[1]) { auto lhsShape = llvm::to_vector<6>( @@ -494,7 +494,7 @@ operands[1].cast().getValues()); SmallVector resultShape; if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape)) - return BoolAttr::get(true, getContext()); + return BoolAttr::get(getContext(), true); } // Lastly, see if folding can be completed based on what constraints are known @@ -506,7 +506,7 @@ return nullptr; if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape)) - return BoolAttr::get(true, getContext()); + return BoolAttr::get(getContext(), true); // Because a failing witness result here represents an eventual assertion // failure, we do not replace it with a constant witness. @@ -526,7 +526,7 @@ OpFoldResult CstrEqOp::fold(ArrayRef operands) { if (llvm::all_of(operands, [&](Attribute a) { return a && a == operands[0]; })) - return BoolAttr::get(true, getContext()); + return BoolAttr::get(getContext(), true); // Because a failing witness result here represents an eventual assertion // failure, we do not try to replace it with a constant witness. Similarly, we @@ -573,14 +573,14 @@ OpFoldResult ShapeEqOp::fold(ArrayRef operands) { if (lhs() == rhs()) - return BoolAttr::get(true, getContext()); + return BoolAttr::get(getContext(), true); auto lhs = operands[0].dyn_cast_or_null(); if (lhs == nullptr) return {}; auto rhs = operands[1].dyn_cast_or_null(); if (rhs == nullptr) return {}; - return BoolAttr::get(lhs == rhs, getContext()); + return BoolAttr::get(getContext(), lhs == rhs); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -844,7 +844,7 @@ if (lhs() == rhs()) { auto val = applyCmpPredicateToEqualOperands(getPredicate()); - return BoolAttr::get(val, getContext()); + return BoolAttr::get(getContext(), val); } auto lhs = operands.front().dyn_cast_or_null(); @@ -853,7 +853,7 @@ return {}; auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); - return BoolAttr::get(val, getContext()); + return BoolAttr::get(getContext(), val); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -246,7 +246,7 @@ if (traitAttrsSet.count(attr.first.strref()) > 0) attrs.push_back(attr); - auto dictAttr = DictionaryAttr::get(attrs, op.getContext()); + auto dictAttr = DictionaryAttr::get(op.getContext(), attrs); p << op.getOperationName() << " " << dictAttr << " " << op.lhs() << ", "; p << op.rhs() << ", " << op.acc(); if (op.masks().size() == 2) @@ -1444,7 +1444,7 @@ auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute { return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v)); }); - return ArrayAttr::get(llvm::to_vector<8>(attrs), context); + return ArrayAttr::get(context, llvm::to_vector<8>(attrs)); } static LogicalResult verify(InsertStridedSliceOp op) { diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -92,11 +92,11 @@ UnitAttr Builder::getUnitAttr() { return UnitAttr::get(context); } BoolAttr Builder::getBoolAttr(bool value) { - return BoolAttr::get(value, context); + return BoolAttr::get(context, value); } DictionaryAttr Builder::getDictionaryAttr(ArrayRef value) { - return DictionaryAttr::get(value, context); + return DictionaryAttr::get(context, value); } IntegerAttr Builder::getIndexAttr(int64_t value) { @@ -200,11 +200,11 @@ } StringAttr Builder::getStringAttr(StringRef bytes) { - return StringAttr::get(bytes, context); + return StringAttr::get(context, bytes); } ArrayAttr Builder::getArrayAttr(ArrayRef value) { - return ArrayAttr::get(value, context); + return ArrayAttr::get(context, value); } FlatSymbolRefAttr Builder::getSymbolRefAttr(Operation *value) { @@ -214,12 +214,12 @@ return getSymbolRefAttr(symName.getValue()); } FlatSymbolRefAttr Builder::getSymbolRefAttr(StringRef value) { - return SymbolRefAttr::get(value, getContext()); + return SymbolRefAttr::get(getContext(), value); } SymbolRefAttr Builder::getSymbolRefAttr(StringRef value, ArrayRef nestedReferences) { - return SymbolRefAttr::get(value, nestedReferences, getContext()); + return SymbolRefAttr::get(getContext(), value, nestedReferences); } ArrayAttr Builder::getBoolArrayAttr(ArrayRef values) { diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -35,7 +35,7 @@ // ArrayAttr //===----------------------------------------------------------------------===// -ArrayAttr ArrayAttr::get(ArrayRef value, MLIRContext *context) { +ArrayAttr ArrayAttr::get(MLIRContext *context, ArrayRef value) { return Base::get(context, value); } @@ -134,8 +134,8 @@ return findDuplicateElement(array); } -DictionaryAttr DictionaryAttr::get(ArrayRef value, - MLIRContext *context) { +DictionaryAttr DictionaryAttr::get(MLIRContext *context, + ArrayRef value) { if (value.empty()) return DictionaryAttr::getEmpty(context); assert(llvm::all_of(value, @@ -267,13 +267,12 @@ // SymbolRefAttr //===----------------------------------------------------------------------===// -FlatSymbolRefAttr SymbolRefAttr::get(StringRef value, MLIRContext *ctx) { +FlatSymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value) { return Base::get(ctx, value, llvm::None).cast(); } -SymbolRefAttr SymbolRefAttr::get(StringRef value, - ArrayRef nestedReferences, - MLIRContext *ctx) { +SymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value, + ArrayRef nestedReferences) { return Base::get(ctx, value, nestedReferences); } @@ -294,7 +293,7 @@ IntegerAttr IntegerAttr::get(Type type, const APInt &value) { if (type.isSignlessInteger(1)) - return BoolAttr::get(value.getBoolValue(), type.getContext()); + return BoolAttr::get(type.getContext(), value.getBoolValue()); return Base::get(type.getContext(), type, value); } @@ -377,8 +376,8 @@ // OpaqueAttr //===----------------------------------------------------------------------===// -OpaqueAttr OpaqueAttr::get(Identifier dialect, StringRef attrData, Type type, - MLIRContext *context) { +OpaqueAttr OpaqueAttr::get(MLIRContext *context, Identifier dialect, + StringRef attrData, Type type) { return Base::get(context, dialect, attrData, type); } @@ -409,7 +408,7 @@ // StringAttr //===----------------------------------------------------------------------===// -StringAttr StringAttr::get(StringRef bytes, MLIRContext *context) { +StringAttr StringAttr::get(MLIRContext *context, StringRef bytes) { return get(bytes, NoneType::get(context)); } diff --git a/mlir/lib/IR/BuiltinDialect.cpp b/mlir/lib/IR/BuiltinDialect.cpp --- a/mlir/lib/IR/BuiltinDialect.cpp +++ b/mlir/lib/IR/BuiltinDialect.cpp @@ -166,7 +166,7 @@ newAttrs.insert(attr); for (auto &attr : getAttrs()) newAttrs.insert(attr); - dest->setAttrs(DictionaryAttr::get(newAttrs.takeVector(), getContext())); + dest->setAttrs(DictionaryAttr::get(getContext(), newAttrs.takeVector())); // Clone the body. getBody().cloneInto(&dest.getBody(), mapper); diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -872,7 +872,7 @@ storage->setType(NoneType::get(ctx)); } -BoolAttr BoolAttr::get(bool value, MLIRContext *context) { +BoolAttr BoolAttr::get(MLIRContext *context, bool value) { return value ? context->getImpl().trueAttr : context->getImpl().falseAttr; } diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -76,7 +76,7 @@ ArrayRef attributes, BlockRange successors, unsigned numRegions) { return create(location, name, resultTypes, operands, - DictionaryAttr::get(attributes, location.getContext()), + DictionaryAttr::get(location.getContext(), attributes), successors, numRegions); } diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp --- a/mlir/lib/IR/SymbolTable.cpp +++ b/mlir/lib/IR/SymbolTable.cpp @@ -46,7 +46,7 @@ assert(within->isAncestor(symbol) && "expected 'within' to be an ancestor"); MLIRContext *ctx = symbol->getContext(); - auto leafRef = FlatSymbolRefAttr::get(symbolName, ctx); + auto leafRef = FlatSymbolRefAttr::get(ctx, symbolName); results.push_back(leafRef); // Early exit for when 'within' is the parent of 'symbol'. @@ -67,13 +67,13 @@ getNameIfSymbol(symbolTableOp, symbolNameId); if (!symbolTableName) return failure(); - results.push_back(SymbolRefAttr::get(*symbolTableName, nestedRefs, ctx)); + results.push_back(SymbolRefAttr::get(ctx, *symbolTableName, nestedRefs)); symbolTableOp = symbolTableOp->getParentOp(); if (symbolTableOp == within) break; nestedRefs.insert(nestedRefs.begin(), - FlatSymbolRefAttr::get(*symbolTableName, ctx)); + FlatSymbolRefAttr::get(ctx, *symbolTableName)); } while (true); return success(); } @@ -203,7 +203,7 @@ /// Sets the name of the given symbol operation. void SymbolTable::setSymbolName(Operation *symbol, StringRef name) { symbol->setAttr(getSymbolAttrName(), - StringAttr::get(name, symbol->getContext())); + StringAttr::get(symbol->getContext(), name)); } /// Returns the visibility of the given symbol operation. @@ -235,7 +235,7 @@ "unknown symbol visibility kind"); StringRef visName = vis == Visibility::Private ? "private" : "nested"; - symbol->setAttr(getVisibilityAttrName(), StringAttr::get(visName, ctx)); + symbol->setAttr(getVisibilityAttrName(), StringAttr::get(ctx, visName)); } /// Returns the nearest symbol table from a given operation `from`. Returns @@ -603,7 +603,7 @@ // doesn't support parent references. if (SymbolTable::getNearestSymbolTable(limit->getParentOp()) == symbol->getParentOp()) - return {{SymbolRefAttr::get(symName, symbol->getContext()), limit}}; + return {{SymbolRefAttr::get(symbol->getContext(), symName), limit}}; return {}; } @@ -659,7 +659,7 @@ template static SmallVector collectSymbolScopes(StringRef symbol, IRUnit *limit) { - return {{SymbolRefAttr::get(symbol, limit->getContext()), limit}}; + return {{SymbolRefAttr::get(limit->getContext(), symbol), limit}}; } /// Returns true if the given reference 'SubRef' is a sub reference of the @@ -825,11 +825,11 @@ if (auto dictAttr = container.dyn_cast()) { auto newAttrs = llvm::to_vector<4>(dictAttr.getValue()); updateAttrs(make_second_range(newAttrs)); - return DictionaryAttr::get(newAttrs, dictAttr.getContext()); + return DictionaryAttr::get(dictAttr.getContext(), newAttrs); } auto newAttrs = llvm::to_vector<4>(container.cast().getValue()); updateAttrs(newAttrs); - return ArrayAttr::get(newAttrs, container.getContext()); + return ArrayAttr::get(container.getContext(), newAttrs); } /// Generates a new symbol reference attribute with a new leaf reference. @@ -839,8 +839,8 @@ return newLeafAttr; auto nestedRefs = llvm::to_vector<2>(oldAttr.getNestedReferences()); nestedRefs.back() = newLeafAttr; - return SymbolRefAttr::get(oldAttr.getRootReference(), nestedRefs, - oldAttr.getContext()); + return SymbolRefAttr::get(oldAttr.getContext(), oldAttr.getRootReference(), + nestedRefs); } /// The implementation of SymbolTable::replaceAllSymbolUses below. @@ -867,7 +867,7 @@ // Generate a new attribute to replace the given attribute. MLIRContext *ctx = limit->getContext(); - FlatSymbolRefAttr newLeafAttr = FlatSymbolRefAttr::get(newSymbol, ctx); + FlatSymbolRefAttr newLeafAttr = FlatSymbolRefAttr::get(ctx, newSymbol); for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) { SymbolRefAttr newAttr = generateNewRefAttr(scope.symbol, newLeafAttr); auto walkFn = [&](SymbolTable::SymbolUse symbolUse, @@ -883,13 +883,13 @@ if (useRef != scope.symbol) { if (scope.symbol.isa()) { replacementRef = - SymbolRefAttr::get(newSymbol, useRef.getNestedReferences(), ctx); + SymbolRefAttr::get(ctx, newSymbol, useRef.getNestedReferences()); } else { auto nestedRefs = llvm::to_vector<4>(useRef.getNestedReferences()); nestedRefs[scope.symbol.getNestedReferences().size() - 1] = newLeafAttr; replacementRef = - SymbolRefAttr::get(useRef.getRootReference(), nestedRefs, ctx); + SymbolRefAttr::get(ctx, useRef.getRootReference(), nestedRefs); } } diff --git a/mlir/lib/Parser/AttributeParser.cpp b/mlir/lib/Parser/AttributeParser.cpp --- a/mlir/lib/Parser/AttributeParser.cpp +++ b/mlir/lib/Parser/AttributeParser.cpp @@ -148,7 +148,7 @@ return Attribute(); return type ? StringAttr::get(val, type) - : StringAttr::get(val, getContext()); + : StringAttr::get(getContext(), val); } // Parse a symbol reference attribute. @@ -176,7 +176,7 @@ std::string nameStr = getToken().getSymbolReference(); consumeToken(Token::at_identifier); - nestedRefs.push_back(SymbolRefAttr::get(nameStr, getContext())); + nestedRefs.push_back(SymbolRefAttr::get(getContext(), nameStr)); } return builder.getSymbolRefAttr(nameStr, nestedRefs); diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -742,7 +742,8 @@ body << " ::mlir::MLIRContext* ctx = getContext();\n"; body << " ::mlir::Builder odsBuilder(ctx); (void)odsBuilder;\n"; - body << " return ::mlir::DictionaryAttr::get({\n"; + body << " return ::mlir::DictionaryAttr::get("; + body << " ctx, {\n"; interleave( derivedAttrs, body, [&](const NamedAttribute &namedAttr) { @@ -755,7 +756,7 @@ << "}"; }, ",\n"); - body << "\n }, ctx);"; + body << "});"; } } } diff --git a/mlir/tools/mlir-tblgen/StructsGen.cpp b/mlir/tools/mlir-tblgen/StructsGen.cpp --- a/mlir/tools/mlir-tblgen/StructsGen.cpp +++ b/mlir/tools/mlir-tblgen/StructsGen.cpp @@ -150,7 +150,7 @@ } const char *getEndInfo = R"( - ::mlir::Attribute dict = ::mlir::DictionaryAttr::get(fields, context); + ::mlir::Attribute dict = ::mlir::DictionaryAttr::get(context, fields); return dict.dyn_cast<{0}>(); } )"; diff --git a/mlir/unittests/TableGen/StructsGenTest.cpp b/mlir/unittests/TableGen/StructsGenTest.cpp --- a/mlir/unittests/TableGen/StructsGenTest.cpp +++ b/mlir/unittests/TableGen/StructsGenTest.cpp @@ -67,7 +67,7 @@ newValues.push_back(wrongAttr); // Make a new DictionaryAttr and validate. - auto badDictionary = mlir::DictionaryAttr::get(newValues, &context); + auto badDictionary = mlir::DictionaryAttr::get(&context, newValues); ASSERT_FALSE(test::TestStruct::classof(badDictionary)); } @@ -88,7 +88,7 @@ auto wrongAttr = mlir::NamedAttribute(wrongId, expectedValues[0].second); newValues.push_back(wrongAttr); - auto badDictionary = mlir::DictionaryAttr::get(newValues, &context); + auto badDictionary = mlir::DictionaryAttr::get(&context, newValues); ASSERT_FALSE(test::TestStruct::classof(badDictionary)); } @@ -113,7 +113,7 @@ auto wrongAttr = mlir::NamedAttribute(id, elementsAttr); newValues.push_back(wrongAttr); - auto badDictionary = mlir::DictionaryAttr::get(newValues, &context); + auto badDictionary = mlir::DictionaryAttr::get(&context, newValues); ASSERT_FALSE(test::TestStruct::classof(badDictionary)); } @@ -130,7 +130,7 @@ expectedValues.begin() + 1, expectedValues.end()); // Make a new DictionaryAttr and validate it is not a validate TestStruct. - auto badDictionary = mlir::DictionaryAttr::get(newValues, &context); + auto badDictionary = mlir::DictionaryAttr::get(&context, newValues); ASSERT_FALSE(test::TestStruct::classof(badDictionary)); }