diff --git a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt --- a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt @@ -33,6 +33,7 @@ mlir_tablegen(LLVMConversions.inc -gen-llvmir-conversions) mlir_tablegen(LLVMConversionEnumsToLLVM.inc -gen-enum-to-llvmir-conversions) mlir_tablegen(LLVMConversionEnumsFromLLVM.inc -gen-enum-from-llvmir-conversions) +mlir_tablegen(LLVMOpFromLLVMIRConversions.inc -gen-op-from-llvmir-conversions) add_public_tablegen_target(MLIRLLVMConversionsIncGen) set(LLVM_TARGET_DEFINITIONS LLVMIntrinsicOps.td) diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -198,7 +198,8 @@ AnyTypeOf<[element, LLVM_VectorOf]>; // Base class for LLVM operations. Defines the interface to the llvm::IRBuilder -// used to translate to LLVM IR proper. +// used to translate to proper LLVM IR and the interface to the mlir::OpBuilder +// used to import from LLVM IR. class LLVM_OpBase traits = []> : Op { // A pattern for constructing the LLVM IR Instruction (or other Value) that @@ -213,6 +214,23 @@ // - $_location - mlir::Location object of the instruction. // Additionally, `$$` can be used to produce the dollar character. string llvmBuilder = ""; + + // A builder to construct the MLIR LLVM dialect operation given the matching + // LLVM IR instruction `inst` and its operands `llvmOperands`. The + // following $-variables exist: + // - $name - substituted by the remapped `inst` operand value at the index + // of the MLIR operation argument with the given name, or if the + // name matches the result name, by a reference to store the + // result of the newly created MLIR operation to; + // - $_int_attr - substituted by a call to an integer attribute matcher; + // - $_resultType - substituted with the MLIR result type; + // - $_location - substituted with the MLIR location; + // - $_builder - substituted with the MLIR builder; + // - $_qualCppClassName - substitiuted with the MLIR operation class name. + // Additionally, `$$` can be used to produce the dollar character. + // NOTE: The $name variable resolution assumes the MLIR and LLVM argument + // orders match and there are no optional or variadic arguments. + string mlirBuilder = ""; } //===----------------------------------------------------------------------===// @@ -335,21 +353,6 @@ "(void) inst;") # !if(!gt(numResults, 0), "$res = inst;", ""); - // A builder to construct the MLIR LLVM dialect operation given the matching - // LLVM IR instruction `inst` and its operands `llvmOperands`. The - // following $-variables exist: - // - $name - substituted by the remapped `inst` operand value at the index - // of the MLIR operation argument with the given name, or if the - // name matches the result name, by a reference to store the - // result of the newly created MLIR operation to; - // - $_int_attr - substituted by a call to an integer attribute matcher; - // - $_resultType - substituted with the MLIR result type; - // - $_location - substituted with the MLIR location; - // - $_builder - substituted with the MLIR builder; - // - $_qualCppClassName - substitiuted with the MLIR operation class name. - // Additionally, `$$` can be used to produce the dollar character. - // NOTE: The $name variable resolution assumes the MLIR and LLVM argument - // orders match and there are no optional or variadic arguments. string mlirBuilder = [{ SmallVector resultTypes = }] # !if(!gt(numResults, 0), "{$_resultType};", "{};") # [{ diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -206,24 +206,28 @@ // Class for arithmetic binary operations. class LLVM_ArithmeticOpBase traits = []> : + string instName, list traits = []> : LLVM_Op, - LLVM_Builder<"$res = builder." # builderFunc # "($lhs, $rhs);"> { + LLVM_Builder<"$res = builder.Create" # instName # "($lhs, $rhs);"> { dag commonArgs = (ins LLVM_ScalarOrVectorOf:$lhs, LLVM_ScalarOrVectorOf:$rhs); let results = (outs LLVM_ScalarOrVectorOf:$res); let builders = [LLVM_OneResultOpBuilder]; let assemblyFormat = "$lhs `,` $rhs custom(attr-dict) `:` type($res)"; + string llvmInstName = instName; + string mlirBuilder = [{ + $res = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs); + }]; } -class LLVM_IntArithmeticOp traits = []> : - LLVM_ArithmeticOpBase { + LLVM_ArithmeticOpBase { let arguments = commonArgs; } -class LLVM_FloatArithmeticOp traits = []> : - LLVM_ArithmeticOpBase], traits)> { dag fmfArg = (ins DefaultValuedAttr:$fastmathFlags); let arguments = !con(commonArgs, fmfArg); @@ -231,30 +235,34 @@ // Class for arithmetic unary operations. class LLVM_UnaryFloatArithmeticOp traits = []> : + string instName, list traits = []> : LLVM_Op], traits)>, - LLVM_Builder<"$res = builder." # builderFunc # "($operand);"> { + LLVM_Builder<"$res = builder.Create" # instName # "($operand);"> { let arguments = (ins type:$operand, DefaultValuedAttr:$fastmathFlags); let results = (outs type:$res); let builders = [LLVM_OneResultOpBuilder]; let assemblyFormat = "$operand custom(attr-dict) `:` type($res)"; + string llvmInstName = instName; + string mlirBuilder = [{ + $res = $_builder.create<$_qualCppClassName>($_location, $operand); + }]; } // Integer binary operations. -def LLVM_AddOp : LLVM_IntArithmeticOp<"add", "CreateAdd", [Commutative]>; -def LLVM_SubOp : LLVM_IntArithmeticOp<"sub", "CreateSub">; -def LLVM_MulOp : LLVM_IntArithmeticOp<"mul", "CreateMul", [Commutative]>; -def LLVM_UDivOp : LLVM_IntArithmeticOp<"udiv", "CreateUDiv">; -def LLVM_SDivOp : LLVM_IntArithmeticOp<"sdiv", "CreateSDiv">; -def LLVM_URemOp : LLVM_IntArithmeticOp<"urem", "CreateURem">; -def LLVM_SRemOp : LLVM_IntArithmeticOp<"srem", "CreateSRem">; -def LLVM_AndOp : LLVM_IntArithmeticOp<"and", "CreateAnd">; -def LLVM_OrOp : LLVM_IntArithmeticOp<"or", "CreateOr">; -def LLVM_XOrOp : LLVM_IntArithmeticOp<"xor", "CreateXor">; -def LLVM_ShlOp : LLVM_IntArithmeticOp<"shl", "CreateShl">; -def LLVM_LShrOp : LLVM_IntArithmeticOp<"lshr", "CreateLShr">; -def LLVM_AShrOp : LLVM_IntArithmeticOp<"ashr", "CreateAShr">; +def LLVM_AddOp : LLVM_IntArithmeticOp<"add", "Add", [Commutative]>; +def LLVM_SubOp : LLVM_IntArithmeticOp<"sub", "Sub">; +def LLVM_MulOp : LLVM_IntArithmeticOp<"mul", "Mul", [Commutative]>; +def LLVM_UDivOp : LLVM_IntArithmeticOp<"udiv", "UDiv">; +def LLVM_SDivOp : LLVM_IntArithmeticOp<"sdiv", "SDiv">; +def LLVM_URemOp : LLVM_IntArithmeticOp<"urem", "URem">; +def LLVM_SRemOp : LLVM_IntArithmeticOp<"srem", "SRem">; +def LLVM_AndOp : LLVM_IntArithmeticOp<"and", "And">; +def LLVM_OrOp : LLVM_IntArithmeticOp<"or", "Or">; +def LLVM_XOrOp : LLVM_IntArithmeticOp<"xor", "Xor">; +def LLVM_ShlOp : LLVM_IntArithmeticOp<"shl", "Shl">; +def LLVM_LShrOp : LLVM_IntArithmeticOp<"lshr", "LShr">; +def LLVM_AShrOp : LLVM_IntArithmeticOp<"ashr", "AShr">; // Predicate for integer comparisons. def ICmpPredicateEQ : I64EnumAttrCase<"eq", 0>; @@ -335,13 +343,13 @@ } // Floating point binary operations. -def LLVM_FAddOp : LLVM_FloatArithmeticOp<"fadd", "CreateFAdd">; -def LLVM_FSubOp : LLVM_FloatArithmeticOp<"fsub", "CreateFSub">; -def LLVM_FMulOp : LLVM_FloatArithmeticOp<"fmul", "CreateFMul">; -def LLVM_FDivOp : LLVM_FloatArithmeticOp<"fdiv", "CreateFDiv">; -def LLVM_FRemOp : LLVM_FloatArithmeticOp<"frem", "CreateFRem">; +def LLVM_FAddOp : LLVM_FloatArithmeticOp<"fadd", "FAdd">; +def LLVM_FSubOp : LLVM_FloatArithmeticOp<"fsub", "FSub">; +def LLVM_FMulOp : LLVM_FloatArithmeticOp<"fmul", "FMul">; +def LLVM_FDivOp : LLVM_FloatArithmeticOp<"fdiv", "FDiv">; +def LLVM_FRemOp : LLVM_FloatArithmeticOp<"frem", "FRem">; def LLVM_FNegOp : LLVM_UnaryFloatArithmeticOp< - LLVM_ScalarOrVectorOf, "fneg", "CreateFNeg">; + LLVM_ScalarOrVectorOf, "fneg", "FNeg">; // Common code definition that is used to verify and set the alignment attribute // of LLVM ops that accept such an attribute. @@ -383,6 +391,7 @@ OptionalAttr:$elem_type); let results = (outs Res]>:$res); + string llvmInstName = "Alloca"; string llvmBuilder = [{ auto addrSpace = $_resultType->getPointerAddressSpace(); llvm::Type *elementType = moduleTranslation.convertType( @@ -392,6 +401,14 @@ }] # setAlignmentCode # [{ $res = inst; }]; + // FIXME: Import attributes. + string mlirBuilder = [{ + auto *allocaInst = cast(inst); + Type allocatedType = convertType(allocaInst->getAllocatedType()); + unsigned alignment = allocaInst->getAlign().value(); + $res = $_builder.create( + $_location, $_resultType, allocatedType, $arraySize, alignment); + }]; let builders = [ OpBuilder<(ins "Type":$resultType, "Value":$arraySize, "unsigned":$alignment), @@ -505,6 +522,7 @@ OptionalAttr:$alignment, UnitAttr:$volatile_, UnitAttr:$nontemporal); let results = (outs LLVM_LoadableType:$res); + string llvmInstName = "Load"; string llvmBuilder = [{ auto *inst = builder.CreateLoad($_resultType, $addr, $volatile_); }] # setAlignmentCode @@ -514,6 +532,10 @@ # [{ $res = inst; }]; + // FIXME: Import attributes. + string mlirBuilder = [{ + $res = $_builder.create($_location, $_resultType, $addr); + }]; let builders = [ OpBuilder<(ins "Value":$addr, CArg<"unsigned", "0">:$alignment, CArg<"bool", "false">:$isVolatile, CArg<"bool", "false">:$isNonTemporal), @@ -538,12 +560,17 @@ OptionalAttr:$noalias_scopes, OptionalAttr:$alignment, UnitAttr:$volatile_, UnitAttr:$nontemporal); + string llvmInstName = "Store"; string llvmBuilder = [{ auto *inst = builder.CreateStore($value, $addr, $volatile_); }] # setAlignmentCode # setNonTemporalMetadataCode # setAccessGroupsMetadataCode # setAliasScopeMetadataCode; + // FIXME: Import attributes. + string mlirBuilder = [{ + $_builder.create($_location, $value, $addr); + }]; let builders = [ OpBuilder<(ins "Value":$value, "Value":$addr, CArg<"unsigned", "0">:$alignment, CArg<"bool", "false">:$isVolatile, @@ -554,55 +581,60 @@ } // Casts. -class LLVM_CastOp traits = []> : LLVM_Op, - LLVM_Builder<"$res = builder." # builderFunc # "($arg, $_resultType);"> { + LLVM_Builder<"$res = builder.Create" # instName # "($arg, $_resultType);"> { let arguments = (ins type:$arg); let results = (outs resultType:$res); let builders = [LLVM_OneResultOpBuilder]; let assemblyFormat = "$arg attr-dict `:` type($arg) `to` type($res)"; + string llvmInstName = instName; + string mlirBuilder = [{ + $res = $_builder.create<$_qualCppClassName>( + $_location, $_resultType, $arg); + }]; } -def LLVM_BitcastOp : LLVM_CastOp<"bitcast", "CreateBitCast", +def LLVM_BitcastOp : LLVM_CastOp<"bitcast", "BitCast", LLVM_AnyNonAggregate, LLVM_AnyNonAggregate> { let hasFolder = 1; } -def LLVM_AddrSpaceCastOp : LLVM_CastOp<"addrspacecast", "CreateAddrSpaceCast", +def LLVM_AddrSpaceCastOp : LLVM_CastOp<"addrspacecast", "AddrSpaceCast", LLVM_ScalarOrVectorOf, LLVM_ScalarOrVectorOf> { let hasFolder = 1; } -def LLVM_IntToPtrOp : LLVM_CastOp<"inttoptr", "CreateIntToPtr", +def LLVM_IntToPtrOp : LLVM_CastOp<"inttoptr", "IntToPtr", LLVM_ScalarOrVectorOf, LLVM_ScalarOrVectorOf>; -def LLVM_PtrToIntOp : LLVM_CastOp<"ptrtoint", "CreatePtrToInt", +def LLVM_PtrToIntOp : LLVM_CastOp<"ptrtoint", "PtrToInt", LLVM_ScalarOrVectorOf, LLVM_ScalarOrVectorOf>; -def LLVM_SExtOp : LLVM_CastOp<"sext", "CreateSExt", +def LLVM_SExtOp : LLVM_CastOp<"sext", "SExt", LLVM_ScalarOrVectorOf, LLVM_ScalarOrVectorOf>; -def LLVM_ZExtOp : LLVM_CastOp<"zext", "CreateZExt", +def LLVM_ZExtOp : LLVM_CastOp<"zext", "ZExt", LLVM_ScalarOrVectorOf, LLVM_ScalarOrVectorOf>; -def LLVM_TruncOp : LLVM_CastOp<"trunc", "CreateTrunc", +def LLVM_TruncOp : LLVM_CastOp<"trunc", "Trunc", LLVM_ScalarOrVectorOf, LLVM_ScalarOrVectorOf>; -def LLVM_SIToFPOp : LLVM_CastOp<"sitofp", "CreateSIToFP", +def LLVM_SIToFPOp : LLVM_CastOp<"sitofp", "SIToFP", LLVM_ScalarOrVectorOf, LLVM_ScalarOrVectorOf>; -def LLVM_UIToFPOp : LLVM_CastOp<"uitofp", "CreateUIToFP", +def LLVM_UIToFPOp : LLVM_CastOp<"uitofp", "UIToFP", LLVM_ScalarOrVectorOf, LLVM_ScalarOrVectorOf>; -def LLVM_FPToSIOp : LLVM_CastOp<"fptosi", "CreateFPToSI", +def LLVM_FPToSIOp : LLVM_CastOp<"fptosi", "FPToSI", LLVM_ScalarOrVectorOf, LLVM_ScalarOrVectorOf>; -def LLVM_FPToUIOp : LLVM_CastOp<"fptoui", "CreateFPToUI", +def LLVM_FPToUIOp : LLVM_CastOp<"fptoui", "FPToUI", LLVM_ScalarOrVectorOf, LLVM_ScalarOrVectorOf>; -def LLVM_FPExtOp : LLVM_CastOp<"fpext", "CreateFPExt", +def LLVM_FPExtOp : LLVM_CastOp<"fpext", "FPExt", LLVM_ScalarOrVectorOf, LLVM_ScalarOrVectorOf>; -def LLVM_FPTruncOp : LLVM_CastOp<"fptrunc", "CreateFPTrunc", +def LLVM_FPTruncOp : LLVM_CastOp<"fptrunc", "FPTrunc", LLVM_ScalarOrVectorOf, LLVM_ScalarOrVectorOf>; @@ -718,9 +750,14 @@ $vector `[` $position `:` type($position) `]` attr-dict `:` type($vector) }]; + string llvmInstName = "ExtractElement"; string llvmBuilder = [{ $res = builder.CreateExtractElement($vector, $position); }]; + string mlirBuilder = [{ + $res = $_builder.create( + $_location, $vector, $position); + }]; } //===----------------------------------------------------------------------===// @@ -772,9 +809,14 @@ type($vector) }]; + string llvmInstName = "InsertElement"; string llvmBuilder = [{ $res = builder.CreateInsertElement($vector, $value, $position); }]; + string mlirBuilder = [{ + $res = $_builder.create( + $_location, $vector, $value, $position); + }]; } //===----------------------------------------------------------------------===// @@ -843,13 +885,22 @@ LLVM_Type:$trueValue, LLVM_Type:$falseValue); let results = (outs LLVM_Type:$res); let assemblyFormat = "operands attr-dict `:` type($condition) `,` type($res)"; + string llvmInstName = "Select"; + string mlirBuilder = [{ + $res = $_builder.create( + $_location, $_resultType, $condition, $trueValue, $falseValue); + }]; } def LLVM_FreezeOp : LLVM_Op<"freeze", [SameOperandsAndResultType]> { let arguments = (ins LLVM_Type:$val); let results = (outs LLVM_Type:$res); let builders = [LLVM_OneResultOpBuilder]; let assemblyFormat = "$val attr-dict `:` type($val)"; + string llvmInstName = "Freeze"; string llvmBuilder = "builder.CreateFreeze($val);"; + string mlirBuilder = [{ + $res = $_builder.create($_location, $val); + }]; } // Terminators. @@ -916,23 +967,35 @@ let hasVerifier = 1; + string llvmInstName = "Ret"; string llvmBuilder = [{ if ($_numOperands != 0) builder.CreateRet($arg[0]); else builder.CreateRetVoid(); }]; + string mlirBuilder = [{ + $_builder.create($_location, processValues(llvmOperands)); + }]; } def LLVM_ResumeOp : LLVM_TerminatorOp<"resume"> { let arguments = (ins LLVM_Type:$value); - string llvmBuilder = [{ builder.CreateResume($value); }]; let assemblyFormat = "$value attr-dict `:` type($value)"; let hasVerifier = 1; + string llvmInstName = "Resume"; + string llvmBuilder = [{ builder.CreateResume($value); }]; + string mlirBuilder = [{ + $_builder.create($_location, $value); + }]; } def LLVM_UnreachableOp : LLVM_TerminatorOp<"unreachable"> { - string llvmBuilder = [{ builder.CreateUnreachable(); }]; let assemblyFormat = "attr-dict"; + string llvmInstName = "Unreachable"; + string llvmBuilder = [{ builder.CreateUnreachable(); }]; + string mlirBuilder = [{ + $_builder.create($_location); + }]; } def LLVM_SwitchOp : LLVM_TerminatorOp<"switch", @@ -1714,11 +1777,19 @@ def LLVM_FenceOp : LLVM_Op<"fence"> { let arguments = (ins AtomicOrdering:$ordering, StrAttr:$syncscope); let builders = [LLVM_VoidResultTypeOpBuilder, LLVM_ZeroResultOpBuilder]; + string llvmInstName = "Fence"; let llvmBuilder = [{ llvm::LLVMContext &llvmContext = builder.getContext(); builder.CreateFence(getLLVMAtomicOrdering($ordering), llvmContext.getOrInsertSyncScopeID($syncscope)); }]; + string mlirBuilder = [{ + llvm::FenceInst *fenceInst = cast(inst); + $_builder.create( + $_location, + getLLVMAtomicOrdering(fenceInst->getOrdering()), + getLLVMSyncScope(fenceInst)); + }]; let hasCustomAssemblyFormat = 1; let hasVerifier = 1; } diff --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp --- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp @@ -103,6 +103,139 @@ } } +static ICmpPredicate getICmpPredicate(llvm::CmpInst::Predicate p) { + switch (p) { + default: + llvm_unreachable("incorrect comparison predicate"); + case llvm::CmpInst::Predicate::ICMP_EQ: + return LLVM::ICmpPredicate::eq; + case llvm::CmpInst::Predicate::ICMP_NE: + return LLVM::ICmpPredicate::ne; + case llvm::CmpInst::Predicate::ICMP_SLT: + return LLVM::ICmpPredicate::slt; + case llvm::CmpInst::Predicate::ICMP_SLE: + return LLVM::ICmpPredicate::sle; + case llvm::CmpInst::Predicate::ICMP_SGT: + return LLVM::ICmpPredicate::sgt; + case llvm::CmpInst::Predicate::ICMP_SGE: + return LLVM::ICmpPredicate::sge; + case llvm::CmpInst::Predicate::ICMP_ULT: + return LLVM::ICmpPredicate::ult; + case llvm::CmpInst::Predicate::ICMP_ULE: + return LLVM::ICmpPredicate::ule; + case llvm::CmpInst::Predicate::ICMP_UGT: + return LLVM::ICmpPredicate::ugt; + case llvm::CmpInst::Predicate::ICMP_UGE: + return LLVM::ICmpPredicate::uge; + } + llvm_unreachable("incorrect integer comparison predicate"); +} + +static FCmpPredicate getFCmpPredicate(llvm::CmpInst::Predicate p) { + switch (p) { + default: + llvm_unreachable("incorrect comparison predicate"); + case llvm::CmpInst::Predicate::FCMP_FALSE: + return LLVM::FCmpPredicate::_false; + case llvm::CmpInst::Predicate::FCMP_TRUE: + return LLVM::FCmpPredicate::_true; + case llvm::CmpInst::Predicate::FCMP_OEQ: + return LLVM::FCmpPredicate::oeq; + case llvm::CmpInst::Predicate::FCMP_ONE: + return LLVM::FCmpPredicate::one; + case llvm::CmpInst::Predicate::FCMP_OLT: + return LLVM::FCmpPredicate::olt; + case llvm::CmpInst::Predicate::FCMP_OLE: + return LLVM::FCmpPredicate::ole; + case llvm::CmpInst::Predicate::FCMP_OGT: + return LLVM::FCmpPredicate::ogt; + case llvm::CmpInst::Predicate::FCMP_OGE: + return LLVM::FCmpPredicate::oge; + case llvm::CmpInst::Predicate::FCMP_ORD: + return LLVM::FCmpPredicate::ord; + case llvm::CmpInst::Predicate::FCMP_ULT: + return LLVM::FCmpPredicate::ult; + case llvm::CmpInst::Predicate::FCMP_ULE: + return LLVM::FCmpPredicate::ule; + case llvm::CmpInst::Predicate::FCMP_UGT: + return LLVM::FCmpPredicate::ugt; + case llvm::CmpInst::Predicate::FCMP_UGE: + return LLVM::FCmpPredicate::uge; + case llvm::CmpInst::Predicate::FCMP_UNO: + return LLVM::FCmpPredicate::uno; + case llvm::CmpInst::Predicate::FCMP_UEQ: + return LLVM::FCmpPredicate::ueq; + case llvm::CmpInst::Predicate::FCMP_UNE: + return LLVM::FCmpPredicate::une; + } + llvm_unreachable("incorrect floating point comparison predicate"); +} + +static AtomicOrdering getLLVMAtomicOrdering(llvm::AtomicOrdering ordering) { + switch (ordering) { + case llvm::AtomicOrdering::NotAtomic: + return LLVM::AtomicOrdering::not_atomic; + case llvm::AtomicOrdering::Unordered: + return LLVM::AtomicOrdering::unordered; + case llvm::AtomicOrdering::Monotonic: + return LLVM::AtomicOrdering::monotonic; + case llvm::AtomicOrdering::Acquire: + return LLVM::AtomicOrdering::acquire; + case llvm::AtomicOrdering::Release: + return LLVM::AtomicOrdering::release; + case llvm::AtomicOrdering::AcquireRelease: + return LLVM::AtomicOrdering::acq_rel; + case llvm::AtomicOrdering::SequentiallyConsistent: + return LLVM::AtomicOrdering::seq_cst; + } + llvm_unreachable("incorrect atomic ordering"); +} + +static AtomicBinOp getLLVMAtomicBinOp(llvm::AtomicRMWInst::BinOp binOp) { + switch (binOp) { + case llvm::AtomicRMWInst::Xchg: + return LLVM::AtomicBinOp::xchg; + case llvm::AtomicRMWInst::Add: + return LLVM::AtomicBinOp::add; + case llvm::AtomicRMWInst::Sub: + return LLVM::AtomicBinOp::sub; + case llvm::AtomicRMWInst::And: + return LLVM::AtomicBinOp::_and; + case llvm::AtomicRMWInst::Nand: + return LLVM::AtomicBinOp::nand; + case llvm::AtomicRMWInst::Or: + return LLVM::AtomicBinOp::_or; + case llvm::AtomicRMWInst::Xor: + return LLVM::AtomicBinOp::_xor; + case llvm::AtomicRMWInst::Max: + return LLVM::AtomicBinOp::max; + case llvm::AtomicRMWInst::Min: + return LLVM::AtomicBinOp::min; + case llvm::AtomicRMWInst::UMax: + return LLVM::AtomicBinOp::umax; + case llvm::AtomicRMWInst::UMin: + return LLVM::AtomicBinOp::umin; + case llvm::AtomicRMWInst::FAdd: + return LLVM::AtomicBinOp::fadd; + case llvm::AtomicRMWInst::FSub: + return LLVM::AtomicBinOp::fsub; + default: + llvm_unreachable("unsupported atomic binary operation"); + } +} + +/// Converts the sync scope identifier of `fenceInst` to the string +/// representation necessary to build the LLVM dialect fence operation. +static StringRef getLLVMSyncScope(llvm::FenceInst *fenceInst) { + llvm::LLVMContext &llvmContext = fenceInst->getContext(); + SmallVector syncScopeNames; + llvmContext.getSyncScopeNames(syncScopeNames); + for (StringRef name : syncScopeNames) + if (fenceInst->getSyncScopeID() == llvmContext.getOrInsertSyncScopeID(name)) + return name; + llvm_unreachable("incorrect sync scope identifier"); +} + DataLayoutSpecInterface mlir::translateDataLayout(const llvm::DataLayout &dataLayout, MLIRContext *context) { @@ -210,6 +343,11 @@ /// counterpart exists. Otherwise, returns failure. LogicalResult convertIntrinsic(OpBuilder &odsBuilder, llvm::CallInst *inst); + /// Converts an LLVM instruction to an MLIR LLVM dialect operation if the + /// operation defines an MLIR Builder. Otherwise, returns failure. + LogicalResult convertOperation(OpBuilder &odsBuilder, + llvm::Instruction *inst); + /// Imports `f` into the current module. LogicalResult processFunction(llvm::Function *f); @@ -313,6 +451,15 @@ return failure(); } +LogicalResult Importer::convertOperation(OpBuilder &odsBuilder, + llvm::Instruction *inst) { + // Copy the instruction operands used for the conversion. + SmallVector llvmOperands(inst->operands()); +#include "mlir/Dialect/LLVMIR/LLVMOpFromLLVMIRConversions.inc" + + return failure(); +} + // We only need integers, floats, doubles, and vectors and tensors thereof for // attributes. Scalar and vector types are converted to the standard // equivalents. Array types are converted to ranked tensors; nested array types @@ -626,208 +773,6 @@ return integerAttr; } -/// Return the MLIR OperationName for the given LLVM opcode. -static StringRef lookupOperationNameFromOpcode(unsigned opcode) { -// Maps from LLVM opcode to MLIR OperationName. This is deliberately ordered -// as in llvm/IR/Instructions.def to aid comprehension and spot missing -// instructions. -#define INST(llvm_n, mlir_n) \ - { llvm::Instruction::llvm_n, LLVM::mlir_n##Op::getOperationName() } - static const DenseMap opcMap = { - // clang-format off - INST(Ret, Return), - // Br is handled specially. - // Switch is handled specially. - // FIXME: indirectbr - // Invoke is handled specially. - INST(Resume, Resume), - INST(Unreachable, Unreachable), - // FIXME: cleanupret - // FIXME: catchret - // FIXME: catchswitch - // FIXME: callbr - INST(FNeg, FNeg), - INST(Add, Add), - INST(FAdd, FAdd), - INST(Sub, Sub), - INST(FSub, FSub), - INST(Mul, Mul), - INST(FMul, FMul), - INST(UDiv, UDiv), - INST(SDiv, SDiv), - INST(FDiv, FDiv), - INST(URem, URem), - INST(SRem, SRem), - INST(FRem, FRem), - INST(Shl, Shl), - INST(LShr, LShr), - INST(AShr, AShr), - INST(And, And), - INST(Or, Or), - INST(Xor, XOr), - INST(ExtractElement, ExtractElement), - INST(InsertElement, InsertElement), - // ShuffleVector is handled specially. - // ExtractValue is handled specially. - // InsertValue is handled specially. - INST(Alloca, Alloca), - INST(Load, Load), - INST(Store, Store), - INST(Fence, Fence), - // AtomicCmpXchg is handled specially. - // AtomicRMW is handled specially. - // Getelementptr is handled specially. - INST(Trunc, Trunc), - INST(ZExt, ZExt), - INST(SExt, SExt), - INST(FPToUI, FPToUI), - INST(FPToSI, FPToSI), - INST(UIToFP, UIToFP), - INST(SIToFP, SIToFP), - INST(FPTrunc, FPTrunc), - INST(FPExt, FPExt), - INST(PtrToInt, PtrToInt), - INST(IntToPtr, IntToPtr), - INST(BitCast, Bitcast), - INST(AddrSpaceCast, AddrSpaceCast), - // ICmp is handled specially. - // FCmp is handled specially. - // PHI is handled specially. - INST(Select, Select), - INST(Freeze, Freeze), - INST(Call, Call), - // FIXME: vaarg - // FIXME: landingpad - // FIXME: catchpad - // FIXME: cleanuppad - // clang-format on - }; -#undef INST - - return opcMap.lookup(opcode); -} - -static ICmpPredicate getICmpPredicate(llvm::CmpInst::Predicate p) { - switch (p) { - default: - llvm_unreachable("incorrect comparison predicate"); - case llvm::CmpInst::Predicate::ICMP_EQ: - return LLVM::ICmpPredicate::eq; - case llvm::CmpInst::Predicate::ICMP_NE: - return LLVM::ICmpPredicate::ne; - case llvm::CmpInst::Predicate::ICMP_SLT: - return LLVM::ICmpPredicate::slt; - case llvm::CmpInst::Predicate::ICMP_SLE: - return LLVM::ICmpPredicate::sle; - case llvm::CmpInst::Predicate::ICMP_SGT: - return LLVM::ICmpPredicate::sgt; - case llvm::CmpInst::Predicate::ICMP_SGE: - return LLVM::ICmpPredicate::sge; - case llvm::CmpInst::Predicate::ICMP_ULT: - return LLVM::ICmpPredicate::ult; - case llvm::CmpInst::Predicate::ICMP_ULE: - return LLVM::ICmpPredicate::ule; - case llvm::CmpInst::Predicate::ICMP_UGT: - return LLVM::ICmpPredicate::ugt; - case llvm::CmpInst::Predicate::ICMP_UGE: - return LLVM::ICmpPredicate::uge; - } - llvm_unreachable("incorrect integer comparison predicate"); -} - -static FCmpPredicate getFCmpPredicate(llvm::CmpInst::Predicate p) { - switch (p) { - default: - llvm_unreachable("incorrect comparison predicate"); - case llvm::CmpInst::Predicate::FCMP_FALSE: - return LLVM::FCmpPredicate::_false; - case llvm::CmpInst::Predicate::FCMP_TRUE: - return LLVM::FCmpPredicate::_true; - case llvm::CmpInst::Predicate::FCMP_OEQ: - return LLVM::FCmpPredicate::oeq; - case llvm::CmpInst::Predicate::FCMP_ONE: - return LLVM::FCmpPredicate::one; - case llvm::CmpInst::Predicate::FCMP_OLT: - return LLVM::FCmpPredicate::olt; - case llvm::CmpInst::Predicate::FCMP_OLE: - return LLVM::FCmpPredicate::ole; - case llvm::CmpInst::Predicate::FCMP_OGT: - return LLVM::FCmpPredicate::ogt; - case llvm::CmpInst::Predicate::FCMP_OGE: - return LLVM::FCmpPredicate::oge; - case llvm::CmpInst::Predicate::FCMP_ORD: - return LLVM::FCmpPredicate::ord; - case llvm::CmpInst::Predicate::FCMP_ULT: - return LLVM::FCmpPredicate::ult; - case llvm::CmpInst::Predicate::FCMP_ULE: - return LLVM::FCmpPredicate::ule; - case llvm::CmpInst::Predicate::FCMP_UGT: - return LLVM::FCmpPredicate::ugt; - case llvm::CmpInst::Predicate::FCMP_UGE: - return LLVM::FCmpPredicate::uge; - case llvm::CmpInst::Predicate::FCMP_UNO: - return LLVM::FCmpPredicate::uno; - case llvm::CmpInst::Predicate::FCMP_UEQ: - return LLVM::FCmpPredicate::ueq; - case llvm::CmpInst::Predicate::FCMP_UNE: - return LLVM::FCmpPredicate::une; - } - llvm_unreachable("incorrect floating point comparison predicate"); -} - -static AtomicOrdering getLLVMAtomicOrdering(llvm::AtomicOrdering ordering) { - switch (ordering) { - case llvm::AtomicOrdering::NotAtomic: - return LLVM::AtomicOrdering::not_atomic; - case llvm::AtomicOrdering::Unordered: - return LLVM::AtomicOrdering::unordered; - case llvm::AtomicOrdering::Monotonic: - return LLVM::AtomicOrdering::monotonic; - case llvm::AtomicOrdering::Acquire: - return LLVM::AtomicOrdering::acquire; - case llvm::AtomicOrdering::Release: - return LLVM::AtomicOrdering::release; - case llvm::AtomicOrdering::AcquireRelease: - return LLVM::AtomicOrdering::acq_rel; - case llvm::AtomicOrdering::SequentiallyConsistent: - return LLVM::AtomicOrdering::seq_cst; - } - llvm_unreachable("incorrect atomic ordering"); -} - -static AtomicBinOp getLLVMAtomicBinOp(llvm::AtomicRMWInst::BinOp binOp) { - switch (binOp) { - case llvm::AtomicRMWInst::Xchg: - return LLVM::AtomicBinOp::xchg; - case llvm::AtomicRMWInst::Add: - return LLVM::AtomicBinOp::add; - case llvm::AtomicRMWInst::Sub: - return LLVM::AtomicBinOp::sub; - case llvm::AtomicRMWInst::And: - return LLVM::AtomicBinOp::_and; - case llvm::AtomicRMWInst::Nand: - return LLVM::AtomicBinOp::nand; - case llvm::AtomicRMWInst::Or: - return LLVM::AtomicBinOp::_or; - case llvm::AtomicRMWInst::Xor: - return LLVM::AtomicBinOp::_xor; - case llvm::AtomicRMWInst::Max: - return LLVM::AtomicBinOp::max; - case llvm::AtomicRMWInst::Min: - return LLVM::AtomicBinOp::min; - case llvm::AtomicRMWInst::UMax: - return LLVM::AtomicBinOp::umax; - case llvm::AtomicRMWInst::UMin: - return LLVM::AtomicBinOp::umin; - case llvm::AtomicRMWInst::FAdd: - return LLVM::AtomicBinOp::fadd; - case llvm::AtomicRMWInst::FSub: - return LLVM::AtomicBinOp::fsub; - default: - llvm_unreachable("unsupported atomic binary operation"); - } -} - // `br` branches to `target`. Return the branch arguments to `br`, in the // same order of the PHIs in `target`. LogicalResult @@ -842,82 +787,24 @@ } LogicalResult Importer::processInstruction(llvm::Instruction *inst) { - // FIXME: Support uses of SubtargetData. Currently inbounds GEPs, fast-math - // flags and call / operand attributes are not supported. + // FIXME: Support uses of SubtargetData. + // FIXME: Add support for inbounds GEPs. + // FIXME: Add support for fast-math flags and call / operand attributes. + // FIXME: Add support for the indirectbr, cleanupret, catchret, catchswitch, + // callbr, vaarg, landingpad, catchpad, cleanuppad instructions. // Convert all intrinsics that provide an MLIR builder. - if (auto callInst = dyn_cast(inst)) + if (auto *callInst = dyn_cast(inst)) if (succeeded(convertIntrinsic(b, callInst))) return success(); - Location loc = translateLoc(inst->getDebugLoc()); - switch (inst->getOpcode()) { - default: - return emitError(loc) << "unknown instruction: " << diag(*inst); - case llvm::Instruction::Add: - case llvm::Instruction::FAdd: - case llvm::Instruction::Sub: - case llvm::Instruction::FSub: - case llvm::Instruction::Mul: - case llvm::Instruction::FMul: - case llvm::Instruction::UDiv: - case llvm::Instruction::SDiv: - case llvm::Instruction::FDiv: - case llvm::Instruction::URem: - case llvm::Instruction::SRem: - case llvm::Instruction::FRem: - case llvm::Instruction::Shl: - case llvm::Instruction::LShr: - case llvm::Instruction::AShr: - case llvm::Instruction::And: - case llvm::Instruction::Or: - case llvm::Instruction::Xor: - case llvm::Instruction::Load: - case llvm::Instruction::Store: - case llvm::Instruction::Ret: - case llvm::Instruction::Resume: - case llvm::Instruction::Trunc: - case llvm::Instruction::ZExt: - case llvm::Instruction::SExt: - case llvm::Instruction::FPToUI: - case llvm::Instruction::FPToSI: - case llvm::Instruction::UIToFP: - case llvm::Instruction::SIToFP: - case llvm::Instruction::FPTrunc: - case llvm::Instruction::FPExt: - case llvm::Instruction::PtrToInt: - case llvm::Instruction::IntToPtr: - case llvm::Instruction::AddrSpaceCast: - case llvm::Instruction::Freeze: - case llvm::Instruction::BitCast: - case llvm::Instruction::ExtractElement: - case llvm::Instruction::InsertElement: - case llvm::Instruction::Select: - case llvm::Instruction::FNeg: - case llvm::Instruction::Unreachable: { - OperationState state(loc, lookupOperationNameFromOpcode(inst->getOpcode())); - SmallVector operands(inst->operand_values()); - SmallVector ops = processValues(operands); - if (!inst->getType()->isVoidTy()) { - Type type = convertType(inst->getType()); - state.addTypes(type); - } - state.addOperands(ops); - Operation *op = b.create(state); - if (!inst->getType()->isVoidTy()) - mapValue(inst, op->getResult(0)); + // Convert all operations that provide an MLIR builder. + if (succeeded(convertOperation(b, inst))) return success(); - } - case llvm::Instruction::Alloca: { - Value size = processValue(inst->getOperand(0)); - auto *allocaInst = cast(inst); - Value res = b.create(loc, convertType(inst->getType()), - convertType(allocaInst->getAllocatedType()), - size, allocaInst->getAlign().value()); - mapValue(inst, res); - return success(); - } - case llvm::Instruction::ICmp: { + + // Convert all special instructions that do not provide an MLIR builder. + Location loc = translateLoc(inst->getDebugLoc()); + if (inst->getOpcode() == llvm::Instruction::ICmp) { Value lhs = processValue(inst->getOperand(0)); Value rhs = processValue(inst->getOperand(1)); Value res = b.create( @@ -926,7 +813,7 @@ mapValue(inst, res); return success(); } - case llvm::Instruction::FCmp: { + if (inst->getOpcode() == llvm::Instruction::FCmp) { Value lhs = processValue(inst->getOperand(0)); Value rhs = processValue(inst->getOperand(1)); @@ -947,7 +834,7 @@ mapValue(inst, res); return success(); } - case llvm::Instruction::Br: { + if (inst->getOpcode() == llvm::Instruction::Br) { auto *brInst = cast(inst); OperationState state(loc, brInst->isConditional() ? "llvm.cond_br" : "llvm.br"); @@ -975,7 +862,7 @@ b.create(state); return success(); } - case llvm::Instruction::Switch: { + if (inst->getOpcode() == llvm::Instruction::Switch) { auto *swInst = cast(inst); // Process the condition value. Value condition = processValue(swInst->getCondition()); @@ -1006,13 +893,13 @@ caseValues, caseBlocks, caseOperandRefs); return success(); } - case llvm::Instruction::PHI: { + if (inst->getOpcode() == llvm::Instruction::PHI) { Type type = convertType(inst->getType()); mapValue(inst, b.getInsertionBlock()->addArgument( type, translateLoc(inst->getDebugLoc()))); return success(); } - case llvm::Instruction::Call: { + if (inst->getOpcode() == llvm::Instruction::Call) { llvm::CallInst *ci = cast(inst); SmallVector args(ci->args()); SmallVector ops = processValues(args); @@ -1034,7 +921,7 @@ mapValue(inst, op->getResult(0)); return success(); } - case llvm::Instruction::LandingPad: { + if (inst->getOpcode() == llvm::Instruction::LandingPad) { llvm::LandingPadInst *lpi = cast(inst); SmallVector ops; @@ -1046,7 +933,7 @@ mapValue(inst, res); return success(); } - case llvm::Instruction::Invoke: { + if (inst->getOpcode() == llvm::Instruction::Invoke) { llvm::InvokeInst *ii = cast(inst); SmallVector tys; @@ -1077,24 +964,7 @@ mapValue(inst, op->getResult(0)); return success(); } - case llvm::Instruction::Fence: { - StringRef syncscope; - SmallVector ssNs; - llvm::LLVMContext &llvmContext = inst->getContext(); - llvm::FenceInst *fence = cast(inst); - llvmContext.getSyncScopeNames(ssNs); - int fenceSyncScopeID = fence->getSyncScopeID(); - for (unsigned i = 0, e = ssNs.size(); i != e; i++) { - if (fenceSyncScopeID == llvmContext.getOrInsertSyncScopeID(ssNs[i])) { - syncscope = ssNs[i]; - break; - } - } - b.create(loc, getLLVMAtomicOrdering(fence->getOrdering()), - syncscope); - return success(); - } - case llvm::Instruction::AtomicRMW: { + if (inst->getOpcode() == llvm::Instruction::AtomicRMW) { auto *atomicInst = cast(inst); Value ptr = processValue(atomicInst->getPointerOperand()); Value val = processValue(atomicInst->getValOperand()); @@ -1108,7 +978,7 @@ mapValue(inst, res); return success(); } - case llvm::Instruction::AtomicCmpXchg: { + if (inst->getOpcode() == llvm::Instruction::AtomicCmpXchg) { auto *cmpXchgInst = cast(inst); Value ptr = processValue(cmpXchgInst->getPointerOperand()); Value cmpVal = processValue(cmpXchgInst->getCompareOperand()); @@ -1125,7 +995,7 @@ mapValue(inst, res); return success(); } - case llvm::Instruction::GetElementPtr: { + if (inst->getOpcode() == llvm::Instruction::GetElementPtr) { // FIXME: Support inbounds GEPs. llvm::GetElementPtrInst *gep = cast(inst); Value basePtr = processValue(gep->getOperand(0)); @@ -1146,7 +1016,7 @@ mapValue(inst, res); return success(); } - case llvm::Instruction::InsertValue: { + if (inst->getOpcode() == llvm::Instruction::InsertValue) { auto *ivInst = cast(inst); Value inserted = processValue(ivInst->getInsertedValueOperand()); Value aggOperand = processValue(ivInst->getAggregateOperand()); @@ -1157,7 +1027,7 @@ mapValue(inst, res); return success(); } - case llvm::Instruction::ExtractValue: { + if (inst->getOpcode() == llvm::Instruction::ExtractValue) { auto *evInst = cast(inst); Value aggOperand = processValue(evInst->getAggregateOperand()); @@ -1167,7 +1037,7 @@ mapValue(inst, res); return success(); } - case llvm::Instruction::ShuffleVector: { + if (inst->getOpcode() == llvm::Instruction::ShuffleVector) { auto *svInst = cast(inst); Value vec1 = processValue(svInst->getOperand(0)); Value vec2 = processValue(svInst->getOperand(1)); @@ -1177,7 +1047,8 @@ mapValue(inst, res); return success(); } - } + + return emitError(loc) << "unknown instruction: " << diag(*inst); } FlatSymbolRefAttr Importer::getPersonalityAsAttr(llvm::Function *f) { diff --git a/mlir/test/Target/LLVMIR/Import/basic.ll b/mlir/test/Target/LLVMIR/Import/basic.ll --- a/mlir/test/Target/LLVMIR/Import/basic.ll +++ b/mlir/test/Target/LLVMIR/Import/basic.ll @@ -227,13 +227,6 @@ ret i32* bitcast (double* @g2 to i32*) } -; CHECK-LABEL: llvm.func @f4() -> !llvm.ptr -define i32* @f4() { -; CHECK: %[[b:[0-9]+]] = llvm.mlir.null : !llvm.ptr -; CHECK: llvm.return %[[b]] : !llvm.ptr - ret i32* bitcast (double* null to i32*) -} - ; CHECK-LABEL: llvm.func @f5 define void @f5(i32 %d) { ; FIXME: icmp should return i1. @@ -264,39 +257,6 @@ ret void } -; CHECK-LABEL: llvm.func @FPArithmetic(%arg0: f32, %arg1: f32, %arg2: f64, %arg3: f64) -define void @FPArithmetic(float %a, float %b, double %c, double %d) { - ; CHECK: %[[a1:[0-9]+]] = llvm.mlir.constant(3.030000e+01 : f64) : f64 - ; CHECK: %[[a2:[0-9]+]] = llvm.mlir.constant(3.030000e+01 : f32) : f32 - ; CHECK: %[[a3:[0-9]+]] = llvm.fadd %[[a2]], %arg0 : f32 - %1 = fadd float 0x403E4CCCC0000000, %a - ; CHECK: %[[a4:[0-9]+]] = llvm.fadd %arg0, %arg1 : f32 - %2 = fadd float %a, %b - ; CHECK: %[[a5:[0-9]+]] = llvm.fadd %[[a1]], %arg2 : f64 - %3 = fadd double 3.030000e+01, %c - ; CHECK: %[[a6:[0-9]+]] = llvm.fsub %arg0, %arg1 : f32 - %4 = fsub float %a, %b - ; CHECK: %[[a7:[0-9]+]] = llvm.fsub %arg2, %arg3 : f64 - %5 = fsub double %c, %d - ; CHECK: %[[a8:[0-9]+]] = llvm.fmul %arg0, %arg1 : f32 - %6 = fmul float %a, %b - ; CHECK: %[[a9:[0-9]+]] = llvm.fmul %arg2, %arg3 : f64 - %7 = fmul double %c, %d - ; CHECK: %[[a10:[0-9]+]] = llvm.fdiv %arg0, %arg1 : f32 - %8 = fdiv float %a, %b - ; CHECK: %[[a12:[0-9]+]] = llvm.fdiv %arg2, %arg3 : f64 - %9 = fdiv double %c, %d - ; CHECK: %[[a11:[0-9]+]] = llvm.frem %arg0, %arg1 : f32 - %10 = frem float %a, %b - ; CHECK: %[[a13:[0-9]+]] = llvm.frem %arg2, %arg3 : f64 - %11 = frem double %c, %d - ; CHECK: %{{.+}} = llvm.fneg %{{.+}} : f32 - %12 = fneg float %a - ; CHECK: %{{.+}} = llvm.fneg %{{.+}} : f64 - %13 = fneg double %c - ret void -} - ; CHECK-LABEL: llvm.func @FPComparison(%arg0: f32, %arg1: f32) define void @FPComparison(float %a, float %b) { ; CHECK: llvm.fcmp "_false" %arg0, %arg1 @@ -441,17 +401,6 @@ ret i32 0 } -;CHECK-LABEL: @useFenceInst -define i32 @useFenceInst() { - ;CHECK: llvm.fence syncscope("agent") seq_cst - fence syncscope("agent") seq_cst - ;CHECK: llvm.fence release - fence release - ;CHECK: llvm.fence seq_cst - fence syncscope("") seq_cst - ret i32 0 -} - ; Switch instruction declare void @g(i32) @@ -600,43 +549,6 @@ ret <4 x half> %shuffle } -; ExtractElement -; CHECK-LABEL: llvm.func @extract_element -define half @extract_element(<4 x half>* %vec, i32 %idx) { - ; CHECK: %[[V0:.+]] = llvm.load %{{.+}} : !llvm.ptr> - %val0 = load <4 x half>, <4 x half>* %vec - ; CHECK: %[[V1:.+]] = llvm.extractelement %[[V0]][%{{.+}} : i32] : vector<4xf16> - %r = extractelement <4 x half> %val0, i32 %idx - ; CHECK: llvm.return %[[V1]] - ret half %r -} - -; InsertElement -; CHECK-LABEL: llvm.func @insert_element -define <4 x half> @insert_element(<4 x half>* %vec, half %v, i32 %idx) { - ; CHECK: %[[V0:.+]] = llvm.load %{{.+}} : !llvm.ptr> - %val0 = load <4 x half>, <4 x half>* %vec - ; CHECK: %[[V1:.+]] = llvm.insertelement %{{.+}}, %[[V0]][%{{.+}} : i32] : vector<4xf16> - %r = insertelement <4 x half> %val0, half %v, i32 %idx - ; CHECK: llvm.return %[[V1]] - ret <4 x half> %r -} - -; Select -; CHECK-LABEL: llvm.func @select_inst -define void @select_inst(i32 %arg0, i32 %arg1, i1 %pred) { - ; CHECK: %{{.+}} = llvm.select %{{.+}}, %{{.+}}, %{{.+}} : i1, i32 - %1 = select i1 %pred, i32 %arg0, i32 %arg1 - ret void -} - -; Unreachable -; CHECK-LABEL: llvm.func @unreachable_inst -define void @unreachable_inst() { - ; CHECK: llvm.unreachable - unreachable -} - ; Varadic function definition %struct.va_list = type { i8* } diff --git a/mlir/test/Target/LLVMIR/Import/instructions.ll b/mlir/test/Target/LLVMIR/Import/instructions.ll new file mode 100644 --- /dev/null +++ b/mlir/test/Target/LLVMIR/Import/instructions.ll @@ -0,0 +1,276 @@ +; RUN: mlir-translate -import-llvm -split-input-file %s | FileCheck %s + +; CHECK-LABEL: @integer_arith +; CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] +; CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]] +; CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]] +; CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]] +define void @integer_arith(i32 %arg1, i32 %arg2, i64 %arg3, i64 %arg4) { + ; CHECK-DAG: %[[C1:[0-9]+]] = llvm.mlir.constant(-7 : i32) : i32 + ; CHECK-DAG: %[[C2:[0-9]+]] = llvm.mlir.constant(42 : i32) : i32 + ; CHECK: llvm.add %[[ARG1]], %[[C1]] : i32 + ; CHECK: llvm.add %[[C2]], %[[ARG2]] : i32 + ; CHECK: llvm.sub %[[ARG3]], %[[ARG4]] : i64 + ; CHECK: llvm.mul %[[ARG1]], %[[ARG2]] : i32 + ; CHECK: llvm.udiv %[[ARG3]], %[[ARG4]] : i64 + ; CHECK: llvm.sdiv %[[ARG1]], %[[ARG2]] : i32 + ; CHECK: llvm.urem %[[ARG3]], %[[ARG4]] : i64 + ; CHECK: llvm.srem %[[ARG1]], %[[ARG2]] : i32 + ; CHECK: llvm.shl %[[ARG3]], %[[ARG4]] : i64 + ; CHECK: llvm.lshr %[[ARG1]], %[[ARG2]] : i32 + ; CHECK: llvm.ashr %[[ARG3]], %[[ARG4]] : i64 + ; CHECK: llvm.and %[[ARG1]], %[[ARG2]] : i32 + ; CHECK: llvm.or %[[ARG3]], %[[ARG4]] : i64 + ; CHECK: llvm.xor %[[ARG1]], %[[ARG2]] : i32 + %1 = add i32 %arg1, -7 + %2 = add i32 42, %arg2 + %3 = sub i64 %arg3, %arg4 + %4 = mul i32 %arg1, %arg2 + %5 = udiv i64 %arg3, %arg4 + %6 = sdiv i32 %arg1, %arg2 + %7 = urem i64 %arg3, %arg4 + %8 = srem i32 %arg1, %arg2 + %9 = shl i64 %arg3, %arg4 + %10 = lshr i32 %arg1, %arg2 + %11 = ashr i64 %arg3, %arg4 + %12 = and i32 %arg1, %arg2 + %13 = or i64 %arg3, %arg4 + %14 = xor i32 %arg1, %arg2 + ret void +} + +; // ----- + +; CHECK-LABEL: @fp_arith +; CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] +; CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]] +; CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]] +; CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]] +define void @fp_arith(float %arg1, float %arg2, double %arg3, double %arg4) { + ; CHECK: %[[C1:[0-9]+]] = llvm.mlir.constant(3.030000e+01 : f64) : f64 + ; CHECK: %[[C2:[0-9]+]] = llvm.mlir.constant(3.030000e+01 : f32) : f32 + ; CHECK: llvm.fadd %[[C2]], %[[ARG1]] : f32 + ; CHECK: llvm.fadd %[[ARG1]], %[[ARG2]] : f32 + ; CHECK: llvm.fadd %[[C1]], %[[ARG3]] : f64 + ; CHECK: llvm.fsub %[[ARG1]], %[[ARG2]] : f32 + ; CHECK: llvm.fmul %[[ARG3]], %[[ARG4]] : f64 + ; CHECK: llvm.fdiv %[[ARG1]], %[[ARG2]] : f32 + ; CHECK: llvm.frem %[[ARG3]], %[[ARG4]] : f64 + ; CHECK: llvm.fneg %[[ARG1]] : f32 + %1 = fadd float 0x403E4CCCC0000000, %arg1 + %2 = fadd float %arg1, %arg2 + %3 = fadd double 3.030000e+01, %arg3 + %4 = fsub float %arg1, %arg2 + %5 = fmul double %arg3, %arg4 + %6 = fdiv float %arg1, %arg2 + %7 = frem double %arg3, %arg4 + %8 = fneg float %arg1 + ret void +} + +; // ----- + +; CHECK-LABEL: @fp_casts +; CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] +; CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]] +; CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]] +define void @fp_casts(float %arg1, double %arg2, i32 %arg3) { + ; CHECK: llvm.fptrunc %[[ARG2]] : f64 to f32 + ; CHECK: llvm.fpext %[[ARG1]] : f32 to f64 + ; CHECK: llvm.fptosi %[[ARG2]] : f64 to i16 + ; CHECK: llvm.fptoui %[[ARG1]] : f32 to i32 + ; CHECK: llvm.sitofp %[[ARG3]] : i32 to f32 + ; CHECK: llvm.uitofp %[[ARG3]] : i32 to f64 + %1 = fptrunc double %arg2 to float + %2 = fpext float %arg1 to double + %3 = fptosi double %arg2 to i16 + %4 = fptoui float %arg1 to i32 + %5 = sitofp i32 %arg3 to float + %6 = uitofp i32 %arg3 to double + ret void +} + +; // ----- + +; CHECK-LABEL: @integer_extension_and_truncation +; CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] +define void @integer_extension_and_truncation(i32 %arg1) { + ; CHECK: llvm.sext %[[ARG1]] : i32 to i64 + ; CHECK: llvm.zext %[[ARG1]] : i32 to i64 + ; CHECK: llvm.trunc %[[ARG1]] : i32 to i16 + %1 = sext i32 %arg1 to i64 + %2 = zext i32 %arg1 to i64 + %3 = trunc i32 %arg1 to i16 + ret void +} + +; // ----- + +; CHECK-LABEL: @pointer_casts +; CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] +; CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]] +define i32* @pointer_casts(double* %arg1, i64 %arg2) { + ; CHECK: %[[NULL:[0-9]+]] = llvm.mlir.null : !llvm.ptr + ; CHECK: llvm.ptrtoint %[[ARG1]] : !llvm.ptr to i64 + ; CHECK: llvm.inttoptr %[[ARG2]] : i64 to !llvm.ptr + ; CHECK: llvm.bitcast %[[ARG1]] : !llvm.ptr to !llvm.ptr + ; CHECK: llvm.return %[[NULL]] : !llvm.ptr + %1 = ptrtoint double* %arg1 to i64 + %2 = inttoptr i64 %arg2 to i64* + %3 = bitcast double* %arg1 to i32* + ret i32* bitcast (double* null to i32*) +} + +; // ----- + +; CHECK-LABEL: @addrspace_casts +; CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] +define ptr addrspace(2) @addrspace_casts(ptr addrspace(1) %arg1) { + ; CHECK: llvm.addrspacecast %[[ARG1]] : !llvm.ptr<1> to !llvm.ptr<2> + ; CHECK: llvm.return {{.*}} : !llvm.ptr<2> + %1 = addrspacecast ptr addrspace(1) %arg1 to ptr addrspace(2) + ret ptr addrspace(2) %1 +} + +; // ----- + +; CHECK-LABEL: @integer_arith +; CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] +; CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]] +; CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]] +; CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]] +define void @integer_arith(i32 %arg1, i32 %arg2, i64 %arg3, i64 %arg4) { + ; CHECK-DAG: %[[C1:[0-9]+]] = llvm.mlir.constant(-7 : i32) : i32 + ; CHECK-DAG: %[[C2:[0-9]+]] = llvm.mlir.constant(42 : i32) : i32 + ; CHECK: llvm.add %[[ARG1]], %[[C1]] : i32 + ; CHECK: llvm.add %[[C2]], %[[ARG2]] : i32 + ; CHECK: llvm.sub %[[ARG3]], %[[ARG4]] : i64 + ; CHECK: llvm.mul %[[ARG1]], %[[ARG2]] : i32 + ; CHECK: llvm.udiv %[[ARG3]], %[[ARG4]] : i64 + ; CHECK: llvm.sdiv %[[ARG1]], %[[ARG2]] : i32 + ; CHECK: llvm.urem %[[ARG3]], %[[ARG4]] : i64 + ; CHECK: llvm.srem %[[ARG1]], %[[ARG2]] : i32 + ; CHECK: llvm.shl %[[ARG3]], %[[ARG4]] : i64 + ; CHECK: llvm.lshr %[[ARG1]], %[[ARG2]] : i32 + ; CHECK: llvm.ashr %[[ARG3]], %[[ARG4]] : i64 + ; CHECK: llvm.and %[[ARG1]], %[[ARG2]] : i32 + ; CHECK: llvm.or %[[ARG3]], %[[ARG4]] : i64 + ; CHECK: llvm.xor %[[ARG1]], %[[ARG2]] : i32 + %1 = add i32 %arg1, -7 + %2 = add i32 42, %arg2 + %3 = sub i64 %arg3, %arg4 + %4 = mul i32 %arg1, %arg2 + %5 = udiv i64 %arg3, %arg4 + %6 = sdiv i32 %arg1, %arg2 + %7 = urem i64 %arg3, %arg4 + %8 = srem i32 %arg1, %arg2 + %9 = shl i64 %arg3, %arg4 + %10 = lshr i32 %arg1, %arg2 + %11 = ashr i64 %arg3, %arg4 + %12 = and i32 %arg1, %arg2 + %13 = or i64 %arg3, %arg4 + %14 = xor i32 %arg1, %arg2 + ret void +} + +; // ----- + +; CHECK-LABEL: @extract_element +; CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]] +; CHECK-SAME: %[[IDX:[a-zA-Z0-9]+]] +define half @extract_element(<4 x half>* %vec, i32 %idx) { + ; CHECK: %[[V1:.+]] = llvm.load %[[VEC]] : !llvm.ptr> + ; CHECK: %[[V2:.+]] = llvm.extractelement %[[V1]][%[[IDX]] : i32] : vector<4xf16> + ; CHECK: llvm.return %[[V2]] + %1 = load <4 x half>, <4 x half>* %vec + %2 = extractelement <4 x half> %1, i32 %idx + ret half %2 +} + +; // ----- + +; CHECK-LABEL: @insert_element +; CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]] +; CHECK-SAME: %[[VAL:[a-zA-Z0-9]+]] +; CHECK-SAME: %[[IDX:[a-zA-Z0-9]+]] +define <4 x half> @insert_element(<4 x half>* %vec, half %val, i32 %idx) { + ; CHECK: %[[V1:.+]] = llvm.load %[[VEC]] : !llvm.ptr> + ; CHECK: %[[V2:.+]] = llvm.insertelement %[[VAL]], %[[V1]][%[[IDX]] : i32] : vector<4xf16> + ; CHECK: llvm.return %[[V2]] + %1 = load <4 x half>, <4 x half>* %vec + %2 = insertelement <4 x half> %1, half %val, i32 %idx + ret <4 x half> %2 +} + +; // ----- + +; CHECK-LABEL: @select +; CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] +; CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]] +; CHECK-SAME: %[[COND:[a-zA-Z0-9]+]] +define void @select(i32 %arg0, i32 %arg1, i1 %cond) { + ; CHECK: llvm.select %[[COND]], %[[ARG1]], %[[ARG2]] : i1, i32 + %1 = select i1 %cond, i32 %arg0, i32 %arg1 + ret void +} + +; // ----- + +; CHECK-LABEL: @alloca +; CHECK-SAME: %[[SIZE:[a-zA-Z0-9]+]] +define double* @alloca(i64 %size) { + ; CHECK: %[[C1:[0-9]+]] = llvm.mlir.constant(1 : i32) : i32 + ; CHECK: llvm.alloca %[[C1]] x f64 {alignment = 8 : i64} : (i32) -> !llvm.ptr + ; CHECK: llvm.alloca %[[SIZE]] x i32 {alignment = 8 : i64} : (i64) -> !llvm.ptr + ; CHECK: llvm.alloca %[[SIZE]] x i32 {alignment = 4 : i64} : (i64) -> !llvm.ptr + %1 = alloca double + %2 = alloca i32, i64 %size, align 8 + %3 = alloca i32, i64 %size, addrspace(3) + ret double* %1 +} + +; // ----- + +; CHECK-LABEL: @load_store +; CHECK-SAME: %[[PTR:[a-zA-Z0-9]+]] +define void @load_store(double* %ptr) { + ; CHECK: %[[V1:[0-9]+]] = llvm.load %[[PTR]] : !llvm.ptr + ; CHECK: llvm.store %[[V1]], %[[PTR]] : !llvm.ptr + %1 = load double, double* %ptr + store double %1, double* %ptr + ret void +} + +; // ----- + +; CHECK-LABEL: @freeze +; CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] +define void @freeze(i32 %arg1) { + ; CHECK: %[[UNDEF:[0-9]+]] = llvm.mlir.undef : i64 + ; CHECK: llvm.freeze %[[ARG1]] : i32 + ; CHECK: llvm.freeze %[[UNDEF]] : i64 + %1 = freeze i32 %arg1 + %2 = freeze i64 undef + ret void +} + +; // ----- + +; CHECK-LABEL: @unreachable +define void @unreachable() { + ; CHECK: llvm.unreachable + unreachable +} + +; // ----- + +; CHECK-LABEL: @fence +define void @fence() { + ; CHECK: llvm.fence syncscope("agent") seq_cst + ; CHECK: llvm.fence release + ; CHECK: llvm.fence seq_cst + fence syncscope("agent") seq_cst + fence release + fence syncscope("") seq_cst + ret void +} diff --git a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp --- a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp +++ b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp @@ -185,10 +185,14 @@ return false; } -// Emit an intrinsic identifier driven check and a call to the builder of the -// MLIR LLVM dialect intrinsic operation to build for the given LLVM IR -// intrinsic identifier. -static LogicalResult emitOneIntrBuilder(const Record &record, raw_ostream &os) { +using ConditionFn = mlir::function_ref; + +// Emit a conditional call to the MLIR builder of the LLVM dialect operation to +// build for the given LLVM IR instruction. A condition function `conditionFn` +// emits a check to verify the opcode or intrinsic identifier of the LLVM IR +// instruction matches the LLVM dialect operation to build. +static LogicalResult emitOneMLIRBuilder(const Record &record, raw_ostream &os, + ConditionFn conditionFn) { auto op = tblgen::Operator(record); if (!record.getValue("mlirBuilder")) @@ -240,8 +244,7 @@ } // Output the check and the builder string. - os << "if (intrinsicID == llvm::Intrinsic::" - << record.getValueAsString("llvmEnumName") << ") {\n"; + os << "if (" << conditionFn(record) << ") {\n"; os << bs.str() << builderStrRef << "\n"; os << " return success();\n"; os << "}\n"; @@ -249,13 +252,35 @@ return success(); } -// Emit all intrinsic builders. Returns false on success because of the +// Emit all intrinsic MLIR builders. Returns false on success because of the // generator registration requirements. -static bool emitIntrBuilders(const RecordKeeper &recordKeeper, - raw_ostream &os) { +static bool emitIntrMLIRBuilders(const RecordKeeper &recordKeeper, + raw_ostream &os) { + // Emit condition to check if "llvmEnumName" matches the intrinsic id. + auto emitIntrCond = [](const Record &record) { + return "intrinsicID == llvm::Intrinsic::" + + record.getValueAsString("llvmEnumName"); + }; for (const Record *def : recordKeeper.getAllDerivedDefinitions("LLVM_IntrOpBase")) { - if (failed(emitOneIntrBuilder(*def, os))) + if (failed(emitOneMLIRBuilder(*def, os, emitIntrCond))) + return true; + } + return false; +} + +// Emit all op builders. Returns false on success because of the +// generator registration requirements. +static bool emitOpMLIRBuilders(const RecordKeeper &recordKeeper, + raw_ostream &os) { + // Emit condition to check if "llvmInstName" matches the instruction opcode. + auto emitOpcodeCond = [](const Record &record) { + return "inst->getOpcode() == llvm::Instruction::" + + record.getValueAsString("llvmInstName"); + }; + for (const Record *def : + recordKeeper.getAllDerivedDefinitions("LLVM_OpBase")) { + if (failed(emitOneMLIRBuilder(*def, os, emitOpcodeCond))) return true; } return false; @@ -485,10 +510,13 @@ genLLVMIRConversions("gen-llvmir-conversions", "Generate LLVM IR conversions", emitBuilders); -static mlir::GenRegistration - genIntrFromLLVMIRConversions("gen-intr-from-llvmir-conversions", - "Generate intrinsic conversions from LLVM IR", - emitIntrBuilders); +static mlir::GenRegistration genOpFromLLVMIRConversions( + "gen-op-from-llvmir-conversions", + "Generate conversions of operations from LLVM IR", emitOpMLIRBuilders); + +static mlir::GenRegistration genIntrFromLLVMIRConversions( + "gen-intr-from-llvmir-conversions", + "Generate conversions of intrinsics from LLVM IR", emitIntrMLIRBuilders); static mlir::GenRegistration genEnumToLLVMConversion("gen-enum-to-llvmir-conversions",