diff --git a/mlir/test/lib/Dialect/LLVM/TestLowerToLLVM.cpp b/mlir/test/lib/Dialect/LLVM/TestLowerToLLVM.cpp --- a/mlir/test/lib/Dialect/LLVM/TestLowerToLLVM.cpp +++ b/mlir/test/lib/Dialect/LLVM/TestLowerToLLVM.cpp @@ -27,43 +27,26 @@ #include "mlir/IR/DialectRegistry.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" +#include "mlir/Pass/PassOptions.h" #include "mlir/Transforms/Passes.h" using namespace mlir; namespace { -struct TestLowerToLLVM - : public PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLowerToLLVM) - - TestLowerToLLVM() = default; - TestLowerToLLVM(const TestLowerToLLVM &pass) : PassWrapper(pass) {} - - StringRef getArgument() const final { return "test-lower-to-llvm"; } - StringRef getDescription() const final { - return "Test lowering to LLVM as a generally usable sink pass"; - } - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - - Option reassociateFPReductions{ +struct TestLowerToLLVMOptions + : public PassPipelineOptions { + PassOptions::Option reassociateFPReductions{ *this, "reassociate-fp-reductions", llvm::cl::desc("Allow reassociation og FP reductions"), llvm::cl::init(false)}; - - void runOnOperation() final; }; -} // namespace -void TestLowerToLLVM::runOnOperation() { - MLIRContext *context = &this->getContext(); - RewritePatternSet patterns(context); +void buildTestLowerToLLVM(OpPassManager &pm, + const TestLowerToLLVMOptions &options) { // TODO: it is feasible to scope lowering at arbitrary level and introduce // unrealized casts, but there needs to be the final module-wise cleanup in // the end. Keep module-level for now. - PassManager pm(&getContext()); // Blanket-convert any remaining high-level vector ops to loops if any remain. pm.addNestedPass(createConvertVectorToSCFPass()); @@ -82,7 +65,7 @@ pm.addPass(createConvertVectorToLLVMPass( // TODO: add more options on a per-need basis. LowerVectorToLLVMOptions().enableReassociateFPReductions( - reassociateFPReductions))); + options.reassociateFPReductions))); // Convert Math to LLVM (always needed). pm.addNestedPass(createConvertMathToLLVMPass()); // Convert MemRef to LLVM (always needed). @@ -93,14 +76,17 @@ pm.addPass(createConvertIndexToLLVMPass()); // Convert remaining unrealized_casts (always needed). pm.addPass(createReconcileUnrealizedCastsPass()); - if (failed(pm.run(getOperation()))) { - getOperation()->dump(); - return signalPassFailure(); - } } +} // namespace namespace mlir { namespace test { -void registerTestLowerToLLVM() { PassRegistration(); } +void registerTestLowerToLLVM() { + PassPipelineRegistration( + "test-lower-to-llvm", + "An example of pipeline to lower the main dialects (arith, linalg, " + "memref, scf, vector) down to LLVM.", + buildTestLowerToLLVM); +} } // namespace test } // namespace mlir