diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -171,7 +171,7 @@ static CallOp createFuncCall(OpBuilder &builder, Operation *op, StringRef name, TypeRange resultType, ValueRange operands, - bool emitCInterface = false) { + bool emitCInterface = true) { auto fn = getFunc(op, name, resultType, operands, emitCInterface); return builder.create(op->getLoc(), resultType, fn, operands); } @@ -179,7 +179,7 @@ static CallOp replaceOpWithFuncCall(PatternRewriter &rewriter, Operation *op, StringRef name, TypeRange resultType, ValueRange operands, - bool emitCInterface = false) { + bool emitCInterface = true) { auto fn = getFunc(op, name, resultType, operands, emitCInterface); return rewriter.replaceOpWithNewOp(op, resultType, fn, operands); } @@ -196,7 +196,7 @@ StringRef name = "sparseDimSize"; SmallVector params{src, constantIndex(rewriter, loc, idx)}; Type iTp = rewriter.getIndexType(); - return createFuncCall(rewriter, op, name, iTp, params).getResult(0); + return createFuncCall(rewriter, op, name, iTp, params, false).getResult(0); } /// Generates a call into the "swiss army knife" method of the sparse runtime @@ -205,7 +205,7 @@ ArrayRef params) { StringRef name = "newSparseTensor"; Type pTp = getOpaquePointerType(rewriter); - return createFuncCall(rewriter, op, name, pTp, params, true).getResult(0); + return createFuncCall(rewriter, op, name, pTp, params).getResult(0); } /// Populates given sizes array from type. @@ -382,8 +382,7 @@ else llvm_unreachable("Unknown element type"); SmallVector params{ptr, val, ind, perm}; - Type pTp = getOpaquePointerType(rewriter); - createFuncCall(rewriter, op, name, pTp, params, true); + createFuncCall(rewriter, op, name, getOpaquePointerType(rewriter), params); } /// Generates a call to `iter->getNext()`. If there is a next element, @@ -410,7 +409,7 @@ llvm_unreachable("Unknown element type"); SmallVector params{iter, ind, elemPtr}; Type i1 = rewriter.getI1Type(); - return createFuncCall(rewriter, op, name, i1, params, true).getResult(0); + return createFuncCall(rewriter, op, name, i1, params).getResult(0); } /// If the tensor is a sparse constant, generates and returns the pair of @@ -757,7 +756,7 @@ ConversionPatternRewriter &rewriter) const override { StringRef name = "delSparseTensor"; TypeRange noTp; - createFuncCall(rewriter, op, name, noTp, adaptor.getOperands()); + createFuncCall(rewriter, op, name, noTp, adaptor.getOperands(), false); rewriter.eraseOp(op); return success(); } @@ -786,8 +785,7 @@ name = "sparsePointers8"; else return failure(); - auto operands = adaptor.getOperands(); - replaceOpWithFuncCall(rewriter, op, name, resType, operands, true); + replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands()); return success(); } }; @@ -814,8 +812,7 @@ name = "sparseIndices8"; else return failure(); - auto operands = adaptor.getOperands(); - replaceOpWithFuncCall(rewriter, op, name, resType, operands, true); + replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands()); return success(); } }; @@ -844,8 +841,7 @@ name = "sparseValuesI8"; else return failure(); - auto operands = adaptor.getOperands(); - replaceOpWithFuncCall(rewriter, op, name, resType, operands, true); + replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands()); return success(); } }; @@ -861,7 +857,7 @@ // Finalize any pending insertions. StringRef name = "endInsert"; TypeRange noTp; - createFuncCall(rewriter, op, name, noTp, adaptor.getOperands()); + createFuncCall(rewriter, op, name, noTp, adaptor.getOperands(), false); } rewriter.replaceOp(op, adaptor.getOperands()); return success(); @@ -893,8 +889,7 @@ else llvm_unreachable("Unknown element type"); TypeRange noTp; - auto operands = adaptor.getOperands(); - replaceOpWithFuncCall(rewriter, op, name, noTp, operands, true); + replaceOpWithFuncCall(rewriter, op, name, noTp, adaptor.getOperands()); return success(); } };