diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -357,12 +357,7 @@ return *this; } - /// In the particular case where the vector has a single dimension that we - /// drop, return the scalar element type. - // TODO: unify once we have a VectorType that supports 0-D. - operator Type() { - if (shape.empty()) - return elementType; + operator VectorType() { return VectorType::get(shape, elementType, scalableDims); } diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2216,7 +2216,7 @@ return failure(); if (mask.size() != 1) return failure(); - Type resType = VectorType::Builder(v1VectorType).setShape({1}); + VectorType resType = VectorType::Builder(v1VectorType).setShape({1}); if (llvm::cast(mask[0]).getInt() == 0) rewriter.replaceOpWithNewOp(shuffleOp, resType, shuffleOp.getV1()); diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp @@ -89,21 +89,20 @@ PatternRewriter &rewriter) { if (index == -1) return val; - Type lowType = VectorType::Builder(type).dropDim(0); + Type lowType = type.getRank() > 1 ? VectorType::Builder(type).dropDim(0) + : type.getElementType(); // At extraction dimension? if (index == 0) return rewriter.create(loc, lowType, val, pos); // Unroll leading dimensions. VectorType vType = cast(lowType); - Type resType = VectorType::Builder(type).dropDim(index); - auto resVectorType = cast(resType); + VectorType resType = VectorType::Builder(type).dropDim(index); Value result = rewriter.create( - loc, resVectorType, rewriter.getZeroAttr(resVectorType)); - for (int64_t d = 0, e = resVectorType.getDimSize(0); d < e; d++) { + loc, resType, rewriter.getZeroAttr(resType)); + for (int64_t d = 0, e = resType.getDimSize(0); d < e; d++) { Value ext = rewriter.create(loc, vType, val, d); Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter); - result = - rewriter.create(loc, resVectorType, load, result, d); + result = rewriter.create(loc, resType, load, result, d); } return result; } @@ -120,13 +119,13 @@ if (index == 0) return rewriter.create(loc, type, val, result, pos); // Unroll leading dimensions. - Type lowType = VectorType::Builder(type).dropDim(0); - VectorType vType = cast(lowType); - Type insType = VectorType::Builder(vType).dropDim(0); + VectorType lowType = VectorType::Builder(type).dropDim(0); + Type insType = lowType.getRank() > 1 ? VectorType::Builder(lowType).dropDim(0) + : lowType.getElementType(); for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) { - Value ext = rewriter.create(loc, vType, result, d); + Value ext = rewriter.create(loc, lowType, result, d); Value ins = rewriter.create(loc, insType, val, d); - Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter); + Value sto = reshapeStore(loc, ins, ext, lowType, index - 1, pos, rewriter); result = rewriter.create(loc, type, sto, result, d); } return result;