diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -929,10 +929,10 @@ /// One: /// %x = vector.insert_slices %0 /// is replaced by: -/// %r0 = vector.splat 0 -// %t1 = vector.tuple_get %0, 0 +/// %r0 = zero-result +/// %t1 = vector.tuple_get %0, 0 /// %r1 = vector.insert_strided_slice %r0, %t1 -// %t2 = vector.tuple_get %0, 1 +/// %t2 = vector.tuple_get %0, 1 /// %r2 = vector.insert_strided_slice %r1, %t2 /// .. /// %x = .. @@ -953,10 +953,8 @@ op.getStrides(strides); // all-ones at the moment // Prepare result. - auto elemType = vectorType.getElementType(); - Value zero = rewriter.create(loc, elemType, - rewriter.getZeroAttr(elemType)); - Value result = rewriter.create(loc, vectorType, zero); + Value result = rewriter.create( + loc, vectorType, rewriter.getZeroAttr(vectorType)); // For each element in the tuple, extract the proper strided slice. TupleType tupleType = op.getSourceTupleType(); @@ -1015,9 +1013,8 @@ VectorType::get(dstType.getShape().drop_front(), eltType); Value bcst = rewriter.create(loc, resType, op.source()); - Value zero = rewriter.create(loc, eltType, - rewriter.getZeroAttr(eltType)); - Value result = rewriter.create(loc, dstType, zero); + Value result = rewriter.create(loc, dstType, + rewriter.getZeroAttr(dstType)); for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) result = rewriter.create(loc, bcst, result, d); rewriter.replaceOp(op, result); @@ -1064,9 +1061,8 @@ // %x = [%a,%b,%c,%d] VectorType resType = VectorType::get(dstType.getShape().drop_front(), eltType); - Value zero = rewriter.create(loc, eltType, - rewriter.getZeroAttr(eltType)); - Value result = rewriter.create(loc, dstType, zero); + Value result = rewriter.create(loc, dstType, + rewriter.getZeroAttr(dstType)); if (m == 0) { // Stetch at start. Value ext = rewriter.create(loc, op.source(), 0); @@ -1104,7 +1100,6 @@ auto loc = op.getLoc(); VectorType resType = op.getResultType(); - Type eltType = resType.getElementType(); // Set up convenience transposition table. SmallVector transp; @@ -1112,9 +1107,8 @@ transp.push_back(attr.cast().getInt()); // Generate fully unrolled extract/insert ops. - Value zero = rewriter.create(loc, eltType, - rewriter.getZeroAttr(eltType)); - Value result = rewriter.create(loc, resType, zero); + Value result = rewriter.create(loc, resType, + rewriter.getZeroAttr(resType)); SmallVector lhs(transp.size(), 0); SmallVector rhs(transp.size(), 0); rewriter.replaceOp(op, expandIndices(loc, resType, 0, transp, lhs, rhs, @@ -1173,9 +1167,8 @@ Type eltType = resType.getElementType(); Value acc = (op.acc().empty()) ? nullptr : op.acc()[0]; - Value zero = rewriter.create(loc, eltType, - rewriter.getZeroAttr(eltType)); - Value result = rewriter.create(loc, resType, zero); + Value result = rewriter.create(loc, resType, + rewriter.getZeroAttr(resType)); for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) { auto pos = rewriter.getI64ArrayAttr(d); Value x = rewriter.create(loc, eltType, op.lhs(), pos); @@ -1346,7 +1339,8 @@ rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex)); // Unroll into a series of lower dimensional vector.contract ops. Location loc = op.getLoc(); - Value result = zeroVector(loc, resType, rewriter); + Value result = rewriter.create(loc, resType, + rewriter.getZeroAttr(resType)); for (int64_t d = 0; d < dimSize; ++d) { auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter); auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter); @@ -1381,7 +1375,8 @@ // Base case. if (lhsType.getRank() == 1) { assert(rhsType.getRank() == 1 && "corrupt contraction"); - Value zero = zeroVector(loc, lhsType, rewriter); + Value zero = rewriter.create(loc, lhsType, + rewriter.getZeroAttr(lhsType)); Value fma = rewriter.create(loc, op.lhs(), op.rhs(), zero); StringAttr kind = rewriter.getStringAttr("add"); return rewriter.create(loc, resType, kind, fma, @@ -1409,15 +1404,6 @@ return result; } - // Helper method to construct a zero vector. - static Value zeroVector(Location loc, VectorType vType, - PatternRewriter &rewriter) { - Type eltType = vType.getElementType(); - Value zero = rewriter.create(loc, eltType, - rewriter.getZeroAttr(eltType)); - return rewriter.create(loc, vType, zero); - } - // Helper to find an index in an affine map. static Optional getResultIndex(AffineMap map, int64_t index) { for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { @@ -1493,7 +1479,8 @@ // Unroll leading dimensions. VectorType vType = lowType.cast(); VectorType resType = adjustType(type, index).cast(); - Value result = zeroVector(loc, resType, rewriter); + Value result = rewriter.create(loc, resType, + rewriter.getZeroAttr(resType)); for (int64_t d = 0, e = resType.getDimSize(0); d < e; d++) { auto posAttr = rewriter.getI64ArrayAttr(d); Value ext = rewriter.create(loc, vType, val, posAttr); @@ -1555,10 +1542,8 @@ return failure(); auto loc = op.getLoc(); - auto elemType = sourceVectorType.getElementType(); - Value zero = rewriter.create(loc, elemType, - rewriter.getZeroAttr(elemType)); - Value desc = rewriter.create(loc, resultVectorType, zero); + Value desc = rewriter.create( + loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); unsigned mostMinorVectorSize = sourceVectorType.getShape()[1]; for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) { Value vec = rewriter.create(loc, op.source(), i); @@ -1589,10 +1574,8 @@ return failure(); auto loc = op.getLoc(); - auto elemType = sourceVectorType.getElementType(); - Value zero = rewriter.create(loc, elemType, - rewriter.getZeroAttr(elemType)); - Value desc = rewriter.create(loc, resultVectorType, zero); + Value desc = rewriter.create( + loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); unsigned mostMinorVectorSize = resultVectorType.getShape()[1]; for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) { Value vec = rewriter.create(