diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -108,8 +108,19 @@ return converter.isLegal(op.getOperandTypes()); }); target.addDynamicallyLegalOp([&](tensor::CastOp op) { - return converter.isLegal(op.getOperand().getType()); + return converter.isLegal(op.source().getType()) && + converter.isLegal(op.dest().getType()); }); + target.addDynamicallyLegalOp( + [&](tensor::ExpandShapeOp op) { + return converter.isLegal(op.src().getType()) && + converter.isLegal(op.result().getType()); + }); + target.addDynamicallyLegalOp( + [&](tensor::CollapseShapeOp op) { + return converter.isLegal(op.src().getType()) && + converter.isLegal(op.result().getType()); + }); target.addDynamicallyLegalOp( [&](bufferization::AllocTensorOp op) { return converter.isLegal(op.getType()); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -26,6 +26,7 @@ #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" #include "mlir/Dialect/SparseTensor/Utils/Merger.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/TensorEncoding.h" @@ -1824,6 +1825,79 @@ SparsificationOptions options; }; +/// Sparse rewriting rule for expand shape operator. +struct ExpandShapeRewriter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::ExpandShapeOp op, + PatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto encDst = getSparseTensorEncoding(op.result().getType()); + auto encSrc = getSparseTensorEncoding(op.src().getType()); + // Since a pure dense expansion is very cheap (change of view), for + // a sparse2dense or dense2sparse, we can simply unfuse a sparse + // form the expansion operation itself. + if (encDst && encSrc) { + return failure(); // TODO: implement sparse2sparse + } else if (encSrc) { + RankedTensorType rtp = op.src().getType().cast(); + auto denseTp = + RankedTensorType::get(rtp.getShape(), rtp.getElementType()); + auto convert = rewriter.create(loc, denseTp, op.src()); + op->setOperand(0, convert); + return success(); + } else if (encDst) { + RankedTensorType rtp = op.result().getType().cast(); + auto denseTp = + RankedTensorType::get(rtp.getShape(), rtp.getElementType()); + auto reshape = rewriter.create( + loc, denseTp, op.src(), op.getReassociation()); + Value convert = rewriter.create(loc, rtp, reshape); + rewriter.replaceOp(op, convert); + return success(); + } + return failure(); + } +}; + +/// Sparse rewriting rule for collapse shape operator. +struct CollapseShapeRewriter + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::CollapseShapeOp op, + PatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto encDst = getSparseTensorEncoding(op.result().getType()); + auto encSrc = getSparseTensorEncoding(op.src().getType()); + // Since a pure dense collapse is very cheap (change of view), for + // a sparse2dense or dense2sparse, we can simply unfuse a sparse + // form the expansion operation itself. + if (encDst && encSrc) { + return failure(); // TODO: implement sparse2sparse + } else if (encSrc) { + RankedTensorType rtp = op.src().getType().cast(); + auto denseTp = + RankedTensorType::get(rtp.getShape(), rtp.getElementType()); + auto convert = rewriter.create(loc, denseTp, op.src()); + op->setOperand(0, convert); + return success(); + } else if (encDst) { + RankedTensorType rtp = op.result().getType().cast(); + auto denseTp = + RankedTensorType::get(rtp.getShape(), rtp.getElementType()); + auto reshape = rewriter.create( + loc, denseTp, op.src(), op.getReassociation()); + Value convert = rewriter.create(loc, rtp, reshape); + rewriter.replaceOp(op, convert); + return success(); + } + return failure(); + } +}; + } // namespace /// Populates the given patterns list with rewriting rules required for @@ -1831,4 +1905,6 @@ void mlir::populateSparsificationPatterns( RewritePatternSet &patterns, const SparsificationOptions &options) { patterns.add(patterns.getContext(), options); + patterns.add( + patterns.getContext()); } diff --git a/mlir/test/Dialect/SparseTensor/rewriting.mlir b/mlir/test/Dialect/SparseTensor/rewriting.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SparseTensor/rewriting.mlir @@ -0,0 +1,87 @@ +// RUN: mlir-opt %s -sparsification | FileCheck %s + +#SparseVector = #sparse_tensor.encoding<{ + dimLevelType = ["compressed"] +}> + +#SparseMatrix = #sparse_tensor.encoding<{ + dimLevelType = ["compressed", "compressed"] +}> + +// CHECK-LABEL: func.func @expand_dense( +// CHECK-SAME: %[[A:.*]]: tensor<12xf64>) -> tensor<3x4xf64> { +// CHECK: %[[E:.*]] = tensor.expand_shape %[[A]] {{.*}} : tensor<12xf64> into tensor<3x4xf64> +// CHECK: return %[[E]] : tensor<3x4xf64> +// CHECK: } +func.func @expand_dense(%arg0: tensor<12xf64>) -> tensor<3x4xf64> { + %0 = tensor.expand_shape %arg0 [[0, 1]] : tensor<12xf64> into tensor<3x4xf64> + return %0 : tensor<3x4xf64> +} + +// CHECK-LABEL: func.func @expand_from_sparse( +// CHECK-SAME: %[[A:.*]]: tensor<12xf64, #sparse_tensor.encoding<{{{.*}}}>>) -> tensor<3x4xf64> { +// CHECK: %[[C:.*]] = sparse_tensor.convert %[[A]] : tensor<12xf64, #sparse_tensor.encoding<{{{.*}}}>> to tensor<12xf64> +// CHECK: %[[E:.*]] = tensor.expand_shape %[[C]] {{.*}} : tensor<12xf64> into tensor<3x4xf64> +// CHECK: return %[[E]] : tensor<3x4xf64> +// CHECK: } +func.func @expand_from_sparse(%arg0: tensor<12xf64, #SparseVector>) -> tensor<3x4xf64> { + %0 = tensor.expand_shape %arg0 [[0, 1]] : tensor<12xf64, #SparseVector> into tensor<3x4xf64> + return %0 : tensor<3x4xf64> +} + +// CHECK-LABEL: func.func @expand_to_sparse( +// CHECK-SAME: %[[A:.*]]: tensor<12xf64>) -> tensor<3x4xf64, #sparse_tensor.encoding<{{{.*}}}>> { +// CHECK: %[[E:.*]] = tensor.expand_shape %[[A]] {{.*}} : tensor<12xf64> into tensor<3x4xf64> +// CHECK: %[[C:.*]] = sparse_tensor.convert %[[E]] : tensor<3x4xf64> to tensor<3x4xf64, #sparse_tensor.encoding<{{{.*}}}>> +// CHECK: return %[[C]] : tensor<3x4xf64, #sparse_tensor.encoding<{{{.*}}}>> +// CHECK: } +func.func @expand_to_sparse(%arg0: tensor<12xf64>) -> tensor<3x4xf64, #SparseMatrix> { + %0 = tensor.expand_shape %arg0 [[0, 1]] : tensor<12xf64> into tensor<3x4xf64, #SparseMatrix> + return %0 : tensor<3x4xf64, #SparseMatrix> +} + +// TODO: make this work +// CHECK-LABEL: func.func @expand_sparse2sparse( +func.func @expand_sparse2sparse(%arg0: tensor<12xf64, #SparseVector>) -> tensor<3x4xf64, #SparseMatrix> { + %0 = tensor.expand_shape %arg0 [[0, 1]] : tensor<12xf64, #SparseVector> into tensor<3x4xf64, #SparseMatrix> + return %0 : tensor<3x4xf64, #SparseMatrix> +} + +// CHECK-LABEL: func.func @collapse_dense( +// CHECK-SAME: %[[A:.*]]: tensor<3x4xf64>) -> tensor<12xf64> { +// CHECK: %[[R:.*]] = tensor.collapse_shape %[[A]] {{.*}} : tensor<3x4xf64> into tensor<12xf64> +// CHECK: return %[[R]] : tensor<12xf64> +// CHECK: } +func.func @collapse_dense(%arg0: tensor<3x4xf64>) -> tensor<12xf64> { + %0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<3x4xf64> into tensor<12xf64> + return %0 : tensor<12xf64> +} + +// CHECK-LABEL: func.func @collapse_from_sparse( +// CHECK-SAME: %[[A:.*]]: tensor<3x4xf64, #sparse_tensor.encoding<{{{.*}}}>>) -> tensor<12xf64> { +// CHECK: %[[C:.*]] = sparse_tensor.convert %[[A]] : tensor<3x4xf64, #sparse_tensor.encoding<{{{.*}}}>> to tensor<3x4xf64> +// CHECK: %[[R:.*]] = tensor.collapse_shape %[[C]] {{.*}} : tensor<3x4xf64> into tensor<12xf64> +// CHECK: return %[[R]] : tensor<12xf64> +// CHECK: } +func.func @collapse_from_sparse(%arg0: tensor<3x4xf64, #SparseMatrix>) -> tensor<12xf64> { + %0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<3x4xf64, #SparseMatrix> into tensor<12xf64> + return %0 : tensor<12xf64> +} + +// CHECK-LABEL: func.func @collapse_to_sparse( +// CHECK-SAME: %[[A:.*]]: tensor<3x4xf64>) -> tensor<12xf64, #sparse_tensor.encoding<{{{.*}}}>> { +// CHECK: %[[R:.*]] = tensor.collapse_shape %[[A]] {{.*}} : tensor<3x4xf64> into tensor<12xf64> +// CHECK: %[[C:.*]] = sparse_tensor.convert %[[R]] : tensor<12xf64> to tensor<12xf64, #sparse_tensor.encoding<{{{.*}}}>> +// CHECK: return %[[C]] : tensor<12xf64, #sparse_tensor.encoding<{{{.*}}}>> +// CHECK: } +func.func @collapse_to_sparse(%arg0: tensor<3x4xf64>) -> tensor<12xf64, #SparseVector> { + %0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<3x4xf64> into tensor<12xf64, #SparseVector> + return %0 : tensor<12xf64, #SparseVector> +} + +// TODO: make this work +// CHECK-LABEL: func.func @collapse_sparse2sparse( +func.func @collapse_sparse2sparse(%arg0: tensor<3x4xf64, #SparseMatrix>) -> tensor<12xf64, #SparseVector> { + %0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<3x4xf64, #SparseMatrix> into tensor<12xf64, #SparseVector> + return %0 : tensor<12xf64, #SparseVector> +} diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reshape.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reshape.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reshape.mlir @@ -0,0 +1,129 @@ +// RUN: mlir-opt %s --sparse-compiler | \ +// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +#SparseVector = #sparse_tensor.encoding<{ + dimLevelType = ["compressed"] +}> + +#SparseMatrix = #sparse_tensor.encoding<{ + dimLevelType = ["compressed", "compressed"] +}> + +// +// Test with various forms of the two most elementary reshape +// operations: expand/collapse. +// +module { + + func.func @expand_dense(%arg0: tensor<12xf64>) -> tensor<3x4xf64> { + %0 = tensor.expand_shape %arg0 [[0, 1]] : tensor<12xf64> into tensor<3x4xf64> + return %0 : tensor<3x4xf64> + } + + func.func @expand_from_sparse(%arg0: tensor<12xf64, #SparseVector>) -> tensor<3x4xf64> { + %0 = tensor.expand_shape %arg0 [[0, 1]] : tensor<12xf64, #SparseVector> into tensor<3x4xf64> + return %0 : tensor<3x4xf64> + } + + func.func @expand_to_sparse(%arg0: tensor<12xf64>) -> tensor<3x4xf64, #SparseMatrix> { + %0 = tensor.expand_shape %arg0 [[0, 1]] : tensor<12xf64> into tensor<3x4xf64, #SparseMatrix> + return %0 : tensor<3x4xf64, #SparseMatrix> + } + +// TODO: make this work +// func.func @expand_sparse2sparse(%arg0: tensor<12xf64, #SparseVector>) -> tensor<3x4xf64, #SparseMatrix> { +// %0 = tensor.expand_shape %arg0 [[0, 1]] : tensor<12xf64, #SparseVector> into tensor<3x4xf64, #SparseMatrix> +// return %0 : tensor<3x4xf64, #SparseMatrix> +// } + + func.func @collapse_dense(%arg0: tensor<3x4xf64>) -> tensor<12xf64> { + %0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<3x4xf64> into tensor<12xf64> + return %0 : tensor<12xf64> + } + + func.func @collapse_from_sparse(%arg0: tensor<3x4xf64, #SparseMatrix>) -> tensor<12xf64> { + %0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<3x4xf64, #SparseMatrix> into tensor<12xf64> + return %0 : tensor<12xf64> + } + + func.func @collapse_to_sparse(%arg0: tensor<3x4xf64>) -> tensor<12xf64, #SparseVector> { + %0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<3x4xf64> into tensor<12xf64, #SparseVector> + return %0 : tensor<12xf64, #SparseVector> + } + +// TODO: make this work +// func.func @collapse_sparse2sparse(%arg0: tensor<3x4xf64, #SparseMatrix>) -> tensor<12xf64, #SparseVector> { +// %0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<3x4xf64, #SparseMatrix> into tensor<12xf64, #SparseVector> +// return %0 : tensor<12xf64, #SparseVector> +// } + + + // + // Main driver. + // + func.func @entry() { + %c0 = arith.constant 0 : index + %df = arith.constant -1.0 : f64 + + // Setup test vectors and matrices.. + %v = arith.constant dense <[ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, + 7.0, 8.0, 9.0, 10.0, 11.0, 12.0]> : tensor<12xf64> + %m = arith.constant dense <[ [ 1.1, 1.2, 1.3, 1.4 ], + [ 2.1, 2.2, 2.3, 2.4 ], + [ 3.1, 3.2, 3.3, 3.4 ]]> : tensor<3x4xf64> + %sv = sparse_tensor.convert %v : tensor<12xf64> to tensor<12xf64, #SparseVector> + %sm = sparse_tensor.convert %m : tensor<3x4xf64> to tensor<3x4xf64, #SparseMatrix> + + + // Call the kernels. + %expand0 = call @expand_dense(%v) : (tensor<12xf64>) -> tensor<3x4xf64> + %expand1 = call @expand_from_sparse(%sv) : (tensor<12xf64, #SparseVector>) -> tensor<3x4xf64> + %expand2 = call @expand_to_sparse(%v) : (tensor<12xf64>) -> tensor<3x4xf64, #SparseMatrix> + + %collapse0 = call @collapse_dense(%m) : (tensor<3x4xf64>) -> tensor<12xf64> + %collapse1 = call @collapse_from_sparse(%sm) : (tensor<3x4xf64, #SparseMatrix>) -> tensor<12xf64> + %collapse2 = call @collapse_to_sparse(%m) : (tensor<3x4xf64>) -> tensor<12xf64, #SparseVector> + + // + // Verify result. + // + // CHECK: ( ( 1, 2, 3, 4 ), ( 5, 6, 7, 8 ), ( 9, 10, 11, 12 ) ) + // CHECK-NEXT: ( ( 1, 2, 3, 4 ), ( 5, 6, 7, 8 ), ( 9, 10, 11, 12 ) ) + // CHECK-NEXT: ( 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, -1, -1, -1, -1 ) + // CHECK-NEXT: ( 1.1, 1.2, 1.3, 1.4, 2.1, 2.2, 2.3, 2.4, 3.1, 3.2, 3.3, 3.4 ) + // CHECK-NEXT: ( 1.1, 1.2, 1.3, 1.4, 2.1, 2.2, 2.3, 2.4, 3.1, 3.2, 3.3, 3.4 ) + // CHECK-NEXT: ( 1.1, 1.2, 1.3, 1.4, 2.1, 2.2, 2.3, 2.4, 3.1, 3.2, 3.3, 3.4, -1, -1, -1, -1 ) + // + %m0 = vector.transfer_read %expand0[%c0, %c0], %df: tensor<3x4xf64>, vector<3x4xf64> + vector.print %m0 : vector<3x4xf64> + %m1 = vector.transfer_read %expand1[%c0, %c0], %df: tensor<3x4xf64>, vector<3x4xf64> + vector.print %m1 : vector<3x4xf64> + %a2 = sparse_tensor.values %expand2 : tensor<3x4xf64, #SparseMatrix> to memref + %m2 = vector.transfer_read %a2[%c0], %df: memref, vector<16xf64> + vector.print %m2 : vector<16xf64> + + %v0 = vector.transfer_read %collapse0[%c0], %df: tensor<12xf64>, vector<12xf64> + vector.print %v0 : vector<12xf64> + %v1 = vector.transfer_read %collapse1[%c0], %df: tensor<12xf64>, vector<12xf64> + vector.print %v1 : vector<12xf64> + %b2 = sparse_tensor.values %collapse2 : tensor<12xf64, #SparseVector> to memref + %v2 = vector.transfer_read %b2[%c0], %df: memref, vector<16xf64> + vector.print %v2 : vector<16xf64> + + // Release sparse resources. + sparse_tensor.release %sv : tensor<12xf64, #SparseVector> + sparse_tensor.release %sm : tensor<3x4xf64, #SparseMatrix> + sparse_tensor.release %expand2 : tensor<3x4xf64, #SparseMatrix> + sparse_tensor.release %collapse2 : tensor<12xf64, #SparseVector> + + // Release dense resources. + %meme1 = bufferization.to_memref %expand1 : memref<3x4xf64> + memref.dealloc %meme1 : memref<3x4xf64> + %memc1 = bufferization.to_memref %collapse1 : memref<12xf64> + memref.dealloc %memc1 : memref<12xf64> + + return + } +}