diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -15,6 +15,7 @@ include "mlir/Dialect/Linalg/IR/LinalgBase.td" include "mlir/Interfaces/ControlFlowInterfaces.td" +include "mlir/Interfaces/LoopLikeInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/ViewLikeInterface.td" @@ -485,4 +486,58 @@ }]; } +def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [ + AttrSizedOperandSegments, + DeclareOpInterfaceMethods, + RecursiveSideEffects, + SingleBlockImplicitTerminator<"linalg::YieldOp"> + ]> { + let summary = "Linalg tiled loop operation"; + let description = [{ + This is a loop-like operation with additional properties. The arguments + also include the input and the output tensors and the attributes to specify + the iterator types. The body region of the loop contains `subtensor` + operations applied to every tensor argument of TiledLoopOp. + + The body region must contain exactly one block that terminates with + `linalg.yield` with the operands resulting from `subtensor_insert` + operations. + + Parsing TiledLoopOp will set all elements of the `iterator_types` attribute + to "parallel" type, when it is absent from the custom format. + + Example: + + ```mlir + linalg.tiled_loop (%i) = (%c0) to (%c24) step (%c4) + ins(%lhs, %rhs : tensor<24x64xi8>, tensor<24x64xi8>) + outs(%out : tensor<24x64xi8>) + iterators("parallel") { + %lhs_sub = subtensor %lhs[%i, 0] [%c4, %c64] [1, 1] + : tensor<24x64xi8> to tensor + %rhs_sub = subtensor %rhs[%i, 0] [%c4, %c64] [1, 1] + : tensor<24x64xi8> to tensor + %out_sub = subtensor %out[%i, 0] [%c4, %c64] [1, 1] + : tensor<24x64xi8> to tensor + + %result_sub = linalg.generic ... + + %result = subtensor_insert %result_sub into %out[%i, 0][%c4, %c64][1, 1] + : tensor into tensor<24x64xi8> + linalg.yield %result : tensor<24x64xi8> + } + ``` + }]; + + let arguments = (ins Variadic:$lowerBound, + Variadic:$upperBound, + Variadic:$step, + Variadic:$inputs, + Variadic:$outputs, + ArrayAttr:$iterator_types); + let results = (outs Variadic:$results); + let regions = (region SizedRegion<1>:$region); +} + + #endif // LINALG_OPS diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1704,9 +1704,157 @@ return success(); } + if (auto tiledLoopOp = dyn_cast(parentOp)) { + return success(); + } return op.emitOpError("expected parent op with LinalgOp interface"); } +//===----------------------------------------------------------------------===// +// TiledLoopOp +//===----------------------------------------------------------------------===// + +static void print(OpAsmPrinter &p, TiledLoopOp op) { + p << op.getOperationName() << " (" << op.getBody()->getArguments() << ") = (" + << op.lowerBound() << ") to (" << op.upperBound() << ") step (" << op.step() + << ")"; + + if (!op.inputs().empty()) + p << " ins (" << op.inputs() << ")"; + if (!op.outputs().empty()) + p << " outs (" << op.outputs() << ")"; + + if (llvm::any_of(op.iterator_types(), [](Attribute attr) { + return attr.cast().getValue() != + getParallelIteratorTypeName(); + })) { + p << " iterators(" << op.iterator_types() << ")"; + } + + p.printRegion(op.region(), /*printEntryBlockArgs=*/false); + p.printOptionalAttrDict( + op.getAttrs(), /*elidedAttrs=*/{TiledLoopOp::getOperandSegmentSizeAttr(), + getIteratorTypesAttrName()}); +} + +static ParseResult parseTiledLoopOp(OpAsmParser &parser, + OperationState &result) { + auto &builder = parser.getBuilder(); + // Parse an opening `(` followed by induction variables followed by `)` + SmallVector ivs; + if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1, + OpAsmParser::Delimiter::Paren)) + return failure(); + + // Parse loop bounds. + SmallVector lower; + if (parser.parseEqual() || + parser.parseOperandList(lower, ivs.size(), + OpAsmParser::Delimiter::Paren) || + parser.resolveOperands(lower, builder.getIndexType(), result.operands)) + return failure(); + + SmallVector upper; + if (parser.parseKeyword("to") || + parser.parseOperandList(upper, ivs.size(), + OpAsmParser::Delimiter::Paren) || + parser.resolveOperands(upper, builder.getIndexType(), result.operands)) + return failure(); + + // Parse step values. + SmallVector steps; + if (parser.parseKeyword("step") || + parser.parseOperandList(steps, ivs.size(), + OpAsmParser::Delimiter::Paren) || + parser.resolveOperands(steps, builder.getIndexType(), result.operands)) + return failure(); + + // Parse input tensors. + SmallVector inputs; + if (succeeded(parser.parseOptionalKeyword("ins"))) { + SmallVector inputTypes; + llvm::SMLoc inputsOperandsLoc = parser.getCurrentLocation(); + + if (parser.parseLParen() || parser.parseOperandList(inputs) || + parser.parseColonTypeList(inputTypes) || parser.parseRParen()) + return failure(); + + if (parser.resolveOperands(inputs, inputTypes, inputsOperandsLoc, + result.operands)) + return failure(); + } + + // Parse output tensors. + SmallVector outputs; + if (succeeded(parser.parseOptionalKeyword("outs"))) { + SmallVector outputTypes; + llvm::SMLoc outputsOperandsLoc = parser.getCurrentLocation(); + + if (parser.parseLParen() || parser.parseOperandList(outputs) || + parser.parseColonTypeList(outputTypes) || parser.parseRParen()) + return failure(); + + if (parser.resolveOperands(outputs, outputTypes, outputsOperandsLoc, + result.operands)) + return failure(); + result.addTypes(outputTypes); + } + + // Parse attributes. + SmallVector iterTypes; + if (succeeded(parser.parseOptionalKeyword("iterators"))) { + StringAttr iterType; + + if (parser.parseLParen() || parser.parseAttribute(iterType)) + return failure(); + iterTypes.push_back(iterType); + for (int i = 1, e = ivs.size(); i < e; ++i) { + if (parser.parseComma() || parser.parseAttribute(iterType)) + return failure(); + iterTypes.push_back(iterType); + } + if (parser.parseRParen()) + return failure(); + } else { + auto parallelIter = builder.getStringAttr(getParallelIteratorTypeName()); + iterTypes = SmallVector(ivs.size(), parallelIter); + } + result.addAttribute(getIteratorTypesAttrName(), + builder.getArrayAttr(iterTypes)); + result.addAttribute( + TiledLoopOp::getOperandSegmentSizeAttr(), + builder.getI32VectorAttr({static_cast(lower.size()), + static_cast(upper.size()), + static_cast(steps.size()), + static_cast(inputs.size()), + static_cast(outputs.size())})); + + // Parse the body. + Region *body = result.addRegion(); + SmallVector types(ivs.size(), builder.getIndexType()); + if (parser.parseRegion(*body, ivs, types)) + return failure(); + + // Parse optional attributes. + parser.parseOptionalAttrDict(result.attributes); + + return success(); +} + +Region &TiledLoopOp::getLoopBody() { return region(); } + +LogicalResult TiledLoopOp::moveOutOfLoop(ArrayRef ops) { + for (auto *op : ops) + op->moveBefore(*this); + return success(); +} + +bool TiledLoopOp::isDefinedOutsideOfLoop(Value value) { + return !region().isAncestor(value.getParentRegion()); +} + +static LogicalResult verify(TiledLoopOp op) { return success(); } + /////// Operations corresponding to library calls defined with Tablegen //////// template diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -794,3 +794,110 @@ return %1 : tensor } // CHECK: %{{.+}} = linalg.fill(%{{.+}}, %{{.+}}) : tensor, f32 -> tensor + +// ----- + +#accesses = [ + affine_map<(i, j) -> (i, j)>, + affine_map<(i, j) -> (i, j)>, + affine_map<(i, j) -> (i, j)> +] + +#trait = { + indexing_maps = #accesses, + iterator_types = ["parallel", "parallel"] +} + +func @tiled_loop(%lhs: tensor<24x64xi8>, %rhs: tensor<24x64xi8>, + %out: tensor<24x64xi8>) -> tensor<24x64xi8> { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c4 = constant 4 : index + %c24 = constant 24 : index + %c64 = constant 64 : index + %prod = linalg.tiled_loop (%i) = (%c0) to (%c24) step (%c4) + ins(%lhs, %rhs : tensor<24x64xi8>, tensor<24x64xi8>) + outs(%out : tensor<24x64xi8>) { + %lhs_sub = subtensor %lhs[%i, 0] [%c4, %c64] [1, 1] + : tensor<24x64xi8> to tensor + %rhs_sub = subtensor %rhs[%i, 0] [%c4, %c64] [1, 1] + : tensor<24x64xi8> to tensor + %out_sub = subtensor %out[%i, 0] [%c4, %c64] [1, 1] + : tensor<24x64xi8> to tensor + + %sum = linalg.generic #trait + ins(%lhs_sub, %rhs_sub : tensor, tensor) + outs(%out_sub : tensor) { + ^bb(%l: i8, %r: i8, %o: i8) : + %s = addi %l, %r : i8 + linalg.yield %s : i8 + } -> tensor + + %sum_sub = subtensor_insert %sum into %out[%i, 0][%c4, %c64][1, 1] + : tensor into tensor<24x64xi8> + linalg.yield %sum_sub : tensor<24x64xi8> + } + return %prod : tensor<24x64xi8> +} +// CHECK-LABEL: func @tiled_loop +// CHECK-NOT: iterators( + +// ----- + +#id_3d = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#id_2d = affine_map<(d0, d1, d2) -> (d0, d2)> +#id_1d = affine_map<(d0, d1, d2) -> (d1)> + +#trait = { + indexing_maps = [ + #id_3d, + #id_2d, + #id_1d, + #id_1d + ], + iterator_types = ["reduction", "parallel", "reduction"] +} + +func @tiled_loop_reduction(%input_3d: tensor<16x24x32xf32>, + %input_2d: tensor<16x32xf32>, + %input_1d: tensor<24xf32>, + %output: tensor<24xf32>) -> tensor<24xf32> { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c2 = constant 2 : index + %c4 = constant 4 : index + %c8 = constant 8 : index + %X = dim %input_3d, %c0 : tensor<16x24x32xf32> + %Y = dim %input_3d, %c1 : tensor<16x24x32xf32> + %Z = dim %input_3d, %c2 : tensor<16x24x32xf32> + %result = linalg.tiled_loop (%i, %j, %k) + = (%c0, %c0, %c0) to (%X, %Y, %Z) step (%c2, %c4, %c8) + ins(%input_3d, %input_2d: tensor<16x24x32xf32>, tensor<16x32xf32>) + outs( %output: tensor<24xf32>) + iterators("reduction", "parallel", "reduction") { + %sub_3d = subtensor %input_3d[%i, %j, %k][2, 4, 8][1, 1, 1] + : tensor<16x24x32xf32> to tensor<2x4x8xf32> + %sub_2d = subtensor %input_2d[%i, %k][2, 8][1, 1] + : tensor<16x32xf32> to tensor<2x8xf32> + %sub_1d = subtensor %input_1d[%j] [4] [1] + : tensor<24xf32> to tensor<4xf32> + %sub_out = subtensor %output[%j] [4] [1] + : tensor<24xf32> to tensor<4xf32> + %acc = linalg.generic #trait + ins(%sub_3d, %sub_2d, %sub_1d + : tensor<2x4x8xf32>, tensor<2x8xf32>, tensor<4xf32>) + outs(%sub_out : tensor<4xf32>) { + ^bb0(%i3d: f32, %i2d: f32, %i1d: f32, %o: f32): + %0 = addf %i3d, %i2d : f32 + %1 = addf %0, %i1d : f32 + linalg.yield %1 : f32 + } -> tensor<4xf32> + + %sum_sub = subtensor_insert %acc into %output[%j][%c4][1] + : tensor<4xf32> into tensor<24xf32> + linalg.yield %sum_sub : tensor<24xf32> + } + return %result : tensor<24xf32> +} +// CHECK-LABEL: func @tiled_loop_reduction +// CHECK: iterators(