diff --git a/mlir/include/mlir/Dialect/IRDL/IR/IRDLInterfaces.td b/mlir/include/mlir/Dialect/IRDL/IR/IRDLInterfaces.td --- a/mlir/include/mlir/Dialect/IRDL/IR/IRDLInterfaces.td +++ b/mlir/include/mlir/Dialect/IRDL/IR/IRDLInterfaces.td @@ -15,20 +15,16 @@ include "mlir/IR/OpBase.td" -def VerifyConstraintInterface : OpInterface<"VerifyConstraintInterface"> { +class VerifyInterface : OpInterface<"Verify" # name # "Interface"> { let cppNamespace = "::mlir::irdl"; - let description = [{ - Interface to get an IRDL constraint verifier from an operation. - }]; - let methods = [ InterfaceMethod< [{ Get an instance of a constraint verifier for the associated operation." Returns `nullptr` upon failure. }], - "std::unique_ptr<::mlir::irdl::Constraint>", + "std::unique_ptr<::mlir::irdl::" # return_type # ">", "getVerifier", (ins "::mlir::ArrayRef":$valueToConstr, "::mlir::DenseMap<::mlir::irdl::TypeOp, std::unique_ptr<::mlir::DynamicTypeDefinition>> const&":$types, @@ -37,4 +33,20 @@ ]; } +def VerifyConstraintInterface : VerifyInterface<"Constraint", "Constraint"> { + let cppNamespace = "::mlir::irdl"; + + let description = [{ + Interface to get an IRDL constraint verifier from an operation. + }]; +} + +def VerifyRegionInterface : VerifyInterface<"Region", "RegionConstraint"> { + let cppNamespace = "::mlir::irdl"; + + let description = [{ + Interface to get an IRDL region verifier from an operation. + }]; +} + #endif // MLIR_DIALECT_IRDL_IR_IRDLINTERFACES diff --git a/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td b/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td --- a/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td +++ b/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td @@ -161,7 +161,8 @@ def IRDL_OperationOp : IRDL_Op<"operation", [HasParent<"DialectOp">, NoTerminator, NoRegionArguments, - AtMostOneChildOf<"OperandsOp, ResultsOp, AttributesOp">, Symbol]> { + AtMostOneChildOf<"OperandsOp, ResultsOp, AttributesOp, RegionsOp">, + Symbol]> { let summary = "Define a new operation"; let description = [{ `irdl.operation` defines a new operation belonging to the `irdl.dialect` @@ -297,6 +298,79 @@ let hasVerifier = true; } +def IRDL_RegionOp : IRDL_Op<"region", + [HasParent<"OperationOp">, VerifyRegionInterface, + DeclareOpInterfaceMethods]> { + let summary = "Define a region of an operation"; + let description = [{ + `irdl.region` defines a set of characterstics that + a region of an operation should have. + The characteristics are the set of constraints for the entry block of + a region and the total number of blocks. The number of blocks must be + a non-zero and non-negative integer number. The number of blocks is + optional by default and equals to 1. + + + Example: + + ```mlir + irdl.dialect @example { + irdl.operation @op_with_regions { + %r1 = irdl.region with size 3 + %0 = irdl.any + %r2 = irdl.region(%0) + irdl.regions(%r1, %r2) + } + } + ``` + + In the snippet above the operation is constrained to have two regions. + The first region should contain three blocks with no agruments + in the first one. The second region should have one region + with one argument. + }]; + let arguments = (ins Variadic:$entryBlockArgs, + OptionalAttr:$numberOfBlocks, + UnitAttr:$constrainedArguments); + let results = (outs IRDL_RegionType:$output); + + let assemblyFormat = [{ + ``(`(` $entryBlockArgs $constrainedArguments^ `)`)? + ``(` ` `with` `size` $numberOfBlocks^)? attr-dict + }]; + + let hasVerifier = true; +} + +def IRDL_RegionsOp : IRDL_Op<"regions", [HasParent<"OperationOp">]> { + let summary = "Define the regions of an operation"; + let description = [{ + `irdl.regions` defines the regions of an operation by accepting + values produced by `irdl.region` operation as arguments. + + Example: + + ```mlir + irdl.dialect @example { + irdl.operation @op_with_regions { + %r1 = irdl.region with size 3 + %0 = irdl.any + %r2 = irdl.region(%0) + irdl.regions(%r1, %r2) + } + } + ``` + + In the snippet above the operation is constrained to have two regions. + The first region should contain three blocks with no agruments + in the first one. The second region should have one region + with one argument. + }]; + + let arguments = (ins Variadic:$args); + let assemblyFormat = " `(` $args `)` attr-dict "; +} + //===----------------------------------------------------------------------===// // IRDL Constraint operations //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/IRDL/IR/IRDLTypes.td b/mlir/include/mlir/Dialect/IRDL/IR/IRDLTypes.td --- a/mlir/include/mlir/Dialect/IRDL/IR/IRDLTypes.td +++ b/mlir/include/mlir/Dialect/IRDL/IR/IRDLTypes.td @@ -51,4 +51,11 @@ }]; } +def IRDL_RegionType : IRDL_Type<"Region", "region"> { + let summary = "IRDL handle to a region definition"; + let description = [{ + + }]; +} + #endif // MLIR_DIALECT_IRDL_IR_IRDLTYPES diff --git a/mlir/include/mlir/Dialect/IRDL/IRDLVerifiers.h b/mlir/include/mlir/Dialect/IRDL/IRDLVerifiers.h --- a/mlir/include/mlir/Dialect/IRDL/IRDLVerifiers.h +++ b/mlir/include/mlir/Dialect/IRDL/IRDLVerifiers.h @@ -14,8 +14,10 @@ #define MLIR_DIALECT_IRDL_IRDLVERIFIERS_H #include "mlir/IR/Attributes.h" +#include "mlir/IR/Region.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" #include namespace mlir { @@ -178,6 +180,20 @@ ConstraintVerifier &context) const override; }; +struct RegionConstraint { + explicit RegionConstraint(bool constrainArguments, + SmallVector argumentConstraints, + std::optional blockCount); + + LogicalResult verify(function_ref emitError, + mlir::Region ®ion, + ConstraintVerifier &constraintContext); + +private: + bool constrainArguments; + SmallVector argumentConstraints; + std::optional blockCount; +}; } // namespace irdl } // namespace mlir diff --git a/mlir/lib/Dialect/IRDL/IR/IRDL.cpp b/mlir/lib/Dialect/IRDL/IR/IRDL.cpp --- a/mlir/lib/Dialect/IRDL/IR/IRDL.cpp +++ b/mlir/lib/Dialect/IRDL/IR/IRDL.cpp @@ -9,10 +9,13 @@ #include "mlir/Dialect/IRDL/IR/IRDL.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Diagnostics.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/ExtensibleDialect.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/IR/Operation.h" +#include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" @@ -114,6 +117,16 @@ p << '}'; } +LogicalResult RegionOp::verify() { + if (IntegerAttr numberOfBlocks = getNumberOfBlocksAttr()) { + if (int64_t number = numberOfBlocks.getInt(); number <= 0) { + return emitOpError("the number of blocks is expected to be >= 1 but got ") + << number; + } + } + return success(); +} + #include "mlir/Dialect/IRDL/IR/IRDLInterfaces.cpp.inc" #define GET_TYPEDEF_CLASSES diff --git a/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp b/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp --- a/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp +++ b/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp @@ -7,10 +7,26 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/IRDL/IR/IRDL.h" +#include "mlir/IR/ValueRange.h" using namespace mlir; using namespace mlir::irdl; +static SmallVector +getConstraintIndicesForArgs(mlir::OperandRange args, + ArrayRef valueToConstr) { + SmallVector constraints; + for (Value arg : args) { + for (auto [i, value] : enumerate(valueToConstr)) { + if (value == arg) { + constraints.push_back(i); + break; + } + } + } + return constraints; +} + std::unique_ptr IsOp::getVerifier( ArrayRef valueToConstr, DenseMap> const &types, @@ -24,15 +40,8 @@ DenseMap> const &types, DenseMap> const &attrs) { - SmallVector constraints; - for (Value arg : getArgs()) { - for (auto [i, value] : enumerate(valueToConstr)) { - if (value == arg) { - constraints.push_back(i); - break; - } - } - } + SmallVector constraints = + getConstraintIndicesForArgs(getArgs(), valueToConstr); // Symbol reference case for the base SymbolRefAttr symRef = getBaseType(); @@ -60,17 +69,8 @@ DenseMap> const &types, DenseMap> const &attrs) { - SmallVector constraints; - for (Value arg : getArgs()) { - for (auto [i, value] : enumerate(valueToConstr)) { - if (value == arg) { - constraints.push_back(i); - break; - } - } - } - - return std::make_unique(constraints); + return std::make_unique( + getConstraintIndicesForArgs(getArgs(), valueToConstr)); } std::unique_ptr AllOfOp::getVerifier( @@ -78,17 +78,8 @@ DenseMap> const &types, DenseMap> const &attrs) { - SmallVector constraints; - for (Value arg : getArgs()) { - for (auto [i, value] : enumerate(valueToConstr)) { - if (value == arg) { - constraints.push_back(i); - break; - } - } - } - - return std::make_unique(constraints); + return std::make_unique( + getConstraintIndicesForArgs(getArgs(), valueToConstr)); } std::unique_ptr AnyOp::getVerifier( @@ -98,3 +89,14 @@ &attrs) { return std::make_unique(); } + +std::unique_ptr RegionOp::getVerifier( + ArrayRef valueToConstr, + DenseMap> const &types, + DenseMap> const + &attrs) { + return std::make_unique( + getConstrainedArguments(), + getConstraintIndicesForArgs(getEntryBlockArgs(), valueToConstr), + getNumberOfBlocks()); +} diff --git a/mlir/lib/Dialect/IRDL/IRDLLoading.cpp b/mlir/lib/Dialect/IRDL/IRDLLoading.cpp --- a/mlir/lib/Dialect/IRDL/IRDLLoading.cpp +++ b/mlir/lib/Dialect/IRDL/IRDLLoading.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/IRDL/IRDLLoading.h" #include "mlir/Dialect/IRDL/IR/IRDL.h" #include "mlir/Dialect/IRDL/IR/IRDLInterfaces.h" +#include "mlir/Dialect/IRDL/IRDLVerifiers.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/ExtensibleDialect.h" @@ -105,6 +106,29 @@ return success(); } +static LogicalResult irdlRegionVerifier( + Operation *op, ArrayRef> constraints, + ArrayRef> regionsConstraints) { + if (op->getNumRegions() != regionsConstraints.size()) { + return op->emitOpError() + << "unexpected number of regions: expected: " + << regionsConstraints.size() << " but got: " << op->getNumRegions(); + } + + auto emitError = [op] { return op->emitError(); }; + + ConstraintVerifier verifier(constraints); + + for (auto [constraint, region] : + llvm::zip(regionsConstraints, op->getRegions())) { + if (failed(constraint->verify(emitError, region, verifier))) { + return failure(); + } + } + + return success(); +} + /// Define and load an operation represented by a `irdl.operation` /// operation. static WalkResult loadOperation( @@ -113,6 +137,7 @@ DenseMap> &attrs) { // Resolve SSA values to verifier constraint slots SmallVector constrToValue; + SmallVector regionToValue; for (Operation &op : op->getRegion(0).getOps()) { if (isa(op)) { if (op.getNumResults() != 1) @@ -120,18 +145,35 @@ << "IRDL constraint operations must have exactly one result"; constrToValue.push_back(op.getResult(0)); } + if (isa(op)) { + if (op.getNumResults() != 1) + return op.emitError() + << "IRDL constraint operations must have exactly one result"; + regionToValue.push_back(op.getResult(0)); + } } - // Build the verifiers for each constraint slot - SmallVector> constraints; - for (Value v : constrToValue) { - VerifyConstraintInterface op = - cast(v.getDefiningOp()); - std::unique_ptr verifier = + // Function to build the verifiers for each constraint slot + using Constraints = SmallVector>; + const auto getConstraints = [&] { + Constraints constraints; + for (Value v : constrToValue) { + VerifyConstraintInterface op = + cast(v.getDefiningOp()); + std::unique_ptr verifier = + op.getVerifier(constrToValue, types, attrs); + constraints.push_back(std::move(verifier)); + } + return constraints; + }; + + // Build region constraints + SmallVector> regionConstraints; + for (Value v : regionToValue) { + VerifyRegionInterface op = cast(v.getDefiningOp()); + std::unique_ptr verifier = op.getVerifier(constrToValue, types, attrs); - if (!verifier) - return WalkResult::interrupt(); - constraints.push_back(std::move(verifier)); + regionConstraints.push_back(std::move(verifier)); } SmallVector operandConstraints; @@ -190,6 +232,13 @@ printer.printGenericOp(op); }; + Constraints constraints = getConstraints(); + // We need to check all the constraints once to be valid + for (const auto &constraint : constraints) { + if (!constraint) { + return WalkResult::interrupt(); + } + } auto verifier = [constraints{std::move(constraints)}, operandConstraints{std::move(operandConstraints)}, @@ -199,8 +248,11 @@ resultConstraints, attributesContraints); }; - // IRDL does not support defining regions. - auto regionVerifier = [](Operation *op) { return success(); }; + auto regionVerifier = + [constraints{getConstraints()}, + regionConstraints{std::move(regionConstraints)}](Operation *op) { + return irdlRegionVerifier(op, constraints, regionConstraints); + }; auto opDef = DynamicOpDefinition::get( op.getName(), dialect, std::move(verifier), std::move(regionVerifier), diff --git a/mlir/lib/Dialect/IRDL/IRDLVerifiers.cpp b/mlir/lib/Dialect/IRDL/IRDLVerifiers.cpp --- a/mlir/lib/Dialect/IRDL/IRDLVerifiers.cpp +++ b/mlir/lib/Dialect/IRDL/IRDLVerifiers.cpp @@ -175,3 +175,18 @@ ConstraintVerifier &context) const { return success(); } + +RegionConstraint::RegionConstraint(bool constrainArguments, + SmallVector argumentConstraints, + std::optional blockCount) + : constrainArguments(constrainArguments), + argumentConstraints(std::move(argumentConstraints)), + blockCount(blockCount) {} + +LogicalResult +RegionConstraint::verify(function_ref emitError, + mlir::Region ®ion, + ConstraintVerifier &constraintContext) { + + return success(); +} diff --git a/mlir/test/Dialect/IRDL/regions.irdl.mlir b/mlir/test/Dialect/IRDL/regions.irdl.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/IRDL/regions.irdl.mlir @@ -0,0 +1,19 @@ +// RUN: mlir-opt %s -verify-diagnostics -split-input-file + +irdl.dialect @testRegionOpNegativeNumber { + irdl.operation @op { + // expected-error @below {{'irdl.region' op the number of blocks is expected to be >= 1 but got -42}} + %r1 = irdl.region with size -42 + } +} + +// ----- + +irdl.dialect @testRegionsOpWrongOperation { + irdl.operation @op { + // expected-note @below {{prior use here}} + %r1 = irdl.any + // expected-error @below {{use of value '%r1' expects different type than prior uses: '!irdl.region' vs '!irdl.attribute'}} + irdl.regions(%r1) + } +} diff --git a/mlir/test/Dialect/IRDL/testd.irdl.mlir b/mlir/test/Dialect/IRDL/testd.irdl.mlir --- a/mlir/test/Dialect/IRDL/testd.irdl.mlir +++ b/mlir/test/Dialect/IRDL/testd.irdl.mlir @@ -119,4 +119,19 @@ "attr2" = %1 } } + // CHECK: irdl.operation @regions { + // CHECK: %[[r0:[^ ]*]] = irdl.region with size 1 + // CHECK: %[[v0:[^ ]*]] = irdl.any + // CHECK: %[[r1:[^ ]*]] = irdl.region(%[[v0]]) + // CHECK: %[[r2:[^ ]*]] = irdl.region with size 3 + // CHECK: irdl.regions(%[[r0]], %[[r1]], %[[r2]]) + // CHECK: } + irdl.operation @regions { + %r0 = irdl.region with size 1 + %v0 = irdl.any + %r1 = irdl.region(%v0) + %r2 = irdl.region with size 3 + + irdl.regions(%r0, %r1, %r2) + } } diff --git a/mlir/test/Dialect/IRDL/testd.mlir b/mlir/test/Dialect/IRDL/testd.mlir --- a/mlir/test/Dialect/IRDL/testd.mlir +++ b/mlir/test/Dialect/IRDL/testd.mlir @@ -234,3 +234,30 @@ "testd.attrs"() {attr1 = i32, attr2 = i32} : () -> () return } + +// ----- + +//===----------------------------------------------------------------------===// +// Regions +//===----------------------------------------------------------------------===// + +func.func @succeededRegions() { + "testd.regions"() ({ + ^bb1: + llvm.unreachable + }, + { + ^bb1(%arg0: i32): + llvm.unreachable + }, + { + ^bb1: + cf.br ^bb3 + ^bb2: + cf.br ^bb3 + ^bb3: + llvm.unreachable + }) : () -> () + + return +}