diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td --- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td +++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td @@ -16,56 +16,446 @@ include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/OpBase.td" -def LowerVectorsOp : Op, - DeclareOpInterfaceMethods]> { + +// TODO: not isolated from above and better targetability of the op. +// TODO: not a functional-style transform. +def ApplyRankReducingSubviewPatternsOp : + Op { let description = [{ - Indicates that the vector operations nested under the isolated from above op - `target` should be lowered to finer-grained vector primitives. + Apply opt-in vector transfer permutation patterns that include: + - TransferReadDropUnitDimsPattern + - TransferWriteDropUnitDimsPattern + + These patterns have the effect of rewriting a vector.transfer with unit + dimensions into a rank-reduced version thanks to subview operations. + This is complemented by shape_cast folding patterns. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs TransformHandleTypeInterface:$results); - At this time, the transform is all or nothing. + let assemblyFormat = [{ + $target + attr-dict + `:` functional-type($target, results) + }]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + +// TODO: not isolated from above and better targetability of the op. +// TODO: not a functional-style transform. +def ApplyTransferPermutationPatternsOp : + Op { + let description = [{ + Apply opt-in vector transfer permutation patterns that include: + - TransferReadPermutationLowering + - TransferWritePermutationLowering + - TransferOpReduceRank + - TransferWriteNonPermutationLowering + + These patterns have the effect of rewriting a vector.transfer with an + aritrary permutation_map to a vector.transfer with a permutation_map that is + a minor identity followed by a vector.transpose. + + In other words, this makes the vector.transfer contiguous on the most minor + dimensions and materializes the permutation_map as a vector.transpose. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs TransformHandleTypeInterface:$results); + + let assemblyFormat = [{ + $target + attr-dict + `:` functional-type($target, results) + }]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + +// TODO: not isolated from above and better targetability of the op. +// TODO: not a functional-style transform. +// TODO: evolve lowering_strategy to proper enums. +def LowerBroadcastOp : Op { + let description = [{ + Indicates that the vector outerproduct operations nested under the isolated + from above op `target` should be lowered to finer-grained vector primitives. + + This is usally a late step that is run after bufferization as part of the + process of lowering to e.g. LLVM or NVVM. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target + ); + let results = (outs TransformHandleTypeInterface:$results); + + let assemblyFormat = [{ + $target + attr-dict + `:` functional-type($target, results) + }]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + +// TODO: not isolated from above and better targetability of the op. +// TODO: not a functional-style transform. +// TODO: evolve lowering_strategy to proper enums. +def LowerContractionOp : Op { + let description = [{ + Indicates that the vector contraction-like operations nested under the + isolated from above op `target` should be lowered to finer-grained vector + primitives. This is usally a late step that is run after bufferization as part of the process of lowering to e.g. LLVM or NVVM. }]; - // TODO: evolve this to proper enums. - let arguments = (ins PDL_Operation:$target, + let arguments = (ins TransformHandleTypeInterface:$target, DefaultValuedAttr:$contraction_lowering, + "vector::VectorContractLowering::OuterProduct">:$lowering_strategy + ); + let results = (outs TransformHandleTypeInterface:$results); + + let assemblyFormat = [{ + $target + (`lowering_strategy` `=` $lowering_strategy^)? + attr-dict + `:` functional-type($target, results) + }]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + +// TODO: not isolated from above and better targetability of the op. +// TODO: not a functional-style transform. +// TODO: evolve lowering_strategy to proper enums. +def LowerMaskOp : Op { + let description = [{ + Indicates that the vector mask operations nested under the isolated from + above op `target` should be lowered to finer-grained vector primitives. + + This is usally a late step that is run after bufferization as part of the + process of lowering to e.g. LLVM or NVVM. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target + ); + let results = (outs TransformHandleTypeInterface:$results); + + let assemblyFormat = [{ + $target + attr-dict + `:` functional-type($target, results) + }]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + +// TODO: not isolated from above and better targetability of the op. +// TODO: not a functional-style transform. +// TODO: evolve lowering_strategy to proper enums. +def LowerMultiReductionOp : Op { + let description = [{ + Indicates that the vector multi_reduction-like operations nested under the + isolated from above op `target` should be lowered to finer-grained vector + primitives. + + This is usally a late step that is run after bufferization as part of the + process of lowering to e.g. LLVM or NVVM. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target, DefaultValuedAttr: - $multireduction_lowering, - DefaultValuedAttr:$split_transfers, + $lowering_strategy + ); + let results = (outs TransformHandleTypeInterface:$results); + + let assemblyFormat = [{ + $target + (`lowering_strategy` `=` $lowering_strategy^)? + attr-dict + `:` functional-type($target, results) + }]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + +// TODO: not isolated from above and better targetability of the op. +// TODO: not a functional-style transform. +// TODO: evolve lowering_strategy to proper enums. +def LowerOuterProductOp : Op { + let description = [{ + Indicates that the vector outerproduct operations nested under the isolated + from above op `target` should be lowered to finer-grained vector primitives. + + This is usally a late step that is run after bufferization as part of the + process of lowering to e.g. LLVM or NVVM. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target + ); + let results = (outs TransformHandleTypeInterface:$results); + + let assemblyFormat = [{ + $target + attr-dict + `:` functional-type($target, results) + }]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + +// TODO: not isolated from above and better targetability of the op. +// TODO: not a functional-style transform. +def LowerShapeCastOp : Op { + let description = [{ + Indicates that the vector shape_cast operations nested under the + isolated from above op `target` should be lowered to finer-grained vector + primitives. + + This is usally a late step that is run after bufferization as part of the + process of lowering to e.g. LLVM or NVVM. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs TransformHandleTypeInterface:$results); + + let assemblyFormat = [{ + $target + attr-dict + `:` functional-type($target, results) + }]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + +// TODO: not isolated from above and better targetability of the op. +// TODO: not a functional-style transform. +def LowerTransferOp : Op { + let description = [{ + Indicates that the vector transfer operations nested under the + isolated from above op `target` should be lowered to finer-grained vector + primitives. + + This is usally a late step that is run after bufferization as part of the + process of lowering to e.g. LLVM or NVVM. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target, + DefaultValuedAttr:$max_transfer_rank + ); + let results = (outs TransformHandleTypeInterface:$results); + + let assemblyFormat = [{ + $target + (`max_transfer_rank` `=` $max_transfer_rank^)? + attr-dict + `:` functional-type($target, results) + }]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + +// TODO: not isolated from above and better targetability of the op. +// TODO: not a functional-style transform. +// TODO: evolve lowering_strategy to proper enums. +def LowerTransposeOp : Op { + let description = [{ + Indicates that the vector transpose-like operations nested under the + isolated from above op `target` should be lowered to finer-grained vector + primitives. + + This is usally a late step that is run after bufferization as part of the + process of lowering to e.g. LLVM or NVVM. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target, DefaultValuedAttr:$transpose_lowering, - DefaultValuedAttr:$transpose_avx2_lowering, - DefaultValuedAttr:$unroll_vector_transfers + "vector::VectorTransposeLowering::EltWise">:$lowering_strategy, + DefaultValuedAttr:$avx2_lowering_strategy + ); + let results = (outs TransformHandleTypeInterface:$results); + + let assemblyFormat = [{ + $target + oilist ( + `lowering_strategy` `=` $lowering_strategy + | `avx2_lowering_strategy` `=` $avx2_lowering_strategy + ) + attr-dict + `:` functional-type($target, results) + }]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + +// TODO: not isolated from above and better targetability of the op. +// TODO: not a functional-style transform. +// TODO: evolve split_transfer_strategy to proper enums. +def SplitTransferFullPartialOp : Op { + let description = [{ + Indicates that the vector transfer operations nested under the + isolated from above op `target` should be split to full and partial parts. + + This is usally a late step that is run after bufferization as part of the + process of lowering to e.g. LLVM or NVVM. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target, + DefaultValuedAttr:$split_transfer_strategy + ); + let results = (outs TransformHandleTypeInterface:$results); + + let assemblyFormat = [{ + $target + (`split_transfer_strategy` `=` $split_transfer_strategy^)? + attr-dict + `:` functional-type($target, results) + }]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + +// TODO: not isolated from above and better targetability of the op. +// TODO: not a functional-style transform. +def TransferToScfOp : Op { + let description = [{ + Indicates that the vector transfer operations nested under the + isolated from above op `target` should be rewritten with scf.for loops over + finer-grained vector primitives. + + This is usally a late step that is run after bufferization as part of the + process of lowering to e.g. LLVM or NVVM. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target, + DefaultValuedAttr:$max_transfer_rank, + DefaultValuedAttr:$full_unroll ); - let results = (outs PDL_Operation:$results); - - let builders = [ - OpBuilder<(ins "Type":$resultType, "Value":$target, - "const vector::LowerVectorsOptions &":$options), [{ - return build($_builder, $_state, resultType, target, - options.vectorContractLowering, - options.vectorMultiReductionLowering, options.vectorTransferSplit, - options.vectorTransposeLowering, options.transposeAVX2Lowering, - options.unrollVectorTransfers); - }] - > - ]; + let results = (outs TransformHandleTypeInterface:$results); let assemblyFormat = [{ $target oilist ( - `contraction_lowering` `=` $contraction_lowering - | `multireduction_lowering` `=` $multireduction_lowering - | `split_transfers` `=` $split_transfers - | `transpose_lowering` `=` $transpose_lowering + `max_transfer_rank` `=` $max_transfer_rank + | `full_unroll` `=` $full_unroll ) attr-dict + `:` functional-type($target, results) + }]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); }]; } diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h --- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h @@ -50,6 +50,14 @@ RewritePatternSet &patterns, VectorTransformsOptions options, PatternBenefit benefit = 1, bool disableOuterProductLowering = false); +/// Populate the pattern set with the following patterns: +/// +/// [OuterProductOpLowering] +/// Progressively lower a `vector.outerproduct` to linearized +/// `vector.extract` + `vector.fma` + `vector.insert`. +void populateVectorOuterProductLoweringPatterns(RewritePatternSet &patterns, + PatternBenefit benefit = 1); + /// Collect a set of patterns to convert vector.multi_reduction op into /// a sequence of vector.reduction ops. The patterns comprise: /// diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp --- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp +++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h" #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" @@ -25,101 +26,354 @@ using namespace mlir::transform; //===----------------------------------------------------------------------===// -// LowerVectorsOp +// ApplyRankReducingSubviewPatternsOp //===----------------------------------------------------------------------===// -void transform::LowerVectorsOp::getEffects( - SmallVectorImpl &effects) { - consumesHandle(getTarget(), effects); - producesHandle(getResults(), effects); - modifiesPayload(effects); +DiagnosedSilenceableFailure +transform::ApplyRankReducingSubviewPatternsOp::applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state) { + // This check can't be part of the verifier because payload IR is + // independent from transform IR and may not even exist. + if (!target->hasTrait()) { + return mlir::emitDefiniteFailure(target, + "applies only to isolated-from-above " + "targets because it needs to apply " + "patterns greedily"); + } + + RewritePatternSet patterns(getContext()); + vector::populateVectorTransferDropUnitDimsPatterns(patterns); + + if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) + return DiagnosedSilenceableFailure::definiteFailure(); + + results.push_back(target); + return DiagnosedSilenceableFailure::success(); +} + +//===----------------------------------------------------------------------===// +// ApplyTransferPermutationPatternsOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::ApplyTransferPermutationPatternsOp::applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state) { + // This check can't be part of the verifier because payload IR is + // independent from transform IR and may not even exist. + if (!target->hasTrait()) { + return mlir::emitDefiniteFailure(target, + "applies only to isolated-from-above " + "targets because it needs to apply " + "patterns greedily"); + } + + RewritePatternSet patterns(getContext()); + vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); + + if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) + return DiagnosedSilenceableFailure::definiteFailure(); + + results.push_back(target); + return DiagnosedSilenceableFailure::success(); +} + +//===----------------------------------------------------------------------===// +// LowerBroadcastOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::LowerBroadcastOp::applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state) { + // This check can't be part of the verifier because payload IR is + // independent from transform IR and may not even exist. + if (!target->hasTrait()) { + return mlir::emitDefiniteFailure(target, + "applies only to isolated-from-above " + "targets because it needs to apply " + "patterns greedily"); + } + + RewritePatternSet patterns(getContext()); + populateVectorBroadcastLoweringPatterns(patterns); + + if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) + return DiagnosedSilenceableFailure::definiteFailure(); + + results.push_back(target); + return DiagnosedSilenceableFailure::success(); +} + +//===----------------------------------------------------------------------===// +// LowerContractionOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::LowerContractionOp::applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state) { + // This check can't be part of the verifier because payload IR is + // independent from transform IR and may not even exist. + if (!target->hasTrait()) { + return mlir::emitDefiniteFailure(target, + "applies only to isolated-from-above " + "targets because it needs to apply " + "patterns greedily"); + } + + RewritePatternSet patterns(getContext()); + vector::VectorTransformsOptions vectorTransformOptions; + vectorTransformOptions.setVectorTransformsOptions(getLoweringStrategy()); + populateVectorContractLoweringPatterns(patterns, vectorTransformOptions, + /*benefit=*/1, + /*disableOuterProductLowering=*/true); + + if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) + return DiagnosedSilenceableFailure::definiteFailure(); + + results.push_back(target); + return DiagnosedSilenceableFailure::success(); +} + +//===----------------------------------------------------------------------===// +// LowerMaskOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::LowerMaskOp::applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state) { + // This check can't be part of the verifier because payload IR is + // independent from transform IR and may not even exist. + if (!target->hasTrait()) { + return mlir::emitDefiniteFailure(target, + "applies only to isolated-from-above " + "targets because it needs to apply " + "patterns greedily"); + } + + RewritePatternSet patterns(getContext()); + populateVectorMaskOpLoweringPatterns(patterns); + populateVectorMaskLoweringPatternsForSideEffectingOps(patterns); + + if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) + return DiagnosedSilenceableFailure::definiteFailure(); + + results.push_back(target); + return DiagnosedSilenceableFailure::success(); +} + +//===----------------------------------------------------------------------===// +// LowerMultiReductionOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::LowerMultiReductionOp::applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state) { + // This check can't be part of the verifier because payload IR is + // independent from transform IR and may not even exist. + if (!target->hasTrait()) { + return mlir::emitDefiniteFailure(target, + "applies only to isolated-from-above " + "targets because it needs to apply " + "patterns greedily"); + } + + RewritePatternSet patterns(getContext()); + vector::VectorTransformsOptions vectorTransformOptions; + vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy()); + vector::populateVectorMultiReductionLoweringPatterns( + patterns, vectorTransformOptions.vectorMultiReductionLowering); + + if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) + return DiagnosedSilenceableFailure::definiteFailure(); + + results.push_back(target); + return DiagnosedSilenceableFailure::success(); +} + +//===----------------------------------------------------------------------===// +// LowerOuterProductOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::LowerOuterProductOp::applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state) { + // This check can't be part of the verifier because payload IR is + // independent from transform IR and may not even exist. + if (!target->hasTrait()) { + return mlir::emitDefiniteFailure(target, + "applies only to isolated-from-above " + "targets because it needs to apply " + "patterns greedily"); + } + + RewritePatternSet patterns(getContext()); + populateVectorOuterProductLoweringPatterns(patterns); + + if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) + return DiagnosedSilenceableFailure::definiteFailure(); + + results.push_back(target); + return DiagnosedSilenceableFailure::success(); } -DiagnosedSilenceableFailure transform::LowerVectorsOp::apply( - mlir::transform::TransformResults &transformResults, - mlir::transform::TransformState &state) { - - SmallVector results; - ArrayRef payloadOps = state.getPayloadOps(getTarget()); - for (Operation *target : payloadOps) { - // This check can't be part of the verifier because payload IR is - // independent from transform IR and may not even exist. - if (!target->hasTrait()) { - return mlir::emitDefiniteFailure(target, - "applies only to isolated-from-above " - "targets because it needs to apply " - "patterns greedily"); - } - - MLIRContext *ctx = getContext(); - RewritePatternSet patterns(ctx); - vector::VectorTransposeLowering vectorTransposeLowering = - getTransposeLowering(); - vector::VectorMultiReductionLowering vectorMultiReductionLowering = - getMultireductionLowering(); - vector::VectorContractLowering vectorContractLowering = - getContractionLowering(); - vector::VectorTransferSplit vectorTransferSplit = getSplitTransfers(); - - vector::VectorTransformsOptions vectorTransformOptions; - vectorTransformOptions.setVectorTransformsOptions(vectorContractLowering) - .setVectorMultiReductionLowering(vectorMultiReductionLowering) - .setVectorTransposeLowering(vectorTransposeLowering) - .setVectorTransferSplit(vectorTransferSplit); - - VectorTransferToSCFOptions vectorTransferToSCFOptions = - VectorTransferToSCFOptions().enableFullUnroll( - getUnrollVectorTransfers()); - - int maxTransferRank = 1; +//===----------------------------------------------------------------------===// +// LowerShapeCastOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::LowerShapeCastOp::applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state) { + // This check can't be part of the verifier because payload IR is + // independent from transform IR and may not even exist. + if (!target->hasTrait()) { + return mlir::emitDefiniteFailure(target, + "applies only to isolated-from-above " + "targets because it needs to apply " + "patterns greedily"); + } + + RewritePatternSet patterns(getContext()); + vector::populateVectorShapeCastLoweringPatterns(patterns); + + if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) + return DiagnosedSilenceableFailure::definiteFailure(); + + results.push_back(target); + return DiagnosedSilenceableFailure::success(); +} + +//===----------------------------------------------------------------------===// +// LowerTransferOp +//===----------------------------------------------------------------------===// +DiagnosedSilenceableFailure transform::LowerTransferOp::applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state) { + // This check can't be part of the verifier because payload IR is + // independent from transform IR and may not even exist. + if (!target->hasTrait()) { + return mlir::emitDefiniteFailure(target, + "applies only to isolated-from-above " + "targets because it needs to apply " + "patterns greedily"); + } + + RewritePatternSet patterns(getContext()); + vector::populateVectorTransferLoweringPatterns(patterns, + getMaxTransferRank()); + + if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) + return DiagnosedSilenceableFailure::definiteFailure(); + + results.push_back(target); + return DiagnosedSilenceableFailure::success(); +} + +//===----------------------------------------------------------------------===// +// LowerTransposeOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::LowerTransposeOp::applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state) { + // This check can't be part of the verifier because payload IR is + // independent from transform IR and may not even exist. + if (!target->hasTrait()) { + return mlir::emitDefiniteFailure(target, + "applies only to isolated-from-above " + "targets because it needs to apply " + "patterns greedily"); + } + + RewritePatternSet patterns(getContext()); + vector::populateVectorTransposeLoweringPatterns( + patterns, vector::VectorTransformsOptions().setVectorTransposeLowering( + getLoweringStrategy())); + if (getAvx2LoweringStrategy()) { auto avx2LoweringOptions = x86vector::avx2::LoweringOptions().setTransposeOptions( x86vector::avx2::TransposeLoweringOptions() - .lower4x8xf32(getTransposeAvx2Lowering()) - .lower8x8xf32(getTransposeAvx2Lowering())); - - vector::populateVectorToVectorCanonicalizationPatterns(patterns); + .lower4x8xf32(true) + .lower8x8xf32(true)); + x86vector::avx2::populateSpecializedTransposeLoweringPatterns( + patterns, avx2LoweringOptions, /*benefit=*/10); + } - // In the future we may want to more finely select particular stages. - // Stage 1: contraction lowerings. - populateVectorContractLoweringPatterns( - patterns, vectorTransformOptions, /*benefit=*/1, - /*disableOuterProductLowering*/ true); - vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); + if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) + return DiagnosedSilenceableFailure::definiteFailure(); - // Stage 2: multi-reduction lowerings. - vector::populateVectorMultiReductionLoweringPatterns( - patterns, vectorTransformOptions.vectorMultiReductionLowering); + results.push_back(target); + return DiagnosedSilenceableFailure::success(); +} - // Stage 3: Rewrite vector.transfer into full and partial parts. - populateVectorTransferFullPartialPatterns(patterns, vectorTransformOptions); +//===----------------------------------------------------------------------===// +// SplitTransferFullPartialOp +//===----------------------------------------------------------------------===// - // Stage 4: Lower vector transfers. - vector::populateVectorTransferLoweringPatterns(patterns, maxTransferRank); +DiagnosedSilenceableFailure transform::SplitTransferFullPartialOp::applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state) { + // This check can't be part of the verifier because payload IR is + // independent from transform IR and may not even exist. + if (!target->hasTrait()) { + return mlir::emitDefiniteFailure(target, + "applies only to isolated-from-above " + "targets because it needs to apply " + "patterns greedily"); + } - // Stage 5: Vector to scf patterns. - populateVectorToSCFConversionPatterns( - patterns, vectorTransferToSCFOptions.setTargetRank(maxTransferRank)); + RewritePatternSet patterns(getContext()); + vector::VectorTransformsOptions vectorTransformOptions; + vectorTransformOptions.setVectorTransferSplit(getSplitTransferStrategy()); + populateVectorTransferFullPartialPatterns(patterns, vectorTransformOptions); - // Stage 6: Lower vector.shape_cast. - vector::populateVectorShapeCastLoweringPatterns(patterns); + if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) + return DiagnosedSilenceableFailure::definiteFailure(); - // Stage 7: Lower vector.transpose. - vector::populateVectorTransposeLoweringPatterns( - patterns, vectorTransformOptions, /*benefit=*/1); - if (getTransposeAvx2Lowering()) - x86vector::avx2::populateSpecializedTransposeLoweringPatterns( - patterns, avx2LoweringOptions, /*benefit=*/10); + results.push_back(target); + return DiagnosedSilenceableFailure::success(); +} - // Apply everything. - if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) - return DiagnosedSilenceableFailure::definiteFailure(); +//===----------------------------------------------------------------------===// +// TransferToScfOp +//===----------------------------------------------------------------------===// - results.push_back(target); +DiagnosedSilenceableFailure transform::TransferToScfOp::applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state) { + // This check can't be part of the verifier because payload IR is + // independent from transform IR and may not even exist. + if (!target->hasTrait()) { + return mlir::emitDefiniteFailure(target, + "applies only to isolated-from-above " + "targets because it needs to apply " + "patterns greedily"); } - transformResults.set(getResults().cast(), results); + RewritePatternSet patterns(getContext()); + VectorTransferToSCFOptions vectorTransferToSCFOptions = + VectorTransferToSCFOptions() + .enableFullUnroll(getFullUnroll()) + .setTargetRank(getMaxTransferRank()); + populateVectorToSCFConversionPatterns(patterns, vectorTransferToSCFOptions); + + if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) + return DiagnosedSilenceableFailure::definiteFailure(); + + results.push_back(target); return DiagnosedSilenceableFailure::success(); } @@ -137,6 +391,7 @@ VectorTransformDialectExtension() { declareDependentDialect(); declareDependentDialect(); + declareDependentDialect(); registerTransformOps< #define GET_OP_LIST #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.cpp.inc" diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp @@ -1327,3 +1327,8 @@ ContractionOpToOuterProductOpLowering>( options, patterns.getContext(), benefit); } + +void mlir::vector::populateVectorOuterProductLoweringPatterns( + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add(patterns.getContext(), benefit); +} diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp @@ -78,18 +78,20 @@ PatternRewriter &rewriter) const override { // TODO: support 0-d corner case. if (op.getTransferRank() == 0) - return failure(); + return rewriter.notifyMatchFailure(op, "0-d corner case not supported"); SmallVector permutation; AffineMap map = op.getPermutationMap(); if (map.getNumResults() == 0) - return failure(); - if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) - return failure(); + return rewriter.notifyMatchFailure(op, "0 result permutation map"); + if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) { + return rewriter.notifyMatchFailure( + op, "map is not permutable to minor identity, apply another pattern"); + } AffineMap permutationMap = map.getPermutationMap(permutation, op.getContext()); if (permutationMap.isIdentity()) - return failure(); + return rewriter.notifyMatchFailure(op, "map is not identity"); permutationMap = map.getPermutationMap(permutation, op.getContext()); // Caluclate the map of the new read by applying the inverse permutation. @@ -148,14 +150,17 @@ PatternRewriter &rewriter) const override { // TODO: support 0-d corner case. if (op.getTransferRank() == 0) - return failure(); + return rewriter.notifyMatchFailure(op, "0-d corner case not supported"); SmallVector permutation; AffineMap map = op.getPermutationMap(); if (map.isMinorIdentity()) - return failure(); - if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) - return failure(); + return rewriter.notifyMatchFailure(op, "map is already minor identity"); + + if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) { + return rewriter.notifyMatchFailure( + op, "map is not permutable to minor identity, apply another pattern"); + } // Remove unused dims from the permutation map. E.g.: // E.g.: (d0, d1, d2, d3, d4, d5) -> (d5, d3, d4) @@ -209,19 +214,23 @@ LogicalResult matchAndRewrite(vector::TransferWriteOp op, PatternRewriter &rewriter) const override { + // TODO: support 0-d corner case. if (op.getTransferRank() == 0) - return failure(); + return rewriter.notifyMatchFailure(op, "0-d corner case not supported"); + SmallVector permutation; AffineMap map = op.getPermutationMap(); - if (map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) - return failure(); + if (map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) { + return rewriter.notifyMatchFailure( + op, + "map is already permutable to minor identity, apply another pattern"); + } // Missing outer dimensions are allowed, find the most outer existing // dimension then deduce the missing inner dimensions. SmallVector foundDim(map.getNumDims(), false); - for (AffineExpr exp : map.getResults()) { + for (AffineExpr exp : map.getResults()) foundDim[exp.cast().getPosition()] = true; - } SmallVector exprs; bool foundFirstDim = false; SmallVector missingInnerDim; @@ -274,7 +283,7 @@ PatternRewriter &rewriter) const override { // TODO: support 0-d corner case. if (op.getTransferRank() == 0) - return failure(); + return rewriter.notifyMatchFailure(op, "0-d corner case not supported"); AffineMap map = op.getPermutationMap(); unsigned numLeadingBroadcast = 0; @@ -286,7 +295,8 @@ } // If there are no leading zeros in the map there is nothing to do. if (numLeadingBroadcast == 0) - return failure(); + return rewriter.notifyMatchFailure(op, "no leading broadcasts in map"); + VectorType originalVecType = op.getVectorType(); unsigned reducedShapeRank = originalVecType.getRank() - numLeadingBroadcast; // Calculate new map, vector type and masks without the leading zeros. @@ -295,8 +305,10 @@ op.getContext()); // Only remove the leading zeros if the rest of the map is a minor identity // with broadasting. Otherwise we first want to permute the map. - if (!newMap.isMinorIdentityWithBroadcasting()) - return failure(); + if (!newMap.isMinorIdentityWithBroadcasting()) { + return rewriter.notifyMatchFailure( + op, "map is not a minor identity with broadcasting"); + } // TODO: support zero-dimension vectors natively. See: // https://llvm.discourse.group/t/should-we-have-0-d-vectors/3097. @@ -315,11 +327,13 @@ newRead); return success(); } + SmallVector newShape = llvm::to_vector<4>( originalVecType.getShape().take_back(reducedShapeRank)); // Vector rank cannot be zero. Handled by TransferReadToVectorLoadLowering. if (newShape.empty()) - return failure(); + return rewriter.notifyMatchFailure(op, "rank-reduced vector is 0-d"); + VectorType newReadType = VectorType::get(newShape, originalVecType.getElementType()); ArrayAttr newInBoundsAttr = @@ -370,8 +384,10 @@ LogicalResult matchAndRewrite(vector::TransferReadOp read, PatternRewriter &rewriter) const override { - if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank) - return failure(); + if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank) { + return rewriter.notifyMatchFailure( + read, "vector type is greater than max transfer rank"); + } SmallVector broadcastedDims; // Permutations are handled by VectorToSCF or @@ -379,15 +395,15 @@ // We let the 0-d corner case pass-through as it is supported. if (!read.getPermutationMap().isMinorIdentityWithBroadcasting( &broadcastedDims)) - return failure(); + return rewriter.notifyMatchFailure(read, "not minor identity + bcast"); auto memRefType = read.getShapedType().dyn_cast(); if (!memRefType) - return failure(); + return rewriter.notifyMatchFailure(read, "not a memref source"); // Non-unit strides are handled by VectorToSCF. if (!vector::isLastMemrefDimUnitStride(memRefType)) - return failure(); + return rewriter.notifyMatchFailure(read, "!= 1 stride needs VectorToSCF"); // If there is broadcasting involved then we first load the unbroadcasted // vector, and then broadcast it with `vector.broadcast`. @@ -403,16 +419,16 @@ // resulting vector type is the same as the element type. auto memrefElTy = memRefType.getElementType(); if (memrefElTy.isa() && memrefElTy != unbroadcastedVectorType) - return failure(); + return rewriter.notifyMatchFailure(read, "incompatible element type"); // Otherwise, element types of the memref and the vector must match. if (!memrefElTy.isa() && memrefElTy != read.getVectorType().getElementType()) - return failure(); + return rewriter.notifyMatchFailure(read, "non-matching element type"); // Out-of-bounds dims are handled by MaterializeTransferMask. if (read.hasOutOfBoundsDim()) - return failure(); + return rewriter.notifyMatchFailure(read, "out-of-bounds needs mask"); // Create vector load op. Operation *loadOp; @@ -458,7 +474,8 @@ PatternRewriter &rewriter) const override { auto vecType = loadOp.getVectorType(); if (vecType.getNumElements() != 1) - return failure(); + return rewriter.notifyMatchFailure(loadOp, "not a single element vector"); + auto memrefLoad = rewriter.create( loadOp.getLoc(), loadOp.getBase(), loadOp.getIndices()); rewriter.replaceOpWithNewOp(loadOp, vecType, @@ -476,7 +493,8 @@ PatternRewriter &rewriter) const override { auto vecType = storeOp.getVectorType(); if (vecType.getNumElements() != 1) - return failure(); + return rewriter.notifyMatchFailure(storeOp, "not single element vector"); + Value extracted; if (vecType.getRank() == 0) { // TODO: Unifiy once ExtractOp supports 0-d vectors. @@ -512,10 +530,10 @@ LogicalResult matchAndRewrite(vector::TransferWriteOp write, PatternRewriter &rewriter) const override { - if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank) - return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) { - diag << "rank exceeds maxTransferRank: " << write; - }); + if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank) { + return rewriter.notifyMatchFailure( + write, "vector type is greater than max transfer rank"); + } // Permutations are handled by VectorToSCF or // populateVectorTransferPermutationMapLoweringPatterns. diff --git a/mlir/test/Dialect/LLVM/transform-e2e.mlir b/mlir/test/Dialect/LLVM/transform-e2e.mlir --- a/mlir/test/Dialect/LLVM/transform-e2e.mlir +++ b/mlir/test/Dialect/LLVM/transform-e2e.mlir @@ -20,6 +20,39 @@ transform.structured.vectorize %2 transform.bufferization.one_shot_bufferize layout{IdentityLayoutMap} %module_op {bufferize_function_boundaries = true} - %func = transform.structured.match ops{["func.func"]} in %module_op : (!pdl.operation) -> !pdl.operation - transform.vector.lower_vectors %func multireduction_lowering = "innerreduction" + + %f = transform.structured.match ops{["func.func"]} in %module_op + : (!pdl.operation) -> !pdl.operation + + // TODO: group these lower-level controls into various properly named vector + // lowering TD macros. + %func = transform.vector.lower_contraction %f + lowering_strategy = "outerproduct" + : (!pdl.operation) -> !pdl.operation + + %func_2 = transform.vector.apply_transfer_permutation_patterns %func + : (!pdl.operation) -> !pdl.operation + + %func_3 = transform.vector.lower_multi_reduction %func_2 + lowering_strategy = "innerparallel" + : (!pdl.operation) -> !pdl.operation + + %func_4 = transform.vector.split_transfer_full_partial %func_3 + split_transfer_strategy = "linalg-copy" + : (!pdl.operation) -> !pdl.operation + + %func_5 = transform.vector.transfer_to_scf %func_4 + max_transfer_rank = 1 full_unroll = true + : (!pdl.operation) -> !pdl.operation + + %func_6 = transform.vector.lower_transfer %func_5 + max_transfer_rank = 1 + : (!pdl.operation) -> !pdl.operation + + %func_7 = transform.vector.lower_shape_cast %func_6 + : (!pdl.operation) -> !pdl.operation + + %func_8 = transform.vector.lower_transpose %func_7 + lowering_strategy = "shuffle" + : (!pdl.operation) -> !pdl.operation } diff --git a/mlir/test/Dialect/Vector/transform-vector.mlir b/mlir/test/Dialect/Vector/transform-vector.mlir --- a/mlir/test/Dialect/Vector/transform-vector.mlir +++ b/mlir/test/Dialect/Vector/transform-vector.mlir @@ -16,11 +16,45 @@ transform.sequence failures(propagate) { ^bb1(%module_op: !pdl.operation): %0 = transform.structured.match ops{["linalg.matmul"]} in %module_op : (!pdl.operation) -> !pdl.operation - %1, %loops:3 = transform.structured.tile %0 [8, 4, 2] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation) + %1, %loops:3 = transform.structured.tile %0 [8, 4, 2] + : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation) %2 = get_closest_isolated_parent %1 : (!pdl.operation) -> !pdl.operation transform.structured.vectorize %2 - transform.bufferization.one_shot_bufferize %module_op + transform.bufferization.one_shot_bufferize layout{IdentityLayoutMap} %module_op + {bufferize_function_boundaries = true, allow_return_allocs = true} - %func = transform.structured.match ops{["func.func"]} in %module_op : (!pdl.operation) -> !pdl.operation - transform.vector.lower_vectors %func multireduction_lowering = "innerreduction" + %f = transform.structured.match ops{["func.func"]} in %module_op + : (!pdl.operation) -> !pdl.operation + + // TODO: group these lower-level controls into various properly named vector + // lowering TD macros. + %func = transform.vector.lower_contraction %f + lowering_strategy = "outerproduct" + : (!pdl.operation) -> !pdl.operation + + %func_2 = transform.vector.apply_transfer_permutation_patterns %func + : (!pdl.operation) -> !pdl.operation + + %func_3 = transform.vector.lower_multi_reduction %func_2 + lowering_strategy = "innerparallel" + : (!pdl.operation) -> !pdl.operation + + %func_4 = transform.vector.split_transfer_full_partial %func_3 + split_transfer_strategy = "linalg-copy" + : (!pdl.operation) -> !pdl.operation + + %func_5 = transform.vector.transfer_to_scf %func_4 + max_transfer_rank = 1 full_unroll = true + : (!pdl.operation) -> !pdl.operation + + %func_6 = transform.vector.lower_transfer %func_5 + max_transfer_rank = 1 + : (!pdl.operation) -> !pdl.operation + + %func_7 = transform.vector.lower_shape_cast %func_6 + : (!pdl.operation) -> !pdl.operation + + %func_8 = transform.vector.lower_transpose %func_7 + lowering_strategy = "shuffle" + : (!pdl.operation) -> !pdl.operation } diff --git a/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir @@ -0,0 +1,172 @@ +// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file | FileCheck %s + +// CHECK-LABEL: func @broadcast_vec1d_from_scalar +// CHECK-SAME: %[[A:.*0]]: f32 +// CHECK: %[[T0:.*]] = vector.splat %[[A]] : vector<2xf32> +// CHECK: return %[[T0]] : vector<2xf32> + +func.func @broadcast_vec1d_from_scalar(%arg0: f32) -> vector<2xf32> { + %0 = vector.broadcast %arg0 : f32 to vector<2xf32> + return %0 : vector<2xf32> +} + +// CHECK-LABEL: func @broadcast_vec2d_from_scalar +// CHECK-SAME: %[[A:.*0]]: f32 +// CHECK: %[[T0:.*]] = vector.splat %[[A]] : vector<2x3xf32> +// CHECK: return %[[T0]] : vector<2x3xf32> + +func.func @broadcast_vec2d_from_scalar(%arg0: f32) -> vector<2x3xf32> { + %0 = vector.broadcast %arg0 : f32 to vector<2x3xf32> + return %0 : vector<2x3xf32> +} + +// CHECK-LABEL: func @broadcast_vec3d_from_scalar +// CHECK-SAME: %[[A:.*0]]: f32 +// CHECK: %[[T0:.*]] = vector.splat %[[A]] : vector<2x3x4xf32> +// CHECK: return %[[T0]] : vector<2x3x4xf32> + +func.func @broadcast_vec3d_from_scalar(%arg0: f32) -> vector<2x3x4xf32> { + %0 = vector.broadcast %arg0 : f32 to vector<2x3x4xf32> + return %0 : vector<2x3x4xf32> +} + +// CHECK-LABEL: func @broadcast_vec1d_from_vec1d +// CHECK-SAME: %[[A:.*0]]: vector<2xf32> +// CHECK: return %[[A]] : vector<2xf32> + +func.func @broadcast_vec1d_from_vec1d(%arg0: vector<2xf32>) -> vector<2xf32> { + %0 = vector.broadcast %arg0 : vector<2xf32> to vector<2xf32> + return %0 : vector<2xf32> +} + +// CHECK-LABEL: func @broadcast_vec2d_from_vec1d +// CHECK-SAME: %[[A:.*0]]: vector<2xf32> +// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32> +// CHECK: %[[T0:.*]] = vector.insert %[[A]], %[[C0]] [0] : vector<2xf32> into vector<3x2xf32> +// CHECK: %[[T1:.*]] = vector.insert %[[A]], %[[T0]] [1] : vector<2xf32> into vector<3x2xf32> +// CHECK: %[[T2:.*]] = vector.insert %[[A]], %[[T1]] [2] : vector<2xf32> into vector<3x2xf32> +// CHECK: return %[[T2]] : vector<3x2xf32> + +func.func @broadcast_vec2d_from_vec1d(%arg0: vector<2xf32>) -> vector<3x2xf32> { + %0 = vector.broadcast %arg0 : vector<2xf32> to vector<3x2xf32> + return %0 : vector<3x2xf32> +} + +// CHECK-LABEL: func @broadcast_vec3d_from_vec1d +// CHECK-SAME: %[[A:.*0]]: vector<2xf32> +// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32> +// CHECK: %[[C1:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x2xf32> +// CHECK: %[[T0:.*]] = vector.insert %[[A]], %[[C0]] [0] : vector<2xf32> into vector<3x2xf32> +// CHECK: %[[T1:.*]] = vector.insert %[[A]], %[[T0]] [1] : vector<2xf32> into vector<3x2xf32> +// CHECK: %[[T2:.*]] = vector.insert %[[A]], %[[T1]] [2] : vector<2xf32> into vector<3x2xf32> +// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C1]] [0] : vector<3x2xf32> into vector<4x3x2xf32> +// CHECK: %[[T4:.*]] = vector.insert %[[T2]], %[[T3]] [1] : vector<3x2xf32> into vector<4x3x2xf32> +// CHECK: %[[T5:.*]] = vector.insert %[[T2]], %[[T4]] [2] : vector<3x2xf32> into vector<4x3x2xf32> +// CHECK: %[[T6:.*]] = vector.insert %[[T2]], %[[T5]] [3] : vector<3x2xf32> into vector<4x3x2xf32> +// CHECK: return %[[T6]] : vector<4x3x2xf32> + +func.func @broadcast_vec3d_from_vec1d(%arg0: vector<2xf32>) -> vector<4x3x2xf32> { + %0 = vector.broadcast %arg0 : vector<2xf32> to vector<4x3x2xf32> + return %0 : vector<4x3x2xf32> +} + +// CHECK-LABEL: func @broadcast_vec3d_from_vec2d +// CHECK-SAME: %[[A:.*0]]: vector<3x2xf32> +// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x2xf32> +// CHECK: %[[T0:.*]] = vector.insert %[[A]], %[[C0]] [0] : vector<3x2xf32> into vector<4x3x2xf32> +// CHECK: %[[T1:.*]] = vector.insert %[[A]], %[[T0]] [1] : vector<3x2xf32> into vector<4x3x2xf32> +// CHECK: %[[T2:.*]] = vector.insert %[[A]], %[[T1]] [2] : vector<3x2xf32> into vector<4x3x2xf32> +// CHECK: %[[T3:.*]] = vector.insert %[[A]], %[[T2]] [3] : vector<3x2xf32> into vector<4x3x2xf32> +// CHECK: return %[[T3]] : vector<4x3x2xf32> + +func.func @broadcast_vec3d_from_vec2d(%arg0: vector<3x2xf32>) -> vector<4x3x2xf32> { + %0 = vector.broadcast %arg0 : vector<3x2xf32> to vector<4x3x2xf32> + return %0 : vector<4x3x2xf32> +} + +// CHECK-LABEL: func @broadcast_stretch +// CHECK-SAME: %[[A:.*0]]: vector<1xf32> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<1xf32> +// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<4xf32> +// CHECK: return %[[T1]] : vector<4xf32> + +func.func @broadcast_stretch(%arg0: vector<1xf32>) -> vector<4xf32> { + %0 = vector.broadcast %arg0 : vector<1xf32> to vector<4xf32> + return %0 : vector<4xf32> +} + +// CHECK-LABEL: func @broadcast_stretch_at_start +// CHECK-SAME: %[[A:.*0]]: vector<1x4xf32> +// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<3x4xf32> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<1x4xf32> +// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C0]] [0] : vector<4xf32> into vector<3x4xf32> +// CHECK: %[[T2:.*]] = vector.insert %[[T0]], %[[T1]] [1] : vector<4xf32> into vector<3x4xf32> +// CHECK: %[[T3:.*]] = vector.insert %[[T0]], %[[T2]] [2] : vector<4xf32> into vector<3x4xf32> +// CHECK: return %[[T3]] : vector<3x4xf32> + +func.func @broadcast_stretch_at_start(%arg0: vector<1x4xf32>) -> vector<3x4xf32> { + %0 = vector.broadcast %arg0 : vector<1x4xf32> to vector<3x4xf32> + return %0 : vector<3x4xf32> +} + +// CHECK-LABEL: func @broadcast_stretch_at_end +// CHECK-SAME: %[[A:.*0]]: vector<4x1xf32> +// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<4x3xf32> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<4x1xf32> +// CHECK: %[[T2:.*]] = vector.splat %[[T0]] : vector<3xf32> +// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xf32> into vector<4x3xf32> +// CHECK: %[[T4:.*]] = vector.extract %[[A]][1, 0] : vector<4x1xf32> +// CHECK: %[[T6:.*]] = vector.splat %[[T4]] : vector<3xf32> +// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xf32> into vector<4x3xf32> +// CHECK: %[[T8:.*]] = vector.extract %[[A]][2, 0] : vector<4x1xf32> +// CHECK: %[[T10:.*]] = vector.splat %[[T8]] : vector<3xf32> +// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T7]] [2] : vector<3xf32> into vector<4x3xf32> +// CHECK: %[[T12:.*]] = vector.extract %[[A]][3, 0] : vector<4x1xf32> +// CHECK: %[[T14:.*]] = vector.splat %[[T12]] : vector<3xf32> +// CHECK: %[[T15:.*]] = vector.insert %[[T14]], %[[T11]] [3] : vector<3xf32> into vector<4x3xf32> +// CHECK: return %[[T15]] : vector<4x3xf32> + +func.func @broadcast_stretch_at_end(%arg0: vector<4x1xf32>) -> vector<4x3xf32> { + %0 = vector.broadcast %arg0 : vector<4x1xf32> to vector<4x3xf32> + return %0 : vector<4x3xf32> +} + +// CHECK-LABEL: func @broadcast_stretch_in_middle +// CHECK-SAME: %[[A:.*0]]: vector<4x1x2xf32> +// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x2xf32> +// CHECK: %[[C1:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<4x1x2xf32> +// CHECK: %[[T2:.*]] = vector.insert %[[T0]], %[[C1]] [0] : vector<2xf32> into vector<3x2xf32> +// CHECK: %[[T3:.*]] = vector.insert %[[T0]], %[[T2]] [1] : vector<2xf32> into vector<3x2xf32> +// CHECK: %[[T4:.*]] = vector.insert %[[T0]], %[[T3]] [2] : vector<2xf32> into vector<3x2xf32> +// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[C0]] [0] : vector<3x2xf32> into vector<4x3x2xf32> +// CHECK: %[[T6:.*]] = vector.extract %[[A]][1, 0] : vector<4x1x2xf32> +// CHECK: %[[T8:.*]] = vector.insert %[[T6]], %[[C1]] [0] : vector<2xf32> into vector<3x2xf32> +// CHECK: %[[T9:.*]] = vector.insert %[[T6]], %[[T8]] [1] : vector<2xf32> into vector<3x2xf32> +// CHECK: %[[T10:.*]] = vector.insert %[[T6]], %[[T9]] [2] : vector<2xf32> into vector<3x2xf32> +// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T5]] [1] : vector<3x2xf32> into vector<4x3x2xf32> +// CHECK: %[[T12:.*]] = vector.extract %[[A]][2, 0] : vector<4x1x2xf32> +// CHECK: %[[T14:.*]] = vector.insert %[[T12]], %[[C1]] [0] : vector<2xf32> into vector<3x2xf32> +// CHECK: %[[T15:.*]] = vector.insert %[[T12]], %[[T14]] [1] : vector<2xf32> into vector<3x2xf32> +// CHECK: %[[T16:.*]] = vector.insert %[[T12]], %[[T15]] [2] : vector<2xf32> into vector<3x2xf32> +// CHECK: %[[T17:.*]] = vector.insert %[[T16]], %[[T11]] [2] : vector<3x2xf32> into vector<4x3x2xf32> +// CHECK: %[[T18:.*]] = vector.extract %[[A]][3, 0] : vector<4x1x2xf32> +// CHECK: %[[T20:.*]] = vector.insert %[[T18]], %[[C1]] [0] : vector<2xf32> into vector<3x2xf32> +// CHECK: %[[T21:.*]] = vector.insert %[[T18]], %[[T20]] [1] : vector<2xf32> into vector<3x2xf32> +// CHECK: %[[T22:.*]] = vector.insert %[[T18]], %[[T21]] [2] : vector<2xf32> into vector<3x2xf32> +// CHECK: %[[T23:.*]] = vector.insert %[[T22]], %[[T17]] [3] : vector<3x2xf32> into vector<4x3x2xf32> +// CHECK: return %[[T23]] : vector<4x3x2xf32> + +func.func @broadcast_stretch_in_middle(%arg0: vector<4x1x2xf32>) -> vector<4x3x2xf32> { + %0 = vector.broadcast %arg0 : vector<4x1x2xf32> to vector<4x3x2xf32> + return %0 : vector<4x3x2xf32> +} + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + %f = transform.structured.match ops{["func.func"]} in %module_op + : (!pdl.operation) -> !pdl.operation + + transform.vector.lower_broadcast %f + : (!pdl.operation) -> !pdl.operation +} diff --git a/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -test-vector-contraction-lowering=vector-outerproduct=1 | FileCheck %s +// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file | FileCheck %s #matvec_accesses = [ affine_map<(i, j) -> (i, j)>, @@ -207,3 +207,10 @@ memref.store %0, %arg2[] : memref> return } + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + transform.vector.lower_contraction %module_op + lowering_strategy = "outerproduct" + : (!pdl.operation) -> !pdl.operation +} diff --git a/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir @@ -0,0 +1,306 @@ +// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file | FileCheck %s + +#dotp_accesses = [ + affine_map<(i) -> (i)>, + affine_map<(i) -> (i)>, + affine_map<(i) -> ()> +] +#dotp_trait = { + indexing_maps = #dotp_accesses, + iterator_types = ["reduction"] +} + +// CHECK-LABEL: func @extract_contract1 +// CHECK-SAME: %[[A:.*0]]: vector<4xf32>, +// CHECK-SAME: %[[B:.*1]]: vector<4xf32>, +// CHECK-SAME: %[[C:.*2]]: f32 +// CHECK: %[[F:.*]] = arith.mulf %[[A]], %[[B]] : vector<4xf32> +// CHECK: %[[R:.*]] = vector.reduction , %[[F]], %[[C]] : vector<4xf32> into f32 +// CHECK: return %[[R]] : f32 + +func.func @extract_contract1(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2: f32) -> f32 { + %0 = vector.contract #dotp_trait %arg0, %arg1, %arg2 + : vector<4xf32>, vector<4xf32> into f32 + return %0 : f32 +} + +// CHECK-LABEL: func @masked_extract_contract1 +// CHECK-SAME: %[[A:.*0]]: vector<4xf32>, %[[B:.*1]]: vector<4xf32>, %[[C:.*2]]: f32 +// CHECK-SAME: %[[M:.*]]: vector<4xi1> +// CHECK: %[[F:.*]] = arith.mulf %[[A]], %[[B]] : vector<4xf32> +// CHECK: %[[R:.*]] = vector.mask %[[M]] { vector.reduction , %0, %arg2 : vector<4xf32> into f32 } : vector<4xi1> -> f32 +// CHECK: return %[[R]] : f32 + +func.func @masked_extract_contract1(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2: f32, %mask: vector<4xi1>) -> f32 { + %0 = vector.mask %mask { vector.contract #dotp_trait %arg0, %arg1, %arg2 : vector<4xf32>, vector<4xf32> into f32 } : vector<4xi1> -> f32 + return %0 : f32 +} + +// CHECK-LABEL: func @extract_contract1_int +// CHECK-SAME: %[[A:.*0]]: vector<4xi32>, +// CHECK-SAME: %[[B:.*1]]: vector<4xi32>, +// CHECK-SAME: %[[C:.*2]]: i32 +// CHECK: %[[F:.*]] = arith.muli %[[A]], %[[B]] : vector<4xi32> +// CHECK: %[[R:.*]] = vector.reduction , %[[F]], %[[C]] : vector<4xi32> into i32 +// CHECK: return %[[R]] : i32 + +func.func @extract_contract1_int(%arg0: vector<4xi32>, %arg1: vector<4xi32>, %arg2: i32) -> i32 { + %0 = vector.contract #dotp_trait %arg0, %arg1, %arg2 + : vector<4xi32>, vector<4xi32> into i32 + return %0 : i32 +} + +#matvec_accesses = [ + affine_map<(i, j) -> (i, j)>, + affine_map<(i, j) -> (j)>, + affine_map<(i, j) -> (i)> +] +#matvec_trait = { + indexing_maps = #matvec_accesses, + iterator_types = ["parallel", "reduction"] +} + +// CHECK-LABEL: func @extract_contract2 +// CHECK-SAME: %[[A:.*0]]: vector<2x3xf32>, +// CHECK-SAME: %[[B:.*1]]: vector<3xf32>, +// CHECK-SAME: %[[C:.*2]]: vector<2xf32> +// CHECK: %[[R:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xf32> +// CHECK: %[[T2:.*]] = arith.mulf %[[T0]], %[[B]] : vector<3xf32> +// CHECK: %[[T3:.*]] = vector.reduction , %[[T2]] : vector<3xf32> into f32 +// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : f32 into vector<2xf32> +// CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : vector<2x3xf32> +// CHECK: %[[T7:.*]] = arith.mulf %[[T5]], %[[B]] : vector<3xf32> +// CHECK: %[[T8:.*]] = vector.reduction , %[[T7]] : vector<3xf32> into f32 +// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : f32 into vector<2xf32> +// CHECK: %[[T10:.*]] = arith.addf %[[T9]], %[[C]] : vector<2xf32> +// CHECK: return %[[T10]] : vector<2xf32> + +func.func @extract_contract2(%arg0: vector<2x3xf32>, + %arg1: vector<3xf32>, + %arg2: vector<2xf32>) -> vector<2xf32> { + %0 = vector.contract #matvec_trait %arg0, %arg1, %arg2 + : vector<2x3xf32>, vector<3xf32> into vector<2xf32> + return %0 : vector<2xf32> +} + +// CHECK-LABEL: func @extract_contract2_int +// CHECK-SAME: %[[A:.*0]]: vector<2x3xi32>, +// CHECK-SAME: %[[B:.*1]]: vector<3xi32>, +// CHECK-SAME: %[[C:.*2]]: vector<2xi32> +// CHECK: %[[R:.*]] = arith.constant dense<0> : vector<2xi32> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xi32> +// CHECK: %[[T2:.*]] = arith.muli %[[T0]], %[[B]] : vector<3xi32> +// CHECK: %[[T3:.*]] = vector.reduction , %[[T2]] : vector<3xi32> into i32 +// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : i32 into vector<2xi32> +// CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : vector<2x3xi32> +// CHECK: %[[T7:.*]] = arith.muli %[[T5]], %[[B]] : vector<3xi32> +// CHECK: %[[T8:.*]] = vector.reduction , %[[T7]] : vector<3xi32> into i32 +// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : i32 into vector<2xi32> +// CHECK: %[[T10:.*]] = arith.addi %[[T9]], %[[C]] : vector<2xi32> +// CHECK: return %[[T10]] : vector<2xi32> +func.func @extract_contract2_int(%arg0: vector<2x3xi32>, + %arg1: vector<3xi32>, + %arg2: vector<2xi32>) -> vector<2xi32> { + %0 = vector.contract #matvec_trait %arg0, %arg1, %arg2 + : vector<2x3xi32>, vector<3xi32> into vector<2xi32> + return %0 : vector<2xi32> +} + +#vecmat_accesses = [ + affine_map<(i, j) -> (j)>, + affine_map<(i, j) -> (i, j)>, + affine_map<(i, j) -> (i)> +] +#vecmat_trait = { + indexing_maps = #vecmat_accesses, + iterator_types = ["parallel", "reduction"] +} + +// CHECK-LABEL: func @extract_contract3 +// CHECK-SAME: %[[A:.*0]]: vector<3xf32>, +// CHECK-SAME: %[[B:.*1]]: vector<2x3xf32>, +// CHECK-SAME: %[[C:.*2]]: vector<2xf32> +// CHECK: %[[R:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32> +// CHECK: %[[T0:.*]] = vector.extract %[[B]][0] : vector<2x3xf32> +// CHECK: %[[T2:.*]] = arith.mulf %[[T0]], %[[A]] : vector<3xf32> +// CHECK: %[[T3:.*]] = vector.reduction , %[[T2]] : vector<3xf32> into f32 +// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : f32 into vector<2xf32> +// CHECK: %[[T5:.*]] = vector.extract %[[B]][1] : vector<2x3xf32> +// CHECK: %[[T7:.*]] = arith.mulf %[[T5]], %[[A]] : vector<3xf32> +// CHECK: %[[T8:.*]] = vector.reduction , %[[T7]] : vector<3xf32> into f32 +// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : f32 into vector<2xf32> +// CHECK: %[[T10:.*]] = arith.addf %[[T9]], %[[C]] : vector<2xf32> +// CHECK: return %[[T10]] : vector<2xf32> + +func.func @extract_contract3(%arg0: vector<3xf32>, + %arg1: vector<2x3xf32>, + %arg2: vector<2xf32>) -> vector<2xf32> { + %0 = vector.contract #vecmat_trait %arg0, %arg1, %arg2 + : vector<3xf32>, vector<2x3xf32> into vector<2xf32> + return %0 : vector<2xf32> +} + +#matmat_accesses = [ + affine_map<(i, j, k) -> (i, k)>, + affine_map<(i, j, k) -> (k, j)>, + affine_map<(i, j, k) -> (i, j)> +] +#matmat_trait = { + indexing_maps = #matmat_accesses, + iterator_types = ["parallel", "parallel", "reduction"] +} + +// CHECK-LABEL: func @extract_contract4 +// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>, +// CHECK-SAME: %[[B:.*1]]: vector<2x2xf32>, +// CHECK-SAME: %[[C:.*2]]: vector<2x2xf32> +// CHECK: %[[R:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32> +// CHECK: %[[Bt:.*]] = vector.transpose %arg1, [1, 0] : vector<2x2xf32> to vector<2x2xf32> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x2xf32> +// CHECK: %[[T2:.*]] = vector.extract %[[Bt]][0] : vector<2x2xf32> +// CHECK: %[[T9:.*]] = arith.mulf %[[T0]], %[[T2]] : vector<2xf32> +// CHECK: %[[T10:.*]] = vector.reduction , %[[T9]] : vector<2xf32> into f32 +// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[R]] [0, 0] : f32 into vector<2x2xf32> +// +// CHECK: %[[T12:.*]] = vector.extract %[[Bt]][1] : vector<2x2xf32> +// CHECK: %[[T19:.*]] = arith.mulf %[[T0]], %[[T12]] : vector<2xf32> +// CHECK: %[[T20:.*]] = vector.reduction , %[[T19]] : vector<2xf32> into f32 +// CHECK: %[[T21:.*]] = vector.insert %[[T20]], %[[T11]] [0, 1] : f32 into vector<2x2xf32> +// +// CHECK: %[[T23:.*]] = vector.extract %[[A]][1] : vector<2x2xf32> +// CHECK: %[[T24:.*]] = vector.extract %[[Bt]][0] : vector<2x2xf32> +// CHECK: %[[T32:.*]] = arith.mulf %[[T23]], %[[T24]] : vector<2xf32> +// CHECK: %[[T33:.*]] = vector.reduction , %[[T32]] : vector<2xf32> into f32 +// CHECK: %[[T34:.*]] = vector.insert %[[T33]], %[[T21]] [1, 0] : f32 into vector<2x2xf32> +// +// CHECK: %[[T40:.*]] = vector.extract %[[Bt]][1] : vector<2x2xf32> +// CHECK: %[[T41:.*]] = arith.mulf %[[T23]], %[[T40]] : vector<2xf32> +// CHECK: %[[T42:.*]] = vector.reduction , %[[T41]] : vector<2xf32> into f32 +// CHECK: %[[T43:.*]] = vector.insert %[[T42]], %[[T34]] [1, 1] : f32 into vector<2x2xf32> +// +// CHECK: %[[T52:.*]] = arith.addf %[[T43]], %[[C]] : vector<2x2xf32> +// CHECK: return %[[T52]] : vector<2x2xf32> + +func.func @extract_contract4(%arg0: vector<2x2xf32>, + %arg1: vector<2x2xf32>, + %arg2: vector<2x2xf32>) -> vector<2x2xf32> { + %0 = vector.contract #matmat_trait %arg0, %arg1, %arg2 + : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> + return %0 : vector<2x2xf32> +} + + +#contraction2d_accesses = [ + affine_map<(i, j) -> (i, j)>, + affine_map<(i, j) -> (i, j)>, + affine_map<(i, j) -> ()> +] +#contraction2d_trait = { + indexing_maps = #contraction2d_accesses, + iterator_types = ["reduction", "reduction"] +} + +// CHECK-LABEL: func @full_contract1 +// CHECK-SAME: %[[A:.*0]]: vector<2x3xf32>, +// CHECK-SAME: %[[B:.*1]]: vector<2x3xf32>, +// CHECK-SAME: %[[C:.*2]]: f32 +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xf32> +// CHECK: %[[T1:.*]] = vector.extract %[[B]][0] : vector<2x3xf32> +// CHECK: %[[T2:.*]] = arith.mulf %[[T0]], %[[T1]] : vector<3xf32> +// CHECK: %[[T3:.*]] = vector.reduction , %[[T2]], %[[C]] : vector<3xf32> into f32 +// CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : vector<2x3xf32> +// CHECK: %[[T6:.*]] = vector.extract %[[B]][1] : vector<2x3xf32> +// CHECK: %[[T7:.*]] = arith.mulf %[[T5]], %[[T6]] : vector<3xf32> +// CHECK: %[[T8:.*]] = vector.reduction , %[[T7]], %[[T3]] : vector<3xf32> into f32 +// CHECK: return %[[T8]] : f32 + +func.func @full_contract1(%arg0: vector<2x3xf32>, + %arg1: vector<2x3xf32>, + %arg2: f32) -> f32 { + %0 = vector.contract #contraction2d_trait %arg0, %arg1, %arg2 + : vector<2x3xf32>, vector<2x3xf32> into f32 + return %0 : f32 +} + +#contraction2d_trans_accesses = [ + affine_map<(i, j) -> (i, j)>, + affine_map<(i, j) -> (j, i)>, + affine_map<(i, j) -> ()> +] +#contraction2d_trans_trait = { + indexing_maps = #contraction2d_trans_accesses, + iterator_types = ["reduction", "reduction"] +} + +// CHECK-LABEL: func @full_contract2 +// CHECK-SAME: %[[A:.*0]]: vector<2x3xf32>, +// CHECK-SAME: %[[B:.*1]]: vector<3x2xf32>, +// CHECK-SAME: %[[C:.*2]]: f32 +// CHECK: %[[Z:.*]] = arith.constant dense<0.000000e+00> : vector<3xf32> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xf32> +// CHECK: %[[T1:.*]] = vector.extract %[[B]][0, 0] : vector<3x2xf32> +// CHECK: %[[T3:.*]] = vector.insert %[[T1]], %[[Z]] [0] : f32 into vector<3xf32> +// CHECK: %[[T4:.*]] = vector.extract %[[B]][1, 0] : vector<3x2xf32> +// CHECK: %[[T6:.*]] = vector.insert %[[T4]], %[[T3]] [1] : f32 into vector<3xf32> +// CHECK: %[[T7:.*]] = vector.extract %[[B]][2, 0] : vector<3x2xf32> +// CHECK: %[[T9:.*]] = vector.insert %[[T7]], %[[T6]] [2] : f32 into vector<3xf32> +// CHECK: %[[T10:.*]] = arith.mulf %[[T0]], %[[T9]] : vector<3xf32> +// CHECK: %[[T11:.*]] = vector.reduction , %[[T10]], %[[C]] : vector<3xf32> into f32 +// +// CHECK: %[[T12:.*]] = vector.extract %[[A]][1] : vector<2x3xf32> +// CHECK: %[[T13:.*]] = vector.extract %[[B]][0, 1] : vector<3x2xf +// CHECK: %[[T15:.*]] = vector.insert %[[T13]], %[[Z]] [0] : f32 into vector<3xf32> +// CHECK: %[[T16:.*]] = vector.extract %[[B]][1, 1] : vector<3x2xf32> +// CHECK: %[[T18:.*]] = vector.insert %[[T16]], %[[T15]] [1] : f32 into vector<3xf32> +// CHECK: %[[T19:.*]] = vector.extract %[[B]][2, 1] : vector<3x2xf32> +// CHECK: %[[T21:.*]] = vector.insert %[[T19]], %[[T18]] [2] : f32 into vector<3xf32> +// CHECK: %[[T22:.*]] = arith.mulf %[[T12]], %[[T21]] : vector<3xf32> +// CHECK: %[[T23:.*]] = vector.reduction , %[[T22]], %[[T11]] : vector<3xf32> into f32 +// CHECK: return %[[T23]] : f32 + +func.func @full_contract2(%arg0: vector<2x3xf32>, + %arg1: vector<3x2xf32>, + %arg2: f32) -> f32 { + %0 = vector.contract #contraction2d_trans_trait %arg0, %arg1, %arg2 + : vector<2x3xf32>, vector<3x2xf32> into f32 + return %0 : f32 +} + +// CHECK-LABEL: @contract_one_sided_unit_reduction_dim +// CHECK-SAME: (%[[A0:.+]]: vector<1x2xi32>, %[[A1:.+]]: vector<2x2xi32>, %[[A2:.+]]: vector<2xi32>) +// CHECK-DAG: %[[C:.+]] = arith.constant dense<0> : vector<2xi32> +// CHECK-DAG: %[[E00:.+]] = vector.extract %[[A0]][0] : vector<1x2xi32> +// CHECK-DAG: %[[E10:.+]] = vector.extract %[[A1]][0] : vector<2x2xi32> +// CHECK: %[[M0:.+]] = arith.muli %[[E10]], %[[E00]] : vector<2xi32> +// CHECK: %[[R0:.+]] = vector.reduction , %[[M0]] : vector<2xi32> into i32 +// CHECK: %[[I0:.+]] = vector.insert %[[R0]], %[[C]] [0] : i32 into vector<2xi32> +// CHECK: %[[E11:.+]] = vector.extract %[[A1]][1] : vector<2x2xi32> +// CHECK: %[[M1:.+]] = arith.muli %[[E11]], %[[E00]] : vector<2xi32> +// CHECK: %[[R1:.+]] = vector.reduction , %[[M1]] : vector<2xi32> into i32 +// CHECK: %[[I1:.+]] = vector.insert %[[R1]], %[[I0]] [1] : i32 into vector<2xi32> +// CHECK: %[[S:.+]] = arith.addi %[[I1]], %[[A2]] : vector<2xi32> +// CHECK: return %[[S]] : vector<2xi32> + +func.func @contract_one_sided_unit_reduction_dim(%arg0 : vector<1x2xi32>, %arg1 : vector<2x2xi32>, %arg2 : vector<2xi32>) -> vector<2xi32> { + %res = vector.contract { + indexing_maps = [ + affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d1, d2)>, + affine_map<(d0, d1, d2) -> (d1)> + ], + iterator_types = ["reduction", "parallel", "reduction"], + kind = #vector.kind + } %arg0, %arg1, %arg2 : vector<1x2xi32>, vector<2x2xi32>, vector<2xi32> into vector<2xi32> + return %res : vector<2xi32> +} + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + %f = transform.structured.match ops{["func.func"]} in %module_op + : (!pdl.operation) -> !pdl.operation + + %f2 = transform.vector.lower_contraction %f + lowering_strategy = "dot" + : (!pdl.operation) -> !pdl.operation +} diff --git a/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir @@ -0,0 +1,57 @@ +// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file | FileCheck %s + +#matmat_accesses = [ + affine_map<(i, j, k) -> (i, k)>, + affine_map<(i, j, k) -> (k, j)>, + affine_map<(i, j, k) -> (i, j)> +] +#matmat_trait = { + indexing_maps = #matmat_accesses, + iterator_types = ["parallel", "parallel", "reduction"] +} + +// CHECK-LABEL: func @matmul +// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>, +// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>, +// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32> +// CHECK: %[[vcst:.*]] = arith.constant dense<0.000000e+00> : vector<8xf32> +// CHECK: %[[vcst_0:.*]] = arith.constant dense<0.000000e+00> : vector<12xf32> +// CHECK: %[[vcst_1:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32> +// CHECK: %[[a0:.*]] = vector.extract %[[A]][0] : vector<2x4xf32> +// CHECK: %[[a1:.*]] = vector.insert_strided_slice %[[a0]], %[[vcst]] {offsets = [0], strides = [1]} : vector<4xf32> into vector<8xf32> +// CHECK: %[[a2:.*]] = vector.extract %[[A]][1] : vector<2x4xf32> +// CHECK: %[[a3:.*]] = vector.insert_strided_slice %[[a2]], %[[a1]] {offsets = [4], strides = [1]} : vector<4xf32> into vector<8xf32> +// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<4x3xf32> +// CHECK: %[[b1:.*]] = vector.insert_strided_slice %[[b0]], %[[vcst_0]] {offsets = [0], strides = [1]} : vector<3xf32> into vector<12xf32> +// CHECK: %[[b2:.*]] = vector.extract %[[B]][1] : vector<4x3xf32> +// CHECK: %[[b3:.*]] = vector.insert_strided_slice %[[b2]], %[[b1]] {offsets = [3], strides = [1]} : vector<3xf32> into vector<12xf32> +// CHECK: %[[b4:.*]] = vector.extract %[[B]][2] : vector<4x3xf32> +// CHECK: %[[b5:.*]] = vector.insert_strided_slice %[[b4]], %[[b3]] {offsets = [6], strides = [1]} : vector<3xf32> into vector<12xf32> +// CHECK: %[[b6:.*]] = vector.extract %[[B]][3] : vector<4x3xf32> +// CHECK: %[[b7:.*]] = vector.insert_strided_slice %[[b6]], %[[b5]] {offsets = [9], strides = [1]} : vector<3xf32> into vector<12xf32> +// CHECK: %[[mm1:.*]] = vector.matrix_multiply %[[a3]], %[[b7]] {lhs_columns = 4 : i32, lhs_rows = 2 : i32, rhs_columns = 3 : i32} : (vector<8xf32>, vector<12xf32>) -> vector<6xf32> +// CHECK: %[[mm2:.*]] = vector.extract_strided_slice %[[mm1]] {offsets = [0], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32> +// CHECK: %[[mm3:.*]] = vector.insert %[[mm2]], %[[vcst_1]] [0] : vector<3xf32> into vector<2x3xf32> +// CHECK: %[[mm4:.*]] = vector.extract_strided_slice %[[mm1]] {offsets = [3], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32> +// CHECK: %[[mm5:.*]] = vector.insert %[[mm4]], %[[mm3]] [1] : vector<3xf32> into vector<2x3xf32> +// CHECK: %[[mm6:.*]] = arith.addf %[[C]], %[[mm5]] : vector<2x3xf32> +func.func @matmul(%arg0: vector<2x4xf32>, + %arg1: vector<4x3xf32>, + %arg2: vector<2x3xf32>) -> vector<2x3xf32> { + %0 = vector.contract #matmat_trait %arg0, %arg1, %arg2 + : vector<2x4xf32>, vector<4x3xf32> into vector<2x3xf32> + return %0 : vector<2x3xf32> +} + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + %f = transform.structured.match ops{["func.func"]} in %module_op + : (!pdl.operation) -> !pdl.operation + + %f2 = transform.vector.lower_contraction %f + lowering_strategy = "matmulintrinsics" + : (!pdl.operation) -> !pdl.operation + + %f3 = transform.vector.lower_shape_cast %f2 + : (!pdl.operation) -> !pdl.operation +} diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir @@ -0,0 +1,353 @@ +// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file | FileCheck %s + +#matvec_accesses = [ + affine_map<(i, j) -> (i, j)>, + affine_map<(i, j) -> (j)>, + affine_map<(i, j) -> (i)> +] +#matvec_trait = { + indexing_maps = #matvec_accesses, + iterator_types = ["parallel", "reduction"] +} + +#matmat_accesses = [ + affine_map<(i, j, k) -> (i, k)>, + affine_map<(i, j, k) -> (k, j)>, + affine_map<(i, j, k) -> (i, j)> +] +#matmat_trait = { + indexing_maps = #matmat_accesses, + iterator_types = ["parallel", "parallel", "reduction"] +} + +#matmat_accesses_0 = [ + affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (k, n)>, + affine_map<(m, n, k) -> (m, n)> +] +#matmat_trait_0 = { + indexing_maps = #matmat_accesses_0, + iterator_types = ["parallel", "parallel", "reduction"] +} + +// CHECK-LABEL: func.func @masked_extract_contract2( +// CHECK-SAME: %[[VAL_0:.*]]: vector<2x3xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: vector<3xf32>, +// CHECK-SAME: %[[VAL_2:.*]]: vector<2xf32>, +// CHECK-SAME: %[[IN_MASK:.*]]: vector<2x3xi1>) -> vector<2xf32> +// CHECK: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<2x3xi1> to vector<3x2xi1> +// CHECK: %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<3x2xi1> +// CHECK: vector.mask %[[MASK0]] { vector.outerproduct + +// CHECK: %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<3x2xi1> +// CHECK: vector.mask %[[MASK1]] { vector.outerproduct + +// CHECK: %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<3x2xi1> +// CHECK: vector.mask %[[MASK2]] { vector.outerproduct + +func.func @masked_extract_contract2(%arg0: vector<2x3xf32>, + %arg1: vector<3xf32>, + %arg2: vector<2xf32>, + %m: vector<2x3xi1>) -> vector<2xf32> { + %0 = vector.mask %m { vector.contract #matvec_trait %arg0, %arg1, %arg2 + : vector<2x3xf32>, vector<3xf32> into vector<2xf32> } : vector<2x3xi1> -> vector<2xf32> + return %0 : vector<2xf32> +} + +// CHECK-LABEL: func.func @masked_extract_contract4( +// CHECK-SAME: %[[VAL_0:.*]]: vector<3x5xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: vector<5x7xf32>, +// CHECK-SAME: %[[VAL_2:.*]]: vector<3x7xf32>, +// CHECK-SAME: %[[VAL_3:.*]]: vector<3x7x5xi1>) -> vector<3x7xf32> { +// CHECK: %[[VAL_5:.*]] = vector.transpose %[[VAL_3]], [2, 0, 1] : vector<3x7x5xi1> to vector<5x3x7xi1> +// CHECK: %[[VAL_8:.*]] = vector.extract %[[VAL_5]][0] : vector<5x3x7xi1> +// CHECK: %[[VAL_9:.*]] = vector.mask %[[VAL_8]] { vector.outerproduct %{{.*}} {kind = #vector.kind} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32> +// CHECK: %[[VAL_12:.*]] = vector.extract %[[VAL_5]][1] : vector<5x3x7xi1> +// CHECK: %[[VAL_13:.*]] = vector.mask %[[VAL_12]] { vector.outerproduct %{{.*}} {kind = #vector.kind} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32> +// CHECK: %[[VAL_16:.*]] = vector.extract %[[VAL_5]][2] : vector<5x3x7xi1> +// CHECK: %[[VAL_17:.*]] = vector.mask %[[VAL_16]] { vector.outerproduct %{{.*}} {kind = #vector.kind} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32> +// CHECK: %[[VAL_20:.*]] = vector.extract %[[VAL_5]][3] : vector<5x3x7xi1> +// CHECK: %[[VAL_21:.*]] = vector.mask %[[VAL_20]] { vector.outerproduct %{{.*}} {kind = #vector.kind} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32> +// CHECK: %[[VAL_24:.*]] = vector.extract %[[VAL_5]][4] : vector<5x3x7xi1> +// CHECK: %[[VAL_25:.*]] = vector.mask %[[VAL_24]] { vector.outerproduct %{{.*}} {kind = #vector.kind} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32> + +func.func @masked_extract_contract4(%arg0: vector<3x5xf32>, + %arg1: vector<5x7xf32>, + %arg2: vector<3x7xf32>, + %m : vector<3x7x5xi1>) -> vector<3x7xf32> { + %0 = vector.mask %m { vector.contract #matmat_trait %arg0, %arg1, %arg2 + : vector<3x5xf32>, vector<5x7xf32> into vector<3x7xf32> } : vector<3x7x5xi1> -> vector<3x7xf32> + return %0 : vector<3x7xf32> +} + +// CHECK-LABEL: func @matmul +// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>, +// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>, +// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32> +// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0] +// CHECK-SAME: : vector<2x4xf32> to vector<4x2xf32> +// +// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<4x2xf32> +// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<4x3xf32> +// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]] +// CHECK-SAME: : vector<2xf32>, vector<3xf32> +// +// CHECK: %[[a1:.*]] = vector.extract %[[At]][1] : vector<4x2xf32> +// CHECK: %[[b1:.*]] = vector.extract %[[B]][1] : vector<4x3xf32> +// CHECK: %[[c1:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[c0]] +// CHECK-SAME: : vector<2xf32>, vector<3xf32> +// +// CHECK: %[[a2:.*]] = vector.extract %[[At]][2] : vector<4x2xf32> +// CHECK: %[[b2:.*]] = vector.extract %[[B]][2] : vector<4x3xf32> +// CHECK: %[[c2:.*]] = vector.outerproduct %[[a2]], %[[b2]], %[[c1]] +// CHECK-SAME: : vector<2xf32>, vector<3xf32> +// +// CHECK: %[[a3:.*]] = vector.extract %[[At]][3] : vector<4x2xf32> +// CHECK: %[[b3:.*]] = vector.extract %[[B]][3] : vector<4x3xf32> +// CHECK: %[[c3:.*]] = vector.outerproduct %[[a3]], %[[b3]], %[[c2]] +// CHECK-SAME: : vector<2xf32>, vector<3xf32> +// +// CHECK: return %[[c3]] : vector<2x3xf32> +func.func @matmul(%arg0: vector<2x4xf32>, + %arg1: vector<4x3xf32>, + %arg2: vector<2x3xf32>) -> vector<2x3xf32> { + %0 = vector.contract #matmat_trait %arg0, %arg1, %arg2 + : vector<2x4xf32>, vector<4x3xf32> into vector<2x3xf32> + return %0 : vector<2x3xf32> +} + +// CHECK-LABEL: func @matmul_0 +// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>, +// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>, +// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32> +// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0] +// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32> +// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32> +// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]] +// CHECK: return %[[c0]] : vector<2x3xf32> +func.func @matmul_0(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<2x3xf32>) +-> vector<2x3xf32> +{ + %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2 + : vector<2x1xf32>, vector<1x3xf32> into vector<2x3xf32> + return %0 : vector<2x3xf32> +} + +// CHECK-LABEL: func @matmul_0_mixed +// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf16>, +// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf16>, +// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32> +// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0] +// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf16> +// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf16> +// CHECK: %[[a1:.*]] = arith.extf %[[a0]] : vector<2xf16> to vector<2xf32> +// CHECK: %[[b1:.*]] = arith.extf %[[b0]] : vector<3xf16> to vector<3xf32> +// CHECK: %[[c0:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[C]] +// CHECK: return %[[c0]] : vector<2x3xf32> +func.func @matmul_0_mixed(%arg0: vector<2x1xf16>, %arg1: vector<1x3xf16>, %arg2: vector<2x3xf32>) +-> vector<2x3xf32> +{ + %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2 + : vector<2x1xf16>, vector<1x3xf16> into vector<2x3xf32> + return %0 : vector<2x3xf32> +} + +#matmat_accesses_1 = [ + affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (n, k)>, + affine_map<(m, n, k) -> (m, n)> +] +#matmat_trait_1 = { + indexing_maps = #matmat_accesses_1, + iterator_types = ["parallel", "parallel", "reduction"] +} + +// CHECK-LABEL: func @matmul_1 +// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>, +// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<3x1xf32>, +// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32> +// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0] +// CHECK: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0] +// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32> +// CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<1x3xf32> +// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]] +// CHECK: return %[[c0]] : vector<2x3xf32> +func.func @matmul_1(%arg0: vector<2x1xf32>, %arg1: vector<3x1xf32>, %arg2: vector<2x3xf32>) +-> vector<2x3xf32> +{ + %0 = vector.contract #matmat_trait_1 %arg0, %arg1, %arg2 + : vector<2x1xf32>, vector<3x1xf32> into vector<2x3xf32> + return %0 : vector<2x3xf32> +} + +#matmat_accesses_2 = [ + affine_map<(m, n, k) -> (k, m)>, + affine_map<(m, n, k) -> (k, n)>, + affine_map<(m, n, k) -> (m, n)> +] +#matmat_trait_2 = { + indexing_maps = #matmat_accesses_2, + iterator_types = ["parallel", "parallel", "reduction"] +} + +// CHECK-LABEL: func @matmul_2 +// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>, +// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>, +// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32> +// CHECK: %[[a0:.*]] = vector.extract %[[A]][0] : vector<1x2xf32> +// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32> +// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]] +// CHECK: return %[[c0]] : vector<2x3xf32> +func.func @matmul_2(%arg0: vector<1x2xf32>, %arg1: vector<1x3xf32>, %arg2: vector<2x3xf32>) +-> vector<2x3xf32> +{ + %0 = vector.contract #matmat_trait_2 %arg0, %arg1, %arg2 + : vector<1x2xf32>, vector<1x3xf32> into vector<2x3xf32> + return %0 : vector<2x3xf32> +} + +#matmat_accesses_3 = [ + affine_map<(m, n, k) -> (k, m)>, + affine_map<(m, n, k) -> (n, k)>, + affine_map<(m, n, k) -> (m, n)> +] +#matmat_trait_3 = { + indexing_maps = #matmat_accesses_3, + iterator_types = ["parallel", "parallel", "reduction"] +} + +// CHECK-LABEL: func @matmul_3 +// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>, +// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<3x1xf32>, +// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32> +// CHECK: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0] +// CHECK: %[[a0:.*]] = vector.extract %[[A]][0] : vector<1x2xf32> +// CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<1x3xf32> +// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]] +// CHECK: return %[[c0]] : vector<2x3xf32> +func.func @matmul_3(%arg0: vector<1x2xf32>, %arg1: vector<3x1xf32>, %arg2: vector<2x3xf32>) +-> vector<2x3xf32> +{ + %0 = vector.contract #matmat_trait_3 %arg0, %arg1, %arg2 + : vector<1x2xf32>, vector<3x1xf32> into vector<2x3xf32> + return %0 : vector<2x3xf32> +} + +#matmat_accesses_4 = [ + affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (k, n)>, + affine_map<(m, n, k) -> (n, m)> +] +#matmat_trait_4 = { + indexing_maps = #matmat_accesses_4, + iterator_types = ["parallel", "parallel", "reduction"] +} + +// CHECK-LABEL: func @matmul_4 +// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>, +// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>, +// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32> +// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0] +// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32> +// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32> +// CHECK: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]] +// CHECK: return %[[c0]] : vector<3x2xf32> +func.func @matmul_4(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<3x2xf32>) +-> vector<3x2xf32> +{ + %0 = vector.contract #matmat_trait_4 %arg0, %arg1, %arg2 + : vector<2x1xf32>, vector<1x3xf32> into vector<3x2xf32> + return %0 : vector<3x2xf32> +} + +#matmat_accesses_5 = [ + affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (k, n)>, + affine_map<(m, n, k) -> (n, m)> +] +#matmat_trait_5 = { + indexing_maps = #matmat_accesses_5, + iterator_types = ["parallel", "parallel", "reduction"] +} + +// CHECK-LABEL: func @matmul_5 +// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>, +// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>, +// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32> +// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0] +// CHECK-DAG: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32> +// CHECK-DAG: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32> +// CHECK: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]] +// CHECK: return %[[c0]] : vector<3x2xf32> +func.func @matmul_5(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<3x2xf32>) +-> vector<3x2xf32> +{ + %0 = vector.contract #matmat_trait_5 %arg0, %arg1, %arg2 + : vector<2x1xf32>, vector<1x3xf32> into vector<3x2xf32> + return %0 : vector<3x2xf32> +} + +#matmat_accesses_6 = [ + affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (k, n)>, + affine_map<(m, n, k) -> (n, m)> +] +#matmat_trait_6 = { + indexing_maps = #matmat_accesses_6, + iterator_types = ["parallel", "parallel", "reduction"] +} + +// CHECK-LABEL: func @matmul_6 +// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>, +// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>, +// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32> +// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0] +// CHECK-DAG: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32> +// CHECK-DAG: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32> +// CHECK: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]] +// CHECK: return %[[c0]] : vector<3x2xf32> +func.func @matmul_6(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<3x2xf32>) +-> vector<3x2xf32> +{ + %0 = vector.contract #matmat_trait_6 %arg0, %arg1, %arg2 + : vector<2x1xf32>, vector<1x3xf32> into vector<3x2xf32> + return %0 : vector<3x2xf32> +} + +#matmat_accesses_7 = [ + affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (k, n)>, + affine_map<(m, n, k) -> (n, m)> +] +#matmat_trait_7 = { + indexing_maps = #matmat_accesses_7, + iterator_types = ["parallel", "parallel", "reduction"] +} + +// CHECK-LABEL: func @matmul_7 +// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>, +// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>, +// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32> +// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0] +// CHECK-DAG: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32> +// CHECK-DAG: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32> +// CHECK: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]] +// CHECK: return %[[c0]] : vector<3x2xf32> +func.func @matmul_7(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<3x2xf32>) +-> vector<3x2xf32> +{ + %0 = vector.contract #matmat_trait_7 %arg0, %arg1, %arg2 + : vector<2x1xf32>, vector<1x3xf32> into vector<3x2xf32> + return %0 : vector<3x2xf32> +} + + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + %f = transform.structured.match ops{["func.func"]} in %module_op + : (!pdl.operation) -> !pdl.operation + + %f2 = transform.vector.lower_contraction %f + lowering_strategy = "outerproduct" + : (!pdl.operation) -> !pdl.operation +} diff --git a/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir @@ -0,0 +1,62 @@ +// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file | FileCheck %s + +// CHECK-LABEL: func @parallel_contract_lowering +// CHECK: %[[E0:.*]] = vector.extract %{{.*}}[0, 0] : vector<1x1x4xf32> +// CHECK: %[[E1:.*]] = vector.extract %{{.*}}[0, 0] : vector<1x1x4xf32> +// CHECK: %[[F:.*]] = vector.fma %[[E0]], %[[E1]], %{{.*}} : vector<4xf32> +// CHECK: return %[[F]] : vector<4xf32> +func.func @parallel_contract_lowering(%arg0: vector<1x1x4xf32>, %arg1: vector<1x1x4xf32>, %arg2: vector<4xf32>) -> vector<4xf32> { + %0 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2, d0)>, affine_map<(d0, d1, d2) -> (d1, d2, d0)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"], kind = #vector.kind} %arg0, %arg1, %arg2 : vector<1x1x4xf32>, vector<1x1x4xf32> into vector<4xf32> + return %0 : vector<4xf32> +} + +// CHECK-LABEL: func @parallel_contract_lowering_broadcast +// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} : vector<1x1xf32> to vector<4x1x1xf32> +// CHECK: %[[T:.*]] = vector.transpose %[[B]], [1, 2, 0] : vector<4x1x1xf32> to vector<1x1x4xf32> +// CHECK: %[[E0:.*]] = vector.extract %[[T]][0, 0] : vector<1x1x4xf32> +// CHECK: %[[E1:.*]] = vector.extract %{{.*}}[0, 0] : vector<1x1x4xf32> +// CHECK: %[[F:.*]] = vector.fma %[[E0]], %[[E1]], %{{.*}} : vector<4xf32> +// CHECK: return %[[F]] : vector<4xf32> +func.func @parallel_contract_lowering_broadcast(%arg0: vector<1x1xf32>, %arg1: vector<1x1x4xf32>, %arg2: vector<4xf32>) -> vector<4xf32> { + %0 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d1, d2, d0)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"], kind = #vector.kind} %arg0, %arg1, %arg2 : vector<1x1xf32>, vector<1x1x4xf32> into vector<4xf32> + return %0 : vector<4xf32> +} + +// CHECK-LABEL: func @parallel_contract_lowering +// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} : vector<1x1xf32> to vector<4x1x1xf32> +// CHECK: %[[T0:.*]] = vector.transpose %[[B]], [1, 2, 0] : vector<4x1x1xf32> to vector<1x1x4xf32> +// CHECK: %[[T1:.*]] = vector.transpose %{{.*}}, [0, 2, 1] : vector<1x4x1xf32> to vector<1x1x4xf32> +// CHECK: %[[E0:.*]] = vector.extract %[[T0]][0, 0] : vector<1x1x4xf32> +// CHECK: %[[E1:.*]] = vector.extract %[[T1]][0, 0] : vector<1x1x4xf32> +// CHECK: %[[F:.*]] = vector.fma %[[E0]], %[[E1]], %arg2 : vector<4xf32> +// CHECK: return %[[F]] : vector<4xf32> +func.func @parallel_contract_lowering_transpose(%arg0: vector<1x1xf32>, %arg1: vector<1x4x1xf32>, %arg2: vector<4xf32>) -> vector<4xf32> { + %0 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d1, d0, d2)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"], kind = #vector.kind} %arg0, %arg1, %arg2 : vector<1x1xf32>, vector<1x4x1xf32> into vector<4xf32> + return %0 : vector<4xf32> +} + +// CHECK-LABEL: func @parallel_contract_lowering_scalar +// CHECK: %[[E0:.*]] = vector.extract %{{.*}}[0, 0] : vector<1x1xf32> +// CHECK: %[[E1:.*]] = vector.extract %{{.*}}[0, 0] : vector<1x1xf32> +// CHECK: %[[M:.*]] = arith.mulf %[[E0]], %[[E1]] : f32 +// CHECK: %[[A:.*]] = arith.addf %[[M]], %{{.*}} : f32 +// CHECK: return %[[A]] : f32 +func.func @parallel_contract_lowering_scalar(%arg0: vector<1x1xf32>, %arg1: vector<1x1xf32>, %arg2: f32) -> f32 { + %0 = vector.contract { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> ()>], + iterator_types = ["reduction", "reduction"], kind = #vector.kind} + %arg0, %arg1, %arg2 : vector<1x1xf32>, vector<1x1xf32> into f32 + return %0 : f32 +} + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + %f = transform.structured.match ops{["func.func"]} in %module_op + : (!pdl.operation) -> !pdl.operation + + %f2 = transform.vector.lower_contraction %f + lowering_strategy = "parallelarith" + : (!pdl.operation) -> !pdl.operation +} diff --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir deleted file mode 100644 --- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir +++ /dev/null @@ -1,1235 +0,0 @@ -// RUN: mlir-opt %s -test-vector-contraction-lowering | FileCheck %s -// RUN: mlir-opt %s -test-vector-contraction-lowering=vector-lower-matrix-intrinsics=1 | FileCheck %s --check-prefix=MATRIX -// RUN: mlir-opt %s -test-vector-contraction-lowering=vector-outerproduct=1 | FileCheck %s --check-prefix=OUTERPRODUCT -// RUN: mlir-opt %s -test-vector-contraction-lowering=vector-parallel-arith=1 | FileCheck %s --check-prefix=PARALLEL - -#dotp_accesses = [ - affine_map<(i) -> (i)>, - affine_map<(i) -> (i)>, - affine_map<(i) -> ()> -] -#dotp_trait = { - indexing_maps = #dotp_accesses, - iterator_types = ["reduction"] -} - -// CHECK-LABEL: func @extract_contract1 -// CHECK-SAME: %[[A:.*0]]: vector<4xf32>, -// CHECK-SAME: %[[B:.*1]]: vector<4xf32>, -// CHECK-SAME: %[[C:.*2]]: f32 -// CHECK: %[[F:.*]] = arith.mulf %[[A]], %[[B]] : vector<4xf32> -// CHECK: %[[R:.*]] = vector.reduction , %[[F]], %[[C]] : vector<4xf32> into f32 -// CHECK: return %[[R]] : f32 - -func.func @extract_contract1(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2: f32) -> f32 { - %0 = vector.contract #dotp_trait %arg0, %arg1, %arg2 - : vector<4xf32>, vector<4xf32> into f32 - return %0 : f32 -} - -// CHECK-LABEL: func @masked_extract_contract1 -// CHECK-SAME: %[[A:.*0]]: vector<4xf32>, %[[B:.*1]]: vector<4xf32>, %[[C:.*2]]: f32 -// CHECK-SAME: %[[M:.*]]: vector<4xi1> -// CHECK: %[[F:.*]] = arith.mulf %[[A]], %[[B]] : vector<4xf32> -// CHECK: %[[R:.*]] = vector.mask %[[M]] { vector.reduction , %0, %arg2 : vector<4xf32> into f32 } : vector<4xi1> -> f32 -// CHECK: return %[[R]] : f32 - -func.func @masked_extract_contract1(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2: f32, %mask: vector<4xi1>) -> f32 { - %0 = vector.mask %mask { vector.contract #dotp_trait %arg0, %arg1, %arg2 : vector<4xf32>, vector<4xf32> into f32 } : vector<4xi1> -> f32 - return %0 : f32 -} - -// CHECK-LABEL: func @extract_contract1_int -// CHECK-SAME: %[[A:.*0]]: vector<4xi32>, -// CHECK-SAME: %[[B:.*1]]: vector<4xi32>, -// CHECK-SAME: %[[C:.*2]]: i32 -// CHECK: %[[F:.*]] = arith.muli %[[A]], %[[B]] : vector<4xi32> -// CHECK: %[[R:.*]] = vector.reduction , %[[F]], %[[C]] : vector<4xi32> into i32 -// CHECK: return %[[R]] : i32 - -func.func @extract_contract1_int(%arg0: vector<4xi32>, %arg1: vector<4xi32>, %arg2: i32) -> i32 { - %0 = vector.contract #dotp_trait %arg0, %arg1, %arg2 - : vector<4xi32>, vector<4xi32> into i32 - return %0 : i32 -} - -#matvec_accesses = [ - affine_map<(i, j) -> (i, j)>, - affine_map<(i, j) -> (j)>, - affine_map<(i, j) -> (i)> -] -#matvec_trait = { - indexing_maps = #matvec_accesses, - iterator_types = ["parallel", "reduction"] -} - -// CHECK-LABEL: func @extract_contract2 -// CHECK-SAME: %[[A:.*0]]: vector<2x3xf32>, -// CHECK-SAME: %[[B:.*1]]: vector<3xf32>, -// CHECK-SAME: %[[C:.*2]]: vector<2xf32> -// CHECK: %[[R:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32> -// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xf32> -// CHECK: %[[T2:.*]] = arith.mulf %[[T0]], %[[B]] : vector<3xf32> -// CHECK: %[[T3:.*]] = vector.reduction , %[[T2]] : vector<3xf32> into f32 -// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : f32 into vector<2xf32> -// CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : vector<2x3xf32> -// CHECK: %[[T7:.*]] = arith.mulf %[[T5]], %[[B]] : vector<3xf32> -// CHECK: %[[T8:.*]] = vector.reduction , %[[T7]] : vector<3xf32> into f32 -// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : f32 into vector<2xf32> -// CHECK: %[[T10:.*]] = arith.addf %[[T9]], %[[C]] : vector<2xf32> -// CHECK: return %[[T10]] : vector<2xf32> - -func.func @extract_contract2(%arg0: vector<2x3xf32>, - %arg1: vector<3xf32>, - %arg2: vector<2xf32>) -> vector<2xf32> { - %0 = vector.contract #matvec_trait %arg0, %arg1, %arg2 - : vector<2x3xf32>, vector<3xf32> into vector<2xf32> - return %0 : vector<2xf32> -} - -// OUTERPRODUCT-LABEL: func.func @masked_extract_contract2( -// OUTERPRODUCT-SAME: %[[VAL_0:.*]]: vector<2x3xf32>, -// OUTERPRODUCT-SAME: %[[VAL_1:.*]]: vector<3xf32>, -// OUTERPRODUCT-SAME: %[[VAL_2:.*]]: vector<2xf32>, -// OUTERPRODUCT-SAME: %[[IN_MASK:.*]]: vector<2x3xi1>) -> vector<2xf32> -// OUTERPRODUCT: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<2x3xi1> to vector<3x2xi1> -// OUTERPRODUCT: %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<3x2xi1> -// OUTERPRODUCT: vector.mask %[[MASK0]] { vector.outerproduct - -// OUTERPRODUCT: %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<3x2xi1> -// OUTERPRODUCT: vector.mask %[[MASK1]] { vector.outerproduct - -// OUTERPRODUCT: %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<3x2xi1> -// OUTERPRODUCT: vector.mask %[[MASK2]] { vector.outerproduct - -func.func @masked_extract_contract2(%arg0: vector<2x3xf32>, - %arg1: vector<3xf32>, - %arg2: vector<2xf32>, - %m: vector<2x3xi1>) -> vector<2xf32> { - %0 = vector.mask %m { vector.contract #matvec_trait %arg0, %arg1, %arg2 - : vector<2x3xf32>, vector<3xf32> into vector<2xf32> } : vector<2x3xi1> -> vector<2xf32> - return %0 : vector<2xf32> -} - -// CHECK-LABEL: func @extract_contract2_int -// CHECK-SAME: %[[A:.*0]]: vector<2x3xi32>, -// CHECK-SAME: %[[B:.*1]]: vector<3xi32>, -// CHECK-SAME: %[[C:.*2]]: vector<2xi32> -// CHECK: %[[R:.*]] = arith.constant dense<0> : vector<2xi32> -// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xi32> -// CHECK: %[[T2:.*]] = arith.muli %[[T0]], %[[B]] : vector<3xi32> -// CHECK: %[[T3:.*]] = vector.reduction , %[[T2]] : vector<3xi32> into i32 -// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : i32 into vector<2xi32> -// CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : vector<2x3xi32> -// CHECK: %[[T7:.*]] = arith.muli %[[T5]], %[[B]] : vector<3xi32> -// CHECK: %[[T8:.*]] = vector.reduction , %[[T7]] : vector<3xi32> into i32 -// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : i32 into vector<2xi32> -// CHECK: %[[T10:.*]] = arith.addi %[[T9]], %[[C]] : vector<2xi32> -// CHECK: return %[[T10]] : vector<2xi32> -func.func @extract_contract2_int(%arg0: vector<2x3xi32>, - %arg1: vector<3xi32>, - %arg2: vector<2xi32>) -> vector<2xi32> { - %0 = vector.contract #matvec_trait %arg0, %arg1, %arg2 - : vector<2x3xi32>, vector<3xi32> into vector<2xi32> - return %0 : vector<2xi32> -} - -#vecmat_accesses = [ - affine_map<(i, j) -> (j)>, - affine_map<(i, j) -> (i, j)>, - affine_map<(i, j) -> (i)> -] -#vecmat_trait = { - indexing_maps = #vecmat_accesses, - iterator_types = ["parallel", "reduction"] -} - -// CHECK-LABEL: func @extract_contract3 -// CHECK-SAME: %[[A:.*0]]: vector<3xf32>, -// CHECK-SAME: %[[B:.*1]]: vector<2x3xf32>, -// CHECK-SAME: %[[C:.*2]]: vector<2xf32> -// CHECK: %[[R:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32> -// CHECK: %[[T0:.*]] = vector.extract %[[B]][0] : vector<2x3xf32> -// CHECK: %[[T2:.*]] = arith.mulf %[[T0]], %[[A]] : vector<3xf32> -// CHECK: %[[T3:.*]] = vector.reduction , %[[T2]] : vector<3xf32> into f32 -// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : f32 into vector<2xf32> -// CHECK: %[[T5:.*]] = vector.extract %[[B]][1] : vector<2x3xf32> -// CHECK: %[[T7:.*]] = arith.mulf %[[T5]], %[[A]] : vector<3xf32> -// CHECK: %[[T8:.*]] = vector.reduction , %[[T7]] : vector<3xf32> into f32 -// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : f32 into vector<2xf32> -// CHECK: %[[T10:.*]] = arith.addf %[[T9]], %[[C]] : vector<2xf32> -// CHECK: return %[[T10]] : vector<2xf32> - -func.func @extract_contract3(%arg0: vector<3xf32>, - %arg1: vector<2x3xf32>, - %arg2: vector<2xf32>) -> vector<2xf32> { - %0 = vector.contract #vecmat_trait %arg0, %arg1, %arg2 - : vector<3xf32>, vector<2x3xf32> into vector<2xf32> - return %0 : vector<2xf32> -} - -#matmat_accesses = [ - affine_map<(i, j, k) -> (i, k)>, - affine_map<(i, j, k) -> (k, j)>, - affine_map<(i, j, k) -> (i, j)> -] -#matmat_trait = { - indexing_maps = #matmat_accesses, - iterator_types = ["parallel", "parallel", "reduction"] -} - -// CHECK-LABEL: func @extract_contract4 -// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>, -// CHECK-SAME: %[[B:.*1]]: vector<2x2xf32>, -// CHECK-SAME: %[[C:.*2]]: vector<2x2xf32> -// CHECK: %[[R:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32> -// CHECK: %[[Bt:.*]] = vector.transpose %arg1, [1, 0] : vector<2x2xf32> to vector<2x2xf32> -// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x2xf32> -// CHECK: %[[T2:.*]] = vector.extract %[[Bt]][0] : vector<2x2xf32> -// CHECK: %[[T9:.*]] = arith.mulf %[[T0]], %[[T2]] : vector<2xf32> -// CHECK: %[[T10:.*]] = vector.reduction , %[[T9]] : vector<2xf32> into f32 -// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[R]] [0, 0] : f32 into vector<2x2xf32> -// -// CHECK: %[[T12:.*]] = vector.extract %[[Bt]][1] : vector<2x2xf32> -// CHECK: %[[T19:.*]] = arith.mulf %[[T0]], %[[T12]] : vector<2xf32> -// CHECK: %[[T20:.*]] = vector.reduction , %[[T19]] : vector<2xf32> into f32 -// CHECK: %[[T21:.*]] = vector.insert %[[T20]], %[[T11]] [0, 1] : f32 into vector<2x2xf32> -// -// CHECK: %[[T23:.*]] = vector.extract %[[A]][1] : vector<2x2xf32> -// CHECK: %[[T24:.*]] = vector.extract %[[Bt]][0] : vector<2x2xf32> -// CHECK: %[[T32:.*]] = arith.mulf %[[T23]], %[[T24]] : vector<2xf32> -// CHECK: %[[T33:.*]] = vector.reduction , %[[T32]] : vector<2xf32> into f32 -// CHECK: %[[T34:.*]] = vector.insert %[[T33]], %[[T21]] [1, 0] : f32 into vector<2x2xf32> -// -// CHECK: %[[T40:.*]] = vector.extract %[[Bt]][1] : vector<2x2xf32> -// CHECK: %[[T41:.*]] = arith.mulf %[[T23]], %[[T40]] : vector<2xf32> -// CHECK: %[[T42:.*]] = vector.reduction , %[[T41]] : vector<2xf32> into f32 -// CHECK: %[[T43:.*]] = vector.insert %[[T42]], %[[T34]] [1, 1] : f32 into vector<2x2xf32> -// -// CHECK: %[[T52:.*]] = arith.addf %[[T43]], %[[C]] : vector<2x2xf32> -// CHECK: return %[[T52]] : vector<2x2xf32> - -func.func @extract_contract4(%arg0: vector<2x2xf32>, - %arg1: vector<2x2xf32>, - %arg2: vector<2x2xf32>) -> vector<2x2xf32> { - %0 = vector.contract #matmat_trait %arg0, %arg1, %arg2 - : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> - return %0 : vector<2x2xf32> -} - -// OUTERPRODUCT-LABEL: func.func @masked_extract_contract4( -// OUTERPRODUCT-SAME: %[[VAL_0:.*]]: vector<3x5xf32>, -// OUTERPRODUCT-SAME: %[[VAL_1:.*]]: vector<5x7xf32>, -// OUTERPRODUCT-SAME: %[[VAL_2:.*]]: vector<3x7xf32>, -// OUTERPRODUCT-SAME: %[[VAL_3:.*]]: vector<3x7x5xi1>) -> vector<3x7xf32> { -// OUTERPRODUCT: %[[VAL_5:.*]] = vector.transpose %[[VAL_3]], [2, 0, 1] : vector<3x7x5xi1> to vector<5x3x7xi1> -// OUTERPRODUCT: %[[VAL_8:.*]] = vector.extract %[[VAL_5]][0] : vector<5x3x7xi1> -// OUTERPRODUCT: %[[VAL_9:.*]] = vector.mask %[[VAL_8]] { vector.outerproduct %{{.*}} {kind = #vector.kind} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32> -// OUTERPRODUCT: %[[VAL_12:.*]] = vector.extract %[[VAL_5]][1] : vector<5x3x7xi1> -// OUTERPRODUCT: %[[VAL_13:.*]] = vector.mask %[[VAL_12]] { vector.outerproduct %{{.*}} {kind = #vector.kind} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32> -// OUTERPRODUCT: %[[VAL_16:.*]] = vector.extract %[[VAL_5]][2] : vector<5x3x7xi1> -// OUTERPRODUCT: %[[VAL_17:.*]] = vector.mask %[[VAL_16]] { vector.outerproduct %{{.*}} {kind = #vector.kind} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32> -// OUTERPRODUCT: %[[VAL_20:.*]] = vector.extract %[[VAL_5]][3] : vector<5x3x7xi1> -// OUTERPRODUCT: %[[VAL_21:.*]] = vector.mask %[[VAL_20]] { vector.outerproduct %{{.*}} {kind = #vector.kind} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32> -// OUTERPRODUCT: %[[VAL_24:.*]] = vector.extract %[[VAL_5]][4] : vector<5x3x7xi1> -// OUTERPRODUCT: %[[VAL_25:.*]] = vector.mask %[[VAL_24]] { vector.outerproduct %{{.*}} {kind = #vector.kind} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32> - -func.func @masked_extract_contract4(%arg0: vector<3x5xf32>, - %arg1: vector<5x7xf32>, - %arg2: vector<3x7xf32>, - %m : vector<3x7x5xi1>) -> vector<3x7xf32> { - %0 = vector.mask %m { vector.contract #matmat_trait %arg0, %arg1, %arg2 - : vector<3x5xf32>, vector<5x7xf32> into vector<3x7xf32> } : vector<3x7x5xi1> -> vector<3x7xf32> - return %0 : vector<3x7xf32> -} - -#contraction2d_accesses = [ - affine_map<(i, j) -> (i, j)>, - affine_map<(i, j) -> (i, j)>, - affine_map<(i, j) -> ()> -] -#contraction2d_trait = { - indexing_maps = #contraction2d_accesses, - iterator_types = ["reduction", "reduction"] -} - -// CHECK-LABEL: func @full_contract1 -// CHECK-SAME: %[[A:.*0]]: vector<2x3xf32>, -// CHECK-SAME: %[[B:.*1]]: vector<2x3xf32>, -// CHECK-SAME: %[[C:.*2]]: f32 -// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xf32> -// CHECK: %[[T1:.*]] = vector.extract %[[B]][0] : vector<2x3xf32> -// CHECK: %[[T2:.*]] = arith.mulf %[[T0]], %[[T1]] : vector<3xf32> -// CHECK: %[[T3:.*]] = vector.reduction , %[[T2]], %[[C]] : vector<3xf32> into f32 -// CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : vector<2x3xf32> -// CHECK: %[[T6:.*]] = vector.extract %[[B]][1] : vector<2x3xf32> -// CHECK: %[[T7:.*]] = arith.mulf %[[T5]], %[[T6]] : vector<3xf32> -// CHECK: %[[T8:.*]] = vector.reduction , %[[T7]], %[[T3]] : vector<3xf32> into f32 -// CHECK: return %[[T8]] : f32 - -func.func @full_contract1(%arg0: vector<2x3xf32>, - %arg1: vector<2x3xf32>, - %arg2: f32) -> f32 { - %0 = vector.contract #contraction2d_trait %arg0, %arg1, %arg2 - : vector<2x3xf32>, vector<2x3xf32> into f32 - return %0 : f32 -} - -#contraction2d_trans_accesses = [ - affine_map<(i, j) -> (i, j)>, - affine_map<(i, j) -> (j, i)>, - affine_map<(i, j) -> ()> -] -#contraction2d_trans_trait = { - indexing_maps = #contraction2d_trans_accesses, - iterator_types = ["reduction", "reduction"] -} - -// CHECK-LABEL: func @full_contract2 -// CHECK-SAME: %[[A:.*0]]: vector<2x3xf32>, -// CHECK-SAME: %[[B:.*1]]: vector<3x2xf32>, -// CHECK-SAME: %[[C:.*2]]: f32 -// CHECK: %[[Z:.*]] = arith.constant dense<0.000000e+00> : vector<3xf32> -// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xf32> -// CHECK: %[[T1:.*]] = vector.extract %[[B]][0, 0] : vector<3x2xf32> -// CHECK: %[[T3:.*]] = vector.insert %[[T1]], %[[Z]] [0] : f32 into vector<3xf32> -// CHECK: %[[T4:.*]] = vector.extract %[[B]][1, 0] : vector<3x2xf32> -// CHECK: %[[T6:.*]] = vector.insert %[[T4]], %[[T3]] [1] : f32 into vector<3xf32> -// CHECK: %[[T7:.*]] = vector.extract %[[B]][2, 0] : vector<3x2xf32> -// CHECK: %[[T9:.*]] = vector.insert %[[T7]], %[[T6]] [2] : f32 into vector<3xf32> -// CHECK: %[[T10:.*]] = arith.mulf %[[T0]], %[[T9]] : vector<3xf32> -// CHECK: %[[T11:.*]] = vector.reduction , %[[T10]], %[[C]] : vector<3xf32> into f32 -// -// CHECK: %[[T12:.*]] = vector.extract %[[A]][1] : vector<2x3xf32> -// CHECK: %[[T13:.*]] = vector.extract %[[B]][0, 1] : vector<3x2xf -// CHECK: %[[T15:.*]] = vector.insert %[[T13]], %[[Z]] [0] : f32 into vector<3xf32> -// CHECK: %[[T16:.*]] = vector.extract %[[B]][1, 1] : vector<3x2xf32> -// CHECK: %[[T18:.*]] = vector.insert %[[T16]], %[[T15]] [1] : f32 into vector<3xf32> -// CHECK: %[[T19:.*]] = vector.extract %[[B]][2, 1] : vector<3x2xf32> -// CHECK: %[[T21:.*]] = vector.insert %[[T19]], %[[T18]] [2] : f32 into vector<3xf32> -// CHECK: %[[T22:.*]] = arith.mulf %[[T12]], %[[T21]] : vector<3xf32> -// CHECK: %[[T23:.*]] = vector.reduction , %[[T22]], %[[T11]] : vector<3xf32> into f32 -// CHECK: return %[[T23]] : f32 - -func.func @full_contract2(%arg0: vector<2x3xf32>, - %arg1: vector<3x2xf32>, - %arg2: f32) -> f32 { - %0 = vector.contract #contraction2d_trans_trait %arg0, %arg1, %arg2 - : vector<2x3xf32>, vector<3x2xf32> into f32 - return %0 : f32 -} - -// CHECK-LABEL: func @outerproduct_noacc -// CHECK-SAME: %[[A:.*0]]: vector<2xf32>, -// CHECK-SAME: %[[B:.*1]]: vector<3xf32> -// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32> -// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xf32> -// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xf32> -// CHECK: %[[T2:.*]] = arith.mulf %[[T1]], %[[B]] : vector<3xf32> -// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32> -// CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : vector<2xf32> -// CHECK: %[[T5:.*]] = vector.splat %[[T4]] : vector<3xf32> -// CHECK: %[[T6:.*]] = arith.mulf %[[T5]], %[[B]] : vector<3xf32> -// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xf32> into vector<2x3xf32> -// CHECK: return %[[T7]] : vector<2x3xf32> - -func.func @outerproduct_noacc(%arg0: vector<2xf32>, - %arg1: vector<3xf32>) -> vector<2x3xf32> { - %0 = vector.outerproduct %arg0, %arg1 : vector<2xf32>, vector<3xf32> - return %0: vector<2x3xf32> -} - -// CHECK-LABEL: func @outerproduct_acc -// CHECK-SAME: %[[A:.*0]]: vector<2xf32>, -// CHECK-SAME: %[[B:.*1]]: vector<3xf32>, -// CHECK-SAME: %[[C:.*2]]: vector<2x3xf32> -// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32> -// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xf32> -// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xf32> -// CHECK: %[[T2:.*]] = vector.extract %[[C]][0] : vector<2x3xf32> -// CHECK: %[[T3:.*]] = vector.fma %[[T1]], %[[B]], %[[T2]] : vector<3xf32> -// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32> -// CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : vector<2xf32> -// CHECK: %[[T6:.*]] = vector.splat %[[T5]] : vector<3xf32> -// CHECK: %[[T7:.*]] = vector.extract %[[C]][1] : vector<2x3xf32> -// CHECK: %[[T8:.*]] = vector.fma %[[T6]], %[[B]], %[[T7]] : vector<3xf32> -// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : vector<3xf32> into vector<2x3xf32> -// CHECK: return %[[T9]] : vector<2x3xf32> - -func.func @outerproduct_acc(%arg0: vector<2xf32>, - %arg1: vector<3xf32>, - %arg2: vector<2x3xf32>) -> vector<2x3xf32> { - %0 = vector.outerproduct %arg0, %arg1, %arg2 : vector<2xf32>, vector<3xf32> - return %0: vector<2x3xf32> -} - -// CHECK-LABEL: func @outerproduct_noacc_int -// CHECK-SAME: %[[A:.*0]]: vector<2xi32>, -// CHECK-SAME: %[[B:.*1]]: vector<3xi32> -// CHECK: %[[C0:.*]] = arith.constant dense<0> : vector<2x3xi32> -// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xi32> -// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xi32> -// CHECK: %[[T2:.*]] = arith.muli %[[T1]], %[[B]] : vector<3xi32> -// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xi32> into vector<2x3xi32> -// CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : vector<2xi32> -// CHECK: %[[T5:.*]] = vector.splat %[[T4]] : vector<3xi32> -// CHECK: %[[T6:.*]] = arith.muli %[[T5]], %[[B]] : vector<3xi32> -// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xi32> into vector<2x3xi32> -// CHECK: return %[[T7]] : vector<2x3xi32> -func.func @outerproduct_noacc_int(%arg0: vector<2xi32>, - %arg1: vector<3xi32>) -> vector<2x3xi32> { - %0 = vector.outerproduct %arg0, %arg1 : vector<2xi32>, vector<3xi32> - return %0: vector<2x3xi32> -} - -// CHECK-LABEL: func @outerproduct_acc_int -// CHECK-SAME: %[[A:.*0]]: vector<2xi32>, -// CHECK-SAME: %[[B:.*1]]: vector<3xi32>, -// CHECK-SAME: %[[C:.*2]]: vector<2x3xi32> -// CHECK: %[[C0:.*]] = arith.constant dense<0> : vector<2x3xi32> -// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xi32> -// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xi32> -// CHECK: %[[T2:.*]] = vector.extract %[[C]][0] : vector<2x3xi32> -// CHECK: %[[T3:.*]] = arith.muli %[[T1]], %[[B]] : vector<3xi32> -// CHECK: %[[T4:.*]] = arith.addi %[[T3]], %[[T2]] : vector<3xi32> -// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[C0]] [0] : vector<3xi32> into vector<2x3xi32> -// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xi32> -// CHECK: %[[T7:.*]] = vector.splat %[[T6]] : vector<3xi32> -// CHECK: %[[T8:.*]] = vector.extract %[[C]][1] : vector<2x3xi32> -// CHECK: %[[T9:.*]] = arith.muli %[[T7]], %[[B]] : vector<3xi32> -// CHECK: %[[T10:.*]] = arith.addi %[[T9]], %[[T8]] : vector<3xi32> -// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T5]] [1] : vector<3xi32> into vector<2x3xi32> -// CHECK: return %[[T11]] : vector<2x3xi32> -func.func @outerproduct_acc_int(%arg0: vector<2xi32>, - %arg1: vector<3xi32>, - %arg2: vector<2x3xi32>) -> vector<2x3xi32> { - %0 = vector.outerproduct %arg0, %arg1, %arg2 : vector<2xi32>, vector<3xi32> - return %0: vector<2x3xi32> -} - -// CHECK-LABEL: func @axpy_fp( -// CHECK-SAME: %[[A:.*0]]: vector<16xf32>, -// CHECK-SAME: %[[B:.*1]]: f32) -// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xf32> -// CHECK: %[[T1:.*]] = arith.mulf %[[A]], %[[T0]] : vector<16xf32> -// CHECK: return %[[T1]] : vector<16xf32> -func.func @axpy_fp(%arg0: vector<16xf32>, %arg1: f32) -> vector<16xf32> { - %0 = vector.outerproduct %arg0, %arg1: vector<16xf32>, f32 - return %0: vector<16xf32> -} - -// CHECK-LABEL: func @axpy_fp_add( -// CHECK-SAME: %[[A:.*0]]: vector<16xf32>, -// CHECK-SAME: %[[B:.*1]]: f32, -// CHECK-SAME: %[[C:.*2]]: vector<16xf32>) -// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xf32> -// CHECK: %[[T1:.*]] = vector.fma %[[A]], %[[T0]], %[[C]] : vector<16xf32> -// CHECK: return %[[T1]] : vector<16xf32> -func.func @axpy_fp_add(%arg0: vector<16xf32>, %arg1: f32, %arg2 : vector<16xf32>) -> vector<16xf32> { - %0 = vector.outerproduct %arg0, %arg1, %arg2: vector<16xf32>, f32 - return %0: vector<16xf32> -} - -// CHECK-LABEL: func @axpy_int( -// CHECK-SAME: %[[A:.*0]]: vector<16xi32>, -// CHECK-SAME: %[[B:.*1]]: i32) -// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xi32> -// CHECK: %[[T1:.*]] = arith.muli %[[A]], %[[T0]] : vector<16xi32> -// CHECK: return %[[T1]] : vector<16xi32> -func.func @axpy_int(%arg0: vector<16xi32>, %arg1: i32) -> vector<16xi32> { - %0 = vector.outerproduct %arg0, %arg1: vector<16xi32>, i32 - return %0: vector<16xi32> -} - -// CHECK-LABEL: func @axpy_int_add( -// CHECK-SAME: %[[A:.*0]]: vector<16xi32>, -// CHECK-SAME: %[[B:.*1]]: i32, -// CHECK-SAME: %[[C:.*2]]: vector<16xi32>) -// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xi32> -// CHECK: %[[T1:.*]] = arith.muli %[[A]], %[[T0]] : vector<16xi32> -// CHECK: %[[T2:.*]] = arith.addi %[[T1]], %[[C]] : vector<16xi32> -// CHECK: return %[[T2]] : vector<16xi32> -func.func @axpy_int_add(%arg0: vector<16xi32>, %arg1: i32, %arg2: vector<16xi32>) -> vector<16xi32> { - %0 = vector.outerproduct %arg0, %arg1, %arg2: vector<16xi32>, i32 - return %0: vector<16xi32> -} - -// CHECK-LABEL: func @nop_shape_cast -// CHECK-SAME: %[[A:.*]]: vector<16xf32> -// CHECK: return %[[A]] : vector<16xf32> - -func.func @nop_shape_cast(%arg0: vector<16xf32>) -> vector<16xf32> { - %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<16xf32> - return %0 : vector<16xf32> -} - -// CHECK-LABEL: func @cancel_shape_cast -// CHECK-SAME: %[[A:.*]]: vector<16xf32> -// CHECK: return %[[A]] : vector<16xf32> - -func.func @cancel_shape_cast(%arg0: vector<16xf32>) -> vector<16xf32> { - %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<4x4xf32> - %1 = vector.shape_cast %0 : vector<4x4xf32> to vector<16xf32> - return %1 : vector<16xf32> -} - -// Shape up and downcasts for 2-D vectors, for supporting conversion to -// llvm.matrix operations -// CHECK-LABEL: func @shape_casts -func.func @shape_casts(%a: vector<2x2xf32>) -> (vector<4xf32>, vector<2x2xf32>) { - // CHECK-DAG: %[[cst22:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32> - // CHECK-DAG: %[[cst:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32> - // CHECK: %[[ex0:.*]] = vector.extract %{{.*}}[0] : vector<2x2xf32> - // - // CHECK: %[[in0:.*]] = vector.insert_strided_slice %[[ex0]], %[[cst]] - // CHECK-SAME: {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32> - // - // CHECK: %[[ex1:.*]] = vector.extract %{{.*}}[1] : vector<2x2xf32> - // - // CHECK: %[[in2:.*]] = vector.insert_strided_slice %[[ex1]], %[[in0]] - // CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32> - // - %0 = vector.shape_cast %a : vector<2x2xf32> to vector<4xf32> - // CHECK: %[[add:.*]] = arith.addf %[[in2]], %[[in2]] : vector<4xf32> - %r0 = arith.addf %0, %0: vector<4xf32> - // - // CHECK: %[[ss0:.*]] = vector.extract_strided_slice %[[add]] - // CHECK-SAME: {offsets = [0], sizes = [2], strides = [1]} : - // CHECK-SAME: vector<4xf32> to vector<2xf32> - // - // CHECK: %[[res0:.*]] = vector.insert %[[ss0]], %[[cst22]] [0] : - // CHECK-SAME: vector<2xf32> into vector<2x2xf32> - // - // CHECK: %[[s2:.*]] = vector.extract_strided_slice %[[add]] - // CHECK-SAME: {offsets = [2], sizes = [2], strides = [1]} : - // CHECK-SAME: vector<4xf32> to vector<2xf32> - // - // CHECK: %[[res1:.*]] = vector.insert %[[s2]], %[[res0]] [1] : - // CHECK-SAME: vector<2xf32> into vector<2x2xf32> - // - %1 = vector.shape_cast %r0 : vector<4xf32> to vector<2x2xf32> - // CHECK: return %[[add]], %[[res1]] : vector<4xf32>, vector<2x2xf32> - return %r0, %1 : vector<4xf32>, vector<2x2xf32> -} - -// CHECK-LABEL: func @shape_cast_2d2d -// CHECK-SAME: %[[A:.*]]: vector<3x2xf32> -// CHECK: %[[C:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32> -// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<3x2xf32> -// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0, 0] : f32 into vector<2x3xf32> -// CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 1] : vector<3x2xf32> -// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 1] : f32 into vector<2x3xf32> -// CHECK: %[[T4:.*]] = vector.extract %[[A]][1, 0] : vector<3x2xf32> -// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [0, 2] : f32 into vector<2x3xf32> -// CHECK: %[[T6:.*]] = vector.extract %[[A]][1, 1] : vector<3x2xf32> -// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 0] : f32 into vector<2x3xf32> -// CHECK: %[[T8:.*]] = vector.extract %[[A]][2, 0] : vector<3x2xf32> -// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [1, 1] : f32 into vector<2x3xf32> -// CHECK: %[[T10:.*]] = vector.extract %[[A]][2, 1] : vector<3x2xf32> -// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [1, 2] : f32 into vector<2x3xf32> -// CHECK: return %[[T11]] : vector<2x3xf32> - -func.func @shape_cast_2d2d(%arg0 : vector<3x2xf32>) -> vector<2x3xf32> { - %s = vector.shape_cast %arg0: vector<3x2xf32> to vector<2x3xf32> - return %s : vector<2x3xf32> -} - -// CHECK-LABEL: func @shape_cast_3d1d -// CHECK-SAME: %[[A:.*]]: vector<1x3x2xf32> -// CHECK: %[[C:.*]] = arith.constant dense<0.000000e+00> : vector<6xf32> -// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0, 0] : vector<1x3x2xf32> -// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0] : f32 into vector<6xf32> -// CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 0, 1] : vector<1x3x2xf32> -// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [1] : f32 into vector<6xf32> -// CHECK: %[[T4:.*]] = vector.extract %[[A]][0, 1, 0] : vector<1x3x2xf32> -// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [2] : f32 into vector<6xf32> -// CHECK: %[[T6:.*]] = vector.extract %[[A]][0, 1, 1] : vector<1x3x2xf32> -// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [3] : f32 into vector<6xf32> -// CHECK: %[[T8:.*]] = vector.extract %[[A]][0, 2, 0] : vector<1x3x2xf32> -// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [4] : f32 into vector<6xf32> -// CHECK: %[[T10:.*]] = vector.extract %[[A]][0, 2, 1] : vector<1x3x2xf32> -// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [5] : f32 into vector<6xf32> -// CHECK: return %[[T11]] : vector<6xf32> - -func.func @shape_cast_3d1d(%arg0 : vector<1x3x2xf32>) -> vector<6xf32> { - %s = vector.shape_cast %arg0 : vector<1x3x2xf32> to vector<6xf32> - return %s : vector<6xf32> -} - -// CHECK-LABEL: func @shape_cast_1d3d -// CHECK-SAME: %[[A:.*]]: vector<6xf32> -// CHECK: %[[C:.*]] = arith.constant dense<0.000000e+00> : vector<2x1x3xf32> -// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<6xf32> -// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0, 0, 0] : f32 into vector<2x1x3xf32> -// CHECK: %[[T2:.*]] = vector.extract %[[A]][1] : vector<6xf32> -// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 0, 1] : f32 into vector<2x1x3xf32> -// CHECK: %[[T4:.*]] = vector.extract %[[A]][2] : vector<6xf32> -// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [0, 0, 2] : f32 into vector<2x1x3xf32> -// CHECK: %[[T6:.*]] = vector.extract %[[A]][3] : vector<6xf32> -// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 0, 0] : f32 into vector<2x1x3xf32> -// CHECK: %[[T8:.*]] = vector.extract %[[A]][4] : vector<6xf32> -// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [1, 0, 1] : f32 into vector<2x1x3xf32> -// CHECK: %[[T10:.*]] = vector.extract %[[A]][5] : vector<6xf32> -// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [1, 0, 2] : f32 into vector<2x1x3xf32> -// CHECK: return %[[T11]] : vector<2x1x3xf32> - -func.func @shape_cast_1d3d(%arg0 : vector<6xf32>) -> vector<2x1x3xf32> { - %s = vector.shape_cast %arg0 : vector<6xf32> to vector<2x1x3xf32> - return %s : vector<2x1x3xf32> -} - -// MATRIX-LABEL: func @matmul -// MATRIX-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>, -// MATRIX-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>, -// MATRIX-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32> -// MATRIX: %[[vcst:.*]] = arith.constant dense<0.000000e+00> : vector<8xf32> -// MATRIX: %[[vcst_0:.*]] = arith.constant dense<0.000000e+00> : vector<12xf32> -// MATRIX: %[[vcst_1:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32> -// MATRIX: %[[a0:.*]] = vector.extract %[[A]][0] : vector<2x4xf32> -// MATRIX: %[[a1:.*]] = vector.insert_strided_slice %[[a0]], %[[vcst]] {offsets = [0], strides = [1]} : vector<4xf32> into vector<8xf32> -// MATRIX: %[[a2:.*]] = vector.extract %[[A]][1] : vector<2x4xf32> -// MATRIX: %[[a3:.*]] = vector.insert_strided_slice %[[a2]], %[[a1]] {offsets = [4], strides = [1]} : vector<4xf32> into vector<8xf32> -// MATRIX: %[[b0:.*]] = vector.extract %[[B]][0] : vector<4x3xf32> -// MATRIX: %[[b1:.*]] = vector.insert_strided_slice %[[b0]], %[[vcst_0]] {offsets = [0], strides = [1]} : vector<3xf32> into vector<12xf32> -// MATRIX: %[[b2:.*]] = vector.extract %[[B]][1] : vector<4x3xf32> -// MATRIX: %[[b3:.*]] = vector.insert_strided_slice %[[b2]], %[[b1]] {offsets = [3], strides = [1]} : vector<3xf32> into vector<12xf32> -// MATRIX: %[[b4:.*]] = vector.extract %[[B]][2] : vector<4x3xf32> -// MATRIX: %[[b5:.*]] = vector.insert_strided_slice %[[b4]], %[[b3]] {offsets = [6], strides = [1]} : vector<3xf32> into vector<12xf32> -// MATRIX: %[[b6:.*]] = vector.extract %[[B]][3] : vector<4x3xf32> -// MATRIX: %[[b7:.*]] = vector.insert_strided_slice %[[b6]], %[[b5]] {offsets = [9], strides = [1]} : vector<3xf32> into vector<12xf32> -// MATRIX: %[[mm1:.*]] = vector.matrix_multiply %[[a3]], %[[b7]] {lhs_columns = 4 : i32, lhs_rows = 2 : i32, rhs_columns = 3 : i32} : (vector<8xf32>, vector<12xf32>) -> vector<6xf32> -// MATRIX: %[[mm2:.*]] = vector.extract_strided_slice %[[mm1]] {offsets = [0], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32> -// MATRIX: %[[mm3:.*]] = vector.insert %[[mm2]], %[[vcst_1]] [0] : vector<3xf32> into vector<2x3xf32> -// MATRIX: %[[mm4:.*]] = vector.extract_strided_slice %[[mm1]] {offsets = [3], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32> -// MATRIX: %[[mm5:.*]] = vector.insert %[[mm4]], %[[mm3]] [1] : vector<3xf32> into vector<2x3xf32> -// MATRIX: %[[mm6:.*]] = arith.addf %[[C]], %[[mm5]] : vector<2x3xf32> - -// OUTERPRODUCT-LABEL: func @matmul -// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>, -// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>, -// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32> -// OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0] -// OUTERPRODUCT-SAME: : vector<2x4xf32> to vector<4x2xf32> -// -// OUTERPRODUCT: %[[a0:.*]] = vector.extract %[[At]][0] : vector<4x2xf32> -// OUTERPRODUCT: %[[b0:.*]] = vector.extract %[[B]][0] : vector<4x3xf32> -// OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]] -// OUTERPRODUCT-SAME: : vector<2xf32>, vector<3xf32> -// -// OUTERPRODUCT: %[[a1:.*]] = vector.extract %[[At]][1] : vector<4x2xf32> -// OUTERPRODUCT: %[[b1:.*]] = vector.extract %[[B]][1] : vector<4x3xf32> -// OUTERPRODUCT: %[[c1:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[c0]] -// OUTERPRODUCT-SAME: : vector<2xf32>, vector<3xf32> -// -// OUTERPRODUCT: %[[a2:.*]] = vector.extract %[[At]][2] : vector<4x2xf32> -// OUTERPRODUCT: %[[b2:.*]] = vector.extract %[[B]][2] : vector<4x3xf32> -// OUTERPRODUCT: %[[c2:.*]] = vector.outerproduct %[[a2]], %[[b2]], %[[c1]] -// OUTERPRODUCT-SAME: : vector<2xf32>, vector<3xf32> -// -// OUTERPRODUCT: %[[a3:.*]] = vector.extract %[[At]][3] : vector<4x2xf32> -// OUTERPRODUCT: %[[b3:.*]] = vector.extract %[[B]][3] : vector<4x3xf32> -// OUTERPRODUCT: %[[c3:.*]] = vector.outerproduct %[[a3]], %[[b3]], %[[c2]] -// OUTERPRODUCT-SAME: : vector<2xf32>, vector<3xf32> -// -// OUTERPRODUCT: return %[[c3]] : vector<2x3xf32> - -// REDUCE-LABEL: func @matmul -// REDUCE-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>, -// REDUCE-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>, -// REDUCE-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32> -// -// REDUCE: %[[RES:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32> -// REDUCE: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0] -// REDUCE-SAME: : vector<4x3f32> to vector<3x4xf32> -// -// REDUCE: %[[a0:.*]] = vector.extract %[[A]][0] : vector<2x4xf32> -// REDUCE-NEXT: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<3x4xf32> -// REDUCE-NEXT: %[[ab00:.*]] = mul %[[a0]], %[[b0]] : vector<4xf32> -// REDUCE-NEXT: %[[s00:.*]] = vector.reduction , %[[ab00]] : vector<4xf32> into f32 -// REDUCE-NEXT: %[[r00:.*]] = vector.insert %[[s00]], %[[RES]] [0, 0] : f32 into vector<2x3xf32> -// -// ... -// -// REDUCE: %[[a1:.*]] = vector.extract %[[A]][1] : vector<2x4xf32> -// REDUCE-NEXT: %[[b2:.*]] = vector.extract %[[Bt]][2] : vector<3x4xf32> -// REDUCE-NEXT: %[[ab12:.*]] = mul %[[a1]], %[[b02]] : vector<4xf32> -// REDUCE-NEXT: %[[s12:.*]] = vector.reduction , %[[ab12]] : vector<4xf32> into f32 -// REDUCE-NEXT: %[[r12:.*]] = vector.insert %[[s12]], %{{.*}} [1, 2] : f32 into vector<2x3xf32> -// -// REDUCE: return %[[c3]] : vector<2x3xf32> -func.func @matmul(%arg0: vector<2x4xf32>, - %arg1: vector<4x3xf32>, - %arg2: vector<2x3xf32>) -> vector<2x3xf32> { - %0 = vector.contract #matmat_trait %arg0, %arg1, %arg2 - : vector<2x4xf32>, vector<4x3xf32> into vector<2x3xf32> - return %0 : vector<2x3xf32> -} - -// CHECK-LABEL: func @broadcast_vec1d_from_scalar -// CHECK-SAME: %[[A:.*0]]: f32 -// CHECK: %[[T0:.*]] = vector.splat %[[A]] : vector<2xf32> -// CHECK: return %[[T0]] : vector<2xf32> - -func.func @broadcast_vec1d_from_scalar(%arg0: f32) -> vector<2xf32> { - %0 = vector.broadcast %arg0 : f32 to vector<2xf32> - return %0 : vector<2xf32> -} - -// CHECK-LABEL: func @broadcast_vec2d_from_scalar -// CHECK-SAME: %[[A:.*0]]: f32 -// CHECK: %[[T0:.*]] = vector.splat %[[A]] : vector<2x3xf32> -// CHECK: return %[[T0]] : vector<2x3xf32> - -func.func @broadcast_vec2d_from_scalar(%arg0: f32) -> vector<2x3xf32> { - %0 = vector.broadcast %arg0 : f32 to vector<2x3xf32> - return %0 : vector<2x3xf32> -} - -// CHECK-LABEL: func @broadcast_vec3d_from_scalar -// CHECK-SAME: %[[A:.*0]]: f32 -// CHECK: %[[T0:.*]] = vector.splat %[[A]] : vector<2x3x4xf32> -// CHECK: return %[[T0]] : vector<2x3x4xf32> - -func.func @broadcast_vec3d_from_scalar(%arg0: f32) -> vector<2x3x4xf32> { - %0 = vector.broadcast %arg0 : f32 to vector<2x3x4xf32> - return %0 : vector<2x3x4xf32> -} - -// CHECK-LABEL: func @broadcast_vec1d_from_vec1d -// CHECK-SAME: %[[A:.*0]]: vector<2xf32> -// CHECK: return %[[A]] : vector<2xf32> - -func.func @broadcast_vec1d_from_vec1d(%arg0: vector<2xf32>) -> vector<2xf32> { - %0 = vector.broadcast %arg0 : vector<2xf32> to vector<2xf32> - return %0 : vector<2xf32> -} - -// CHECK-LABEL: func @broadcast_vec2d_from_vec1d -// CHECK-SAME: %[[A:.*0]]: vector<2xf32> -// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32> -// CHECK: %[[T0:.*]] = vector.insert %[[A]], %[[C0]] [0] : vector<2xf32> into vector<3x2xf32> -// CHECK: %[[T1:.*]] = vector.insert %[[A]], %[[T0]] [1] : vector<2xf32> into vector<3x2xf32> -// CHECK: %[[T2:.*]] = vector.insert %[[A]], %[[T1]] [2] : vector<2xf32> into vector<3x2xf32> -// CHECK: return %[[T2]] : vector<3x2xf32> - -func.func @broadcast_vec2d_from_vec1d(%arg0: vector<2xf32>) -> vector<3x2xf32> { - %0 = vector.broadcast %arg0 : vector<2xf32> to vector<3x2xf32> - return %0 : vector<3x2xf32> -} - -// CHECK-LABEL: func @broadcast_vec3d_from_vec1d -// CHECK-SAME: %[[A:.*0]]: vector<2xf32> -// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32> -// CHECK: %[[C1:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x2xf32> -// CHECK: %[[T0:.*]] = vector.insert %[[A]], %[[C0]] [0] : vector<2xf32> into vector<3x2xf32> -// CHECK: %[[T1:.*]] = vector.insert %[[A]], %[[T0]] [1] : vector<2xf32> into vector<3x2xf32> -// CHECK: %[[T2:.*]] = vector.insert %[[A]], %[[T1]] [2] : vector<2xf32> into vector<3x2xf32> -// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C1]] [0] : vector<3x2xf32> into vector<4x3x2xf32> -// CHECK: %[[T4:.*]] = vector.insert %[[T2]], %[[T3]] [1] : vector<3x2xf32> into vector<4x3x2xf32> -// CHECK: %[[T5:.*]] = vector.insert %[[T2]], %[[T4]] [2] : vector<3x2xf32> into vector<4x3x2xf32> -// CHECK: %[[T6:.*]] = vector.insert %[[T2]], %[[T5]] [3] : vector<3x2xf32> into vector<4x3x2xf32> -// CHECK: return %[[T6]] : vector<4x3x2xf32> - -func.func @broadcast_vec3d_from_vec1d(%arg0: vector<2xf32>) -> vector<4x3x2xf32> { - %0 = vector.broadcast %arg0 : vector<2xf32> to vector<4x3x2xf32> - return %0 : vector<4x3x2xf32> -} - -// CHECK-LABEL: func @broadcast_vec3d_from_vec2d -// CHECK-SAME: %[[A:.*0]]: vector<3x2xf32> -// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x2xf32> -// CHECK: %[[T0:.*]] = vector.insert %[[A]], %[[C0]] [0] : vector<3x2xf32> into vector<4x3x2xf32> -// CHECK: %[[T1:.*]] = vector.insert %[[A]], %[[T0]] [1] : vector<3x2xf32> into vector<4x3x2xf32> -// CHECK: %[[T2:.*]] = vector.insert %[[A]], %[[T1]] [2] : vector<3x2xf32> into vector<4x3x2xf32> -// CHECK: %[[T3:.*]] = vector.insert %[[A]], %[[T2]] [3] : vector<3x2xf32> into vector<4x3x2xf32> -// CHECK: return %[[T3]] : vector<4x3x2xf32> - -func.func @broadcast_vec3d_from_vec2d(%arg0: vector<3x2xf32>) -> vector<4x3x2xf32> { - %0 = vector.broadcast %arg0 : vector<3x2xf32> to vector<4x3x2xf32> - return %0 : vector<4x3x2xf32> -} - -// CHECK-LABEL: func @broadcast_stretch -// CHECK-SAME: %[[A:.*0]]: vector<1xf32> -// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<1xf32> -// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<4xf32> -// CHECK: return %[[T1]] : vector<4xf32> - -func.func @broadcast_stretch(%arg0: vector<1xf32>) -> vector<4xf32> { - %0 = vector.broadcast %arg0 : vector<1xf32> to vector<4xf32> - return %0 : vector<4xf32> -} - -// CHECK-LABEL: func @broadcast_stretch_at_start -// CHECK-SAME: %[[A:.*0]]: vector<1x4xf32> -// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<3x4xf32> -// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<1x4xf32> -// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C0]] [0] : vector<4xf32> into vector<3x4xf32> -// CHECK: %[[T2:.*]] = vector.insert %[[T0]], %[[T1]] [1] : vector<4xf32> into vector<3x4xf32> -// CHECK: %[[T3:.*]] = vector.insert %[[T0]], %[[T2]] [2] : vector<4xf32> into vector<3x4xf32> -// CHECK: return %[[T3]] : vector<3x4xf32> - -func.func @broadcast_stretch_at_start(%arg0: vector<1x4xf32>) -> vector<3x4xf32> { - %0 = vector.broadcast %arg0 : vector<1x4xf32> to vector<3x4xf32> - return %0 : vector<3x4xf32> -} - -// CHECK-LABEL: func @broadcast_stretch_at_end -// CHECK-SAME: %[[A:.*0]]: vector<4x1xf32> -// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<4x3xf32> -// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<4x1xf32> -// CHECK: %[[T2:.*]] = vector.splat %[[T0]] : vector<3xf32> -// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xf32> into vector<4x3xf32> -// CHECK: %[[T4:.*]] = vector.extract %[[A]][1, 0] : vector<4x1xf32> -// CHECK: %[[T6:.*]] = vector.splat %[[T4]] : vector<3xf32> -// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xf32> into vector<4x3xf32> -// CHECK: %[[T8:.*]] = vector.extract %[[A]][2, 0] : vector<4x1xf32> -// CHECK: %[[T10:.*]] = vector.splat %[[T8]] : vector<3xf32> -// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T7]] [2] : vector<3xf32> into vector<4x3xf32> -// CHECK: %[[T12:.*]] = vector.extract %[[A]][3, 0] : vector<4x1xf32> -// CHECK: %[[T14:.*]] = vector.splat %[[T12]] : vector<3xf32> -// CHECK: %[[T15:.*]] = vector.insert %[[T14]], %[[T11]] [3] : vector<3xf32> into vector<4x3xf32> -// CHECK: return %[[T15]] : vector<4x3xf32> - -func.func @broadcast_stretch_at_end(%arg0: vector<4x1xf32>) -> vector<4x3xf32> { - %0 = vector.broadcast %arg0 : vector<4x1xf32> to vector<4x3xf32> - return %0 : vector<4x3xf32> -} - -// CHECK-LABEL: func @broadcast_stretch_in_middle -// CHECK-SAME: %[[A:.*0]]: vector<4x1x2xf32> -// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x2xf32> -// CHECK: %[[C1:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32> -// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<4x1x2xf32> -// CHECK: %[[T2:.*]] = vector.insert %[[T0]], %[[C1]] [0] : vector<2xf32> into vector<3x2xf32> -// CHECK: %[[T3:.*]] = vector.insert %[[T0]], %[[T2]] [1] : vector<2xf32> into vector<3x2xf32> -// CHECK: %[[T4:.*]] = vector.insert %[[T0]], %[[T3]] [2] : vector<2xf32> into vector<3x2xf32> -// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[C0]] [0] : vector<3x2xf32> into vector<4x3x2xf32> -// CHECK: %[[T6:.*]] = vector.extract %[[A]][1, 0] : vector<4x1x2xf32> -// CHECK: %[[T8:.*]] = vector.insert %[[T6]], %[[C1]] [0] : vector<2xf32> into vector<3x2xf32> -// CHECK: %[[T9:.*]] = vector.insert %[[T6]], %[[T8]] [1] : vector<2xf32> into vector<3x2xf32> -// CHECK: %[[T10:.*]] = vector.insert %[[T6]], %[[T9]] [2] : vector<2xf32> into vector<3x2xf32> -// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T5]] [1] : vector<3x2xf32> into vector<4x3x2xf32> -// CHECK: %[[T12:.*]] = vector.extract %[[A]][2, 0] : vector<4x1x2xf32> -// CHECK: %[[T14:.*]] = vector.insert %[[T12]], %[[C1]] [0] : vector<2xf32> into vector<3x2xf32> -// CHECK: %[[T15:.*]] = vector.insert %[[T12]], %[[T14]] [1] : vector<2xf32> into vector<3x2xf32> -// CHECK: %[[T16:.*]] = vector.insert %[[T12]], %[[T15]] [2] : vector<2xf32> into vector<3x2xf32> -// CHECK: %[[T17:.*]] = vector.insert %[[T16]], %[[T11]] [2] : vector<3x2xf32> into vector<4x3x2xf32> -// CHECK: %[[T18:.*]] = vector.extract %[[A]][3, 0] : vector<4x1x2xf32> -// CHECK: %[[T20:.*]] = vector.insert %[[T18]], %[[C1]] [0] : vector<2xf32> into vector<3x2xf32> -// CHECK: %[[T21:.*]] = vector.insert %[[T18]], %[[T20]] [1] : vector<2xf32> into vector<3x2xf32> -// CHECK: %[[T22:.*]] = vector.insert %[[T18]], %[[T21]] [2] : vector<2xf32> into vector<3x2xf32> -// CHECK: %[[T23:.*]] = vector.insert %[[T22]], %[[T17]] [3] : vector<3x2xf32> into vector<4x3x2xf32> -// CHECK: return %[[T23]] : vector<4x3x2xf32> - -func.func @broadcast_stretch_in_middle(%arg0: vector<4x1x2xf32>) -> vector<4x3x2xf32> { - %0 = vector.broadcast %arg0 : vector<4x1x2xf32> to vector<4x3x2xf32> - return %0 : vector<4x3x2xf32> -} - -// CHECK-LABEL: func @genbool_1d -// CHECK: %[[T0:.*]] = arith.constant dense<[true, true, true, true, false, false, false, false]> : vector<8xi1> -// CHECK: return %[[T0]] : vector<8xi1> - -func.func @genbool_1d() -> vector<8xi1> { - %0 = vector.constant_mask [4] : vector<8xi1> - return %0 : vector<8xi1> -} - -// CHECK-LABEL: func @genbool_2d -// CHECK: %[[C1:.*]] = arith.constant dense<[true, true, false, false]> : vector<4xi1> -// CHECK: %[[C2:.*]] = arith.constant dense : vector<4x4xi1> -// CHECK: %[[T0:.*]] = vector.insert %[[C1]], %[[C2]] [0] : vector<4xi1> into vector<4x4xi1> -// CHECK: %[[T1:.*]] = vector.insert %[[C1]], %[[T0]] [1] : vector<4xi1> into vector<4x4xi1> -// CHECK: return %[[T1]] : vector<4x4xi1> - -func.func @genbool_2d() -> vector<4x4xi1> { - %v = vector.constant_mask [2, 2] : vector<4x4xi1> - return %v: vector<4x4xi1> -} - -// CHECK-LABEL: func @genbool_3d -// CHECK: %[[C1:.*]] = arith.constant dense<[true, true, true, false]> : vector<4xi1> -// CHECK: %[[C2:.*]] = arith.constant dense : vector<3x4xi1> -// CHECK: %[[C3:.*]] = arith.constant dense : vector<2x3x4xi1> -// CHECK: %[[T0:.*]] = vector.insert %[[C1]], %[[C2]] [0] : vector<4xi1> into vector<3x4xi1> -// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C3]] [0] : vector<3x4xi1> into vector<2x3x4xi1> -// CHECK: return %[[T1]] : vector<2x3x4xi1> - -func.func @genbool_3d() -> vector<2x3x4xi1> { - %v = vector.constant_mask [1, 1, 3] : vector<2x3x4xi1> - return %v: vector<2x3x4xi1> -} - -// CHECK-LABEL: func @genbool_var_1d( -// CHECK-SAME: %[[A:.*]]: index) -// CHECK: %[[T0:.*]] = vector.create_mask %[[A]] : vector<3xi1> -// CHECK: return %[[T0]] : vector<3xi1> - -func.func @genbool_var_1d(%arg0: index) -> vector<3xi1> { - %0 = vector.create_mask %arg0 : vector<3xi1> - return %0 : vector<3xi1> -} - -// CHECK-LABEL: func @genbool_var_2d( -// CHECK-SAME: %[[A:.*0]]: index, -// CHECK-SAME: %[[B:.*1]]: index) -// CHECK: %[[C1:.*]] = arith.constant dense : vector<3xi1> -// CHECK: %[[C2:.*]] = arith.constant dense : vector<2x3xi1> -// CHECK: %[[c0:.*]] = arith.constant 0 : index -// CHECK: %[[c1:.*]] = arith.constant 1 : index -// CHECK: %[[T0:.*]] = vector.create_mask %[[B]] : vector<3xi1> -// CHECK: %[[T1:.*]] = arith.cmpi sgt, %[[A]], %[[c0]] : index -// CHECK: %[[T2:.*]] = arith.select %[[T1]], %[[T0]], %[[C1]] : vector<3xi1> -// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C2]] [0] : vector<3xi1> into vector<2x3xi1> -// CHECK: %[[T4:.*]] = arith.cmpi sgt, %[[A]], %[[c1]] : index -// CHECK: %[[T5:.*]] = arith.select %[[T4]], %[[T0]], %[[C1]] : vector<3xi1> -// CHECK: %[[T6:.*]] = vector.insert %[[T5]], %[[T3]] [1] : vector<3xi1> into vector<2x3xi1> -// CHECK: return %[[T6]] : vector<2x3xi1> - -func.func @genbool_var_2d(%arg0: index, %arg1: index) -> vector<2x3xi1> { - %0 = vector.create_mask %arg0, %arg1 : vector<2x3xi1> - return %0 : vector<2x3xi1> -} - -// CHECK-LABEL: func @genbool_var_3d( -// CHECK-SAME: %[[A:.*0]]: index, -// CHECK-SAME: %[[B:.*1]]: index, -// CHECK-SAME: %[[C:.*2]]: index) -// CHECK-DAG: %[[C1:.*]] = arith.constant dense : vector<7xi1> -// CHECK-DAG: %[[C2:.*]] = arith.constant dense : vector<1x7xi1> -// CHECK-DAG: %[[C3:.*]] = arith.constant dense : vector<2x1x7xi1> -// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index -// CHECK: %[[T0:.*]] = vector.create_mask %[[C]] : vector<7xi1> -// CHECK: %[[T1:.*]] = arith.cmpi sgt, %[[B]], %[[c0]] : index -// CHECK: %[[T2:.*]] = arith.select %[[T1]], %[[T0]], %[[C1]] : vector<7xi1> -// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C2]] [0] : vector<7xi1> into vector<1x7xi1> -// CHECK: %[[T4:.*]] = arith.cmpi sgt, %[[A]], %[[c0]] : index -// CHECK: %[[T5:.*]] = arith.select %[[T4]], %[[T3]], %[[C2]] : vector<1x7xi1> -// CHECK: %[[T6:.*]] = vector.insert %[[T5]], %[[C3]] [0] : vector<1x7xi1> into vector<2x1x7xi1> -// CHECK: %[[T7:.*]] = arith.cmpi sgt, %[[A]], %[[c1]] : index -// CHECK: %[[T8:.*]] = arith.select %[[T7]], %[[T3]], %[[C2]] : vector<1x7xi1> -// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T6]] [1] : vector<1x7xi1> into vector<2x1x7xi1> -// CHECK: return %[[T9]] : vector<2x1x7xi1> - -func.func @genbool_var_3d(%arg0: index, %arg1: index, %arg2: index) -> vector<2x1x7xi1> { - %0 = vector.create_mask %arg0, %arg1, %arg2 : vector<2x1x7xi1> - return %0 : vector<2x1x7xi1> -} - -// CHECK-LABEL: @contract_one_sided_unit_reduction_dim -// CHECK-SAME: (%[[A0:.+]]: vector<1x2xi32>, %[[A1:.+]]: vector<2x2xi32>, %[[A2:.+]]: vector<2xi32>) -// CHECK-DAG: %[[C:.+]] = arith.constant dense<0> : vector<2xi32> -// CHECK-DAG: %[[E00:.+]] = vector.extract %[[A0]][0] : vector<1x2xi32> -// CHECK-DAG: %[[E10:.+]] = vector.extract %[[A1]][0] : vector<2x2xi32> -// CHECK: %[[M0:.+]] = arith.muli %[[E10]], %[[E00]] : vector<2xi32> -// CHECK: %[[R0:.+]] = vector.reduction , %[[M0]] : vector<2xi32> into i32 -// CHECK: %[[I0:.+]] = vector.insert %[[R0]], %[[C]] [0] : i32 into vector<2xi32> -// CHECK: %[[E11:.+]] = vector.extract %[[A1]][1] : vector<2x2xi32> -// CHECK: %[[M1:.+]] = arith.muli %[[E11]], %[[E00]] : vector<2xi32> -// CHECK: %[[R1:.+]] = vector.reduction , %[[M1]] : vector<2xi32> into i32 -// CHECK: %[[I1:.+]] = vector.insert %[[R1]], %[[I0]] [1] : i32 into vector<2xi32> -// CHECK: %[[S:.+]] = arith.addi %[[I1]], %[[A2]] : vector<2xi32> -// CHECK: return %[[S]] : vector<2xi32> - -func.func @contract_one_sided_unit_reduction_dim(%arg0 : vector<1x2xi32>, %arg1 : vector<2x2xi32>, %arg2 : vector<2xi32>) -> vector<2xi32> { - %res = vector.contract { - indexing_maps = [ - affine_map<(d0, d1, d2) -> (d0, d2)>, - affine_map<(d0, d1, d2) -> (d1, d2)>, - affine_map<(d0, d1, d2) -> (d1)> - ], - iterator_types = ["reduction", "parallel", "reduction"], - kind = #vector.kind - } %arg0, %arg1, %arg2 : vector<1x2xi32>, vector<2x2xi32>, vector<2xi32> into vector<2xi32> - return %res : vector<2xi32> -} - -#matmat_accesses_0 = [ - affine_map<(m, n, k) -> (m, k)>, - affine_map<(m, n, k) -> (k, n)>, - affine_map<(m, n, k) -> (m, n)> -] -#matmat_trait_0 = { - indexing_maps = #matmat_accesses_0, - iterator_types = ["parallel", "parallel", "reduction"] -} - -// OUTERPRODUCT-LABEL: func @matmul_0 -// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>, -// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>, -// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32> -// OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0] -// OUTERPRODUCT: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32> -// OUTERPRODUCT: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32> -// OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]] -// OUTERPRODUCT: return %[[c0]] : vector<2x3xf32> -func.func @matmul_0(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<2x3xf32>) --> vector<2x3xf32> -{ - %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2 - : vector<2x1xf32>, vector<1x3xf32> into vector<2x3xf32> - return %0 : vector<2x3xf32> -} - -// OUTERPRODUCT-LABEL: func @matmul_0_mixed -// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf16>, -// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf16>, -// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32> -// OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0] -// OUTERPRODUCT: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf16> -// OUTERPRODUCT: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf16> -// OUTERPRODUCT: %[[a1:.*]] = arith.extf %[[a0]] : vector<2xf16> to vector<2xf32> -// OUTERPRODUCT: %[[b1:.*]] = arith.extf %[[b0]] : vector<3xf16> to vector<3xf32> -// OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[C]] -// OUTERPRODUCT: return %[[c0]] : vector<2x3xf32> -func.func @matmul_0_mixed(%arg0: vector<2x1xf16>, %arg1: vector<1x3xf16>, %arg2: vector<2x3xf32>) --> vector<2x3xf32> -{ - %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2 - : vector<2x1xf16>, vector<1x3xf16> into vector<2x3xf32> - return %0 : vector<2x3xf32> -} - -#matmat_accesses_1 = [ - affine_map<(m, n, k) -> (m, k)>, - affine_map<(m, n, k) -> (n, k)>, - affine_map<(m, n, k) -> (m, n)> -] -#matmat_trait_1 = { - indexing_maps = #matmat_accesses_1, - iterator_types = ["parallel", "parallel", "reduction"] -} - -// OUTERPRODUCT-LABEL: func @matmul_1 -// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>, -// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<3x1xf32>, -// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32> -// OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0] -// OUTERPRODUCT: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0] -// OUTERPRODUCT: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32> -// OUTERPRODUCT: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<1x3xf32> -// OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]] -// OUTERPRODUCT: return %[[c0]] : vector<2x3xf32> -func.func @matmul_1(%arg0: vector<2x1xf32>, %arg1: vector<3x1xf32>, %arg2: vector<2x3xf32>) --> vector<2x3xf32> -{ - %0 = vector.contract #matmat_trait_1 %arg0, %arg1, %arg2 - : vector<2x1xf32>, vector<3x1xf32> into vector<2x3xf32> - return %0 : vector<2x3xf32> -} - -#matmat_accesses_2 = [ - affine_map<(m, n, k) -> (k, m)>, - affine_map<(m, n, k) -> (k, n)>, - affine_map<(m, n, k) -> (m, n)> -] -#matmat_trait_2 = { - indexing_maps = #matmat_accesses_2, - iterator_types = ["parallel", "parallel", "reduction"] -} - -// OUTERPRODUCT-LABEL: func @matmul_2 -// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>, -// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>, -// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32> -// OUTERPRODUCT: %[[a0:.*]] = vector.extract %[[A]][0] : vector<1x2xf32> -// OUTERPRODUCT: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32> -// OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]] -// OUTERPRODUCT: return %[[c0]] : vector<2x3xf32> -func.func @matmul_2(%arg0: vector<1x2xf32>, %arg1: vector<1x3xf32>, %arg2: vector<2x3xf32>) --> vector<2x3xf32> -{ - %0 = vector.contract #matmat_trait_2 %arg0, %arg1, %arg2 - : vector<1x2xf32>, vector<1x3xf32> into vector<2x3xf32> - return %0 : vector<2x3xf32> -} - -#matmat_accesses_3 = [ - affine_map<(m, n, k) -> (k, m)>, - affine_map<(m, n, k) -> (n, k)>, - affine_map<(m, n, k) -> (m, n)> -] -#matmat_trait_3 = { - indexing_maps = #matmat_accesses_3, - iterator_types = ["parallel", "parallel", "reduction"] -} - -// OUTERPRODUCT-LABEL: func @matmul_3 -// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>, -// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<3x1xf32>, -// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32> -// OUTERPRODUCT: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0] -// OUTERPRODUCT: %[[a0:.*]] = vector.extract %[[A]][0] : vector<1x2xf32> -// OUTERPRODUCT: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<1x3xf32> -// OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]] -// OUTERPRODUCT: return %[[c0]] : vector<2x3xf32> -func.func @matmul_3(%arg0: vector<1x2xf32>, %arg1: vector<3x1xf32>, %arg2: vector<2x3xf32>) --> vector<2x3xf32> -{ - %0 = vector.contract #matmat_trait_3 %arg0, %arg1, %arg2 - : vector<1x2xf32>, vector<3x1xf32> into vector<2x3xf32> - return %0 : vector<2x3xf32> -} - -#matmat_accesses_4 = [ - affine_map<(m, n, k) -> (m, k)>, - affine_map<(m, n, k) -> (k, n)>, - affine_map<(m, n, k) -> (n, m)> -] -#matmat_trait_4 = { - indexing_maps = #matmat_accesses_4, - iterator_types = ["parallel", "parallel", "reduction"] -} - -// OUTERPRODUCT-LABEL: func @matmul_4 -// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>, -// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>, -// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32> -// OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0] -// OUTERPRODUCT: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32> -// OUTERPRODUCT: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32> -// OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]] -// OUTERPRODUCT: return %[[c0]] : vector<3x2xf32> -func.func @matmul_4(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<3x2xf32>) --> vector<3x2xf32> -{ - %0 = vector.contract #matmat_trait_4 %arg0, %arg1, %arg2 - : vector<2x1xf32>, vector<1x3xf32> into vector<3x2xf32> - return %0 : vector<3x2xf32> -} - -#matmat_accesses_5 = [ - affine_map<(m, n, k) -> (m, k)>, - affine_map<(m, n, k) -> (k, n)>, - affine_map<(m, n, k) -> (n, m)> -] -#matmat_trait_5 = { - indexing_maps = #matmat_accesses_5, - iterator_types = ["parallel", "parallel", "reduction"] -} - -// OUTERPRODUCT-LABEL: func @matmul_5 -// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>, -// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>, -// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32> -// OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0] -// OUTERPRODUCT-DAG: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32> -// OUTERPRODUCT-DAG: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32> -// OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]] -// OUTERPRODUCT: return %[[c0]] : vector<3x2xf32> -func.func @matmul_5(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<3x2xf32>) --> vector<3x2xf32> -{ - %0 = vector.contract #matmat_trait_5 %arg0, %arg1, %arg2 - : vector<2x1xf32>, vector<1x3xf32> into vector<3x2xf32> - return %0 : vector<3x2xf32> -} - -#matmat_accesses_6 = [ - affine_map<(m, n, k) -> (m, k)>, - affine_map<(m, n, k) -> (k, n)>, - affine_map<(m, n, k) -> (n, m)> -] -#matmat_trait_6 = { - indexing_maps = #matmat_accesses_6, - iterator_types = ["parallel", "parallel", "reduction"] -} - -// OUTERPRODUCT-LABEL: func @matmul_6 -// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>, -// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>, -// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32> -// OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0] -// OUTERPRODUCT-DAG: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32> -// OUTERPRODUCT-DAG: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32> -// OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]] -// OUTERPRODUCT: return %[[c0]] : vector<3x2xf32> -func.func @matmul_6(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<3x2xf32>) --> vector<3x2xf32> -{ - %0 = vector.contract #matmat_trait_6 %arg0, %arg1, %arg2 - : vector<2x1xf32>, vector<1x3xf32> into vector<3x2xf32> - return %0 : vector<3x2xf32> -} - -#matmat_accesses_7 = [ - affine_map<(m, n, k) -> (m, k)>, - affine_map<(m, n, k) -> (k, n)>, - affine_map<(m, n, k) -> (n, m)> -] -#matmat_trait_7 = { - indexing_maps = #matmat_accesses_7, - iterator_types = ["parallel", "parallel", "reduction"] -} - -// OUTERPRODUCT-LABEL: func @matmul_7 -// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>, -// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>, -// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32> -// OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0] -// OUTERPRODUCT-DAG: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32> -// OUTERPRODUCT-DAG: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32> -// OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]] -// OUTERPRODUCT: return %[[c0]] : vector<3x2xf32> -func.func @matmul_7(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<3x2xf32>) --> vector<3x2xf32> -{ - %0 = vector.contract #matmat_trait_7 %arg0, %arg1, %arg2 - : vector<2x1xf32>, vector<1x3xf32> into vector<3x2xf32> - return %0 : vector<3x2xf32> -} - -// PARALLEL-LABEL: func @parrallel_contract_lowering -// PARALLEL: %[[E0:.*]] = vector.extract %{{.*}}[0, 0] : vector<1x1x4xf32> -// PARALLEL: %[[E1:.*]] = vector.extract %{{.*}}[0, 0] : vector<1x1x4xf32> -// PARALLEL: %[[F:.*]] = vector.fma %[[E0]], %[[E1]], %{{.*}} : vector<4xf32> -// PARALLEL: return %[[F]] : vector<4xf32> -func.func @parrallel_contract_lowering(%arg0: vector<1x1x4xf32>, %arg1: vector<1x1x4xf32>, %arg2: vector<4xf32>) -> vector<4xf32> { - %0 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2, d0)>, affine_map<(d0, d1, d2) -> (d1, d2, d0)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"], kind = #vector.kind} %arg0, %arg1, %arg2 : vector<1x1x4xf32>, vector<1x1x4xf32> into vector<4xf32> - return %0 : vector<4xf32> -} - -// PARALLEL-LABEL: func @parrallel_contract_lowering_broadcast -// PARALLEL: %[[B:.*]] = vector.broadcast %{{.*}} : vector<1x1xf32> to vector<4x1x1xf32> -// PARALLEL: %[[T:.*]] = vector.transpose %[[B]], [1, 2, 0] : vector<4x1x1xf32> to vector<1x1x4xf32> -// PARALLEL: %[[E0:.*]] = vector.extract %[[T]][0, 0] : vector<1x1x4xf32> -// PARALLEL: %[[E1:.*]] = vector.extract %{{.*}}[0, 0] : vector<1x1x4xf32> -// PARALLEL: %[[F:.*]] = vector.fma %[[E0]], %[[E1]], %{{.*}} : vector<4xf32> -// PARALLEL: return %[[F]] : vector<4xf32> -func.func @parrallel_contract_lowering_broadcast(%arg0: vector<1x1xf32>, %arg1: vector<1x1x4xf32>, %arg2: vector<4xf32>) -> vector<4xf32> { - %0 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d1, d2, d0)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"], kind = #vector.kind} %arg0, %arg1, %arg2 : vector<1x1xf32>, vector<1x1x4xf32> into vector<4xf32> - return %0 : vector<4xf32> -} - -// PARALLEL-LABEL: func @parrallel_contract_lowering -// PARALLEL: %[[B:.*]] = vector.broadcast %{{.*}} : vector<1x1xf32> to vector<4x1x1xf32> -// PARALLEL: %[[T0:.*]] = vector.transpose %[[B]], [1, 2, 0] : vector<4x1x1xf32> to vector<1x1x4xf32> -// PARALLEL: %[[T1:.*]] = vector.transpose %{{.*}}, [0, 2, 1] : vector<1x4x1xf32> to vector<1x1x4xf32> -// PARALLEL: %[[E0:.*]] = vector.extract %[[T0]][0, 0] : vector<1x1x4xf32> -// PARALLEL: %[[E1:.*]] = vector.extract %[[T1]][0, 0] : vector<1x1x4xf32> -// PARALLEL: %[[F:.*]] = vector.fma %[[E0]], %[[E1]], %arg2 : vector<4xf32> -// PARALLEL: return %[[F]] : vector<4xf32> -func.func @parrallel_contract_lowering_transpose(%arg0: vector<1x1xf32>, %arg1: vector<1x4x1xf32>, %arg2: vector<4xf32>) -> vector<4xf32> { - %0 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d1, d0, d2)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"], kind = #vector.kind} %arg0, %arg1, %arg2 : vector<1x1xf32>, vector<1x4x1xf32> into vector<4xf32> - return %0 : vector<4xf32> -} - -// PARALLEL-LABEL: func @parrallel_contract_lowering_scalar -// PARALLEL: %[[E0:.*]] = vector.extract %{{.*}}[0, 0] : vector<1x1xf32> -// PARALLEL: %[[E1:.*]] = vector.extract %{{.*}}[0, 0] : vector<1x1xf32> -// PARALLEL: %[[M:.*]] = arith.mulf %[[E0]], %[[E1]] : f32 -// PARALLEL: %[[A:.*]] = arith.addf %[[M]], %{{.*}} : f32 -// PARALLEL: return %[[A]] : f32 -func.func @parrallel_contract_lowering_scalar(%arg0: vector<1x1xf32>, %arg1: vector<1x1xf32>, %arg2: f32) -> f32 { - %0 = vector.contract { - indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> ()>], - iterator_types = ["reduction", "reduction"], kind = #vector.kind} - %arg0, %arg1, %arg2 : vector<1x1xf32>, vector<1x1xf32> into f32 - return %0 : f32 -} - - diff --git a/mlir/test/Dialect/Vector/vector-mask-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-mask-lowering-transforms.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-mask-lowering-transforms.mlir @@ -0,0 +1,101 @@ +// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file | FileCheck %s + +// CHECK-LABEL: func @genbool_1d +// CHECK: %[[T0:.*]] = arith.constant dense<[true, true, true, true, false, false, false, false]> : vector<8xi1> +// CHECK: return %[[T0]] : vector<8xi1> + +func.func @genbool_1d() -> vector<8xi1> { + %0 = vector.constant_mask [4] : vector<8xi1> + return %0 : vector<8xi1> +} + +// CHECK-LABEL: func @genbool_2d +// CHECK: %[[C1:.*]] = arith.constant dense<[true, true, false, false]> : vector<4xi1> +// CHECK: %[[C2:.*]] = arith.constant dense : vector<4x4xi1> +// CHECK: %[[T0:.*]] = vector.insert %[[C1]], %[[C2]] [0] : vector<4xi1> into vector<4x4xi1> +// CHECK: %[[T1:.*]] = vector.insert %[[C1]], %[[T0]] [1] : vector<4xi1> into vector<4x4xi1> +// CHECK: return %[[T1]] : vector<4x4xi1> + +func.func @genbool_2d() -> vector<4x4xi1> { + %v = vector.constant_mask [2, 2] : vector<4x4xi1> + return %v: vector<4x4xi1> +} + +// CHECK-LABEL: func @genbool_3d +// CHECK: %[[C1:.*]] = arith.constant dense<[true, true, true, false]> : vector<4xi1> +// CHECK: %[[C2:.*]] = arith.constant dense : vector<3x4xi1> +// CHECK: %[[C3:.*]] = arith.constant dense : vector<2x3x4xi1> +// CHECK: %[[T0:.*]] = vector.insert %[[C1]], %[[C2]] [0] : vector<4xi1> into vector<3x4xi1> +// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C3]] [0] : vector<3x4xi1> into vector<2x3x4xi1> +// CHECK: return %[[T1]] : vector<2x3x4xi1> + +func.func @genbool_3d() -> vector<2x3x4xi1> { + %v = vector.constant_mask [1, 1, 3] : vector<2x3x4xi1> + return %v: vector<2x3x4xi1> +} + +// CHECK-LABEL: func @genbool_var_1d( +// CHECK-SAME: %[[A:.*]]: index) +// CHECK: %[[T0:.*]] = vector.create_mask %[[A]] : vector<3xi1> +// CHECK: return %[[T0]] : vector<3xi1> + +func.func @genbool_var_1d(%arg0: index) -> vector<3xi1> { + %0 = vector.create_mask %arg0 : vector<3xi1> + return %0 : vector<3xi1> +} + +// CHECK-LABEL: func @genbool_var_2d( +// CHECK-SAME: %[[A:.*0]]: index, +// CHECK-SAME: %[[B:.*1]]: index) +// CHECK: %[[C1:.*]] = arith.constant dense : vector<3xi1> +// CHECK: %[[C2:.*]] = arith.constant dense : vector<2x3xi1> +// CHECK: %[[c0:.*]] = arith.constant 0 : index +// CHECK: %[[c1:.*]] = arith.constant 1 : index +// CHECK: %[[T0:.*]] = vector.create_mask %[[B]] : vector<3xi1> +// CHECK: %[[T1:.*]] = arith.cmpi sgt, %[[A]], %[[c0]] : index +// CHECK: %[[T2:.*]] = arith.select %[[T1]], %[[T0]], %[[C1]] : vector<3xi1> +// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C2]] [0] : vector<3xi1> into vector<2x3xi1> +// CHECK: %[[T4:.*]] = arith.cmpi sgt, %[[A]], %[[c1]] : index +// CHECK: %[[T5:.*]] = arith.select %[[T4]], %[[T0]], %[[C1]] : vector<3xi1> +// CHECK: %[[T6:.*]] = vector.insert %[[T5]], %[[T3]] [1] : vector<3xi1> into vector<2x3xi1> +// CHECK: return %[[T6]] : vector<2x3xi1> + +func.func @genbool_var_2d(%arg0: index, %arg1: index) -> vector<2x3xi1> { + %0 = vector.create_mask %arg0, %arg1 : vector<2x3xi1> + return %0 : vector<2x3xi1> +} + +// CHECK-LABEL: func @genbool_var_3d( +// CHECK-SAME: %[[A:.*0]]: index, +// CHECK-SAME: %[[B:.*1]]: index, +// CHECK-SAME: %[[C:.*2]]: index) +// CHECK-DAG: %[[C1:.*]] = arith.constant dense : vector<7xi1> +// CHECK-DAG: %[[C2:.*]] = arith.constant dense : vector<1x7xi1> +// CHECK-DAG: %[[C3:.*]] = arith.constant dense : vector<2x1x7xi1> +// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index +// CHECK: %[[T0:.*]] = vector.create_mask %[[C]] : vector<7xi1> +// CHECK: %[[T1:.*]] = arith.cmpi sgt, %[[B]], %[[c0]] : index +// CHECK: %[[T2:.*]] = arith.select %[[T1]], %[[T0]], %[[C1]] : vector<7xi1> +// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C2]] [0] : vector<7xi1> into vector<1x7xi1> +// CHECK: %[[T4:.*]] = arith.cmpi sgt, %[[A]], %[[c0]] : index +// CHECK: %[[T5:.*]] = arith.select %[[T4]], %[[T3]], %[[C2]] : vector<1x7xi1> +// CHECK: %[[T6:.*]] = vector.insert %[[T5]], %[[C3]] [0] : vector<1x7xi1> into vector<2x1x7xi1> +// CHECK: %[[T7:.*]] = arith.cmpi sgt, %[[A]], %[[c1]] : index +// CHECK: %[[T8:.*]] = arith.select %[[T7]], %[[T3]], %[[C2]] : vector<1x7xi1> +// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T6]] [1] : vector<1x7xi1> into vector<2x1x7xi1> +// CHECK: return %[[T9]] : vector<2x1x7xi1> + +func.func @genbool_var_3d(%arg0: index, %arg1: index, %arg2: index) -> vector<2x1x7xi1> { + %0 = vector.create_mask %arg0, %arg1, %arg2 : vector<2x1x7xi1> + return %0 : vector<2x1x7xi1> +} + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + %f = transform.structured.match ops{["func.func"]} in %module_op + : (!pdl.operation) -> !pdl.operation + + transform.vector.lower_mask %f + : (!pdl.operation) -> !pdl.operation +} diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir --- a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -test-vector-multi-reduction-lowering-patterns -split-input-file | FileCheck %s +// RUN: mlir-opt %s --test-transform-dialect-interpreter | FileCheck %s func.func @vector_multi_reduction(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> { %0 = vector.multi_reduction , %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32> @@ -264,3 +264,10 @@ // CHECK-LABEL: func @vector_multi_reduction_parallel_middle // CHECK-SAME: %[[INPUT:.+]]: vector<3x4x5xf32>, %[[ACC:.+]]: vector<4xf32> // CHECK: vector.transpose %[[INPUT]], [1, 0, 2] : vector<3x4x5xf32> to vector<4x3x5xf32> + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + transform.vector.lower_multi_reduction %module_op + lowering_strategy = "innerreduction" + : (!pdl.operation) -> !pdl.operation +} diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir --- a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -test-vector-multi-reduction-lowering-patterns="use-outer-reductions" | FileCheck %s +// RUN: mlir-opt %s --test-transform-dialect-interpreter | FileCheck %s func.func @vector_multi_reduction(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> { %0 = vector.multi_reduction , %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32> @@ -187,3 +187,10 @@ } // CHECK-LABEL: func @vector_multi_reduction_to_scalar // CHECK: return %{{.+}} + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + transform.vector.lower_multi_reduction %module_op + lowering_strategy = "innerparallel" + : (!pdl.operation) -> !pdl.operation +} diff --git a/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir @@ -0,0 +1,148 @@ +// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file | FileCheck %s + +// CHECK-LABEL: func @outerproduct_noacc +// CHECK-SAME: %[[A:.*0]]: vector<2xf32>, +// CHECK-SAME: %[[B:.*1]]: vector<3xf32> +// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xf32> +// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xf32> +// CHECK: %[[T2:.*]] = arith.mulf %[[T1]], %[[B]] : vector<3xf32> +// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32> +// CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : vector<2xf32> +// CHECK: %[[T5:.*]] = vector.splat %[[T4]] : vector<3xf32> +// CHECK: %[[T6:.*]] = arith.mulf %[[T5]], %[[B]] : vector<3xf32> +// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xf32> into vector<2x3xf32> +// CHECK: return %[[T7]] : vector<2x3xf32> + +func.func @outerproduct_noacc(%arg0: vector<2xf32>, + %arg1: vector<3xf32>) -> vector<2x3xf32> { + %0 = vector.outerproduct %arg0, %arg1 : vector<2xf32>, vector<3xf32> + return %0: vector<2x3xf32> +} + +// CHECK-LABEL: func @outerproduct_acc +// CHECK-SAME: %[[A:.*0]]: vector<2xf32>, +// CHECK-SAME: %[[B:.*1]]: vector<3xf32>, +// CHECK-SAME: %[[C:.*2]]: vector<2x3xf32> +// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xf32> +// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xf32> +// CHECK: %[[T2:.*]] = vector.extract %[[C]][0] : vector<2x3xf32> +// CHECK: %[[T3:.*]] = vector.fma %[[T1]], %[[B]], %[[T2]] : vector<3xf32> +// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32> +// CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : vector<2xf32> +// CHECK: %[[T6:.*]] = vector.splat %[[T5]] : vector<3xf32> +// CHECK: %[[T7:.*]] = vector.extract %[[C]][1] : vector<2x3xf32> +// CHECK: %[[T8:.*]] = vector.fma %[[T6]], %[[B]], %[[T7]] : vector<3xf32> +// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : vector<3xf32> into vector<2x3xf32> +// CHECK: return %[[T9]] : vector<2x3xf32> + +func.func @outerproduct_acc(%arg0: vector<2xf32>, + %arg1: vector<3xf32>, + %arg2: vector<2x3xf32>) -> vector<2x3xf32> { + %0 = vector.outerproduct %arg0, %arg1, %arg2 : vector<2xf32>, vector<3xf32> + return %0: vector<2x3xf32> +} + +// CHECK-LABEL: func @outerproduct_noacc_int +// CHECK-SAME: %[[A:.*0]]: vector<2xi32>, +// CHECK-SAME: %[[B:.*1]]: vector<3xi32> +// CHECK: %[[C0:.*]] = arith.constant dense<0> : vector<2x3xi32> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xi32> +// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xi32> +// CHECK: %[[T2:.*]] = arith.muli %[[T1]], %[[B]] : vector<3xi32> +// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xi32> into vector<2x3xi32> +// CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : vector<2xi32> +// CHECK: %[[T5:.*]] = vector.splat %[[T4]] : vector<3xi32> +// CHECK: %[[T6:.*]] = arith.muli %[[T5]], %[[B]] : vector<3xi32> +// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xi32> into vector<2x3xi32> +// CHECK: return %[[T7]] : vector<2x3xi32> +func.func @outerproduct_noacc_int(%arg0: vector<2xi32>, + %arg1: vector<3xi32>) -> vector<2x3xi32> { + %0 = vector.outerproduct %arg0, %arg1 : vector<2xi32>, vector<3xi32> + return %0: vector<2x3xi32> +} + +// CHECK-LABEL: func @outerproduct_acc_int +// CHECK-SAME: %[[A:.*0]]: vector<2xi32>, +// CHECK-SAME: %[[B:.*1]]: vector<3xi32>, +// CHECK-SAME: %[[C:.*2]]: vector<2x3xi32> +// CHECK: %[[C0:.*]] = arith.constant dense<0> : vector<2x3xi32> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xi32> +// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xi32> +// CHECK: %[[T2:.*]] = vector.extract %[[C]][0] : vector<2x3xi32> +// CHECK: %[[T3:.*]] = arith.muli %[[T1]], %[[B]] : vector<3xi32> +// CHECK: %[[T4:.*]] = arith.addi %[[T3]], %[[T2]] : vector<3xi32> +// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[C0]] [0] : vector<3xi32> into vector<2x3xi32> +// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xi32> +// CHECK: %[[T7:.*]] = vector.splat %[[T6]] : vector<3xi32> +// CHECK: %[[T8:.*]] = vector.extract %[[C]][1] : vector<2x3xi32> +// CHECK: %[[T9:.*]] = arith.muli %[[T7]], %[[B]] : vector<3xi32> +// CHECK: %[[T10:.*]] = arith.addi %[[T9]], %[[T8]] : vector<3xi32> +// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T5]] [1] : vector<3xi32> into vector<2x3xi32> +// CHECK: return %[[T11]] : vector<2x3xi32> +func.func @outerproduct_acc_int(%arg0: vector<2xi32>, + %arg1: vector<3xi32>, + %arg2: vector<2x3xi32>) -> vector<2x3xi32> { + %0 = vector.outerproduct %arg0, %arg1, %arg2 : vector<2xi32>, vector<3xi32> + return %0: vector<2x3xi32> +} + +// CHECK-LABEL: func @axpy_fp( +// CHECK-SAME: %[[A:.*0]]: vector<16xf32>, +// CHECK-SAME: %[[B:.*1]]: f32) +// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xf32> +// CHECK: %[[T1:.*]] = arith.mulf %[[A]], %[[T0]] : vector<16xf32> +// CHECK: return %[[T1]] : vector<16xf32> +func.func @axpy_fp(%arg0: vector<16xf32>, %arg1: f32) -> vector<16xf32> { + %0 = vector.outerproduct %arg0, %arg1: vector<16xf32>, f32 + return %0: vector<16xf32> +} + +// CHECK-LABEL: func @axpy_fp_add( +// CHECK-SAME: %[[A:.*0]]: vector<16xf32>, +// CHECK-SAME: %[[B:.*1]]: f32, +// CHECK-SAME: %[[C:.*2]]: vector<16xf32>) +// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xf32> +// CHECK: %[[T1:.*]] = vector.fma %[[A]], %[[T0]], %[[C]] : vector<16xf32> +// CHECK: return %[[T1]] : vector<16xf32> +func.func @axpy_fp_add(%arg0: vector<16xf32>, %arg1: f32, %arg2 : vector<16xf32>) -> vector<16xf32> { + %0 = vector.outerproduct %arg0, %arg1, %arg2: vector<16xf32>, f32 + return %0: vector<16xf32> +} + +// CHECK-LABEL: func @axpy_int( +// CHECK-SAME: %[[A:.*0]]: vector<16xi32>, +// CHECK-SAME: %[[B:.*1]]: i32) +// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xi32> +// CHECK: %[[T1:.*]] = arith.muli %[[A]], %[[T0]] : vector<16xi32> +// CHECK: return %[[T1]] : vector<16xi32> +func.func @axpy_int(%arg0: vector<16xi32>, %arg1: i32) -> vector<16xi32> { + %0 = vector.outerproduct %arg0, %arg1: vector<16xi32>, i32 + return %0: vector<16xi32> +} + +// CHECK-LABEL: func @axpy_int_add( +// CHECK-SAME: %[[A:.*0]]: vector<16xi32>, +// CHECK-SAME: %[[B:.*1]]: i32, +// CHECK-SAME: %[[C:.*2]]: vector<16xi32>) +// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xi32> +// CHECK: %[[T1:.*]] = arith.muli %[[A]], %[[T0]] : vector<16xi32> +// CHECK: %[[T2:.*]] = arith.addi %[[T1]], %[[C]] : vector<16xi32> +// CHECK: return %[[T2]] : vector<16xi32> +func.func @axpy_int_add(%arg0: vector<16xi32>, %arg1: i32, %arg2: vector<16xi32>) -> vector<16xi32> { + %0 = vector.outerproduct %arg0, %arg1, %arg2: vector<16xi32>, i32 + return %0: vector<16xi32> +} + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + %f = transform.structured.match ops{["func.func"]} in %module_op + : (!pdl.operation) -> !pdl.operation + + %f2 = transform.vector.lower_outerproduct %f + : (!pdl.operation) -> !pdl.operation + + %f3 = transform.vector.lower_broadcast %f2 + : (!pdl.operation) -> !pdl.operation +} diff --git a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir @@ -0,0 +1,134 @@ + +// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file | FileCheck %s + +// CHECK-LABEL: func @nop_shape_cast +// CHECK-SAME: %[[A:.*]]: vector<16xf32> +// CHECK: return %[[A]] : vector<16xf32> +func.func @nop_shape_cast(%arg0: vector<16xf32>) -> vector<16xf32> { + %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<16xf32> + return %0 : vector<16xf32> +} + +// CHECK-LABEL: func @cancel_shape_cast +// CHECK-SAME: %[[A:.*]]: vector<16xf32> +// CHECK: return %[[A]] : vector<16xf32> + +func.func @cancel_shape_cast(%arg0: vector<16xf32>) -> vector<16xf32> { + %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<4x4xf32> + %1 = vector.shape_cast %0 : vector<4x4xf32> to vector<16xf32> + return %1 : vector<16xf32> +} + +// Shape up and downcasts for 2-D vectors, for supporting conversion to +// llvm.matrix operations +// CHECK-LABEL: func @shape_casts +func.func @shape_casts(%a: vector<2x2xf32>) -> (vector<4xf32>, vector<2x2xf32>) { + // CHECK-DAG: %[[cst22:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32> + // CHECK-DAG: %[[cst:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32> + // CHECK: %[[ex0:.*]] = vector.extract %{{.*}}[0] : vector<2x2xf32> + // + // CHECK: %[[in0:.*]] = vector.insert_strided_slice %[[ex0]], %[[cst]] + // CHECK-SAME: {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32> + // + // CHECK: %[[ex1:.*]] = vector.extract %{{.*}}[1] : vector<2x2xf32> + // + // CHECK: %[[in2:.*]] = vector.insert_strided_slice %[[ex1]], %[[in0]] + // CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32> + // + %0 = vector.shape_cast %a : vector<2x2xf32> to vector<4xf32> + // CHECK: %[[add:.*]] = arith.addf %[[in2]], %[[in2]] : vector<4xf32> + %r0 = arith.addf %0, %0: vector<4xf32> + // + // CHECK: %[[ss0:.*]] = vector.extract_strided_slice %[[add]] + // CHECK-SAME: {offsets = [0], sizes = [2], strides = [1]} : + // CHECK-SAME: vector<4xf32> to vector<2xf32> + // + // CHECK: %[[res0:.*]] = vector.insert %[[ss0]], %[[cst22]] [0] : + // CHECK-SAME: vector<2xf32> into vector<2x2xf32> + // + // CHECK: %[[s2:.*]] = vector.extract_strided_slice %[[add]] + // CHECK-SAME: {offsets = [2], sizes = [2], strides = [1]} : + // CHECK-SAME: vector<4xf32> to vector<2xf32> + // + // CHECK: %[[res1:.*]] = vector.insert %[[s2]], %[[res0]] [1] : + // CHECK-SAME: vector<2xf32> into vector<2x2xf32> + // + %1 = vector.shape_cast %r0 : vector<4xf32> to vector<2x2xf32> + // CHECK: return %[[add]], %[[res1]] : vector<4xf32>, vector<2x2xf32> + return %r0, %1 : vector<4xf32>, vector<2x2xf32> +} + +// CHECK-LABEL: func @shape_cast_2d2d +// CHECK-SAME: %[[A:.*]]: vector<3x2xf32> +// CHECK: %[[C:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<3x2xf32> +// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0, 0] : f32 into vector<2x3xf32> +// CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 1] : vector<3x2xf32> +// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 1] : f32 into vector<2x3xf32> +// CHECK: %[[T4:.*]] = vector.extract %[[A]][1, 0] : vector<3x2xf32> +// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [0, 2] : f32 into vector<2x3xf32> +// CHECK: %[[T6:.*]] = vector.extract %[[A]][1, 1] : vector<3x2xf32> +// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 0] : f32 into vector<2x3xf32> +// CHECK: %[[T8:.*]] = vector.extract %[[A]][2, 0] : vector<3x2xf32> +// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [1, 1] : f32 into vector<2x3xf32> +// CHECK: %[[T10:.*]] = vector.extract %[[A]][2, 1] : vector<3x2xf32> +// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [1, 2] : f32 into vector<2x3xf32> +// CHECK: return %[[T11]] : vector<2x3xf32> + +func.func @shape_cast_2d2d(%arg0 : vector<3x2xf32>) -> vector<2x3xf32> { + %s = vector.shape_cast %arg0: vector<3x2xf32> to vector<2x3xf32> + return %s : vector<2x3xf32> +} + +// CHECK-LABEL: func @shape_cast_3d1d +// CHECK-SAME: %[[A:.*]]: vector<1x3x2xf32> +// CHECK: %[[C:.*]] = arith.constant dense<0.000000e+00> : vector<6xf32> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0, 0] : vector<1x3x2xf32> +// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0] : f32 into vector<6xf32> +// CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 0, 1] : vector<1x3x2xf32> +// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [1] : f32 into vector<6xf32> +// CHECK: %[[T4:.*]] = vector.extract %[[A]][0, 1, 0] : vector<1x3x2xf32> +// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [2] : f32 into vector<6xf32> +// CHECK: %[[T6:.*]] = vector.extract %[[A]][0, 1, 1] : vector<1x3x2xf32> +// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [3] : f32 into vector<6xf32> +// CHECK: %[[T8:.*]] = vector.extract %[[A]][0, 2, 0] : vector<1x3x2xf32> +// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [4] : f32 into vector<6xf32> +// CHECK: %[[T10:.*]] = vector.extract %[[A]][0, 2, 1] : vector<1x3x2xf32> +// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [5] : f32 into vector<6xf32> +// CHECK: return %[[T11]] : vector<6xf32> + +func.func @shape_cast_3d1d(%arg0 : vector<1x3x2xf32>) -> vector<6xf32> { + %s = vector.shape_cast %arg0 : vector<1x3x2xf32> to vector<6xf32> + return %s : vector<6xf32> +} + +// CHECK-LABEL: func @shape_cast_1d3d +// CHECK-SAME: %[[A:.*]]: vector<6xf32> +// CHECK: %[[C:.*]] = arith.constant dense<0.000000e+00> : vector<2x1x3xf32> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<6xf32> +// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0, 0, 0] : f32 into vector<2x1x3xf32> +// CHECK: %[[T2:.*]] = vector.extract %[[A]][1] : vector<6xf32> +// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 0, 1] : f32 into vector<2x1x3xf32> +// CHECK: %[[T4:.*]] = vector.extract %[[A]][2] : vector<6xf32> +// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [0, 0, 2] : f32 into vector<2x1x3xf32> +// CHECK: %[[T6:.*]] = vector.extract %[[A]][3] : vector<6xf32> +// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 0, 0] : f32 into vector<2x1x3xf32> +// CHECK: %[[T8:.*]] = vector.extract %[[A]][4] : vector<6xf32> +// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [1, 0, 1] : f32 into vector<2x1x3xf32> +// CHECK: %[[T10:.*]] = vector.extract %[[A]][5] : vector<6xf32> +// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [1, 0, 2] : f32 into vector<2x1x3xf32> +// CHECK: return %[[T11]] : vector<2x1x3xf32> + +func.func @shape_cast_1d3d(%arg0 : vector<6xf32>) -> vector<2x1x3xf32> { + %s = vector.shape_cast %arg0 : vector<6xf32> to vector<2x1x3xf32> + return %s : vector<2x1x3xf32> +} + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + %f = transform.structured.match ops{["func.func"]} in %module_op + : (!pdl.operation) -> !pdl.operation + + %f2 = transform.vector.lower_shape_cast %f + : (!pdl.operation) -> !pdl.operation +} diff --git a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir --- a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -test-vector-transfer-drop-unit-dims-patterns -split-input-file | FileCheck %s +// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file | FileCheck %s func.func @transfer_read_rank_reducing( %arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>) -> vector<3x2xi8> { @@ -15,8 +15,6 @@ // CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}> // CHECK: vector.transfer_read %[[SUBVIEW]] -// ----- - func.func @transfer_write_rank_reducing(%arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, %vec : vector<3x2xi8>) { %c0 = arith.constant 0 : index vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] : @@ -28,4 +26,11 @@ // CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x2xi8 // CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0] [1, 1, 3, 2] [1, 1, 1, 1] // CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}> -// CHECK: vector.transfer_write %{{.*}}, %[[SUBVIEW]] \ No newline at end of file +// CHECK: vector.transfer_write %{{.*}}, %[[SUBVIEW]] + + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + transform.vector.apply_rank_reducing_subview_patterns %module_op + : (!pdl.operation) -> !pdl.operation +} diff --git a/mlir/test/Dialect/Vector/vector-transfer-full-partial-split-copy-transform.mlir b/mlir/test/Dialect/Vector/vector-transfer-full-partial-split-copy-transform.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-transfer-full-partial-split-copy-transform.mlir @@ -0,0 +1,243 @@ +// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file | FileCheck %s + +// CHECK-DAG: #[[$map_p4:.*]] = affine_map<()[s0] -> (s0 + 4)> +// CHECK-DAG: #[[$map_p8:.*]] = affine_map<()[s0] -> (s0 + 8)> +// CHECK-DAG: #[[$bounds_map_4:.*]] = affine_map<(d0, d1, d2) -> (d0 - d1, 4)> +// CHECK-DAG: #[[$bounds_map_8:.*]] = affine_map<(d0, d1, d2) -> (d0 - d1, 8)> + +// CHECK-LABEL: split_vector_transfer_read_2d( +// CHECK-SAME: %[[A:[a-zA-Z0-9_]*]]: memref +// CHECK-SAME: %[[i:[a-zA-Z0-9_]*]]: index +// CHECK-SAME: %[[j:[a-zA-Z0-9_]*]]: index +func.func @split_vector_transfer_read_2d(%A: memref, %i: index, %j: index) -> vector<4x8xf32> { + %c0 = arith.constant 0 : index + %f0 = arith.constant 0.0 : f32 + + // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index + // CHECK-DAG: %[[c8:.*]] = arith.constant 8 : index + // alloca for boundary full tile + // CHECK: %[[alloc:.*]] = memref.alloca() {alignment = 32 : i64} : memref<4x8xf32> + // %i + 4 <= dim(%A, 0) + // CHECK: %[[idx0:.*]] = affine.apply #[[$map_p4]]()[%[[i]]] + // CHECK: %[[d0:.*]] = memref.dim %[[A]], %[[c0]] : memref + // CHECK: %[[cmp0:.*]] = arith.cmpi sle, %[[idx0]], %[[d0]] : index + // %j + 8 <= dim(%A, 1) + // CHECK: %[[idx1:.*]] = affine.apply #[[$map_p8]]()[%[[j]]] + // CHECK: %[[cmp1:.*]] = arith.cmpi sle, %[[idx1]], %[[c8]] : index + // are both conds true + // CHECK: %[[cond:.*]] = arith.andi %[[cmp0]], %[[cmp1]] : i1 + // CHECK: %[[ifres:.*]]:3 = scf.if %[[cond]] -> (memref, index, index) { + // inBounds, just yield %A + // CHECK: scf.yield %[[A]], %[[i]], %[[j]] : memref, index, index + // CHECK: } else { + // slow path, fill tmp alloc and yield a memref_casted version of it + // CHECK: linalg.fill ins(%cst : f32) outs(%[[alloc]] : memref<4x8xf32>) + // CHECK: %[[d0:.*]] = memref.dim %[[A]], %[[c0]] : memref + // CHECK: %[[sv0:.*]] = affine.min #[[$bounds_map_4]](%[[d0]], %[[i]], %[[c4]]) + // CHECK: %[[sv1:.*]] = affine.min #[[$bounds_map_8]](%[[c8]], %[[j]], %[[c8]]) + // CHECK: %[[sv:.*]] = memref.subview %[[A]][%[[i]], %[[j]]] [%[[sv0]], %[[sv1]]] [1, 1] + // CHECK-SAME: memref to memref> + // CHECK: %[[alloc_view:.*]] = memref.subview %[[alloc]][0, 0] [%[[sv0]], %[[sv1]]] [1, 1] + // CHECK: memref.copy %[[sv]], %[[alloc_view]] : memref> to memref + // CHECK: %[[yielded:.*]] = memref.cast %[[alloc]] : + // CHECK-SAME: memref<4x8xf32> to memref + // CHECK: scf.yield %[[yielded]], %[[c0]], %[[c0]] : + // CHECK-SAME: memref, index, index + // CHECK: } + // CHECK: %[[res:.*]] = vector.transfer_read %[[ifres]]#0[%[[ifres]]#1, %[[ifres]]#2], %cst + // CHECK-SAME: {in_bounds = [true, true]} : memref, vector<4x8xf32> + %1 = vector.transfer_read %A[%i, %j], %f0 : memref, vector<4x8xf32> + + // CHECK: return %[[res]] : vector<4x8xf32> + return %1: vector<4x8xf32> +} + +// CHECK-LABEL: split_vector_transfer_read_strided_2d( +// CHECK-SAME: %[[A:[a-zA-Z0-9_]*]]: memref +// CHECK-SAME: %[[i:[a-zA-Z0-9_]*]]: index +// CHECK-SAME: %[[j:[a-zA-Z0-9_]*]]: index +func.func @split_vector_transfer_read_strided_2d( + %A: memref<7x8xf32, strided<[?, 1], offset: ?>>, + %i: index, %j: index) -> vector<4x8xf32> { + %c0 = arith.constant 0 : index + %f0 = arith.constant 0.0 : f32 + + + // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index + // CHECK-DAG: %[[c7:.*]] = arith.constant 7 : index + // CHECK-DAG: %[[c8:.*]] = arith.constant 8 : index + // alloca for boundary full tile + // CHECK: %[[alloc:.*]] = memref.alloca() {alignment = 32 : i64} : memref<4x8xf32> + // %i + 4 <= dim(%A, 0) + // CHECK: %[[idx0:.*]] = affine.apply #[[$map_p4]]()[%[[i]]] + // CHECK: %[[cmp0:.*]] = arith.cmpi sle, %[[idx0]], %[[c7]] : index + // %j + 8 <= dim(%A, 1) + // CHECK: %[[idx1:.*]] = affine.apply #[[$map_p8]]()[%[[j]]] + // CHECK: %[[cmp1:.*]] = arith.cmpi sle, %[[idx1]], %[[c8]] : index + // are both conds true + // CHECK: %[[cond:.*]] = arith.andi %[[cmp0]], %[[cmp1]] : i1 + // CHECK: %[[ifres:.*]]:3 = scf.if %[[cond]] -> (memref>, index, index) { + // inBounds but not cast-compatible: yield a memref_casted form of %A + // CHECK: %[[casted:.*]] = memref.cast %arg0 : + // CHECK-SAME: memref<7x8xf32, strided<[?, 1], offset: ?>> to memref> + // CHECK: scf.yield %[[casted]], %[[i]], %[[j]] : + // CHECK-SAME: memref>, index, index + // CHECK: } else { + // slow path, fill tmp alloc and yield a memref_casted version of it + // CHECK: linalg.fill ins(%cst : f32) outs(%[[alloc]] : memref<4x8xf32>) + // CHECK: %[[sv0:.*]] = affine.min #[[$bounds_map_4]](%[[c7]], %[[i]], %[[c4]]) + // CHECK: %[[sv1:.*]] = affine.min #[[$bounds_map_8]](%[[c8]], %[[j]], %[[c8]]) + // CHECK: %[[sv:.*]] = memref.subview %[[A]][%[[i]], %[[j]]] [%[[sv0]], %[[sv1]]] [1, 1] + // CHECK-SAME: memref<7x8xf32, strided<[?, 1], offset: ?>> to memref> + // CHECK: %[[alloc_view:.*]] = memref.subview %[[alloc]][0, 0] [%[[sv0]], %[[sv1]]] [1, 1] + // CHECK: memref.copy %[[sv]], %[[alloc_view]] : memref> to memref + // CHECK: %[[yielded:.*]] = memref.cast %[[alloc]] : + // CHECK-SAME: memref<4x8xf32> to memref> + // CHECK: scf.yield %[[yielded]], %[[c0]], %[[c0]] : + // CHECK-SAME: memref>, index, index + // CHECK: } + // CHECK: %[[res:.*]] = vector.transfer_read {{.*}} {in_bounds = [true, true]} : + // CHECK-SAME: memref>, vector<4x8xf32> + %1 = vector.transfer_read %A[%i, %j], %f0 : + memref<7x8xf32, strided<[?, 1], offset: ?>>, vector<4x8xf32> + + return %1 : vector<4x8xf32> +} + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + transform.vector.split_transfer_full_partial %module_op + split_transfer_strategy = "linalg-copy" + : (!pdl.operation) -> !pdl.operation +} + +// ----- + +func.func @split_vector_transfer_write_2d(%V: vector<4x8xf32>, %A: memref, %i: index, %j: index) { + vector.transfer_write %V, %A[%i, %j] : + vector<4x8xf32>, memref + return +} + +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (s0 + 4)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 + 8)> +// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0 - d1, 4)> +// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0 - d1, 8)> + +// CHECK-LABEL: func @split_vector_transfer_write_2d( +// CHECK-SAME: %[[VEC:.*]]: vector<4x8xf32>, +// CHECK-SAME: %[[DEST:.*]]: memref, +// CHECK-SAME: %[[I:.*]]: index, +// CHECK-SAME: %[[J:.*]]: index) { +// CHECK-DAG: %[[CT:.*]] = arith.constant true +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index +// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index +// CHECK: %[[TEMP:.*]] = memref.alloca() {alignment = 32 : i64} : memref<4x8xf32> +// CHECK: %[[IDX0:.*]] = affine.apply #[[$MAP0]]()[%[[I]]] +// CHECK: %[[DIM0:.*]] = memref.dim %[[DEST]], %[[C0]] : memref +// CHECK: %[[DIM0_IN:.*]] = arith.cmpi sle, %[[IDX0]], %[[DIM0]] : index +// CHECK: %[[DIM1:.*]] = affine.apply #[[$MAP1]]()[%[[J]]] +// CHECK: %[[DIM1_IN:.*]] = arith.cmpi sle, %[[DIM1]], %[[C8]] : index +// CHECK: %[[IN_BOUNDS:.*]] = arith.andi %[[DIM0_IN]], %[[DIM1_IN]] : i1 +// CHECK: %[[IN_BOUND_DEST:.*]]:3 = scf.if %[[IN_BOUNDS]] +// CHECK-SAME: -> (memref, index, index) { +// CHECK: scf.yield %[[DEST]], %[[I]], %[[J]] : memref, index, index +// CHECK: } else { +// CHECK: %[[VAL_16:.*]] = memref.cast %[[TEMP]] : memref<4x8xf32> to memref +// CHECK: scf.yield %[[VAL_16]], %[[C0]], %[[C0]] : memref, index, index +// CHECK: } +// CHECK: vector.transfer_write %[[VEC]], +// CHECK-SAME: %[[IN_BOUND_DEST:.*]]#0[%[[IN_BOUND_DEST]]#1, %[[IN_BOUND_DEST]]#2] +// CHECK-SAME: {in_bounds = [true, true]} : vector<4x8xf32>, memref +// CHECK: %[[OUT_BOUNDS:.*]] = arith.xori %[[IN_BOUNDS]], %[[CT]] : i1 +// CHECK: scf.if %[[OUT_BOUNDS]] { +// CHECK: %[[VAL_19:.*]] = memref.dim %[[DEST]], %[[C0]] : memref +// CHECK-DAG: %[[VAL_20:.*]] = affine.min #[[$MAP2]](%[[VAL_19]], %[[I]], %[[C4]]) +// CHECK-DAG: %[[VAL_21:.*]] = affine.min #[[$MAP3]](%[[C8]], %[[J]], %[[C8]]) +// CHECK: %[[VAL_22:.*]] = memref.subview %[[TEMP]] +// CHECK-SAME: [%[[I]], %[[J]]] [%[[VAL_20]], %[[VAL_21]]] +// CHECK-SAME: [1, 1] : memref<4x8xf32> to memref> +// CHECK: %[[DEST_VIEW:.*]] = memref.subview %[[DEST]][0, 0] [%[[VAL_20]], %[[VAL_21]]] [1, 1] +// CHECK: memref.copy %[[VAL_22]], %[[DEST_VIEW]] +// CHECK-SAME: : memref> to memref +// CHECK: } +// CHECK: return +// CHECK: } + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + transform.vector.split_transfer_full_partial %module_op + split_transfer_strategy = "linalg-copy" + : (!pdl.operation) -> !pdl.operation +} + +// ----- + +func.func @split_vector_transfer_write_strided_2d( + %V: vector<4x8xf32>, %A: memref<7x8xf32, strided<[?, 1], offset: ?>>, + %i: index, %j: index) { + vector.transfer_write %V, %A[%i, %j] : + vector<4x8xf32>, memref<7x8xf32, strided<[?, 1], offset: ?>> + return +} + +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 + 4)> +// CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0] -> (s0 + 8)> +// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0 - d1, 4)> +// CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0 - d1, 8)> +// CHECK-LABEL: func @split_vector_transfer_write_strided_2d( +// CHECK-SAME: %[[VEC:.*]]: vector<4x8xf32>, +// CHECK-SAME: %[[DEST:.*]]: memref<7x8xf32, strided<[?, 1], offset: ?>>, +// CHECK-SAME: %[[I:.*]]: index, +// CHECK-SAME: %[[J:.*]]: index) { +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[CT:.*]] = arith.constant true +// CHECK-DAG: %[[C7:.*]] = arith.constant 7 : index +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index +// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index +// CHECK: %[[TEMP:.*]] = memref.alloca() {alignment = 32 : i64} : memref<4x8xf32> +// CHECK: %[[DIM0:.*]] = affine.apply #[[$MAP1]]()[%[[I]]] +// CHECK: %[[DIM0_IN:.*]] = arith.cmpi sle, %[[DIM0]], %[[C7]] : index +// CHECK: %[[DIM1:.*]] = affine.apply #[[$MAP2]]()[%[[J]]] +// CHECK: %[[DIM1_IN:.*]] = arith.cmpi sle, %[[DIM1]], %[[C8]] : index +// CHECK: %[[IN_BOUNDS:.*]] = arith.andi %[[DIM0_IN]], %[[DIM1_IN]] : i1 +// CHECK: %[[IN_BOUND_DEST:.*]]:3 = scf.if %[[IN_BOUNDS]] +// CHECK-SAME: -> (memref>, index, index) { +// CHECK: %[[VAL_16:.*]] = memref.cast %[[DEST]] +// CHECK-SAME: : memref<7x8xf32, strided<[?, 1], offset: ?>> to memref> +// CHECK: scf.yield %[[VAL_16]], %[[I]], %[[J]] +// CHECK-SAME: : memref>, index, index +// CHECK: } else { +// CHECK: %[[VAL_17:.*]] = memref.cast %[[TEMP]] +// CHECK-SAME: : memref<4x8xf32> to memref> +// CHECK: scf.yield %[[VAL_17]], %[[C0]], %[[C0]] +// CHECK-SAME: : memref>, index, index +// CHECK: } +// CHECK: vector.transfer_write %[[VEC]], +// CHECK-SAME: %[[IN_BOUND_DEST:.*]]#0 +// CHECK-SAME: [%[[IN_BOUND_DEST]]#1, %[[IN_BOUND_DEST]]#2] +// CHECK-SAME: {in_bounds = [true, true]} +// CHECK-SAME: : vector<4x8xf32>, memref> +// CHECK: %[[OUT_BOUNDS:.*]] = arith.xori %[[IN_BOUNDS]], %[[CT]] : i1 +// CHECK: scf.if %[[OUT_BOUNDS]] { +// CHECK-DAG: %[[VAL_20:.*]] = affine.min #[[$MAP3]](%[[C7]], %[[I]], %[[C4]]) +// CHECK-DAG: %[[VAL_21:.*]] = affine.min #[[$MAP4]](%[[C8]], %[[J]], %[[C8]]) +// CHECK: %[[VAL_22:.*]] = memref.subview %[[TEMP]] +// CHECK-SAME: [%[[I]], %[[J]]] [%[[VAL_20]], %[[VAL_21]]] +// CHECK-SAME: [1, 1] : memref<4x8xf32> to memref> +// CHECK: %[[DEST_VIEW:.*]] = memref.subview %[[DEST]][0, 0] [%[[VAL_20]], %[[VAL_21]]] [1, 1] +// CHECK: memref.copy %[[VAL_22]], %[[DEST_VIEW]] +// CHECK-SAME: : memref> to memref> +// CHECK: } +// CHECK: return +// CHECK: } + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + transform.vector.split_transfer_full_partial %module_op + split_transfer_strategy = "linalg-copy" + : (!pdl.operation) -> !pdl.operation +} diff --git a/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir b/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir --- a/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir @@ -1,23 +1,14 @@ -// RUN: mlir-opt %s -test-vector-transfer-full-partial-split -split-input-file | FileCheck %s -// RUN: mlir-opt %s -test-vector-transfer-full-partial-split=use-memref-copy -split-input-file | FileCheck %s --check-prefix=LINALG +// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file | FileCheck %s + // CHECK-DAG: #[[$map_p4:.*]] = affine_map<()[s0] -> (s0 + 4)> // CHECK-DAG: #[[$map_p8:.*]] = affine_map<()[s0] -> (s0 + 8)> -// LINALG-DAG: #[[$map_p4:.*]] = affine_map<()[s0] -> (s0 + 4)> -// LINALG-DAG: #[[$map_p8:.*]] = affine_map<()[s0] -> (s0 + 8)> -// LINALG-DAG: #[[$bounds_map_4:.*]] = affine_map<(d0, d1, d2) -> (d0 - d1, 4)> -// LINALG-DAG: #[[$bounds_map_8:.*]] = affine_map<(d0, d1, d2) -> (d0 - d1, 8)> - // CHECK-LABEL: split_vector_transfer_read_2d( // CHECK-SAME: %[[A:[a-zA-Z0-9_]*]]: memref // CHECK-SAME: %[[i:[a-zA-Z0-9_]*]]: index // CHECK-SAME: %[[j:[a-zA-Z0-9_]*]]: index -// LINALG-LABEL: split_vector_transfer_read_2d( -// LINALG-SAME: %[[A:[a-zA-Z0-9_]*]]: memref -// LINALG-SAME: %[[i:[a-zA-Z0-9_]*]]: index -// LINALG-SAME: %[[j:[a-zA-Z0-9_]*]]: index func.func @split_vector_transfer_read_2d(%A: memref, %i: index, %j: index) -> vector<4x8xf32> { %c0 = arith.constant 0 : index %f0 = arith.constant 0.0 : f32 @@ -53,43 +44,8 @@ // CHECK: %[[res:.*]] = vector.transfer_read %[[ifres]]#0[%[[ifres]]#1, %[[ifres]]#2], %cst // CHECK-SAME: {in_bounds = [true, true]} : memref, vector<4x8xf32> - // LINALG-DAG: %[[c0:.*]] = arith.constant 0 : index - // LINALG-DAG: %[[c4:.*]] = arith.constant 4 : index - // LINALG-DAG: %[[c8:.*]] = arith.constant 8 : index - // alloca for boundary full tile - // LINALG: %[[alloc:.*]] = memref.alloca() {alignment = 32 : i64} : memref<4x8xf32> - // %i + 4 <= dim(%A, 0) - // LINALG: %[[idx0:.*]] = affine.apply #[[$map_p4]]()[%[[i]]] - // LINALG: %[[d0:.*]] = memref.dim %[[A]], %[[c0]] : memref - // LINALG: %[[cmp0:.*]] = arith.cmpi sle, %[[idx0]], %[[d0]] : index - // %j + 8 <= dim(%A, 1) - // LINALG: %[[idx1:.*]] = affine.apply #[[$map_p8]]()[%[[j]]] - // LINALG: %[[cmp1:.*]] = arith.cmpi sle, %[[idx1]], %[[c8]] : index - // are both conds true - // LINALG: %[[cond:.*]] = arith.andi %[[cmp0]], %[[cmp1]] : i1 - // LINALG: %[[ifres:.*]]:3 = scf.if %[[cond]] -> (memref, index, index) { - // inBounds, just yield %A - // LINALG: scf.yield %[[A]], %[[i]], %[[j]] : memref, index, index - // LINALG: } else { - // slow path, fill tmp alloc and yield a memref_casted version of it - // LINALG: linalg.fill ins(%cst : f32) outs(%[[alloc]] : memref<4x8xf32>) - // LINALG: %[[d0:.*]] = memref.dim %[[A]], %[[c0]] : memref - // LINALG: %[[sv0:.*]] = affine.min #[[$bounds_map_4]](%[[d0]], %[[i]], %[[c4]]) - // LINALG: %[[sv1:.*]] = affine.min #[[$bounds_map_8]](%[[c8]], %[[j]], %[[c8]]) - // LINALG: %[[sv:.*]] = memref.subview %[[A]][%[[i]], %[[j]]] [%[[sv0]], %[[sv1]]] [1, 1] - // LINALG-SAME: memref to memref> - // LINALG: %[[alloc_view:.*]] = memref.subview %[[alloc]][0, 0] [%[[sv0]], %[[sv1]]] [1, 1] - // LINALG: memref.copy %[[sv]], %[[alloc_view]] : memref> to memref - // LINALG: %[[yielded:.*]] = memref.cast %[[alloc]] : - // LINALG-SAME: memref<4x8xf32> to memref - // LINALG: scf.yield %[[yielded]], %[[c0]], %[[c0]] : - // LINALG-SAME: memref, index, index - // LINALG: } - // LINALG: %[[res:.*]] = vector.transfer_read %[[ifres]]#0[%[[ifres]]#1, %[[ifres]]#2], %cst - // LINALG-SAME: {in_bounds = [true, true]} : memref, vector<4x8xf32> %1 = vector.transfer_read %A[%i, %j], %f0 : memref, vector<4x8xf32> - // LINALG: return %[[res]] : vector<4x8xf32> return %1: vector<4x8xf32> } @@ -98,10 +54,6 @@ // CHECK-SAME: %[[i:[a-zA-Z0-9_]*]]: index // CHECK-SAME: %[[j:[a-zA-Z0-9_]*]]: index -// LINALG-LABEL: split_vector_transfer_read_strided_2d( -// LINALG-SAME: %[[A:[a-zA-Z0-9_]*]]: memref -// LINALG-SAME: %[[i:[a-zA-Z0-9_]*]]: index -// LINALG-SAME: %[[j:[a-zA-Z0-9_]*]]: index func.func @split_vector_transfer_read_strided_2d( %A: memref<7x8xf32, strided<[?, 1], offset: ?>>, %i: index, %j: index) -> vector<4x8xf32> { @@ -142,43 +94,6 @@ // CHECK: } // CHECK: %[[res:.*]] = vector.transfer_read {{.*}} {in_bounds = [true, true]} : // CHECK-SAME: memref>, vector<4x8xf32> - - // LINALG-DAG: %[[c0:.*]] = arith.constant 0 : index - // LINALG-DAG: %[[c4:.*]] = arith.constant 4 : index - // LINALG-DAG: %[[c7:.*]] = arith.constant 7 : index - // LINALG-DAG: %[[c8:.*]] = arith.constant 8 : index - // alloca for boundary full tile - // LINALG: %[[alloc:.*]] = memref.alloca() {alignment = 32 : i64} : memref<4x8xf32> - // %i + 4 <= dim(%A, 0) - // LINALG: %[[idx0:.*]] = affine.apply #[[$map_p4]]()[%[[i]]] - // LINALG: %[[cmp0:.*]] = arith.cmpi sle, %[[idx0]], %[[c7]] : index - // %j + 8 <= dim(%A, 1) - // LINALG: %[[idx1:.*]] = affine.apply #[[$map_p8]]()[%[[j]]] - // LINALG: %[[cmp1:.*]] = arith.cmpi sle, %[[idx1]], %[[c8]] : index - // are both conds true - // LINALG: %[[cond:.*]] = arith.andi %[[cmp0]], %[[cmp1]] : i1 - // LINALG: %[[ifres:.*]]:3 = scf.if %[[cond]] -> (memref>, index, index) { - // inBounds but not cast-compatible: yield a memref_casted form of %A - // LINALG: %[[casted:.*]] = memref.cast %arg0 : - // LINALG-SAME: memref<7x8xf32, strided<[?, 1], offset: ?>> to memref> - // LINALG: scf.yield %[[casted]], %[[i]], %[[j]] : - // LINALG-SAME: memref>, index, index - // LINALG: } else { - // slow path, fill tmp alloc and yield a memref_casted version of it - // LINALG: linalg.fill ins(%cst : f32) outs(%[[alloc]] : memref<4x8xf32>) - // LINALG: %[[sv0:.*]] = affine.min #[[$bounds_map_4]](%[[c7]], %[[i]], %[[c4]]) - // LINALG: %[[sv1:.*]] = affine.min #[[$bounds_map_8]](%[[c8]], %[[j]], %[[c8]]) - // LINALG: %[[sv:.*]] = memref.subview %[[A]][%[[i]], %[[j]]] [%[[sv0]], %[[sv1]]] [1, 1] - // LINALG-SAME: memref<7x8xf32, strided<[?, 1], offset: ?>> to memref> - // LINALG: %[[alloc_view:.*]] = memref.subview %[[alloc]][0, 0] [%[[sv0]], %[[sv1]]] [1, 1] - // LINALG: memref.copy %[[sv]], %[[alloc_view]] : memref> to memref - // LINALG: %[[yielded:.*]] = memref.cast %[[alloc]] : - // LINALG-SAME: memref<4x8xf32> to memref> - // LINALG: scf.yield %[[yielded]], %[[c0]], %[[c0]] : - // LINALG-SAME: memref>, index, index - // LINALG: } - // LINALG: %[[res:.*]] = vector.transfer_read {{.*}} {in_bounds = [true, true]} : - // LINALG-SAME: memref>, vector<4x8xf32> %1 = vector.transfer_read %A[%i, %j], %f0 : memref<7x8xf32, strided<[?, 1], offset: ?>>, vector<4x8xf32> @@ -186,6 +101,13 @@ return %1 : vector<4x8xf32> } +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + transform.vector.split_transfer_full_partial %module_op + split_transfer_strategy = "vector-transfer" + : (!pdl.operation) -> !pdl.operation +} + // ----- func.func @split_vector_transfer_write_2d(%V: vector<4x8xf32>, %A: memref, %i: index, %j: index) { @@ -236,50 +158,13 @@ // CHECK: return // CHECK: } -// LINALG-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 + 4)> -// LINALG-DAG: #[[MAP1:.*]] = affine_map<()[s0] -> (s0 + 8)> -// LINALG-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0 - d1, 4)> -// LINALG-DAG: #[[MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0 - d1, 8)> -// LINALG: func @split_vector_transfer_write_2d( -// LINALG-SAME: %[[VEC:.*]]: vector<4x8xf32>, -// LINALG-SAME: %[[DEST:.*]]: memref, -// LINALG-SAME: %[[I:.*]]: index, -// LINALG-SAME: %[[J:.*]]: index) { -// LINALG-DAG: %[[CT:.*]] = arith.constant true -// LINALG-DAG: %[[C0:.*]] = arith.constant 0 : index -// LINALG-DAG: %[[C4:.*]] = arith.constant 4 : index -// LINALG-DAG: %[[C8:.*]] = arith.constant 8 : index -// LINALG: %[[TEMP:.*]] = memref.alloca() {alignment = 32 : i64} : memref<4x8xf32> -// LINALG: %[[IDX0:.*]] = affine.apply #[[MAP0]]()[%[[I]]] -// LINALG: %[[DIM0:.*]] = memref.dim %[[DEST]], %[[C0]] : memref -// LINALG: %[[DIM0_IN:.*]] = arith.cmpi sle, %[[IDX0]], %[[DIM0]] : index -// LINALG: %[[DIM1:.*]] = affine.apply #[[MAP1]]()[%[[J]]] -// LINALG: %[[DIM1_IN:.*]] = arith.cmpi sle, %[[DIM1]], %[[C8]] : index -// LINALG: %[[IN_BOUNDS:.*]] = arith.andi %[[DIM0_IN]], %[[DIM1_IN]] : i1 -// LINALG: %[[IN_BOUND_DEST:.*]]:3 = scf.if %[[IN_BOUNDS]] -// LINALG-SAME: -> (memref, index, index) { -// LINALG: scf.yield %[[DEST]], %[[I]], %[[J]] : memref, index, index -// LINALG: } else { -// LINALG: %[[VAL_16:.*]] = memref.cast %[[TEMP]] : memref<4x8xf32> to memref -// LINALG: scf.yield %[[VAL_16]], %[[C0]], %[[C0]] : memref, index, index -// LINALG: } -// LINALG: vector.transfer_write %[[VEC]], -// LINALG-SAME: %[[IN_BOUND_DEST:.*]]#0[%[[IN_BOUND_DEST]]#1, %[[IN_BOUND_DEST]]#2] -// LINALG-SAME: {in_bounds = [true, true]} : vector<4x8xf32>, memref -// LINALG: %[[OUT_BOUNDS:.*]] = arith.xori %[[IN_BOUNDS]], %[[CT]] : i1 -// LINALG: scf.if %[[OUT_BOUNDS]] { -// LINALG: %[[VAL_19:.*]] = memref.dim %[[DEST]], %[[C0]] : memref -// LINALG-DAG: %[[VAL_20:.*]] = affine.min #[[MAP2]](%[[VAL_19]], %[[I]], %[[C4]]) -// LINALG-DAG: %[[VAL_21:.*]] = affine.min #[[MAP3]](%[[C8]], %[[J]], %[[C8]]) -// LINALG: %[[VAL_22:.*]] = memref.subview %[[TEMP]] -// LINALG-SAME: [%[[I]], %[[J]]] [%[[VAL_20]], %[[VAL_21]]] -// LINALG-SAME: [1, 1] : memref<4x8xf32> to memref> -// LINALG: %[[DEST_VIEW:.*]] = memref.subview %[[DEST]][0, 0] [%[[VAL_20]], %[[VAL_21]]] [1, 1] -// LINALG: memref.copy %[[VAL_22]], %[[DEST_VIEW]] -// LINALG-SAME: : memref> to memref -// LINALG: } -// LINALG: return -// LINALG: } + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + transform.vector.split_transfer_full_partial %module_op + split_transfer_strategy = "vector-transfer" + : (!pdl.operation) -> !pdl.operation +} // ----- @@ -336,56 +221,12 @@ // CHECK: return // CHECK: } -// LINALG-DAG: #[[MAP1:.*]] = affine_map<()[s0] -> (s0 + 4)> -// LINALG-DAG: #[[MAP2:.*]] = affine_map<()[s0] -> (s0 + 8)> -// LINALG-DAG: #[[MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0 - d1, 4)> -// LINALG-DAG: #[[MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0 - d1, 8)> -// LINALG: func @split_vector_transfer_write_strided_2d( -// LINALG-SAME: %[[VEC:.*]]: vector<4x8xf32>, -// LINALG-SAME: %[[DEST:.*]]: memref<7x8xf32, strided<[?, 1], offset: ?>>, -// LINALG-SAME: %[[I:.*]]: index, -// LINALG-SAME: %[[J:.*]]: index) { -// LINALG-DAG: %[[C0:.*]] = arith.constant 0 : index -// LINALG-DAG: %[[CT:.*]] = arith.constant true -// LINALG-DAG: %[[C7:.*]] = arith.constant 7 : index -// LINALG-DAG: %[[C4:.*]] = arith.constant 4 : index -// LINALG-DAG: %[[C8:.*]] = arith.constant 8 : index -// LINALG: %[[TEMP:.*]] = memref.alloca() {alignment = 32 : i64} : memref<4x8xf32> -// LINALG: %[[DIM0:.*]] = affine.apply #[[MAP1]]()[%[[I]]] -// LINALG: %[[DIM0_IN:.*]] = arith.cmpi sle, %[[DIM0]], %[[C7]] : index -// LINALG: %[[DIM1:.*]] = affine.apply #[[MAP2]]()[%[[J]]] -// LINALG: %[[DIM1_IN:.*]] = arith.cmpi sle, %[[DIM1]], %[[C8]] : index -// LINALG: %[[IN_BOUNDS:.*]] = arith.andi %[[DIM0_IN]], %[[DIM1_IN]] : i1 -// LINALG: %[[IN_BOUND_DEST:.*]]:3 = scf.if %[[IN_BOUNDS]] -// LINALG-SAME: -> (memref>, index, index) { -// LINALG: %[[VAL_16:.*]] = memref.cast %[[DEST]] -// LINALG-SAME: : memref<7x8xf32, strided<[?, 1], offset: ?>> to memref> -// LINALG: scf.yield %[[VAL_16]], %[[I]], %[[J]] -// LINALG-SAME: : memref>, index, index -// LINALG: } else { -// LINALG: %[[VAL_17:.*]] = memref.cast %[[TEMP]] -// LINALG-SAME: : memref<4x8xf32> to memref> -// LINALG: scf.yield %[[VAL_17]], %[[C0]], %[[C0]] -// LINALG-SAME: : memref>, index, index -// LINALG: } -// LINALG: vector.transfer_write %[[VEC]], -// LINALG-SAME: %[[IN_BOUND_DEST:.*]]#0 -// LINALG-SAME: [%[[IN_BOUND_DEST]]#1, %[[IN_BOUND_DEST]]#2] -// LINALG-SAME: {in_bounds = [true, true]} -// LINALG-SAME: : vector<4x8xf32>, memref> -// LINALG: %[[OUT_BOUNDS:.*]] = arith.xori %[[IN_BOUNDS]], %[[CT]] : i1 -// LINALG: scf.if %[[OUT_BOUNDS]] { -// LINALG-DAG: %[[VAL_20:.*]] = affine.min #[[MAP3]](%[[C7]], %[[I]], %[[C4]]) -// LINALG-DAG: %[[VAL_21:.*]] = affine.min #[[MAP4]](%[[C8]], %[[J]], %[[C8]]) -// LINALG: %[[VAL_22:.*]] = memref.subview %[[TEMP]] -// LINALG-SAME: [%[[I]], %[[J]]] [%[[VAL_20]], %[[VAL_21]]] -// LINALG-SAME: [1, 1] : memref<4x8xf32> to memref> -// LINALG: %[[DEST_VIEW:.*]] = memref.subview %[[DEST]][0, 0] [%[[VAL_20]], %[[VAL_21]]] [1, 1] -// LINALG: memref.copy %[[VAL_22]], %[[DEST_VIEW]] -// LINALG-SAME: : memref> to memref> -// LINALG: } -// LINALG: return -// LINALG: } +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + transform.vector.split_transfer_full_partial %module_op + split_transfer_strategy = "vector-transfer" + : (!pdl.operation) -> !pdl.operation +} // ----- @@ -406,10 +247,6 @@ return %token : !async.token } -// ----- - -func.func private @fake_side_effecting_fun(%0: vector<2x2xf32>) -> () - // Ensure that `alloca`s are inserted outside of loops even though loops are // consdered allocation scopes. // CHECK-LABEL: transfer_read_within_scf_for @@ -425,3 +262,10 @@ } return } + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + transform.vector.split_transfer_full_partial %module_op + split_transfer_strategy = "vector-transfer" + : (!pdl.operation) -> !pdl.operation +} diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir --- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -test-vector-transfer-lowering-patterns -canonicalize -split-input-file | FileCheck %s +// RUN: mlir-opt %s --test-transform-dialect-interpreter -canonicalize --split-input-file | FileCheck %s // CHECK-LABEL: func @vector_transfer_ops_0d_memref( // CHECK-SAME: %[[MEM:.*]]: memref @@ -21,8 +21,6 @@ return } -// ----- - // CHECK-LABEL: func @vector_transfer_ops_0d_tensor( // CHECK-SAME: %[[SOURCE:.*]]: tensor func.func @vector_transfer_ops_0d_tensor(%M: tensor) -> vector<1xf32> { @@ -37,8 +35,6 @@ return %0: vector<1xf32> } -// ----- - // transfer_read/write are lowered to vector.load/store // CHECK-LABEL: func @transfer_to_load( // CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>, @@ -55,8 +51,6 @@ return %res : vector<4xf32> } -// ----- - // n-D results are also supported. // CHECK-LABEL: func @transfer_2D( // CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>, @@ -73,8 +67,6 @@ return %res : vector<2x4xf32> } -// ----- - // Vector element types are supported when the result has the same type. // CHECK-LABEL: func @transfer_vector_element( // CHECK-SAME: %[[MEM:.*]]: memref<8x8xvector<2x4xf32>>, @@ -91,8 +83,6 @@ return %res : vector<2x4xf32> } -// ----- - // TODO: Vector element types are not supported yet when the result has a // different type. // CHECK-LABEL: func @transfer_vector_element_different_types( @@ -111,8 +101,6 @@ return %res : vector<1x2x4xf32> } -// ----- - // TODO: transfer_read/write cannot be lowered because there is a dimension // that is not guaranteed to be in-bounds. // CHECK-LABEL: func @transfer_2D_not_inbounds( @@ -131,8 +119,6 @@ return %res : vector<2x4xf32> } -// ----- - // TODO: transfer_read/write cannot be lowered because they are not guaranteed // to be in-bounds. // CHECK-LABEL: func @transfer_not_inbounds( @@ -151,8 +137,6 @@ return %res : vector<4xf32> } -// ----- - // CHECK-LABEL: func @transfer_nondefault_layout( // CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32, #{{.*}}>, // CHECK-SAME: %[[IDX:.*]]: index) -> vector<4xf32> { @@ -169,8 +153,6 @@ return %res : vector<4xf32> } -// ----- - // TODO: transfer_read/write cannot be lowered to vector.load/store yet when the // permutation map is not the minor identity map (up to broadcasting). // CHECK-LABEL: func @transfer_perm_map( @@ -187,8 +169,6 @@ return %res : vector<4xf32> } -// ----- - // Lowering of transfer_read with broadcasting is supported (note that a `load` // is generated instead of a `vector.load`). // CHECK-LABEL: func @transfer_broadcasting( @@ -199,15 +179,15 @@ // CHECK-NEXT: return %[[RES]] : vector<4xf32> // CHECK-NEXT: } -#broadcast = affine_map<(d0, d1) -> (0)> +#broadcast_1d = affine_map<(d0, d1) -> (0)> func.func @transfer_broadcasting(%mem : memref<8x8xf32>, %i : index) -> vector<4xf32> { %cf0 = arith.constant 0.0 : f32 - %res = vector.transfer_read %mem[%i, %i], %cf0 {in_bounds = [true], permutation_map = #broadcast} : memref<8x8xf32>, vector<4xf32> + %res = vector.transfer_read %mem[%i, %i], %cf0 + {in_bounds = [true], permutation_map = #broadcast_1d} + : memref<8x8xf32>, vector<4xf32> return %res : vector<4xf32> } -// ----- - // CHECK-LABEL: func @transfer_scalar( // CHECK-SAME: %[[MEM:.*]]: memref, // CHECK-SAME: %[[IDX:.*]]: index) -> vector<1xf32> { @@ -221,8 +201,6 @@ return %res : vector<1xf32> } -// ----- - // An example with two broadcasted dimensions. // CHECK-LABEL: func @transfer_broadcasting_2D( // CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>, @@ -232,15 +210,15 @@ // CHECK-NEXT: return %[[RES]] : vector<4x4xf32> // CHECK-NEXT: } -#broadcast = affine_map<(d0, d1) -> (0, 0)> +#broadcast_2d = affine_map<(d0, d1) -> (0, 0)> func.func @transfer_broadcasting_2D(%mem : memref<8x8xf32>, %i : index) -> vector<4x4xf32> { %cf0 = arith.constant 0.0 : f32 - %res = vector.transfer_read %mem[%i, %i], %cf0 {in_bounds = [true, true], permutation_map = #broadcast} : memref<8x8xf32>, vector<4x4xf32> + %res = vector.transfer_read %mem[%i, %i], %cf0 + {in_bounds = [true, true], permutation_map = #broadcast_2d} + : memref<8x8xf32>, vector<4x4xf32> return %res : vector<4x4xf32> } -// ----- - // More complex broadcasting case (here a `vector.load` is generated). // CHECK-LABEL: func @transfer_broadcasting_complex( // CHECK-SAME: %[[MEM:.*]]: memref<10x20x30x8x8xf32>, @@ -250,13 +228,25 @@ // CHECK-NEXT: return %[[RES]] : vector<3x2x4x5xf32> // CHECK-NEXT: } -#broadcast = affine_map<(d0, d1, d2, d3, d4) -> (d1, 0, 0, d4)> +#broadcast_2d_in_4d = affine_map<(d0, d1, d2, d3, d4) -> (d1, 0, 0, d4)> func.func @transfer_broadcasting_complex(%mem : memref<10x20x30x8x8xf32>, %i : index) -> vector<3x2x4x5xf32> { %cf0 = arith.constant 0.0 : f32 - %res = vector.transfer_read %mem[%i, %i, %i, %i, %i], %cf0 {in_bounds = [true, true, true, true], permutation_map = #broadcast} : memref<10x20x30x8x8xf32>, vector<3x2x4x5xf32> + %res = vector.transfer_read %mem[%i, %i, %i, %i, %i], %cf0 + {in_bounds = [true, true, true, true], permutation_map = #broadcast_2d_in_4d} + : memref<10x20x30x8x8xf32>, vector<3x2x4x5xf32> return %res : vector<3x2x4x5xf32> } + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + %m2 = transform.vector.lower_transfer %module_op + max_transfer_rank = 99 + : (!pdl.operation) -> !pdl.operation + transform.vector.apply_transfer_permutation_patterns %m2 + : (!pdl.operation) -> !pdl.operation +} + // ----- #map0 = affine_map<(d0, d1, d2, d3) -> (d1, d0, 0, 0)> @@ -321,8 +311,6 @@ vector<7x14x8x16xf32>, vector<8xf32> } -// ----- - // CHECK-LABEL: func @transfer_write_permutations // CHECK-SAME: %[[ARG0:.*]]: memref // CHECK-SAME: %[[ARG1:.*]]: tensor @@ -348,8 +336,6 @@ return %0 : tensor } -// ----- - // CHECK-LABEL: func @transfer_write_broadcast_unit_dim // CHECK-SAME: %[[ARG0:.*]]: memref // CHECK-SAME: %[[ARG1:.*]]: tensor @@ -374,3 +360,12 @@ return %0 : tensor } + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + %m2 = transform.vector.lower_transfer %module_op + max_transfer_rank = 99 + : (!pdl.operation) -> !pdl.operation + transform.vector.apply_transfer_permutation_patterns %m2 + : (!pdl.operation) -> !pdl.operation +} diff --git a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir --- a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir @@ -1,666 +1,611 @@ -// RUN: mlir-opt %s -test-vector-transpose-lowering=eltwise=1 -split-input-file | FileCheck %s --check-prefix=ELTWISE -// RUN: mlir-opt %s -test-vector-transpose-lowering=shuffle=1 -split-input-file | FileCheck %s --check-prefix=SHUFFLE -// RUN: mlir-opt %s -test-vector-transpose-lowering=flat=1 -split-input-file | FileCheck %s --check-prefix=FLAT -// RUN: mlir-opt %s -test-vector-transpose-lowering=avx2=1 -split-input-file | FileCheck %s --check-prefix=AVX2 - -// ELTWISE-LABEL: func @transpose23 -// ELTWISE-SAME: %[[A:.*]]: vector<2x3xf32> -// ELTWISE: %[[Z:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32> -// ELTWISE: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<2x3xf32> -// ELTWISE: %[[T1:.*]] = vector.insert %[[T0]], %[[Z]] [0, 0] : f32 into vector<3x2xf32> -// ELTWISE: %[[T2:.*]] = vector.extract %[[A]][0, 1] : vector<2x3xf32> -// ELTWISE: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [1, 0] : f32 into vector<3x2xf32> -// ELTWISE: %[[T4:.*]] = vector.extract %[[A]][0, 2] : vector<2x3xf32> -// ELTWISE: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [2, 0] : f32 into vector<3x2xf32> -// ELTWISE: %[[T6:.*]] = vector.extract %[[A]][1, 0] : vector<2x3xf32> -// ELTWISE: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [0, 1] : f32 into vector<3x2xf32> -// ELTWISE: %[[T8:.*]] = vector.extract %[[A]][1, 1] : vector<2x3xf32> -// ELTWISE: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [1, 1] : f32 into vector<3x2xf32> -// ELTWISE: %[[T10:.*]] = vector.extract %[[A]][1, 2] : vector<2x3xf32> -// ELTWISE: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [2, 1] : f32 into vector<3x2xf32> -// ELTWISE: return %[[T11]] : vector<3x2xf32> +// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file | FileCheck %s + +// CHECK-LABEL: func @transpose23 +// CHECK-SAME: %[[A:.*]]: vector<2x3xf32> +// CHECK: %[[Z:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<2x3xf32> +// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[Z]] [0, 0] : f32 into vector<3x2xf32> +// CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 1] : vector<2x3xf32> +// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [1, 0] : f32 into vector<3x2xf32> +// CHECK: %[[T4:.*]] = vector.extract %[[A]][0, 2] : vector<2x3xf32> +// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [2, 0] : f32 into vector<3x2xf32> +// CHECK: %[[T6:.*]] = vector.extract %[[A]][1, 0] : vector<2x3xf32> +// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [0, 1] : f32 into vector<3x2xf32> +// CHECK: %[[T8:.*]] = vector.extract %[[A]][1, 1] : vector<2x3xf32> +// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [1, 1] : f32 into vector<3x2xf32> +// CHECK: %[[T10:.*]] = vector.extract %[[A]][1, 2] : vector<2x3xf32> +// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [2, 1] : f32 into vector<3x2xf32> +// CHECK: return %[[T11]] : vector<3x2xf32> func.func @transpose23(%arg0: vector<2x3xf32>) -> vector<3x2xf32> { %0 = vector.transpose %arg0, [1, 0] : vector<2x3xf32> to vector<3x2xf32> return %0 : vector<3x2xf32> } +// CHECK-LABEL: func @transpose102_1x8x8xf32 +func.func @transpose102_1x8x8xf32(%arg0: vector<1x8x8xf32>) -> vector<8x1x8xf32> { + // CHECK: vector.extract {{.*}}[0, 0] : vector<1x8x8xf32> + // CHECK-NEXT: vector.insert {{.*}} [0, 0] : vector<8xf32> into vector<8x1x8xf32> + // CHECK-NEXT: vector.extract {{.*}}[0, 1] : vector<1x8x8xf32> + // CHECK-NEXT: vector.insert {{.*}} [1, 0] : vector<8xf32> into vector<8x1x8xf32> + // CHECK-NEXT: vector.extract {{.*}}[0, 2] : vector<1x8x8xf32> + // CHECK-NEXT: vector.insert {{.*}} [2, 0] : vector<8xf32> into vector<8x1x8xf32> + // CHECK-NEXT: vector.extract {{.*}}[0, 3] : vector<1x8x8xf32> + // CHECK-NEXT: vector.insert {{.*}} [3, 0] : vector<8xf32> into vector<8x1x8xf32> + // CHECK-NEXT: vector.extract {{.*}}[0, 4] : vector<1x8x8xf32> + // CHECK-NEXT: vector.insert {{.*}} [4, 0] : vector<8xf32> into vector<8x1x8xf32> + // CHECK-NEXT: vector.extract {{.*}}[0, 5] : vector<1x8x8xf32> + // CHECK-NEXT: vector.insert {{.*}} [5, 0] : vector<8xf32> into vector<8x1x8xf32> + // CHECK-NEXT: vector.extract {{.*}}[0, 6] : vector<1x8x8xf32> + // CHECK-NEXT: vector.insert {{.*}} [6, 0] : vector<8xf32> into vector<8x1x8xf32> + // CHECK-NEXT: vector.extract {{.*}}[0, 7] : vector<1x8x8xf32> + // CHECK-NEXT: vector.insert {{.*}} [7, 0] : vector<8xf32> into vector<8x1x8xf32> + %0 = vector.transpose %arg0, [1, 0, 2] : vector<1x8x8xf32> to vector<8x1x8xf32> + return %0 : vector<8x1x8xf32> +} + +// CHECK-LABEL: func @transpose102_8x1x8xf32 +func.func @transpose102_8x1x8xf32(%arg0: vector<8x1x8xf32>) -> vector<1x8x8xf32> { + // CHECK: vector.extract {{.*}}[0, 0] : vector<8x1x8xf32> + // CHECK-NEXT: vector.insert {{.*}} [0, 0] : vector<8xf32> into vector<1x8x8xf32> + // CHECK-NEXT: vector.extract {{.*}}[1, 0] : vector<8x1x8xf32> + // CHECK-NEXT: vector.insert {{.*}} [0, 1] : vector<8xf32> into vector<1x8x8xf32> + // CHECK-NEXT: vector.extract {{.*}}[2, 0] : vector<8x1x8xf32> + // CHECK-NEXT: vector.insert {{.*}} [0, 2] : vector<8xf32> into vector<1x8x8xf32> + // CHECK-NEXT: vector.extract {{.*}}[3, 0] : vector<8x1x8xf32> + // CHECK-NEXT: vector.insert {{.*}} [0, 3] : vector<8xf32> into vector<1x8x8xf32> + // CHECK-NEXT: vector.extract {{.*}}[4, 0] : vector<8x1x8xf32> + // CHECK-NEXT: vector.insert {{.*}} [0, 4] : vector<8xf32> into vector<1x8x8xf32> + // CHECK-NEXT: vector.extract {{.*}}[5, 0] : vector<8x1x8xf32> + // CHECK-NEXT: vector.insert {{.*}} [0, 5] : vector<8xf32> into vector<1x8x8xf32> + // CHECK-NEXT: vector.extract {{.*}}[6, 0] : vector<8x1x8xf32> + // CHECK-NEXT: vector.insert {{.*}} [0, 6] : vector<8xf32> into vector<1x8x8xf32> + // CHECK-NEXT: vector.extract {{.*}}[7, 0] : vector<8x1x8xf32> + // CHECK-NEXT: vector.insert {{.*}} [0, 7] : vector<8xf32> into vector<1x8x8xf32> + %0 = vector.transpose %arg0, [1, 0, 2] : vector<8x1x8xf32> to vector<1x8x8xf32> + return %0 : vector<1x8x8xf32> +} + +// CHECK-LABEL: func @transpose1023_1x1x8x8xf32( +func.func @transpose1023_1x1x8x8xf32(%arg0: vector<1x1x8x8xf32>) -> vector<1x1x8x8xf32> { + // Note the single 2-D extract/insert pair since 2 and 3 are not transposed! + // CHECK: vector.extract {{.*}}[0, 0] : vector<1x1x8x8xf32> + // CHECK-NEXT: vector.insert {{.*}} [0, 0] : vector<8x8xf32> into vector<1x1x8x8xf32> + %0 = vector.transpose %arg0, [1, 0, 2, 3] : vector<1x1x8x8xf32> to vector<1x1x8x8xf32> + return %0 : vector<1x1x8x8xf32> +} + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + transform.vector.lower_transpose %module_op + lowering_strategy = "eltwise" + : (!pdl.operation) -> !pdl.operation +} + // ----- -// SHUFFLE-LABEL: func @transpose -// FLAT-LABEL: func @transpose( +// CHECK-LABEL: func @transpose func.func @transpose(%arg0: vector<2x4xf32>) -> vector<4x2xf32> { - // SHUFFLE: vector.shape_cast %{{.*}} : vector<2x4xf32> to vector<8xf32> + // CHECK: vector.shape_cast %{{.*}} : vector<2x4xf32> to vector<8xf32> // 0 4 // 0 1 2 3 1 5 // 4 5 6 7 -> 2 6 // 3 7 - // SHUFFLE-NEXT: vector.shuffle %{{.*}} [0, 4, 1, 5, 2, 6, 3, 7] : vector<8xf32>, vector<8xf32> - // SHUFFLE-NEXT: vector.shape_cast %{{.*}} : vector<8xf32> to vector<4x2xf32> + // CHECK-NEXT: vector.shuffle %{{.*}} [0, 4, 1, 5, 2, 6, 3, 7] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<8xf32> to vector<4x2xf32> + %0 = vector.transpose %arg0, [1, 0] : vector<2x4xf32> to vector<4x2xf32> + return %0 : vector<4x2xf32> +} + + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + transform.vector.lower_transpose %module_op + lowering_strategy = "shuffle" + : (!pdl.operation) -> !pdl.operation +} + +// ----- - // FLAT: vector.shape_cast {{.*}} : vector<2x4xf32> to vector<8xf32> - // FLAT: vector.flat_transpose %{{.*}} {columns = 2 : i32, rows = 4 : i32} : vector<8xf32> -> vector<8xf32> - // FLAT: vector.shape_cast {{.*}} : vector<8xf32> to vector<4x2xf32> +// CHECK-LABEL: func @transpose( +func.func @transpose(%arg0: vector<2x4xf32>) -> vector<4x2xf32> { + // CHECK: vector.shape_cast {{.*}} : vector<2x4xf32> to vector<8xf32> + // CHECK: vector.flat_transpose %{{.*}} {columns = 2 : i32, rows = 4 : i32} : vector<8xf32> -> vector<8xf32> + // CHECK: vector.shape_cast {{.*}} : vector<8xf32> to vector<4x2xf32> %0 = vector.transpose %arg0, [1, 0] : vector<2x4xf32> to vector<4x2xf32> return %0 : vector<4x2xf32> } + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + transform.vector.lower_transpose %module_op + lowering_strategy = "flat_transpose" + : (!pdl.operation) -> !pdl.operation +} + // ----- -// AVX2-LABEL: func @transpose4x8 +// CHECK-LABEL: func @transpose4x8 func.func @transpose4x8xf32(%arg0: vector<4x8xf32>) -> vector<8x4xf32> { - // AVX2: vector.extract {{.*}}[0] - // AVX2-NEXT: vector.extract {{.*}}[1] - // AVX2-NEXT: vector.extract {{.*}}[2] - // AVX2-NEXT: vector.extract {{.*}}[3] - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 8, 9, 4, 5, 12, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 3, 10, 11, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 8, 9, 4, 5, 12, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 3, 10, 11, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.insert {{.*}}[0] - // AVX2-NEXT: vector.insert {{.*}}[1] - // AVX2-NEXT: vector.insert {{.*}}[2] - // AVX2-NEXT: vector.insert {{.*}}[3] - // AVX2-NEXT: vector.shape_cast {{.*}} vector<4x8xf32> to vector<32xf32> - // AVX2-NEXT: vector.shape_cast {{.*}} vector<32xf32> to vector<8x4xf32> + // CHECK: vector.extract {{.*}}[0] + // CHECK-NEXT: vector.extract {{.*}}[1] + // CHECK-NEXT: vector.extract {{.*}}[2] + // CHECK-NEXT: vector.extract {{.*}}[3] + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 1, 8, 9, 4, 5, 12, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 3, 10, 11, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 1, 8, 9, 4, 5, 12, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 3, 10, 11, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.insert {{.*}}[0] + // CHECK-NEXT: vector.insert {{.*}}[1] + // CHECK-NEXT: vector.insert {{.*}}[2] + // CHECK-NEXT: vector.insert {{.*}}[3] + // CHECK-NEXT: vector.shape_cast {{.*}} vector<4x8xf32> to vector<32xf32> + // CHECK-NEXT: vector.shape_cast {{.*}} vector<32xf32> to vector<8x4xf32> %0 = vector.transpose %arg0, [1, 0] : vector<4x8xf32> to vector<8x4xf32> return %0 : vector<8x4xf32> } -// ----- - -// AVX2-LABEL: func @transpose021_1x4x8 +// CHECK-LABEL: func @transpose021_1x4x8 func.func @transpose021_1x4x8xf32(%arg0: vector<1x4x8xf32>) -> vector<1x8x4xf32> { - // AVX2: vector.extract {{.*}}[0, 0] - // AVX2-NEXT: vector.extract {{.*}}[0, 1] - // AVX2-NEXT: vector.extract {{.*}}[0, 2] - // AVX2-NEXT: vector.extract {{.*}}[0, 3] - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 8, 9, 4, 5, 12, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 3, 10, 11, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 8, 9, 4, 5, 12, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 3, 10, 11, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.insert {{.*}}[0] - // AVX2-NEXT: vector.insert {{.*}}[1] - // AVX2-NEXT: vector.insert {{.*}}[2] - // AVX2-NEXT: vector.insert {{.*}}[3] - // AVX2-NEXT: vector.shape_cast {{.*}} vector<4x8xf32> to vector<32xf32> - // AVX2-NEXT: vector.shape_cast {{.*}} vector<32xf32> to vector<1x8x4xf32> + // CHECK: vector.extract {{.*}}[0, 0] + // CHECK-NEXT: vector.extract {{.*}}[0, 1] + // CHECK-NEXT: vector.extract {{.*}}[0, 2] + // CHECK-NEXT: vector.extract {{.*}}[0, 3] + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 1, 8, 9, 4, 5, 12, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 3, 10, 11, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 1, 8, 9, 4, 5, 12, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 3, 10, 11, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.insert {{.*}}[0] + // CHECK-NEXT: vector.insert {{.*}}[1] + // CHECK-NEXT: vector.insert {{.*}}[2] + // CHECK-NEXT: vector.insert {{.*}}[3] + // CHECK-NEXT: vector.shape_cast {{.*}} vector<4x8xf32> to vector<32xf32> + // CHECK-NEXT: vector.shape_cast {{.*}} vector<32xf32> to vector<1x8x4xf32> %0 = vector.transpose %arg0, [0, 2, 1] : vector<1x4x8xf32> to vector<1x8x4xf32> return %0 : vector<1x8x4xf32> } -// ----- - -// AVX2-LABEL: func @transpose8x8 +// CHECK-LABEL: func @transpose8x8 func.func @transpose8x8xf32(%arg0: vector<8x8xf32>) -> vector<8x8xf32> { - // AVX2: vector.extract {{.*}}[0] - // AVX2-NEXT: vector.extract {{.*}}[1] - // AVX2-NEXT: vector.extract {{.*}}[2] - // AVX2-NEXT: vector.extract {{.*}}[3] - // AVX2-NEXT: vector.extract {{.*}}[4] - // AVX2-NEXT: vector.extract {{.*}}[5] - // AVX2-NEXT: vector.extract {{.*}}[6] - // AVX2-NEXT: vector.extract {{.*}}[7] - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-COUNT-4: vector.shuffle {{.*}} [2, 3, 8, 9, 6, 7, 12, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-COUNT-4: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> - // AVX2-COUNT-4: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.insert {{.*}}[0] - // AVX2-NEXT: vector.insert {{.*}}[1] - // AVX2-NEXT: vector.insert {{.*}}[2] - // AVX2-NEXT: vector.insert {{.*}}[3] - // AVX2-NEXT: vector.insert {{.*}}[4] - // AVX2-NEXT: vector.insert {{.*}}[5] - // AVX2-NEXT: vector.insert {{.*}}[6] - // AVX2-NEXT: vector.insert {{.*}}[7] + // CHECK: vector.extract {{.*}}[0] + // CHECK-NEXT: vector.extract {{.*}}[1] + // CHECK-NEXT: vector.extract {{.*}}[2] + // CHECK-NEXT: vector.extract {{.*}}[3] + // CHECK-NEXT: vector.extract {{.*}}[4] + // CHECK-NEXT: vector.extract {{.*}}[5] + // CHECK-NEXT: vector.extract {{.*}}[6] + // CHECK-NEXT: vector.extract {{.*}}[7] + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-COUNT-4: vector.shuffle {{.*}} [2, 3, 8, 9, 6, 7, 12, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-COUNT-4: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> + // CHECK-COUNT-4: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.insert {{.*}}[0] + // CHECK-NEXT: vector.insert {{.*}}[1] + // CHECK-NEXT: vector.insert {{.*}}[2] + // CHECK-NEXT: vector.insert {{.*}}[3] + // CHECK-NEXT: vector.insert {{.*}}[4] + // CHECK-NEXT: vector.insert {{.*}}[5] + // CHECK-NEXT: vector.insert {{.*}}[6] + // CHECK-NEXT: vector.insert {{.*}}[7] %0 = vector.transpose %arg0, [1, 0] : vector<8x8xf32> to vector<8x8xf32> return %0 : vector<8x8xf32> } -// ----- - -// AVX2-LABEL: func @transpose021_1x8x8 +// CHECK-LABEL: func @transpose021_1x8x8 func.func @transpose021_1x8x8xf32(%arg0: vector<1x8x8xf32>) -> vector<1x8x8xf32> { - // AVX2: vector.extract {{.*}}[0, 0] - // AVX2-NEXT: vector.extract {{.*}}[0, 1] - // AVX2-NEXT: vector.extract {{.*}}[0, 2] - // AVX2-NEXT: vector.extract {{.*}}[0, 3] - // AVX2-NEXT: vector.extract {{.*}}[0, 4] - // AVX2-NEXT: vector.extract {{.*}}[0, 5] - // AVX2-NEXT: vector.extract {{.*}}[0, 6] - // AVX2-NEXT: vector.extract {{.*}}[0, 7] - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-COUNT-4: vector.shuffle {{.*}} [2, 3, 8, 9, 6, 7, 12, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-COUNT-4: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> - // AVX2-COUNT-4: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.insert {{.*}}[0] - // AVX2-NEXT: vector.insert {{.*}}[1] - // AVX2-NEXT: vector.insert {{.*}}[2] - // AVX2-NEXT: vector.insert {{.*}}[3] - // AVX2-NEXT: vector.insert {{.*}}[4] - // AVX2-NEXT: vector.insert {{.*}}[5] - // AVX2-NEXT: vector.insert {{.*}}[6] - // AVX2-NEXT: vector.insert {{.*}}[7] - // AVX2-NEXT: vector.shape_cast %{{.*}} : vector<8x8xf32> to vector<1x8x8xf32> + // CHECK: vector.extract {{.*}}[0, 0] + // CHECK-NEXT: vector.extract {{.*}}[0, 1] + // CHECK-NEXT: vector.extract {{.*}}[0, 2] + // CHECK-NEXT: vector.extract {{.*}}[0, 3] + // CHECK-NEXT: vector.extract {{.*}}[0, 4] + // CHECK-NEXT: vector.extract {{.*}}[0, 5] + // CHECK-NEXT: vector.extract {{.*}}[0, 6] + // CHECK-NEXT: vector.extract {{.*}}[0, 7] + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-COUNT-4: vector.shuffle {{.*}} [2, 3, 8, 9, 6, 7, 12, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-COUNT-4: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> + // CHECK-COUNT-4: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.insert {{.*}}[0] + // CHECK-NEXT: vector.insert {{.*}}[1] + // CHECK-NEXT: vector.insert {{.*}}[2] + // CHECK-NEXT: vector.insert {{.*}}[3] + // CHECK-NEXT: vector.insert {{.*}}[4] + // CHECK-NEXT: vector.insert {{.*}}[5] + // CHECK-NEXT: vector.insert {{.*}}[6] + // CHECK-NEXT: vector.insert {{.*}}[7] + // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<8x8xf32> to vector<1x8x8xf32> %0 = vector.transpose %arg0, [0, 2, 1] : vector<1x8x8xf32> to vector<1x8x8xf32> return %0 : vector<1x8x8xf32> } -// ----- - -// AVX2-LABEL: func @transpose120_8x1x8 +// CHECK-LABEL: func @transpose120_8x1x8 func.func @transpose120_8x1x8xf32(%arg0: vector<8x1x8xf32>) -> vector<1x8x8xf32> { - // AVX2: vector.extract {{.*}}[0, 0] - // AVX2-NEXT: vector.extract {{.*}}[1, 0] - // AVX2-NEXT: vector.extract {{.*}}[2, 0] - // AVX2-NEXT: vector.extract {{.*}}[3, 0] - // AVX2-NEXT: vector.extract {{.*}}[4, 0] - // AVX2-NEXT: vector.extract {{.*}}[5, 0] - // AVX2-NEXT: vector.extract {{.*}}[6, 0] - // AVX2-NEXT: vector.extract {{.*}}[7, 0] - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-COUNT-4: vector.shuffle {{.*}} [2, 3, 8, 9, 6, 7, 12, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-COUNT-4: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> - // AVX2-COUNT-4: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.insert {{.*}}[0] - // AVX2-NEXT: vector.insert {{.*}}[1] - // AVX2-NEXT: vector.insert {{.*}}[2] - // AVX2-NEXT: vector.insert {{.*}}[3] - // AVX2-NEXT: vector.insert {{.*}}[4] - // AVX2-NEXT: vector.insert {{.*}}[5] - // AVX2-NEXT: vector.insert {{.*}}[6] - // AVX2-NEXT: vector.insert {{.*}}[7] - // AVX2-NEXT: vector.shape_cast %{{.*}} : vector<8x8xf32> to vector<1x8x8xf32> + // CHECK: vector.extract {{.*}}[0, 0] + // CHECK-NEXT: vector.extract {{.*}}[1, 0] + // CHECK-NEXT: vector.extract {{.*}}[2, 0] + // CHECK-NEXT: vector.extract {{.*}}[3, 0] + // CHECK-NEXT: vector.extract {{.*}}[4, 0] + // CHECK-NEXT: vector.extract {{.*}}[5, 0] + // CHECK-NEXT: vector.extract {{.*}}[6, 0] + // CHECK-NEXT: vector.extract {{.*}}[7, 0] + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-COUNT-4: vector.shuffle {{.*}} [2, 3, 8, 9, 6, 7, 12, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-COUNT-4: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> + // CHECK-COUNT-4: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.insert {{.*}}[0] + // CHECK-NEXT: vector.insert {{.*}}[1] + // CHECK-NEXT: vector.insert {{.*}}[2] + // CHECK-NEXT: vector.insert {{.*}}[3] + // CHECK-NEXT: vector.insert {{.*}}[4] + // CHECK-NEXT: vector.insert {{.*}}[5] + // CHECK-NEXT: vector.insert {{.*}}[6] + // CHECK-NEXT: vector.insert {{.*}}[7] + // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<8x8xf32> to vector<1x8x8xf32> %0 = vector.transpose %arg0, [1, 2, 0] : vector<8x1x8xf32> to vector<1x8x8xf32> return %0 : vector<1x8x8xf32> } -// ----- - -// AVX2-LABEL: func @transpose120_8x8x1 +// CHECK-LABEL: func @transpose120_8x8x1 func.func @transpose120_8x8x1xf32(%arg0: vector<8x8x1xf32>) -> vector<8x1x8xf32> { - // AVX2: vector.shape_cast %{{.*}} : vector<8x8x1xf32> to vector<8x8xf32> - // AVX2-NEXT: vector.extract {{.*}}[0] - // AVX2-NEXT: vector.extract {{.*}}[1] - // AVX2-NEXT: vector.extract {{.*}}[2] - // AVX2-NEXT: vector.extract {{.*}}[3] - // AVX2-NEXT: vector.extract {{.*}}[4] - // AVX2-NEXT: vector.extract {{.*}}[5] - // AVX2-NEXT: vector.extract {{.*}}[6] - // AVX2-NEXT: vector.extract {{.*}}[7] - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-COUNT-4: vector.shuffle {{.*}} [2, 3, 8, 9, 6, 7, 12, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-COUNT-4: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> - // AVX2-COUNT-4: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.insert {{.*}}[0] - // AVX2-NEXT: vector.insert {{.*}}[1] - // AVX2-NEXT: vector.insert {{.*}}[2] - // AVX2-NEXT: vector.insert {{.*}}[3] - // AVX2-NEXT: vector.insert {{.*}}[4] - // AVX2-NEXT: vector.insert {{.*}}[5] - // AVX2-NEXT: vector.insert {{.*}}[6] - // AVX2-NEXT: vector.insert {{.*}}[7] - // AVX2-NEXT: vector.shape_cast %{{.*}} : vector<8x8xf32> to vector<8x1x8xf32> + // CHECK: vector.shape_cast %{{.*}} : vector<8x8x1xf32> to vector<8x8xf32> + // CHECK-NEXT: vector.extract {{.*}}[0] + // CHECK-NEXT: vector.extract {{.*}}[1] + // CHECK-NEXT: vector.extract {{.*}}[2] + // CHECK-NEXT: vector.extract {{.*}}[3] + // CHECK-NEXT: vector.extract {{.*}}[4] + // CHECK-NEXT: vector.extract {{.*}}[5] + // CHECK-NEXT: vector.extract {{.*}}[6] + // CHECK-NEXT: vector.extract {{.*}}[7] + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-COUNT-4: vector.shuffle {{.*}} [2, 3, 8, 9, 6, 7, 12, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-COUNT-4: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> + // CHECK-COUNT-4: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.insert {{.*}}[0] + // CHECK-NEXT: vector.insert {{.*}}[1] + // CHECK-NEXT: vector.insert {{.*}}[2] + // CHECK-NEXT: vector.insert {{.*}}[3] + // CHECK-NEXT: vector.insert {{.*}}[4] + // CHECK-NEXT: vector.insert {{.*}}[5] + // CHECK-NEXT: vector.insert {{.*}}[6] + // CHECK-NEXT: vector.insert {{.*}}[7] + // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<8x8xf32> to vector<8x1x8xf32> %0 = vector.transpose %arg0, [1, 2, 0] : vector<8x8x1xf32> to vector<8x1x8xf32> return %0 : vector<8x1x8xf32> } -// ----- - -// AVX2-LABEL: func @transpose102_8x8x1 +// CHECK-LABEL: func @transpose102_8x8x1 func.func @transpose102_8x8x1xf32(%arg0: vector<8x8x1xf32>) -> vector<8x8x1xf32> { - // AVX2: vector.shape_cast %{{.*}} : vector<8x8x1xf32> to vector<8x8xf32> - // AVX2-NEXT: vector.extract {{.*}}[0] - // AVX2-NEXT: vector.extract {{.*}}[1] - // AVX2-NEXT: vector.extract {{.*}}[2] - // AVX2-NEXT: vector.extract {{.*}}[3] - // AVX2-NEXT: vector.extract {{.*}}[4] - // AVX2-NEXT: vector.extract {{.*}}[5] - // AVX2-NEXT: vector.extract {{.*}}[6] - // AVX2-NEXT: vector.extract {{.*}}[7] - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-COUNT-4: vector.shuffle {{.*}} [2, 3, 8, 9, 6, 7, 12, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-COUNT-4: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> - // AVX2-COUNT-4: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.insert {{.*}}[0] - // AVX2-NEXT: vector.insert {{.*}}[1] - // AVX2-NEXT: vector.insert {{.*}}[2] - // AVX2-NEXT: vector.insert {{.*}}[3] - // AVX2-NEXT: vector.insert {{.*}}[4] - // AVX2-NEXT: vector.insert {{.*}}[5] - // AVX2-NEXT: vector.insert {{.*}}[6] - // AVX2-NEXT: vector.insert {{.*}}[7] - // AVX2-NEXT: vector.shape_cast %{{.*}} : vector<8x8xf32> to vector<8x8x1xf32> + // CHECK: vector.shape_cast %{{.*}} : vector<8x8x1xf32> to vector<8x8xf32> + // CHECK-NEXT: vector.extract {{.*}}[0] + // CHECK-NEXT: vector.extract {{.*}}[1] + // CHECK-NEXT: vector.extract {{.*}}[2] + // CHECK-NEXT: vector.extract {{.*}}[3] + // CHECK-NEXT: vector.extract {{.*}}[4] + // CHECK-NEXT: vector.extract {{.*}}[5] + // CHECK-NEXT: vector.extract {{.*}}[6] + // CHECK-NEXT: vector.extract {{.*}}[7] + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-COUNT-4: vector.shuffle {{.*}} [2, 3, 8, 9, 6, 7, 12, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-COUNT-4: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> + // CHECK-COUNT-4: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.insert {{.*}}[0] + // CHECK-NEXT: vector.insert {{.*}}[1] + // CHECK-NEXT: vector.insert {{.*}}[2] + // CHECK-NEXT: vector.insert {{.*}}[3] + // CHECK-NEXT: vector.insert {{.*}}[4] + // CHECK-NEXT: vector.insert {{.*}}[5] + // CHECK-NEXT: vector.insert {{.*}}[6] + // CHECK-NEXT: vector.insert {{.*}}[7] + // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<8x8xf32> to vector<8x8x1xf32> %0 = vector.transpose %arg0, [1, 0, 2] : vector<8x8x1xf32> to vector<8x8x1xf32> return %0 : vector<8x8x1xf32> } -// ----- - -// AVX2-LABEL: func @transpose201_8x1x8 +// CHECK-LABEL: func @transpose201_8x1x8 func.func @transpose201_8x1x8xf32(%arg0: vector<8x1x8xf32>) -> vector<8x8x1xf32> { - // AVX2: vector.extract {{.*}}[0, 0] - // AVX2-NEXT: vector.extract {{.*}}[1, 0] - // AVX2-NEXT: vector.extract {{.*}}[2, 0] - // AVX2-NEXT: vector.extract {{.*}}[3, 0] - // AVX2-NEXT: vector.extract {{.*}}[4, 0] - // AVX2-NEXT: vector.extract {{.*}}[5, 0] - // AVX2-NEXT: vector.extract {{.*}}[6, 0] - // AVX2-NEXT: vector.extract {{.*}}[7, 0] - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-COUNT-4: vector.shuffle {{.*}} [2, 3, 8, 9, 6, 7, 12, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-COUNT-4: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> - // AVX2-COUNT-4: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.insert {{.*}}[0] - // AVX2-NEXT: vector.insert {{.*}}[1] - // AVX2-NEXT: vector.insert {{.*}}[2] - // AVX2-NEXT: vector.insert {{.*}}[3] - // AVX2-NEXT: vector.insert {{.*}}[4] - // AVX2-NEXT: vector.insert {{.*}}[5] - // AVX2-NEXT: vector.insert {{.*}}[6] - // AVX2-NEXT: vector.insert {{.*}}[7] - // AVX2-NEXT: vector.shape_cast %{{.*}} : vector<8x8xf32> to vector<8x8x1xf32> + // CHECK: vector.extract {{.*}}[0, 0] + // CHECK-NEXT: vector.extract {{.*}}[1, 0] + // CHECK-NEXT: vector.extract {{.*}}[2, 0] + // CHECK-NEXT: vector.extract {{.*}}[3, 0] + // CHECK-NEXT: vector.extract {{.*}}[4, 0] + // CHECK-NEXT: vector.extract {{.*}}[5, 0] + // CHECK-NEXT: vector.extract {{.*}}[6, 0] + // CHECK-NEXT: vector.extract {{.*}}[7, 0] + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-COUNT-4: vector.shuffle {{.*}} [2, 3, 8, 9, 6, 7, 12, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-COUNT-4: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> + // CHECK-COUNT-4: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.insert {{.*}}[0] + // CHECK-NEXT: vector.insert {{.*}}[1] + // CHECK-NEXT: vector.insert {{.*}}[2] + // CHECK-NEXT: vector.insert {{.*}}[3] + // CHECK-NEXT: vector.insert {{.*}}[4] + // CHECK-NEXT: vector.insert {{.*}}[5] + // CHECK-NEXT: vector.insert {{.*}}[6] + // CHECK-NEXT: vector.insert {{.*}}[7] + // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<8x8xf32> to vector<8x8x1xf32> %0 = vector.transpose %arg0, [2, 0, 1] : vector<8x1x8xf32> to vector<8x8x1xf32> return %0 : vector<8x8x1xf32> } -// ----- - -// AVX2-LABEL: func @transpose201_1x8x8 +// CHECK-LABEL: func @transpose201_1x8x8 func.func @transpose201_1x8x8xf32(%arg0: vector<1x8x8xf32>) -> vector<8x1x8xf32> { - // AVX2: vector.extract {{.*}}[0, 0] - // AVX2-NEXT: vector.extract {{.*}}[0, 1] - // AVX2-NEXT: vector.extract {{.*}}[0, 2] - // AVX2-NEXT: vector.extract {{.*}}[0, 3] - // AVX2-NEXT: vector.extract {{.*}}[0, 4] - // AVX2-NEXT: vector.extract {{.*}}[0, 5] - // AVX2-NEXT: vector.extract {{.*}}[0, 6] - // AVX2-NEXT: vector.extract {{.*}}[0, 7] - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-COUNT-4: vector.shuffle {{.*}} [2, 3, 8, 9, 6, 7, 12, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-COUNT-4: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> - // AVX2-COUNT-4: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.insert {{.*}}[0] - // AVX2-NEXT: vector.insert {{.*}}[1] - // AVX2-NEXT: vector.insert {{.*}}[2] - // AVX2-NEXT: vector.insert {{.*}}[3] - // AVX2-NEXT: vector.insert {{.*}}[4] - // AVX2-NEXT: vector.insert {{.*}}[5] - // AVX2-NEXT: vector.insert {{.*}}[6] - // AVX2-NEXT: vector.insert {{.*}}[7] - // AVX2-NEXT: vector.shape_cast %{{.*}} : vector<8x8xf32> to vector<8x1x8xf32> + // CHECK: vector.extract {{.*}}[0, 0] + // CHECK-NEXT: vector.extract {{.*}}[0, 1] + // CHECK-NEXT: vector.extract {{.*}}[0, 2] + // CHECK-NEXT: vector.extract {{.*}}[0, 3] + // CHECK-NEXT: vector.extract {{.*}}[0, 4] + // CHECK-NEXT: vector.extract {{.*}}[0, 5] + // CHECK-NEXT: vector.extract {{.*}}[0, 6] + // CHECK-NEXT: vector.extract {{.*}}[0, 7] + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-COUNT-4: vector.shuffle {{.*}} [2, 3, 8, 9, 6, 7, 12, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-COUNT-4: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> + // CHECK-COUNT-4: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.insert {{.*}}[0] + // CHECK-NEXT: vector.insert {{.*}}[1] + // CHECK-NEXT: vector.insert {{.*}}[2] + // CHECK-NEXT: vector.insert {{.*}}[3] + // CHECK-NEXT: vector.insert {{.*}}[4] + // CHECK-NEXT: vector.insert {{.*}}[5] + // CHECK-NEXT: vector.insert {{.*}}[6] + // CHECK-NEXT: vector.insert {{.*}}[7] + // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<8x8xf32> to vector<8x1x8xf32> %0 = vector.transpose %arg0, [2, 0, 1] : vector<1x8x8xf32> to vector<8x1x8xf32> return %0 : vector<8x1x8xf32> } -// ----- - -// AVX2-LABEL: func @transpose210_8x1x8 +// CHECK-LABEL: func @transpose210_8x1x8 func.func @transpose210_8x1x8xf32(%arg0: vector<8x1x8xf32>) -> vector<8x1x8xf32> { - // AVX2: vector.extract {{.*}}[0, 0] - // AVX2-NEXT: vector.extract {{.*}}[1, 0] - // AVX2-NEXT: vector.extract {{.*}}[2, 0] - // AVX2-NEXT: vector.extract {{.*}}[3, 0] - // AVX2-NEXT: vector.extract {{.*}}[4, 0] - // AVX2-NEXT: vector.extract {{.*}}[5, 0] - // AVX2-NEXT: vector.extract {{.*}}[6, 0] - // AVX2-NEXT: vector.extract {{.*}}[7, 0] - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-COUNT-4: vector.shuffle {{.*}} [2, 3, 8, 9, 6, 7, 12, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-COUNT-4: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> - // AVX2-COUNT-4: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.insert {{.*}}[0] - // AVX2-NEXT: vector.insert {{.*}}[1] - // AVX2-NEXT: vector.insert {{.*}}[2] - // AVX2-NEXT: vector.insert {{.*}}[3] - // AVX2-NEXT: vector.insert {{.*}}[4] - // AVX2-NEXT: vector.insert {{.*}}[5] - // AVX2-NEXT: vector.insert {{.*}}[6] - // AVX2-NEXT: vector.insert {{.*}}[7] - // AVX2-NEXT: vector.shape_cast %{{.*}} : vector<8x8xf32> to vector<8x1x8xf32> + // CHECK: vector.extract {{.*}}[0, 0] + // CHECK-NEXT: vector.extract {{.*}}[1, 0] + // CHECK-NEXT: vector.extract {{.*}}[2, 0] + // CHECK-NEXT: vector.extract {{.*}}[3, 0] + // CHECK-NEXT: vector.extract {{.*}}[4, 0] + // CHECK-NEXT: vector.extract {{.*}}[5, 0] + // CHECK-NEXT: vector.extract {{.*}}[6, 0] + // CHECK-NEXT: vector.extract {{.*}}[7, 0] + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-COUNT-4: vector.shuffle {{.*}} [2, 3, 8, 9, 6, 7, 12, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-COUNT-4: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> + // CHECK-COUNT-4: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.insert {{.*}}[0] + // CHECK-NEXT: vector.insert {{.*}}[1] + // CHECK-NEXT: vector.insert {{.*}}[2] + // CHECK-NEXT: vector.insert {{.*}}[3] + // CHECK-NEXT: vector.insert {{.*}}[4] + // CHECK-NEXT: vector.insert {{.*}}[5] + // CHECK-NEXT: vector.insert {{.*}}[6] + // CHECK-NEXT: vector.insert {{.*}}[7] + // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<8x8xf32> to vector<8x1x8xf32> %0 = vector.transpose %arg0, [2, 1, 0] : vector<8x1x8xf32> to vector<8x1x8xf32> return %0 : vector<8x1x8xf32> } -// ----- - -// AVX2-LABEL: func @transpose210_8x8x1 +// CHECK-LABEL: func @transpose210_8x8x1 func.func @transpose210_8x8x1xf32(%arg0: vector<8x8x1xf32>) -> vector<1x8x8xf32> { - // AVX2: vector.shape_cast %{{.*}} : vector<8x8x1xf32> to vector<8x8xf32> - // AVX2-NEXT: vector.extract {{.*}}[0] - // AVX2-NEXT: vector.extract {{.*}}[1] - // AVX2-NEXT: vector.extract {{.*}}[2] - // AVX2-NEXT: vector.extract {{.*}}[3] - // AVX2-NEXT: vector.extract {{.*}}[4] - // AVX2-NEXT: vector.extract {{.*}}[5] - // AVX2-NEXT: vector.extract {{.*}}[6] - // AVX2-NEXT: vector.extract {{.*}}[7] - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-COUNT-4: vector.shuffle {{.*}} [2, 3, 8, 9, 6, 7, 12, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-COUNT-4: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> - // AVX2-COUNT-4: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.insert {{.*}}[0] - // AVX2-NEXT: vector.insert {{.*}}[1] - // AVX2-NEXT: vector.insert {{.*}}[2] - // AVX2-NEXT: vector.insert {{.*}}[3] - // AVX2-NEXT: vector.insert {{.*}}[4] - // AVX2-NEXT: vector.insert {{.*}}[5] - // AVX2-NEXT: vector.insert {{.*}}[6] - // AVX2-NEXT: vector.insert {{.*}}[7] - // AVX2-NEXT: vector.shape_cast %{{.*}} : vector<8x8xf32> to vector<1x8x8xf32> + // CHECK: vector.shape_cast %{{.*}} : vector<8x8x1xf32> to vector<8x8xf32> + // CHECK-NEXT: vector.extract {{.*}}[0] + // CHECK-NEXT: vector.extract {{.*}}[1] + // CHECK-NEXT: vector.extract {{.*}}[2] + // CHECK-NEXT: vector.extract {{.*}}[3] + // CHECK-NEXT: vector.extract {{.*}}[4] + // CHECK-NEXT: vector.extract {{.*}}[5] + // CHECK-NEXT: vector.extract {{.*}}[6] + // CHECK-NEXT: vector.extract {{.*}}[7] + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-COUNT-4: vector.shuffle {{.*}} [2, 3, 8, 9, 6, 7, 12, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-COUNT-4: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> + // CHECK-COUNT-4: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.insert {{.*}}[0] + // CHECK-NEXT: vector.insert {{.*}}[1] + // CHECK-NEXT: vector.insert {{.*}}[2] + // CHECK-NEXT: vector.insert {{.*}}[3] + // CHECK-NEXT: vector.insert {{.*}}[4] + // CHECK-NEXT: vector.insert {{.*}}[5] + // CHECK-NEXT: vector.insert {{.*}}[6] + // CHECK-NEXT: vector.insert {{.*}}[7] + // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<8x8xf32> to vector<1x8x8xf32> %0 = vector.transpose %arg0, [2, 1, 0] : vector<8x8x1xf32> to vector<1x8x8xf32> return %0 : vector<1x8x8xf32> } -// ----- - -// AVX2-LABEL: func @transpose210_1x8x8 +// CHECK-LABEL: func @transpose210_1x8x8 func.func @transpose210_1x8x8xf32(%arg0: vector<1x8x8xf32>) -> vector<8x8x1xf32> { - // AVX2: vector.extract {{.*}}[0, 0] - // AVX2-NEXT: vector.extract {{.*}}[0, 1] - // AVX2-NEXT: vector.extract {{.*}}[0, 2] - // AVX2-NEXT: vector.extract {{.*}}[0, 3] - // AVX2-NEXT: vector.extract {{.*}}[0, 4] - // AVX2-NEXT: vector.extract {{.*}}[0, 5] - // AVX2-NEXT: vector.extract {{.*}}[0, 6] - // AVX2-NEXT: vector.extract {{.*}}[0, 7] - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-COUNT-4: vector.shuffle {{.*}} [2, 3, 8, 9, 6, 7, 12, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-COUNT-4: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> - // AVX2-COUNT-4: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.insert {{.*}}[0] - // AVX2-NEXT: vector.insert {{.*}}[1] - // AVX2-NEXT: vector.insert {{.*}}[2] - // AVX2-NEXT: vector.insert {{.*}}[3] - // AVX2-NEXT: vector.insert {{.*}}[4] - // AVX2-NEXT: vector.insert {{.*}}[5] - // AVX2-NEXT: vector.insert {{.*}}[6] - // AVX2-NEXT: vector.insert {{.*}}[7] - // AVX2-NEXT: vector.shape_cast %{{.*}} : vector<8x8xf32> to vector<8x8x1xf32> + // CHECK: vector.extract {{.*}}[0, 0] + // CHECK-NEXT: vector.extract {{.*}}[0, 1] + // CHECK-NEXT: vector.extract {{.*}}[0, 2] + // CHECK-NEXT: vector.extract {{.*}}[0, 3] + // CHECK-NEXT: vector.extract {{.*}}[0, 4] + // CHECK-NEXT: vector.extract {{.*}}[0, 5] + // CHECK-NEXT: vector.extract {{.*}}[0, 6] + // CHECK-NEXT: vector.extract {{.*}}[0, 7] + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-COUNT-4: vector.shuffle {{.*}} [2, 3, 8, 9, 6, 7, 12, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-COUNT-4: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> + // CHECK-COUNT-4: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.insert {{.*}}[0] + // CHECK-NEXT: vector.insert {{.*}}[1] + // CHECK-NEXT: vector.insert {{.*}}[2] + // CHECK-NEXT: vector.insert {{.*}}[3] + // CHECK-NEXT: vector.insert {{.*}}[4] + // CHECK-NEXT: vector.insert {{.*}}[5] + // CHECK-NEXT: vector.insert {{.*}}[6] + // CHECK-NEXT: vector.insert {{.*}}[7] + // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<8x8xf32> to vector<8x8x1xf32> %0 = vector.transpose %arg0, [2, 1, 0] : vector<1x8x8xf32> to vector<8x8x1xf32> return %0 : vector<8x8x1xf32> } -// ----- - -func.func @do_not_lower_nonf32_to_avx2(%arg0: vector<4x8xi32>) -> vector<8x4xi32> { - %0 = vector.transpose %arg0, [1, 0] : vector<4x8xi32> to vector<8x4xi32> - return %0 : vector<8x4xi32> -} - -// AVX2-NOT: vector.shuffle - -// ----- - -// AVX2-LABEL: func @transpose021_8x1x8 -func.func @transpose021_8x1x8xf32(%arg0: vector<8x1x8xf32>) -> vector<8x8x1xf32> { - %0 = vector.transpose %arg0, [0, 2, 1] : vector<8x1x8xf32> to vector<8x8x1xf32> - return %0 : vector<8x8x1xf32> +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + transform.vector.lower_transpose %module_op + avx2_lowering_strategy = true + : (!pdl.operation) -> !pdl.operation } - -// AVX2-NOT: vector.shuffle - -// ----- - -// AVX2-LABEL: func @transpose021_8x8x1 -func.func @transpose021_8x8x1xf32(%arg0: vector<8x8x1xf32>) -> vector<8x1x8xf32> { - %0 = vector.transpose %arg0, [0, 2, 1] : vector<8x8x1xf32> to vector<8x1x8xf32> - return %0 : vector<8x1x8xf32> -} - -// AVX2-NOT: vector.shuffle - -// ----- - -// ELTWISE-LABEL: func @transpose102_1x8x8xf32 -// AVX2-LABEL: func @transpose102_1x8x8 -func.func @transpose102_1x8x8xf32(%arg0: vector<1x8x8xf32>) -> vector<8x1x8xf32> { - // ELTWISE: vector.extract {{.*}}[0, 0] : vector<1x8x8xf32> - // ELTWISE-NEXT: vector.insert {{.*}} [0, 0] : vector<8xf32> into vector<8x1x8xf32> - // ELTWISE-NEXT: vector.extract {{.*}}[0, 1] : vector<1x8x8xf32> - // ELTWISE-NEXT: vector.insert {{.*}} [1, 0] : vector<8xf32> into vector<8x1x8xf32> - // ELTWISE-NEXT: vector.extract {{.*}}[0, 2] : vector<1x8x8xf32> - // ELTWISE-NEXT: vector.insert {{.*}} [2, 0] : vector<8xf32> into vector<8x1x8xf32> - // ELTWISE-NEXT: vector.extract {{.*}}[0, 3] : vector<1x8x8xf32> - // ELTWISE-NEXT: vector.insert {{.*}} [3, 0] : vector<8xf32> into vector<8x1x8xf32> - // ELTWISE-NEXT: vector.extract {{.*}}[0, 4] : vector<1x8x8xf32> - // ELTWISE-NEXT: vector.insert {{.*}} [4, 0] : vector<8xf32> into vector<8x1x8xf32> - // ELTWISE-NEXT: vector.extract {{.*}}[0, 5] : vector<1x8x8xf32> - // ELTWISE-NEXT: vector.insert {{.*}} [5, 0] : vector<8xf32> into vector<8x1x8xf32> - // ELTWISE-NEXT: vector.extract {{.*}}[0, 6] : vector<1x8x8xf32> - // ELTWISE-NEXT: vector.insert {{.*}} [6, 0] : vector<8xf32> into vector<8x1x8xf32> - // ELTWISE-NEXT: vector.extract {{.*}}[0, 7] : vector<1x8x8xf32> - // ELTWISE-NEXT: vector.insert {{.*}} [7, 0] : vector<8xf32> into vector<8x1x8xf32> - %0 = vector.transpose %arg0, [1, 0, 2] : vector<1x8x8xf32> to vector<8x1x8xf32> - return %0 : vector<8x1x8xf32> -} - -// AVX2-NOT: vector.shuffle - -// ----- - -// ELTWISE-LABEL: func @transpose102_8x1x8xf32 -// AVX2-LABEL: func @transpose102_8x1x8 -func.func @transpose102_8x1x8xf32(%arg0: vector<8x1x8xf32>) -> vector<1x8x8xf32> { - // ELTWISE: vector.extract {{.*}}[0, 0] : vector<8x1x8xf32> - // ELTWISE-NEXT: vector.insert {{.*}} [0, 0] : vector<8xf32> into vector<1x8x8xf32> - // ELTWISE-NEXT: vector.extract {{.*}}[1, 0] : vector<8x1x8xf32> - // ELTWISE-NEXT: vector.insert {{.*}} [0, 1] : vector<8xf32> into vector<1x8x8xf32> - // ELTWISE-NEXT: vector.extract {{.*}}[2, 0] : vector<8x1x8xf32> - // ELTWISE-NEXT: vector.insert {{.*}} [0, 2] : vector<8xf32> into vector<1x8x8xf32> - // ELTWISE-NEXT: vector.extract {{.*}}[3, 0] : vector<8x1x8xf32> - // ELTWISE-NEXT: vector.insert {{.*}} [0, 3] : vector<8xf32> into vector<1x8x8xf32> - // ELTWISE-NEXT: vector.extract {{.*}}[4, 0] : vector<8x1x8xf32> - // ELTWISE-NEXT: vector.insert {{.*}} [0, 4] : vector<8xf32> into vector<1x8x8xf32> - // ELTWISE-NEXT: vector.extract {{.*}}[5, 0] : vector<8x1x8xf32> - // ELTWISE-NEXT: vector.insert {{.*}} [0, 5] : vector<8xf32> into vector<1x8x8xf32> - // ELTWISE-NEXT: vector.extract {{.*}}[6, 0] : vector<8x1x8xf32> - // ELTWISE-NEXT: vector.insert {{.*}} [0, 6] : vector<8xf32> into vector<1x8x8xf32> - // ELTWISE-NEXT: vector.extract {{.*}}[7, 0] : vector<8x1x8xf32> - // ELTWISE-NEXT: vector.insert {{.*}} [0, 7] : vector<8xf32> into vector<1x8x8xf32> - %0 = vector.transpose %arg0, [1, 0, 2] : vector<8x1x8xf32> to vector<1x8x8xf32> - return %0 : vector<1x8x8xf32> -} - -// AVX2-NOT: vector.shuffle - -// ----- - -// ELTWISE-LABEL: func @transpose1023_1x1x8x8xf32( -// AVX2-LABEL: func @transpose1023_1x1x8x8 -func.func @transpose1023_1x1x8x8xf32(%arg0: vector<1x1x8x8xf32>) -> vector<1x1x8x8xf32> { - // Note the single 2-D extract/insert pair since 2 and 3 are not transposed! - // ELTWISE: vector.extract {{.*}}[0, 0] : vector<1x1x8x8xf32> - // ELTWISE-NEXT: vector.insert {{.*}} [0, 0] : vector<8x8xf32> into vector<1x1x8x8xf32> - %0 = vector.transpose %arg0, [1, 0, 2, 3] : vector<1x1x8x8xf32> to vector<1x1x8x8xf32> - return %0 : vector<1x1x8x8xf32> -} - -// AVX2-NOT: vector.shuffle - -// ----- - -// AVX2-LABEL: func @transpose120_1x8x8 -func.func @transpose120_1x8x8xf32(%arg0: vector<1x8x8xf32>) -> vector<8x8x1xf32> { - - %0 = vector.transpose %arg0, [1, 2, 0] : vector<1x8x8xf32> to vector<8x8x1xf32> - return %0 : vector<8x8x1xf32> -} - -// AVX2-NOT: vector.shuffle - -// ----- - -// AVX2-LABEL: func @transpose201_8x8x1 -func.func @transpose201_8x8x1xf32(%arg0: vector<8x8x1xf32>) -> vector<1x8x8xf32> { - %0 = vector.transpose %arg0, [2, 0, 1] : vector<8x8x1xf32> to vector<1x8x8xf32> - return %0 : vector<1x8x8xf32> -} - -// AVX2-NOT: vector.shuffle - diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -113,75 +113,6 @@ } }; -struct TestVectorContractionLowering - : public PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorContractionLowering) - - StringRef getArgument() const final { - return "test-vector-contraction-lowering"; - } - StringRef getDescription() const final { - return "Test lowering patterns that lower contract ops in the vector " - "dialect"; - } - TestVectorContractionLowering() = default; - TestVectorContractionLowering(const TestVectorContractionLowering &pass) - : PassWrapper(pass) {} - - Option lowerToFlatMatrix{ - *this, "vector-lower-matrix-intrinsics", - llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"), - llvm::cl::init(false)}; - Option lowerToOuterProduct{ - *this, "vector-outerproduct", - llvm::cl::desc("Lower vector.contract to vector.outerproduct"), - llvm::cl::init(false)}; - Option lowerToParallelArith{ - *this, "vector-parallel-arith", - llvm::cl::desc("Lower vector.contract to elementwise vector ops."), - llvm::cl::init(false)}; - - void runOnOperation() override { - RewritePatternSet patterns(&getContext()); - - // Test on one pattern in isolation. - if (lowerToOuterProduct) { - VectorContractLowering lowering = VectorContractLowering::OuterProduct; - VectorTransformsOptions options{lowering}; - populateVectorContractLoweringPatterns( - patterns, options, /*benefit=*/1, - /*disableOuterProductlowering=*/true); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); - return; - } - - if (lowerToParallelArith) { - vector::populateVectorContractLoweringPatterns( - patterns, - vector::VectorTransformsOptions().setVectorTransformsOptions( - vector::VectorContractLowering::ParallelArith)); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); - return; - } - - // Test on all contract lowering patterns. - VectorContractLowering contractLowering = VectorContractLowering::Dot; - if (lowerToFlatMatrix) - contractLowering = VectorContractLowering::Matmul; - VectorMultiReductionLowering vectorMultiReductionLowering = - VectorMultiReductionLowering::InnerParallel; - VectorTransformsOptions options{contractLowering, - vectorMultiReductionLowering, - VectorTransposeLowering()}; - populateVectorBroadcastLoweringPatterns(patterns); - populateVectorContractLoweringPatterns(patterns, options); - populateVectorMaskOpLoweringPatterns(patterns); - populateVectorShapeCastLoweringPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); - } -}; - struct TestVectorContractionPrepareForMMTLowering : public PassWrapper> { @@ -209,82 +140,6 @@ } }; -struct TestVectorTransposeLowering - : public PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransposeLowering) - - StringRef getArgument() const final { - return "test-vector-transpose-lowering"; - } - StringRef getDescription() const final { - return "Test lowering patterns that lower contract ops in the vector " - "dialect"; - } - TestVectorTransposeLowering() = default; - TestVectorTransposeLowering(const TestVectorTransposeLowering &pass) - : PassWrapper(pass) {} - - Option lowerToEltwise{ - *this, "eltwise", - llvm::cl::desc("Lower 2-D vector.transpose to eltwise insert/extract"), - llvm::cl::init(false)}; - Option lowerToFlatTranspose{ - *this, "flat", - llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"), - llvm::cl::init(false)}; - Option lowerToShuffleTranspose{ - *this, "shuffle", - llvm::cl::desc("Lower 2-D vector.transpose to shape_cast + shuffle"), - llvm::cl::init(false)}; - Option lowerToAvx2{ - *this, "avx2", - llvm::cl::desc("Lower vector.transpose to avx2-specific patterns"), - llvm::cl::init(false)}; - - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - - void runOnOperation() override { - func::FuncOp funcOp = getOperation(); - MLIRContext *context = funcOp.getContext(); - RewritePatternSet patterns(context); - - vector::VectorTransformsOptions vectorTransformOptions; - if (lowerToEltwise) { - vectorTransformOptions = - vectorTransformOptions.setVectorTransposeLowering( - VectorTransposeLowering::EltWise); - } - if (lowerToFlatTranspose) { - vectorTransformOptions = - vectorTransformOptions.setVectorTransposeLowering( - VectorTransposeLowering::Flat); - } - if (lowerToShuffleTranspose) { - vectorTransformOptions = - vectorTransformOptions.setVectorTransposeLowering( - VectorTransposeLowering::Shuffle); - } - vector::populateVectorTransposeLoweringPatterns(patterns, - vectorTransformOptions); - - if (lowerToAvx2) { - auto avx2LoweringOptions = - x86vector::avx2::LoweringOptions().setTransposeOptions( - x86vector::avx2::TransposeLoweringOptions() - .lower4x8xf32() - .lower8x8xf32()); - x86vector::avx2::populateSpecializedTransposeLoweringPatterns( - patterns, avx2LoweringOptions, /*benefit=*/10); - } - - if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) - return signalPassFailure(); - } -}; - struct TestVectorUnrollingPatterns : public PassWrapper> { @@ -435,47 +290,6 @@ llvm::cl::init(false)}; }; -struct TestVectorTransferFullPartialSplitPatterns - : public PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( - TestVectorTransferFullPartialSplitPatterns) - - StringRef getArgument() const final { - return "test-vector-transfer-full-partial-split"; - } - StringRef getDescription() const final { - return "Test lowering patterns to split " - "transfer ops via scf.if + linalg ops"; - } - TestVectorTransferFullPartialSplitPatterns() = default; - TestVectorTransferFullPartialSplitPatterns( - const TestVectorTransferFullPartialSplitPatterns &pass) - : PassWrapper(pass) {} - - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - - Option useLinalgOps{ - *this, "use-memref-copy", - llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + " - "memref.copy operations."), - llvm::cl::init(false)}; - void runOnOperation() override { - MLIRContext *ctx = &getContext(); - RewritePatternSet patterns(ctx); - VectorTransformsOptions options; - if (useLinalgOps) - options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy); - else - options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer); - populateVectorTransferFullPartialPatterns(patterns, options); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); - } -}; - struct TestScalarVectorTransferLoweringPatterns : public PassWrapper> { @@ -514,63 +328,6 @@ void runOnOperation() override { transferOpflowOpt(getOperation()); } }; -struct TestVectorTransferLoweringPatterns - : public PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( - TestVectorTransferLoweringPatterns) - - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - StringRef getArgument() const final { - return "test-vector-transfer-lowering-patterns"; - } - StringRef getDescription() const final { - return "Test lowering patterns to lower transfer ops to other vector ops"; - } - void runOnOperation() override { - RewritePatternSet patterns(&getContext()); - populateVectorTransferLoweringPatterns(patterns); - populateVectorTransferPermutationMapLoweringPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); - } -}; - -struct TestVectorMultiReductionLoweringPatterns - : public PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( - TestVectorMultiReductionLoweringPatterns) - - TestVectorMultiReductionLoweringPatterns() = default; - TestVectorMultiReductionLoweringPatterns( - const TestVectorMultiReductionLoweringPatterns &pass) - : PassWrapper(pass) {} - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - StringRef getArgument() const final { - return "test-vector-multi-reduction-lowering-patterns"; - } - StringRef getDescription() const final { - return "Test lowering patterns to lower vector.multi_reduction to other " - "vector ops"; - } - Option useOuterReductions{ - *this, "use-outer-reductions", - llvm::cl::desc("Move reductions to outer most dimensions"), - llvm::cl::init(false)}; - void runOnOperation() override { - RewritePatternSet patterns(&getContext()); - populateVectorMultiReductionLoweringPatterns( - patterns, useOuterReductions - ? vector::VectorMultiReductionLowering::InnerParallel - : vector::VectorMultiReductionLowering::InnerReduction); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); - } -}; - struct TestVectorTransferCollapseInnerMostContiguousDims : public PassWrapper> { @@ -621,25 +378,6 @@ } }; -struct TestVectorTransferDropUnitDimsPatterns - : public PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( - TestVectorTransferDropUnitDimsPatterns) - - StringRef getArgument() const final { - return "test-vector-transfer-drop-unit-dims-patterns"; - } - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - void runOnOperation() override { - RewritePatternSet patterns(&getContext()); - populateVectorTransferDropUnitDimsPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); - } -}; - struct TestFlattenVectorTransferPatterns : public PassWrapper> { @@ -923,32 +661,20 @@ void registerTestVectorLowerings() { PassRegistration(); - PassRegistration(); - PassRegistration(); - PassRegistration(); - PassRegistration(); PassRegistration(); - PassRegistration(); - PassRegistration(); PassRegistration(); - PassRegistration(); - - PassRegistration(); - PassRegistration(); PassRegistration(); - PassRegistration(); - PassRegistration(); PassRegistration();