diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td @@ -73,6 +73,11 @@ "int64_t" : $stride ); + let builders = [ + // The nop slice (i.e., that includes everything). + AttrBuilder<(ins), [{ return $_get($_ctxt, 0, kDynamic, 1); }]> + ]; + let extraClassDeclaration = [{ void print(llvm::raw_ostream &os) const; diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -213,6 +213,7 @@ } void SparseTensorDimSliceAttr::print(llvm::raw_ostream &os) const { + assert(getImpl() && "Uninitialized SparseTensorDimSliceAttr"); os << '('; os << getStaticString(getOffset()); os << ", "; @@ -528,10 +529,37 @@ ir_detail::DimLvlMapParser cParser(parser); auto res = cParser.parseDimLvlMap(); RETURN_ON_FAIL(res); - // Proof of concept result. - // TODO: use DimLvlMap directly as storage representation - for (Level lvl = 0, lvlRank = res->getLvlRank(); lvl < lvlRank; lvl++) - lvlTypes.push_back(res->getLvlType(lvl)); + // TODO: use DimLvlMap directly as storage representation, rather + // than converting things over. + const auto &dlm = *res; + + ERROR_IF(!lvlTypes.empty(), "Cannot mix `lvlTypes` with `NEW_SYNTAX`") + const Level lvlRank = dlm.getLvlRank(); + for (Level lvl = 0; lvl < lvlRank; lvl++) + lvlTypes.push_back(dlm.getLvlType(lvl)); + + ERROR_IF(!dimSlices.empty(), "Cannot mix `dimSlices` with `NEW_SYNTAX`") + const Dimension dimRank = dlm.getDimRank(); + for (Dimension dim = 0; dim < dimRank; dim++) + dimSlices.push_back(dlm.getDimSlice(dim)); + // NOTE: the old syntax requires an all-or-nothing approach to + // `dimSlices`; therefore, if any slice actually exists then we need + // to convert null-DSA into default/nop DSA. + const auto isDefined = [](SparseTensorDimSliceAttr slice) { + return static_cast(slice.getImpl()); + }; + if (llvm::any_of(dimSlices, isDefined)) { + const auto defaultSlice = + SparseTensorDimSliceAttr::get(parser.getContext()); + for (Dimension dim = 0; dim < dimRank; dim++) + if (!isDefined(dimSlices[dim])) + dimSlices[dim] = defaultSlice; + } else { + dimSlices.clear(); + } + + ERROR_IF(dimToLvl, "Cannot mix `dimToLvl` with `NEW_SYNTAX`") + dimToLvl = dlm.getDimToLvlMap(parser.getContext()); } // Only the last item can omit the comma diff --git a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir --- a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir +++ b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir @@ -156,9 +156,9 @@ (d0, d1) -> (d0 : dense, d1 : compressed) }> -// CHECK-LABEL: func private @foo( +// CHECK-LABEL: func private @CSR_implicit( // CHECK-SAME: tensor> -func.func private @foo(%arg0: tensor) { +func.func private @CSR_implicit(%arg0: tensor) { return } @@ -169,9 +169,9 @@ {l0, l1} (d0 = l0, d1 = l1) -> (l0 = d0 : dense, l1 = d1 : compressed) }> -// CHECK-LABEL: func private @foo( +// CHECK-LABEL: func private @CSR_explicit( // CHECK-SAME: tensor> -func.func private @foo(%arg0: tensor) { +func.func private @CSR_explicit(%arg0: tensor) { return } @@ -187,11 +187,9 @@ ) }> -// FIXME: should not have to use 4 dims ;-) -// -// CHECK-LABEL: func private @foo( -// CHECK-SAME: tensor> -func.func private @foo(%arg0: tensor) { +// CHECK-LABEL: func private @BCSR_implicit( +// CHECK-SAME: tensor (d0 floordiv 2, d1 floordiv 3, d0 mod 2, d1 mod 3)> }>> +func.func private @BCSR_implicit(%arg0: tensor) { return } @@ -210,11 +208,9 @@ ) }> -// FIXME: should not have to use 4 dims ;-) -// -// CHECK-LABEL: func private @foo( -// CHECK-SAME: tensor> -func.func private @foo(%arg0: tensor) { +// CHECK-LABEL: func private @BCSR_explicit( +// CHECK-SAME: tensor (d0 floordiv 2, d1 floordiv 3, d0 mod 2, d1 mod 3)> }>> +func.func private @BCSR_explicit(%arg0: tensor) { return } @@ -229,9 +225,8 @@ ) }> -// -// CHECK-LABEL: func private @foo_2_out_of_4( -// CHECK-SAME: tensor> -func.func private @foo_2_out_of_4(%arg0: tensor) { +// CHECK-LABEL: func private @NV_24( +// CHECK-SAME: tensor (d0, d1 floordiv 4, d1 mod 4)> }>> +func.func private @NV_24(%arg0: tensor) { return }