diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.h b/flang/include/flang/Optimizer/Dialect/FIROps.h --- a/flang/include/flang/Optimizer/Dialect/FIROps.h +++ b/flang/include/flang/Optimizer/Dialect/FIROps.h @@ -22,9 +22,6 @@ class DoLoopOp; class RealAttr; -void buildCmpFOp(mlir::OpBuilder &builder, mlir::OperationState &result, - mlir::CmpFPredicate predicate, mlir::Value lhs, - mlir::Value rhs); void buildCmpCOp(mlir::OpBuilder &builder, mlir::OperationState &result, mlir::CmpFPredicate predicate, mlir::Value lhs, mlir::Value rhs); @@ -33,8 +30,6 @@ DoLoopOp getForInductionVarOwner(mlir::Value val); bool isReferenceLike(mlir::Type type); mlir::ParseResult isValidCaseAttr(mlir::Attribute attr); -mlir::ParseResult parseCmpfOp(mlir::OpAsmParser &parser, - mlir::OperationState &result); mlir::ParseResult parseCmpcOp(mlir::OpAsmParser &parser, mlir::OperationState &result); mlir::ParseResult parseSelector(mlir::OpAsmParser &parser, diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td --- a/flang/include/flang/Optimizer/Dialect/FIROps.td +++ b/flang/include/flang/Optimizer/Dialect/FIROps.td @@ -2783,41 +2783,6 @@ def fir_ModfOp : RealArithmeticOp<"modf">; -def fir_CmpfOp : fir_Op<"cmpf", - [NoSideEffect, SameTypeOperands, SameOperandsAndResultShape]> { - let summary = "floating-point comparison operator"; - - let description = [{ - Extends the standard floating-point comparison to handle the extended - floating-point types found in FIR. - }]; - - let arguments = (ins AnyRealLike:$lhs, AnyRealLike:$rhs); - - let results = (outs AnyLogicalLike); - - let builders = [OpBuilder<(ins "mlir::CmpFPredicate":$predicate, - "mlir::Value":$lhs, "mlir::Value":$rhs), [{ - buildCmpFOp($_builder, $_state, predicate, lhs, rhs); - }]>]; - - let parser = [{ return parseCmpfOp(parser, result); }]; - - let printer = [{ printCmpfOp(p, *this); }]; - - let extraClassDeclaration = [{ - static constexpr llvm::StringRef getPredicateAttrName() { - return "predicate"; - } - static CmpFPredicate getPredicateByName(llvm::StringRef name); - - CmpFPredicate getPredicate() { - return (CmpFPredicate)(*this)->getAttrOfType( - getPredicateAttrName()).getInt(); - } - }]; -} - def fir_ConstcOp : fir_Op<"constc", [NoSideEffect]> { let summary = "create a complex constant"; @@ -2918,6 +2883,8 @@ return (CmpFPredicate)(*this)->getAttrOfType( getPredicateAttrName()).getInt(); } + + static CmpFPredicate getPredicateByName(llvm::StringRef name); }]; } diff --git a/flang/lib/Lower/IntrinsicCall.cpp b/flang/lib/Lower/IntrinsicCall.cpp --- a/flang/lib/Lower/IntrinsicCall.cpp +++ b/flang/lib/Lower/IntrinsicCall.cpp @@ -1041,7 +1041,7 @@ auto zero = builder.createRealZeroConstant(loc, resultType); auto diff = builder.create(loc, args[0], args[1]); auto cmp = - builder.create(loc, mlir::CmpFPredicate::OGT, diff, zero); + builder.create(loc, mlir::CmpFPredicate::OGT, diff, zero); return builder.create(loc, cmp, diff, zero); } @@ -1188,8 +1188,8 @@ auto zeroAttr = builder.getZeroAttr(resultType); auto zero = builder.create(loc, resultType, zeroAttr); auto neg = builder.create(loc, abs); - auto cmp = - builder.create(loc, mlir::CmpFPredicate::OLT, args[1], zero); + auto cmp = builder.create(loc, mlir::CmpFPredicate::OLT, + args[1], zero); return builder.create(loc, cmp, neg, abs); } @@ -1213,26 +1213,26 @@ // Return the number if one of the inputs is NaN and the other is // a number. auto leftIsResult = - builder.create(loc, orderedCmp, left, right); - auto rightIsNan = builder.create( + builder.create(loc, orderedCmp, left, right); + auto rightIsNan = builder.create( loc, mlir::CmpFPredicate::UNE, right, right); result = builder.create(loc, leftIsResult, rightIsNan); } else if constexpr (behavior == ExtremumBehavior::IeeeMinMaximum) { // Always return NaNs if one the input is NaNs auto leftIsResult = - builder.create(loc, orderedCmp, left, right); - auto leftIsNan = builder.create( + builder.create(loc, orderedCmp, left, right); + auto leftIsNan = builder.create( loc, mlir::CmpFPredicate::UNE, left, left); result = builder.create(loc, leftIsResult, leftIsNan); } else if constexpr (behavior == ExtremumBehavior::MinMaxss) { // If the left is a NaN, return the right whatever it is. - result = builder.create(loc, orderedCmp, left, right); + result = builder.create(loc, orderedCmp, left, right); } else if constexpr (behavior == ExtremumBehavior::PgfortranLlvm) { // If one of the operand is a NaN, return left whatever it is. static constexpr auto unorderedCmp = extremum == Extremum::Max ? mlir::CmpFPredicate::UGT : mlir::CmpFPredicate::ULT; - result = builder.create(loc, unorderedCmp, left, right); + result = builder.create(loc, unorderedCmp, left, right); } else { // TODO: ieeeMinNum/ieeeMaxNum static_assert(behavior == ExtremumBehavior::IeeeMinMaxNum, diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -299,25 +299,9 @@ } //===----------------------------------------------------------------------===// -// CmpfOp +// CmpOp //===----------------------------------------------------------------------===// -// Note: getCmpFPredicateNames() is inline static in StandardOps/IR/Ops.cpp -mlir::CmpFPredicate fir::CmpfOp::getPredicateByName(llvm::StringRef name) { - auto pred = mlir::symbolizeCmpFPredicate(name); - assert(pred.hasValue() && "invalid predicate name"); - return pred.getValue(); -} - -void fir::buildCmpFOp(OpBuilder &builder, OperationState &result, - CmpFPredicate predicate, Value lhs, Value rhs) { - result.addOperands({lhs, rhs}); - result.types.push_back(builder.getI1Type()); - result.addAttribute( - CmpfOp::getPredicateAttrName(), - builder.getI64IntegerAttr(static_cast(predicate))); -} - template static void printCmpOp(OpAsmPrinter &p, OPTY op) { p << ' '; @@ -335,8 +319,6 @@ p << " : " << op.lhs().getType(); } -static void printCmpfOp(OpAsmPrinter &p, CmpfOp op) { printCmpOp(p, op); } - template static mlir::ParseResult parseCmpOp(mlir::OpAsmParser &parser, mlir::OperationState &result) { @@ -358,7 +340,7 @@ // Rewrite string attribute to an enum value. llvm::StringRef predicateName = predicateNameAttr.cast().getValue(); - auto predicate = fir::CmpfOp::getPredicateByName(predicateName); + auto predicate = fir::CmpcOp::getPredicateByName(predicateName); auto builder = parser.getBuilder(); mlir::Type i1Type = builder.getI1Type(); attrs.set(OPTY::getPredicateAttrName(), @@ -368,11 +350,6 @@ return success(); } -mlir::ParseResult fir::parseCmpfOp(mlir::OpAsmParser &parser, - mlir::OperationState &result) { - return parseCmpOp(parser, result); -} - //===----------------------------------------------------------------------===// // CmpcOp //===----------------------------------------------------------------------===// @@ -386,6 +363,12 @@ builder.getI64IntegerAttr(static_cast(predicate))); } +mlir::CmpFPredicate fir::CmpcOp::getPredicateByName(llvm::StringRef name) { + auto pred = mlir::symbolizeCmpFPredicate(name); + assert(pred.hasValue() && "invalid predicate name"); + return pred.getValue(); +} + static void printCmpcOp(OpAsmPrinter &p, fir::CmpcOp op) { printCmpOp(p, op); } mlir::ParseResult fir::parseCmpcOp(mlir::OpAsmParser &parser, diff --git a/flang/test/Fir/fir-ops.fir b/flang/test/Fir/fir-ops.fir --- a/flang/test/Fir/fir-ops.fir +++ b/flang/test/Fir/fir-ops.fir @@ -441,51 +441,6 @@ fir.dt_entry "method", @method_impl } -// CHECK-LABEL: func @compare_real( -// CHECK-SAME: [[VAL_133:%.*]]: f128, [[VAL_134:%.*]]: f128) { -func @compare_real(%a : f128, %b : f128) { - -// CHECK: [[VAL_135:%.*]] = fir.cmpf "false", [[VAL_133]], [[VAL_134]] : f128 -// CHECK: [[VAL_136:%.*]] = fir.cmpf "oeq", [[VAL_133]], [[VAL_134]] : f128 -// CHECK: [[VAL_137:%.*]] = fir.cmpf "ogt", [[VAL_133]], [[VAL_134]] : f128 -// CHECK: [[VAL_138:%.*]] = fir.cmpf "oge", [[VAL_133]], [[VAL_134]] : f128 - %d0 = fir.cmpf "false", %a, %b : f128 - %d1 = fir.cmpf "oeq", %a, %b : f128 - %d2 = fir.cmpf "ogt", %a, %b : f128 - %d3 = fir.cmpf "oge", %a, %b : f128 - -// CHECK: [[VAL_139:%.*]] = fir.cmpf "olt", [[VAL_133]], [[VAL_134]] : f128 -// CHECK: [[VAL_140:%.*]] = fir.cmpf "ole", [[VAL_133]], [[VAL_134]] : f128 -// CHECK: [[VAL_141:%.*]] = fir.cmpf "one", [[VAL_133]], [[VAL_134]] : f128 -// CHECK: [[VAL_142:%.*]] = fir.cmpf "ord", [[VAL_133]], [[VAL_134]] : f128 - %a0 = fir.cmpf "olt", %a, %b : f128 - %a1 = fir.cmpf "ole", %a, %b : f128 - %a2 = fir.cmpf "one", %a, %b : f128 - %a3 = fir.cmpf "ord", %a, %b : f128 - -// CHECK: [[VAL_143:%.*]] = fir.cmpf "ueq", [[VAL_133]], [[VAL_134]] : f128 -// CHECK: [[VAL_144:%.*]] = fir.cmpf "ugt", [[VAL_133]], [[VAL_134]] : f128 -// CHECK: [[VAL_145:%.*]] = fir.cmpf "uge", [[VAL_133]], [[VAL_134]] : f128 -// CHECK: [[VAL_146:%.*]] = fir.cmpf "ult", [[VAL_133]], [[VAL_134]] : f128 - %b0 = fir.cmpf "ueq", %a, %b : f128 - %b1 = fir.cmpf "ugt", %a, %b : f128 - %b2 = fir.cmpf "uge", %a, %b : f128 - %b3 = fir.cmpf "ult", %a, %b : f128 - -// CHECK: [[VAL_147:%.*]] = fir.cmpf "ule", [[VAL_133]], [[VAL_134]] : f128 -// CHECK: [[VAL_148:%.*]] = fir.cmpf "une", [[VAL_133]], [[VAL_134]] : f128 -// CHECK: [[VAL_149:%.*]] = fir.cmpf "uno", [[VAL_133]], [[VAL_134]] : f128 -// CHECK: [[VAL_150:%.*]] = fir.cmpf "true", [[VAL_133]], [[VAL_134]] : f128 - %c0 = fir.cmpf "ule", %a, %b : f128 - %c1 = fir.cmpf "une", %a, %b : f128 - %c2 = fir.cmpf "uno", %a, %b : f128 - %c3 = fir.cmpf "true", %a, %b : f128 - -// CHECK: return -// CHECK: } - return -} - // CHECK-LABEL: func @compare_complex( // CHECK-SAME: [[VAL_151:%.*]]: !fir.complex<16>, [[VAL_152:%.*]]: !fir.complex<16>) { func @compare_complex(%a : !fir.complex<16>, %b : !fir.complex<16>) {