diff --git a/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td b/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td --- a/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td +++ b/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td @@ -161,7 +161,7 @@ def IRDL_OperationOp : IRDL_Op<"operation", [HasParent<"DialectOp">, NoTerminator, NoRegionArguments, - AtMostOneChildOf<"OperandsOp, ResultsOp">, Symbol]> { + AtMostOneChildOf<"OperandsOp, ResultsOp, AttributesOp">, Symbol]> { let summary = "Define a new operation"; let description = [{ `irdl.operation` defines a new operation belonging to the `irdl.dialect` @@ -260,6 +260,38 @@ let assemblyFormat = " `(` $args `)` attr-dict "; } +def IRDL_AttributesOp : IRDL_Op<"attributes", [HasParent<"OperationOp">]> { + let summary = "Define the attributes of an operation"; + + let description = [{ + `irdl.attributes` defines the attributes of the `irdl.operation` parent + operation definition. + + In the following example, `irdl.attributes` defines the attributes of the + `attr_op` operation: + + ```mlir + irdl.dialect @example { + + irdl.operation @attr_op { + %0 = irdl.any + %1 = irdl.is i64 + irdl.attibutes { + "attr1" = %0, + "attr2" = %1 + } + } + } + ``` + + The operation will expect attribute "attr1" of any type and attribute "attr2" + of type `i64` type. + }]; + + let arguments = (ins Variadic:$attributeValues, StrArrayAttr:$attributeValueNames); + let assemblyFormat = "custom($attributeValues, $attributeValueNames) attr-dict "; +} + //===----------------------------------------------------------------------===// // IRDL Constraint operations //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/IRDL/IR/IRDL.cpp b/mlir/lib/Dialect/IRDL/IR/IRDL.cpp --- a/mlir/lib/Dialect/IRDL/IR/IRDL.cpp +++ b/mlir/lib/Dialect/IRDL/IR/IRDL.cpp @@ -71,6 +71,40 @@ return success(); } +static ParseResult +parseAttributesOp(OpAsmParser &p, + SmallVectorImpl &attrOperands, + ArrayAttr &attrNamesAttr) { + Builder &builder = p.getBuilder(); + SmallVector attrNames; + if (succeeded(p.parseOptionalLBrace())) { + auto parseOperands = [&]() { + StringAttr nameAttr; + OpAsmParser::UnresolvedOperand operand; + if (p.parseAttribute(nameAttr) || p.parseEqual() || + p.parseOperand(operand)) + return failure(); + attrNames.push_back(nameAttr); + attrOperands.push_back(operand); + return success(); + }; + if (p.parseCommaSeparatedList(parseOperands) || p.parseRBrace()) + return failure(); + } + attrNamesAttr = builder.getArrayAttr(attrNames); + return success(); +} + +static void printAttributesOp(OpAsmPrinter &p, AttributesOp op, + OperandRange attrArgs, ArrayAttr attrNames) { + if (attrNames.empty()) + return; + p << "{"; + interleaveComma(llvm::seq(0, attrNames.size()), p, + [&](int i) { p << attrNames[i] << " = " << attrArgs[i]; }); + p << '}'; +} + #include "mlir/Dialect/IRDL/IR/IRDLInterfaces.cpp.inc" #define GET_TYPEDEF_CLASSES diff --git a/mlir/lib/Dialect/IRDL/IRDLLoading.cpp b/mlir/lib/Dialect/IRDL/IRDLLoading.cpp --- a/mlir/lib/Dialect/IRDL/IRDLLoading.cpp +++ b/mlir/lib/Dialect/IRDL/IRDLLoading.cpp @@ -13,13 +13,19 @@ #include "mlir/Dialect/IRDL/IRDLLoading.h" #include "mlir/Dialect/IRDL/IR/IRDL.h" #include "mlir/Dialect/IRDL/IR/IRDLInterfaces.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/ExtensibleDialect.h" +#include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" +#include "mlir/Support/TypeID.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/SMLoc.h" +#include + using namespace mlir; using namespace mlir::irdl; @@ -52,8 +58,8 @@ /// with IRDL. static LogicalResult irdlOpVerifier(Operation *op, ArrayRef> constraints, - ArrayRef operandConstrs, - ArrayRef resultConstrs) { + ArrayRef operandConstrs, ArrayRef resultConstrs, + const DenseMap &attributeConstrs) { /// Check that we have the right number of operands. unsigned numOperands = op->getNumOperands(); size_t numExpectedOperands = operandConstrs.size(); @@ -68,6 +74,29 @@ return op->emitOpError() << numExpectedResults << " results expected, but got " << numResults; + /// Check that we have the right set of attributes + auto attrs = op->getAttrs(); + /// First, check if we do not have excessive attributes being passed + for (auto attr : attrs) { + if (!attributeConstrs.contains(attr.getName())) { + return op->emitOpError() + << "attribute " << attr.getName() << " is not expected"; + } + } + /// Then, check if we have all defined attributes are actually passed + for (auto attrDef : attributeConstrs) { + const auto attrDefName = attrDef.getFirst().getValue(); + const auto *actualAttrIt = + std::find_if(attrs.begin(), attrs.end(), + [attrDefName](const NamedAttribute &actualAttr) { + return actualAttr.getName().getValue() == attrDefName; + }); + if (actualAttrIt == attrs.end()) { + return op->emitOpError() << "attribute \"" << attrDefName + << "\" is expected but not provided"; + } + } + auto emitError = [op]() { return op->emitError(); }; ConstraintVerifier verifier(constraints); @@ -84,6 +113,17 @@ resultConstrs[i]))) return failure(); + /// Check that all attributes satisfy the constraints. + for (const NamedAttribute &attr : attrs) { + const auto constraint = attributeConstrs.find(attr.getName())->getSecond(); + Type type; + attr.getValue().walkImmediateSubElements([](auto &&...) {}, + [&](Type t) { type = t; }); + if (failed(verifier.verify({emitError}, TypeAttr::get(type), constraint))) { + return failure(); + } + } + return success(); } @@ -147,6 +187,23 @@ } } + // Gather which constraint slots correspond to attributes constraints + DenseMap attributesContraints; + auto attributesOp = op.getOp(); + if (attributesOp.has_value()) { + const auto values = attributesOp->getAttributeValues(); + const auto names = attributesOp->getAttributeValueNames(); + + for (const auto &[name, value] : llvm::zip(names, values)) { + for (auto [i, constr] : enumerate(constrToValue)) { + if (constr == value) { + attributesContraints[name.cast()] = i; + break; + } + } + } + } + // IRDL does not support defining custom parsers or printers. auto parser = [](OpAsmParser &parser, OperationState &result) { return failure(); @@ -158,9 +215,10 @@ auto verifier = [constraints{std::move(constraints)}, operandConstraints{std::move(operandConstraints)}, - resultConstraints{std::move(resultConstraints)}](Operation *op) { + resultConstraints{std::move(resultConstraints)}, + attributesContraints{std::move(attributesContraints)}](Operation *op) { return irdlOpVerifier(op, constraints, operandConstraints, - resultConstraints); + resultConstraints, attributesContraints); }; // IRDL does not support defining regions. diff --git a/mlir/test/Dialect/IRDL/testd.irdl.mlir b/mlir/test/Dialect/IRDL/testd.irdl.mlir --- a/mlir/test/Dialect/IRDL/testd.irdl.mlir +++ b/mlir/test/Dialect/IRDL/testd.irdl.mlir @@ -104,4 +104,20 @@ %2 = irdl.any_of(%0, %1) irdl.results(%2, %2) } + + + // CHECK: irdl.operation @attrs { + // CHECK: %[[v0:[^ ]*]] = irdl.is i32 + // CHECK: %[[v1:[^ ]*]] = irdl.is i64 + // CHECK: irdl.attributes {"attr1" = %[[v0]], "attr2" = %[[v1]]} + // CHECK: } + irdl.operation @attrs { + %0 = irdl.is i32 + %1 = irdl.is i64 + + irdl.attributes { + "attr1" = %0, + "attr2" = %1 + } + } } diff --git a/mlir/test/Dialect/IRDL/testd.mlir b/mlir/test/Dialect/IRDL/testd.mlir --- a/mlir/test/Dialect/IRDL/testd.mlir +++ b/mlir/test/Dialect/IRDL/testd.mlir @@ -198,3 +198,47 @@ "testd.constraint_vars"() : () -> (i64, i32) return } + +// ----- + +//===----------------------------------------------------------------------===// +// Constraint attributes +//===----------------------------------------------------------------------===// + +func.func @succeededAttrs() { + // CHECK: "testd.attrs"() {attr1 = 42 : i32, attr2 = 42 : i64} : () -> () + "testd.attrs"() {attr1 = 42 : i32, attr2 = 42 : i64} : () -> () + return +} + +// ----- + +func.func @failedAttrsExcessiveAttr() { + // expected-error@+1 {{attribute "attr3" is not expected}} + "testd.attrs"() {attr1 = 42 : i32, attr2 = 42 : i64, attr3 = 42.0 : f64} : () -> () + return +} + +// ----- + +func.func @failedAttrsMissingAttr() { + // expected-error@+1 {{attribute "attr2" is expected but not provided}} + "testd.attrs"() {attr1 = 42 : i32} : () -> () + return +} + +// ----- + +func.func @failedAttrsConstraint() { + // expected-error@+1 {{expected 'i32' but got 'i64'}} + "testd.attrs"() {attr1 = 42 : i64, attr2 = 42 : i64} : () -> () + return +} + +// ----- + +func.func @failedAttrsConstraint2() { + // expected-error@+1 {{expected 'i64' but got 'i32'}} + "testd.attrs"() {attr1 = 42 : i32, attr2 = 42 : i32} : () -> () + return +}