diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h @@ -136,6 +136,11 @@ /// of the given type, and returns the `memref<$tp>`. Value genAllocaScalar(OpBuilder &builder, Location loc, Type tp); +/// Generates a temporary buffer, initializes it with the given contents, +/// and returns it as type `memref` (rather than specifying the +/// size of the buffer). +Value allocaBuffer(OpBuilder &builder, Location loc, ValueRange values); + /// Generates code to allocate a buffer of the given type, and zero /// initialize it. If the buffer type has any dynamic sizes, then the /// `sizes` parameter should be as filled by sizesFromPtr(); that way diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp @@ -1151,6 +1151,18 @@ return builder.create(loc, MemRefType::get({}, tp)); } +Value mlir::sparse_tensor::allocaBuffer(OpBuilder &builder, Location loc, + ValueRange values) { + const unsigned sz = values.size(); + assert(sz >= 1); + Value buffer = genAlloca(builder, loc, sz, values[0].getType()); + for (unsigned i = 0; i < sz; i++) { + Value idx = constantIndex(builder, loc, i); + builder.create(loc, values[i], buffer, idx); + } + return buffer; +} + Value mlir::sparse_tensor::allocDenseTensor(OpBuilder &builder, Location loc, RankedTensorType tensorTp, ValueRange sizes) { diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -207,18 +207,6 @@ return rewriter.create(loc, memTp, ValueRange{sz}); } -/// Generates a temporary buffer of the given type and given contents. -static Value genBuffer(OpBuilder &builder, Location loc, ValueRange values) { - unsigned sz = values.size(); - assert(sz >= 1); - Value buffer = genAlloca(builder, loc, sz, values[0].getType()); - for (unsigned i = 0; i < sz; i++) { - Value idx = constantIndex(builder, loc, i); - builder.create(loc, values[i], buffer, idx); - } - return buffer; -} - /// Generates a temporary buffer for the level-types of the given encoding. static Value genLvlTypesBuffer(OpBuilder &builder, Location loc, SparseTensorEncodingAttr enc) { @@ -227,7 +215,7 @@ lvlTypes.reserve(dlts.size()); for (auto dlt : dlts) lvlTypes.push_back(constantDimLevelTypeEncoding(builder, loc, dlt)); - return genBuffer(builder, loc, lvlTypes); + return allocaBuffer(builder, loc, lvlTypes); } /// This class abstracts over the API of `_mlir_ciface_newSparseTensor`: @@ -329,7 +317,7 @@ // Dimension-sizes array of the enveloping tensor. Useful for either // verification of external data, or for construction of internal data. assert(dimSizes.size() == dimRank && "Dimension-rank mismatch"); - params[kParamDimSizes] = genBuffer(builder, loc, dimSizes); + params[kParamDimSizes] = allocaBuffer(builder, loc, dimSizes); // The level-sizes array must be passed as well, since for arbitrary // dim2lvl mappings it cannot be trivially reconstructed at runtime. // For now however, since we're still assuming permutations, we will @@ -358,10 +346,10 @@ lvlSizes[i] = dimSizes[i]; } } - params[kParamLvlSizes] = genBuffer(builder, loc, lvlSizes); - params[kParamLvl2Dim] = genBuffer(builder, loc, lvl2dim); + params[kParamLvlSizes] = allocaBuffer(builder, loc, lvlSizes); + params[kParamLvl2Dim] = allocaBuffer(builder, loc, lvl2dim); params[kParamDim2Lvl] = - dimOrder ? genBuffer(builder, loc, dim2lvl) : params[kParamLvl2Dim]; + dimOrder ? allocaBuffer(builder, loc, dim2lvl) : params[kParamLvl2Dim]; // Secondary and primary types encoding. setTemplateTypes(enc, stp); // Finally, make note that initialization is complete. @@ -780,7 +768,7 @@ // Construct the dimShape. const auto dimShape = stp.getShape(); SmallVector dimShapeValues = getDimShape(rewriter, loc, stp); - Value dimShapeBuffer = genBuffer(rewriter, loc, dimShapeValues); + Value dimShapeBuffer = allocaBuffer(rewriter, loc, dimShapeValues); // Allocate `SparseTensorReader` and perform all initial setup that // does not depend on lvlSizes (nor dim2lvl, lvl2dim, etc). Type opaqueTp = getOpaquePointerType(rewriter); @@ -833,9 +821,9 @@ ? rewriter.create(loc, dimSizesBuffer, dim) : dimShapeValues[d]; } - lvlSizesBuffer = genBuffer(rewriter, loc, lvlSizeValues); - lvl2dimBuffer = genBuffer(rewriter, loc, lvl2dimValues); - dim2lvlBuffer = genBuffer(rewriter, loc, dim2lvlValues); + lvlSizesBuffer = allocaBuffer(rewriter, loc, lvlSizeValues); + lvl2dimBuffer = allocaBuffer(rewriter, loc, lvl2dimValues); + dim2lvlBuffer = allocaBuffer(rewriter, loc, dim2lvlValues); } else { assert(dimRank == lvlRank && "Rank mismatch"); SmallVector iotaValues; @@ -843,7 +831,7 @@ for (unsigned i = 0; i < lvlRank; i++) iotaValues.push_back(constantIndex(rewriter, loc, i)); lvlSizesBuffer = dimSizesBuffer ? dimSizesBuffer : dimShapeBuffer; - dim2lvlBuffer = lvl2dimBuffer = genBuffer(rewriter, loc, iotaValues); + dim2lvlBuffer = lvl2dimBuffer = allocaBuffer(rewriter, loc, iotaValues); } // Use the `reader` to parse the file. SmallVector params{