Index: lib/Transforms/InstCombine/InstructionCombining.cpp =================================================================== --- lib/Transforms/InstCombine/InstructionCombining.cpp +++ lib/Transforms/InstCombine/InstructionCombining.cpp @@ -162,6 +162,49 @@ I.setFastMathFlags(FMF); } +/// Combine constant operands of associative operations either before or after a +/// cast to eliminate one of the associative operations: +/// (op (cast (op X, C2)), C1) --> (cast (op X, op (C1, C2))) +/// (op (cast (op X, C2)), C1) --> (op (cast X), op (C1, C2)) +static bool simplifyAssocCastAssoc(BinaryOperator *BinOp1) { + auto *Cast = dyn_cast(BinOp1->getOperand(0)); + if (!Cast || !Cast->hasOneUse()) + return false; + + // TODO: Enhance logic for other casts and remove this check. + auto CastOpcode = Cast->getOpcode(); + if (CastOpcode != Instruction::ZExt) + return false; + + // TODO: Enhance logic for other BinOps and remove this check. + auto AssocOpcode = BinOp1->getOpcode(); + if (AssocOpcode != Instruction::Xor && AssocOpcode != Instruction::And && + AssocOpcode != Instruction::Or) + return false; + + auto *BinOp2 = dyn_cast(Cast->getOperand(0)); + if (!BinOp2 || !BinOp2->hasOneUse() || BinOp2->getOpcode() != AssocOpcode) + return false; + + Constant *C1, *C2; + if (!match(BinOp1->getOperand(1), m_Constant(C1)) || + !match(BinOp2->getOperand(1), m_Constant(C2))) + return false; + + // TODO: This assumes a zext cast. + // Eg, if it was a trunc, we'd cast C1 to the source type because casting C2 + // to the destination type might lose bits. + + // Fold the constants together in the destination type: + // (op (cast (op X, C2)), C1) --> (op (cast X), FoldedC) + Type *DestTy = C1->getType(); + Constant *CastC2 = ConstantExpr::getCast(CastOpcode, C2, DestTy); + Constant *FoldedC = ConstantExpr::get(AssocOpcode, C1, CastC2); + Cast->setOperand(0, BinOp2->getOperand(0)); + BinOp1->setOperand(1, FoldedC); + return true; +} + /// This performs a few simplifications for operators that are associative or /// commutative: /// @@ -249,6 +292,12 @@ } if (I.isAssociative() && I.isCommutative()) { + if (simplifyAssocCastAssoc(&I)) { + Changed = true; + ++NumReassoc; + continue; + } + // Transform: "(A op B) op C" ==> "(C op A) op B" if "C op A" simplifies. if (Op0 && Op0->getOpcode() == Opcode) { Value *A = Op0->getOperand(0); Index: test/Transforms/InstCombine/assoc-cast-assoc.ll =================================================================== --- test/Transforms/InstCombine/assoc-cast-assoc.ll +++ test/Transforms/InstCombine/assoc-cast-assoc.ll @@ -3,9 +3,8 @@ define i5 @XorZextXor(i3 %a) { ; CHECK-LABEL: @XorZextXor( -; CHECK-NEXT: [[OP1:%.*]] = xor i3 %a, 3 -; CHECK-NEXT: [[CAST:%.*]] = zext i3 [[OP1]] to i5 -; CHECK-NEXT: [[OP2:%.*]] = xor i5 [[CAST]], 12 +; CHECK-NEXT: [[CAST:%.*]] = zext i3 %a to i5 +; CHECK-NEXT: [[OP2:%.*]] = xor i5 [[CAST]], 15 ; CHECK-NEXT: ret i5 [[OP2]] ; %op1 = xor i3 %a, 3 @@ -16,9 +15,8 @@ define <2 x i32> @XorZextXorVec(<2 x i1> %a) { ; CHECK-LABEL: @XorZextXorVec( -; CHECK-NEXT: [[OP1:%.*]] = xor <2 x i1> %a, -; CHECK-NEXT: [[CAST:%.*]] = zext <2 x i1> [[OP1]] to <2 x i32> -; CHECK-NEXT: [[OP2:%.*]] = xor <2 x i32> [[CAST]], +; CHECK-NEXT: [[CAST:%.*]] = zext <2 x i1> %a to <2 x i32> +; CHECK-NEXT: [[OP2:%.*]] = xor <2 x i32> [[CAST]], ; CHECK-NEXT: ret <2 x i32> [[OP2]] ; %op1 = xor <2 x i1> %a, @@ -29,9 +27,8 @@ define i5 @OrZextOr(i3 %a) { ; CHECK-LABEL: @OrZextOr( -; CHECK-NEXT: [[OP1:%.*]] = or i3 %a, 3 -; CHECK-NEXT: [[CAST:%.*]] = zext i3 [[OP1]] to i5 -; CHECK-NEXT: [[OP2:%.*]] = or i5 [[CAST]], 8 +; CHECK-NEXT: [[CAST:%.*]] = zext i3 %a to i5 +; CHECK-NEXT: [[OP2:%.*]] = or i5 [[CAST]], 11 ; CHECK-NEXT: ret i5 [[OP2]] ; %op1 = or i3 %a, 3 @@ -42,9 +39,8 @@ define <2 x i32> @OrZextOrVec(<2 x i2> %a) { ; CHECK-LABEL: @OrZextOrVec( -; CHECK-NEXT: [[OP1:%.*]] = or <2 x i2> %a, -; CHECK-NEXT: [[CAST:%.*]] = zext <2 x i2> [[OP1]] to <2 x i32> -; CHECK-NEXT: [[OP2:%.*]] = or <2 x i32> [[CAST]], +; CHECK-NEXT: [[CAST:%.*]] = zext <2 x i2> %a to <2 x i32> +; CHECK-NEXT: [[OP2:%.*]] = or <2 x i32> [[CAST]], ; CHECK-NEXT: ret <2 x i32> [[OP2]] ; %op1 = or <2 x i2> %a, @@ -69,9 +65,8 @@ define <2 x i32> @AndZextAndVec(<2 x i8> %a) { ; CHECK-LABEL: @AndZextAndVec( -; CHECK-NEXT: [[OP1:%.*]] = and <2 x i8> %a, -; CHECK-NEXT: [[CAST:%.*]] = zext <2 x i8> [[OP1]] to <2 x i32> -; CHECK-NEXT: [[OP2:%.*]] = and <2 x i32> [[CAST]], +; CHECK-NEXT: [[CAST:%.*]] = zext <2 x i8> %a to <2 x i32> +; CHECK-NEXT: [[OP2:%.*]] = and <2 x i32> [[CAST]], ; CHECK-NEXT: ret <2 x i32> [[OP2]] ; %op1 = and <2 x i8> %a,