diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -182,11 +182,27 @@ return call.getResult(0); } +/// Generates the comparison `v != 0` where `v` is of numeric type `t`. +/// For floating types, we use the "unordered" comparator (i.e., returns +/// true if `v` is NaN). +static Value genIsNonzero(ConversionPatternRewriter &rewriter, Location loc, + Type t, Value v) { + Value zero = rewriter.create(loc, rewriter.getZeroAttr(t)); + if (t.isa()) + return rewriter.create(loc, CmpFPredicate::UNE, v, zero); + if (t.isIntOrIndex()) + return rewriter.create(loc, CmpIPredicate::ne, v, zero); + llvm_unreachable("Unknown element type"); +} + /// Generates a call that adds one element to a coordinate scheme. +/// In particular, this generates code like the following: +/// val = a[i1,..,ik]; +/// if val != 0 +/// t->add(val, [i1,..,ik], [p1,..,pk]); static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op, Value ptr, Value tensor, Value ind, Value perm, ValueRange ivs) { - Location loc = op->getLoc(); StringRef name; Type eltType = tensor.getType().cast().getElementType(); if (eltType.isF64()) @@ -203,8 +219,11 @@ name = "addEltI8"; else llvm_unreachable("Unknown element type"); + Location loc = op->getLoc(); Value val = rewriter.create(loc, tensor, ivs); - // TODO: add if here? + Value cond = genIsNonzero(rewriter, loc, eltType, val); + scf::IfOp ifOp = rewriter.create(loc, cond, /*else*/ false); + rewriter.setInsertionPointToStart(&ifOp.thenRegion().front()); unsigned i = 0; for (auto iv : ivs) { Value idx = rewriter.create(loc, rewriter.getIndexAttr(i++)); @@ -321,6 +340,9 @@ // Note that the dense tensor traversal code is actually implemented // using MLIR IR to avoid having to expose too much low-level // memref traversal details to the runtime support library. + // Also note that the code below only generates the "new" ops and + // the loop-nest per se; whereas the entire body of the innermost + // loop is generated by genAddElt(). Location loc = op->getLoc(); ShapedType shape = resType.cast(); auto memTp = diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -1,4 +1,4 @@ -//===- SparsificationPass.cpp - Pass for autogen spares tensor code -------===// +//===- SparseTensorPasses.cpp - Pass for autogen sparse tensor code -------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -114,7 +114,8 @@ }); // The following operations and dialects may be introduced by the // rewriting rules, and are therefore marked as legal. - target.addLegalOp(); + target.addLegalOp(); target.addLegalDialect(); // Populate with rules and apply rewriting rules. diff --git a/mlir/lib/ExecutionEngine/SparseUtils.cpp b/mlir/lib/ExecutionEngine/SparseUtils.cpp --- a/mlir/lib/ExecutionEngine/SparseUtils.cpp +++ b/mlir/lib/ExecutionEngine/SparseUtils.cpp @@ -548,8 +548,6 @@ void *_mlir_ciface_##NAME(void *tensor, TYPE value, \ StridedMemRefType *iref, \ StridedMemRefType *pref) { \ - if (!value) \ - return tensor; \ assert(iref->strides[0] == 1 && pref->strides[0] == 1); \ assert(iref->sizes[0] == pref->sizes[0]); \ const uint64_t *indx = iref->data + iref->offset; \