diff --git a/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td b/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td --- a/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td +++ b/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td @@ -372,5 +372,74 @@ let assemblyFormat = [{ ` ` attr-dict }]; } +def IRDL_AnyOf : IRDL_ConstraintOp<"any_of", + [ParentOneOf<["TypeOp", "AttributeOp", "OperationOp"]>, + SameOperandsAndResultType]> { + let summary = "Constraints to the union of the provided constraints"; + let description = [{ + `irdl.any_of` defines a constraint that accepts any type or attribute that + satisfies at least one of its provided type constraints. + + Example: + + ```mlir + irdl.dialect cmath { + irdl.type complex { + %0 = irdl.is i32 + %1 = irdl.is i64 + %2 = irdl.is f32 + %3 = irdl.is f64 + %4 = irdl.any_of(%0, %1, %2, %3) + irdl.parameters(%4) + } + } + ``` + + The above program defines a type `complex` inside the dialect `cmath` that + can have a single type parameter that can be either `i32`, `i64`, `f32` or + `f32`. + }]; + + let arguments = (ins Variadic:$args); + let results = (outs IRDL_AttributeType:$output); + let assemblyFormat = [{ `(` $args `)` ` ` attr-dict }]; +} + +def IRDL_AllOf : IRDL_ConstraintOp<"all_of", + [ParentOneOf<["TypeOp", "AttributeOp", "OperationOp"]>, + SameOperandsAndResultType]> { + let summary = "Constraints to the intersection of the provided constraints"; + let description = [{ + `irdl.all_of` defines a constraint that accepts any type or attribute that + satisfies all of its provided constraints. + + Example: + + ```mlir + irdl.dialect cmath { + irdl.type complex_f32 { + %0 = irdl.is i32 + %1 = irdl.is f32 + %2 = irdl.any_of(%0, %1) // is 32-bit + + %3 = irdl.is f32 + %4 = irdl.is f64 + %5 = irdl.any_of(%3, %4) // is a float + + %6 = irdl.all_of(%2, %5) // is a 32-bit float + irdl.parameters(%6) + } + } + ``` + + The above program defines a type `complex` inside the dialect `cmath` that + can has one parameter that must be 32-bit long and a float (in other + words, that must be `f32`). + }]; + + let arguments = (ins Variadic:$args); + let results = (outs IRDL_AttributeType:$output); + let assemblyFormat = [{ `(` $args `)` ` ` attr-dict }]; +} #endif // MLIR_DIALECT_IRDL_IR_IRDLOPS diff --git a/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp b/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp --- a/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp +++ b/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp @@ -65,6 +65,40 @@ "either a type or an attribute definition"); } +std::unique_ptr AnyOf::getVerifier( + SmallVector const &valueToConstr, + DenseMap> &types, + DenseMap> &attrs) { + SmallVector constraints; + for (Value arg : getArgs()) { + for (auto [i, value] : enumerate(valueToConstr)) { + if (value == arg) { + constraints.push_back(i); + break; + } + } + } + + return std::make_unique(constraints); +} + +std::unique_ptr AllOf::getVerifier( + SmallVector const &valueToConstr, + DenseMap> &types, + DenseMap> &attrs) { + SmallVector constraints; + for (Value arg : getArgs()) { + for (auto [i, value] : enumerate(valueToConstr)) { + if (value == arg) { + constraints.push_back(i); + break; + } + } + } + + return std::make_unique(constraints); +} + std::unique_ptr Any::getVerifier( SmallVector const &valueToConstr, DenseMap> &types, diff --git a/mlir/lib/Dialect/IRDL/IRDLRegistration.cpp b/mlir/lib/Dialect/IRDL/IRDLRegistration.cpp --- a/mlir/lib/Dialect/IRDL/IRDLRegistration.cpp +++ b/mlir/lib/Dialect/IRDL/IRDLRegistration.cpp @@ -246,6 +246,116 @@ return verifier; } +/// Get the possible bases of a constraint. Return `true` if all bases can +/// potentially be matched. +/// A base is a type or an attribute definition. For instance, the base of +/// `irdl.parametric "!builtin.complex"(...)` is `builtin.complex`. +/// This function returns the following information through arguments: +/// - `paramIds`: the set of type or attribute IDs that are used as bases. +/// - `paramIrdlOps`: the set of IRDL operations that are used as bases. +/// - `isIds`: the set of type or attribute IDs that are used in `irdl.is` +/// constraints. +static bool getBases(Operation *op, SmallPtrSet ¶mIds, + SmallPtrSet ¶mIrdlOps, + SmallPtrSet &isIds) { + // For `irdl.any_of`, we get the bases from all its arguments. + if (auto anyOf = dyn_cast(op)) { + bool has_any = false; + for (Value arg : anyOf.getArgs()) + has_any &= getBases(arg.getDefiningOp(), paramIds, paramIrdlOps, isIds); + return has_any; + } + + // For `irdl.all_of`, we get the bases from the first argument. + // This is restrictive, but we can relax it later if needed. + if (auto allOf = dyn_cast(op)) + return getBases(allOf.getArgs()[0].getDefiningOp(), paramIds, paramIrdlOps, + isIds); + + // For `irdl.parametric`, we get directly the base from the operation. + if (auto params = dyn_cast(op)) { + SymbolRefAttr symRef = params.getBaseType(); + Operation *defOp = SymbolTable::lookupNearestSymbolFrom(op, symRef); + assert(defOp && "symbol reference should refer to an existing operation"); + paramIrdlOps.insert(defOp); + return false; + } + + // For `irdl.is`, we get the base TypeID directly. + if (auto is = dyn_cast(op)) { + Attribute expected = is.getExpected(); + isIds.insert(expected.getTypeID()); + return false; + } + + // For `irdl.any`, we return `false` since we can match any type or attribute + // base. + if (auto isA = dyn_cast(op)) + return true; + + llvm_unreachable("unknown IRDL constraint"); +} + +/// Check that an any_of is in the subset IRDL can handle. +/// IRDL uses a greedy algorithm to match constraints. This means that if we +/// encounter an `any_of` with multiple constraints, we will match the first +/// constraint that is satisfied. Thus, the order of constraints matter in +/// `any_of` with our current algorithm. +/// In order to make the order of constraints irrelevant, we require that +/// all `any_of` constraint parameters are disjoint. For this, we check that +/// the base parameters are all disjoints between `parametric` operations, and +/// that they are disjoint between `parametric` and `is` operations. +/// This restriction will be relaxed in the future, when we will change our +/// algorithm to be non-greedy. +static LogicalResult checkCorrectAnyOf(AnyOf anyOf) { + SmallPtrSet paramIds; + SmallPtrSet paramIrdlOps; + SmallPtrSet isIds; + + for (Value arg : anyOf.getArgs()) { + Operation *argOp = arg.getDefiningOp(); + SmallPtrSet argParamIds; + SmallPtrSet argParamIrdlOps; + SmallPtrSet argIsIds; + + // Get the bases of this argument. If it can match any type or attribute, + // then our `any_of` should not be allowed. + if (getBases(argOp, argParamIds, argParamIrdlOps, argIsIds)) + return failure(); + + // We check that the base parameters are all disjoints between `parametric` + // operations, and that they are disjoint between `parametric` and `is` + // operations. + for (TypeID id : argParamIds) { + if (isIds.count(id)) + return failure(); + bool inserted = paramIds.insert(id).second; + if (!inserted) + return failure(); + } + + // We check that the base parameters are all disjoints with `irdl.is` + // operations. + for (TypeID id : isIds) { + if (paramIds.count(id)) + return failure(); + isIds.insert(id); + } + + // We check that all `parametric` operations are disjoint. We do not + // need to check that they are disjoint with `is` operations, since + // `is` operations cannot refer to attributes defined with `irdl.parametric` + // operations. + for (Operation *op : argParamIrdlOps) { + bool inserted = paramIrdlOps.insert(op).second; + if (!inserted) + return failure(); + } + } + + return success(); +} + /// Register all dialects in the given module, without registering any /// operation, type or attribute definitions. static DenseMap @@ -300,6 +410,16 @@ } LogicalResult mlir::irdl::registerDialects(ModuleOp op) { + + // First, check that all any_of constraints are in a correct form. + // This is to ensure we can do the verification correctly. + WalkResult anyOfCorrects = + op.walk([](AnyOf anyOf) { return (WalkResult)checkCorrectAnyOf(anyOf); }); + if (anyOfCorrects.wasInterrupted()) { + op.emitError("any_of constraints are not in the correct form"); + return failure(); + } + // Preallocate all dialects, and type and attribute definitions. // In particular, this allocates TypeIDs so type and attributes can have // verifiers that refer to each other. diff --git a/mlir/test/Dialect/IRDL/cmath.irdl.mlir b/mlir/test/Dialect/IRDL/cmath.irdl.mlir --- a/mlir/test/Dialect/IRDL/cmath.irdl.mlir +++ b/mlir/test/Dialect/IRDL/cmath.irdl.mlir @@ -6,11 +6,15 @@ // CHECK: irdl.type @complex { // CHECK: %[[v0:[^ ]*]] = irdl.is f32 - // CHECK: irdl.parameters(%[[v0]]) + // CHECK: %[[v1:[^ ]*]] = irdl.is f64 + // CHECK: %[[v2:[^ ]*]] = irdl.any_of(%[[v0]], %[[v1]]) + // CHECK: irdl.parameters(%[[v2]]) // CHECK: } irdl.type @complex { %0 = irdl.is f32 - irdl.parameters(%0) + %1 = irdl.is f64 + %2 = irdl.any_of(%0, %1) + irdl.parameters(%2) } // CHECK: irdl.operation @norm { @@ -28,13 +32,17 @@ // CHECK: irdl.operation @mul { // CHECK: %[[v0:[^ ]*]] = irdl.is f32 - // CHECK: %[[v3:[^ ]*]] = irdl.parametric @complex<%[[v0]]> + // CHECK: %[[v1:[^ ]*]] = irdl.is f64 + // CHECK: %[[v2:[^ ]*]] = irdl.any_of(%[[v0]], %[[v1]]) + // CHECK: %[[v3:[^ ]*]] = irdl.parametric @complex<%[[v2]]> // CHECK: irdl.operands(%[[v3]], %[[v3]]) // CHECK: irdl.results(%[[v3]]) // CHECK: } irdl.operation @mul { %0 = irdl.is f32 - %3 = irdl.parametric @complex<%0> + %1 = irdl.is f64 + %2 = irdl.any_of(%0, %1) + %3 = irdl.parametric @complex<%2> irdl.operands(%3, %3) irdl.results(%3) } diff --git a/mlir/test/Dialect/IRDL/cyclic-types.irdl.mlir b/mlir/test/Dialect/IRDL/cyclic-types.irdl.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/IRDL/cyclic-types.irdl.mlir @@ -0,0 +1,50 @@ +// RUN: mlir-opt %s | mlir-opt | FileCheck %s + +// Types that have cyclic references. + +// CHECK: irdl.dialect @testd { +irdl.dialect @testd { + // CHECK: irdl.type @self_referencing { + // CHECK: %[[v0:[^ ]*]] = irdl.any + // CHECK: %[[v1:[^ ]*]] = irdl.parametric @self_referencing<%[[v0]]> + // CHECK: %[[v2:[^ ]*]] = irdl.is i32 + // CHECK: %[[v3:[^ ]*]] = irdl.any_of(%[[v1]], %[[v2]]) + // CHECK: irdl.parameters(%[[v3]]) + // CHECK: } + irdl.type @self_referencing { + %0 = irdl.any + %1 = irdl.parametric @self_referencing<%0> + %2 = irdl.is i32 + %3 = irdl.any_of(%1, %2) + irdl.parameters(%3) + } + + + // CHECK: irdl.type @type1 { + // CHECK: %[[v0:[^ ]*]] = irdl.any + // CHECK: %[[v1:[^ ]*]] = irdl.parametric @type2<%[[v0]]> + // CHECK: %[[v2:[^ ]*]] = irdl.is i32 + // CHECK: %[[v3:[^ ]*]] = irdl.any_of(%[[v1]], %[[v2]]) + // CHECK: irdl.parameters(%[[v3]]) + irdl.type @type1 { + %0 = irdl.any + %1 = irdl.parametric @type2<%0> + %2 = irdl.is i32 + %3 = irdl.any_of(%1, %2) + irdl.parameters(%3) + } + + // CHECK: irdl.type @type2 { + // CHECK: %[[v0:[^ ]*]] = irdl.any + // CHECK: %[[v1:[^ ]*]] = irdl.parametric @type1<%[[v0]]> + // CHECK: %[[v2:[^ ]*]] = irdl.is i32 + // CHECK: %[[v3:[^ ]*]] = irdl.any_of(%[[v1]], %[[v2]]) + // CHECK: irdl.parameters(%[[v3]]) + irdl.type @type2 { + %0 = irdl.any + %1 = irdl.parametric @type1<%0> + %2 = irdl.is i32 + %3 = irdl.any_of(%1, %2) + irdl.parameters(%3) + } +} \ No newline at end of file diff --git a/mlir/test/Dialect/IRDL/cyclic-types.mlir b/mlir/test/Dialect/IRDL/cyclic-types.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/IRDL/cyclic-types.mlir @@ -0,0 +1,57 @@ +// RUN: mlir-opt %s --irdl-file=%S/cyclic-types.irdl.mlir -split-input-file -verify-diagnostics | FileCheck %s + +// Types that have cyclic references. + +// CHECK: !testd.self_referencing +func.func @no_references(%v: !testd.self_referencing) { + return +} + +// ----- + +// CHECK: !testd.self_referencing> +func.func @one_reference(%v: !testd.self_referencing>) { + return +} + +// ----- + +// expected-error@+1 {{'i64' does not satisfy the constraint}} +func.func @wrong_parameter(%v: !testd.self_referencing) { + return +} + +// ----- + +// CHECK: !testd.type1 +func.func @type1_no_references(%v: !testd.type1) { + return +} + +// ----- + +// CHECK: !testd.type1> +func.func @type1_one_references(%v: !testd.type1>) { + return +} + +// ----- + +// CHECK: !testd.type1>> +func.func @type1_two_references(%v: !testd.type1>>) { + return +} + +// ----- + +// expected-error@+1 {{'i64' does not satisfy the constraint}} +func.func @wrong_parameter_type1(%v: !testd.type1) { + return +} + +// ----- + +// expected-error@+1 {{'i64' does not satisfy the constraint}} +func.func @wrong_parameter_type2(%v: !testd.type2) { + return +} diff --git a/mlir/test/Dialect/IRDL/test-type.irdl.mlir b/mlir/test/Dialect/IRDL/test-type.irdl.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/IRDL/test-type.irdl.mlir @@ -0,0 +1,33 @@ +// RUN: mlir-opt %s | mlir-opt | FileCheck %s + +module { + // CHECK-LABEL: irdl.dialect @testd { + irdl.dialect @testd { + // CHECK: irdl.type @singleton + irdl.type @singleton + + // CHECK: irdl.type @parametrized { + // CHECK: %[[v0:[^ ]*]] = irdl.any + // CHECK: %[[v1:[^ ]*]] = irdl.is i32 + // CHECK: %[[v2:[^ ]*]] = irdl.is i64 + // CHECK: %[[v3:[^ ]*]] = irdl.any_of(%[[v1]], %[[v2]]) + // CHECK: irdl.parameters(%[[v0]], %[[v3]]) + // CHECK: } + irdl.type @parametrized { + %0 = irdl.any + %1 = irdl.is i32 + %2 = irdl.is i64 + %3 = irdl.any_of(%1, %2) + irdl.parameters(%0, %3) + } + + // CHECK: irdl.operation @any { + // CHECK: %[[v0:[^ ]*]] = irdl.any + // CHECK: irdl.results(%[[v0]]) + // CHECK: } + irdl.operation @any { + %0 = irdl.any + irdl.results(%0) + } + } +} diff --git a/mlir/test/Dialect/IRDL/test-type.mlir b/mlir/test/Dialect/IRDL/test-type.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/IRDL/test-type.mlir @@ -0,0 +1,35 @@ +// RUN: mlir-opt %s --irdl-file=%S/test-type.irdl.mlir -split-input-file -verify-diagnostics | FileCheck %s + +func.func @succeededTypeVerifier() { + // CHECK: "testd.any"() : () -> !testd.singleton + "testd.any"() : () -> !testd.singleton + + // CHECK-NEXT: "testd.any"() : () -> !testd.parametrized + "testd.any"() : () -> !testd.parametrized + + // CHECK: "testd.any"() : () -> !testd.parametrized + "testd.any"() : () -> !testd.parametrized + + return +} + +// ----- + +func.func @failedSingletonVerifier() { + // expected-error@+1 {{expected 0 type arguments, but had 1}} + "testd.any"() : () -> !testd.singleton +} + +// ----- + +func.func @failedParametrizedVerifierWrongNumOfArgs() { + // expected-error@+1 {{expected 2 type arguments, but had 1}} + "testd.any"() : () -> !testd.parametrized +} + +// ----- + +func.func @failedParametrizedVerifierWrongArgument() { + // expected-error@+1 {{'i1' does not satisfy the constraint}} + "testd.any"() : () -> !testd.parametrized +} diff --git a/mlir/test/Dialect/IRDL/testd.irdl.mlir b/mlir/test/Dialect/IRDL/testd.irdl.mlir --- a/mlir/test/Dialect/IRDL/testd.irdl.mlir +++ b/mlir/test/Dialect/IRDL/testd.irdl.mlir @@ -29,6 +29,34 @@ irdl.results(%0) } + // CHECK: irdl.operation @anyof { + // CHECK: %[[v0:[^ ]*]] = irdl.is i32 + // CHECK: %[[v1:[^ ]*]] = irdl.is i64 + // CHECK: %[[v2:[^ ]*]] = irdl.any_of(%[[v0]], %[[v1]]) + // CHECK: irdl.results(%[[v2]]) + // CHECK: } + irdl.operation @anyof { + %0 = irdl.is i32 + %1 = irdl.is i64 + %2 = irdl.any_of(%0, %1) + irdl.results(%2) + } + + // CHECK: irdl.operation @all_of { + // CHECK: %[[v0:[^ ]*]] = irdl.is i32 + // CHECK: %[[v1:[^ ]*]] = irdl.is i64 + // CHECK: %[[v2:[^ ]*]] = irdl.any_of(%[[v0]], %[[v1]]) + // CHECK: %[[v3:[^ ]*]] = irdl.all_of(%[[v2]], %[[v1]]) + // CHECK: irdl.results(%[[v3]]) + // CHECK: } + irdl.operation @all_of { + %0 = irdl.is i32 + %1 = irdl.is i64 + %2 = irdl.any_of(%0, %1) + %3 = irdl.all_of(%2, %1) + irdl.results(%3) + } + // CHECK: irdl.operation @any { // CHECK: %[[v0:[^ ]*]] = irdl.any // CHECK: irdl.results(%[[v0]]) @@ -51,21 +79,29 @@ // CHECK: irdl.operation @dynparams { // CHECK: %[[v0:[^ ]*]] = irdl.is i32 - // CHECK: %[[v3:[^ ]*]] = irdl.parametric @parametric<%[[v0]]> + // CHECK: %[[v1:[^ ]*]] = irdl.is i64 + // CHECK: %[[v2:[^ ]*]] = irdl.any_of(%[[v0]], %[[v1]]) + // CHECK: %[[v3:[^ ]*]] = irdl.parametric @parametric<%[[v2]]> // CHECK: irdl.results(%[[v3]]) // CHECK: } irdl.operation @dynparams { %0 = irdl.is i32 - %3 = irdl.parametric @parametric<%0> + %1 = irdl.is i64 + %2 = irdl.any_of(%0, %1) + %3 = irdl.parametric @parametric<%2> irdl.results(%3) } // CHECK: irdl.operation @constraint_vars { - // CHECK: %[[v0:[^ ]*]] = irdl.any - // CHECK: irdl.results(%[[v0]], %[[v0]]) + // CHECK: %[[v0:[^ ]*]] = irdl.is i32 + // CHECK: %[[v1:[^ ]*]] = irdl.is i64 + // CHECK: %[[v2:[^ ]*]] = irdl.any_of(%[[v0]], %[[v1]]) + // CHECK: irdl.results(%[[v2]], %[[v2]]) // CHECK: } irdl.operation @constraint_vars { - %0 = irdl.any - irdl.results(%0, %0) + %0 = irdl.is i32 + %1 = irdl.is i64 + %2 = irdl.any_of(%0, %1) + irdl.results(%2, %2) } } diff --git a/mlir/test/Dialect/IRDL/testd.mlir b/mlir/test/Dialect/IRDL/testd.mlir --- a/mlir/test/Dialect/IRDL/testd.mlir +++ b/mlir/test/Dialect/IRDL/testd.mlir @@ -55,6 +55,56 @@ // ----- +//===----------------------------------------------------------------------===// +// AnyOf constraint +//===----------------------------------------------------------------------===// + +func.func @succeededAnyOfConstraint() { + // CHECK: "testd.anyof"() : () -> i32 + "testd.anyof"() : () -> i32 + // CHECK: "testd.anyof"() : () -> i64 + "testd.anyof"() : () -> i64 + return +} + +// ----- + +func.func @failedAnyOfConstraint() { + // expected-error@+1 {{'i1' does not satisfy the constraint}} + "testd.anyof"() : () -> i1 + return +} + +// ----- + +//===----------------------------------------------------------------------===// +// AllOf constraint +//===----------------------------------------------------------------------===// + +func.func @succeededAllOfConstraint() { + // CHECK: "testd.all_of"() : () -> i64 + "testd.all_of"() : () -> i64 + return +} + +// ----- + +func.func @failedAllOfConstraint1() { + // expected-error@+1 {{'i1' does not satisfy the constraint}} + "testd.all_of"() : () -> i1 + return +} + +// ----- + +func.func @failedAllOfConstraint2() { + // expected-error@+1 {{expected 'i64' but got 'i32'}} + "testd.all_of"() : () -> i32 + return +} + +// ----- + //===----------------------------------------------------------------------===// // Any constraint //===----------------------------------------------------------------------===// @@ -76,8 +126,10 @@ func.func @succeededDynBaseConstraint() { // CHECK: "testd.dynbase"() : () -> !testd.parametric "testd.dynbase"() : () -> !testd.parametric - // CHECK: "testd.dynbase"() : () -> !testd.parametric> - "testd.dynbase"() : () -> !testd.parametric> + // CHECK: "testd.dynbase"() : () -> !testd.parametric + "testd.dynbase"() : () -> !testd.parametric + // CHECK: "testd.dynbase"() : () -> !testd.parametric> + "testd.dynbase"() : () -> !testd.parametric> return } @@ -98,6 +150,8 @@ func.func @succeededDynParamsConstraint() { // CHECK: "testd.dynparams"() : () -> !testd.parametric "testd.dynparams"() : () -> !testd.parametric + // CHECK: "testd.dynparams"() : () -> !testd.parametric + "testd.dynparams"() : () -> !testd.parametric return } @@ -112,7 +166,7 @@ // ----- func.func @failedDynParamsConstraintParam() { - // expected-error@+1 {{expected 'i32' but got 'i1'}} + // expected-error@+1 {{'i1' does not satisfy the constraint}} "testd.dynparams"() : () -> !testd.parametric return }