diff --git a/mlir/include/mlir/Tools/PDLL/AST/Nodes.h b/mlir/include/mlir/Tools/PDLL/AST/Nodes.h --- a/mlir/include/mlir/Tools/PDLL/AST/Nodes.h +++ b/mlir/include/mlir/Tools/PDLL/AST/Nodes.h @@ -566,6 +566,40 @@ } }; +//===----------------------------------------------------------------------===// +// RangeExpr +//===----------------------------------------------------------------------===// + +/// This expression builds a range from a set of element values (which may be +/// ranges themselves). +class RangeExpr final : public Node::NodeBase, + private llvm::TrailingObjects { +public: + static RangeExpr *create(Context &ctx, SMRange loc, ArrayRef elements, + RangeType type); + + /// Return the element expressions of this range. + MutableArrayRef getElements() { + return {getTrailingObjects(), numElements}; + } + ArrayRef getElements() const { + return const_cast(this)->getElements(); + } + + /// Return the range result type of this expression. + RangeType getType() const { return Base::getType().cast(); } + +private: + RangeExpr(SMRange loc, RangeType type, unsigned numElements) + : Base(loc, type), numElements(numElements) {} + + /// The number of element values for this range. + unsigned numElements; + + /// TrailingObject utilities. + friend class llvm::TrailingObjects; +}; + //===----------------------------------------------------------------------===// // TupleExpr //===----------------------------------------------------------------------===// @@ -1284,7 +1318,7 @@ inline bool Expr::classof(const Node *node) { return isa(node); + OperationExpr, RangeExpr, TupleExpr, TypeExpr>(node); } inline bool OpRewriteStmt::classof(const Node *node) { diff --git a/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp b/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp --- a/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp +++ b/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp @@ -84,6 +84,7 @@ void printImpl(const DeclRefExpr *expr); void printImpl(const MemberAccessExpr *expr); void printImpl(const OperationExpr *expr); + void printImpl(const RangeExpr *expr); void printImpl(const TupleExpr *expr); void printImpl(const TypeExpr *expr); @@ -169,8 +170,8 @@ // Expressions. const AttributeExpr, const CallExpr, const DeclRefExpr, - const MemberAccessExpr, const OperationExpr, const TupleExpr, - const TypeExpr, + const MemberAccessExpr, const OperationExpr, const RangeExpr, + const TupleExpr, const TypeExpr, // Decls. const AttrConstraintDecl, const OpConstraintDecl, @@ -254,6 +255,14 @@ printChildren("Attributes", expr->getAttributes()); } +void NodePrinter::printImpl(const RangeExpr *expr) { + os << "RangeExpr " << expr << " Type<"; + print(expr->getType()); + os << ">\n"; + + printChildren(expr->getElements()); +} + void NodePrinter::printImpl(const TupleExpr *expr) { os << "TupleExpr " << expr << " Type<"; print(expr->getType()); diff --git a/mlir/lib/Tools/PDLL/AST/Nodes.cpp b/mlir/lib/Tools/PDLL/AST/Nodes.cpp --- a/mlir/lib/Tools/PDLL/AST/Nodes.cpp +++ b/mlir/lib/Tools/PDLL/AST/Nodes.cpp @@ -57,8 +57,8 @@ // Expressions. const AttributeExpr, const CallExpr, const DeclRefExpr, - const MemberAccessExpr, const OperationExpr, const TupleExpr, - const TypeExpr, + const MemberAccessExpr, const OperationExpr, const RangeExpr, + const TupleExpr, const TypeExpr, // Core Constraint Decls. const AttrConstraintDecl, const OpConstraintDecl, @@ -109,6 +109,10 @@ for (const Node *child : expr->getAttributes()) visit(child); } + void visitImpl(const RangeExpr *expr) { + for (const Node *child : expr->getElements()) + visit(child); + } void visitImpl(const TupleExpr *expr) { for (const Node *child : expr->getElements()) visit(child); @@ -325,6 +329,21 @@ return getNameDecl()->getName(); } +//===----------------------------------------------------------------------===// +// RangeExpr +//===----------------------------------------------------------------------===// + +RangeExpr *RangeExpr::create(Context &ctx, SMRange loc, + ArrayRef elements, RangeType type) { + unsigned allocSize = RangeExpr::totalSizeToAlloc(elements.size()); + void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(TupleExpr)); + + RangeExpr *expr = new (rawData) RangeExpr(loc, type, elements.size()); + std::uninitialized_copy(elements.begin(), elements.end(), + expr->getElements().begin()); + return expr; +} + //===----------------------------------------------------------------------===// // TupleExpr //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp b/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp --- a/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp +++ b/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp @@ -97,6 +97,7 @@ SmallVector genExprImpl(const ast::DeclRefExpr *expr); Value genExprImpl(const ast::MemberAccessExpr *expr); Value genExprImpl(const ast::OperationExpr *expr); + Value genExprImpl(const ast::RangeExpr *expr); SmallVector genExprImpl(const ast::TupleExpr *expr); Value genExprImpl(const ast::TypeExpr *expr); @@ -377,7 +378,8 @@ Value CodeGen::genSingleExpr(const ast::Expr *expr) { return TypeSwitch(expr) .Case( + const ast::OperationExpr, const ast::RangeExpr, + const ast::TypeExpr>( [&](auto derivedNode) { return this->genExprImpl(derivedNode); }) .Case( [&](auto derivedNode) { @@ -517,6 +519,15 @@ attrValues, results); } +Value CodeGen::genExprImpl(const ast::RangeExpr *expr) { + SmallVector elements; + for (const ast::Expr *element : expr->getElements()) + llvm::append_range(elements, genExpr(element)); + + return builder.create(genLoc(expr->getLoc()), + genType(expr->getType()), elements); +} + SmallVector CodeGen::genExprImpl(const ast::TupleExpr *expr) { SmallVector elements; for (const ast::Expr *element : expr->getElements()) diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp --- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp +++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp @@ -47,10 +47,9 @@ bool enableDocumentation, CodeCompleteContext *codeCompleteContext) : ctx(ctx), lexer(sourceMgr, ctx.getDiagEngine(), codeCompleteContext), curToken(lexer.lexToken()), enableDocumentation(enableDocumentation), - valueTy(ast::ValueType::get(ctx)), - valueRangeTy(ast::ValueRangeType::get(ctx)), - typeTy(ast::TypeType::get(ctx)), + typeTy(ast::TypeType::get(ctx)), valueTy(ast::ValueType::get(ctx)), typeRangeTy(ast::TypeRangeType::get(ctx)), + valueRangeTy(ast::ValueRangeType::get(ctx)), attrTy(ast::AttributeType::get(ctx)), codeCompleteContext(codeCompleteContext) {} @@ -116,6 +115,14 @@ LogicalResult convertExpressionTo( ast::Expr *&expr, ast::Type type, function_ref noteAttachFn = {}); + LogicalResult + convertOpExpressionTo(ast::Expr *&expr, ast::OperationType exprType, + ast::Type type, + function_ref emitErrorFn); + LogicalResult convertTupleExpressionTo( + ast::Expr *&expr, ast::TupleType exprType, ast::Type type, + function_ref emitErrorFn, + function_ref noteAttachFn); /// Given an operation expression, convert it to a Value or ValueRange /// typed expression. @@ -555,8 +562,8 @@ ParserContext parserContext = ParserContext::Global; /// Cached types to simplify verification and expression creation. - ast::Type valueTy, valueRangeTy; - ast::Type typeTy, typeRangeTy; + ast::Type typeTy, valueTy; + ast::RangeType typeRangeTy, valueRangeTy; ast::Type attrTy; /// A counter used when naming anonymous constraints and rewrites. @@ -619,55 +626,8 @@ return diag; }; - if (auto exprOpType = exprType.dyn_cast()) { - // Two operation types are compatible if they have the same name, or if the - // expected type is more general. - if (auto opType = type.dyn_cast()) { - if (opType.getName()) - return emitConvertError(); - return success(); - } - - // An operation can always convert to a ValueRange. - if (type == valueRangeTy) { - expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr, - valueRangeTy); - return success(); - } - - // Allow conversion to a single value by constraining the result range. - if (type == valueTy) { - // If the operation is registered, we can verify if it can ever have a - // single result. - if (const ods::Operation *odsOp = exprOpType.getODSOperation()) { - if (odsOp->getResults().empty()) { - return emitConvertError()->attachNote( - llvm::formatv("see the definition of `{0}`, which was defined " - "with zero results", - odsOp->getName()), - odsOp->getLoc()); - } - - unsigned numSingleResults = llvm::count_if( - odsOp->getResults(), [](const ods::OperandOrResult &result) { - return result.getVariableLengthKind() == - ods::VariableLengthKind::Single; - }); - if (numSingleResults > 1) { - return emitConvertError()->attachNote( - llvm::formatv("see the definition of `{0}`, which was defined " - "with at least {1} results", - odsOp->getName(), numSingleResults), - odsOp->getLoc()); - } - } - - expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr, - valueTy); - return success(); - } - return emitConvertError(); - } + if (auto exprOpType = exprType.dyn_cast()) + return convertOpExpressionTo(expr, exprOpType, type, emitConvertError); // FIXME: Decide how to allow/support converting a single result to multiple, // and multiple to a single result. For now, we just allow Single->Range, @@ -681,22 +641,85 @@ return success(); // Handle tuple types. - if (auto exprTupleType = exprType.dyn_cast()) { - auto tupleType = type.dyn_cast(); - if (!tupleType || tupleType.size() != exprTupleType.size()) - return emitConvertError(); + if (auto exprTupleType = exprType.dyn_cast()) + return convertTupleExpressionTo(expr, exprTupleType, type, emitConvertError, + noteAttachFn); + + return emitConvertError(); +} + +LogicalResult Parser::convertOpExpressionTo( + ast::Expr *&expr, ast::OperationType exprType, ast::Type type, + function_ref emitErrorFn) { + // Two operation types are compatible if they have the same name, or if the + // expected type is more general. + if (auto opType = type.dyn_cast()) { + if (opType.getName()) + return emitErrorFn(); + return success(); + } + + // An operation can always convert to a ValueRange. + if (type == valueRangeTy) { + expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr, + valueRangeTy); + return success(); + } + + // Allow conversion to a single value by constraining the result range. + if (type == valueTy) { + // If the operation is registered, we can verify if it can ever have a + // single result. + if (const ods::Operation *odsOp = exprType.getODSOperation()) { + if (odsOp->getResults().empty()) { + return emitErrorFn()->attachNote( + llvm::formatv("see the definition of `{0}`, which was defined " + "with zero results", + odsOp->getName()), + odsOp->getLoc()); + } + + unsigned numSingleResults = llvm::count_if( + odsOp->getResults(), [](const ods::OperandOrResult &result) { + return result.getVariableLengthKind() == + ods::VariableLengthKind::Single; + }); + if (numSingleResults > 1) { + return emitErrorFn()->attachNote( + llvm::formatv("see the definition of `{0}`, which was defined " + "with at least {1} results", + odsOp->getName(), numSingleResults), + odsOp->getLoc()); + } + } + + expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr, + valueTy); + return success(); + } + return emitErrorFn(); +} + +LogicalResult Parser::convertTupleExpressionTo( + ast::Expr *&expr, ast::TupleType exprType, ast::Type type, + function_ref emitErrorFn, + function_ref noteAttachFn) { + // Handle conversions between tuples. + if (auto tupleType = type.dyn_cast()) { + if (tupleType.size() != exprType.size()) + return emitErrorFn(); // Build a new tuple expression using each of the elements of the current // tuple. SmallVector newExprs; - for (unsigned i = 0, e = exprTupleType.size(); i < e; ++i) { + for (unsigned i = 0, e = exprType.size(); i < e; ++i) { newExprs.push_back(ast::MemberAccessExpr::create( ctx, expr->getLoc(), expr, llvm::to_string(i), - exprTupleType.getElementTypes()[i])); + exprType.getElementTypes()[i])); auto diagFn = [&](ast::Diagnostic &diag) { diag.attachNote(llvm::formatv("when converting element #{0} of `{1}`", - i, exprTupleType)); + i, exprType)); if (noteAttachFn) noteAttachFn(diag); }; @@ -709,7 +732,37 @@ return success(); } - return emitConvertError(); + // Handle conversion to a range. + auto convertToRange = [&](ArrayRef allowedElementTypes, + ast::RangeType resultTy) -> LogicalResult { + // TODO: We currently only allow range conversion within a rewrite context. + if (parserContext != ParserContext::Rewrite) { + return emitErrorFn()->attachNote("Tuple to Range conversion is currently " + "only allowed within a rewrite context"); + } + + // All of the tuple elements must be allowed types. + for (ast::Type elementType : exprType.getElementTypes()) + if (!llvm::is_contained(allowedElementTypes, elementType)) + return emitErrorFn(); + + // Build a new tuple expression using each of the elements of the current + // tuple. + SmallVector newExprs; + for (unsigned i = 0, e = exprType.size(); i < e; ++i) { + newExprs.push_back(ast::MemberAccessExpr::create( + ctx, expr->getLoc(), expr, llvm::to_string(i), + exprType.getElementTypes()[i])); + } + expr = ast::RangeExpr::create(ctx, expr->getLoc(), newExprs, resultTy); + return success(); + }; + if (type == valueRangeTy) + return convertToRange({valueTy, valueRangeTy}, valueRangeTy); + if (type == typeRangeTy) + return convertToRange({typeTy, typeRangeTy}, typeRangeTy); + + return emitErrorFn(); } //===----------------------------------------------------------------------===// @@ -2955,6 +3008,10 @@ } } + // Otherwise, try to convert the expression to a range. + if (succeeded(convertExpressionTo(valueExpr, rangeTy))) + continue; + return emitError( valueExpr->getLoc(), llvm::formatv( diff --git a/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll b/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll --- a/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll +++ b/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll @@ -114,6 +114,26 @@ // ----- +//===----------------------------------------------------------------------===// +// RangeExpr +//===----------------------------------------------------------------------===// + +// CHECK: pdl.pattern @RangeExpr +// CHECK: %[[ARG:.*]] = operand +// CHECK: %[[ARGS:.*]] = operands +// CHECK: %[[TYPE:.*]] = type +// CHECK: %[[TYPES:.*]] = types +// CHECK: range : !pdl.range +// CHECK: range %[[ARG]], %[[ARGS]] : !pdl.value, !pdl.range +// CHECK: range : !pdl.range +// CHECK: range %[[TYPE]], %[[TYPES]] : !pdl.type, !pdl.range +Pattern RangeExpr { + replace op<>(arg: Value, args: ValueRange) -> (type: Type, types: TypeRange) + with op((), (arg, args)) -> ((), (type, types)); +} + +// ----- + //===----------------------------------------------------------------------===// // TypeExpr //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-pdll/Parser/expr-failure.pdll b/mlir/test/mlir-pdll/Parser/expr-failure.pdll --- a/mlir/test/mlir-pdll/Parser/expr-failure.pdll +++ b/mlir/test/mlir-pdll/Parser/expr-failure.pdll @@ -124,6 +124,25 @@ // ----- +//===----------------------------------------------------------------------===// +// Range Expr +//===----------------------------------------------------------------------===// + +Pattern { + // CHECK: unable to convert expression of type `Tuple<>` to the expected type of `ValueRange` + // CHECK: Tuple to Range conversion is currently only allowed within a rewrite context + erase op<>(()); +} + +// ----- + +Pattern { + // CHECK: unable to convert expression of type `Tuple` to the expected type of `ValueRange` + replace op<>(arg: Value) -> (type: Type) with op((arg, type)); +} + +// ----- + //===----------------------------------------------------------------------===// // Tuple Expr //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-pdll/Parser/expr.pdll b/mlir/test/mlir-pdll/Parser/expr.pdll --- a/mlir/test/mlir-pdll/Parser/expr.pdll +++ b/mlir/test/mlir-pdll/Parser/expr.pdll @@ -235,6 +235,29 @@ // ----- +//===----------------------------------------------------------------------===// +// RangeExpr +//===----------------------------------------------------------------------===// + +// CHECK: Module +// CHECK: `Operands` +// CHECK: -RangeExpr {{.*}} Type +// CHECK: -RangeExpr {{.*}} Type +// CHECK: -MemberAccessExpr {{.*}} Member<0> Type +// CHECK: -MemberAccessExpr {{.*}} Member<1> Type +// CHECK: `Result Types` +// CHECK: -RangeExpr {{.*}} Type +// CHECK: -RangeExpr {{.*}} Type +// CHECK: -MemberAccessExpr {{.*}} Member<0> Type +// CHECK: -MemberAccessExpr {{.*}} Member<1> Type +Pattern { + rewrite op<>(arg: Value, args: ValueRange) -> (type: Type, types: TypeRange) with { + op((), (arg, args)) -> ((), (type, types)); + }; +} + +// ----- + //===----------------------------------------------------------------------===// // TypeExpr //===----------------------------------------------------------------------===//