diff --git a/mlir/docs/Tutorials/Toy/Ch-6.md b/mlir/docs/Tutorials/Toy/Ch-6.md --- a/mlir/docs/Tutorials/Toy/Ch-6.md +++ b/mlir/docs/Tutorials/Toy/Ch-6.md @@ -63,7 +63,7 @@ ```c++ mlir::ConversionTarget target(getContext()); target.addLegalDialect(); - target.addLegalOp(); + target.addLegalOp(); ``` ### Type Converter diff --git a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp --- a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp +++ b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp @@ -172,7 +172,7 @@ // final target for this lowering. For this lowering, we are only targeting // the LLVM dialect. LLVMConversionTarget target(getContext()); - target.addLegalOp(); + target.addLegalOp(); // During this lowering, we will also be lowering the MemRef types, that are // currently being operated on, to a representation in LLVM. To perform this diff --git a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp --- a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp +++ b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp @@ -172,7 +172,7 @@ // final target for this lowering. For this lowering, we are only targeting // the LLVM dialect. LLVMConversionTarget target(getContext()); - target.addLegalOp(); + target.addLegalOp(); // During this lowering, we will also be lowering the MemRef types, that are // currently being operated on, to a representation in LLVM. To perform this diff --git a/mlir/include/mlir/IR/BuiltinOps.h b/mlir/include/mlir/IR/BuiltinOps.h --- a/mlir/include/mlir/IR/BuiltinOps.h +++ b/mlir/include/mlir/IR/BuiltinOps.h @@ -15,6 +15,7 @@ #include "mlir/IR/FunctionSupport.h" #include "mlir/IR/OwningOpRef.h" +#include "mlir/IR/RegionKindInterface.h" #include "mlir/IR/SymbolTable.h" #include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Interfaces/CastInterfaces.h" diff --git a/mlir/include/mlir/IR/BuiltinOps.td b/mlir/include/mlir/IR/BuiltinOps.td --- a/mlir/include/mlir/IR/BuiltinOps.td +++ b/mlir/include/mlir/IR/BuiltinOps.td @@ -15,6 +15,7 @@ #define BUILTIN_OPS include "mlir/IR/BuiltinDialect.td" +include "mlir/IR/RegionKindInterface.td" include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/CastInterfaces.td" @@ -158,9 +159,8 @@ //===----------------------------------------------------------------------===// def ModuleOp : Builtin_Op<"module", [ - AffineScope, IsolatedFromAbove, NoRegionArguments, SymbolTable, Symbol, - SingleBlockImplicitTerminator<"ModuleTerminatorOp"> -]> { + AffineScope, IsolatedFromAbove, NoRegionArguments, SymbolTable, Symbol] + # NoTerminator.traits> { let summary = "A top level container operation"; let description = [{ A `module` represents a top-level container operation. It contains a single @@ -206,22 +206,6 @@ let skipDefaultBuilders = 1; } -//===----------------------------------------------------------------------===// -// ModuleTerminatorOp -//===----------------------------------------------------------------------===// - -def ModuleTerminatorOp : Builtin_Op<"module_terminator", [ - Terminator, HasParent<"ModuleOp"> -]> { - let summary = "A pseudo op that marks the end of a module"; - let description = [{ - `module_terminator` is a special terminator operation for the body of a - `module`, it has no semantic meaning beyond keeping the body of a `module` - well-formed. - }]; - let assemblyFormat = "attr-dict"; -} - //===----------------------------------------------------------------------===// // UnrealizedConversionCastOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1808,6 +1808,9 @@ ]; } +// Op's regions have a single block. +def SingleBlock : NativeOpTrait<"SingleBlock">; + // Op's regions have a single block with the specified terminator. class SingleBlockImplicitTerminator : ParamNativeOpTrait<"SingleBlockImplicitTerminator", op>; diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -675,6 +675,11 @@ //===----------------------------------------------------------------------===// // Terminator Traits +/// This class indicates that the regions associated with this op don't have +/// terminators. This is valid only for Graph regions (see RegionKind). +template +class NoTerminator : public TraitBase {}; + /// This class provides the API for ops that are known to be terminators. template class IsTerminator : public TraitBase { @@ -778,6 +783,90 @@ : public detail::MultiSuccessorTraitBase { }; +//===----------------------------------------------------------------------===// +// SingleBlock + +/// This class provides APIs and verifiers for ops with regions having a single +/// block. +template +struct SingleBlock : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i) { + Region ®ion = op->getRegion(i); + + // Empty regions are fine. + if (region.empty()) + continue; + + // Non-empty regions must contain a single basic block. + if (std::next(region.begin()) != region.end()) + return op->emitOpError("expects region #") + << i << " to have 0 or 1 blocks"; + + if (!ConcreteType::template hasTrait()) { + Block &block = region.front(); + if (block.empty()) + return op->emitOpError() << "expects a non-empty block"; + } + } + return success(); + } + + Block *getBody(unsigned idx = 0) { + Region ®ion = this->getOperation()->getRegion(idx); + assert(!region.empty() && "unexpected empty region"); + return ®ion.front(); + } + Region &getBodyRegion(unsigned idx = 0) { + return this->getOperation()->getRegion(idx); + } + + //===------------------------------------------------------------------===// + // Single Region Utilities + //===------------------------------------------------------------------===// + + /// The following are a set of methods only enabled when the parent + /// operation has a single region. Each of these methods take an additional + /// template parameter that represents the concrete operation so that we + /// can use SFINAE to disable the methods for non-single region operations. + template + using enable_if_single_region = + typename std::enable_if_t(), T>; + + template + enable_if_single_region begin() { + return getBody()->begin(); + } + template + enable_if_single_region end() { + return getBody()->end(); + } + template + enable_if_single_region front() { + return *begin(); + } + + /// Insert the operation into the back of the body. + template + enable_if_single_region push_back(Operation *op) { + insert(Block::iterator(getBody()->end()), op); + } + + /// Insert the operation at the given insertion point. + template + enable_if_single_region insert(Operation *insertPt, Operation *op) { + insert(Block::iterator(insertPt), op); + } + template + enable_if_single_region insert(Block::iterator insertPt, Operation *op) { + auto *body = getBody(); + if (insertPt == body->end()) + insertPt = Block::iterator(body->end()); + body->getOperations().insert(insertPt, op); + } +}; + //===----------------------------------------------------------------------===// // SingleBlockImplicitTerminator @@ -786,8 +875,9 @@ template struct SingleBlockImplicitTerminator { template - class Impl : public TraitBase { + class Impl : public SingleBlock { private: + using Base = SingleBlock; /// Builds a terminator operation without relying on OpBuilder APIs to avoid /// cyclic header inclusion. static Operation *buildTerminator(OpBuilder &builder, Location loc) { @@ -801,22 +891,14 @@ using ImplicitTerminatorOpT = TerminatorOpType; static LogicalResult verifyTrait(Operation *op) { + if (failed(Base::verifyTrait(op))) + return failure(); for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i) { Region ®ion = op->getRegion(i); - // Empty regions are fine. if (region.empty()) continue; - - // Non-empty regions must contain a single basic block. - if (std::next(region.begin()) != region.end()) - return op->emitOpError("expects region #") - << i << " to have 0 or 1 blocks"; - - Block &block = region.front(); - if (block.empty()) - return op->emitOpError() << "expects a non-empty block"; - Operation &terminator = block.back(); + Operation &terminator = region.front().back(); if (isa(terminator)) continue; @@ -849,40 +931,15 @@ buildTerminator); } - Block *getBody(unsigned idx = 0) { - Region ®ion = this->getOperation()->getRegion(idx); - assert(!region.empty() && "unexpected empty region"); - return ®ion.front(); - } - Region &getBodyRegion(unsigned idx = 0) { - return this->getOperation()->getRegion(idx); - } - //===------------------------------------------------------------------===// // Single Region Utilities //===------------------------------------------------------------------===// + using Base::getBody; - /// The following are a set of methods only enabled when the parent - /// operation has a single region. Each of these methods take an additional - /// template parameter that represents the concrete operation so that we - /// can use SFINAE to disable the methods for non-single region operations. template using enable_if_single_region = typename std::enable_if_t(), T>; - template - enable_if_single_region begin() { - return getBody()->begin(); - } - template - enable_if_single_region end() { - return getBody()->end(); - } - template - enable_if_single_region front() { - return *begin(); - } - /// Insert the operation into the back of the body, before the terminator. template enable_if_single_region push_back(Operation *op) { diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -92,8 +92,11 @@ virtual void printGenericOp(Operation *op) = 0; /// Prints a region. + /// If printEmptyBlock is true, then the block header is printed even if the + /// block is empty. virtual void printRegion(Region &blocks, bool printEntryBlockArgs = true, - bool printBlockTerminators = true) = 0; + bool printBlockTerminators = true, + bool printEmptyBlock = false) = 0; /// Renumber the arguments for the specified region to the same names as the /// SSA values in namesToUse. This may only be used for IsolatedFromAbove diff --git a/mlir/include/mlir/IR/Region.h b/mlir/include/mlir/IR/Region.h --- a/mlir/include/mlir/IR/Region.h +++ b/mlir/include/mlir/IR/Region.h @@ -43,6 +43,10 @@ using BlockListType = llvm::iplist; BlockListType &getBlocks() { return blocks; } + Block &emplaceBlock() { + push_back(new Block); + return back(); + } // Iteration over the blocks in the region. using iterator = BlockListType::iterator; diff --git a/mlir/include/mlir/IR/RegionKindInterface.h b/mlir/include/mlir/IR/RegionKindInterface.h --- a/mlir/include/mlir/IR/RegionKindInterface.h +++ b/mlir/include/mlir/IR/RegionKindInterface.h @@ -28,6 +28,16 @@ Graph, }; +namespace OpTrait { +/// A trait that specifies that an operation is only defining graph regions. +template +class HasOnlyGraphRegion : public TraitBase { +public: + static RegionKind getRegionKind(unsigned index) { return RegionKind::Graph; } + static bool hasSSADominance(unsigned index) { return false; } +}; +} // namespace OpTrait + } // namespace mlir #include "mlir/IR/RegionKindInterface.h.inc" diff --git a/mlir/include/mlir/IR/RegionKindInterface.td b/mlir/include/mlir/IR/RegionKindInterface.td --- a/mlir/include/mlir/IR/RegionKindInterface.td +++ b/mlir/include/mlir/IR/RegionKindInterface.td @@ -50,4 +50,17 @@ ]; } +def HasOnlyGraphRegion : NativeOpTrait<"HasOnlyGraphRegion">; + +// Op's regions that don't need a terminator: requires some other traits +// so it defines a list that must be concatenated. +def NoTerminator { + list traits = [ + NativeOpTrait<"NoTerminator">, + SingleBlock, + RegionKindInterface, + HasOnlyGraphRegion + ]; +} + #endif // MLIR_IR_REGIONKINDINTERFACE diff --git a/mlir/include/mlir/Parser.h b/mlir/include/mlir/Parser.h --- a/mlir/include/mlir/Parser.h +++ b/mlir/include/mlir/Parser.h @@ -37,12 +37,11 @@ Block *parsedBlock, MLIRContext *context, Location sourceFileLoc) { static_assert( ContainerOpT::template hasTrait() && - std::is_base_of:: - template Impl, - ContainerOpT>::value, + (ContainerOpT::template hasTrait() || + ContainerOpT::template hasTrait< + OpTrait::SingleBlockImplicitTerminator>()), "Expected `ContainerOpT` to have a single region with a single " - "block that has an implicit terminator"); + "block that has an implicit terminator or does not require one"); // Check to see if we parsed a single instance of this operation. if (llvm::hasSingleElement(*parsedBlock)) { diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp --- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp +++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp @@ -156,8 +156,8 @@ /// Adds Async Runtime C API declarations to the module. static void addAsyncRuntimeApiDeclarations(ModuleOp module) { - auto builder = ImplicitLocOpBuilder::atBlockTerminator(module.getLoc(), - module.getBody()); + auto builder = + ImplicitLocOpBuilder::atBlockEnd(module.getLoc(), module.getBody()); auto addFuncDecl = [&](StringRef name, FunctionType type) { if (module.lookupSymbol(name)) @@ -207,8 +207,8 @@ using namespace mlir::LLVM; MLIRContext *ctx = module.getContext(); - ImplicitLocOpBuilder builder(module.getLoc(), - module.getBody()->getTerminator()); + auto builder = + ImplicitLocOpBuilder::atBlockEnd(module.getLoc(), module.getBody()); auto voidTy = LLVMVoidType::get(ctx); auto i64 = IntegerType::get(ctx, 64); @@ -232,15 +232,14 @@ return; MLIRContext *ctx = module.getContext(); - - OpBuilder moduleBuilder(module.getBody()->getTerminator()); - Location loc = module.getLoc(); + auto loc = module.getLoc(); + auto moduleBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, module.getBody()); auto voidTy = LLVM::LLVMVoidType::get(ctx); auto i8Ptr = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)); auto resumeOp = moduleBuilder.create( - loc, kResume, LLVM::LLVMFunctionType::get(voidTy, {i8Ptr})); + kResume, LLVM::LLVMFunctionType::get(voidTy, {i8Ptr})); resumeOp.setPrivate(); auto *block = resumeOp.addEntryBlock(); diff --git a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp --- a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp +++ b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp @@ -329,7 +329,7 @@ auto function = [&] { if (auto function = module.lookupSymbol(functionName)) return function; - return OpBuilder(module.getBody()->getTerminator()) + return OpBuilder::atBlockEnd(module.getBody()) .create(loc, functionName, functionType); }(); return builder.create( diff --git a/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp --- a/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp +++ b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp @@ -99,7 +99,7 @@ LogicalResult ConvertGpuLaunchFuncToVulkanLaunchFunc::declareVulkanLaunchFunc( Location loc, gpu::LaunchFuncOp launchOp) { - OpBuilder builder(getOperation().getBody()->getTerminator()); + auto builder = OpBuilder::atBlockEnd(getOperation().getBody()); // Workgroup size is written into the kernel. So to properly modelling // vulkan launch, we have to skip local workgroup size configuration here. diff --git a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp --- a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp +++ b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp @@ -291,7 +291,7 @@ void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) { ModuleOp module = getOperation(); - OpBuilder builder(module.getBody()->getTerminator()); + auto builder = OpBuilder::atBlockEnd(module.getBody()); if (!module.lookupSymbol(kSetEntryPoint)) { builder.create( diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp --- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp +++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp @@ -227,7 +227,7 @@ LLVMConversionTarget target(getContext()); target.addIllegalOp(); - target.addLegalOp(); + target.addLegalOp(); if (failed(applyPartialConversion(module, target, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp --- a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp +++ b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp @@ -35,7 +35,7 @@ populateBuiltinFuncToSPIRVPatterns(context, typeConverter, patterns); // Allow builtin ops. - target->addLegalOp(); + target->addLegalOp(); target->addDynamicallyLegalOp([&](FuncOp op) { return typeConverter.isSignatureLegal(op.getType()) && typeConverter.isLegal(&op.getBody()); diff --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp --- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp +++ b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp @@ -214,7 +214,7 @@ auto module = getOperation(); ConversionTarget target(getContext()); target.addLegalDialect(); - target.addLegalOp(); + target.addLegalOp(); target.addLegalOp(); OwningRewritePatternList patterns; populateLinalgToStandardConversionPatterns(patterns, &getContext()); diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp @@ -1361,7 +1361,7 @@ matchAndRewrite(spirv::ModuleEndOp moduleEndOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(moduleEndOp); + rewriter.eraseOp(moduleEndOp); return success(); } }; diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.cpp @@ -48,10 +48,8 @@ target.addIllegalDialect(); target.addLegalDialect(); - // Set `ModuleOp` and `ModuleTerminatorOp` as legal for `spv.module` - // conversion. + // Set `ModuleOp` as legal for `spv.module` conversion. target.addLegalOp(); - target.addLegalOp(); if (failed(applyPartialConversion(module, target, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -672,7 +672,7 @@ ConversionTarget target(ctx); target .addLegalDialect(); - target.addLegalOp(); + target.addLegalOp(); // Setup conversion patterns. OwningRewritePatternList patterns; diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp @@ -40,7 +40,7 @@ OwningRewritePatternList patterns; populateVectorToSPIRVPatterns(context, typeConverter, patterns); - target->addLegalOp(); + target->addLegalOp(); target->addLegalOp(); if (failed(applyFullConversion(module, *target, std::move(patterns)))) diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp --- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp @@ -199,7 +199,7 @@ // TODO: Derive outlined function name from the parent FuncOp (support // multiple nested async.execute operations). FuncOp func = FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs); - symbolTable.insert(func, Block::iterator(module.getBody()->getTerminator())); + symbolTable.insert(func); SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private); diff --git a/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp --- a/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp @@ -43,8 +43,7 @@ populateBranchOpInterfaceTypeConversionPattern(patterns, context, typeConverter); populateReturnOpTypeConversionPattern(patterns, context, typeConverter); - target.addLegalOp(); + target.addLegalOp(); target.markUnknownOpDynamicallyLegal([&](Operation *op) { return isNotBranchOpInterfaceOrReturnLikeOp(op) || diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -418,7 +418,8 @@ /// Print the given region. void printRegion(Region ®ion, bool printEntryBlockArgs, - bool printBlockTerminators) override { + bool printBlockTerminators, + bool printEmptyBlock = false) override { if (region.empty()) return; @@ -2324,7 +2325,7 @@ /// Print the given region. void printRegion(Region ®ion, bool printEntryBlockArgs, - bool printBlockTerminators) override; + bool printBlockTerminators, bool printEmptyBlock) override; /// Renumber the arguments for the specified region to the same names as the /// SSA values in namesToUse. This may only be used for IsolatedFromAbove @@ -2435,7 +2436,7 @@ os << " ("; interleaveComma(op->getRegions(), [&](Region ®ion) { printRegion(region, /*printEntryBlockArgs=*/true, - /*printBlockTerminators=*/true); + /*printBlockTerminators=*/true, /*printEmptyBlock=*/true); }); os << ')'; } @@ -2499,7 +2500,6 @@ } os << newLine; } - currentIndent += indentWidth; auto range = llvm::make_range( block->begin(), std::prev(block->end(), printBlockTerminator ? 0 : 1)); @@ -2536,11 +2536,14 @@ } void OperationPrinter::printRegion(Region ®ion, bool printEntryBlockArgs, - bool printBlockTerminators) { + bool printBlockTerminators, + bool printEmptyBlock) { os << " {" << newLine; if (!region.empty()) { auto *entryBlock = ®ion.front(); - print(entryBlock, printEntryBlockArgs && entryBlock->getNumArguments() != 0, + print(entryBlock, + (printEmptyBlock && entryBlock->empty()) || + (printEntryBlockArgs && entryBlock->getNumArguments() != 0), printBlockTerminators); for (auto &b : llvm::drop_begin(region.getBlocks(), 1)) print(&b); diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp --- a/mlir/lib/IR/Block.cpp +++ b/mlir/lib/IR/Block.cpp @@ -305,9 +305,11 @@ SuccessorRange::SuccessorRange() : SuccessorRange(nullptr, 0) {} SuccessorRange::SuccessorRange(Block *block) : SuccessorRange() { - if (Operation *term = block->getTerminator()) + if (!llvm::hasSingleElement(*block->getParent())) { + Operation *term = block->getTerminator(); if ((count = term->getNumSuccessors())) base = term->getBlockOperands().data(); + } } SuccessorRange::SuccessorRange(Operation *term) : SuccessorRange() { diff --git a/mlir/lib/IR/BuiltinDialect.cpp b/mlir/lib/IR/BuiltinDialect.cpp --- a/mlir/lib/IR/BuiltinDialect.cpp +++ b/mlir/lib/IR/BuiltinDialect.cpp @@ -217,7 +217,7 @@ void ModuleOp::build(OpBuilder &builder, OperationState &state, Optional name) { - ensureTerminator(*state.addRegion(), builder, state.location); + state.addRegion()->emplaceBlock(); if (name) { state.attributes.push_back(builder.getNamedAttr( mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(*name))); diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp --- a/mlir/lib/IR/SymbolTable.cpp +++ b/mlir/lib/IR/SymbolTable.cpp @@ -161,11 +161,15 @@ // TODO: consider if SymbolTable's constructor should behave the same. if (!symbol->getParentOp()) { auto &body = symbolTableOp->getRegion(0).front(); - if (insertPt == Block::iterator() || insertPt == body.end()) - insertPt = Block::iterator(body.getTerminator()); - - assert(insertPt->getParentOp() == symbolTableOp && - "expected insertPt to be in the associated module operation"); + if (insertPt == Block::iterator()) + insertPt = Block::iterator(body.end()); + else + assert(insertPt->getParentOp() == symbolTableOp && + "expected insertPt to be in the associated module operation"); + // Insert before the terminator, if any. + if (insertPt == Block::iterator(body.end()) && !body.empty() && + std::prev(body.end())->hasTrait()) + insertPt = std::prev(body.end()); body.getOperations().insert(insertPt, symbol); } @@ -291,11 +295,13 @@ Operation *SymbolTable::lookupSymbolIn(Operation *symbolTableOp, StringRef symbol) { assert(symbolTableOp->hasTrait()); + if (symbolTableOp->getRegion(0).empty()) + return nullptr; // Look for a symbol with the given name. Identifier symbolNameId = Identifier::get(SymbolTable::getSymbolAttrName(), symbolTableOp->getContext()); - for (auto &op : symbolTableOp->getRegion(0).front().without_terminator()) + for (auto &op : symbolTableOp->getRegion(0).front()) if (getNameIfSymbol(&op, symbolNameId) == symbol) return &op; return nullptr; diff --git a/mlir/lib/IR/Verifier.cpp b/mlir/lib/IR/Verifier.cpp --- a/mlir/lib/IR/Verifier.cpp +++ b/mlir/lib/IR/Verifier.cpp @@ -48,8 +48,9 @@ private: /// Verify the given potentially nested region or block. - LogicalResult verifyRegion(Region ®ion); - LogicalResult verifyBlock(Block &block); + LogicalResult verifyRegion(Region ®ion, + RegionKind kind = RegionKind::SSACFG); + LogicalResult verifyBlock(Block &block, RegionKind kind); LogicalResult verifyOperation(Operation &op); /// Verify the dominance property of operations within the given Region. @@ -96,7 +97,7 @@ return success(); } -LogicalResult OperationVerifier::verifyRegion(Region ®ion) { +LogicalResult OperationVerifier::verifyRegion(Region ®ion, RegionKind kind) { if (region.empty()) return success(); @@ -108,19 +109,20 @@ // Verify each of the blocks within the region. for (Block &block : region) - if (failed(verifyBlock(block))) + if (failed(verifyBlock(block, kind))) return failure(); return success(); } -LogicalResult OperationVerifier::verifyBlock(Block &block) { +LogicalResult OperationVerifier::verifyBlock(Block &block, RegionKind kind) { for (auto arg : block.getArguments()) if (arg.getOwner() != &block) return emitError(block, "block argument not owned by block"); // Verify that this block has a terminator. - if (block.empty()) - return emitError(block, "block with no terminator"); + + if (block.empty() && kind != RegionKind::Graph) + return emitError(block, "empty block: expect at least a terminator"); // Verify the non-terminator operations separately so that we can verify // they has no successors. @@ -134,11 +136,20 @@ } // Verify the terminator. + + if (kind == RegionKind::Graph) { + // Graph region don't have terminators. + if (!block.empty()) + return verifyOperation(block.back()); + return success(); + } + Operation &terminator = block.back(); if (failed(verifyOperation(terminator))) return failure(); if (!terminator.mightHaveTrait()) - return block.back().emitError("block with no terminator"); + return block.back().emitError("block with no terminator, has ") + << terminator; // Verify that this block is not branching to a block of a different // region. @@ -176,13 +187,14 @@ unsigned numRegions = op.getNumRegions(); for (unsigned i = 0; i < numRegions; i++) { Region ®ion = op.getRegion(i); + RegionKind kind = + kindInterface ? kindInterface.getRegionKind(i) : RegionKind::SSACFG; // Check that Graph Regions only have a single basic block. This is // similar to the code in SingleBlockImplicitTerminator, but doesn't // require the trait to be specified. This arbitrary limitation is // designed to limit the number of cases that have to be handled by // transforms and conversions until the concept stabilizes. - if (op.isRegistered() && kindInterface && - kindInterface.getRegionKind(i) == RegionKind::Graph) { + if (op.isRegistered() && kind == RegionKind::Graph) { // Empty regions are fine. if (region.empty()) continue; @@ -192,7 +204,7 @@ return op.emitOpError("expects graph region #") << i << " to have 0 or 1 blocks"; } - if (failed(verifyRegion(region))) + if (failed(verifyRegion(region, kind))) return failure(); } diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -2056,7 +2056,7 @@ auto &parsedOps = (*topLevelOp)->getRegion(0).front().getOperations(); auto &destOps = topLevelBlock->getOperations(); destOps.splice(destOps.empty() ? destOps.end() : std::prev(destOps.end()), - parsedOps, parsedOps.begin(), std::prev(parsedOps.end())); + parsedOps, parsedOps.begin(), parsedOps.end()); return success(); } diff --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp --- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp @@ -269,10 +269,11 @@ /// Globals are inserted before the first function, if any. Block::iterator getGlobalInsertPt() { - auto i = module.getBody()->begin(); - while (!isa(i)) - ++i; - return i; + auto it = module.getBody()->begin(); + auto endIt = module.getBody()->end(); + while (it != endIt && !isa(it)) + ++it; + return it; } /// Functions are always inserted before the module terminator. diff --git a/mlir/lib/Transforms/SymbolDCE.cpp b/mlir/lib/Transforms/SymbolDCE.cpp --- a/mlir/lib/Transforms/SymbolDCE.cpp +++ b/mlir/lib/Transforms/SymbolDCE.cpp @@ -61,8 +61,7 @@ if (!nestedSymbolTable->hasTrait()) return; for (auto &block : nestedSymbolTable->getRegion(0)) { - for (Operation &op : - llvm::make_early_inc_range(block.without_terminator())) { + for (Operation &op : llvm::make_early_inc_range(block)) { if (isa(&op) && !liveSymbols.count(&op)) op.erase(); } @@ -84,7 +83,7 @@ // are known to be live. for (auto &block : symbolTableOp->getRegion(0)) { // Add all non-symbols or symbols that can't be discarded. - for (Operation &op : block.without_terminator()) { + for (Operation &op : block) { SymbolOpInterface symbol = dyn_cast(&op); if (!symbol) { worklist.push_back(&op); diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp --- a/mlir/lib/Transforms/Utils/RegionUtils.cpp +++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp @@ -300,6 +300,7 @@ for (Region ®ion : regions) { if (region.empty()) continue; + bool hasSingleBlock = llvm::hasSingleElement(region); // We do the deletion in an order that deletes all uses before deleting // defs. @@ -311,7 +312,8 @@ // before domtree parents. A CFG post-order (with reverse iteration with a // block) satisfies that without needing an explicit domtree calculation. for (Block *block : llvm::post_order(®ion.front())) { - eraseTerminatorSuccessorOperands(block->getTerminator(), liveMap); + if (!hasSingleBlock) + eraseTerminatorSuccessorOperands(block->getTerminator(), liveMap); for (Operation &childOp : llvm::make_early_inc_range(llvm::reverse(block->getOperations()))) { erasedAnything |= diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -293,7 +293,7 @@ fprintf(stderr, "Number of op results: %u\n", stats.numOpResults); // clang-format off // CHECK-LABEL: @stats - // CHECK: Number of operations: 13 + // CHECK: Number of operations: 12 // CHECK: Number of attributes: 4 // CHECK: Number of blocks: 3 // CHECK: Number of regions: 3 diff --git a/mlir/test/CAPI/pass.c b/mlir/test/CAPI/pass.c --- a/mlir/test/CAPI/pass.c +++ b/mlir/test/CAPI/pass.c @@ -42,7 +42,6 @@ // Run the print-op-stats pass on the top-level module: // CHECK-LABEL: Operations encountered: // CHECK: func , 1 - // CHECK: module_terminator , 1 // CHECK: std.addi , 1 // CHECK: std.return , 1 { diff --git a/mlir/test/Dialect/GPU/outlining.mlir b/mlir/test/Dialect/GPU/outlining.mlir --- a/mlir/test/Dialect/GPU/outlining.mlir +++ b/mlir/test/Dialect/GPU/outlining.mlir @@ -56,201 +56,3 @@ // CHECK-NEXT: "use"(%[[KERNEL_ARG0]]) : (f32) -> () // CHECK-NEXT: "some_op"(%[[BID]], %[[BDIM]]) : (index, index) -> () // CHECK-NEXT: = load %[[KERNEL_ARG1]][%[[TID]]] : memref - -// ----- - -// CHECK: module attributes {gpu.container_module} -// CHECK-LABEL: @multiple_launches -func @multiple_launches() { - // CHECK: %[[CST:.*]] = constant 8 : index - %cst = constant 8 : index - // CHECK: gpu.launch_func @multiple_launches_kernel::@multiple_launches_kernel blocks in (%[[CST]], %[[CST]], %[[CST]]) threads in (%[[CST]], %[[CST]], %[[CST]]) - gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %cst, %grid_y = %cst, - %grid_z = %cst) - threads(%tx, %ty, %tz) in (%block_x = %cst, %block_y = %cst, - %block_z = %cst) { - gpu.terminator - } - // CHECK: gpu.launch_func @multiple_launches_kernel_0::@multiple_launches_kernel blocks in (%[[CST]], %[[CST]], %[[CST]]) threads in (%[[CST]], %[[CST]], %[[CST]]) - gpu.launch blocks(%bx2, %by2, %bz2) in (%grid_x2 = %cst, %grid_y2 = %cst, - %grid_z2 = %cst) - threads(%tx2, %ty2, %tz2) in (%block_x2 = %cst, %block_y2 = %cst, - %block_z2 = %cst) { - gpu.terminator - } - return -} - -// CHECK: module @multiple_launches_kernel -// CHECK: func @multiple_launches_kernel -// CHECK: module @multiple_launches_kernel_0 -// CHECK: func @multiple_launches_kernel - -// ----- - -// CHECK-LABEL: @extra_constants_not_inlined -func @extra_constants_not_inlined(%arg0: memref) { - // CHECK: %[[CST:.*]] = constant 8 : index - %cst = constant 8 : index - %cst2 = constant 2 : index - %c0 = constant 0 : index - %cst3 = "secret_constant"() : () -> index - // CHECK: gpu.launch_func @extra_constants_not_inlined_kernel::@extra_constants_not_inlined_kernel blocks in (%[[CST]], %[[CST]], %[[CST]]) threads in (%[[CST]], %[[CST]], %[[CST]]) args({{.*}} : memref, {{.*}} : index) - gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %cst, %grid_y = %cst, - %grid_z = %cst) - threads(%tx, %ty, %tz) in (%block_x = %cst, %block_y = %cst, - %block_z = %cst) { - "use"(%cst2, %arg0, %cst3) : (index, memref, index) -> () - gpu.terminator - } - return -} - -// CHECK-LABEL: func @extra_constants_not_inlined_kernel(%{{.*}}: memref, %{{.*}}: index) -// CHECK: constant 2 - -// ----- - -// CHECK-LABEL: @extra_constants -// CHECK-SAME: %[[ARG0:.*]]: memref -func @extra_constants(%arg0: memref) { - // CHECK: %[[CST:.*]] = constant 8 : index - %cst = constant 8 : index - %cst2 = constant 2 : index - %c0 = constant 0 : index - %cst3 = dim %arg0, %c0 : memref - // CHECK: gpu.launch_func @extra_constants_kernel::@extra_constants_kernel blocks in (%[[CST]], %[[CST]], %[[CST]]) threads in (%[[CST]], %[[CST]], %[[CST]]) args(%[[ARG0]] : memref) - gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %cst, %grid_y = %cst, - %grid_z = %cst) - threads(%tx, %ty, %tz) in (%block_x = %cst, %block_y = %cst, - %block_z = %cst) { - "use"(%cst2, %arg0, %cst3) : (index, memref, index) -> () - gpu.terminator - } - return -} - -// CHECK-LABEL: func @extra_constants_kernel( -// CHECK-SAME: %[[KARG0:.*]]: memref -// CHECK: constant 2 -// CHECK: constant 0 -// CHECK: dim %[[KARG0]] - -// ----- - -// CHECK-LABEL: @extra_constants_noarg -// CHECK-SAME: %[[ARG0:.*]]: memref, %[[ARG1:.*]]: memref -func @extra_constants_noarg(%arg0: memref, %arg1: memref) { - // CHECK: %[[CST:.*]] = constant 8 : index - %cst = constant 8 : index - %cst2 = constant 2 : index - %c0 = constant 0 : index - // CHECK: dim %[[ARG1]] - %cst3 = dim %arg1, %c0 : memref - // CHECK: gpu.launch_func @extra_constants_noarg_kernel::@extra_constants_noarg_kernel blocks in (%[[CST]], %[[CST]], %[[CST]]) threads in (%[[CST]], %[[CST]], %[[CST]]) args(%[[ARG0]] : memref, {{.*}} : index) - gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %cst, %grid_y = %cst, - %grid_z = %cst) - threads(%tx, %ty, %tz) in (%block_x = %cst, %block_y = %cst, - %block_z = %cst) { - "use"(%cst2, %arg0, %cst3) : (index, memref, index) -> () - gpu.terminator - } - return -} - -// CHECK-LABEL: func @extra_constants_noarg_kernel( -// CHECK-SAME: %[[KARG0:.*]]: memref, %[[KARG1:.*]]: index -// CHECK: %[[KCST:.*]] = constant 2 -// CHECK: "use"(%[[KCST]], %[[KARG0]], %[[KARG1]]) - -// ----- - -// CHECK-LABEL: @multiple_uses -func @multiple_uses(%arg0 : memref) { - %c1 = constant 1 : index - %c2 = constant 2 : index - // CHECK: gpu.func {{.*}} { - // CHECK: %[[C2:.*]] = constant 2 : index - // CHECK: "use1"(%[[C2]], %[[C2]]) - // CHECK: "use2"(%[[C2]]) - // CHECK: gpu.return - // CHECK: } - gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, - %grid_z = %c1) - threads(%tx, %ty, %tz) in (%block_x = %c1, %block_y = %c1, - %block_z = %c1) { - "use1"(%c2, %c2) : (index, index) -> () - "use2"(%c2) : (index) -> () - gpu.terminator - } - return -} - -// ----- - -// CHECK-LABEL: @multiple_uses2 -func @multiple_uses2(%arg0 : memref<*xf32>) { - %c1 = constant 1 : index - %c2 = constant 2 : index - %d = dim %arg0, %c2 : memref<*xf32> - // CHECK: gpu.func {{.*}} { - // CHECK: %[[C2:.*]] = constant 2 : index - // CHECK: %[[D:.*]] = dim %[[ARG:.*]], %[[C2]] - // CHECK: "use1"(%[[D]]) - // CHECK: "use2"(%[[C2]], %[[C2]]) - // CHECK: "use3"(%[[ARG]]) - // CHECK: gpu.return - // CHECK: } - gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, - %grid_z = %c1) - threads(%tx, %ty, %tz) in (%block_x = %c1, %block_y = %c1, - %block_z = %c1) { - "use1"(%d) : (index) -> () - "use2"(%c2, %c2) : (index, index) -> () - "use3"(%arg0) : (memref<*xf32>) -> () - gpu.terminator - } - return -} - -// ----- - -llvm.mlir.global internal @global(42 : i64) : i64 - -//CHECK-LABEL: @function_call -func @function_call(%arg0 : memref) { - %cst = constant 8 : index - gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %cst, %grid_y = %cst, - %grid_z = %cst) - threads(%tx, %ty, %tz) in (%block_x = %cst, %block_y = %cst, - %block_z = %cst) { - call @device_function() : () -> () - call @device_function() : () -> () - %0 = llvm.mlir.addressof @global : !llvm.ptr - gpu.terminator - } - return -} - -func @device_function() { - call @recursive_device_function() : () -> () - return -} - -func @recursive_device_function() { - call @recursive_device_function() : () -> () - return -} - -// CHECK: gpu.module @function_call_kernel { -// CHECK: gpu.func @function_call_kernel() -// CHECK: call @device_function() : () -> () -// CHECK: call @device_function() : () -> () -// CHECK: llvm.mlir.addressof @global : !llvm.ptr -// CHECK: gpu.return -// -// CHECK: llvm.mlir.global internal @global(42 : i64) : i64 -// -// CHECK: func @device_function() -// CHECK: func @recursive_device_function() -// CHECK-NOT: func @device_function diff --git a/mlir/test/IR/invalid-module-op.mlir b/mlir/test/IR/invalid-module-op.mlir --- a/mlir/test/IR/invalid-module-op.mlir +++ b/mlir/test/IR/invalid-module-op.mlir @@ -19,31 +19,12 @@ // expected-error@+1 {{region should have no arguments}} module { ^bb1(%arg: i32): - "module_terminator"() : () -> () - } - return -} - -// ----- - -func @module_op() { - // expected-error@below {{expects regions to end with 'module_terminator'}} - // expected-note@below {{the absence of terminator implies 'module_terminator'}} - module { - return } return } // ----- -func @module_op() { - // expected-error@+1 {{expects parent op 'module'}} - "module_terminator"() : () -> () -} - -// ----- - // expected-error@+1 {{can only contain attributes with dialect-prefixed names}} module attributes {attr} { } diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -120,7 +120,7 @@ // ----- -func @no_terminator() { // expected-error {{block with no terminator}} +func @no_terminator() { // expected-error {{empty block: expect at least a terminator}} ^bb40: return ^bb41: diff --git a/mlir/test/IR/print-ir-defuse.mlir b/mlir/test/IR/print-ir-defuse.mlir --- a/mlir/test/IR/print-ir-defuse.mlir +++ b/mlir/test/IR/print-ir-defuse.mlir @@ -18,8 +18,6 @@ // CHECK: Has 0 results: // CHECK: Visiting op 'dialect.op3' with 0 operands: // CHECK: Has 0 results: -// CHECK: Visiting op 'module_terminator' with 0 operands: -// CHECK: Has 0 results: // CHECK: Visiting op 'module' with 0 operands: // CHECK: Has 0 results: diff --git a/mlir/test/IR/print-ir-nesting.mlir b/mlir/test/IR/print-ir-nesting.mlir --- a/mlir/test/IR/print-ir-nesting.mlir +++ b/mlir/test/IR/print-ir-nesting.mlir @@ -3,7 +3,7 @@ // CHECK: visiting op: 'module' with 0 operands and 0 results // CHECK: 1 nested regions: // CHECK: Region with 1 blocks: -// CHECK: Block with 0 arguments, 0 successors, and 3 operations +// CHECK: Block with 0 arguments, 0 successors, and 2 operations module { @@ -52,6 +52,4 @@ "dialect.innerop7"() : () -> () }) : () -> () -// CHECK: visiting op: 'module_terminator' with 0 operands and 0 results - } // module diff --git a/mlir/test/Transforms/test-legalizer-analysis.mlir b/mlir/test/Transforms/test-legalizer-analysis.mlir --- a/mlir/test/Transforms/test-legalizer-analysis.mlir +++ b/mlir/test/Transforms/test-legalizer-analysis.mlir @@ -1,6 +1,5 @@ // RUN: mlir-opt -allow-unregistered-dialect -test-legalize-patterns -verify-diagnostics -test-legalize-mode=analysis %s | FileCheck %s // expected-remark@-2 {{op 'module' is legalizable}} -// expected-remark@-3 {{op 'module_terminator' is legalizable}} // expected-remark@+1 {{op 'func' is legalizable}} func @test(%arg0: f32) { diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -570,7 +570,7 @@ // Define the conversion target used for the test. ConversionTarget target(getContext()); - target.addLegalOp(); + target.addLegalOp(); target.addLegalOp(); target @@ -699,7 +699,7 @@ patterns.insert(&getContext()); mlir::ConversionTarget target(getContext()); - target.addLegalOp(); + target.addLegalOp(); // We make OneVResOneVOperandOp1 legal only when it has more that one // operand. This will trigger the conversion that will replace one-operand // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1. @@ -967,9 +967,8 @@ .insert( context); ConversionTarget target(*context); - target.addLegalOp(); + target.addLegalOp(); target.addIllegalOp(); /// Expect the op to have a single block after legalization. diff --git a/mlir/test/lib/Transforms/TestConvVectorization.cpp b/mlir/test/lib/Transforms/TestConvVectorization.cpp --- a/mlir/test/lib/Transforms/TestConvVectorization.cpp +++ b/mlir/test/lib/Transforms/TestConvVectorization.cpp @@ -55,7 +55,7 @@ ConversionTarget target(*context); target.addLegalDialect(); - target.addLegalOp(); + target.addLegalOp(); target.addLegalOp(); SmallVector stage1Patterns; diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -437,6 +437,15 @@ llvm::any_of(op.getTraits(), [](const OpTrait &trait) { return trait.getDef().isSubClassOf("SingleBlockImplicitTerminator"); }); + + hasSingleBlockTrait = + hasImplicitTermTrait || + llvm::any_of(op.getTraits(), [](const OpTrait &trait) { + if (auto *native = dyn_cast(&trait)) { + return native->getTrait() == "::mlir::OpTrait::SingleBlock"; + } + return false; + }); } /// Generate the operation parser from this format. @@ -472,6 +481,9 @@ /// trait. bool hasImplicitTermTrait; + /// A flag indicating if this operation has the SingleBlock trait. + bool hasSingleBlockTrait; + /// A map of buildable types to indices. llvm::MapVector> buildableTypes; @@ -667,6 +679,14 @@ ensureTerminator(*region, parser.getBuilder(), result.location); )"; +/// The code snippet used to ensure a list of regions have a block. +/// +/// {0}: The name of the region list. +const char *regionListEnsureSingleBlockParserCode = R"( + for (auto ®ion : {0}Regions) + if (region.empty()) *{0}Region.emplaceBlock(); +)"; + /// The code snippet used to generate a parser call for an optional region. /// /// {0}: The name of the region. @@ -693,6 +713,13 @@ ensureTerminator(*{0}Region, parser.getBuilder(), result.location); )"; +/// The code snippet used to ensure a region has a block. +/// +/// {0}: The name of the region. +const char *regionEnsureSingleBlockParserCode = R"( + if ({0}Region->empty()) {0}Region->emplaceBlock(); +)"; + /// The code snippet used to generate a parser call for a successor list. /// /// {0}: The name for the successor list. @@ -1120,6 +1147,9 @@ body << " if (!" << region->name << "Region->empty()) {\n "; if (hasImplicitTermTrait) body << llvm::formatv(regionEnsureTerminatorParserCode, region->name); + else if (hasSingleBlockTrait) + body << llvm::formatv(regionEnsureSingleBlockParserCode, + region->name); } } @@ -1193,11 +1223,14 @@ bool isVariadic = region->getVar()->isVariadic(); body << llvm::formatv(isVariadic ? regionListParserCode : regionParserCode, region->getVar()->name); - if (hasImplicitTermTrait) { + if (hasImplicitTermTrait) body << llvm::formatv(isVariadic ? regionListEnsureTerminatorParserCode : regionEnsureTerminatorParserCode, region->getVar()->name); - } + else if (hasSingleBlockTrait) + body << llvm::formatv(isVariadic ? regionListEnsureSingleBlockParserCode + : regionEnsureSingleBlockParserCode, + region->getVar()->name); } else if (auto *successor = dyn_cast(element)) { bool isVariadic = successor->getVar()->isVariadic(); @@ -1222,6 +1255,8 @@ body << llvm::formatv(regionListParserCode, "full"); if (hasImplicitTermTrait) body << llvm::formatv(regionListEnsureTerminatorParserCode, "full"); + else if (hasSingleBlockTrait) + body << llvm::formatv(regionListEnsureSingleBlockParserCode, "full"); } else if (isa(element)) { body << llvm::formatv(successorListParserCode, "full");