diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -31,6 +31,10 @@ namespace { +/// Shorthand aliases for the `emitCInterface` argument to `getFunc()`, +/// `createFuncCall()`, and `replaceOpWithFuncCall()`. +enum class EmitCInterface : bool { Off = false, On = true }; + //===----------------------------------------------------------------------===// // Helper methods. //===----------------------------------------------------------------------===// @@ -154,7 +158,7 @@ /// of ABI complications with passing in and returning MemRefs to C functions. static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, TypeRange resultType, ValueRange operands, - bool emitCInterface) { + EmitCInterface emitCInterface) { MLIRContext *context = op->getContext(); auto module = op->getParentOfType(); auto result = SymbolRefAttr::get(context, name); @@ -165,7 +169,7 @@ op->getLoc(), name, FunctionType::get(context, operands.getTypes(), resultType)); func.setPrivate(); - if (emitCInterface) + if (static_cast(emitCInterface)) func->setAttr("llvm.emit_c_interface", UnitAttr::get(context)); } return result; @@ -174,7 +178,7 @@ /// Creates a `CallOp` to the function reference returned by `getFunc()`. static CallOp createFuncCall(OpBuilder &builder, Operation *op, StringRef name, TypeRange resultType, ValueRange operands, - bool emitCInterface = false) { + EmitCInterface emitCInterface) { auto fn = getFunc(op, name, resultType, operands, emitCInterface); return builder.create(op->getLoc(), resultType, fn, operands); } @@ -184,7 +188,7 @@ static CallOp replaceOpWithFuncCall(PatternRewriter &rewriter, Operation *op, StringRef name, TypeRange resultType, ValueRange operands, - bool emitCInterface = false) { + EmitCInterface emitCInterface) { auto fn = getFunc(op, name, resultType, operands, emitCInterface); return rewriter.replaceOpWithNewOp(op, resultType, fn, operands); } @@ -200,7 +204,8 @@ StringRef name = "sparseDimSize"; SmallVector params{src, constantIndex(rewriter, op->getLoc(), idx)}; Type iTp = rewriter.getIndexType(); - return createFuncCall(rewriter, op, name, iTp, params).getResult(0); + return createFuncCall(rewriter, op, name, iTp, params, EmitCInterface::Off) + .getResult(0); } /// Generates a call into the "swiss army knife" method of the sparse runtime @@ -209,9 +214,8 @@ ArrayRef params) { StringRef name = "newSparseTensor"; Type pTp = getOpaquePointerType(rewriter); - auto call = createFuncCall(rewriter, op, name, pTp, params, - /*emitCInterface=*/true); - return call.getResult(0); + return createFuncCall(rewriter, op, name, pTp, params, EmitCInterface::On) + .getResult(0); } /// Populates given sizes array from type. @@ -388,7 +392,7 @@ llvm_unreachable("Unknown element type"); SmallVector params{ptr, val, ind, perm}; Type pTp = getOpaquePointerType(rewriter); - createFuncCall(rewriter, op, name, pTp, params, /*emitCInterface=*/true); + createFuncCall(rewriter, op, name, pTp, params, EmitCInterface::On); } /// Generates a call to `iter->getNext()`. If there is a next element, @@ -415,9 +419,8 @@ llvm_unreachable("Unknown element type"); SmallVector params{iter, ind, elemPtr}; Type i1 = rewriter.getI1Type(); - auto call = createFuncCall(rewriter, op, name, i1, params, - /*emitCInterface=*/true); - return call.getResult(0); + return createFuncCall(rewriter, op, name, i1, params, EmitCInterface::On) + .getResult(0); } /// If the tensor is a sparse constant, generates and returns the pair of @@ -764,7 +767,8 @@ ConversionPatternRewriter &rewriter) const override { StringRef name = "delSparseTensor"; TypeRange noTp; - createFuncCall(rewriter, op, name, noTp, adaptor.getOperands()); + createFuncCall(rewriter, op, name, noTp, adaptor.getOperands(), + EmitCInterface::Off); rewriter.eraseOp(op); return success(); } @@ -794,7 +798,7 @@ else return failure(); replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(), - /*emitCInterface=*/true); + EmitCInterface::On); return success(); } }; @@ -822,7 +826,7 @@ else return failure(); replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(), - /*emitCInterface=*/true); + EmitCInterface::On); return success(); } }; @@ -852,7 +856,7 @@ else return failure(); replaceOpWithFuncCall(rewriter, op, name, resType, adaptor.getOperands(), - /*emitCInterface=*/true); + EmitCInterface::On); return success(); } }; @@ -868,7 +872,8 @@ // Finalize any pending insertions. StringRef name = "endInsert"; TypeRange noTp; - createFuncCall(rewriter, op, name, noTp, adaptor.getOperands()); + createFuncCall(rewriter, op, name, noTp, adaptor.getOperands(), + EmitCInterface::Off); } rewriter.replaceOp(op, adaptor.getOperands()); return success(); @@ -901,7 +906,7 @@ llvm_unreachable("Unknown element type"); TypeRange noTp; replaceOpWithFuncCall(rewriter, op, name, noTp, adaptor.getOperands(), - /*emitCInterface=*/true); + EmitCInterface::On); return success(); } };