diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -30,6 +30,9 @@ namespace { +// Iteration graph sorting. +enum SortMask { kSparseOnly = 0x0, kIncludeDense = 0x1, kIncludeUndef = 0x2 }; + // Code generation. struct CodeGen { CodeGen(SparsificationOptions o, unsigned numTensors, unsigned numLoops) @@ -141,7 +144,7 @@ /// order yields innermost unit-stride access with better spatial locality. static bool computeIterationGraph(Merger &merger, linalg::GenericOp op, std::vector &topSort, - bool sparseOnly) { + unsigned mask) { // Set up an n x n from/to adjacency matrix of the iteration graph // for the implicit loop indices i_0 .. i_n-1. unsigned n = op.getNumLoops(); @@ -152,8 +155,8 @@ auto map = op.getTiedIndexingMap(t); auto enc = getSparseTensorEncoding(t->get().getType()); assert(map.getNumDims() == n); - // Skip dense tensor constraints when sparse only is requested. - if (sparseOnly && !enc) + // Skip dense tensor constraints when not requested. + if (!(mask & SortMask::kIncludeDense) && !enc) continue; // Each tensor expression and optional dimension ordering (row-major // by default) puts an ordering constraint on the loop indices. For @@ -164,6 +167,16 @@ unsigned t = map.getDimPosition(perm(enc, d)); adjM[f][t] = true; } + // Push unrelated loops into sparse iteration space, so these + // will be skipped more often. + if (mask & SortMask::kIncludeUndef) { + unsigned tensor = t->getOperandNumber(); + for (unsigned i = 0; i < n; i++) + if (merger.isDim(tensor, i, Dim::kSparse)) + for (unsigned j = 0; j < n; j++) + if (merger.isDim(tensor, j, Dim::kUndef)) + adjM[i][j] = true; + } } // Topologically sort the iteration graph to determine loop order. @@ -1134,8 +1147,12 @@ // This assumes that higher-level passes have already put the // tensors in each tensor expression in a feasible order. std::vector topSort; - if (!computeIterationGraph(merger, op, topSort, /*sparseOnly=*/false) && - !computeIterationGraph(merger, op, topSort, /*sparseOnly=*/true)) + if (!computeIterationGraph(merger, op, topSort, + SortMask::kIncludeUndef | + SortMask::kIncludeDense) && + !computeIterationGraph(merger, op, topSort, SortMask::kIncludeUndef) && + !computeIterationGraph(merger, op, topSort, SortMask::kIncludeDense) && + !computeIterationGraph(merger, op, topSort, SortMask::kSparseOnly)) return failure(); // Builds the tensor expression for the Linalg operation in SSA form. diff --git a/mlir/test/Dialect/SparseTensor/sparse_2d.mlir b/mlir/test/Dialect/SparseTensor/sparse_2d.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_2d.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_2d.mlir @@ -1043,25 +1043,26 @@ // CHECK: %[[VAL_19:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref // CHECK: scf.for %[[VAL_20:.*]] = %[[VAL_18]] to %[[VAL_19]] step %[[VAL_5]] { // CHECK: %[[VAL_21:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_20]]] : memref -// CHECK: scf.for %[[VAL_22:.*]] = %[[VAL_4]] to %[[VAL_12]] step %[[VAL_5]] { -// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_21]], %[[VAL_22]]] : memref -// CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_20]]] : memref -// CHECK: %[[VAL_25:.*]] = addi %[[VAL_20]], %[[VAL_5]] : index -// CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_25]]] : memref -// CHECK: scf.for %[[VAL_27:.*]] = %[[VAL_24]] to %[[VAL_26]] step %[[VAL_5]] { -// CHECK: %[[VAL_28:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_27]]] : memref -// CHECK: %[[VAL_29:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_21]], %[[VAL_28]]] : memref -// CHECK: %[[VAL_30:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_27]]] : memref -// CHECK: %[[VAL_31:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_22]], %[[VAL_28]]] : memref -// CHECK: %[[VAL_32:.*]] = mulf %[[VAL_23]], %[[VAL_31]] : f32 -// CHECK: %[[VAL_33:.*]] = mulf %[[VAL_30]], %[[VAL_32]] : f32 -// CHECK: %[[VAL_34:.*]] = addf %[[VAL_29]], %[[VAL_33]] : f32 -// CHECK: memref.store %[[VAL_34]], %[[VAL_17]]{{\[}}%[[VAL_21]], %[[VAL_28]]] : memref +// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_20]]] : memref +// CHECK: %[[VAL_23:.*]] = addi %[[VAL_20]], %[[VAL_5]] : index +// CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_23]]] : memref +// CHECK: scf.for %[[VAL_25:.*]] = %[[VAL_22]] to %[[VAL_24]] step %[[VAL_5]] { +// CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_25]]] : memref +// CHECK: %[[VAL_27:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_25]]] : memref +// CHECK: %[[VAL_28:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_21]], %[[VAL_26]]] : memref +// CHECK: %[[VAL_29:.*]] = scf.for %[[VAL_30:.*]] = %[[VAL_4]] to %[[VAL_12]] step %[[VAL_5]] iter_args(%[[VAL_31:.*]] = %[[VAL_28]]) -> (f32) { +// CHECK: %[[VAL_32:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_21]], %[[VAL_30]]] : memref +// CHECK: %[[VAL_33:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_30]], %[[VAL_26]]] : memref +// CHECK: %[[VAL_34:.*]] = mulf %[[VAL_32]], %[[VAL_33]] : f32 +// CHECK: %[[VAL_35:.*]] = mulf %[[VAL_27]], %[[VAL_34]] : f32 +// CHECK: %[[VAL_36:.*]] = addf %[[VAL_31]], %[[VAL_35]] : f32 +// CHECK: scf.yield %[[VAL_36]] : f32 // CHECK: } +// CHECK: memref.store %[[VAL_37:.*]], %[[VAL_17]]{{\[}}%[[VAL_21]], %[[VAL_26]]] : memref // CHECK: } // CHECK: } -// CHECK: %[[VAL_35:.*]] = memref.tensor_load %[[VAL_17]] : memref -// CHECK: return %[[VAL_35]] : tensor +// CHECK: %[[VAL_38:.*]] = memref.tensor_load %[[VAL_17]] : memref +// CHECK: return %[[VAL_38]] : tensor // CHECK: } func @sampled_dense_dense(%args: tensor, %arga: tensor,