diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td --- a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td +++ b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td @@ -22,6 +22,10 @@ dialect also accept vectors and tensors of integers or floats. }]; + let dependentDialects = [ + "::mlir::ub::UBDialect" + ]; + let hasConstantMaterializer = 1; let useDefaultAttributePrinterParser = 1; } diff --git a/mlir/include/mlir/Dialect/CommonFolders.h b/mlir/include/mlir/Dialect/CommonFolders.h --- a/mlir/include/mlir/Dialect/CommonFolders.h +++ b/mlir/include/mlir/Dialect/CommonFolders.h @@ -22,17 +22,31 @@ #include namespace mlir { +namespace ub { +class PoisonAttr; +} /// Performs constant folding `calculate` with element-wise behavior on the two /// attributes in `operands` and returns the result if possible. /// Uses `resultType` for the type of the returned attribute. +/// Optional PoisonAttr template argument allows to specify 'poison' attribute +/// which will be diectly propagated to result. template (ElementValueT, ElementValueT)>> Attribute constFoldBinaryOpConditional(ArrayRef operands, Type resultType, - const CalculationT &calculate) { + CalculationT &&calculate) { assert(operands.size() == 2 && "binary op takes two operands"); + if constexpr (!std::is_void_v) { + if (isa_and_nonnull(operands[0])) + return operands[0]; + + if (isa_and_nonnull(operands[1])) + return operands[1]; + } + if (!resultType || !operands[0] || !operands[1]) return {}; @@ -95,13 +109,24 @@ /// attributes in `operands` and returns the result if possible. /// Uses the operand element type for the element type of the returned /// attribute. +/// Optional PoisonAttr template argument allows to specify 'poison' attribute +/// which will be diectly propagated to result. template (ElementValueT, ElementValueT)>> Attribute constFoldBinaryOpConditional(ArrayRef operands, - const CalculationT &calculate) { + CalculationT &&calculate) { assert(operands.size() == 2 && "binary op takes two operands"); + if constexpr (!std::is_void_v) { + if (isa_and_nonnull(operands[0])) + return operands[0]; + + if (isa_and_nonnull(operands[1])) + return operands[1]; + } + auto getResultType = [](Attribute attr) -> Type { if (auto typed = dyn_cast_or_null(attr)) return typed.getType(); @@ -115,18 +140,19 @@ if (lhsType != rhsType) return {}; - return constFoldBinaryOpConditional(operands, lhsType, - calculate); + return constFoldBinaryOpConditional( + operands, lhsType, std::forward(calculate)); } template > Attribute constFoldBinaryOp(ArrayRef operands, Type resultType, - const CalculationT &calculate) { - return constFoldBinaryOpConditional( + CalculationT &&calculate) { + return constFoldBinaryOpConditional( operands, resultType, [&](ElementValueT a, ElementValueT b) -> std::optional { return calculate(a, b); @@ -135,11 +161,12 @@ template > Attribute constFoldBinaryOp(ArrayRef operands, - const CalculationT &calculate) { - return constFoldBinaryOpConditional( + CalculationT &&calculate) { + return constFoldBinaryOpConditional( operands, [&](ElementValueT a, ElementValueT b) -> std::optional { return calculate(a, b); @@ -148,16 +175,24 @@ /// Performs constant folding `calculate` with element-wise behavior on the one /// attributes in `operands` and returns the result if possible. +/// Optional PoisonAttr template argument allows to specify 'poison' attribute +/// which will be diectly propagated to result. template (ElementValueT)>> Attribute constFoldUnaryOpConditional(ArrayRef operands, - const CalculationT &&calculate) { + CalculationT &&calculate) { assert(operands.size() == 1 && "unary op takes one operands"); if (!operands[0]) return {}; + if constexpr (!std::is_void_v) { + if (isa(operands[0])) + return operands[0]; + } + if (isa(operands[0])) { auto op = cast(operands[0]); @@ -196,10 +231,11 @@ template > Attribute constFoldUnaryOp(ArrayRef operands, - const CalculationT &&calculate) { - return constFoldUnaryOpConditional( + CalculationT &&calculate) { + return constFoldUnaryOpConditional( operands, [&](ElementValueT a) -> std::optional { return calculate(a); }); @@ -209,13 +245,19 @@ class AttrElementT, class TargetAttrElementT, class ElementValueT = typename AttrElementT::ValueType, class TargetElementValueT = typename TargetAttrElementT::ValueType, + class PoisonAttr = void, class CalculationT = function_ref> Attribute constFoldCastOp(ArrayRef operands, Type resType, - const CalculationT &calculate) { + CalculationT &&calculate) { assert(operands.size() == 1 && "Cast op takes one operand"); if (!operands[0]) return {}; + if constexpr (!std::is_void_v) { + if (isa(operands[0])) + return operands[0]; + } + if (isa(operands[0])) { auto op = cast(operands[0]); bool castStatus = true; @@ -255,6 +297,72 @@ return {}; } +template (ElementValueT, ElementValueT)>> +Attribute constFoldBinaryOpConditionalPoison(ArrayRef operands, + CalculationT &&calculate) { + return constFoldBinaryOpConditional( + operands, std::forward(calculate)); +} + +template > +Attribute constFoldBinaryOpPoison(ArrayRef operands, Type resultType, + CalculationT &&calculate) { + return constFoldBinaryOp(operands, resultType, + std::forward(calculate)); +} + +template > +Attribute constFoldBinaryOpPoison(ArrayRef operands, + CalculationT &&calculate) { + return constFoldBinaryOp(operands, + std::forward(calculate)); +} + +template (ElementValueT)>> +Attribute constFoldUnaryOpConditionalPoison(ArrayRef operands, + CalculationT &&calculate) { + return constFoldUnaryOpConditional( + operands, std::forward(calculate)); +} + +template > +Attribute constFoldUnaryOpPoison(ArrayRef operands, + CalculationT &&calculate) { + return constFoldUnaryOp(operands, + std::forward(calculate)); +} + +template < + class AttrElementT, class TargetAttrElementT, + class ElementValueT = typename AttrElementT::ValueType, + class TargetElementValueT = typename TargetAttrElementT::ValueType, + class CalculationT = function_ref> +Attribute constFoldCastOpPoison(ArrayRef operands, Type resType, + CalculationT &&calculate) { + return constFoldCastOp( + operands, resType, std::forward(calculate)); +} + } // namespace mlir #endif // MLIR_DIALECT_COMMONFOLDERS_H diff --git a/mlir/include/mlir/Dialect/Math/IR/MathBase.td b/mlir/include/mlir/Dialect/Math/IR/MathBase.td --- a/mlir/include/mlir/Dialect/Math/IR/MathBase.td +++ b/mlir/include/mlir/Dialect/Math/IR/MathBase.td @@ -31,7 +31,8 @@ }]; let hasConstantMaterializer = 1; let dependentDialects = [ - "::mlir::arith::ArithDialect" + "::mlir::arith::ArithDialect", + "::mlir::ub::UBDialect" ]; } #endif // MATH_BASE diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -53,6 +53,10 @@ let hasRegionArgAttrVerify = 1; let hasRegionResultAttrVerify = 1; + let dependentDialects = [ + "::mlir::ub::UBDialect" + ]; + let extraClassDeclaration = [{ void registerAttributes(); void registerTypes(); diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td @@ -36,7 +36,11 @@ }]; let cppNamespace = "::mlir::shape"; - let dependentDialects = ["arith::ArithDialect", "tensor::TensorDialect"]; + let dependentDialects = [ + "::mlir::arith::ArithDialect", + "::mlir::tensor::TensorDialect", + "::mlir::ub::UBDialect" + ]; let useDefaultTypePrinterParser = 1; let hasConstantMaterializer = 1; diff --git a/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp b/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp --- a/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/Transforms/InliningUtils.h" @@ -49,5 +50,8 @@ Operation *arith::ArithDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { + if (auto poison = dyn_cast(value)) + return builder.create(loc, type, poison); + return ConstantOp::materialize(builder, value, type, loc); } diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -271,7 +271,7 @@ if (getLhs() == sub.getRhs()) return sub.getLhs(); - return constFoldBinaryOp( + return constFoldBinaryOpPoison( adaptor.getOperands(), [](APInt a, const APInt &b) { return std::move(a) + b; }); } @@ -317,10 +317,10 @@ // Let the `constFoldBinaryOp` utility attempt to fold the sum of both // operands. If that succeeds, calculate the overflow bit based on the sum // and the first (constant) operand, `lhs`. - if (Attribute sumAttr = constFoldBinaryOp( + if (Attribute sumAttr = constFoldBinaryOpPoison( adaptor.getOperands(), [](APInt a, const APInt &b) { return std::move(a) + b; })) { - Attribute overflowAttr = constFoldBinaryOp( + Attribute overflowAttr = constFoldBinaryOpPoison( ArrayRef({sumAttr, adaptor.getLhs()}), getI1SameShape(llvm::cast(sumAttr).getType()), calculateUnsignedOverflow); @@ -361,7 +361,7 @@ return add.getRhs(); } - return constFoldBinaryOp( + return constFoldBinaryOpPoison( adaptor.getOperands(), [](APInt a, const APInt &b) { return std::move(a) - b; }); } @@ -387,7 +387,7 @@ // TODO: Handle the overflow case. // default folder - return constFoldBinaryOp( + return constFoldBinaryOpPoison( adaptor.getOperands(), [](const APInt &a, const APInt &b) { return a * b; }); } @@ -420,11 +420,11 @@ } // mulsi_extended(cst_a, cst_b) -> cst_low, cst_high - if (Attribute lowAttr = constFoldBinaryOp( + if (Attribute lowAttr = constFoldBinaryOpPoison( adaptor.getOperands(), [](const APInt &a, const APInt &b) { return a * b; })) { // Invoke the constant fold helper again to calculate the 'high' result. - Attribute highAttr = constFoldBinaryOp( + Attribute highAttr = constFoldBinaryOpPoison( adaptor.getOperands(), [](const APInt &a, const APInt &b) { unsigned bitWidth = a.getBitWidth(); APInt fullProduct = a.sext(bitWidth * 2) * b.sext(bitWidth * 2); @@ -477,11 +477,11 @@ } // mului_extended(cst_a, cst_b) -> cst_low, cst_high - if (Attribute lowAttr = constFoldBinaryOp( + if (Attribute lowAttr = constFoldBinaryOpPoison( adaptor.getOperands(), [](const APInt &a, const APInt &b) { return a * b; })) { // Invoke the constant fold helper again to calculate the 'high' result. - Attribute highAttr = constFoldBinaryOp( + Attribute highAttr = constFoldBinaryOpPoison( adaptor.getOperands(), [](const APInt &a, const APInt &b) { unsigned bitWidth = a.getBitWidth(); APInt fullProduct = a.zext(bitWidth * 2) * b.zext(bitWidth * 2); @@ -513,14 +513,14 @@ // Don't fold if it would require a division by zero. bool div0 = false; - auto result = constFoldBinaryOp(adaptor.getOperands(), - [&](APInt a, const APInt &b) { - if (div0 || !b) { - div0 = true; - return a; - } - return a.udiv(b); - }); + auto result = constFoldBinaryOpPoison( + adaptor.getOperands(), [&](APInt a, const APInt &b) { + if (div0 || !b) { + div0 = true; + return a; + } + return a.udiv(b); + }); return div0 ? Attribute() : result; } @@ -542,7 +542,7 @@ // Don't fold if it would overflow or if it requires a division by zero. bool overflowOrDiv0 = false; - auto result = constFoldBinaryOp( + auto result = constFoldBinaryOpPoison( adaptor.getOperands(), [&](APInt a, const APInt &b) { if (overflowOrDiv0 || !b) { overflowOrDiv0 = true; @@ -588,7 +588,7 @@ return getLhs(); bool overflowOrDiv0 = false; - auto result = constFoldBinaryOp( + auto result = constFoldBinaryOpPoison( adaptor.getOperands(), [&](APInt a, const APInt &b) { if (overflowOrDiv0 || !b) { overflowOrDiv0 = true; @@ -621,7 +621,7 @@ // Don't fold if it would overflow or if it requires a division by zero. bool overflowOrDiv0 = false; - auto result = constFoldBinaryOp( + auto result = constFoldBinaryOpPoison( adaptor.getOperands(), [&](APInt a, const APInt &b) { if (overflowOrDiv0 || !b) { overflowOrDiv0 = true; @@ -682,7 +682,7 @@ // Don't fold if it would overflow or if it requires a division by zero. bool overflowOrDiv0 = false; - auto result = constFoldBinaryOp( + auto result = constFoldBinaryOpPoison( adaptor.getOperands(), [&](APInt a, const APInt &b) { if (overflowOrDiv0 || !b) { overflowOrDiv0 = true; @@ -731,14 +731,14 @@ // Don't fold if it would require a division by zero. bool div0 = false; - auto result = constFoldBinaryOp(adaptor.getOperands(), - [&](APInt a, const APInt &b) { - if (div0 || b.isZero()) { - div0 = true; - return a; - } - return a.urem(b); - }); + auto result = constFoldBinaryOpPoison( + adaptor.getOperands(), [&](APInt a, const APInt &b) { + if (div0 || b.isZero()) { + div0 = true; + return a; + } + return a.urem(b); + }); return div0 ? Attribute() : result; } @@ -754,14 +754,14 @@ // Don't fold if it would require a division by zero. bool div0 = false; - auto result = constFoldBinaryOp(adaptor.getOperands(), - [&](APInt a, const APInt &b) { - if (div0 || b.isZero()) { - div0 = true; - return a; - } - return a.srem(b); - }); + auto result = constFoldBinaryOpPoison( + adaptor.getOperands(), [&](APInt a, const APInt &b) { + if (div0 || b.isZero()) { + div0 = true; + return a; + } + return a.srem(b); + }); return div0 ? Attribute() : result; } @@ -810,7 +810,7 @@ if (Value result = foldAndIofAndI(*this)) return result; - return constFoldBinaryOp( + return constFoldBinaryOpPoison( adaptor.getOperands(), [](APInt a, const APInt &b) { return std::move(a) & b; }); } @@ -840,7 +840,7 @@ intValue.isAllOnes()) return getLhs().getDefiningOp().getRhs(); - return constFoldBinaryOp( + return constFoldBinaryOpPoison( adaptor.getOperands(), [](APInt a, const APInt &b) { return std::move(a) | b; }); } @@ -873,7 +873,7 @@ return prev.getRhs(); } - return constFoldBinaryOp( + return constFoldBinaryOpPoison( adaptor.getOperands(), [](APInt a, const APInt &b) { return std::move(a) ^ b; }); } @@ -891,8 +891,8 @@ /// negf(negf(x)) -> x if (auto op = this->getOperand().getDefiningOp()) return op.getOperand(); - return constFoldUnaryOp(adaptor.getOperands(), - [](const APFloat &a) { return -a; }); + return constFoldUnaryOpPoison(adaptor.getOperands(), + [](const APFloat &a) { return -a; }); } //===----------------------------------------------------------------------===// @@ -904,7 +904,7 @@ if (matchPattern(getRhs(), m_NegZeroFloat())) return getLhs(); - return constFoldBinaryOp( + return constFoldBinaryOpPoison( adaptor.getOperands(), [](const APFloat &a, const APFloat &b) { return a + b; }); } @@ -918,7 +918,7 @@ if (matchPattern(getRhs(), m_PosZeroFloat())) return getLhs(); - return constFoldBinaryOp( + return constFoldBinaryOpPoison( adaptor.getOperands(), [](const APFloat &a, const APFloat &b) { return a - b; }); } @@ -936,7 +936,7 @@ if (matchPattern(getRhs(), m_NegInfFloat())) return getLhs(); - return constFoldBinaryOp( + return constFoldBinaryOpPoison( adaptor.getOperands(), [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); }); } @@ -961,10 +961,10 @@ intValue.isMinSignedValue()) return getLhs(); - return constFoldBinaryOp(adaptor.getOperands(), - [](const APInt &a, const APInt &b) { - return llvm::APIntOps::smax(a, b); - }); + return constFoldBinaryOpPoison( + adaptor.getOperands(), [](const APInt &a, const APInt &b) { + return llvm::APIntOps::smax(a, b); + }); } //===----------------------------------------------------------------------===// @@ -985,10 +985,10 @@ if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue()) return getLhs(); - return constFoldBinaryOp(adaptor.getOperands(), - [](const APInt &a, const APInt &b) { - return llvm::APIntOps::umax(a, b); - }); + return constFoldBinaryOpPoison( + adaptor.getOperands(), [](const APInt &a, const APInt &b) { + return llvm::APIntOps::umax(a, b); + }); } //===----------------------------------------------------------------------===// @@ -1004,7 +1004,7 @@ if (matchPattern(getRhs(), m_PosInfFloat())) return getLhs(); - return constFoldBinaryOp( + return constFoldBinaryOpPoison( adaptor.getOperands(), [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); }); } @@ -1029,10 +1029,10 @@ intValue.isMaxSignedValue()) return getLhs(); - return constFoldBinaryOp(adaptor.getOperands(), - [](const APInt &a, const APInt &b) { - return llvm::APIntOps::smin(a, b); - }); + return constFoldBinaryOpPoison( + adaptor.getOperands(), [](const APInt &a, const APInt &b) { + return llvm::APIntOps::smin(a, b); + }); } //===----------------------------------------------------------------------===// @@ -1053,10 +1053,10 @@ if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue()) return getLhs(); - return constFoldBinaryOp(adaptor.getOperands(), - [](const APInt &a, const APInt &b) { - return llvm::APIntOps::umin(a, b); - }); + return constFoldBinaryOpPoison( + adaptor.getOperands(), [](const APInt &a, const APInt &b) { + return llvm::APIntOps::umin(a, b); + }); } //===----------------------------------------------------------------------===// @@ -1068,7 +1068,7 @@ if (matchPattern(getRhs(), m_OneFloat())) return getLhs(); - return constFoldBinaryOp( + return constFoldBinaryOpPoison( adaptor.getOperands(), [](const APFloat &a, const APFloat &b) { return a * b; }); } @@ -1087,7 +1087,7 @@ if (matchPattern(getRhs(), m_OneFloat())) return getLhs(); - return constFoldBinaryOp( + return constFoldBinaryOpPoison( adaptor.getOperands(), [](const APFloat &a, const APFloat &b) { return a / b; }); } @@ -1102,12 +1102,12 @@ //===----------------------------------------------------------------------===// OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) { - return constFoldBinaryOp(adaptor.getOperands(), - [](const APFloat &a, const APFloat &b) { - APFloat result(a); - (void)result.remainder(b); - return result; - }); + return constFoldBinaryOpPoison( + adaptor.getOperands(), [](const APFloat &a, const APFloat &b) { + APFloat result(a); + (void)result.remainder(b); + return result; + }); } //===----------------------------------------------------------------------===// @@ -1224,7 +1224,7 @@ Type resType = getElementTypeOrSelf(getType()); unsigned bitWidth = llvm::cast(resType).getWidth(); - return constFoldCastOp( + return constFoldCastOpPoison( adaptor.getOperands(), getType(), [bitWidth](const APInt &a, bool &castStatus) { return a.zext(bitWidth); @@ -1251,7 +1251,7 @@ Type resType = getElementTypeOrSelf(getType()); unsigned bitWidth = llvm::cast(resType).getWidth(); - return constFoldCastOp( + return constFoldCastOpPoison( adaptor.getOperands(), getType(), [bitWidth](const APInt &a, bool &castStatus) { return a.sext(bitWidth); @@ -1321,7 +1321,7 @@ Type resType = getElementTypeOrSelf(getType()); unsigned bitWidth = llvm::cast(resType).getWidth(); - return constFoldCastOp( + return constFoldCastOpPoison( adaptor.getOperands(), getType(), [bitWidth](const APInt &a, bool &castStatus) { return a.trunc(bitWidth); @@ -1417,7 +1417,7 @@ OpFoldResult arith::UIToFPOp::fold(FoldAdaptor adaptor) { Type resEleType = getElementTypeOrSelf(getType()); - return constFoldCastOp( + return constFoldCastOpPoison( adaptor.getOperands(), getType(), [&resEleType](const APInt &a, bool &castStatus) { FloatType floatTy = llvm::cast(resEleType); @@ -1439,7 +1439,7 @@ OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) { Type resEleType = getElementTypeOrSelf(getType()); - return constFoldCastOp( + return constFoldCastOpPoison( adaptor.getOperands(), getType(), [&resEleType](const APInt &a, bool &castStatus) { FloatType floatTy = llvm::cast(resEleType); @@ -1461,7 +1461,7 @@ OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) { Type resType = getElementTypeOrSelf(getType()); unsigned bitWidth = llvm::cast(resType).getWidth(); - return constFoldCastOp( + return constFoldCastOpPoison( adaptor.getOperands(), getType(), [&bitWidth](const APFloat &a, bool &castStatus) { bool ignored; @@ -1483,7 +1483,7 @@ OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) { Type resType = getElementTypeOrSelf(getType()); unsigned bitWidth = llvm::cast(resType).getWidth(); - return constFoldCastOp( + return constFoldCastOpPoison( adaptor.getOperands(), getType(), [&bitWidth](const APFloat &a, bool &castStatus) { bool ignored; @@ -1522,7 +1522,7 @@ if (auto intTy = dyn_cast(getElementTypeOrSelf(getType()))) resultBitwidth = intTy.getWidth(); - return constFoldCastOp( + return constFoldCastOpPoison( adaptor.getOperands(), getType(), [resultBitwidth](const APInt &a, bool & /*castStatus*/) { return a.sextOrTrunc(resultBitwidth); @@ -1549,7 +1549,7 @@ if (auto intTy = dyn_cast(getElementTypeOrSelf(getType()))) resultBitwidth = intTy.getWidth(); - return constFoldCastOp( + return constFoldCastOpPoison( adaptor.getOperands(), getType(), [resultBitwidth](const APInt &a, bool & /*castStatus*/) { return a.zextOrTrunc(resultBitwidth); @@ -1731,7 +1731,7 @@ // We are moving constants to the right side; So if lhs is constant rhs is // guaranteed to be a constant. if (auto lhs = llvm::dyn_cast_if_present(adaptor.getLhs())) { - return constFoldBinaryOp( + return constFoldBinaryOpPoison( adaptor.getOperands(), getI1SameShape(lhs.getType()), [pred = getPredicate()](const APInt &lhs, const APInt &rhs) { return APInt(1, @@ -2317,7 +2317,7 @@ return getLhs(); // Don't fold if shifting more than the bit width. bool bounded = false; - auto result = constFoldBinaryOp( + auto result = constFoldBinaryOpPoison( adaptor.getOperands(), [&](const APInt &a, const APInt &b) { bounded = b.ule(b.getBitWidth()); return a.shl(b); @@ -2335,7 +2335,7 @@ return getLhs(); // Don't fold if shifting more than the bit width. bool bounded = false; - auto result = constFoldBinaryOp( + auto result = constFoldBinaryOpPoison( adaptor.getOperands(), [&](const APInt &a, const APInt &b) { bounded = b.ule(b.getBitWidth()); return a.lshr(b); @@ -2353,7 +2353,7 @@ return getLhs(); // Don't fold if shifting more than the bit width. bool bounded = false; - auto result = constFoldBinaryOp( + auto result = constFoldBinaryOpPoison( adaptor.getOperands(), [&](const APInt &a, const APInt &b) { bounded = b.ule(b.getBitWidth()); return a.ashr(b); diff --git a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp --- a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp +++ b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp @@ -9,7 +9,6 @@ #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/CommonFolders.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" diff --git a/mlir/lib/Dialect/Math/IR/CMakeLists.txt b/mlir/lib/Dialect/Math/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Math/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Math/IR/CMakeLists.txt @@ -12,4 +12,5 @@ MLIRArithDialect MLIRDialect MLIRIR + MLIRUBDialect ) diff --git a/mlir/lib/Dialect/Math/IR/MathDialect.cpp b/mlir/lib/Dialect/Math/IR/MathDialect.cpp --- a/mlir/lib/Dialect/Math/IR/MathDialect.cpp +++ b/mlir/lib/Dialect/Math/IR/MathDialect.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/Transforms/InliningUtils.h" using namespace mlir; diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp --- a/mlir/lib/Dialect/Math/IR/MathOps.cpp +++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/CommonFolders.h" #include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/Builders.h" #include @@ -27,8 +28,8 @@ //===----------------------------------------------------------------------===// OpFoldResult math::AbsFOp::fold(FoldAdaptor adaptor) { - return constFoldUnaryOp(adaptor.getOperands(), - [](const APFloat &a) { return abs(a); }); + return constFoldUnaryOpPoison( + adaptor.getOperands(), [](const APFloat &a) { return abs(a); }); } //===----------------------------------------------------------------------===// @@ -36,8 +37,8 @@ //===----------------------------------------------------------------------===// OpFoldResult math::AbsIOp::fold(FoldAdaptor adaptor) { - return constFoldUnaryOp(adaptor.getOperands(), - [](const APInt &a) { return a.abs(); }); + return constFoldUnaryOpPoison( + adaptor.getOperands(), [](const APInt &a) { return a.abs(); }); } //===----------------------------------------------------------------------===// @@ -45,7 +46,7 @@ //===----------------------------------------------------------------------===// OpFoldResult math::AtanOp::fold(FoldAdaptor adaptor) { - return constFoldUnaryOpConditional( + return constFoldUnaryOpConditionalPoison( adaptor.getOperands(), [](const APFloat &a) -> std::optional { switch (a.getSizeInBits(a.getSemantics())) { case 64: @@ -63,7 +64,7 @@ //===----------------------------------------------------------------------===// OpFoldResult math::Atan2Op::fold(FoldAdaptor adaptor) { - return constFoldBinaryOpConditional( + return constFoldBinaryOpConditionalPoison( adaptor.getOperands(), [](const APFloat &a, const APFloat &b) -> std::optional { if (a.isZero() && b.isZero()) @@ -86,7 +87,7 @@ //===----------------------------------------------------------------------===// OpFoldResult math::CeilOp::fold(FoldAdaptor adaptor) { - return constFoldUnaryOp( + return constFoldUnaryOpPoison( adaptor.getOperands(), [](const APFloat &a) { APFloat result(a); result.roundToIntegral(llvm::RoundingMode::TowardPositive); @@ -99,12 +100,12 @@ //===----------------------------------------------------------------------===// OpFoldResult math::CopySignOp::fold(FoldAdaptor adaptor) { - return constFoldBinaryOp(adaptor.getOperands(), - [](const APFloat &a, const APFloat &b) { - APFloat result(a); - result.copySign(b); - return result; - }); + return constFoldBinaryOpPoison( + adaptor.getOperands(), [](const APFloat &a, const APFloat &b) { + APFloat result(a); + result.copySign(b); + return result; + }); } //===----------------------------------------------------------------------===// @@ -112,7 +113,7 @@ //===----------------------------------------------------------------------===// OpFoldResult math::CosOp::fold(FoldAdaptor adaptor) { - return constFoldUnaryOpConditional( + return constFoldUnaryOpConditionalPoison( adaptor.getOperands(), [](const APFloat &a) -> std::optional { switch (a.getSizeInBits(a.getSemantics())) { case 64: @@ -130,7 +131,7 @@ //===----------------------------------------------------------------------===// OpFoldResult math::SinOp::fold(FoldAdaptor adaptor) { - return constFoldUnaryOpConditional( + return constFoldUnaryOpConditionalPoison( adaptor.getOperands(), [](const APFloat &a) -> std::optional { switch (a.getSizeInBits(a.getSemantics())) { case 64: @@ -148,7 +149,7 @@ //===----------------------------------------------------------------------===// OpFoldResult math::CountLeadingZerosOp::fold(FoldAdaptor adaptor) { - return constFoldUnaryOp( + return constFoldUnaryOpPoison( adaptor.getOperands(), [](const APInt &a) { return APInt(a.getBitWidth(), a.countl_zero()); }); } @@ -158,7 +159,7 @@ //===----------------------------------------------------------------------===// OpFoldResult math::CountTrailingZerosOp::fold(FoldAdaptor adaptor) { - return constFoldUnaryOp( + return constFoldUnaryOpPoison( adaptor.getOperands(), [](const APInt &a) { return APInt(a.getBitWidth(), a.countr_zero()); }); } @@ -168,7 +169,7 @@ //===----------------------------------------------------------------------===// OpFoldResult math::CtPopOp::fold(FoldAdaptor adaptor) { - return constFoldUnaryOp( + return constFoldUnaryOpPoison( adaptor.getOperands(), [](const APInt &a) { return APInt(a.getBitWidth(), a.popcount()); }); } @@ -178,7 +179,7 @@ //===----------------------------------------------------------------------===// OpFoldResult math::ErfOp::fold(FoldAdaptor adaptor) { - return constFoldUnaryOpConditional( + return constFoldUnaryOpConditionalPoison( adaptor.getOperands(), [](const APFloat &a) -> std::optional { switch (a.getSizeInBits(a.getSemantics())) { case 64: @@ -196,7 +197,7 @@ //===----------------------------------------------------------------------===// OpFoldResult math::IPowIOp::fold(FoldAdaptor adaptor) { - return constFoldBinaryOpConditional( + return constFoldBinaryOpConditionalPoison( adaptor.getOperands(), [](const APInt &base, const APInt &power) -> std::optional { unsigned width = base.getBitWidth(); @@ -247,7 +248,7 @@ //===----------------------------------------------------------------------===// OpFoldResult math::LogOp::fold(FoldAdaptor adaptor) { - return constFoldUnaryOpConditional( + return constFoldUnaryOpConditionalPoison( adaptor.getOperands(), [](const APFloat &a) -> std::optional { if (a.isNegative()) return {}; @@ -267,7 +268,7 @@ //===----------------------------------------------------------------------===// OpFoldResult math::Log2Op::fold(FoldAdaptor adaptor) { - return constFoldUnaryOpConditional( + return constFoldUnaryOpConditionalPoison( adaptor.getOperands(), [](const APFloat &a) -> std::optional { if (a.isNegative()) return {}; @@ -287,7 +288,7 @@ //===----------------------------------------------------------------------===// OpFoldResult math::Log10Op::fold(FoldAdaptor adaptor) { - return constFoldUnaryOpConditional( + return constFoldUnaryOpConditionalPoison( adaptor.getOperands(), [](const APFloat &a) -> std::optional { if (a.isNegative()) return {}; @@ -308,7 +309,7 @@ //===----------------------------------------------------------------------===// OpFoldResult math::Log1pOp::fold(FoldAdaptor adaptor) { - return constFoldUnaryOpConditional( + return constFoldUnaryOpConditionalPoison( adaptor.getOperands(), [](const APFloat &a) -> std::optional { switch (a.getSizeInBits(a.getSemantics())) { case 64: @@ -330,7 +331,7 @@ //===----------------------------------------------------------------------===// OpFoldResult math::PowFOp::fold(FoldAdaptor adaptor) { - return constFoldBinaryOpConditional( + return constFoldBinaryOpConditionalPoison( adaptor.getOperands(), [](const APFloat &a, const APFloat &b) -> std::optional { if (a.getSizeInBits(a.getSemantics()) == 64 && @@ -350,7 +351,7 @@ //===----------------------------------------------------------------------===// OpFoldResult math::SqrtOp::fold(FoldAdaptor adaptor) { - return constFoldUnaryOpConditional( + return constFoldUnaryOpConditionalPoison( adaptor.getOperands(), [](const APFloat &a) -> std::optional { if (a.isNegative()) return {}; @@ -371,7 +372,7 @@ //===----------------------------------------------------------------------===// OpFoldResult math::ExpOp::fold(FoldAdaptor adaptor) { - return constFoldUnaryOpConditional( + return constFoldUnaryOpConditionalPoison( adaptor.getOperands(), [](const APFloat &a) -> std::optional { switch (a.getSizeInBits(a.getSemantics())) { case 64: @@ -389,7 +390,7 @@ //===----------------------------------------------------------------------===// OpFoldResult math::Exp2Op::fold(FoldAdaptor adaptor) { - return constFoldUnaryOpConditional( + return constFoldUnaryOpConditionalPoison( adaptor.getOperands(), [](const APFloat &a) -> std::optional { switch (a.getSizeInBits(a.getSemantics())) { case 64: @@ -407,7 +408,7 @@ //===----------------------------------------------------------------------===// OpFoldResult math::ExpM1Op::fold(FoldAdaptor adaptor) { - return constFoldUnaryOpConditional( + return constFoldUnaryOpConditionalPoison( adaptor.getOperands(), [](const APFloat &a) -> std::optional { switch (a.getSizeInBits(a.getSemantics())) { case 64: @@ -425,7 +426,7 @@ //===----------------------------------------------------------------------===// OpFoldResult math::TanOp::fold(FoldAdaptor adaptor) { - return constFoldUnaryOpConditional( + return constFoldUnaryOpConditionalPoison( adaptor.getOperands(), [](const APFloat &a) -> std::optional { switch (a.getSizeInBits(a.getSemantics())) { case 64: @@ -443,7 +444,7 @@ //===----------------------------------------------------------------------===// OpFoldResult math::TanhOp::fold(FoldAdaptor adaptor) { - return constFoldUnaryOpConditional( + return constFoldUnaryOpConditionalPoison( adaptor.getOperands(), [](const APFloat &a) -> std::optional { switch (a.getSizeInBits(a.getSemantics())) { case 64: @@ -461,7 +462,7 @@ //===----------------------------------------------------------------------===// OpFoldResult math::RoundEvenOp::fold(FoldAdaptor adaptor) { - return constFoldUnaryOp( + return constFoldUnaryOpPoison( adaptor.getOperands(), [](const APFloat &a) { APFloat result(a); result.roundToIntegral(llvm::RoundingMode::NearestTiesToEven); @@ -474,7 +475,7 @@ //===----------------------------------------------------------------------===// OpFoldResult math::FloorOp::fold(FoldAdaptor adaptor) { - return constFoldUnaryOp( + return constFoldUnaryOpPoison( adaptor.getOperands(), [](const APFloat &a) { APFloat result(a); result.roundToIntegral(llvm::RoundingMode::TowardNegative); @@ -487,7 +488,7 @@ //===----------------------------------------------------------------------===// OpFoldResult math::RoundOp::fold(FoldAdaptor adaptor) { - return constFoldUnaryOpConditional( + return constFoldUnaryOpConditionalPoison( adaptor.getOperands(), [](const APFloat &a) -> std::optional { switch (a.getSizeInBits(a.getSemantics())) { case 64: @@ -505,7 +506,7 @@ //===----------------------------------------------------------------------===// OpFoldResult math::TruncOp::fold(FoldAdaptor adaptor) { - return constFoldUnaryOpConditional( + return constFoldUnaryOpConditionalPoison( adaptor.getOperands(), [](const APFloat &a) -> std::optional { switch (a.getSizeInBits(a.getSemantics())) { case 64: @@ -522,5 +523,8 @@ Operation *math::MathDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { + if (auto poison = dyn_cast(value)) + return builder.create(loc, type, poison); + return arith::ConstantOp::materialize(builder, value, type, loc); } diff --git a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt --- a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt @@ -42,4 +42,5 @@ MLIRSideEffectInterfaces MLIRSupport MLIRTransforms + MLIRUBDialect ) diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp @@ -18,6 +18,7 @@ #include "mlir/Dialect/CommonFolders.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "llvm/ADT/STLExtras.h" @@ -243,7 +244,7 @@ // The resulting value will equal the low-order N bits of the correct result // R, where N is the component width and R is computed with enough precision // to avoid overflow and underflow. - return constFoldBinaryOp( + return constFoldBinaryOpPoison( adaptor.getOperands(), [](APInt a, const APInt &b) { return std::move(a) + b; }); } @@ -265,7 +266,7 @@ // The resulting value will equal the low-order N bits of the correct result // R, where N is the component width and R is computed with enough precision // to avoid overflow and underflow. - return constFoldBinaryOp( + return constFoldBinaryOpPoison( adaptor.getOperands(), [](const APInt &a, const APInt &b) { return a * b; }); } @@ -284,7 +285,7 @@ // The resulting value will equal the low-order N bits of the correct result // R, where N is the component width and R is computed with enough precision // to avoid overflow and underflow. - return constFoldBinaryOp( + return constFoldBinaryOpPoison( adaptor.getOperands(), [](APInt a, const APInt &b) { return std::move(a) - b; }); } diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" @@ -973,6 +974,9 @@ Operation *SPIRVDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { + if (auto poison = dyn_cast(value)) + return builder.create(loc, type, poison); + if (!spirv::ConstantOp::isBuildableWith(type)) return nullptr; diff --git a/mlir/lib/Dialect/Shape/IR/CMakeLists.txt b/mlir/lib/Dialect/Shape/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Shape/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Shape/IR/CMakeLists.txt @@ -22,4 +22,5 @@ MLIRIR MLIRSideEffectInterfaces MLIRTensorDialect + MLIRUBDialect ) diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/CommonFolders.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Traits.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" @@ -147,6 +148,9 @@ Operation *ShapeDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { + if (auto poison = dyn_cast(value)) + return builder.create(loc, type, poison); + if (llvm::isa(type) || isExtentTensorType(type)) return builder.create( loc, type, llvm::cast(value)); @@ -156,6 +160,7 @@ if (llvm::isa(type)) return builder.create(loc, type, llvm::cast(value)); + return arith::ConstantOp::materialize(builder, value, type, loc); } @@ -413,7 +418,7 @@ if (matchPattern(getRhs(), m_Zero())) return getLhs(); - return constFoldBinaryOp( + return constFoldBinaryOpPoison( adaptor.getOperands(), [](APInt a, const APInt &b) { return std::move(a) + b; }); } diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -2584,3 +2584,58 @@ %select4 = arith.select %false, %poison, %arg : i32 return %select1, %select2, %select3, %select4 : i32, i32, i32, i32 } + +// CHECK-LABEL: @addi_poison1 +// CHECK: %[[P:.*]] = ub.poison : i32 +// CHECK: return %[[P]] +func.func @addi_poison1(%arg: i32) -> i32 { + %0 = ub.poison : i32 + %1 = arith.addi %0, %arg : i32 + return %1 : i32 +} + +// CHECK-LABEL: @addi_poison2 +// CHECK: %[[P:.*]] = ub.poison : i32 +// CHECK: return %[[P]] +func.func @addi_poison2(%arg: i32) -> i32 { + %0 = ub.poison : i32 + %1 = arith.addi %arg, %0 : i32 + return %1 : i32 +} + +// CHECK-LABEL: @addf_poison1 +// CHECK: %[[P:.*]] = ub.poison : f32 +// CHECK: return %[[P]] +func.func @addf_poison1(%arg: f32) -> f32 { + %0 = ub.poison : f32 + %1 = arith.addf %0, %arg : f32 + return %1 : f32 +} + +// CHECK-LABEL: @addf_poison2 +// CHECK: %[[P:.*]] = ub.poison : f32 +// CHECK: return %[[P]] +func.func @addf_poison2(%arg: f32) -> f32 { + %0 = ub.poison : f32 + %1 = arith.addf %arg, %0 : f32 + return %1 : f32 +} + + +// CHECK-LABEL: @negf_poison +// CHECK: %[[P:.*]] = ub.poison : f32 +// CHECK: return %[[P]] +func.func @negf_poison() -> f32 { + %0 = ub.poison : f32 + %1 = arith.negf %0 : f32 + return %1 : f32 +} + +// CHECK-LABEL: @extsi_poison +// CHECK: %[[P:.*]] = ub.poison : i64 +// CHECK: return %[[P]] +func.func @extsi_poison() -> i64 { + %0 = ub.poison : i32 + %1 = arith.extsi %0 : i32 to i64 + return %1 : i64 +} diff --git a/mlir/test/Dialect/Math/canonicalize.mlir b/mlir/test/Dialect/Math/canonicalize.mlir --- a/mlir/test/Dialect/Math/canonicalize.mlir +++ b/mlir/test/Dialect/Math/canonicalize.mlir @@ -483,3 +483,12 @@ %0 = math.erf %v1 : vector<4xf32> return %0 : vector<4xf32> } + +// CHECK-LABEL: @abs_poison +// CHECK: %[[P:.*]] = ub.poison : f32 +// CHECK: return %[[P]] +func.func @abs_poison() -> f32 { + %0 = ub.poison : f32 + %1 = math.absf %0 : f32 + return %1 : f32 +} diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir --- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir @@ -325,6 +325,15 @@ return %0: vector<3xi32> } +// CHECK-LABEL: @iadd_poison +// CHECK: %[[P:.*]] = ub.poison : i32 +// CHECK: return %[[P]] +func.func @iadd_poison(%arg0: i32) -> i32 { + %0 = ub.poison : i32 + %1 = spirv.IAdd %arg0, %0 : i32 + return %1: i32 +} + // ----- //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -1479,3 +1479,16 @@ // CHECK: return %[[DIM]] return %result : index } + + +// ----- + +// CHECK-LABEL: @add_poison +// CHECK: %[[P:.*]] = ub.poison : !shape.siz +// CHECK: return %[[P]] +func.func @add_poison() -> !shape.size { + %1 = shape.const_size 2 + %2 = ub.poison : !shape.size + %result = shape.add %1, %2 : !shape.size, !shape.size -> !shape.size + return %result : !shape.size +}