diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td @@ -273,6 +273,9 @@ /// Hooks for the input/output type enumeration in FunctionLike . unsigned getNumFuncArguments() { return getType().getNumInputs(); } unsigned getNumFuncResults() { return getType().getNumResults(); } + Type getTypeWithoutArguments(ArrayRef argIndices) { + return getType().getWithoutArguments(argIndices); + } /// Returns the keywords used in the custom syntax for this Op. static StringRef getWorkgroupKeyword() { return "workgroup"; } diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -757,6 +757,9 @@ return getAttrOfType(getTypeAttrName()) .getValue().cast(); } + + void setType(LLVMType newType); + bool isVarArg() { return getType().isFunctionVarArg(); } @@ -769,6 +772,11 @@ // Depends on the type attribute being correct as checked by verifyType. unsigned getNumFuncResults(); + // Hook for OpTrait::FunctionLike, returns the signature type without the + // specified arguments. Depends on the type attribute being correct as + // checked by verifyType. + Type getTypeWithoutArguments(ArrayRef argIndices); + // Hook for OpTrait::FunctionLike, called after verifying that the 'type' // attribute is present. This can check for preconditions of the // getNumArguments hook not failing. diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h @@ -119,6 +119,7 @@ LLVMType getFunctionParamType(unsigned argIdx); unsigned getFunctionNumParams(); LLVMType getFunctionResultType(); + LLVMType getFunctionTypeWithoutArguments(ArrayRef argIndices); bool isFunctionTy(); bool isFunctionVarArg(); @@ -332,6 +333,9 @@ ArrayRef getParams(); ArrayRef params() { return getParams(); } + // Helper for returning modified type without arguments. + LLVMType getWithoutArguments(ArrayRef argIndices); + /// Verifies that the type about to be constructed is well-formed. static LogicalResult verifyConstructionInvariants(Location loc, LLVMType result, diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td @@ -262,6 +262,12 @@ /// Returns the number of results. Hook for OpTrait::FunctionLike. unsigned getNumFuncResults() { return getType().getNumResults(); } + /// Returns the signature type without the specified arguments. This is a + // hook for OpTrait::FunctionLike. + Type getTypeWithoutArguments(ArrayRef argIndices) { + return getType().getWithoutArguments(argIndices); + } + /// Hook for OpTrait::FunctionLike, called after verifying that the 'type' /// attribute is present and checks if it holds a function type. Ensures /// getType, getNumFuncArguments, and getNumFuncResults can be called safely diff --git a/mlir/include/mlir/IR/Function.h b/mlir/include/mlir/IR/Function.h --- a/mlir/include/mlir/IR/Function.h +++ b/mlir/include/mlir/IR/Function.h @@ -59,12 +59,6 @@ void print(OpAsmPrinter &p); LogicalResult verify(); - /// Erase a single argument at `argIndex`. - void eraseArgument(unsigned argIndex) { eraseArguments({argIndex}); } - /// Erases the arguments listed in `argIndices`. - /// `argIndices` is allowed to have duplicates and can be in any order. - void eraseArguments(ArrayRef argIndices); - /// Erase a single result at `resultIndex`. void eraseResult(unsigned resultIndex) { eraseResults({resultIndex}); } /// Erases the results listed in `resultIndices`. @@ -107,6 +101,12 @@ /// Returns the number of results. This is a hook for OpTrait::FunctionLike. unsigned getNumFuncResults() { return getType().getResults().size(); } + /// Returns the signature type without the specified arguments. This is a hook + /// for OpTrait::FunctionLike. + FunctionType getTypeWithoutArguments(ArrayRef argIndices) { + return getType().getWithoutArguments(argIndices); + } + /// Hook for OpTrait::FunctionLike, called after verifying that the 'type' /// attribute is present and checks if it holds a function type. Ensures /// getType, getNumFuncArguments, and getNumFuncResults can be called safely. @@ -123,7 +123,8 @@ namespace llvm { // Functions hash just like pointers. -template <> struct DenseMapInfo { +template <> +struct DenseMapInfo { static mlir::FuncOp getEmptyKey() { auto pointer = llvm::DenseMapInfo::getEmptyKey(); return mlir::FuncOp::getFromOpaquePointer(pointer); @@ -139,7 +140,8 @@ }; /// Allow stealing the low bits of FuncOp. -template <> struct PointerLikeTypeTraits { +template <> +struct PointerLikeTypeTraits { public: static inline void *getAsVoidPointer(mlir::FuncOp I) { return const_cast(I.getAsOpaquePointer()); diff --git a/mlir/include/mlir/IR/FunctionSupport.h b/mlir/include/mlir/IR/FunctionSupport.h --- a/mlir/include/mlir/IR/FunctionSupport.h +++ b/mlir/include/mlir/IR/FunctionSupport.h @@ -15,6 +15,7 @@ #define MLIR_IR_FUNCTIONSUPPORT_H #include "mlir/IR/OpDefinition.h" +#include "llvm/ADT/BitVector.h" #include "llvm/ADT/SmallString.h" namespace mlir { @@ -216,6 +217,16 @@ return static_cast(this)->getNumFuncResults(); } + /// Get an updated signature type without the specified arguments. + /// + /// Note that the concrete class must define a method with the same name to + /// hide this one if the concrete class does not use FunctionType for the + /// function type under the hood. + FunctionType getTypeWithoutArguments(ArrayRef argIndices) { + return static_cast(this)->getTypeWithoutArguments( + argIndices); + } + /// Gets argument. BlockArgument getArgument(unsigned idx) { return getBody().getArgument(idx); } @@ -229,6 +240,12 @@ return getBody().getArgumentTypes(); } + /// Erase a single argument at `argIndex`. + void eraseArgument(unsigned argIndex) { eraseArguments({argIndex}); } + /// Erases the arguments listed in `argIndices`. + /// `argIndices` is allowed to have duplicates and can be in any order. + void eraseArguments(ArrayRef argIndices); + //===--------------------------------------------------------------------===// // Argument Attributes //===--------------------------------------------------------------------===// @@ -497,6 +514,46 @@ concreteOp->setAttr(getTypeAttrName(), TypeAttr::get(newType)); } +//===----------------------------------------------------------------------===// +// Function Arguments. +//===----------------------------------------------------------------------===// + +template +void FunctionLike::eraseArguments(ArrayRef argIndices) { + size_t originalNumArgs = getNumArguments(); + llvm::BitVector eraseIndices(originalNumArgs); + for (auto index : argIndices) + eraseIndices.set(index); + auto shouldEraseArg = [&](size_t i) { return eraseIndices.test(i); }; + + // There are 3 things that need to be updated: + // - Function type. + // - Arg attrs. + // - Block arguments of entry block. + + // Collect arg attrs to set. + SmallVector newArgAttrs; + for (size_t i = 0; i < originalNumArgs; ++i) { + if (shouldEraseArg(i)) + continue; + newArgAttrs.emplace_back(getArgAttrDict(i)); + } + + // Update the function type. + setType(getTypeWithoutArguments(argIndices)); + + // Update the arg attrs. + setAllArgAttrs(newArgAttrs); + + // Update the entry block's arguments. + // We do this in reverse so that we erase later indices before earlier + // indices, to avoid shifting the later indices. + Block &entry = front(); + for (size_t i = 0; i < originalNumArgs; ++i) + if (shouldEraseArg(originalNumArgs - i - 1)) + entry.eraseArgument(originalNumArgs - i - 1); +} + //===----------------------------------------------------------------------===// // Function Argument Attribute. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h --- a/mlir/include/mlir/IR/Types.h +++ b/mlir/include/mlir/IR/Types.h @@ -104,12 +104,16 @@ bool operator!() const { return impl == nullptr; } - template bool isa() const; + template + bool isa() const; template bool isa() const; - template U dyn_cast() const; - template U dyn_cast_or_null() const; - template U cast() const; + template + U dyn_cast() const; + template + U dyn_cast_or_null() const; + template + U cast() const; // Support type casting Type to itself. static bool classof(Type) { return true; } @@ -247,6 +251,9 @@ unsigned getNumResults() const; Type getResult(unsigned i) const { return getResults()[i]; } ArrayRef getResults() const; + + // Helper for returning modified type without arguments. + FunctionType getWithoutArguments(ArrayRef argIndices); }; //===----------------------------------------------------------------------===// @@ -288,7 +295,8 @@ return ::llvm::hash_value(arg.impl); } -template bool Type::isa() const { +template +bool Type::isa() const { assert(impl && "isa<> used on a null type."); return U::classof(*this); } @@ -298,13 +306,16 @@ return isa() || isa(); } -template U Type::dyn_cast() const { +template +U Type::dyn_cast() const { return isa() ? U(impl) : U(nullptr); } -template U Type::dyn_cast_or_null() const { +template +U Type::dyn_cast_or_null() const { return (impl && isa()) ? U(impl) : U(nullptr); } -template U Type::cast() const { +template +U Type::cast() const { assert(isa()); return U(impl); } @@ -314,7 +325,8 @@ namespace llvm { // Type hash just like pointers. -template <> struct DenseMapInfo { +template <> +struct DenseMapInfo { static mlir::Type getEmptyKey() { auto pointer = llvm::DenseMapInfo::getEmptyKey(); return mlir::Type(static_cast(pointer)); @@ -328,7 +340,8 @@ }; /// We align TypeStorage by 8, so allow LLVM to steal the low bits. -template <> struct PointerLikeTypeTraits { +template <> +struct PointerLikeTypeTraits { public: static inline void *getAsVoidPointer(mlir::Type I) { return const_cast(I.getAsOpaquePointer()); diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -1453,6 +1453,25 @@ return 1; } +// Hook for OpTrait::FunctionLike, returns the signature type without the +// specified arguments. Depends on the type attribute being correct as checked +// by verifyType. +Type LLVMFuncOp::getTypeWithoutArguments(ArrayRef argIndices) { + return getType().getFunctionTypeWithoutArguments(argIndices); +} + +void LLVMFuncOp::setType(LLVMType newType) { + SmallVector nameBuf; + auto oldType = getType(); + + for (size_t i = newType.getFunctionNumParams(), + e = oldType.getFunctionNumParams(); + i < e; ++i) + removeAttr(getArgAttrName(i, nameBuf)); + + setAttr(getTypeAttrName(), TypeAttr::get(newType)); +} + // Verifies LLVM- and implementation-specific properties of the LLVM func Op: // - functions don't have 'common' linkage // - external functions have 'external' or 'extern_weak' linkage; diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -18,6 +18,7 @@ #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/TypeSupport.h" +#include "llvm/ADT/BitVector.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/TypeSize.h" @@ -127,6 +128,11 @@ return cast().getReturnType(); } +LLVMType +LLVMType::getFunctionTypeWithoutArguments(ArrayRef argIndices) { + return cast().getWithoutArguments(argIndices); +} + bool LLVMType::isFunctionTy() { return isa(); } bool LLVMType::isFunctionVarArg() { @@ -329,6 +335,23 @@ return getImpl()->getArgumentTypes(); } +LLVMType LLVMFunctionType::getWithoutArguments(ArrayRef argIndices) { + unsigned originalNumArgs = getNumParams(); + llvm::BitVector skipIndices(originalNumArgs); + for (auto index : argIndices) + skipIndices.set(index); + + SmallVector newInputTypes; + for (unsigned i = 0; i < originalNumArgs; ++i) { + if (skipIndices.test(i)) + continue; + + newInputTypes.emplace_back(getParamType(i)); + } + + return get(getReturnType(), newInputTypes, isVarArg()); +} + LogicalResult LLVMFunctionType::verifyConstructionInvariants( Location loc, LLVMType result, ArrayRef arguments, bool) { if (!isValidResultType(result)) diff --git a/mlir/lib/IR/Function.cpp b/mlir/lib/IR/Function.cpp --- a/mlir/lib/IR/Function.cpp +++ b/mlir/lib/IR/Function.cpp @@ -98,40 +98,6 @@ return success(); } -void FuncOp::eraseArguments(ArrayRef argIndices) { - auto oldType = getType(); - int originalNumArgs = oldType.getNumInputs(); - llvm::BitVector eraseIndices(originalNumArgs); - for (auto index : argIndices) - eraseIndices.set(index); - auto shouldEraseArg = [&](int i) { return eraseIndices.test(i); }; - - // There are 3 things that need to be updated: - // - Function type. - // - Arg attrs. - // - Block arguments of entry block. - - // Update the function type and arg attrs. - SmallVector newInputTypes; - SmallVector newArgAttrs; - for (int i = 0; i < originalNumArgs; i++) { - if (shouldEraseArg(i)) - continue; - newInputTypes.emplace_back(oldType.getInput(i)); - newArgAttrs.emplace_back(getArgAttrDict(i)); - } - setType(FunctionType::get(newInputTypes, oldType.getResults(), getContext())); - setAllArgAttrs(newArgAttrs); - - // Update the entry block's arguments. - // We do this in reverse so that we erase later indices before earlier - // indices, to avoid shifting the later indices. - Block &entry = front(); - for (int i = 0; i < originalNumArgs; i++) - if (shouldEraseArg(originalNumArgs - i - 1)) - entry.eraseArgument(originalNumArgs - i - 1); -} - void FuncOp::eraseResults(ArrayRef resultIndices) { auto oldType = getType(); int originalNumResults = oldType.getNumResults(); diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp --- a/mlir/lib/IR/Types.cpp +++ b/mlir/lib/IR/Types.cpp @@ -10,6 +10,7 @@ #include "TypeDetail.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dialect.h" +#include "llvm/ADT/BitVector.h" #include "llvm/ADT/Twine.h" using namespace mlir; @@ -46,6 +47,24 @@ return getImpl()->getResults(); } +// Helper for returning modified type without arguments. +FunctionType FunctionType::getWithoutArguments(ArrayRef argIndices) { + unsigned originalNumArgs = getNumInputs(); + llvm::BitVector skipIndices(originalNumArgs); + for (auto index : argIndices) + skipIndices.set(index); + + SmallVector newInputTypes; + for (unsigned i = 0; i < originalNumArgs; ++i) { + if (skipIndices.test(i)) + continue; + + newInputTypes.emplace_back(getInput(i)); + } + + return get(newInputTypes, getResults(), getContext()); +} + //===----------------------------------------------------------------------===// // OpaqueType //===----------------------------------------------------------------------===//