Index: llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp =================================================================== --- llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -2157,11 +2157,11 @@ // (mul (shl X, c1), c2) -> (mul X, c2 << c1) if (N0.getOpcode() == ISD::SHL && - isConstantOrConstantVector(N1) && - isConstantOrConstantVector(N0.getOperand(1))) { + isConstantOrConstantVector(N1, /* NoOpaques */ true) && + isConstantOrConstantVector(N0.getOperand(1), /* NoOpaques */ true)) { SDValue C3 = DAG.getNode(ISD::SHL, SDLoc(N), VT, N1, N0.getOperand(1)); - AddToWorklist(C3.getNode()); - return DAG.getNode(ISD::MUL, SDLoc(N), VT, N0.getOperand(0), C3); + if (isConstantOrConstantVector(C3)) + return DAG.getNode(ISD::MUL, SDLoc(N), VT, N0.getOperand(0), C3); } // Change (mul (shl X, C), Y) -> (shl (mul X, Y), C) when the shift has one @@ -4714,8 +4714,8 @@ isConstantOrConstantVector(N1, /* No Opaques */ true) && isConstantOrConstantVector(N0.getOperand(1), /* No Opaques */ true)) { SDValue Shl = DAG.getNode(ISD::SHL, SDLoc(N1), VT, N0.getOperand(1), N1); - AddToWorklist(Shl.getNode()); - return DAG.getNode(ISD::MUL, SDLoc(N), VT, N0.getOperand(0), Shl); + if (isConstantOrConstantVector(Shl)) + return DAG.getNode(ISD::MUL, SDLoc(N), VT, N0.getOperand(0), Shl); } if (N1C && !N1C->isOpaque()) Index: llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAG.cpp =================================================================== --- llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -3498,25 +3498,22 @@ EVT SVT = VT.getScalarType(); SmallVector Outputs; for (unsigned I = 0, E = BV1->getNumOperands(); I != E; ++I) { - ConstantSDNode *V1 = dyn_cast(BV1->getOperand(I)); - ConstantSDNode *V2 = dyn_cast(BV2->getOperand(I)); - if (!V1 || !V2) // Not a constant, bail. - return SDValue(); - - if (V1->isOpaque() || V2->isOpaque()) - return SDValue(); + SDValue V1 = BV1->getOperand(I); + SDValue V2 = BV2->getOperand(I); // Avoid BUILD_VECTOR nodes that perform implicit truncation. - // FIXME: This is valid and could be handled by truncating the APInts. + // FIXME: This is valid and could be handled by truncation. if (V1->getValueType(0) != SVT || V2->getValueType(0) != SVT) return SDValue(); // Fold one vector element. - std::pair Folded = FoldValue(Opcode, V1->getAPIntValue(), - V2->getAPIntValue()); - if (!Folded.second) + SDValue ScalarResult = getNode(Opcode, DL, SVT, V1, V2); + + // Scalar folding only succeeded if the result is a constant or UNDEF. + if (!ScalarResult.isUndef() && ScalarResult.getOpcode() != ISD::Constant && + ScalarResult.getOpcode() != ISD::ConstantFP) return SDValue(); - Outputs.push_back(getConstant(Folded.first, DL, SVT)); + Outputs.push_back(ScalarResult); } assert(VT.getVectorNumElements() == Outputs.size() && Index: llvm/trunk/test/CodeGen/AArch64/dag-combine-mul-shl.ll =================================================================== --- llvm/trunk/test/CodeGen/AArch64/dag-combine-mul-shl.ll +++ llvm/trunk/test/CodeGen/AArch64/dag-combine-mul-shl.ll @@ -0,0 +1,117 @@ +; RUN: llc -mtriple=aarch64 < %s | FileCheck %s + +; CHECK-LABEL: fn1_vector: +; CHECK: adrp x[[BASE:[0-9]+]], .LCP +; CHECK-NEXT: ldr q[[NUM:[0-9]+]], [x[[BASE]], +; CHECK-NEXT: mul v0.16b, v0.16b, v[[NUM]].16b +; CHECK-NEXT: ret +define <16 x i8> @fn1_vector(<16 x i8> %arg) { +entry: + %shl = shl <16 x i8> %arg, + %mul = mul <16 x i8> %shl, + ret <16 x i8> %mul +} + +; CHECK-LABEL: fn2_vector: +; CHECK: adrp x[[BASE:[0-9]+]], .LCP +; CHECK-NEXT: ldr q[[NUM:[0-9]+]], [x[[BASE]], +; CHECK-NEXT: mul v0.16b, v0.16b, v[[NUM]].16b +; CHECK-NEXT: ret +define <16 x i8> @fn2_vector(<16 x i8> %arg) { +entry: + %mul = mul <16 x i8> %arg, + %shl = shl <16 x i8> %mul, + ret <16 x i8> %shl +} + +; CHECK-LABEL: fn1_vector_undef: +; CHECK: adrp x[[BASE:[0-9]+]], .LCP +; CHECK-NEXT: ldr q[[NUM:[0-9]+]], [x[[BASE]], +; CHECK-NEXT: mul v0.16b, v0.16b, v[[NUM]].16b +; CHECK-NEXT: ret +define <16 x i8> @fn1_vector_undef(<16 x i8> %arg) { +entry: + %shl = shl <16 x i8> %arg, + %mul = mul <16 x i8> %shl, + ret <16 x i8> %mul +} + +; CHECK-LABEL: fn2_vector_undef: +; CHECK: adrp x[[BASE:[0-9]+]], .LCP +; CHECK-NEXT: ldr q[[NUM:[0-9]+]], [x[[BASE]], +; CHECK-NEXT: mul v0.16b, v0.16b, v[[NUM]].16b +; CHECK-NEXT: ret +define <16 x i8> @fn2_vector_undef(<16 x i8> %arg) { +entry: + %mul = mul <16 x i8> %arg, + %shl = shl <16 x i8> %mul, + ret <16 x i8> %shl +} + +; CHECK-LABEL: fn1_scalar: +; CHECK: mov w[[REG:[0-9]+]], #1664 +; CHECK-NEXT: mul w0, w0, w[[REG]] +; CHECK-NEXT: ret +define i32 @fn1_scalar(i32 %arg) { +entry: + %shl = shl i32 %arg, 7 + %mul = mul i32 %shl, 13 + ret i32 %mul +} + +; CHECK-LABEL: fn2_scalar: +; CHECK: mov w[[REG:[0-9]+]], #1664 +; CHECK-NEXT: mul w0, w0, w[[REG]] +; CHECK-NEXT: ret +define i32 @fn2_scalar(i32 %arg) { +entry: + %mul = mul i32 %arg, 13 + %shl = shl i32 %mul, 7 + ret i32 %shl +} + +; CHECK-LABEL: fn1_scalar_undef: +; CHECK: mov w0 +; CHECK-NEXT: ret +define i32 @fn1_scalar_undef(i32 %arg) { +entry: + %shl = shl i32 %arg, 7 + %mul = mul i32 %shl, undef + ret i32 %mul +} + +; CHECK-LABEL: fn2_scalar_undef: +; CHECK: mov w0 +; CHECK-NEXT: ret +define i32 @fn2_scalar_undef(i32 %arg) { +entry: + %mul = mul i32 %arg, undef + %shl = shl i32 %mul, 7 + ret i32 %shl +} + +; CHECK-LABEL: fn1_scalar_opaque: +; CHECK: mov w[[REG:[0-9]+]], #13 +; CHECK-NEXT: mul w[[REG]], w0, w[[REG]] +; CHECK-NEXT: lsl w0, w[[REG]], #7 +; CHECK-NEXT: ret +define i32 @fn1_scalar_opaque(i32 %arg) { +entry: + %bitcast = bitcast i32 13 to i32 + %shl = shl i32 %arg, 7 + %mul = mul i32 %shl, %bitcast + ret i32 %mul +} + +; CHECK-LABEL: fn2_scalar_opaque: +; CHECK: mov w[[REG:[0-9]+]], #13 +; CHECK-NEXT: mul w[[REG]], w0, w[[REG]] +; CHECK-NEXT: lsl w0, w[[REG]], #7 +; CHECK-NEXT: ret +define i32 @fn2_scalar_opaque(i32 %arg) { +entry: + %bitcast = bitcast i32 13 to i32 + %mul = mul i32 %arg, %bitcast + %shl = shl i32 %mul, 7 + ret i32 %shl +} Index: llvm/trunk/test/CodeGen/X86/shift-pcmp.ll =================================================================== --- llvm/trunk/test/CodeGen/X86/shift-pcmp.ll +++ llvm/trunk/test/CodeGen/X86/shift-pcmp.ll @@ -26,15 +26,13 @@ ; SSE-LABEL: bar: ; SSE: # BB#0: ; SSE-NEXT: pcmpeqw %xmm1, %xmm0 -; SSE-NEXT: psrlw $15, %xmm0 -; SSE-NEXT: psllw $5, %xmm0 +; SSE-NEXT: pand {{.*}}(%rip), %xmm0 ; SSE-NEXT: retq ; ; AVX-LABEL: bar: ; AVX: # BB#0: ; AVX-NEXT: vpcmpeqw %xmm1, %xmm0, %xmm0 -; AVX-NEXT: vpsrlw $15, %xmm0, %xmm0 -; AVX-NEXT: vpsllw $5, %xmm0, %xmm0 +; AVX-NEXT: vpand {{.*}}(%rip), %xmm0, %xmm0 ; AVX-NEXT: retq ; %icmp = icmp eq <8 x i16> %a, %b