diff --git a/mlir/include/mlir/Tools/PDLL/CodeGen/CPPGen.h b/mlir/include/mlir/Tools/PDLL/CodeGen/CPPGen.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Tools/PDLL/CodeGen/CPPGen.h @@ -0,0 +1,28 @@ +//===- CPPGen.h - MLIR PDLL CPP Code Generation -----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOOLS_PDLL_CODEGEN_CPPGEN_H_ +#define MLIR_TOOLS_PDLL_CODEGEN_CPPGEN_H_ + +#include "mlir/Support/LLVM.h" +#include + +namespace mlir { +class ModuleOp; + +namespace pdll { +namespace ast { +class Module; +} // namespace ast + +void codegenPDLLToCPP(const ast::Module &astModule, ModuleOp module, + raw_ostream &os); +} // namespace pdll +} // namespace mlir + +#endif // MLIR_TOOLS_PDLL_CODEGEN_CPPGEN_H_ diff --git a/mlir/lib/Tools/PDLL/CodeGen/CMakeLists.txt b/mlir/lib/Tools/PDLL/CodeGen/CMakeLists.txt --- a/mlir/lib/Tools/PDLL/CodeGen/CMakeLists.txt +++ b/mlir/lib/Tools/PDLL/CodeGen/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_library(MLIRPDLLCodeGen + CPPGen.cpp MLIRGen.cpp LINK_LIBS PUBLIC diff --git a/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp b/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp @@ -0,0 +1,219 @@ +//===- CPPGen.cpp ---------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This files contains a PDLL generator that outputs C++ code that defines PDLL +// patterns as individual C++ PDLPatternModules for direct use in native code, +// and also defines any native constraints whose bodies were defined in PDLL. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Tools/PDLL/CodeGen/CPPGen.h" +#include "mlir/Dialect/PDL/IR/PDL.h" +#include "mlir/Dialect/PDL/IR/PDLOps.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Tools/PDLL/AST/Nodes.h" +#include "llvm/ADT/SmallString.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringSet.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/FormatVariadic.h" + +using namespace mlir; +using namespace mlir::pdll; + +//===----------------------------------------------------------------------===// +// CodeGen +//===----------------------------------------------------------------------===// + +namespace { +class CodeGen { +public: + CodeGen(raw_ostream &os) : os(os) {} + + /// Generate C++ code for the given PDL pattern module. + void generate(const ast::Module &astModule, ModuleOp module); + +private: + void generate(pdl::PatternOp pattern, StringRef patternName, + StringSet<> &nativeFunctions); + + /// Generate C++ code for all user defined constraints and rewrites with + /// native code. + void generateConstraintAndRewrites(const ast::Module &astModule, + ModuleOp module, + StringSet<> &nativeFunctions); + void generate(const ast::UserConstraintDecl *decl, + StringSet<> &nativeFunctions); + void generate(const ast::UserRewriteDecl *decl, StringSet<> &nativeFunctions); + void generateConstraintOrRewrite(StringRef name, bool isConstraint, + ArrayRef inputs, + StringRef codeBlock, + StringSet<> &nativeFunctions); + + /// The stream to output to. + raw_ostream &os; +}; +} // namespace + +void CodeGen::generate(const ast::Module &astModule, ModuleOp module) { + SetVector, StringSet<>> patternNames; + StringSet<> nativeFunctions; + + // Generate code for any native functions within the module. + generateConstraintAndRewrites(astModule, module, nativeFunctions); + + os << "namespace {\n"; + std::string basePatternName = "GeneratedPDLLPattern"; + int patternIndex = 0; + for (pdl::PatternOp pattern : module.getOps()) { + // If the pattern has a name, use that. Otherwise, generate a unique name. + if (Optional patternName = pattern.sym_name()) { + patternNames.insert(patternName->str()); + } else { + std::string name; + do { + name = (basePatternName + Twine(patternIndex++)).str(); + } while (!patternNames.insert(name)); + } + + generate(pattern, patternNames.back(), nativeFunctions); + } + os << "} // end namespace\n\n"; + + // Emit function to add the generated matchers to the pattern list. + os << "static void LLVM_ATTRIBUTE_UNUSED populateGeneratedPDLLPatterns(" + "::mlir::RewritePatternSet &patterns) {\n"; + for (const auto &name : patternNames) + os << " patterns.add<" << name << ">(patterns.getContext());\n"; + os << "}\n"; +} + +void CodeGen::generate(pdl::PatternOp pattern, StringRef patternName, + StringSet<> &nativeFunctions) { + const char *patternClassStartStr = R"( +struct {0} : ::mlir::PDLPatternModule {{ + {0}(::mlir::MLIRContext *context) + : ::mlir::PDLPatternModule(::mlir::parseSourceString<::mlir::ModuleOp>( +)"; + os << llvm::formatv(patternClassStartStr, patternName); + + os << "R\"mlir("; + pattern->print(os, OpPrintingFlags().enableDebugInfo()); + os << "\n )mlir\", context)) {\n"; + + // Register any native functions used within the pattern. + StringSet<> registeredNativeFunctions; + auto checkRegisterNativeFn = [&](StringRef fnName, StringRef fnType) { + if (!nativeFunctions.count(fnName) || + !registeredNativeFunctions.insert(fnName).second) + return; + os << " register" << fnType << "Function(\"" << fnName << "\", " + << fnName << "PDLFn);\n"; + }; + pattern.walk([&](Operation *op) { + if (auto constraintOp = dyn_cast(op)) + checkRegisterNativeFn(constraintOp.name(), "Constraint"); + else if (auto rewriteOp = dyn_cast(op)) + checkRegisterNativeFn(rewriteOp.name(), "Rewrite"); + }); + os << " }\n};\n\n"; +} + +void CodeGen::generateConstraintAndRewrites(const ast::Module &astModule, + ModuleOp module, + StringSet<> &nativeFunctions) { + // First check to see which constraints and rewrites are actually referenced + // in the module. + StringSet<> usedFns; + module.walk([&](Operation *op) { + TypeSwitch(op) + .Case( + [&](auto op) { usedFns.insert(op.name()); }); + }); + + for (const ast::Decl *decl : astModule.getChildren()) { + TypeSwitch(decl) + .Case( + [&](const auto *decl) { + // We only generate code for inline native decls that have been + // referenced. + if (decl->getCodeBlock() && + usedFns.contains(decl->getName().getName())) + this->generate(decl, nativeFunctions); + }); + } +} + +void CodeGen::generate(const ast::UserConstraintDecl *decl, + StringSet<> &nativeFunctions) { + return generateConstraintOrRewrite(decl->getName().getName(), + /*isConstraint=*/true, decl->getInputs(), + *decl->getCodeBlock(), nativeFunctions); +} + +void CodeGen::generate(const ast::UserRewriteDecl *decl, + StringSet<> &nativeFunctions) { + return generateConstraintOrRewrite(decl->getName().getName(), + /*isConstraint=*/false, decl->getInputs(), + *decl->getCodeBlock(), nativeFunctions); +} + +void CodeGen::generateConstraintOrRewrite(StringRef name, bool isConstraint, + ArrayRef inputs, + StringRef codeBlock, + StringSet<> &nativeFunctions) { + nativeFunctions.insert(name); + + // TODO: Should there be something explicit for handling optionality? + auto getCppType = [&](ast::Type type) -> StringRef { + return llvm::TypeSwitch(type) + .Case([&](ast::AttributeType) { return "::mlir::Attribute"; }) + .Case([&](ast::OperationType) { + // TODO: Allow using the derived Op class when possible. + return "::mlir::Operation *"; + }) + .Case([&](ast::TypeType) { return "::mlir::Type"; }) + .Case([&](ast::ValueType) { return "::mlir::Value"; }) + .Case([&](ast::TypeRangeType) { return "::mlir::TypeRange"; }) + .Case([&](ast::ValueRangeType) { return "::mlir::ValueRange"; }); + }; + + // FIXME: We currently do not have a modeling for the "constant params" + // support PDL provides. We should either figure out a modeling for this, or + // refactor the support within PDL to be something a bit more reasonable for + // what we need as a frontend. + os << "static " << (isConstraint ? "::mlir::LogicalResult " : "void ") << name + << "PDLFn(::llvm::ArrayRef<::mlir::PDLValue> values, " + "::mlir::ArrayAttr constParams, ::mlir::PatternRewriter &rewriter" + << (isConstraint ? "" : ", ::mlir::PDLResultList &results") << ") {\n"; + + const char *argumentInitStr = R"( + {0} {1} = {{}; + if (values[{2}]) + {1} = values[{2}].cast<{0}>(); + (void){1}; +)"; + for (const auto &it : llvm::enumerate(inputs)) { + const ast::VariableDecl *input = it.value(); + os << llvm::formatv(argumentInitStr, getCppType(input->getType()), + input->getName().getName(), it.index()); + } + + os << " " << codeBlock.trim() << "\n}\n"; +} + +//===----------------------------------------------------------------------===// +// CPPGen +//===----------------------------------------------------------------------===// + +void mlir::pdll::codegenPDLLToCPP(const ast::Module &astModule, ModuleOp module, + raw_ostream &os) { + CodeGen codegen(os); + codegen.generate(astModule, module); +} diff --git a/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll b/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll @@ -0,0 +1,105 @@ +// RUN: mlir-pdll %s -I %S -split-input-file -x cpp | FileCheck %s + +// Check that we generate a wrapper pattern for each PDL pattern. Also +// add in a pattern awkwardly named the same as our generated patterns to +// check that we handle overlap. + +// CHECK: struct GeneratedPDLLPattern0 : ::mlir::PDLPatternModule { +// CHECK: : ::mlir::PDLPatternModule(::mlir::parseSourceString<::mlir::ModuleOp>( +// CHECK: R"mlir( +// CHECK: pdl.pattern +// CHECK: operation "test.op" +// CHECK: )mlir", context)) + +// CHECK: struct NamedPattern : ::mlir::PDLPatternModule { +// CHECK: : ::mlir::PDLPatternModule(::mlir::parseSourceString<::mlir::ModuleOp>( +// CHECK: R"mlir( +// CHECK: pdl.pattern +// CHECK: operation "test.op2" +// CHECK: )mlir", context)) + +// CHECK: struct GeneratedPDLLPattern1 : ::mlir::PDLPatternModule { + +// CHECK: struct GeneratedPDLLPattern2 : ::mlir::PDLPatternModule { +// CHECK: : ::mlir::PDLPatternModule(::mlir::parseSourceString<::mlir::ModuleOp>( +// CHECK: R"mlir( +// CHECK: pdl.pattern +// CHECK: operation "test.op3" +// CHECK: )mlir", context)) + +// CHECK: static void LLVM_ATTRIBUTE_UNUSED populateGeneratedPDLLPatterns(::mlir::RewritePatternSet &patterns) { +// CHECK-NEXT: patterns.add(patterns.getContext()); +// CHECK-NEXT: patterns.add(patterns.getContext()); +// CHECK-NEXT: patterns.add(patterns.getContext()); +// CHECK-NEXT: patterns.add(patterns.getContext()); +// CHECK-NEXT: } + +Pattern => erase op; +Pattern NamedPattern => erase op; +Pattern GeneratedPDLLPattern1 => erase op<>; +Pattern => erase op; + +// ----- + +// Check the generation of native constraints and rewrites. + +// CHECK: static ::mlir::LogicalResult TestCstPDLFn(::llvm::ArrayRef<::mlir::PDLValue> values, ::mlir::ArrayAttr constParams, +// CHECK-SAME: ::mlir::PatternRewriter &rewriter) { +// CHECK: ::mlir::Attribute attr = {}; +// CHECK: if (values[0]) +// CHECK: attr = values[0].cast<::mlir::Attribute>(); +// CHECK: ::mlir::Operation * op = {}; +// CHECK: if (values[1]) +// CHECK: op = values[1].cast<::mlir::Operation *>(); +// CHECK: ::mlir::Type type = {}; +// CHECK: if (values[2]) +// CHECK: type = values[2].cast<::mlir::Type>(); +// CHECK: ::mlir::Value value = {}; +// CHECK: if (values[3]) +// CHECK: value = values[3].cast<::mlir::Value>(); +// CHECK: ::mlir::TypeRange typeRange = {}; +// CHECK: if (values[4]) +// CHECK: typeRange = values[4].cast<::mlir::TypeRange>(); +// CHECK: ::mlir::ValueRange valueRange = {}; +// CHECK: if (values[5]) +// CHECK: valueRange = values[5].cast<::mlir::ValueRange>(); + +// CHECK: return success(); +// CHECK: } + +// CHECK-NOT: TestUnusedCst + +// CHECK: static void TestRewritePDLFn(::llvm::ArrayRef<::mlir::PDLValue> values, ::mlir::ArrayAttr constParams, +// CHECK-SAME: ::mlir::PatternRewriter &rewriter, ::mlir::PDLResultList &results) { +// CHECK: ::mlir::Attribute attr = {}; +// CHECK: ::mlir::Operation * op = {}; +// CHECK: ::mlir::Type type = {}; +// CHECK: ::mlir::Value value = {}; +// CHECK: ::mlir::TypeRange typeRange = {}; +// CHECK: ::mlir::ValueRange valueRange = {}; + +// CHECK: foo; +// CHECK: } + +// CHECK-NOT: TestUnusedRewrite + +// CHECK: struct TestCstAndRewrite : ::mlir::PDLPatternModule { +// CHECK: registerConstraintFunction("TestCst", TestCstPDLFn); +// CHECK: registerRewriteFunction("TestRewrite", TestRewritePDLFn); + +Constraint TestCst(attr: Attr, op: Op, type: Type, value: Value, typeRange: TypeRange, valueRange: ValueRange) [{ + return success(); +}]; +Constraint TestUnusedCst() [{ return success(); }]; + +Rewrite TestRewrite(attr: Attr, op: Op, type: Type, value: Value, typeRange: TypeRange, valueRange: ValueRange) [{ foo; }]; +Rewrite TestUnusedRewrite(op: Op) [{}]; + +Pattern TestCstAndRewrite { + let root = op<>(operand: Value, operands: ValueRange) -> (type: Type, types: TypeRange); + TestCst(attr<"true">, root, type, operand, types, operands); + rewrite root with { + TestRewrite(attr<"true">, root, type, operand, types, operands); + erase root; + }; +} diff --git a/mlir/tools/mlir-pdll/mlir-pdll.cpp b/mlir/tools/mlir-pdll/mlir-pdll.cpp --- a/mlir/tools/mlir-pdll/mlir-pdll.cpp +++ b/mlir/tools/mlir-pdll/mlir-pdll.cpp @@ -11,6 +11,7 @@ #include "mlir/Support/ToolUtilities.h" #include "mlir/Tools/PDLL/AST/Context.h" #include "mlir/Tools/PDLL/AST/Nodes.h" +#include "mlir/Tools/PDLL/CodeGen/CPPGen.h" #include "mlir/Tools/PDLL/CodeGen/MLIRGen.h" #include "mlir/Tools/PDLL/Parser/Parser.h" #include "llvm/Support/CommandLine.h" @@ -29,6 +30,7 @@ enum class OutputType { AST, MLIR, + CPP, }; static LogicalResult @@ -54,7 +56,12 @@ if (!pdlModule) return failure(); - pdlModule->print(os, OpPrintingFlags().enableDebugInfo()); + if (outputType == OutputType::MLIR) { + pdlModule->print(os, OpPrintingFlags().enableDebugInfo()); + return success(); + } + + codegenPDLLToCPP(**module, *pdlModule, os); return success(); } @@ -82,7 +89,10 @@ llvm::cl::values(clEnumValN(OutputType::AST, "ast", "generate the AST for the input file"), clEnumValN(OutputType::MLIR, "mlir", - "generate the PDL MLIR for the input file"))); + "generate the PDL MLIR for the input file"), + clEnumValN(OutputType::CPP, "cpp", + "generate a C++ source file containing the " + "patterns for the input file"))); llvm::InitLLVM y(argc, argv); llvm::cl::ParseCommandLineOptions(argc, argv, "PDLL Frontend");