diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h --- a/mlir/include/mlir-c/Dialect/SparseTensor.h +++ b/mlir/include/mlir-c/Dialect/SparseTensor.h @@ -26,15 +26,15 @@ /// If updating, keep them in sync and update the static_assert in the impl /// file. enum MlirSparseTensorDimLevelType { - MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE, - MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED, - MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU, - MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NO, - MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU_NO, - MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON, - MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU, - MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NO, - MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU_NO, + MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE = 4, // 0b001_00 + MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED = 8, // 0b010_00 + MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU = 9, // 0b010_01 + MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NO = 10, // 0b010_10 + MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU_NO = 11, // 0b010_11 + MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON = 16, // 0b100_00 + MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU = 17, // 0b100_01 + MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NO = 18, // 0b100_10 + MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU_NO = 19, // 0b100_11 }; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td @@ -168,10 +168,16 @@ // // TODO: separate type and property in encoding // - enum class DimLevelType { - Dense, - Compressed, CompressedNu, CompressedNo, CompressedNuNo, - Singleton, SingletonNu, SingletonNo, SingletonNuNo, + enum class DimLevelType : uint8_t { + Dense = 4, // 0b001_00 + Compressed = 8, // 0b010_00 + CompressedNu = 9, // 0b010_01 + CompressedNo = 10, // 0b010_10 + CompressedNuNo = 11, // 0b010_11 + Singleton = 16, // 0b100_00 + SingletonNu = 17, // 0b100_01 + SingletonNo = 18, // 0b100_10 + SingletonNuNo = 19, // 0b100_11 }; }]; } diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/Enums.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/Enums.h --- a/mlir/include/mlir/ExecutionEngine/SparseTensor/Enums.h +++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/Enums.h @@ -146,15 +146,15 @@ /// breaking dependency cycles. `SparseTensorEncodingAttr::DimLevelType` /// is the source of truth and this enum should be kept consistent with it. enum class MLIR_SPARSETENSOR_EXPORT DimLevelType : uint8_t { - kDense = 0, - kCompressed = 1, - kCompressedNu = 2, - kCompressedNo = 3, - kCompressedNuNo = 4, - kSingleton = 5, - kSingletonNu = 6, - kSingletonNo = 7, - kSingletonNuNo = 8, + kDense = 4, // 0b001_00 + kCompressed = 8, // 0b010_00 + kCompressedNu = 9, // 0b010_01 + kCompressedNo = 10, // 0b010_10 + kCompressedNuNo = 11, // 0b010_11 + kSingleton = 16, // 0b100_00 + kSingletonNu = 17, // 0b100_01 + kSingletonNo = 18, // 0b100_10 + kSingletonNuNo = 19, // 0b100_11 }; /// Check if the `DimLevelType` is dense. @@ -164,56 +164,71 @@ /// Check if the `DimLevelType` is compressed (regardless of properties). constexpr MLIR_SPARSETENSOR_EXPORT bool isCompressedDLT(DimLevelType dlt) { - switch (dlt) { - case DimLevelType::kCompressed: - case DimLevelType::kCompressedNu: - case DimLevelType::kCompressedNo: - case DimLevelType::kCompressedNuNo: - return true; - default: - return false; - } + return static_cast(dlt) & + static_cast(DimLevelType::kCompressed); } /// Check if the `DimLevelType` is singleton (regardless of properties). constexpr MLIR_SPARSETENSOR_EXPORT bool isSingletonDLT(DimLevelType dlt) { - switch (dlt) { - case DimLevelType::kSingleton: - case DimLevelType::kSingletonNu: - case DimLevelType::kSingletonNo: - case DimLevelType::kSingletonNuNo: - return true; - default: - return false; - } + return static_cast(dlt) & + static_cast(DimLevelType::kSingleton); } /// Check if the `DimLevelType` is ordered (regardless of storage format). constexpr MLIR_SPARSETENSOR_EXPORT bool isOrderedDLT(DimLevelType dlt) { - switch (dlt) { - case DimLevelType::kCompressedNo: - case DimLevelType::kCompressedNuNo: - case DimLevelType::kSingletonNo: - case DimLevelType::kSingletonNuNo: - return false; - default: - return true; - } + return !(static_cast(dlt) & 2); } /// Check if the `DimLevelType` is unique (regardless of storage format). constexpr MLIR_SPARSETENSOR_EXPORT bool isUniqueDLT(DimLevelType dlt) { - switch (dlt) { - case DimLevelType::kCompressedNu: - case DimLevelType::kCompressedNuNo: - case DimLevelType::kSingletonNu: - case DimLevelType::kSingletonNuNo: - return false; - default: - return true; - } + return !(static_cast(dlt) & 1); } +// Ensure the above predicates work as intended. +static_assert((!isCompressedDLT(DimLevelType::kDense) && + isCompressedDLT(DimLevelType::kCompressed) && + isCompressedDLT(DimLevelType::kCompressedNu) && + isCompressedDLT(DimLevelType::kCompressedNo) && + isCompressedDLT(DimLevelType::kCompressedNuNo) && + !isCompressedDLT(DimLevelType::kSingleton) && + !isCompressedDLT(DimLevelType::kSingletonNu) && + !isCompressedDLT(DimLevelType::kSingletonNo) && + !isCompressedDLT(DimLevelType::kSingletonNuNo)), + "isCompressedDLT definition is broken"); + +static_assert((!isSingletonDLT(DimLevelType::kDense) && + !isSingletonDLT(DimLevelType::kCompressed) && + !isSingletonDLT(DimLevelType::kCompressedNu) && + !isSingletonDLT(DimLevelType::kCompressedNo) && + !isSingletonDLT(DimLevelType::kCompressedNuNo) && + isSingletonDLT(DimLevelType::kSingleton) && + isSingletonDLT(DimLevelType::kSingletonNu) && + isSingletonDLT(DimLevelType::kSingletonNo) && + isSingletonDLT(DimLevelType::kSingletonNuNo)), + "isSingletonDLT definition is broken"); + +static_assert((isOrderedDLT(DimLevelType::kDense) && + isOrderedDLT(DimLevelType::kCompressed) && + isOrderedDLT(DimLevelType::kCompressedNu) && + !isOrderedDLT(DimLevelType::kCompressedNo) && + !isOrderedDLT(DimLevelType::kCompressedNuNo) && + isOrderedDLT(DimLevelType::kSingleton) && + isOrderedDLT(DimLevelType::kSingletonNu) && + !isOrderedDLT(DimLevelType::kSingletonNo) && + !isOrderedDLT(DimLevelType::kSingletonNuNo)), + "isOrderedDLT definition is broken"); + +static_assert((isUniqueDLT(DimLevelType::kDense) && + isUniqueDLT(DimLevelType::kCompressed) && + !isUniqueDLT(DimLevelType::kCompressedNu) && + isUniqueDLT(DimLevelType::kCompressedNo) && + !isUniqueDLT(DimLevelType::kCompressedNuNo) && + isUniqueDLT(DimLevelType::kSingleton) && + !isUniqueDLT(DimLevelType::kSingletonNu) && + isUniqueDLT(DimLevelType::kSingletonNo) && + !isUniqueDLT(DimLevelType::kSingletonNuNo)), + "isUniqueDLT definition is broken"); + } // namespace sparse_tensor } // namespace mlir diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h --- a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h +++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h @@ -634,7 +634,10 @@ "Value position is out of bounds"); // TODO: yield(this->cursor, src.values[parentPos]); - } else if (src.isCompressedDim(d)) { + return; + } + const auto dlt = src.getDimType(d); + if (isCompressedDLT(dlt)) { // Look up the bounds of the `d`-level segment determined by the // `d-1`-level position `parentPos`. const std::vector

