diff --git a/mlir/include/mlir/Tools/PDLL/AST/Nodes.h b/mlir/include/mlir/Tools/PDLL/AST/Nodes.h --- a/mlir/include/mlir/Tools/PDLL/AST/Nodes.h +++ b/mlir/include/mlir/Tools/PDLL/AST/Nodes.h @@ -1239,8 +1239,8 @@ } inline bool Expr::classof(const Node *node) { - return isa(node); + return isa(node); } inline bool OpRewriteStmt::classof(const Node *node) { diff --git a/mlir/include/mlir/Tools/PDLL/CodeGen/MLIRGen.h b/mlir/include/mlir/Tools/PDLL/CodeGen/MLIRGen.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Tools/PDLL/CodeGen/MLIRGen.h @@ -0,0 +1,41 @@ +//===- MLIRGen.h - MLIR PDLL Code Generation --------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOOLS_PDLL_CODEGEN_MLIRGEN_H_ +#define MLIR_TOOLS_PDLL_CODEGEN_MLIRGEN_H_ + +#include + +#include "mlir/Support/LogicalResult.h" + +namespace llvm { +class SourceMgr; +} // namespace llvm + +namespace mlir { +class MLIRContext; +class ModuleOp; +template +class OwningOpRef; + +namespace pdll { +namespace ast { +class Context; +class Module; +} // namespace ast + +/// Given a PDLL module, generate an MLIR PDL pattern module within the given +/// MLIR context. +OwningOpRef codegenPDLLToMLIR(MLIRContext *mlirContext, + const ast::Context &context, + const llvm::SourceMgr &sourceMgr, + const ast::Module &module); +} // namespace pdll +} // namespace mlir + +#endif // MLIR_TOOLS_PDLL_CODEGEN_MLIRGEN_H_ diff --git a/mlir/lib/Tools/PDLL/CMakeLists.txt b/mlir/lib/Tools/PDLL/CMakeLists.txt --- a/mlir/lib/Tools/PDLL/CMakeLists.txt +++ b/mlir/lib/Tools/PDLL/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(AST) +add_subdirectory(CodeGen) add_subdirectory(Parser) diff --git a/mlir/lib/Tools/PDLL/CodeGen/CMakeLists.txt b/mlir/lib/Tools/PDLL/CodeGen/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Tools/PDLL/CodeGen/CMakeLists.txt @@ -0,0 +1,9 @@ +add_mlir_library(MLIRPDLLCodeGen + MLIRGen.cpp + + LINK_LIBS PUBLIC + MLIRParser + MLIRPDLLAST + MLIRPDL + MLIRSupport + ) diff --git a/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp b/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp @@ -0,0 +1,580 @@ +//===- MLIRGen.cpp --------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Tools/PDLL/CodeGen/MLIRGen.h" +#include "mlir/Dialect/PDL/IR/PDL.h" +#include "mlir/Dialect/PDL/IR/PDLOps.h" +#include "mlir/Dialect/PDL/IR/PDLTypes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Parser.h" +#include "mlir/Tools/PDLL/AST/Context.h" +#include "mlir/Tools/PDLL/AST/Nodes.h" +#include "mlir/Tools/PDLL/AST/Types.h" +#include "llvm/ADT/ScopedHashTable.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; +using namespace mlir::pdll; + +//===----------------------------------------------------------------------===// +// CodeGen +//===----------------------------------------------------------------------===// + +namespace { +class CodeGen { +public: + CodeGen(MLIRContext *mlirContext, const ast::Context &context, + const llvm::SourceMgr &sourceMgr) + : builder(mlirContext), sourceMgr(sourceMgr) { + // Make sure that the PDL dialect is loaded. + mlirContext->loadDialect(); + } + + OwningOpRef generate(const ast::Module &module); + +private: + /// Generate an MLIR location from the given source location. + Location genLoc(llvm::SMLoc loc); + Location genLoc(llvm::SMRange loc) { return genLoc(loc.Start); } + + /// Generate an MLIR type from the given source type. + Type genType(ast::Type type); + + /// Generate MLIR for the given AST node. + void gen(const ast::Node *node); + + //===--------------------------------------------------------------------===// + // Statements + //===--------------------------------------------------------------------===// + + void genImpl(const ast::CompoundStmt *stmt); + void genImpl(const ast::LetStmt *stmt); + void genImpl(const ast::EraseStmt *stmt); + void genImpl(const ast::ReplaceStmt *stmt); + void genImpl(const ast::RewriteStmt *stmt); + void genImpl(const ast::ReturnStmt *stmt); + + //===--------------------------------------------------------------------===// + // Decls + //===--------------------------------------------------------------------===// + + void genImpl(const ast::UserConstraintDecl *decl); + void genImpl(const ast::UserRewriteDecl *decl); + void genImpl(const ast::PatternDecl *decl); + SmallVector genVar(const ast::VariableDecl *varDecl); + + /// Generate the value for a variable that does not have an initializer + /// expression, i.e. create the PDL value based on the type/constraints of the + /// variable. + Value genNonInitializerVar(const ast::VariableDecl *varDecl, Location loc); + + /// Apply the constraints of the given variable to `values`, which correspond + /// to the MLIR values of the variable. + void applyVarConstraints(const ast::VariableDecl *varDecl, ValueRange values); + + //===--------------------------------------------------------------------===// + // Expressions + //===--------------------------------------------------------------------===// + + Value genSingleExpr(const ast::Expr *expr); + SmallVector genExpr(const ast::Expr *expr); + Value genExprImpl(const ast::AttributeExpr *expr); + SmallVector genExprImpl(const ast::CallExpr *expr); + SmallVector genExprImpl(const ast::DeclRefExpr *expr); + Value genExprImpl(const ast::MemberAccessExpr *expr); + Value genExprImpl(const ast::OperationExpr *expr); + SmallVector genExprImpl(const ast::TupleExpr *expr); + Value genExprImpl(const ast::TypeExpr *expr); + + SmallVector genConstraintCall(const ast::UserConstraintDecl *decl, + Location loc, ValueRange inputs); + SmallVector genRewriteCall(const ast::UserRewriteDecl *decl, + Location loc, ValueRange inputs); + template + SmallVector genConstraintOrRewriteCall(const T *decl, Location loc, + ValueRange inputs); + + //===--------------------------------------------------------------------===// + // Fields + //===--------------------------------------------------------------------===// + + /// The MLIR builder used for building the resultant IR. + OpBuilder builder; + + /// A map from variable declarations to the mlir equivalent. + using VariableMapTy = + llvm::ScopedHashTable>; + VariableMapTy variables; + + /// The source manager of the PDLL ast. + const llvm::SourceMgr &sourceMgr; +}; +} // namespace + +OwningOpRef CodeGen::generate(const ast::Module &module) { + OwningOpRef mlirModule = + builder.create(genLoc(module.getLoc())); + builder.setInsertionPointToStart(mlirModule->getBody()); + + // Generate code for each of the decls within the module. + for (const ast::Decl *decl : module.getChildren()) + gen(decl); + + return mlirModule; +} + +Location CodeGen::genLoc(llvm::SMLoc loc) { + unsigned fileID = sourceMgr.FindBufferContainingLoc(loc); + + // TODO: Fix performance issues in SourceMgr::getLineAndColumn so that we can + // use it here. + auto &bufferInfo = sourceMgr.getBufferInfo(fileID); + unsigned lineNo = bufferInfo.getLineNumber(loc.getPointer()); + unsigned column = + (loc.getPointer() - bufferInfo.getPointerForLineNumber(lineNo)) + 1; + auto *buffer = sourceMgr.getMemoryBuffer(fileID); + + return FileLineColLoc::get(builder.getContext(), + buffer->getBufferIdentifier(), lineNo, column); +} + +Type CodeGen::genType(ast::Type type) { + return TypeSwitch(type) + .Case([&](ast::AttributeType astType) -> Type { + return builder.getType(); + }) + .Case([&](ast::OperationType astType) -> Type { + return builder.getType(); + }) + .Case([&](ast::TypeType astType) -> Type { + return builder.getType(); + }) + .Case([&](ast::ValueType astType) -> Type { + return builder.getType(); + }) + .Case([&](ast::RangeType astType) -> Type { + return pdl::RangeType::get(genType(astType.getElementType())); + }); +} + +void CodeGen::gen(const ast::Node *node) { + TypeSwitch(node) + .Case( + [&](auto derivedNode) { this->genImpl(derivedNode); }) + .Case([&](const ast::Expr *expr) { genExpr(expr); }); +} + +//===----------------------------------------------------------------------===// +// CodeGen: Statements +//===----------------------------------------------------------------------===// + +void CodeGen::genImpl(const ast::CompoundStmt *stmt) { + VariableMapTy::ScopeTy varScope(variables); + for (const ast::Stmt *childStmt : stmt->getChildren()) + gen(childStmt); +} + +void CodeGen::genImpl(const ast::LetStmt *stmt) { genVar(stmt->getVarDecl()); } + +/// If the given builder is nested under a PDL PatternOp, build a rewrite +/// operation and update the builder to nest under it. This is necessary for +/// PDLL operation rewrite statements that are directly nested within a Pattern. +static void checkAndNestUnderRewriteOp(OpBuilder &builder, Value rootExpr, + Location loc) { + if (isa(builder.getInsertionBlock()->getParentOp())) { + pdl::RewriteOp rewrite = builder.create( + loc, rootExpr, /*name=*/StringAttr(), + /*externalArgs=*/ValueRange(), /*externalConstParams=*/ArrayAttr()); + builder.createBlock(&rewrite.body()); + } +} + +void CodeGen::genImpl(const ast::EraseStmt *stmt) { + OpBuilder::InsertionGuard insertGuard(builder); + Value rootExpr = genSingleExpr(stmt->getRootOpExpr()); + Location loc = genLoc(stmt->getLoc()); + + // Make sure we are nested in a RewriteOp. + checkAndNestUnderRewriteOp(builder, rootExpr, loc); + builder.create(loc, rootExpr); +} + +void CodeGen::genImpl(const ast::ReplaceStmt *stmt) { + OpBuilder::InsertionGuard insertGuard(builder); + Value rootExpr = genSingleExpr(stmt->getRootOpExpr()); + Location loc = genLoc(stmt->getLoc()); + + // Make sure we are nested in a RewriteOp. + checkAndNestUnderRewriteOp(builder, rootExpr, loc); + + SmallVector replValues; + for (ast::Expr *replExpr : stmt->getReplExprs()) + replValues.push_back(genSingleExpr(replExpr)); + + // Check to see if the statement has a replacement operation, or a range of + // replacement values. + bool usesReplOperation = + replValues.size() == 1 && + replValues.front().getType().isa(); + builder.create( + loc, rootExpr, usesReplOperation ? replValues[0] : Value(), + usesReplOperation ? ValueRange() : ValueRange(replValues)); +} + +void CodeGen::genImpl(const ast::RewriteStmt *stmt) { + OpBuilder::InsertionGuard insertGuard(builder); + Value rootExpr = genSingleExpr(stmt->getRootOpExpr()); + + // Make sure we are nested in a RewriteOp. + checkAndNestUnderRewriteOp(builder, rootExpr, genLoc(stmt->getLoc())); + gen(stmt->getRewriteBody()); +} + +void CodeGen::genImpl(const ast::ReturnStmt *stmt) { + // ReturnStmt generation is handled by the respective constraint or rewrite + // parent node. +} + +//===----------------------------------------------------------------------===// +// CodeGen: Decls +//===----------------------------------------------------------------------===// + +void CodeGen::genImpl(const ast::UserConstraintDecl *decl) { + // All PDLL constraints get inlined when called, and the main native + // constraint declarations doesn't require any MLIR to be generated, only uses + // of it do. +} + +void CodeGen::genImpl(const ast::UserRewriteDecl *decl) { + // All PDLL rewrites get inlined when called, and the main native + // rewrite declarations doesn't require any MLIR to be generated, only uses + // of it do. +} + +void CodeGen::genImpl(const ast::PatternDecl *decl) { + const ast::Name *name = decl->getName(); + + // FIXME: Properly model HasBoundedRecursion in PDL so that we don't drop it + // here. + pdl::PatternOp pattern = builder.create( + genLoc(decl->getLoc()), decl->getBenefit(), + name ? Optional(name->getName()) : Optional()); + + OpBuilder::InsertionGuard savedInsertPoint(builder); + builder.setInsertionPointToStart(pattern.getBody()); + gen(decl->getBody()); +} + +SmallVector CodeGen::genVar(const ast::VariableDecl *varDecl) { + if (variables.count(varDecl)) + return variables.lookup(varDecl); + Location loc = genLoc(varDecl->getLoc()); + + // If the variable has an initial value, use that as the base value. + // Otherwise, generate a value using the constraint list. + SmallVector values; + if (const ast::Expr *initExpr = varDecl->getInitExpr()) + values = genExpr(initExpr); + else + values.push_back(genNonInitializerVar(varDecl, loc)); + + // Apply the constraints of the values of the variable. + applyVarConstraints(varDecl, values); + + variables.insert(varDecl, values); + return values; +} + +Value CodeGen::genNonInitializerVar(const ast::VariableDecl *varDecl, + Location loc) { + // A functor used to generate expressions nested + auto getTypeConstraint = [&]() -> Value { + for (const ast::ConstraintRef &constraint : varDecl->getConstraints()) { + Value typeValue = + TypeSwitch(constraint.constraint) + .Case([&](auto *cst) -> Value { + if (auto *typeConstraintExpr = cst->getTypeExpr()) + return genSingleExpr(typeConstraintExpr); + return Value(); + }) + .Default(Value()); + if (typeValue) + return typeValue; + } + return Value(); + }; + + // Generate a value based on the type of the variable. + ast::Type type = varDecl->getType(); + Type mlirType = genType(type); + if (type.isa()) + return builder.create(loc, mlirType, getTypeConstraint()); + if (type.isa()) + return builder.create(loc, mlirType, /*type=*/TypeAttr()); + if (type.isa()) + return builder.create(loc, getTypeConstraint()); + if (ast::OperationType opType = type.dyn_cast()) { + Value operands = builder.create( + loc, pdl::RangeType::get(builder.getType()), + /*type=*/Value()); + Value results = builder.create( + loc, pdl::RangeType::get(builder.getType()), + /*types=*/ArrayAttr()); + return builder.create(loc, opType.getName(), operands, + llvm::None, ValueRange(), results); + } + + if (ast::RangeType rangeTy = type.dyn_cast()) { + ast::Type eleTy = rangeTy.getElementType(); + if (eleTy.isa()) + return builder.create(loc, mlirType, + getTypeConstraint()); + if (eleTy.isa()) + return builder.create(loc, mlirType, /*types=*/ArrayAttr()); + } + + llvm_unreachable("invalid non-initialized variable type"); +} + +void CodeGen::applyVarConstraints(const ast::VariableDecl *varDecl, + ValueRange values) { + // Generate calls to any user constraints that were attached via the + // constraint list. + for (const ast::ConstraintRef &ref : varDecl->getConstraints()) + if (const auto *userCst = dyn_cast(ref.constraint)) + genConstraintCall(userCst, genLoc(ref.referenceLoc), values); +} + +//===----------------------------------------------------------------------===// +// CodeGen: Expressions +//===----------------------------------------------------------------------===// + +Value CodeGen::genSingleExpr(const ast::Expr *expr) { + return TypeSwitch(expr) + .Case( + [&](auto derivedNode) { return this->genExprImpl(derivedNode); }) + .Case( + [&](auto derivedNode) { + SmallVector results = this->genExprImpl(derivedNode); + assert(results.size() == 1 && "expected single expression result"); + return results[0]; + }); +} + +SmallVector CodeGen::genExpr(const ast::Expr *expr) { + return TypeSwitch>(expr) + .Case( + [&](auto derivedNode) { return this->genExprImpl(derivedNode); }) + .Default([&](const ast::Expr *expr) -> SmallVector { + return {genSingleExpr(expr)}; + }); +} + +Value CodeGen::genExprImpl(const ast::AttributeExpr *expr) { + Attribute attr = parseAttribute(expr->getValue(), builder.getContext()); + assert(attr && "invalid mlir attribute data"); + return builder.create(genLoc(expr->getLoc()), attr); +} + +SmallVector CodeGen::genExprImpl(const ast::CallExpr *expr) { + Location loc = genLoc(expr->getLoc()); + SmallVector arguments; + for (const ast::Expr *arg : expr->getArguments()) + arguments.push_back(genSingleExpr(arg)); + + // Resolve the callable expression of this call. + auto *callableExpr = dyn_cast(expr->getCallableExpr()); + assert(callableExpr && "unhandled CallExpr callable"); + + // Generate the PDL based on the type of callable. + const ast::Decl *callable = callableExpr->getDecl(); + if (const auto *decl = dyn_cast(callable)) + return genConstraintCall(decl, loc, arguments); + if (const auto *decl = dyn_cast(callable)) + return genRewriteCall(decl, loc, arguments); + llvm_unreachable("unhandled CallExpr callable"); +} + +SmallVector CodeGen::genExprImpl(const ast::DeclRefExpr *expr) { + if (const auto *varDecl = dyn_cast(expr->getDecl())) + return genVar(varDecl); + llvm_unreachable("unknown decl reference expression"); +} + +Value CodeGen::genExprImpl(const ast::MemberAccessExpr *expr) { + Location loc = genLoc(expr->getLoc()); + StringRef name = expr->getMemberName(); + SmallVector parentExprs = genExpr(expr->getParentExpr()); + ast::Type parentType = expr->getParentExpr()->getType(); + + // Handle operation based member access. + if (ast::OperationType opType = parentType.dyn_cast()) { + if (isa(expr)) { + Type mlirType = genType(expr->getType()); + if (mlirType.isa()) + return builder.create(loc, mlirType, parentExprs[0], + builder.getI32IntegerAttr(0)); + return builder.create(loc, mlirType, parentExprs[0]); + } + llvm_unreachable("unhandled operation member access expression"); + } + + // Handle tuple based member access. + if (auto tupleType = parentType.dyn_cast()) { + auto elementNames = tupleType.getElementNames(); + + // The index is either a numeric index, or a name. + unsigned index = 0; + if (llvm::isDigit(name[0])) + name.getAsInteger(/*Radix=*/10, index); + else + index = llvm::find(elementNames, name) - elementNames.begin(); + + assert(index < parentExprs.size() && "invalid result index"); + return parentExprs[index]; + } + + llvm_unreachable("unhandled member access expression"); +} + +Value CodeGen::genExprImpl(const ast::OperationExpr *expr) { + Location loc = genLoc(expr->getLoc()); + Optional opName = expr->getName(); + + // Operands. + SmallVector operands; + for (const ast::Expr *operand : expr->getOperands()) + operands.push_back(genSingleExpr(operand)); + + // Attributes. + SmallVector attrNames; + SmallVector attrValues; + for (const ast::NamedAttributeDecl *attr : expr->getAttributes()) { + attrNames.push_back(attr->getName().getName()); + attrValues.push_back(genSingleExpr(attr->getValue())); + } + + // Results. + SmallVector results; + for (const ast::Expr *result : expr->getResultTypes()) + results.push_back(genSingleExpr(result)); + + return builder.create(loc, opName, operands, attrNames, + attrValues, results); +} + +SmallVector CodeGen::genExprImpl(const ast::TupleExpr *expr) { + SmallVector elements; + for (const ast::Expr *element : expr->getElements()) + elements.push_back(genSingleExpr(element)); + return elements; +} + +Value CodeGen::genExprImpl(const ast::TypeExpr *expr) { + Type type = parseType(expr->getValue(), builder.getContext()); + assert(type && "invalid mlir type data"); + return builder.create(genLoc(expr->getLoc()), + builder.getType(), + TypeAttr::get(type)); +} + +SmallVector +CodeGen::genConstraintCall(const ast::UserConstraintDecl *decl, Location loc, + ValueRange inputs) { + // Apply any constraints defined on the arguments to the input values. + for (auto it : llvm::zip(decl->getInputs(), inputs)) + applyVarConstraints(std::get<0>(it), std::get<1>(it)); + + // Generate the constraint call. + SmallVector results = + genConstraintOrRewriteCall(decl, loc, + inputs); + + // Apply any constraints defined on the results of the constraint. + for (auto it : llvm::zip(decl->getResults(), results)) + applyVarConstraints(std::get<0>(it), std::get<1>(it)); + return results; +} + +SmallVector CodeGen::genRewriteCall(const ast::UserRewriteDecl *decl, + Location loc, ValueRange inputs) { + return genConstraintOrRewriteCall(decl, loc, + inputs); +} + +template +SmallVector CodeGen::genConstraintOrRewriteCall(const T *decl, + Location loc, + ValueRange inputs) { + const ast::CompoundStmt *cstBody = decl->getBody(); + + // If the decl doesn't have a statement body, it is a native decl. + if (!cstBody) { + ast::Type declResultType = decl->getResultType(); + SmallVector resultTypes; + if (ast::TupleType tupleType = declResultType.dyn_cast()) { + for (ast::Type type : tupleType.getElementTypes()) + resultTypes.push_back(genType(type)); + } else { + resultTypes.push_back(genType(declResultType)); + } + + // FIXME: We currently do not have a modeling for the "constant params" + // support PDL provides. We should either figure out a modeling for this, or + // refactor the support within PDL to be something a bit more reasonable for + // what we need as a frontend. + Operation *pdlOp = builder.create(loc, resultTypes, + decl->getName().getName(), inputs, + /*params=*/ArrayAttr()); + return pdlOp->getResults(); + } + + // Otherwise, this is a PDLL decl. + VariableMapTy::ScopeTy varScope(variables); + + // Map the inputs of the call to the decl arguments. + // Note: This is only valid because we do not support recursion, meaning + // we don't need to worry about conflicting mappings here. + for (auto it : llvm::zip(inputs, decl->getInputs())) + variables.insert(std::get<1>(it), {std::get<0>(it)}); + + // Visit the body of the call as normal. + gen(cstBody); + + // If the decl has no results, there is nothing to do. + if (cstBody->getChildren().empty()) + return SmallVector(); + auto *returnStmt = dyn_cast(cstBody->getChildren().back()); + if (!returnStmt) + return SmallVector(); + + // Otherwise, grab the results from the return statement. + return genExpr(returnStmt->getResultExpr()); +} + +//===----------------------------------------------------------------------===// +// MLIRGen +//===----------------------------------------------------------------------===// + +OwningOpRef mlir::pdll::codegenPDLLToMLIR( + MLIRContext *mlirContext, const ast::Context &context, + const llvm::SourceMgr &sourceMgr, const ast::Module &module) { + CodeGen codegen(mlirContext, context, sourceMgr); + OwningOpRef mlirModule = codegen.generate(module); + if (failed(verify(*mlirModule))) + return nullptr; + return mlirModule; +} diff --git a/mlir/test/mlir-pdll/CodeGen/MLIR/decl.pdll b/mlir/test/mlir-pdll/CodeGen/MLIR/decl.pdll new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-pdll/CodeGen/MLIR/decl.pdll @@ -0,0 +1,97 @@ +// RUN: mlir-pdll %s -I %S -split-input-file -x mlir | FileCheck %s + +//===----------------------------------------------------------------------===// +// PatternDecl +//===----------------------------------------------------------------------===// + +// CHECK: pdl.pattern : benefit(0) { +Pattern => erase _: Op; + +// ----- + +// CHECK: pdl.pattern @NamedPattern : benefit(0) { +Pattern NamedPattern => erase _: Op; + +// ----- + +// CHECK: pdl.pattern @NamedPattern : benefit(10) { +Pattern NamedPattern with benefit(10), recursion => erase _: Op; + +// ----- + +//===----------------------------------------------------------------------===// +// VariableDecl +//===----------------------------------------------------------------------===// + +// Test the case of a variable with an initializer. + +// CHECK: pdl.pattern @VarWithInit +// CHECK: %[[INIT:.*]] = operation "test.op" +// CHECK: rewrite %[[INIT]] { +// CHECK: erase %[[INIT]] +Pattern VarWithInit { + let var = op; + erase var; +} + +// ----- + +// Test range based constraints. + +// CHECK: pdl.pattern @VarWithRangeConstraints +// CHECK: %[[OPERAND_TYPES:.*]] = types +// CHECK: %[[OPERANDS:.*]] = operands : %[[OPERAND_TYPES]] +// CHECK: %[[RESULT_TYPES:.*]] = types +// CHECK: operation(%[[OPERANDS]] : !pdl.range) -> (%[[RESULT_TYPES]] : !pdl.range) +Pattern VarWithRangeConstraints { + erase op<>(operands: ValueRange) -> (results: TypeRange); +} + +// ----- + +// Test single entity constraints. + +// CHECK: pdl.pattern @VarWithConstraints +// CHECK: %[[OPERAND_TYPE:.*]] = type +// CHECK: %[[OPERAND:.*]] = operand : %[[OPERAND_TYPES]] +// CHECK: %[[ATTR_TYPE:.*]] = type +// CHECK: %[[ATTR:.*]] = attribute : %[[ATTR_TYPE]] +// CHECK: %[[RESULT_TYPE:.*]] = type +// CHECK: operation(%[[OPERAND]] : !pdl.value) {"attr" = %[[ATTR]]} -> (%[[RESULT_TYPE]] : !pdl.type) +Pattern VarWithConstraints { + erase op<>(operand: Value) { attr = _: Attr} -> (result: Type); +} + +// ----- + +// Test op constraint. + +// CHECK: pdl.pattern @VarWithNoNameOpConstraint +// CHECK: %[[OPERANDS:.*]] = operands +// CHECK: %[[RESULT_TYPES:.*]] = types +// CHECK: operation(%[[OPERANDS]] : !pdl.range) -> (%[[RESULT_TYPES]] : !pdl.range) +Pattern VarWithNoNameOpConstraint => erase _: Op; + +// CHECK: pdl.pattern @VarWithNamedOpConstraint +// CHECK: %[[OPERANDS:.*]] = operands +// CHECK: %[[RESULT_TYPES:.*]] = types +// CHECK: operation "test.op"(%[[OPERANDS]] : !pdl.range) -> (%[[RESULT_TYPES]] : !pdl.range) +Pattern VarWithNamedOpConstraint => erase _: Op; + +// ----- + +// Test user defined constraints. + +// CHECK: pdl.pattern @VarWithUserConstraint +// CHECK: %[[OPERANDS:.*]] = operands +// CHECK: %[[RESULT_TYPES:.*]] = types +// CHECK: %[[OP:.*]] = operation(%[[OPERANDS]] : !pdl.range) -> (%[[RESULT_TYPES]] : !pdl.range) +// CHECK: apply_native_constraint "NestedArgCst"(%[[OP]] : !pdl.operation) +// CHECK: apply_native_constraint "NestedResCst"(%[[OP]] : !pdl.operation) +// CHECK: apply_native_constraint "OpCst"(%[[OP]] : !pdl.operation) +// CHECK: rewrite %[[OP]] +Constraint NestedArgCst(op: Op); +Constraint NestedResCst(op: Op); +Constraint TestArgResCsts(op: NestedArgCst) -> NestedResCst => op; +Constraint OpCst(op: Op); +Pattern VarWithUserConstraint => erase _: [TestArgResCsts, OpCst]; diff --git a/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll b/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll @@ -0,0 +1,93 @@ +// RUN: mlir-pdll %s -I %S -split-input-file -x mlir | FileCheck %s + +//===----------------------------------------------------------------------===// +// AttributeExpr +//===----------------------------------------------------------------------===// + +// CHECK: pdl.pattern @AttrExpr +// CHECK: %[[ATTR:.*]] = attribute 10 +// CHECK: operation({{.*}}) {"attr" = %[[ATTR]]} +Pattern AttrExpr => erase op<> { attr = attr<"10"> }; + +// ----- + +//===----------------------------------------------------------------------===// +// CallExpr +//===----------------------------------------------------------------------===// + +// CHECK: pdl.pattern @TestCallWithArgsAndReturn +// CHECK: %[[ROOT:.*]] = operation +// CHECK: rewrite %[[ROOT]] +// CHECK: %[[REPL_OP:.*]] = operation "test.op" +// CHECK: %[[RESULTS:.*]] = results of %[[REPL_OP]] +// CHECK: replace %[[ROOT]] with(%[[RESULTS]] : !pdl.range) +Rewrite TestRewrite(root: Op) -> ValueRange => root; +Pattern TestCallWithArgsAndReturn => replace root: Op with TestRewrite(op); + +// ----- + +// CHECK: pdl.pattern @TestExternalCall +// CHECK: %[[ROOT:.*]] = operation +// CHECK: rewrite %[[ROOT]] +// CHECK: %[[RESULTS:.*]] = apply_native_rewrite "TestRewrite"(%[[ROOT]] : !pdl.operation) : !pdl.range +// CHECK: replace %[[ROOT]] with(%[[RESULTS]] : !pdl.range) +Rewrite TestRewrite(op: Op) -> ValueRange; +Pattern TestExternalCall => replace root: Op with TestRewrite(root); + +// ----- + +//===----------------------------------------------------------------------===// +// MemberAccessExpr +//===----------------------------------------------------------------------===// + +// Handle implicit "all" operation results access. +// CHECK: pdl.pattern @OpAllResultMemberAccess +// CHECK: %[[OP0:.*]] = operation +// CHECK: %[[OP0_RES:.*]] = result 0 of %[[OP0]] +// CHECK: %[[OP1:.*]] = operation +// CHECK: %[[OP1_RES:.*]] = results of %[[OP1]] +// CHECK: operation(%[[OP0_RES]], %[[OP1_RES]] : !pdl.value, !pdl.range) +Pattern OpAllResultMemberAccess { + let singleVar: Value = op<>; + let rangeVar: ValueRange = op<>; + erase op<>(singleVar, rangeVar); +} + +// ----- + +// CHECK: pdl.pattern @TupleMemberAccessNumber +// CHECK: %[[FIRST:.*]] = operation "test.first" +// CHECK: %[[SECOND:.*]] = operation "test.second" +// CHECK: rewrite %[[FIRST]] { +// CHECK: replace %[[FIRST]] with %[[SECOND]] +Pattern TupleMemberAccessNumber { + let firstOp = op; + let secondOp = op(firstOp); + let tuple = (firstOp, secondOp); + replace tuple.0 with tuple.1; +} + +// ----- + +// CHECK: pdl.pattern @TupleMemberAccessName +// CHECK: %[[FIRST:.*]] = operation "test.first" +// CHECK: %[[SECOND:.*]] = operation "test.second" +// CHECK: rewrite %[[FIRST]] { +// CHECK: replace %[[FIRST]] with %[[SECOND]] +Pattern TupleMemberAccessName { + let firstOp = op; + let secondOp = op(firstOp); + let tuple = (first = firstOp, second = secondOp); + replace tuple.first with tuple.second; +} + +// ----- + +//===----------------------------------------------------------------------===// +// TypeExpr +//===----------------------------------------------------------------------===// + +// CHECK: pdl.pattern @TypeExpr +// CHECK: %[[TYPE:.*]] = type : i32 +// CHECK: operation({{.*}}) -> (%[[TYPE]] : !pdl.type) +Pattern TypeExpr => erase op<> -> (type<"i32">); diff --git a/mlir/test/mlir-pdll/CodeGen/MLIR/stmt.pdll b/mlir/test/mlir-pdll/CodeGen/MLIR/stmt.pdll new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-pdll/CodeGen/MLIR/stmt.pdll @@ -0,0 +1,61 @@ +// RUN: mlir-pdll %s -I %S -split-input-file -x mlir | FileCheck %s + +//===----------------------------------------------------------------------===// +// EraseStmt +//===----------------------------------------------------------------------===// + +// CHECK: pdl.pattern @EraseStmt +// CHECK: %[[OP:.*]] = operation +// CHECK: rewrite %[[OP]] +// CHECK: erase %[[OP]] +Pattern EraseStmt => erase op<>; + +// ----- + +// CHECK: pdl.pattern @EraseStmtNested +// CHECK: %[[OP:.*]] = operation +// CHECK: rewrite %[[OP]] +// CHECK: erase %[[OP]] +Pattern EraseStmtNested => rewrite root: Op with { erase root; }; + +// ----- + +//===----------------------------------------------------------------------===// +// ReplaceStmt +//===----------------------------------------------------------------------===// + +// CHECK: pdl.pattern @ReplaceStmt +// CHECK: %[[OPERANDS:.*]] = operands +// CHECK: %[[OP:.*]] = operation(%[[OPERANDS]] +// CHECK: rewrite %[[OP]] +// CHECK: replace %[[OP]] with(%[[OPERANDS]] : !pdl.range) +Pattern ReplaceStmt => replace op<>(operands: ValueRange) with operands; + +// ----- + +// CHECK: pdl.pattern @ReplaceStmtNested +// CHECK: %[[OPERANDS:.*]] = operands +// CHECK: %[[OP:.*]] = operation(%[[OPERANDS]] +// CHECK: rewrite %[[OP]] +// CHECK: replace %[[OP]] with(%[[OPERANDS]] : !pdl.range) +Pattern ReplaceStmtNested { + let root = op<>(operands: ValueRange); + rewrite root with { replace root with operands; }; +} + +// ----- + +//===----------------------------------------------------------------------===// +// RewriteStmt +//===----------------------------------------------------------------------===// + +// CHECK: pdl.pattern @RewriteStmtNested +// CHECK: %[[OP:.*]] = operation +// CHECK: rewrite %[[OP]] +// CHECK: erase %[[OP]] +Pattern RewriteStmtNested { + rewrite root: Op with { + rewrite root with { erase root; }; + }; +} + diff --git a/mlir/test/mlir-pdll/Parser/lit.local.cfg b/mlir/test/mlir-pdll/lit.local.cfg rename from mlir/test/mlir-pdll/Parser/lit.local.cfg rename to mlir/test/mlir-pdll/lit.local.cfg diff --git a/mlir/tools/mlir-pdll/CMakeLists.txt b/mlir/tools/mlir-pdll/CMakeLists.txt --- a/mlir/tools/mlir-pdll/CMakeLists.txt +++ b/mlir/tools/mlir-pdll/CMakeLists.txt @@ -1,5 +1,6 @@ set(LIBS MLIRPDLLAST + MLIRPDLLCodeGen MLIRPDLLParser ) diff --git a/mlir/tools/mlir-pdll/mlir-pdll.cpp b/mlir/tools/mlir-pdll/mlir-pdll.cpp --- a/mlir/tools/mlir-pdll/mlir-pdll.cpp +++ b/mlir/tools/mlir-pdll/mlir-pdll.cpp @@ -6,10 +6,12 @@ // //===----------------------------------------------------------------------===// +#include "mlir/IR/BuiltinOps.h" #include "mlir/Support/FileUtilities.h" #include "mlir/Support/ToolUtilities.h" #include "mlir/Tools/PDLL/AST/Context.h" #include "mlir/Tools/PDLL/AST/Nodes.h" +#include "mlir/Tools/PDLL/CodeGen/MLIRGen.h" #include "mlir/Tools/PDLL/Parser/Parser.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/InitLLVM.h" @@ -26,6 +28,7 @@ /// The desired output type. enum class OutputType { AST, + MLIR, }; static LogicalResult @@ -40,12 +43,18 @@ if (failed(module)) return failure(); - switch (outputType) { - case OutputType::AST: + if (outputType == OutputType::AST) { (*module)->print(os); - break; + return success(); } + MLIRContext mlirContext; + OwningOpRef pdlModule = + codegenPDLLToMLIR(&mlirContext, astContext, sourceMgr, **module); + if (!pdlModule) + return failure(); + + pdlModule->print(os, OpPrintingFlags().enableDebugInfo()); return success(); } @@ -71,7 +80,9 @@ "x", llvm::cl::init(OutputType::AST), llvm::cl::desc("The type of output desired"), llvm::cl::values(clEnumValN(OutputType::AST, "ast", - "generate the AST for the input file"))); + "generate the AST for the input file"), + clEnumValN(OutputType::MLIR, "mlir", + "generate the PDL MLIR for the input file"))); llvm::InitLLVM y(argc, argv); llvm::cl::ParseCommandLineOptions(argc, argv, "PDLL Frontend");