diff --git a/mlir/include/mlir/Dialect/CommonFolders.h b/mlir/include/mlir/Dialect/CommonFolders.h --- a/mlir/include/mlir/Dialect/CommonFolders.h +++ b/mlir/include/mlir/Dialect/CommonFolders.h @@ -15,6 +15,7 @@ #ifndef MLIR_DIALECT_COMMONFOLDERS_H #define MLIR_DIALECT_COMMONFOLDERS_H +#include "mlir/Dialect/UB/IR/UBOps.h" // PoisonAttr #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "llvm/ADT/ArrayRef.h" @@ -33,6 +34,12 @@ Type resultType, const CalculationT &calculate) { assert(operands.size() == 2 && "binary op takes two operands"); + if (isa_and_nonnull(operands[0])) + return operands[0]; + + if (isa_and_nonnull(operands[1])) + return operands[1]; + if (!resultType || !operands[0] || !operands[1]) return {}; @@ -102,6 +109,12 @@ Attribute constFoldBinaryOpConditional(ArrayRef operands, const CalculationT &calculate) { assert(operands.size() == 2 && "binary op takes two operands"); + if (isa_and_nonnull(operands[0])) + return operands[0]; + + if (isa_and_nonnull(operands[1])) + return operands[1]; + auto getResultType = [](Attribute attr) -> Type { if (auto typed = dyn_cast_or_null(attr)) return typed.getType(); @@ -158,6 +171,9 @@ if (!operands[0]) return {}; + if (isa(operands[0])) + return operands[0]; + if (isa(operands[0])) { auto op = cast(operands[0]); diff --git a/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp b/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp --- a/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/Transforms/InliningUtils.h" @@ -49,5 +50,8 @@ Operation *arith::ArithDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { + if (isa(value)) + return builder.create(loc, type, cast(value)); + return ConstantOp::materialize(builder, value, type, loc); } diff --git a/mlir/lib/Dialect/Arith/IR/CMakeLists.txt b/mlir/lib/Dialect/Arith/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Arith/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Arith/IR/CMakeLists.txt @@ -28,6 +28,7 @@ MLIRInferIntRangeInterface MLIRInferTypeOpInterface MLIRIR + MLIRUBDialect ) add_mlir_dialect_library(MLIRArithValueBoundsOpInterfaceImpl diff --git a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp --- a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp +++ b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp @@ -9,7 +9,6 @@ #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/CommonFolders.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" diff --git a/mlir/lib/Dialect/Math/IR/CMakeLists.txt b/mlir/lib/Dialect/Math/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Math/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Math/IR/CMakeLists.txt @@ -12,4 +12,5 @@ MLIRArithDialect MLIRDialect MLIRIR + MLIRUBDialect ) diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp --- a/mlir/lib/Dialect/Math/IR/MathOps.cpp +++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp @@ -522,5 +522,8 @@ Operation *math::MathDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { + if (isa(value)) + return builder.create(loc, type, cast(value)); + return arith::ConstantOp::materialize(builder, value, type, loc); } diff --git a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt --- a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt @@ -42,4 +42,5 @@ MLIRSideEffectInterfaces MLIRSupport MLIRTransforms + MLIRUBDialect ) diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" @@ -973,6 +974,9 @@ Operation *SPIRVDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { + if (isa(value)) + return builder.create(loc, type, cast(value)); + if (!spirv::ConstantOp::isBuildableWith(type)) return nullptr; diff --git a/mlir/lib/Dialect/Shape/IR/CMakeLists.txt b/mlir/lib/Dialect/Shape/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Shape/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Shape/IR/CMakeLists.txt @@ -22,4 +22,5 @@ MLIRIR MLIRSideEffectInterfaces MLIRTensorDialect + MLIRUBDialect ) diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -147,6 +147,9 @@ Operation *ShapeDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { + if (isa(value)) + return builder.create(loc, type, cast(value)); + if (llvm::isa(type) || isExtentTensorType(type)) return builder.create( loc, type, llvm::cast(value)); @@ -156,6 +159,7 @@ if (llvm::isa(type)) return builder.create(loc, type, llvm::cast(value)); + return arith::ConstantOp::materialize(builder, value, type, loc); } diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -2567,3 +2567,49 @@ %2 = arith.ori %arg0, %1 : index return %2 : index } + +// CHECK-LABEL: @addi_poison1 +// CHECK: %[[P:.*]] = ub.poison : i32 +// CHECK: return %[[P]] +func.func @addi_poison1(%arg: i32) -> i32 { + %0 = ub.poison : i32 + %1 = arith.addi %0, %arg : i32 + return %1 : i32 +} + +// CHECK-LABEL: @addi_poison2 +// CHECK: %[[P:.*]] = ub.poison : i32 +// CHECK: return %[[P]] +func.func @addi_poison2(%arg: i32) -> i32 { + %0 = ub.poison : i32 + %1 = arith.addi %arg, %0 : i32 + return %1 : i32 +} + +// CHECK-LABEL: @addf_poison1 +// CHECK: %[[P:.*]] = ub.poison : f32 +// CHECK: return %[[P]] +func.func @addf_poison1(%arg: f32) -> f32 { + %0 = ub.poison : f32 + %1 = arith.addf %0, %arg : f32 + return %1 : f32 +} + +// CHECK-LABEL: @addf_poison2 +// CHECK: %[[P:.*]] = ub.poison : f32 +// CHECK: return %[[P]] +func.func @addf_poison2(%arg: f32) -> f32 { + %0 = ub.poison : f32 + %1 = arith.addf %arg, %0 : f32 + return %1 : f32 +} + + +// CHECK-LABEL: @negf_poison( +// CHECK: %[[P:.*]] = ub.poison : f32 +// CHECK: return %[[P]] +func.func @negf_poison() -> f32 { + %0 = ub.poison : f32 + %1 = arith.negf %0 : f32 + return %1 : f32 +} diff --git a/mlir/test/Dialect/Math/canonicalize.mlir b/mlir/test/Dialect/Math/canonicalize.mlir --- a/mlir/test/Dialect/Math/canonicalize.mlir +++ b/mlir/test/Dialect/Math/canonicalize.mlir @@ -483,3 +483,12 @@ %0 = math.erf %v1 : vector<4xf32> return %0 : vector<4xf32> } + +// CHECK-LABEL: @abs_poison +// CHECK: %[[P:.*]] = ub.poison : f32 +// CHECK: return %[[P]] +func.func @abs_poison() -> f32 { + %0 = ub.poison : f32 + %1 = math.absf %0 : f32 + return %1 : f32 +} diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir --- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir @@ -325,6 +325,15 @@ return %0: vector<3xi32> } +// CHECK-LABEL: @iadd_poison +// CHECK: %[[P:.*]] = ub.poison : i32 +// CHECK: return %[[P]] +func.func @iadd_poison(%arg0: i32) -> i32 { + %0 = ub.poison : i32 + %1 = spirv.IAdd %arg0, %0 : i32 + return %1: i32 +} + // ----- //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -1479,3 +1479,16 @@ // CHECK: return %[[DIM]] return %result : index } + + +// ----- + +// CHECK-LABEL: @add_poison +// CHECK: %[[P:.*]] = ub.poison : !shape.siz +// CHECK: return %[[P]] +func.func @add_poison() -> !shape.size { + %1 = shape.const_size 2 + %2 = ub.poison : !shape.size + %result = shape.add %1, %2 : !shape.size, !shape.size -> !shape.size + return %result : !shape.size +}