Index: clang/lib/AST/Interp/ByteCodeExprGen.h =================================================================== --- clang/lib/AST/Interp/ByteCodeExprGen.h +++ clang/lib/AST/Interp/ByteCodeExprGen.h @@ -276,6 +276,7 @@ return FPO.getRoundingMode(); } + bool emitPrimCast(PrimType FromT, PrimType ToT, QualType ToQT, const Expr *E); bool emitRecordDestruction(const Descriptor *Desc); bool emitDerivedToBaseCasts(const RecordType *DerivedType, const RecordType *BaseType, const Expr *E); Index: clang/lib/AST/Interp/ByteCodeExprGen.cpp =================================================================== --- clang/lib/AST/Interp/ByteCodeExprGen.cpp +++ clang/lib/AST/Interp/ByteCodeExprGen.cpp @@ -986,19 +986,22 @@ template bool ByteCodeExprGen::VisitFloatCompoundAssignOperator( const CompoundAssignOperator *E) { - assert(E->getType()->isFloatingType()); const Expr *LHS = E->getLHS(); const Expr *RHS = E->getRHS(); - llvm::RoundingMode RM = getRoundingMode(E); + QualType LHSType = LHS->getType(); QualType LHSComputationType = E->getComputationLHSType(); QualType ResultType = E->getComputationResultType(); std::optional LT = classify(LHSComputationType); std::optional RT = classify(ResultType); + assert(ResultType->isFloatingType()); + if (!LT || !RT) return false; + PrimType LHST = classifyPrim(LHSType); + // C++17 onwards require that we evaluate the RHS first. // Compute RHS and save it in a temporary variable so we can // load it again later. @@ -1012,21 +1015,19 @@ // First, visit LHS. if (!visit(LHS)) return false; - if (!this->emitLoad(*LT, E)) + if (!this->emitLoad(LHST, E)) return false; // If necessary, convert LHS to its computation type. - if (LHS->getType() != LHSComputationType) { - const auto *TargetSemantics = &Ctx.getFloatSemantics(LHSComputationType); - - if (!this->emitCastFP(TargetSemantics, RM, E)) - return false; - } + if (!this->emitPrimCast(LHST, classifyPrim(LHSComputationType), + LHSComputationType, E)) + return false; // Now load RHS. if (!this->emitGetLocal(*RT, TempOffset, E)) return false; + llvm::RoundingMode RM = getRoundingMode(E); switch (E->getOpcode()) { case BO_AddAssign: if (!this->emitAddf(RM, E)) @@ -1048,17 +1049,12 @@ return false; } - // If necessary, convert result to LHS's type. - if (LHS->getType() != ResultType) { - const auto *TargetSemantics = &Ctx.getFloatSemantics(LHS->getType()); - - if (!this->emitCastFP(TargetSemantics, RM, E)) - return false; - } + if (!this->emitPrimCast(classifyPrim(ResultType), LHST, LHS->getType(), E)) + return false; if (DiscardResult) - return this->emitStorePop(*LT, E); - return this->emitStore(*LT, E); + return this->emitStorePop(LHST, E); + return this->emitStore(LHST, E); } template @@ -1100,14 +1096,6 @@ bool ByteCodeExprGen::VisitCompoundAssignOperator( const CompoundAssignOperator *E) { - // Handle floating point operations separately here, since they - // require special care. - if (E->getType()->isFloatingType()) - return VisitFloatCompoundAssignOperator(E); - - if (E->getType()->isPointerType()) - return VisitPointerCompoundAssignOperator(E); - const Expr *LHS = E->getLHS(); const Expr *RHS = E->getRHS(); std::optional LHSComputationT = @@ -1120,6 +1108,15 @@ if (!LT || !RT || !RHST || !ResultT || !LHSComputationT) return false; + // Handle floating point operations separately here, since they + // require special care. + + if (ResultT == PT_Float || RT == PT_Float) + return VisitFloatCompoundAssignOperator(E); + + if (E->getType()->isPointerType()) + return VisitPointerCompoundAssignOperator(E); + assert(!E->getType()->isPointerType() && "Handled above"); assert(!E->getType()->isFloatingType() && "Handled above"); @@ -2654,6 +2651,38 @@ return OffsetSum; } +/// Emit casts from a PrimType to another PrimType +template +bool ByteCodeExprGen::emitPrimCast(PrimType FromT, PrimType ToT, + QualType ToQT, const Expr *E) { + + if (FromT == PT_Float) { + // Floating to floating. + if (ToT == PT_Float) { + const llvm::fltSemantics *ToSem = &Ctx.getFloatSemantics(ToQT); + return this->emitCastFP(ToSem, getRoundingMode(E), E); + } + + // Float to integral. + if (isIntegralType(ToT)) + return this->emitCastFloatingIntegral(ToT, E); + } + + if (isIntegralType(FromT)) { + // Integral to integral. + if (isIntegralType(ToT)) + return FromT != ToT ? this->emitCast(FromT, ToT, E) : true; + + if (ToT == PT_Float) { + const llvm::fltSemantics *ToSem = &Ctx.getFloatSemantics(ToQT); + return this->emitCastIntegralFloating(FromT, ToSem, getRoundingMode(E), + E); + } + } + + return false; +} + /// When calling this, we have a pointer of the local-to-destroy /// on the stack. /// Emit destruction of record types (or arrays of record types). Index: clang/test/AST/Interp/floats.cpp =================================================================== --- clang/test/AST/Interp/floats.cpp +++ clang/test/AST/Interp/floats.cpp @@ -102,6 +102,22 @@ return a[1]; } static_assert(ff() == 3, ""); + + constexpr float intPlusDouble() { + int a = 0; + a += 2.0; + + return a; + } + static_assert(intPlusDouble() == 2, ""); + + constexpr float doublePlusInt() { + double a = 0.0; + a += 2; + + return a; + } + static_assert(doublePlusInt() == 2, ""); } namespace unary {