diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h --- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h +++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h @@ -142,6 +142,9 @@ LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector &result); +/// Tests if types are the same when ignoring encoding on ranked tensors. +bool isSameTypeWithoutEncoding(Type tp1, Type tp2); + /// Function to control the folding of constant and extract slice. using ControlConstantExtractSliceFusionFn = std::function; diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -449,8 +449,6 @@ return nullptr; } -/// Returns true iff the given sparse tensor encoding attribute has a trailing -/// COO region starting at the given level. bool mlir::sparse_tensor::isCOOType(SparseTensorEncodingAttr enc, Level startLvl, bool isUnique) { if (!enc || diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -346,6 +346,45 @@ } }; +// Fuse a tensor cast into producing operation. Note that a tensor.cast +// should really not be used to convert between sparse encodings. Since +// the pattern currently appears as a result of some prior rewriting +// we make an attempt to repair very obvious cases. +// TODO: audit the pure tensor dialect rewriting rules +struct FuseTensorCast : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::CastOp op, + PatternRewriter &rewriter) const override { + Type srcType = op.getSource().getType(); + Type dstType = op.getDest().getType(); + // A nop cast simply folds away. + if (srcType == dstType) { + rewriter.replaceOp(op, op->getResults()); + return success(); + } + // See if a sparsity changing cast can be fused into producer. + if (tensor::isSameTypeWithoutEncoding(srcType, dstType)) { + if (Operation *def = op.getSource().getDefiningOp()) { + if (def->hasOneUse() && isa(def)) { + def->getResult(0).setType(op->getResultTypes()[0]); + rewriter.replaceOp(op, def->getResult(0)); + return success(); + } + } + } + // Repair tensor casts with at least one sparse operand into the + // the properly supported sparse_tensor.convert. + if (getSparseTensorEncoding(srcType) || getSparseTensorEncoding(dstType)) { + rewriter.replaceOpWithNewOp(op, dstType, op.getSource()); + return success(); + } + // Fail otherwise. + return failure(); + } +}; + /// Sparse rewriting rule for sparse-to-sparse reshape operator. template struct Sparse2SparseReshapeRewriter : public OpRewritePattern { @@ -1125,7 +1164,7 @@ //===---------------------------------------------------------------------===// void mlir::populatePreSparsificationRewriting(RewritePatternSet &patterns) { - patterns.add( + patterns.add( patterns.getContext()); } diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -110,6 +110,16 @@ return success(); } +bool tensor::isSameTypeWithoutEncoding(Type tp1, Type tp2) { + if (auto rtp1 = tp1.dyn_cast()) { + if (auto rtp2 = tp2.dyn_cast()) + return rtp1.getShape() == rtp2.getShape() && + rtp1.getElementType() == rtp2.getElementType(); + return false; + } + return tp1 == tp2; // default implementation +} + /// Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or /// rank-extending tensor.insert_slice op. static llvm::SmallBitVector getDroppedDims(ArrayRef reducedShape, @@ -1343,18 +1353,6 @@ getReassociationIndicesAttribute(b, reassociation)); } -// Checks if types are the same, but ignoring encoding on ranked tensors. -static bool isSameTypesWithoutEncoding(Type tp1, Type tp2) { - if (auto rtp1 = tp1.dyn_cast()) { - if (auto rtp2 = tp2.dyn_cast()) - return rtp1.getShape() == rtp2.getShape() && - rtp1.getElementType() == rtp2.getElementType(); - return false; - } - // Default implementation. - return tp1 == tp2; -} - template ::value> static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op, @@ -1367,7 +1365,7 @@ auto maps = op.getReassociationMaps(); RankedTensorType expectedType = CollapseShapeOp::inferCollapsedType(expandedType, maps); - if (!isSameTypesWithoutEncoding(collapsedType, expectedType)) + if (!isSameTypeWithoutEncoding(collapsedType, expectedType)) return op.emitOpError("expected collapsed type to be ") << expectedType << ", but got " << collapsedType; return success(); diff --git a/mlir/test/Dialect/SparseTensor/rewriting.mlir b/mlir/test/Dialect/SparseTensor/post_rewriting.mlir old mode 100755 new mode 100644 rename from mlir/test/Dialect/SparseTensor/rewriting.mlir rename to mlir/test/Dialect/SparseTensor/post_rewriting.mlir diff --git a/mlir/test/Dialect/SparseTensor/pre_rewriting.mlir b/mlir/test/Dialect/SparseTensor/pre_rewriting.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SparseTensor/pre_rewriting.mlir @@ -0,0 +1,45 @@ +// RUN: mlir-opt %s -pre-sparsification-rewrite | FileCheck %s + +#SparseVector = #sparse_tensor.encoding<{ + dimLevelType = ["compressed"] +}> + +#SortedCOO = #sparse_tensor.encoding<{ + dimLevelType = [ "compressed-nu", "singleton" ] +}> + +#Slice = #sparse_tensor.encoding<{ + dimLevelType = [ "compressed-nu", "singleton" ], + slice = [ (?, 1, 1), (?, 3, 1) ] +}> + +// CHECK-LABEL: func @sparse_nop_cast( +// CHECK-SAME: %[[A:.*]]: tensor>) +// CHECK: return %[[A]] : tensor> +func.func @sparse_nop_cast(%a : tensor) -> tensor { + %0 = tensor.cast %a : tensor to tensor + %1 = tensor.cast %0 : tensor to tensor + %2 = tensor.cast %1 : tensor to tensor + return %2 : tensor +} + +// CHECK-LABEL: func @sparse_repair_cast( +// CHECK-SAME: %[[A:.*]]: tensor) +// CHECK: %[[C:.*]] = sparse_tensor.convert %[[A]] : tensor to tensor +// CHECK: return %[[C]] : tensor> +func.func @sparse_repair_cast(%a : tensor) -> tensor { + %0 = tensor.cast %a : tensor to tensor + return %0 : tensor +} + +// CHECK-LABEL: func @sparse_fuse_slice( +// CHECK-SAME: %[[A:.*]]: tensor<2x3xi64, #sparse_tensor.encoding<{{{.*}}}>>) +// CHECK: %[[E:.*]] = tensor.extract_slice %[[A]][1, 0] [1, 3] [1, 1] : tensor<2x3xi64, #sparse_tensor.encoding<{{{.*}}}>> to tensor<1x3xi64, #sparse_tensor.encoding<{{{.*}}}>> +// CHECK: %[[C:.*]] = sparse_tensor.convert %[[E]] : tensor<1x3xi64, #sparse_tensor.encoding<{{{.*}}}>> to tensor<1x3xi64, #sparse_tensor.encoding<{{{.*}}}>> +// CHECK: return %[[C]] : tensor<1x3xi64, #sparse_tensor.encoding<{{{.*}}}>> +func.func @sparse_fuse_slice(%a : tensor<2x3xi64, #SortedCOO>) -> tensor<1x3xi64, #SortedCOO> { + %extracted_slice = tensor.extract_slice %a[1, 0] [1, 3] [1, 1] : tensor<2x3xi64, #SortedCOO> to tensor<1x3xi64> + %cast = tensor.cast %extracted_slice : tensor<1x3xi64> to tensor<1x3xi64, #Slice> + %0 = sparse_tensor.convert %cast : tensor<1x3xi64, #Slice> to tensor<1x3xi64, #SortedCOO> + return %0 : tensor<1x3xi64, #SortedCOO> +}