diff --git a/mlir/include/mlir/Dialect/SCF/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/SCF/Utils.h @@ -0,0 +1,50 @@ +//===- Utils.h - SCF dialect utilities --------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file defines prototypes for various SCF utilities. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SCF_UTILS_H_ +#define MLIR_DIALECT_SCF_UTILS_H_ + +namespace mlir { +class OpBuilder; +class ValueRange; + +namespace scf { +class ForOp; +class ParallelOp; +} // end namespace scf + +/// Create a clone of `loop` with `newIterOperands` added as new initialization +/// values and `newYieldedValues` added as new yielded values. The returned +/// ForOp has `newYieldedValues.size()` new result values. The `loop` induction +/// variable and `newIterOperands` are remapped to the new induction variable +/// and the new entry block arguments respectively. +/// +/// Additionally, if `replaceLoopResults` is true, all uses of +/// `loop.getResults()` are replaced with the first `loop.getNumResults()` +/// return values respectively. This additional replacement is provided as a +/// convenience to update the consumers of `loop`, in the case e.g. when `loop` +/// is soon to be deleted. +/// +/// Return the cloned loop. +/// +/// This convenience function is useful to factorize common mechanisms related +/// to hoisting roundtrips to memory into yields. It does not perform any +/// legality checks. +/// +/// Prerequisite: `newYieldedValues.size() == newYieldedValues.size()`. +scf::ForOp cloneWithNewYields(OpBuilder &b, scf::ForOp loop, + ValueRange newIterOperands, + ValueRange newYieldedValues, + bool replaceLoopResults = true); + +} // end namespace mlir +#endif // MLIR_DIALECT_SCF_UTILS_H_ diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt @@ -2,6 +2,7 @@ ParallelLoopFusion.cpp ParallelLoopSpecialization.cpp ParallelLoopTiling.cpp + Utils.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SCF diff --git a/mlir/lib/Dialect/SCF/Transforms/Utils.cpp b/mlir/lib/Dialect/SCF/Transforms/Utils.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SCF/Transforms/Utils.cpp @@ -0,0 +1,73 @@ +//===- LoopUtils.cpp ---- Misc utilities for loop transformation ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements miscellaneous loop transformation routines. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SCF/Utils.h" + +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/IR/BlockAndValueMapping.h" + +using namespace mlir; + +scf::ForOp mlir::cloneWithNewYields(OpBuilder &b, scf::ForOp loop, + ValueRange newIterOperands, + ValueRange newYieldedValues, + bool replaceLoopResults) { + assert(newIterOperands.size() == newYieldedValues.size() && + "newIterOperands must be of the same size as newYieldedValues"); + + // Create a new loop before the existing one, with the extra operands. + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(loop); + auto operands = llvm::to_vector<4>(loop.getIterOperands()); + operands.append(newIterOperands.begin(), newIterOperands.end()); + scf::ForOp newLoop = + b.create(loop.getLoc(), loop.lowerBound(), loop.upperBound(), + loop.step(), operands); + + auto &loopBody = *loop.getBody(); + auto &newLoopBody = *newLoop.getBody(); + // Clone / erase the yield inside the original loop to both: + // 1. augment its operands with the newYieldedValues. + // 2. automatically apply the BlockAndValueMapping on its operand + auto yield = cast(loopBody.getTerminator()); + b.setInsertionPoint(yield); + auto yieldOperands = llvm::to_vector<4>(yield.getOperands()); + yieldOperands.append(newYieldedValues.begin(), newYieldedValues.end()); + auto newYield = b.create(yield.getLoc(), yieldOperands); + + // Clone the loop body with remaps. + BlockAndValueMapping bvm; + // a. remap the induction variable. + bvm.map(loop.getInductionVar(), newLoop.getInductionVar()); + // b. remap the BB args. + bvm.map(loopBody.getArguments(), + newLoopBody.getArguments().take_front(loopBody.getNumArguments())); + // c. remap the iter args. + bvm.map(newIterOperands, + newLoop.getRegionIterArgs().take_back(newIterOperands.size())); + b.setInsertionPointToStart(&newLoopBody); + // Skip the original yield terminator which does not have enough operands. + for (auto &o : loopBody.without_terminator()) + b.clone(o, bvm); + + // Replace `loop`'s results if requested. + if (replaceLoopResults) { + for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front( + loop.getNumResults()))) + std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); + } + + // TODO: this is unsafe in the context of a PatternRewrite. + newYield.erase(); + + return newLoop; +} diff --git a/mlir/test/Transforms/loop-utils.mlir b/mlir/test/Transforms/loop-utils.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Transforms/loop-utils.mlir @@ -0,0 +1,40 @@ +// RUN: mlir-opt -allow-unregistered-dialect -test-scf-utils -mlir-disable-threading %s | FileCheck %s + +// CHECK-LABEL: @hoist +// CHECK-SAME: %[[lb:[a-zA-Z0-9]*]]: index, +// CHECK-SAME: %[[ub:[a-zA-Z0-9]*]]: index, +// CHECK-SAME: %[[step:[a-zA-Z0-9]*]]: index +func @hoist(%lb: index, %ub: index, %step: index) { + // CHECK: %[[A:.*]] = "fake_read"() : () -> index + // CHECK: %[[RES:.*]] = scf.for %[[I:.*]] = %[[lb]] to %[[ub]] step %[[step]] iter_args(%[[VAL:.*]] = %[[A]]) -> (index) + // CHECK: %[[YIELD:.*]] = "fake_compute"(%[[VAL]]) : (index) -> index + // CHECK: scf.yield %[[YIELD]] : index + // CHECK: "fake_write"(%[[RES]]) : (index) -> () + scf.for %i = %lb to %ub step %step { + %0 = "fake_read"() : () -> (index) + %1 = "fake_compute"(%0) : (index) -> (index) + "fake_write"(%1) : (index) -> () + } + return +} + +// CHECK-LABEL: @hoist2 +// CHECK-SAME: %[[lb:[a-zA-Z0-9]*]]: index, +// CHECK-SAME: %[[ub:[a-zA-Z0-9]*]]: index, +// CHECK-SAME: %[[step:[a-zA-Z0-9]*]]: index +// CHECK-SAME: %[[extra_arg:[a-zA-Z0-9]*]]: f32 +func @hoist2(%lb: index, %ub: index, %step: index, %extra_arg: f32) -> f32 { + // CHECK: %[[A:.*]] = "fake_read"() : () -> index + // CHECK: %[[RES:.*]]:2 = scf.for %[[I:.*]] = %[[lb]] to %[[ub]] step %[[step]] iter_args(%[[VAL0:.*]] = %[[extra_arg]], %[[VAL1:.*]] = %[[A]]) -> (f32, index) + // CHECK: %[[YIELD:.*]] = "fake_compute"(%[[VAL1]]) : (index) -> index + // CHECK: scf.yield %[[VAL0]], %[[YIELD]] : f32, index + // CHECK: "fake_write"(%[[RES]]#1) : (index) -> () + // CHECK: return %[[RES]]#0 : f32 + %0 = scf.for %i = %lb to %ub step %step iter_args(%iter = %extra_arg) -> (f32) { + %0 = "fake_read"() : () -> (index) + %1 = "fake_compute"(%0) : (index) -> (index) + "fake_write"(%1) : (index) -> () + scf.yield %iter: f32 + } + return %0: f32 +} diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt --- a/mlir/test/lib/Transforms/CMakeLists.txt +++ b/mlir/test/lib/Transforms/CMakeLists.txt @@ -19,6 +19,7 @@ TestMemRefBoundCheck.cpp TestMemRefDependenceCheck.cpp TestMemRefStrideCalculation.cpp + TestSCFUtils.cpp TestVectorTransforms.cpp EXCLUDE_FROM_LIBMLIR diff --git a/mlir/test/lib/Transforms/TestSCFUtils.cpp b/mlir/test/lib/Transforms/TestSCFUtils.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Transforms/TestSCFUtils.cpp @@ -0,0 +1,58 @@ +//===- TestSCFUtils.cpp --- Pass to test independent SCF dialect utils ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to test SCF dialect utils. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/SCF/Utils.h" +#include "mlir/IR/Builders.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/Passes.h" + +#include "llvm/ADT/SetVector.h" + +using namespace mlir; + +namespace { +class TestSCFUtilsPass : public PassWrapper { +public: + explicit TestSCFUtilsPass() {} + + void runOnFunction() override { + FuncOp func = getFunction(); + SmallVector toErase; + + func.walk([&](Operation *fakeRead) { + if (fakeRead->getName().getStringRef() != "fake_read") + return; + auto *fakeCompute = fakeRead->getResult(0).use_begin()->getOwner(); + auto *fakeWrite = fakeCompute->getResult(0).use_begin()->getOwner(); + auto loop = fakeRead->getParentOfType(); + + OpBuilder b(loop); + loop.moveOutOfLoop({fakeRead}); + fakeWrite->moveAfter(loop); + auto newLoop = cloneWithNewYields(b, loop, fakeRead->getResult(0), + fakeCompute->getResult(0)); + fakeCompute->getResult(0).replaceAllUsesWith( + newLoop.getResults().take_back()[0]); + toErase.push_back(loop); + }); + for (auto loop : llvm::reverse(toErase)) + loop.erase(); + } +}; +} // end namespace + +namespace mlir { +void registerTestSCFUtilsPass() { + PassRegistration("test-scf-utils", "test scf utils"); +} +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -60,6 +60,7 @@ void registerTestOpaqueLoc(); void registerTestParallelismDetection(); void registerTestGpuParallelLoopMappingPass(); +void registerTestSCFUtilsPass(); void registerTestVectorConversions(); void registerVectorizerTestPass(); } // namespace mlir @@ -127,6 +128,7 @@ registerTestOpaqueLoc(); registerTestParallelismDetection(); registerTestGpuParallelLoopMappingPass(); + registerTestSCFUtilsPass(); registerTestVectorConversions(); registerVectorizerTestPass(); }