diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -306,6 +306,31 @@ return getConstantIntValue(ofr) == static_cast(0); }); } + + /// Bubbles up a slice of this pad by taking the slice first and then + /// performing the padding. `offsets` and `strides` specifies each dimension's + /// start offset and size for the slice. The slice has unit strides along all + /// dimensions. + /// + /// Specifically, this function converts: + /// ``` + /// %0 = linalg.pad_tensor %source low[...] high[...] { linalg.yield %cst } + /// %1 = %0 offsets=[...], sizes[...] + /// ``` + /// into + /// ``` + /// %0 = tensor.extract_slice %source ... + /// %0 = linalg.pad_tensor %0 low[...] high[...] { linalg.yield %cst } + /// ``` + /// + /// If `generateZeroSliceGuard` is true, the generated IR will contain logic + /// to guard against the case that we might take a zero-sized slice from the + /// original source. For such cases, we `tensor.generate` to generate the + /// full tensor. + Operation *bubbleUpSlice(OpBuilder &, + ArrayRef offsets, + ArrayRef sizes, + bool generateZeroSliceGuard = true); }]; let builders = [ diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1326,10 +1326,27 @@ /// Rewrite extract_slice(pad_tensor(x)) into pad_tensor(extract_slice(x)). struct ExtractSliceOfPadTensorSwapPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + /// A function to control pattern application and rewrite logic. + /// + /// The function will be given the slice op and should return: + /// - None: to fail the match and not apply the pattern; + /// - true: to apply the pattern with zero slice guard; + /// - false: to apply the pattern without zero slice guard. + /// + /// See the documentation for PadTensorOp::bubbleUpSlice regarding zero slice + /// guard. + using ControlFn = std::function(tensor::ExtractSliceOp)>; + + ExtractSliceOfPadTensorSwapPattern(MLIRContext *context, + ControlFn controlFn = nullptr, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), controlFn(controlFn) {} LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const override; + +private: + ControlFn controlFn; }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1293,12 +1293,22 @@ } SmallVector PadTensorOp::getTiledImplementation( - OpBuilder &b, ValueRange dest, ArrayRef offsets, + OpBuilder &b, ValueRange /*dest*/, ArrayRef offsets, ArrayRef sizes, bool /*tileDestOperands*/) { + Operation *result = bubbleUpSlice(b, offsets, sizes); + if (!result) + return {}; + return {result}; +} + +Operation *PadTensorOp::bubbleUpSlice(OpBuilder &b, + ArrayRef offsets, + ArrayRef sizes, + bool generateZeroSliceGuard) { // Only constant padding value supported. Value padValue = getConstantPaddingValue(); if (!padValue) - return {}; + return nullptr; // Helper variables and functions for various arithmetic operations. These are // used extensively for computing new offset/length and padding values. @@ -1440,8 +1450,7 @@ // Insert cast to ensure that types match. (May be folded away.) auto castResult = [&](Value val) -> Operation * { - auto castOp = b.create(loc, resultType, val); - return castOp; + return b.create(loc, resultType, val); }; // In cases where the original data source is unused: Emit a GenerateOp and @@ -1459,8 +1468,8 @@ // Emit a SliceOp and a PadTensorOp. Should not be used in cases where // the result shape of the new SliceOp has a zero dimension. - auto createPadTensorOfSubTensor = [&]() { - // Create pad_tensor(subtensor(x)). + auto createPadTensorOfExtractSlice = [&]() { + // Create pad_tensor(extract_slice(x)). auto newSliceOp = b.create( loc, source(), newOffsets, newLengths, newStrides); auto newPadTensorOp = b.create( @@ -1474,15 +1483,14 @@ return castResult(newPadTensorOp); }; - // Rewrite subtensor(pad_tensor(x)) into a GenerateOp it is statically known - // that the original data source x is not used. - if (hasZeroLen) { - return {createGenerateOp()}; - } + // Rewrite extract_slice(pad_tensor(x)) into a GenerateOp it is statically + // known that the original data source x is not used. + if (hasZeroLen) + return createGenerateOp(); // If there are dynamic dimensions: Generate an scf.if check to avoid creating // SliceOps with result dimensions of size 0 at runtime. - if (dynHasZeroLenCond) { + if (generateZeroSliceGuard && dynHasZeroLenCond) { auto result = b.create( loc, resultType, dynHasZeroLenCond, /*thenBuilder=*/ @@ -1492,11 +1500,11 @@ /*elseBuilder=*/ [&](OpBuilder &b, Location loc) { b.create(loc, - createPadTensorOfSubTensor()->getResult(0)); + createPadTensorOfExtractSlice()->getResult(0)); }); - return {result}; + return result; } - return {createPadTensorOfSubTensor()}; + return createPadTensorOfExtractSlice(); } namespace { diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -901,21 +901,26 @@ LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite( tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const { + if (!sliceOp.hasUnitStride()) + return failure(); + auto padOp = sliceOp.source().getDefiningOp(); if (!padOp) return failure(); - // Only unit stride supported. - if (!sliceOp.hasUnitStride()) - return failure(); + + bool zeroSliceGuard = true; + if (controlFn) { + if (Optional control = controlFn(sliceOp)) + zeroSliceGuard = control.getValue(); + else + return failure(); + } Operation *tiledPadOp = - padOp - .getTiledImplementation( - rewriter, /*dest=*/ValueRange{}, sliceOp.getMixedOffsets(), - sliceOp.getMixedSizes(), /*tileDestOperands=*/false) - .front(); + padOp.bubbleUpSlice(rewriter, sliceOp.getMixedOffsets(), + sliceOp.getMixedSizes(), zeroSliceGuard); // All shapes are static and the data source is actually used. Rewrite into - // pad_tensor(subtensor(x)). + // pad_tensor(extract_slice(x)). rewriter.replaceOp(sliceOp, tiledPadOp->getResults()); return success(); }