diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h --- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h @@ -171,27 +171,26 @@ switch (dlt) { // TODO: should probably raise an error instead of printing it... case DimLevelType::Undef: - return "\"undef\""; + return "undef"; case DimLevelType::Dense: - return "\"dense\""; + return "dense"; case DimLevelType::Compressed: - return "\"compressed\""; + return "compressed"; case DimLevelType::CompressedNu: - return "\"compressed-nu\""; + return "compressed-nu"; case DimLevelType::CompressedNo: - return "\"compressed-no\""; + return "compressed-no"; case DimLevelType::CompressedNuNo: - return "\"compressed-nu-no\""; + return "compressed-nu-no"; case DimLevelType::Singleton: - return "\"singleton\""; + return "singleton"; case DimLevelType::SingletonNu: - return "\"singleton-nu\""; + return "singleton-nu"; case DimLevelType::SingletonNo: - return "\"singleton-no\""; + return "singleton-no"; case DimLevelType::SingletonNuNo: - return "\"singleton-nu-no\""; + return "singleton-nu-no"; } - return ""; } /// Check that the `DimLevelType` contains a valid (possibly undefined) value. diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -174,7 +174,7 @@ // Print the struct-like storage in dictionary fashion. printer << "<{ dimLevelType = [ "; for (unsigned i = 0, e = getDimLevelType().size(); i < e; i++) { - printer << toMLIRString(getDimLevelType()[i]); + printer << "\"" << toMLIRString(getDimLevelType()[i]) << "\""; if (i != e - 1) printer << ", "; } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -21,6 +21,7 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SparseTensor/IR/Enums.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -487,7 +488,7 @@ StringRef namePrefix, FuncGeneratorType createFunc) { // The mangled name of the function has this format: - // _[C|S|D]___ + // ____ // __ RankedTensorType rtp = desc.getTensorType(); SmallString<32> nameBuffer; @@ -496,13 +497,7 @@ unsigned rank = rtp.getShape().size(); assert(rank == indices.size()); for (unsigned d = 0; d < rank; d++) { - if (isCompressedDim(rtp, d)) { - nameOstream << "C_"; - } else if (isSingletonDim(rtp, d)) { - nameOstream << "S_"; - } else { - nameOstream << "D_"; - } + nameOstream << toMLIRString(getDimLevelType(rtp, d)) << "_"; } // Static dim sizes are used in the generated code while dynamic sizes are // loaded from the dimSizes buffer. This is the reason for adding the shape diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir --- a/mlir/test/Dialect/SparseTensor/codegen.mlir +++ b/mlir/test/Dialect/SparseTensor/codegen.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s --sparse-tensor-codegen --canonicalize --cse | FileCheck %s +// RUN: mlir-opt %s --sparse-tensor-codegen --canonicalize -cse | FileCheck %s #SV = #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }> @@ -46,6 +46,10 @@ dimOrdering = affine_map<(i, j, k) -> (k, i, j)> }> +#Coo = #sparse_tensor.encoding<{ + dimLevelType = [ "compressed-nu", "singleton" ] +}> + // CHECK-LABEL: func @sparse_nop( // CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, // CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, @@ -381,7 +385,7 @@ return %added : memref } -// CHECK-LABEL: func.func private @_insert_C_100_f64_0_0( +// CHECK-LABEL: func.func private @_insert_compressed_100_f64_0_0( // CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, // CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, // CHECK-SAME: %[[A2:.*2]]: memref, @@ -411,7 +415,7 @@ // CHECK-SAME: iter_args(%[[P0:.*]] = %[[A0]], %[[P1:.*]] = %[[A1]], %[[P2:.*]] = %[[A2]], %[[P3:.*]] = %[[A3]], %[[P4:.*]] = %[[A4]]) -> (memref<1xindex>, memref<3xindex>, memref, memref, memref) { // CHECK: %[[INDEX:.*]] = memref.load %[[A7]][%[[I]]] : memref // CHECK: %[[VAL:.*]] = memref.load %[[A5]][%[[INDEX]]] : memref -// CHECK: %[[C:.*]]:5 = func.call @_insert_C_100_f64_0_0(%[[P0]], %[[P1]], %[[P2]], %[[P3]], %[[P4]], %[[INDEX]], %[[VAL]]) +// CHECK: %[[C:.*]]:5 = func.call @_insert_compressed_100_f64_0_0(%[[P0]], %[[P1]], %[[P2]], %[[P3]], %[[P4]], %[[INDEX]], %[[VAL]]) // CHECK: memref.store %[[F0]], %[[A5]][%[[INDEX]]] : memref // CHECK: memref.store %[[B0]], %[[A6]][%[[INDEX]]] : memref // CHECK: scf.yield %[[C]]#0, %[[C]]#1, %[[C]]#2, %[[C]]#3, %[[C]]#4 : memref<1xindex>, memref<3xindex>, memref, memref, memref @@ -432,7 +436,7 @@ return %1 : tensor<100xf64, #SV> } -// CHECK-LABEL: func.func private @_insert_D_C_8_8_f64_64_32( +// CHECK-LABEL: func.func private @_insert_dense_compressed_8_8_f64_64_32( // CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, // CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, // CHECK-SAME: %[[A2:.*2]]: memref, @@ -464,7 +468,7 @@ // CHECK-SAME: iter_args(%[[P0:.*]] = %[[A0]], %[[P1:.*]] = %[[A1]], %[[P2:.*]] = %[[A2]], %[[P3:.*]] = %[[A3]], %[[P4:.*]] = %[[A4]]) -> (memref<2xindex>, memref<3xindex>, memref, memref, memref) { // CHECK: %[[INDEX:.*]] = memref.load %[[A7]][%[[I]]] : memref // CHECK: %[[VAL:.*]] = memref.load %[[A5]][%[[INDEX]]] : memref -// CHECK: %[[C:.*]]:5 = func.call @_insert_D_C_8_8_f64_64_32(%[[P0]], %[[P1]], %[[P2]], %[[P3]], %[[P4]], %[[A9]], %[[INDEX]], %[[VAL]]) +// CHECK: %[[C:.*]]:5 = func.call @_insert_dense_compressed_8_8_f64_64_32(%[[P0]], %[[P1]], %[[P2]], %[[P3]], %[[P4]], %[[A9]], %[[INDEX]], %[[VAL]]) // CHECK: memref.store %[[F0]], %[[A5]][%[[INDEX]]] : memref // CHECK: memref.store %[[B0]], %[[A6]][%[[INDEX]]] : memref // CHECK: scf.yield %[[C]]#0, %[[C]]#1, %[[C]]#2, %[[C]]#3, %[[C]]#4 : memref<2xindex>, memref<3xindex>, memref, memref, memref @@ -486,7 +490,7 @@ return %1 : tensor<8x8xf64, #CSR> } -// CHECK-LABEL: func.func private @_insert_D_C_8_8_f64_0_0( +// CHECK-LABEL: func.func private @"_insert_dense_compressed-no_8_8_f64_0_0"( // CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, // CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, // CHECK-SAME: %[[A2:.*2]]: memref, @@ -518,7 +522,7 @@ // CHECK-SAME: iter_args(%[[P0:.*]] = %[[A0]], %[[P1:.*]] = %[[A1]], %[[P2:.*]] = %[[A2]], %[[P3:.*]] = %[[A3]], %[[P4:.*]] = %[[A4]]) -> (memref<2xindex>, memref<3xindex>, memref, memref, memref) { // CHECK: %[[INDEX:.*]] = memref.load %[[A7]][%[[I]]] : memref // CHECK: %[[VAL:.*]] = memref.load %[[A5]][%[[INDEX]]] : memref -// CHECK: %[[C:.*]]:5 = func.call @_insert_D_C_8_8_f64_0_0(%[[P0]], %[[P1]], %[[P2]], %[[P3]], %[[P4]], %[[A9]], %[[INDEX]], %[[VAL]]) +// CHECK: %[[C:.*]]:5 = func.call @"_insert_dense_compressed-no_8_8_f64_0_0"(%[[P0]], %[[P1]], %[[P2]], %[[P3]], %[[P4]], %[[A9]], %[[INDEX]], %[[VAL]]) // CHECK: memref.store %[[F0]], %[[A5]][%[[INDEX]]] : memref // CHECK: memref.store %[[B0]], %[[A6]][%[[INDEX]]] : memref // CHECK: scf.yield %[[C]]#0, %[[C]]#1, %[[C]]#2, %[[C]]#3, %[[C]]#4 : memref<2xindex>, memref<3xindex>, memref, memref, memref @@ -540,7 +544,7 @@ return %1 : tensor<8x8xf64, #UCSR> } -// CHECK-LABEL: func.func private @_insert_C_128_f64_0_0( +// CHECK-LABEL: func.func private @_insert_compressed_128_f64_0_0( // CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, // CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, // CHECK-SAME: %[[A2:.*2]]: memref, @@ -558,7 +562,7 @@ // CHECK-SAME: %[[A4:.*4]]: memref, // CHECK-SAME: %[[A5:.*5]]: index, // CHECK-SAME: %[[A6:.*6]]: f64) -// CHECK: %[[R:.*]]:5 = call @_insert_C_128_f64_0_0(%[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]]) +// CHECK: %[[R:.*]]:5 = call @_insert_compressed_128_f64_0_0(%[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]]) // CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3, %[[R]]#4 // CHECK-SAME: memref<1xindex>, memref<3xindex>, memref, memref, memref func.func @sparse_insert(%arg0: tensor<128xf64, #SV>, %arg1: index, %arg2: f64) -> tensor<128xf64, #SV> { @@ -567,7 +571,7 @@ return %1 : tensor<128xf64, #SV> } -// CHECK-LABEL: func.func private @_insert_C_128_f64_64_32( +// CHECK-LABEL: func.func private @_insert_compressed_128_f64_64_32( // CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, // CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, // CHECK-SAME: %[[A2:.*2]]: memref, @@ -585,7 +589,7 @@ // CHECK-SAME: %[[A4:.*4]]: memref, // CHECK-SAME: %[[A5:.*5]]: index, // CHECK-SAME: %[[A6:.*6]]: f64) -// CHECK: %[[R:.*]]:5 = call @_insert_C_128_f64_64_32(%[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]]) +// CHECK: %[[R:.*]]:5 = call @_insert_compressed_128_f64_64_32(%[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]]) // CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3, %[[R]]#4 // CHECK-SAME: memref<1xindex>, memref<3xindex>, memref, memref, memref func.func @sparse_insert_typed(%arg0: tensor<128xf64, #SparseVector>, %arg1: index, %arg2: f64) -> tensor<128xf64, #SparseVector> { @@ -594,6 +598,37 @@ return %1 : tensor<128xf64, #SparseVector> } +// CHECK-LABEL: func.func private @"_insert_compressed-nu_singleton_5_6_f64_0_0"( +// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, +// CHECK-SAME: %[[A1:.*1]]: memref<4xindex>, +// CHECK-SAME: %[[A2:.*2]]: memref, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A5:.*5]]: memref, +// CHECK-SAME: %[[A6:.*6]]: index, +// CHECK-SAME: %[[A7:.*7]]: index, +// CHECK-SAME: %[[A8:.*8]]: f64) +// CHECK: %[[P0:.*]] = sparse_tensor.push_back %[[A1]], %[[A3]], %[[A6]] +// CHECK: %[[P1:.*]] = sparse_tensor.push_back %[[A1]], %[[A4]], %[[A7]] +// CHECK: %[[P2:.*]] = sparse_tensor.push_back %[[A1]], %[[A5]], %[[A8]] +// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[P0]], %[[P1]], %[[P2]] +// CHECK: func.func @sparse_insert_coo( +// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, +// CHECK-SAME: %[[A1:.*1]]: memref<4xindex>, +// CHECK-SAME: %[[A2:.*2]]: memref, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A5:.*5]]: memref, +// CHECK-SAME: %[[A6:.*6]]: index, +// CHECK-SAME: %[[A7:.*7]]: f64) +// CHECK: %[[R:.*]]:6 = call @"_insert_compressed-nu_singleton_5_6_f64_0_0"(%[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A6]], %[[A7]]) +// CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3, %[[R]]#4, %[[R]]#5 +func.func @sparse_insert_coo(%arg0: tensor<5x6xf64, #Coo>, %arg1: index, %arg2: f64) -> tensor<5x6xf64, #Coo> { + %0 = sparse_tensor.insert %arg2 into %arg0[%arg1, %arg1] : tensor<5x6xf64, #Coo> + %1 = sparse_tensor.load %0 hasInserts : tensor<5x6xf64, #Coo> + return %1 : tensor<5x6xf64, #Coo> +} + // CHECK-LABEL: func.func @sparse_nop_convert( // CHECK-SAME: %[[A0:.*0]]: memref<1xindex>, // CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, diff --git a/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir b/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir @@ -12,7 +12,7 @@ // // Computes C = A x B with all matrices sparse (SpMSpM) in CSR. // -// CHECK-LABEL: func.func private @_insert_D_C_4_4_f64_0_0( +// CHECK-LABEL: func.func private @_insert_dense_compressed_4_4_f64_0_0( // CHECK-SAME: %[[VAL_0:.*]]: memref<2xindex>, // CHECK-SAME: %[[VAL_1:.*]]: memref<3xindex>, // CHECK-SAME: %[[VAL_2:[^ ]+]]: memref, @@ -118,7 +118,7 @@ // CHECK: %[[VAL_63:.*]]:5 = scf.for %[[VAL_64:.*]] = %[[VAL_12]] to %[[VAL_62]] step %[[VAL_13]] iter_args(%[[VAL_65:.*]] = %[[VAL_32]], %[[VAL_66:.*]] = %[[VAL_33]], %[[VAL_67:.*]] = %[[VAL_34]], %[[VAL_68:.*]] = %[[VAL_35]], %[[VAL_69:.*]] = %[[VAL_36]]) -> (memref<2xindex>, memref<3xindex>, memref, memref, memref) { // CHECK: %[[VAL_70:.*]] = memref.load %[[VAL_28]]{{\[}}%[[VAL_64]]] : memref<4xindex> // CHECK: %[[VAL_71:.*]] = memref.load %[[VAL_26]]{{\[}}%[[VAL_70]]] : memref<4xf64> -// CHECK: %[[VAL_72:.*]]:5 = func.call @_insert_D_C_4_4_f64_0_0(%[[VAL_65]], %[[VAL_66]], %[[VAL_67]], %[[VAL_68]], %[[VAL_69]], %[[VAL_31]], %[[VAL_70]], %[[VAL_71]]) : (memref<2xindex>, memref<3xindex>, memref, memref, memref, index, index, f64) -> (memref<2xindex>, memref<3xindex>, memref, memref, memref) +// CHECK: %[[VAL_72:.*]]:5 = func.call @_insert_dense_compressed_4_4_f64_0_0(%[[VAL_65]], %[[VAL_66]], %[[VAL_67]], %[[VAL_68]], %[[VAL_69]], %[[VAL_31]], %[[VAL_70]], %[[VAL_71]]) : (memref<2xindex>, memref<3xindex>, memref, memref, memref, index, index, f64) -> (memref<2xindex>, memref<3xindex>, memref, memref, memref) // CHECK: memref.store %[[VAL_11]], %[[VAL_26]]{{\[}}%[[VAL_70]]] : memref<4xf64> // CHECK: memref.store %[[VAL_14]], %[[VAL_27]]{{\[}}%[[VAL_70]]] : memref<4xi1> // CHECK: scf.yield %[[VAL_72]]#0, %[[VAL_72]]#1, %[[VAL_72]]#2, %[[VAL_72]]#3, %[[VAL_72]]#4 : memref<2xindex>, memref<3xindex>, memref, memref, memref