diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageExpansion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageExpansion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageExpansion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageExpansion.cpp @@ -65,6 +65,22 @@ // Conversion rules. //===----------------------------------------------------------------------===// +/// Sparse tensor storage conversion rule for sparse_tensor::storage. +class SparseStorageConversion : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(StorageOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Simply convert it to a unrealize_conversion_cast. + // We should guarantee that all uses of sparse_tensor.storage op will + // be eventually eliminated by accessing the flattened SSA values directly. + rewriter.replaceOpWithNewOp( + op, TypeRange{op.getType()}, adaptor.getInputs()); + return success(); + } +}; + /// Sparse tensor storage conversion rule for sparse_tensor::storage_get. class SparseStorageGetConverter : public OpConversionPattern { public: @@ -195,7 +211,8 @@ /// to expand compounded sparse tensor tuples. void mlir::populateSparseTensorStorageExpansionPatterns( TypeConverter &typeConverter, RewritePatternSet &patterns) { - patterns.add( - typeConverter, patterns.getContext()); + patterns.add(typeConverter, + patterns.getContext()); } diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir --- a/mlir/test/Dialect/SparseTensor/codegen.mlir +++ b/mlir/test/Dialect/SparseTensor/codegen.mlir @@ -1,6 +1,5 @@ // RUN: mlir-opt %s --sparse-tensor-codegen --canonicalize --cse | FileCheck %s --check-prefix=CHECK-CODEGEN -// FIXME: -// R_U_N: mlir-opt %s --sparse-tensor-codegen --sparse-tensor-storage-expansion --canonicalize --cse | FileCheck %s --check-prefix=CHECK-STORAGE +// RUN: mlir-opt %s --sparse-tensor-codegen --sparse-tensor-storage-expansion --canonicalize --cse | FileCheck %s --check-prefix=CHECK-STORAGE #SparseVector = #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], @@ -279,6 +278,23 @@ // CHECK-CODEGEN: %[[T6:.*]] = memref.cast %[[T5]] : memref<1xf64> to memref // CHECK-CODEGEN: %[[T:.*]] = sparse_tensor.storage(%[[T0]], %[[T2]], %[[T4]], %[[T6]]) // CHECK-CODEGEN: return %[[T]] : tuple, memref, memref, memref> +// +// CHECK-STORAGE-LABEL: func @sparse_alloc_csc( +// CHECK-STORAGE-SAME: %[[TMP_arg0:.*]]: index) +// CHECK-STORAGE-DAG: %[[TMP_c0:.*]] = arith.constant 0 : index +// CHECK-STORAGE-DAG: %[[TMP_c1:.*]] = arith.constant 1 : index +// CHECK-STORAGE-DAG: %[[TMP_c10:.*]] = arith.constant 10 : index +// CHECK-STORAGE: %[[TMP_0:.*]] = memref.alloc() : memref<2xindex> +// CHECK-STORAGE: memref.store %[[TMP_arg0]], %[[TMP_0]][%[[TMP_c0]]] : memref<2xindex> +// CHECK-STORAGE: memref.store %[[TMP_c10]], %[[TMP_0]][%[[TMP_c1]]] : memref<2xindex> +// CHECK-STORAGE: %[[TMP_1:.*]] = memref.alloc() : memref<1xindex> +// CHECK-STORAGE: %[[TMP_2:.*]] = memref.cast %[[TMP_1]] : memref<1xindex> to memref +// CHECK-STORAGE: %[[TMP_3:.*]] = memref.alloc() : memref<1xindex> +// CHECK-STORAGE: %[[TMP_4:.*]] = memref.cast %[[TMP_3]] : memref<1xindex> to memref +// CHECK-STORAGE: %[[TMP_5:.*]] = memref.alloc() : memref<1xf64> +// CHECK-STORAGE: %[[TMP_6:.*]] = memref.cast %[[TMP_5]] : memref<1xf64> to memref +// CHECK-STORAGE: return %[[TMP_0]], %[[TMP_2]], %[[TMP_4]], %[[TMP_6]] : memref<2xindex>, memref, memref, memref + func.func @sparse_alloc_csc(%arg0: index) -> tensor<10x?xf64, #CSC> { %0 = bufferization.alloc_tensor(%arg0) : tensor<10x?xf64, #CSC> %1 = sparse_tensor.load %0 : tensor<10x?xf64, #CSC> @@ -300,6 +316,21 @@ // CHECK-CODEGEN: %[[A1:.*]] = memref.cast %[[A]] : memref<6000xf64> to memref // CHECK-CODEGEN: %[[T:.*]] = sparse_tensor.storage(%[[A0]], %[[A1]]) // CHECK-CODEGEN: return %[[T]] : tuple, memref> +// +// CHECK-STORAGE-LABEL: func @sparse_alloc_3d() -> (memref<3xindex>, memref) { +// CHECK-STORAGE-DAG: %[[TMP_c0:.*]] = arith.constant 0 : index +// CHECK-STORAGE-DAG: %[[TMP_c1:.*]] = arith.constant 1 : index +// CHECK-STORAGE-DAG: %[[TMP_c2:.*]] = arith.constant 2 : index +// CHECK-STORAGE-DAG: %[[TMP_c10:.*]] = arith.constant 10 : index +// CHECK-STORAGE-DAG: %[[TMP_c20:.*]] = arith.constant 20 : index +// CHECK-STORAGE-DAG: %[[TMP_c30:.*]] = arith.constant 30 : index +// CHECK-STORAGE: %[[TMP_0:.*]] = memref.alloc() : memref<3xindex> +// CHECK-STORAGE: memref.store %[[TMP_c30]], %[[TMP_0]][%[[TMP_c0]]] : memref<3xindex> +// CHECK-STORAGE: memref.store %[[TMP_c10]], %[[TMP_0]][%[[TMP_c1]]] : memref<3xindex> +// CHECK-STORAGE: memref.store %[[TMP_c20]], %[[TMP_0]][%[[TMP_c2]]] : memref<3xindex> +// CHECK-STORAGE: %[[TMP_1:.*]] = memref.alloc() : memref<6000xf64> +// CHECK-STORAGE: %[[TMP_2:.*]] = memref.cast %[[TMP_1]] : memref<6000xf64> to memref +// CHECK-STORAGE: return %[[TMP_0]], %[[TMP_2]] : memref<3xindex>, memref func.func @sparse_alloc_3d() -> tensor<10x20x30xf64, #Dense3D> { %0 = bufferization.alloc_tensor() : tensor<10x20x30xf64, #Dense3D> %1 = sparse_tensor.load %0 : tensor<10x20x30xf64, #Dense3D> diff --git a/mlir/test/Dialect/SparseTensor/sparse_tensor_storage.mlir b/mlir/test/Dialect/SparseTensor/sparse_tensor_storage.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_tensor_storage.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_tensor_storage.mlir @@ -13,8 +13,8 @@ // CHECK-LABEL: func @call_sparse_storage_expand( // CHECK-SAME: %[[TMP_arg0:.*0]]: memref, // CHECK-SAME: %[[TMP_arg1:.*1]]: memref, -// CHECK-SAME: %[[TMP_arg2:.*]]: f64) -// CHECK: %[[TMP_0:.*]]:3 = call @sparse_storage_expand(%[[TMP_arg0]], %[[TMP_arg1]], %[[TMP_arg2]]) +// CHECK-SAME: %[[TMP_arg2:.*]]: f64) +// CHECK: %[[TMP_0:.*]]:3 = call @sparse_storage_expand(%[[TMP_arg0]], %[[TMP_arg1]], %[[TMP_arg2]]) // CHECK: return %[[TMP_0]]#0, %[[TMP_0]]#1, %[[TMP_0]]#2 : memref, memref, f64 func.func @call_sparse_storage_expand(%arg0: tuple, memref, f64>) -> tuple, memref, f64> { @@ -23,10 +23,21 @@ return %1 : tuple, memref, f64> } +// CHECK-LABEL: func @sparse_storage( +// CHECK-SAME: %[[TMP_arg0:.*0]]: memref, +// CHECK-SAME: %[[TMP_arg1:.*1]]: memref, +// CHECK-SAME: %[[TMP_arg2:.*2]]: memref) +// CHECK: return %[[TMP_arg0]], %[[TMP_arg1]], %[[TMP_arg2]] +func.func @sparse_storage(%arg0: memref, %arg1: memref, %arg2: memref) + -> tuple, memref, memref> { + %1 = sparse_tensor.storage(%arg0, %arg1, %arg2) : memref, memref, memref to tuple, memref, memref> + return %1 : tuple, memref, memref> +} + // CHECK-LABEL: func @sparse_storage_get( // CHECK-SAME: %[[TMP_arg0:.*0]]: memref, // CHECK-SAME: %[[TMP_arg1:.*1]]: memref, -// CHECK-SAME: %[[TMP_arg2:.*]]: f64) +// CHECK-SAME: %[[TMP_arg2:.*]]: f64) // CHECK: return %[[TMP_arg0]] : memref func.func @sparse_storage_get(%arg0: tuple, memref, f64>) -> memref { %0 = sparse_tensor.storage_get %arg0[0] @@ -38,7 +49,7 @@ // CHECK-SAME: %[[TMP_arg0:.*0]]: memref, // CHECK-SAME: %[[TMP_arg1:.*1]]: memref, // CHECK-SAME: %[[TMP_arg2:.*]]: f64, -// CHECK-SAME: %[[TMP_arg3:.*]]: memref) +// CHECK-SAME: %[[TMP_arg3:.*]]: memref) // CHECK: return %[[TMP_arg3]], %[[TMP_arg1]], %[[TMP_arg2]] : memref, memref, f64 func.func @sparse_storage_set(%arg0: tuple, memref, f64>, %arg1: memref) -> tuple, memref, f64> {