diff --git a/mlir/include/mlir/IR/Region.h b/mlir/include/mlir/IR/Region.h --- a/mlir/include/mlir/IR/Region.h +++ b/mlir/include/mlir/IR/Region.h @@ -74,6 +74,9 @@ BlockArgListType getArguments() { return empty() ? BlockArgListType() : front().getArguments(); } + + ValueTypeRange getArgumentTypes(); + using args_iterator = BlockArgListType::iterator; using reverse_args_iterator = BlockArgListType::reverse_iterator; args_iterator args_begin() { return getArguments().begin(); } diff --git a/mlir/include/mlir/IR/Visitors.h b/mlir/include/mlir/IR/Visitors.h --- a/mlir/include/mlir/IR/Visitors.h +++ b/mlir/include/mlir/IR/Visitors.h @@ -76,7 +76,7 @@ // upon the type of the callback function. /// Walk all of the operations nested under and including the given operation. -/// This method is selected for callbacks that operation on Operation*. +/// This method is selected for callbacks that operate on Operation*. /// /// Example: /// op->walk([](Operation *op) { ... }); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -43,9 +43,9 @@ return false; using mlir::matchers::m_Val; - auto a = m_Val(r.front().getArgument(0)); - auto b = m_Val(r.front().getArgument(1)); - auto c = m_Val(r.front().getArgument(2)); + auto a = m_Val(r.getArgument(0)); + auto b = m_Val(r.getArgument(1)); + auto c = m_Val(r.getArgument(2)); // TODO: Update this detection once we have matcher support for specifying // that any permutation of operands matches. auto pattern1 = m_Op(m_Op(m_Op(a, b), c)); diff --git a/mlir/lib/IR/Region.cpp b/mlir/lib/IR/Region.cpp --- a/mlir/lib/IR/Region.cpp +++ b/mlir/lib/IR/Region.cpp @@ -33,6 +33,11 @@ return container->getLoc(); } +/// Return a range containing the types of the arguments for this region. +auto Region::getArgumentTypes() -> ValueTypeRange { + return ValueTypeRange(getArguments()); +} + /// Add one argument to the argument list for each type specified in the list. iterator_range Region::addArguments(TypeRange types) { return front().addArguments(types); diff --git a/mlir/lib/Transforms/Utils/InliningUtils.cpp b/mlir/lib/Transforms/Utils/InliningUtils.cpp --- a/mlir/lib/Transforms/Utils/InliningUtils.cpp +++ b/mlir/lib/Transforms/Utils/InliningUtils.cpp @@ -334,7 +334,7 @@ mapper.map(regionArg, operand); } - // Ensure that the resultant values of the call, match the callable. + // Ensure that the resultant values of the call match the callable. castBuilder.setInsertionPointAfter(call); for (unsigned i = 0, e = callResults.size(); i != e; ++i) { Value callResult = callResults[i];