diff --git a/mlir/include/mlir/Dialect/Affine/LoopUtils.h b/mlir/include/mlir/Dialect/Affine/LoopUtils.h --- a/mlir/include/mlir/Dialect/Affine/LoopUtils.h +++ b/mlir/include/mlir/Dialect/Affine/LoopUtils.h @@ -41,6 +41,9 @@ /// constant. Returns failure otherwise. LogicalResult loopUnrollFull(AffineForOp forOp); +/// Walk either an scf.for or an affine.for to find a band to coalesce. +template void coalescePerfectlyNestedLoops(LoopOpTy op); + /// Unrolls this for operation by the specified unroll factor. Returns failure /// if the loop cannot be unrolled either due to restrictions or due to invalid /// unroll factors. Requires positive loop bounds and step. If specified, diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.h b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.h --- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.h +++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_SCF_TRANSFORMOPS_SCFTRANSFORMOPS_H #define MLIR_DIALECT_SCF_TRANSFORMOPS_SCFTRANSFORMOPS_H +#include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformTypes.h" #include "mlir/IR/OpImplementation.h" diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td --- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td +++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td @@ -9,6 +9,7 @@ #ifndef SCF_TRANSFORM_OPS #define SCF_TRANSFORM_OPS +include "mlir/Dialect/PDL/IR/PDLTypes.td" include "mlir/Dialect/Transform/IR/TransformDialect.td" include "mlir/Dialect/Transform/IR/TransformEffects.td" include "mlir/Dialect/Transform/IR/TransformInterfaces.td" @@ -189,4 +190,31 @@ }]; } +def LoopCoalesceOp : Op { + let summary = "Coalesces the perfect loopnest enclosed by a given loop"; + let description = [{ + Obtain the perfect loopnest where the given loop is the outermost loop of + the nest. Perform loop colescing of the nest from bottom-up repeatedly. + + #### Return modes + + The return handle points to the coalesced loop if coalescing happens, or + the given input loop if coalescing does not happen. + }]; + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs PDL_Operation:$transformed); + + let assemblyFormat = + "$target attr-dict `:` type($target) `->` type($transformed)"; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + #endif // SCF_TRANSFORM_OPS diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopCoalescing.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopCoalescing.cpp --- a/mlir/lib/Dialect/Affine/Transforms/LoopCoalescing.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/LoopCoalescing.cpp @@ -32,73 +32,19 @@ struct LoopCoalescingPass : public impl::LoopCoalescingBase { - /// Walk either an scf.for or an affine.for to find a band to coalesce. - template - static void walkLoop(LoopOpTy op) { - // Ignore nested loops. - if (op->template getParentOfType()) - return; - - SmallVector loops; - getPerfectlyNestedLoops(loops, op); - LLVM_DEBUG(llvm::dbgs() - << "found a perfect nest of depth " << loops.size() << '\n'); - - // Look for a band of loops that can be coalesced, i.e. perfectly nested - // loops with bounds defined above some loop. - // 1. For each loop, find above which parent loop its operands are - // defined. - SmallVector operandsDefinedAbove(loops.size()); - for (unsigned i = 0, e = loops.size(); i < e; ++i) { - operandsDefinedAbove[i] = i; - for (unsigned j = 0; j < i; ++j) { - if (areValuesDefinedAbove(loops[i].getOperands(), - loops[j].getRegion())) { - operandsDefinedAbove[i] = j; - break; - } - } - LLVM_DEBUG(llvm::dbgs() - << " bounds of loop " << i << " are known above depth " - << operandsDefinedAbove[i] << '\n'); - } - - // 2. Identify bands of loops such that the operands of all of them are - // defined above the first loop in the band. Traverse the nest bottom-up - // so that modifications don't invalidate the inner loops. - for (unsigned end = loops.size(); end > 0; --end) { - unsigned start = 0; - for (; start < end - 1; ++start) { - auto maxPos = - *std::max_element(std::next(operandsDefinedAbove.begin(), start), - std::next(operandsDefinedAbove.begin(), end)); - if (maxPos > start) - continue; - - assert(maxPos == start && - "expected loop bounds to be known at the start of the band"); - LLVM_DEBUG(llvm::dbgs() << " found coalesceable band from " << start - << " to " << end << '\n'); - - auto band = - llvm::makeMutableArrayRef(loops.data() + start, end - start); - (void)coalesceLoops(band); - break; - } - // If a band was found and transformed, keep looking at the loops above - // the outermost transformed loop. - if (start != end - 1) - end = start + 1; - } - } - void runOnOperation() override { func::FuncOp func = getOperation(); func.walk([&](Operation *op) { - if (auto scfForOp = dyn_cast(op)) - walkLoop(scfForOp); - else if (auto affineForOp = dyn_cast(op)) - walkLoop(affineForOp); + if (auto scfForOp = dyn_cast(op)) { + // Ingore nested loops. + if (!scfForOp->template getParentOfType()) + coalescePerfectlyNestedLoops(scfForOp); + } + else if (auto affineForOp = dyn_cast(op)) { + // Ignore nested loops. + if (!affineForOp->template getParentOfType()) + coalescePerfectlyNestedLoops(affineForOp); + } }); } }; diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp --- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp @@ -21,6 +21,7 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Utils/Utils.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/IntegerSet.h" #include "mlir/Support/MathExtras.h" @@ -2841,3 +2842,59 @@ return success(); } + +template +void mlir::coalescePerfectlyNestedLoops(LoopOpTy op) { + SmallVector loops; + getPerfectlyNestedLoops(loops, op); + LLVM_DEBUG(llvm::dbgs() << "found a perfect nest of depth " << loops.size() + << '\n'); + + // Look for a band of loops that can be coalesced, i.e. perfectly nested + // loops with bounds defined above some loop. + // 1. For each loop, find above which parent loop its operands are + // defined. + SmallVector operandsDefinedAbove(loops.size()); + for (unsigned i = 0, e = loops.size(); i < e; ++i) { + operandsDefinedAbove[i] = i; + for (unsigned j = 0; j < i; ++j) { + if (areValuesDefinedAbove(loops[i].getOperands(), loops[j].getRegion())) { + operandsDefinedAbove[i] = j; + break; + } + } + LLVM_DEBUG(llvm::dbgs() + << " bounds of loop " << i << " are known above depth " + << operandsDefinedAbove[i] << '\n'); + } + + // 2. Identify bands of loops such that the operands of all of them are + // defined above the first loop in the band. Traverse the nest bottom-up + // so that modifications don't invalidate the inner loops. + for (unsigned end = loops.size(); end > 0; --end) { + unsigned start = 0; + for (; start < end - 1; ++start) { + auto maxPos = + *std::max_element(std::next(operandsDefinedAbove.begin(), start), + std::next(operandsDefinedAbove.begin(), end)); + if (maxPos > start) + continue; + + assert(maxPos == start && + "expected loop bounds to be known at the start of the band"); + LLVM_DEBUG(llvm::dbgs() << " found coalesceable band from " << start + << " to " << end << '\n'); + + auto band = llvm::makeMutableArrayRef(loops.data() + start, end - start); + (void)coalesceLoops(band); + break; + } + // If a band was found and transformed, keep looking at the loops above + // the outermost transformed loop. + if (start != end - 1) + end = start + 1; + } +} + +template void mlir::coalescePerfectlyNestedLoops(scf::ForOp op); +template void mlir::coalescePerfectlyNestedLoops(AffineForOp op); diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp --- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp +++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp @@ -226,6 +226,22 @@ return DiagnosedSilenceableFailure::success(); } +//===----------------------------------------------------------------------===// +// LoopCoalesceOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::LoopCoalesceOp::applyToOne(Operation *op, + transform::ApplyToEachResultList &results, + transform::TransformState &state) { + if (scf::ForOp scfForOp = dyn_cast(op)) + coalescePerfectlyNestedLoops(scfForOp); + else if (AffineForOp affineForOp = dyn_cast(op)) + coalescePerfectlyNestedLoops(affineForOp); + results.push_back(op); + return DiagnosedSilenceableFailure::success(); +} + //===----------------------------------------------------------------------===// // Transform op registration //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SCF/transform-op-coalesce.mlir b/mlir/test/Dialect/SCF/transform-op-coalesce.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SCF/transform-op-coalesce.mlir @@ -0,0 +1,92 @@ +// RUN: mlir-opt %s -allow-unregistered-dialect -test-transform-dialect-interpreter -split-input-file -verify-diagnostics | FileCheck %s + +func.func @coalesce_inner() { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + // CHECK: scf.for [[IV0:%[a-zA-Z0-9]+]] + scf.for %i = %c0 to %c10 step %c1 { + // CHECK: scf.for [[IV1:%[a-zA-Z0-9]+]] + scf.for %j = %c0 to %c10 step %c1 { + // CHECK: scf.for [[IV2:%[a-zA-Z0-9]+]] + scf.for %k = %i to %j step %c1 { + // Inner loop must have been removed. + // CHECK-NOT: scf.for [[IV3:%[a-zA-Z0-9]+]] + scf.for %l = %i to %j step %c1 { + "foo"() : () -> () + } + } {coalesce} + } + } + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["scf.for"]} attributes {coalesce} in %arg1 + %1 = transform.cast %0 : !pdl.operation to !transform.op<"scf.for"> + %2 = transform.loop.coalesce %1: !transform.op<"scf.for"> -> !pdl.operation +} + +// ----- + +func.func @coalesce_outer(%arg1: memref<64x64xf32, 1>, %arg2: memref<64x64xf32, 1>, %arg3: memref<64x64xf32, 1>) attributes {} { + // CHECK: affine.for [[IV1:%[a-zA-Z0-9]+]] = 0 to [[UB:%[a-zA-Z0-9]+]] { + affine.for %arg4 = 0 to 64 { + // CHECK-NOT: affine.for [[IV2:%[a-zA-Z0-9]+]] + affine.for %arg5 = 0 to 64 { + // CHECK: %[[IDX0:.*]] = affine.apply #[[MAP0:.+]]([[IV1]])[%{{.*}}] + // CHECK: %[[IDX1:.*]] = affine.apply #[[MAP1:.+]]([[IV1]])[%{{.*}}] + // CHECK-NEXT: %{{.*}} = affine.load %{{.*}}[%[[IDX1]], %[[IDX0]]] : memref<64x64xf32, 1> + %0 = affine.load %arg1[%arg4, %arg5] : memref<64x64xf32, 1> + %1 = affine.load %arg2[%arg4, %arg5] : memref<64x64xf32, 1> + %2 = arith.addf %0, %1 : f32 + affine.store %2, %arg3[%arg4, %arg5] : memref<64x64xf32, 1> + } + } {coalesce} + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["affine.for"]} attributes {coalesce} in %arg1 + %1 = transform.cast %0 : !pdl.operation to !transform.op<"affine.for"> + %2 = transform.loop.coalesce %1 : !transform.op<"affine.for"> -> !pdl.operation +} + +// ----- + +func.func @coalesce_and_unroll(%arg1: memref<64x64xf32, 1>, %arg2: memref<64x64xf32, 1>, %arg3: memref<64x64xf32, 1>) attributes {} { + // CHECK: scf.for [[IV1:%[a-zA-Z0-9]+]] + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + + scf.for %arg4 = %c0 to %c64 step %c1 { + // CHECK-NOT: scf.for [[IV2:%[a-zA-Z0-9]+]] + scf.for %arg5 = %c0 to %c64 step %c1 { + // CHECK: %[[IDX0:.*]] = arith.remsi [[IV1]] + // CHECK: %[[IDX1:.*]] = arith.divsi [[IV1]] + // CHECK-NEXT: %{{.*}} = memref.load %{{.*}}[%[[IDX1]], %[[IDX0]]] : memref<64x64xf32, 1> + %0 = memref.load %arg1[%arg4, %arg5] : memref<64x64xf32, 1> + %1 = memref.load %arg2[%arg4, %arg5] : memref<64x64xf32, 1> + %2 = arith.addf %0, %1 : f32 + // CHECK: memref.store + // CHECK: memref.store + // CHECK: memref.store + // Residual loop must have a single store. + // CHECK: memref.store + memref.store %2, %arg3[%arg4, %arg5] : memref<64x64xf32, 1> + } + } {coalesce} + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["scf.for"]} attributes {coalesce} in %arg1 + %1 = transform.cast %0 : !pdl.operation to !transform.op<"scf.for"> + %2 = transform.loop.coalesce %1 : !transform.op<"scf.for"> -> !pdl.operation + %loop = transform.cast %2 : !pdl.operation to !transform.op<"scf.for"> + transform.loop.unroll %loop {factor = 3} : !transform.op<"scf.for"> +}