diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -337,6 +337,25 @@ let parser = [{ return parseCallOp(parser, result); }]; let printer = [{ printCallOp(p, *this); }]; } +def LLVM_InvokeOp : LLVM_Op<"invoke", [Terminator]>, + Arguments<(ins OptionalAttr:$callee, + Variadic:$args)>, + Results<(outs Variadic)>, + LLVM_TwoBuilders { + let verifier = [{ + if (getNumResults() > 1) + return emitOpError("must have 0 or 1 result"); + if (getNumSuccessors() != 2) + return emitOpError("must have a success and failure successor"); + for(auto &inst : *getSuccessor(1)) { + // FIXME: Ignore all phi inst and check for landingpad + } + return success(); + }]; + let parser = [{ return parseInvokeOp(parser, result); }]; + let printer = [{ printInvokeOp(p, *this); }]; +} def LLVM_ExtractElementOp : LLVM_OneResultOp<"extractelement", [NoSideEffect]>, Arguments<(ins LLVM_Type:$vector, LLVM_Type:$position)> { @@ -420,6 +439,12 @@ let parser = [{ return parseSelectOp(parser, result); }]; let printer = [{ printSelectOp(p, *this); }]; } +def LLVM_LandingPadOp + : LLVM_OneResultOp<"landingpad", [NoSideEffect]>, + Arguments<(ins BoolAttr:$cleanup, StrArrayAttr:$clauseTypes, Variadic:$clauses)> { + let parser = [{ return parseLandingPadOp(parser, result); }]; + let printer = [{ printLandingPadOp(p, *this); }]; +} // Terminators. def LLVM_BrOp : LLVM_TerminatorOp<"br", []> { diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -400,6 +400,131 @@ return success(); } +//===----------------------------------------------------------------------===// +// Printing/parsing for LLVM::InvokeOp. +//===----------------------------------------------------------------------===// + +static void printInvokeOp(OpAsmPrinter &p, InvokeOp &op) { + auto callee = op.callee(); + bool isDirect = callee.hasValue(); + + // Print the direct callee if present as a function attribute, or an indirect + // callee (first operand) otherwise. + p << op.getOperationName() << ' '; + if (isDirect) + p.printSymbolName(callee.getValue()); + else + p << *op.getOperand(0); + + p << '(' << op.getOperands().drop_front(isDirect ? 0 : 1) << ')'; + p.printOptionalAttrDict(op.getAttrs(), {"callee"}); + p << " "; + p.printSuccessorAndUseList(op.getOperation(), 0); // success block + p << ", "; + p.printSuccessorAndUseList(op.getOperation(), 1); // fail block + + // Reconstruct the function MLIR function type from operand and result types. + SmallVector resultTypes(op.getResultTypes()); + SmallVector argTypes( + llvm::drop_begin(op.getOperandTypes(), isDirect ? 0 : 1)); + p << " : " << FunctionType::get(argTypes, resultTypes, op.getContext()); +} + +// ::= `llvm.call` (function-id | ssa-use) `(` ssa-use-list `)` +// attribute-dict? `:` function-type +static ParseResult parseInvokeOp(OpAsmParser &parser, OperationState &result) { + SmallVector attrs; + SmallVector operands; + Block *successDest; + Block *failureDest; + SmallVector successOperands; + SmallVector failureOperands; + Type type; + SymbolRefAttr funcAttr; + llvm::SMLoc trailingTypeLoc; + + // Parse an operand list that will, in practice, contain 0 or 1 operand. In + // case of an indirect call, there will be 1 operand before `(`. In case of a + // direct call, there will be no operands and the parser will stop at the + // function identifier without complaining. + if (parser.parseOperandList(operands)) + return failure(); + bool isDirect = operands.empty(); + + // Optionally parse a function identifier. + if (isDirect) + if (parser.parseAttribute(funcAttr, "callee", attrs)) + return failure(); + + if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || + parser.parseOptionalAttrDict(attrs) || + parser.parseSuccessorAndUseList(successDest, successOperands) || + parser.parseComma() || + parser.parseSuccessorAndUseList(failureDest, failureOperands) || + parser.parseColon() || parser.getCurrentLocation(&trailingTypeLoc) || + parser.parseType(type)) + return failure(); + auto funcType = type.dyn_cast(); + if (!funcType) + return parser.emitError(trailingTypeLoc, "expected function type"); + if (isDirect) { + // Make sure types match. + if (parser.resolveOperands(operands, funcType.getInputs(), + parser.getNameLoc(), result.operands)) + return failure(); + result.addTypes(funcType.getResults()); + } else { + // Construct the LLVM IR Dialect function type that the first operand + // should match. + if (funcType.getNumResults() > 1) + return parser.emitError(trailingTypeLoc, + "expected function with 0 or 1 result"); + + Builder &builder = parser.getBuilder(); + auto *llvmDialect = + builder.getContext()->getRegisteredDialect(); + LLVM::LLVMType llvmResultType; + if (funcType.getNumResults() == 0) { + llvmResultType = LLVM::LLVMType::getVoidTy(llvmDialect); + } else { + llvmResultType = funcType.getResult(0).dyn_cast(); + if (!llvmResultType) + return parser.emitError(trailingTypeLoc, + "expected result to have LLVM type"); + } + + SmallVector argTypes; + argTypes.reserve(funcType.getNumInputs()); + for (int i = 0, e = funcType.getNumInputs(); i < e; ++i) { + auto argType = funcType.getInput(i).dyn_cast(); + if (!argType) + return parser.emitError(trailingTypeLoc, + "expected LLVM types as inputs"); + argTypes.push_back(argType); + } + auto llvmFuncType = LLVM::LLVMType::getFunctionTy(llvmResultType, argTypes, + /*isVarArg=*/false); + auto wrappedFuncType = llvmFuncType.getPointerTo(); + + auto funcArguments = + ArrayRef(operands).drop_front(); + + // Make sure that the first operand (indirect callee) matches the wrapped + // LLVM IR function type, and that the types of the other call operands + // match the types of the function arguments. + if (parser.resolveOperand(operands[0], wrappedFuncType, result.operands) || + parser.resolveOperands(funcArguments, funcType.getInputs(), + parser.getNameLoc(), result.operands)) + return failure(); + + result.addTypes(llvmResultType); + } + result.addSuccessor(successDest, successOperands); + result.addSuccessor(failureDest, failureOperands); + result.attributes = attrs; + return success(); +} + //===----------------------------------------------------------------------===// // Printing/parsing for LLVM::ExtractElementOp. //===----------------------------------------------------------------------===// @@ -658,6 +783,22 @@ return success(); } +// LandingPad Op +static void printLandingPadOp(OpAsmPrinter &p, LandingPadOp &op) { + p << op.getOperationName() << ' '; + p.printOptionalAttrDict(op.getAttrs()); + p << ' '; + p.printOperands(op.getOperands()); + p << " : "; + p.printType(op.getType()); +} + +static ParseResult parseLandingPadOp(OpAsmParser &parser, + OperationState &result) { + llvm::dbgs() << "Parsing LandingPadOp\n"; + return success(); +} + //===----------------------------------------------------------------------===// // Printing/parsing for LLVM::BrOp. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp --- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp @@ -76,7 +76,7 @@ /// `br` branches to `target`. Return the block arguments to attach to the /// generated branch op. These should be in the same order as the PHIs in /// `target`. - SmallVector processBranchArgs(llvm::BranchInst *br, + SmallVector processBranchArgs(llvm::Instruction *br, llvm::BasicBlock *target); /// Return `value` as an attribute to attach to a GlobalOp. Attribute getConstantAsAttr(llvm::Constant *value); @@ -196,19 +196,23 @@ // Get the given constant as an attribute. Not all constants can be represented // as attributes. Attribute Importer::getConstantAsAttr(llvm::Constant *value) { + value->dump(); if (auto *ci = dyn_cast(value)) return b.getIntegerAttr( IntegerType::get(ci->getType()->getBitWidth(), context), ci->getValue()); - if (auto *c = dyn_cast(value)) - if (c->isString()) + if (auto *c = dyn_cast(value)) { + if (c->isString()) { return b.getStringAttr(c->getAsString()); + } + } if (auto *c = dyn_cast(value)) { if (c->getType()->isDoubleTy()) return b.getFloatAttr(FloatType::getF64(context), c->getValueAPF()); else if (c->getType()->isFloatingPointTy()) return b.getFloatAttr(FloatType::getF32(context), c->getValueAPF()); } + // FIXME: Constant array is not handled - used by filter in launchpad return Attribute(); } @@ -295,6 +299,7 @@ return nullptr; } +// FIXME: processValue breaks for function pointers Value Importer::processValue(llvm::Value *value) { auto it = instMap.find(value); if (it != instMap.end()) @@ -333,7 +338,7 @@ // Br is handled specially. // FIXME: switch // FIXME: indirectbr - // FIXME: invoke + INST(Invoke, Invoke), // FIXME: resume // FIXME: unreachable // FIXME: cleanupret @@ -370,8 +375,7 @@ // FIXME: shufflevector // FIXME: extractvalue // FIXME: insertvalue - // FIXME: landingpad -}; + INST(LandingPad, LandingPad)}; #undef INST static ICmpPredicate getICmpPredicate(llvm::CmpInst::Predicate p) { @@ -404,7 +408,7 @@ // `br` branches to `target`. Return the branch arguments to `br`, in the // same order of the PHIs in `target`. -SmallVector Importer::processBranchArgs(llvm::BranchInst *br, +SmallVector Importer::processBranchArgs(llvm::Instruction *br, llvm::BasicBlock *target) { SmallVector v; for (auto inst = target->begin(); isa(inst); ++inst) { @@ -527,6 +531,63 @@ ArrayRef()); return success(); } + case llvm::Instruction::Invoke: { + auto *invokeInst = cast(inst); + SmallVector ops; + ops.reserve(inst->getNumOperands()); + for (auto &op : invokeInst->arg_operands()) { + ops.push_back(processValue(op.get())); + } + SmallVector tys; + if (!invokeInst->getType()->isVoidTy()) { + tys.push_back(processType(inst->getType())); + } + Operation *op; + OperationState state(loc, "llvm.invoke"); + if (llvm::Function *callee = invokeInst->getCalledFunction()) { + state.addAttribute(b.getIdentifier("callee"), + b.getSymbolRefAttr(callee->getName())); + } else { + ops.insert(ops.begin(), processValue(invokeInst->getCalledValue())); + } + state.addOperands(ops); + state.addTypes(tys); + llvm::BasicBlock *normalDest = invokeInst->getNormalDest(); + llvm::BasicBlock *unwindDest = invokeInst->getUnwindDest(); + state.addSuccessor(blocks[normalDest], + processBranchArgs(invokeInst, normalDest)); + state.addSuccessor(blocks[unwindDest], + processBranchArgs(invokeInst, unwindDest)); + b.createOperation(state); + if (!invokeInst->getType()->isVoidTy()) + v = op->getResult(0); + return success(); + } + case llvm::Instruction::LandingPad: { + auto *lp = dyn_cast(inst); + + SmallVector tys; + tys.push_back(processType(lp->getType())); + + SmallVector ops; + SmallVector clauseTys; + ops.reserve(lp->getNumClauses()); + clauseTys.reserve(lp->getNumClauses()); + for (int i = 0, ie = lp->getNumClauses(); i != ie; i++) { + clauseTys.push_back(lp->isCatch(i) ? "catch" : "filter"); + ops.push_back(processValue(lp->getClause(i))); + } + OperationState state(loc, "llvm.landingpad"); + state.addTypes(tys); + state.addAttribute(b.getIdentifier("cleanup"), + b.getBoolAttr(lp->isCleanup())); + state.addAttribute(b.getIdentifier("clauseTypes"), + b.getStrArrayAttr(clauseTys)); + state.addOperands(ops); + Operation *op = b.createOperation(state); + v = op->getResult(0); + return success(); + } } } diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -194,6 +194,23 @@ return success(result->getType()->isVoidTy()); } + if (auto invOp = dyn_cast(opInst)) { + auto operands = lookupValues(opInst.getOperands()); + ArrayRef operandsRef(operands); + if (auto attr = opInst.getAttrOfType("callee")) { + builder.CreateInvoke(functionMapping.lookup(attr.getValue()), + blockMapping[invOp.getSuccessor(0)], + blockMapping[invOp.getSuccessor(1)], operandsRef); + // builder.CreateCall(functionMapping.lookup(attr.getValue()), + // operandsRef); + } else { + builder.CreateInvoke( + operandsRef.front(), blockMapping[invOp.getSuccessor(0)], + blockMapping[invOp.getSuccessor(1)], operandsRef.drop_front()); + // builder.CreateCall(operandsRef.front(), operandsRef.drop_front()); + } + return success(); + } // Emit branches. We need to look up the remapped blocks and ignore the block // arguments that were transformed into PHI nodes. if (auto brOp = dyn_cast(opInst)) {