diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.h b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.h --- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.h +++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.h @@ -17,6 +17,7 @@ namespace mlir { namespace vector { class VectorOp; +struct LowerVectorsOptions; } // namespace vector } // namespace mlir @@ -32,6 +33,20 @@ namespace vector { void registerTransformDialectExtension(DialectRegistry ®istry); + +/// Helper structure used to hold the different options of LowerVectorsOp. +struct LowerVectorsOptions : public VectorTransformsOptions { + bool transposeAVX2Lowering = false; + LowerVectorsOptions &setTransposeAVX2Lowering(bool opt) { + transposeAVX2Lowering = opt; + return *this; + } + bool unrollVectorTransfers = true; + LowerVectorsOptions &setUnrollVectorTransfers(bool opt) { + unrollVectorTransfers = opt; + return *this; + } +}; } // namespace vector } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td --- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td +++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td @@ -46,6 +46,18 @@ ); let results = (outs PDL_Operation:$results); + let builders = [ + OpBuilder<(ins "Type":$resultType, "Value":$target, + "const vector::LowerVectorsOptions &":$options), [{ + return build($_builder, $_state, resultType, target, + options.vectorContractLowering, + options.vectorMultiReductionLowering, options.vectorTransferSplit, + options.vectorTransposeLowering, options.transposeAVX2Lowering, + options.unrollVectorTransfers); + }] + > + ]; + let assemblyFormat = [{ $target oilist (