diff --git a/mlir/docs/Traits.md b/mlir/docs/Traits.md --- a/mlir/docs/Traits.md +++ b/mlir/docs/Traits.md @@ -253,34 +253,6 @@ trait. In particular, broadcasting behavior is not allowed. See the comments on `OpTrait::ElementwiseMappable` for the precise requirements. -### Function-Like - -* `OpTrait::FunctionLike` - -This trait provides APIs for operations that behave like functions. In -particular: - -- Ops must be symbols, i.e. also have the `Symbol` trait; -- Ops have a single region with multiple blocks that corresponds to the body - of the function; -- An op with a single empty region corresponds to an external function; -- arguments of the first block of the region are treated as function - arguments; -- they can have argument and result attributes that are stored in dictionary - attributes on the operation itself. - -This trait provides limited type support for the declared or defined functions. -The convenience function `getTypeAttrName()` returns the name of an attribute -that can be used to store the function type. In addition, this trait provides -`getType` and `setType` helpers to store a `FunctionType` in the attribute named -by `getTypeAttrName()`. - -In general, this trait assumes concrete ops use `FunctionType` under the hood. -If this is not the case, in order to use the function type support, concrete ops -must define the following methods, using the same name, to hide the ones defined -for `FunctionType`: `addBodyBlock`, `getType`, `getTypeWithoutArgsAndResults` -and `setType`. - ### HasParent * `OpTrait::HasParent` -- `HasParent` or diff --git a/mlir/include/mlir/Dialect/GPU/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/GPUDialect.h --- a/mlir/include/mlir/Dialect/GPU/GPUDialect.h +++ b/mlir/include/mlir/Dialect/GPU/GPUDialect.h @@ -19,7 +19,7 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" -#include "mlir/IR/FunctionSupport.h" +#include "mlir/IR/FunctionInterfaces.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/SymbolTable.h" diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td @@ -17,6 +17,7 @@ include "mlir/Dialect/GPU/GPUBase.td" include "mlir/Dialect/LLVMIR/LLVMOpBase.td" include "mlir/IR/EnumAttr.td" +include "mlir/IR/FunctionInterfaces.td" include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/DataLayoutInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" @@ -145,9 +146,10 @@ let verifier = [{ return success(); }]; } -def GPU_GPUFuncOp : GPU_Op<"func", [HasParent<"GPUModuleOp">, - AutomaticAllocationScope, FunctionLike, - IsolatedFromAbove, Symbol]> { +def GPU_GPUFuncOp : GPU_Op<"func", [ + HasParent<"GPUModuleOp">, AutomaticAllocationScope, FunctionOpInterface, + IsolatedFromAbove, Symbol + ]> { let summary = "Function executable on a GPU"; let description = [{ @@ -273,19 +275,24 @@ return "workgroup_attributions"; } - // FunctionLike trait needs access to the functions below. - friend class OpTrait::FunctionLike; + /// Returns the type of this function. + /// FIXME: We should drive this via the ODS `type` param. + FunctionType getType() { + return getTypeAttr().getValue().cast(); + } + + /// Returns the argument types of this function. + ArrayRef getArgumentTypes() { return getType().getInputs(); } - /// Hooks for the input/output type enumeration in FunctionLike . - unsigned getNumFuncArguments() { return getType().getNumInputs(); } - unsigned getNumFuncResults() { return getType().getNumResults(); } + /// Returns the result types of this function. + ArrayRef getResultTypes() { return getType().getResults(); } /// Returns the keywords used in the custom syntax for this Op. static StringRef getWorkgroupKeyword() { return "workgroup"; } static StringRef getPrivateKeyword() { return "private"; } static StringRef getKernelKeyword() { return "kernel"; } - /// Hook for FunctionLike verifier. + /// Hook for FunctionOpInterface verifier. LogicalResult verifyType(); /// Verifies the body of the function. diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -15,6 +15,7 @@ include "mlir/Dialect/LLVMIR/LLVMOpBase.td" include "mlir/Dialect/LLVMIR/LLVMOpsInterfaces.td" +include "mlir/IR/FunctionInterfaces.td" include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" @@ -1237,8 +1238,9 @@ let assemblyFormat = "attr-dict"; } -def LLVM_LLVMFuncOp : LLVM_Op<"func", - [AutomaticAllocationScope, IsolatedFromAbove, FunctionLike, Symbol]> { +def LLVM_LLVMFuncOp : LLVM_Op<"func", [ + AutomaticAllocationScope, IsolatedFromAbove, FunctionOpInterface, Symbol + ]> { let summary = "LLVM dialect function."; let description = [{ @@ -1292,24 +1294,19 @@ Block *addEntryBlock(); LLVMFunctionType getType() { - return (*this)->getAttrOfType(getTypeAttrName()) - .getValue().cast(); - } - bool isVarArg() { - return getType().isVarArg(); + return getTypeAttr().getValue().cast(); } + bool isVarArg() { return getType().isVarArg(); } - // Hook for OpTrait::FunctionLike, returns the number of function arguments`. - // Depends on the type attribute being correct as checked by verifyType. - unsigned getNumFuncArguments(); + /// Returns the argument types of this function. + ArrayRef getArgumentTypes() { return getType().getParams(); } - // Hook for OpTrait::FunctionLike, returns the number of function results. - // Depends on the type attribute being correct as checked by verifyType. - unsigned getNumFuncResults(); + /// Returns the result types of this function. + ArrayRef getResultTypes() { return getType().getReturnTypes(); } - // Hook for OpTrait::FunctionLike, called after verifying that the 'type' - // attribute is present. This can check for preconditions of the - // getNumArguments hook not failing. + /// Hook for FunctionOpInterface, called after verifying that the 'type' + /// attribute is present. This can check for preconditions of the + /// getNumArguments hook not failing. LogicalResult verifyType(); }]; diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h @@ -135,7 +135,7 @@ static bool isValidResultType(Type type); /// Returns whether the function is variadic. - bool isVarArg(); + bool isVarArg() const; /// Gets or creates an instance of LLVM dialect function in the same context /// as the `result` type. @@ -145,9 +145,17 @@ getChecked(function_ref emitError, Type result, ArrayRef arguments, bool isVarArg = false); + /// Returns a clone of this function type with the given argument + /// and result types. + LLVMFunctionType clone(TypeRange inputs, TypeRange results) const; + /// Returns the result type of the function. Type getReturnType(); + /// Returns the result type of the function as an ArrayRef, enabling better + /// integration with generic MLIR utilities. + ArrayRef getReturnTypes(); + /// Returns the number of arguments to the function. unsigned getNumParams(); diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -17,7 +17,8 @@ Convert ops with the `ElementwiseMappable` trait to linalg parallel loops. This pass only converts ops that operate on ranked tensors. It can be - run on op which contains linalg ops (most commonly a FunctionLike op). + run on op which contains linalg ops (most commonly a + FunctionOpInterface op). }]; let constructor = "mlir::createConvertElementwiseToLinalgPass()"; let dependentDialects = ["linalg::LinalgDialect", "memref::MemRefDialect"]; @@ -240,9 +241,9 @@ control flow inside a function. All blocks except for the entry block are detensored by converting their arguments whenever possible. - This can be run on any op with the FunctionLike trait and must not be + This can be run on any FunctionOpInterface op and must not be run on others. This is because it performs specific legalization of the - blocks that make up the body, which it assumes has a FunctionLike trait. + blocks that make up the body, which it assumes has is a FunctionOpInterface. }]; let options = [ Option<"aggressiveMode", "aggressive-mode", "bool", /*default=*/"false", diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -3964,7 +3964,7 @@ // Check that an op can only be used within the scope of a function-like op. def InFunctionScope : PredOpTrait< "op must appear in a function-like op's block", - CPred<"isNestedInFunctionLikeOp($_op.getParentOp())">>; + CPred<"isNestedInFunctionOpInterface($_op.getParentOp())">>; // Check that an op can only be used within the scope of a module-like op. def InModuleScope : PredOpTrait< diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td @@ -16,6 +16,7 @@ #define MLIR_DIALECT_SPIRV_IR_STRUCTURE_OPS include "mlir/Dialect/SPIRV/IR/SPIRVBase.td" +include "mlir/IR/FunctionInterfaces.td" include "mlir/IR/OpAsmInterface.td" include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/CallInterfaces.td" @@ -253,7 +254,7 @@ def SPV_FuncOp : SPV_Op<"func", [ AutomaticAllocationScope, DeclareOpInterfaceMethods, - FunctionLike, InModuleScope, IsolatedFromAbove, Symbol + FunctionOpInterface, InModuleScope, IsolatedFromAbove, Symbol ]> { let summary = "Declare or define a function"; @@ -307,22 +308,24 @@ let autogenSerialization = 0; let extraClassDeclaration = [{ - private: - // This trait needs access to the hooks defined below. - friend class OpTrait::FunctionLike; + /// Returns the type of this function. + /// FIXME: We should drive this via the ODS `type` param. + FunctionType getType() { + return getTypeAttr().getValue().cast(); + } - /// Returns the number of arguments. Hook for OpTrait::FunctionLike. - unsigned getNumFuncArguments() { return getType().getNumInputs(); } + /// Returns the argument types of this function. + ArrayRef getArgumentTypes() { return getType().getInputs(); } - /// Returns the number of results. Hook for OpTrait::FunctionLike. - unsigned getNumFuncResults() { return getType().getNumResults(); } + /// Returns the result types of this function. + ArrayRef getResultTypes() { return getType().getResults(); } - /// Hook for OpTrait::FunctionLike, called after verifying that the 'type' + /// Hook for FunctionOpInterface, called after verifying that the 'type' /// attribute is present and checks if it holds a function type. Ensures - /// getType, getNumFuncArguments, and getNumFuncResults can be called safely + /// getType, getNumArguments, and getNumResults can be called safely LogicalResult verifyType(); - /// Hook for OpTrait::FunctionLike, called after verifying the function + /// Hook for FunctionOpInterface, called after verifying the function /// type and the presence of the (potentially empty) function body. /// Ensures SPIR-V specific semantics. LogicalResult verifyBody(); diff --git a/mlir/include/mlir/IR/BuiltinOps.h b/mlir/include/mlir/IR/BuiltinOps.h --- a/mlir/include/mlir/IR/BuiltinOps.h +++ b/mlir/include/mlir/IR/BuiltinOps.h @@ -13,7 +13,7 @@ #ifndef MLIR_IR_BUILTINOPS_H_ #define MLIR_IR_BUILTINOPS_H_ -#include "mlir/IR/FunctionSupport.h" +#include "mlir/IR/FunctionInterfaces.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/OwningOpRef.h" #include "mlir/IR/RegionKindInterface.h" diff --git a/mlir/include/mlir/IR/BuiltinOps.td b/mlir/include/mlir/IR/BuiltinOps.td --- a/mlir/include/mlir/IR/BuiltinOps.td +++ b/mlir/include/mlir/IR/BuiltinOps.td @@ -15,6 +15,7 @@ #define BUILTIN_OPS include "mlir/IR/BuiltinDialect.td" +include "mlir/IR/FunctionInterfaces.td" include "mlir/IR/OpAsmInterface.td" include "mlir/IR/RegionKindInterface.td" include "mlir/IR/SymbolInterfaces.td" @@ -32,8 +33,8 @@ //===----------------------------------------------------------------------===// def FuncOp : Builtin_Op<"func", [ - AffineScope, AutomaticAllocationScope, CallableOpInterface, FunctionLike, - IsolatedFromAbove, Symbol + AffineScope, AutomaticAllocationScope, CallableOpInterface, + FunctionOpInterface, IsolatedFromAbove, Symbol ]> { let summary = "An operation with a name containing a single `SSACFG` region"; let description = [{ @@ -109,6 +110,12 @@ /// compatible. void cloneInto(FuncOp dest, BlockAndValueMapping &mapper); + /// Returns the type of this function. + /// FIXME: We should drive this via the ODS `type` param. + FunctionType getType() { + return getTypeAttr().getValue().cast(); + } + //===------------------------------------------------------------------===// // CallableOpInterface //===------------------------------------------------------------------===// @@ -123,26 +130,17 @@ ArrayRef getCallableResults() { return getType().getResults(); } //===------------------------------------------------------------------===// - // SymbolOpInterface Methods + // FunctionOpInterface Methods //===------------------------------------------------------------------===// - bool isDeclaration() { return isExternal(); } - - private: - // This trait needs access to the hooks defined below. - friend class OpTrait::FunctionLike; - - /// Returns the number of arguments. This is a hook for - /// OpTrait::FunctionLike. - unsigned getNumFuncArguments() { return getType().getInputs().size(); } + /// Returns the argument types of this function. + ArrayRef getArgumentTypes() { return getType().getInputs(); } - /// Returns the number of results. This is a hook for OpTrait::FunctionLike. - unsigned getNumFuncResults() { return getType().getResults().size(); } + /// Returns the result types of this function. + ArrayRef getResultTypes() { return getType().getResults(); } - /// Hook for OpTrait::FunctionLike, called after verifying that the 'type' - /// attribute is present and checks if it holds a function type. Ensures - /// getType, getNumFuncArguments, and getNumFuncResults can be called - /// safely. + /// Verify the type attribute of this function. Returns failure and emits + /// an error if the attribute is invalid. LogicalResult verifyType() { auto type = getTypeAttr().getValue(); if (!type.isa()) @@ -150,6 +148,12 @@ "' attribute of function type"); return success(); } + + //===------------------------------------------------------------------===// + // SymbolOpInterface Methods + //===------------------------------------------------------------------===// + + bool isDeclaration() { return isExternal(); } }]; let parser = [{ return ::parseFuncOp(parser, result); }]; let printer = [{ return ::print(*this, p); }]; diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -154,6 +154,10 @@ unsigned getNumResults() const; Type getResult(unsigned i) const { return getResults()[i]; } + /// Returns a clone of this function type with the given argument + /// and result types. + FunctionType clone(TypeRange inputs, TypeRange results) const; + /// Returns a new function type with the specified arguments and results /// inserted. FunctionType getWithArgsAndResults(ArrayRef argIndices, diff --git a/mlir/include/mlir/IR/CMakeLists.txt b/mlir/include/mlir/IR/CMakeLists.txt --- a/mlir/include/mlir/IR/CMakeLists.txt +++ b/mlir/include/mlir/IR/CMakeLists.txt @@ -37,6 +37,12 @@ mlir_tablegen(BuiltinTypeInterfaces.cpp.inc -gen-type-interface-defs) add_public_tablegen_target(MLIRBuiltinTypeInterfacesIncGen) +set(LLVM_TARGET_DEFINITIONS FunctionInterfaces.td) +mlir_tablegen(FunctionOpInterfaces.h.inc -gen-op-interface-decls) +mlir_tablegen(FunctionOpInterfaces.cpp.inc -gen-op-interface-defs) +add_public_tablegen_target(MLIRFunctionInterfacesIncGen) +add_dependencies(mlir-generic-headers MLIRFunctionInterfacesIncGen) + set(LLVM_TARGET_DEFINITIONS SubElementInterfaces.td) mlir_tablegen(SubElementAttrInterfaces.h.inc -gen-attr-interface-decls) mlir_tablegen(SubElementAttrInterfaces.cpp.inc -gen-attr-interface-defs) diff --git a/mlir/include/mlir/IR/FunctionImplementation.h b/mlir/include/mlir/IR/FunctionImplementation.h --- a/mlir/include/mlir/IR/FunctionImplementation.h +++ b/mlir/include/mlir/IR/FunctionImplementation.h @@ -15,12 +15,12 @@ #ifndef MLIR_IR_FUNCTIONIMPLEMENTATION_H_ #define MLIR_IR_FUNCTIONIMPLEMENTATION_H_ -#include "mlir/IR/FunctionSupport.h" +#include "mlir/IR/FunctionInterfaces.h" #include "mlir/IR/OpImplementation.h" namespace mlir { -namespace function_like_impl { +namespace function_interface_impl { /// A named class for passing around the variadic flag. class VariadicFlag { @@ -44,7 +44,7 @@ ArrayRef argAttrs, ArrayRef resultAttrs); -/// Callback type for `parseFunctionLikeOp`, the callback should produce the +/// Callback type for `parseFunctionOp`, the callback should produce the /// type that will be associated with a function-like operation from lists of /// function arguments and results, VariadicFlag indicates whether the function /// should have variadic arguments; in case of error, it may populate the last @@ -81,18 +81,17 @@ /// whether the function is variadic. If the builder returns a null type, /// `result` will not contain the `type` attribute. The caller can then add a /// type, report the error or delegate the reporting to the op's verifier. -ParseResult parseFunctionLikeOp(OpAsmParser &parser, OperationState &result, - bool allowVariadic, - FuncTypeBuilder funcTypeBuilder); +ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, + bool allowVariadic, + FuncTypeBuilder funcTypeBuilder); /// Printer implementation for function-like operations. Accepts lists of /// argument and result types to use while printing. -void printFunctionLikeOp(OpAsmPrinter &p, Operation *op, - ArrayRef argTypes, bool isVariadic, - ArrayRef resultTypes); +void printFunctionOp(OpAsmPrinter &p, Operation *op, ArrayRef argTypes, + bool isVariadic, ArrayRef resultTypes); -/// Prints the signature of the function-like operation `op`. Assumes `op` has -/// the FunctionLike trait and passed the verification. +/// Prints the signature of the function-like operation `op`. Assumes `op` has +/// is a FunctionOpInterface and has passed verification. void printFunctionSignature(OpAsmPrinter &p, Operation *op, ArrayRef argTypes, bool isVariadic, ArrayRef resultTypes); @@ -100,13 +99,13 @@ /// Prints the list of function prefixed with the "attributes" keyword. The /// attributes with names listed in "elided" as well as those used by the /// function-like operation internally are not printed. Nothing is printed -/// if all attributes are elided. Assumes `op` has the `FunctionLike` trait and -/// passed the verification. +/// if all attributes are elided. Assumes `op` is a FunctionOpInterface and +/// has passed verification. void printFunctionAttributes(OpAsmPrinter &p, Operation *op, unsigned numInputs, unsigned numResults, ArrayRef elided = {}); -} // namespace function_like_impl +} // namespace function_interface_impl } // namespace mlir diff --git a/mlir/include/mlir/IR/FunctionInterfaces.h b/mlir/include/mlir/IR/FunctionInterfaces.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/IR/FunctionInterfaces.h @@ -0,0 +1,295 @@ +//===- FunctionSupport.h - Utility types for function-like ops --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines support types for Operations that represent function-like +// constructs to use. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_FUNCTIONINTERFACES_H +#define MLIR_IR_FUNCTIONINTERFACES_H + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OpDefinition.h" +#include "llvm/ADT/SmallString.h" + +namespace mlir { + +namespace function_interface_impl { + +/// Return the name of the attribute used for function types. +inline StringRef getTypeAttrName() { return "type"; } + +/// Return the name of the attribute used for function argument attributes. +inline StringRef getArgDictAttrName() { return "arg_attrs"; } + +/// Return the name of the attribute used for function argument attributes. +inline StringRef getResultDictAttrName() { return "res_attrs"; } + +/// Returns the dictionary attribute corresponding to the argument at 'index'. +/// If there are no argument attributes at 'index', a null attribute is +/// returned. +DictionaryAttr getArgAttrDict(Operation *op, unsigned index); + +/// Returns the dictionary attribute corresponding to the result at 'index'. +/// If there are no result attributes at 'index', a null attribute is +/// returned. +DictionaryAttr getResultAttrDict(Operation *op, unsigned index); + +namespace detail { +/// Update the given index into an argument or result attribute dictionary. +void setArgResAttrDict(Operation *op, StringRef attrName, + unsigned numTotalIndices, unsigned index, + DictionaryAttr attrs); +} // namespace detail + +/// Set all of the argument or result attribute dictionaries for a function. The +/// size of `attrs` is expected to match the number of arguments/results of the +/// given `op`. +void setAllArgAttrDicts(Operation *op, ArrayRef attrs); +void setAllArgAttrDicts(Operation *op, ArrayRef attrs); +void setAllResultAttrDicts(Operation *op, ArrayRef attrs); +void setAllResultAttrDicts(Operation *op, ArrayRef attrs); + +/// Return all of the attributes for the argument at 'index'. +inline ArrayRef getArgAttrs(Operation *op, unsigned index) { + auto argDict = getArgAttrDict(op, index); + return argDict ? argDict.getValue() : llvm::None; +} + +/// Return all of the attributes for the result at 'index'. +inline ArrayRef getResultAttrs(Operation *op, unsigned index) { + auto resultDict = getResultAttrDict(op, index); + return resultDict ? resultDict.getValue() : llvm::None; +} + +/// Insert the specified arguments and update the function type attribute. +void insertFunctionArguments(Operation *op, ArrayRef argIndices, + TypeRange argTypes, + ArrayRef argAttrs, + ArrayRef> argLocs, + unsigned originalNumArgs, Type newType); + +/// Insert the specified results and update the function type attribute. +void insertFunctionResults(Operation *op, ArrayRef resultIndices, + TypeRange resultTypes, + ArrayRef resultAttrs, + unsigned originalNumResults, Type newType); + +/// Erase the specified arguments and update the function type attribute. +void eraseFunctionArguments(Operation *op, ArrayRef argIndices, + unsigned originalNumArgs, Type newType); + +/// Erase the specified results and update the function type attribute. +void eraseFunctionResults(Operation *op, ArrayRef resultIndices, + unsigned originalNumResults, Type newType); + +/// Set a FunctionOpInterface operation's type signature. +void setFunctionType(Operation *op, Type newType); + +/// Insert a set of `newTypes` into `oldTypes` at the given `indices`. If any +/// types are inserted, `storage` is used to hold the new type list. The new +/// type list is returned. `indices` must be sorted by increasing index. +TypeRange insertTypesInto(TypeRange oldTypes, ArrayRef indices, + TypeRange newTypes, SmallVectorImpl &storage); + +/// Filters out any elements referenced by `indices`. If any types are removed, +/// `storage` is used to hold the new type list. Returns the new type list. +TypeRange filterTypesOut(TypeRange types, ArrayRef indices, + SmallVectorImpl &storage); + +//===----------------------------------------------------------------------===// +// Function Argument Attribute. +//===----------------------------------------------------------------------===// + +/// Set the attributes held by the argument at 'index'. +template +void setArgAttrs(ConcreteType op, unsigned index, + ArrayRef attributes) { + assert(index < op.getNumArguments() && "invalid argument number"); + return detail::setArgResAttrDict( + op, getArgDictAttrName(), op.getNumArguments(), index, + DictionaryAttr::get(op->getContext(), attributes)); +} +template +void setArgAttrs(ConcreteType op, unsigned index, DictionaryAttr attributes) { + return detail::setArgResAttrDict( + op, getArgDictAttrName(), op.getNumArguments(), index, + attributes ? attributes : DictionaryAttr::get(op->getContext())); +} + +/// If the an attribute exists with the specified name, change it to the new +/// value. Otherwise, add a new attribute with the specified name/value. +template +void setArgAttr(ConcreteType op, unsigned index, StringAttr name, + Attribute value) { + NamedAttrList attributes(op.getArgAttrDict(index)); + Attribute oldValue = attributes.set(name, value); + + // If the attribute changed, then set the new arg attribute list. + if (value != oldValue) + op.setArgAttrs(index, attributes.getDictionary(value.getContext())); +} + +/// Remove the attribute 'name' from the argument at 'index'. Returns the +/// removed attribute, or nullptr if `name` was not a valid attribute. +template +Attribute removeArgAttr(ConcreteType op, unsigned index, StringAttr name) { + // Build an attribute list and remove the attribute at 'name'. + NamedAttrList attributes(op.getArgAttrDict(index)); + Attribute removedAttr = attributes.erase(name); + + // If the attribute was removed, then update the argument dictionary. + if (removedAttr) + op.setArgAttrs(index, attributes.getDictionary(removedAttr.getContext())); + return removedAttr; +} + +//===----------------------------------------------------------------------===// +// Function Result Attribute. +//===----------------------------------------------------------------------===// + +/// Set the attributes held by the result at 'index'. +template +void setResultAttrs(ConcreteType op, unsigned index, + ArrayRef attributes) { + assert(index < op.getNumResults() && "invalid result number"); + return detail::setArgResAttrDict( + op, getResultDictAttrName(), op.getNumResults(), index, + DictionaryAttr::get(op->getContext(), attributes)); +} + +template +void setResultAttrs(ConcreteType op, unsigned index, + DictionaryAttr attributes) { + assert(index < op.getNumResults() && "invalid result number"); + return detail::setArgResAttrDict( + op, getResultDictAttrName(), op.getNumResults(), index, + attributes ? attributes : DictionaryAttr::get(op->getContext())); +} + +/// If the an attribute exists with the specified name, change it to the new +/// value. Otherwise, add a new attribute with the specified name/value. +template +void setResultAttr(ConcreteType op, unsigned index, StringAttr name, + Attribute value) { + NamedAttrList attributes(op.getResultAttrDict(index)); + Attribute oldAttr = attributes.set(name, value); + + // If the attribute changed, then set the new arg attribute list. + if (oldAttr != value) + op.setResultAttrs(index, attributes.getDictionary(value.getContext())); +} + +/// Remove the attribute 'name' from the result at 'index'. +template +Attribute removeResultAttr(ConcreteType op, unsigned index, StringAttr name) { + // Build an attribute list and remove the attribute at 'name'. + NamedAttrList attributes(op.getResultAttrDict(index)); + Attribute removedAttr = attributes.erase(name); + + // If the attribute was removed, then update the result dictionary. + if (removedAttr) + op.setResultAttrs(index, + attributes.getDictionary(removedAttr.getContext())); + return removedAttr; +} + +/// This function defines the internal implementation of the `verifyTrait` +/// method on FunctionOpInterface::Trait. +template +LogicalResult verifyTrait(ConcreteOp op) { + if (!op.getTypeAttr()) + return op.emitOpError("requires a type attribute '") + << ConcreteOp::getTypeAttrName() << '\''; + + if (failed(op.verifyType())) + return failure(); + + if (ArrayAttr allArgAttrs = op.getAllArgAttrs()) { + unsigned numArgs = op.getNumArguments(); + if (allArgAttrs.size() != numArgs) { + return op.emitOpError() + << "expects argument attribute array `" << getArgDictAttrName() + << "` to have the same number of elements as the number of " + "function arguments, got " + << allArgAttrs.size() << ", but expected " << numArgs; + } + for (unsigned i = 0; i != numArgs; ++i) { + DictionaryAttr argAttrs = + allArgAttrs[i].dyn_cast_or_null(); + if (!argAttrs) { + return op.emitOpError() << "expects argument attribute dictionary " + "to be a DictionaryAttr, but got `" + << allArgAttrs[i] << "`"; + } + + // Verify that all of the argument attributes are dialect attributes, i.e. + // that they contain a dialect prefix in their name. Call the dialect, if + // registered, to verify the attributes themselves. + for (auto attr : argAttrs) { + if (!attr.getName().strref().contains('.')) + return op.emitOpError("arguments may only have dialect attributes"); + if (Dialect *dialect = attr.getNameDialect()) { + if (failed(dialect->verifyRegionArgAttribute(op, /*regionIndex=*/0, + /*argIndex=*/i, attr))) + return failure(); + } + } + } + } + if (ArrayAttr allResultAttrs = op.getAllResultAttrs()) { + unsigned numResults = op.getNumResults(); + if (allResultAttrs.size() != numResults) { + return op.emitOpError() + << "expects result attribute array `" << getResultDictAttrName() + << "` to have the same number of elements as the number of " + "function results, got " + << allResultAttrs.size() << ", but expected " << numResults; + } + for (unsigned i = 0; i != numResults; ++i) { + DictionaryAttr resultAttrs = + allResultAttrs[i].dyn_cast_or_null(); + if (!resultAttrs) { + return op.emitOpError() << "expects result attribute dictionary " + "to be a DictionaryAttr, but got `" + << allResultAttrs[i] << "`"; + } + + // Verify that all of the result attributes are dialect attributes, i.e. + // that they contain a dialect prefix in their name. Call the dialect, if + // registered, to verify the attributes themselves. + for (auto attr : resultAttrs) { + if (!attr.getName().strref().contains('.')) + return op.emitOpError("results may only have dialect attributes"); + if (Dialect *dialect = attr.getNameDialect()) { + if (failed(dialect->verifyRegionResultAttribute(op, /*regionIndex=*/0, + /*resultIndex=*/i, + attr))) + return failure(); + } + } + } + } + + // Check that the op has exactly one region for the body. + if (op->getNumRegions() != 1) + return op.emitOpError("expects one region"); + + return op.verifyBody(); +} +} // namespace function_interface_impl +} // namespace mlir + +//===----------------------------------------------------------------------===// +// Tablegen Interface Declarations +//===----------------------------------------------------------------------===// + +#include "mlir/IR/FunctionOpInterfaces.h.inc" + +#endif // MLIR_IR_FUNCTIONINTERFACES_H diff --git a/mlir/include/mlir/IR/FunctionInterfaces.td b/mlir/include/mlir/IR/FunctionInterfaces.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/IR/FunctionInterfaces.td @@ -0,0 +1,529 @@ +//===- FunctionInterfaces.td - Function interfaces --------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains definitions for interfaces that support the definition of +// "function-like" operations. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_FUNCTIONINTERFACES_TD_ +#define MLIR_IR_FUNCTIONINTERFACES_TD_ + +include "mlir/IR/OpBase.td" + +//===----------------------------------------------------------------------===// +// FunctionOpInterface +//===----------------------------------------------------------------------===// + +def FunctionOpInterface : OpInterface<"FunctionOpInterface"> { + let cppNamespace = "::mlir"; + let description = [{ + This interfaces provides support for interacting with operations that + behave like functions. In particular, these operations: + + - must be symbols, i.e. have the `Symbol` trait. + - must have a single region, that may be comprised with multiple blocks, + that corresponds to the function body. + * when this region is empty, the operation corresponds to an external + function. + * leading arguments of the first block of the region are treated as + function arguments. + + The function, aside from implementing the various interface methods, + should have the following ODS arguments: + + - `type` (required) + * A TypeAttr that holds the signature type of the function. + * TODO: this field will soon be renamed to something less generic. + + - `arg_attrs` (optional) + * An ArrayAttr of DictionaryAttr that contains attribute dictionaries + for each of the function arguments. + + - `res_attrs` (optional) + * An ArrayAttr of DictionaryAttr that contains attribute dictionaries + for each of the function results. + }]; + let methods = [ + InterfaceMethod<[{ + Returns the function argument types based exclusively on + the type (to allow for this method may be called on function + declarations). + }], + "ArrayRef", "getArgumentTypes">, + InterfaceMethod<[{ + Returns the function result types based exclusively on + the type (to allow for this method may be called on function + declarations). + }], + "ArrayRef", "getResultTypes">, + InterfaceMethod<[{ + Returns a clone of the function type with the given argument and + result types. + + Note: The default implementation assumes the function type has + an appropriate clone method: + `Type clone(ArrayRef inputs, ArrayRef results)` + }], + "Type", "cloneTypeWith", (ins + "::mlir::TypeRange":$inputs, "::mlir::TypeRange":$results + ), /*methodBody=*/[{}], /*defaultImplementation=*/[{ + return $_op.getType().clone(inputs, results); + }]>, + + InterfaceMethod<[{ + Verify the contents of the body of this function. + + Note: The default implementation merely checks that of the entry block + exists, it has the same number arguments as the function type. + }], + "::mlir::LogicalResult", "verifyBody", (ins), + /*methodBody=*/[{}], /*defaultImplementation=*/[{ + if ($_op.isExternal()) + return success(); + + unsigned numArguments = $_op.getNumArguments(); + if ($_op.front().getNumArguments() != numArguments) + return $_op.emitOpError("entry block must have ") + << numArguments << " arguments to match function signature"; + return success(); + }]>, + InterfaceMethod<[{ + Verify the type attribute of the function for derived op-specific + invariants. + }], + "::mlir::LogicalResult", "verifyType">, + ]; + + let extraSharedClassDeclaration = [{ + /// Block list iterator types. + using BlockListType = Region::BlockListType; + using iterator = BlockListType::iterator; + using reverse_iterator = BlockListType::reverse_iterator; + + /// Block argument iterator types. + using BlockArgListType = Region::BlockArgListType; + using args_iterator = BlockArgListType::iterator; + + //===------------------------------------------------------------------===// + // Body Handling + //===------------------------------------------------------------------===// + + /// Returns true if this function is external, i.e. it has no body. + bool isExternal() { return empty(); } + + /// Return the region containing the body of this function. + Region &getBody() { return $_op->getRegion(0); } + + /// Delete all blocks from this function. + void eraseBody() { + getBody().dropAllReferences(); + getBody().getBlocks().clear(); + } + + /// Return the list of blocks within the function body. + BlockListType &getBlocks() { return getBody().getBlocks(); } + + iterator begin() { return getBody().begin(); } + iterator end() { return getBody().end(); } + reverse_iterator rbegin() { return getBody().rbegin(); } + reverse_iterator rend() { return getBody().rend(); } + + /// Returns true if this function has no blocks within the body. + bool empty() { return getBody().empty(); } + + /// Push a new block to the back of the body region. + void push_back(Block *block) { getBody().push_back(block); } + + /// Push a new block to the front of the body region. + void push_front(Block *block) { getBody().push_front(block); } + + /// Return the last block in the body region. + Block &back() { return getBody().back(); } + + /// Return the first block in the body region. + Block &front() { return getBody().front(); } + + /// Add an entry block to an empty function, and set up the block arguments + /// to match the signature of the function. The newly inserted entry block + /// is returned. + Block *addEntryBlock() { + assert(empty() && "function already has an entry block"); + Block *entry = new Block(); + push_back(entry); + entry->addArguments($_op.getArgumentTypes()); + return entry; + } + + /// Add a normal block to the end of the function's block list. The function + /// should at least already have an entry block. + Block *addBlock() { + assert(!empty() && "function should at least have an entry block"); + push_back(new Block()); + return &back(); + } + + //===------------------------------------------------------------------===// + // Type Attribute Handling + //===------------------------------------------------------------------===// + + /// Change the type of this function in place. This is an extremely dangerous + /// operation and it is up to the caller to ensure that this is legal for + /// this function, and to restore invariants: + /// - the entry block args must be updated to match the function params. + /// - the argument/result attributes may need an update: if the new type + /// has less parameters we drop the extra attributes, if there are more + /// parameters they won't have any attributes. + void setType(Type newType) { + function_interface_impl::setFunctionType(this->getOperation(), newType); + } + + // FIXME: These functions should be removed in favor of just forwarding to + // the derived operation, which should already have these defined + // (via ODS). + + /// Returns the name of the attribute used for function types. + static StringRef getTypeAttrName() { + return function_interface_impl::getTypeAttrName(); + } + + /// Returns the name of the attribute used for function argument attributes. + static StringRef getArgDictAttrName() { + return function_interface_impl::getArgDictAttrName(); + } + + /// Returns the name of the attribute used for function argument attributes. + static StringRef getResultDictAttrName() { + return function_interface_impl::getResultDictAttrName(); + } + + /// Return the attribute containing the type of this function. + TypeAttr getTypeAttr() { + return this->getOperation()->template getAttrOfType( + getTypeAttrName()); + } + + /// Return the type of this function. + Type getType() { return getTypeAttr().getValue(); } + + //===------------------------------------------------------------------===// + // Argument and Result Handling + //===------------------------------------------------------------------===// + + /// Returns the number of function arguments. + unsigned getNumArguments() { return $_op.getArgumentTypes().size(); } + + /// Returns the number of function results. + unsigned getNumResults() { return $_op.getResultTypes().size(); } + + /// Returns the entry block function argument at the given index. + BlockArgument getArgument(unsigned idx) { + return getBody().getArgument(idx); + } + + /// Support argument iteration. + args_iterator args_begin() { return getBody().args_begin(); } + args_iterator args_end() { return getBody().args_end(); } + BlockArgListType getArguments() { return getBody().getArguments(); } + + /// Insert a single argument of type `argType` with attributes `argAttrs` and + /// location `argLoc` at `argIndex`. + void insertArgument(unsigned argIndex, Type argType, DictionaryAttr argAttrs, + Optional argLoc = {}) { + insertArguments({argIndex}, {argType}, {argAttrs}, {argLoc}); + } + + /// Inserts arguments with the listed types, attributes, and locations at the + /// listed indices. `argIndices` must be sorted. Arguments are inserted in the + /// order they are listed, such that arguments with identical index will + /// appear in the same order that they were listed here. + void insertArguments(ArrayRef argIndices, TypeRange argTypes, + ArrayRef argAttrs, + ArrayRef> argLocs) { + unsigned originalNumArgs = $_op.getNumArguments(); + Type newType = $_op.getTypeWithArgsAndResults( + argIndices, argTypes, /*resultIndices=*/{}, /*resultTypes=*/{}); + function_interface_impl::insertFunctionArguments( + this->getOperation(), argIndices, argTypes, argAttrs, argLocs, + originalNumArgs, newType); + } + + /// Insert a single result of type `resultType` at `resultIndex`. + void insertResult(unsigned resultIndex, Type resultType, + DictionaryAttr resultAttrs) { + insertResults({resultIndex}, {resultType}, {resultAttrs}); + } + + /// Inserts results with the listed types at the listed indices. + /// `resultIndices` must be sorted. Results are inserted in the order they are + /// listed, such that results with identical index will appear in the same + /// order that they were listed here. + void insertResults(ArrayRef resultIndices, TypeRange resultTypes, + ArrayRef resultAttrs) { + unsigned originalNumResults = $_op.getNumResults(); + Type newType = $_op.getTypeWithArgsAndResults( + /*argIndices=*/{}, /*argTypes=*/{}, resultIndices, resultTypes); + function_interface_impl::insertFunctionResults( + this->getOperation(), resultIndices, resultTypes, resultAttrs, + originalNumResults, newType); + } + + /// Erase a single argument at `argIndex`. + void eraseArgument(unsigned argIndex) { eraseArguments({argIndex}); } + + /// Erases the arguments listed in `argIndices`. + /// `argIndices` is allowed to have duplicates and can be in any order. + void eraseArguments(ArrayRef argIndices) { + unsigned originalNumArgs = $_op.getNumArguments(); + Type newType = $_op.getTypeWithoutArgsAndResults(argIndices, {}); + function_interface_impl::eraseFunctionArguments(this->getOperation(), argIndices, + originalNumArgs, newType); + } + + /// Erase a single result at `resultIndex`. + void eraseResult(unsigned resultIndex) { eraseResults({resultIndex}); } + + /// Erases the results listed in `resultIndices`. + /// `resultIndices` is allowed to have duplicates and can be in any order. + void eraseResults(ArrayRef resultIndices) { + unsigned originalNumResults = $_op.getNumResults(); + Type newType = $_op.getTypeWithoutArgsAndResults({}, resultIndices); + function_interface_impl::eraseFunctionResults( + this->getOperation(), resultIndices, originalNumResults, newType); + } + + /// Return the type of this function with the specified arguments and + /// results inserted. This is used to update the function's signature in + /// the `insertArguments` and `insertResults` methods. The arrays must be + /// sorted by increasing index. + Type getTypeWithArgsAndResults( + ArrayRef argIndices, TypeRange argTypes, + ArrayRef resultIndices, TypeRange resultTypes) { + SmallVector argStorage, resultStorage; + TypeRange newArgTypes = function_interface_impl::insertTypesInto( + $_op.getArgumentTypes(), argIndices, argTypes, argStorage); + TypeRange newResultTypes = function_interface_impl::insertTypesInto( + $_op.getResultTypes(), resultIndices, resultTypes, resultStorage); + return $_op.cloneTypeWith(newArgTypes, newResultTypes); + } + + /// Return the type of this function without the specified arguments and + /// results. This is used to update the function's signature in the + /// `eraseArguments` and `eraseResults` methods. The arrays of indices are + /// allowed to have duplicates and can be in any order. + Type getTypeWithoutArgsAndResults( + ArrayRef argIndices, ArrayRef resultIndices) { + SmallVector argStorage, resultStorage; + TypeRange newArgTypes = function_interface_impl::filterTypesOut( + $_op.getArgumentTypes(), argIndices, argStorage); + TypeRange newResultTypes = function_interface_impl::filterTypesOut( + $_op.getResultTypes(), resultIndices, resultStorage); + return $_op.cloneTypeWith(newArgTypes, newResultTypes); + } + + //===------------------------------------------------------------------===// + // Argument Attributes + //===------------------------------------------------------------------===// + + /// Return all of the attributes for the argument at 'index'. + ArrayRef getArgAttrs(unsigned index) { + return function_interface_impl::getArgAttrs(this->getOperation(), index); + } + + /// Return an ArrayAttr containing all argument attribute dictionaries of + /// this function, or nullptr if no arguments have attributes. + ArrayAttr getAllArgAttrs() { + return this->getOperation()->template getAttrOfType( + getArgDictAttrName()); + } + /// Return all argument attributes of this function. + void getAllArgAttrs(SmallVectorImpl &result) { + if (ArrayAttr argAttrs = getAllArgAttrs()) { + auto argAttrRange = argAttrs.template getAsRange(); + result.append(argAttrRange.begin(), argAttrRange.end()); + } else { + result.append($_op.getNumArguments(), + DictionaryAttr::get(this->getOperation()->getContext())); + } + } + + /// Return the specified attribute, if present, for the argument at 'index', + /// null otherwise. + Attribute getArgAttr(unsigned index, StringAttr name) { + auto argDict = getArgAttrDict(index); + return argDict ? argDict.get(name) : nullptr; + } + Attribute getArgAttr(unsigned index, StringRef name) { + auto argDict = getArgAttrDict(index); + return argDict ? argDict.get(name) : nullptr; + } + + template + AttrClass getArgAttrOfType(unsigned index, StringAttr name) { + return getArgAttr(index, name).template dyn_cast_or_null(); + } + template + AttrClass getArgAttrOfType(unsigned index, StringRef name) { + return getArgAttr(index, name).template dyn_cast_or_null(); + } + + /// Set the attributes held by the argument at 'index'. + void setArgAttrs(unsigned index, ArrayRef attributes) { + function_interface_impl::setArgAttrs($_op, index, attributes); + } + + /// Set the attributes held by the argument at 'index'. `attributes` may be + /// null, in which case any existing argument attributes are removed. + void setArgAttrs(unsigned index, DictionaryAttr attributes) { + function_interface_impl::setArgAttrs($_op, index, attributes); + } + void setAllArgAttrs(ArrayRef attributes) { + assert(attributes.size() == $_op.getNumArguments()); + function_interface_impl::setAllArgAttrDicts(this->getOperation(), attributes); + } + void setAllArgAttrs(ArrayRef attributes) { + assert(attributes.size() == $_op.getNumArguments()); + function_interface_impl::setAllArgAttrDicts(this->getOperation(), attributes); + } + void setAllArgAttrs(ArrayAttr attributes) { + assert(attributes.size() == $_op.getNumArguments()); + this->getOperation()->setAttr(getArgDictAttrName(), attributes); + } + + /// If the an attribute exists with the specified name, change it to the new + /// value. Otherwise, add a new attribute with the specified name/value. + void setArgAttr(unsigned index, StringAttr name, Attribute value) { + function_interface_impl::setArgAttr($_op, index, name, value); + } + void setArgAttr(unsigned index, StringRef name, Attribute value) { + setArgAttr(index, + StringAttr::get(this->getOperation()->getContext(), name), + value); + } + + /// Remove the attribute 'name' from the argument at 'index'. Return the + /// attribute that was erased, or nullptr if there was no attribute with + /// such name. + Attribute removeArgAttr(unsigned index, StringAttr name) { + return function_interface_impl::removeArgAttr($_op, index, name); + } + Attribute removeArgAttr(unsigned index, StringRef name) { + return removeArgAttr( + index, StringAttr::get(this->getOperation()->getContext(), name)); + } + + //===------------------------------------------------------------------===// + // Result Attributes + //===------------------------------------------------------------------===// + + /// Return all of the attributes for the result at 'index'. + ArrayRef getResultAttrs(unsigned index) { + return function_interface_impl::getResultAttrs(this->getOperation(), index); + } + + /// Return an ArrayAttr containing all result attribute dictionaries of this + /// function, or nullptr if no result have attributes. + ArrayAttr getAllResultAttrs() { + return this->getOperation()->template getAttrOfType( + getResultDictAttrName()); + } + /// Return all result attributes of this function. + void getAllResultAttrs(SmallVectorImpl &result) { + if (ArrayAttr argAttrs = getAllResultAttrs()) { + auto argAttrRange = argAttrs.template getAsRange(); + result.append(argAttrRange.begin(), argAttrRange.end()); + } else { + result.append($_op.getNumResults(), + DictionaryAttr::get(this->getOperation()->getContext())); + } + } + + /// Return the specified attribute, if present, for the result at 'index', + /// null otherwise. + Attribute getResultAttr(unsigned index, StringAttr name) { + auto argDict = getResultAttrDict(index); + return argDict ? argDict.get(name) : nullptr; + } + Attribute getResultAttr(unsigned index, StringRef name) { + auto argDict = getResultAttrDict(index); + return argDict ? argDict.get(name) : nullptr; + } + + template + AttrClass getResultAttrOfType(unsigned index, StringAttr name) { + return getResultAttr(index, name).template dyn_cast_or_null(); + } + template + AttrClass getResultAttrOfType(unsigned index, StringRef name) { + return getResultAttr(index, name).template dyn_cast_or_null(); + } + + /// Set the attributes held by the result at 'index'. + void setResultAttrs(unsigned index, ArrayRef attributes) { + function_interface_impl::setResultAttrs($_op, index, attributes); + } + + /// Set the attributes held by the result at 'index'. `attributes` may be + /// null, in which case any existing argument attributes are removed. + void setResultAttrs(unsigned index, DictionaryAttr attributes) { + function_interface_impl::setResultAttrs($_op, index, attributes); + } + void setAllResultAttrs(ArrayRef attributes) { + assert(attributes.size() == $_op.getNumResults()); + function_interface_impl::setAllResultAttrDicts( + this->getOperation(), attributes); + } + void setAllResultAttrs(ArrayRef attributes) { + assert(attributes.size() == $_op.getNumResults()); + function_interface_impl::setAllResultAttrDicts( + this->getOperation(), attributes); + } + void setAllResultAttrs(ArrayAttr attributes) { + assert(attributes.size() == $_op.getNumResults()); + this->getOperation()->setAttr(getResultDictAttrName(), attributes); + } + + /// If the an attribute exists with the specified name, change it to the new + /// value. Otherwise, add a new attribute with the specified name/value. + void setResultAttr(unsigned index, StringAttr name, Attribute value) { + function_interface_impl::setResultAttr($_op, index, name, value); + } + void setResultAttr(unsigned index, StringRef name, Attribute value) { + setResultAttr(index, + StringAttr::get(this->getOperation()->getContext(), name), + value); + } + + /// Remove the attribute 'name' from the result at 'index'. Return the + /// attribute that was erased, or nullptr if there was no attribute with + /// such name. + Attribute removeResultAttr(unsigned index, StringAttr name) { + return function_interface_impl::removeResultAttr($_op, index, name); + } + + /// Returns the dictionary attribute corresponding to the argument at + /// 'index'. If there are no argument attributes at 'index', a null + /// attribute is returned. + DictionaryAttr getArgAttrDict(unsigned index) { + assert(index < $_op.getNumArguments() && "invalid argument number"); + return function_interface_impl::getArgAttrDict(this->getOperation(), index); + } + + /// Returns the dictionary attribute corresponding to the result at 'index'. + /// If there are no result attributes at 'index', a null attribute is + /// returned. + DictionaryAttr getResultAttrDict(unsigned index) { + assert(index < $_op.getNumResults() && "invalid result number"); + return function_interface_impl::getResultAttrDict(this->getOperation(), index); + } + }]; + + let verify = "return function_interface_impl::verifyTrait(cast($_op));"; +} + +#endif // MLIR_IR_FUNCTIONINTERFACES_TD_ diff --git a/mlir/include/mlir/IR/FunctionSupport.h b/mlir/include/mlir/IR/FunctionSupport.h deleted file mode 100644 --- a/mlir/include/mlir/IR/FunctionSupport.h +++ /dev/null @@ -1,803 +0,0 @@ -//===- FunctionSupport.h - Utility types for function-like ops --*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file defines support types for Operations that represent function-like -// constructs to use. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_IR_FUNCTIONSUPPORT_H -#define MLIR_IR_FUNCTIONSUPPORT_H - -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/OpDefinition.h" -#include "llvm/ADT/SmallString.h" - -namespace mlir { - -namespace function_like_impl { - -/// Return the name of the attribute used for function types. -inline StringRef getTypeAttrName() { return "type"; } - -/// Return the name of the attribute used for function argument attributes. -inline StringRef getArgDictAttrName() { return "arg_attrs"; } - -/// Return the name of the attribute used for function argument attributes. -inline StringRef getResultDictAttrName() { return "res_attrs"; } - -/// Returns the dictionary attribute corresponding to the argument at 'index'. -/// If there are no argument attributes at 'index', a null attribute is -/// returned. -DictionaryAttr getArgAttrDict(Operation *op, unsigned index); - -/// Returns the dictionary attribute corresponding to the result at 'index'. -/// If there are no result attributes at 'index', a null attribute is -/// returned. -DictionaryAttr getResultAttrDict(Operation *op, unsigned index); - -namespace detail { -/// Update the given index into an argument or result attribute dictionary. -void setArgResAttrDict(Operation *op, StringRef attrName, - unsigned numTotalIndices, unsigned index, - DictionaryAttr attrs); -} // namespace detail - -/// Set all of the argument or result attribute dictionaries for a function. The -/// size of `attrs` is expected to match the number of arguments/results of the -/// given `op`. -void setAllArgAttrDicts(Operation *op, ArrayRef attrs); -void setAllArgAttrDicts(Operation *op, ArrayRef attrs); -void setAllResultAttrDicts(Operation *op, ArrayRef attrs); -void setAllResultAttrDicts(Operation *op, ArrayRef attrs); - -/// Return all of the attributes for the argument at 'index'. -inline ArrayRef getArgAttrs(Operation *op, unsigned index) { - auto argDict = getArgAttrDict(op, index); - return argDict ? argDict.getValue() : llvm::None; -} - -/// Return all of the attributes for the result at 'index'. -inline ArrayRef getResultAttrs(Operation *op, unsigned index) { - auto resultDict = getResultAttrDict(op, index); - return resultDict ? resultDict.getValue() : llvm::None; -} - -/// Insert the specified arguments and update the function type attribute. -void insertFunctionArguments(Operation *op, ArrayRef argIndices, - TypeRange argTypes, - ArrayRef argAttrs, - ArrayRef> argLocs, - unsigned originalNumArgs, Type newType); - -/// Insert the specified results and update the function type attribute. -void insertFunctionResults(Operation *op, ArrayRef resultIndices, - TypeRange resultTypes, - ArrayRef resultAttrs, - unsigned originalNumResults, Type newType); - -/// Erase the specified arguments and update the function type attribute. -void eraseFunctionArguments(Operation *op, ArrayRef argIndices, - unsigned originalNumArgs, Type newType); - -/// Erase the specified results and update the function type attribute. -void eraseFunctionResults(Operation *op, ArrayRef resultIndices, - unsigned originalNumResults, Type newType); - -/// Get and set a FunctionLike operation's type signature. -FunctionType getFunctionType(Operation *op); -void setFunctionType(Operation *op, FunctionType newType); - -/// Get a FunctionLike operation's body. -Region &getFunctionBody(Operation *op); - -} // namespace function_like_impl - -namespace OpTrait { - -/// This trait provides APIs for Ops that behave like functions. In particular: -/// - Ops must be symbols, i.e. also have the `Symbol` trait; -/// - Ops have a single region with multiple blocks that corresponds to the body -/// of the function; -/// - An op with a single empty region corresponds to an external function; -/// - leading arguments of the first block of the region are treated as function -/// arguments; -/// - they can have argument attributes that are stored in a dictionary -/// attribute on the Op itself. -/// -/// This trait provides limited type support for the declared or defined -/// functions. The convenience function `getTypeAttrName()` returns the name of -/// an attribute that can be used to store the function type. In addition, this -/// trait provides `getType` and `setType` helpers to store a `FunctionType` in -/// the attribute named by `getTypeAttrName()`. -/// -/// In general, this trait assumes concrete ops use `FunctionType` under the -/// hood. If this is not the case, in order to use the function type support, -/// concrete ops must define the following methods, using the same name, to hide -/// the ones defined for `FunctionType`: `addBodyBlock`, `getType`, -/// `getTypeWithoutArgsAndResults` and `setType`. -/// -/// Besides the requirements above, concrete ops must interact with this trait -/// using the following functions: -/// - Concrete ops *must* define a member function `getNumFuncArguments()` that -/// returns the number of function arguments based exclusively on type (so -/// that it can be called on function declarations). -/// - Concrete ops *must* define a member function `getNumFuncResults()` that -/// returns the number of function results based exclusively on type (so that -/// it can be called on function declarations). -/// - To verify that the type respects op-specific invariants, concrete ops may -/// redefine the `verifyType()` hook that will be called after verifying the -/// presence of the `type` attribute and before any call to -/// `getNumFuncArguments`/`getNumFuncResults` from the verifier. -/// - To verify that the body respects op-specific invariants, concrete ops may -/// redefine the `verifyBody()` hook that will be called after verifying the -/// function type and the presence of the (potentially empty) body region. -template -class FunctionLike : public OpTrait::TraitBase { -public: - /// Verify that all of the argument attributes are dialect attributes. - static LogicalResult verifyTrait(Operation *op); - - //===--------------------------------------------------------------------===// - // Body Handling - //===--------------------------------------------------------------------===// - - /// Returns true if this function is external, i.e. it has no body. - bool isExternal() { return empty(); } - - Region &getBody() { - return function_like_impl::getFunctionBody(this->getOperation()); - } - - /// Delete all blocks from this function. - void eraseBody() { - getBody().dropAllReferences(); - getBody().getBlocks().clear(); - } - - /// This is the list of blocks in the function. - using BlockListType = Region::BlockListType; - BlockListType &getBlocks() { return getBody().getBlocks(); } - - // Iteration over the block in the function. - using iterator = BlockListType::iterator; - using reverse_iterator = BlockListType::reverse_iterator; - - iterator begin() { return getBody().begin(); } - iterator end() { return getBody().end(); } - reverse_iterator rbegin() { return getBody().rbegin(); } - reverse_iterator rend() { return getBody().rend(); } - - bool empty() { return getBody().empty(); } - void push_back(Block *block) { getBody().push_back(block); } - void push_front(Block *block) { getBody().push_front(block); } - - Block &back() { return getBody().back(); } - Block &front() { return getBody().front(); } - - /// Add an entry block to an empty function, and set up the block arguments - /// to match the signature of the function. The newly inserted entry block - /// is returned. - /// - /// Note that the concrete class must define a method with the same name to - /// hide this one if the concrete class does not use FunctionType for the - /// function type under the hood. - Block *addEntryBlock(); - - /// Add a normal block to the end of the function's block list. The function - /// should at least already have an entry block. - Block *addBlock(); - - /// Hook for concrete ops to verify the contents of the body. Called as a - /// part of trait verification, after type verification and ensuring that a - /// region exists. - LogicalResult verifyBody(); - - //===--------------------------------------------------------------------===// - // Type Attribute Handling - //===--------------------------------------------------------------------===// - - /// Return the name of the attribute used for function types. - static StringRef getTypeAttrName() { - return function_like_impl::getTypeAttrName(); - } - - TypeAttr getTypeAttr() { - return this->getOperation()->template getAttrOfType( - getTypeAttrName()); - } - - /// Return the type of this function. - /// - /// Note that the concrete class must define a method with the same name to - /// hide this one if the concrete class does not use FunctionType for the - /// function type under the hood. - FunctionType getType() { - return function_like_impl::getFunctionType(this->getOperation()); - } - - /// Return the type of this function with the specified arguments and results - /// inserted. This is used to update the function's signature in the - /// `insertArguments` and `insertResults` methods. The arrays must be sorted - /// by increasing index. - /// - /// Note that the concrete class must define a method with the same name to - /// hide this one if the concrete class does not use FunctionType for the - /// function type under the hood. - FunctionType getTypeWithArgsAndResults(ArrayRef argIndices, - TypeRange argTypes, - ArrayRef resultIndices, - TypeRange resultTypes) { - return getType().getWithArgsAndResults(argIndices, argTypes, resultIndices, - resultTypes); - } - - /// Return the type of this function without the specified arguments and - /// results. This is used to update the function's signature in the - /// `eraseArguments` and `eraseResults` methods. The arrays of indices are - /// allowed to have duplicates and can be in any order. - /// - /// Note that the concrete class must define a method with the same name to - /// hide this one if the concrete class does not use FunctionType for the - /// function type under the hood. - FunctionType getTypeWithoutArgsAndResults(ArrayRef argIndices, - ArrayRef resultIndices) { - return getType().getWithoutArgsAndResults(argIndices, resultIndices); - } - - bool isTypeAttrValid() { - auto typeAttr = getTypeAttr(); - if (!typeAttr) - return false; - return typeAttr.getValue() != Type{}; - } - - /// Change the type of this function in place. This is an extremely dangerous - /// operation and it is up to the caller to ensure that this is legal for this - /// function, and to restore invariants: - /// - the entry block args must be updated to match the function params. - /// - the argument/result attributes may need an update: if the new type - /// has less parameters we drop the extra attributes, if there are more - /// parameters they won't have any attributes. - /// - /// Note that the concrete class must define a method with the same name to - /// hide this one if the concrete class does not use FunctionType for the - /// function type under the hood. - void setType(FunctionType newType); - - //===--------------------------------------------------------------------===// - // Argument and Result Handling - //===--------------------------------------------------------------------===// - using BlockArgListType = Region::BlockArgListType; - - unsigned getNumArguments() { - return static_cast(this)->getNumFuncArguments(); - } - - unsigned getNumResults() { - return static_cast(this)->getNumFuncResults(); - } - - /// Gets argument. - BlockArgument getArgument(unsigned idx) { return getBody().getArgument(idx); } - - /// Support argument iteration. - using args_iterator = Region::args_iterator; - args_iterator args_begin() { return getBody().args_begin(); } - args_iterator args_end() { return getBody().args_end(); } - Block::BlockArgListType getArguments() { return getBody().getArguments(); } - - ValueTypeRange getArgumentTypes() { - return getBody().getArgumentTypes(); - } - - /// Insert a single argument of type `argType` with attributes `argAttrs` and - /// location `argLoc` at `argIndex`. - void insertArgument(unsigned argIndex, Type argType, DictionaryAttr argAttrs, - Optional argLoc = {}) { - insertArguments({argIndex}, {argType}, {argAttrs}, {argLoc}); - } - - /// Inserts arguments with the listed types, attributes, and locations at the - /// listed indices. `argIndices` must be sorted. Arguments are inserted in the - /// order they are listed, such that arguments with identical index will - /// appear in the same order that they were listed here. - void insertArguments(ArrayRef argIndices, TypeRange argTypes, - ArrayRef argAttrs, - ArrayRef> argLocs) { - unsigned originalNumArgs = getNumArguments(); - Type newType = getTypeWithArgsAndResults( - argIndices, argTypes, /*resultIndices=*/{}, /*resultTypes=*/{}); - function_like_impl::insertFunctionArguments( - this->getOperation(), argIndices, argTypes, argAttrs, argLocs, - originalNumArgs, newType); - } - - /// Insert a single result of type `resultType` at `resultIndex`. - void insertResult(unsigned resultIndex, Type resultType, - DictionaryAttr resultAttrs) { - insertResults({resultIndex}, {resultType}, {resultAttrs}); - } - - /// Inserts results with the listed types at the listed indices. - /// `resultIndices` must be sorted. Results are inserted in the order they are - /// listed, such that results with identical index will appear in the same - /// order that they were listed here. - void insertResults(ArrayRef resultIndices, TypeRange resultTypes, - ArrayRef resultAttrs) { - unsigned originalNumResults = getNumResults(); - Type newType = getTypeWithArgsAndResults(/*argIndices=*/{}, /*argTypes=*/{}, - resultIndices, resultTypes); - function_like_impl::insertFunctionResults( - this->getOperation(), resultIndices, resultTypes, resultAttrs, - originalNumResults, newType); - } - - /// Erase a single argument at `argIndex`. - void eraseArgument(unsigned argIndex) { eraseArguments({argIndex}); } - - /// Erases the arguments listed in `argIndices`. - /// `argIndices` is allowed to have duplicates and can be in any order. - void eraseArguments(ArrayRef argIndices) { - unsigned originalNumArgs = getNumArguments(); - Type newType = getTypeWithoutArgsAndResults(argIndices, {}); - function_like_impl::eraseFunctionArguments(this->getOperation(), argIndices, - originalNumArgs, newType); - } - - /// Erase a single result at `resultIndex`. - void eraseResult(unsigned resultIndex) { eraseResults({resultIndex}); } - - /// Erases the results listed in `resultIndices`. - /// `resultIndices` is allowed to have duplicates and can be in any order. - void eraseResults(ArrayRef resultIndices) { - unsigned originalNumResults = getNumResults(); - Type newType = getTypeWithoutArgsAndResults({}, resultIndices); - function_like_impl::eraseFunctionResults( - this->getOperation(), resultIndices, originalNumResults, newType); - } - - //===--------------------------------------------------------------------===// - // Argument Attributes - //===--------------------------------------------------------------------===// - - /// FunctionLike operations allow for attaching attributes to each of the - /// respective function arguments. These argument attributes are stored as - /// DictionaryAttrs in the main operation attribute dictionary. The name of - /// these entries is `arg` followed by the index of the argument. These - /// argument attribute dictionaries are optional, and will generally only - /// exist if they are non-empty. - - /// Return all of the attributes for the argument at 'index'. - ArrayRef getArgAttrs(unsigned index) { - return function_like_impl::getArgAttrs(this->getOperation(), index); - } - - /// Return an ArrayAttr containing all argument attribute dictionaries of this - /// function, or nullptr if no arguments have attributes. - ArrayAttr getAllArgAttrs() { - return this->getOperation()->template getAttrOfType( - function_like_impl::getArgDictAttrName()); - } - /// Return all argument attributes of this function. - void getAllArgAttrs(SmallVectorImpl &result) { - if (ArrayAttr argAttrs = getAllArgAttrs()) { - auto argAttrRange = argAttrs.template getAsRange(); - result.append(argAttrRange.begin(), argAttrRange.end()); - } else { - result.append(getNumArguments(), - DictionaryAttr::get(this->getOperation()->getContext())); - } - } - - /// Return the specified attribute, if present, for the argument at 'index', - /// null otherwise. - Attribute getArgAttr(unsigned index, StringAttr name) { - auto argDict = getArgAttrDict(index); - return argDict ? argDict.get(name) : nullptr; - } - Attribute getArgAttr(unsigned index, StringRef name) { - auto argDict = getArgAttrDict(index); - return argDict ? argDict.get(name) : nullptr; - } - - template - AttrClass getArgAttrOfType(unsigned index, StringAttr name) { - return getArgAttr(index, name).template dyn_cast_or_null(); - } - template - AttrClass getArgAttrOfType(unsigned index, StringRef name) { - return getArgAttr(index, name).template dyn_cast_or_null(); - } - - /// Set the attributes held by the argument at 'index'. - void setArgAttrs(unsigned index, ArrayRef attributes); - - /// Set the attributes held by the argument at 'index'. `attributes` may be - /// null, in which case any existing argument attributes are removed. - void setArgAttrs(unsigned index, DictionaryAttr attributes); - void setAllArgAttrs(ArrayRef attributes) { - assert(attributes.size() == getNumArguments()); - function_like_impl::setAllArgAttrDicts(this->getOperation(), attributes); - } - void setAllArgAttrs(ArrayRef attributes) { - assert(attributes.size() == getNumArguments()); - function_like_impl::setAllArgAttrDicts(this->getOperation(), attributes); - } - void setAllArgAttrs(ArrayAttr attributes) { - assert(attributes.size() == getNumArguments()); - this->getOperation()->setAttr(function_like_impl::getArgDictAttrName(), - attributes); - } - - /// If the an attribute exists with the specified name, change it to the new - /// value. Otherwise, add a new attribute with the specified name/value. - void setArgAttr(unsigned index, StringAttr name, Attribute value); - void setArgAttr(unsigned index, StringRef name, Attribute value) { - setArgAttr(index, StringAttr::get(this->getOperation()->getContext(), name), - value); - } - - /// Remove the attribute 'name' from the argument at 'index'. Return the - /// attribute that was erased, or nullptr if there was no attribute with such - /// name. - Attribute removeArgAttr(unsigned index, StringAttr name); - Attribute removeArgAttr(unsigned index, StringRef name) { - return removeArgAttr( - index, StringAttr::get(this->getOperation()->getContext(), name)); - } - - //===--------------------------------------------------------------------===// - // Result Attributes - //===--------------------------------------------------------------------===// - - /// FunctionLike operations allow for attaching attributes to each of the - /// respective function results. These result attributes are stored as - /// DictionaryAttrs in the main operation attribute dictionary. The name of - /// these entries is `result` followed by the index of the result. These - /// result attribute dictionaries are optional, and will generally only - /// exist if they are non-empty. - - /// Return all of the attributes for the result at 'index'. - ArrayRef getResultAttrs(unsigned index) { - return function_like_impl::getResultAttrs(this->getOperation(), index); - } - - /// Return an ArrayAttr containing all result attribute dictionaries of this - /// function, or nullptr if no result have attributes. - ArrayAttr getAllResultAttrs() { - return this->getOperation()->template getAttrOfType( - function_like_impl::getResultDictAttrName()); - } - /// Return all result attributes of this function. - void getAllResultAttrs(SmallVectorImpl &result) { - if (ArrayAttr argAttrs = getAllResultAttrs()) { - auto argAttrRange = argAttrs.template getAsRange(); - result.append(argAttrRange.begin(), argAttrRange.end()); - } else { - result.append(getNumResults(), - DictionaryAttr::get(this->getOperation()->getContext())); - } - } - - /// Return the specified attribute, if present, for the result at 'index', - /// null otherwise. - Attribute getResultAttr(unsigned index, StringAttr name) { - auto argDict = getResultAttrDict(index); - return argDict ? argDict.get(name) : nullptr; - } - Attribute getResultAttr(unsigned index, StringRef name) { - auto argDict = getResultAttrDict(index); - return argDict ? argDict.get(name) : nullptr; - } - - template - AttrClass getResultAttrOfType(unsigned index, StringAttr name) { - return getResultAttr(index, name).template dyn_cast_or_null(); - } - template - AttrClass getResultAttrOfType(unsigned index, StringRef name) { - return getResultAttr(index, name).template dyn_cast_or_null(); - } - - /// Set the attributes held by the result at 'index'. - void setResultAttrs(unsigned index, ArrayRef attributes); - - /// Set the attributes held by the result at 'index'. `attributes` may be - /// null, in which case any existing argument attributes are removed. - void setResultAttrs(unsigned index, DictionaryAttr attributes); - void setAllResultAttrs(ArrayRef attributes) { - assert(attributes.size() == getNumResults()); - function_like_impl::setAllResultAttrDicts(this->getOperation(), attributes); - } - void setAllResultAttrs(ArrayRef attributes) { - assert(attributes.size() == getNumResults()); - function_like_impl::setAllResultAttrDicts(this->getOperation(), attributes); - } - void setAllResultAttrs(ArrayAttr attributes) { - assert(attributes.size() == getNumResults()); - this->getOperation()->setAttr(function_like_impl::getResultDictAttrName(), - attributes); - } - - /// If the an attribute exists with the specified name, change it to the new - /// value. Otherwise, add a new attribute with the specified name/value. - void setResultAttr(unsigned index, StringAttr name, Attribute value); - void setResultAttr(unsigned index, StringRef name, Attribute value) { - setResultAttr(index, - StringAttr::get(this->getOperation()->getContext(), name), - value); - } - - /// Remove the attribute 'name' from the result at 'index'. Return the - /// attribute that was erased, or nullptr if there was no attribute with such - /// name. - Attribute removeResultAttr(unsigned index, StringAttr name); - -protected: - /// Returns the dictionary attribute corresponding to the argument at 'index'. - /// If there are no argument attributes at 'index', a null attribute is - /// returned. - DictionaryAttr getArgAttrDict(unsigned index) { - assert(index < getNumArguments() && "invalid argument number"); - return function_like_impl::getArgAttrDict(this->getOperation(), index); - } - - /// Returns the dictionary attribute corresponding to the result at 'index'. - /// If there are no result attributes at 'index', a null attribute is - /// returned. - DictionaryAttr getResultAttrDict(unsigned index) { - assert(index < getNumResults() && "invalid result number"); - return function_like_impl::getResultAttrDict(this->getOperation(), index); - } - - /// Hook for concrete classes to verify that the type attribute respects - /// op-specific invariants. Default implementation always succeeds. - LogicalResult verifyType() { return success(); } -}; - -/// Default verifier checks that if the entry block exists, it has the same -/// number of arguments as the function-like operation. -template -LogicalResult FunctionLike::verifyBody() { - auto funcOp = cast(this->getOperation()); - - if (funcOp.isExternal()) - return success(); - - unsigned numArguments = funcOp.getNumArguments(); - if (funcOp.front().getNumArguments() != numArguments) - return funcOp.emitOpError("entry block must have ") - << numArguments << " arguments to match function signature"; - - return success(); -} - -template -LogicalResult FunctionLike::verifyTrait(Operation *op) { - auto funcOp = cast(op); - if (!funcOp.isTypeAttrValid()) - return funcOp.emitOpError("requires a type attribute '") - << getTypeAttrName() << '\''; - - if (failed(funcOp.verifyType())) - return failure(); - - if (ArrayAttr allArgAttrs = funcOp.getAllArgAttrs()) { - unsigned numArgs = funcOp.getNumArguments(); - if (allArgAttrs.size() != numArgs) { - return funcOp.emitOpError() - << "expects argument attribute array `" - << function_like_impl::getArgDictAttrName() - << "` to have the same number of elements as the number of " - "function arguments, got " - << allArgAttrs.size() << ", but expected " << numArgs; - } - for (unsigned i = 0; i != numArgs; ++i) { - DictionaryAttr argAttrs = - allArgAttrs[i].dyn_cast_or_null(); - if (!argAttrs) { - return funcOp.emitOpError() << "expects argument attribute dictionary " - "to be a DictionaryAttr, but got `" - << allArgAttrs[i] << "`"; - } - - // Verify that all of the argument attributes are dialect attributes, i.e. - // that they contain a dialect prefix in their name. Call the dialect, if - // registered, to verify the attributes themselves. - for (auto attr : argAttrs) { - if (!attr.getName().strref().contains('.')) - return funcOp.emitOpError( - "arguments may only have dialect attributes"); - if (Dialect *dialect = attr.getNameDialect()) { - if (failed(dialect->verifyRegionArgAttribute(op, /*regionIndex=*/0, - /*argIndex=*/i, attr))) - return failure(); - } - } - } - } - if (ArrayAttr allResultAttrs = funcOp.getAllResultAttrs()) { - unsigned numResults = funcOp.getNumResults(); - if (allResultAttrs.size() != numResults) { - return funcOp.emitOpError() - << "expects result attribute array `" - << function_like_impl::getResultDictAttrName() - << "` to have the same number of elements as the number of " - "function results, got " - << allResultAttrs.size() << ", but expected " << numResults; - } - for (unsigned i = 0; i != numResults; ++i) { - DictionaryAttr resultAttrs = - allResultAttrs[i].dyn_cast_or_null(); - if (!resultAttrs) { - return funcOp.emitOpError() << "expects result attribute dictionary " - "to be a DictionaryAttr, but got `" - << allResultAttrs[i] << "`"; - } - - // Verify that all of the result attributes are dialect attributes, i.e. - // that they contain a dialect prefix in their name. Call the dialect, if - // registered, to verify the attributes themselves. - for (auto attr : resultAttrs) { - if (!attr.getName().strref().contains('.')) - return funcOp.emitOpError("results may only have dialect attributes"); - if (Dialect *dialect = attr.getNameDialect()) { - if (failed(dialect->verifyRegionResultAttribute(op, /*regionIndex=*/0, - /*resultIndex=*/i, - attr))) - return failure(); - } - } - } - } - - // Check that the op has exactly one region for the body. - if (op->getNumRegions() != 1) - return funcOp.emitOpError("expects one region"); - - return funcOp.verifyBody(); -} - -//===----------------------------------------------------------------------===// -// Function Body. -//===----------------------------------------------------------------------===// - -template -Block *FunctionLike::addEntryBlock() { - assert(empty() && "function already has an entry block"); - auto *entry = new Block(); - push_back(entry); - entry->addArguments(getType().getInputs()); - return entry; -} - -template -Block *FunctionLike::addBlock() { - assert(!empty() && "function should at least have an entry block"); - push_back(new Block()); - return &back(); -} - -//===----------------------------------------------------------------------===// -// Function Type Attribute. -//===----------------------------------------------------------------------===// - -template -void FunctionLike::setType(FunctionType newType) { - function_like_impl::setFunctionType(this->getOperation(), newType); -} - -//===----------------------------------------------------------------------===// -// Function Argument Attribute. -//===----------------------------------------------------------------------===// - -/// Set the attributes held by the argument at 'index'. -template -void FunctionLike::setArgAttrs( - unsigned index, ArrayRef attributes) { - assert(index < getNumArguments() && "invalid argument number"); - Operation *op = this->getOperation(); - return function_like_impl::detail::setArgResAttrDict( - op, function_like_impl::getArgDictAttrName(), getNumArguments(), index, - DictionaryAttr::get(op->getContext(), attributes)); -} - -template -void FunctionLike::setArgAttrs(unsigned index, - DictionaryAttr attributes) { - Operation *op = this->getOperation(); - return function_like_impl::detail::setArgResAttrDict( - op, function_like_impl::getArgDictAttrName(), getNumArguments(), index, - attributes ? attributes : DictionaryAttr::get(op->getContext())); -} - -/// If the an attribute exists with the specified name, change it to the new -/// value. Otherwise, add a new attribute with the specified name/value. -template -void FunctionLike::setArgAttr(unsigned index, StringAttr name, - Attribute value) { - NamedAttrList attributes(getArgAttrDict(index)); - Attribute oldValue = attributes.set(name, value); - - // If the attribute changed, then set the new arg attribute list. - if (value != oldValue) - setArgAttrs(index, attributes.getDictionary(value.getContext())); -} - -/// Remove the attribute 'name' from the argument at 'index'. -template -Attribute FunctionLike::removeArgAttr(unsigned index, - StringAttr name) { - // Build an attribute list and remove the attribute at 'name'. - NamedAttrList attributes(getArgAttrDict(index)); - Attribute removedAttr = attributes.erase(name); - - // If the attribute was removed, then update the argument dictionary. - if (removedAttr) - setArgAttrs(index, attributes.getDictionary(removedAttr.getContext())); - return removedAttr; -} - -//===----------------------------------------------------------------------===// -// Function Result Attribute. -//===----------------------------------------------------------------------===// - -/// Set the attributes held by the result at 'index'. -template -void FunctionLike::setResultAttrs( - unsigned index, ArrayRef attributes) { - assert(index < getNumResults() && "invalid result number"); - Operation *op = this->getOperation(); - return function_like_impl::detail::setArgResAttrDict( - op, function_like_impl::getResultDictAttrName(), getNumResults(), index, - DictionaryAttr::get(op->getContext(), attributes)); -} - -template -void FunctionLike::setResultAttrs(unsigned index, - DictionaryAttr attributes) { - assert(index < getNumResults() && "invalid result number"); - Operation *op = this->getOperation(); - return function_like_impl::detail::setArgResAttrDict( - op, function_like_impl::getResultDictAttrName(), getNumResults(), index, - attributes ? attributes : DictionaryAttr::get(op->getContext())); -} - -/// If the an attribute exists with the specified name, change it to the new -/// value. Otherwise, add a new attribute with the specified name/value. -template -void FunctionLike::setResultAttr(unsigned index, StringAttr name, - Attribute value) { - NamedAttrList attributes(getResultAttrDict(index)); - Attribute oldAttr = attributes.set(name, value); - - // If the attribute changed, then set the new arg attribute list. - if (oldAttr != value) - setResultAttrs(index, attributes.getDictionary(value.getContext())); -} - -/// Remove the attribute 'name' from the result at 'index'. -template -Attribute FunctionLike::removeResultAttr(unsigned index, - StringAttr name) { - // Build an attribute list and remove the attribute at 'name'. - NamedAttrList attributes(getResultAttrDict(index)); - Attribute removedAttr = attributes.erase(name); - - // If the attribute was removed, then update the result dictionary. - if (removedAttr) - setResultAttrs(index, attributes.getDictionary(removedAttr.getContext())); - return removedAttr; -} - -} // namespace OpTrait - -} // namespace mlir - -#endif // MLIR_IR_FUNCTIONSUPPORT_H diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -2044,8 +2044,6 @@ def Involution : NativeOpTrait<"IsInvolution">; // Op behaves like a constant. def ConstantLike : NativeOpTrait<"ConstantLike">; -// Op behaves like a function. -def FunctionLike : NativeOpTrait<"FunctionLike">; // Op is isolated from above. def IsolatedFromAbove : NativeOpTrait<"IsIsolatedFromAbove">; // Op results are float or vectors/tensors thereof. diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -493,17 +493,17 @@ }; /// Add a pattern to the given pattern list to convert the signature of a -/// FunctionLike op with the given type converter. This only supports -/// FunctionLike ops which use FunctionType to represent their type. -void populateFunctionLikeTypeConversionPattern(StringRef functionLikeOpName, - RewritePatternSet &patterns, - TypeConverter &converter); +/// FunctionOpInterface op with the given type converter. This only supports +/// ops which use FunctionType to represent their type. +void populateFunctionOpInterfaceTypeConversionPattern( + StringRef functionLikeOpName, RewritePatternSet &patterns, + TypeConverter &converter); template -void populateFunctionLikeTypeConversionPattern(RewritePatternSet &patterns, - TypeConverter &converter) { - populateFunctionLikeTypeConversionPattern(FuncOpT::getOperationName(), - patterns, converter); +void populateFunctionOpInterfaceTypeConversionPattern( + RewritePatternSet &patterns, TypeConverter &converter) { + populateFunctionOpInterfaceTypeConversionPattern(FuncOpT::getOperationName(), + patterns, converter); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp --- a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp +++ b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp @@ -8,7 +8,7 @@ #include "mlir/Analysis/AliasAnalysis/LocalAliasAnalysis.h" -#include "mlir/IR/FunctionSupport.h" +#include "mlir/IR/FunctionInterfaces.h" #include "mlir/IR/Matchers.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" @@ -240,7 +240,7 @@ // freed on all paths within the region, or is just not captured by anything. // For now assume allocation scope to the function scope (we don't care if // pointer escape outside function). - allocScopeOp = op->getParentWithTrait(); + allocScopeOp = op->getParentOfType(); return success(); } diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp --- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp +++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp @@ -970,7 +970,7 @@ llvmConverter.addConversion(AsyncRuntimeTypeConverter::convertAsyncTypes); // Convert async types in function signatures and function calls. - populateFunctionLikeTypeConversionPattern(patterns, converter); + populateFunctionOpInterfaceTypeConversionPattern(patterns, converter); populateCallOpTypeConversionPattern(patterns, converter); // Convert return operations inside async.execute regions. diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp @@ -57,7 +57,7 @@ SmallVector attributes; for (const auto &attr : gpuFuncOp->getAttrs()) { if (attr.getName() == SymbolTable::getSymbolAttrName() || - attr.getName() == function_like_impl::getTypeAttrName() || + attr.getName() == FunctionOpInterface::getTypeAttrName() || attr.getName() == gpu::GPUFuncOp::getNumWorkgroupAttributionsAttrName()) continue; attributes.push_back(attr); diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp @@ -197,7 +197,7 @@ rewriter.getFunctionType(signatureConverter.getConvertedTypes(), llvm::None)); for (const auto &namedAttr : funcOp->getAttrs()) { - if (namedAttr.getName() == function_like_impl::getTypeAttrName() || + if (namedAttr.getName() == FunctionOpInterface::getTypeAttrName() || namedAttr.getName() == SymbolTable::getSymbolAttrName()) continue; newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue()); diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp --- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp +++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp @@ -104,8 +104,7 @@ rewriter.create(rewriter.getUnknownLoc(), name, opFunctionTy); opFunc.setPrivate(); } - assert(SymbolTable::lookupSymbolIn(module, name) - ->template hasTrait()); + assert(isa(SymbolTable::lookupSymbolIn(module, name))); rewriter.replaceOpWithNewOp(op, name, op.getType(), op->getOperands()); diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -54,10 +54,10 @@ SmallVectorImpl &result) { for (const auto &attr : attrs) { if (attr.getName() == SymbolTable::getSymbolAttrName() || - attr.getName() == function_like_impl::getTypeAttrName() || + attr.getName() == FunctionOpInterface::getTypeAttrName() || attr.getName() == "std.varargs" || (filterArgAttrs && - attr.getName() == function_like_impl::getArgDictAttrName())) + attr.getName() == FunctionOpInterface::getArgDictAttrName())) continue; result.push_back(attr); } @@ -251,7 +251,7 @@ newArgAttrs[mapping->inputNo + j] = argAttrDicts[i]; } attributes.push_back( - rewriter.getNamedAttr(function_like_impl::getArgDictAttrName(), + rewriter.getNamedAttr(FunctionOpInterface::getArgDictAttrName(), rewriter.getArrayAttr(newArgAttrs))); } for (const auto &pair : llvm::enumerate(attributes)) { diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -672,7 +672,7 @@ return success(); SmallVector argAttrs; bool isVariadic = false; - return function_like_impl::parseFunctionArgumentList( + return function_interface_impl::parseFunctionArgumentList( parser, /*allowAttributes=*/false, /*allowVariadic=*/false, argNames, argTypes, argAttrs, isVariadic); } @@ -790,7 +790,7 @@ return failure(); auto signatureLocation = parser.getCurrentLocation(); - if (failed(function_like_impl::parseFunctionSignature( + if (failed(function_interface_impl::parseFunctionSignature( parser, /*allowVariadic=*/false, entryArgs, argTypes, argAttrs, isVariadic, resultTypes, resultAttrs))) return failure(); @@ -829,8 +829,8 @@ // Parse attributes. if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) return failure(); - function_like_impl::addArgAndResultAttrs(builder, result, argAttrs, - resultAttrs); + function_interface_impl::addArgAndResultAttrs(builder, result, argAttrs, + resultAttrs); // Parse the region. If no argument names were provided, take all names // (including those of attributions) from the entry block. @@ -855,7 +855,7 @@ p.printSymbolName(op.getName()); FunctionType type = op.getType(); - function_like_impl::printFunctionSignature( + function_interface_impl::printFunctionSignature( p, op.getOperation(), type.getInputs(), /*isVariadic=*/false, type.getResults()); @@ -864,7 +864,7 @@ if (op.isKernel()) p << ' ' << op.getKernelKeyword(); - function_like_impl::printFunctionAttributes( + function_interface_impl::printFunctionAttributes( p, op.getOperation(), type.getNumInputs(), type.getNumResults(), {op.getNumWorkgroupAttributionsAttrName(), GPUDialect::getKernelFuncAttrName()}); @@ -872,7 +872,6 @@ p.printRegion(op.getBody(), /*printEntryBlockArgs=*/false); } -/// Hook for FunctionLike verifier. LogicalResult GPUFuncOp::verifyType() { Type type = getTypeAttr().getValue(); if (!type.isa()) diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -1976,8 +1976,8 @@ assert(type.cast().getNumParams() == argAttrs.size() && "expected as many argument attribute lists as arguments"); - function_like_impl::addArgAndResultAttrs(builder, result, argAttrs, - /*resultAttrs=*/llvm::None); + function_interface_impl::addArgAndResultAttrs(builder, result, argAttrs, + /*resultAttrs=*/llvm::None); } // Builds an LLVM function type from the given lists of input and output types. @@ -1986,7 +1986,7 @@ static Type buildLLVMFunctionType(OpAsmParser &parser, llvm::SMLoc loc, ArrayRef inputs, ArrayRef outputs, - function_like_impl::VariadicFlag variadicFlag) { + function_interface_impl::VariadicFlag variadicFlag) { Builder &b = parser.getBuilder(); if (outputs.size() > 1) { parser.emitError(loc, "failed to construct function type: expected zero or " @@ -2043,23 +2043,23 @@ auto signatureLocation = parser.getCurrentLocation(); if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), result.attributes) || - function_like_impl::parseFunctionSignature( + function_interface_impl::parseFunctionSignature( parser, /*allowVariadic=*/true, entryArgs, argTypes, argAttrs, isVariadic, resultTypes, resultAttrs)) return failure(); auto type = buildLLVMFunctionType(parser, signatureLocation, argTypes, resultTypes, - function_like_impl::VariadicFlag(isVariadic)); + function_interface_impl::VariadicFlag(isVariadic)); if (!type) return failure(); - result.addAttribute(function_like_impl::getTypeAttrName(), + result.addAttribute(FunctionOpInterface::getTypeAttrName(), TypeAttr::get(type)); if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) return failure(); - function_like_impl::addArgAndResultAttrs(parser.getBuilder(), result, - argAttrs, resultAttrs); + function_interface_impl::addArgAndResultAttrs(parser.getBuilder(), result, + argAttrs, resultAttrs); auto *body = result.addRegion(); OptionalParseResult parseResult = parser.parseOptionalRegion( @@ -2087,9 +2087,9 @@ if (!returnType.isa()) resTypes.push_back(returnType); - function_like_impl::printFunctionSignature(p, op, argTypes, op.isVarArg(), - resTypes); - function_like_impl::printFunctionAttributes( + function_interface_impl::printFunctionSignature(p, op, argTypes, + op.isVarArg(), resTypes); + function_interface_impl::printFunctionAttributes( p, op, argTypes.size(), resTypes.size(), {getLinkageAttrName()}); // Print the body if this is not an external function. @@ -2101,9 +2101,6 @@ } } -// Hook for OpTrait::FunctionLike, called after verifying that the 'type' -// attribute is present. This can check for preconditions of the -// getNumArguments hook not failing. LogicalResult LLVMFuncOp::verifyType() { auto llvmType = getTypeAttr().getValue().dyn_cast_or_null(); if (!llvmType) @@ -2113,23 +2110,6 @@ return success(); } -// Hook for OpTrait::FunctionLike, returns the number of function arguments. -// Depends on the type attribute being correct as checked by verifyType -unsigned LLVMFuncOp::getNumFuncArguments() { return getType().getNumParams(); } - -// Hook for OpTrait::FunctionLike, returns the number of function results. -// Depends on the type attribute being correct as checked by verifyType -unsigned LLVMFuncOp::getNumFuncResults() { - // We model LLVM functions that return void as having zero results, - // and all others as having one result. - // If we modeled a void return as one result, then it would be possible to - // attach an MLIR result attribute to it, and it isn't clear what semantics we - // would assign to that. - if (getType().getReturnType().isa()) - return 0; - return 1; -} - // Verifies LLVM- and implementation-specific properties of the LLVM func Op: // - functions don't have 'common' linkage // - external functions have 'external' or 'extern_weak' linkage; @@ -2141,6 +2121,14 @@ << "functions cannot have '" << stringifyLinkage(LLVM::Linkage::Common) << "' linkage"; + // Check to see if this function has a void return with a result attribute to + // it. It isn't clear what semantics we would assign to that. + if (op.getType().getReturnType().isa() && + !op.getResultAttrs(0).empty()) { + return op.emitOpError() + << "cannot attach result attributes to functions with a void return"; + } + if (op.isExternal()) { if (op.getLinkage() != LLVM::Linkage::External && op.getLinkage() != LLVM::Linkage::ExternWeak) diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -113,7 +113,16 @@ isVarArg); } +LLVMFunctionType LLVMFunctionType::clone(TypeRange inputs, + TypeRange results) const { + assert(results.size() == 1 && "expected a single result type"); + return get(results[0], llvm::to_vector(inputs), isVarArg()); +} + Type LLVMFunctionType::getReturnType() { return getImpl()->getReturnType(); } +ArrayRef LLVMFunctionType::getReturnTypes() { + return getImpl()->getReturnType(); +} unsigned LLVMFunctionType::getNumParams() { return getImpl()->getArgumentTypes().size(); @@ -123,7 +132,7 @@ return getImpl()->getArgumentTypes()[i]; } -bool LLVMFunctionType::isVarArg() { return getImpl()->isVariadic(); } +bool LLVMFunctionType::isVarArg() const { return getImpl()->isVariadic(); } ArrayRef LLVMFunctionType::getParams() { return getImpl()->getArgumentTypes(); diff --git a/mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h b/mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h --- a/mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h +++ b/mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h @@ -333,9 +333,8 @@ /// Construct a storage from the given components. The list is expected to be /// allocated in the context. LLVMFunctionTypeStorage(Type result, ArrayRef arguments, bool variadic) - : argumentTypes(arguments) { - returnTypeAndVariadic.setPointerAndInt(result, variadic); - } + : resultType(result), isVariadicFlag(variadic), + numArguments(arguments.size()), argumentTypes(arguments.data()) {} /// Hook into the type uniquing infrastructure. static LLVMFunctionTypeStorage *construct(TypeStorageAllocator &allocator, @@ -358,19 +357,24 @@ } /// Returns the list of function argument types. - ArrayRef getArgumentTypes() const { return argumentTypes; } + ArrayRef getArgumentTypes() const { + return ArrayRef(argumentTypes, numArguments); + } /// Checks whether the function type is variadic. - bool isVariadic() const { return returnTypeAndVariadic.getInt(); } + bool isVariadic() const { return isVariadicFlag; } /// Returns the function result type. - Type getReturnType() const { return returnTypeAndVariadic.getPointer(); } + const Type &getReturnType() const { return resultType; } private: - /// Function result type packed with the variadic bit. - llvm::PointerIntPair returnTypeAndVariadic; - /// Argument types. - ArrayRef argumentTypes; + /// The result type of the function. + Type resultType; + /// Flag indicating if the function is variadic. + bool isVariadicFlag; + /// The argument types of the function. + unsigned numArguments; + const Type *argumentTypes; }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp @@ -15,7 +15,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/DialectImplementation.h" -#include "mlir/IR/FunctionSupport.h" +#include "mlir/IR/FunctionInterfaces.h" #include "mlir/Parser.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/InliningUtils.h" @@ -133,7 +133,7 @@ << "'" << BufferizableOpInterface::kInplaceableAttrName << "' is expected to be a boolean attribute"; } - if (!op->hasTrait()) + if (!isa(op)) return op->emitError() << "expected " << attr.getName() << " to be used on function-like operations"; return success(); @@ -144,7 +144,7 @@ << "'" << BufferizableOpInterface::kBufferLayoutAttrName << "' is expected to be a affine map attribute"; } - if (!op->hasTrait()) + if (!isa(op)) return op->emitError() << "expected " << attr.getName() << " to be used on function-like operations"; return success(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp @@ -88,19 +88,18 @@ /// A conversion pattern for detensoring internal (non-entry) blocks within a /// function. -struct FunctionNonEntryBlockConversion : public ConversionPattern { +struct FunctionNonEntryBlockConversion + : public OpInterfaceConversionPattern { FunctionNonEntryBlockConversion(MLIRContext *ctx, TypeConverter &converter, DenseSet blockArgsToDetensor) - : ConversionPattern(converter, MatchTraitOpTypeTag(), - TypeID::get(), /*benefit=*/1, - ctx), + : OpInterfaceConversionPattern(converter, ctx), blockArgsToDetensor(std::move(blockArgsToDetensor)) {} LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(FunctionOpInterface op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { rewriter.startRootUpdate(op); - Region ®ion = function_like_impl::getFunctionBody(op); + Region ®ion = op.getBody(); SmallVector conversions; for (Block &block : llvm::drop_begin(region, 1)) { @@ -197,7 +196,7 @@ /// detensored, then: /// - opsToDetensor should be = {linalg.generic{add}}. /// - blockArgsToDetensor should be = {bb1 -> {0}, bb2 -> {0}}. - virtual void compute(Operation *func, + virtual void compute(FunctionOpInterface func, DetensorizeTypeConverter typeConverter, DenseSet &opsToDetensor, DenseSet &blockArgsToDetensor) = 0; @@ -249,7 +248,8 @@ /// AND can be detensored. class ControlFlowDetectionModel : public CostModel { public: - void compute(Operation *func, DetensorizeTypeConverter typeConverter, + void compute(FunctionOpInterface func, + DetensorizeTypeConverter typeConverter, DenseSet &opsToDetensor, DenseSet &blockArgsToDetensor) override { SmallVector workList; @@ -454,7 +454,8 @@ /// Detensorize everything that can detensored. class AggressiveDetensoringModel : public CostModel { public: - void compute(Operation *func, DetensorizeTypeConverter typeConverter, + void compute(FunctionOpInterface func, + DetensorizeTypeConverter typeConverter, DenseSet &opsToDetensor, DenseSet &blockArgsToDetensor) override { func->walk([&](GenericOp genericOp) { @@ -462,16 +463,13 @@ opsToDetensor.insert(genericOp); }); - for (Block &block : - llvm::drop_begin(function_like_impl::getFunctionBody(func), 1)) + for (Block &block : llvm::drop_begin(func.getBody(), 1)) for (BlockArgument blockArgument : block.getArguments()) blockArgsToDetensor.insert(blockArgument); } }; void runOnOperation() override { - assert(getOperation()->hasTrait() && - "DetensorizePass can only be run on FunctionLike operations"); MLIRContext *context = &getContext(); DetensorizeTypeConverter typeConverter; RewritePatternSet patterns(context); @@ -479,15 +477,15 @@ DenseSet opsToDetensor; DenseMap> detensorableBranchOps; DenseSet blockArgsToDetensor; + FunctionOpInterface funcOp = cast(getOperation()); if (aggressiveMode.getValue()) { AggressiveDetensoringModel costModel; - costModel.compute(getOperation(), typeConverter, opsToDetensor, + costModel.compute(funcOp, typeConverter, opsToDetensor, blockArgsToDetensor); - } else { ControlFlowDetectionModel costModel; - costModel.compute(getOperation(), typeConverter, opsToDetensor, + costModel.compute(funcOp, typeConverter, opsToDetensor, blockArgsToDetensor); } @@ -503,8 +501,8 @@ // since detensoring can't happen along external calling convention // boundaries, which we conservatively approximate as all function // signatures. - if (op->hasTrait()) { - auto &body = function_like_impl::getFunctionBody(op); + if (auto funcOp = dyn_cast(op)) { + Region &body = funcOp.getBody(); return llvm::all_of(llvm::drop_begin(body, 1), [&](Block &block) { return !llvm::any_of( blockArgsToDetensor, [&](BlockArgument blockArgument) { diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -66,14 +66,14 @@ /// Returns true if the given op is a function-like op or nested in a /// function-like op without a module-like op in the middle. -static bool isNestedInFunctionLikeOp(Operation *op) { +static bool isNestedInFunctionOpInterface(Operation *op) { if (!op) return false; if (op->hasTrait()) return false; - if (op->hasTrait()) + if (isa(op)) return true; - return isNestedInFunctionLikeOp(op->getParentOp()); + return isNestedInFunctionOpInterface(op->getParentOp()); } /// Returns true if the given op is an module-like op that maintains a symbol @@ -1957,13 +1957,13 @@ // Parse the function signature. bool isVariadic = false; - if (function_like_impl::parseFunctionSignature( + if (function_interface_impl::parseFunctionSignature( parser, /*allowVariadic=*/false, entryArgs, argTypes, argAttrs, isVariadic, resultTypes, resultAttrs)) return failure(); auto fnType = builder.getFunctionType(argTypes, resultTypes); - state.addAttribute(function_like_impl::getTypeAttrName(), + state.addAttribute(FunctionOpInterface::getTypeAttrName(), TypeAttr::get(fnType)); // Parse the optional function control keyword. @@ -1978,8 +1978,8 @@ // Add the attributes to the function arguments. assert(argAttrs.size() == argTypes.size()); assert(resultAttrs.size() == resultTypes.size()); - function_like_impl::addArgAndResultAttrs(builder, state, argAttrs, - resultAttrs); + function_interface_impl::addArgAndResultAttrs(builder, state, argAttrs, + resultAttrs); // Parse the optional function body. auto *body = state.addRegion(); @@ -1993,12 +1993,12 @@ printer << " "; printer.printSymbolName(fnOp.sym_name()); auto fnType = fnOp.getType(); - function_like_impl::printFunctionSignature(printer, fnOp, fnType.getInputs(), - /*isVariadic=*/false, - fnType.getResults()); + function_interface_impl::printFunctionSignature( + printer, fnOp, fnType.getInputs(), + /*isVariadic=*/false, fnType.getResults()); printer << " \"" << spirv::stringifyFunctionControl(fnOp.function_control()) << "\""; - function_like_impl::printFunctionAttributes( + function_interface_impl::printFunctionAttributes( printer, fnOp, fnType.getNumInputs(), fnType.getNumResults(), {spirv::attributeName()}); diff --git a/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp b/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp --- a/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp @@ -9,7 +9,7 @@ #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/FunctionSupport.h" +#include "mlir/IR/FunctionInterfaces.h" #include "mlir/IR/Operation.h" #include "mlir/IR/SymbolTable.h" @@ -129,7 +129,7 @@ } spirv::EntryPointABIAttr spirv::lookupEntryPointABI(Operation *op) { - while (op && !op->hasTrait()) + while (op && !isa(op)) op = op->getParentOp(); if (!op) return {}; diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -580,7 +580,7 @@ // Copy over all attributes other than the function name and type. for (const auto &namedAttr : funcOp->getAttrs()) { - if (namedAttr.getName() != function_like_impl::getTypeAttrName() && + if (namedAttr.getName() != FunctionOpInterface::getTypeAttrName() && namedAttr.getName() != SymbolTable::getSymbolAttrName()) newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue()); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -124,7 +124,8 @@ .addLegalDialect(); // Populate with rules and apply rewriting rules. - populateFunctionLikeTypeConversionPattern(patterns, converter); + populateFunctionOpInterfaceTypeConversionPattern(patterns, + converter); populateCallOpTypeConversionPattern(patterns, converter); populateSparseTensorConversionPatterns(converter, patterns); if (failed(applyPartialConversion(getOperation(), target, diff --git a/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp --- a/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp @@ -31,7 +31,8 @@ RewritePatternSet patterns(context); ConversionTarget target(*context); - populateFunctionLikeTypeConversionPattern(patterns, typeConverter); + populateFunctionOpInterfaceTypeConversionPattern(patterns, + typeConverter); target.addDynamicallyLegalOp([&](FuncOp op) { return typeConverter.isSignatureLegal(op.getType()) && typeConverter.isLegal(&op.getBody()); diff --git a/mlir/lib/IR/BuiltinDialect.cpp b/mlir/lib/IR/BuiltinDialect.cpp --- a/mlir/lib/IR/BuiltinDialect.cpp +++ b/mlir/lib/IR/BuiltinDialect.cpp @@ -108,24 +108,23 @@ if (argAttrs.empty()) return; assert(type.getNumInputs() == argAttrs.size()); - function_like_impl::addArgAndResultAttrs(builder, state, argAttrs, - /*resultAttrs=*/llvm::None); + function_interface_impl::addArgAndResultAttrs(builder, state, argAttrs, + /*resultAttrs=*/llvm::None); } static ParseResult parseFuncOp(OpAsmParser &parser, OperationState &result) { - auto buildFuncType = [](Builder &builder, ArrayRef argTypes, - ArrayRef results, - function_like_impl::VariadicFlag, std::string &) { - return builder.getFunctionType(argTypes, results); - }; + auto buildFuncType = + [](Builder &builder, ArrayRef argTypes, ArrayRef results, + function_interface_impl::VariadicFlag, + std::string &) { return builder.getFunctionType(argTypes, results); }; - return function_like_impl::parseFunctionLikeOp( + return function_interface_impl::parseFunctionOp( parser, result, /*allowVariadic=*/false, buildFuncType); } static void print(FuncOp op, OpAsmPrinter &p) { FunctionType fnType = op.getType(); - function_like_impl::printFunctionLikeOp( + function_interface_impl::printFunctionOp( p, op, fnType.getInputs(), /*isVariadic=*/false, fnType.getResults()); } diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -14,6 +14,7 @@ #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dialect.h" +#include "mlir/IR/FunctionInterfaces.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/TensorEncoding.h" #include "llvm/ADT/APFloat.h" @@ -152,19 +153,8 @@ return getImpl()->getResults(); } -/// Helper to call a callback once on each index in the range -/// [0, `totalIndices`), *except* for the indices given in `indices`. -/// `indices` is allowed to have duplicates and can be in any order. -inline void iterateIndicesExcept(unsigned totalIndices, - ArrayRef indices, - function_ref callback) { - llvm::BitVector skipIndices(totalIndices); - for (unsigned i : indices) - skipIndices.set(i); - - for (unsigned i = 0; i < totalIndices; ++i) - if (!skipIndices.test(i)) - callback(i); +FunctionType FunctionType::clone(TypeRange inputs, TypeRange results) const { + return get(getContext(), inputs, results); } /// Returns a new function type with the specified arguments and results @@ -172,65 +162,24 @@ FunctionType FunctionType::getWithArgsAndResults( ArrayRef argIndices, TypeRange argTypes, ArrayRef resultIndices, TypeRange resultTypes) { - assert(argIndices.size() == argTypes.size()); - assert(resultIndices.size() == resultTypes.size()); - - ArrayRef newInputTypes = getInputs(); - SmallVector newInputTypesBuffer; - if (!argIndices.empty()) { - const auto *fromIt = newInputTypes.begin(); - for (auto it : llvm::zip(argIndices, argTypes)) { - const auto *toIt = newInputTypes.begin() + std::get<0>(it); - newInputTypesBuffer.append(fromIt, toIt); - newInputTypesBuffer.push_back(std::get<1>(it)); - fromIt = toIt; - } - newInputTypesBuffer.append(fromIt, newInputTypes.end()); - newInputTypes = newInputTypesBuffer; - } - - ArrayRef newResultTypes = getResults(); - SmallVector newResultTypesBuffer; - if (!resultIndices.empty()) { - const auto *fromIt = newResultTypes.begin(); - for (auto it : llvm::zip(resultIndices, resultTypes)) { - const auto *toIt = newResultTypes.begin() + std::get<0>(it); - newResultTypesBuffer.append(fromIt, toIt); - newResultTypesBuffer.push_back(std::get<1>(it)); - fromIt = toIt; - } - newResultTypesBuffer.append(fromIt, newResultTypes.end()); - newResultTypes = newResultTypesBuffer; - } - - return FunctionType::get(getContext(), newInputTypes, newResultTypes); + SmallVector argStorage, resultStorage; + TypeRange newArgTypes = function_interface_impl::insertTypesInto( + getInputs(), argIndices, argTypes, argStorage); + TypeRange newResultTypes = function_interface_impl::insertTypesInto( + getResults(), resultIndices, resultTypes, resultStorage); + return clone(newArgTypes, newResultTypes); } /// Returns a new function type without the specified arguments and results. FunctionType FunctionType::getWithoutArgsAndResults(ArrayRef argIndices, ArrayRef resultIndices) { - ArrayRef newInputTypes = getInputs(); - SmallVector newInputTypesBuffer; - if (!argIndices.empty()) { - unsigned originalNumArgs = getNumInputs(); - iterateIndicesExcept(originalNumArgs, argIndices, [&](unsigned i) { - newInputTypesBuffer.emplace_back(getInput(i)); - }); - newInputTypes = newInputTypesBuffer; - } - - ArrayRef newResultTypes = getResults(); - SmallVector newResultTypesBuffer; - if (!resultIndices.empty()) { - unsigned originalNumResults = getNumResults(); - iterateIndicesExcept(originalNumResults, resultIndices, [&](unsigned i) { - newResultTypesBuffer.emplace_back(getResult(i)); - }); - newResultTypes = newResultTypesBuffer; - } - - return get(getContext(), newInputTypes, newResultTypes); + SmallVector argStorage, resultStorage; + TypeRange newArgTypes = function_interface_impl::filterTypesOut( + getInputs(), argIndices, argStorage); + TypeRange newResultTypes = function_interface_impl::filterTypesOut( + getResults(), resultIndices, resultStorage); + return clone(newArgTypes, newResultTypes); } void FunctionType::walkImmediateSubElements( diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt --- a/mlir/lib/IR/CMakeLists.txt +++ b/mlir/lib/IR/CMakeLists.txt @@ -14,7 +14,7 @@ Dialect.cpp Dominance.cpp FunctionImplementation.cpp - FunctionSupport.cpp + FunctionInterfaces.cpp IntegerSet.cpp Location.cpp MLIRContext.cpp diff --git a/mlir/lib/IR/FunctionImplementation.cpp b/mlir/lib/IR/FunctionImplementation.cpp --- a/mlir/lib/IR/FunctionImplementation.cpp +++ b/mlir/lib/IR/FunctionImplementation.cpp @@ -8,12 +8,12 @@ #include "mlir/IR/FunctionImplementation.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/FunctionSupport.h" +#include "mlir/IR/FunctionInterfaces.h" #include "mlir/IR/SymbolTable.h" using namespace mlir; -ParseResult mlir::function_like_impl::parseFunctionArgumentList( +ParseResult mlir::function_interface_impl::parseFunctionArgumentList( OpAsmParser &parser, bool allowAttributes, bool allowVariadic, SmallVectorImpl &argNames, SmallVectorImpl &argTypes, SmallVectorImpl &argAttrs, @@ -128,11 +128,7 @@ return parser.parseRParen(); } -/// Parses a function signature using `parser`. The `allowVariadic` argument -/// indicates whether functions with variadic arguments are supported. The -/// trailing arguments are populated by this function with names, types and -/// attributes of the arguments and those of the results. -ParseResult mlir::function_like_impl::parseFunctionSignature( +ParseResult mlir::function_interface_impl::parseFunctionSignature( OpAsmParser &parser, bool allowVariadic, SmallVectorImpl &argNames, SmallVectorImpl &argTypes, SmallVectorImpl &argAttrs, @@ -159,16 +155,18 @@ // Add the attributes to the function arguments. if (!argAttrs.empty() && llvm::any_of(argAttrs, nonEmptyAttrsFn)) { ArrayAttr attrDicts = builder.getArrayAttr(buildAttrArrayFn(argAttrs)); - result.addAttribute(function_like_impl::getArgDictAttrName(), attrDicts); + result.addAttribute(function_interface_impl::getArgDictAttrName(), + attrDicts); } // Add the attributes to the function results. if (!resultAttrs.empty() && llvm::any_of(resultAttrs, nonEmptyAttrsFn)) { ArrayAttr attrDicts = builder.getArrayAttr(buildAttrArrayFn(resultAttrs)); - result.addAttribute(function_like_impl::getResultDictAttrName(), attrDicts); + result.addAttribute(function_interface_impl::getResultDictAttrName(), + attrDicts); } } -void mlir::function_like_impl::addArgAndResultAttrs( +void mlir::function_interface_impl::addArgAndResultAttrs( Builder &builder, OperationState &result, ArrayRef argAttrs, ArrayRef resultAttrs) { auto buildFn = [](ArrayRef attrs) { @@ -176,7 +174,7 @@ }; addArgAndResultAttrsImpl(builder, result, argAttrs, resultAttrs, buildFn); } -void mlir::function_like_impl::addArgAndResultAttrs( +void mlir::function_interface_impl::addArgAndResultAttrs( Builder &builder, OperationState &result, ArrayRef argAttrs, ArrayRef resultAttrs) { MLIRContext *context = builder.getContext(); @@ -189,9 +187,7 @@ addArgAndResultAttrsImpl(builder, result, argAttrs, resultAttrs, buildFn); } -/// Parser implementation for function-like operations. Uses `funcTypeBuilder` -/// to construct the custom function type given lists of input and output types. -ParseResult mlir::function_like_impl::parseFunctionLikeOp( +ParseResult mlir::function_interface_impl::parseFunctionOp( OpAsmParser &parser, OperationState &result, bool allowVariadic, FuncTypeBuilder funcTypeBuilder) { SmallVector entryArgs; @@ -290,9 +286,7 @@ os << ')'; } -/// Print the signature of the function-like operation `op`. Assumes `op` has -/// the FunctionLike trait and passed the verification. -void mlir::function_like_impl::printFunctionSignature( +void mlir::function_interface_impl::printFunctionSignature( OpAsmPrinter &p, Operation *op, ArrayRef argTypes, bool isVariadic, ArrayRef resultTypes) { Region &body = op->getRegion(0); @@ -331,12 +325,7 @@ } } -/// Prints the list of function prefixed with the "attributes" keyword. The -/// attributes with names listed in "elided" as well as those used by the -/// function-like operation internally are not printed. Nothing is printed -/// if all attributes are elided. Assumes `op` has the `FunctionLike` trait and -/// passed the verification. -void mlir::function_like_impl::printFunctionAttributes( +void mlir::function_interface_impl::printFunctionAttributes( OpAsmPrinter &p, Operation *op, unsigned numInputs, unsigned numResults, ArrayRef elided) { // Print out function attributes, if present. @@ -348,13 +337,9 @@ p.printOptionalAttrDictWithKeyword(op->getAttrs(), ignoredAttrs); } -/// Printer implementation for function-like operations. Accepts lists of -/// argument and result types to use while printing. -void mlir::function_like_impl::printFunctionLikeOp(OpAsmPrinter &p, - Operation *op, - ArrayRef argTypes, - bool isVariadic, - ArrayRef resultTypes) { +void mlir::function_interface_impl::printFunctionOp( + OpAsmPrinter &p, Operation *op, ArrayRef argTypes, bool isVariadic, + ArrayRef resultTypes) { // Print the operation and the function name. auto funcName = op->getAttrOfType(SymbolTable::getSymbolAttrName()) diff --git a/mlir/lib/IR/FunctionSupport.cpp b/mlir/lib/IR/FunctionInterfaces.cpp rename from mlir/lib/IR/FunctionSupport.cpp rename to mlir/lib/IR/FunctionInterfaces.cpp --- a/mlir/lib/IR/FunctionSupport.cpp +++ b/mlir/lib/IR/FunctionInterfaces.cpp @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// -#include "mlir/IR/FunctionSupport.h" +#include "mlir/IR/FunctionInterfaces.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/BitVector.h" @@ -15,9 +15,9 @@ /// Helper to call a callback once on each index in the range /// [0, `totalIndices`), *except* for the indices given in `indices`. /// `indices` is allowed to have duplicates and can be in any order. -inline void iterateIndicesExcept(unsigned totalIndices, - ArrayRef indices, - function_ref callback) { +inline static void iterateIndicesExcept(unsigned totalIndices, + ArrayRef indices, + function_ref callback) { llvm::BitVector skipIndices(totalIndices); for (unsigned i : indices) skipIndices.set(i); @@ -27,6 +27,12 @@ callback(i); } +//===----------------------------------------------------------------------===// +// Tablegen Interface Definitions +//===----------------------------------------------------------------------===// + +#include "mlir/IR/FunctionOpInterfaces.cpp.inc" + //===----------------------------------------------------------------------===// // Function Arguments and Results. //===----------------------------------------------------------------------===// @@ -35,23 +41,24 @@ return attr.cast().empty(); } -DictionaryAttr mlir::function_like_impl::getArgAttrDict(Operation *op, - unsigned index) { +DictionaryAttr mlir::function_interface_impl::getArgAttrDict(Operation *op, + unsigned index) { ArrayAttr attrs = op->getAttrOfType(getArgDictAttrName()); DictionaryAttr argAttrs = attrs ? attrs[index].cast() : DictionaryAttr(); return argAttrs; } -DictionaryAttr mlir::function_like_impl::getResultAttrDict(Operation *op, - unsigned index) { +DictionaryAttr +mlir::function_interface_impl::getResultAttrDict(Operation *op, + unsigned index) { ArrayAttr attrs = op->getAttrOfType(getResultDictAttrName()); DictionaryAttr resAttrs = attrs ? attrs[index].cast() : DictionaryAttr(); return resAttrs; } -void mlir::function_like_impl::detail::setArgResAttrDict( +void mlir::function_interface_impl::detail::setArgResAttrDict( Operation *op, StringRef attrName, unsigned numTotalIndices, unsigned index, DictionaryAttr attrs) { ArrayAttr allAttrs = op->getAttrOfType(attrName); @@ -95,12 +102,12 @@ op->setAttr(attrName, ArrayAttr::get(op->getContext(), attrs)); } -void mlir::function_like_impl::setAllArgAttrDicts( +void mlir::function_interface_impl::setAllArgAttrDicts( Operation *op, ArrayRef attrs) { setAllArgAttrDicts(op, ArrayRef(attrs.data(), attrs.size())); } -void mlir::function_like_impl::setAllArgAttrDicts(Operation *op, - ArrayRef attrs) { +void mlir::function_interface_impl::setAllArgAttrDicts( + Operation *op, ArrayRef attrs) { auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute { return !attr ? DictionaryAttr::get(op->getContext()) : attr; }); @@ -108,11 +115,11 @@ llvm::to_vector<8>(wrappedAttrs)); } -void mlir::function_like_impl::setAllResultAttrDicts( +void mlir::function_interface_impl::setAllResultAttrDicts( Operation *op, ArrayRef attrs) { setAllResultAttrDicts(op, ArrayRef(attrs.data(), attrs.size())); } -void mlir::function_like_impl::setAllResultAttrDicts( +void mlir::function_interface_impl::setAllResultAttrDicts( Operation *op, ArrayRef attrs) { auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute { return !attr ? DictionaryAttr::get(op->getContext()) : attr; @@ -121,7 +128,7 @@ llvm::to_vector<8>(wrappedAttrs)); } -void mlir::function_like_impl::insertFunctionArguments( +void mlir::function_interface_impl::insertFunctionArguments( Operation *op, ArrayRef argIndices, TypeRange argTypes, ArrayRef argAttrs, ArrayRef> argLocs, unsigned originalNumArgs, Type newType) { @@ -168,7 +175,7 @@ argLocs.empty() ? Optional{} : argLocs[i]); } -void mlir::function_like_impl::insertFunctionResults( +void mlir::function_interface_impl::insertFunctionResults( Operation *op, ArrayRef resultIndices, TypeRange resultTypes, ArrayRef resultAttrs, unsigned originalNumResults, Type newType) { @@ -210,7 +217,7 @@ op->setAttr(getTypeAttrName(), TypeAttr::get(newType)); } -void mlir::function_like_impl::eraseFunctionArguments( +void mlir::function_interface_impl::eraseFunctionArguments( Operation *op, ArrayRef argIndices, unsigned originalNumArgs, Type newType) { // There are 3 things that need to be updated: @@ -234,7 +241,7 @@ entry.eraseArguments(argIndices); } -void mlir::function_like_impl::eraseFunctionResults( +void mlir::function_interface_impl::eraseFunctionResults( Operation *op, ArrayRef resultIndices, unsigned originalNumResults, Type newType) { // There are 2 things that need to be updated: @@ -255,22 +262,48 @@ op->setAttr(getTypeAttrName(), TypeAttr::get(newType)); } +TypeRange mlir::function_interface_impl::insertTypesInto( + TypeRange oldTypes, ArrayRef indices, TypeRange newTypes, + SmallVectorImpl &storage) { + assert(indices.size() == newTypes.size() && + "mismatch between indice and type count"); + if (indices.empty()) + return oldTypes; + + auto fromIt = oldTypes.begin(); + for (auto it : llvm::zip(indices, newTypes)) { + const auto toIt = oldTypes.begin() + std::get<0>(it); + storage.append(fromIt, toIt); + storage.push_back(std::get<1>(it)); + fromIt = toIt; + } + storage.append(fromIt, oldTypes.end()); + return storage; +} + +TypeRange +mlir::function_interface_impl::filterTypesOut(TypeRange types, + ArrayRef indices, + SmallVectorImpl &storage) { + if (indices.empty()) + return types; + iterateIndicesExcept(types.size(), indices, + [&](unsigned i) { storage.emplace_back(types[i]); }); + return storage; +} + //===----------------------------------------------------------------------===// // Function type signature. //===----------------------------------------------------------------------===// -FunctionType mlir::function_like_impl::getFunctionType(Operation *op) { - assert(op->hasTrait()); - return op->getAttrOfType(getTypeAttrName()) - .getValue() - .cast(); -} - -void mlir::function_like_impl::setFunctionType(Operation *op, - FunctionType newType) { - assert(op->hasTrait()); - FunctionType oldType = getFunctionType(op); +void mlir::function_interface_impl::setFunctionType(Operation *op, + Type newType) { + FunctionOpInterface funcOp = cast(op); + unsigned oldNumArgs = funcOp.getNumArguments(); + unsigned oldNumResults = funcOp.getNumResults(); op->setAttr(getTypeAttrName(), TypeAttr::get(newType)); + unsigned newNumArgs = funcOp.getNumArguments(); + unsigned newNumResults = funcOp.getNumResults(); // Functor used to update the argument and result attributes of the function. auto updateAttrFn = [&](StringRef attrName, unsigned oldCount, @@ -298,21 +331,10 @@ }; // Update the argument and result attributes. - updateAttrFn(function_like_impl::getArgDictAttrName(), oldType.getNumInputs(), - newType.getNumInputs(), [&](Operation *op, auto &&attrs) { - setAllArgAttrDicts(op, attrs); - }); updateAttrFn( - function_like_impl::getResultDictAttrName(), oldType.getNumResults(), - newType.getNumResults(), + getArgDictAttrName(), oldNumArgs, newNumArgs, + [&](Operation *op, auto &&attrs) { setAllArgAttrDicts(op, attrs); }); + updateAttrFn( + getResultDictAttrName(), oldNumResults, newNumResults, [&](Operation *op, auto &&attrs) { setAllResultAttrDicts(op, attrs); }); } - -//===----------------------------------------------------------------------===// -// Function body. -//===----------------------------------------------------------------------===// - -Region &mlir::function_like_impl::getFunctionBody(Operation *op) { - assert(op->hasTrait()); - return op->getRegion(0); -} diff --git a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp --- a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp +++ b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp @@ -619,7 +619,7 @@ // If this operation defines a symbol, record it. if (SymbolOpInterface symbol = dyn_cast(op)) { symbols.emplace_back(symbol.getName(), - op->hasTrait() + isa(op) ? lsp::SymbolKind::Function : lsp::SymbolKind::Class, getRangeFromLoc(sourceMgr, def->scopeLoc), diff --git a/mlir/lib/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Transforms/NormalizeMemRefs.cpp --- a/mlir/lib/Transforms/NormalizeMemRefs.cpp +++ b/mlir/lib/Transforms/NormalizeMemRefs.cpp @@ -111,8 +111,8 @@ // Caller of the function. Optional symbolUses = funcOp.getSymbolUses(moduleOp); for (SymbolTable::SymbolUse symbolUse : *symbolUses) { - // TODO: Extend this for ops that are FunctionLike. This would require - // creating an OpInterface for FunctionLike ops. + // TODO: Extend this for ops that are FunctionOpInterface. This would + // require creating an OpInterface for FunctionOpInterface ops. FuncOp parentFuncOp = symbolUse.getUser()->getParentOfType(); for (FuncOp &funcOp : normalizableFuncs) { if (parentFuncOp == funcOp) { @@ -297,8 +297,8 @@ // TODO: Further optimization - Check if the memref is indeed part of // ReturnOp at the parentFuncOp and only then updation of signature is // required. - // TODO: Extend this for ops that are FunctionLike. This would require - // creating an OpInterface for FunctionLike ops. + // TODO: Extend this for ops that are FunctionOpInterface. This would + // require creating an OpInterface for FunctionOpInterface ops. FuncOp parentFuncOp = newCallOp->getParentOfType(); funcOpsToUpdate.insert(parentFuncOp); } diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -11,7 +11,7 @@ #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/FunctionSupport.h" +#include "mlir/IR/FunctionInterfaces.h" #include "mlir/Rewrite/PatternApplicator.h" #include "mlir/Transforms/Utils.h" #include "llvm/ADT/ScopeExit.h" @@ -3049,49 +3049,49 @@ } //===----------------------------------------------------------------------===// -// FunctionLikeSignatureConversion +// FunctionOpInterfaceSignatureConversion //===----------------------------------------------------------------------===// /// Create a default conversion pattern that rewrites the type signature of a -/// FunctionLike op. This only supports FunctionLike ops which use FunctionType -/// to represent their type. +/// FunctionOpInterface op. This only supports ops which use FunctionType to +/// represent their type. namespace { -struct FunctionLikeSignatureConversion : public ConversionPattern { - FunctionLikeSignatureConversion(StringRef functionLikeOpName, - MLIRContext *ctx, TypeConverter &converter) +struct FunctionOpInterfaceSignatureConversion : public ConversionPattern { + FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName, + MLIRContext *ctx, + TypeConverter &converter) : ConversionPattern(converter, functionLikeOpName, /*benefit=*/1, ctx) {} - /// Hook to implement combined matching and rewriting for FunctionLike ops. LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - FunctionType type = function_like_impl::getFunctionType(op); + FunctionOpInterface funcOp = cast(op); + FunctionType type = funcOp.getType().cast(); // Convert the original function types. TypeConverter::SignatureConversion result(type.getNumInputs()); SmallVector newResults; if (failed(typeConverter->convertSignatureArgs(type.getInputs(), result)) || failed(typeConverter->convertTypes(type.getResults(), newResults)) || - failed(rewriter.convertRegionTypes( - &function_like_impl::getFunctionBody(op), *typeConverter, &result))) + failed(rewriter.convertRegionTypes(&funcOp.getBody(), *typeConverter, + &result))) return failure(); // Update the function signature in-place. auto newType = FunctionType::get(rewriter.getContext(), result.getConvertedTypes(), newResults); - rewriter.updateRootInPlace( - op, [&] { function_like_impl::setFunctionType(op, newType); }); + rewriter.updateRootInPlace(op, [&] { funcOp.setType(newType); }); return success(); } }; } // namespace -void mlir::populateFunctionLikeTypeConversionPattern( +void mlir::populateFunctionOpInterfaceTypeConversionPattern( StringRef functionLikeOpName, RewritePatternSet &patterns, TypeConverter &converter) { - patterns.add( + patterns.add( functionLikeOpName, patterns.getContext(), converter); } diff --git a/mlir/test/Dialect/LLVMIR/func.mlir b/mlir/test/Dialect/LLVMIR/func.mlir --- a/mlir/test/Dialect/LLVMIR/func.mlir +++ b/mlir/test/Dialect/LLVMIR/func.mlir @@ -215,6 +215,13 @@ // ----- +module { + // expected-error@+1 {{cannot attach result attributes to functions with a void return}} + llvm.func @variadic_def() -> (!llvm.void {llvm.noalias}) +} + +// ----- + module { // expected-error@+1 {{variadic arguments must be in the end of the argument list}} llvm.func @variadic_inside(%arg0: i32, ..., %arg1: i32) diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -656,7 +656,8 @@ TestNestedOpCreationUndoRewrite, TestReplaceEraseOp, TestCreateUnregisteredOp>(&getContext()); patterns.add(&getContext(), converter); - mlir::populateFunctionLikeTypeConversionPattern(patterns, converter); + mlir::populateFunctionOpInterfaceTypeConversionPattern(patterns, + converter); mlir::populateCallOpTypeConversionPattern(patterns, converter); // Define the conversion target used for the test. @@ -1121,7 +1122,8 @@ TestTestSignatureConversionNoConverter>(converter, &getContext()); patterns.add(&getContext()); - mlir::populateFunctionLikeTypeConversionPattern(patterns, converter); + mlir::populateFunctionOpInterfaceTypeConversionPattern(patterns, + converter); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) diff --git a/mlir/test/lib/IR/TestFunc.cpp b/mlir/test/lib/IR/TestFunc.cpp --- a/mlir/test/lib/IR/TestFunc.cpp +++ b/mlir/test/lib/IR/TestFunc.cpp @@ -1,4 +1,4 @@ -//===- TestFunctionLike.cpp - Pass to test helpers on FunctionLike --------===// +//===- TestFunc.cpp - Pass to test helpers on function utilities ----------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information.