diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -153,6 +153,75 @@ }]; } +def SplitReductionOp : Op { + let description = [{ + Indicates that the given `target` op should be transformed with the + `splitReduction` transformation and split factor provided as attribute. + + The `splitReduction` transformation splits the first single linalg op + reduction into a parallel and reduction dimension. + A new `linalg.generic` op is created to perform the rest of the reduction. + + Example: + + ``` + %r = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, + affine_map<(d0) -> ()>], + iterator_types = ["reduction"]} + ins(%in : tensor<32xf32>) + outs(%out : tensor) { + ^bb0(%arg1: f32, %arg2: f32): + %y = arith.addf %arg1, %arg2 : f32 + linalg.yield %y : f32 + } -> tensor + ``` + + To: + + ``` + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.expand_shape %in [[0, 1]] : tensor<32xf32> into tensor<4x8xf32> + %1 = linalg.init_tensor [4] : tensor<4xf32> + %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<4xf32>) -> tensor<4xf32> + %3 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"]} + ins(%0 : tensor<4x8xf32>) outs(%2 : tensor<4xf32>) { + ^bb0(%arg3: f32, %arg5: f32): + %5 = arith.addf %arg3, %arg4 : f32 + linalg.yield %5 : f32 + } -> tensor<4xf32> + %r = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, + affine_map<(d0) -> ()>], + iterator_types = ["reduction"]} + ins(%3 : tensor<4xf32>) outs(%out : tensor) { + ^bb0(%arg3: f32, %arg4: f32): + %5 = arith.addf %arg3, %arg4 : f32 + linalg.yield %5 : f32 + } -> tensor + ``` + + This op returns handles to the fill op used to initialize the neutral + element, the split op and the result-combining op. + }]; + + let arguments = (ins PDL_Operation:$target, + DefaultValuedAttr:$split_factor, + DefaultValuedAttr:$insert_split_dimension); + let results = (outs PDL_Operation:$fill_op, + PDL_Operation:$split_linalg_op, + PDL_Operation:$combining_linalg_op); + + let assemblyFormat = "$target attr-dict"; + + let extraClassDeclaration = [{ + ::mlir::FailureOr<::llvm::SmallVector<::mlir::Operation *>> applyToOne( + ::mlir::linalg::LinalgOp target); + }]; +} + def TileOp : Op, FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface]> { diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1466,6 +1466,7 @@ /// reduction dimension. The dimension index is used to control where the extra /// dimension is added to the intermediate tensor shape. If the ratio value is /// less or equal to 1 then nothing will be done. +// TODO: don't use unsigned unless doing bit manipulation. using ControlSplitReductionFn = std::function(LinalgOp op)>; @@ -1519,6 +1520,18 @@ const ControlSplitReductionFn &controlSplitReductionFn, const LinalgTransformationFilter &f); +/// Filterless version of the above. +/// Returns both the new linalg ops as well as the fillOp needed to initialize +/// the temporary expanded tensor with the proper neutral element. +struct SplitReductionResult { + FillOp fillOp; + LinalgOp splitLinalgOp; + LinalgOp resultCombiningLinalgOp; +}; +FailureOr +splitReduction(PatternRewriter &b, LinalgOp op, + const ControlSplitReductionFn &controlSplitReductionFn); + } // namespace linalg } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -98,13 +98,15 @@ /// Streams the given values into the diagnotic. Expects this object to be a /// silencable failure. - template DiagnosedSilenceableFailure &operator<<(T &&value) & { + template + DiagnosedSilenceableFailure &operator<<(T &&value) & { assert(isSilenceableFailure() && "can only append output in silencable failure state"); *diagnostic << std::forward(value); return *this; } - template DiagnosedSilenceableFailure &&operator<<(T &&value) && { + template + DiagnosedSilenceableFailure &&operator<<(T &&value) && { return std::move(this->operator<<(std::forward(value))); } @@ -577,16 +579,17 @@ }; /// Trait implementing the TransformOpInterface for operations applying a -/// transformation to a single operation handle and producing a single operation -/// handle. The op must implement a method with one of the following signatures: +/// transformation to a single operation handle and producing one or multiple +/// operation handles. +/// The op must implement a method with one of the following signatures: /// - FailureOr applyToOne(OpTy) +/// - FailureOr> applyToOne(OpTy) /// - LogicalResult applyToOne(OpTy) /// to perform a transformation that is applied in turn to all payload IR /// operations that correspond to the handle of the transform IR operation. /// In the functions above, OpTy is either Operation * or a concrete payload IR /// Op class that the transformation is applied to (NOT the class of the -/// transform IR op). The op is expected to have one operand and zero or one -/// results. +/// transform IR op). The op is expected to have a single operand. template class TransformEachOpTrait : public OpTrait::TraitBase { @@ -713,33 +716,53 @@ namespace detail { /// Appends `result` to the vector assuming it corresponds to the success state /// in `FailureOr`. If `result` is just a -/// `LogicalResult`, does nothing. +/// `LogicalResult`, appends an empy vector. template std::enable_if_t::value, LogicalResult> -appendTransformResultToVector(Ty result, - SmallVectorImpl &results) { +appendTransformResultToVector( + Ty result, SmallVectorImpl> &results) { + results.push_back(SmallVector()); return result; } + template -std::enable_if_t::value, LogicalResult> -appendTransformResultToVector(Ty result, - SmallVectorImpl &results) { - static_assert( - std::is_convertible::value, - "expected transform function to return operations"); +std::enable_if_t< + llvm::conjunction< + llvm::negation>, + std::is_convertible>::value, + LogicalResult> +appendTransformResultToVector( + Ty result, SmallVectorImpl> &results) { if (failed(result)) return failure(); - - results.push_back(*result); + results.push_back(SmallVector{*result}); return success(); } -/// Applies a one-to-one transform to each of the given targets. Puts the -/// results of transforms, if any, in `results` in the same order. Fails if any -/// of the application fails. Individual transforms must be callable with -/// one of the following signatures: +template +std::enable_if_t< + llvm::conjunction< + llvm::negation>, + llvm::negation>>::value, + LogicalResult> +appendTransformResultToVector( + ContainerTy resultContainer, + SmallVectorImpl> &results) { + if (failed(resultContainer)) + return failure(); + results.push_back(*resultContainer); + return success(); +} +/// Applies a one-to-one or a one-to-many transform to each of the given +/// targets. Puts the results of transforms, if any, in `results` in the same +/// order. Fails if any of the application fails. Individual transforms must be +/// callable with one of the following signatures: /// - FailureOr(OpTy) /// - LogicalResult(OpTy) +/// - FailureOr( +/// SmallVectorImpl) +/// - LogicalResult(SmallVectorImpl) /// where OpTy is either /// - Operation *, in which case the transform is always applied; /// - a concrete Op class, in which case a check is performed whether @@ -748,7 +771,8 @@ template DiagnosedSilenceableFailure applyTransformToEach(ArrayRef targets, - SmallVectorImpl &results, FnTy transform) { + SmallVectorImpl> &results, + FnTy transform) { using OpTy = typename llvm::function_traits::template arg_t<0>; static_assert(std::is_convertible::value, "expected transform function to take an operation"); @@ -782,17 +806,36 @@ decltype(&OpTy::applyToOne)>::template arg_t<0>; ArrayRef targets = state.getPayloadOps(this->getOperation()->getOperand(0)); - SmallVector results; + SmallVector, 1> results; + // In the multi-result case, collect the number of results each transform + // produced. DiagnosedSilenceableFailure result = detail::applyTransformToEach( targets, results, [&](TransformOpType specificOp) { return static_cast(this)->applyToOne(specificOp); }); if (!result.succeeded()) return result; - - if (OpTy::template hasTrait()) { - transformResults.set( - this->getOperation()->getResult(0).template cast(), results); + for (const SmallVector &oneTargetResults : results) { + if (OpTy::template hasTrait()) + continue; + if (OpTy::template hasTrait()) { + transformResults.set( + this->getOperation()->getResult(0).template cast(), + oneTargetResults); + continue; + } + if (this->getOperation()->getNumResults() != oneTargetResults.size()) { + Diagnostic diag(this->getOperation()->getLoc(), + DiagnosticSeverity::Error); + diag << "unexpected number of results (got " << oneTargetResults.size() + << " expected " << this->getOperation()->getNumResults() << ")"; + return DiagnosedSilenceableFailure::silencableFailure(std::move(diag)); + } + for (const auto &it : + llvm::zip(this->getOperation()->getResults(), oneTargetResults)) { + transformResults.set(std::get<0>(it).template cast(), + std::get<1>(it)); + } } return DiagnosedSilenceableFailure::success(); } @@ -802,9 +845,6 @@ mlir::transform::TransformEachOpTrait::verifyTrait(Operation *op) { static_assert(OpTy::template hasTrait(), "expected single-operand op"); - static_assert(OpTy::template hasTrait() || - OpTy::template hasTrait(), - "expected zero- or single-result op"); if (!op->getName().getInterface()) { return op->emitError() << "TransformEachOpTrait should only be attached to " "ops that implement TransformOpInterface"; diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -394,6 +394,27 @@ return result->op; } +//===----------------------------------------------------------------------===// +// SplitReductionOp +//===----------------------------------------------------------------------===// + +FailureOr> +transform::SplitReductionOp::applyToOne(LinalgOp target) { + ControlSplitReductionFn splitFn = [&](LinalgOp) { + return std::pair(getSplitFactor(), + getInsertSplitDimension()); + }; + SimpleRewriter rewriter(getContext()); + rewriter.setInsertionPoint(target); + FailureOr splitResult = + splitReduction(rewriter, target, splitFn); + if (failed(splitResult)) + return getOperation()->emitError("failed to apply"); + return SmallVector{splitResult->fillOp, + splitResult->splitLinalgOp, + splitResult->resultCombiningLinalgOp}; +} + //===----------------------------------------------------------------------===// // TileOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp @@ -64,11 +64,30 @@ op.getNumReductionLoops() != 1 || op.getNumOutputs() != 1 || !op.hasOnlyProjectedPermutations()) return b.notifyMatchFailure(op, "precondition not met"); + + FailureOr res = + splitReduction(b, op, controlSplitReductionFn); + if (failed(res)) + return failure(); + + filter.replaceLinalgTransformationFilter(b, res->splitLinalgOp); + filter.replaceLinalgTransformationFilter(b, res->resultCombiningLinalgOp); + + return res->splitLinalgOp; +} + +FailureOr mlir::linalg::splitReduction( + PatternRewriter &b, LinalgOp op, + const ControlSplitReductionFn &controlSplitReductionFn) { + OpBuilder::InsertionGuard guard(b); + b.setInsertionPoint(op); + std::pair control = controlSplitReductionFn(op); int64_t ratio = control.first; unsigned insertDimIndex = control.second; if (ratio <= 1) return b.notifyMatchFailure(op, "split ratio needs to be greater than 1"); + SmallVector dims; op.getReductionDims(dims); assert(dims.size() == 1); @@ -79,14 +98,16 @@ reductionDimSize % ratio != 0 || insertDimIndex >= loopRanges.size()) return b.notifyMatchFailure( op, "Reduction dimension not divisible by split ratio"); + SmallVector combinerOps; if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps) || combinerOps.size() != 1) return b.notifyMatchFailure(op, "Cannot match the reduction pattern"); + Operation *reductionOp = combinerOps[0]; Optional identity = getIdentity(reductionOp); if (!identity) - return b.notifyMatchFailure(op, "Unknown identity value for the redution"); + return b.notifyMatchFailure(op, "Unknown identity value for the reduction"); Location loc = op->getLoc(); SmallVector newInputs; @@ -127,6 +148,7 @@ loc, newType, operand->get(), reassociation); newInputs.push_back(newInput); } + // Calculate the new output map and shape, we insert the new dimension based // on the index returned by `controlSplitReductionFn`. SmallVector newOutputShape; @@ -169,8 +191,8 @@ b.inlineRegionBefore(op->getRegion(0), genericOp.region(), genericOp.region().begin()); - // Then create a new reduction that only reduce the newly added dimension from - // the previous op. + // Then create a new reduction that only reduce the newly added dimension + // from the previous op. unsigned intermRank = newOutputShape.size(); AffineMap inputMap = b.getMultiDimIdentityMap(intermRank); SmallVector outputOperands = op.getOutputOperands(); @@ -197,9 +219,10 @@ b.create(loc, clonedReductionOp->getResult(0)); }); b.replaceOp(op, reduction.getResults()); - filter.replaceLinalgTransformationFilter(b, genericOp); - filter.replaceLinalgTransformationFilter(b, reduction); - return cast(genericOp.getOperation()); + + return SplitReductionResult{identityTensor.getDefiningOp(), + cast(genericOp.getOperation()), + reduction}; } namespace { diff --git a/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir b/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir @@ -0,0 +1,35 @@ +// RUN: mlir-opt --test-transform-dialect-interpreter %s | FileCheck %s + +// CHECK-LABEL: func.func @matmul_split +func.func @matmul_split(%A : tensor<16x256xf32>, %B: tensor<256x32xf32>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> { + + // CHECK: linalg.generic + // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"] + // CHECK-SAME: ins(%{{[a-zA-Z0-9]*}}, %{{[a-zA-Z0-9]*}} : tensor<16x4x64xf32>, tensor<4x64x32xf32>) + // CHECK-SAME: outs(%{{[a-zA-Z0-9]*}} : tensor<16x32x4xf32>) { + + // CHECK: linalg.generic + // CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"] + // CHECK-SAME: ins(%{{[a-zA-Z0-9]*}} : tensor<16x32x4xf32>) + // CHECK-SAME: outs(%{{[a-zA-Z0-9]*}} : tensor<16x32xf32>) { + %0 = linalg.matmul ins(%A, %B: tensor<16x256xf32>, tensor<256x32xf32>) + outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32> + return %0: tensor<16x32xf32> +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @pdl_target : benefit(1) { + %args = operands + %results = types + %0 = pdl.operation "linalg.matmul"(%args : !pdl.range) -> (%results : !pdl.range) + // TODO: we don't want this, but it is the required terminator for pdl.pattern + rewrite %0 with "transform.dialect" + } + + transform.sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = pdl_match @pdl_target in %arg1 + %1:3 = transform.structured.split_reduction %0 { split_factor = 4, insert_split_dimension = 2} + } +} diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -378,3 +378,10 @@ } } } +// ----- + +transform.sequence { +^bb0(%arg0: !pdl.operation): + // expected-error @below {{unexpected number of results (got 0 expected 3)}} + transform.test_wrong_number_of_results %arg0 +} diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.h b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.h --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.h +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.h @@ -16,7 +16,7 @@ #include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" -#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" namespace mlir { class DialectRegistry; diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -226,6 +226,11 @@ return DiagnosedSilenceableFailure::success(); } +FailureOr> +mlir::test::TestWrongNumberOfResultsOp::applyToOne(Operation *) { + return SmallVector{}; +} + namespace { /// Test extension of the Transform dialect. Registers additional ops and /// declares PDL as dependent dialect since the additional ops are using PDL diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td @@ -128,4 +128,20 @@ let cppNamespace = "::mlir::test"; } +def TestWrongNumberOfResultsOp + : Op { + let arguments = (ins PDL_Operation:$target); + let results = (outs PDL_Operation:$a, + PDL_Operation:$b, + PDL_Operation:$c); + let assemblyFormat = "$target attr-dict"; + let cppNamespace = "::mlir::test"; + let extraClassDeclaration = [{ + ::mlir::FailureOr<::llvm::SmallVector<::mlir::Operation *>> applyToOne( + ::mlir::Operation *target); + }]; +} + #endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD