diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td --- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td +++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td @@ -16,56 +16,528 @@ include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/OpBase.td" -def LowerVectorsOp : Op, - DeclareOpInterfaceMethods]> { + +// TODO: not isolated from above and better targetability of the op. +// TODO: not a functional-style transform. +def ApplyFoldingPatternsOp : + Op { + let description = [{ + Apply opt-in vector folding patterns that include: + - CreateMaskFolder + - MaskedLoadFolder + - MaskedStoreFolder + - GatherFolder + - ScatterFolder + - ExpandLoadFolder + - CompressStoreFolder + - StridedSliceConstantMaskFolder + - TransposeFolder (folds consecutive vector.transpose ops) + }]; + + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs TransformHandleTypeInterface:$results); + + let assemblyFormat = [{ + $target + attr-dict + `:` functional-type($target, results) + }]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + +// TODO: not isolated from above and better targetability of the op. +// TODO: not a functional-style transform. +def ApplyRankReducingShapeCastPatternsOp : + Op { + let description = [{ + Apply opt-in vector transfer permutation patterns that include: + - CastAwayExtractStridedSliceLeadingOneDim + - CastAwayInsertStridedSliceLeadingOneDim + - CastAwayInsertLeadingOneDim + - CastAwayTransferReadLeadingOneDim + - CastAwayTransferWriteLeadingOneDim + - CastAwayElementwiseLeadingOneDim + - CastAwayContractionLeadingOneDim + + These patterns are defined in VectorDropLeadUnitDim.cpp and have the effect + of rewriting a vector.transfer with a leading 1 dimension into a rank-reduced + version thanks to vector.broadcast and vector.shape_cast operations. + + This is complemented by vector.shape_cast folding patterns to try to + eliminate such operations. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs TransformHandleTypeInterface:$results); + + let assemblyFormat = [{ + $target + attr-dict + `:` functional-type($target, results) + }]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + +// TODO: not isolated from above and better targetability of the op. +// TODO: not a functional-style transform. +def ApplyRankReducingSubviewPatternsOp : + Op { + let description = [{ + Apply opt-in vector transfer permutation patterns that include: + - TransferReadDropUnitDimsPattern + - TransferWriteDropUnitDimsPattern + + These patterns have the effect of rewriting a vector.transfer with unit + dimensions into a rank-reduced version thanks to subview operations. + This is complemented by shape_cast folding patterns. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs TransformHandleTypeInterface:$results); + + let assemblyFormat = [{ + $target + attr-dict + `:` functional-type($target, results) + }]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + +// TODO: not isolated from above and better targetability of the op. +// TODO: not a functional-style transform. +def ApplyTransferPermutationPatternsOp : + Op { + let description = [{ + Apply opt-in vector transfer permutation patterns that include: + - TransferReadPermutationLowering + - TransferWritePermutationLowering + - TransferOpReduceRank + - TransferWriteNonPermutationLowering + + These patterns have the effect of rewriting a vector.transfer with an + aritrary permutation_map to a vector.transfer with a permutation_map that is + a minor identity followed by a vector.transpose. + + In other words, this makes the vector.transfer contiguous on the most minor + dimensions and materializes the permutation_map as a vector.transpose. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs TransformHandleTypeInterface:$results); + + let assemblyFormat = [{ + $target + attr-dict + `:` functional-type($target, results) + }]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + +// TODO: not isolated from above and better targetability of the op. +// TODO: not a functional-style transform. +// TODO: evolve lowering_strategy to proper enums. +def LowerBroadcastOp : Op { let description = [{ - Indicates that the vector operations nested under the isolated from above op - `target` should be lowered to finer-grained vector primitives. + Indicates that the vector outerproduct operations nested under the isolated + from above op `target` should be lowered to finer-grained vector primitives. + + This is usally a late step that is run after bufferization as part of the + process of lowering to e.g. LLVM or NVVM. + }]; - At this time, the transform is all or nothing. + let arguments = (ins TransformHandleTypeInterface:$target + ); + let results = (outs TransformHandleTypeInterface:$results); + + let assemblyFormat = [{ + $target + attr-dict + `:` functional-type($target, results) + }]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + +// TODO: not isolated from above and better targetability of the op. +// TODO: not a functional-style transform. +// TODO: evolve lowering_strategy to proper enums. +def LowerContractionOp : Op { + let description = [{ + Indicates that the vector contraction-like operations nested under the + isolated from above op `target` should be lowered to finer-grained vector + primitives. This is usally a late step that is run after bufferization as part of the process of lowering to e.g. LLVM or NVVM. }]; - // TODO: evolve this to proper enums. - let arguments = (ins PDL_Operation:$target, + let arguments = (ins TransformHandleTypeInterface:$target, DefaultValuedAttr:$contraction_lowering, + "vector::VectorContractLowering::OuterProduct">:$lowering_strategy + ); + let results = (outs TransformHandleTypeInterface:$results); + + let assemblyFormat = [{ + $target + (`lowering_strategy` `=` $lowering_strategy^)? + attr-dict + `:` functional-type($target, results) + }]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + +// TODO: not isolated from above and better targetability of the op. +// TODO: not a functional-style transform. +// TODO: evolve lowering_strategy to proper enums. +def LowerMaskOp : Op { + let description = [{ + Indicates that the vector mask operations nested under the isolated from + above op `target` should be lowered to finer-grained vector primitives. + + This is usally a late step that is run after bufferization as part of the + process of lowering to e.g. LLVM or NVVM. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target + ); + let results = (outs TransformHandleTypeInterface:$results); + + let assemblyFormat = [{ + $target + attr-dict + `:` functional-type($target, results) + }]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + +// TODO: not isolated from above and better targetability of the op. +// TODO: not a functional-style transform. +// TODO: evolve lowering_strategy to proper enums. +def LowerMultiReductionOp : Op { + let description = [{ + Indicates that the vector multi_reduction-like operations nested under the + isolated from above op `target` should be lowered to finer-grained vector + primitives. + + This is usally a late step that is run after bufferization as part of the + process of lowering to e.g. LLVM or NVVM. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target, DefaultValuedAttr: - $multireduction_lowering, + $lowering_strategy + ); + let results = (outs TransformHandleTypeInterface:$results); + + let assemblyFormat = [{ + $target + (`lowering_strategy` `=` $lowering_strategy^)? + attr-dict + `:` functional-type($target, results) + }]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + +// TODO: not isolated from above and better targetability of the op. +// TODO: not a functional-style transform. +// TODO: evolve lowering_strategy to proper enums. +def LowerOuterProductOp : Op { + let description = [{ + Indicates that the vector outerproduct operations nested under the isolated + from above op `target` should be lowered to finer-grained vector primitives. + + This is usally a late step that is run after bufferization as part of the + process of lowering to e.g. LLVM or NVVM. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target + ); + let results = (outs TransformHandleTypeInterface:$results); + + let assemblyFormat = [{ + $target + attr-dict + `:` functional-type($target, results) + }]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + +// TODO: not isolated from above and better targetability of the op. +// TODO: not a functional-style transform. +// TODO: evolve split_transfer_strategy to proper enums. +def SplitTransferFullPartialOp : Op { + let description = [{ + Indicates that the vector transfer operations nested under the + isolated from above op `target` should be split to full and partial parts. + + This is usally a late step that is run after bufferization as part of the + process of lowering to e.g. LLVM or NVVM. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target, DefaultValuedAttr:$split_transfers, + "vector::VectorTransferSplit::LinalgCopy">:$split_transfer_strategy + ); + let results = (outs TransformHandleTypeInterface:$results); + + let assemblyFormat = [{ + $target + (`split_transfer_strategy` `=` $split_transfer_strategy^)? + attr-dict + `:` functional-type($target, results) + }]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + + +// TODO: not isolated from above and better targetability of the op. +// TODO: not a functional-style transform. +def LowerTransferOp : Op { + let description = [{ + Indicates that the vector transfer operations nested under the + isolated from above op `target` should be lowered to finer-grained vector + primitives. + + This is usally a late step that is run after bufferization as part of the + process of lowering to e.g. LLVM or NVVM. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target, + DefaultValuedAttr:$max_transfer_rank + ); + let results = (outs TransformHandleTypeInterface:$results); + + let assemblyFormat = [{ + $target + (`max_transfer_rank` `=` $max_transfer_rank^)? + attr-dict + `:` functional-type($target, results) + }]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + +// TODO: not isolated from above and better targetability of the op. +// TODO: not a functional-style transform. +def TransferToScfOp : Op { + let description = [{ + Indicates that the vector transfer operations nested under the + isolated from above op `target` should be rewritten with scf.for loops over + finer-grained vector primitives. + + This is usally a late step that is run after bufferization as part of the + process of lowering to e.g. LLVM or NVVM. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target, + DefaultValuedAttr:$max_transfer_rank, + DefaultValuedAttr:$full_unroll + ); + let results = (outs TransformHandleTypeInterface:$results); + + let assemblyFormat = [{ + $target + oilist ( + `max_transfer_rank` `=` $max_transfer_rank + | `full_unroll` `=` $full_unroll + ) + attr-dict + `:` functional-type($target, results) + }]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + +// TODO: not isolated from above and better targetability of the op. +// TODO: not a functional-style transform. +def LowerShapeCastOp : Op { + let description = [{ + Indicates that the vector shape_cast operations nested under the + isolated from above op `target` should be lowered to finer-grained vector + primitives. + + This is usally a late step that is run after bufferization as part of the + process of lowering to e.g. LLVM or NVVM. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs TransformHandleTypeInterface:$results); + + let assemblyFormat = [{ + $target + attr-dict + `:` functional-type($target, results) + }]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + +// TODO: not isolated from above and better targetability of the op. +// TODO: not a functional-style transform. +// TODO: evolve lowering_strategy to proper enums. +def LowerTransposeOp : Op { + let description = [{ + Indicates that the vector transpose-like operations nested under the + isolated from above op `target` should be lowered to finer-grained vector + primitives. + + This is usally a late step that is run after bufferization as part of the + process of lowering to e.g. LLVM or NVVM. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target, DefaultValuedAttr:$transpose_lowering, - DefaultValuedAttr:$transpose_avx2_lowering, - DefaultValuedAttr:$unroll_vector_transfers + "vector::VectorTransposeLowering::EltWise">:$lowering_strategy, + DefaultValuedAttr:$avx2_lowering_strategy ); - let results = (outs PDL_Operation:$results); - - let builders = [ - OpBuilder<(ins "Type":$resultType, "Value":$target, - "const vector::LowerVectorsOptions &":$options), [{ - return build($_builder, $_state, resultType, target, - options.vectorContractLowering, - options.vectorMultiReductionLowering, options.vectorTransferSplit, - options.vectorTransposeLowering, options.transposeAVX2Lowering, - options.unrollVectorTransfers); - }] - > - ]; + let results = (outs TransformHandleTypeInterface:$results); let assemblyFormat = [{ $target oilist ( - `contraction_lowering` `=` $contraction_lowering - | `multireduction_lowering` `=` $multireduction_lowering - | `split_transfers` `=` $split_transfers - | `transpose_lowering` `=` $transpose_lowering + `lowering_strategy` `=` $lowering_strategy + | `avx2_lowering_strategy` `=` $avx2_lowering_strategy ) attr-dict + `:` functional-type($target, results) + }]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); }]; } diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h --- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h @@ -50,6 +50,14 @@ RewritePatternSet &patterns, VectorTransformsOptions options, PatternBenefit benefit = 1, bool disableOuterProductLowering = false); +/// Populate the pattern set with the following patterns: +/// +/// [OuterProductOpLowering] +/// Progressively lower a `vector.outerproduct` to linearized +/// `vector.extract` + `vector.fma` + `vector.insert`. +void populateVectorOuterProductLoweringPatterns(RewritePatternSet &patterns, + PatternBenefit benefit = 1); + /// Collect a set of patterns to convert vector.multi_reduction op into /// a sequence of vector.reduction ops. The patterns comprise: /// @@ -60,9 +68,9 @@ /// /// [ReduceMultiDimReductionRank] /// Once in innermost or outermost reduction -/// form, rewrites n-D vector.multi_reduction into 2-D vector.multi_reduction, -/// by introducing vector.shape_cast ops to collapse + multi-reduce + expand -/// back. +/// form, rewrites n-D vector.multi_reduction into 2-D +/// vector.multi_reduction, by introducing vector.shape_cast ops to collapse +/// + multi-reduce + expand back. /// /// [TwoDimMultiReductionToElementWise] /// Once in 2-D vector.multi_reduction form, with an **outermost** reduction @@ -72,14 +80,15 @@ /// [TwoDimMultiReductionToReduction] /// Once in 2-D vector.multi_reduction form, with an **innermost** reduction /// dimension, unroll the outer dimension to obtain a sequence of extract + -/// vector.reduction + insert. This can further lower to horizontal reduction -/// ops. +/// vector.reduction + insert. This can further lower to horizontal +/// reduction ops. /// /// [OneDimMultiReductionToTwoDim] /// For cases that reduce to 1-D vector reduction (and are thus missing -/// either a parallel or a reduction), we lift them back up to 2-D with a simple -/// vector.shape_cast to vector<1xk> so that the other patterns can kick in, -/// thus fully exiting out of the vector.multi_reduction abstraction. +/// either a parallel or a reduction), we lift them back up to 2-D with a +/// simple vector.shape_cast to vector<1xk> so that the other patterns can +/// kick in, thus fully exiting out of the vector.multi_reduction +/// abstraction. void populateVectorMultiReductionLoweringPatterns( RewritePatternSet &patterns, VectorMultiReductionLowering options, PatternBenefit benefit = 1); diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp --- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp +++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp @@ -25,101 +25,409 @@ using namespace mlir::transform; //===----------------------------------------------------------------------===// -// LowerVectorsOp +// LowerBroadcastOp //===----------------------------------------------------------------------===// -void transform::LowerVectorsOp::getEffects( - SmallVectorImpl &effects) { - consumesHandle(getTarget(), effects); - producesHandle(getResults(), effects); - modifiesPayload(effects); +DiagnosedSilenceableFailure transform::LowerBroadcastOp::applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state) { + // This check can't be part of the verifier because payload IR is + // independent from transform IR and may not even exist. + if (!target->hasTrait()) { + return mlir::emitDefiniteFailure(target, + "applies only to isolated-from-above " + "targets because it needs to apply " + "patterns greedily"); + } + + RewritePatternSet patterns(getContext()); + populateVectorBroadcastLoweringPatterns(patterns); + + if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) + return DiagnosedSilenceableFailure::definiteFailure(); + + results.push_back(target); + return DiagnosedSilenceableFailure::success(); +} + +//===----------------------------------------------------------------------===// +// LowerContractionOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::LowerContractionOp::applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state) { + // This check can't be part of the verifier because payload IR is + // independent from transform IR and may not even exist. + if (!target->hasTrait()) { + return mlir::emitDefiniteFailure(target, + "applies only to isolated-from-above " + "targets because it needs to apply " + "patterns greedily"); + } + + RewritePatternSet patterns(getContext()); + vector::VectorTransformsOptions vectorTransformOptions; + vectorTransformOptions.setVectorTransformsOptions(getLoweringStrategy()); + populateVectorContractLoweringPatterns(patterns, vectorTransformOptions, + /*benefit=*/1, + /*disableOuterProductLowering=*/true); + + if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) + return DiagnosedSilenceableFailure::definiteFailure(); + + results.push_back(target); + return DiagnosedSilenceableFailure::success(); +} + +//===----------------------------------------------------------------------===// +// LowerMaskOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::LowerMaskOp::applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state) { + // This check can't be part of the verifier because payload IR is + // independent from transform IR and may not even exist. + if (!target->hasTrait()) { + return mlir::emitDefiniteFailure(target, + "applies only to isolated-from-above " + "targets because it needs to apply " + "patterns greedily"); + } + + RewritePatternSet patterns(getContext()); + populateVectorMaskOpLoweringPatterns(patterns); + populateVectorMaskLoweringPatternsForSideEffectingOps(patterns); + + if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) + return DiagnosedSilenceableFailure::definiteFailure(); + + results.push_back(target); + return DiagnosedSilenceableFailure::success(); +} + +//===----------------------------------------------------------------------===// +// LowerMultiReductionOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::LowerMultiReductionOp::applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state) { + // This check can't be part of the verifier because payload IR is + // independent from transform IR and may not even exist. + if (!target->hasTrait()) { + return mlir::emitDefiniteFailure(target, + "applies only to isolated-from-above " + "targets because it needs to apply " + "patterns greedily"); + } + + RewritePatternSet patterns(getContext()); + vector::VectorTransformsOptions vectorTransformOptions; + vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy()); + vector::populateVectorMultiReductionLoweringPatterns( + patterns, vectorTransformOptions.vectorMultiReductionLowering); + + if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) + return DiagnosedSilenceableFailure::definiteFailure(); + + results.push_back(target); + return DiagnosedSilenceableFailure::success(); +} + +//===----------------------------------------------------------------------===// +// LowerOuterProductOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::LowerOuterProductOp::applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state) { + // This check can't be part of the verifier because payload IR is + // independent from transform IR and may not even exist. + if (!target->hasTrait()) { + return mlir::emitDefiniteFailure(target, + "applies only to isolated-from-above " + "targets because it needs to apply " + "patterns greedily"); + } + + RewritePatternSet patterns(getContext()); + populateVectorOuterProductLoweringPatterns(patterns); + + if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) + return DiagnosedSilenceableFailure::definiteFailure(); + + results.push_back(target); + return DiagnosedSilenceableFailure::success(); +} + +//===----------------------------------------------------------------------===// +// SplitTransferFullPartialOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::SplitTransferFullPartialOp::applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state) { + // This check can't be part of the verifier because payload IR is + // independent from transform IR and may not even exist. + if (!target->hasTrait()) { + return mlir::emitDefiniteFailure(target, + "applies only to isolated-from-above " + "targets because it needs to apply " + "patterns greedily"); + } + + RewritePatternSet patterns(getContext()); + vector::VectorTransformsOptions vectorTransformOptions; + vectorTransformOptions.setVectorTransferSplit(getSplitTransferStrategy()); + populateVectorTransferFullPartialPatterns(patterns, vectorTransformOptions); + + if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) + return DiagnosedSilenceableFailure::definiteFailure(); + + results.push_back(target); + return DiagnosedSilenceableFailure::success(); +} + +//===----------------------------------------------------------------------===// +// LowerTransferOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::LowerTransferOp::applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state) { + // This check can't be part of the verifier because payload IR is + // independent from transform IR and may not even exist. + if (!target->hasTrait()) { + return mlir::emitDefiniteFailure(target, + "applies only to isolated-from-above " + "targets because it needs to apply " + "patterns greedily"); + } + + RewritePatternSet patterns(getContext()); + vector::populateVectorTransferLoweringPatterns(patterns, + getMaxTransferRank()); + + if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) + return DiagnosedSilenceableFailure::definiteFailure(); + + results.push_back(target); + return DiagnosedSilenceableFailure::success(); +} + +//===----------------------------------------------------------------------===// +// TransferToScfOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::TransferToScfOp::applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state) { + // This check can't be part of the verifier because payload IR is + // independent from transform IR and may not even exist. + if (!target->hasTrait()) { + return mlir::emitDefiniteFailure(target, + "applies only to isolated-from-above " + "targets because it needs to apply " + "patterns greedily"); + } + + RewritePatternSet patterns(getContext()); + VectorTransferToSCFOptions vectorTransferToSCFOptions = + VectorTransferToSCFOptions() + .enableFullUnroll(getFullUnroll()) + .setTargetRank(getMaxTransferRank()); + populateVectorToSCFConversionPatterns(patterns, vectorTransferToSCFOptions); + + if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) + return DiagnosedSilenceableFailure::definiteFailure(); + + results.push_back(target); + return DiagnosedSilenceableFailure::success(); } -DiagnosedSilenceableFailure transform::LowerVectorsOp::apply( - mlir::transform::TransformResults &transformResults, - mlir::transform::TransformState &state) { - - SmallVector results; - ArrayRef payloadOps = state.getPayloadOps(getTarget()); - for (Operation *target : payloadOps) { - // This check can't be part of the verifier because payload IR is - // independent from transform IR and may not even exist. - if (!target->hasTrait()) { - return mlir::emitDefiniteFailure(target, - "applies only to isolated-from-above " - "targets because it needs to apply " - "patterns greedily"); - } - - MLIRContext *ctx = getContext(); - RewritePatternSet patterns(ctx); - vector::VectorTransposeLowering vectorTransposeLowering = - getTransposeLowering(); - vector::VectorMultiReductionLowering vectorMultiReductionLowering = - getMultireductionLowering(); - vector::VectorContractLowering vectorContractLowering = - getContractionLowering(); - vector::VectorTransferSplit vectorTransferSplit = getSplitTransfers(); - - vector::VectorTransformsOptions vectorTransformOptions; - vectorTransformOptions.setVectorTransformsOptions(vectorContractLowering) - .setVectorMultiReductionLowering(vectorMultiReductionLowering) - .setVectorTransposeLowering(vectorTransposeLowering) - .setVectorTransferSplit(vectorTransferSplit); - - VectorTransferToSCFOptions vectorTransferToSCFOptions = - VectorTransferToSCFOptions().enableFullUnroll( - getUnrollVectorTransfers()); - - int maxTransferRank = 1; +//===----------------------------------------------------------------------===// +// LowerShapeCastOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::LowerShapeCastOp::applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state) { + // This check can't be part of the verifier because payload IR is + // independent from transform IR and may not even exist. + if (!target->hasTrait()) { + return mlir::emitDefiniteFailure(target, + "applies only to isolated-from-above " + "targets because it needs to apply " + "patterns greedily"); + } + RewritePatternSet patterns(getContext()); + vector::populateVectorShapeCastLoweringPatterns(patterns); + + if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) + return DiagnosedSilenceableFailure::definiteFailure(); + + results.push_back(target); + return DiagnosedSilenceableFailure::success(); +} + +//===----------------------------------------------------------------------===// +// LowerTransposeOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::LowerTransposeOp::applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state) { + // This check can't be part of the verifier because payload IR is + // independent from transform IR and may not even exist. + if (!target->hasTrait()) { + return mlir::emitDefiniteFailure(target, + "applies only to isolated-from-above " + "targets because it needs to apply " + "patterns greedily"); + } + + RewritePatternSet patterns(getContext()); + vector::populateVectorTransposeLoweringPatterns( + patterns, vector::VectorTransformsOptions().setVectorTransposeLowering( + getLoweringStrategy())); + if (getAvx2LoweringStrategy()) { auto avx2LoweringOptions = x86vector::avx2::LoweringOptions().setTransposeOptions( x86vector::avx2::TransposeLoweringOptions() - .lower4x8xf32(getTransposeAvx2Lowering()) - .lower8x8xf32(getTransposeAvx2Lowering())); + .lower4x8xf32(true) + .lower8x8xf32(true)); + x86vector::avx2::populateSpecializedTransposeLoweringPatterns( + patterns, avx2LoweringOptions, /*benefit=*/10); + } - vector::populateVectorToVectorCanonicalizationPatterns(patterns); + if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) + return DiagnosedSilenceableFailure::definiteFailure(); - // In the future we may want to more finely select particular stages. - // Stage 1: contraction lowerings. - populateVectorContractLoweringPatterns( - patterns, vectorTransformOptions, /*benefit=*/1, - /*disableOuterProductLowering*/ true); - vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); + results.push_back(target); + return DiagnosedSilenceableFailure::success(); +} + +//===----------------------------------------------------------------------===// +// ApplyFoldingPatternsOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::ApplyFoldingPatternsOp::applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state) { + // This check can't be part of the verifier because payload IR is + // independent from transform IR and may not even exist. + if (!target->hasTrait()) { + return mlir::emitDefiniteFailure(target, + "applies only to isolated-from-above " + "targets because it needs to apply " + "patterns greedily"); + } + + RewritePatternSet patterns(getContext()); + vector::populateVectorToVectorCanonicalizationPatterns(patterns); + + if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) + return DiagnosedSilenceableFailure::definiteFailure(); + + results.push_back(target); + return DiagnosedSilenceableFailure::success(); +} + +//===----------------------------------------------------------------------===// +// ApplyTransferPermutationPatternsOp +//===----------------------------------------------------------------------===// - // Stage 2: multi-reduction lowerings. - vector::populateVectorMultiReductionLoweringPatterns( - patterns, vectorTransformOptions.vectorMultiReductionLowering); +DiagnosedSilenceableFailure +transform::ApplyTransferPermutationPatternsOp::applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state) { + // This check can't be part of the verifier because payload IR is + // independent from transform IR and may not even exist. + if (!target->hasTrait()) { + return mlir::emitDefiniteFailure(target, + "applies only to isolated-from-above " + "targets because it needs to apply " + "patterns greedily"); + } + + RewritePatternSet patterns(getContext()); + vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); - // Stage 3: Rewrite vector.transfer into full and partial parts. - populateVectorTransferFullPartialPatterns(patterns, vectorTransformOptions); + if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) + return DiagnosedSilenceableFailure::definiteFailure(); - // Stage 4: Lower vector transfers. - vector::populateVectorTransferLoweringPatterns(patterns, maxTransferRank); + results.push_back(target); + return DiagnosedSilenceableFailure::success(); +} - // Stage 5: Vector to scf patterns. - populateVectorToSCFConversionPatterns( - patterns, vectorTransferToSCFOptions.setTargetRank(maxTransferRank)); +//===----------------------------------------------------------------------===// +// ApplyRankReducingShapeCastPatternsOp +//===----------------------------------------------------------------------===// - // Stage 6: Lower vector.shape_cast. - vector::populateVectorShapeCastLoweringPatterns(patterns); +DiagnosedSilenceableFailure +transform::ApplyRankReducingShapeCastPatternsOp::applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state) { + // This check can't be part of the verifier because payload IR is + // independent from transform IR and may not even exist. + if (!target->hasTrait()) { + return mlir::emitDefiniteFailure(target, + "applies only to isolated-from-above " + "targets because it needs to apply " + "patterns greedily"); + } - // Stage 7: Lower vector.transpose. - vector::populateVectorTransposeLoweringPatterns( - patterns, vectorTransformOptions, /*benefit=*/1); - if (getTransposeAvx2Lowering()) - x86vector::avx2::populateSpecializedTransposeLoweringPatterns( - patterns, avx2LoweringOptions, /*benefit=*/10); + RewritePatternSet patterns(getContext()); + vector::populateCastAwayVectorLeadingOneDimPatterns(patterns); - // Apply everything. - if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) - return DiagnosedSilenceableFailure::definiteFailure(); + if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) + return DiagnosedSilenceableFailure::definiteFailure(); - results.push_back(target); + results.push_back(target); + return DiagnosedSilenceableFailure::success(); +} + +//===----------------------------------------------------------------------===// +// ApplyRankReducingSubviewPatternsOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::ApplyRankReducingSubviewPatternsOp::applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state) { + // This check can't be part of the verifier because payload IR is + // independent from transform IR and may not even exist. + if (!target->hasTrait()) { + return mlir::emitDefiniteFailure(target, + "applies only to isolated-from-above " + "targets because it needs to apply " + "patterns greedily"); } - transformResults.set(getResults().cast(), results); + RewritePatternSet patterns(getContext()); + vector::populateVectorTransferDropUnitDimsPatterns(patterns); + + if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) + return DiagnosedSilenceableFailure::definiteFailure(); + + results.push_back(target); return DiagnosedSilenceableFailure::success(); } diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp @@ -1327,3 +1327,8 @@ ContractionOpToOuterProductOpLowering>( options, patterns.getContext(), benefit); } + +void mlir::vector::populateVectorOuterProductLoweringPatterns( + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add(patterns.getContext(), benefit); +} diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp @@ -379,15 +379,15 @@ // We let the 0-d corner case pass-through as it is supported. if (!read.getPermutationMap().isMinorIdentityWithBroadcasting( &broadcastedDims)) - return failure(); + return rewriter.notifyMatchFailure(read, "not minor identity + bcast"); auto memRefType = read.getShapedType().dyn_cast(); if (!memRefType) - return failure(); + return rewriter.notifyMatchFailure(read, "not a memref source"); // Non-unit strides are handled by VectorToSCF. if (!vector::isLastMemrefDimUnitStride(memRefType)) - return failure(); + return rewriter.notifyMatchFailure(read, "!= 1 stride needs VectorToSCF"); // If there is broadcasting involved then we first load the unbroadcasted // vector, and then broadcast it with `vector.broadcast`. @@ -403,16 +403,16 @@ // resulting vector type is the same as the element type. auto memrefElTy = memRefType.getElementType(); if (memrefElTy.isa() && memrefElTy != unbroadcastedVectorType) - return failure(); + return rewriter.notifyMatchFailure(read, "incompatible element type"); // Otherwise, element types of the memref and the vector must match. if (!memrefElTy.isa() && memrefElTy != read.getVectorType().getElementType()) - return failure(); + return rewriter.notifyMatchFailure(read, "non-matching element type"); // Out-of-bounds dims are handled by MaterializeTransferMask. if (read.hasOutOfBoundsDim()) - return failure(); + return rewriter.notifyMatchFailure(read, "out-of-bounds needs mask"); // Create vector load op. Operation *loadOp; @@ -458,7 +458,8 @@ PatternRewriter &rewriter) const override { auto vecType = loadOp.getVectorType(); if (vecType.getNumElements() != 1) - return failure(); + return rewriter.notifyMatchFailure(loadOp, "not a single element vector"); + auto memrefLoad = rewriter.create( loadOp.getLoc(), loadOp.getBase(), loadOp.getIndices()); rewriter.replaceOpWithNewOp(loadOp, vecType, @@ -476,7 +477,8 @@ PatternRewriter &rewriter) const override { auto vecType = storeOp.getVectorType(); if (vecType.getNumElements() != 1) - return failure(); + return rewriter.notifyMatchFailure(storeOp, "not single element vector"); + Value extracted; if (vecType.getRank() == 0) { // TODO: Unifiy once ExtractOp supports 0-d vectors. diff --git a/mlir/test/Dialect/LLVM/transform-e2e.mlir b/mlir/test/Dialect/LLVM/transform-e2e.mlir --- a/mlir/test/Dialect/LLVM/transform-e2e.mlir +++ b/mlir/test/Dialect/LLVM/transform-e2e.mlir @@ -20,6 +20,39 @@ transform.structured.vectorize %2 transform.bufferization.one_shot_bufferize layout{IdentityLayoutMap} %module_op {bufferize_function_boundaries = true} - %func = transform.structured.match ops{["func.func"]} in %module_op : (!pdl.operation) -> !pdl.operation - transform.vector.lower_vectors %func multireduction_lowering = "innerreduction" + + %f = transform.structured.match ops{["func.func"]} in %module_op + : (!pdl.operation) -> !pdl.operation + + // TODO: group these lower-level controls into various properly named vector + // lowering TD macros. + %func = transform.vector.lower_contraction %f + lowering_strategy = "outerproduct" + : (!pdl.operation) -> !pdl.operation + + %func_2 = transform.vector.apply_transfer_permutation_patterns %func + : (!pdl.operation) -> !pdl.operation + + %func_3 = transform.vector.lower_multi_reduction %func_2 + lowering_strategy = "innerparallel" + : (!pdl.operation) -> !pdl.operation + + %func_4 = transform.vector.split_transfer_full_partial %func_3 + split_transfer_strategy = "linalg-copy" + : (!pdl.operation) -> !pdl.operation + + %func_5 = transform.vector.transfer_to_scf %func_4 + max_transfer_rank = 1 full_unroll = true + : (!pdl.operation) -> !pdl.operation + + %func_6 = transform.vector.lower_transfer %func_5 + max_transfer_rank = 1 + : (!pdl.operation) -> !pdl.operation + + %func_7 = transform.vector.lower_shape_cast %func_6 + : (!pdl.operation) -> !pdl.operation + + %func_8 = transform.vector.lower_transpose %func_7 + lowering_strategy = "shuffle" + : (!pdl.operation) -> !pdl.operation } diff --git a/mlir/test/Dialect/Vector/transform-vector.mlir b/mlir/test/Dialect/Vector/transform-vector.mlir --- a/mlir/test/Dialect/Vector/transform-vector.mlir +++ b/mlir/test/Dialect/Vector/transform-vector.mlir @@ -16,11 +16,45 @@ transform.sequence failures(propagate) { ^bb1(%module_op: !pdl.operation): %0 = transform.structured.match ops{["linalg.matmul"]} in %module_op : (!pdl.operation) -> !pdl.operation - %1, %loops:3 = transform.structured.tile %0 [8, 4, 2] : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation) + %1, %loops:3 = transform.structured.tile %0 [8, 4, 2] + : (!pdl.operation) -> (!pdl.operation, !pdl.operation, !pdl.operation, !pdl.operation) %2 = get_closest_isolated_parent %1 : (!pdl.operation) -> !pdl.operation transform.structured.vectorize %2 - transform.bufferization.one_shot_bufferize %module_op + transform.bufferization.one_shot_bufferize layout{IdentityLayoutMap} %module_op + {bufferize_function_boundaries = true, allow_return_allocs = true} - %func = transform.structured.match ops{["func.func"]} in %module_op : (!pdl.operation) -> !pdl.operation - transform.vector.lower_vectors %func multireduction_lowering = "innerreduction" + %f = transform.structured.match ops{["func.func"]} in %module_op + : (!pdl.operation) -> !pdl.operation + + // TODO: group these lower-level controls into various properly named vector + // lowering TD macros. + %func = transform.vector.lower_contraction %f + lowering_strategy = "outerproduct" + : (!pdl.operation) -> !pdl.operation + + %func_2 = transform.vector.apply_transfer_permutation_patterns %func + : (!pdl.operation) -> !pdl.operation + + %func_3 = transform.vector.lower_multi_reduction %func_2 + lowering_strategy = "innerparallel" + : (!pdl.operation) -> !pdl.operation + + %func_4 = transform.vector.split_transfer_full_partial %func_3 + split_transfer_strategy = "linalg-copy" + : (!pdl.operation) -> !pdl.operation + + %func_5 = transform.vector.transfer_to_scf %func_4 + max_transfer_rank = 1 full_unroll = true + : (!pdl.operation) -> !pdl.operation + + %func_6 = transform.vector.lower_transfer %func_5 + max_transfer_rank = 1 + : (!pdl.operation) -> !pdl.operation + + %func_7 = transform.vector.lower_shape_cast %func_6 + : (!pdl.operation) -> !pdl.operation + + %func_8 = transform.vector.lower_transpose %func_7 + lowering_strategy = "shuffle" + : (!pdl.operation) -> !pdl.operation } diff --git a/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir @@ -0,0 +1,172 @@ +// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file | FileCheck %s + +// CHECK-LABEL: func @broadcast_vec1d_from_scalar +// CHECK-SAME: %[[A:.*0]]: f32 +// CHECK: %[[T0:.*]] = vector.splat %[[A]] : vector<2xf32> +// CHECK: return %[[T0]] : vector<2xf32> + +func.func @broadcast_vec1d_from_scalar(%arg0: f32) -> vector<2xf32> { + %0 = vector.broadcast %arg0 : f32 to vector<2xf32> + return %0 : vector<2xf32> +} + +// CHECK-LABEL: func @broadcast_vec2d_from_scalar +// CHECK-SAME: %[[A:.*0]]: f32 +// CHECK: %[[T0:.*]] = vector.splat %[[A]] : vector<2x3xf32> +// CHECK: return %[[T0]] : vector<2x3xf32> + +func.func @broadcast_vec2d_from_scalar(%arg0: f32) -> vector<2x3xf32> { + %0 = vector.broadcast %arg0 : f32 to vector<2x3xf32> + return %0 : vector<2x3xf32> +} + +// CHECK-LABEL: func @broadcast_vec3d_from_scalar +// CHECK-SAME: %[[A:.*0]]: f32 +// CHECK: %[[T0:.*]] = vector.splat %[[A]] : vector<2x3x4xf32> +// CHECK: return %[[T0]] : vector<2x3x4xf32> + +func.func @broadcast_vec3d_from_scalar(%arg0: f32) -> vector<2x3x4xf32> { + %0 = vector.broadcast %arg0 : f32 to vector<2x3x4xf32> + return %0 : vector<2x3x4xf32> +} + +// CHECK-LABEL: func @broadcast_vec1d_from_vec1d +// CHECK-SAME: %[[A:.*0]]: vector<2xf32> +// CHECK: return %[[A]] : vector<2xf32> + +func.func @broadcast_vec1d_from_vec1d(%arg0: vector<2xf32>) -> vector<2xf32> { + %0 = vector.broadcast %arg0 : vector<2xf32> to vector<2xf32> + return %0 : vector<2xf32> +} + +// CHECK-LABEL: func @broadcast_vec2d_from_vec1d +// CHECK-SAME: %[[A:.*0]]: vector<2xf32> +// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32> +// CHECK: %[[T0:.*]] = vector.insert %[[A]], %[[C0]] [0] : vector<2xf32> into vector<3x2xf32> +// CHECK: %[[T1:.*]] = vector.insert %[[A]], %[[T0]] [1] : vector<2xf32> into vector<3x2xf32> +// CHECK: %[[T2:.*]] = vector.insert %[[A]], %[[T1]] [2] : vector<2xf32> into vector<3x2xf32> +// CHECK: return %[[T2]] : vector<3x2xf32> + +func.func @broadcast_vec2d_from_vec1d(%arg0: vector<2xf32>) -> vector<3x2xf32> { + %0 = vector.broadcast %arg0 : vector<2xf32> to vector<3x2xf32> + return %0 : vector<3x2xf32> +} + +// CHECK-LABEL: func @broadcast_vec3d_from_vec1d +// CHECK-SAME: %[[A:.*0]]: vector<2xf32> +// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32> +// CHECK: %[[C1:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x2xf32> +// CHECK: %[[T0:.*]] = vector.insert %[[A]], %[[C0]] [0] : vector<2xf32> into vector<3x2xf32> +// CHECK: %[[T1:.*]] = vector.insert %[[A]], %[[T0]] [1] : vector<2xf32> into vector<3x2xf32> +// CHECK: %[[T2:.*]] = vector.insert %[[A]], %[[T1]] [2] : vector<2xf32> into vector<3x2xf32> +// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C1]] [0] : vector<3x2xf32> into vector<4x3x2xf32> +// CHECK: %[[T4:.*]] = vector.insert %[[T2]], %[[T3]] [1] : vector<3x2xf32> into vector<4x3x2xf32> +// CHECK: %[[T5:.*]] = vector.insert %[[T2]], %[[T4]] [2] : vector<3x2xf32> into vector<4x3x2xf32> +// CHECK: %[[T6:.*]] = vector.insert %[[T2]], %[[T5]] [3] : vector<3x2xf32> into vector<4x3x2xf32> +// CHECK: return %[[T6]] : vector<4x3x2xf32> + +func.func @broadcast_vec3d_from_vec1d(%arg0: vector<2xf32>) -> vector<4x3x2xf32> { + %0 = vector.broadcast %arg0 : vector<2xf32> to vector<4x3x2xf32> + return %0 : vector<4x3x2xf32> +} + +// CHECK-LABEL: func @broadcast_vec3d_from_vec2d +// CHECK-SAME: %[[A:.*0]]: vector<3x2xf32> +// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x2xf32> +// CHECK: %[[T0:.*]] = vector.insert %[[A]], %[[C0]] [0] : vector<3x2xf32> into vector<4x3x2xf32> +// CHECK: %[[T1:.*]] = vector.insert %[[A]], %[[T0]] [1] : vector<3x2xf32> into vector<4x3x2xf32> +// CHECK: %[[T2:.*]] = vector.insert %[[A]], %[[T1]] [2] : vector<3x2xf32> into vector<4x3x2xf32> +// CHECK: %[[T3:.*]] = vector.insert %[[A]], %[[T2]] [3] : vector<3x2xf32> into vector<4x3x2xf32> +// CHECK: return %[[T3]] : vector<4x3x2xf32> + +func.func @broadcast_vec3d_from_vec2d(%arg0: vector<3x2xf32>) -> vector<4x3x2xf32> { + %0 = vector.broadcast %arg0 : vector<3x2xf32> to vector<4x3x2xf32> + return %0 : vector<4x3x2xf32> +} + +// CHECK-LABEL: func @broadcast_stretch +// CHECK-SAME: %[[A:.*0]]: vector<1xf32> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<1xf32> +// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<4xf32> +// CHECK: return %[[T1]] : vector<4xf32> + +func.func @broadcast_stretch(%arg0: vector<1xf32>) -> vector<4xf32> { + %0 = vector.broadcast %arg0 : vector<1xf32> to vector<4xf32> + return %0 : vector<4xf32> +} + +// CHECK-LABEL: func @broadcast_stretch_at_start +// CHECK-SAME: %[[A:.*0]]: vector<1x4xf32> +// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<3x4xf32> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<1x4xf32> +// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C0]] [0] : vector<4xf32> into vector<3x4xf32> +// CHECK: %[[T2:.*]] = vector.insert %[[T0]], %[[T1]] [1] : vector<4xf32> into vector<3x4xf32> +// CHECK: %[[T3:.*]] = vector.insert %[[T0]], %[[T2]] [2] : vector<4xf32> into vector<3x4xf32> +// CHECK: return %[[T3]] : vector<3x4xf32> + +func.func @broadcast_stretch_at_start(%arg0: vector<1x4xf32>) -> vector<3x4xf32> { + %0 = vector.broadcast %arg0 : vector<1x4xf32> to vector<3x4xf32> + return %0 : vector<3x4xf32> +} + +// CHECK-LABEL: func @broadcast_stretch_at_end +// CHECK-SAME: %[[A:.*0]]: vector<4x1xf32> +// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<4x3xf32> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<4x1xf32> +// CHECK: %[[T2:.*]] = vector.splat %[[T0]] : vector<3xf32> +// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xf32> into vector<4x3xf32> +// CHECK: %[[T4:.*]] = vector.extract %[[A]][1, 0] : vector<4x1xf32> +// CHECK: %[[T6:.*]] = vector.splat %[[T4]] : vector<3xf32> +// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xf32> into vector<4x3xf32> +// CHECK: %[[T8:.*]] = vector.extract %[[A]][2, 0] : vector<4x1xf32> +// CHECK: %[[T10:.*]] = vector.splat %[[T8]] : vector<3xf32> +// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T7]] [2] : vector<3xf32> into vector<4x3xf32> +// CHECK: %[[T12:.*]] = vector.extract %[[A]][3, 0] : vector<4x1xf32> +// CHECK: %[[T14:.*]] = vector.splat %[[T12]] : vector<3xf32> +// CHECK: %[[T15:.*]] = vector.insert %[[T14]], %[[T11]] [3] : vector<3xf32> into vector<4x3xf32> +// CHECK: return %[[T15]] : vector<4x3xf32> + +func.func @broadcast_stretch_at_end(%arg0: vector<4x1xf32>) -> vector<4x3xf32> { + %0 = vector.broadcast %arg0 : vector<4x1xf32> to vector<4x3xf32> + return %0 : vector<4x3xf32> +} + +// CHECK-LABEL: func @broadcast_stretch_in_middle +// CHECK-SAME: %[[A:.*0]]: vector<4x1x2xf32> +// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x2xf32> +// CHECK: %[[C1:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<4x1x2xf32> +// CHECK: %[[T2:.*]] = vector.insert %[[T0]], %[[C1]] [0] : vector<2xf32> into vector<3x2xf32> +// CHECK: %[[T3:.*]] = vector.insert %[[T0]], %[[T2]] [1] : vector<2xf32> into vector<3x2xf32> +// CHECK: %[[T4:.*]] = vector.insert %[[T0]], %[[T3]] [2] : vector<2xf32> into vector<3x2xf32> +// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[C0]] [0] : vector<3x2xf32> into vector<4x3x2xf32> +// CHECK: %[[T6:.*]] = vector.extract %[[A]][1, 0] : vector<4x1x2xf32> +// CHECK: %[[T8:.*]] = vector.insert %[[T6]], %[[C1]] [0] : vector<2xf32> into vector<3x2xf32> +// CHECK: %[[T9:.*]] = vector.insert %[[T6]], %[[T8]] [1] : vector<2xf32> into vector<3x2xf32> +// CHECK: %[[T10:.*]] = vector.insert %[[T6]], %[[T9]] [2] : vector<2xf32> into vector<3x2xf32> +// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T5]] [1] : vector<3x2xf32> into vector<4x3x2xf32> +// CHECK: %[[T12:.*]] = vector.extract %[[A]][2, 0] : vector<4x1x2xf32> +// CHECK: %[[T14:.*]] = vector.insert %[[T12]], %[[C1]] [0] : vector<2xf32> into vector<3x2xf32> +// CHECK: %[[T15:.*]] = vector.insert %[[T12]], %[[T14]] [1] : vector<2xf32> into vector<3x2xf32> +// CHECK: %[[T16:.*]] = vector.insert %[[T12]], %[[T15]] [2] : vector<2xf32> into vector<3x2xf32> +// CHECK: %[[T17:.*]] = vector.insert %[[T16]], %[[T11]] [2] : vector<3x2xf32> into vector<4x3x2xf32> +// CHECK: %[[T18:.*]] = vector.extract %[[A]][3, 0] : vector<4x1x2xf32> +// CHECK: %[[T20:.*]] = vector.insert %[[T18]], %[[C1]] [0] : vector<2xf32> into vector<3x2xf32> +// CHECK: %[[T21:.*]] = vector.insert %[[T18]], %[[T20]] [1] : vector<2xf32> into vector<3x2xf32> +// CHECK: %[[T22:.*]] = vector.insert %[[T18]], %[[T21]] [2] : vector<2xf32> into vector<3x2xf32> +// CHECK: %[[T23:.*]] = vector.insert %[[T22]], %[[T17]] [3] : vector<3x2xf32> into vector<4x3x2xf32> +// CHECK: return %[[T23]] : vector<4x3x2xf32> + +func.func @broadcast_stretch_in_middle(%arg0: vector<4x1x2xf32>) -> vector<4x3x2xf32> { + %0 = vector.broadcast %arg0 : vector<4x1x2xf32> to vector<4x3x2xf32> + return %0 : vector<4x3x2xf32> +} + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + %f = transform.structured.match ops{["func.func"]} in %module_op + : (!pdl.operation) -> !pdl.operation + + transform.vector.lower_broadcast %f + : (!pdl.operation) -> !pdl.operation +} diff --git a/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -test-vector-contraction-lowering=vector-outerproduct=1 | FileCheck %s +// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file | FileCheck %s #matvec_accesses = [ affine_map<(i, j) -> (i, j)>, @@ -207,3 +207,10 @@ memref.store %0, %arg2[] : memref> return } + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + transform.vector.lower_contraction %module_op + lowering_strategy = "outerproduct" + : (!pdl.operation) -> !pdl.operation +} diff --git a/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir @@ -0,0 +1,306 @@ +// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file | FileCheck %s + +#dotp_accesses = [ + affine_map<(i) -> (i)>, + affine_map<(i) -> (i)>, + affine_map<(i) -> ()> +] +#dotp_trait = { + indexing_maps = #dotp_accesses, + iterator_types = ["reduction"] +} + +// CHECK-LABEL: func @extract_contract1 +// CHECK-SAME: %[[A:.*0]]: vector<4xf32>, +// CHECK-SAME: %[[B:.*1]]: vector<4xf32>, +// CHECK-SAME: %[[C:.*2]]: f32 +// CHECK: %[[F:.*]] = arith.mulf %[[A]], %[[B]] : vector<4xf32> +// CHECK: %[[R:.*]] = vector.reduction , %[[F]], %[[C]] : vector<4xf32> into f32 +// CHECK: return %[[R]] : f32 + +func.func @extract_contract1(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2: f32) -> f32 { + %0 = vector.contract #dotp_trait %arg0, %arg1, %arg2 + : vector<4xf32>, vector<4xf32> into f32 + return %0 : f32 +} + +// CHECK-LABEL: func @masked_extract_contract1 +// CHECK-SAME: %[[A:.*0]]: vector<4xf32>, %[[B:.*1]]: vector<4xf32>, %[[C:.*2]]: f32 +// CHECK-SAME: %[[M:.*]]: vector<4xi1> +// CHECK: %[[F:.*]] = arith.mulf %[[A]], %[[B]] : vector<4xf32> +// CHECK: %[[R:.*]] = vector.mask %[[M]] { vector.reduction , %0, %arg2 : vector<4xf32> into f32 } : vector<4xi1> -> f32 +// CHECK: return %[[R]] : f32 + +func.func @masked_extract_contract1(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2: f32, %mask: vector<4xi1>) -> f32 { + %0 = vector.mask %mask { vector.contract #dotp_trait %arg0, %arg1, %arg2 : vector<4xf32>, vector<4xf32> into f32 } : vector<4xi1> -> f32 + return %0 : f32 +} + +// CHECK-LABEL: func @extract_contract1_int +// CHECK-SAME: %[[A:.*0]]: vector<4xi32>, +// CHECK-SAME: %[[B:.*1]]: vector<4xi32>, +// CHECK-SAME: %[[C:.*2]]: i32 +// CHECK: %[[F:.*]] = arith.muli %[[A]], %[[B]] : vector<4xi32> +// CHECK: %[[R:.*]] = vector.reduction , %[[F]], %[[C]] : vector<4xi32> into i32 +// CHECK: return %[[R]] : i32 + +func.func @extract_contract1_int(%arg0: vector<4xi32>, %arg1: vector<4xi32>, %arg2: i32) -> i32 { + %0 = vector.contract #dotp_trait %arg0, %arg1, %arg2 + : vector<4xi32>, vector<4xi32> into i32 + return %0 : i32 +} + +#matvec_accesses = [ + affine_map<(i, j) -> (i, j)>, + affine_map<(i, j) -> (j)>, + affine_map<(i, j) -> (i)> +] +#matvec_trait = { + indexing_maps = #matvec_accesses, + iterator_types = ["parallel", "reduction"] +} + +// CHECK-LABEL: func @extract_contract2 +// CHECK-SAME: %[[A:.*0]]: vector<2x3xf32>, +// CHECK-SAME: %[[B:.*1]]: vector<3xf32>, +// CHECK-SAME: %[[C:.*2]]: vector<2xf32> +// CHECK: %[[R:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xf32> +// CHECK: %[[T2:.*]] = arith.mulf %[[T0]], %[[B]] : vector<3xf32> +// CHECK: %[[T3:.*]] = vector.reduction , %[[T2]] : vector<3xf32> into f32 +// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : f32 into vector<2xf32> +// CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : vector<2x3xf32> +// CHECK: %[[T7:.*]] = arith.mulf %[[T5]], %[[B]] : vector<3xf32> +// CHECK: %[[T8:.*]] = vector.reduction , %[[T7]] : vector<3xf32> into f32 +// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : f32 into vector<2xf32> +// CHECK: %[[T10:.*]] = arith.addf %[[T9]], %[[C]] : vector<2xf32> +// CHECK: return %[[T10]] : vector<2xf32> + +func.func @extract_contract2(%arg0: vector<2x3xf32>, + %arg1: vector<3xf32>, + %arg2: vector<2xf32>) -> vector<2xf32> { + %0 = vector.contract #matvec_trait %arg0, %arg1, %arg2 + : vector<2x3xf32>, vector<3xf32> into vector<2xf32> + return %0 : vector<2xf32> +} + +// CHECK-LABEL: func @extract_contract2_int +// CHECK-SAME: %[[A:.*0]]: vector<2x3xi32>, +// CHECK-SAME: %[[B:.*1]]: vector<3xi32>, +// CHECK-SAME: %[[C:.*2]]: vector<2xi32> +// CHECK: %[[R:.*]] = arith.constant dense<0> : vector<2xi32> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xi32> +// CHECK: %[[T2:.*]] = arith.muli %[[T0]], %[[B]] : vector<3xi32> +// CHECK: %[[T3:.*]] = vector.reduction , %[[T2]] : vector<3xi32> into i32 +// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : i32 into vector<2xi32> +// CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : vector<2x3xi32> +// CHECK: %[[T7:.*]] = arith.muli %[[T5]], %[[B]] : vector<3xi32> +// CHECK: %[[T8:.*]] = vector.reduction , %[[T7]] : vector<3xi32> into i32 +// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : i32 into vector<2xi32> +// CHECK: %[[T10:.*]] = arith.addi %[[T9]], %[[C]] : vector<2xi32> +// CHECK: return %[[T10]] : vector<2xi32> +func.func @extract_contract2_int(%arg0: vector<2x3xi32>, + %arg1: vector<3xi32>, + %arg2: vector<2xi32>) -> vector<2xi32> { + %0 = vector.contract #matvec_trait %arg0, %arg1, %arg2 + : vector<2x3xi32>, vector<3xi32> into vector<2xi32> + return %0 : vector<2xi32> +} + +#vecmat_accesses = [ + affine_map<(i, j) -> (j)>, + affine_map<(i, j) -> (i, j)>, + affine_map<(i, j) -> (i)> +] +#vecmat_trait = { + indexing_maps = #vecmat_accesses, + iterator_types = ["parallel", "reduction"] +} + +// CHECK-LABEL: func @extract_contract3 +// CHECK-SAME: %[[A:.*0]]: vector<3xf32>, +// CHECK-SAME: %[[B:.*1]]: vector<2x3xf32>, +// CHECK-SAME: %[[C:.*2]]: vector<2xf32> +// CHECK: %[[R:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32> +// CHECK: %[[T0:.*]] = vector.extract %[[B]][0] : vector<2x3xf32> +// CHECK: %[[T2:.*]] = arith.mulf %[[T0]], %[[A]] : vector<3xf32> +// CHECK: %[[T3:.*]] = vector.reduction , %[[T2]] : vector<3xf32> into f32 +// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : f32 into vector<2xf32> +// CHECK: %[[T5:.*]] = vector.extract %[[B]][1] : vector<2x3xf32> +// CHECK: %[[T7:.*]] = arith.mulf %[[T5]], %[[A]] : vector<3xf32> +// CHECK: %[[T8:.*]] = vector.reduction , %[[T7]] : vector<3xf32> into f32 +// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : f32 into vector<2xf32> +// CHECK: %[[T10:.*]] = arith.addf %[[T9]], %[[C]] : vector<2xf32> +// CHECK: return %[[T10]] : vector<2xf32> + +func.func @extract_contract3(%arg0: vector<3xf32>, + %arg1: vector<2x3xf32>, + %arg2: vector<2xf32>) -> vector<2xf32> { + %0 = vector.contract #vecmat_trait %arg0, %arg1, %arg2 + : vector<3xf32>, vector<2x3xf32> into vector<2xf32> + return %0 : vector<2xf32> +} + +#matmat_accesses = [ + affine_map<(i, j, k) -> (i, k)>, + affine_map<(i, j, k) -> (k, j)>, + affine_map<(i, j, k) -> (i, j)> +] +#matmat_trait = { + indexing_maps = #matmat_accesses, + iterator_types = ["parallel", "parallel", "reduction"] +} + +// CHECK-LABEL: func @extract_contract4 +// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>, +// CHECK-SAME: %[[B:.*1]]: vector<2x2xf32>, +// CHECK-SAME: %[[C:.*2]]: vector<2x2xf32> +// CHECK: %[[R:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32> +// CHECK: %[[Bt:.*]] = vector.transpose %arg1, [1, 0] : vector<2x2xf32> to vector<2x2xf32> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x2xf32> +// CHECK: %[[T2:.*]] = vector.extract %[[Bt]][0] : vector<2x2xf32> +// CHECK: %[[T9:.*]] = arith.mulf %[[T0]], %[[T2]] : vector<2xf32> +// CHECK: %[[T10:.*]] = vector.reduction , %[[T9]] : vector<2xf32> into f32 +// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[R]] [0, 0] : f32 into vector<2x2xf32> +// +// CHECK: %[[T12:.*]] = vector.extract %[[Bt]][1] : vector<2x2xf32> +// CHECK: %[[T19:.*]] = arith.mulf %[[T0]], %[[T12]] : vector<2xf32> +// CHECK: %[[T20:.*]] = vector.reduction , %[[T19]] : vector<2xf32> into f32 +// CHECK: %[[T21:.*]] = vector.insert %[[T20]], %[[T11]] [0, 1] : f32 into vector<2x2xf32> +// +// CHECK: %[[T23:.*]] = vector.extract %[[A]][1] : vector<2x2xf32> +// CHECK: %[[T24:.*]] = vector.extract %[[Bt]][0] : vector<2x2xf32> +// CHECK: %[[T32:.*]] = arith.mulf %[[T23]], %[[T24]] : vector<2xf32> +// CHECK: %[[T33:.*]] = vector.reduction , %[[T32]] : vector<2xf32> into f32 +// CHECK: %[[T34:.*]] = vector.insert %[[T33]], %[[T21]] [1, 0] : f32 into vector<2x2xf32> +// +// CHECK: %[[T40:.*]] = vector.extract %[[Bt]][1] : vector<2x2xf32> +// CHECK: %[[T41:.*]] = arith.mulf %[[T23]], %[[T40]] : vector<2xf32> +// CHECK: %[[T42:.*]] = vector.reduction , %[[T41]] : vector<2xf32> into f32 +// CHECK: %[[T43:.*]] = vector.insert %[[T42]], %[[T34]] [1, 1] : f32 into vector<2x2xf32> +// +// CHECK: %[[T52:.*]] = arith.addf %[[T43]], %[[C]] : vector<2x2xf32> +// CHECK: return %[[T52]] : vector<2x2xf32> + +func.func @extract_contract4(%arg0: vector<2x2xf32>, + %arg1: vector<2x2xf32>, + %arg2: vector<2x2xf32>) -> vector<2x2xf32> { + %0 = vector.contract #matmat_trait %arg0, %arg1, %arg2 + : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> + return %0 : vector<2x2xf32> +} + + +#contraction2d_accesses = [ + affine_map<(i, j) -> (i, j)>, + affine_map<(i, j) -> (i, j)>, + affine_map<(i, j) -> ()> +] +#contraction2d_trait = { + indexing_maps = #contraction2d_accesses, + iterator_types = ["reduction", "reduction"] +} + +// CHECK-LABEL: func @full_contract1 +// CHECK-SAME: %[[A:.*0]]: vector<2x3xf32>, +// CHECK-SAME: %[[B:.*1]]: vector<2x3xf32>, +// CHECK-SAME: %[[C:.*2]]: f32 +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xf32> +// CHECK: %[[T1:.*]] = vector.extract %[[B]][0] : vector<2x3xf32> +// CHECK: %[[T2:.*]] = arith.mulf %[[T0]], %[[T1]] : vector<3xf32> +// CHECK: %[[T3:.*]] = vector.reduction , %[[T2]], %[[C]] : vector<3xf32> into f32 +// CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : vector<2x3xf32> +// CHECK: %[[T6:.*]] = vector.extract %[[B]][1] : vector<2x3xf32> +// CHECK: %[[T7:.*]] = arith.mulf %[[T5]], %[[T6]] : vector<3xf32> +// CHECK: %[[T8:.*]] = vector.reduction , %[[T7]], %[[T3]] : vector<3xf32> into f32 +// CHECK: return %[[T8]] : f32 + +func.func @full_contract1(%arg0: vector<2x3xf32>, + %arg1: vector<2x3xf32>, + %arg2: f32) -> f32 { + %0 = vector.contract #contraction2d_trait %arg0, %arg1, %arg2 + : vector<2x3xf32>, vector<2x3xf32> into f32 + return %0 : f32 +} + +#contraction2d_trans_accesses = [ + affine_map<(i, j) -> (i, j)>, + affine_map<(i, j) -> (j, i)>, + affine_map<(i, j) -> ()> +] +#contraction2d_trans_trait = { + indexing_maps = #contraction2d_trans_accesses, + iterator_types = ["reduction", "reduction"] +} + +// CHECK-LABEL: func @full_contract2 +// CHECK-SAME: %[[A:.*0]]: vector<2x3xf32>, +// CHECK-SAME: %[[B:.*1]]: vector<3x2xf32>, +// CHECK-SAME: %[[C:.*2]]: f32 +// CHECK: %[[Z:.*]] = arith.constant dense<0.000000e+00> : vector<3xf32> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xf32> +// CHECK: %[[T1:.*]] = vector.extract %[[B]][0, 0] : vector<3x2xf32> +// CHECK: %[[T3:.*]] = vector.insert %[[T1]], %[[Z]] [0] : f32 into vector<3xf32> +// CHECK: %[[T4:.*]] = vector.extract %[[B]][1, 0] : vector<3x2xf32> +// CHECK: %[[T6:.*]] = vector.insert %[[T4]], %[[T3]] [1] : f32 into vector<3xf32> +// CHECK: %[[T7:.*]] = vector.extract %[[B]][2, 0] : vector<3x2xf32> +// CHECK: %[[T9:.*]] = vector.insert %[[T7]], %[[T6]] [2] : f32 into vector<3xf32> +// CHECK: %[[T10:.*]] = arith.mulf %[[T0]], %[[T9]] : vector<3xf32> +// CHECK: %[[T11:.*]] = vector.reduction , %[[T10]], %[[C]] : vector<3xf32> into f32 +// +// CHECK: %[[T12:.*]] = vector.extract %[[A]][1] : vector<2x3xf32> +// CHECK: %[[T13:.*]] = vector.extract %[[B]][0, 1] : vector<3x2xf +// CHECK: %[[T15:.*]] = vector.insert %[[T13]], %[[Z]] [0] : f32 into vector<3xf32> +// CHECK: %[[T16:.*]] = vector.extract %[[B]][1, 1] : vector<3x2xf32> +// CHECK: %[[T18:.*]] = vector.insert %[[T16]], %[[T15]] [1] : f32 into vector<3xf32> +// CHECK: %[[T19:.*]] = vector.extract %[[B]][2, 1] : vector<3x2xf32> +// CHECK: %[[T21:.*]] = vector.insert %[[T19]], %[[T18]] [2] : f32 into vector<3xf32> +// CHECK: %[[T22:.*]] = arith.mulf %[[T12]], %[[T21]] : vector<3xf32> +// CHECK: %[[T23:.*]] = vector.reduction , %[[T22]], %[[T11]] : vector<3xf32> into f32 +// CHECK: return %[[T23]] : f32 + +func.func @full_contract2(%arg0: vector<2x3xf32>, + %arg1: vector<3x2xf32>, + %arg2: f32) -> f32 { + %0 = vector.contract #contraction2d_trans_trait %arg0, %arg1, %arg2 + : vector<2x3xf32>, vector<3x2xf32> into f32 + return %0 : f32 +} + +// CHECK-LABEL: @contract_one_sided_unit_reduction_dim +// CHECK-SAME: (%[[A0:.+]]: vector<1x2xi32>, %[[A1:.+]]: vector<2x2xi32>, %[[A2:.+]]: vector<2xi32>) +// CHECK-DAG: %[[C:.+]] = arith.constant dense<0> : vector<2xi32> +// CHECK-DAG: %[[E00:.+]] = vector.extract %[[A0]][0] : vector<1x2xi32> +// CHECK-DAG: %[[E10:.+]] = vector.extract %[[A1]][0] : vector<2x2xi32> +// CHECK: %[[M0:.+]] = arith.muli %[[E10]], %[[E00]] : vector<2xi32> +// CHECK: %[[R0:.+]] = vector.reduction , %[[M0]] : vector<2xi32> into i32 +// CHECK: %[[I0:.+]] = vector.insert %[[R0]], %[[C]] [0] : i32 into vector<2xi32> +// CHECK: %[[E11:.+]] = vector.extract %[[A1]][1] : vector<2x2xi32> +// CHECK: %[[M1:.+]] = arith.muli %[[E11]], %[[E00]] : vector<2xi32> +// CHECK: %[[R1:.+]] = vector.reduction , %[[M1]] : vector<2xi32> into i32 +// CHECK: %[[I1:.+]] = vector.insert %[[R1]], %[[I0]] [1] : i32 into vector<2xi32> +// CHECK: %[[S:.+]] = arith.addi %[[I1]], %[[A2]] : vector<2xi32> +// CHECK: return %[[S]] : vector<2xi32> + +func.func @contract_one_sided_unit_reduction_dim(%arg0 : vector<1x2xi32>, %arg1 : vector<2x2xi32>, %arg2 : vector<2xi32>) -> vector<2xi32> { + %res = vector.contract { + indexing_maps = [ + affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d1, d2)>, + affine_map<(d0, d1, d2) -> (d1)> + ], + iterator_types = ["reduction", "parallel", "reduction"], + kind = #vector.kind + } %arg0, %arg1, %arg2 : vector<1x2xi32>, vector<2x2xi32>, vector<2xi32> into vector<2xi32> + return %res : vector<2xi32> +} + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + %f = transform.structured.match ops{["func.func"]} in %module_op + : (!pdl.operation) -> !pdl.operation + + %f2 = transform.vector.lower_contraction %f + lowering_strategy = "dot" + : (!pdl.operation) -> !pdl.operation +} diff --git a/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir @@ -0,0 +1,57 @@ +// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file | FileCheck %s + +#matmat_accesses = [ + affine_map<(i, j, k) -> (i, k)>, + affine_map<(i, j, k) -> (k, j)>, + affine_map<(i, j, k) -> (i, j)> +] +#matmat_trait = { + indexing_maps = #matmat_accesses, + iterator_types = ["parallel", "parallel", "reduction"] +} + +// CHECK-LABEL: func @matmul +// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>, +// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>, +// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32> +// CHECK: %[[vcst:.*]] = arith.constant dense<0.000000e+00> : vector<8xf32> +// CHECK: %[[vcst_0:.*]] = arith.constant dense<0.000000e+00> : vector<12xf32> +// CHECK: %[[vcst_1:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32> +// CHECK: %[[a0:.*]] = vector.extract %[[A]][0] : vector<2x4xf32> +// CHECK: %[[a1:.*]] = vector.insert_strided_slice %[[a0]], %[[vcst]] {offsets = [0], strides = [1]} : vector<4xf32> into vector<8xf32> +// CHECK: %[[a2:.*]] = vector.extract %[[A]][1] : vector<2x4xf32> +// CHECK: %[[a3:.*]] = vector.insert_strided_slice %[[a2]], %[[a1]] {offsets = [4], strides = [1]} : vector<4xf32> into vector<8xf32> +// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<4x3xf32> +// CHECK: %[[b1:.*]] = vector.insert_strided_slice %[[b0]], %[[vcst_0]] {offsets = [0], strides = [1]} : vector<3xf32> into vector<12xf32> +// CHECK: %[[b2:.*]] = vector.extract %[[B]][1] : vector<4x3xf32> +// CHECK: %[[b3:.*]] = vector.insert_strided_slice %[[b2]], %[[b1]] {offsets = [3], strides = [1]} : vector<3xf32> into vector<12xf32> +// CHECK: %[[b4:.*]] = vector.extract %[[B]][2] : vector<4x3xf32> +// CHECK: %[[b5:.*]] = vector.insert_strided_slice %[[b4]], %[[b3]] {offsets = [6], strides = [1]} : vector<3xf32> into vector<12xf32> +// CHECK: %[[b6:.*]] = vector.extract %[[B]][3] : vector<4x3xf32> +// CHECK: %[[b7:.*]] = vector.insert_strided_slice %[[b6]], %[[b5]] {offsets = [9], strides = [1]} : vector<3xf32> into vector<12xf32> +// CHECK: %[[mm1:.*]] = vector.matrix_multiply %[[a3]], %[[b7]] {lhs_columns = 4 : i32, lhs_rows = 2 : i32, rhs_columns = 3 : i32} : (vector<8xf32>, vector<12xf32>) -> vector<6xf32> +// CHECK: %[[mm2:.*]] = vector.extract_strided_slice %[[mm1]] {offsets = [0], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32> +// CHECK: %[[mm3:.*]] = vector.insert %[[mm2]], %[[vcst_1]] [0] : vector<3xf32> into vector<2x3xf32> +// CHECK: %[[mm4:.*]] = vector.extract_strided_slice %[[mm1]] {offsets = [3], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32> +// CHECK: %[[mm5:.*]] = vector.insert %[[mm4]], %[[mm3]] [1] : vector<3xf32> into vector<2x3xf32> +// CHECK: %[[mm6:.*]] = arith.addf %[[C]], %[[mm5]] : vector<2x3xf32> +func.func @matmul(%arg0: vector<2x4xf32>, + %arg1: vector<4x3xf32>, + %arg2: vector<2x3xf32>) -> vector<2x3xf32> { + %0 = vector.contract #matmat_trait %arg0, %arg1, %arg2 + : vector<2x4xf32>, vector<4x3xf32> into vector<2x3xf32> + return %0 : vector<2x3xf32> +} + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + %f = transform.structured.match ops{["func.func"]} in %module_op + : (!pdl.operation) -> !pdl.operation + + %f2 = transform.vector.lower_contraction %f + lowering_strategy = "matmulintrinsics" + : (!pdl.operation) -> !pdl.operation + + %f3 = transform.vector.lower_shape_cast %f2 + : (!pdl.operation) -> !pdl.operation +} diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir @@ -0,0 +1,353 @@ +// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file | FileCheck %s + +#matvec_accesses = [ + affine_map<(i, j) -> (i, j)>, + affine_map<(i, j) -> (j)>, + affine_map<(i, j) -> (i)> +] +#matvec_trait = { + indexing_maps = #matvec_accesses, + iterator_types = ["parallel", "reduction"] +} + +#matmat_accesses = [ + affine_map<(i, j, k) -> (i, k)>, + affine_map<(i, j, k) -> (k, j)>, + affine_map<(i, j, k) -> (i, j)> +] +#matmat_trait = { + indexing_maps = #matmat_accesses, + iterator_types = ["parallel", "parallel", "reduction"] +} + +#matmat_accesses_0 = [ + affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (k, n)>, + affine_map<(m, n, k) -> (m, n)> +] +#matmat_trait_0 = { + indexing_maps = #matmat_accesses_0, + iterator_types = ["parallel", "parallel", "reduction"] +} + +// CHECK-LABEL: func.func @masked_extract_contract2( +// CHECK-SAME: %[[VAL_0:.*]]: vector<2x3xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: vector<3xf32>, +// CHECK-SAME: %[[VAL_2:.*]]: vector<2xf32>, +// CHECK-SAME: %[[IN_MASK:.*]]: vector<2x3xi1>) -> vector<2xf32> +// CHECK: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<2x3xi1> to vector<3x2xi1> +// CHECK: %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<3x2xi1> +// CHECK: vector.mask %[[MASK0]] { vector.outerproduct + +// CHECK: %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<3x2xi1> +// CHECK: vector.mask %[[MASK1]] { vector.outerproduct + +// CHECK: %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<3x2xi1> +// CHECK: vector.mask %[[MASK2]] { vector.outerproduct + +func.func @masked_extract_contract2(%arg0: vector<2x3xf32>, + %arg1: vector<3xf32>, + %arg2: vector<2xf32>, + %m: vector<2x3xi1>) -> vector<2xf32> { + %0 = vector.mask %m { vector.contract #matvec_trait %arg0, %arg1, %arg2 + : vector<2x3xf32>, vector<3xf32> into vector<2xf32> } : vector<2x3xi1> -> vector<2xf32> + return %0 : vector<2xf32> +} + +// CHECK-LABEL: func.func @masked_extract_contract4( +// CHECK-SAME: %[[VAL_0:.*]]: vector<3x5xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: vector<5x7xf32>, +// CHECK-SAME: %[[VAL_2:.*]]: vector<3x7xf32>, +// CHECK-SAME: %[[VAL_3:.*]]: vector<3x7x5xi1>) -> vector<3x7xf32> { +// CHECK: %[[VAL_5:.*]] = vector.transpose %[[VAL_3]], [2, 0, 1] : vector<3x7x5xi1> to vector<5x3x7xi1> +// CHECK: %[[VAL_8:.*]] = vector.extract %[[VAL_5]][0] : vector<5x3x7xi1> +// CHECK: %[[VAL_9:.*]] = vector.mask %[[VAL_8]] { vector.outerproduct %{{.*}} {kind = #vector.kind} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32> +// CHECK: %[[VAL_12:.*]] = vector.extract %[[VAL_5]][1] : vector<5x3x7xi1> +// CHECK: %[[VAL_13:.*]] = vector.mask %[[VAL_12]] { vector.outerproduct %{{.*}} {kind = #vector.kind} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32> +// CHECK: %[[VAL_16:.*]] = vector.extract %[[VAL_5]][2] : vector<5x3x7xi1> +// CHECK: %[[VAL_17:.*]] = vector.mask %[[VAL_16]] { vector.outerproduct %{{.*}} {kind = #vector.kind} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32> +// CHECK: %[[VAL_20:.*]] = vector.extract %[[VAL_5]][3] : vector<5x3x7xi1> +// CHECK: %[[VAL_21:.*]] = vector.mask %[[VAL_20]] { vector.outerproduct %{{.*}} {kind = #vector.kind} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32> +// CHECK: %[[VAL_24:.*]] = vector.extract %[[VAL_5]][4] : vector<5x3x7xi1> +// CHECK: %[[VAL_25:.*]] = vector.mask %[[VAL_24]] { vector.outerproduct %{{.*}} {kind = #vector.kind} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32> + +func.func @masked_extract_contract4(%arg0: vector<3x5xf32>, + %arg1: vector<5x7xf32>, + %arg2: vector<3x7xf32>, + %m : vector<3x7x5xi1>) -> vector<3x7xf32> { + %0 = vector.mask %m { vector.contract #matmat_trait %arg0, %arg1, %arg2 + : vector<3x5xf32>, vector<5x7xf32> into vector<3x7xf32> } : vector<3x7x5xi1> -> vector<3x7xf32> + return %0 : vector<3x7xf32> +} + +// CHECK-LABEL: func @matmul +// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>, +// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>, +// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32> +// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0] +// CHECK-SAME: : vector<2x4xf32> to vector<4x2xf32> +// +// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<4x2xf32> +// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<4x3xf32> +// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]] +// CHECK-SAME: : vector<2xf32>, vector<3xf32> +// +// CHECK: %[[a1:.*]] = vector.extract %[[At]][1] : vector<4x2xf32> +// CHECK: %[[b1:.*]] = vector.extract %[[B]][1] : vector<4x3xf32> +// CHECK: %[[c1:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[c0]] +// CHECK-SAME: : vector<2xf32>, vector<3xf32> +// +// CHECK: %[[a2:.*]] = vector.extract %[[At]][2] : vector<4x2xf32> +// CHECK: %[[b2:.*]] = vector.extract %[[B]][2] : vector<4x3xf32> +// CHECK: %[[c2:.*]] = vector.outerproduct %[[a2]], %[[b2]], %[[c1]] +// CHECK-SAME: : vector<2xf32>, vector<3xf32> +// +// CHECK: %[[a3:.*]] = vector.extract %[[At]][3] : vector<4x2xf32> +// CHECK: %[[b3:.*]] = vector.extract %[[B]][3] : vector<4x3xf32> +// CHECK: %[[c3:.*]] = vector.outerproduct %[[a3]], %[[b3]], %[[c2]] +// CHECK-SAME: : vector<2xf32>, vector<3xf32> +// +// CHECK: return %[[c3]] : vector<2x3xf32> +func.func @matmul(%arg0: vector<2x4xf32>, + %arg1: vector<4x3xf32>, + %arg2: vector<2x3xf32>) -> vector<2x3xf32> { + %0 = vector.contract #matmat_trait %arg0, %arg1, %arg2 + : vector<2x4xf32>, vector<4x3xf32> into vector<2x3xf32> + return %0 : vector<2x3xf32> +} + +// CHECK-LABEL: func @matmul_0 +// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>, +// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>, +// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32> +// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0] +// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32> +// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32> +// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]] +// CHECK: return %[[c0]] : vector<2x3xf32> +func.func @matmul_0(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<2x3xf32>) +-> vector<2x3xf32> +{ + %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2 + : vector<2x1xf32>, vector<1x3xf32> into vector<2x3xf32> + return %0 : vector<2x3xf32> +} + +// CHECK-LABEL: func @matmul_0_mixed +// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf16>, +// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf16>, +// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32> +// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0] +// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf16> +// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf16> +// CHECK: %[[a1:.*]] = arith.extf %[[a0]] : vector<2xf16> to vector<2xf32> +// CHECK: %[[b1:.*]] = arith.extf %[[b0]] : vector<3xf16> to vector<3xf32> +// CHECK: %[[c0:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[C]] +// CHECK: return %[[c0]] : vector<2x3xf32> +func.func @matmul_0_mixed(%arg0: vector<2x1xf16>, %arg1: vector<1x3xf16>, %arg2: vector<2x3xf32>) +-> vector<2x3xf32> +{ + %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2 + : vector<2x1xf16>, vector<1x3xf16> into vector<2x3xf32> + return %0 : vector<2x3xf32> +} + +#matmat_accesses_1 = [ + affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (n, k)>, + affine_map<(m, n, k) -> (m, n)> +] +#matmat_trait_1 = { + indexing_maps = #matmat_accesses_1, + iterator_types = ["parallel", "parallel", "reduction"] +} + +// CHECK-LABEL: func @matmul_1 +// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>, +// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<3x1xf32>, +// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32> +// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0] +// CHECK: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0] +// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32> +// CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<1x3xf32> +// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]] +// CHECK: return %[[c0]] : vector<2x3xf32> +func.func @matmul_1(%arg0: vector<2x1xf32>, %arg1: vector<3x1xf32>, %arg2: vector<2x3xf32>) +-> vector<2x3xf32> +{ + %0 = vector.contract #matmat_trait_1 %arg0, %arg1, %arg2 + : vector<2x1xf32>, vector<3x1xf32> into vector<2x3xf32> + return %0 : vector<2x3xf32> +} + +#matmat_accesses_2 = [ + affine_map<(m, n, k) -> (k, m)>, + affine_map<(m, n, k) -> (k, n)>, + affine_map<(m, n, k) -> (m, n)> +] +#matmat_trait_2 = { + indexing_maps = #matmat_accesses_2, + iterator_types = ["parallel", "parallel", "reduction"] +} + +// CHECK-LABEL: func @matmul_2 +// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>, +// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>, +// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32> +// CHECK: %[[a0:.*]] = vector.extract %[[A]][0] : vector<1x2xf32> +// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32> +// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]] +// CHECK: return %[[c0]] : vector<2x3xf32> +func.func @matmul_2(%arg0: vector<1x2xf32>, %arg1: vector<1x3xf32>, %arg2: vector<2x3xf32>) +-> vector<2x3xf32> +{ + %0 = vector.contract #matmat_trait_2 %arg0, %arg1, %arg2 + : vector<1x2xf32>, vector<1x3xf32> into vector<2x3xf32> + return %0 : vector<2x3xf32> +} + +#matmat_accesses_3 = [ + affine_map<(m, n, k) -> (k, m)>, + affine_map<(m, n, k) -> (n, k)>, + affine_map<(m, n, k) -> (m, n)> +] +#matmat_trait_3 = { + indexing_maps = #matmat_accesses_3, + iterator_types = ["parallel", "parallel", "reduction"] +} + +// CHECK-LABEL: func @matmul_3 +// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>, +// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<3x1xf32>, +// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32> +// CHECK: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0] +// CHECK: %[[a0:.*]] = vector.extract %[[A]][0] : vector<1x2xf32> +// CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<1x3xf32> +// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]] +// CHECK: return %[[c0]] : vector<2x3xf32> +func.func @matmul_3(%arg0: vector<1x2xf32>, %arg1: vector<3x1xf32>, %arg2: vector<2x3xf32>) +-> vector<2x3xf32> +{ + %0 = vector.contract #matmat_trait_3 %arg0, %arg1, %arg2 + : vector<1x2xf32>, vector<3x1xf32> into vector<2x3xf32> + return %0 : vector<2x3xf32> +} + +#matmat_accesses_4 = [ + affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (k, n)>, + affine_map<(m, n, k) -> (n, m)> +] +#matmat_trait_4 = { + indexing_maps = #matmat_accesses_4, + iterator_types = ["parallel", "parallel", "reduction"] +} + +// CHECK-LABEL: func @matmul_4 +// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>, +// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>, +// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32> +// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0] +// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32> +// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32> +// CHECK: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]] +// CHECK: return %[[c0]] : vector<3x2xf32> +func.func @matmul_4(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<3x2xf32>) +-> vector<3x2xf32> +{ + %0 = vector.contract #matmat_trait_4 %arg0, %arg1, %arg2 + : vector<2x1xf32>, vector<1x3xf32> into vector<3x2xf32> + return %0 : vector<3x2xf32> +} + +#matmat_accesses_5 = [ + affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (k, n)>, + affine_map<(m, n, k) -> (n, m)> +] +#matmat_trait_5 = { + indexing_maps = #matmat_accesses_5, + iterator_types = ["parallel", "parallel", "reduction"] +} + +// CHECK-LABEL: func @matmul_5 +// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>, +// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>, +// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32> +// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0] +// CHECK-DAG: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32> +// CHECK-DAG: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32> +// CHECK: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]] +// CHECK: return %[[c0]] : vector<3x2xf32> +func.func @matmul_5(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<3x2xf32>) +-> vector<3x2xf32> +{ + %0 = vector.contract #matmat_trait_5 %arg0, %arg1, %arg2 + : vector<2x1xf32>, vector<1x3xf32> into vector<3x2xf32> + return %0 : vector<3x2xf32> +} + +#matmat_accesses_6 = [ + affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (k, n)>, + affine_map<(m, n, k) -> (n, m)> +] +#matmat_trait_6 = { + indexing_maps = #matmat_accesses_6, + iterator_types = ["parallel", "parallel", "reduction"] +} + +// CHECK-LABEL: func @matmul_6 +// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>, +// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>, +// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32> +// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0] +// CHECK-DAG: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32> +// CHECK-DAG: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32> +// CHECK: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]] +// CHECK: return %[[c0]] : vector<3x2xf32> +func.func @matmul_6(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<3x2xf32>) +-> vector<3x2xf32> +{ + %0 = vector.contract #matmat_trait_6 %arg0, %arg1, %arg2 + : vector<2x1xf32>, vector<1x3xf32> into vector<3x2xf32> + return %0 : vector<3x2xf32> +} + +#matmat_accesses_7 = [ + affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (k, n)>, + affine_map<(m, n, k) -> (n, m)> +] +#matmat_trait_7 = { + indexing_maps = #matmat_accesses_7, + iterator_types = ["parallel", "parallel", "reduction"] +} + +// CHECK-LABEL: func @matmul_7 +// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>, +// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>, +// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32> +// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0] +// CHECK-DAG: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32> +// CHECK-DAG: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32> +// CHECK: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]] +// CHECK: return %[[c0]] : vector<3x2xf32> +func.func @matmul_7(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<3x2xf32>) +-> vector<3x2xf32> +{ + %0 = vector.contract #matmat_trait_7 %arg0, %arg1, %arg2 + : vector<2x1xf32>, vector<1x3xf32> into vector<3x2xf32> + return %0 : vector<3x2xf32> +} + + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + %f = transform.structured.match ops{["func.func"]} in %module_op + : (!pdl.operation) -> !pdl.operation + + %f2 = transform.vector.lower_contraction %f + lowering_strategy = "outerproduct" + : (!pdl.operation) -> !pdl.operation +} diff --git a/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir @@ -0,0 +1,62 @@ +// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file | FileCheck %s + +// CHECK-LABEL: func @parallel_contract_lowering +// CHECK: %[[E0:.*]] = vector.extract %{{.*}}[0, 0] : vector<1x1x4xf32> +// CHECK: %[[E1:.*]] = vector.extract %{{.*}}[0, 0] : vector<1x1x4xf32> +// CHECK: %[[F:.*]] = vector.fma %[[E0]], %[[E1]], %{{.*}} : vector<4xf32> +// CHECK: return %[[F]] : vector<4xf32> +func.func @parallel_contract_lowering(%arg0: vector<1x1x4xf32>, %arg1: vector<1x1x4xf32>, %arg2: vector<4xf32>) -> vector<4xf32> { + %0 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2, d0)>, affine_map<(d0, d1, d2) -> (d1, d2, d0)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"], kind = #vector.kind} %arg0, %arg1, %arg2 : vector<1x1x4xf32>, vector<1x1x4xf32> into vector<4xf32> + return %0 : vector<4xf32> +} + +// CHECK-LABEL: func @parallel_contract_lowering_broadcast +// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} : vector<1x1xf32> to vector<4x1x1xf32> +// CHECK: %[[T:.*]] = vector.transpose %[[B]], [1, 2, 0] : vector<4x1x1xf32> to vector<1x1x4xf32> +// CHECK: %[[E0:.*]] = vector.extract %[[T]][0, 0] : vector<1x1x4xf32> +// CHECK: %[[E1:.*]] = vector.extract %{{.*}}[0, 0] : vector<1x1x4xf32> +// CHECK: %[[F:.*]] = vector.fma %[[E0]], %[[E1]], %{{.*}} : vector<4xf32> +// CHECK: return %[[F]] : vector<4xf32> +func.func @parallel_contract_lowering_broadcast(%arg0: vector<1x1xf32>, %arg1: vector<1x1x4xf32>, %arg2: vector<4xf32>) -> vector<4xf32> { + %0 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d1, d2, d0)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"], kind = #vector.kind} %arg0, %arg1, %arg2 : vector<1x1xf32>, vector<1x1x4xf32> into vector<4xf32> + return %0 : vector<4xf32> +} + +// CHECK-LABEL: func @parallel_contract_lowering +// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} : vector<1x1xf32> to vector<4x1x1xf32> +// CHECK: %[[T0:.*]] = vector.transpose %[[B]], [1, 2, 0] : vector<4x1x1xf32> to vector<1x1x4xf32> +// CHECK: %[[T1:.*]] = vector.transpose %{{.*}}, [0, 2, 1] : vector<1x4x1xf32> to vector<1x1x4xf32> +// CHECK: %[[E0:.*]] = vector.extract %[[T0]][0, 0] : vector<1x1x4xf32> +// CHECK: %[[E1:.*]] = vector.extract %[[T1]][0, 0] : vector<1x1x4xf32> +// CHECK: %[[F:.*]] = vector.fma %[[E0]], %[[E1]], %arg2 : vector<4xf32> +// CHECK: return %[[F]] : vector<4xf32> +func.func @parallel_contract_lowering_transpose(%arg0: vector<1x1xf32>, %arg1: vector<1x4x1xf32>, %arg2: vector<4xf32>) -> vector<4xf32> { + %0 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d1, d0, d2)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"], kind = #vector.kind} %arg0, %arg1, %arg2 : vector<1x1xf32>, vector<1x4x1xf32> into vector<4xf32> + return %0 : vector<4xf32> +} + +// CHECK-LABEL: func @parallel_contract_lowering_scalar +// CHECK: %[[E0:.*]] = vector.extract %{{.*}}[0, 0] : vector<1x1xf32> +// CHECK: %[[E1:.*]] = vector.extract %{{.*}}[0, 0] : vector<1x1xf32> +// CHECK: %[[M:.*]] = arith.mulf %[[E0]], %[[E1]] : f32 +// CHECK: %[[A:.*]] = arith.addf %[[M]], %{{.*}} : f32 +// CHECK: return %[[A]] : f32 +func.func @parallel_contract_lowering_scalar(%arg0: vector<1x1xf32>, %arg1: vector<1x1xf32>, %arg2: f32) -> f32 { + %0 = vector.contract { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> ()>], + iterator_types = ["reduction", "reduction"], kind = #vector.kind} + %arg0, %arg1, %arg2 : vector<1x1xf32>, vector<1x1xf32> into f32 + return %0 : f32 +} + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + %f = transform.structured.match ops{["func.func"]} in %module_op + : (!pdl.operation) -> !pdl.operation + + %f2 = transform.vector.lower_contraction %f + lowering_strategy = "parallelarith" + : (!pdl.operation) -> !pdl.operation +} diff --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir deleted file mode 100644 --- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir +++ /dev/null @@ -1,1235 +0,0 @@ -// RUN: mlir-opt %s -test-vector-contraction-lowering | FileCheck %s -// RUN: mlir-opt %s -test-vector-contraction-lowering=vector-lower-matrix-intrinsics=1 | FileCheck %s --check-prefix=MATRIX -// RUN: mlir-opt %s -test-vector-contraction-lowering=vector-outerproduct=1 | FileCheck %s --check-prefix=OUTERPRODUCT -// RUN: mlir-opt %s -test-vector-contraction-lowering=vector-parallel-arith=1 | FileCheck %s --check-prefix=PARALLEL - -#dotp_accesses = [ - affine_map<(i) -> (i)>, - affine_map<(i) -> (i)>, - affine_map<(i) -> ()> -] -#dotp_trait = { - indexing_maps = #dotp_accesses, - iterator_types = ["reduction"] -} - -// CHECK-LABEL: func @extract_contract1 -// CHECK-SAME: %[[A:.*0]]: vector<4xf32>, -// CHECK-SAME: %[[B:.*1]]: vector<4xf32>, -// CHECK-SAME: %[[C:.*2]]: f32 -// CHECK: %[[F:.*]] = arith.mulf %[[A]], %[[B]] : vector<4xf32> -// CHECK: %[[R:.*]] = vector.reduction , %[[F]], %[[C]] : vector<4xf32> into f32 -// CHECK: return %[[R]] : f32 - -func.func @extract_contract1(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2: f32) -> f32 { - %0 = vector.contract #dotp_trait %arg0, %arg1, %arg2 - : vector<4xf32>, vector<4xf32> into f32 - return %0 : f32 -} - -// CHECK-LABEL: func @masked_extract_contract1 -// CHECK-SAME: %[[A:.*0]]: vector<4xf32>, %[[B:.*1]]: vector<4xf32>, %[[C:.*2]]: f32 -// CHECK-SAME: %[[M:.*]]: vector<4xi1> -// CHECK: %[[F:.*]] = arith.mulf %[[A]], %[[B]] : vector<4xf32> -// CHECK: %[[R:.*]] = vector.mask %[[M]] { vector.reduction , %0, %arg2 : vector<4xf32> into f32 } : vector<4xi1> -> f32 -// CHECK: return %[[R]] : f32 - -func.func @masked_extract_contract1(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2: f32, %mask: vector<4xi1>) -> f32 { - %0 = vector.mask %mask { vector.contract #dotp_trait %arg0, %arg1, %arg2 : vector<4xf32>, vector<4xf32> into f32 } : vector<4xi1> -> f32 - return %0 : f32 -} - -// CHECK-LABEL: func @extract_contract1_int -// CHECK-SAME: %[[A:.*0]]: vector<4xi32>, -// CHECK-SAME: %[[B:.*1]]: vector<4xi32>, -// CHECK-SAME: %[[C:.*2]]: i32 -// CHECK: %[[F:.*]] = arith.muli %[[A]], %[[B]] : vector<4xi32> -// CHECK: %[[R:.*]] = vector.reduction , %[[F]], %[[C]] : vector<4xi32> into i32 -// CHECK: return %[[R]] : i32 - -func.func @extract_contract1_int(%arg0: vector<4xi32>, %arg1: vector<4xi32>, %arg2: i32) -> i32 { - %0 = vector.contract #dotp_trait %arg0, %arg1, %arg2 - : vector<4xi32>, vector<4xi32> into i32 - return %0 : i32 -} - -#matvec_accesses = [ - affine_map<(i, j) -> (i, j)>, - affine_map<(i, j) -> (j)>, - affine_map<(i, j) -> (i)> -] -#matvec_trait = { - indexing_maps = #matvec_accesses, - iterator_types = ["parallel", "reduction"] -} - -// CHECK-LABEL: func @extract_contract2 -// CHECK-SAME: %[[A:.*0]]: vector<2x3xf32>, -// CHECK-SAME: %[[B:.*1]]: vector<3xf32>, -// CHECK-SAME: %[[C:.*2]]: vector<2xf32> -// CHECK: %[[R:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32> -// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xf32> -// CHECK: %[[T2:.*]] = arith.mulf %[[T0]], %[[B]] : vector<3xf32> -// CHECK: %[[T3:.*]] = vector.reduction , %[[T2]] : vector<3xf32> into f32 -// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : f32 into vector<2xf32> -// CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : vector<2x3xf32> -// CHECK: %[[T7:.*]] = arith.mulf %[[T5]], %[[B]] : vector<3xf32> -// CHECK: %[[T8:.*]] = vector.reduction , %[[T7]] : vector<3xf32> into f32 -// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : f32 into vector<2xf32> -// CHECK: %[[T10:.*]] = arith.addf %[[T9]], %[[C]] : vector<2xf32> -// CHECK: return %[[T10]] : vector<2xf32> - -func.func @extract_contract2(%arg0: vector<2x3xf32>, - %arg1: vector<3xf32>, - %arg2: vector<2xf32>) -> vector<2xf32> { - %0 = vector.contract #matvec_trait %arg0, %arg1, %arg2 - : vector<2x3xf32>, vector<3xf32> into vector<2xf32> - return %0 : vector<2xf32> -} - -// OUTERPRODUCT-LABEL: func.func @masked_extract_contract2( -// OUTERPRODUCT-SAME: %[[VAL_0:.*]]: vector<2x3xf32>, -// OUTERPRODUCT-SAME: %[[VAL_1:.*]]: vector<3xf32>, -// OUTERPRODUCT-SAME: %[[VAL_2:.*]]: vector<2xf32>, -// OUTERPRODUCT-SAME: %[[IN_MASK:.*]]: vector<2x3xi1>) -> vector<2xf32> -// OUTERPRODUCT: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<2x3xi1> to vector<3x2xi1> -// OUTERPRODUCT: %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<3x2xi1> -// OUTERPRODUCT: vector.mask %[[MASK0]] { vector.outerproduct - -// OUTERPRODUCT: %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<3x2xi1> -// OUTERPRODUCT: vector.mask %[[MASK1]] { vector.outerproduct - -// OUTERPRODUCT: %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<3x2xi1> -// OUTERPRODUCT: vector.mask %[[MASK2]] { vector.outerproduct - -func.func @masked_extract_contract2(%arg0: vector<2x3xf32>, - %arg1: vector<3xf32>, - %arg2: vector<2xf32>, - %m: vector<2x3xi1>) -> vector<2xf32> { - %0 = vector.mask %m { vector.contract #matvec_trait %arg0, %arg1, %arg2 - : vector<2x3xf32>, vector<3xf32> into vector<2xf32> } : vector<2x3xi1> -> vector<2xf32> - return %0 : vector<2xf32> -} - -// CHECK-LABEL: func @extract_contract2_int -// CHECK-SAME: %[[A:.*0]]: vector<2x3xi32>, -// CHECK-SAME: %[[B:.*1]]: vector<3xi32>, -// CHECK-SAME: %[[C:.*2]]: vector<2xi32> -// CHECK: %[[R:.*]] = arith.constant dense<0> : vector<2xi32> -// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xi32> -// CHECK: %[[T2:.*]] = arith.muli %[[T0]], %[[B]] : vector<3xi32> -// CHECK: %[[T3:.*]] = vector.reduction , %[[T2]] : vector<3xi32> into i32 -// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : i32 into vector<2xi32> -// CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : vector<2x3xi32> -// CHECK: %[[T7:.*]] = arith.muli %[[T5]], %[[B]] : vector<3xi32> -// CHECK: %[[T8:.*]] = vector.reduction , %[[T7]] : vector<3xi32> into i32 -// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : i32 into vector<2xi32> -// CHECK: %[[T10:.*]] = arith.addi %[[T9]], %[[C]] : vector<2xi32> -// CHECK: return %[[T10]] : vector<2xi32> -func.func @extract_contract2_int(%arg0: vector<2x3xi32>, - %arg1: vector<3xi32>, - %arg2: vector<2xi32>) -> vector<2xi32> { - %0 = vector.contract #matvec_trait %arg0, %arg1, %arg2 - : vector<2x3xi32>, vector<3xi32> into vector<2xi32> - return %0 : vector<2xi32> -} - -#vecmat_accesses = [ - affine_map<(i, j) -> (j)>, - affine_map<(i, j) -> (i, j)>, - affine_map<(i, j) -> (i)> -] -#vecmat_trait = { - indexing_maps = #vecmat_accesses, - iterator_types = ["parallel", "reduction"] -} - -// CHECK-LABEL: func @extract_contract3 -// CHECK-SAME: %[[A:.*0]]: vector<3xf32>, -// CHECK-SAME: %[[B:.*1]]: vector<2x3xf32>, -// CHECK-SAME: %[[C:.*2]]: vector<2xf32> -// CHECK: %[[R:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32> -// CHECK: %[[T0:.*]] = vector.extract %[[B]][0] : vector<2x3xf32> -// CHECK: %[[T2:.*]] = arith.mulf %[[T0]], %[[A]] : vector<3xf32> -// CHECK: %[[T3:.*]] = vector.reduction , %[[T2]] : vector<3xf32> into f32 -// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : f32 into vector<2xf32> -// CHECK: %[[T5:.*]] = vector.extract %[[B]][1] : vector<2x3xf32> -// CHECK: %[[T7:.*]] = arith.mulf %[[T5]], %[[A]] : vector<3xf32> -// CHECK: %[[T8:.*]] = vector.reduction , %[[T7]] : vector<3xf32> into f32 -// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : f32 into vector<2xf32> -// CHECK: %[[T10:.*]] = arith.addf %[[T9]], %[[C]] : vector<2xf32> -// CHECK: return %[[T10]] : vector<2xf32> - -func.func @extract_contract3(%arg0: vector<3xf32>, - %arg1: vector<2x3xf32>, - %arg2: vector<2xf32>) -> vector<2xf32> { - %0 = vector.contract #vecmat_trait %arg0, %arg1, %arg2 - : vector<3xf32>, vector<2x3xf32> into vector<2xf32> - return %0 : vector<2xf32> -} - -#matmat_accesses = [ - affine_map<(i, j, k) -> (i, k)>, - affine_map<(i, j, k) -> (k, j)>, - affine_map<(i, j, k) -> (i, j)> -] -#matmat_trait = { - indexing_maps = #matmat_accesses, - iterator_types = ["parallel", "parallel", "reduction"] -} - -// CHECK-LABEL: func @extract_contract4 -// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>, -// CHECK-SAME: %[[B:.*1]]: vector<2x2xf32>, -// CHECK-SAME: %[[C:.*2]]: vector<2x2xf32> -// CHECK: %[[R:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32> -// CHECK: %[[Bt:.*]] = vector.transpose %arg1, [1, 0] : vector<2x2xf32> to vector<2x2xf32> -// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x2xf32> -// CHECK: %[[T2:.*]] = vector.extract %[[Bt]][0] : vector<2x2xf32> -// CHECK: %[[T9:.*]] = arith.mulf %[[T0]], %[[T2]] : vector<2xf32> -// CHECK: %[[T10:.*]] = vector.reduction , %[[T9]] : vector<2xf32> into f32 -// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[R]] [0, 0] : f32 into vector<2x2xf32> -// -// CHECK: %[[T12:.*]] = vector.extract %[[Bt]][1] : vector<2x2xf32> -// CHECK: %[[T19:.*]] = arith.mulf %[[T0]], %[[T12]] : vector<2xf32> -// CHECK: %[[T20:.*]] = vector.reduction , %[[T19]] : vector<2xf32> into f32 -// CHECK: %[[T21:.*]] = vector.insert %[[T20]], %[[T11]] [0, 1] : f32 into vector<2x2xf32> -// -// CHECK: %[[T23:.*]] = vector.extract %[[A]][1] : vector<2x2xf32> -// CHECK: %[[T24:.*]] = vector.extract %[[Bt]][0] : vector<2x2xf32> -// CHECK: %[[T32:.*]] = arith.mulf %[[T23]], %[[T24]] : vector<2xf32> -// CHECK: %[[T33:.*]] = vector.reduction , %[[T32]] : vector<2xf32> into f32 -// CHECK: %[[T34:.*]] = vector.insert %[[T33]], %[[T21]] [1, 0] : f32 into vector<2x2xf32> -// -// CHECK: %[[T40:.*]] = vector.extract %[[Bt]][1] : vector<2x2xf32> -// CHECK: %[[T41:.*]] = arith.mulf %[[T23]], %[[T40]] : vector<2xf32> -// CHECK: %[[T42:.*]] = vector.reduction , %[[T41]] : vector<2xf32> into f32 -// CHECK: %[[T43:.*]] = vector.insert %[[T42]], %[[T34]] [1, 1] : f32 into vector<2x2xf32> -// -// CHECK: %[[T52:.*]] = arith.addf %[[T43]], %[[C]] : vector<2x2xf32> -// CHECK: return %[[T52]] : vector<2x2xf32> - -func.func @extract_contract4(%arg0: vector<2x2xf32>, - %arg1: vector<2x2xf32>, - %arg2: vector<2x2xf32>) -> vector<2x2xf32> { - %0 = vector.contract #matmat_trait %arg0, %arg1, %arg2 - : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> - return %0 : vector<2x2xf32> -} - -// OUTERPRODUCT-LABEL: func.func @masked_extract_contract4( -// OUTERPRODUCT-SAME: %[[VAL_0:.*]]: vector<3x5xf32>, -// OUTERPRODUCT-SAME: %[[VAL_1:.*]]: vector<5x7xf32>, -// OUTERPRODUCT-SAME: %[[VAL_2:.*]]: vector<3x7xf32>, -// OUTERPRODUCT-SAME: %[[VAL_3:.*]]: vector<3x7x5xi1>) -> vector<3x7xf32> { -// OUTERPRODUCT: %[[VAL_5:.*]] = vector.transpose %[[VAL_3]], [2, 0, 1] : vector<3x7x5xi1> to vector<5x3x7xi1> -// OUTERPRODUCT: %[[VAL_8:.*]] = vector.extract %[[VAL_5]][0] : vector<5x3x7xi1> -// OUTERPRODUCT: %[[VAL_9:.*]] = vector.mask %[[VAL_8]] { vector.outerproduct %{{.*}} {kind = #vector.kind} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32> -// OUTERPRODUCT: %[[VAL_12:.*]] = vector.extract %[[VAL_5]][1] : vector<5x3x7xi1> -// OUTERPRODUCT: %[[VAL_13:.*]] = vector.mask %[[VAL_12]] { vector.outerproduct %{{.*}} {kind = #vector.kind} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32> -// OUTERPRODUCT: %[[VAL_16:.*]] = vector.extract %[[VAL_5]][2] : vector<5x3x7xi1> -// OUTERPRODUCT: %[[VAL_17:.*]] = vector.mask %[[VAL_16]] { vector.outerproduct %{{.*}} {kind = #vector.kind} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32> -// OUTERPRODUCT: %[[VAL_20:.*]] = vector.extract %[[VAL_5]][3] : vector<5x3x7xi1> -// OUTERPRODUCT: %[[VAL_21:.*]] = vector.mask %[[VAL_20]] { vector.outerproduct %{{.*}} {kind = #vector.kind} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32> -// OUTERPRODUCT: %[[VAL_24:.*]] = vector.extract %[[VAL_5]][4] : vector<5x3x7xi1> -// OUTERPRODUCT: %[[VAL_25:.*]] = vector.mask %[[VAL_24]] { vector.outerproduct %{{.*}} {kind = #vector.kind} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32> - -func.func @masked_extract_contract4(%arg0: vector<3x5xf32>, - %arg1: vector<5x7xf32>, - %arg2: vector<3x7xf32>, - %m : vector<3x7x5xi1>) -> vector<3x7xf32> { - %0 = vector.mask %m { vector.contract #matmat_trait %arg0, %arg1, %arg2 - : vector<3x5xf32>, vector<5x7xf32> into vector<3x7xf32> } : vector<3x7x5xi1> -> vector<3x7xf32> - return %0 : vector<3x7xf32> -} - -#contraction2d_accesses = [ - affine_map<(i, j) -> (i, j)>, - affine_map<(i, j) -> (i, j)>, - affine_map<(i, j) -> ()> -] -#contraction2d_trait = { - indexing_maps = #contraction2d_accesses, - iterator_types = ["reduction", "reduction"] -} - -// CHECK-LABEL: func @full_contract1 -// CHECK-SAME: %[[A:.*0]]: vector<2x3xf32>, -// CHECK-SAME: %[[B:.*1]]: vector<2x3xf32>, -// CHECK-SAME: %[[C:.*2]]: f32 -// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xf32> -// CHECK: %[[T1:.*]] = vector.extract %[[B]][0] : vector<2x3xf32> -// CHECK: %[[T2:.*]] = arith.mulf %[[T0]], %[[T1]] : vector<3xf32> -// CHECK: %[[T3:.*]] = vector.reduction , %[[T2]], %[[C]] : vector<3xf32> into f32 -// CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : vector<2x3xf32> -// CHECK: %[[T6:.*]] = vector.extract %[[B]][1] : vector<2x3xf32> -// CHECK: %[[T7:.*]] = arith.mulf %[[T5]], %[[T6]] : vector<3xf32> -// CHECK: %[[T8:.*]] = vector.reduction , %[[T7]], %[[T3]] : vector<3xf32> into f32 -// CHECK: return %[[T8]] : f32 - -func.func @full_contract1(%arg0: vector<2x3xf32>, - %arg1: vector<2x3xf32>, - %arg2: f32) -> f32 { - %0 = vector.contract #contraction2d_trait %arg0, %arg1, %arg2 - : vector<2x3xf32>, vector<2x3xf32> into f32 - return %0 : f32 -} - -#contraction2d_trans_accesses = [ - affine_map<(i, j) -> (i, j)>, - affine_map<(i, j) -> (j, i)>, - affine_map<(i, j) -> ()> -] -#contraction2d_trans_trait = { - indexing_maps = #contraction2d_trans_accesses, - iterator_types = ["reduction", "reduction"] -} - -// CHECK-LABEL: func @full_contract2 -// CHECK-SAME: %[[A:.*0]]: vector<2x3xf32>, -// CHECK-SAME: %[[B:.*1]]: vector<3x2xf32>, -// CHECK-SAME: %[[C:.*2]]: f32 -// CHECK: %[[Z:.*]] = arith.constant dense<0.000000e+00> : vector<3xf32> -// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xf32> -// CHECK: %[[T1:.*]] = vector.extract %[[B]][0, 0] : vector<3x2xf32> -// CHECK: %[[T3:.*]] = vector.insert %[[T1]], %[[Z]] [0] : f32 into vector<3xf32> -// CHECK: %[[T4:.*]] = vector.extract %[[B]][1, 0] : vector<3x2xf32> -// CHECK: %[[T6:.*]] = vector.insert %[[T4]], %[[T3]] [1] : f32 into vector<3xf32> -// CHECK: %[[T7:.*]] = vector.extract %[[B]][2, 0] : vector<3x2xf32> -// CHECK: %[[T9:.*]] = vector.insert %[[T7]], %[[T6]] [2] : f32 into vector<3xf32> -// CHECK: %[[T10:.*]] = arith.mulf %[[T0]], %[[T9]] : vector<3xf32> -// CHECK: %[[T11:.*]] = vector.reduction , %[[T10]], %[[C]] : vector<3xf32> into f32 -// -// CHECK: %[[T12:.*]] = vector.extract %[[A]][1] : vector<2x3xf32> -// CHECK: %[[T13:.*]] = vector.extract %[[B]][0, 1] : vector<3x2xf -// CHECK: %[[T15:.*]] = vector.insert %[[T13]], %[[Z]] [0] : f32 into vector<3xf32> -// CHECK: %[[T16:.*]] = vector.extract %[[B]][1, 1] : vector<3x2xf32> -// CHECK: %[[T18:.*]] = vector.insert %[[T16]], %[[T15]] [1] : f32 into vector<3xf32> -// CHECK: %[[T19:.*]] = vector.extract %[[B]][2, 1] : vector<3x2xf32> -// CHECK: %[[T21:.*]] = vector.insert %[[T19]], %[[T18]] [2] : f32 into vector<3xf32> -// CHECK: %[[T22:.*]] = arith.mulf %[[T12]], %[[T21]] : vector<3xf32> -// CHECK: %[[T23:.*]] = vector.reduction , %[[T22]], %[[T11]] : vector<3xf32> into f32 -// CHECK: return %[[T23]] : f32 - -func.func @full_contract2(%arg0: vector<2x3xf32>, - %arg1: vector<3x2xf32>, - %arg2: f32) -> f32 { - %0 = vector.contract #contraction2d_trans_trait %arg0, %arg1, %arg2 - : vector<2x3xf32>, vector<3x2xf32> into f32 - return %0 : f32 -} - -// CHECK-LABEL: func @outerproduct_noacc -// CHECK-SAME: %[[A:.*0]]: vector<2xf32>, -// CHECK-SAME: %[[B:.*1]]: vector<3xf32> -// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32> -// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xf32> -// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xf32> -// CHECK: %[[T2:.*]] = arith.mulf %[[T1]], %[[B]] : vector<3xf32> -// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32> -// CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : vector<2xf32> -// CHECK: %[[T5:.*]] = vector.splat %[[T4]] : vector<3xf32> -// CHECK: %[[T6:.*]] = arith.mulf %[[T5]], %[[B]] : vector<3xf32> -// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xf32> into vector<2x3xf32> -// CHECK: return %[[T7]] : vector<2x3xf32> - -func.func @outerproduct_noacc(%arg0: vector<2xf32>, - %arg1: vector<3xf32>) -> vector<2x3xf32> { - %0 = vector.outerproduct %arg0, %arg1 : vector<2xf32>, vector<3xf32> - return %0: vector<2x3xf32> -} - -// CHECK-LABEL: func @outerproduct_acc -// CHECK-SAME: %[[A:.*0]]: vector<2xf32>, -// CHECK-SAME: %[[B:.*1]]: vector<3xf32>, -// CHECK-SAME: %[[C:.*2]]: vector<2x3xf32> -// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32> -// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xf32> -// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xf32> -// CHECK: %[[T2:.*]] = vector.extract %[[C]][0] : vector<2x3xf32> -// CHECK: %[[T3:.*]] = vector.fma %[[T1]], %[[B]], %[[T2]] : vector<3xf32> -// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32> -// CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : vector<2xf32> -// CHECK: %[[T6:.*]] = vector.splat %[[T5]] : vector<3xf32> -// CHECK: %[[T7:.*]] = vector.extract %[[C]][1] : vector<2x3xf32> -// CHECK: %[[T8:.*]] = vector.fma %[[T6]], %[[B]], %[[T7]] : vector<3xf32> -// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : vector<3xf32> into vector<2x3xf32> -// CHECK: return %[[T9]] : vector<2x3xf32> - -func.func @outerproduct_acc(%arg0: vector<2xf32>, - %arg1: vector<3xf32>, - %arg2: vector<2x3xf32>) -> vector<2x3xf32> { - %0 = vector.outerproduct %arg0, %arg1, %arg2 : vector<2xf32>, vector<3xf32> - return %0: vector<2x3xf32> -} - -// CHECK-LABEL: func @outerproduct_noacc_int -// CHECK-SAME: %[[A:.*0]]: vector<2xi32>, -// CHECK-SAME: %[[B:.*1]]: vector<3xi32> -// CHECK: %[[C0:.*]] = arith.constant dense<0> : vector<2x3xi32> -// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xi32> -// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xi32> -// CHECK: %[[T2:.*]] = arith.muli %[[T1]], %[[B]] : vector<3xi32> -// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xi32> into vector<2x3xi32> -// CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : vector<2xi32> -// CHECK: %[[T5:.*]] = vector.splat %[[T4]] : vector<3xi32> -// CHECK: %[[T6:.*]] = arith.muli %[[T5]], %[[B]] : vector<3xi32> -// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xi32> into vector<2x3xi32> -// CHECK: return %[[T7]] : vector<2x3xi32> -func.func @outerproduct_noacc_int(%arg0: vector<2xi32>, - %arg1: vector<3xi32>) -> vector<2x3xi32> { - %0 = vector.outerproduct %arg0, %arg1 : vector<2xi32>, vector<3xi32> - return %0: vector<2x3xi32> -} - -// CHECK-LABEL: func @outerproduct_acc_int -// CHECK-SAME: %[[A:.*0]]: vector<2xi32>, -// CHECK-SAME: %[[B:.*1]]: vector<3xi32>, -// CHECK-SAME: %[[C:.*2]]: vector<2x3xi32> -// CHECK: %[[C0:.*]] = arith.constant dense<0> : vector<2x3xi32> -// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xi32> -// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xi32> -// CHECK: %[[T2:.*]] = vector.extract %[[C]][0] : vector<2x3xi32> -// CHECK: %[[T3:.*]] = arith.muli %[[T1]], %[[B]] : vector<3xi32> -// CHECK: %[[T4:.*]] = arith.addi %[[T3]], %[[T2]] : vector<3xi32> -// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[C0]] [0] : vector<3xi32> into vector<2x3xi32> -// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xi32> -// CHECK: %[[T7:.*]] = vector.splat %[[T6]] : vector<3xi32> -// CHECK: %[[T8:.*]] = vector.extract %[[C]][1] : vector<2x3xi32> -// CHECK: %[[T9:.*]] = arith.muli %[[T7]], %[[B]] : vector<3xi32> -// CHECK: %[[T10:.*]] = arith.addi %[[T9]], %[[T8]] : vector<3xi32> -// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T5]] [1] : vector<3xi32> into vector<2x3xi32> -// CHECK: return %[[T11]] : vector<2x3xi32> -func.func @outerproduct_acc_int(%arg0: vector<2xi32>, - %arg1: vector<3xi32>, - %arg2: vector<2x3xi32>) -> vector<2x3xi32> { - %0 = vector.outerproduct %arg0, %arg1, %arg2 : vector<2xi32>, vector<3xi32> - return %0: vector<2x3xi32> -} - -// CHECK-LABEL: func @axpy_fp( -// CHECK-SAME: %[[A:.*0]]: vector<16xf32>, -// CHECK-SAME: %[[B:.*1]]: f32) -// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xf32> -// CHECK: %[[T1:.*]] = arith.mulf %[[A]], %[[T0]] : vector<16xf32> -// CHECK: return %[[T1]] : vector<16xf32> -func.func @axpy_fp(%arg0: vector<16xf32>, %arg1: f32) -> vector<16xf32> { - %0 = vector.outerproduct %arg0, %arg1: vector<16xf32>, f32 - return %0: vector<16xf32> -} - -// CHECK-LABEL: func @axpy_fp_add( -// CHECK-SAME: %[[A:.*0]]: vector<16xf32>, -// CHECK-SAME: %[[B:.*1]]: f32, -// CHECK-SAME: %[[C:.*2]]: vector<16xf32>) -// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xf32> -// CHECK: %[[T1:.*]] = vector.fma %[[A]], %[[T0]], %[[C]] : vector<16xf32> -// CHECK: return %[[T1]] : vector<16xf32> -func.func @axpy_fp_add(%arg0: vector<16xf32>, %arg1: f32, %arg2 : vector<16xf32>) -> vector<16xf32> { - %0 = vector.outerproduct %arg0, %arg1, %arg2: vector<16xf32>, f32 - return %0: vector<16xf32> -} - -// CHECK-LABEL: func @axpy_int( -// CHECK-SAME: %[[A:.*0]]: vector<16xi32>, -// CHECK-SAME: %[[B:.*1]]: i32) -// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xi32> -// CHECK: %[[T1:.*]] = arith.muli %[[A]], %[[T0]] : vector<16xi32> -// CHECK: return %[[T1]] : vector<16xi32> -func.func @axpy_int(%arg0: vector<16xi32>, %arg1: i32) -> vector<16xi32> { - %0 = vector.outerproduct %arg0, %arg1: vector<16xi32>, i32 - return %0: vector<16xi32> -} - -// CHECK-LABEL: func @axpy_int_add( -// CHECK-SAME: %[[A:.*0]]: vector<16xi32>, -// CHECK-SAME: %[[B:.*1]]: i32, -// CHECK-SAME: %[[C:.*2]]: vector<16xi32>) -// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xi32> -// CHECK: %[[T1:.*]] = arith.muli %[[A]], %[[T0]] : vector<16xi32> -// CHECK: %[[T2:.*]] = arith.addi %[[T1]], %[[C]] : vector<16xi32> -// CHECK: return %[[T2]] : vector<16xi32> -func.func @axpy_int_add(%arg0: vector<16xi32>, %arg1: i32, %arg2: vector<16xi32>) -> vector<16xi32> { - %0 = vector.outerproduct %arg0, %arg1, %arg2: vector<16xi32>, i32 - return %0: vector<16xi32> -} - -// CHECK-LABEL: func @nop_shape_cast -// CHECK-SAME: %[[A:.*]]: vector<16xf32> -// CHECK: return %[[A]] : vector<16xf32> - -func.func @nop_shape_cast(%arg0: vector<16xf32>) -> vector<16xf32> { - %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<16xf32> - return %0 : vector<16xf32> -} - -// CHECK-LABEL: func @cancel_shape_cast -// CHECK-SAME: %[[A:.*]]: vector<16xf32> -// CHECK: return %[[A]] : vector<16xf32> - -func.func @cancel_shape_cast(%arg0: vector<16xf32>) -> vector<16xf32> { - %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<4x4xf32> - %1 = vector.shape_cast %0 : vector<4x4xf32> to vector<16xf32> - return %1 : vector<16xf32> -} - -// Shape up and downcasts for 2-D vectors, for supporting conversion to -// llvm.matrix operations -// CHECK-LABEL: func @shape_casts -func.func @shape_casts(%a: vector<2x2xf32>) -> (vector<4xf32>, vector<2x2xf32>) { - // CHECK-DAG: %[[cst22:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32> - // CHECK-DAG: %[[cst:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32> - // CHECK: %[[ex0:.*]] = vector.extract %{{.*}}[0] : vector<2x2xf32> - // - // CHECK: %[[in0:.*]] = vector.insert_strided_slice %[[ex0]], %[[cst]] - // CHECK-SAME: {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32> - // - // CHECK: %[[ex1:.*]] = vector.extract %{{.*}}[1] : vector<2x2xf32> - // - // CHECK: %[[in2:.*]] = vector.insert_strided_slice %[[ex1]], %[[in0]] - // CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32> - // - %0 = vector.shape_cast %a : vector<2x2xf32> to vector<4xf32> - // CHECK: %[[add:.*]] = arith.addf %[[in2]], %[[in2]] : vector<4xf32> - %r0 = arith.addf %0, %0: vector<4xf32> - // - // CHECK: %[[ss0:.*]] = vector.extract_strided_slice %[[add]] - // CHECK-SAME: {offsets = [0], sizes = [2], strides = [1]} : - // CHECK-SAME: vector<4xf32> to vector<2xf32> - // - // CHECK: %[[res0:.*]] = vector.insert %[[ss0]], %[[cst22]] [0] : - // CHECK-SAME: vector<2xf32> into vector<2x2xf32> - // - // CHECK: %[[s2:.*]] = vector.extract_strided_slice %[[add]] - // CHECK-SAME: {offsets = [2], sizes = [2], strides = [1]} : - // CHECK-SAME: vector<4xf32> to vector<2xf32> - // - // CHECK: %[[res1:.*]] = vector.insert %[[s2]], %[[res0]] [1] : - // CHECK-SAME: vector<2xf32> into vector<2x2xf32> - // - %1 = vector.shape_cast %r0 : vector<4xf32> to vector<2x2xf32> - // CHECK: return %[[add]], %[[res1]] : vector<4xf32>, vector<2x2xf32> - return %r0, %1 : vector<4xf32>, vector<2x2xf32> -} - -// CHECK-LABEL: func @shape_cast_2d2d -// CHECK-SAME: %[[A:.*]]: vector<3x2xf32> -// CHECK: %[[C:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32> -// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<3x2xf32> -// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0, 0] : f32 into vector<2x3xf32> -// CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 1] : vector<3x2xf32> -// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 1] : f32 into vector<2x3xf32> -// CHECK: %[[T4:.*]] = vector.extract %[[A]][1, 0] : vector<3x2xf32> -// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [0, 2] : f32 into vector<2x3xf32> -// CHECK: %[[T6:.*]] = vector.extract %[[A]][1, 1] : vector<3x2xf32> -// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 0] : f32 into vector<2x3xf32> -// CHECK: %[[T8:.*]] = vector.extract %[[A]][2, 0] : vector<3x2xf32> -// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [1, 1] : f32 into vector<2x3xf32> -// CHECK: %[[T10:.*]] = vector.extract %[[A]][2, 1] : vector<3x2xf32> -// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [1, 2] : f32 into vector<2x3xf32> -// CHECK: return %[[T11]] : vector<2x3xf32> - -func.func @shape_cast_2d2d(%arg0 : vector<3x2xf32>) -> vector<2x3xf32> { - %s = vector.shape_cast %arg0: vector<3x2xf32> to vector<2x3xf32> - return %s : vector<2x3xf32> -} - -// CHECK-LABEL: func @shape_cast_3d1d -// CHECK-SAME: %[[A:.*]]: vector<1x3x2xf32> -// CHECK: %[[C:.*]] = arith.constant dense<0.000000e+00> : vector<6xf32> -// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0, 0] : vector<1x3x2xf32> -// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0] : f32 into vector<6xf32> -// CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 0, 1] : vector<1x3x2xf32> -// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [1] : f32 into vector<6xf32> -// CHECK: %[[T4:.*]] = vector.extract %[[A]][0, 1, 0] : vector<1x3x2xf32> -// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [2] : f32 into vector<6xf32> -// CHECK: %[[T6:.*]] = vector.extract %[[A]][0, 1, 1] : vector<1x3x2xf32> -// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [3] : f32 into vector<6xf32> -// CHECK: %[[T8:.*]] = vector.extract %[[A]][0, 2, 0] : vector<1x3x2xf32> -// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [4] : f32 into vector<6xf32> -// CHECK: %[[T10:.*]] = vector.extract %[[A]][0, 2, 1] : vector<1x3x2xf32> -// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [5] : f32 into vector<6xf32> -// CHECK: return %[[T11]] : vector<6xf32> - -func.func @shape_cast_3d1d(%arg0 : vector<1x3x2xf32>) -> vector<6xf32> { - %s = vector.shape_cast %arg0 : vector<1x3x2xf32> to vector<6xf32> - return %s : vector<6xf32> -} - -// CHECK-LABEL: func @shape_cast_1d3d -// CHECK-SAME: %[[A:.*]]: vector<6xf32> -// CHECK: %[[C:.*]] = arith.constant dense<0.000000e+00> : vector<2x1x3xf32> -// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<6xf32> -// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0, 0, 0] : f32 into vector<2x1x3xf32> -// CHECK: %[[T2:.*]] = vector.extract %[[A]][1] : vector<6xf32> -// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 0, 1] : f32 into vector<2x1x3xf32> -// CHECK: %[[T4:.*]] = vector.extract %[[A]][2] : vector<6xf32> -// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [0, 0, 2] : f32 into vector<2x1x3xf32> -// CHECK: %[[T6:.*]] = vector.extract %[[A]][3] : vector<6xf32> -// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 0, 0] : f32 into vector<2x1x3xf32> -// CHECK: %[[T8:.*]] = vector.extract %[[A]][4] : vector<6xf32> -// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [1, 0, 1] : f32 into vector<2x1x3xf32> -// CHECK: %[[T10:.*]] = vector.extract %[[A]][5] : vector<6xf32> -// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [1, 0, 2] : f32 into vector<2x1x3xf32> -// CHECK: return %[[T11]] : vector<2x1x3xf32> - -func.func @shape_cast_1d3d(%arg0 : vector<6xf32>) -> vector<2x1x3xf32> { - %s = vector.shape_cast %arg0 : vector<6xf32> to vector<2x1x3xf32> - return %s : vector<2x1x3xf32> -} - -// MATRIX-LABEL: func @matmul -// MATRIX-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>, -// MATRIX-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>, -// MATRIX-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32> -// MATRIX: %[[vcst:.*]] = arith.constant dense<0.000000e+00> : vector<8xf32> -// MATRIX: %[[vcst_0:.*]] = arith.constant dense<0.000000e+00> : vector<12xf32> -// MATRIX: %[[vcst_1:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32> -// MATRIX: %[[a0:.*]] = vector.extract %[[A]][0] : vector<2x4xf32> -// MATRIX: %[[a1:.*]] = vector.insert_strided_slice %[[a0]], %[[vcst]] {offsets = [0], strides = [1]} : vector<4xf32> into vector<8xf32> -// MATRIX: %[[a2:.*]] = vector.extract %[[A]][1] : vector<2x4xf32> -// MATRIX: %[[a3:.*]] = vector.insert_strided_slice %[[a2]], %[[a1]] {offsets = [4], strides = [1]} : vector<4xf32> into vector<8xf32> -// MATRIX: %[[b0:.*]] = vector.extract %[[B]][0] : vector<4x3xf32> -// MATRIX: %[[b1:.*]] = vector.insert_strided_slice %[[b0]], %[[vcst_0]] {offsets = [0], strides = [1]} : vector<3xf32> into vector<12xf32> -// MATRIX: %[[b2:.*]] = vector.extract %[[B]][1] : vector<4x3xf32> -// MATRIX: %[[b3:.*]] = vector.insert_strided_slice %[[b2]], %[[b1]] {offsets = [3], strides = [1]} : vector<3xf32> into vector<12xf32> -// MATRIX: %[[b4:.*]] = vector.extract %[[B]][2] : vector<4x3xf32> -// MATRIX: %[[b5:.*]] = vector.insert_strided_slice %[[b4]], %[[b3]] {offsets = [6], strides = [1]} : vector<3xf32> into vector<12xf32> -// MATRIX: %[[b6:.*]] = vector.extract %[[B]][3] : vector<4x3xf32> -// MATRIX: %[[b7:.*]] = vector.insert_strided_slice %[[b6]], %[[b5]] {offsets = [9], strides = [1]} : vector<3xf32> into vector<12xf32> -// MATRIX: %[[mm1:.*]] = vector.matrix_multiply %[[a3]], %[[b7]] {lhs_columns = 4 : i32, lhs_rows = 2 : i32, rhs_columns = 3 : i32} : (vector<8xf32>, vector<12xf32>) -> vector<6xf32> -// MATRIX: %[[mm2:.*]] = vector.extract_strided_slice %[[mm1]] {offsets = [0], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32> -// MATRIX: %[[mm3:.*]] = vector.insert %[[mm2]], %[[vcst_1]] [0] : vector<3xf32> into vector<2x3xf32> -// MATRIX: %[[mm4:.*]] = vector.extract_strided_slice %[[mm1]] {offsets = [3], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32> -// MATRIX: %[[mm5:.*]] = vector.insert %[[mm4]], %[[mm3]] [1] : vector<3xf32> into vector<2x3xf32> -// MATRIX: %[[mm6:.*]] = arith.addf %[[C]], %[[mm5]] : vector<2x3xf32> - -// OUTERPRODUCT-LABEL: func @matmul -// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>, -// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>, -// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32> -// OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0] -// OUTERPRODUCT-SAME: : vector<2x4xf32> to vector<4x2xf32> -// -// OUTERPRODUCT: %[[a0:.*]] = vector.extract %[[At]][0] : vector<4x2xf32> -// OUTERPRODUCT: %[[b0:.*]] = vector.extract %[[B]][0] : vector<4x3xf32> -// OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]] -// OUTERPRODUCT-SAME: : vector<2xf32>, vector<3xf32> -// -// OUTERPRODUCT: %[[a1:.*]] = vector.extract %[[At]][1] : vector<4x2xf32> -// OUTERPRODUCT: %[[b1:.*]] = vector.extract %[[B]][1] : vector<4x3xf32> -// OUTERPRODUCT: %[[c1:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[c0]] -// OUTERPRODUCT-SAME: : vector<2xf32>, vector<3xf32> -// -// OUTERPRODUCT: %[[a2:.*]] = vector.extract %[[At]][2] : vector<4x2xf32> -// OUTERPRODUCT: %[[b2:.*]] = vector.extract %[[B]][2] : vector<4x3xf32> -// OUTERPRODUCT: %[[c2:.*]] = vector.outerproduct %[[a2]], %[[b2]], %[[c1]] -// OUTERPRODUCT-SAME: : vector<2xf32>, vector<3xf32> -// -// OUTERPRODUCT: %[[a3:.*]] = vector.extract %[[At]][3] : vector<4x2xf32> -// OUTERPRODUCT: %[[b3:.*]] = vector.extract %[[B]][3] : vector<4x3xf32> -// OUTERPRODUCT: %[[c3:.*]] = vector.outerproduct %[[a3]], %[[b3]], %[[c2]] -// OUTERPRODUCT-SAME: : vector<2xf32>, vector<3xf32> -// -// OUTERPRODUCT: return %[[c3]] : vector<2x3xf32> - -// REDUCE-LABEL: func @matmul -// REDUCE-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>, -// REDUCE-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>, -// REDUCE-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32> -// -// REDUCE: %[[RES:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32> -// REDUCE: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0] -// REDUCE-SAME: : vector<4x3f32> to vector<3x4xf32> -// -// REDUCE: %[[a0:.*]] = vector.extract %[[A]][0] : vector<2x4xf32> -// REDUCE-NEXT: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<3x4xf32> -// REDUCE-NEXT: %[[ab00:.*]] = mul %[[a0]], %[[b0]] : vector<4xf32> -// REDUCE-NEXT: %[[s00:.*]] = vector.reduction , %[[ab00]] : vector<4xf32> into f32 -// REDUCE-NEXT: %[[r00:.*]] = vector.insert %[[s00]], %[[RES]] [0, 0] : f32 into vector<2x3xf32> -// -// ... -// -// REDUCE: %[[a1:.*]] = vector.extract %[[A]][1] : vector<2x4xf32> -// REDUCE-NEXT: %[[b2:.*]] = vector.extract %[[Bt]][2] : vector<3x4xf32> -// REDUCE-NEXT: %[[ab12:.*]] = mul %[[a1]], %[[b02]] : vector<4xf32> -// REDUCE-NEXT: %[[s12:.*]] = vector.reduction , %[[ab12]] : vector<4xf32> into f32 -// REDUCE-NEXT: %[[r12:.*]] = vector.insert %[[s12]], %{{.*}} [1, 2] : f32 into vector<2x3xf32> -// -// REDUCE: return %[[c3]] : vector<2x3xf32> -func.func @matmul(%arg0: vector<2x4xf32>, - %arg1: vector<4x3xf32>, - %arg2: vector<2x3xf32>) -> vector<2x3xf32> { - %0 = vector.contract #matmat_trait %arg0, %arg1, %arg2 - : vector<2x4xf32>, vector<4x3xf32> into vector<2x3xf32> - return %0 : vector<2x3xf32> -} - -// CHECK-LABEL: func @broadcast_vec1d_from_scalar -// CHECK-SAME: %[[A:.*0]]: f32 -// CHECK: %[[T0:.*]] = vector.splat %[[A]] : vector<2xf32> -// CHECK: return %[[T0]] : vector<2xf32> - -func.func @broadcast_vec1d_from_scalar(%arg0: f32) -> vector<2xf32> { - %0 = vector.broadcast %arg0 : f32 to vector<2xf32> - return %0 : vector<2xf32> -} - -// CHECK-LABEL: func @broadcast_vec2d_from_scalar -// CHECK-SAME: %[[A:.*0]]: f32 -// CHECK: %[[T0:.*]] = vector.splat %[[A]] : vector<2x3xf32> -// CHECK: return %[[T0]] : vector<2x3xf32> - -func.func @broadcast_vec2d_from_scalar(%arg0: f32) -> vector<2x3xf32> { - %0 = vector.broadcast %arg0 : f32 to vector<2x3xf32> - return %0 : vector<2x3xf32> -} - -// CHECK-LABEL: func @broadcast_vec3d_from_scalar -// CHECK-SAME: %[[A:.*0]]: f32 -// CHECK: %[[T0:.*]] = vector.splat %[[A]] : vector<2x3x4xf32> -// CHECK: return %[[T0]] : vector<2x3x4xf32> - -func.func @broadcast_vec3d_from_scalar(%arg0: f32) -> vector<2x3x4xf32> { - %0 = vector.broadcast %arg0 : f32 to vector<2x3x4xf32> - return %0 : vector<2x3x4xf32> -} - -// CHECK-LABEL: func @broadcast_vec1d_from_vec1d -// CHECK-SAME: %[[A:.*0]]: vector<2xf32> -// CHECK: return %[[A]] : vector<2xf32> - -func.func @broadcast_vec1d_from_vec1d(%arg0: vector<2xf32>) -> vector<2xf32> { - %0 = vector.broadcast %arg0 : vector<2xf32> to vector<2xf32> - return %0 : vector<2xf32> -} - -// CHECK-LABEL: func @broadcast_vec2d_from_vec1d -// CHECK-SAME: %[[A:.*0]]: vector<2xf32> -// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32> -// CHECK: %[[T0:.*]] = vector.insert %[[A]], %[[C0]] [0] : vector<2xf32> into vector<3x2xf32> -// CHECK: %[[T1:.*]] = vector.insert %[[A]], %[[T0]] [1] : vector<2xf32> into vector<3x2xf32> -// CHECK: %[[T2:.*]] = vector.insert %[[A]], %[[T1]] [2] : vector<2xf32> into vector<3x2xf32> -// CHECK: return %[[T2]] : vector<3x2xf32> - -func.func @broadcast_vec2d_from_vec1d(%arg0: vector<2xf32>) -> vector<3x2xf32> { - %0 = vector.broadcast %arg0 : vector<2xf32> to vector<3x2xf32> - return %0 : vector<3x2xf32> -} - -// CHECK-LABEL: func @broadcast_vec3d_from_vec1d -// CHECK-SAME: %[[A:.*0]]: vector<2xf32> -// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32> -// CHECK: %[[C1:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x2xf32> -// CHECK: %[[T0:.*]] = vector.insert %[[A]], %[[C0]] [0] : vector<2xf32> into vector<3x2xf32> -// CHECK: %[[T1:.*]] = vector.insert %[[A]], %[[T0]] [1] : vector<2xf32> into vector<3x2xf32> -// CHECK: %[[T2:.*]] = vector.insert %[[A]], %[[T1]] [2] : vector<2xf32> into vector<3x2xf32> -// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C1]] [0] : vector<3x2xf32> into vector<4x3x2xf32> -// CHECK: %[[T4:.*]] = vector.insert %[[T2]], %[[T3]] [1] : vector<3x2xf32> into vector<4x3x2xf32> -// CHECK: %[[T5:.*]] = vector.insert %[[T2]], %[[T4]] [2] : vector<3x2xf32> into vector<4x3x2xf32> -// CHECK: %[[T6:.*]] = vector.insert %[[T2]], %[[T5]] [3] : vector<3x2xf32> into vector<4x3x2xf32> -// CHECK: return %[[T6]] : vector<4x3x2xf32> - -func.func @broadcast_vec3d_from_vec1d(%arg0: vector<2xf32>) -> vector<4x3x2xf32> { - %0 = vector.broadcast %arg0 : vector<2xf32> to vector<4x3x2xf32> - return %0 : vector<4x3x2xf32> -} - -// CHECK-LABEL: func @broadcast_vec3d_from_vec2d -// CHECK-SAME: %[[A:.*0]]: vector<3x2xf32> -// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x2xf32> -// CHECK: %[[T0:.*]] = vector.insert %[[A]], %[[C0]] [0] : vector<3x2xf32> into vector<4x3x2xf32> -// CHECK: %[[T1:.*]] = vector.insert %[[A]], %[[T0]] [1] : vector<3x2xf32> into vector<4x3x2xf32> -// CHECK: %[[T2:.*]] = vector.insert %[[A]], %[[T1]] [2] : vector<3x2xf32> into vector<4x3x2xf32> -// CHECK: %[[T3:.*]] = vector.insert %[[A]], %[[T2]] [3] : vector<3x2xf32> into vector<4x3x2xf32> -// CHECK: return %[[T3]] : vector<4x3x2xf32> - -func.func @broadcast_vec3d_from_vec2d(%arg0: vector<3x2xf32>) -> vector<4x3x2xf32> { - %0 = vector.broadcast %arg0 : vector<3x2xf32> to vector<4x3x2xf32> - return %0 : vector<4x3x2xf32> -} - -// CHECK-LABEL: func @broadcast_stretch -// CHECK-SAME: %[[A:.*0]]: vector<1xf32> -// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<1xf32> -// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<4xf32> -// CHECK: return %[[T1]] : vector<4xf32> - -func.func @broadcast_stretch(%arg0: vector<1xf32>) -> vector<4xf32> { - %0 = vector.broadcast %arg0 : vector<1xf32> to vector<4xf32> - return %0 : vector<4xf32> -} - -// CHECK-LABEL: func @broadcast_stretch_at_start -// CHECK-SAME: %[[A:.*0]]: vector<1x4xf32> -// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<3x4xf32> -// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<1x4xf32> -// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C0]] [0] : vector<4xf32> into vector<3x4xf32> -// CHECK: %[[T2:.*]] = vector.insert %[[T0]], %[[T1]] [1] : vector<4xf32> into vector<3x4xf32> -// CHECK: %[[T3:.*]] = vector.insert %[[T0]], %[[T2]] [2] : vector<4xf32> into vector<3x4xf32> -// CHECK: return %[[T3]] : vector<3x4xf32> - -func.func @broadcast_stretch_at_start(%arg0: vector<1x4xf32>) -> vector<3x4xf32> { - %0 = vector.broadcast %arg0 : vector<1x4xf32> to vector<3x4xf32> - return %0 : vector<3x4xf32> -} - -// CHECK-LABEL: func @broadcast_stretch_at_end -// CHECK-SAME: %[[A:.*0]]: vector<4x1xf32> -// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<4x3xf32> -// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<4x1xf32> -// CHECK: %[[T2:.*]] = vector.splat %[[T0]] : vector<3xf32> -// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xf32> into vector<4x3xf32> -// CHECK: %[[T4:.*]] = vector.extract %[[A]][1, 0] : vector<4x1xf32> -// CHECK: %[[T6:.*]] = vector.splat %[[T4]] : vector<3xf32> -// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xf32> into vector<4x3xf32> -// CHECK: %[[T8:.*]] = vector.extract %[[A]][2, 0] : vector<4x1xf32> -// CHECK: %[[T10:.*]] = vector.splat %[[T8]] : vector<3xf32> -// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T7]] [2] : vector<3xf32> into vector<4x3xf32> -// CHECK: %[[T12:.*]] = vector.extract %[[A]][3, 0] : vector<4x1xf32> -// CHECK: %[[T14:.*]] = vector.splat %[[T12]] : vector<3xf32> -// CHECK: %[[T15:.*]] = vector.insert %[[T14]], %[[T11]] [3] : vector<3xf32> into vector<4x3xf32> -// CHECK: return %[[T15]] : vector<4x3xf32> - -func.func @broadcast_stretch_at_end(%arg0: vector<4x1xf32>) -> vector<4x3xf32> { - %0 = vector.broadcast %arg0 : vector<4x1xf32> to vector<4x3xf32> - return %0 : vector<4x3xf32> -} - -// CHECK-LABEL: func @broadcast_stretch_in_middle -// CHECK-SAME: %[[A:.*0]]: vector<4x1x2xf32> -// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x2xf32> -// CHECK: %[[C1:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32> -// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<4x1x2xf32> -// CHECK: %[[T2:.*]] = vector.insert %[[T0]], %[[C1]] [0] : vector<2xf32> into vector<3x2xf32> -// CHECK: %[[T3:.*]] = vector.insert %[[T0]], %[[T2]] [1] : vector<2xf32> into vector<3x2xf32> -// CHECK: %[[T4:.*]] = vector.insert %[[T0]], %[[T3]] [2] : vector<2xf32> into vector<3x2xf32> -// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[C0]] [0] : vector<3x2xf32> into vector<4x3x2xf32> -// CHECK: %[[T6:.*]] = vector.extract %[[A]][1, 0] : vector<4x1x2xf32> -// CHECK: %[[T8:.*]] = vector.insert %[[T6]], %[[C1]] [0] : vector<2xf32> into vector<3x2xf32> -// CHECK: %[[T9:.*]] = vector.insert %[[T6]], %[[T8]] [1] : vector<2xf32> into vector<3x2xf32> -// CHECK: %[[T10:.*]] = vector.insert %[[T6]], %[[T9]] [2] : vector<2xf32> into vector<3x2xf32> -// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T5]] [1] : vector<3x2xf32> into vector<4x3x2xf32> -// CHECK: %[[T12:.*]] = vector.extract %[[A]][2, 0] : vector<4x1x2xf32> -// CHECK: %[[T14:.*]] = vector.insert %[[T12]], %[[C1]] [0] : vector<2xf32> into vector<3x2xf32> -// CHECK: %[[T15:.*]] = vector.insert %[[T12]], %[[T14]] [1] : vector<2xf32> into vector<3x2xf32> -// CHECK: %[[T16:.*]] = vector.insert %[[T12]], %[[T15]] [2] : vector<2xf32> into vector<3x2xf32> -// CHECK: %[[T17:.*]] = vector.insert %[[T16]], %[[T11]] [2] : vector<3x2xf32> into vector<4x3x2xf32> -// CHECK: %[[T18:.*]] = vector.extract %[[A]][3, 0] : vector<4x1x2xf32> -// CHECK: %[[T20:.*]] = vector.insert %[[T18]], %[[C1]] [0] : vector<2xf32> into vector<3x2xf32> -// CHECK: %[[T21:.*]] = vector.insert %[[T18]], %[[T20]] [1] : vector<2xf32> into vector<3x2xf32> -// CHECK: %[[T22:.*]] = vector.insert %[[T18]], %[[T21]] [2] : vector<2xf32> into vector<3x2xf32> -// CHECK: %[[T23:.*]] = vector.insert %[[T22]], %[[T17]] [3] : vector<3x2xf32> into vector<4x3x2xf32> -// CHECK: return %[[T23]] : vector<4x3x2xf32> - -func.func @broadcast_stretch_in_middle(%arg0: vector<4x1x2xf32>) -> vector<4x3x2xf32> { - %0 = vector.broadcast %arg0 : vector<4x1x2xf32> to vector<4x3x2xf32> - return %0 : vector<4x3x2xf32> -} - -// CHECK-LABEL: func @genbool_1d -// CHECK: %[[T0:.*]] = arith.constant dense<[true, true, true, true, false, false, false, false]> : vector<8xi1> -// CHECK: return %[[T0]] : vector<8xi1> - -func.func @genbool_1d() -> vector<8xi1> { - %0 = vector.constant_mask [4] : vector<8xi1> - return %0 : vector<8xi1> -} - -// CHECK-LABEL: func @genbool_2d -// CHECK: %[[C1:.*]] = arith.constant dense<[true, true, false, false]> : vector<4xi1> -// CHECK: %[[C2:.*]] = arith.constant dense : vector<4x4xi1> -// CHECK: %[[T0:.*]] = vector.insert %[[C1]], %[[C2]] [0] : vector<4xi1> into vector<4x4xi1> -// CHECK: %[[T1:.*]] = vector.insert %[[C1]], %[[T0]] [1] : vector<4xi1> into vector<4x4xi1> -// CHECK: return %[[T1]] : vector<4x4xi1> - -func.func @genbool_2d() -> vector<4x4xi1> { - %v = vector.constant_mask [2, 2] : vector<4x4xi1> - return %v: vector<4x4xi1> -} - -// CHECK-LABEL: func @genbool_3d -// CHECK: %[[C1:.*]] = arith.constant dense<[true, true, true, false]> : vector<4xi1> -// CHECK: %[[C2:.*]] = arith.constant dense : vector<3x4xi1> -// CHECK: %[[C3:.*]] = arith.constant dense : vector<2x3x4xi1> -// CHECK: %[[T0:.*]] = vector.insert %[[C1]], %[[C2]] [0] : vector<4xi1> into vector<3x4xi1> -// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C3]] [0] : vector<3x4xi1> into vector<2x3x4xi1> -// CHECK: return %[[T1]] : vector<2x3x4xi1> - -func.func @genbool_3d() -> vector<2x3x4xi1> { - %v = vector.constant_mask [1, 1, 3] : vector<2x3x4xi1> - return %v: vector<2x3x4xi1> -} - -// CHECK-LABEL: func @genbool_var_1d( -// CHECK-SAME: %[[A:.*]]: index) -// CHECK: %[[T0:.*]] = vector.create_mask %[[A]] : vector<3xi1> -// CHECK: return %[[T0]] : vector<3xi1> - -func.func @genbool_var_1d(%arg0: index) -> vector<3xi1> { - %0 = vector.create_mask %arg0 : vector<3xi1> - return %0 : vector<3xi1> -} - -// CHECK-LABEL: func @genbool_var_2d( -// CHECK-SAME: %[[A:.*0]]: index, -// CHECK-SAME: %[[B:.*1]]: index) -// CHECK: %[[C1:.*]] = arith.constant dense : vector<3xi1> -// CHECK: %[[C2:.*]] = arith.constant dense : vector<2x3xi1> -// CHECK: %[[c0:.*]] = arith.constant 0 : index -// CHECK: %[[c1:.*]] = arith.constant 1 : index -// CHECK: %[[T0:.*]] = vector.create_mask %[[B]] : vector<3xi1> -// CHECK: %[[T1:.*]] = arith.cmpi sgt, %[[A]], %[[c0]] : index -// CHECK: %[[T2:.*]] = arith.select %[[T1]], %[[T0]], %[[C1]] : vector<3xi1> -// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C2]] [0] : vector<3xi1> into vector<2x3xi1> -// CHECK: %[[T4:.*]] = arith.cmpi sgt, %[[A]], %[[c1]] : index -// CHECK: %[[T5:.*]] = arith.select %[[T4]], %[[T0]], %[[C1]] : vector<3xi1> -// CHECK: %[[T6:.*]] = vector.insert %[[T5]], %[[T3]] [1] : vector<3xi1> into vector<2x3xi1> -// CHECK: return %[[T6]] : vector<2x3xi1> - -func.func @genbool_var_2d(%arg0: index, %arg1: index) -> vector<2x3xi1> { - %0 = vector.create_mask %arg0, %arg1 : vector<2x3xi1> - return %0 : vector<2x3xi1> -} - -// CHECK-LABEL: func @genbool_var_3d( -// CHECK-SAME: %[[A:.*0]]: index, -// CHECK-SAME: %[[B:.*1]]: index, -// CHECK-SAME: %[[C:.*2]]: index) -// CHECK-DAG: %[[C1:.*]] = arith.constant dense : vector<7xi1> -// CHECK-DAG: %[[C2:.*]] = arith.constant dense : vector<1x7xi1> -// CHECK-DAG: %[[C3:.*]] = arith.constant dense : vector<2x1x7xi1> -// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index -// CHECK: %[[T0:.*]] = vector.create_mask %[[C]] : vector<7xi1> -// CHECK: %[[T1:.*]] = arith.cmpi sgt, %[[B]], %[[c0]] : index -// CHECK: %[[T2:.*]] = arith.select %[[T1]], %[[T0]], %[[C1]] : vector<7xi1> -// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C2]] [0] : vector<7xi1> into vector<1x7xi1> -// CHECK: %[[T4:.*]] = arith.cmpi sgt, %[[A]], %[[c0]] : index -// CHECK: %[[T5:.*]] = arith.select %[[T4]], %[[T3]], %[[C2]] : vector<1x7xi1> -// CHECK: %[[T6:.*]] = vector.insert %[[T5]], %[[C3]] [0] : vector<1x7xi1> into vector<2x1x7xi1> -// CHECK: %[[T7:.*]] = arith.cmpi sgt, %[[A]], %[[c1]] : index -// CHECK: %[[T8:.*]] = arith.select %[[T7]], %[[T3]], %[[C2]] : vector<1x7xi1> -// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T6]] [1] : vector<1x7xi1> into vector<2x1x7xi1> -// CHECK: return %[[T9]] : vector<2x1x7xi1> - -func.func @genbool_var_3d(%arg0: index, %arg1: index, %arg2: index) -> vector<2x1x7xi1> { - %0 = vector.create_mask %arg0, %arg1, %arg2 : vector<2x1x7xi1> - return %0 : vector<2x1x7xi1> -} - -// CHECK-LABEL: @contract_one_sided_unit_reduction_dim -// CHECK-SAME: (%[[A0:.+]]: vector<1x2xi32>, %[[A1:.+]]: vector<2x2xi32>, %[[A2:.+]]: vector<2xi32>) -// CHECK-DAG: %[[C:.+]] = arith.constant dense<0> : vector<2xi32> -// CHECK-DAG: %[[E00:.+]] = vector.extract %[[A0]][0] : vector<1x2xi32> -// CHECK-DAG: %[[E10:.+]] = vector.extract %[[A1]][0] : vector<2x2xi32> -// CHECK: %[[M0:.+]] = arith.muli %[[E10]], %[[E00]] : vector<2xi32> -// CHECK: %[[R0:.+]] = vector.reduction , %[[M0]] : vector<2xi32> into i32 -// CHECK: %[[I0:.+]] = vector.insert %[[R0]], %[[C]] [0] : i32 into vector<2xi32> -// CHECK: %[[E11:.+]] = vector.extract %[[A1]][1] : vector<2x2xi32> -// CHECK: %[[M1:.+]] = arith.muli %[[E11]], %[[E00]] : vector<2xi32> -// CHECK: %[[R1:.+]] = vector.reduction , %[[M1]] : vector<2xi32> into i32 -// CHECK: %[[I1:.+]] = vector.insert %[[R1]], %[[I0]] [1] : i32 into vector<2xi32> -// CHECK: %[[S:.+]] = arith.addi %[[I1]], %[[A2]] : vector<2xi32> -// CHECK: return %[[S]] : vector<2xi32> - -func.func @contract_one_sided_unit_reduction_dim(%arg0 : vector<1x2xi32>, %arg1 : vector<2x2xi32>, %arg2 : vector<2xi32>) -> vector<2xi32> { - %res = vector.contract { - indexing_maps = [ - affine_map<(d0, d1, d2) -> (d0, d2)>, - affine_map<(d0, d1, d2) -> (d1, d2)>, - affine_map<(d0, d1, d2) -> (d1)> - ], - iterator_types = ["reduction", "parallel", "reduction"], - kind = #vector.kind - } %arg0, %arg1, %arg2 : vector<1x2xi32>, vector<2x2xi32>, vector<2xi32> into vector<2xi32> - return %res : vector<2xi32> -} - -#matmat_accesses_0 = [ - affine_map<(m, n, k) -> (m, k)>, - affine_map<(m, n, k) -> (k, n)>, - affine_map<(m, n, k) -> (m, n)> -] -#matmat_trait_0 = { - indexing_maps = #matmat_accesses_0, - iterator_types = ["parallel", "parallel", "reduction"] -} - -// OUTERPRODUCT-LABEL: func @matmul_0 -// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>, -// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>, -// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32> -// OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0] -// OUTERPRODUCT: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32> -// OUTERPRODUCT: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32> -// OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]] -// OUTERPRODUCT: return %[[c0]] : vector<2x3xf32> -func.func @matmul_0(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<2x3xf32>) --> vector<2x3xf32> -{ - %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2 - : vector<2x1xf32>, vector<1x3xf32> into vector<2x3xf32> - return %0 : vector<2x3xf32> -} - -// OUTERPRODUCT-LABEL: func @matmul_0_mixed -// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf16>, -// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf16>, -// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32> -// OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0] -// OUTERPRODUCT: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf16> -// OUTERPRODUCT: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf16> -// OUTERPRODUCT: %[[a1:.*]] = arith.extf %[[a0]] : vector<2xf16> to vector<2xf32> -// OUTERPRODUCT: %[[b1:.*]] = arith.extf %[[b0]] : vector<3xf16> to vector<3xf32> -// OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[C]] -// OUTERPRODUCT: return %[[c0]] : vector<2x3xf32> -func.func @matmul_0_mixed(%arg0: vector<2x1xf16>, %arg1: vector<1x3xf16>, %arg2: vector<2x3xf32>) --> vector<2x3xf32> -{ - %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2 - : vector<2x1xf16>, vector<1x3xf16> into vector<2x3xf32> - return %0 : vector<2x3xf32> -} - -#matmat_accesses_1 = [ - affine_map<(m, n, k) -> (m, k)>, - affine_map<(m, n, k) -> (n, k)>, - affine_map<(m, n, k) -> (m, n)> -] -#matmat_trait_1 = { - indexing_maps = #matmat_accesses_1, - iterator_types = ["parallel", "parallel", "reduction"] -} - -// OUTERPRODUCT-LABEL: func @matmul_1 -// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>, -// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<3x1xf32>, -// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32> -// OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0] -// OUTERPRODUCT: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0] -// OUTERPRODUCT: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32> -// OUTERPRODUCT: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<1x3xf32> -// OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]] -// OUTERPRODUCT: return %[[c0]] : vector<2x3xf32> -func.func @matmul_1(%arg0: vector<2x1xf32>, %arg1: vector<3x1xf32>, %arg2: vector<2x3xf32>) --> vector<2x3xf32> -{ - %0 = vector.contract #matmat_trait_1 %arg0, %arg1, %arg2 - : vector<2x1xf32>, vector<3x1xf32> into vector<2x3xf32> - return %0 : vector<2x3xf32> -} - -#matmat_accesses_2 = [ - affine_map<(m, n, k) -> (k, m)>, - affine_map<(m, n, k) -> (k, n)>, - affine_map<(m, n, k) -> (m, n)> -] -#matmat_trait_2 = { - indexing_maps = #matmat_accesses_2, - iterator_types = ["parallel", "parallel", "reduction"] -} - -// OUTERPRODUCT-LABEL: func @matmul_2 -// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>, -// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>, -// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32> -// OUTERPRODUCT: %[[a0:.*]] = vector.extract %[[A]][0] : vector<1x2xf32> -// OUTERPRODUCT: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32> -// OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]] -// OUTERPRODUCT: return %[[c0]] : vector<2x3xf32> -func.func @matmul_2(%arg0: vector<1x2xf32>, %arg1: vector<1x3xf32>, %arg2: vector<2x3xf32>) --> vector<2x3xf32> -{ - %0 = vector.contract #matmat_trait_2 %arg0, %arg1, %arg2 - : vector<1x2xf32>, vector<1x3xf32> into vector<2x3xf32> - return %0 : vector<2x3xf32> -} - -#matmat_accesses_3 = [ - affine_map<(m, n, k) -> (k, m)>, - affine_map<(m, n, k) -> (n, k)>, - affine_map<(m, n, k) -> (m, n)> -] -#matmat_trait_3 = { - indexing_maps = #matmat_accesses_3, - iterator_types = ["parallel", "parallel", "reduction"] -} - -// OUTERPRODUCT-LABEL: func @matmul_3 -// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>, -// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<3x1xf32>, -// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32> -// OUTERPRODUCT: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0] -// OUTERPRODUCT: %[[a0:.*]] = vector.extract %[[A]][0] : vector<1x2xf32> -// OUTERPRODUCT: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<1x3xf32> -// OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]] -// OUTERPRODUCT: return %[[c0]] : vector<2x3xf32> -func.func @matmul_3(%arg0: vector<1x2xf32>, %arg1: vector<3x1xf32>, %arg2: vector<2x3xf32>) --> vector<2x3xf32> -{ - %0 = vector.contract #matmat_trait_3 %arg0, %arg1, %arg2 - : vector<1x2xf32>, vector<3x1xf32> into vector<2x3xf32> - return %0 : vector<2x3xf32> -} - -#matmat_accesses_4 = [ - affine_map<(m, n, k) -> (m, k)>, - affine_map<(m, n, k) -> (k, n)>, - affine_map<(m, n, k) -> (n, m)> -] -#matmat_trait_4 = { - indexing_maps = #matmat_accesses_4, - iterator_types = ["parallel", "parallel", "reduction"] -} - -// OUTERPRODUCT-LABEL: func @matmul_4 -// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>, -// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>, -// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32> -// OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0] -// OUTERPRODUCT: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32> -// OUTERPRODUCT: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32> -// OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]] -// OUTERPRODUCT: return %[[c0]] : vector<3x2xf32> -func.func @matmul_4(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<3x2xf32>) --> vector<3x2xf32> -{ - %0 = vector.contract #matmat_trait_4 %arg0, %arg1, %arg2 - : vector<2x1xf32>, vector<1x3xf32> into vector<3x2xf32> - return %0 : vector<3x2xf32> -} - -#matmat_accesses_5 = [ - affine_map<(m, n, k) -> (m, k)>, - affine_map<(m, n, k) -> (k, n)>, - affine_map<(m, n, k) -> (n, m)> -] -#matmat_trait_5 = { - indexing_maps = #matmat_accesses_5, - iterator_types = ["parallel", "parallel", "reduction"] -} - -// OUTERPRODUCT-LABEL: func @matmul_5 -// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>, -// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>, -// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32> -// OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0] -// OUTERPRODUCT-DAG: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32> -// OUTERPRODUCT-DAG: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32> -// OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]] -// OUTERPRODUCT: return %[[c0]] : vector<3x2xf32> -func.func @matmul_5(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<3x2xf32>) --> vector<3x2xf32> -{ - %0 = vector.contract #matmat_trait_5 %arg0, %arg1, %arg2 - : vector<2x1xf32>, vector<1x3xf32> into vector<3x2xf32> - return %0 : vector<3x2xf32> -} - -#matmat_accesses_6 = [ - affine_map<(m, n, k) -> (m, k)>, - affine_map<(m, n, k) -> (k, n)>, - affine_map<(m, n, k) -> (n, m)> -] -#matmat_trait_6 = { - indexing_maps = #matmat_accesses_6, - iterator_types = ["parallel", "parallel", "reduction"] -} - -// OUTERPRODUCT-LABEL: func @matmul_6 -// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>, -// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>, -// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32> -// OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0] -// OUTERPRODUCT-DAG: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32> -// OUTERPRODUCT-DAG: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32> -// OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]] -// OUTERPRODUCT: return %[[c0]] : vector<3x2xf32> -func.func @matmul_6(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<3x2xf32>) --> vector<3x2xf32> -{ - %0 = vector.contract #matmat_trait_6 %arg0, %arg1, %arg2 - : vector<2x1xf32>, vector<1x3xf32> into vector<3x2xf32> - return %0 : vector<3x2xf32> -} - -#matmat_accesses_7 = [ - affine_map<(m, n, k) -> (m, k)>, - affine_map<(m, n, k) -> (k, n)>, - affine_map<(m, n, k) -> (n, m)> -] -#matmat_trait_7 = { - indexing_maps = #matmat_accesses_7, - iterator_types = ["parallel", "parallel", "reduction"] -} - -// OUTERPRODUCT-LABEL: func @matmul_7 -// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>, -// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>, -// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x2xf32> -// OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0] -// OUTERPRODUCT-DAG: %[[a0:.*]] = vector.extract %[[At]][0] : vector<1x2xf32> -// OUTERPRODUCT-DAG: %[[b0:.*]] = vector.extract %[[B]][0] : vector<1x3xf32> -// OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]] -// OUTERPRODUCT: return %[[c0]] : vector<3x2xf32> -func.func @matmul_7(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<3x2xf32>) --> vector<3x2xf32> -{ - %0 = vector.contract #matmat_trait_7 %arg0, %arg1, %arg2 - : vector<2x1xf32>, vector<1x3xf32> into vector<3x2xf32> - return %0 : vector<3x2xf32> -} - -// PARALLEL-LABEL: func @parrallel_contract_lowering -// PARALLEL: %[[E0:.*]] = vector.extract %{{.*}}[0, 0] : vector<1x1x4xf32> -// PARALLEL: %[[E1:.*]] = vector.extract %{{.*}}[0, 0] : vector<1x1x4xf32> -// PARALLEL: %[[F:.*]] = vector.fma %[[E0]], %[[E1]], %{{.*}} : vector<4xf32> -// PARALLEL: return %[[F]] : vector<4xf32> -func.func @parrallel_contract_lowering(%arg0: vector<1x1x4xf32>, %arg1: vector<1x1x4xf32>, %arg2: vector<4xf32>) -> vector<4xf32> { - %0 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2, d0)>, affine_map<(d0, d1, d2) -> (d1, d2, d0)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"], kind = #vector.kind} %arg0, %arg1, %arg2 : vector<1x1x4xf32>, vector<1x1x4xf32> into vector<4xf32> - return %0 : vector<4xf32> -} - -// PARALLEL-LABEL: func @parrallel_contract_lowering_broadcast -// PARALLEL: %[[B:.*]] = vector.broadcast %{{.*}} : vector<1x1xf32> to vector<4x1x1xf32> -// PARALLEL: %[[T:.*]] = vector.transpose %[[B]], [1, 2, 0] : vector<4x1x1xf32> to vector<1x1x4xf32> -// PARALLEL: %[[E0:.*]] = vector.extract %[[T]][0, 0] : vector<1x1x4xf32> -// PARALLEL: %[[E1:.*]] = vector.extract %{{.*}}[0, 0] : vector<1x1x4xf32> -// PARALLEL: %[[F:.*]] = vector.fma %[[E0]], %[[E1]], %{{.*}} : vector<4xf32> -// PARALLEL: return %[[F]] : vector<4xf32> -func.func @parrallel_contract_lowering_broadcast(%arg0: vector<1x1xf32>, %arg1: vector<1x1x4xf32>, %arg2: vector<4xf32>) -> vector<4xf32> { - %0 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d1, d2, d0)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"], kind = #vector.kind} %arg0, %arg1, %arg2 : vector<1x1xf32>, vector<1x1x4xf32> into vector<4xf32> - return %0 : vector<4xf32> -} - -// PARALLEL-LABEL: func @parrallel_contract_lowering -// PARALLEL: %[[B:.*]] = vector.broadcast %{{.*}} : vector<1x1xf32> to vector<4x1x1xf32> -// PARALLEL: %[[T0:.*]] = vector.transpose %[[B]], [1, 2, 0] : vector<4x1x1xf32> to vector<1x1x4xf32> -// PARALLEL: %[[T1:.*]] = vector.transpose %{{.*}}, [0, 2, 1] : vector<1x4x1xf32> to vector<1x1x4xf32> -// PARALLEL: %[[E0:.*]] = vector.extract %[[T0]][0, 0] : vector<1x1x4xf32> -// PARALLEL: %[[E1:.*]] = vector.extract %[[T1]][0, 0] : vector<1x1x4xf32> -// PARALLEL: %[[F:.*]] = vector.fma %[[E0]], %[[E1]], %arg2 : vector<4xf32> -// PARALLEL: return %[[F]] : vector<4xf32> -func.func @parrallel_contract_lowering_transpose(%arg0: vector<1x1xf32>, %arg1: vector<1x4x1xf32>, %arg2: vector<4xf32>) -> vector<4xf32> { - %0 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d1, d0, d2)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"], kind = #vector.kind} %arg0, %arg1, %arg2 : vector<1x1xf32>, vector<1x4x1xf32> into vector<4xf32> - return %0 : vector<4xf32> -} - -// PARALLEL-LABEL: func @parrallel_contract_lowering_scalar -// PARALLEL: %[[E0:.*]] = vector.extract %{{.*}}[0, 0] : vector<1x1xf32> -// PARALLEL: %[[E1:.*]] = vector.extract %{{.*}}[0, 0] : vector<1x1xf32> -// PARALLEL: %[[M:.*]] = arith.mulf %[[E0]], %[[E1]] : f32 -// PARALLEL: %[[A:.*]] = arith.addf %[[M]], %{{.*}} : f32 -// PARALLEL: return %[[A]] : f32 -func.func @parrallel_contract_lowering_scalar(%arg0: vector<1x1xf32>, %arg1: vector<1x1xf32>, %arg2: f32) -> f32 { - %0 = vector.contract { - indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> ()>], - iterator_types = ["reduction", "reduction"], kind = #vector.kind} - %arg0, %arg1, %arg2 : vector<1x1xf32>, vector<1x1xf32> into f32 - return %0 : f32 -} - - diff --git a/mlir/test/Dialect/Vector/vector-mask-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-mask-lowering-transforms.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-mask-lowering-transforms.mlir @@ -0,0 +1,101 @@ +// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file | FileCheck %s + +// CHECK-LABEL: func @genbool_1d +// CHECK: %[[T0:.*]] = arith.constant dense<[true, true, true, true, false, false, false, false]> : vector<8xi1> +// CHECK: return %[[T0]] : vector<8xi1> + +func.func @genbool_1d() -> vector<8xi1> { + %0 = vector.constant_mask [4] : vector<8xi1> + return %0 : vector<8xi1> +} + +// CHECK-LABEL: func @genbool_2d +// CHECK: %[[C1:.*]] = arith.constant dense<[true, true, false, false]> : vector<4xi1> +// CHECK: %[[C2:.*]] = arith.constant dense : vector<4x4xi1> +// CHECK: %[[T0:.*]] = vector.insert %[[C1]], %[[C2]] [0] : vector<4xi1> into vector<4x4xi1> +// CHECK: %[[T1:.*]] = vector.insert %[[C1]], %[[T0]] [1] : vector<4xi1> into vector<4x4xi1> +// CHECK: return %[[T1]] : vector<4x4xi1> + +func.func @genbool_2d() -> vector<4x4xi1> { + %v = vector.constant_mask [2, 2] : vector<4x4xi1> + return %v: vector<4x4xi1> +} + +// CHECK-LABEL: func @genbool_3d +// CHECK: %[[C1:.*]] = arith.constant dense<[true, true, true, false]> : vector<4xi1> +// CHECK: %[[C2:.*]] = arith.constant dense : vector<3x4xi1> +// CHECK: %[[C3:.*]] = arith.constant dense : vector<2x3x4xi1> +// CHECK: %[[T0:.*]] = vector.insert %[[C1]], %[[C2]] [0] : vector<4xi1> into vector<3x4xi1> +// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C3]] [0] : vector<3x4xi1> into vector<2x3x4xi1> +// CHECK: return %[[T1]] : vector<2x3x4xi1> + +func.func @genbool_3d() -> vector<2x3x4xi1> { + %v = vector.constant_mask [1, 1, 3] : vector<2x3x4xi1> + return %v: vector<2x3x4xi1> +} + +// CHECK-LABEL: func @genbool_var_1d( +// CHECK-SAME: %[[A:.*]]: index) +// CHECK: %[[T0:.*]] = vector.create_mask %[[A]] : vector<3xi1> +// CHECK: return %[[T0]] : vector<3xi1> + +func.func @genbool_var_1d(%arg0: index) -> vector<3xi1> { + %0 = vector.create_mask %arg0 : vector<3xi1> + return %0 : vector<3xi1> +} + +// CHECK-LABEL: func @genbool_var_2d( +// CHECK-SAME: %[[A:.*0]]: index, +// CHECK-SAME: %[[B:.*1]]: index) +// CHECK: %[[C1:.*]] = arith.constant dense : vector<3xi1> +// CHECK: %[[C2:.*]] = arith.constant dense : vector<2x3xi1> +// CHECK: %[[c0:.*]] = arith.constant 0 : index +// CHECK: %[[c1:.*]] = arith.constant 1 : index +// CHECK: %[[T0:.*]] = vector.create_mask %[[B]] : vector<3xi1> +// CHECK: %[[T1:.*]] = arith.cmpi sgt, %[[A]], %[[c0]] : index +// CHECK: %[[T2:.*]] = arith.select %[[T1]], %[[T0]], %[[C1]] : vector<3xi1> +// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C2]] [0] : vector<3xi1> into vector<2x3xi1> +// CHECK: %[[T4:.*]] = arith.cmpi sgt, %[[A]], %[[c1]] : index +// CHECK: %[[T5:.*]] = arith.select %[[T4]], %[[T0]], %[[C1]] : vector<3xi1> +// CHECK: %[[T6:.*]] = vector.insert %[[T5]], %[[T3]] [1] : vector<3xi1> into vector<2x3xi1> +// CHECK: return %[[T6]] : vector<2x3xi1> + +func.func @genbool_var_2d(%arg0: index, %arg1: index) -> vector<2x3xi1> { + %0 = vector.create_mask %arg0, %arg1 : vector<2x3xi1> + return %0 : vector<2x3xi1> +} + +// CHECK-LABEL: func @genbool_var_3d( +// CHECK-SAME: %[[A:.*0]]: index, +// CHECK-SAME: %[[B:.*1]]: index, +// CHECK-SAME: %[[C:.*2]]: index) +// CHECK-DAG: %[[C1:.*]] = arith.constant dense : vector<7xi1> +// CHECK-DAG: %[[C2:.*]] = arith.constant dense : vector<1x7xi1> +// CHECK-DAG: %[[C3:.*]] = arith.constant dense : vector<2x1x7xi1> +// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index +// CHECK: %[[T0:.*]] = vector.create_mask %[[C]] : vector<7xi1> +// CHECK: %[[T1:.*]] = arith.cmpi sgt, %[[B]], %[[c0]] : index +// CHECK: %[[T2:.*]] = arith.select %[[T1]], %[[T0]], %[[C1]] : vector<7xi1> +// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C2]] [0] : vector<7xi1> into vector<1x7xi1> +// CHECK: %[[T4:.*]] = arith.cmpi sgt, %[[A]], %[[c0]] : index +// CHECK: %[[T5:.*]] = arith.select %[[T4]], %[[T3]], %[[C2]] : vector<1x7xi1> +// CHECK: %[[T6:.*]] = vector.insert %[[T5]], %[[C3]] [0] : vector<1x7xi1> into vector<2x1x7xi1> +// CHECK: %[[T7:.*]] = arith.cmpi sgt, %[[A]], %[[c1]] : index +// CHECK: %[[T8:.*]] = arith.select %[[T7]], %[[T3]], %[[C2]] : vector<1x7xi1> +// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T6]] [1] : vector<1x7xi1> into vector<2x1x7xi1> +// CHECK: return %[[T9]] : vector<2x1x7xi1> + +func.func @genbool_var_3d(%arg0: index, %arg1: index, %arg2: index) -> vector<2x1x7xi1> { + %0 = vector.create_mask %arg0, %arg1, %arg2 : vector<2x1x7xi1> + return %0 : vector<2x1x7xi1> +} + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + %f = transform.structured.match ops{["func.func"]} in %module_op + : (!pdl.operation) -> !pdl.operation + + transform.vector.lower_mask %f + : (!pdl.operation) -> !pdl.operation +} diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir --- a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -test-vector-multi-reduction-lowering-patterns -split-input-file | FileCheck %s +// RUN: mlir-opt %s --test-transform-dialect-interpreter | FileCheck %s func.func @vector_multi_reduction(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> { %0 = vector.multi_reduction , %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32> @@ -264,3 +264,10 @@ // CHECK-LABEL: func @vector_multi_reduction_parallel_middle // CHECK-SAME: %[[INPUT:.+]]: vector<3x4x5xf32>, %[[ACC:.+]]: vector<4xf32> // CHECK: vector.transpose %[[INPUT]], [1, 0, 2] : vector<3x4x5xf32> to vector<4x3x5xf32> + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + transform.vector.lower_multi_reduction %module_op + lowering_strategy = "innerreduction" + : (!pdl.operation) -> !pdl.operation +} diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir --- a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -test-vector-multi-reduction-lowering-patterns="use-outer-reductions" | FileCheck %s +// RUN: mlir-opt %s --test-transform-dialect-interpreter | FileCheck %s func.func @vector_multi_reduction(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> { %0 = vector.multi_reduction , %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32> @@ -187,3 +187,10 @@ } // CHECK-LABEL: func @vector_multi_reduction_to_scalar // CHECK: return %{{.+}} + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + transform.vector.lower_multi_reduction %module_op + lowering_strategy = "innerparallel" + : (!pdl.operation) -> !pdl.operation +} diff --git a/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir @@ -0,0 +1,148 @@ +// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file | FileCheck %s + +// CHECK-LABEL: func @outerproduct_noacc +// CHECK-SAME: %[[A:.*0]]: vector<2xf32>, +// CHECK-SAME: %[[B:.*1]]: vector<3xf32> +// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xf32> +// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xf32> +// CHECK: %[[T2:.*]] = arith.mulf %[[T1]], %[[B]] : vector<3xf32> +// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32> +// CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : vector<2xf32> +// CHECK: %[[T5:.*]] = vector.splat %[[T4]] : vector<3xf32> +// CHECK: %[[T6:.*]] = arith.mulf %[[T5]], %[[B]] : vector<3xf32> +// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xf32> into vector<2x3xf32> +// CHECK: return %[[T7]] : vector<2x3xf32> + +func.func @outerproduct_noacc(%arg0: vector<2xf32>, + %arg1: vector<3xf32>) -> vector<2x3xf32> { + %0 = vector.outerproduct %arg0, %arg1 : vector<2xf32>, vector<3xf32> + return %0: vector<2x3xf32> +} + +// CHECK-LABEL: func @outerproduct_acc +// CHECK-SAME: %[[A:.*0]]: vector<2xf32>, +// CHECK-SAME: %[[B:.*1]]: vector<3xf32>, +// CHECK-SAME: %[[C:.*2]]: vector<2x3xf32> +// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xf32> +// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xf32> +// CHECK: %[[T2:.*]] = vector.extract %[[C]][0] : vector<2x3xf32> +// CHECK: %[[T3:.*]] = vector.fma %[[T1]], %[[B]], %[[T2]] : vector<3xf32> +// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32> +// CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : vector<2xf32> +// CHECK: %[[T6:.*]] = vector.splat %[[T5]] : vector<3xf32> +// CHECK: %[[T7:.*]] = vector.extract %[[C]][1] : vector<2x3xf32> +// CHECK: %[[T8:.*]] = vector.fma %[[T6]], %[[B]], %[[T7]] : vector<3xf32> +// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : vector<3xf32> into vector<2x3xf32> +// CHECK: return %[[T9]] : vector<2x3xf32> + +func.func @outerproduct_acc(%arg0: vector<2xf32>, + %arg1: vector<3xf32>, + %arg2: vector<2x3xf32>) -> vector<2x3xf32> { + %0 = vector.outerproduct %arg0, %arg1, %arg2 : vector<2xf32>, vector<3xf32> + return %0: vector<2x3xf32> +} + +// CHECK-LABEL: func @outerproduct_noacc_int +// CHECK-SAME: %[[A:.*0]]: vector<2xi32>, +// CHECK-SAME: %[[B:.*1]]: vector<3xi32> +// CHECK: %[[C0:.*]] = arith.constant dense<0> : vector<2x3xi32> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xi32> +// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xi32> +// CHECK: %[[T2:.*]] = arith.muli %[[T1]], %[[B]] : vector<3xi32> +// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xi32> into vector<2x3xi32> +// CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : vector<2xi32> +// CHECK: %[[T5:.*]] = vector.splat %[[T4]] : vector<3xi32> +// CHECK: %[[T6:.*]] = arith.muli %[[T5]], %[[B]] : vector<3xi32> +// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xi32> into vector<2x3xi32> +// CHECK: return %[[T7]] : vector<2x3xi32> +func.func @outerproduct_noacc_int(%arg0: vector<2xi32>, + %arg1: vector<3xi32>) -> vector<2x3xi32> { + %0 = vector.outerproduct %arg0, %arg1 : vector<2xi32>, vector<3xi32> + return %0: vector<2x3xi32> +} + +// CHECK-LABEL: func @outerproduct_acc_int +// CHECK-SAME: %[[A:.*0]]: vector<2xi32>, +// CHECK-SAME: %[[B:.*1]]: vector<3xi32>, +// CHECK-SAME: %[[C:.*2]]: vector<2x3xi32> +// CHECK: %[[C0:.*]] = arith.constant dense<0> : vector<2x3xi32> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xi32> +// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xi32> +// CHECK: %[[T2:.*]] = vector.extract %[[C]][0] : vector<2x3xi32> +// CHECK: %[[T3:.*]] = arith.muli %[[T1]], %[[B]] : vector<3xi32> +// CHECK: %[[T4:.*]] = arith.addi %[[T3]], %[[T2]] : vector<3xi32> +// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[C0]] [0] : vector<3xi32> into vector<2x3xi32> +// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xi32> +// CHECK: %[[T7:.*]] = vector.splat %[[T6]] : vector<3xi32> +// CHECK: %[[T8:.*]] = vector.extract %[[C]][1] : vector<2x3xi32> +// CHECK: %[[T9:.*]] = arith.muli %[[T7]], %[[B]] : vector<3xi32> +// CHECK: %[[T10:.*]] = arith.addi %[[T9]], %[[T8]] : vector<3xi32> +// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T5]] [1] : vector<3xi32> into vector<2x3xi32> +// CHECK: return %[[T11]] : vector<2x3xi32> +func.func @outerproduct_acc_int(%arg0: vector<2xi32>, + %arg1: vector<3xi32>, + %arg2: vector<2x3xi32>) -> vector<2x3xi32> { + %0 = vector.outerproduct %arg0, %arg1, %arg2 : vector<2xi32>, vector<3xi32> + return %0: vector<2x3xi32> +} + +// CHECK-LABEL: func @axpy_fp( +// CHECK-SAME: %[[A:.*0]]: vector<16xf32>, +// CHECK-SAME: %[[B:.*1]]: f32) +// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xf32> +// CHECK: %[[T1:.*]] = arith.mulf %[[A]], %[[T0]] : vector<16xf32> +// CHECK: return %[[T1]] : vector<16xf32> +func.func @axpy_fp(%arg0: vector<16xf32>, %arg1: f32) -> vector<16xf32> { + %0 = vector.outerproduct %arg0, %arg1: vector<16xf32>, f32 + return %0: vector<16xf32> +} + +// CHECK-LABEL: func @axpy_fp_add( +// CHECK-SAME: %[[A:.*0]]: vector<16xf32>, +// CHECK-SAME: %[[B:.*1]]: f32, +// CHECK-SAME: %[[C:.*2]]: vector<16xf32>) +// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xf32> +// CHECK: %[[T1:.*]] = vector.fma %[[A]], %[[T0]], %[[C]] : vector<16xf32> +// CHECK: return %[[T1]] : vector<16xf32> +func.func @axpy_fp_add(%arg0: vector<16xf32>, %arg1: f32, %arg2 : vector<16xf32>) -> vector<16xf32> { + %0 = vector.outerproduct %arg0, %arg1, %arg2: vector<16xf32>, f32 + return %0: vector<16xf32> +} + +// CHECK-LABEL: func @axpy_int( +// CHECK-SAME: %[[A:.*0]]: vector<16xi32>, +// CHECK-SAME: %[[B:.*1]]: i32) +// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xi32> +// CHECK: %[[T1:.*]] = arith.muli %[[A]], %[[T0]] : vector<16xi32> +// CHECK: return %[[T1]] : vector<16xi32> +func.func @axpy_int(%arg0: vector<16xi32>, %arg1: i32) -> vector<16xi32> { + %0 = vector.outerproduct %arg0, %arg1: vector<16xi32>, i32 + return %0: vector<16xi32> +} + +// CHECK-LABEL: func @axpy_int_add( +// CHECK-SAME: %[[A:.*0]]: vector<16xi32>, +// CHECK-SAME: %[[B:.*1]]: i32, +// CHECK-SAME: %[[C:.*2]]: vector<16xi32>) +// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xi32> +// CHECK: %[[T1:.*]] = arith.muli %[[A]], %[[T0]] : vector<16xi32> +// CHECK: %[[T2:.*]] = arith.addi %[[T1]], %[[C]] : vector<16xi32> +// CHECK: return %[[T2]] : vector<16xi32> +func.func @axpy_int_add(%arg0: vector<16xi32>, %arg1: i32, %arg2: vector<16xi32>) -> vector<16xi32> { + %0 = vector.outerproduct %arg0, %arg1, %arg2: vector<16xi32>, i32 + return %0: vector<16xi32> +} + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + %f = transform.structured.match ops{["func.func"]} in %module_op + : (!pdl.operation) -> !pdl.operation + + %f2 = transform.vector.lower_outerproduct %f + : (!pdl.operation) -> !pdl.operation + + %f3 = transform.vector.lower_broadcast %f2 + : (!pdl.operation) -> !pdl.operation +} diff --git a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir @@ -0,0 +1,134 @@ + +// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file | FileCheck %s + +// CHECK-LABEL: func @nop_shape_cast +// CHECK-SAME: %[[A:.*]]: vector<16xf32> +// CHECK: return %[[A]] : vector<16xf32> +func.func @nop_shape_cast(%arg0: vector<16xf32>) -> vector<16xf32> { + %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<16xf32> + return %0 : vector<16xf32> +} + +// CHECK-LABEL: func @cancel_shape_cast +// CHECK-SAME: %[[A:.*]]: vector<16xf32> +// CHECK: return %[[A]] : vector<16xf32> + +func.func @cancel_shape_cast(%arg0: vector<16xf32>) -> vector<16xf32> { + %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<4x4xf32> + %1 = vector.shape_cast %0 : vector<4x4xf32> to vector<16xf32> + return %1 : vector<16xf32> +} + +// Shape up and downcasts for 2-D vectors, for supporting conversion to +// llvm.matrix operations +// CHECK-LABEL: func @shape_casts +func.func @shape_casts(%a: vector<2x2xf32>) -> (vector<4xf32>, vector<2x2xf32>) { + // CHECK-DAG: %[[cst22:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32> + // CHECK-DAG: %[[cst:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32> + // CHECK: %[[ex0:.*]] = vector.extract %{{.*}}[0] : vector<2x2xf32> + // + // CHECK: %[[in0:.*]] = vector.insert_strided_slice %[[ex0]], %[[cst]] + // CHECK-SAME: {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32> + // + // CHECK: %[[ex1:.*]] = vector.extract %{{.*}}[1] : vector<2x2xf32> + // + // CHECK: %[[in2:.*]] = vector.insert_strided_slice %[[ex1]], %[[in0]] + // CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32> + // + %0 = vector.shape_cast %a : vector<2x2xf32> to vector<4xf32> + // CHECK: %[[add:.*]] = arith.addf %[[in2]], %[[in2]] : vector<4xf32> + %r0 = arith.addf %0, %0: vector<4xf32> + // + // CHECK: %[[ss0:.*]] = vector.extract_strided_slice %[[add]] + // CHECK-SAME: {offsets = [0], sizes = [2], strides = [1]} : + // CHECK-SAME: vector<4xf32> to vector<2xf32> + // + // CHECK: %[[res0:.*]] = vector.insert %[[ss0]], %[[cst22]] [0] : + // CHECK-SAME: vector<2xf32> into vector<2x2xf32> + // + // CHECK: %[[s2:.*]] = vector.extract_strided_slice %[[add]] + // CHECK-SAME: {offsets = [2], sizes = [2], strides = [1]} : + // CHECK-SAME: vector<4xf32> to vector<2xf32> + // + // CHECK: %[[res1:.*]] = vector.insert %[[s2]], %[[res0]] [1] : + // CHECK-SAME: vector<2xf32> into vector<2x2xf32> + // + %1 = vector.shape_cast %r0 : vector<4xf32> to vector<2x2xf32> + // CHECK: return %[[add]], %[[res1]] : vector<4xf32>, vector<2x2xf32> + return %r0, %1 : vector<4xf32>, vector<2x2xf32> +} + +// CHECK-LABEL: func @shape_cast_2d2d +// CHECK-SAME: %[[A:.*]]: vector<3x2xf32> +// CHECK: %[[C:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<3x2xf32> +// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0, 0] : f32 into vector<2x3xf32> +// CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 1] : vector<3x2xf32> +// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 1] : f32 into vector<2x3xf32> +// CHECK: %[[T4:.*]] = vector.extract %[[A]][1, 0] : vector<3x2xf32> +// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [0, 2] : f32 into vector<2x3xf32> +// CHECK: %[[T6:.*]] = vector.extract %[[A]][1, 1] : vector<3x2xf32> +// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 0] : f32 into vector<2x3xf32> +// CHECK: %[[T8:.*]] = vector.extract %[[A]][2, 0] : vector<3x2xf32> +// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [1, 1] : f32 into vector<2x3xf32> +// CHECK: %[[T10:.*]] = vector.extract %[[A]][2, 1] : vector<3x2xf32> +// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [1, 2] : f32 into vector<2x3xf32> +// CHECK: return %[[T11]] : vector<2x3xf32> + +func.func @shape_cast_2d2d(%arg0 : vector<3x2xf32>) -> vector<2x3xf32> { + %s = vector.shape_cast %arg0: vector<3x2xf32> to vector<2x3xf32> + return %s : vector<2x3xf32> +} + +// CHECK-LABEL: func @shape_cast_3d1d +// CHECK-SAME: %[[A:.*]]: vector<1x3x2xf32> +// CHECK: %[[C:.*]] = arith.constant dense<0.000000e+00> : vector<6xf32> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0, 0] : vector<1x3x2xf32> +// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0] : f32 into vector<6xf32> +// CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 0, 1] : vector<1x3x2xf32> +// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [1] : f32 into vector<6xf32> +// CHECK: %[[T4:.*]] = vector.extract %[[A]][0, 1, 0] : vector<1x3x2xf32> +// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [2] : f32 into vector<6xf32> +// CHECK: %[[T6:.*]] = vector.extract %[[A]][0, 1, 1] : vector<1x3x2xf32> +// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [3] : f32 into vector<6xf32> +// CHECK: %[[T8:.*]] = vector.extract %[[A]][0, 2, 0] : vector<1x3x2xf32> +// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [4] : f32 into vector<6xf32> +// CHECK: %[[T10:.*]] = vector.extract %[[A]][0, 2, 1] : vector<1x3x2xf32> +// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [5] : f32 into vector<6xf32> +// CHECK: return %[[T11]] : vector<6xf32> + +func.func @shape_cast_3d1d(%arg0 : vector<1x3x2xf32>) -> vector<6xf32> { + %s = vector.shape_cast %arg0 : vector<1x3x2xf32> to vector<6xf32> + return %s : vector<6xf32> +} + +// CHECK-LABEL: func @shape_cast_1d3d +// CHECK-SAME: %[[A:.*]]: vector<6xf32> +// CHECK: %[[C:.*]] = arith.constant dense<0.000000e+00> : vector<2x1x3xf32> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<6xf32> +// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0, 0, 0] : f32 into vector<2x1x3xf32> +// CHECK: %[[T2:.*]] = vector.extract %[[A]][1] : vector<6xf32> +// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 0, 1] : f32 into vector<2x1x3xf32> +// CHECK: %[[T4:.*]] = vector.extract %[[A]][2] : vector<6xf32> +// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [0, 0, 2] : f32 into vector<2x1x3xf32> +// CHECK: %[[T6:.*]] = vector.extract %[[A]][3] : vector<6xf32> +// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 0, 0] : f32 into vector<2x1x3xf32> +// CHECK: %[[T8:.*]] = vector.extract %[[A]][4] : vector<6xf32> +// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [1, 0, 1] : f32 into vector<2x1x3xf32> +// CHECK: %[[T10:.*]] = vector.extract %[[A]][5] : vector<6xf32> +// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [1, 0, 2] : f32 into vector<2x1x3xf32> +// CHECK: return %[[T11]] : vector<2x1x3xf32> + +func.func @shape_cast_1d3d(%arg0 : vector<6xf32>) -> vector<2x1x3xf32> { + %s = vector.shape_cast %arg0 : vector<6xf32> to vector<2x1x3xf32> + return %s : vector<2x1x3xf32> +} + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + %f = transform.structured.match ops{["func.func"]} in %module_op + : (!pdl.operation) -> !pdl.operation + + %f2 = transform.vector.lower_shape_cast %f + : (!pdl.operation) -> !pdl.operation +} diff --git a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir --- a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir @@ -1,666 +1,681 @@ -// RUN: mlir-opt %s -test-vector-transpose-lowering=eltwise=1 -split-input-file | FileCheck %s --check-prefix=ELTWISE -// RUN: mlir-opt %s -test-vector-transpose-lowering=shuffle=1 -split-input-file | FileCheck %s --check-prefix=SHUFFLE -// RUN: mlir-opt %s -test-vector-transpose-lowering=flat=1 -split-input-file | FileCheck %s --check-prefix=FLAT -// RUN: mlir-opt %s -test-vector-transpose-lowering=avx2=1 -split-input-file | FileCheck %s --check-prefix=AVX2 - -// ELTWISE-LABEL: func @transpose23 -// ELTWISE-SAME: %[[A:.*]]: vector<2x3xf32> -// ELTWISE: %[[Z:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32> -// ELTWISE: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<2x3xf32> -// ELTWISE: %[[T1:.*]] = vector.insert %[[T0]], %[[Z]] [0, 0] : f32 into vector<3x2xf32> -// ELTWISE: %[[T2:.*]] = vector.extract %[[A]][0, 1] : vector<2x3xf32> -// ELTWISE: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [1, 0] : f32 into vector<3x2xf32> -// ELTWISE: %[[T4:.*]] = vector.extract %[[A]][0, 2] : vector<2x3xf32> -// ELTWISE: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [2, 0] : f32 into vector<3x2xf32> -// ELTWISE: %[[T6:.*]] = vector.extract %[[A]][1, 0] : vector<2x3xf32> -// ELTWISE: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [0, 1] : f32 into vector<3x2xf32> -// ELTWISE: %[[T8:.*]] = vector.extract %[[A]][1, 1] : vector<2x3xf32> -// ELTWISE: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [1, 1] : f32 into vector<3x2xf32> -// ELTWISE: %[[T10:.*]] = vector.extract %[[A]][1, 2] : vector<2x3xf32> -// ELTWISE: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [2, 1] : f32 into vector<3x2xf32> -// ELTWISE: return %[[T11]] : vector<3x2xf32> +// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file --allow-unregistered-dialect | FileCheck %s + +// CHECK-LABEL: func @transpose23 +// CHECK-SAME: %[[A:.*]]: vector<2x3xf32> +// CHECK: %[[Z:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<2x3xf32> +// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[Z]] [0, 0] : f32 into vector<3x2xf32> +// CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 1] : vector<2x3xf32> +// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [1, 0] : f32 into vector<3x2xf32> +// CHECK: %[[T4:.*]] = vector.extract %[[A]][0, 2] : vector<2x3xf32> +// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [2, 0] : f32 into vector<3x2xf32> +// CHECK: %[[T6:.*]] = vector.extract %[[A]][1, 0] : vector<2x3xf32> +// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [0, 1] : f32 into vector<3x2xf32> +// CHECK: %[[T8:.*]] = vector.extract %[[A]][1, 1] : vector<2x3xf32> +// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [1, 1] : f32 into vector<3x2xf32> +// CHECK: %[[T10:.*]] = vector.extract %[[A]][1, 2] : vector<2x3xf32> +// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [2, 1] : f32 into vector<3x2xf32> +// CHECK: return %[[T11]] : vector<3x2xf32> func.func @transpose23(%arg0: vector<2x3xf32>) -> vector<3x2xf32> { %0 = vector.transpose %arg0, [1, 0] : vector<2x3xf32> to vector<3x2xf32> return %0 : vector<3x2xf32> } +// CHECK-LABEL: func @transpose102_1x8x8xf32 +func.func @transpose102_1x8x8xf32(%arg0: vector<1x8x8xf32>) -> vector<8x1x8xf32> { + // CHECK: vector.extract {{.*}}[0, 0] : vector<1x8x8xf32> + // CHECK-NEXT: vector.insert {{.*}} [0, 0] : vector<8xf32> into vector<8x1x8xf32> + // CHECK-NEXT: vector.extract {{.*}}[0, 1] : vector<1x8x8xf32> + // CHECK-NEXT: vector.insert {{.*}} [1, 0] : vector<8xf32> into vector<8x1x8xf32> + // CHECK-NEXT: vector.extract {{.*}}[0, 2] : vector<1x8x8xf32> + // CHECK-NEXT: vector.insert {{.*}} [2, 0] : vector<8xf32> into vector<8x1x8xf32> + // CHECK-NEXT: vector.extract {{.*}}[0, 3] : vector<1x8x8xf32> + // CHECK-NEXT: vector.insert {{.*}} [3, 0] : vector<8xf32> into vector<8x1x8xf32> + // CHECK-NEXT: vector.extract {{.*}}[0, 4] : vector<1x8x8xf32> + // CHECK-NEXT: vector.insert {{.*}} [4, 0] : vector<8xf32> into vector<8x1x8xf32> + // CHECK-NEXT: vector.extract {{.*}}[0, 5] : vector<1x8x8xf32> + // CHECK-NEXT: vector.insert {{.*}} [5, 0] : vector<8xf32> into vector<8x1x8xf32> + // CHECK-NEXT: vector.extract {{.*}}[0, 6] : vector<1x8x8xf32> + // CHECK-NEXT: vector.insert {{.*}} [6, 0] : vector<8xf32> into vector<8x1x8xf32> + // CHECK-NEXT: vector.extract {{.*}}[0, 7] : vector<1x8x8xf32> + // CHECK-NEXT: vector.insert {{.*}} [7, 0] : vector<8xf32> into vector<8x1x8xf32> + %0 = vector.transpose %arg0, [1, 0, 2] : vector<1x8x8xf32> to vector<8x1x8xf32> + return %0 : vector<8x1x8xf32> +} + +// CHECK-LABEL: func @transpose102_8x1x8xf32 +func.func @transpose102_8x1x8xf32(%arg0: vector<8x1x8xf32>) -> vector<1x8x8xf32> { + // CHECK: vector.extract {{.*}}[0, 0] : vector<8x1x8xf32> + // CHECK-NEXT: vector.insert {{.*}} [0, 0] : vector<8xf32> into vector<1x8x8xf32> + // CHECK-NEXT: vector.extract {{.*}}[1, 0] : vector<8x1x8xf32> + // CHECK-NEXT: vector.insert {{.*}} [0, 1] : vector<8xf32> into vector<1x8x8xf32> + // CHECK-NEXT: vector.extract {{.*}}[2, 0] : vector<8x1x8xf32> + // CHECK-NEXT: vector.insert {{.*}} [0, 2] : vector<8xf32> into vector<1x8x8xf32> + // CHECK-NEXT: vector.extract {{.*}}[3, 0] : vector<8x1x8xf32> + // CHECK-NEXT: vector.insert {{.*}} [0, 3] : vector<8xf32> into vector<1x8x8xf32> + // CHECK-NEXT: vector.extract {{.*}}[4, 0] : vector<8x1x8xf32> + // CHECK-NEXT: vector.insert {{.*}} [0, 4] : vector<8xf32> into vector<1x8x8xf32> + // CHECK-NEXT: vector.extract {{.*}}[5, 0] : vector<8x1x8xf32> + // CHECK-NEXT: vector.insert {{.*}} [0, 5] : vector<8xf32> into vector<1x8x8xf32> + // CHECK-NEXT: vector.extract {{.*}}[6, 0] : vector<8x1x8xf32> + // CHECK-NEXT: vector.insert {{.*}} [0, 6] : vector<8xf32> into vector<1x8x8xf32> + // CHECK-NEXT: vector.extract {{.*}}[7, 0] : vector<8x1x8xf32> + // CHECK-NEXT: vector.insert {{.*}} [0, 7] : vector<8xf32> into vector<1x8x8xf32> + %0 = vector.transpose %arg0, [1, 0, 2] : vector<8x1x8xf32> to vector<1x8x8xf32> + return %0 : vector<1x8x8xf32> +} + +// CHECK-LABEL: func @transpose1023_1x1x8x8xf32( +func.func @transpose1023_1x1x8x8xf32(%arg0: vector<1x1x8x8xf32>) -> vector<1x1x8x8xf32> { + // Note the single 2-D extract/insert pair since 2 and 3 are not transposed! + // CHECK: vector.extract {{.*}}[0, 0] : vector<1x1x8x8xf32> + // CHECK-NEXT: vector.insert {{.*}} [0, 0] : vector<8x8xf32> into vector<1x1x8x8xf32> + %0 = vector.transpose %arg0, [1, 0, 2, 3] : vector<1x1x8x8xf32> to vector<1x1x8x8xf32> + return %0 : vector<1x1x8x8xf32> +} + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + transform.vector.lower_transpose %module_op + lowering_strategy = "eltwise" + : (!pdl.operation) -> !pdl.operation +} + // ----- -// SHUFFLE-LABEL: func @transpose -// FLAT-LABEL: func @transpose( +// CHECK-LABEL: func @transpose func.func @transpose(%arg0: vector<2x4xf32>) -> vector<4x2xf32> { - // SHUFFLE: vector.shape_cast %{{.*}} : vector<2x4xf32> to vector<8xf32> + // CHECK: vector.shape_cast %{{.*}} : vector<2x4xf32> to vector<8xf32> // 0 4 // 0 1 2 3 1 5 // 4 5 6 7 -> 2 6 // 3 7 - // SHUFFLE-NEXT: vector.shuffle %{{.*}} [0, 4, 1, 5, 2, 6, 3, 7] : vector<8xf32>, vector<8xf32> - // SHUFFLE-NEXT: vector.shape_cast %{{.*}} : vector<8xf32> to vector<4x2xf32> + // CHECK-NEXT: vector.shuffle %{{.*}} [0, 4, 1, 5, 2, 6, 3, 7] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<8xf32> to vector<4x2xf32> + %0 = vector.transpose %arg0, [1, 0] : vector<2x4xf32> to vector<4x2xf32> + return %0 : vector<4x2xf32> +} + + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + transform.vector.lower_transpose %module_op + lowering_strategy = "shuffle" + : (!pdl.operation) -> !pdl.operation +} + +// ----- - // FLAT: vector.shape_cast {{.*}} : vector<2x4xf32> to vector<8xf32> - // FLAT: vector.flat_transpose %{{.*}} {columns = 2 : i32, rows = 4 : i32} : vector<8xf32> -> vector<8xf32> - // FLAT: vector.shape_cast {{.*}} : vector<8xf32> to vector<4x2xf32> +// CHECK-LABEL: func @transpose( +func.func @transpose(%arg0: vector<2x4xf32>) -> vector<4x2xf32> { + // CHECK: vector.shape_cast {{.*}} : vector<2x4xf32> to vector<8xf32> + // CHECK: vector.flat_transpose %{{.*}} {columns = 2 : i32, rows = 4 : i32} : vector<8xf32> -> vector<8xf32> + // CHECK: vector.shape_cast {{.*}} : vector<8xf32> to vector<4x2xf32> %0 = vector.transpose %arg0, [1, 0] : vector<2x4xf32> to vector<4x2xf32> return %0 : vector<4x2xf32> } + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + transform.vector.lower_transpose %module_op + lowering_strategy = "flat_transpose" + : (!pdl.operation) -> !pdl.operation +} + // ----- -// AVX2-LABEL: func @transpose4x8 +// CHECK-LABEL: func @transpose4x8 func.func @transpose4x8xf32(%arg0: vector<4x8xf32>) -> vector<8x4xf32> { - // AVX2: vector.extract {{.*}}[0] - // AVX2-NEXT: vector.extract {{.*}}[1] - // AVX2-NEXT: vector.extract {{.*}}[2] - // AVX2-NEXT: vector.extract {{.*}}[3] - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 8, 9, 4, 5, 12, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 3, 10, 11, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 8, 9, 4, 5, 12, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 3, 10, 11, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.insert {{.*}}[0] - // AVX2-NEXT: vector.insert {{.*}}[1] - // AVX2-NEXT: vector.insert {{.*}}[2] - // AVX2-NEXT: vector.insert {{.*}}[3] - // AVX2-NEXT: vector.shape_cast {{.*}} vector<4x8xf32> to vector<32xf32> - // AVX2-NEXT: vector.shape_cast {{.*}} vector<32xf32> to vector<8x4xf32> + // CHECK: vector.extract {{.*}}[0] + // CHECK-NEXT: vector.extract {{.*}}[1] + // CHECK-NEXT: vector.extract {{.*}}[2] + // CHECK-NEXT: vector.extract {{.*}}[3] + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 1, 8, 9, 4, 5, 12, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 3, 10, 11, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 1, 8, 9, 4, 5, 12, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 3, 10, 11, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.insert {{.*}}[0] + // CHECK-NEXT: vector.insert {{.*}}[1] + // CHECK-NEXT: vector.insert {{.*}}[2] + // CHECK-NEXT: vector.insert {{.*}}[3] + // CHECK-NEXT: vector.shape_cast {{.*}} vector<4x8xf32> to vector<32xf32> + // CHECK-NEXT: vector.shape_cast {{.*}} vector<32xf32> to vector<8x4xf32> %0 = vector.transpose %arg0, [1, 0] : vector<4x8xf32> to vector<8x4xf32> return %0 : vector<8x4xf32> } -// ----- - -// AVX2-LABEL: func @transpose021_1x4x8 +// CHECK-LABEL: func @transpose021_1x4x8 func.func @transpose021_1x4x8xf32(%arg0: vector<1x4x8xf32>) -> vector<1x8x4xf32> { - // AVX2: vector.extract {{.*}}[0, 0] - // AVX2-NEXT: vector.extract {{.*}}[0, 1] - // AVX2-NEXT: vector.extract {{.*}}[0, 2] - // AVX2-NEXT: vector.extract {{.*}}[0, 3] - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 8, 9, 4, 5, 12, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 3, 10, 11, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 8, 9, 4, 5, 12, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 3, 10, 11, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.insert {{.*}}[0] - // AVX2-NEXT: vector.insert {{.*}}[1] - // AVX2-NEXT: vector.insert {{.*}}[2] - // AVX2-NEXT: vector.insert {{.*}}[3] - // AVX2-NEXT: vector.shape_cast {{.*}} vector<4x8xf32> to vector<32xf32> - // AVX2-NEXT: vector.shape_cast {{.*}} vector<32xf32> to vector<1x8x4xf32> + // CHECK: vector.extract {{.*}}[0, 0] + // CHECK-NEXT: vector.extract {{.*}}[0, 1] + // CHECK-NEXT: vector.extract {{.*}}[0, 2] + // CHECK-NEXT: vector.extract {{.*}}[0, 3] + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 1, 8, 9, 4, 5, 12, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 3, 10, 11, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 1, 8, 9, 4, 5, 12, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 3, 10, 11, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.insert {{.*}}[0] + // CHECK-NEXT: vector.insert {{.*}}[1] + // CHECK-NEXT: vector.insert {{.*}}[2] + // CHECK-NEXT: vector.insert {{.*}}[3] + // CHECK-NEXT: vector.shape_cast {{.*}} vector<4x8xf32> to vector<32xf32> + // CHECK-NEXT: vector.shape_cast {{.*}} vector<32xf32> to vector<1x8x4xf32> %0 = vector.transpose %arg0, [0, 2, 1] : vector<1x4x8xf32> to vector<1x8x4xf32> return %0 : vector<1x8x4xf32> } -// ----- - -// AVX2-LABEL: func @transpose8x8 +// CHECK-LABEL: func @transpose8x8 func.func @transpose8x8xf32(%arg0: vector<8x8xf32>) -> vector<8x8xf32> { - // AVX2: vector.extract {{.*}}[0] - // AVX2-NEXT: vector.extract {{.*}}[1] - // AVX2-NEXT: vector.extract {{.*}}[2] - // AVX2-NEXT: vector.extract {{.*}}[3] - // AVX2-NEXT: vector.extract {{.*}}[4] - // AVX2-NEXT: vector.extract {{.*}}[5] - // AVX2-NEXT: vector.extract {{.*}}[6] - // AVX2-NEXT: vector.extract {{.*}}[7] - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-COUNT-4: vector.shuffle {{.*}} [2, 3, 8, 9, 6, 7, 12, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-COUNT-4: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> - // AVX2-COUNT-4: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.insert {{.*}}[0] - // AVX2-NEXT: vector.insert {{.*}}[1] - // AVX2-NEXT: vector.insert {{.*}}[2] - // AVX2-NEXT: vector.insert {{.*}}[3] - // AVX2-NEXT: vector.insert {{.*}}[4] - // AVX2-NEXT: vector.insert {{.*}}[5] - // AVX2-NEXT: vector.insert {{.*}}[6] - // AVX2-NEXT: vector.insert {{.*}}[7] + // CHECK: vector.extract {{.*}}[0] + // CHECK-NEXT: vector.extract {{.*}}[1] + // CHECK-NEXT: vector.extract {{.*}}[2] + // CHECK-NEXT: vector.extract {{.*}}[3] + // CHECK-NEXT: vector.extract {{.*}}[4] + // CHECK-NEXT: vector.extract {{.*}}[5] + // CHECK-NEXT: vector.extract {{.*}}[6] + // CHECK-NEXT: vector.extract {{.*}}[7] + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-COUNT-4: vector.shuffle {{.*}} [2, 3, 8, 9, 6, 7, 12, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-COUNT-4: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> + // CHECK-COUNT-4: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.insert {{.*}}[0] + // CHECK-NEXT: vector.insert {{.*}}[1] + // CHECK-NEXT: vector.insert {{.*}}[2] + // CHECK-NEXT: vector.insert {{.*}}[3] + // CHECK-NEXT: vector.insert {{.*}}[4] + // CHECK-NEXT: vector.insert {{.*}}[5] + // CHECK-NEXT: vector.insert {{.*}}[6] + // CHECK-NEXT: vector.insert {{.*}}[7] %0 = vector.transpose %arg0, [1, 0] : vector<8x8xf32> to vector<8x8xf32> return %0 : vector<8x8xf32> } -// ----- - -// AVX2-LABEL: func @transpose021_1x8x8 +// CHECK-LABEL: func @transpose021_1x8x8 func.func @transpose021_1x8x8xf32(%arg0: vector<1x8x8xf32>) -> vector<1x8x8xf32> { - // AVX2: vector.extract {{.*}}[0, 0] - // AVX2-NEXT: vector.extract {{.*}}[0, 1] - // AVX2-NEXT: vector.extract {{.*}}[0, 2] - // AVX2-NEXT: vector.extract {{.*}}[0, 3] - // AVX2-NEXT: vector.extract {{.*}}[0, 4] - // AVX2-NEXT: vector.extract {{.*}}[0, 5] - // AVX2-NEXT: vector.extract {{.*}}[0, 6] - // AVX2-NEXT: vector.extract {{.*}}[0, 7] - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-COUNT-4: vector.shuffle {{.*}} [2, 3, 8, 9, 6, 7, 12, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-COUNT-4: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> - // AVX2-COUNT-4: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.insert {{.*}}[0] - // AVX2-NEXT: vector.insert {{.*}}[1] - // AVX2-NEXT: vector.insert {{.*}}[2] - // AVX2-NEXT: vector.insert {{.*}}[3] - // AVX2-NEXT: vector.insert {{.*}}[4] - // AVX2-NEXT: vector.insert {{.*}}[5] - // AVX2-NEXT: vector.insert {{.*}}[6] - // AVX2-NEXT: vector.insert {{.*}}[7] - // AVX2-NEXT: vector.shape_cast %{{.*}} : vector<8x8xf32> to vector<1x8x8xf32> + // CHECK: vector.extract {{.*}}[0, 0] + // CHECK-NEXT: vector.extract {{.*}}[0, 1] + // CHECK-NEXT: vector.extract {{.*}}[0, 2] + // CHECK-NEXT: vector.extract {{.*}}[0, 3] + // CHECK-NEXT: vector.extract {{.*}}[0, 4] + // CHECK-NEXT: vector.extract {{.*}}[0, 5] + // CHECK-NEXT: vector.extract {{.*}}[0, 6] + // CHECK-NEXT: vector.extract {{.*}}[0, 7] + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-COUNT-4: vector.shuffle {{.*}} [2, 3, 8, 9, 6, 7, 12, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-COUNT-4: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> + // CHECK-COUNT-4: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.insert {{.*}}[0] + // CHECK-NEXT: vector.insert {{.*}}[1] + // CHECK-NEXT: vector.insert {{.*}}[2] + // CHECK-NEXT: vector.insert {{.*}}[3] + // CHECK-NEXT: vector.insert {{.*}}[4] + // CHECK-NEXT: vector.insert {{.*}}[5] + // CHECK-NEXT: vector.insert {{.*}}[6] + // CHECK-NEXT: vector.insert {{.*}}[7] + // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<8x8xf32> to vector<1x8x8xf32> %0 = vector.transpose %arg0, [0, 2, 1] : vector<1x8x8xf32> to vector<1x8x8xf32> return %0 : vector<1x8x8xf32> } -// ----- - -// AVX2-LABEL: func @transpose120_8x1x8 +// CHECK-LABEL: func @transpose120_8x1x8 func.func @transpose120_8x1x8xf32(%arg0: vector<8x1x8xf32>) -> vector<1x8x8xf32> { - // AVX2: vector.extract {{.*}}[0, 0] - // AVX2-NEXT: vector.extract {{.*}}[1, 0] - // AVX2-NEXT: vector.extract {{.*}}[2, 0] - // AVX2-NEXT: vector.extract {{.*}}[3, 0] - // AVX2-NEXT: vector.extract {{.*}}[4, 0] - // AVX2-NEXT: vector.extract {{.*}}[5, 0] - // AVX2-NEXT: vector.extract {{.*}}[6, 0] - // AVX2-NEXT: vector.extract {{.*}}[7, 0] - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-COUNT-4: vector.shuffle {{.*}} [2, 3, 8, 9, 6, 7, 12, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-COUNT-4: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> - // AVX2-COUNT-4: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.insert {{.*}}[0] - // AVX2-NEXT: vector.insert {{.*}}[1] - // AVX2-NEXT: vector.insert {{.*}}[2] - // AVX2-NEXT: vector.insert {{.*}}[3] - // AVX2-NEXT: vector.insert {{.*}}[4] - // AVX2-NEXT: vector.insert {{.*}}[5] - // AVX2-NEXT: vector.insert {{.*}}[6] - // AVX2-NEXT: vector.insert {{.*}}[7] - // AVX2-NEXT: vector.shape_cast %{{.*}} : vector<8x8xf32> to vector<1x8x8xf32> + // CHECK: vector.extract {{.*}}[0, 0] + // CHECK-NEXT: vector.extract {{.*}}[1, 0] + // CHECK-NEXT: vector.extract {{.*}}[2, 0] + // CHECK-NEXT: vector.extract {{.*}}[3, 0] + // CHECK-NEXT: vector.extract {{.*}}[4, 0] + // CHECK-NEXT: vector.extract {{.*}}[5, 0] + // CHECK-NEXT: vector.extract {{.*}}[6, 0] + // CHECK-NEXT: vector.extract {{.*}}[7, 0] + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-COUNT-4: vector.shuffle {{.*}} [2, 3, 8, 9, 6, 7, 12, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-COUNT-4: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> + // CHECK-COUNT-4: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.insert {{.*}}[0] + // CHECK-NEXT: vector.insert {{.*}}[1] + // CHECK-NEXT: vector.insert {{.*}}[2] + // CHECK-NEXT: vector.insert {{.*}}[3] + // CHECK-NEXT: vector.insert {{.*}}[4] + // CHECK-NEXT: vector.insert {{.*}}[5] + // CHECK-NEXT: vector.insert {{.*}}[6] + // CHECK-NEXT: vector.insert {{.*}}[7] + // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<8x8xf32> to vector<1x8x8xf32> %0 = vector.transpose %arg0, [1, 2, 0] : vector<8x1x8xf32> to vector<1x8x8xf32> return %0 : vector<1x8x8xf32> } -// ----- - -// AVX2-LABEL: func @transpose120_8x8x1 +// CHECK-LABEL: func @transpose120_8x8x1 func.func @transpose120_8x8x1xf32(%arg0: vector<8x8x1xf32>) -> vector<8x1x8xf32> { - // AVX2: vector.shape_cast %{{.*}} : vector<8x8x1xf32> to vector<8x8xf32> - // AVX2-NEXT: vector.extract {{.*}}[0] - // AVX2-NEXT: vector.extract {{.*}}[1] - // AVX2-NEXT: vector.extract {{.*}}[2] - // AVX2-NEXT: vector.extract {{.*}}[3] - // AVX2-NEXT: vector.extract {{.*}}[4] - // AVX2-NEXT: vector.extract {{.*}}[5] - // AVX2-NEXT: vector.extract {{.*}}[6] - // AVX2-NEXT: vector.extract {{.*}}[7] - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-COUNT-4: vector.shuffle {{.*}} [2, 3, 8, 9, 6, 7, 12, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-COUNT-4: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> - // AVX2-COUNT-4: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.insert {{.*}}[0] - // AVX2-NEXT: vector.insert {{.*}}[1] - // AVX2-NEXT: vector.insert {{.*}}[2] - // AVX2-NEXT: vector.insert {{.*}}[3] - // AVX2-NEXT: vector.insert {{.*}}[4] - // AVX2-NEXT: vector.insert {{.*}}[5] - // AVX2-NEXT: vector.insert {{.*}}[6] - // AVX2-NEXT: vector.insert {{.*}}[7] - // AVX2-NEXT: vector.shape_cast %{{.*}} : vector<8x8xf32> to vector<8x1x8xf32> + // CHECK: vector.shape_cast %{{.*}} : vector<8x8x1xf32> to vector<8x8xf32> + // CHECK-NEXT: vector.extract {{.*}}[0] + // CHECK-NEXT: vector.extract {{.*}}[1] + // CHECK-NEXT: vector.extract {{.*}}[2] + // CHECK-NEXT: vector.extract {{.*}}[3] + // CHECK-NEXT: vector.extract {{.*}}[4] + // CHECK-NEXT: vector.extract {{.*}}[5] + // CHECK-NEXT: vector.extract {{.*}}[6] + // CHECK-NEXT: vector.extract {{.*}}[7] + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-COUNT-4: vector.shuffle {{.*}} [2, 3, 8, 9, 6, 7, 12, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-COUNT-4: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> + // CHECK-COUNT-4: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.insert {{.*}}[0] + // CHECK-NEXT: vector.insert {{.*}}[1] + // CHECK-NEXT: vector.insert {{.*}}[2] + // CHECK-NEXT: vector.insert {{.*}}[3] + // CHECK-NEXT: vector.insert {{.*}}[4] + // CHECK-NEXT: vector.insert {{.*}}[5] + // CHECK-NEXT: vector.insert {{.*}}[6] + // CHECK-NEXT: vector.insert {{.*}}[7] + // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<8x8xf32> to vector<8x1x8xf32> %0 = vector.transpose %arg0, [1, 2, 0] : vector<8x8x1xf32> to vector<8x1x8xf32> return %0 : vector<8x1x8xf32> } -// ----- - -// AVX2-LABEL: func @transpose102_8x8x1 +// CHECK-LABEL: func @transpose102_8x8x1 func.func @transpose102_8x8x1xf32(%arg0: vector<8x8x1xf32>) -> vector<8x8x1xf32> { - // AVX2: vector.shape_cast %{{.*}} : vector<8x8x1xf32> to vector<8x8xf32> - // AVX2-NEXT: vector.extract {{.*}}[0] - // AVX2-NEXT: vector.extract {{.*}}[1] - // AVX2-NEXT: vector.extract {{.*}}[2] - // AVX2-NEXT: vector.extract {{.*}}[3] - // AVX2-NEXT: vector.extract {{.*}}[4] - // AVX2-NEXT: vector.extract {{.*}}[5] - // AVX2-NEXT: vector.extract {{.*}}[6] - // AVX2-NEXT: vector.extract {{.*}}[7] - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-COUNT-4: vector.shuffle {{.*}} [2, 3, 8, 9, 6, 7, 12, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-COUNT-4: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> - // AVX2-COUNT-4: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.insert {{.*}}[0] - // AVX2-NEXT: vector.insert {{.*}}[1] - // AVX2-NEXT: vector.insert {{.*}}[2] - // AVX2-NEXT: vector.insert {{.*}}[3] - // AVX2-NEXT: vector.insert {{.*}}[4] - // AVX2-NEXT: vector.insert {{.*}}[5] - // AVX2-NEXT: vector.insert {{.*}}[6] - // AVX2-NEXT: vector.insert {{.*}}[7] - // AVX2-NEXT: vector.shape_cast %{{.*}} : vector<8x8xf32> to vector<8x8x1xf32> + // CHECK: vector.shape_cast %{{.*}} : vector<8x8x1xf32> to vector<8x8xf32> + // CHECK-NEXT: vector.extract {{.*}}[0] + // CHECK-NEXT: vector.extract {{.*}}[1] + // CHECK-NEXT: vector.extract {{.*}}[2] + // CHECK-NEXT: vector.extract {{.*}}[3] + // CHECK-NEXT: vector.extract {{.*}}[4] + // CHECK-NEXT: vector.extract {{.*}}[5] + // CHECK-NEXT: vector.extract {{.*}}[6] + // CHECK-NEXT: vector.extract {{.*}}[7] + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-COUNT-4: vector.shuffle {{.*}} [2, 3, 8, 9, 6, 7, 12, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-COUNT-4: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> + // CHECK-COUNT-4: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.insert {{.*}}[0] + // CHECK-NEXT: vector.insert {{.*}}[1] + // CHECK-NEXT: vector.insert {{.*}}[2] + // CHECK-NEXT: vector.insert {{.*}}[3] + // CHECK-NEXT: vector.insert {{.*}}[4] + // CHECK-NEXT: vector.insert {{.*}}[5] + // CHECK-NEXT: vector.insert {{.*}}[6] + // CHECK-NEXT: vector.insert {{.*}}[7] + // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<8x8xf32> to vector<8x8x1xf32> %0 = vector.transpose %arg0, [1, 0, 2] : vector<8x8x1xf32> to vector<8x8x1xf32> return %0 : vector<8x8x1xf32> } -// ----- - -// AVX2-LABEL: func @transpose201_8x1x8 +// CHECK-LABEL: func @transpose201_8x1x8 func.func @transpose201_8x1x8xf32(%arg0: vector<8x1x8xf32>) -> vector<8x8x1xf32> { - // AVX2: vector.extract {{.*}}[0, 0] - // AVX2-NEXT: vector.extract {{.*}}[1, 0] - // AVX2-NEXT: vector.extract {{.*}}[2, 0] - // AVX2-NEXT: vector.extract {{.*}}[3, 0] - // AVX2-NEXT: vector.extract {{.*}}[4, 0] - // AVX2-NEXT: vector.extract {{.*}}[5, 0] - // AVX2-NEXT: vector.extract {{.*}}[6, 0] - // AVX2-NEXT: vector.extract {{.*}}[7, 0] - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-COUNT-4: vector.shuffle {{.*}} [2, 3, 8, 9, 6, 7, 12, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-COUNT-4: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> - // AVX2-COUNT-4: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.insert {{.*}}[0] - // AVX2-NEXT: vector.insert {{.*}}[1] - // AVX2-NEXT: vector.insert {{.*}}[2] - // AVX2-NEXT: vector.insert {{.*}}[3] - // AVX2-NEXT: vector.insert {{.*}}[4] - // AVX2-NEXT: vector.insert {{.*}}[5] - // AVX2-NEXT: vector.insert {{.*}}[6] - // AVX2-NEXT: vector.insert {{.*}}[7] - // AVX2-NEXT: vector.shape_cast %{{.*}} : vector<8x8xf32> to vector<8x8x1xf32> + // CHECK: vector.extract {{.*}}[0, 0] + // CHECK-NEXT: vector.extract {{.*}}[1, 0] + // CHECK-NEXT: vector.extract {{.*}}[2, 0] + // CHECK-NEXT: vector.extract {{.*}}[3, 0] + // CHECK-NEXT: vector.extract {{.*}}[4, 0] + // CHECK-NEXT: vector.extract {{.*}}[5, 0] + // CHECK-NEXT: vector.extract {{.*}}[6, 0] + // CHECK-NEXT: vector.extract {{.*}}[7, 0] + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-COUNT-4: vector.shuffle {{.*}} [2, 3, 8, 9, 6, 7, 12, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-COUNT-4: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> + // CHECK-COUNT-4: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.insert {{.*}}[0] + // CHECK-NEXT: vector.insert {{.*}}[1] + // CHECK-NEXT: vector.insert {{.*}}[2] + // CHECK-NEXT: vector.insert {{.*}}[3] + // CHECK-NEXT: vector.insert {{.*}}[4] + // CHECK-NEXT: vector.insert {{.*}}[5] + // CHECK-NEXT: vector.insert {{.*}}[6] + // CHECK-NEXT: vector.insert {{.*}}[7] + // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<8x8xf32> to vector<8x8x1xf32> %0 = vector.transpose %arg0, [2, 0, 1] : vector<8x1x8xf32> to vector<8x8x1xf32> return %0 : vector<8x8x1xf32> } -// ----- - -// AVX2-LABEL: func @transpose201_1x8x8 +// CHECK-LABEL: func @transpose201_1x8x8 func.func @transpose201_1x8x8xf32(%arg0: vector<1x8x8xf32>) -> vector<8x1x8xf32> { - // AVX2: vector.extract {{.*}}[0, 0] - // AVX2-NEXT: vector.extract {{.*}}[0, 1] - // AVX2-NEXT: vector.extract {{.*}}[0, 2] - // AVX2-NEXT: vector.extract {{.*}}[0, 3] - // AVX2-NEXT: vector.extract {{.*}}[0, 4] - // AVX2-NEXT: vector.extract {{.*}}[0, 5] - // AVX2-NEXT: vector.extract {{.*}}[0, 6] - // AVX2-NEXT: vector.extract {{.*}}[0, 7] - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-COUNT-4: vector.shuffle {{.*}} [2, 3, 8, 9, 6, 7, 12, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-COUNT-4: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> - // AVX2-COUNT-4: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.insert {{.*}}[0] - // AVX2-NEXT: vector.insert {{.*}}[1] - // AVX2-NEXT: vector.insert {{.*}}[2] - // AVX2-NEXT: vector.insert {{.*}}[3] - // AVX2-NEXT: vector.insert {{.*}}[4] - // AVX2-NEXT: vector.insert {{.*}}[5] - // AVX2-NEXT: vector.insert {{.*}}[6] - // AVX2-NEXT: vector.insert {{.*}}[7] - // AVX2-NEXT: vector.shape_cast %{{.*}} : vector<8x8xf32> to vector<8x1x8xf32> + // CHECK: vector.extract {{.*}}[0, 0] + // CHECK-NEXT: vector.extract {{.*}}[0, 1] + // CHECK-NEXT: vector.extract {{.*}}[0, 2] + // CHECK-NEXT: vector.extract {{.*}}[0, 3] + // CHECK-NEXT: vector.extract {{.*}}[0, 4] + // CHECK-NEXT: vector.extract {{.*}}[0, 5] + // CHECK-NEXT: vector.extract {{.*}}[0, 6] + // CHECK-NEXT: vector.extract {{.*}}[0, 7] + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-COUNT-4: vector.shuffle {{.*}} [2, 3, 8, 9, 6, 7, 12, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-COUNT-4: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> + // CHECK-COUNT-4: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.insert {{.*}}[0] + // CHECK-NEXT: vector.insert {{.*}}[1] + // CHECK-NEXT: vector.insert {{.*}}[2] + // CHECK-NEXT: vector.insert {{.*}}[3] + // CHECK-NEXT: vector.insert {{.*}}[4] + // CHECK-NEXT: vector.insert {{.*}}[5] + // CHECK-NEXT: vector.insert {{.*}}[6] + // CHECK-NEXT: vector.insert {{.*}}[7] + // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<8x8xf32> to vector<8x1x8xf32> %0 = vector.transpose %arg0, [2, 0, 1] : vector<1x8x8xf32> to vector<8x1x8xf32> return %0 : vector<8x1x8xf32> } -// ----- - -// AVX2-LABEL: func @transpose210_8x1x8 +// CHECK-LABEL: func @transpose210_8x1x8 func.func @transpose210_8x1x8xf32(%arg0: vector<8x1x8xf32>) -> vector<8x1x8xf32> { - // AVX2: vector.extract {{.*}}[0, 0] - // AVX2-NEXT: vector.extract {{.*}}[1, 0] - // AVX2-NEXT: vector.extract {{.*}}[2, 0] - // AVX2-NEXT: vector.extract {{.*}}[3, 0] - // AVX2-NEXT: vector.extract {{.*}}[4, 0] - // AVX2-NEXT: vector.extract {{.*}}[5, 0] - // AVX2-NEXT: vector.extract {{.*}}[6, 0] - // AVX2-NEXT: vector.extract {{.*}}[7, 0] - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-COUNT-4: vector.shuffle {{.*}} [2, 3, 8, 9, 6, 7, 12, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-COUNT-4: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> - // AVX2-COUNT-4: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.insert {{.*}}[0] - // AVX2-NEXT: vector.insert {{.*}}[1] - // AVX2-NEXT: vector.insert {{.*}}[2] - // AVX2-NEXT: vector.insert {{.*}}[3] - // AVX2-NEXT: vector.insert {{.*}}[4] - // AVX2-NEXT: vector.insert {{.*}}[5] - // AVX2-NEXT: vector.insert {{.*}}[6] - // AVX2-NEXT: vector.insert {{.*}}[7] - // AVX2-NEXT: vector.shape_cast %{{.*}} : vector<8x8xf32> to vector<8x1x8xf32> + // CHECK: vector.extract {{.*}}[0, 0] + // CHECK-NEXT: vector.extract {{.*}}[1, 0] + // CHECK-NEXT: vector.extract {{.*}}[2, 0] + // CHECK-NEXT: vector.extract {{.*}}[3, 0] + // CHECK-NEXT: vector.extract {{.*}}[4, 0] + // CHECK-NEXT: vector.extract {{.*}}[5, 0] + // CHECK-NEXT: vector.extract {{.*}}[6, 0] + // CHECK-NEXT: vector.extract {{.*}}[7, 0] + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-COUNT-4: vector.shuffle {{.*}} [2, 3, 8, 9, 6, 7, 12, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-COUNT-4: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> + // CHECK-COUNT-4: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.insert {{.*}}[0] + // CHECK-NEXT: vector.insert {{.*}}[1] + // CHECK-NEXT: vector.insert {{.*}}[2] + // CHECK-NEXT: vector.insert {{.*}}[3] + // CHECK-NEXT: vector.insert {{.*}}[4] + // CHECK-NEXT: vector.insert {{.*}}[5] + // CHECK-NEXT: vector.insert {{.*}}[6] + // CHECK-NEXT: vector.insert {{.*}}[7] + // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<8x8xf32> to vector<8x1x8xf32> %0 = vector.transpose %arg0, [2, 1, 0] : vector<8x1x8xf32> to vector<8x1x8xf32> return %0 : vector<8x1x8xf32> } -// ----- - -// AVX2-LABEL: func @transpose210_8x8x1 +// CHECK-LABEL: func @transpose210_8x8x1 func.func @transpose210_8x8x1xf32(%arg0: vector<8x8x1xf32>) -> vector<1x8x8xf32> { - // AVX2: vector.shape_cast %{{.*}} : vector<8x8x1xf32> to vector<8x8xf32> - // AVX2-NEXT: vector.extract {{.*}}[0] - // AVX2-NEXT: vector.extract {{.*}}[1] - // AVX2-NEXT: vector.extract {{.*}}[2] - // AVX2-NEXT: vector.extract {{.*}}[3] - // AVX2-NEXT: vector.extract {{.*}}[4] - // AVX2-NEXT: vector.extract {{.*}}[5] - // AVX2-NEXT: vector.extract {{.*}}[6] - // AVX2-NEXT: vector.extract {{.*}}[7] - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-COUNT-4: vector.shuffle {{.*}} [2, 3, 8, 9, 6, 7, 12, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-COUNT-4: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> - // AVX2-COUNT-4: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.insert {{.*}}[0] - // AVX2-NEXT: vector.insert {{.*}}[1] - // AVX2-NEXT: vector.insert {{.*}}[2] - // AVX2-NEXT: vector.insert {{.*}}[3] - // AVX2-NEXT: vector.insert {{.*}}[4] - // AVX2-NEXT: vector.insert {{.*}}[5] - // AVX2-NEXT: vector.insert {{.*}}[6] - // AVX2-NEXT: vector.insert {{.*}}[7] - // AVX2-NEXT: vector.shape_cast %{{.*}} : vector<8x8xf32> to vector<1x8x8xf32> + // CHECK: vector.shape_cast %{{.*}} : vector<8x8x1xf32> to vector<8x8xf32> + // CHECK-NEXT: vector.extract {{.*}}[0] + // CHECK-NEXT: vector.extract {{.*}}[1] + // CHECK-NEXT: vector.extract {{.*}}[2] + // CHECK-NEXT: vector.extract {{.*}}[3] + // CHECK-NEXT: vector.extract {{.*}}[4] + // CHECK-NEXT: vector.extract {{.*}}[5] + // CHECK-NEXT: vector.extract {{.*}}[6] + // CHECK-NEXT: vector.extract {{.*}}[7] + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-COUNT-4: vector.shuffle {{.*}} [2, 3, 8, 9, 6, 7, 12, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-COUNT-4: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> + // CHECK-COUNT-4: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.insert {{.*}}[0] + // CHECK-NEXT: vector.insert {{.*}}[1] + // CHECK-NEXT: vector.insert {{.*}}[2] + // CHECK-NEXT: vector.insert {{.*}}[3] + // CHECK-NEXT: vector.insert {{.*}}[4] + // CHECK-NEXT: vector.insert {{.*}}[5] + // CHECK-NEXT: vector.insert {{.*}}[6] + // CHECK-NEXT: vector.insert {{.*}}[7] + // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<8x8xf32> to vector<1x8x8xf32> %0 = vector.transpose %arg0, [2, 1, 0] : vector<8x8x1xf32> to vector<1x8x8xf32> return %0 : vector<1x8x8xf32> } -// ----- - -// AVX2-LABEL: func @transpose210_1x8x8 +// CHECK-LABEL: func @transpose210_1x8x8 func.func @transpose210_1x8x8xf32(%arg0: vector<1x8x8xf32>) -> vector<8x8x1xf32> { - // AVX2: vector.extract {{.*}}[0, 0] - // AVX2-NEXT: vector.extract {{.*}}[0, 1] - // AVX2-NEXT: vector.extract {{.*}}[0, 2] - // AVX2-NEXT: vector.extract {{.*}}[0, 3] - // AVX2-NEXT: vector.extract {{.*}}[0, 4] - // AVX2-NEXT: vector.extract {{.*}}[0, 5] - // AVX2-NEXT: vector.extract {{.*}}[0, 6] - // AVX2-NEXT: vector.extract {{.*}}[0, 7] - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> - // AVX2-COUNT-4: vector.shuffle {{.*}} [2, 3, 8, 9, 6, 7, 12, 13] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> - // AVX2-COUNT-4: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> - // AVX2-COUNT-4: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> - // AVX2-NEXT: vector.insert {{.*}}[0] - // AVX2-NEXT: vector.insert {{.*}}[1] - // AVX2-NEXT: vector.insert {{.*}}[2] - // AVX2-NEXT: vector.insert {{.*}}[3] - // AVX2-NEXT: vector.insert {{.*}}[4] - // AVX2-NEXT: vector.insert {{.*}}[5] - // AVX2-NEXT: vector.insert {{.*}}[6] - // AVX2-NEXT: vector.insert {{.*}}[7] - // AVX2-NEXT: vector.shape_cast %{{.*}} : vector<8x8xf32> to vector<8x8x1xf32> + // CHECK: vector.extract {{.*}}[0, 0] + // CHECK-NEXT: vector.extract {{.*}}[0, 1] + // CHECK-NEXT: vector.extract {{.*}}[0, 2] + // CHECK-NEXT: vector.extract {{.*}}[0, 3] + // CHECK-NEXT: vector.extract {{.*}}[0, 4] + // CHECK-NEXT: vector.extract {{.*}}[0, 5] + // CHECK-NEXT: vector.extract {{.*}}[0, 6] + // CHECK-NEXT: vector.extract {{.*}}[0, 7] + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // CHECK-COUNT-4: vector.shuffle {{.*}} [2, 3, 8, 9, 6, 7, 12, 13] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // CHECK-COUNT-4: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> + // CHECK-COUNT-4: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> + // CHECK-NEXT: vector.insert {{.*}}[0] + // CHECK-NEXT: vector.insert {{.*}}[1] + // CHECK-NEXT: vector.insert {{.*}}[2] + // CHECK-NEXT: vector.insert {{.*}}[3] + // CHECK-NEXT: vector.insert {{.*}}[4] + // CHECK-NEXT: vector.insert {{.*}}[5] + // CHECK-NEXT: vector.insert {{.*}}[6] + // CHECK-NEXT: vector.insert {{.*}}[7] + // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<8x8xf32> to vector<8x8x1xf32> %0 = vector.transpose %arg0, [2, 1, 0] : vector<1x8x8xf32> to vector<8x8x1xf32> return %0 : vector<8x8x1xf32> } -// ----- - +// CHECK-LABEL: func @transpose021_8x1x8 func.func @do_not_lower_nonf32_to_avx2(%arg0: vector<4x8xi32>) -> vector<8x4xi32> { %0 = vector.transpose %arg0, [1, 0] : vector<4x8xi32> to vector<8x4xi32> return %0 : vector<8x4xi32> } -// AVX2-NOT: vector.shuffle - -// ----- +// CHECK-NOT: vector.shuffle -// AVX2-LABEL: func @transpose021_8x1x8 +// CHECK-LABEL: func @transpose021_8x1x8 func.func @transpose021_8x1x8xf32(%arg0: vector<8x1x8xf32>) -> vector<8x8x1xf32> { %0 = vector.transpose %arg0, [0, 2, 1] : vector<8x1x8xf32> to vector<8x8x1xf32> return %0 : vector<8x8x1xf32> } -// AVX2-NOT: vector.shuffle - -// ----- +// CHECK-NOT: vector.shuffle -// AVX2-LABEL: func @transpose021_8x8x1 +// CHECK-LABEL: func @transpose021_8x8x1 func.func @transpose021_8x8x1xf32(%arg0: vector<8x8x1xf32>) -> vector<8x1x8xf32> { %0 = vector.transpose %arg0, [0, 2, 1] : vector<8x8x1xf32> to vector<8x1x8xf32> return %0 : vector<8x1x8xf32> } -// AVX2-NOT: vector.shuffle - -// ----- +// CHECK-NOT: vector.shuffle -// ELTWISE-LABEL: func @transpose102_1x8x8xf32 -// AVX2-LABEL: func @transpose102_1x8x8 +// CHECK-LABEL: func @transpose102_1x8x8 func.func @transpose102_1x8x8xf32(%arg0: vector<1x8x8xf32>) -> vector<8x1x8xf32> { - // ELTWISE: vector.extract {{.*}}[0, 0] : vector<1x8x8xf32> - // ELTWISE-NEXT: vector.insert {{.*}} [0, 0] : vector<8xf32> into vector<8x1x8xf32> - // ELTWISE-NEXT: vector.extract {{.*}}[0, 1] : vector<1x8x8xf32> - // ELTWISE-NEXT: vector.insert {{.*}} [1, 0] : vector<8xf32> into vector<8x1x8xf32> - // ELTWISE-NEXT: vector.extract {{.*}}[0, 2] : vector<1x8x8xf32> - // ELTWISE-NEXT: vector.insert {{.*}} [2, 0] : vector<8xf32> into vector<8x1x8xf32> - // ELTWISE-NEXT: vector.extract {{.*}}[0, 3] : vector<1x8x8xf32> - // ELTWISE-NEXT: vector.insert {{.*}} [3, 0] : vector<8xf32> into vector<8x1x8xf32> - // ELTWISE-NEXT: vector.extract {{.*}}[0, 4] : vector<1x8x8xf32> - // ELTWISE-NEXT: vector.insert {{.*}} [4, 0] : vector<8xf32> into vector<8x1x8xf32> - // ELTWISE-NEXT: vector.extract {{.*}}[0, 5] : vector<1x8x8xf32> - // ELTWISE-NEXT: vector.insert {{.*}} [5, 0] : vector<8xf32> into vector<8x1x8xf32> - // ELTWISE-NEXT: vector.extract {{.*}}[0, 6] : vector<1x8x8xf32> - // ELTWISE-NEXT: vector.insert {{.*}} [6, 0] : vector<8xf32> into vector<8x1x8xf32> - // ELTWISE-NEXT: vector.extract {{.*}}[0, 7] : vector<1x8x8xf32> - // ELTWISE-NEXT: vector.insert {{.*}} [7, 0] : vector<8xf32> into vector<8x1x8xf32> %0 = vector.transpose %arg0, [1, 0, 2] : vector<1x8x8xf32> to vector<8x1x8xf32> return %0 : vector<8x1x8xf32> } -// AVX2-NOT: vector.shuffle +// CHECK-NOT: vector.shuffle -// ----- -// ELTWISE-LABEL: func @transpose102_8x1x8xf32 -// AVX2-LABEL: func @transpose102_8x1x8 +// CHECK-LABEL: func @transpose102_8x1x8 func.func @transpose102_8x1x8xf32(%arg0: vector<8x1x8xf32>) -> vector<1x8x8xf32> { - // ELTWISE: vector.extract {{.*}}[0, 0] : vector<8x1x8xf32> - // ELTWISE-NEXT: vector.insert {{.*}} [0, 0] : vector<8xf32> into vector<1x8x8xf32> - // ELTWISE-NEXT: vector.extract {{.*}}[1, 0] : vector<8x1x8xf32> - // ELTWISE-NEXT: vector.insert {{.*}} [0, 1] : vector<8xf32> into vector<1x8x8xf32> - // ELTWISE-NEXT: vector.extract {{.*}}[2, 0] : vector<8x1x8xf32> - // ELTWISE-NEXT: vector.insert {{.*}} [0, 2] : vector<8xf32> into vector<1x8x8xf32> - // ELTWISE-NEXT: vector.extract {{.*}}[3, 0] : vector<8x1x8xf32> - // ELTWISE-NEXT: vector.insert {{.*}} [0, 3] : vector<8xf32> into vector<1x8x8xf32> - // ELTWISE-NEXT: vector.extract {{.*}}[4, 0] : vector<8x1x8xf32> - // ELTWISE-NEXT: vector.insert {{.*}} [0, 4] : vector<8xf32> into vector<1x8x8xf32> - // ELTWISE-NEXT: vector.extract {{.*}}[5, 0] : vector<8x1x8xf32> - // ELTWISE-NEXT: vector.insert {{.*}} [0, 5] : vector<8xf32> into vector<1x8x8xf32> - // ELTWISE-NEXT: vector.extract {{.*}}[6, 0] : vector<8x1x8xf32> - // ELTWISE-NEXT: vector.insert {{.*}} [0, 6] : vector<8xf32> into vector<1x8x8xf32> - // ELTWISE-NEXT: vector.extract {{.*}}[7, 0] : vector<8x1x8xf32> - // ELTWISE-NEXT: vector.insert {{.*}} [0, 7] : vector<8xf32> into vector<1x8x8xf32> %0 = vector.transpose %arg0, [1, 0, 2] : vector<8x1x8xf32> to vector<1x8x8xf32> return %0 : vector<1x8x8xf32> } -// AVX2-NOT: vector.shuffle +// CHECK-NOT: vector.shuffle -// ----- -// ELTWISE-LABEL: func @transpose1023_1x1x8x8xf32( -// AVX2-LABEL: func @transpose1023_1x1x8x8 +// CHECK-LABEL: func @transpose1023_1x1x8x8 func.func @transpose1023_1x1x8x8xf32(%arg0: vector<1x1x8x8xf32>) -> vector<1x1x8x8xf32> { // Note the single 2-D extract/insert pair since 2 and 3 are not transposed! - // ELTWISE: vector.extract {{.*}}[0, 0] : vector<1x1x8x8xf32> - // ELTWISE-NEXT: vector.insert {{.*}} [0, 0] : vector<8x8xf32> into vector<1x1x8x8xf32> %0 = vector.transpose %arg0, [1, 0, 2, 3] : vector<1x1x8x8xf32> to vector<1x1x8x8xf32> return %0 : vector<1x1x8x8xf32> } -// AVX2-NOT: vector.shuffle - -// ----- +// CHECK-NOT: vector.shuffle -// AVX2-LABEL: func @transpose120_1x8x8 +// CHECK-LABEL: func @transpose120_1x8x8 func.func @transpose120_1x8x8xf32(%arg0: vector<1x8x8xf32>) -> vector<8x8x1xf32> { %0 = vector.transpose %arg0, [1, 2, 0] : vector<1x8x8xf32> to vector<8x8x1xf32> return %0 : vector<8x8x1xf32> } -// AVX2-NOT: vector.shuffle +// CHECK-NOT: vector.shuffle -// ----- -// AVX2-LABEL: func @transpose201_8x8x1 +// CHECK-LABEL: func @transpose201_8x8x1 func.func @transpose201_8x8x1xf32(%arg0: vector<8x8x1xf32>) -> vector<1x8x8xf32> { %0 = vector.transpose %arg0, [2, 0, 1] : vector<8x8x1xf32> to vector<1x8x8xf32> return %0 : vector<1x8x8xf32> } -// AVX2-NOT: vector.shuffle +// CHECK-NOT: vector.shuffle + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + transform.vector.lower_transpose %module_op + avx2_lowering_strategy = true + : (!pdl.operation) -> !pdl.operation +} diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -113,75 +113,6 @@ } }; -struct TestVectorContractionLowering - : public PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorContractionLowering) - - StringRef getArgument() const final { - return "test-vector-contraction-lowering"; - } - StringRef getDescription() const final { - return "Test lowering patterns that lower contract ops in the vector " - "dialect"; - } - TestVectorContractionLowering() = default; - TestVectorContractionLowering(const TestVectorContractionLowering &pass) - : PassWrapper(pass) {} - - Option lowerToFlatMatrix{ - *this, "vector-lower-matrix-intrinsics", - llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"), - llvm::cl::init(false)}; - Option lowerToOuterProduct{ - *this, "vector-outerproduct", - llvm::cl::desc("Lower vector.contract to vector.outerproduct"), - llvm::cl::init(false)}; - Option lowerToParallelArith{ - *this, "vector-parallel-arith", - llvm::cl::desc("Lower vector.contract to elementwise vector ops."), - llvm::cl::init(false)}; - - void runOnOperation() override { - RewritePatternSet patterns(&getContext()); - - // Test on one pattern in isolation. - if (lowerToOuterProduct) { - VectorContractLowering lowering = VectorContractLowering::OuterProduct; - VectorTransformsOptions options{lowering}; - populateVectorContractLoweringPatterns( - patterns, options, /*benefit=*/1, - /*disableOuterProductlowering=*/true); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); - return; - } - - if (lowerToParallelArith) { - vector::populateVectorContractLoweringPatterns( - patterns, - vector::VectorTransformsOptions().setVectorTransformsOptions( - vector::VectorContractLowering::ParallelArith)); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); - return; - } - - // Test on all contract lowering patterns. - VectorContractLowering contractLowering = VectorContractLowering::Dot; - if (lowerToFlatMatrix) - contractLowering = VectorContractLowering::Matmul; - VectorMultiReductionLowering vectorMultiReductionLowering = - VectorMultiReductionLowering::InnerParallel; - VectorTransformsOptions options{contractLowering, - vectorMultiReductionLowering, - VectorTransposeLowering()}; - populateVectorBroadcastLoweringPatterns(patterns); - populateVectorContractLoweringPatterns(patterns, options); - populateVectorMaskOpLoweringPatterns(patterns); - populateVectorShapeCastLoweringPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); - } -}; - struct TestVectorContractionPrepareForMMTLowering : public PassWrapper> { @@ -209,82 +140,6 @@ } }; -struct TestVectorTransposeLowering - : public PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransposeLowering) - - StringRef getArgument() const final { - return "test-vector-transpose-lowering"; - } - StringRef getDescription() const final { - return "Test lowering patterns that lower contract ops in the vector " - "dialect"; - } - TestVectorTransposeLowering() = default; - TestVectorTransposeLowering(const TestVectorTransposeLowering &pass) - : PassWrapper(pass) {} - - Option lowerToEltwise{ - *this, "eltwise", - llvm::cl::desc("Lower 2-D vector.transpose to eltwise insert/extract"), - llvm::cl::init(false)}; - Option lowerToFlatTranspose{ - *this, "flat", - llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"), - llvm::cl::init(false)}; - Option lowerToShuffleTranspose{ - *this, "shuffle", - llvm::cl::desc("Lower 2-D vector.transpose to shape_cast + shuffle"), - llvm::cl::init(false)}; - Option lowerToAvx2{ - *this, "avx2", - llvm::cl::desc("Lower vector.transpose to avx2-specific patterns"), - llvm::cl::init(false)}; - - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - - void runOnOperation() override { - func::FuncOp funcOp = getOperation(); - MLIRContext *context = funcOp.getContext(); - RewritePatternSet patterns(context); - - vector::VectorTransformsOptions vectorTransformOptions; - if (lowerToEltwise) { - vectorTransformOptions = - vectorTransformOptions.setVectorTransposeLowering( - VectorTransposeLowering::EltWise); - } - if (lowerToFlatTranspose) { - vectorTransformOptions = - vectorTransformOptions.setVectorTransposeLowering( - VectorTransposeLowering::Flat); - } - if (lowerToShuffleTranspose) { - vectorTransformOptions = - vectorTransformOptions.setVectorTransposeLowering( - VectorTransposeLowering::Shuffle); - } - vector::populateVectorTransposeLoweringPatterns(patterns, - vectorTransformOptions); - - if (lowerToAvx2) { - auto avx2LoweringOptions = - x86vector::avx2::LoweringOptions().setTransposeOptions( - x86vector::avx2::TransposeLoweringOptions() - .lower4x8xf32() - .lower8x8xf32()); - x86vector::avx2::populateSpecializedTransposeLoweringPatterns( - patterns, avx2LoweringOptions, /*benefit=*/10); - } - - if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) - return signalPassFailure(); - } -}; - struct TestVectorUnrollingPatterns : public PassWrapper> { @@ -537,40 +392,6 @@ } }; -struct TestVectorMultiReductionLoweringPatterns - : public PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( - TestVectorMultiReductionLoweringPatterns) - - TestVectorMultiReductionLoweringPatterns() = default; - TestVectorMultiReductionLoweringPatterns( - const TestVectorMultiReductionLoweringPatterns &pass) - : PassWrapper(pass) {} - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - StringRef getArgument() const final { - return "test-vector-multi-reduction-lowering-patterns"; - } - StringRef getDescription() const final { - return "Test lowering patterns to lower vector.multi_reduction to other " - "vector ops"; - } - Option useOuterReductions{ - *this, "use-outer-reductions", - llvm::cl::desc("Move reductions to outer most dimensions"), - llvm::cl::init(false)}; - void runOnOperation() override { - RewritePatternSet patterns(&getContext()); - populateVectorMultiReductionLoweringPatterns( - patterns, useOuterReductions - ? vector::VectorMultiReductionLowering::InnerParallel - : vector::VectorMultiReductionLowering::InnerReduction); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); - } -}; - struct TestVectorTransferCollapseInnerMostContiguousDims : public PassWrapper> { @@ -923,12 +744,8 @@ void registerTestVectorLowerings() { PassRegistration(); - PassRegistration(); - PassRegistration(); - PassRegistration(); - PassRegistration(); PassRegistration(); @@ -941,8 +758,6 @@ PassRegistration(); - PassRegistration(); - PassRegistration(); PassRegistration();