diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td @@ -158,6 +158,14 @@ "unsigned":$indexBitWidth ); + let extraClassDeclaration = [{ + /// Returns the type for pointer storage based on pointerBitWidth + Type getPointerType() const; + + /// Returns the type for index storage based on indexBitWidth + Type getIndexType() const; + }]; + let genVerifyDecl = 1; let hasCustomAssemblyFormat = 1; } diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -41,6 +41,18 @@ } } +Type SparseTensorEncodingAttr::getPointerType() const { + unsigned ptrWidth = getPointerBitWidth(); + Type indexType = IndexType::get(getContext()); + return ptrWidth ? IntegerType::get(getContext(), ptrWidth) : indexType; +} + +Type SparseTensorEncodingAttr::getIndexType() const { + unsigned idxWidth = getIndexBitWidth(); + Type indexType = IndexType::get(getContext()); + return idxWidth ? IntegerType::get(getContext(), idxWidth) : indexType; +} + Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) { if (failed(parser.parseLess())) return {}; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -198,12 +198,10 @@ return llvm::None; // Construct the basic types. auto *context = type.getContext(); - unsigned idxWidth = enc.getIndexBitWidth(); - unsigned ptrWidth = enc.getPointerBitWidth(); RankedTensorType rType = type.cast(); Type indexType = IndexType::get(context); - Type idxType = idxWidth ? IntegerType::get(context, idxWidth) : indexType; - Type ptrType = ptrWidth ? IntegerType::get(context, ptrWidth) : indexType; + Type idxType = enc.getIndexType(); + Type ptrType = enc.getPointerType(); Type eltType = rType.getElementType(); // // Sparse tensor storage scheme for rank-dimensional tensor is organized @@ -263,9 +261,7 @@ // Append linear x pointers, initialized to zero. Since each compressed // dimension initially already has a single zero entry, this maintains // the desired "linear + 1" length property at all times. - unsigned ptrWidth = getSparseTensorEncoding(rtp).getPointerBitWidth(); - Type indexType = builder.getIndexType(); - Type ptrType = ptrWidth ? builder.getIntegerType(ptrWidth) : indexType; + Type ptrType = getSparseTensorEncoding(rtp).getPointerType(); Value ptrZero = constantZero(builder, loc, ptrType); createPushback(builder, loc, fields, field, ptrZero, linear); return; @@ -273,11 +269,11 @@ if (isSingletonDim(rtp, r)) { return; // nothing to do } // Keep compounding the size, but nothing needs to be initialized - // at this level. We will eventually reach a compressed level or - // otherwise the values array for the from-here "all-dense" case. - assert(isDenseDim(rtp, r)); - Value size = sizeAtStoredDim(builder, loc, rtp, fields, r); - linear = builder.create(loc, linear, size); + // at this level. We will eventually reach a compressed level or + // otherwise the values array for the from-here "all-dense" case. + assert(isDenseDim(rtp, r)); + Value size = sizeAtStoredDim(builder, loc, rtp, fields, r); + linear = builder.create(loc, linear, size); } // Reached values array so prepare for an insertion. Value valZero = constantZero(builder, loc, rtp.getElementType()); @@ -310,13 +306,10 @@ SmallVectorImpl &fields) { auto enc = getSparseTensorEncoding(type); assert(enc); - // Construct the basic types. - unsigned idxWidth = enc.getIndexBitWidth(); - unsigned ptrWidth = enc.getPointerBitWidth(); RankedTensorType rtp = type.cast(); Type indexType = builder.getIndexType(); - Type idxType = idxWidth ? builder.getIntegerType(idxWidth) : indexType; - Type ptrType = ptrWidth ? builder.getIntegerType(ptrWidth) : indexType; + Type idxType = enc.getIndexType(); + Type ptrType = enc.getPointerType(); Type eltType = rtp.getElementType(); auto shape = rtp.getShape(); unsigned rank = shape.size(); @@ -534,9 +527,7 @@ // TODO: avoid cleanup and keep compressed scheme consistent at all times? // if (d > 0) { - unsigned ptrWidth = getSparseTensorEncoding(rtp).getPointerBitWidth(); - Type indexType = builder.getIndexType(); - Type ptrType = ptrWidth ? builder.getIntegerType(ptrWidth) : indexType; + Type ptrType = getSparseTensorEncoding(rtp).getPointerType(); Value mz = constantIndex(builder, loc, getMemSizesIndex(field)); Value hi = genLoad(builder, loc, fields[memSizesIdx], mz); Value zero = constantIndex(builder, loc, 0);