diff --git a/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h b/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h --- a/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h +++ b/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h @@ -46,29 +46,29 @@ /// /// When applying the pattern a second time, the existing alloca() operation /// is reused and only a second vector.type_cast is added. - struct VectorTransferToSCFOptions { + /// Minimal rank to which vector transfer are lowered. unsigned targetRank = 1; + VectorTransferToSCFOptions &setTargetRank(unsigned r) { + targetRank = r; + return *this; + } + /// bool lowerPermutationMaps = false; - bool lowerTensors = false; - bool unroll = false; - - VectorTransferToSCFOptions &setLowerPermutationMaps(bool l) { + VectorTransferToSCFOptions &enableLowerPermutationMaps(bool l = true) { lowerPermutationMaps = l; return *this; } - - VectorTransferToSCFOptions &setLowerTensors(bool l) { + /// Allows vector transfers that operated on tensors to be lowered (this is an + /// uncommon alternative). + bool lowerTensors = false; + VectorTransferToSCFOptions &enableLowerTensors(bool l = true) { lowerTensors = l; return *this; } - - VectorTransferToSCFOptions &setTargetRank(unsigned r) { - targetRank = r; - return *this; - } - - VectorTransferToSCFOptions &setUnroll(bool u) { + /// Triggers full unrolling (vs iterating with a loop) during transfer to scf. + bool unroll = false; + VectorTransferToSCFOptions &enableFullUnroll(bool u = true) { unroll = u; return *this; } diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h @@ -113,6 +113,22 @@ linalg::LinalgVectorizationOptions options; }; +/// Represent one application of createLinalgStrategyLowerVectorsPass. +struct VectorLowering : public Transformation { + explicit VectorLowering( + linalg::LinalgVectorLoweringOptions options, + LinalgTransformationFilter::FilterFunction f = nullptr) + : Transformation(f), options(options) {} + + void addToPassPipeline(OpPassManager &pm, + LinalgTransformationFilter m) const override { + pm.addPass(createLinalgStrategyLowerVectorsPass(options, m)); + } + +private: + linalg::LinalgVectorLoweringOptions options; +}; + /// Codegen strategy controls how a Linalg op is progressively lowered. struct CodegenStrategy { /// Append a pattern to add a level of tiling for Op `opName` with tiling @@ -195,53 +211,16 @@ return b ? vectorize(opName, f) : *this; return *this; } - /// Configure the post staged-patterns late vector transformations. - CodegenStrategy & - setVectorTransformsOptions(vector::VectorTransformsOptions options) { - vectorTransformOptions = options; + /// Append a pattern to lower all vector operations. + CodegenStrategy &vectorLowering(LinalgVectorLoweringOptions options) { + transformationSequence.emplace_back( + std::make_unique(options)); return *this; } - /// Configure the post staged-patterns late vector.transfer to scf - /// conversion. + /// Configure the post staged-patterns global enabling passes options. CodegenStrategy & - setVectorTransferToSCFOptions(VectorTransferToSCFOptions options) { - vectorToSCFOptions = options; - return *this; - } - /// - /// Configure the application of late transformations. - /// - CodegenStrategy &setEnableLICM(bool val) { - this->lateCodegenStrategyOptions.enableLICM = val; - return *this; - } - CodegenStrategy &setEnableHoistRedundantVectorTransfers(bool val) { - this->lateCodegenStrategyOptions.enableHoistRedundantVectorTransfers = val; - return *this; - } - CodegenStrategy &setEnableHoistRedundantVectorTransfersOnTensor(bool val) { - this->lateCodegenStrategyOptions - .enableHoistRedundantVectorTransfersOnTensor = val; - return *this; - } - CodegenStrategy &setMaxTransferRank(int64_t val) { - this->lateCodegenStrategyOptions.maxTransferRank = val; - return *this; - } - CodegenStrategy &setEnableVectorTransferLowering(bool val) { - this->lateCodegenStrategyOptions.enableVectorTransferLowering = val; - return *this; - } - CodegenStrategy &setEnableVectorTransferPartialRewrite(bool val) { - this->lateCodegenStrategyOptions.enableVectorTransferPartialRewrite = val; - return *this; - } - CodegenStrategy &setEnableVectorContractLowering(bool val) { - this->lateCodegenStrategyOptions.enableVectorContractLowering = val; - return *this; - } - CodegenStrategy &setEnableVectorToSCFConversion(bool val) { - this->lateCodegenStrategyOptions.enableVectorToSCFConversion = val; + setVectorTransferToSCFOptions(LinalgEnablingOptions options) { + linalgEnablingOptions = options; return *this; } @@ -252,10 +231,8 @@ private: LogicalResult postPatternTransforms(Operation *func) const; - vector::VectorTransformsOptions vectorTransformOptions; - VectorTransferToSCFOptions vectorToSCFOptions; + LinalgEnablingOptions linalgEnablingOptions; SmallVector, 4> transformationSequence; - LateCodegenStrategyOptions lateCodegenStrategyOptions; }; } // namespace linalg diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -846,41 +846,80 @@ : LinalgBaseVectorizationPattern(opName, context, filter, benefit) {} }; -/// Options to control the application of late transformations. -struct LateCodegenStrategyOptions { - /// Hoisting transformations are always deemed beneficial and must disabled - /// explicitly. - bool enableLICM = true; - bool enableHoistRedundantVectorTransfers = true; - bool enableHoistRedundantVectorTransfersOnTensor = true; - /// Vector lowering operations may result in surprising behavior when - /// composing multiple codegen strategies and must be enabled explicitly. - int64_t maxTransferRank = 1; - bool enableVectorTransferLowering = true; - bool enableVectorTransferPartialRewrite = false; - bool enableVectorContractLowering = false; - bool enableVectorToSCFConversion = false; -}; - /// Options to control the application of enabling transformations. /// Hoisting transformations are always deemed beneficial and must be disabled /// explicitly. struct LinalgEnablingOptions { - bool enableLICM = true; - bool enableHoistRedundantVectorTransfers = true; - bool enableHoistRedundantVectorTransfersOnTensor = true; + /// Enable LICM. + bool licm = true; + LinalgEnablingOptions &enableLICM(bool val = true) { + licm = val; + return *this; + } + /// Enable hoisting of redundant vector transfer ops. + bool hoistRedundantVectorTransfers = true; + LinalgEnablingOptions &enableHoistRedundantVectorTransfers(bool val = true) { + hoistRedundantVectorTransfers = val; + return *this; + } + /// Enable hoisting of redundant vector transfer ops on tensor. + bool hoistRedundantVectorTransfersOnTensor = true; + LinalgEnablingOptions & + enableHoistRedundantVectorTransfersOnTensor(bool val = true) { + hoistRedundantVectorTransfersOnTensor = val; + return *this; + } }; /// Vector lowering options control how ops are lowered down to 1-D and scf.for /// form. struct LinalgVectorLoweringOptions { + /// Maximal transfer rank under which we do not lower further. int64_t maxTransferRank = 1; - bool enableVectorTransferLowering = true; - bool enableVectorTransferPartialRewrite = false; - bool enableVectorContractLowering = false; - bool enableVectorToSCFConversion = false; + LinalgVectorLoweringOptions &setMaxTransferRank(int64_t val) { + maxTransferRank = val; + return *this; + } + /// Vector lowering operations may result in surprising behavior when + /// composing multiple codegen strategies and must be enabled explicitly. + bool transferLowering = true; + LinalgVectorLoweringOptions &enableTransferLowering(bool val = true) { + transferLowering = val; + return *this; + } + /// Trigger full / partial vector.transfer splits. + bool transferPartialRewrite = false; + LinalgVectorLoweringOptions &enableTransferPartialRewrite(bool val = true) { + transferPartialRewrite = val; + return *this; + } + /// Enable lowering of vector.contract. + bool contractionLowering = false; + LinalgVectorLoweringOptions &enableContractionLowering(bool val = true) { + contractionLowering = val; + return *this; + } + /// Enable lowering of vector.transfer to scf. + bool transferToSCFConversion = false; + LinalgVectorLoweringOptions &enableTransferToSCFConversion(bool val = true) { + transferToSCFConversion = val; + return *this; + } + /// Configure late vector transformations. vector::VectorTransformsOptions vectorTransformOptions; + LinalgVectorLoweringOptions & + setVectorTransformsOptions(vector::VectorTransformsOptions options) { + vectorTransformOptions = options; + return *this; + } + /// Configure the post staged-patterns late vector.transfer to scf + /// conversion. VectorTransferToSCFOptions vectorTransferToSCFOptions; + LinalgVectorLoweringOptions & + setVectorTransferToSCFOptions(VectorTransferToSCFOptions options) { + vectorTransferToSCFOptions = options; + return *this; + } }; /// Trait to check if T provides a `getOperationName` method. diff --git a/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp b/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp @@ -44,21 +44,7 @@ : linalg::LinalgTransformationFilter( t->filter, currentState, nextState); t->addToPassPipeline(pm, filter); - pm.addPass(createLinalgStrategyEnablePass()); + pm.addPass(createLinalgStrategyEnablePass(linalgEnablingOptions)); } - LinalgVectorLoweringOptions vectorLoweringOptions; - vectorLoweringOptions.maxTransferRank = - lateCodegenStrategyOptions.maxTransferRank; - vectorLoweringOptions.enableVectorTransferLowering = - lateCodegenStrategyOptions.enableVectorTransferLowering; - vectorLoweringOptions.enableVectorTransferPartialRewrite = - lateCodegenStrategyOptions.enableVectorTransferPartialRewrite; - vectorLoweringOptions.enableVectorContractLowering = - lateCodegenStrategyOptions.enableVectorContractLowering; - vectorLoweringOptions.enableVectorToSCFConversion = - lateCodegenStrategyOptions.enableVectorToSCFConversion; - vectorLoweringOptions.vectorTransformOptions = vectorTransformOptions; - vectorLoweringOptions.vectorTransferToSCFOptions = vectorToSCFOptions; - pm.addPass(createLinalgStrategyLowerVectorsPass(vectorLoweringOptions)); pm.addPass(createLinalgStrategyRemoveMarkersPass()); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp @@ -224,7 +224,7 @@ if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) return signalPassFailure(); - if (options.enableLICM) { + if (options.licm) { if (funcOp ->walk([&](LoopLikeOpInterface loopLike) { if (failed(moveLoopInvariantCode(loopLike))) @@ -236,10 +236,10 @@ } promoteSingleIterationLoops(funcOp); - if (options.enableHoistRedundantVectorTransfers) + if (options.hoistRedundantVectorTransfers) hoistRedundantVectorTransfers(funcOp); - if (options.enableHoistRedundantVectorTransfersOnTensor) + if (options.hoistRedundantVectorTransfersOnTensor) hoistRedundantVectorTransfersOnTensor(funcOp); } @@ -263,21 +263,21 @@ MLIRContext *context = funcOp.getContext(); RewritePatternSet patterns(context); - if (options.enableVectorTransferLowering) { + if (options.transferLowering) { vector::populateVectorTransferLoweringPatterns(patterns, options.maxTransferRank); } - if (options.enableVectorTransferPartialRewrite) { + if (options.transferPartialRewrite) { patterns.add( context, options.vectorTransformOptions); } - if (options.enableVectorContractLowering) { + if (options.contractionLowering) { patterns.add( options.vectorTransformOptions, context); vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); } - if (options.enableVectorToSCFConversion) { + if (options.transferToSCFConversion) { populateVectorToSCFConversionPatterns(patterns, options.vectorTransferToSCFOptions); } diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp @@ -153,15 +153,18 @@ .generalizeIf(generalize, anchorOpName) .interchangeIf(!iteratorInterchange.empty(), iteratorInterchange) .vectorizeIf(vectorize, generalize ? genericOpName : anchorOpName) - .setEnableVectorTransferPartialRewrite(true) - .setEnableVectorContractLowering(true) - .setEnableVectorToSCFConversion(true) - .setVectorTransformsOptions( - vector::VectorTransformsOptions() - .setVectorTransformsOptions(vectorContractLowering) - .setVectorTransferSplit(vectorTransferSplit)) - .setVectorTransferToSCFOptions( - VectorTransferToSCFOptions().setUnroll(unrollVectorTransfers)); + .vectorLowering( + LinalgVectorLoweringOptions() + .setVectorTransformsOptions( + vector::VectorTransformsOptions() + .setVectorTransformsOptions(vectorContractLowering) + .setVectorTransferSplit(vectorTransferSplit)) + .setVectorTransferToSCFOptions( + VectorTransferToSCFOptions().enableFullUnroll( + unrollVectorTransfers)) + .enableTransferPartialRewrite() + .enableContractionLowering() + .enableTransferToSCFConversion()); // Created a nested OpPassManager and run. FuncOp funcOp = getFunction();