Adds tests for full sum reduction (tensors summed up into scalars)
and the well-known sampled-dense-dense-matrix-product. Refines
the optimizations rules slightly to handle the summation better.
Details
Diff Detail
- Repository
- rG LLVM Github Monorepo
Event Timeline
The generated MLIR code looks good to me. Thank you for the tests / examples! :D
mlir/test/Dialect/Linalg/sparse_2d.mlir | ||
---|---|---|
1111 | Is changing the order of the mapping here equivalent to changing the traversal order? For example, does (i,j,k) -> (j, i) mean A is DCSC instead of DCSR? | |
1123 | Nit: Can we use S instead of A to indicate that it's a sparse mask? (That way, we can use A and B for A * B to be consistent with normal notation for multiplication.) I realize you'll need to change the variable names everywhere, so if you prefer to keep the original names, I'm okay with that too. | |
1142 | VAL_13 and VAL_15 are the same here (the size of the reduction dimension). Could we just reuse VAL_13 in the future, or is there a specific reason to keep all tensor dimension sizes separate? | |
mlir/test/Dialect/Linalg/sparse_3d.mlir | ||
1263 | For summation, we actually can just loop through the nnz array (VAL_11) and forgo all the multi-level indexing. Is this what you mean when you said you prefer keeping all the i, j, k indices versus flattening? |
mlir/test/Dialect/Linalg/sparse_2d.mlir | ||
---|---|---|
1111 | Yes, for now just permuting the indices of the tensor access is similar to TACO's ordering. | |
1123 | Sure, don't mind doing that just here, since it makes sense for this kernel | |
1142 | No reason other than a mechanical translation of the dimensions. In the long run, we probably will use other APIs to find the loop sizes. | |
mlir/test/Dialect/Linalg/sparse_3d.mlir | ||
1263 | Yes! Making this one flattened loop will be indeed one of the future optimizations (since it vectorizes so much better too). |
Is changing the order of the mapping here equivalent to changing the traversal order? For example, does (i,j,k) -> (j, i) mean A is DCSC instead of DCSR?