diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -1176,12 +1176,9 @@ getInMutable().assign(lhs.getIn()); return getResult(); } - Type resType = getType(); - unsigned bitWidth; - if (auto shapedType = resType.dyn_cast()) - bitWidth = shapedType.getElementTypeBitWidth(); - else - bitWidth = resType.getIntOrFloatBitWidth(); + + Type resType = getElementTypeOrSelf(getType()); + unsigned bitWidth = resType.cast().getWidth(); return constFoldCastOp( operands, getType(), [bitWidth](const APInt &a, bool &castStatus) { return a.zext(bitWidth); @@ -1205,12 +1202,9 @@ getInMutable().assign(lhs.getIn()); return getResult(); } - Type resType = getType(); - unsigned bitWidth; - if (auto shapedType = resType.dyn_cast()) - bitWidth = shapedType.getElementTypeBitWidth(); - else - bitWidth = resType.getIntOrFloatBitWidth(); + + Type resType = getElementTypeOrSelf(getType()); + unsigned bitWidth = resType.cast().getWidth(); return constFoldCastOp( operands, getType(), [bitWidth](const APInt &a, bool &castStatus) { return a.sext(bitWidth); @@ -1259,13 +1253,8 @@ return getResult(); } - Type resType = getType(); - unsigned bitWidth; - if (auto shapedType = resType.dyn_cast()) - bitWidth = shapedType.getElementTypeBitWidth(); - else - bitWidth = resType.getIntOrFloatBitWidth(); - + Type resType = getElementTypeOrSelf(getType()); + unsigned bitWidth = resType.cast().getWidth(); return constFoldCastOp( operands, getType(), [bitWidth](const APInt &a, bool &castStatus) { return a.trunc(bitWidth); @@ -1361,12 +1350,7 @@ } OpFoldResult arith::UIToFPOp::fold(ArrayRef operands) { - Type resType = getType(); - Type resEleType; - if (auto shapedType = resType.dyn_cast()) - resEleType = shapedType.getElementType(); - else - resEleType = resType; + Type resEleType = getElementTypeOrSelf(getType()); return constFoldCastOp( operands, getType(), [&resEleType](const APInt &a, bool &castStatus) { FloatType floatTy = resEleType.cast(); @@ -1387,12 +1371,7 @@ } OpFoldResult arith::SIToFPOp::fold(ArrayRef operands) { - Type resType = getType(); - Type resEleType; - if (auto shapedType = resType.dyn_cast()) - resEleType = shapedType.getElementType(); - else - resEleType = resType; + Type resEleType = getElementTypeOrSelf(getType()); return constFoldCastOp( operands, getType(), [&resEleType](const APInt &a, bool &castStatus) { FloatType floatTy = resEleType.cast(); @@ -1412,17 +1391,12 @@ } OpFoldResult arith::FPToUIOp::fold(ArrayRef operands) { - Type resType = getType(); - Type resEleType; - if (auto shapedType = resType.dyn_cast()) - resEleType = shapedType.getElementType(); - else - resEleType = resType; + Type resType = getElementTypeOrSelf(getType()); + unsigned bitWidth = resType.cast().getWidth(); return constFoldCastOp( - operands, getType(), [&resEleType](const APFloat &a, bool &castStatus) { - IntegerType intTy = resEleType.cast(); + operands, getType(), [&bitWidth](const APFloat &a, bool &castStatus) { bool ignored; - APSInt api(intTy.getWidth(), /*isUnsigned=*/true); + APSInt api(bitWidth, /*isUnsigned=*/true); castStatus = APFloat::opInvalidOp != a.convertToInteger(api, APFloat::rmTowardZero, &ignored); return api; @@ -1438,17 +1412,12 @@ } OpFoldResult arith::FPToSIOp::fold(ArrayRef operands) { - Type resType = getType(); - Type resEleType; - if (auto shapedType = resType.dyn_cast()) - resEleType = shapedType.getElementType(); - else - resEleType = resType; + Type resType = getElementTypeOrSelf(getType()); + unsigned bitWidth = resType.cast().getWidth(); return constFoldCastOp( - operands, getType(), [&resEleType](const APFloat &a, bool &castStatus) { - IntegerType intTy = resEleType.cast(); + operands, getType(), [&bitWidth](const APFloat &a, bool &castStatus) { bool ignored; - APSInt api(intTy.getWidth(), /*isUnsigned=*/false); + APSInt api(bitWidth, /*isUnsigned=*/false); castStatus = APFloat::opInvalidOp != a.convertToInteger(api, APFloat::rmTowardZero, &ignored); return api;