diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -148,13 +148,15 @@ return LLVM::LLVMPointerType::get(rewriter.getI8Type()); } +enum EmitCInterface : bool { CINTERFACE_ON = true, CINTERFACE_OFF = false }; + /// Returns a function reference (first hit also inserts into module). Sets /// the "_emit_c_interface" on the function declaration when requested, /// so that LLVM lowering generates a wrapper function that takes care /// of ABI complications with passing in and returning MemRefs to C functions. static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, TypeRange resultType, ValueRange operands, - bool emitCInterface) { + EmitCInterface emitCInterface) { MLIRContext *context = op->getContext(); auto module = op->getParentOfType(); auto result = SymbolRefAttr::get(context, name); @@ -174,7 +176,7 @@ /// Creates a `CallOp` to the function reference returned by `getFunc()`. static CallOp createFuncCall(OpBuilder &builder, Operation *op, StringRef name, TypeRange resultType, ValueRange operands, - bool emitCInterface = false) { + EmitCInterface emitCInterface) { auto fn = getFunc(op, name, resultType, operands, emitCInterface); return builder.create(op->getLoc(), resultType, fn, operands); } @@ -184,7 +186,7 @@ static CallOp replaceOpWithFuncCall(PatternRewriter &rewriter, Operation *op, StringRef name, TypeRange resultType, ValueRange operands, - bool emitCInterface = false) { + EmitCInterface emitCInterface) { auto fn = getFunc(op, name, resultType, operands, emitCInterface); return rewriter.replaceOpWithNewOp(op, resultType, fn, operands); } @@ -200,7 +202,8 @@ StringRef name = "sparseDimSize"; SmallVector params{src, constantIndex(rewriter, op->getLoc(), idx)}; Type iTp = rewriter.getIndexType(); - return createFuncCall(rewriter, op, name, iTp, params).getResult(0); + auto call = createFuncCall(rewriter, op, name, iTp, params, CINTERFACE_OFF); + return call.getResult(0); } /// Generates a call into the "swiss army knife" method of the sparse runtime @@ -209,8 +212,7 @@ ArrayRef params) { StringRef name = "newSparseTensor"; Type pTp = getOpaquePointerType(rewriter); - auto call = createFuncCall(rewriter, op, name, pTp, params, - /*emitCInterface=*/true); + auto call = createFuncCall(rewriter, op, name, pTp, params, CINTERFACE_ON); return call.getResult(0); } @@ -388,7 +390,7 @@ llvm_unreachable("Unknown element type"); SmallVector params{ptr, val, ind, perm}; Type pTp = getOpaquePointerType(rewriter); - createFuncCall(rewriter, op, name, pTp, params, /*emitCInterface=*/true); + createFuncCall(rewriter, op, name, pTp, params, CINTERFACE_ON); } /// Generates a call to `iter->getNext()`. If there is a next element, @@ -415,8 +417,7 @@ llvm_unreachable("Unknown element type"); SmallVector params{iter, ind, elemPtr}; Type i1 = rewriter.getI1Type(); - auto call = createFuncCall(rewriter, op, name, i1, params, - /*emitCInterface=*/true); + auto call = createFuncCall(rewriter, op, name, i1, params, CINTERFACE_ON); return call.getResult(0); } @@ -763,8 +764,9 @@ matchAndRewrite(ReleaseOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { StringRef name = "delSparseTensor"; + auto params = adaptor.getOperands(); TypeRange noTp; - createFuncCall(rewriter, op, name, noTp, adaptor.getOperands()); + createFuncCall(rewriter, op, name, noTp, params, CINTERFACE_OFF); rewriter.eraseOp(op); return success(); } @@ -793,8 +795,8 @@ name = "sparsePointers8"; else return failure(); - replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(), - /*emitCInterface=*/true); + auto params = adaptor.getOperands(); + replaceOpWithFuncCall(rewriter, op, name, resType, params, CINTERFACE_ON); return success(); } }; @@ -821,8 +823,8 @@ name = "sparseIndices8"; else return failure(); - replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(), - /*emitCInterface=*/true); + auto params = adaptor.getOperands(); + replaceOpWithFuncCall(rewriter, op, name, resType, params, CINTERFACE_ON); return success(); } }; @@ -851,8 +853,8 @@ name = "sparseValuesI8"; else return failure(); - replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(), - /*emitCInterface=*/true); + auto params = adaptor.getOperands(); + replaceOpWithFuncCall(rewriter, op, name, resType, params, CINTERFACE_ON); return success(); } }; @@ -867,8 +869,9 @@ if (op.hasInserts()) { // Finalize any pending insertions. StringRef name = "endInsert"; + auto params = adaptor.getOperands(); TypeRange noTp; - createFuncCall(rewriter, op, name, noTp, adaptor.getOperands()); + createFuncCall(rewriter, op, name, noTp, params, CINTERFACE_OFF); } rewriter.replaceOp(op, adaptor.getOperands()); return success(); @@ -899,9 +902,9 @@ name = "lexInsertI8"; else llvm_unreachable("Unknown element type"); + auto params = adaptor.getOperands(); TypeRange noTp; - replaceOpWithFuncCall(rewriter, op, name, noTp, adaptor.getOperands(), - /*emitCInterface=*/true); + replaceOpWithFuncCall(rewriter, op, name, noTp, params, CINTERFACE_ON); return success(); } };