diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td @@ -56,6 +56,7 @@ private: llvm::StringMap namedStructuredOpRegionBuilders; }]; + let hasCanonicalizer = 1; } // Whether a type is a RangeType. diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -336,6 +336,48 @@ } // namespace +//===----------------------------------------------------------------------===// +// LinalgOp +//===----------------------------------------------------------------------===// + +namespace { +/// Rewrite a memref::DimOp on a LinalgOp result to a memref::DimOp on the +/// LinalgOp's respective output. E.g.: +/// ``` +/// %0 = linalg.matmul ins(%a, %b : tensor, tensor) +/// outs(%c : tensor) -> tensor +/// %r = memref.dim %0, %c0 : tensor +/// ``` +/// is rewritten to: +/// ``` +/// %0 = linalg.matmul ins(%a, %b : tensor, tensor) +/// outs(%c : tensor) -> tensor +/// %r = memref.dim %c, %c0 : tensor +/// ``` +struct DimOfLinalgResultCanonicalization : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::DimOp dimOp, + PatternRewriter &rewriter) const override { + auto linalgOp = dimOp.memrefOrTensor().getDefiningOp(); + if (!linalgOp) + return failure(); + + // TODO(gysit): Use LinalgOp interface function that returns an OpOperand& + // corresponding to the given OpResult. + auto idx = dimOp.memrefOrTensor().cast().getResultNumber(); + rewriter.replaceOpWithNewOp(dimOp, linalgOp.getOutput(idx), + dimOp.index()); + return success(); + } +}; +} // namespace + +void LinalgDialect::getCanonicalizationPatterns( + RewritePatternSet &results) const { + results.add(this->getContext()); +} + //===----------------------------------------------------------------------===// // CopyOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -199,7 +199,7 @@ // Recover the subtensor out of the new static results. This keeps the // original linalg op around because it uses the dims of the original results. - // This later folds away. + // This is canonicalized away by DimOfLinalgResultCanonicalization. SmallVector paddedSubviewResults; paddedSubviewResults.reserve(opToPad->getNumResults()); SetVector newUsersOfOpToPad;