diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -15,7 +15,6 @@ class AffineExpr; class BlockAndValueMapping; -class ModuleOp; class UnknownLoc; class FileLineColLoc; class Type; @@ -47,7 +46,7 @@ class Builder { public: explicit Builder(MLIRContext *context) : context(context) {} - explicit Builder(ModuleOp module); + explicit Builder(Operation *op) : Builder(op->getContext()) {} MLIRContext *getContext() const { return context; } diff --git a/mlir/include/mlir/IR/BuiltinOps.td b/mlir/include/mlir/IR/BuiltinOps.td --- a/mlir/include/mlir/IR/BuiltinOps.td +++ b/mlir/include/mlir/IR/BuiltinOps.td @@ -197,12 +197,6 @@ /// Return the name of this module if present. Optional getName() { return sym_name(); } - /// Print this module in the custom top-level form. - void print(raw_ostream &os, OpPrintingFlags flags = llvm::None); - void print(raw_ostream &os, AsmState &state, - OpPrintingFlags flags = llvm::None); - void dump(); - //===------------------------------------------------------------------===// // SymbolOpInterface Methods //===------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h --- a/mlir/include/mlir/IR/MLIRContext.h +++ b/mlir/include/mlir/IR/MLIRContext.h @@ -25,7 +25,7 @@ class MLIRContextImpl; class StorageUniquer; -/// MLIRContext is the top-level object for a collection of MLIR modules. It +/// MLIRContext is the top-level object for a collection of MLIR operations. It /// holds immortal uniqued objects like types, and the tables used to unique /// them. /// diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.h b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.h --- a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.h +++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.h @@ -19,6 +19,8 @@ #include "llvm/ADT/MapVector.h" namespace mlir { +class ModuleOp; + namespace pdl_to_pdl_interp { class MatcherNode; diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -15,7 +15,6 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/AsmState.h" #include "mlir/IR/Attributes.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/IntegerSet.h" @@ -2189,8 +2188,9 @@ AsmStateImpl &state) : ModulePrinter(os, flags, &state) {} - /// Print the given top-level module. - void print(ModuleOp op); + /// Print the given top-level operation. + void printTopLevelOperation(Operation *op); + /// Print the given operation with its indent and location. void print(Operation *op); /// Print the bare location, not including indentation/location/etc. @@ -2289,12 +2289,12 @@ }; } // end anonymous namespace -void OperationPrinter::print(ModuleOp op) { +void OperationPrinter::printTopLevelOperation(Operation *op) { // Output the aliases at the top level that can't be deferred. state->getAliasState().printNonDeferredAliases(os, newLine); // Print the module. - print(op.getOperation()); + print(op); os << newLine; // Output the aliases at the top level that can be deferred. @@ -2588,6 +2588,14 @@ } void Operation::print(raw_ostream &os, OpPrintingFlags flags) { + // If this is a top level operation, we also print aliases. + if (!getParent() && !flags.shouldUseLocalScope()) { + AsmState state(this); + state.getImpl().initializeAliases(this, flags); + print(os, state, flags); + return; + } + // Find the operation to number from based upon the provided flags. Operation *printedOp = this; bool shouldUseLocalScope = flags.shouldUseLocalScope(); @@ -2608,7 +2616,11 @@ print(os, state, flags); } void Operation::print(raw_ostream &os, AsmState &state, OpPrintingFlags flags) { - OperationPrinter(os, flags, state.getImpl()).print(this); + OperationPrinter printer(os, flags, state.getImpl()); + if (!getParent() && !flags.shouldUseLocalScope()) + printer.printTopLevelOperation(this); + else + printer.print(this); } void Operation::dump() { @@ -2649,17 +2661,3 @@ OperationPrinter printer(os, /*flags=*/llvm::None, state.getImpl()); printer.printBlockName(this); } - -void ModuleOp::print(raw_ostream &os, OpPrintingFlags flags) { - AsmState state(*this); - - // Don't populate aliases when printing at local scope. - if (!flags.shouldUseLocalScope()) - state.getImpl().initializeAliases(*this, flags); - print(os, state, flags); -} -void ModuleOp::print(raw_ostream &os, AsmState &state, OpPrintingFlags flags) { - OperationPrinter(os, flags, state.getImpl()).print(*this); -} - -void ModuleOp::dump() { print(llvm::errs()); } diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -9,7 +9,6 @@ #include "mlir/IR/Attributes.h" #include "AttributeDetail.h" #include "mlir/IR/AffineMap.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/IntegerSet.h" diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -10,15 +10,13 @@ #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/BlockAndValueMapping.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/StandardTypes.h" -#include "llvm/Support/raw_ostream.h" -using namespace mlir; +#include "mlir/IR/SymbolTable.h" -Builder::Builder(ModuleOp module) : context(module.getContext()) {} +using namespace mlir; Identifier Builder::getIdentifier(StringRef str) { return Identifier::get(str, context); diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -17,7 +17,6 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinDialect.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Identifier.h" diff --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp --- a/mlir/test/EDSC/builder-api-test.cpp +++ b/mlir/test/EDSC/builder-api-test.cpp @@ -108,7 +108,7 @@ // CHECK-DAG: [[ri4:%[0-9]+]] = muli {{.*}}, {{.*}} : i32 // CHECK: {{.*}} = subi [[ri3]], [[ri4]] : i32 // clang-format on - f.print(llvm::outs()); + f.print(llvm::outs(), OpPrintingFlags().useLocalScope()); f.erase(); } @@ -130,7 +130,7 @@ // CHECK-DAG: [[r1:%[0-9]+]] = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%{{.*}}, %{{.*}}] // CHECK-NEXT: affine.for %{{.*}} = affine_map<(d0) -> (d0)>([[r0]]) to affine_map<(d0) -> (d0)>([[r1]]) step 2 { // clang-format on - f.print(llvm::outs()); + f.print(llvm::outs(), OpPrintingFlags().useLocalScope()); f.erase(); } @@ -152,7 +152,7 @@ // CHECK-DAG: [[r1:%[0-9]+]] = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%{{.*}}, %{{.*}}] // CHECK-NEXT: scf.for %{{.*}} = [[r0]] to [[r1]] step {{.*}} { // clang-format on - f.print(llvm::outs()); + f.print(llvm::outs(), OpPrintingFlags().useLocalScope()); f.erase(); } @@ -173,7 +173,7 @@ // CHECK: affine.for %{{.*}} = max affine_map<(d0, d1) -> (d0, d1)>(%{{.*}}, %{{.*}}) to min affine_map<(d0, d1) -> (d0, d1)>(%{{.*}}, %{{.*}}) { // CHECK: return // clang-format on - f.print(llvm::outs()); + f.print(llvm::outs(), OpPrintingFlags().useLocalScope()); f.erase(); } @@ -205,7 +205,7 @@ // CHECK-NEXT: affine.yield %[[iarg_2]], %[[sum]] : i32, i32 // CHECK-NEXT: } // clang-format on - f.print(llvm::outs()); + f.print(llvm::outs(), OpPrintingFlags().useLocalScope()); f.erase(); } @@ -232,7 +232,7 @@ // CHECK-NEXT: return // CHECK-NEXT: } // clang-format on - f.print(llvm::outs()); + f.print(llvm::outs(), OpPrintingFlags().useLocalScope()); f.erase(); } @@ -271,7 +271,7 @@ // CHECK-NEXT: br ^bb1(%{{.*}}, %{{.*}} : i32, i32) // CHECK-NEXT: } // clang-format on - f.print(llvm::outs()); + f.print(llvm::outs(), OpPrintingFlags().useLocalScope()); f.erase(); } @@ -305,7 +305,7 @@ // CHECK-NEXT: ^bb2(%{{.*}}: i64, %{{.*}}: i32): // pred: ^bb0 // CHECK-NEXT: return // clang-format on - f.print(llvm::outs()); + f.print(llvm::outs(), OpPrintingFlags().useLocalScope()); f.erase(); } @@ -363,7 +363,7 @@ // CHECK-DAG: [[e:%.*]] = addf [[d]], [[c]] : f32 // CHECK-NEXT: affine.store [[e]], %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}] : memref // clang-format on - f.print(llvm::outs()); + f.print(llvm::outs(), OpPrintingFlags().useLocalScope()); f.erase(); } @@ -384,7 +384,7 @@ // CHECK: ^bb1: // no predecessors // CHECK: {{.*}} = constant 1 : i32 // clang-format on - f.print(llvm::outs()); + f.print(llvm::outs(), OpPrintingFlags().useLocalScope()); f.erase(); } @@ -409,7 +409,7 @@ // CHECK: %[[SRC2:.*]] = affine.load // CHECK: sexti %[[SRC2]] : i1 to i8 // clang-format on - f.print(llvm::outs()); + f.print(llvm::outs(), OpPrintingFlags().useLocalScope()); f.erase(); } @@ -430,7 +430,7 @@ // CHECK: [[ARG0:%.*]]: i1, [[ARG1:%.*]]: i1 // CHECK: or [[ARG0]], [[ARG1]] // clang-format on - f.print(llvm::outs()); + f.print(llvm::outs(), OpPrintingFlags().useLocalScope()); f.erase(); } @@ -454,7 +454,7 @@ // CHECK: [[TRUE:%.*]] = constant true // CHECK: subi [[TRUE]], [[AND]] : i1 // clang-format on - f.print(llvm::outs()); + f.print(llvm::outs(), OpPrintingFlags().useLocalScope()); f.erase(); } @@ -473,7 +473,7 @@ // CHECK-DAG: {{.*}} = constant 2 // CHECK-NEXT: {{.*}} = divi_signed // clang-format on - f.print(llvm::outs()); + f.print(llvm::outs(), OpPrintingFlags().useLocalScope()); f.erase(); } @@ -492,7 +492,7 @@ // CHECK-DAG: {{.*}} = constant 2 // CHECK-NEXT: {{.*}} = divi_unsigned // clang-format on - f.print(llvm::outs()); + f.print(llvm::outs(), OpPrintingFlags().useLocalScope()); f.erase(); } @@ -511,7 +511,7 @@ // CHECK: {{.*}} = constant 1.0 // CHECK-NEXT: {{.*}} = fpext // clang-format on - f.print(llvm::outs()); + f.print(llvm::outs(), OpPrintingFlags().useLocalScope()); f.erase(); } @@ -530,7 +530,7 @@ // CHECK: {{.*}} = constant 1.0 // CHECK-NEXT: {{.*}} = fptrunc // clang-format on - f.print(llvm::outs()); + f.print(llvm::outs(), OpPrintingFlags().useLocalScope()); f.erase(); } @@ -606,7 +606,7 @@ // CHECK-DAG: {{.*}} = affine.load // CHECK-NEXT: {{.*}} = select // clang-format on - f.print(llvm::outs()); + f.print(llvm::outs(), OpPrintingFlags().useLocalScope()); f.erase(); } @@ -712,7 +712,7 @@ // CHECK-DAG: affine.apply // CHECK-NEXT: select // clang-format on - f.print(llvm::outs()); + f.print(llvm::outs(), OpPrintingFlags().useLocalScope()); f.erase(); } @@ -787,7 +787,7 @@ // CHECK-NEXT: {{.*}}= addf {{.*}}, {{.*}} : f32 // CHECK-NEXT: affine.store {{.*}}, {{.*}}[%{{.*}}, %{{.*}}, %{{.*}}] : memref // clang-format on - f.print(llvm::outs()); + f.print(llvm::outs(), OpPrintingFlags().useLocalScope()); f.erase(); } @@ -821,7 +821,7 @@ // CHECK: load %{{.*}}{{\[}}[[B]]{{\]}} // CHECK: store %{{.*}}, %{{.*}}{{\[}}[[D]]{{\]}} // clang-format on - f.print(llvm::outs()); + f.print(llvm::outs(), OpPrintingFlags().useLocalScope()); f.erase(); } @@ -850,7 +850,7 @@ // CHECK: [[A:%.*]] = affine.load %{{.*}}[] // CHECK: affine.store [[A]], %{{.*}}[] // clang-format on - f.print(llvm::outs()); + f.print(llvm::outs(), OpPrintingFlags().useLocalScope()); f.erase(); } @@ -886,7 +886,7 @@ intrinsics::affine_if(intSet, affineIfArgs, /*withElseRegion=*/false); intrinsics::affine_if(intSet, affineIfArgs, /*withElseRegion=*/true); - f.print(llvm::outs()); + f.print(llvm::outs(), OpPrintingFlags().useLocalScope()); f.erase(); } @@ -932,7 +932,7 @@ linalg_generic_pointwise_max(SA({i, j}), SB({i, j}), SC({i, j})); linalg_generic_pointwise_tanh(SA({i, j}), SC({i, j})); - f.print(llvm::outs()); + f.print(llvm::outs(), OpPrintingFlags().useLocalScope()); f.erase(); } @@ -963,7 +963,7 @@ ScopedContext scope(builder, f.getLoc()); linalg_generic_matmul(f.getArguments()); - f.print(llvm::outs()); + f.print(llvm::outs(), OpPrintingFlags().useLocalScope()); f.erase(); } @@ -999,7 +999,7 @@ linalg_generic_conv_nhwc(f.getArguments(), /*strides=*/{3, 4}, /*dilations=*/{5, 6}); - f.print(llvm::outs()); + f.print(llvm::outs(), OpPrintingFlags().useLocalScope()); f.erase(); } @@ -1036,7 +1036,7 @@ /*depth_multiplier=*/7, /*strides=*/{3, 4}, /*dilations=*/{5, 6}); - f.print(llvm::outs()); + f.print(llvm::outs(), OpPrintingFlags().useLocalScope()); f.erase(); } @@ -1062,7 +1062,7 @@ auto reshaped = linalg_reshape(v, maps); linalg_reshape(memrefType, reshaped, maps); - f.print(llvm::outs()); + f.print(llvm::outs(), OpPrintingFlags().useLocalScope()); f.erase(); } @@ -1134,7 +1134,7 @@ Value o1 = linalg_generic_matmul(A, B, tanhed, tensorType)->getResult(0); linalg_generic_matmul(A, B, o1, tensorType); - f.print(llvm::outs()); + f.print(llvm::outs(), OpPrintingFlags().useLocalScope()); f.erase(); } @@ -1156,7 +1156,7 @@ // CHECK-DAG: {{.*}} = constant 0 // CHECK-NEXT: {{.*}} = vector.extractelement // clang-format on - f.print(llvm::outs()); + f.print(llvm::outs(), OpPrintingFlags().useLocalScope()); f.erase(); } @@ -1202,7 +1202,7 @@ }; linalg_generic_matmul(A, B, C, contractionBuilder); - f.print(llvm::outs()); + f.print(llvm::outs(), OpPrintingFlags().useLocalScope()); f.erase(); } @@ -1238,7 +1238,7 @@ // CHECK: addf [[res]]#0, [[res]]#1 : f32 // clang-format on - f.print(llvm::outs()); + f.print(llvm::outs(), OpPrintingFlags().useLocalScope()); f.erase(); }