diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h @@ -195,53 +195,16 @@ return b ? vectorize(opName, f) : *this; return *this; } - /// Configure the post staged-patterns late vector transformations. + /// Configure the post staged-patterns late vector lowering options. CodegenStrategy & - setVectorTransformsOptions(vector::VectorTransformsOptions options) { - vectorTransformOptions = options; + setLinalgVectorLoweringOptions(LinalgVectorLoweringOptions options) { + lateVectorLoweringOptions = options; return *this; } - /// Configure the post staged-patterns late vector.transfer to scf - /// conversion. + /// Configure the post staged-patterns global enabling passes options. CodegenStrategy & - setVectorTransferToSCFOptions(VectorTransferToSCFOptions options) { - vectorToSCFOptions = options; - return *this; - } - /// - /// Configure the application of late transformations. - /// - CodegenStrategy &setEnableLICM(bool val) { - this->lateCodegenStrategyOptions.enableLICM = val; - return *this; - } - CodegenStrategy &setEnableHoistRedundantVectorTransfers(bool val) { - this->lateCodegenStrategyOptions.enableHoistRedundantVectorTransfers = val; - return *this; - } - CodegenStrategy &setEnableHoistRedundantVectorTransfersOnTensor(bool val) { - this->lateCodegenStrategyOptions - .enableHoistRedundantVectorTransfersOnTensor = val; - return *this; - } - CodegenStrategy &setMaxTransferRank(int64_t val) { - this->lateCodegenStrategyOptions.maxTransferRank = val; - return *this; - } - CodegenStrategy &setEnableVectorTransferLowering(bool val) { - this->lateCodegenStrategyOptions.enableVectorTransferLowering = val; - return *this; - } - CodegenStrategy &setEnableVectorTransferPartialRewrite(bool val) { - this->lateCodegenStrategyOptions.enableVectorTransferPartialRewrite = val; - return *this; - } - CodegenStrategy &setEnableVectorContractLowering(bool val) { - this->lateCodegenStrategyOptions.enableVectorContractLowering = val; - return *this; - } - CodegenStrategy &setEnableVectorToSCFConversion(bool val) { - this->lateCodegenStrategyOptions.enableVectorToSCFConversion = val; + setVectorTransferToSCFOptions(LinalgEnablingOptions options) { + linalgEnablingOptions = options; return *this; } @@ -252,10 +215,9 @@ private: LogicalResult postPatternTransforms(Operation *func) const; - vector::VectorTransformsOptions vectorTransformOptions; - VectorTransferToSCFOptions vectorToSCFOptions; + LinalgEnablingOptions linalgEnablingOptions; + LinalgVectorLoweringOptions lateVectorLoweringOptions; SmallVector, 4> transformationSequence; - LateCodegenStrategyOptions lateCodegenStrategyOptions; }; } // namespace linalg diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -845,41 +845,80 @@ : LinalgBaseVectorizationPattern(opName, context, filter, benefit) {} }; -/// Options to control the application of late transformations. -struct LateCodegenStrategyOptions { - /// Hoisting transformations are always deemed beneficial and must disabled - /// explicitly. - bool enableLICM = true; - bool enableHoistRedundantVectorTransfers = true; - bool enableHoistRedundantVectorTransfersOnTensor = true; - /// Vector lowering operations may result in surprising behavior when - /// composing multiple codegen strategies and must be enabled explicitly. - int64_t maxTransferRank = 1; - bool enableVectorTransferLowering = true; - bool enableVectorTransferPartialRewrite = false; - bool enableVectorContractLowering = false; - bool enableVectorToSCFConversion = false; -}; - /// Options to control the application of enabling transformations. /// Hoisting transformations are always deemed beneficial and must be disabled /// explicitly. struct LinalgEnablingOptions { + /// Enable LICM. bool enableLICM = true; + LinalgEnablingOptions &setEnableLICM(bool val) { + enableLICM = val; + return *this; + } + /// Enable hoisting of redundant vector transfer ops. bool enableHoistRedundantVectorTransfers = true; + LinalgEnablingOptions &setEnableHoistRedundantVectorTransfers(bool val) { + enableHoistRedundantVectorTransfers = val; + return *this; + } + /// Enable hoisting of redundant vector transfer ops on tensor. bool enableHoistRedundantVectorTransfersOnTensor = true; + LinalgEnablingOptions & + setEnableHoistRedundantVectorTransfersOnTensor(bool val) { + enableHoistRedundantVectorTransfersOnTensor = val; + return *this; + } }; /// Vector lowering options control how ops are lowered down to 1-D and scf.for /// form. struct LinalgVectorLoweringOptions { + /// Maximal transfer rank under which we do not lower further. int64_t maxTransferRank = 1; + LinalgVectorLoweringOptions &setMaxTransferRank(int64_t val) { + maxTransferRank = val; + return *this; + } + /// Vector lowering operations may result in surprising behavior when + /// composing multiple codegen strategies and must be enabled explicitly. bool enableVectorTransferLowering = true; + LinalgVectorLoweringOptions &setEnableVectorTransferLowering(bool val) { + enableVectorTransferLowering = val; + return *this; + } + /// Trigger full / partial vector.transfer splits. bool enableVectorTransferPartialRewrite = false; + LinalgVectorLoweringOptions &setEnableVectorTransferPartialRewrite(bool val) { + enableVectorTransferPartialRewrite = val; + return *this; + } + /// Enable lowering of vector.contract. bool enableVectorContractLowering = false; + LinalgVectorLoweringOptions &setEnableVectorContractLowering(bool val) { + enableVectorContractLowering = val; + return *this; + } + /// Enable lowering of vector.transfer to scf. bool enableVectorToSCFConversion = false; + LinalgVectorLoweringOptions &setEnableVectorToSCFConversion(bool val) { + enableVectorToSCFConversion = val; + return *this; + } + /// Configure late vector transformations. vector::VectorTransformsOptions vectorTransformOptions; + LinalgVectorLoweringOptions & + setVectorTransformsOptions(vector::VectorTransformsOptions options) { + vectorTransformOptions = options; + return *this; + } + /// Configure the post staged-patterns late vector.transfer to scf + /// conversion. VectorTransferToSCFOptions vectorTransferToSCFOptions; + LinalgVectorLoweringOptions & + setVectorTransferToSCFOptions(VectorTransferToSCFOptions options) { + vectorTransferToSCFOptions = options; + return *this; + } }; /// Trait to check if T provides a `getOperationName` method. diff --git a/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp b/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp @@ -46,19 +46,6 @@ t->addToPassPipeline(pm, filter); pm.addPass(createLinalgStrategyEnablePass()); } - LinalgVectorLoweringOptions vectorLoweringOptions; - vectorLoweringOptions.maxTransferRank = - lateCodegenStrategyOptions.maxTransferRank; - vectorLoweringOptions.enableVectorTransferLowering = - lateCodegenStrategyOptions.enableVectorTransferLowering; - vectorLoweringOptions.enableVectorTransferPartialRewrite = - lateCodegenStrategyOptions.enableVectorTransferPartialRewrite; - vectorLoweringOptions.enableVectorContractLowering = - lateCodegenStrategyOptions.enableVectorContractLowering; - vectorLoweringOptions.enableVectorToSCFConversion = - lateCodegenStrategyOptions.enableVectorToSCFConversion; - vectorLoweringOptions.vectorTransformOptions = vectorTransformOptions; - vectorLoweringOptions.vectorTransferToSCFOptions = vectorToSCFOptions; - pm.addPass(createLinalgStrategyLowerVectorsPass(vectorLoweringOptions)); + pm.addPass(createLinalgStrategyLowerVectorsPass(lateVectorLoweringOptions)); pm.addPass(createLinalgStrategyRemoveMarkersPass()); } diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp @@ -153,15 +153,17 @@ .generalizeIf(generalize, anchorOpName) .interchangeIf(!iteratorInterchange.empty(), iteratorInterchange) .vectorizeIf(vectorize, generalize ? genericOpName : anchorOpName) - .setEnableVectorTransferPartialRewrite(true) - .setEnableVectorContractLowering(true) - .setEnableVectorToSCFConversion(true) - .setVectorTransformsOptions( - vector::VectorTransformsOptions() - .setVectorTransformsOptions(vectorContractLowering) - .setVectorTransferSplit(vectorTransferSplit)) - .setVectorTransferToSCFOptions( - VectorTransferToSCFOptions().setUnroll(unrollVectorTransfers)); + .setLinalgVectorLoweringOptions( + LinalgVectorLoweringOptions() + .setVectorTransformsOptions( + vector::VectorTransformsOptions() + .setVectorTransformsOptions(vectorContractLowering) + .setVectorTransferSplit(vectorTransferSplit)) + .setVectorTransferToSCFOptions( + VectorTransferToSCFOptions().setUnroll(unrollVectorTransfers)) + .setEnableVectorTransferPartialRewrite(true) + .setEnableVectorContractLowering(true) + .setEnableVectorToSCFConversion(true)); // Created a nested OpPassManager and run. FuncOp funcOp = getFunction();