diff --git a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp --- a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp +++ b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp @@ -126,6 +126,17 @@ assert(SE.getTypeSizeInBits(V->getType()) == SE.getTypeSizeInBits(Ty) && "InsertNoopCastOfTo cannot change sizes!"); + auto *PtrTy = dyn_cast(Ty); + // inttoptr only works for integral pointers. For non-integral pointers, we + // can create a GEP on i8* null with the integral value as index. + if (Op == Instruction::IntToPtr && DL.isNonIntegralPointerType(PtrTy)) { + auto *Int8PtrTy = Builder.getInt8PtrTy(PtrTy->getAddressSpace()); + assert(DL.getTypeAllocSize(Int8PtrTy->getElementType()) == 1 && + "alloc size of i8 must by 1 byte for the GEP to be correct"); + auto *GEP = Builder.CreateGEP( + Builder.getInt8Ty(), Constant::getNullValue(Int8PtrTy), V, "uglygep"); + return Builder.CreateBitCast(GEP, Ty); + } // Short-circuit unnecessary bitcasts. if (Op == Instruction::BitCast) { if (V->getType() == Ty) diff --git a/llvm/unittests/Transforms/Utils/ScalarEvolutionExpanderTest.cpp b/llvm/unittests/Transforms/Utils/ScalarEvolutionExpanderTest.cpp --- a/llvm/unittests/Transforms/Utils/ScalarEvolutionExpanderTest.cpp +++ b/llvm/unittests/Transforms/Utils/ScalarEvolutionExpanderTest.cpp @@ -20,11 +20,14 @@ #include "llvm/IR/InstIterator.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/IR/Verifier.h" #include "gtest/gtest.h" namespace llvm { +using namespace PatternMatch; + // We use this fixture to ensure that we clean up ScalarEvolution before // deleting the PassManager. class ScalarEvolutionExpanderTest : public testing::Test { @@ -917,4 +920,53 @@ TestMatchingCanonicalIV(GetAR5, ARBitWidth); } +TEST_F(ScalarEvolutionExpanderTest, ExpandNonIntegralPtrWithNullBase) { + LLVMContext C; + SMDiagnostic Err; + + std::unique_ptr M = + parseAssemblyString("target datalayout = " + "\"e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:" + "128-n8:16:32:64-S128-ni:1-p2:32:8:8:32-ni:2\"" + "define float addrspace(1)* @test(i64 %offset) { " + " %ptr = getelementptr inbounds float, float " + "addrspace(1)* null, i64 %offset" + " ret float addrspace(1)* %ptr" + "}", + Err, C); + + assert(M && "Could not parse module?"); + assert(!verifyModule(*M) && "Must have been well formed!"); + + runWithSE(*M, "test", [&](Function &F, LoopInfo &LI, ScalarEvolution &SE) { + auto &I = GetInstByName(F, "ptr"); + auto PtrPlus1 = + SE.getAddExpr(SE.getSCEV(&I), SE.getConstant(I.getType(), 1)); + SCEVExpander Exp(SE, M->getDataLayout(), "expander"); + + Value *V = Exp.expandCodeFor(PtrPlus1, I.getType(), &I); + I.replaceAllUsesWith(V); + + // Check the expander created bitcast (gep i8* null, %offset). + auto *Cast = dyn_cast(V); + EXPECT_TRUE(Cast); + EXPECT_EQ(Cast->getType(), I.getType()); + auto *GEP = dyn_cast(Cast->getOperand(0)); + EXPECT_TRUE(GEP); + EXPECT_TRUE(cast(GEP->getPointerOperand())->isNullValue()); + EXPECT_EQ(cast(GEP->getPointerOperand()->getType()) + ->getAddressSpace(), + cast(I.getType())->getAddressSpace()); + + // Check the expander created the expected index computation: add (shl + // %offset, 2), 1. + Value *Arg; + EXPECT_TRUE( + match(GEP->getOperand(1), + m_Add(m_Shl(m_Value(Arg), m_SpecificInt(2)), m_SpecificInt(1)))); + EXPECT_EQ(Arg, &*F.arg_begin()); + EXPECT_FALSE(verifyFunction(F, &errs())); + }); +} + } // end namespace llvm