diff --git a/mlir/docs/PDLL.md b/mlir/docs/PDLL.md --- a/mlir/docs/PDLL.md +++ b/mlir/docs/PDLL.md @@ -1135,17 +1135,86 @@ ``` The arguments of the constraint are accessible within the code block via the -same name. The type of these native variables are mapped directly to the -corresponding MLIR type of the [core constraint](#core-constraints) used. For -example, an `Op` corresponds to a variable of type `Operation *`. +same name. See the ["type translation"](#native-constraint-type-translations) below for +detailed information on how PDLL types are converted to native types. In addition to the +PDLL arguments, the code block may also access the current `PatternRewriter` using +`rewriter`. The result type of the native constraint function is implicitly defined +as a `::mlir::LogicalResult`. -The results of the constraint can be populated using the provided `results` -variable. This variable is a `PDLResultList`, and expects results to be -populated in the order that they are defined within the result list of the -constraint declaration. +Taking the constraints defined above as an example, these function would roughly be +translated into: -In addition to the above, the code block may also access the current -`PatternRewriter` using `rewriter`. +```c++ +LogicalResult HasOneUse(PatternRewriter &rewriter, Value value) { + return success(value.hasOneUse()); +} +LogicalResult HasSameElementType(Value value1, Value value2) { + return success(value1.getType().cast().getElementType() == + value2.getType().cast().getElementType()); +} +``` + +TODO: Native constraints should also be able allowed to return values in certain cases. + +###### Native Constraint Type Translations + +The types of argument and result variables are generally mapped to the corresponding +MLIR type of the [constraint](#constraints) used. Below is a detailed description +of how the mapped type of a variable is determined for the various different types of +constraints. + +* Attr, Op, Type, TypeRange, Value, ValueRange: + +These are all core constraints, and are mapped directly to the MLIR equivalent +(that their names suggest), namely: + + * `Attr` -> "::mlir::Attribute" + * `Op` -> "::mlir::Operation *" + * `Type` -> "::mlir::Type" + * `TypeRange` -> "::mlir::TypeRange" + * `Value` -> "::mlir::Value" + * `ValueRange` -> "::mlir::ValueRange" + +* Op + +A named operation constraint has a unique translation. If the ODS registration of the +referenced operation has been included, the qualified C++ is used. If the ODS information +is not available, this constraint maps to "::mlir::Operation *", similarly to the unnamed +variant. For example, given the following: + +```pdll +// `my_ops.td` provides the ODS definition of the `my_dialect` operations, such as +// `my_dialect.bar` used below. +#include "my_ops.td" + +Constraint Cst(op: Op) [{ + return success(op ... ); +}]; +``` + +The native type used for `op` may be of the form `my_dialect::BarOp`, as opposed to the +default `::mlir::Operation *`. Below is a sample translation of the above constraint: + +```c++ +LogicalResult Cst(my_dialect::BarOp op) { + return success(op ... ); +} +``` + +* Imported ODS Constraints + +Aside from the core constraints, certain constraints imported from ODS may use a unique +native type. How to enable this unique type depends on the ODS constraint construct that +was imported: + + * `Attr` constraints + - Imported `Attr` constraints utilize the `storageType` field for native type translation. + + * `Type` constraints + - Imported `Type` constraints utilize the `cppClassName` field for native type translation. + + * `AttrInterface`/`OpInterface`/`TypeInterface` constraints + - Imported interfaces utilize the `cppClassName` field for native type translation. #### Defining Constraints Inline @@ -1403,10 +1472,7 @@ ```pdll Rewrite BuildOp(value: Value) -> (foo: Op, bar: Op) [{ - // We push back the results into the `results` variable in the order defined - // by the result list of the rewrite declaration. - results.push_back(rewriter.create(value)); - results.push_back(rewriter.create()); + return {rewriter.create(value), rewriter.create()}; }]; Pattern { @@ -1420,17 +1486,85 @@ ``` The arguments of the rewrite are accessible within the code block via the -same name. The type of these native variables are mapped directly to the -corresponding MLIR type of the [core constraint](#core-constraints) used. For -example, an `Op` corresponds to a variable of type `Operation *`. +same name. See the ["type translation"](#native-rewrite-type-translations) below for +detailed information on how PDLL types are converted to native types. In addition to the +PDLL arguments, the code block may also access the current `PatternRewriter` using +`rewriter`. See the ["result translation"](#native-rewrite-result-translation) section +for detailed information on how the result type of the native function is determined. + +Taking the rewrite defined above as an example, this function would roughly be +translated into: + +```c++ +std::tuple BuildOp(Value value) { + return {rewriter.create(value), rewriter.create()}; +} +``` -The results of the rewrite can be populated using the provided `results` -variable. This variable is a `PDLResultList`, and expects results to be -populated in the order that they are defined within the result list of the -rewrite declaration. +###### Native Rewrite Type Translations -In addition to the above, the code block may also access the current -`PatternRewriter` using `rewriter`. +The types of argument and result variables are generally mapped to the corresponding +MLIR type of the [constraint](#constraints) used. The rules of native `Rewrite` type translation +are identical to those of native `Constraint`s, please view the corresponding +[native `Constraint` type translation](#native-constraint-type-translations) section for a +detailed description of how the mapped type of a variable is determined. + +###### Native Rewrite Result Translation + +The results of a native rewrite are directly translated to the results of the native function, +using the type translation rules [described above](#native-rewrite-type-translations). The section +below describes the various result translation scenarios: + +* Zero Result + +```pdll +Rewrite createOp() [{ + rewriter.create(); +}]; +``` + +In the case where a native `Rewrite` has no results, the native function returns `void`: + +```c++ +void createOp(PatternRewriter &rewriter) { + rewriter.create(); +} +``` + +* Single Result + +```pdll +Rewrite createOp() -> Op [{ + return rewriter.create(); +}]; +``` + +In the case where a native `Rewrite` has a single result, the native function returns the corresponding +native type for that single result: + +```c++ +my_dialect::FooOp createOp(PatternRewriter &rewriter) { + return rewriter.create(); +} +``` + +* Multi Result + +```pdll +Rewrite complexRewrite(value: Value) -> (Op, FunctionOpInterface) [{ + ... +}]; +``` + +In the case where a native `Rewrite` has multiple results, the native function returns a `std::tuple<...>` +containing the corresponding native types for each of the results: + +```c++ +std::tuple +complexRewrite(PatternRewriter &rewriter, Value value) { + ... +} +``` #### Defining Rewrites Inline diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -943,9 +943,13 @@ " to be of type: " + llvm::getTypeName()); }); } + using ProcessPDLValueBasedOn::verifyAsArg; + static T processAsArg(BaseT baseValue) { return baseValue.template cast(); } + using ProcessPDLValueBasedOn::processAsArg; + static void processAsResult(PatternRewriter &, PDLResultList &results, T value) { results.push_back(value); @@ -967,6 +971,8 @@ struct ProcessPDLValue : public ProcessPDLValueBasedOn { static StringRef processAsArg(StringAttr value) { return value.getValue(); } + using ProcessPDLValueBasedOn::processAsArg; + static void processAsResult(PatternRewriter &rewriter, PDLResultList &results, StringRef value) { results.push_back(rewriter.getStringAttr(value)); diff --git a/mlir/include/mlir/Tools/PDLL/AST/Nodes.h b/mlir/include/mlir/Tools/PDLL/AST/Nodes.h --- a/mlir/include/mlir/Tools/PDLL/AST/Nodes.h +++ b/mlir/include/mlir/Tools/PDLL/AST/Nodes.h @@ -506,6 +506,7 @@ NamedAttributeDecl *> { public: static OperationExpr *create(Context &ctx, SMRange loc, + const ods::Operation *odsOp, const OpNameDecl *nameDecl, ArrayRef operands, ArrayRef resultTypes, @@ -830,16 +831,15 @@ /// - This is a constraint which is defined using only PDLL constructs. class UserConstraintDecl final : public Node::NodeBase, - llvm::TrailingObjects { + llvm::TrailingObjects { public: /// Create a native constraint with the given optional code block. - static UserConstraintDecl *createNative(Context &ctx, const Name &name, - ArrayRef inputs, - ArrayRef results, - Optional codeBlock, - Type resultType) { - return createImpl(ctx, name, inputs, results, codeBlock, /*body=*/nullptr, - resultType); + static UserConstraintDecl * + createNative(Context &ctx, const Name &name, ArrayRef inputs, + ArrayRef results, Optional codeBlock, + Type resultType, ArrayRef nativeInputTypes = {}) { + return createImpl(ctx, name, inputs, nativeInputTypes, results, codeBlock, + /*body=*/nullptr, resultType); } /// Create a PDLL constraint with the given body. @@ -848,8 +848,8 @@ ArrayRef results, const CompoundStmt *body, Type resultType) { - return createImpl(ctx, name, inputs, results, /*codeBlock=*/llvm::None, - body, resultType); + return createImpl(ctx, name, inputs, /*nativeInputTypes=*/llvm::None, + results, /*codeBlock=*/llvm::None, body, resultType); } /// Return the name of the constraint. @@ -863,6 +863,10 @@ return const_cast(this)->getInputs(); } + /// Return the explicit native type to use for the given input. Returns None + /// if no explicit type was set. + Optional getNativeInputType(unsigned index) const; + /// Return the explicit results of the constraint declaration. May be empty, /// even if the constraint has results (e.g. in the case of inferred results). MutableArrayRef getResults() { @@ -891,10 +895,12 @@ /// components. static UserConstraintDecl * createImpl(Context &ctx, const Name &name, ArrayRef inputs, + ArrayRef nativeInputTypes, ArrayRef results, Optional codeBlock, const CompoundStmt *body, Type resultType); - UserConstraintDecl(const Name &name, unsigned numInputs, unsigned numResults, + UserConstraintDecl(const Name &name, unsigned numInputs, + bool hasNativeInputTypes, unsigned numResults, Optional codeBlock, const CompoundStmt *body, Type resultType) : Base(name.getLoc(), &name), numInputs(numInputs), @@ -916,8 +922,14 @@ /// The result type of the constraint. Type resultType; + /// Flag indicating if this constraint has explicit native input types. + bool hasNativeInputTypes; + /// Allow access to various internals. - friend llvm::TrailingObjects; + friend llvm::TrailingObjects; + size_t numTrailingObjects(OverloadToken) const { + return numInputs + numResults; + } }; //===----------------------------------------------------------------------===// @@ -1145,6 +1157,23 @@ return cast(this)->getResultType(); } + /// Return the explicit results of the declaration. Note that these may be + /// empty, even if the callable has results (e.g. in the case of inferred + /// results). + ArrayRef getResults() const { + if (const auto *cst = dyn_cast(this)) + return cst->getResults(); + return cast(this)->getResults(); + } + + /// Return the optional code block of this callable, if this is a native + /// callable with a provided implementation. + Optional getCodeBlock() const { + if (const auto *cst = dyn_cast(this)) + return cst->getCodeBlock(); + return cast(this)->getCodeBlock(); + } + /// Support LLVM type casting facilities. static bool classof(const Node *decl) { return isa(decl); diff --git a/mlir/include/mlir/Tools/PDLL/AST/Types.h b/mlir/include/mlir/Tools/PDLL/AST/Types.h --- a/mlir/include/mlir/Tools/PDLL/AST/Types.h +++ b/mlir/include/mlir/Tools/PDLL/AST/Types.h @@ -14,6 +14,10 @@ namespace mlir { namespace pdll { +namespace ods { +class Operation; +} // namespace ods + namespace ast { class Context; @@ -151,10 +155,15 @@ /// Return an instance of the Operation type with an optional operation name. /// If no name is provided, this type may refer to any operation. static OperationType get(Context &context, - Optional name = llvm::None); + Optional name = llvm::None, + const ods::Operation *odsOp = nullptr); /// Return the name of this operation type, or None if it doesn't have on. Optional getName() const; + + /// Return the ODS operation that this type refers to, or nullptr if the ODS + /// operation is unknown. + const ods::Operation *getODSOperation() const; }; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Tools/PDLL/ODS/Context.h b/mlir/include/mlir/Tools/PDLL/ODS/Context.h --- a/mlir/include/mlir/Tools/PDLL/ODS/Context.h +++ b/mlir/include/mlir/Tools/PDLL/ODS/Context.h @@ -63,7 +63,8 @@ /// operation already existed). std::pair insertOperation(StringRef name, StringRef summary, StringRef desc, - bool supportsResultTypeInferrence, SMLoc loc); + StringRef nativeClassName, bool supportsResultTypeInferrence, + SMLoc loc); /// Lookup an operation registered with the given name, or null if no /// operation with that name is registered. diff --git a/mlir/include/mlir/Tools/PDLL/ODS/Dialect.h b/mlir/include/mlir/Tools/PDLL/ODS/Dialect.h --- a/mlir/include/mlir/Tools/PDLL/ODS/Dialect.h +++ b/mlir/include/mlir/Tools/PDLL/ODS/Dialect.h @@ -35,7 +35,8 @@ /// operation already existed). std::pair insertOperation(StringRef name, StringRef summary, StringRef desc, - bool supportsResultTypeInferrence, SMLoc loc); + StringRef nativeClassName, bool supportsResultTypeInferrence, + SMLoc loc); /// Lookup an operation registered with the given name, or null if no /// operation with that name is registered. diff --git a/mlir/include/mlir/Tools/PDLL/ODS/Operation.h b/mlir/include/mlir/Tools/PDLL/ODS/Operation.h --- a/mlir/include/mlir/Tools/PDLL/ODS/Operation.h +++ b/mlir/include/mlir/Tools/PDLL/ODS/Operation.h @@ -154,6 +154,9 @@ /// Returns the description of the operation. StringRef getDescription() const { return description; } + /// Returns the native class name of the operation. + StringRef getNativeClassName() const { return nativeClassName; } + /// Returns the attributes of this operation. ArrayRef getAttributes() const { return attributes; } @@ -168,7 +171,7 @@ private: Operation(StringRef name, StringRef summary, StringRef desc, - bool supportsTypeInferrence, SMLoc loc); + StringRef nativeClassName, bool supportsTypeInferrence, SMLoc loc); /// The name of the operation. std::string name; @@ -177,6 +180,9 @@ std::string summary; std::string description; + /// The native class name of the operation, used when generating native code. + std::string nativeClassName; + /// Flag indicating if the operation is known to support type inferrence. bool supportsTypeInferrence; diff --git a/mlir/lib/Tools/PDLL/AST/Nodes.cpp b/mlir/lib/Tools/PDLL/AST/Nodes.cpp --- a/mlir/lib/Tools/PDLL/AST/Nodes.cpp +++ b/mlir/lib/Tools/PDLL/AST/Nodes.cpp @@ -298,17 +298,18 @@ // OperationExpr //===----------------------------------------------------------------------===// -OperationExpr *OperationExpr::create( - Context &ctx, SMRange loc, const OpNameDecl *name, - ArrayRef operands, ArrayRef resultTypes, - ArrayRef attributes) { +OperationExpr * +OperationExpr::create(Context &ctx, SMRange loc, const ods::Operation *odsOp, + const OpNameDecl *name, ArrayRef operands, + ArrayRef resultTypes, + ArrayRef attributes) { unsigned allocSize = OperationExpr::totalSizeToAlloc( operands.size() + resultTypes.size(), attributes.size()); void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(OperationExpr)); - Type resultType = OperationType::get(ctx, name->getName()); + Type resultType = OperationType::get(ctx, name->getName(), odsOp); OperationExpr *opExpr = new (rawData) OperationExpr(loc, resultType, name, operands.size(), resultTypes.size(), attributes.size(), name->getLoc()); @@ -426,23 +427,41 @@ // UserConstraintDecl //===----------------------------------------------------------------------===// +Optional +UserConstraintDecl::getNativeInputType(unsigned index) const { + return hasNativeInputTypes ? getTrailingObjects()[index] + : Optional(); +} + UserConstraintDecl *UserConstraintDecl::createImpl( Context &ctx, const Name &name, ArrayRef inputs, - ArrayRef results, Optional codeBlock, - const CompoundStmt *body, Type resultType) { - unsigned allocSize = UserConstraintDecl::totalSizeToAlloc( - inputs.size() + results.size()); + ArrayRef nativeInputTypes, ArrayRef results, + Optional codeBlock, const CompoundStmt *body, Type resultType) { + bool hasNativeInputTypes = !nativeInputTypes.empty(); + assert(!hasNativeInputTypes || nativeInputTypes.size() == inputs.size()); + + unsigned allocSize = + UserConstraintDecl::totalSizeToAlloc( + inputs.size() + results.size(), + hasNativeInputTypes ? inputs.size() : 0); void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(UserConstraintDecl)); if (codeBlock) codeBlock = codeBlock->copy(ctx.getAllocator()); - UserConstraintDecl *decl = new (rawData) UserConstraintDecl( - name, inputs.size(), results.size(), codeBlock, body, resultType); + UserConstraintDecl *decl = new (rawData) + UserConstraintDecl(name, inputs.size(), hasNativeInputTypes, + results.size(), codeBlock, body, resultType); std::uninitialized_copy(inputs.begin(), inputs.end(), decl->getInputs().begin()); std::uninitialized_copy(results.begin(), results.end(), decl->getResults().begin()); + if (hasNativeInputTypes) { + StringRef *nativeInputTypesPtr = decl->getTrailingObjects(); + for (unsigned i = 0, e = inputs.size(); i < e; ++i) + nativeInputTypesPtr[i] = nativeInputTypes[i].copy(ctx.getAllocator()); + } + return decl; } diff --git a/mlir/lib/Tools/PDLL/AST/TypeDetail.h b/mlir/lib/Tools/PDLL/AST/TypeDetail.h --- a/mlir/lib/Tools/PDLL/AST/TypeDetail.h +++ b/mlir/lib/Tools/PDLL/AST/TypeDetail.h @@ -75,13 +75,15 @@ //===----------------------------------------------------------------------===// struct OperationTypeStorage - : public TypeStorageBase { + : public TypeStorageBase> { using Base::Base; static OperationTypeStorage * - construct(StorageUniquer::StorageAllocator &alloc, StringRef key) { - return new (alloc.allocate()) - OperationTypeStorage(alloc.copyInto(key)); + construct(StorageUniquer::StorageAllocator &alloc, + const std::pair &key) { + return new (alloc.allocate()) OperationTypeStorage( + std::make_pair(alloc.copyInto(key.first), key.second)); } }; diff --git a/mlir/lib/Tools/PDLL/AST/Types.cpp b/mlir/lib/Tools/PDLL/AST/Types.cpp --- a/mlir/lib/Tools/PDLL/AST/Types.cpp +++ b/mlir/lib/Tools/PDLL/AST/Types.cpp @@ -11,6 +11,7 @@ #include "mlir/Tools/PDLL/AST/Context.h" using namespace mlir; +using namespace mlir::pdll; using namespace mlir::pdll::ast; MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::pdll::ast::detail::AttributeTypeStorage) @@ -68,16 +69,22 @@ // OperationType //===----------------------------------------------------------------------===// -OperationType OperationType::get(Context &context, Optional name) { +OperationType OperationType::get(Context &context, Optional name, + const ods::Operation *odsOp) { return context.getTypeUniquer().get( - /*initFn=*/function_ref(), name.getValueOr("")); + /*initFn=*/function_ref(), + std::make_pair(name.getValueOr(""), odsOp)); } Optional OperationType::getName() const { - StringRef name = getImplAs()->getValue(); + StringRef name = getImplAs()->getValue().first; return name.empty() ? Optional() : Optional(name); } +const ods::Operation *OperationType::getODSOperation() const { + return getImplAs()->getValue().second; +} + //===----------------------------------------------------------------------===// // RangeType //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp b/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp --- a/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp +++ b/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/PDL/IR/PDLOps.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Tools/PDLL/AST/Nodes.h" +#include "mlir/Tools/PDLL/ODS/Operation.h" #include "llvm/ADT/SmallString.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringSet.h" @@ -51,11 +52,20 @@ void generate(const ast::UserConstraintDecl *decl, StringSet<> &nativeFunctions); void generate(const ast::UserRewriteDecl *decl, StringSet<> &nativeFunctions); - void generateConstraintOrRewrite(StringRef name, bool isConstraint, - ArrayRef inputs, - StringRef codeBlock, + void generateConstraintOrRewrite(const ast::CallableDecl *decl, + bool isConstraint, StringSet<> &nativeFunctions); + /// Return the native name of the type to use for the single input argument of + /// the given constraint. + StringRef getNativeTypeName(const ast::UserConstraintDecl *cst); + + /// Return the native name for the type of the given type. + StringRef getNativeTypeName(ast::Type type); + + /// Return the native name for the type of the given variable decl. + StringRef getNativeTypeName(ast::VariableDecl *decl); + /// The stream to output to. raw_ostream &os; }; @@ -152,55 +162,91 @@ void CodeGen::generate(const ast::UserConstraintDecl *decl, StringSet<> &nativeFunctions) { - return generateConstraintOrRewrite(decl->getName().getName(), - /*isConstraint=*/true, decl->getInputs(), - *decl->getCodeBlock(), nativeFunctions); + return generateConstraintOrRewrite(cast(decl), + /*isConstraint=*/true, nativeFunctions); } void CodeGen::generate(const ast::UserRewriteDecl *decl, StringSet<> &nativeFunctions) { - return generateConstraintOrRewrite(decl->getName().getName(), - /*isConstraint=*/false, decl->getInputs(), - *decl->getCodeBlock(), nativeFunctions); + return generateConstraintOrRewrite(cast(decl), + /*isConstraint=*/false, nativeFunctions); +} + +StringRef CodeGen::getNativeTypeName(const ast::UserConstraintDecl *cst) { + if (Optional name = cst->getNativeInputType(0)) + return *name; + return getNativeTypeName(cst->getInputs()[0]); +} + +StringRef CodeGen::getNativeTypeName(ast::Type type) { + return llvm::TypeSwitch(type) + .Case([&](ast::AttributeType) { return "::mlir::Attribute"; }) + .Case([&](ast::OperationType opType) -> StringRef { + // Use the derived Op class when available. + if (const auto *odsOp = opType.getODSOperation()) + return odsOp->getNativeClassName(); + return "::mlir::Operation *"; + }) + .Case([&](ast::TypeType) { return "::mlir::Type"; }) + .Case([&](ast::ValueType) { return "::mlir::Value"; }) + .Case([&](ast::TypeRangeType) { return "::mlir::TypeRange"; }) + .Case([&](ast::ValueRangeType) { return "::mlir::ValueRange"; }); +} + +StringRef CodeGen::getNativeTypeName(ast::VariableDecl *decl) { + // Try to extract a type name from the variable's constraints. + for (ast::ConstraintRef &cst : decl->getConstraints()) { + if (auto *userCst = dyn_cast(cst.constraint)) + return getNativeTypeName(userCst); + } + + // Otherwise, use the type of the variable. + return getNativeTypeName(decl->getType()); } -void CodeGen::generateConstraintOrRewrite(StringRef name, bool isConstraint, - ArrayRef inputs, - StringRef codeBlock, +void CodeGen::generateConstraintOrRewrite(const ast::CallableDecl *decl, + bool isConstraint, StringSet<> &nativeFunctions) { + StringRef name = decl->getName()->getName(); nativeFunctions.insert(name); - // TODO: Should there be something explicit for handling optionality? - auto getCppType = [&](ast::Type type) -> StringRef { - return llvm::TypeSwitch(type) - .Case([&](ast::AttributeType) { return "::mlir::Attribute"; }) - .Case([&](ast::OperationType) { - // TODO: Allow using the derived Op class when possible. - return "::mlir::Operation *"; - }) - .Case([&](ast::TypeType) { return "::mlir::Type"; }) - .Case([&](ast::ValueType) { return "::mlir::Value"; }) - .Case([&](ast::TypeRangeType) { return "::mlir::TypeRange"; }) - .Case([&](ast::ValueRangeType) { return "::mlir::ValueRange"; }); - }; - os << "static " << (isConstraint ? "::mlir::LogicalResult " : "void ") << name - << "PDLFn(::mlir::PatternRewriter &rewriter, " - << (isConstraint ? "" : "::mlir::PDLResultList &results, ") - << "::llvm::ArrayRef<::mlir::PDLValue> values) {\n"; - - const char *argumentInitStr = R"( - {0} {1} = {{}; - if (values[{2}]) - {1} = values[{2}].cast<{0}>(); - (void){1}; -)"; - for (const auto &it : llvm::enumerate(inputs)) { - const ast::VariableDecl *input = it.value(); - os << llvm::formatv(argumentInitStr, getCppType(input->getType()), - input->getName().getName(), it.index()); + os << "static "; + + // TODO: Work out a proper modeling for "optionality". + + // Emit the result type. + // If this is a constraint, we always return a LogicalResult. + // TODO: This will need to change if we allow Constraints to return values as + // well. + if (isConstraint) { + os << "::mlir::LogicalResult"; + } else { + // Otherwise, generate a type based on the results of the callable. + // If the callable has explicit results, use those to build the result. + // Otherwise, use the type of the callable. + ArrayRef results = decl->getResults(); + if (results.empty()) { + os << "void"; + } else if (results.size() == 1) { + os << getNativeTypeName(results[0]); + } else { + os << "std::tuple<"; + llvm::interleaveComma(results, os, [&](ast::VariableDecl *result) { + os << getNativeTypeName(result); + }); + os << ">"; + } } - os << " " << codeBlock.trim() << "\n}\n"; + os << " " << name << "PDLFn(::mlir::PatternRewriter &rewriter"; + if (!decl->getInputs().empty()) { + os << ", "; + llvm::interleaveComma(decl->getInputs(), os, [&](ast::VariableDecl *input) { + os << getNativeTypeName(input) << " " << input->getName().getName(); + }); + } + os << ") {\n"; + os << " " << decl->getCodeBlock()->trim() << "\n}\n\n"; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp b/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp --- a/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp +++ b/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp @@ -442,8 +442,7 @@ return builder.create(loc, mlirType, parentExprs[0]); } - assert(opType.getName() && "expected valid operation name"); - const ods::Operation *odsOp = odsContext.lookupOperation(*opType.getName()); + const ods::Operation *odsOp = opType.getODSOperation(); assert(odsOp && "expected valid ODS operation information"); // Find the result with the member name or by index. diff --git a/mlir/lib/Tools/PDLL/ODS/Context.cpp b/mlir/lib/Tools/PDLL/ODS/Context.cpp --- a/mlir/lib/Tools/PDLL/ODS/Context.cpp +++ b/mlir/lib/Tools/PDLL/ODS/Context.cpp @@ -61,10 +61,12 @@ std::pair Context::insertOperation(StringRef name, StringRef summary, StringRef desc, + StringRef nativeClassName, bool supportsResultTypeInferrence, SMLoc loc) { std::pair dialectAndName = name.split('.'); return insertDialect(dialectAndName.first) - .insertOperation(name, summary, desc, supportsResultTypeInferrence, loc); + .insertOperation(name, summary, desc, nativeClassName, + supportsResultTypeInferrence, loc); } const Operation *Context::lookupOperation(StringRef name) const { diff --git a/mlir/lib/Tools/PDLL/ODS/Dialect.cpp b/mlir/lib/Tools/PDLL/ODS/Dialect.cpp --- a/mlir/lib/Tools/PDLL/ODS/Dialect.cpp +++ b/mlir/lib/Tools/PDLL/ODS/Dialect.cpp @@ -23,13 +23,14 @@ std::pair Dialect::insertOperation(StringRef name, StringRef summary, StringRef desc, + StringRef nativeClassName, bool supportsResultTypeInferrence, llvm::SMLoc loc) { std::unique_ptr &operation = operations[name]; if (operation) return std::make_pair(&*operation, /*wasInserted*/ false); - operation.reset( - new Operation(name, summary, desc, supportsResultTypeInferrence, loc)); + operation.reset(new Operation(name, summary, desc, nativeClassName, + supportsResultTypeInferrence, loc)); return std::make_pair(&*operation, /*wasInserted*/ true); } diff --git a/mlir/lib/Tools/PDLL/ODS/Operation.cpp b/mlir/lib/Tools/PDLL/ODS/Operation.cpp --- a/mlir/lib/Tools/PDLL/ODS/Operation.cpp +++ b/mlir/lib/Tools/PDLL/ODS/Operation.cpp @@ -18,8 +18,10 @@ //===----------------------------------------------------------------------===// Operation::Operation(StringRef name, StringRef summary, StringRef desc, - bool supportsTypeInferrence, llvm::SMLoc loc) + StringRef nativeClassName, bool supportsTypeInferrence, + llvm::SMLoc loc) : name(name.str()), summary(summary.str()), + nativeClassName(nativeClassName.str()), supportsTypeInferrence(supportsTypeInferrence), location(loc, llvm::SMLoc::getFromPointer(loc.getPointer() + 1)) { llvm::raw_string_ostream descOS(description); diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp --- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp +++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp @@ -142,11 +142,13 @@ template ast::Decl *createODSNativePDLLConstraintDecl(StringRef name, StringRef codeBlock, SMRange loc, - ast::Type type); + ast::Type type, + StringRef nativeType); template ast::Decl * createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint, - SMRange loc, ast::Type type); + SMRange loc, ast::Type type, + StringRef nativeType); //===--------------------------------------------------------------------===// // Decls @@ -610,8 +612,7 @@ if (type == valueTy) { // If the operation is registered, we can verify if it can ever have a // single result. - Optional opName = exprOpType.getName(); - if (const ods::Operation *odsOp = lookupODSOperation(opName)) { + if (const ods::Operation *odsOp = exprOpType.getODSOperation()) { if (odsOp->getResults().empty()) { return emitConvertError()->attachNote( llvm::formatv("see the definition of `{0}`, which was defined " @@ -821,7 +822,8 @@ ods::Operation *odsOp = nullptr; std::tie(odsOp, inserted) = odsContext.insertOperation( op.getOperationName(), op.getSummary(), op.getDescription(), - supportsResultTypeInferrence, op.getLoc().front()); + op.getQualCppClassName(), supportsResultTypeInferrence, + op.getLoc().front()); // Ignore operations that have already been added. if (!inserted) @@ -846,19 +848,21 @@ /// Attr constraints. for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Attr")) { if (!def->isAnonymous() && !curDeclScope->lookup(def->getName())) { + tblgen::Attribute constraint(def); decls.push_back( createODSNativePDLLConstraintDecl( - tblgen::AttrConstraint(def), - convertLocToRange(def->getLoc().front()), attrTy)); + constraint, convertLocToRange(def->getLoc().front()), attrTy, + constraint.getStorageType())); } } /// Type constraints. for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Type")) { if (!def->isAnonymous() && !curDeclScope->lookup(def->getName())) { + tblgen::TypeConstraint constraint(def); decls.push_back( createODSNativePDLLConstraintDecl( - tblgen::TypeConstraint(def), - convertLocToRange(def->getLoc().front()), typeTy)); + constraint, convertLocToRange(def->getLoc().front()), typeTy, + constraint.getCPPClassName())); } } /// Interfaces. @@ -870,24 +874,26 @@ continue; SMRange loc = convertLocToRange(def->getLoc().front()); - StringRef className = def->getValueAsString("cppClassName"); - StringRef cppNamespace = def->getValueAsString("cppNamespace"); + std::string cppClassName = + llvm::formatv("{0}::{1}", def->getValueAsString("cppNamespace"), + def->getValueAsString("cppClassName")) + .str(); std::string codeBlock = - llvm::formatv("return ::mlir::success(llvm::isa<{0}::{1}>(self));", - cppNamespace, className) + llvm::formatv("return ::mlir::success(llvm::isa<{0}>(self));", + cppClassName) .str(); if (def->isSubClassOf("OpInterface")) { decls.push_back(createODSNativePDLLConstraintDecl( - name, codeBlock, loc, opTy)); + name, codeBlock, loc, opTy, cppClassName)); } else if (def->isSubClassOf("AttrInterface")) { decls.push_back( createODSNativePDLLConstraintDecl( - name, codeBlock, loc, attrTy)); + name, codeBlock, loc, attrTy, cppClassName)); } else if (def->isSubClassOf("TypeInterface")) { decls.push_back( createODSNativePDLLConstraintDecl( - name, codeBlock, loc, typeTy)); + name, codeBlock, loc, typeTy, cppClassName)); } } } @@ -895,7 +901,8 @@ template ast::Decl * Parser::createODSNativePDLLConstraintDecl(StringRef name, StringRef codeBlock, - SMRange loc, ast::Type type) { + SMRange loc, ast::Type type, + StringRef nativeType) { // Build the single input parameter. ast::DeclScope *argScope = pushDeclScope(); auto *paramVar = ast::VariableDecl::create( @@ -907,7 +914,7 @@ // Build the native constraint. auto *constraintDecl = ast::UserConstraintDecl::createNative( ctx, ast::Name::create(ctx, name, loc), paramVar, - /*results=*/llvm::None, codeBlock, ast::TupleType::get(ctx)); + /*results=*/llvm::None, codeBlock, ast::TupleType::get(ctx), nativeType); curDeclScope->add(constraintDecl); return constraintDecl; } @@ -915,7 +922,8 @@ template ast::Decl * Parser::createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint, - SMRange loc, ast::Type type) { + SMRange loc, ast::Type type, + StringRef nativeType) { // Format the condition template. tblgen::FmtContext fmtContext; fmtContext.withSelf("self"); @@ -924,7 +932,7 @@ &fmtContext); return createODSNativePDLLConstraintDecl( - constraint.getUniqueDefName(), codeBlock, loc, type); + constraint.getUniqueDefName(), codeBlock, loc, type, nativeType); } //===----------------------------------------------------------------------===// @@ -2531,7 +2539,8 @@ constraintType = ast::AttributeType::get(ctx); } else if (const auto *cst = dyn_cast(ref.constraint)) { - constraintType = ast::OperationType::get(ctx, cst->getName()); + constraintType = ast::OperationType::get( + ctx, cst->getName(), lookupODSOperation(cst->getName())); } else if (isa(ref.constraint)) { constraintType = typeTy; } else if (isa(ref.constraint)) { @@ -2707,7 +2716,7 @@ return valueRangeTy; // Verify member access based on the operation type. - if (const ods::Operation *odsOp = lookupODSOperation(opType.getName())) { + if (const ods::Operation *odsOp = opType.getODSOperation()) { auto results = odsOp->getResults(); // Handle indexed results. @@ -2784,7 +2793,7 @@ checkOperationResultTypeInferrence(loc, *opNameRef, odsOp); } - return ast::OperationExpr::create(ctx, loc, name, operands, results, + return ast::OperationExpr::create(ctx, loc, odsOp, name, operands, results, attributes); } diff --git a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp --- a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp +++ b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp @@ -707,9 +707,7 @@ } void codeCompleteOperationMemberAccess(ast::OperationType opType) final { - Optional opName = opType.getName(); - const ods::Operation *odsOp = - opName ? odsContext.lookupOperation(*opName) : nullptr; + const ods::Operation *odsOp = opType.getODSOperation(); if (!odsOp) return; diff --git a/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll b/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll --- a/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll +++ b/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll @@ -44,43 +44,22 @@ // Check the generation of native constraints and rewrites. // CHECK: static ::mlir::LogicalResult TestCstPDLFn(::mlir::PatternRewriter &rewriter, -// CHECK-SAME: ::llvm::ArrayRef<::mlir::PDLValue> values) { -// CHECK: ::mlir::Attribute attr = {}; -// CHECK: if (values[0]) -// CHECK: attr = values[0].cast<::mlir::Attribute>(); -// CHECK: ::mlir::Operation * op = {}; -// CHECK: if (values[1]) -// CHECK: op = values[1].cast<::mlir::Operation *>(); -// CHECK: ::mlir::Type type = {}; -// CHECK: if (values[2]) -// CHECK: type = values[2].cast<::mlir::Type>(); -// CHECK: ::mlir::Value value = {}; -// CHECK: if (values[3]) -// CHECK: value = values[3].cast<::mlir::Value>(); -// CHECK: ::mlir::TypeRange typeRange = {}; -// CHECK: if (values[4]) -// CHECK: typeRange = values[4].cast<::mlir::TypeRange>(); -// CHECK: ::mlir::ValueRange valueRange = {}; -// CHECK: if (values[5]) -// CHECK: valueRange = values[5].cast<::mlir::ValueRange>(); - -// CHECK: return success(); +// CHECK-SAME: ::mlir::Attribute attr, ::mlir::Operation * op, ::mlir::Type type, +// CHECK-SAME: ::mlir::Value value, ::mlir::TypeRange typeRange, ::mlir::ValueRange valueRange) { +// CHECK-NEXT: return success(); // CHECK: } // CHECK-NOT: TestUnusedCst -// CHECK: static void TestRewritePDLFn(::mlir::PatternRewriter &rewriter, ::mlir::PDLResultList &results, -// CHECK-SAME: ::llvm::ArrayRef<::mlir::PDLValue> values) { -// CHECK: ::mlir::Attribute attr = {}; -// CHECK: ::mlir::Operation * op = {}; -// CHECK: ::mlir::Type type = {}; -// CHECK: ::mlir::Value value = {}; -// CHECK: ::mlir::TypeRange typeRange = {}; -// CHECK: ::mlir::ValueRange valueRange = {}; - +// CHECK: static void TestRewritePDLFn(::mlir::PatternRewriter &rewriter, +// CHECK-SAME: ::mlir::Attribute attr, ::mlir::Operation * op, ::mlir::Type type, +// CHECK-SAME: ::mlir::Value value, ::mlir::TypeRange typeRange, ::mlir::ValueRange valueRange) { // CHECK: foo; // CHECK: } +// CHECK: static ::mlir::Attribute TestRewriteSinglePDLFn(::mlir::PatternRewriter &rewriter) { +// CHECK: std::tuple<::mlir::Attribute, ::mlir::Type> TestRewriteTuplePDLFn(::mlir::PatternRewriter &rewriter) { + // CHECK-NOT: TestUnusedRewrite // CHECK: struct TestCstAndRewrite : ::mlir::PDLPatternModule { @@ -93,6 +72,8 @@ Constraint TestUnusedCst() [{ return success(); }]; Rewrite TestRewrite(attr: Attr, op: Op, type: Type, value: Value, typeRange: TypeRange, valueRange: ValueRange) [{ foo; }]; +Rewrite TestRewriteSingle() -> Attr [{}]; +Rewrite TestRewriteTuple() -> (Attr, Type) [{}]; Rewrite TestUnusedRewrite(op: Op) [{}]; Pattern TestCstAndRewrite { @@ -100,6 +81,8 @@ TestCst(attr<"true">, root, type, operand, types, operands); rewrite root with { TestRewrite(attr<"true">, root, type, operand, types, operands); + TestRewriteSingle(); + TestRewriteTuple(); erase root; }; }