Index: include/llvm/InitializePasses.h =================================================================== --- include/llvm/InitializePasses.h +++ include/llvm/InitializePasses.h @@ -74,6 +74,7 @@ void initializeCallGraphWrapperPassPass(PassRegistry &); void initializeBlockExtractorPassPass(PassRegistry&); void initializeBlockFrequencyInfoWrapperPassPass(PassRegistry&); +void initializeBoolRetToIntPass(PassRegistry&); void initializeBoundsCheckingPass(PassRegistry&); void initializeBranchFolderPassPass(PassRegistry&); void initializeBranchProbabilityInfoWrapperPassPass(PassRegistry&); Index: include/llvm/LinkAllPasses.h =================================================================== --- include/llvm/LinkAllPasses.h +++ include/llvm/LinkAllPasses.h @@ -65,6 +65,7 @@ (void) llvm::createTypeBasedAAWrapperPass(); (void) llvm::createScopedNoAliasAAWrapperPass(); (void) llvm::createBoundsCheckingPass(); + (void) llvm::createBoolRetToIntPass(); (void) llvm::createBreakCriticalEdgesPass(); (void) llvm::createCallGraphPrinterPass(); (void) llvm::createCallGraphViewerPass(); Index: include/llvm/Transforms/Scalar.h =================================================================== --- include/llvm/Transforms/Scalar.h +++ include/llvm/Transforms/Scalar.h @@ -480,6 +480,12 @@ // FunctionPass *createLoopDistributePass(); +//===----------------------------------------------------------------------===// +// +// BoolRetToInt - Convert i1 to i32 type. +// +Pass *createBoolRetToIntPass(); + } // End llvm namespace #endif Index: lib/Target/PowerPC/PPCTargetMachine.cpp =================================================================== --- lib/Target/PowerPC/PPCTargetMachine.cpp +++ lib/Target/PowerPC/PPCTargetMachine.cpp @@ -282,6 +282,7 @@ } void PPCPassConfig::addIRPasses() { + addPass(createBoolRetToIntPass()); addPass(createAtomicExpandPass(&getPPCTargetMachine())); // For the BG/Q (or if explicitly requested), add explicit data prefetch Index: lib/Transforms/Scalar/BoolRetToInt.cpp =================================================================== --- lib/Transforms/Scalar/BoolRetToInt.cpp +++ lib/Transforms/Scalar/BoolRetToInt.cpp @@ -0,0 +1,116 @@ +#include "llvm/Transforms/Scalar.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Instructions.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Pass.h" + +using namespace llvm; + +namespace { + +class BoolRetToInt : public FunctionPass { + + static ReturnInst *getUniqueRetInst(Function &F) { + ReturnInst *retInst = nullptr; + for (auto &BB : F) { + if (ReturnInst *currRetInst = dyn_cast(BB.getTerminator())) { + if (retInst) return nullptr; + else retInst = currRetInst; + } + } + return retInst; + } + + SetVector findAllDefs(Value *V) { + SetVector defs; + SetVector workList; + workList.insert(V); + while (!workList.empty()) { + Value *curr = workList.back(); + workList.pop_back(); + bool isInserted = defs.insert(curr); + assert(isInserted && "already processed def in worklist?"); + if (User *currUser = dyn_cast(curr)) + for (auto &op : currUser->operands()) + if (!defs.count(op)) + workList.insert(op); + } + return defs; + } + + Value *translate(Value *V) { + Type *int1Ty = Type::getInt1Ty(V->getContext()); + Type *int32Ty = Type::getInt32Ty(V->getContext()); + Value *zero = Constant::getNullValue(int32Ty); + if (Constant *C = dyn_cast(V)) + return ConstantExpr::getZExt(C, int32Ty); + + if (ReturnInst *R = dyn_cast(V)) + return new TruncInst(zero, int1Ty, "backToBool", R); + + if (PHINode *P = dyn_cast(V)) { + PHINode *Q = PHINode::Create(int32Ty, P->getNumIncomingValues(), P->getName(), P); + for (unsigned i = 0; i < P->getNumOperands(); ++i) + Q->addIncoming(zero, P->getIncomingBlock(i)); + return Q; + } + + assert(false && "Unexpected value"); + } + + public: + static char ID; + BoolRetToInt() : FunctionPass(ID) {} + + bool runOnFunction(Function &F) { + + return false; + + if (!F.getReturnType()->isIntegerTy(1)) + return false; + + ReturnInst *retInst = getUniqueRetInst(F); + if (!retInst) return false; + + SetVector defs = findAllDefs(retInst); + for (auto &V : defs) + if (!isa(V) && !isa(V) && !isa(V)) + return false; + + for (auto &V : defs) + for (auto &W : V->uses()) + if (!defs.count(W)) + return false; + + DenseMap boolToIntMap; + for (auto &V : defs) + boolToIntMap[V] = translate(V); + + for (auto &pair : boolToIntMap) { + User *first = dyn_cast(pair.first); + User *second = dyn_cast(pair.second); + assert (!!first == !!second && "translated from user to non-user!?"); + if (first && second) + for (unsigned i = 0; i < first->getNumOperands(); ++i) + second->setOperand(i, boolToIntMap[first->getOperand(i)]); + } + + retInst->setOperand(0, boolToIntMap[retInst]); + + return true; + } + + void getAnalysisUsage(AnalysisUsage &) const {} +}; + +} + +char BoolRetToInt::ID = 0; +INITIALIZE_PASS(BoolRetToInt, "bool-ret-to-int", + "Convert i1 constants to i32 if they are returned", + false, false) + +Pass *llvm::createBoolRetToIntPass() { + return new BoolRetToInt(); +} Index: lib/Transforms/Scalar/CMakeLists.txt =================================================================== --- lib/Transforms/Scalar/CMakeLists.txt +++ lib/Transforms/Scalar/CMakeLists.txt @@ -2,6 +2,7 @@ ADCE.cpp AlignmentFromAssumptions.cpp BDCE.cpp + BoolRetToInt.cpp ConstantHoisting.cpp ConstantProp.cpp CorrelatedValuePropagation.cpp Index: lib/Transforms/Scalar/Scalar.cpp =================================================================== --- lib/Transforms/Scalar/Scalar.cpp +++ lib/Transforms/Scalar/Scalar.cpp @@ -32,6 +32,7 @@ void llvm::initializeScalarOpts(PassRegistry &Registry) { initializeADCEPass(Registry); initializeBDCEPass(Registry); + initializeBoolRetToIntPass(Registry); initializeAlignmentFromAssumptionsPass(Registry); initializeConstantHoistingPass(Registry); initializeConstantPropagationPass(Registry);