diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h --- a/mlir/include/mlir/IR/Matchers.h +++ b/mlir/include/mlir/IR/Matchers.h @@ -48,14 +48,22 @@ } }; +/// Check to see if the specified operation is ConstantLike. This includes some +/// quick filters to avoid a semi-expensive test in the common case. +static bool isConstantLike(Operation *op) { + return op->getNumOperands() == 0 && op->getNumResults() == 1 && + op->hasTrait(); +} + /// The matcher that matches operations that have the `ConstantLike` trait. struct constant_op_matcher { - bool match(Operation *op) { return op->hasTrait(); } + bool match(Operation *op) { return isConstantLike(op); } }; /// The matcher that matches operations that have the `ConstantLike` trait, and /// binds the folded attribute value. -template struct constant_op_binder { +template +struct constant_op_binder { AttrT *bind_value; /// Creates a matcher instance that binds the constant attribute value to @@ -65,7 +73,7 @@ constant_op_binder() : bind_value(nullptr) {} bool match(Operation *op) { - if (!op->hasTrait()) + if (!isConstantLike(op)) return false; // Fold the constant to an attribute. @@ -111,7 +119,8 @@ /// The matcher that matches a given target constant scalar / vector splat / /// tensor splat integer value. -template struct constant_int_value_matcher { +template +struct constant_int_value_matcher { bool match(Operation *op) { APInt value; return constant_int_op_binder(&value).match(op) && TargetValue == value; @@ -120,7 +129,8 @@ /// The matcher that matches anything except the given target constant scalar / /// vector splat / tensor splat integer value. -template struct constant_int_not_value_matcher { +template +struct constant_int_not_value_matcher { bool match(Operation *op) { APInt value; return constant_int_op_binder(&value).match(op) && TargetNotValue != value; @@ -128,7 +138,8 @@ }; /// The matcher that matches a certain kind of op. -template struct op_matcher { +template +struct op_matcher { bool match(Operation *op) { return isa(op); } }; @@ -224,7 +235,8 @@ } /// Matches the given OpClass. -template inline detail::op_matcher m_Op() { +template +inline detail::op_matcher m_Op() { return detail::op_matcher(); } diff --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp --- a/mlir/lib/Transforms/Utils/FoldUtils.cpp +++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp @@ -130,8 +130,7 @@ region.walk([&](Operation *op) { // If this is a constant, process it. Attribute value; - if (op->getNumOperands() == 0 && op->getNumResults() == 1 && - matchPattern(op, m_Constant(&value))) { + if (matchPattern(op, m_Constant(&value))) { processConstant(op, value); // We may have deleted the operation, don't check it for regions. return WalkResult::advance();