Index: mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td =================================================================== --- mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td +++ mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td @@ -45,6 +45,7 @@ let cppNamespace = "mlir::tosa"; let hasConstantMaterializer = 1; let useDefaultAttributePrinterParser = 1; + let useFoldAPI = kEmitFoldAdaptorFolder; } //===----------------------------------------------------------------------===// Index: mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp =================================================================== --- mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -539,7 +539,7 @@ return {}; } -OpFoldResult AddOp::fold(ArrayRef operands) { +OpFoldResult AddOp::fold(FoldAdaptor adaptor) { auto lhsTy = getInput1().getType().dyn_cast(); auto rhsTy = getInput2().getType().dyn_cast(); auto resultTy = getType().dyn_cast(); @@ -549,8 +549,8 @@ return {}; auto resultETy = resultTy.getElementType(); - auto lhsAttr = operands[0].dyn_cast_or_null(); - auto rhsAttr = operands[1].dyn_cast_or_null(); + auto lhsAttr = adaptor.getInput1().dyn_cast_or_null(); + auto rhsAttr = adaptor.getInput2().dyn_cast_or_null(); if (lhsAttr && lhsAttr.isSplat() && resultETy.isa()) { if (lhsAttr.getSplatValue().isZero()) @@ -579,7 +579,7 @@ lhsTy); } -OpFoldResult DivOp::fold(ArrayRef operands) { +OpFoldResult DivOp::fold(FoldAdaptor adaptor) { auto lhsTy = getInput1().getType().dyn_cast(); auto rhsTy = getInput2().getType().dyn_cast(); auto resultTy = getType().dyn_cast(); @@ -589,8 +589,8 @@ return {}; auto resultETy = resultTy.getElementType(); - auto lhsAttr = operands[0].dyn_cast_or_null(); - auto rhsAttr = operands[1].dyn_cast_or_null(); + auto lhsAttr = adaptor.getInput1().dyn_cast_or_null(); + auto rhsAttr = adaptor.getInput2().dyn_cast_or_null(); if (lhsAttr && lhsAttr.isSplat()) { if (resultETy.isa() && lhsAttr.getSplatValue().isZero()) return lhsAttr; @@ -646,7 +646,7 @@ } } // namespace -OpFoldResult MulOp::fold(ArrayRef operands) { +OpFoldResult MulOp::fold(FoldAdaptor adaptor) { auto lhs = getInput1(); auto rhs = getInput2(); auto lhsTy = lhs.getType().dyn_cast(); @@ -658,8 +658,8 @@ return {}; auto resultETy = resultTy.getElementType(); - auto lhsAttr = operands[0].dyn_cast_or_null(); - auto rhsAttr = operands[1].dyn_cast_or_null(); + auto lhsAttr = adaptor.getInput1().dyn_cast_or_null(); + auto rhsAttr = adaptor.getInput2().dyn_cast_or_null(); if (lhsAttr && lhsAttr.isSplat() && resultETy.isa()) { auto val = lhsAttr.getSplatValue(); @@ -700,7 +700,7 @@ return mulBinaryFolder(lhsAttr, rhsAttr, lhsTy, getShift()); } -OpFoldResult SubOp::fold(ArrayRef operands) { +OpFoldResult SubOp::fold(FoldAdaptor adaptor) { auto lhsTy = getInput1().getType().dyn_cast(); auto rhsTy = getInput2().getType().dyn_cast(); auto resultTy = getType().dyn_cast(); @@ -710,8 +710,8 @@ return {}; auto resultETy = resultTy.getElementType(); - auto lhsAttr = operands[0].dyn_cast_or_null(); - auto rhsAttr = operands[1].dyn_cast_or_null(); + auto lhsAttr = adaptor.getInput1().dyn_cast_or_null(); + auto rhsAttr = adaptor.getInput2().dyn_cast_or_null(); if (rhsAttr && rhsAttr.isSplat() && resultETy.isa()) { if (rhsAttr.getSplatValue().isZero()) @@ -757,10 +757,10 @@ }; } // namespace -OpFoldResult GreaterOp::fold(ArrayRef operands) { +OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) { auto resultTy = getType().dyn_cast(); - auto lhsAttr = operands[0].dyn_cast_or_null(); - auto rhsAttr = operands[1].dyn_cast_or_null(); + auto lhsAttr = adaptor.getInput1().dyn_cast_or_null(); + auto rhsAttr = adaptor.getInput2().dyn_cast_or_null(); if (!lhsAttr || !rhsAttr) return {}; @@ -769,10 +769,10 @@ lhsAttr, rhsAttr, resultTy); } -OpFoldResult GreaterEqualOp::fold(ArrayRef operands) { +OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) { auto resultTy = getType().dyn_cast(); - auto lhsAttr = operands[0].dyn_cast_or_null(); - auto rhsAttr = operands[1].dyn_cast_or_null(); + auto lhsAttr = adaptor.getInput1().dyn_cast_or_null(); + auto rhsAttr = adaptor.getInput2().dyn_cast_or_null(); if (!lhsAttr || !rhsAttr) return {}; @@ -782,10 +782,10 @@ lhsAttr, rhsAttr, resultTy); } -OpFoldResult EqualOp::fold(ArrayRef operands) { +OpFoldResult EqualOp::fold(FoldAdaptor adaptor) { auto resultTy = getType().dyn_cast(); - auto lhsAttr = operands[0].dyn_cast_or_null(); - auto rhsAttr = operands[1].dyn_cast_or_null(); + auto lhsAttr = adaptor.getInput1().dyn_cast_or_null(); + auto rhsAttr = adaptor.getInput2().dyn_cast_or_null(); Value lhs = getInput1(); Value rhs = getInput2(); auto lhsTy = lhs.getType().cast(); @@ -805,11 +805,11 @@ resultTy); } -OpFoldResult CastOp::fold(ArrayRef operands) { +OpFoldResult CastOp::fold(FoldAdaptor adaptor) { if (getInput().getType() == getType()) return getInput(); - auto operand = operands[0].dyn_cast_or_null(); + auto operand = adaptor.getInput().dyn_cast_or_null(); if (!operand) return {}; @@ -868,13 +868,10 @@ return {}; } -OpFoldResult ConstOp::fold(ArrayRef operands) { - assert(operands.empty() && "constant has no operands"); - return getValueAttr(); -} +OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { return getValueAttr(); } #define REDUCE_FOLDER(OP) \ - OpFoldResult OP::fold(ArrayRef operands) { \ + OpFoldResult OP::fold(FoldAdaptor adaptor) { \ ShapedType inputTy = getInput().getType().cast(); \ if (!inputTy.hasRank()) \ return {}; \ @@ -891,7 +888,7 @@ REDUCE_FOLDER(ReduceSumOp) #undef REDUCE_FOLDER -OpFoldResult ReshapeOp::fold(ArrayRef operands) { +OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) { auto inputTy = getInput1().getType().dyn_cast(); auto outputTy = getType().dyn_cast(); @@ -901,7 +898,7 @@ if (inputTy == outputTy) return getInput1(); - auto operand = operands[0].dyn_cast_or_null(); + auto operand = adaptor.getInput1().dyn_cast_or_null(); if (operand && outputTy.hasStaticShape() && operand.isSplat()) { return SplatElementsAttr::get(outputTy, operand.getSplatValue()); } @@ -909,10 +906,10 @@ return {}; } -OpFoldResult PadOp::fold(ArrayRef operands) { +OpFoldResult PadOp::fold(FoldAdaptor adaptor) { // If the pad is all zeros we can fold this operation away. - if (operands[1]) { - auto densePad = operands[1].cast(); + if (adaptor.getPadding()) { + auto densePad = adaptor.getPadding().cast(); if (densePad.isSplat() && densePad.getSplatValue().isZero()) { return getInput1(); } @@ -923,7 +920,7 @@ // Fold away cases where a tosa.resize operation returns a copy // of the input image. -OpFoldResult ResizeOp::fold(ArrayRef operands) { +OpFoldResult ResizeOp::fold(FoldAdaptor adaptor) { ArrayRef offset = getOffset(); ArrayRef border = getBorder(); ArrayRef scale = getScale(); @@ -952,11 +949,11 @@ return input; } -OpFoldResult ReverseOp::fold(ArrayRef operands) { +OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) { auto operand = getInput(); auto operandTy = operand.getType().cast(); auto axis = getAxis(); - auto operandAttr = operands[0].dyn_cast_or_null(); + auto operandAttr = adaptor.getInput().dyn_cast_or_null(); if (operandAttr) return operandAttr; @@ -967,7 +964,7 @@ return {}; } -OpFoldResult SliceOp::fold(ArrayRef operands) { +OpFoldResult SliceOp::fold(FoldAdaptor adaptor) { auto inputTy = getInput().getType().dyn_cast(); auto outputTy = getType().dyn_cast(); @@ -977,10 +974,10 @@ if (inputTy == outputTy && inputTy.hasStaticShape()) return getInput(); - if (!operands[0]) + if (!adaptor.getInput()) return {}; - auto operand = operands[0].cast(); + auto operand = adaptor.getInput().cast(); if (operand.isSplat() && outputTy.hasStaticShape()) { return SplatElementsAttr::get(outputTy, operand.getSplatValue()); } @@ -995,11 +992,11 @@ return {}; } -OpFoldResult tosa::SelectOp::fold(ArrayRef operands) { +OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) { if (getOnTrue() == getOnFalse()) return getOnTrue(); - auto predicate = operands[0].dyn_cast_or_null(); + auto predicate = adaptor.getPred().dyn_cast_or_null(); if (!predicate) return {}; @@ -1009,19 +1006,19 @@ : getOnFalse(); } -OpFoldResult TileOp::fold(ArrayRef operands) { +OpFoldResult TileOp::fold(FoldAdaptor adaptor) { bool allOnes = llvm::all_of(getMultiples(), [](int64_t v) { return v == 1; }); if (allOnes && getInput1().getType() == getType()) return getInput1(); return {}; } -OpFoldResult TransposeOp::fold(ArrayRef operands) { +OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) { auto inputTy = getInput1().getType().cast(); auto resultTy = getType().cast(); // Transposing splat values just means reshaping. - if (auto input = operands[0].dyn_cast_or_null()) { + if (auto input = adaptor.getInput1().dyn_cast_or_null()) { if (input.isSplat() && resultTy.hasStaticShape() && inputTy.getElementType() == resultTy.getElementType()) return input.reshape(resultTy);