diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td --- a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td +++ b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td @@ -24,6 +24,7 @@ let hasConstantMaterializer = 1; let useDefaultAttributePrinterParser = 1; + let useFoldAPI = kEmitFoldAdaptorFolder; } // The predicate indicates the type of the comparison to perform: diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -164,9 +164,7 @@ return value.isa(); } -OpFoldResult arith::ConstantOp::fold(ArrayRef operands) { - return getValue(); -} +OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); } void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result, int64_t value, unsigned width) { @@ -217,7 +215,7 @@ // AddIOp //===----------------------------------------------------------------------===// -OpFoldResult arith::AddIOp::fold(ArrayRef operands) { +OpFoldResult arith::AddIOp::fold(FoldAdaptor adaptor) { // addi(x, 0) -> x if (matchPattern(getRhs(), m_Zero())) return getLhs(); @@ -233,7 +231,8 @@ return sub.getLhs(); return constFoldBinaryOp( - operands, [](APInt a, const APInt &b) { return std::move(a) + b; }); + adaptor.getOperands(), + [](APInt a, const APInt &b) { return std::move(a) + b; }); } void arith::AddIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, @@ -260,7 +259,7 @@ } LogicalResult -arith::AddUIExtendedOp::fold(ArrayRef operands, +arith::AddUIExtendedOp::fold(FoldAdaptor adaptor, SmallVectorImpl &results) { Type overflowTy = getOverflow().getType(); // addui_extended(x, 0) -> x, false @@ -280,21 +279,22 @@ // `constFoldBinaryOp` again to calculate the overflow bit because the // constructed attribute is of the same element type as both operands. if (Attribute sumAttr = constFoldBinaryOp( - operands, [](APInt a, const APInt &b) { return std::move(a) + b; })) { + adaptor.getOperands(), + [](APInt a, const APInt &b) { return std::move(a) + b; })) { Attribute overflowAttr; - if (auto lhs = operands[0].dyn_cast()) { + if (auto lhs = adaptor.getLhs().dyn_cast()) { // Both arguments are scalars, calculate the scalar overflow value. auto sum = sumAttr.cast(); overflowAttr = IntegerAttr::get( overflowTy, calculateUnsignedOverflow(sum.getValue(), lhs.getValue())); - } else if (auto lhs = operands[0].dyn_cast()) { + } else if (auto lhs = adaptor.getLhs().dyn_cast()) { // Both arguments are splats, calculate the splat overflow value. auto sum = sumAttr.cast(); APInt overflow = calculateUnsignedOverflow(sum.getSplatValue(), lhs.getSplatValue()); overflowAttr = SplatElementsAttr::get(overflowTy, overflow); - } else if (auto lhs = operands[0].dyn_cast()) { + } else if (auto lhs = adaptor.getLhs().dyn_cast()) { // Othwerwise calculate element-wise overflow values. auto sum = sumAttr.cast(); const auto numElems = static_cast(sum.getNumElements()); @@ -328,7 +328,7 @@ // SubIOp //===----------------------------------------------------------------------===// -OpFoldResult arith::SubIOp::fold(ArrayRef operands) { +OpFoldResult arith::SubIOp::fold(FoldAdaptor adaptor) { // subi(x,x) -> 0 if (getOperand(0) == getOperand(1)) return Builder(getContext()).getZeroAttr(getType()); @@ -346,7 +346,8 @@ } return constFoldBinaryOp( - operands, [](APInt a, const APInt &b) { return std::move(a) - b; }); + adaptor.getOperands(), + [](APInt a, const APInt &b) { return std::move(a) - b; }); } void arith::SubIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, @@ -360,7 +361,7 @@ // MulIOp //===----------------------------------------------------------------------===// -OpFoldResult arith::MulIOp::fold(ArrayRef operands) { +OpFoldResult arith::MulIOp::fold(FoldAdaptor adaptor) { // muli(x, 0) -> 0 if (matchPattern(getRhs(), m_Zero())) return getRhs(); @@ -371,7 +372,8 @@ // default folder return constFoldBinaryOp( - operands, [](const APInt &a, const APInt &b) { return a * b; }); + adaptor.getOperands(), + [](const APInt &a, const APInt &b) { return a * b; }); } //===----------------------------------------------------------------------===// @@ -386,11 +388,11 @@ } LogicalResult -arith::MulSIExtendedOp::fold(ArrayRef operands, +arith::MulSIExtendedOp::fold(FoldAdaptor adaptor, SmallVectorImpl &results) { // mulsi_extended(x, 0) -> 0, 0 if (matchPattern(getRhs(), m_Zero())) { - Attribute zero = operands[1]; + Attribute zero = adaptor.getRhs(); results.push_back(zero); results.push_back(zero); return success(); @@ -398,10 +400,11 @@ // mulsi_extended(cst_a, cst_b) -> cst_low, cst_high if (Attribute lowAttr = constFoldBinaryOp( - operands, [](const APInt &a, const APInt &b) { return a * b; })) { + adaptor.getOperands(), + [](const APInt &a, const APInt &b) { return a * b; })) { // Invoke the constant fold helper again to calculate the 'high' result. Attribute highAttr = constFoldBinaryOp( - operands, [](const APInt &a, const APInt &b) { + adaptor.getOperands(), [](const APInt &a, const APInt &b) { unsigned bitWidth = a.getBitWidth(); APInt fullProduct = a.sext(bitWidth * 2) * b.sext(bitWidth * 2); return fullProduct.extractBits(bitWidth, bitWidth); @@ -433,11 +436,11 @@ } LogicalResult -arith::MulUIExtendedOp::fold(ArrayRef operands, +arith::MulUIExtendedOp::fold(FoldAdaptor adaptor, SmallVectorImpl &results) { // mului_extended(x, 0) -> 0, 0 if (matchPattern(getRhs(), m_Zero())) { - Attribute zero = operands[1]; + Attribute zero = adaptor.getRhs(); results.push_back(zero); results.push_back(zero); return success(); @@ -454,10 +457,11 @@ // mului_extended(cst_a, cst_b) -> cst_low, cst_high if (Attribute lowAttr = constFoldBinaryOp( - operands, [](const APInt &a, const APInt &b) { return a * b; })) { + adaptor.getOperands(), + [](const APInt &a, const APInt &b) { return a * b; })) { // Invoke the constant fold helper again to calculate the 'high' result. Attribute highAttr = constFoldBinaryOp( - operands, [](const APInt &a, const APInt &b) { + adaptor.getOperands(), [](const APInt &a, const APInt &b) { unsigned bitWidth = a.getBitWidth(); APInt fullProduct = a.zext(bitWidth * 2) * b.zext(bitWidth * 2); return fullProduct.extractBits(bitWidth, bitWidth); @@ -481,21 +485,21 @@ // DivUIOp //===----------------------------------------------------------------------===// -OpFoldResult arith::DivUIOp::fold(ArrayRef operands) { +OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) { // divui (x, 1) -> x. if (matchPattern(getRhs(), m_One())) return getLhs(); // Don't fold if it would require a division by zero. bool div0 = false; - auto result = - constFoldBinaryOp(operands, [&](APInt a, const APInt &b) { - if (div0 || !b) { - div0 = true; - return a; - } - return a.udiv(b); - }); + auto result = constFoldBinaryOp(adaptor.getOperands(), + [&](APInt a, const APInt &b) { + if (div0 || !b) { + div0 = true; + return a; + } + return a.udiv(b); + }); return div0 ? Attribute() : result; } @@ -510,15 +514,15 @@ // DivSIOp //===----------------------------------------------------------------------===// -OpFoldResult arith::DivSIOp::fold(ArrayRef operands) { +OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) { // divsi (x, 1) -> x. if (matchPattern(getRhs(), m_One())) return getLhs(); // Don't fold if it would overflow or if it requires a division by zero. bool overflowOrDiv0 = false; - auto result = - constFoldBinaryOp(operands, [&](APInt a, const APInt &b) { + auto result = constFoldBinaryOp( + adaptor.getOperands(), [&](APInt a, const APInt &b) { if (overflowOrDiv0 || !b) { overflowOrDiv0 = true; return a; @@ -557,14 +561,14 @@ // CeilDivUIOp //===----------------------------------------------------------------------===// -OpFoldResult arith::CeilDivUIOp::fold(ArrayRef operands) { +OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) { // ceildivui (x, 1) -> x. if (matchPattern(getRhs(), m_One())) return getLhs(); bool overflowOrDiv0 = false; - auto result = - constFoldBinaryOp(operands, [&](APInt a, const APInt &b) { + auto result = constFoldBinaryOp( + adaptor.getOperands(), [&](APInt a, const APInt &b) { if (overflowOrDiv0 || !b) { overflowOrDiv0 = true; return a; @@ -589,15 +593,15 @@ // CeilDivSIOp //===----------------------------------------------------------------------===// -OpFoldResult arith::CeilDivSIOp::fold(ArrayRef operands) { +OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) { // ceildivsi (x, 1) -> x. if (matchPattern(getRhs(), m_One())) return getLhs(); // Don't fold if it would overflow or if it requires a division by zero. bool overflowOrDiv0 = false; - auto result = - constFoldBinaryOp(operands, [&](APInt a, const APInt &b) { + auto result = constFoldBinaryOp( + adaptor.getOperands(), [&](APInt a, const APInt &b) { if (overflowOrDiv0 || !b) { overflowOrDiv0 = true; return a; @@ -650,15 +654,15 @@ // FloorDivSIOp //===----------------------------------------------------------------------===// -OpFoldResult arith::FloorDivSIOp::fold(ArrayRef operands) { +OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) { // floordivsi (x, 1) -> x. if (matchPattern(getRhs(), m_One())) return getLhs(); // Don't fold if it would overflow or if it requires a division by zero. bool overflowOrDiv0 = false; - auto result = - constFoldBinaryOp(operands, [&](APInt a, const APInt &b) { + auto result = constFoldBinaryOp( + adaptor.getOperands(), [&](APInt a, const APInt &b) { if (overflowOrDiv0 || !b) { overflowOrDiv0 = true; return a; @@ -699,21 +703,21 @@ // RemUIOp //===----------------------------------------------------------------------===// -OpFoldResult arith::RemUIOp::fold(ArrayRef operands) { +OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) { // remui (x, 1) -> 0. if (matchPattern(getRhs(), m_One())) return Builder(getContext()).getZeroAttr(getType()); // Don't fold if it would require a division by zero. bool div0 = false; - auto result = - constFoldBinaryOp(operands, [&](APInt a, const APInt &b) { - if (div0 || b.isNullValue()) { - div0 = true; - return a; - } - return a.urem(b); - }); + auto result = constFoldBinaryOp(adaptor.getOperands(), + [&](APInt a, const APInt &b) { + if (div0 || b.isNullValue()) { + div0 = true; + return a; + } + return a.urem(b); + }); return div0 ? Attribute() : result; } @@ -722,21 +726,21 @@ // RemSIOp //===----------------------------------------------------------------------===// -OpFoldResult arith::RemSIOp::fold(ArrayRef operands) { +OpFoldResult arith::RemSIOp::fold(FoldAdaptor adaptor) { // remsi (x, 1) -> 0. if (matchPattern(getRhs(), m_One())) return Builder(getContext()).getZeroAttr(getType()); // Don't fold if it would require a division by zero. bool div0 = false; - auto result = - constFoldBinaryOp(operands, [&](APInt a, const APInt &b) { - if (div0 || b.isNullValue()) { - div0 = true; - return a; - } - return a.srem(b); - }); + auto result = constFoldBinaryOp(adaptor.getOperands(), + [&](APInt a, const APInt &b) { + if (div0 || b.isNullValue()) { + div0 = true; + return a; + } + return a.srem(b); + }); return div0 ? Attribute() : result; } @@ -762,7 +766,7 @@ return {}; } -OpFoldResult arith::AndIOp::fold(ArrayRef operands) { +OpFoldResult arith::AndIOp::fold(FoldAdaptor adaptor) { /// and(x, 0) -> 0 if (matchPattern(getRhs(), m_Zero())) return getRhs(); @@ -786,31 +790,33 @@ return result; return constFoldBinaryOp( - operands, [](APInt a, const APInt &b) { return std::move(a) & b; }); + adaptor.getOperands(), + [](APInt a, const APInt &b) { return std::move(a) & b; }); } //===----------------------------------------------------------------------===// // OrIOp //===----------------------------------------------------------------------===// -OpFoldResult arith::OrIOp::fold(ArrayRef operands) { +OpFoldResult arith::OrIOp::fold(FoldAdaptor adaptor) { /// or(x, 0) -> x if (matchPattern(getRhs(), m_Zero())) return getLhs(); /// or(x, ) -> - if (auto rhsAttr = operands[1].dyn_cast_or_null()) + if (auto rhsAttr = adaptor.getRhs().dyn_cast_or_null()) if (rhsAttr.getValue().isAllOnes()) return rhsAttr; return constFoldBinaryOp( - operands, [](APInt a, const APInt &b) { return std::move(a) | b; }); + adaptor.getOperands(), + [](APInt a, const APInt &b) { return std::move(a) | b; }); } //===----------------------------------------------------------------------===// // XOrIOp //===----------------------------------------------------------------------===// -OpFoldResult arith::XOrIOp::fold(ArrayRef operands) { +OpFoldResult arith::XOrIOp::fold(FoldAdaptor adaptor) { /// xor(x, 0) -> x if (matchPattern(getRhs(), m_Zero())) return getLhs(); @@ -835,7 +841,8 @@ } return constFoldBinaryOp( - operands, [](APInt a, const APInt &b) { return std::move(a) ^ b; }); + adaptor.getOperands(), + [](APInt a, const APInt &b) { return std::move(a) ^ b; }); } void arith::XOrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, @@ -847,11 +854,11 @@ // NegFOp //===----------------------------------------------------------------------===// -OpFoldResult arith::NegFOp::fold(ArrayRef operands) { +OpFoldResult arith::NegFOp::fold(FoldAdaptor adaptor) { /// negf(negf(x)) -> x if (auto op = this->getOperand().getDefiningOp()) return op.getOperand(); - return constFoldUnaryOp(operands, + return constFoldUnaryOp(adaptor.getOperands(), [](const APFloat &a) { return -a; }); } @@ -859,35 +866,35 @@ // AddFOp //===----------------------------------------------------------------------===// -OpFoldResult arith::AddFOp::fold(ArrayRef operands) { +OpFoldResult arith::AddFOp::fold(FoldAdaptor adaptor) { // addf(x, -0) -> x if (matchPattern(getRhs(), m_NegZeroFloat())) return getLhs(); return constFoldBinaryOp( - operands, [](const APFloat &a, const APFloat &b) { return a + b; }); + adaptor.getOperands(), + [](const APFloat &a, const APFloat &b) { return a + b; }); } //===----------------------------------------------------------------------===// // SubFOp //===----------------------------------------------------------------------===// -OpFoldResult arith::SubFOp::fold(ArrayRef operands) { +OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) { // subf(x, +0) -> x if (matchPattern(getRhs(), m_PosZeroFloat())) return getLhs(); return constFoldBinaryOp( - operands, [](const APFloat &a, const APFloat &b) { return a - b; }); + adaptor.getOperands(), + [](const APFloat &a, const APFloat &b) { return a - b; }); } //===----------------------------------------------------------------------===// // MaxFOp //===----------------------------------------------------------------------===// -OpFoldResult arith::MaxFOp::fold(ArrayRef operands) { - assert(operands.size() == 2 && "maxf takes two operands"); - +OpFoldResult arith::MaxFOp::fold(FoldAdaptor adaptor) { // maxf(x,x) -> x if (getLhs() == getRhs()) return getRhs(); @@ -897,7 +904,7 @@ return getLhs(); return constFoldBinaryOp( - operands, + adaptor.getOperands(), [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); }); } @@ -905,9 +912,7 @@ // MaxSIOp //===----------------------------------------------------------------------===// -OpFoldResult MaxSIOp::fold(ArrayRef operands) { - assert(operands.size() == 2 && "binary operation takes two operands"); - +OpFoldResult MaxSIOp::fold(FoldAdaptor adaptor) { // maxsi(x,x) -> x if (getLhs() == getRhs()) return getRhs(); @@ -923,7 +928,7 @@ intValue.isMinSignedValue()) return getLhs(); - return constFoldBinaryOp(operands, + return constFoldBinaryOp(adaptor.getOperands(), [](const APInt &a, const APInt &b) { return llvm::APIntOps::smax(a, b); }); @@ -933,9 +938,7 @@ // MaxUIOp //===----------------------------------------------------------------------===// -OpFoldResult MaxUIOp::fold(ArrayRef operands) { - assert(operands.size() == 2 && "binary operation takes two operands"); - +OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) { // maxui(x,x) -> x if (getLhs() == getRhs()) return getRhs(); @@ -949,7 +952,7 @@ if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue()) return getLhs(); - return constFoldBinaryOp(operands, + return constFoldBinaryOp(adaptor.getOperands(), [](const APInt &a, const APInt &b) { return llvm::APIntOps::umax(a, b); }); @@ -959,9 +962,7 @@ // MinFOp //===----------------------------------------------------------------------===// -OpFoldResult arith::MinFOp::fold(ArrayRef operands) { - assert(operands.size() == 2 && "minf takes two operands"); - +OpFoldResult arith::MinFOp::fold(FoldAdaptor adaptor) { // minf(x,x) -> x if (getLhs() == getRhs()) return getRhs(); @@ -971,7 +972,7 @@ return getLhs(); return constFoldBinaryOp( - operands, + adaptor.getOperands(), [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); }); } @@ -979,9 +980,7 @@ // MinSIOp //===----------------------------------------------------------------------===// -OpFoldResult MinSIOp::fold(ArrayRef operands) { - assert(operands.size() == 2 && "binary operation takes two operands"); - +OpFoldResult MinSIOp::fold(FoldAdaptor adaptor) { // minsi(x,x) -> x if (getLhs() == getRhs()) return getRhs(); @@ -997,7 +996,7 @@ intValue.isMaxSignedValue()) return getLhs(); - return constFoldBinaryOp(operands, + return constFoldBinaryOp(adaptor.getOperands(), [](const APInt &a, const APInt &b) { return llvm::APIntOps::smin(a, b); }); @@ -1007,9 +1006,7 @@ // MinUIOp //===----------------------------------------------------------------------===// -OpFoldResult MinUIOp::fold(ArrayRef operands) { - assert(operands.size() == 2 && "binary operation takes two operands"); - +OpFoldResult MinUIOp::fold(FoldAdaptor adaptor) { // minui(x,x) -> x if (getLhs() == getRhs()) return getRhs(); @@ -1023,7 +1020,7 @@ if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue()) return getLhs(); - return constFoldBinaryOp(operands, + return constFoldBinaryOp(adaptor.getOperands(), [](const APInt &a, const APInt &b) { return llvm::APIntOps::umin(a, b); }); @@ -1033,13 +1030,14 @@ // MulFOp //===----------------------------------------------------------------------===// -OpFoldResult arith::MulFOp::fold(ArrayRef operands) { +OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) { // mulf(x, 1) -> x if (matchPattern(getRhs(), m_OneFloat())) return getLhs(); return constFoldBinaryOp( - operands, [](const APFloat &a, const APFloat &b) { return a * b; }); + adaptor.getOperands(), + [](const APFloat &a, const APFloat &b) { return a * b; }); } void arith::MulFOp::getCanonicalizationPatterns(RewritePatternSet &patterns, @@ -1051,13 +1049,14 @@ // DivFOp //===----------------------------------------------------------------------===// -OpFoldResult arith::DivFOp::fold(ArrayRef operands) { +OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) { // divf(x, 1) -> x if (matchPattern(getRhs(), m_OneFloat())) return getLhs(); return constFoldBinaryOp( - operands, [](const APFloat &a, const APFloat &b) { return a / b; }); + adaptor.getOperands(), + [](const APFloat &a, const APFloat &b) { return a / b; }); } void arith::DivFOp::getCanonicalizationPatterns(RewritePatternSet &patterns, @@ -1069,8 +1068,8 @@ // RemFOp //===----------------------------------------------------------------------===// -OpFoldResult arith::RemFOp::fold(ArrayRef operands) { - return constFoldBinaryOp(operands, +OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) { + return constFoldBinaryOp(adaptor.getOperands(), [](const APFloat &a, const APFloat &b) { APFloat result(a); (void)result.remainder(b); @@ -1170,7 +1169,7 @@ // ExtUIOp //===----------------------------------------------------------------------===// -OpFoldResult arith::ExtUIOp::fold(ArrayRef operands) { +OpFoldResult arith::ExtUIOp::fold(FoldAdaptor adaptor) { if (auto lhs = getIn().getDefiningOp()) { getInMutable().assign(lhs.getIn()); return getResult(); @@ -1179,7 +1178,8 @@ Type resType = getElementTypeOrSelf(getType()); unsigned bitWidth = resType.cast().getWidth(); return constFoldCastOp( - operands, getType(), [bitWidth](const APInt &a, bool &castStatus) { + adaptor.getOperands(), getType(), + [bitWidth](const APInt &a, bool &castStatus) { return a.zext(bitWidth); }); } @@ -1196,7 +1196,7 @@ // ExtSIOp //===----------------------------------------------------------------------===// -OpFoldResult arith::ExtSIOp::fold(ArrayRef operands) { +OpFoldResult arith::ExtSIOp::fold(FoldAdaptor adaptor) { if (auto lhs = getIn().getDefiningOp()) { getInMutable().assign(lhs.getIn()); return getResult(); @@ -1205,7 +1205,8 @@ Type resType = getElementTypeOrSelf(getType()); unsigned bitWidth = resType.cast().getWidth(); return constFoldCastOp( - operands, getType(), [bitWidth](const APInt &a, bool &castStatus) { + adaptor.getOperands(), getType(), + [bitWidth](const APInt &a, bool &castStatus) { return a.sext(bitWidth); }); } @@ -1237,9 +1238,7 @@ // TruncIOp //===----------------------------------------------------------------------===// -OpFoldResult arith::TruncIOp::fold(ArrayRef operands) { - assert(operands.size() == 1 && "unary operation takes one operand"); - +OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) { // trunci(zexti(a)) -> a // trunci(sexti(a)) -> a if (matchPattern(getOperand(), m_Op()) || @@ -1255,7 +1254,8 @@ Type resType = getElementTypeOrSelf(getType()); unsigned bitWidth = resType.cast().getWidth(); return constFoldCastOp( - operands, getType(), [bitWidth](const APInt &a, bool &castStatus) { + adaptor.getOperands(), getType(), + [bitWidth](const APInt &a, bool &castStatus) { return a.trunc(bitWidth); }); } @@ -1280,10 +1280,8 @@ /// Perform safe const propagation for truncf, i.e. only propagate if FP value /// can be represented without precision loss or rounding. -OpFoldResult arith::TruncFOp::fold(ArrayRef operands) { - assert(operands.size() == 1 && "unary operation takes one operand"); - - auto constOperand = operands.front(); +OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) { + auto constOperand = adaptor.getIn(); if (!constOperand || !constOperand.isa()) return {}; @@ -1348,10 +1346,11 @@ return checkIntFloatCast(inputs, outputs); } -OpFoldResult arith::UIToFPOp::fold(ArrayRef operands) { +OpFoldResult arith::UIToFPOp::fold(FoldAdaptor adaptor) { Type resEleType = getElementTypeOrSelf(getType()); return constFoldCastOp( - operands, getType(), [&resEleType](const APInt &a, bool &castStatus) { + adaptor.getOperands(), getType(), + [&resEleType](const APInt &a, bool &castStatus) { FloatType floatTy = resEleType.cast(); APFloat apf(floatTy.getFloatSemantics(), APInt::getZero(floatTy.getWidth())); @@ -1369,10 +1368,11 @@ return checkIntFloatCast(inputs, outputs); } -OpFoldResult arith::SIToFPOp::fold(ArrayRef operands) { +OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) { Type resEleType = getElementTypeOrSelf(getType()); return constFoldCastOp( - operands, getType(), [&resEleType](const APInt &a, bool &castStatus) { + adaptor.getOperands(), getType(), + [&resEleType](const APInt &a, bool &castStatus) { FloatType floatTy = resEleType.cast(); APFloat apf(floatTy.getFloatSemantics(), APInt::getZero(floatTy.getWidth())); @@ -1389,11 +1389,12 @@ return checkIntFloatCast(inputs, outputs); } -OpFoldResult arith::FPToUIOp::fold(ArrayRef operands) { +OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) { Type resType = getElementTypeOrSelf(getType()); unsigned bitWidth = resType.cast().getWidth(); return constFoldCastOp( - operands, getType(), [&bitWidth](const APFloat &a, bool &castStatus) { + adaptor.getOperands(), getType(), + [&bitWidth](const APFloat &a, bool &castStatus) { bool ignored; APSInt api(bitWidth, /*isUnsigned=*/true); castStatus = APFloat::opInvalidOp != @@ -1410,11 +1411,12 @@ return checkIntFloatCast(inputs, outputs); } -OpFoldResult arith::FPToSIOp::fold(ArrayRef operands) { +OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) { Type resType = getElementTypeOrSelf(getType()); unsigned bitWidth = resType.cast().getWidth(); return constFoldCastOp( - operands, getType(), [&bitWidth](const APFloat &a, bool &castStatus) { + adaptor.getOperands(), getType(), + [&bitWidth](const APFloat &a, bool &castStatus) { bool ignored; APSInt api(bitWidth, /*isUnsigned=*/false); castStatus = APFloat::opInvalidOp != @@ -1445,11 +1447,11 @@ return areIndexCastCompatible(inputs, outputs); } -OpFoldResult arith::IndexCastOp::fold(ArrayRef operands) { +OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) { // index_cast(constant) -> constant // A little hack because we go through int. Otherwise, the size of the // constant might need to change. - if (auto value = operands[0].dyn_cast_or_null()) + if (auto value = adaptor.getIn().dyn_cast_or_null()) return IntegerAttr::get(getType(), value.getInt()); return {}; @@ -1469,11 +1471,11 @@ return areIndexCastCompatible(inputs, outputs); } -OpFoldResult arith::IndexCastUIOp::fold(ArrayRef operands) { +OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) { // index_castui(constant) -> constant // A little hack because we go through int. Otherwise, the size of the // constant might need to change. - if (auto value = operands[0].dyn_cast_or_null()) + if (auto value = adaptor.getIn().dyn_cast_or_null()) return IntegerAttr::get(getType(), value.getValue().getZExtValue()); return {}; @@ -1502,11 +1504,9 @@ return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth(); } -OpFoldResult arith::BitcastOp::fold(ArrayRef operands) { - assert(operands.size() == 1 && "bitcast op expects 1 operand"); - +OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) { auto resType = getType(); - auto operand = operands[0]; + auto operand = adaptor.getIn(); if (!operand) return {}; @@ -1620,9 +1620,7 @@ return std::nullopt; } -OpFoldResult arith::CmpIOp::fold(ArrayRef operands) { - assert(operands.size() == 2 && "cmpi takes two operands"); - +OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) { // cmpi(pred, x, x) if (getLhs() == getRhs()) { auto val = applyCmpPredicateToEqualOperands(getPredicate()); @@ -1649,7 +1647,7 @@ } // Move constant to the right side. - if (operands[0] && !operands[1]) { + if (adaptor.getLhs() && !adaptor.getRhs()) { // Do not use invertPredicate, as it will change eq to ne and vice versa. using Pred = CmpIPredicate; const std::pair invPreds[] = { @@ -1672,13 +1670,13 @@ llvm_unreachable("unknown cmpi predicate kind"); } - auto lhs = operands.front().dyn_cast_or_null(); + auto lhs = adaptor.getLhs().dyn_cast_or_null(); if (!lhs) return {}; // We are moving constants to the right side; So if lhs is constant rhs is // guaranteed to be a constant. - auto rhs = operands.back().cast(); + auto rhs = adaptor.getRhs().cast(); auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); return BoolAttr::get(getContext(), val); @@ -1741,11 +1739,9 @@ llvm_unreachable("unknown cmpf predicate kind"); } -OpFoldResult arith::CmpFOp::fold(ArrayRef operands) { - assert(operands.size() == 2 && "cmpf takes two operands"); - - auto lhs = operands.front().dyn_cast_or_null(); - auto rhs = operands.back().dyn_cast_or_null(); +OpFoldResult arith::CmpFOp::fold(FoldAdaptor adaptor) { + auto lhs = adaptor.getLhs().dyn_cast_or_null(); + auto rhs = adaptor.getRhs().dyn_cast_or_null(); // If one operand is NaN, making them both NaN does not change the result. if (lhs && lhs.getValue().isNaN()) @@ -2123,7 +2119,7 @@ results.add(context); } -OpFoldResult arith::SelectOp::fold(ArrayRef operands) { +OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) { Value trueVal = getTrueValue(); Value falseVal = getFalseValue(); if (trueVal == falseVal) @@ -2220,14 +2216,14 @@ // ShLIOp //===----------------------------------------------------------------------===// -OpFoldResult arith::ShLIOp::fold(ArrayRef operands) { +OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) { // shli(x, 0) -> x if (matchPattern(getRhs(), m_Zero())) return getLhs(); // Don't fold if shifting more than the bit width. bool bounded = false; auto result = constFoldBinaryOp( - operands, [&](const APInt &a, const APInt &b) { + adaptor.getOperands(), [&](const APInt &a, const APInt &b) { bounded = b.ule(b.getBitWidth()); return a.shl(b); }); @@ -2238,14 +2234,14 @@ // ShRUIOp //===----------------------------------------------------------------------===// -OpFoldResult arith::ShRUIOp::fold(ArrayRef operands) { +OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) { // shrui(x, 0) -> x if (matchPattern(getRhs(), m_Zero())) return getLhs(); // Don't fold if shifting more than the bit width. bool bounded = false; auto result = constFoldBinaryOp( - operands, [&](const APInt &a, const APInt &b) { + adaptor.getOperands(), [&](const APInt &a, const APInt &b) { bounded = b.ule(b.getBitWidth()); return a.lshr(b); }); @@ -2256,14 +2252,14 @@ // ShRSIOp //===----------------------------------------------------------------------===// -OpFoldResult arith::ShRSIOp::fold(ArrayRef operands) { +OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) { // shrsi(x, 0) -> x if (matchPattern(getRhs(), m_Zero())) return getLhs(); // Don't fold if shifting more than the bit width. bool bounded = false; auto result = constFoldBinaryOp( - operands, [&](const APInt &a, const APInt &b) { + adaptor.getOperands(), [&](const APInt &a, const APInt &b) { bounded = b.ule(b.getBitWidth()); return a.ashr(b); });