diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -347,17 +347,27 @@ /// Helper method to inspect sparse encodings in the tensor types. /// Fills the per-dimension sparsity information for all tensors. -static void findSparseAnnotations(Merger &merger, linalg::GenericOp op) { +static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) { + bool annotated = false; unsigned numTensors = op.getNumShapedOperands(); + unsigned lhs = numTensors - 1; for (unsigned t = 0; t < numTensors; t++) { auto map = op.getIndexingMap(t); unsigned rank = op.getShapedType(t).getRank(); auto enc = getSparseTensorEncoding(op.getShapedType(t)); + if (enc) { + annotated = true; + if (enc.getDimOrdering() && !enc.getDimOrdering().isIdentity()) + return false; // TODO: handle permutations + if (t == lhs) + return false; // TODO: handle sparse outputs + } for (unsigned d = 0; d < rank; d++) { unsigned idx = map.getDimPosition(d); merger.setDim(t, idx, toDim(enc, d)); } } + return annotated; } /// A DFS helper to compute a topological sort. Note that recursion is @@ -1356,7 +1366,8 @@ unsigned numTensors = op.getNumShapedOperands(); unsigned numLoops = op.iterator_types().getValue().size(); Merger merger(numTensors, numLoops); - findSparseAnnotations(merger, op); + if (!findSparseAnnotations(merger, op)) + return failure(); // Computes a topologically sorted iteration graph to ensure // tensors are visited in natural index order. Fails on cycles. diff --git a/mlir/test/Dialect/SparseTensor/sparse_3d.mlir b/mlir/test/Dialect/SparseTensor/sparse_3d.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_3d.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_3d.mlir @@ -1,6 +1,8 @@ // NOTE: Assertions have been autogenerated by utils/generate-test-checks.py // RUN: mlir-opt %s -sparsification | FileCheck %s +#Td = #sparse_tensor.encoding<{ dimLevelType = [ "dense" ] }> + #Tddd = #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense", "dense" ] }> #Tdds = #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense", "compressed" ] }> #Tdsd = #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed", "dense" ] }> @@ -1249,7 +1251,7 @@ // CHECK-LABEL: func @sum_reduction_inv( // CHECK-SAME: %[[VAL_0:.*]]: tensor, -// CHECK-SAME: %[[VAL_1:.*]]: tensor, +// CHECK-SAME: %[[VAL_1:.*]]: tensor>, // CHECK-SAME: %[[VAL_2:.*]]: tensor) -> tensor { // CHECK: %[[VAL_3:.*]] = constant 2 : index // CHECK: %[[VAL_4:.*]] = constant 0 : index @@ -1257,8 +1259,8 @@ // CHECK: %[[VAL_6:.*]] = memref.dim %[[VAL_0]], %[[VAL_5]] : tensor // CHECK: %[[VAL_7:.*]] = memref.dim %[[VAL_0]], %[[VAL_3]] : tensor // CHECK: %[[VAL_8:.*]] = memref.buffer_cast %[[VAL_0]] : memref -// CHECK: %[[VAL_9:.*]] = memref.dim %[[VAL_1]], %[[VAL_4]] : tensor -// CHECK: %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_1]] : memref +// CHECK: %[[VAL_9:.*]] = memref.dim %[[VAL_1]], %[[VAL_4]] : tensor> +// CHECK: %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_1]] : tensor> to memref // CHECK: %[[VAL_11:.*]] = memref.buffer_cast %[[VAL_2]] : memref // CHECK: %[[VAL_12:.*]] = memref.alloc() : memref // CHECK: linalg.copy(%[[VAL_11]], %[[VAL_12]]) : memref, memref @@ -1279,10 +1281,10 @@ // CHECK: return %[[VAL_24]] : tensor // CHECK: } func @sum_reduction_inv(%arga: tensor, - %argb: tensor, + %argb: tensor, %argx: tensor) -> tensor { %0 = linalg.generic #trait_sum_reduction_inv - ins(%arga, %argb: tensor, tensor) + ins(%arga, %argb: tensor, tensor) outs(%argx: tensor) { ^bb(%a: f32, %b: f32, %x: f32): %0 = mulf %a, %b : f32 @@ -1304,7 +1306,7 @@ } // CHECK-LABEL: func @invariants( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<10xf32>, +// CHECK-SAME: %[[VAL_0:.*]]: tensor<10xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>>, // CHECK-SAME: %[[VAL_1:.*]]: tensor<20xf32>, // CHECK-SAME: %[[VAL_2:.*]]: tensor<30xf32>, // CHECK-SAME: %[[VAL_3:.*]]: tensor<10x20x30xf32>) -> tensor<10x20x30xf32> { @@ -1313,14 +1315,14 @@ // CHECK: %[[VAL_6:.*]] = constant 30 : index // CHECK: %[[VAL_7:.*]] = constant 0 : index // CHECK: %[[VAL_8:.*]] = constant 1 : index -// CHECK: %[[VAL_9:.*]] = memref.buffer_cast %[[VAL_0]] : memref<10xf32> +// CHECK: %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref // CHECK: %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_1]] : memref<20xf32> // CHECK: %[[VAL_11:.*]] = memref.buffer_cast %[[VAL_2]] : memref<30xf32> // CHECK: %[[VAL_12:.*]] = memref.buffer_cast %[[VAL_3]] : memref<10x20x30xf32> // CHECK: %[[VAL_13:.*]] = memref.alloc() : memref<10x20x30xf32> // CHECK: linalg.copy(%[[VAL_12]], %[[VAL_13]]) : memref<10x20x30xf32>, memref<10x20x30xf32> // CHECK: scf.for %[[VAL_14:.*]] = %[[VAL_7]] to %[[VAL_4]] step %[[VAL_8]] { -// CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_14]]] : memref<10xf32> +// CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_14]]] : memref // CHECK: scf.for %[[VAL_16:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_8]] { // CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_16]]] : memref<20xf32> // CHECK: scf.for %[[VAL_18:.*]] = %[[VAL_7]] to %[[VAL_6]] step %[[VAL_8]] { @@ -1334,12 +1336,12 @@ // CHECK: %[[VAL_22:.*]] = memref.tensor_load %[[VAL_13]] : memref<10x20x30xf32> // CHECK: return %[[VAL_22]] : tensor<10x20x30xf32> // CHECK: } -func @invariants(%arga: tensor<10xf32>, +func @invariants(%arga: tensor<10xf32, #Td>, %argb: tensor<20xf32>, %argc: tensor<30xf32>, %argx: tensor<10x20x30xf32>) -> tensor<10x20x30xf32> { %0 = linalg.generic #trait_invariants - ins(%arga, %argb, %argc : tensor<10xf32>, tensor<20xf32>, tensor<30xf32>) + ins(%arga, %argb, %argc : tensor<10xf32, #Td>, tensor<20xf32>, tensor<30xf32>) outs(%argx: tensor<10x20x30xf32>) { ^bb(%a: f32, %b: f32, %c: f32, %x: f32): %0 = mulf %a, %b : f32 diff --git a/mlir/test/Dialect/SparseTensor/sparse_parallel.mlir b/mlir/test/Dialect/SparseTensor/sparse_parallel.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_parallel.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_parallel.mlir @@ -9,6 +9,10 @@ // RUN: mlir-opt %s -sparsification="parallelization-strategy=4" | \ // RUN: FileCheck %s --check-prefix=CHECK-PAR4 +#DenseMatrix = #sparse_tensor.encoding<{ + dimLevelType = [ "dense", "dense" ] +}> + #SparseMatrix = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }> @@ -52,9 +56,11 @@ // CHECK-PAR4: scf.parallel // CHECK-PAR4: return // -func @scale_dd(%scale: f32, %arga: tensor, %argx: tensor) -> tensor { +func @scale_dd(%scale: f32, + %arga: tensor, + %argx: tensor) -> tensor { %0 = linalg.generic #trait_dd - ins(%arga: tensor) + ins(%arga: tensor) outs(%argx: tensor) { ^bb(%a: f32, %x: f32): %0 = mulf %a, %scale : f32 @@ -98,7 +104,9 @@ // CHECK-PAR4: scf.parallel // CHECK-PAR4: return // -func @scale_ss(%scale: f32, %arga: tensor, %argx: tensor) -> tensor { +func @scale_ss(%scale: f32, + %arga: tensor, + %argx: tensor) -> tensor { %0 = linalg.generic #trait_ss ins(%arga: tensor) outs(%argx: tensor) { @@ -145,9 +153,11 @@ // CHECK-PAR4: scf.for // CHECK-PAR4: return // -func @matvec(%argA: tensor<16x32xf32, #CSR>, %argb: tensor<32xf32>, %argx: tensor<16xf32>) -> tensor<16xf32> { +func @matvec(%arga: tensor<16x32xf32, #CSR>, + %argb: tensor<32xf32>, + %argx: tensor<16xf32>) -> tensor<16xf32> { %0 = linalg.generic #trait_matvec - ins(%argA, %argb : tensor<16x32xf32, #CSR>, tensor<32xf32>) + ins(%arga, %argb : tensor<16x32xf32, #CSR>, tensor<32xf32>) outs(%argx: tensor<16xf32>) { ^bb(%A: f32, %b: f32, %x: f32): %0 = mulf %A, %b : f32