diff --git a/mlir/include/mlir/Dialect/SCF/Passes.h b/mlir/include/mlir/Dialect/SCF/Passes.h --- a/mlir/include/mlir/Dialect/SCF/Passes.h +++ b/mlir/include/mlir/Dialect/SCF/Passes.h @@ -30,7 +30,7 @@ /// Creates a pass that canonicalizes affine.min and affine.max operations /// inside of scf.for loops with known lower and upper bounds. -std::unique_ptr createSCFAffineOpCanonicalizationPass(); +std::unique_ptr createSCFForLoopCanonicalizationPass(); /// Creates a loop fusion pass which fuses parallel loops. std::unique_ptr createParallelLoopFusionPass(); diff --git a/mlir/include/mlir/Dialect/SCF/Passes.td b/mlir/include/mlir/Dialect/SCF/Passes.td --- a/mlir/include/mlir/Dialect/SCF/Passes.td +++ b/mlir/include/mlir/Dialect/SCF/Passes.td @@ -17,14 +17,14 @@ let dependentDialects = ["memref::MemRefDialect"]; } -// Note: Making this a canonicalization pattern would require a dependency -// of the SCF dialect on the Affine dialect or vice versa. -def SCFAffineOpCanonicalization - : FunctionPass<"canonicalize-scf-affine-op"> { - let summary = "Canonicalize affine.min and affine.max ops in the context of " - "SCF loops with known bounds"; - let constructor = "mlir::createSCFAffineOpCanonicalizationPass()"; - let dependentDialects = ["AffineDialect"]; +// Note: Making these canonicalization patterns would require a dependency +// of the SCF dialect on the Affine/Tensor/MemRef dialects or vice versa. +def SCFForLoopCanonicalization + : FunctionPass<"for-loop-canonicalization"> { + let summary = "Canonicalize operations within scf.for loop bodies"; + let constructor = "mlir::createSCFForLoopCanonicalizationPass()"; + let dependentDialects = ["AffineDialect", "tensor::TensorDialect", + "memref::MemRefDialect"]; } def SCFForLoopPeeling diff --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td @@ -122,7 +122,7 @@ let summary = "for operation"; let description = [{ The "scf.for" operation represents a loop taking 3 SSA value as operands - that represent the lower bound, upper bound and step respectively. The + that represent the lower bound, upper bound and step respectively. The operation defines an SSA value for its induction variable. It has one region capturing the loop body. The induction variable is represented as an argument of this region. This SSA value always has type index, which is the @@ -146,14 +146,18 @@ values after loop termination. The initial values of the variables are passed as additional SSA operands to the "scf.for" following the 3 loop control SSA values mentioned above (lower bound, upper bound and step). The - operation region has equivalent arguments for each variable representing - the value of the variable at the current iteration. - - The region must terminate with a "scf.yield" that passes all the current - iteration variables to the next iteration, or to the "scf.for" result, if - at the last iteration. Note, that when the loop-carried variables are - present, calling ForOp::build will not insert the terminator implicitly. - The caller must insert "scf.yield" in that case. + operation region has an argument for the induction variable, followed by + one argument for each loop-carried variable, representing he value of the + variable at the current iteration. + + The region must terminate with a "scf.yield" that passes the current + values of loop-carried variables to the next iteration, or to the "scf.for" + result, if at the last iteration. The type (static or dynamic) of a + loop-carried variable may not change with iterations. E.g., it is illegal + to pass a tensor of larger size to the next iteration; even if the tensor's + dimensions are dynamic (i.e., same static type). Note, that when the + loop-carried variables are present, calling ForOp::build will not insert the + terminator implicitly. The caller must insert "scf.yield" in that case. "scf.for" results hold the final values after the last iteration. For example, to sum-reduce a memref: diff --git a/mlir/include/mlir/Dialect/SCF/Transforms.h b/mlir/include/mlir/Dialect/SCF/Transforms.h --- a/mlir/include/mlir/Dialect/SCF/Transforms.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms.h @@ -179,7 +179,7 @@ /// Populate patterns for canonicalizing operations inside SCF loop bodies. /// At the moment, only affine.min/max computations with iteration variables, /// loop bounds and loop steps are canonicalized. -void populateSCFLoopBodyCanonicalizationPatterns(RewritePatternSet &patterns); +void populateSCFForLoopCanonicalizationPatterns(RewritePatternSet &patterns); } // namespace scf } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp b/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp @@ -48,7 +48,7 @@ RewritePatternSet stage2Patterns = linalg::getLinalgTilingCanonicalizationPatterns(context); - scf::populateSCFLoopBodyCanonicalizationPatterns(stage2Patterns); + scf::populateSCFForLoopCanonicalizationPatterns(stage2Patterns); auto stage3Transforms = [&](Operation *op) { // Some of these may be too aggressive as a stage 3 that is applied on each diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -537,7 +537,7 @@ MLIRContext *ctx = funcOp.getContext(); RewritePatternSet patterns(ctx); insertTilingPatterns(patterns, options); - scf::populateSCFLoopBodyCanonicalizationPatterns(patterns); + scf::populateSCFForLoopCanonicalizationPatterns(patterns); (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); (void)applyPatternsAndFoldGreedily( funcOp, getLinalgTilingCanonicalizationPatterns(ctx)); diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRSCFTransforms Bufferize.cpp + LoopCanonicalization.cpp LoopPipelining.cpp LoopRangeFolding.cpp LoopSpecialization.cpp @@ -22,6 +23,7 @@ MLIRSCF MLIRStandard MLIRSupport + MLIRTensor MLIRTransforms MLIRTransformUtils ) diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp @@ -0,0 +1,127 @@ +//===- LoopCanonicalization.cpp - Cross-dialect canonicalization patterns -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains cross-dialect canonicalization patterns that cannot be +// actual canonicalization patterns due to undesired additional dependencies. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/Passes.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/SCF/Transforms.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace mlir::scf; + +namespace { +/// Fold dim ops of iter_args to dim ops of their respective init args. E.g.: +/// +/// ``` +/// %0 = ... : tensor +/// scf.for ... iter_args(%arg0 = %0) -> (tensor) { +/// %1 = tensor.dim %arg0, %c0 : tensor +/// ... +/// } +/// ``` +/// +/// is folded to: +/// +/// ``` +/// %0 = ... : tensor +/// scf.for ... iter_args(%arg0 = %0) -> (tensor) { +/// %1 = tensor.dim %0, %c0 : tensor +/// ... +/// } +/// ``` +template +struct DimOfIterArgFolder : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy dimOp, + PatternRewriter &rewriter) const override { + auto blockArg = dimOp.source().template dyn_cast(); + if (!blockArg) + return failure(); + auto forOp = dyn_cast(blockArg.getParentBlock()->getParentOp()); + if (!forOp) + return failure(); + + Value initArg = forOp.getOpOperandForRegionIterArg(blockArg).get(); + rewriter.updateRootInPlace( + dimOp, [&]() { dimOp.sourceMutable().assign(initArg); }); + + return success(); + }; +}; + +/// Canonicalize AffineMinOp/AffineMaxOp operations in the context of scf.for +/// and scf.parallel loops with a known range. +template +struct AffineOpSCFCanonicalizationPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + auto loopMatcher = [](Value iv, Value &lb, Value &ub, Value &step) { + if (scf::ForOp forOp = scf::getForInductionVarOwner(iv)) { + lb = forOp.lowerBound(); + ub = forOp.upperBound(); + step = forOp.step(); + return success(); + } + if (scf::ParallelOp parOp = scf::getParallelForInductionVarOwner(iv)) { + for (unsigned idx = 0; idx < parOp.getNumLoops(); ++idx) { + if (parOp.getInductionVars()[idx] == iv) { + lb = parOp.lowerBound()[idx]; + ub = parOp.upperBound()[idx]; + step = parOp.step()[idx]; + return success(); + } + } + return failure(); + } + return failure(); + }; + + return scf::canonicalizeMinMaxOpInLoop(rewriter, op, op.getAffineMap(), + op.operands(), IsMin, loopMatcher); + } +}; + +struct SCFForLoopCanonicalization + : public SCFForLoopCanonicalizationBase { + void runOnFunction() override { + FuncOp funcOp = getFunction(); + MLIRContext *ctx = funcOp.getContext(); + RewritePatternSet patterns(ctx); + scf::populateSCFForLoopCanonicalizationPatterns(patterns); + if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) + signalPassFailure(); + } +}; +} // namespace + +void mlir::scf::populateSCFForLoopCanonicalizationPatterns( + RewritePatternSet &patterns) { + MLIRContext *ctx = patterns.getContext(); + patterns + .insert, + AffineOpSCFCanonicalizationPattern, + DimOfIterArgFolder, + DimOfIterArgFolder>(ctx); +} + +std::unique_ptr mlir::createSCFForLoopCanonicalizationPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp --- a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp @@ -520,40 +520,6 @@ /// the direct parent. bool skipPartial; }; - -/// Canonicalize AffineMinOp/AffineMaxOp operations in the context of scf.for -/// and scf.parallel loops with a known range. -template -struct AffineOpSCFCanonicalizationPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(OpTy op, - PatternRewriter &rewriter) const override { - auto loopMatcher = [](Value iv, Value &lb, Value &ub, Value &step) { - if (scf::ForOp forOp = scf::getForInductionVarOwner(iv)) { - lb = forOp.lowerBound(); - ub = forOp.upperBound(); - step = forOp.step(); - return success(); - } - if (scf::ParallelOp parOp = scf::getParallelForInductionVarOwner(iv)) { - for (unsigned idx = 0; idx < parOp.getNumLoops(); ++idx) { - if (parOp.getInductionVars()[idx] == iv) { - lb = parOp.lowerBound()[idx]; - ub = parOp.upperBound()[idx]; - step = parOp.step()[idx]; - return success(); - } - } - return failure(); - } - return failure(); - }; - - return scf::canonicalizeMinMaxOpInLoop(rewriter, op, op.getAffineMap(), - op.operands(), IsMin, loopMatcher); - } -}; } // namespace namespace { @@ -587,24 +553,8 @@ }); } }; - -struct SCFAffineOpCanonicalization - : public SCFAffineOpCanonicalizationBase { - void runOnFunction() override { - FuncOp funcOp = getFunction(); - MLIRContext *ctx = funcOp.getContext(); - RewritePatternSet patterns(ctx); - scf::populateSCFLoopBodyCanonicalizationPatterns(patterns); - if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) - signalPassFailure(); - } -}; } // namespace -std::unique_ptr mlir::createSCFAffineOpCanonicalizationPass() { - return std::make_unique(); -} - std::unique_ptr mlir::createParallelLoopSpecializationPass() { return std::make_unique(); } @@ -616,12 +566,3 @@ std::unique_ptr mlir::createForLoopPeelingPass() { return std::make_unique(); } - -void mlir::scf::populateSCFLoopBodyCanonicalizationPatterns( - RewritePatternSet &patterns) { - MLIRContext *ctx = patterns.getContext(); - patterns - .insert, - AffineOpSCFCanonicalizationPattern>( - ctx); -} diff --git a/mlir/lib/Dialect/SCF/Transforms/PassDetail.h b/mlir/lib/Dialect/SCF/Transforms/PassDetail.h --- a/mlir/lib/Dialect/SCF/Transforms/PassDetail.h +++ b/mlir/lib/Dialect/SCF/Transforms/PassDetail.h @@ -22,6 +22,10 @@ class MemRefDialect; } // end namespace memref +namespace tensor { +class TensorDialect; +} // end namespace tensor + #define GEN_PASS_CLASSES #include "mlir/Dialect/SCF/Passes.h.inc" diff --git a/mlir/test/Dialect/SCF/canonicalize-affine-op.mlir b/mlir/test/Dialect/SCF/for-loop-canonicalization.mlir rename from mlir/test/Dialect/SCF/canonicalize-affine-op.mlir rename to mlir/test/Dialect/SCF/for-loop-canonicalization.mlir --- a/mlir/test/Dialect/SCF/canonicalize-affine-op.mlir +++ b/mlir/test/Dialect/SCF/for-loop-canonicalization.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -canonicalize-scf-affine-op -split-input-file | FileCheck %s +// RUN: mlir-opt %s -for-loop-canonicalization -split-input-file | FileCheck %s // CHECK-LABEL: func @scf_for_canonicalize_min // CHECK: %[[C2:.*]] = constant 2 : i64 @@ -224,3 +224,21 @@ } return } + +// ----- + +// CHECK-LABEL: func @tensor_dim_of_iter_arg( +// CHECK-SAME: %[[t:.*]]: tensor +// CHECK: scf.for +// CHECK: tensor.dim %[[t]] +func @tensor_dim_of_iter_arg(%t : tensor) -> index { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c10 = constant 10 : index + %0, %1 = scf.for %i = %c0 to %c10 step %c1 iter_args(%arg0 = %t, %arg1 = %c0) + -> (tensor, index) { + %dim = tensor.dim %arg0, %c0 : tensor + scf.yield %arg0, %dim : tensor, index + } + return %1 : index +} diff --git a/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp b/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp --- a/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp @@ -71,7 +71,7 @@ RewritePatternSet stage2Patterns = linalg::getLinalgTilingCanonicalizationPatterns(context); - scf::populateSCFLoopBodyCanonicalizationPatterns(stage2Patterns); + scf::populateSCFForLoopCanonicalizationPatterns(stage2Patterns); auto stage3Transforms = [](Operation *op) { PassManager pm(op->getContext()); diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp @@ -237,7 +237,7 @@ RewritePatternSet patterns = linalg::getLinalgTilingCanonicalizationPatterns(context); patterns.add(context); - scf::populateSCFLoopBodyCanonicalizationPatterns(patterns); + scf::populateSCFForLoopCanonicalizationPatterns(patterns); FrozenRewritePatternSet frozenPatterns(std::move(patterns)); do { (void)applyPatternsAndFoldGreedily(getFunction(), frozenPatterns); diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -1491,6 +1491,7 @@ ":SCFPassIncGen", ":StandardOps", ":Support", + ":TensorDialect", ":Transforms", "//llvm:Support", ],