diff --git a/mlir/include/mlir/Transforms/LoopUtils.h b/mlir/include/mlir/Transforms/LoopUtils.h --- a/mlir/include/mlir/Transforms/LoopUtils.h +++ b/mlir/include/mlir/Transforms/LoopUtils.h @@ -294,6 +294,30 @@ separateFullTiles(MutableArrayRef nest, SmallVectorImpl *fullTileNest = nullptr); +/// Creates a clone of `loop` with `newYieldedValues` added as new +/// initialization values and `newYieldedValues` added as new yielded values. +/// The returned ForOp has `newYieldedValues.size()` new result values. +/// The `loop` induction variable and `newIterOperands` are remapped to the new +/// induction variable and the new entry block arguments respectively. +/// +/// Additionally, if `replaceLoopResults` is true, all uses of +/// `loop.getResults()` are replaced with the first `loop.getNumResults()` +/// return values respectively. This additional replacement is provided as a +/// convenience to update the consumers of `loop`, in the case e.g. when `loop` +/// is soon to be deleted. +/// +/// Returns the cloned loop. +/// +/// This convenience function is useful to factorize common mechanisms related +/// to hoisting roundtrips to memory into yields. It does not perform any +/// legality checks. +/// +/// Prerequisite: `newYieldedValues.size() == newYieldedValues.size()`. +scf::ForOp cloneWithNewYields(OpBuilder &b, scf::ForOp loop, + ValueRange newIterOperands, + ValueRange newYieldedValues, + bool replaceLoopResults = true); + } // end namespace mlir #endif // MLIR_TRANSFORMS_LOOP_UTILS_H diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -1096,7 +1096,7 @@ auto begin = t.getBody()->begin(); auto nOps = t.getBody()->getOperations().size(); - // Insert newForOp before the terminator of `t`. + // Insert newLoop before the terminator of `t`. auto b = OpBuilder::atBlockTerminator((t.getBody())); Value stepped = b.create(t.getLoc(), iv, forOp.step()); Value less = b.create(t.getLoc(), CmpIPredicate::slt, @@ -1104,15 +1104,14 @@ Value ub = b.create(t.getLoc(), less, forOp.upperBound(), stepped); - // Splice [begin, begin + nOps - 1) into `newForOp` and replace uses. - auto newForOp = b.create(t.getLoc(), iv, ub, originalStep); - newForOp.getBody()->getOperations().splice( - newForOp.getBody()->getOperations().begin(), + // Splice [begin, begin + nOps - 1) into `newLoop` and replace uses. + auto newLoop = b.create(t.getLoc(), iv, ub, originalStep); + newLoop.getBody()->getOperations().splice( + newLoop.getBody()->getOperations().begin(), t.getBody()->getOperations(), begin, std::next(begin, nOps - 1)); - replaceAllUsesInRegionWith(iv, newForOp.getInductionVar(), - newForOp.region()); + replaceAllUsesInRegionWith(iv, newLoop.getInductionVar(), newLoop.region()); - innerLoops.push_back(newForOp); + innerLoops.push_back(newLoop); } return innerLoops; @@ -2507,3 +2506,63 @@ return success(); } + +scf::ForOp mlir::cloneWithNewYields(OpBuilder &b, scf::ForOp loop, + ValueRange newIterOperands, + ValueRange newYieldedValues, + bool replaceLoopResults) { + assert(newIterOperands.size() == newYieldedValues.size() && + "newIterOperands must be of the same size as newYieldedValues"); + + // Create a new loop before the existing one, with the extra operands. + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(loop); + auto operands = llvm::to_vector<4>(loop.getIterOperands()); + operands.append(newIterOperands.begin(), newIterOperands.end()); + scf::ForOp newLoop = + b.create(loop.getLoc(), loop.lowerBound(), loop.upperBound(), + loop.step(), operands); + + // Clone / erase the yield inside the original loop to both: + // 1. augment its operands with the newYieldedValues. + // 2. automatically apply the BlockAndValueMapping on its operand + auto yield = cast(loop.region().front().getTerminator()); + b.setInsertionPoint(yield); + auto yieldOperands = llvm::to_vector<4>(yield.getOperands()); + yieldOperands.append(newYieldedValues.begin(), newYieldedValues.end()); + auto newYield = b.create(yield.getLoc(), yieldOperands); + + // Clone the loop body with remaps. + BlockAndValueMapping bvm; + // a. remap the induction variable. + bvm.map(loop.getInductionVar(), newLoop.getInductionVar()); + // b. remap the BB args. + for (auto it : llvm::zip(loop.region().front().getArguments(), + newLoop.region().front().getArguments().take_front( + loop.region().front().getNumArguments()))) + bvm.map(std::get<0>(it), std::get<1>(it)); + // c. remap + for (auto it : + llvm::zip(newIterOperands, + newLoop.getRegionIterArgs().take_back(newIterOperands.size()))) + bvm.map(std::get<0>(it), std::get<1>(it)); + b.setInsertionPointToStart(&newLoop.region().front()); + for (auto &o : loop.region().front().getOperations()) { + if (&o != yield.getOperation()) { + // Skip the original yield terminator which does not have enough operands. + auto *cloned = b.clone(o, bvm); + for (unsigned idx = 0, e = cloned->getNumResults(); idx < e; ++idx) + bvm.map(o.getResult(idx), cloned->getResult(idx)); + } + } + + // Perform `loop` results replacement if requested. + if (replaceLoopResults) { + for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front( + loop.getNumResults()))) + std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); + } + + newYield.erase(); + return newLoop; +} diff --git a/mlir/test/Transforms/loop-utils.mlir b/mlir/test/Transforms/loop-utils.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Transforms/loop-utils.mlir @@ -0,0 +1,40 @@ +// RUN: mlir-opt -allow-unregistered-dialect -test-loop-utils -mlir-disable-threading %s | FileCheck %s + +// CHECK-LABEL: @hoist +// CHECK-SAME: %[[lb:[a-zA-Z0-9]*]]: index, +// CHECK-SAME: %[[ub:[a-zA-Z0-9]*]]: index, +// CHECK-SAME: %[[step:[a-zA-Z0-9]*]]: index +func @hoist(%lb: index, %ub: index, %step: index) { + // CHECK: %[[A:.*]] = "fake_read"() : () -> index + // CHECK: %[[RES:.*]] = scf.for %[[I:.*]] = %[[lb]] to %[[ub]] step %[[step]] iter_args(%[[VAL:.*]] = %[[A]]) -> (index) + // CHECK: %[[YIELD:.*]] = "fake_compute"(%[[VAL]]) : (index) -> index + // CHECK: scf.yield %[[YIELD]] : index + // CHECK: "fake_write"(%[[RES]]) : (index) -> () + scf.for %i = %lb to %ub step %step { + %0 = "fake_read"() : () -> (index) + %1 = "fake_compute"(%0) : (index) -> (index) + "fake_write"(%1) : (index) -> () + } + return +} + +// CHECK-LABEL: @hoist2 +// CHECK-SAME: %[[lb:[a-zA-Z0-9]*]]: index, +// CHECK-SAME: %[[ub:[a-zA-Z0-9]*]]: index, +// CHECK-SAME: %[[step:[a-zA-Z0-9]*]]: index +// CHECK-SAME: %[[extra_arg:[a-zA-Z0-9]*]]: f32 +func @hoist2(%lb: index, %ub: index, %step: index, %extra_arg: f32) -> f32 { + // CHECK: %[[A:.*]] = "fake_read"() : () -> index + // CHECK: %[[RES:.*]]:2 = scf.for %[[I:.*]] = %[[lb]] to %[[ub]] step %[[step]] iter_args(%[[VAL0:.*]] = %[[extra_arg]], %[[VAL1:.*]] = %[[A]]) -> (f32, index) + // CHECK: %[[YIELD:.*]] = "fake_compute"(%[[VAL1]]) : (index) -> index + // CHECK: scf.yield %[[VAL0]], %[[YIELD]] : f32, index + // CHECK: "fake_write"(%[[RES]]#1) : (index) -> () + // CHECK: return %[[RES]]#0 : f32 + %0 = scf.for %i = %lb to %ub step %step iter_args(%iter = %extra_arg) -> (f32) { + %0 = "fake_read"() : () -> (index) + %1 = "fake_compute"(%0) : (index) -> (index) + "fake_write"(%1) : (index) -> () + scf.yield %iter: f32 + } + return %0: f32 +} diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt --- a/mlir/test/lib/Transforms/CMakeLists.txt +++ b/mlir/test/lib/Transforms/CMakeLists.txt @@ -15,6 +15,7 @@ TestLoopMapping.cpp TestLoopParametricTiling.cpp TestLoopUnrolling.cpp + TestLoopUntils.cpp TestOpaqueLoc.cpp TestMemRefBoundCheck.cpp TestMemRefDependenceCheck.cpp diff --git a/mlir/test/lib/Transforms/TestLoopUtils.cpp b/mlir/test/lib/Transforms/TestLoopUtils.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Transforms/TestLoopUtils.cpp @@ -0,0 +1,58 @@ +//===- TestLoopUtils.cpp --- Pass to test independent loop utils ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to test loop utils. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/IR/Builders.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/LoopUtils.h" +#include "mlir/Transforms/Passes.h" + +#include "llvm/ADT/SetVector.h" + +using namespace mlir; + +namespace { +class TestLoopUtilsPass : public PassWrapper { +public: + explicit TestLoopUtilsPass() {} + + void runOnFunction() override { + FuncOp func = getFunction(); + SmallVector toErase; + + func.walk([&](Operation *fakeRead) { + if (fakeRead->getName().getStringRef() != "fake_read") + return; + auto *fakeCompute = fakeRead->getResult(0).use_begin()->getOwner(); + auto *fakeWrite = fakeCompute->getResult(0).use_begin()->getOwner(); + auto loop = fakeRead->getParentOfType(); + + OpBuilder b(loop); + loop.moveOutOfLoop({fakeRead}); + fakeWrite->moveAfter(loop); + auto newLoop = cloneWithNewYields(b, loop, fakeRead->getResult(0), + fakeCompute->getResult(0)); + fakeCompute->getResult(0).replaceAllUsesWith( + newLoop.getResults().take_back()[0]); + toErase.push_back(loop); + }); + for (auto loop : llvm::reverse(toErase)) + loop.erase(); + } +}; +} // end namespace + +namespace mlir { +void registerTestLoopUtilsPass() { + PassRegistration("test-loop-utils", "test loop utils"); +} +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -54,6 +54,7 @@ void registerTestLoopFusion(); void registerTestLoopMappingPass(); void registerTestLoopUnrollingPass(); +void registerTestLoopUtilsPass(); void registerTestMatchers(); void registerTestMemRefDependenceCheck(); void registerTestMemRefStrideCalculation(); @@ -122,6 +123,7 @@ registerTestLoopFusion(); registerTestLoopMappingPass(); registerTestLoopUnrollingPass(); + registerTestLoopUtilsPass(); registerTestMatchers(); registerTestMemRefDependenceCheck(); registerTestMemRefStrideCalculation();