diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.h b/llvm/lib/Target/RISCV/RISCVInstrInfo.h --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.h +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.h @@ -208,6 +208,8 @@ bool isAssociativeAndCommutative(const MachineInstr &Inst, bool Invert) const override; + std::optional getInverseOpcode(unsigned Opcode) const override; + protected: const RISCVSubtarget &STI; }; diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp @@ -1215,8 +1215,12 @@ bool RISCVInstrInfo::isAssociativeAndCommutative(const MachineInstr &Inst, bool Invert) const { unsigned Opc = Inst.getOpcode(); - if (Invert) - return false; + if (Invert) { + auto InverseOpcode = getInverseOpcode(Opc); + if (!InverseOpcode) + return false; + Opc = *InverseOpcode; + } if (isFADD(Opc) || isFMUL(Opc)) return Inst.getFlag(MachineInstr::MIFlag::FmReassoc) && @@ -1224,6 +1228,26 @@ return false; } +std::optional +RISCVInstrInfo::getInverseOpcode(unsigned Opcode) const { + switch (Opcode) { + default: + return None; + case RISCV::FADD_H: + return RISCV::FSUB_H; + case RISCV::FADD_S: + return RISCV::FSUB_S; + case RISCV::FADD_D: + return RISCV::FSUB_D; + case RISCV::FSUB_H: + return RISCV::FADD_H; + case RISCV::FSUB_S: + return RISCV::FADD_S; + case RISCV::FSUB_D: + return RISCV::FADD_D; + } +} + static bool canCombineFPFusedMultiply(const MachineInstr &Root, const MachineOperand &MO, bool DoRegPressureReduce) { diff --git a/llvm/test/CodeGen/RISCV/machine-combiner.ll b/llvm/test/CodeGen/RISCV/machine-combiner.ll --- a/llvm/test/CodeGen/RISCV/machine-combiner.ll +++ b/llvm/test/CodeGen/RISCV/machine-combiner.ll @@ -129,18 +129,18 @@ define double @test_reassoc_big2(double %a0, double %a1, i32 %a2, double %a3, i32 %a4, double %a5) { ; CHECK-LABEL: test_reassoc_big2: ; CHECK: # %bb.0: -; CHECK-NEXT: fcvt.d.w ft0, a0 -; CHECK-NEXT: fcvt.d.w ft1, a1 -; CHECK-NEXT: fmul.d ft0, fa2, ft0 -; CHECK-NEXT: fmul.d ft1, ft1, fa1 -; CHECK-NEXT: fadd.d ft2, fa0, fa1 -; CHECK-NEXT: fadd.d ft3, fa2, fa1 -; CHECK-NEXT: fmul.d ft0, ft1, ft0 -; CHECK-NEXT: fadd.d ft1, fa2, ft2 -; CHECK-NEXT: fmul.d ft2, fa0, ft3 -; CHECK-NEXT: fsub.d ft1, fa3, ft1 -; CHECK-NEXT: fmul.d ft0, ft0, ft2 -; CHECK-NEXT: fmul.d fa0, ft1, ft0 +; CHECK-NEXT: fadd.d ft0, fa0, fa1 +; CHECK-NEXT: fsub.d ft1, fa3, fa2 +; CHECK-NEXT: fadd.d ft2, fa2, fa1 +; CHECK-NEXT: fcvt.d.w ft3, a0 +; CHECK-NEXT: fcvt.d.w ft4, a1 +; CHECK-NEXT: fmul.d ft3, fa2, ft3 +; CHECK-NEXT: fmul.d ft4, ft4, fa1 +; CHECK-NEXT: fsub.d ft0, ft1, ft0 +; CHECK-NEXT: fmul.d ft1, fa0, ft2 +; CHECK-NEXT: fmul.d ft2, ft4, ft3 +; CHECK-NEXT: fmul.d ft0, ft0, ft1 +; CHECK-NEXT: fmul.d fa0, ft0, ft2 ; CHECK-NEXT: ret %cvt1 = sitofp i32 %a2 to double %cvt2 = sitofp i32 %a4 to double