diff --git a/mlir/include/mlir/Dialect/Affine/Passes.h b/mlir/include/mlir/Dialect/Affine/Passes.h --- a/mlir/include/mlir/Dialect/Affine/Passes.h +++ b/mlir/include/mlir/Dialect/Affine/Passes.h @@ -18,6 +18,7 @@ #include namespace mlir { + namespace func { class FuncOp; } // namespace func @@ -110,10 +111,6 @@ /// memory hierarchy. std::unique_ptr> createPipelineDataTransferPass(); -/// Populate patterns that expand affine index operations into more fundamental -/// operations (not necessarily restricted to Affine dialect). -void populateAffineExpandIndexOpsPatterns(RewritePatternSet &patterns); - /// Creates a pass to expand affine index operations into more fundamental /// operations (not necessarily restricted to Affine dialect). std::unique_ptr createAffineExpandIndexOpsPass(); diff --git a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h @@ -0,0 +1,49 @@ +//===- Transforms.h - Transforms Entrypoints --------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file defines a set of transforms specific for the AffineOps +// dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_AFFINE_TRANSFORMS_TRANSFORMS_H +#define MLIR_DIALECT_AFFINE_TRANSFORMS_TRANSFORMS_H + +#include "mlir/Support/LogicalResult.h" + +namespace mlir { +class RewritePatternSet; +class RewriterBase; +class AffineApplyOp; + +/// Populate patterns that expand affine index operations into more fundamental +/// operations (not necessarily restricted to Affine dialect). +void populateAffineExpandIndexOpsPatterns(RewritePatternSet &patterns); + +/// Helper function to rewrite `op`'s affine map and reorder its operands such +/// that they are in increasing order of hoistability (i.e. the least hoistable) +/// operands come first in the operand list. +void reorderOperandsByHoistability(RewriterBase &rewriter, AffineApplyOp op); + +/// Split an "affine.apply" operation into smaller ops. +/// This reassociates a large AffineApplyOp into an ordered list of smaller +/// AffineApplyOps. This can be used right before lowering affine ops to arith +/// to exhibit more opportunities for CSE and LICM. +/// Return the sink AffineApplyOp on success or failure if `op` does not +/// decompose into smaller AffineApplyOps. +/// Note that this can be undone by canonicalization which tries to +/// maximally compose chains of AffineApplyOps. +FailureOr decompose(RewriterBase &rewriter, AffineApplyOp op); + +//===----------------------------------------------------------------------===// +// Registration +//===----------------------------------------------------------------------===// + +} // namespace mlir + +#endif // MLIR_DIALECT_AFFINE_TRANSFORMS_TRANSFORMS_H diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp --- a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/Affine/Passes.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/Transforms/Transforms.h" #include "mlir/Dialect/Affine/Utils.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" diff --git a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt @@ -5,6 +5,7 @@ AffineLoopNormalize.cpp AffineParallelize.cpp AffineScalarReplacement.cpp + DecomposeAffineOps.cpp LoopCoalescing.cpp LoopFusion.cpp LoopTiling.cpp diff --git a/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp b/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp @@ -0,0 +1,173 @@ +//===- DecomposeAffineOps.cpp - Decompose affine ops into finer-grained ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements functionality to progressively decompose coarse-grained +// affine ops into finer-grained ops. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/Transforms/Transforms.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/Support/Debug.h" + +using namespace mlir; + +#define DEBUG_TYPE "decompose-affine-ops" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define DBGSNL() (llvm::dbgs() << "\n") + +/// Count the number of loops surrounding `operand` such that operand could be +/// hoisted above. +/// Stop counting at the first loop over which the operand cannot be hoisted. +static int64_t numEnclosingInvariantLoops(OpOperand &operand) { + int64_t count = 0; + Operation *currentOp = operand.getOwner(); + while (auto loopOp = currentOp->getParentOfType()) { + if (!loopOp.isDefinedOutsideOfLoop(operand.get())) + break; + currentOp = loopOp; + count++; + } + return count; +} + +void mlir::reorderOperandsByHoistability(RewriterBase &rewriter, + AffineApplyOp op) { + SmallVector numInvariant = llvm::to_vector<4>( + llvm::map_range(op->getOpOperands(), [&](OpOperand &operand) { + return numEnclosingInvariantLoops(operand); + })); + + int64_t numOperands = op.getNumOperands(); + SmallVector operandPositions = + llvm::to_vector<4>(llvm::seq(0, numOperands)); + std::sort(operandPositions.begin(), operandPositions.end(), + [&numInvariant](size_t i1, size_t i2) { + return numInvariant[i1] > numInvariant[i2]; + }); + + SmallVector replacements(numOperands); + SmallVector operands(numOperands); + for (int64_t i = 0; i < numOperands; ++i) { + operands[i] = op.getOperand(operandPositions[i]); + replacements[operandPositions[i]] = getAffineSymbolExpr(i, op.getContext()); + } + + AffineMap map = op.getAffineMap(); + ArrayRef repls{replacements}; + map = map.replaceDimsAndSymbols(repls.take_front(map.getNumDims()), + repls.drop_front(map.getNumDims()), + /*numResultDims=*/0, + /*numResultSyms=*/numOperands); + map = AffineMap::get(0, numOperands, + simplifyAffineExpr(map.getResult(0), 0, numOperands), + op->getContext()); + canonicalizeMapAndOperands(&map, &operands); + + rewriter.startRootUpdate(op); + op.setMap(map); + op->setOperands(operands); + rewriter.finalizeRootUpdate(op); +} + +/// Build an affine.apply that is a subexpression `expr` of `originalOp`s affine +/// map and with the same operands. +/// Canonicalize the map and operands to deduplicate and drop dead operands +/// before returning but do not perform maximal composition of AffineApplyOp +/// which would defeat the purpose. +static AffineApplyOp createSubApply(RewriterBase &rewriter, + AffineApplyOp originalOp, AffineExpr expr) { + MLIRContext *ctx = originalOp->getContext(); + AffineMap m = originalOp.getAffineMap(); + auto rhsMap = AffineMap::get(m.getNumDims(), m.getNumSymbols(), expr, ctx); + SmallVector rhsOperands = originalOp->getOperands(); + canonicalizeMapAndOperands(&rhsMap, &rhsOperands); + return rewriter.create(originalOp.getLoc(), rhsMap, + rhsOperands); +} + +FailureOr mlir::decompose(RewriterBase &rewriter, + AffineApplyOp op) { + // 1. Preconditions: only handle dimensionless AffineApplyOp maps with a + // top-level binary expression that we can reassociate (i.e. add or mul). + AffineMap m = op.getAffineMap(); + if (m.getNumDims() > 0) + return rewriter.notifyMatchFailure(op, "expected no dims"); + + AffineExpr remainingExp = m.getResult(0); + auto binExpr = remainingExp.dyn_cast(); + if (!binExpr) + return rewriter.notifyMatchFailure(op, "terminal affine.apply"); + + if (!binExpr.getLHS().isa() && + !binExpr.getRHS().isa()) + return rewriter.notifyMatchFailure(op, "terminal affine.apply"); + + bool supportedKind = ((binExpr.getKind() == AffineExprKind::Add) || + (binExpr.getKind() == AffineExprKind::Mul)); + if (!supportedKind) + return rewriter.notifyMatchFailure( + op, "only add or mul binary expr can be reassociated"); + + LLVM_DEBUG(DBGS() << "Start decomposeIntoFinerGrainedOps: " << op << "\n"); + + // 2. Iteratively extract the RHS subexpressions while the top-level binary + // expr kind remains the same. + MLIRContext *ctx = op->getContext(); + SmallVector subExpressions; + while (true) { + auto currentBinExpr = remainingExp.dyn_cast(); + if (!currentBinExpr || currentBinExpr.getKind() != binExpr.getKind()) { + subExpressions.push_back(remainingExp); + LLVM_DEBUG(DBGS() << "--terminal: " << subExpressions.back() << "\n"); + break; + } + subExpressions.push_back(currentBinExpr.getRHS()); + LLVM_DEBUG(DBGS() << "--subExpr: " << subExpressions.back() << "\n"); + remainingExp = currentBinExpr.getLHS(); + } + + // 3. Reorder subExpressions by the min symbol they are a function of. + // This also takes care of properly reordering local variables. + // This however won't be able to split expression that cannot be reassociated + // such as ones that involve divs and multiple symbols. + auto getMaxSymbol = [&](AffineExpr e) -> int64_t { + for (int64_t i = m.getNumSymbols(); i >= 0; --i) + if (e.isFunctionOfSymbol(i)) + return i; + return -1; + }; + std::sort(subExpressions.begin(), subExpressions.end(), + [&](AffineExpr e1, AffineExpr e2) { + return getMaxSymbol(e1) < getMaxSymbol(e2); + }); + LLVM_DEBUG( + llvm::interleaveComma(subExpressions, DBGS() << "--sorted subexprs: "); + llvm::dbgs() << "\n"); + + // 4. Merge sorted subExpressions iteratively, thus achieving reassociation. + auto s0 = getAffineSymbolExpr(0, ctx); + auto s1 = getAffineSymbolExpr(1, ctx); + AffineMap binMap = AffineMap::get( + /*dimCount=*/0, /*symbolCount=*/2, + getAffineBinaryOpExpr(binExpr.getKind(), s0, s1), ctx); + + auto current = createSubApply(rewriter, op, subExpressions[0]); + for (int64_t i = 1, e = subExpressions.size(); i < e; ++i) { + Value tmp = createSubApply(rewriter, op, subExpressions[i]); + current = rewriter.create(op.getLoc(), binMap, + ValueRange{current, tmp}); + LLVM_DEBUG(DBGS() << "--reassociate into: " << current << "\n"); + } + + // 5. Replace original op. + rewriter.replaceOp(op, current.getResult()); + return current; +} diff --git a/mlir/test/Dialect/Affine/decompose-affine-ops.mlir b/mlir/test/Dialect/Affine/decompose-affine-ops.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Affine/decompose-affine-ops.mlir @@ -0,0 +1,156 @@ +// RUN: mlir-opt %s -allow-unregistered-dialect -test-decompose-affine-ops -split-input-file | FileCheck %s + +// CHECK-DAG: #[[$c42:.*]] = affine_map<() -> (42)> +// CHECK-DAG: #[[$div32mod4:.*]] = affine_map<()[s0] -> ((s0 floordiv 32) mod 4)> +// CHECK-DAG: #[[$add:.*]] = affine_map<()[s0, s1] -> (s0 + s1)> + +// CHECK-LABEL: func.func @simple_test_1 +// CHECK-SAME: %[[I0:[0-9a-zA-Z]+]]: index, +// CHECK-SAME: %[[I1:[0-9a-zA-Z]+]]: index, +// CHECK-SAME: %[[I2:[0-9a-zA-Z]+]]: index, +// CHECK-SAME: %[[LB:[0-9a-zA-Z]+]]: index, +// CHECK-SAME: %[[UB:[0-9a-zA-Z]+]]: index, +// CHECK-SAME: %[[STEP:[0-9a-zA-Z]+]]: index +func.func @simple_test_1(%0: index, %1: index, %2: index, %lb: index, %ub: index, %step: index) { + // CHECK: %[[c42:.*]] = affine.apply #[[$c42]]() + // CHECK: %[[R1:.*]] = affine.apply #[[$div32mod4]]()[%[[I1]]] + // CHECK: %[[a:.*]] = affine.apply #[[$add]]()[%[[c42]], %[[R1]]] + %a = affine.apply affine_map<(d0) -> ((d0 floordiv 32) mod 4 + 42)>(%1) + + // CHECK: "some_side_effecting_consumer"(%[[a]]) : (index) -> () + "some_side_effecting_consumer"(%a) : (index) -> () + return +} + +// ----- + +// CHECK-DAG: #[[$c42:.*]] = affine_map<() -> (42)> +// CHECK-DAG: #[[$id:.*]] = affine_map<()[s0] -> (s0)> +// CHECK-DAG: #[[$add:.*]] = affine_map<()[s0, s1] -> (s0 + s1)> +// CHECK-DAG: #[[$div32div4timesm4:.*]] = affine_map<()[s0] -> (((s0 floordiv 32) floordiv 4) * -4)> +// CHECK-DAG: #[[$div32:.*]] = affine_map<()[s0] -> (s0 floordiv 32)> + +// CHECK-LABEL: func.func @simple_test_2 +// CHECK-SAME: %[[I0:[0-9a-zA-Z]+]]: index, +// CHECK-SAME: %[[I1:[0-9a-zA-Z]+]]: index, +// CHECK-SAME: %[[I2:[0-9a-zA-Z]+]]: index, +// CHECK-SAME: %[[LB:[0-9a-zA-Z]+]]: index, +// CHECK-SAME: %[[UB:[0-9a-zA-Z]+]]: index, +// CHECK-SAME: %[[STEP:[0-9a-zA-Z]+]]: index +func.func @simple_test_2(%0: index, %1: index, %2: index, %lb: index, %ub: index, %step: index) { + // CHECK: %[[c42:.*]] = affine.apply #[[$c42]]() + // CHECK: scf.for %[[i:.*]] = + scf.for %i = %lb to %ub step %step { + // CHECK: %[[R1:.*]] = affine.apply #[[$id]]()[%[[i]]] + // CHECK: %[[R2:.*]] = affine.apply #[[$add]]()[%[[c42]], %[[R1]]] + // CHECK: scf.for %[[j:.*]] = + scf.for %j = %lb to %ub step %step { + // CHECK: %[[R3:.*]] = affine.apply #[[$div32div4timesm4]]()[%[[j]]] + // CHECK: %[[R4:.*]] = affine.apply #[[$add]]()[%[[R2]], %[[R3]]] + // CHECK: %[[R5:.*]] = affine.apply #[[$div32]]()[%[[j]]] + // CHECK: %[[a:.*]] = affine.apply #[[$add]]()[%[[R4]], %[[R5]]] + %a = affine.apply affine_map<(d0)[s0] -> ((d0 floordiv 32) mod 4 + s0 + 42)>(%j)[%i] + + // CHECK: "some_side_effecting_consumer"(%[[a]]) : (index) -> () + "some_side_effecting_consumer"(%a) : (index) -> () + } + } + return +} + +// ----- + +// CHECK-DAG: #[[$div4:.*]] = affine_map<()[s0] -> (s0 floordiv 4)> +// CHECK-DAG: #[[$times32:.*]] = affine_map<()[s0] -> (s0 * 32)> +// CHECK-DAG: #[[$times16:.*]] = affine_map<()[s0] -> (s0 * 16)> +// CHECK-DAG: #[[$add:.*]] = affine_map<()[s0, s1] -> (s0 + s1)> +// CHECK-DAG: #[[$div4timesm32:.*]] = affine_map<()[s0] -> ((s0 floordiv 4) * -32)> +// CHECK-DAG: #[[$times8:.*]] = affine_map<()[s0] -> (s0 * 8)> +// CHECK-DAG: #[[$id:.*]] = affine_map<()[s0] -> (s0)> +// CHECK-DAG: #[[$div32div4timesm4:.*]] = affine_map<()[s0] -> (((s0 floordiv 32) floordiv 4) * -4)> +// CHECK-DAG: #[[$div32:.*]] = affine_map<()[s0] -> (s0 floordiv 32)> + +// CHECK-LABEL: func.func @larger_test +// CHECK-SAME: %[[I0:[0-9a-zA-Z]+]]: index, +// CHECK-SAME: %[[I1:[0-9a-zA-Z]+]]: index, +// CHECK-SAME: %[[I2:[0-9a-zA-Z]+]]: index, +// CHECK-SAME: %[[LB:[0-9a-zA-Z]+]]: index, +// CHECK-SAME: %[[UB:[0-9a-zA-Z]+]]: index, +// CHECK-SAME: %[[STEP:[0-9a-zA-Z]+]]: index +func.func @larger_test(%0: index, %1: index, %2: index, %lb: index, %ub: index, %step: index) { + %c2 = arith.constant 2 : index + %c6 = arith.constant 6 : index + + // CHECK: %[[R0:.*]] = affine.apply #[[$div4]]()[%[[I0]]] + // CHECK-NEXT: %[[R1:.*]] = affine.apply #[[$times16]]()[%[[I1]]] + // CHECK-NEXT: %[[R2:.*]] = affine.apply #[[$add]]()[%[[R0]], %[[R1]]] + // CHECK-NEXT: %[[R3:.*]] = affine.apply #[[$times32]]()[%[[I2]]] + + // I1 * 16 + I2 * 32 + I0 floordiv 4 + // CHECK-NEXT: %[[b:.*]] = affine.apply #[[$add]]()[%[[R2]], %[[R3]]] + + // (I0 floordiv 4) * 32 + // CHECK-NEXT: %[[R5:.*]] = affine.apply #[[$div4timesm32]]()[%[[I0]]] + // 8 * I0 + // CHECK-NEXT: %[[R6:.*]] = affine.apply #[[$times8]]()[%[[I0]]] + // 8 * I0 + (I0 floordiv 4) * 32 + // CHECK-NEXT: %[[c:.*]] = affine.apply #[[$add]]()[%[[R5]], %[[R6]]] + + // CHECK-NEXT: scf.for %[[i:.*]] = + scf.for %i = %lb to %ub step %step { + // remainder from %a not hoisted above %i. + // CHECK-NEXT: %[[R8:.*]] = affine.apply #[[$times32]]()[%[[i]]] + // CHECK-NEXT: %[[a:.*]] = affine.apply #[[$add]]()[%[[b]], %[[R8]]] + + // CHECK-NEXT: scf.for %[[j:.*]] = + scf.for %j = %lb to %ub step %step { + // Gets hoisted partially to i and rest outermost. + // The hoisted part is %b. + %a = affine.apply affine_map<()[s0, s1, s2, s3] -> (s1 * 16 + s2 * 32 + s3 * 32 + s0 floordiv 4)>()[%0, %1, %2, %i] + + // Gets completely hoisted + %b = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 16 + s2 * 32 + s0 floordiv 4)>()[%0, %1, %2] + + // Gets completely hoisted + %c = affine.apply affine_map<()[s0] -> (s0 * 8 - (s0 floordiv 4) * 32)>()[%0] + + // 32 * %j + %c remains here, the rest is hoisted. + // CHECK-DAG: %[[R10:.*]] = affine.apply #[[$times32]]()[%[[j]]] + // CHECK-DAG: %[[d:.*]] = affine.apply #[[$add]]()[%[[c]], %[[R10]]] + %d = affine.apply affine_map<()[s0, s1] -> (s0 * 8 + s1 * 32 - (s0 floordiv 4) * 32)>()[%0, %j] + + // CHECK-DAG: %[[idj:.*]] = affine.apply #[[$id]]()[%[[j]]] + // CHECK-NEXT: scf.for %[[k:.*]] = + scf.for %k = %lb to %ub step %step { + // CHECK-NEXT: %[[idk:.*]] = affine.apply #[[$id]]()[%[[k]]] + // CHECK-NEXT: %[[e:.*]] = affine.apply #[[$add]]()[%[[c]], %[[idk]]] + %e = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 8 - (s1 floordiv 4) * 32)>()[%k, %0] + + // CHECK-NEXT: %[[R15:.*]] = affine.apply #[[$div32div4timesm4]]()[%[[k]]] + // CHECK-NEXT: %[[R16:.*]] = affine.apply #[[$add]]()[%[[idj]], %[[R15]]] + // CHECK-NEXT: %[[R17:.*]] = affine.apply #[[$div32]]()[%[[k]]] + // CHECK-NEXT: %[[f:.*]] = affine.apply #[[$add]]()[%[[R16]], %[[R17]]] + %f = affine.apply affine_map<(d0)[s0] -> ((d0 floordiv 32) mod 4 + s0)>(%k)[%j] + + // CHECK-NEXT: %[[g:.*]] = affine.apply #[[$add]]()[%[[b]], %[[idk]]] + %g = affine.apply affine_map<()[s0, s1, s2, s3] -> (s0 + s2 * 16 + s3 * 32 + s1 floordiv 4)>()[%k, %0, %1, %2] + + // CHECK-NEXT: "some_side_effecting_consumer"(%[[a]]) : (index) -> () + "some_side_effecting_consumer"(%a) : (index) -> () + // CHECK-NEXT: "some_side_effecting_consumer"(%[[b]]) : (index) -> () + "some_side_effecting_consumer"(%b) : (index) -> () + // CHECK-NEXT: "some_side_effecting_consumer"(%[[c]]) : (index) -> () + "some_side_effecting_consumer"(%c) : (index) -> () + // CHECK-NEXT: "some_side_effecting_consumer"(%[[d]]) : (index) -> () + "some_side_effecting_consumer"(%d) : (index) -> () + // CHECK-NEXT: "some_side_effecting_consumer"(%[[e]]) : (index) -> () + "some_side_effecting_consumer"(%e) : (index) -> () + // CHECK-NEXT: "some_side_effecting_consumer"(%[[f]]) : (index) -> () + "some_side_effecting_consumer"(%f) : (index) -> () + // CHECK-NEXT: "some_side_effecting_consumer"(%[[g]]) : (index) -> () + "some_side_effecting_consumer"(%g) : (index) -> () + } + } + } + return +} diff --git a/mlir/test/lib/Dialect/Affine/CMakeLists.txt b/mlir/test/lib/Dialect/Affine/CMakeLists.txt --- a/mlir/test/lib/Dialect/Affine/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Affine/CMakeLists.txt @@ -3,6 +3,7 @@ TestAffineDataCopy.cpp TestAffineLoopUnswitching.cpp TestAffineLoopParametricTiling.cpp + TestDecomposeAffineOps.cpp TestLoopFusion.cpp TestLoopMapping.cpp TestLoopPermutation.cpp diff --git a/mlir/test/lib/Dialect/Affine/TestDecomposeAffineOps.cpp b/mlir/test/lib/Dialect/Affine/TestDecomposeAffineOps.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/Affine/TestDecomposeAffineOps.cpp @@ -0,0 +1,57 @@ +//===- TestDecomposeAffineOps.cpp - Test affine ops decomposition utility -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to test affine data copy utility functions and +// options. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/Transforms/Transforms.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/Passes.h" + +#define PASS_NAME "test-decompose-affine-ops" + +using namespace mlir; + +namespace { + +struct TestDecomposeAffineOps + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDecomposeAffineOps) + + StringRef getArgument() const final { return PASS_NAME; } + StringRef getDescription() const final { + return "Tests affine ops decomposition utility functions."; + } + TestDecomposeAffineOps() = default; + TestDecomposeAffineOps(const TestDecomposeAffineOps &pass) + : PassWrapper(pass){}; + + void runOnOperation() override; +}; + +} // namespace + +void TestDecomposeAffineOps::runOnOperation() { + IRRewriter rewriter(&getContext()); + this->getOperation().walk([&](AffineApplyOp op) { + rewriter.setInsertionPoint(op); + reorderOperandsByHoistability(rewriter, op); + (void)decompose(rewriter, op); + }); +} + +namespace mlir { +void registerTestDecomposeAffineOpPass() { + PassRegistration(); +} +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -40,6 +40,7 @@ void registerSymbolTestPasses(); void registerRegionTestPasses(); void registerTestAffineDataCopyPass(); +void registerTestDecomposeAffineOpPass(); void registerTestAffineLoopUnswitchingPass(); void registerTestAllReduceLoweringPass(); void registerTestFunc(); @@ -148,6 +149,7 @@ registerSymbolTestPasses(); registerRegionTestPasses(); registerTestAffineDataCopyPass(); + registerTestDecomposeAffineOpPass(); registerTestAffineLoopUnswitchingPass(); registerTestAllReduceLoweringPass(); registerTestFunc();