diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -866,6 +866,9 @@ int64_t rank = padOp.getSourceType().getRank(); for (unsigned dim = 0; dim < rank; ++dim) { auto low = asValue(rewriter, loc, padOp.getMixedLowPad()[dim]); + bool hasLowPad = getConstantIntValue(low) != static_cast(0); + auto high = asValue(rewriter, loc, padOp.getMixedHighPad()[dim]); + bool hasHighPad = getConstantIntValue(high) != static_cast(0); auto offset = asValue(rewriter, loc, sliceOp.getMixedOffsets()[dim]); auto length = asValue(rewriter, loc, sliceOp.getMixedSizes()[dim]); auto srcSize = @@ -874,7 +877,9 @@ // The new amount of low padding is `low - offset`. Except for the case // where none of the low padding is read. In that case, the new amount of // low padding is zero. - Value newLow = max(zero, sub(low, offset)); + // + // Optimization: If low = 0, then newLow = 0. + Value newLow = hasLowPad ? max(zero, sub(low, offset)) : zero; appendIndex(newLow, newLows, staticNewLows); // Start reading the data from position `offset - low`. Since the original @@ -887,7 +892,10 @@ // In that case, set the offset to the end of source tensor. The new // ExtractSliceOp length will be zero in that case. (Effectively reading no // data from the source.) - Value newOffset = min(max(sub(offset, low), zero), srcSize); + // + // Optimization: If low = 0, then the formula can be simplified. + Value newOffset = hasLowPad ? min(max(sub(offset, low), zero), srcSize) + : min(offset, srcSize); newOffsets.push_back(getAsOpFoldResult(newOffset)); // The original ExtractSliceOp was reading until position `offset + length`. @@ -906,7 +914,11 @@ // endLoc = min(max(offset - low + length, 0), srcSize) // // The new ExtractSliceOp length is `endLoc - newOffset`. - Value endLoc = min(max(add(sub(offset, low), length), zero), srcSize); + // + // Optimization: If low = 0, then the formula can be simplified. + Value endLoc = hasLowPad + ? min(max(add(sub(offset, low), length), zero), srcSize) + : min(add(offset, length), srcSize); Value newLength = sub(endLoc, newOffset); newLengths.push_back(getAsOpFoldResult(newLength)); @@ -925,7 +937,9 @@ // The amount of high padding is simply the number of elements remaining, // so that the result has the same length as the original ExtractSliceOp. - Value newHigh = sub(sub(length, newLength), newLow); + // As an optimization, if the original high padding is zero, then the new + // high padding must also be zero. + Value newHigh = hasHighPad ? sub(sub(length, newLength), newLow) : zero; appendIndex(newHigh, newHighs, staticNewHighs); // Only unit stride supported. diff --git a/mlir/test/Dialect/Linalg/subtensor-of-padtensor.mlir b/mlir/test/Dialect/Linalg/subtensor-of-padtensor.mlir --- a/mlir/test/Dialect/Linalg/subtensor-of-padtensor.mlir +++ b/mlir/test/Dialect/Linalg/subtensor-of-padtensor.mlir @@ -153,3 +153,42 @@ return %1 : tensor<3x4xf32> } +// ----- + +// CHECK-LABEL: @dynamic_zero_low_padding +// CHECK: scf.if +// CHECK: tensor.generate +// CHECK: else +// CHECK: %[[SLICE:.*]] = tensor.extract_slice +// CHECK: linalg.pad_tensor %[[SLICE]] low[0, 0] +func @dynamic_zero_low_padding(%arg0 : tensor, %pad : f32, + %o1 : index, %o2 : index, + %s1 : index, %s2 : index) + -> tensor { + %0 = linalg.pad_tensor %arg0 low[0, 0] high[7, 8] { + ^bb0(%arg1: index, %arg2: index): + linalg.yield %pad : f32 + } : tensor to tensor + %1 = tensor.extract_slice %0[%o1, %o2] [%s1, %s2] [1, 1] : tensor to tensor + return %1 : tensor +} + +// ----- + +// CHECK-LABEL: @dynamic_zero_high_padding +// CHECK: scf.if +// CHECK: tensor.generate +// CHECK: else +// CHECK: %[[SLICE:.*]] = tensor.extract_slice +// CHECK: linalg.pad_tensor %[[SLICE]] low[%{{.*}}, %{{.*}}] high[0, 0] +func @dynamic_zero_high_padding(%arg0 : tensor, %pad : f32, + %o1 : index, %o2 : index, + %s1 : index, %s2 : index) + -> tensor { + %0 = linalg.pad_tensor %arg0 low[7, 8] high[0, 0] { + ^bb0(%arg1: index, %arg2: index): + linalg.yield %pad : f32 + } : tensor to tensor + %1 = tensor.extract_slice %0[%o1, %o2] [%s1, %s2] [1, 1] : tensor to tensor + return %1 : tensor +}