diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -334,6 +334,12 @@ /// operation. virtual void printNewline() = 0; + /// Increase indentation. + virtual void increaseIndent() = 0; + + /// Decrease indentation. + virtual void decreaseIndent() = 0; + /// Print a block argument in the usual format of: /// %ssaName : type {attr1=42} loc("here") /// where location printing is controlled by the standard internal option. diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -173,6 +173,16 @@ p << " outs(" << outputs << " : " << outputs.getTypes() << ")"; } +static void printCommonStructuredOpPartsWithNewLine(OpAsmPrinter &p, + ValueRange inputs, + ValueRange outputs) { + p.printNewline(); + if (!inputs.empty()) + p << "ins(" << inputs << " : " << inputs.getTypes() << ")"; + p.printNewline(); + if (!outputs.empty()) + p << "outs(" << outputs << " : " << outputs.getTypes() << ")"; +} //===----------------------------------------------------------------------===// // Specific parsing and printing for named structured ops created by ods-gen. //===----------------------------------------------------------------------===// @@ -1335,16 +1345,20 @@ } void MapOp::print(OpAsmPrinter &p) { - printCommonStructuredOpParts(p, SmallVector(getInputOperands()), - SmallVector(getOutputOperands())); + p.increaseIndent(); + printCommonStructuredOpPartsWithNewLine( + p, SmallVector(getInputOperands()), + SmallVector(getOutputOperands())); p.printOptionalAttrDict((*this)->getAttrs()); + p.printNewline(); p << "("; llvm::interleaveComma(getMapper().getArguments(), p, [&](auto arg) { p.printRegionArgument(arg); }); p << ") "; p.printRegion(getMapper(), /*printEntryBlockArgs=*/false); + p.decreaseIndent(); } LogicalResult MapOp::verify() { @@ -1481,21 +1495,26 @@ static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, ArrayRef attributeValue) { - p << " " << attributeName << " = [" << attributeValue << "] "; + p << attributeName << " = [" << attributeValue << "] "; } void ReduceOp::print(OpAsmPrinter &p) { - printCommonStructuredOpParts(p, SmallVector(getInputOperands()), - SmallVector(getOutputOperands())); + p.increaseIndent(); + printCommonStructuredOpPartsWithNewLine( + p, SmallVector(getInputOperands()), + SmallVector(getOutputOperands())); + p.printNewline(); printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions()); p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()}); + p.printNewline(); p << "("; llvm::interleaveComma(getCombiner().getArguments(), p, [&](auto arg) { p.printRegionArgument(arg); }); p << ") "; p.printRegion(getCombiner(), /*printEntryBlockArgs=*/false); + p.decreaseIndent(); } LogicalResult ReduceOp::verify() { @@ -1657,10 +1676,14 @@ } void TransposeOp::print(OpAsmPrinter &p) { - printCommonStructuredOpParts(p, SmallVector(getInputOperands()), - SmallVector(getOutputOperands())); + p.increaseIndent(); + printCommonStructuredOpPartsWithNewLine( + p, SmallVector(getInputOperands()), + SmallVector(getOutputOperands())); + p.printNewline(); printDenseI64ArrayAttr(p, getPermutationAttrName(), getPermutation()); p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()}); + p.decreaseIndent(); } LogicalResult TransposeOp::verify() { diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -716,6 +716,8 @@ void printAffineMapOfSSAIds(AffineMapAttr, ValueRange) override {} void printAffineExprOfSSAIds(AffineExpr, ValueRange, ValueRange) override {} void printNewline() override {} + void increaseIndent() override {} + void decreaseIndent() override {} void printOperand(Value) override {} void printOperand(Value, raw_ostream &os) override { // Users expect the output string to have at least the prefixed % to signal @@ -2768,6 +2770,12 @@ os.indent(currentIndent); } + /// Increase indentation. + void increaseIndent() override { currentIndent += indentWidth; } + + /// Decrease indentation. + void decreaseIndent() override { currentIndent -= indentWidth; } + /// Print a block argument in the usual format of: /// %ssaName : type {attr1=42} loc("here") /// where location printing is controlled by the standard internal option. diff --git a/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir b/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir --- a/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir @@ -341,7 +341,7 @@ func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>, %init: tensor<64xf32>) -> tensor<64xf32> { // CHECK: linalg.map - // CHECK-SAME: ins(%[[LHS]], %[[RHS]] : memref<64xf32 + // CHECK-NEXT: ins(%[[LHS]], %[[RHS]] : memref<64xf32 %add = linalg.map ins(%lhs, %rhs: tensor<64xf32>, tensor<64xf32>) outs(%init:tensor<64xf32>) @@ -359,7 +359,7 @@ func.func @reduce(%input: tensor<16x32x64xf32>, %init: tensor<16x64xf32>) -> tensor<16x64xf32> { // CHECK: linalg.reduce - // CHECK-SAME: ins(%[[INPUT]] : memref<16x32x64xf32 + // CHECK-NEXT: ins(%[[INPUT]] : memref<16x32x64xf32 %reduce = linalg.reduce ins(%input:tensor<16x32x64xf32>) outs(%init:tensor<16x64xf32>) @@ -378,7 +378,7 @@ func.func @transpose(%input: tensor<16x32x64xf32>, %init: tensor<32x64x16xf32>) -> tensor<32x64x16xf32> { // CHECK: linalg.transpose - // CHECK-SAME: ins(%[[ARG0]] : memref<16x32x64xf32 + // CHECK-NEXT: ins(%[[ARG0]] : memref<16x32x64xf32 %transpose = linalg.transpose ins(%input:tensor<16x32x64xf32>) outs(%init:tensor<32x64x16xf32>) diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -338,7 +338,13 @@ func.return %add : tensor<64xf32> } // CHECK-LABEL: func @map_binary -// CHECK: linalg.map +// CHECK: linalg.map +// CHECK-NEXT: ins +// CHECK-NEXT: outs +// CHECK-NEXT: (%{{.*}}: f32, %{{.*}}: f32) { +// CHECK-NEXT: arith.addf +// CHECK-NEXT: linalg.yield +// CHECK-NEXT: } // ----- @@ -401,7 +407,14 @@ func.return %reduce : tensor<16x64xf32> } // CHECK-LABEL: func @reduce -// CHECK: linalg.reduce +// CHECK: linalg.reduce +// CHECK-NEXT: ins +// CHECK-NEXT: outs +// CHECK-NEXT: dimensions = [1] +// CHECK-NEXT: (%{{.*}}: f32, %{{.*}}: f32) { +// CHECK-NEXT: arith.addf +// CHECK-NEXT: linalg.yield +// CHECK-NEXT: } // ----- @@ -469,6 +482,10 @@ func.return %transpose : tensor<32x64x16xf32> } // CHECK-LABEL: func @transpose +// CHECK: linalg.transpose +// CHECK-NEXT: ins +// CHECK-NEXT: outs +// CHECK-NEXT: permutation // -----