diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -174,16 +174,6 @@ p << " outs(" << outputs << " : " << outputs.getTypes() << ")"; } -static void printCommonStructuredOpPartsWithNewLine(OpAsmPrinter &p, - ValueRange inputs, - ValueRange outputs) { - if (!inputs.empty()) { - p << " ins(" << inputs << " : " << inputs.getTypes() << ")"; - } - if (!outputs.empty()) { - p << " outs(" << outputs << " : " << outputs.getTypes() << ")"; - } -} //===----------------------------------------------------------------------===// // Specific parsing and printing for named structured ops created by ods-gen. //===----------------------------------------------------------------------===// @@ -1023,38 +1013,121 @@ inputs, /*outputs=*/{}, bodyBuild); } -ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) { - if (parseDstStyleOp(parser, result)) - return failure(); +static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, + const OperationName &payloadOpName, + const NamedAttrList &payloadOpAttrs, + ArrayRef operands) { + OpBuilder b(parser.getContext()); + Region *body = result.addRegion(); + Block &block = body->emplaceBlock(); + b.setInsertionPointToStart(&block); + SmallVector bbArgs; + for (auto &operand : operands) { + block.addArgument(operand.getType().cast().getElementType(), + b.getUnknownLoc()); + } - SmallVector regionArgs; - if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren, - /*allowType=*/true, /*allowAttrs=*/true)) { - return failure(); + Operation *payloadOp = b.create( + result.location, b.getStringAttr(payloadOpName.getStringRef()), + block.getArguments(), + TypeRange{ + result.operands.back().getType().cast().getElementType()}, + payloadOpAttrs); + + b.create(result.location, payloadOp->getResults()); +} + +ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) { + std::optional payloadOpName; + NamedAttrList payloadOpAttrs; + if (succeeded(parser.parseOptionalLBrace())) { + FailureOr operationName = parser.parseCustomOperationName(); + if (failed(operationName)) + return failure(); + if (parser.parseOptionalAttrDict(payloadOpAttrs)) + return failure(); + payloadOpName = operationName.value(); + if (parser.parseRBrace()) + return failure(); } - Region *body = result.addRegion(); - if (parser.parseRegion(*body, regionArgs)) + if (parseDstStyleOp(parser, result)) return failure(); + if (payloadOpName.has_value()) { + addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs, + makeArrayRef(result.operands).drop_back()); + } else { + SmallVector regionArgs; + if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren, + /*allowType=*/true, /*allowAttrs=*/true)) { + return failure(); + } + Region *body = result.addRegion(); + if (parser.parseRegion(*body, regionArgs)) + return failure(); + } return success(); } +// Retrieve the operation from the body, if it is the only one (except +// yield) and if it gets the same amount of arguments as the body does. +static Operation *findPayloadOp(Block *body) { + if (body->getOperations().size() != 2) + return nullptr; + Operation &payload = body->getOperations().front(); + assert(isa(body->getOperations().back())); + + if (payload.getNumOperands() == 0 || + payload.getNumOperands() != body->getNumArguments()) + return nullptr; + for (const auto &[bbArg, operand] : + llvm::zip(payload.getOperands(), body->getArguments())) { + if (bbArg != operand) + return nullptr; + } + return &payload; +} + +void printShortForm(OpAsmPrinter &p, Operation *payloadOp) { + SmallVector elidedAttrs; + std::string attrToElide; + p << " { " << payloadOp->getName().getStringRef(); + for (const auto &attr : payloadOp->getAttrs()) { + auto fastAttr = attr.getValue().dyn_cast(); + if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) { + attrToElide = attr.getName().str(); + elidedAttrs.push_back(attrToElide); + break; + } + } + p.printOptionalAttrDict(payloadOp->getAttrs(), elidedAttrs); + p << " }"; +} + void MapOp::print(OpAsmPrinter &p) { - printCommonStructuredOpPartsWithNewLine( - p, SmallVector(getDpsInputOperands()), - SmallVector(getDpsInitOperands())); - p.printOptionalAttrDict((*this)->getAttrs()); + Block *mapper = getBody(); + Operation *payloadOp = findPayloadOp(mapper); + if (payloadOp) { + printShortForm(p, payloadOp); + } - p.increaseIndent(); - p.printNewline(); - p << "("; - llvm::interleaveComma(getMapper().getArguments(), p, - [&](auto arg) { p.printRegionArgument(arg); }); - p << ") "; + printCommonStructuredOpParts(p, SmallVector(getDpsInputOperands()), + SmallVector(getDpsInitOperands())); + p.printOptionalAttrDict((*this)->getAttrs()); - p.printRegion(getMapper(), /*printEntryBlockArgs=*/false); - p.decreaseIndent(); + if (!payloadOp) { + // Print region if the payload op was not detected. + p.increaseIndent(); + p.printNewline(); + p << "("; + llvm::interleaveComma(mapper->getArguments(), p, + [&](auto arg) { p.printRegionArgument(arg); }); + p << ") "; + + p.printRegion(getMapper(), /*printEntryBlockArgs=*/false); + p.decreaseIndent(); + } } LogicalResult MapOp::verify() { @@ -1067,7 +1140,7 @@ "mapper, but got: " << getInputs().size() << " and " << blockArgs.size(); - // The parameters of mapper should all match the element type // of inputs. + // The parameters of mapper should all match the element type of inputs. for (const auto &[bbArgType, inputArg] : llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) { auto inputElemType = inputArg.getType().cast().getElementType(); @@ -1189,21 +1262,39 @@ } ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) { + std::optional payloadOpName; + NamedAttrList payloadOpAttrs; + if (succeeded(parser.parseOptionalLBrace())) { + FailureOr operationName = parser.parseCustomOperationName(); + if (failed(operationName)) + return failure(); + if (parser.parseOptionalAttrDict(payloadOpAttrs)) + return failure(); + payloadOpName = operationName.value(); + if (parser.parseRBrace()) + return failure(); + } + if (parseDstStyleOp( parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) { return parseDenseI64ArrayAttr(parser, attributes, "dimensions"); })) return failure(); - SmallVector regionArgs; - if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren, - /*allowType=*/true, /*allowAttrs=*/true)) { - return failure(); - } + if (payloadOpName.has_value()) { + addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs, + makeArrayRef(result.operands)); + } else { + SmallVector regionArgs; + if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren, + /*allowType=*/true, /*allowAttrs=*/true)) { + return failure(); + } - Region *body = result.addRegion(); - if (parser.parseRegion(*body, regionArgs)) - return failure(); + Region *body = result.addRegion(); + if (parser.parseRegion(*body, regionArgs)) + return failure(); + } return success(); } @@ -1214,22 +1305,28 @@ } void ReduceOp::print(OpAsmPrinter &p) { - printCommonStructuredOpPartsWithNewLine( - p, SmallVector(getDpsInputOperands()), - SmallVector(getDpsInitOperands())); + Block *mapper = getBody(); + Operation *payloadOp = findPayloadOp(mapper); + if (payloadOp) { + printShortForm(p, payloadOp); + } + printCommonStructuredOpParts(p, SmallVector(getDpsInputOperands()), + SmallVector(getDpsInitOperands())); printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions()); p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()}); - - p.increaseIndent(); - p.printNewline(); - p << "("; - llvm::interleaveComma(getCombiner().getArguments(), p, - [&](auto arg) { p.printRegionArgument(arg); }); - p << ") "; - - p.printRegion(getCombiner(), /*printEntryBlockArgs=*/false); - p.decreaseIndent(); + if (!payloadOp) { + // Print region if the payload op was not detected. + p.increaseIndent(); + p.printNewline(); + p << "("; + llvm::interleaveComma(mapper->getArguments(), p, + [&](auto arg) { p.printRegionArgument(arg); }); + p << ") "; + + p.printRegion(getCombiner(), /*printEntryBlockArgs=*/false); + p.decreaseIndent(); + } } LogicalResult ReduceOp::verify() { @@ -1378,9 +1475,8 @@ } void TransposeOp::print(OpAsmPrinter &p) { - printCommonStructuredOpPartsWithNewLine( - p, SmallVector(getDpsInputOperands()), - SmallVector(getDpsInitOperands())); + printCommonStructuredOpParts(p, SmallVector(getDpsInputOperands()), + SmallVector(getDpsInitOperands())); printDenseI64ArrayAttr(p, getPermutationAttrName(), getPermutation()); p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()}); } @@ -1493,9 +1589,8 @@ } void BroadcastOp::print(OpAsmPrinter &p) { - printCommonStructuredOpPartsWithNewLine( - p, SmallVector(getDpsInputOperands()), - SmallVector(getDpsInitOperands())); + printCommonStructuredOpParts(p, SmallVector(getDpsInputOperands()), + SmallVector(getDpsInitOperands())); printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions()); p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()}); } diff --git a/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir b/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir --- a/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir @@ -340,7 +340,7 @@ // CHECK-SAME: %[[RHS:[0-9a-zA-Z]*]]: memref<64xf32 func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>, %init: tensor<64xf32>) -> tensor<64xf32> { - // CHECK: linalg.map ins(%[[LHS]], %[[RHS]] : memref<64xf32 + // CHECK: linalg.map { arith.addf } ins(%[[LHS]], %[[RHS]] : memref<64xf32 %add = linalg.map ins(%lhs, %rhs: tensor<64xf32>, tensor<64xf32>) outs(%init:tensor<64xf32>) @@ -357,7 +357,7 @@ // CHECK-SAME: %[[INPUT:.*]]: memref<16x32x64xf32 func.func @reduce(%input: tensor<16x32x64xf32>, %init: tensor<16x64xf32>) -> tensor<16x64xf32> { - // CHECK: linalg.reduce ins(%[[INPUT]] : memref<16x32x64xf32 + // CHECK: linalg.reduce { arith.addf } ins(%[[INPUT]] : memref<16x32x64xf32 %reduce = linalg.reduce ins(%input:tensor<16x32x64xf32>) outs(%init:tensor<16x64xf32>) diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -356,12 +356,8 @@ func.return %add : tensor<64xf32> } // CHECK-LABEL: func @map_binary -// CHECK: linalg.map ins +// CHECK: linalg.map { arith.addf } ins // CHECK-SAME: outs -// CHECK-NEXT: (%{{.*}}: f32, %{{.*}}: f32) { -// CHECK-NEXT: arith.addf -// CHECK-NEXT: linalg.yield -// CHECK-NEXT: } // ----- @@ -424,13 +420,9 @@ func.return %reduce : tensor<16x64xf32> } // CHECK-LABEL: func @reduce -// CHECK: linalg.reduce ins +// CHECK: linalg.reduce { arith.addf } ins // CHECK-SAME: outs // CHECK-SAME: dimensions = [1] -// CHECK-NEXT: (%{{.*}}: f32, %{{.*}}: f32) { -// CHECK-NEXT: arith.addf -// CHECK-NEXT: linalg.yield -// CHECK-NEXT: } // ----- @@ -446,8 +438,10 @@ } func.return } -// CHECK-LABEL: func @reduce_memref -// CHECK: linalg.reduce +// CHECK-LABEL: func @reduce +// CHECK: linalg.reduce { arith.addf } ins +// CHECK-SAME: outs +// CHECK-SAME: dimensions = [1] // ----- @@ -467,6 +461,7 @@ } // CHECK-LABEL: func @variadic_reduce // CHECK: linalg.reduce +// CHECK-NOT: { arith.addf // ----- @@ -484,8 +479,9 @@ } func.return } -// CHECK-LABEL: func @variadic_reduce_memref +// CHECK-LABEL: func @variadic_reduce_memref // CHECK: linalg.reduce +// CHECK-NOT: { arith.addf // ----- @@ -560,3 +556,46 @@ // CHECK: linalg.broadcast ins // CHECK-SAME: outs // CHECK-SAME: dimensions + +// ----- + +func.func @map_arith_with_attr(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>, + %init: tensor<64xf32>) -> tensor<64xf32> { + %add = linalg.map + ins(%lhs, %rhs: tensor<64xf32>, tensor<64xf32>) + outs(%init:tensor<64xf32>) + (%lhs_elem: f32, %rhs_elem: f32) { + %0 = arith.addf %lhs_elem, %rhs_elem fastmath : f32 + linalg.yield %0: f32 + } + func.return %add : tensor<64xf32> +} + +// CHECK-LABEL: func @map_arith_with_attr +// CHECK-NEXT: %[[MAPPED:.*]] = linalg.map +// CHECK-SAME: { arith.addf {fastmath = #arith.fastmath} } +// CHECK-SAME: ins +// CHECK-SAME: outs +// CHECK-NEXT: return %[[MAPPED]] : tensor<64xf32> + +// ----- + +func.func @reduce_arith_with_attr(%input: tensor<16x32x64xf32>, + %init: tensor<16x64xf32>) -> tensor<16x64xf32> { + %reduce = linalg.reduce + ins(%input:tensor<16x32x64xf32>) + outs(%init:tensor<16x64xf32>) + dimensions = [1] + (%in: f32, %out: f32) { + %0 = arith.addf %in, %out fastmath : f32 + linalg.yield %0: f32 + } + func.return %reduce : tensor<16x64xf32> +} +// CHECK-LABEL: func @reduce_arith_with_attr +// CHECK-NEXT: %[[REDUCED:.*]] = linalg.reduce +// CHECK-SAME: { arith.addf {fastmath = #arith.fastmath} } +// CHECK-SAME: ins +// CHECK-SAME: outs +// CHECK-SAME: dimensions = [1] +// CHECK-NEXT: return %[[REDUCED]] : tensor<16x64xf32>