diff --git a/mlir/include/mlir/Dialect/CommonFolders.h b/mlir/include/mlir/Dialect/CommonFolders.h --- a/mlir/include/mlir/Dialect/CommonFolders.h +++ b/mlir/include/mlir/Dialect/CommonFolders.h @@ -22,17 +22,35 @@ #include namespace mlir { +namespace ub { +class PoisonAttr; +} /// Performs constant folding `calculate` with element-wise behavior on the two /// attributes in `operands` and returns the result if possible. /// Uses `resultType` for the type of the returned attribute. +/// Optional PoisonAttr template argument allows to specify 'poison' attribute +/// which will be directly propagated to result. template (ElementValueT, ElementValueT)>> Attribute constFoldBinaryOpConditional(ArrayRef operands, Type resultType, - const CalculationT &calculate) { + CalculationT &&calculate) { assert(operands.size() == 2 && "binary op takes two operands"); + static_assert( + std::is_void_v || !llvm::is_incomplete_v, + "PoisonAttr is undefined, either add a dependency on UB dialect or pass " + "void as template argument to opt-out from poison semantics."); + if constexpr (!std::is_void_v) { + if (isa_and_nonnull(operands[0])) + return operands[0]; + + if (isa_and_nonnull(operands[1])) + return operands[1]; + } + if (!resultType || !operands[0] || !operands[1]) return {}; @@ -95,13 +113,28 @@ /// attributes in `operands` and returns the result if possible. /// Uses the operand element type for the element type of the returned /// attribute. +/// Optional PoisonAttr template argument allows to specify 'poison' attribute +/// which will be directly propagated to result. template (ElementValueT, ElementValueT)>> Attribute constFoldBinaryOpConditional(ArrayRef operands, - const CalculationT &calculate) { + CalculationT &&calculate) { assert(operands.size() == 2 && "binary op takes two operands"); + static_assert( + std::is_void_v || !llvm::is_incomplete_v, + "PoisonAttr is undefined, either add a dependency on UB dialect or pass " + "void as template argument to opt-out from poison semantics."); + if constexpr (!std::is_void_v) { + if (isa_and_nonnull(operands[0])) + return operands[0]; + + if (isa_and_nonnull(operands[1])) + return operands[1]; + } + auto getResultType = [](Attribute attr) -> Type { if (auto typed = dyn_cast_or_null(attr)) return typed.getType(); @@ -115,18 +148,19 @@ if (lhsType != rhsType) return {}; - return constFoldBinaryOpConditional(operands, lhsType, - calculate); + return constFoldBinaryOpConditional( + operands, lhsType, std::forward(calculate)); } template > Attribute constFoldBinaryOp(ArrayRef operands, Type resultType, - const CalculationT &calculate) { - return constFoldBinaryOpConditional( + CalculationT &&calculate) { + return constFoldBinaryOpConditional( operands, resultType, [&](ElementValueT a, ElementValueT b) -> std::optional { return calculate(a, b); @@ -135,11 +169,12 @@ template > Attribute constFoldBinaryOp(ArrayRef operands, - const CalculationT &calculate) { - return constFoldBinaryOpConditional( + CalculationT &&calculate) { + return constFoldBinaryOpConditional( operands, [&](ElementValueT a, ElementValueT b) -> std::optional { return calculate(a, b); @@ -148,16 +183,28 @@ /// Performs constant folding `calculate` with element-wise behavior on the one /// attributes in `operands` and returns the result if possible. +/// Optional PoisonAttr template argument allows to specify 'poison' attribute +/// which will be directly propagated to result. template (ElementValueT)>> Attribute constFoldUnaryOpConditional(ArrayRef operands, - const CalculationT &&calculate) { + CalculationT &&calculate) { assert(operands.size() == 1 && "unary op takes one operands"); if (!operands[0]) return {}; + static_assert( + std::is_void_v || !llvm::is_incomplete_v, + "PoisonAttr is undefined, either add a dependency on UB dialect or pass " + "void as template argument to opt-out from poison semantics."); + if constexpr (!std::is_void_v) { + if (isa(operands[0])) + return operands[0]; + } + if (isa(operands[0])) { auto op = cast(operands[0]); @@ -196,10 +243,11 @@ template > Attribute constFoldUnaryOp(ArrayRef operands, - const CalculationT &&calculate) { - return constFoldUnaryOpConditional( + CalculationT &&calculate) { + return constFoldUnaryOpConditional( operands, [&](ElementValueT a) -> std::optional { return calculate(a); }); @@ -209,13 +257,23 @@ class AttrElementT, class TargetAttrElementT, class ElementValueT = typename AttrElementT::ValueType, class TargetElementValueT = typename TargetAttrElementT::ValueType, + class PoisonAttr = ub::PoisonAttr, class CalculationT = function_ref> Attribute constFoldCastOp(ArrayRef operands, Type resType, - const CalculationT &calculate) { + CalculationT &&calculate) { assert(operands.size() == 1 && "Cast op takes one operand"); if (!operands[0]) return {}; + static_assert( + std::is_void_v || !llvm::is_incomplete_v, + "PoisonAttr is undefined, either add a dependency on UB dialect or pass " + "void as template argument to opt-out from poison semantics."); + if constexpr (!std::is_void_v) { + if (isa(operands[0])) + return operands[0]; + } + if (isa(operands[0])) { auto op = cast(operands[0]); bool castStatus = true; @@ -254,7 +312,6 @@ } return {}; } - } // namespace mlir #endif // MLIR_DIALECT_COMMONFOLDERS_H diff --git a/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp b/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp --- a/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/Transforms/InliningUtils.h" @@ -49,5 +50,8 @@ Operation *arith::ArithDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { + if (auto poison = dyn_cast(value)) + return builder.create(loc, type, poison); + return ConstantOp::materialize(builder, value, type, loc); } diff --git a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp --- a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp +++ b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp @@ -9,7 +9,6 @@ #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/CommonFolders.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" diff --git a/mlir/lib/Dialect/Math/IR/CMakeLists.txt b/mlir/lib/Dialect/Math/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Math/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Math/IR/CMakeLists.txt @@ -12,4 +12,5 @@ MLIRArithDialect MLIRDialect MLIRIR + MLIRUBDialect ) diff --git a/mlir/lib/Dialect/Math/IR/MathDialect.cpp b/mlir/lib/Dialect/Math/IR/MathDialect.cpp --- a/mlir/lib/Dialect/Math/IR/MathDialect.cpp +++ b/mlir/lib/Dialect/Math/IR/MathDialect.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/Transforms/InliningUtils.h" using namespace mlir; diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp --- a/mlir/lib/Dialect/Math/IR/MathOps.cpp +++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/CommonFolders.h" #include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/Builders.h" #include @@ -522,5 +523,8 @@ Operation *math::MathDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { + if (auto poison = dyn_cast(value)) + return builder.create(loc, type, poison); + return arith::ConstantOp::materialize(builder, value, type, loc); } diff --git a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt --- a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt @@ -43,4 +43,5 @@ MLIRSideEffectInterfaces MLIRSupport MLIRTransforms + MLIRUBDialect ) diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp @@ -18,6 +18,7 @@ #include "mlir/Dialect/CommonFolders.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "llvm/ADT/STLExtras.h" diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" @@ -949,6 +950,9 @@ Operation *SPIRVDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { + if (auto poison = dyn_cast(value)) + return builder.create(loc, type, poison); + if (!spirv::ConstantOp::isBuildableWith(type)) return nullptr; diff --git a/mlir/lib/Dialect/Shape/IR/CMakeLists.txt b/mlir/lib/Dialect/Shape/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Shape/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Shape/IR/CMakeLists.txt @@ -23,4 +23,5 @@ MLIRIR MLIRSideEffectInterfaces MLIRTensorDialect + MLIRUBDialect ) diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/CommonFolders.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Traits.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" @@ -147,6 +148,9 @@ Operation *ShapeDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { + if (auto poison = dyn_cast(value)) + return builder.create(loc, type, poison); + if (llvm::isa(type) || isExtentTensorType(type)) return builder.create( loc, type, llvm::cast(value)); @@ -156,6 +160,7 @@ if (llvm::isa(type)) return builder.create(loc, type, llvm::cast(value)); + return arith::ConstantOp::materialize(builder, value, type, loc); } diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -2584,3 +2584,58 @@ %select4 = arith.select %false, %poison, %arg : i32 return %select1, %select2, %select3, %select4 : i32, i32, i32, i32 } + +// CHECK-LABEL: @addi_poison1 +// CHECK: %[[P:.*]] = ub.poison : i32 +// CHECK: return %[[P]] +func.func @addi_poison1(%arg: i32) -> i32 { + %0 = ub.poison : i32 + %1 = arith.addi %0, %arg : i32 + return %1 : i32 +} + +// CHECK-LABEL: @addi_poison2 +// CHECK: %[[P:.*]] = ub.poison : i32 +// CHECK: return %[[P]] +func.func @addi_poison2(%arg: i32) -> i32 { + %0 = ub.poison : i32 + %1 = arith.addi %arg, %0 : i32 + return %1 : i32 +} + +// CHECK-LABEL: @addf_poison1 +// CHECK: %[[P:.*]] = ub.poison : f32 +// CHECK: return %[[P]] +func.func @addf_poison1(%arg: f32) -> f32 { + %0 = ub.poison : f32 + %1 = arith.addf %0, %arg : f32 + return %1 : f32 +} + +// CHECK-LABEL: @addf_poison2 +// CHECK: %[[P:.*]] = ub.poison : f32 +// CHECK: return %[[P]] +func.func @addf_poison2(%arg: f32) -> f32 { + %0 = ub.poison : f32 + %1 = arith.addf %arg, %0 : f32 + return %1 : f32 +} + + +// CHECK-LABEL: @negf_poison +// CHECK: %[[P:.*]] = ub.poison : f32 +// CHECK: return %[[P]] +func.func @negf_poison() -> f32 { + %0 = ub.poison : f32 + %1 = arith.negf %0 : f32 + return %1 : f32 +} + +// CHECK-LABEL: @extsi_poison +// CHECK: %[[P:.*]] = ub.poison : i64 +// CHECK: return %[[P]] +func.func @extsi_poison() -> i64 { + %0 = ub.poison : i32 + %1 = arith.extsi %0 : i32 to i64 + return %1 : i64 +} diff --git a/mlir/test/Dialect/Math/canonicalize.mlir b/mlir/test/Dialect/Math/canonicalize.mlir --- a/mlir/test/Dialect/Math/canonicalize.mlir +++ b/mlir/test/Dialect/Math/canonicalize.mlir @@ -483,3 +483,12 @@ %0 = math.erf %v1 : vector<4xf32> return %0 : vector<4xf32> } + +// CHECK-LABEL: @abs_poison +// CHECK: %[[P:.*]] = ub.poison : f32 +// CHECK: return %[[P]] +func.func @abs_poison() -> f32 { + %0 = ub.poison : f32 + %1 = math.absf %0 : f32 + return %1 : f32 +} diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir --- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir @@ -325,6 +325,15 @@ return %0: vector<3xi32> } +// CHECK-LABEL: @iadd_poison +// CHECK: %[[P:.*]] = ub.poison : i32 +// CHECK: return %[[P]] +func.func @iadd_poison(%arg0: i32) -> i32 { + %0 = ub.poison : i32 + %1 = spirv.IAdd %arg0, %0 : i32 + return %1: i32 +} + // ----- //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -1479,3 +1479,16 @@ // CHECK: return %[[DIM]] return %result : index } + + +// ----- + +// CHECK-LABEL: @add_poison +// CHECK: %[[P:.*]] = ub.poison : !shape.siz +// CHECK: return %[[P]] +func.func @add_poison() -> !shape.size { + %1 = shape.const_size 2 + %2 = ub.poison : !shape.size + %result = shape.add %1, %2 : !shape.size, !shape.size -> !shape.size + return %result : !shape.size +}