diff --git a/mlir/include/mlir/IR/Function.h b/mlir/include/mlir/IR/Function.h --- a/mlir/include/mlir/IR/Function.h +++ b/mlir/include/mlir/IR/Function.h @@ -65,6 +65,12 @@ /// `argIndices` is allowed to have duplicates and can be in any order. void eraseArguments(ArrayRef argIndices); + /// Erase a single result at `resultIndex`. + void eraseResult(unsigned resultIndex) { eraseResults({resultIndex}); } + /// Erases the results listed in `resultIndices`. + /// `resultIndices` is allowed to have duplicates and can be in any order. + void eraseResults(ArrayRef resultIndices); + /// Create a deep copy of this function and all of its blocks, remapping /// any operands that use values outside of the function using the map that is /// provided (leaving them alone if no entry is present). If the mapper diff --git a/mlir/lib/IR/Function.cpp b/mlir/lib/IR/Function.cpp --- a/mlir/lib/IR/Function.cpp +++ b/mlir/lib/IR/Function.cpp @@ -132,6 +132,31 @@ entry.eraseArgument(originalNumArgs - i - 1); } +void FuncOp::eraseResults(ArrayRef resultIndices) { + auto oldType = getType(); + int originalNumResults = oldType.getNumResults(); + llvm::BitVector eraseIndices(originalNumResults); + for (auto index : resultIndices) + eraseIndices.set(index); + auto shouldEraseResult = [&](int i) { return eraseIndices.test(i); }; + + // There are 2 things that need to be updated: + // - Function type. + // - Result attrs. + + // Update the function type and result attrs. + SmallVector newResultTypes; + SmallVector newResultAttrs; + for (int i = 0; i < originalNumResults; i++) { + if (shouldEraseResult(i)) + continue; + newResultTypes.emplace_back(oldType.getResult(i)); + newResultAttrs.emplace_back(getResultAttrDict(i)); + } + setType(FunctionType::get(oldType.getInputs(), newResultTypes, getContext())); + setAllResultAttrs(newResultAttrs); +} + /// Clone the internal blocks from this function into dest and all attributes /// from this function to dest. void FuncOp::cloneInto(FuncOp dest, BlockAndValueMapping &mapper) { diff --git a/mlir/test/IR/test-func-erase-result.mlir b/mlir/test/IR/test-func-erase-result.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/IR/test-func-erase-result.mlir @@ -0,0 +1,68 @@ +// RUN: mlir-opt %s -test-func-erase-result -split-input-file | FileCheck %s + +// CHECK: func @f(){{$}} +// CHECK-NOT: attributes{{.*}}result +func @f() -> (f32 {test.erase_this_result}) + +// ----- + +// CHECK: func @f() -> (f32 {test.A}) +// CHECK-NOT: attributes{{.*}}result +func @f() -> ( + f32 {test.erase_this_result}, + f32 {test.A} +) + +// ----- + +// CHECK: func @f() -> (f32 {test.A}) +// CHECK-NOT: attributes{{.*}}result +func @f() -> ( + f32 {test.A}, + f32 {test.erase_this_result} +) + +// ----- + +// CHECK: func @f() -> (f32 {test.A}, f32 {test.B}) +// CHECK-NOT: attributes{{.*}}result +func @f() -> ( + f32 {test.A}, + f32 {test.erase_this_result}, + f32 {test.B} +) + +// ----- + +// CHECK: func @f() -> (f32 {test.A}, f32 {test.B}) +// CHECK-NOT: attributes{{.*}}result +func @f() -> ( + f32 {test.A}, + f32 {test.erase_this_result}, + f32 {test.erase_this_result}, + f32 {test.B} +) + +// ----- + +// CHECK: func @f() -> (f32 {test.A}, f32 {test.B}, f32 {test.C}) +// CHECK-NOT: attributes{{.*}}result +func @f() -> ( + f32 {test.A}, + f32 {test.erase_this_result}, + f32 {test.B}, + f32 {test.erase_this_result}, + f32 {test.C} +) + +// ----- + +// CHECK: func @f() -> (tensor<1xf32>, tensor<2xf32>, tensor<3xf32>) +// CHECK-NOT: attributes{{.*}}result +func @f() -> ( + tensor<1xf32>, + f32 {test.erase_this_result}, + tensor<2xf32>, + f32 {test.erase_this_result}, + tensor<3xf32> +) diff --git a/mlir/test/lib/IR/TestFunc.cpp b/mlir/test/lib/IR/TestFunc.cpp --- a/mlir/test/lib/IR/TestFunc.cpp +++ b/mlir/test/lib/IR/TestFunc.cpp @@ -36,6 +36,30 @@ } }; +/// This is a test pass for verifying FuncOp's eraseResult method. +struct TestFuncEraseResult + : public PassWrapper> { + void runOnOperation() override { + auto module = getOperation(); + + for (FuncOp func : module.getOps()) { + SmallVector indicesToErase; + for (auto resultIndex : llvm::seq(0, func.getNumResults())) { + if (func.getResultAttr(resultIndex, "test.erase_this_result")) { + // Push back twice to test that duplicate indices are handled + // correctly. + indicesToErase.push_back(resultIndex); + indicesToErase.push_back(resultIndex); + } + } + // Reverse the order to test that unsorted index lists are handled + // correctly. + std::reverse(indicesToErase.begin(), indicesToErase.end()); + func.eraseResults(indicesToErase); + } + } +}; + /// This is a test pass for verifying FuncOp's setType method. struct TestFuncSetType : public PassWrapper> { @@ -55,10 +79,13 @@ namespace mlir { void registerTestFunc() { - PassRegistration pass("test-func-erase-arg", - "Test erasing func args."); + PassRegistration("test-func-erase-arg", + "Test erasing func args."); - PassRegistration pass2("test-func-set-type", - "Test FuncOp::setType."); + PassRegistration("test-func-erase-result", + "Test erasing func results."); + + PassRegistration("test-func-set-type", + "Test FuncOp::setType."); } } // namespace mlir