diff --git a/llvm/docs/TableGen/ProgRef.rst b/llvm/docs/TableGen/ProgRef.rst --- a/llvm/docs/TableGen/ProgRef.rst +++ b/llvm/docs/TableGen/ProgRef.rst @@ -1560,8 +1560,8 @@ ``!eq(`` *a*\ `,` *b*\ ``)`` This operator produces 1 if *a* is equal to *b*; 0 otherwise. - The arguments must be ``bit``, ``bits``, ``int``, or ``string`` values. - Use ``!cast`` to compare other types of objects. + The arguments must be ``bit``, ``bits``, ``int``, ``string``, or + record values. Use ``!cast`` to compare other types of objects. ``!filter(``\ *var*\ ``,`` *list*\ ``,`` *predicate*\ ``)`` @@ -1603,7 +1603,7 @@ ``!ge(``\ *a*\ `,` *b*\ ``)`` This operator produces 1 if *a* is greater than or equal to *b*; 0 otherwise. - The arguments must be ``bit``, ``bits``, or ``int`` values. + The arguments must be ``bit``, ``bits``, ``int``, or ``string`` values. ``!getdagop(``\ *dag*\ ``)`` --or-- ``!getdagop<``\ *type*\ ``>(``\ *dag*\ ``)`` This operator produces the operator of the given *dag* node. @@ -1629,7 +1629,7 @@ ``!gt(``\ *a*\ `,` *b*\ ``)`` This operator produces 1 if *a* is greater than *b*; 0 otherwise. - The arguments must be ``bit``, ``bits``, or ``int`` values. + The arguments must be ``bit``, ``bits``, ``int``, or ``string`` values. ``!head(``\ *a*\ ``)`` This operator produces the zeroth element of the list *a*. @@ -1652,7 +1652,7 @@ ``!le(``\ *a*\ ``,`` *b*\ ``)`` This operator produces 1 if *a* is less than or equal to *b*; 0 otherwise. - The arguments must be ``bit``, ``bits``, or ``int`` values. + The arguments must be ``bit``, ``bits``, ``int``, or ``string`` values. ``!listconcat(``\ *list1*\ ``,`` *list2*\ ``, ...)`` This operator concatenates the list arguments *list1*, *list2*, etc., and @@ -1665,15 +1665,15 @@ ``!lt(``\ *a*\ `,` *b*\ ``)`` This operator produces 1 if *a* is less than *b*; 0 otherwise. - The arguments must be ``bit``, ``bits``, or ``int`` values. + The arguments must be ``bit``, ``bits``, ``int``, or ``string`` values. ``!mul(``\ *a*\ ``,`` *b*\ ``, ...)`` This operator multiplies *a*, *b*, etc., and produces the product. ``!ne(``\ *a*\ `,` *b*\ ``)`` This operator produces 1 if *a* is not equal to *b*; 0 otherwise. - The arguments must be ``bit``, ``bits``, ``int``, or ``string`` values. - Use ``!cast`` to compare other types of objects. + The arguments must be ``bit``, ``bits``, ``int``, ``string``, + or record values. Use ``!cast`` to compare other types of objects. ``!not(``\ *a*\ ``)`` This operator performs a logical NOT on *a*, which must be diff --git a/llvm/lib/TableGen/Record.cpp b/llvm/lib/TableGen/Record.cpp --- a/llvm/lib/TableGen/Record.cpp +++ b/llvm/lib/TableGen/Record.cpp @@ -1010,36 +1010,53 @@ case LT: case GE: case GT: { - // try to fold eq comparison for 'bit' and 'int', otherwise fallback - // to string objects. - IntInit *L = + // First see if we have two bit, bits, or int. + IntInit *LHSi = dyn_cast_or_null(LHS->convertInitializerTo(IntRecTy::get())); - IntInit *R = + IntInit *RHSi = dyn_cast_or_null(RHS->convertInitializerTo(IntRecTy::get())); - if (L && R) { + if (LHSi && RHSi) { bool Result; switch (getOpcode()) { - case EQ: Result = L->getValue() == R->getValue(); break; - case NE: Result = L->getValue() != R->getValue(); break; - case LE: Result = L->getValue() <= R->getValue(); break; - case LT: Result = L->getValue() < R->getValue(); break; - case GE: Result = L->getValue() >= R->getValue(); break; - case GT: Result = L->getValue() > R->getValue(); break; + case EQ: Result = LHSi->getValue() == RHSi->getValue(); break; + case NE: Result = LHSi->getValue() != RHSi->getValue(); break; + case LE: Result = LHSi->getValue() <= RHSi->getValue(); break; + case LT: Result = LHSi->getValue() < RHSi->getValue(); break; + case GE: Result = LHSi->getValue() >= RHSi->getValue(); break; + case GT: Result = LHSi->getValue() > RHSi->getValue(); break; default: llvm_unreachable("unhandled comparison"); } return BitInit::get(Result); } - if (getOpcode() == EQ || getOpcode() == NE) { - StringInit *LHSs = dyn_cast(LHS); - StringInit *RHSs = dyn_cast(RHS); + // Next try strings. + StringInit *LHSs = dyn_cast(LHS); + StringInit *RHSs = dyn_cast(RHS); - // Make sure we've resolved - if (LHSs && RHSs) { - bool Equal = LHSs->getValue() == RHSs->getValue(); - return BitInit::get(getOpcode() == EQ ? Equal : !Equal); + if (LHSs && RHSs) { + bool Result; + switch (getOpcode()) { + case EQ: Result = LHSs->getValue() == RHSs->getValue(); break; + case NE: Result = LHSs->getValue() != RHSs->getValue(); break; + case LE: Result = LHSs->getValue() <= RHSs->getValue(); break; + case LT: Result = LHSs->getValue() < RHSs->getValue(); break; + case GE: Result = LHSs->getValue() >= RHSs->getValue(); break; + case GT: Result = LHSs->getValue() > RHSs->getValue(); break; + default: llvm_unreachable("unhandled comparison"); } + return BitInit::get(Result); +//// bool Equal = LHSs->getValue() == RHSs->getValue(); +//// return BitInit::get(getOpcode() == EQ ? Equal : !Equal); + } + + // Finally, !eq and !ne can be used with records. + if (getOpcode() == EQ || getOpcode() == NE) { + DefInit *LHSd = dyn_cast(LHS); + DefInit *RHSd = dyn_cast(RHS); + if (LHSd && RHSd) + return BitInit::get((getOpcode() == EQ) ? LHSd == RHSd + : LHSd != RHSd); } break; diff --git a/llvm/lib/TableGen/TGParser.cpp b/llvm/lib/TableGen/TGParser.cpp --- a/llvm/lib/TableGen/TGParser.cpp +++ b/llvm/lib/TableGen/TGParser.cpp @@ -1148,15 +1148,12 @@ break; case tgtok::XEq: case tgtok::XNe: - Type = BitRecTy::get(); - // ArgType for Eq / Ne is not known at this point - break; case tgtok::XLe: case tgtok::XLt: case tgtok::XGe: case tgtok::XGt: Type = BitRecTy::get(); - ArgType = IntRecTy::get(); + // ArgType for the comparison operators is not yet known. break; case tgtok::XListConcat: // We don't know the list type until we parse the first argument @@ -1244,10 +1241,24 @@ break; case BinOpInit::EQ: case BinOpInit::NE: + if (!ArgType->typeIsConvertibleTo(IntRecTy::get()) && + !ArgType->typeIsConvertibleTo(StringRecTy::get()) && + !ArgType->typeIsConvertibleTo(RecordRecTy::get({}))) { + Error(InitLoc, Twine("expected bit, bits, int, string, or record; " + "got value of type '") + ArgType->getAsString() + + "'"); + return nullptr; + } + break; + case BinOpInit::LE: + case BinOpInit::LT: + case BinOpInit::GE: + case BinOpInit::GT: if (!ArgType->typeIsConvertibleTo(IntRecTy::get()) && !ArgType->typeIsConvertibleTo(StringRecTy::get())) { - Error(InitLoc, Twine("expected int, bits, or string; got value of " - "type '") + ArgType->getAsString() + "'"); + Error(InitLoc, Twine("expected bit, bits, int, or string; " + "got value of type '") + ArgType->getAsString() + + "'"); return nullptr; } break; diff --git a/llvm/test/TableGen/compare.td b/llvm/test/TableGen/compare.td --- a/llvm/test/TableGen/compare.td +++ b/llvm/test/TableGen/compare.td @@ -1,54 +1,117 @@ // RUN: llvm-tblgen %s | FileCheck %s -// XFAIL: vg_leak - -// CHECK: --- Defs --- - -// CHECK: def A0 { -// CHECK: bit eq = 1; -// CHECK: bit ne = 0; -// CHECK: bit le = 1; -// CHECK: bit lt = 0; -// CHECK: bit ge = 1; -// CHECK: bit gt = 0; -// CHECK: } - -// CHECK: def A1 { -// CHECK: bit eq = 0; -// CHECK: bit ne = 1; -// CHECK: bit le = 1; -// CHECK: bit lt = 1; -// CHECK: bit ge = 0; -// CHECK: bit gt = 0; -// CHECK: } - -// CHECK: def A2 { -// CHECK: bit eq = 0; -// CHECK: bit ne = 1; -// CHECK: bit le = 0; -// CHECK: bit lt = 0; -// CHECK: bit ge = 1; -// CHECK: bit gt = 1; -// CHECK: } - -// CHECK: def A3 { -// CHECK: bit eq = 0; -// CHECK: bit ne = 1; -// CHECK: bit le = 0; -// CHECK: bit lt = 0; -// CHECK: bit ge = 1; -// CHECK: bit gt = 1; -// CHECK: } - -class A { - bit eq = !eq(x, y); - bit ne = !ne(x, y); - bit le = !le(x, y); - bit lt = !lt(x, y); - bit ge = !ge(x, y); - bit gt = !gt(x, y); -} - -def A0 : A<-3, -3>; -def A1 : A<-1, 4>; -def A2 : A<3, -2>; -def A3 : A<4, 2>; +// RUN: not llvm-tblgen -DERROR1 %s 2>&1 | FileCheck --check-prefix=ERROR1 %s +// RUN: not llvm-tblgen -DERROR2 %s 2>&1 | FileCheck --check-prefix=ERROR2 %s + +// This file tests the comparison bang operators. + +class BitCompare { + list compare = [!eq(a, b), !ne(a, b), + !lt(a, b), !le(a, b), + !gt(a, b), !ge(a, b)]; +} + +class BitsCompare a, bits<3> b> { + list compare = [!eq(a, b), !ne(a, b), + !lt(a, b), !le(a, b), + !gt(a, b), !ge(a, b)]; +} + +class IntCompare { + list compare = [!eq(a, b), !ne(a, b), + !lt(a, b), !le(a, b), + !gt(a, b), !ge(a, b)]; +} + +class StringCompare { + list compare = [!eq(a, b), !ne(a, b), + !lt(a, b), !le(a, b), + !gt(a, b), !ge(a, b)]; +} + +multiclass MC { + def _MC; +} + +// CHECK: def Bit00 +// CHECK: compare = [1, 0, 0, 1, 0, 1]; +// CHECK: def Bit01 +// CHECK: compare = [0, 1, 1, 1, 0, 0]; +// CHECK: def Bit10 +// CHECK: compare = [0, 1, 0, 0, 1, 1]; +// CHECK: def Bit11 +// CHECK: compare = [1, 0, 0, 1, 0, 1]; + +def Bit00 : BitCompare<0, 0>; +def Bit01 : BitCompare<0, 1>; +def Bit10 : BitCompare<1, 0>; +def Bit11 : BitCompare<1, 1>; + +// CHECK: def Bits1 +// CHECK: compare = [0, 1, 1, 1, 0, 0]; +// CHECK: def Bits2 +// CHECK: compare = [1, 0, 0, 1, 0, 1]; +// CHECK: def Bits3 +// CHECK: compare = [0, 1, 0, 0, 1, 1]; + +def Bits1 : BitsCompare<{0, 1, 0}, {1, 0, 1}>; +def Bits2 : BitsCompare<{0, 1, 1}, {0, 1, 1}>; +def Bits3 : BitsCompare<{1, 1, 1}, {0, 1, 1}>; + +// CHECK: def Int1 +// CHECK: compare = [0, 1, 1, 1, 0, 0]; +// CHECK: def Int2 +// CHECK: compare = [1, 0, 0, 1, 0, 1]; +// CHECK: def Int3 +// CHECK: compare = [0, 1, 0, 0, 1, 1]; + +def Int1 : IntCompare<-7, 13>; +def Int2 : IntCompare<42, 42>; +def Int3 : IntCompare<108, 42>; + +// CHECK: def Record1 +// CHECK: compare1 = [1, 0]; +// CHECK: compare2 = [0, 1]; +// CHECK: compare3 = [1, 1]; + +defm foo : MC; +defm bar : MC; + +def Record1 { + list compare1 = [!eq(Bit00, Bit00), !eq(Bit00, Bit01)]; + list compare2 = [!ne(Bit00, Bit00), !ne(Bit00, Int1)]; + list compare3 = [!eq(bar_MC, bar_MC), !ne(bar_MC, foo_MC)]; +} + +// CHECK: def String1 +// CHECK: compare = [0, 1, 1, 1, 0, 0]; +// CHECK: def String2 +// CHECK: compare = [1, 0, 0, 1, 0, 1]; +// CHECK: def String3 +// CHECK: compare = [0, 1, 0, 0, 1, 1]; +// CHECK: def String4 +// CHECK: compare = [0, 1, 0, 0, 1, 1]; +def String1 : StringCompare<"bar", "foo">; +def String2 : StringCompare<"foo", "foo">; +def String3 : StringCompare<"foo", "bar">; +def String4 : StringCompare<"foo", "Foo">; + +#ifdef ERROR1 + +// ERROR1: expected bit, bits, int, string, or record; got value + +def Zerror1 { + bit compare1 = !eq([0, 1, 2], [0, 1, 2]); +} + +#endif + +#ifdef ERROR2 + +// ERROR2: expected bit, bits, int, or string; got value + +def Zerror2 { + bit compare1 = !lt(Bit00, Bit00); +} + +#endif +