diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1399,10 +1399,27 @@ /// Rewrite extract_slice(pad_tensor(x)) into pad_tensor(extract_slice(x)). struct ExtractSliceOfPadTensorSwapPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + /// A function to control pattern application and rewrite logic. + /// + /// The function will be given the slice op and should return: + /// - None: to fail the match and not apply the pattern; + /// - true: to apply the pattern with zero slice guard; + /// - false: to apply the pattern without zero slice guard. + /// + /// See the documentation for tensor::bubbleUpPadSlice regarding zero slice + /// guard. + using ControlFn = std::function(tensor::ExtractSliceOp)>; + + ExtractSliceOfPadTensorSwapPattern(MLIRContext *context, + ControlFn controlFn = nullptr, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {} LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const override; + +private: + ControlFn controlFn; }; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h b/mlir/include/mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h @@ -18,6 +18,32 @@ namespace mlir { namespace tensor { +class PadOp; + +/// Bubbles up a slice of this pad by taking the slice first and then performing +/// the padding. `offsets` and `strides` specifies each dimension's start offset +/// and size for the slice. The slice has unit strides along all dimensions. +/// +/// Specifically, this function converts: +/// ``` +/// %0 = tensor.pad %source low[...] high[...] { linalg.yield %cst } +/// %1 = %0 offsets=[...], sizes[...] +/// ``` +/// into +/// ``` +/// %0 = tensor.extract_slice %source ... +/// %0 = tensor.pad %0 low[...] high[...] { linalg.yield %cst } +/// ``` +/// +/// If `generateZeroSliceGuard` is true, the generated IR will contain logic +/// to guard against the case that we might take a zero-sized slice from the +/// original source. For such cases, we `tensor.generate` to generate the +/// full tensor. +Operation *bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp, + ArrayRef offsets, + ArrayRef sizes, + bool generateZeroSliceGuard = true); + /// Registers external models for Tiling interface for tensor ops. /// Currently, it registers: /// diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -54,6 +54,7 @@ MLIRStandardOpsTransforms MLIRStandardToLLVM MLIRTensor + MLIRTensorTilingInterfaceImpl MLIRTensorTransforms MLIRTransforms MLIRTransformUtils diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -19,6 +19,7 @@ #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/SCF/Transforms.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" @@ -911,23 +912,26 @@ LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite( tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const { + if (!sliceOp.hasUnitStride()) + return failure(); + auto padOp = sliceOp.source().getDefiningOp(); if (!padOp) return failure(); - // Only unit stride supported. - if (!sliceOp.hasUnitStride()) - return failure(); - TilingInterface tilingInterface = - dyn_cast(padOp.getOperation()); + bool zeroSliceGuard = true; + if (controlFn) { + if (Optional control = controlFn(sliceOp)) + zeroSliceGuard = control.getValue(); + else + return failure(); + } + Operation *tiledPadOp = - tilingInterface - .getTiledImplementation( - rewriter, /*dest=*/ValueRange{}, sliceOp.getMixedOffsets(), - sliceOp.getMixedSizes(), /*tileDestOperands=*/false) - .front(); + tensor::bubbleUpPadSlice(rewriter, padOp, sliceOp.getMixedOffsets(), + sliceOp.getMixedSizes(), zeroSliceGuard); // All shapes are static and the data source is actually used. Rewrite into - // pad_tensor(subtensor(x)). + // pad(extract_slice(x)). rewriter.replaceOp(sliceOp, tiledPadOp->getResults()); return success(); } diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp @@ -63,215 +63,223 @@ ArrayRef offsets, ArrayRef sizes, bool /*tileDestOperands*/) const { - auto padOp = cast(op); - // Only constant padding value supported. - Value padValue = padOp.getConstantPaddingValue(); - if (!padValue) + Operation *result = + tensor::bubbleUpPadSlice(b, cast(op), offsets, sizes); + if (!result) return {}; + return {result}; + } +}; - // Helper variables and functions for various arithmetic operations. These - // are used extensively for computing new offset/length and padding values. - Location loc = op->getLoc(); - AffineExpr dim0, dim1; - bindDims(b.getContext(), dim0, dim1); - // Add two integers. - auto addMap = AffineMap::get(2, 0, {dim0 + dim1}); - auto add = [&](Value v1, Value v2) { - return b.createOrFold(loc, addMap, ValueRange{v1, v2}); - }; - // Subtract two integers. - auto subMap = AffineMap::get(2, 0, {dim0 - dim1}); - auto sub = [&](Value v1, Value v2) { - return b.createOrFold(loc, subMap, ValueRange{v1, v2}); - }; - // Take the minimum of two integers. - auto idMap = AffineMap::getMultiDimIdentityMap(2, b.getContext()); - auto min = [&](Value v1, Value v2) { - return b.createOrFold(loc, idMap, ValueRange{v1, v2}); - }; - // Take the maximum of two integers. - auto max = [&](Value v1, Value v2) { - return b.createOrFold(loc, idMap, ValueRange{v1, v2}); - }; - // Zero index-typed integer. - auto zero = b.create(loc, 0); +} // namespace - // Helper function for filling static/dynamic low/high padding indices - // vectors of PadOp. - auto appendIndex = [&](Value val, SmallVector &dynIndices, - SmallVector &staticIndices) { - if (auto constInt = getConstantIntValue(val)) { - staticIndices.push_back(*constInt); - } else { - staticIndices.push_back(ShapedType::kDynamicSize); - dynIndices.push_back(val); - } - }; +Operation *tensor::bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp, + ArrayRef offsets, + ArrayRef sizes, + bool generateZeroSliceGuard) { + // Only constant padding value supported. + Value padValue = padOp.getConstantPaddingValue(); + if (!padValue) + return nullptr; - // Compute new offsets, lengths, low padding, high padding. - SmallVector newOffsets, newLengths, newStrides; - SmallVector newLows, newHighs; - SmallVector staticNewLows, staticNewHighs; - // Set to true if the original data source is not read at all. - bool hasZeroLen = false; - // Same as hasZeroLen, but for dynamic dimension sizes. This condition - // is true if the original data source turns out to be unused at runtime. - Value dynHasZeroLenCond; + // Helper variables and functions for various arithmetic operations. These + // are used extensively for computing new offset/length and padding values. + Location loc = padOp->getLoc(); + AffineExpr dim0, dim1; + bindDims(b.getContext(), dim0, dim1); + // Add two integers. + auto addMap = AffineMap::get(2, 0, {dim0 + dim1}); + auto add = [&](Value v1, Value v2) { + return b.createOrFold(loc, addMap, ValueRange{v1, v2}); + }; + // Subtract two integers. + auto subMap = AffineMap::get(2, 0, {dim0 - dim1}); + auto sub = [&](Value v1, Value v2) { + return b.createOrFold(loc, subMap, ValueRange{v1, v2}); + }; + // Take the minimum of two integers. + auto idMap = AffineMap::getMultiDimIdentityMap(2, b.getContext()); + auto min = [&](Value v1, Value v2) { + return b.createOrFold(loc, idMap, ValueRange{v1, v2}); + }; + // Take the maximum of two integers. + auto max = [&](Value v1, Value v2) { + return b.createOrFold(loc, idMap, ValueRange{v1, v2}); + }; + // Zero index-typed integer. + auto zero = b.create(loc, 0); - int64_t rank = padOp.getSourceType().getRank(); - for (unsigned dim = 0; dim < rank; ++dim) { - auto low = - getValueOrCreateConstantIndexOp(b, loc, padOp.getMixedLowPad()[dim]); - bool hasLowPad = getConstantIntValue(low) != static_cast(0); - auto high = - getValueOrCreateConstantIndexOp(b, loc, padOp.getMixedHighPad()[dim]); - bool hasHighPad = getConstantIntValue(high) != static_cast(0); - auto offset = getValueOrCreateConstantIndexOp(b, loc, offsets[dim]); - auto length = getValueOrCreateConstantIndexOp(b, loc, sizes[dim]); - auto srcSize = b.createOrFold(loc, padOp.source(), dim); + // Helper function for filling static/dynamic low/high padding indices + // vectors of PadOp. + auto appendIndex = [&](Value val, SmallVector &dynIndices, + SmallVector &staticIndices) { + if (auto constInt = getConstantIntValue(val)) { + staticIndices.push_back(*constInt); + } else { + staticIndices.push_back(ShapedType::kDynamicSize); + dynIndices.push_back(val); + } + }; - // The new amount of low padding is `low - offset`. Except for the case - // where none of the low padding is read. In that case, the new amount of - // low padding is zero. - // - // Optimization: If low = 0, then newLow = 0. - Value newLow = hasLowPad ? max(zero, sub(low, offset)) : zero; - appendIndex(newLow, newLows, staticNewLows); + // Compute new offsets, lengths, low padding, high padding. + SmallVector newOffsets, newLengths, newStrides; + SmallVector newLows, newHighs; + SmallVector staticNewLows, staticNewHighs; + // Set to true if the original data source is not read at all. + bool hasZeroLen = false; + // Same as hasZeroLen, but for dynamic dimension sizes. This condition + // is true if the original data source turns out to be unused at runtime. + Value dynHasZeroLenCond; - // Start reading the data from position `offset - low`. Since the original - // read may have started in the low padding zone, this value could be - // negative. Therefore, start reading from: - // - // max(offset - low, 0) - // - // The original read could also have started in the high padding zone. - // In that case, set the offset to the end of source tensor. The new - // ExtractSliceOp length will be zero in that case. (Effectively reading - // no data from the source.) - // - // Optimization: If low = 0, then the formula can be simplified. - Value newOffset = hasLowPad ? min(max(sub(offset, low), zero), srcSize) - : min(offset, srcSize); - newOffsets.push_back(getAsOpFoldResult(newOffset)); + int64_t rank = padOp.getSourceType().getRank(); + for (unsigned dim = 0; dim < rank; ++dim) { + auto low = + getValueOrCreateConstantIndexOp(b, loc, padOp.getMixedLowPad()[dim]); + bool hasLowPad = getConstantIntValue(low) != static_cast(0); + auto high = + getValueOrCreateConstantIndexOp(b, loc, padOp.getMixedHighPad()[dim]); + bool hasHighPad = getConstantIntValue(high) != static_cast(0); + auto offset = getValueOrCreateConstantIndexOp(b, loc, offsets[dim]); + auto length = getValueOrCreateConstantIndexOp(b, loc, sizes[dim]); + auto srcSize = b.createOrFold(loc, padOp.source(), dim); - // The original ExtractSliceOp was reading until position `offset + - // length`. Therefore, the corresponding position within the source tensor - // is: - // - // offset + length - low - // - // In case the original ExtractSliceOp stopped reading within the low - // padding zone, this value can be negative. In that case, the end - // position of the read should be zero. (Similar to newOffset.) - // - // The original read could also have stopped in the high padding zone. - // In that case, set the end positition of the read should be the end of - // the source tensor. (Similar to newOffset.) - // - // endLoc = min(max(offset - low + length, 0), srcSize) - // - // The new ExtractSliceOp length is `endLoc - newOffset`. - // - // Optimization: If low = 0, then the formula can be simplified. - Value endLoc = - hasLowPad ? min(max(add(sub(offset, low), length), zero), srcSize) - : min(add(offset, length), srcSize); - Value newLength = sub(endLoc, newOffset); - newLengths.push_back(getAsOpFoldResult(newLength)); + // The new amount of low padding is `low - offset`. Except for the case + // where none of the low padding is read. In that case, the new amount of + // low padding is zero. + // + // Optimization: If low = 0, then newLow = 0. + Value newLow = hasLowPad ? max(zero, sub(low, offset)) : zero; + appendIndex(newLow, newLows, staticNewLows); - // Check if newLength is zero. In that case, no SubTensorOp should be - // executed. - if (auto newLengthInt = getConstantIntValue(newLength)) { - hasZeroLen |= *newLengthInt == 0; - } else { - Value check = b.create(loc, arith::CmpIPredicate::eq, - newLength, zero); - dynHasZeroLenCond = - dynHasZeroLenCond - ? b.create(loc, check, dynHasZeroLenCond) - : check; - } + // Start reading the data from position `offset - low`. Since the original + // read may have started in the low padding zone, this value could be + // negative. Therefore, start reading from: + // + // max(offset - low, 0) + // + // The original read could also have started in the high padding zone. + // In that case, set the offset to the end of source tensor. The new + // ExtractSliceOp length will be zero in that case. (Effectively reading + // no data from the source.) + // + // Optimization: If low = 0, then the formula can be simplified. + Value newOffset = hasLowPad ? min(max(sub(offset, low), zero), srcSize) + : min(offset, srcSize); + newOffsets.push_back(getAsOpFoldResult(newOffset)); - // The amount of high padding is simply the number of elements remaining, - // so that the result has the same length as the original ExtractSliceOp. - // As an optimization, if the original high padding is zero, then the new - // high padding must also be zero. - Value newHigh = hasHighPad ? sub(sub(length, newLength), newLow) : zero; - appendIndex(newHigh, newHighs, staticNewHighs); + // The original ExtractSliceOp was reading until position `offset + + // length`. Therefore, the corresponding position within the source tensor + // is: + // + // offset + length - low + // + // In case the original ExtractSliceOp stopped reading within the low + // padding zone, this value can be negative. In that case, the end + // position of the read should be zero. (Similar to newOffset.) + // + // The original read could also have stopped in the high padding zone. + // In that case, set the end positition of the read should be the end of + // the source tensor. (Similar to newOffset.) + // + // endLoc = min(max(offset - low + length, 0), srcSize) + // + // The new ExtractSliceOp length is `endLoc - newOffset`. + // + // Optimization: If low = 0, then the formula can be simplified. + Value endLoc = hasLowPad + ? min(max(add(sub(offset, low), length), zero), srcSize) + : min(add(offset, length), srcSize); + Value newLength = sub(endLoc, newOffset); + newLengths.push_back(getAsOpFoldResult(newLength)); - // Only unit stride supported. - newStrides.push_back(b.getIndexAttr(1)); + // Check if newLength is zero. In that case, no SubTensorOp should be + // executed. + if (auto newLengthInt = getConstantIntValue(newLength)) { + hasZeroLen |= *newLengthInt == 0; + } else { + Value check = b.create(loc, arith::CmpIPredicate::eq, + newLength, zero); + dynHasZeroLenCond = + dynHasZeroLenCond + ? b.create(loc, check, dynHasZeroLenCond) + : check; } - // The shape of the result can be obtained from the sizes passed in. - SmallVector dynDims; - SmallVector shape; - dispatchIndexOpFoldResults(sizes, dynDims, shape, ShapedType::kDynamicSize); - RankedTensorType resultType = - RankedTensorType::get(shape, padOp.getResultType().getElementType()); + // The amount of high padding is simply the number of elements remaining, + // so that the result has the same length as the original ExtractSliceOp. + // As an optimization, if the original high padding is zero, then the new + // high padding must also be zero. + Value newHigh = hasHighPad ? sub(sub(length, newLength), newLow) : zero; + appendIndex(newHigh, newHighs, staticNewHighs); + + // Only unit stride supported. + newStrides.push_back(b.getIndexAttr(1)); + } - // Insert cast to ensure that types match. (May be folded away.) - auto castResult = [&](Value val) -> Operation * { - auto castOp = b.create(loc, resultType, val); - return castOp; - }; + // The shape of the result can be obtained from the sizes passed in. + SmallVector dynDims; + SmallVector shape; + dispatchIndexOpFoldResults(sizes, dynDims, shape, ShapedType::kDynamicSize); + RankedTensorType resultType = + RankedTensorType::get(shape, padOp.getResultType().getElementType()); - // In cases where the original data source is unused: Emit a GenerateOp and - // do not generate a SliceOp. (The result shape of the SliceOp would - // have a dimension of size 0, the semantics of which is unclear.) - auto createGenerateOp = [&]() { - // Create GenerateOp. - auto generateOp = b.create( - loc, resultType, dynDims, - [&](OpBuilder &builder, Location gLoc, ValueRange indices) { - builder.create(gLoc, padValue); - }); - return castResult(generateOp); - }; + // Insert cast to ensure that types match. (May be folded away.) + auto castResult = [&](Value val) -> Operation * { + return b.create(loc, resultType, val); + }; - // Emit a SliceOp and a PadOp. Should not be used in cases where - // the result shape of the new SliceOp has a zero dimension. - auto createPadTensorOfSubTensor = [&]() { - // Create pad_tensor(subtensor(x)). - auto newSliceOp = b.create( - loc, padOp.source(), newOffsets, newLengths, newStrides); - auto newPadOp = b.create(loc, newSliceOp, staticNewLows, - staticNewHighs, newLows, newHighs); + // In cases where the original data source is unused: Emit a GenerateOp and + // do not generate a SliceOp. (The result shape of the SliceOp would + // have a dimension of size 0, the semantics of which is unclear.) + auto createGenerateOp = [&]() { + // Create GenerateOp. + auto generateOp = b.create( + loc, resultType, dynDims, + [&](OpBuilder &builder, Location gLoc, ValueRange indices) { + builder.create(gLoc, padValue); + }); + return castResult(generateOp); + }; - // Copy region to new PadOp. - BlockAndValueMapping bvm; - padOp.region().cloneInto(&newPadOp.getRegion(), bvm); + // Emit a SliceOp and a PadOp. Should not be used in cases where + // the result shape of the new SliceOp has a zero dimension. + auto createPadOfExtractSlice = [&]() { + // Create pad(extract_slice(x)). + auto newSliceOp = b.create( + loc, padOp.source(), newOffsets, newLengths, newStrides); + auto newPadOp = b.create(loc, newSliceOp, staticNewLows, + staticNewHighs, newLows, newHighs); - // Cast result and return. - return castResult(newPadOp); - }; + // Copy region to new PadOp. + BlockAndValueMapping bvm; + padOp.region().cloneInto(&newPadOp.getRegion(), bvm); - // Rewrite subtensor(pad_tensor(x)) into a GenerateOp it is statically known - // that the original data source x is not used. - if (hasZeroLen) - return {createGenerateOp()}; + // Cast result and return. + return castResult(newPadOp); + }; - // If there are dynamic dimensions: Generate an scf.if check to avoid - // creating SliceOps with result dimensions of size 0 at runtime. - if (dynHasZeroLenCond) { - auto result = b.create( - loc, resultType, dynHasZeroLenCond, - /*thenBuilder=*/ - [&](OpBuilder &b, Location loc) { - b.create(loc, createGenerateOp()->getResult(0)); - }, - /*elseBuilder=*/ - [&](OpBuilder &b, Location loc) { - b.create(loc, - createPadTensorOfSubTensor()->getResult(0)); - }); - return {result}; - } - return {createPadTensorOfSubTensor()}; - } -}; + // Rewrite extract_slice(pad(x)) into a GenerateOp it is statically known that + // the original data source x is not used. + if (hasZeroLen) + return createGenerateOp(); -} // namespace + // If there are dynamic dimensions: Generate an scf.if check to avoid + // creating SliceOps with result dimensions of size 0 at runtime. + if (generateZeroSliceGuard && dynHasZeroLenCond) { + auto result = b.create( + loc, resultType, dynHasZeroLenCond, + /*thenBuilder=*/ + [&](OpBuilder &b, Location loc) { + b.create(loc, createGenerateOp()->getResult(0)); + }, + /*elseBuilder=*/ + [&](OpBuilder &b, Location loc) { + b.create(loc, createPadOfExtractSlice()->getResult(0)); + }); + return result; + } + return createPadOfExtractSlice(); +} void mlir::tensor::registerTilingOpInterfaceExternalModels( DialectRegistry ®istry) { diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -7084,6 +7084,7 @@ ":StandardOpsTransforms", ":Support", ":TensorDialect", + ":TensorTilingInterfaceImpl", ":TensorTransforms", ":TensorUtils", ":TransformUtils",