diff --git a/mlir/docs/Tutorials/Toy/Ch-6.md b/mlir/docs/Tutorials/Toy/Ch-6.md --- a/mlir/docs/Tutorials/Toy/Ch-6.md +++ b/mlir/docs/Tutorials/Toy/Ch-6.md @@ -63,7 +63,7 @@ ```c++ mlir::ConversionTarget target(getContext()); target.addLegalDialect(); - target.addLegalOp(); + target.addLegalOp(); ``` ### Type Converter diff --git a/mlir/docs/Tutorials/UnderstandingTheIRStructure.md b/mlir/docs/Tutorials/UnderstandingTheIRStructure.md --- a/mlir/docs/Tutorials/UnderstandingTheIRStructure.md +++ b/mlir/docs/Tutorials/UnderstandingTheIRStructure.md @@ -110,7 +110,6 @@ "dialect.innerop6"() : () -> () "dialect.innerop7"() : () -> () }) {"other attribute" = 42 : i64} : () -> () - "module_terminator"() : () -> () }) : () -> () ``` @@ -147,7 +146,6 @@ 0 nested regions: visiting op: 'dialect.innerop7' with 0 operands and 0 results 0 nested regions: - visiting op: 'module_terminator' with 0 operands and 0 results 0 nested regions: ``` diff --git a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp --- a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp +++ b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp @@ -174,7 +174,7 @@ // final target for this lowering. For this lowering, we are only targeting // the LLVM dialect. LLVMConversionTarget target(getContext()); - target.addLegalOp(); + target.addLegalOp(); // During this lowering, we will also be lowering the MemRef types, that are // currently being operated on, to a representation in LLVM. To perform this diff --git a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp --- a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp +++ b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp @@ -174,7 +174,7 @@ // final target for this lowering. For this lowering, we are only targeting // the LLVM dialect. LLVMConversionTarget target(getContext()); - target.addLegalOp(); + target.addLegalOp(); // During this lowering, we will also be lowering the MemRef types, that are // currently being operated on, to a representation in LLVM. To perform this diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h --- a/mlir/include/mlir/IR/Block.h +++ b/mlir/include/mlir/IR/Block.h @@ -226,6 +226,14 @@ /// Returns true if this blocks has no successors. bool hasNoSuccessors() { return succ_begin() == succ_end(); } + /// Returns true if this block may be valid without terminator. That is if: + /// - it does not have a parent region. + /// - Or the parent region have a single block and: + /// - This region does not have a parent op. + /// - Or the parent op is unregistered. + /// - Or the parent op has the NoTerminator trait. + bool mayNotHaveTerminator(); + /// If this block has exactly one predecessor, return it. Otherwise, return /// null. /// diff --git a/mlir/include/mlir/IR/BuiltinOps.h b/mlir/include/mlir/IR/BuiltinOps.h --- a/mlir/include/mlir/IR/BuiltinOps.h +++ b/mlir/include/mlir/IR/BuiltinOps.h @@ -15,6 +15,7 @@ #include "mlir/IR/FunctionSupport.h" #include "mlir/IR/OwningOpRef.h" +#include "mlir/IR/RegionKindInterface.h" #include "mlir/IR/SymbolTable.h" #include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Interfaces/CastInterfaces.h" diff --git a/mlir/include/mlir/IR/BuiltinOps.td b/mlir/include/mlir/IR/BuiltinOps.td --- a/mlir/include/mlir/IR/BuiltinOps.td +++ b/mlir/include/mlir/IR/BuiltinOps.td @@ -15,6 +15,7 @@ #define BUILTIN_OPS include "mlir/IR/BuiltinDialect.td" +include "mlir/IR/RegionKindInterface.td" include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/CastInterfaces.td" @@ -158,9 +159,8 @@ //===----------------------------------------------------------------------===// def ModuleOp : Builtin_Op<"module", [ - AffineScope, IsolatedFromAbove, NoRegionArguments, SymbolTable, Symbol, - SingleBlockImplicitTerminator<"ModuleTerminatorOp"> -]> { + AffineScope, IsolatedFromAbove, NoRegionArguments, SymbolTable, Symbol] + # NoTerminator.traits> { let summary = "A top level container operation"; let description = [{ A `module` represents a top-level container operation. It contains a single @@ -206,22 +206,6 @@ let skipDefaultBuilders = 1; } -//===----------------------------------------------------------------------===// -// ModuleTerminatorOp -//===----------------------------------------------------------------------===// - -def ModuleTerminatorOp : Builtin_Op<"module_terminator", [ - Terminator, HasParent<"ModuleOp"> -]> { - let summary = "A pseudo op that marks the end of a module"; - let description = [{ - `module_terminator` is a special terminator operation for the body of a - `module`, it has no semantic meaning beyond keeping the body of a `module` - well-formed. - }]; - let assemblyFormat = "attr-dict"; -} - //===----------------------------------------------------------------------===// // UnrealizedConversionCastOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1808,6 +1808,9 @@ ]; } +// Op's regions have a single block. +def SingleBlock : NativeOpTrait<"SingleBlock">; + // Op's regions have a single block with the specified terminator. class SingleBlockImplicitTerminator : ParamNativeOpTrait<"SingleBlockImplicitTerminator", op>; diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -675,6 +675,11 @@ //===----------------------------------------------------------------------===// // Terminator Traits +/// This class indicates that the regions associated with this op don't have +/// terminators. +template +class NoTerminator : public TraitBase {}; + /// This class provides the API for ops that are known to be terminators. template class IsTerminator : public TraitBase { @@ -778,6 +783,88 @@ : public detail::MultiSuccessorTraitBase { }; +//===----------------------------------------------------------------------===// +// SingleBlock + +/// This class provides APIs and verifiers for ops with regions having a single +/// block. +template +struct SingleBlock : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i) { + Region ®ion = op->getRegion(i); + + // Empty regions are fine. + if (region.empty()) + continue; + + // Non-empty regions must contain a single basic block. + if (!llvm::hasSingleElement(region)) + return op->emitOpError("expects region #") + << i << " to have 0 or 1 blocks"; + + if (!ConcreteType::template hasTrait()) { + Block &block = region.front(); + if (block.empty()) + return op->emitOpError() << "expects a non-empty block"; + } + } + return success(); + } + + Block *getBody(unsigned idx = 0) { + Region ®ion = this->getOperation()->getRegion(idx); + assert(!region.empty() && "unexpected empty region"); + return ®ion.front(); + } + Region &getBodyRegion(unsigned idx = 0) { + return this->getOperation()->getRegion(idx); + } + + //===------------------------------------------------------------------===// + // Single Region Utilities + //===------------------------------------------------------------------===// + + /// The following are a set of methods only enabled when the parent + /// operation has a single region. Each of these methods take an additional + /// template parameter that represents the concrete operation so that we + /// can use SFINAE to disable the methods for non-single region operations. + template + using enable_if_single_region = + typename std::enable_if_t(), T>; + + template + enable_if_single_region begin() { + return getBody()->begin(); + } + template + enable_if_single_region end() { + return getBody()->end(); + } + template + enable_if_single_region front() { + return *begin(); + } + + /// Insert the operation into the back of the body. + template + enable_if_single_region push_back(Operation *op) { + insert(Block::iterator(getBody()->end()), op); + } + + /// Insert the operation at the given insertion point. + template + enable_if_single_region insert(Operation *insertPt, Operation *op) { + insert(Block::iterator(insertPt), op); + } + template + enable_if_single_region insert(Block::iterator insertPt, Operation *op) { + auto *body = getBody(); + body->getOperations().insert(insertPt, op); + } +}; + //===----------------------------------------------------------------------===// // SingleBlockImplicitTerminator @@ -786,8 +873,9 @@ template struct SingleBlockImplicitTerminator { template - class Impl : public TraitBase { + class Impl : public SingleBlock { private: + using Base = SingleBlock; /// Builds a terminator operation without relying on OpBuilder APIs to avoid /// cyclic header inclusion. static Operation *buildTerminator(OpBuilder &builder, Location loc) { @@ -801,22 +889,14 @@ using ImplicitTerminatorOpT = TerminatorOpType; static LogicalResult verifyTrait(Operation *op) { + if (failed(Base::verifyTrait(op))) + return failure(); for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i) { Region ®ion = op->getRegion(i); - // Empty regions are fine. if (region.empty()) continue; - - // Non-empty regions must contain a single basic block. - if (std::next(region.begin()) != region.end()) - return op->emitOpError("expects region #") - << i << " to have 0 or 1 blocks"; - - Block &block = region.front(); - if (block.empty()) - return op->emitOpError() << "expects a non-empty block"; - Operation &terminator = block.back(); + Operation &terminator = region.front().back(); if (isa(terminator)) continue; @@ -849,40 +929,15 @@ buildTerminator); } - Block *getBody(unsigned idx = 0) { - Region ®ion = this->getOperation()->getRegion(idx); - assert(!region.empty() && "unexpected empty region"); - return ®ion.front(); - } - Region &getBodyRegion(unsigned idx = 0) { - return this->getOperation()->getRegion(idx); - } - //===------------------------------------------------------------------===// // Single Region Utilities //===------------------------------------------------------------------===// + using Base::getBody; - /// The following are a set of methods only enabled when the parent - /// operation has a single region. Each of these methods take an additional - /// template parameter that represents the concrete operation so that we - /// can use SFINAE to disable the methods for non-single region operations. template using enable_if_single_region = typename std::enable_if_t(), T>; - template - enable_if_single_region begin() { - return getBody()->begin(); - } - template - enable_if_single_region end() { - return getBody()->end(); - } - template - enable_if_single_region front() { - return *begin(); - } - /// Insert the operation into the back of the body, before the terminator. template enable_if_single_region push_back(Operation *op) { diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -92,8 +92,13 @@ virtual void printGenericOp(Operation *op) = 0; /// Prints a region. + /// If 'printEntryBlockArgs' is false, the arguments of the + /// block are not printed. If 'printBlockTerminator' is false, the terminator + /// operation of the block is not printed. If printEmptyBlock is true, then + /// the block header is printed even if the block is empty. virtual void printRegion(Region &blocks, bool printEntryBlockArgs = true, - bool printBlockTerminators = true) = 0; + bool printBlockTerminators = true, + bool printEmptyBlock = false) = 0; /// Renumber the arguments for the specified region to the same names as the /// SSA values in namesToUse. This may only be used for IsolatedFromAbove diff --git a/mlir/include/mlir/IR/Region.h b/mlir/include/mlir/IR/Region.h --- a/mlir/include/mlir/IR/Region.h +++ b/mlir/include/mlir/IR/Region.h @@ -43,6 +43,10 @@ using BlockListType = llvm::iplist; BlockListType &getBlocks() { return blocks; } + Block &emplaceBlock() { + push_back(new Block); + return back(); + } // Iteration over the blocks in the region. using iterator = BlockListType::iterator; diff --git a/mlir/include/mlir/IR/RegionKindInterface.h b/mlir/include/mlir/IR/RegionKindInterface.h --- a/mlir/include/mlir/IR/RegionKindInterface.h +++ b/mlir/include/mlir/IR/RegionKindInterface.h @@ -28,6 +28,16 @@ Graph, }; +namespace OpTrait { +/// A trait that specifies that an operation only defines graph regions. +template +class HasOnlyGraphRegion : public TraitBase { +public: + static RegionKind getRegionKind(unsigned index) { return RegionKind::Graph; } + static bool hasSSADominance(unsigned index) { return false; } +}; +} // namespace OpTrait + } // namespace mlir #include "mlir/IR/RegionKindInterface.h.inc" diff --git a/mlir/include/mlir/IR/RegionKindInterface.td b/mlir/include/mlir/IR/RegionKindInterface.td --- a/mlir/include/mlir/IR/RegionKindInterface.td +++ b/mlir/include/mlir/IR/RegionKindInterface.td @@ -50,4 +50,17 @@ ]; } +def HasOnlyGraphRegion : NativeOpTrait<"HasOnlyGraphRegion">; + +// Op's regions that don't need a terminator: requires some other traits +// so it defines a list that must be concatenated. +def NoTerminator { + list traits = [ + NativeOpTrait<"NoTerminator">, + SingleBlock, + RegionKindInterface, + HasOnlyGraphRegion + ]; +} + #endif // MLIR_IR_REGIONKINDINTERFACE diff --git a/mlir/include/mlir/Parser.h b/mlir/include/mlir/Parser.h --- a/mlir/include/mlir/Parser.h +++ b/mlir/include/mlir/Parser.h @@ -37,12 +37,11 @@ Block *parsedBlock, MLIRContext *context, Location sourceFileLoc) { static_assert( ContainerOpT::template hasTrait() && - std::is_base_of:: - template Impl, - ContainerOpT>::value, + (ContainerOpT::template hasTrait() || + ContainerOpT::template hasTrait< + OpTrait::SingleBlockImplicitTerminator>()), "Expected `ContainerOpT` to have a single region with a single " - "block that has an implicit terminator"); + "block that has an implicit terminator or does not require one"); // Check to see if we parsed a single instance of this operation. if (llvm::hasSingleElement(*parsedBlock)) { diff --git a/mlir/lib/Bindings/Python/mlir/dialects/_builtin_ops_ext.py b/mlir/lib/Bindings/Python/mlir/dialects/_builtin_ops_ext.py --- a/mlir/lib/Bindings/Python/mlir/dialects/_builtin_ops_ext.py +++ b/mlir/lib/Bindings/Python/mlir/dialects/_builtin_ops_ext.py @@ -11,8 +11,6 @@ super().__init__(self.build_generic(results=[], operands=[], loc=loc, ip=ip)) body = self.regions[0].blocks.append() - with InsertionPoint(body): - Operation.create("module_terminator") @property def body(self): diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp --- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp +++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp @@ -156,8 +156,8 @@ /// Adds Async Runtime C API declarations to the module. static void addAsyncRuntimeApiDeclarations(ModuleOp module) { - auto builder = ImplicitLocOpBuilder::atBlockTerminator(module.getLoc(), - module.getBody()); + auto builder = + ImplicitLocOpBuilder::atBlockEnd(module.getLoc(), module.getBody()); auto addFuncDecl = [&](StringRef name, FunctionType type) { if (module.lookupSymbol(name)) @@ -207,8 +207,8 @@ using namespace mlir::LLVM; MLIRContext *ctx = module.getContext(); - ImplicitLocOpBuilder builder(module.getLoc(), - module.getBody()->getTerminator()); + auto builder = + ImplicitLocOpBuilder::atBlockEnd(module.getLoc(), module.getBody()); auto voidTy = LLVMVoidType::get(ctx); auto i64 = IntegerType::get(ctx, 64); @@ -232,15 +232,14 @@ return; MLIRContext *ctx = module.getContext(); - - OpBuilder moduleBuilder(module.getBody()->getTerminator()); - Location loc = module.getLoc(); + auto loc = module.getLoc(); + auto moduleBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, module.getBody()); auto voidTy = LLVM::LLVMVoidType::get(ctx); auto i8Ptr = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)); auto resumeOp = moduleBuilder.create( - loc, kResume, LLVM::LLVMFunctionType::get(voidTy, {i8Ptr})); + kResume, LLVM::LLVMFunctionType::get(voidTy, {i8Ptr})); resumeOp.setPrivate(); auto *block = resumeOp.addEntryBlock(); diff --git a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp --- a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp +++ b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp @@ -329,7 +329,7 @@ auto function = [&] { if (auto function = module.lookupSymbol(functionName)) return function; - return OpBuilder(module.getBody()->getTerminator()) + return OpBuilder::atBlockEnd(module.getBody()) .create(loc, functionName, functionType); }(); return builder.create( diff --git a/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp --- a/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp +++ b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp @@ -99,7 +99,7 @@ LogicalResult ConvertGpuLaunchFuncToVulkanLaunchFunc::declareVulkanLaunchFunc( Location loc, gpu::LaunchFuncOp launchOp) { - OpBuilder builder(getOperation().getBody()->getTerminator()); + auto builder = OpBuilder::atBlockEnd(getOperation().getBody()); // Workgroup size is written into the kernel. So to properly modelling // vulkan launch, we have to skip local workgroup size configuration here. diff --git a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp --- a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp +++ b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp @@ -291,7 +291,7 @@ void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) { ModuleOp module = getOperation(); - OpBuilder builder(module.getBody()->getTerminator()); + auto builder = OpBuilder::atBlockEnd(module.getBody()); if (!module.lookupSymbol(kSetEntryPoint)) { builder.create( diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp --- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp +++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp @@ -227,7 +227,7 @@ LLVMConversionTarget target(getContext()); target.addIllegalOp(); - target.addLegalOp(); + target.addLegalOp(); if (failed(applyPartialConversion(module, target, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp --- a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp +++ b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp @@ -35,7 +35,7 @@ populateBuiltinFuncToSPIRVPatterns(context, typeConverter, patterns); // Allow builtin ops. - target->addLegalOp(); + target->addLegalOp(); target->addDynamicallyLegalOp([&](FuncOp op) { return typeConverter.isSignatureLegal(op.getType()) && typeConverter.isLegal(&op.getBody()); diff --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp --- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp +++ b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp @@ -216,7 +216,7 @@ ConversionTarget target(getContext()); target.addLegalDialect(); - target.addLegalOp(); + target.addLegalOp(); target.addLegalOp(); OwningRewritePatternList patterns; populateLinalgToStandardConversionPatterns(patterns, &getContext()); diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp @@ -1361,7 +1361,7 @@ matchAndRewrite(spirv::ModuleEndOp moduleEndOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(moduleEndOp); + rewriter.eraseOp(moduleEndOp); return success(); } }; diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.cpp @@ -48,10 +48,8 @@ target.addIllegalDialect(); target.addLegalDialect(); - // Set `ModuleOp` and `ModuleTerminatorOp` as legal for `spv.module` - // conversion. + // Set `ModuleOp` as legal for `spv.module` conversion. target.addLegalOp(); - target.addLegalOp(); if (failed(applyPartialConversion(module, target, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -675,7 +675,7 @@ ConversionTarget target(ctx); target.addLegalDialect(); - target.addLegalOp(); + target.addLegalOp(); // Setup conversion patterns. OwningRewritePatternList patterns; diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp @@ -40,7 +40,7 @@ OwningRewritePatternList patterns; populateVectorToSPIRVPatterns(context, typeConverter, patterns); - target->addLegalOp(); + target->addLegalOp(); target->addLegalOp(); if (failed(applyFullConversion(module, *target, std::move(patterns)))) diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp --- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp @@ -199,7 +199,7 @@ // TODO: Derive outlined function name from the parent FuncOp (support // multiple nested async.execute operations). FuncOp func = FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs); - symbolTable.insert(func, Block::iterator(module.getBody()->getTerminator())); + symbolTable.insert(func); SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private); diff --git a/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp --- a/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp @@ -43,8 +43,7 @@ populateBranchOpInterfaceTypeConversionPattern(patterns, context, typeConverter); populateReturnOpTypeConversionPattern(patterns, context, typeConverter); - target.addLegalOp(); + target.addLegalOp(); target.markUnknownOpDynamicallyLegal([&](Operation *op) { return isNotBranchOpInterfaceOrReturnLikeOp(op) || diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -418,7 +418,8 @@ /// Print the given region. void printRegion(Region ®ion, bool printEntryBlockArgs, - bool printBlockTerminators) override { + bool printBlockTerminators, + bool printEmptyBlock = false) override { if (region.empty()) return; @@ -2324,7 +2325,7 @@ /// Print the given region. void printRegion(Region ®ion, bool printEntryBlockArgs, - bool printBlockTerminators) override; + bool printBlockTerminators, bool printEmptyBlock) override; /// Renumber the arguments for the specified region to the same names as the /// SSA values in namesToUse. This may only be used for IsolatedFromAbove @@ -2435,7 +2436,7 @@ os << " ("; interleaveComma(op->getRegions(), [&](Region ®ion) { printRegion(region, /*printEntryBlockArgs=*/true, - /*printBlockTerminators=*/true); + /*printBlockTerminators=*/true, /*printEmptyBlock=*/true); }); os << ')'; } @@ -2536,12 +2537,18 @@ } void OperationPrinter::printRegion(Region ®ion, bool printEntryBlockArgs, - bool printBlockTerminators) { + bool printBlockTerminators, + bool printEmptyBlock) { os << " {" << newLine; if (!region.empty()) { auto *entryBlock = ®ion.front(); - print(entryBlock, printEntryBlockArgs && entryBlock->getNumArguments() != 0, - printBlockTerminators); + // Force printing the block header if printEmptyBlock is set and the block + // is empty or if printEntryBlockArgs is set and there are arguments to + // print. + bool shouldAlwaysPrintBlockHeader = + (printEmptyBlock && entryBlock->empty()) || + (printEntryBlockArgs && entryBlock->getNumArguments() != 0); + print(entryBlock, shouldAlwaysPrintBlockHeader, printBlockTerminators); for (auto &b : llvm::drop_begin(region.getBlocks(), 1)) print(&b); } diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp --- a/mlir/lib/IR/Block.cpp +++ b/mlir/lib/IR/Block.cpp @@ -294,6 +294,21 @@ return newBB; } +/// Returns true if this block may be valid without terminator. That is if: +/// - it does not have a parent region. +/// - Or the parent region have a single block and: +/// - This region does not have a parent op. +/// - Or the parent op is unregistered. +/// - Or the parent op has the NoTerminator trait. +bool Block::mayNotHaveTerminator() { + if (!getParent()) + return true; + if (!llvm::hasSingleElement(*getParent())) + return false; + Operation *op = getParentOp(); + return !op || op->mightHaveTrait(); +} + //===----------------------------------------------------------------------===// // Predecessors //===----------------------------------------------------------------------===// @@ -314,9 +329,11 @@ SuccessorRange::SuccessorRange() : SuccessorRange(nullptr, 0) {} SuccessorRange::SuccessorRange(Block *block) : SuccessorRange() { - if (Operation *term = block->getTerminator()) + if (!llvm::hasSingleElement(*block->getParent())) { + Operation *term = block->getTerminator(); if ((count = term->getNumSuccessors())) base = term->getBlockOperands().data(); + } } SuccessorRange::SuccessorRange(Operation *term) : SuccessorRange() { diff --git a/mlir/lib/IR/BuiltinDialect.cpp b/mlir/lib/IR/BuiltinDialect.cpp --- a/mlir/lib/IR/BuiltinDialect.cpp +++ b/mlir/lib/IR/BuiltinDialect.cpp @@ -209,7 +209,7 @@ void ModuleOp::build(OpBuilder &builder, OperationState &state, Optional name) { - ensureTerminator(*state.addRegion(), builder, state.location); + state.addRegion()->emplaceBlock(); if (name) { state.attributes.push_back(builder.getNamedAttr( mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(*name))); diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp --- a/mlir/lib/IR/SymbolTable.cpp +++ b/mlir/lib/IR/SymbolTable.cpp @@ -161,11 +161,17 @@ // TODO: consider if SymbolTable's constructor should behave the same. if (!symbol->getParentOp()) { auto &body = symbolTableOp->getRegion(0).front(); - if (insertPt == Block::iterator() || insertPt == body.end()) - insertPt = Block::iterator(body.getTerminator()); - - assert(insertPt->getParentOp() == symbolTableOp && - "expected insertPt to be in the associated module operation"); + if (insertPt == Block::iterator()) { + insertPt = Block::iterator(body.end()); + } else { + assert((insertPt == body.end() || + insertPt->getParentOp() == symbolTableOp) && + "expected insertPt to be in the associated module operation"); + } + // Insert before the terminator, if any. + if (insertPt == Block::iterator(body.end()) && !body.empty() && + std::prev(body.end())->hasTrait()) + insertPt = std::prev(body.end()); body.getOperations().insert(insertPt, symbol); } @@ -291,11 +297,14 @@ Operation *SymbolTable::lookupSymbolIn(Operation *symbolTableOp, StringRef symbol) { assert(symbolTableOp->hasTrait()); + Region ®ion0 = symbolTableOp->getRegion(0); + if (region0.empty()) + return nullptr; // Look for a symbol with the given name. Identifier symbolNameId = Identifier::get(SymbolTable::getSymbolAttrName(), symbolTableOp->getContext()); - for (auto &op : symbolTableOp->getRegion(0).front().without_terminator()) + for (auto &op : region0.front()) if (getNameIfSymbol(&op, symbolNameId) == symbol) return &op; return nullptr; diff --git a/mlir/lib/IR/Verifier.cpp b/mlir/lib/IR/Verifier.cpp --- a/mlir/lib/IR/Verifier.cpp +++ b/mlir/lib/IR/Verifier.cpp @@ -119,11 +119,15 @@ return emitError(block, "block argument not owned by block"); // Verify that this block has a terminator. - if (block.empty()) - return emitError(block, "block with no terminator"); + + if (block.empty()) { + if (block.mayNotHaveTerminator()) + return success(); + return emitError(block, "empty block: expect at least a terminator"); + } // Verify the non-terminator operations separately so that we can verify - // they has no successors. + // they have no successors. for (auto &op : llvm::make_range(block.begin(), std::prev(block.end()))) { if (op.getNumSuccessors() != 0) return op.emitError( @@ -137,8 +141,13 @@ Operation &terminator = block.back(); if (failed(verifyOperation(terminator))) return failure(); + + if (block.mayNotHaveTerminator()) + return success(); + if (!terminator.mightHaveTrait()) - return block.back().emitError("block with no terminator"); + return block.back().emitError("block with no terminator, has ") + << terminator; // Verify that this block is not branching to a block of a different // region. @@ -176,13 +185,14 @@ unsigned numRegions = op.getNumRegions(); for (unsigned i = 0; i < numRegions; i++) { Region ®ion = op.getRegion(i); + RegionKind kind = + kindInterface ? kindInterface.getRegionKind(i) : RegionKind::SSACFG; // Check that Graph Regions only have a single basic block. This is // similar to the code in SingleBlockImplicitTerminator, but doesn't // require the trait to be specified. This arbitrary limitation is // designed to limit the number of cases that have to be handled by // transforms and conversions until the concept stabilizes. - if (op.isRegistered() && kindInterface && - kindInterface.getRegionKind(i) == RegionKind::Graph) { + if (op.isRegistered() && kind == RegionKind::Graph) { // Empty regions are fine. if (region.empty()) continue; diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -2064,7 +2064,7 @@ auto &parsedOps = (*topLevelOp)->getRegion(0).front().getOperations(); auto &destOps = topLevelBlock->getOperations(); destOps.splice(destOps.empty() ? destOps.end() : std::prev(destOps.end()), - parsedOps, parsedOps.begin(), std::prev(parsedOps.end())); + parsedOps, parsedOps.begin(), parsedOps.end()); return success(); } diff --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp --- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp @@ -269,10 +269,11 @@ /// Globals are inserted before the first function, if any. Block::iterator getGlobalInsertPt() { - auto i = module.getBody()->begin(); - while (!isa(i)) - ++i; - return i; + auto it = module.getBody()->begin(); + auto endIt = module.getBody()->end(); + while (it != endIt && !isa(it)) + ++it; + return it; } /// Functions are always inserted before the module terminator. diff --git a/mlir/lib/Transforms/SymbolDCE.cpp b/mlir/lib/Transforms/SymbolDCE.cpp --- a/mlir/lib/Transforms/SymbolDCE.cpp +++ b/mlir/lib/Transforms/SymbolDCE.cpp @@ -61,8 +61,7 @@ if (!nestedSymbolTable->hasTrait()) return; for (auto &block : nestedSymbolTable->getRegion(0)) { - for (Operation &op : - llvm::make_early_inc_range(block.without_terminator())) { + for (Operation &op : llvm::make_early_inc_range(block)) { if (isa(&op) && !liveSymbols.count(&op)) op.erase(); } @@ -84,7 +83,7 @@ // are known to be live. for (auto &block : symbolTableOp->getRegion(0)) { // Add all non-symbols or symbols that can't be discarded. - for (Operation &op : block.without_terminator()) { + for (Operation &op : block) { SymbolOpInterface symbol = dyn_cast(&op); if (!symbol) { worklist.push_back(&op); diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp --- a/mlir/lib/Transforms/Utils/RegionUtils.cpp +++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp @@ -311,6 +311,7 @@ for (Region ®ion : regions) { if (region.empty()) continue; + bool hasSingleBlock = llvm::hasSingleElement(region); // We do the deletion in an order that deletes all uses before deleting // defs. @@ -322,7 +323,8 @@ // before domtree parents. A CFG post-order (with reverse iteration with a // block) satisfies that without needing an explicit domtree calculation. for (Block *block : llvm::post_order(®ion.front())) { - eraseTerminatorSuccessorOperands(block->getTerminator(), liveMap); + if (!hasSingleBlock) + eraseTerminatorSuccessorOperands(block->getTerminator(), liveMap); for (Operation &childOp : llvm::make_early_inc_range(llvm::reverse(block->getOperations()))) { if (!liveMap.wasProvenLive(&childOp)) { diff --git a/mlir/test/Bindings/Python/context_managers.py b/mlir/test/Bindings/Python/context_managers.py --- a/mlir/test/Bindings/Python/context_managers.py +++ b/mlir/test/Bindings/Python/context_managers.py @@ -62,7 +62,7 @@ def testInsertionPointEnterExit(): ctx1 = Context() m = Module.create(Location.unknown(ctx1)) - ip = InsertionPoint.at_block_terminator(m.body) + ip = InsertionPoint(m.body) with ip: assert InsertionPoint.current is ip diff --git a/mlir/test/Bindings/Python/dialects.py b/mlir/test/Bindings/Python/dialects.py --- a/mlir/test/Bindings/Python/dialects.py +++ b/mlir/test/Bindings/Python/dialects.py @@ -77,7 +77,7 @@ ctx.allow_unregistered_dialects = True m = Module.create() - with InsertionPoint.at_block_terminator(m.body): + with InsertionPoint(m.body): f32 = F32Type.get() # Create via dialects context collection. input1 = createInput() diff --git a/mlir/test/Bindings/Python/dialects/linalg/ops.py b/mlir/test/Bindings/Python/dialects/linalg/ops.py --- a/mlir/test/Bindings/Python/dialects/linalg/ops.py +++ b/mlir/test/Bindings/Python/dialects/linalg/ops.py @@ -17,7 +17,7 @@ module = Module.create() f32 = F32Type.get() tensor_type = RankedTensorType.get((2, 3, 4), f32) - with InsertionPoint.at_block_terminator(module.body): + with InsertionPoint(module.body): func = builtin.FuncOp(name="matmul_test", type=FunctionType.get( inputs=[tensor_type, tensor_type], @@ -40,7 +40,7 @@ module = Module.create() f32 = F32Type.get() memref_type = MemRefType.get((2, 3, 4), f32) - with InsertionPoint.at_block_terminator(module.body): + with InsertionPoint(module.body): func = builtin.FuncOp(name="matmul_test", type=FunctionType.get( inputs=[memref_type, memref_type, memref_type], diff --git a/mlir/test/Bindings/Python/insertion_point.py b/mlir/test/Bindings/Python/insertion_point.py --- a/mlir/test/Bindings/Python/insertion_point.py +++ b/mlir/test/Bindings/Python/insertion_point.py @@ -129,8 +129,13 @@ def test_insert_at_end_with_terminator_errors(): with Context() as ctx, Location.unknown(): ctx.allow_unregistered_dialects = True - m = Module.create() # Module is created with a terminator. - with InsertionPoint(m.body): + module = Module.parse(r""" + func @foo() -> () { + return + } + """) + entry_block = module.body.operations[0].regions[0].blocks[0] + with InsertionPoint(entry_block): try: Operation.create("custom.op1", results=[], operands=[]) except IndexError as e: diff --git a/mlir/test/Bindings/Python/ir_operation.py b/mlir/test/Bindings/Python/ir_operation.py --- a/mlir/test/Bindings/Python/ir_operation.py +++ b/mlir/test/Bindings/Python/ir_operation.py @@ -64,7 +64,6 @@ # CHECK: BLOCK 0: # CHECK: OP 0: %0 = "custom.addi" # CHECK: OP 1: return - # CHECK: OP 1: module_terminator walk_operations("", op) run(testTraverseOpRegionBlockIterators) @@ -101,7 +100,6 @@ # CHECK: BLOCK 0: # CHECK: OP 0: %0 = "custom.addi" # CHECK: OP 1: return - # CHECK: OP 1: module_terminator walk_operations("", module.operation) run(testTraverseOpRegionBlockIndices) @@ -546,9 +544,9 @@ def testPrintInvalidOperation(): ctx = Context() with Location.unknown(ctx): - module = Operation.create("module", regions=1) - # This block does not have a terminator, it may crash the custom printer. - # Verify that we fallback to the generic printer for safety. + module = Operation.create("module", regions=2) + # This module has two region and is invalid verify that we fallback + # to the generic printer for safety. block = module.regions[0].blocks.append() # CHECK: // Verification failed, printing generic form # CHECK: "module"() ( { diff --git a/mlir/test/Bindings/Python/ods_helpers.py b/mlir/test/Bindings/Python/ods_helpers.py --- a/mlir/test/Bindings/Python/ods_helpers.py +++ b/mlir/test/Bindings/Python/ods_helpers.py @@ -29,7 +29,7 @@ with Context() as ctx, Location.unknown(): ctx.allow_unregistered_dialects = True m = Module.create() - with InsertionPoint.at_block_terminator(m.body): + with InsertionPoint(m.body): op = TestFixedRegionsOp.build_generic(results=[], operands=[]) # CHECK: NUM_REGIONS: 2 print(f"NUM_REGIONS: {len(op.regions)}") @@ -84,7 +84,7 @@ with Context() as ctx, Location.unknown(): ctx.allow_unregistered_dialects = True m = Module.create() - with InsertionPoint.at_block_terminator(m.body): + with InsertionPoint(m.body): v0 = add_dummy_value() v1 = add_dummy_value() t0 = IntegerType.get_signless(8) @@ -111,7 +111,7 @@ with Context() as ctx, Location.unknown(): ctx.allow_unregistered_dialects = True m = Module.create() - with InsertionPoint.at_block_terminator(m.body): + with InsertionPoint(m.body): v0 = add_dummy_value() v1 = add_dummy_value() v2 = add_dummy_value() @@ -187,7 +187,7 @@ with Context() as ctx, Location.unknown(): ctx.allow_unregistered_dialects = True m = Module.create() - with InsertionPoint.at_block_terminator(m.body): + with InsertionPoint(m.body): v0 = add_dummy_value() v1 = add_dummy_value() t0 = IntegerType.get_signless(8) diff --git a/mlir/test/Bindings/Python/pass_manager.py b/mlir/test/Bindings/Python/pass_manager.py --- a/mlir/test/Bindings/Python/pass_manager.py +++ b/mlir/test/Bindings/Python/pass_manager.py @@ -91,6 +91,5 @@ # CHECK: Operations encountered: # CHECK: func , 1 # CHECK: module , 1 -# CHECK: module_terminator , 1 # CHECK: std.return , 1 run(testRunPipeline) diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -293,7 +293,7 @@ fprintf(stderr, "Number of op results: %u\n", stats.numOpResults); // clang-format off // CHECK-LABEL: @stats - // CHECK: Number of operations: 13 + // CHECK: Number of operations: 12 // CHECK: Number of attributes: 4 // CHECK: Number of blocks: 3 // CHECK: Number of regions: 3 diff --git a/mlir/test/CAPI/pass.c b/mlir/test/CAPI/pass.c --- a/mlir/test/CAPI/pass.c +++ b/mlir/test/CAPI/pass.c @@ -42,7 +42,6 @@ // Run the print-op-stats pass on the top-level module: // CHECK-LABEL: Operations encountered: // CHECK: func , 1 - // CHECK: module_terminator , 1 // CHECK: std.addi , 1 // CHECK: std.return , 1 { @@ -84,7 +83,6 @@ // Run the print-op-stats pass on functions under the top-level module: // CHECK-LABEL: Operations encountered: - // CHECK-NOT: module_terminator // CHECK: func , 1 // CHECK: std.addi , 1 // CHECK: std.return , 1 @@ -101,7 +99,6 @@ } // Run the print-op-stats pass on functions under the nested module: // CHECK-LABEL: Operations encountered: - // CHECK-NOT: module_terminator // CHECK: func , 1 // CHECK: std.addf , 1 // CHECK: std.return , 1 diff --git a/mlir/test/IR/invalid-module-op.mlir b/mlir/test/IR/invalid-module-op.mlir --- a/mlir/test/IR/invalid-module-op.mlir +++ b/mlir/test/IR/invalid-module-op.mlir @@ -19,31 +19,12 @@ // expected-error@+1 {{region should have no arguments}} module { ^bb1(%arg: i32): - "module_terminator"() : () -> () - } - return -} - -// ----- - -func @module_op() { - // expected-error@below {{expects regions to end with 'module_terminator'}} - // expected-note@below {{the absence of terminator implies 'module_terminator'}} - module { - return } return } // ----- -func @module_op() { - // expected-error@+1 {{expects parent op 'module'}} - "module_terminator"() : () -> () -} - -// ----- - // expected-error@+1 {{can only contain attributes with dialect-prefixed names}} module attributes {attr} { } diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -120,7 +120,7 @@ // ----- -func @no_terminator() { // expected-error {{block with no terminator}} +func @no_terminator() { // expected-error {{empty block: expect at least a terminator}} ^bb40: return ^bb41: diff --git a/mlir/test/IR/module-op.mlir b/mlir/test/IR/module-op.mlir --- a/mlir/test/IR/module-op.mlir +++ b/mlir/test/IR/module-op.mlir @@ -4,16 +4,14 @@ module { } -// CHECK: module { -// CHECK-NEXT: } -module { - "module_terminator"() : () -> () -} +// ----- // CHECK: module attributes {foo.attr = true} { module attributes {foo.attr = true} { } +// ----- + // CHECK: module { module { // CHECK-NEXT: "foo.result_op"() : () -> i32 diff --git a/mlir/test/IR/print-ir-defuse.mlir b/mlir/test/IR/print-ir-defuse.mlir --- a/mlir/test/IR/print-ir-defuse.mlir +++ b/mlir/test/IR/print-ir-defuse.mlir @@ -18,8 +18,6 @@ // CHECK: Has 0 results: // CHECK: Visiting op 'dialect.op3' with 0 operands: // CHECK: Has 0 results: -// CHECK: Visiting op 'module_terminator' with 0 operands: -// CHECK: Has 0 results: // CHECK: Visiting op 'module' with 0 operands: // CHECK: Has 0 results: diff --git a/mlir/test/IR/print-ir-nesting.mlir b/mlir/test/IR/print-ir-nesting.mlir --- a/mlir/test/IR/print-ir-nesting.mlir +++ b/mlir/test/IR/print-ir-nesting.mlir @@ -3,7 +3,7 @@ // CHECK: visiting op: 'module' with 0 operands and 0 results // CHECK: 1 nested regions: // CHECK: Region with 1 blocks: -// CHECK: Block with 0 arguments, 0 successors, and 3 operations +// CHECK: Block with 0 arguments, 0 successors, and 2 operations module { @@ -52,6 +52,4 @@ "dialect.innerop7"() : () -> () }) : () -> () -// CHECK: visiting op: 'module_terminator' with 0 operands and 0 results - } // module diff --git a/mlir/test/IR/region.mlir b/mlir/test/IR/region.mlir --- a/mlir/test/IR/region.mlir +++ b/mlir/test/IR/region.mlir @@ -73,3 +73,11 @@ }) : () -> () return } + +// ----- + +// Region with single block and not terminator. +// CHECK: unregistered_without_terminator +"test.unregistered_without_terminator"() ( { + ^bb0: // no predecessors +}) : () -> () diff --git a/mlir/test/Transforms/test-legalizer-analysis.mlir b/mlir/test/Transforms/test-legalizer-analysis.mlir --- a/mlir/test/Transforms/test-legalizer-analysis.mlir +++ b/mlir/test/Transforms/test-legalizer-analysis.mlir @@ -1,6 +1,5 @@ // RUN: mlir-opt -allow-unregistered-dialect -test-legalize-patterns -verify-diagnostics -test-legalize-mode=analysis %s | FileCheck %s // expected-remark@-2 {{op 'module' is legalizable}} -// expected-remark@-3 {{op 'module_terminator' is legalizable}} // expected-remark@+1 {{op 'func' is legalizable}} func @test(%arg0: f32) { diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -575,7 +575,7 @@ // Define the conversion target used for the test. ConversionTarget target(getContext()); - target.addLegalOp(); + target.addLegalOp(); target.addLegalOp(); target @@ -704,7 +704,7 @@ patterns.insert(&getContext()); mlir::ConversionTarget target(getContext()); - target.addLegalOp(); + target.addLegalOp(); // We make OneVResOneVOperandOp1 legal only when it has more that one // operand. This will trigger the conversion that will replace one-operand // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1. @@ -972,9 +972,8 @@ .insert( context); ConversionTarget target(*context); - target.addLegalOp(); + target.addLegalOp(); target.addIllegalOp(); /// Expect the op to have a single block after legalization. diff --git a/mlir/test/lib/Transforms/TestConvVectorization.cpp b/mlir/test/lib/Transforms/TestConvVectorization.cpp --- a/mlir/test/lib/Transforms/TestConvVectorization.cpp +++ b/mlir/test/lib/Transforms/TestConvVectorization.cpp @@ -56,7 +56,7 @@ ConversionTarget target(*context); target.addLegalDialect(); - target.addLegalOp(); + target.addLegalOp(); target.addLegalOp(); SmallVector stage1Patterns; diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -437,6 +437,14 @@ llvm::any_of(op.getTraits(), [](const OpTrait &trait) { return trait.getDef().isSubClassOf("SingleBlockImplicitTerminator"); }); + + hasSingleBlockTrait = + hasImplicitTermTrait || + llvm::any_of(op.getTraits(), [](const OpTrait &trait) { + if (auto *native = dyn_cast(&trait)) + return native->getTrait() == "::mlir::OpTrait::SingleBlock"; + return false; + }); } /// Generate the operation parser from this format. @@ -472,6 +480,9 @@ /// trait. bool hasImplicitTermTrait; + /// A flag indicating if this operation has the SingleBlock trait. + bool hasSingleBlockTrait; + /// A map of buildable types to indices. llvm::MapVector> buildableTypes; @@ -667,6 +678,14 @@ ensureTerminator(*region, parser.getBuilder(), result.location); )"; +/// The code snippet used to ensure a list of regions have a block. +/// +/// {0}: The name of the region list. +const char *regionListEnsureSingleBlockParserCode = R"( + for (auto ®ion : {0}Regions) + if (region.empty()) *{0}Region.emplaceBlock(); +)"; + /// The code snippet used to generate a parser call for an optional region. /// /// {0}: The name of the region. @@ -693,6 +712,13 @@ ensureTerminator(*{0}Region, parser.getBuilder(), result.location); )"; +/// The code snippet used to ensure a region has a block. +/// +/// {0}: The name of the region. +const char *regionEnsureSingleBlockParserCode = R"( + if ({0}Region->empty()) {0}Region->emplaceBlock(); +)"; + /// The code snippet used to generate a parser call for a successor list. /// /// {0}: The name for the successor list. @@ -1120,6 +1146,9 @@ body << " if (!" << region->name << "Region->empty()) {\n "; if (hasImplicitTermTrait) body << llvm::formatv(regionEnsureTerminatorParserCode, region->name); + else if (hasSingleBlockTrait) + body << llvm::formatv(regionEnsureSingleBlockParserCode, + region->name); } } @@ -1193,11 +1222,14 @@ bool isVariadic = region->getVar()->isVariadic(); body << llvm::formatv(isVariadic ? regionListParserCode : regionParserCode, region->getVar()->name); - if (hasImplicitTermTrait) { + if (hasImplicitTermTrait) body << llvm::formatv(isVariadic ? regionListEnsureTerminatorParserCode : regionEnsureTerminatorParserCode, region->getVar()->name); - } + else if (hasSingleBlockTrait) + body << llvm::formatv(isVariadic ? regionListEnsureSingleBlockParserCode + : regionEnsureSingleBlockParserCode, + region->getVar()->name); } else if (auto *successor = dyn_cast(element)) { bool isVariadic = successor->getVar()->isVariadic(); @@ -1222,6 +1254,8 @@ body << llvm::formatv(regionListParserCode, "full"); if (hasImplicitTermTrait) body << llvm::formatv(regionListEnsureTerminatorParserCode, "full"); + else if (hasSingleBlockTrait) + body << llvm::formatv(regionListEnsureSingleBlockParserCode, "full"); } else if (isa(element)) { body << llvm::formatv(successorListParserCode, "full");