diff --git a/mlir/include/mlir/Dialect/Index/IR/IndexDialect.td b/mlir/include/mlir/Dialect/Index/IR/IndexDialect.td --- a/mlir/include/mlir/Dialect/Index/IR/IndexDialect.td +++ b/mlir/include/mlir/Dialect/Index/IR/IndexDialect.td @@ -83,6 +83,7 @@ let hasConstantMaterializer = 1; let useDefaultAttributePrinterParser = 1; + let useFoldAPI = kEmitFoldAdaptorFolder; } #endif // INDEX_DIALECT diff --git a/mlir/lib/Dialect/Index/IR/IndexOps.cpp b/mlir/lib/Dialect/Index/IR/IndexOps.cpp --- a/mlir/lib/Dialect/Index/IR/IndexOps.cpp +++ b/mlir/lib/Dialect/Index/IR/IndexOps.cpp @@ -115,36 +115,40 @@ // AddOp //===----------------------------------------------------------------------===// -OpFoldResult AddOp::fold(ArrayRef operands) { +OpFoldResult AddOp::fold(FoldAdaptor adaptor) { return foldBinaryOpUnchecked( - operands, [](const APInt &lhs, const APInt &rhs) { return lhs + rhs; }); + adaptor.getOperands(), + [](const APInt &lhs, const APInt &rhs) { return lhs + rhs; }); } //===----------------------------------------------------------------------===// // SubOp //===----------------------------------------------------------------------===// -OpFoldResult SubOp::fold(ArrayRef operands) { +OpFoldResult SubOp::fold(FoldAdaptor adaptor) { return foldBinaryOpUnchecked( - operands, [](const APInt &lhs, const APInt &rhs) { return lhs - rhs; }); + adaptor.getOperands(), + [](const APInt &lhs, const APInt &rhs) { return lhs - rhs; }); } //===----------------------------------------------------------------------===// // MulOp //===----------------------------------------------------------------------===// -OpFoldResult MulOp::fold(ArrayRef operands) { +OpFoldResult MulOp::fold(FoldAdaptor adaptor) { return foldBinaryOpUnchecked( - operands, [](const APInt &lhs, const APInt &rhs) { return lhs * rhs; }); + adaptor.getOperands(), + [](const APInt &lhs, const APInt &rhs) { return lhs * rhs; }); } //===----------------------------------------------------------------------===// // DivSOp //===----------------------------------------------------------------------===// -OpFoldResult DivSOp::fold(ArrayRef operands) { +OpFoldResult DivSOp::fold(FoldAdaptor adaptor) { return foldBinaryOpChecked( - operands, [](const APInt &lhs, const APInt &rhs) -> Optional { + adaptor.getOperands(), + [](const APInt &lhs, const APInt &rhs) -> Optional { // Don't fold division by zero. if (rhs.isZero()) return std::nullopt; @@ -156,9 +160,10 @@ // DivUOp //===----------------------------------------------------------------------===// -OpFoldResult DivUOp::fold(ArrayRef operands) { +OpFoldResult DivUOp::fold(FoldAdaptor adaptor) { return foldBinaryOpChecked( - operands, [](const APInt &lhs, const APInt &rhs) -> Optional { + adaptor.getOperands(), + [](const APInt &lhs, const APInt &rhs) -> Optional { // Don't fold division by zero. if (rhs.isZero()) return std::nullopt; @@ -193,18 +198,19 @@ return (n + x).sdiv(m) + 1; } -OpFoldResult CeilDivSOp::fold(ArrayRef operands) { - return foldBinaryOpChecked(operands, calculateCeilDivS); +OpFoldResult CeilDivSOp::fold(FoldAdaptor adaptor) { + return foldBinaryOpChecked(adaptor.getOperands(), calculateCeilDivS); } //===----------------------------------------------------------------------===// // CeilDivUOp //===----------------------------------------------------------------------===// -OpFoldResult CeilDivUOp::fold(ArrayRef operands) { +OpFoldResult CeilDivUOp::fold(FoldAdaptor adaptor) { // Compute `ceildivu(n, m)` as `n == 0 ? 0 : (n-1)/m + 1`. return foldBinaryOpChecked( - operands, [](const APInt &n, const APInt &m) -> Optional { + adaptor.getOperands(), + [](const APInt &n, const APInt &m) -> Optional { // Don't fold division by zero. if (m.isZero()) return std::nullopt; @@ -242,56 +248,58 @@ return -1 - (x - n).sdiv(m); } -OpFoldResult FloorDivSOp::fold(ArrayRef operands) { - return foldBinaryOpChecked(operands, calculateFloorDivS); +OpFoldResult FloorDivSOp::fold(FoldAdaptor adaptor) { + return foldBinaryOpChecked(adaptor.getOperands(), calculateFloorDivS); } //===----------------------------------------------------------------------===// // RemSOp //===----------------------------------------------------------------------===// -OpFoldResult RemSOp::fold(ArrayRef operands) { - return foldBinaryOpChecked(operands, [](const APInt &lhs, const APInt &rhs) { - return lhs.srem(rhs); - }); +OpFoldResult RemSOp::fold(FoldAdaptor adaptor) { + return foldBinaryOpChecked( + adaptor.getOperands(), + [](const APInt &lhs, const APInt &rhs) { return lhs.srem(rhs); }); } //===----------------------------------------------------------------------===// // RemUOp //===----------------------------------------------------------------------===// -OpFoldResult RemUOp::fold(ArrayRef operands) { - return foldBinaryOpChecked(operands, [](const APInt &lhs, const APInt &rhs) { - return lhs.urem(rhs); - }); +OpFoldResult RemUOp::fold(FoldAdaptor adaptor) { + return foldBinaryOpChecked( + adaptor.getOperands(), + [](const APInt &lhs, const APInt &rhs) { return lhs.urem(rhs); }); } //===----------------------------------------------------------------------===// // MaxSOp //===----------------------------------------------------------------------===// -OpFoldResult MaxSOp::fold(ArrayRef operands) { - return foldBinaryOpChecked(operands, [](const APInt &lhs, const APInt &rhs) { - return lhs.sgt(rhs) ? lhs : rhs; - }); +OpFoldResult MaxSOp::fold(FoldAdaptor adaptor) { + return foldBinaryOpChecked(adaptor.getOperands(), + [](const APInt &lhs, const APInt &rhs) { + return lhs.sgt(rhs) ? lhs : rhs; + }); } //===----------------------------------------------------------------------===// // MaxUOp //===----------------------------------------------------------------------===// -OpFoldResult MaxUOp::fold(ArrayRef operands) { - return foldBinaryOpChecked(operands, [](const APInt &lhs, const APInt &rhs) { - return lhs.ugt(rhs) ? lhs : rhs; - }); +OpFoldResult MaxUOp::fold(FoldAdaptor adaptor) { + return foldBinaryOpChecked(adaptor.getOperands(), + [](const APInt &lhs, const APInt &rhs) { + return lhs.ugt(rhs) ? lhs : rhs; + }); } //===----------------------------------------------------------------------===// // MinSOp //===----------------------------------------------------------------------===// -OpFoldResult MinSOp::fold(ArrayRef operands) { - return foldBinaryOpChecked(operands, [](const APInt &lhs, const APInt &rhs) { +OpFoldResult MinSOp::fold(FoldAdaptor adaptor) { + return foldBinaryOpChecked(adaptor.getOperands(), [](const APInt &lhs, const APInt &rhs) { return lhs.slt(rhs) ? lhs : rhs; }); } @@ -300,8 +308,8 @@ // MinUOp //===----------------------------------------------------------------------===// -OpFoldResult MinUOp::fold(ArrayRef operands) { - return foldBinaryOpChecked(operands, [](const APInt &lhs, const APInt &rhs) { +OpFoldResult MinUOp::fold(FoldAdaptor adaptor) { + return foldBinaryOpChecked(adaptor.getOperands(), [](const APInt &lhs, const APInt &rhs) { return lhs.ult(rhs) ? lhs : rhs; }); } @@ -310,9 +318,10 @@ // ShlOp //===----------------------------------------------------------------------===// -OpFoldResult ShlOp::fold(ArrayRef operands) { +OpFoldResult ShlOp::fold(FoldAdaptor adaptor) { return foldBinaryOpUnchecked( - operands, [](const APInt &lhs, const APInt &rhs) -> Optional { + adaptor.getOperands(), + [](const APInt &lhs, const APInt &rhs) -> Optional { // We cannot fold if the RHS is greater than or equal to 32 because // this would be UB in 32-bit systems but not on 64-bit systems. RHS is // already treated as unsigned. @@ -326,9 +335,10 @@ // ShrSOp //===----------------------------------------------------------------------===// -OpFoldResult ShrSOp::fold(ArrayRef operands) { +OpFoldResult ShrSOp::fold(FoldAdaptor adaptor) { return foldBinaryOpChecked( - operands, [](const APInt &lhs, const APInt &rhs) -> Optional { + adaptor.getOperands(), + [](const APInt &lhs, const APInt &rhs) -> Optional { // Don't fold if RHS is greater than or equal to 32. if (rhs.uge(32)) return {}; @@ -340,9 +350,10 @@ // ShrUOp //===----------------------------------------------------------------------===// -OpFoldResult ShrUOp::fold(ArrayRef operands) { +OpFoldResult ShrUOp::fold(FoldAdaptor adaptor) { return foldBinaryOpChecked( - operands, [](const APInt &lhs, const APInt &rhs) -> Optional { + adaptor.getOperands(), + [](const APInt &lhs, const APInt &rhs) -> Optional { // Don't fold if RHS is greater than or equal to 32. if (rhs.uge(32)) return {}; @@ -354,27 +365,30 @@ // AndOp //===----------------------------------------------------------------------===// -OpFoldResult AndOp::fold(ArrayRef operands) { +OpFoldResult AndOp::fold(FoldAdaptor adaptor) { return foldBinaryOpUnchecked( - operands, [](const APInt &lhs, const APInt &rhs) { return lhs & rhs; }); + adaptor.getOperands(), + [](const APInt &lhs, const APInt &rhs) { return lhs & rhs; }); } //===----------------------------------------------------------------------===// // OrOp //===----------------------------------------------------------------------===// -OpFoldResult OrOp::fold(ArrayRef operands) { +OpFoldResult OrOp::fold(FoldAdaptor adaptor) { return foldBinaryOpUnchecked( - operands, [](const APInt &lhs, const APInt &rhs) { return lhs | rhs; }); + adaptor.getOperands(), + [](const APInt &lhs, const APInt &rhs) { return lhs | rhs; }); } //===----------------------------------------------------------------------===// // XOrOp //===----------------------------------------------------------------------===// -OpFoldResult XOrOp::fold(ArrayRef operands) { +OpFoldResult XOrOp::fold(FoldAdaptor adaptor) { return foldBinaryOpUnchecked( - operands, [](const APInt &lhs, const APInt &rhs) { return lhs ^ rhs; }); + adaptor.getOperands(), + [](const APInt &lhs, const APInt &rhs) { return lhs ^ rhs; }); } //===----------------------------------------------------------------------===// @@ -425,10 +439,9 @@ llvm_unreachable("unhandled IndexCmpPredicate predicate"); } -OpFoldResult CmpOp::fold(ArrayRef operands) { - assert(operands.size() == 2 && "compare expected 2 operands"); - auto lhs = dyn_cast_if_present(operands[0]); - auto rhs = dyn_cast_if_present(operands[1]); +OpFoldResult CmpOp::fold(FoldAdaptor adaptor) { + auto lhs = dyn_cast_if_present(adaptor.getLhs()); + auto rhs = dyn_cast_if_present(adaptor.getRhs()); if (!lhs || !rhs) return {}; @@ -453,9 +466,7 @@ setNameFn(getResult(), specialName.str()); } -OpFoldResult ConstantOp::fold(ArrayRef operands) { - return getValueAttr(); -} +OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { return getValueAttr(); } void ConstantOp::build(OpBuilder &b, OperationState &state, int64_t value) { build(b, state, b.getIndexType(), b.getIndexAttr(value)); @@ -465,7 +476,7 @@ // BoolConstantOp //===----------------------------------------------------------------------===// -OpFoldResult BoolConstantOp::fold(ArrayRef operands) { +OpFoldResult BoolConstantOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }