diff --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp --- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp @@ -73,11 +73,12 @@ /// is unavailable. Location processDebugLoc(const llvm::DebugLoc &loc, llvm::Instruction *inst = nullptr); - /// `br` branches to `target`. Return the block arguments to attach to the - /// generated branch op. These should be in the same order as the PHIs in - /// `target`. - SmallVector processBranchArgs(llvm::BranchInst *br, - llvm::BasicBlock *target); + /// `br` branches to `target`. Append the block arguments to attach to the + /// generated branch op to `blockArguments`. These should be in the same order + /// as the PHIs in `target`. + LogicalResult processBranchArgs(llvm::BranchInst *br, + llvm::BasicBlock *target, + SmallVectorImpl &blockArguments); /// Returns the standard type equivalent to be used in attributes for the /// given LLVM IR dialect type. Type getStdTypeForAttr(LLVMType type); @@ -151,17 +152,27 @@ return LLVMType::getDoubleTy(dialect); case llvm::Type::IntegerTyID: return LLVMType::getIntNTy(dialect, type->getIntegerBitWidth()); - case llvm::Type::PointerTyID: - return processType(type->getPointerElementType()) - .getPointerTo(type->getPointerAddressSpace()); - case llvm::Type::ArrayTyID: - return LLVMType::getArrayTy(processType(type->getArrayElementType()), - type->getArrayNumElements()); + case llvm::Type::PointerTyID: { + LLVMType elementType = processType(type->getPointerElementType()); + if (!elementType) + return nullptr; + return elementType.getPointerTo(type->getPointerAddressSpace()); + } + case llvm::Type::ArrayTyID: { + LLVMType elementType = processType(type->getArrayElementType()); + if (!elementType) + return nullptr; + return LLVMType::getArrayTy(elementType, type->getArrayNumElements()); + } case llvm::Type::VectorTyID: { - if (type->getVectorIsScalable()) + if (type->getVectorIsScalable()) { emitError(unknownLoc) << "scalable vector types not supported"; - return LLVMType::getVectorTy(processType(type->getVectorElementType()), - type->getVectorNumElements()); + return nullptr; + } + LLVMType elementType = processType(type->getVectorElementType()); + if (!elementType) + return nullptr; + return LLVMType::getVectorTy(elementType, type->getVectorNumElements()); } case llvm::Type::VoidTyID: return LLVMType::getVoidTy(dialect); @@ -171,18 +182,30 @@ return LLVMType::getX86_FP80Ty(dialect); case llvm::Type::StructTyID: { SmallVector elementTypes; - for (unsigned i = 0, e = type->getStructNumElements(); i != e; ++i) - elementTypes.push_back(processType(type->getStructElementType(i))); + elementTypes.reserve(type->getStructNumElements()); + for (unsigned i = 0, e = type->getStructNumElements(); i != e; ++i) { + LLVMType ty = processType(type->getStructElementType(i)); + if (!ty) + return nullptr; + elementTypes.push_back(ty); + } return LLVMType::getStructTy(dialect, elementTypes, cast(type)->isPacked()); } case llvm::Type::FunctionTyID: { llvm::FunctionType *fty = cast(type); SmallVector paramTypes; - for (unsigned i = 0, e = fty->getNumParams(); i != e; ++i) - paramTypes.push_back(processType(fty->getParamType(i))); - return LLVMType::getFunctionTy(processType(fty->getReturnType()), - paramTypes, fty->isVarArg()); + for (unsigned i = 0, e = fty->getNumParams(); i != e; ++i) { + LLVMType ty = processType(fty->getParamType(i)); + if (!ty) + return nullptr; + paramTypes.push_back(ty); + } + LLVMType result = processType(fty->getReturnType()); + if (!result) + return nullptr; + + return LLVMType::getFunctionTy(result, paramTypes, fty->isVarArg()); } default: { // FIXME: Diagnostic should be able to natively handle types that have @@ -191,7 +214,7 @@ llvm::raw_string_ostream os(s); os << *type; emitError(unknownLoc) << "unhandled type: " << os.str(); - return {}; + return nullptr; } } } @@ -217,10 +240,14 @@ // LLVM vectors can only contain scalars. if (type.isVectorTy()) { auto numElements = type.getUnderlyingType()->getVectorElementCount(); - if (numElements.Scalable) + if (numElements.Scalable) { emitError(unknownLoc) << "scalable vectors not supported"; - return VectorType::get(numElements.Min, - getStdTypeForAttr(type.getVectorElementType())); + return nullptr; + } + Type elementType = getStdTypeForAttr(type.getVectorElementType()); + if (!elementType) + return nullptr; + return VectorType::get(numElements.Min, elementType); } // LLVM arrays can contain other arrays or vectors. @@ -239,20 +266,26 @@ LLVMType vectorType = type.getArrayElementType(); auto numElements = vectorType.getUnderlyingType()->getVectorElementCount(); - if (numElements.Scalable) + if (numElements.Scalable) { emitError(unknownLoc) << "scalable vectors not supported"; + return nullptr; + } shape.push_back(numElements.Min); - LLVMType elementType = vectorType.getVectorElementType(); - return VectorType::get(shape, getStdTypeForAttr(elementType)); + Type elementType = getStdTypeForAttr(vectorType.getVectorElementType()); + if (!elementType) + return nullptr; + return VectorType::get(shape, elementType); } // Otherwise use a tensor. - return RankedTensorType::get(shape, - getStdTypeForAttr(type.getArrayElementType())); + Type elementType = getStdTypeForAttr(type.getArrayElementType()); + if (!elementType) + return nullptr; + return RankedTensorType::get(shape, elementType); } - llvm_unreachable("no equivalent standard type for typed attributes"); + return nullptr; } // Get the given constant as an attribute. Not all constants can be represented @@ -277,9 +310,11 @@ // Convert constant data to a dense elements attribute. if (auto *cd = dyn_cast(value)) { LLVMType type = processType(cd->getElementType()); + if (!type) + return nullptr; + auto attrType = getStdTypeForAttr(processType(cd->getType())) .dyn_cast_or_null(); - assert(attrType); if (!attrType) return nullptr; @@ -368,15 +403,19 @@ Attribute valueAttr; if (GV->hasInitializer()) valueAttr = getConstantAsAttr(GV->getInitializer()); + LLVMType type = processType(GV->getValueType()); + if (!type) + return nullptr; GlobalOp op = b.create( - UnknownLoc::get(context), processType(GV->getValueType()), - GV->isConstant(), processLinkage(GV->getLinkage()), GV->getName(), - valueAttr); + UnknownLoc::get(context), type, GV->isConstant(), + processLinkage(GV->getLinkage()), GV->getName(), valueAttr); if (GV->hasInitializer() && !valueAttr) { Region &r = op.getInitializerRegion(); currentEntryBlock = b.createBlock(&r); b.setInsertionPoint(currentEntryBlock, currentEntryBlock->begin()); Value v = processConstant(GV->getInitializer()); + if (!v) + return nullptr; b.create(op.getLoc(), ArrayRef({v})); } return globals[GV] = op; @@ -386,13 +425,17 @@ if (Attribute attr = getConstantAsAttr(c)) { // These constants can be represented as attributes. OpBuilder b(currentEntryBlock, currentEntryBlock->begin()); - return instMap[c] = b.create(unknownLoc, - processType(c->getType()), attr); + LLVMType type = processType(c->getType()); + if (!type) + return nullptr; + return instMap[c] = b.create(unknownLoc, type, attr); } if (auto *cn = dyn_cast(c)) { OpBuilder b(currentEntryBlock, currentEntryBlock->begin()); - return instMap[c] = - b.create(unknownLoc, processType(cn->getType())); + LLVMType type = processType(cn->getType()); + if (!type) + return nullptr; + return instMap[c] = b.create(unknownLoc, type); } if (auto *ce = dyn_cast(c)) { llvm::Instruction *i = ce->getAsInstruction(); @@ -420,13 +463,19 @@ // this instruction yet, create an unknown op and remap it later. if (isa(value)) { OperationState state(UnknownLoc::get(context), "unknown"); - state.addTypes({processType(value->getType())}); + LLVMType type = processType(value->getType()); + if (!type) + return nullptr; + state.addTypes(type); unknownInstMap[value] = b.createOperation(state); return unknownInstMap[value]->getResult(0); } if (auto *GV = dyn_cast(value)) { - return b.create(UnknownLoc::get(context), processGlobal(GV), + auto global = processGlobal(GV); + if (!global) + return nullptr; + return b.create(UnknownLoc::get(context), global, ArrayRef()); } @@ -520,14 +569,17 @@ // `br` branches to `target`. Return the branch arguments to `br`, in the // same order of the PHIs in `target`. -SmallVector Importer::processBranchArgs(llvm::BranchInst *br, - llvm::BasicBlock *target) { - SmallVector v; +LogicalResult +Importer::processBranchArgs(llvm::BranchInst *br, llvm::BasicBlock *target, + SmallVectorImpl &blockArguments) { for (auto inst = target->begin(); isa(inst); ++inst) { auto *PN = cast(&*inst); - v.push_back(processValue(PN->getIncomingValueForBlock(br->getParent()))); + Value value = processValue(PN->getIncomingValueForBlock(br->getParent())); + if (!value) + return failure(); + blockArguments.push_back(value); } - return v; + return success(); } LogicalResult Importer::processInstruction(llvm::Instruction *inst) { @@ -577,20 +629,32 @@ OperationState state(loc, opcMap.lookup(inst->getOpcode())); SmallVector ops; ops.reserve(inst->getNumOperands()); - for (auto *op : inst->operand_values()) - ops.push_back(processValue(op)); + for (auto *op : inst->operand_values()) { + Value value = processValue(op); + if (!value) + return failure(); + ops.push_back(value); + } state.addOperands(ops); - if (!inst->getType()->isVoidTy()) - state.addTypes(ArrayRef({processType(inst->getType())})); + if (!inst->getType()->isVoidTy()) { + LLVMType type = processType(inst->getType()); + if (!type) + return failure(); + state.addTypes(type); + } Operation *op = b.createOperation(state); if (!inst->getType()->isVoidTy()) v = op->getResult(0); return success(); } case llvm::Instruction::ICmp: { + Value lhs = processValue(inst->getOperand(0)); + Value rhs = processValue(inst->getOperand(1)); + if (!lhs || !rhs) + return failure(); v = b.create( - loc, getICmpPredicate(cast(inst)->getPredicate()), - processValue(inst->getOperand(0)), processValue(inst->getOperand(1))); + loc, getICmpPredicate(cast(inst)->getPredicate()), lhs, + rhs); return success(); } case llvm::Instruction::Br: { @@ -598,35 +662,57 @@ OperationState state(loc, brInst->isConditional() ? "llvm.cond_br" : "llvm.br"); SmallVector ops; - if (brInst->isConditional()) - ops.push_back(processValue(brInst->getCondition())); + if (brInst->isConditional()) { + Value condition = processValue(brInst->getCondition()); + if (!condition) + return failure(); + ops.push_back(condition); + } state.addOperands(ops); SmallVector succs; - for (auto *succ : llvm::reverse(brInst->successors())) - state.addSuccessor(blocks[succ], processBranchArgs(brInst, succ)); + for (auto *succ : llvm::reverse(brInst->successors())) { + SmallVector blockArguments; + if (failed(processBranchArgs(brInst, succ, blockArguments))) + return failure(); + state.addSuccessor(blocks[succ], blockArguments); + } b.createOperation(state); return success(); } case llvm::Instruction::PHI: { - v = b.getInsertionBlock()->addArgument(processType(inst->getType())); + LLVMType type = processType(inst->getType()); + if (!type) + return failure(); + v = b.getInsertionBlock()->addArgument(type); return success(); } case llvm::Instruction::Call: { llvm::CallInst *ci = cast(inst); SmallVector ops; ops.reserve(inst->getNumOperands()); - for (auto &op : ci->arg_operands()) - ops.push_back(processValue(op.get())); + for (auto &op : ci->arg_operands()) { + Value arg = processValue(op.get()); + if (!arg) + return failure(); + ops.push_back(arg); + } SmallVector tys; - if (!ci->getType()->isVoidTy()) - tys.push_back(processType(inst->getType())); + if (!ci->getType()->isVoidTy()) { + LLVMType type = processType(inst->getType()); + if (!type) + return failure(); + tys.push_back(type); + } Operation *op; if (llvm::Function *callee = ci->getCalledFunction()) { op = b.create(loc, tys, b.getSymbolRefAttr(callee->getName()), ops); } else { - ops.insert(ops.begin(), processValue(ci->getCalledValue())); + Value calledValue = processValue(ci->getCalledValue()); + if (!calledValue) + return failure(); + ops.insert(ops.begin(), calledValue); op = b.create(loc, tys, ops, ArrayRef()); } if (!ci->getType()->isVoidTy()) @@ -637,10 +723,16 @@ // FIXME: Support inbounds GEPs. llvm::GetElementPtrInst *gep = cast(inst); SmallVector ops; - for (auto *op : gep->operand_values()) - ops.push_back(processValue(op)); - v = b.create(loc, processType(inst->getType()), ops, - ArrayRef()); + for (auto *op : gep->operand_values()) { + Value value = processValue(op); + if (!value) + return failure(); + ops.push_back(value); + } + Type type = processType(inst->getType()); + if (!type) + return failure(); + v = b.create(loc, type, ops, ArrayRef()); return success(); } } @@ -651,9 +743,13 @@ instMap.clear(); unknownInstMap.clear(); + LLVMType functionType = processType(f->getFunctionType()); + if (!functionType) + return failure(); + b.setInsertionPoint(module.getBody(), getFuncInsertPt()); LLVMFuncOp fop = b.create(UnknownLoc::get(context), f->getName(), - processType(f->getFunctionType())); + functionType); if (f->isDeclaration()) return success(); @@ -666,8 +762,9 @@ currentEntryBlock = blockList[0]; // Add function arguments to the entry block. - for (auto &arg : f->args()) - instMap[&arg] = blockList[0]->addArgument(processType(arg.getType())); + for (auto kv : llvm::enumerate(f->args())) + instMap[&kv.value()] = blockList[0]->addArgument( + functionType.getFunctionParamType(kv.index())); for (auto bbs : llvm::zip(*f, blockList)) { if (failed(processBasicBlock(&std::get<0>(bbs), std::get<1>(bbs))))