diff --git a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp --- a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp +++ b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp @@ -501,7 +501,8 @@ // correctly. for (auto e : llvm::enumerate(funcTy.getInputs())) { unsigned index = e.index(); - llvm::ArrayRef attrs = func.getArgAttrs(index); + llvm::ArrayRef attrs = + mlir::function_interface_impl::getArgAttrs(func, index); for (mlir::NamedAttribute attr : attrs) { savedAttrs.push_back({index, attr}); } diff --git a/mlir/examples/toy/Ch2/include/toy/Ops.td b/mlir/examples/toy/Ch2/include/toy/Ops.td --- a/mlir/examples/toy/Ch2/include/toy/Ops.td +++ b/mlir/examples/toy/Ch2/include/toy/Ops.td @@ -134,7 +134,9 @@ let arguments = (ins SymbolNameAttr:$sym_name, - TypeAttrOf:$function_type + TypeAttrOf:$function_type, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs ); let regions = (region AnyRegion:$body); diff --git a/mlir/examples/toy/Ch2/mlir/Dialect.cpp b/mlir/examples/toy/Ch2/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch2/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch2/mlir/Dialect.cpp @@ -211,14 +211,17 @@ std::string &) { return builder.getFunctionType(argTypes, results); }; return mlir::function_interface_impl::parseFunctionOp( - parser, result, /*allowVariadic=*/false, buildFuncType); + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); } void FuncOp::print(mlir::OpAsmPrinter &p) { // Dispatch to the FunctionOpInterface provided utility method that prints the // function operation. - mlir::function_interface_impl::printFunctionOp(p, *this, - /*isVariadic=*/false); + mlir::function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); } //===----------------------------------------------------------------------===// diff --git a/mlir/examples/toy/Ch3/include/toy/Ops.td b/mlir/examples/toy/Ch3/include/toy/Ops.td --- a/mlir/examples/toy/Ch3/include/toy/Ops.td +++ b/mlir/examples/toy/Ch3/include/toy/Ops.td @@ -133,7 +133,9 @@ let arguments = (ins SymbolNameAttr:$sym_name, - TypeAttrOf:$function_type + TypeAttrOf:$function_type, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs ); let regions = (region AnyRegion:$body); diff --git a/mlir/examples/toy/Ch3/mlir/Dialect.cpp b/mlir/examples/toy/Ch3/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch3/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch3/mlir/Dialect.cpp @@ -198,14 +198,17 @@ std::string &) { return builder.getFunctionType(argTypes, results); }; return mlir::function_interface_impl::parseFunctionOp( - parser, result, /*allowVariadic=*/false, buildFuncType); + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); } void FuncOp::print(mlir::OpAsmPrinter &p) { // Dispatch to the FunctionOpInterface provided utility method that prints the // function operation. - mlir::function_interface_impl::printFunctionOp(p, *this, - /*isVariadic=*/false); + mlir::function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); } //===----------------------------------------------------------------------===// diff --git a/mlir/examples/toy/Ch4/include/toy/Ops.td b/mlir/examples/toy/Ch4/include/toy/Ops.td --- a/mlir/examples/toy/Ch4/include/toy/Ops.td +++ b/mlir/examples/toy/Ch4/include/toy/Ops.td @@ -163,7 +163,9 @@ let arguments = (ins SymbolNameAttr:$sym_name, - TypeAttrOf:$function_type + TypeAttrOf:$function_type, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs ); let regions = (region AnyRegion:$body); diff --git a/mlir/examples/toy/Ch4/mlir/Dialect.cpp b/mlir/examples/toy/Ch4/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch4/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch4/mlir/Dialect.cpp @@ -287,14 +287,17 @@ std::string &) { return builder.getFunctionType(argTypes, results); }; return mlir::function_interface_impl::parseFunctionOp( - parser, result, /*allowVariadic=*/false, buildFuncType); + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); } void FuncOp::print(mlir::OpAsmPrinter &p) { // Dispatch to the FunctionOpInterface provided utility method that prints the // function operation. - mlir::function_interface_impl::printFunctionOp(p, *this, - /*isVariadic=*/false); + mlir::function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); } /// Returns the region on the function operation that is callable. diff --git a/mlir/examples/toy/Ch5/include/toy/Ops.td b/mlir/examples/toy/Ch5/include/toy/Ops.td --- a/mlir/examples/toy/Ch5/include/toy/Ops.td +++ b/mlir/examples/toy/Ch5/include/toy/Ops.td @@ -163,7 +163,9 @@ let arguments = (ins SymbolNameAttr:$sym_name, - TypeAttrOf:$function_type + TypeAttrOf:$function_type, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs ); let regions = (region AnyRegion:$body); diff --git a/mlir/examples/toy/Ch5/mlir/Dialect.cpp b/mlir/examples/toy/Ch5/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch5/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch5/mlir/Dialect.cpp @@ -287,14 +287,17 @@ std::string &) { return builder.getFunctionType(argTypes, results); }; return mlir::function_interface_impl::parseFunctionOp( - parser, result, /*allowVariadic=*/false, buildFuncType); + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); } void FuncOp::print(mlir::OpAsmPrinter &p) { // Dispatch to the FunctionOpInterface provided utility method that prints the // function operation. - mlir::function_interface_impl::printFunctionOp(p, *this, - /*isVariadic=*/false); + mlir::function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); } /// Returns the region on the function operation that is callable. diff --git a/mlir/examples/toy/Ch6/include/toy/Ops.td b/mlir/examples/toy/Ch6/include/toy/Ops.td --- a/mlir/examples/toy/Ch6/include/toy/Ops.td +++ b/mlir/examples/toy/Ch6/include/toy/Ops.td @@ -163,7 +163,9 @@ let arguments = (ins SymbolNameAttr:$sym_name, - TypeAttrOf:$function_type + TypeAttrOf:$function_type, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs ); let regions = (region AnyRegion:$body); diff --git a/mlir/examples/toy/Ch6/mlir/Dialect.cpp b/mlir/examples/toy/Ch6/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch6/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch6/mlir/Dialect.cpp @@ -287,14 +287,17 @@ std::string &) { return builder.getFunctionType(argTypes, results); }; return mlir::function_interface_impl::parseFunctionOp( - parser, result, /*allowVariadic=*/false, buildFuncType); + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); } void FuncOp::print(mlir::OpAsmPrinter &p) { // Dispatch to the FunctionOpInterface provided utility method that prints the // function operation. - mlir::function_interface_impl::printFunctionOp(p, *this, - /*isVariadic=*/false); + mlir::function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); } /// Returns the region on the function operation that is callable. diff --git a/mlir/examples/toy/Ch7/include/toy/Ops.td b/mlir/examples/toy/Ch7/include/toy/Ops.td --- a/mlir/examples/toy/Ch7/include/toy/Ops.td +++ b/mlir/examples/toy/Ch7/include/toy/Ops.td @@ -186,7 +186,9 @@ let arguments = (ins SymbolNameAttr:$sym_name, - TypeAttrOf:$function_type + TypeAttrOf:$function_type, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs ); let regions = (region AnyRegion:$body); diff --git a/mlir/examples/toy/Ch7/mlir/Dialect.cpp b/mlir/examples/toy/Ch7/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch7/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch7/mlir/Dialect.cpp @@ -314,14 +314,17 @@ std::string &) { return builder.getFunctionType(argTypes, results); }; return mlir::function_interface_impl::parseFunctionOp( - parser, result, /*allowVariadic=*/false, buildFuncType); + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); } void FuncOp::print(mlir::OpAsmPrinter &p) { // Dispatch to the FunctionOpInterface provided utility method that prints the // function operation. - mlir::function_interface_impl::printFunctionOp(p, *this, - /*isVariadic=*/false); + mlir::function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); } /// Returns the region on the function operation that is callable. diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td --- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td +++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td @@ -140,7 +140,9 @@ let arguments = (ins SymbolNameAttr:$sym_name, TypeAttrOf:$function_type, - OptionalAttr:$sym_visibility); + OptionalAttr:$sym_visibility, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs); let regions = (region AnyRegion:$body); diff --git a/mlir/include/mlir/Dialect/Func/IR/FuncOps.td b/mlir/include/mlir/Dialect/Func/IR/FuncOps.td --- a/mlir/include/mlir/Dialect/Func/IR/FuncOps.td +++ b/mlir/include/mlir/Dialect/Func/IR/FuncOps.td @@ -251,7 +251,9 @@ let arguments = (ins SymbolNameAttr:$sym_name, TypeAttrOf:$function_type, - OptionalAttr:$sym_visibility); + OptionalAttr:$sym_visibility, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs); let regions = (region AnyRegion:$body); let builders = [OpBuilder<(ins diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td @@ -242,7 +242,9 @@ attribution. }]; - let arguments = (ins TypeAttrOf:$function_type); + let arguments = (ins TypeAttrOf:$function_type, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs); let regions = (region AnyRegion:$body); let skipDefaultBuilders = 1; diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -1311,7 +1311,9 @@ DefaultValuedAttr:$CConv, OptionalAttr:$personality, OptionalAttr:$garbageCollector, - OptionalAttr:$passthrough + OptionalAttr:$passthrough, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs ); let regions = (region AnyRegion:$body); diff --git a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td --- a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td +++ b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td @@ -52,6 +52,8 @@ let arguments = (ins SymbolNameAttr:$sym_name, TypeAttrOf:$function_type, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs, OptionalAttr:$sym_visibility); let regions = (region AnyRegion:$body); @@ -401,6 +403,8 @@ let arguments = (ins SymbolNameAttr:$sym_name, TypeAttrOf:$function_type, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs, OptionalAttr:$sym_visibility); let regions = (region AnyRegion:$body); diff --git a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td --- a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td +++ b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td @@ -652,7 +652,9 @@ let arguments = (ins SymbolNameAttr:$sym_name, - TypeAttrOf:$function_type + TypeAttrOf:$function_type, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs ); let regions = (region MinSizedRegion<1>:$body); diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td @@ -291,6 +291,8 @@ let arguments = (ins TypeAttrOf:$function_type, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs, StrAttr:$sym_name, SPIRV_FunctionControlAttr:$function_control ); diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -1107,6 +1107,8 @@ let arguments = (ins SymbolNameAttr:$sym_name, TypeAttrOf:$function_type, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs, OptionalAttr:$sym_visibility); let regions = (region AnyRegion:$body); diff --git a/mlir/include/mlir/IR/FunctionImplementation.h b/mlir/include/mlir/IR/FunctionImplementation.h --- a/mlir/include/mlir/IR/FunctionImplementation.h +++ b/mlir/include/mlir/IR/FunctionImplementation.h @@ -39,10 +39,12 @@ /// with special names given by getResultAttrName, getArgumentAttrName. void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef argAttrs, - ArrayRef resultAttrs); + ArrayRef resultAttrs, + StringAttr argAttrsName, StringAttr resAttrsName); void addArgAndResultAttrs(Builder &builder, OperationState &result, - ArrayRef argAttrs, - ArrayRef resultAttrs); + ArrayRef args, + ArrayRef resultAttrs, + StringAttr argAttrsName, StringAttr resAttrsName); /// Callback type for `parseFunctionOp`, the callback should produce the /// type that will be associated with a function-like operation from lists of @@ -69,21 +71,25 @@ /// Parser implementation for function-like operations. Uses /// `funcTypeBuilder` to construct the custom function type given lists of -/// input and output types. If `allowVariadic` is set, the parser will accept +/// input and output types. The parser sets the `typeAttrName` attribute to the +/// resulting function type. If `allowVariadic` is set, the parser will accept /// trailing ellipsis in the function signature and indicate to the builder /// whether the function is variadic. If the builder returns a null type, /// `result` will not contain the `type` attribute. The caller can then add a /// type, report the error or delegate the reporting to the op's verifier. ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, - bool allowVariadic, - FuncTypeBuilder funcTypeBuilder); + bool allowVariadic, StringAttr typeAttrName, + FuncTypeBuilder funcTypeBuilder, + StringAttr argAttrsName, StringAttr resAttrsName); /// Printer implementation for function-like operations. -void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic); +void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, + StringRef typeAttrName, StringAttr argAttrsName, + StringAttr resAttrsName); /// Prints the signature of the function-like operation `op`. Assumes `op` has /// is a FunctionOpInterface and has passed verification. -void printFunctionSignature(OpAsmPrinter &p, Operation *op, +void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef argTypes, bool isVariadic, ArrayRef resultTypes); @@ -92,8 +98,7 @@ /// function-like operation internally are not printed. Nothing is printed /// if all attributes are elided. Assumes `op` is a FunctionOpInterface and /// has passed verification. -void printFunctionAttributes(OpAsmPrinter &p, Operation *op, unsigned numInputs, - unsigned numResults, +void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef elided = {}); } // namespace function_interface_impl diff --git a/mlir/include/mlir/IR/FunctionInterfaces.h b/mlir/include/mlir/IR/FunctionInterfaces.h --- a/mlir/include/mlir/IR/FunctionInterfaces.h +++ b/mlir/include/mlir/IR/FunctionInterfaces.h @@ -22,78 +22,59 @@ #include "llvm/ADT/SmallString.h" namespace mlir { +class FunctionOpInterface; namespace function_interface_impl { -/// Return the name of the attribute used for function types. -inline StringRef getTypeAttrName() { return "function_type"; } - -/// Return the name of the attribute used for function argument attributes. -inline StringRef getArgDictAttrName() { return "arg_attrs"; } - -/// Return the name of the attribute used for function argument attributes. -inline StringRef getResultDictAttrName() { return "res_attrs"; } - /// Returns the dictionary attribute corresponding to the argument at 'index'. /// If there are no argument attributes at 'index', a null attribute is /// returned. -DictionaryAttr getArgAttrDict(Operation *op, unsigned index); +DictionaryAttr getArgAttrDict(FunctionOpInterface op, unsigned index); /// Returns the dictionary attribute corresponding to the result at 'index'. /// If there are no result attributes at 'index', a null attribute is /// returned. -DictionaryAttr getResultAttrDict(Operation *op, unsigned index); +DictionaryAttr getResultAttrDict(FunctionOpInterface op, unsigned index); -namespace detail { -/// Update the given index into an argument or result attribute dictionary. -void setArgResAttrDict(Operation *op, StringRef attrName, - unsigned numTotalIndices, unsigned index, - DictionaryAttr attrs); -} // namespace detail +/// Return all of the attributes for the argument at 'index'. +ArrayRef getArgAttrs(FunctionOpInterface op, unsigned index); + +/// Return all of the attributes for the result at 'index'. +ArrayRef getResultAttrs(FunctionOpInterface op, unsigned index); /// Set all of the argument or result attribute dictionaries for a function. The /// size of `attrs` is expected to match the number of arguments/results of the /// given `op`. -void setAllArgAttrDicts(Operation *op, ArrayRef attrs); -void setAllArgAttrDicts(Operation *op, ArrayRef attrs); -void setAllResultAttrDicts(Operation *op, ArrayRef attrs); -void setAllResultAttrDicts(Operation *op, ArrayRef attrs); - -/// Return all of the attributes for the argument at 'index'. -inline ArrayRef getArgAttrs(Operation *op, unsigned index) { - auto argDict = getArgAttrDict(op, index); - return argDict ? argDict.getValue() : std::nullopt; -} - -/// Return all of the attributes for the result at 'index'. -inline ArrayRef getResultAttrs(Operation *op, unsigned index) { - auto resultDict = getResultAttrDict(op, index); - return resultDict ? resultDict.getValue() : std::nullopt; -} +void setAllArgAttrDicts(FunctionOpInterface op, ArrayRef attrs); +void setAllArgAttrDicts(FunctionOpInterface op, ArrayRef attrs); +void setAllResultAttrDicts(FunctionOpInterface op, + ArrayRef attrs); +void setAllResultAttrDicts(FunctionOpInterface op, ArrayRef attrs); /// Insert the specified arguments and update the function type attribute. -void insertFunctionArguments(Operation *op, ArrayRef argIndices, - TypeRange argTypes, +void insertFunctionArguments(FunctionOpInterface op, + ArrayRef argIndices, TypeRange argTypes, ArrayRef argAttrs, ArrayRef argLocs, unsigned originalNumArgs, Type newType); /// Insert the specified results and update the function type attribute. -void insertFunctionResults(Operation *op, ArrayRef resultIndices, +void insertFunctionResults(FunctionOpInterface op, + ArrayRef resultIndices, TypeRange resultTypes, ArrayRef resultAttrs, unsigned originalNumResults, Type newType); /// Erase the specified arguments and update the function type attribute. -void eraseFunctionArguments(Operation *op, const BitVector &argIndices, +void eraseFunctionArguments(FunctionOpInterface op, const BitVector &argIndices, Type newType); /// Erase the specified results and update the function type attribute. -void eraseFunctionResults(Operation *op, const BitVector &resultIndices, - Type newType); +void eraseFunctionResults(FunctionOpInterface op, + const BitVector &resultIndices, Type newType); /// Set a FunctionOpInterface operation's type signature. -void setFunctionType(Operation *op, Type newType); +void setFunctionType(FunctionOpInterface op, Type newType); /// Insert a set of `newTypes` into `oldTypes` at the given `indices`. If any /// types are inserted, `storage` is used to hold the new type list. The new @@ -111,20 +92,10 @@ //===----------------------------------------------------------------------===// /// Set the attributes held by the argument at 'index'. -template -void setArgAttrs(ConcreteType op, unsigned index, - ArrayRef attributes) { - assert(index < op.getNumArguments() && "invalid argument number"); - return detail::setArgResAttrDict( - op, getArgDictAttrName(), op.getNumArguments(), index, - DictionaryAttr::get(op->getContext(), attributes)); -} -template -void setArgAttrs(ConcreteType op, unsigned index, DictionaryAttr attributes) { - return detail::setArgResAttrDict( - op, getArgDictAttrName(), op.getNumArguments(), index, - attributes ? attributes : DictionaryAttr::get(op->getContext())); -} +void setArgAttrs(FunctionOpInterface op, unsigned index, + ArrayRef attributes); +void setArgAttrs(FunctionOpInterface op, unsigned index, + DictionaryAttr attributes); /// If the an attribute exists with the specified name, change it to the new /// value. Otherwise, add a new attribute with the specified name/value. @@ -158,23 +129,10 @@ //===----------------------------------------------------------------------===// /// Set the attributes held by the result at 'index'. -template -void setResultAttrs(ConcreteType op, unsigned index, - ArrayRef attributes) { - assert(index < op.getNumResults() && "invalid result number"); - return detail::setArgResAttrDict( - op, getResultDictAttrName(), op.getNumResults(), index, - DictionaryAttr::get(op->getContext(), attributes)); -} - -template -void setResultAttrs(ConcreteType op, unsigned index, - DictionaryAttr attributes) { - assert(index < op.getNumResults() && "invalid result number"); - return detail::setArgResAttrDict( - op, getResultDictAttrName(), op.getNumResults(), index, - attributes ? attributes : DictionaryAttr::get(op->getContext())); -} +void setResultAttrs(FunctionOpInterface op, unsigned index, + ArrayRef attributes); +void setResultAttrs(FunctionOpInterface op, unsigned index, + DictionaryAttr attributes); /// If the an attribute exists with the specified name, change it to the new /// value. Otherwise, add a new attribute with the specified name/value. @@ -207,10 +165,6 @@ /// method on FunctionOpInterface::Trait. template LogicalResult verifyTrait(ConcreteOp op) { - if (!op.getFunctionTypeAttr()) - return op.emitOpError("requires a type attribute '") - << function_interface_impl::getTypeAttrName() << '\''; - if (failed(op.verifyType())) return failure(); @@ -218,9 +172,8 @@ unsigned numArgs = op.getNumArguments(); if (allArgAttrs.size() != numArgs) { return op.emitOpError() - << "expects argument attribute array `" << getArgDictAttrName() - << "` to have the same number of elements as the number of " - "function arguments, got " + << "expects argument attribute array to have the same number of " + "elements as the number of function arguments, got " << allArgAttrs.size() << ", but expected " << numArgs; } for (unsigned i = 0; i != numArgs; ++i) { @@ -250,9 +203,8 @@ unsigned numResults = op.getNumResults(); if (allResultAttrs.size() != numResults) { return op.emitOpError() - << "expects result attribute array `" << getResultDictAttrName() - << "` to have the same number of elements as the number of " - "function results, got " + << "expects result attribute array to have the same number of " + "elements as the number of function results, got " << allResultAttrs.size() << ", but expected " << numResults; } for (unsigned i = 0; i != numResults; ++i) { diff --git a/mlir/include/mlir/IR/FunctionInterfaces.td b/mlir/include/mlir/IR/FunctionInterfaces.td --- a/mlir/include/mlir/IR/FunctionInterfaces.td +++ b/mlir/include/mlir/IR/FunctionInterfaces.td @@ -49,6 +49,52 @@ for each of the function results. }]; let methods = [ + InterfaceMethod<[{ + Returns the type of the function. + }], + "::mlir::Type", "getFunctionType">, + InterfaceMethod<[{ + Set the type of the function. This method should perform an unsafe + modification to the function type; it should not update argument or + result attributes. + }], + "void", "setFunctionTypeAttr", (ins "::mlir::TypeAttr":$type)>, + + InterfaceMethod<[{ + Get the array of argument attribute dictionaries. The method should return + an array attribute containing only dictionary attributes equal in number + to the number of function arguments. Alternatively, the method can return + null to indicate that the function has no argument attributes. + }], + "::mlir::ArrayAttr", "getArgAttrsAttr">, + InterfaceMethod<[{ + Get the array of result attribute dictionaries. The method should return + an array attribute containing only dictionary attributes equal in number + to the number of function results. Alternatively, the method can return + null to indicate that the function has no result attributes. + }], + "::mlir::ArrayAttr", "getResAttrsAttr">, + InterfaceMethod<[{ + Set the array of argument attribute dictionaries. + }], + "void", "setArgAttrsAttr", (ins "::mlir::ArrayAttr":$attrs)>, + InterfaceMethod<[{ + Set the array of result attribute dictionaries. + }], + "void", "setResAttrsAttr", (ins "::mlir::ArrayAttr":$attrs)>, + InterfaceMethod<[{ + Remove the array of argument attribute dictionaries. This is the same as + setting all argument attributes to an empty dictionary. The method should + return the removed attribute. + }], + "::mlir::Attribute", "removeArgAttrsAttr">, + InterfaceMethod<[{ + Remove the array of result attribute dictionaries. This is the same as + setting all result attributes to an empty dictionary. The method should + return the removed attribute. + }], + "::mlir::Attribute", "removeResAttrsAttr">, + InterfaceMethod<[{ Returns the function argument types based exclusively on the type (to allow for this method may be called on function @@ -139,7 +185,7 @@ ArrayRef attrs, TypeRange inputTypes) { state.addAttribute(SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)); - state.addAttribute(function_interface_impl::getTypeAttrName(), + state.addAttribute(ConcreteOp::getFunctionTypeAttrName(state.name), TypeAttr::get(type)); state.attributes.append(attrs.begin(), attrs.end()); @@ -240,34 +286,6 @@ function_interface_impl::setFunctionType(this->getOperation(), newType); } - // FIXME: These functions should be removed in favor of just forwarding to - // the derived operation, which should already have these defined - // (via ODS). - - /// Returns the name of the attribute used for function types. - static StringRef getTypeAttrName() { - return function_interface_impl::getTypeAttrName(); - } - - /// Returns the name of the attribute used for function argument attributes. - static StringRef getArgDictAttrName() { - return function_interface_impl::getArgDictAttrName(); - } - - /// Returns the name of the attribute used for function argument attributes. - static StringRef getResultDictAttrName() { - return function_interface_impl::getResultDictAttrName(); - } - - /// Return the attribute containing the type of this function. - TypeAttr getFunctionTypeAttr() { - return this->getOperation()->template getAttrOfType( - getTypeAttrName()); - } - - /// Return the type of this function. - Type getFunctionType() { return getFunctionTypeAttr().getValue(); } - //===------------------------------------------------------------------===// // Argument and Result Handling //===------------------------------------------------------------------===// @@ -409,10 +427,8 @@ /// Return an ArrayAttr containing all argument attribute dictionaries of /// this function, or nullptr if no arguments have attributes. - ArrayAttr getAllArgAttrs() { - return this->getOperation()->template getAttrOfType( - getArgDictAttrName()); - } + ArrayAttr getAllArgAttrs() { return $_op.getArgAttrsAttr(); } + /// Return all argument attributes of this function. void getAllArgAttrs(SmallVectorImpl &result) { if (ArrayAttr argAttrs = getAllArgAttrs()) { @@ -464,7 +480,7 @@ } void setAllArgAttrs(ArrayAttr attributes) { assert(attributes.size() == $_op.getNumArguments()); - this->getOperation()->setAttr(getArgDictAttrName(), attributes); + $_op.setArgAttrsAttr(attributes); } /// If the an attribute exists with the specified name, change it to the new @@ -500,10 +516,8 @@ /// Return an ArrayAttr containing all result attribute dictionaries of this /// function, or nullptr if no result have attributes. - ArrayAttr getAllResultAttrs() { - return this->getOperation()->template getAttrOfType( - getResultDictAttrName()); - } + ArrayAttr getAllResultAttrs() { return $_op.getResAttrsAttr(); } + /// Return all result attributes of this function. void getAllResultAttrs(SmallVectorImpl &result) { if (ArrayAttr argAttrs = getAllResultAttrs()) { @@ -557,7 +571,7 @@ } void setAllResultAttrs(ArrayAttr attributes) { assert(attributes.size() == $_op.getNumResults()); - this->getOperation()->setAttr(getResultDictAttrName(), attributes); + $_op.setResAttrsAttr(attributes); } /// If the an attribute exists with the specified name, change it to the new diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1524,6 +1524,8 @@ } def IndexListArrayAttr : TypedArrayAttrBase; +def DictArrayAttr : + TypedArrayAttrBase; // Attributes containing symbol references. def SymbolRefAttr : Attr()">, diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp --- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp +++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp @@ -59,16 +59,15 @@ /// Only retain those attributes that are not constructed by /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument /// attributes. -static void filterFuncAttributes(ArrayRef attrs, - bool filterArgAndResAttrs, +static void filterFuncAttributes(func::FuncOp func, bool filterArgAndResAttrs, SmallVectorImpl &result) { - for (const auto &attr : attrs) { + for (const NamedAttribute &attr : func->getAttrs()) { if (attr.getName() == SymbolTable::getSymbolAttrName() || - attr.getName() == FunctionOpInterface::getTypeAttrName() || + attr.getName() == func.getFunctionTypeAttrName() || attr.getName() == "func.varargs" || (filterArgAndResAttrs && - (attr.getName() == FunctionOpInterface::getArgDictAttrName() || - attr.getName() == FunctionOpInterface::getResultDictAttrName()))) + (attr.getName() == func.getArgAttrsAttrName() || + attr.getName() == func.getResAttrsAttrName()))) continue; result.push_back(attr); } @@ -91,18 +90,19 @@ static void prependResAttrsToArgAttrs(OpBuilder &builder, SmallVectorImpl &attributes, - size_t numArguments) { + func::FuncOp func) { + size_t numArguments = func.getNumArguments(); auto allAttrs = SmallVector( numArguments + 1, DictionaryAttr::get(builder.getContext())); NamedAttribute *argAttrs = nullptr; for (auto *it = attributes.begin(); it != attributes.end();) { - if (it->getName() == FunctionOpInterface::getArgDictAttrName()) { + if (it->getName() == func.getArgAttrsAttrName()) { auto arrayAttrs = it->getValue().cast(); assert(arrayAttrs.size() == numArguments && "Number of arg attrs and args should match"); std::copy(arrayAttrs.begin(), arrayAttrs.end(), allAttrs.begin() + 1); argAttrs = it; - } else if (it->getName() == FunctionOpInterface::getResultDictAttrName()) { + } else if (it->getName() == func.getResAttrsAttrName()) { auto arrayAttrs = it->getValue().cast(); assert(!arrayAttrs.empty() && "expected array to be non-empty"); allAttrs[0] = (arrayAttrs.size() == 1) @@ -114,9 +114,8 @@ it++; } - auto newArgAttrs = - builder.getNamedAttr(FunctionOpInterface::getArgDictAttrName(), - builder.getArrayAttr(allAttrs)); + auto newArgAttrs = builder.getNamedAttr(func.getArgAttrsAttrName(), + builder.getArrayAttr(allAttrs)); if (!argAttrs) { attributes.emplace_back(newArgAttrs); return; @@ -138,12 +137,11 @@ LLVM::LLVMFuncOp newFuncOp) { auto type = funcOp.getFunctionType(); SmallVector attributes; - filterFuncAttributes(funcOp->getAttrs(), /*filterArgAndResAttrs=*/false, - attributes); + filterFuncAttributes(funcOp, /*filterArgAndResAttrs=*/false, attributes); auto [wrapperFuncType, resultIsNowArg] = typeConverter.convertFunctionTypeCWrapper(type); if (resultIsNowArg) - prependResAttrsToArgAttrs(rewriter, attributes, funcOp.getNumArguments()); + prependResAttrsToArgAttrs(rewriter, attributes, funcOp); auto wrapperFuncOp = rewriter.create( loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(), wrapperFuncType, LLVM::Linkage::External, /*dsoLocal*/ false, @@ -204,11 +202,10 @@ assert(wrapperType && "unexpected type conversion failure"); SmallVector attributes; - filterFuncAttributes(funcOp->getAttrs(), /*filterArgAndResAttrs=*/false, - attributes); + filterFuncAttributes(funcOp, /*filterArgAndResAttrs=*/false, attributes); if (resultIsNowArg) - prependResAttrsToArgAttrs(builder, attributes, funcOp.getNumArguments()); + prependResAttrsToArgAttrs(builder, attributes, funcOp); // Create the auxiliary function. auto wrapperFunc = builder.create( loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(), @@ -304,8 +301,7 @@ // Propagate argument/result attributes to all converted arguments/result // obtained after converting a given original argument/result. SmallVector attributes; - filterFuncAttributes(funcOp->getAttrs(), /*filterArgAndResAttrs=*/true, - attributes); + filterFuncAttributes(funcOp, /*filterArgAndResAttrs=*/true, attributes); if (ArrayAttr resAttrDicts = funcOp.getAllResultAttrs()) { assert(!resAttrDicts.empty() && "expected array to be non-empty"); auto newResAttrDicts = @@ -313,8 +309,8 @@ ? resAttrDicts : rewriter.getArrayAttr( {wrapAsStructAttrs(rewriter, resAttrDicts)}); - attributes.push_back(rewriter.getNamedAttr( - FunctionOpInterface::getResultDictAttrName(), newResAttrDicts)); + attributes.push_back( + rewriter.getNamedAttr(funcOp.getResAttrsAttrName(), newResAttrDicts)); } if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) { SmallVector newArgAttrs( @@ -357,9 +353,8 @@ newArgAttrs[mapping->inputNo + j] = DictionaryAttr::get(rewriter.getContext(), convertedAttrs); } - attributes.push_back( - rewriter.getNamedAttr(FunctionOpInterface::getArgDictAttrName(), - rewriter.getArrayAttr(newArgAttrs))); + attributes.push_back(rewriter.getNamedAttr( + funcOp.getArgAttrsAttrName(), rewriter.getArrayAttr(newArgAttrs))); } for (const auto &pair : llvm::enumerate(attributes)) { if (pair.value().getName() == "llvm.linkage") { diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp @@ -60,7 +60,7 @@ SmallVector attributes; for (const auto &attr : gpuFuncOp->getAttrs()) { if (attr.getName() == SymbolTable::getSymbolAttrName() || - attr.getName() == FunctionOpInterface::getTypeAttrName() || + attr.getName() == gpuFuncOp.getFunctionTypeAttrName() || attr.getName() == gpu::GPUFuncOp::getNumWorkgroupAttributionsAttrName()) continue; attributes.push_back(attr); diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp @@ -226,7 +226,7 @@ rewriter.getFunctionType(signatureConverter.getConvertedTypes(), std::nullopt)); for (const auto &namedAttr : funcOp->getAttrs()) { - if (namedAttr.getName() == FunctionOpInterface::getTypeAttrName() || + if (namedAttr.getName() == funcOp.getFunctionTypeAttrName() || namedAttr.getName() == SymbolTable::getSymbolAttrName()) continue; newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue()); diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp --- a/mlir/lib/Dialect/Async/IR/Async.cpp +++ b/mlir/lib/Dialect/Async/IR/Async.cpp @@ -332,8 +332,7 @@ ArrayRef argAttrs) { state.addAttribute(SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)); - state.addAttribute(FunctionOpInterface::getTypeAttrName(), - TypeAttr::get(type)); + state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type)); state.attributes.append(attrs.begin(), attrs.end()); state.addRegion(); @@ -341,8 +340,9 @@ if (argAttrs.empty()) return; assert(type.getNumInputs() == argAttrs.size()); - function_interface_impl::addArgAndResultAttrs(builder, state, argAttrs, - /*resultAttrs=*/std::nullopt); + function_interface_impl::addArgAndResultAttrs( + builder, state, argAttrs, /*resultAttrs=*/std::nullopt, + getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name)); } ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { @@ -352,11 +352,15 @@ std::string &) { return builder.getFunctionType(argTypes, results); }; return function_interface_impl::parseFunctionOp( - parser, result, /*allowVariadic=*/false, buildFuncType); + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); } void FuncOp::print(OpAsmPrinter &p) { - function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false); + function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); } /// Check that the result type of async.func is not void and must be diff --git a/mlir/lib/Dialect/Func/IR/FuncOps.cpp b/mlir/lib/Dialect/Func/IR/FuncOps.cpp --- a/mlir/lib/Dialect/Func/IR/FuncOps.cpp +++ b/mlir/lib/Dialect/Func/IR/FuncOps.cpp @@ -244,16 +244,16 @@ ArrayRef argAttrs) { state.addAttribute(SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)); - state.addAttribute(FunctionOpInterface::getTypeAttrName(), - TypeAttr::get(type)); + state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type)); state.attributes.append(attrs.begin(), attrs.end()); state.addRegion(); if (argAttrs.empty()) return; assert(type.getNumInputs() == argAttrs.size()); - function_interface_impl::addArgAndResultAttrs(builder, state, argAttrs, - /*resultAttrs=*/std::nullopt); + function_interface_impl::addArgAndResultAttrs( + builder, state, argAttrs, /*resultAttrs=*/std::nullopt, + getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name)); } ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { @@ -263,11 +263,15 @@ std::string &) { return builder.getFunctionType(argTypes, results); }; return function_interface_impl::parseFunctionOp( - parser, result, /*allowVariadic=*/false, buildFuncType); + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); } void FuncOp::print(OpAsmPrinter &p) { - function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false); + function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); } /// Clone the internal blocks from this function into dest and all attributes diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -859,7 +859,8 @@ ArrayRef attrs) { result.addAttribute(SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)); - result.addAttribute(getTypeAttrName(), TypeAttr::get(type)); + result.addAttribute(getFunctionTypeAttrName(result.name), + TypeAttr::get(type)); result.addAttribute(getNumWorkgroupAttributionsAttrName(), builder.getI64IntegerAttr(workgroupAttributions.size())); result.addAttributes(attrs); @@ -930,10 +931,12 @@ for (auto &arg : entryArgs) argTypes.push_back(arg.type); auto type = builder.getFunctionType(argTypes, resultTypes); - result.addAttribute(GPUFuncOp::getTypeAttrName(), TypeAttr::get(type)); + result.addAttribute(getFunctionTypeAttrName(result.name), + TypeAttr::get(type)); - function_interface_impl::addArgAndResultAttrs(builder, result, entryArgs, - resultAttrs); + function_interface_impl::addArgAndResultAttrs( + builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name), + getResAttrsAttrName(result.name)); // Parse workgroup memory attributions. if (failed(parseAttributions(parser, GPUFuncOp::getWorkgroupKeyword(), @@ -992,19 +995,15 @@ p << ' ' << getKernelKeyword(); function_interface_impl::printFunctionAttributes( - p, *this, type.getNumInputs(), type.getNumResults(), + p, *this, {getNumWorkgroupAttributionsAttrName(), - GPUDialect::getKernelFuncAttrName()}); + GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()}); p << ' '; p.printRegion(getBody(), /*printEntryBlockArgs=*/false); } LogicalResult GPUFuncOp::verifyType() { - Type type = getFunctionTypeAttr().getValue(); - if (!type.isa()) - return emitOpError("requires '" + getTypeAttrName() + - "' attribute of function type"); - if (isKernel() && getFunctionType().getNumResults() != 0) return emitOpError() << "expected void return type for kernel function"; diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -2006,8 +2006,9 @@ assert(type.cast().getNumParams() == argAttrs.size() && "expected as many argument attribute lists as arguments"); - function_interface_impl::addArgAndResultAttrs(builder, result, argAttrs, - /*resultAttrs=*/std::nullopt); + function_interface_impl::addArgAndResultAttrs( + builder, result, argAttrs, /*resultAttrs=*/std::nullopt, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); } // Builds an LLVM function type from the given lists of input and output types. @@ -2090,13 +2091,14 @@ function_interface_impl::VariadicFlag(isVariadic)); if (!type) return failure(); - result.addAttribute(FunctionOpInterface::getTypeAttrName(), + result.addAttribute(getFunctionTypeAttrName(result.name), TypeAttr::get(type)); if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) return failure(); - function_interface_impl::addArgAndResultAttrs(parser.getBuilder(), result, - entryArgs, resultAttrs); + function_interface_impl::addArgAndResultAttrs( + parser.getBuilder(), result, entryArgs, resultAttrs, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); auto *body = result.addRegion(); OptionalParseResult parseResult = @@ -2130,8 +2132,9 @@ function_interface_impl::printFunctionSignature(p, *this, argTypes, isVarArg(), resTypes); function_interface_impl::printFunctionAttributes( - p, *this, argTypes.size(), resTypes.size(), - {getLinkageAttrName(), getCConvAttrName()}); + p, *this, + {getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(), + getLinkageAttrName(), getCConvAttrName()}); // Print the body if this is not an external function. Region &body = getBody(); diff --git a/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp b/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp --- a/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp +++ b/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp @@ -152,11 +152,15 @@ std::string &) { return builder.getFunctionType(argTypes, results); }; return function_interface_impl::parseFunctionOp( - parser, result, /*allowVariadic=*/false, buildFuncType); + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); } void FuncOp::print(OpAsmPrinter &p) { - function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false); + function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); } //===----------------------------------------------------------------------===// @@ -313,11 +317,15 @@ std::string &) { return builder.getFunctionType(argTypes, results); }; return function_interface_impl::parseFunctionOp( - parser, result, /*allowVariadic=*/false, buildFuncType); + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); } void SubgraphOp::print(OpAsmPrinter &p) { - function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false); + function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp b/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp --- a/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp +++ b/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp @@ -220,11 +220,15 @@ std::string &) { return builder.getFunctionType(argTypes, results); }; return function_interface_impl::parseFunctionOp( - parser, result, /*allowVariadic=*/false, buildFuncType); + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); } void FuncOp::print(OpAsmPrinter &p) { - function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false); + function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -2382,7 +2382,7 @@ for (auto &arg : entryArgs) argTypes.push_back(arg.type); auto fnType = builder.getFunctionType(argTypes, resultTypes); - result.addAttribute(FunctionOpInterface::getTypeAttrName(), + result.addAttribute(getFunctionTypeAttrName(result.name), TypeAttr::get(fnType)); // Parse the optional function control keyword. @@ -2396,8 +2396,9 @@ // Add the attributes to the function arguments. assert(resultAttrs.size() == resultTypes.size()); - function_interface_impl::addArgAndResultAttrs(builder, result, entryArgs, - resultAttrs); + function_interface_impl::addArgAndResultAttrs( + builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name), + getResAttrsAttrName(result.name)); // Parse the optional function body. auto *body = result.addRegion(); @@ -2417,8 +2418,10 @@ printer << " \"" << spirv::stringifyFunctionControl(getFunctionControl()) << "\""; function_interface_impl::printFunctionAttributes( - printer, *this, fnType.getNumInputs(), fnType.getNumResults(), - {spirv::attributeName()}); + printer, *this, + {spirv::attributeName(), + getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(), + getFunctionControlAttrName()}); // Print the body if this is not an external function. Region &body = this->getBody(); @@ -2430,10 +2433,6 @@ } LogicalResult spirv::FuncOp::verifyType() { - auto type = getFunctionTypeAttr().getValue(); - if (!type.isa()) - return emitOpError("requires '" + getTypeAttrName() + - "' attribute of function type"); if (getFunctionType().getNumResults() > 1) return emitOpError("cannot have more than one result"); return success(); @@ -2473,7 +2472,7 @@ ArrayRef attrs) { state.addAttribute(SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)); - state.addAttribute(getTypeAttrName(), TypeAttr::get(type)); + state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type)); state.addAttribute(spirv::attributeName(), builder.getAttr(control)); state.attributes.append(attrs.begin(), attrs.end()); diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -531,7 +531,7 @@ // Copy over all attributes other than the function name and type. for (const auto &namedAttr : funcOp->getAttrs()) { - if (namedAttr.getName() != FunctionOpInterface::getTypeAttrName() && + if (namedAttr.getName() != funcOp.getFunctionTypeAttrName() && namedAttr.getName() != SymbolTable::getSymbolAttrName()) newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue()); } diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -1300,8 +1300,9 @@ if (argAttrs.empty()) return; assert(type.getNumInputs() == argAttrs.size()); - function_interface_impl::addArgAndResultAttrs(builder, state, argAttrs, - /*resultAttrs=*/std::nullopt); + function_interface_impl::addArgAndResultAttrs( + builder, state, argAttrs, /*resultAttrs=*/std::nullopt, + getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name)); } ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { @@ -1311,11 +1312,15 @@ std::string &) { return builder.getFunctionType(argTypes, results); }; return function_interface_impl::parseFunctionOp( - parser, result, /*allowVariadic=*/false, buildFuncType); + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); } void FuncOp::print(OpAsmPrinter &p) { - function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false); + function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/FunctionImplementation.cpp b/mlir/lib/IR/FunctionImplementation.cpp --- a/mlir/lib/IR/FunctionImplementation.cpp +++ b/mlir/lib/IR/FunctionImplementation.cpp @@ -113,7 +113,7 @@ return parser.parseRParen(); } -ParseResult mlir::function_interface_impl::parseFunctionSignature( +ParseResult function_interface_impl::parseFunctionSignature( OpAsmParser &parser, bool allowVariadic, SmallVectorImpl &arguments, bool &isVariadic, SmallVectorImpl &resultTypes, @@ -125,9 +125,10 @@ return success(); } -void mlir::function_interface_impl::addArgAndResultAttrs( +void function_interface_impl::addArgAndResultAttrs( Builder &builder, OperationState &result, ArrayRef argAttrs, - ArrayRef resultAttrs) { + ArrayRef resultAttrs, StringAttr argAttrsName, + StringAttr resAttrsName) { auto nonEmptyAttrsFn = [](DictionaryAttr attrs) { return attrs && !attrs.empty(); }; @@ -142,28 +143,28 @@ // Add the attributes to the function arguments. if (llvm::any_of(argAttrs, nonEmptyAttrsFn)) - result.addAttribute(function_interface_impl::getArgDictAttrName(), - getArrayAttr(argAttrs)); + result.addAttribute(argAttrsName, getArrayAttr(argAttrs)); // Add the attributes to the function results. if (llvm::any_of(resultAttrs, nonEmptyAttrsFn)) - result.addAttribute(function_interface_impl::getResultDictAttrName(), - getArrayAttr(resultAttrs)); + result.addAttribute(resAttrsName, getArrayAttr(resultAttrs)); } -void mlir::function_interface_impl::addArgAndResultAttrs( +void function_interface_impl::addArgAndResultAttrs( Builder &builder, OperationState &result, - ArrayRef args, - ArrayRef resultAttrs) { + ArrayRef args, ArrayRef resultAttrs, + StringAttr argAttrsName, StringAttr resAttrsName) { SmallVector argAttrs; for (const auto &arg : args) argAttrs.push_back(arg.attrs); - addArgAndResultAttrs(builder, result, argAttrs, resultAttrs); + addArgAndResultAttrs(builder, result, argAttrs, resultAttrs, argAttrsName, + resAttrsName); } -ParseResult mlir::function_interface_impl::parseFunctionOp( +ParseResult function_interface_impl::parseFunctionOp( OpAsmParser &parser, OperationState &result, bool allowVariadic, - FuncTypeBuilder funcTypeBuilder) { + StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, + StringAttr argAttrsName, StringAttr resAttrsName) { SmallVector entryArgs; SmallVector resultAttrs; SmallVector resultTypes; @@ -197,7 +198,7 @@ << "failed to construct function type" << (errorMessage.empty() ? "" : ": ") << errorMessage; } - result.addAttribute(getTypeAttrName(), TypeAttr::get(type)); + result.addAttribute(typeAttrName, TypeAttr::get(type)); // If function attributes are present, parse them. NamedAttrList parsedAttributes; @@ -209,7 +210,7 @@ // dictionary. for (StringRef disallowed : {SymbolTable::getVisibilityAttrName(), SymbolTable::getSymbolAttrName(), - getTypeAttrName()}) { + typeAttrName.getValue()}) { if (parsedAttributes.get(disallowed)) return parser.emitError(attributeDictLocation, "'") << disallowed @@ -220,7 +221,8 @@ // Add the attributes to the function arguments. assert(resultAttrs.size() == resultTypes.size()); - addArgAndResultAttrs(builder, result, entryArgs, resultAttrs); + addArgAndResultAttrs(builder, result, entryArgs, resultAttrs, argAttrsName, + resAttrsName); // Parse the optional function body. The printer will not print the body if // its empty, so disallow parsing of empty body in the parser. @@ -261,14 +263,14 @@ os << ')'; } -void mlir::function_interface_impl::printFunctionSignature( - OpAsmPrinter &p, Operation *op, ArrayRef argTypes, bool isVariadic, - ArrayRef resultTypes) { +void function_interface_impl::printFunctionSignature( + OpAsmPrinter &p, FunctionOpInterface op, ArrayRef argTypes, + bool isVariadic, ArrayRef resultTypes) { Region &body = op->getRegion(0); bool isExternal = body.empty(); p << '('; - ArrayAttr argAttrs = op->getAttrOfType(getArgDictAttrName()); + ArrayAttr argAttrs = op.getArgAttrsAttr(); for (unsigned i = 0, e = argTypes.size(); i < e; ++i) { if (i > 0) p << ", "; @@ -295,26 +297,23 @@ if (!resultTypes.empty()) { p.getStream() << " -> "; - auto resultAttrs = op->getAttrOfType(getResultDictAttrName()); + auto resultAttrs = op.getResAttrsAttr(); printFunctionResultList(p, resultTypes, resultAttrs); } } -void mlir::function_interface_impl::printFunctionAttributes( - OpAsmPrinter &p, Operation *op, unsigned numInputs, unsigned numResults, - ArrayRef elided) { +void function_interface_impl::printFunctionAttributes( + OpAsmPrinter &p, Operation *op, ArrayRef elided) { // Print out function attributes, if present. - SmallVector ignoredAttrs = { - ::mlir::SymbolTable::getSymbolAttrName(), getTypeAttrName(), - getArgDictAttrName(), getResultDictAttrName()}; + SmallVector ignoredAttrs = {SymbolTable::getSymbolAttrName()}; ignoredAttrs.append(elided.begin(), elided.end()); p.printOptionalAttrDictWithKeyword(op->getAttrs(), ignoredAttrs); } -void mlir::function_interface_impl::printFunctionOp(OpAsmPrinter &p, - FunctionOpInterface op, - bool isVariadic) { +void function_interface_impl::printFunctionOp( + OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, + StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName) { // Print the operation and the function name. auto funcName = op->getAttrOfType(SymbolTable::getSymbolAttrName()) @@ -329,8 +328,8 @@ ArrayRef argTypes = op.getArgumentTypes(); ArrayRef resultTypes = op.getResultTypes(); printFunctionSignature(p, op, argTypes, isVariadic, resultTypes); - printFunctionAttributes(p, op, argTypes.size(), resultTypes.size(), - {visibilityAttrName}); + printFunctionAttributes( + p, op, {visibilityAttrName, typeAttrName, argAttrsName, resAttrsName}); // Print the body if this is not an external function. Region &body = op->getRegion(0); if (!body.empty()) { diff --git a/mlir/lib/IR/FunctionInterfaces.cpp b/mlir/lib/IR/FunctionInterfaces.cpp --- a/mlir/lib/IR/FunctionInterfaces.cpp +++ b/mlir/lib/IR/FunctionInterfaces.cpp @@ -24,27 +24,104 @@ return attr.cast().empty(); } -DictionaryAttr mlir::function_interface_impl::getArgAttrDict(Operation *op, - unsigned index) { - ArrayAttr attrs = op->getAttrOfType(getArgDictAttrName()); +DictionaryAttr function_interface_impl::getArgAttrDict(FunctionOpInterface op, + unsigned index) { + ArrayAttr attrs = op.getArgAttrsAttr(); DictionaryAttr argAttrs = attrs ? attrs[index].cast() : DictionaryAttr(); return argAttrs; } DictionaryAttr -mlir::function_interface_impl::getResultAttrDict(Operation *op, - unsigned index) { - ArrayAttr attrs = op->getAttrOfType(getResultDictAttrName()); +function_interface_impl::getResultAttrDict(FunctionOpInterface op, + unsigned index) { + ArrayAttr attrs = op.getResAttrsAttr(); DictionaryAttr resAttrs = attrs ? attrs[index].cast() : DictionaryAttr(); return resAttrs; } -void mlir::function_interface_impl::detail::setArgResAttrDict( - Operation *op, StringRef attrName, unsigned numTotalIndices, unsigned index, - DictionaryAttr attrs) { - ArrayAttr allAttrs = op->getAttrOfType(attrName); +ArrayRef +function_interface_impl::getArgAttrs(FunctionOpInterface op, unsigned index) { + auto argDict = getArgAttrDict(op, index); + return argDict ? argDict.getValue() : std::nullopt; +} + +ArrayRef +function_interface_impl::getResultAttrs(FunctionOpInterface op, + unsigned index) { + auto resultDict = getResultAttrDict(op, index); + return resultDict ? resultDict.getValue() : std::nullopt; +} + +/// Get either the argument or result attributes array. +template +static ArrayAttr getArgResAttrs(FunctionOpInterface op) { + if constexpr (isArg) + return op.getArgAttrsAttr(); + else + return op.getResAttrsAttr(); +} + +/// Set either the argument or result attributes array. +template +static void setArgResAttrs(FunctionOpInterface op, ArrayAttr attrs) { + if constexpr (isArg) + op.setArgAttrsAttr(attrs); + else + op.setResAttrsAttr(attrs); +} + +/// Erase either the argument or result attributes array. +template +static void removeArgResAttrs(FunctionOpInterface op) { + if constexpr (isArg) + op.removeArgAttrsAttr(); + else + op.removeResAttrsAttr(); +} + +/// Set all of the argument or result attribute dictionaries for a function. +template +static void setAllArgResAttrDicts(FunctionOpInterface op, + ArrayRef attrs) { + if (llvm::all_of(attrs, isEmptyAttrDict)) + removeArgResAttrs(op); + else + setArgResAttrs(op, ArrayAttr::get(op->getContext(), attrs)); +} + +void function_interface_impl::setAllArgAttrDicts( + FunctionOpInterface op, ArrayRef attrs) { + setAllArgAttrDicts(op, ArrayRef(attrs.data(), attrs.size())); +} + +void function_interface_impl::setAllArgAttrDicts(FunctionOpInterface op, + ArrayRef attrs) { + auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute { + return !attr ? DictionaryAttr::get(op->getContext()) : attr; + }); + setAllArgResAttrDicts(op, llvm::to_vector<8>(wrappedAttrs)); +} + +void function_interface_impl::setAllResultAttrDicts( + FunctionOpInterface op, ArrayRef attrs) { + setAllResultAttrDicts(op, ArrayRef(attrs.data(), attrs.size())); +} + +void function_interface_impl::setAllResultAttrDicts(FunctionOpInterface op, + ArrayRef attrs) { + auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute { + return !attr ? DictionaryAttr::get(op->getContext()) : attr; + }); + setAllArgResAttrDicts(op, llvm::to_vector<8>(wrappedAttrs)); +} + +/// Update the given index into an argument or result attribute dictionary. +template +static void setArgResAttrDict(FunctionOpInterface op, unsigned numTotalIndices, + unsigned index, DictionaryAttr attrs) { + ArrayAttr allAttrs = getArgResAttrs(op); if (!allAttrs) { if (attrs.empty()) return; @@ -53,7 +130,7 @@ SmallVector newAttrs(numTotalIndices, DictionaryAttr::get(op->getContext())); newAttrs[index] = attrs; - op->setAttr(attrName, ArrayAttr::get(op->getContext(), newAttrs)); + setArgResAttrs(op, ArrayAttr::get(op->getContext(), newAttrs)); return; } // Check to see if the attribute is different from what we already have. @@ -65,54 +142,52 @@ ArrayRef rawAttrArray = allAttrs.getValue(); if (attrs.empty() && llvm::all_of(rawAttrArray.take_front(index), isEmptyAttrDict) && - llvm::all_of(rawAttrArray.drop_front(index + 1), isEmptyAttrDict)) { - op->removeAttr(attrName); - return; - } + llvm::all_of(rawAttrArray.drop_front(index + 1), isEmptyAttrDict)) + return removeArgResAttrs(op); // Otherwise, create a new attribute array with the updated dictionary. SmallVector newAttrs(rawAttrArray.begin(), rawAttrArray.end()); newAttrs[index] = attrs; - op->setAttr(attrName, ArrayAttr::get(op->getContext(), newAttrs)); + setArgResAttrs(op, ArrayAttr::get(op->getContext(), newAttrs)); } -/// Set all of the argument or result attribute dictionaries for a function. -static void setAllArgResAttrDicts(Operation *op, StringRef attrName, - ArrayRef attrs) { - if (llvm::all_of(attrs, isEmptyAttrDict)) - op->removeAttr(attrName); - else - op->setAttr(attrName, ArrayAttr::get(op->getContext(), attrs)); +void function_interface_impl::setArgAttrs(FunctionOpInterface op, + unsigned index, + ArrayRef attributes) { + assert(index < op.getNumArguments() && "invalid argument number"); + return setArgResAttrDict( + op, op.getNumArguments(), index, + DictionaryAttr::get(op->getContext(), attributes)); } -void mlir::function_interface_impl::setAllArgAttrDicts( - Operation *op, ArrayRef attrs) { - setAllArgAttrDicts(op, ArrayRef(attrs.data(), attrs.size())); -} -void mlir::function_interface_impl::setAllArgAttrDicts( - Operation *op, ArrayRef attrs) { - auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute { - return !attr ? DictionaryAttr::get(op->getContext()) : attr; - }); - setAllArgResAttrDicts(op, getArgDictAttrName(), - llvm::to_vector<8>(wrappedAttrs)); +void function_interface_impl::setArgAttrs(FunctionOpInterface op, + unsigned index, + DictionaryAttr attributes) { + return setArgResAttrDict( + op, op.getNumArguments(), index, + attributes ? attributes : DictionaryAttr::get(op->getContext())); } -void mlir::function_interface_impl::setAllResultAttrDicts( - Operation *op, ArrayRef attrs) { - setAllResultAttrDicts(op, ArrayRef(attrs.data(), attrs.size())); +void function_interface_impl::setResultAttrs( + FunctionOpInterface op, unsigned index, + ArrayRef attributes) { + assert(index < op.getNumResults() && "invalid result number"); + return setArgResAttrDict( + op, op.getNumResults(), index, + DictionaryAttr::get(op->getContext(), attributes)); } -void mlir::function_interface_impl::setAllResultAttrDicts( - Operation *op, ArrayRef attrs) { - auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute { - return !attr ? DictionaryAttr::get(op->getContext()) : attr; - }); - setAllArgResAttrDicts(op, getResultDictAttrName(), - llvm::to_vector<8>(wrappedAttrs)); + +void function_interface_impl::setResultAttrs(FunctionOpInterface op, + unsigned index, + DictionaryAttr attributes) { + assert(index < op.getNumResults() && "invalid result number"); + return setArgResAttrDict( + op, op.getNumResults(), index, + attributes ? attributes : DictionaryAttr::get(op->getContext())); } -void mlir::function_interface_impl::insertFunctionArguments( - Operation *op, ArrayRef argIndices, TypeRange argTypes, +void function_interface_impl::insertFunctionArguments( + FunctionOpInterface op, ArrayRef argIndices, TypeRange argTypes, ArrayRef argAttrs, ArrayRef argLocs, unsigned originalNumArgs, Type newType) { assert(argIndices.size() == argTypes.size()); @@ -128,7 +203,7 @@ Block &entry = op->getRegion(0).front(); // Update the argument attributes of the function. - auto oldArgAttrs = op->getAttrOfType(getArgDictAttrName()); + ArrayAttr oldArgAttrs = op.getArgAttrsAttr(); if (oldArgAttrs || !argAttrs.empty()) { SmallVector newArgAttrs; newArgAttrs.reserve(originalNumArgs + argIndices.size()); @@ -152,15 +227,15 @@ } // Update the function type and any entry block arguments. - op->setAttr(getTypeAttrName(), TypeAttr::get(newType)); + op.setFunctionTypeAttr(TypeAttr::get(newType)); for (unsigned i = 0, e = argIndices.size(); i < e; ++i) entry.insertArgument(argIndices[i] + i, argTypes[i], argLocs[i]); } -void mlir::function_interface_impl::insertFunctionResults( - Operation *op, ArrayRef resultIndices, TypeRange resultTypes, - ArrayRef resultAttrs, unsigned originalNumResults, - Type newType) { +void function_interface_impl::insertFunctionResults( + FunctionOpInterface op, ArrayRef resultIndices, + TypeRange resultTypes, ArrayRef resultAttrs, + unsigned originalNumResults, Type newType) { assert(resultIndices.size() == resultTypes.size()); assert(resultIndices.size() == resultAttrs.size() || resultAttrs.empty()); if (resultIndices.empty()) @@ -171,7 +246,7 @@ // - Result attrs. // Update the result attributes of the function. - auto oldResultAttrs = op->getAttrOfType(getResultDictAttrName()); + ArrayAttr oldResultAttrs = op.getResAttrsAttr(); if (oldResultAttrs || !resultAttrs.empty()) { SmallVector newResultAttrs; newResultAttrs.reserve(originalNumResults + resultIndices.size()); @@ -196,11 +271,11 @@ } // Update the function type. - op->setAttr(getTypeAttrName(), TypeAttr::get(newType)); + op.setFunctionTypeAttr(TypeAttr::get(newType)); } -void mlir::function_interface_impl::eraseFunctionArguments( - Operation *op, const BitVector &argIndices, Type newType) { +void function_interface_impl::eraseFunctionArguments( + FunctionOpInterface op, const BitVector &argIndices, Type newType) { // There are 3 things that need to be updated: // - Function type. // - Arg attrs. @@ -208,7 +283,7 @@ Block &entry = op->getRegion(0).front(); // Update the argument attributes of the function. - if (auto argAttrs = op->getAttrOfType(getArgDictAttrName())) { + if (ArrayAttr argAttrs = op.getArgAttrsAttr()) { SmallVector newArgAttrs; newArgAttrs.reserve(argAttrs.size()); for (unsigned i = 0, e = argIndices.size(); i < e; ++i) @@ -218,18 +293,18 @@ } // Update the function type and any entry block arguments. - op->setAttr(getTypeAttrName(), TypeAttr::get(newType)); + op.setFunctionTypeAttr(TypeAttr::get(newType)); entry.eraseArguments(argIndices); } -void mlir::function_interface_impl::eraseFunctionResults( - Operation *op, const BitVector &resultIndices, Type newType) { +void function_interface_impl::eraseFunctionResults( + FunctionOpInterface op, const BitVector &resultIndices, Type newType) { // There are 2 things that need to be updated: // - Function type. // - Result attrs. // Update the result attributes of the function. - if (auto resAttrs = op->getAttrOfType(getResultDictAttrName())) { + if (ArrayAttr resAttrs = op.getResAttrsAttr()) { SmallVector newResultAttrs; newResultAttrs.reserve(resAttrs.size()); for (unsigned i = 0, e = resultIndices.size(); i < e; ++i) @@ -239,10 +314,10 @@ } // Update the function type. - op->setAttr(getTypeAttrName(), TypeAttr::get(newType)); + op.setFunctionTypeAttr(TypeAttr::get(newType)); } -TypeRange mlir::function_interface_impl::insertTypesInto( +TypeRange function_interface_impl::insertTypesInto( TypeRange oldTypes, ArrayRef indices, TypeRange newTypes, SmallVectorImpl &storage) { assert(indices.size() == newTypes.size() && @@ -261,7 +336,7 @@ return storage; } -TypeRange mlir::function_interface_impl::filterTypesOut( +TypeRange function_interface_impl::filterTypesOut( TypeRange types, const BitVector &indices, SmallVectorImpl &storage) { if (indices.none()) return types; @@ -276,45 +351,41 @@ // Function type signature. //===----------------------------------------------------------------------===// -void mlir::function_interface_impl::setFunctionType(Operation *op, - Type newType) { - FunctionOpInterface funcOp = cast(op); - unsigned oldNumArgs = funcOp.getNumArguments(); - unsigned oldNumResults = funcOp.getNumResults(); - op->setAttr(getTypeAttrName(), TypeAttr::get(newType)); - unsigned newNumArgs = funcOp.getNumArguments(); - unsigned newNumResults = funcOp.getNumResults(); +void function_interface_impl::setFunctionType(FunctionOpInterface op, + Type newType) { + unsigned oldNumArgs = op.getNumArguments(); + unsigned oldNumResults = op.getNumResults(); + op.setFunctionTypeAttr(TypeAttr::get(newType)); + unsigned newNumArgs = op.getNumArguments(); + unsigned newNumResults = op.getNumResults(); // Functor used to update the argument and result attributes of the function. - auto updateAttrFn = [&](StringRef attrName, unsigned oldCount, - unsigned newCount, auto setAttrFn) { + auto emptyDict = DictionaryAttr::get(op.getContext()); + auto updateAttrFn = [&](auto isArg, unsigned oldCount, unsigned newCount) { + constexpr bool isArgVal = std::is_same_v; + if (oldCount == newCount) return; // The new type has no arguments/results, just drop the attribute. - if (newCount == 0) { - op->removeAttr(attrName); - return; - } - ArrayAttr attrs = op->getAttrOfType(attrName); + if (newCount == 0) + return removeArgResAttrs(op); + ArrayAttr attrs = getArgResAttrs(op); if (!attrs) return; // The new type has less arguments/results, take the first N attributes. if (newCount < oldCount) - return setAttrFn(op, attrs.getValue().take_front(newCount)); + return setAllArgResAttrDicts( + op, attrs.getValue().take_front(newCount)); // Otherwise, the new type has more arguments/results. Initialize the new - // arguments/results with empty attributes. + // arguments/results with empty dictionary attributes. SmallVector newAttrs(attrs.begin(), attrs.end()); - newAttrs.resize(newCount); - setAttrFn(op, newAttrs); + newAttrs.resize(newCount, emptyDict); + setAllArgResAttrDicts(op, newAttrs); }; // Update the argument and result attributes. - updateAttrFn( - getArgDictAttrName(), oldNumArgs, newNumArgs, - [&](Operation *op, auto &&attrs) { setAllArgAttrDicts(op, attrs); }); - updateAttrFn( - getResultDictAttrName(), oldNumResults, newNumResults, - [&](Operation *op, auto &&attrs) { setAllResultAttrDicts(op, attrs); }); + updateAttrFn(std::true_type{}, oldNumArgs, newNumArgs); + updateAttrFn(std::false_type{}, oldNumResults, newNumResults); } diff --git a/mlir/test/IR/invalid-func-op.mlir b/mlir/test/IR/invalid-func-op.mlir --- a/mlir/test/IR/invalid-func-op.mlir +++ b/mlir/test/IR/invalid-func-op.mlir @@ -96,20 +96,11 @@ // ----- -// expected-error@+1 {{argument attribute array `arg_attrs` to have the same number of elements as the number of function arguments}} +// expected-error@+1 {{argument attribute array to have the same number of elements as the number of function arguments}} func.func private @invalid_arg_attrs() attributes { arg_attrs = [{}] } // ----- -// expected-error@+1 {{expects argument attribute dictionary to be a DictionaryAttr, but got `10 : i64`}} -func.func private @invalid_arg_attrs(i32) attributes { arg_attrs = [10] } -// ----- - -// expected-error@+1 {{result attribute array `res_attrs` to have the same number of elements as the number of function results}} +// expected-error@+1 {{result attribute array to have the same number of elements as the number of function results}} func.func private @invalid_res_attrs() attributes { res_attrs = [{}] } - -// ----- - -// expected-error@+1 {{expects result attribute dictionary to be a DictionaryAttr, but got `10 : i64`}} -func.func private @invalid_res_attrs() -> i32 attributes { res_attrs = [10] }