diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -470,27 +470,32 @@ NoSideEffect]> { let arguments = (ins LLVMI1:$condition, Variadic:$trueDestOperands, - Variadic:$falseDestOperands); + Variadic:$falseDestOperands, + OptionalAttr:$branch_weights); let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest); - let assemblyFormat = [{ - $condition `,` - $trueDest (`(` $trueDestOperands^ `:` type($trueDestOperands) `)`)? `,` - $falseDest (`(` $falseDestOperands^ `:` type($falseDestOperands) `)`)? - attr-dict - }]; let builders = [OpBuilder< - "OpBuilder &builder, OperationState &result, Value condition," - "Block *trueDest, ValueRange trueOperands," - "Block *falseDest, ValueRange falseOperands", [{ - build(builder, result, condition, trueOperands, falseOperands, trueDest, - falseDest); + "OpBuilder &builder, OperationState &result, Value condition," + "Block *trueDest, ValueRange trueOperands," + "Block *falseDest, ValueRange falseOperands," + "Optional> weights = {}", [{ + ArrayAttr weightsAttr; + if (weights) { + weightsAttr = + builder.getI32ArrayAttr({static_cast(weights->first), + static_cast(weights->second)}); + } + build(builder, result, condition, trueOperands, falseOperands, weightsAttr, + trueDest, falseDest); }]>, OpBuilder< "OpBuilder &builder, OperationState &result, Value condition," "Block *trueDest, Block *falseDest, ValueRange falseOperands = {}", [{ build(builder, result, condition, trueDest, ValueRange(), falseDest, falseOperands); }]>, LLVM_TerminatorPassthroughOpBuilder]; + + let parser = [{ return parseCondBrOp(parser, result); }]; + let printer = [{ printCondBrOp(p, *this); }]; } def LLVM_ReturnOp : LLVM_TerminatorOp<"return", [NoSideEffect]>, Arguments<(ins Variadic:$args)> { diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -177,6 +177,78 @@ return index == 0 ? trueDestOperandsMutable() : falseDestOperandsMutable(); } +// ::= `llvm.cond_br` ssa-use +// (`[` true-weight `:` type `,` false-weight `:` type `]`)? `,` +// [(), +// ()] +static ParseResult parseCondBrOp(OpAsmParser &parser, OperationState &result) { + OpAsmParser::OperandType condition; + auto &builder = parser.getBuilder(); + Type trueWeightType, falseWeightType; + Block *dest; + + // Parse the condition. + auto int1Ty = LLVMType::getInt1Ty( + builder.getContext()->getRegisteredDialect()); + if (parser.parseOperand(condition) || + parser.resolveOperand(condition, int1Ty, result.operands)) + return failure(); + + // Parse optional branch weights. + if (succeeded(parser.parseOptionalLSquare())) { + IntegerAttr trueWeight, falseWeight; + NamedAttrList weights; + + auto i32Type = builder.getIntegerType(32); + if (parser.parseAttribute(trueWeight, i32Type, "true_branch_weight", + weights) || + parser.parseColon() || parser.parseType(trueWeightType) || + !trueWeightType.isInteger(32) || parser.parseComma() || + parser.parseAttribute(falseWeight, i32Type, "false_branch_weight", + weights) || + parser.parseColon() || parser.parseType(falseWeightType) || + !falseWeightType.isInteger(32) || parser.parseRSquare()) + return failure(); + + result.addAttribute("branch_weights", + builder.getArrayAttr({trueWeight, falseWeight})); + } + + // Parse the true branch. + SmallVector trueOperands; + if (parser.parseComma() || + parser.parseSuccessorAndUseList(dest, trueOperands)) + return failure(); + result.addSuccessors(dest); + result.addOperands(trueOperands); + + // Parse the false branch. + SmallVector falseOperands; + if (parser.parseComma() || + parser.parseSuccessorAndUseList(dest, falseOperands)) + return failure(); + result.addSuccessors(dest); + result.addOperands(falseOperands); + + result.addAttribute( + CondBrOp::getOperandSegmentSizeAttr(), + builder.getI32VectorAttr({1, static_cast(trueOperands.size()), + static_cast(falseOperands.size())})); + return success(); +} + +static void printCondBrOp(OpAsmPrinter &p, CondBrOp &op) { + p << op.getOperationName() << ' ' << op.condition(); + if (auto weights = op.branch_weights()) { + p << ' '; + p.printAttribute(weights.getValue()); + } + p << ", "; + p.printSuccessorAndUseList(op.trueDest(), op.trueDestOperands()); + p << ", "; + p.printSuccessorAndUseList(op.falseDest(), op.falseDestOperands()); +} + //===----------------------------------------------------------------------===// // Printing/parsing for LLVM::LoadOp. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -29,6 +29,7 @@ #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/LLVMContext.h" +#include "llvm/IR/MDBuilder.h" #include "llvm/IR/Module.h" #include "llvm/Transforms/Utils/Cloning.h" @@ -436,6 +437,21 @@ return success(); } if (auto condbrOp = dyn_cast(opInst)) { + auto weights = condbrOp.branch_weights(); + if (weights) { + // Map weight attributes to LLVM metadata. + auto weightsVector = llvm::to_vector<2>( + llvm::map_range(weights.getValue(), [&](Attribute attr) { + return static_cast(attr.cast().getInt()); + })); + llvm::MDNode *branchWeights = llvm::MDBuilder(llvmModule->getContext()) + .createBranchWeights(weightsVector); + builder.CreateCondBr(valueMapping.lookup(condbrOp.getOperand(0)), + blockMapping[condbrOp.getSuccessor(0)], + blockMapping[condbrOp.getSuccessor(1)], + branchWeights); + return success(); + } builder.CreateCondBr(valueMapping.lookup(condbrOp.getOperand(0)), blockMapping[condbrOp.getSuccessor(0)], blockMapping[condbrOp.getSuccessor(1)]); diff --git a/mlir/test/Target/llvmir.mlir b/mlir/test/Target/llvmir.mlir --- a/mlir/test/Target/llvmir.mlir +++ b/mlir/test/Target/llvmir.mlir @@ -1240,3 +1240,17 @@ %0 = llvm.mlir.addressof @address_taken : !llvm<"void()*"> llvm.return %0 : !llvm<"void()*"> } + +// ----- + +// Check that branch weight attributes are exported properly as metadata. +llvm.func @cond_br_weights(%cond : !llvm.i1, %arg0 : !llvm.i32, %arg1 : !llvm.i32) -> !llvm.i32 { + // CHECK: !prof ![[NODE:[0-9]+]] + llvm.cond_br %cond [5 : i32, 10 : i32], ^bb1, ^bb2 +^bb1: // pred: ^bb0 + llvm.return %arg0 : !llvm.i32 +^bb2: // pred: ^bb0 + llvm.return %arg1 : !llvm.i32 +} + +// CHECK: ![[NODE]] = !{!"branch_weights", i32 5, i32 10}