diff --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp --- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp @@ -907,8 +907,20 @@ Value rhs = processValue(inst->getOperand(1)); if (!lhs || !rhs) return failure(); + + if (lhs.getType() != rhs.getType()) + return failure(); + + Type boolType = b.getI1Type(); + Type resType = boolType; + if (LLVM::isCompatibleVectorType(lhs.getType())) { + unsigned numElements = + LLVM::getVectorNumElements(lhs.getType()).getFixedValue(); + resType = VectorType::get({numElements}, boolType); + } + instMap[inst] = b.create( - loc, b.getI1Type(), + loc, resType, getFCmpPredicate(cast(inst)->getPredicate()), lhs, rhs); return success(); } diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir --- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir +++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir @@ -449,7 +449,7 @@ } // CHECK-LABEL: @fastmathFlags -func.func @fastmathFlags(%arg0: f32, %arg1: f32, %arg2: i32) { +func.func @fastmathFlags(%arg0: f32, %arg1: f32, %arg2: i32, %arg3: vector<2 x f32>, %arg4: vector<2 x f32>) { // CHECK: {{.*}} = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath} : f32 // CHECK: {{.*}} = llvm.fsub %arg0, %arg1 {fastmathFlags = #llvm.fastmath} : f32 // CHECK: {{.*}} = llvm.fmul %arg0, %arg1 {fastmathFlags = #llvm.fastmath} : f32 @@ -461,8 +461,14 @@ %3 = llvm.fdiv %arg0, %arg1 {fastmathFlags = #llvm.fastmath} : f32 %4 = llvm.frem %arg0, %arg1 {fastmathFlags = #llvm.fastmath} : f32 -// CHECK: {{.*}} = llvm.fcmp "oeq" %arg0, %arg1 {fastmathFlags = #llvm.fastmath} : f32 +// CHECK: %[[SCALAR_PRED0:.+]] = llvm.fcmp "oeq" %arg0, %arg1 {fastmathFlags = #llvm.fastmath} : f32 %5 = llvm.fcmp "oeq" %arg0, %arg1 {fastmathFlags = #llvm.fastmath} : f32 +// CHECK: %{{.*}} = llvm.add %[[SCALAR_PRED0]], %[[SCALAR_PRED0]] : i1 + %typecheck_5 = llvm.add %5, %5 : i1 +// CHECK: %[[VEC_PRED0:.+]] = llvm.fcmp "oeq" %arg3, %arg4 {fastmathFlags = #llvm.fastmath} : vector<2xf32> + %vcmp = llvm.fcmp "oeq" %arg3, %arg4 {fastmathFlags = #llvm.fastmath} : vector<2xf32> +// CHECK: %{{.*}} = llvm.add %[[VEC_PRED0]], %[[VEC_PRED0]] : vector<2xi1> + %typecheck_vcmp = llvm.add %vcmp, %vcmp : vector<2xi1> // CHECK: {{.*}} = llvm.fneg %arg0 {fastmathFlags = #llvm.fastmath} : f32 %6 = llvm.fneg %arg0 {fastmathFlags = #llvm.fastmath} : f32