diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -521,7 +521,8 @@ %0 = linalg.tiled_loop (%i) = (%c0) to (%c24) step (%c4) ins(%lhs, %rhs : tensor<24x64xi8>, tensor<24x64xi8>) outs(%out : tensor<24x64xi8>) - iterators("parallel") { + iterators("parallel") + distribution("block_x") { %lhs_sub = subtensor %lhs[%i, 0] [%c4, %c64] [1, 1] : tensor<24x64xi8> to tensor %rhs_sub = subtensor %rhs[%i, 0] [%c4, %c64] [1, 1] @@ -551,7 +552,8 @@ linalg.tiled_loop (%i) = (%c0) to (%c24) step (%c4) ins(%lhs, %rhs : memref<24x64xi8>, memref<24x64xi8>) outs(%out : memref<24x64xi8>) - iterators("parallel") { + iterators("parallel") + distribution("block_x") { %lhs_sub = subview %lhs[%i, 0] [%c4, %c64] [1, 1] : memref<24x64xi8> to memref %rhs_sub = subview %rhs[%i, 0] [%c4, %c64] [1, 1] @@ -570,11 +572,18 @@ Variadic:$step, Variadic:$inputs, Variadic:$outputs, - ArrayAttr:$iterator_types); + ArrayAttr:$iterator_types, + OptionalAttr:$distribution_types); let results = (outs Variadic:$results); let regions = (region SizedRegion<1>:$region); let builders = [ + OpBuilder<(ins "ValueRange":$lowerBounds, "ValueRange":$upperBounds, + "ValueRange":$steps, "ValueRange":$inputs, "ValueRange":$outputs, + "ArrayAttr":$iteratorTypes, "Optional":$distributionTypes, + CArg<"function_ref", + "nullptr">:$bodyBuilderFn)>, OpBuilder<(ins "ValueRange":$lowerBounds, "ValueRange":$upperBounds, "ValueRange":$steps, "ValueRange":$inputs, "ValueRange":$outputs, "ArrayAttr":$iteratorTypes, diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h --- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h @@ -50,6 +50,12 @@ /// op's iterators. constexpr StringRef getIteratorTypesAttrName() { return "iterator_types"; } +/// Attribute name for the StrArrayAttr which encodes the distribution type for +/// `linalg.tiled_loop`. +constexpr StringRef getDistributionTypesAttrName() { + return "distribution_types"; +} + /// Attribute name for the StringAttr which encodes an optional documentation /// string of the structured op. constexpr StringRef getDocAttrName() { return "doc"; } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -2075,6 +2075,18 @@ function_ref bodyBuilderFn) { + build(builder, result, lowerBounds, upperBounds, steps, inputs, outputs, + iteratorTypes, llvm::None, bodyBuilderFn); +} + +void TiledLoopOp::build(OpBuilder &builder, OperationState &result, + ValueRange lowerBounds, ValueRange upperBounds, + ValueRange steps, ValueRange inputs, ValueRange outputs, + ArrayAttr iteratorTypes, + Optional distributionTypes, + function_ref + bodyBuilderFn) { result.addOperands(lowerBounds); result.addOperands(upperBounds); result.addOperands(steps); @@ -2089,6 +2101,10 @@ static_cast(outputs.size())})); result.addAttribute(getIteratorTypesAttrName(), iteratorTypes); + if (distributionTypes.hasValue()) + result.addAttribute(getDistributionTypesAttrName(), + distributionTypes.getValue()); + // Add output types for `RankedTensorType` output arguments. for (Value output : outputs) { Type outputType = output.getType(); @@ -2143,14 +2159,17 @@ if (llvm::any_of(op.iterator_types(), [](Attribute attr) { return attr.cast().getValue() != getParallelIteratorTypeName(); - })) { + })) p << " iterators" << op.iterator_types() << ""; - } + + if (op.distribution_types().hasValue()) + p << " distribution" << op.distribution_types().getValue() << ""; p.printRegion(op.region(), /*printEntryBlockArgs=*/false); p.printOptionalAttrDict( op->getAttrs(), /*elidedAttrs=*/{TiledLoopOp::getOperandSegmentSizeAttr(), - getIteratorTypesAttrName()}); + getIteratorTypesAttrName(), + getDistributionTypesAttrName()}); } static ParseResult parseTiledLoopOp(OpAsmParser &parser, @@ -2219,26 +2238,38 @@ } // Parse attributes. - SmallVector iterTypes; - if (succeeded(parser.parseOptionalKeyword("iterators"))) { - StringAttr iterType; + SmallVector iterTypes, distributionTypes; + auto parseAttr = [&](StringRef keyword, SmallVector *attrs) { + if (succeeded(parser.parseOptionalKeyword(keyword))) { + StringAttr attr; - if (parser.parseLSquare() || parser.parseAttribute(iterType)) - return failure(); - iterTypes.push_back(iterType); - for (int i = 1, e = ivs.size(); i < e; ++i) { - if (parser.parseComma() || parser.parseAttribute(iterType)) + if (parser.parseLSquare() || parser.parseAttribute(attr)) + return failure(); + attrs->push_back(attr); + for (int i = 1, e = ivs.size(); i < e; ++i) { + if (parser.parseComma() || parser.parseAttribute(attr)) + return failure(); + attrs->push_back(attr); + } + if (parser.parseRSquare()) return failure(); - iterTypes.push_back(iterType); } - if (parser.parseRSquare()) - return failure(); - } else { + return success(); + }; + if (failed(parseAttr("iterators", &iterTypes)) || + failed(parseAttr("distribution", &distributionTypes))) + return failure(); + + // Set all loop iterator types to "parallel" if they are not printed in IR. + if (iterTypes.empty()) { auto parallelIter = builder.getStringAttr(getParallelIteratorTypeName()); iterTypes = SmallVector(ivs.size(), parallelIter); } result.addAttribute(getIteratorTypesAttrName(), builder.getArrayAttr(iterTypes)); + if (!distributionTypes.empty()) + result.addAttribute(getDistributionTypesAttrName(), + builder.getArrayAttr(distributionTypes)); result.addAttribute( TiledLoopOp::getOperandSegmentSizeAttr(), builder.getI32VectorAttr({static_cast(lower.size()), @@ -2352,7 +2383,8 @@ Location loc = tiledLoop.getLoc(); auto newTiledLoop = rewriter.create( loc, tiledLoop.lowerBound(), tiledLoop.upperBound(), tiledLoop.step(), - newInputs, tiledLoop.outputs(), tiledLoop.iterator_types()); + newInputs, tiledLoop.outputs(), tiledLoop.iterator_types(), + tiledLoop.distribution_types()); // Clone the region. BlockAndValueMapping bvm; @@ -2441,7 +2473,8 @@ Location loc = tiledLoop.getLoc(); auto newTiledLoop = rewriter.create( loc, tiledLoop.lowerBound(), tiledLoop.upperBound(), tiledLoop.step(), - tiledLoop.inputs(), newOutputOperands, tiledLoop.iterator_types()); + tiledLoop.inputs(), newOutputOperands, tiledLoop.iterator_types(), + tiledLoop.distribution_types()); // Clone the region. BlockAndValueMapping bvm; diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -827,7 +827,8 @@ %i2d_ = %input_2d: tensor<16x32xf32>, %i1d_ = %input_1d: tensor<24xf32>) outs(%o_ = %output: tensor<24xf32>) - iterators["reduction", "parallel", "reduction"] { + iterators["reduction", "parallel", "reduction"] + distribution["block_x", "block_y", "none"] { %sub_3d = subtensor %i3d_[%i, %j, %k][2, 4, 8][1, 1, 1] : tensor<16x24x32xf32> to tensor<2x4x8xf32> %sub_2d = subtensor %i2d_[%i, %k][2, 8][1, 1]