diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -182,12 +182,19 @@ return call.getResult(0); } +/// Generates a constant zero of the given type. +static Value getZero(ConversionPatternRewriter &rewriter, Location loc, + Type t) { + return rewriter.create(loc, rewriter.getZeroAttr(t)); +} + /// Generates the comparison `v != 0` where `v` is of numeric type `t`. /// For floating types, we use the "unordered" comparator (i.e., returns /// true if `v` is NaN). static Value genIsNonzero(ConversionPatternRewriter &rewriter, Location loc, - Type t, Value v) { - Value zero = rewriter.create(loc, rewriter.getZeroAttr(t)); + Value v) { + Type t = v.getType(); + Value zero = getZero(rewriter, loc, t); if (t.isa()) return rewriter.create(loc, CmpFPredicate::UNE, v, zero); if (t.isIntOrIndex()) @@ -203,11 +210,11 @@ /// if (tensor[ivs]!=0) { /// ind = ivs static Value genIndexAndValueForDense(ConversionPatternRewriter &rewriter, - Operation *op, Type eltType, Value tensor, - Value ind, ValueRange ivs) { + Operation *op, Value tensor, Value ind, + ValueRange ivs) { Location loc = op->getLoc(); Value val = rewriter.create(loc, tensor, ivs); - Value cond = genIsNonzero(rewriter, loc, eltType, val); + Value cond = genIsNonzero(rewriter, loc, val); scf::IfOp ifOp = rewriter.create(loc, cond, /*else*/ false); rewriter.setInsertionPointToStart(&ifOp.thenRegion().front()); unsigned i = 0; @@ -446,8 +453,8 @@ val = genIndexAndValueForSparse( rewriter, op, indices, values, ind, ivs, rank); else - val = genIndexAndValueForDense(rewriter, op, eltType, - tensor, ind, ivs); + val = genIndexAndValueForDense(rewriter, op, tensor, + ind, ivs); genAddEltCall(rewriter, op, eltType, ptr, val, ind, perm); return {};