diff --git a/mlir/examples/toy/Ch2/include/toy/Ops.td b/mlir/examples/toy/Ch2/include/toy/Ops.td --- a/mlir/examples/toy/Ch2/include/toy/Ops.td +++ b/mlir/examples/toy/Ch2/include/toy/Ops.td @@ -134,7 +134,9 @@ let arguments = (ins SymbolNameAttr:$sym_name, - TypeAttrOf:$function_type + TypeAttrOf:$function_type, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs ); let regions = (region AnyRegion:$body); diff --git a/mlir/examples/toy/Ch2/mlir/Dialect.cpp b/mlir/examples/toy/Ch2/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch2/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch2/mlir/Dialect.cpp @@ -218,8 +218,9 @@ void FuncOp::print(mlir::OpAsmPrinter &p) { // Dispatch to the FunctionOpInterface provided utility method that prints the // function operation. - mlir::function_interface_impl::printFunctionOp(p, *this, - /*isVariadic=*/false); + mlir::function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); } //===----------------------------------------------------------------------===// diff --git a/mlir/examples/toy/Ch3/include/toy/Ops.td b/mlir/examples/toy/Ch3/include/toy/Ops.td --- a/mlir/examples/toy/Ch3/include/toy/Ops.td +++ b/mlir/examples/toy/Ch3/include/toy/Ops.td @@ -133,7 +133,9 @@ let arguments = (ins SymbolNameAttr:$sym_name, - TypeAttrOf:$function_type + TypeAttrOf:$function_type, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs ); let regions = (region AnyRegion:$body); diff --git a/mlir/examples/toy/Ch3/mlir/Dialect.cpp b/mlir/examples/toy/Ch3/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch3/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch3/mlir/Dialect.cpp @@ -205,8 +205,9 @@ void FuncOp::print(mlir::OpAsmPrinter &p) { // Dispatch to the FunctionOpInterface provided utility method that prints the // function operation. - mlir::function_interface_impl::printFunctionOp(p, *this, - /*isVariadic=*/false); + mlir::function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); } //===----------------------------------------------------------------------===// diff --git a/mlir/examples/toy/Ch4/include/toy/Ops.td b/mlir/examples/toy/Ch4/include/toy/Ops.td --- a/mlir/examples/toy/Ch4/include/toy/Ops.td +++ b/mlir/examples/toy/Ch4/include/toy/Ops.td @@ -163,7 +163,9 @@ let arguments = (ins SymbolNameAttr:$sym_name, - TypeAttrOf:$function_type + TypeAttrOf:$function_type, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs ); let regions = (region AnyRegion:$body); diff --git a/mlir/examples/toy/Ch4/mlir/Dialect.cpp b/mlir/examples/toy/Ch4/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch4/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch4/mlir/Dialect.cpp @@ -294,8 +294,9 @@ void FuncOp::print(mlir::OpAsmPrinter &p) { // Dispatch to the FunctionOpInterface provided utility method that prints the // function operation. - mlir::function_interface_impl::printFunctionOp(p, *this, - /*isVariadic=*/false); + mlir::function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); } /// Returns the region on the function operation that is callable. diff --git a/mlir/examples/toy/Ch5/include/toy/Ops.td b/mlir/examples/toy/Ch5/include/toy/Ops.td --- a/mlir/examples/toy/Ch5/include/toy/Ops.td +++ b/mlir/examples/toy/Ch5/include/toy/Ops.td @@ -163,7 +163,9 @@ let arguments = (ins SymbolNameAttr:$sym_name, - TypeAttrOf:$function_type + TypeAttrOf:$function_type, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs ); let regions = (region AnyRegion:$body); diff --git a/mlir/examples/toy/Ch5/mlir/Dialect.cpp b/mlir/examples/toy/Ch5/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch5/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch5/mlir/Dialect.cpp @@ -294,8 +294,9 @@ void FuncOp::print(mlir::OpAsmPrinter &p) { // Dispatch to the FunctionOpInterface provided utility method that prints the // function operation. - mlir::function_interface_impl::printFunctionOp(p, *this, - /*isVariadic=*/false); + mlir::function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); } /// Returns the region on the function operation that is callable. diff --git a/mlir/examples/toy/Ch6/include/toy/Ops.td b/mlir/examples/toy/Ch6/include/toy/Ops.td --- a/mlir/examples/toy/Ch6/include/toy/Ops.td +++ b/mlir/examples/toy/Ch6/include/toy/Ops.td @@ -163,7 +163,9 @@ let arguments = (ins SymbolNameAttr:$sym_name, - TypeAttrOf:$function_type + TypeAttrOf:$function_type, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs ); let regions = (region AnyRegion:$body); diff --git a/mlir/examples/toy/Ch6/mlir/Dialect.cpp b/mlir/examples/toy/Ch6/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch6/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch6/mlir/Dialect.cpp @@ -294,8 +294,9 @@ void FuncOp::print(mlir::OpAsmPrinter &p) { // Dispatch to the FunctionOpInterface provided utility method that prints the // function operation. - mlir::function_interface_impl::printFunctionOp(p, *this, - /*isVariadic=*/false); + mlir::function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); } /// Returns the region on the function operation that is callable. diff --git a/mlir/examples/toy/Ch7/include/toy/Ops.td b/mlir/examples/toy/Ch7/include/toy/Ops.td --- a/mlir/examples/toy/Ch7/include/toy/Ops.td +++ b/mlir/examples/toy/Ch7/include/toy/Ops.td @@ -186,7 +186,9 @@ let arguments = (ins SymbolNameAttr:$sym_name, - TypeAttrOf:$function_type + TypeAttrOf:$function_type, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs ); let regions = (region AnyRegion:$body); diff --git a/mlir/examples/toy/Ch7/mlir/Dialect.cpp b/mlir/examples/toy/Ch7/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch7/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch7/mlir/Dialect.cpp @@ -321,8 +321,9 @@ void FuncOp::print(mlir::OpAsmPrinter &p) { // Dispatch to the FunctionOpInterface provided utility method that prints the // function operation. - mlir::function_interface_impl::printFunctionOp(p, *this, - /*isVariadic=*/false); + mlir::function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); } /// Returns the region on the function operation that is callable. diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td --- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td +++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td @@ -140,7 +140,9 @@ let arguments = (ins SymbolNameAttr:$sym_name, TypeAttrOf:$function_type, - OptionalAttr:$sym_visibility); + OptionalAttr:$sym_visibility, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs); let regions = (region AnyRegion:$body); diff --git a/mlir/include/mlir/Dialect/Func/IR/FuncOps.td b/mlir/include/mlir/Dialect/Func/IR/FuncOps.td --- a/mlir/include/mlir/Dialect/Func/IR/FuncOps.td +++ b/mlir/include/mlir/Dialect/Func/IR/FuncOps.td @@ -251,7 +251,9 @@ let arguments = (ins SymbolNameAttr:$sym_name, TypeAttrOf:$function_type, - OptionalAttr:$sym_visibility); + OptionalAttr:$sym_visibility, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs); let regions = (region AnyRegion:$body); let builders = [OpBuilder<(ins diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td @@ -242,7 +242,9 @@ attribution. }]; - let arguments = (ins TypeAttrOf:$function_type); + let arguments = (ins TypeAttrOf:$function_type, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs); let regions = (region AnyRegion:$body); let skipDefaultBuilders = 1; diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -1308,7 +1308,9 @@ DefaultValuedAttr:$CConv, OptionalAttr:$personality, OptionalAttr:$garbageCollector, - OptionalAttr:$passthrough + OptionalAttr:$passthrough, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs ); let regions = (region AnyRegion:$body); diff --git a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td --- a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td +++ b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td @@ -52,6 +52,8 @@ let arguments = (ins SymbolNameAttr:$sym_name, TypeAttrOf:$function_type, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs, OptionalAttr:$sym_visibility); let regions = (region AnyRegion:$body); @@ -401,6 +403,8 @@ let arguments = (ins SymbolNameAttr:$sym_name, TypeAttrOf:$function_type, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs, OptionalAttr:$sym_visibility); let regions = (region AnyRegion:$body); diff --git a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td --- a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td +++ b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td @@ -652,7 +652,9 @@ let arguments = (ins SymbolNameAttr:$sym_name, - TypeAttrOf:$function_type + TypeAttrOf:$function_type, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs ); let regions = (region MinSizedRegion<1>:$body); diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td @@ -291,6 +291,8 @@ let arguments = (ins TypeAttrOf:$function_type, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs, StrAttr:$sym_name, SPIRV_FunctionControlAttr:$function_control ); diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -1107,6 +1107,8 @@ let arguments = (ins SymbolNameAttr:$sym_name, TypeAttrOf:$function_type, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs, OptionalAttr:$sym_visibility); let regions = (region AnyRegion:$body); diff --git a/mlir/include/mlir/IR/FunctionImplementation.h b/mlir/include/mlir/IR/FunctionImplementation.h --- a/mlir/include/mlir/IR/FunctionImplementation.h +++ b/mlir/include/mlir/IR/FunctionImplementation.h @@ -39,10 +39,12 @@ /// with special names given by getResultAttrName, getArgumentAttrName. void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef argAttrs, - ArrayRef resultAttrs); + ArrayRef resultAttrs, + StringAttr argAttrsName, StringAttr resAttrsName); void addArgAndResultAttrs(Builder &builder, OperationState &result, - ArrayRef argAttrs, - ArrayRef resultAttrs); + ArrayRef args, + ArrayRef resultAttrs, + StringAttr argAttrsName, StringAttr resAttrsName); /// Callback type for `parseFunctionOp`, the callback should produce the /// type that will be associated with a function-like operation from lists of @@ -77,15 +79,17 @@ /// type, report the error or delegate the reporting to the op's verifier. ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, - FuncTypeBuilder funcTypeBuilder); + FuncTypeBuilder funcTypeBuilder, + StringAttr argAttrsName, StringAttr resAttrsName); /// Printer implementation for function-like operations. void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, - StringRef typeAttrName); + StringRef typeAttrName, StringAttr argAttrsName, + StringAttr resAttrsName); /// Prints the signature of the function-like operation `op`. Assumes `op` has /// is a FunctionOpInterface and has passed verification. -void printFunctionSignature(OpAsmPrinter &p, Operation *op, +void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef argTypes, bool isVariadic, ArrayRef resultTypes); diff --git a/mlir/include/mlir/IR/FunctionInterfaces.h b/mlir/include/mlir/IR/FunctionInterfaces.h --- a/mlir/include/mlir/IR/FunctionInterfaces.h +++ b/mlir/include/mlir/IR/FunctionInterfaces.h @@ -26,48 +26,30 @@ namespace function_interface_impl { -/// Return the name of the attribute used for function argument attributes. -inline StringRef getArgDictAttrName() { return "arg_attrs"; } - -/// Return the name of the attribute used for function argument attributes. -inline StringRef getResultDictAttrName() { return "res_attrs"; } - /// Returns the dictionary attribute corresponding to the argument at 'index'. /// If there are no argument attributes at 'index', a null attribute is /// returned. -DictionaryAttr getArgAttrDict(Operation *op, unsigned index); +DictionaryAttr getArgAttrDict(FunctionOpInterface op, unsigned index); /// Returns the dictionary attribute corresponding to the result at 'index'. /// If there are no result attributes at 'index', a null attribute is /// returned. -DictionaryAttr getResultAttrDict(Operation *op, unsigned index); +DictionaryAttr getResultAttrDict(FunctionOpInterface op, unsigned index); -namespace detail { -/// Update the given index into an argument or result attribute dictionary. -void setArgResAttrDict(Operation *op, StringRef attrName, - unsigned numTotalIndices, unsigned index, - DictionaryAttr attrs); -} // namespace detail +/// Return all of the attributes for the argument at 'index'. +ArrayRef getArgAttrs(FunctionOpInterface op, unsigned index); + +/// Return all of the attributes for the result at 'index'. +ArrayRef getResultAttrs(FunctionOpInterface op, unsigned index); /// Set all of the argument or result attribute dictionaries for a function. The /// size of `attrs` is expected to match the number of arguments/results of the /// given `op`. -void setAllArgAttrDicts(Operation *op, ArrayRef attrs); -void setAllArgAttrDicts(Operation *op, ArrayRef attrs); -void setAllResultAttrDicts(Operation *op, ArrayRef attrs); -void setAllResultAttrDicts(Operation *op, ArrayRef attrs); - -/// Return all of the attributes for the argument at 'index'. -inline ArrayRef getArgAttrs(Operation *op, unsigned index) { - auto argDict = getArgAttrDict(op, index); - return argDict ? argDict.getValue() : std::nullopt; -} - -/// Return all of the attributes for the result at 'index'. -inline ArrayRef getResultAttrs(Operation *op, unsigned index) { - auto resultDict = getResultAttrDict(op, index); - return resultDict ? resultDict.getValue() : std::nullopt; -} +void setAllArgAttrDicts(FunctionOpInterface op, ArrayRef attrs); +void setAllArgAttrDicts(FunctionOpInterface op, ArrayRef attrs); +void setAllResultAttrDicts(FunctionOpInterface op, + ArrayRef attrs); +void setAllResultAttrDicts(FunctionOpInterface op, ArrayRef attrs); /// Insert the specified arguments and update the function type attribute. void insertFunctionArguments(FunctionOpInterface op, @@ -110,20 +92,10 @@ //===----------------------------------------------------------------------===// /// Set the attributes held by the argument at 'index'. -template -void setArgAttrs(ConcreteType op, unsigned index, - ArrayRef attributes) { - assert(index < op.getNumArguments() && "invalid argument number"); - return detail::setArgResAttrDict( - op, getArgDictAttrName(), op.getNumArguments(), index, - DictionaryAttr::get(op->getContext(), attributes)); -} -template -void setArgAttrs(ConcreteType op, unsigned index, DictionaryAttr attributes) { - return detail::setArgResAttrDict( - op, getArgDictAttrName(), op.getNumArguments(), index, - attributes ? attributes : DictionaryAttr::get(op->getContext())); -} +void setArgAttrs(FunctionOpInterface op, unsigned index, + ArrayRef attributes); +void setArgAttrs(FunctionOpInterface op, unsigned index, + DictionaryAttr attributes); /// If the an attribute exists with the specified name, change it to the new /// value. Otherwise, add a new attribute with the specified name/value. @@ -157,23 +129,10 @@ //===----------------------------------------------------------------------===// /// Set the attributes held by the result at 'index'. -template -void setResultAttrs(ConcreteType op, unsigned index, - ArrayRef attributes) { - assert(index < op.getNumResults() && "invalid result number"); - return detail::setArgResAttrDict( - op, getResultDictAttrName(), op.getNumResults(), index, - DictionaryAttr::get(op->getContext(), attributes)); -} - -template -void setResultAttrs(ConcreteType op, unsigned index, - DictionaryAttr attributes) { - assert(index < op.getNumResults() && "invalid result number"); - return detail::setArgResAttrDict( - op, getResultDictAttrName(), op.getNumResults(), index, - attributes ? attributes : DictionaryAttr::get(op->getContext())); -} +void setResultAttrs(FunctionOpInterface op, unsigned index, + ArrayRef attributes); +void setResultAttrs(FunctionOpInterface op, unsigned index, + DictionaryAttr attributes); /// If the an attribute exists with the specified name, change it to the new /// value. Otherwise, add a new attribute with the specified name/value. @@ -213,9 +172,8 @@ unsigned numArgs = op.getNumArguments(); if (allArgAttrs.size() != numArgs) { return op.emitOpError() - << "expects argument attribute array `" << getArgDictAttrName() - << "` to have the same number of elements as the number of " - "function arguments, got " + << "expects argument attribute array to have the same number of " + "elements as the number of function arguments, got " << allArgAttrs.size() << ", but expected " << numArgs; } for (unsigned i = 0; i != numArgs; ++i) { @@ -245,9 +203,8 @@ unsigned numResults = op.getNumResults(); if (allResultAttrs.size() != numResults) { return op.emitOpError() - << "expects result attribute array `" << getResultDictAttrName() - << "` to have the same number of elements as the number of " - "function results, got " + << "expects result attribute array to have the same number of " + "elements as the number of function results, got " << allResultAttrs.size() << ", but expected " << numResults; } for (unsigned i = 0; i != numResults; ++i) { diff --git a/mlir/include/mlir/IR/FunctionInterfaces.td b/mlir/include/mlir/IR/FunctionInterfaces.td --- a/mlir/include/mlir/IR/FunctionInterfaces.td +++ b/mlir/include/mlir/IR/FunctionInterfaces.td @@ -59,6 +59,42 @@ result attributes. }], "void", "setFunctionTypeAttr", (ins "::mlir::TypeAttr":$type)>, + + InterfaceMethod<[{ + Get the array of argument attribute dictionaries. The method should return + an array attribute containing only dictionary attributes equal in number + to the number of function arguments. Alternatively, the method can return + null to indicate that the function has no argument attributes. + }], + "::mlir::ArrayAttr", "getArgAttrsAttr">, + InterfaceMethod<[{ + Get the array of result attribute dictionaries. The method should return + an array attribute containing only dictionary attributes equal in number + to the number of function results. Alternatively, the method can return + null to indicate that the function has no result attributes. + }], + "::mlir::ArrayAttr", "getResAttrsAttr">, + InterfaceMethod<[{ + Set the array of argument attribute dictionaries. + }], + "void", "setArgAttrsAttr", (ins "::mlir::ArrayAttr":$attrs)>, + InterfaceMethod<[{ + Set the array of result attribute dictionaries. + }], + "void", "setResAttrsAttr", (ins "::mlir::ArrayAttr":$attrs)>, + InterfaceMethod<[{ + Remove the array of argument attribute dictionaries. This is the same as + setting all argument attributes to an empty dictionary. The method should + return the removed attribute. + }], + "::mlir::Attribute", "removeArgAttrsAttr">, + InterfaceMethod<[{ + Remove the array of result attribute dictionaries. This is the same as + setting all result attributes to an empty dictionary. The method should + return the removed attribute. + }], + "::mlir::Attribute", "removeResAttrsAttr">, + InterfaceMethod<[{ Returns the function argument types based exclusively on the type (to allow for this method may be called on function @@ -250,20 +286,6 @@ function_interface_impl::setFunctionType(this->getOperation(), newType); } - // FIXME: These functions should be removed in favor of just forwarding to - // the derived operation, which should already have these defined - // (via ODS). - - /// Returns the name of the attribute used for function argument attributes. - static StringRef getArgDictAttrName() { - return function_interface_impl::getArgDictAttrName(); - } - - /// Returns the name of the attribute used for function argument attributes. - static StringRef getResultDictAttrName() { - return function_interface_impl::getResultDictAttrName(); - } - //===------------------------------------------------------------------===// // Argument and Result Handling //===------------------------------------------------------------------===// @@ -405,10 +427,8 @@ /// Return an ArrayAttr containing all argument attribute dictionaries of /// this function, or nullptr if no arguments have attributes. - ArrayAttr getAllArgAttrs() { - return this->getOperation()->template getAttrOfType( - getArgDictAttrName()); - } + ArrayAttr getAllArgAttrs() { return $_op.getArgAttrsAttr(); } + /// Return all argument attributes of this function. void getAllArgAttrs(SmallVectorImpl &result) { if (ArrayAttr argAttrs = getAllArgAttrs()) { @@ -460,7 +480,7 @@ } void setAllArgAttrs(ArrayAttr attributes) { assert(attributes.size() == $_op.getNumArguments()); - this->getOperation()->setAttr(getArgDictAttrName(), attributes); + $_op.setArgAttrsAttr(attributes); } /// If the an attribute exists with the specified name, change it to the new @@ -496,10 +516,8 @@ /// Return an ArrayAttr containing all result attribute dictionaries of this /// function, or nullptr if no result have attributes. - ArrayAttr getAllResultAttrs() { - return this->getOperation()->template getAttrOfType( - getResultDictAttrName()); - } + ArrayAttr getAllResultAttrs() { return $_op.getResAttrsAttr(); } + /// Return all result attributes of this function. void getAllResultAttrs(SmallVectorImpl &result) { if (ArrayAttr argAttrs = getAllResultAttrs()) { @@ -553,7 +571,7 @@ } void setAllResultAttrs(ArrayAttr attributes) { assert(attributes.size() == $_op.getNumResults()); - this->getOperation()->setAttr(getResultDictAttrName(), attributes); + $_op.setResAttrsAttr(attributes); } /// If the an attribute exists with the specified name, change it to the new diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1524,6 +1524,8 @@ } def IndexListArrayAttr : TypedArrayAttrBase; +def DictArrayAttr : + TypedArrayAttrBase; // Attributes containing symbol references. def SymbolRefAttr : Attr()">, diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp --- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp +++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp @@ -66,8 +66,8 @@ attr.getName() == func.getFunctionTypeAttrName() || attr.getName() == "func.varargs" || (filterArgAndResAttrs && - (attr.getName() == FunctionOpInterface::getArgDictAttrName() || - attr.getName() == FunctionOpInterface::getResultDictAttrName()))) + (attr.getName() == func.getArgAttrsAttrName() || + attr.getName() == func.getResAttrsAttrName()))) continue; result.push_back(attr); } @@ -90,18 +90,19 @@ static void prependResAttrsToArgAttrs(OpBuilder &builder, SmallVectorImpl &attributes, - size_t numArguments) { + func::FuncOp func) { + size_t numArguments = func.getNumArguments(); auto allAttrs = SmallVector( numArguments + 1, DictionaryAttr::get(builder.getContext())); NamedAttribute *argAttrs = nullptr; for (auto *it = attributes.begin(); it != attributes.end();) { - if (it->getName() == FunctionOpInterface::getArgDictAttrName()) { + if (it->getName() == func.getArgAttrsAttrName()) { auto arrayAttrs = it->getValue().cast(); assert(arrayAttrs.size() == numArguments && "Number of arg attrs and args should match"); std::copy(arrayAttrs.begin(), arrayAttrs.end(), allAttrs.begin() + 1); argAttrs = it; - } else if (it->getName() == FunctionOpInterface::getResultDictAttrName()) { + } else if (it->getName() == func.getResAttrsAttrName()) { auto arrayAttrs = it->getValue().cast(); assert(!arrayAttrs.empty() && "expected array to be non-empty"); allAttrs[0] = (arrayAttrs.size() == 1) @@ -113,9 +114,8 @@ it++; } - auto newArgAttrs = - builder.getNamedAttr(FunctionOpInterface::getArgDictAttrName(), - builder.getArrayAttr(allAttrs)); + auto newArgAttrs = builder.getNamedAttr(func.getArgAttrsAttrName(), + builder.getArrayAttr(allAttrs)); if (!argAttrs) { attributes.emplace_back(newArgAttrs); return; @@ -141,7 +141,7 @@ auto [wrapperFuncType, resultIsNowArg] = typeConverter.convertFunctionTypeCWrapper(type); if (resultIsNowArg) - prependResAttrsToArgAttrs(rewriter, attributes, funcOp.getNumArguments()); + prependResAttrsToArgAttrs(rewriter, attributes, funcOp); auto wrapperFuncOp = rewriter.create( loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(), wrapperFuncType, LLVM::Linkage::External, /*dsoLocal*/ false, @@ -205,7 +205,7 @@ filterFuncAttributes(funcOp, /*filterArgAndResAttrs=*/false, attributes); if (resultIsNowArg) - prependResAttrsToArgAttrs(builder, attributes, funcOp.getNumArguments()); + prependResAttrsToArgAttrs(builder, attributes, funcOp); // Create the auxiliary function. auto wrapperFunc = builder.create( loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(), @@ -309,8 +309,8 @@ ? resAttrDicts : rewriter.getArrayAttr( {wrapAsStructAttrs(rewriter, resAttrDicts)}); - attributes.push_back(rewriter.getNamedAttr( - FunctionOpInterface::getResultDictAttrName(), newResAttrDicts)); + attributes.push_back( + rewriter.getNamedAttr(funcOp.getResAttrsAttrName(), newResAttrDicts)); } if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) { SmallVector newArgAttrs( @@ -353,9 +353,8 @@ newArgAttrs[mapping->inputNo + j] = DictionaryAttr::get(rewriter.getContext(), convertedAttrs); } - attributes.push_back( - rewriter.getNamedAttr(FunctionOpInterface::getArgDictAttrName(), - rewriter.getArrayAttr(newArgAttrs))); + attributes.push_back(rewriter.getNamedAttr( + funcOp.getArgAttrsAttrName(), rewriter.getArrayAttr(newArgAttrs))); } for (const auto &pair : llvm::enumerate(attributes)) { if (pair.value().getName() == "llvm.linkage") { diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp --- a/mlir/lib/Dialect/Async/IR/Async.cpp +++ b/mlir/lib/Dialect/Async/IR/Async.cpp @@ -340,8 +340,9 @@ if (argAttrs.empty()) return; assert(type.getNumInputs() == argAttrs.size()); - function_interface_impl::addArgAndResultAttrs(builder, state, argAttrs, - /*resultAttrs=*/std::nullopt); + function_interface_impl::addArgAndResultAttrs( + builder, state, argAttrs, /*resultAttrs=*/std::nullopt, + getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name)); } ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { @@ -352,12 +353,14 @@ return function_interface_impl::parseFunctionOp( parser, result, /*allowVariadic=*/false, - getFunctionTypeAttrName(result.name), buildFuncType); + getFunctionTypeAttrName(result.name), buildFuncType, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); } void FuncOp::print(OpAsmPrinter &p) { - function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false, - getFunctionTypeAttrName()); + function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); } /// Check that the result type of async.func is not void and must be diff --git a/mlir/lib/Dialect/Func/IR/FuncOps.cpp b/mlir/lib/Dialect/Func/IR/FuncOps.cpp --- a/mlir/lib/Dialect/Func/IR/FuncOps.cpp +++ b/mlir/lib/Dialect/Func/IR/FuncOps.cpp @@ -251,8 +251,9 @@ if (argAttrs.empty()) return; assert(type.getNumInputs() == argAttrs.size()); - function_interface_impl::addArgAndResultAttrs(builder, state, argAttrs, - /*resultAttrs=*/std::nullopt); + function_interface_impl::addArgAndResultAttrs( + builder, state, argAttrs, /*resultAttrs=*/std::nullopt, + getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name)); } ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { @@ -263,12 +264,14 @@ return function_interface_impl::parseFunctionOp( parser, result, /*allowVariadic=*/false, - getFunctionTypeAttrName(result.name), buildFuncType); + getFunctionTypeAttrName(result.name), buildFuncType, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); } void FuncOp::print(OpAsmPrinter &p) { - function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false, - getFunctionTypeAttrName()); + function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); } /// Clone the internal blocks from this function into dest and all attributes diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -934,8 +934,9 @@ result.addAttribute(getFunctionTypeAttrName(result.name), TypeAttr::get(type)); - function_interface_impl::addArgAndResultAttrs(builder, result, entryArgs, - resultAttrs); + function_interface_impl::addArgAndResultAttrs( + builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name), + getResAttrsAttrName(result.name)); // Parse workgroup memory attributions. if (failed(parseAttributions(parser, GPUFuncOp::getWorkgroupKeyword(), @@ -996,7 +997,8 @@ function_interface_impl::printFunctionAttributes( p, *this, {getNumWorkgroupAttributionsAttrName(), - GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName()}); + GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()}); p << ' '; p.printRegion(getBody(), /*printEntryBlockArgs=*/false); } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -2006,8 +2006,9 @@ assert(type.cast().getNumParams() == argAttrs.size() && "expected as many argument attribute lists as arguments"); - function_interface_impl::addArgAndResultAttrs(builder, result, argAttrs, - /*resultAttrs=*/std::nullopt); + function_interface_impl::addArgAndResultAttrs( + builder, result, argAttrs, /*resultAttrs=*/std::nullopt, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); } // Builds an LLVM function type from the given lists of input and output types. @@ -2095,8 +2096,9 @@ if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) return failure(); - function_interface_impl::addArgAndResultAttrs(parser.getBuilder(), result, - entryArgs, resultAttrs); + function_interface_impl::addArgAndResultAttrs( + parser.getBuilder(), result, entryArgs, resultAttrs, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); auto *body = result.addRegion(); OptionalParseResult parseResult = @@ -2131,7 +2133,8 @@ isVarArg(), resTypes); function_interface_impl::printFunctionAttributes( p, *this, - {getFunctionTypeAttrName(), getLinkageAttrName(), getCConvAttrName()}); + {getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(), + getLinkageAttrName(), getCConvAttrName()}); // Print the body if this is not an external function. Region &body = getBody(); diff --git a/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp b/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp --- a/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp +++ b/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp @@ -153,12 +153,14 @@ return function_interface_impl::parseFunctionOp( parser, result, /*allowVariadic=*/false, - getFunctionTypeAttrName(result.name), buildFuncType); + getFunctionTypeAttrName(result.name), buildFuncType, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); } void FuncOp::print(OpAsmPrinter &p) { - function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false, - getFunctionTypeAttrName()); + function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); } //===----------------------------------------------------------------------===// @@ -316,12 +318,14 @@ return function_interface_impl::parseFunctionOp( parser, result, /*allowVariadic=*/false, - getFunctionTypeAttrName(result.name), buildFuncType); + getFunctionTypeAttrName(result.name), buildFuncType, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); } void SubgraphOp::print(OpAsmPrinter &p) { - function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false, - getFunctionTypeAttrName()); + function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp b/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp --- a/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp +++ b/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp @@ -221,12 +221,14 @@ return function_interface_impl::parseFunctionOp( parser, result, /*allowVariadic=*/false, - getFunctionTypeAttrName(result.name), buildFuncType); + getFunctionTypeAttrName(result.name), buildFuncType, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); } void FuncOp::print(OpAsmPrinter &p) { - function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false, - getFunctionTypeAttrName()); + function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -2396,8 +2396,9 @@ // Add the attributes to the function arguments. assert(resultAttrs.size() == resultTypes.size()); - function_interface_impl::addArgAndResultAttrs(builder, result, entryArgs, - resultAttrs); + function_interface_impl::addArgAndResultAttrs( + builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name), + getResAttrsAttrName(result.name)); // Parse the optional function body. auto *body = result.addRegion(); @@ -2419,7 +2420,8 @@ function_interface_impl::printFunctionAttributes( printer, *this, {spirv::attributeName(), - getFunctionTypeAttrName(), getFunctionControlAttrName()}); + getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(), + getFunctionControlAttrName()}); // Print the body if this is not an external function. Region &body = this->getBody(); diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -1300,8 +1300,9 @@ if (argAttrs.empty()) return; assert(type.getNumInputs() == argAttrs.size()); - function_interface_impl::addArgAndResultAttrs(builder, state, argAttrs, - /*resultAttrs=*/std::nullopt); + function_interface_impl::addArgAndResultAttrs( + builder, state, argAttrs, /*resultAttrs=*/std::nullopt, + getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name)); } ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { @@ -1312,12 +1313,14 @@ return function_interface_impl::parseFunctionOp( parser, result, /*allowVariadic=*/false, - getFunctionTypeAttrName(result.name), buildFuncType); + getFunctionTypeAttrName(result.name), buildFuncType, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); } void FuncOp::print(OpAsmPrinter &p) { - function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false, - getFunctionTypeAttrName()); + function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/FunctionImplementation.cpp b/mlir/lib/IR/FunctionImplementation.cpp --- a/mlir/lib/IR/FunctionImplementation.cpp +++ b/mlir/lib/IR/FunctionImplementation.cpp @@ -113,7 +113,7 @@ return parser.parseRParen(); } -ParseResult mlir::function_interface_impl::parseFunctionSignature( +ParseResult function_interface_impl::parseFunctionSignature( OpAsmParser &parser, bool allowVariadic, SmallVectorImpl &arguments, bool &isVariadic, SmallVectorImpl &resultTypes, @@ -125,9 +125,10 @@ return success(); } -void mlir::function_interface_impl::addArgAndResultAttrs( +void function_interface_impl::addArgAndResultAttrs( Builder &builder, OperationState &result, ArrayRef argAttrs, - ArrayRef resultAttrs) { + ArrayRef resultAttrs, StringAttr argAttrsName, + StringAttr resAttrsName) { auto nonEmptyAttrsFn = [](DictionaryAttr attrs) { return attrs && !attrs.empty(); }; @@ -142,28 +143,28 @@ // Add the attributes to the function arguments. if (llvm::any_of(argAttrs, nonEmptyAttrsFn)) - result.addAttribute(function_interface_impl::getArgDictAttrName(), - getArrayAttr(argAttrs)); + result.addAttribute(argAttrsName, getArrayAttr(argAttrs)); // Add the attributes to the function results. if (llvm::any_of(resultAttrs, nonEmptyAttrsFn)) - result.addAttribute(function_interface_impl::getResultDictAttrName(), - getArrayAttr(resultAttrs)); + result.addAttribute(resAttrsName, getArrayAttr(resultAttrs)); } -void mlir::function_interface_impl::addArgAndResultAttrs( +void function_interface_impl::addArgAndResultAttrs( Builder &builder, OperationState &result, - ArrayRef args, - ArrayRef resultAttrs) { + ArrayRef args, ArrayRef resultAttrs, + StringAttr argAttrsName, StringAttr resAttrsName) { SmallVector argAttrs; for (const auto &arg : args) argAttrs.push_back(arg.attrs); - addArgAndResultAttrs(builder, result, argAttrs, resultAttrs); + addArgAndResultAttrs(builder, result, argAttrs, resultAttrs, argAttrsName, + resAttrsName); } -ParseResult mlir::function_interface_impl::parseFunctionOp( +ParseResult function_interface_impl::parseFunctionOp( OpAsmParser &parser, OperationState &result, bool allowVariadic, - StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder) { + StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, + StringAttr argAttrsName, StringAttr resAttrsName) { SmallVector entryArgs; SmallVector resultAttrs; SmallVector resultTypes; @@ -220,7 +221,8 @@ // Add the attributes to the function arguments. assert(resultAttrs.size() == resultTypes.size()); - addArgAndResultAttrs(builder, result, entryArgs, resultAttrs); + addArgAndResultAttrs(builder, result, entryArgs, resultAttrs, argAttrsName, + resAttrsName); // Parse the optional function body. The printer will not print the body if // its empty, so disallow parsing of empty body in the parser. @@ -261,14 +263,14 @@ os << ')'; } -void mlir::function_interface_impl::printFunctionSignature( - OpAsmPrinter &p, Operation *op, ArrayRef argTypes, bool isVariadic, - ArrayRef resultTypes) { +void function_interface_impl::printFunctionSignature( + OpAsmPrinter &p, FunctionOpInterface op, ArrayRef argTypes, + bool isVariadic, ArrayRef resultTypes) { Region &body = op->getRegion(0); bool isExternal = body.empty(); p << '('; - ArrayAttr argAttrs = op->getAttrOfType(getArgDictAttrName()); + ArrayAttr argAttrs = op.getArgAttrsAttr(); for (unsigned i = 0, e = argTypes.size(); i < e; ++i) { if (i > 0) p << ", "; @@ -295,26 +297,23 @@ if (!resultTypes.empty()) { p.getStream() << " -> "; - auto resultAttrs = op->getAttrOfType(getResultDictAttrName()); + auto resultAttrs = op.getResAttrsAttr(); printFunctionResultList(p, resultTypes, resultAttrs); } } -void mlir::function_interface_impl::printFunctionAttributes( +void function_interface_impl::printFunctionAttributes( OpAsmPrinter &p, Operation *op, ArrayRef elided) { // Print out function attributes, if present. - SmallVector ignoredAttrs = {SymbolTable::getSymbolAttrName(), - getArgDictAttrName(), - getResultDictAttrName()}; + SmallVector ignoredAttrs = {SymbolTable::getSymbolAttrName()}; ignoredAttrs.append(elided.begin(), elided.end()); p.printOptionalAttrDictWithKeyword(op->getAttrs(), ignoredAttrs); } -void mlir::function_interface_impl::printFunctionOp(OpAsmPrinter &p, - FunctionOpInterface op, - bool isVariadic, - StringRef typeAttrName) { +void function_interface_impl::printFunctionOp( + OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, + StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName) { // Print the operation and the function name. auto funcName = op->getAttrOfType(SymbolTable::getSymbolAttrName()) @@ -329,7 +328,8 @@ ArrayRef argTypes = op.getArgumentTypes(); ArrayRef resultTypes = op.getResultTypes(); printFunctionSignature(p, op, argTypes, isVariadic, resultTypes); - printFunctionAttributes(p, op, {visibilityAttrName, typeAttrName}); + printFunctionAttributes( + p, op, {visibilityAttrName, typeAttrName, argAttrsName, resAttrsName}); // Print the body if this is not an external function. Region &body = op->getRegion(0); if (!body.empty()) { diff --git a/mlir/lib/IR/FunctionInterfaces.cpp b/mlir/lib/IR/FunctionInterfaces.cpp --- a/mlir/lib/IR/FunctionInterfaces.cpp +++ b/mlir/lib/IR/FunctionInterfaces.cpp @@ -24,27 +24,104 @@ return attr.cast().empty(); } -DictionaryAttr mlir::function_interface_impl::getArgAttrDict(Operation *op, - unsigned index) { - ArrayAttr attrs = op->getAttrOfType(getArgDictAttrName()); +DictionaryAttr function_interface_impl::getArgAttrDict(FunctionOpInterface op, + unsigned index) { + ArrayAttr attrs = op.getArgAttrsAttr(); DictionaryAttr argAttrs = attrs ? attrs[index].cast() : DictionaryAttr(); return argAttrs; } DictionaryAttr -mlir::function_interface_impl::getResultAttrDict(Operation *op, - unsigned index) { - ArrayAttr attrs = op->getAttrOfType(getResultDictAttrName()); +function_interface_impl::getResultAttrDict(FunctionOpInterface op, + unsigned index) { + ArrayAttr attrs = op.getResAttrsAttr(); DictionaryAttr resAttrs = attrs ? attrs[index].cast() : DictionaryAttr(); return resAttrs; } -void mlir::function_interface_impl::detail::setArgResAttrDict( - Operation *op, StringRef attrName, unsigned numTotalIndices, unsigned index, - DictionaryAttr attrs) { - ArrayAttr allAttrs = op->getAttrOfType(attrName); +ArrayRef +function_interface_impl::getArgAttrs(FunctionOpInterface op, unsigned index) { + auto argDict = getArgAttrDict(op, index); + return argDict ? argDict.getValue() : std::nullopt; +} + +ArrayRef +function_interface_impl::getResultAttrs(FunctionOpInterface op, + unsigned index) { + auto resultDict = getResultAttrDict(op, index); + return resultDict ? resultDict.getValue() : std::nullopt; +} + +/// Get either the argument or result attributes array. +template +static ArrayAttr getArgResAttrs(FunctionOpInterface op) { + if constexpr (isArg) + return op.getArgAttrsAttr(); + else + return op.getResAttrsAttr(); +} + +/// Set either the argument or result attributes array. +template +static void setArgResAttrs(FunctionOpInterface op, ArrayAttr attrs) { + if constexpr (isArg) + op.setArgAttrsAttr(attrs); + else + op.setResAttrsAttr(attrs); +} + +/// Erase either the argument or result attributes array. +template +static void removeArgResAttrs(FunctionOpInterface op) { + if constexpr (isArg) + op.removeArgAttrsAttr(); + else + op.removeResAttrsAttr(); +} + +/// Set all of the argument or result attribute dictionaries for a function. +template +static void setAllArgResAttrDicts(FunctionOpInterface op, + ArrayRef attrs) { + if (llvm::all_of(attrs, isEmptyAttrDict)) + removeArgResAttrs(op); + else + setArgResAttrs(op, ArrayAttr::get(op->getContext(), attrs)); +} + +void function_interface_impl::setAllArgAttrDicts( + FunctionOpInterface op, ArrayRef attrs) { + setAllArgAttrDicts(op, ArrayRef(attrs.data(), attrs.size())); +} + +void function_interface_impl::setAllArgAttrDicts(FunctionOpInterface op, + ArrayRef attrs) { + auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute { + return !attr ? DictionaryAttr::get(op->getContext()) : attr; + }); + setAllArgResAttrDicts(op, llvm::to_vector<8>(wrappedAttrs)); +} + +void function_interface_impl::setAllResultAttrDicts( + FunctionOpInterface op, ArrayRef attrs) { + setAllResultAttrDicts(op, ArrayRef(attrs.data(), attrs.size())); +} + +void function_interface_impl::setAllResultAttrDicts(FunctionOpInterface op, + ArrayRef attrs) { + auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute { + return !attr ? DictionaryAttr::get(op->getContext()) : attr; + }); + setAllArgResAttrDicts(op, llvm::to_vector<8>(wrappedAttrs)); +} + +/// Update the given index into an argument or result attribute dictionary. +template +static void setArgResAttrDict(FunctionOpInterface op, unsigned numTotalIndices, + unsigned index, DictionaryAttr attrs) { + ArrayAttr allAttrs = getArgResAttrs(op); if (!allAttrs) { if (attrs.empty()) return; @@ -53,7 +130,7 @@ SmallVector newAttrs(numTotalIndices, DictionaryAttr::get(op->getContext())); newAttrs[index] = attrs; - op->setAttr(attrName, ArrayAttr::get(op->getContext(), newAttrs)); + setArgResAttrs(op, ArrayAttr::get(op->getContext(), newAttrs)); return; } // Check to see if the attribute is different from what we already have. @@ -65,53 +142,51 @@ ArrayRef rawAttrArray = allAttrs.getValue(); if (attrs.empty() && llvm::all_of(rawAttrArray.take_front(index), isEmptyAttrDict) && - llvm::all_of(rawAttrArray.drop_front(index + 1), isEmptyAttrDict)) { - op->removeAttr(attrName); - return; - } + llvm::all_of(rawAttrArray.drop_front(index + 1), isEmptyAttrDict)) + return removeArgResAttrs(op); // Otherwise, create a new attribute array with the updated dictionary. SmallVector newAttrs(rawAttrArray.begin(), rawAttrArray.end()); newAttrs[index] = attrs; - op->setAttr(attrName, ArrayAttr::get(op->getContext(), newAttrs)); + setArgResAttrs(op, ArrayAttr::get(op->getContext(), newAttrs)); } -/// Set all of the argument or result attribute dictionaries for a function. -static void setAllArgResAttrDicts(Operation *op, StringRef attrName, - ArrayRef attrs) { - if (llvm::all_of(attrs, isEmptyAttrDict)) - op->removeAttr(attrName); - else - op->setAttr(attrName, ArrayAttr::get(op->getContext(), attrs)); +void function_interface_impl::setArgAttrs(FunctionOpInterface op, + unsigned index, + ArrayRef attributes) { + assert(index < op.getNumArguments() && "invalid argument number"); + return setArgResAttrDict( + op, op.getNumArguments(), index, + DictionaryAttr::get(op->getContext(), attributes)); } -void mlir::function_interface_impl::setAllArgAttrDicts( - Operation *op, ArrayRef attrs) { - setAllArgAttrDicts(op, ArrayRef(attrs.data(), attrs.size())); -} -void mlir::function_interface_impl::setAllArgAttrDicts( - Operation *op, ArrayRef attrs) { - auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute { - return !attr ? DictionaryAttr::get(op->getContext()) : attr; - }); - setAllArgResAttrDicts(op, getArgDictAttrName(), - llvm::to_vector<8>(wrappedAttrs)); +void function_interface_impl::setArgAttrs(FunctionOpInterface op, + unsigned index, + DictionaryAttr attributes) { + return setArgResAttrDict( + op, op.getNumArguments(), index, + attributes ? attributes : DictionaryAttr::get(op->getContext())); } -void mlir::function_interface_impl::setAllResultAttrDicts( - Operation *op, ArrayRef attrs) { - setAllResultAttrDicts(op, ArrayRef(attrs.data(), attrs.size())); +void function_interface_impl::setResultAttrs( + FunctionOpInterface op, unsigned index, + ArrayRef attributes) { + assert(index < op.getNumResults() && "invalid result number"); + return setArgResAttrDict( + op, op.getNumResults(), index, + DictionaryAttr::get(op->getContext(), attributes)); } -void mlir::function_interface_impl::setAllResultAttrDicts( - Operation *op, ArrayRef attrs) { - auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute { - return !attr ? DictionaryAttr::get(op->getContext()) : attr; - }); - setAllArgResAttrDicts(op, getResultDictAttrName(), - llvm::to_vector<8>(wrappedAttrs)); + +void function_interface_impl::setResultAttrs(FunctionOpInterface op, + unsigned index, + DictionaryAttr attributes) { + assert(index < op.getNumResults() && "invalid result number"); + return setArgResAttrDict( + op, op.getNumResults(), index, + attributes ? attributes : DictionaryAttr::get(op->getContext())); } -void mlir::function_interface_impl::insertFunctionArguments( +void function_interface_impl::insertFunctionArguments( FunctionOpInterface op, ArrayRef argIndices, TypeRange argTypes, ArrayRef argAttrs, ArrayRef argLocs, unsigned originalNumArgs, Type newType) { @@ -128,7 +203,7 @@ Block &entry = op->getRegion(0).front(); // Update the argument attributes of the function. - auto oldArgAttrs = op->getAttrOfType(getArgDictAttrName()); + ArrayAttr oldArgAttrs = op.getArgAttrsAttr(); if (oldArgAttrs || !argAttrs.empty()) { SmallVector newArgAttrs; newArgAttrs.reserve(originalNumArgs + argIndices.size()); @@ -157,7 +232,7 @@ entry.insertArgument(argIndices[i] + i, argTypes[i], argLocs[i]); } -void mlir::function_interface_impl::insertFunctionResults( +void function_interface_impl::insertFunctionResults( FunctionOpInterface op, ArrayRef resultIndices, TypeRange resultTypes, ArrayRef resultAttrs, unsigned originalNumResults, Type newType) { @@ -171,7 +246,7 @@ // - Result attrs. // Update the result attributes of the function. - auto oldResultAttrs = op->getAttrOfType(getResultDictAttrName()); + ArrayAttr oldResultAttrs = op.getResAttrsAttr(); if (oldResultAttrs || !resultAttrs.empty()) { SmallVector newResultAttrs; newResultAttrs.reserve(originalNumResults + resultIndices.size()); @@ -199,7 +274,7 @@ op.setFunctionTypeAttr(TypeAttr::get(newType)); } -void mlir::function_interface_impl::eraseFunctionArguments( +void function_interface_impl::eraseFunctionArguments( FunctionOpInterface op, const BitVector &argIndices, Type newType) { // There are 3 things that need to be updated: // - Function type. @@ -208,7 +283,7 @@ Block &entry = op->getRegion(0).front(); // Update the argument attributes of the function. - if (auto argAttrs = op->getAttrOfType(getArgDictAttrName())) { + if (ArrayAttr argAttrs = op.getArgAttrsAttr()) { SmallVector newArgAttrs; newArgAttrs.reserve(argAttrs.size()); for (unsigned i = 0, e = argIndices.size(); i < e; ++i) @@ -222,14 +297,14 @@ entry.eraseArguments(argIndices); } -void mlir::function_interface_impl::eraseFunctionResults( +void function_interface_impl::eraseFunctionResults( FunctionOpInterface op, const BitVector &resultIndices, Type newType) { // There are 2 things that need to be updated: // - Function type. // - Result attrs. // Update the result attributes of the function. - if (auto resAttrs = op->getAttrOfType(getResultDictAttrName())) { + if (ArrayAttr resAttrs = op.getResAttrsAttr()) { SmallVector newResultAttrs; newResultAttrs.reserve(resAttrs.size()); for (unsigned i = 0, e = resultIndices.size(); i < e; ++i) @@ -242,7 +317,7 @@ op.setFunctionTypeAttr(TypeAttr::get(newType)); } -TypeRange mlir::function_interface_impl::insertTypesInto( +TypeRange function_interface_impl::insertTypesInto( TypeRange oldTypes, ArrayRef indices, TypeRange newTypes, SmallVectorImpl &storage) { assert(indices.size() == newTypes.size() && @@ -261,7 +336,7 @@ return storage; } -TypeRange mlir::function_interface_impl::filterTypesOut( +TypeRange function_interface_impl::filterTypesOut( TypeRange types, const BitVector &indices, SmallVectorImpl &storage) { if (indices.none()) return types; @@ -276,8 +351,8 @@ // Function type signature. //===----------------------------------------------------------------------===// -void mlir::function_interface_impl::setFunctionType(FunctionOpInterface op, - Type newType) { +void function_interface_impl::setFunctionType(FunctionOpInterface op, + Type newType) { unsigned oldNumArgs = op.getNumArguments(); unsigned oldNumResults = op.getNumResults(); op.setFunctionTypeAttr(TypeAttr::get(newType)); @@ -285,35 +360,31 @@ unsigned newNumResults = op.getNumResults(); // Functor used to update the argument and result attributes of the function. - auto updateAttrFn = [&](StringRef attrName, unsigned oldCount, - unsigned newCount, auto setAttrFn) { + auto updateAttrFn = [&](auto isArg, unsigned oldCount, unsigned newCount) { + constexpr bool isArgVal = std::is_same_v; + if (oldCount == newCount) return; // The new type has no arguments/results, just drop the attribute. - if (newCount == 0) { - op->removeAttr(attrName); - return; - } - ArrayAttr attrs = op->getAttrOfType(attrName); + if (newCount == 0) + return removeArgResAttrs(op); + ArrayAttr attrs = getArgResAttrs(op); if (!attrs) return; // The new type has less arguments/results, take the first N attributes. if (newCount < oldCount) - return setAttrFn(op, attrs.getValue().take_front(newCount)); + return setAllArgResAttrDicts( + op, attrs.getValue().take_front(newCount)); // Otherwise, the new type has more arguments/results. Initialize the new // arguments/results with empty attributes. SmallVector newAttrs(attrs.begin(), attrs.end()); newAttrs.resize(newCount); - setAttrFn(op, newAttrs); + setAllArgResAttrDicts(op, newAttrs); }; // Update the argument and result attributes. - updateAttrFn( - getArgDictAttrName(), oldNumArgs, newNumArgs, - [&](Operation *op, auto &&attrs) { setAllArgAttrDicts(op, attrs); }); - updateAttrFn( - getResultDictAttrName(), oldNumResults, newNumResults, - [&](Operation *op, auto &&attrs) { setAllResultAttrDicts(op, attrs); }); + updateAttrFn(std::true_type{}, oldNumArgs, newNumArgs); + updateAttrFn(std::false_type{}, oldNumResults, newNumResults); } diff --git a/mlir/test/IR/invalid-func-op.mlir b/mlir/test/IR/invalid-func-op.mlir --- a/mlir/test/IR/invalid-func-op.mlir +++ b/mlir/test/IR/invalid-func-op.mlir @@ -96,20 +96,11 @@ // ----- -// expected-error@+1 {{argument attribute array `arg_attrs` to have the same number of elements as the number of function arguments}} +// expected-error@+1 {{argument attribute array to have the same number of elements as the number of function arguments}} func.func private @invalid_arg_attrs() attributes { arg_attrs = [{}] } // ----- -// expected-error@+1 {{expects argument attribute dictionary to be a DictionaryAttr, but got `10 : i64`}} -func.func private @invalid_arg_attrs(i32) attributes { arg_attrs = [10] } -// ----- - -// expected-error@+1 {{result attribute array `res_attrs` to have the same number of elements as the number of function results}} +// expected-error@+1 {{result attribute array to have the same number of elements as the number of function results}} func.func private @invalid_res_attrs() attributes { res_attrs = [{}] } - -// ----- - -// expected-error@+1 {{expects result attribute dictionary to be a DictionaryAttr, but got `10 : i64`}} -func.func private @invalid_res_attrs() -> i32 attributes { res_attrs = [10] }