diff --git a/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td b/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td --- a/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td +++ b/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td @@ -161,7 +161,7 @@ def IRDL_OperationOp : IRDL_Op<"operation", [HasParent<"DialectOp">, NoTerminator, NoRegionArguments, - AtMostOneChildOf<"OperandsOp, ResultsOp">, Symbol]> { + AtMostOneChildOf<"OperandsOp, ResultsOp, AttributesOp">, Symbol]> { let summary = "Define a new operation"; let description = [{ `irdl.operation` defines a new operation belonging to the `irdl.dialect` @@ -260,6 +260,43 @@ let assemblyFormat = " `(` $args `)` attr-dict "; } +def IRDL_AttributesOp : IRDL_Op<"attributes", [HasParent<"OperationOp">]> { + let summary = "Define the attributes of an operation"; + + let description = [{ + `irdl.attributes` defines the attributes of the `irdl.operation` parent + operation definition. + + In the following example, `irdl.attributes` defines the attributes of the + `attr_op` operation: + + ```mlir + irdl.dialect @example { + + irdl.operation @attr_op { + %0 = irdl.any + %1 = irdl.is i64 + irdl.attibutes { + "attr1" = %0, + "attr2" = %1 + } + } + } + ``` + + The operation will expect an arbitrary attribute "attr1" and an + attribute "attr2" with value `i64`. + }]; + + let arguments = (ins Variadic:$attributeValues, + StrArrayAttr:$attributeValueNames); + let assemblyFormat = [{ + custom($attributeValues, $attributeValueNames) attr-dict + }]; + + let hasVerifier = true; +} + //===----------------------------------------------------------------------===// // IRDL Constraint operations //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/IRDL/IR/IRDL.cpp b/mlir/lib/Dialect/IRDL/IR/IRDL.cpp --- a/mlir/lib/Dialect/IRDL/IR/IRDL.cpp +++ b/mlir/lib/Dialect/IRDL/IR/IRDL.cpp @@ -71,6 +71,49 @@ return success(); } +LogicalResult AttributesOp::verify() { + const size_t namesSize = getAttributeValueNames().size(); + const size_t valuesSize = getAttributeValues().size(); + + if (namesSize != valuesSize) + return emitOpError() + << "the number of attribute names and their constraints must be " + "the same but got " + << namesSize << " and " << valuesSize << " respectively"; + + return success(); +} + +static ParseResult +parseAttributesOp(OpAsmParser &p, + SmallVectorImpl &attrOperands, + ArrayAttr &attrNamesAttr) { + Builder &builder = p.getBuilder(); + SmallVector attrNames; + if (succeeded(p.parseOptionalLBrace())) { + auto parseOperands = [&]() { + if (p.parseAttribute(attrNames.emplace_back()) || p.parseEqual() || + p.parseOperand(attrOperands.emplace_back())) + return failure(); + return success(); + }; + if (p.parseCommaSeparatedList(parseOperands) || p.parseRBrace()) + return failure(); + } + attrNamesAttr = builder.getArrayAttr(attrNames); + return success(); +} + +static void printAttributesOp(OpAsmPrinter &p, AttributesOp op, + OperandRange attrArgs, ArrayAttr attrNames) { + if (attrNames.empty()) + return; + p << "{"; + interleaveComma(llvm::seq(0, attrNames.size()), p, + [&](int i) { p << attrNames[i] << " = " << attrArgs[i]; }); + p << '}'; +} + #include "mlir/Dialect/IRDL/IR/IRDLInterfaces.cpp.inc" #define GET_TYPEDEF_CLASSES diff --git a/mlir/lib/Dialect/IRDL/IRDLLoading.cpp b/mlir/lib/Dialect/IRDL/IRDLLoading.cpp --- a/mlir/lib/Dialect/IRDL/IRDLLoading.cpp +++ b/mlir/lib/Dialect/IRDL/IRDLLoading.cpp @@ -13,8 +13,10 @@ #include "mlir/Dialect/IRDL/IRDLLoading.h" #include "mlir/Dialect/IRDL/IR/IRDL.h" #include "mlir/Dialect/IRDL/IR/IRDLInterfaces.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/ExtensibleDialect.h" +#include "mlir/IR/OperationSupport.h" #include "mlir/Support/LogicalResult.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" @@ -52,8 +54,8 @@ /// with IRDL. static LogicalResult irdlOpVerifier(Operation *op, ArrayRef> constraints, - ArrayRef operandConstrs, - ArrayRef resultConstrs) { + ArrayRef operandConstrs, ArrayRef resultConstrs, + const DenseMap &attributeConstrs) { /// Check that we have the right number of operands. unsigned numOperands = op->getNumOperands(); size_t numExpectedOperands = operandConstrs.size(); @@ -68,10 +70,26 @@ return op->emitOpError() << numExpectedResults << " results expected, but got " << numResults; - auto emitError = [op]() { return op->emitError(); }; + auto emitError = [op] { return op->emitError(); }; ConstraintVerifier verifier(constraints); + /// Сheck that we have all needed attributes passed + /// and they satisfy the constraints. + DictionaryAttr actualAttrs = op->getAttrDictionary(); + + for (auto [name, constraint] : attributeConstrs) { + /// First, check if the attribute actually passed. + std::optional actual = actualAttrs.getNamed(name); + if (!actual.has_value()) + return op->emitOpError() + << "attribute " << name << " is expected but not provided"; + + /// Then, check if the attribute value satisfies the constraint. + if (failed(verifier.verify({emitError}, actual->getValue(), constraint))) + return failure(); + } + /// Check that all operands satisfy the constraints. for (auto [i, operandType] : enumerate(op->getOperandTypes())) if (failed(verifier.verify({emitError}, TypeAttr::get(operandType), @@ -147,6 +165,23 @@ } } + // Gather which constraint slots correspond to attributes constraints + DenseMap attributesContraints; + auto attributesOp = op.getOp(); + if (attributesOp.has_value()) { + const Operation::operand_range values = attributesOp->getAttributeValues(); + const ArrayAttr names = attributesOp->getAttributeValueNames(); + + for (const auto &[name, value] : llvm::zip(names, values)) { + for (auto [i, constr] : enumerate(constrToValue)) { + if (constr == value) { + attributesContraints[name.cast()] = i; + break; + } + } + } + } + // IRDL does not support defining custom parsers or printers. auto parser = [](OpAsmParser &parser, OperationState &result) { return failure(); @@ -158,9 +193,10 @@ auto verifier = [constraints{std::move(constraints)}, operandConstraints{std::move(operandConstraints)}, - resultConstraints{std::move(resultConstraints)}](Operation *op) { + resultConstraints{std::move(resultConstraints)}, + attributesContraints{std::move(attributesContraints)}](Operation *op) { return irdlOpVerifier(op, constraints, operandConstraints, - resultConstraints); + resultConstraints, attributesContraints); }; // IRDL does not support defining regions. diff --git a/mlir/test/Dialect/IRDL/attributes-op.irdl.mlir b/mlir/test/Dialect/IRDL/attributes-op.irdl.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/IRDL/attributes-op.irdl.mlir @@ -0,0 +1,22 @@ +// RUN: mlir-opt %s -verify-diagnostics -split-input-file +irdl.dialect @errors { + irdl.operation @attrs1 { + %0 = irdl.is i32 + %1 = irdl.is i64 + + // expected-error@+1 {{'irdl.attributes' op the number of attribute names and their constraints must be the same but got 1 and 2 respectively}} + "irdl.attributes"(%0, %1) <{attributeValueNames = ["attr1"]}> : (!irdl.attribute, !irdl.attribute) -> () + } +} + +// ----- + +irdl.dialect @errors { + irdl.operation @attrs2 { + %0 = irdl.is i32 + %1 = irdl.is i64 + + // expected-error@+1 {{'irdl.attributes' op the number of attribute names and their constraints must be the same but got 2 and 1 respectively}} + "irdl.attributes"(%0) <{attributeValueNames = ["attr1", "attr2"]}> : (!irdl.attribute) -> () + } +} diff --git a/mlir/test/Dialect/IRDL/testd.irdl.mlir b/mlir/test/Dialect/IRDL/testd.irdl.mlir --- a/mlir/test/Dialect/IRDL/testd.irdl.mlir +++ b/mlir/test/Dialect/IRDL/testd.irdl.mlir @@ -104,4 +104,19 @@ %2 = irdl.any_of(%0, %1) irdl.results(%2, %2) } + + // CHECK: irdl.operation @attrs { + // CHECK: %[[v0:[^ ]*]] = irdl.is i32 + // CHECK: %[[v1:[^ ]*]] = irdl.is i64 + // CHECK: irdl.attributes {"attr1" = %[[v0]], "attr2" = %[[v1]]} + // CHECK: } + irdl.operation @attrs { + %0 = irdl.is i32 + %1 = irdl.is i64 + + irdl.attributes { + "attr1" = %0, + "attr2" = %1 + } + } } diff --git a/mlir/test/Dialect/IRDL/testd.mlir b/mlir/test/Dialect/IRDL/testd.mlir --- a/mlir/test/Dialect/IRDL/testd.mlir +++ b/mlir/test/Dialect/IRDL/testd.mlir @@ -198,3 +198,39 @@ "testd.constraint_vars"() : () -> (i64, i32) return } + +// ----- + +//===----------------------------------------------------------------------===// +// Constraint attributes +//===----------------------------------------------------------------------===// + +func.func @succeededAttrs() { + // CHECK: "testd.attrs"() {attr1 = i32, attr2 = i64} : () -> () + "testd.attrs"() {attr1 = i32, attr2 = i64} : () -> () + return +} + +// ----- + +func.func @failedAttrsMissingAttr() { + // expected-error@+1 {{attribute "attr2" is expected but not provided}} + "testd.attrs"() {attr1 = i32} : () -> () + return +} + +// ----- + +func.func @failedAttrsConstraint() { + // expected-error@+1 {{expected 'i32' but got 'i64'}} + "testd.attrs"() {attr1 = i64, attr2 = i64} : () -> () + return +} + +// ----- + +func.func @failedAttrsConstraint2() { + // expected-error@+1 {{expected 'i64' but got 'i32'}} + "testd.attrs"() {attr1 = i32, attr2 = i32} : () -> () + return +}