&pointersD = src.pointers[d]; @@ -650,11 +653,11 @@ cursorReordD = static_cast(indicesD[pos]); forallElements(yield, pos, d + 1); } - } else if (src.isSingletonDim(d)) { + } else if (isSingletonDLT(dlt)) { this->cursor[this->reord[d]] = src.getIndex(d, parentPos); forallElements(yield, parentPos, d + 1); - } else { // Dense dimension. - assert(src.isDenseDim(d)); // TODO: reuse the ASSERT_DENSE_DIM message + } else { + assert(isDenseDLT(dlt)); // TODO: reuse the ASSERT_DENSE_DIM message const uint64_t sz = src.getDimSizes()[d]; const uint64_t pstart = parentPos * sz; uint64_t &cursorReordD = this->cursor[this->reord[d]]; diff --git a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp --- a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp +++ b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp @@ -87,11 +87,12 @@ } // Verify that the sparsity values are supported. + // TODO: update this check to match what we actually support. for (uint64_t i = 0; i < rank; ++i) if (sparsity[i] != DimLevelType::kDense && sparsity[i] != DimLevelType::kCompressed) - MLIR_SPARSETENSOR_FATAL("Unsupported sparsity value %d\n", - static_cast(sparsity[i])); + MLIR_SPARSETENSOR_FATAL("unsupported dimension level type: %d\n", + static_cast(sparsity[i])); #endif // Convert external format to internal COO. diff --git a/mlir/test/CAPI/sparse_tensor.c b/mlir/test/CAPI/sparse_tensor.c --- a/mlir/test/CAPI/sparse_tensor.c +++ b/mlir/test/CAPI/sparse_tensor.c @@ -43,9 +43,9 @@ mlirSparseTensorEncodingAttrGetHigherOrdering(originalAttr); // CHECK: (d0, d1)[s0] -> (s0, d0, d1) mlirAffineMapDump(higherOrdering); - // CHECK: level_type: 0 - // CHECK: level_type: 1 - // CHECK: level_type: 1 + // CHECK: level_type: 4 + // CHECK: level_type: 8 + // CHECK: level_type: 8 int numLevelTypes = mlirSparseTensorEncodingGetNumDimLevelTypes(originalAttr); enum MlirSparseTensorDimLevelType *levelTypes = malloc(sizeof(enum MlirSparseTensorDimLevelType) * numLevelTypes); diff --git a/mlir/test/Dialect/SparseTensor/conversion_sparse2dense.mlir b/mlir/test/Dialect/SparseTensor/conversion_sparse2dense.mlir --- a/mlir/test/Dialect/SparseTensor/conversion_sparse2dense.mlir +++ b/mlir/test/Dialect/SparseTensor/conversion_sparse2dense.mlir @@ -19,8 +19,8 @@ // CHECK-DAG: %[[I13:.*]] = arith.constant 13 : index // CHECK-DAG: %[[AttrsS:.*]] = memref.alloca() : memref<1xi8> // CHECK-DAG: %[[AttrsD:.*]] = memref.cast %[[AttrsS]] : memref<1xi8> to memref -// CHECK-DAG: %[[Attr0:.*]] = arith.constant 0 : i8 -// CHECK-DAG: memref.store %[[Attr0]], %[[AttrsS]][%[[I0]]] : memref<1xi8> +// CHECK-DAG: %[[DenseDLT:.*]] = arith.constant 4 : i8 +// CHECK-DAG: memref.store %[[DenseDLT]], %[[AttrsS]][%[[I0]]] : memref<1xi8> // CHECK-DAG: %[[SizesS:.*]] = memref.alloca() : memref<1xindex> // CHECK-DAG: %[[SizesD:.*]] = memref.cast %[[SizesS]] : memref<1xindex> to memref // CHECK-DAG: memref.store %[[I13]], %[[SizesS]][%[[I0]]] : memref<1xindex> @@ -56,8 +56,8 @@ // CHECK-DAG: %[[I0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[AttrsS:.*]] = memref.alloca() : memref<1xi8> // CHECK-DAG: %[[AttrsD:.*]] = memref.cast %[[AttrsS]] : memref<1xi8> to memref -// CHECK-DAG: %[[Attr0:.*]] = arith.constant 0 : i8 -// CHECK-DAG: memref.store %[[Attr0]], %[[AttrsS]][%[[I0]]] : memref<1xi8> +// CHECK-DAG: %[[DenseDLT:.*]] = arith.constant 4 : i8 +// CHECK-DAG: memref.store %[[DenseDLT]], %[[AttrsS]][%[[I0]]] : memref<1xi8> // CHECK-DAG: %[[SizesS:.*]] = memref.alloca() : memref<1xindex> // CHECK-DAG: %[[SizesD:.*]] = memref.cast %[[SizesS]] : memref<1xindex> to memref // CHECK-DAG: %[[SizeI0:.*]] = call @sparseDimSize(%[[Arg]], %[[I0]]) : (!llvm.ptr, index) -> index @@ -97,9 +97,9 @@ // CHECK-DAG: %[[I4:.*]] = arith.constant 4 : index // CHECK-DAG: %[[AttrsS:.*]] = memref.alloca() : memref<2xi8> // CHECK-DAG: %[[AttrsD:.*]] = memref.cast %[[AttrsS]] : memref<2xi8> to memref -// CHECK-DAG: %[[Attr0:.*]] = arith.constant 0 : i8 -// CHECK-DAG: memref.store %[[Attr0]], %[[AttrsS]][%[[I0]]] : memref<2xi8> -// CHECK-DAG: memref.store %[[Attr0]], %[[AttrsS]][%[[I1]]] : memref<2xi8> +// CHECK-DAG: %[[DenseDLT:.*]] = arith.constant 4 : i8 +// CHECK-DAG: memref.store %[[DenseDLT]], %[[AttrsS]][%[[I0]]] : memref<2xi8> +// CHECK-DAG: memref.store %[[DenseDLT]], %[[AttrsS]][%[[I1]]] : memref<2xi8> // CHECK-DAG: %[[SizesS:.*]] = memref.alloca() : memref<2xindex> // CHECK-DAG: %[[SizesD:.*]] = memref.cast %[[SizesS]] : memref<2xindex> to memref // CHECK-DAG: memref.store %[[I2]], %[[SizesS]][%[[I0]]] : memref<2xindex> @@ -140,9 +140,9 @@ // CHECK-DAG: %[[I4:.*]] = arith.constant 4 : index // CHECK-DAG: %[[AttrsS:.*]] = memref.alloca() : memref<2xi8> // CHECK-DAG: %[[AttrsD:.*]] = memref.cast %[[AttrsS]] : memref<2xi8> to memref -// CHECK-DAG: %[[Attr0:.*]] = arith.constant 0 : i8 -// CHECK-DAG: memref.store %[[Attr0]], %[[AttrsS]][%[[I0]]] : memref<2xi8> -// CHECK-DAG: memref.store %[[Attr0]], %[[AttrsS]][%[[I1]]] : memref<2xi8> +// CHECK-DAG: %[[DenseDLT:.*]] = arith.constant 4 : i8 +// CHECK-DAG: memref.store %[[DenseDLT]], %[[AttrsS]][%[[I0]]] : memref<2xi8> +// CHECK-DAG: memref.store %[[DenseDLT]], %[[AttrsS]][%[[I1]]] : memref<2xi8> // CHECK-DAG: %[[SizesS:.*]] = memref.alloca() : memref<2xindex> // CHECK-DAG: %[[SizesD:.*]] = memref.cast %[[SizesS]] : memref<2xindex> to memref // CHECK-DAG: %[[SizeI0:.*]] = call @sparseDimSize(%[[Arg]], %[[I0]]) : (!llvm.ptr, index) -> index @@ -184,9 +184,9 @@ // CHECK-DAG: %[[I2:.*]] = arith.constant 2 : index // CHECK-DAG: %[[AttrsS:.*]] = memref.alloca() : memref<2xi8> // CHECK-DAG: %[[AttrsD:.*]] = memref.cast %[[AttrsS]] : memref<2xi8> to memref -// CHECK-DAG: %[[Attr0:.*]] = arith.constant 0 : i8 -// CHECK-DAG: memref.store %[[Attr0]], %[[AttrsS]][%[[I0]]] : memref<2xi8> -// CHECK-DAG: memref.store %[[Attr0]], %[[AttrsS]][%[[I1]]] : memref<2xi8> +// CHECK-DAG: %[[DenseDLT:.*]] = arith.constant 4 : i8 +// CHECK-DAG: memref.store %[[DenseDLT]], %[[AttrsS]][%[[I0]]] : memref<2xi8> +// CHECK-DAG: memref.store %[[DenseDLT]], %[[AttrsS]][%[[I1]]] : memref<2xi8> // CHECK-DAG: %[[SizesS:.*]] = memref.alloca() : memref<2xindex> // CHECK-DAG: %[[SizesD:.*]] = memref.cast %[[SizesS]] : memref<2xindex> to memref // CHECK-DAG: %[[SizeI1:.*]] = call @sparseDimSize(%[[Arg]], %[[I1]]) : (!llvm.ptr, index) -> index @@ -227,9 +227,9 @@ // CHECK-DAG: %[[I1:.*]] = arith.constant 1 : index // CHECK-DAG: %[[AttrsS:.*]] = memref.alloca() : memref<2xi8> // CHECK-DAG: %[[AttrsD:.*]] = memref.cast %[[AttrsS]] : memref<2xi8> to memref -// CHECK-DAG: %[[Attr0:.*]] = arith.constant 0 : i8 -// CHECK-DAG: memref.store %[[Attr0]], %[[AttrsS]][%[[I0]]] : memref<2xi8> -// CHECK-DAG: memref.store %[[Attr0]], %[[AttrsS]][%[[I1]]] : memref<2xi8> +// CHECK-DAG: %[[DenseDLT:.*]] = arith.constant 4 : i8 +// CHECK-DAG: memref.store %[[DenseDLT]], %[[AttrsS]][%[[I0]]] : memref<2xi8> +// CHECK-DAG: memref.store %[[DenseDLT]], %[[AttrsS]][%[[I1]]] : memref<2xi8> // CHECK-DAG: %[[SizesS:.*]] = memref.alloca() : memref<2xindex> // CHECK-DAG: %[[SizesD:.*]] = memref.cast %[[SizesS]] : memref<2xindex> to memref // CHECK-DAG: %[[SizeI0:.*]] = call @sparseDimSize(%[[Arg]], %[[I0]]) : (!llvm.ptr, index) -> index @@ -274,10 +274,10 @@ // CHECK-DAG: %[[I4:.*]] = arith.constant 4 : index // CHECK-DAG: %[[AttrsS:.*]] = memref.alloca() : memref<3xi8> // CHECK-DAG: %[[AttrsD:.*]] = memref.cast %[[AttrsS]] : memref<3xi8> to memref -// CHECK-DAG: %[[Attr0:.*]] = arith.constant 0 : i8 -// CHECK-DAG: memref.store %[[Attr0]], %[[AttrsS]][%[[I0]]] : memref<3xi8> -// CHECK-DAG: memref.store %[[Attr0]], %[[AttrsS]][%[[I1]]] : memref<3xi8> -// CHECK-DAG: memref.store %[[Attr0]], %[[AttrsS]][%[[I2]]] : memref<3xi8> +// CHECK-DAG: %[[DenseDLT:.*]] = arith.constant 4 : i8 +// CHECK-DAG: memref.store %[[DenseDLT]], %[[AttrsS]][%[[I0]]] : memref<3xi8> +// CHECK-DAG: memref.store %[[DenseDLT]], %[[AttrsS]][%[[I1]]] : memref<3xi8> +// CHECK-DAG: memref.store %[[DenseDLT]], %[[AttrsS]][%[[I2]]] : memref<3xi8> // CHECK-DAG: %[[SizesS:.*]] = memref.alloca() : memref<3xindex> // CHECK-DAG: %[[SizesD:.*]] = memref.cast %[[SizesS]] : memref<3xindex> to memref // CHECK-DAG: memref.store %[[I2]], %[[SizesS]][%[[I0]]] : memref<3xindex> diff --git a/mlir/test/Dialect/SparseTensor/sparse_concat.mlir b/mlir/test/Dialect/SparseTensor/sparse_concat.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_concat.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_concat.mlir @@ -14,7 +14,7 @@ // CHECK-DAG: %[[TMP_c6_i32:.*]] = arith.constant 6 : i32 // CHECK-DAG: %[[TMP_c1_i32:.*]] = arith.constant 1 : i32 // CHECK-DAG: %[[TMP_c0_i32:.*]] = arith.constant 0 : i32 -// CHECK-DAG: %[[TMP_c1_i8:.*]] = arith.constant 1 : i8 +// CHECK-DAG: %[[TMP_c8_i8:.*]] = arith.constant 8 : i8 // CHECK-DAG: %[[TMP_c3:.*]] = arith.constant 3 : index // CHECK-DAG: %[[TMP_c1:.*]] = arith.constant 1 : index // CHECK-DAG: %[[TMP_cst:.*]] = arith.constant 0.000000e+00 : f64 @@ -33,8 +33,8 @@ // CHECK: } // CHECK: %[[TMP_1:.*]] = memref.alloca() : memref<2xi8> // CHECK: %[[TMP_2:.*]] = memref.cast %[[TMP_1]] : memref<2xi8> to memref -// CHECK: memref.store %[[TMP_c1_i8]], %[[TMP_1]][%[[TMP_c0]]] : memref<2xi8> -// CHECK: memref.store %[[TMP_c1_i8]], %[[TMP_1]][%[[TMP_c1]]] : memref<2xi8> +// CHECK: memref.store %[[TMP_c8_i8]], %[[TMP_1]][%[[TMP_c0]]] : memref<2xi8> +// CHECK: memref.store %[[TMP_c8_i8]], %[[TMP_1]][%[[TMP_c1]]] : memref<2xi8> // CHECK: %[[TMP_3:.*]] = memref.alloca() : memref<2xindex> // CHECK: %[[TMP_4:.*]] = memref.cast %[[TMP_3]] : memref<2xindex> to memref // CHECK: memref.store %[[TMP_c3]], %[[TMP_3]][%[[TMP_c0]]] : memref<2xindex> @@ -83,11 +83,11 @@ // CHECK-DAG: %[[TMP_c0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[TMP_c5:.*]] = arith.constant 5 : index // CHECK-DAG: %[[TMP_c4:.*]] = arith.constant 4 : index -// CHECK-DAG: %[[TMP_c1_i8:.*]] = arith.constant 1 : i8 +// CHECK-DAG: %[[TMP_c8_i8:.*]] = arith.constant 8 : i8 // CHECK: %[[TMP_0:.*]] = memref.alloca() : memref<2xi8> // CHECK: %[[TMP_1:.*]] = memref.cast %[[TMP_0]] : memref<2xi8> to memref -// CHECK: memref.store %[[TMP_c1_i8]], %[[TMP_0]][%[[TMP_c0]]] : memref<2xi8> -// CHECK: memref.store %[[TMP_c1_i8]], %[[TMP_0]][%[[TMP_c1]]] : memref<2xi8> +// CHECK: memref.store %[[TMP_c8_i8]], %[[TMP_0]][%[[TMP_c0]]] : memref<2xi8> +// CHECK: memref.store %[[TMP_c8_i8]], %[[TMP_0]][%[[TMP_c1]]] : memref<2xi8> // CHECK: %[[TMP_2:.*]] = memref.alloca() : memref<2xindex> // CHECK: %[[TMP_3:.*]] = memref.cast %[[TMP_2]] : memref<2xindex> to memref // CHECK: memref.store %[[TMP_c5]], %[[TMP_2]][%[[TMP_c0]]] : memref<2xindex> @@ -115,8 +115,8 @@ // CHECK: } // CHECK: %[[TMP_11:.*]] = memref.alloca() : memref<2xi8> // CHECK: %[[TMP_12:.*]] = memref.cast %[[TMP_11]] : memref<2xi8> to memref -// CHECK: memref.store %[[TMP_c1_i8]], %[[TMP_11]][%[[TMP_c0]]] : memref<2xi8> -// CHECK: memref.store %[[TMP_c1_i8]], %[[TMP_11]][%[[TMP_c1]]] : memref<2xi8> +// CHECK: memref.store %[[TMP_c8_i8]], %[[TMP_11]][%[[TMP_c0]]] : memref<2xi8> +// CHECK: memref.store %[[TMP_c8_i8]], %[[TMP_11]][%[[TMP_c1]]] : memref<2xi8> // CHECK: %[[TMP_13:.*]] = memref.alloca() : memref<2xindex> // CHECK: %[[TMP_14:.*]] = memref.cast %[[TMP_13]] : memref<2xindex> to memref // CHECK: memref.store %[[TMP_c3]], %[[TMP_13]][%[[TMP_c0]]] : memref<2xindex> @@ -167,11 +167,11 @@ // CHECK-DAG: %[[TMP_c0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[TMP_c4:.*]] = arith.constant 4 : index // CHECK-DAG: %[[TMP_c5:.*]] = arith.constant 5 : index -// CHECK-DAG: %[[TMP_c1_i8:.*]] = arith.constant 1 : i8 +// CHECK-DAG: %[[TMP_c8_i8:.*]] = arith.constant 8 : i8 // CHECK: %[[TMP_0:.*]] = memref.alloca() : memref<2xi8> // CHECK: %[[TMP_1:.*]] = memref.cast %[[TMP_0]] : memref<2xi8> to memref -// CHECK: memref.store %[[TMP_c1_i8]], %[[TMP_0]][%[[TMP_c0]]] : memref<2xi8> -// CHECK: memref.store %[[TMP_c1_i8]], %[[TMP_0]][%[[TMP_c1]]] : memref<2xi8> +// CHECK: memref.store %[[TMP_c8_i8]], %[[TMP_0]][%[[TMP_c0]]] : memref<2xi8> +// CHECK: memref.store %[[TMP_c8_i8]], %[[TMP_0]][%[[TMP_c1]]] : memref<2xi8> // CHECK: %[[TMP_2:.*]] = memref.alloca() : memref<2xindex> // CHECK: %[[TMP_3:.*]] = memref.cast %[[TMP_2]] : memref<2xindex> to memref // CHECK: memref.store %[[TMP_c4]], %[[TMP_2]][%[[TMP_c0]]] : memref<2xindex> @@ -199,8 +199,8 @@ // CHECK: } // CHECK: %[[TMP_11:.*]] = memref.alloca() : memref<2xi8> // CHECK: %[[TMP_12:.*]] = memref.cast %[[TMP_11]] : memref<2xi8> to memref -// CHECK: memref.store %[[TMP_c1_i8]], %[[TMP_11]][%[[TMP_c0]]] : memref<2xi8> -// CHECK: memref.store %[[TMP_c1_i8]], %[[TMP_11]][%[[TMP_c1]]] : memref<2xi8> +// CHECK: memref.store %[[TMP_c8_i8]], %[[TMP_11]][%[[TMP_c0]]] : memref<2xi8> +// CHECK: memref.store %[[TMP_c8_i8]], %[[TMP_11]][%[[TMP_c1]]] : memref<2xi8> // CHECK: %[[TMP_13:.*]] = memref.alloca() : memref<2xindex> // CHECK: %[[TMP_14:.*]] = memref.cast %[[TMP_13]] : memref<2xindex> to memref // CHECK: memref.store %[[TMP_c4]], %[[TMP_13]][%[[TMP_c0]]] : memref<2xindex> @@ -243,7 +243,7 @@ // CHECK-DAG: %[[TMP_c6_i32:.*]] = arith.constant 6 : i32 // CHECK-DAG: %[[TMP_c1_i32:.*]] = arith.constant 1 : i32 // CHECK-DAG: %[[TMP_c0_i32:.*]] = arith.constant 0 : i32 -// CHECK-DAG: %[[TMP_c1_i8:.*]] = arith.constant 1 : i8 +// CHECK-DAG: %[[TMP_c8_i8:.*]] = arith.constant 8 : i8 // CHECK-DAG: %[[TMP_c3:.*]] = arith.constant 3 : index // CHECK-DAG: %[[TMP_c1:.*]] = arith.constant 1 : index // CHECK-DAG: %[[TMP_cst:.*]] = arith.constant 0.000000e+00 : f64 @@ -262,8 +262,8 @@ // CHECK: } // CHECK: %[[TMP_1:.*]] = memref.alloca() : memref<2xi8> // CHECK: %[[TMP_2:.*]] = memref.cast %[[TMP_1]] : memref<2xi8> to memref -// CHECK: memref.store %[[TMP_c1_i8]], %[[TMP_1]][%[[TMP_c0]]] : memref<2xi8> -// CHECK: memref.store %[[TMP_c1_i8]], %[[TMP_1]][%[[TMP_c1]]] : memref<2xi8> +// CHECK: memref.store %[[TMP_c8_i8]], %[[TMP_1]][%[[TMP_c0]]] : memref<2xi8> +// CHECK: memref.store %[[TMP_c8_i8]], %[[TMP_1]][%[[TMP_c1]]] : memref<2xi8> // CHECK: %[[TMP_3:.*]] = memref.alloca() : memref<2xindex> // CHECK: %[[TMP_4:.*]] = memref.cast %[[TMP_3]] : memref<2xindex> to memref // CHECK: memref.store %[[TMP_c4]], %[[TMP_3]][%[[TMP_c0]]] : memref<2xindex> @@ -304,7 +304,7 @@ // CHECK-DAG: %[[TMP_c6_i32:.*]] = arith.constant 6 : i32 // CHECK-DAG: %[[TMP_c1_i32:.*]] = arith.constant 1 : i32 // CHECK-DAG: %[[TMP_c0_i32:.*]] = arith.constant 0 : i32 -// CHECK-DAG: %[[TMP_c1_i8:.*]] = arith.constant 1 : i8 +// CHECK-DAG: %[[TMP_c8_i8:.*]] = arith.constant 8 : i8 // CHECK-DAG: %[[TMP_cst:.*]] = arith.constant 0.000000e+00 : f64 // CHECK-DAG: %[[TMP_c0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[TMP_c3:.*]] = arith.constant 3 : index @@ -323,8 +323,8 @@ // CHECK: } // CHECK: %[[TMP_2:.*]] = memref.alloca() : memref<2xi8> // CHECK: %[[TMP_3:.*]] = memref.cast %[[TMP_2]] : memref<2xi8> to memref -// CHECK: memref.store %[[TMP_c1_i8]], %[[TMP_2]][%[[TMP_c0]]] : memref<2xi8> -// CHECK: memref.store %[[TMP_c1_i8]], %[[TMP_2]][%[[TMP_c1]]] : memref<2xi8> +// CHECK: memref.store %[[TMP_c8_i8]], %[[TMP_2]][%[[TMP_c0]]] : memref<2xi8> +// CHECK: memref.store %[[TMP_c8_i8]], %[[TMP_2]][%[[TMP_c1]]] : memref<2xi8> // CHECK: %[[TMP_4:.*]] = memref.alloca() : memref<2xindex> // CHECK: %[[TMP_5:.*]] = memref.cast %[[TMP_4]] : memref<2xindex> to memref // CHECK: memref.store %[[TMP_c3]], %[[TMP_4]][%[[TMP_c0]]] : memref<2xindex> diff --git a/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir b/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir @@ -14,7 +14,7 @@ // CHECK-DAG: %[[VAL_8:.*]] = arith.constant true // CHECK-DAG: %[[VAL_9:.*]] = arith.constant 100 : index // CHECK-DAG: %[[VAL_10:.*]] = arith.constant 300 : index -// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 1 : i8 +// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 8 : i8 // CHECK: %[[VAL_12:.*]] = memref.alloca() : memref<2xi8> // CHECK: %[[VAL_13:.*]] = memref.cast %[[VAL_12]] : memref<2xi8> to memref // CHECK: memref.store %[[VAL_11]], %[[VAL_12]]{{\[}}%[[VAL_5]]] : memref<2xi8>