diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -222,6 +222,89 @@ }]; } +def SplitReductionByScalingOp : + Op { + let description = [{ + Indicates that the given `target` op should be transformed with the + `splitReductionByScaling` transformation and split factor provided as + attribute. + + Instead of introducing an ExpandShapeOp, this scaling-based implementation + rewrites a reduction dimension `k` into `k * split_factor + kk`. + The dimension `kk` is added as an extra parallel dimension to the + intermediate output tensor at position `insert_split_dimension`. + + Consider a minimal example where `k` is reduced: + O(i, j) += I(i, j, k) + Assume i=3, j=5, k=128, split_factor=16 and insert_split_dimension=0. + The compute is rewritten as: + a. O_i(kk, i, j) += I(i, j, 16 * k + kk) + b. O(i, j) += O_i(kk, i, j) + The intermediate tensor O_i is of shape (128/16)x3x5 == 8x3x5. + + Example: + + ``` + %0 = linalg.matmul ins(%A, %B: tensor<16x256xf32>, tensor<256x32xf32>) + outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32> + ``` + + Is transformed to: + + ``` + #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d2 * 4 + d3)> + #map1 = affine_map<(d0, d1, d2, d3) -> (d2 * 4 + d3, d1)> + #map2 = affine_map<(d0, d1, d2, d3) -> (d2, d3)> + #map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + #map4 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> + #map5 = affine_map<(d0, d1, d2) -> (d0, d1)> + %0 = linalg.init_tensor [16, 32, 64] : tensor<16x32x64xf32> + %cst = arith.constant 0.000000e+00 : f32 + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<16x32x64xf32>) -> + tensor<16x32x64xf32> + %2 = linalg.init_tensor [64, 4] : tensor<64x4xi1> + + %3 = linalg.generic {indexing_maps = [#map0, #map1, #map2, #map3], + iterator_types = ["parallel", "parallel", "parallel", "reduction"]} + ins(%A, %B, %2 : tensor<16x256xf32>, tensor<256x32xf32>, tensor<64x4xi1>) + outs(%1 : tensor<16x32x64xf32>) { + ^bb0(%arg3: f32, %arg4: f32, %arg5: i1, %arg6: f32): + %5 = arith.mulf %arg3, %arg4 : f32 + %6 = arith.addf %arg6, %5 : f32 + linalg.yield %6 : f32 + } -> tensor<16x32x64xf32> + + %4 = linalg.generic {indexing_maps = [#map4, #map5], + iterator_types = ["parallel", "parallel", "reduction"]} + ins(%3 : tensor<16x32x64xf32>) + outs(%C : tensor<16x32xf32>) { + ^bb0(%arg3: f32, %arg4: f32): + %5 = arith.addf %arg3, %arg4 : f32 + linalg.yield %5 : f32 + } -> tensor<16x32xf32> + + return %4 : tensor<16x32xf32> + ``` + + }]; + + let arguments = (ins PDL_Operation:$target, + DefaultValuedAttr:$split_factor, + DefaultValuedAttr:$insert_split_dimension); + let results = (outs PDL_Operation:$fill_op, + PDL_Operation:$split_linalg_op, + PDL_Operation:$combining_linalg_op); + + let assemblyFormat = "$target attr-dict"; + + let extraClassDeclaration = [{ + ::mlir::FailureOr<::llvm::SmallVector<::mlir::Operation *>> applyToOne( + ::mlir::linalg::LinalgOp target, TransformState &state); + }]; +} + def TileOp : Op, FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface]> { diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1532,6 +1532,56 @@ splitReduction(PatternRewriter &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn); +/// Scaling-based implementation of the split reduction transformation. +/// Instead of introducing an ExpandShapeOp, this rewrites a reduction dimension +/// `k` into `k * scale + kk`. +/// +/// Example: +/// ``` +/// %0 = linalg.matmul ins(%A, %B: tensor<16x256xf32>, tensor<256x32xf32>) +/// outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32> +/// ``` +/// +/// Is transformed to: +/// +/// ``` +/// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d2 * 4 + d3)> +/// #map1 = affine_map<(d0, d1, d2, d3) -> (d2 * 4 + d3, d1)> +/// #map2 = affine_map<(d0, d1, d2, d3) -> (d2, d3)> +/// #map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> +/// #map4 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +/// #map5 = affine_map<(d0, d1, d2) -> (d0, d1)> +/// %0 = linalg.init_tensor [16, 32, 64] : tensor<16x32x64xf32> +/// %cst = arith.constant 0.000000e+00 : f32 +/// %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<16x32x64xf32>) -> +/// tensor<16x32x64xf32> +/// %2 = linalg.init_tensor [64, 4] : tensor<64x4xi1> +/// +/// %3 = linalg.generic {indexing_maps = [#map0, #map1, #map2, #map3], +/// iterator_types = ["parallel", "parallel", "parallel", "reduction"]} +/// ins(%A, %B, %2 : tensor<16x256xf32>, tensor<256x32xf32>, tensor<64x4xi1>) +/// outs(%1 : tensor<16x32x64xf32>) { +/// ^bb0(%arg3: f32, %arg4: f32, %arg5: i1, %arg6: f32): +/// %5 = arith.mulf %arg3, %arg4 : f32 +/// %6 = arith.addf %arg6, %5 : f32 +/// linalg.yield %6 : f32 +/// } -> tensor<16x32x64xf32> +/// +/// %4 = linalg.generic {indexing_maps = [#map4, #map5], +/// iterator_types = ["parallel", "parallel", "reduction"]} +// ins(%3 : tensor<16x32x64xf32>) +/// outs(%C : tensor<16x32xf32>) { +/// ^bb0(%arg3: f32, %arg4: f32): +/// %5 = arith.addf %arg3, %arg4 : f32 +/// linalg.yield %5 : f32 +/// } -> tensor<16x32xf32> +/// +/// return %4 : tensor<16x32xf32> +/// ``` +FailureOr +splitReductionByScaling(PatternRewriter &b, LinalgOp op, + const ControlSplitReductionFn &controlSplitReductionFn); + } // namespace linalg } // namespace mlir diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h --- a/mlir/include/mlir/IR/AffineMap.h +++ b/mlir/include/mlir/IR/AffineMap.h @@ -240,6 +240,22 @@ getContext()); } + /// Returns a new AffineMap with the same number of dims and symbols and one + /// less result at `pos`, dropped. + AffineMap dropResult(unsigned pos) { + auto exprs = llvm::to_vector<4>(getResults()); + exprs.erase(exprs.begin() + pos); + return AffineMap::get(getNumDims(), getNumSymbols(), exprs, getContext()); + } + + /// Returns a new AffineMap with the same number of dims and symbols and an + /// extra result inserted at `pos`. + AffineMap insertResult(AffineExpr expr, unsigned pos) { + auto exprs = llvm::to_vector<4>(getResults()); + exprs.insert(exprs.begin() + pos, expr); + return AffineMap::get(getNumDims(), getNumSymbols(), exprs, getContext()); + } + /// Folds the results of the application of an affine map on the provided /// operands to a constant if possible. LogicalResult constantFold(ArrayRef operandConstants, diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -249,6 +249,16 @@ return *this; } + /// Insert a val into shape @pos. + Builder &insertDim(int64_t val, unsigned pos) { + assert(pos <= shape.size() && "overflow"); + if (storage.empty()) + storage.append(shape.begin(), shape.end()); + storage.insert(storage.begin() + pos, val); + shape = {storage.data(), storage.size()}; + return *this; + } + operator RankedTensorType() { return RankedTensorType::get(shape, elementType, encoding); } diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -421,6 +421,28 @@ splitResult->resultCombiningLinalgOp}; } +//===----------------------------------------------------------------------===// +// SplitReductionByScalingOp +//===----------------------------------------------------------------------===// + +FailureOr> +transform::SplitReductionByScalingOp::applyToOne(LinalgOp target, + TransformState &state) { + ControlSplitReductionFn splitFn = [&](LinalgOp) { + return std::pair(getSplitFactor(), + getInsertSplitDimension()); + }; + SimpleRewriter rewriter(getContext()); + rewriter.setInsertionPoint(target); + FailureOr splitResult = + splitReductionByScaling(rewriter, target, splitFn); + if (failed(splitResult)) + return getOperation()->emitError("failed to apply"); + return SmallVector{splitResult->fillOp, + splitResult->splitLinalgOp, + splitResult->resultCombiningLinalgOp}; +} + //===----------------------------------------------------------------------===// // TileOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp @@ -19,13 +19,14 @@ #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Utils/Utils.h" #include "mlir/IR/PatternMatch.h" using namespace mlir; using namespace mlir::linalg; /// Return the identity numeric value associated to the give op. -static Optional getIdentity(Operation *op) { +static Attribute getNeutralElement(Operation *op) { // Builder only used as helper for attribute creation. OpBuilder b(op->getContext()); Type resultType = op->getResult(0).getType(); @@ -41,7 +42,7 @@ if (isa(op)) return b.getFloatAttr(resultType, llvm::APFloat::getLargest(semantic, true)); - return llvm::None; + return Attribute(); } if (isa(op)) return b.getIntegerAttr(resultType, 0); @@ -53,7 +54,7 @@ return b.getIntegerAttr(resultType, std::numeric_limits::max()); if (isa(op)) return b.getIntegerAttr(resultType, 1); - return llvm::None; + return Attribute(); } FailureOr mlir::linalg::splitReduction( @@ -84,7 +85,7 @@ std::pair control = controlSplitReductionFn(op); int64_t ratio = control.first; - unsigned insertDimIndex = control.second; + unsigned insertSplitDimension = control.second; if (ratio <= 1) return b.notifyMatchFailure(op, "split ratio needs to be greater than 1"); @@ -95,7 +96,8 @@ SmallVector loopRanges = op.getStaticLoopRanges(); int64_t reductionDimSize = loopRanges[reductionDim]; if (reductionDimSize == ShapedType::kDynamicSize || - reductionDimSize % ratio != 0 || insertDimIndex >= loopRanges.size()) + reductionDimSize % ratio != 0 || + insertSplitDimension >= loopRanges.size()) return b.notifyMatchFailure( op, "Reduction dimension not divisible by split ratio"); @@ -105,7 +107,7 @@ return b.notifyMatchFailure(op, "Cannot match the reduction pattern"); Operation *reductionOp = combinerOps[0]; - Optional identity = getIdentity(reductionOp); + Attribute identity = getNeutralElement(reductionOp); if (!identity) return b.notifyMatchFailure(op, "Unknown identity value for the reduction"); @@ -125,13 +127,14 @@ newShape.push_back(ratio); newShape.push_back(op.getShape(operand)[idx] / ratio); reassociation.push_back({index++, index++}); - exprs.push_back(b.getAffineDimExpr(insertDimIndex)); + exprs.push_back(b.getAffineDimExpr(insertSplitDimension)); exprs.push_back( - b.getAffineDimExpr(dim < insertDimIndex ? dim : dim + 1)); + b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1)); continue; } newShape.push_back(op.getShape(operand)[idx]); - exprs.push_back(b.getAffineDimExpr(dim < insertDimIndex ? dim : dim + 1)); + exprs.push_back( + b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1)); reassociation.push_back({index++}); } newMaps.push_back( @@ -157,20 +160,20 @@ SmallVector outputExpr; for (unsigned idx : llvm::seq(0, oldOutputMap.getNumResults() + 1)) { - if (idx == insertDimIndex) { + if (idx == insertSplitDimension) { newOutputShape.push_back(ratio); - outputExpr.push_back(b.getAffineDimExpr(insertDimIndex)); + outputExpr.push_back(b.getAffineDimExpr(insertSplitDimension)); continue; } - unsigned oldDim = idx < insertDimIndex ? idx : idx - 1; + unsigned oldDim = idx < insertSplitDimension ? idx : idx - 1; newOutputShape.push_back(oldShape[oldDim]); unsigned dim = oldOutputMap.getDimPosition(oldDim); outputExpr.push_back( - b.getAffineDimExpr(dim < insertDimIndex ? dim : dim + 1)); + b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1)); } Value initTensor = b.create( loc, newOutputShape, op.getRegionOutputArgs()[0].getType()); - Value constantOp = b.create(loc, *identity); + Value constantOp = b.create(loc, identity); Value identityTensor = b.create(op->getLoc(), constantOp, initTensor) .getResult(0); @@ -179,7 +182,7 @@ op.getContext())); SmallVector newIteratorTypes; for (auto &it : llvm::enumerate(op.iterator_types())) { - if (insertDimIndex == it.index()) + if (insertSplitDimension == it.index()) newIteratorTypes.push_back(getParallelIteratorTypeName()); newIteratorTypes.push_back(it.value().cast().getValue()); } @@ -199,7 +202,7 @@ SmallVector reductionIteratorTypes; SmallVector exprs; for (unsigned i : llvm::seq(0, intermRank)) { - if (insertDimIndex == i) { + if (insertSplitDimension == i) { reductionIteratorTypes.push_back(getReductionIteratorTypeName()); } else { exprs.push_back(b.getAffineDimExpr(i)); @@ -225,6 +228,206 @@ reduction}; } +/// Rewrite f(i, j, k, ...) into f(i, j, k * ratio + kk, ...) +/// TODO: Additional pattern to rewrite f(i, j, k * ratio + kk, ...) into +/// f(i, j, k, kk, ...) with a proper ExpandShapeOp. This is probably better +/// done as a transform to enable better vectorization. +static AffineMap scaleReductionDim(LinalgOp op, OpOperand &opOperand, + unsigned reductionDimPos, + int64_t reductionRatio) { + auto reductionDim = getAffineDimExpr(reductionDimPos, op.getContext()); + auto reductionDimP1 = getAffineDimExpr(reductionDimPos + 1, op.getContext()); + AffineMap map = op.getTiedIndexingMap(&opOperand); + AffineMap idMap = + AffineMap::getMultiDimIdentityMap(map.getNumDims(), op.getContext()); + AffineMap shiftedIdMap = idMap.shiftDims(1, /*offset=*/reductionDimPos + 1); + AffineMap composeMap = shiftedIdMap.replace( + reductionDim, reductionDim * reductionRatio + reductionDimP1, + shiftedIdMap.getNumDims(), /*numSymbols=*/0); + return map.compose(composeMap); +} + +static AffineMap insertParallelDim(LinalgOp op, OpOperand &opOperand, + unsigned reductionDimPos, int64_t size) { + auto reductionDim = getAffineDimExpr(reductionDimPos, op.getContext()); + AffineMap map = op.getTiedIndexingMap(&opOperand); + AffineMap idMap = + AffineMap::getMultiDimIdentityMap(map.getNumDims(), op.getContext()); + AffineMap shiftedIdMap = idMap.shiftDims(1, /*offset=*/reductionDimPos + 1); + return map.compose(shiftedIdMap).insertResult(reductionDim, reductionDimPos); +} + +/// Core rewrite implementation. +FailureOr mlir::linalg::splitReductionByScaling( + PatternRewriter &b, LinalgOp op, + const ControlSplitReductionFn &controlSplitReductionFn) { + OpBuilder::InsertionGuard guard(b); + b.setInsertionPoint(op); + + // Matcher part, enforce preconditions. + std::pair control = controlSplitReductionFn(op); + int64_t splitFactor = control.first; + unsigned insertSplitDimension = control.second; + if (splitFactor <= 1) + return b.notifyMatchFailure(op, "split factor needs to be greater than 1"); + + SmallVector dims; + op.getReductionDims(dims); + if (dims.empty()) + return b.notifyMatchFailure(op, "needs at least 1 reduction dimension"); + + unsigned reductionDimPos = dims[0]; + SmallVector loopRanges = op.getStaticLoopRanges(); + int64_t reductionDimSize = loopRanges[reductionDimPos]; + if (reductionDimSize == ShapedType::kDynamicSize || + reductionDimSize % splitFactor != 0 || + insertSplitDimension >= loopRanges.size()) + return b.notifyMatchFailure( + op, "first reduction dimension not divisible by split factor"); + + SmallVector combinerOps; + if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps)) + return b.notifyMatchFailure(op, "cannot match a reduction pattern"); + + SmallVector neutralElements = llvm::to_vector<4>( + llvm::map_range(combinerOps, [&](Operation *reductionOp) { + return getNeutralElement(reductionOp); + })); + if (!llvm::all_of(neutralElements, [](Attribute attr) { return attr; })) + return b.notifyMatchFailure(op, "unknown reduction neutral"); + + // TODO: relax this when multi-reduction support is available. + if (op.getNumOutputs() != neutralElements.size()) + return b.notifyMatchFailure(op, "expect one reduction per output"); + + // Rewrite part. + // Step 1. Build the intermediate outputs filled with the proper + // neutralElements. Such outputs are of the same shape with an extra dimension + // inserted at `insertSplitDimension`. + // + // Consider a minimal example where `k` is reduced: + // O(i, j) += I(i, j, k) + // Assume i=3, j=5, k=128, splitFactor=16 and insertSplitDimension=0. + // The compute is rewritten as: + // a. O_i(kk, i, j) += I(i, j, 16 * k + kk) + // b. O(i, j) += O_i(kk, i, j) + // The intermediate tensor O_i is of shape (128/16)x3x5 == 8x3x5. + Location loc = op->getLoc(); + MLIRContext *context = op.getContext(); + // For now assume outputs are 1-1 with reduction neutralElements. + // TODO: generalize when multi-reduction support is available. + SmallVector newOutputs; + newOutputs.reserve(op.getNumOutputs()); + SmallVector fillOps; + fillOps.reserve(op.getNumOutputs()); + for (auto it : llvm::zip(op.outputs(), neutralElements)) { + Value rankedTensor = std::get<0>(it); + auto t = rankedTensor.getType().cast(); + RankedTensorType newT = RankedTensorType::Builder(t).insertDim( + reductionDimSize / splitFactor, insertSplitDimension); + SmallVector dims = + tensor::createDynamicDimValues(b, loc, rankedTensor); + Value initTensor = b.create( + loc, dims, newT.getShape(), t.getElementType()); + Value constantOp = b.create(loc, std::get<1>(it)); + fillOps.push_back( + b.create(op->getLoc(), constantOp, initTensor)); + newOutputs.push_back(fillOps.back().getResult(0)); + } + + // Step 2. Reindex / expand indexing maps. + // Reindex existing input indexings: k -> k * splitFactor + k'. + SmallVector newMaps; + newMaps.reserve(op.getNumInputsAndOutputs() + 1); + for (OpOperand *o : op.getInputOperands()) + newMaps.push_back(scaleReductionDim(op, *o, reductionDimPos, splitFactor)); + // Provision a new indexing for the shape-only tensor. + auto nDims = op.getNumLoops() + 1; + auto redDim = getAffineDimExpr(reductionDimPos, context); + auto redDimP1 = getAffineDimExpr(reductionDimPos + 1, context); + newMaps.push_back(AffineMap::get(nDims, 0, {redDim, redDimP1}, context)); + // Expand existing output indexings. + // TODO: a subset of these may not reduce along reducePos and should be + // reindexed: k -> k * splitFactor + k', when multi-reduction support is + // available. + for (OpOperand *o : op.getOutputOperands()) + newMaps.push_back(insertParallelDim(op, *o, reductionDimPos, + reductionDimSize / splitFactor)); + + // Step 3. Handle operands. + // Compute the new input tensors. + auto newInputs = llvm::to_vector<4>(op.inputs()); + // Add a single shape-only tensor to carry the dimensions without resorting to + // more complex inversions. + newInputs.push_back(b.create( + loc, ArrayRef{reductionDimSize / splitFactor, splitFactor}, + b.getIntegerType(1))); + // Output tensors are already good to go. + + // Step 4. Create the new op matching the original op with an extra parallel + // dimension. + SmallVector iteratorTypes = + llvm::to_vector<4>(op.getIteratorTypes().getAsValueRange()); + iteratorTypes.insert(iteratorTypes.begin() + reductionDimPos, + getParallelIteratorTypeName()); + GenericOp genericOp = + b.create(loc, ValueRange(newOutputs).getTypes(), newInputs, + newOutputs, newMaps, iteratorTypes); + b.inlineRegionBefore(op->getRegion(0), genericOp.region(), + genericOp.region().begin()); + genericOp.region().front().insertArgument(reductionDimPos, + b.getIntegerType(1), loc); + + // Step 5. Create new reduction ops that only reduce the newly added + // dimensions from the previous op. + // For now assume outputs are 1-1 with reduction ops. + // TODO: a subset of these may not reduce in the first place and do not + // require a new op, when multi-reduction support is available. + // TODO: all results can be handled in a single GenericOp, when + // multi-reduction support is available. + SmallVector results; + for (auto it : + llvm::zip(genericOp->getResults(), op.outputs(), combinerOps)) { + Value reindexedOutput = std::get<0>(it); + Value originalOutput = std::get<1>(it); + auto originalOutputType = originalOutput.getType().cast(); + Operation *combinerOp = std::get<2>(it); + + AffineMap map = b.getMultiDimIdentityMap(originalOutputType.getRank() + 1); + SmallVector indexingMaps = { + map, map.dropResult(insertSplitDimension)}; + SmallVector reductionIteratorTypes( + originalOutputType.getRank() + 1, getParallelIteratorTypeName()); + reductionIteratorTypes[insertSplitDimension] = + getReductionIteratorTypeName(); + + // clang-format off + auto reductionOp = b.create( + loc, + originalOutputType, + reindexedOutput, + originalOutput, + indexingMaps, + reductionIteratorTypes, + [combinerOp](OpBuilder &b, Location loc, ValueRange bbArgs) { + Operation *clonedReductionOp = b.clone(*combinerOp); + clonedReductionOp->setOperand(0, bbArgs[0]); + clonedReductionOp->setOperand(1, bbArgs[1]); + b.create(loc, clonedReductionOp->getResult(0)); + }); + // clang-format on + + results.push_back(reductionOp); + } + + // TODO: extend when multi-reduction support is available. + assert(fillOps.size() == results.size() && results.size() == 1); + b.replaceOp(op, results.front()->getResults()); + return SplitReductionResult{fillOps.front(), + cast(genericOp.getOperation()), + results.front()}; +} + namespace { struct LinalgSplitReduction : public OpInterfaceRewritePattern { diff --git a/mlir/test/Dialect/Linalg/transform-op-split-reduction-by-scaling.mlir b/mlir/test/Dialect/Linalg/transform-op-split-reduction-by-scaling.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/transform-op-split-reduction-by-scaling.mlir @@ -0,0 +1,35 @@ +// RUN: mlir-opt --test-transform-dialect-interpreter %s | FileCheck %s + +// CHECK-LABEL: func.func @matmul_split +func.func @matmul_split(%A : tensor, %B: tensor<256x32xf32>, %C: tensor) -> tensor { + + // CHECK: linalg.generic + // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"] + // CHECK-SAME: ins(%{{[a-zA-Z0-9]*}}, %{{[a-zA-Z0-9]*}}, %{{[a-zA-Z0-9]*}} : tensor, tensor<256x32xf32>, tensor<64x4xi1>) + // CHECK-SAME: outs(%{{[a-zA-Z0-9]*}} : tensor) { + + // CHECK: linalg.generic + // CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"] + // CHECK-SAME: ins(%{{[a-zA-Z0-9]*}} : tensor) + // CHECK-SAME: outs(%{{[a-zA-Z0-9]*}} : tensor) { + %0 = linalg.matmul ins(%A, %B: tensor, tensor<256x32xf32>) + outs(%C: tensor) -> tensor + return %0: tensor +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @pdl_target : benefit(1) { + %args = operands + %results = types + %0 = pdl.operation "linalg.matmul"(%args : !pdl.range) -> (%results : !pdl.range) + // TODO: we don't want this, but it is the required terminator for pdl.pattern + rewrite %0 with "transform.dialect" + } + + transform.sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = pdl_match @pdl_target in %arg1 + %1:3 = transform.structured.split_reduction_by_scaling %0 { split_factor = 4, insert_split_dimension = 2} + } +}