diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h --- a/mlir/include/mlir/IR/Block.h +++ b/mlir/include/mlir/IR/Block.h @@ -98,6 +98,10 @@ /// Erase the argument at 'index' and remove it from the argument list. void eraseArgument(unsigned index); + /// Erases the arguments listed in `argIndices` and removes them from the + /// argument list. + /// `argIndices` is allowed to have duplicates and can be in any order. + void eraseArguments(ArrayRef argIndices); unsigned getNumArguments() { return arguments.size(); } BlockArgument getArgument(unsigned i) { return arguments[i]; } diff --git a/mlir/include/mlir/IR/Function.h b/mlir/include/mlir/IR/Function.h --- a/mlir/include/mlir/IR/Function.h +++ b/mlir/include/mlir/IR/Function.h @@ -59,18 +59,6 @@ void print(OpAsmPrinter &p); LogicalResult verify(); - /// Erase a single argument at `argIndex`. - void eraseArgument(unsigned argIndex) { eraseArguments({argIndex}); } - /// Erases the arguments listed in `argIndices`. - /// `argIndices` is allowed to have duplicates and can be in any order. - void eraseArguments(ArrayRef argIndices); - - /// Erase a single result at `resultIndex`. - void eraseResult(unsigned resultIndex) { eraseResults({resultIndex}); } - /// Erases the results listed in `resultIndices`. - /// `resultIndices` is allowed to have duplicates and can be in any order. - void eraseResults(ArrayRef resultIndices); - /// Create a deep copy of this function and all of its blocks, remapping /// any operands that use values outside of the function using the map that is /// provided (leaving them alone if no entry is present). If the mapper diff --git a/mlir/include/mlir/IR/FunctionSupport.h b/mlir/include/mlir/IR/FunctionSupport.h --- a/mlir/include/mlir/IR/FunctionSupport.h +++ b/mlir/include/mlir/IR/FunctionSupport.h @@ -84,12 +84,21 @@ /// arguments; /// - they can have argument attributes that are stored in a dictionary /// attribute on the Op itself. -/// This trait does *NOT* provide type support for the functions, meaning that -/// concrete Ops must handle the type of the declared or defined function. -/// `getTypeAttrName()` is a convenience function that returns the name of the -/// attribute that can be used to store the function type, but the trait makes -/// no assumption based on it. /// +/// This trait provides limited type support for the declared or defined +/// functions. The convenience function `getTypeAttrName()` returns the name of +/// an attribute that can be used to store the function type. In addition, this +/// trait provides `getType` and `setType` helpers to store a `FunctionType` in +/// the attribute named by `getTypeAttrName()`. +/// +/// In general, this trait assumes concrete ops use `FunctionType` under the +/// hood. If this is not the case, in order to use the function type support, +/// concrete ops must define the following methods, using the same name, to hide +/// the ones defined for `FunctionType`: `addBodyBlock`, `getType`, +/// `getTypeWithoutArgsAndResults` and `setType`. +/// +/// Besides the requirements above, concrete ops must interact with this trait +/// using the following functions: /// - Concrete ops *must* define a member function `getNumFuncArguments()` that /// returns the number of function arguments based exclusively on type (so /// that it can be called on function declarations). @@ -183,6 +192,17 @@ return getTypeAttr().getValue().template cast(); } + /// Return the type of this function without the specified arguments and + /// results. This is used to update the function's signature in the + /// `eraseArguments` and `eraseResults` methods. The arrays of indices are + /// allowed to have duplicates and can be in any order. + /// + /// Note that the concrete class must define a method with the same name to + /// hide this one if the concrete class does not use FunctionType for the + /// function type under the hood. + FunctionType getTypeWithoutArgsAndResults(ArrayRef argIndices, + ArrayRef resultIndices); + bool isTypeAttrValid() { auto typeAttr = getTypeAttr(); if (!typeAttr) @@ -204,7 +224,7 @@ void setType(FunctionType newType); //===--------------------------------------------------------------------===// - // Argument Handling + // Argument and Result Handling //===--------------------------------------------------------------------===// using BlockArgListType = Region::BlockArgListType; @@ -229,6 +249,18 @@ return getBody().getArgumentTypes(); } + /// Erase a single argument at `argIndex`. + void eraseArgument(unsigned argIndex); + /// Erases the arguments listed in `argIndices`. + /// `argIndices` is allowed to have duplicates and can be in any order. + void eraseArguments(ArrayRef argIndices); + + /// Erase a single result at `resultIndex`. + void eraseResult(unsigned resultIndex); + /// Erases the results listed in `resultIndices`. + /// `resultIndices` is allowed to have duplicates and can be in any order. + void eraseResults(ArrayRef resultIndices); + //===--------------------------------------------------------------------===// // Argument Attributes //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/FunctionSupportImplementation.h b/mlir/include/mlir/IR/FunctionSupportImplementation.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/IR/FunctionSupportImplementation.h @@ -0,0 +1,100 @@ +//===- FunctionSupportImplementation.h - Implementations for function-like --=// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/FunctionSupport.h" +#include "llvm/ADT/BitVector.h" + +namespace mlir { + +namespace OpTrait { + +//===----------------------------------------------------------------------===// +// Function Type Attribute. +//===----------------------------------------------------------------------===// + +template +FunctionType FunctionLike::getTypeWithoutArgsAndResults( + ArrayRef argIndices, ArrayRef resultIndices) { + return getType().getWithoutArgsAndResults(argIndices, resultIndices); +} + +//===----------------------------------------------------------------------===// +// Function Arguments and Results. +//===----------------------------------------------------------------------===// + +template +void FunctionLike::eraseArgument(unsigned argIndex) { + eraseArguments({argIndex}); +} + +template +void FunctionLike::eraseArguments(ArrayRef argIndices) { + size_t originalNumArgs = getNumArguments(); + llvm::BitVector eraseIndices(originalNumArgs); + for (auto index : argIndices) + eraseIndices.set(index); + + // There are 3 things that need to be updated: + // - Function type. + // - Arg attrs. + // - Block arguments of entry block. + + // Collect arg attrs to set. + SmallVector newArgAttrs; + for (size_t i = 0; i < originalNumArgs; ++i) { + if (eraseIndices.test(i)) + continue; + newArgAttrs.emplace_back(getArgAttrDict(i)); + } + + // Update the function type. + setType(getTypeWithoutArgsAndResults(argIndices, {})); + + // Update the arg attrs. + setAllArgAttrs(newArgAttrs); + + // Update the entry block's arguments. + Block &entry = front(); + entry.eraseArguments(argIndices); +} + +template +void FunctionLike::eraseResult(unsigned resultIndex) { + eraseResults({resultIndex}); +} + +template +void FunctionLike::eraseResults( + ArrayRef resultIndices) { + size_t originalNumResults = getNumResults(); + llvm::BitVector eraseIndices(originalNumResults); + for (auto index : resultIndices) + eraseIndices.set(index); + + // There are 2 things that need to be updated: + // - Function type. + // - Result attrs. + + // Collect result attrs to set. + SmallVector newResultAttrs; + for (size_t i = 0; i < originalNumResults; i++) { + if (eraseIndices.test(i)) + continue; + newResultAttrs.emplace_back(getResultAttrDict(i)); + } + + // Update the function type. + setType(getTypeWithoutArgsAndResults({}, resultIndices)); + + // Update the result attrs. + setAllResultAttrs(newResultAttrs); +} + +} // end namespace OpTrait + +} // end namespace mlir diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h --- a/mlir/include/mlir/IR/Types.h +++ b/mlir/include/mlir/IR/Types.h @@ -238,15 +238,19 @@ static FunctionType get(TypeRange inputs, TypeRange results, MLIRContext *context); - // Input types. + /// Input types. unsigned getNumInputs() const; Type getInput(unsigned i) const { return getInputs()[i]; } ArrayRef getInputs() const; - // Result types. + /// Result types. unsigned getNumResults() const; Type getResult(unsigned i) const { return getResults()[i]; } ArrayRef getResults() const; + + /// Returns a new function type without the specified arguments and results. + FunctionType getWithoutArgsAndResults(ArrayRef argIndices, + ArrayRef resultIndices); }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -25,8 +25,6 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" -#include - #define DEBUG_TYPE "linalg-drop-unit-dims" using namespace mlir; @@ -166,9 +164,8 @@ for (unsigned unitDimLoop : unitDims) { entryBlock->getArgument(unitDimLoop).replaceAllUsesWith(zero); } - std::set orderedUnitDims(unitDims.begin(), unitDims.end()); - for (unsigned i : llvm::reverse(orderedUnitDims)) - entryBlock->eraseArgument(i); + SmallVector unitDimsToErase(unitDims.begin(), unitDims.end()); + entryBlock->eraseArguments(unitDimsToErase); return success(); } diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp --- a/mlir/lib/IR/Block.cpp +++ b/mlir/lib/IR/Block.cpp @@ -9,6 +9,7 @@ #include "mlir/IR/Block.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Operation.h" +#include "llvm/ADT/BitVector.h" using namespace mlir; //===----------------------------------------------------------------------===// @@ -176,6 +177,19 @@ arguments.erase(arguments.begin() + index); } +void Block::eraseArguments(ArrayRef argIndices) { + size_t originalNumArgs = getNumArguments(); + llvm::BitVector eraseIndices(originalNumArgs); + for (auto index : argIndices) + eraseIndices.set(index); + + // We do this in reverse so that we erase later indices before earlier + // indices, to avoid shifting the later indices. + for (size_t i = 0; i < originalNumArgs; ++i) + if (eraseIndices.test(originalNumArgs - i - 1)) + eraseArgument(originalNumArgs - i - 1); +} + /// Insert one value to the given position of the argument list. The existing /// arguments are shifted. The block is expected not to have predecessors. BlockArgument Block::insertArgument(args_iterator it, Type type) { diff --git a/mlir/lib/IR/Function.cpp b/mlir/lib/IR/Function.cpp --- a/mlir/lib/IR/Function.cpp +++ b/mlir/lib/IR/Function.cpp @@ -98,65 +98,6 @@ return success(); } -void FuncOp::eraseArguments(ArrayRef argIndices) { - auto oldType = getType(); - int originalNumArgs = oldType.getNumInputs(); - llvm::BitVector eraseIndices(originalNumArgs); - for (auto index : argIndices) - eraseIndices.set(index); - auto shouldEraseArg = [&](int i) { return eraseIndices.test(i); }; - - // There are 3 things that need to be updated: - // - Function type. - // - Arg attrs. - // - Block arguments of entry block. - - // Update the function type and arg attrs. - SmallVector newInputTypes; - SmallVector newArgAttrs; - for (int i = 0; i < originalNumArgs; i++) { - if (shouldEraseArg(i)) - continue; - newInputTypes.emplace_back(oldType.getInput(i)); - newArgAttrs.emplace_back(getArgAttrDict(i)); - } - setType(FunctionType::get(newInputTypes, oldType.getResults(), getContext())); - setAllArgAttrs(newArgAttrs); - - // Update the entry block's arguments. - // We do this in reverse so that we erase later indices before earlier - // indices, to avoid shifting the later indices. - Block &entry = front(); - for (int i = 0; i < originalNumArgs; i++) - if (shouldEraseArg(originalNumArgs - i - 1)) - entry.eraseArgument(originalNumArgs - i - 1); -} - -void FuncOp::eraseResults(ArrayRef resultIndices) { - auto oldType = getType(); - int originalNumResults = oldType.getNumResults(); - llvm::BitVector eraseIndices(originalNumResults); - for (auto index : resultIndices) - eraseIndices.set(index); - auto shouldEraseResult = [&](int i) { return eraseIndices.test(i); }; - - // There are 2 things that need to be updated: - // - Function type. - // - Result attrs. - - // Update the function type and result attrs. - SmallVector newResultTypes; - SmallVector newResultAttrs; - for (int i = 0; i < originalNumResults; i++) { - if (shouldEraseResult(i)) - continue; - newResultTypes.emplace_back(oldType.getResult(i)); - newResultAttrs.emplace_back(getResultAttrDict(i)); - } - setType(FunctionType::get(oldType.getInputs(), newResultTypes, getContext())); - setAllResultAttrs(newResultAttrs); -} - /// Clone the internal blocks from this function into dest and all attributes /// from this function to dest. void FuncOp::cloneInto(FuncOp dest, BlockAndValueMapping &mapper) { diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp --- a/mlir/lib/IR/Types.cpp +++ b/mlir/lib/IR/Types.cpp @@ -10,6 +10,7 @@ #include "TypeDetail.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dialect.h" +#include "llvm/ADT/BitVector.h" #include "llvm/ADT/Twine.h" using namespace mlir; @@ -46,6 +47,39 @@ return getImpl()->getResults(); } +// Returns a new function type without the specified arguments and results. +FunctionType +FunctionType::getWithoutArgsAndResults(ArrayRef argIndices, + ArrayRef resultIndices) { + unsigned originalNumArgs = getNumInputs(); + llvm::BitVector skipArgIndices(originalNumArgs); + for (auto index : argIndices) + skipArgIndices.set(index); + + SmallVector newInputTypes; + for (unsigned i = 0; i < originalNumArgs; ++i) { + if (skipArgIndices.test(i)) + continue; + + newInputTypes.emplace_back(getInput(i)); + } + + unsigned originalNumResults = getNumResults(); + llvm::BitVector skipResultIndices(originalNumResults); + for (auto index : resultIndices) + skipResultIndices.set(index); + + SmallVector newOutputTypes; + for (unsigned i = 0; i < originalNumResults; ++i) { + if (skipResultIndices.test(i)) + continue; + + newOutputTypes.emplace_back(getResult(i)); + } + + return get(newInputTypes, newOutputTypes, getContext()); +} + //===----------------------------------------------------------------------===// // OpaqueType //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/IR/TestFunc.cpp b/mlir/test/lib/IR/TestFunc.cpp --- a/mlir/test/lib/IR/TestFunc.cpp +++ b/mlir/test/lib/IR/TestFunc.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/IR/Function.h" +#include "mlir/IR/FunctionSupportImplementation.h" #include "mlir/Pass/Pass.h" using namespace mlir;