diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h --- a/mlir/include/mlir/IR/Block.h +++ b/mlir/include/mlir/IR/Block.h @@ -17,6 +17,9 @@ #include "mlir/IR/Visitors.h" namespace mlir { +class TypeRange; +template class ValueTypeRange; + /// `Block` represents an ordered list of `Operation`s. class Block : public IRObjectWithUseList, public llvm::ilist_node_with_parent { @@ -67,6 +70,9 @@ BlockArgListType getArguments() { return arguments; } + /// Return a range containing the types of the arguments for this block. + ValueTypeRange getArgumentTypes(); + using args_iterator = BlockArgListType::iterator; using reverse_args_iterator = BlockArgListType::reverse_iterator; args_iterator args_begin() { return getArguments().begin(); } @@ -85,15 +91,13 @@ BlockArgument insertArgument(args_iterator it, Type type); /// Add one argument to the argument list for each type specified in the list. - iterator_range addArguments(ArrayRef types); + iterator_range addArguments(TypeRange types); - // Add one value to the argument list at the specified position. + /// Add one value to the argument list at the specified position. BlockArgument insertArgument(unsigned index, Type type); - /// Erase the argument at 'index' and remove it from the argument list. If - /// 'updatePredTerms' is set to true, this argument is also removed from the - /// terminators of each predecessor to this block. - void eraseArgument(unsigned index, bool updatePredTerms = true); + /// Erase the argument at 'index' and remove it from the argument list. + void eraseArgument(unsigned index); unsigned getNumArguments() { return arguments.size(); } BlockArgument getArgument(unsigned i) { return arguments[i]; } diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -563,10 +563,19 @@ explicit TypeRange(OperandRange values); explicit TypeRange(ResultRange values); explicit TypeRange(ValueRange values); + explicit TypeRange(ArrayRef values); + explicit TypeRange(ArrayRef values) + : TypeRange(ArrayRef(values.data(), values.size())) {} template TypeRange(ValueTypeRange values) : TypeRange(ValueRangeT(values.begin().getCurrent(), values.end().getCurrent())) {} + template , Arg>::value>> + TypeRange(Arg &&arg) : TypeRange(ArrayRef(std::forward(arg))) {} + TypeRange(std::initializer_list types) + : TypeRange(ArrayRef(types)) {} private: /// The owner of the range is either: diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -2792,7 +2792,9 @@ } // namespace static void ensureDistinctSuccessors(Block &bb) { - auto *terminator = bb.getTerminator(); + Operation *terminator = bb.getTerminator(); + if (terminator->getNumSuccessors() < 2) + return; // Find repeated successors with arguments. llvm::SmallDenseMap> successorPositions; @@ -2811,21 +2813,15 @@ // There is no need to pass arguments to the dummy block because it will be // dominated by the original block and can therefore use any values defined in // the original block. + OpBuilder builder(terminator->getContext()); for (const auto &successor : successorPositions) { - const auto &positions = successor.second; // Start from the second occurrence of a block in the successor list. - for (auto position = std::next(positions.begin()), end = positions.end(); - position != end; ++position) { - auto *dummyBlock = new Block(); - bb.getParent()->push_back(dummyBlock); - auto builder = OpBuilder(dummyBlock); - SmallVector operands( - terminator->getSuccessorOperands(*position)); - builder.create(terminator->getLoc(), successor.first, operands); - terminator->setSuccessor(dummyBlock, *position); - for (int i = 0, e = terminator->getNumSuccessorOperands(*position); i < e; - ++i) - terminator->eraseSuccessorOperand(*position, i); + for (int position : llvm::drop_begin(successor.second, 1)) { + Block *dummyBlock = builder.createBlock(bb.getParent()); + terminator->setSuccessor(dummyBlock, position); + dummyBlock->addArguments(successor.first->getArgumentTypes()); + builder.create(terminator->getLoc(), successor.first, + dummyBlock->getArguments()); } } } diff --git a/mlir/lib/Dialect/AffineOps/AffineOps.cpp b/mlir/lib/Dialect/AffineOps/AffineOps.cpp --- a/mlir/lib/Dialect/AffineOps/AffineOps.cpp +++ b/mlir/lib/Dialect/AffineOps/AffineOps.cpp @@ -1621,9 +1621,8 @@ "symbol count must match"); // Verify that the operands are valid dimension/symbols. - if (failed(verifyDimAndSymbolIdentifiers( - op, op.getOperation()->getNonSuccessorOperands(), - condition.getNumDims()))) + if (failed(verifyDimAndSymbolIdentifiers(op, op.getOperands(), + condition.getNumDims()))) return failure(); // Verify that the entry of each child region does not have arguments. diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp --- a/mlir/lib/IR/Block.cpp +++ b/mlir/lib/IR/Block.cpp @@ -143,6 +143,11 @@ // Argument list management. //===----------------------------------------------------------------------===// +/// Return a range containing the types of the arguments for this block. +auto Block::getArgumentTypes() -> ValueTypeRange { + return ValueTypeRange(getArguments()); +} + BlockArgument Block::addArgument(Type type) { BlockArgument arg = BlockArgument::create(type, this); arguments.push_back(arg); @@ -150,13 +155,11 @@ } /// Add one argument to the argument list for each type specified in the list. -auto Block::addArguments(ArrayRef types) - -> iterator_range { - arguments.reserve(arguments.size() + types.size()); - auto initialSize = arguments.size(); - for (auto type : types) { +auto Block::addArguments(TypeRange types) -> iterator_range { + size_t initialSize = arguments.size(); + arguments.reserve(initialSize + types.size()); + for (auto type : types) addArgument(type); - } return {arguments.data() + initialSize, arguments.data() + arguments.size()}; } @@ -167,22 +170,8 @@ return arg; } -void Block::eraseArgument(unsigned index, bool updatePredTerms) { +void Block::eraseArgument(unsigned index) { assert(index < arguments.size()); - - // If requested, update predecessors. We do this first since this block might - // be a predecessor of itself and use this block argument as a successor - // operand. - if (updatePredTerms) { - // Erase this argument from each of the predecessor's terminator. - for (auto predIt = pred_begin(), predE = pred_end(); predIt != predE; - ++predIt) { - auto *predTerminator = (*predIt)->getTerminator(); - predTerminator->eraseSuccessorOperand(predIt.getSuccessorIndex(), index); - } - } - - // Delete the argument. arguments[index].destroy(); arguments.erase(arguments.begin() + index); } diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp --- a/mlir/lib/IR/OperationSupport.cpp +++ b/mlir/lib/IR/OperationSupport.cpp @@ -150,6 +150,8 @@ TypeRange::TypeRange(ResultRange values) : TypeRange(values.getBase()->getResultTypes().slice(values.getStartIndex(), values.size())) {} +TypeRange::TypeRange(ArrayRef values) + : TypeRange(values.data(), values.size()) {} TypeRange::TypeRange(ValueRange values) : TypeRange(OwnerT(), values.size()) { detail::ValueRangeOwner owner = values.begin().getBase(); if (auto *op = reinterpret_cast(owner.ptr.dyn_cast())) diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -502,8 +502,8 @@ "different blocks"); return condBranchOp.getSuccessor(0) == current - ? terminator.getSuccessorOperand(0, index) - : terminator.getSuccessorOperand(1, index); + ? condBranchOp.trueDestOperands()[index] + : condBranchOp.falseDestOperands()[index]; } void ModuleTranslation::connectPHINodes(LLVMFuncOp func) { diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp --- a/mlir/lib/Transforms/Utils/RegionUtils.cpp +++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp @@ -294,7 +294,7 @@ // earlier arguments. for (unsigned i = 0, e = block.getNumArguments(); i < e; i++) if (!liveMap.wasProvenLive(block.getArgument(e - i - 1))) { - block.eraseArgument(e - i - 1, /*updatePredTerms=*/false); + block.eraseArgument(e - i - 1); erasedAnything = true; } } diff --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir @@ -556,16 +556,16 @@ } // CHECK-LABEL: func @cond_br_same_target(%arg0: !llvm.i1, %arg1: !llvm.i32, %arg2: !llvm.i32) func @cond_br_same_target(%arg0: i1, %arg1: i32, %arg2 : i32) -> (i32) { -// CHECK-NEXT: llvm.cond_br %arg0, ^[[origBlock:bb[0-9]+]](%arg1 : !llvm.i32), ^[[dummyBlock:bb[0-9]+]] +// CHECK-NEXT: llvm.cond_br %arg0, ^[[origBlock:bb[0-9]+]](%arg1 : !llvm.i32), ^[[dummyBlock:bb[0-9]+]](%arg2 : !llvm.i32) cond_br %arg0, ^bb1(%arg1 : i32), ^bb1(%arg2 : i32) -// CHECK: ^[[origBlock]](%0: !llvm.i32): -// CHECK-NEXT: llvm.return %0 : !llvm.i32 +// CHECK: ^[[origBlock]](%[[BLOCKARG1:.*]]: !llvm.i32): +// CHECK-NEXT: llvm.return %[[BLOCKARG1]] : !llvm.i32 ^bb1(%0 : i32): return %0 : i32 -// CHECK: ^[[dummyBlock]]: -// CHECK-NEXT: llvm.br ^[[origBlock]](%arg2 : !llvm.i32) +// CHECK: ^[[dummyBlock]](%[[BLOCKARG2:.*]]: !llvm.i32): +// CHECK-NEXT: llvm.br ^[[origBlock]](%[[BLOCKARG2]] : !llvm.i32) } // CHECK-LABEL: func @fcmp(%arg0: !llvm.float, %arg1: !llvm.float) {