diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -1090,6 +1090,19 @@ /*methodName=*/"getRegionBuilder", (ins), [{ return ConcreteOp::getRegionBuilder(); }] + >, + InterfaceMethod< + /*desc=*/[{ + Return true if all the indexing maps are projected permutations. + Otherwise return false. + }], + /*retTy=*/"bool", + /*methodName=*/"hasOnlyProjectedPermutations", + (ins), + [{ + return llvm::all_of($_op.getIndexingMaps(), + [](AffineMap map) { return map.isProjectedPermutation(); }); + }] > ]; diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1447,6 +1447,64 @@ } }; +/// Function signature to control reduction splitting. This returns a pair +/// containing a ratio and a dimension index. The ratio is used to split the +/// reduction dimension. The dimension index is used to control where the extra +/// dimension is added to the intermediate tensor shape. If the ratio value is +/// less or equal to 1 then nothing will be done. +using ControlSplitReductionFn = + std::function(LinalgOp op)>; + +/// Patterns to apply `splitReduction` below. +void populateSplitReductionPattern( + RewritePatternSet &patterns, + ControlSplitReductionFn controlSplitReductionFn, + LinalgTransformationFilter f = LinalgTransformationFilter()); + +/// Apply transformation to split the single linalg op reduction into a parallel +/// and reduction dimension. Then create a new linalg.generic op doing the rest +/// of the reduction. Return the new linalg op with an extra parallel dimension +/// or failure if the transformation didn't happen. +/// Example: +/// ``` +/// %r = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, +/// affine_map<(d0) -> ()>], +/// iterator_types = ["reduction"]} +/// ins(%in : tensor<32xf32>) +/// outs(%out : tensor) { +/// ^bb0(%arg1: f32, %arg2: f32): +/// %y = arith.addf %arg1, %arg2 : f32 +/// linalg.yield %y : f32 +/// } -> tensor +/// ``` +/// To: +/// ``` +/// %cst = arith.constant 0.000000e+00 : f32 +/// %0 = tensor.expand_shape %in [[0, 1]] : tensor<32xf32> into tensor<4x8xf32> +/// %1 = linalg.init_tensor [4] : tensor<4xf32> +/// %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<4xf32>) -> tensor<4xf32> +/// %3 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, +/// affine_map<(d0, d1) -> (d0)>], +/// iterator_types = ["parallel", "reduction"]} +/// ins(%0 : tensor<4x8xf32>) outs(%2 : tensor<4xf32>) { +/// ^bb0(%arg3: f32, %arg5: f32): +/// %5 = arith.addf %arg3, %arg4 : f32 +/// linalg.yield %5 : f32 +/// } -> tensor<4xf32> +/// %r = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, +/// affine_map<(d0) -> ()>], +/// iterator_types = ["reduction"]} +/// ins(%3 : tensor<4xf32>) outs(%out : tensor) { +/// ^bb0(%arg3: f32, %arg4: f32): +/// %5 = arith.addf %arg3, %arg4 : f32 +/// linalg.yield %5 : f32 +/// } -> tensor +/// ``` +FailureOr +splitReduction(PatternRewriter &b, LinalgOp op, + ControlSplitReductionFn controlSplitReductionFn, + LinalgTransformationFilter f); + } // namespace linalg } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -20,6 +20,7 @@ PadOpInterchange.cpp Promotion.cpp SparseTensorRewriting.cpp + SplitReduction.cpp Tiling.cpp Transforms.cpp Vectorization.cpp diff --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp @@ -0,0 +1,234 @@ +//===-------- SplitReduction.cpp - Split reduction dimesion ---------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements linalg transformation to break a reduction dimension +// between a parallel and a reduction dimension. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/PatternMatch.h" + +using namespace mlir; +using namespace mlir::linalg; + +/// Return the identity numeric value associated to the give op. +static Optional getIdentity(Operation *op) { + // Builder only used as helper for attribute creation. + OpBuilder b(op->getContext()); + Type resultType = op->getResult(0).getType(); + if (auto floatType = resultType.dyn_cast()) { + const llvm::fltSemantics &semantic = floatType.getFloatSemantics(); + if (isa(op)) + return b.getFloatAttr(resultType, llvm::APFloat::getZero(semantic)); + if (isa(op)) + return b.getFloatAttr(resultType, llvm::APFloat(semantic, 1)); + if (isa(op)) + return b.getFloatAttr(resultType, + llvm::APFloat::getLargest(semantic, true)); + if (isa(op)) + return b.getFloatAttr(resultType, + llvm::APFloat::getLargest(semantic, true)); + return llvm::None; + } + if (isa(op)) + return b.getIntegerAttr(resultType, 0); + if (isa(op)) + return b.getIntegerAttr(resultType, -1); + if (isa(op)) + return b.getIntegerAttr(resultType, std::numeric_limits::min()); + if (isa(op)) + return b.getIntegerAttr(resultType, std::numeric_limits::max()); + if (isa(op)) + return b.getIntegerAttr(resultType, 1); + return llvm::None; +} + +FailureOr +mlir::linalg::splitReduction(PatternRewriter &b, LinalgOp op, + ControlSplitReductionFn controlSplitReductionFn, + LinalgTransformationFilter filter) { + if (failed(filter.checkAndNotify(b, op)) || !op.hasTensorSemantics() || + op.getNumReductionLoops() != 1 || op.getNumOutputs() != 1 || + !op.hasOnlyProjectedPermutations()) + return b.notifyMatchFailure(op, "precondition not met"); + std::pair control = controlSplitReductionFn(op); + int64_t ratio = control.first; + unsigned insertDimIndex = control.second; + if (ratio <= 1) + return b.notifyMatchFailure(op, "split ratio needs to be greater than 1"); + SmallVector dims; + op.getReductionDims(dims); + assert(dims.size() == 1); + unsigned reductionDim = dims[0]; + Optional> loopRanges = op.getStaticLoopRanges(); + if (!loopRanges) + return b.notifyMatchFailure(op, "Cannot analyze loops"); + int64_t reductionDimSize = (*loopRanges)[reductionDim]; + if (reductionDimSize == ShapedType::kDynamicSize || + reductionDimSize % ratio != 0 || insertDimIndex >= loopRanges->size()) + return b.notifyMatchFailure( + op, "Reduction dimension not divisible by split ratio"); + SmallVector combinerOps; + if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps) || + combinerOps.size() != 1) + return b.notifyMatchFailure(op, "Cannot match the reduction pattern"); + Operation *reductionOp = combinerOps[0]; + Optional identity = getIdentity(reductionOp); + if (!identity) + return b.notifyMatchFailure(op, "Unknown identity value for the redution"); + + Location loc = op->getLoc(); + SmallVector newInputs; + SmallVector newMaps; + // Calculate the new shapes and indexing maps of the input operands. + for (OpOperand *operand : op.getInputOperands()) { + AffineMap map = op.getTiedIndexingMap(operand); + SmallVector newShape; + SmallVector exprs; + SmallVector reassociation; + unsigned index = 0; + for (unsigned idx : llvm::seq(0, map.getNumResults())) { + unsigned dim = map.getDimPosition(idx); + if (reductionDim == dim) { + newShape.push_back(ratio); + newShape.push_back(op.getShape(operand)[idx] / ratio); + reassociation.push_back({index++, index++}); + exprs.push_back(b.getAffineDimExpr(insertDimIndex)); + exprs.push_back( + b.getAffineDimExpr(dim < insertDimIndex ? dim : dim + 1)); + continue; + } + newShape.push_back(op.getShape(operand)[idx]); + exprs.push_back(b.getAffineDimExpr(dim < insertDimIndex ? dim : dim + 1)); + reassociation.push_back({index++}); + } + newMaps.push_back( + AffineMap::get(map.getNumDims() + 1, 0, exprs, op.getContext())); + // If the shape is unchanged the input doesn't change. + if (newShape == op.getShape(operand)) { + newInputs.push_back(operand->get()); + continue; + } + Type newType = RankedTensorType::get( + newShape, + operand->get().getType().cast().getElementType()); + Value newInput = b.create( + loc, newType, operand->get(), reassociation); + newInputs.push_back(newInput); + } + // Calculate the new output map and shape, we insert the new dimension based + // on the index returned by `controlSplitReductionFn`. + SmallVector newOutputShape; + AffineMap oldOutputMap = op.getTiedIndexingMap(op.getOutputOperand(0)); + ArrayRef oldShape = op.getShape(op.getOutputOperand(0)); + SmallVector outputExpr; + for (unsigned idx : + llvm::seq(0, oldOutputMap.getNumResults() + 1)) { + if (idx == insertDimIndex) { + newOutputShape.push_back(ratio); + outputExpr.push_back(b.getAffineDimExpr(insertDimIndex)); + continue; + } + unsigned oldDim = idx < insertDimIndex ? idx : idx - 1; + newOutputShape.push_back(oldShape[oldDim]); + unsigned dim = oldOutputMap.getDimPosition(oldDim); + outputExpr.push_back( + b.getAffineDimExpr(dim < insertDimIndex ? dim : dim + 1)); + } + Value initTensor = b.create( + loc, newOutputShape, op.getRegionOutputArgs()[0].getType()); + Value constantOp = b.create(loc, *identity); + Value identityTensor = + b.create(op->getLoc(), constantOp, initTensor) + .getResult(0); + + newMaps.push_back(AffineMap::get(oldOutputMap.getNumDims() + 1, 0, outputExpr, + op.getContext())); + SmallVector newIteratorTypes; + for (auto &it : llvm::enumerate(op.iterator_types())) { + if (insertDimIndex == it.index()) + newIteratorTypes.push_back(getParallelIteratorTypeName()); + newIteratorTypes.push_back(it.value().cast().getValue()); + } + // Create the new op matching the original op with an extra parallel + // dimension. + GenericOp genericOp = b.create( + loc, TypeRange({initTensor.getType()}), newInputs, + ValueRange({identityTensor}), newMaps, newIteratorTypes); + b.inlineRegionBefore(op->getRegion(0), genericOp.region(), + genericOp.region().begin()); + + // Then create a new reduction that only reduce the newly added dimension from + // the previous op. + unsigned intermRank = newOutputShape.size(); + AffineMap inputMap = b.getMultiDimIdentityMap(intermRank); + SmallVector outputOperands = op.getOutputOperands(); + SmallVector reductionIteratorTypes; + SmallVector exprs; + for (unsigned i : llvm::seq(0, intermRank)) { + if (insertDimIndex == i) { + reductionIteratorTypes.push_back(getReductionIteratorTypeName()); + } else { + exprs.push_back(b.getAffineDimExpr(i)); + reductionIteratorTypes.push_back(getParallelIteratorTypeName()); + } + } + AffineMap outputMap = AffineMap::get(intermRank, 0, exprs, op.getContext()); + SmallVector reductionMaps = {inputMap, outputMap}; + + auto reduction = b.create( + loc, op->getResultTypes(), ValueRange({genericOp.getResult(0)}), + outputOperands, reductionMaps, reductionIteratorTypes, + [reductionOp](OpBuilder &b, Location loc, ValueRange inputs) { + Operation *clonedReductionOp = b.clone(*reductionOp); + clonedReductionOp->setOperand(0, inputs[0]); + clonedReductionOp->setOperand(1, inputs[1]); + b.create(loc, clonedReductionOp->getResult(0)); + }); + b.replaceOp(op, reduction.getResults()); + filter.replaceLinalgTransformationFilter(b, genericOp); + filter.replaceLinalgTransformationFilter(b, reduction); + return cast(genericOp.getOperation()); +} + +namespace { + +struct LinalgSplitReduction : public OpInterfaceRewritePattern { + /// Construct a generic pattern applied to all LinalgOp that verify `filter`. + LinalgSplitReduction(MLIRContext *context, + ControlSplitReductionFn controlSplitReductionFn, + LinalgTransformationFilter f, PatternBenefit benefit = 1) + : OpInterfaceRewritePattern(context, benefit), + controlSplitReductionFn(controlSplitReductionFn), filter(std::move(f)) { + } + + LogicalResult matchAndRewrite(LinalgOp op, + PatternRewriter &rewriter) const override { + return splitReduction(rewriter, op, controlSplitReductionFn, filter); + } + +private: + ControlSplitReductionFn controlSplitReductionFn; + LinalgTransformationFilter filter; +}; + +} // namespace + +void linalg::populateSplitReductionPattern( + RewritePatternSet &patterns, + ControlSplitReductionFn controlSplitReductionFn, + LinalgTransformationFilter f) { + patterns.add(patterns.getContext(), + controlSplitReductionFn, f); +} diff --git a/mlir/test/Dialect/Linalg/split_reduction.mlir b/mlir/test/Dialect/Linalg/split_reduction.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/split_reduction.mlir @@ -0,0 +1,119 @@ +// RUN: mlir-opt %s -test-linalg-transform-patterns=test-split-reduction -split-input-file | FileCheck %s + +func @matmul_split(%A : tensor<16x256xf32>, %B: tensor<256x32xf32>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> { + %0 = linalg.matmul ins(%A, %B: tensor<16x256xf32>, tensor<256x32xf32>) + outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32> + return %0: tensor<16x32xf32> +} + +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d3, d1)> +// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> +// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> +// CHECK-LABEL: @matmul_split +// CHECK-DAG: %[[ID:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] : tensor<16x256xf32> into tensor<16x4x64xf32> +// CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] : tensor<256x32xf32> into tensor<4x64x32xf32> +// CHECK-DAG: %[[INI:.*]] = linalg.init_tensor [16, 32, 4] : tensor<16x32x4xf32> +// CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<16x32x4xf32>) -> tensor<16x32x4xf32> +// CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]] +// CHECK-SAME: , iterator_types = ["parallel", "parallel", "parallel", "reduction"]} +// CHECK-SAME: ins(%[[I1]], %[[I2]] : tensor<16x4x64xf32>, tensor<4x64x32xf32>) outs(%[[F]] : tensor<16x32x4xf32>) { +// CHECK: arith.mulf +// CHECK: arith.addf +// CHECK: linalg.yield +// CHECK: } -> tensor<16x32x4xf32> +// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP3]], #[[$MAP4]]], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]} ins(%[[G]] : tensor<16x32x4xf32>) outs(%{{.*}} : tensor<16x32xf32>) { +// CHECK: arith.addf +// CHECK: linalg.yield %{{.*}} : f32 +// CHECK: } -> tensor<16x32xf32> +// CHECK: return %[[R]] : tensor<16x32xf32> + +// ----- + +func @generic_split_1d(%arg0: tensor<32xf32>, %arg1: tensor, %out: tensor) -> tensor { + %red = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, + affine_map<(d0) -> ()>, + affine_map<(d0) -> ()>], + iterator_types = ["reduction"]} + ins(%arg0, %arg1 : tensor<32xf32>, tensor) + outs(%out : tensor) { + ^bb0(%arg7: f32, %arg8: f32, %arg9: f32): + %40 = arith.subf %arg7, %arg8 : f32 + %41 = math.exp %40 : f32 + %42 = arith.mulf %41, %arg9 : f32 + linalg.yield %42 : f32 + } -> tensor + return %red : tensor +} + +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> ()> +// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0)> +// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0) -> (d0)> +// CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0) -> ()> +//CHECK-LABEL: @generic_split_1d +// CHECK: %[[ID:.*]] = arith.constant 1.000000e+00 : f32 +// CHECK: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1]] : tensor<32xf32> into tensor<4x8xf32> +// CHECK: %[[INI:.*]] = linalg.init_tensor [4] : tensor<4xf32> +// CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<4xf32>) -> tensor<4xf32> +// CHECK: %[[G:.*]] = linalg.generic +// CHECK: {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], +// CHECK: iterator_types = ["parallel", "reduction"]} ins(%[[I1]], %{{.*}} : tensor<4x8xf32>, tensor) outs(%[[F]] : tensor<4xf32>) { +// CHECK: arith.subf +// CHECK: math.exp +// CHECK: arith.mulf +// CHECK: linalg.yield +// CHECK: } -> tensor<4xf32> +// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP3]], #[[$MAP4]]], iterator_types = ["reduction"]} ins(%[[G]] : tensor<4xf32>) outs(%{{.*}} : tensor) { +// CHECK: arith.mulf +// CHECK: linalg.yield +// CHECK: } -> tensor +// CHECK: return %[[R]] : tensor + +// ----- + +func @generic_split_3d(%input: tensor<32x2xf32>, %input_2: tensor<5x32xf32>, %output: tensor<5x2xf32>) + -> tensor<5x2xf32> +{ + %0 = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1, d2) -> (d1, d0)>, + affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1, d2) -> (d2, d0)> + ], + iterator_types = ["parallel", "reduction", "parallel"] + } ins(%input, %input_2 : tensor<32x2xf32>, tensor<5x32xf32>) outs(%output : tensor<5x2xf32>) { + ^bb0(%arg0: f32, %arg1: f32, %arg2: f32): + %3 = arith.addf %arg0, %arg1 : f32 + %4 = arith.maxf %3, %arg2 : f32 + linalg.yield %4 : f32 + } -> tensor<5x2xf32> + return %0 : tensor<5x2xf32> +} + +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d1, d0)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d2, d1)> +// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d2)> +// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> +// CHECK-LABEL: func @generic_split_3d +// CHECK: %[[ID:.*]] = arith.constant -3.40282347E+38 : f32 +// CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] : tensor<32x2xf32> into tensor<4x8x2xf32> +// CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] : tensor<5x32xf32> into tensor<5x4x8xf32> +// CHECK: %[[INI:.*]] = linalg.init_tensor [5, 2, 4] : tensor<5x2x4xf32> +// CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<5x2x4xf32>) -> tensor<5x2x4xf32> +// CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel"]} +// CHECK-SAME: ins(%[[I1]], %[[I2]] : tensor<4x8x2xf32>, tensor<5x4x8xf32>) outs(%[[F]] : tensor<5x2x4xf32>) { +// CHECK: arith.addf +// CHECK: arith.maxf +// CHECK: linalg.yield +// CHECK: } -> tensor<5x2x4xf32> +// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP3]], #[[$MAP4]]], iterator_types = ["parallel", "parallel", "reduction"]} +// CHECK-SAME: ins(%[[G]] : tensor<5x2x4xf32>) outs(%{{.*}} : tensor<5x2xf32>) { +// CHECK: arith.maxf +// CHECK: linalg.yield +// CHECK: } -> tensor<5x2xf32> +// CHECK: return %[[R]] : tensor<5x2xf32> diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -111,6 +111,10 @@ llvm::cl::desc("Test rewrite of subtensor(pad_tensor) into " "pad_tensor(subtensor)"), llvm::cl::init(false)}; + Option testSplitReduction{ + *this, "test-split-reduction", + llvm::cl::desc("Test split reduction transformation"), + llvm::cl::init(false)}; ListOption peeledLoops{ *this, "peeled-loops", llvm::cl::desc("Loops to be peeled when test-tile-pattern"), @@ -617,6 +621,20 @@ (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern)); } +static void applySplitReduction(FuncOp funcOp) { + RewritePatternSet patterns(funcOp.getContext()); + linalg::populateSplitReductionPattern( + patterns, + [](LinalgOp op) { + unsigned insertDimIndex = op.getNumLoops() - 1; + return std::make_pair(4, insertDimIndex); + }, + LinalgTransformationFilter( + ArrayRef{}, + StringAttr::get(funcOp.getContext(), "SPLIT"))); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); +} + /// Apply transformations specified as patterns. void TestLinalgTransforms::runOnOperation() { auto lambda = [&](void *) { @@ -666,6 +684,8 @@ if (testTileScalarizeDynamicDims) return applyTilePattern(getOperation(), loopType, tileSizes, /*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true); + if (testSplitReduction) + return applySplitReduction(getOperation()); } namespace mlir {