diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -484,6 +484,7 @@ linalg.yield %f0, %f1 : f32, f32 ``` }]; + let builders = [OpBuilderDAG<(ins), [{ /* nothing to do */ }]>]; } def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [ @@ -537,6 +538,21 @@ ArrayAttr:$iterator_types); let results = (outs Variadic:$results); let regions = (region SizedRegion<1>:$region); + + let builders = [ + OpBuilderDAG<(ins "ValueRange":$lowerBounds, "ValueRange":$upperBounds, + "ValueRange":$steps, "ValueRange":$inputs, "ValueRange":$outputs, + "ArrayRef":$iteratorTypes, + CArg<"function_ref", + "nullptr">:$bodyBuilderFn)>, + ]; + + let extraClassDeclaration = [{ + ValueRange getInductionVars() { + return getBody()->getArguments(); + } + unsigned getNumLoops() { return step().size(); } + }]; } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1701,6 +1701,40 @@ // TiledLoopOp //===----------------------------------------------------------------------===// +void TiledLoopOp::build( + OpBuilder &builder, OperationState &result, ValueRange lowerBounds, + ValueRange upperBounds, ValueRange steps, ValueRange inputs, + ValueRange outputs, ArrayRef iteratorTypes, + function_ref bodyBuilderFn) { + result.addOperands(lowerBounds); + result.addOperands(upperBounds); + result.addOperands(steps); + result.addOperands(inputs); + result.addOperands(outputs); + result.addAttribute( + TiledLoopOp::getOperandSegmentSizeAttr(), + builder.getI32VectorAttr({static_cast(lowerBounds.size()), + static_cast(upperBounds.size()), + static_cast(steps.size()), + static_cast(inputs.size()), + static_cast(outputs.size())})); + result.addAttribute(getIteratorTypesAttrName(), + builder.getStrArrayAttr(iteratorTypes)); + result.addTypes(outputs.getTypes()); + + OpBuilder::InsertionGuard guard(builder); + unsigned numIVs = steps.size(); + SmallVector argTypes(numIVs, builder.getIndexType()); + Region *bodyRegion = result.addRegion(); + Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes); + + if (bodyBuilderFn) { + builder.setInsertionPointToStart(bodyBlock); + bodyBuilderFn(builder, result.location, bodyBlock->getArguments()); + } + TiledLoopOp::ensureTerminator(*bodyRegion, builder, result.location); +} + static void print(OpAsmPrinter &p, TiledLoopOp op) { p << op.getOperationName() << " (" << op.getBody()->getArguments() << ") = (" << op.lowerBound() << ") to (" << op.upperBound() << ") step (" << op.step()