diff --git a/mlir/include/mlir/Dialect/Affine/LoopUtils.h b/mlir/include/mlir/Dialect/Affine/LoopUtils.h --- a/mlir/include/mlir/Dialect/Affine/LoopUtils.h +++ b/mlir/include/mlir/Dialect/Affine/LoopUtils.h @@ -18,6 +18,7 @@ #include "mlir/IR/Block.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/RegionUtils.h" #include namespace mlir { @@ -293,6 +294,54 @@ separateFullTiles(MutableArrayRef nest, SmallVectorImpl *fullTileNest = nullptr); +/// Walk either an scf.for or an affine.for to find a band to coalesce. +template +LogicalResult coalescePerfectlyNestedLoops(LoopOpTy op) { + LogicalResult result(failure()); + SmallVector loops; + getPerfectlyNestedLoops(loops, op); + + // Look for a band of loops that can be coalesced, i.e. perfectly nested + // loops with bounds defined above some loop. + // 1. For each loop, find above which parent loop its operands are + // defined. + SmallVector operandsDefinedAbove(loops.size()); + for (unsigned i = 0, e = loops.size(); i < e; ++i) { + operandsDefinedAbove[i] = i; + for (unsigned j = 0; j < i; ++j) { + if (areValuesDefinedAbove(loops[i].getOperands(), loops[j].getRegion())) { + operandsDefinedAbove[i] = j; + break; + } + } + } + + // 2. Identify bands of loops such that the operands of all of them are + // defined above the first loop in the band. Traverse the nest bottom-up + // so that modifications don't invalidate the inner loops. + for (unsigned end = loops.size(); end > 0; --end) { + unsigned start = 0; + for (; start < end - 1; ++start) { + auto maxPos = + *std::max_element(std::next(operandsDefinedAbove.begin(), start), + std::next(operandsDefinedAbove.begin(), end)); + if (maxPos > start) + continue; + assert(maxPos == start && + "expected loop bounds to be known at the start of the band"); + auto band = llvm::makeMutableArrayRef(loops.data() + start, end - start); + if (succeeded(coalesceLoops(band))) + result = success(); + break; + } + // If a band was found and transformed, keep looking at the loops above + // the outermost transformed loop. + if (start != end - 1) + end = start + 1; + } + return result; +} + } // namespace mlir #endif // MLIR_DIALECT_AFFINE_LOOPUTILS_H diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td --- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td +++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td @@ -189,4 +189,31 @@ }]; } +def LoopCoalesceOp : Op { + let summary = "Coalesces the perfect loopnest enclosed by a given loop"; + let description = [{ + Obtain the perfect loopnest where the given loop is the outermost loop of + the nest. Perform loop colescing of the nest from bottom-up repeatedly. + + #### Return modes + + The return handle points to the coalesced loop if coalescing happens, or + the given input loop if coalescing does not happen. + }]; + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs TransformHandleTypeInterface:$transformed); + + let assemblyFormat = + "$target attr-dict `:` functional-type($target, $transformed)"; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + #endif // SCF_TRANSFORM_OPS diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h --- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h @@ -121,7 +121,7 @@ /// Replace a perfect nest of "for" loops with a single linearized loop. Assumes /// `loops` contains a list of perfectly nested loops with bounds and steps /// independent of any loop induction variable involved in the nest. -void coalesceLoops(MutableArrayRef loops); +LogicalResult coalesceLoops(MutableArrayRef loops); /// Take the ParallelLoop and for each set of dimension indices, combine them /// into a single dimension. combinedDimensions must contain each index into diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopCoalescing.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopCoalescing.cpp --- a/mlir/lib/Dialect/Affine/Transforms/LoopCoalescing.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/LoopCoalescing.cpp @@ -32,73 +32,13 @@ struct LoopCoalescingPass : public impl::LoopCoalescingBase { - /// Walk either an scf.for or an affine.for to find a band to coalesce. - template - static void walkLoop(LoopOpTy op) { - // Ignore nested loops. - if (op->template getParentOfType()) - return; - - SmallVector loops; - getPerfectlyNestedLoops(loops, op); - LLVM_DEBUG(llvm::dbgs() - << "found a perfect nest of depth " << loops.size() << '\n'); - - // Look for a band of loops that can be coalesced, i.e. perfectly nested - // loops with bounds defined above some loop. - // 1. For each loop, find above which parent loop its operands are - // defined. - SmallVector operandsDefinedAbove(loops.size()); - for (unsigned i = 0, e = loops.size(); i < e; ++i) { - operandsDefinedAbove[i] = i; - for (unsigned j = 0; j < i; ++j) { - if (areValuesDefinedAbove(loops[i].getOperands(), - loops[j].getRegion())) { - operandsDefinedAbove[i] = j; - break; - } - } - LLVM_DEBUG(llvm::dbgs() - << " bounds of loop " << i << " are known above depth " - << operandsDefinedAbove[i] << '\n'); - } - - // 2. Identify bands of loops such that the operands of all of them are - // defined above the first loop in the band. Traverse the nest bottom-up - // so that modifications don't invalidate the inner loops. - for (unsigned end = loops.size(); end > 0; --end) { - unsigned start = 0; - for (; start < end - 1; ++start) { - auto maxPos = - *std::max_element(std::next(operandsDefinedAbove.begin(), start), - std::next(operandsDefinedAbove.begin(), end)); - if (maxPos > start) - continue; - - assert(maxPos == start && - "expected loop bounds to be known at the start of the band"); - LLVM_DEBUG(llvm::dbgs() << " found coalesceable band from " << start - << " to " << end << '\n'); - - auto band = - llvm::makeMutableArrayRef(loops.data() + start, end - start); - (void)coalesceLoops(band); - break; - } - // If a band was found and transformed, keep looking at the loops above - // the outermost transformed loop. - if (start != end - 1) - end = start + 1; - } - } - void runOnOperation() override { func::FuncOp func = getOperation(); - func.walk([&](Operation *op) { + func.walk([](Operation *op) { if (auto scfForOp = dyn_cast(op)) - walkLoop(scfForOp); + (void)coalescePerfectlyNestedLoops(scfForOp); else if (auto affineForOp = dyn_cast(op)) - walkLoop(affineForOp); + (void)coalescePerfectlyNestedLoops(affineForOp); }); } }; diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp --- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp @@ -25,7 +25,6 @@ #include "mlir/IR/IntegerSet.h" #include "mlir/Support/MathExtras.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/Debug.h" diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp --- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp +++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp @@ -226,6 +226,29 @@ return DiagnosedSilenceableFailure::success(); } +//===----------------------------------------------------------------------===// +// LoopCoalesceOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::LoopCoalesceOp::applyToOne(Operation *op, + transform::ApplyToEachResultList &results, + transform::TransformState &state) { + LogicalResult result(failure()); + if (scf::ForOp scfForOp = dyn_cast(op)) + result = coalescePerfectlyNestedLoops(scfForOp); + else if (AffineForOp affineForOp = dyn_cast(op)) + result = coalescePerfectlyNestedLoops(affineForOp); + + results.push_back(op); + if (failed(result)) { + Diagnostic diag(op->getLoc(), DiagnosticSeverity::Note); + diag << "Op failed to coalesce"; + return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); + } + return DiagnosedSilenceableFailure::success(); +} + //===----------------------------------------------------------------------===// // Transform op registration //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -656,9 +656,9 @@ loop.setStep(loopPieces.step); } -void mlir::coalesceLoops(MutableArrayRef loops) { +LogicalResult mlir::coalesceLoops(MutableArrayRef loops) { if (loops.size() < 2) - return; + return failure(); scf::ForOp innermost = loops.back(); scf::ForOp outermost = loops.front(); @@ -710,6 +710,7 @@ Block::iterator(second.getOperation()), innermost.getBody()->getOperations()); second.erase(); + return success(); } void mlir::collapseParallelLoops( diff --git a/mlir/test/Dialect/SCF/transform-op-coalesce.mlir b/mlir/test/Dialect/SCF/transform-op-coalesce.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SCF/transform-op-coalesce.mlir @@ -0,0 +1,92 @@ +// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file -verify-diagnostics | FileCheck %s + +func.func @coalesce_inner() { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + + // CHECK: scf.for [[IV0:.*]] + // CHECK: scf.for [[IV1:.*]] + // CHECK: scf.for [[IV2:.*]] + // CHECK-NOT: scf.for [[IV3:.*]] + scf.for %i = %c0 to %c10 step %c1 { + scf.for %j = %c0 to %c10 step %c1 { + scf.for %k = %i to %j step %c1 { + // Inner loop must have been removed. + scf.for %l = %i to %j step %c1 { + arith.addi %i, %j : index + } + } {coalesce} + } + } + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["scf.for"]} attributes {coalesce} in %arg1 + %1 = transform.cast %0 : !pdl.operation to !transform.op<"scf.for"> + %2 = transform.loop.coalesce %1: (!transform.op<"scf.for">) -> (!transform.op<"scf.for">) +} + +// ----- + +func.func @coalesce_outer(%arg1: memref<64x64xf32, 1>, %arg2: memref<64x64xf32, 1>, %arg3: memref<64x64xf32, 1>) attributes {} { + // CHECK: affine.for [[IV1:.*]] = 0 to [[UB:.*]] { + // CHECK-NOT: affine.for [[IV2:.*]] + affine.for %arg4 = 0 to 64 { + affine.for %arg5 = 0 to 64 { + // CHECK: %[[IDX0:.*]] = affine.apply #[[MAP0:.+]]([[IV1]])[%{{.*}}] + // CHECK: %[[IDX1:.*]] = affine.apply #[[MAP1:.+]]([[IV1]])[%{{.*}}] + // CHECK-NEXT: %{{.*}} = affine.load %{{.*}}[%[[IDX1]], %[[IDX0]]] : memref<64x64xf32, 1> + %0 = affine.load %arg1[%arg4, %arg5] : memref<64x64xf32, 1> + %1 = affine.load %arg2[%arg4, %arg5] : memref<64x64xf32, 1> + %2 = arith.addf %0, %1 : f32 + affine.store %2, %arg3[%arg4, %arg5] : memref<64x64xf32, 1> + } + } {coalesce} + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["affine.for"]} attributes {coalesce} in %arg1 + %1 = transform.cast %0 : !pdl.operation to !transform.op<"affine.for"> + %2 = transform.loop.coalesce %1 : (!transform.op<"affine.for">) -> (!transform.op<"affine.for">) +} + +// ----- + +func.func @coalesce_and_unroll(%arg1: memref<64x64xf32, 1>, %arg2: memref<64x64xf32, 1>, %arg3: memref<64x64xf32, 1>) attributes {} { + // CHECK: scf.for [[IV1:%[a-zA-Z0-9]+]] + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + + scf.for %arg4 = %c0 to %c64 step %c1 { + // CHECK-NOT: scf.for [[IV2:%[a-zA-Z0-9]+]] + scf.for %arg5 = %c0 to %c64 step %c1 { + // CHECK: %[[IDX0:.*]] = arith.remsi [[IV1]] + // CHECK: %[[IDX1:.*]] = arith.divsi [[IV1]] + // CHECK-NEXT: %{{.*}} = memref.load %{{.*}}[%[[IDX1]], %[[IDX0]]] : memref<64x64xf32, 1> + %0 = memref.load %arg1[%arg4, %arg5] : memref<64x64xf32, 1> + %1 = memref.load %arg2[%arg4, %arg5] : memref<64x64xf32, 1> + %2 = arith.addf %0, %1 : f32 + // CHECK: memref.store + // CHECK: memref.store + // CHECK: memref.store + // Residual loop must have a single store. + // CHECK: memref.store + memref.store %2, %arg3[%arg4, %arg5] : memref<64x64xf32, 1> + } + } {coalesce} + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["scf.for"]} attributes {coalesce} in %arg1 + %1 = transform.cast %0 : !pdl.operation to !transform.op<"scf.for"> + %2 = transform.loop.coalesce %1 : (!transform.op<"scf.for">) -> (!transform.op<"scf.for">) + transform.loop.unroll %2 {factor = 3} : !transform.op<"scf.for"> +}