diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -93,20 +93,42 @@ /// /// Example: /// ``` -/// %0 = tensor.extract_slice %arg0[3, 4][3, 32][1, 1] : tensor<64x64xf32> to +/// %0 = tensor.extract_slice %arg0[3, 4][3, 32][1, 1] : tensor<64x64xf32> to /// tensor<3x32xf32> -/// %1 = tensor.extract_slice %0[0, 5][3, 4][1, 1] : tensor<3x32xf32> to +/// %1 = tensor.extract_slice %0[0, 5][3, 4][1, 1] : tensor<3x32xf32> to /// tensor<3x4xf32> /// ``` /// folds into: /// ``` -/// %1 = tensor.extract_slice %arg0[3, 9][3, 4][1, 1] : tensor<64x64xf32> to -/// tensor<3x4xf32> +/// %1 = tensor.extract_slice %arg0[3, 9][3, 4][1, 1] : tensor<64x64xf32> to +/// tensor<3x4xf32> /// ``` tensor::ExtractSliceOp makeComposedExtractSliceOp( OpBuilder &b, Location loc, Value source, ArrayRef offsets, ArrayRef sizes, ArrayRef strides); +/// Create a PadTensorOp that pads `source` to the size of the statically sized +/// `type` whose static sizes are assumed to be greater than the dynamic +/// `source` size. The padding introduces trailing `pad` values until the target +/// size is met. If `source` is defined by one or more LinalgOps that have been +/// padded with the same value and sizes, return their padded result instead of +/// creating a PadTensorOp. +/// +/// Example: +/// ``` +/// %0 = tensor.extract_slice %arg0 [%iv0, %iv1] [%sz0, %sz1] +/// %1 = linalg.pad_tensor %0 low[0, 0] high[...] { linalg.yield %cst } +/// %2 = linalg.matmul ins(...) outs(%1) +/// %3 = tensor.extract_slice %2 [0, 0] [%sz0, %sz1] +/// ``` +/// makeComposedPadHighOp(source=%3, pad=%cst) returns %2 +/// makeComposedPadHighOp(source=%3, pad=%other_cst) returns %4 +/// ``` +/// %4 = linalg.pad_tensor %3 low[0, 0] high[...] { linalg.yield %other_cst } +/// ``` +Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, + Value source, Value pad, bool nofold); + //===----------------------------------------------------------------------===// // Fusion / Tiling utilities //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -211,9 +211,9 @@ auto staticTensorType = RankedTensorType::get( staticSizes, getElementTypeOrSelf(opOperand->get())); bool nofold = nofoldFunc ? nofoldFunc(*opOperand) : false; - result = linalg::PadTensorOp::createPadHighOp( - staticTensorType, opOperand->get(), paddingValue.getValue(), - /*nofold=*/nofold, opToPad->getLoc(), b); + result = + makeComposedPadHighOp(b, opToPad->getLoc(), staticTensorType, + opOperand->get(), paddingValue.getValue(), nofold); return success(); } diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -322,6 +322,66 @@ foldedOffsets, sizes, strides); } +Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, + Value source, Value pad, bool nofold) { + assert(type.hasStaticShape() && "expect tensor type to have static shape"); + + // Exit if `source` is not defined by an ExtractSliceOp. + auto sliceOp = source.getDefiningOp(); + if (!sliceOp) + return PadTensorOp::createPadHighOp(type, source, pad, nofold, loc, b); + + // Search the `source` use-def chain for padded LinalgOps. + Value current = sliceOp.source(); + while (current) { + auto linalgOp = current.getDefiningOp(); + if (!linalgOp) + break; + OpResult opResult = current.cast(); + current = linalgOp.getOutputOperand(opResult.getResultNumber())->get(); + } + auto padTensorOp = current ? current.getDefiningOp() : nullptr; + + // Exit if the search fails to match a PadTensorOp at the end of the matched + // LinalgOp sequence. + if (!padTensorOp) + return PadTensorOp::createPadHighOp(type, source, pad, nofold, loc, b); + + // Exit if the padded result type does not match. + if (sliceOp.source().getType() != type) + return PadTensorOp::createPadHighOp(type, source, pad, nofold, loc, b); + + // Exit if the LinalgOps are not high padded. + if (llvm::any_of(padTensorOp.getMixedLowPad(), [](OpFoldResult ofr) { + return getConstantIntValue(ofr) != static_cast(0); + })) + return PadTensorOp::createPadHighOp(type, source, pad, nofold, loc, b); + + // Exit if the sizes of the dynamic sizes of `sliceOp` do not match the size + // of the slice padded by `padTensorOp`. + auto padTensorOpSliceOp = + padTensorOp.source().getDefiningOp(); + if (!padTensorOpSliceOp || + llvm::any_of(llvm::zip(sliceOp.getMixedSizes(), + padTensorOpSliceOp.getMixedSizes()), + [](std::tuple it) { + return !isEqualConstantIntOrValue(std::get<0>(it), + std::get<1>(it)); + })) + return PadTensorOp::createPadHighOp(type, source, pad, nofold, loc, b); + + // Exit if the padding values do not match. + Attribute padTensorOpPadAttr, padAttr; + Value padTensorOpPad = padTensorOp.getConstantPaddingValue(); + if (!padTensorOpPad || + !matchPattern(padTensorOpPad, m_Constant(&padTensorOpPadAttr)) || + !matchPattern(pad, m_Constant(&padAttr)) || padTensorOpPadAttr != padAttr) + return PadTensorOp::createPadHighOp(type, source, pad, nofold, loc, b); + + // Return the padded result if the padding values and sizes match. + return sliceOp.source(); +} + /// Specialization to build an scf "for" nest. template <> void GenerateLoopNest::doit( diff --git a/mlir/test/Dialect/Linalg/pad.mlir b/mlir/test/Dialect/Linalg/pad.mlir --- a/mlir/test/Dialect/Linalg/pad.mlir +++ b/mlir/test/Dialect/Linalg/pad.mlir @@ -214,6 +214,123 @@ // ----- +#map0 = affine_map<(d0) -> (64, d0)> + +// CHECK: compose_padding +// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<64x64xf32> +func @compose_padding(%arg0: tensor<64x64xf32>, + %iv0 : index) -> tensor { + %cst = arith.constant 0.0 : f32 + + // CHECK: %[[SIZE:.*]] = affine.min + %size = affine.min #map0(%iv0) + + // CHECK: %[[T0:.*]] = tensor.extract_slice %[[ARG0]] + // CHECK-SAME: [0, 0] + // CHECK-SAME: [%[[SIZE]], %[[SIZE]]] + // CHECK: %[[T1:.*]] = linalg.pad_tensor %[[T0]] + // CHECK: %[[T2:.*]] = linalg.fill(%{{.*}}, %[[T1]] + // CHECK: %[[T3:.*]] = linalg.fill(%{{.*}}, %[[T2]] + %0 = tensor.extract_slice %arg0[0, 0] [%size, %size] [1, 1] : tensor<64x64xf32> to tensor + %1 = linalg.pad_tensor %0 low[0, 0] high[%iv0, %iv0] { + ^bb0(%arg3: index, %arg4: index): // no predecessors + linalg.yield %cst : f32 + } : tensor to tensor<64x64xf32> + %2 = linalg.fill(%cst, %1) : f32, tensor<64x64xf32> -> tensor<64x64xf32> + %3 = linalg.fill(%cst, %2) : f32, tensor<64x64xf32> -> tensor<64x64xf32> + %4 = tensor.extract_slice %3[0, 0] [%size, %size] [1, 1] : tensor<64x64xf32> to tensor + + // Check there are no additional pad tensor operations. + // CHECK-NOT: linalg.pad_tensor + + // Check the matmul directly uses the result of the fill operation. + // CHECK: %[[T4:.*]] = linalg.matmul ins(%[[T3]] + // CHECK: %[[T5:.*]] = tensor.extract_slice %[[T4]] + // CHECK-SAME: [0, 0] + // CHECK-SAME: [%[[SIZE]], %[[SIZE]]] + %5 = linalg.matmul ins(%4, %4 : tensor, tensor) outs(%4 : tensor) -> tensor + + // CHECK: return %[[T5]] + return %5 : tensor +} + +// ----- + +#map0 = affine_map<(d0) -> (64, d0)> + +// CHECK: different_padding_values +func @different_padding_values(%arg0: tensor<64x64xf32>, + %iv0 : index) -> tensor { + %cst = arith.constant 42.0 : f32 + %size = affine.min #map0(%iv0) + %0 = tensor.extract_slice %arg0[0, 0] [%size, %size] [1, 1] : tensor<64x64xf32> to tensor + %1 = linalg.pad_tensor %0 low[0, 0] high[%iv0, %iv0] { + ^bb0(%arg3: index, %arg4: index): // no predecessors + linalg.yield %cst : f32 + } : tensor to tensor<64x64xf32> + %2 = linalg.fill(%cst, %1) : f32, tensor<64x64xf32> -> tensor<64x64xf32> + %4 = tensor.extract_slice %2[0, 0] [%size, %size] [1, 1] : tensor<64x64xf32> to tensor + + // Different padding values prevent composing the paddings (42.0 vs. 0.0). + // CHECK: = linalg.fill + // CHECK: = linalg.pad_tensor + // CHECK: = linalg.matmul + %5 = linalg.matmul ins(%4, %4 : tensor, tensor) outs(%4 : tensor) -> tensor + return %5 : tensor +} + +// ----- + +#map0 = affine_map<(d0) -> (64, d0)> + +// CHECK: different_padding_dynamic_sizes +func @different_padding_dynamic_sizes(%arg0: tensor<64x64xf32>, + %iv0 : index) -> tensor { + %cst = arith.constant 0.0 : f32 + %size = affine.min #map0(%iv0) + %0 = tensor.extract_slice %arg0[0, 0] [%iv0, %iv0] [1, 1] : tensor<64x64xf32> to tensor + %1 = linalg.pad_tensor %0 low[0, 0] high[%iv0, %iv0] { + ^bb0(%arg3: index, %arg4: index): // no predecessors + linalg.yield %cst : f32 + } : tensor to tensor<64x64xf32> + %2 = linalg.fill(%cst, %1) : f32, tensor<64x64xf32> -> tensor<64x64xf32> + %4 = tensor.extract_slice %2[0, 0] [%size, %size] [1, 1] : tensor<64x64xf32> to tensor + + // Different dynamic sizes prevent composing the paddings (%iv0 vs %size). + // CHECK: = linalg.fill + // CHECK: = linalg.pad_tensor + // CHECK: = linalg.matmul + %5 = linalg.matmul ins(%4, %4 : tensor, tensor) outs(%4 : tensor) -> tensor + return %5 : tensor +} + +// ----- + +#map0 = affine_map<(d0) -> (64, d0)> + +// CHECK: different_padding_static_sizes +func @different_padding_static_sizes(%arg0: tensor<62x62xf32>, + %iv0 : index) -> tensor { + %cst = arith.constant 0.0 : f32 + %size = affine.min #map0(%iv0) + %0 = tensor.extract_slice %arg0[0, 0] [%size, %size] [1, 1] : tensor<62x62xf32> to tensor + %1 = linalg.pad_tensor %0 low[0, 0] high[%iv0, %iv0] { + ^bb0(%arg3: index, %arg4: index): // no predecessors + linalg.yield %cst : f32 + } : tensor to tensor<62x62xf32> + %2 = linalg.fill(%cst, %1) : f32, tensor<62x62xf32> -> tensor<62x62xf32> + %4 = tensor.extract_slice %2[0, 0] [%size, %size] [1, 1] : tensor<62x62xf32> to tensor + + // Different static sizes prevent composing the paddings (62 vs 64 derived from #map0). + // CHECK: = linalg.fill + // CHECK: = linalg.pad_tensor + // CHECK: = linalg.matmul + %5 = linalg.matmul ins(%4, %4 : tensor, tensor) outs(%4 : tensor) -> tensor + return %5 : tensor +} + +// ----- + #map = affine_map<(d0) -> (7, -d0 + 12)> // CHECK-FILL: scalar_operand