diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1076,6 +1076,15 @@ const FrozenRewritePatternSet &stage2Patterns, function_ref stage3Lambda = nullptr); +/// Rewrite subtensor(pad_tensor(x)) into pad_tensor(subtensor(x)). +struct SubTensorOfPadTensorSwapPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(SubTensorOp subTensorOp, + PatternRewriter &rewriter) const override; +}; + } // namespace linalg } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -700,3 +700,225 @@ return success(); } + +/// Given an OpFoldResult, return a Value. If the OpFoldResult is an Attribute, +/// it must be of type Integer. +static Value asValue(OpBuilder &builder, Location loc, OpFoldResult ofr) { + if (auto val = ofr.dyn_cast()) + return val; + auto intVal = getConstantIntValue(ofr); + assert(intVal && "expected Value or IntegerAttr"); + return builder.create(loc, *intVal); +} + +/// Given a value, try to extract a constant index-type integer as an Attribute. +/// If this fails, return the original value. +static OpFoldResult asOpFoldResult(OpBuilder &builder, Value val) { + if (auto constInt = getConstantIntValue(val)) + return builder.getIndexAttr(*constInt); + return val; +} + +LogicalResult SubTensorOfPadTensorSwapPattern::matchAndRewrite( + SubTensorOp subTensorOp, PatternRewriter &rewriter) const { + auto padOp = subTensorOp.source().getDefiningOp(); + if (!padOp) + return failure(); + // Only unit stride supported. + if (!subTensorOp.hasUnitStride()) + return failure(); + // Only constant padding value supported. + Value padValue = padOp.getConstantPaddingValue(); + if (!padValue) + return failure(); + // Only zero low padding supported at the moment. + if (!padOp.hasZeroLowPad()) + return failure(); + + // Helper variables and functions for various arithmetic operations. These are + // used extensively for computing new offset/length and padding values. + Location loc = subTensorOp.getLoc(); + AffineExpr dim0, dim1; + bindDims(rewriter.getContext(), dim0, dim1); + // Add two integers. + auto addMap = AffineMap::get(2, 0, {dim0 + dim1}); + auto add = [&](Value v1, Value v2) { + return rewriter.createOrFold(loc, addMap, + ValueRange{v1, v2}); + }; + // Subtract two integers. + auto subMap = AffineMap::get(2, 0, {dim0 - dim1}); + auto sub = [&](Value v1, Value v2) { + return rewriter.createOrFold(loc, subMap, + ValueRange{v1, v2}); + }; + // Take the minimum of two integers. + auto idMap = AffineMap::getMultiDimIdentityMap(2, rewriter.getContext()); + auto min = [&](Value v1, Value v2) { + return rewriter.createOrFold(loc, idMap, ValueRange{v1, v2}); + }; + // Take the maximum of two integers. + auto max = [&](Value v1, Value v2) { + return rewriter.createOrFold(loc, idMap, ValueRange{v1, v2}); + }; + // Zero index-typed integer. + auto zero = rewriter.create(loc, 0); + + // Helper function for filling static/dynamic low/high padding indices vectors + // of PadTensorOp. + auto appendIndex = [&](Value val, SmallVector &dynIndices, + SmallVector &staticIndices) { + if (auto constInt = getConstantIntValue(val)) { + staticIndices.push_back(*constInt); + } else { + staticIndices.push_back(ShapedType::kDynamicSize); + dynIndices.push_back(val); + } + }; + + // Compute new offsets, lengths, low padding, high padding. + SmallVector newOffsets, newLengths, newStrides; + SmallVector newLows, newHighs; + SmallVector staticNewLows, staticNewHighs; + // Set to true if the original data source is not read at all. + bool hasZeroLen = false; + // Same as hasZeroLen, but for dynamic dimension sizes. This condition + // is true if the original data source turns out to be unused at runtime. + Value dynHasZeroLenCond; + + int64_t rank = padOp.getSourceType().getRank(); + for (unsigned dim = 0; dim < rank; ++dim) { + auto offset = asValue(rewriter, loc, subTensorOp.getMixedOffsets()[dim]); + auto length = asValue(rewriter, loc, subTensorOp.getMixedSizes()[dim]); + auto srcSize = rewriter.createOrFold( + loc, padOp.source(), dim); + + // Existing low padding is zero, so new low padding is also zero. + Value newLow = zero; + appendIndex(newLow, newLows, staticNewLows); + + // There is no low padding, so the offset remains unchanged. Except for the + // case where the SubTensorOp starts reading from a position within the high + // padding. In that case, set the offset to the end of source tensor. The + // new SubTensorOp length will be zero in that case. (Effectively reading no + // data from the source.) + Value newOffset = min(offset, srcSize); + newOffsets.push_back(asOpFoldResult(rewriter, newOffset)); + + // The new SubTensorOp starts reading at `newOffset` and reads until + // `offset + length`. This position may be outside of the source (i.e., + // within the high padding). In that case, read only until the end of the + // source. In mathematical terms: + // + // endLoc = min(offset + length, srcSize) + // + // The new SubTensorOp length is `endLoc - newOffset`. + Value newLength = sub(min(add(offset, length), srcSize), newOffset); + newLengths.push_back(asOpFoldResult(rewriter, newLength)); + if (auto newLengthInt = getConstantIntValue(newLength)) { + hasZeroLen |= *newLengthInt == 0; + } else { + Value check = rewriter.create( + loc, CmpIPredicate::eq, newLength, zero); + dynHasZeroLenCond = dynHasZeroLenCond + ? rewriter.create(loc, check, dynHasZeroLenCond) : check; + } + + // The number of elements available to read from the source (starting from + // the new offset) is `maxRead = srcSize - newOffset`. The original + // SubTensorOp may have read a larger number of elements `length > maxRead`. + // In that case, the missing number of elements `length - maxRead` must be + // paddded. (If `maxRead > length`, more than enough data is available to + // read and no high padding is needed.) + Value newHigh = max(zero, add(sub(newOffset, srcSize), length)); + appendIndex(newHigh, newHighs, staticNewHighs); + + // Only unit stride supported. + newStrides.push_back(rewriter.getIndexAttr(1)); + } + + // Insert cast to ensure that types match. (May be folded away.) + auto castResult = [&](Value val) -> Value { + auto castOp = rewriter.create( + loc, subTensorOp.getType(), val); + return castOp; + }; + + // In cases where the original data source is unused: Emit a GenerateOp and + // do not generate a SubTensorOp. (The result shape of the SubTensorOp would + // have a dimension of size 0, the semantics of which is unclear.) + auto createGenerateOp = [&]() { + // The shape of the GenerateOp is the same as the existing SubTensorOp. + RankedTensorType type = subTensorOp.getType(); + SmallVector dynDims; + for (unsigned i = 0; i < type.getRank(); ++i) { + if (type.isDynamicDim(i)) + dynDims.push_back( + asValue(rewriter, loc, subTensorOp.getMixedOffsets()[i])); + } + + // Create GenerateOp. + auto generateOp = rewriter.create(loc, type, dynDims); + + // Copy region to new op. + BlockAndValueMapping bvm; + padOp.region().cloneInto(&generateOp.getRegion(), bvm); + // Rewrite linalg::YieldOp to tensor::YieldOp. + { + OpBuilder::InsertionGuard guard(rewriter); + auto yieldOp = dyn_cast( + generateOp.getRegion().front().getTerminator()); + assert(yieldOp && "malformed PadTensorOp: expected YieldOp terminator"); + assert(yieldOp.values().size() == 1); + rewriter.setInsertionPoint(yieldOp); + rewriter.replaceOpWithNewOp( + yieldOp, yieldOp.values()[0]); + } + + return castResult(generateOp); + }; + + // Emit a SubTensorOp and a PadTensorOp. Should not be used in cases where + // the result shape of the new SubTensorOp has a zero dimension. + auto createPadTensorOfSubTensor = [&]() { + // Create pad_tensor(subtensor(x)). + auto newSubTensorOp = rewriter.create( + loc, padOp.source(), newOffsets, newLengths, newStrides); + auto newPadTensorOp = rewriter.create( + loc, newSubTensorOp, staticNewLows, staticNewHighs, newLows, newHighs); + + // Copy region to new PadTensorOp. + BlockAndValueMapping bvm; + padOp.region().cloneInto(&newPadTensorOp.getRegion(), bvm); + + // Cast result and return. + return castResult(newPadTensorOp); + }; + + // Rewrite subtensor(pad_tensor(x)) into a GenerateOp it is statically known + // that the original data source x is not used. + if (hasZeroLen) { + rewriter.replaceOp(subTensorOp, createGenerateOp()); + return success(); + } + + // If there are dynamic dimensions: Generate an scf.if check to avoid creating + // SubTensorOps with result dimensions of size 0 at runtime. + if (dynHasZeroLenCond) { + auto result = rewriter.create( + loc, subTensorOp.getType(), dynHasZeroLenCond, + /*thenBuilder=*/[&](OpBuilder &b, Location loc) { + b.create(loc, createGenerateOp()); + }, + /*elseBuilder=*/[&](OpBuilder &b, Location loc) { + b.create(loc, createPadTensorOfSubTensor()); + }); + rewriter.replaceOp(subTensorOp, result.getResult(0)); + return success(); + } + + // All shapes are static and the data source is actually used. Rewrite into + // pad_tensor(subtensor(x)). + rewriter.replaceOp(subTensorOp, createPadTensorOfSubTensor()); + return success(); +} diff --git a/mlir/test/Dialect/Linalg/subtensor-of-padtensor.mlir b/mlir/test/Dialect/Linalg/subtensor-of-padtensor.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/subtensor-of-padtensor.mlir @@ -0,0 +1,80 @@ +// RUN: mlir-opt %s -test-linalg-transform-patterns=test-swap-subtensor-padtensor -canonicalize -split-input-file | FileCheck %s + +// CHECK-LABEL: @static_data_only( +// CHECK-SAME: %[[ARG0:.*]]: tensor<4x5xf32> +// CHECK: %[[RESULT:.*]] = subtensor %[[ARG0]][1, 2] [2, 1] [1, 1] : tensor<4x5xf32> to tensor<2x1xf32> +// CHECK: return %[[RESULT]] +func @static_data_only(%arg0 : tensor<4x5xf32>, %pad : f32) + -> tensor<2x1xf32> { + %0 = linalg.pad_tensor %arg0 low[0, 0] high[7, 8] { + ^bb0(%arg1: index, %arg2: index): + linalg.yield %pad : f32 + } : tensor<4x5xf32> to tensor<11x13xf32> + %1 = subtensor %0[1, 2] [2, 1] [1, 1] : tensor<11x13xf32> to tensor<2x1xf32> + return %1 : tensor<2x1xf32> +} + +// ----- + +// CHECK-LABEL: @static_high_pad_only +// CHECK-SAME: %[[ARG0:.*]]: tensor<4x5xf32>, %[[PAD:.*]]: f32 +// CHECK-NOT: linalg.pad_tensor +// CHECK-NOT: subtensor +// CHECK: %[[RESULT:.*]] = tensor.generate +// CHECK: tensor.yield %[[PAD]] +// CHECK: return %[[RESULT]] : tensor<2x4xf32> +func @static_high_pad_only(%arg0 : tensor<4x5xf32>, %pad : f32) + -> tensor<2x4xf32> { + %0 = linalg.pad_tensor %arg0 low[0, 0] high[7, 8] { + ^bb0(%arg1: index, %arg2: index): + linalg.yield %pad : f32 + } : tensor<4x5xf32> to tensor<11x13xf32> + %1 = subtensor %0[4, 5] [2, 4] [1, 1] : tensor<11x13xf32> to tensor<2x4xf32> + return %1 : tensor<2x4xf32> +} + +// ----- + +// CHECK-LABEL: @static_mixed_data_high_pad +// CHECK-SAME: %[[ARG0:.*]]: tensor<4x5xf32>, %[[PAD:.*]]: f32 +// CHECK-NOT: linalg.pad_tensor +// CHECK: %[[SUBTENSOR:.*]] = subtensor %[[ARG0]][2, 4] [2, 1] [1, 1] : tensor<4x5xf32> to tensor<2x1xf32> +// CHECK: %[[RESULT:.*]] = linalg.pad_tensor %[[SUBTENSOR]] low[0, 0] high[1, 3] +// CHECK: linalg.yield %[[PAD]] +// CHECK: return %[[RESULT]] : tensor<3x4xf32> +func @static_mixed_data_high_pad(%arg0 : tensor<4x5xf32>, %pad : f32) + -> tensor<3x4xf32> { + %0 = linalg.pad_tensor %arg0 low[0, 0] high[7, 8] { + ^bb0(%arg1: index, %arg2: index): + linalg.yield %pad : f32 + } : tensor<4x5xf32> to tensor<11x13xf32> + %1 = subtensor %0[2, 4] [3, 4] [1, 1] : tensor<11x13xf32> to tensor<3x4xf32> + return %1 : tensor<3x4xf32> +} + +// ----- + +// CHECK-LABEL: @dynamic_high_pad +// CHECK-SAME: %[[ARG0:.*]]: tensor +// CHECK-NOT: linalg.pad_tensor +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: memref.dim %[[ARG0]], %[[C0]] +// CHECK: %[[RESULT:.*]] = scf.if %{{.*}} -> (tensor<3x4xf32>) { +// CHECK: %[[GEN:.*]] = tensor.generate +// CHECK: scf.yield %[[GEN]] +// CHECK: } else { +// CHECK: %[[SUBTENSOR:.*]] = subtensor %[[ARG0]][%{{.*}}, 4] [%{{.*}}, 1] [1, 1] : tensor to tensor +// CHECK: %[[PADTENSOR:.*]] = linalg.pad_tensor %[[SUBTENSOR]] low[0, 0] high[%{{.*}}, 3] +// CHECK: %[[CAST:.*]] = tensor.cast %[[PADTENSOR]] : tensor to tensor<3x4xf32> +// CHECK: scf.yield %[[CAST]] +// CHECK: } +// CHECK: return %[[RESULT]] +func @dynamic_high_pad(%arg0 : tensor, %h1: index, %pad : f32) -> tensor<3x4xf32> { + %0 = linalg.pad_tensor %arg0 low[0, 0] high[%h1, 8] { + ^bb0(%arg1: index, %arg2: index): + linalg.yield %pad : f32 + } : tensor to tensor + %1 = subtensor %0[2, 4] [3, 4] [1, 1] : tensor to tensor<3x4xf32> + return %1 : tensor<3x4xf32> +} + diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -97,6 +97,11 @@ *this, "test-transform-pad-tensor", llvm::cl::desc("Test transform pad tensor by copying with generic ops"), llvm::cl::init(false)}; + Option testSwapSubTensorPadTensor{ + *this, "test-swap-subtensor-padtensor", + llvm::cl::desc("Test rewrite of subtensor(pad_tensor) into " + "pad_tensor(subtensor)"), + llvm::cl::init(false)}; ListOption tileSizesForPadding{ *this, "tile-sizes-for-padding", llvm::cl::desc("Linalg tile sizes when tile+pad"), llvm::cl::ZeroOrMore, @@ -524,6 +529,12 @@ (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); } +static void applySubTensorOfPadTensorSwapPattern(FuncOp funcOp) { + RewritePatternSet patterns(funcOp.getContext()); + patterns.add(funcOp.getContext()); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); +} + static void applyAffineMinSCFCanonicalizationPatterns(FuncOp funcOp) { RewritePatternSet foldPattern(funcOp.getContext()); foldPattern.add(funcOp.getContext()); @@ -602,6 +613,8 @@ return applyLinalgToVectorPatterns(getFunction()); if (testTransformPadTensor) return applyPadTensorToGenericPatterns(getFunction()); + if (testSwapSubTensorPadTensor) + return applySubTensorOfPadTensorSwapPattern(getFunction()); if (testAffineMinSCFCanonicalizationPatterns) return applyAffineMinSCFCanonicalizationPatterns(getFunction()); if (testTileAndPadPattern)