diff --git a/mlir/include/mlir/IR/BlockAndValueMapping.h b/mlir/include/mlir/IR/BlockAndValueMapping.h --- a/mlir/include/mlir/IR/BlockAndValueMapping.h +++ b/mlir/include/mlir/IR/BlockAndValueMapping.h @@ -32,6 +32,15 @@ valueMap[from.getAsOpaquePointer()] = to.getAsOpaquePointer(); } + template < + typename S, typename T, + std::enable_if_t::value && + !std::is_assignable::value> * = nullptr> + void map(S &&from, T &&to) { + for (auto pair : llvm::zip(from, to)) + map(std::get<0>(pair), std::get<1>(pair)); + } + /// Erases a mapping for 'from'. void erase(Block *from) { valueMap.erase(from); } void erase(Value from) { valueMap.erase(from.getAsOpaquePointer()); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp @@ -273,13 +273,11 @@ // 2. Inline region, currently only works for a single basic block. BlockAndValueMapping map; auto &block = genericOp.region().front(); - for (auto it : llvm::zip(block.getArguments(), indexedValues)) - map.map(std::get<0>(it), std::get<1>(it)); + map.map(block.getArguments(), indexedValues); for (auto &op : block.without_terminator()) { assert(op.getNumRegions() == 0); auto *newOp = b.clone(op, map); - for (auto it : llvm::zip(op.getResults(), newOp->getResults())) - map.map(std::get<0>(it), std::get<1>(it)); + map.map(op.getResults(), newOp->getResults()); } // 3. Emit std_store. @@ -377,13 +375,11 @@ // 2. Inline region, currently only works for a single basic block. BlockAndValueMapping map; auto &block = indexedGenericOp.region().front(); - for (auto it : llvm::zip(block.getArguments(), indexedValues)) - map.map(std::get<0>(it), std::get<1>(it)); + map.map(block.getArguments(), indexedValues); for (auto &op : block.without_terminator()) { assert(op.getNumRegions() == 0); auto *newOp = b.clone(op, map); - for (auto it : llvm::zip(op.getResults(), newOp->getResults())) - map.map(std::get<0>(it), std::get<1>(it)); + map.map(op.getResults(), newOp->getResults()); } // 3. Emit std_store.