diff --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp --- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp @@ -347,10 +347,14 @@ if (c->isString()) return b.getStringAttr(c->getAsString()); if (auto *c = dyn_cast(value)) { - if (c->getType()->isDoubleTy()) - return b.getFloatAttr(FloatType::getF64(context), c->getValueAPF()); - if (c->getType()->isFloatingPointTy()) - return b.getFloatAttr(FloatType::getF32(context), c->getValueAPF()); + auto *type = c->getType(); + FloatType floatTy; + if (type->isBFloatTy()) + floatTy = FloatType::getBF16(context); + else + floatTy = getDLFloatType(*context, type->getScalarSizeInBits()); + assert(floatTy && "unsupported floating point type"); + return b.getFloatAttr(floatTy, c->getValueAPF()); } if (auto *f = dyn_cast(value)) return SymbolRefAttr::get(b.getContext(), f->getName()); @@ -607,7 +611,7 @@ // FIXME: cleanuppad // FIXME: catchpad // ICmp is handled specially. - // FIXME: fcmp + // FCmp is handled specially. // PHI is handled specially. INST(Freeze, Freeze), INST(Call, Call), // FIXME: select @@ -649,7 +653,47 @@ case llvm::CmpInst::Predicate::ICMP_UGE: return LLVM::ICmpPredicate::uge; } - llvm_unreachable("incorrect comparison predicate"); + llvm_unreachable("incorrect integer comparison predicate"); +} + +static FCmpPredicate getFCmpPredicate(llvm::CmpInst::Predicate p) { + switch (p) { + default: + llvm_unreachable("incorrect comparison predicate"); + case llvm::CmpInst::Predicate::FCMP_FALSE: + return LLVM::FCmpPredicate::_false; + case llvm::CmpInst::Predicate::FCMP_TRUE: + return LLVM::FCmpPredicate::_true; + case llvm::CmpInst::Predicate::FCMP_OEQ: + return LLVM::FCmpPredicate::oeq; + case llvm::CmpInst::Predicate::FCMP_ONE: + return LLVM::FCmpPredicate::one; + case llvm::CmpInst::Predicate::FCMP_OLT: + return LLVM::FCmpPredicate::olt; + case llvm::CmpInst::Predicate::FCMP_OLE: + return LLVM::FCmpPredicate::ole; + case llvm::CmpInst::Predicate::FCMP_OGT: + return LLVM::FCmpPredicate::ogt; + case llvm::CmpInst::Predicate::FCMP_OGE: + return LLVM::FCmpPredicate::oge; + case llvm::CmpInst::Predicate::FCMP_ORD: + return LLVM::FCmpPredicate::ord; + case llvm::CmpInst::Predicate::FCMP_ULT: + return LLVM::FCmpPredicate::ult; + case llvm::CmpInst::Predicate::FCMP_ULE: + return LLVM::FCmpPredicate::ule; + case llvm::CmpInst::Predicate::FCMP_UGT: + return LLVM::FCmpPredicate::ugt; + case llvm::CmpInst::Predicate::FCMP_UGE: + return LLVM::FCmpPredicate::uge; + case llvm::CmpInst::Predicate::FCMP_UNO: + return LLVM::FCmpPredicate::uno; + case llvm::CmpInst::Predicate::FCMP_UEQ: + return LLVM::FCmpPredicate::ueq; + case llvm::CmpInst::Predicate::FCMP_UNE: + return LLVM::FCmpPredicate::une; + } + llvm_unreachable("incorrect floating point comparison predicate"); } static AtomicOrdering getLLVMAtomicOrdering(llvm::AtomicOrdering ordering) { @@ -774,6 +818,16 @@ rhs); return success(); } + case llvm::Instruction::FCmp: { + Value lhs = processValue(inst->getOperand(0)); + Value rhs = processValue(inst->getOperand(1)); + if (!lhs || !rhs) + return failure(); + instMap[inst] = b.create( + loc, b.getI1Type(), + getFCmpPredicate(cast(inst)->getPredicate()), lhs, rhs); + return success(); + } case llvm::Instruction::Br: { auto *brInst = cast(inst); OperationState state(loc, diff --git a/mlir/test/Target/LLVMIR/Import/basic.ll b/mlir/test/Target/LLVMIR/Import/basic.ll --- a/mlir/test/Target/LLVMIR/Import/basic.ll +++ b/mlir/test/Target/LLVMIR/Import/basic.ll @@ -281,6 +281,62 @@ ret void } +; CHECK-LABEL: llvm.func @FPComparison(%arg0: f32, %arg1: f32) +define void @FPComparison(float %a, float %b) { + ; CHECK: llvm.fcmp "_false" %arg0, %arg1 + %1 = fcmp false float %a, %b + ; CHECK: llvm.fcmp "oeq" %arg0, %arg1 + %2 = fcmp oeq float %a, %b + ; CHECK: llvm.fcmp "ogt" %arg0, %arg1 + %3 = fcmp ogt float %a, %b + ; CHECK: llvm.fcmp "oge" %arg0, %arg1 + %4 = fcmp oge float %a, %b + ; CHECK: llvm.fcmp "olt" %arg0, %arg1 + %5 = fcmp olt float %a, %b + ; CHECK: llvm.fcmp "ole" %arg0, %arg1 + %6 = fcmp ole float %a, %b + ; CHECK: llvm.fcmp "one" %arg0, %arg1 + %7 = fcmp one float %a, %b + ; CHECK: llvm.fcmp "ord" %arg0, %arg1 + %8 = fcmp ord float %a, %b + ; CHECK: llvm.fcmp "ueq" %arg0, %arg1 + %9 = fcmp ueq float %a, %b + ; CHECK: llvm.fcmp "ugt" %arg0, %arg1 + %10 = fcmp ugt float %a, %b + ; CHECK: llvm.fcmp "uge" %arg0, %arg1 + %11 = fcmp uge float %a, %b + ; CHECK: llvm.fcmp "ult" %arg0, %arg1 + %12 = fcmp ult float %a, %b + ; CHECK: llvm.fcmp "ule" %arg0, %arg1 + %13 = fcmp ule float %a, %b + ; CHECK: llvm.fcmp "une" %arg0, %arg1 + %14 = fcmp une float %a, %b + ; CHECK: llvm.fcmp "uno" %arg0, %arg1 + %15 = fcmp uno float %a, %b + ; CHECK: llvm.fcmp "_true" %arg0, %arg1 + %16 = fcmp true float %a, %b + ret void +} + +; Testing rest of the floating point constant kinds. +; CHECK-LABEL: llvm.func @FPConstant(%arg0: f16, %arg1: bf16, %arg2: f128, %arg3: f80) +define void @FPConstant(half %a, bfloat %b, fp128 %c, x86_fp80 %d) { + ; CHECK-DAG: %[[C0:.+]] = llvm.mlir.constant(7.000000e+00 : f80) : f80 + ; CHECK-DAG: %[[C1:.+]] = llvm.mlir.constant(0.000000e+00 : f128) : f128 + ; CHECK-DAG: %[[C2:.+]] = llvm.mlir.constant(1.000000e+00 : bf16) : bf16 + ; CHECK-DAG: %[[C3:.+]] = llvm.mlir.constant(1.000000e+00 : f16) : f16 + + ; CHECK: llvm.fadd %[[C3]], %arg0 : f16 + %1 = fadd half 1.0, %a + ; CHECK: llvm.fadd %[[C2]], %arg1 : bf16 + %2 = fadd bfloat 1.0, %b + ; CHECK: llvm.fadd %[[C1]], %arg2 : f128 + %3 = fadd fp128 0xL00000000000000000000000000000000, %c + ; CHECK: llvm.fadd %[[C0]], %arg3 : f80 + %4 = fadd x86_fp80 0xK4001E000000000000000, %d + ret void +} + ; ; Functions as constants. ;