diff --git a/mlir/include/mlir/Tools/PDLL/AST/Nodes.h b/mlir/include/mlir/Tools/PDLL/AST/Nodes.h --- a/mlir/include/mlir/Tools/PDLL/AST/Nodes.h +++ b/mlir/include/mlir/Tools/PDLL/AST/Nodes.h @@ -442,6 +442,36 @@ } }; +//===----------------------------------------------------------------------===// +// TupleExpr +//===----------------------------------------------------------------------===// + +/// This expression builds a tuple from a set of element values. +class TupleExpr final : public Node::NodeBase, + private llvm::TrailingObjects { +public: + static TupleExpr *create(Context &ctx, llvm::SMRange loc, + ArrayRef elements, + ArrayRef elementNames); + + /// Return the element expressions of this tuple. + MutableArrayRef getElements() { + return {getTrailingObjects(), getType().size()}; + } + ArrayRef getElements() const { + return const_cast(this)->getElements(); + } + + /// Return the tuple result type of this expression. + TupleType getType() const { return Base::getType().cast(); } + +private: + TupleExpr(llvm::SMRange loc, TupleType type) : Base(loc, type) {} + + /// TrailingObject utilities. + friend class llvm::TrailingObjects; +}; + //===----------------------------------------------------------------------===// // TypeExpr //===----------------------------------------------------------------------===// @@ -844,7 +874,7 @@ inline bool Expr::classof(const Node *node) { return isa(node); + TupleExpr, TypeExpr>(node); } inline bool OpRewriteStmt::classof(const Node *node) { diff --git a/mlir/include/mlir/Tools/PDLL/AST/Types.h b/mlir/include/mlir/Tools/PDLL/AST/Types.h --- a/mlir/include/mlir/Tools/PDLL/AST/Types.h +++ b/mlir/include/mlir/Tools/PDLL/AST/Types.h @@ -22,6 +22,7 @@ struct ConstraintTypeStorage; struct OperationTypeStorage; struct RangeTypeStorage; +struct TupleTypeStorage; struct TypeTypeStorage; struct ValueTypeStorage; } // namespace detail @@ -203,6 +204,35 @@ static ValueRangeType get(Context &context); }; +//===----------------------------------------------------------------------===// +// TupleType +//===----------------------------------------------------------------------===// + +/// This class represents a PDLL tuple type, i.e. an ordered set of element +/// types with optional names. +class TupleType : public Type::TypeBase { +public: + using Base::Base; + + /// Return an instance of the Tuple type. + static TupleType get(Context &context, ArrayRef elementTypes, + ArrayRef elementNames); + static TupleType get(Context &context, + ArrayRef elementTypes = llvm::None); + + /// Return the element types of this tuple. + ArrayRef getElementTypes() const; + + /// Return the element names of this tuple. + ArrayRef getElementNames() const; + + /// Return the number of elements within this tuple. + size_t size() const { return getElementTypes().size(); } + + /// Return if the tuple has no elements. + bool empty() const { return size() == 0; } +}; + //===----------------------------------------------------------------------===// // TypeType //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Tools/PDLL/AST/Context.cpp b/mlir/lib/Tools/PDLL/AST/Context.cpp --- a/mlir/lib/Tools/PDLL/AST/Context.cpp +++ b/mlir/lib/Tools/PDLL/AST/Context.cpp @@ -20,4 +20,5 @@ typeUniquer.registerParametricStorageType(); typeUniquer.registerParametricStorageType(); + typeUniquer.registerParametricStorageType(); } diff --git a/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp b/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp --- a/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp +++ b/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp @@ -80,6 +80,7 @@ void printImpl(const DeclRefExpr *expr); void printImpl(const MemberAccessExpr *expr); void printImpl(const OperationExpr *expr); + void printImpl(const TupleExpr *expr); void printImpl(const TypeExpr *expr); void printImpl(const AttrConstraintDecl *decl); @@ -132,6 +133,17 @@ print(type.getElementType()); os << "Range"; }) + .Case([&](TupleType type) { + os << "Tuple<"; + llvm::interleaveComma( + llvm::zip(type.getElementNames(), type.getElementTypes()), os, + [&](auto it) { + if (!std::get<0>(it).empty()) + os << std::get<0>(it) << ": "; + this->print(std::get<1>(it)); + }); + os << ">"; + }) .Case([&](TypeType) { os << "Type"; }) .Case([&](ValueType) { os << "Value"; }) .Default([](Type) { llvm_unreachable("unknown AST type"); }); @@ -149,7 +161,7 @@ // Expressions. const AttributeExpr, const DeclRefExpr, const MemberAccessExpr, - const OperationExpr, const TypeExpr, + const OperationExpr, const TupleExpr, const TypeExpr, // Decls. const AttrConstraintDecl, const OpConstraintDecl, @@ -208,6 +220,14 @@ printChildren("Attributes", expr->getAttributes()); } +void NodePrinter::printImpl(const TupleExpr *expr) { + os << "TupleExpr " << expr << " Type<"; + print(expr->getType()); + os << ">\n"; + + printChildren(expr->getElements()); +} + void NodePrinter::printImpl(const TypeExpr *expr) { os << "TypeExpr " << expr << " Value<\"" << expr->getValue() << "\">\n"; } diff --git a/mlir/lib/Tools/PDLL/AST/Nodes.cpp b/mlir/lib/Tools/PDLL/AST/Nodes.cpp --- a/mlir/lib/Tools/PDLL/AST/Nodes.cpp +++ b/mlir/lib/Tools/PDLL/AST/Nodes.cpp @@ -147,6 +147,26 @@ return getNameDecl()->getName(); } +//===----------------------------------------------------------------------===// +// TupleExpr +//===----------------------------------------------------------------------===// + +TupleExpr *TupleExpr::create(Context &ctx, llvm::SMRange loc, + ArrayRef elements, + ArrayRef names) { + unsigned allocSize = TupleExpr::totalSizeToAlloc(elements.size()); + void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(TupleExpr)); + + auto elementTypes = llvm::map_range( + elements, [](const Expr *expr) { return expr->getType(); }); + TupleType type = TupleType::get(ctx, llvm::to_vector(elementTypes), names); + + TupleExpr *expr = new (rawData) TupleExpr(loc, type); + std::uninitialized_copy(elements.begin(), elements.end(), + expr->getElements().begin()); + return expr; +} + //===----------------------------------------------------------------------===// // TypeExpr //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Tools/PDLL/AST/TypeDetail.h b/mlir/lib/Tools/PDLL/AST/TypeDetail.h --- a/mlir/lib/Tools/PDLL/AST/TypeDetail.h +++ b/mlir/lib/Tools/PDLL/AST/TypeDetail.h @@ -96,6 +96,26 @@ using Base::Base; }; +//===----------------------------------------------------------------------===// +// TupleType +//===----------------------------------------------------------------------===// + +struct TupleTypeStorage + : public TypeStorageBase, ArrayRef>> { + using Base::Base; + + static TupleTypeStorage * + construct(StorageUniquer::StorageAllocator &alloc, + std::pair, ArrayRef> key) { + SmallVector names = llvm::to_vector(llvm::map_range( + key.second, [&](StringRef name) { return alloc.copyInto(name); })); + return new (alloc.allocate()) TupleTypeStorage( + std::make_pair(alloc.copyInto(key.first), + alloc.copyInto(llvm::makeArrayRef(names)))); + } +}; + //===----------------------------------------------------------------------===// // TypeType //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Tools/PDLL/AST/Types.cpp b/mlir/lib/Tools/PDLL/AST/Types.cpp --- a/mlir/lib/Tools/PDLL/AST/Types.cpp +++ b/mlir/lib/Tools/PDLL/AST/Types.cpp @@ -107,6 +107,29 @@ .cast(); } +//===----------------------------------------------------------------------===// +// TupleType +//===----------------------------------------------------------------------===// + +TupleType TupleType::get(Context &context, ArrayRef elementTypes, + ArrayRef elementNames) { + assert(elementTypes.size() == elementNames.size()); + return context.getTypeUniquer().get( + /*initFn=*/function_ref(), elementTypes, elementNames); +} +TupleType TupleType::get(Context &context, ArrayRef elementTypes) { + SmallVector elementNames(elementTypes.size()); + return get(context, elementTypes, elementNames); +} + +ArrayRef TupleType::getElementTypes() const { + return getImplAs()->getValue().first; +} + +ArrayRef TupleType::getElementNames() const { + return getImplAs()->getValue().second; +} + //===----------------------------------------------------------------------===// // TypeType //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp --- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp +++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp @@ -13,9 +13,11 @@ #include "mlir/Tools/PDLL/AST/Diagnostic.h" #include "mlir/Tools/PDLL/AST/Nodes.h" #include "mlir/Tools/PDLL/AST/Types.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/SaveAndRestore.h" +#include "llvm/Support/ScopedPrinter.h" #include using namespace mlir; @@ -147,6 +149,7 @@ FailureOr parseOperationName(); FailureOr parseWrappedOperationName(); FailureOr parseOperationExpr(); + FailureOr parseTupleExpr(); FailureOr parseTypeExpr(); FailureOr parseUnderscoreExpr(); @@ -227,6 +230,9 @@ Optional name, MutableArrayRef values, ast::Type singleTy, ast::Type rangeTy); + FailureOr createTupleExpr(llvm::SMRange loc, + ArrayRef elements, + ArrayRef elementNames); //===--------------------------------------------------------------------===// // Stmts @@ -403,6 +409,35 @@ (type == typeTy || type == typeRangeTy)) return success(); + // Handle tuple types. + if (auto exprTupleType = exprType.dyn_cast()) { + auto tupleType = type.dyn_cast(); + if (!tupleType || tupleType.size() != exprTupleType.size()) + return emitConvertError(); + + // Build a new tuple expression using each of the elements of the current + // tuple. + SmallVector newExprs; + for (unsigned i = 0, e = exprTupleType.size(); i < e; ++i) { + newExprs.push_back(ast::MemberAccessExpr::create( + ctx, expr->getLoc(), expr, llvm::to_string(i), + exprTupleType.getElementTypes()[i])); + + auto diagFn = [&](ast::Diagnostic &diag) { + diag.attachNote(llvm::formatv("when converting element #{0} of `{1}`", + i, exprTupleType)); + if (noteAttachFn) + noteAttachFn(diag); + }; + if (failed(convertExpressionTo(newExprs.back(), + tupleType.getElementTypes()[i], diagFn))) + return failure(); + } + expr = ast::TupleExpr::create(ctx, expr->getLoc(), newExprs, + tupleType.getElementNames()); + return success(); + } + return emitConvertError(); } @@ -799,6 +834,9 @@ case Token::kw_type: lhsExpr = parseTypeExpr(); break; + case Token::l_paren: + lhsExpr = parseTupleExpr(); + break; default: return emitError("expected expression"); } @@ -996,6 +1034,58 @@ resultTypes); } +FailureOr Parser::parseTupleExpr() { + llvm::SMRange loc = curToken.getLoc(); + consumeToken(Token::l_paren); + + DenseMap usedNames; + SmallVector elementNames; + SmallVector elements; + if (curToken.isNot(Token::r_paren)) { + do { + // Check for the optional element name assignment before the value. + StringRef elementName; + if (curToken.is(Token::identifier) || curToken.isDependentKeyword()) { + Token elementNameTok = curToken; + consumeToken(); + + // The element name is only present if followed by an `=`. + if (consumeIf(Token::equal)) { + elementName = elementNameTok.getSpelling(); + + // Check to see if this name is already used. + auto elementNameIt = + usedNames.try_emplace(elementName, elementNameTok.getLoc()); + if (!elementNameIt.second) { + return emitErrorAndNote( + elementNameTok.getLoc(), + llvm::formatv("duplicate tuple element label `{0}`", + elementName), + elementNameIt.first->getSecond(), + "see previous label use here"); + } + } else { + // Otherwise, we treat this as part of an expression so reset the + // lexer. + resetToken(elementNameTok.getLoc()); + } + } + elementNames.push_back(elementName); + + // Parse the tuple element value. + FailureOr element = parseExpr(); + if (failed(element)) + return failure(); + elements.push_back(*element); + } while (consumeIf(Token::comma)); + } + loc.End = curToken.getEndLoc(); + if (failed( + parseToken(Token::r_paren, "expected `)` after tuple element list"))) + return failure(); + return createTupleExpr(loc, elements, elementNames); +} + FailureOr Parser::parseTypeExpr() { llvm::SMRange loc = curToken.getLoc(); consumeToken(Token::kw_type); @@ -1329,6 +1419,19 @@ if (parentType.isa()) { if (name == ast::AllResultsMemberAccessExpr::getMemberName()) return valueRangeTy; + } else if (auto tupleType = parentType.dyn_cast()) { + // Handle indexed results. + unsigned index = 0; + if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) && + index < tupleType.size()) { + return tupleType.getElementTypes()[index]; + } + + // Handle named results. + auto elementNames = tupleType.getElementNames(); + auto it = llvm::find(elementNames, name); + if (it != elementNames.end()) + return tupleType.getElementTypes()[it - elementNames.begin()]; } return emitError( loc, @@ -1419,6 +1522,20 @@ return success(); } +FailureOr +Parser::createTupleExpr(llvm::SMRange loc, ArrayRef elements, + ArrayRef elementNames) { + for (const ast::Expr *element : elements) { + ast::Type eleTy = element->getType(); + if (eleTy.isa()) { + return emitError( + element->getLoc(), + llvm::formatv("unable to build a tuple with `{0}` element", eleTy)); + } + } + return ast::TupleExpr::create(ctx, loc, elements, elementNames); +} + //===----------------------------------------------------------------------===// // Stmts diff --git a/mlir/test/mlir-pdll/Parser/expr-failure.pdll b/mlir/test/mlir-pdll/Parser/expr-failure.pdll --- a/mlir/test/mlir-pdll/Parser/expr-failure.pdll +++ b/mlir/test/mlir-pdll/Parser/expr-failure.pdll @@ -63,6 +63,65 @@ // ----- +Pattern { + let tuple = (result1 = value: Value, result2 = value); + + // CHECK: invalid member access `unknown_result` on expression of type `Tuple` + let tuple2 = (tuple.result1, tuple.unknown_result); + + erase op<>; +} + +// ----- + +Pattern { + let tuple = (result1 = value: Value, result2 = value); + + // CHECK: invalid member access `2` on expression of type `Tuple` + let tuple2 = (tuple.0, tuple.2); + + erase op<>; +} + +// ----- + +//===----------------------------------------------------------------------===// +// Tuple Expr +//===----------------------------------------------------------------------===// + +Pattern { + // CHECK: expected `)` after tuple element list + let tuple = (value: Value, value; +} + +// ----- + +Pattern { + // CHECK: unable to build a tuple with `Tuple` element + let tuple = (_: Value, _: Value); + let var = (tuple); + erase op<>; +} + +// ----- + +Pattern { + // CHECK: expected expression + let tuple = (10 = _: Value); + erase op<>; +} + +// ----- + +Pattern { + // CHECK: duplicate tuple element label `field` + // CHECK: see previous label use here + let tuple = (field = _: Value, field = _: Value); + erase op<>; +} + +// ----- + //===----------------------------------------------------------------------===// // `attr` Expr //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-pdll/Parser/expr.pdll b/mlir/test/mlir-pdll/Parser/expr.pdll --- a/mlir/test/mlir-pdll/Parser/expr.pdll +++ b/mlir/test/mlir-pdll/Parser/expr.pdll @@ -14,6 +14,27 @@ // ----- +//===----------------------------------------------------------------------===// +// MemberAccessExpr +//===----------------------------------------------------------------------===// + +// CHECK: Module +// CHECK: `-VariableDecl {{.*}} Name Type +// CHECK: `-MemberAccessExpr {{.*}} Member<0> Type +// CHECK: `-DeclRefExpr {{.*}} Type> +// CHECK: `-VariableDecl {{.*}} Name Type +// CHECK: `-MemberAccessExpr {{.*}} Member Type +// CHECK: `-DeclRefExpr {{.*}} Type> +Pattern { + let tuple = (firstElt = _: Op); + let firstEltIndex = tuple.0; + let firstEltName = tuple.firstElt; + + erase _: Op; +} + +// ----- + //===----------------------------------------------------------------------===// // OperationExpr //===----------------------------------------------------------------------===// @@ -90,6 +111,28 @@ // ----- +//===----------------------------------------------------------------------===// +// TupleExpr +//===----------------------------------------------------------------------===// + +// CHECK: Module +// CHECK: `-VariableDecl {{.*}} Name +// CHECK: `-TupleExpr {{.*}} Type> +// CHECK: `-VariableDecl {{.*}} Name +// CHECK: `-TupleExpr {{.*}} Type> +// CHECK: |-DeclRefExpr {{.*}} Type +// CHECK: `-DeclRefExpr {{.*}} Type +Pattern { + let value: Value; + + let emptyTuple = (); + let mixedTuple = (arg1 = _: Attr, value); + + erase _: Op; +} + +// ----- + //===----------------------------------------------------------------------===// // TypeExpr //===----------------------------------------------------------------------===//