diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -164,8 +164,24 @@ reduction into a parallel and reduction dimension. A new `linalg.generic` op is created to perform the rest of the reduction. - Example: - + The transformation supports different configurations attributes: + - split_factor: the factor by which to split (i.e. the size of the + remaining reduction after splitting). + - insert_split_dimension: the dimension in the temporary tensor into + which the new parallel dimension is inserted. + - use_scaling_algorithm: whether to use a scaling based formulation that + does not create an ExpandShapeOp (default: do not use scaling) + - use_alloc: whether to use an alloc op to allocate the temporary + tensor (default: do not use alloc op) + + This op returns 4 handles to: + - the init op (or tensor_alloc op if use_alloc = true), + - the fill op used to initialize the neutral element, + - the split op and + - the result-combining op. + + Example (default: use_scaling_algorithm = false, use_alloc = false): + ==================================================================== ``` %r = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>], @@ -178,7 +194,7 @@ } -> tensor ``` - To: + is split into: ``` %cst = arith.constant 0.000000e+00 : f32 @@ -203,34 +219,8 @@ } -> tensor ``` - This op returns handles to the fill op used to initialize the neutral - element, the split op and the result-combining op. - }]; - - let arguments = (ins PDL_Operation:$target, - DefaultValuedAttr:$split_factor, - DefaultValuedAttr:$insert_split_dimension); - let results = (outs PDL_Operation:$fill_op, - PDL_Operation:$split_linalg_op, - PDL_Operation:$combining_linalg_op); - - let assemblyFormat = "$target attr-dict"; - - let extraClassDeclaration = [{ - ::mlir::FailureOr<::llvm::SmallVector<::mlir::Operation *>> applyToOne( - ::mlir::linalg::LinalgOp target, TransformState &state); - }]; -} - -def SplitReductionByScalingOp : - Op { - let description = [{ - Indicates that the given `target` op should be transformed with the - `splitReductionByScaling` transformation and split factor provided as - attribute. - + Example (use_scaling_algorithm = true, use_alloc = true): + ========================================================= Instead of introducing an ExpandShapeOp, this scaling-based implementation rewrites a reduction dimension `k` into `k * split_factor + kk`. The dimension `kk` is added as an extra parallel dimension to the @@ -287,12 +277,13 @@ return %4 : tensor<16x32xf32> ``` - }]; let arguments = (ins PDL_Operation:$target, DefaultValuedAttr:$split_factor, - DefaultValuedAttr:$insert_split_dimension); + DefaultValuedAttr:$insert_split_dimension, + UnitAttr:$use_scaling_algorithm, + UnitAttr:$use_alloc); let results = (outs PDL_Operation:$fill_op, PDL_Operation:$split_linalg_op, PDL_Operation:$combining_linalg_op); diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1474,7 +1474,8 @@ void populateSplitReductionPattern( RewritePatternSet &patterns, const ControlSplitReductionFn &controlSplitReductionFn, - const LinalgTransformationFilter &f = LinalgTransformationFilter()); + const LinalgTransformationFilter &f = LinalgTransformationFilter(), + bool useAlloc = false); /// Apply transformation to split the single linalg op reduction into a parallel /// and reduction dimension. Then create a new linalg.generic op doing the rest @@ -1518,19 +1519,21 @@ FailureOr splitReduction(PatternRewriter &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, - const LinalgTransformationFilter &f); + const LinalgTransformationFilter &f, bool useAlloc = false); /// Filterless version of the above. /// Returns both the new linalg ops as well as the fillOp needed to initialize /// the temporary expanded tensor with the proper neutral element. struct SplitReductionResult { + Operation *initOrAlloc; FillOp fillOp; LinalgOp splitLinalgOp; LinalgOp resultCombiningLinalgOp; }; FailureOr splitReduction(PatternRewriter &b, LinalgOp op, - const ControlSplitReductionFn &controlSplitReductionFn); + const ControlSplitReductionFn &controlSplitReductionFn, + bool useAlloc = false); /// Scaling-based implementation of the split reduction transformation. /// Instead of introducing an ExpandShapeOp, this rewrites a reduction dimension @@ -1580,7 +1583,8 @@ /// ``` FailureOr splitReductionByScaling(PatternRewriter &b, LinalgOp op, - const ControlSplitReductionFn &controlSplitReductionFn); + const ControlSplitReductionFn &controlSplitReductionFn, + bool useAlloc = false); } // namespace linalg } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -413,29 +413,9 @@ SimpleRewriter rewriter(getContext()); rewriter.setInsertionPoint(target); FailureOr splitResult = - splitReduction(rewriter, target, splitFn); - if (failed(splitResult)) - return getOperation()->emitError("failed to apply"); - return SmallVector{splitResult->fillOp, - splitResult->splitLinalgOp, - splitResult->resultCombiningLinalgOp}; -} - -//===----------------------------------------------------------------------===// -// SplitReductionByScalingOp -//===----------------------------------------------------------------------===// - -FailureOr> -transform::SplitReductionByScalingOp::applyToOne(LinalgOp target, - TransformState &state) { - ControlSplitReductionFn splitFn = [&](LinalgOp) { - return std::pair(getSplitFactor(), - getInsertSplitDimension()); - }; - SimpleRewriter rewriter(getContext()); - rewriter.setInsertionPoint(target); - FailureOr splitResult = - splitReductionByScaling(rewriter, target, splitFn); + (getUseScalingAlgorithm()) + ? splitReductionByScaling(rewriter, target, splitFn, getUseAlloc()) + : splitReduction(rewriter, target, splitFn, getUseAlloc()); if (failed(splitResult)) return getOperation()->emitError("failed to apply"); return SmallVector{splitResult->fillOp, diff --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp @@ -15,6 +15,7 @@ #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" @@ -60,14 +61,14 @@ FailureOr mlir::linalg::splitReduction( PatternRewriter &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, - const LinalgTransformationFilter &filter) { + const LinalgTransformationFilter &filter, bool useAlloc) { if (failed(filter.checkAndNotify(b, op)) || !op.hasTensorSemantics() || op.getNumReductionLoops() != 1 || op.getNumOutputs() != 1 || !op.hasOnlyProjectedPermutations()) return b.notifyMatchFailure(op, "precondition not met"); FailureOr res = - splitReduction(b, op, controlSplitReductionFn); + splitReduction(b, op, controlSplitReductionFn, useAlloc); if (failed(res)) return failure(); @@ -79,7 +80,7 @@ FailureOr mlir::linalg::splitReduction( PatternRewriter &b, LinalgOp op, - const ControlSplitReductionFn &controlSplitReductionFn) { + const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) { OpBuilder::InsertionGuard guard(b); b.setInsertionPoint(op); @@ -171,11 +172,20 @@ outputExpr.push_back( b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1)); } - Value initTensor = b.create( - loc, newOutputShape, op.getRegionOutputArgs()[0].getType()); + Value initOrAllocTensor; + if (useAlloc) { + initOrAllocTensor = b.create( + loc, + RankedTensorType::get(newOutputShape, + op.getRegionOutputArgs()[0].getType()), + ValueRange{}); + } else { + initOrAllocTensor = b.create( + loc, newOutputShape, op.getRegionOutputArgs()[0].getType()); + } Value constantOp = b.create(loc, identity); Value identityTensor = - b.create(op->getLoc(), constantOp, initTensor) + b.create(op->getLoc(), constantOp, initOrAllocTensor) .getResult(0); newMaps.push_back(AffineMap::get(oldOutputMap.getNumDims() + 1, 0, outputExpr, @@ -189,7 +199,7 @@ // Create the new op matching the original op with an extra parallel // dimension. GenericOp genericOp = b.create( - loc, TypeRange({initTensor.getType()}), newInputs, + loc, TypeRange({initOrAllocTensor.getType()}), newInputs, ValueRange({identityTensor}), newMaps, newIteratorTypes); b.inlineRegionBefore(op->getRegion(0), genericOp.region(), genericOp.region().begin()); @@ -223,9 +233,9 @@ }); b.replaceOp(op, reduction.getResults()); - return SplitReductionResult{identityTensor.getDefiningOp(), - cast(genericOp.getOperation()), - reduction}; + return SplitReductionResult{ + initOrAllocTensor.getDefiningOp(), identityTensor.getDefiningOp(), + cast(genericOp.getOperation()), reduction}; } /// Rewrite f(i, j, k, ...) into f(i, j, k * ratio + kk, ...) @@ -260,7 +270,7 @@ /// Core rewrite implementation. FailureOr mlir::linalg::splitReductionByScaling( PatternRewriter &b, LinalgOp op, - const ControlSplitReductionFn &controlSplitReductionFn) { + const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) { OpBuilder::InsertionGuard guard(b); b.setInsertionPoint(op); @@ -297,7 +307,7 @@ return b.notifyMatchFailure(op, "unknown reduction neutral"); // TODO: relax this when multi-reduction support is available. - if (op.getNumOutputs() != (int)neutralElements.size()) + if (op.getNumOutputs() != static_cast(neutralElements.size())) return b.notifyMatchFailure(op, "expect one reduction per output"); // Rewrite part. @@ -318,6 +328,7 @@ // TODO: generalize when multi-reduction support is available. SmallVector newOutputs; newOutputs.reserve(op.getNumOutputs()); + SmallVector initOrAllocTensorOps; SmallVector fillOps; fillOps.reserve(op.getNumOutputs()); for (auto it : llvm::zip(op.outputs(), neutralElements)) { @@ -327,12 +338,19 @@ reductionDimSize / splitFactor, insertSplitDimension); SmallVector dims = tensor::createDynamicDimValues(b, loc, rankedTensor); - Value initTensor = b.create( - loc, dims, newT.getShape(), t.getElementType()); + Value initOrAllocTensor; + if (useAlloc) { + initOrAllocTensor = + b.create(loc, newT, dims); + } else { + initOrAllocTensor = b.create( + loc, dims, newT.getShape(), t.getElementType()); + } Value constantOp = b.create(loc, std::get<1>(it)); fillOps.push_back( - b.create(op->getLoc(), constantOp, initTensor)); + b.create(op->getLoc(), constantOp, initOrAllocTensor)); newOutputs.push_back(fillOps.back().getResult(0)); + initOrAllocTensorOps.push_back(initOrAllocTensor.getDefiningOp()); } // Step 2. Reindex / expand indexing maps. @@ -423,7 +441,7 @@ // TODO: extend when multi-reduction support is available. assert(fillOps.size() == results.size() && results.size() == 1); b.replaceOp(op, results.front()->getResults()); - return SplitReductionResult{fillOps.front(), + return SplitReductionResult{initOrAllocTensorOps.front(), fillOps.front(), cast(genericOp.getOperation()), results.front()}; } @@ -434,18 +452,21 @@ /// Construct a generic pattern applied to all LinalgOp that verify `filter`. LinalgSplitReduction(MLIRContext *context, ControlSplitReductionFn controlSplitReductionFn, - LinalgTransformationFilter f, PatternBenefit benefit = 1) + LinalgTransformationFilter f, bool useAlloc = false, + PatternBenefit benefit = 1) : OpInterfaceRewritePattern(context, benefit), controlSplitReductionFn(std::move(controlSplitReductionFn)), - filter(std::move(f)) {} + useAlloc(useAlloc), filter(std::move(f)) {} LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override { - return splitReduction(rewriter, op, controlSplitReductionFn, filter); + return splitReduction(rewriter, op, controlSplitReductionFn, filter, + useAlloc); } private: ControlSplitReductionFn controlSplitReductionFn; + bool useAlloc; LinalgTransformationFilter filter; }; @@ -454,7 +475,7 @@ void linalg::populateSplitReductionPattern( RewritePatternSet &patterns, const ControlSplitReductionFn &controlSplitReductionFn, - const LinalgTransformationFilter &f) { + const LinalgTransformationFilter &f, bool useAlloc) { patterns.add(patterns.getContext(), - controlSplitReductionFn, f); + controlSplitReductionFn, f, useAlloc); } diff --git a/mlir/test/Dialect/Linalg/transform-op-split-reduction-by-scaling.mlir b/mlir/test/Dialect/Linalg/transform-op-split-reduction-by-scaling.mlir --- a/mlir/test/Dialect/Linalg/transform-op-split-reduction-by-scaling.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-split-reduction-by-scaling.mlir @@ -3,6 +3,7 @@ // CHECK-LABEL: func.func @matmul_split func.func @matmul_split(%A : tensor, %B: tensor<256x32xf32>, %C: tensor) -> tensor { + // CHECK: bufferization.alloc_tensor({{.*}}) : tensor // CHECK: linalg.generic // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"] // CHECK-SAME: ins(%{{[a-zA-Z0-9]*}}, %{{[a-zA-Z0-9]*}}, %{{[a-zA-Z0-9]*}} : tensor, tensor<256x32xf32>, tensor<64x4xi1>) @@ -30,6 +31,7 @@ transform.sequence %arg0 { ^bb1(%arg1: !pdl.operation): %0 = pdl_match @pdl_target in %arg1 - %1:3 = transform.structured.split_reduction_by_scaling %0 { split_factor = 4, insert_split_dimension = 2} + %1:3 = transform.structured.split_reduction %0 + { split_factor = 4, insert_split_dimension = 2, use_scaling_algorithm, use_alloc} } } diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" @@ -41,6 +42,7 @@ void getDependentDialects(DialectRegistry ®istry) const override { // clang-format off registry.insert