diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td --- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td +++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td @@ -16,56 +16,426 @@ include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/OpBase.td" -def LowerVectorsOp : Op, - DeclareOpInterfaceMethods]> { + +// TODO: not isolated from above and better targetability of the op. +// TODO: not a functional-style transform. +def ApplyFoldingPatternsOp : + Op { + let description = [{ + Apply opt-in vector folding patterns that include: + - CreateMaskFolder + - MaskedLoadFolder + - MaskedStoreFolder + - GatherFolder + - ScatterFolder + - ExpandLoadFolder + - CompressStoreFolder + - StridedSliceConstantMaskFolder + - TransposeFolder (folds consecutive vector.transpose ops) + }]; + + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs TransformHandleTypeInterface:$results); + + let assemblyFormat = [{ + $target + attr-dict + `:` functional-type($target, results) + }]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + +// TODO: not isolated from above and better targetability of the op. +// TODO: not a functional-style transform. +def ApplyRankReducingShapeCastPatternsOp : + Op { + let description = [{ + Apply opt-in vector transfer permutation patterns that include: + - CastAwayExtractStridedSliceLeadingOneDim + - CastAwayInsertStridedSliceLeadingOneDim + - CastAwayInsertLeadingOneDim + - CastAwayTransferReadLeadingOneDim + - CastAwayTransferWriteLeadingOneDim + - CastAwayElementwiseLeadingOneDim + - CastAwayContractionLeadingOneDim + + These patterns are defined in VectorDropLeadUnitDim.cpp and have the effect + of rewriting a vector.transfer with a leading 1 dimension into a rank-reduced + version thanks to vector.broadcast and vector.shape_cast operations. + + This is complemented by vector.shape_cast folding patterns to try to + eliminate such operations. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs TransformHandleTypeInterface:$results); + + let assemblyFormat = [{ + $target + attr-dict + `:` functional-type($target, results) + }]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + +// TODO: not isolated from above and better targetability of the op. +// TODO: not a functional-style transform. +def ApplyRankReducingSubviewPatternsOp : + Op { + let description = [{ + Apply opt-in vector transfer permutation patterns that include: + - TransferReadDropUnitDimsPattern + - TransferWriteDropUnitDimsPattern + + These patterns have the effect of rewriting a vector.transfer with unit + dimensions into a rank-reduced version thanks to subview operations. + This is complemented by shape_cast folding patterns. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs TransformHandleTypeInterface:$results); + + let assemblyFormat = [{ + $target + attr-dict + `:` functional-type($target, results) + }]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + +// TODO: not isolated from above and better targetability of the op. +// TODO: not a functional-style transform. +def ApplyTransferPermutationPatternsOp : + Op { let description = [{ - Indicates that the vector operations nested under the isolated from above op - `target` should be lowered to finer-grained vector primitives. + Apply opt-in vector transfer permutation patterns that include: + - TransferReadPermutationLowering + - TransferWritePermutationLowering + - TransferOpReduceRank + - TransferWriteNonPermutationLowering + + These patterns have the effect of rewriting a vector.transfer with an + aritrary permutation_map to a vector.transfer with a permutation_map that is + a minor identity followed by a vector.transpose. - At this time, the transform is all or nothing. + In other words, this makes the vector.transfer contiguous on the most minor + dimensions and materializes the permutation_map as a vector.transpose. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs TransformHandleTypeInterface:$results); + + let assemblyFormat = [{ + $target + attr-dict + `:` functional-type($target, results) + }]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + +// TODO: not isolated from above and better targetability of the op. +// TODO: not a functional-style transform. +// TODO: evolve lowering_strategy to proper enums. +def LowerContractionOp : Op { + let description = [{ + Indicates that the vector contraction-like operations nested under the + isolated from above op `target` should be lowered to finer-grained vector + primitives. This is usally a late step that is run after bufferization as part of the process of lowering to e.g. LLVM or NVVM. }]; - // TODO: evolve this to proper enums. - let arguments = (ins PDL_Operation:$target, + let arguments = (ins TransformHandleTypeInterface:$target, DefaultValuedAttr:$contraction_lowering, + "vector::VectorContractLowering::OuterProduct">:$lowering_strategy + ); + let results = (outs TransformHandleTypeInterface:$results); + + let assemblyFormat = [{ + $target + (`lowering_strategy` `=` $lowering_strategy^)? + attr-dict + `:` functional-type($target, results) + }]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + +// TODO: not isolated from above and better targetability of the op. +// TODO: not a functional-style transform. +// TODO: evolve lowering_strategy to proper enums. +def LowerMultiReductionOp : Op { + let description = [{ + Indicates that the vector multi_reduction-like operations nested under the + isolated from above op `target` should be lowered to finer-grained vector + primitives. + + This is usally a late step that is run after bufferization as part of the + process of lowering to e.g. LLVM or NVVM. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target, DefaultValuedAttr: - $multireduction_lowering, + $lowering_strategy + ); + let results = (outs TransformHandleTypeInterface:$results); + + let assemblyFormat = [{ + $target + (`lowering_strategy` `=` $lowering_strategy^)? + attr-dict + `:` functional-type($target, results) + }]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + +// TODO: not isolated from above and better targetability of the op. +// TODO: not a functional-style transform. +// TODO: evolve split_transfer_strategy to proper enums. +def SplitTransferFullPartialOp : Op { + let description = [{ + Indicates that the vector transfer operations nested under the + isolated from above op `target` should be split to full and partial parts. + + This is usally a late step that is run after bufferization as part of the + process of lowering to e.g. LLVM or NVVM. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target, DefaultValuedAttr:$split_transfers, + "vector::VectorTransferSplit::LinalgCopy">:$split_transfer_strategy + ); + let results = (outs TransformHandleTypeInterface:$results); + + let assemblyFormat = [{ + $target + (`split_transfer_strategy` `=` $split_transfer_strategy^)? + attr-dict + `:` functional-type($target, results) + }]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + + +// TODO: not isolated from above and better targetability of the op. +// TODO: not a functional-style transform. +def LowerTransferOp : Op { + let description = [{ + Indicates that the vector transfer operations nested under the + isolated from above op `target` should be lowered to finer-grained vector + primitives. + + This is usally a late step that is run after bufferization as part of the + process of lowering to e.g. LLVM or NVVM. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target, + DefaultValuedAttr:$max_transfer_rank + ); + let results = (outs TransformHandleTypeInterface:$results); + + let assemblyFormat = [{ + $target + (`max_transfer_rank` `=` $max_transfer_rank^)? + attr-dict + `:` functional-type($target, results) + }]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + +// TODO: not isolated from above and better targetability of the op. +// TODO: not a functional-style transform. +def TransferToScfOp : Op { + let description = [{ + Indicates that the vector transfer operations nested under the + isolated from above op `target` should be rewritten with scf.for loops over + finer-grained vector primitives. + + This is usally a late step that is run after bufferization as part of the + process of lowering to e.g. LLVM or NVVM. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target, + DefaultValuedAttr:$max_transfer_rank, + DefaultValuedAttr:$full_unroll + ); + let results = (outs TransformHandleTypeInterface:$results); + + let assemblyFormat = [{ + $target + oilist ( + `max_transfer_rank` `=` $max_transfer_rank + | `full_unroll` `=` $full_unroll + ) + attr-dict + `:` functional-type($target, results) + }]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + +// TODO: not isolated from above and better targetability of the op. +// TODO: not a functional-style transform. +def LowerShapeCastOp : Op { + let description = [{ + Indicates that the vector shape_cast operations nested under the + isolated from above op `target` should be lowered to finer-grained vector + primitives. + + This is usally a late step that is run after bufferization as part of the + process of lowering to e.g. LLVM or NVVM. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs TransformHandleTypeInterface:$results); + + let assemblyFormat = [{ + $target + attr-dict + `:` functional-type($target, results) + }]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + +// TODO: not isolated from above and better targetability of the op. +// TODO: not a functional-style transform. +// TODO: evolve lowering_strategy to proper enums. +def LowerTransposeOp : Op { + let description = [{ + Indicates that the vector transpose-like operations nested under the + isolated from above op `target` should be lowered to finer-grained vector + primitives. + + This is usally a late step that is run after bufferization as part of the + process of lowering to e.g. LLVM or NVVM. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target, DefaultValuedAttr:$transpose_lowering, - DefaultValuedAttr:$transpose_avx2_lowering, - DefaultValuedAttr:$unroll_vector_transfers + "vector::VectorTransposeLowering::EltWise">:$lowering_strategy, + DefaultValuedAttr:$avx2_lowering_strategy ); - let results = (outs PDL_Operation:$results); - - let builders = [ - OpBuilder<(ins "Type":$resultType, "Value":$target, - "const vector::LowerVectorsOptions &":$options), [{ - return build($_builder, $_state, resultType, target, - options.vectorContractLowering, - options.vectorMultiReductionLowering, options.vectorTransferSplit, - options.vectorTransposeLowering, options.transposeAVX2Lowering, - options.unrollVectorTransfers); - }] - > - ]; + let results = (outs TransformHandleTypeInterface:$results); let assemblyFormat = [{ $target oilist ( - `contraction_lowering` `=` $contraction_lowering - | `multireduction_lowering` `=` $multireduction_lowering - | `split_transfers` `=` $split_transfers - | `transpose_lowering` `=` $transpose_lowering + `lowering_strategy` `=` $lowering_strategy + | `avx2_lowering_strategy` `=` $avx2_lowering_strategy ) attr-dict + `:` functional-type($target, results) + }]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); }]; } diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp --- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp +++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp @@ -25,101 +25,327 @@ using namespace mlir::transform; //===----------------------------------------------------------------------===// -// LowerVectorsOp +// LowerContractionOp //===----------------------------------------------------------------------===// -void transform::LowerVectorsOp::getEffects( - SmallVectorImpl &effects) { - consumesHandle(getTarget(), effects); - producesHandle(getResults(), effects); - modifiesPayload(effects); +DiagnosedSilenceableFailure transform::LowerContractionOp::applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state) { + // This check can't be part of the verifier because payload IR is + // independent from transform IR and may not even exist. + if (!target->hasTrait()) { + return mlir::emitDefiniteFailure(target, + "applies only to isolated-from-above " + "targets because it needs to apply " + "patterns greedily"); + } + + RewritePatternSet patterns(getContext()); + vector::VectorTransformsOptions vectorTransformOptions; + vectorTransformOptions.setVectorTransformsOptions(getLoweringStrategy()); + populateVectorContractLoweringPatterns(patterns, vectorTransformOptions, + /*benefit=*/1, + /*disableOuterProductLowering=*/true); + + if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) + return DiagnosedSilenceableFailure::definiteFailure(); + + results.push_back(target); + return DiagnosedSilenceableFailure::success(); +} + +//===----------------------------------------------------------------------===// +// LowerMultiReductionOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::LowerMultiReductionOp::applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state) { + // This check can't be part of the verifier because payload IR is + // independent from transform IR and may not even exist. + if (!target->hasTrait()) { + return mlir::emitDefiniteFailure(target, + "applies only to isolated-from-above " + "targets because it needs to apply " + "patterns greedily"); + } + + RewritePatternSet patterns(getContext()); + vector::VectorTransformsOptions vectorTransformOptions; + vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy()); + vector::populateVectorMultiReductionLoweringPatterns( + patterns, vectorTransformOptions.vectorMultiReductionLowering); + + if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) + return DiagnosedSilenceableFailure::definiteFailure(); + + results.push_back(target); + return DiagnosedSilenceableFailure::success(); +} + +//===----------------------------------------------------------------------===// +// SplitTransferFullPartialOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::SplitTransferFullPartialOp::applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state) { + // This check can't be part of the verifier because payload IR is + // independent from transform IR and may not even exist. + if (!target->hasTrait()) { + return mlir::emitDefiniteFailure(target, + "applies only to isolated-from-above " + "targets because it needs to apply " + "patterns greedily"); + } + + RewritePatternSet patterns(getContext()); + vector::VectorTransformsOptions vectorTransformOptions; + vectorTransformOptions.setVectorTransferSplit(getSplitTransferStrategy()); + populateVectorTransferFullPartialPatterns(patterns, vectorTransformOptions); + + if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) + return DiagnosedSilenceableFailure::definiteFailure(); + + results.push_back(target); + return DiagnosedSilenceableFailure::success(); +} + +//===----------------------------------------------------------------------===// +// LowerTransferOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::LowerTransferOp::applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state) { + // This check can't be part of the verifier because payload IR is + // independent from transform IR and may not even exist. + if (!target->hasTrait()) { + return mlir::emitDefiniteFailure(target, + "applies only to isolated-from-above " + "targets because it needs to apply " + "patterns greedily"); + } + + RewritePatternSet patterns(getContext()); + vector::populateVectorTransferLoweringPatterns(patterns, + getMaxTransferRank()); + + if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) + return DiagnosedSilenceableFailure::definiteFailure(); + + results.push_back(target); + return DiagnosedSilenceableFailure::success(); +} + +//===----------------------------------------------------------------------===// +// TransferToScfOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::TransferToScfOp::applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state) { + // This check can't be part of the verifier because payload IR is + // independent from transform IR and may not even exist. + if (!target->hasTrait()) { + return mlir::emitDefiniteFailure(target, + "applies only to isolated-from-above " + "targets because it needs to apply " + "patterns greedily"); + } + + RewritePatternSet patterns(getContext()); + VectorTransferToSCFOptions vectorTransferToSCFOptions = + VectorTransferToSCFOptions() + .enableFullUnroll(getFullUnroll()) + .setTargetRank(getMaxTransferRank()); + populateVectorToSCFConversionPatterns(patterns, vectorTransferToSCFOptions); + + if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) + return DiagnosedSilenceableFailure::definiteFailure(); + + results.push_back(target); + return DiagnosedSilenceableFailure::success(); } -DiagnosedSilenceableFailure transform::LowerVectorsOp::apply( - mlir::transform::TransformResults &transformResults, - mlir::transform::TransformState &state) { - - SmallVector results; - ArrayRef payloadOps = state.getPayloadOps(getTarget()); - for (Operation *target : payloadOps) { - // This check can't be part of the verifier because payload IR is - // independent from transform IR and may not even exist. - if (!target->hasTrait()) { - return mlir::emitDefiniteFailure(target, - "applies only to isolated-from-above " - "targets because it needs to apply " - "patterns greedily"); - } - - MLIRContext *ctx = getContext(); - RewritePatternSet patterns(ctx); - vector::VectorTransposeLowering vectorTransposeLowering = - getTransposeLowering(); - vector::VectorMultiReductionLowering vectorMultiReductionLowering = - getMultireductionLowering(); - vector::VectorContractLowering vectorContractLowering = - getContractionLowering(); - vector::VectorTransferSplit vectorTransferSplit = getSplitTransfers(); - - vector::VectorTransformsOptions vectorTransformOptions; - vectorTransformOptions.setVectorTransformsOptions(vectorContractLowering) - .setVectorMultiReductionLowering(vectorMultiReductionLowering) - .setVectorTransposeLowering(vectorTransposeLowering) - .setVectorTransferSplit(vectorTransferSplit); - - VectorTransferToSCFOptions vectorTransferToSCFOptions = - VectorTransferToSCFOptions().enableFullUnroll( - getUnrollVectorTransfers()); - - int maxTransferRank = 1; +//===----------------------------------------------------------------------===// +// LowerShapeCastOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::LowerShapeCastOp::applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state) { + // This check can't be part of the verifier because payload IR is + // independent from transform IR and may not even exist. + if (!target->hasTrait()) { + return mlir::emitDefiniteFailure(target, + "applies only to isolated-from-above " + "targets because it needs to apply " + "patterns greedily"); + } + + RewritePatternSet patterns(getContext()); + vector::populateVectorShapeCastLoweringPatterns(patterns); + + if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) + return DiagnosedSilenceableFailure::definiteFailure(); + + results.push_back(target); + return DiagnosedSilenceableFailure::success(); +} + +//===----------------------------------------------------------------------===// +// LowerTransposeOp +//===----------------------------------------------------------------------===// +DiagnosedSilenceableFailure transform::LowerTransposeOp::applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state) { + // This check can't be part of the verifier because payload IR is + // independent from transform IR and may not even exist. + if (!target->hasTrait()) { + return mlir::emitDefiniteFailure(target, + "applies only to isolated-from-above " + "targets because it needs to apply " + "patterns greedily"); + } + + RewritePatternSet patterns(getContext()); + vector::populateVectorTransposeLoweringPatterns( + patterns, vector::VectorTransformsOptions().setVectorTransposeLowering( + getLoweringStrategy())); + if (getAvx2LoweringStrategy()) { auto avx2LoweringOptions = x86vector::avx2::LoweringOptions().setTransposeOptions( x86vector::avx2::TransposeLoweringOptions() - .lower4x8xf32(getTransposeAvx2Lowering()) - .lower8x8xf32(getTransposeAvx2Lowering())); + .lower4x8xf32(true) + .lower8x8xf32(true)); + x86vector::avx2::populateSpecializedTransposeLoweringPatterns( + patterns, avx2LoweringOptions, /*benefit=*/10); + } - vector::populateVectorToVectorCanonicalizationPatterns(patterns); + if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) + return DiagnosedSilenceableFailure::definiteFailure(); - // In the future we may want to more finely select particular stages. - // Stage 1: contraction lowerings. - populateVectorContractLoweringPatterns( - patterns, vectorTransformOptions, /*benefit=*/1, - /*disableOuterProductLowering*/ true); - vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); + results.push_back(target); + return DiagnosedSilenceableFailure::success(); +} - // Stage 2: multi-reduction lowerings. - vector::populateVectorMultiReductionLoweringPatterns( - patterns, vectorTransformOptions.vectorMultiReductionLowering); +//===----------------------------------------------------------------------===// +// ApplyFoldingPatternsOp +//===----------------------------------------------------------------------===// - // Stage 3: Rewrite vector.transfer into full and partial parts. - populateVectorTransferFullPartialPatterns(patterns, vectorTransformOptions); +DiagnosedSilenceableFailure transform::ApplyFoldingPatternsOp::applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state) { + // This check can't be part of the verifier because payload IR is + // independent from transform IR and may not even exist. + if (!target->hasTrait()) { + return mlir::emitDefiniteFailure(target, + "applies only to isolated-from-above " + "targets because it needs to apply " + "patterns greedily"); + } - // Stage 4: Lower vector transfers. - vector::populateVectorTransferLoweringPatterns(patterns, maxTransferRank); + RewritePatternSet patterns(getContext()); + vector::populateVectorToVectorCanonicalizationPatterns(patterns); - // Stage 5: Vector to scf patterns. - populateVectorToSCFConversionPatterns( - patterns, vectorTransferToSCFOptions.setTargetRank(maxTransferRank)); + if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) + return DiagnosedSilenceableFailure::definiteFailure(); - // Stage 6: Lower vector.shape_cast. - vector::populateVectorShapeCastLoweringPatterns(patterns); + results.push_back(target); + return DiagnosedSilenceableFailure::success(); +} - // Stage 7: Lower vector.transpose. - vector::populateVectorTransposeLoweringPatterns( - patterns, vectorTransformOptions, /*benefit=*/1); - if (getTransposeAvx2Lowering()) - x86vector::avx2::populateSpecializedTransposeLoweringPatterns( - patterns, avx2LoweringOptions, /*benefit=*/10); +//===----------------------------------------------------------------------===// +// ApplyTransferPermutationPatternsOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::ApplyTransferPermutationPatternsOp::applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state) { + // This check can't be part of the verifier because payload IR is + // independent from transform IR and may not even exist. + if (!target->hasTrait()) { + return mlir::emitDefiniteFailure(target, + "applies only to isolated-from-above " + "targets because it needs to apply " + "patterns greedily"); + } + + RewritePatternSet patterns(getContext()); + vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); + + if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) + return DiagnosedSilenceableFailure::definiteFailure(); + + results.push_back(target); + return DiagnosedSilenceableFailure::success(); +} + +//===----------------------------------------------------------------------===// +// ApplyRankReducingShapeCastPatternsOp +//===----------------------------------------------------------------------===// - // Apply everything. - if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) - return DiagnosedSilenceableFailure::definiteFailure(); +DiagnosedSilenceableFailure +transform::ApplyRankReducingShapeCastPatternsOp::applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state) { + // This check can't be part of the verifier because payload IR is + // independent from transform IR and may not even exist. + if (!target->hasTrait()) { + return mlir::emitDefiniteFailure(target, + "applies only to isolated-from-above " + "targets because it needs to apply " + "patterns greedily"); + } + + RewritePatternSet patterns(getContext()); + vector::populateCastAwayVectorLeadingOneDimPatterns(patterns); + + if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) + return DiagnosedSilenceableFailure::definiteFailure(); + + results.push_back(target); + return DiagnosedSilenceableFailure::success(); +} - results.push_back(target); +//===----------------------------------------------------------------------===// +// ApplyRankReducingSubviewPatternsOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::ApplyRankReducingSubviewPatternsOp::applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state) { + // This check can't be part of the verifier because payload IR is + // independent from transform IR and may not even exist. + if (!target->hasTrait()) { + return mlir::emitDefiniteFailure(target, + "applies only to isolated-from-above " + "targets because it needs to apply " + "patterns greedily"); } - transformResults.set(getResults().cast(), results); + RewritePatternSet patterns(getContext()); + vector::populateVectorTransferDropUnitDimsPatterns(patterns); + + if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) + return DiagnosedSilenceableFailure::definiteFailure(); + + results.push_back(target); return DiagnosedSilenceableFailure::success(); } diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp @@ -379,15 +379,15 @@ // We let the 0-d corner case pass-through as it is supported. if (!read.getPermutationMap().isMinorIdentityWithBroadcasting( &broadcastedDims)) - return failure(); + return rewriter.notifyMatchFailure(read, "not minor identity + bcast"); auto memRefType = read.getShapedType().dyn_cast(); if (!memRefType) - return failure(); + return rewriter.notifyMatchFailure(read, "not a memref source"); // Non-unit strides are handled by VectorToSCF. if (!vector::isLastMemrefDimUnitStride(memRefType)) - return failure(); + return rewriter.notifyMatchFailure(read, "!= 1 stride needs VectorToSCF"); // If there is broadcasting involved then we first load the unbroadcasted // vector, and then broadcast it with `vector.broadcast`. @@ -403,16 +403,16 @@ // resulting vector type is the same as the element type. auto memrefElTy = memRefType.getElementType(); if (memrefElTy.isa() && memrefElTy != unbroadcastedVectorType) - return failure(); + return rewriter.notifyMatchFailure(read, "incompatible element type"); // Otherwise, element types of the memref and the vector must match. if (!memrefElTy.isa() && memrefElTy != read.getVectorType().getElementType()) - return failure(); + return rewriter.notifyMatchFailure(read, "non-matching element type"); // Out-of-bounds dims are handled by MaterializeTransferMask. if (read.hasOutOfBoundsDim()) - return failure(); + return rewriter.notifyMatchFailure(read, "out-of-bounds needs mask"); // Create vector load op. Operation *loadOp; @@ -458,7 +458,8 @@ PatternRewriter &rewriter) const override { auto vecType = loadOp.getVectorType(); if (vecType.getNumElements() != 1) - return failure(); + return rewriter.notifyMatchFailure(loadOp, "not a single element vector"); + auto memrefLoad = rewriter.create( loadOp.getLoc(), loadOp.getBase(), loadOp.getIndices()); rewriter.replaceOpWithNewOp(loadOp, vecType, @@ -476,7 +477,8 @@ PatternRewriter &rewriter) const override { auto vecType = storeOp.getVectorType(); if (vecType.getNumElements() != 1) - return failure(); + return rewriter.notifyMatchFailure(storeOp, "not single element vector"); + Value extracted; if (vecType.getRank() == 0) { // TODO: Unifiy once ExtractOp supports 0-d vectors. diff --git a/mlir/test/Dialect/LLVM/transform-e2e.mlir b/mlir/test/Dialect/LLVM/transform-e2e.mlir --- a/mlir/test/Dialect/LLVM/transform-e2e.mlir +++ b/mlir/test/Dialect/LLVM/transform-e2e.mlir @@ -20,6 +20,39 @@ transform.structured.vectorize %2 transform.bufferization.one_shot_bufferize layout{IdentityLayoutMap} %module_op {bufferize_function_boundaries = true} - %func = transform.structured.match ops{["func.func"]} in %module_op : (!pdl.operation) -> !pdl.operation - transform.vector.lower_vectors %func multireduction_lowering = "innerreduction" + + %f = transform.structured.match ops{["func.func"]} in %module_op + : (!pdl.operation) -> !pdl.operation + + // TODO: group these lower-level controls into various properly named vector + // lowering TD macros. + %func = transform.vector.lower_contraction %f + lowering_strategy = "outerproduct" + : (!pdl.operation) -> !pdl.operation + + %func_2 = transform.vector.apply_transfer_permutation_patterns %func + : (!pdl.operation) -> !pdl.operation + + %func_3 = transform.vector.lower_multi_reduction %func_2 + lowering_strategy = "innerparallel" + : (!pdl.operation) -> !pdl.operation + + %func_4 = transform.vector.split_transfer_full_partial %func_3 + split_transfer_strategy = "linalg-copy" + : (!pdl.operation) -> !pdl.operation + + %func_5 = transform.vector.transfer_to_scf %func_4 + max_transfer_rank = 1 full_unroll = true + : (!pdl.operation) -> !pdl.operation + + %func_6 = transform.vector.lower_transfer %func_5 + max_transfer_rank = 1 + : (!pdl.operation) -> !pdl.operation + + %func_7 = transform.vector.lower_shape_cast %func_6 + : (!pdl.operation) -> !pdl.operation + + %func_8 = transform.vector.lower_transpose %func_7 + lowering_strategy = "shuffle" + : (!pdl.operation) -> !pdl.operation } diff --git a/mlir/test/Dialect/Vector/transform-vector.mlir b/mlir/test/Dialect/Vector/transform-vector.mlir --- a/mlir/test/Dialect/Vector/transform-vector.mlir +++ b/mlir/test/Dialect/Vector/transform-vector.mlir @@ -16,11 +16,45 @@ transform.sequence failures(propagate) { ^bb1(%module_op: !pdl.operation): %0 = transform.structured.match ops{["linalg.matmul"]} in %module_op : (!pdl.operation) -> !pdl.operation - %1, %loops:3 = transform.structured.tile %0 [8, 4, 2] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation) + %1, %loops:3 = transform.structured.tile %0 [8, 4, 2] + : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation) %2 = get_closest_isolated_parent %1 : (!pdl.operation) -> !pdl.operation transform.structured.vectorize %2 - transform.bufferization.one_shot_bufferize %module_op + transform.bufferization.one_shot_bufferize layout{IdentityLayoutMap} %module_op + {bufferize_function_boundaries = true, allow_return_allocs = true} - %func = transform.structured.match ops{["func.func"]} in %module_op : (!pdl.operation) -> !pdl.operation - transform.vector.lower_vectors %func multireduction_lowering = "innerreduction" + %f = transform.structured.match ops{["func.func"]} in %module_op + : (!pdl.operation) -> !pdl.operation + + // TODO: group these lower-level controls into various properly named vector + // lowering TD macros. + %func = transform.vector.lower_contraction %f + lowering_strategy = "outerproduct" + : (!pdl.operation) -> !pdl.operation + + %func_2 = transform.vector.apply_transfer_permutation_patterns %func + : (!pdl.operation) -> !pdl.operation + + %func_3 = transform.vector.lower_multi_reduction %func_2 + lowering_strategy = "innerparallel" + : (!pdl.operation) -> !pdl.operation + + %func_4 = transform.vector.split_transfer_full_partial %func_3 + split_transfer_strategy = "linalg-copy" + : (!pdl.operation) -> !pdl.operation + + %func_5 = transform.vector.transfer_to_scf %func_4 + max_transfer_rank = 1 full_unroll = true + : (!pdl.operation) -> !pdl.operation + + %func_6 = transform.vector.lower_transfer %func_5 + max_transfer_rank = 1 + : (!pdl.operation) -> !pdl.operation + + %func_7 = transform.vector.lower_shape_cast %func_6 + : (!pdl.operation) -> !pdl.operation + + %func_8 = transform.vector.lower_transpose %func_7 + lowering_strategy = "shuffle" + : (!pdl.operation) -> !pdl.operation }