diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -678,7 +678,7 @@ Tensor_Op, Arguments<(ins AnyTensor:$src, IndexListArrayAttr:$reassociation)>, Results<(outs AnyTensor:$result)> { - + code commonExtraClassDeclaration = [{ static StringRef getReassociationAttrName() { return "reassociation"; } SmallVector getReassociationMaps(); @@ -982,6 +982,8 @@ return getConstantIntValue(ofr) == static_cast(0); }); } + /// Return the dimensions with a non-zero low or high padding. + llvm::SmallBitVector getPaddedDims(); }]; let builders = [ diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -1858,6 +1858,18 @@ result.addAttributes(attrs); } +llvm::SmallBitVector PadOp::getPaddedDims() { + llvm::SmallBitVector paddedDims(getSourceType().getRank()); + auto extractPaddedDims = [&](ArrayRef paddingWidths) { + for (const auto &en : enumerate(paddingWidths)) + if (getConstantIntValue(en.value()) != static_cast(0)) + paddedDims.set(en.index()); + }; + extractPaddedDims(getMixedLowPad()); + extractPaddedDims(getMixedHighPad()); + return paddedDims; +} + namespace { // Folds tensor.pad when padding is static zeros and the attribute // doesn't request otherwise. @@ -1940,13 +1952,169 @@ return success(); } }; + +/// Fold chains of tensor::ExtractSliceOp, tensor::PadOp pairs that pad +/// different dimensions. The pattern applies if the following preconditions +/// hold: +/// 1) the tensor::ExtractSliceOps are not rank-reducing, +/// 2) the tensor::ExtractSliceOps have only unit-strides, +/// 3) the tensor::PadOps perform only high-padding, +/// 4) the tensor::PadOps have the same padding value, +/// 5) the tensor::PadOps have common padding dimensions, +/// 6) one tensor::ExtractSliceOp, tensor::PadOp pair has zero-padding and +/// zero-offset for every dimension. +/// 7) the tensor::ExtractSliceOp sizes match the source tensor sizes for the +/// padded source dimensions. +/// +/// Example: +/// +/// ```mlir +/// %0 = tensor.extract_slice %input[16, 0] [%sz0, 64] [1, 1] +/// : tensor<64x64xf32> to tensor +/// %1 = tensor.pad %0 low[0, 0] high[%pw0, 0] { ... +/// } : tensor to tensor<8x64xf32> +/// %2 = tensor.extract_slice %1[0, 4] [8, %sz1] [1, 1] +/// : tensor<8x64xf32> to tensor<8x?xf32> +/// %res = tensor.pad %2 nofold low[0, 0] high[0, %pw1] { ... +/// } : tensor<8x?xf32> to tensor<8x4xf32> +/// ``` +/// +/// folds into: +/// +/// ```mlir +/// %0 = tensor.extract_slice %input[16, 4] [%sz0, %sz1] [1, 1] +/// : tensor<64x64xf32> to tensor +/// %res = tensor.pad %0 nofold low[0, 0] high[%pw0, %pw1] { ... +/// } : tensor to tensor<8x4xf32> +/// ``` +struct FoldOrthogonalPaddings : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(PadOp padOp, + PatternRewriter &rewriter) const override { + auto innerSliceOp = padOp.source().getDefiningOp(); + if (!innerSliceOp) + return failure(); + auto outerPadOp = innerSliceOp.source().getDefiningOp(); + if (!outerPadOp || outerPadOp.nofold()) + return failure(); + auto outerSliceOp = outerPadOp.source().getDefiningOp(); + if (!outerSliceOp) + return failure(); + + // 1) Fail if the chain is rank-reducing. + int64_t rank = padOp.getSourceType().getRank(); + if (outerSliceOp.getSourceType().getRank() != rank) { + return rewriter.notifyMatchFailure(padOp, + "cannot fold rank-reducing chain"); + } + + // 2) Fail if the tensor::ExtractSliceOps have non-unit strides. + if (!innerSliceOp.hasUnitStride() || !outerSliceOp.hasUnitStride()) { + return rewriter.notifyMatchFailure( + padOp, "cannot fold non-unit stride ExtractSliceOps"); + } + + // 3) Fail if the tensor::PadOps have non-zero low padding. + if (!padOp.hasZeroLowPad() || !outerPadOp.hasZeroLowPad()) { + return rewriter.notifyMatchFailure(padOp, + "cannot fold PadOps with low padding"); + } + + // 4) Fail if the tensor::PadOps padding values do not match. + Attribute innerAttr, outerAttr; + Value innerValue = padOp.getConstantPaddingValue(); + Value outerValue = outerPadOp.getConstantPaddingValue(); + if (!innerValue || !outerValue || + !matchPattern(innerValue, m_Constant(&innerAttr)) || + !matchPattern(outerValue, m_Constant(&outerAttr)) || + innerAttr != outerAttr) { + return rewriter.notifyMatchFailure( + padOp, "cannot fold PadOps with different padding values"); + } + + // 5) Fail if a dimension is padded by both tensor::PadOps. + llvm::SmallBitVector innerDims = padOp.getPaddedDims(); + llvm::SmallBitVector outerDims = outerPadOp.getPaddedDims(); + if (innerDims.anyCommon(outerDims)) { + return rewriter.notifyMatchFailure( + padOp, "cannot fold PadOps with common padding dimensions"); + } + + // 6) Combine the offsets of the two tensor::ExtractSliceOps. Find the + // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair + // for every dimension, and use the offset the other pair. Fail if no + // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair + // exists. + SmallVector newOffsets(rank, rewriter.getIndexAttr(0)); + for (auto &en : enumerate(newOffsets)) { + OpFoldResult innerOffset = innerSliceOp.getMixedOffsets()[en.index()]; + OpFoldResult outerOffset = outerSliceOp.getMixedOffsets()[en.index()]; + if (!innerDims.test(en.index()) && + (getConstantIntValue(innerOffset) == static_cast(0))) { + en.value() = outerOffset; + continue; + } + if (!outerDims.test(en.index()) && + (getConstantIntValue(outerOffset) == static_cast(0))) { + en.value() = innerOffset; + continue; + } + return rewriter.notifyMatchFailure( + padOp, "cannot find zero-offset and zero-padding pair"); + } + + // 7) Combine the sizes of the two tensor::ExtractSliceOps. Take the size of + // the outer tensor::ExtractSliceOp for the dimensions padded by the outer + // tensor::PadOp and fail if the size of the inner tensor::ExtractSliceOp + // does not match the size of the padded dimension. Otherwise, take the size + // of the inner tensor::ExtractSliceOp. + SmallVector newSizes = innerSliceOp.getMixedSizes(); + for (auto &en : enumerate(newSizes)) { + if (!outerDims.test(en.index())) + continue; + OpFoldResult extractedSize = innerSliceOp.getMixedSizes()[en.index()]; + int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()]; + assert(!ShapedType::isDynamic(sourceSize) && + "expected padded dimension to have a static size"); + if (getConstantIntValue(extractedSize) != sourceSize) { + return rewriter.notifyMatchFailure( + padOp, "cannot fold since the ExtractSliceOp size does not match " + "the size of the previously padded tensor dimension"); + } + en.value() = outerSliceOp.getMixedSizes()[en.index()]; + } + + // Combine the high paddings of the two tensor::PadOps. + SmallVector newHighPad(rank, rewriter.getIndexAttr(0)); + for (auto &en : enumerate(newHighPad)) { + if (innerDims.test(en.index())) + newHighPad[en.index()] = padOp.getMixedHighPad()[en.index()]; + if (outerDims.test(en.index())) + newHighPad[en.index()] = outerPadOp.getMixedHighPad()[en.index()]; + } + + // Create a new tensor::ExtractSliceOp, tensor::PadOP pair the performs the + // two paddings in one step. + auto newSliceOp = rewriter.create( + padOp.getLoc(), outerSliceOp.source(), newOffsets, newSizes, + innerSliceOp.getMixedStrides()); + auto newPadOp = rewriter.create( + padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(), + padOp.getMixedLowPad(), newHighPad, padOp.nofold()); + rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(), + newPadOp.getRegion().begin()); + rewriter.replaceOp(padOp, newPadOp.getResult()); + return success(); + } +}; + } // namespace void PadOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results - .add( - context); + results.add(context); } /// Return the padding value of the PadOp if it constant. In this context, diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -1198,6 +1198,95 @@ // ----- +// CHECK-LABEL: func @fold_orthogonal_pad_chains( +// CHECK-SAME: %[[ARG0:.*]]: tensor<64x64xf32>, +// CHECK-SAME: %[[SZ0:.*]]: index, %[[SZ1:.*]]: index, %[[PW0:.*]]: index, %[[PW1:.*]]: index +func.func @fold_orthogonal_pad_chains(%arg0: tensor<64x64xf32>, + %sz0 : index, %sz1 : index, + %pw0 : index, %pw1 : index) -> tensor<8x4xf32> { + // CHECK: %[[T0:.*]] = tensor.extract_slice %[[ARG0]] + // CHECK-SAME: [16, 4] [%[[SZ0]], %[[SZ1]]] + // CHECK: %[[PAD:.*]] = tensor.pad %[[T0]] nofold + // CHECK-SAME: high[%[[PW0]], %[[PW1]]] + // CHECK: return %[[PAD]] + %pad_value = arith.constant 0.0 : f32 + %0 = tensor.extract_slice %arg0[16, 0] [%sz0, 64] [1, 1] : tensor<64x64xf32> to tensor + %1 = tensor.pad %0 low[0, 0] high[%pw0, 0] { + ^bb0(%arg1: index, %arg2: index): + tensor.yield %pad_value : f32 + } : tensor to tensor<8x64xf32> + %2 = tensor.extract_slice %1[0, 4] [8, %sz1] [1, 1] : tensor<8x64xf32> to tensor<8x?xf32> + %3 = tensor.pad %2 nofold low[0, 0] high[0, %pw1] { + ^bb0(%arg1: index, %arg2: index): + tensor.yield %pad_value : f32 + } : tensor<8x?xf32> to tensor<8x4xf32> + func.return %3 : tensor<8x4xf32> +} + +// ----- + +// CHECK-LABEL: func @dont_fold_pad_chains( +// CHECK-SAME: %[[ARG0:.*]]: tensor<64x64xf32>, +// CHECK-SAME: %[[SZ0:.*]]: index, %[[SZ1:.*]]: index, %[[PW0:.*]]: index, %[[PW1:.*]]: index +func.func @dont_fold_pad_chains(%arg0: tensor<64x64xf32>, + %sz0 : index, %sz1 : index, + %pw0 : index, %pw1 : index) -> (tensor<8x4xf32>, tensor<4x64xf32>, tensor<8x4xf32>, tensor<6x4xf32>) { + // CHECK: %[[T0:.*]] = tensor.extract_slice %[[ARG0]] + // CHECK: %[[T1:.*]] = tensor.pad %[[T0]] + %pad_value = arith.constant 0.0 : f32 + %0 = tensor.extract_slice %arg0[16, 0] [%sz0, 64] [1, 1] : tensor<64x64xf32> to tensor + %1 = tensor.pad %0 low[0, 0] high[%pw0, 0] { + ^bb0(%arg1: index, %arg2: index): + tensor.yield %pad_value : f32 + } : tensor to tensor<8x64xf32> + + // Don't fold if the padding values are different. + // CHECK: %[[T2:.*]] = tensor.extract_slice %[[T1]] + // CHECK-SAME: [0, 4] [8, %[[SZ1]]] + // CHECK: %[[PAD0:.*]] = tensor.pad %[[T2]] + %different_value = arith.constant 1.0 : f32 + %2 = tensor.extract_slice %1[0, 4] [8, %sz1] [1, 1] : tensor<8x64xf32> to tensor<8x?xf32> + %3 = tensor.pad %2 nofold low[0, 0] high[0, %pw1] { + ^bb0(%arg1: index, %arg2: index): + tensor.yield %different_value : f32 + } : tensor<8x?xf32> to tensor<8x4xf32> + + // Don't fold if the pad ops have common padding dimensions. + // CHECK: %[[T3:.*]] = tensor.extract_slice %[[T1]] + // CHECK-SAME: [4, 0] [%[[SZ1]], 64] + // CHECK: %[[PAD1:.*]] = tensor.pad %[[T3]] + %4 = tensor.extract_slice %1[4, 0] [%sz1, 64] [1, 1] : tensor<8x64xf32> to tensor + %5 = tensor.pad %4 nofold low[0, 0] high[%pw1, 0] { + ^bb0(%arg1: index, %arg2: index): + tensor.yield %pad_value : f32 + } : tensor to tensor<4x64xf32> + + // Don't fold if padded source tensor dimension is accessed at an offset. + // CHECK: %[[T4:.*]] = tensor.extract_slice %[[T1]] + // CHECK-SAME: [%[[SZ0]], 4] [8, %[[SZ1]] + // CHECK: %[[PAD2:.*]] = tensor.pad %[[T4]] + %6 = tensor.extract_slice %1[%sz0, 4] [8, %sz1] [1, 1] : tensor<8x64xf32> to tensor<8x?xf32> + %7 = tensor.pad %6 nofold low[0, 0] high[0, %pw1] { + ^bb0(%arg1: index, %arg2: index): + tensor.yield %pad_value : f32 + } : tensor<8x?xf32> to tensor<8x4xf32> + + // Don't fold if a padded source tensor dimension is sliced. + // CHECK: %[[T5:.*]] = tensor.extract_slice %[[T1]] + // CHECK-SAME: [0, 4] [6, %[[SZ1]] + // CHECK: %[[PAD3:.*]] = tensor.pad %[[T5]] + %8 = tensor.extract_slice %1[0, 4] [6, %sz1] [1, 1] : tensor<8x64xf32> to tensor<6x?xf32> + %9 = tensor.pad %8 nofold low[0, 0] high[0, %pw1] { + ^bb0(%arg1: index, %arg2: index): + tensor.yield %pad_value : f32 + } : tensor<6x?xf32> to tensor<6x4xf32> + + // CHECK: return %[[PAD0]], %[[PAD1]], %[[PAD2]], %[[PAD3]] + func.return %3, %5, %7, %9 : tensor<8x4xf32>, tensor<4x64xf32>, tensor<8x4xf32>, tensor<6x4xf32> +} + +// ----- + // CHECK-LABEL: func @fold_collapse_shape_from_elements func @fold_collapse_shape_from_elements(%arg0: i32) -> tensor { // CHECK: %[[FROM:.+]] = tensor.from_elements %arg0 : tensor