diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -460,6 +460,47 @@ /// OpOperand shall be marked as nofold to enable packing. using PaddingNoFoldComputationFunction = std::function; +/// Callback returning the number of loops to hoist the pad tensor operation +/// defining the given OpOperand. +using PaddingHoistComputationFunction = std::function; + +struct LinalgPaddingOptions { + /// Callback returning the padding value to use for a given OpOperand or + /// failure for no padding. Padding operations are introduced if + /// `paddingValueComputationFunction` is set and does not return failure. + /// Padding all operands guarantees the operation is statically shaped and + /// thus can be vectorized. + PaddingValueComputationFunction paddingValueComputationFunction = nullptr; + + LinalgPaddingOptions & + setPaddingValueComputationFunction(PaddingValueComputationFunction fun) { + paddingValueComputationFunction = std::move(fun); + return *this; + } + + /// Callback returning true if the pad tensor operation defining the given + /// OpOperand shall be marked as nofold to enable packing. A padding operation + /// is only marked nofold if `paddingNoFoldComputationFunction` is set and + /// returns true. Otherwise, the nofold attribute is set to false. + PaddingNoFoldComputationFunction paddingNoFoldComputationFunction = nullptr; + + LinalgPaddingOptions & + setPaddingNoFoldComputationFunction(PaddingNoFoldComputationFunction fun) { + paddingNoFoldComputationFunction = std::move(fun); + return *this; + } + + /// Callback returning the number of loops to hoist the pad tensor operation + /// defining the given OpOperand. + PaddingHoistComputationFunction paddingHoistComputationFunction = nullptr; + + LinalgPaddingOptions & + setPaddingHoistComputationFunction(PaddingHoistComputationFunction fun) { + paddingHoistComputationFunction = std::move(fun); + return *this; + } +}; + struct LinalgTilingOptions { /// Computation function that returns the tile sizes for each operation. /// Delayed construction of constant tile sizes should occur to interoperate @@ -650,6 +691,35 @@ } }; +/// +/// Linalg padding pattern. +/// +/// Apply the `padding` transformation as a pattern. +/// `filter` controls LinalgTransformMarker matching and update when specified. +/// See `padding` for more details. +struct LinalgPaddingPattern : public RewritePattern { + // Entry point to match any LinalgOp OpInterface. + LinalgPaddingPattern( + MLIRContext *context, + LinalgPaddingOptions options = LinalgPaddingOptions(), + LinalgTransformationFilter filter = LinalgTransformationFilter(), + PatternBenefit benefit = 1); + // Entry point to match a specific LinalgOp. + LinalgPaddingPattern( + StringRef opName, MLIRContext *context, + LinalgPaddingOptions options = LinalgPaddingOptions(), + LinalgTransformationFilter filter = LinalgTransformationFilter(), + PatternBenefit benefit = 1); + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override; + +private: + /// LinalgTransformMarker handles special attribute manipulations. + LinalgTransformationFilter filter; + /// Options to control padding and hoisting. + LinalgPaddingOptions options; +}; + struct LinalgFusionOptions { /// List of operands indices to use for fusion. llvm::SmallSet indicesToFuse = {}; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Linalg/Transforms/HoistPadding.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/SCF/Transforms.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -470,6 +471,62 @@ return success(); } +/// Linalg padding pattern. +mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern( + MLIRContext *context, LinalgPaddingOptions options, + LinalgTransformationFilter filter, PatternBenefit benefit) + : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter), + options(options) {} + +mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern( + StringRef opName, MLIRContext *context, LinalgPaddingOptions options, + LinalgTransformationFilter filter, PatternBenefit benefit) + : RewritePattern(opName, benefit, context, {}), filter(filter), + options(options) {} + +LogicalResult mlir::linalg::LinalgPaddingPattern::matchAndRewrite( + Operation *op, PatternRewriter &rewriter) const { + LinalgOp linalgOp = dyn_cast(op); + if (!linalgOp) + return failure(); + if (failed(filter.checkAndNotify(rewriter, op))) + return failure(); + + // Pad the operation. + LinalgOp paddedOp; + FailureOr> newResults = rewriteAsPaddedOp( + rewriter, linalgOp, options.paddingValueComputationFunction, + options.paddingNoFoldComputationFunction, paddedOp); + if (failed(newResults)) + return failure(); + + // Compute the desired hoisting depths. + SmallVector depths; + if (options.paddingHoistComputationFunction) { + for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) + depths.push_back(options.paddingHoistComputationFunction(*opOperand)); + } + + // Hoist the padding. + for (auto en : enumerate(depths)) { + OpOperand &opOperand = paddedOp->getOpOperand(en.index()); + auto padTensorOp = opOperand.get().getDefiningOp(); + if (!padTensorOp || en.value() == 0) + continue; + PadTensorOp hoistedOp; + FailureOr newResult = + hoistPaddingOnTensors(padTensorOp, en.value(), hoistedOp); + if (failed(newResult)) + continue; + rewriter.replaceOp(padTensorOp, newResult.getValue()); + } + + // Replace the original operation to pad. + rewriter.replaceOp(op, newResults.getValue()); + filter.replaceLinalgTransformationFilter(rewriter, paddedOp); + return success(); +} + /// Linalg generic interchange pattern. mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern( MLIRContext *context, ArrayRef interchangeVector, diff --git a/mlir/test/Dialect/Linalg/pad-and-hoist.mlir b/mlir/test/Dialect/Linalg/pad-and-hoist.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/pad-and-hoist.mlir @@ -0,0 +1,161 @@ +// RUN: mlir-opt %s -test-linalg-transform-patterns="test-pad-pattern pack-operands=1,1,0 hoist-operands=2,1,0" -cse -canonicalize -split-input-file | FileCheck %s +// RUN: mlir-opt %s -test-linalg-transform-patterns="test-pad-pattern pack-operands=1,1,0 hoist-operands=4,3,0" -cse -canonicalize -split-input-file | FileCheck %s --check-prefix=CHECK-DOUBLE + +// CHECK-DAG: #[[MAP0:[0-9a-z]+]] = affine_map<(d0) -> (5, -d0 + 24)> +// CHECK-DAG: #[[MAP1:[0-9a-z]+]] = affine_map<(d0) -> (8, -d0 + 12)> +// CHECK-DAG: #[[DIV6:[0-9a-z]+]] = affine_map<(d0) -> (d0 ceildiv 6)> +#map0 = affine_map<(d0) -> (5, -d0 + 24)> +#map1 = affine_map<(d0) -> (8, -d0 + 12)> +#map2 = affine_map<(d0) -> (7, -d0 + 25)> + +// CHECK: single_tiling +// CHECK-DOUBLE: single_tiling + +// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<24x12xf32> +// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]]: tensor<12x25xf32> +// CHECK-SAME: %[[ARG2:[0-9a-zA-Z]*]]: tensor<24x25xf32> +func @single_tiling(%arg0: tensor<24x12xf32>, + %arg1: tensor<12x25xf32>, + %arg2: tensor<24x25xf32>) -> tensor<24x25xf32> { + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[C5:.*]] = arith.constant 5 + // CHECK-DAG: %[[C8:.*]] = arith.constant 8 + %c0 = arith.constant 0 : index + %c12 = arith.constant 12 : index + %c25 = arith.constant 25 : index + %c24 = arith.constant 24 : index + %c6 = arith.constant 6 : index + %c7 = arith.constant 7 : index + %c5 = arith.constant 5 : index + + // CHECK: scf.for %[[IV0:[0-9a-zA-Z]*]] = + %0 = scf.for %arg3 = %c0 to %c24 step %c5 iter_args(%arg4 = %arg2) -> (tensor<24x25xf32>) { + + // Packing the first input operand for all values of IV2 (IV2x5x6). + // CHECK: = linalg.init_tensor [2, 5, 6] + // CHECK: %[[PT0:.*]] = scf.for %[[P0IV2:[0-9a-z]+]] = + // CHECK: %[[PIDX0:.*]] = affine.apply #[[DIV6]](%[[P0IV2]]) + // CHECK: %[[TS0:.*]] = affine.min #[[MAP0]](%[[IV0]]) + // CHECK: %[[T0:.*]] = tensor.extract_slice %[[ARG0]] + // CHECK-SAME: %[[IV0]], %[[P0IV2]] + // CHECK-SAME: %[[TS0]], 6 + // CHECK: %[[V0:.*]] = arith.subi %[[C5]], %[[TS0]] + // CHECK: %[[T1:.*]] = linalg.pad_tensor %[[T0]] nofold {{.*}} high[%[[V0]] + // CHECK: %[[T2:.*]] = tensor.insert_slice %[[T1:.*]] into %{{.*}}[%[[PIDX0]], 0, 0] + // CHECK: scf.yield %[[T2:.*]] + + // CHECK: scf.for %[[IV1:[0-9a-zA-Z]*]] = + %1 = scf.for %arg5 = %c0 to %c25 step %c7 iter_args(%arg6 = %arg4) -> (tensor<24x25xf32>) { + + // Packing the second input operand for all values of IV2 (IV2x6x8). + // CHECK: = linalg.init_tensor [2, 6, 8] + // CHECK: %[[PT1:.*]] = scf.for %[[P1IV2:[0-9a-z]+]] = + // CHECK: %[[PIDX1:.*]] = affine.apply #[[DIV6]](%[[P1IV2]]) + // CHECK: %[[TS1:.*]] = affine.min #[[MAP1]](%[[IV1]]) + // CHECK: %[[T3:.*]] = tensor.extract_slice %[[ARG1]] + // CHECK-SAME: %[[P1IV2]], %[[IV1]] + // CHECK-SAME: 6, %[[TS1]] + // CHECK: %[[V1:.*]] = arith.subi %[[C8]], %[[TS1]] + // CHECK: %[[T4:.*]] = linalg.pad_tensor %[[T3]] nofold {{.*}} high[%[[C0]], %[[V1]] + // CHECK: %[[T5:.*]] = tensor.insert_slice %[[T4:.*]] into %{{.*}}[%[[PIDX1]], 0, 0] + // CHECK: scf.yield %[[T5:.*]] + + // CHECK: scf.for %[[IV2:[0-9a-zA-Z]*]] = {{.*}} iter_args(%[[ARG4:.*]] = + %2 = scf.for %arg7 = %c0 to %c12 step %c6 iter_args(%arg8 = %arg6) -> (tensor<24x25xf32>) { + %3 = affine.min #map0(%arg3) + // Index the packed operands. + // CHECK-DAG: %[[IDX:.*]] = affine.apply #[[DIV6]](%[[IV2]]) + // CHECK-DAG: %[[T6:.*]] = tensor.extract_slice %[[PT0]][%[[IDX]] + // CHECK-DAG: %[[T7:.*]] = tensor.extract_slice %[[PT1]][%[[IDX]] + %4 = tensor.extract_slice %arg0[%arg3, %arg7] [%3, 6] [1, 1] : tensor<24x12xf32> to tensor + %5 = affine.min #map1(%arg5) + %6 = tensor.extract_slice %arg1[%arg7, %arg5] [6, %5] [1, 1] : tensor<12x25xf32> to tensor<6x?xf32> + + // Pad the output operand without setting the nofold attribute. + // CHECK-DAG: %[[T8:.*]] = tensor.extract_slice %[[ARG4]][%[[IV0]], %[[IV1]] + // CHECK: %[[T9:.*]] = linalg.pad_tensor %[[T8]] low + %7 = tensor.extract_slice %arg8[%arg3, %arg5] [%3, %5] [1, 1] : tensor<24x25xf32> to tensor + + // Check matmul uses the packed input operands and the padded output operand. + // CHECK: = linalg.matmul ins(%[[T6]], %[[T7]]{{.*}} outs(%[[T9]] + %8 = linalg.matmul {__internal_linalg_transform__ = "pad"} ins(%4, %6 : tensor, tensor<6x?xf32>) outs(%7 : tensor) -> tensor + %9 = tensor.insert_slice %8 into %arg8[%arg3, %arg5] [%3, %5] [1, 1] : tensor into tensor<24x25xf32> + scf.yield %9 : tensor<24x25xf32> + } + scf.yield %2 : tensor<24x25xf32> + } + scf.yield %1 : tensor<24x25xf32> + } + return %0 : tensor<24x25xf32> +} + +// ----- + +#map0 = affine_map<(d0) -> (15, -d0 + 24)> +#map1 = affine_map<(d0) -> (16, -d0 + 25)> +#map2 = affine_map<(d0, d1) -> (15, -d0 + d1)> +#map3 = affine_map<(d0, d1) -> (d0 + d1)> +#map4 = affine_map<(d0, d1) -> (16, -d0 + d1)> + +// CHECK: double_tiling +// CHECK-DOUBLE: double_tiling + +// CHECK-DOUBLE-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<24x12xf32> +// CHECK-DOUBLE-SAME: %[[ARG1:[0-9a-zA-Z]*]]: tensor<12x25xf32> +// CHECK-DOUBLE-SAME: %[[ARG2:[0-9a-zA-Z]*]]: tensor<24x25xf32> +func @double_tiling(%arg0: tensor<24x12xf32>, + %arg1: tensor<12x25xf32>, + %arg2: tensor<24x25xf32>) -> tensor<24x25xf32> { + %c0 = arith.constant 0 : index + %c25 = arith.constant 25 : index + %c24 = arith.constant 24 : index + %c16 = arith.constant 16 : index + %c15 = arith.constant 15 : index + + // Packing the first input operand. + // CHECK-DOUBLE: = linalg.init_tensor + // CHECK-DOUBLE: = linalg.pad_tensor {{.*}} nofold + + // CHECK-DOUBLE: scf.for %[[IV0:[0-9a-zA-Z]*]] = + %0 = scf.for %arg3 = %c0 to %c24 step %c15 iter_args(%arg4 = %arg2) -> (tensor<24x25xf32>) { + + // Packing the second input operand. + // CHECK-DOUBLE: = linalg.init_tensor + // CHECK-DOUBLE: = linalg.pad_tensor {{.*}} nofold + + // CHECK-DOUBLE: scf.for %[[IV1:[0-9a-zA-Z]*]] = + %1 = scf.for %arg5 = %c0 to %c25 step %c16 iter_args(%arg6 = %arg4) -> (tensor<24x25xf32>) { + %2 = affine.min #map0(%arg3) + %3 = affine.min #map1(%arg5) + %4 = tensor.extract_slice %arg6[%arg3, %arg5] [%2, %3] [1, 1] : tensor<24x25xf32> to tensor + + // CHECK-DOUBLE: scf.for %[[IV2:[0-9a-zA-Z]*]] = + %5 = scf.for %arg7 = %c0 to %2 step %c15 iter_args(%arg8 = %4) -> (tensor) { + + // CHECK-DOUBLE: scf.for %[[IV3:[0-9a-zA-Z]*]] = + %7 = scf.for %arg9 = %c0 to %3 step %c16 iter_args(%arg10 = %arg8) -> (tensor) { + %8 = affine.min #map2(%arg7, %2) + %9 = affine.apply #map3(%arg7, %arg3) + %10 = tensor.extract_slice %arg0[%9, 0] [%8, 12] [1, 1] : tensor<24x12xf32> to tensor + %11 = affine.min #map4(%arg9, %3) + %12 = affine.apply #map3(%arg9, %arg5) + %13 = tensor.extract_slice %arg1[0, %12] [12, %11] [1, 1] : tensor<12x25xf32> to tensor<12x?xf32> + %14 = tensor.extract_slice %arg10[%arg7, %arg9] [%8, %11] [1, 1] : tensor to tensor + + // Pad the output operand and perform the multiplication. + // CHECK-DOUBLE: = linalg.pad_tensor + // CHECK-DOUBLE: = linalg.matmul + %15 = linalg.matmul {__internal_linalg_transform__ = "pad"} ins(%10, %13 : tensor, tensor<12x?xf32>) outs(%14 : tensor) -> tensor + %16 = tensor.insert_slice %15 into %arg10[%arg7, %arg9] [%8, %11] [1, 1] : tensor into tensor + scf.yield %16 : tensor + } + scf.yield %7 : tensor + } + %6 = tensor.insert_slice %5 into %arg6[%arg3, %arg5] [%2, %3] [1, 1] : tensor into tensor<24x25xf32> + scf.yield %6 : tensor<24x25xf32> + } + scf.yield %1 : tensor<24x25xf32> + } + return %0 : tensor<24x25xf32> +} + diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -96,6 +96,9 @@ Option testHoistPadding{*this, "test-hoist-padding", llvm::cl::desc("Test hoist padding"), llvm::cl::init(0)}; + Option testPadPattern{*this, "test-pad-pattern", + llvm::cl::desc("Test pad pattern"), + llvm::cl::init(false)}; Option testTransformPadTensor{ *this, "test-transform-pad-tensor", llvm::cl::desc("Test transform pad tensor by copying with generic ops"), @@ -117,6 +120,14 @@ *this, "nofold-operands", llvm::cl::desc("Operands to set nofold when test-tile-pattern"), llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; + ListOption packOperands{ + *this, "pack-operands", + llvm::cl::desc("Operand packing flags when test-pad-pattern"), + llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; + ListOption hoistOperands{ + *this, "hoist-operands", + llvm::cl::desc("Operand hoisting depths when test-pad-pattern"), + llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; ListOption peeledLoops{ *this, "peeled-loops", llvm::cl::desc("Loops to be peeled when test-tile-pattern"), @@ -629,6 +640,30 @@ (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern)); } +static void applyPadPattern(FuncOp funcOp, ArrayRef packOperands, + ArrayRef hoistOperands) { + MLIRContext *context = funcOp.getContext(); + RewritePatternSet padPattern(context); + auto linalgPaddingOptions = linalg::LinalgPaddingOptions(); + auto packFunc = [&](OpOperand &opOperand) { + return opOperand.getOperandNumber() < packOperands.size() + ? packOperands[opOperand.getOperandNumber()] + : false; + }; + auto hoistingFunc = [&](OpOperand &opOperand) { + return opOperand.getOperandNumber() < hoistOperands.size() + ? hoistOperands[opOperand.getOperandNumber()] + : 0; + }; + linalgPaddingOptions.setPaddingValueComputationFunction(getNeutralOfLinalgOp); + linalgPaddingOptions.setPaddingNoFoldComputationFunction(packFunc); + linalgPaddingOptions.setPaddingHoistComputationFunction(hoistingFunc); + padPattern.add( + context, linalgPaddingOptions, + LinalgTransformationFilter(Identifier::get("pad", context))); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(padPattern)); +} + static void applyInterchangePattern(FuncOp funcOp, ArrayRef interchangeVector) { MLIRContext *context = funcOp.getContext(); @@ -772,6 +807,8 @@ } }); } + if (testPadPattern) + return applyPadPattern(getFunction(), packOperands, hoistOperands); if (testInterchangePattern.hasValue()) return applyInterchangePattern(getFunction(), testInterchangePattern); }