diff --git a/mlir/include/mlir/IR/EnumAttr.td b/mlir/include/mlir/IR/EnumAttr.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/IR/EnumAttr.td @@ -0,0 +1,69 @@ +//===-- EnumAttr.td - Enum attributes ----------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef ENUM_ATTR +#define ENUM_ATTR + +include "mlir/IR/OpBase.td" + +// A C++ enum as an attribute parameter. The parameter implements a parser and +// printer for the enum by dispatching calls to `stringToSymbol` and +// `symbolToString`. +class EnumParameter + : AttrParameter { + // Parse a keyword and pass it to `stringToSymbol`. Emit an error if a the + // symbol is not valid. + let parser = [{[&]() -> ::mlir::FailureOr<}] # cppType # [{> { + auto loc = $_parser.getCurrentLocation(); + ::llvm::StringRef enumKeyword; + if (::mlir::failed($_parser.parseKeyword(&enumKeyword))) + return ::mlir::failure(); + auto maybeEnum = }] # enumInfo.cppNamespace # "::" # + enumInfo.stringToSymbolFnName # [{(enumKeyword); + if (maybeEnum) + return *maybeEnum; + return {$_parser.emitError(loc, "expected }] # + cppType # [{ to be one of: }] # + !interleave(!foreach(enum, enumInfo.enumerants, enum.str), ", ") # [{")}; + }()}]; + // Print the enum by calling `symbolToString`. + let printer = [{$_printer << }] # enumInfo.symbolToStringFnName # + [{($_self)}]; +} + +// An attribute backed by a C++ enum. The attribute contains a single +// parameter `value` whose type is the C++ enum class. +class EnumAttr traits = []> + : AttrDef { + let summary = enumInfo.summary; + + // Inherit the C++ namespace from the enum. + let cppNamespace = enumInfo.cppNamespace; + + // Define a constant builder for the attribute to convert from C++ enums. + let constBuilderCall = cppNamespace # "::" # cppClassName # + "::get($_builder.getContext(), $0)"; + + // Op attribute getters should return the underlying C++ enum type. + let returnType = enumInfo.cppNamespace # "::" # enumInfo.className; + + // Convert from attribute to the underlying C++ type in op getters. + let convertFromStorage = "$_self.getValue()"; + + // The enum attribute has one parameter: the C++ enum value. + let parameters = (ins EnumParameter:$value); + + // If a mnemonic was provided, use it to generate a custom assembly format. + let mnemonic = name; + + let assemblyFormat = "$value"; +} + +#endif // ENUM_ATTR diff --git a/mlir/test/IR/enum-attr-invalid.mlir b/mlir/test/IR/enum-attr-invalid.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/IR/enum-attr-invalid.mlir @@ -0,0 +1,14 @@ +// RUN: mlir-opt -verify-diagnostics -split-input-file %s + +func @test_invalid_enum_case() -> () { + // expected-error@+2 {{expected test::TestEnum to be one of: first, second, third}} + // expected-error@+1 {{failed to parse TestEnumAttr}} + "test.op"() {value = #test<"enum fourth">} : () -> () +} + +// ----- + +func @test_invalid_attr() -> () { + // expected-error@+1 {{op attribute 'value' failed to satisfy constraint: a test enum}} + "test.op_with_enum"() {value = 1 : index} : () -> () +} diff --git a/mlir/test/IR/enum-attr-roundtrip.mlir b/mlir/test/IR/enum-attr-roundtrip.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/IR/enum-attr-roundtrip.mlir @@ -0,0 +1,28 @@ +// RUN: mlir-opt %s | mlir-opt -test-patterns | FileCheck %s + +// CHECK-LABEL: @test_enum_attr_roundtrip +func @test_enum_attr_roundtrip() -> () { + // CHECK: value = #test<"enum first"> + "test.op"() {value = #test<"enum first">} : () -> () + // CHECK: value = #test<"enum second"> + "test.op"() {value = #test<"enum second">} : () -> () + // CHECK: value = #test<"enum third"> + "test.op"() {value = #test<"enum third">} : () -> () + return +} + +// CHECK-LABEL: @test_op_with_enum +func @test_op_with_enum() -> () { + // CHECK: "test.op_with_enum"() {value = #test<"enum third">} + "test.op_with_enum"() {value = #test<"enum third">} : () -> () + return +} + +// CHECK-LABEL: @test_match_op_with_enum +func @test_match_op_with_enum() -> () { + // CHECK: "test.op_with_enum"() {tag = 0 : i32, value = #test<"enum third">} + "test.op_with_enum"() {tag = 0 : i32, value = #test<"enum third">} : () -> () + // CHECK: "test.op_with_enum"() {tag = 1 : i32, value = #test<"enum second">} + "test.op_with_enum"() {tag = 0 : i32, value = #test<"enum first">} : () -> () + return +} diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.h b/mlir/test/lib/Dialect/Test/TestAttributes.h --- a/mlir/test/lib/Dialect/Test/TestAttributes.h +++ b/mlir/test/lib/Dialect/Test/TestAttributes.h @@ -23,6 +23,7 @@ #include "mlir/IR/DialectImplementation.h" #include "TestAttrInterfaces.h.inc" +#include "TestOpEnums.h.inc" #define GET_ATTRDEF_CLASSES #include "TestAttrDefs.h.inc" diff --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h --- a/mlir/test/lib/Dialect/Test/TestDialect.h +++ b/mlir/test/lib/Dialect/Test/TestDialect.h @@ -39,7 +39,6 @@ class RewritePatternSet; } // namespace mlir -#include "TestOpEnums.h.inc" #include "TestOpInterfaces.h.inc" #include "TestOpStructs.h.inc" #include "TestOpsDialect.h.inc" diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -10,6 +10,7 @@ #define TEST_OPS include "mlir/Dialect/DLTI/DLTIBase.td" +include "mlir/IR/EnumAttr.td" include "mlir/IR/OpBase.td" include "mlir/IR/OpAsmInterface.td" include "mlir/IR/RegionKindInterface.td" @@ -315,6 +316,37 @@ ); } +//===----------------------------------------------------------------------===// +// Test Enum Attributes +//===----------------------------------------------------------------------===// + +// Define the C++ enum. +def TestEnum + : I32EnumAttr<"TestEnum", "a test enum", [ + I32EnumAttrCase<"First", 0, "first">, + I32EnumAttrCase<"Second", 1, "second">, + I32EnumAttrCase<"Third", 2, "third">, + ]> { + let genSpecializedAttr = 0; + let cppNamespace = "test"; +} + +// Define the enum attribute. +def TestEnumAttr : EnumAttr; + +// Define an op that contains the enum attribute. +def OpWithEnum : TEST_Op<"op_with_enum"> { + let arguments = (ins TestEnumAttr:$value, OptionalAttr:$tag); +} + +// Define a pattern that matches and creates an enum attribute. +def : Pat<(OpWithEnum ConstantAttr:$value, + ConstantAttr:$tag), + (OpWithEnum ConstantAttr, + ConstantAttr)>; + //===----------------------------------------------------------------------===// // Test Attribute Constraints //===----------------------------------------------------------------------===// diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -824,6 +824,7 @@ td_library( name = "OpBaseTdFiles", srcs = [ + "include/mlir/IR/EnumAttr.td", "include/mlir/IR/OpAsmInterface.td", "include/mlir/IR/OpBase.td", "include/mlir/IR/RegionKindInterface.td",