diff --git a/mlir/include/mlir/IR/Region.h b/mlir/include/mlir/IR/Region.h --- a/mlir/include/mlir/IR/Region.h +++ b/mlir/include/mlir/IR/Region.h @@ -240,6 +240,7 @@ /// Takes body of another region (that region will have no body after this /// operation completes). The current body of this region is cleared. void takeBody(Region &other) { + dropAllReferences(); blocks.clear(); blocks.splice(blocks.end(), other.getBlocks()); } diff --git a/mlir/test/IR/test-take-body.mlir b/mlir/test/IR/test-take-body.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/IR/test-take-body.mlir @@ -0,0 +1,23 @@ +// RUN: mlir-opt -allow-unregistered-dialect %s --test-take-body -split-input-file + +func @foo() { + %0 = "test.foo"() : () -> i32 + cf.br ^header + +^header: + cf.br ^body + +^body: + "test.use"(%0) : (i32) -> () + cf.br ^header +} + +func private @bar() { + return +} + +// CHECK-LABEL: func @foo +// CHECK-NEXT: return + +// CHECK-LABEL: func private @bar() +// CHECK-NOT: { diff --git a/mlir/test/lib/IR/CMakeLists.txt b/mlir/test/lib/IR/CMakeLists.txt --- a/mlir/test/lib/IR/CMakeLists.txt +++ b/mlir/test/lib/IR/CMakeLists.txt @@ -15,6 +15,7 @@ TestSideEffects.cpp TestSlicing.cpp TestSymbolUses.cpp + TestRegions.cpp TestTypes.cpp TestVisitors.cpp TestVisitorsGeneric.cpp diff --git a/mlir/test/lib/IR/TestRegions.cpp b/mlir/test/lib/IR/TestRegions.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/IR/TestRegions.cpp @@ -0,0 +1,45 @@ +//===- TestRegions.cpp - Pass to test Region's methods --------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "TestDialect.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace { +/// This is a test pass that tests Region's takeBody method by making the first +/// function take the body of the second. +struct TakeBodyPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TakeBodyPass) + + StringRef getArgument() const final { return "test-take-body"; } + StringRef getDescription() const final { return "Test Region's takeBody"; } + + void runOnOperation() override { + auto module = getOperation(); + + SmallVector functions = + llvm::to_vector(module.getOps()); + if (functions.size() != 2) { + module.emitError("Expected only two functions in test"); + signalPassFailure(); + return; + } + + functions[0].getBody().takeBody(functions[1].getBody()); + } +}; + +} // namespace + +namespace mlir { +void registerRegionTestPasses() { PassRegistration(); } +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -37,6 +37,7 @@ void registerSideEffectTestPasses(); void registerSliceAnalysisTestPass(); void registerSymbolTestPasses(); +void registerRegionTestPasses(); void registerTestAffineDataCopyPass(); void registerTestAffineLoopUnswitchingPass(); void registerTestAllReduceLoweringPass(); @@ -128,6 +129,7 @@ registerSideEffectTestPasses(); registerSliceAnalysisTestPass(); registerSymbolTestPasses(); + registerRegionTestPasses(); registerTestAffineDataCopyPass(); registerTestAffineLoopUnswitchingPass(); registerTestAllReduceLoweringPass();