diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensorUtils.h b/mlir/include/mlir/ExecutionEngine/SparseTensorUtils.h --- a/mlir/include/mlir/ExecutionEngine/SparseTensorUtils.h +++ b/mlir/include/mlir/ExecutionEngine/SparseTensorUtils.h @@ -18,8 +18,22 @@ extern "C" { -/// Encoding of the elemental type, for "overloading" @newSparseTensor. -enum class OverheadType : uint32_t { kU64 = 1, kU32 = 2, kU16 = 3, kU8 = 4 }; +/// This type is used in the public API at all places where MLIR expects +/// values with the built-in type "index". For now, we simply assume that +/// type is 64-bit, but targets with different "index" bit widths should link +/// with an alternatively built runtime support library. +// TODO: support such targets? +typedef uint64_t index_t; + +/// Encoding of overhead types (both pointer overhead and indices +/// overhead), for "overloading" @newSparseTensor. +enum class OverheadType : uint32_t { + kIndex = 0, + kU64 = 1, + kU32 = 2, + kU16 = 3, + kU8 = 4 +}; /// Encoding of the elemental type, for "overloading" @newSparseTensor. enum class PrimaryType : uint32_t { diff --git a/mlir/lib/Dialect/SparseTensor/Utils/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Utils/CodegenUtils.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/CodegenUtils.cpp @@ -26,7 +26,7 @@ OverheadType overheadTypeEncoding(unsigned width) { switch (width) { - default: + case 64: return OverheadType::kU64; case 32: return OverheadType::kU32; @@ -34,11 +34,16 @@ return OverheadType::kU16; case 8: return OverheadType::kU8; + case 0: + return OverheadType::kIndex; } + llvm_unreachable("Unsupported overhead bitwidth"); } Type getOverheadType(Builder &builder, OverheadType ot) { switch (ot) { + case OverheadType::kIndex: + return builder.getIndexType(); case OverheadType::kU64: return builder.getIntegerType(64); case OverheadType::kU32: @@ -53,20 +58,13 @@ Type getPointerOverheadType(Builder &builder, const SparseTensorEncodingAttr &enc) { - // NOTE(wrengr): This workaround will be fixed in D115010. - unsigned width = enc.getPointerBitWidth(); - if (width == 0) - return builder.getIndexType(); - return getOverheadType(builder, overheadTypeEncoding(width)); + return getOverheadType(builder, + overheadTypeEncoding(enc.getPointerBitWidth())); } Type getIndexOverheadType(Builder &builder, const SparseTensorEncodingAttr &enc) { - // NOTE(wrengr): This workaround will be fixed in D115010. - unsigned width = enc.getIndexBitWidth(); - if (width == 0) - return builder.getIndexType(); - return getOverheadType(builder, overheadTypeEncoding(width)); + return getOverheadType(builder, overheadTypeEncoding(enc.getIndexBitWidth())); } PrimaryType primaryTypeEncoding(Type elemTp) { diff --git a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp --- a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp +++ b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp @@ -634,13 +634,6 @@ extern "C" { -/// This type is used in the public API at all places where MLIR expects -/// values with the built-in type "index". For now, we simply assume that -/// type is 64-bit, but targets with different "index" bit widths should link -/// with an alternatively built runtime support library. -// TODO: support such targets? -typedef uint64_t index_t; - //===----------------------------------------------------------------------===// // // Public API with methods that operate on MLIR buffers (memrefs) to interact @@ -753,6 +746,12 @@ static_cast(tensor)->lexInsert(cursor, val); \ } +// Assume index_t is in fact uint64_t, so that _mlir_ciface_newSparseTensor +// can safely rewrite kIndex to kU64. We make this assertion to guarantee +// that this file cannot get out of sync with its header. +static_assert(std::is_same::value, + "Expected index_t == uint64_t"); + /// Constructs a new sparse tensor. This is the "swiss army knife" /// method for materializing sparse tensors into the computation. /// @@ -778,6 +777,13 @@ const index_t *perm = pref->data + pref->offset; uint64_t rank = aref->sizes[0]; + // Rewrite kIndex to kU64, to avoid introducing a bunch of new cases. + // This is safe because of the static_assert above. + if (ptrTp == OverheadType::kIndex) + ptrTp = OverheadType::kU64; + if (indTp == OverheadType::kIndex) + indTp = OverheadType::kU64; + // Double matrices with all combinations of overhead storage. CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF64, uint64_t, uint64_t, double); diff --git a/mlir/test/Dialect/SparseTensor/conversion_sparse2dense.mlir b/mlir/test/Dialect/SparseTensor/conversion_sparse2dense.mlir --- a/mlir/test/Dialect/SparseTensor/conversion_sparse2dense.mlir +++ b/mlir/test/Dialect/SparseTensor/conversion_sparse2dense.mlir @@ -27,16 +27,15 @@ // CHECK-DAG: %[[PermS:.*]] = memref.alloca() : memref<1xindex> // CHECK-DAG: %[[PermD:.*]] = memref.cast %[[PermS]] : memref<1xindex> to memref // CHECK-DAG: memref.store %[[I0]], %[[PermS]][%[[I0]]] : memref<1xindex> -// CHECK-DAG: %[[SecTp:.*]] = arith.constant 1 : i32 +// CHECK-DAG: %[[zeroI32:.*]] = arith.constant 0 : i32 // CHECK-DAG: %[[ElemTp:.*]] = arith.constant 4 : i32 // CHECK-DAG: %[[ActionToIter:.*]] = arith.constant 5 : i32 -// CHECK-DAG: %[[Iter:.*]] = call @newSparseTensor(%[[AttrsD]], %[[SizesD]], %[[PermD]], %[[SecTp]], %[[SecTp]], %[[ElemTp]], %[[ActionToIter]], %[[Arg]]) : (memref, memref, memref, i32, i32, i32, i32, !llvm.ptr) -> !llvm.ptr +// CHECK-DAG: %[[Iter:.*]] = call @newSparseTensor(%[[AttrsD]], %[[SizesD]], %[[PermD]], %[[zeroI32]], %[[zeroI32]], %[[ElemTp]], %[[ActionToIter]], %[[Arg]]) : (memref, memref, memref, i32, i32, i32, i32, !llvm.ptr) -> !llvm.ptr // CHECK-DAG: %[[IndS:.*]] = memref.alloca() : memref<1xindex> // CHECK-DAG: %[[IndD:.*]] = memref.cast %[[IndS]] : memref<1xindex> to memref // CHECK-DAG: %[[ElemBuffer:.*]] = memref.alloca() : memref // CHECK-DAG: %[[M:.*]] = memref.alloc() : memref<13xi32> -// CHECK-DAG: %[[E0:.*]] = arith.constant 0 : i32 -// CHECK-DAG: linalg.fill(%[[E0]], %[[M]]) : i32, memref<13xi32> +// CHECK-DAG: linalg.fill(%[[zeroI32]], %[[M]]) : i32, memref<13xi32> // CHECK: scf.while : () -> () { // CHECK: %[[Cond:.*]] = call @getNextI32(%[[Iter]], %[[IndD]], %[[ElemBuffer]]) : (!llvm.ptr, memref, memref) -> i1 // CHECK: scf.condition(%[[Cond]]) @@ -67,16 +66,15 @@ // CHECK-DAG: %[[PermS:.*]] = memref.alloca() : memref<1xindex> // CHECK-DAG: %[[PermD:.*]] = memref.cast %[[PermS]] : memref<1xindex> to memref // CHECK-DAG: memref.store %[[I0]], %[[PermS]][%[[I0]]] : memref<1xindex> -// CHECK-DAG: %[[SecTp:.*]] = arith.constant 1 : i32 +// CHECK-DAG: %[[zeroI32:.*]] = arith.constant 0 : i32 // CHECK-DAG: %[[ElemTp:.*]] = arith.constant 4 : i32 // CHECK-DAG: %[[ActionToIter:.*]] = arith.constant 5 : i32 -// CHECK-DAG: %[[Iter:.*]] = call @newSparseTensor(%[[AttrsD]], %[[SizesD]], %[[PermD]], %[[SecTp]], %[[SecTp]], %[[ElemTp]], %[[ActionToIter]], %[[Arg]]) : (memref, memref, memref, i32, i32, i32, i32, !llvm.ptr) -> !llvm.ptr +// CHECK-DAG: %[[Iter:.*]] = call @newSparseTensor(%[[AttrsD]], %[[SizesD]], %[[PermD]], %[[zeroI32]], %[[zeroI32]], %[[ElemTp]], %[[ActionToIter]], %[[Arg]]) : (memref, memref, memref, i32, i32, i32, i32, !llvm.ptr) -> !llvm.ptr // CHECK-DAG: %[[IndS:.*]] = memref.alloca() : memref<1xindex> // CHECK-DAG: %[[IndD:.*]] = memref.cast %[[IndS]] : memref<1xindex> to memref // CHECK-DAG: %[[ElemBuffer:.*]] = memref.alloca() : memref // CHECK-DAG: %[[M:.*]] = memref.alloc(%[[SizeI0]]) : memref -// CHECK-DAG: %[[E0:.*]] = arith.constant 0 : i32 -// CHECK-DAG: linalg.fill(%[[E0]], %[[M]]) : i32, memref +// CHECK-DAG: linalg.fill(%[[zeroI32]], %[[M]]) : i32, memref // CHECK: scf.while : () -> () { // CHECK: %[[Cond:.*]] = call @getNextI32(%[[Iter]], %[[IndD]], %[[ElemBuffer]]) : (!llvm.ptr, memref, memref) -> i1 // CHECK: scf.condition(%[[Cond]])