diff --git a/mlir/include/mlir/Reducer/ReductionNode.h b/mlir/include/mlir/Reducer/ReductionNode.h --- a/mlir/include/mlir/Reducer/ReductionNode.h +++ b/mlir/include/mlir/Reducer/ReductionNode.h @@ -20,6 +20,7 @@ #include #include +#include "mlir/IR/OwningOpRef.h" #include "mlir/Reducer/Tester.h" #include "mlir/Support/LogicalResult.h" #include "llvm/ADT/ArrayRef.h" @@ -57,7 +58,7 @@ /// will have been applied certain reduction strategies. Note that it's not /// necessary to be an interesting case or a reduced module (has smaller size /// than parent's). - ModuleOp getModule() const { return module; } + ModuleOp getModule() const { return module.get(); } /// Return the region we're reducing. Region &getRegion() const { return *region; } @@ -141,7 +142,7 @@ /// This is a copy of module from parent node. All the reducer patterns will /// be applied to this instance. - ModuleOp module; + OwningOpRef module; /// The region of certain operation we're reducing in the module Region *region; diff --git a/mlir/lib/Reducer/ReductionNode.cpp b/mlir/lib/Reducer/ReductionNode.cpp --- a/mlir/lib/Reducer/ReductionNode.cpp +++ b/mlir/lib/Reducer/ReductionNode.cpp @@ -112,6 +112,9 @@ // This module may has been updated. Reset the range. ranges.clear(); ranges.push_back({0, std::distance(region->op_begin(), region->op_end())}); + } else { + // Release the uninteresting module to save some memory. + module.release()->erase(); } } diff --git a/mlir/lib/Tools/mlir-reduce/MlirReduceMain.cpp b/mlir/lib/Tools/mlir-reduce/MlirReduceMain.cpp --- a/mlir/lib/Tools/mlir-reduce/MlirReduceMain.cpp +++ b/mlir/lib/Tools/mlir-reduce/MlirReduceMain.cpp @@ -28,7 +28,8 @@ using namespace mlir; // Parse and verify the input MLIR file. -static LogicalResult loadModule(MLIRContext &context, OwningModuleRef &module, +static LogicalResult loadModule(MLIRContext &context, + OwningOpRef &module, StringRef inputFilename) { module = parseSourceFile(inputFilename, &context); if (!module) @@ -62,7 +63,7 @@ if (!output) return failure(); - mlir::OwningModuleRef moduleRef; + OwningOpRef moduleRef; if (failed(loadModule(context, moduleRef, inputFilename))) return failure(); @@ -75,12 +76,12 @@ if (failed(parser.addToPipeline(pm, errorHandler))) return failure(); - ModuleOp m = moduleRef.get().clone(); + OwningOpRef m = moduleRef.get().clone(); - if (failed(pm.run(m))) + if (failed(pm.run(m.get()))) return failure(); - m.print(output->os()); + m->print(output->os()); output->keep(); return success();