diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td @@ -158,6 +158,14 @@ "unsigned":$indexBitWidth ); + let extraClassDeclaration = [{ + // Returns the type for pointer storage based on pointerBitWidth + Type getPointerType() const; + + // Returns the type for index storage based on indexBitWidth + Type getIndexType() const; + }]; + let genVerifyDecl = 1; let hasCustomAssemblyFormat = 1; } diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -41,6 +41,18 @@ } } +Type SparseTensorEncodingAttr::getPointerType() const { + unsigned ptrWidth = getPointerBitWidth(); + Type indexType = IndexType::get(getContext()); + return ptrWidth ? IntegerType::get(getContext(), ptrWidth) : indexType; +} + +Type SparseTensorEncodingAttr::getIndexType() const { + unsigned idxWidth = getIndexBitWidth(); + Type indexType = IndexType::get(getContext()); + return idxWidth ? IntegerType::get(getContext(), idxWidth) : indexType; +} + Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) { if (failed(parser.parseLess())) return {}; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h @@ -543,22 +543,6 @@ return fields[fidx].getType().template cast().getElementType(); } - // TODO: a better places for these functions should be in - // SparseTensorEncodingAttr. - Type getPtrElementType() const { - auto *ctx = rType.getContext(); - unsigned ptrWidth = getSparseTensorEncoding(rType).getPointerBitWidth(); - Type indexType = IndexType::get(ctx); - return ptrWidth ? IntegerType::get(ctx, ptrWidth) : indexType; - } - - Type getIdxElementType() const { - auto *ctx = rType.getContext(); - unsigned idxWidth = getSparseTensorEncoding(rType).getIndexBitWidth(); - Type indexType = IndexType::get(ctx); - return idxWidth ? IntegerType::get(ctx, idxWidth) : indexType; - } - private: unsigned getFieldIndex(unsigned dim, SparseTensorFieldKind kind) const { unsigned fieldIdx = -1u; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -194,6 +194,7 @@ static void allocSchemeForRank(OpBuilder &builder, Location loc, MutSparseTensorDescriptor desc, unsigned r0) { RankedTensorType rtp = desc.getTensorType(); + auto enc = getSparseTensorEncoding(rtp); unsigned rank = rtp.getShape().size(); Value linear = constantIndex(builder, loc, 1); for (unsigned r = r0; r < rank; r++) { @@ -201,7 +202,7 @@ // Append linear x pointers, initialized to zero. Since each compressed // dimension initially already has a single zero entry, this maintains // the desired "linear + 1" length property at all times. - Type ptrType = desc.getPtrElementType(); + Type ptrType = enc.getPointerType(); Value ptrZero = constantZero(builder, loc, ptrType); unsigned fidx = desc.getPtrMemRefIndex(r); createPushback(builder, loc, desc, fidx, ptrZero, linear); @@ -246,12 +247,14 @@ ValueRange dynSizes, bool enableInit, SmallVectorImpl &fields) { RankedTensorType rtp = type.cast(); + auto enc = getSparseTensorEncoding(rtp); Value heuristic = constantIndex(builder, loc, 16); foreachFieldAndTypeInSparseTensor( rtp, - [&](Type fType, unsigned fIdx, SparseTensorFieldKind fKind, - unsigned /*dim*/, DimLevelType /*dlt*/) -> bool { + [&builder, &fields, loc, heuristic, + enableInit](Type fType, unsigned fIdx, SparseTensorFieldKind fKind, + unsigned /*dim*/, DimLevelType /*dlt*/) -> bool { assert(fields.size() == fIdx); auto memRefTp = fType.cast(); Value field; @@ -293,7 +296,7 @@ loc, constantZero(builder, loc, builder.getIndexType()), desc.getMemSizesMemRef()); // zero memSizes - Value ptrZero = constantZero(builder, loc, desc.getPtrElementType()); + Value ptrZero = constantZero(builder, loc, enc.getPointerType()); for (unsigned r = 0; r < rank; r++) { unsigned ro = toOrigDim(rtp, r); // Fills dim sizes array. @@ -476,9 +479,7 @@ // times? // if (d > 0) { - unsigned ptrWidth = getSparseTensorEncoding(rtp).getPointerBitWidth(); - Type indexType = builder.getIndexType(); - Type ptrType = ptrWidth ? builder.getIntegerType(ptrWidth) : indexType; + Type ptrType = getSparseTensorEncoding(rtp).getPointerType(); Value ptrMemRef = desc.getPtrMemRef(d); Value mz = constantIndex(builder, loc, desc.getPtrMemSizesIndex(d)); Value hi = genLoad(builder, loc, desc.getMemSizesMemRef(), mz);