diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -85,7 +85,10 @@ /// original one, but they will be left empty. /// Operands are remapped using `mapper` (if present), and `mapper` is updated /// to contain the results. - Operation *cloneWithoutRegions(BlockAndValueMapping &mapper); + /// The `mapResults` argument specifies whether the results of the operation + /// should also be mapped. + Operation *cloneWithoutRegions(BlockAndValueMapping &mapper, + bool mapResults = true); /// Create a partial copy of this operation without traversing into attached /// regions. The new operation will have the same number of regions as the diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -526,7 +526,8 @@ /// Create a deep copy of this operation but keep the operation regions empty. /// Operands are remapped using `mapper` (if present), and `mapper` is updated /// to contain the results. -Operation *Operation::cloneWithoutRegions(BlockAndValueMapping &mapper) { +Operation *Operation::cloneWithoutRegions(BlockAndValueMapping &mapper, + bool mapResults) { SmallVector operands; SmallVector successors; @@ -545,8 +546,10 @@ successors, getNumRegions()); // Remember the mapping of any results. - for (unsigned i = 0, e = getNumResults(); i != e; ++i) - mapper.map(getResult(i), newOp->getResult(i)); + if (mapResults) { + for (unsigned i = 0, e = getNumResults(); i != e; ++i) + mapper.map(getResult(i), newOp->getResult(i)); + } return newOp; } @@ -562,12 +565,15 @@ /// sub-operations to the corresponding operation that is copied, and adds /// those mappings to the map. Operation *Operation::clone(BlockAndValueMapping &mapper) { - auto *newOp = cloneWithoutRegions(mapper); + auto *newOp = cloneWithoutRegions(mapper, /*mapResults=*/false); // Clone the regions. for (unsigned i = 0; i != numRegions; ++i) getRegion(i).cloneInto(&newOp->getRegion(i), mapper); + for (unsigned i = 0, e = getNumResults(); i != e; ++i) + mapper.map(getResult(i), newOp->getResult(i)); + return newOp; } diff --git a/mlir/test/IR/test-clone.mlir b/mlir/test/IR/test-clone.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/IR/test-clone.mlir @@ -0,0 +1,20 @@ +// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline="func.func(test-clone)" -split-input-file + +module { + func @fixpoint(%arg1 : i32) -> i32 { + %r = "test.use"(%arg1) ({ + "test.yield"(%arg1) : (i32) -> () + }) : (i32) -> i32 + return %r : i32 + } +} + +// CHECK: func @fixpoint(%[[arg0:.+]]: i32) -> i32 { +// CHECK-NEXT: %[[i0:.+]] = "test.use"(%[[arg0]]) ({ +// CHECK-NEXT: "test.yield"(%arg0) : (i32) -> () +// CHECK-NEXT: }) : (i32) -> i32 +// CHECK-NEXT: %[[i1:.+]] = "test.use"(%[[i0]]) ({ +// CHECK-NEXT: "test.yield"(%[[i0]]) : (i32) -> () +// CHECK-NEXT: }) : (i32) -> i32 +// CHECK-NEXT: return %[[i1]] : i32 +// CHECK-NEXT: } diff --git a/mlir/test/lib/IR/CMakeLists.txt b/mlir/test/lib/IR/CMakeLists.txt --- a/mlir/test/lib/IR/CMakeLists.txt +++ b/mlir/test/lib/IR/CMakeLists.txt @@ -1,6 +1,7 @@ # Exclude tests from libMLIR.so add_mlir_library(MLIRTestIR TestBuiltinAttributeInterfaces.cpp + TestClone.cpp TestDiagnostics.cpp TestDominance.cpp TestFunc.cpp diff --git a/mlir/test/lib/IR/TestClone.cpp b/mlir/test/lib/IR/TestClone.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/IR/TestClone.cpp @@ -0,0 +1,64 @@ +//===- TestSymbolUses.cpp - Pass to test symbol uselists ------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "TestDialect.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace { + +/// This is a test pass which clones the body of a function. Specifically +/// this pass replaces f(x) to instead return f(f(x)) in which the cloned body +/// takes the result of the first operation return as an input. +struct ClonePass + : public PassWrapper> { + StringRef getArgument() const final { return "test-clone"; } + StringRef getDescription() const final { return "Test clone of op"; } + void runOnOperation() override { + FunctionOpInterface op = getOperation(); + + // Limit testing to ops with only one region. + if (op->getNumRegions() != 1) + return; + + Region ®ion = op->getRegion(0); + if (!region.hasOneBlock()) + return; + + Block ®ionEntry = region.front(); + auto terminator = regionEntry.getTerminator(); + + // Only handle functions whose returns match the inputs. + if (terminator->getNumOperands() != regionEntry.getNumArguments()) + return; + + BlockAndValueMapping map; + for (auto tup : + llvm::zip(terminator->getOperands(), regionEntry.getArguments())) { + if (std::get<0>(tup).getType() != std::get<1>(tup).getType()) + return; + map.map(std::get<1>(tup), std::get<0>(tup)); + } + + OpBuilder B(op->getContext()); + B.setInsertionPointToEnd(®ionEntry); + SmallVector toClone; + for (Operation &inst : regionEntry) + toClone.push_back(&inst); + for (Operation *inst : toClone) + B.clone(*inst, map); + terminator->erase(); + } +}; +} // namespace + +namespace mlir { +void registerCloneTestPasses() { PassRegistration(); } +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -30,6 +30,7 @@ // Defined in the test directory, no public header. namespace mlir { void registerConvertToTargetEnvPass(); +void registerCloneTestPasses(); void registerPassManagerTestPass(); void registerPrintSpirvAvailabilityPass(); void registerShapeFunctionTestPasses(); @@ -119,6 +120,7 @@ #ifdef MLIR_INCLUDE_TESTS void registerTestPasses() { + registerCloneTestPasses(); registerConvertToTargetEnvPass(); registerPassManagerTestPass(); registerPrintSpirvAvailabilityPass();