diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp --- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp +++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp @@ -13,6 +13,7 @@ #include "mlir/Conversion/LLVMCommon/VectorPattern.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Pass/Pass.h" #include @@ -173,6 +174,14 @@ ConversionPatternRewriter &rewriter) const override; }; +struct PoisonOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(ub::PoisonOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + } // namespace //===----------------------------------------------------------------------===// @@ -412,6 +421,30 @@ rewriter); } +//===----------------------------------------------------------------------===// +// PoisonOpLowering +//===----------------------------------------------------------------------===// + +LogicalResult +PoisonOpLowering::matchAndRewrite(ub::PoisonOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Type origType = op.getType(); + if (!origType.isIntOrIndexOrFloat()) + return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { + diag << "only index, integer and float types are supported, got " + << origType; + }); + + Type resType = getTypeConverter()->convertType(origType); + if (!resType) + return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { + diag << "failed to convert result type " << origType; + }); + + rewriter.replaceOpWithNewOp(op, resType); + return success(); +} + //===----------------------------------------------------------------------===// // Pass Definition //===----------------------------------------------------------------------===// @@ -477,6 +510,7 @@ MulUIExtendedOpLowering, NegFOpLowering, OrIOpLowering, + PoisonOpLowering, RemFOpLowering, RemSIOpLowering, RemUIOpLowering, diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir --- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir +++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir @@ -564,3 +564,17 @@ %7 = arith.subf %arg0, %arg1 fastmath : f32 return } + + +// ----- + +// CHECK-LABEL: @check_poison +func.func @check_poison() { +// CHECK: {{.*}} = llvm.mlir.poison : i64 + %0 = ub.poison : index +// CHECK: {{.*}} = llvm.mlir.poison : i16 + %1 = ub.poison : i16 +// CHECK: {{.*}} = llvm.mlir.poison : f64 + %2 = ub.poison : f64 + return +}