Index: mlir/include/mlir/IR/OperationSupport.h =================================================================== --- mlir/include/mlir/IR/OperationSupport.h +++ mlir/include/mlir/IR/OperationSupport.h @@ -726,6 +726,9 @@ /// Always print operations in the generic form. OpPrintingFlags &printGenericOpForm(); + /// Do not verify the operation when using custom operation printers. + OpPrintingFlags &assumeVerified(); + /// Use local scope when printing the operation. This allows for using the /// printer in a more localized and thread-safe setting, but may not /// necessarily be identical to what the IR will look like when dumping @@ -747,6 +750,9 @@ /// Return if operations should be printed in the generic form. bool shouldPrintGenericOpForm() const; + /// Return if operation verification should be skipped. + bool shouldAssumeVerified() const; + /// Return if the printer should use local scope when dumping the IR. bool shouldUseLocalScope() const; @@ -762,6 +768,9 @@ /// Print operations in the generic form. bool printGenericOpFormFlag : 1; + /// Skip operation verification. + bool assumeVerifiedFlag : 1; + /// Print operations with numberings local to the current operation. bool printLocalScope : 1; }; Index: mlir/include/mlir/IR/Value.h =================================================================== --- mlir/include/mlir/IR/Value.h +++ mlir/include/mlir/IR/Value.h @@ -24,6 +24,7 @@ class BlockArgument; class Operation; class OpOperand; +class OpPrintingFlags; class OpResult; class Region; class Value; @@ -215,6 +216,7 @@ // Utilities void print(raw_ostream &os); + void print(raw_ostream &os, const OpPrintingFlags &flags); void print(raw_ostream &os, AsmState &state); void dump(); Index: mlir/lib/IR/AsmPrinter.cpp =================================================================== --- mlir/lib/IR/AsmPrinter.cpp +++ mlir/lib/IR/AsmPrinter.cpp @@ -25,6 +25,7 @@ #include "mlir/IR/OpImplementation.h" #include "mlir/IR/Operation.h" #include "mlir/IR/SubElementInterfaces.h" +#include "mlir/IR/Verifier.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/MapVector.h" @@ -40,6 +41,7 @@ #include "llvm/Support/Endian.h" #include "llvm/Support/Regex.h" #include "llvm/Support/SaveAndRestore.h" +#include "llvm/Support/Threading.h" #include @@ -141,6 +143,11 @@ "mlir-print-op-generic", llvm::cl::init(false), llvm::cl::desc("Print the generic op form"), llvm::cl::Hidden}; + llvm::cl::opt assumeVerifiedOpt{ + "mlir-print-assume-verified", llvm::cl::init(false), + llvm::cl::desc("Skip op verification when using custom printers"), + llvm::cl::Hidden}; + llvm::cl::opt printLocalScopeOpt{ "mlir-print-local-scope", llvm::cl::init(false), llvm::cl::desc("Print with local scope and inline information (eliding " @@ -160,7 +167,8 @@ /// Initialize the printing flags with default supplied by the cl::opts above. OpPrintingFlags::OpPrintingFlags() : printDebugInfoFlag(false), printDebugInfoPrettyFormFlag(false), - printGenericOpFormFlag(false), printLocalScope(false) { + printGenericOpFormFlag(false), assumeVerifiedFlag(false), + printLocalScope(false) { // Initialize based upon command line options, if they are available. if (!clOptions.isConstructed()) return; @@ -169,6 +177,7 @@ printDebugInfoFlag = clOptions->printDebugInfoOpt; printDebugInfoPrettyFormFlag = clOptions->printPrettyDebugInfoOpt; printGenericOpFormFlag = clOptions->printGenericOpFormOpt; + assumeVerifiedFlag = clOptions->assumeVerifiedOpt; printLocalScope = clOptions->printLocalScopeOpt; } @@ -196,6 +205,12 @@ return *this; } +/// Do not verify the operation when using custom operation printers. +OpPrintingFlags &OpPrintingFlags::assumeVerified() { + assumeVerifiedFlag = true; + return *this; +} + /// Use local scope when printing the operation. This allows for using the /// printer in a more localized and thread-safe setting, but may not necessarily /// be identical of what the IR will look like when dumping the full module. @@ -231,6 +246,11 @@ return printGenericOpFormFlag; } +/// Return if operation verification should be skipped. +bool OpPrintingFlags::shouldAssumeVerified() const { + return assumeVerifiedFlag; +} + /// Return if the printer should use local scope when dumping the IR. bool OpPrintingFlags::shouldUseLocalScope() const { return printLocalScope; } @@ -1245,9 +1265,31 @@ } // namespace detail } // namespace mlir +/// Verifies the operation and switches to generic op printing if verification +/// fails. We need to do this because custom print functions may fail for +/// invalid ops. +static OpPrintingFlags verifyOpAndAdjustFlags(Operation *op, + OpPrintingFlags printerFlags) { + if (printerFlags.shouldPrintGenericOpForm() || + printerFlags.shouldAssumeVerified()) + return printerFlags; + + // Ignore errors emitted by the verifier. We check the thread id to avoid + // consuming other threads' errors. + auto parentThreadId = llvm::get_threadid(); + ScopedDiagnosticHandler diagHandler(op->getContext(), [&](Diagnostic &) { + return success(parentThreadId == llvm::get_threadid()); + }); + if (failed(verify(op))) + printerFlags.printGenericOpForm(); + + return printerFlags; +} + AsmState::AsmState(Operation *op, const OpPrintingFlags &printerFlags, LocationMap *locationMap) - : impl(std::make_unique(op, printerFlags, locationMap)) {} + : impl(std::make_unique( + op, verifyOpAndAdjustFlags(op, printerFlags), locationMap)) {} AsmState::~AsmState() = default; const OpPrintingFlags &AsmState::getPrinterFlags() const { @@ -2853,14 +2895,15 @@ AsmPrinter::Impl(os).printIntegerSet(*this); } -void Value::print(raw_ostream &os) { +void Value::print(raw_ostream &os) { print(os, OpPrintingFlags()); } +void Value::print(raw_ostream &os, const OpPrintingFlags &flags) { if (!impl) { os << "<>"; return; } if (auto *op = getDefiningOp()) - return op->print(os); + return op->print(os, flags); // TODO: Improve BlockArgument print'ing. BlockArgument arg = this->cast(); os << " of type '" << arg.getType() Index: mlir/lib/IR/Diagnostics.cpp =================================================================== --- mlir/lib/IR/Diagnostics.cpp +++ mlir/lib/IR/Diagnostics.cpp @@ -128,8 +128,10 @@ Diagnostic &Diagnostic::appendOp(Operation &val, const OpPrintingFlags &flags) { std::string str; llvm::raw_string_ostream os(str); - val.print(os, - OpPrintingFlags(flags).useLocalScope().elideLargeElementsAttrs()); + val.print(os, OpPrintingFlags(flags) + .useLocalScope() + .elideLargeElementsAttrs() + .printGenericOpForm()); return *this << os.str(); } @@ -137,7 +139,10 @@ Diagnostic &Diagnostic::operator<<(Value val) { std::string str; llvm::raw_string_ostream os(str); - val.print(os); + val.print(os, OpPrintingFlags() + .useLocalScope() + .elideLargeElementsAttrs() + .printGenericOpForm()); return *this << os.str(); } Index: mlir/lib/IR/Operation.cpp =================================================================== --- mlir/lib/IR/Operation.cpp +++ mlir/lib/IR/Operation.cpp @@ -1097,6 +1097,8 @@ // Check that any value that is used by an operation is defined in the // same region as either an operation result. auto *operandRegion = operand.getParentRegion(); + if (!operandRegion) + return op.emitError("operation's operand is unlinked"); if (!region.isAncestor(operandRegion)) { return op.emitOpError("using value defined outside the region") .attachNote(isolatedOp->getLoc()) Index: mlir/test/Analysis/test-match-reduction.mlir =================================================================== --- mlir/test/Analysis/test-match-reduction.mlir +++ mlir/test/Analysis/test-match-reduction.mlir @@ -7,7 +7,7 @@ func @linalg_red_add(%in0t : tensor, %out0t : tensor<1xf32>) { // expected-remark@below {{Reduction found in output #0!}} // expected-remark@below {{Reduced Value: of type 'f32' at index: 0}} - // expected-remark@below {{Combiner Op: %1 = arith.addf %arg2, %arg3 : f32}} + // expected-remark@below {{Combiner Op: %1 = "arith.addf"(%arg2, %arg3) : (f32, f32) -> f32}} %red = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (0)>], iterator_types = ["reduction"]} @@ -27,8 +27,8 @@ %cst = arith.constant 0.000000e+00 : f32 affine.for %i = 0 to 256 { // expected-remark@below {{Reduction found in output #0!}} - // expected-remark@below {{Reduced Value: %1 = affine.load %arg0[%arg2, %arg3] : memref<256x512xf32>}} - // expected-remark@below {{Combiner Op: %2 = arith.addf %arg4, %1 : f32}} + // expected-remark@below {{Reduced Value: %2 = "affine.load"}} + // expected-remark@below {{Combiner Op: %3 = "arith.addf"(%arg4, %2) : (f32, f32) -> f32}} %final_red = affine.for %j = 0 to 512 iter_args(%red_iter = %cst) -> (f32) { %ld = affine.load %in[%i, %j] : memref<256x512xf32> %add = arith.addf %red_iter, %ld : f32 @@ -63,8 +63,8 @@ // expected-remark@below {{Testing function}} func @linalg_fused_red_add(%in0t: tensor<4x4xf32>, %out0t: tensor<4xf32>) { // expected-remark@below {{Reduction found in output #0!}} - // expected-remark@below {{Reduced Value: %2 = arith.subf %1, %arg2 : f32}} - // expected-remark@below {{Combiner Op: %3 = arith.addf %2, %arg3 : f32}} + // expected-remark@below {{Reduced Value: %2 = "arith.subf"(%1, %arg2) : (f32, f32) -> f32}} + // expected-remark@below {{Combiner Op: %3 = "arith.addf"(%2, %arg3) : (f32, f32) -> f32}} %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} Index: mlir/test/IR/print-ir-invalid.mlir =================================================================== --- /dev/null +++ mlir/test/IR/print-ir-invalid.mlir @@ -0,0 +1,33 @@ +// # RUN: mlir-opt -test-print-invalid %s | FileCheck %s +// # RUN: mlir-opt -test-print-invalid %s --mlir-print-assume-verified | FileCheck %s --check-prefix=ASSUME-VERIFIED + +// The pass creates some ops and prints them to stdout, the input is just an +// empty module. +module {} + +// The operation is invalid because the body does not have a terminator, print +// the generic form. +// CHECK: Invalid operation: +// CHECK-NEXT: "builtin.func"() ({ +// CHECK-NEXT: ^bb0: +// CHECK-NEXT: }) +// CHECK-SAME: sym_name = "test" + +// The operation is valid because the body has a terminator, print the custom +// form. +// CHECK: Valid operation: +// CHECK-NEXT: func @test() { +// CHECK-NEXT: return +// CHECK-NEXT: } + +// With --mlir-print-assume-verified the custom form is printed in both cases. +// This works in this particular case, but may crash in general. + +// ASSUME-VERIFIED: Invalid operation: +// ASSUME-VERIFIED-NEXT: func @test() { +// ASSUME-VERIFIED-NEXT: } + +// ASSUME-VERIFIED: Valid operation: +// ASSUME-VERIFIED-NEXT: func @test() { +// ASSUME-VERIFIED-NEXT: return +// ASSUME-VERIFIED-NEXT: } Index: mlir/test/Transforms/test-legalize-erased-op-with-uses.mlir =================================================================== --- mlir/test/Transforms/test-legalize-erased-op-with-uses.mlir +++ mlir/test/Transforms/test-legalize-erased-op-with-uses.mlir @@ -5,6 +5,6 @@ func @remove_all_ops(%arg0: i32) -> i32 { // expected-error@below {{failed to legalize operation 'test.illegal_op_a' marked as erased}} %0 = "test.illegal_op_a"() : () -> i32 - // expected-note@below {{found live user of result #0: return %0 : i32}} + // expected-note@below {{found live user of result #0: "std.return"(%0) : (i32) -> ()}} return %0 : i32 } Index: mlir/test/lib/IR/CMakeLists.txt =================================================================== --- mlir/test/lib/IR/CMakeLists.txt +++ mlir/test/lib/IR/CMakeLists.txt @@ -9,6 +9,7 @@ TestOpaqueLoc.cpp TestOperationEquals.cpp TestPrintDefUse.cpp + TestPrintInvalid.cpp TestPrintNesting.cpp TestSideEffects.cpp TestSlicing.cpp Index: mlir/test/lib/IR/TestPrintInvalid.cpp =================================================================== --- /dev/null +++ mlir/test/lib/IR/TestPrintInvalid.cpp @@ -0,0 +1,52 @@ +//===- TestPrintInvalid.cpp - Test printing invalid ops -------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass creates and prints to the standard output an invalid operation and +// a valid operation. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Pass/Pass.h" +#include "llvm/Support/raw_ostream.h" + +using namespace mlir; + +namespace { +struct TestPrintInvalidPass + : public PassWrapper> { + StringRef getArgument() const final { return "test-print-invalid"; } + StringRef getDescription() const final { + return "Test printing invalid ops."; + } + void getDependentDialects(DialectRegistry ®istry) const { + registry.insert(); + } + + void runOnOperation() override { + Location loc = getOperation().getLoc(); + OpBuilder builder(getOperation().body()); + auto funcOp = builder.create( + loc, "test", FunctionType::get(getOperation().getContext(), {}, {})); + funcOp.addEntryBlock(); + // The created function is invalid because there is no return op. + llvm::outs() << "Invalid operation:\n" << funcOp << "\n"; + builder.setInsertionPointToEnd(&funcOp.getBody().front()); + builder.create(loc); + // Now this function is valid. + llvm::outs() << "Valid operation:\n" << funcOp << "\n"; + funcOp.erase(); + } +}; +} // namespace + +namespace mlir { +void registerTestPrintInvalidPass() { + PassRegistration{}; +} +} // namespace mlir Index: mlir/test/mlir-tblgen/return-types.mlir =================================================================== --- mlir/test/mlir-tblgen/return-types.mlir +++ mlir/test/mlir-tblgen/return-types.mlir @@ -41,9 +41,9 @@ // CHECK-LABEL: testReifyFunctions func @testReifyFunctions(%arg0 : tensor<10xf32>, %arg1 : tensor<20xf32>) { - // expected-remark@+1 {{arith.constant 10}} + // expected-remark@+1 {{"arith.constant"() {value = 10 : index} }} %0 = "test.op_with_shaped_type_infer_type_if"(%arg0, %arg1) : (tensor<10xf32>, tensor<20xf32>) -> tensor<10xi17> - // expected-remark@+1 {{arith.constant 20}} + // expected-remark@+1 {{"arith.constant"() {value = 20 : index} }} %1 = "test.op_with_shaped_type_infer_type_if"(%arg1, %arg0) : (tensor<20xf32>, tensor<10xf32>) -> tensor<20xi17> return } Index: mlir/tools/mlir-opt/mlir-opt.cpp =================================================================== --- mlir/tools/mlir-opt/mlir-opt.cpp +++ mlir/tools/mlir-opt/mlir-opt.cpp @@ -45,6 +45,7 @@ void registerTestMatchers(); void registerTestOperationEqualPass(); void registerTestPrintDefUsePass(); +void registerTestPrintInvalidPass(); void registerTestPrintNestingPass(); void registerTestReducer(); void registerTestSpirvEntryPointABIPass(); @@ -133,6 +134,7 @@ registerTestMatchers(); registerTestOperationEqualPass(); registerTestPrintDefUsePass(); + registerTestPrintInvalidPass(); registerTestPrintNestingPass(); registerTestReducer(); registerTestSpirvEntryPointABIPass();