diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensorUtils.h b/mlir/include/mlir/ExecutionEngine/SparseTensorUtils.h --- a/mlir/include/mlir/ExecutionEngine/SparseTensorUtils.h +++ b/mlir/include/mlir/ExecutionEngine/SparseTensorUtils.h @@ -18,8 +18,22 @@ extern "C" { -/// Encoding of the elemental type, for "overloading" @newSparseTensor. -enum class OverheadType : uint32_t { kU64 = 1, kU32 = 2, kU16 = 3, kU8 = 4 }; +/// This type is used in the public API at all places where MLIR expects +/// values with the built-in type "index". For now, we simply assume that +/// type is 64-bit, but targets with different "index" bit widths should link +/// with an alternatively built runtime support library. +// TODO: support such targets? +using index_t = uint64_t; + +/// Encoding of overhead types (both pointer overhead and indices +/// overhead), for "overloading" @newSparseTensor. +enum class OverheadType : uint32_t { + kIndex = 0, + kU64 = 1, + kU32 = 2, + kU16 = 3, + kU8 = 4 +}; /// Encoding of the elemental type, for "overloading" @newSparseTensor. enum class PrimaryType : uint32_t { diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -74,7 +74,7 @@ Location loc, unsigned width) { OverheadType sec; switch (width) { - default: + case 64: sec = OverheadType::kU64; break; case 32: @@ -86,6 +86,11 @@ case 8: sec = OverheadType::kU8; break; + case 0: + sec = OverheadType::kIndex; + break; + default: + llvm_unreachable("Unsupported overhead bitwidth"); } return constantI32(rewriter, loc, static_cast(sec)); } diff --git a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp --- a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp +++ b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp @@ -686,13 +686,6 @@ extern "C" { -/// This type is used in the public API at all places where MLIR expects -/// values with the built-in type "index". For now, we simply assume that -/// type is 64-bit, but targets with different "index" bit widths should link -/// with an alternatively built runtime support library. -// TODO: support such targets? -using index_t = uint64_t; - //===----------------------------------------------------------------------===// // // Public API with methods that operate on MLIR buffers (memrefs) to interact @@ -821,6 +814,12 @@ cursor, values, filled, added, count); \ } +// Assume index_t is in fact uint64_t, so that _mlir_ciface_newSparseTensor +// can safely rewrite kIndex to kU64. We make this assertion to guarantee +// that this file cannot get out of sync with its header. +static_assert(std::is_same::value, + "Expected index_t == uint64_t"); + /// Constructs a new sparse tensor. This is the "swiss army knife" /// method for materializing sparse tensors into the computation. /// @@ -846,6 +845,13 @@ const index_t *perm = pref->data + pref->offset; uint64_t rank = aref->sizes[0]; + // Rewrite kIndex to kU64, to avoid introducing a bunch of new cases. + // This is safe because of the static_assert above. + if (ptrTp == OverheadType::kIndex) + ptrTp = OverheadType::kU64; + if (indTp == OverheadType::kIndex) + indTp = OverheadType::kU64; + // Double matrices with all combinations of overhead storage. CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF64, uint64_t, uint64_t, double); diff --git a/mlir/test/Dialect/SparseTensor/conversion_sparse2dense.mlir b/mlir/test/Dialect/SparseTensor/conversion_sparse2dense.mlir --- a/mlir/test/Dialect/SparseTensor/conversion_sparse2dense.mlir +++ b/mlir/test/Dialect/SparseTensor/conversion_sparse2dense.mlir @@ -27,16 +27,15 @@ // CHECK-DAG: %[[PermS:.*]] = memref.alloca() : memref<1xindex> // CHECK-DAG: %[[PermD:.*]] = memref.cast %[[PermS]] : memref<1xindex> to memref // CHECK-DAG: memref.store %[[I0]], %[[PermS]][%[[I0]]] : memref<1xindex> -// CHECK-DAG: %[[SecTp:.*]] = arith.constant 1 : i32 +// CHECK-DAG: %[[zeroI32:.*]] = arith.constant 0 : i32 // CHECK-DAG: %[[ElemTp:.*]] = arith.constant 4 : i32 // CHECK-DAG: %[[ActionToIter:.*]] = arith.constant 5 : i32 -// CHECK-DAG: %[[Iter:.*]] = call @newSparseTensor(%[[AttrsD]], %[[SizesD]], %[[PermD]], %[[SecTp]], %[[SecTp]], %[[ElemTp]], %[[ActionToIter]], %[[Arg]]) : (memref, memref, memref, i32, i32, i32, i32, !llvm.ptr) -> !llvm.ptr +// CHECK-DAG: %[[Iter:.*]] = call @newSparseTensor(%[[AttrsD]], %[[SizesD]], %[[PermD]], %[[zeroI32]], %[[zeroI32]], %[[ElemTp]], %[[ActionToIter]], %[[Arg]]) : (memref, memref, memref, i32, i32, i32, i32, !llvm.ptr) -> !llvm.ptr // CHECK-DAG: %[[IndS:.*]] = memref.alloca() : memref<1xindex> // CHECK-DAG: %[[IndD:.*]] = memref.cast %[[IndS]] : memref<1xindex> to memref // CHECK-DAG: %[[ElemBuffer:.*]] = memref.alloca() : memref // CHECK-DAG: %[[M:.*]] = memref.alloc() : memref<13xi32> -// CHECK-DAG: %[[E0:.*]] = arith.constant 0 : i32 -// CHECK-DAG: linalg.fill(%[[E0]], %[[M]]) : i32, memref<13xi32> +// CHECK-DAG: linalg.fill(%[[zeroI32]], %[[M]]) : i32, memref<13xi32> // CHECK: scf.while : () -> () { // CHECK: %[[Cond:.*]] = call @getNextI32(%[[Iter]], %[[IndD]], %[[ElemBuffer]]) : (!llvm.ptr, memref, memref) -> i1 // CHECK: scf.condition(%[[Cond]]) @@ -67,16 +66,15 @@ // CHECK-DAG: %[[PermS:.*]] = memref.alloca() : memref<1xindex> // CHECK-DAG: %[[PermD:.*]] = memref.cast %[[PermS]] : memref<1xindex> to memref // CHECK-DAG: memref.store %[[I0]], %[[PermS]][%[[I0]]] : memref<1xindex> -// CHECK-DAG: %[[SecTp:.*]] = arith.constant 1 : i32 +// CHECK-DAG: %[[zeroI32:.*]] = arith.constant 0 : i32 // CHECK-DAG: %[[ElemTp:.*]] = arith.constant 4 : i32 // CHECK-DAG: %[[ActionToIter:.*]] = arith.constant 5 : i32 -// CHECK-DAG: %[[Iter:.*]] = call @newSparseTensor(%[[AttrsD]], %[[SizesD]], %[[PermD]], %[[SecTp]], %[[SecTp]], %[[ElemTp]], %[[ActionToIter]], %[[Arg]]) : (memref, memref, memref, i32, i32, i32, i32, !llvm.ptr) -> !llvm.ptr +// CHECK-DAG: %[[Iter:.*]] = call @newSparseTensor(%[[AttrsD]], %[[SizesD]], %[[PermD]], %[[zeroI32]], %[[zeroI32]], %[[ElemTp]], %[[ActionToIter]], %[[Arg]]) : (memref, memref, memref, i32, i32, i32, i32, !llvm.ptr) -> !llvm.ptr // CHECK-DAG: %[[IndS:.*]] = memref.alloca() : memref<1xindex> // CHECK-DAG: %[[IndD:.*]] = memref.cast %[[IndS]] : memref<1xindex> to memref // CHECK-DAG: %[[ElemBuffer:.*]] = memref.alloca() : memref // CHECK-DAG: %[[M:.*]] = memref.alloc(%[[SizeI0]]) : memref -// CHECK-DAG: %[[E0:.*]] = arith.constant 0 : i32 -// CHECK-DAG: linalg.fill(%[[E0]], %[[M]]) : i32, memref +// CHECK-DAG: linalg.fill(%[[zeroI32]], %[[M]]) : i32, memref // CHECK: scf.while : () -> () { // CHECK: %[[Cond:.*]] = call @getNextI32(%[[Iter]], %[[IndD]], %[[ElemBuffer]]) : (!llvm.ptr, memref, memref) -> i1 // CHECK: scf.condition(%[[Cond]])