diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -255,6 +255,16 @@ linalg.yield %0: f32 } ``` + + Shortened print form is available. Applies to simple maps with one + non-yield operation inside the body. + + The example above will be printed as: + ``` + %add = linalg.map { arith.addf } + ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>) + outs(%init: tensor<64xf32>) + ``` }]; let arguments = (ins @@ -329,10 +339,22 @@ outs(%init:tensor<16x64xf32>) dimensions = [1] (%in: f32, %out: f32) { - %0 = arith.addf %in, %out: f32 + %0 = arith.addf %out, %in: f32 linalg.yield %0: f32 } ``` + + Shortened print form is available. Applies to simple (not variadic) reduces + with one non-yield operation inside the body. Applies only if the operation + takes `%out` as the first argument. + + The example above will be printed as: + ``` + %reduce = linalg.reduce { arith.addf } + ins(%input:tensor<16x32x64xf32>) + outs(%init:tensor<16x64xf32>) + dimensions = [1] + ``` }]; let arguments = (ins diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1014,7 +1014,8 @@ static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, const OperationName &payloadOpName, const NamedAttrList &payloadOpAttrs, - ArrayRef operands) { + ArrayRef operands, + bool initFirst = false) { OpBuilder b(parser.getContext()); Region *body = result.addRegion(); Block &block = body->emplaceBlock(); @@ -1024,14 +1025,24 @@ block.addArgument(operand.getType().cast().getElementType(), b.getUnknownLoc()); } + SmallVector payloadOpOperands; + // If initFirst flag is enabled, we consider init as the first position of + // payload operands. + if (initFirst) { + payloadOpOperands.push_back(block.getArguments().back()); + for (const auto& arg : block.getArguments().drop_back()) + payloadOpOperands.push_back(arg); + } else { + payloadOpOperands = {block.getArguments().begin(), + block.getArguments().end()}; + } Operation *payloadOp = b.create( result.location, b.getStringAttr(payloadOpName.getStringRef()), - block.getArguments(), + payloadOpOperands, TypeRange{ result.operands.back().getType().cast().getElementType()}, payloadOpAttrs); - b.create(result.location, payloadOp->getResults()); } @@ -1070,7 +1081,9 @@ // Retrieve the operation from the body, if it is the only one (except // yield) and if it gets the same amount of arguments as the body does. -static Operation *findPayloadOp(Block *body) { +// If initFirst flag is enabled, we check that init takes the first position in +// operands of payload. +static Operation *findPayloadOp(Block *body, bool initFirst = false) { if (body->getOperations().size() != 2) return nullptr; Operation &payload = body->getOperations().front(); @@ -1079,10 +1092,22 @@ if (payload.getNumOperands() == 0 || payload.getNumOperands() != body->getNumArguments()) return nullptr; - for (const auto &[bbArg, operand] : - llvm::zip(payload.getOperands(), body->getArguments())) { - if (bbArg != operand) + if (initFirst) { + // check init + if (payload.getOperands().back() != body->getArgument(0)) return nullptr; + // check rest + for (int i = 1; i < body->getNumArguments(); ++i) { + if (payload.getOperand(i - 1) != body->getArgument(i)) { + return nullptr; + } + } + } else { + for (const auto &[bbArg, operand] : + llvm::zip(payload.getOperands(), body->getArguments())) { + if (bbArg != operand) + return nullptr; + } } return &payload; } @@ -1281,7 +1306,7 @@ if (payloadOpName.has_value()) { addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs, - makeArrayRef(result.operands)); + makeArrayRef(result.operands), /*initFirst=*/true); } else { SmallVector regionArgs; if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren, @@ -1304,7 +1329,7 @@ void ReduceOp::print(OpAsmPrinter &p) { Block *mapper = getBody(); - Operation *payloadOp = findPayloadOp(mapper); + Operation *payloadOp = findPayloadOp(mapper, /*initFirst=*/true); if (payloadOp) { printShortForm(p, payloadOp); } diff --git a/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir b/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir --- a/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir @@ -363,7 +363,7 @@ outs(%init:tensor<16x64xf32>) dimensions = [1] (%in: f32, %out: f32) { - %0 = arith.addf %in, %out: f32 + %0 = arith.addf %out, %in: f32 linalg.yield %0: f32 } func.return %reduce : tensor<16x64xf32> diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -414,7 +414,7 @@ outs(%init:tensor<16x64xf32>) dimensions = [1] (%in: f32, %out: f32) { - %0 = arith.addf %in, %out: f32 + %0 = arith.addf %out, %in: f32 linalg.yield %0: f32 } func.return %reduce : tensor<16x64xf32> @@ -433,7 +433,7 @@ outs(%init:memref<16x64xf32>) dimensions = [1] (%in: f32, %out: f32) { - %0 = arith.addf %in, %out: f32 + %0 = arith.addf %out, %in: f32 linalg.yield %0: f32 } func.return @@ -587,7 +587,7 @@ outs(%init:tensor<16x64xf32>) dimensions = [1] (%in: f32, %out: f32) { - %0 = arith.addf %in, %out fastmath : f32 + %0 = arith.addf %out, %in fastmath : f32 linalg.yield %0: f32 } func.return %reduce : tensor<16x64xf32>