diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h --- a/mlir/include/mlir/IR/Matchers.h +++ b/mlir/include/mlir/IR/Matchers.h @@ -15,6 +15,7 @@ #ifndef MLIR_IR_MATCHERS_H #define MLIR_IR_MATCHERS_H +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpDefinition.h" @@ -38,7 +39,7 @@ /// Creates a matcher instance that binds the value to bv if match succeeds. attr_value_binder(ValueType *bv) : bind_value(bv) {} - bool match(const Attribute &attr) { + bool match(Attribute attr) { if (auto intAttr = llvm::dyn_cast(attr)) { *bind_value = intAttr.getValue(); return true; @@ -123,27 +124,33 @@ }; /// The matcher that matches a constant scalar / vector splat / tensor splat -/// float operation and binds the constant float value. -struct constant_float_op_binder { +/// float Attribute or Operation and binds the constant float value. +struct constant_float_value_binder { FloatAttr::ValueType *bind_value; /// Creates a matcher instance that binds the value to bv if match succeeds. - constant_float_op_binder(FloatAttr::ValueType *bv) : bind_value(bv) {} + constant_float_value_binder(FloatAttr::ValueType *bv) : bind_value(bv) {} + + bool match(Attribute attr) { + attr_value_binder matcher(bind_value); + if (matcher.match(attr)) + return true; + + if (auto splatAttr = dyn_cast(attr)) + return matcher.match(splatAttr.getSplatValue()); + + return false; + } bool match(Operation *op) { Attribute attr; if (!constant_op_binder(&attr).match(op)) return false; - auto type = op->getResult(0).getType(); - - if (llvm::isa(type)) - return attr_value_binder(bind_value).match(attr); - if (llvm::isa(type)) { - if (auto splatAttr = llvm::dyn_cast(attr)) { - return attr_value_binder(bind_value) - .match(splatAttr.getSplatValue()); - } - } + + Type type = op->getResult(0).getType(); + if (isa(type)) + return match(attr); + return false; } }; @@ -153,34 +160,45 @@ struct constant_float_predicate_matcher { bool (*predicate)(const APFloat &); + bool match(Attribute attr) { + APFloat value(APFloat::Bogus()); + return constant_float_value_binder(&value).match(attr) && predicate(value); + } + bool match(Operation *op) { APFloat value(APFloat::Bogus()); - return constant_float_op_binder(&value).match(op) && predicate(value); + return constant_float_value_binder(&value).match(op) && predicate(value); } }; /// The matcher that matches a constant scalar / vector splat / tensor splat -/// integer operation and binds the constant integer value. -struct constant_int_op_binder { +/// integer Attribute or Operation and binds the constant integer value. +struct constant_int_value_binder { IntegerAttr::ValueType *bind_value; /// Creates a matcher instance that binds the value to bv if match succeeds. - constant_int_op_binder(IntegerAttr::ValueType *bv) : bind_value(bv) {} + constant_int_value_binder(IntegerAttr::ValueType *bv) : bind_value(bv) {} + + bool match(Attribute attr) { + attr_value_binder matcher(bind_value); + if (matcher.match(attr)) + return true; + + if (auto splatAttr = dyn_cast(attr)) + return matcher.match(splatAttr.getSplatValue()); + + return false; + } bool match(Operation *op) { Attribute attr; if (!constant_op_binder(&attr).match(op)) return false; - auto type = op->getResult(0).getType(); - - if (llvm::isa(type)) - return attr_value_binder(bind_value).match(attr); - if (llvm::isa(type)) { - if (auto splatAttr = llvm::dyn_cast(attr)) { - return attr_value_binder(bind_value) - .match(splatAttr.getSplatValue()); - } - } + + Type type = op->getResult(0).getType(); + if (isa(type)) + return match(attr); + return false; } }; @@ -190,9 +208,14 @@ struct constant_int_predicate_matcher { bool (*predicate)(const APInt &); + bool match(Attribute attr) { + APInt value; + return constant_int_value_binder(&value).match(attr) && predicate(value); + } + bool match(Operation *op) { APInt value; - return constant_int_op_binder(&value).match(op) && predicate(value); + return constant_int_value_binder(&value).match(op) && predicate(value); } }; @@ -203,14 +226,14 @@ }; /// Trait to check whether T provides a 'match' method with type -/// `OperationOrValue`. -template -using has_operation_or_value_matcher_t = - decltype(std::declval().match(std::declval())); +/// `MatchTarget` (Value, Operation, or Attribute). +template +using has_compatible_matcher_t = + decltype(std::declval().match(std::declval())); /// Statically switch to a Value matcher. template -std::enable_if_t::value, bool> matchOperandOrValueAtIndex(Operation *op, unsigned idx, MatcherClass &matcher) { @@ -219,7 +242,7 @@ /// Statically switch to an Operation matcher. template -std::enable_if_t::value, bool> matchOperandOrValueAtIndex(Operation *op, unsigned idx, MatcherClass &matcher) { @@ -376,6 +399,7 @@ /// Entry point for matching a pattern over a Value. template inline bool matchPattern(Value value, const Pattern &pattern) { + assert(value); // TODO: handle other cases if (auto *op = value.getDefiningOp()) return const_cast(pattern).match(op); @@ -385,21 +409,34 @@ /// Entry point for matching a pattern over an Operation. template inline bool matchPattern(Operation *op, const Pattern &pattern) { + assert(op); return const_cast(pattern).match(op); } +/// Entry point for matching a pattern over an Attribute. Returns `false` +/// when `attr` is null. +template +inline bool matchPattern(Attribute attr, const Pattern &pattern) { + static_assert(llvm::is_detected::value, + "Pattern does not support matching Attributes"); + if (!attr) + return false; + return const_cast(pattern).match(attr); +} + /// Matches a constant holding a scalar/vector/tensor float (splat) and /// writes the float value to bind_value. -inline detail::constant_float_op_binder +inline detail::constant_float_value_binder m_ConstantFloat(FloatAttr::ValueType *bind_value) { - return detail::constant_float_op_binder(bind_value); + return detail::constant_float_value_binder(bind_value); } /// Matches a constant holding a scalar/vector/tensor integer (splat) and /// writes the integer value to bind_value. -inline detail::constant_int_op_binder +inline detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value) { - return detail::constant_int_op_binder(bind_value); + return detail::constant_int_value_binder(bind_value); } template diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -106,12 +106,9 @@ } static FailureOr getIntOrSplatIntValue(Attribute attr) { - if (auto intAttr = llvm::dyn_cast(attr)) - return intAttr.getValue(); - - if (auto splatAttr = llvm::dyn_cast(attr)) - if (llvm::isa(splatAttr.getElementType())) - return splatAttr.getSplatValue(); + APInt value; + if (matchPattern(attr, m_ConstantInt(&value))) + return value; return failure(); } @@ -258,7 +255,7 @@ OpFoldResult arith::AddIOp::fold(FoldAdaptor adaptor) { // addi(x, 0) -> x - if (matchPattern(getRhs(), m_Zero())) + if (matchPattern(adaptor.getRhs(), m_Zero())) return getLhs(); // addi(subi(a, b), b) -> a @@ -349,7 +346,7 @@ if (getOperand(0) == getOperand(1)) return Builder(getContext()).getZeroAttr(getType()); // subi(x,0) -> x - if (matchPattern(getRhs(), m_Zero())) + if (matchPattern(adaptor.getRhs(), m_Zero())) return getLhs(); if (auto add = getLhs().getDefiningOp()) { @@ -379,11 +376,11 @@ OpFoldResult arith::MulIOp::fold(FoldAdaptor adaptor) { // muli(x, 0) -> 0 - if (matchPattern(getRhs(), m_Zero())) + if (matchPattern(adaptor.getRhs(), m_Zero())) return getRhs(); // muli(x, 1) -> x - if (matchPattern(getRhs(), m_One())) - return getOperand(0); + if (matchPattern(adaptor.getRhs(), m_One())) + return getLhs(); // TODO: Handle the overflow case. // default folder @@ -412,7 +409,7 @@ arith::MulSIExtendedOp::fold(FoldAdaptor adaptor, SmallVectorImpl &results) { // mulsi_extended(x, 0) -> 0, 0 - if (matchPattern(getRhs(), m_Zero())) { + if (matchPattern(adaptor.getRhs(), m_Zero())) { Attribute zero = adaptor.getRhs(); results.push_back(zero); results.push_back(zero); @@ -460,7 +457,7 @@ arith::MulUIExtendedOp::fold(FoldAdaptor adaptor, SmallVectorImpl &results) { // mului_extended(x, 0) -> 0, 0 - if (matchPattern(getRhs(), m_Zero())) { + if (matchPattern(adaptor.getRhs(), m_Zero())) { Attribute zero = adaptor.getRhs(); results.push_back(zero); results.push_back(zero); @@ -468,7 +465,7 @@ } // mului_extended(x, 1) -> x, 0 - if (matchPattern(getRhs(), m_One())) { + if (matchPattern(adaptor.getRhs(), m_One())) { Builder builder(getContext()); Attribute zero = builder.getZeroAttr(getLhs().getType()); results.push_back(getLhs()); @@ -508,7 +505,7 @@ OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) { // divui (x, 1) -> x. - if (matchPattern(getRhs(), m_One())) + if (matchPattern(adaptor.getRhs(), m_One())) return getLhs(); // Don't fold if it would require a division by zero. @@ -537,7 +534,7 @@ OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) { // divsi (x, 1) -> x. - if (matchPattern(getRhs(), m_One())) + if (matchPattern(adaptor.getRhs(), m_One())) return getLhs(); // Don't fold if it would overflow or if it requires a division by zero. @@ -584,7 +581,7 @@ OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) { // ceildivui (x, 1) -> x. - if (matchPattern(getRhs(), m_One())) + if (matchPattern(adaptor.getRhs(), m_One())) return getLhs(); bool overflowOrDiv0 = false; @@ -616,7 +613,7 @@ OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) { // ceildivsi (x, 1) -> x. - if (matchPattern(getRhs(), m_One())) + if (matchPattern(adaptor.getRhs(), m_One())) return getLhs(); // Don't fold if it would overflow or if it requires a division by zero. @@ -677,7 +674,7 @@ OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) { // floordivsi (x, 1) -> x. - if (matchPattern(getRhs(), m_One())) + if (matchPattern(adaptor.getRhs(), m_One())) return getLhs(); // Don't fold if it would overflow or if it requires a division by zero. @@ -726,7 +723,7 @@ OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) { // remui (x, 1) -> 0. - if (matchPattern(getRhs(), m_One())) + if (matchPattern(adaptor.getRhs(), m_One())) return Builder(getContext()).getZeroAttr(getType()); // Don't fold if it would require a division by zero. @@ -749,7 +746,7 @@ OpFoldResult arith::RemSIOp::fold(FoldAdaptor adaptor) { // remsi (x, 1) -> 0. - if (matchPattern(getRhs(), m_One())) + if (matchPattern(adaptor.getRhs(), m_One())) return Builder(getContext()).getZeroAttr(getType()); // Don't fold if it would require a division by zero. @@ -789,11 +786,12 @@ OpFoldResult arith::AndIOp::fold(FoldAdaptor adaptor) { /// and(x, 0) -> 0 - if (matchPattern(getRhs(), m_Zero())) + if (matchPattern(adaptor.getRhs(), m_Zero())) return getRhs(); /// and(x, allOnes) -> x APInt intValue; - if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isAllOnes()) + if (matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue)) && + intValue.isAllOnes()) return getLhs(); /// and(x, not(x)) -> 0 if (matchPattern(getRhs(), m_Op(matchers::m_Val(getLhs()), @@ -820,13 +818,14 @@ //===----------------------------------------------------------------------===// OpFoldResult arith::OrIOp::fold(FoldAdaptor adaptor) { - /// or(x, 0) -> x - if (matchPattern(getRhs(), m_Zero())) - return getLhs(); - /// or(x, ) -> - if (auto rhsAttr = llvm::dyn_cast_if_present(adaptor.getRhs())) - if (rhsAttr.getValue().isAllOnes()) - return rhsAttr; + if (APInt rhsVal; matchPattern(adaptor.getRhs(), m_ConstantInt(&rhsVal))) { + /// or(x, 0) -> x + if (rhsVal.isZero()) + return getLhs(); + /// or(x, ) -> + if (rhsVal.isAllOnes()) + return adaptor.getRhs(); + } APInt intValue; /// or(x, xor(x, 1)) -> 1 @@ -851,7 +850,7 @@ OpFoldResult arith::XOrIOp::fold(FoldAdaptor adaptor) { /// xor(x, 0) -> x - if (matchPattern(getRhs(), m_Zero())) + if (matchPattern(adaptor.getRhs(), m_Zero())) return getLhs(); /// xor(x, x) -> 0 if (getLhs() == getRhs()) @@ -901,7 +900,7 @@ OpFoldResult arith::AddFOp::fold(FoldAdaptor adaptor) { // addf(x, -0) -> x - if (matchPattern(getRhs(), m_NegZeroFloat())) + if (matchPattern(adaptor.getRhs(), m_NegZeroFloat())) return getLhs(); return constFoldBinaryOp( @@ -915,7 +914,7 @@ OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) { // subf(x, +0) -> x - if (matchPattern(getRhs(), m_PosZeroFloat())) + if (matchPattern(adaptor.getRhs(), m_PosZeroFloat())) return getLhs(); return constFoldBinaryOp( @@ -933,7 +932,7 @@ return getRhs(); // maxf(x, -inf) -> x - if (matchPattern(getRhs(), m_NegInfFloat())) + if (matchPattern(adaptor.getRhs(), m_NegInfFloat())) return getLhs(); return constFoldBinaryOp( @@ -950,16 +949,15 @@ if (getLhs() == getRhs()) return getRhs(); - APInt intValue; - // maxsi(x,MAX_INT) -> MAX_INT - if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && - intValue.isMaxSignedValue()) - return getRhs(); - - // maxsi(x, MIN_INT) -> x - if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && - intValue.isMinSignedValue()) - return getLhs(); + if (APInt intValue; + matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) { + // maxsi(x,MAX_INT) -> MAX_INT + if (intValue.isMaxSignedValue()) + return getRhs(); + // maxsi(x, MIN_INT) -> x + if (intValue.isMinSignedValue()) + return getLhs(); + } return constFoldBinaryOp(adaptor.getOperands(), [](const APInt &a, const APInt &b) { @@ -976,14 +974,15 @@ if (getLhs() == getRhs()) return getRhs(); - APInt intValue; - // maxui(x,MAX_INT) -> MAX_INT - if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue()) - return getRhs(); - - // maxui(x, MIN_INT) -> x - if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue()) - return getLhs(); + if (APInt intValue; + matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) { + // maxui(x,MAX_INT) -> MAX_INT + if (intValue.isMaxValue()) + return getRhs(); + // maxui(x, MIN_INT) -> x + if (intValue.isMinValue()) + return getLhs(); + } return constFoldBinaryOp(adaptor.getOperands(), [](const APInt &a, const APInt &b) { @@ -1001,7 +1000,7 @@ return getRhs(); // minf(x, +inf) -> x - if (matchPattern(getRhs(), m_PosInfFloat())) + if (matchPattern(adaptor.getRhs(), m_PosInfFloat())) return getLhs(); return constFoldBinaryOp( @@ -1018,16 +1017,15 @@ if (getLhs() == getRhs()) return getRhs(); - APInt intValue; - // minsi(x,MIN_INT) -> MIN_INT - if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && - intValue.isMinSignedValue()) - return getRhs(); - - // minsi(x, MAX_INT) -> x - if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && - intValue.isMaxSignedValue()) - return getLhs(); + if (APInt intValue; + matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) { + // minsi(x,MIN_INT) -> MIN_INT + if (intValue.isMinSignedValue()) + return getRhs(); + // minsi(x, MAX_INT) -> x + if (intValue.isMaxSignedValue()) + return getLhs(); + } return constFoldBinaryOp(adaptor.getOperands(), [](const APInt &a, const APInt &b) { @@ -1044,14 +1042,15 @@ if (getLhs() == getRhs()) return getRhs(); - APInt intValue; - // minui(x,MIN_INT) -> MIN_INT - if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue()) - return getRhs(); - - // minui(x, MAX_INT) -> x - if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue()) - return getLhs(); + if (APInt intValue; + matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) { + // minui(x,MIN_INT) -> MIN_INT + if (intValue.isMinValue()) + return getRhs(); + // minui(x, MAX_INT) -> x + if (intValue.isMaxValue()) + return getLhs(); + } return constFoldBinaryOp(adaptor.getOperands(), [](const APInt &a, const APInt &b) { @@ -1065,7 +1064,7 @@ OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) { // mulf(x, 1) -> x - if (matchPattern(getRhs(), m_OneFloat())) + if (matchPattern(adaptor.getRhs(), m_OneFloat())) return getLhs(); return constFoldBinaryOp( @@ -1084,7 +1083,7 @@ OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) { // divf(x, 1) -> x - if (matchPattern(getRhs(), m_OneFloat())) + if (matchPattern(adaptor.getRhs(), m_OneFloat())) return getLhs(); return constFoldBinaryOp( @@ -1685,7 +1684,7 @@ return getBoolAttribute(getType(), getContext(), val); } - if (matchPattern(getRhs(), m_Zero())) { + if (matchPattern(adaptor.getRhs(), m_Zero())) { if (auto extOp = getLhs().getDefiningOp()) { // extsi(%x : i1 -> iN) != 0 -> %x std::optional integerWidth = @@ -2188,11 +2187,11 @@ Value condition = getCondition(); // select true, %0, %1 => %0 - if (matchPattern(condition, m_One())) + if (matchPattern(adaptor.getCondition(), m_One())) return trueVal; // select false, %0, %1 => %1 - if (matchPattern(condition, m_Zero())) + if (matchPattern(adaptor.getCondition(), m_Zero())) return falseVal; // If either operand is fully poisoned, return the other. @@ -2203,8 +2202,8 @@ return trueVal; // select %x, true, false => %x - if (getType().isInteger(1) && matchPattern(getTrueValue(), m_One()) && - matchPattern(getFalseValue(), m_Zero())) + if (getType().isInteger(1) && matchPattern(adaptor.getTrueValue(), m_One()) && + matchPattern(adaptor.getFalseValue(), m_Zero())) return condition; if (auto cmp = dyn_cast_or_null(condition.getDefiningOp())) { @@ -2313,7 +2312,7 @@ OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) { // shli(x, 0) -> x - if (matchPattern(getRhs(), m_Zero())) + if (matchPattern(adaptor.getRhs(), m_Zero())) return getLhs(); // Don't fold if shifting more than the bit width. bool bounded = false; @@ -2331,7 +2330,7 @@ OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) { // shrui(x, 0) -> x - if (matchPattern(getRhs(), m_Zero())) + if (matchPattern(adaptor.getRhs(), m_Zero())) return getLhs(); // Don't fold if shifting more than the bit width. bool bounded = false; @@ -2349,7 +2348,7 @@ OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) { // shrsi(x, 0) -> x - if (matchPattern(getRhs(), m_Zero())) + if (matchPattern(adaptor.getRhs(), m_Zero())) return getLhs(); // Don't fold if shifting more than the bit width. bool bounded = false; diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -23,6 +23,7 @@ #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Matchers.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/Operation.h" @@ -1966,27 +1967,12 @@ // spirv.BtiwiseAndOp //===----------------------------------------------------------------------===// -static std::optional extractIntConstant(Attribute attr) { - IntegerAttr intAttr; - if (auto splat = dyn_cast_if_present(attr)) - intAttr = dyn_cast(splat.getSplatValue()); - else - intAttr = dyn_cast_if_present(attr); - - if (!intAttr) - return std::nullopt; - - return intAttr.getValue(); -} - OpFoldResult spirv::BitwiseAndOp::fold(spirv::BitwiseAndOp::FoldAdaptor adaptor) { - std::optional rhsVal = extractIntConstant(adaptor.getOperand2()); - if (!rhsVal) + APInt rhsMask; + if (!matchPattern(adaptor.getOperand2(), m_ConstantInt(&rhsMask))) return {}; - APInt rhsMask = *rhsVal; - // x & 0 -> 0 if (rhsMask.isZero()) return getOperand2(); @@ -2011,12 +1997,10 @@ //===----------------------------------------------------------------------===// OpFoldResult spirv::BitwiseOrOp::fold(spirv::BitwiseOrOp::FoldAdaptor adaptor) { - std::optional rhsVal = extractIntConstant(adaptor.getOperand2()); - if (!rhsVal) + APInt rhsMask; + if (!matchPattern(adaptor.getOperand2(), m_ConstantInt(&rhsMask))) return {}; - APInt rhsMask = *rhsVal; - // x | 0 -> x if (rhsMask.isZero()) return getOperand1();