diff --git a/mlir/lib/Reducer/ReductionTreePass.cpp b/mlir/lib/Reducer/ReductionTreePass.cpp --- a/mlir/lib/Reducer/ReductionTreePass.cpp +++ b/mlir/lib/Reducer/ReductionTreePass.cpp @@ -217,7 +217,11 @@ Operation *topOperation = getOperation(); while (topOperation->getParentOp() != nullptr) topOperation = topOperation->getParentOp(); - ModuleOp module = cast(topOperation); + ModuleOp module = dyn_cast(topOperation); + if (!module) + return emitError(getOperation()->getLoc()) + << "top-level op must be 'builtin.module'", + signalPassFailure(); SmallVector workList; workList.push_back(getOperation()); diff --git a/mlir/lib/Tools/mlir-reduce/MlirReduceMain.cpp b/mlir/lib/Tools/mlir-reduce/MlirReduceMain.cpp --- a/mlir/lib/Tools/mlir-reduce/MlirReduceMain.cpp +++ b/mlir/lib/Tools/mlir-reduce/MlirReduceMain.cpp @@ -23,16 +23,27 @@ #include "mlir/Support/FileUtilities.h" #include "mlir/Support/LogicalResult.h" #include "llvm/Support/InitLLVM.h" +#include "llvm/Support/SourceMgr.h" #include "llvm/Support/ToolOutputFile.h" using namespace mlir; // Parse and verify the input MLIR file. static LogicalResult loadModule(MLIRContext &context, - OwningOpRef &module, + OwningOpRef &op, StringRef inputFilename) { - module = parseSourceFile(inputFilename, &context); - if (!module) + // Set up the input file. + std::string errorMessage; + auto file = openInputFile(inputFilename, &errorMessage); + if (!file) { + llvm::errs() << errorMessage << "\n"; + return failure(); + } + + llvm::SourceMgr sourceMgr; + sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc()); + op = parseSourceFileForTool(sourceMgr, &context); + if (!op) return failure(); return success(); @@ -62,6 +73,7 @@ registerReducerPasses(); PassPipelineCLParser parser("", "Reduction Passes to Run"); + registerToolParserCLOptions(); llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR test case reduction tool.\n"); @@ -76,8 +88,8 @@ if (!output) return failure(); - OwningOpRef moduleRef; - if (failed(loadModule(context, moduleRef, inputFilename))) + OwningOpRef opRef; + if (failed(loadModule(context, opRef, inputFilename))) return failure(); auto errorHandler = [&](const Twine &msg) { @@ -85,16 +97,16 @@ }; // Reduction pass pipeline. - PassManager pm(&context); + PassManager pm(&context, opRef.get()->getName().getStringRef()); if (failed(parser.addToPipeline(pm, errorHandler))) return failure(); - OwningOpRef m = moduleRef.get().clone(); + OwningOpRef op = opRef.get()->clone(); - if (failed(pm.run(m.get()))) + if (failed(pm.run(op.get()))) return failure(); - m->print(output->os()); + op.get()->print(output->os()); output->keep(); return success(); diff --git a/mlir/test/mlir-reduce/invalid.mlir b/mlir/test/mlir-reduce/invalid.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-reduce/invalid.mlir @@ -0,0 +1,8 @@ +// UNSUPPORTED: system-windows +// RUN: not mlir-reduce -opt-reduction-pass --no-implicit-module %s |& FileCheck %s --check-prefix=CHECK-PASS +// RUN: not mlir-reduce -reduction-tree --no-implicit-module %s |& FileCheck %s --check-prefix=CHECK-TREE + +// The reduction passes are currently restricted to 'builtin.module'. +// CHECK-PASS: error: Can't add pass '{{.+}}' restricted to 'builtin.module' on a PassManager intended to run on 'func.func' +// CHECK-TREE: error: top-level op must be 'builtin.module' +func.func private @foo()