diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -153,6 +153,25 @@ }]; } +def SplitReductionOp : Op, + FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface]> { + let description = [{ + Indicates that the given `target` op should be transformed with the + `splitReduction` transformation and split factor provided as attribute. + + This op returns handles to the split op and the result-combining op. + }]; + + let arguments = (ins PDL_Operation:$target, + DefaultValuedAttr:$split_factor, + DefaultValuedAttr:$split_dimension); + let results = (outs PDL_Operation:$split_linalg_op, + PDL_Operation:$combining_linalg_op); + + let assemblyFormat = "$target attr-dict"; +} + def TileOp : Op, FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface]> { diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1466,6 +1466,7 @@ /// reduction dimension. The dimension index is used to control where the extra /// dimension is added to the intermediate tensor shape. If the ratio value is /// less or equal to 1 then nothing will be done. +// TODO: don't use unsigned unless doing bit manipulation. using ControlSplitReductionFn = std::function(LinalgOp op)>; @@ -1519,6 +1520,16 @@ const ControlSplitReductionFn &controlSplitReductionFn, const LinalgTransformationFilter &f); +/// Filterless version of the above. +/// Returns both the new linalg ops. +struct SplitReductionResult { + LinalgOp splitLinalgOp; + LinalgOp resultCombiningLinalgOp; +}; +FailureOr +splitReduction(PatternRewriter &b, LinalgOp op, + const ControlSplitReductionFn &controlSplitReductionFn); + } // namespace linalg } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -394,6 +394,40 @@ return result->op; } +//===----------------------------------------------------------------------===// +// SplitReductionOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::SplitReductionOp::apply(TransformResults &transformResults, + TransformState &state) { + ControlSplitReductionFn splitFn = [&](LinalgOp _) { + return std::pair(getSplitFactor(), getSplitDimension()); + }; + + SimpleRewriter rewriter(getContext()); + ArrayRef payloadOps = state.getPayloadOps(getTarget()); + auto linalgOp = dyn_cast_or_null(payloadOps.front()); + if (payloadOps.size() != 1 || !linalgOp) { + getOperation()->emitError("only single LinalgOp payload supported"); + return DiagnosedSilenceableFailure::definiteFailure(); + } + + FailureOr splitResult = + splitReduction(rewriter, linalgOp, splitFn); + if (failed(splitResult)) { + getOperation()->emitError("failed to apply"); + return DiagnosedSilenceableFailure::definiteFailure(); + } + + transformResults.set(getOperation()->getOpResult(0), + splitResult->splitLinalgOp.getOperation()); + transformResults.set(getOperation()->getOpResult(1), + splitResult->resultCombiningLinalgOp.getOperation()); + + return DiagnosedSilenceableFailure::success(); +} + //===----------------------------------------------------------------------===// // TileOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp @@ -64,11 +64,29 @@ op.getNumReductionLoops() != 1 || op.getNumOutputs() != 1 || !op.hasOnlyProjectedPermutations()) return b.notifyMatchFailure(op, "precondition not met"); + + auto res = splitReduction(b, op, controlSplitReductionFn); + if (failed(res)) + return failure(); + + filter.replaceLinalgTransformationFilter(b, res->splitLinalgOp); + filter.replaceLinalgTransformationFilter(b, res->resultCombiningLinalgOp); + + return res->splitLinalgOp; +} + +FailureOr mlir::linalg::splitReduction( + PatternRewriter &b, LinalgOp op, + const ControlSplitReductionFn &controlSplitReductionFn) { + OpBuilder::InsertionGuard guard(b); + b.setInsertionPoint(op); + std::pair control = controlSplitReductionFn(op); int64_t ratio = control.first; unsigned insertDimIndex = control.second; if (ratio <= 1) return b.notifyMatchFailure(op, "split ratio needs to be greater than 1"); + SmallVector dims; op.getReductionDims(dims); assert(dims.size() == 1); @@ -79,14 +97,16 @@ reductionDimSize % ratio != 0 || insertDimIndex >= loopRanges.size()) return b.notifyMatchFailure( op, "Reduction dimension not divisible by split ratio"); + SmallVector combinerOps; if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps) || combinerOps.size() != 1) return b.notifyMatchFailure(op, "Cannot match the reduction pattern"); + Operation *reductionOp = combinerOps[0]; Optional identity = getIdentity(reductionOp); if (!identity) - return b.notifyMatchFailure(op, "Unknown identity value for the redution"); + return b.notifyMatchFailure(op, "Unknown identity value for the reduction"); Location loc = op->getLoc(); SmallVector newInputs; @@ -127,6 +147,7 @@ loc, newType, operand->get(), reassociation); newInputs.push_back(newInput); } + // Calculate the new output map and shape, we insert the new dimension based // on the index returned by `controlSplitReductionFn`. SmallVector newOutputShape; @@ -169,8 +190,8 @@ b.inlineRegionBefore(op->getRegion(0), genericOp.region(), genericOp.region().begin()); - // Then create a new reduction that only reduce the newly added dimension from - // the previous op. + // Then create a new reduction that only reduce the newly added dimension + // from the previous op. unsigned intermRank = newOutputShape.size(); AffineMap inputMap = b.getMultiDimIdentityMap(intermRank); SmallVector outputOperands = op.getOutputOperands(); @@ -197,9 +218,9 @@ b.create(loc, clonedReductionOp->getResult(0)); }); b.replaceOp(op, reduction.getResults()); - filter.replaceLinalgTransformationFilter(b, genericOp); - filter.replaceLinalgTransformationFilter(b, reduction); - return cast(genericOp.getOperation()); + + return SplitReductionResult{cast(genericOp.getOperation()), + reduction}; } namespace { diff --git a/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir b/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir @@ -0,0 +1,35 @@ +// RUN: mlir-opt --test-transform-dialect-interpreter %s | FileCheck %s + +// CHECK-LABEL: func.func @matmul_split +func.func @matmul_split(%A : tensor<16x256xf32>, %B: tensor<256x32xf32>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> { + + // CHECK: linalg.generic + // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"] + // CHECK-SAME: ins(%{{[a-zA-Z0-9]*}}, %{{[a-zA-Z0-9]*}} : tensor<16x4x64xf32>, tensor<4x64x32xf32>) + // CHECK-SAME: outs(%{{[a-zA-Z0-9]*}} : tensor<16x32x4xf32>) { + + // CHECK: linalg.generic + // CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"] + // CHECK-SAME: ins(%{{[a-zA-Z0-9]*}} : tensor<16x32x4xf32>) + // CHECK-SAME: outs(%{{[a-zA-Z0-9]*}} : tensor<16x32xf32>) { + %0 = linalg.matmul ins(%A, %B: tensor<16x256xf32>, tensor<256x32xf32>) + outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32> + return %0: tensor<16x32xf32> +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @pdl_target : benefit(1) { + %args = operands + %results = types + %0 = pdl.operation "linalg.matmul"(%args : !pdl.range) -> (%results : !pdl.range) + // TODO: we don't want this, but it is the required terminator for pdl.pattern + rewrite %0 with "transform.dialect" + } + + transform.sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = pdl_match @pdl_target in %arg1 + %1:2 = transform.structured.split_reduction %0 { split_factor = 4, split_dimension = 2} + } +}