diff --git a/mlir/include/mlir/Conversion/ControlFlowToSCF/ControlFlowToSCF.h b/mlir/include/mlir/Conversion/ControlFlowToSCF/ControlFlowToSCF.h --- a/mlir/include/mlir/Conversion/ControlFlowToSCF/ControlFlowToSCF.h +++ b/mlir/include/mlir/Conversion/ControlFlowToSCF/ControlFlowToSCF.h @@ -15,9 +15,60 @@ #include +#include "mlir/Transforms/CFGToSCF.h" + namespace mlir { class Pass; +/// Implementation of `CFGToSCFInterface` used to lift Control Flow Dialect +/// operations to SCF Dialect operations. +class ControlFlowToSCFTransformation : public CFGToSCFInterface { +public: + /// Creates an `scf.if` op if `controlFlowCondOp` is a `cf.cond_br` op or + /// an `scf.index_switch` if `controlFlowCondOp` is a `cf.switch`. + /// Returns failure otherwise. + FailureOr createStructuredBranchRegionOp( + OpBuilder &builder, Operation *controlFlowCondOp, TypeRange resultTypes, + MutableArrayRef regions) override; + + /// Creates an `scf.yield` op returning the given results. + LogicalResult + createStructuredBranchRegionTerminatorOp(Location loc, OpBuilder &builder, + Operation *branchRegionOp, + ValueRange results) override; + + /// Creates an `scf.while` op. The loop body is made the before-region of the + /// while op and terminated with an `scf.condition` op. The after-region does + /// nothing but forward the iteration variables. + FailureOr + createStructuredDoWhileLoopOp(OpBuilder &builder, Operation *replacedOp, + ValueRange loopVariablesInit, Value condition, + ValueRange loopVariablesNextIter, + Region &&loopBody) override; + + /// Creates an `arith.constant` with an i32 attribute of the given value. + Value getCFGSwitchValue(Location loc, OpBuilder &builder, + unsigned value) override; + + /// Creates a `cf.switch` op with the given cases and flag. + void createCFGSwitchOp(Location loc, OpBuilder &builder, Value flag, + ArrayRef caseValues, + BlockRange caseDestinations, + ArrayRef caseArguments, Block *defaultDest, + ValueRange defaultArgs) override; + + /// Creates a `ub.poison` op of the given type. + Value getUndefValue(Location loc, OpBuilder &builder, Type type) override; + + /// Creates a `func.return` op with poison for each of the return values of + /// the function. It is guaranteed to be directly within the function body. + /// TODO: This can be made independent of the `func` dialect once the UB + /// dialect has a `ub.unreachable` op. + FailureOr createUnreachableTerminator(Location loc, + OpBuilder &builder, + Region ®ion) override; +}; + #define GEN_PASS_DECL_LIFTCONTROLFLOWTOSCFPASS #include "mlir/Conversion/Passes.h.inc" diff --git a/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp b/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp --- a/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp +++ b/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp @@ -29,132 +29,132 @@ using namespace mlir; -namespace { - -class ControlFlowToSCFTransformation : public CFGToSCFInterface { -public: - FailureOr createStructuredBranchRegionOp( - OpBuilder &builder, Operation *controlFlowCondOp, TypeRange resultTypes, - MutableArrayRef regions) override { - if (auto condBrOp = dyn_cast(controlFlowCondOp)) { - assert(regions.size() == 2); - auto ifOp = builder.create( - controlFlowCondOp->getLoc(), resultTypes, condBrOp.getCondition()); - ifOp.getThenRegion().takeBody(regions[0]); - ifOp.getElseRegion().takeBody(regions[1]); - return ifOp.getOperation(); - } - - if (auto switchOp = dyn_cast(controlFlowCondOp)) { - // `getCFGSwitchValue` returns an i32 that we need to convert to index - // fist. - auto cast = builder.create( - controlFlowCondOp->getLoc(), builder.getIndexType(), - switchOp.getFlag()); - SmallVector cases; - if (auto caseValues = switchOp.getCaseValues()) - llvm::append_range( - cases, llvm::map_range(*caseValues, [](const llvm::APInt &apInt) { - return apInt.getZExtValue(); - })); - - assert(regions.size() == cases.size() + 1); - - auto indexSwitchOp = builder.create( - controlFlowCondOp->getLoc(), resultTypes, cast, cases, cases.size()); - - indexSwitchOp.getDefaultRegion().takeBody(regions[0]); - for (auto &&[targetRegion, sourceRegion] : - llvm::zip(indexSwitchOp.getCaseRegions(), llvm::drop_begin(regions))) - targetRegion.takeBody(sourceRegion); - - return indexSwitchOp.getOperation(); - } - - controlFlowCondOp->emitOpError( - "Cannot convert unknown control flow op to structured control flow"); - return failure(); - } - - LogicalResult - createStructuredBranchRegionTerminatorOp(Location loc, OpBuilder &builder, - Operation *branchRegionOp, - ValueRange results) override { - builder.create(loc, results); - return success(); - } - - FailureOr - createStructuredDoWhileLoopOp(OpBuilder &builder, Operation *replacedOp, - ValueRange loopVariablesInit, Value condition, - ValueRange loopVariablesNextIter, - Region &&loopBody) override { - Location loc = replacedOp->getLoc(); - auto whileOp = builder.create( - loc, loopVariablesInit.getTypes(), loopVariablesInit); - - whileOp.getBefore().takeBody(loopBody); - - builder.setInsertionPointToEnd(&whileOp.getBefore().back()); - // `getCFGSwitchValue` returns a i32. We therefore need to truncate the - // condition to i1 first. It is guaranteed to be either 0 or 1 already. - builder.create( - loc, - builder.create(loc, builder.getI1Type(), condition), - loopVariablesNextIter); - - auto *afterBlock = new Block; - whileOp.getAfter().push_back(afterBlock); - afterBlock->addArguments( - loopVariablesInit.getTypes(), - SmallVector(loopVariablesInit.size(), loc)); - builder.setInsertionPointToEnd(afterBlock); - builder.create(loc, afterBlock->getArguments()); - - return whileOp.getOperation(); +FailureOr +ControlFlowToSCFTransformation::createStructuredBranchRegionOp( + OpBuilder &builder, Operation *controlFlowCondOp, TypeRange resultTypes, + MutableArrayRef regions) { + if (auto condBrOp = dyn_cast(controlFlowCondOp)) { + assert(regions.size() == 2); + auto ifOp = builder.create(controlFlowCondOp->getLoc(), + resultTypes, condBrOp.getCondition()); + ifOp.getThenRegion().takeBody(regions[0]); + ifOp.getElseRegion().takeBody(regions[1]); + return ifOp.getOperation(); } - Value getCFGSwitchValue(Location loc, OpBuilder &builder, - unsigned int value) override { - return builder.create(loc, - builder.getI32IntegerAttr(value)); + if (auto switchOp = dyn_cast(controlFlowCondOp)) { + // `getCFGSwitchValue` returns an i32 that we need to convert to index + // fist. + auto cast = builder.create( + controlFlowCondOp->getLoc(), builder.getIndexType(), + switchOp.getFlag()); + SmallVector cases; + if (auto caseValues = switchOp.getCaseValues()) + llvm::append_range( + cases, llvm::map_range(*caseValues, [](const llvm::APInt &apInt) { + return apInt.getZExtValue(); + })); + + assert(regions.size() == cases.size() + 1); + + auto indexSwitchOp = builder.create( + controlFlowCondOp->getLoc(), resultTypes, cast, cases, cases.size()); + + indexSwitchOp.getDefaultRegion().takeBody(regions[0]); + for (auto &&[targetRegion, sourceRegion] : + llvm::zip(indexSwitchOp.getCaseRegions(), llvm::drop_begin(regions))) + targetRegion.takeBody(sourceRegion); + + return indexSwitchOp.getOperation(); } - void createCFGSwitchOp(Location loc, OpBuilder &builder, Value flag, - ArrayRef caseValues, - BlockRange caseDestinations, - ArrayRef caseArguments, Block *defaultDest, - ValueRange defaultArgs) override { - builder.create(loc, flag, defaultDest, defaultArgs, - llvm::to_vector_of(caseValues), - caseDestinations, caseArguments); - } - - Value getUndefValue(Location loc, OpBuilder &builder, Type type) override { - return builder.create(loc, type, nullptr); - } + controlFlowCondOp->emitOpError( + "Cannot convert unknown control flow op to structured control flow"); + return failure(); +} + +LogicalResult +ControlFlowToSCFTransformation::createStructuredBranchRegionTerminatorOp( + Location loc, OpBuilder &builder, Operation *branchRegionOp, + ValueRange results) { + builder.create(loc, results); + return success(); +} + +FailureOr +ControlFlowToSCFTransformation::createStructuredDoWhileLoopOp( + OpBuilder &builder, Operation *replacedOp, ValueRange loopVariablesInit, + Value condition, ValueRange loopVariablesNextIter, Region &&loopBody) { + Location loc = replacedOp->getLoc(); + auto whileOp = builder.create(loc, loopVariablesInit.getTypes(), + loopVariablesInit); + + whileOp.getBefore().takeBody(loopBody); + + builder.setInsertionPointToEnd(&whileOp.getBefore().back()); + // `getCFGSwitchValue` returns a i32. We therefore need to truncate the + // condition to i1 first. It is guaranteed to be either 0 or 1 already. + builder.create( + loc, builder.create(loc, builder.getI1Type(), condition), + loopVariablesNextIter); + + auto *afterBlock = new Block; + whileOp.getAfter().push_back(afterBlock); + afterBlock->addArguments( + loopVariablesInit.getTypes(), + SmallVector(loopVariablesInit.size(), loc)); + builder.setInsertionPointToEnd(afterBlock); + builder.create(loc, afterBlock->getArguments()); + + return whileOp.getOperation(); +} + +Value ControlFlowToSCFTransformation::getCFGSwitchValue(Location loc, + OpBuilder &builder, + unsigned int value) { + return builder.create(loc, + builder.getI32IntegerAttr(value)); +} + +void ControlFlowToSCFTransformation::createCFGSwitchOp( + Location loc, OpBuilder &builder, Value flag, + ArrayRef caseValues, BlockRange caseDestinations, + ArrayRef caseArguments, Block *defaultDest, + ValueRange defaultArgs) { + builder.create(loc, flag, defaultDest, defaultArgs, + llvm::to_vector_of(caseValues), + caseDestinations, caseArguments); +} + +Value ControlFlowToSCFTransformation::getUndefValue(Location loc, + OpBuilder &builder, + Type type) { + return builder.create(loc, type, nullptr); +} + +FailureOr +ControlFlowToSCFTransformation::createUnreachableTerminator(Location loc, + OpBuilder &builder, + Region ®ion) { + + // TODO: This should create a `ub.unreachable` op. Once such an operation + // exists to make the pass independent of the func dialect. For now just + // return poison values. + auto funcOp = dyn_cast(region.getParentOp()); + if (!funcOp) + return emitError(loc, "Expected '") + << func::FuncOp::getOperationName() << "' as top level operation"; + + return builder + .create( + loc, llvm::map_to_vector(funcOp.getResultTypes(), + [&](Type type) { + return getUndefValue(loc, builder, type); + })) + .getOperation(); +} - FailureOr createUnreachableTerminator(Location loc, - OpBuilder &builder, - Region ®ion) override { - - // TODO: This should create a `ub.unreachable` op. Once such an operation - // exists to make the pass can be made independent of the func - // dialect. For now just return poison values. - auto funcOp = dyn_cast(region.getParentOp()); - if (!funcOp) - return emitError(loc, "Expected '") - << func::FuncOp::getOperationName() << "' as top level operation"; - - return builder - .create( - loc, llvm::map_to_vector(funcOp.getResultTypes(), - [&](Type type) { - return getUndefValue(loc, builder, type); - })) - .getOperation(); - } -}; +namespace { struct LiftControlFlowToSCF : public impl::LiftControlFlowToSCFPassBase {