diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -129,6 +129,10 @@ /// (identity) layout map. std::unique_ptr> createNormalizeMemRefsPass(); +/// Creates a pass which converts a program into maximal SSA form. In this form, +/// any value referenced within a block is also defined within the same block. +std::unique_ptr createMaxSSAFormPass(); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -629,4 +629,33 @@ let constructor = "mlir::createPrintOpGraphPass()"; } +def MaxSSAForm : FunctionPass<"max-ssa"> { + let summary = "Converts a program into maximal SSA form"; + let description = [{ + This pass converts a program into maximal SSA form. In this form, any value + referenced within a block is also defined within the block. This is done by + adding block arguments to all basic block dominance chains which may lead + to an operation that relied on referencing a Value based on basic block + dominance. + + This pass is useful in dataflow-style programming models since it renders + all data flow within the program explicit (through block arguments) instead + of implicit (through block dominance). + + This pass only works on Standard-level IR, in that it expects all operations + (and blocks) within a FuncOp to be within the same region. Furthermore, it + is assumed that any value referenced by any operation is eligible to be + passed around as a block argument. + }]; + let constructor = "mlir::createMaxSSAFormPass()"; + let options = [ + ListOption<"ignoredDialects", "ignore-dialects", "std::string", + "List of ignored dialects. If a values' type is defined by an ignored " + "dialect, the value will be ignored during SSA maximization.", + "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">, + Option<"ignoreMemref", "ignore-memref", "bool", "false", + "Ignore memref values in SSA maximization."> + ]; +} + #endif // MLIR_TRANSFORMS_PASSES diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt --- a/mlir/lib/Transforms/CMakeLists.txt +++ b/mlir/lib/Transforms/CMakeLists.txt @@ -11,6 +11,7 @@ LoopCoalescing.cpp LoopFusion.cpp LoopInvariantCodeMotion.cpp + MaxSSA.cpp NormalizeMemRefs.cpp OpStats.cpp ParallelLoopCollapsing.cpp diff --git a/mlir/lib/Transforms/MaxSSA.cpp b/mlir/lib/Transforms/MaxSSA.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Transforms/MaxSSA.cpp @@ -0,0 +1,211 @@ +//===- MaxSSA.cpp - Maximal SSA form conversion ------------------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Contains the definitions of the maximal SSA form conversion pass. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "llvm/ADT/TypeSwitch.h" + +#include + +using namespace mlir; + +/// Rewrites the terminator in 'from' to pass an additional argument 'v' when +/// passing control flow to 'to'. +static LogicalResult rewriteControlFlowToBlock(Block *from, Block *to, + Value v) { + auto termOp = from->getTerminator(); + auto branchOp = dyn_cast(termOp); + if (!branchOp) + return termOp->emitOpError() << "expected terminator op within control " + "flow to be a branch-like op"; + + // Find 'to' successor in the branch op. + auto successors = branchOp->getSuccessors(); + auto succIt = llvm::find(successors, to); + assert(succIt != successors.end()); + unsigned succIdx = std::distance(successors.begin(), succIt); + + // Add the value 'v' as a block argument. + branchOp.getMutableSuccessorOperands(succIdx)->append(v); + return success(); +} + +/// Rewrites uses of 'oldV' in 'b' to 'newV'. +static void rewriteUsageInBlock(Block *b, Value oldV, Value newV) { + for (auto &use : llvm::make_early_inc_range(oldV.getUses())) + if (use.getOwner()->getBlock() == b) + use.set(newV); +} + +/// Perform a depth-first search backwards through the CFG graph of a program, +/// starting from 'use', add a new block argument of type(v) to the block and +/// replaces all uses of 'v' with the new block argument. +/// +/// Arguments: 'use': ablock where a value v flows through. 'succ': a successor +/// block of the 'use' block. Notably this conversion backtracks through the BB +/// CFG graph, so 'succ' will be a basic block that called backtrackAndConvert +/// on 'use'. 'inBlockValues': A mapping containing the appended block argument +/// when backtracking through a basic block. 'convertedControlFlow': A mapping +/// containing, for a given block (key) which successor operand range in the +/// terminator have been rewritten to the new block argument signature. +static LogicalResult +backtrackAndConvert(Block *use, Block *succ, Value v, + DenseMap> &inBlockValues, + DenseMap>> + &convertedControlFlow) { + // The base case is when we've backtracked to the Block which defines the + // value. In these cases, set the actual value as the converted value. + if (v.getParentBlock() == use) + inBlockValues[v][use] = v; + + auto alreadyinBlockValues = inBlockValues[v].find(use); + if (alreadyinBlockValues == inBlockValues[v].end()) { + // Rewrite this blocks' block arguments to take in a new value of 'v' type. + use->addArgument(v.getType()); + Value newBarg = use->getArguments().back(); + + // Register the converted block argument in case other branches in the CFG + // arrive here later. + inBlockValues[v][use] = newBarg; + rewriteUsageInBlock(use, v, newBarg); + + // Recurse through the predecessors of this block. + for (auto pred : use->getPredecessors()) + if (backtrackAndConvert(pred, use, v, inBlockValues, convertedControlFlow) + .failed()) + return failure(); + } + + // Rewrite control flow to the 'succ' block through the terminator, if not + // already done. + if (succ && convertedControlFlow[v][use].count(succ) == 0) { + alreadyinBlockValues = inBlockValues[v].find(use); + assert(alreadyinBlockValues != inBlockValues[v].end()); + if (rewriteControlFlowToBlock(use, succ, alreadyinBlockValues->second) + .failed()) + return failure(); + convertedControlFlow[v][use].insert(succ); + } + + return success(); +} + +namespace { + +struct MaxSSAFormPass : public MaxSSAFormBase { +public: + void runOnFunction() override { + FuncOp function = getOperation(); + + function.walk([&](Operation *op) { + // Run on operation results. + if (llvm::any_of(op->getResults(), + [&](Value res) { return runOnValue(res).failed(); })) { + signalPassFailure(); + return; + } + // Run on block arguments. + for (auto &block : function) { + if (llvm::any_of(block.getArguments(), [&](Value barg) { + return runOnValue(barg).failed(); + })) { + signalPassFailure(); + return; + } + } + }); + + assert( + verifyFunction(function).succeeded() && + "Some values were still referenced outside of their defining block!"); + } + +private: + /// Returns true if this value is ignored in SSA maximisation. + bool isIgnored(Value v) const; + + /// Verifies that all values indeed are only referenced within their defining + /// block. + LogicalResult verifyFunction(FuncOp f) const; + + /// Driver which will run backtrackAndConvert on values referenced outside + /// their defining block. Returns failure in case the pass failed to apply. + /// This may happen when nested regions exist within the FuncOp which this + /// pass is applied to, or if non branch-like control flow is used. + LogicalResult runOnValue(Value v); + + /// A mapping {original value : {block : replaced value}} representing + /// 'original value' has been replaced in 'block' with 'replaced value'". + DenseMap> inBlockValues; + + /// A mapping {original value : {block : succ block}} representing + /// 'original value' has already been passed from 'block' to 'succ block' + /// through the terminator of 'block'. + DenseMap>> convertedControlFlow; +}; + +/// Returns true if this value is ignored in SSA maximisation. +bool MaxSSAFormPass::isIgnored(Value v) const { + Type t = v.getType(); + + return llvm::TypeSwitch(t) + .Case([&](auto) { return static_cast(ignoreMemref); }) + .Default([&](auto) { + return llvm::find(ignoredDialects, t.getDialect().getNamespace()) != + ignoredDialects.end(); + }); +} + +LogicalResult MaxSSAFormPass::verifyFunction(FuncOp f) const { + for (auto &op : f.getOps()) { + auto isValid = [&](Value v) { + return isIgnored(v) || v.getParentBlock() == op.getBlock(); + }; + + if (!llvm::all_of(op.getOperands(), isValid)) { + f.dump(); + return op.emitOpError() + << "has operands that are not defined within its block"; + } + } + return success(); +} + +LogicalResult MaxSSAFormPass::runOnValue(Value v) { + if (isIgnored(v)) + return success(); + Block *definingBlock = v.getParentBlock(); + for (auto user : v.getUsers()) { + Block *userBlock = user->getBlock(); + if (definingBlock != userBlock) { + // This is a case of using an SSA value through basic block dominance. + if (userBlock->getParent() != definingBlock->getParent()) + return user->emitOpError() << "can only convert SSA usage across " + "blocks in the same region."; + + if (backtrackAndConvert(userBlock, /*succ=*/nullptr, v, inBlockValues, + convertedControlFlow) + .failed()) + return failure(); + } + } + return success(); +} + +} // namespace + +namespace mlir { +std::unique_ptr createMaxSSAFormPass() { + return std::make_unique(); +} +} // namespace mlir diff --git a/mlir/test/Transforms/max-ssa-errors.mlir b/mlir/test/Transforms/max-ssa-errors.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Transforms/max-ssa-errors.mlir @@ -0,0 +1,14 @@ +// RUN: mlir-opt -split-input-file --max-ssa %s -verify-diagnostics + +func @scf(%i : index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c0_i32 = arith.constant 0 : i32 + %m = memref.alloc() : memref<4xi32> + scf.for %0 = %c0 to %c4 step %c1 { + // expected-error @+1 {{'memref.store' op can only convert SSA usage across blocks in the same region.}} + memref.store %c0_i32, %m[%i] : memref<4xi32> + } + return +} diff --git a/mlir/test/Transforms/max-ssa-ignore.mlir b/mlir/test/Transforms/max-ssa-ignore.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Transforms/max-ssa-ignore.mlir @@ -0,0 +1,62 @@ +// RUN: mlir-opt -split-input-file --max-ssa="ignore-memref" %s | FileCheck %s + +// CHECK-LABEL: func @memory_loop() -> i32 { +// CHECK: %[[VAL_0:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_1:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_2:.*]] = arith.constant 4 : index +// CHECK: %[[VAL_3:.*]] = arith.constant 5 : i32 +// CHECK: %[[VAL_4:.*]] = memref.alloc() : memref<64xi32> +// CHECK: %[[VAL_5:.*]] = memref.alloc() : memref<64xi32> +// CHECK: br ^bb1(%[[VAL_0]], %[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[VAL_3]] : index, index, index, index, i32) +// CHECK: ^bb1(%[[VAL_6:.*]]: index, %[[VAL_7:.*]]: index, %[[VAL_8:.*]]: index, %[[VAL_9:.*]]: index, %[[VAL_10:.*]]: i32): +// CHECK: %[[VAL_11:.*]] = arith.cmpi slt, %[[VAL_6]], %[[VAL_9]] : index +// CHECK: cond_br %[[VAL_11]], ^bb2(%[[VAL_7]], %[[VAL_6]], %[[VAL_8]], %[[VAL_9]], %[[VAL_10]] : index, index, index, index, i32), ^bb3(%[[VAL_7]], %[[VAL_8]], %[[VAL_9]] : index, index, index) +// CHECK: ^bb2(%[[VAL_12:.*]]: index, %[[VAL_13:.*]]: index, %[[VAL_14:.*]]: index, %[[VAL_15:.*]]: index, %[[VAL_16:.*]]: i32): +// CHECK: memref.store %[[VAL_16]], %[[VAL_4]]{{\[}}%[[VAL_13]]] : memref<64xi32> +// CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_13]], %[[VAL_14]] : index +// CHECK: br ^bb1(%[[VAL_17]], %[[VAL_12]], %[[VAL_14]], %[[VAL_15]], %[[VAL_16]] : index, index, index, index, i32) +// CHECK: ^bb3(%[[VAL_18:.*]]: index, %[[VAL_19:.*]]: index, %[[VAL_20:.*]]: index): +// CHECK: br ^bb4(%[[VAL_18]], %[[VAL_18]], %[[VAL_19]], %[[VAL_20]] : index, index, index, index) +// CHECK: ^bb4(%[[VAL_21:.*]]: index, %[[VAL_22:.*]]: index, %[[VAL_23:.*]]: index, %[[VAL_24:.*]]: index): +// CHECK: %[[VAL_25:.*]] = arith.cmpi slt, %[[VAL_21]], %[[VAL_24]] : index +// CHECK: cond_br %[[VAL_25]], ^bb5(%[[VAL_22]], %[[VAL_21]], %[[VAL_23]], %[[VAL_24]] : index, index, index, index), ^bb6(%[[VAL_22]] : index) +// CHECK: ^bb5(%[[VAL_26:.*]]: index, %[[VAL_27:.*]]: index, %[[VAL_28:.*]]: index, %[[VAL_29:.*]]: index): +// CHECK: %[[VAL_30:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_27]]] : memref<64xi32> +// CHECK: %[[VAL_31:.*]] = arith.addi %[[VAL_30]], %[[VAL_30]] : i32 +// CHECK: memref.store %[[VAL_31]], %[[VAL_5]]{{\[}}%[[VAL_27]]] : memref<64xi32> +// CHECK: %[[VAL_32:.*]] = arith.addi %[[VAL_27]], %[[VAL_28]] : index +// CHECK: br ^bb4(%[[VAL_32]], %[[VAL_26]], %[[VAL_28]], %[[VAL_29]] : index, index, index, index) +// CHECK: ^bb6(%[[VAL_33:.*]]: index): +// CHECK: %[[VAL_34:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_33]]] : memref<64xi32> +// CHECK: return %[[VAL_34]] : i32 +// CHECK: } +func @memory_loop() -> i32 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c5_i32 = arith.constant 5 : i32 + %0 = memref.alloc() : memref<64xi32> + %1 = memref.alloc() : memref<64xi32> + br ^bb1(%c0 : index) +^bb1(%2: index): + %3 = arith.cmpi slt, %2, %c4 : index + cond_br %3, ^bb2, ^bb3 +^bb2: + memref.store %c5_i32, %0[%2] : memref<64xi32> + %4 = arith.addi %2, %c1 : index + br ^bb1(%4 : index) +^bb3: + br ^bb4(%c0 : index) +^bb4(%5: index): + %6 = arith.cmpi slt, %5, %c4 : index + cond_br %6, ^bb5, ^bb6 +^bb5: + %7 = memref.load %0[%5] : memref<64xi32> + %8 = arith.addi %7, %7 : i32 + memref.store %8, %1[%5] : memref<64xi32> + %9 = arith.addi %5, %c1 : index + br ^bb4(%9 : index) +^bb6: + %10 = memref.load %1[%c0] : memref<64xi32> + return %10 : i32 +} diff --git a/mlir/test/Transforms/max-ssa.mlir b/mlir/test/Transforms/max-ssa.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Transforms/max-ssa.mlir @@ -0,0 +1,196 @@ +// RUN: mlir-opt -split-input-file --max-ssa %s | FileCheck %s + +// CHECK-LABEL: func @simple( +// CHECK-SAME: %[[VAL_0:.*]]: index, +// CHECK-SAME: %[[VAL_1:.*]]: index) -> (index, index) { +// CHECK: br ^bb1(%[[VAL_0]], %[[VAL_1]] : index, index) +// CHECK: ^bb1(%[[VAL_2:.*]]: index, %[[VAL_3:.*]]: index): +// CHECK: return %[[VAL_2]], %[[VAL_3]] : index, index +// CHECK: } +func @simple(%arg0 : index, %arg1 : index) -> (index, index) { + br ^bb1(%arg0 : index) +^bb1(%0 : index): + return %0, %arg1 : index, index +} + +// ----- + +// CHECK-LABEL: func @backedge() -> index { +// CHECK: %[[VAL_0:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_1:.*]] = arith.constant 42 : index +// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index +// CHECK: br ^bb1(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]] : index, index, index) +// CHECK: ^bb1(%[[VAL_3:.*]]: index, %[[VAL_4:.*]]: index, %[[VAL_5:.*]]: index): +// CHECK: %[[VAL_6:.*]] = arith.cmpi slt, %[[VAL_3]], %[[VAL_4]] : index +// CHECK: cond_br %[[VAL_6]], ^bb2(%[[VAL_4]], %[[VAL_3]], %[[VAL_5]] : index, index, index), ^bb3(%[[VAL_3]] : index) +// CHECK: ^bb2(%[[VAL_7:.*]]: index, %[[VAL_8:.*]]: index, %[[VAL_9:.*]]: index): +// CHECK: %[[VAL_10:.*]] = arith.addi %[[VAL_8]], %[[VAL_9]] : index +// CHECK: br ^bb1(%[[VAL_10]], %[[VAL_7]], %[[VAL_9]] : index, index, index) +// CHECK: ^bb3(%[[VAL_11:.*]]: index): +// CHECK: return %[[VAL_11]] : index +// CHECK: } +func @backedge() -> index { + %c1 = arith.constant 1 : index + %c42 = arith.constant 42 : index + %c1_0 = arith.constant 1 : index + br ^bb1(%c1 : index) +^bb1(%0: index): + %1 = arith.cmpi slt, %0, %c42 : index + cond_br %1, ^bb2, ^bb3 +^bb2: + %2 = arith.addi %0, %c1_0 : index + br ^bb1(%2 : index) +^bb3: + return %0 : index +} + +// ----- + +// CHECK-LABEL: func @memory_loop() -> i32 { +// CHECK: %[[VAL_0:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_1:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_2:.*]] = arith.constant 4 : index +// CHECK: %[[VAL_3:.*]] = arith.constant 5 : i32 +// CHECK: %[[VAL_4:.*]] = memref.alloc() : memref<64xi32> +// CHECK: %[[VAL_5:.*]] = memref.alloc() : memref<64xi32> +// CHECK: br ^bb1(%[[VAL_0]], %[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[VAL_3]], %[[VAL_4]], %[[VAL_5]] : index, index, index, index, i32, memref<64xi32>, memref<64xi32>) +// CHECK: ^bb1(%[[VAL_6:.*]]: index, %[[VAL_7:.*]]: index, %[[VAL_8:.*]]: index, %[[VAL_9:.*]]: index, %[[VAL_10:.*]]: i32, %[[VAL_11:.*]]: memref<64xi32>, %[[VAL_12:.*]]: memref<64xi32>): +// CHECK: %[[VAL_13:.*]] = arith.cmpi slt, %[[VAL_6]], %[[VAL_9]] : index +// CHECK: cond_br %[[VAL_13]], ^bb2(%[[VAL_7]], %[[VAL_6]], %[[VAL_8]], %[[VAL_9]], %[[VAL_10]], %[[VAL_11]], %[[VAL_12]] : index, index, index, index, i32, memref<64xi32>, memref<64xi32>), ^bb3(%[[VAL_7]], %[[VAL_8]], %[[VAL_9]], %[[VAL_11]], %[[VAL_12]] : index, index, index, memref<64xi32>, memref<64xi32>) +// CHECK: ^bb2(%[[VAL_14:.*]]: index, %[[VAL_15:.*]]: index, %[[VAL_16:.*]]: index, %[[VAL_17:.*]]: index, %[[VAL_18:.*]]: i32, %[[VAL_19:.*]]: memref<64xi32>, %[[VAL_20:.*]]: memref<64xi32>): +// CHECK: memref.store %[[VAL_18]], %[[VAL_19]]{{\[}}%[[VAL_15]]] : memref<64xi32> +// CHECK: %[[VAL_21:.*]] = arith.addi %[[VAL_15]], %[[VAL_16]] : index +// CHECK: br ^bb1(%[[VAL_21]], %[[VAL_14]], %[[VAL_16]], %[[VAL_17]], %[[VAL_18]], %[[VAL_19]], %[[VAL_20]] : index, index, index, index, i32, memref<64xi32>, memref<64xi32>) +// CHECK: ^bb3(%[[VAL_22:.*]]: index, %[[VAL_23:.*]]: index, %[[VAL_24:.*]]: index, %[[VAL_25:.*]]: memref<64xi32>, %[[VAL_26:.*]]: memref<64xi32>): +// CHECK: br ^bb4(%[[VAL_22]], %[[VAL_22]], %[[VAL_23]], %[[VAL_24]], %[[VAL_25]], %[[VAL_26]] : index, index, index, index, memref<64xi32>, memref<64xi32>) +// CHECK: ^bb4(%[[VAL_27:.*]]: index, %[[VAL_28:.*]]: index, %[[VAL_29:.*]]: index, %[[VAL_30:.*]]: index, %[[VAL_31:.*]]: memref<64xi32>, %[[VAL_32:.*]]: memref<64xi32>): +// CHECK: %[[VAL_33:.*]] = arith.cmpi slt, %[[VAL_27]], %[[VAL_30]] : index +// CHECK: cond_br %[[VAL_33]], ^bb5(%[[VAL_28]], %[[VAL_27]], %[[VAL_29]], %[[VAL_30]], %[[VAL_31]], %[[VAL_32]] : index, index, index, index, memref<64xi32>, memref<64xi32>), ^bb6(%[[VAL_28]], %[[VAL_32]] : index, memref<64xi32>) +// CHECK: ^bb5(%[[VAL_34:.*]]: index, %[[VAL_35:.*]]: index, %[[VAL_36:.*]]: index, %[[VAL_37:.*]]: index, %[[VAL_38:.*]]: memref<64xi32>, %[[VAL_39:.*]]: memref<64xi32>): +// CHECK: %[[VAL_40:.*]] = memref.load %[[VAL_38]]{{\[}}%[[VAL_35]]] : memref<64xi32> +// CHECK: %[[VAL_41:.*]] = arith.addi %[[VAL_40]], %[[VAL_40]] : i32 +// CHECK: memref.store %[[VAL_41]], %[[VAL_39]]{{\[}}%[[VAL_35]]] : memref<64xi32> +// CHECK: %[[VAL_42:.*]] = arith.addi %[[VAL_35]], %[[VAL_36]] : index +// CHECK: br ^bb4(%[[VAL_42]], %[[VAL_34]], %[[VAL_36]], %[[VAL_37]], %[[VAL_38]], %[[VAL_39]] : index, index, index, index, memref<64xi32>, memref<64xi32>) +// CHECK: ^bb6(%[[VAL_43:.*]]: index, %[[VAL_44:.*]]: memref<64xi32>): +// CHECK: %[[VAL_45:.*]] = memref.load %[[VAL_44]]{{\[}}%[[VAL_43]]] : memref<64xi32> +// CHECK: return %[[VAL_45]] : i32 +// CHECK: } +func @memory_loop() -> i32 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c5_i32 = arith.constant 5 : i32 + %0 = memref.alloc() : memref<64xi32> + %1 = memref.alloc() : memref<64xi32> + br ^bb1(%c0 : index) +^bb1(%2: index): + %3 = arith.cmpi slt, %2, %c4 : index + cond_br %3, ^bb2, ^bb3 +^bb2: + memref.store %c5_i32, %0[%2] : memref<64xi32> + %4 = arith.addi %2, %c1 : index + br ^bb1(%4 : index) +^bb3: + br ^bb4(%c0 : index) +^bb4(%5: index): + %6 = arith.cmpi slt, %5, %c4 : index + cond_br %6, ^bb5, ^bb6 +^bb5: + %7 = memref.load %0[%5] : memref<64xi32> + %8 = arith.addi %7, %7 : i32 + memref.store %8, %1[%5] : memref<64xi32> + %9 = arith.addi %5, %c1 : index + br ^bb4(%9 : index) +^bb6: + %10 = memref.load %1[%c0] : memref<64xi32> + return %10 : i32 +} + +// ----- + +// CHECK-LABEL: func @matrix_power( +// CHECK-SAME: %[[VAL_0:.*]]: memref<400xi32>, %[[VAL_1:.*]]: memref<20xi32>, %[[VAL_2:.*]]: memref<20xi32>, %[[VAL_3:.*]]: memref<20xi32>) { +// CHECK: %[[VAL_4:.*]] = arith.constant 1 : i32 +// CHECK: %[[VAL_5:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_6:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_7:.*]] = arith.constant 20 : index +// CHECK: br ^bb1(%[[VAL_5]], %[[VAL_4]], %[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[VAL_3]], %[[VAL_5]], %[[VAL_6]], %[[VAL_7]] : index, i32, memref<400xi32>, memref<20xi32>, memref<20xi32>, memref<20xi32>, index, index, index) +// CHECK: ^bb1(%[[VAL_8:.*]]: index, %[[VAL_9:.*]]: i32, %[[VAL_10:.*]]: memref<400xi32>, %[[VAL_11:.*]]: memref<20xi32>, %[[VAL_12:.*]]: memref<20xi32>, %[[VAL_13:.*]]: memref<20xi32>, %[[VAL_14:.*]]: index, %[[VAL_15:.*]]: index, %[[VAL_16:.*]]: index): +// CHECK: %[[VAL_17:.*]] = arith.cmpi slt, %[[VAL_8]], %[[VAL_16]] : index +// CHECK: cond_br %[[VAL_17]], ^bb2(%[[VAL_8]], %[[VAL_9]], %[[VAL_10]], %[[VAL_11]], %[[VAL_12]], %[[VAL_13]], %[[VAL_14]], %[[VAL_15]], %[[VAL_16]] : index, i32, memref<400xi32>, memref<20xi32>, memref<20xi32>, memref<20xi32>, index, index, index), ^bb5 +// CHECK: ^bb2(%[[VAL_18:.*]]: index, %[[VAL_19:.*]]: i32, %[[VAL_20:.*]]: memref<400xi32>, %[[VAL_21:.*]]: memref<20xi32>, %[[VAL_22:.*]]: memref<20xi32>, %[[VAL_23:.*]]: memref<20xi32>, %[[VAL_24:.*]]: index, %[[VAL_25:.*]]: index, %[[VAL_26:.*]]: index): +// CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_18]], %[[VAL_24]] : index +// CHECK: %[[VAL_28:.*]] = arith.subi %[[VAL_18]], %[[VAL_24]] : index +// CHECK: %[[VAL_29:.*]] = arith.index_cast %[[VAL_28]] : index to i32 +// CHECK: %[[VAL_30:.*]] = arith.addi %[[VAL_29]], %[[VAL_19]] : i32 +// CHECK: %[[VAL_31:.*]] = arith.index_cast %[[VAL_30]] : i32 to index +// CHECK: br ^bb3(%[[VAL_25]], %[[VAL_19]], %[[VAL_20]], %[[VAL_21]], %[[VAL_22]], %[[VAL_23]], %[[VAL_24]], %[[VAL_25]], %[[VAL_26]], %[[VAL_27]], %[[VAL_31]] : index, i32, memref<400xi32>, memref<20xi32>, memref<20xi32>, memref<20xi32>, index, index, index, index, index) +// CHECK: ^bb3(%[[VAL_32:.*]]: index, %[[VAL_33:.*]]: i32, %[[VAL_34:.*]]: memref<400xi32>, %[[VAL_35:.*]]: memref<20xi32>, %[[VAL_36:.*]]: memref<20xi32>, %[[VAL_37:.*]]: memref<20xi32>, %[[VAL_38:.*]]: index, %[[VAL_39:.*]]: index, %[[VAL_40:.*]]: index, %[[VAL_41:.*]]: index, %[[VAL_42:.*]]: index): +// CHECK: %[[VAL_43:.*]] = arith.cmpi slt, %[[VAL_32]], %[[VAL_40]] : index +// CHECK: cond_br %[[VAL_43]], ^bb4(%[[VAL_32]], %[[VAL_33]], %[[VAL_34]], %[[VAL_35]], %[[VAL_36]], %[[VAL_37]], %[[VAL_38]], %[[VAL_39]], %[[VAL_40]], %[[VAL_41]], %[[VAL_42]] : index, i32, memref<400xi32>, memref<20xi32>, memref<20xi32>, memref<20xi32>, index, index, index, index, index), ^bb1(%[[VAL_41]], %[[VAL_33]], %[[VAL_34]], %[[VAL_35]], %[[VAL_36]], %[[VAL_37]], %[[VAL_38]], %[[VAL_39]], %[[VAL_40]] : index, i32, memref<400xi32>, memref<20xi32>, memref<20xi32>, memref<20xi32>, index, index, index) +// CHECK: ^bb4(%[[VAL_44:.*]]: index, %[[VAL_45:.*]]: i32, %[[VAL_46:.*]]: memref<400xi32>, %[[VAL_47:.*]]: memref<20xi32>, %[[VAL_48:.*]]: memref<20xi32>, %[[VAL_49:.*]]: memref<20xi32>, %[[VAL_50:.*]]: index, %[[VAL_51:.*]]: index, %[[VAL_52:.*]]: index, %[[VAL_53:.*]]: index, %[[VAL_54:.*]]: index): +// CHECK: %[[VAL_55:.*]] = arith.addi %[[VAL_44]], %[[VAL_50]] : index +// CHECK: %[[VAL_56:.*]] = memref.load %[[VAL_47]]{{\[}}%[[VAL_44]]] : memref<20xi32> +// CHECK: %[[VAL_57:.*]] = arith.index_cast %[[VAL_56]] : i32 to index +// CHECK: %[[VAL_58:.*]] = memref.load %[[VAL_49]]{{\[}}%[[VAL_44]]] : memref<20xi32> +// CHECK: %[[VAL_59:.*]] = arith.subi %[[VAL_54]], %[[VAL_50]] : index +// CHECK: %[[VAL_60:.*]] = memref.load %[[VAL_48]]{{\[}}%[[VAL_44]]] : memref<20xi32> +// CHECK: %[[VAL_61:.*]] = arith.index_cast %[[VAL_60]] : i32 to index +// CHECK: %[[VAL_62:.*]] = arith.muli %[[VAL_61]], %[[VAL_52]] : index +// CHECK: %[[VAL_63:.*]] = arith.addi %[[VAL_59]], %[[VAL_62]] : index +// CHECK: %[[VAL_64:.*]] = memref.load %[[VAL_46]]{{\[}}%[[VAL_63]]] : memref<400xi32> +// CHECK: %[[VAL_65:.*]] = arith.muli %[[VAL_58]], %[[VAL_64]] : i32 +// CHECK: %[[VAL_66:.*]] = arith.muli %[[VAL_57]], %[[VAL_52]] : index +// CHECK: %[[VAL_67:.*]] = arith.addi %[[VAL_54]], %[[VAL_66]] : index +// CHECK: %[[VAL_68:.*]] = memref.load %[[VAL_46]]{{\[}}%[[VAL_67]]] : memref<400xi32> +// CHECK: %[[VAL_69:.*]] = arith.addi %[[VAL_68]], %[[VAL_65]] : i32 +// CHECK: %[[VAL_70:.*]] = arith.muli %[[VAL_57]], %[[VAL_52]] : index +// CHECK: %[[VAL_71:.*]] = arith.addi %[[VAL_54]], %[[VAL_70]] : index +// CHECK: memref.store %[[VAL_69]], %[[VAL_46]]{{\[}}%[[VAL_71]]] : memref<400xi32> +// CHECK: br ^bb3(%[[VAL_55]], %[[VAL_45]], %[[VAL_46]], %[[VAL_47]], %[[VAL_48]], %[[VAL_49]], %[[VAL_50]], %[[VAL_51]], %[[VAL_52]], %[[VAL_53]], %[[VAL_54]] : index, i32, memref<400xi32>, memref<20xi32>, memref<20xi32>, memref<20xi32>, index, index, index, index, index) +// CHECK: ^bb5: +// CHECK: return +// CHECK: } +func @matrix_power(%arg0: memref<400xi32>, %arg1: memref<20xi32>, %arg2: memref<20xi32>, %arg3: memref<20xi32>) { + %c1_i32 = arith.constant 1 : i32 + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %c20 = arith.constant 20 : index + br ^bb1(%c1 : index) +^bb1(%0: index): // 2 preds: ^bb0, ^bb3 + %1 = arith.cmpi slt, %0, %c20 : index + cond_br %1, ^bb2(%0 : index), ^bb5 +^bb2(%2: index): // pred: ^bb1 + %3 = arith.addi %2, %c1 : index + %4 = arith.subi %2, %c1 : index + %5 = arith.index_cast %4 : index to i32 + %6 = arith.addi %5, %c1_i32 : i32 + %7 = arith.index_cast %6 : i32 to index + br ^bb3(%c0 : index) +^bb3(%8: index): // 2 preds: ^bb2, ^bb4 + %9 = arith.cmpi slt, %8, %c20 : index + cond_br %9, ^bb4(%8 : index), ^bb1(%3 : index) +^bb4(%10: index): // pred: ^bb3 + %11 = arith.addi %10, %c1 : index + %12 = memref.load %arg1[%10] : memref<20xi32> + %13 = arith.index_cast %12 : i32 to index + %14 = memref.load %arg3[%10] : memref<20xi32> + %15 = arith.subi %7, %c1 : index + %16 = memref.load %arg2[%10] : memref<20xi32> + %17 = arith.index_cast %16 : i32 to index + %18 = arith.muli %17, %c20 : index + %19 = arith.addi %15, %18 : index + %20 = memref.load %arg0[%19] : memref<400xi32> + %21 = arith.muli %14, %20 : i32 + %22 = arith.muli %13, %c20 : index + %23 = arith.addi %7, %22 : index + %24 = memref.load %arg0[%23] : memref<400xi32> + %25 = arith.addi %24, %21 : i32 + %26 = arith.muli %13, %c20 : index + %27 = arith.addi %7, %26 : index + memref.store %25, %arg0[%27] : memref<400xi32> + br ^bb3(%11 : index) +^bb5: // pred: ^bb1 + return +}