diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -214,6 +214,7 @@ let assemblyFormat = "$specifier $specifierKind (`at` $dim^)? attr-dict `:` " "qualified(type($specifier)) `to` type($result)"; let hasVerifier = 1; + let hasFolder = 1; } def SparseTensor_SetStorageSpecifierOp : SparseTensor_Op<"storage_specifier.set", diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -493,6 +493,20 @@ return success(); } +template +static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op) { + return op.getSpecifier().template getDefiningOp(); +} + +OpFoldResult GetStorageSpecifierOp::fold(ArrayRef operands) { + StorageSpecifierKind kind = getSpecifierKind(); + Optional dim = getDim(); + for (auto op = getSpecifierSetDef(*this); op; op = getSpecifierSetDef(op)) + if (kind == op.getSpecifierKind() && dim == op.getDim()) + return op.getValue(); + return {}; +} + LogicalResult SetStorageSpecifierOp::verify() { if (failed(verifySparsifierGetterSetter(getSpecifierKind(), getDim(), getSpecifier(), getOperation()))) { diff --git a/mlir/test/Dialect/SparseTensor/fold.mlir b/mlir/test/Dialect/SparseTensor/fold.mlir --- a/mlir/test/Dialect/SparseTensor/fold.mlir +++ b/mlir/test/Dialect/SparseTensor/fold.mlir @@ -32,6 +32,7 @@ %2 = sparse_tensor.values %arg0 : tensor<64xf32, #SparseVector> to memref return } + // CHECK-LABEL: func @sparse_concat_dce( // CHECK-NOT: sparse_tensor.concatenate // CHECK: return @@ -45,3 +46,19 @@ return } +// CHECK-LABEL: func @sparse_get_specifier_dce_fold( +// CHECK-SAME: %[[A0:.*]]: !sparse_tensor.storage_specifier +// CHECK-SAME: %[[A1:.*]]: i64, +// CHECK-SAME: %[[A2:.*]]: i64) +// CHECK-NOT: sparse_tensor.storage_specifier.set +// CHECK-NOT: sparse_tensor.storage_specifier.get +// CHECK: return %[[A1]] +func.func @sparse_get_specifier_dce_fold(%arg0: !sparse_tensor.storage_specifier<#SparseVector>, %arg1: i64, %arg2: i64) -> i64 { + %0 = sparse_tensor.storage_specifier.set %arg0 dim_sz at 0 with %arg1 + : i64, !sparse_tensor.storage_specifier<#SparseVector> + %1 = sparse_tensor.storage_specifier.set %0 ptr_mem_sz at 0 with %arg2 + : i64, !sparse_tensor.storage_specifier<#SparseVector> + %2 = sparse_tensor.storage_specifier.get %1 dim_sz at 0 + : !sparse_tensor.storage_specifier<#SparseVector> to i64 + return %2 : i64 +} diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir --- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir +++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir @@ -133,7 +133,7 @@ #SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}> // CHECK-LABEL: func @sparse_set_md( -// CHECK-SAME: %[[A:.*]]: !sparse_tensor.storage_specifier<#{{.*}}>, +// CHECK-SAME: %[[A:.*]]: !sparse_tensor.storage_specifier<#{{.*}}>, // CHECK-SAME: %[[I:.*]]: i64) // CHECK: %[[T:.*]] = sparse_tensor.storage_specifier.set %[[A]] dim_sz at 0 with %[[I]] // CHECK: return %[[T]] : !sparse_tensor.storage_specifier<#{{.*}}> @@ -553,4 +553,3 @@ sparse_tensor.sort_coo stable %arg0, %arg1 jointly %arg2 { nx=2 : index, ny=1 : index}: memref jointly memref return %arg1, %arg2 : memref, memref } -