diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td @@ -304,6 +304,14 @@ /// reset to the default, and all other fields inherited from `this`. SparseTensorEncodingAttr withoutBitWidths() const; + /// Constructs a new encoding with the given dimSlices, and all + /// other fields inherited from `this`. + SparseTensorEncodingAttr withDimSlices(ArrayRef<::mlir::sparse_tensor::SparseTensorDimSliceAttr> dimSlices) const; + + /// Constructs a new encoding with the dimSlices reset to the default, + /// and all other fields inherited from `this`. + SparseTensorEncodingAttr withoutDimSlices() const; + // // Rank methods. // diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h @@ -111,6 +111,15 @@ return withEncoding(enc.withoutBitWidths()); } + SparseTensorType + withDimSlices(ArrayRef dimSlices) const { + return withEncoding(enc.withDimSlices(dimSlices)); + } + + SparseTensorType withoutDimSlices() const { + return withEncoding(enc.withoutDimSlices()); + } + // // Other methods. // diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -291,6 +291,17 @@ return withBitWidths(0, 0); } +SparseTensorEncodingAttr SparseTensorEncodingAttr::withDimSlices( + ArrayRef dimSlices) const { + return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(), + getDimToLvl(), getPosWidth(), + getCrdWidth(), dimSlices); +} + +SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimSlices() const { + return withDimSlices(ArrayRef{}); +} + bool SparseTensorEncodingAttr::isAllDense() const { return !getImpl() || llvm::all_of(getLvlTypes(), isDenseDLT); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -1138,10 +1138,7 @@ // TODO: We should check these in ExtractSliceOp::verify. if (!srcEnc || !dstEnc || !dstEnc.isSlice()) return failure(); - assert(srcEnc.getLvlTypes() == dstEnc.getLvlTypes()); - assert(srcEnc.getDimToLvl() == dstEnc.getDimToLvl()); - assert(srcEnc.getPosWidth() == dstEnc.getPosWidth()); - assert(srcEnc.getCrdWidth() == dstEnc.getCrdWidth()); + assert(srcEnc.withoutDimSlices() == dstEnc.withoutDimSlices()); SmallVector fields; auto desc = getMutDescriptorFromTensorTuple(adaptor.getSource(), fields);