diff --git a/mlir/include/mlir/Transforms/OperationUtils.h b/mlir/include/mlir/Transforms/OperationUtils.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Transforms/OperationUtils.h @@ -0,0 +1,26 @@ +//===- OperationUtils.h - Operation transformation utilities ----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TRANSFORMS_OPERATIONUTILS_H +#define MLIR_TRANSFORMS_OPERATIONUTILS_H + +#include "mlir/Support/LLVM.h" + +namespace mlir { +class BlockAndValueMapping; +class Operation; + +/// This is like `Operation::clone`, but instead of just keeping track of the +/// block and value mapping for the copy, it also keeps track of the +/// operation<->operation mapping. This matters because not all operations have +/// results. +Operation *cloneOperation(Operation *original, BlockAndValueMapping &mapper, + DenseMap &operationMap); +} // namespace mlir + +#endif // MLIR_TRANSFORMS_OPERATIONUTILS_H diff --git a/mlir/lib/Transforms/Utils/CMakeLists.txt b/mlir/lib/Transforms/Utils/CMakeLists.txt --- a/mlir/lib/Transforms/Utils/CMakeLists.txt +++ b/mlir/lib/Transforms/Utils/CMakeLists.txt @@ -6,6 +6,7 @@ GreedyPatternRewriteDriver.cpp InliningUtils.cpp LoopInvariantCodeMotionUtils.cpp + OperationUtils.cpp RegionUtils.cpp TopologicalSortUtils.cpp diff --git a/mlir/lib/Transforms/Utils/OperationUtils.cpp b/mlir/lib/Transforms/Utils/OperationUtils.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Transforms/Utils/OperationUtils.cpp @@ -0,0 +1,149 @@ +//===- OperationUtils.cpp - Operation transformation utilities --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/OperationUtils.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Operation.h" + +using namespace mlir; + +static Operation * +cloneOperation(Operation *original, BlockAndValueMapping &mapper, + DenseMap &operationMap, + Operation::CloneOptions options); + +/// Clone this region into 'dest' before the given position in 'dest'. +static void cloneRegion(Region &original, Region &dest, + BlockAndValueMapping &mapper, + DenseMap &operationMap) { + assert(&original != &dest && "cannot clone region into itself"); + + // If the list is empty there is nothing to clone. + if (original.empty()) + return; + + Region::iterator destPos = dest.end(); + + // The below clone implementation takes special care to be read only for the + // sake of multi threading. That essentially means not adding any uses to any + // of the blocks or operation results contained within this region as that + // would lead to a write in their use-def list. This is unavoidable for + // 'Value's from outside the region however, in which case it is not read + // only. Using the BlockAndValueMapper it is possible to remap such 'Value's + // to ones owned by the calling thread however, making it read only once + // again. + + // First clone all the blocks and block arguments and map them, but don't yet + // clone the operations, as they may otherwise add a use to a block that has + // not yet been mapped + for (Block &block : original) { + Block *newBlock = new Block(); + mapper.map(&block, newBlock); + + // Clone the block arguments. The user might be deleting arguments to the + // block by specifying them in the mapper. If so, we don't add the + // argument to the cloned block. + for (auto arg : block.getArguments()) + if (!mapper.contains(arg)) + mapper.map(arg, newBlock->addArgument(arg.getType(), arg.getLoc())); + + dest.getBlocks().insert(destPos, newBlock); + } + + auto newBlocksRange = llvm::make_range( + Region::iterator(mapper.lookup(&original.front())), destPos); + + // Now follow up with creating the operations, but don't yet clone their + // regions, nor set their operands. Setting the successors is safe as all have + // already been mapped. We are essentially just creating the operation results + // to be able to map them. + // Cloning the operands and region as well would lead to uses of operations + // not yet mapped. + auto cloneOptions = + Operation::CloneOptions::all().cloneRegions(false).cloneOperands(false); + for (auto zippedBlocks : llvm::zip(original, newBlocksRange)) { + Block &sourceBlock = std::get<0>(zippedBlocks); + Block &clonedBlock = std::get<1>(zippedBlocks); + // Clone and remap the operations within this block. + for (Operation &op : sourceBlock) + clonedBlock.push_back( + cloneOperation(&op, mapper, operationMap, cloneOptions)); + } + + // Finally now that all operation results have been mapped, set the operands + // and clone the regions. + SmallVector operands; + for (auto zippedBlocks : llvm::zip(original, newBlocksRange)) { + for (auto ops : + llvm::zip(std::get<0>(zippedBlocks), std::get<1>(zippedBlocks))) { + Operation &source = std::get<0>(ops); + Operation &clone = std::get<1>(ops); + + operands.resize(source.getNumOperands()); + llvm::transform( + source.getOperands(), operands.begin(), + [&](Value operand) { return mapper.lookupOrDefault(operand); }); + clone.setOperands(operands); + + for (auto [sourceRegion, destRegion] : + llvm::zip(source.getRegions(), clone.getRegions())) + cloneRegion(sourceRegion, destRegion, mapper, operationMap); + } + } +} + +/// Create a deep copy of this operation, remapping any operands that use +/// values outside of the operation using the map that is provided (leaving +/// them alone if no entry is present). Replaces references to cloned +/// sub-operations to the corresponding operation that is copied, and adds +/// those mappings to the map. +static Operation * +cloneOperation(Operation *original, BlockAndValueMapping &mapper, + DenseMap &operationMap, + Operation::CloneOptions options) { + SmallVector operands; + SmallVector successors; + + // Remap the operands. + if (options.shouldCloneOperands()) { + operands.reserve(original->getNumOperands()); + for (auto opValue : original->getOperands()) + operands.push_back(mapper.lookupOrDefault(opValue)); + } + + // Remap the successors. + successors.reserve(original->getNumSuccessors()); + for (Block *successor : original->getSuccessors()) + successors.push_back(mapper.lookupOrDefault(successor)); + + // Create the new operation. + auto *newOp = Operation::create( + original->getLoc(), original->getName(), original->getResultTypes(), + operands, original->getAttrs(), successors, original->getNumRegions()); + operationMap[original] = newOp; + + // Clone the regions. + if (options.shouldCloneRegions()) { + for (unsigned i = 0, e = original->getNumRegions(); i != e; ++i) + cloneRegion(original->getRegion(i), newOp->getRegion(i), mapper, + operationMap); + } + + // Remember the mapping of any results. + for (unsigned i = 0, e = original->getNumResults(); i != e; ++i) + mapper.map(original->getResult(i), newOp->getResult(i)); + + return newOp; +} + +Operation * +mlir::cloneOperation(Operation *original, BlockAndValueMapping &mapper, + DenseMap &operationMap) { + return ::cloneOperation(original, mapper, operationMap, + Operation::CloneOptions::all()); +} diff --git a/mlir/unittests/Transforms/CMakeLists.txt b/mlir/unittests/Transforms/CMakeLists.txt --- a/mlir/unittests/Transforms/CMakeLists.txt +++ b/mlir/unittests/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_unittest(MLIRTransformsTests Canonicalizer.cpp DialectConversion.cpp + OperationUtils.cpp ) target_link_libraries(MLIRTransformsTests PRIVATE diff --git a/mlir/unittests/Transforms/OperationUtils.cpp b/mlir/unittests/Transforms/OperationUtils.cpp new file mode 100644 --- /dev/null +++ b/mlir/unittests/Transforms/OperationUtils.cpp @@ -0,0 +1,36 @@ +//===- OperationUtils.cpp - Operation utils unit tests ----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/OperationUtils.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/OwningOpRef.h" +#include "gtest/gtest.h" + +using namespace mlir; + +TEST(OperationUtilsTest, CloneNestedRemap) { + MLIRContext ctx; + ctx.allowUnregisteredDialects(); + + OperationState state(UnknownLoc::get(&ctx), "no_results"); + Operation *noResultsOp = Operation::create(state); + + OperationState owner(UnknownLoc::get(&ctx), "owner"); + owner.addRegion()->emplaceBlock().push_back(noResultsOp); + OwningOpRef ownerOp = Operation::create(owner); + + BlockAndValueMapping bvMap; + DenseMap opMap; + OwningOpRef clonedOwnerOp = + cloneOperation(*ownerOp, bvMap, opMap); + + EXPECT_EQ(opMap[*ownerOp], *clonedOwnerOp); + EXPECT_EQ(opMap[noResultsOp], + &(*clonedOwnerOp)->getRegion(0).front().front()); +}