diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1419,13 +1419,29 @@ // Note: All derived attributes should be materializable as an Attribute. E.g., // do not use DerivedAttr for things that could not have been stored as // Attribute. -class DerivedAttr : Attr, "derived attribute"> { +// +class DerivedAttr : + Attr, "derived attribute"> { let returnType = ret; code body = b; + + // Specify how to convert from the derived attribute to an attibute. + // + // ## Special placeholders + // + // Special placeholders can be used to refer to entities during conversion: + // + // * `$_builder` will be replaced by a mlir::Builder instance. + // * `$_ctx` will be replaced by the MLIRContext* instance. + // * `$_self` will be replaced with the derived attribute (value produces + // `returnType`). + let convertFromStorage = convert; } // Derived attribute that returns a mlir::Type. -class DerivedTypeAttr : DerivedAttr<"Type", body>; +class DerivedTypeAttr : DerivedAttr<"Type", body> { + let convertFromStorage = "TypeAttr::get($_self)"; +} //===----------------------------------------------------------------------===// // Constant attribute kinds diff --git a/mlir/include/mlir/Interfaces/DerivedAttributeOpInterface.td b/mlir/include/mlir/Interfaces/DerivedAttributeOpInterface.td --- a/mlir/include/mlir/Interfaces/DerivedAttributeOpInterface.td +++ b/mlir/include/mlir/Interfaces/DerivedAttributeOpInterface.td @@ -31,6 +31,14 @@ /*methodName=*/"isDerivedAttribute", /*args=*/(ins "StringRef":$name) >, + InterfaceMethod< + /*desc=*/[{ + Materializes the derived attributes. Returns null attribute where + unable to materialize a derived attribute as attribute. + }], + /*retTy=*/"DictionaryAttr", + /*methodName=*/"materializeDerivedAttributes" + >, ]; } diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt --- a/mlir/test/lib/Dialect/Test/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt @@ -21,15 +21,16 @@ ) target_link_libraries(MLIRTestDialect PUBLIC + LLVMSupport MLIRControlFlowInterfaces + MLIRDerivedAttributeOpInterface MLIRDialect MLIRIR + MLIRInferTypeOpInterface MLIRLinalgTransforms MLIRPass MLIRStandardOps MLIRStandardToStandard - MLIRTransforms MLIRTransformUtils - MLIRInferTypeOpInterface - LLVMSupport + MLIRTransforms ) diff --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h --- a/mlir/test/lib/Dialect/Test/TestDialect.h +++ b/mlir/test/lib/Dialect/Test/TestDialect.h @@ -22,6 +22,7 @@ #include "mlir/IR/SymbolTable.h" #include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/DerivedAttributeOpInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffects.h" diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -236,6 +236,15 @@ ); } +def DerivedTypeAttrOp : TEST_Op<"derived_type_attr", []> { + let results = (outs AnyTensor:$output); + DerivedTypeAttr element_dtype = + DerivedTypeAttr<"return getElementTypeOrSelf(output().getType());">; + DerivedAttr size = DerivedAttr<"int", + "return output().getType().cast().getSizeInBits();", + "$_builder.getI32IntegerAttr($_self)">; +} + //===----------------------------------------------------------------------===// // Test Attribute Constraints //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -11,6 +11,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" + using namespace mlir; // Native function for testing NativeCodeCall @@ -129,6 +130,23 @@ }; } // end anonymous namespace +namespace { +struct TestDerivedAttributeDriver + : public PassWrapper { + void runOnFunction() override; +}; +} // end anonymous namespace + +void TestDerivedAttributeDriver::runOnFunction() { + getFunction().walk([](DerivedAttributeOpInterface dOp) { + auto dAttr = dOp.materializeDerivedAttributes(); + if (!dAttr) + return; + for (auto d : dAttr) + dOp.emitRemark() << d.first << " = " << d.second; + }); +} + //===----------------------------------------------------------------------===// // Legalization Driver. //===----------------------------------------------------------------------===// @@ -589,6 +607,9 @@ mlir::PassRegistration("test-return-type", "Run return type functions"); + mlir::PassRegistration( + "test-derived-attr", "Run test derived attributes"); + mlir::PassRegistration("test-patterns", "Run test dialect patterns"); diff --git a/mlir/test/mlir-tblgen/op-attribute.td b/mlir/test/mlir-tblgen/op-attribute.td --- a/mlir/test/mlir-tblgen/op-attribute.td +++ b/mlir/test/mlir-tblgen/op-attribute.td @@ -215,6 +215,7 @@ // DEF: if (name == "element_dtype") return true; // DEF: return false; // DEF: } +// DEF: DerivedTypeAttrOp::materializeDerivedAttributes // Test that only default valued attributes at the end of the arguments // list get default values in the builder signature diff --git a/mlir/test/mlir-tblgen/op-derived-attribute.mlir b/mlir/test/mlir-tblgen/op-derived-attribute.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-tblgen/op-derived-attribute.mlir @@ -0,0 +1,13 @@ +// RUN: mlir-opt -test-derived-attr -verify-diagnostics %s | FileCheck %s --dump-input-on-failure + +// CHECK-LABEL: verifyDerivedAttributes +func @verifyDerivedAttributes() { + // expected-remark @+2 {{element_dtype = f32}} + // expected-remark @+1 {{size = 320}} + %0 = "test.derived_type_attr"() : () -> tensor<10xf32> + // expected-remark @+2 {{element_dtype = i79}} + // expected-remark @+1 {{size = 948}} + %1 = "test.derived_type_attr"() : () -> tensor<12xi79> + + return +} diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -396,21 +396,66 @@ } } - // Generate helper method to query whether a named attribute is a derived - // attribute. This enables, for example, avoiding adding an attribute that - // overlaps with a derived attribute. - auto derivedAttr = make_filter_range(op.getAttributes(), - [](const NamedAttribute &namedAttr) { - return namedAttr.attr.isDerivedAttr(); - }); - if (!derivedAttr.empty()) { + auto derivedAttrs = make_filter_range(op.getAttributes(), + [](const NamedAttribute &namedAttr) { + return namedAttr.attr.isDerivedAttr(); + }); + if (!derivedAttrs.empty()) { opClass.addTrait("DerivedAttributeOpInterface::Trait"); - auto &method = opClass.newMethod("bool", "isDerivedAttribute", - "StringRef name", OpMethod::MP_Static); - auto &body = method.body(); - for (auto namedAttr : derivedAttr) - body << " if (name == \"" << namedAttr.name << "\") return true;\n"; - body << " return false;"; + // Generate helper method to query whether a named attribute is a derived + // attribute. This enables, for example, avoiding adding an attribute that + // overlaps with a derived attribute. + { + auto &method = opClass.newMethod("bool", "isDerivedAttribute", + "StringRef name", OpMethod::MP_Static); + auto &body = method.body(); + for (auto namedAttr : derivedAttrs) + body << " if (name == \"" << namedAttr.name << "\") return true;\n"; + body << " return false;"; + } + // Generate method to materialize derived attributes as a DictionaryAttr. + { + OpMethod &method = + opClass.newMethod("DictionaryAttr", "materializeDerivedAttributes"); + auto &body = method.body(); + + auto nonMaterializable = + make_filter_range(derivedAttrs, [](const NamedAttribute &namedAttr) { + return namedAttr.attr.getConvertFromStorageCall().empty(); + }); + if (!nonMaterializable.empty()) { + std::string attrs; + llvm::raw_string_ostream os(attrs); + interleaveComma(nonMaterializable, os, + [&](const NamedAttribute &attr) { os << attr.name; }); + PrintWarning( + op.getLoc(), + formatv( + "op has non-materialzable derived attributes '{0}', skipping", + os.str())); + body << formatv(" emitOpError(\"op has non-materializable derived " + "attributes '{0}'\");\n", + attrs); + body << " return nullptr;"; + return; + } + + body << " MLIRContext* ctx = getContext();\n"; + body << " Builder odsBuilder(ctx); (void)odsBuilder;\n"; + body << " return DictionaryAttr::get({\n"; + interleave( + derivedAttrs, body, + [&](const NamedAttribute &namedAttr) { + auto tmpl = namedAttr.attr.getConvertFromStorageCall(); + body << " {Identifier::get(\"" << namedAttr.name << "\", ctx),\n" + << tgfmt(tmpl, &fctx.withSelf(namedAttr.name + "()") + .withBuilder("odsBuilder") + .addSubst("_ctx", "ctx")) + << "}"; + }, + ",\n"); + body << "\n }, ctx);"; + } } } @@ -1115,16 +1160,14 @@ body << " " << builderOpState << ".addAttribute(\"operand_segment_sizes\", " "odsBuilder->getI32VectorAttr({"; - llvm::interleaveComma( - llvm::seq(0, op.getNumOperands()), body, [&](int i) { - if (op.getOperand(i).isOptional()) - body << "(" << getArgumentName(op, i) << " ? 1 : 0)"; - else if (op.getOperand(i).isVariadic()) - body << "static_cast(" << getArgumentName(op, i) - << ".size())"; - else - body << "1"; - }); + interleaveComma(llvm::seq(0, op.getNumOperands()), body, [&](int i) { + if (op.getOperand(i).isOptional()) + body << "(" << getArgumentName(op, i) << " ? 1 : 0)"; + else if (op.getOperand(i).isVariadic()) + body << "static_cast(" << getArgumentName(op, i) << ".size())"; + else + body << "1"; + }); body << "}));\n"; } @@ -1222,10 +1265,10 @@ continue; std::string args; llvm::raw_string_ostream os(args); - llvm::interleaveComma(method.getArguments(), os, - [&](const OpInterfaceMethod::Argument &arg) { - os << arg.type << " " << arg.name; - }); + interleaveComma(method.getArguments(), os, + [&](const OpInterfaceMethod::Argument &arg) { + os << arg.type << " " << arg.name; + }); opClass.newMethod(method.getReturnType(), method.getName(), os.str(), method.isStatic() ? OpMethod::MP_Static : OpMethod::MP_None, @@ -1776,7 +1819,7 @@ static void emitOpList(const std::vector &defs, raw_ostream &os) { IfDefScope scope("GET_OP_LIST", os); - llvm::interleave( + interleave( // TODO: We are constructing the Operator wrapper instance just for // getting it's qualified class name here. Reduce the overhead by having a // lightweight version of Operator class just for that purpose.