diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1296,11 +1296,14 @@ // Base class for attributes containing types. Example: // def IntTypeAttr : TypeAttrBase<"IntegerType", "integer type attribute"> // defines a type attribute containing an integer type. -class TypeAttrBase : +class TypeAttrBase> : Attr()">, CPred<"$_self.cast<::mlir::TypeAttr>().getValue().isa<" - # retType # ">()">]>, + # retType # ">()">, + SubstLeaves<"$_self", + "$_self.cast<::mlir::TypeAttr>().getValue()", typePred>]>, summary> { let storageType = [{ ::mlir::TypeAttr }]; let returnType = retType; @@ -1313,7 +1316,8 @@ } class TypeAttrOf - : TypeAttrBase { + : TypeAttrBase { let constBuilderCall = "::mlir::TypeAttr::get($0)"; } diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir --- a/mlir/test/IR/attribute.mlir +++ b/mlir/test/IR/attribute.mlir @@ -879,3 +879,11 @@ "test.default_value_print"(%arg0) {"value_with_default" = 1 : i32} : (i32) -> () return } + +// ----- + +func.func @type_attr_of_fail() { + // expected-error @below {{failed to satisfy constraint: type attribute of 64-bit signless integer}} + test.type_attr_of i32 + return +} diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -277,6 +277,13 @@ }]; } +def TypeAttrOfOp : TEST_Op<"type_attr_of"> { + let arguments = (ins TypeAttrOf:$type); + let assemblyFormat = [{ + attr-dict $type + }]; +} + def DenseArrayAttrOp : TEST_Op<"dense_array_attr"> { let arguments = (ins DenseBoolArrayAttr:$i1attr, diff --git a/mlir/test/mlir-tblgen/op-attribute.td b/mlir/test/mlir-tblgen/op-attribute.td --- a/mlir/test/mlir-tblgen/op-attribute.td +++ b/mlir/test/mlir-tblgen/op-attribute.td @@ -318,10 +318,10 @@ // DEF: if (tblgen_str_attr && !((tblgen_str_attr.isa<::mlir::StringAttr>()))) // DEF: if (tblgen_elements_attr && !((tblgen_elements_attr.isa<::mlir::ElementsAttr>()))) // DEF: if (tblgen_function_attr && !((tblgen_function_attr.isa<::mlir::FlatSymbolRefAttr>()))) -// DEF: if (tblgen_some_type_attr && !(((tblgen_some_type_attr.isa<::mlir::TypeAttr>())) && ((tblgen_some_type_attr.cast<::mlir::TypeAttr>().getValue().isa())))) +// DEF: if (tblgen_some_type_attr && !(((tblgen_some_type_attr.isa<::mlir::TypeAttr>())) && ((tblgen_some_type_attr.cast<::mlir::TypeAttr>().getValue().isa())) && ((true)))) // DEF: if (tblgen_array_attr && !((tblgen_array_attr.isa<::mlir::ArrayAttr>()))) // DEF: if (tblgen_some_attr_array && !(((tblgen_some_attr_array.isa<::mlir::ArrayAttr>())) && (::llvm::all_of(tblgen_some_attr_array.cast<::mlir::ArrayAttr>(), [&](::mlir::Attribute attr) { return attr && ((some-condition)); })))) -// DEF: if (tblgen_type_attr && !(((tblgen_type_attr.isa<::mlir::TypeAttr>())) && ((tblgen_type_attr.cast<::mlir::TypeAttr>().getValue().isa<::mlir::Type>())))) +// DEF: if (tblgen_type_attr && !(((tblgen_type_attr.isa<::mlir::TypeAttr>())) && ((tblgen_type_attr.cast<::mlir::TypeAttr>().getValue().isa<::mlir::Type>())) && ((true)))) // Test common attribute kind getters' return types // ---