diff --git a/mlir/include/mlir/Dialect/IRDL/IRDLLoading.h b/mlir/include/mlir/Dialect/IRDL/IRDLLoading.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/IRDL/IRDLLoading.h @@ -0,0 +1,28 @@ +//===- IRDLRegistration.h - IRDL registration -------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Manages the registration of MLIR objects from IRDL operations. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_IRDL_IRDLREGISTRATION_H +#define MLIR_DIALECT_IRDL_IRDLREGISTRATION_H + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Support/LogicalResult.h" + +namespace mlir { +namespace irdl { + +/// Load all the dialects defined in the module. +LogicalResult loadDialects(ModuleOp op); + +} // namespace irdl +} // namespace mlir + +#endif // MLIR_DIALECT_IRDL_IRDLREGISTRATION_H diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -23,8 +23,8 @@ #include "mlir/IR/Operation.h" #include "llvm/Support/PointerLikeTypeTraits.h" -#include #include +#include namespace mlir { class Builder; @@ -633,7 +633,7 @@ class Impl : public TraitBase::Impl> { public: - mlir::TypedValue getResult() { + mlir::TypedValue getResult() { return cast>( this->getOperation()->getResult(0)); } @@ -1255,6 +1255,14 @@ << (sizeof...(ParentOpTypes) != 1 ? "to be one of '" : "'") << llvm::ArrayRef({ParentOpTypes::getOperationName()...}) << "'"; } + + template >> + std::enable_if_t + getParentOp() { + Operation *parent = this->getOperation()->getParentOp(); + return llvm::cast(parent); + } }; }; diff --git a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h --- a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h +++ b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h @@ -74,6 +74,13 @@ } bool shouldEmitBytecode() const { return emitBytecodeFlag; } + /// Set the IRDL file to load before processing the input. + MlirOptMainConfig &setIrdlFile(llvm::StringRef file) { + irdlFileFlag = file; + return *this; + } + llvm::StringRef getIrdlFile() const { return irdlFileFlag; } + /// Set the filename to use for logging actions, use "-" for stdout. MlirOptMainConfig &logActionsTo(StringRef filename) { logActionsToFlag = filename; @@ -157,6 +164,9 @@ /// Emit bytecode instead of textual assembly when generating output. bool emitBytecodeFlag = false; + /// IRDL file to register before processing the input. + std::string irdlFileFlag = ""; + /// Log action execution to the given file (or "-" for stdout) std::string logActionsToFlag; diff --git a/mlir/lib/Dialect/IRDL/CMakeLists.txt b/mlir/lib/Dialect/IRDL/CMakeLists.txt --- a/mlir/lib/Dialect/IRDL/CMakeLists.txt +++ b/mlir/lib/Dialect/IRDL/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRIRDL IR/IRDL.cpp + IRDLLoading.cpp DEPENDS MLIRIRDLIncGen diff --git a/mlir/lib/Dialect/IRDL/IRDLLoading.cpp b/mlir/lib/Dialect/IRDL/IRDLLoading.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/IRDL/IRDLLoading.cpp @@ -0,0 +1,131 @@ +//===- IRDLLoading.cpp - IRDL dialect loading --------------------- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Manages the loading of MLIR objects from IRDL operations. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/IRDL/IRDLLoading.h" +#include "mlir/Dialect/IRDL/IR/IRDL.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/ExtensibleDialect.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/Support/SMLoc.h" + +using namespace mlir; +using namespace mlir::irdl; + +/// Define and load an operation represented by a `irdl.operation` +/// operation. +static WalkResult loadOperation(OperationOp op, ExtensibleDialect *dialect) { + // IRDL does not support defining custom parsers or printers. + auto parser = [](OpAsmParser &parser, OperationState &result) { + return failure(); + }; + auto printer = [](Operation *op, OpAsmPrinter &printer, StringRef) { + printer.printGenericOp(op); + }; + + auto verifier = [](Operation *op) { return success(); }; + + // IRDL does not support defining regions. + auto regionVerifier = [](Operation *op) { return success(); }; + + auto opDef = DynamicOpDefinition::get( + op.getName(), dialect, std::move(verifier), std::move(regionVerifier), + std::move(parser), std::move(printer)); + dialect->registerDynamicOp(std::move(opDef)); + + return WalkResult::advance(); +} + +/// Load all dialects in the given module, without loading any +/// operation, type or attribute definitions. +static DenseMap loadEmptyDialects(ModuleOp op) { + DenseMap dialects; + op.walk([&](DialectOp dialectOp) { + MLIRContext *ctx = dialectOp.getContext(); + StringRef dialectName = dialectOp.getName(); + + DynamicDialect *dialect = ctx->getOrLoadDynamicDialect( + dialectName, [](DynamicDialect *dialect) {}); + + dialects.insert({dialectOp, dialect}); + }); + return dialects; +} + +/// Preallocate type definitions objects with empty verifiers. +/// This in particular allocates a TypeID for each type definition. +static DenseMap> +preallocateTypeDefs(ModuleOp op, + DenseMap dialects) { + DenseMap> typeDefs; + op.walk([&](TypeOp typeOp) { + ExtensibleDialect *dialect = dialects[typeOp.getParentOp()]; + auto typeDef = DynamicTypeDefinition::get( + typeOp.getName(), dialect, + [](function_ref, ArrayRef) { + return success(); + }); + typeDefs.try_emplace(typeOp, std::move(typeDef)); + }); + return typeDefs; +} + +/// Preallocate attribute definitions objects with empty verifiers. +/// This in particular allocates a TypeID for each attribute definition. +static DenseMap> +preallocateAttrDefs(ModuleOp op, + DenseMap dialects) { + DenseMap> attrDefs; + op.walk([&](AttributeOp attrOp) { + ExtensibleDialect *dialect = dialects[attrOp.getParentOp()]; + auto attrDef = DynamicAttrDefinition::get( + attrOp.getName(), dialect, + [](function_ref, ArrayRef) { + return success(); + }); + attrDefs.try_emplace(attrOp, std::move(attrDef)); + }); + return attrDefs; +} + +LogicalResult mlir::irdl::loadDialects(ModuleOp op) { + // Preallocate all dialects, and type and attribute definitions. + // In particular, this allocates TypeIDs so type and attributes can have + // verifiers that refer to each other. + DenseMap dialects = loadEmptyDialects(op); + DenseMap> types = + preallocateTypeDefs(op, dialects); + DenseMap> attrs = + preallocateAttrDefs(op, dialects); + + // Define and load all operations. + WalkResult res = op.walk([&](OperationOp opOp) { + return loadOperation(opOp, dialects[opOp.getParentOp()]); + }); + if (res.wasInterrupted()) + return failure(); + + // Load all types in their dialects. + for (auto &pair : types) { + ExtensibleDialect *dialect = dialects[pair.first.getParentOp()]; + dialect->registerDynamicType(std::move(pair.second)); + } + + // Load all attributes in their dialects. + for (auto &pair : attrs) { + ExtensibleDialect *dialect = dialects[pair.first.getParentOp()]; + dialect->registerDynamicAttr(std::move(pair.second)); + } + + return success(); +} diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp --- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp +++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp @@ -16,6 +16,8 @@ #include "mlir/Debug/Counter.h" #include "mlir/Debug/ExecutionContext.h" #include "mlir/Debug/Observers/ActionLogging.h" +#include "mlir/Dialect/IRDL/IR/IRDL.h" +#include "mlir/Dialect/IRDL/IRDLLoading.h" #include "mlir/IR/AsmState.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinOps.h" @@ -69,6 +71,11 @@ "emit-bytecode", cl::desc("Emit bytecode when generating output"), cl::location(emitBytecodeFlag), cl::init(false)); + static cl::opt irdlFile( + "irdl-file", + cl::desc("IRDL file to register before processing the input"), + cl::location(irdlFileFlag), cl::init(""), cl::value_desc("filename")); + static cl::opt explicitModule( "no-implicit-module", cl::desc("Disable implicit addition of a top-level module op during " @@ -275,6 +282,35 @@ return success(); } +LogicalResult loadIRDLDialects(StringRef irdlFile, MLIRContext &ctx) { + DialectRegistry registry; + registry.insert(); + ctx.appendDialectRegistry(registry); + + // Set up the input file. + std::string errorMessage; + std::unique_ptr file = openInputFile(irdlFile, &errorMessage); + if (!file) { + emitError(UnknownLoc::get(&ctx)) << errorMessage; + return failure(); + } + + // Give the buffer to the source manager. + // This will be picked up by the parser. + SourceMgr sourceMgr; + sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc()); + + SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &ctx); + + // Parse the input file. + OwningOpRef module(parseSourceFile(sourceMgr, &ctx)); + + // Load IRDL dialects. + if (irdl::loadDialects(module.get()).failed()) + return failure(); + return success(); +} + /// Parses the memory buffer. If successfully, run a series of passes against /// it and print the result. static LogicalResult processBuffer(raw_ostream &os, @@ -292,6 +328,12 @@ if (threadPool) context.setThreadPool(*threadPool); + StringRef irdlFile = config.getIrdlFile(); + if (!irdlFile.empty()) { + if (failed(loadIRDLDialects(irdlFile, context))) + return failure(); + } + // Parse the input file. if (config.shouldPreloadDialectsInContext()) context.loadAllAvailableDialects(); diff --git a/mlir/test/Dialect/IRDL/test-cmath.mlir b/mlir/test/Dialect/IRDL/test-cmath.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/IRDL/test-cmath.mlir @@ -0,0 +1,27 @@ +// RUN: mlir-opt %s --irdl-file=%S/cmath.irdl.mlir | mlir-opt --irdl-file=%S/cmath.irdl.mlir | FileCheck %s + +module { + // CHECK: func.func @conorm(%[[p:[^:]*]]: !cmath.complex, %[[q:[^:]*]]: !cmath.complex) -> f32 { + // CHECK: %[[norm_p:[^ ]*]] = "cmath.norm"(%[[p]]) : (!cmath.complex) -> f32 + // CHECK: %[[norm_q:[^ ]*]] = "cmath.norm"(%[[q]]) : (!cmath.complex) -> f32 + // CHECK: %[[pq:[^ ]*]] = arith.mulf %[[norm_p]], %[[norm_q]] : f32 + // CHECK: return %[[pq]] : f32 + // CHECK: } + func.func @conorm(%p: !cmath.complex, %q: !cmath.complex) -> f32 { + %norm_p = "cmath.norm"(%p) : (!cmath.complex) -> f32 + %norm_q = "cmath.norm"(%q) : (!cmath.complex) -> f32 + %pq = arith.mulf %norm_p, %norm_q : f32 + return %pq : f32 + } + + // CHECK: func.func @conorm2(%[[p:[^:]*]]: !cmath.complex, %[[q:[^:]*]]: !cmath.complex) -> f32 { + // CHECK: %[[pq:[^ ]*]] = "cmath.mul"(%[[p]], %[[q]]) : (!cmath.complex, !cmath.complex) -> !cmath.complex + // CHECK: %[[conorm:[^ ]*]] = "cmath.norm"(%[[pq]]) : (!cmath.complex) -> f32 + // CHECK: return %[[conorm]] : f32 + // CHECK: } + func.func @conorm2(%p: !cmath.complex, %q: !cmath.complex) -> f32 { + %pq = "cmath.mul"(%p, %q) : (!cmath.complex, !cmath.complex) -> !cmath.complex + %conorm = "cmath.norm"(%pq) : (!cmath.complex) -> f32 + return %conorm : f32 + } +} diff --git a/mlir/test/Dialect/IRDL/testd.mlir b/mlir/test/Dialect/IRDL/testd.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/IRDL/testd.mlir @@ -0,0 +1,108 @@ +// RUN: mlir-opt %s --irdl-file=%S/testd.irdl.mlir -split-input-file -verify-diagnostics | FileCheck %s + +//===----------------------------------------------------------------------===// +// Type or attribute constraint +//===----------------------------------------------------------------------===// + +func.func @typeFitsType() { + // CHECK: "testd.any"() : () -> !testd.parametric + "testd.any"() : () -> !testd.parametric + return +} + +// ----- + +func.func @attrDoesntFitType() { + "testd.any"() : () -> !testd.parametric<"foo"> + return +} + +// ----- + +func.func @attrFitsAttr() { + // CHECK: "testd.any"() : () -> !testd.attr_in_type_out<"foo"> + "testd.any"() : () -> !testd.attr_in_type_out<"foo"> + return +} + +// ----- + +func.func @typeFitsAttr() { + // CHECK: "testd.any"() : () -> !testd.attr_in_type_out + "testd.any"() : () -> !testd.attr_in_type_out + return +} + +// ----- + +//===----------------------------------------------------------------------===// +// Equality constraint +//===----------------------------------------------------------------------===// + +func.func @succeededEqConstraint() { + // CHECK: "testd.eq"() : () -> i32 + "testd.eq"() : () -> i32 + return +} + + +// ----- + +//===----------------------------------------------------------------------===// +// Any constraint +//===----------------------------------------------------------------------===// + +func.func @succeededAnyConstraint() { + // CHECK: "testd.any"() : () -> i32 + "testd.any"() : () -> i32 + // CHECK: "testd.any"() : () -> i64 + "testd.any"() : () -> i64 + return +} + +// ----- + +//===----------------------------------------------------------------------===// +// Dynamic base constraint +//===----------------------------------------------------------------------===// + +func.func @succeededDynBaseConstraint() { + // CHECK: "testd.dynbase"() : () -> !testd.parametric + "testd.dynbase"() : () -> !testd.parametric + // CHECK: "testd.dynbase"() : () -> !testd.parametric> + "testd.dynbase"() : () -> !testd.parametric> + return +} + + +// ----- + +//===----------------------------------------------------------------------===// +// Dynamic parameters constraint +//===----------------------------------------------------------------------===// + +func.func @succeededDynParamsConstraint() { + // CHECK: "testd.dynparams"() : () -> !testd.parametric + "testd.dynparams"() : () -> !testd.parametric + return +} + +// ----- + +//===----------------------------------------------------------------------===// +// Constraint variables +//===----------------------------------------------------------------------===// + +func.func @succeededConstraintVars() { + // CHECK: "testd.constraint_vars"() : () -> (i32, i32) + "testd.constraint_vars"() : () -> (i32, i32) + return +} + +// ----- + +func.func @succeededConstraintVars2() { + // CHECK: "testd.constraint_vars"() : () -> (i64, i64) + "testd.constraint_vars"() : () -> (i64, i64) + return +}