diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -87,6 +87,8 @@ /// Returns the LLVM dialect. LLVM::LLVMDialect *getDialect() { return llvmDialect; } + const LowerToLLVMOptions &getOptions() const { return options; } + /// Promote the LLVM struct representation of all MemRef descriptors to stack /// and use pointers to struct to avoid the complexity of the /// platform-specific C/C++ ABI lowering related to struct argument passing. @@ -390,8 +392,6 @@ public: ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context, LLVMTypeConverter &typeConverter, - const LowerToLLVMOptions &options = - LowerToLLVMOptions::getDefaultOptions(), PatternBenefit benefit = 1); /// Returns the LLVM dialect. @@ -443,9 +443,6 @@ protected: /// Reference to the type converter, with potential extensions. LLVMTypeConverter &typeConverter; - - /// Reference to the llvm lowering options. - const LowerToLLVMOptions &options; }; /// Utility class for operation conversions targeting the LLVM dialect that @@ -454,11 +451,10 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern { public: ConvertOpToLLVMPattern(LLVMTypeConverter &typeConverter, - const LowerToLLVMOptions &options, PatternBenefit benefit = 1) : ConvertToLLVMPattern(OpTy::getOperationName(), &typeConverter.getContext(), typeConverter, - options, benefit) {} + benefit) {} }; namespace LLVM { diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h @@ -42,31 +42,26 @@ /// Standard dialect to the LLVM dialect, excluding non-memory-related /// operations and FuncOp. void populateStdToLLVMMemoryConversionPatterns( - LLVMTypeConverter &converter, OwningRewritePatternList &patterns, - const LowerToLLVMOptions &options); + LLVMTypeConverter &converter, OwningRewritePatternList &patterns); /// Collect a set of patterns to convert from the Standard dialect to the LLVM /// dialect, excluding the memory-related operations. void populateStdToLLVMNonMemoryConversionPatterns( - LLVMTypeConverter &converter, OwningRewritePatternList &patterns, - const LowerToLLVMOptions &options); + LLVMTypeConverter &converter, OwningRewritePatternList &patterns); /// Collect the default pattern to convert a FuncOp to the LLVM dialect. If /// `emitCWrappers` is set, the pattern will also produce functions /// that pass memref descriptors by pointer-to-structure in addition to the /// default unpacked form. void populateStdToLLVMFuncOpConversionPattern( - LLVMTypeConverter &converter, OwningRewritePatternList &patterns, - const LowerToLLVMOptions &options); + LLVMTypeConverter &converter, OwningRewritePatternList &patterns); /// Collect the patterns to convert from the Standard dialect to LLVM. The /// conversion patterns capture the LLVMTypeConverter and the LowerToLLVMOptions /// by reference meaning the references have to remain alive during the entire /// pattern lifetime. -void populateStdToLLVMConversionPatterns( - LLVMTypeConverter &converter, OwningRewritePatternList &patterns, - const LowerToLLVMOptions &options = - LowerToLLVMOptions::getDefaultOptions()); +void populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter, + OwningRewritePatternList &patterns); /// Creates a pass to convert the Standard dialect into the LLVMIR dialect. /// stdlib malloc/free is used by default for allocating memrefs allocated with diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -424,10 +424,9 @@ ConvertToLLVMPattern::ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context, LLVMTypeConverter &typeConverter, - const LowerToLLVMOptions &options, PatternBenefit benefit) : ConversionPattern(rootOpName, benefit, typeConverter, context), - typeConverter(typeConverter), options(options) {} + typeConverter(typeConverter) {} /*============================================================================*/ /* StructBuilder implementation */ @@ -1124,10 +1123,8 @@ /// information. static constexpr StringRef kEmitIfaceAttrName = "llvm.emit_c_interface"; struct FuncOpConversion : public FuncOpConversionBase { - FuncOpConversion(LLVMTypeConverter &converter, - const LowerToLLVMOptions &options) - : FuncOpConversionBase(converter, options) {} - using ConvertOpToLLVMPattern::options; + FuncOpConversion(LLVMTypeConverter &converter) + : FuncOpConversionBase(converter) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, @@ -1138,7 +1135,7 @@ if (!newFuncOp) return failure(); - if (options.emitCWrappers || + if (typeConverter.getOptions().emitCWrappers || funcOp.getAttrOfType(kEmitIfaceAttrName)) { if (newFuncOp.isExternal()) wrapExternalFunction(rewriter, op->getLoc(), typeConverter, funcOp, @@ -1652,11 +1649,9 @@ using ConvertOpToLLVMPattern::getIndexType; using ConvertOpToLLVMPattern::typeConverter; using ConvertOpToLLVMPattern::getVoidPtrType; - using ConvertOpToLLVMPattern::options; - explicit AllocLikeOpLowering(LLVMTypeConverter &converter, - const LowerToLLVMOptions &options) - : ConvertOpToLLVMPattern(converter, options) {} + explicit AllocLikeOpLowering(LLVMTypeConverter &converter) + : ConvertOpToLLVMPattern(converter) {} LogicalResult match(Operation *op) const override { MemRefType memRefType = cast(op).getType(); @@ -1823,7 +1818,7 @@ /// allocation size to be a multiple of alignment, Optional getAllocationAlignment(AllocOp allocOp) const { // No alignment can be used for the 'malloc' call itself. - if (!options.useAlignedAlloc) + if (!typeConverter.getOptions().useAlignedAlloc) return None; if (allocOp.alignment()) @@ -2002,9 +1997,8 @@ }; struct AllocOpLowering : public AllocLikeOpLowering { - explicit AllocOpLowering(LLVMTypeConverter &converter, - const LowerToLLVMOptions &options) - : AllocLikeOpLowering(converter, options) {} + explicit AllocOpLowering(LLVMTypeConverter &converter) + : AllocLikeOpLowering(converter) {} }; using AllocaOpLowering = AllocLikeOpLowering; @@ -2174,9 +2168,8 @@ struct DeallocOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - explicit DeallocOpLowering(LLVMTypeConverter &converter, - const LowerToLLVMOptions &options) - : ConvertOpToLLVMPattern(converter, options) {} + explicit DeallocOpLowering(LLVMTypeConverter &converter) + : ConvertOpToLLVMPattern(converter) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, @@ -3200,8 +3193,7 @@ /// Collect a set of patterns to convert from the Standard dialect to LLVM. void mlir::populateStdToLLVMNonMemoryConversionPatterns( - LLVMTypeConverter &converter, OwningRewritePatternList &patterns, - const LowerToLLVMOptions &options) { + LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { // FIXME: this should be tablegen'ed // clang-format off patterns.insert< @@ -3265,13 +3257,12 @@ UnsignedRemIOpLowering, UnsignedShiftRightOpLowering, XOrOpLowering, - ZeroExtendIOpLowering>(converter, options); + ZeroExtendIOpLowering>(converter); // clang-format on } void mlir::populateStdToLLVMMemoryConversionPatterns( - LLVMTypeConverter &converter, OwningRewritePatternList &patterns, - const LowerToLLVMOptions &options) { + LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { // clang-format off patterns.insert< AssumeAlignmentOpLowering, @@ -3282,25 +3273,23 @@ StoreOpLowering, SubViewOpLowering, ViewOpLowering, - AllocOpLowering>(converter, options); + AllocOpLowering>(converter); // clang-format on } void mlir::populateStdToLLVMFuncOpConversionPattern( - LLVMTypeConverter &converter, OwningRewritePatternList &patterns, - const LowerToLLVMOptions &options) { - if (options.useBarePtrCallConv) - patterns.insert(converter, options); + LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { + if (converter.getOptions().useBarePtrCallConv) + patterns.insert(converter); else - patterns.insert(converter, options); + patterns.insert(converter); } void mlir::populateStdToLLVMConversionPatterns( - LLVMTypeConverter &converter, OwningRewritePatternList &patterns, - const LowerToLLVMOptions &options) { - populateStdToLLVMFuncOpConversionPattern(converter, patterns, options); - populateStdToLLVMNonMemoryConversionPatterns(converter, patterns, options); - populateStdToLLVMMemoryConversionPatterns(converter, patterns, options); + LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { + populateStdToLLVMFuncOpConversionPattern(converter, patterns); + populateStdToLLVMNonMemoryConversionPatterns(converter, patterns); + populateStdToLLVMMemoryConversionPatterns(converter, patterns); } // Create an LLVM IR structure type if there is more than one result. @@ -3395,7 +3384,7 @@ LLVMTypeConverter typeConverter(&getContext(), options); OwningRewritePatternList patterns; - populateStdToLLVMConversionPatterns(typeConverter, patterns, options); + populateStdToLLVMConversionPatterns(typeConverter, patterns); LLVMConversionTarget target(getContext()); if (failed(applyPartialConversion(m, target, patterns)))