diff --git a/mlir/include/mlir/Tools/PDLL/AST/Nodes.h b/mlir/include/mlir/Tools/PDLL/AST/Nodes.h --- a/mlir/include/mlir/Tools/PDLL/AST/Nodes.h +++ b/mlir/include/mlir/Tools/PDLL/AST/Nodes.h @@ -943,7 +943,7 @@ Type resultType) : Base(name.getLoc(), &name), numInputs(numInputs), numResults(numResults), codeBlock(codeBlock), constraintBody(body), - resultType(resultType) {} + resultType(resultType), hasNativeInputTypes(hasNativeInputTypes) {} /// The number of inputs to this constraint. unsigned numInputs; diff --git a/mlir/include/mlir/Tools/PDLL/Parser/CodeComplete.h b/mlir/include/mlir/Tools/PDLL/Parser/CodeComplete.h --- a/mlir/include/mlir/Tools/PDLL/Parser/CodeComplete.h +++ b/mlir/include/mlir/Tools/PDLL/Parser/CodeComplete.h @@ -48,12 +48,9 @@ /// Signal code completion for a constraint name with an optional decl scope. /// `currentType` is the current type of the variable that will use the - /// constraint, or nullptr if a type is unknown. `allowNonCoreConstraints` - /// indicates if user defined constraints are allowed in the completion - /// results. `allowInlineTypeConstraints` enables inline type constraints for - /// Attr/Value/ValueRange. + /// constraint, or nullptr if a type is unknown. `allowInlineTypeConstraints` + /// enables inline type constraints for Attr/Value/ValueRange. virtual void codeCompleteConstraintName(ast::Type currentType, - bool allowNonCoreConstraints, bool allowInlineTypeConstraints, const ast::DeclScope *scope); diff --git a/mlir/lib/Tools/PDLL/Parser/CodeComplete.cpp b/mlir/lib/Tools/PDLL/Parser/CodeComplete.cpp --- a/mlir/lib/Tools/PDLL/Parser/CodeComplete.cpp +++ b/mlir/lib/Tools/PDLL/Parser/CodeComplete.cpp @@ -24,5 +24,5 @@ ast::OperationType opType) {} void CodeCompleteContext::codeCompleteConstraintName( - ast::Type currentType, bool allowNonCoreConstraints, - bool allowInlineTypeConstraints, const ast::DeclScope *scope) {} + ast::Type currentType, bool allowInlineTypeConstraints, + const ast::DeclScope *scope) {} diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp --- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp +++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp @@ -297,13 +297,10 @@ /// existing constraints that have already been parsed for the same entity /// that will be constrained by this constraint. `allowInlineTypeConstraints` /// allows the use of inline Type constraints, e.g. `Value`. - /// If `allowNonCoreConstraints` is true, then complex (e.g. user defined - /// constraints) may be used with the variable. FailureOr parseConstraint(Optional &typeConstraint, ArrayRef existingConstraints, - bool allowInlineTypeConstraints, - bool allowNonCoreConstraints); + bool allowInlineTypeConstraints); /// Try to parse the constraint for a UserConstraintDecl/UserRewriteDecl /// argument or result variable. The constraints for these variables do not @@ -389,20 +386,16 @@ /// `inferredType` is the type of the variable inferred by the constraints /// within the list, and is updated to the most refined type as determined by /// the constraints. Returns success if the constraint list is valid, failure - /// otherwise. If `allowNonCoreConstraints` is true, then complex (e.g. user - /// defined constraints) may be used with the variable. + /// otherwise. LogicalResult validateVariableConstraints(ArrayRef constraints, - ast::Type &inferredType, - bool allowNonCoreConstraints = true); + ast::Type &inferredType); /// Validate a single reference to a constraint. `inferredType` contains the /// currently inferred variabled type and is refined within the type defined /// by the constraint. Returns success if the constraint is valid, failure - /// otherwise. If `allowNonCoreConstraints` is true, then complex (e.g. user - /// defined constraints) may be used with the variable. + /// otherwise. LogicalResult validateVariableConstraint(const ast::ConstraintRef &ref, - ast::Type &inferredType, - bool allowNonCoreConstraints = true); + ast::Type &inferredType); LogicalResult validateTypeConstraintExpr(const ast::Expr *typeExpr); LogicalResult validateTypeRangeConstraintExpr(const ast::Expr *typeExpr); @@ -469,7 +462,6 @@ LogicalResult codeCompleteMemberAccess(ast::Expr *parentExpr); LogicalResult codeCompleteAttributeName(Optional opName); LogicalResult codeCompleteConstraintName(ast::Type inferredType, - bool allowNonCoreConstraints, bool allowInlineTypeConstraints); LogicalResult codeCompleteDialectName(); LogicalResult codeCompleteOperationName(StringRef dialectName); @@ -1129,18 +1121,7 @@ // Check to see if this result is named. if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) { // Check to see if this name actually refers to a Constraint. - ast::Decl *existingDecl = curDeclScope->lookup(curToken.getSpelling()); - if (isa_and_nonnull(existingDecl)) { - // If yes, and this is a Rewrite, give a nice error message as non-Core - // constraints are not supported on Rewrite results. - if (parserContext == ParserContext::Rewrite) { - return emitError( - "`Rewrite` results are only permitted to use core constraints, " - "such as `Attr`, `Op`, `Type`, `TypeRange`, `Value`, `ValueRange`"); - } - - // Otherwise, parse this as an unnamed result variable. - } else { + if (!curDeclScope->lookup(curToken.getSpelling())) { // If it wasn't a constraint, parse the result similarly to a variable. If // there is already an existing decl, we will emit an error when defining // this variable later. @@ -1662,8 +1643,7 @@ Optional typeConstraint; auto parseSingleConstraint = [&] { FailureOr constraint = parseConstraint( - typeConstraint, constraints, /*allowInlineTypeConstraints=*/true, - /*allowNonCoreConstraints=*/true); + typeConstraint, constraints, /*allowInlineTypeConstraints=*/true); if (failed(constraint)) return failure(); constraints.push_back(*constraint); @@ -1684,8 +1664,7 @@ FailureOr Parser::parseConstraint(Optional &typeConstraint, ArrayRef existingConstraints, - bool allowInlineTypeConstraints, - bool allowNonCoreConstraints) { + bool allowInlineTypeConstraints) { auto parseTypeConstraint = [&](ast::Expr *&typeExpr) -> LogicalResult { if (!allowInlineTypeConstraints) { return emitError( @@ -1791,12 +1770,10 @@ case Token::code_complete: { // Try to infer the current type for use by code completion. ast::Type inferredType; - if (failed(validateVariableConstraints(existingConstraints, inferredType, - allowNonCoreConstraints))) + if (failed(validateVariableConstraints(existingConstraints, inferredType))) return failure(); - return codeCompleteConstraintName(inferredType, allowNonCoreConstraints, - allowInlineTypeConstraints); + return codeCompleteConstraintName(inferredType, allowInlineTypeConstraints); } default: break; @@ -1805,13 +1782,9 @@ } FailureOr Parser::parseArgOrResultConstraint() { - // Constraint arguments may apply more complex constraints via the arguments. - bool allowNonCoreConstraints = parserContext == ParserContext::Constraint; - Optional typeConstraint; return parseConstraint(typeConstraint, /*existingConstraints=*/llvm::None, - /*allowInlineTypeConstraints=*/false, - allowNonCoreConstraints); + /*allowInlineTypeConstraints=*/false); } //===----------------------------------------------------------------------===// @@ -2598,29 +2571,23 @@ FailureOr Parser::createArgOrResultVariableDecl(StringRef name, SMRange loc, const ast::ConstraintRef &constraint) { - // Constraint arguments may apply more complex constraints via the arguments. - bool allowNonCoreConstraints = parserContext == ParserContext::Constraint; ast::Type argType; - if (failed(validateVariableConstraint(constraint, argType, - allowNonCoreConstraints))) + if (failed(validateVariableConstraint(constraint, argType))) return failure(); return defineVariableDecl(name, loc, argType, constraint); } LogicalResult Parser::validateVariableConstraints(ArrayRef constraints, - ast::Type &inferredType, - bool allowNonCoreConstraints) { + ast::Type &inferredType) { for (const ast::ConstraintRef &ref : constraints) - if (failed(validateVariableConstraint(ref, inferredType, - allowNonCoreConstraints))) + if (failed(validateVariableConstraint(ref, inferredType))) return failure(); return success(); } LogicalResult Parser::validateVariableConstraint(const ast::ConstraintRef &ref, - ast::Type &inferredType, - bool allowNonCoreConstraints) { + ast::Type &inferredType) { ast::Type constraintType; if (const auto *cst = dyn_cast(ref.constraint)) { if (const ast::Expr *typeExpr = cst->getTypeExpr()) { @@ -2652,13 +2619,6 @@ constraintType = valueRangeTy; } else if (const auto *cst = dyn_cast(ref.constraint)) { - if (!allowNonCoreConstraints) { - return emitError(ref.referenceLoc, - "`Rewrite` arguments and results are only permitted to " - "use core constraints, such as `Attr`, `Op`, `Type`, " - "`TypeRange`, `Value`, `ValueRange`"); - } - ArrayRef inputs = cst->getInputs(); if (inputs.size() != 1) { return emitErrorAndNote(ref.referenceLoc, @@ -3160,11 +3120,9 @@ LogicalResult Parser::codeCompleteConstraintName(ast::Type inferredType, - bool allowNonCoreConstraints, bool allowInlineTypeConstraints) { codeCompleteContext->codeCompleteConstraintName( - inferredType, allowNonCoreConstraints, allowInlineTypeConstraints, - curDeclScope); + inferredType, allowInlineTypeConstraints, curDeclScope); return failure(); } diff --git a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp --- a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp +++ b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp @@ -760,7 +760,6 @@ } void codeCompleteConstraintName(ast::Type currentType, - bool allowNonCoreConstraints, bool allowInlineTypeConstraints, const ast::DeclScope *scope) final { auto addCoreConstraint = [&](StringRef constraint, StringRef mlirType, @@ -808,9 +807,6 @@ while (scope) { for (const ast::Decl *decl : scope->getDecls()) { if (const auto *cst = dyn_cast(decl)) { - if (!allowNonCoreConstraints) - continue; - lsp::CompletionItem item; item.label = cst->getName().getName().str(); item.kind = lsp::CompletionItemKind::Interface; diff --git a/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll b/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll --- a/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll +++ b/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll @@ -1,4 +1,4 @@ -// RUN: mlir-pdll %s -I %S -split-input-file -x cpp | FileCheck %s +// RUN: mlir-pdll %s -I %S -I %S/../../../../include -split-input-file -x cpp | FileCheck %s // Check that we generate a wrapper pattern for each PDL pattern. Also // add in a pattern awkwardly named the same as our generated patterns to @@ -44,6 +44,8 @@ // Check the generation of native constraints and rewrites. +#include "include/ods.td" + // CHECK: static ::mlir::LogicalResult TestCstPDLFn(::mlir::PatternRewriter &rewriter, // CHECK-SAME: ::mlir::Attribute attr, ::mlir::Operation * op, ::mlir::Type type, // CHECK-SAME: ::mlir::Value value, ::mlir::TypeRange typeRange, ::mlir::ValueRange valueRange) { @@ -58,6 +60,7 @@ // CHECK: foo; // CHECK: } +// CHECK: TestAttrInterface TestRewriteODSPDLFn(::mlir::PatternRewriter &rewriter, TestAttrInterface attr) { // CHECK: static ::mlir::Attribute TestRewriteSinglePDLFn(::mlir::PatternRewriter &rewriter) { // CHECK: std::tuple<::mlir::Attribute, ::mlir::Type> TestRewriteTuplePDLFn(::mlir::PatternRewriter &rewriter) { @@ -73,6 +76,7 @@ Constraint TestUnusedCst() [{ return success(); }]; Rewrite TestRewrite(attr: Attr, op: Op, type: Type, value: Value, typeRange: TypeRange, valueRange: ValueRange) [{ foo; }]; +Rewrite TestRewriteODS(attr: TestAttrInterface) -> TestAttrInterface [{}]; Rewrite TestRewriteSingle() -> Attr [{}]; Rewrite TestRewriteTuple() -> (Attr, Type) [{}]; Rewrite TestUnusedRewrite(op: Op) [{}]; @@ -82,6 +86,7 @@ TestCst(attr<"true">, root, type, operand, types, operands); rewrite root with { TestRewrite(attr<"true">, root, type, operand, types, operands); + TestRewriteODS(attr<"true">); TestRewriteSingle(); TestRewriteTuple(); erase root; diff --git a/mlir/test/mlir-pdll/CodeGen/CPP/include/ods.td b/mlir/test/mlir-pdll/CodeGen/CPP/include/ods.td new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-pdll/CodeGen/CPP/include/ods.td @@ -0,0 +1,3 @@ +include "mlir/IR/OpBase.td" + +def TestAttrInterface : AttrInterface<"TestAttrInterface">; diff --git a/mlir/test/mlir-pdll/Parser/rewrite-failure.pdll b/mlir/test/mlir-pdll/Parser/rewrite-failure.pdll --- a/mlir/test/mlir-pdll/Parser/rewrite-failure.pdll +++ b/mlir/test/mlir-pdll/Parser/rewrite-failure.pdll @@ -88,13 +88,6 @@ // ----- -Constraint ValueConstraint(value: Value); - -// CHECK: arguments and results are only permitted to use core constraints, such as `Attr`, `Op`, `Type`, `TypeRange`, `Value`, `ValueRange` -Rewrite Foo(arg: ValueConstraint); - -// ----- - // CHECK: expected `)` to end argument list Rewrite Foo(arg: Value{} @@ -139,13 +132,6 @@ // ----- -Constraint ValueConstraint(value: Value); - -// CHECK: results are only permitted to use core constraints, such as `Attr`, `Op`, `Type`, `TypeRange`, `Value`, `ValueRange` -Rewrite Foo() -> ValueConstraint; - -// ----- - //===----------------------------------------------------------------------===// // Native Rewrites //===----------------------------------------------------------------------===//