diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp @@ -823,8 +823,7 @@ rewriter.create(loc, bufferType, buffer, capacity); if (enableBufferInitialization) { Value fillSize = rewriter.create(loc, capacity, newSize); - Value fillValue = rewriter.create( - loc, value.getType(), rewriter.getZeroAttr(value.getType())); + Value fillValue = constantZero(rewriter, loc, value.getType()); Value subBuffer = rewriter.create( loc, newBuffer, /*offset=*/ValueRange{newSize}, /*size=*/ValueRange{fillSize}, diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -235,8 +235,7 @@ Value buffer = builder.create(loc, memRefType, sz); Type elemType = memRefType.getElementType(); if (enableInit) { - Value fillValue = builder.create( - loc, elemType, builder.getZeroAttr(elemType)); + Value fillValue = constantZero(builder, loc, elemType); builder.create(loc, fillValue, buffer); } return buffer; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -216,9 +216,9 @@ // The following operations and dialects may be introduced by the // codegen rules, and are therefore marked as legal. target.addLegalOp(); - target.addLegalDialect(); + target.addLegalDialect< + arith::ArithDialect, bufferization::BufferizationDialect, + complex::ComplexDialect, memref::MemRefDialect, scf::SCFDialect>(); target.addLegalOp(); // Populate with rules and apply rewriting rules. populateFunctionOpInterfaceTypeConversionPattern(patterns, diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex32.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex32.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex32.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex32.mlir @@ -8,7 +8,7 @@ // RUN: %{command} // // Do the same run, but now with direct IR generation. -// REDEFINE: %{option} = enable-runtime-library=false +// REDEFINE: %{option} = "enable-runtime-library=false enable-buffer-initialization=true" // RUN: %{command} #SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>