diff --git a/mlir/include/mlir/IR/BlockAndValueMapping.h b/mlir/include/mlir/IR/BlockAndValueMapping.h --- a/mlir/include/mlir/IR/BlockAndValueMapping.h +++ b/mlir/include/mlir/IR/BlockAndValueMapping.h @@ -27,9 +27,26 @@ public: /// Inserts a new mapping for 'from' to 'to'. If there is an existing mapping, /// it is overwritten. - void map(Block *from, Block *to) { valueMap[from] = to; } - void map(Value from, Value to) { - valueMap[from.getAsOpaquePointer()] = to.getAsOpaquePointer(); + template ::value && + std::is_assignable::value> * = nullptr> + void map(S from, T to) { + valueMap[Value{from}.getAsOpaquePointer()] = Value{to}.getAsOpaquePointer(); + } + + template ::value && + std::is_same::value> * = nullptr> + void map(S from, T to) { + valueMap[from] = to; + } + + template ::value && + !std::is_same::value> * = nullptr> + void map(S from, T to) { + for (auto pair : llvm::zip(from, to)) + map(std::get<0>(pair), std::get<1>(pair)); } /// Erases a mapping for 'from'. diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp @@ -273,13 +273,11 @@ // 2. Inline region, currently only works for a single basic block. BlockAndValueMapping map; auto &block = genericOp.region().front(); - for (auto it : llvm::zip(block.getArguments(), indexedValues)) - map.map(std::get<0>(it), std::get<1>(it)); + map.map(block.getArguments(), indexedValues); for (auto &op : block.without_terminator()) { assert(op.getNumRegions() == 0); auto *newOp = b.clone(op, map); - for (auto it : llvm::zip(op.getResults(), newOp->getResults())) - map.map(std::get<0>(it), std::get<1>(it)); + map.map(op.getResults(), newOp->getResults()); } // 3. Emit std_store. @@ -377,13 +375,11 @@ // 2. Inline region, currently only works for a single basic block. BlockAndValueMapping map; auto &block = indexedGenericOp.region().front(); - for (auto it : llvm::zip(block.getArguments(), indexedValues)) - map.map(std::get<0>(it), std::get<1>(it)); + map.map(block.getArguments(), indexedValues); for (auto &op : block.without_terminator()) { assert(op.getNumRegions() == 0); auto *newOp = b.clone(op, map); - for (auto it : llvm::zip(op.getResults(), newOp->getResults())) - map.map(std::get<0>(it), std::get<1>(it)); + map.map(op.getResults(), newOp->getResults()); } // 3. Emit std_store.