diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -133,7 +133,7 @@ [AttrSizedOperandSegments, NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = "tensor pad operation"; let description = [{ diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td --- a/mlir/include/mlir/Interfaces/TilingInterface.td +++ b/mlir/include/mlir/Interfaces/TilingInterface.td @@ -56,7 +56,7 @@ step for the loops of the operation. }], /*retTy=*/"SmallVector", - /*methodName=*/"getLoopBounds", + /*methodName=*/"getIterationDomain", /*args=*/(ins "OpBuilder &":$b) >, InterfaceMethod< @@ -64,7 +64,7 @@ Method to generate the tiled implementation of an operation. The iteration space of the operation is returned by - `getLoopBounds`. The caller provides the information of the + `getIterationDomain`. The caller provides the information of the tile within this iteration space whose implementation the caller needs. - `dest` are the Value into which the result of the tiled @@ -74,20 +74,24 @@ - `offsets` provides the offset of the tile within the iteration space - `sizes` provides the size of the tile. + - `tileDestOperands` specifies whether to also tile `dest` operands + or not. Avoiding tiling `dest` operands can be useful for + composition with various looping container ops. The method returns the operation that is the tiled implementation. }], - /*retType=*/"Operation *", + /*retType=*/"SmallVector", /*methodName=*/"getTiledImplementation", /*args=*/(ins "OpBuilder &":$b, "ValueRange ":$dest, "ArrayRef ":$offsets, - "ArrayRef ":$sizes), + "ArrayRef ":$sizes, + "bool ":$tileDestOperands), /*methodBody=*/"", /*defaultImplementation=*/[{ - return nullptr; + return {}; }] > ]; diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1233,7 +1233,7 @@ return iteratorTypes; } -SmallVector PadTensorOp::getLoopBounds(OpBuilder &b) { +SmallVector PadTensorOp::getIterationDomain(OpBuilder &b) { ReifiedRankedShapedTypeDims reifiedShapes; (void)reifyResultShapes(b, reifiedShapes); Value zero = b.create(getLoc(), 0); @@ -1246,13 +1246,13 @@ return loopRanges; } -Operation *PadTensorOp::getTiledImplementation(OpBuilder &b, ValueRange dest, - ArrayRef offsets, - ArrayRef sizes) { +SmallVector PadTensorOp::getTiledImplementation( + OpBuilder &b, ValueRange dest, ArrayRef offsets, + ArrayRef sizes, bool /*tileDestOperands*/) { // Only constant padding value supported. Value padValue = getConstantPaddingValue(); if (!padValue) - return nullptr; + return {}; // Helper variables and functions for various arithmetic operations. These are // used extensively for computing new offset/length and padding values. @@ -1431,7 +1431,7 @@ // Rewrite subtensor(pad_tensor(x)) into a GenerateOp it is statically known // that the original data source x is not used. if (hasZeroLen) { - return createGenerateOp(); + return {createGenerateOp()}; } // If there are dynamic dimensions: Generate an scf.if check to avoid creating @@ -1448,9 +1448,9 @@ b.create(loc, createPadTensorOfSubTensor()->getResult(0)); }); - return result; + return {result}; } - return createPadTensorOfSubTensor(); + return {createPadTensorOfSubTensor()}; } namespace { diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -351,7 +351,7 @@ options.tileSizeComputationFunction(builder, op); assert(static_cast(tileSizes.size()) == rank); // Compute lower and upper bounds of the loop nest. - SmallVector ranges = op.getLoopBounds(builder); + SmallVector ranges = op.getIterationDomain(builder); SmallVector lbs, dims, allDims, steps; for (int64_t i = 0; i < rank; ++i) { allDims.push_back(ranges[i].size); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -907,9 +907,12 @@ if (!sliceOp.hasUnitStride()) return failure(); - Operation *tiledPadOp = padOp.getTiledImplementation( - rewriter, /*dest=*/ValueRange{}, sliceOp.getMixedOffsets(), - sliceOp.getMixedSizes()); + Operation *tiledPadOp = + padOp + .getTiledImplementation( + rewriter, /*dest=*/ValueRange{}, sliceOp.getMixedOffsets(), + sliceOp.getMixedSizes(), /*tileDestOperands=*/false) + .front(); // All shapes are static and the data source is actually used. Rewrite into // pad_tensor(subtensor(x)). rewriter.replaceOp(sliceOp, tiledPadOp->getResults());