diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -853,13 +853,13 @@ }; /// Defines a type for "pointer" and "index" storage in the sparse storage -/// scheme, with a choice between the native platform-dependent index width, -/// 64-bit integers, or 32-bit integers. A narrow width obviously reduces +/// scheme, with a choice between the native platform-dependent index width +/// or any of 64-/32-/16-/8-bit integers. A narrow width obviously reduces /// the memory footprint of the sparse storage scheme, but the width should /// suffice to define the total required range (viz. the maximum number of /// stored entries per indirection level for the "pointers" and the maximum /// value of each tensor index over all dimensions for the "indices"). -enum class SparseIntType { kNative, kI64, kI32 }; +enum class SparseIntType { kNative, kI64, kI32, kI16, kI8 }; /// Sparsification options. struct SparsificationOptions { diff --git a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp @@ -512,6 +512,10 @@ return rewriter.getIntegerType(64); case linalg::SparseIntType::kI32: return rewriter.getIntegerType(32); + case linalg::SparseIntType::kI16: + return rewriter.getIntegerType(16); + case linalg::SparseIntType::kI8: + return rewriter.getIntegerType(8); } llvm_unreachable("unexpected SparseIntType"); } diff --git a/mlir/test/Dialect/Linalg/sparse_storage.mlir b/mlir/test/Dialect/Linalg/sparse_storage.mlir --- a/mlir/test/Dialect/Linalg/sparse_storage.mlir +++ b/mlir/test/Dialect/Linalg/sparse_storage.mlir @@ -6,6 +6,10 @@ // RUN: FileCheck %s --check-prefix=CHECK-TYPE2 // RUN: mlir-opt %s -test-sparsification="ptr-type=2 ind-type=2" | \ // RUN: FileCheck %s --check-prefix=CHECK-TYPE3 +// RUN: mlir-opt %s -test-sparsification="ptr-type=3 ind-type=3" | \ +// RUN: FileCheck %s --check-prefix=CHECK-TYPE4 +// RUN: mlir-opt %s -test-sparsification="ptr-type=4 ind-type=4" | \ +// RUN: FileCheck %s --check-prefix=CHECK-TYPE5 #trait_mul_1d = { indexing_maps = [ @@ -86,6 +90,38 @@ // CHECK-TYPE3: store %[[MUL]], %{{.*}}[%[[INDC]]] : memref<32xf64> // CHECK-TYPE3: } +// CHECK-TYPE4-LABEL: func @mul_dd( +// CHECK-TYPE4: %[[C0:.*]] = constant 0 : index +// CHECK-TYPE4: %[[C1:.*]] = constant 1 : index +// CHECK-TYPE4: %[[P0:.*]] = load %{{.*}}[%[[C0]]] : memref +// CHECK-TYPE4: %[[B0:.*]] = index_cast %[[P0]] : i16 to index +// CHECK-TYPE4: %[[P1:.*]] = load %{{.*}}[%[[C1]]] : memref +// CHECK-TYPE4: %[[B1:.*]] = index_cast %[[P1]] : i16 to index +// CHECK-TYPE4: scf.for %[[I:.*]] = %[[B0]] to %[[B1]] step %[[C1]] { +// CHECK-TYPE4: %[[IND0:.*]] = load %{{.*}}[%[[I]]] : memref +// CHECK-TYPE4: %[[INDC:.*]] = index_cast %[[IND0]] : i16 to index +// CHECK-TYPE4: %[[VAL0:.*]] = load %{{.*}}[%[[I]]] : memref +// CHECK-TYPE4: %[[VAL1:.*]] = load %{{.*}}[%[[INDC]]] : memref<32xf64> +// CHECK-TYPE4: %[[MUL:.*]] = mulf %[[VAL0]], %[[VAL1]] : f64 +// CHECK-TYPE4: store %[[MUL]], %{{.*}}[%[[INDC]]] : memref<32xf64> +// CHECK-TYPE4: } + +// CHECK-TYPE5-LABEL: func @mul_dd( +// CHECK-TYPE5: %[[C0:.*]] = constant 0 : index +// CHECK-TYPE5: %[[C1:.*]] = constant 1 : index +// CHECK-TYPE5: %[[P0:.*]] = load %{{.*}}[%[[C0]]] : memref +// CHECK-TYPE5: %[[B0:.*]] = index_cast %[[P0]] : i8 to index +// CHECK-TYPE5: %[[P1:.*]] = load %{{.*}}[%[[C1]]] : memref +// CHECK-TYPE5: %[[B1:.*]] = index_cast %[[P1]] : i8 to index +// CHECK-TYPE5: scf.for %[[I:.*]] = %[[B0]] to %[[B1]] step %[[C1]] { +// CHECK-TYPE5: %[[IND0:.*]] = load %{{.*}}[%[[I]]] : memref +// CHECK-TYPE5: %[[INDC:.*]] = index_cast %[[IND0]] : i8 to index +// CHECK-TYPE5: %[[VAL0:.*]] = load %{{.*}}[%[[I]]] : memref +// CHECK-TYPE5: %[[VAL1:.*]] = load %{{.*}}[%[[INDC]]] : memref<32xf64> +// CHECK-TYPE5: %[[MUL:.*]] = mulf %[[VAL0]], %[[VAL1]] : f64 +// CHECK-TYPE5: store %[[MUL]], %{{.*}}[%[[INDC]]] : memref<32xf64> +// CHECK-TYPE5: } + func @mul_dd(%arga: tensor<32xf64>, %argb: tensor<32xf64>) -> tensor<32xf64> { %0 = linalg.generic #trait_mul_1d ins(%arga, %argb: tensor<32xf64>, tensor<32xf64>) diff --git a/mlir/test/lib/Transforms/TestSparsification.cpp b/mlir/test/lib/Transforms/TestSparsification.cpp --- a/mlir/test/lib/Transforms/TestSparsification.cpp +++ b/mlir/test/lib/Transforms/TestSparsification.cpp @@ -82,6 +82,10 @@ return linalg::SparseIntType::kI64; case 2: return linalg::SparseIntType::kI32; + case 3: + return linalg::SparseIntType::kI16; + case 4: + return linalg::SparseIntType::kI8; } }