diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageExpansion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageExpansion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageExpansion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageExpansion.cpp @@ -65,6 +65,22 @@ // Conversion rules. //===----------------------------------------------------------------------===// +/// Sparse tensor storage conversion rule for sparse_tensor::storage. +class SparseStorageConversion : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(StorageOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Simply convert it to a unrealize_conversion_cast. + // We should guarantee that all uses of sparse_tensor.storage op will + // be eventually eliminated by accessing the flattened SSA values directly. + rewriter.replaceOpWithNewOp( + op, TypeRange{op.getType()}, adaptor.getInputs()); + return success(); + } +}; + /// Sparse tensor storage conversion rule for sparse_tensor::storage_get. class SparseStorageGetConverter : public OpConversionPattern { public: @@ -195,7 +211,8 @@ /// to expand compounded sparse tensor tuples. void mlir::populateSparseTensorStorageExpansionPatterns( TypeConverter &typeConverter, RewritePatternSet &patterns) { - patterns.add( - typeConverter, patterns.getContext()); + patterns.add(typeConverter, + patterns.getContext()); } diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir --- a/mlir/test/Dialect/SparseTensor/codegen.mlir +++ b/mlir/test/Dialect/SparseTensor/codegen.mlir @@ -1,6 +1,5 @@ -// RUN: mlir-opt %s --sparse-tensor-codegen --canonicalize --cse | FileCheck %s --check-prefix=CHECK-CODEGEN -// FIXME: -// R_U_N: mlir-opt %s --sparse-tensor-codegen --sparse-tensor-storage-expansion --canonicalize --cse | FileCheck %s --check-prefix=CHECK-STORAGE +// RUN: mlir-opt %s --sparse-tensor-codegen --canonicalize --cse | FileCheck %s --check-prefixes=CHECK,CHECK-CODEGEN +// RUN: mlir-opt %s --sparse-tensor-codegen --sparse-tensor-storage-expansion --canonicalize --cse | FileCheck %s --check-prefixes=CHECK,CHECK-STORAGE #SparseVector = #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], @@ -263,43 +262,49 @@ return } -// CHECK-CODEGEN-LABEL: func @sparse_alloc_csc( -// CHECK-CODEGEN-SAME: %[[A:.*]]: index) -// CHECK-CODEGEN-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-CODEGEN-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-CODEGEN-DAG: %[[C10:.*]] = arith.constant 10 : index -// CHECK-CODEGEN: %[[T0:.*]] = memref.alloc() : memref<2xindex> -// CHECK-CODEGEN: memref.store %[[A]], %[[T0]][%[[C0]]] : memref<2xindex> -// CHECK-CODEGEN: memref.store %[[C10]], %[[T0]][%[[C1]]] : memref<2xindex> -// CHECK-CODEGEN: %[[T1:.*]] = memref.alloc() : memref<1xindex> -// CHECK-CODEGEN: %[[T2:.*]] = memref.cast %[[T1]] : memref<1xindex> to memref -// CHECK-CODEGEN: %[[T3:.*]] = memref.alloc() : memref<1xindex> -// CHECK-CODEGEN: %[[T4:.*]] = memref.cast %[[T3]] : memref<1xindex> to memref -// CHECK-CODEGEN: %[[T5:.*]] = memref.alloc() : memref<1xf64> -// CHECK-CODEGEN: %[[T6:.*]] = memref.cast %[[T5]] : memref<1xf64> to memref -// CHECK-CODEGEN: %[[T:.*]] = sparse_tensor.storage(%[[T0]], %[[T2]], %[[T4]], %[[T6]]) -// CHECK-CODEGEN: return %[[T]] : tuple, memref, memref, memref> +// CHECK-LABEL: func @sparse_alloc_csc( +// CHECK-SAME: %[[A:.*]]: index) -> +// CHECK-CODEGEN-SAME: tuple, memref, memref, memref> +// CHECK-STORAGE-SAME: memref<2xindex>, memref, memref, memref +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C10:.*]] = arith.constant 10 : index +// CHECK: %[[T0:.*]] = memref.alloc() : memref<2xindex> +// CHECK: memref.store %[[A]], %[[T0]][%[[C0]]] : memref<2xindex> +// CHECK: memref.store %[[C10]], %[[T0]][%[[C1]]] : memref<2xindex> +// CHECK: %[[T1:.*]] = memref.alloc() : memref<1xindex> +// CHECK: %[[T2:.*]] = memref.cast %[[T1]] : memref<1xindex> to memref +// CHECK: %[[T3:.*]] = memref.alloc() : memref<1xindex> +// CHECK: %[[T4:.*]] = memref.cast %[[T3]] : memref<1xindex> to memref +// CHECK: %[[T5:.*]] = memref.alloc() : memref<1xf64> +// CHECK: %[[T6:.*]] = memref.cast %[[T5]] : memref<1xf64> to memref +// CHECK-CODEGEN: %[[T:.*]] = sparse_tensor.storage(%[[T0]], %[[T2]], %[[T4]], %[[T6]]) +// CHECK-CODEGEN: return %[[T]] +// CHECK-STORAGE: return %[[T0]], %[[T2]], %[[T4]], %[[T6]] func.func @sparse_alloc_csc(%arg0: index) -> tensor<10x?xf64, #CSC> { %0 = bufferization.alloc_tensor(%arg0) : tensor<10x?xf64, #CSC> %1 = sparse_tensor.load %0 : tensor<10x?xf64, #CSC> return %1 : tensor<10x?xf64, #CSC> } -// CHECK-CODEGEN-LABEL: func @sparse_alloc_3d() -> tuple, memref> -// CHECK-CODEGEN-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-CODEGEN-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-CODEGEN-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK-CODEGEN-DAG: %[[C10:.*]] = arith.constant 10 : index -// CHECK-CODEGEN-DAG: %[[C20:.*]] = arith.constant 20 : index -// CHECK-CODEGEN-DAG: %[[C30:.*]] = arith.constant 30 : index -// CHECK-CODEGEN: %[[A0:.*]] = memref.alloc() : memref<3xindex> -// CHECK-CODEGEN: memref.store %[[C30]], %[[A0]][%[[C0]]] : memref<3xindex> -// CHECK-CODEGEN: memref.store %[[C10]], %[[A0]][%[[C1]]] : memref<3xindex> -// CHECK-CODEGEN: memref.store %[[C20]], %[[A0]][%[[C2]]] : memref<3xindex> -// CHECK-CODEGEN: %[[A:.*]] = memref.alloc() : memref<6000xf64> -// CHECK-CODEGEN: %[[A1:.*]] = memref.cast %[[A]] : memref<6000xf64> to memref -// CHECK-CODEGEN: %[[T:.*]] = sparse_tensor.storage(%[[A0]], %[[A1]]) -// CHECK-CODEGEN: return %[[T]] : tuple, memref> +// CHECK-LABEL: func @sparse_alloc_3d() -> +// CHECK-CODEGEN-SAME: tuple, memref> +// CHECK-STORAGE-SAME: memref<3xindex>, memref +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C10:.*]] = arith.constant 10 : index +// CHECK-DAG: %[[C20:.*]] = arith.constant 20 : index +// CHECK-DAG: %[[C30:.*]] = arith.constant 30 : index +// CHECK: %[[A0:.*]] = memref.alloc() : memref<3xindex> +// CHECK: memref.store %[[C30]], %[[A0]][%[[C0]]] : memref<3xindex> +// CHECK: memref.store %[[C10]], %[[A0]][%[[C1]]] : memref<3xindex> +// CHECK: memref.store %[[C20]], %[[A0]][%[[C2]]] : memref<3xindex> +// CHECK: %[[A:.*]] = memref.alloc() : memref<6000xf64> +// CHECK: %[[A1:.*]] = memref.cast %[[A]] : memref<6000xf64> to memref +// CHECK-CODEGEN: %[[T:.*]] = sparse_tensor.storage(%[[A0]], %[[A1]]) +// CHECK-CODEGEN: return %[[T]] : tuple, memref> +// CHECK-STORAGE: return %[[A0]], %[[A1]] : memref<3xindex>, memref func.func @sparse_alloc_3d() -> tensor<10x20x30xf64, #Dense3D> { %0 = bufferization.alloc_tensor() : tensor<10x20x30xf64, #Dense3D> %1 = sparse_tensor.load %0 : tensor<10x20x30xf64, #Dense3D> diff --git a/mlir/test/Dialect/SparseTensor/sparse_tensor_storage.mlir b/mlir/test/Dialect/SparseTensor/sparse_tensor_storage.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_tensor_storage.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_tensor_storage.mlir @@ -13,8 +13,8 @@ // CHECK-LABEL: func @call_sparse_storage_expand( // CHECK-SAME: %[[TMP_arg0:.*0]]: memref, // CHECK-SAME: %[[TMP_arg1:.*1]]: memref, -// CHECK-SAME: %[[TMP_arg2:.*]]: f64) -// CHECK: %[[TMP_0:.*]]:3 = call @sparse_storage_expand(%[[TMP_arg0]], %[[TMP_arg1]], %[[TMP_arg2]]) +// CHECK-SAME: %[[TMP_arg2:.*]]: f64) +// CHECK: %[[TMP_0:.*]]:3 = call @sparse_storage_expand(%[[TMP_arg0]], %[[TMP_arg1]], %[[TMP_arg2]]) // CHECK: return %[[TMP_0]]#0, %[[TMP_0]]#1, %[[TMP_0]]#2 : memref, memref, f64 func.func @call_sparse_storage_expand(%arg0: tuple, memref, f64>) -> tuple, memref, f64> { @@ -23,10 +23,21 @@ return %1 : tuple, memref, f64> } +// CHECK-LABEL: func @sparse_storage( +// CHECK-SAME: %[[TMP_arg0:.*0]]: memref, +// CHECK-SAME: %[[TMP_arg1:.*1]]: memref, +// CHECK-SAME: %[[TMP_arg2:.*2]]: memref) +// CHECK: return %[[TMP_arg0]], %[[TMP_arg1]], %[[TMP_arg2]] +func.func @sparse_storage(%arg0: memref, %arg1: memref, %arg2: memref) + -> tuple, memref, memref> { + %1 = sparse_tensor.storage(%arg0, %arg1, %arg2) : memref, memref, memref to tuple, memref, memref> + return %1 : tuple, memref, memref> +} + // CHECK-LABEL: func @sparse_storage_get( // CHECK-SAME: %[[TMP_arg0:.*0]]: memref, // CHECK-SAME: %[[TMP_arg1:.*1]]: memref, -// CHECK-SAME: %[[TMP_arg2:.*]]: f64) +// CHECK-SAME: %[[TMP_arg2:.*]]: f64) // CHECK: return %[[TMP_arg0]] : memref func.func @sparse_storage_get(%arg0: tuple, memref, f64>) -> memref { %0 = sparse_tensor.storage_get %arg0[0] @@ -38,7 +49,7 @@ // CHECK-SAME: %[[TMP_arg0:.*0]]: memref, // CHECK-SAME: %[[TMP_arg1:.*1]]: memref, // CHECK-SAME: %[[TMP_arg2:.*]]: f64, -// CHECK-SAME: %[[TMP_arg3:.*]]: memref) +// CHECK-SAME: %[[TMP_arg3:.*]]: memref) // CHECK: return %[[TMP_arg3]], %[[TMP_arg1]], %[[TMP_arg2]] : memref, memref, f64 func.func @sparse_storage_set(%arg0: tuple, memref, f64>, %arg1: memref) -> tuple, memref, f64> {