diff --git a/mlir/include/mlir/Dialect/CommonFolders.h b/mlir/include/mlir/Dialect/CommonFolders.h --- a/mlir/include/mlir/Dialect/CommonFolders.h +++ b/mlir/include/mlir/Dialect/CommonFolders.h @@ -108,6 +108,55 @@ return {}; } +template < + class AttrElementT, class TargetAttrElementT, + class ElementValueT = typename AttrElementT::ValueType, + class TargetElementValueT = typename TargetAttrElementT::ValueType, + class CalculationT = function_ref> +Attribute constFoldCastOp(ArrayRef operands, Type resType, + const CalculationT &&calculate) { + assert(operands.size() == 1 && "Cast op takes one operands"); + if (!operands[0]) + return {}; + + if (operands[0].isa()) { + auto op = operands[0].cast(); + bool castStatus = true; + auto res = calculate(op.getValue(), castStatus); + if (!castStatus) + return {}; + return TargetAttrElementT::get(resType, res); + } + if (operands[0].isa()) { + // Both operands are splats so we can avoid expanding the values out and + // just fold based on the splat value. + auto op = operands[0].cast(); + bool castStatus = true; + auto elementResult = + calculate(op.getSplatValue(), castStatus); + if (!castStatus) + return {}; + return DenseElementsAttr::get(resType, elementResult); + } else if (operands[0].isa()) { + // Operands are ElementsAttr-derived; perform an element-wise fold by + // expanding the values. + auto op = operands[0].cast(); + bool castStatus = true; + auto opIt = op.value_begin(); + SmallVector elementResults; + elementResults.reserve(op.getNumElements()); + for (size_t i = 0, e = op.getNumElements(); i < e; ++i, ++opIt) { + auto elt = calculate(*opIt, castStatus); + if (!castStatus) + return {}; + elementResults.push_back(elt); + } + + return DenseElementsAttr::get(resType, elementResults); + } + return {}; +} + } // namespace mlir #endif // MLIR_DIALECT_COMMONFOLDERS_H diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -875,16 +875,15 @@ //===----------------------------------------------------------------------===// OpFoldResult arith::ExtUIOp::fold(ArrayRef operands) { - if (auto lhs = operands[0].dyn_cast_or_null()) - return IntegerAttr::get( - getType(), lhs.getValue().zext(getType().getIntOrFloatBitWidth())); - if (auto lhs = getIn().getDefiningOp()) { getInMutable().assign(lhs.getIn()); return getResult(); } - - return {}; + Type resType = getType(); + return constFoldCastOp( + operands, getType(), [&resType](const APInt &a, bool &castStatus) { + return a.zext(resType.getIntOrFloatBitWidth()); + }); } bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { @@ -900,16 +899,15 @@ //===----------------------------------------------------------------------===// OpFoldResult arith::ExtSIOp::fold(ArrayRef operands) { - if (auto lhs = operands[0].dyn_cast_or_null()) - return IntegerAttr::get( - getType(), lhs.getValue().sext(getType().getIntOrFloatBitWidth())); - if (auto lhs = getIn().getDefiningOp()) { getInMutable().assign(lhs.getIn()); return getResult(); } - - return {}; + Type resType = getType(); + return constFoldCastOp( + operands, getType(), [&resType](const APInt &a, bool &castStatus) { + return a.sext(resType.getIntOrFloatBitWidth()); + }); } bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { @@ -954,15 +952,11 @@ return getResult(); } - if (!operands[0]) - return {}; - - if (auto lhs = operands[0].dyn_cast()) { - return IntegerAttr::get( - getType(), lhs.getValue().trunc(getType().getIntOrFloatBitWidth())); - } - - return {}; + Type resType = getType(); + return constFoldCastOp( + operands, getType(), [&resType](const APInt &a, bool &castStatus) { + return a.trunc(resType.getIntOrFloatBitWidth()); + }); } bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { @@ -1048,15 +1042,16 @@ } OpFoldResult arith::UIToFPOp::fold(ArrayRef operands) { - if (auto lhs = operands[0].dyn_cast_or_null()) { - const APInt &api = lhs.getValue(); - FloatType floatTy = getType().cast(); - APFloat apf(floatTy.getFloatSemantics(), - APInt::getZero(floatTy.getWidth())); - apf.convertFromAPInt(api, /*IsSigned=*/false, APFloat::rmNearestTiesToEven); - return FloatAttr::get(floatTy, apf); - } - return {}; + Type resType = getType(); + return constFoldCastOp( + operands, getType(), [&resType](const APInt &a, bool &castStatus) { + FloatType floatTy = resType.cast(); + APFloat apf(floatTy.getFloatSemantics(), + APInt::getZero(floatTy.getWidth())); + apf.convertFromAPInt(a, /*IsSigned=*/false, + APFloat::rmNearestTiesToEven); + return apf; + }); } //===----------------------------------------------------------------------===// @@ -1068,15 +1063,16 @@ } OpFoldResult arith::SIToFPOp::fold(ArrayRef operands) { - if (auto lhs = operands[0].dyn_cast_or_null()) { - const APInt &api = lhs.getValue(); - FloatType floatTy = getType().cast(); - APFloat apf(floatTy.getFloatSemantics(), - APInt::getZero(floatTy.getWidth())); - apf.convertFromAPInt(api, /*IsSigned=*/true, APFloat::rmNearestTiesToEven); - return FloatAttr::get(floatTy, apf); - } - return {}; + Type resType = getType(); + return constFoldCastOp( + operands, getType(), [&resType](const APInt &a, bool &castStatus) { + FloatType floatTy = resType.cast(); + APFloat apf(floatTy.getFloatSemantics(), + APInt::getZero(floatTy.getWidth())); + apf.convertFromAPInt(a, /*IsSigned=*/true, + APFloat::rmNearestTiesToEven); + return apf; + }); } //===----------------------------------------------------------------------===// // FPToUIOp @@ -1087,21 +1083,16 @@ } OpFoldResult arith::FPToUIOp::fold(ArrayRef operands) { - if (auto lhs = operands[0].dyn_cast_or_null()) { - const APFloat &apf = lhs.getValue(); - IntegerType intTy = getType().cast(); - bool ignored; - APSInt api(intTy.getWidth(), /*isUnsigned=*/true); - if (APFloat::opInvalidOp == - apf.convertToInteger(api, APFloat::rmTowardZero, &ignored)) { - // Undefined behavior invoked - the destination type can't represent - // the input constant. - return {}; - } - return IntegerAttr::get(getType(), api); - } - - return {}; + Type resType = getType(); + return constFoldCastOp( + operands, getType(), [&resType](const APFloat &a, bool &castStatus) { + IntegerType intTy = resType.cast(); + bool ignored; + APSInt api(intTy.getWidth(), /*isUnsigned=*/true); + castStatus = APFloat::opInvalidOp != + a.convertToInteger(api, APFloat::rmTowardZero, &ignored); + return api; + }); } //===----------------------------------------------------------------------===// @@ -1113,21 +1104,16 @@ } OpFoldResult arith::FPToSIOp::fold(ArrayRef operands) { - if (auto lhs = operands[0].dyn_cast_or_null()) { - const APFloat &apf = lhs.getValue(); - IntegerType intTy = getType().cast(); - bool ignored; - APSInt api(intTy.getWidth(), /*isUnsigned=*/false); - if (APFloat::opInvalidOp == - apf.convertToInteger(api, APFloat::rmTowardZero, &ignored)) { - // Undefined behavior invoked - the destination type can't represent - // the input constant. - return {}; - } - return IntegerAttr::get(getType(), api); - } - - return {}; + Type resType = getType(); + return constFoldCastOp( + operands, getType(), [&resType](const APFloat &a, bool &castStatus) { + IntegerType intTy = resType.cast(); + bool ignored; + APSInt api(intTy.getWidth(), /*isUnsigned=*/false); + castStatus = APFloat::opInvalidOp != + a.convertToInteger(api, APFloat::rmTowardZero, &ignored); + return api; + }); } //===----------------------------------------------------------------------===//