diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -472,6 +472,19 @@ let parser = [{ return parseCondBrOp(parser, result); }]; let printer = [{ printCondBrOp(p, *this); }]; } +def LLVM_IndirectBrOp : LLVM_TerminatorOp<"indirect_br", []> { + let arguments = (ins LLVM_Type:$address); + let builders = [OpBuilder< + "Builder *b, OperationState &result, Value address," + "DenseMap> successors", + [{ + result.addOperands(address); + for (auto successor : successors) + result.addSuccessor(successor.getFirst(), successor.getSecond()); + }]>]; + let printer = [{ printIndirectBrOp(p, *this); }]; + let parser = [{ return parseIndirectBrOp(parser, result); }]; +} def LLVM_ReturnOp : LLVM_TerminatorOp<"return", []> { string llvmBuilder = [{ if ($_numOperands != 0) diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -844,6 +844,53 @@ return success(); } +///===---------------------------------------------------------------------===// +/// Printing/parsing for LLVM::IndirectBrOp. +///===---------------------------------------------------------------------===// + +static void printIndirectBrOp(OpAsmPrinter &p, IndirectBrOp op) { + p << op.getOperationName() << ' ' << op.address() << " : " + << op.address().getType() << " ["; + interleaveComma(llvm::seq(0, op.getNumSuccessors()), p, + [&](int i) { p.printSuccessorAndUseList(op, i); }); + p << "]"; +} + +/// ::= `llvm.indirect_br` ssa-use `,` +/// `[` bb-id (`[` ssa-use-and-type-list `]`)? +/// (`,` bb-id (`[` ssa-use-and-type-list `]`)?)* `]` +/// attribute-dict? +static ParseResult parseIndirectBrOp(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::OperandType address; + Type ty; + + if (parser.parseOperand(address) || parser.parseColonType(ty) || + parser.resolveOperand(address, ty, result.operands) || + parser.parseLSquare()) + return failure(); + + // Successors are stored in a DenseMap before adding them to result to avoid + // duplication of labels. + DenseMap> successors; + + do { + Block *dest; + SmallVector oper; + if (parser.parseSuccessorAndUseList(dest, oper)) + return failure(); + successors[dest] = oper; + } while (succeeded(parser.parseOptionalComma())); + + for (auto successor : successors) + result.addSuccessor(successor.getFirst(), successor.getSecond()); + + if (parser.parseRSquare() || parser.parseOptionalAttrDict(result.attributes)) + return failure(); + + return success(); +} + //===----------------------------------------------------------------------===// // Printing/parsing for LLVM::ReturnOp. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp --- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp @@ -645,6 +645,21 @@ b.createOperation(state); return success(); } + case llvm::Instruction::IndirectBr: { + auto *indBr = cast(inst); + Value address = processValue(indBr->getAddress()); + if (!address) + return failure(); + DenseMap> succs; + for (auto succ : indBr->successors()) { + SmallVector blockArguments; + if (failed(processBranchArgs(indBr, succ, blockArguments))) + return failure(); + succs[blocks[succ]] = blockArguments; + } + b.create(loc, address, succs); + return success(); + } case llvm::Instruction::PHI: { LLVMType type = processType(inst->getType()); if (!type) diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -361,6 +361,14 @@ blockMapping[condbrOp.getSuccessor(1)]); return success(); } + if (auto indBrOp = dyn_cast(opInst)) { + unsigned nSucc = indBrOp.getNumSuccessors(); + llvm::IndirectBrInst *ibi = + builder.CreateIndirectBr(valueMapping.lookup(indBrOp.address()), nSucc); + for (unsigned idx = 0; idx < nSucc; idx++) + ibi->addDestination(blockMapping[indBrOp.getSuccessor(idx)]); + return success(); + } // Emit addressof. We need to look up the global value referenced by the // operation and store it in the MLIR-to-LLVM value mapping. This does not diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir --- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir +++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir @@ -282,3 +282,25 @@ ^bb4: llvm.return %0 : !llvm.i32 } + +// CHECK-LABEL: @indirectBr +llvm.func @indirectBr(%arg0: !llvm<"i8*">, %arg1: !llvm<"i32*">) { + %0 = llvm.mlir.constant(2 : i32) : !llvm.i32 + %1 = llvm.mlir.constant(1 : i32) : !llvm.i32 + %2 = llvm.mlir.constant(0 : i32) : !llvm.i32 + // CHECK: llvm.indirect_br %arg0 : !llvm<"i8*"> [^bb{{[12]}}, ^bb{{[12]}}] + llvm.indirect_br %arg0 : !llvm<"i8*"> [^bb2, ^bb1] +^bb1: // pred: ^bb0 + llvm.store %2, %arg1 : !llvm<"i32*"> + llvm.br ^bb4 +^bb2: // 2 preds: ^bb0, ^bb4 + llvm.store %1, %arg1 : !llvm<"i32*"> + llvm.br ^bb4 +^bb3: // pred: ^bb4 + llvm.store %0, %arg1 : !llvm<"i32*"> + llvm.br ^bb4 +// CHECK: ^bb4: +^bb4: // 3 preds: ^bb1, ^bb2, ^bb3 + // CHECK: llvm.indirect_br %arg0 : !llvm<"i8*"> [^bb{{[23]}}, ^bb{{[23]}}] + llvm.indirect_br %arg0 : !llvm<"i8*"> [^bb3, ^bb2, ^bb2] +} diff --git a/mlir/test/Target/import.ll b/mlir/test/Target/import.ll --- a/mlir/test/Target/import.ll +++ b/mlir/test/Target/import.ll @@ -297,3 +297,31 @@ ; CHECK: llvm.return %{{[0-9]+}} : !llvm.i32 ret i32 0 } + +; CHECK-LABEL: llvm.func @indirectbrTest(%arg0: !llvm<"i8*">, %arg1: !llvm<"i32*">) +define void @indirectbrTest(i8* %address, i32* %sink) #0 { + +entry: + ; CHECK: llvm.indirect_br %arg0 : !llvm<"i8*"> [^bb{{[12]}}, ^bb{{[12]}}] + indirectbr i8* %address, [label %bb1, label %bb2] + +; ^bb1: // pred: ^bb0 +bb1: + store volatile i32 0, i32* %sink + br label %latch + +; CHECK: ^bb2: // 2 preds: ^bb0, ^bb4 +bb2: + store volatile i32 1, i32* %sink + br label %latch + +; CHECK: ^bb3: // pred: ^bb4 +bb3: + store volatile i32 2, i32* %sink + br label %latch + +; CHECK: ^bb4: // 3 preds: ^bb1, ^bb2, ^bb3 +latch: + ; CHECK: llvm.indirect_br %arg0 : !llvm<"i8*"> [^bb{{[23]}}, ^bb{{[23]}}] + indirectbr i8* %address, [label %bb2, label %bb3, label %bb3] +} diff --git a/mlir/test/Target/llvmir.mlir b/mlir/test/Target/llvmir.mlir --- a/mlir/test/Target/llvmir.mlir +++ b/mlir/test/Target/llvmir.mlir @@ -1171,3 +1171,24 @@ ^bb3: // pred: ^bb1 %8 = llvm.invoke @bar(%6) to ^bb2 unwind ^bb1 : (!llvm<"i8*">) -> !llvm<"i8*"> } + +// CHECK-LABEL: @indirectbr(i8* %0, i32* %1) +llvm.func @indirectbr(%arg0: !llvm<"i8*">, %arg1: !llvm<"i32*">) { + %0 = llvm.mlir.constant(2 : i32) : !llvm.i32 + %1 = llvm.mlir.constant(1 : i32) : !llvm.i32 + %2 = llvm.mlir.constant(0 : i32) : !llvm.i32 + // CHECK: indirectbr i8* %0, [label %{{[34]}}, label %{{[34]}}] + llvm.indirect_br %arg0 : !llvm<"i8*"> [^bb1, ^bb2] +^bb1: // pred: ^bb0 + llvm.store %2, %arg1 : !llvm<"i32*"> + llvm.br ^bb4 +^bb2: // 2 preds: ^bb0, ^bb4 + llvm.store %1, %arg1 : !llvm<"i32*"> + llvm.br ^bb4 +^bb3: // pred: ^bb4 + llvm.store %0, %arg1 : !llvm<"i32*"> + llvm.br ^bb4 +^bb4: // 3 preds: ^bb1, ^bb2, ^bb3 + // CHECK: indirectbr i8* %0, [label %{{[45]}}, label %{{[45]}}] + llvm.indirect_br %arg0 : !llvm<"i8*"> [^bb3, ^bb2, ^bb2] +}