diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td --- a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td +++ b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td @@ -22,6 +22,10 @@ dialect also accept vectors and tensors of integers or floats. }]; + let dependentDialects = [ + "::mlir::ub::UBDialect" + ]; + let hasConstantMaterializer = 1; let useDefaultAttributePrinterParser = 1; } diff --git a/mlir/include/mlir/Dialect/CommonFolders.h b/mlir/include/mlir/Dialect/CommonFolders.h --- a/mlir/include/mlir/Dialect/CommonFolders.h +++ b/mlir/include/mlir/Dialect/CommonFolders.h @@ -22,17 +22,29 @@ #include namespace mlir { +namespace ub { +class PoisonAttr; +} /// Performs constant folding `calculate` with element-wise behavior on the two /// attributes in `operands` and returns the result if possible. /// Uses `resultType` for the type of the returned attribute. template (ElementValueT, ElementValueT)>> Attribute constFoldBinaryOpConditional(ArrayRef operands, Type resultType, const CalculationT &calculate) { assert(operands.size() == 2 && "binary op takes two operands"); + if constexpr (!std::is_void_v) { + if (isa_and_nonnull(operands[0])) + return operands[0]; + + if (isa_and_nonnull(operands[1])) + return operands[1]; + } + if (!resultType || !operands[0] || !operands[1]) return {}; @@ -97,11 +109,20 @@ /// attribute. template (ElementValueT, ElementValueT)>> Attribute constFoldBinaryOpConditional(ArrayRef operands, const CalculationT &calculate) { assert(operands.size() == 2 && "binary op takes two operands"); + if constexpr (!std::is_void_v) { + if (isa_and_nonnull(operands[0])) + return operands[0]; + + if (isa_and_nonnull(operands[1])) + return operands[1]; + } + auto getResultType = [](Attribute attr) -> Type { if (auto typed = dyn_cast_or_null(attr)) return typed.getType(); @@ -115,18 +136,19 @@ if (lhsType != rhsType) return {}; - return constFoldBinaryOpConditional(operands, lhsType, calculate); } template > Attribute constFoldBinaryOp(ArrayRef operands, Type resultType, const CalculationT &calculate) { - return constFoldBinaryOpConditional( + return constFoldBinaryOpConditional( operands, resultType, [&](ElementValueT a, ElementValueT b) -> std::optional { return calculate(a, b); @@ -135,11 +157,12 @@ template > Attribute constFoldBinaryOp(ArrayRef operands, const CalculationT &calculate) { - return constFoldBinaryOpConditional( + return constFoldBinaryOpConditional( operands, [&](ElementValueT a, ElementValueT b) -> std::optional { return calculate(a, b); @@ -150,6 +173,7 @@ /// attributes in `operands` and returns the result if possible. template (ElementValueT)>> Attribute constFoldUnaryOpConditional(ArrayRef operands, @@ -158,6 +182,11 @@ if (!operands[0]) return {}; + if constexpr (!std::is_void_v) { + if (isa(operands[0])) + return operands[0]; + } + if (isa(operands[0])) { auto op = cast(operands[0]); @@ -196,10 +225,11 @@ template > Attribute constFoldUnaryOp(ArrayRef operands, const CalculationT &&calculate) { - return constFoldUnaryOpConditional( + return constFoldUnaryOpConditional( operands, [&](ElementValueT a) -> std::optional { return calculate(a); }); @@ -209,6 +239,7 @@ class AttrElementT, class TargetAttrElementT, class ElementValueT = typename AttrElementT::ValueType, class TargetElementValueT = typename TargetAttrElementT::ValueType, + class PoisonAttr = ub::PoisonAttr, class CalculationT = function_ref> Attribute constFoldCastOp(ArrayRef operands, Type resType, const CalculationT &calculate) { @@ -216,6 +247,11 @@ if (!operands[0]) return {}; + if constexpr (!std::is_void_v) { + if (isa(operands[0])) + return operands[0]; + } + if (isa(operands[0])) { auto op = cast(operands[0]); bool castStatus = true; diff --git a/mlir/include/mlir/Dialect/Math/IR/MathBase.td b/mlir/include/mlir/Dialect/Math/IR/MathBase.td --- a/mlir/include/mlir/Dialect/Math/IR/MathBase.td +++ b/mlir/include/mlir/Dialect/Math/IR/MathBase.td @@ -31,7 +31,8 @@ }]; let hasConstantMaterializer = 1; let dependentDialects = [ - "::mlir::arith::ArithDialect" + "::mlir::arith::ArithDialect", + "::mlir::ub::UBDialect" ]; } #endif // MATH_BASE diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -53,6 +53,10 @@ let hasRegionArgAttrVerify = 1; let hasRegionResultAttrVerify = 1; + let dependentDialects = [ + "::mlir::ub::UBDialect" + ]; + let extraClassDeclaration = [{ void registerAttributes(); void registerTypes(); diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td @@ -36,7 +36,11 @@ }]; let cppNamespace = "::mlir::shape"; - let dependentDialects = ["arith::ArithDialect", "tensor::TensorDialect"]; + let dependentDialects = [ + "::mlir::arith::ArithDialect", + "::mlir::tensor::TensorDialect", + "::mlir::ub::UBDialect" + ]; let useDefaultTypePrinterParser = 1; let hasConstantMaterializer = 1; diff --git a/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp b/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp --- a/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/Transforms/InliningUtils.h" @@ -49,5 +50,8 @@ Operation *arith::ArithDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { + if (isa(value)) + return builder.create(loc, type, cast(value)); + return ConstantOp::materialize(builder, value, type, loc); } diff --git a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp --- a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp +++ b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp @@ -9,7 +9,6 @@ #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/CommonFolders.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" diff --git a/mlir/lib/Dialect/Math/IR/CMakeLists.txt b/mlir/lib/Dialect/Math/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Math/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Math/IR/CMakeLists.txt @@ -12,4 +12,5 @@ MLIRArithDialect MLIRDialect MLIRIR + MLIRUBDialect ) diff --git a/mlir/lib/Dialect/Math/IR/MathDialect.cpp b/mlir/lib/Dialect/Math/IR/MathDialect.cpp --- a/mlir/lib/Dialect/Math/IR/MathDialect.cpp +++ b/mlir/lib/Dialect/Math/IR/MathDialect.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/Transforms/InliningUtils.h" using namespace mlir; diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp --- a/mlir/lib/Dialect/Math/IR/MathOps.cpp +++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/CommonFolders.h" #include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/Builders.h" #include @@ -522,5 +523,8 @@ Operation *math::MathDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { + if (isa(value)) + return builder.create(loc, type, cast(value)); + return arith::ConstantOp::materialize(builder, value, type, loc); } diff --git a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt --- a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt @@ -42,4 +42,5 @@ MLIRSideEffectInterfaces MLIRSupport MLIRTransforms + MLIRUBDialect ) diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp @@ -18,6 +18,7 @@ #include "mlir/Dialect/CommonFolders.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "llvm/ADT/STLExtras.h" diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" @@ -973,6 +974,9 @@ Operation *SPIRVDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { + if (isa(value)) + return builder.create(loc, type, cast(value)); + if (!spirv::ConstantOp::isBuildableWith(type)) return nullptr; diff --git a/mlir/lib/Dialect/Shape/IR/CMakeLists.txt b/mlir/lib/Dialect/Shape/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Shape/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Shape/IR/CMakeLists.txt @@ -22,4 +22,5 @@ MLIRIR MLIRSideEffectInterfaces MLIRTensorDialect + MLIRUBDialect ) diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/CommonFolders.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Traits.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" @@ -147,6 +148,9 @@ Operation *ShapeDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { + if (isa(value)) + return builder.create(loc, type, cast(value)); + if (llvm::isa(type) || isExtentTensorType(type)) return builder.create( loc, type, llvm::cast(value)); @@ -156,6 +160,7 @@ if (llvm::isa(type)) return builder.create(loc, type, llvm::cast(value)); + return arith::ConstantOp::materialize(builder, value, type, loc); } diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -2584,3 +2584,58 @@ %select4 = arith.select %false, %poison, %arg : i32 return %select1, %select2, %select3, %select4 : i32, i32, i32, i32 } + +// CHECK-LABEL: @addi_poison1 +// CHECK: %[[P:.*]] = ub.poison : i32 +// CHECK: return %[[P]] +func.func @addi_poison1(%arg: i32) -> i32 { + %0 = ub.poison : i32 + %1 = arith.addi %0, %arg : i32 + return %1 : i32 +} + +// CHECK-LABEL: @addi_poison2 +// CHECK: %[[P:.*]] = ub.poison : i32 +// CHECK: return %[[P]] +func.func @addi_poison2(%arg: i32) -> i32 { + %0 = ub.poison : i32 + %1 = arith.addi %arg, %0 : i32 + return %1 : i32 +} + +// CHECK-LABEL: @addf_poison1 +// CHECK: %[[P:.*]] = ub.poison : f32 +// CHECK: return %[[P]] +func.func @addf_poison1(%arg: f32) -> f32 { + %0 = ub.poison : f32 + %1 = arith.addf %0, %arg : f32 + return %1 : f32 +} + +// CHECK-LABEL: @addf_poison2 +// CHECK: %[[P:.*]] = ub.poison : f32 +// CHECK: return %[[P]] +func.func @addf_poison2(%arg: f32) -> f32 { + %0 = ub.poison : f32 + %1 = arith.addf %arg, %0 : f32 + return %1 : f32 +} + + +// CHECK-LABEL: @negf_poison +// CHECK: %[[P:.*]] = ub.poison : f32 +// CHECK: return %[[P]] +func.func @negf_poison() -> f32 { + %0 = ub.poison : f32 + %1 = arith.negf %0 : f32 + return %1 : f32 +} + +// CHECK-LABEL: @extsi_poison +// CHECK: %[[P:.*]] = ub.poison : i64 +// CHECK: return %[[P]] +func.func @extsi_poison() -> i64 { + %0 = ub.poison : i32 + %1 = arith.extsi %0 : i32 to i64 + return %1 : i64 +} diff --git a/mlir/test/Dialect/Math/canonicalize.mlir b/mlir/test/Dialect/Math/canonicalize.mlir --- a/mlir/test/Dialect/Math/canonicalize.mlir +++ b/mlir/test/Dialect/Math/canonicalize.mlir @@ -483,3 +483,12 @@ %0 = math.erf %v1 : vector<4xf32> return %0 : vector<4xf32> } + +// CHECK-LABEL: @abs_poison +// CHECK: %[[P:.*]] = ub.poison : f32 +// CHECK: return %[[P]] +func.func @abs_poison() -> f32 { + %0 = ub.poison : f32 + %1 = math.absf %0 : f32 + return %1 : f32 +} diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir --- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir @@ -325,6 +325,15 @@ return %0: vector<3xi32> } +// CHECK-LABEL: @iadd_poison +// CHECK: %[[P:.*]] = ub.poison : i32 +// CHECK: return %[[P]] +func.func @iadd_poison(%arg0: i32) -> i32 { + %0 = ub.poison : i32 + %1 = spirv.IAdd %arg0, %0 : i32 + return %1: i32 +} + // ----- //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -1479,3 +1479,16 @@ // CHECK: return %[[DIM]] return %result : index } + + +// ----- + +// CHECK-LABEL: @add_poison +// CHECK: %[[P:.*]] = ub.poison : !shape.siz +// CHECK: return %[[P]] +func.func @add_poison() -> !shape.size { + %1 = shape.const_size 2 + %2 = ub.poison : !shape.size + %result = shape.add %1, %2 : !shape.size, !shape.size -> !shape.size + return %result : !shape.size +}