diff --git a/mlir/include/mlir/Dialect/CommonFolders.h b/mlir/include/mlir/Dialect/CommonFolders.h --- a/mlir/include/mlir/Dialect/CommonFolders.h +++ b/mlir/include/mlir/Dialect/CommonFolders.h @@ -108,6 +108,56 @@ return {}; } +template < + class AttrElementT, class TargetAttrElementT, + class ElementValueT = typename AttrElementT::ValueType, + class TargetElementValueT = typename TargetAttrElementT::ValueType, + class CalculationT = function_ref> +Attribute constFoldCastOp(ArrayRef operands, Type resType, + const CalculationT &calculate) { + assert(operands.size() == 1 && "Cast op takes one operand"); + if (!operands[0]) + return {}; + + if (operands[0].isa()) { + auto op = operands[0].cast(); + bool castStatus = true; + auto res = calculate(op.getValue(), castStatus); + if (!castStatus) + return {}; + return TargetAttrElementT::get(resType, res); + } + if (operands[0].isa()) { + // The operand is a splat so we can avoid expanding the values out and + // just fold based on the splat value. + auto op = operands[0].cast(); + bool castStatus = true; + auto elementResult = + calculate(op.getSplatValue(), castStatus); + if (!castStatus) + return {}; + return DenseElementsAttr::get(resType, elementResult); + } + if (operands[0].isa()) { + // Operand is ElementsAttr-derived; perform an element-wise fold by + // expanding the value. + auto op = operands[0].cast(); + bool castStatus = true; + auto opIt = op.value_begin(); + SmallVector elementResults; + elementResults.reserve(op.getNumElements()); + for (size_t i = 0, e = op.getNumElements(); i < e; ++i, ++opIt) { + auto elt = calculate(*opIt, castStatus); + if (!castStatus) + return {}; + elementResults.push_back(elt); + } + + return DenseElementsAttr::get(resType, elementResults); + } + return {}; +} + } // namespace mlir #endif // MLIR_DIALECT_COMMONFOLDERS_H diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -875,16 +875,20 @@ //===----------------------------------------------------------------------===// OpFoldResult arith::ExtUIOp::fold(ArrayRef operands) { - if (auto lhs = operands[0].dyn_cast_or_null()) - return IntegerAttr::get( - getType(), lhs.getValue().zext(getType().getIntOrFloatBitWidth())); - if (auto lhs = getIn().getDefiningOp()) { getInMutable().assign(lhs.getIn()); return getResult(); } - - return {}; + Type resType = getType(); + unsigned bitWidth; + if (auto shapedType = resType.dyn_cast()) + bitWidth = shapedType.getElementTypeBitWidth(); + else + bitWidth = resType.getIntOrFloatBitWidth(); + return constFoldCastOp( + operands, getType(), [bitWidth](const APInt &a, bool &castStatus) { + return a.zext(bitWidth); + }); } bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { @@ -900,16 +904,20 @@ //===----------------------------------------------------------------------===// OpFoldResult arith::ExtSIOp::fold(ArrayRef operands) { - if (auto lhs = operands[0].dyn_cast_or_null()) - return IntegerAttr::get( - getType(), lhs.getValue().sext(getType().getIntOrFloatBitWidth())); - if (auto lhs = getIn().getDefiningOp()) { getInMutable().assign(lhs.getIn()); return getResult(); } - - return {}; + Type resType = getType(); + unsigned bitWidth; + if (auto shapedType = resType.dyn_cast()) + bitWidth = shapedType.getElementTypeBitWidth(); + else + bitWidth = resType.getIntOrFloatBitWidth(); + return constFoldCastOp( + operands, getType(), [bitWidth](const APInt &a, bool &castStatus) { + return a.sext(bitWidth); + }); } bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { @@ -954,15 +962,17 @@ return getResult(); } - if (!operands[0]) - return {}; + Type resType = getType(); + unsigned bitWidth; + if (auto shapedType = resType.dyn_cast()) + bitWidth = shapedType.getElementTypeBitWidth(); + else + bitWidth = resType.getIntOrFloatBitWidth(); - if (auto lhs = operands[0].dyn_cast()) { - return IntegerAttr::get( - getType(), lhs.getValue().trunc(getType().getIntOrFloatBitWidth())); - } - - return {}; + return constFoldCastOp( + operands, getType(), [bitWidth](const APInt &a, bool &castStatus) { + return a.trunc(bitWidth); + }); } bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { @@ -1048,15 +1058,21 @@ } OpFoldResult arith::UIToFPOp::fold(ArrayRef operands) { - if (auto lhs = operands[0].dyn_cast_or_null()) { - const APInt &api = lhs.getValue(); - FloatType floatTy = getType().cast(); - APFloat apf(floatTy.getFloatSemantics(), - APInt::getZero(floatTy.getWidth())); - apf.convertFromAPInt(api, /*IsSigned=*/false, APFloat::rmNearestTiesToEven); - return FloatAttr::get(floatTy, apf); - } - return {}; + Type resType = getType(); + Type resEleType; + if (auto shapedType = resType.dyn_cast()) + resEleType = shapedType.getElementType(); + else + resEleType = resType; + return constFoldCastOp( + operands, getType(), [&resEleType](const APInt &a, bool &castStatus) { + FloatType floatTy = resEleType.cast(); + APFloat apf(floatTy.getFloatSemantics(), + APInt::getZero(floatTy.getWidth())); + apf.convertFromAPInt(a, /*IsSigned=*/false, + APFloat::rmNearestTiesToEven); + return apf; + }); } //===----------------------------------------------------------------------===// @@ -1068,15 +1084,21 @@ } OpFoldResult arith::SIToFPOp::fold(ArrayRef operands) { - if (auto lhs = operands[0].dyn_cast_or_null()) { - const APInt &api = lhs.getValue(); - FloatType floatTy = getType().cast(); - APFloat apf(floatTy.getFloatSemantics(), - APInt::getZero(floatTy.getWidth())); - apf.convertFromAPInt(api, /*IsSigned=*/true, APFloat::rmNearestTiesToEven); - return FloatAttr::get(floatTy, apf); - } - return {}; + Type resType = getType(); + Type resEleType; + if (auto shapedType = resType.dyn_cast()) + resEleType = shapedType.getElementType(); + else + resEleType = resType; + return constFoldCastOp( + operands, getType(), [&resEleType](const APInt &a, bool &castStatus) { + FloatType floatTy = resEleType.cast(); + APFloat apf(floatTy.getFloatSemantics(), + APInt::getZero(floatTy.getWidth())); + apf.convertFromAPInt(a, /*IsSigned=*/true, + APFloat::rmNearestTiesToEven); + return apf; + }); } //===----------------------------------------------------------------------===// // FPToUIOp @@ -1087,21 +1109,21 @@ } OpFoldResult arith::FPToUIOp::fold(ArrayRef operands) { - if (auto lhs = operands[0].dyn_cast_or_null()) { - const APFloat &apf = lhs.getValue(); - IntegerType intTy = getType().cast(); - bool ignored; - APSInt api(intTy.getWidth(), /*isUnsigned=*/true); - if (APFloat::opInvalidOp == - apf.convertToInteger(api, APFloat::rmTowardZero, &ignored)) { - // Undefined behavior invoked - the destination type can't represent - // the input constant. - return {}; - } - return IntegerAttr::get(getType(), api); - } - - return {}; + Type resType = getType(); + Type resEleType; + if (auto shapedType = resType.dyn_cast()) + resEleType = shapedType.getElementType(); + else + resEleType = resType; + return constFoldCastOp( + operands, getType(), [&resEleType](const APFloat &a, bool &castStatus) { + IntegerType intTy = resEleType.cast(); + bool ignored; + APSInt api(intTy.getWidth(), /*isUnsigned=*/true); + castStatus = APFloat::opInvalidOp != + a.convertToInteger(api, APFloat::rmTowardZero, &ignored); + return api; + }); } //===----------------------------------------------------------------------===// @@ -1113,21 +1135,21 @@ } OpFoldResult arith::FPToSIOp::fold(ArrayRef operands) { - if (auto lhs = operands[0].dyn_cast_or_null()) { - const APFloat &apf = lhs.getValue(); - IntegerType intTy = getType().cast(); - bool ignored; - APSInt api(intTy.getWidth(), /*isUnsigned=*/false); - if (APFloat::opInvalidOp == - apf.convertToInteger(api, APFloat::rmTowardZero, &ignored)) { - // Undefined behavior invoked - the destination type can't represent - // the input constant. - return {}; - } - return IntegerAttr::get(getType(), api); - } - - return {}; + Type resType = getType(); + Type resEleType; + if (auto shapedType = resType.dyn_cast()) + resEleType = shapedType.getElementType(); + else + resEleType = resType; + return constFoldCastOp( + operands, getType(), [&resEleType](const APFloat &a, bool &castStatus) { + IntegerType intTy = resEleType.cast(); + bool ignored; + APSInt api(intTy.getWidth(), /*isUnsigned=*/false); + castStatus = APFloat::opInvalidOp != + a.convertToInteger(api, APFloat::rmTowardZero, &ignored); + return api; + }); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir --- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir +++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir @@ -282,6 +282,53 @@ return %ext : i16 } +// CHECK-LABEL: @signExtendConstantSplat +// CHECK: %[[cres:.+]] = arith.constant dense<-2> : vector<4xi16> +// CHECK: return %[[cres]] +func @signExtendConstantSplat() -> vector<4xi16> { + %c-2 = arith.constant -2 : i8 + %splat = vector.splat %c-2 : vector<4xi8> + %ext = arith.extsi %splat : vector<4xi8> to vector<4xi16> + return %ext : vector<4xi16> +} + +// CHECK-LABEL: @signExtendConstantVector +// CHECK: %[[cres:.+]] = arith.constant dense<[1, 3, 5, 7]> : vector<4xi16> +// CHECK: return %[[cres]] +func @signExtendConstantVector() -> vector<4xi16> { + %vector = arith.constant dense<[1, 3, 5, 7]> : vector<4xi8> + %ext = arith.extsi %vector : vector<4xi8> to vector<4xi16> + return %ext : vector<4xi16> +} + +// CHECK-LABEL: @unsignedExtendConstant +// CHECK: %[[cres:.+]] = arith.constant 2 : i16 +// CHECK: return %[[cres]] +func @unsignedExtendConstant() -> i16 { + %c2 = arith.constant 2 : i8 + %ext = arith.extui %c2 : i8 to i16 + return %ext : i16 +} + +// CHECK-LABEL: @unsignedExtendConstantSplat +// CHECK: %[[cres:.+]] = arith.constant dense<2> : vector<4xi16> +// CHECK: return %[[cres]] +func @unsignedExtendConstantSplat() -> vector<4xi16> { + %c2 = arith.constant 2 : i8 + %splat = vector.splat %c2 : vector<4xi8> + %ext = arith.extui %splat : vector<4xi8> to vector<4xi16> + return %ext : vector<4xi16> +} + +// CHECK-LABEL: @unsignedExtendConstantVector +// CHECK: %[[cres:.+]] = arith.constant dense<[1, 3, 5, 7]> : vector<4xi16> +// CHECK: return %[[cres]] +func @unsignedExtendConstantVector() -> vector<4xi16> { + %vector = arith.constant dense<[1, 3, 5, 7]> : vector<4xi8> + %ext = arith.extui %vector : vector<4xi8> to vector<4xi16> + return %ext : vector<4xi16> +} + // CHECK-LABEL: @truncConstant // CHECK: %[[cres:.+]] = arith.constant -2 : i16 // CHECK: return %[[cres]] @@ -291,6 +338,25 @@ return %tr : i16 } +// CHECK-LABEL: @truncConstantSplat +// CHECK: %[[cres:.+]] = arith.constant dense<-2> : vector<4xi8> +// CHECK: return %[[cres]] +func @truncConstantSplat() -> vector<4xi8> { + %c-2 = arith.constant -2 : i16 + %splat = vector.splat %c-2 : vector<4xi16> + %trunc = arith.trunci %splat : vector<4xi16> to vector<4xi8> + return %trunc : vector<4xi8> +} + +// CHECK-LABEL: @truncConstantVector +// CHECK: %[[cres:.+]] = arith.constant dense<[1, 3, 5, 7]> : vector<4xi8> +// CHECK: return %[[cres]] +func @truncConstantVector() -> vector<4xi8> { + %vector = arith.constant dense<[1, 3, 5, 7]> : vector<4xi16> + %trunc = arith.trunci %vector : vector<4xi16> to vector<4xi8> + return %trunc : vector<4xi8> +} + // CHECK-LABEL: @truncTrunc // CHECK: %[[cres:.+]] = arith.trunci %arg0 : i64 to i8 // CHECK: return %[[cres]] @@ -921,6 +987,25 @@ return %res : i32 } +// CHECK-LABEL: @constant_FPtoUI_splat( +func @constant_FPtoUI_splat() -> vector<4xi32> { + // CHECK: %[[C0:.+]] = arith.constant dense<2> : vector<4xi32> + // CHECK: return %[[C0]] + %c0 = arith.constant 2.0 : f32 + %splat = vector.splat %c0 : vector<4xf32> + %res = arith.fptoui %splat : vector<4xf32> to vector<4xi32> + return %res : vector<4xi32> +} + +// CHECK-LABEL: @constant_FPtoUI_vector( +func @constant_FPtoUI_vector() -> vector<4xi32> { + // CHECK: %[[C0:.+]] = arith.constant dense<[1, 3, 5, 7]> : vector<4xi32> + // CHECK: return %[[C0]] + %vector = arith.constant dense<[1.0, 3.0, 5.0, 7.0]> : vector<4xf32> + %res = arith.fptoui %vector : vector<4xf32> to vector<4xi32> + return %res : vector<4xi32> +} + // ----- // CHECK-LABEL: @invalid_constant_FPtoUI( func @invalid_constant_FPtoUI() -> i32 { @@ -942,6 +1027,25 @@ return %res : i32 } +// CHECK-LABEL: @constant_FPtoSI_splat( +func @constant_FPtoSI_splat() -> vector<4xi32> { + // CHECK: %[[C0:.+]] = arith.constant dense<-2> : vector<4xi32> + // CHECK: return %[[C0]] + %c0 = arith.constant -2.0 : f32 + %splat = vector.splat %c0 : vector<4xf32> + %res = arith.fptosi %splat : vector<4xf32> to vector<4xi32> + return %res : vector<4xi32> +} + +// CHECK-LABEL: @constant_FPtoSI_vector( +func @constant_FPtoSI_vector() -> vector<4xi32> { + // CHECK: %[[C0:.+]] = arith.constant dense<[-1, -3, -5, -7]> : vector<4xi32> + // CHECK: return %[[C0]] + %vector = arith.constant dense<[-1.0, -3.0, -5.0, -7.0]> : vector<4xf32> + %res = arith.fptosi %vector : vector<4xf32> to vector<4xi32> + return %res : vector<4xi32> +} + // ----- // CHECK-LABEL: @invalid_constant_FPtoSI( func @invalid_constant_FPtoSI() -> i8 { @@ -962,16 +1066,54 @@ return %res : f32 } +// CHECK-LABEL: @constant_SItoFP_splat( +func @constant_SItoFP_splat() -> vector<4xf32> { + // CHECK: %[[C0:.+]] = arith.constant dense<2.000000e+00> : vector<4xf32> + // CHECK: return %[[C0]] + %c0 = arith.constant 2 : i32 + %splat = vector.splat %c0 : vector<4xi32> + %res = arith.sitofp %splat : vector<4xi32> to vector<4xf32> + return %res : vector<4xf32> +} + +// CHECK-LABEL: @constant_SItoFP_vector( +func @constant_SItoFP_vector() -> vector<4xf32> { + // CHECK: %[[C0:.+]] = arith.constant dense<[1.000000e+00, 3.000000e+00, 5.000000e+00, 7.000000e+00]> : vector<4xf32> + // CHECK: return %[[C0]] + %vector = arith.constant dense<[1, 3, 5, 7]> : vector<4xi32> + %res = arith.sitofp %vector : vector<4xi32> to vector<4xf32> + return %res : vector<4xf32> +} + // ----- // CHECK-LABEL: @constant_UItoFP( func @constant_UItoFP() -> f32 { // CHECK: %[[C0:.+]] = arith.constant 2.000000e+00 : f32 // CHECK: return %[[C0]] %c0 = arith.constant 2 : i32 - %res = arith.sitofp %c0 : i32 to f32 + %res = arith.uitofp %c0 : i32 to f32 return %res : f32 } +// CHECK-LABEL: @constant_UItoFP_splat( +func @constant_UItoFP_splat() -> vector<4xf32> { + // CHECK: %[[C0:.+]] = arith.constant dense<2.000000e+00> : vector<4xf32> + // CHECK: return %[[C0]] + %c0 = arith.constant 2 : i32 + %splat = vector.splat %c0 : vector<4xi32> + %res = arith.uitofp %splat : vector<4xi32> to vector<4xf32> + return %res : vector<4xf32> +} + +// CHECK-LABEL: @constant_UItoFP_vector( +func @constant_UItoFP_vector() -> vector<4xf32> { + // CHECK: %[[C0:.+]] = arith.constant dense<[1.000000e+00, 3.000000e+00, 5.000000e+00, 7.000000e+00]> : vector<4xf32> + // CHECK: return %[[C0]] + %vector = arith.constant dense<[1, 3, 5, 7]> : vector<4xi32> + %res = arith.uitofp %vector : vector<4xi32> to vector<4xf32> + return %res : vector<4xf32> +} + // ----- // Tests rewritten from https://github.com/llvm/llvm-project/blob/main/llvm/test/Transforms/InstCombine/2008-11-08-FCmp.ll