diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -553,69 +553,32 @@ SmallVectorImpl &ivs, const LoopIndexToRangeIndexMap &loopIndexToRangeIndex); -/// Callback returning the padding value to use for a given OpOperand or failure -/// for no padding. This should be a function of both the operation and the -/// operand type. -using PaddingValueComputationFunction = - std::function(OpBuilder &, OpOperand &)>; - -/// Callback returning true if the PadOp defining the given OpOperand shall be -/// marked as nofold to enable packing. -using PaddingNoFoldComputationFunction = std::function; - -/// Callback returning the number of loops to hoist the PadOp defining the given -/// OpOperand. -using PaddingHoistComputationFunction = std::function; - -/// Callback returning the transpose vector used to permute the result tensor -/// dimensions of the PadOp defining the given OpOperand. -using PaddingTransposeComputationFunction = - std::function(OpOperand &)>; - struct LinalgPaddingOptions { - /// Callback returning the padding value to use for a given OpOperand or - /// failure for no padding. Padding operations are introduced if - /// `paddingValueComputationFunction` is set and does not return failure. - /// Padding all operands guarantees the operation is statically shaped and - /// thus can be vectorized. - PaddingValueComputationFunction paddingValueComputationFunction = nullptr; - - LinalgPaddingOptions & - setPaddingValueComputationFunction(PaddingValueComputationFunction fun) { - paddingValueComputationFunction = std::move(fun); + /// A padding value for every operand. + SmallVector paddingValues; + LinalgPaddingOptions &setPaddingValues(ArrayRef pv) { + paddingValues.assign(pv.begin(), pv.end()); return *this; } - - /// Callback returning true if the PadOp defining the given OpOperand shall be - /// marked as nofold to enable packing. A padding operation is only marked - /// nofold if `paddingNoFoldComputationFunction` is set and returns true. - /// Otherwise, the nofold attribute is set to false. - PaddingNoFoldComputationFunction paddingNoFoldComputationFunction = nullptr; - - LinalgPaddingOptions & - setPaddingNoFoldComputationFunction(PaddingNoFoldComputationFunction fun) { - paddingNoFoldComputationFunction = std::move(fun); + /// A flag for every operand to mark the PadOp as nofold which enables packing + /// for statically shaped operands. + SmallVector packPaddings; + LinalgPaddingOptions &setPackPaddings(ArrayRef pp) { + packPaddings.assign(pp.begin(), pp.end()); return *this; } - - /// Callback returning the number of loops to hoist the PadOp defining the - /// given OpOperand. - PaddingHoistComputationFunction paddingHoistComputationFunction = nullptr; - - LinalgPaddingOptions & - setPaddingHoistComputationFunction(PaddingHoistComputationFunction fun) { - paddingHoistComputationFunction = std::move(fun); + /// A number of loops to hoist the PadOp out for every operand. + SmallVector hoistPaddings; + LinalgPaddingOptions &setHoistPaddings(ArrayRef hp) { + hoistPaddings.assign(hp.begin(), hp.end()); return *this; } - - /// Callback returning the transpose vector used to permute the result tensor - /// dimensions of the PadOp defining the given OpOperand. - PaddingTransposeComputationFunction paddingTransposeComputationFunction = - nullptr; - - LinalgPaddingOptions &setPaddingTransposeComputationFunction( - PaddingTransposeComputationFunction fun) { - paddingTransposeComputationFunction = std::move(fun); + /// A permutation vector for every operand used to transpose the packed PadOp + /// results. + SmallVector> transposePaddings; + LinalgPaddingOptions & + setTransposePaddings(ArrayRef> tp) { + transposePaddings.assign(tp.begin(), tp.end()); return *this; } }; @@ -1254,16 +1217,15 @@ PatternRewriter &rewriter) const override; }; -/// Pad the operands of `opToPad` to a static bounding box. Use `paddingFunc` -/// and `nofoldFunc` to set the padding value and the nofold attribute of the +/// Pad the operands of `opToPad` to a static bounding box. Use `paddingValues` +/// and `packPaddings` to set the padding value and the nofold attribute of the /// introduced tensor::PadOps, respectively. Update `paddedOp` to the cloned /// statically shaped operation and return the extracted dynamically shaped /// results. If padding fails, return failure. FailureOr> rewriteAsPaddedOp(OpBuilder &b, LinalgOp opToPad, - const PaddingValueComputationFunction &paddingFunc, - const PaddingNoFoldComputationFunction &nofoldFunc, - LinalgOp &paddedOp); + ArrayRef paddingValues, + ArrayRef packPaddings, LinalgOp &paddedOp); using OptimizeCopyFn = std::function; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -158,16 +158,17 @@ return *this; } -/// Helper function that tries to pad `opOperand`. Exit early for scalar -/// operands, if `paddingFunc` returns failure, or if `opOperand` is not defined -/// by an ExtractSliceOp. Otherwise, try to pad the operand even if it already -/// has a static shape. Set `result` to the result of the created tensor::PadOp -/// or and return success if the operand either has been padded to a static -/// shape or already had a static shape and failure otherwise. +/// Pad `opOperand` using the provided `paddingValues`. Exit early for scalar +/// operands, if `paddingValues` contains no value for the `opOperand`, or if +/// `opOperand` is not defined by an ExtractSliceOp. Otherwise, try to pad the +/// operand even if it already has a static shape. Set `result` to the result of +/// the created tensor::PadOp or and return success if the operand either has +/// been padded to a static shape or already had a static shape and failure +/// otherwise. static LogicalResult padOperandToSmallestStaticBoundingBox( OpBuilder &b, linalg::LinalgOp opToPad, OpOperand *opOperand, - const PaddingValueComputationFunction &paddingFunc, - const PaddingNoFoldComputationFunction &nofoldFunc, Value &result) { + ArrayRef paddingValues, ArrayRef packPaddings, + Value &result) { // Get the shape of the operand and check if it has a dynamic shape. Only // return failure if the operand is not a scalar and has a dynamic shape. ArrayRef shape = opToPad.getShape(opOperand); @@ -178,9 +179,11 @@ return success(); // Cannot pad if the padding value is unknown. - FailureOr paddingValue = paddingFunc(b, *opOperand); - if (failed(paddingValue)) + if (opOperand->getOperandNumber() >= paddingValues.size()) return failure(hasDynamicShape); + Attribute paddingAttr = paddingValues[opOperand->getOperandNumber()]; + Value paddingValue = b.create( + opToPad.getLoc(), paddingAttr.getType(), paddingAttr); // Follow the use-def chain if `currOpOperand` is defined by a LinalgOp. OpOperand *currOpOperand = opOperand; @@ -227,18 +230,18 @@ // Pad the operand to the bounding box defined by `staticSizes`. auto staticTensorType = RankedTensorType::get( staticSizes, getElementTypeOrSelf(opOperand->get())); - bool nofold = nofoldFunc ? nofoldFunc(*opOperand) : false; - result = - makeComposedPadHighOp(b, opToPad->getLoc(), staticTensorType, - opOperand->get(), paddingValue.getValue(), nofold); + bool nofold = opOperand->getOperandNumber() < packPaddings.size() + ? packPaddings[opOperand->getOperandNumber()] + : false; + result = makeComposedPadHighOp(b, opToPad->getLoc(), staticTensorType, + opOperand->get(), paddingValue, nofold); return success(); } FailureOr> linalg::rewriteAsPaddedOp(OpBuilder &b, LinalgOp opToPad, - const PaddingValueComputationFunction &paddingFunc, - const PaddingNoFoldComputationFunction &nofoldFunc, - LinalgOp &paddedOp) { + ArrayRef paddingValues, + ArrayRef packPaddings, LinalgOp &paddedOp) { Location loc = opToPad->getLoc(); // TODO: there are cases where we may still want to pad to larger sizes. @@ -256,7 +259,7 @@ // If padding was requested but the shape cannot be bounded statically then // the pattern fails to apply. if (failed(padOperandToSmallestStaticBoundingBox( - b, opToPad, opOperand, paddingFunc, nofoldFunc, paddedOperand))) + b, opToPad, opOperand, paddingValues, packPaddings, paddedOperand))) return failure(); newOperands.push_back(paddedOperand ? paddedOperand : opOperand->get()); } @@ -498,21 +501,16 @@ // Pad the operation. LinalgOp paddedOp; - FailureOr> newResults = rewriteAsPaddedOp( - rewriter, linalgOp, options.paddingValueComputationFunction, - options.paddingNoFoldComputationFunction, paddedOp); + FailureOr> newResults = + rewriteAsPaddedOp(rewriter, linalgOp, options.paddingValues, + options.packPaddings, paddedOp); if (failed(newResults)) return failure(); - // Compute the desired hoisting depths. - SmallVector depths; - if (options.paddingHoistComputationFunction) { - for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) - depths.push_back(options.paddingHoistComputationFunction(*opOperand)); - } - // Hoist the padding. - for (const auto &en : enumerate(depths)) { + for (const auto &en : enumerate(options.hoistPaddings)) { + if (static_cast(en.index()) >= paddedOp.getNumInputsAndOutputs()) + break; OpOperand &opOperand = paddedOp->getOpOperand(en.index()); auto padOp = opOperand.get().getDefiningOp(); if (!padOp || en.value() == 0) @@ -520,7 +518,9 @@ tensor::PadOp hoistedOp; SmallVector transposeOps; SmallVector transposeVector = - options.paddingTransposeComputationFunction(opOperand); + en.index() < options.transposePaddings.size() + ? options.transposePaddings[en.index()] + : SmallVector{}; FailureOr newResult = hoistPaddingOnTensors( padOp, en.value(), transposeVector, hoistedOp, transposeOps); diff --git a/mlir/test/Dialect/Linalg/codegen-strategy.mlir b/mlir/test/Dialect/Linalg/codegen-strategy.mlir --- a/mlir/test/Dialect/Linalg/codegen-strategy.mlir +++ b/mlir/test/Dialect/Linalg/codegen-strategy.mlir @@ -1,13 +1,13 @@ // RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=2,4,8 vectorize vectorize-contraction-to=matrixintrinsics unroll-vector-transfers=true" -split-input-file | FileCheck %s --check-prefix=CHECK-INTRINSIC // RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 promote promote-full-tile-pad register-tile-sizes=2,4,8 vectorize vectorize-contraction-to=outerproduct split-transfers=true unroll-vector-transfers=false" -split-input-file | FileCheck %s --check-prefix=CHECK-OUTER // RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 tile-interchange=1,2,0 generalize iterator-interchange=0,2,1" -split-input-file | FileCheck %s --check-prefix=CHECK-INTERCHANGE -// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 pad pack-paddings=1,1,0 hoist-paddings=3,3,0" -split-input-file | FileCheck %s --check-prefix=CHECK-PAD -// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 fuse pad vectorize" -split-input-file | FileCheck %s --check-prefix=CHECK-FUSE -// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=conv anchor-op=linalg.conv_2d_nhwc_hwcf tile-sizes=1,1,8,32,1,1,8 fuse pad decompose vectorize vectorize-padding" -split-input-file | FileCheck %s --check-prefix=CHECK-DECOMP +// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 pad padding-values=0.:f32,0.:f32,0.:f32 pack-paddings=1,1,0 hoist-paddings=3,3,0" -split-input-file | FileCheck %s --check-prefix=CHECK-PAD +// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=matmul anchor-op=linalg.matmul tile-sizes=16,32,64 fuse pad padding-values=0.:f32,0.:f32,0.:f32 vectorize" -split-input-file | FileCheck %s --check-prefix=CHECK-FUSE +// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-func=conv anchor-op=linalg.conv_2d_nhwc_hwcf tile-sizes=1,1,8,32,1,1,8 fuse pad padding-values=0.:f32,0.:f32,0.:f32 decompose vectorize vectorize-padding" -split-input-file | FileCheck %s --check-prefix=CHECK-DECOMP // CHECK-INTRINSIC: func @matmul( // CHECK-OUTER: func @matmul( -func @matmul(%arg0: memref<72x72xf32>, %arg1: memref<72x72xf32>, %arg2: memref<72x72xf32>) { +func.func @matmul(%arg0: memref<72x72xf32>, %arg1: memref<72x72xf32>, %arg2: memref<72x72xf32>) { // Check the matrix intrinsic lowering is triggered. // CHECK-INTRINSIC: vector.matrix_multiply @@ -17,13 +17,13 @@ // Check the outer product lowering is triggered. // CHECK-OUTER: vector.outerproduct {{.*}} : vector<2xf32>, vector<4xf32> linalg.matmul ins(%arg0, %arg1: memref<72x72xf32>, memref<72x72xf32>) outs(%arg2: memref<72x72xf32>) - return + func.return } // ----- // CHECK-INTERCHANGE: func @matmul( -func @matmul(%arg0: tensor<72x72xf32>, %arg1: tensor<72x72xf32>, %arg2: tensor<72x72xf32>) -> tensor<72x72xf32> { +func.func @matmul(%arg0: tensor<72x72xf32>, %arg1: tensor<72x72xf32>, %arg2: tensor<72x72xf32>) -> tensor<72x72xf32> { // CHECK-INTERCHANGE-DAG: %[[C16:.*]] = arith.constant 16 // CHECK-INTERCHANGE-DAG: %[[C32:.*]] = arith.constant 32 // CHECK-INTERCHANGE-DAG: %[[C64:.*]] = arith.constant 64 @@ -37,7 +37,7 @@ // CHECK-INTERCHANGE: linalg.generic // CHECK-INTERCHANGE-SAME: iterator_types = ["parallel", "reduction", "parallel"] %0 = linalg.matmul ins(%arg0, %arg1: tensor<72x72xf32>, tensor<72x72xf32>) outs(%arg2: tensor<72x72xf32>) -> tensor<72x72xf32> - return %0 : tensor<72x72xf32> + func.return %0 : tensor<72x72xf32> } // ----- @@ -45,7 +45,7 @@ // CHECK-PAD-DAG: #[[MAP0:[0-9a-z]+]] = affine_map<(d0) -> (-d0 + 72, 16)> // CHECK-PAD: func @matmul( -func @matmul(%arg0: tensor<72x72xf32>, %arg1: tensor<72x72xf32>, %arg2: tensor<72x72xf32>) -> tensor<72x72xf32> { +func.func @matmul(%arg0: tensor<72x72xf32>, %arg1: tensor<72x72xf32>, %arg2: tensor<72x72xf32>) -> tensor<72x72xf32> { // Check the padding of the input operands has been hoisted out of the tile loop nest. // CHECK-PAD-COUNT=2: tensor.pad %{{.*}} nofold @@ -56,13 +56,13 @@ // CHECK-PAD-COUNT=2: scf.for // CHECK-PAD: linalg.matmul %0 = linalg.matmul ins(%arg0, %arg1: tensor<72x72xf32>, tensor<72x72xf32>) outs(%arg2: tensor<72x72xf32>) -> tensor<72x72xf32> - return %0 : tensor<72x72xf32> + func.return %0 : tensor<72x72xf32> } // ----- // CHECK-FUSE: func @matmul( -func @matmul(%arg0: tensor<72x72xf32>, %arg1: tensor<72x72xf32>, %arg2: tensor<72x72xf32>) -> tensor<72x72xf32> { +func.func @matmul(%arg0: tensor<72x72xf32>, %arg1: tensor<72x72xf32>, %arg2: tensor<72x72xf32>) -> tensor<72x72xf32> { // Check the padding and vectorization applies to the fill operation due to the empty anchor op string. // CHECK-FUSE: %[[CST:.*]] = arith.constant dense<0.000000e+00> @@ -73,13 +73,13 @@ // Check the matmul is padded and vectorized despite the empty anchor op string. // CHECK-FUSE: vector.outerproduct %1 = linalg.matmul ins(%arg0, %arg1: tensor<72x72xf32>, tensor<72x72xf32>) outs(%0: tensor<72x72xf32>) -> tensor<72x72xf32> - return %1 : tensor<72x72xf32> + func.return %1 : tensor<72x72xf32> } // ----- // CHECK-DECOMP: func @conv( -func @conv(%arg0: tensor<8x18x17x32xf32>, %arg1: tensor<3x3x32x64xf32>, %arg2: tensor<8x16x15x64xf32>) -> tensor<8x16x15x64xf32> { +func.func @conv(%arg0: tensor<8x18x17x32xf32>, %arg1: tensor<3x3x32x64xf32>, %arg2: tensor<8x16x15x64xf32>) -> tensor<8x16x15x64xf32> { %cst = arith.constant 0.000000e+00 : f32 %0 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<8x16x15x64xf32>) -> tensor<8x16x15x64xf32> @@ -88,5 +88,5 @@ // CHECK-DECOMP: vector.outerproduct // CHECK-DECOMP: vector.transfer_write {{.*}}: vector<1x8x32xf32>, tensor<1x1x?x32xf32> %1 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<8x18x17x32xf32>, tensor<3x3x32x64xf32>) outs(%0 : tensor<8x16x15x64xf32>) -> tensor<8x16x15x64xf32> - return %1 : tensor<8x16x15x64xf32> + func.return %1 : tensor<8x16x15x64xf32> } diff --git a/mlir/test/Dialect/Linalg/pad.mlir b/mlir/test/Dialect/Linalg/pad.mlir --- a/mlir/test/Dialect/Linalg/pad.mlir +++ b/mlir/test/Dialect/Linalg/pad.mlir @@ -1,7 +1,7 @@ -// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.matmul pad pack-paddings=1,1,0 run-enable-pass=false" -cse -split-input-file | FileCheck %s --check-prefix=MATMUL -// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.fill pad pack-paddings=1,1 run-enable-pass=false" -cse -split-input-file | FileCheck %s --check-prefix=FILL -// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.fill pad pack-paddings=1,0 run-enable-pass=false" -test-linalg-codegen-strategy="anchor-op=linalg.matmul pad pack-paddings=1,0 run-enable-pass=false" -cse -split-input-file | FileCheck %s --check-prefix=FILL-MATMUL -// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.matmul pad pack-paddings=1,1,0 pad-inputs-only run-enable-pass=false" -cse -split-input-file | FileCheck %s --check-prefix=INPUTS-ONLY +// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.matmul pad padding-values=0.:f32,0.:f32,0.:f32 pack-paddings=1,1,0 run-enable-pass=false" -cse -split-input-file | FileCheck %s --check-prefix=MATMUL +// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.fill pad padding-values=0.:f32,1.:f32 pack-paddings=1,1 run-enable-pass=false" -cse -split-input-file | FileCheck %s --check-prefix=FILL +// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.fill pad padding-values=0.:f32,0.:f32 pack-paddings=1,0 run-enable-pass=false" -test-linalg-codegen-strategy="anchor-op=linalg.matmul pad padding-values=0.:f32,0.:f32,0.:f32 pack-paddings=1,1,0 run-enable-pass=false" -cse -split-input-file | FileCheck %s --check-prefix=FILL-MATMUL +// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.matmul pad padding-values=0.:f32,0.:f32 pack-paddings=1,1,0 run-enable-pass=false" -cse -split-input-file | FileCheck %s --check-prefix=INPUTS-ONLY // MATMUL-DAG: #[[MAP0:[0-9a-z]+]] = affine_map<()[s0] -> (-s0 + 12, 7)> // MATMUL-DAG: #[[MAP1:[0-9a-z]+]] = affine_map<()[s0] -> (-s0 + 7)> @@ -18,6 +18,7 @@ %arg1: tensor<12x25xf32>, %arg2: tensor<24x25xf32>, %iv0 : index, %iv1 : index, %iv2 : index) -> tensor<24x25xf32> { + // MATMUL-DAG: %[[CST:.*]] = arith.constant 0. // MATMUL-DAG: %[[C0:.*]] = arith.constant 0 : index // MATMUL: %[[TS2:.*]] = affine.min #[[MAP0]]()[%[[IV2]]] @@ -35,6 +36,7 @@ // MATMUL: %[[T3:.*]] = tensor.pad %[[T0]] nofold // MATMUL-SAME: [%[[C0]], %[[C0]]] // MATMUL-SAME: [%[[C0]], %[[V0]] + // MATMUL: tensor.yield %[[CST]] // MATMUL: %[[T4:.*]] = tensor.pad %[[T1]] nofold // Check the statically sized matmul output with fully divisible sizes is not padded. @@ -62,6 +64,7 @@ %arg1: tensor<12x25xf32>, %arg2: tensor<24x25xf32>, %iv0 : index, %iv1 : index, %iv2 : index) -> tensor<24x25xf32> { + // MATMUL-DAG: %[[CST:.*]] = arith.constant 0. // MATMUL-DAG: %[[C0:.*]] = arith.constant 0 : index %3 = tensor.extract_slice %arg0[%iv0, %iv2] [4, 6] [1, 1] : tensor<24x12xf32> to tensor<4x6xf32> @@ -78,6 +81,7 @@ // MATMUL: %[[T1:.*]] = tensor.pad %[[T0]] low // MATMUL-SAME: [%[[C0]], %[[C0]]] // MATMUL-SAME: [%[[C0]], %[[V0]] + // MATMUL: tensor.yield %[[CST]] // MATMUL: %[[T2:.*]] = linalg.matmul // MATMUL-SAME: outs(%[[T1]] : tensor<4x7xf32>) @@ -485,12 +489,14 @@ // FILL-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<1x64x1x64xf32> func.func @rank_reducing(%arg0: tensor<1x64x1x64xf32>, %iv0 : index) -> tensor<1x?x?xf32> { + // FILL: %[[CST:.*]] = arith.constant 1. %cst = arith.constant 0.0 : f32 %size = affine.min #map0()[%iv0] %0 = tensor.extract_slice %arg0[0, 0, 0, 0] [1, %size, 1, %size] [1, 1, 1, 1] : tensor<1x64x1x64xf32> to tensor<1x?x?xf32> // Check the fill is padded despite the rank-reducing slice operation. // FILL: %[[T0:.*]] = tensor.pad + // FILL: tensor.yield %[[CST]] // FILL: %[[T1:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[T0]] // FILL-SAME: tensor<1x64x64xf32> // FILL: = tensor.extract_slice %[[T1]] diff --git a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt --- a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt @@ -20,6 +20,7 @@ MLIRLinalgTransforms MLIRLLVMToLLVMIRTranslation MLIRMemRef + MLIRParser MLIRPass MLIRSCF MLIRSCFTransforms diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp @@ -20,6 +20,7 @@ #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/Parser/Parser.h" #include "mlir/Pass/Pass.h" #include "llvm/ADT/SetVector.h" @@ -96,23 +97,21 @@ llvm::cl::init(false)}; Option pad{*this, "pad", llvm::cl::desc("Pad the operands."), llvm::cl::init(false)}; - Option padInputsOnly{ - *this, "pad-inputs-only", - llvm::cl::desc("Only pad input operands when test-pad-pattern"), - llvm::cl::init(false)}; + ListOption paddingValues{ + *this, "padding-values", + llvm::cl::desc("Operand padding values parsed by the attribute parser."), + llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; ListOption packPaddings{ - *this, "pack-paddings", - llvm::cl::desc("Operand packing flags when test-pad-pattern."), + *this, "pack-paddings", llvm::cl::desc("Operand packing flags."), llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; ListOption hoistPaddings{ - *this, "hoist-paddings", - llvm::cl::desc("Operand hoisting depths when test-pad-pattern."), + *this, "hoist-paddings", llvm::cl::desc("Operand hoisting depths."), llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; ListOption transposePaddings{ *this, "transpose-paddings", llvm::cl::desc( - "Transpose paddings when test-pad-pattern. Specify a " - "operand dimension interchange using the following format:\n" + "Transpose paddings. Specify a operand dimension interchange " + "using the following format:\n" "-transpose-paddings=1:0:2,0:1,0:1\n" "It defines the interchange [1, 0, 2] for operand one and " "the interchange [0, 1] (no transpose) for the remaining operands." @@ -226,14 +225,6 @@ } } // namespace -// For now, just assume it is the zero of type. -// In the future, it should be the zero of type + op. -static Value getNeutralOfLinalgOp(OpBuilder &b, OpOperand &op) { - auto t = getElementTypeOrSelf(op.get()); - return b.create(op.getOwner()->getLoc(), t, - b.getZeroAttr(t)); -} - /// Apply transformations specified as patterns. void TestLinalgCodegenStrategy::runOnOperation() { if (!anchorFuncOpName.empty() && anchorFuncOpName != getOperation().getName()) @@ -256,44 +247,32 @@ registerTilingOptions = registerTilingOptions.setTileSizes(registerTileSizes); - LinalgPaddingOptions paddingOptions; - auto packFunc = [&](OpOperand &opOperand) { - return opOperand.getOperandNumber() < packPaddings.size() - ? packPaddings[opOperand.getOperandNumber()] - : false; - }; - auto hoistingFunc = [&](OpOperand &opOperand) { - return opOperand.getOperandNumber() < hoistPaddings.size() - ? hoistPaddings[opOperand.getOperandNumber()] - : 0; - }; - auto transposeFunc = [&](OpOperand &opOperand) { - SmallVector transposeVector = {}; - if (opOperand.getOperandNumber() >= transposePaddings.size()) - return transposeVector; - SmallVector elems; - StringRef(transposePaddings[opOperand.getOperandNumber()]) - .split(elems, ':'); - for (StringRef elem : elems) - transposeVector.push_back(std::stoi(elem.str())); - return transposeVector; - }; - paddingOptions.setPaddingValueComputationFunction(getNeutralOfLinalgOp); - paddingOptions.setPaddingNoFoldComputationFunction(packFunc); - paddingOptions.setPaddingHoistComputationFunction(hoistingFunc); - paddingOptions.setPaddingTransposeComputationFunction(transposeFunc); + // Parse the padding values. + SmallVector paddingValueAttributes; + for (const std::string &paddingValue : paddingValues) { + paddingValueAttributes.push_back( + parseAttribute(paddingValue, &getContext())); + } - // Compute input padding values only an return failure for output operands. - if (padInputsOnly) { - paddingOptions.setPaddingValueComputationFunction( - [](OpBuilder &b, OpOperand &op) -> FailureOr { - auto linalgOp = dyn_cast(op.getOwner()); - if (linalgOp && linalgOp.isInputTensor(&op)) - return getNeutralOfLinalgOp(b, op); - return failure(); - }); + // Parse the transpose vectors. + SmallVector> transposePaddingVectors; + for (const std::string &transposePadding : transposePaddings) { + SmallVector transposeVector = {}; + SmallVector tokens; + StringRef(transposePadding).split(tokens, ':'); + for (StringRef token : tokens) + transposeVector.push_back(std::stoi(token.str())); + transposePaddingVectors.push_back(transposeVector); } + LinalgPaddingOptions paddingOptions; + paddingOptions.setPaddingValues(paddingValueAttributes); + paddingOptions.setPackPaddings( + SmallVector{packPaddings.begin(), packPaddings.end()}); + paddingOptions.setHoistPaddings( + SmallVector{hoistPaddings.begin(), hoistPaddings.end()}); + paddingOptions.setTransposePaddings(transposePaddingVectors); + vector::VectorContractLowering vectorContractLowering = llvm::StringSwitch( vectorizeContractionTo.getValue()) diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel @@ -402,6 +402,7 @@ "//mlir:LinalgOps", "//mlir:LinalgTransforms", "//mlir:MemRefDialect", + "//mlir:Parser", "//mlir:Pass", "//mlir:SCFDialect", "//mlir:SCFTransforms",