diff --git a/mlir/include/mlir/Dialect/Affine/LoopUtils.h b/mlir/include/mlir/Dialect/Affine/LoopUtils.h --- a/mlir/include/mlir/Dialect/Affine/LoopUtils.h +++ b/mlir/include/mlir/Dialect/Affine/LoopUtils.h @@ -18,6 +18,7 @@ #include "mlir/IR/Block.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/RegionUtils.h" #include namespace mlir { @@ -293,6 +294,54 @@ separateFullTiles(MutableArrayRef nest, SmallVectorImpl *fullTileNest = nullptr); +/// Walk either an scf.for or an affine.for to find a band to coalesce. +template +LogicalResult coalescePerfectlyNestedLoops(LoopOpTy op) { + LogicalResult result(failure()); + SmallVector loops; + getPerfectlyNestedLoops(loops, op); + + // Look for a band of loops that can be coalesced, i.e. perfectly nested + // loops with bounds defined above some loop. + // 1. For each loop, find above which parent loop its operands are + // defined. + SmallVector operandsDefinedAbove(loops.size()); + for (unsigned i = 0, e = loops.size(); i < e; ++i) { + operandsDefinedAbove[i] = i; + for (unsigned j = 0; j < i; ++j) { + if (areValuesDefinedAbove(loops[i].getOperands(), loops[j].getRegion())) { + operandsDefinedAbove[i] = j; + break; + } + } + } + + // 2. Identify bands of loops such that the operands of all of them are + // defined above the first loop in the band. Traverse the nest bottom-up + // so that modifications don't invalidate the inner loops. + for (unsigned end = loops.size(); end > 0; --end) { + unsigned start = 0; + for (; start < end - 1; ++start) { + auto maxPos = + *std::max_element(std::next(operandsDefinedAbove.begin(), start), + std::next(operandsDefinedAbove.begin(), end)); + if (maxPos > start) + continue; + assert(maxPos == start && + "expected loop bounds to be known at the start of the band"); + auto band = llvm::makeMutableArrayRef(loops.data() + start, end - start); + if (succeeded(coalesceLoops(band))) + result = success(); + break; + } + // If a band was found and transformed, keep looking at the loops above + // the outermost transformed loop. + if (start != end - 1) + end = start + 1; + } + return result; +} + } // namespace mlir #endif // MLIR_DIALECT_AFFINE_LOOPUTILS_H diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td --- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td +++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td @@ -189,4 +189,31 @@ }]; } +def LoopCoalesceOp : Op { + let summary = "Coalesces the perfect loop nest enclosed by a given loop"; + let description = [{ + Given a perfect loop nest identified by the outermost loop, + perform loop coalescing in a bottom-up one-by-one manner. + + #### Return modes + + The return handle points to the coalesced loop if coalescing happens, or + the given input loop if coalescing does not happen. + }]; + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs TransformHandleTypeInterface:$transformed); + + let assemblyFormat = + "$target attr-dict `:` functional-type($target, $transformed)"; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + #endif // SCF_TRANSFORM_OPS diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h --- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h @@ -121,7 +121,7 @@ /// Replace a perfect nest of "for" loops with a single linearized loop. Assumes /// `loops` contains a list of perfectly nested loops with bounds and steps /// independent of any loop induction variable involved in the nest. -void coalesceLoops(MutableArrayRef loops); +LogicalResult coalesceLoops(MutableArrayRef loops); /// Take the ParallelLoop and for each set of dimension indices, combine them /// into a single dimension. combinedDimensions must contain each index into diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopCoalescing.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopCoalescing.cpp --- a/mlir/lib/Dialect/Affine/Transforms/LoopCoalescing.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/LoopCoalescing.cpp @@ -32,72 +32,13 @@ struct LoopCoalescingPass : public impl::LoopCoalescingBase { - /// Walk either an scf.for or an affine.for to find a band to coalesce. - template - static void walkLoop(LoopOpTy op) { - // Ignore nested loops. - if (op->template getParentOfType()) - return; - - SmallVector loops; - getPerfectlyNestedLoops(loops, op); - LLVM_DEBUG(llvm::dbgs() - << "found a perfect nest of depth " << loops.size() << '\n'); - - // Look for a band of loops that can be coalesced, i.e. perfectly nested - // loops with bounds defined above some loop. - // 1. For each loop, find above which parent loop its operands are - // defined. - SmallVector operandsDefinedAbove(loops.size()); - for (unsigned i = 0, e = loops.size(); i < e; ++i) { - operandsDefinedAbove[i] = i; - for (unsigned j = 0; j < i; ++j) { - if (areValuesDefinedAbove(loops[i].getOperands(), - loops[j].getRegion())) { - operandsDefinedAbove[i] = j; - break; - } - } - LLVM_DEBUG(llvm::dbgs() - << " bounds of loop " << i << " are known above depth " - << operandsDefinedAbove[i] << '\n'); - } - - // 2. Identify bands of loops such that the operands of all of them are - // defined above the first loop in the band. Traverse the nest bottom-up - // so that modifications don't invalidate the inner loops. - for (unsigned end = loops.size(); end > 0; --end) { - unsigned start = 0; - for (; start < end - 1; ++start) { - auto maxPos = - *std::max_element(std::next(operandsDefinedAbove.begin(), start), - std::next(operandsDefinedAbove.begin(), end)); - if (maxPos > start) - continue; - - assert(maxPos == start && - "expected loop bounds to be known at the start of the band"); - LLVM_DEBUG(llvm::dbgs() << " found coalesceable band from " << start - << " to " << end << '\n'); - - auto band = llvm::MutableArrayRef(loops.data() + start, end - start); - (void)coalesceLoops(band); - break; - } - // If a band was found and transformed, keep looking at the loops above - // the outermost transformed loop. - if (start != end - 1) - end = start + 1; - } - } - void runOnOperation() override { func::FuncOp func = getOperation(); - func.walk([&](Operation *op) { + func.walk([](Operation *op) { if (auto scfForOp = dyn_cast(op)) - walkLoop(scfForOp); + (void)coalescePerfectlyNestedLoops(scfForOp); else if (auto affineForOp = dyn_cast(op)) - walkLoop(affineForOp); + (void)coalescePerfectlyNestedLoops(affineForOp); }); } }; diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp --- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp @@ -25,7 +25,6 @@ #include "mlir/IR/IntegerSet.h" #include "mlir/Support/MathExtras.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/Debug.h" diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp --- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp +++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp @@ -219,9 +219,32 @@ result = loopUnrollByFactor(affineFor, getFactor()); if (failed(result)) { - Diagnostic diag(op->getLoc(), DiagnosticSeverity::Note); - diag << "Op failed to unroll"; - return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); + DiagnosedSilenceableFailure diag = emitSilenceableError() + << "failed to unroll"; + return diag; + } + return DiagnosedSilenceableFailure::success(); +} + +//===----------------------------------------------------------------------===// +// LoopCoalesceOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::LoopCoalesceOp::applyToOne(Operation *op, + transform::ApplyToEachResultList &results, + transform::TransformState &state) { + LogicalResult result(failure()); + if (scf::ForOp scfForOp = dyn_cast(op)) + result = coalescePerfectlyNestedLoops(scfForOp); + else if (AffineForOp affineForOp = dyn_cast(op)) + result = coalescePerfectlyNestedLoops(affineForOp); + + results.push_back(op); + if (failed(result)) { + DiagnosedSilenceableFailure diag = emitSilenceableError() + << "failed to coalesce"; + return diag; } return DiagnosedSilenceableFailure::success(); } diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -656,9 +656,9 @@ loop.setStep(loopPieces.step); } -void mlir::coalesceLoops(MutableArrayRef loops) { +LogicalResult mlir::coalesceLoops(MutableArrayRef loops) { if (loops.size() < 2) - return; + return failure(); scf::ForOp innermost = loops.back(); scf::ForOp outermost = loops.front(); @@ -710,6 +710,7 @@ Block::iterator(second.getOperation()), innermost.getBody()->getOperations()); second.erase(); + return success(); } void mlir::collapseParallelLoops( diff --git a/mlir/test/Dialect/SCF/transform-op-coalesce.mlir b/mlir/test/Dialect/SCF/transform-op-coalesce.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SCF/transform-op-coalesce.mlir @@ -0,0 +1,92 @@ +// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file -verify-diagnostics | FileCheck %s + +func.func @coalesce_inner() { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + + // CHECK: scf.for %[[IV0:.+]] + // CHECK: scf.for %[[IV1:.+]] + // CHECK: scf.for %[[IV2:.+]] + // CHECK-NOT: scf.for %[[IV3:.+]] + scf.for %i = %c0 to %c10 step %c1 { + scf.for %j = %c0 to %c10 step %c1 { + scf.for %k = %i to %j step %c1 { + // Inner loop must have been removed. + scf.for %l = %i to %j step %c1 { + arith.addi %i, %j : index + } + } {coalesce} + } + } + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["scf.for"]} attributes {coalesce} in %arg1 + %1 = transform.cast %0 : !pdl.operation to !transform.op<"scf.for"> + %2 = transform.loop.coalesce %1: (!transform.op<"scf.for">) -> (!transform.op<"scf.for">) +} + +// ----- + +func.func @coalesce_outer(%arg1: memref<64x64xf32, 1>, %arg2: memref<64x64xf32, 1>, %arg3: memref<64x64xf32, 1>) attributes {} { + // CHECK: affine.for %[[IV1:.+]] = 0 to %[[UB:.+]] { + // CHECK-NOT: affine.for %[[IV2:.+]] + affine.for %arg4 = 0 to 64 { + affine.for %arg5 = 0 to 64 { + // CHECK: %[[IDX0:.+]] = affine.apply #[[MAP0:.+]](%[[IV1]])[%{{.+}}] + // CHECK: %[[IDX1:.+]] = affine.apply #[[MAP1:.+]](%[[IV1]])[%{{.+}}] + // CHECK-NEXT: %{{.+}} = affine.load %{{.+}}[%[[IDX1]], %[[IDX0]]] : memref<64x64xf32, 1> + %0 = affine.load %arg1[%arg4, %arg5] : memref<64x64xf32, 1> + %1 = affine.load %arg2[%arg4, %arg5] : memref<64x64xf32, 1> + %2 = arith.addf %0, %1 : f32 + affine.store %2, %arg3[%arg4, %arg5] : memref<64x64xf32, 1> + } + } {coalesce} + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["affine.for"]} attributes {coalesce} in %arg1 + %1 = transform.cast %0 : !pdl.operation to !transform.op<"affine.for"> + %2 = transform.loop.coalesce %1 : (!transform.op<"affine.for">) -> (!transform.op<"affine.for">) +} + +// ----- + +func.func @coalesce_and_unroll(%arg1: memref<64x64xf32, 1>, %arg2: memref<64x64xf32, 1>, %arg3: memref<64x64xf32, 1>) attributes {} { + // CHECK: scf.for %[[IV1:.+]] = + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + + scf.for %arg4 = %c0 to %c64 step %c1 { + // CHECK-NOT: scf.for + scf.for %arg5 = %c0 to %c64 step %c1 { + // CHECK: %[[IDX0:.+]] = arith.remsi %[[IV1]] + // CHECK: %[[IDX1:.+]] = arith.divsi %[[IV1]] + // CHECK-NEXT: %{{.+}} = memref.load %{{.+}}[%[[IDX1]], %[[IDX0]]] : memref<64x64xf32, 1> + %0 = memref.load %arg1[%arg4, %arg5] : memref<64x64xf32, 1> + %1 = memref.load %arg2[%arg4, %arg5] : memref<64x64xf32, 1> + %2 = arith.addf %0, %1 : f32 + // CHECK: memref.store + // CHECK: memref.store + // CHECK: memref.store + // Residual loop must have a single store. + // CHECK: memref.store + memref.store %2, %arg3[%arg4, %arg5] : memref<64x64xf32, 1> + } + } {coalesce} + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["scf.for"]} attributes {coalesce} in %arg1 + %1 = transform.cast %0 : !pdl.operation to !transform.op<"scf.for"> + %2 = transform.loop.coalesce %1 : (!transform.op<"scf.for">) -> (!transform.op<"scf.for">) + transform.loop.unroll %2 {factor = 3} : !transform.op<"scf.for"> +} diff --git a/mlir/test/Dialect/SCF/transform-ops-invalid.mlir b/mlir/test/Dialect/SCF/transform-ops-invalid.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SCF/transform-ops-invalid.mlir @@ -0,0 +1,61 @@ +// RUN: mlir-opt %s -test-transform-dialect-interpreter --split-input-file --verify-diagnostics + +#map0 = affine_map<(d0) -> (d0 * 110)> +#map1 = affine_map<(d0) -> (696, d0 * 110 + 110)> +func.func @test_loops_do_not_get_coalesced() { + affine.for %i = 0 to 7 { + affine.for %j = #map0(%i) to min #map1(%i) { + } + } {coalesce} + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["affine.for"]} attributes {coalesce} in %arg1 + %1 = transform.cast %0 : !pdl.operation to !transform.op<"affine.for"> + // expected-error @below {{failed to coalesce}} + %2 = transform.loop.coalesce %1: (!transform.op<"affine.for">) -> (!transform.op<"affine.for">) +} + +// ----- + +func.func @test_loops_do_not_get_unrolled() { + affine.for %i = 0 to 7 { + arith.addi %i, %i : index + } + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["arith.addi"]} in %arg1 + %1 = transform.loop.get_parent_for %0 { affine = true } : (!pdl.operation) -> !transform.op<"affine.for"> + // expected-error @below {{failed to unroll}} + transform.loop.unroll %1 { factor = 8 } : !transform.op<"affine.for"> +} + +// ----- + +func.func private @cond() -> i1 +func.func private @body() + +func.func @loop_outline_op_multi_region() { + // expected-note @below {{target op}} + scf.while : () -> () { + %0 = func.call @cond() : () -> i1 + scf.condition(%0) + } do { + ^bb0: + func.call @body() : () -> () + scf.yield + } + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["scf.while"]} in %arg1 + // expected-error @below {{failed to outline}} + transform.loop.outline %0 {func_name = "foo"} : (!pdl.operation) -> !pdl.operation +} diff --git a/mlir/test/Dialect/SCF/transform-ops.mlir b/mlir/test/Dialect/SCF/transform-ops.mlir --- a/mlir/test/Dialect/SCF/transform-ops.mlir +++ b/mlir/test/Dialect/SCF/transform-ops.mlir @@ -84,31 +84,6 @@ // ----- -func.func private @cond() -> i1 -func.func private @body() - -func.func @loop_outline_op_multi_region() { - // expected-note @below {{target op}} - scf.while : () -> () { - %0 = func.call @cond() : () -> i1 - scf.condition(%0) - } do { - ^bb0: - func.call @body() : () -> () - scf.yield - } - return -} - -transform.sequence failures(propagate) { -^bb1(%arg1: !pdl.operation): - %0 = transform.structured.match ops{["scf.while"]} in %arg1 - // expected-error @below {{failed to outline}} - transform.loop.outline %0 {func_name = "foo"} : (!pdl.operation) -> !pdl.operation -} - -// ----- - // CHECK-LABEL: @loop_peel_op func.func @loop_peel_op() { // CHECK: %[[C0:.+]] = arith.constant 0