diff --git a/mlir/include/mlir/Dialect/IRDL/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/IRDL/IR/CMakeLists.txt --- a/mlir/include/mlir/Dialect/IRDL/IR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/IRDL/IR/CMakeLists.txt @@ -1,5 +1,12 @@ add_mlir_dialect(IRDL irdl) +# Add IRDL interfaces +set(LLVM_TARGET_DEFINITIONS IRDLInterfaces.td) +mlir_tablegen(IRDLInterfaces.h.inc -gen-op-interface-decls) +mlir_tablegen(IRDLInterfaces.cpp.inc -gen-op-interface-defs) +add_public_tablegen_target(MLIRIRDLInterfacesIncGen) +add_dependencies(mlir-generic-headers MLIRIRDLInterfacesIncGen) + # Add IRDL operations set(LLVM_TARGET_DEFINITIONS IRDLOps.td) mlir_tablegen(IRDLOps.h.inc -gen-op-decls) diff --git a/mlir/include/mlir/Dialect/IRDL/IR/IRDL.h b/mlir/include/mlir/Dialect/IRDL/IR/IRDL.h --- a/mlir/include/mlir/Dialect/IRDL/IR/IRDL.h +++ b/mlir/include/mlir/Dialect/IRDL/IR/IRDL.h @@ -13,6 +13,7 @@ #ifndef MLIR_DIALECT_IRDL_IR_IRDL_H_ #define MLIR_DIALECT_IRDL_IR_IRDL_H_ +#include "mlir/Dialect/IRDL/IR/IRDLInterfaces.h" #include "mlir/Dialect/IRDL/IR/IRDLTraits.h" #include "mlir/IR/SymbolTable.h" #include "mlir/Interfaces/InferTypeOpInterface.h" diff --git a/mlir/include/mlir/Dialect/IRDL/IR/IRDLInterfaces.h b/mlir/include/mlir/Dialect/IRDL/IR/IRDLInterfaces.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/IRDL/IR/IRDLInterfaces.h @@ -0,0 +1,38 @@ +//===- IRDLInterfaces.h - IRDL interfaces definition ------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the interfaces used by the IRDL dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_IRDL_IR_IRDLINTERFACES_H_ +#define MLIR_DIALECT_IRDL_IR_IRDLINTERFACES_H_ + +#include "mlir/Dialect/IRDL/IRDLVerifiers.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/ExtensibleDialect.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/Types.h" +#include "mlir/Support/LogicalResult.h" +#include + +namespace mlir { +namespace irdl { +class TypeOp; +class AttributeOp; +} // namespace irdl +} // namespace mlir + +//===----------------------------------------------------------------------===// +// IRDL Dialect Interfaces +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/IRDL/IR/IRDLInterfaces.h.inc" + +#endif // MLIR_DIALECT_IRDL_IR_IRDLINTERFACES_H_ diff --git a/mlir/include/mlir/Dialect/IRDL/IR/IRDLInterfaces.td b/mlir/include/mlir/Dialect/IRDL/IR/IRDLInterfaces.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/IRDL/IR/IRDLInterfaces.td @@ -0,0 +1,40 @@ +//===- IRDLInterfaces.td - IRDL Attributes -----------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the interfaces used by IRDL. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_IRDL_IR_IRDLINTERFACES +#define MLIR_DIALECT_IRDL_IR_IRDLINTERFACES + +include "mlir/IR/OpBase.td" + +def VerifyConstraintInterface : OpInterface<"VerifyConstraintInterface"> { + let cppNamespace = "::mlir::irdl"; + + let description = [{ + Interface to get an IRDL constraint verifier from an operation. + }]; + + let methods = [ + InterfaceMethod< + [{ + Get an instance of a constraint verifier for the associated operation." + Returns `nullptr` upon failure. + }], + "std::unique_ptr<::mlir::irdl::Constraint>", + "getVerifier", + (ins "::mlir::SmallVector const&":$valueRes, + "::mlir::DenseMap<::mlir::irdl::TypeOp, std::unique_ptr<::mlir::DynamicTypeDefinition>> &":$types, + "::mlir::DenseMap<::mlir::irdl::AttributeOp, std::unique_ptr<::mlir::DynamicAttrDefinition>> &":$attrs) + > + ]; +} + +#endif // MLIR_DIALECT_IRDL_IR_IRDLINTERFACES diff --git a/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td b/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td --- a/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td +++ b/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td @@ -15,6 +15,7 @@ include "IRDL.td" include "IRDLTypes.td" +include "IRDLInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/IR/SymbolInterfaces.td" @@ -264,7 +265,8 @@ //===----------------------------------------------------------------------===// class IRDL_ConstraintOp traits = []> - : IRDL_Op { + : IRDL_Op] # traits> { } def IRDL_Is : IRDL_ConstraintOp<"is", diff --git a/mlir/lib/Dialect/IRDL/CMakeLists.txt b/mlir/lib/Dialect/IRDL/CMakeLists.txt --- a/mlir/lib/Dialect/IRDL/CMakeLists.txt +++ b/mlir/lib/Dialect/IRDL/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRIRDL IR/IRDL.cpp + IR/IRDLOps.cpp IRDLLoading.cpp IRDLVerifiers.cpp diff --git a/mlir/lib/Dialect/IRDL/IR/IRDL.cpp b/mlir/lib/Dialect/IRDL/IR/IRDL.cpp --- a/mlir/lib/Dialect/IRDL/IR/IRDL.cpp +++ b/mlir/lib/Dialect/IRDL/IR/IRDL.cpp @@ -71,6 +71,8 @@ return success(); } +#include "mlir/Dialect/IRDL/IR/IRDLInterfaces.cpp.inc" + #define GET_TYPEDEF_CLASSES #include "mlir/Dialect/IRDL/IR/IRDLTypesGen.cpp.inc" diff --git a/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp b/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp @@ -0,0 +1,61 @@ +//===- IRDLOps.cpp - IRDL dialect -------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/IRDL/IR/IRDL.h" + +using namespace mlir; +using namespace mlir::irdl; + +std::unique_ptr Is::getVerifier( + SmallVector const &valueToConstr, + DenseMap> &types, + DenseMap> &attrs) { + return std::make_unique(getExpectedAttr()); +} + +std::unique_ptr Parametric::getVerifier( + SmallVector const &valueToConstr, + DenseMap> &types, + DenseMap> &attrs) { + SmallVector constraints; + for (Value arg : getArgs()) { + for (auto [i, value] : enumerate(valueToConstr)) { + if (value == arg) { + constraints.push_back(i); + break; + } + } + } + + // Symbol reference case for the base + SymbolRefAttr symRef = getBaseType(); + Operation *defOp = + SymbolTable::lookupNearestSymbolFrom(getOperation(), symRef); + if (!defOp) { + emitError() << symRef << " does not refer to any existing symbol"; + return nullptr; + } + + if (auto typeOp = dyn_cast(defOp)) + return std::make_unique(types[typeOp].get(), + constraints); + + if (auto attrOp = dyn_cast(defOp)) + return std::make_unique(attrs[attrOp].get(), + constraints); + + llvm_unreachable("verifier should ensure that the referenced operation is " + "either a type or an attribute definition"); +} + +std::unique_ptr Any::getVerifier( + SmallVector const &valueToConstr, + DenseMap> &types, + DenseMap> &attrs) { + return std::make_unique(); +} diff --git a/mlir/lib/Dialect/IRDL/IRDLLoading.cpp b/mlir/lib/Dialect/IRDL/IRDLLoading.cpp --- a/mlir/lib/Dialect/IRDL/IRDLLoading.cpp +++ b/mlir/lib/Dialect/IRDL/IRDLLoading.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/IRDL/IRDLLoading.h" #include "mlir/Dialect/IRDL/IR/IRDL.h" +#include "mlir/Dialect/IRDL/IR/IRDLInterfaces.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/ExtensibleDialect.h" #include "mlir/Support/LogicalResult.h" @@ -22,9 +23,130 @@ using namespace mlir; using namespace mlir::irdl; +/// Verify that the given list of parameters satisfy the given constraints. +/// This encodes the logic of the verification method for attributes and types +/// defined with IRDL. +static LogicalResult +irdlAttrOrTypeVerifier(function_ref emitError, + ArrayRef params, + ArrayRef> constraints, + ArrayRef paramConstraints) { + if (params.size() != paramConstraints.size()) { + emitError() << "expected " << paramConstraints.size() + << " type arguments, but had " << params.size(); + return failure(); + } + + ConstraintVerifier verifier(constraints); + + // Check that each parameter satisfies its constraint. + for (auto [i, param] : enumerate(params)) + if (failed(verifier.verify(emitError, param, paramConstraints[i]))) + return failure(); + + return success(); +} + +/// Verify that the given operation satisfies the given constraints. +/// This encodes the logic of the verification method for operations defined +/// with IRDL. +static LogicalResult +irdlOpVerifier(Operation *op, ArrayRef> constraints, + ArrayRef operandConstrs, + ArrayRef resultConstrs) { + /// Check that we have the right number of operands. + unsigned numOperands = op->getNumOperands(); + size_t numExpectedOperands = operandConstrs.size(); + if (numOperands != numExpectedOperands) + return op->emitOpError() << numExpectedOperands + << " operands expected, but got " << numOperands; + + /// Check that we have the right number of results. + unsigned numResults = op->getNumResults(); + size_t numExpectedResults = resultConstrs.size(); + if (numResults != numExpectedResults) + return op->emitOpError() + << numExpectedResults << " results expected, but got " << numResults; + + auto emitError = [op]() { return op->emitError(); }; + + ConstraintVerifier verifier(constraints); + + /// Check that all operands satisfy the constraints. + for (auto [i, operandType] : enumerate(op->getOperandTypes())) + if (failed(verifier.verify({emitError}, TypeAttr::get(operandType), + operandConstrs[i]))) + return failure(); + + /// Check that all results satisfy the constraints. + for (auto [i, resultType] : enumerate(op->getResultTypes())) + if (failed(verifier.verify({emitError}, TypeAttr::get(resultType), + resultConstrs[i]))) + return failure(); + + return success(); +} + /// Define and load an operation represented by a `irdl.operation` /// operation. -static WalkResult loadOperation(OperationOp op, ExtensibleDialect *dialect) { +static WalkResult loadOperation( + OperationOp op, ExtensibleDialect *dialect, + DenseMap> &types, + DenseMap> &attrs) { + // Resolve SSA values to verifier constraint slots + SmallVector constrToValue; + for (Operation &op : op->getRegion(0).getOps()) { + if (isa(op)) { + if (op.getNumResults() != 1) + return op.emitError() + << "IRDL constraint operations must have exactly one result"; + constrToValue.push_back(op.getResult(0)); + } + } + + // Build the verifiers for each constraint slot + SmallVector> constraints; + for (Value v : constrToValue) { + VerifyConstraintInterface op = + cast(v.getDefiningOp()); + std::unique_ptr verifier = + op.getVerifier(constrToValue, types, attrs); + if (!verifier) + return WalkResult::interrupt(); + constraints.push_back(std::move(verifier)); + } + + SmallVector operandConstraints; + SmallVector resultConstraints; + + // Gather which constraint slots correspond to operand constraints + auto operandsOp = op.getOp(); + if (operandsOp.has_value()) { + operandConstraints.reserve(operandsOp->getArgs().size()); + for (Value operand : operandsOp->getArgs()) { + for (auto [i, constr] : enumerate(constrToValue)) { + if (constr == operand) { + operandConstraints.push_back(i); + break; + } + } + } + } + + // Gather which constraint slots correspond to result constraints + auto resultsOp = op.getOp(); + if (resultsOp.has_value()) { + resultConstraints.reserve(resultsOp->getArgs().size()); + for (Value result : resultsOp->getArgs()) { + for (auto [i, constr] : enumerate(constrToValue)) { + if (constr == result) { + resultConstraints.push_back(i); + break; + } + } + } + } + // IRDL does not support defining custom parsers or printers. auto parser = [](OpAsmParser &parser, OperationState &result) { return failure(); @@ -33,7 +155,13 @@ printer.printGenericOp(op); }; - auto verifier = [](Operation *op) { return success(); }; + auto verifier = + [constraints{std::move(constraints)}, + operandConstraints{std::move(operandConstraints)}, + resultConstraints{std::move(resultConstraints)}](Operation *op) { + return irdlOpVerifier(op, constraints, operandConstraints, + resultConstraints); + }; // IRDL does not support defining regions. auto regionVerifier = [](Operation *op) { return success(); }; @@ -46,6 +174,68 @@ return WalkResult::advance(); } +/// Get the verifier of a type or attribute definition. +/// Return nullptr if the definition is invalid. +static DynamicAttrDefinition::VerifierFn getAttrOrTypeVerifier( + Operation *attrOrTypeDef, ExtensibleDialect *dialect, + DenseMap> &types, + DenseMap> &attrs) { + assert((isa(attrOrTypeDef) || isa(attrOrTypeDef)) && + "Expected an attribute or type definition"); + + // Resolve SSA values to verifier constraint slots + SmallVector constrToValue; + for (Operation &op : attrOrTypeDef->getRegion(0).getOps()) { + if (isa(op)) { + assert(op.getNumResults() == 1 && + "IRDL constraint operations must have exactly one result"); + constrToValue.push_back(op.getResult(0)); + } + } + + // Build the verifiers for each constraint slot + SmallVector> constraints; + for (Value v : constrToValue) { + VerifyConstraintInterface op = + cast(v.getDefiningOp()); + std::unique_ptr verifier = + op.getVerifier(constrToValue, types, attrs); + if (!verifier) + return {}; + constraints.push_back(std::move(verifier)); + } + + // Get the parameter definitions. + std::optional params; + if (auto attr = dyn_cast(attrOrTypeDef)) + params = attr.getOp(); + else if (auto type = dyn_cast(attrOrTypeDef)) + params = type.getOp(); + + // Gather which constraint slots correspond to parameter constraints + SmallVector paramConstraints; + if (params.has_value()) { + paramConstraints.reserve(params->getArgs().size()); + for (Value param : params->getArgs()) { + for (auto [i, constr] : enumerate(constrToValue)) { + if (constr == param) { + paramConstraints.push_back(i); + break; + } + } + } + } + + auto verifier = [paramConstraints{std::move(paramConstraints)}, + constraints{std::move(constraints)}]( + function_ref emitError, + ArrayRef params) { + return irdlAttrOrTypeVerifier(emitError, params, constraints, + paramConstraints); + }; + return verifier; +} + /// Load all dialects in the given module, without loading any operation, type /// or attribute definitions. static DenseMap loadEmptyDialects(ModuleOp op) { @@ -108,9 +298,33 @@ DenseMap> attrs = preallocateAttrDefs(op, dialects); + // Set the verifier for types. + WalkResult res = op.walk([&](TypeOp typeOp) { + DynamicAttrDefinition::VerifierFn verifier = getAttrOrTypeVerifier( + typeOp, dialects[typeOp.getParentOp()], types, attrs); + if (!verifier) + return WalkResult::interrupt(); + types[typeOp]->setVerifyFn(std::move(verifier)); + return WalkResult::advance(); + }); + if (res.wasInterrupted()) + return failure(); + + // Set the verifier for attributes. + res = op.walk([&](AttributeOp attrOp) { + DynamicAttrDefinition::VerifierFn verifier = getAttrOrTypeVerifier( + attrOp, dialects[attrOp.getParentOp()], types, attrs); + if (!verifier) + return WalkResult::interrupt(); + attrs[attrOp]->setVerifyFn(std::move(verifier)); + return WalkResult::advance(); + }); + if (res.wasInterrupted()) + return failure(); + // Define and load all operations. - WalkResult res = op.walk([&](OperationOp opOp) { - return loadOperation(opOp, dialects[opOp.getParentOp()]); + res = op.walk([&](OperationOp opOp) { + return loadOperation(opOp, dialects[opOp.getParentOp()], types, attrs); }); if (res.wasInterrupted()) return failure(); diff --git a/mlir/test/Dialect/IRDL/testd.mlir b/mlir/test/Dialect/IRDL/testd.mlir --- a/mlir/test/Dialect/IRDL/testd.mlir +++ b/mlir/test/Dialect/IRDL/testd.mlir @@ -45,6 +45,13 @@ return } +// ----- + +func.func @failedEqConstraint() { + // expected-error@+1 {{expected 'i32' but got 'i64'}} + "testd.eq"() : () -> i64 + return +} // ----- @@ -74,6 +81,13 @@ return } +// ----- + +func.func @failedDynBaseConstraint() { + // expected-error@+1 {{expected base type 'testd.parametric' but got 'i32'}} + "testd.dynbase"() : () -> i32 + return +} // ----- @@ -89,6 +103,22 @@ // ----- +func.func @failedDynParamsConstraintBase() { + // expected-error@+1 {{expected base type 'testd.parametric' but got 'i32'}} + "testd.dynparams"() : () -> i32 + return +} + +// ----- + +func.func @failedDynParamsConstraintParam() { + // expected-error@+1 {{expected 'i32' but got 'i1'}} + "testd.dynparams"() : () -> !testd.parametric + return +} + +// ----- + //===----------------------------------------------------------------------===// // Constraint variables //===----------------------------------------------------------------------===// @@ -106,3 +136,11 @@ "testd.constraint_vars"() : () -> (i64, i64) return } + +// ----- + +func.func @failedConstraintVars() { + // expected-error@+1 {{expected 'i64' but got 'i32'}} + "testd.constraint_vars"() : () -> (i64, i32) + return +}