diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" @@ -889,11 +890,18 @@ VectorType vtp = vectorType(codegen, itype); ival = rewriter.create(loc, vtp, ival); if (idx == ldx) { - SmallVector integers; - for (unsigned i = 0; i < vl; i++) - integers.push_back(APInt(/*width=*/64, i)); - auto values = DenseElementsAttr::get(vtp, integers); - Value incr = rewriter.create(loc, vtp, values); + Value incr; + if (vtp.isScalable()) { + Type stepvty = vectorType(codegen, rewriter.getI64Type()); + Value stepv = rewriter.create(loc, stepvty); + incr = rewriter.create(loc, vtp, stepv); + } else { + SmallVector integers; + for (unsigned i = 0; i < vl; i++) + integers.push_back(APInt(/*width=*/64, i)); + auto values = DenseElementsAttr::get(vtp, integers); + incr = rewriter.create(loc, vtp, values); + } ival = rewriter.create(loc, ival, incr); } }