diff --git a/mlir/docs/Dialects/SPIR-V.md b/mlir/docs/Dialects/SPIR-V.md --- a/mlir/docs/Dialects/SPIR-V.md +++ b/mlir/docs/Dialects/SPIR-V.md @@ -92,8 +92,8 @@ (de)serialization. * Ops with `mlir.snake_case` names are those that have no corresponding instructions (or concepts) in the binary format. They are introduced to - satisfy MLIR structural requirements. For example, `spv.mlir.endmodule` and - `spv.mlir.merge`. They map to no instructions during (de)serialization. + satisfy MLIR structural requirements. For example, `spv.mlir.merge`. They + map to no instructions during (de)serialization. (TODO: consider merging the last two cases and adopting `spv.mlir.` prefix for them.) diff --git a/mlir/docs/SPIRVToLLVMDialectConversion.md b/mlir/docs/SPIRVToLLVMDialectConversion.md --- a/mlir/docs/SPIRVToLLVMDialectConversion.md +++ b/mlir/docs/SPIRVToLLVMDialectConversion.md @@ -810,8 +810,6 @@ `spv.module` is converted into `ModuleOp`. This plays a role of enclosing scope to LLVM ops. At the moment, SPIR-V module attributes are ignored. -`spv.mlir.endmodule` is mapped to an equivalent terminator `ModuleTerminatorOp`. - ## `mlir-spirv-cpu-runner` `mlir-spirv-cpu-runner` allows to execute `gpu` dialect kernel on the CPU via diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td @@ -407,9 +407,8 @@ // ----- def SPV_ModuleOp : SPV_Op<"module", - [IsolatedFromAbove, - SingleBlockImplicitTerminator<"ModuleEndOp">, - SymbolTable, Symbol]> { + [IsolatedFromAbove, NoRegionArguments, NoTerminator, + SingleBlock, SymbolTable, Symbol]> { let summary = "The top-level op that defines a SPIR-V module"; let description = [{ @@ -426,7 +425,7 @@ implicitly capture values from the enclosing environment. This op has only one region, which only contains one block. The block - must be terminated via the `spv.mlir.endmodule` op. + has no terminator. @@ -463,7 +462,7 @@ let results = (outs); - let regions = (region SizedRegion<1>:$body); + let regions = (region AnyRegion); let builders = [ OpBuilder<(ins CArg<"Optional", "llvm::None">:$name)>, @@ -487,40 +486,11 @@ Optional getName() { return sym_name(); } static StringRef getVCETripleAttrName() { return "vce_triple"; } - - Block& getBlock() { - return this->getOperation()->getRegion(0).front(); - } }]; } // ----- -def SPV_ModuleEndOp : SPV_Op<"mlir.endmodule", [InModuleScope, Terminator]> { - let summary = "The pseudo op that ends a SPIR-V module"; - - let description = [{ - This op terminates the only block inside a `spv.module`'s only region. - This op does not have a corresponding SPIR-V instruction and thus will - not be serialized into the binary format; it is used solely to satisfy - the structual requirement that an block must be ended with a terminator. - }]; - - let arguments = (ins); - - let results = (outs); - - let assemblyFormat = "attr-dict"; - - let verifier = [{ return success(); }]; - - let hasOpcode = 0; - - let autogenSerialization = 0; -} - -// ----- - def SPV_ReferenceOfOp : SPV_Op<"mlir.referenceof", [NoSideEffect]> { let summary = "Reference a specialization constant."; diff --git a/mlir/lib/Conversion/GPUToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/GPUToSPIRV/CMakeLists.txt --- a/mlir/lib/Conversion/GPUToSPIRV/CMakeLists.txt +++ b/mlir/lib/Conversion/GPUToSPIRV/CMakeLists.txt @@ -1,14 +1,9 @@ -set(LLVM_TARGET_DEFINITIONS GPUToSPIRV.td) -mlir_tablegen(GPUToSPIRV.cpp.inc -gen-rewriters) -add_public_tablegen_target(MLIRGPUToSPIRVIncGen) - add_mlir_conversion_library(MLIRGPUToSPIRV GPUToSPIRV.cpp GPUToSPIRVPass.cpp DEPENDS MLIRConversionPassIncGen - MLIRGPUToSPIRVIncGen LINK_LIBS PUBLIC MLIRGPU diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp @@ -85,6 +85,19 @@ ConversionPatternRewriter &rewriter) const override; }; +class GPUModuleEndConversion final + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(gpu::ModuleEndOp endOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + rewriter.eraseOp(endOp); + return success(); + } +}; + /// Pattern to convert a gpu.return into a SPIR-V return. // TODO: This can go to DRR when GPU return has operands. class GPUReturnOpConversion final : public OpConversionPattern { @@ -301,12 +314,10 @@ StringRef(spvModuleName)); // Move the region from the module op into the SPIR-V module. - Region &spvModuleRegion = spvModule.body(); + Region &spvModuleRegion = spvModule.getRegion(); rewriter.inlineRegionBefore(moduleOp.body(), spvModuleRegion, spvModuleRegion.begin()); - // The spv.module build method adds a block with a terminator. Remove that - // block. The terminator of the module op in the remaining block will be - // legalized later. + // The spv.module build method adds a block. Remove that. rewriter.eraseBlock(&spvModuleRegion.back()); rewriter.eraseOp(moduleOp); return success(); @@ -330,15 +341,11 @@ // GPU To SPIRV Patterns. //===----------------------------------------------------------------------===// -namespace { -#include "GPUToSPIRV.cpp.inc" -} - void mlir::populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { - populateWithGenerated(patterns); patterns.add< - GPUFuncOpConversion, GPUModuleConversion, GPUReturnOpConversion, + GPUFuncOpConversion, GPUModuleConversion, GPUModuleEndConversion, + GPUReturnOpConversion, LaunchConfigConversion, LaunchConfigConversion, LaunchConfigConversion; - -#endif // MLIR_CONVERSION_GPU_TO_SPIRV diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp @@ -1342,7 +1342,7 @@ auto newModuleOp = rewriter.create(spvModuleOp.getLoc(), spvModuleOp.getName()); - rewriter.inlineRegionBefore(spvModuleOp.body(), newModuleOp.getBody()); + rewriter.inlineRegionBefore(spvModuleOp.getRegion(), newModuleOp.getBody()); // Remove the terminator block that was automatically added by builder rewriter.eraseBlock(&newModuleOp.getBodyRegion().back()); @@ -1351,20 +1351,6 @@ } }; -class ModuleEndConversionPattern - : public SPIRVToLLVMConversion { -public: - using SPIRVToLLVMConversion::SPIRVToLLVMConversion; - - LogicalResult - matchAndRewrite(spirv::ModuleEndOp moduleEndOp, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - - rewriter.eraseOp(moduleEndOp); - return success(); - } -}; - } // namespace //===----------------------------------------------------------------------===// @@ -1507,8 +1493,7 @@ void mlir::populateSPIRVToLLVMModuleConversionPatterns( LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) { - patterns.add( - patterns.getContext(), typeConverter); + patterns.add(patterns.getContext(), typeConverter); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -2529,7 +2529,8 @@ void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state, Optional name) { - ensureTerminator(*state.addRegion(), builder, state.location); + OpBuilder::InsertionGuard guard(builder); + builder.createBlock(state.addRegion()); if (name) { state.attributes.append(mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(*name)); @@ -2545,7 +2546,8 @@ builder.getI32IntegerAttr(static_cast(addressingModel))); state.addAttribute("memory_model", builder.getI32IntegerAttr( static_cast(memoryModel))); - ensureTerminator(*state.addRegion(), builder, state.location); + OpBuilder::InsertionGuard guard(builder); + builder.createBlock(state.addRegion()); if (name) { state.attributes.append(mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(*name)); @@ -2581,7 +2583,10 @@ if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{})) return failure(); - spirv::ModuleOp::ensureTerminator(*body, parser.getBuilder(), state.location); + // Make sure we have at least one block. + if (body->empty()) + body->push_back(new Block()); + return success(); } @@ -2608,8 +2613,7 @@ } printer.printOptionalAttrDictWithKeyword(moduleOp->getAttrs(), elidedAttrs); - printer.printRegion(moduleOp.body(), /*printEntryBlockArgs=*/false, - /*printBlockTerminators=*/false); + printer.printRegion(moduleOp.getRegion()); } static LogicalResult verify(spirv::ModuleOp moduleOp) { @@ -2619,7 +2623,7 @@ entryPoints; SymbolTable table(moduleOp); - for (auto &op : moduleOp.getBlock()) { + for (auto &op : *moduleOp.getBody()) { if (op.getDialect() != dialect) return op.emitError("'spv.module' can only contain spv.* ops"); diff --git a/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp b/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp --- a/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp +++ b/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp @@ -134,7 +134,7 @@ auto combinedModule = combinedModuleBuilder.create( modules[0].getLoc(), addressingModel, memoryModel); - combinedModuleBuilder.setInsertionPointToStart(&*combinedModule.getBody()); + combinedModuleBuilder.setInsertionPointToStart(combinedModule.getBody()); // In some cases, a symbol in the (current state of the) combined module is // renamed in order to maintain the conflicting symbol in the input module @@ -160,7 +160,7 @@ // for spv.funcs. This way, if the conflicting op in the input module is // non-spv.func, we rename that symbol instead and maintain the spv.func in // the combined module name as it is. - for (auto &op : combinedModule.getBlock().without_terminator()) { + for (auto &op : *combinedModule.getBody()) { if (auto symbolOp = dyn_cast(op)) { StringRef oldSymName = symbolOp.getName(); @@ -195,7 +195,7 @@ // In the current input module, rename all symbols that conflict with // symbols from the combined module. This includes renaming spv.funcs. - for (auto &op : moduleClone.getBlock().without_terminator()) { + for (auto &op : *moduleClone.getBody()) { if (auto symbolOp = dyn_cast(op)) { StringRef oldSymName = symbolOp.getName(); @@ -225,7 +225,7 @@ } // Clone all the module's ops to the combined module. - for (auto &op : moduleClone.getBlock().without_terminator()) + for (auto &op : *moduleClone.getBody()) combinedModuleBuilder.insert(op.clone()); } @@ -233,7 +233,7 @@ DenseMap hashToSymbolOp; SmallVector eraseList; - for (auto &op : combinedModule.getBlock().without_terminator()) { + for (auto &op : *combinedModule.getBody()) { llvm::hash_code hashCode(0); SymbolOpInterface symbolOp = dyn_cast(op); diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp @@ -115,7 +115,7 @@ OpBuilder::InsertionGuard moduleInsertionGuard(builder); auto spirvModule = funcOp->getParentOfType(); - builder.setInsertionPoint(spirvModule.body().front().getTerminator()); + builder.setInsertionPointToEnd(spirvModule.getBody()); // Adds the spv.EntryPointOp after collecting all the interface variables // needed. diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -51,7 +51,7 @@ spirv::Deserializer::Deserializer(ArrayRef binary, MLIRContext *context) : binary(binary), context(context), unknownLoc(UnknownLoc::get(context)), - module(createModuleOp()), opBuilder(module->body()) {} + module(createModuleOp()), opBuilder(module->getRegion()) {} LogicalResult spirv::Deserializer::deserialize() { LLVM_DEBUG(llvm::dbgs() << "+++ starting deserialization +++\n"); diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -99,7 +99,7 @@ // Iterate over the module body to serialize it. Assumptions are that there is // only one basic block in the moduleOp - for (auto &op : module.getBlock()) { + for (auto &op : *module.getBody()) { if (failed(processOperation(&op))) { return failure(); } @@ -1090,7 +1090,6 @@ return processGlobalVariableOp(op); }) .Case([&](spirv::LoopOp op) { return processLoopOp(op); }) - .Case([&](spirv::ModuleEndOp) { return success(); }) .Case([&](spirv::ReferenceOfOp op) { return processReferenceOfOp(op); }) .Case([&](spirv::SelectionOp op) { return processSelectionOp(op); }) .Case([&](spirv::SpecConstantOp op) { return processSpecConstantOp(op); }) diff --git a/mlir/test/Conversion/SPIRVToLLVM/module-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/module-ops-to-llvm.mlir --- a/mlir/test/Conversion/SPIRVToLLVM/module-ops-to-llvm.mlir +++ b/mlir/test/Conversion/SPIRVToLLVM/module-ops-to-llvm.mlir @@ -13,12 +13,6 @@ // CHECK: module spv.module Logical GLSL450 requires #spv.vce {} -// CHECK: module -spv.module Logical GLSL450 { - // CHECK: } - spv.mlir.endmodule -} - // CHECK: module spv.module Logical GLSL450 { // CHECK-LABEL: llvm.func @empty() diff --git a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir --- a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir @@ -425,12 +425,6 @@ requires #spv.vce attributes {foo = "bar"} { } -// Module with explicit spv.mlir.endmodule -// CHECK: spv.module -spv.module Logical GLSL450 { - spv.mlir.endmodule -} - // Module with function // CHECK: spv.module spv.module Logical GLSL450 { @@ -476,15 +470,6 @@ // ----- -// Module with wrong terminator -// expected-error@+2 {{expects regions to end with 'spv.mlir.endmodule'}} -// expected-note@+1 {{in custom textual format, the absence of terminator implies 'spv.mlir.endmodule'}} -"spv.module"() ({ - %0 = spv.Constant true -}) {addressing_model = 0 : i32, memory_model = 1 : i32} : () -> () - -// ----- - // Use non SPIR-V op inside module spv.module Logical GLSL450 { // expected-error @+1 {{'spv.module' can only contain spv.* ops}} @@ -511,17 +496,6 @@ // ----- -//===----------------------------------------------------------------------===// -// spv.mlir.endmodule -//===----------------------------------------------------------------------===// - -func @module_end_not_in_module() -> () { - // expected-error @+1 {{op must appear in a module-like op's block}} - spv.mlir.endmodule -} - -// ----- - //===----------------------------------------------------------------------===// // spv.mlir.referenceof //===----------------------------------------------------------------------===// diff --git a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp --- a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp +++ b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp @@ -59,7 +59,7 @@ } Type getFloatStructType() { - OpBuilder opBuilder(module->body()); + OpBuilder opBuilder(module->getRegion()); llvm::SmallVector elementTypes{opBuilder.getF32Type()}; llvm::SmallVector offsetInfo{0}; auto structType = spirv::StructType::get(elementTypes, offsetInfo); @@ -67,7 +67,7 @@ } void addGlobalVar(Type type, llvm::StringRef name) { - OpBuilder opBuilder(module->body()); + OpBuilder opBuilder(module->getRegion()); auto ptrType = spirv::PointerType::get(type, spirv::StorageClass::Uniform); opBuilder.create( UnknownLoc::get(&context), TypeAttr::get(ptrType),