diff --git a/mlir/docs/Tutorials/Toy/Ch-6.md b/mlir/docs/Tutorials/Toy/Ch-6.md --- a/mlir/docs/Tutorials/Toy/Ch-6.md +++ b/mlir/docs/Tutorials/Toy/Ch-6.md @@ -105,7 +105,7 @@ that only legal operations will remain after the conversion. ```c++ - mlir::ModuleOp module = getModule(); + mlir::ModuleOp module = getOperation(); if (mlir::failed(mlir::applyFullConversion(module, target, patterns, &typeConverter))) signalPassFailure(); diff --git a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp --- a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp +++ b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp @@ -153,12 +153,13 @@ //===----------------------------------------------------------------------===// namespace { -struct ToyToLLVMLoweringPass : public ModulePass { - void runOnModule() final; +struct ToyToLLVMLoweringPass + : public OperationPass { + void runOnOperation() final; }; } // end anonymous namespace -void ToyToLLVMLoweringPass::runOnModule() { +void ToyToLLVMLoweringPass::runOnOperation() { // The first thing to define is the conversion target. This will define the // final target for this lowering. For this lowering, we are only targeting // the LLVM dialect. @@ -191,7 +192,7 @@ // We want to completely lower to LLVM, so we use a `FullConversion`. This // ensures that only legal operations will remain after the conversion. - auto module = getModule(); + auto module = getOperation(); if (failed(applyFullConversion(module, target, patterns, &typeConverter))) signalPassFailure(); } diff --git a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp --- a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp +++ b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp @@ -153,12 +153,13 @@ //===----------------------------------------------------------------------===// namespace { -struct ToyToLLVMLoweringPass : public ModulePass { - void runOnModule() final; +struct ToyToLLVMLoweringPass + : public OperationPass { + void runOnOperation() final; }; } // end anonymous namespace -void ToyToLLVMLoweringPass::runOnModule() { +void ToyToLLVMLoweringPass::runOnOperation() { // The first thing to define is the conversion target. This will define the // final target for this lowering. For this lowering, we are only targeting // the LLVM dialect. @@ -191,7 +192,7 @@ // We want to completely lower to LLVM, so we use a `FullConversion`. This // ensures that only legal operations will remain after the conversion. - auto module = getModule(); + auto module = getOperation(); if (failed(applyFullConversion(module, target, patterns, &typeConverter))) signalPassFailure(); } diff --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h --- a/mlir/include/mlir/Pass/Pass.h +++ b/mlir/include/mlir/Pass/Pass.h @@ -341,24 +341,9 @@ runOnFunction(); } - /// Return the current module being transformed. + /// Return the current function being transformed. FuncOp getFunction() { return this->getOperation(); } }; - -/// A model for providing module pass specific utilities. -/// -/// Derived module passes are expected to provide the following: -/// - A 'void runOnModule()' method. -template struct ModulePass : public OperationPass { - /// The polymorphic API that runs the pass over the currently held module. - virtual void runOnModule() = 0; - - /// The polymorphic API that runs the pass over the currently held operation. - void runOnOperation() final { runOnModule(); } - - /// Return the current module being transformed. - ModuleOp getModule() { return this->getOperation(); } -}; } // end namespace mlir #endif // MLIR_PASS_PASS_H diff --git a/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp b/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp --- a/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp +++ b/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp @@ -163,16 +163,17 @@ } namespace { -struct ConvertAVX512ToLLVMPass : public ModulePass { +struct ConvertAVX512ToLLVMPass + : public OperationPass { /// Include the generated pass utilities. #define GEN_PASS_ConvertAVX512ToLLVM #include "mlir/Conversion/Passes.h.inc" - void runOnModule() override; + void runOnOperation() override; }; } // namespace -void ConvertAVX512ToLLVMPass::runOnModule() { +void ConvertAVX512ToLLVMPass::runOnOperation() { // Convert to the LLVM IR dialect. OwningRewritePatternList patterns; LLVMTypeConverter converter(&getContext()); @@ -186,8 +187,8 @@ target.addIllegalDialect(); target.addDynamicallyLegalOp( [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); - if (failed( - applyPartialConversion(getModule(), target, patterns, &converter))) { + if (failed(applyPartialConversion(getOperation(), target, patterns, + &converter))) { signalPassFailure(); } } diff --git a/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp b/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp --- a/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp +++ b/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp @@ -61,7 +61,7 @@ /// /// Intermediate data structures are allocated on the stack. class GpuLaunchFuncToCudaCallsPass - : public ModulePass { + : public OperationPass { private: /// Include the generated pass utilities. #define GEN_PASS_ConvertGpuLaunchFuncToCudaCalls @@ -126,20 +126,19 @@ public: // Run the dialect converter on the module. - void runOnModule() override { + void runOnOperation() override { // Cache the LLVMDialect for the current module. llvmDialect = getContext().getRegisteredDialect(); // Cache the used LLVM types. initializeCachedTypes(); - getModule().walk([this](mlir::gpu::LaunchFuncOp op) { - translateGpuLaunchCalls(op); - }); + getOperation().walk( + [this](mlir::gpu::LaunchFuncOp op) { translateGpuLaunchCalls(op); }); // GPU kernel modules are no longer necessary since we have a global // constant with the CUBIN data. for (auto m : - llvm::make_early_inc_range(getModule().getOps())) + llvm::make_early_inc_range(getOperation().getOps())) m.erase(); } @@ -160,7 +159,7 @@ // The types in comments give the actual types expected/returned but the API // uses void pointers. This is fine as they have the same linkage in C. void GpuLaunchFuncToCudaCallsPass::declareCudaFunctions(Location loc) { - ModuleOp module = getModule(); + ModuleOp module = getOperation(); OpBuilder builder(module.getBody()->getTerminator()); if (!module.lookupSymbol(cuModuleLoadName)) { builder.create( @@ -391,7 +390,7 @@ builder.getI32IntegerAttr(0)); // Create an LLVM global with CUBIN extracted from the kernel annotation and // obtain a pointer to the first byte in it. - auto kernelModule = getModule().lookupSymbol( + auto kernelModule = getOperation().lookupSymbol( launchOp.getKernelModuleName()); assert(kernelModule && "expected a kernel module"); @@ -412,7 +411,7 @@ // in the called helper function. auto cuModule = allocatePointer(builder, loc); auto cuModuleLoad = - getModule().lookupSymbol(cuModuleLoadName); + getOperation().lookupSymbol(cuModuleLoadName); builder.create(loc, ArrayRef{getCUResultType()}, builder.getSymbolRefAttr(cuModuleLoad), ArrayRef{cuModule, data}); @@ -423,20 +422,20 @@ auto kernelName = generateKernelNameConstant(launchOp.kernel(), loc, builder); auto cuFunction = allocatePointer(builder, loc); auto cuModuleGetFunction = - getModule().lookupSymbol(cuModuleGetFunctionName); + getOperation().lookupSymbol(cuModuleGetFunctionName); builder.create( loc, ArrayRef{getCUResultType()}, builder.getSymbolRefAttr(cuModuleGetFunction), ArrayRef{cuFunction, cuOwningModuleRef, kernelName}); // Grab the global stream needed for execution. auto cuGetStreamHelper = - getModule().lookupSymbol(cuGetStreamHelperName); + getOperation().lookupSymbol(cuGetStreamHelperName); auto cuStream = builder.create( loc, ArrayRef{getPointerType()}, builder.getSymbolRefAttr(cuGetStreamHelper), ArrayRef{}); // Invoke the function with required arguments. auto cuLaunchKernel = - getModule().lookupSymbol(cuLaunchKernelName); + getOperation().lookupSymbol(cuLaunchKernelName); auto cuFunctionRef = builder.create(loc, getPointerType(), cuFunction); auto paramsArray = setupParamsArray(launchOp, builder); @@ -458,7 +457,7 @@ nullpointer /* extra */}); // Sync on the stream to make it synchronous. auto cuStreamSync = - getModule().lookupSymbol(cuStreamSynchronizeName); + getOperation().lookupSymbol(cuStreamSynchronizeName); builder.create(loc, ArrayRef{getCUResultType()}, builder.getSymbolRefAttr(cuStreamSync), ArrayRef(cuStream.getResult(0))); diff --git a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp --- a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp @@ -33,18 +33,18 @@ /// replace it). /// /// 2) Lower the body of the spirv::ModuleOp. -struct GPUToSPIRVPass : public ModulePass { +struct GPUToSPIRVPass : public OperationPass { /// Include the generated pass utilities. #define GEN_PASS_ConvertGpuToSPIRV #include "mlir/Conversion/Passes.h.inc" - void runOnModule() override; + void runOnOperation() override; }; } // namespace -void GPUToSPIRVPass::runOnModule() { +void GPUToSPIRVPass::runOnOperation() { MLIRContext *context = &getContext(); - ModuleOp module = getModule(); + ModuleOp module = getOperation(); SmallVector kernelModules; OpBuilder builder(context); diff --git a/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp --- a/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp +++ b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp @@ -38,13 +38,13 @@ /// function and attaching binary data and entry point name as an attributes to /// created vulkan launch call op. class ConvertGpuLaunchFuncToVulkanLaunchFunc - : public ModulePass { + : public OperationPass { public: /// Include the generated pass utilities. #define GEN_PASS_ConvertGpuLaunchFuncToVulkanLaunchFunc #include "mlir/Conversion/Passes.h.inc" - void runOnModule() override; + void runOnOperation() override; private: /// Creates a SPIR-V binary shader from the given `module` using @@ -68,14 +68,13 @@ /// operand is unsupported by Vulkan runtime. LogicalResult declareVulkanLaunchFunc(Location loc, gpu::LaunchFuncOp launchOp); - }; } // anonymous namespace -void ConvertGpuLaunchFuncToVulkanLaunchFunc::runOnModule() { +void ConvertGpuLaunchFuncToVulkanLaunchFunc::runOnOperation() { bool done = false; - getModule().walk([this, &done](gpu::LaunchFuncOp op) { + getOperation().walk([this, &done](gpu::LaunchFuncOp op) { if (done) { op.emitError("should only contain one 'gpu::LaunchFuncOp' op"); return signalPassFailure(); @@ -86,17 +85,17 @@ // Erase `gpu::GPUModuleOp` and `spirv::Module` operations. for (auto gpuModule : - llvm::make_early_inc_range(getModule().getOps())) + llvm::make_early_inc_range(getOperation().getOps())) gpuModule.erase(); for (auto spirvModule : - llvm::make_early_inc_range(getModule().getOps())) + llvm::make_early_inc_range(getOperation().getOps())) spirvModule.erase(); } LogicalResult ConvertGpuLaunchFuncToVulkanLaunchFunc::declareVulkanLaunchFunc( Location loc, gpu::LaunchFuncOp launchOp) { - OpBuilder builder(getModule().getBody()->getTerminator()); + OpBuilder builder(getOperation().getBody()->getTerminator()); // TODO: Workgroup size is written into the kernel. So to properly modelling // vulkan launch, we cannot have the local workgroup size configuration here. SmallVector vulkanLaunchTypes{launchOp.getOperandTypes()}; @@ -138,7 +137,7 @@ void ConvertGpuLaunchFuncToVulkanLaunchFunc::convertGpuLaunchFunc( gpu::LaunchFuncOp launchOp) { - ModuleOp module = getModule(); + ModuleOp module = getOperation(); OpBuilder builder(launchOp); Location loc = launchOp.getLoc(); diff --git a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp --- a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp +++ b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp @@ -57,7 +57,7 @@ /// * deinitVulkan -- deinitializes vulkan runtime /// class VulkanLaunchFuncToVulkanCallsPass - : public ModulePass { + : public OperationPass { private: /// Include the generated pass utilities. #define GEN_PASS_ConvertVulkanLaunchFuncToVulkanCalls @@ -153,7 +153,7 @@ LogicalResult deduceMemRefRank(Value ptrToMemRefDescriptor, uint32_t &rank); public: - void runOnModule() override; + void runOnOperation() override; private: LLVM::LLVMDialect *llvmDialect; @@ -171,18 +171,18 @@ } // anonymous namespace -void VulkanLaunchFuncToVulkanCallsPass::runOnModule() { +void VulkanLaunchFuncToVulkanCallsPass::runOnOperation() { initializeCachedTypes(); // Collect SPIR-V attributes such as `spirv_blob` and // `spirv_entry_point_name`. - getModule().walk([this](LLVM::CallOp op) { + getOperation().walk([this](LLVM::CallOp op) { if (isVulkanLaunchCallOp(op)) collectSPIRVAttributes(op); }); // Convert vulkan launch call op into a sequence of Vulkan runtime calls. - getModule().walk([this](LLVM::CallOp op) { + getOperation().walk([this](LLVM::CallOp op) { if (isCInterfaceVulkanLaunchCallOp(op)) translateVulkanLaunchCall(op); }); @@ -280,7 +280,7 @@ } void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) { - ModuleOp module = getModule(); + ModuleOp module = getOperation(); OpBuilder builder(module.getBody()->getTerminator()); if (!module.lookupSymbol(kSetEntryPoint)) { diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp --- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp +++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp @@ -556,17 +556,18 @@ } namespace { -struct ConvertLinalgToLLVMPass : public ModulePass { +struct ConvertLinalgToLLVMPass + : public OperationPass { /// Include the generated pass utilities. #define GEN_PASS_ConvertLinalgToLLVM #include "mlir/Conversion/Passes.h.inc" - void runOnModule() override; + void runOnOperation() override; }; } // namespace -void ConvertLinalgToLLVMPass::runOnModule() { - auto module = getModule(); +void ConvertLinalgToLLVMPass::runOnOperation() { + auto module = getOperation(); // Convert to the LLVM IR dialect using the converter defined above. OwningRewritePatternList patterns; diff --git a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp --- a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp +++ b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp @@ -16,18 +16,18 @@ namespace { /// A pass converting MLIR Linalg ops into SPIR-V ops. -class LinalgToSPIRVPass : public ModulePass { +class LinalgToSPIRVPass : public OperationPass { /// Include the generated pass utilities. #define GEN_PASS_ConvertLinalgToSPIRV #include "mlir/Conversion/Passes.h.inc" - void runOnModule() override; + void runOnOperation() override; }; } // namespace -void LinalgToSPIRVPass::runOnModule() { +void LinalgToSPIRVPass::runOnOperation() { MLIRContext *context = &getContext(); - ModuleOp module = getModule(); + ModuleOp module = getOperation(); auto targetAttr = spirv::lookupTargetEnvOrDefault(module); std::unique_ptr target = diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -2776,7 +2776,7 @@ namespace { /// A pass converting MLIR operations into the LLVM IR dialect. -struct LLVMLoweringPass : public ModulePass { +struct LLVMLoweringPass : public OperationPass { /// Include the generated pass utilities. #define GEN_PASS_ConvertStandardToLLVM #include "mlir/Conversion/Passes.h.inc" @@ -2793,16 +2793,16 @@ LLVMLoweringPass(const LLVMLoweringPass &pass) {} /// Run the dialect converter on the module. - void runOnModule() override { + void runOnOperation() override { if (useBarePtrCallConv && emitCWrappers) { - getModule().emitError() + getOperation().emitError() << "incompatible conversion options: bare-pointer calling convention " "and C wrapper emission"; signalPassFailure(); return; } - ModuleOp m = getModule(); + ModuleOp m = getOperation(); LLVMTypeConverterCustomization customs; customs.funcArgConverter = useBarePtrCallConv ? barePtrFuncArgTypeConverter diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp @@ -22,18 +22,18 @@ namespace { /// A pass converting MLIR Standard operations into the SPIR-V dialect. class ConvertStandardToSPIRVPass - : public ModulePass { + : public OperationPass { /// Include the generated pass utilities. #define GEN_PASS_ConvertStandardToSPIRV #include "mlir/Conversion/Passes.h.inc" - void runOnModule() override; + void runOnOperation() override; }; } // namespace -void ConvertStandardToSPIRVPass::runOnModule() { +void ConvertStandardToSPIRVPass::runOnOperation() { MLIRContext *context = &getContext(); - ModuleOp module = getModule(); + ModuleOp module = getOperation(); auto targetAttr = spirv::lookupTargetEnvOrDefault(module); std::unique_ptr target = diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1118,23 +1118,24 @@ } namespace { -struct LowerVectorToLLVMPass : public ModulePass { +struct LowerVectorToLLVMPass + : public OperationPass { /// Include the generated pass utilities. #define GEN_PASS_ConvertVectorToLLVM #include "mlir/Conversion/Passes.h.inc" - void runOnModule() override; + void runOnOperation() override; }; } // namespace -void LowerVectorToLLVMPass::runOnModule() { +void LowerVectorToLLVMPass::runOnOperation() { // Perform progressive lowering of operations on slices and // all contraction operations. Also applies folding and DCE. { OwningRewritePatternList patterns; populateVectorSlicesLoweringPatterns(patterns, &getContext()); populateVectorContractLoweringPatterns(patterns, &getContext()); - applyPatternsGreedily(getModule(), patterns); + applyPatternsGreedily(getOperation(), patterns); } // Convert to the LLVM IR dialect. @@ -1148,8 +1149,8 @@ LLVMConversionTarget target(getContext()); target.addDynamicallyLegalOp( [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); - if (failed( - applyPartialConversion(getModule(), target, patterns, &converter))) { + if (failed(applyPartialConversion(getOperation(), target, patterns, + &converter))) { signalPassFailure(); } } diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp --- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp @@ -214,16 +214,17 @@ /// The gpu.modules are intended to be compiled to a cubin blob independently in /// a separate pass. The external functions can then be annotated with the /// symbol of the cubin accessor function. -class GpuKernelOutliningPass : public ModulePass { +class GpuKernelOutliningPass + : public OperationPass { public: /// Include the generated pass utilities. #define GEN_PASS_GpuKernelOutlining #include "mlir/Dialect/GPU/Passes.h.inc" - void runOnModule() override { - SymbolTable symbolTable(getModule()); + void runOnOperation() override { + SymbolTable symbolTable(getOperation()); bool modified = false; - for (auto func : getModule().getOps()) { + for (auto func : getOperation().getOps()) { // Insert just after the function. Block::iterator insertPt(func.getOperation()->getNextNode()); auto funcWalkResult = func.walk([&](gpu::LaunchOp op) { @@ -255,8 +256,8 @@ // If any new module was inserted in this module, annotate this module as // a container module. if (modified) - getModule().setAttr(gpu::GPUDialect::getContainerModuleAttrName(), - UnitAttr::get(&getContext())); + getOperation().setAttr(gpu::GPUDialect::getContainerModuleAttrName(), + UnitAttr::get(&getContext())); } private: @@ -267,7 +268,7 @@ // a SymbolTable by the caller. SymbolTable needs to be refactored to // prevent manual building of Ops with symbols in code using SymbolTables // and then this needs to use the OpBuilder. - auto context = getModule().getContext(); + auto context = getOperation().getContext(); Builder builder(context); OperationState state(kernelFunc.getLoc(), gpu::GPUModuleOp::getOperationName()); diff --git a/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp @@ -80,14 +80,14 @@ namespace { class DecorateSPIRVCompositeTypeLayoutPass - : public ModulePass { + : public OperationPass { private: - void runOnModule() override; + void runOnOperation() override; }; } // namespace -void DecorateSPIRVCompositeTypeLayoutPass::runOnModule() { - auto module = getModule(); +void DecorateSPIRVCompositeTypeLayoutPass::runOnOperation() { + auto module = getOperation(); OwningRewritePatternList patterns; populateSPIRVLayoutInfoPatterns(patterns, module.getContext()); ConversionTarget target(*(module.getContext())); diff --git a/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp b/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp --- a/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp +++ b/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp @@ -71,7 +71,8 @@ } // end namespace llvm namespace { -class InferQuantizedTypesPass : public ModulePass { +class InferQuantizedTypesPass + : public OperationPass { public: /// Include the generated pass utilities. #define GEN_PASS_QuantizerInferQuantizedTypes @@ -82,7 +83,7 @@ const TargetConfiguration &config) : explicitSolverContext(&solverContext), explicitConfig(&config) {} - void runOnModule() override; + void runOnOperation() override; void runWithConfig(SolverContext &solverContext, const TargetConfiguration &config); @@ -108,7 +109,7 @@ return success(); } -void InferQuantizedTypesPass::runOnModule() { +void InferQuantizedTypesPass::runOnOperation() { if (explicitSolverContext && explicitConfig) { // If explicitly constructed with a config and context. runWithConfig(*explicitSolverContext, *explicitConfig); @@ -116,7 +117,7 @@ } // For global pass registration, use defaults. - SolverContext solverContext(*getModule().getContext()); + SolverContext solverContext(*getOperation().getContext()); auto config = FxpMathTargetConfig::create(solverContext); runWithConfig(solverContext, *config); } @@ -124,7 +125,7 @@ void InferQuantizedTypesPass::runWithConfig(SolverContext &solverContext, const TargetConfiguration &config) { CAGSlice cag(solverContext); - for (auto f : getModule().getOps()) { + for (auto f : getOperation().getOps()) { f.walk([&cag, &config](Operation *op) { config.handleOp(op, cag); }); } config.finalizeAnchors(cag); diff --git a/mlir/lib/Transforms/OpStats.cpp b/mlir/lib/Transforms/OpStats.cpp --- a/mlir/lib/Transforms/OpStats.cpp +++ b/mlir/lib/Transforms/OpStats.cpp @@ -18,7 +18,7 @@ using namespace mlir; namespace { -struct PrintOpStatsPass : public ModulePass { +struct PrintOpStatsPass : public OperationPass { /// Include the generated pass utilities. #define GEN_PASS_PrintOpStats #include "mlir/Transforms/Passes.h.inc" @@ -26,7 +26,7 @@ explicit PrintOpStatsPass(raw_ostream &os = llvm::errs()) : os(os) {} // Prints the resultant operation statistics post iterating over the module. - void runOnModule() override; + void runOnOperation() override; // Print summary of op stats. void printSummary(); @@ -37,11 +37,11 @@ }; } // namespace -void PrintOpStatsPass::runOnModule() { +void PrintOpStatsPass::runOnOperation() { opCount.clear(); // Compute the operation statistics for each function in the module. - for (auto &op : getModule()) + for (auto &op : getOperation()) op.walk([&](Operation *op) { ++opCount[op->getName().getStringRef()]; }); printSummary(); } diff --git a/mlir/lib/Transforms/ViewOpGraph.cpp b/mlir/lib/Transforms/ViewOpGraph.cpp --- a/mlir/lib/Transforms/ViewOpGraph.cpp +++ b/mlir/lib/Transforms/ViewOpGraph.cpp @@ -100,7 +100,7 @@ // PrintOpPass is simple pass to write graph per function. // Note: this is a module pass only to avoid interleaving on the same ostream // due to multi-threading over functions. -struct PrintOpPass : public ModulePass { +struct PrintOpPass : public OperationPass { /// Include the generated pass utilities. #define GEN_PASS_PrintOpGraph #include "mlir/Transforms/Passes.h.inc" @@ -140,7 +140,7 @@ } } - void runOnModule() override { processModule(getModule()); } + void runOnOperation() override { processModule(getOperation()); } private: raw_ostream &os; diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -363,13 +363,13 @@ }; struct TestLegalizePatternDriver - : public ModulePass { + : public OperationPass { /// The mode of conversion to use with the driver. enum class ConversionMode { Analysis, Full, Partial }; TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {} - void runOnModule() override { + void runOnOperation() override { TestTypeConverter converter; mlir::OwningRewritePatternList patterns; populateWithGenerated(&getContext(), &patterns); @@ -414,7 +414,8 @@ // Handle a partial conversion. if (mode == ConversionMode::Partial) { - (void)applyPartialConversion(getModule(), target, patterns, &converter); + (void)applyPartialConversion(getOperation(), target, patterns, + &converter); return; } @@ -425,7 +426,7 @@ return (bool)op->getAttrOfType("test.dynamically_legal"); }); - (void)applyFullConversion(getModule(), target, patterns, &converter); + (void)applyFullConversion(getOperation(), target, patterns, &converter); return; } @@ -434,7 +435,7 @@ // Analyze the convertible operations. DenseSet legalizedOps; - if (failed(applyAnalysisConversion(getModule(), target, patterns, + if (failed(applyAnalysisConversion(getOperation(), target, patterns, legalizedOps, &converter))) return signalPassFailure(); diff --git a/mlir/test/lib/IR/TestFunc.cpp b/mlir/test/lib/IR/TestFunc.cpp --- a/mlir/test/lib/IR/TestFunc.cpp +++ b/mlir/test/lib/IR/TestFunc.cpp @@ -13,9 +13,9 @@ namespace { /// This is a test pass for verifying FuncOp's eraseArgument method. -struct TestFuncEraseArg : public ModulePass { - void runOnModule() override { - auto module = getModule(); +struct TestFuncEraseArg : public OperationPass { + void runOnOperation() override { + auto module = getOperation(); for (FuncOp func : module.getOps()) { SmallVector indicesToErase; @@ -36,9 +36,9 @@ }; /// This is a test pass for verifying FuncOp's setType method. -struct TestFuncSetType : public ModulePass { - void runOnModule() override { - auto module = getModule(); +struct TestFuncSetType : public OperationPass { + void runOnOperation() override { + auto module = getOperation(); SymbolTable symbolTable(module); for (FuncOp func : module.getOps()) { diff --git a/mlir/test/lib/IR/TestSideEffects.cpp b/mlir/test/lib/IR/TestSideEffects.cpp --- a/mlir/test/lib/IR/TestSideEffects.cpp +++ b/mlir/test/lib/IR/TestSideEffects.cpp @@ -12,9 +12,9 @@ using namespace mlir; namespace { -struct SideEffectsPass : public ModulePass { - void runOnModule() override { - auto module = getModule(); +struct SideEffectsPass : public OperationPass { + void runOnOperation() override { + auto module = getOperation(); // Walk operations detecting side effects. SmallVector effects; diff --git a/mlir/test/lib/IR/TestSymbolUses.cpp b/mlir/test/lib/IR/TestSymbolUses.cpp --- a/mlir/test/lib/IR/TestSymbolUses.cpp +++ b/mlir/test/lib/IR/TestSymbolUses.cpp @@ -15,7 +15,7 @@ namespace { /// This is a symbol test pass that tests the symbol uselist functionality /// provided by the symbol table along with erasing from the symbol table. -struct SymbolUsesPass : public ModulePass { +struct SymbolUsesPass : public OperationPass { WalkResult operateOnSymbol(Operation *symbol, ModuleOp module, SmallVectorImpl &deadFunctions) { // Test computing uses on a non symboltable op. @@ -59,8 +59,8 @@ return WalkResult::advance(); } - void runOnModule() override { - auto module = getModule(); + void runOnOperation() override { + auto module = getOperation(); // Walk nested symbols. SmallVector deadFunctions; @@ -86,9 +86,10 @@ /// This is a symbol test pass that tests the symbol use replacement /// functionality provided by the symbol table. -struct SymbolReplacementPass : public ModulePass { - void runOnModule() override { - auto module = getModule(); +struct SymbolReplacementPass + : public OperationPass { + void runOnOperation() override { + auto module = getOperation(); // Walk nested functions and modules. module.getBodyRegion().walk([&](Operation *nestedOp) { diff --git a/mlir/test/lib/Pass/TestPassManager.cpp b/mlir/test/lib/Pass/TestPassManager.cpp --- a/mlir/test/lib/Pass/TestPassManager.cpp +++ b/mlir/test/lib/Pass/TestPassManager.cpp @@ -13,8 +13,8 @@ using namespace mlir; namespace { -struct TestModulePass : public ModulePass { - void runOnModule() final {} +struct TestModulePass : public OperationPass { + void runOnOperation() final {} }; struct TestFunctionPass : public FunctionPass { void runOnFunction() final {} diff --git a/mlir/test/lib/Transforms/TestAllReduceLowering.cpp b/mlir/test/lib/Transforms/TestAllReduceLowering.cpp --- a/mlir/test/lib/Transforms/TestAllReduceLowering.cpp +++ b/mlir/test/lib/Transforms/TestAllReduceLowering.cpp @@ -18,11 +18,11 @@ namespace { struct TestAllReduceLoweringPass - : public ModulePass { - void runOnModule() override { + : public OperationPass { + void runOnOperation() override { OwningRewritePatternList patterns; populateGpuRewritePatterns(&getContext(), patterns); - applyPatternsGreedily(getModule(), patterns); + applyPatternsGreedily(getOperation(), patterns); } }; } // namespace diff --git a/mlir/test/lib/Transforms/TestCallGraph.cpp b/mlir/test/lib/Transforms/TestCallGraph.cpp --- a/mlir/test/lib/Transforms/TestCallGraph.cpp +++ b/mlir/test/lib/Transforms/TestCallGraph.cpp @@ -17,9 +17,9 @@ using namespace mlir; namespace { -struct TestCallGraphPass : public ModulePass { - void runOnModule() { - llvm::errs() << "Testing : " << getModule().getAttr("test.name") << "\n"; +struct TestCallGraphPass : public OperationPass { + void runOnOperation() override { + llvm::errs() << "Testing : " << getOperation().getAttr("test.name") << "\n"; getAnalysis().print(llvm::errs()); } }; diff --git a/mlir/test/lib/Transforms/TestOpaqueLoc.cpp b/mlir/test/lib/Transforms/TestOpaqueLoc.cpp --- a/mlir/test/lib/Transforms/TestOpaqueLoc.cpp +++ b/mlir/test/lib/Transforms/TestOpaqueLoc.cpp @@ -17,7 +17,7 @@ /// It also takes all operations that are not function operations or /// terminators and clones them with opaque locations which store the initial /// locations. -struct TestOpaqueLoc : public ModulePass { +struct TestOpaqueLoc : public OperationPass { /// A simple structure which is used for testing as an underlying location in /// OpaqueLoc. @@ -29,11 +29,11 @@ int id; }; - void runOnModule() override { + void runOnOperation() override { std::vector> myLocs; int last_it = 0; - getModule().walk([&](Operation *op) { + getOperation().walk([&](Operation *op) { myLocs.push_back(std::make_unique(last_it++)); Location loc = op->getLoc(); @@ -74,7 +74,7 @@ os.flush(); }); - getModule().walk([&](Operation *op) { op->emitOpError(); }); + getOperation().walk([&](Operation *op) { op->emitOpError(); }); } };