diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -96,6 +96,121 @@ } }; +static mlir::Type getComplexEleTy(mlir::Type complex) { + if (auto cc = complex.dyn_cast()) + return cc.getElementType(); + return complex.cast().getElementType(); +} + +/// convert value of from-type to value of to-type +struct ConvertOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + static bool isFloatingPointTy(mlir::Type ty) { + return ty.isa(); + } + + mlir::LogicalResult + matchAndRewrite(fir::ConvertOp convert, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + auto fromTy = convertType(convert.value().getType()); + auto toTy = convertType(convert.res().getType()); + mlir::Value op0 = adaptor.getOperands()[0]; + if (fromTy == toTy) { + rewriter.replaceOp(convert, op0); + return success(); + } + auto loc = convert.getLoc(); + auto convertFpToFp = [&](mlir::Value val, unsigned fromBits, + unsigned toBits, mlir::Type toTy) -> mlir::Value { + if (fromBits == toBits) { + // TODO: Converting between two floating-point representations with the + // same bitwidth is not allowed for now. + mlir::emitError(loc, + "cannot implicitly convert between two floating-point " + "representations of the same bitwidth"); + return {}; + } + if (fromBits > toBits) + return rewriter.create(loc, toTy, val); + return rewriter.create(loc, toTy, val); + }; + // Complex to complex conversion. + if (fir::isa_complex(convert.value().getType()) && + fir::isa_complex(convert.res().getType())) { + // Special case: handle the conversion of a complex such that both the + // real and imaginary parts are converted together. + auto zero = mlir::ArrayAttr::get(convert.getContext(), + rewriter.getI32IntegerAttr(0)); + auto one = mlir::ArrayAttr::get(convert.getContext(), + rewriter.getI32IntegerAttr(1)); + auto ty = convertType(getComplexEleTy(convert.value().getType())); + auto rp = rewriter.create(loc, ty, op0, zero); + auto ip = rewriter.create(loc, ty, op0, one); + auto nt = convertType(getComplexEleTy(convert.res().getType())); + auto fromBits = mlir::LLVM::getPrimitiveTypeSizeInBits(ty); + auto toBits = mlir::LLVM::getPrimitiveTypeSizeInBits(nt); + auto rc = convertFpToFp(rp, fromBits, toBits, nt); + auto ic = convertFpToFp(ip, fromBits, toBits, nt); + auto un = rewriter.create(loc, toTy); + auto i1 = + rewriter.create(loc, toTy, un, rc, zero); + rewriter.replaceOpWithNewOp(convert, toTy, i1, + ic, one); + return mlir::success(); + } + // Floating point to floating point conversion. + if (isFloatingPointTy(fromTy)) { + if (isFloatingPointTy(toTy)) { + auto fromBits = mlir::LLVM::getPrimitiveTypeSizeInBits(fromTy); + auto toBits = mlir::LLVM::getPrimitiveTypeSizeInBits(toTy); + auto v = convertFpToFp(op0, fromBits, toBits, toTy); + rewriter.replaceOp(convert, v); + return mlir::success(); + } + if (toTy.isa()) { + rewriter.replaceOpWithNewOp(convert, toTy, op0); + return mlir::success(); + } + } else if (fromTy.isa()) { + // Integer to integer conversion. + if (toTy.isa()) { + auto fromBits = mlir::LLVM::getPrimitiveTypeSizeInBits(fromTy); + auto toBits = mlir::LLVM::getPrimitiveTypeSizeInBits(toTy); + assert(fromBits != toBits); + if (fromBits > toBits) { + rewriter.replaceOpWithNewOp(convert, toTy, op0); + return mlir::success(); + } + rewriter.replaceOpWithNewOp(convert, toTy, op0); + return mlir::success(); + } + // Integer to floating point conversion. + if (isFloatingPointTy(toTy)) { + rewriter.replaceOpWithNewOp(convert, toTy, op0); + return mlir::success(); + } + // Integer to pointer conversion. + if (toTy.isa()) { + rewriter.replaceOpWithNewOp(convert, toTy, op0); + return mlir::success(); + } + } else if (fromTy.isa()) { + // Pointer to integer conversion. + if (toTy.isa()) { + rewriter.replaceOpWithNewOp(convert, toTy, op0); + return mlir::success(); + } + // Pointer to pointer conversion. + if (toTy.isa()) { + rewriter.replaceOpWithNewOp(convert, toTy, op0); + return mlir::success(); + } + } + return emitError(loc) << "cannot convert " << fromTy << " to " << toTy; + } +}; + /// Lower `fir.has_value` operation to `llvm.return` operation. struct HasValueOpConversion : public FIROpConversion { using FIROpConversion::FIROpConversion; @@ -489,12 +604,6 @@ } }; -static mlir::Type getComplexEleTy(mlir::Type complex) { - if (auto cc = complex.dyn_cast()) - return cc.getElementType(); - return complex.cast().getElementType(); -} - // // Primitive operations on Complex types // @@ -679,13 +788,14 @@ auto *context = getModule().getContext(); fir::LLVMTypeConverter typeConverter{getModule()}; mlir::OwningRewritePatternList pattern(context); - pattern.insert(typeConverter); + pattern + .insert(typeConverter); mlir::populateStdToLLVMConversionPatterns(typeConverter, pattern); mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter, pattern); diff --git a/flang/lib/Optimizer/CodeGen/TypeConverter.h b/flang/lib/Optimizer/CodeGen/TypeConverter.h --- a/flang/lib/Optimizer/CodeGen/TypeConverter.h +++ b/flang/lib/Optimizer/CodeGen/TypeConverter.h @@ -169,24 +169,6 @@ return fromRealTypeID(kindMapping.getRealTypeID(kind), kind); } - // Use the target specifics to figure out how to map complex to LLVM IR. The - // use of complex values in function signatures is handled before conversion - // to LLVM IR dialect here. - // - // fir.complex | std.complex --> llvm<"{t,t}"> - template - mlir::Type convertComplexType(C cmplx) { - LLVM_DEBUG(llvm::dbgs() << "type convert: " << cmplx << '\n'); - auto eleTy = cmplx.getElementType(); - return convertType(specifics->complexMemoryType(eleTy)); - } - - // convert a front-end kind value to either a std or LLVM IR dialect type - // fir.real --> llvm.anyfloat where anyfloat is a kind mapping - mlir::Type convertRealType(fir::KindTy kind) { - return fromRealTypeID(kindMapping.getRealTypeID(kind), kind); - } - template mlir::Type convertPointerLike(A &ty) { mlir::Type eleTy = ty.getEleTy(); diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir --- a/flang/test/Fir/convert-to-llvm.fir +++ b/flang/test/Fir/convert-to-llvm.fir @@ -514,3 +514,121 @@ // CHECK: %{{.*}} = llvm.insertvalue %[[NEGX]], %{{.*}}[0 : i32] : !llvm.struct<(f128, f128)> // CHECK: %{{.*}} = llvm.insertvalue %[[NEGY]], %{{.*}}[1 : i32] : !llvm.struct<(f128, f128)> // CHECK: llvm.return %{{.*}} : !llvm.struct<(f128, f128)> + +// ----- + +// Test `fir.convert` operation conversion from Float type. + +func @convert_from_float(%arg0 : f32) { + %0 = fir.convert %arg0 : (f32) -> f16 + %1 = fir.convert %arg0 : (f32) -> f32 + %2 = fir.convert %arg0 : (f32) -> f64 + %3 = fir.convert %arg0 : (f32) -> f80 + %4 = fir.convert %arg0 : (f32) -> f128 + %5 = fir.convert %arg0 : (f32) -> i1 + %6 = fir.convert %arg0 : (f32) -> i8 + %7 = fir.convert %arg0 : (f32) -> i16 + %8 = fir.convert %arg0 : (f32) -> i32 + %9 = fir.convert %arg0 : (f32) -> i64 + return +} + +// CHECK-LABEL: convert_from_float( +// CHECK-SAME: %[[ARG0:.*]]: f32 +// CHECK: %{{.*}} = llvm.fptrunc %[[ARG0]] : f32 to f16 +// CHECK-NOT: f32 to f32 +// CHECK: %{{.*}} = llvm.fpext %[[ARG0]] : f32 to f64 +// CHECK: %{{.*}} = llvm.fpext %[[ARG0]] : f32 to f80 +// CHECK: %{{.*}} = llvm.fpext %[[ARG0]] : f32 to f128 +// CHECK: %{{.*}} = llvm.fptosi %[[ARG0]] : f32 to i1 +// CHECK: %{{.*}} = llvm.fptosi %[[ARG0]] : f32 to i8 +// CHECK: %{{.*}} = llvm.fptosi %[[ARG0]] : f32 to i16 +// CHECK: %{{.*}} = llvm.fptosi %[[ARG0]] : f32 to i32 +// CHECK: %{{.*}} = llvm.fptosi %[[ARG0]] : f32 to i64 + +// ----- + +// Test `fir.convert` operation conversion from Integer type. + +func @convert_from_int(%arg0 : i32) { + %0 = fir.convert %arg0 : (i32) -> f16 + %1 = fir.convert %arg0 : (i32) -> f32 + %2 = fir.convert %arg0 : (i32) -> f64 + %3 = fir.convert %arg0 : (i32) -> f80 + %4 = fir.convert %arg0 : (i32) -> f128 + %5 = fir.convert %arg0 : (i32) -> i1 + %6 = fir.convert %arg0 : (i32) -> i8 + %7 = fir.convert %arg0 : (i32) -> i16 + %8 = fir.convert %arg0 : (i32) -> i32 + %9 = fir.convert %arg0 : (i32) -> i64 + %10 = fir.convert %arg0 : (i32) -> i64 + %ptr = fir.convert %10 : (i64) -> !fir.ref + return +} + +// CHECK-LABEL: convert_from_int( +// CHECK-SAME: %[[ARG0:.*]]: i32 +// CHECK: %{{.*}} = llvm.sitofp %[[ARG0]] : i32 to f16 +// CHECK: %{{.*}} = llvm.sitofp %[[ARG0]] : i32 to f32 +// CHECK: %{{.*}} = llvm.sitofp %[[ARG0]] : i32 to f64 +// CHECK: %{{.*}} = llvm.sitofp %[[ARG0]] : i32 to f80 +// CHECK: %{{.*}} = llvm.sitofp %[[ARG0]] : i32 to f128 +// CHECK: %{{.*}} = llvm.trunc %[[ARG0]] : i32 to i1 +// CHECK: %{{.*}} = llvm.trunc %[[ARG0]] : i32 to i8 +// CHECK: %{{.*}} = llvm.trunc %[[ARG0]] : i32 to i16 +// CHECK-NOT: %{{.*}} = llvm.trunc %[[ARG0]] : i32 to i32 +// CHECK: %{{.*}} = llvm.sext %[[ARG0]] : i32 to i64 +// CHECK: %{{.*}} = llvm.inttoptr %{{.*}} : i64 to !llvm.ptr + +// ----- + +// Test `fir.convert` operation conversion from !fir.ref<> type. + +func @convert_from_ref(%arg0 : !fir.ref) { + %0 = fir.convert %arg0 : (!fir.ref) -> !fir.ref + %1 = fir.convert %arg0 : (!fir.ref) -> i32 + return +} + +// CHECK-LABEL: convert_from_ref( +// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr +// CHECK: %{{.*}} = llvm.bitcast %[[ARG0]] : !llvm.ptr to !llvm.ptr +// CHECK: %{{.*}} = llvm.ptrtoint %[[ARG0]] : !llvm.ptr to i32 + +// ----- + +// Test `fir.convert` operation conversion between fir.complex types. + +func @convert_complex4(%arg0 : !fir.complex<4>) -> !fir.complex<8> { + %0 = fir.convert %arg0 : (!fir.complex<4>) -> !fir.complex<8> + return %0 : !fir.complex<8> +} + +// CHECK-LABEL: func @convert_complex4( +// CHECK-SAME: %[[ARG0:.*]]: !llvm.struct<(f32, f32)>) -> !llvm.struct<(f64, f64)> +// CHECK: %[[X:.*]] = llvm.extractvalue %[[ARG0]][0 : i32] : !llvm.struct<(f32, f32)> +// CHECK: %[[Y:.*]] = llvm.extractvalue %[[ARG0]][1 : i32] : !llvm.struct<(f32, f32)> +// CHECK: %[[CONVERTX:.*]] = llvm.fpext %[[X]] : f32 to f64 +// CHECK: %[[CONVERTY:.*]] = llvm.fpext %[[Y]] : f32 to f64 +// CHECK: %[[STRUCT0:.*]] = llvm.mlir.undef : !llvm.struct<(f64, f64)> +// CHECK: %[[STRUCT1:.*]] = llvm.insertvalue %[[CONVERTX]], %[[STRUCT0]][0 : i32] : !llvm.struct<(f64, f64)> +// CHECK: %[[STRUCT2:.*]] = llvm.insertvalue %[[CONVERTY]], %[[STRUCT1]][1 : i32] : !llvm.struct<(f64, f64)> +// CHECK: llvm.return %[[STRUCT2]] : !llvm.struct<(f64, f64)> + +// Test `fir.convert` operation conversion between fir.complex types. + +func @convert_complex16(%arg0 : !fir.complex<16>) -> !fir.complex<2> { + %0 = fir.convert %arg0 : (!fir.complex<16>) -> !fir.complex<2> + return %0 : !fir.complex<2> +} + +// CHECK-LABEL: func @convert_complex16( +// CHECK-SAME: %[[ARG0:.*]]: !llvm.struct<(f128, f128)>) -> !llvm.struct<(f16, f16)> +// CHECK: %[[X:.*]] = llvm.extractvalue %[[ARG0]][0 : i32] : !llvm.struct<(f128, f128)> +// CHECK: %[[Y:.*]] = llvm.extractvalue %[[ARG0]][1 : i32] : !llvm.struct<(f128, f128)> +// CHECK: %[[CONVERTX:.*]] = llvm.fptrunc %[[X]] : f128 to f16 +// CHECK: %[[CONVERTY:.*]] = llvm.fptrunc %[[Y]] : f128 to f16 +// CHECK: %[[STRUCT0:.*]] = llvm.mlir.undef : !llvm.struct<(f16, f16)> +// CHECK: %[[STRUCT1:.*]] = llvm.insertvalue %[[CONVERTX]], %[[STRUCT0]][0 : i32] : !llvm.struct<(f16, f16)> +// CHECK: %[[STRUCT2:.*]] = llvm.insertvalue %[[CONVERTY]], %[[STRUCT1]][1 : i32] : !llvm.struct<(f16, f16)> +// CHECK: llvm.return %[[STRUCT2]] : !llvm.struct<(f16, f16)>