diff --git a/mlir/include/mlir/CMakeLists.txt b/mlir/include/mlir/CMakeLists.txt --- a/mlir/include/mlir/CMakeLists.txt +++ b/mlir/include/mlir/CMakeLists.txt @@ -1,3 +1,4 @@ add_subdirectory(Dialect) add_subdirectory(IR) add_subdirectory(Interfaces) +add_subdirectory(Transforms) diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h --- a/mlir/include/mlir/InitAllPasses.h +++ b/mlir/include/mlir/InitAllPasses.h @@ -35,6 +35,8 @@ #include "mlir/Quantizer/Transforms/Passes.h" #include "mlir/Transforms/LocationSnapshot.h" #include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/ViewOpGraph.h" +#include "mlir/Transforms/ViewRegionGraph.h" #include @@ -48,6 +50,10 @@ // individual passes. // The global registry is interesting to interact with the command-line tools. inline void registerAllPasses() { + // Init general passes +#define GEN_PASS_REGISTRATION +#include "mlir/Transforms/Passes.h.inc" + // At the moment we still rely on global initializers for registering passes, // but we may not do it in the future. // We must reference the passes in such a way that compilers will not @@ -57,27 +63,17 @@ if (std::getenv("bar") != (char *)-1) return; - // Init general passes - createCanonicalizerPass(); - createCSEPass(); + // Affine createSuperVectorizePass({}); createLoopUnrollPass(); createLoopUnrollAndJamPass(); createSimplifyAffineStructuresPass(); - createLoopFusionPass(); createLoopInvariantCodeMotionPass(); createAffineLoopInvariantCodeMotionPass(); - createPipelineDataTransferPass(); createLowerAffinePass(); createLoopTilingPass(0); - createLoopCoalescingPass(); createAffineDataCopyGenerationPass(0, 0); createMemRefDataFlowOptPass(); - createStripDebugInfoPass(); - createPrintOpStatsPass(); - createInlinerPass(); - createSymbolDCEPass(); - createLocationSnapshotPass({}); // AVX512 createConvertAVX512ToLLVMPass(); diff --git a/mlir/include/mlir/Pass/PassBase.td b/mlir/include/mlir/Pass/PassBase.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Pass/PassBase.td @@ -0,0 +1,35 @@ +//===-- PassBase.td - Base pass definition file ------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains definitions for defining pass registration and other +// mechanisms. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_PASS_PASSBASE +#define MLIR_PASS_PASSBASE + +//===----------------------------------------------------------------------===// +// Pass +//===----------------------------------------------------------------------===// + +class Pass { + // The command line argument of the pass. + string argument = passArg; + + // A short 1-line summary of the pass. + string summary = ""; + + // A human readable description of the pass. + string description = ""; + + // A C++ constructor call to create an instance of this pass. + code constructor = [{}]; +} + +#endif // MLIR_PASS_PASSBASE diff --git a/mlir/include/mlir/Pass/PassRegistry.h b/mlir/include/mlir/Pass/PassRegistry.h --- a/mlir/include/mlir/Pass/PassRegistry.h +++ b/mlir/include/mlir/Pass/PassRegistry.h @@ -122,7 +122,7 @@ /// Register a specific dialect pass allocator function with the system, /// typically used through the PassRegistration template. -void registerPass(StringRef arg, StringRef description, const PassID *passID, +void registerPass(StringRef arg, StringRef description, const PassAllocatorFunction &function); /// PassRegistration provides a global initializer that registers a Pass @@ -138,7 +138,7 @@ template struct PassRegistration { PassRegistration(StringRef arg, StringRef description, const PassAllocatorFunction &constructor) { - registerPass(arg, description, PassID::getID(), constructor); + registerPass(arg, description, constructor); } PassRegistration(StringRef arg, StringRef description) diff --git a/mlir/include/mlir/TableGen/Pass.h b/mlir/include/mlir/TableGen/Pass.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/TableGen/Pass.h @@ -0,0 +1,49 @@ +//===- Pass.h - TableGen pass definitions -----------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TABLEGEN_PASS_H_ +#define MLIR_TABLEGEN_PASS_H_ + +#include "mlir/Support/LLVM.h" +#include + +namespace llvm { +class Record; +} // end namespace llvm + +namespace mlir { +namespace tblgen { +//===----------------------------------------------------------------------===// +// Pass +//===----------------------------------------------------------------------===// + +/// Wrapper class providing helper methods for Passes defined in TableGen. +class Pass { +public: + explicit Pass(const llvm::Record *def); + + /// Return the command line argument of the pass. + StringRef getArgument() const; + + /// Return the short 1-line summary of the pass. + StringRef getSummary() const; + + /// Return the description of the pass. + StringRef getDescription() const; + + /// Return the C++ constructor call to create an instance of this pass. + StringRef getConstructor() const; + +private: + const llvm::Record *def; +}; + +} // end namespace tblgen +} // end namespace mlir + +#endif // MLIR_TABLEGEN_PASS_H_ diff --git a/mlir/include/mlir/Transforms/CMakeLists.txt b/mlir/include/mlir/Transforms/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Transforms/CMakeLists.txt @@ -0,0 +1,4 @@ + +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls) +add_public_tablegen_target(MLIRTransformsPassIncGen) diff --git a/mlir/include/mlir/Transforms/LocationSnapshot.h b/mlir/include/mlir/Transforms/LocationSnapshot.h --- a/mlir/include/mlir/Transforms/LocationSnapshot.h +++ b/mlir/include/mlir/Transforms/LocationSnapshot.h @@ -58,6 +58,8 @@ std::unique_ptr createLocationSnapshotPass(OpPrintingFlags flags, StringRef fileName = "", StringRef tag = ""); +/// Overload utilizing pass options for initialization. +std::unique_ptr createLocationSnapshotPass(); } // end namespace mlir diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Transforms/Passes.td @@ -0,0 +1,96 @@ +//===-- Passes.td - Transforms pass definition file --------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains definitions for passes within the Transforms/ directory. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TRANSFORMS_PASSES +#define MLIR_TRANSFORMS_PASSES + +include "mlir/Pass/PassBase.td" + +def AffinePipelineDataTransfer : Pass< + "affine-pipeline-data-transfer"> { + let summary = "Pipeline non-blocking data transfers between explicitly " + "managed levels of the memory hierarchy"; + let constructor = "mlir::createPipelineDataTransferPass()"; +} + +def AffineLoopFusion : Pass<"affine-loop-fusion"> { + let summary = "Fuse affine loop nests"; + let constructor = "mlir::createLoopFusionPass()"; +} + +def Canonicalizer : Pass<"canonicalize"> { + let summary = "Canonicalize operations"; + let constructor = "mlir::createCanonicalizerPass()"; +} + +def CSE : Pass<"cse"> { + let summary = "Eliminate common sub-expressions"; + let constructor = "mlir::createCSEPass()"; +} + +def Inliner : Pass<"inline"> { + let summary = "Inline function calls"; + let constructor = "mlir::createInlinerPass()"; +} + +def LocationSnapshot : Pass<"snapshot-op-locations"> { + let summary = "Generate new locations from the current IR"; + let constructor = "mlir::createLocationSnapshotPass()"; +} + +def LoopCoalescing : Pass<"loop-coalescing"> { + let summary = "Coalesce nested loops with independent bounds into a single " + "loop"; + let constructor = "mlir::createLoopCoalescingPass()"; +} + +def LoopInvariantCodeMotion : Pass<"loop-invariant-code-motion"> { + let summary = "Hoist loop invariant instructions outside of the loop"; + let constructor = "mlir::createLoopInvariantCodeMotionPass()"; +} + +def MemRefDataFlowOpt : Pass<"memref-dataflow-opt"> { + let summary = "Perform store/load forwarding for memrefs"; + let constructor = "mlir::createMemRefDataFlowOptPass()"; +} + +def ParallelLoopCollapsing : Pass<"parallel-loop-collapsing"> { + let summary = "Collapse parallel loops to use less induction variables"; + let constructor = "mlir::createParallelLoopCollapsingPass()"; +} + +def PrintCFG : Pass<"print-cfg-graph"> { + let summary = "Print CFG graph per-Region"; + let constructor = "mlir::createPrintCFGGraphPass()"; +} + +def PrintOpStats : Pass<"print-op-stats"> { + let summary = "Print statistics of operations"; + let constructor = "mlir::createPrintOpStatsPass()"; +} + +def PrintOp : Pass<"print-op-graph"> { + let summary = "Print op graph per-Region"; + let constructor = "mlir::createPrintOpGraphPass()"; +} + +def StripDebugInfo : Pass<"strip-debuginfo"> { + let summary = "Strip debug info from all operations"; + let constructor = "mlir::createStripDebugInfoPass()"; +} + +def SymbolDCE : Pass<"symbol-dce"> { + let summary = "Eliminate dead symbols"; + let constructor = "mlir::createSymbolDCEPass()"; +} + +#endif // MLIR_TRANSFORMS_PASSES diff --git a/mlir/lib/Pass/PassRegistry.cpp b/mlir/lib/Pass/PassRegistry.cpp --- a/mlir/lib/Pass/PassRegistry.cpp +++ b/mlir/lib/Pass/PassRegistry.cpp @@ -96,8 +96,9 @@ }) {} void mlir::registerPass(StringRef arg, StringRef description, - const PassID *passID, const PassAllocatorFunction &function) { + // TODO: We should use the 'arg' as the lookup key instead of the pass id. + const PassID *passID = function()->getPassID(); PassInfo passInfo(arg, description, passID, function); bool inserted = passRegistry->try_emplace(passID, passInfo).second; assert(inserted && "Pass registered multiple times"); diff --git a/mlir/lib/TableGen/CMakeLists.txt b/mlir/lib/TableGen/CMakeLists.txt --- a/mlir/lib/TableGen/CMakeLists.txt +++ b/mlir/lib/TableGen/CMakeLists.txt @@ -8,6 +8,7 @@ OpClass.cpp OpInterfaces.cpp OpTrait.cpp + Pass.cpp Pattern.cpp Predicate.cpp SideEffects.cpp diff --git a/mlir/lib/TableGen/Pass.cpp b/mlir/lib/TableGen/Pass.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/TableGen/Pass.cpp @@ -0,0 +1,33 @@ +//===- Pass.cpp - Pass related classes ------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/TableGen/Pass.h" +#include "llvm/TableGen/Record.h" + +using namespace mlir; +using namespace mlir::tblgen; + +//===----------------------------------------------------------------------===// +// Pass +//===----------------------------------------------------------------------===// + +Pass::Pass(const llvm::Record *def) : def(def) {} + +StringRef Pass::getArgument() const { + return def->getValueAsString("argument"); +} + +StringRef Pass::getSummary() const { return def->getValueAsString("summary"); } + +StringRef Pass::getDescription() const { + return def->getValueAsString("description"); +} + +StringRef Pass::getConstructor() const { + return def->getValueAsString("constructor"); +} diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt --- a/mlir/lib/Transforms/CMakeLists.txt +++ b/mlir/lib/Transforms/CMakeLists.txt @@ -23,6 +23,7 @@ DEPENDS MLIRStandardOpsIncGen + MLIRTransformsPassIncGen ) target_link_libraries(MLIRTransforms diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -262,5 +262,3 @@ } std::unique_ptr mlir::createCSEPass() { return std::make_unique(); } - -static PassRegistration pass("cse", "Eliminate common sub-expressions"); diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp --- a/mlir/lib/Transforms/Canonicalizer.cpp +++ b/mlir/lib/Transforms/Canonicalizer.cpp @@ -40,6 +40,3 @@ std::unique_ptr mlir::createCanonicalizerPass() { return std::make_unique(); } - -static PassRegistration pass("canonicalize", - "Canonicalize operations"); diff --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp --- a/mlir/lib/Transforms/Inliner.cpp +++ b/mlir/lib/Transforms/Inliner.cpp @@ -622,5 +622,3 @@ std::unique_ptr mlir::createInlinerPass() { return std::make_unique(); } - -static PassRegistration pass("inline", "Inline function calls"); diff --git a/mlir/lib/Transforms/LocationSnapshot.cpp b/mlir/lib/Transforms/LocationSnapshot.cpp --- a/mlir/lib/Transforms/LocationSnapshot.cpp +++ b/mlir/lib/Transforms/LocationSnapshot.cpp @@ -157,6 +157,6 @@ StringRef tag) { return std::make_unique(flags, fileName, tag); } - -static PassRegistration - reg("snapshot-op-locations", "generate new locations from the current IR"); +std::unique_ptr mlir::createLocationSnapshotPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Transforms/LoopCoalescing.cpp b/mlir/lib/Transforms/LoopCoalescing.cpp --- a/mlir/lib/Transforms/LoopCoalescing.cpp +++ b/mlir/lib/Transforms/LoopCoalescing.cpp @@ -89,7 +89,3 @@ std::unique_ptr> mlir::createLoopCoalescingPass() { return std::make_unique(); } - -static PassRegistration - reg(PASS_NAME, - "coalesce nested loops with independent bounds into a single loop"); diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -1973,6 +1973,3 @@ GreedyFusion(&g, localBufSizeThreshold, fastMemorySpace, maximalFusion) .run(); } - -static PassRegistration pass("affine-loop-fusion", - "Fuse loop nests"); diff --git a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp --- a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp +++ b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp @@ -121,7 +121,3 @@ std::unique_ptr mlir::createLoopInvariantCodeMotionPass() { return std::make_unique(); } - -static PassRegistration - pass("loop-invariant-code-motion", - "Hoist loop invariant instructions outside of the loop"); diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp --- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp +++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp @@ -222,6 +222,3 @@ defInst->erase(); } } - -static PassRegistration - pass("memref-dataflow-opt", "Perform store/load forwarding for memrefs"); diff --git a/mlir/lib/Transforms/OpStats.cpp b/mlir/lib/Transforms/OpStats.cpp --- a/mlir/lib/Transforms/OpStats.cpp +++ b/mlir/lib/Transforms/OpStats.cpp @@ -84,6 +84,3 @@ std::unique_ptr> mlir::createPrintOpStatsPass() { return std::make_unique(); } - -static PassRegistration - pass("print-op-stats", "Print statistics of operations"); diff --git a/mlir/lib/Transforms/ParallelLoopCollapsing.cpp b/mlir/lib/Transforms/ParallelLoopCollapsing.cpp --- a/mlir/lib/Transforms/ParallelLoopCollapsing.cpp +++ b/mlir/lib/Transforms/ParallelLoopCollapsing.cpp @@ -15,8 +15,7 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" -#define PASS_NAME "parallel-loop-collapsing" -#define DEBUG_TYPE PASS_NAME +#define DEBUG_TYPE "parallel-loop-collapsing" using namespace mlir; @@ -64,6 +63,3 @@ std::unique_ptr mlir::createParallelLoopCollapsingPass() { return std::make_unique(); } - -static PassRegistration - reg(PASS_NAME, "collapse parallel loops to use less induction variables."); diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -366,8 +366,3 @@ return; } } - -static PassRegistration pass( - "affine-pipeline-data-transfer", - "Pipeline non-blocking data transfers between explicitly managed levels of " - "the memory hierarchy"); diff --git a/mlir/lib/Transforms/StripDebugInfo.cpp b/mlir/lib/Transforms/StripDebugInfo.cpp --- a/mlir/lib/Transforms/StripDebugInfo.cpp +++ b/mlir/lib/Transforms/StripDebugInfo.cpp @@ -29,6 +29,3 @@ std::unique_ptr mlir::createStripDebugInfoPass() { return std::make_unique(); } - -static PassRegistration - pass("strip-debuginfo", "Strip debug info from all operations"); diff --git a/mlir/lib/Transforms/SymbolDCE.cpp b/mlir/lib/Transforms/SymbolDCE.cpp --- a/mlir/lib/Transforms/SymbolDCE.cpp +++ b/mlir/lib/Transforms/SymbolDCE.cpp @@ -156,5 +156,3 @@ std::unique_ptr mlir::createSymbolDCEPass() { return std::make_unique(); } - -static PassRegistration pass("symbol-dce", "Eliminate dead symbols"); diff --git a/mlir/lib/Transforms/ViewOpGraph.cpp b/mlir/lib/Transforms/ViewOpGraph.cpp --- a/mlir/lib/Transforms/ViewOpGraph.cpp +++ b/mlir/lib/Transforms/ViewOpGraph.cpp @@ -161,6 +161,3 @@ const Twine &title) { return std::make_unique(os, shortNames, title); } - -static PassRegistration pass("print-op-graph", - "Print op graph per region"); diff --git a/mlir/lib/Transforms/ViewRegionGraph.cpp b/mlir/lib/Transforms/ViewRegionGraph.cpp --- a/mlir/lib/Transforms/ViewRegionGraph.cpp +++ b/mlir/lib/Transforms/ViewRegionGraph.cpp @@ -80,6 +80,3 @@ const Twine &title) { return std::make_unique(os, shortNames, title); } - -static PassRegistration pass("print-cfg-graph", - "Print CFG graph per Function"); diff --git a/mlir/tools/mlir-tblgen/CMakeLists.txt b/mlir/tools/mlir-tblgen/CMakeLists.txt --- a/mlir/tools/mlir-tblgen/CMakeLists.txt +++ b/mlir/tools/mlir-tblgen/CMakeLists.txt @@ -13,6 +13,7 @@ OpDocGen.cpp OpFormatGen.cpp OpInterfacesGen.cpp + PassGen.cpp RewriterGen.cpp SPIRVUtilsGen.cpp StructsGen.cpp diff --git a/mlir/tools/mlir-tblgen/PassGen.cpp b/mlir/tools/mlir-tblgen/PassGen.cpp new file mode 100644 --- /dev/null +++ b/mlir/tools/mlir-tblgen/PassGen.cpp @@ -0,0 +1,60 @@ +//===- Pass.cpp - MLIR pass registration generator ------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// PassGen uses the description of passes to generate base classes for passes +// and command line registration. +// +//===----------------------------------------------------------------------===// + +#include "mlir/TableGen/GenInfo.h" +#include "mlir/TableGen/Pass.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/TableGen/Error.h" +#include "llvm/TableGen/Record.h" + +using namespace mlir; +using namespace mlir::tblgen; + +//===----------------------------------------------------------------------===// +// GEN: Pass registration generation +//===----------------------------------------------------------------------===// + +/// Emit the code for registering each of the given passes with the global +/// PassRegistry. +static void emitRegistration(ArrayRef passes, raw_ostream &os) { + os << "#ifdef GEN_PASS_REGISTRATION\n"; + for (const Pass &pass : passes) { + os << llvm::formatv("::mlir::registerPass(\"{0}\", \"{1}\", []() -> " + "std::unique_ptr {{ return {2}; });\n", + pass.getArgument(), pass.getSummary(), + pass.getConstructor()); + } + os << "#undef GEN_PASS_REGISTRATION\n"; + os << "#endif // GEN_PASS_REGISTRATION\n"; +} + +//===----------------------------------------------------------------------===// +// GEN: Registration hooks +//===----------------------------------------------------------------------===// + +static void emitDecls(const llvm::RecordKeeper &recordKeeper, raw_ostream &os) { + os << "/* Autogenerated by mlir-tblgen; don't manually edit */\n"; + + std::vector passes; + for (const llvm::Record *def : recordKeeper.getAllDerivedDefinitions("Pass")) + passes.push_back(Pass(def)); + emitRegistration(passes, os); +} + +static mlir::GenRegistration + genRegister("gen-pass-decls", "Generate operation documentation", + [](const llvm::RecordKeeper &records, raw_ostream &os) { + emitDecls(records, os); + return false; + });