diff --git a/mlir/examples/toy/Ch2/mlir/Dialect.cpp b/mlir/examples/toy/Ch2/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch2/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch2/mlir/Dialect.cpp @@ -211,7 +211,8 @@ std::string &) { return builder.getFunctionType(argTypes, results); }; return mlir::function_interface_impl::parseFunctionOp( - parser, result, /*allowVariadic=*/false, buildFuncType); + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType); } void FuncOp::print(mlir::OpAsmPrinter &p) { diff --git a/mlir/examples/toy/Ch3/mlir/Dialect.cpp b/mlir/examples/toy/Ch3/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch3/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch3/mlir/Dialect.cpp @@ -198,7 +198,8 @@ std::string &) { return builder.getFunctionType(argTypes, results); }; return mlir::function_interface_impl::parseFunctionOp( - parser, result, /*allowVariadic=*/false, buildFuncType); + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType); } void FuncOp::print(mlir::OpAsmPrinter &p) { diff --git a/mlir/examples/toy/Ch4/mlir/Dialect.cpp b/mlir/examples/toy/Ch4/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch4/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch4/mlir/Dialect.cpp @@ -287,7 +287,8 @@ std::string &) { return builder.getFunctionType(argTypes, results); }; return mlir::function_interface_impl::parseFunctionOp( - parser, result, /*allowVariadic=*/false, buildFuncType); + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType); } void FuncOp::print(mlir::OpAsmPrinter &p) { diff --git a/mlir/examples/toy/Ch5/mlir/Dialect.cpp b/mlir/examples/toy/Ch5/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch5/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch5/mlir/Dialect.cpp @@ -287,7 +287,8 @@ std::string &) { return builder.getFunctionType(argTypes, results); }; return mlir::function_interface_impl::parseFunctionOp( - parser, result, /*allowVariadic=*/false, buildFuncType); + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType); } void FuncOp::print(mlir::OpAsmPrinter &p) { diff --git a/mlir/examples/toy/Ch6/mlir/Dialect.cpp b/mlir/examples/toy/Ch6/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch6/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch6/mlir/Dialect.cpp @@ -287,7 +287,8 @@ std::string &) { return builder.getFunctionType(argTypes, results); }; return mlir::function_interface_impl::parseFunctionOp( - parser, result, /*allowVariadic=*/false, buildFuncType); + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType); } void FuncOp::print(mlir::OpAsmPrinter &p) { diff --git a/mlir/examples/toy/Ch7/mlir/Dialect.cpp b/mlir/examples/toy/Ch7/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch7/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch7/mlir/Dialect.cpp @@ -314,7 +314,8 @@ std::string &) { return builder.getFunctionType(argTypes, results); }; return mlir::function_interface_impl::parseFunctionOp( - parser, result, /*allowVariadic=*/false, buildFuncType); + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType); } void FuncOp::print(mlir::OpAsmPrinter &p) { diff --git a/mlir/include/mlir/IR/FunctionImplementation.h b/mlir/include/mlir/IR/FunctionImplementation.h --- a/mlir/include/mlir/IR/FunctionImplementation.h +++ b/mlir/include/mlir/IR/FunctionImplementation.h @@ -69,17 +69,19 @@ /// Parser implementation for function-like operations. Uses /// `funcTypeBuilder` to construct the custom function type given lists of -/// input and output types. If `allowVariadic` is set, the parser will accept +/// input and output types. The parser sets the `typeAttrName` attribute to the +/// resulting function type. If `allowVariadic` is set, the parser will accept /// trailing ellipsis in the function signature and indicate to the builder /// whether the function is variadic. If the builder returns a null type, /// `result` will not contain the `type` attribute. The caller can then add a /// type, report the error or delegate the reporting to the op's verifier. ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, - bool allowVariadic, + bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder); /// Printer implementation for function-like operations. -void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic); +void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, + StringRef typeAttrName); /// Prints the signature of the function-like operation `op`. Assumes `op` has /// is a FunctionOpInterface and has passed verification. @@ -92,8 +94,7 @@ /// function-like operation internally are not printed. Nothing is printed /// if all attributes are elided. Assumes `op` is a FunctionOpInterface and /// has passed verification. -void printFunctionAttributes(OpAsmPrinter &p, Operation *op, unsigned numInputs, - unsigned numResults, +void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef elided = {}); } // namespace function_interface_impl diff --git a/mlir/include/mlir/IR/FunctionInterfaces.h b/mlir/include/mlir/IR/FunctionInterfaces.h --- a/mlir/include/mlir/IR/FunctionInterfaces.h +++ b/mlir/include/mlir/IR/FunctionInterfaces.h @@ -22,12 +22,10 @@ #include "llvm/ADT/SmallString.h" namespace mlir { +class FunctionOpInterface; namespace function_interface_impl { -/// Return the name of the attribute used for function types. -inline StringRef getTypeAttrName() { return "function_type"; } - /// Return the name of the attribute used for function argument attributes. inline StringRef getArgDictAttrName() { return "arg_attrs"; } @@ -72,28 +70,29 @@ } /// Insert the specified arguments and update the function type attribute. -void insertFunctionArguments(Operation *op, ArrayRef argIndices, - TypeRange argTypes, +void insertFunctionArguments(FunctionOpInterface op, + ArrayRef argIndices, TypeRange argTypes, ArrayRef argAttrs, ArrayRef argLocs, unsigned originalNumArgs, Type newType); /// Insert the specified results and update the function type attribute. -void insertFunctionResults(Operation *op, ArrayRef resultIndices, +void insertFunctionResults(FunctionOpInterface op, + ArrayRef resultIndices, TypeRange resultTypes, ArrayRef resultAttrs, unsigned originalNumResults, Type newType); /// Erase the specified arguments and update the function type attribute. -void eraseFunctionArguments(Operation *op, const BitVector &argIndices, +void eraseFunctionArguments(FunctionOpInterface op, const BitVector &argIndices, Type newType); /// Erase the specified results and update the function type attribute. -void eraseFunctionResults(Operation *op, const BitVector &resultIndices, - Type newType); +void eraseFunctionResults(FunctionOpInterface op, + const BitVector &resultIndices, Type newType); /// Set a FunctionOpInterface operation's type signature. -void setFunctionType(Operation *op, Type newType); +void setFunctionType(FunctionOpInterface op, Type newType); /// Insert a set of `newTypes` into `oldTypes` at the given `indices`. If any /// types are inserted, `storage` is used to hold the new type list. The new @@ -207,10 +206,6 @@ /// method on FunctionOpInterface::Trait. template LogicalResult verifyTrait(ConcreteOp op) { - if (!op.getFunctionTypeAttr()) - return op.emitOpError("requires a type attribute '") - << function_interface_impl::getTypeAttrName() << '\''; - if (failed(op.verifyType())) return failure(); diff --git a/mlir/include/mlir/IR/FunctionInterfaces.td b/mlir/include/mlir/IR/FunctionInterfaces.td --- a/mlir/include/mlir/IR/FunctionInterfaces.td +++ b/mlir/include/mlir/IR/FunctionInterfaces.td @@ -49,6 +49,16 @@ for each of the function results. }]; let methods = [ + InterfaceMethod<[{ + Returns the type of the function. + }], + "::mlir::Type", "getFunctionType">, + InterfaceMethod<[{ + Set the type of the function. This method should perform an unsafe + modification to the function type; it should not update argument or + result attributes. + }], + "void", "setFunctionTypeAttr", (ins "::mlir::TypeAttr":$type)>, InterfaceMethod<[{ Returns the function argument types based exclusively on the type (to allow for this method may be called on function @@ -139,7 +149,7 @@ ArrayRef attrs, TypeRange inputTypes) { state.addAttribute(SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)); - state.addAttribute(function_interface_impl::getTypeAttrName(), + state.addAttribute(ConcreteOp::getFunctionTypeAttrName(state.name), TypeAttr::get(type)); state.attributes.append(attrs.begin(), attrs.end()); @@ -244,11 +254,6 @@ // the derived operation, which should already have these defined // (via ODS). - /// Returns the name of the attribute used for function types. - static StringRef getTypeAttrName() { - return function_interface_impl::getTypeAttrName(); - } - /// Returns the name of the attribute used for function argument attributes. static StringRef getArgDictAttrName() { return function_interface_impl::getArgDictAttrName(); @@ -259,15 +264,6 @@ return function_interface_impl::getResultDictAttrName(); } - /// Return the attribute containing the type of this function. - TypeAttr getFunctionTypeAttr() { - return this->getOperation()->template getAttrOfType( - getTypeAttrName()); - } - - /// Return the type of this function. - Type getFunctionType() { return getFunctionTypeAttr().getValue(); } - //===------------------------------------------------------------------===// // Argument and Result Handling //===------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp --- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp +++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp @@ -59,12 +59,11 @@ /// Only retain those attributes that are not constructed by /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument /// attributes. -static void filterFuncAttributes(ArrayRef attrs, - bool filterArgAndResAttrs, +static void filterFuncAttributes(func::FuncOp func, bool filterArgAndResAttrs, SmallVectorImpl &result) { - for (const auto &attr : attrs) { + for (const NamedAttribute &attr : func->getAttrs()) { if (attr.getName() == SymbolTable::getSymbolAttrName() || - attr.getName() == FunctionOpInterface::getTypeAttrName() || + attr.getName() == func.getFunctionTypeAttrName() || attr.getName() == "func.varargs" || (filterArgAndResAttrs && (attr.getName() == FunctionOpInterface::getArgDictAttrName() || @@ -138,8 +137,7 @@ LLVM::LLVMFuncOp newFuncOp) { auto type = funcOp.getFunctionType(); SmallVector attributes; - filterFuncAttributes(funcOp->getAttrs(), /*filterArgAndResAttrs=*/false, - attributes); + filterFuncAttributes(funcOp, /*filterArgAndResAttrs=*/false, attributes); auto [wrapperFuncType, resultIsNowArg] = typeConverter.convertFunctionTypeCWrapper(type); if (resultIsNowArg) @@ -204,8 +202,7 @@ assert(wrapperType && "unexpected type conversion failure"); SmallVector attributes; - filterFuncAttributes(funcOp->getAttrs(), /*filterArgAndResAttrs=*/false, - attributes); + filterFuncAttributes(funcOp, /*filterArgAndResAttrs=*/false, attributes); if (resultIsNowArg) prependResAttrsToArgAttrs(builder, attributes, funcOp.getNumArguments()); @@ -304,8 +301,7 @@ // Propagate argument/result attributes to all converted arguments/result // obtained after converting a given original argument/result. SmallVector attributes; - filterFuncAttributes(funcOp->getAttrs(), /*filterArgAndResAttrs=*/true, - attributes); + filterFuncAttributes(funcOp, /*filterArgAndResAttrs=*/true, attributes); if (ArrayAttr resAttrDicts = funcOp.getAllResultAttrs()) { assert(!resAttrDicts.empty() && "expected array to be non-empty"); auto newResAttrDicts = diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp @@ -60,7 +60,7 @@ SmallVector attributes; for (const auto &attr : gpuFuncOp->getAttrs()) { if (attr.getName() == SymbolTable::getSymbolAttrName() || - attr.getName() == FunctionOpInterface::getTypeAttrName() || + attr.getName() == gpuFuncOp.getFunctionTypeAttrName() || attr.getName() == gpu::GPUFuncOp::getNumWorkgroupAttributionsAttrName()) continue; attributes.push_back(attr); diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp @@ -226,7 +226,7 @@ rewriter.getFunctionType(signatureConverter.getConvertedTypes(), std::nullopt)); for (const auto &namedAttr : funcOp->getAttrs()) { - if (namedAttr.getName() == FunctionOpInterface::getTypeAttrName() || + if (namedAttr.getName() == funcOp.getFunctionTypeAttrName() || namedAttr.getName() == SymbolTable::getSymbolAttrName()) continue; newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue()); diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp --- a/mlir/lib/Dialect/Async/IR/Async.cpp +++ b/mlir/lib/Dialect/Async/IR/Async.cpp @@ -332,8 +332,7 @@ ArrayRef argAttrs) { state.addAttribute(SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)); - state.addAttribute(FunctionOpInterface::getTypeAttrName(), - TypeAttr::get(type)); + state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type)); state.attributes.append(attrs.begin(), attrs.end()); state.addRegion(); @@ -352,11 +351,13 @@ std::string &) { return builder.getFunctionType(argTypes, results); }; return function_interface_impl::parseFunctionOp( - parser, result, /*allowVariadic=*/false, buildFuncType); + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType); } void FuncOp::print(OpAsmPrinter &p) { - function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false); + function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false, + getFunctionTypeAttrName()); } /// Check that the result type of async.func is not void and must be diff --git a/mlir/lib/Dialect/Func/IR/FuncOps.cpp b/mlir/lib/Dialect/Func/IR/FuncOps.cpp --- a/mlir/lib/Dialect/Func/IR/FuncOps.cpp +++ b/mlir/lib/Dialect/Func/IR/FuncOps.cpp @@ -244,8 +244,7 @@ ArrayRef argAttrs) { state.addAttribute(SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)); - state.addAttribute(FunctionOpInterface::getTypeAttrName(), - TypeAttr::get(type)); + state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type)); state.attributes.append(attrs.begin(), attrs.end()); state.addRegion(); @@ -263,11 +262,13 @@ std::string &) { return builder.getFunctionType(argTypes, results); }; return function_interface_impl::parseFunctionOp( - parser, result, /*allowVariadic=*/false, buildFuncType); + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType); } void FuncOp::print(OpAsmPrinter &p) { - function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false); + function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false, + getFunctionTypeAttrName()); } /// Clone the internal blocks from this function into dest and all attributes diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -859,7 +859,8 @@ ArrayRef attrs) { result.addAttribute(SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)); - result.addAttribute(getTypeAttrName(), TypeAttr::get(type)); + result.addAttribute(getFunctionTypeAttrName(result.name), + TypeAttr::get(type)); result.addAttribute(getNumWorkgroupAttributionsAttrName(), builder.getI64IntegerAttr(workgroupAttributions.size())); result.addAttributes(attrs); @@ -930,7 +931,8 @@ for (auto &arg : entryArgs) argTypes.push_back(arg.type); auto type = builder.getFunctionType(argTypes, resultTypes); - result.addAttribute(GPUFuncOp::getTypeAttrName(), TypeAttr::get(type)); + result.addAttribute(getFunctionTypeAttrName(result.name), + TypeAttr::get(type)); function_interface_impl::addArgAndResultAttrs(builder, result, entryArgs, resultAttrs); @@ -992,19 +994,14 @@ p << ' ' << getKernelKeyword(); function_interface_impl::printFunctionAttributes( - p, *this, type.getNumInputs(), type.getNumResults(), + p, *this, {getNumWorkgroupAttributionsAttrName(), - GPUDialect::getKernelFuncAttrName()}); + GPUDialect::getKernelFuncAttrName(), getFunctionTypeAttrName()}); p << ' '; p.printRegion(getBody(), /*printEntryBlockArgs=*/false); } LogicalResult GPUFuncOp::verifyType() { - Type type = getFunctionTypeAttr().getValue(); - if (!type.isa()) - return emitOpError("requires '" + getTypeAttrName() + - "' attribute of function type"); - if (isKernel() && getFunctionType().getNumResults() != 0) return emitOpError() << "expected void return type for kernel function"; diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -2090,7 +2090,7 @@ function_interface_impl::VariadicFlag(isVariadic)); if (!type) return failure(); - result.addAttribute(FunctionOpInterface::getTypeAttrName(), + result.addAttribute(getFunctionTypeAttrName(result.name), TypeAttr::get(type)); if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) @@ -2130,8 +2130,8 @@ function_interface_impl::printFunctionSignature(p, *this, argTypes, isVarArg(), resTypes); function_interface_impl::printFunctionAttributes( - p, *this, argTypes.size(), resTypes.size(), - {getLinkageAttrName(), getCConvAttrName()}); + p, *this, + {getFunctionTypeAttrName(), getLinkageAttrName(), getCConvAttrName()}); // Print the body if this is not an external function. Region &body = getBody(); diff --git a/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp b/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp --- a/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp +++ b/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp @@ -152,11 +152,13 @@ std::string &) { return builder.getFunctionType(argTypes, results); }; return function_interface_impl::parseFunctionOp( - parser, result, /*allowVariadic=*/false, buildFuncType); + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType); } void FuncOp::print(OpAsmPrinter &p) { - function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false); + function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false, + getFunctionTypeAttrName()); } //===----------------------------------------------------------------------===// @@ -313,11 +315,13 @@ std::string &) { return builder.getFunctionType(argTypes, results); }; return function_interface_impl::parseFunctionOp( - parser, result, /*allowVariadic=*/false, buildFuncType); + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType); } void SubgraphOp::print(OpAsmPrinter &p) { - function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false); + function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false, + getFunctionTypeAttrName()); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp b/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp --- a/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp +++ b/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp @@ -220,11 +220,13 @@ std::string &) { return builder.getFunctionType(argTypes, results); }; return function_interface_impl::parseFunctionOp( - parser, result, /*allowVariadic=*/false, buildFuncType); + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType); } void FuncOp::print(OpAsmPrinter &p) { - function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false); + function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false, + getFunctionTypeAttrName()); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -2382,7 +2382,7 @@ for (auto &arg : entryArgs) argTypes.push_back(arg.type); auto fnType = builder.getFunctionType(argTypes, resultTypes); - result.addAttribute(FunctionOpInterface::getTypeAttrName(), + result.addAttribute(getFunctionTypeAttrName(result.name), TypeAttr::get(fnType)); // Parse the optional function control keyword. @@ -2417,8 +2417,9 @@ printer << " \"" << spirv::stringifyFunctionControl(getFunctionControl()) << "\""; function_interface_impl::printFunctionAttributes( - printer, *this, fnType.getNumInputs(), fnType.getNumResults(), - {spirv::attributeName()}); + printer, *this, + {spirv::attributeName(), + getFunctionTypeAttrName(), getFunctionControlAttrName()}); // Print the body if this is not an external function. Region &body = this->getBody(); @@ -2430,10 +2431,6 @@ } LogicalResult spirv::FuncOp::verifyType() { - auto type = getFunctionTypeAttr().getValue(); - if (!type.isa()) - return emitOpError("requires '" + getTypeAttrName() + - "' attribute of function type"); if (getFunctionType().getNumResults() > 1) return emitOpError("cannot have more than one result"); return success(); @@ -2473,7 +2470,7 @@ ArrayRef attrs) { state.addAttribute(SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)); - state.addAttribute(getTypeAttrName(), TypeAttr::get(type)); + state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type)); state.addAttribute(spirv::attributeName(), builder.getAttr(control)); state.attributes.append(attrs.begin(), attrs.end()); diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -531,7 +531,7 @@ // Copy over all attributes other than the function name and type. for (const auto &namedAttr : funcOp->getAttrs()) { - if (namedAttr.getName() != FunctionOpInterface::getTypeAttrName() && + if (namedAttr.getName() != funcOp.getFunctionTypeAttrName() && namedAttr.getName() != SymbolTable::getSymbolAttrName()) newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue()); } diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -1311,11 +1311,13 @@ std::string &) { return builder.getFunctionType(argTypes, results); }; return function_interface_impl::parseFunctionOp( - parser, result, /*allowVariadic=*/false, buildFuncType); + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType); } void FuncOp::print(OpAsmPrinter &p) { - function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false); + function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false, + getFunctionTypeAttrName()); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/FunctionImplementation.cpp b/mlir/lib/IR/FunctionImplementation.cpp --- a/mlir/lib/IR/FunctionImplementation.cpp +++ b/mlir/lib/IR/FunctionImplementation.cpp @@ -163,7 +163,7 @@ ParseResult mlir::function_interface_impl::parseFunctionOp( OpAsmParser &parser, OperationState &result, bool allowVariadic, - FuncTypeBuilder funcTypeBuilder) { + StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder) { SmallVector entryArgs; SmallVector resultAttrs; SmallVector resultTypes; @@ -197,7 +197,7 @@ << "failed to construct function type" << (errorMessage.empty() ? "" : ": ") << errorMessage; } - result.addAttribute(getTypeAttrName(), TypeAttr::get(type)); + result.addAttribute(typeAttrName, TypeAttr::get(type)); // If function attributes are present, parse them. NamedAttrList parsedAttributes; @@ -209,7 +209,7 @@ // dictionary. for (StringRef disallowed : {SymbolTable::getVisibilityAttrName(), SymbolTable::getSymbolAttrName(), - getTypeAttrName()}) { + typeAttrName.getValue()}) { if (parsedAttributes.get(disallowed)) return parser.emitError(attributeDictLocation, "'") << disallowed @@ -301,12 +301,11 @@ } void mlir::function_interface_impl::printFunctionAttributes( - OpAsmPrinter &p, Operation *op, unsigned numInputs, unsigned numResults, - ArrayRef elided) { + OpAsmPrinter &p, Operation *op, ArrayRef elided) { // Print out function attributes, if present. - SmallVector ignoredAttrs = { - ::mlir::SymbolTable::getSymbolAttrName(), getTypeAttrName(), - getArgDictAttrName(), getResultDictAttrName()}; + SmallVector ignoredAttrs = {SymbolTable::getSymbolAttrName(), + getArgDictAttrName(), + getResultDictAttrName()}; ignoredAttrs.append(elided.begin(), elided.end()); p.printOptionalAttrDictWithKeyword(op->getAttrs(), ignoredAttrs); @@ -314,7 +313,8 @@ void mlir::function_interface_impl::printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, - bool isVariadic) { + bool isVariadic, + StringRef typeAttrName) { // Print the operation and the function name. auto funcName = op->getAttrOfType(SymbolTable::getSymbolAttrName()) @@ -329,8 +329,7 @@ ArrayRef argTypes = op.getArgumentTypes(); ArrayRef resultTypes = op.getResultTypes(); printFunctionSignature(p, op, argTypes, isVariadic, resultTypes); - printFunctionAttributes(p, op, argTypes.size(), resultTypes.size(), - {visibilityAttrName}); + printFunctionAttributes(p, op, {visibilityAttrName, typeAttrName}); // Print the body if this is not an external function. Region &body = op->getRegion(0); if (!body.empty()) { diff --git a/mlir/lib/IR/FunctionInterfaces.cpp b/mlir/lib/IR/FunctionInterfaces.cpp --- a/mlir/lib/IR/FunctionInterfaces.cpp +++ b/mlir/lib/IR/FunctionInterfaces.cpp @@ -112,7 +112,7 @@ } void mlir::function_interface_impl::insertFunctionArguments( - Operation *op, ArrayRef argIndices, TypeRange argTypes, + FunctionOpInterface op, ArrayRef argIndices, TypeRange argTypes, ArrayRef argAttrs, ArrayRef argLocs, unsigned originalNumArgs, Type newType) { assert(argIndices.size() == argTypes.size()); @@ -152,15 +152,15 @@ } // Update the function type and any entry block arguments. - op->setAttr(getTypeAttrName(), TypeAttr::get(newType)); + op.setFunctionTypeAttr(TypeAttr::get(newType)); for (unsigned i = 0, e = argIndices.size(); i < e; ++i) entry.insertArgument(argIndices[i] + i, argTypes[i], argLocs[i]); } void mlir::function_interface_impl::insertFunctionResults( - Operation *op, ArrayRef resultIndices, TypeRange resultTypes, - ArrayRef resultAttrs, unsigned originalNumResults, - Type newType) { + FunctionOpInterface op, ArrayRef resultIndices, + TypeRange resultTypes, ArrayRef resultAttrs, + unsigned originalNumResults, Type newType) { assert(resultIndices.size() == resultTypes.size()); assert(resultIndices.size() == resultAttrs.size() || resultAttrs.empty()); if (resultIndices.empty()) @@ -196,11 +196,11 @@ } // Update the function type. - op->setAttr(getTypeAttrName(), TypeAttr::get(newType)); + op.setFunctionTypeAttr(TypeAttr::get(newType)); } void mlir::function_interface_impl::eraseFunctionArguments( - Operation *op, const BitVector &argIndices, Type newType) { + FunctionOpInterface op, const BitVector &argIndices, Type newType) { // There are 3 things that need to be updated: // - Function type. // - Arg attrs. @@ -218,12 +218,12 @@ } // Update the function type and any entry block arguments. - op->setAttr(getTypeAttrName(), TypeAttr::get(newType)); + op.setFunctionTypeAttr(TypeAttr::get(newType)); entry.eraseArguments(argIndices); } void mlir::function_interface_impl::eraseFunctionResults( - Operation *op, const BitVector &resultIndices, Type newType) { + FunctionOpInterface op, const BitVector &resultIndices, Type newType) { // There are 2 things that need to be updated: // - Function type. // - Result attrs. @@ -239,7 +239,7 @@ } // Update the function type. - op->setAttr(getTypeAttrName(), TypeAttr::get(newType)); + op.setFunctionTypeAttr(TypeAttr::get(newType)); } TypeRange mlir::function_interface_impl::insertTypesInto( @@ -276,14 +276,13 @@ // Function type signature. //===----------------------------------------------------------------------===// -void mlir::function_interface_impl::setFunctionType(Operation *op, +void mlir::function_interface_impl::setFunctionType(FunctionOpInterface op, Type newType) { - FunctionOpInterface funcOp = cast(op); - unsigned oldNumArgs = funcOp.getNumArguments(); - unsigned oldNumResults = funcOp.getNumResults(); - op->setAttr(getTypeAttrName(), TypeAttr::get(newType)); - unsigned newNumArgs = funcOp.getNumArguments(); - unsigned newNumResults = funcOp.getNumResults(); + unsigned oldNumArgs = op.getNumArguments(); + unsigned oldNumResults = op.getNumResults(); + op.setFunctionTypeAttr(TypeAttr::get(newType)); + unsigned newNumArgs = op.getNumArguments(); + unsigned newNumResults = op.getNumResults(); // Functor used to update the argument and result attributes of the function. auto updateAttrFn = [&](StringRef attrName, unsigned oldCount,