diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td @@ -15,6 +15,8 @@ include "mlir/IR/OpBase.td" +def LinalgOperand: AnyTypeOf<[AnyRankedTensor, AnyStridedMemRef]>; + def Linalg_Dialect : Dialect { let name = "linalg"; let description = [{ diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -496,21 +496,25 @@ let summary = "Linalg tiled loop operation"; let description = [{ This is a loop-like operation with additional properties. The arguments - also include the input and the output tensors and the attributes to specify - the iterator types. The body region of the loop contains `subtensor` - operations applied to every tensor argument of TiledLoopOp. + also include the input and the output tensors or memrefs and the attributes + to specify the iterator types. + + Parsing TiledLoopOp will set all elements of the `iterator_types` attribute + to "parallel" type, when it is absent from the custom format. + + Tensor-based version: + + The body region of the loop contains `subtensor` operations applied to + every tensor argument of TiledLoopOp. The body region must contain exactly one block that terminates with `linalg.yield` with the operands resulting from `subtensor_insert` operations. - Parsing TiledLoopOp will set all elements of the `iterator_types` attribute - to "parallel" type, when it is absent from the custom format. - Example: ```mlir - linalg.tiled_loop (%i) = (%c0) to (%c24) step (%c4) + %0 = linalg.tiled_loop (%i) = (%c0) to (%c24) step (%c4) ins(%lhs, %rhs : tensor<24x64xi8>, tensor<24x64xi8>) outs(%out : tensor<24x64xi8>) iterators("parallel") { @@ -528,13 +532,40 @@ linalg.yield %result : tensor<24x64xi8> } ``` + + MemRef-based version: + + The body region of the loop contains `subview` operations applied to + every memref argument of TiledLoopOp. + + The body region must contain exactly one block that terminates with + `linalg.yield` with no operands. + + Example: + + ```mlir + linalg.tiled_loop (%i) = (%c0) to (%c24) step (%c4) + ins(%lhs, %rhs : memref<24x64xi8>, memref<24x64xi8>) + outs(%out : memref<24x64xi8>) + iterators("parallel") { + %lhs_sub = subview %lhs[%i, 0] [%c4, %c64] [1, 1] + : memref<24x64xi8> to memref + %rhs_sub = subview %rhs[%i, 0] [%c4, %c64] [1, 1] + : memref<24x64xi8> to memref + %out_sub = subview %out[%i, 0] [%c4, %c64] [1, 1] + : memref<24x64xi8> to memref + + %result_sub = linalg.generic ... + linalg.yield + } + ``` }]; let arguments = (ins Variadic:$lowerBound, Variadic:$upperBound, Variadic:$step, - Variadic:$inputs, - Variadic:$outputs, + Variadic:$inputs, + Variadic:$outputs, ArrayAttr:$iterator_types); let results = (outs Variadic:$results); let regions = (region SizedRegion<1>:$region); @@ -542,7 +573,7 @@ let builders = [ OpBuilder<(ins "ValueRange":$lowerBounds, "ValueRange":$upperBounds, "ValueRange":$steps, "ValueRange":$inputs, "ValueRange":$outputs, - "ArrayRef":$iteratorTypes, + "ArrayAttr":$iteratorTypes, CArg<"function_ref", "nullptr">:$bodyBuilderFn)>, ]; diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -496,8 +496,6 @@ //===----------------------------------------------------------------------===// // Generic Linalg ops. //===----------------------------------------------------------------------===// -def LinalgOperand: AnyTypeOf<[AnyRankedTensor, AnyStridedMemRef]>; - class LinalgOperandOfRank: Type< And<[ LinalgOperand.predicate, diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1743,7 +1743,7 @@ void TiledLoopOp::build( OpBuilder &builder, OperationState &result, ValueRange lowerBounds, ValueRange upperBounds, ValueRange steps, ValueRange inputs, - ValueRange outputs, ArrayRef iteratorTypes, + ValueRange outputs, ArrayAttr iteratorTypes, function_ref bodyBuilderFn) { result.addOperands(lowerBounds); result.addOperands(upperBounds); @@ -1757,9 +1757,14 @@ static_cast(steps.size()), static_cast(inputs.size()), static_cast(outputs.size())})); - result.addAttribute(getIteratorTypesAttrName(), - builder.getStrArrayAttr(iteratorTypes)); - result.addTypes(outputs.getTypes()); + result.addAttribute(getIteratorTypesAttrName(), iteratorTypes); + + // Add output types for `RankedTensorType` output arguments. + for (Value output : outputs) { + Type outputType = output.getType(); + if (outputType.isa()) + result.addTypes(outputType); + } OpBuilder::InsertionGuard guard(builder); unsigned numIVs = steps.size(); @@ -1770,8 +1775,8 @@ if (bodyBuilderFn) { builder.setInsertionPointToStart(bodyBlock); bodyBuilderFn(builder, result.location, bodyBlock->getArguments()); + TiledLoopOp::ensureTerminator(*bodyRegion, builder, result.location); } - TiledLoopOp::ensureTerminator(*bodyRegion, builder, result.location); } static void print(OpAsmPrinter &p, TiledLoopOp op) {