diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -254,6 +254,48 @@ let hasVerifier = 1; } +def SparseTensor_ToSpecifierOp : SparseTensor_Op<"specifier", [Pure]>, + Arguments<(ins AnySparseTensor:$tensor)>, + Results<(outs SparseTensorStorageSpecifier:$result)> { + let summary = ""; + let description = [{ + Get the storage specifier for the tensor. + + Example of getting the storage specifier for a COO tensor: + + ```mlir + %0 = sparse_tensor.specifier %0 : tensor<64x64xf64, #CSR> to + !sparse_tensor.storage_specifier<#COO> + ``` + }]; + let assemblyFormat = "$tensor attr-dict `:` " + "type($tensor) `to` qualified(type($result))"; + let builders = [ + // Builds an op by inferring the result type. + OpBuilder<(ins "Value":$tensor)> + ]; +} + +def SparseTensor_SetSpecifierOp : SparseTensor_Op<"set_specifier", +[AllTypesMatch<["tensor", "result"]>]>, + Arguments<(ins AnySparseTensor:$tensor, SparseTensorStorageSpecifier:$specifier)>, + Results<(outs AnySparseTensor:$result)> { + let summary = ""; + let description = [{ + Set the storage specifier for a tensor and return the new representation of + the tensor. + + Example of setting the storage specifier for a CSR tensor: + + ```mlir + %0 = sparse_tensor.set_specifier %0, %1: tensor<64x64xf64, #CSR>, + !sparse_tensor.storage_specifier<#COO> to tensor<64x64xf64, #CSR> + ``` + }]; + let assemblyFormat = "$tensor `,` $specifier attr-dict `:` " + "type($tensor) `,` qualified(type($specifier)) `to` type($result)"; +} + def SparseTensor_StorageSpecifierInitOp : SparseTensor_Op<"storage_specifier.init", [Pure]>, Results<(outs SparseTensorStorageSpecifier:$result)> { let summary = ""; diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -716,6 +716,13 @@ return success(); } +void ToSpecifierOp::build(OpBuilder &builder, OperationState &result, + Value tensor) { + build(builder, result, + StorageSpecifierType::get(getSparseTensorEncoding(tensor.getType())), + tensor); +} + LogicalResult GetStorageSpecifierOp::verify() { RETURN_FAILURE_IF_FAILED(verifySparsifierGetterSetter( getSpecifierKind(), getDim(), getSpecifier(), getOperation())) diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -991,6 +991,37 @@ } }; +/// Sparse codegen rule for storage specifier accesses. +class SparseToSpecifierConverter : public OpConversionPattern { +public: + using OpAdaptor = typename ToSpecifierOp::Adaptor; + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(ToSpecifierOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); + rewriter.replaceOp(op, desc.getSpecifier()); + return success(); + } +}; + +/// Sparse codegen rule for setting storage specifier. +class SparseSetSpecifierConverter : public OpConversionPattern { +public: + using OpAdaptor = typename SetSpecifierOp::Adaptor; + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(SetSpecifierOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector fields; + auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields); + desc.setSpecifier(op.getSpecifier()); + Location loc = op.getLoc(); + rewriter.replaceOp(op, genTuple(rewriter, loc, op.getType(), fields)); + return success(); + } +}; + /// Sparse codegen rule for the convert operator. class SparseConvertConverter : public OpConversionPattern { public: @@ -1172,6 +1203,7 @@ SparseCompressConverter, SparseInsertConverter, SparseToPointersConverter, SparseToIndicesConverter, SparseToIndicesBufferConverter, SparseToValuesConverter, + SparseToSpecifierConverter, SparseSetSpecifierConverter, SparseConvertConverter, SparseNumberOfEntriesConverter>( typeConverter, patterns.getContext()); patterns.add(typeConverter, patterns.getContext(), diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h @@ -254,6 +254,8 @@ /// Getters: get the value for required field. /// + Value getSpecifier() const { return fields.back(); } + Value getSpecifierField(OpBuilder &builder, Location loc, StorageSpecifierKind kind, std::optional dim) const { @@ -388,6 +390,8 @@ fields[fidx] = v; } + void setSpecifier(Value spec) { fields[fields.size() - 1] = spec; } + void setSpecifierField(OpBuilder &builder, Location loc, StorageSpecifierKind kind, std::optional dim, Value v) { diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir --- a/mlir/test/Dialect/SparseTensor/codegen.mlir +++ b/mlir/test/Dialect/SparseTensor/codegen.mlir @@ -250,6 +250,32 @@ return %0 : memref } +// CHECK-LABEL: func.func @sparse_specifier( +// CHECK-SAME: %[[A0:.*0]]: memref, +// CHECK-SAME: %[[A1:.*1]]: memref, +// CHECK-SAME: %[[A2:.*2]]: memref, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A5:.*5]]: !sparse_tensor.storage_specifier +// CHECK: return %[[A5]] : !sparse_tensor.storage_specifier +func.func @sparse_specifier(%arg0: tensor) -> !sparse_tensor.storage_specifier<#ccoo> { + %0 = sparse_tensor.specifier %arg0 : tensor to !sparse_tensor.storage_specifier<#ccoo> + return %0 : !sparse_tensor.storage_specifier<#ccoo> +} + +// CHECK-LABEL: func.func @sparse_set_specifier( +// CHECK-SAME: %[[A0:.*0]]: memref, +// CHECK-SAME: %[[A1:.*1]]: memref, +// CHECK-SAME: %[[A2:.*2]]: memref, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A5:.*5]]: !sparse_tensor.storage_specifier +// CHECK-SAME: %[[A6:.*6]]: !sparse_tensor.storage_specifier +// CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A6]] +func.func @sparse_set_specifier(%arg0: tensor, %arg1: !sparse_tensor.storage_specifier<#ccoo>) -> tensor { + %0 = sparse_tensor.set_specifier %arg0, %arg1 : tensor, !sparse_tensor.storage_specifier<#ccoo> to tensor + return %0 : tensor +} // CHECK-LABEL: func.func @sparse_indices_coo( // CHECK-SAME: %[[A0:.*0]]: memref, diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir --- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir +++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir @@ -131,6 +131,39 @@ return %0 : memref } +// ----- + +#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}> + +// CHECK-LABEL: func @sparse_get_specifier( +// CHECK-SAME: %[[A:.*]]: tensor<128xf64, #{{.*}}>) +// CHECK: %[[T:.*]] = sparse_tensor.specifier %[[A]] : tensor<128xf64, #{{.*}}> to !sparse_tensor.storage_specifier<#{{.*}}> +// CHECK: return %[[T]] : !sparse_tensor.storage_specifier<#{{.*}}> +func.func @sparse_get_specifier(%arg0: tensor<128xf64, #SparseVector>) + -> !sparse_tensor.storage_specifier<#SparseVector> { + %0 = sparse_tensor.specifier %arg0 : tensor<128xf64, #SparseVector> to + !sparse_tensor.storage_specifier<#SparseVector> + return %0 : !sparse_tensor.storage_specifier<#SparseVector> +} + +// ----- + +#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}> + +// CHECK-LABEL: func @sparse_set_specifier( +// CHECK-SAME: %[[A:.*]]: tensor<128xf64, #{{.*}}>, +// CHECK-SAME: %[[B:.*]]: !sparse_tensor.storage_specifier<#{{.*}}>) +// CHECK: %[[T:.*]] = sparse_tensor.set_specifier %[[A]], %[[B]] : tensor<128xf64, #{{.*}}>, !sparse_tensor.storage_specifier<#{{.*}}> to tensor<128xf64, #{{.*}}> +// CHECK: return %[[T]] : tensor<128xf64, #{{.*}}> +func.func @sparse_set_specifier(%arg0: tensor<128xf64, #SparseVector>, + %arg1:!sparse_tensor.storage_specifier<#SparseVector>) + -> tensor<128xf64, #SparseVector> { + %0 = sparse_tensor.set_specifier %arg0, %arg1 : tensor<128xf64, #SparseVector>, + !sparse_tensor.storage_specifier<#SparseVector> to tensor<128xf64, #SparseVector> + return %0 : tensor<128xf64, #SparseVector> +} + + // ----- #SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>