diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -514,21 +514,29 @@ NoSideEffect]> { let arguments = (ins LLVMI1:$condition, Variadic:$trueDestOperands, - Variadic:$falseDestOperands); + Variadic:$falseDestOperands, + OptionalAttr:$branch_weights); let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest); let assemblyFormat = [{ - $condition `,` + $condition ( `weights` `(` $branch_weights^ `)` )? `,` $trueDest (`(` $trueDestOperands^ `:` type($trueDestOperands) `)`)? `,` $falseDest (`(` $falseDestOperands^ `:` type($falseDestOperands) `)`)? attr-dict }]; let builders = [OpBuilder< - "OpBuilder &builder, OperationState &result, Value condition," - "Block *trueDest, ValueRange trueOperands," - "Block *falseDest, ValueRange falseOperands", [{ - build(builder, result, condition, trueOperands, falseOperands, trueDest, - falseDest); + "OpBuilder &builder, OperationState &result, Value condition," + "Block *trueDest, ValueRange trueOperands," + "Block *falseDest, ValueRange falseOperands," + "Optional> weights = {}", [{ + ElementsAttr weightsAttr; + if (weights) { + weightsAttr = + builder.getI32VectorAttr({static_cast(weights->first), + static_cast(weights->second)}); + } + build(builder, result, condition, trueOperands, falseOperands, weightsAttr, + trueDest, falseDest); }]>, OpBuilder< "OpBuilder &builder, OperationState &result, Value condition," "Block *trueDest, Block *falseDest, ValueRange falseOperands = {}", [{ diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -30,6 +30,7 @@ #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/LLVMContext.h" +#include "llvm/IR/MDBuilder.h" #include "llvm/IR/Module.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" @@ -594,9 +595,22 @@ return success(); } if (auto condbrOp = dyn_cast(opInst)) { + auto weights = condbrOp.branch_weights(); + llvm::MDNode *branchWeights = nullptr; + if (weights) { + // Map weight attributes to LLVM metadata. + auto trueWeight = + weights.getValue().getValue(0).cast().getInt(); + auto falseWeight = + weights.getValue().getValue(1).cast().getInt(); + branchWeights = + llvm::MDBuilder(llvmModule->getContext()) + .createBranchWeights(static_cast(trueWeight), + static_cast(falseWeight)); + } builder.CreateCondBr(valueMapping.lookup(condbrOp.getOperand(0)), blockMapping[condbrOp.getSuccessor(0)], - blockMapping[condbrOp.getSuccessor(1)]); + blockMapping[condbrOp.getSuccessor(1)], branchWeights); return success(); } diff --git a/mlir/test/Target/llvmir.mlir b/mlir/test/Target/llvmir.mlir --- a/mlir/test/Target/llvmir.mlir +++ b/mlir/test/Target/llvmir.mlir @@ -1240,3 +1240,17 @@ %0 = llvm.mlir.addressof @address_taken : !llvm<"void()*"> llvm.return %0 : !llvm<"void()*"> } + +// ----- + +// Check that branch weight attributes are exported properly as metadata. +llvm.func @cond_br_weights(%cond : !llvm.i1, %arg0 : !llvm.i32, %arg1 : !llvm.i32) -> !llvm.i32 { + // CHECK: !prof ![[NODE:[0-9]+]] + llvm.cond_br %cond weights(dense<[5, 10]> : vector<2xi32>), ^bb1, ^bb2 +^bb1: // pred: ^bb0 + llvm.return %arg0 : !llvm.i32 +^bb2: // pred: ^bb0 + llvm.return %arg1 : !llvm.i32 +} + +// CHECK: ![[NODE]] = !{!"branch_weights", i32 5, i32 10}