diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -156,26 +156,37 @@ return *this; } -/// Helper function that tries to pad `opOperand`. Exit early and return success -/// for scalar operands or if `paddingFunc` returns failure. Otherwise, try to -/// pad the operand even if it already has a static shape. Set `result` to the -/// result of the created PadTensorOp or return failure if the operand cannot be -/// padded to a static shape. +/// Helper function that tries to pad `opOperand`. Exit early for scalar +/// operands, if `paddingFunc` returns failure, or if `opOperand` is not defined +/// by an ExtractSliceOp. Otherwise, try to pad the operand even if it already +/// has a static shape. Set `result` to the result of the created PadTensorOp or +/// and return success if the operand either has been padded to a static shape +/// or already had a static shape and failure otherwise. static LogicalResult padOperandToSmallestStaticBoundingBox( OpBuilder &b, linalg::LinalgOp opToPad, OpOperand *opOperand, const PaddingValueComputationFunction &paddingFunc, const PaddingNoFoldComputationFunction &nofoldFunc, Value &result) { - // Can't pad scalars. - if (opToPad.getShape(opOperand).empty()) + // Get the shape of the operand and check if it has a dynamic shape. Only + // return failure if the operand is not a scalar and has a dynamic shape. + ArrayRef shape = opToPad.getShape(opOperand); + bool hasDynamicShape = llvm::is_contained(shape, ShapedType::kDynamicSize); + + // Cannot pad scalar operands. + if (shape.empty()) return success(); - // Can't pad if no padding value is known. + + // Cannot pad if the padding value is unknown. FailureOr paddingValue = paddingFunc(b, *opOperand); if (failed(paddingValue)) - return success(); + return failure(hasDynamicShape); + + // Cannot construct a static bounding box if the operand is not defined by an + // ExtractSliceOp. auto sliceOp = opOperand->get().getDefiningOp(); - // Not a slice op, cannot construct a static bounding box. if (!sliceOp) - return failure(); + return failure(hasDynamicShape); + + // Upper bound the `sliceOp` sizes to obtain a static bounding box. SmallVector staticSizes; staticSizes.reserve(opToPad.getRank(opOperand)); auto shapedOp = cast(sliceOp.getOperation()); @@ -195,6 +206,8 @@ } staticSizes.push_back(upperBound.getValue()); } + + // Pad the operand to the bounding box defined by `staticSizes`. auto staticTensorType = RankedTensorType::get( staticSizes, getElementTypeOrSelf(opOperand->get())); bool nofold = nofoldFunc ? nofoldFunc(*opOperand) : false; @@ -490,8 +503,10 @@ FailureOr> newResults = rewriteAsPaddedOp( rewriter, linalgOp, options.paddingValueComputationFunction, options.paddingNoFoldComputationFunction, paddedOp); - if (failed(newResults)) + if (failed(newResults)) { + filter.replaceLinalgTransformationFilter(rewriter, linalgOp); return failure(); + } // Compute the desired hoisting depths. SmallVector depths; diff --git a/mlir/test/Dialect/Linalg/pad.mlir b/mlir/test/Dialect/Linalg/pad.mlir --- a/mlir/test/Dialect/Linalg/pad.mlir +++ b/mlir/test/Dialect/Linalg/pad.mlir @@ -1,5 +1,6 @@ // RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.matmul pad pack-paddings=1,1,0 run-enable-pass=false" -cse -canonicalize -split-input-file | FileCheck %s // RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.fill pad pack-paddings=1,1,0 run-enable-pass=false" -cse -canonicalize -split-input-file | FileCheck %s --check-prefix=CHECK-FILL +// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.matmul pad pack-paddings=1,1,0 pad-inputs-only run-enable-pass=false" -cse -canonicalize -split-input-file | FileCheck %s --check-prefix=INPUTS-ONLY // CHECK-DAG: #[[MAP0:[0-9a-z]+]] = affine_map<(d0) -> (7, -d0 + 12)> // CHECK-DAG: #[[MAP1:[0-9a-z]+]] = affine_map<(d0) -> (-d0 + 7)> @@ -246,3 +247,105 @@ } return %0 : tensor<24x12xf32> } + +// ----- + +#map0 = affine_map<()[s0] -> (7, s0)> + +// CHECK: static_extract_slice_missing +// CHECK-SAME: %[[ARG2:[0-9a-zA-Z]*]]: tensor<4x5xf32>, +func @static_extract_slice_missing(%arg0: tensor<24x12xf32>, + %arg1: tensor<12x25xf32>, + %arg2: tensor<4x5xf32>, + %iv0 : index, %iv1 : index, %iv2 : index) -> tensor<4x5xf32> { + %0 = affine.min #map0()[%iv2] + %1 = tensor.extract_slice %arg0[%iv0, %iv2] [4, %0] [1, 1] : tensor<24x12xf32> to tensor<4x?xf32> + %2 = tensor.extract_slice %arg1[%iv2, %iv1] [%0, 5] [1, 1] : tensor<12x25xf32> to tensor + + // Check the matmul inputs are padded despite the missing slice for the static output. + // CHECK: %[[T0:.*]] = linalg.pad_tensor + // CHECK: %[[T1:.*]] = linalg.pad_tensor + // CHECK: = linalg.matmul ins(%[[T0]], %[[T1]] + // CHECK-SAME: outs(%[[ARG2]] + %3 = linalg.matmul ins(%1, %2 : tensor<4x?xf32>, tensor) outs(%arg2 : tensor<4x5xf32>) -> tensor<4x5xf32> + return %3 : tensor<4x5xf32> +} + +// ----- + +#map0 = affine_map<()[s0] -> (7, s0)> + +// CHECK: static_and_dynamic_extract_slice_missing +// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<4x?xf32>, +// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]]: tensor<12x25xf32>, +// CHECK-SAME: %[[ARG2:[0-9a-zA-Z]*]]: tensor<4x5xf32>, +func @static_and_dynamic_extract_slice_missing(%arg0: tensor<4x?xf32>, + %arg1: tensor<12x25xf32>, + %arg2: tensor<4x5xf32>, + %iv0 : index, %iv1 : index, %iv2 : index) -> tensor<4x5xf32> { + %0 = affine.min #map0()[%iv2] + // CHECK: %[[T0:.*]] = tensor.extract_slice %[[ARG1]] + %2 = tensor.extract_slice %arg1[%iv2, %iv1] [%0, 5] [1, 1] : tensor<12x25xf32> to tensor + + // Check the matmul is not padded due to the missing slice for the dynamic input. + // CHECK: = linalg.matmul ins(%[[ARG0]], %[[T0]] + // CHECK-SAME: outs(%[[ARG2]] + %3 = linalg.matmul ins(%arg0, %2 : tensor<4x?xf32>, tensor) outs(%arg2 : tensor<4x5xf32>) -> tensor<4x5xf32> + return %3 : tensor<4x5xf32> +} + +// ----- + +#map0 = affine_map<()[s0] -> (7, s0)> + +// INPUTS-ONLY: static_input_padding_only +// INPUTS-ONLY-SAME: %[[ARG2:[0-9a-zA-Z]*]]: tensor<24x25xf32>, +func @static_input_padding_only(%arg0: tensor<24x12xf32>, + %arg1: tensor<12x25xf32>, + %arg2: tensor<24x25xf32>, + %iv0 : index, %iv1 : index, %iv2 : index) -> tensor<24x25xf32> { + %0 = affine.min #map0()[%iv2] + %1 = tensor.extract_slice %arg0[%iv0, %iv2] [4, %0] [1, 1] : tensor<24x12xf32> to tensor<4x?xf32> + %2 = tensor.extract_slice %arg1[%iv2, %iv1] [%0, 5] [1, 1] : tensor<12x25xf32> to tensor + + // INPUTS-ONLY: %[[T0:.*]] = tensor.extract_slice %[[ARG2]] + %3 = tensor.extract_slice %arg2[%iv0, %iv1] [4, 5] [1, 1] : tensor<24x25xf32> to tensor<4x5xf32> + + // Check the matmul inputs are padded despite the failure to compute a padding value for the static output. + // INPUTS-ONLY: %[[T1:.*]] = linalg.pad_tensor + // INPUTS-ONLY: %[[T2:.*]] = linalg.pad_tensor + // INPUTS-ONLY: = linalg.matmul ins(%[[T1]], %[[T2]] + // INPUTS-ONLY-SAME: outs(%[[T0]] + %4 = linalg.matmul ins(%1, %2 : tensor<4x?xf32>, tensor) outs(%3 : tensor<4x5xf32>) -> tensor<4x5xf32> + %5 = tensor.insert_slice %4 into %arg2[%iv0, %iv1] [4, 5] [1, 1] : tensor<4x5xf32> into tensor<24x25xf32> + return %5 : tensor<24x25xf32> +} + +// ----- + +#map0 = affine_map<()[s0] -> (7, s0)> + +// INPUTS-ONLY: dynamic_input_padding_only +// INPUTS-ONLY-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<24x12xf32>, +// INPUTS-ONLY-SAME: %[[ARG1:[0-9a-zA-Z]*]]: tensor<12x25xf32>, +// INPUTS-ONLY-SAME: %[[ARG2:[0-9a-zA-Z]*]]: tensor<24x25xf32>, +func @dynamic_input_padding_only(%arg0: tensor<24x12xf32>, + %arg1: tensor<12x25xf32>, + %arg2: tensor<24x25xf32>, + %iv0 : index, %iv1 : index, %iv2 : index) -> tensor<24x25xf32> { + %0 = affine.min #map0()[%iv2] + + // INPUTS-ONLY: %[[T0:.*]] = tensor.extract_slice %[[ARG0]] + // INPUTS-ONLY: %[[T1:.*]] = tensor.extract_slice %[[ARG1]] + // INPUTS-ONLY: %[[T2:.*]] = tensor.extract_slice %[[ARG2]] + %1 = tensor.extract_slice %arg0[%iv0, %iv2] [4, %0] [1, 1] : tensor<24x12xf32> to tensor<4x?xf32> + %2 = tensor.extract_slice %arg1[%iv2, %iv1] [%0, %0] [1, 1] : tensor<12x25xf32> to tensor + %3 = tensor.extract_slice %arg2[%iv0, %iv1] [4, %0] [1, 1] : tensor<24x25xf32> to tensor<4x?xf32> + + // Check the matmul is not padded due to the failure to compute a padding value for the dynamic output. + // INPUTS-ONLY: = linalg.matmul ins(%[[T0]], %[[T1]] + // INPUTS-ONLY-SAME: outs(%[[T2]] + %4 = linalg.matmul ins(%1, %2 : tensor<4x?xf32>, tensor) outs(%3 : tensor<4x?xf32>) -> tensor<4x?xf32> + %5 = tensor.insert_slice %4 into %arg2[%iv0, %iv1] [4, %0] [1, 1] : tensor<4x?xf32> into tensor<24x25xf32> + return %5 : tensor<24x25xf32> +} diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp @@ -94,13 +94,17 @@ llvm::cl::init(false)}; Option pad{*this, "pad", llvm::cl::desc("Pad the operands."), llvm::cl::init(false)}; + Option padInputsOnly{ + *this, "pad-inputs-only", + llvm::cl::desc("Only pad inputs when test-pad-pattern"), + llvm::cl::init(false)}; ListOption packPaddings{ *this, "pack-paddings", - llvm::cl::desc("Operand packing flags when test-pad-pattern"), + llvm::cl::desc("Operand packing flags when test-pad-pattern."), llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; ListOption hoistPaddings{ *this, "hoist-paddings", - llvm::cl::desc("Operand hoisting depths when test-pad-pattern"), + llvm::cl::desc("Operand hoisting depths when test-pad-pattern."), llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; Option generalize{*this, "generalize", llvm::cl::desc("Generalize named operations."), @@ -245,6 +249,17 @@ paddingOptions.setPaddingNoFoldComputationFunction(packFunc); paddingOptions.setPaddingHoistComputationFunction(hoistingFunc); + // Compute input padding values only an return failure for output operands. + if (padInputsOnly) { + paddingOptions.setPaddingValueComputationFunction( + [](OpBuilder &b, OpOperand &op) -> FailureOr { + auto linalgOp = dyn_cast(op.getOwner()); + if (linalgOp && linalgOp.isInputTensor(&op)) + return getNeutralOfLinalgOp(b, op); + return failure(); + }); + } + vector::VectorContractLowering vectorContractLowering = llvm::StringSwitch( vectorizeContractionTo.getValue())