diff --git a/mlir/include/mlir/Transforms/LoopUtils.h b/mlir/include/mlir/Transforms/LoopUtils.h --- a/mlir/include/mlir/Transforms/LoopUtils.h +++ b/mlir/include/mlir/Transforms/LoopUtils.h @@ -115,7 +115,7 @@ /// to inner. Returns the position in `inputNest` of the AffineForOp that /// becomes the new outermost loop of this nest. This method always succeeds, /// asserts out on invalid input / specifications. -unsigned permuteLoops(ArrayRef inputNest, +unsigned permuteLoops(MutableArrayRef inputNest, ArrayRef permMap); // Sinks all sequential loops to the innermost levels (while preserving @@ -124,11 +124,6 @@ // Returns AffineForOp of the root of the new loop nest after loop interchanges. AffineForOp sinkSequentialLoops(AffineForOp forOp); -/// Sinks 'forOp' by 'loopDepth' levels by performing a series of loop -/// interchanges. Requires that 'forOp' is part of a perfect nest with -/// 'loopDepth' AffineForOps consecutively nested under it. -void sinkLoop(AffineForOp forOp, unsigned loopDepth); - /// Performs tiling fo imperfectly nested loops (with interchange) by /// strip-mining the `forOps` by `sizes` and sinking them, in their order of /// occurrence in `forOps`, under each of the `targets`. diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -716,7 +716,7 @@ // input[i] should move from position i -> permMap[i]. Returns the position in // `input` that becomes the new outermost loop. -unsigned mlir::permuteLoops(ArrayRef input, +unsigned mlir::permuteLoops(MutableArrayRef input, ArrayRef permMap) { assert(input.size() == permMap.size() && "invalid permutation map size"); // Check whether the permutation spec is valid. This is a small vector - we'll @@ -733,19 +733,55 @@ assert(isPerfectlyNested(input) && "input not perfectly nested"); - Optional loopNestRootIndex; + // Compute the inverse mapping, invPermMap: since input[i] goes to position + // permMap[i], position i of the permuted nest is at input[invPermMap[i]]. + SmallVector, 4> invPermMap; + for (unsigned i = 0, e = input.size(); i < e; ++i) + invPermMap.push_back({permMap[i], i}); + llvm::sort(invPermMap); + + // Move the innermost loop body to the loop that would be the innermost in the + // permuted nest (only if the innermost loop is going to change). + if (permMap.back() != input.size() - 1) { + auto *destBody = input[invPermMap.back().second].getBody(); + auto *srcBody = input.back().getBody(); + destBody->getOperations().splice(destBody->begin(), + srcBody->getOperations(), srcBody->begin(), + std::prev(srcBody->end())); + } + + // We'll move each loop in `input` in the reverse order so that its body is + // empty when we are moving it; this incurs zero copies and no erasing. for (int i = input.size() - 1; i >= 0; --i) { - int permIndex = static_cast(permMap[i]); - // Store the index of the for loop which will be the new loop nest root. - if (permIndex == 0) - loopNestRootIndex = i; - if (permIndex > i) { - // Sink loop 'i' by 'permIndex - i' levels deeper into the loop nest. - sinkLoop(input[i], permIndex - i); + // If this has to become the outermost loop after permutation, add it to the + // parent block of the original root. + if (permMap[i] == 0) { + // If the root remains the same, nothing to do. + if (i == 0) + continue; + // Make input[i] the new outermost loop moving it into parentBlock. + auto *parentBlock = input[0].getOperation()->getBlock(); + parentBlock->getOperations().splice( + Block::iterator(input[0]), + input[i].getOperation()->getBlock()->getOperations(), + Block::iterator(input[i])); + continue; } + + // If the parent in the permuted order is the same as in the original, + // nothing to do. + unsigned parentPosInInput = invPermMap[permMap[i] - 1].second; + if (i > 0 && static_cast(i - 1) == parentPosInInput) + continue; + + // Move input[i] to its surrounding loop in the transformed nest. + auto *destBody = input[parentPosInInput].getBody(); + destBody->getOperations().splice( + destBody->begin(), input[i].getOperation()->getBlock()->getOperations(), + Block::iterator(input[i])); } - assert(loopNestRootIndex.hasValue()); - return loopNestRootIndex.getValue(); + + return invPermMap[0].second; } // Sinks all sequential loops to the innermost levels (while preserving @@ -803,15 +839,6 @@ return loops[loopNestRootIndex]; } -/// Performs a series of loop interchanges to sink 'forOp' 'loopDepth' levels -/// deeper in the loop nest. -void mlir::sinkLoop(AffineForOp forOp, unsigned loopDepth) { - for (unsigned i = 0; i < loopDepth; ++i) { - AffineForOp nextForOp = cast(forOp.getBody()->front()); - interchangeLoops(forOp, nextForOp); - } -} - // Factors out common behavior to add a new `iv` (resp. `iv` + `offset`) to the // lower (resp. upper) loop bound. When called for both the lower and upper // bounds, the resulting IR resembles: diff --git a/mlir/test/Dialect/Affine/loop-permute.mlir b/mlir/test/Dialect/Affine/loop-permute.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Affine/loop-permute.mlir @@ -0,0 +1,47 @@ +// RUN: mlir-opt %s -test-loop-permutation="permutation-map=1,2,0" | FileCheck %s --check-prefix=CHECK-120 +// RUN: mlir-opt %s -test-loop-permutation="permutation-map=1,0,2" | FileCheck %s --check-prefix=CHECK-102 +// RUN: mlir-opt %s -test-loop-permutation="permutation-map=0,1,2" | FileCheck %s --check-prefix=CHECK-012 +// RUN: mlir-opt %s -test-loop-permutation="permutation-map=0,2,1" | FileCheck %s --check-prefix=CHECK-021 +// RUN: mlir-opt %s -test-loop-permutation="permutation-map=2,0,1" | FileCheck %s --check-prefix=CHECK-201 +// RUN: mlir-opt %s -test-loop-permutation="permutation-map=2,1,0" | FileCheck %s --check-prefix=CHECK-210 + +// CHECK-120-LABEL: func @permute +func @permute(%U0 : index, %U1 : index, %U2 : index) { + "abc"() : () -> () + affine.for %arg0 = 0 to %U0 { + affine.for %arg1 = 0 to %U1 { + affine.for %arg2 = 0 to %U2 { + "foo"(%arg0, %arg1) : (index, index) -> () + "bar"(%arg2) : (index) -> () + } + } + } + "xyz"() : () -> () + return +} +// CHECK-120: "abc" +// CHECK-120-NEXT: affine.for +// CHECK-120-NEXT: affine.for +// CHECK-120-NEXT: affine.for +// CHECK-120-NEXT: "foo"(%arg4, %arg5) +// CHECK-120-NEXT: "bar"(%arg3) +// CHECK-120-NEXT: } +// CHECK-120-NEXT: } +// CHECK-120-NEXT: } +// CHECK-120-NEXT: "xyz" +// CHECK-120-NEXT: return + +// CHECK-102: "foo"(%arg4, %arg3) +// CHECK-102-NEXT: "bar"(%arg5) + +// CHECK-012: "foo"(%arg3, %arg4) +// CHECK-012-NEXT: "bar"(%arg5) + +// CHECK-021: "foo"(%arg3, %arg5) +// CHECK-021-NEXT: "bar"(%arg4) + +// CHECK-210: "foo"(%arg5, %arg4) +// CHECK-210-NEXT: "bar"(%arg3) + +// CHECK-201: "foo"(%arg5, %arg3) +// CHECK-201-NEXT: "bar"(%arg4) diff --git a/mlir/test/lib/Dialect/Affine/CMakeLists.txt b/mlir/test/lib/Dialect/Affine/CMakeLists.txt --- a/mlir/test/lib/Dialect/Affine/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Affine/CMakeLists.txt @@ -1,5 +1,6 @@ add_llvm_library(MLIRAffineTransformsTestPasses TestAffineDataCopy.cpp + TestLoopPermutation.cpp TestParallelismDetection.cpp TestVectorizationUtils.cpp diff --git a/mlir/test/lib/Dialect/Affine/TestLoopPermutation.cpp b/mlir/test/lib/Dialect/Affine/TestLoopPermutation.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/Affine/TestLoopPermutation.cpp @@ -0,0 +1,67 @@ +//===- TestLoopPermutation.cpp - Test affine loop permutation -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to test the affine for op permutation utility. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/Utils.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/LoopUtils.h" +#include "mlir/Transforms/Passes.h" + +#define PASS_NAME "test-loop-permutation" + +using namespace mlir; + +static llvm::cl::OptionCategory clOptionsCategory(PASS_NAME " options"); + +namespace { + +/// This pass applies the permutation on the first maximal perfect nest. +struct TestLoopPermutation : public FunctionPass { + TestLoopPermutation() = default; + TestLoopPermutation(const TestLoopPermutation &pass){}; + + void runOnFunction() override; + +private: + /// Permutation specifying loop i is mapped to permList[i] in + /// transformed nest (with i going from outermost to innermost). + ListOption permList{*this, "permutation-map", + llvm::cl::desc("Specify the loop permutation"), + llvm::cl::OneOrMore, llvm::cl::CommaSeparated}; +}; + +} // end anonymous namespace + +void TestLoopPermutation::runOnFunction() { + // Get the first maximal perfect nest. + SmallVector nest; + for (auto &op : getFunction().front()) { + if (auto forOp = dyn_cast(op)) { + getPerfectlyNestedLoops(nest, forOp); + break; + } + } + + // Nothing to do. + if (nest.size() < 2) + return; + + SmallVector permMap(permList.begin(), permList.end()); + permuteLoops(nest, permMap); +} + +namespace mlir { +void registerTestLoopPermutationPass() { + PassRegistration( + PASS_NAME, "Tests affine loop permutation utility"); +} +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -39,6 +39,7 @@ void registerSymbolTestPasses(); void registerTestAffineDataCopyPass(); void registerTestAllReduceLoweringPass(); +void registerTestLoopPermutationPass(); void registerTestCallGraphPass(); void registerTestConstantFold(); void registerTestConvertGPUKernelToCubinPass(); @@ -96,6 +97,7 @@ registerSymbolTestPasses(); registerTestAffineDataCopyPass(); registerTestAllReduceLoweringPass(); + registerTestLoopPermutationPass(); registerTestCallGraphPass(); registerTestConstantFold(); #if MLIR_CUDA_CONVERSIONS_ENABLED