diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -166,6 +166,19 @@ auto *ctx = &getContext(); RewritePatternSet patterns(ctx); SparseTensorTypeToBufferConverter converter; + // Required by scf.for 1:N type conversion. + converter.addSourceMaterialization( + [](OpBuilder &builder, RankedTensorType tp, ValueRange inputs, + Location loc) -> Optional { + if (!getSparseTensorEncoding(tp)) + // not a sparse tensor. + return llvm::None; + // Sparse compiler knows how to cancel out these casts. + return builder + .create(loc, TypeRange(tp), inputs) + .getResult(0); + }); + ConversionTarget target(*ctx); // Most ops in the sparse dialect must go! target.addIllegalDialect();