diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -286,11 +286,8 @@ $res = builder.CreateICmp(getLLVMCmpPredicate($predicate), $lhs, $rhs); }]; let builders = [ - OpBuilder<(ins "ICmpPredicate":$predicate, "Value":$lhs, "Value":$rhs), - [{ - build($_builder, $_state, IntegerType::get(lhs.getType().getContext(), 1), - predicate, lhs, rhs); - }]>]; + OpBuilder<(ins "ICmpPredicate":$predicate, "Value":$lhs, "Value":$rhs)> + ]; let hasCustomAssemblyFormat = 1; } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -90,8 +90,28 @@ } //===----------------------------------------------------------------------===// -// Printing/parsing for LLVM::CmpOp. -//===----------------------------------------------------------------------===// +// Printing, parsing and builder for LLVM::CmpOp. +//===----------------------------------------------------------------------===// + +void ICmpOp::build(OpBuilder &builder, OperationState &result, + ICmpPredicate predicate, Value lhs, Value rhs) { + auto boolType = IntegerType::get(lhs.getType().getContext(), 1); + if (LLVM::isCompatibleVectorType(lhs.getType()) || + LLVM::isCompatibleVectorType(rhs.getType())) { + int64_t numLHSElements = 1, numRHSElements = 1; + if (LLVM::isCompatibleVectorType(lhs.getType())) + numLHSElements = + LLVM::getVectorNumElements(lhs.getType()).getFixedValue(); + if (LLVM::isCompatibleVectorType(rhs.getType())) + numRHSElements = + LLVM::getVectorNumElements(rhs.getType()).getFixedValue(); + build(builder, result, + VectorType::get({std::max(numLHSElements, numRHSElements)}, boolType), + predicate, lhs, rhs); + } else { + build(builder, result, boolType, predicate, lhs, rhs); + } +} void ICmpOp::print(OpAsmPrinter &p) { p << " \"" << stringifyICmpPredicate(getPredicate()) << "\" " << getOperand(0) diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir --- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir +++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir @@ -14,9 +14,12 @@ // CHECK: {{.*}} = llvm.sdiv %[[I32]], %[[I32]] : i32 // CHECK: {{.*}} = llvm.urem %[[I32]], %[[I32]] : i32 // CHECK: {{.*}} = llvm.srem %[[I32]], %[[I32]] : i32 -// CHECK: {{.*}} = llvm.icmp "ne" %[[I32]], %[[I32]] : i32 -// CHECK: {{.*}} = llvm.icmp "ne" %[[I8PTR1]], %[[I8PTR1]] : !llvm.ptr -// CHECK: {{.*}} = llvm.icmp "ne" %[[VI8PTR1]], %[[VI8PTR1]] : !llvm.vec<2 x ptr> +// CHECK: %[[SCALAR_PRED0:.+]] = llvm.icmp "ne" %[[I32]], %[[I32]] : i32 +// CHECK: {{.*}} = llvm.add %[[SCALAR_PRED0]], %[[SCALAR_PRED0]] : i1 +// CHECK: %[[SCALAR_PRED1:.+]] = llvm.icmp "ne" %[[I8PTR1]], %[[I8PTR1]] : !llvm.ptr +// CHECK: {{.*}} = llvm.add %[[SCALAR_PRED1]], %[[SCALAR_PRED1]] : i1 +// CHECK: %[[VEC_PRED:.+]] = llvm.icmp "ne" %[[VI8PTR1]], %[[VI8PTR1]] : !llvm.vec<2 x ptr> +// CHECK: {{.*}} = llvm.add %[[VEC_PRED]], %[[VEC_PRED]] : vector<2xi1> %0 = llvm.add %arg0, %arg0 : i32 %1 = llvm.sub %arg0, %arg0 : i32 %2 = llvm.mul %arg0, %arg0 : i32 @@ -25,8 +28,11 @@ %5 = llvm.urem %arg0, %arg0 : i32 %6 = llvm.srem %arg0, %arg0 : i32 %7 = llvm.icmp "ne" %arg0, %arg0 : i32 + %typecheck_7 = llvm.add %7, %7 : i1 %ptrcmp = llvm.icmp "ne" %arg2, %arg2 : !llvm.ptr + %typecheck_ptrcmp = llvm.add %ptrcmp, %ptrcmp : i1 %vptrcmp = llvm.icmp "ne" %arg5, %arg5 : !llvm.vec<2 x ptr> + %typecheck_vptrcmp = llvm.add %vptrcmp, %vptrcmp : vector<2 x i1> // Floating point binary operations. //