diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp @@ -350,13 +350,13 @@ /// Test for admissible types on operands (with output parameter `isCOO`). static bool areAdmissibleTypes(SparseTensorType aTp, SparseTensorType bTp, SparseTensorType cTp, bool enableRT, - bool &isCOO) { + bool isMatVec, bool &isCOO) { if (bTp.hasEncoding() || cTp.hasEncoding()) return false; if (isAdmissibleCOO(aTp)) { isCOO = true; #ifdef CUSPARSE_COO_AOS - return true; + return isMatVec; #else return enableRT; #endif @@ -424,7 +424,7 @@ SparseTensorType aTp = getSparseTensorType(a); SparseTensorType xTp = getSparseTensorType(x); SparseTensorType yTp = getSparseTensorType(y); - if (!areAdmissibleTypes(aTp, xTp, yTp, enableRT, isCOO)) + if (!areAdmissibleTypes(aTp, xTp, yTp, enableRT, /*isMatVec=*/true, isCOO)) return failure(); // Start sparse kernel and copy data from host to device. @@ -530,7 +530,7 @@ SparseTensorType aTp = getSparseTensorType(a); SparseTensorType bTp = getSparseTensorType(b); SparseTensorType cTp = getSparseTensorType(c); - if (!areAdmissibleTypes(aTp, bTp, cTp, enableRT, isCOO)) + if (!areAdmissibleTypes(aTp, bTp, cTp, enableRT, /*isMatVec=*/false, isCOO)) return failure(); // Start sparse kernel and copy data from host to device.