diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -776,6 +776,7 @@ (ins TransformHandleTypeInterface:$target, DefaultValuedAttr:$padding_values, DefaultValuedAttr:$padding_dimensions, + OptionalAttr:$pad_to_multiple_of, DefaultValuedAttr:$pack_paddings, DefaultValuedAttr< TypedArrayAttrBase, diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -148,6 +148,12 @@ paddingDimensions.assign(pd.begin(), pd.end()); return *this; } + /// A list of multiples to which each padding dimension should be padded to. + std::optional> padToMultipleOf; + LinalgPaddingOptions &setPadToMultipleOf(ArrayRef m) { + padToMultipleOf.emplace(m.begin(), m.end()); + return *this; + } /// A flag for every operand to mark the PadOp as nofold which enables /// packing for statically shaped operands. SmallVector packPaddings; @@ -350,14 +356,17 @@ void peelLoops(RewriterBase &rewriter, ArrayRef loops); /// Pad the iterator dimensions `paddingDimensions` of all `opToPad` operands -/// to a static bounding box. Use `paddingValues` and `packPaddings` to set -/// padding value and nofold attribute of the created tensor::PadOps, -/// respectively. Update `paddedOp` to the cloned operation with statically -/// shaped `paddingDimensions` and return the extracted dynamically shaped -/// results. If padding fails, return failure. +/// to a static bounding box. `padToMultipleOf` indicates that each padding +/// dimension should be padded to the specified multiple. If the derived padding +/// sizes should not be rounded up to any multiple, use "1". Use `paddingValues` +/// and `packPaddings` to set padding value and nofold attribute of the created +/// tensor::PadOps, respectively. Update `paddedOp` to the cloned operation with +/// statically shaped `paddingDimensions` and return the extracted dynamically +/// shaped results. If padding fails, return failure. FailureOr> rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, ArrayRef paddingDimensions, + ArrayRef padToMultipleOf, ArrayRef paddingValues, ArrayRef packPaddings, LinalgOp &paddedOp); diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -1594,9 +1594,14 @@ TrackingListener listener(state, *this); IRRewriter rewriter(getContext(), &listener); LinalgOp paddedOp; - FailureOr> result = rewriteAsPaddedOp( - rewriter, target, extractFromI64ArrayAttr(getPaddingDimensions()), - paddingValues, packPaddings, paddedOp); + SmallVector paddingDimensions = + extractFromI64ArrayAttr(getPaddingDimensions()); + SmallVector padToMultipleOf(paddingDimensions.size(), 1); + if (getPadToMultipleOf().has_value()) + padToMultipleOf = extractFromI64ArrayAttr(*getPadToMultipleOf()); + FailureOr> result = + rewriteAsPaddedOp(rewriter, target, paddingDimensions, padToMultipleOf, + paddingValues, packPaddings, paddedOp); if (succeeded(result)) { // We need to perform our own replacement here because this API is still // used in patterns that "pad and hoist", for which the replacement values @@ -1630,7 +1635,11 @@ "integers, found " << getPaddingDimensions(); } - + if (getPadToMultipleOf().has_value()) { + if (getPadToMultipleOf()->size() != paddingDimensions.size()) { + return emitOpError() << "expects as many multiples as padding_dimensions"; + } + } ArrayAttr transposes = getTransposePaddings(); for (Attribute attr : transposes) { SmallVector transpose = extractFromI64ArrayAttr(attr); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -48,34 +48,50 @@ /// Pad the `opOperand` in the `paddingDimensions` using the padding value and /// the nofold flag found in `paddingValues` and `packPaddings`, respectively. -/// Exit early and return the `opOperand` value if the shape dimensions that -/// match `paddingDimensions` have a static size and the nofold flag is not set. +/// +/// Exit early and return the `opOperand` value if it already has the requested +/// shape. I.e.: +/// - static shape +/// - nofold is not set +/// - dim sizes are multiples of `padToMultipleOf` +/// /// Otherwise, try to pad the shape dimensions that match the iterator /// dimensions `paddingDimensions` and return the tensor::PadOp result if /// padding succeeds or failure otherwise. static FailureOr padOperandToSmallestStaticBoundingBox( RewriterBase &rewriter, linalg::LinalgOp opToPad, OpOperand *opOperand, - ArrayRef paddingDimensions, ArrayRef paddingValues, - ArrayRef packPaddings) { + ArrayRef paddingDimensions, ArrayRef padToMultipleOf, + ArrayRef paddingValues, ArrayRef packPaddings) { + assert(padToMultipleOf.size() == paddingDimensions.size() && + "invalid number of elements in padToMultipleOf"); + AffineMap indexingMap = opToPad.getMatchingIndexingMap(opOperand); ArrayRef shape = opToPad.getShape(opOperand); - // Collect the shape dimension that are a function of the `paddingDimensions`. - llvm::SmallDenseSet shapeDimsToPad; - for (int64_t dim : paddingDimensions) - for (const auto &en : enumerate(indexingMap.getResults())) - if (en.value().isFunctionOfDim(dim)) - shapeDimsToPad.insert(en.index()); + // Collect the shape dimensions that are a function of `paddingDimensions`, + // along with the multiple that they should be padded to ("1" if none). + bool alreadyHasRequestedShape = true; + DenseMap shapeDimToMultiple; + for (const auto &dimEn : enumerate(paddingDimensions)) { + for (const auto &en : enumerate(indexingMap.getResults())) { + if (en.value().isFunctionOfDim(dimEn.value())) { + int64_t dimSize = shape[en.index()]; + shapeDimToMultiple[en.index()] = padToMultipleOf[dimEn.index()]; + if (ShapedType::isDynamic(dimSize)) { + alreadyHasRequestedShape = false; + } else if (dimSize % shapeDimToMultiple[en.index()] != 0) { + alreadyHasRequestedShape = false; + } + } + } + } // Return the unpadded operand if padding to a static shape is not needed and // if the nofold flag is not set. bool nofold = opOperand->getOperandNumber() < packPaddings.size() ? packPaddings[opOperand->getOperandNumber()] : false; - bool hasStaticShape = llvm::none_of(shapeDimsToPad, [&](int64_t dim) { - return ShapedType::isDynamic(shape[dim]); - }); - if (!nofold && hasStaticShape) + if (!nofold && alreadyHasRequestedShape) return opOperand->get(); // Fail if `paddingValues` specifies no padding value. @@ -86,12 +102,17 @@ Value paddingValue = rewriter.create( opToPad.getLoc(), cast(paddingAttr)); + // Helper function to round a number up to a given multiple. + auto ceil = [](int64_t val, int64_t multiple) { + return ((val + multiple - 1) / multiple) * multiple; + }; + // Upper bound the sizes to obtain a static bounding box. SmallVector paddedShape(shape.begin(), shape.end()); for (int64_t i = 0, e = shape.size(); i < e; ++i) { LLVM_DEBUG(DBGS() << "--compute padded size for dim " << i << "\n"); // Skip dimensions that do not require padding. - if (!shapeDimsToPad.contains(i)) { + if (!shapeDimToMultiple.contains(i)) { LLVM_DEBUG(DBGS() << "----dim does not require padding, SKIP\n"); continue; } @@ -105,7 +126,7 @@ return rewriter.notifyMatchFailure( opToPad, "count not compute a bounding box for padding"); } - paddedShape[i] = *upperBound; + paddedShape[i] = ceil(*upperBound, shapeDimToMultiple[i]); LLVM_DEBUG(DBGS() << "----new dim size: " << paddedShape[i] << "\n"); } @@ -131,9 +152,11 @@ //===----------------------------------------------------------------------===// // rewriteAsPaddedOp transformation. //===----------------------------------------------------------------------===// + FailureOr> linalg::rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, ArrayRef paddingDimensions, + ArrayRef padToMultipleOf, ArrayRef paddingValues, ArrayRef packPaddings, LinalgOp &paddedOp) { LLVM_DEBUG(DBGS() << "Start rewriteAsPaddedOp : " << opToPad << "\n"); @@ -153,8 +176,8 @@ newOperands.reserve(opToPad->getNumOperands()); for (OpOperand &opOperand : opToPad->getOpOperands()) { FailureOr paddedOperand = padOperandToSmallestStaticBoundingBox( - rewriter, opToPad, &opOperand, paddingDimensions, paddingValues, - packPaddings); + rewriter, opToPad, &opOperand, paddingDimensions, padToMultipleOf, + paddingValues, packPaddings); // Exit if `paddingDimensions` cannot be bounded statically. if (failed(paddedOperand)) { LLVM_DEBUG(DBGS() << "--operand cannot be bound statically : " @@ -241,9 +264,13 @@ // Pad the operation. LinalgOp paddedOp; - FailureOr> newResults = - rewriteAsPaddedOp(rewriter, linalgOp, options.paddingDimensions, - options.paddingValues, options.packPaddings, paddedOp); + SmallVector padToMultipleOf(options.paddingDimensions.size(), 1); + if (options.padToMultipleOf.has_value()) + padToMultipleOf.assign(options.padToMultipleOf->begin(), + options.padToMultipleOf->end()); + FailureOr> newResults = rewriteAsPaddedOp( + rewriter, linalgOp, options.paddingDimensions, padToMultipleOf, + options.paddingValues, options.packPaddings, paddedOp); if (failed(newResults)) return rewriter.notifyMatchFailure(linalgOp, "failed to rewrite as a padded op"); diff --git a/mlir/test/Dialect/Linalg/transform-op-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-pad.mlir --- a/mlir/test/Dialect/Linalg/transform-op-pad.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-pad.mlir @@ -45,6 +45,39 @@ #map = affine_map<()[s0] -> (-s0 + 12, 7)> +// CHECK-LABEL: @pad_to_multiple +func.func @pad_to_multiple(%arg0: tensor<24x12xf32>, + %arg1: tensor<12x25xf32>, + %arg2: tensor<24x25xf32>, + %iv0 : index, %iv1 : index, %iv2 : index) -> tensor<24x25xf32> { + %0 = affine.min #map()[%iv2] + %1 = tensor.extract_slice %arg0[%iv0, %iv2] [4, %0] [1, 1] : tensor<24x12xf32> to tensor<4x?xf32> + %2 = tensor.extract_slice %arg1[%iv2, %iv1] [%0, 5] [1, 1] : tensor<12x25xf32> to tensor + %3 = tensor.extract_slice %arg2[%iv0, %iv1] [4, 5] [1, 1] : tensor<24x25xf32> to tensor<4x5xf32> + + // CHECK: linalg.matmul + // CHECK-SAME: ins(%{{.*}}, %{{.*}} : tensor<4x7xf32>, tensor<7x6xf32>) + // CHECK-SAME: outs(%{{.*}} : tensor<4x6xf32>) + %4 = linalg.matmul ins(%1, %2 : tensor<4x?xf32>, tensor) outs(%3 : tensor<4x5xf32>) -> tensor<4x5xf32> + %5 = tensor.insert_slice %4 into %arg2[%iv0, %iv1] [4, 5] [1, 1] : tensor<4x5xf32> into tensor<24x25xf32> + func.return %5 : tensor<24x25xf32> +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.pad %0 { + padding_values=[0.0 : f32, 0.0 : f32, 0.0 : f32], + padding_dimensions=[0, 1, 2], + pad_to_multiple_of=[2, 2, 1], + pack_paddings=[1, 1, 0] + } : (!transform.any_op) -> !transform.any_op +} + +// ----- + +#map = affine_map<()[s0] -> (-s0 + 12, 7)> + // CHECK-LABEL: @static_sizes_output_divisible_on_empty_op func.func @static_sizes_output_divisible_on_empty_op(%arg0: tensor<24x12xf32>, %arg1: tensor<12x25xf32>, %arg2: tensor<24x25xf32>, %iv0: index,