diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -790,14 +790,6 @@ return genInvariantValue(merger, codegen, rewriter, exp); Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e0); Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e1); - if (merger.exp(exp).kind == Kind::kNegI) { - // TODO: no negi in std, need to make zero explicit. - Type tp = op.getOutputTensorTypes()[0].getElementType(); - v1 = v0; - v0 = rewriter.create(loc, tp, rewriter.getZeroAttr(tp)); - if (codegen.curVecLength > 1) - v0 = genVectorInvariantValue(codegen, rewriter, v0); - } return merger.buildExp(rewriter, loc, exp, v0, v1); } diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -568,7 +568,7 @@ if (isa(def)) return addExp(kFloorF, e); if (isa(def)) - return addExp(kNegF, e); // TODO: no negi in std? + return addExp(kNegF, e); // no negi in std if (isa(def)) return addExp(kTruncF, e, v); if (isa(def)) @@ -651,9 +651,12 @@ return rewriter.create(loc, v0); case kNegF: return rewriter.create(loc, v0); - case kNegI: - assert(v1); // no negi in std - return rewriter.create(loc, v0, v1); + case kNegI: // no negi in std + return rewriter.create( + loc, + rewriter.create(loc, v0.getType(), + rewriter.getZeroAttr(v0.getType())), + v0); case kTruncF: return rewriter.create(loc, v0, inferType(e, v0)); case kExtF: