diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -745,6 +745,17 @@ mlir::SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() { addConversion([](Type type) { return type; }); addConversion(convertSparseTensorType); + + // Required by scf.for 1:N type conversion. + addSourceMaterialization([](OpBuilder &builder, RankedTensorType tp, + ValueRange inputs, + Location loc) -> Optional { + if (!getSparseTensorEncoding(tp)) + // Not a sparse tensor. + return llvm::None; + // Sparse compiler knows how to cancel out these casts. + return genTuple(builder, loc, tp, inputs); + }); } //===----------------------------------------------------------------------===//