diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -139,9 +139,13 @@ /// for sparse storage formats since these only support access along fixed /// dimensions. Even for dense storage formats, however, the natural index /// order yields innermost unit-stride access with better spatial locality. +/// +/// bit-mask XY +/// |+--> include dense constraints +/// +---> include sparse outer constraints static bool computeIterationGraph(Merger &merger, linalg::GenericOp op, std::vector &topSort, - bool sparseOnly) { + unsigned mask) { // Set up an n x n from/to adjacency matrix of the iteration graph // for the implicit loop indices i_0 .. i_n-1. unsigned n = op.getNumLoops(); @@ -153,7 +157,7 @@ auto enc = getSparseTensorEncoding(t->get().getType()); assert(map.getNumDims() == n); // Skip dense tensor constraints when sparse only is requested. - if (sparseOnly && !enc) + if (!(mask & 1) && !enc) continue; // Each tensor expression and optional dimension ordering (row-major // by default) puts an ordering constraint on the loop indices. For @@ -164,6 +168,16 @@ unsigned t = map.getDimPosition(perm(enc, d)); adjM[f][t] = true; } + // Push unrelated loops into sparse iteration space, so these + // will be skipped more often. + if (mask & 2) { + unsigned tensor = t->getOperandNumber(); + for (unsigned i = 0; i < n; i++) + if (merger.isDim(tensor, i, Dim::kSparse)) + for (unsigned j = 0; j < n; j++) + if (merger.isDim(tensor, j, Dim::kUndef)) + adjM[i][j] = true; + } } // Topologically sort the iteration graph to determine loop order. @@ -1134,8 +1148,10 @@ // This assumes that higher-level passes have already put the // tensors in each tensor expression in a feasible order. std::vector topSort; - if (!computeIterationGraph(merger, op, topSort, /*sparseOnly=*/false) && - !computeIterationGraph(merger, op, topSort, /*sparseOnly=*/true)) + if (!computeIterationGraph(merger, op, topSort, /*mask=*/0x03) && + !computeIterationGraph(merger, op, topSort, /*mask=*/0x02) && + !computeIterationGraph(merger, op, topSort, /*mask=*/0x01) && + !computeIterationGraph(merger, op, topSort, /*mask=*/0x00)) return failure(); // Builds the tensor expression for the Linalg operation in SSA form. diff --git a/mlir/test/Dialect/SparseTensor/sparse_2d.mlir b/mlir/test/Dialect/SparseTensor/sparse_2d.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_2d.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_2d.mlir @@ -1043,25 +1043,26 @@ // CHECK: %[[VAL_19:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref // CHECK: scf.for %[[VAL_20:.*]] = %[[VAL_18]] to %[[VAL_19]] step %[[VAL_5]] { // CHECK: %[[VAL_21:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_20]]] : memref -// CHECK: scf.for %[[VAL_22:.*]] = %[[VAL_4]] to %[[VAL_12]] step %[[VAL_5]] { -// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_21]], %[[VAL_22]]] : memref -// CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_20]]] : memref -// CHECK: %[[VAL_25:.*]] = addi %[[VAL_20]], %[[VAL_5]] : index -// CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_25]]] : memref -// CHECK: scf.for %[[VAL_27:.*]] = %[[VAL_24]] to %[[VAL_26]] step %[[VAL_5]] { -// CHECK: %[[VAL_28:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_27]]] : memref -// CHECK: %[[VAL_29:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_21]], %[[VAL_28]]] : memref -// CHECK: %[[VAL_30:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_27]]] : memref -// CHECK: %[[VAL_31:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_22]], %[[VAL_28]]] : memref -// CHECK: %[[VAL_32:.*]] = mulf %[[VAL_23]], %[[VAL_31]] : f32 -// CHECK: %[[VAL_33:.*]] = mulf %[[VAL_30]], %[[VAL_32]] : f32 -// CHECK: %[[VAL_34:.*]] = addf %[[VAL_29]], %[[VAL_33]] : f32 -// CHECK: memref.store %[[VAL_34]], %[[VAL_17]]{{\[}}%[[VAL_21]], %[[VAL_28]]] : memref +// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_20]]] : memref +// CHECK: %[[VAL_23:.*]] = addi %[[VAL_20]], %[[VAL_5]] : index +// CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_23]]] : memref +// CHECK: scf.for %[[VAL_25:.*]] = %[[VAL_22]] to %[[VAL_24]] step %[[VAL_5]] { +// CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_25]]] : memref +// CHECK: %[[VAL_27:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_25]]] : memref +// CHECK: %[[VAL_28:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_21]], %[[VAL_26]]] : memref +// CHECK: %[[VAL_29:.*]] = scf.for %[[VAL_30:.*]] = %[[VAL_4]] to %[[VAL_12]] step %[[VAL_5]] iter_args(%[[VAL_31:.*]] = %[[VAL_28]]) -> (f32) { +// CHECK: %[[VAL_32:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_21]], %[[VAL_30]]] : memref +// CHECK: %[[VAL_33:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_30]], %[[VAL_26]]] : memref +// CHECK: %[[VAL_34:.*]] = mulf %[[VAL_32]], %[[VAL_33]] : f32 +// CHECK: %[[VAL_35:.*]] = mulf %[[VAL_27]], %[[VAL_34]] : f32 +// CHECK: %[[VAL_36:.*]] = addf %[[VAL_31]], %[[VAL_35]] : f32 +// CHECK: scf.yield %[[VAL_36]] : f32 // CHECK: } +// CHECK: memref.store %[[VAL_37:.*]], %[[VAL_17]]{{\[}}%[[VAL_21]], %[[VAL_26]]] : memref // CHECK: } // CHECK: } -// CHECK: %[[VAL_35:.*]] = memref.tensor_load %[[VAL_17]] : memref -// CHECK: return %[[VAL_35]] : tensor +// CHECK: %[[VAL_38:.*]] = memref.tensor_load %[[VAL_17]] : memref +// CHECK: return %[[VAL_38]] : tensor // CHECK: } func @sampled_dense_dense(%args: tensor, %arga: tensor,