diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -17,6 +17,7 @@ #include "CodegenUtils.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" @@ -232,7 +233,31 @@ } }; -/// Sparse conversion rule for pointer accesses. +/// Sparse codegen rule for the dealloc operator. +class SparseTensorDeallocConverter + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(bufferization::DeallocTensorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto enc = getSparseTensorEncoding(op.getTensor().getType()); + if (!enc) + return failure(); + // Replace the tuple deallocation with field deallocations. + Location loc = op->getLoc(); + Value tuple = adaptor.getTensor(); + for (unsigned i = 0, sz = tuple.getType().cast().size(); i < sz; + i++) { + Value mem = createTupleGet(rewriter, loc, tuple, i); + rewriter.create(loc, mem); + } + rewriter.eraseOp(op); + return success(); + } +}; + +/// Sparse codegen rule for pointer accesses. class SparseToPointersConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -251,7 +276,7 @@ } }; -/// Sparse conversion rule for index accesses. +/// Sparse codegen rule for index accesses. class SparseToIndicesConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -270,7 +295,7 @@ } }; -/// Sparse conversion rule for value accesses. +/// Sparse codegen rule for value accesses. class SparseToValuesConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -280,7 +305,7 @@ // Replace the requested values access with corresponding field. Location loc = op->getLoc(); Value tuple = adaptor.getTensor(); - unsigned i = tuple.getType().cast().size() - 1; // last + unsigned i = tuple.getType().cast().size() - 1; // last rewriter.replaceOp(op, createTupleGet(rewriter, loc, tuple, i)); return success(); } @@ -306,6 +331,7 @@ void mlir::populateSparseTensorCodegenPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns) { patterns.add(typeConverter, patterns.getContext()); + SparseTensorDeallocConverter, SparseToPointersConverter, + SparseToIndicesConverter, SparseToValuesConverter>( + typeConverter, patterns.getContext()); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -124,8 +124,7 @@ }); // The following operations and dialects may be introduced by the // rewriting rules, and are therefore marked as legal. - target.addLegalOp(); target.addLegalDialect< arith::ArithmeticDialect, bufferization::BufferizationDialect, @@ -160,7 +159,9 @@ // Almost everything in the sparse dialect must go! target.addIllegalDialect(); target.addLegalOp(); - // All dynamic rules below accept new function, call, return. + // All dynamic rules below accept new function, call, return, and various + // tensor and bufferization operations as legal output of the rewriting + // provided that all sparse tensor types have been fully rewritten. target.addDynamicallyLegalOp([&](func::FuncOp op) { return converter.isSignatureLegal(op.getFunctionType()); }); @@ -170,6 +171,10 @@ target.addDynamicallyLegalOp([&](func::ReturnOp op) { return converter.isLegal(op.getOperandTypes()); }); + target.addDynamicallyLegalOp( + [&](bufferization::DeallocTensorOp op) { + return converter.isLegal(op.getTensor().getType()); + }); // Legal dialects may occur in generated code. target.addLegalDialect to memref return %0 : memref } + +// CHECK-LABEL: func @sparse_dealloc_csr( +// CHECK-SAME: %[[A:.*]]: tuple, memref, memref, memref>) +// CHECK: %[[F0:.*]] = sparse_tensor.storage_get %[[A]][0] : tuple, memref, memref, memref> to memref<2xindex> +// CHECK: memref.dealloc %[[F0]] : memref<2xindex> +// CHECK: %[[F1:.*]] = sparse_tensor.storage_get %[[A]][1] : tuple, memref, memref, memref> to memref +// CHECK: memref.dealloc %[[F1]] : memref +// CHECK: %[[F2:.*]] = sparse_tensor.storage_get %[[A]][2] : tuple, memref, memref, memref> to memref +// CHECK: memref.dealloc %[[F2]] : memref +// CHECK: %[[F3:.*]] = sparse_tensor.storage_get %[[A]][3] : tuple, memref, memref, memref> to memref +// CHECK: memref.dealloc %[[F3]] : memref +// CHECK: return +func.func @sparse_dealloc_csr(%arg0: tensor) { + bufferization.dealloc_tensor %arg0 : tensor + return +}