diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h @@ -40,37 +40,6 @@ Location loc) override; }; -/// The predicate indicates the type of the comparison to perform: -/// (un)orderedness, (in)equality and less/greater than (or equal to) as -/// well as predicates that are always true or false. -enum class CmpFPredicate { - FirstValidValue, - // Always false - AlwaysFalse = FirstValidValue, - // Ordered comparisons - OEQ, - OGT, - OGE, - OLT, - OLE, - ONE, - // Both ordered - ORD, - // Unordered comparisons - UEQ, - UGT, - UGE, - ULT, - ULE, - UNE, - // Any unordered - UNO, - // Always true - AlwaysTrue, - // Number of predicates. - NumPredicates -}; - #define GET_OP_CLASSES #include "mlir/Dialect/StandardOps/IR/Ops.h.inc" diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -433,6 +433,34 @@ }]; } +// The predicate indicates the type of the comparison to perform: +// (un)orderedness, (in)equality and less/greater than (or equal to) as +// well as predicates that are always true or false. +def CMPF_P_FALSE : I64EnumAttrCase<"AlwaysFalse", 0, "false">; +def CMPF_P_OEQ : I64EnumAttrCase<"OEQ", 1, "oeq">; +def CMPF_P_OGT : I64EnumAttrCase<"OGT", 2, "ogt">; +def CMPF_P_OGE : I64EnumAttrCase<"OGE", 3, "oge">; +def CMPF_P_OLT : I64EnumAttrCase<"OLT", 4, "olt">; +def CMPF_P_OLE : I64EnumAttrCase<"OLE", 5, "ole">; +def CMPF_P_ONE : I64EnumAttrCase<"ONE", 6, "one">; +def CMPF_P_ORD : I64EnumAttrCase<"ORD", 7, "ord">; +def CMPF_P_UEQ : I64EnumAttrCase<"UEQ", 8, "ueq">; +def CMPF_P_UGT : I64EnumAttrCase<"UGT", 9, "ugt">; +def CMPF_P_UGE : I64EnumAttrCase<"UGE", 10, "uge">; +def CMPF_P_ULT : I64EnumAttrCase<"ULT", 11, "ult">; +def CMPF_P_ULE : I64EnumAttrCase<"ULE", 12, "ule">; +def CMPF_P_UNE : I64EnumAttrCase<"UNE", 13, "une">; +def CMPF_P_UNO : I64EnumAttrCase<"UNO", 14, "uno">; +def CMPF_P_TRUE : I64EnumAttrCase<"AlwaysTrue", 15, "true">; + +def CmpFPredicateAttr : I64EnumAttr< + "CmpFPredicate", "", + [CMPF_P_FALSE, CMPF_P_OEQ, CMPF_P_OGT, CMPF_P_OGE, CMPF_P_OLT, CMPF_P_OLE, + CMPF_P_ONE, CMPF_P_ORD, CMPF_P_UEQ, CMPF_P_UGT, CMPF_P_UGE, CMPF_P_ULT, + CMPF_P_ULE, CMPF_P_UNE, CMPF_P_UNO, CMPF_P_TRUE]> { + let cppNamespace = "::mlir"; +} + def CmpFOp : Std_Op<"cmpf", [NoSideEffect, SameTypeOperands, SameOperandsAndResultShape, TypesMatchWith< @@ -461,7 +489,11 @@ %r3 = "std.cmpf"(%0, %1) {predicate: 0} : (f8, f8) -> i1 }]; - let arguments = (ins FloatLike:$lhs, FloatLike:$rhs); + let arguments = (ins + CmpFPredicateAttr:$predicate, + FloatLike:$lhs, + FloatLike:$rhs + ); let results = (outs BoolLike:$result); let builders = [OpBuilder< @@ -480,7 +512,11 @@ } }]; + let verifier = [{ return success(); }]; + let hasFolder = 1; + + let assemblyFormat = "$predicate `,` $lhs `,` $rhs attr-dict `:` type($lhs)"; } def CMPI_P_EQ : I64EnumAttrCase<"eq", 0>; diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -580,55 +580,6 @@ // CmpFOp //===----------------------------------------------------------------------===// -// Returns an array of mnemonics for CmpFPredicates indexed by values thereof. -static inline const char *const *getCmpFPredicateNames() { - static const char *predicateNames[] = { - /*AlwaysFalse*/ "false", - /*OEQ*/ "oeq", - /*OGT*/ "ogt", - /*OGE*/ "oge", - /*OLT*/ "olt", - /*OLE*/ "ole", - /*ONE*/ "one", - /*ORD*/ "ord", - /*UEQ*/ "ueq", - /*UGT*/ "ugt", - /*UGE*/ "uge", - /*ULT*/ "ult", - /*ULE*/ "ule", - /*UNE*/ "une", - /*UNO*/ "uno", - /*AlwaysTrue*/ "true", - }; - static_assert(std::extent::value == - (size_t)CmpFPredicate::NumPredicates, - "wrong number of predicate names"); - return predicateNames; -} - -// Returns a value of the predicate corresponding to the given mnemonic. -// Returns NumPredicates (one-past-end) if there is no such mnemonic. -CmpFPredicate CmpFOp::getPredicateByName(StringRef name) { - return llvm::StringSwitch(name) - .Case("false", CmpFPredicate::AlwaysFalse) - .Case("oeq", CmpFPredicate::OEQ) - .Case("ogt", CmpFPredicate::OGT) - .Case("oge", CmpFPredicate::OGE) - .Case("olt", CmpFPredicate::OLT) - .Case("ole", CmpFPredicate::OLE) - .Case("one", CmpFPredicate::ONE) - .Case("ord", CmpFPredicate::ORD) - .Case("ueq", CmpFPredicate::UEQ) - .Case("ugt", CmpFPredicate::UGT) - .Case("uge", CmpFPredicate::UGE) - .Case("ult", CmpFPredicate::ULT) - .Case("ule", CmpFPredicate::ULE) - .Case("une", CmpFPredicate::UNE) - .Case("uno", CmpFPredicate::UNO) - .Case("true", CmpFPredicate::AlwaysTrue) - .Default(CmpFPredicate::NumPredicates); -} - static void buildCmpFOp(Builder *build, OperationState &result, CmpFPredicate predicate, Value lhs, Value rhs) { result.addOperands({lhs, rhs}); @@ -638,73 +589,8 @@ build->getI64IntegerAttr(static_cast(predicate))); } -static ParseResult parseCmpFOp(OpAsmParser &parser, OperationState &result) { - SmallVector ops; - SmallVector attrs; - Attribute predicateNameAttr; - Type type; - if (parser.parseAttribute(predicateNameAttr, CmpFOp::getPredicateAttrName(), - attrs) || - parser.parseComma() || parser.parseOperandList(ops, 2) || - parser.parseOptionalAttrDict(attrs) || parser.parseColonType(type) || - parser.resolveOperands(ops, type, result.operands)) - return failure(); - - if (!predicateNameAttr.isa()) - return parser.emitError(parser.getNameLoc(), - "expected string comparison predicate attribute"); - - // Rewrite string attribute to an enum value. - StringRef predicateName = predicateNameAttr.cast().getValue(); - auto predicate = CmpFOp::getPredicateByName(predicateName); - if (predicate == CmpFPredicate::NumPredicates) - return parser.emitError(parser.getNameLoc(), - "unknown comparison predicate \"" + predicateName + - "\""); - - auto builder = parser.getBuilder(); - Type i1Type = getCheckedI1SameShape(type); - if (!i1Type) - return parser.emitError(parser.getNameLoc(), - "expected type with valid i1 shape"); - - attrs[0].second = builder.getI64IntegerAttr(static_cast(predicate)); - result.attributes = attrs; - - result.addTypes({i1Type}); - return success(); -} - -static void print(OpAsmPrinter &p, CmpFOp op) { - p << "cmpf "; - - auto predicateValue = - op.getAttrOfType(CmpFOp::getPredicateAttrName()).getInt(); - assert(predicateValue >= static_cast(CmpFPredicate::FirstValidValue) && - predicateValue < static_cast(CmpFPredicate::NumPredicates) && - "unknown predicate index"); - p << '"' << getCmpFPredicateNames()[predicateValue] << '"' << ", " << op.lhs() - << ", " << op.rhs(); - p.printOptionalAttrDict(op.getAttrs(), - /*elidedAttrs=*/{CmpFOp::getPredicateAttrName()}); - p << " : " << op.lhs().getType(); -} - -static LogicalResult verify(CmpFOp op) { - auto predicateAttr = - op.getAttrOfType(CmpFOp::getPredicateAttrName()); - if (!predicateAttr) - return op.emitOpError("requires an integer attribute named 'predicate'"); - auto predicate = predicateAttr.getInt(); - if (predicate < (int64_t)CmpFPredicate::FirstValidValue || - predicate >= (int64_t)CmpFPredicate::NumPredicates) - return op.emitOpError("'predicate' attribute value out of range"); - - return success(); -} - -// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point -// comparison predicates. +/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point +/// comparison predicates. static bool applyCmpPredicate(CmpFPredicate predicate, const APFloat &lhs, const APFloat &rhs) { auto cmpResult = lhs.compare(rhs); diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -346,28 +346,28 @@ // ----- func @cmpf_generic_invalid_predicate_value(%a : f32) { - // expected-error@+1 {{'predicate' attribute value out of range}} + // expected-error@+1 {{attribute 'predicate' failed to satisfy constraint: allowed 64-bit integer cases}} %r = "std.cmpf"(%a, %a) {predicate = 42} : (f32, f32) -> i1 } // ----- func @cmpf_canonical_invalid_predicate_value(%a : f32) { - // expected-error@+1 {{unknown comparison predicate "foo"}} + // expected-error@+1 {{invalid predicate attribute specification: "foo"}} %r = cmpf "foo", %a, %a : f32 } // ----- func @cmpf_canonical_invalid_predicate_value_signed(%a : f32) { - // expected-error@+1 {{unknown comparison predicate "sge"}} + // expected-error@+1 {{invalid predicate attribute specification: "sge"}} %r = cmpf "sge", %a, %a : f32 } // ----- func @cmpf_canonical_invalid_predicate_value_no_order(%a : f32) { - // expected-error@+1 {{unknown comparison predicate "eq"}} + // expected-error@+1 {{invalid predicate attribute specification: "eq"}} %r = cmpf "eq", %a, %a : f32 } @@ -380,14 +380,14 @@ // ----- func @cmpf_generic_no_predicate_attr(%a : f32, %b : f32) { - // expected-error@+1 {{requires an integer attribute named 'predicate'}} + // expected-error@+1 {{requires attribute 'predicate'}} %r = "std.cmpf"(%a, %b) {foo = 1} : (f32, f32) -> i1 } // ----- func @cmpf_wrong_type(%a : i32, %b : i32) { - %r = cmpf "oeq", %a, %b : i32 // expected-error {{operand #0 must be floating-point-like}} + %r = cmpf "oeq", %a, %b : i32 // expected-error {{must be floating-point-like}} } // -----