diff --git a/mlir/include/mlir/Dialect/IRDL/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/IRDL/IR/CMakeLists.txt --- a/mlir/include/mlir/Dialect/IRDL/IR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/IRDL/IR/CMakeLists.txt @@ -1,5 +1,12 @@ add_mlir_dialect(IRDL irdl) +# Add IRDL interfaces +set(LLVM_TARGET_DEFINITIONS IRDLInterfaces.td) +mlir_tablegen(IRDLInterfaces.h.inc -gen-op-interface-decls) +mlir_tablegen(IRDLInterfaces.cpp.inc -gen-op-interface-defs) +add_public_tablegen_target(MLIRIRDLInterfacesIncGen) +add_dependencies(mlir-generic-headers MLIRIRDLInterfacesIncGen) + # Add IRDL operations set(LLVM_TARGET_DEFINITIONS IRDLOps.td) mlir_tablegen(IRDLOps.h.inc -gen-op-decls) diff --git a/mlir/include/mlir/Dialect/IRDL/IR/IRDL.h b/mlir/include/mlir/Dialect/IRDL/IR/IRDL.h --- a/mlir/include/mlir/Dialect/IRDL/IR/IRDL.h +++ b/mlir/include/mlir/Dialect/IRDL/IR/IRDL.h @@ -13,6 +13,7 @@ #ifndef MLIR_DIALECT_IRDL_IR_IRDL_H_ #define MLIR_DIALECT_IRDL_IR_IRDL_H_ +#include "mlir/Dialect/IRDL/IR/IRDLInterfaces.h" #include "mlir/Dialect/IRDL/IR/IRDLTraits.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" diff --git a/mlir/include/mlir/Dialect/IRDL/IR/IRDLInterfaces.h b/mlir/include/mlir/Dialect/IRDL/IR/IRDLInterfaces.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/IRDL/IR/IRDLInterfaces.h @@ -0,0 +1,37 @@ +//===- IRDLInterfaces.h - IRDL interfaces definition ------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the interfaces used by the IRDL dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_IRDL_IR_IRDLINTERFACES_H_ +#define MLIR_DIALECT_IRDL_IR_IRDLINTERFACES_H_ + +#include "mlir/Dialect/IRDL/IRDLVerifiers.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/ExtensibleDialect.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/Types.h" +#include "mlir/Support/LogicalResult.h" +#include + +namespace mlir { +namespace irdl { +class TypeOp; +class AttributeOp; +} // namespace irdl +} // namespace mlir + +//===----------------------------------------------------------------------===// +// IRDL Dialect Interfaces +//===----------------------------------------------------------------------===// +#include "mlir/Dialect/IRDL/IR/IRDLInterfaces.h.inc" + +#endif // MLIR_DIALECT_IRDL_IR_IRDLINTERFACES_H_ diff --git a/mlir/include/mlir/Dialect/IRDL/IR/IRDLInterfaces.td b/mlir/include/mlir/Dialect/IRDL/IR/IRDLInterfaces.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/IRDL/IR/IRDLInterfaces.td @@ -0,0 +1,40 @@ +//===- IRDLInterfaces.td - IRDL Attributes -----------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the interfaces used by IRDL. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_IRDL_IR_IRDLINTERFACES +#define MLIR_DIALECT_IRDL_IR_IRDLINTERFACES + +include "mlir/IR/OpBase.td" + +def VerifyConstraintInterface : OpInterface<"VerifyConstraintInterface"> { + let cppNamespace = "::mlir::irdl"; + + let description = [{ + Interface to get an IRDL constraint verifier from an operation. + }]; + + let methods = [ + InterfaceMethod< + [{ + Get an instance of a constraint verifier for the associated operation." + Returns `nullptr` upon failure. + }], + "std::unique_ptr<::mlir::irdl::Constraint>", + "getVerifier", + (ins "mlir::SmallVector const&":$valueRes, + "llvm::DenseMap> &":$types, + "llvm::DenseMap> &":$attrs) + > + ]; +} + +#endif // MLIR_DIALECT_IRDL_IR_IRDLINTERFACES diff --git a/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td b/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td --- a/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td +++ b/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td @@ -15,6 +15,7 @@ include "IRDL.td" include "IRDLTypes.td" +include "IRDLInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/IR/SymbolInterfaces.td" @@ -280,7 +281,8 @@ //===----------------------------------------------------------------------===// class IRDL_ConstraintOp traits = []> - : IRDL_Op { + : IRDL_Op] # traits> { } def IRDL_Is : IRDL_ConstraintOp<"is", diff --git a/mlir/include/mlir/Dialect/IRDL/IRDLRegistration.h b/mlir/include/mlir/Dialect/IRDL/IRDLRegistration.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/IRDL/IRDLRegistration.h @@ -0,0 +1,28 @@ +//===- IRDLRegistration.h - IRDL registration -------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Manages the registration of MLIR objects from IRDL operations. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_IRDL_IRDLREGISTRATION_H +#define MLIR_DIALECT_IRDL_IRDLREGISTRATION_H + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Support/LogicalResult.h" + +namespace mlir { +namespace irdl { + +/// Register all the dialects in a module. +LogicalResult registerDialects(ModuleOp op); + +} // namespace irdl +} // namespace mlir + +#endif // MLIR_DIALECT_IRDL_IRDLREGISTRATION_H diff --git a/mlir/include/mlir/Dialect/IRDL/IRDLVerifiers.h b/mlir/include/mlir/Dialect/IRDL/IRDLVerifiers.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/IRDL/IRDLVerifiers.h @@ -0,0 +1,175 @@ +//===- IRDLVerifiers.h - IRDL verifiers --------------------------- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Verifiers for objects declared by IRDL. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_IRDL_IRDLVERIFIERS_H +#define MLIR_DIALECT_IRDL_IRDLVERIFIERS_H + +#include "mlir/IR/ExtensibleDialect.h" +#include "mlir/IR/Operation.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/ArrayRef.h" + +namespace mlir { +namespace irdl { + +class Constraint; + +/// Provides context to the verification of constraints. +/// It contains the assignment of variables to attributes, and the assignment +/// of variables to constraints. +class ConstraintVerifier { +public: + ConstraintVerifier(ArrayRef> constraints); + + /// Check that a constraint is satisfied by an attribute. + /// + /// Constraints may call other constraint verifiers. If that is the case, + /// the constraint verifier will check if the variable is already assigned, + /// and if so, check that the attribute is the same as the one assigned. + /// If the variable is not assigned, the constraint verifier will + /// assign the attribute to the variable, and check that the constraint + /// is satisfied. + LogicalResult verify(function_ref emitError, + Attribute attr, unsigned variable); + +private: + /// The constraints that can be used for verification. + ArrayRef> constraints; + + /// The assignment of variables to attributes. + /// Variables that are non assigned are represented by a null attribute. + SmallVector> assigned; +}; + +/// Once turned into IRDL verifiers, all constraints are +/// attribute constraints. Type constraints are represented +/// as `TypeAttr` attribute constraints to simplify verification. +/// Verification that a type constraint must yield a +/// `TypeAttr` attribute happens before conversion, at the MLIR level. +class Constraint { +public: + virtual ~Constraint() = default; + + /// Check that an attribute is satisfying the constraint. + /// + /// Constraints may call other constraint verifiers. If that is the case, + /// the constraint verifier will check if the variable is already assigned, + /// and if so, check that the attribute is the same as the one assigned. + /// If the variable is not assigned, the constraint verifier will + /// assign the attribute to the variable, and check that the constraint + /// is satisfied. + virtual LogicalResult verify(function_ref emitError, + Attribute attr, + ConstraintVerifier &context) const = 0; +}; + +/// A constraint that checks that an attribute is equal to a given attribute. +class IsConstraint : public Constraint { +public: + IsConstraint(Attribute expectedAttribute) + : expectedAttribute(expectedAttribute) {} + + virtual ~IsConstraint() = default; + + LogicalResult verify(function_ref emitError, + Attribute attr, + ConstraintVerifier &context) const override; + +private: + Attribute expectedAttribute; +}; + +/// A constraint that checks that an attribute is of a +/// specific dynamic attribute definition, and that all of its parameters +/// satisfy the given constraints. +class DynParametricAttrConstraint : public Constraint { +public: + DynParametricAttrConstraint(DynamicAttrDefinition *attrDef, + SmallVector constraints) + : attrDef(attrDef), constraints(std::move(constraints)) {} + + virtual ~DynParametricAttrConstraint() = default; + + LogicalResult verify(function_ref emitError, + Attribute attr, + ConstraintVerifier &context) const override; + +private: + DynamicAttrDefinition *attrDef; + SmallVector constraints; +}; + +/// A constraint that checks that a type is of a specific dynamic type +/// definition, and that all of its parameters satisfy the given constraints. +class DynParametricTypeConstraint : public Constraint { +public: + DynParametricTypeConstraint(DynamicTypeDefinition *typeDef, + SmallVector constraints) + : typeDef(typeDef), constraints(std::move(constraints)) {} + + virtual ~DynParametricTypeConstraint() = default; + + LogicalResult verify(function_ref emitError, + Attribute attr, + ConstraintVerifier &context) const override; + +private: + DynamicTypeDefinition *typeDef; + SmallVector constraints; +}; + +/// A constraint checking that one of the given constraints is satisfied. +class AnyOfConstraint : public Constraint { +public: + AnyOfConstraint(SmallVector constraints) + : constraints(std::move(constraints)) {} + + virtual ~AnyOfConstraint() = default; + + LogicalResult verify(function_ref emitError, + Attribute attr, + ConstraintVerifier &context) const override; + +private: + SmallVector constraints; +}; + +/// A constraint checking that all of the given constraints are satisfied. +class AllOfConstraint : public Constraint { +public: + AllOfConstraint(SmallVector constraints) + : constraints(std::move(constraints)) {} + + virtual ~AllOfConstraint() = default; + + LogicalResult verify(function_ref emitError, + Attribute attr, + ConstraintVerifier &context) const override; + +private: + SmallVector constraints; +}; + +/// A constraint that is always satisfied. +class AnyAttributeConstraint : public Constraint { +public: + virtual ~AnyAttributeConstraint() = default; + + LogicalResult verify(function_ref emitError, + Attribute attr, + ConstraintVerifier &context) const override; +}; + +} // namespace irdl +} // namespace mlir + +#endif // MLIR_DIALECT_IRDL_IRDLVERIFIERS_H diff --git a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h --- a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h +++ b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h @@ -53,25 +53,28 @@ /// - emitBytecode will generate bytecode output instead of text. /// - implicitModule will enable implicit addition of a top-level /// 'builtin.module' if one doesn't already exist. -/// - dumpPassPipeline will dump the pipeline being run to stderr -LogicalResult -MlirOptMain(llvm::raw_ostream &outputStream, - std::unique_ptr buffer, - const PassPipelineCLParser &passPipeline, DialectRegistry ®istry, - bool splitInputFile, bool verifyDiagnostics, bool verifyPasses, - bool allowUnregisteredDialects, - bool preloadDialectsInContext = false, bool emitBytecode = false, - bool implicitModule = false, bool dumpPassPipeline = false); - -/// Support a callback to setup the pass manager. -/// - passManagerSetupFn is the callback invoked to setup the pass manager to -/// apply on the loaded IR. +/// - dumpPassPipeline will dump the pipeline being run to stderr. +/// - irdlFile is the path to an IRDL file to load. LogicalResult MlirOptMain( llvm::raw_ostream &outputStream, std::unique_ptr buffer, - PassPipelineFn passManagerSetupFn, DialectRegistry ®istry, + const PassPipelineCLParser &passPipeline, DialectRegistry ®istry, bool splitInputFile, bool verifyDiagnostics, bool verifyPasses, bool allowUnregisteredDialects, bool preloadDialectsInContext = false, - bool emitBytecode = false, bool implicitModule = false); + bool emitBytecode = false, bool implicitModule = false, + bool dumpPassPipeline = false, StringRef irdlFile = ""); + +/// Support a callback to setup the pass manager. +/// - passManagerSetupFn is the callback invoked to setup the pass manager to +/// apply on the loaded IR. +LogicalResult MlirOptMain(llvm::raw_ostream &outputStream, + std::unique_ptr buffer, + PassPipelineFn passManagerSetupFn, + DialectRegistry ®istry, bool splitInputFile, + bool verifyDiagnostics, bool verifyPasses, + bool allowUnregisteredDialects, + bool preloadDialectsInContext = false, + bool emitBytecode = false, + bool implicitModule = false, StringRef irdlFile = ""); /// Implementation for tools like `mlir-opt`. /// - toolName is used for the header displayed by `--help`. diff --git a/mlir/lib/Dialect/IRDL/CMakeLists.txt b/mlir/lib/Dialect/IRDL/CMakeLists.txt --- a/mlir/lib/Dialect/IRDL/CMakeLists.txt +++ b/mlir/lib/Dialect/IRDL/CMakeLists.txt @@ -1,6 +1,8 @@ add_mlir_dialect_library(MLIRIRDL IR/IRDL.cpp IR/IRDLOps.cpp + IRDLRegistration.cpp + IRDLVerifiers.cpp DEPENDS MLIRIRDLIncGen diff --git a/mlir/lib/Dialect/IRDL/IR/IRDL.cpp b/mlir/lib/Dialect/IRDL/IR/IRDL.cpp --- a/mlir/lib/Dialect/IRDL/IR/IRDL.cpp +++ b/mlir/lib/Dialect/IRDL/IR/IRDL.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/IRDL/IR/IRDL.h" +#include "mlir/Dialect/IRDL/IRDLRegistration.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/DialectImplementation.h" @@ -74,6 +75,8 @@ return success(); } +#include "mlir/Dialect/IRDL/IR/IRDLInterfaces.cpp.inc" + #define GET_TYPEDEF_CLASSES #include "mlir/Dialect/IRDL/IR/IRDLTypesGen.cpp.inc" diff --git a/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp b/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp --- a/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp +++ b/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp @@ -7,6 +7,10 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/IRDL/IR/IRDL.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Support/TypeID.h" +#include "llvm/ADT/SmallPtrSet.h" +#include using namespace mlir; using namespace mlir::irdl; @@ -18,3 +22,52 @@ DialectOp TypeOp::getDialectOp() { return cast(getOperation()->getParentOp()); } + +std::unique_ptr Is::getVerifier( + SmallVector const &valueToConstr, + DenseMap> &types, + DenseMap> &attrs) { + return std::make_unique(getExpectedAttr()); +} + +std::unique_ptr Parametric::getVerifier( + SmallVector const &valueToConstr, + DenseMap> &types, + DenseMap> &attrs) { + SmallVector constraints; + for (Value arg : getArgs()) { + for (auto [i, value] : enumerate(valueToConstr)) { + if (value == arg) { + constraints.push_back(i); + break; + } + } + } + + // Symbol reference case for the base + SymbolRefAttr symRef = getBaseType(); + Operation *defOp = + SymbolTable::lookupNearestSymbolFrom(getOperation(), symRef); + if (!defOp) { + emitError() << symRef << " does not refer to any existing symbol"; + return nullptr; + } + + if (auto typeOp = dyn_cast(defOp)) + return std::make_unique(types[typeOp].get(), + constraints); + + if (auto attrOp = dyn_cast(defOp)) + return std::make_unique(attrs[attrOp].get(), + constraints); + + llvm_unreachable("verifier should ensure that the referenced operation is " + "either a type or an attribute definition"); +} + +std::unique_ptr Any::getVerifier( + SmallVector const &valueToConstr, + DenseMap> &types, + DenseMap> &attrs) { + return std::make_unique(); +} diff --git a/mlir/lib/Dialect/IRDL/IRDLRegistration.cpp b/mlir/lib/Dialect/IRDL/IRDLRegistration.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/IRDL/IRDLRegistration.cpp @@ -0,0 +1,356 @@ +//===- IRDLRegistration.cpp - IRDL dialect registration ----------- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Manages the registration of MLIR objects from IRDL operations. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/IRDL/IRDLRegistration.h" +#include "mlir/Dialect/IRDL/IR/IRDL.h" +#include "mlir/Dialect/IRDL/IR/IRDLInterfaces.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/ExtensibleDialect.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/Support/SMLoc.h" + +using namespace mlir; +using namespace mlir::irdl; + +/// Verify that the given list of parameters satisfy the given constraints. +/// This encodes the logic of the verification method for attributes and types +/// defined with IRDL. +static LogicalResult +irdlAttrOrTypeVerifier(function_ref emitError, + ArrayRef params, + ArrayRef> constraints, + ArrayRef paramConstraints) { + if (params.size() != paramConstraints.size()) { + emitError() << "expected " << paramConstraints.size() + << " type arguments, but had " << params.size(); + return failure(); + } + + ConstraintVerifier verifier(constraints); + + // Check that each parameter satisfies its constraint. + for (auto [i, param] : enumerate(params)) + if (failed(verifier.verify(emitError, param, paramConstraints[i]))) + return failure(); + + return success(); +} + +/// Verify that the given operation satisfies the given constraints. +/// This encodes the logic of the verification method for operations defined +/// with IRDL. +static LogicalResult +irdlOpVerifier(Operation *op, ArrayRef> constraints, + ArrayRef operandConstrs, + ArrayRef resultConstrs) { + /// Check that we have the right number of operands. + unsigned numOperands = op->getNumOperands(); + size_t numExpectedOperands = operandConstrs.size(); + if (numOperands != numExpectedOperands) + return op->emitOpError(std::to_string(numExpectedOperands) + + " operands expected, but got " + + std::to_string(numOperands)); + + /// Check that we have the right number of results. + unsigned numResults = op->getNumResults(); + size_t numExpectedResults = resultConstrs.size(); + if (numResults != numExpectedResults) + return op->emitOpError(std::to_string(numExpectedResults) + + " results expected, but got " + + std::to_string(numResults)); + + auto emitError = [op]() { return op->emitError(); }; + + ConstraintVerifier verifier(constraints); + + /// Check that all operands satisfy the constraints. + for (auto [i, operandType] : enumerate(op->getOperandTypes())) { + if (failed(verifier.verify({emitError}, TypeAttr::get(operandType), + operandConstrs[i]))) { + return failure(); + } + } + + /// Check that all results satisfy the constraints. + for (auto [i, resultType] : enumerate(op->getResultTypes())) { + if (failed(verifier.verify({emitError}, TypeAttr::get(resultType), + resultConstrs[i]))) { + return failure(); + } + } + + return success(); +} + +/// Define and register an operation represented by a `irdl.operation` +/// operation. +static WalkResult registerOperation( + OperationOp op, ExtensibleDialect *dialect, + DenseMap> &types, + DenseMap> &attrs) { + // Resolve SSA values to verifier constraint slots + SmallVector constrToValue; + for (Operation &op : op->getRegion(0).getOps()) { + if (isa(op)) { + if (op.getNumResults() != 1) { + op.emitError() + << "IRDL constraint operations must have exactly one result"; + return WalkResult::interrupt(); + } + constrToValue.push_back(op.getResult(0)); + } + } + + // Build the verifiers for each constraint slot + SmallVector> constraints; + for (Value v : constrToValue) { + VerifyConstraintInterface op = + cast(v.getDefiningOp()); + std::unique_ptr verifier = + op.getVerifier(constrToValue, types, attrs); + if (!verifier) { + return WalkResult::interrupt(); + } + constraints.push_back(std::move(verifier)); + } + + SmallVector operandConstraints; + SmallVector resultConstraints; + + // Gather which constraint slots correspond to operand constraints + auto operandsOp = op.getOp(); + if (operandsOp.has_value()) { + operandConstraints.reserve(operandsOp->getArgs().size()); + for (Value operand : operandsOp->getArgs()) { + for (auto [i, constr] : enumerate(constrToValue)) { + if (constr == operand) { + operandConstraints.push_back(i); + break; + } + } + } + } + + // Gather which constraint slots correspond to result constraints + auto resultsOp = op.getOp(); + if (resultsOp.has_value()) { + resultConstraints.reserve(resultsOp->getArgs().size()); + for (Value result : resultsOp->getArgs()) { + for (auto [i, constr] : enumerate(constrToValue)) { + if (constr == result) { + resultConstraints.push_back(i); + break; + } + } + } + } + + // IRDL does not support defining custom parsers or printers. + auto parser = [](OpAsmParser &parser, OperationState &result) { + return failure(); + }; + auto printer = [](Operation *op, OpAsmPrinter &printer, StringRef) { + printer.printGenericOp(op); + }; + + auto verifier = + [constraints{std::move(constraints)}, + operandConstraints{std::move(operandConstraints)}, + resultConstraints{std::move(resultConstraints)}](Operation *op) { + return irdlOpVerifier(op, constraints, operandConstraints, + resultConstraints); + }; + + // IRDL does not support defining regions. + auto regionVerifier = [](Operation *op) { return success(); }; + + auto opDef = DynamicOpDefinition::get( + op.getName(), dialect, std::move(verifier), std::move(regionVerifier), + std::move(parser), std::move(printer)); + dialect->registerDynamicOp(std::move(opDef)); + + return WalkResult::advance(); +} + +/// Get the verifier of a type or attribute definition. +/// Return nullptr if the definition is invalid. +static DynamicAttrDefinition::VerifierFn getAttrOrTypeVerifier( + Operation *attrOrTypeDef, ExtensibleDialect *dialect, + DenseMap> &types, + DenseMap> &attrs) { + assert((isa(attrOrTypeDef) || isa(attrOrTypeDef)) && + "Expected an attribute or type definition"); + + // Resolve SSA values to verifier constraint slots + SmallVector constrToValue; + for (Operation &op : attrOrTypeDef->getRegion(0).getOps()) { + if (isa(op)) { + assert(op.getNumResults() == 1 && + "IRDL constraint operations must have exactly one result"); + constrToValue.push_back(op.getResult(0)); + } + } + + // Build the verifiers for each constraint slot + SmallVector> constraints; + for (Value v : constrToValue) { + VerifyConstraintInterface op = + cast(v.getDefiningOp()); + std::unique_ptr verifier = + op.getVerifier(constrToValue, types, attrs); + if (!verifier) { + return {}; + } + constraints.push_back(std::move(verifier)); + } + + // Get the parameter definitions. + std::optional params; + if (auto attr = dyn_cast(attrOrTypeDef)) + params = attr.getOp(); + else if (auto type = dyn_cast(attrOrTypeDef)) + params = type.getOp(); + + // Gather which constraint slots correspond to parameter constraints + SmallVector paramConstraints; + if (params.has_value()) { + paramConstraints.reserve(params->getArgs().size()); + for (Value param : params->getArgs()) { + for (auto [i, constr] : enumerate(constrToValue)) { + if (constr == param) { + paramConstraints.push_back(i); + break; + } + } + } + } + + auto verifier = [paramConstraints{std::move(paramConstraints)}, + constraints{std::move(constraints)}]( + function_ref emitError, + ArrayRef params) { + return irdlAttrOrTypeVerifier(emitError, params, constraints, + paramConstraints); + }; + return verifier; +} + +/// Register all dialects in the given module, without registering any +/// operation, type or attribute definitions. +static DenseMap +registerEmptyDialects(ModuleOp op) { + DenseMap dialects; + op.walk([&](DialectOp dialectOp) { + MLIRContext *ctx = dialectOp.getContext(); + StringRef dialectName = dialectOp.getName(); + + DynamicDialect *dialect = ctx->getOrLoadDynamicDialect( + dialectName, [](DynamicDialect *dialect) {}); + + dialects.insert({dialectOp, dialect}); + }); + return dialects; +} + +/// Preallocate type definitions objects with empty verifiers. +/// This in particular allocates a TypeID for each type definition. +static DenseMap> +preallocateTypeDefs(ModuleOp op, + DenseMap dialects) { + DenseMap> typeDefs; + op.walk([&](TypeOp typeOp) { + ExtensibleDialect *dialect = dialects[typeOp.getDialectOp()]; + auto typeDef = DynamicTypeDefinition::get( + typeOp.getName(), dialect, + [](function_ref, ArrayRef) { + return success(); + }); + typeDefs.try_emplace(typeOp, std::move(typeDef)); + }); + return typeDefs; +} + +/// Preallocate attribute definitions objects with empty verifiers. +/// This in particular allocates a TypeID for each attribute definition. +static DenseMap> +preallocateAttrDefs(ModuleOp op, + DenseMap dialects) { + DenseMap> attrDefs; + op.walk([&](AttributeOp attrOp) { + ExtensibleDialect *dialect = dialects[attrOp.getDialectOp()]; + auto attrDef = DynamicAttrDefinition::get( + attrOp.getName(), dialect, + [](function_ref, ArrayRef) { + return success(); + }); + attrDefs.try_emplace(attrOp, std::move(attrDef)); + }); + return attrDefs; +} + +LogicalResult mlir::irdl::registerDialects(ModuleOp op) { + // Preallocate all dialects, and type and attribute definitions. + // In particular, this allocates TypeIDs so type and attributes can have + // verifiers that refer to each other. + DenseMap dialects = registerEmptyDialects(op); + DenseMap> types = + preallocateTypeDefs(op, dialects); + DenseMap> attrs = + preallocateAttrDefs(op, dialects); + + // Set the verifier for types. + WalkResult res = op.walk([&](TypeOp typeOp) { + DynamicAttrDefinition::VerifierFn verifier = getAttrOrTypeVerifier( + typeOp, dialects[typeOp.getDialectOp()], types, attrs); + if (!verifier) + return WalkResult::interrupt(); + types[typeOp]->setVerifyFn(std::move(verifier)); + return WalkResult::advance(); + }); + if (res.wasInterrupted()) + return failure(); + + // Set the verifier for attributes. + res = op.walk([&](AttributeOp attrOp) { + DynamicAttrDefinition::VerifierFn verifier = getAttrOrTypeVerifier( + attrOp, dialects[attrOp.getDialectOp()], types, attrs); + if (!verifier) + return WalkResult::interrupt(); + attrs[attrOp]->setVerifyFn(std::move(verifier)); + return WalkResult::advance(); + }); + if (res.wasInterrupted()) + return failure(); + + // Define and register all operations. + res = op.walk([&](OperationOp opOp) { + return registerOperation(opOp, dialects[opOp.getDialectOp()], types, attrs); + }); + if (res.wasInterrupted()) + return failure(); + + // Register all types to their dialects. + for (auto &pair : types) { + ExtensibleDialect *dialect = dialects[pair.first.getDialectOp()]; + dialect->registerDynamicType(std::move(pair.second)); + } + + // Register all attributes to their dialects. + for (auto &pair : attrs) { + ExtensibleDialect *dialect = dialects[pair.first.getDialectOp()]; + dialect->registerDynamicAttr(std::move(pair.second)); + } + + return success(); +} diff --git a/mlir/lib/Dialect/IRDL/IRDLVerifiers.cpp b/mlir/lib/Dialect/IRDL/IRDLVerifiers.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/IRDL/IRDLVerifiers.cpp @@ -0,0 +1,176 @@ +//===- IRDLVerifiers.cpp - IRDL verifiers ------------------------- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Verifiers for objects declared by IRDL. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/IRDL/IRDLVerifiers.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/ArrayRef.h" + +using namespace mlir; +using namespace mlir::irdl; + +ConstraintVerifier::ConstraintVerifier( + ArrayRef> constraints) + : constraints(constraints), assigned() { + assigned.resize(this->constraints.size()); +} + +LogicalResult +ConstraintVerifier::verify(function_ref emitError, + Attribute attr, unsigned variable) { + + assert(variable < constraints.size() && "invalid constraint variable"); + + // If the variable is already assigned, check that the attribute is the same. + if (assigned[variable].has_value()) { + if (attr == assigned[variable].value()) + return success(); + else { + if (emitError) + return emitError() << "expected '" << assigned[variable].value() + << "' but got '" << attr << "'"; + return failure(); + } + } + + // Otherwise, check the constraint and assign the attribute to the variable. + LogicalResult result = constraints[variable]->verify(emitError, attr, *this); + if (succeeded(result)) + assigned[variable] = attr; + + return result; +} + +LogicalResult IsConstraint::verify(function_ref emitError, + Attribute attr, + ConstraintVerifier &context) const { + if (attr == expectedAttribute) + return success(); + + if (emitError) + return emitError() << "expected '" << expectedAttribute << "' but got '" + << attr << "'"; + return failure(); +} + +LogicalResult DynParametricAttrConstraint::verify( + function_ref emitError, Attribute attr, + ConstraintVerifier &context) const { + + // Check that the base is the expected one. + auto dynAttr = attr.dyn_cast(); + if (!dynAttr || dynAttr.getAttrDef() != attrDef) { + if (emitError) { + StringRef dialectName = attrDef->getDialect()->getNamespace(); + StringRef attrName = attrDef->getName(); + return emitError() << "expected base attribute '" << attrName << '.' + << dialectName << "' but got '" << attr << "'"; + } + return failure(); + } + + // Check that the parameters satisfy the constraints. + ArrayRef params = dynAttr.getParams(); + if (params.size() != constraints.size()) { + if (emitError) { + StringRef dialectName = attrDef->getDialect()->getNamespace(); + StringRef attrName = attrDef->getName(); + emitError() << "attribute '" << dialectName << "." << attrName + << "' expects " << params.size() << " parameters but got " + << constraints.size(); + } + return failure(); + } + + for (size_t i = 0, s = params.size(); i < s; i++) + if (failed(context.verify(emitError, params[i], constraints[i]))) + return failure(); + + return success(); +} + +LogicalResult DynParametricTypeConstraint::verify( + function_ref emitError, Attribute attr, + ConstraintVerifier &context) const { + // Check that the base is a TypeAttr. + auto typeAttr = attr.dyn_cast(); + if (!typeAttr) { + if (emitError) + return emitError() << "expected type, got attribute '" << attr; + return failure(); + } + + // Check that the type base is the expected one. + auto dynType = typeAttr.getValue().dyn_cast(); + if (!dynType || dynType.getTypeDef() != typeDef) { + if (emitError) { + StringRef dialectName = typeDef->getDialect()->getNamespace(); + StringRef attrName = typeDef->getName(); + return emitError() << "expected base type '" << dialectName << '.' + << attrName << "' but got '" << attr << "'"; + } + return failure(); + } + + // Check that the parameters satisfy the constraints. + ArrayRef params = dynType.getParams(); + if (params.size() != constraints.size()) { + if (emitError) { + StringRef dialectName = typeDef->getDialect()->getNamespace(); + StringRef attrName = typeDef->getName(); + emitError() << "attribute '" << dialectName << "." << attrName + << "' expects " << params.size() << " parameters but got " + << constraints.size(); + } + return failure(); + } + + for (size_t i = 0, s = params.size(); i < s; i++) + if (failed(context.verify(emitError, params[i], constraints[i]))) + return failure(); + + return success(); +} + +LogicalResult +AnyOfConstraint::verify(function_ref emitError, + Attribute attr, ConstraintVerifier &context) const { + for (unsigned constr : constraints) { + // We do not pass the `emitError` here, since we want to emit an error + // only if none of the constraints are satisfied. + if (succeeded(context.verify({}, attr, constr))) { + return success(); + } + } + + if (emitError) + return emitError() << "'" << attr << "' does not satisfy the constraint"; + return failure(); +} + +LogicalResult +AllOfConstraint::verify(function_ref emitError, + Attribute attr, ConstraintVerifier &context) const { + for (unsigned constr : constraints) { + if (failed(context.verify(emitError, attr, constr))) { + return failure(); + } + } + + return success(); +} + +LogicalResult +AnyAttributeConstraint::verify(function_ref emitError, + Attribute attr, + ConstraintVerifier &context) const { + return success(); +} diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp --- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp +++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp @@ -13,6 +13,8 @@ #include "mlir/Tools/mlir-opt/MlirOptMain.h" #include "mlir/Bytecode/BytecodeWriter.h" +#include "mlir/Dialect/IRDL/IR/IRDL.h" +#include "mlir/Dialect/IRDL/IRDLRegistration.h" #include "mlir/IR/AsmState.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinOps.h" @@ -104,6 +106,43 @@ return success(); } +LogicalResult registerIRDL(StringRef irdlFile, MLIRContext &ctx) { + DialectRegistry registry; + registry.insert(); + ctx.appendDialectRegistry(registry); + + // Set up the input file. + std::string errorMessage; + std::unique_ptr file = openInputFile(irdlFile, &errorMessage); + if (!file) { + emitError(UnknownLoc::get(&ctx)) << errorMessage; + return failure(); + } + + // Give the buffer to the source manager. + // This will be picked up by the parser. + SourceMgr sourceMgr; + sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc()); + + SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &ctx); + + // Disable multi-threading when parsing the input file. This removes the + // unnecessary/costly context synchronization when parsing. + // We also disable it during registration of the IRDL dialects. + bool wasThreadingEnabled = ctx.isMultithreadingEnabled(); + ctx.disableMultithreading(); + + // Parse the input file. + OwningOpRef module(parseSourceFile(sourceMgr, &ctx)); + + // Register IRDL dialects. + if (irdl::registerDialects(module.get()).failed()) + return failure(); + ctx.enableMultithreading(wasThreadingEnabled); + + return success(); +} + /// Parses the memory buffer. If successfully, run a series of passes against /// it and print the result. static LogicalResult @@ -112,7 +151,7 @@ bool allowUnregisteredDialects, bool preloadDialectsInContext, bool emitBytecode, bool implicitModule, PassPipelineFn passManagerSetupFn, DialectRegistry ®istry, - llvm::ThreadPool *threadPool) { + llvm::ThreadPool *threadPool, StringRef irdlFile) { // Tell sourceMgr about this buffer, which is what the parser will pick up. auto sourceMgr = std::make_shared(); sourceMgr->AddNewSourceBuffer(std::move(ownedBuffer), SMLoc()); @@ -123,6 +162,11 @@ if (threadPool) context.setThreadPool(*threadPool); + if (!irdlFile.empty()) { + if (failed(registerIRDL(irdlFile, context))) + return failure(); + } + // Parse the input file. if (preloadDialectsInContext) context.loadAllAvailableDialects(); @@ -153,14 +197,12 @@ return sourceMgrHandler.verify(); } -LogicalResult mlir::MlirOptMain(raw_ostream &outputStream, - std::unique_ptr buffer, - PassPipelineFn passManagerSetupFn, - DialectRegistry ®istry, bool splitInputFile, - bool verifyDiagnostics, bool verifyPasses, - bool allowUnregisteredDialects, - bool preloadDialectsInContext, - bool emitBytecode, bool implicitModule) { +LogicalResult mlir::MlirOptMain( + raw_ostream &outputStream, std::unique_ptr buffer, + PassPipelineFn passManagerSetupFn, DialectRegistry ®istry, + bool splitInputFile, bool verifyDiagnostics, bool verifyPasses, + bool allowUnregisteredDialects, bool preloadDialectsInContext, + bool emitBytecode, bool implicitModule, StringRef irdlFile) { // The split-input-file mode is a very specific mode that slices the file // up into small pieces and checks each independently. // We use an explicit threadpool to avoid creating and joining/destroying @@ -180,18 +222,21 @@ return processBuffer(os, std::move(chunkBuffer), verifyDiagnostics, verifyPasses, allowUnregisteredDialects, preloadDialectsInContext, emitBytecode, implicitModule, - passManagerSetupFn, registry, threadPool); + passManagerSetupFn, registry, threadPool, irdlFile); }; return splitAndProcessBuffer(std::move(buffer), chunkFn, outputStream, splitInputFile, /*insertMarkerInOutput=*/true); } -LogicalResult mlir::MlirOptMain( - raw_ostream &outputStream, std::unique_ptr buffer, - const PassPipelineCLParser &passPipeline, DialectRegistry ®istry, - bool splitInputFile, bool verifyDiagnostics, bool verifyPasses, - bool allowUnregisteredDialects, bool preloadDialectsInContext, - bool emitBytecode, bool implicitModule, bool dumpPassPipeline) { +LogicalResult mlir::MlirOptMain(raw_ostream &outputStream, + std::unique_ptr buffer, + const PassPipelineCLParser &passPipeline, + DialectRegistry ®istry, bool splitInputFile, + bool verifyDiagnostics, bool verifyPasses, + bool allowUnregisteredDialects, + bool preloadDialectsInContext, + bool emitBytecode, bool implicitModule, + bool dumpPassPipeline, StringRef irdlFile) { auto passManagerSetupFn = [&](PassManager &pm) { auto errorHandler = [&](const Twine &msg) { emitError(UnknownLoc::get(pm.getContext())) << msg; @@ -208,7 +253,7 @@ return MlirOptMain(outputStream, std::move(buffer), passManagerSetupFn, registry, splitInputFile, verifyDiagnostics, verifyPasses, allowUnregisteredDialects, preloadDialectsInContext, - emitBytecode, implicitModule); + emitBytecode, implicitModule, irdlFile); } LogicalResult mlir::MlirOptMain(int argc, char **argv, llvm::StringRef toolName, @@ -260,6 +305,9 @@ "dump-pass-pipeline", cl::desc("Print the pipeline that will be run"), cl::init(false)}; + static cl::opt irdlFile("irdl-file", cl::desc("IRDL file"), + cl::value_desc("filename")); + InitLLVM y(argc, argv); // Register any command line options. @@ -306,7 +354,7 @@ splitInputFile, verifyDiagnostics, verifyPasses, allowUnregisteredDialects, preloadDialectsInContext, emitBytecode, /*implicitModule=*/!noImplicitModule, - dumpPassPipeline))) + dumpPassPipeline, irdlFile))) return failure(); // Keep the output file if the invocation of MlirOptMain was successful. diff --git a/mlir/test/Dialect/IRDL/test-cmath.mlir b/mlir/test/Dialect/IRDL/test-cmath.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/IRDL/test-cmath.mlir @@ -0,0 +1,27 @@ +// RUN: mlir-opt %s --irdl-file=%S/cmath.irdl.mlir | mlir-opt --irdl-file=%S/cmath.irdl.mlir | FileCheck %s + +module { + // CHECK: func.func @conorm(%[[p:[^:]*]]: !cmath.complex, %[[q:[^:]*]]: !cmath.complex) -> f32 { + // CHECK: %[[norm_p:[^ ]*]] = "cmath.norm"(%[[p]]) : (!cmath.complex) -> f32 + // CHECK: %[[norm_q:[^ ]*]] = "cmath.norm"(%[[q]]) : (!cmath.complex) -> f32 + // CHECK: %[[pq:[^ ]*]] = arith.mulf %[[norm_p]], %[[norm_q]] : f32 + // CHECK: return %[[pq]] : f32 + // CHECK: } + func.func @conorm(%p: !cmath.complex, %q: !cmath.complex) -> f32 { + %norm_p = "cmath.norm"(%p) : (!cmath.complex) -> f32 + %norm_q = "cmath.norm"(%q) : (!cmath.complex) -> f32 + %pq = arith.mulf %norm_p, %norm_q : f32 + return %pq : f32 + } + + // CHECK: func.func @conorm2(%[[p:[^:]*]]: !cmath.complex, %[[q:[^:]*]]: !cmath.complex) -> f32 { + // CHECK: %[[pq:[^ ]*]] = "cmath.mul"(%[[p]], %[[q]]) : (!cmath.complex, !cmath.complex) -> !cmath.complex + // CHECK: %[[conorm:[^ ]*]] = "cmath.norm"(%[[pq]]) : (!cmath.complex) -> f32 + // CHECK: return %[[conorm]] : f32 + // CHECK: } + func.func @conorm2(%p: !cmath.complex, %q: !cmath.complex) -> f32 { + %pq = "cmath.mul"(%p, %q) : (!cmath.complex, !cmath.complex) -> !cmath.complex + %conorm = "cmath.norm"(%pq) : (!cmath.complex) -> f32 + return %conorm : f32 + } +} diff --git a/mlir/test/Dialect/IRDL/testd.mlir b/mlir/test/Dialect/IRDL/testd.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/IRDL/testd.mlir @@ -0,0 +1,146 @@ +// RUN: mlir-opt %s --irdl-file=%S/testd.irdl.mlir -split-input-file -verify-diagnostics | FileCheck %s + +//===----------------------------------------------------------------------===// +// Type or attribute constraint +//===----------------------------------------------------------------------===// + +func.func @typeFitsType() { + // CHECK: "testd.any"() : () -> !testd.parametric + "testd.any"() : () -> !testd.parametric + return +} + +// ----- + +func.func @attrDoesntFitType() { + "testd.any"() : () -> !testd.parametric<"foo"> + return +} + +// ----- + +func.func @attrFitsAttr() { + // CHECK: "testd.any"() : () -> !testd.attr_in_type_out<"foo"> + "testd.any"() : () -> !testd.attr_in_type_out<"foo"> + return +} + +// ----- + +func.func @typeFitsAttr() { + // CHECK: "testd.any"() : () -> !testd.attr_in_type_out + "testd.any"() : () -> !testd.attr_in_type_out + return +} + +// ----- + +//===----------------------------------------------------------------------===// +// Equality constraint +//===----------------------------------------------------------------------===// + +func.func @succeededEqConstraint() { + // CHECK: "testd.eq"() : () -> i32 + "testd.eq"() : () -> i32 + return +} + +// ----- + +func.func @failedEqConstraint() { + // expected-error@+1 {{expected 'i32' but got 'i64'}} + "testd.eq"() : () -> i64 + return +} + +// ----- + +//===----------------------------------------------------------------------===// +// Any constraint +//===----------------------------------------------------------------------===// + +func.func @succeededAnyConstraint() { + // CHECK: "testd.any"() : () -> i32 + "testd.any"() : () -> i32 + // CHECK: "testd.any"() : () -> i64 + "testd.any"() : () -> i64 + return +} + +// ----- + +//===----------------------------------------------------------------------===// +// Dynamic base constraint +//===----------------------------------------------------------------------===// + +func.func @succeededDynBaseConstraint() { + // CHECK: "testd.dynbase"() : () -> !testd.parametric + "testd.dynbase"() : () -> !testd.parametric + // CHECK: "testd.dynbase"() : () -> !testd.parametric> + "testd.dynbase"() : () -> !testd.parametric> + return +} + +// ----- + +func.func @failedDynBaseConstraint() { + // expected-error@+1 {{expected base type 'testd.parametric' but got 'i32'}} + "testd.dynbase"() : () -> i32 + return +} + +// ----- + +//===----------------------------------------------------------------------===// +// Dynamic parameters constraint +//===----------------------------------------------------------------------===// + +func.func @succeededDynParamsConstraint() { + // CHECK: "testd.dynparams"() : () -> !testd.parametric + "testd.dynparams"() : () -> !testd.parametric + return +} + +// ----- + +func.func @failedDynParamsConstraintBase() { + // expected-error@+1 {{expected base type 'testd.parametric' but got 'i32'}} + "testd.dynparams"() : () -> i32 + return +} + +// ----- + +func.func @failedDynParamsConstraintParam() { + // expected-error@+1 {{expected 'i32' but got 'i1'}} + "testd.dynparams"() : () -> !testd.parametric + return +} + +// ----- + +//===----------------------------------------------------------------------===// +// Constraint variables +//===----------------------------------------------------------------------===// + +func.func @succeededConstraintVars() { + // CHECK: "testd.constraint_vars"() : () -> (i32, i32) + "testd.constraint_vars"() : () -> (i32, i32) + return +} + +// ----- + +func.func @succeededConstraintVars2() { + // CHECK: "testd.constraint_vars"() : () -> (i64, i64) + "testd.constraint_vars"() : () -> (i64, i64) + return +} + +// ----- + +func.func @failedConstraintVars() { + // expected-error@+1 {{expected 'i64' but got 'i32'}} + "testd.constraint_vars"() : () -> (i64, i32) + return +}