diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -438,6 +438,7 @@ LogicalResult checkAndNotify(PatternRewriter &rewriter, Operation *op) const; void replaceLinalgTransformationFilter(PatternRewriter &rewriter, Operation *op) const; + bool hasReplacementFilter(Operation *op) const; LinalgTransformationFilter &addFilter(FilterFunction f) { if (f) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -107,6 +107,15 @@ rewriter.getContext())); } +bool mlir::linalg::LinalgTransformationFilter::hasReplacementFilter( + Operation *op) const { + if (!replacement) + return false; + auto attr = op->getAttr(LinalgTransforms::kLinalgTransformMarker) + .dyn_cast(); + return attr && attr == replacement.getValue(); +} + LinalgTilingOptions & mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef ts) { assert(!tileSizeComputationFunction && "tile sizes already set");