diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -228,9 +228,16 @@ // - $_builder - substituted with the MLIR builder; // - $_qualCppClassName - substitiuted with the MLIR operation class name. // Additionally, `$$` can be used to produce the dollar character. - // NOTE: The $name variable resolution assumes the MLIR and LLVM argument - // orders match and there are no optional or variadic arguments. + // FIXME: The $name variable resolution does not support variadic arguments. string mlirBuilder = ""; + + // An array that specifies a mapping from MLIR argument indices to LLVM IR + // operand indices. The mapping is necessary since argument and operand + // indices do not always match. If not defined, the array is set to the + // identity permutation. An operation may define any custom index permutation + // and set a specific argument index to -1 if it does not map to an LLVM IR + // operand. + list llvmArgIndices = []; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -284,19 +284,39 @@ let cppNamespace = "::mlir::LLVM"; } +// Base class for compare operations. A compare operation takes two operands +// of the same type and returns a boolean result. If the operands are +// vectors, then the result has to be a boolean vector of the same shape. +class LLVM_ArithmeticCmpOp traits = []> : + LLVM_Op]> { + let results = (outs LLVM_ScalarOrVectorOf:$res); +} + // Other integer operations. -def LLVM_ICmpOp : LLVM_Op<"icmp", [Pure]> { +def LLVM_ICmpOp : LLVM_ArithmeticCmpOp<"icmp", [Pure]> { let arguments = (ins ICmpPredicate:$predicate, - AnyTypeOf<[LLVM_ScalarOrVectorOf, LLVM_ScalarOrVectorOf]>:$lhs, - AnyTypeOf<[LLVM_ScalarOrVectorOf, LLVM_ScalarOrVectorOf]>:$rhs); - let results = (outs LLVM_ScalarOrVectorOf:$res); - let llvmBuilder = [{ - $res = builder.CreateICmp(getLLVMCmpPredicate($predicate), $lhs, $rhs); - }]; + AnyTypeOf<[LLVM_ScalarOrVectorOf, + LLVM_ScalarOrVectorOf]>:$lhs, + AnyTypeOf<[LLVM_ScalarOrVectorOf, + LLVM_ScalarOrVectorOf]>:$rhs); let builders = [ OpBuilder<(ins "ICmpPredicate":$predicate, "Value":$lhs, "Value":$rhs)> ]; let hasCustomAssemblyFormat = 1; + string llvmInstName = "ICmp"; + string llvmBuilder = [{ + $res = builder.CreateICmp(getLLVMCmpPredicate($predicate), $lhs, $rhs); + }]; + string mlirBuilder = [{ + auto *iCmpInst = cast(inst); + $res = $_builder.create<$_qualCppClassName>( + $_location, getICmpPredicate(iCmpInst->getPredicate()), $lhs, $rhs); + }]; + // Set the $predicate index to -1 to indicate there is no matching operand + // and decrement the following indices. + list llvmArgIndices = [-1, 0, 1]; } // Predicate for float comparisons @@ -329,17 +349,29 @@ } // Other floating-point operations. -def LLVM_FCmpOp : LLVM_Op<"fcmp", [ +def LLVM_FCmpOp : LLVM_ArithmeticCmpOp<"fcmp", [ Pure, DeclareOpInterfaceMethods]> { let arguments = (ins FCmpPredicate:$predicate, LLVM_ScalarOrVectorOf:$lhs, LLVM_ScalarOrVectorOf:$rhs, DefaultValuedAttr:$fastmathFlags); - let results = (outs LLVM_ScalarOrVectorOf:$res); - let llvmBuilder = [{ + let builders = [ + OpBuilder<(ins "FCmpPredicate":$predicate, "Value":$lhs, "Value":$rhs)> + ]; + let hasCustomAssemblyFormat = 1; + string llvmInstName = "FCmp"; + string llvmBuilder = [{ $res = builder.CreateFCmp(getLLVMCmpPredicate($predicate), $lhs, $rhs); }]; - let hasCustomAssemblyFormat = 1; + // FIXME: Import fastmath flags. + string mlirBuilder = [{ + auto *fCmpInst = cast(inst); + $res = $_builder.create<$_qualCppClassName>( + $_location, getFCmpPredicate(fCmpInst->getPredicate()), $lhs, $rhs); + }]; + // Set the $predicate index to -1 to indicate there is no matching operand + // and decrement the following indices. + list llvmArgIndices = [-1, 0, 1, 2]; } // Floating point binary operations. diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h @@ -521,6 +521,10 @@ Type getVectorType(Type elementType, unsigned numElements, bool isScalable = false); +/// Creates an LLVM dialect-compatible vector type with the given element type +/// and length. +Type getVectorType(Type elementType, const llvm::ElementCount &numElements); + /// Creates an LLVM dialect-compatible type with the given element type and /// length. Type getFixedVectorType(Type elementType, unsigned numElements); diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -88,28 +88,27 @@ return success(); } +/// Returns a boolean type that has the same shape as `type`. It supports both +/// fixed size vectors as well as scalable vectors. +static Type getI1SameShape(Type type) { + Type i1Type = IntegerType::get(type.getContext(), 1); + if (LLVM::isCompatibleVectorType(type)) + return LLVM::getVectorType(i1Type, LLVM::getVectorNumElements(type)); + return i1Type; +} + //===----------------------------------------------------------------------===// // Printing, parsing and builder for LLVM::CmpOp. //===----------------------------------------------------------------------===// void ICmpOp::build(OpBuilder &builder, OperationState &result, ICmpPredicate predicate, Value lhs, Value rhs) { - auto boolType = IntegerType::get(lhs.getType().getContext(), 1); - if (LLVM::isCompatibleVectorType(lhs.getType()) || - LLVM::isCompatibleVectorType(rhs.getType())) { - int64_t numLHSElements = 1, numRHSElements = 1; - if (LLVM::isCompatibleVectorType(lhs.getType())) - numLHSElements = - LLVM::getVectorNumElements(lhs.getType()).getFixedValue(); - if (LLVM::isCompatibleVectorType(rhs.getType())) - numRHSElements = - LLVM::getVectorNumElements(rhs.getType()).getFixedValue(); - build(builder, result, - VectorType::get({std::max(numLHSElements, numRHSElements)}, boolType), - predicate, lhs, rhs); - } else { - build(builder, result, boolType, predicate, lhs, rhs); - } + build(builder, result, getI1SameShape(lhs.getType()), predicate, lhs, rhs); +} + +void FCmpOp::build(OpBuilder &builder, OperationState &result, + FCmpPredicate predicate, Value lhs, Value rhs) { + build(builder, result, getI1SameShape(lhs.getType()), predicate, lhs, rhs); } void ICmpOp::print(OpAsmPrinter &p) { @@ -132,8 +131,6 @@ // attribute-dict? `:` type template static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) { - Builder &builder = parser.getBuilder(); - StringAttr predicateAttr; OpAsmParser::UnresolvedOperand lhs, rhs; Type type; @@ -173,23 +170,10 @@ // The result type is either i1 or a vector type if the inputs are // vectors. - Type resultType = IntegerType::get(builder.getContext(), 1); if (!isCompatibleType(type)) return parser.emitError(trailingTypeLoc, "expected LLVM dialect-compatible type"); - if (LLVM::isCompatibleVectorType(type)) { - if (LLVM::isScalableVectorType(type)) { - resultType = LLVM::getVectorType( - resultType, LLVM::getVectorNumElements(type).getKnownMinValue(), - /*isScalable=*/true); - } else { - resultType = LLVM::getVectorType( - resultType, LLVM::getVectorNumElements(type).getFixedValue(), - /*isScalable=*/false); - } - } - - result.addTypes({resultType}); + result.addTypes(getI1SameShape(type)); return success(); } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -934,6 +934,15 @@ return VectorType::get(numElements, elementType, (unsigned)isScalable); } +Type mlir::LLVM::getVectorType(Type elementType, + const llvm::ElementCount &numElements) { + if (numElements.isScalable()) + return getVectorType(elementType, numElements.getKnownMinValue(), + /*isScalable=*/true); + return getVectorType(elementType, numElements.getFixedValue(), + /*isScalable=*/false); +} + Type mlir::LLVM::getFixedVectorType(Type elementType, unsigned numElements) { bool useLLVM = LLVMFixedVectorType::isValidElementType(elementType); bool useBuiltIn = VectorType::isValidElementType(elementType); diff --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp --- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp @@ -804,36 +804,6 @@ // Convert all special instructions that do not provide an MLIR builder. Location loc = translateLoc(inst->getDebugLoc()); - if (inst->getOpcode() == llvm::Instruction::ICmp) { - Value lhs = processValue(inst->getOperand(0)); - Value rhs = processValue(inst->getOperand(1)); - Value res = b.create( - loc, getICmpPredicate(cast(inst)->getPredicate()), lhs, - rhs); - mapValue(inst, res); - return success(); - } - if (inst->getOpcode() == llvm::Instruction::FCmp) { - Value lhs = processValue(inst->getOperand(0)); - Value rhs = processValue(inst->getOperand(1)); - - if (lhs.getType() != rhs.getType()) - return failure(); - - Type boolType = b.getI1Type(); - Type resType = boolType; - if (LLVM::isCompatibleVectorType(lhs.getType())) { - unsigned numElements = - LLVM::getVectorNumElements(lhs.getType()).getFixedValue(); - resType = VectorType::get({numElements}, boolType); - } - - Value res = b.create( - loc, resType, - getFCmpPredicate(cast(inst)->getPredicate()), lhs, rhs); - mapValue(inst, res); - return success(); - } if (inst->getOpcode() == llvm::Instruction::Br) { auto *brInst = cast(inst); OperationState state(loc, diff --git a/mlir/test/Target/LLVMIR/Import/basic.ll b/mlir/test/Target/LLVMIR/Import/basic.ll --- a/mlir/test/Target/LLVMIR/Import/basic.ll +++ b/mlir/test/Target/LLVMIR/Import/basic.ll @@ -227,28 +227,6 @@ ret i32* bitcast (double* @g2 to i32*) } -; CHECK-LABEL: llvm.func @f5 -define void @f5(i32 %d) { -; FIXME: icmp should return i1. -; CHECK: = llvm.icmp "eq" - %1 = icmp eq i32 %d, 2 -; CHECK: = llvm.icmp "slt" - %2 = icmp slt i32 %d, 2 -; CHECK: = llvm.icmp "sle" - %3 = icmp sle i32 %d, 2 -; CHECK: = llvm.icmp "sgt" - %4 = icmp sgt i32 %d, 2 -; CHECK: = llvm.icmp "sge" - %5 = icmp sge i32 %d, 2 -; CHECK: = llvm.icmp "ult" - %6 = icmp ult i32 %d, 2 -; CHECK: = llvm.icmp "ule" - %7 = icmp ule i32 %d, 2 -; CHECK: = llvm.icmp "ugt" - %8 = icmp ugt i32 %d, 2 - ret void -} - ; CHECK-LABEL: llvm.func @f6(%arg0: !llvm.ptr>) define void @f6(void (i16) *%fn) { ; CHECK: %[[c:[0-9]+]] = llvm.mlir.constant(0 : i16) : i16 @@ -257,43 +235,6 @@ ret void } -; CHECK-LABEL: llvm.func @FPComparison(%arg0: f32, %arg1: f32) -define void @FPComparison(float %a, float %b) { - ; CHECK: llvm.fcmp "_false" %arg0, %arg1 - %1 = fcmp false float %a, %b - ; CHECK: llvm.fcmp "oeq" %arg0, %arg1 - %2 = fcmp oeq float %a, %b - ; CHECK: llvm.fcmp "ogt" %arg0, %arg1 - %3 = fcmp ogt float %a, %b - ; CHECK: llvm.fcmp "oge" %arg0, %arg1 - %4 = fcmp oge float %a, %b - ; CHECK: llvm.fcmp "olt" %arg0, %arg1 - %5 = fcmp olt float %a, %b - ; CHECK: llvm.fcmp "ole" %arg0, %arg1 - %6 = fcmp ole float %a, %b - ; CHECK: llvm.fcmp "one" %arg0, %arg1 - %7 = fcmp one float %a, %b - ; CHECK: llvm.fcmp "ord" %arg0, %arg1 - %8 = fcmp ord float %a, %b - ; CHECK: llvm.fcmp "ueq" %arg0, %arg1 - %9 = fcmp ueq float %a, %b - ; CHECK: llvm.fcmp "ugt" %arg0, %arg1 - %10 = fcmp ugt float %a, %b - ; CHECK: llvm.fcmp "uge" %arg0, %arg1 - %11 = fcmp uge float %a, %b - ; CHECK: llvm.fcmp "ult" %arg0, %arg1 - %12 = fcmp ult float %a, %b - ; CHECK: llvm.fcmp "ule" %arg0, %arg1 - %13 = fcmp ule float %a, %b - ; CHECK: llvm.fcmp "une" %arg0, %arg1 - %14 = fcmp une float %a, %b - ; CHECK: llvm.fcmp "uno" %arg0, %arg1 - %15 = fcmp uno float %a, %b - ; CHECK: llvm.fcmp "_true" %arg0, %arg1 - %16 = fcmp true float %a, %b - ret void -} - ; Testing rest of the floating point constant kinds. ; CHECK-LABEL: llvm.func @FPConstant(%arg0: f16, %arg1: bf16, %arg2: f128, %arg3: f80) define void @FPConstant(half %a, bfloat %b, fp128 %c, x86_fp80 %d) { diff --git a/mlir/test/Target/LLVMIR/Import/instructions.ll b/mlir/test/Target/LLVMIR/Import/instructions.ll --- a/mlir/test/Target/LLVMIR/Import/instructions.ll +++ b/mlir/test/Target/LLVMIR/Import/instructions.ll @@ -9,38 +9,66 @@ ; CHECK-DAG: %[[C1:[0-9]+]] = llvm.mlir.constant(-7 : i32) : i32 ; CHECK-DAG: %[[C2:[0-9]+]] = llvm.mlir.constant(42 : i32) : i32 ; CHECK: llvm.add %[[ARG1]], %[[C1]] : i32 - ; CHECK: llvm.add %[[C2]], %[[ARG2]] : i32 - ; CHECK: llvm.sub %[[ARG3]], %[[ARG4]] : i64 - ; CHECK: llvm.mul %[[ARG1]], %[[ARG2]] : i32 - ; CHECK: llvm.udiv %[[ARG3]], %[[ARG4]] : i64 - ; CHECK: llvm.sdiv %[[ARG1]], %[[ARG2]] : i32 - ; CHECK: llvm.urem %[[ARG3]], %[[ARG4]] : i64 - ; CHECK: llvm.srem %[[ARG1]], %[[ARG2]] : i32 - ; CHECK: llvm.shl %[[ARG3]], %[[ARG4]] : i64 - ; CHECK: llvm.lshr %[[ARG1]], %[[ARG2]] : i32 - ; CHECK: llvm.ashr %[[ARG3]], %[[ARG4]] : i64 - ; CHECK: llvm.and %[[ARG1]], %[[ARG2]] : i32 - ; CHECK: llvm.or %[[ARG3]], %[[ARG4]] : i64 - ; CHECK: llvm.xor %[[ARG1]], %[[ARG2]] : i32 %1 = add i32 %arg1, -7 + ; CHECK: llvm.add %[[C2]], %[[ARG2]] : i32 %2 = add i32 42, %arg2 + ; CHECK: llvm.sub %[[ARG3]], %[[ARG4]] : i64 %3 = sub i64 %arg3, %arg4 + ; CHECK: llvm.mul %[[ARG1]], %[[ARG2]] : i32 %4 = mul i32 %arg1, %arg2 + ; CHECK: llvm.udiv %[[ARG3]], %[[ARG4]] : i64 %5 = udiv i64 %arg3, %arg4 + ; CHECK: llvm.sdiv %[[ARG1]], %[[ARG2]] : i32 %6 = sdiv i32 %arg1, %arg2 + ; CHECK: llvm.urem %[[ARG3]], %[[ARG4]] : i64 %7 = urem i64 %arg3, %arg4 + ; CHECK: llvm.srem %[[ARG1]], %[[ARG2]] : i32 %8 = srem i32 %arg1, %arg2 + ; CHECK: llvm.shl %[[ARG3]], %[[ARG4]] : i64 %9 = shl i64 %arg3, %arg4 + ; CHECK: llvm.lshr %[[ARG1]], %[[ARG2]] : i32 %10 = lshr i32 %arg1, %arg2 + ; CHECK: llvm.ashr %[[ARG3]], %[[ARG4]] : i64 %11 = ashr i64 %arg3, %arg4 + ; CHECK: llvm.and %[[ARG1]], %[[ARG2]] : i32 %12 = and i32 %arg1, %arg2 + ; CHECK: llvm.or %[[ARG3]], %[[ARG4]] : i64 %13 = or i64 %arg3, %arg4 + ; CHECK: llvm.xor %[[ARG1]], %[[ARG2]] : i32 %14 = xor i32 %arg1, %arg2 ret void } ; // ----- +; CHECK-LABEL: @integer_compare +; CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] +; CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]] +; CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]] +; CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]] +define i1 @integer_compare(i32 %arg1, i32 %arg2, <4 x i64> %arg3, <4 x i64> %arg4) { + ; CHECK: llvm.icmp "eq" %[[ARG3]], %[[ARG4]] : vector<4xi64> + %1 = icmp eq <4 x i64> %arg3, %arg4 + ; CHECK: llvm.icmp "slt" %[[ARG1]], %[[ARG2]] : i32 + %2 = icmp slt i32 %arg1, %arg2 + ; CHECK: llvm.icmp "sle" %[[ARG1]], %[[ARG2]] : i32 + %3 = icmp sle i32 %arg1, %arg2 + ; CHECK: llvm.icmp "sgt" %[[ARG1]], %[[ARG2]] : i32 + %4 = icmp sgt i32 %arg1, %arg2 + ; CHECK: llvm.icmp "sge" %[[ARG1]], %[[ARG2]] : i32 + %5 = icmp sge i32 %arg1, %arg2 + ; CHECK: llvm.icmp "ult" %[[ARG1]], %[[ARG2]] : i32 + %6 = icmp ult i32 %arg1, %arg2 + ; CHECK: llvm.icmp "ule" %[[ARG1]], %[[ARG2]] : i32 + %7 = icmp ule i32 %arg1, %arg2 + ; Verify scalar comparisons return a scalar boolean + ; CHECK: llvm.icmp "ugt" %[[ARG1]], %[[ARG2]] : i32 + %8 = icmp ugt i32 %arg1, %arg2 + ret i1 %8 +} + +; // ----- + ; CHECK-LABEL: @fp_arith ; CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] ; CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]] @@ -50,26 +78,70 @@ ; CHECK: %[[C1:[0-9]+]] = llvm.mlir.constant(3.030000e+01 : f64) : f64 ; CHECK: %[[C2:[0-9]+]] = llvm.mlir.constant(3.030000e+01 : f32) : f32 ; CHECK: llvm.fadd %[[C2]], %[[ARG1]] : f32 - ; CHECK: llvm.fadd %[[ARG1]], %[[ARG2]] : f32 - ; CHECK: llvm.fadd %[[C1]], %[[ARG3]] : f64 - ; CHECK: llvm.fsub %[[ARG1]], %[[ARG2]] : f32 - ; CHECK: llvm.fmul %[[ARG3]], %[[ARG4]] : f64 - ; CHECK: llvm.fdiv %[[ARG1]], %[[ARG2]] : f32 - ; CHECK: llvm.frem %[[ARG3]], %[[ARG4]] : f64 - ; CHECK: llvm.fneg %[[ARG1]] : f32 %1 = fadd float 0x403E4CCCC0000000, %arg1 + ; CHECK: llvm.fadd %[[ARG1]], %[[ARG2]] : f32 %2 = fadd float %arg1, %arg2 + ; CHECK: llvm.fadd %[[C1]], %[[ARG3]] : f64 %3 = fadd double 3.030000e+01, %arg3 + ; CHECK: llvm.fsub %[[ARG1]], %[[ARG2]] : f32 %4 = fsub float %arg1, %arg2 + ; CHECK: llvm.fmul %[[ARG3]], %[[ARG4]] : f64 %5 = fmul double %arg3, %arg4 + ; CHECK: llvm.fdiv %[[ARG1]], %[[ARG2]] : f32 %6 = fdiv float %arg1, %arg2 + ; CHECK: llvm.frem %[[ARG3]], %[[ARG4]] : f64 %7 = frem double %arg3, %arg4 + ; CHECK: llvm.fneg %[[ARG1]] : f32 %8 = fneg float %arg1 ret void } ; // ----- +; CHECK-LABEL: @fp_compare +; CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] +; CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]] +; CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]] +; CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]] +define <4 x i1> @fp_compare(float %arg1, float %arg2, <4 x double> %arg3, <4 x double> %arg4) { + ; CHECK: llvm.fcmp "_false" %[[ARG1]], %[[ARG2]] : f32 + %1 = fcmp false float %arg1, %arg2 + ; CHECK: llvm.fcmp "oeq" %[[ARG1]], %[[ARG2]] : f32 + %2 = fcmp oeq float %arg1, %arg2 + ; CHECK: llvm.fcmp "ogt" %[[ARG1]], %[[ARG2]] : f32 + %3 = fcmp ogt float %arg1, %arg2 + ; CHECK: llvm.fcmp "oge" %[[ARG1]], %[[ARG2]] : f32 + %4 = fcmp oge float %arg1, %arg2 + ; CHECK: llvm.fcmp "olt" %[[ARG1]], %[[ARG2]] : f32 + %5 = fcmp olt float %arg1, %arg2 + ; CHECK: llvm.fcmp "ole" %[[ARG1]], %[[ARG2]] : f32 + %6 = fcmp ole float %arg1, %arg2 + ; CHECK: llvm.fcmp "one" %[[ARG1]], %[[ARG2]] : f32 + %7 = fcmp one float %arg1, %arg2 + ; CHECK: llvm.fcmp "ord" %[[ARG1]], %[[ARG2]] : f32 + %8 = fcmp ord float %arg1, %arg2 + ; CHECK: llvm.fcmp "ueq" %[[ARG1]], %[[ARG2]] : f32 + %9 = fcmp ueq float %arg1, %arg2 + ; CHECK: llvm.fcmp "ugt" %[[ARG1]], %[[ARG2]] : f32 + %10 = fcmp ugt float %arg1, %arg2 + ; CHECK: llvm.fcmp "uge" %[[ARG1]], %[[ARG2]] : f32 + %11 = fcmp uge float %arg1, %arg2 + ; CHECK: llvm.fcmp "ult" %[[ARG1]], %[[ARG2]] : f32 + %12 = fcmp ult float %arg1, %arg2 + ; CHECK: llvm.fcmp "ule" %[[ARG1]], %[[ARG2]] : f32 + %13 = fcmp ule float %arg1, %arg2 + ; CHECK: llvm.fcmp "une" %[[ARG1]], %[[ARG2]] : f32 + %14 = fcmp une float %arg1, %arg2 + ; CHECK: llvm.fcmp "uno" %[[ARG1]], %[[ARG2]] : f32 + %15 = fcmp uno float %arg1, %arg2 + ; Verify vector comparisons return a vector of booleans + ; CHECK: llvm.fcmp "_true" %[[ARG3]], %[[ARG4]] : vector<4xf64> + %16 = fcmp true <4 x double> %arg3, %arg4 + ret <4 x i1> %16 +} + +; // ----- + ; CHECK-LABEL: @fp_casts ; CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] ; CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]] diff --git a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp --- a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp +++ b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp @@ -17,6 +17,7 @@ #include "mlir/TableGen/GenInfo.h" #include "mlir/TableGen/Operator.h" +#include "llvm/ADT/Sequence.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/FormatVariadic.h" @@ -203,6 +204,19 @@ if (builderStrRef.empty()) return success(); + // Access the argument index array that maps argument indices to LLVM IR + // operand indices. If the operation defines no custom mapping, set the array + // to the identity permutation. + std::vector llvmArgIndices = + record.getValueAsListOfInts("llvmArgIndices"); + if (llvmArgIndices.empty()) + append_range(llvmArgIndices, seq(0, op.getNumArgs())); + if (llvmArgIndices.size() != static_cast(op.getNumArgs())) { + return emitError( + "'llvmArgIndices' does not match the number of arguments for op " + + op.getOperationName()); + } + // Progressively create the builder string by replacing $-variables. Keep only // the not-yet-traversed part of the builder pattern to avoid re-traversing // the string multiple times. @@ -215,9 +229,13 @@ // Then, rewrite the name based on its kind. FailureOr argIndex = getArgumentIndex(op, name); if (succeeded(argIndex)) { - // Process the argument value assuming the MLIR and LLVM operand orders - // match and there are no optional or variadic arguments. - bs << formatv("processValue(llvmOperands[{0}])", *argIndex); + // Access the LLVM IR operand that maps to the given argument index using + // the provided argument indices mapping. + // FIXME: support trailing variadic arguments. + int64_t operandIdx = llvmArgIndices[*argIndex]; + assert(operandIdx >= 0 && "expected argument to have a mapping"); + assert(!isVariadicOperandName(op, name) && "unexpected variadic operand"); + bs << formatv("processValue(llvmOperands[{0}])", operandIdx); } else if (isResultName(op, name)) { assert(op.getNumResults() == 1 && "expected operation to have one result");