diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp --- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp @@ -302,13 +302,8 @@ << index << " to have no symbols"; auto vectorType = op.getOperand(index).getType().dyn_cast(); unsigned rank = vectorType ? vectorType.getShape().size() : 0; - // Since (...) -> () is parsed into an empty map, we need to add - // a special case for this situation: continue the verification - // of an empty map if the resulting rank is indeed zero, i.e. this - // is a reduction into a scalar. - if (map.getNumDims() == 0 && map.getNumResults() == 0 && rank == 0) - continue; // Verify that the map has the right number of inputs, outputs, and indices. + // This also correctly accounts for (..) -> () for rank-0 results. if (map.getNumDims() != numIterators) return op.emitOpError("expected indexing map ") << index << " to have " << numIterators << " number of inputs"; diff --git a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp --- a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp @@ -1077,6 +1077,7 @@ // Helper to construct an affine map with one index removed. static AffineMap adjustMap(AffineMap map, int64_t index, PatternRewriter &rewriter) { + auto *ctx = rewriter.getContext(); SmallVector results; for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { int64_t idx = map.getResult(i).cast().getPosition(); @@ -1084,13 +1085,11 @@ continue; // Re-insert remaining indices, but renamed when occurring // after the removed index. - auto targetExpr = - getAffineDimExpr(idx < index ? idx : idx - 1, rewriter.getContext()); + auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, ctx); results.push_back(targetExpr); } - // Since (...) -> () cannot be represented properly, - // we resort to an empty map when this situation happens. - return results.empty() ? AffineMap::get(rewriter.getContext()) + // The (...) -> () affine map has its own factory method. + return results.empty() ? AffineMap::get(map.getNumDims() - 1, 0, ctx) : AffineMap::get(map.getNumDims() - 1, 0, results); }