diff --git a/mlir/include/mlir/Transforms/MaxSSAUtils.h b/mlir/include/mlir/Transforms/MaxSSAUtils.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Transforms/MaxSSAUtils.h @@ -0,0 +1,83 @@ +//===- MaxSSAUtils.h - Maximal SSA form utilities ---------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file defines prototypes for various utilities for converting a +// function using standard control flow into having maximal SSA form. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TRANSFORMS_MAX_SSA_UTILS_H +#define MLIR_TRANSFORMS_MAX_SSA_UTILS_H + +#include "mlir/IR/Block.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include + +namespace mlir { +class FuncOp; + +/// A filter callback function that can be used to determine if a given value +/// should be ignored during maximal SSA form conversion. If the function +/// returns true, usages of the value will remain unmodified. +using ValueFilterCallbackFn = llvm::function_ref; + +/// Converts a FuncOp into maximal SSA form. In this form, any value referenced +/// within a block is also defined within the block. This is done by adding +/// block arguments to all basic block dominance chains which may lead to an +/// operation that relied on referencing a Value based on basic block dominance. +/// +/// This utility is useful in dataflow-style programming models since it renders +/// all data flow within the program explicit (through block arguments) instead +/// of implicit (through block dominance) +/// +/// This utility only works on standard control flow, in that it expects all +/// operations (and blocks) within a FuncOp to be nested immediately within the +/// same region. Furthermore, it is assumed that any value referenced by any +/// operation is eligible to be passed around as a block argument. +/// +/// An optional ValueFilterCallbackFn may be provided to filter out values from +/// being considered during SSA maximization. +/// +/// Example: +/// func @simple(%arg0 : index, %arg1 : index) -> (index, index, index) { +/// br ^bb1(%arg0 : index) +/// ^bb1(%0 : index): +/// return %0, %arg0, %arg1 : index, index, index +/// } +/// -> convertToMaximalSSA(@simple) +/// func @simple(%arg0 : index, %arg1 : index) -> (index, index, index) { +/// br ^bb1(%arg0, %arg0, %arg1 : index, index, index) +/// ^bb1(%0, %1, %2 : index, index): +/// return %0, %1, %2 : index, index, index +/// } +LogicalResult convertToMaximalSSA(FuncOp func, + ValueFilterCallbackFn filterFn = nullptr); + +/// Like convertToMaximalSSA(FuncOp) but restricted to converting references to +/// a single 'value' into maximal SSA form. +/// The pass assumes that the 'value' within a FuncOp. Any prerequisits of +/// convertToMaximalSSA(FuncOp) also applies to this. +/// +/// Example: +/// func @simple(%arg0 : index, %arg1 : index) -> (index, index, index) { +/// br ^bb1(%arg0 : index) +/// ^bb1(%0 : index): +/// return %0, %arg0, %arg1 : index, index, index +/// } +/// -> convertToMaximalSSA(%arg1) +/// func @simple(%arg0 : index, %arg1 : index) -> (index, index, index) { +/// br ^bb1(%arg0, %arg1 : index, index) +/// ^bb1(%0, %1 : index, index): +/// return %0, %arg0, %1 : index, index, index +/// } +LogicalResult convertToMaximalSSA(Value value); + +} // end namespace mlir + +#endif // MLIR_TRANSFORMS_MAX_SSA_UTILS_H diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -129,6 +129,10 @@ /// (identity) layout map. std::unique_ptr> createNormalizeMemRefsPass(); +/// Creates a pass which converts a program into maximal SSA form. In this form, +/// any value referenced within a block is also defined within the same block. +std::unique_ptr createMaxSSAFormPass(); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -629,4 +629,25 @@ let constructor = "mlir::createPrintOpGraphPass()"; } +def MaxSSAForm : FunctionPass<"convert-to-max-ssa"> { + let summary = "Converts a program into maximal SSA form"; + let description = [{ + This pass converts a program into maximal SSA form. In this form, any value + referenced within a block is also defined within the block. This is done by + adding block arguments to all basic block dominance chains which may lead + to an operation that relied on referencing a Value based on basic block + dominance. + + This pass is useful in dataflow-style programming models since it renders + all data flow within the program explicit (through block arguments) instead + of implicit (through block dominance). + + This pass only works on standard control flow, in that it expects all operations + (and blocks) within a FuncOp to be nested immediately within the same region. + Furthermore, it is assumed that any value referenced by any operation is eligible + to be passed around as a block argument. + }]; + let constructor = "mlir::createMaxSSAFormPass()"; +} + #endif // MLIR_TRANSFORMS_PASSES diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt --- a/mlir/lib/Transforms/CMakeLists.txt +++ b/mlir/lib/Transforms/CMakeLists.txt @@ -11,6 +11,7 @@ LoopCoalescing.cpp LoopFusion.cpp LoopInvariantCodeMotion.cpp + MaxSSA.cpp NormalizeMemRefs.cpp OpStats.cpp ParallelLoopCollapsing.cpp diff --git a/mlir/lib/Transforms/MaxSSA.cpp b/mlir/lib/Transforms/MaxSSA.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Transforms/MaxSSA.cpp @@ -0,0 +1,40 @@ +//===- MaxSSA.cpp - Maximal SSA form conversion ------------------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Contains the definitions of the maximal SSA form conversion pass. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Transforms/MaxSSAUtils.h" +#include "llvm/ADT/TypeSwitch.h" + +#include + +using namespace mlir; + +namespace { + +struct MaxSSAFormPass : public MaxSSAFormBase { +public: + void runOnFunction() override { + FuncOp func = getOperation(); + if (convertToMaximalSSA(func).failed()) + return signalPassFailure(); + } +}; + +} // namespace + +namespace mlir { +std::unique_ptr createMaxSSAFormPass() { + return std::make_unique(); +} +} // namespace mlir diff --git a/mlir/lib/Transforms/Utils/CMakeLists.txt b/mlir/lib/Transforms/Utils/CMakeLists.txt --- a/mlir/lib/Transforms/Utils/CMakeLists.txt +++ b/mlir/lib/Transforms/Utils/CMakeLists.txt @@ -5,6 +5,7 @@ InliningUtils.cpp LoopFusionUtils.cpp LoopUtils.cpp + MaxSSAUtils.cpp RegionUtils.cpp Utils.cpp diff --git a/mlir/lib/Transforms/Utils/MaxSSAUtils.cpp b/mlir/lib/Transforms/Utils/MaxSSAUtils.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Transforms/Utils/MaxSSAUtils.cpp @@ -0,0 +1,216 @@ +//===- MaxSSAUtils.cpp - Maximal SSA form utilities --------------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Contains various utilities for converting a function using standard control +// flow into having maximal SSA form. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/MaxSSAUtils.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; + +/// Rewrites the terminator in 'from' to pass an additional argument 'v' when +/// passing control flow to 'to'. +static LogicalResult rewriteControlFlowToBlock(Block *from, Block *to, + Value v) { + auto termOp = from->getTerminator(); + auto branchOp = dyn_cast(termOp); + if (!branchOp) + return termOp->emitOpError() << "expected terminator op within control " + "flow to be a branch-like op"; + + // Find 'to' successor in the branch op. + auto successors = branchOp->getSuccessors(); + auto succIt = llvm::find(successors, to); + assert(succIt != successors.end()); + unsigned succIdx = std::distance(successors.begin(), succIt); + + // Add the value 'v' as a block argument. + branchOp.getMutableSuccessorOperands(succIdx)->append(v); + return success(); +} + +/// Rewrites uses of 'oldV' in 'b' to 'newV'. +static void rewriteUsageInBlock(Block *b, Value oldV, Value newV) { + for (auto &use : llvm::make_early_inc_range(oldV.getUses())) + if (use.getOwner()->getBlock() == b) + use.set(newV); +} + +/// Perform a depth-first search backwards through the CFG graph of a program, +/// starting from 'use', add a new block argument of type(v) to the block and +/// replaces all uses of 'v' with the new block argument. +/// +/// Arguments: 'use': ablock where a value v flows through. 'succ': a successor +/// block of the 'use' block. Notably this conversion backtracks through the BB +/// CFG graph, so 'succ' will be a basic block that called backtrackAndConvert +/// on 'use'. 'inBlockValues': A mapping containing the appended block argument +/// when backtracking through a basic block. 'convertedControlFlow': A mapping +/// containing, for a given block (key) which successor operand range in the +/// terminator have been rewritten to the new block argument signature. +static LogicalResult +backtrackAndConvert(Block *use, Block *succ, Value v, + DenseMap> &inBlockValues, + DenseMap>> + &convertedControlFlow) { + // The base case is when we've backtracked to the Block which defines the + // value. In these cases, set the actual value as the converted value. + if (v.getParentBlock() == use) + inBlockValues[v][use] = v; + + if (inBlockValues[v].count(use) == 0) { + // Rewrite this blocks' block arguments to take in a new value of 'v' type. + use->addArgument(v.getType()); + Value newBarg = use->getArguments().back(); + + // Register the converted block argument in case other branches in the CFG + // arrive here later. + inBlockValues[v][use] = newBarg; + rewriteUsageInBlock(use, v, newBarg); + + // Recurse through the predecessors of this block. + for (auto pred : use->getPredecessors()) + if (failed(backtrackAndConvert(pred, use, v, inBlockValues, + convertedControlFlow))) + return failure(); + } + + // Rewrite control flow to the 'succ' block through the terminator, if not + // already done. + if (succ && convertedControlFlow[v][use].count(succ) == 0) { + auto alreadyinBlockValues = inBlockValues[v].find(use); + assert(alreadyinBlockValues != inBlockValues[v].end()); + if (failed( + rewriteControlFlowToBlock(use, succ, alreadyinBlockValues->second))) + return failure(); + convertedControlFlow[v][use].insert(succ); + } + + return success(); +} + +namespace { + +struct MaxSSAFormConverter { +public: + /// An optional filterFn may be provided to dynamically filter out values + /// from being converted. + MaxSSAFormConverter(ValueFilterCallbackFn filterFn = nullptr) + : filterFn(filterFn) {} + + LogicalResult convertFunction(FuncOp function) { + auto walkRes = function.walk([&](Operation *op) { + // Run on operation results. + if (llvm::any_of(op->getResults(), + [&](Value res) { return failed(runOnValue(res)); })) { + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + + if (walkRes.wasInterrupted()) + return failure(); + + // Run on block arguments. + for (auto &block : function) { + if (llvm::any_of(block.getArguments(), + [&](Value barg) { return failed(runOnValue(barg)); })) { + return failure(); + } + } + + assert( + succeeded(verifyFunction(function)) && + "Some values were still referenced outside of their defining block!"); + + return success(); + } + + LogicalResult convertValue(Value v) { + auto *defOp = v.getDefiningOp(); + auto funcOp = dyn_cast(defOp->getParentOp()); + assert(funcOp && "Expected parent operation to be a FuncOp"); + + return runOnValue(v); + } + +private: + /// Verifies that all values which are not filtered indeed are only referenced + /// within their defining block. + LogicalResult verifyFunction(FuncOp f) const; + + /// Driver which will run backtrackAndConvert on values referenced outside + /// their defining block. Returns failure in case the pass failed to apply. + /// This may happen when nested regions exist within the FuncOp which this + /// pass is applied to, or if non branch-like control flow is used. + LogicalResult runOnValue(Value v); + + /// A mapping {original value : {block : replaced value}} representing + /// 'original value' has been replaced in 'block' with 'replaced value'". + DenseMap> inBlockValues; + + /// A mapping {original value : {block : succ block}} representing + /// 'original value' has already been passed from 'block' to 'succ block' + /// through the terminator of 'block'. + DenseMap>> convertedControlFlow; + + /// An optional filter function to dynamically determine whether a value + /// should be considered for SSA maximization. + ValueFilterCallbackFn filterFn; +}; + +LogicalResult MaxSSAFormConverter::verifyFunction(FuncOp f) const { + for (auto &op : f.getOps()) { + auto isValid = [&](Value v) { + if (filterFn && filterFn(v)) + return true; + return v.getParentBlock() == op.getBlock(); + }; + + if (!llvm::all_of(op.getOperands(), isValid)) + return op.emitOpError() + << "has operands that are not defined within its block"; + } + return success(); +} + +LogicalResult MaxSSAFormConverter::runOnValue(Value v) { + if (filterFn && filterFn(v)) + return success(); + Block *definingBlock = v.getParentBlock(); + for (Operation *user : llvm::make_early_inc_range(v.getUsers())) { + Block *userBlock = user->getBlock(); + if (definingBlock != userBlock) { + // This is a case of using an SSA value through basic block dominance. + if (userBlock->getParent() != definingBlock->getParent()) + return user->emitOpError() << "can only convert SSA usage across " + "blocks in the same region."; + + if (failed(backtrackAndConvert(userBlock, /*succ=*/nullptr, v, + inBlockValues, convertedControlFlow))) + return failure(); + } + } + return success(); +} + +} // namespace + +LogicalResult mlir::convertToMaximalSSA(FuncOp func, + ValueFilterCallbackFn filterFn) { + return MaxSSAFormConverter(filterFn).convertFunction(func); +} + +LogicalResult mlir::convertToMaximalSSA(Value value) { + return MaxSSAFormConverter().convertValue(value); +} diff --git a/mlir/test/Transforms/max-ssa-errors.mlir b/mlir/test/Transforms/max-ssa-errors.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Transforms/max-ssa-errors.mlir @@ -0,0 +1,14 @@ +// RUN: mlir-opt -split-input-file --convert-to-max-ssa %s -verify-diagnostics + +func @scf(%i : index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c0_i32 = arith.constant 0 : i32 + %m = memref.alloc() : memref<4xi32> + scf.for %0 = %c0 to %c4 step %c1 { + // expected-error @+1 {{'memref.store' op can only convert SSA usage across blocks in the same region.}} + memref.store %c0_i32, %m[%i] : memref<4xi32> + } + return +} diff --git a/mlir/test/Transforms/max-ssa-ignore.mlir b/mlir/test/Transforms/max-ssa-ignore.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Transforms/max-ssa-ignore.mlir @@ -0,0 +1,62 @@ +// RUN: mlir-opt -split-input-file --test-max-ssa-filtered %s | FileCheck %s + +// CHECK-LABEL: func @memory_loop() -> i32 { +// CHECK: %[[VAL_0:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_1:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_2:.*]] = arith.constant 4 : index +// CHECK: %[[VAL_3:.*]] = arith.constant 5 : i32 +// CHECK: %[[VAL_4:.*]] = memref.alloc() : memref<64xi32> +// CHECK: %[[VAL_5:.*]] = memref.alloc() : memref<64xi32> +// CHECK: br ^bb1(%[[VAL_0]], %[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[VAL_3]] : index, index, index, index, i32) +// CHECK: ^bb1(%[[VAL_6:.*]]: index, %[[VAL_7:.*]]: index, %[[VAL_8:.*]]: index, %[[VAL_9:.*]]: index, %[[VAL_10:.*]]: i32): +// CHECK: %[[VAL_11:.*]] = arith.cmpi slt, %[[VAL_6]], %[[VAL_9]] : index +// CHECK: cond_br %[[VAL_11]], ^bb2(%[[VAL_7]], %[[VAL_8]], %[[VAL_9]], %[[VAL_10]], %[[VAL_6]] : index, index, index, i32, index), ^bb3(%[[VAL_7]], %[[VAL_8]], %[[VAL_9]] : index, index, index) +// CHECK: ^bb2(%[[VAL_12:.*]]: index, %[[VAL_13:.*]]: index, %[[VAL_14:.*]]: index, %[[VAL_15:.*]]: i32, %[[VAL_16:.*]]: index): +// CHECK: memref.store %[[VAL_15]], %[[VAL_4]]{{\[}}%[[VAL_16]]] : memref<64xi32> +// CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_16]], %[[VAL_13]] : index +// CHECK: br ^bb1(%[[VAL_17]], %[[VAL_12]], %[[VAL_13]], %[[VAL_14]], %[[VAL_15]] : index, index, index, index, i32) +// CHECK: ^bb3(%[[VAL_18:.*]]: index, %[[VAL_19:.*]]: index, %[[VAL_20:.*]]: index): +// CHECK: br ^bb4(%[[VAL_18]], %[[VAL_18]], %[[VAL_19]], %[[VAL_20]] : index, index, index, index) +// CHECK: ^bb4(%[[VAL_21:.*]]: index, %[[VAL_22:.*]]: index, %[[VAL_23:.*]]: index, %[[VAL_24:.*]]: index): +// CHECK: %[[VAL_25:.*]] = arith.cmpi slt, %[[VAL_21]], %[[VAL_24]] : index +// CHECK: cond_br %[[VAL_25]], ^bb5(%[[VAL_22]], %[[VAL_23]], %[[VAL_24]], %[[VAL_21]] : index, index, index, index), ^bb6(%[[VAL_22]] : index) +// CHECK: ^bb5(%[[VAL_26:.*]]: index, %[[VAL_27:.*]]: index, %[[VAL_28:.*]]: index, %[[VAL_29:.*]]: index): +// CHECK: %[[VAL_30:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_29]]] : memref<64xi32> +// CHECK: %[[VAL_31:.*]] = arith.addi %[[VAL_30]], %[[VAL_30]] : i32 +// CHECK: memref.store %[[VAL_31]], %[[VAL_5]]{{\[}}%[[VAL_29]]] : memref<64xi32> +// CHECK: %[[VAL_32:.*]] = arith.addi %[[VAL_29]], %[[VAL_27]] : index +// CHECK: br ^bb4(%[[VAL_32]], %[[VAL_26]], %[[VAL_27]], %[[VAL_28]] : index, index, index, index) +// CHECK: ^bb6(%[[VAL_33:.*]]: index): +// CHECK: %[[VAL_34:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_33]]] : memref<64xi32> +// CHECK: return %[[VAL_34]] : i32 +// CHECK: } +func @memory_loop() -> i32 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c5_i32 = arith.constant 5 : i32 + %0 = memref.alloc() : memref<64xi32> + %1 = memref.alloc() : memref<64xi32> + br ^bb1(%c0 : index) +^bb1(%2: index): + %3 = arith.cmpi slt, %2, %c4 : index + cond_br %3, ^bb2, ^bb3 +^bb2: + memref.store %c5_i32, %0[%2] : memref<64xi32> + %4 = arith.addi %2, %c1 : index + br ^bb1(%4 : index) +^bb3: + br ^bb4(%c0 : index) +^bb4(%5: index): + %6 = arith.cmpi slt, %5, %c4 : index + cond_br %6, ^bb5, ^bb6 +^bb5: + %7 = memref.load %0[%5] : memref<64xi32> + %8 = arith.addi %7, %7 : i32 + memref.store %8, %1[%5] : memref<64xi32> + %9 = arith.addi %5, %c1 : index + br ^bb4(%9 : index) +^bb6: + %10 = memref.load %1[%c0] : memref<64xi32> + return %10 : i32 +} diff --git a/mlir/test/Transforms/max-ssa.mlir b/mlir/test/Transforms/max-ssa.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Transforms/max-ssa.mlir @@ -0,0 +1,137 @@ +// RUN: mlir-opt -split-input-file --convert-to-max-ssa %s | FileCheck %s + +// CHECK-LABEL: func @simple( +// CHECK: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index) -> (index, index) { +// CHECK: br ^bb1(%[[VAL_0]], %[[VAL_1]] : index, index) +// CHECK: ^bb1(%[[VAL_2:.*]]: index, %[[VAL_3:.*]]: index): +// CHECK: return %[[VAL_2]], %[[VAL_3]] : index, index +// CHECK: } +func @simple(%arg0 : index, %arg1 : index) -> (index, index) { + br ^bb1(%arg0 : index) +^bb1(%0 : index): + return %0, %arg1 : index, index +} + +// ----- + +// CHECK-LABEL: func @backedge() -> index { +// CHECK: %[[VAL_0:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_1:.*]] = arith.constant 42 : index +// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index +// CHECK: br ^bb1(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]] : index, index, index) +// CHECK: ^bb1(%[[VAL_3:.*]]: index, %[[VAL_4:.*]]: index, %[[VAL_5:.*]]: index): +// CHECK: %[[VAL_6:.*]] = arith.cmpi slt, %[[VAL_3]], %[[VAL_4]] : index +// CHECK: cond_br %[[VAL_6]], ^bb2(%[[VAL_4]], %[[VAL_5]], %[[VAL_3]] : index, index, index), ^bb3(%[[VAL_3]] : index) +// CHECK: ^bb2(%[[VAL_7:.*]]: index, %[[VAL_8:.*]]: index, %[[VAL_9:.*]]: index): +// CHECK: %[[VAL_10:.*]] = arith.addi %[[VAL_9]], %[[VAL_8]] : index +// CHECK: br ^bb1(%[[VAL_10]], %[[VAL_7]], %[[VAL_8]] : index, index, index) +// CHECK: ^bb3(%[[VAL_11:.*]]: index): +// CHECK: return %[[VAL_11]] : index +// CHECK: } +func @backedge() -> index { + %c1 = arith.constant 1 : index + %c42 = arith.constant 42 : index + %c1_0 = arith.constant 1 : index + br ^bb1(%c1 : index) +^bb1(%0: index): + %1 = arith.cmpi slt, %0, %c42 : index + cond_br %1, ^bb2, ^bb3 +^bb2: + %2 = arith.addi %0, %c1_0 : index + br ^bb1(%2 : index) +^bb3: + return %0 : index +} + +// ----- + +// CHECK-LABEL: func @complex_1( +// CHECK: %[[VAL_0:.*]]: i32, %[[VAL_1:.*]]: i1) { +// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index +// CHECK: cond_br %[[VAL_1]], ^bb1(%[[VAL_2]], %[[VAL_0]], %[[VAL_1]] : index, i32, i1), ^bb9 +// CHECK: ^bb1(%[[VAL_3:.*]]: index, %[[VAL_4:.*]]: i32, %[[VAL_5:.*]]: i1): +// CHECK: %[[VAL_6:.*]] = arith.index_cast %[[VAL_4]] : i32 to index +// CHECK: br ^bb2(%[[VAL_3]], %[[VAL_3]], %[[VAL_6]], %[[VAL_4]], %[[VAL_5]] : index, index, index, i32, i1) +// CHECK: ^bb2(%[[VAL_7:.*]]: index, %[[VAL_8:.*]]: index, %[[VAL_9:.*]]: index, %[[VAL_10:.*]]: i32, %[[VAL_11:.*]]: i1): +// CHECK: %[[VAL_12:.*]] = arith.cmpi slt, %[[VAL_7]], %[[VAL_9]] : index +// CHECK: cond_br %[[VAL_12]], ^bb3(%[[VAL_7]], %[[VAL_8]], %[[VAL_9]], %[[VAL_10]], %[[VAL_11]] : index, index, index, i32, i1), ^bb9 +// CHECK: ^bb3(%[[VAL_13:.*]]: index, %[[VAL_14:.*]]: index, %[[VAL_15:.*]]: index, %[[VAL_16:.*]]: i32, %[[VAL_17:.*]]: i1): +// CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_13]], %[[VAL_14]] : index +// CHECK: br ^bb4(%[[VAL_16]], %[[VAL_14]], %[[VAL_15]], %[[VAL_18]], %[[VAL_16]], %[[VAL_17]] : i32, index, index, index, i32, i1) +// CHECK: ^bb4(%[[VAL_19:.*]]: i32, %[[VAL_20:.*]]: index, %[[VAL_21:.*]]: index, %[[VAL_22:.*]]: index, %[[VAL_23:.*]]: i32, %[[VAL_24:.*]]: i1): +// CHECK: cond_br %[[VAL_24]], ^bb5(%[[VAL_20]], %[[VAL_21]], %[[VAL_22]], %[[VAL_23]], %[[VAL_24]], %[[VAL_19]] : index, index, index, i32, i1, i32), ^bb6(%[[VAL_24]], %[[VAL_20]], %[[VAL_21]], %[[VAL_22]], %[[VAL_23]], %[[VAL_24]], %[[VAL_19]] : i1, index, index, index, i32, i1, i32) +// CHECK: ^bb5(%[[VAL_25:.*]]: index, %[[VAL_26:.*]]: index, %[[VAL_27:.*]]: index, %[[VAL_28:.*]]: i32, %[[VAL_29:.*]]: i1, %[[VAL_30:.*]]: i32): +// CHECK: br ^bb6(%[[VAL_29]], %[[VAL_25]], %[[VAL_26]], %[[VAL_27]], %[[VAL_28]], %[[VAL_29]], %[[VAL_30]] : i1, index, index, index, i32, i1, i32) +// CHECK: ^bb6(%[[VAL_31:.*]]: i1, %[[VAL_32:.*]]: index, %[[VAL_33:.*]]: index, %[[VAL_34:.*]]: index, %[[VAL_35:.*]]: i32, %[[VAL_36:.*]]: i1, %[[VAL_37:.*]]: i32): +// CHECK: cond_br %[[VAL_31]], ^bb7(%[[VAL_37]], %[[VAL_32]], %[[VAL_33]], %[[VAL_34]], %[[VAL_35]], %[[VAL_36]] : i32, index, index, index, i32, i1), ^bb8(%[[VAL_32]], %[[VAL_33]], %[[VAL_34]], %[[VAL_35]], %[[VAL_36]] : index, index, index, i32, i1) +// CHECK: ^bb7(%[[VAL_38:.*]]: i32, %[[VAL_39:.*]]: index, %[[VAL_40:.*]]: index, %[[VAL_41:.*]]: index, %[[VAL_42:.*]]: i32, %[[VAL_43:.*]]: i1): +// CHECK: br ^bb4(%[[VAL_38]], %[[VAL_39]], %[[VAL_40]], %[[VAL_41]], %[[VAL_42]], %[[VAL_43]] : i32, index, index, index, i32, i1) +// CHECK: ^bb8(%[[VAL_44:.*]]: index, %[[VAL_45:.*]]: index, %[[VAL_46:.*]]: index, %[[VAL_47:.*]]: i32, %[[VAL_48:.*]]: i1): +// CHECK: br ^bb2(%[[VAL_46]], %[[VAL_44]], %[[VAL_45]], %[[VAL_47]], %[[VAL_48]] : index, index, index, i32, i1) +// CHECK: ^bb9: +// CHECK: return +// CHECK: } +func @complex_1(%arg0: i32, %arg1: i1) { + %c1 = arith.constant 1 : index + cond_br %arg1, ^bb1, ^bb9 +^bb1: // pred: ^bb0 + %0 = arith.index_cast %arg0 : i32 to index + // %c1 is passed as argument 1 here but that does not mean that the first + // argument going into ^bb2 will always be %c1!. + br ^bb2(%c1 : index) +^bb2(%1: index): // 2 preds: ^bb1, ^bb8 + %2 = arith.cmpi slt, %1, %0 : index + cond_br %2, ^bb3(%1 : index), ^bb9 +^bb3(%3: index): // pred: ^bb2 + %4 = arith.addi %3, %c1 : index + br ^bb4(%arg0 : i32) +^bb4(%5: i32): // 2 preds: ^bb3, ^bb7 + cond_br %arg1, ^bb5, ^bb6(%arg1 : i1) +^bb5: // pred: ^bb4 + br ^bb6(%arg1 : i1) +^bb6(%6: i1): // 2 preds: ^bb4, ^bb5 + cond_br %6, ^bb7(%5 : i32), ^bb8 +^bb7(%7: i32): // pred: ^bb6 + br ^bb4(%7 : i32) +^bb8: // pred: ^bb6 + br ^bb2(%4 : index) +^bb9: // 2 preds: ^bb0, ^bb2 + return +} + + +// ----- + +// CHECK-LABEL: func @complex_2( +// CHECK: %[[VAL_0:.*]]: i1) { +// CHECK: %[[VAL_1:.*]] = arith.constant 0 : index +// CHECK: br ^bb1(%[[VAL_1]], %[[VAL_1]], %[[VAL_0]] : index, index, i1) +// CHECK: ^bb1(%[[VAL_2:.*]]: index, %[[VAL_3:.*]]: index, %[[VAL_4:.*]]: i1): +// CHECK: cond_br %[[VAL_4]], ^bb2(%[[VAL_3]], %[[VAL_4]], %[[VAL_2]] : index, i1, index), ^bb3(%[[VAL_3]], %[[VAL_4]] : index, i1) +// CHECK: ^bb2(%[[VAL_5:.*]]: index, %[[VAL_6:.*]]: i1, %[[VAL_7:.*]]: index): +// CHECK: br ^bb1(%[[VAL_7]], %[[VAL_5]], %[[VAL_6]] : index, index, i1) +// CHECK: ^bb3(%[[VAL_8:.*]]: index, %[[VAL_9:.*]]: i1): +// CHECK: br ^bb4(%[[VAL_8]], %[[VAL_9]] : index, i1) +// CHECK: ^bb4(%[[VAL_10:.*]]: index, %[[VAL_11:.*]]: i1): +// CHECK: cond_br %[[VAL_11]], ^bb5(%[[VAL_11]], %[[VAL_10]] : i1, index), ^bb6 +// CHECK: ^bb5(%[[VAL_12:.*]]: i1, %[[VAL_13:.*]]: index): +// CHECK: br ^bb4(%[[VAL_13]], %[[VAL_12]] : index, i1) +// CHECK: ^bb6: +// CHECK: return +// CHECK: } +func @complex_2(%arg0: i1) { + %c0 = arith.constant 0 : index + br ^bb1(%c0 : index) +^bb1(%0: index): // 2 preds: ^bb0, ^bb2 + cond_br %arg0, ^bb2, ^bb3 +^bb2: // pred: ^bb1 + br ^bb1(%0 : index) +^bb3: // pred: ^bb1 + br ^bb4(%c0 : index) +^bb4(%1: index): // 2 preds: ^bb3, ^bb5 + cond_br %arg0, ^bb5, ^bb6 +^bb5: // pred: ^bb4 + br ^bb4(%1 : index) +^bb6: // pred: ^bb4 + return +} diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt --- a/mlir/test/lib/Transforms/CMakeLists.txt +++ b/mlir/test/lib/Transforms/CMakeLists.txt @@ -6,6 +6,7 @@ TestLoopMapping.cpp TestLoopParametricTiling.cpp TestLoopUnrolling.cpp + TestMaxSSAFiltered.cpp EXCLUDE_FROM_LIBMLIR diff --git a/mlir/test/lib/Transforms/TestMaxSSAFiltered.cpp b/mlir/test/lib/Transforms/TestMaxSSAFiltered.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Transforms/TestMaxSSAFiltered.cpp @@ -0,0 +1,51 @@ +//===- TestMaxSSAFiltered.cpp --- Max SSA with filtering test pass --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to test the max SSA utility with use of a +// filtering function. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/MaxSSAUtils.h" +#include "mlir/Transforms/Passes.h" + +#include "llvm/ADT/SetVector.h" + +using namespace mlir; + +namespace { +class TestMaxSSAFilteredPass + : public PassWrapper { +public: + StringRef getArgument() const final { return "test-max-ssa-filtered"; } + StringRef getDescription() const final { + return "test the max SSA utility while filtering out certain values. The " + "test pass is configured to filter out memref values."; + } + explicit TestMaxSSAFilteredPass() {} + + void getDependentDialects(DialectRegistry & /*registry*/) const override {} + + void runOnFunction() override { + FuncOp func = getFunction(); + auto filterFn = [](Value v) { return v.getType().isa(); }; + + if (convertToMaximalSSA(func, filterFn).failed()) + return signalPassFailure(); + } +}; +} // namespace + +namespace mlir { +namespace test { +void registerTestMaxSSAFilteredPass() { + PassRegistration(); +} +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -97,6 +97,7 @@ void registerTestMatchReductionPass(); void registerTestMathAlgebraicSimplificationPass(); void registerTestMathPolynomialApproximationPass(); +void registerTestMaxSSAFilteredPass(); void registerTestMemRefDependenceCheck(); void registerTestMemRefStrideCalculation(); void registerTestNumberOfBlockExecutionsPass(); @@ -188,6 +189,7 @@ mlir::test::registerTestMatchReductionPass(); mlir::test::registerTestMathAlgebraicSimplificationPass(); mlir::test::registerTestMathPolynomialApproximationPass(); + mlir::test::registerTestMaxSSAFilteredPass(); mlir::test::registerTestMemRefDependenceCheck(); mlir::test::registerTestMemRefStrideCalculation(); mlir::test::registerTestNumberOfBlockExecutionsPass();