diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h @@ -37,8 +37,19 @@ bool useBarePtrCallConv = false; bool emitCWrappers = false; - /// Use aligned_alloc for heap allocations. - bool useAlignedAlloc = false; + enum class AllocLowering { + /// Use malloc for for heap allocations. + Malloc, + + /// Use aligned_alloc for heap allocations. + AlignedAlloc, + + /// Do no lower heap allocations. User must provide his own patterns for + /// AllocOp and DeallocOp lowering. + None + }; + + AllocLowering allocLowering = AllocLowering::Malloc; /// The data layout of the module to produce. This must be consistent with the /// data layout used in the upper levels of the lowering pipeline. diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -3922,7 +3922,6 @@ // clang-format off patterns.add< AssumeAlignmentOpLowering, - DeallocOpLowering, DimOpLowering, GlobalMemrefOpLowering, GetGlobalMemrefOpLowering, @@ -3936,10 +3935,11 @@ TransposeOpLowering, ViewOpLowering>(converter); // clang-format on - if (converter.getOptions().useAlignedAlloc) - patterns.add(converter); - else - patterns.add(converter); + auto allocLowering = converter.getOptions().allocLowering; + if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc) + patterns.add(converter); + else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc) + patterns.add(converter); } void mlir::populateStdToLLVMFuncOpConversionPattern( @@ -4071,7 +4071,9 @@ options.emitCWrappers = emitCWrappers; if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) options.overrideIndexBitwidth(indexBitwidth); - options.useAlignedAlloc = useAlignedAlloc; + options.allocLowering = + (useAlignedAlloc ? LowerToLLVMOptions::AllocLowering::AlignedAlloc + : LowerToLLVMOptions::AllocLowering::Malloc); options.dataLayout = llvm::DataLayout(this->dataLayout); LLVMTypeConverter typeConverter(&getContext(), options); @@ -4139,9 +4141,15 @@ std::unique_ptr> mlir::createLowerToLLVMPass(const LowerToLLVMOptions &options) { + auto allocLowering = options.allocLowering; + // There is no way to provide additional patterns for pass, so + // AllocLowering::None will always fail. + assert(allocLowering != LowerToLLVMOptions::AllocLowering::None); + bool useAlignedAlloc = + (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc); return std::make_unique( options.useBarePtrCallConv, options.emitCWrappers, - options.getIndexBitwidth(), options.useAlignedAlloc, options.dataLayout); + options.getIndexBitwidth(), useAlignedAlloc, options.dataLayout); } mlir::LowerToLLVMOptions::LowerToLLVMOptions(MLIRContext *ctx)