diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -161,6 +161,30 @@ return success(); } }; + +struct ZeroOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::ZeroOp zero, OpAdaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + auto ty = convertType(zero.getType()); + if (ty.isa()) { + rewriter.replaceOpWithNewOp(zero, ty); + } else if (ty.isa()) { + rewriter.replaceOpWithNewOp( + zero, ty, mlir::IntegerAttr::get(zero.getType(), 0)); + } else if (mlir::LLVM::isCompatibleFloatingPointType(ty)) { + rewriter.replaceOpWithNewOp( + zero, ty, mlir::FloatAttr::get(zero.getType(), 0.0)); + } else { + // TODO: create ConstantAggregateZero for FIR aggregate types. + return zero.emitOpError( + "conversion of fir.zero with aggregate type not implemented yet"); + } + return success(); + } +}; } // namespace namespace { @@ -180,7 +204,7 @@ auto loc = mlir::UnknownLoc::get(context); mlir::OwningRewritePatternList pattern(context); pattern.insert(typeConverter); + UndefOpConversion, ZeroOpConversion>(typeConverter); mlir::populateStdToLLVMConversionPatterns(typeConverter, pattern); mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter, pattern); @@ -193,7 +217,6 @@ // apply the patterns if (mlir::failed(mlir::applyFullConversion(getModule(), target, std::move(pattern)))) { - mlir::emitError(loc, "error in converting to LLVM-IR dialect\n"); signalPassFailure(); } } diff --git a/flang/test/Fir/convert-to-llvm-invalid.fir b/flang/test/Fir/convert-to-llvm-invalid.fir new file mode 100644 --- /dev/null +++ b/flang/test/Fir/convert-to-llvm-invalid.fir @@ -0,0 +1,11 @@ +// Test FIR to LLVM IR conversion invalid cases and diagnostics. + +// RUN: fir-opt --split-input-file --fir-to-llvm-ir --verify-diagnostics %s + +func @zero_aggregate() { + // expected-error@+2{{'fir.zero_bits' op conversion of fir.zero with aggregate type not implemented yet}} + // expected-error@+1{{failed to legalize operation 'fir.zero_bits'}} + %a = fir.zero_bits !fir.array<10xf32> + return +} + diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir --- a/flang/test/Fir/convert-to-llvm.fir +++ b/flang/test/Fir/convert-to-llvm.fir @@ -81,3 +81,55 @@ // CHECK: %[[CST:.*]] = llvm.mlir.constant(dense<1> : vector<32x32xi32>) : !llvm.array<32 x array<32 x i32>> // CHECK: llvm.return %[[CST]] : !llvm.array<32 x array<32 x i32>> // CHECK: } + +// ----- + +// Test fir.zero_bits operation with LLVM ptr type + +func @zero_test_ptr() { + %z = fir.zero_bits !llvm.ptr + return +} + +// CHECK: %{{.*}} = llvm.mlir.null : !llvm.ptr +// CHECK-NOT: fir.zero_bits + +// ----- + +// Test fir.zero_bits operation with integer type. + +func @zero_test_integer() { + %z0 = fir.zero_bits i8 + %z1 = fir.zero_bits i16 + %z2 = fir.zero_bits i32 + %z3 = fir.zero_bits i64 + return +} + +// CHECK: %{{.*}} = llvm.mlir.constant(0 : i8) : i8 +// CHECK: %{{.*}} = llvm.mlir.constant(0 : i16) : i16 +// CHECK: %{{.*}} = llvm.mlir.constant(0 : i32) : i32 +// CHECK: %{{.*}} = llvm.mlir.constant(0 : i64) : i64 +// CHECK-NOT: fir.zero_bits + +// ----- + +// Test fir.zero_bits operation with floating points types. + +func @zero_test_float() { + %z0 = fir.zero_bits f16 + %z1 = fir.zero_bits bf16 + %z2 = fir.zero_bits f32 + %z3 = fir.zero_bits f64 + %z4 = fir.zero_bits f80 + %z5 = fir.zero_bits f128 + return +} + +// CHECK: %{{.*}} = llvm.mlir.constant(0.000000e+00 : f16) : f16 +// CHECK: %{{.*}} = llvm.mlir.constant(0.000000e+00 : bf16) : bf16 +// CHECK: %{{.*}} = llvm.mlir.constant(0.000000e+00 : f32) : f32 +// CHECK: %{{.*}} = llvm.mlir.constant(0.000000e+00 : f64) : f64 +// CHECK: %{{.*}} = llvm.mlir.constant(0.000000e+00 : f80) : f80 +// CHECK: %{{.*}} = llvm.mlir.constant(0.000000e+00 : f128) : f128 +// CHECK-NOT: fir.zero_bits