Index: lib/Transforms/Vectorize/SLPVectorizer.cpp =================================================================== --- lib/Transforms/Vectorize/SLPVectorizer.cpp +++ lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -427,6 +427,12 @@ /// \brief Perform LICM and CSE on the newly generated gather sequences. void optimizeGatherSequence(); + /// \brief Get the instruction numbering for a given Instruction. + int getIndex(Instruction *I) { + BlockNumbering &BN = getBlockNumbering(I->getParent()); + return BN.getIndex(I); + } + private: struct TreeEntry; @@ -2231,7 +2237,8 @@ unsigned collectStores(BasicBlock *BB, BoUpSLP &R); /// \brief Try to vectorize a chain that starts at two arithmetic instrs. - bool tryToVectorizePair(Value *A, Value *B, BoUpSLP &R); + bool tryToVectorizePair(Value *A, Value *B, BoUpSLP &R, + BinaryOperator *V = nullptr); /// \brief Try to vectorize a list of operands. /// \@param BuildVector A list of users to ignore for the purpose of @@ -2404,10 +2411,23 @@ return count; } -bool SLPVectorizer::tryToVectorizePair(Value *A, Value *B, BoUpSLP &R) { +bool SLPVectorizer::tryToVectorizePair(Value *A, Value *B, BoUpSLP &R, + BinaryOperator *V) { if (!A || !B) return false; Value *VL[] = { A, B }; + + // Canonicalize operands based on source order, so that the ordering in the + // expression tree more closely matches the ordering of the source. + if (V && V->isCommutative() && isa(A) && isa(B) && + cast(A)->getParent() == cast(B)->getParent()) { + assert(V->getOperand(0) == A && V->getOperand(1) == B && + "Expected operands in order."); + int IndexA = R.getIndex(cast(A)); + int IndexB = R.getIndex(cast(B)); + if (IndexA > IndexB) + std::swap(VL[0], VL[1]); + } return tryToVectorizeList(VL, R); } @@ -2508,7 +2528,7 @@ return false; // Try to vectorize V. - if (tryToVectorizePair(V->getOperand(0), V->getOperand(1), R)) + if (tryToVectorizePair(V->getOperand(0), V->getOperand(1), R, V)) return true; BinaryOperator *A = dyn_cast(V->getOperand(0)); @@ -3018,15 +3038,15 @@ } for (int i = 0; i < 2; ++i) { - if (BinaryOperator *BI = dyn_cast(CI->getOperand(i))) { - if (tryToVectorizePair(BI->getOperand(0), BI->getOperand(1), R)) { - Changed = true; - // We would like to start over since some instructions are deleted - // and the iterator may become invalid value. - it = BB->begin(); - e = BB->end(); - } - } + if (BinaryOperator *BI = dyn_cast(CI->getOperand(i))) { + if (tryToVectorizePair(BI->getOperand(0), BI->getOperand(1), R, BI)) { + Changed = true; + // We would like to start over since some instructions are deleted + // and the iterator may become invalid value. + it = BB->begin(); + e = BB->end(); + } + } } continue; } Index: test/Transforms/SLPVectorizer/AArch64/commute.ll =================================================================== --- /dev/null +++ test/Transforms/SLPVectorizer/AArch64/commute.ll @@ -0,0 +1,75 @@ +; RUN: opt -S -slp-vectorizer %s | FileCheck %s +target datalayout = "e-m:e-i64:64-i128:128-n32:64-S128" +target triple = "aarch64--linux-gnu" + +%structA = type { [2 x float] } + +define void @test1(%structA* nocapture readonly %J, i32 %xmin, i32 %ymin) { +; CHECK-LABEL: test1 +; CHECK: %arrayidx4 = getelementptr inbounds %structA* %J, i64 0, i32 0, i64 0 +; CHECK: %arrayidx9 = getelementptr inbounds %structA* %J, i64 0, i32 0, i64 1 +; CHECK: %3 = bitcast float* %arrayidx4 to <2 x float>* +; CHECK: %4 = load <2 x float>* %3, align 4 +; CHECK: %5 = fsub <2 x float> %2, %4 +; CHECK: %6 = fmul <2 x float> %5, %5 +; CHECK: %7 = extractelement <2 x float> %6, i32 0 +; CHECK: %8 = extractelement <2 x float> %6, i32 1 +; CHECK: %add = fadd fast float %7, %8 +; CHECK: %cmp = fcmp oeq float %add, 0.000000e+00 + +entry: + br label %for.body3.lr.ph + +for.body3.lr.ph: + %conv5 = sitofp i32 %ymin to float + %conv = sitofp i32 %xmin to float + %arrayidx4 = getelementptr inbounds %structA* %J, i64 0, i32 0, i64 0 + %0 = load float* %arrayidx4, align 4 + %sub = fsub fast float %conv, %0 + %arrayidx9 = getelementptr inbounds %structA* %J, i64 0, i32 0, i64 1 + %1 = load float* %arrayidx9, align 4 + %sub10 = fsub fast float %conv5, %1 + %mul11 = fmul fast float %sub, %sub + %mul12 = fmul fast float %sub10, %sub10 + %add = fadd fast float %mul11, %mul12 + %cmp = fcmp oeq float %add, 0.000000e+00 + br i1 %cmp, label %for.body3.lr.ph, label %for.end27 + +for.end27: + ret void +} + +define void @test2(%structA* nocapture readonly %J, i32 %xmin, i32 %ymin) { +; CHECK-LABEL: test2 +; CHECK: %arrayidx4 = getelementptr inbounds %structA* %J, i64 0, i32 0, i64 0 +; CHECK: %arrayidx9 = getelementptr inbounds %structA* %J, i64 0, i32 0, i64 1 +; CHECK: %3 = bitcast float* %arrayidx4 to <2 x float>* +; CHECK: %4 = load <2 x float>* %3, align 4 +; CHECK: %5 = fsub <2 x float> %2, %4 +; CHECK: %6 = fmul <2 x float> %5, %5 +; CHECK: %7 = extractelement <2 x float> %6, i32 0 +; CHECK: %8 = extractelement <2 x float> %6, i32 1 +; CHECK: %add = fadd fast float %7, %8 +; CHECK: %cmp = fcmp oeq float %add, 0.000000e+00 + +entry: + br label %for.body3.lr.ph + +for.body3.lr.ph: + %conv5 = sitofp i32 %ymin to float + %conv = sitofp i32 %xmin to float + %arrayidx4 = getelementptr inbounds %structA* %J, i64 0, i32 0, i64 0 + %0 = load float* %arrayidx4, align 4 + %sub = fsub fast float %conv, %0 + %arrayidx9 = getelementptr inbounds %structA* %J, i64 0, i32 0, i64 1 + %1 = load float* %arrayidx9, align 4 + %sub10 = fsub fast float %conv5, %1 + %mul11 = fmul fast float %sub, %sub + %mul12 = fmul fast float %sub10, %sub10 + %add = fadd fast float %mul12, %mul11 ;;;<---- Operands commuted!! + %cmp = fcmp oeq float %add, 0.000000e+00 + br i1 %cmp, label %for.body3.lr.ph, label %for.end27 + +for.end27: + ret void +}