diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h @@ -195,53 +195,16 @@ return b ? vectorize(opName, f) : *this; return *this; } - /// Configure the post staged-patterns late vector transformations. + /// Configure the post staged-patterns late vector lowering options. CodegenStrategy & - setVectorTransformsOptions(vector::VectorTransformsOptions options) { - vectorTransformOptions = options; + setLinalgVectorLoweringOptions(LinalgVectorLoweringOptions options) { + lateVectorLoweringOptions = options; return *this; } - /// Configure the post staged-patterns late vector.transfer to scf - /// conversion. + /// Configure the post staged-patterns global enabling passes options. CodegenStrategy & - setVectorTransferToSCFOptions(VectorTransferToSCFOptions options) { - vectorToSCFOptions = options; - return *this; - } - /// - /// Configure the application of late transformations. - /// - CodegenStrategy &setEnableLICM(bool val) { - this->lateCodegenStrategyOptions.enableLICM = val; - return *this; - } - CodegenStrategy &setEnableHoistRedundantVectorTransfers(bool val) { - this->lateCodegenStrategyOptions.enableHoistRedundantVectorTransfers = val; - return *this; - } - CodegenStrategy &setEnableHoistRedundantVectorTransfersOnTensor(bool val) { - this->lateCodegenStrategyOptions - .enableHoistRedundantVectorTransfersOnTensor = val; - return *this; - } - CodegenStrategy &setMaxTransferRank(int64_t val) { - this->lateCodegenStrategyOptions.maxTransferRank = val; - return *this; - } - CodegenStrategy &setEnableVectorTransferLowering(bool val) { - this->lateCodegenStrategyOptions.enableVectorTransferLowering = val; - return *this; - } - CodegenStrategy &setEnableVectorTransferPartialRewrite(bool val) { - this->lateCodegenStrategyOptions.enableVectorTransferPartialRewrite = val; - return *this; - } - CodegenStrategy &setEnableVectorContractLowering(bool val) { - this->lateCodegenStrategyOptions.enableVectorContractLowering = val; - return *this; - } - CodegenStrategy &setEnableVectorToSCFConversion(bool val) { - this->lateCodegenStrategyOptions.enableVectorToSCFConversion = val; + setVectorTransferToSCFOptions(LinalgEnablingOptions options) { + linalgEnablingOptions = options; return *this; } @@ -252,10 +215,9 @@ private: LogicalResult postPatternTransforms(Operation *func) const; - vector::VectorTransformsOptions vectorTransformOptions; - VectorTransferToSCFOptions vectorToSCFOptions; + LinalgEnablingOptions linalgEnablingOptions; + LinalgVectorLoweringOptions lateVectorLoweringOptions; SmallVector, 4> transformationSequence; - LateCodegenStrategyOptions lateCodegenStrategyOptions; }; } // namespace linalg diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -846,41 +846,79 @@ : LinalgBaseVectorizationPattern(opName, context, filter, benefit) {} }; -/// Options to control the application of late transformations. -struct LateCodegenStrategyOptions { - /// Hoisting transformations are always deemed beneficial and must disabled - /// explicitly. - bool enableLICM = true; - bool enableHoistRedundantVectorTransfers = true; - bool enableHoistRedundantVectorTransfersOnTensor = true; - /// Vector lowering operations may result in surprising behavior when - /// composing multiple codegen strategies and must be enabled explicitly. - int64_t maxTransferRank = 1; - bool enableVectorTransferLowering = true; - bool enableVectorTransferPartialRewrite = false; - bool enableVectorContractLowering = false; - bool enableVectorToSCFConversion = false; -}; - /// Options to control the application of enabling transformations. /// Hoisting transformations are always deemed beneficial and must be disabled /// explicitly. struct LinalgEnablingOptions { - bool enableLICM = true; - bool enableHoistRedundantVectorTransfers = true; - bool enableHoistRedundantVectorTransfersOnTensor = true; + /// Enable LICM. + bool licm = true; + LinalgEnablingOptions &enableLICM(bool val) { + licm = val; + return *this; + } + /// Enable hoisting of redundant vector transfer ops. + bool hoistRedundantVectorTransfers = true; + LinalgEnablingOptions &enableHoistRedundantVectorTransfers(bool val) { + hoistRedundantVectorTransfers = val; + return *this; + } + /// Enable hoisting of redundant vector transfer ops on tensor. + bool hoistRedundantVectorTransfersOnTensor = true; + LinalgEnablingOptions &enableHoistRedundantVectorTransfersOnTensor(bool val) { + hoistRedundantVectorTransfersOnTensor = val; + return *this; + } }; /// Vector lowering options control how ops are lowered down to 1-D and scf.for /// form. struct LinalgVectorLoweringOptions { + /// Maximal transfer rank under which we do not lower further. int64_t maxTransferRank = 1; - bool enableVectorTransferLowering = true; - bool enableVectorTransferPartialRewrite = false; - bool enableVectorContractLowering = false; - bool enableVectorToSCFConversion = false; + LinalgVectorLoweringOptions &setMaxTransferRank(int64_t val) { + maxTransferRank = val; + return *this; + } + /// Vector lowering operations may result in surprising behavior when + /// composing multiple codegen strategies and must be enabled explicitly. + bool transferLowering = true; + LinalgVectorLoweringOptions &enableTransferLowering(bool val) { + transferLowering = val; + return *this; + } + /// Trigger full / partial vector.transfer splits. + bool transferPartialRewrite = false; + LinalgVectorLoweringOptions &enableTransferPartialRewrite(bool val) { + transferPartialRewrite = val; + return *this; + } + /// Enable lowering of vector.contract. + bool contractionLowering = false; + LinalgVectorLoweringOptions &enableContractionLowering(bool val) { + contractionLowering = val; + return *this; + } + /// Enable lowering of vector.transfer to scf. + bool transferToSCFConversion = false; + LinalgVectorLoweringOptions &enableTransferToSCFConversion(bool val) { + transferToSCFConversion = val; + return *this; + } + /// Configure late vector transformations. vector::VectorTransformsOptions vectorTransformOptions; + LinalgVectorLoweringOptions & + setVectorTransformsOptions(vector::VectorTransformsOptions options) { + vectorTransformOptions = options; + return *this; + } + /// Configure the post staged-patterns late vector.transfer to scf + /// conversion. VectorTransferToSCFOptions vectorTransferToSCFOptions; + LinalgVectorLoweringOptions & + setVectorTransferToSCFOptions(VectorTransferToSCFOptions options) { + vectorTransferToSCFOptions = options; + return *this; + } }; /// Trait to check if T provides a `getOperationName` method. diff --git a/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp b/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp @@ -46,19 +46,6 @@ t->addToPassPipeline(pm, filter); pm.addPass(createLinalgStrategyEnablePass()); } - LinalgVectorLoweringOptions vectorLoweringOptions; - vectorLoweringOptions.maxTransferRank = - lateCodegenStrategyOptions.maxTransferRank; - vectorLoweringOptions.enableVectorTransferLowering = - lateCodegenStrategyOptions.enableVectorTransferLowering; - vectorLoweringOptions.enableVectorTransferPartialRewrite = - lateCodegenStrategyOptions.enableVectorTransferPartialRewrite; - vectorLoweringOptions.enableVectorContractLowering = - lateCodegenStrategyOptions.enableVectorContractLowering; - vectorLoweringOptions.enableVectorToSCFConversion = - lateCodegenStrategyOptions.enableVectorToSCFConversion; - vectorLoweringOptions.vectorTransformOptions = vectorTransformOptions; - vectorLoweringOptions.vectorTransferToSCFOptions = vectorToSCFOptions; - pm.addPass(createLinalgStrategyLowerVectorsPass(vectorLoweringOptions)); + pm.addPass(createLinalgStrategyLowerVectorsPass(lateVectorLoweringOptions)); pm.addPass(createLinalgStrategyRemoveMarkersPass()); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp @@ -224,7 +224,7 @@ if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) return signalPassFailure(); - if (options.enableLICM) { + if (options.licm) { if (funcOp ->walk([&](LoopLikeOpInterface loopLike) { if (failed(moveLoopInvariantCode(loopLike))) @@ -236,10 +236,10 @@ } promoteSingleIterationLoops(funcOp); - if (options.enableHoistRedundantVectorTransfers) + if (options.hoistRedundantVectorTransfers) hoistRedundantVectorTransfers(funcOp); - if (options.enableHoistRedundantVectorTransfersOnTensor) + if (options.hoistRedundantVectorTransfersOnTensor) hoistRedundantVectorTransfersOnTensor(funcOp); } @@ -263,21 +263,21 @@ MLIRContext *context = funcOp.getContext(); RewritePatternSet patterns(context); - if (options.enableVectorTransferLowering) { + if (options.transferLowering) { vector::populateVectorTransferLoweringPatterns(patterns, options.maxTransferRank); } - if (options.enableVectorTransferPartialRewrite) { + if (options.transferPartialRewrite) { patterns.add( context, options.vectorTransformOptions); } - if (options.enableVectorContractLowering) { + if (options.contractionLowering) { patterns.add( options.vectorTransformOptions, context); vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); } - if (options.enableVectorToSCFConversion) { + if (options.transferToSCFConversion) { populateVectorToSCFConversionPatterns(patterns, options.vectorTransferToSCFOptions); } diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp @@ -153,15 +153,17 @@ .generalizeIf(generalize, anchorOpName) .interchangeIf(!iteratorInterchange.empty(), iteratorInterchange) .vectorizeIf(vectorize, generalize ? genericOpName : anchorOpName) - .setEnableVectorTransferPartialRewrite(true) - .setEnableVectorContractLowering(true) - .setEnableVectorToSCFConversion(true) - .setVectorTransformsOptions( - vector::VectorTransformsOptions() - .setVectorTransformsOptions(vectorContractLowering) - .setVectorTransferSplit(vectorTransferSplit)) - .setVectorTransferToSCFOptions( - VectorTransferToSCFOptions().setUnroll(unrollVectorTransfers)); + .setLinalgVectorLoweringOptions( + LinalgVectorLoweringOptions() + .setVectorTransformsOptions( + vector::VectorTransformsOptions() + .setVectorTransformsOptions(vectorContractLowering) + .setVectorTransferSplit(vectorTransferSplit)) + .setVectorTransferToSCFOptions( + VectorTransferToSCFOptions().setUnroll(unrollVectorTransfers)) + .enableTransferPartialRewrite(true) + .enableContractionLowering(true) + .enableTransferToSCFConversion(true)); // Created a nested OpPassManager and run. FuncOp funcOp = getFunction();