diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -451,24 +451,22 @@ AffineMap dimToLvl = {}; unsigned posWidth = 0; unsigned crdWidth = 0; - StringRef attrName; - // Exactly 6 keys. SmallVector keys = {"lvlTypes", "dimToLvl", "posWidth", - "crdWidth", "dimSlices", "NEW_SYNTAX"}; + "crdWidth", "dimSlices", "map"}; while (succeeded(parser.parseOptionalKeyword(&attrName))) { - if (!llvm::is_contained(keys, attrName)) { + // Detect admissible keyword. + auto *it = find(keys, attrName); + if (it == keys.end()) { parser.emitError(parser.getNameLoc(), "unexpected key: ") << attrName; return {}; } - + unsigned keyWordIndex = it - keys.begin(); // Consume the `=` after keys RETURN_ON_FAIL(parser.parseEqual()) - // FIXME: using `operator==` below duplicates the string comparison - // cost of the `is_contained` check above. Should instead use some - // "find" function that returns the index into `keys` so that we can - // dispatch on that instead. - if (attrName == "lvlTypes") { + // Dispatch on keyword. + switch (keyWordIndex) { + case 0: { // lvlTypes Attribute attr; RETURN_ON_FAIL(parser.parseAttribute(attr)); auto arrayAttr = llvm::dyn_cast(attr); @@ -485,25 +483,33 @@ return {}; } } - } else if (attrName == "dimToLvl") { + break; + } + case 1: { // dimToLvl Attribute attr; RETURN_ON_FAIL(parser.parseAttribute(attr)) auto affineAttr = llvm::dyn_cast(attr); ERROR_IF(!affineAttr, "expected an affine map for dimToLvl") dimToLvl = affineAttr.getValue(); - } else if (attrName == "posWidth") { + break; + } + case 2: { // posWidth Attribute attr; RETURN_ON_FAIL(parser.parseAttribute(attr)) auto intAttr = llvm::dyn_cast(attr); ERROR_IF(!intAttr, "expected an integral position bitwidth") posWidth = intAttr.getInt(); - } else if (attrName == "crdWidth") { + break; + } + case 3: { // crdWidth Attribute attr; RETURN_ON_FAIL(parser.parseAttribute(attr)) auto intAttr = llvm::dyn_cast(attr); ERROR_IF(!intAttr, "expected an integral index bitwidth") crdWidth = intAttr.getInt(); - } else if (attrName == "dimSlices") { + break; + } + case 4: { // dimSlices RETURN_ON_FAIL(parser.parseLSquare()) // Dispatches to DimSliceAttr to skip mnemonic bool finished = false; @@ -519,13 +525,9 @@ if (!finished) return {}; RETURN_ON_FAIL(parser.parseRSquare()) - } else if (attrName == "NEW_SYNTAX") { - // Note that we are in the process of migrating to a new STEA surface - // syntax. While this is ongoing we use the temporary "NEW_SYNTAX = ...." - // to switch to the new parser. This allows us to gradually migrate - // examples over to the new surface syntax before making the complete - // switch once work is completed. - // TODO: replace everything here with new STEA surface syntax parser + break; + } + case 5: { // map (new STEA surface syntax) ir_detail::DimLvlMapParser cParser(parser); auto res = cParser.parseDimLvlMap(); RETURN_ON_FAIL(res); @@ -533,12 +535,12 @@ // than converting things over. const auto &dlm = *res; - ERROR_IF(!lvlTypes.empty(), "Cannot mix `lvlTypes` with `NEW_SYNTAX`") + ERROR_IF(!lvlTypes.empty(), "Cannot mix `lvlTypes` with `map`") const Level lvlRank = dlm.getLvlRank(); for (Level lvl = 0; lvl < lvlRank; lvl++) lvlTypes.push_back(dlm.getLvlType(lvl)); - ERROR_IF(!dimSlices.empty(), "Cannot mix `dimSlices` with `NEW_SYNTAX`") + ERROR_IF(!dimSlices.empty(), "Cannot mix `dimSlices` with `map`") const Dimension dimRank = dlm.getDimRank(); for (Dimension dim = 0; dim < dimRank; dim++) dimSlices.push_back(dlm.getDimSlice(dim)); @@ -558,11 +560,12 @@ dimSlices.clear(); } - ERROR_IF(dimToLvl, "Cannot mix `dimToLvl` with `NEW_SYNTAX`") + ERROR_IF(dimToLvl, "Cannot mix `dimToLvl` with `map`") dimToLvl = dlm.getDimToLvlMap(parser.getContext()); + break; } - - // Only the last item can omit the comma + } // switch + // Only last item can omit the comma. if (parser.parseOptionalComma().failed()) break; } diff --git a/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir b/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir --- a/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir +++ b/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir @@ -70,18 +70,11 @@ }> func.func private @sparse_slice(tensor) -/////////////////////////////////////////////////////////////////////////////// -// Migration plan for new STEA surface syntax, -// use the NEW_SYNTAX on selected examples -// and then TODO: remove when fully migrated -/////////////////////////////////////////////////////////////////////////////// - // ----- -// expected-error@+3 {{Level-rank mismatch between forward-declarations and specifiers. Declared 3 level-variables; but got 2 level-specifiers.}} +// expected-error@+2 {{Level-rank mismatch between forward-declarations and specifiers. Declared 3 level-variables; but got 2 level-specifiers.}} #TooManyLvlDecl = #sparse_tensor.encoding<{ - NEW_SYNTAX = - {l0, l1, l2} (d0, d1) -> (l0 = d0 : dense, l1 = d1 : compressed) + map = {l0, l1, l2} (d0, d1) -> (l0 = d0 : dense, l1 = d1 : compressed) }> func.func private @too_many_lvl_decl(%arg0: tensor) { return @@ -89,15 +82,9 @@ // ----- -// NOTE: We don't get the "level-rank mismatch" error here, because this -// "undeclared identifier" error occurs first. The error message is a bit -// misleading because `parseLvlVarBinding` calls `parseVarUsage` rather -// than `parseVarBinding` (and the error message generated by `parseVar` -// is assuming that `parseVarUsage` is only called for *uses* of variables). -// expected-error@+3 {{use of undeclared identifier 'l1'}} +// expected-error@+2 {{use of undeclared identifier 'l1'}} #TooFewLvlDecl = #sparse_tensor.encoding<{ - NEW_SYNTAX = - {l0} (d0, d1) -> (l0 = d0 : dense, l1 = d1 : compressed) + map = {l0} (d0, d1) -> (l0 = d0 : dense, l1 = d1 : compressed) }> func.func private @too_few_lvl_decl(%arg0: tensor) { return @@ -105,12 +92,10 @@ // ----- -// expected-error@+3 {{Level-variable ordering mismatch. The variable 'l0' was forward-declared as the 1st level; but is bound by the 0th specification.}} +// expected-error@+2 {{Level-variable ordering mismatch. The variable 'l0' was forward-declared as the 1st level; but is bound by the 0th specification.}} #WrongOrderLvlDecl = #sparse_tensor.encoding<{ - NEW_SYNTAX = - {l1, l0} (d0, d1) -> (l0 = d0 : dense, l1 = d1 : compressed) + map = {l1, l0} (d0, d1) -> (l0 = d0 : dense, l1 = d1 : compressed) }> func.func private @wrong_order_lvl_decl(%arg0: tensor) { return } - diff --git a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir --- a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir +++ b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir @@ -2,13 +2,12 @@ // CHECK-LABEL: func private @sparse_1d_tensor( // CHECK-SAME: tensor<32xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ] }>>) -func.func private @sparse_1d_tensor(tensor<32xf64, #sparse_tensor.encoding<{ lvlTypes = ["compressed"] }>>) +func.func private @sparse_1d_tensor(tensor<32xf64, #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>>) // ----- #CSR = #sparse_tensor.encoding<{ - lvlTypes = [ "dense", "compressed" ], - dimToLvl = affine_map<(i,j) -> (i,j)>, + map = (d0, d1) -> (d0 : dense, d1 : compressed), posWidth = 64, crdWidth = 64 }> @@ -19,9 +18,20 @@ // ----- +#CSR_explicit = #sparse_tensor.encoding<{ + map = {l0, l1} (d0 = l0, d1 = l1) -> (l0 = d0 : dense, l1 = d1 : compressed) +}> + +// CHECK-LABEL: func private @CSR_explicit( +// CHECK-SAME: tensor> +func.func private @CSR_explicit(%arg0: tensor) { + return +} + +// ----- + #CSC = #sparse_tensor.encoding<{ - lvlTypes = [ "dense", "compressed" ], - dimToLvl = affine_map<(i,j) -> (j,i)>, + map = (d0, d1) -> (d1 : dense, d0 : compressed), posWidth = 0, crdWidth = 0 }> @@ -33,8 +43,7 @@ // ----- #DCSC = #sparse_tensor.encoding<{ - lvlTypes = [ "compressed", "compressed" ], - dimToLvl = affine_map<(i,j) -> (j,i)>, + map = (d0, d1) -> (d1 : compressed, d0 : compressed), posWidth = 0, crdWidth = 64 }> @@ -129,7 +138,6 @@ // CHECK-SAME: tensor> func.func private @sparse_slice(tensor) - // ----- // TODO: It is probably better to use [dense, dense, 2:4] (see NV_24 defined using new syntax @@ -143,43 +151,10 @@ // CHECK-SAME: tensor> func.func private @sparse_2_out_of_4(tensor) -/////////////////////////////////////////////////////////////////////////////// -// Migration plan for new STEA surface syntax, -// use the NEW_SYNTAX on selected examples -// and then TODO: remove when fully migrated -/////////////////////////////////////////////////////////////////////////////// - // ----- -#CSR_implicit = #sparse_tensor.encoding<{ - NEW_SYNTAX = - (d0, d1) -> (d0 : dense, d1 : compressed) -}> - -// CHECK-LABEL: func private @CSR_implicit( -// CHECK-SAME: tensor> -func.func private @CSR_implicit(%arg0: tensor) { - return -} - -// ----- - -#CSR_explicit = #sparse_tensor.encoding<{ - NEW_SYNTAX = - {l0, l1} (d0 = l0, d1 = l1) -> (l0 = d0 : dense, l1 = d1 : compressed) -}> - -// CHECK-LABEL: func private @CSR_explicit( -// CHECK-SAME: tensor> -func.func private @CSR_explicit(%arg0: tensor) { - return -} - -// ----- - -#BCSR_implicit = #sparse_tensor.encoding<{ - NEW_SYNTAX = - ( i, j ) -> +#BCSR = #sparse_tensor.encoding<{ + map = ( i, j ) -> ( i floordiv 2 : compressed, j floordiv 3 : compressed, i mod 2 : dense, @@ -187,16 +162,16 @@ ) }> -// CHECK-LABEL: func private @BCSR_implicit( +// CHECK-LABEL: func private @BCSR( // CHECK-SAME: tensor (d0 floordiv 2, d1 floordiv 3, d0 mod 2, d1 mod 3)> }>> -func.func private @BCSR_implicit(%arg0: tensor) { +func.func private @BCSR(%arg0: tensor) { return } // ----- #BCSR_explicit = #sparse_tensor.encoding<{ - NEW_SYNTAX = + map = {il, jl, ii, jj} ( i = il * 2 + ii, j = jl * 3 + jj @@ -217,8 +192,7 @@ // ----- #NV_24 = #sparse_tensor.encoding<{ - NEW_SYNTAX = - ( i, j ) -> + map = ( i, j ) -> ( i : dense, j floordiv 4 : dense, j mod 4 : compressed24