diff --git a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td --- a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td +++ b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td @@ -31,11 +31,14 @@ "apply_patterns.tensor.fold_tensor_empty", [DeclareOpInterfaceMethods]> { let description = [{ - Indicates that reassociative reshapes (tensor.collapse_shape / - tensor.expand_shape) should be folded with inverse rank expansions / rank - reductions (via tensor.insert_slice / tensor.extract_slice). + Indicates that tensor.extract_slice and reassociative reshapes should be + folded into tensor.empty. + + If `fold_single_use_only` is set to "true", only tensor.empty that have a + single use are folded. }]; + let arguments = (ins DefaultValuedAttr:$fold_single_use_only); let assemblyFormat = "attr-dict"; } def ApplyFoldIntoPackAndUnpackPatternsOp : Op struct FoldEmptyTensorWithReshapeOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + FoldEmptyTensorWithReshapeOp(MLIRContext *ctx, PatternBenefit benefit = 1, + bool foldSingleUseOnly = false) + : OpRewritePattern(ctx, benefit), + foldSingleUseOnly(foldSingleUseOnly) {} LogicalResult matchAndRewrite(ReshapeOp reshapeOp, PatternRewriter &rewriter) const override { - if (!reshapeOp.getSrc().template getDefiningOp()) + // Check for tensor.empty source. + auto emptyOp = reshapeOp.getSrc().template getDefiningOp(); + if (!emptyOp) return failure(); + + // Check for single use. + if (foldSingleUseOnly && !llvm::hasSingleElement(emptyOp->getUses())) + return failure(); + + // Reify result shape. Location loc = reshapeOp.getLoc(); ReifiedRankedShapedTypeDims resultShapes; if (failed(reifyResultShapes(rewriter, reshapeOp, resultShapes)) || !llvm::hasSingleElement(resultShapes)) return failure(); + + // Create new tensor.empty op. // TODO: Do not drop tensor type encoding. Value emptyTensor = rewriter.create( loc, resultShapes[0], reshapeOp.getResultType().getElementType()); @@ -40,21 +53,34 @@ } return success(); } + +private: + bool foldSingleUseOnly = false; }; -/// `tensor.empty` does not define any tensor contents, so a slice of a -/// `tensor.empty` can be canonicalized to a smaller `tensor.empty`. +/// tensor.empty does not define any tensor contents, so a slice of a +/// tensor.empty can be folded to a smaller tensor.empty. struct FoldEmptyTensorWithExtractSliceOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + FoldEmptyTensorWithExtractSliceOp(MLIRContext *ctx, + PatternBenefit benefit = 1, + bool foldSingleUseOnly = false) + : OpRewritePattern(ctx, benefit), + foldSingleUseOnly(foldSingleUseOnly) {} LogicalResult matchAndRewrite(ExtractSliceOp sliceOp, PatternRewriter &rewriter) const override { - if (!sliceOp.getSource().getDefiningOp()) + // Check for tensor.empty source. + auto emptyOp = sliceOp.getSource().template getDefiningOp(); + if (!emptyOp) + return failure(); + + // Check for single use. + if (foldSingleUseOnly && !llvm::hasSingleElement(emptyOp->getUses())) return failure(); - // ExtractSliceOp may be rank-reducing; its dynamic sizes must be - // preserved as well as its result type. + // Create new tensor.empty op. tensor.extract_slice may be rank-reducing; + // its dynamic sizes must be preserved as well as its result type. auto tensorType = RankedTensorType::get(sliceOp.getType().getShape(), sliceOp.getType().getElementType(), sliceOp.getType().getEncoding()); @@ -62,14 +88,17 @@ sliceOp.getSizes()); return success(); } + +private: + bool foldSingleUseOnly = false; }; } // namespace -void mlir::tensor::populateFoldTensorEmptyPatterns( - RewritePatternSet &patterns) { +void mlir::tensor::populateFoldTensorEmptyPatterns(RewritePatternSet &patterns, + bool foldSingleUseOnly) { patterns.add, FoldEmptyTensorWithReshapeOp>( - patterns.getContext()); + patterns.getContext(), foldSingleUseOnly); }