diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -228,9 +228,16 @@ // - $_builder - substituted with the MLIR builder; // - $_qualCppClassName - substitiuted with the MLIR operation class name. // Additionally, `$$` can be used to produce the dollar character. - // NOTE: The $name variable resolution assumes the MLIR and LLVM argument - // orders match and there are no optional or variadic arguments. + // FIXME: The $name variable resolution does not support variadic arguments. string mlirBuilder = ""; + + // An array that specifies a mapping from MLIR argument indices to LLVM IR + // operand indices. The mapping is necessary since argument and operand + // indices do not always match. If not defined, the array is set to the + // identity permutation. An operation may define any custom index permutation + // and set a specific argument index to -1 if it does not map to an LLVM IR + // operand. + list llvmArgIndices = []; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -284,19 +284,39 @@ let cppNamespace = "::mlir::LLVM"; } +// Base class for compare operations. A compare operation takes two operands +// of the same type and returns a boolean result. If the operands are +// vectors, then the result has to be a boolean vector of the same shape. +class LLVM_ArithmeticCmpOp traits = []> : + LLVM_Op]> { + let results = (outs LLVM_ScalarOrVectorOf:$res); +} + // Other integer operations. -def LLVM_ICmpOp : LLVM_Op<"icmp", [Pure]> { +def LLVM_ICmpOp : LLVM_ArithmeticCmpOp<"icmp", [Pure]> { let arguments = (ins ICmpPredicate:$predicate, - AnyTypeOf<[LLVM_ScalarOrVectorOf, LLVM_ScalarOrVectorOf]>:$lhs, - AnyTypeOf<[LLVM_ScalarOrVectorOf, LLVM_ScalarOrVectorOf]>:$rhs); - let results = (outs LLVM_ScalarOrVectorOf:$res); - let llvmBuilder = [{ - $res = builder.CreateICmp(getLLVMCmpPredicate($predicate), $lhs, $rhs); - }]; + AnyTypeOf<[LLVM_ScalarOrVectorOf, + LLVM_ScalarOrVectorOf]>:$lhs, + AnyTypeOf<[LLVM_ScalarOrVectorOf, + LLVM_ScalarOrVectorOf]>:$rhs); let builders = [ OpBuilder<(ins "ICmpPredicate":$predicate, "Value":$lhs, "Value":$rhs)> ]; let hasCustomAssemblyFormat = 1; + string llvmInstName = "ICmp"; + string llvmBuilder = [{ + $res = builder.CreateICmp(getLLVMCmpPredicate($predicate), $lhs, $rhs); + }]; + string mlirBuilder = [{ + auto *iCmpInst = cast(inst); + $res = $_builder.create<$_qualCppClassName>( + $_location, getICmpPredicate(iCmpInst->getPredicate()), $lhs, $rhs); + }]; + // Set the $predicate index to -1 to indicate there is no matching operand + // and decrement the following indices. + list llvmArgIndices = [-1, 0, 1]; } // Predicate for float comparisons @@ -329,17 +349,29 @@ } // Other floating-point operations. -def LLVM_FCmpOp : LLVM_Op<"fcmp", [ +def LLVM_FCmpOp : LLVM_ArithmeticCmpOp<"fcmp", [ Pure, DeclareOpInterfaceMethods]> { let arguments = (ins FCmpPredicate:$predicate, LLVM_ScalarOrVectorOf:$lhs, LLVM_ScalarOrVectorOf:$rhs, DefaultValuedAttr:$fastmathFlags); - let results = (outs LLVM_ScalarOrVectorOf:$res); - let llvmBuilder = [{ + let builders = [ + OpBuilder<(ins "FCmpPredicate":$predicate, "Value":$lhs, "Value":$rhs)> + ]; + let hasCustomAssemblyFormat = 1; + string llvmInstName = "FCmp"; + string llvmBuilder = [{ $res = builder.CreateFCmp(getLLVMCmpPredicate($predicate), $lhs, $rhs); }]; - let hasCustomAssemblyFormat = 1; + // FIXME: Import fastmath flags. + string mlirBuilder = [{ + auto *fCmpInst = cast(inst); + $res = $_builder.create<$_qualCppClassName>( + $_location, getFCmpPredicate(fCmpInst->getPredicate()), $lhs, $rhs); + }]; + // Set the $predicate index to -1 to indicate there is no matching operand + // and decrement the following indices. + list llvmArgIndices = [-1, 0, 1, 2]; } // Floating point binary operations. diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -88,28 +88,34 @@ return success(); } +/// Returns a boolean type that has the same shape as `type`. It supports both +/// fixed size vectors as well as scalable vectors. +static Type getI1SameShape(Type type) { + Type i1Type = IntegerType::get(type.getContext(), 1); + if (!LLVM::isCompatibleVectorType(type)) + return i1Type; + if (LLVM::isScalableVectorType(type)) { + return LLVM::getVectorType( + i1Type, LLVM::getVectorNumElements(type).getKnownMinValue(), + /*isScalable=*/true); + } + return LLVM::getVectorType(i1Type, + LLVM::getVectorNumElements(type).getFixedValue(), + /*isScalable=*/false); +} + //===----------------------------------------------------------------------===// // Printing, parsing and builder for LLVM::CmpOp. //===----------------------------------------------------------------------===// void ICmpOp::build(OpBuilder &builder, OperationState &result, ICmpPredicate predicate, Value lhs, Value rhs) { - auto boolType = IntegerType::get(lhs.getType().getContext(), 1); - if (LLVM::isCompatibleVectorType(lhs.getType()) || - LLVM::isCompatibleVectorType(rhs.getType())) { - int64_t numLHSElements = 1, numRHSElements = 1; - if (LLVM::isCompatibleVectorType(lhs.getType())) - numLHSElements = - LLVM::getVectorNumElements(lhs.getType()).getFixedValue(); - if (LLVM::isCompatibleVectorType(rhs.getType())) - numRHSElements = - LLVM::getVectorNumElements(rhs.getType()).getFixedValue(); - build(builder, result, - VectorType::get({std::max(numLHSElements, numRHSElements)}, boolType), - predicate, lhs, rhs); - } else { - build(builder, result, boolType, predicate, lhs, rhs); - } + build(builder, result, getI1SameShape(lhs.getType()), predicate, lhs, rhs); +} + +void FCmpOp::build(OpBuilder &builder, OperationState &result, + FCmpPredicate predicate, Value lhs, Value rhs) { + build(builder, result, getI1SameShape(lhs.getType()), predicate, lhs, rhs); } void ICmpOp::print(OpAsmPrinter &p) { @@ -173,23 +179,10 @@ // The result type is either i1 or a vector type if the inputs are // vectors. - Type resultType = IntegerType::get(builder.getContext(), 1); if (!isCompatibleType(type)) return parser.emitError(trailingTypeLoc, "expected LLVM dialect-compatible type"); - if (LLVM::isCompatibleVectorType(type)) { - if (LLVM::isScalableVectorType(type)) { - resultType = LLVM::getVectorType( - resultType, LLVM::getVectorNumElements(type).getKnownMinValue(), - /*isScalable=*/true); - } else { - resultType = LLVM::getVectorType( - resultType, LLVM::getVectorNumElements(type).getFixedValue(), - /*isScalable=*/false); - } - } - - result.addTypes({resultType}); + result.addTypes(getI1SameShape(type)); return success(); } diff --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp --- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp @@ -804,36 +804,6 @@ // Convert all special instructions that do not provide an MLIR builder. Location loc = translateLoc(inst->getDebugLoc()); - if (inst->getOpcode() == llvm::Instruction::ICmp) { - Value lhs = processValue(inst->getOperand(0)); - Value rhs = processValue(inst->getOperand(1)); - Value res = b.create( - loc, getICmpPredicate(cast(inst)->getPredicate()), lhs, - rhs); - mapValue(inst, res); - return success(); - } - if (inst->getOpcode() == llvm::Instruction::FCmp) { - Value lhs = processValue(inst->getOperand(0)); - Value rhs = processValue(inst->getOperand(1)); - - if (lhs.getType() != rhs.getType()) - return failure(); - - Type boolType = b.getI1Type(); - Type resType = boolType; - if (LLVM::isCompatibleVectorType(lhs.getType())) { - unsigned numElements = - LLVM::getVectorNumElements(lhs.getType()).getFixedValue(); - resType = VectorType::get({numElements}, boolType); - } - - Value res = b.create( - loc, resType, - getFCmpPredicate(cast(inst)->getPredicate()), lhs, rhs); - mapValue(inst, res); - return success(); - } if (inst->getOpcode() == llvm::Instruction::Br) { auto *brInst = cast(inst); OperationState state(loc, diff --git a/mlir/test/Target/LLVMIR/Import/basic.ll b/mlir/test/Target/LLVMIR/Import/basic.ll --- a/mlir/test/Target/LLVMIR/Import/basic.ll +++ b/mlir/test/Target/LLVMIR/Import/basic.ll @@ -227,28 +227,6 @@ ret i32* bitcast (double* @g2 to i32*) } -; CHECK-LABEL: llvm.func @f5 -define void @f5(i32 %d) { -; FIXME: icmp should return i1. -; CHECK: = llvm.icmp "eq" - %1 = icmp eq i32 %d, 2 -; CHECK: = llvm.icmp "slt" - %2 = icmp slt i32 %d, 2 -; CHECK: = llvm.icmp "sle" - %3 = icmp sle i32 %d, 2 -; CHECK: = llvm.icmp "sgt" - %4 = icmp sgt i32 %d, 2 -; CHECK: = llvm.icmp "sge" - %5 = icmp sge i32 %d, 2 -; CHECK: = llvm.icmp "ult" - %6 = icmp ult i32 %d, 2 -; CHECK: = llvm.icmp "ule" - %7 = icmp ule i32 %d, 2 -; CHECK: = llvm.icmp "ugt" - %8 = icmp ugt i32 %d, 2 - ret void -} - ; CHECK-LABEL: llvm.func @f6(%arg0: !llvm.ptr>) define void @f6(void (i16) *%fn) { ; CHECK: %[[c:[0-9]+]] = llvm.mlir.constant(0 : i16) : i16 @@ -257,43 +235,6 @@ ret void } -; CHECK-LABEL: llvm.func @FPComparison(%arg0: f32, %arg1: f32) -define void @FPComparison(float %a, float %b) { - ; CHECK: llvm.fcmp "_false" %arg0, %arg1 - %1 = fcmp false float %a, %b - ; CHECK: llvm.fcmp "oeq" %arg0, %arg1 - %2 = fcmp oeq float %a, %b - ; CHECK: llvm.fcmp "ogt" %arg0, %arg1 - %3 = fcmp ogt float %a, %b - ; CHECK: llvm.fcmp "oge" %arg0, %arg1 - %4 = fcmp oge float %a, %b - ; CHECK: llvm.fcmp "olt" %arg0, %arg1 - %5 = fcmp olt float %a, %b - ; CHECK: llvm.fcmp "ole" %arg0, %arg1 - %6 = fcmp ole float %a, %b - ; CHECK: llvm.fcmp "one" %arg0, %arg1 - %7 = fcmp one float %a, %b - ; CHECK: llvm.fcmp "ord" %arg0, %arg1 - %8 = fcmp ord float %a, %b - ; CHECK: llvm.fcmp "ueq" %arg0, %arg1 - %9 = fcmp ueq float %a, %b - ; CHECK: llvm.fcmp "ugt" %arg0, %arg1 - %10 = fcmp ugt float %a, %b - ; CHECK: llvm.fcmp "uge" %arg0, %arg1 - %11 = fcmp uge float %a, %b - ; CHECK: llvm.fcmp "ult" %arg0, %arg1 - %12 = fcmp ult float %a, %b - ; CHECK: llvm.fcmp "ule" %arg0, %arg1 - %13 = fcmp ule float %a, %b - ; CHECK: llvm.fcmp "une" %arg0, %arg1 - %14 = fcmp une float %a, %b - ; CHECK: llvm.fcmp "uno" %arg0, %arg1 - %15 = fcmp uno float %a, %b - ; CHECK: llvm.fcmp "_true" %arg0, %arg1 - %16 = fcmp true float %a, %b - ret void -} - ; Testing rest of the floating point constant kinds. ; CHECK-LABEL: llvm.func @FPConstant(%arg0: f16, %arg1: bf16, %arg2: f128, %arg3: f80) define void @FPConstant(half %a, bfloat %b, fp128 %c, x86_fp80 %d) { diff --git a/mlir/test/Target/LLVMIR/Import/instructions.ll b/mlir/test/Target/LLVMIR/Import/instructions.ll --- a/mlir/test/Target/LLVMIR/Import/instructions.ll +++ b/mlir/test/Target/LLVMIR/Import/instructions.ll @@ -41,6 +41,34 @@ ; // ----- +; CHECK-LABEL: @integer_compare +; CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] +; CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]] +; CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]] +; CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]] +define i1 @integer_compare(i32 %arg1, i32 %arg2, <4 x i64> %arg3, <4 x i64> %arg4) { + ; CHECK-NEXT: llvm.icmp "eq" %[[ARG3]], %[[ARG4]] : vector<4xi64> + ; CHECK-NEXT: llvm.icmp "slt" %[[ARG1]], %[[ARG2]] : i32 + ; CHECK-NEXT: llvm.icmp "sle" %[[ARG1]], %[[ARG2]] : i32 + ; CHECK-NEXT: llvm.icmp "sgt" %[[ARG1]], %[[ARG2]] : i32 + ; CHECK-NEXT: llvm.icmp "sge" %[[ARG1]], %[[ARG2]] : i32 + ; CHECK-NEXT: llvm.icmp "ult" %[[ARG1]], %[[ARG2]] : i32 + ; CHECK-NEXT: llvm.icmp "ule" %[[ARG1]], %[[ARG2]] : i32 + ; CHECK-NEXT: llvm.icmp "ugt" %[[ARG1]], %[[ARG2]] : i32 + %1 = icmp eq <4 x i64> %arg3, %arg4 + %2 = icmp slt i32 %arg1, %arg2 + %3 = icmp sle i32 %arg1, %arg2 + %4 = icmp sgt i32 %arg1, %arg2 + %5 = icmp sge i32 %arg1, %arg2 + %6 = icmp ult i32 %arg1, %arg2 + %7 = icmp ule i32 %arg1, %arg2 + ; Verify scalar comparisons return a scalar boolean + %8 = icmp ugt i32 %arg1, %arg2 + ret i1 %8 +} + +; // ----- + ; CHECK-LABEL: @fp_arith ; CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] ; CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]] @@ -70,6 +98,50 @@ ; // ----- +; CHECK-LABEL: @fp_compare +; CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] +; CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]] +; CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]] +; CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]] +define <4 x i1> @fp_compare(float %arg1, float %arg2, <4 x double> %arg3, <4 x double> %arg4) { + ; CHECK-NEXT: llvm.fcmp "_false" %[[ARG1]], %[[ARG2]] : f32 + ; CHECK-NEXT: llvm.fcmp "oeq" %[[ARG1]], %[[ARG2]] : f32 + ; CHECK-NEXT: llvm.fcmp "ogt" %[[ARG1]], %[[ARG2]] : f32 + ; CHECK-NEXT: llvm.fcmp "oge" %[[ARG1]], %[[ARG2]] : f32 + ; CHECK-NEXT: llvm.fcmp "olt" %[[ARG1]], %[[ARG2]] : f32 + ; CHECK-NEXT: llvm.fcmp "ole" %[[ARG1]], %[[ARG2]] : f32 + ; CHECK-NEXT: llvm.fcmp "one" %[[ARG1]], %[[ARG2]] : f32 + ; CHECK-NEXT: llvm.fcmp "ord" %[[ARG1]], %[[ARG2]] : f32 + ; CHECK-NEXT: llvm.fcmp "ueq" %[[ARG1]], %[[ARG2]] : f32 + ; CHECK-NEXT: llvm.fcmp "ugt" %[[ARG1]], %[[ARG2]] : f32 + ; CHECK-NEXT: llvm.fcmp "uge" %[[ARG1]], %[[ARG2]] : f32 + ; CHECK-NEXT: llvm.fcmp "ult" %[[ARG1]], %[[ARG2]] : f32 + ; CHECK-NEXT: llvm.fcmp "ule" %[[ARG1]], %[[ARG2]] : f32 + ; CHECK-NEXT: llvm.fcmp "une" %[[ARG1]], %[[ARG2]] : f32 + ; CHECK-NEXT: llvm.fcmp "uno" %[[ARG1]], %[[ARG2]] : f32 + ; CHECK-NEXT: llvm.fcmp "_true" %[[ARG3]], %[[ARG4]] : vector<4xf64> + %1 = fcmp false float %arg1, %arg2 + %2 = fcmp oeq float %arg1, %arg2 + %3 = fcmp ogt float %arg1, %arg2 + %4 = fcmp oge float %arg1, %arg2 + %5 = fcmp olt float %arg1, %arg2 + %6 = fcmp ole float %arg1, %arg2 + %7 = fcmp one float %arg1, %arg2 + %8 = fcmp ord float %arg1, %arg2 + %9 = fcmp ueq float %arg1, %arg2 + %10 = fcmp ugt float %arg1, %arg2 + %11 = fcmp uge float %arg1, %arg2 + %12 = fcmp ult float %arg1, %arg2 + %13 = fcmp ule float %arg1, %arg2 + %14 = fcmp une float %arg1, %arg2 + %15 = fcmp uno float %arg1, %arg2 + ; Verify vector comparisons return a vector of booleans + %16 = fcmp true <4 x double> %arg3, %arg4 + ret <4 x i1> %16 +} + +; // ----- + ; CHECK-LABEL: @fp_casts ; CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] ; CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]] diff --git a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp --- a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp +++ b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp @@ -17,6 +17,7 @@ #include "mlir/TableGen/GenInfo.h" #include "mlir/TableGen/Operator.h" +#include "llvm/ADT/Sequence.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/FormatVariadic.h" @@ -203,6 +204,22 @@ if (builderStrRef.empty()) return success(); + // Access the argument index array that maps argument indices to LLVM IR + // operand indices. If the operation defines no custom mapping, set the array + // to the identity permutation. + std::vector llvmArgIndices = + record.getValueAsListOfInts("llvmArgIndices"); + if (llvmArgIndices.empty()) { + llvmArgIndices.resize(op.getNumArgs()); + for (int64_t idx : seq(0, op.getNumArgs())) + llvmArgIndices[idx] = idx; + } + if (llvmArgIndices.size() != static_cast(op.getNumArgs())) { + return emitError( + "'llvmArgIndices' does not match the number of arguments for op " + + op.getOperationName()); + } + // Progressively create the builder string by replacing $-variables. Keep only // the not-yet-traversed part of the builder pattern to avoid re-traversing // the string multiple times. @@ -215,9 +232,12 @@ // Then, rewrite the name based on its kind. FailureOr argIndex = getArgumentIndex(op, name); if (succeeded(argIndex)) { - // Process the argument value assuming the MLIR and LLVM operand orders - // match and there are no optional or variadic arguments. - bs << formatv("processValue(llvmOperands[{0}])", *argIndex); + // Access the LLVM IR operand that maps to the given argument index using + // the provided argument indices mapping. + // FIXME: support trailing variadic arguments. + int64_t operandIdx = llvmArgIndices[*argIndex]; + assert(operandIdx >= 0 && "expected argument to have a mapping"); + bs << formatv("processValue(llvmOperands[{0}])", operandIdx); } else if (isResultName(op, name)) { assert(op.getNumResults() == 1 && "expected operation to have one result");