diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td @@ -237,7 +237,7 @@ /// the workgroup memory ArrayRef getWorkgroupAttributions() { auto begin = - std::next(getBody().front().args_begin(), getType().getNumInputs()); + std::next(getBody().args_begin(), getType().getNumInputs()); auto end = std::next(begin, getNumWorkgroupAttributions()); return {begin, end}; } @@ -248,7 +248,7 @@ /// Returns the number of buffers located in the private memory. unsigned getNumPrivateAttributions() { - return getBody().front().getNumArguments() - getType().getNumInputs() - + return getBody().getNumArguments() - getType().getNumInputs() - getNumWorkgroupAttributions(); } @@ -258,9 +258,9 @@ // Buffers on the private memory always come after buffers on the workgroup // memory. auto begin = - std::next(getBody().front().args_begin(), + std::next(getBody().args_begin(), getType().getNumInputs() + getNumWorkgroupAttributions()); - return {begin, getBody().front().args_end()}; + return {begin, getBody().args_end()}; } /// Adds a new block argument that corresponds to buffers located in diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -553,7 +553,7 @@ let extraClassDeclaration = [{ // The value stored in memref[ivs]. Value getCurrentValue() { - return body().front().getArgument(0); + return body().getArgument(0); } MemRefType getMemRefType() { return memref().getType().cast(); diff --git a/mlir/include/mlir/IR/FunctionSupport.h b/mlir/include/mlir/IR/FunctionSupport.h --- a/mlir/include/mlir/IR/FunctionSupport.h +++ b/mlir/include/mlir/IR/FunctionSupport.h @@ -216,15 +216,13 @@ } /// Gets argument. - BlockArgument getArgument(unsigned idx) { - return getBlocks().front().getArgument(idx); - } + BlockArgument getArgument(unsigned idx) { return getBody().getArgument(idx); } /// Support argument iteration. - using args_iterator = Block::args_iterator; - args_iterator args_begin() { return front().args_begin(); } - args_iterator args_end() { return front().args_end(); } - Block::BlockArgListType getArguments() { return front().getArguments(); } + using args_iterator = Region::args_iterator; + args_iterator args_begin() { return getBody().args_begin(); } + args_iterator args_end() { return getBody().args_end(); } + Block::BlockArgListType getArguments() { return getBody().getArguments(); } //===--------------------------------------------------------------------===// // Argument Attributes diff --git a/mlir/include/mlir/IR/Region.h b/mlir/include/mlir/IR/Region.h --- a/mlir/include/mlir/IR/Region.h +++ b/mlir/include/mlir/IR/Region.h @@ -16,6 +16,9 @@ #include "mlir/IR/Block.h" namespace mlir { +class TypeRange; +template +class ValueTypeRange; class BlockAndValueMapping; /// This class contains a list of basic blocks and a link to the parent @@ -62,6 +65,48 @@ return &Region::blocks; } + //===--------------------------------------------------------------------===// + // Argument Handling + //===--------------------------------------------------------------------===// + + // This is the list of arguments to the block. + using BlockArgListType = MutableArrayRef; + BlockArgListType getArguments() { + return empty() ? BlockArgListType() : front().getArguments(); + } + using args_iterator = BlockArgListType::iterator; + using reverse_args_iterator = BlockArgListType::reverse_iterator; + args_iterator args_begin() { return getArguments().begin(); } + args_iterator args_end() { return getArguments().end(); } + reverse_args_iterator args_rbegin() { return getArguments().rbegin(); } + reverse_args_iterator args_rend() { return getArguments().rend(); } + + bool args_empty() { return getArguments().empty(); } + + /// Add one value to the argument list. + BlockArgument addArgument(Type type) { return front().addArgument(type); } + + /// Insert one value to the position in the argument list indicated by the + /// given iterator. The existing arguments are shifted. The block is expected + /// not to have predecessors. + BlockArgument insertArgument(args_iterator it, Type type) { + return front().insertArgument(it, type); + } + + /// Add one argument to the argument list for each type specified in the list. + iterator_range addArguments(TypeRange types); + + /// Add one value to the argument list at the specified position. + BlockArgument insertArgument(unsigned index, Type type) { + return front().insertArgument(index, type); + } + + /// Erase the argument at 'index' and remove it from the argument list. + void eraseArgument(unsigned index) { front().eraseArgument(index); } + + unsigned getNumArguments() { return getArguments().size(); } + BlockArgument getArgument(unsigned i) { return getArguments()[i]; } + //===--------------------------------------------------------------------===// // Operation list utilities //===--------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp --- a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp +++ b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp @@ -417,8 +417,8 @@ if (isMappedToProcessor(processor)) { // Use the corresponding thread/grid index as replacement for the loop iv. - Value operand = launchOp.body().front().getArgument( - getLaunchOpArgumentNum(processor)); + Value operand = + launchOp.body().getArgument(getLaunchOpArgumentNum(processor)); // Take the indexmap and add the lower bound and step computations in. // This computes operand * step + lowerBound. // Use an affine map here so that it composes nicely with the provided diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -127,9 +127,9 @@ return allReduce.emitError( "expected either an op attribute or a non-empty body"); if (!allReduce.body().empty()) { - if (allReduce.body().front().getNumArguments() != 2) + if (allReduce.body().getNumArguments() != 2) return allReduce.emitError("expected two region arguments"); - for (auto argument : allReduce.body().front().getArguments()) { + for (auto argument : allReduce.body().getArguments()) { if (argument.getType() != allReduce.getType()) return allReduce.emitError("incorrect region argument type"); } @@ -219,25 +219,25 @@ KernelDim3 LaunchOp::getBlockIds() { assert(!body().empty() && "LaunchOp body must not be empty."); - auto args = body().front().getArguments(); + auto args = body().getArguments(); return KernelDim3{args[0], args[1], args[2]}; } KernelDim3 LaunchOp::getThreadIds() { assert(!body().empty() && "LaunchOp body must not be empty."); - auto args = body().front().getArguments(); + auto args = body().getArguments(); return KernelDim3{args[3], args[4], args[5]}; } KernelDim3 LaunchOp::getGridSize() { assert(!body().empty() && "LaunchOp body must not be empty."); - auto args = body().front().getArguments(); + auto args = body().getArguments(); return KernelDim3{args[6], args[7], args[8]}; } KernelDim3 LaunchOp::getBlockSize() { assert(!body().empty() && "LaunchOp body must not be empty."); - auto args = body().getBlocks().front().getArguments(); + auto args = body().getArguments(); return KernelDim3{args[9], args[10], args[11]}; } @@ -254,8 +254,7 @@ // sizes and transforms them into kNumConfigRegionAttributes region arguments // for block/thread identifiers and grid/block sizes. if (!op.body().empty()) { - Block &entryBlock = op.body().front(); - if (entryBlock.getNumArguments() != + if (op.body().getNumArguments() != LaunchOp::kNumConfigOperands + op.getNumOperands()) return op.emitOpError("unexpected number of region arguments"); } @@ -463,8 +462,8 @@ auto attrName = getNumWorkgroupAttributionsAttrName(); auto attr = getAttrOfType(attrName); setAttr(attrName, IntegerAttr::get(attr.getType(), attr.getValue() + 1)); - return getBody().front().insertArgument( - getType().getNumInputs() + attr.getInt(), type); + return getBody().insertArgument(getType().getNumInputs() + attr.getInt(), + type); } /// Adds a new block argument that corresponds to buffers located in @@ -472,7 +471,7 @@ BlockArgument GPUFuncOp::addPrivateAttribution(Type type) { // Buffers on the private memory always come after buffers on the workgroup // memory. - return getBody().front().addArgument(type); + return getBody().addArgument(type); } void GPUFuncOp::build(OpBuilder &builder, OperationState &result, diff --git a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp --- a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp @@ -181,8 +181,8 @@ // Insert accumulator body between split block. BlockAndValueMapping mapping; - mapping.map(body.front().getArgument(0), lhs); - mapping.map(body.front().getArgument(1), rhs); + mapping.map(body.getArgument(0), lhs); + mapping.map(body.getArgument(1), rhs); rewriter.cloneRegionBefore(body, *split->getParent(), split->getIterator(), mapping); diff --git a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp @@ -1102,8 +1102,7 @@ unsigned argIndex, NamedAttribute attribute) { return verifyRegionAttribute( - op->getLoc(), - op->getRegion(regionIndex).front().getArgument(argIndex).getType(), + op->getLoc(), op->getRegion(regionIndex).getArgument(argIndex).getType(), attribute); } diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -500,22 +500,21 @@ Region *bodyRegion = result.addRegion(); bodyRegion->push_back(new Block()); - bodyRegion->front().addArgument(elementType); + bodyRegion->addArgument(elementType); } } static LogicalResult verify(GenericAtomicRMWOp op) { - auto &block = op.body().front(); - if (block.getNumArguments() != 1) + auto &body = op.body(); + if (body.getNumArguments() != 1) return op.emitOpError("expected single number of entry block arguments"); - if (op.getResult().getType() != block.getArgument(0).getType()) + if (op.getResult().getType() != body.getArgument(0).getType()) return op.emitOpError( "expected block argument of the same type result type"); bool hasSideEffects = - op.body() - .walk([&](Operation *nestedOp) { + body.walk([&](Operation *nestedOp) { if (MemoryEffectOpInterface::hasNoEffect(nestedOp)) return WalkResult::advance(); nestedOp->emitError("body of 'generic_atomic_rmw' should contain " diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -619,7 +619,7 @@ void SSANameState::shadowRegionArgs(Region ®ion, ValueRange namesToUse) { assert(!region.empty() && "cannot shadow arguments of an empty region"); - assert(region.front().getNumArguments() == namesToUse.size() && + assert(region.getNumArguments() == namesToUse.size() && "incorrect number of names passed in"); assert(region.getParentOp()->isKnownIsolatedFromAbove() && "only KnownIsolatedFromAbove ops can shadow names"); @@ -629,7 +629,7 @@ auto nameToUse = namesToUse[i]; if (nameToUse == nullptr) continue; - auto nameToReplace = region.front().getArgument(i); + auto nameToReplace = region.getArgument(i); nameStr.clear(); llvm::raw_svector_ostream nameStream(nameStr); diff --git a/mlir/lib/IR/FunctionImplementation.cpp b/mlir/lib/IR/FunctionImplementation.cpp --- a/mlir/lib/IR/FunctionImplementation.cpp +++ b/mlir/lib/IR/FunctionImplementation.cpp @@ -238,7 +238,7 @@ p << ", "; if (!isExternal) { - p.printOperand(body.front().getArgument(i)); + p.printOperand(body.getArgument(i)); p << ": "; } diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -1022,7 +1022,7 @@ if (region.empty()) continue; - if (region.front().getNumArguments() != 0) { + if (region.getNumArguments() != 0) { if (op->getNumRegions() > 1) return op->emitOpError("region #") << region.getRegionNumber() << " should have no arguments"; diff --git a/mlir/lib/IR/Region.cpp b/mlir/lib/IR/Region.cpp --- a/mlir/lib/IR/Region.cpp +++ b/mlir/lib/IR/Region.cpp @@ -33,6 +33,11 @@ return container->getLoc(); } +/// Add one argument to the argument list for each type specified in the list. +iterator_range Region::addArguments(TypeRange types) { + return front().addArguments(types); +} + Region *Region::getParentRegion() { assert(container && "region is not attached to a container"); return container->getParentRegion(); diff --git a/mlir/lib/Transforms/SCCP.cpp b/mlir/lib/Transforms/SCCP.cpp --- a/mlir/lib/Transforms/SCCP.cpp +++ b/mlir/lib/Transforms/SCCP.cpp @@ -123,7 +123,7 @@ /// Build a lattice state with a given callable region, and a specified number /// of results to be initialized to the default lattice value (Unknown). CallableLatticeState(Region *callableRegion, unsigned numResults) - : callableArguments(callableRegion->front().getArguments()), + : callableArguments(callableRegion->getArguments()), resultLatticeValues(numResults) {} /// Returns the arguments to the callable region. @@ -403,7 +403,7 @@ // If not all of the uses of this symbol are visible, we can't track the // state of the arguments. if (symbol.isPublic() || (!allUsesVisible && symbol.isNested())) - markAllOverdefined(callableRegion->front().getArguments()); + markAllOverdefined(callableRegion->getArguments()); } if (callableLatticeState.empty()) return; diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -284,7 +284,7 @@ ConversionPatternRewriter &rewriter) const final { auto illegalOp = rewriter.create(op->getLoc(), rewriter.getF32Type()); - rewriter.replaceUsesOfBlockArgument(op->getRegion(0).front().getArgument(0), + rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0), illegalOp); rewriter.updateRootInPlace(op, [] {}); return success();