Index: lib/Analysis/ScalarEvolution.cpp =================================================================== --- lib/Analysis/ScalarEvolution.cpp +++ lib/Analysis/ScalarEvolution.cpp @@ -913,7 +913,6 @@ // Expr by Denominator for the following functions with empty implementation. void visitTruncateExpr(const SCEVTruncateExpr *Numerator) {} void visitZeroExtendExpr(const SCEVZeroExtendExpr *Numerator) {} - void visitSignExtendExpr(const SCEVSignExtendExpr *Numerator) {} void visitUDivExpr(const SCEVUDivExpr *Numerator) {} void visitSMaxExpr(const SCEVSMaxExpr *Numerator) {} void visitUMaxExpr(const SCEVUMaxExpr *Numerator) {} @@ -941,17 +940,61 @@ } } + // We address the case of casting with SExt (and in the future Trunc, etc) + // of the Denominator and Numerator operators. + void treatCastsOperators(const SCEVCastExpr *CENumerator, const SCEV **Q, + const SCEV **R) { + if (const SCEVSignExtendExpr *SExtDenominator = + dyn_cast(Denominator)) { + // Dropping SExt in Numerator and Denominator + divide(SE, CENumerator->getOperand(), SExtDenominator->getOperand(), + Q, R); + } + else + divide(SE, CENumerator->getOperand(), Denominator, Q, R); + } + + void visitSignExtendExpr(const SCEVSignExtendExpr *Numerator) { + const SCEV *Q, *R; + + // visitSignExtendExpr(SCEVSignExtendExpr *Numerator) means we have: + // div (sext Numerator to...), b. + // Therefore, the division operation we have here is signed - because of + // the (sext Numerator) cast done by LLVM. If SCEVDivision::divide() + // were to be unsigned division, LLVM would do conversion with zext like: + // div (zext Numerator), b + // Also, we can argue that SCEVDivision::divide() is signed because all + // LLVM IR operators are normally signed. + // So we take out the SExt from both the Numerator and Denominator before + // performing the division operation - otherwise the symbolic computation + // would fail since it doesn't know to handle casts - and take care at the + // end to sign extend the Quotient and the Remainder because they should + // have the type of the SCEVDivision::divide() operator and of the + // Numerator and Denominator + // (see, for example, http://llvm.org/docs/LangRef.html#sdiv-instruction). + treatCastsOperators(Numerator, &Q, &R); + + Quotient = SE.getSignExtendExpr(Q, Numerator->getType()); + Remainder = SE.getSignExtendExpr(R, Numerator->getType()); + } + void visitAddRecExpr(const SCEVAddRecExpr *Numerator) { const SCEV *StartQ, *StartR, *StepQ, *StepR; if (!Numerator->isAffine()) return cannotDivide(Numerator); divide(SE, Numerator->getStart(), Denominator, &StartQ, &StartR); divide(SE, Numerator->getStepRecurrence(SE), Denominator, &StepQ, &StepR); - // Bail out if the types do not match. - Type *Ty = Denominator->getType(); - if (Ty != StartQ->getType() || Ty != StartR->getType() || - Ty != StepQ->getType() || Ty != StepR->getType()) - return cannotDivide(Numerator); + + // This is from Manuel Selva's patch (https://reviews.llvm.org/D35478). + // We have to put this code here instead of the conditionals with + // cannotDivide() in order to work with sext expressions. + assert(Numerator->getStart()->getType() == StartQ->getType() && + StartQ->getType() == StartR->getType() && + "Expected matching types"); + assert(Numerator->getStepRecurrence(SE)->getType() == StepQ->getType() && + StepQ->getType() == StepR->getType() && + "Expected matching types"); + Quotient = SE.getAddRecExpr(StartQ, StepQ, Numerator->getLoop(), Numerator->getNoWrapFlags()); Remainder = SE.getAddRecExpr(StartR, StepR, Numerator->getLoop(), @@ -960,7 +1003,7 @@ void visitAddExpr(const SCEVAddExpr *Numerator) { SmallVector Qs, Rs; - Type *Ty = Denominator->getType(); + Type *Ty = Numerator->getType(); for (const SCEV *Op : Numerator->operands()) { const SCEV *Q, *R; @@ -970,6 +1013,9 @@ if (Ty != Q->getType() || Ty != R->getType()) return cannotDivide(Numerator); + assert(Ty == Q->getType() && Ty == R->getType() && + "Expected matching types"); + Qs.push_back(Q); Rs.push_back(R); } @@ -1030,13 +1076,13 @@ // The Remainder is obtained by replacing Denominator by 0 in Numerator. ValueToValueMap RewriteMap; RewriteMap[cast(Denominator)->getValue()] = - cast(Zero)->getValue(); + cast(SE.getZero(Denominator->getType()))->getValue(); Remainder = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap, true); if (Remainder->isZero()) { // The Quotient is obtained by replacing Denominator by 1 in Numerator. RewriteMap[cast(Denominator)->getValue()] = - cast(One)->getValue(); + cast(SE.getOne(Denominator->getType()))->getValue(); Quotient = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap, true); return; @@ -1058,8 +1104,8 @@ SCEVDivision(ScalarEvolution &S, const SCEV *Numerator, const SCEV *Denominator) : SE(S), Denominator(Denominator) { - Zero = SE.getZero(Denominator->getType()); - One = SE.getOne(Denominator->getType()); + Zero = SE.getZero(Numerator->getType()); + One = SE.getOne(Numerator->getType()); // We generally do not know how to divide Expr by Denominator. We // initialize the division to a "cannot divide" state to simplify the rest @@ -10667,6 +10713,9 @@ // Return the number of product terms in S. static inline int numberOfTerms(const SCEV *S) { + if (const SCEVSignExtendExpr *SExtS = dyn_cast(S)) + return numberOfTerms(SExtS->getOperand()); + if (const SCEVMulExpr *Expr = dyn_cast(S)) return Expr->getNumOperands(); return 1; Index: test/Analysis/Delinearization/test_sext.ll =================================================================== --- test/Analysis/Delinearization/test_sext.ll +++ test/Analysis/Delinearization/test_sext.ll @@ -0,0 +1,153 @@ +; RUN: opt -delinearize -analyze < %s | FileCheck %s + + +; We check below that Polly delinearizes access a[i * N*N + j * N + k]. +; In essence what I want to check is contained in the 2 lines with CHECK below. +; Note that I took out also the number of bytes (4 or 16 bytes) - this depends +; on the vectorization factor (4 by default now), and maybe this default is +; going also to change in the future. +; I prefer not checking ArrayRef - maybe it's dependent again on different builds +; (or future changes of LLVM). +; (Note: We also have an AddRec, which looks something like this +; -labels might differ: +; AccessFunction: {(4 * (sext i32 {{0,+,(%N * %N)}<%for.cond4.preheader.us.us.preheader>,+,%N}<%for.cond4.preheader.us.us> to i64)),+,32}<%vector.body> +; CHECK: Base offset: %a +; CHECK-NEXT: ArrayDecl[UnknownSize][%N][%N] with elements of + +; The below LLVM program is obtained by +; running: clang -O3 -mllvm -disable-llvm-optzns -emit-llvm -S test_trunc.c +; (and then cleaning up manually the debug information from the obtained .ll file) +; +; /* For a 64-bit CPU system sext is generated from i32 to i64 for the 64-bit target. +; I presume it will not generate sext on a 32-bit CPU target +; - if you want to make it generate sext use short instead of int. +; This implies also that if I use scalars of type long (i64) instead of int, +; delinearization will work also without the sext patch. +; +; The delinearization algorithm uses the sext patch because it performs: +; - division of sext_i32_to_i64(N * N) by sext_i32_to_i64(N), which generates +; quotient sext_i32_to_i64(N) and sext_i32_to_i64(0) remainder. +; - the dividend and the divisor are created by the third step of the delinearization +; algorithm (it seems the SCEV expressions sext(N * N), sext(N), etc +; are created by the 1st step of the delinearization algorithm extracting the +; terms from the sum of products from the array index i * N*N + j * N + k) +; presented in Section 4.1 of the paper +; Grosser et al, "On recovering Multi-dimensional Arrays in Polly", IMPACT 2015 +; - http://impact.gforge.inria.fr/impact2015/papers/impact2015-grosser.pdf ) +; */ +; +; +; void Test(int *a, int N) { +; int i, j, k; +; +; /* We note here we initialize a linearized version of the a array. +; Note: there is no need for the ScalarEvolution patch if just using a +; loop nest with only i and j vars and the a vector is accessed as a +; 2D array. +; */ +; for (i = 0; i < N; i++) { +; for (j = 0; j < N; j++) { +; for (k = 0; k < N; k++) { +; a[i * N*N + j * N + k] = i + j; +; } +; } +; } +; } + + + +target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-unknown-linux-gnu" + +define void @Test(i32* %a, i32 %N) { +entry: + %a.addr = alloca i32*, align 8 + %N.addr = alloca i32, align 4 + %i = alloca i32, align 4 + %j = alloca i32, align 4 + %k = alloca i32, align 4 + store i32* %a, i32** %a.addr, align 8 + store i32 %N, i32* %N.addr, align 4 + %0 = bitcast i32* %i to i8* + %1 = bitcast i32* %j to i8* + %2 = bitcast i32* %k to i8* + store i32 0, i32* %i, align 4 + br label %for.cond + +for.cond: ; preds = %for.inc14, %entry + %3 = load i32, i32* %i, align 4 + %4 = load i32, i32* %N.addr, align 4 + %cmp = icmp slt i32 %3, %4 + br i1 %cmp, label %for.body, label %for.end16 + +for.body: ; preds = %for.cond + store i32 0, i32* %j, align 4 + br label %for.cond1 + +for.cond1: ; preds = %for.inc11, %for.body + %5 = load i32, i32* %j, align 4 + %6 = load i32, i32* %N.addr, align 4 + %cmp2 = icmp slt i32 %5, %6 + br i1 %cmp2, label %for.body3, label %for.end13 + +for.body3: ; preds = %for.cond1 + store i32 0, i32* %k, align 4 + br label %for.cond4 + +for.cond4: ; preds = %for.inc, %for.body3 + %7 = load i32, i32* %k, align 4 + %8 = load i32, i32* %N.addr, align 4 + %cmp5 = icmp slt i32 %7, %8 + br i1 %cmp5, label %for.body6, label %for.end + +for.body6: ; preds = %for.cond4 + %9 = load i32, i32* %i, align 4 + %10 = load i32, i32* %j, align 4 + %add = add nsw i32 %9, %10 + %11 = load i32*, i32** %a.addr, align 8 + %12 = load i32, i32* %i, align 4 + %13 = load i32, i32* %N.addr, align 4 + %mul = mul nsw i32 %12, %13 + %14 = load i32, i32* %N.addr, align 4 + %mul7 = mul nsw i32 %mul, %14 + %15 = load i32, i32* %j, align 4 + %16 = load i32, i32* %N.addr, align 4 + %mul8 = mul nsw i32 %15, %16 + %add9 = add nsw i32 %mul7, %mul8 + %17 = load i32, i32* %k, align 4 + %add10 = add nsw i32 %add9, %17 + %idxprom = sext i32 %add10 to i64 + %arrayidx = getelementptr inbounds i32, i32* %11, i64 %idxprom + store i32 %add, i32* %arrayidx, align 4 + br label %for.inc + +for.inc: ; preds = %for.body6 + %18 = load i32, i32* %k, align 4 + %inc = add nsw i32 %18, 1 + store i32 %inc, i32* %k, align 4 + br label %for.cond4 + +for.end: ; preds = %for.cond4 + br label %for.inc11 + +for.inc11: ; preds = %for.end + %19 = load i32, i32* %j, align 4 + %inc12 = add nsw i32 %19, 1 + store i32 %inc12, i32* %j, align 4 + br label %for.cond1 + +for.end13: ; preds = %for.cond1 + br label %for.inc14 + +for.inc14: ; preds = %for.end13 + %20 = load i32, i32* %i, align 4 + %inc15 = add nsw i32 %20, 1 + store i32 %inc15, i32* %i, align 4 + br label %for.cond + +for.end16: ; preds = %for.cond + %21 = bitcast i32* %k to i8* + %22 = bitcast i32* %j to i8* + %23 = bitcast i32* %i to i8* + ret void +}