diff --git a/mlir/include/mlir/Tools/PDLL/AST/Nodes.h b/mlir/include/mlir/Tools/PDLL/AST/Nodes.h --- a/mlir/include/mlir/Tools/PDLL/AST/Nodes.h +++ b/mlir/include/mlir/Tools/PDLL/AST/Nodes.h @@ -23,6 +23,7 @@ class Context; class Decl; class Expr; +class NamedAttributeDecl; class OpNameDecl; class VariableDecl; @@ -344,6 +345,84 @@ StringRef memberName; }; +//===----------------------------------------------------------------------===// +// OperationExpr +//===----------------------------------------------------------------------===// + +/// This expression represents the structural form of an MLIR Operation. It +/// represents either the structural form to match an input operation, or an +/// operation to create within a rewrite. +class OperationExpr final + : public Node::NodeBase, + private llvm::TrailingObjects { +public: + static OperationExpr *create(Context &ctx, llvm::SMRange loc, + const OpNameDecl *nameDecl, + ArrayRef operands, + ArrayRef resultTypes, + ArrayRef attributes); + + /// Return the name of the operation, or None if there isn't one. + Optional getName() const; + + /// Return the declaration of the operation name. + const OpNameDecl *getNameDecl() const { return nameDecl; } + + /// Return the location of the name of the operation expression, or an invalid + /// location if there isn't a name. + llvm::SMRange getNameLoc() const { return nameLoc; } + + /// Return the operands of this operation. + MutableArrayRef getOperands() { + return llvm::makeMutableArrayRef(getTrailingObjects(), numOperands); + } + ArrayRef getOperands() const { + return const_cast(this)->getOperands(); + } + + /// Return the result types of this operation. + MutableArrayRef getResultTypes() { + return llvm::makeMutableArrayRef(getTrailingObjects() + numOperands, + numResultTypes); + } + MutableArrayRef getResultTypes() const { + return const_cast(this)->getResultTypes(); + } + + /// Return the attributes of this operation. + MutableArrayRef getAttributes() { + return llvm::makeMutableArrayRef(getTrailingObjects(), + numAttributes); + } + MutableArrayRef getAttributes() const { + return const_cast(this)->getAttributes(); + } + +private: + OperationExpr(llvm::SMRange loc, Type type, const OpNameDecl *nameDecl, + unsigned numOperands, unsigned numResultTypes, + unsigned numAttributes, llvm::SMRange nameLoc) + : Base(loc, type), nameDecl(nameDecl), numOperands(numOperands), + numResultTypes(numResultTypes), numAttributes(numAttributes), + nameLoc(nameLoc) {} + + /// The name decl of this expression. + const OpNameDecl *nameDecl; + + /// The number of operands, result types, and attributes of the operation. + unsigned numOperands, numResultTypes, numAttributes; + + /// The location of the operation name in the expression if it has a name. + llvm::SMRange nameLoc; + + /// TrailingObject utilities. + friend llvm::TrailingObjects; + size_t numTrailingObjects(OverloadToken) const { + return numOperands + numResultTypes; + } +}; + //===----------------------------------------------------------------------===// // TypeExpr //===----------------------------------------------------------------------===// @@ -558,6 +637,30 @@ Expr *typeExpr; }; +//===----------------------------------------------------------------------===// +// NamedAttributeDecl +//===----------------------------------------------------------------------===// + +/// This Decl represents a NamedAttribute, and contains a string name and +/// attribute value. +class NamedAttributeDecl : public Node::NodeBase { +public: + static NamedAttributeDecl *create(Context &ctx, Name name, Expr *value); + + /// Return the name of the attribute. + const Name &getName() const { return *Decl::getName(); } + + /// Return value of the attribute. + Expr *getValue() const { return value; } + +private: + NamedAttributeDecl(Name name, Expr *value) + : Base(name.location, name), value(value) {} + + /// The value of the attribute. + Expr *value; +}; + //===----------------------------------------------------------------------===// // OpNameDecl //===----------------------------------------------------------------------===// @@ -709,7 +812,8 @@ //===----------------------------------------------------------------------===// inline bool Decl::classof(const Node *node) { - return isa(node); + return isa(node); } inline bool ConstraintDecl::classof(const Node *node) { @@ -723,7 +827,8 @@ } inline bool Expr::classof(const Node *node) { - return isa(node); + return isa(node); } inline bool OpRewriteStmt::classof(const Node *node) { diff --git a/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp b/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp --- a/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp +++ b/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp @@ -79,6 +79,7 @@ void printImpl(const AttributeExpr *expr); void printImpl(const DeclRefExpr *expr); void printImpl(const MemberAccessExpr *expr); + void printImpl(const OperationExpr *expr); void printImpl(const TypeExpr *expr); void printImpl(const AttrConstraintDecl *decl); @@ -87,6 +88,7 @@ void printImpl(const TypeRangeConstraintDecl *decl); void printImpl(const ValueConstraintDecl *decl); void printImpl(const ValueRangeConstraintDecl *decl); + void printImpl(const NamedAttributeDecl *decl); void printImpl(const OpNameDecl *decl); void printImpl(const PatternDecl *decl); void printImpl(const VariableDecl *decl); @@ -147,13 +149,14 @@ // Expressions. const AttributeExpr, const DeclRefExpr, const MemberAccessExpr, - const TypeExpr, + const OperationExpr, const TypeExpr, // Decls. const AttrConstraintDecl, const OpConstraintDecl, const TypeConstraintDecl, const TypeRangeConstraintDecl, const ValueConstraintDecl, const ValueRangeConstraintDecl, - const OpNameDecl, const PatternDecl, const VariableDecl, + const NamedAttributeDecl, const OpNameDecl, const PatternDecl, + const VariableDecl, const Module>([&](auto derivedNode) { this->printImpl(derivedNode); }) .Default([](const Node *) { llvm_unreachable("unknown AST node"); }); @@ -194,6 +197,17 @@ printChildren(expr->getParentExpr()); } +void NodePrinter::printImpl(const OperationExpr *expr) { + os << "OperationExpr " << expr << " Type<"; + print(expr->getType()); + os << ">\n"; + + printChildren(expr->getNameDecl()); + printChildren("Operands", expr->getOperands()); + printChildren("Result Types", expr->getResultTypes()); + printChildren("Attributes", expr->getAttributes()); +} + void NodePrinter::printImpl(const TypeExpr *expr) { os << "TypeExpr " << expr << " Value<\"" << expr->getValue() << "\">\n"; } @@ -229,6 +243,12 @@ printChildren(typeExpr); } +void NodePrinter::printImpl(const NamedAttributeDecl *decl) { + os << "NamedAttributeDecl " << decl << " Name<" << decl->getName().name + << ">\n"; + printChildren(decl->getValue()); +} + void NodePrinter::printImpl(const OpNameDecl *decl) { os << "OpNameDecl " << decl; if (Optional name = decl->getName()) diff --git a/mlir/lib/Tools/PDLL/AST/Nodes.cpp b/mlir/lib/Tools/PDLL/AST/Nodes.cpp --- a/mlir/lib/Tools/PDLL/AST/Nodes.cpp +++ b/mlir/lib/Tools/PDLL/AST/Nodes.cpp @@ -112,6 +112,37 @@ loc, parentExpr, memberName.copy(ctx.getAllocator()), type); } +//===----------------------------------------------------------------------===// +// OperationExpr +//===----------------------------------------------------------------------===// + +OperationExpr * +OperationExpr::create(Context &ctx, llvm::SMRange loc, const OpNameDecl *name, + ArrayRef operands, ArrayRef resultTypes, + ArrayRef attributes) { + unsigned allocSize = + OperationExpr::totalSizeToAlloc( + operands.size() + resultTypes.size(), attributes.size()); + void *rawData = + ctx.getAllocator().Allocate(allocSize, alignof(OperationExpr)); + + Type resultType = OperationType::get(ctx, name->getName()); + OperationExpr *opExpr = new (rawData) + OperationExpr(loc, resultType, name, operands.size(), resultTypes.size(), + attributes.size(), name->getLoc()); + std::uninitialized_copy(operands.begin(), operands.end(), + opExpr->getOperands().begin()); + std::uninitialized_copy(resultTypes.begin(), resultTypes.end(), + opExpr->getResultTypes().begin()); + std::uninitialized_copy(attributes.begin(), attributes.end(), + opExpr->getAttributes().begin()); + return opExpr; +} + +Optional OperationExpr::getName() const { + return getNameDecl()->getName(); +} + //===----------------------------------------------------------------------===// // TypeExpr //===----------------------------------------------------------------------===// @@ -189,6 +220,16 @@ ValueRangeConstraintDecl(loc, typeExpr); } +//===----------------------------------------------------------------------===// +// NamedAttributeDecl +//===----------------------------------------------------------------------===// + +NamedAttributeDecl *NamedAttributeDecl::create(Context &ctx, Name name, + Expr *value) { + return new (ctx.getAllocator().Allocate()) + NamedAttributeDecl(name, value); +} + //===----------------------------------------------------------------------===// // OpNameDecl //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp --- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp +++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp @@ -78,6 +78,11 @@ ast::Expr *&expr, ast::Type type, function_ref noteAttachFn = {}); + /// Given an operation expression, convert it to a Value or ValueRange + /// typed expression. + ast::Expr *convertOpToValue(const ast::Expr *opExpr, + ast::OperationType opType); + //===--------------------------------------------------------------------===// // Directives @@ -88,6 +93,7 @@ // Decls FailureOr parseTopLevelDecl(); + FailureOr parseNamedAttributeDecl(); FailureOr parsePatternDecl(); LogicalResult parsePatternDeclMetadata(Optional &benefit, bool &hasBoundedRecursion); @@ -120,6 +126,7 @@ FailureOr parseMemberAccessExpr(ast::Expr *parentExpr); FailureOr parseOperationName(bool allowEmptyName = false); FailureOr parseWrappedOperationName(); + FailureOr parseOperationExpr(); FailureOr parseTypeExpr(); FailureOr parseUnderscoreExpr(); @@ -170,6 +177,22 @@ llvm::SMRange loc); FailureOr validateMemberAccess(ast::Expr *parentExpr, StringRef name, llvm::SMRange loc); + FailureOr + createOperationExpr(llvm::SMRange loc, const ast::OpNameDecl *name, + MutableArrayRef operands, + MutableArrayRef attributes, + MutableArrayRef results); + LogicalResult + validateOperationOperands(llvm::SMRange loc, Optional name, + MutableArrayRef operands); + LogicalResult validateOperationResults(llvm::SMRange loc, + Optional name, + MutableArrayRef results); + LogicalResult + validateOperationOperandsOrResults(StringRef groupName, llvm::SMRange loc, + Optional name, + MutableArrayRef values, + ast::Type singleTy, ast::Type rangeTy); //===--------------------------------------------------------------------===// // Stmts @@ -289,6 +312,13 @@ return success(); } +ast::Expr *Parser::convertOpToValue(const ast::Expr *opExpr, + ast::OperationType opType) { + // $results is a special member access representing all of the results. + return ast::MemberAccessExpr::create(context, opExpr->getLoc(), opExpr, + "$results", valueRangeTy); +} + LogicalResult Parser::convertExpressionTo( ast::Expr *&expr, ast::Type type, function_ref noteAttachFn) { @@ -414,6 +444,33 @@ return decl; } +FailureOr Parser::parseNamedAttributeDecl() { + std::string attrNameStr; + if (curToken.isString()) + attrNameStr = curToken.getStringValue(); + else if (curToken.is(Token::identifier) || curToken.isKeyword()) + attrNameStr = curToken.getSpelling().str(); + else + return emitError("expected identifier or string attribute name"); + ast::Name name(std::move(attrNameStr), curToken.getLoc()); + consumeToken(); + + // Check for a value of the attribute. + ast::Expr *attrValue = nullptr; + if (consumeIf(Token::equal)) { + FailureOr attrExpr = parseExpr(); + if (failed(attrExpr)) + return failure(); + attrValue = *attrExpr; + } else { + // If there isn't a concrete value, create an expression representing a + // UnitAttr. + attrValue = ast::AttributeExpr::create(context, name.location, "unit"); + } + + return ast::NamedAttributeDecl::create(context, std::move(name), attrValue); +} + FailureOr Parser::parsePatternDecl() { llvm::SMRange loc = curToken.getLoc(); consumeToken(Token::kw_Pattern); @@ -706,6 +763,9 @@ case Token::identifier: lhsExpr = parseIdentifierExpr(); break; + case Token::kw_op: + lhsExpr = parseOperationExpr(); + break; case Token::kw_type: lhsExpr = parseTypeExpr(); break; @@ -838,6 +898,75 @@ return opNameDecl; } +FailureOr Parser::parseOperationExpr() { + llvm::SMRange loc = curToken.getLoc(); + consumeToken(Token::kw_op); + + // If we aren't followed by a `<`, the `op` keyword is treated as a normal + // identifier. + if (curToken.isNot(Token::less)) { + resetToken(loc); + return parseIdentifierExpr(); + } + + // Parse the operation name. + FailureOr opNameDecl = parseWrappedOperationName(); + if (failed(opNameDecl)) + return failure(); + + // Check for the optional list of operands. + SmallVector operands; + if (consumeIf(Token::l_paren)) { + do { + FailureOr operand = parseExpr(); + if (failed(operand)) + return failure(); + operands.push_back(*operand); + } while (consumeIf(Token::comma)); + + if (failed(parseToken(Token::r_paren, + "expected `)` after operation operand list"))) + return failure(); + } + + // Check for the optional list of attributes. + SmallVector attributes; + if (consumeIf(Token::l_brace)) { + do { + FailureOr decl = parseNamedAttributeDecl(); + if (failed(decl)) + return failure(); + attributes.emplace_back(*decl); + } while (consumeIf(Token::comma)); + + if (failed(parseToken(Token::r_brace, + "expected `}` after operation attribute list"))) + return failure(); + } + + // Check for the optional list of result types. + SmallVector resultTypes; + if (consumeIf(Token::arrow)) { + if (failed(parseToken(Token::l_paren, + "expected `(` before operation result type list"))) + return failure(); + + do { + FailureOr resultTypeExpr = parseExpr(); + if (failed(resultTypeExpr)) + return failure(); + resultTypes.push_back(*resultTypeExpr); + } while (consumeIf(Token::comma)); + + if (failed(parseToken(Token::r_paren, + "expected `)` after operation result type list"))) + return failure(); + } + + return createOperationExpr(loc, *opNameDecl, operands, attributes, + resultTypes); +} + FailureOr Parser::parseTypeExpr() { llvm::SMRange loc = curToken.getLoc(); consumeToken(Token::kw_type); @@ -1186,6 +1315,88 @@ .str()); } +FailureOr Parser::createOperationExpr( + llvm::SMRange loc, const ast::OpNameDecl *name, + MutableArrayRef operands, + MutableArrayRef attributes, + MutableArrayRef results) { + Optional opNameRef = name->getName(); + + // Verify the inputs operands. + if (failed(validateOperationOperands(loc, opNameRef, operands))) + return failure(); + + // Verify the attribute list. + for (ast::NamedAttributeDecl *attr : attributes) { + // Check for an attribute type, or a type awaiting resolution. + ast::Type attrType = attr->getValue()->getType(); + if (!attrType.isa()) { + return emitError( + attr->getValue()->getLoc(), + llvm::formatv("expected `Attr` expression, but got `{0}`", attrType)); + } + } + + // Verify the result types. + if (failed(validateOperationResults(loc, opNameRef, results))) + return failure(); + + return ast::OperationExpr::create(context, loc, name, operands, results, + attributes); +} + +LogicalResult +Parser::validateOperationOperands(llvm::SMRange loc, Optional name, + MutableArrayRef operands) { + return validateOperationOperandsOrResults("operand", loc, name, operands, + valueTy, valueRangeTy); +} + +LogicalResult +Parser::validateOperationResults(llvm::SMRange loc, Optional name, + MutableArrayRef results) { + return validateOperationOperandsOrResults("result", loc, name, results, + typeTy, typeRangeTy); +} + +LogicalResult Parser::validateOperationOperandsOrResults( + StringRef groupName, llvm::SMRange loc, Optional name, + MutableArrayRef values, ast::Type singleTy, + ast::Type rangeTy) { + // All operation types accept a single range parameter. + if (values.size() == 1) { + if (failed(convertExpressionTo(values[0], rangeTy))) + return failure(); + return success(); + } + + // Otherwise, accept the value groups as they have been defined and just + // ensure they are one of the expected types. + for (ast::Expr *&valueExpr : values) { + ast::Type valueExprType = valueExpr->getType(); + + // Check if this is one of the expected types. + if (valueExprType == rangeTy || valueExprType == singleTy) + continue; + + // If the operand is an Operation, allow converting to a Value or + // ValueRange. + if (singleTy == valueTy) { + if (auto opType = valueExprType.dyn_cast()) { + valueExpr = convertOpToValue(valueExpr, opType); + continue; + } + } + + return emitError( + valueExpr->getLoc(), + llvm::formatv( + "expected `{0}` or `{1}` convertible expression, but got `{2}`", + singleTy, rangeTy, valueExprType)); + } + return success(); +} + //===----------------------------------------------------------------------===// // Stmts diff --git a/mlir/test/mlir-pdll/Parser/expr-failure.pdll b/mlir/test/mlir-pdll/Parser/expr-failure.pdll --- a/mlir/test/mlir-pdll/Parser/expr-failure.pdll +++ b/mlir/test/mlir-pdll/Parser/expr-failure.pdll @@ -81,6 +81,83 @@ // ----- +//===----------------------------------------------------------------------===// +// `op` Expr +//===----------------------------------------------------------------------===// + +Pattern { + // CHECK: expected `)` after operation operand list + let value: Value; + let foo = op(value<; +} + +// ----- + +Pattern { + // CHECK: unable to convert expression of type `Attr` to the expected type of `ValueRange` + let attr: Attr; + let foo = op(attr); +} + +// ----- + +Pattern { + // CHECK: expected `Value` or `ValueRange` convertible expression, but got `Type` + let foo = op<>(_: Type, _: TypeRange); +} + +// ----- + +Pattern { + // CHECK: expected identifier or string attribute name + let foo = op<> { 10; +} + +// ----- + +Pattern { + // CHECK: expected `Attr` expression, but got `Value` + let foo = op<> { foo = _: Value }; +} + +// ----- + +Pattern { + // CHECK: expected `}` after operation attribute list + let foo = op<> { "foo" {; +} + +// ----- + +Pattern { + // CHECK: expected `(` before operation result type list + let foo = op<> -> ); +} + +// ----- + +Pattern { + // CHECK: unable to convert expression of type `ValueRange` to the expected type of `TypeRange` + let foo = op<> -> (_: ValueRange); +} + +// ----- + +Pattern { + // CHECK: expected `Type` or `TypeRange` convertible expression, but got `Value` + let foo = op<> -> (_: Value, _: ValueRange); +} + +// ----- + +Pattern { + // CHECK: expected `)` after operation result type list + let value: TypeRange; + let foo = op<> -> (value<; +} + +// ----- + //===----------------------------------------------------------------------===// // `type` Expr //===----------------------------------------------------------------------===// @@ -95,4 +172,4 @@ Pattern { // CHECK: expected `>` after type literal let foo = type<""; -} \ No newline at end of file +} diff --git a/mlir/test/mlir-pdll/Parser/expr.pdll b/mlir/test/mlir-pdll/Parser/expr.pdll --- a/mlir/test/mlir-pdll/Parser/expr.pdll +++ b/mlir/test/mlir-pdll/Parser/expr.pdll @@ -14,6 +14,69 @@ // ----- +//===----------------------------------------------------------------------===// +// OperationExpr +//===----------------------------------------------------------------------===// + +// CHECK: Module +// CHECK: `-OperationExpr {{.*}} Type +// CHECK: `-OpNameDecl +Pattern { + erase op<>; +} + +// ----- + +// CHECK: Module +// CHECK: `-OperationExpr {{.*}} Type> +// CHECK: `-OpNameDecl {{.*}} Name +Pattern { + erase op; +} + +// ----- + +// CHECK: Module +// CHECK: `-OperationExpr {{.*}} Type +// CHECK: `-OpNameDecl +// CHECK: `Operands` +// CHECK: |-DeclRefExpr {{.*}} Type +// CHECK: |-DeclRefExpr {{.*}} Type +// CHECK: `-DeclRefExpr {{.*}} Type +Pattern { + erase op<>(_: Value, _: ValueRange, _: Value); +} + +// ----- + +// CHECK: Module +// CHECK: `-OperationExpr {{.*}} Type +// CHECK: `-OpNameDecl +// CHECK: `Attributes` +// CHECK: |-NamedAttributeDecl {{.*}} Name +// CHECK: `-AttributeExpr {{.*}} Value<"unit"> +// CHECK: `-NamedAttributeDecl {{.*}} Name +// CHECK: `-DeclRefExpr {{.*}} Type + +Pattern { + erase op<> {unitAttr, "normal$Attr" = _: Attr}; +} + +// ----- + +// CHECK: Module +// CHECK: `-OperationExpr {{.*}} Type +// CHECK: `-OpNameDecl +// CHECK: `Result Types` +// CHECK: |-DeclRefExpr {{.*}} Type +// CHECK: |-DeclRefExpr {{.*}} Type +// CHECK: `-DeclRefExpr {{.*}} Type +Pattern { + erase op<> -> (_: Type, _: TypeRange, _: Type); +} + +// ----- + //===----------------------------------------------------------------------===// // TypeExpr //===----------------------------------------------------------------------===//