diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.h b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.h --- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.h +++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.h @@ -11,12 +11,14 @@ #include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" #include "mlir/IR/OpImplementation.h" namespace mlir { namespace vector { class VectorOp; +struct LowerVectorsOptions; } // namespace vector } // namespace mlir @@ -32,6 +34,30 @@ namespace vector { void registerTransformDialectExtension(DialectRegistry ®istry); + +/// Helper structure used to hold the different options of LowerVectorsOp. +struct LowerVectorsOptions : public VectorTransformsOptions { + // Have the default values match the LowerVectorsOp values in the td file. + LowerVectorsOptions() : VectorTransformsOptions() { + setVectorTransformsOptions(VectorContractLowering::OuterProduct); + setVectorMultiReductionLowering( + VectorMultiReductionLowering::InnerParallel); + setVectorTransposeLowering(VectorTransposeLowering::EltWise); + setVectorTransferSplit(VectorTransferSplit::LinalgCopy); + } + + bool transposeAVX2Lowering = false; + LowerVectorsOptions &setTransposeAVX2Lowering(bool opt) { + transposeAVX2Lowering = opt; + return *this; + } + + bool unrollVectorTransfers = true; + LowerVectorsOptions &setUnrollVectorTransfers(bool opt) { + unrollVectorTransfers = opt; + return *this; + } +}; } // namespace vector } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td --- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td +++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td @@ -46,6 +46,18 @@ ); let results = (outs PDL_Operation:$results); + let builders = [ + OpBuilder<(ins "Type":$resultType, "Value":$target, + "const vector::LowerVectorsOptions &":$options), [{ + return build($_builder, $_state, resultType, target, + options.vectorContractLowering, + options.vectorMultiReductionLowering, options.vectorTransferSplit, + options.vectorTransposeLowering, options.transposeAVX2Lowering, + options.unrollVectorTransfers); + }] + > + ]; + let assemblyFormat = [{ $target oilist (