diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -448,11 +448,18 @@ return addFilter( [](Operation *op) { return success(isa(op)); }); } + LinalgTransformationFilter &setMatchByDefault() { + matchByDefault = true; + return *this; + } private: SmallVector filters; SmallVector matchDisjunction; Optional replacement; + /// When set to true, if the attribute is not set, it will be treated as + /// a match. Default is false. + bool matchByDefault; }; using TileSizeComputationFunction = diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -51,14 +51,14 @@ mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter( ArrayRef matchDisjunction, Optional replacement) : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), - replacement(replacement) {} + replacement(replacement), matchByDefault(false) {} mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter( FilterFunction f, ArrayRef matchDisjunction, Optional replacement) : filters(), matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), - replacement(replacement) { + replacement(replacement), matchByDefault(false) { if (f) filters.push_back(f); } @@ -74,7 +74,7 @@ if (!attr) { // 1. Has no filter case and matchDisjunction is empty. - if (matchDisjunction.empty()) + if (matchDisjunction.empty() || matchByDefault) return success(); // 2. Has no filter but was expecting a filter.