diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -576,12 +576,18 @@ "dialects/_python_test_ops_gen.py" -gen-python-op-bindings -bind-dialect=python_test) + mlir_tablegen( + "dialects/_python_test_enum_gen.py" + -gen-python-enum-bindings) add_public_tablegen_target(PythonTestDialectPyIncGen) declare_mlir_python_sources( MLIRPythonTestSources.Dialects.PythonTest.ops_gen ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}" ADD_TO_PARENT MLIRPythonTestSources.Dialects.PythonTest - SOURCES "dialects/_python_test_ops_gen.py") + SOURCES + "dialects/_python_test_ops_gen.py" + "dialects/_python_test_enum_gen.py" + ) declare_mlir_python_extension(MLIRPythonTestSources.PythonTestExtension MODULE_NAME _mlirPythonTest diff --git a/mlir/python/mlir/dialects/python_test.py b/mlir/python/mlir/dialects/python_test.py --- a/mlir/python/mlir/dialects/python_test.py +++ b/mlir/python/mlir/dialects/python_test.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._python_test_ops_gen import * +from ._python_test_enum_gen import * from .._mlir_libs._mlirPythonTest import TestAttr, TestType, TestTensorValue, TestIntegerRankedTensorType diff --git a/mlir/test/python/CMakeLists.txt b/mlir/test/python/CMakeLists.txt --- a/mlir/test/python/CMakeLists.txt +++ b/mlir/test/python/CMakeLists.txt @@ -7,6 +7,8 @@ mlir_tablegen(lib/PythonTestAttributes.cpp.inc -gen-attrdef-defs) mlir_tablegen(lib/PythonTestTypes.h.inc -gen-typedef-decls) mlir_tablegen(lib/PythonTestTypes.cpp.inc -gen-typedef-defs) +mlir_tablegen(lib/PythonTestEnums.h.inc -gen-enum-decls) +mlir_tablegen(lib/PythonTestEnums.cpp.inc -gen-enum-defs) add_public_tablegen_target(MLIRPythonTestIncGen) add_subdirectory(lib) diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py --- a/mlir/test/python/dialects/python_test.py +++ b/mlir/test/python/dialects/python_test.py @@ -487,3 +487,22 @@ two_operands = test.InferResultsVariadicInputsOp(single=zero, doubled=zero) # CHECK: f32 print(two_operands.result.type) + + +# CHECK-LABEL: TEST: testBitEnum +@run +def testBitEnum(): + with Context() as ctx, Location.unknown(ctx): + test.register_python_test_dialect(ctx) + module = Module.create() + with InsertionPoint(module.body): + op = test.OpWithBareBitEnumVerticalBar(1) + # CHECK: python_test.op_with_bit_bar_enum_vbar user + print(op) + op = test.OpWithBareBitEnumVerticalBar(2) + # CHECK: python_test.op_with_bit_bar_enum_vbar group + print(op) + op = test.OpWithBareBitEnumVerticalBar(1 | 2) + # CHECK: python_test.op_with_bit_bar_enum_vbar user | group + print(op) + print(module.body) diff --git a/mlir/test/python/lib/PythonTestDialect.h b/mlir/test/python/lib/PythonTestDialect.h --- a/mlir/test/python/lib/PythonTestDialect.h +++ b/mlir/test/python/lib/PythonTestDialect.h @@ -12,6 +12,7 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/OpImplementation.h" #include "mlir/Interfaces/InferTypeOpInterface.h" +#include "llvm/ADT/StringExtras.h" #include "PythonTestDialect.h.inc" @@ -24,4 +25,29 @@ #define GET_TYPEDEF_CLASSES #include "PythonTestTypes.h.inc" +namespace python_test { +inline std::string stringifyTestBitEnumVerticalBar(llvm::APInt symbol) { + auto val = static_cast(symbol.getLimitedValue()); + assert(7u == (7u | val) && "invalid bits set in bit enum"); + ::llvm::SmallVector<::llvm::StringRef, 2> strs; + + if (1u == (1u & val)) + strs.push_back("user"); + + if (2u == (2u & val)) + strs.push_back("group"); + + if (4u == (4u & val)) + strs.push_back("other"); + return llvm::join(strs, " | "); +} +// ::mlir::ParseResult OpWithBareBitEnumVerticalBar::parse ... +// valueAttr = +// parser.getBuilder().getIntegerAttr(parser.getBuilder().getIntegerType(32), +// static_cast(*attrOptional)); + +} // namespace python_test + +#include "PythonTestEnums.h.inc" + #endif // MLIR_TEST_PYTHON_LIB_PYTHONTESTDIALECT_H diff --git a/mlir/test/python/lib/PythonTestDialect.cpp b/mlir/test/python/lib/PythonTestDialect.cpp --- a/mlir/test/python/lib/PythonTestDialect.cpp +++ b/mlir/test/python/lib/PythonTestDialect.cpp @@ -12,6 +12,7 @@ #include "llvm/ADT/TypeSwitch.h" #include "PythonTestDialect.cpp.inc" +#include "PythonTestEnums.cpp.inc" #define GET_ATTRDEF_CLASSES #include "PythonTestAttributes.cpp.inc" diff --git a/mlir/test/python/python_test_ops.td b/mlir/test/python/python_test_ops.td --- a/mlir/test/python/python_test_ops.td +++ b/mlir/test/python/python_test_ops.td @@ -12,6 +12,7 @@ include "mlir/IR/AttrTypeBase.td" include "mlir/IR/OpBase.td" include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/IR/EnumAttr.td" def Python_Test_Dialect : Dialect { let name = "python_test"; @@ -181,4 +182,23 @@ let results = (outs I32:$result); } +// Define an enum with a different separator +def TestBitEnumVerticalBar + : I32BitEnumAttr<"TestBitEnumVerticalBar", "another test bit enum", [ + I32BitEnumAttrCaseBit<"User", 0, "user">, + I32BitEnumAttrCaseBit<"Group", 1, "group">, + I32BitEnumAttrCaseBit<"Other", 2, "other">, + ]> { + let genSpecializedAttr = 0; + let cppNamespace = "python_test"; + let separator = " | "; +} + +// Define an op that contains the bare enum attribute. +def OpWithBareBitEnumVerticalBar : TestOp<"op_with_bit_bar_enum_vbar"> { + let arguments = (ins TestBitEnumVerticalBar:$value, + OptionalAttr:$tag); + let assemblyFormat = "$value (`tag` $tag^)? attr-dict"; +} + #endif // PYTHON_TEST_OPS