diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -1457,37 +1457,48 @@ // When strideW == 1, we can batch the contiguous loads and avoid unrolling int64_t wSizeStep = strideW == 1 ? wSize : 1; - VectorType lhsType = VectorType::get({nSize, wSizeStep, cSize}, - lhsShapedType.getElementType()); - VectorType rhsType = - VectorType::get({cSize, fSize}, rhsShapedType.getElementType()); - VectorType resType = VectorType::get({nSize, wSizeStep, fSize}, - resShapedType.getElementType()); + Type lhsEltType = lhsShapedType.getElementType(); + Type rhsEltType = rhsShapedType.getElementType(); + Type resEltType = resShapedType.getElementType(); + VectorType lhsType = VectorType::get( + {nSize, (wSize - 1) * strideW + 1 + (kwSize - 1) * dilationW + 1, + cSize}, + lhsEltType); + VectorType rhsType = VectorType::get({kwSize, cSize, fSize}, rhsEltType); + VectorType resType = VectorType::get({nSize, wSize, fSize}, resEltType); + + // Read lhs slice of size {w * strideW + kw * dilationW, c, f} @ [0, 0, 0]. + Value lhs = builder.create( + loc, lhsType, lhsShaped, ValueRange{zero, zero, zero}); + // Read rhs slice of size {kw, c, f} @ [0, 0, 0]. + Value rhs = builder.create( + loc, rhsType, rhsShaped, ValueRange{zero, zero, zero}); + // Read res slice of size {n, w, f} @ [0, 0, 0]. + Value res = builder.create( + loc, resType, resShaped, ValueRange{zero, zero, zero}); - SmallVector lhsVals, rhsVals, resVals; // Unroll along kw and read slices of lhs and rhs. - // Alternatively we could preload both 3-d slices and extract smaller slices - // iteratively without touching memory. But this will quickly spill. + SmallVector lhsVals, rhsVals, resVals; for (int64_t kw = 0; kw < kwSize; ++kw) { - // Read rhs slice of size {c, f} @ [kw, 0, 0]. - Value kwVal = builder.create(loc, kw); - rhsVals.push_back(builder.create( - loc, rhsType, rhsShaped, ValueRange{kwVal, zero, zero})); + // Extract rhs slice of size {c, f} @ [kw]. + rhsVals.push_back(builder.create( + loc, rhs, /*offsets=*/ArrayRef{kw})); - for (int64_t w_iv = 0; w_iv < wSize; w_iv += wSizeStep) { - // Read lhs slice of size {n, wSizeStep, c} + for (int64_t w = 0; w < wSize; w += wSizeStep) { + // Extract lhs slice of size {n, wSizeStep, c} // @ [0, sw * w + dw * kw, 0]. - Value lhsStridedIdx = builder.create( - loc, strideW * w_iv + dilationW * kw); - lhsVals.push_back(builder.create( - loc, lhsType, lhsShaped, ValueRange{zero, lhsStridedIdx, zero})); - - // Read res slice: {n, wSizeStep, f} @ [0, w, 0]. - Value wVal = builder.create(loc, w_iv); - // When operating on tensors, reading from the updated value is required - // for vector.transfer_read/write hoisting to function as expected. - resVals.push_back(builder.create( - loc, resType, resShaped, ValueRange{zero, wVal, zero})); + lhsVals.push_back(builder.create( + loc, lhs, + /*offsets=*/ArrayRef{0, w * strideW + kw * dilationW, 0}, + /*sizes=*/ArrayRef{nSize, wSizeStep, cSize}, + /*strides=*/ArrayRef{1, 1, 1})); + + // Extract res slice: {n, wSizeStep, f} @ [0, w, 0]. + resVals.push_back(builder.create( + loc, res, + /*offsets=*/ArrayRef{0, w, 0}, + /*sizes=*/ArrayRef{nSize, wSizeStep, fSize}, + /*strides=*/ArrayRef{1, 1, 1})); } } @@ -1503,19 +1514,22 @@ resVals[linearIndex(kw, w_iv)]); } } + for (int64_t kw = 0; kw < kwSize; ++kw) { - for (int64_t w_iv = 0; w_iv < wSize; w_iv += wSizeStep) { - Value wVal = builder.create(loc, w_iv); + for (int64_t w = 0; w < wSize; w += wSizeStep) { // Write back res slice: {n, wSizeStep, f} @ [0, w, 0]. - write = builder.create( - loc, resVals[linearIndex(kw, w_iv)], resShaped, - ValueRange{zero, wVal, zero}); - if (write.getNumResults() == 1) - resShaped = write->getResult(0); + res = builder.create( + loc, resVals[linearIndex(kw, w)], res, + /*offsets=*/ArrayRef{0, w, 0}, + /*strides=*/ArrayRef{1, 1, 1}); } } - return write.getOperation(); + // Write back res slice of size {n, w, f} @ [0, 0, 0]. + return builder + .create(loc, res, resShaped, + ValueRange{zero, zero, zero}) + .getOperation(); } // Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f}