diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -142,13 +142,17 @@ return constantI8(rewriter, loc, static_cast(dlt2)); } +static Type getOpaquePointerType(PatternRewriter &rewriter) { + return LLVM::LLVMPointerType::get(rewriter.getI8Type()); +} + /// Returns a function reference (first hit also inserts into module). Sets /// the "_emit_c_interface" on the function declaration when requested, /// so that LLVM lowering generates a wrapper function that takes care /// of ABI complications with passing in and returning MemRefs to C functions. static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, TypeRange resultType, ValueRange operands, - bool emitCInterface = false) { + bool emitCInterface) { MLIRContext *context = op->getContext(); auto module = op->getParentOfType(); auto result = SymbolRefAttr::get(context, name); @@ -165,6 +169,21 @@ return result; } +static CallOp createFuncCall(OpBuilder &builder, Operation *op, StringRef name, + TypeRange resultType, ValueRange operands, + bool emitCInterface = false) { + auto fn = getFunc(op, name, resultType, operands, emitCInterface); + return builder.create(op->getLoc(), resultType, fn, operands); +} + +static CallOp replaceOpWithFuncCall(PatternRewriter &rewriter, Operation *op, + StringRef name, TypeRange resultType, + ValueRange operands, + bool emitCInterface = false) { + auto fn = getFunc(op, name, resultType, operands, emitCInterface); + return rewriter.replaceOpWithNewOp(op, resultType, fn, operands); +} + /// Generates dimension size call. static Value genDimSizeCall(ConversionPatternRewriter &rewriter, Operation *op, SparseTensorEncodingAttr &enc, Value src, @@ -175,24 +194,18 @@ // Generate the call. Location loc = op->getLoc(); StringRef name = "sparseDimSize"; - SmallVector params; - params.push_back(src); - params.push_back(constantIndex(rewriter, loc, idx)); + SmallVector params{src, constantIndex(rewriter, loc, idx)}; Type iTp = rewriter.getIndexType(); - auto fn = getFunc(op, name, iTp, params); - return rewriter.create(loc, iTp, fn, params).getResult(0); + return createFuncCall(rewriter, op, name, iTp, params).getResult(0); } /// Generates a call into the "swiss army knife" method of the sparse runtime /// support library for materializing sparse tensors into the computation. static Value genNewCall(ConversionPatternRewriter &rewriter, Operation *op, ArrayRef params) { - Location loc = op->getLoc(); StringRef name = "newSparseTensor"; - Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type()); - auto fn = getFunc(op, name, pTp, params, /*emitCInterface=*/true); - auto call = rewriter.create(loc, pTp, fn, params); - return call.getResult(0); + Type pTp = getOpaquePointerType(rewriter); + return createFuncCall(rewriter, op, name, pTp, params, true).getResult(0); } /// Populates given sizes array from type. @@ -210,8 +223,8 @@ static void sizesFromSrc(ConversionPatternRewriter &rewriter, SmallVector &sizes, Location loc, Value src) { - ShapedType stp = src.getType().cast(); - for (unsigned i = 0, rank = stp.getRank(); i < rank; i++) + unsigned rank = src.getType().cast().getRank(); + for (unsigned i = 0; i < rank; i++) sizes.push_back(linalg::createOrFoldDimOp(rewriter, loc, src, i)); } @@ -221,12 +234,13 @@ SmallVector &sizes, Operation *op, SparseTensorEncodingAttr &enc, ShapedType stp, Value src) { + Location loc = op->getLoc(); auto shape = stp.getShape(); for (unsigned i = 0, rank = stp.getRank(); i < rank; i++) if (shape[i] == ShapedType::kDynamicSize) sizes.push_back(genDimSizeCall(rewriter, op, enc, src, i)); else - sizes.push_back(constantIndex(rewriter, op->getLoc(), shape[i])); + sizes.push_back(constantIndex(rewriter, loc, shape[i])); } /// Generates an uninitialized temporary buffer of the given size and @@ -293,16 +307,16 @@ } params.push_back(genBuffer(rewriter, loc, rev)); // Secondary and primary types encoding. - ShapedType resType = op->getResult(0).getType().cast(); + Type elemTp = op->getResult(0).getType().cast().getElementType(); params.push_back(constantPointerTypeEncoding(rewriter, loc, enc)); params.push_back(constantIndexTypeEncoding(rewriter, loc, enc)); - params.push_back( - constantPrimaryTypeEncoding(rewriter, loc, resType.getElementType())); - // User action and pointer. - Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type()); - if (!ptr) - ptr = rewriter.create(loc, pTp); + // FIXME(wrengr): this one can throw llvm_unreachable... + params.push_back(constantPrimaryTypeEncoding(rewriter, loc, elemTp)); + // User action. params.push_back(constantAction(rewriter, loc, action)); + // Payload pointer. + if (!ptr) + ptr = rewriter.create(loc, getOpaquePointerType(rewriter)); params.push_back(ptr); } @@ -352,7 +366,6 @@ static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op, Type eltType, Value ptr, Value val, Value ind, Value perm) { - Location loc = op->getLoc(); StringRef name; if (eltType.isF64()) name = "addEltF64"; @@ -368,14 +381,9 @@ name = "addEltI8"; else llvm_unreachable("Unknown element type"); - SmallVector params; - params.push_back(ptr); - params.push_back(val); - params.push_back(ind); - params.push_back(perm); - Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type()); - auto fn = getFunc(op, name, pTp, params, /*emitCInterface=*/true); - rewriter.create(loc, pTp, fn, params); + SmallVector params{ptr, val, ind, perm}; + Type pTp = getOpaquePointerType(rewriter); + createFuncCall(rewriter, op, name, pTp, params, true); } /// Generates a call to `iter->getNext()`. If there is a next element, @@ -384,7 +392,6 @@ /// the memory for `iter` is freed and the return value is false. static Value genGetNextCall(ConversionPatternRewriter &rewriter, Operation *op, Value iter, Value ind, Value elemPtr) { - Location loc = op->getLoc(); Type elemTp = elemPtr.getType().cast().getElementType(); StringRef name; if (elemTp.isF64()) @@ -401,14 +408,9 @@ name = "getNextI8"; else llvm_unreachable("Unknown element type"); - SmallVector params; - params.push_back(iter); - params.push_back(ind); - params.push_back(elemPtr); + SmallVector params{iter, ind, elemPtr}; Type i1 = rewriter.getI1Type(); - auto fn = getFunc(op, name, i1, params, /*emitCInterface=*/true); - auto call = rewriter.create(loc, i1, fn, params); - return call.getResult(0); + return createFuncCall(rewriter, op, name, i1, params, true).getResult(0); } /// If the tensor is a sparse constant, generates and returns the pair of @@ -461,7 +463,7 @@ } Value mem = rewriter.create(loc, memTp, dynamicSizes); Value zero = constantZero(rewriter, loc, elemTp); - rewriter.create(loc, zero, mem).result(); + rewriter.create(loc, zero, mem); return mem; } @@ -754,9 +756,8 @@ matchAndRewrite(ReleaseOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { StringRef name = "delSparseTensor"; - TypeRange none; - auto fn = getFunc(op, name, none, adaptor.getOperands()); - rewriter.create(op.getLoc(), none, fn, adaptor.getOperands()); + TypeRange noTp; + createFuncCall(rewriter, op, name, noTp, adaptor.getOperands()); rewriter.eraseOp(op); return success(); } @@ -785,9 +786,8 @@ name = "sparsePointers8"; else return failure(); - auto fn = getFunc(op, name, resType, adaptor.getOperands(), - /*emitCInterface=*/true); - rewriter.replaceOpWithNewOp(op, resType, fn, adaptor.getOperands()); + auto operands = adaptor.getOperands(); + replaceOpWithFuncCall(rewriter, op, name, resType, operands, true); return success(); } }; @@ -814,9 +814,8 @@ name = "sparseIndices8"; else return failure(); - auto fn = getFunc(op, name, resType, adaptor.getOperands(), - /*emitCInterface=*/true); - rewriter.replaceOpWithNewOp(op, resType, fn, adaptor.getOperands()); + auto operands = adaptor.getOperands(); + replaceOpWithFuncCall(rewriter, op, name, resType, operands, true); return success(); } }; @@ -845,9 +844,8 @@ name = "sparseValuesI8"; else return failure(); - auto fn = getFunc(op, name, resType, adaptor.getOperands(), - /*emitCInterface=*/true); - rewriter.replaceOpWithNewOp(op, resType, fn, adaptor.getOperands()); + auto operands = adaptor.getOperands(); + replaceOpWithFuncCall(rewriter, op, name, resType, operands, true); return success(); } }; @@ -863,8 +861,7 @@ // Finalize any pending insertions. StringRef name = "endInsert"; TypeRange noTp; - auto fn = getFunc(op, name, noTp, adaptor.getOperands()); - rewriter.create(op.getLoc(), noTp, fn, adaptor.getOperands()); + createFuncCall(rewriter, op, name, noTp, adaptor.getOperands()); } rewriter.replaceOp(op, adaptor.getOperands()); return success(); @@ -896,9 +893,8 @@ else llvm_unreachable("Unknown element type"); TypeRange noTp; - auto fn = - getFunc(op, name, noTp, adaptor.getOperands(), /*emitCInterface=*/true); - rewriter.replaceOpWithNewOp(op, noTp, fn, adaptor.getOperands()); + auto operands = adaptor.getOperands(); + replaceOpWithFuncCall(rewriter, op, name, noTp, operands, true); return success(); } };