diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -47,12 +47,29 @@ /// Iteration graph sorting. enum SortMask { - kSparseOnly = 0x0, - kIncludeDense = 0x1, - kIncludeUndef = 0x2, - kIncludeAll = 0x3 + // The individual mask bits. + kIncludeDenseOutput = 0x1, // b001 + kIncludeDenseInput = 0x2, // b010 + kIncludeUndef = 0x4, // b100 + // The subsets of mask bits. + kIncludeAll = 0x7, // b111 + kIncludeDense = 0x3, // b011 + kSparseOnly = 0x0, // b000 }; +/// SortMask tests on individual bits. +inline static bool includeDenseInput(unsigned mask) { + return mask & SortMask::kIncludeDenseInput; +} + +inline static bool includeDenseOutput(unsigned mask) { + return mask & SortMask::kIncludeDenseOutput; +} + +inline static bool includeUndef(unsigned mask) { + return mask & SortMask::kIncludeUndef; +} + /// A helper class that visits an affine expression and tries to find an /// AffineDimExpr to which the corresponding iterator from a GenericOp matches /// the desired iterator type. @@ -453,9 +470,35 @@ const auto map = env.op().getMatchingIndexingMap(&t); const auto enc = getSparseTensorEncoding(t.get().getType()); assert(map.getNumDims() + getNumCompoundAffineOnSparseDims(env.op()) == n); - // Skip dense tensor constraints when not requested. - if (!(mask & SortMask::kIncludeDense) && !enc) + + bool isDenseInput = !enc && env.op().isDpsInput(&t); + bool isDenseOutput = !enc && !isDenseInput; + + // Skips dense inputs/outputs when not requested. + if ((isDenseInput && !includeDenseInput(mask)) || + (isDenseOutput && !includeDenseOutput(mask))) continue; + + // Push unrelated loops into sparse iteration space, so these + // will be skipped more often. + // TODO: Do we really need this? + if (includeUndef(mask)) { + unsigned tensor = t.getOperandNumber(); + for (unsigned i = 0; i < n; i++) { + if (isCompressedDLT(env.dlt(tensor, i)) || + isSingletonDLT(env.dlt(tensor, i))) { + for (unsigned j = 0; j < n; j++) + if (isUndefDLT(env.dlt(tensor, j))) { + adjM[i][j] = true; + inDegree[j]++; + } + } else { + assert(isDenseDLT(env.dlt(tensor, i)) || + isUndefDLT(env.dlt(tensor, i))); + } + } + } + // Each tensor expression and optional dimension ordering (row-major // by default) puts an ordering constraint on the loop indices. For // example, the tensor expresion A_ijk forces the ordering i < j < k @@ -508,24 +551,6 @@ addAffineOrderings(adjM, inDegree, fa, ta, fldx, tldx); } } - // Push unrelated loops into sparse iteration space, so these - // will be skipped more often. - if (mask & SortMask::kIncludeUndef) { - unsigned tensor = t.getOperandNumber(); - for (unsigned i = 0; i < n; i++) { - if (isCompressedDLT(env.dlt(tensor, i)) || - isSingletonDLT(env.dlt(tensor, i))) { - for (unsigned j = 0; j < n; j++) - if (isUndefDLT(env.dlt(tensor, j))) { - adjM[i][j] = true; - inDegree[j]++; - } - } else { - assert(isDenseDLT(env.dlt(tensor, i)) || - isUndefDLT(env.dlt(tensor, i))); - } - } - } } // Topologically sort the iteration graph to determine loop order. // Report failure for a cyclic iteration graph. @@ -1532,8 +1557,12 @@ // An const list of all masks that we used for interation graph // computation. Must be ordered from more strict to less strict. - const auto allMask = {SortMask::kIncludeAll, SortMask::kIncludeUndef, - SortMask::kIncludeDense, SortMask::kSparseOnly}; + // Ideally (though might not be guaranteed), the eariler a constraint mask + // can be satisfied, the faster the generated kernel will be. + const auto allMask = { + SortMask::kIncludeAll, SortMask::kIncludeDense, + SortMask::kIncludeDenseInput, SortMask::kIncludeDenseOutput, + SortMask::kIncludeUndef, SortMask::kSparseOnly}; for (auto mask : allMask) { if (computeIterationGraph(env, mask)) { hasCycle = false; diff --git a/mlir/test/Dialect/SparseTensor/sparse_2d.mlir b/mlir/test/Dialect/SparseTensor/sparse_2d.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_2d.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_2d.mlir @@ -1002,46 +1002,45 @@ doc = "X(i,j) += S(i,j) SUM_k A(i,k) B(k,j)" } -// CHECK-LABEL: func @sampled_dense_dense( +// CHECK-LABEL: func.func @sampled_dense_dense( // CHECK-SAME: %[[VAL_0:.*0]]: tensor>, // CHECK-SAME: %[[VAL_1:.*1]]: tensor, // CHECK-SAME: %[[VAL_2:.*2]]: tensor, // CHECK-SAME: %[[VAL_3:.*3]]: tensor) -> tensor { -// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 0 : index // CHECK-DAG: %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]] {dimension = 0 : index} : tensor> to memref // CHECK-DAG: %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 0 : index} : tensor> to memref // CHECK-DAG: %[[VAL_8:.*]] = sparse_tensor.pointers %[[VAL_0]] {dimension = 1 : index} : tensor> to memref // CHECK-DAG: %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 1 : index} : tensor> to memref // CHECK-DAG: %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor> to memref -// CHECK-DAG: %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_1]] : memref -// CHECK-DAG: %[[VAL_12:.*]] = tensor.dim %[[VAL_1]], %[[VAL_5]] : tensor +// CHECK-DAG: %[[VAL_11:.*]] = tensor.dim %[[VAL_1]], %[[VAL_4]] : tensor +// CHECK-DAG: %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_1]] : memref // CHECK-DAG: %[[VAL_13:.*]] = bufferization.to_memref %[[VAL_2]] : memref -// CHECK-DAG: %[[VAL_17:.*]] = bufferization.to_memref %[[VAL_3]] : memref -// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref -// CHECK: %[[VAL_19:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref -// CHECK: scf.for %[[VAL_20:.*]] = %[[VAL_18]] to %[[VAL_19]] step %[[VAL_5]] { -// CHECK: %[[VAL_21:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_20]]] : memref -// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_20]]] : memref -// CHECK: %[[VAL_23:.*]] = arith.addi %[[VAL_20]], %[[VAL_5]] : index -// CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_23]]] : memref -// CHECK: scf.for %[[VAL_25:.*]] = %[[VAL_22]] to %[[VAL_24]] step %[[VAL_5]] { -// CHECK-DAG: %[[VAL_26:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_25]]] : memref -// CHECK-DAG: %[[VAL_27:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_25]]] : memref -// CHECK-DAG: %[[VAL_28:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_21]], %[[VAL_26]]] : memref -// CHECK: %[[VAL_29:.*]] = scf.for %[[VAL_30:.*]] = %[[VAL_4]] to %[[VAL_12]] step %[[VAL_5]] iter_args(%[[VAL_31:.*]] = %[[VAL_28]]) -> (f32) { -// CHECK: %[[VAL_32:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_21]], %[[VAL_30]]] : memref -// CHECK: %[[VAL_33:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_30]], %[[VAL_26]]] : memref -// CHECK: %[[VAL_34:.*]] = arith.mulf %[[VAL_32]], %[[VAL_33]] : f32 -// CHECK: %[[VAL_35:.*]] = arith.mulf %[[VAL_27]], %[[VAL_34]] : f32 -// CHECK: %[[VAL_36:.*]] = arith.addf %[[VAL_31]], %[[VAL_35]] : f32 -// CHECK: scf.yield %[[VAL_36]] : f32 -// CHECK: } -// CHECK: memref.store %[[VAL_29]], %[[VAL_17]]{{\[}}%[[VAL_21]], %[[VAL_26]]] : memref -// CHECK: } -// CHECK: } -// CHECK: %[[VAL_38:.*]] = bufferization.to_tensor %[[VAL_17]] : memref -// CHECK: return %[[VAL_38]] : tensor +// CHECK-DAG: %[[VAL_14:.*]] = bufferization.to_memref %[[VAL_3]] : memref +// CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref +// CHECK: %[[VAL_16:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref +// CHECK: scf.for %[[VAL_17:.*]] = %[[VAL_15]] to %[[VAL_16]] step %[[VAL_4]] { +// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_17]]] : memref +// CHECK: scf.for %[[VAL_19:.*]] = %[[VAL_5]] to %[[VAL_11]] step %[[VAL_4]] { +// CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_18]], %[[VAL_19]]] : memref +// CHECK: %[[VAL_21:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_17]]] : memref +// CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_17]], %[[VAL_4]] : index +// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_22]]] : memref +// CHECK: scf.for %[[VAL_24:.*]] = %[[VAL_21]] to %[[VAL_23]] step %[[VAL_4]] { +// CHECK: %[[VAL_25:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_24]]] : memref +// CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_18]], %[[VAL_25]]] : memref +// CHECK: %[[VAL_27:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_24]]] : memref +// CHECK: %[[VAL_28:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_19]], %[[VAL_25]]] : memref +// CHECK: %[[VAL_29:.*]] = arith.mulf %[[VAL_20]], %[[VAL_28]] : f32 +// CHECK: %[[VAL_30:.*]] = arith.mulf %[[VAL_27]], %[[VAL_29]] : f32 +// CHECK: %[[VAL_31:.*]] = arith.addf %[[VAL_26]], %[[VAL_30]] : f32 +// CHECK: memref.store %[[VAL_31]], %[[VAL_14]]{{\[}}%[[VAL_18]], %[[VAL_25]]] : memref +// CHECK: } {"Emitted from" = "linalg.generic"} +// CHECK: } {"Emitted from" = "linalg.generic"} +// CHECK: } {"Emitted from" = "linalg.generic"} +// CHECK: %[[VAL_32:.*]] = bufferization.to_tensor %[[VAL_14]] : memref +// CHECK: return %[[VAL_32]] : tensor // CHECK: } func.func @sampled_dense_dense(%args: tensor, %arga: tensor, diff --git a/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir b/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir @@ -51,6 +51,53 @@ return %0 : tensor<10x30xf32> } +// CHECK-LABEL: func.func @matmul_sparse_rhs( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<10x20xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<20x30xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>, +// CHECK-SAME: %[[VAL_2:.*]]: tensor<10x30xf32>) -> tensor<10x30xf32> { +// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 10 : index +// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[VAL_6:.*]] = bufferization.to_memref %[[VAL_0]] : memref<10x20xf32> +// CHECK-DAG: %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_1]] {dimension = 0 : index} +// CHECK-DAG: %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_1]] {dimension = 0 : index} +// CHECK-DAG: %[[VAL_9:.*]] = sparse_tensor.pointers %[[VAL_1]] {dimension = 1 : index} +// CHECK-DAG: %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_1]] {dimension = 1 : index} +// CHECK-DAG: %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_1]] +// CHECK: %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_2]] : memref<10x30xf32> +// CHECK: scf.for %[[VAL_13:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] { +// CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_4]]] : memref +// CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_5]]] : memref +// CHECK: scf.for %[[VAL_16:.*]] = %[[VAL_14]] to %[[VAL_15]] step %[[VAL_5]] { +// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_16]]] : memref +// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_13]], %[[VAL_17]]] : memref<10x20xf32> +// CHECK: %[[VAL_19:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_16]]] : memref +// CHECK: %[[VAL_20:.*]] = arith.addi %[[VAL_16]], %[[VAL_5]] : index +// CHECK: %[[VAL_21:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_20]]] : memref +// CHECK: scf.for %[[VAL_22:.*]] = %[[VAL_19]] to %[[VAL_21]] step %[[VAL_5]] { +// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_22]]] : memref +// CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_13]], %[[VAL_23]]] : memref<10x30xf32> +// CHECK: %[[VAL_25:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_22]]] : memref +// CHECK: %[[VAL_26:.*]] = arith.mulf %[[VAL_18]], %[[VAL_25]] : f32 +// CHECK: %[[VAL_27:.*]] = arith.addf %[[VAL_24]], %[[VAL_26]] : f32 +// CHECK: memref.store %[[VAL_27]], %[[VAL_12]]{{\[}}%[[VAL_13]], %[[VAL_23]]] : memref<10x30xf32> +// CHECK: } {"Emitted from" = "linalg.generic"} +// CHECK: } {"Emitted from" = "linalg.generic"} +// CHECK: } {"Emitted from" = "linalg.generic"} +// CHECK: %[[VAL_28:.*]] = bufferization.to_tensor %[[VAL_12]] : memref<10x30xf32> +// CHECK: return %[[VAL_28]] : tensor<10x30xf32> +// CHECK: } +// IMPORTANT! Ensures that dense input are visit in row-major order. +func.func @matmul_sparse_rhs(%a: tensor<10x20xf32>, + %b: tensor<20x30xf32, #DCSR>, + %c: tensor<10x30xf32>) -> tensor<10x30xf32> { + %0 = linalg.matmul + ins(%a, %b: tensor<10x20xf32>, tensor<20x30xf32,#DCSR>) + outs(%c: tensor<10x30xf32>) -> tensor<10x30xf32> + return %0 : tensor<10x30xf32> +} + + // // Computes C = A x B with all matrices sparse (SpMSpM) in DCSR. // @@ -162,31 +209,32 @@ // CHECK-DAG: %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_1]] {dimension = 1 : index} : tensor<3x3xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> to memref // CHECK-DAG: %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<3x3xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> to memref // CHECK: %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_2]] : memref<6x6xi32> -// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_4]]] : memref -// CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_5]]] : memref -// CHECK: scf.for %[[VAL_15:.*]] = %[[VAL_13]] to %[[VAL_14]] step %[[VAL_5]] { -// CHECK: %[[VAL_16:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_15]]] : memref -// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_15]]] : memref -// CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_15]], %[[VAL_5]] : index -// CHECK: %[[VAL_19:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_18]]] : memref -// CHECK: scf.for %[[VAL_20:.*]] = %[[VAL_17]] to %[[VAL_19]] step %[[VAL_5]] { -// CHECK: %[[VAL_21:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_20]]] : memref -// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_20]]] : memref -// CHECK: scf.for %[[VAL_23:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] { -// CHECK: scf.for %[[VAL_24:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] { -// CHECK: %[[VAL_25:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_24]], %[[VAL_23]]] : memref<6x6xi32> -// CHECK: %[[VAL_26:.*]] = arith.addi %[[VAL_24]], %[[VAL_16]] : index -// CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_23]], %[[VAL_21]] : index -// CHECK: %[[VAL_28:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_26]], %[[VAL_27]]] : memref<8x8xi32> -// CHECK: %[[VAL_29:.*]] = arith.muli %[[VAL_28]], %[[VAL_22]] : i32 -// CHECK: %[[VAL_30:.*]] = arith.addi %[[VAL_25]], %[[VAL_29]] : i32 -// CHECK: memref.store %[[VAL_30]], %[[VAL_12]]{{\[}}%[[VAL_24]], %[[VAL_23]]] : memref<6x6xi32> -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK: %[[VAL_31:.*]] = bufferization.to_tensor %[[VAL_12]] : memref<6x6xi32> -// CHECK: return %[[VAL_31]] : tensor<6x6xi32> +// CHECK: scf.for %[[VAL_13:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] { +// CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_4]]] : memref +// CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_5]]] : memref +// CHECK: scf.for %[[VAL_16:.*]] = %[[VAL_14]] to %[[VAL_15]] step %[[VAL_5]] { +// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_16]]] : memref +// CHECK: scf.for %[[VAL_18:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] { +// CHECK: %[[VAL_19:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_13]], %[[VAL_18]]] : memref<6x6xi32> +// CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_16]]] : memref +// CHECK: %[[VAL_21:.*]] = arith.addi %[[VAL_16]], %[[VAL_5]] : index +// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_21]]] : memref +// CHECK: %[[VAL_23:.*]] = scf.for %[[VAL_24:.*]] = %[[VAL_20]] to %[[VAL_22]] step %[[VAL_5]] iter_args(%[[VAL_25:.*]] = %[[VAL_19]]) -> (i32) { +// CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_24]]] : memref +// CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_13]], %[[VAL_17]] : index +// CHECK: %[[VAL_28:.*]] = arith.addi %[[VAL_18]], %[[VAL_26]] : index +// CHECK: %[[VAL_29:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_27]], %[[VAL_28]]] : memref<8x8xi32> +// CHECK: %[[VAL_30:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_24]]] : memref +// CHECK: %[[VAL_31:.*]] = arith.muli %[[VAL_29]], %[[VAL_30]] : i32 +// CHECK: %[[VAL_32:.*]] = arith.addi %[[VAL_25]], %[[VAL_31]] : i32 +// CHECK: scf.yield %[[VAL_32]] : i32 +// CHECK: } {"Emitted from" = "linalg.generic"} +// CHECK: memref.store %[[VAL_33:.*]], %[[VAL_12]]{{\[}}%[[VAL_13]], %[[VAL_18]]] : memref<6x6xi32> +// CHECK: } {"Emitted from" = "linalg.generic"} +// CHECK: } {"Emitted from" = "linalg.generic"} +// CHECK: } {"Emitted from" = "linalg.generic"} +// CHECK: %[[VAL_34:.*]] = bufferization.to_tensor %[[VAL_12]] : memref<6x6xi32> +// CHECK: return %[[VAL_34]] : tensor<6x6xi32> // CHECK: } func.func @conv2d(%input: tensor<8x8xi32>, %filter: tensor<3x3xi32, #DCSR>, @@ -212,28 +260,28 @@ // CHECK-DAG: %[[VAL_11:.*]] = sparse_tensor.indices %[[VAL_1]] {dimension = 1 : index} : tensor<3x6xi8, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> to memref // CHECK-DAG: %[[VAL_12:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<3x6xi8, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> to memref // CHECK: %[[VAL_13:.*]] = bufferization.to_memref %[[VAL_2]] : memref<5x6xi64> -// CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_4]]] : memref -// CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_5]]] : memref -// CHECK: scf.for %[[VAL_16:.*]] = %[[VAL_14]] to %[[VAL_15]] step %[[VAL_5]] { -// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_16]]] : memref -// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_16]]] : memref -// CHECK: %[[VAL_19:.*]] = arith.addi %[[VAL_16]], %[[VAL_5]] : index -// CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_19]]] : memref -// CHECK: scf.for %[[VAL_21:.*]] = %[[VAL_18]] to %[[VAL_20]] step %[[VAL_5]] { -// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_21]]] : memref -// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_21]]] : memref -// CHECK: scf.for %[[VAL_24:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] { -// CHECK: %[[VAL_25:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_24]], %[[VAL_22]]] : memref<5x6xi64> -// CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_24]], %[[VAL_17]]] : memref<5x3xi8> -// CHECK: %[[VAL_27:.*]] = arith.extsi %[[VAL_26]] : i8 to i64 -// CHECK: %[[VAL_28:.*]] = arith.subi %[[VAL_27]], %[[VAL_6]] : i64 -// CHECK: %[[VAL_29:.*]] = arith.extsi %[[VAL_23]] : i8 to i64 -// CHECK: %[[VAL_30:.*]] = arith.muli %[[VAL_28]], %[[VAL_29]] : i64 +// CHECK: scf.for %[[VAL_14:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] { +// CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_4]]] : memref +// CHECK: %[[VAL_16:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_5]]] : memref +// CHECK: scf.for %[[VAL_17:.*]] = %[[VAL_15]] to %[[VAL_16]] step %[[VAL_5]] { +// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_17]]] : memref +// CHECK: %[[VAL_19:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_14]], %[[VAL_18]]] : memref<5x3xi8> +// CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_17]]] : memref +// CHECK: %[[VAL_21:.*]] = arith.addi %[[VAL_17]], %[[VAL_5]] : index +// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_21]]] : memref +// CHECK: scf.for %[[VAL_23:.*]] = %[[VAL_20]] to %[[VAL_22]] step %[[VAL_5]] { +// CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_23]]] : memref +// CHECK: %[[VAL_25:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_14]], %[[VAL_24]]] : memref<5x6xi64> +// CHECK: %[[VAL_26:.*]] = arith.extsi %[[VAL_19]] : i8 to i64 +// CHECK: %[[VAL_27:.*]] = arith.subi %[[VAL_26]], %[[VAL_6]] : i64 +// CHECK: %[[VAL_28:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_23]]] : memref +// CHECK: %[[VAL_29:.*]] = arith.extsi %[[VAL_28]] : i8 to i64 +// CHECK: %[[VAL_30:.*]] = arith.muli %[[VAL_27]], %[[VAL_29]] : i64 // CHECK: %[[VAL_31:.*]] = arith.addi %[[VAL_25]], %[[VAL_30]] : i64 -// CHECK: memref.store %[[VAL_31]], %[[VAL_13]]{{\[}}%[[VAL_24]], %[[VAL_22]]] : memref<5x6xi64> -// CHECK: } -// CHECK: } -// CHECK: } +// CHECK: memref.store %[[VAL_31]], %[[VAL_13]]{{\[}}%[[VAL_14]], %[[VAL_24]]] : memref<5x6xi64> +// CHECK: } {"Emitted from" = "linalg.generic"} +// CHECK: } {"Emitted from" = "linalg.generic"} +// CHECK: } {"Emitted from" = "linalg.generic"} // CHECK: %[[VAL_32:.*]] = bufferization.to_tensor %[[VAL_13]] : memref<5x6xi64> // CHECK: return %[[VAL_32]] : tensor<5x6xi64> // CHECK: } diff --git a/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir b/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir @@ -56,49 +56,48 @@ return %1 : tensor<32xf64> } -// CHECK-LABEL: func.func @sampled_dd_unfused( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>>, -// CHECK-SAME: %[[VAL_1:.*]]: tensor<8x8xf64>, -// CHECK-SAME: %[[VAL_2:.*]]: tensor<8x8xf64>) -> tensor<8x8xf64> { -// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 8 : index -// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[VAL_6:.*]] = arith.constant dense<0.000000e+00> : tensor<8x8xf64> -// CHECK: %[[VAL_7:.*]] = bufferization.alloc_tensor() copy(%[[VAL_6]]) {bufferization.escape = [false]} : tensor<8x8xf64> -// CHECK: %[[VAL_8:.*]] = bufferization.alloc_tensor() copy(%[[VAL_6]]) {bufferization.escape = [false], memory_space = 0 : i64} : tensor<8x8xf64> -// CHECK: %[[VAL_9:.*]] = bufferization.to_memref %[[VAL_1]] : memref<8x8xf64> -// CHECK: %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_2]] : memref<8x8xf64> -// CHECK: %[[VAL_11:.*]] = sparse_tensor.pointers %[[VAL_0]] {dimension = 0 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> to memref -// CHECK: %[[VAL_12:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 0 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> to memref -// CHECK: %[[VAL_13:.*]] = sparse_tensor.pointers %[[VAL_0]] {dimension = 1 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> to memref -// CHECK: %[[VAL_14:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 1 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> to memref -// CHECK: %[[VAL_15:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> to memref -// CHECK: %[[VAL_16:.*]] = bufferization.to_memref %[[VAL_8]] : memref<8x8xf64> -// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_4]]] : memref -// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_5]]] : memref -// CHECK: scf.for %[[VAL_19:.*]] = %[[VAL_17]] to %[[VAL_18]] step %[[VAL_5]] { -// CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_19]]] : memref -// CHECK: %[[VAL_21:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_19]]] : memref -// CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_19]], %[[VAL_5]] : index -// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_22]]] : memref -// CHECK: scf.for %[[VAL_24:.*]] = %[[VAL_21]] to %[[VAL_23]] step %[[VAL_5]] { -// CHECK: %[[VAL_25:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_24]]] : memref -// CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_20]], %[[VAL_25]]] : memref<8x8xf64> -// CHECK: %[[VAL_27:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_24]]] : memref -// CHECK: %[[VAL_28:.*]] = scf.for %[[VAL_29:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] iter_args(%[[VAL_30:.*]] = %[[VAL_26]]) -> (f64) { -// CHECK: %[[VAL_31:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_20]], %[[VAL_29]]] : memref<8x8xf64> -// CHECK: %[[VAL_32:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_29]], %[[VAL_25]]] : memref<8x8xf64> -// CHECK: %[[VAL_33:.*]] = arith.mulf %[[VAL_31]], %[[VAL_32]] : f64 -// CHECK: %[[VAL_34:.*]] = arith.mulf %[[VAL_33]], %[[VAL_27]] : f64 -// CHECK: %[[VAL_35:.*]] = arith.addf %[[VAL_30]], %[[VAL_34]] : f64 -// CHECK: scf.yield %[[VAL_35]] : f64 -// CHECK: } -// CHECK: memref.store %[[VAL_24:.*]], %[[VAL_16]]{{\[}}%[[VAL_20]], %[[VAL_25]]] : memref<8x8xf64> -// CHECK: } +// CHECK-LABEL: func.func @sampled_dd_unfused( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<8x8xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<8x8xf64>, +// CHECK-SAME: %[[VAL_2:.*]]: tensor<8x8xf64>) -> tensor<8x8xf64> { +// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 8 : index +// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[VAL_6:.*]] = arith.constant dense<0.000000e+00> : tensor<8x8xf64> +// CHECK-DAG: %[[VAL_7:.*]] = bufferization.alloc_tensor() copy(%[[VAL_6]]) {bufferization.escape = [false]} : tensor<8x8xf64> +// CHECK-DAG: %[[VAL_8:.*]] = bufferization.alloc_tensor() copy(%[[VAL_6]]) {bufferization.escape = [false], memory_space = 0 : i64} : tensor<8x8xf64> +// CHECK-DAG: %[[VAL_9:.*]] = bufferization.to_memref %[[VAL_1]] : memref<8x8xf64> +// CHECK-DAG: %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_2]] : memref<8x8xf64> +// CHECK-DAG: %[[VAL_11:.*]] = sparse_tensor.pointers %[[VAL_0]] {dimension = 0 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> to memref +// CHECK-DAG: %[[VAL_12:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 0 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> to memref +// CHECK-DAG: %[[VAL_13:.*]] = sparse_tensor.pointers %[[VAL_0]] {dimension = 1 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> to memref +// CHECK-DAG: %[[VAL_14:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 1 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> to memref +// CHECK-DAG: %[[VAL_15:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> to memref +// CHECK: %[[VAL_16:.*]] = bufferization.to_memref %[[VAL_8]] : memref<8x8xf64> +// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_4]]] : memref +// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_5]]] : memref +// CHECK: scf.for %[[VAL_19:.*]] = %[[VAL_17]] to %[[VAL_18]] step %[[VAL_5]] { +// CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_19]]] : memref +// CHECK: scf.for %[[VAL_21:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] { +// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_20]], %[[VAL_21]]] : memref<8x8xf64> +// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_19]]] : memref +// CHECK: %[[VAL_24:.*]] = arith.addi %[[VAL_19]], %[[VAL_5]] : index +// CHECK: %[[VAL_25:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_24]]] : memref +// CHECK: scf.for %[[VAL_26:.*]] = %[[VAL_23]] to %[[VAL_25]] step %[[VAL_5]] { +// CHECK: %[[VAL_27:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_26]]] : memref +// CHECK: %[[VAL_28:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_20]], %[[VAL_27]]] : memref<8x8xf64> +// CHECK: %[[VAL_29:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_21]], %[[VAL_27]]] : memref<8x8xf64> +// CHECK: %[[VAL_30:.*]] = arith.mulf %[[VAL_22]], %[[VAL_29]] : f64 +// CHECK: %[[VAL_31:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_26]]] : memref +// CHECK: %[[VAL_32:.*]] = arith.mulf %[[VAL_30]], %[[VAL_31]] : f64 +// CHECK: %[[VAL_33:.*]] = arith.addf %[[VAL_28]], %[[VAL_32]] : f64 +// CHECK: memref.store %[[VAL_33]], %[[VAL_16]]{{\[}}%[[VAL_20]], %[[VAL_27]]] : memref<8x8xf64> +// CHECK: } {"Emitted from" = "linalg.generic"} +// CHECK: } {"Emitted from" = "linalg.generic"} +// CHECK: } {"Emitted from" = "linalg.generic"} +// CHECK: %[[VAL_34:.*]] = bufferization.to_tensor %[[VAL_16]] : memref<8x8xf64> +// CHECK: return %[[VAL_34]] : tensor<8x8xf64> // CHECK: } -// CHECK: %[[VAL_37:.*]] = bufferization.to_tensor %[[VAL_16]] : memref<8x8xf64> -// CHECK: return %[[VAL_37]] : tensor<8x8xf64> -// CHECK: } func.func @sampled_dd_unfused(%args: tensor<8x8xf64, #SM>, %arga: tensor<8x8xf64>, %argb: tensor<8x8xf64>) -> tensor<8x8xf64> {