diff --git a/mlir/lib/Reducer/ReductionTreePass.cpp b/mlir/lib/Reducer/ReductionTreePass.cpp --- a/mlir/lib/Reducer/ReductionTreePass.cpp +++ b/mlir/lib/Reducer/ReductionTreePass.cpp @@ -217,7 +217,12 @@ Operation *topOperation = getOperation(); while (topOperation->getParentOp() != nullptr) topOperation = topOperation->getParentOp(); - ModuleOp module = cast(topOperation); + ModuleOp module = dyn_cast(topOperation); + if (!module) { + emitError(getOperation()->getLoc()) + << "top-level op must be 'builtin.module'"; + return signalPassFailure(); + } SmallVector workList; workList.push_back(getOperation()); diff --git a/mlir/lib/Tools/mlir-reduce/MlirReduceMain.cpp b/mlir/lib/Tools/mlir-reduce/MlirReduceMain.cpp --- a/mlir/lib/Tools/mlir-reduce/MlirReduceMain.cpp +++ b/mlir/lib/Tools/mlir-reduce/MlirReduceMain.cpp @@ -22,20 +22,28 @@ #include "mlir/Rewrite/FrozenRewritePatternSet.h" #include "mlir/Support/FileUtilities.h" #include "mlir/Support/LogicalResult.h" +#include "mlir/Tools/ParseUtilties.h" #include "llvm/Support/InitLLVM.h" +#include "llvm/Support/SourceMgr.h" #include "llvm/Support/ToolOutputFile.h" using namespace mlir; -// Parse and verify the input MLIR file. -static LogicalResult loadModule(MLIRContext &context, - OwningOpRef &module, - StringRef inputFilename) { - module = parseSourceFile(inputFilename, &context); - if (!module) - return failure(); +// Parse and verify the input MLIR file. Returns null on error. +OwningOpRef loadModule(MLIRContext &context, + StringRef inputFilename, + bool insertImplictModule) { + // Set up the input file. + std::string errorMessage; + auto file = openInputFile(inputFilename, &errorMessage); + if (!file) { + llvm::errs() << errorMessage << "\n"; + return nullptr; + } - return success(); + llvm::SourceMgr sourceMgr; + sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc()); + return parseSourceFileForTool(sourceMgr, &context, insertImplictModule); } LogicalResult mlir::mlirReduceMain(int argc, char **argv, @@ -55,6 +63,12 @@ "o", llvm::cl::desc("Output filename for the reduced test case"), llvm::cl::init("-"), llvm::cl::cat(mlirReduceCategory)); + static llvm::cl::opt noImplicitModule{ + "no-implicit-module", + llvm::cl::desc( + "Disable implicit addition of a top-level module op during parsing"), + llvm::cl::init(false)}; + llvm::cl::HideUnrelatedOptions(mlirReduceCategory); llvm::InitLLVM y(argc, argv); @@ -76,8 +90,9 @@ if (!output) return failure(); - OwningOpRef moduleRef; - if (failed(loadModule(context, moduleRef, inputFilename))) + OwningOpRef opRef = + loadModule(context, inputFilename, !noImplicitModule); + if (!opRef) return failure(); auto errorHandler = [&](const Twine &msg) { @@ -85,16 +100,16 @@ }; // Reduction pass pipeline. - PassManager pm(&context); + PassManager pm(&context, opRef.get()->getName().getStringRef()); if (failed(parser.addToPipeline(pm, errorHandler))) return failure(); - OwningOpRef m = moduleRef.get().clone(); + OwningOpRef op = opRef.get()->clone(); - if (failed(pm.run(m.get()))) + if (failed(pm.run(op.get()))) return failure(); - m->print(output->os()); + op.get()->print(output->os()); output->keep(); return success(); diff --git a/mlir/test/mlir-reduce/invalid.mlir b/mlir/test/mlir-reduce/invalid.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-reduce/invalid.mlir @@ -0,0 +1,8 @@ +// UNSUPPORTED: system-windows +// RUN: not mlir-reduce -opt-reduction-pass --no-implicit-module %s |& FileCheck %s --check-prefix=CHECK-PASS +// RUN: not mlir-reduce -reduction-tree --no-implicit-module %s |& FileCheck %s --check-prefix=CHECK-TREE + +// The reduction passes are currently restricted to 'builtin.module'. +// CHECK-PASS: error: Can't add pass '{{.+}}' restricted to 'builtin.module' on a PassManager intended to run on 'func.func' +// CHECK-TREE: error: top-level op must be 'builtin.module' +func.func private @foo()