diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp --- a/llvm/lib/Analysis/InstructionSimplify.cpp +++ b/llvm/lib/Analysis/InstructionSimplify.cpp @@ -39,6 +39,7 @@ #include "llvm/IR/GlobalAlias.h" #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicsNVPTX.h" #include "llvm/IR/Operator.h" #include "llvm/IR/PatternMatch.h" #include "llvm/IR/ValueHandle.h" @@ -599,6 +600,20 @@ return CommonValue; } +static Constant *foldIntrinsicConstant(ICmpInst::Predicate Pred, Value *Op0, + Value *Op1, Type *RetTy) { + IntrinsicInst *Inst0 = dyn_cast(Op0); + IntrinsicInst *Inst1 = dyn_cast(Op1); + + // fold %cmp = icmp slt i32 %tid, %ntid to true. + if (Inst0->getIntrinsicID() == Intrinsic::nvvm_read_ptx_sreg_tid_x && + Inst1->getIntrinsicID() == Intrinsic::nvvm_read_ptx_sreg_ntid_x) + if (ICmpInst::isLE(Pred) || ICmpInst::isLT(Pred)) + return ConstantInt::getTrue(RetTy); + + return nullptr; +} + static Constant *foldOrCommuteConstant(Instruction::BinaryOps Opcode, Value *&Op0, Value *&Op1, const SimplifyQuery &Q) { @@ -3703,6 +3718,12 @@ if (Value *V = ThreadCmpOverPHI(Pred, LHS, RHS, Q, MaxRecurse)) return V; + // If the comparison is with two instrinsic instructions try to fold them + // using domain knowledge. + if (isa(LHS) && isa(RHS)) + if (Constant *C = foldIntrinsicConstant(Pred, LHS, RHS, ITy)) + return C; + return nullptr; } diff --git a/llvm/test/Transforms/InstSimplify/intrinsic.ll b/llvm/test/Transforms/InstSimplify/intrinsic.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/InstSimplify/intrinsic.ll @@ -0,0 +1,35 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt < %s -instsimplify -S | FileCheck %s + +define i32 @compare() { +; CHECK-LABEL: @compare( +; CHECK-NEXT: entry: +; CHECK-NEXT: br i1 true, label [[IF_THEN:%.*]], label [[IF_ELSE:%.*]] +; CHECK: if.then: +; CHECK-NEXT: br label [[RETURN:%.*]] +; CHECK: if.else: +; CHECK-NEXT: br label [[RETURN]] +; CHECK: return: +; CHECK-NEXT: [[RETVAL:%.*]] = phi i32 [ 1, [[IF_THEN]] ], [ 0, [[IF_ELSE]] ] +; CHECK-NEXT: ret i32 [[RETVAL]] +; +entry: + %tid = call i32 @llvm.nvvm.read.ptx.sreg.tid.x() + %ntid = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x() + %cmp = icmp slt i32 %tid, %ntid + br i1 %cmp, label %if.then, label %if.else + +if.then: ; preds = %entry + br label %return + +if.else: ; preds = %entry + br label %return + +return: ; preds = %if.else, %if.then + %retval = phi i32 [ 1, %if.then ], [ 0, %if.else ] + ret i32 %retval +} + +declare i32 @llvm.nvvm.read.ptx.sreg.tid.x() + +declare i32 @llvm.nvvm.read.ptx.sreg.ntid.x()