diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -249,7 +249,7 @@ // ControlFlowToLLVM //===----------------------------------------------------------------------===// -def ConvertControlFlowToLLVMPass : Pass<"convert-cf-to-llvm", "ModuleOp"> { +def ConvertControlFlowToLLVMPass : Pass<"convert-cf-to-llvm"> { let summary = "Convert ControlFlow operations to the LLVM dialect"; let description = [{ Convert ControlFlow operations into LLVM IR dialect operations. @@ -257,6 +257,16 @@ If other operations are present and their results are required by the LLVM IR dialect operations, the pass will fail. Any LLVM IR operations or types already present in the IR will be kept as is. + + This pass is not restricted to run to any op but certain conversions are + omitted when the pass is run on a non-builtin ModuleOp. The omitted + conversions are the ones that assume a ModuleOp is accessible, for + manipulating its symbol table (for the purpose of calling into external + libraries), namely: + 1. cf.assert lowering + + Such decoupling of partial lowerings is particularly useful in the context + of accelerators that may depend on custom module operations. }]; let dependentDialects = ["LLVM::LLVMDialect"]; let options = [ @@ -645,15 +655,25 @@ // MemRefToLLVM //===----------------------------------------------------------------------===// -def FinalizeMemRefToLLVMConversionPass : - Pass<"finalize-memref-to-llvm", "ModuleOp"> { +def FinalizeMemRefToLLVMConversionPass : Pass<"finalize-memref-to-llvm"> { let summary = "Finalize MemRef dialect to LLVM dialect conversion"; let description = [{ - Finalize the conversion of the operations from the MemRef - dialect to the LLVM dialect. - This conversion will not convert some complex MemRef - operations. Make sure to run `expand-strided-metadata` - beforehand for these. + Finalize the conversion of the operations from the MemRef dialect to the + LLVM dialect. + This conversion will not convert some complex MemRef operations. Make sure + to run `expand-strided-metadata` beforehand for these. + + This pass is not restricted to run to any op but certain conversions are + omitted when the pass is run on a non-builtin ModuleOp. The omitted + conversions are the ones that assume a ModuleOp is accessible, for + manipulating its symbol table (for the purpose of calling into external + libraries), namely: + 1. memref.copy lowering + 2. memref.alloc lowering + 3. memref.dealloc lowering + + Such decoupling of partial lowerings is particularly useful in the context + of accelerators that may depend on custom module operations. }]; let dependentDialects = ["LLVM::LLVMDialect"]; let options = [ diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp --- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp +++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp @@ -240,17 +240,25 @@ } // namespace -void mlir::cf::populateControlFlowToLLVMConversionPatterns( +static void populateModuleIndependentControlFlowToLLVMConversionPatterns( LLVMTypeConverter &converter, RewritePatternSet &patterns) { // clang-format off patterns.add< - AssertOpLowering, BranchOpLowering, CondBranchOpLowering, SwitchOpLowering>(converter); // clang-format on } +void mlir::cf::populateControlFlowToLLVMConversionPatterns( + LLVMTypeConverter &converter, RewritePatternSet &patterns) { + // clang-format off + patterns.add(converter); + // clang-format on + populateModuleIndependentControlFlowToLLVMConversionPatterns(converter, + patterns); +} + void mlir::cf::populateAssertToLLVMConversionPattern( LLVMTypeConverter &converter, RewritePatternSet &patterns, bool abortOnFailure) { @@ -279,8 +287,13 @@ options.useOpaquePointers = useOpaquePointers; LLVMTypeConverter converter(&getContext(), options); - mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns); - + if (isa(getOperation())) { + mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, + patterns); + } else { + populateModuleIndependentControlFlowToLLVMConversionPatterns(converter, + patterns); + } if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) signalPassFailure(); diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -1865,7 +1865,7 @@ } // namespace -void mlir::populateFinalizeMemRefToLLVMConversionPatterns( +static void populateModuleIndependentFinalizeMemRefToLLVMConversionPatterns( LLVMTypeConverter &converter, RewritePatternSet &patterns) { // clang-format off patterns.add< @@ -1881,7 +1881,6 @@ GetGlobalMemrefOpLowering, LoadOpLowering, MemRefCastOpLowering, - MemRefCopyOpLowering, MemorySpaceCastOpLowering, MemRefReinterpretCastOpLowering, MemRefReshapeOpLowering, @@ -1894,6 +1893,15 @@ TransposeOpLowering, ViewOpLowering>(converter); // clang-format on +} + +void mlir::populateFinalizeMemRefToLLVMConversionPatterns( + LLVMTypeConverter &converter, RewritePatternSet &patterns) { + // clang-format off + patterns.add< + MemRefCopyOpLowering>(converter); + // clang-format on + auto allocLowering = converter.getOptions().allocLowering; if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc) patterns.add( converter); + + populateModuleIndependentFinalizeMemRefToLLVMConversionPatterns(converter, + patterns); } namespace { @@ -1929,6 +1940,12 @@ &dataLayoutAnalysis); RewritePatternSet patterns(&getContext()); populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns); + if (isa(getOperation())) { + populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns); + } else { + populateModuleIndependentFinalizeMemRefToLLVMConversionPatterns( + typeConverter, patterns); + } LLVMConversionTarget target(getContext()); target.addLegalOp(); if (failed(applyPartialConversion(op, target, std::move(patterns))))