diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h @@ -9,40 +9,14 @@ #ifndef MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVMPASS_H_ #define MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVMPASS_H_ -#include "llvm/ADT/STLExtras.h" #include -#include - -namespace llvm { -class Module; -} // namespace llvm namespace mlir { -class DialectConversion; -class FuncOp; class LLVMTypeConverter; -struct LogicalResult; -class MLIRContext; class ModuleOp; template class OpPassBase; -class RewritePattern; -class Type; - -// Owning list of rewriting patterns. class OwningRewritePatternList; -/// Type for a callback constructing the owning list of patterns for the -/// conversion to the LLVMIR dialect. The callback is expected to append -/// patterns to the owning list provided as the second argument. -using LLVMPatternListFiller = - std::function; - -/// Type for a callback constructing the type converter for the conversion to -/// the LLVMIR dialect. The callback is expected to return an instance of the -/// converter. -using LLVMTypeConverterMaker = - std::function(MLIRContext *)>; - /// Collect a set of patterns to convert memory-related operations from the /// Standard dialect to the LLVM dialect, excluding non-memory-related /// operations and FuncOp. @@ -76,36 +50,6 @@ std::unique_ptr> createLowerToLLVMPass(bool useAlloca = false); -/// Creates a pass to convert operations to the LLVMIR dialect. The conversion -/// is defined by a list of patterns and a type converter that will be obtained -/// during the pass using the provided callbacks. -/// By default stdlib malloc/free are used for allocating MemRef payloads. -/// Specifying `useAlloca-true` emits stack allocations instead. In the future -/// this may become an enum when we have concrete uses for other options. -std::unique_ptr> -createLowerToLLVMPass(LLVMPatternListFiller patternListFiller, - LLVMTypeConverterMaker typeConverterMaker, - bool useAlloca = false); - -/// Creates a pass to convert operations to the LLVMIR dialect. The conversion -/// is defined by a list of patterns obtained during the pass using the provided -/// callback and an optional type conversion class, an instance is created -/// during the pass. -/// By default stdlib malloc/free are used for allocating MemRef payloads. -/// Specifying `useAlloca-true` emits stack allocations instead. In the future -/// this may become an enum when we have concrete uses for other options. -template -std::unique_ptr> -createLowerToLLVMPass(LLVMPatternListFiller patternListFiller, - bool useAlloca = false) { - return createLowerToLLVMPass( - patternListFiller, - [](MLIRContext *context) { - return std::make_unique(context); - }, - useAlloca); -} - namespace LLVM { /// Make argument-taking successors of each block distinct. PHI nodes in LLVM /// IR use the predecessor ID to identify which value to take. They do not diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -2354,63 +2354,37 @@ return promotedOperands; } -/// Create an instance of LLVMTypeConverter in the given context. -static std::unique_ptr -makeStandardToLLVMTypeConverter(MLIRContext *context) { - LLVMTypeConverterCustomization customs; - customs.funcArgConverter = structFuncArgTypeConverter; - return std::make_unique(context, customs); -} - -/// Create an instance of BarePtrTypeConverter in the given context. -static std::unique_ptr -makeStandardToLLVMBarePtrTypeConverter(MLIRContext *context) { - LLVMTypeConverterCustomization customs; - customs.funcArgConverter = barePtrFuncArgTypeConverter; - return std::make_unique(context, customs); -} - namespace { /// A pass converting MLIR operations into the LLVM IR dialect. struct LLVMLoweringPass : public ModulePass { - // By default, the patterns are those converting Standard operations to the - // LLVMIR dialect. - explicit LLVMLoweringPass( - bool useAlloca = false, - LLVMPatternListFiller patternListFiller = - populateStdToLLVMConversionPatterns, - LLVMTypeConverterMaker converterBuilder = makeStandardToLLVMTypeConverter) - : patternListFiller(patternListFiller), - typeConverterMaker(converterBuilder) {} - - // Run the dialect converter on the module. - void runOnModule() override { - if (!typeConverterMaker || !patternListFiller) - return signalPassFailure(); + /// Creates an LLVM lowering pass. + explicit LLVMLoweringPass(bool useAlloca = false, + bool useBarePtrCallConv = false) + : useBarePtrCallConv(useBarePtrCallConv) {} + /// Run the dialect converter on the module. + void runOnModule() override { ModuleOp m = getModule(); LLVM::ensureDistinctSuccessors(m); - std::unique_ptr typeConverter = - typeConverterMaker(&getContext()); - if (!typeConverter) - return signalPassFailure(); + + LLVMTypeConverterCustomization customs; + customs.funcArgConverter = useBarePtrCallConv ? barePtrFuncArgTypeConverter + : structFuncArgTypeConverter; + LLVMTypeConverter typeConverter(&getContext(), customs); OwningRewritePatternList patterns; - patternListFiller(*typeConverter, patterns); + if (useBarePtrCallConv) + populateStdToLLVMBarePtrConversionPatterns(typeConverter, patterns); + else + populateStdToLLVMConversionPatterns(typeConverter, patterns); ConversionTarget target(getContext()); target.addLegalDialect(); - if (failed(applyPartialConversion(m, target, patterns, &*typeConverter))) + if (failed(applyPartialConversion(m, target, patterns, &typeConverter))) signalPassFailure(); } - // Callback for creating a list of patterns. It is called every time in - // runOnModule since applyPartialConversion consumes the list. - LLVMPatternListFiller patternListFiller; - - // Callback for creating an instance of type converter. The converter - // constructor needs an MLIRContext, which is not available until runOnModule. - LLVMTypeConverterMaker typeConverterMaker; + bool useBarePtrCallConv; }; } // end namespace @@ -2419,23 +2393,11 @@ return std::make_unique(useAlloca); } -std::unique_ptr> -mlir::createLowerToLLVMPass(LLVMPatternListFiller patternListFiller, - LLVMTypeConverterMaker typeConverterMaker, - bool useAlloca) { - return std::make_unique(useAlloca, patternListFiller, - typeConverterMaker); -} - static PassRegistration pass("convert-std-to-llvm", "Convert scalar and vector operations from the " "Standard to the LLVM dialect", [] { return std::make_unique( - clUseAlloca.getValue(), - clUseBarePtrCallConv ? populateStdToLLVMBarePtrConversionPatterns - : populateStdToLLVMConversionPatterns, - clUseBarePtrCallConv ? makeStandardToLLVMBarePtrTypeConverter - : makeStandardToLLVMTypeConverter); + clUseAlloca.getValue(), clUseBarePtrCallConv.getValue()); });