diff --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td @@ -346,6 +346,8 @@ unsigned getNumLoops() { return step().size(); } unsigned getNumReductions() { return initVals().size(); } }]; + + let hasCanonicalizer = 1; } def ReduceOp : SCF_Op<"reduce", [HasParent<"ParallelOp">]> { diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -8,17 +8,8 @@ #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "mlir/IR/AffineExpr.h" -#include "mlir/IR/AffineMap.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/Function.h" -#include "mlir/IR/Matchers.h" -#include "mlir/IR/Module.h" -#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/PatternMatch.h" -#include "mlir/IR/StandardTypes.h" -#include "mlir/IR/Value.h" -#include "mlir/Support/MathExtras.h" #include "mlir/Transforms/InliningUtils.h" using namespace mlir; @@ -742,6 +733,68 @@ return dyn_cast(containingOp); } +namespace { +// Collapse loop dimensions that perform a single iteration. +struct CollapseSingleIterationLoops : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ParallelOp op, + PatternRewriter &rewriter) const override { + BlockAndValueMapping mapping; + // Compute new loop bounds that omit all single-iteration loop dimensions. + SmallVector newLowerBounds; + SmallVector newUpperBounds; + SmallVector newSteps; + newLowerBounds.reserve(op.lowerBound().size()); + newUpperBounds.reserve(op.upperBound().size()); + newSteps.reserve(op.step().size()); + for (auto dim : llvm::zip(op.lowerBound(), op.upperBound(), op.step(), + op.getInductionVars())) { + Value lowerBound, upperBound, step, iv; + std::tie(lowerBound, upperBound, step, iv) = dim; + // Collect the statically known loop bounds. + auto lowerBoundConstant = + dyn_cast_or_null(lowerBound.getDefiningOp()); + auto upperBoundConstant = + dyn_cast_or_null(upperBound.getDefiningOp()); + auto stepConstant = + dyn_cast_or_null(step.getDefiningOp()); + // Replace the loop induction variable by the lower bound if the loop + // performs a single iteration. Otherwise, copy the loop bounds. + if (lowerBoundConstant && upperBoundConstant && stepConstant && + (upperBoundConstant.getValue() - lowerBoundConstant.getValue()) > 0 && + (upperBoundConstant.getValue() - lowerBoundConstant.getValue()) <= + stepConstant.getValue()) { + mapping.map(iv, lowerBound); + } else { + newLowerBounds.push_back(lowerBound); + newUpperBounds.push_back(upperBound); + newSteps.push_back(step); + } + } + // Exit if all or none of the loop dimensions perform a single iteration. + if (newLowerBounds.size() == 0 || + newLowerBounds.size() == op.lowerBound().size()) + return failure(); + // Replace the parallel loop by lower-dimensional parallel loop. + auto newOp = + rewriter.create(op.getLoc(), newLowerBounds, newUpperBounds, + newSteps, op.initVals(), nullptr); + // Clone the loop body and remap the block arguments of the collapsed loops + // (inlining does not support a cancellable block argument mapping). + rewriter.cloneRegionBefore(op.region(), newOp.region(), + newOp.region().begin(), mapping); + rewriter.replaceOp(op, newOp.getResults()); + return success(); + } +}; +} // namespace + +void ParallelOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // ReduceOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -0,0 +1,55 @@ +// RUN: mlir-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s + +func @single_iteration(%A: memref) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c2 = constant 2 : index + %c3 = constant 3 : index + %c6 = constant 6 : index + %c7 = constant 7 : index + %c10 = constant 10 : index + scf.parallel (%i0, %i1, %i2) = (%c0, %c3, %c7) to (%c1, %c6, %c10) step (%c1, %c2, %c3) { + %c42 = constant 42 : i32 + store %c42, %A[%i0, %i1, %i2] : memref + scf.yield + } + return +} + +// CHECK-LABEL: func @single_iteration( +// CHECK-SAME: [[ARG0:%.*]]: memref) { +// CHECK: [[C0:%.*]] = constant 0 : index +// CHECK: [[C2:%.*]] = constant 2 : index +// CHECK: [[C3:%.*]] = constant 3 : index +// CHECK: [[C6:%.*]] = constant 6 : index +// CHECK: [[C7:%.*]] = constant 7 : index +// CHECK: [[C42:%.*]] = constant 42 : i32 +// CHECK: scf.parallel ([[V0:%.*]]) = ([[C3]]) to ([[C6]]) step ([[C2]]) { +// CHECK: store [[C42]], [[ARG0]]{{\[}}[[C0]], [[V0]], [[C7]]] : memref +// CHECK: scf.yield +// CHECK: } +// CHECK: return + +// ----- + +func @no_iteration(%A: memref) { + %c0 = constant 0 : index + %c1 = constant 1 : index + scf.parallel (%i0, %i1) = (%c0, %c0) to (%c1, %c0) step (%c1, %c1) { + %c42 = constant 42 : i32 + store %c42, %A[%i0, %i1] : memref + scf.yield + } + return +} + +// CHECK-LABEL: func @no_iteration( +// CHECK-SAME: [[ARG0:%.*]]: memref) { +// CHECK: [[C0:%.*]] = constant 0 : index +// CHECK: [[C1:%.*]] = constant 1 : index +// CHECK: [[C42:%.*]] = constant 42 : i32 +// CHECK: scf.parallel ([[V1:%.*]]) = ([[C0]]) to ([[C0]]) step ([[C1]]) { +// CHECK: store [[C42]], [[ARG0]]{{\[}}[[C0]], [[V1]]] : memref +// CHECK: scf.yield +// CHECK: } +// CHECK: return diff --git a/mlir/test/Transforms/parallel-loop-collapsing.mlir b/mlir/test/Transforms/parallel-loop-collapsing.mlir --- a/mlir/test/Transforms/parallel-loop-collapsing.mlir +++ b/mlir/test/Transforms/parallel-loop-collapsing.mlir @@ -16,6 +16,8 @@ %c12 = constant 12 : index %c13 = constant 13 : index %c14 = constant 14 : index + %c15 = constant 15 : index + %c26 = constant 26 : index scf.parallel (%i0, %i1, %i2, %i3, %i4) = (%c0, %c3, %c6, %c9, %c12) to (%c2, %c5, %c8, %c11, %c14) step (%c1, %c4, %c7, %c10, %c13) { @@ -26,24 +28,18 @@ // CHECK-LABEL: func @parallel_many_dims() { // CHECK: [[C6:%.*]] = constant 6 : index -// CHECK: [[C7:%.*]] = constant 7 : index // CHECK: [[C9:%.*]] = constant 9 : index // CHECK: [[C10:%.*]] = constant 10 : index // CHECK: [[C12:%.*]] = constant 12 : index -// CHECK: [[C13:%.*]] = constant 13 : index -// CHECK: [[C3:%.*]] = constant 3 : index // CHECK: [[C0:%.*]] = constant 0 : index // CHECK: [[C1:%.*]] = constant 1 : index // CHECK: [[C2:%.*]] = constant 2 : index -// CHECK: scf.parallel ([[NEW_I0:%.*]], [[NEW_I1:%.*]], [[NEW_I2:%.*]]) = ([[C0]], [[C0]], [[C0]]) to ([[C2]], [[C1]], [[C1]]) step ([[C1]], [[C1]], [[C1]]) { +// CHECK: [[C3:%.*]] = constant 3 : index +// CHECK: scf.parallel ([[NEW_I0:%.*]]) = ([[C0]]) to ([[C2]]) step ([[C1]]) { // CHECK: [[I0:%.*]] = remi_signed [[NEW_I0]], [[C2]] : index -// CHECK: [[VAL_16:%.*]] = muli [[NEW_I1]], [[C13]] : index -// CHECK: [[I4:%.*]] = addi [[VAL_16]], [[C12]] : index -// CHECK: [[VAL_18:%.*]] = muli [[NEW_I0]], [[C10]] : index -// CHECK: [[I3:%.*]] = addi [[VAL_18]], [[C9]] : index -// CHECK: [[VAL_20:%.*]] = muli [[NEW_I2]], [[C7]] : index -// CHECK: [[I2:%.*]] = addi [[VAL_20]], [[C6]] : index -// CHECK: "magic.op"([[I0]], [[C3]], [[I2]], [[I3]], [[I4]]) : (index, index, index, index, index) -> index +// CHECK: [[V18:%.*]] = muli [[NEW_I0]], [[C10]] : index +// CHECK: [[I3:%.*]] = addi [[V18]], [[C9]] : index +// CHECK: "magic.op"([[I0]], [[C3]], [[C6]], [[I3]], [[C12]]) : (index, index, index, index, index) -> index // CHECK: scf.yield // CHECK-NEXT: } // CHECK-NEXT: return