diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -477,7 +477,8 @@ let arguments = (ins OptionalAttr:$callee, Variadic:$callee_operands, Variadic:$normalDestOperands, - Variadic:$unwindDestOperands); + Variadic:$unwindDestOperands, + OptionalAttr:$branch_weights); let results = (outs Variadic); let successors = (successor AnySuccessor:$normalDest, AnySuccessor:$unwindDest); @@ -494,7 +495,7 @@ "ValueRange":$normalOps, "Block*":$unwind, "ValueRange":$unwindOps), [{ build($_builder, $_state, tys, /*callee=*/FlatSymbolRefAttr(), ops, normalOps, - unwindOps, normal, unwind); + unwindOps, nullptr, normal, unwind); }]>]; let hasCustomAssemblyFormat = 1; let hasVerifier = 1; @@ -547,13 +548,16 @@ let arguments = (ins OptionalAttr:$callee, Variadic, DefaultValuedAttr:$fastmathFlags); + "{}">:$fastmathFlags, + OptionalAttr:$branch_weights); let results = (outs Optional:$result); let builders = [ OpBuilder<(ins "LLVMFuncOp":$func, "ValueRange":$args)>, OpBuilder<(ins "TypeRange":$results, "StringAttr":$callee, CArg<"ValueRange", "{}">:$args)>, + OpBuilder<(ins "TypeRange":$results, "FlatSymbolRefAttr":$callee, + CArg<"ValueRange", "{}">:$args)>, OpBuilder<(ins "TypeRange":$results, "StringRef":$callee, CArg<"ValueRange", "{}">:$args)> ]; diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -1140,7 +1140,13 @@ void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results, StringAttr callee, ValueRange args) { - build(builder, state, results, SymbolRefAttr::get(callee), args, nullptr); + build(builder, state, results, SymbolRefAttr::get(callee), args, nullptr, + nullptr); +} + +void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results, + FlatSymbolRefAttr callee, ValueRange args) { + build(builder, state, results, callee, args, nullptr, nullptr); } void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func, @@ -1149,7 +1155,8 @@ Type resultType = func.getFunctionType().getReturnType(); if (!resultType.isa()) results.push_back(resultType); - build(builder, state, results, SymbolRefAttr::get(func), args, nullptr); + build(builder, state, results, SymbolRefAttr::get(func), args, nullptr, + nullptr); } CallInterfaceCallable CallOp::getCallableForCallee() { diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp @@ -18,6 +18,7 @@ #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/StringSet.h" +#include "llvm/ADT/TypeSwitch.h" #include "llvm/IR/Constants.h" #include "llvm/IR/InlineAsm.h" #include "llvm/IR/Instructions.h" @@ -116,15 +117,13 @@ } // Attach the branch weights to the operations that support it. - if (auto condBrOp = dyn_cast(op)) { - condBrOp.setBranchWeightsAttr(builder.getI32VectorAttr(branchWeights)); - return success(); - } - if (auto switchOp = dyn_cast(op)) { - switchOp.setBranchWeightsAttr(builder.getI32VectorAttr(branchWeights)); - return success(); - } - return failure(); + return llvm::TypeSwitch(op) + .Case([&](auto branchWeightOp) { + branchWeightOp.setBranchWeightsAttr( + builder.getI32VectorAttr(branchWeights)); + return success(); + }) + .Default([](auto) { return failure(); }); } namespace { diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp @@ -322,6 +322,21 @@ return success(); } +/// Constructs branch weight metadata if the provided `weights` hold a value, +/// otherwise return `nullptr`. +static llvm::MDNode * +convertBranchWeights(std::optional weights, + LLVM::ModuleTranslation &moduleTranslation) { + if (!weights) + return nullptr; + SmallVector weightValues; + weightValues.reserve(weights->size()); + for (APInt weight : weights->cast()) + weightValues.push_back(weight.getLimitedValue()); + return llvm::MDBuilder(moduleTranslation.getLLVMContext()) + .createBranchWeights(weightValues); +} + static LogicalResult convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) { @@ -336,32 +351,34 @@ // Emit function calls. If the "callee" attribute is present, this is a // direct function call and we also need to look up the remapped function // itself. Otherwise, this is an indirect call and the callee is the first - // operand, look it up as a normal value. Return the llvm::Value - // representing the function result, which may be of llvm::VoidTy type. - auto convertCall = [&](Operation &op) -> llvm::Value * { - auto operands = moduleTranslation.lookupValues(op.getOperands()); + // operand, look it up as a normal value. + if (auto callOp = dyn_cast(opInst)) { + auto operands = moduleTranslation.lookupValues(callOp.getOperands()); ArrayRef operandsRef(operands); - if (auto attr = op.getAttrOfType("callee")) - return builder.CreateCall( + llvm::CallInst *call; + if (auto attr = callOp.getCalleeAttr()) { + call = builder.CreateCall( moduleTranslation.lookupFunction(attr.getValue()), operandsRef); - auto calleeType = - op.getOperands().front().getType().cast(); - auto *calleeFunctionType = cast( - moduleTranslation.convertType(calleeType.getElementType())); - return builder.CreateCall(calleeFunctionType, operandsRef.front(), - operandsRef.drop_front()); - }; - - // Emit calls. If the called function has a result, remap the corresponding - // value. Note that LLVM IR dialect CallOp has either 0 or 1 result. - if (isa(opInst)) { - llvm::Value *result = convertCall(opInst); + } else { + auto calleeType = + callOp->getOperands().front().getType().cast(); + auto *calleeFunctionType = cast( + moduleTranslation.convertType(calleeType.getElementType())); + call = builder.CreateCall(calleeFunctionType, operandsRef.front(), + operandsRef.drop_front()); + } + llvm::MDNode *branchWeights = + convertBranchWeights(callOp.getBranchWeights(), moduleTranslation); + if (branchWeights) + call->setMetadata(llvm::LLVMContext::MD_prof, branchWeights); + // If the called function has a result, remap the corresponding value. Note + // that LLVM IR dialect CallOp has either 0 or 1 result. if (opInst.getNumResults() != 0) { - moduleTranslation.mapValue(opInst.getResult(0), result); + moduleTranslation.mapValue(opInst.getResult(0), call); return success(); } // Check that LLVM call returns void for 0-result functions. - return success(result->getType()->isVoidTy()); + return success(call->getType()->isVoidTy()); } if (auto inlineAsmOp = dyn_cast(opInst)) { @@ -442,6 +459,10 @@ moduleTranslation.lookupBlock(invOp.getSuccessor(1)), operandsRef.drop_front()); } + llvm::MDNode *branchWeights = + convertBranchWeights(invOp.getBranchWeights(), moduleTranslation); + if (branchWeights) + result->setMetadata(llvm::LLVMContext::MD_prof, branchWeights); moduleTranslation.mapBranch(invOp, result); // InvokeOp can only have 0 or 1 result if (invOp->getNumResults() != 0) { @@ -478,17 +499,8 @@ return success(); } if (auto condbrOp = dyn_cast(opInst)) { - llvm::MDNode *branchWeights = nullptr; - if (auto weights = condbrOp.getBranchWeights()) { - // Map weight attributes to LLVM metadata. - auto weightValues = weights->getValues(); - auto trueWeight = weightValues[0].getSExtValue(); - auto falseWeight = weightValues[1].getSExtValue(); - branchWeights = - llvm::MDBuilder(moduleTranslation.getLLVMContext()) - .createBranchWeights(static_cast(trueWeight), - static_cast(falseWeight)); - } + llvm::MDNode *branchWeights = + convertBranchWeights(condbrOp.getBranchWeights(), moduleTranslation); llvm::BranchInst *branch = builder.CreateCondBr( moduleTranslation.lookupValue(condbrOp.getOperand(0)), moduleTranslation.lookupBlock(condbrOp.getSuccessor(0)), @@ -498,16 +510,8 @@ return success(); } if (auto switchOp = dyn_cast(opInst)) { - llvm::MDNode *branchWeights = nullptr; - if (auto weights = switchOp.getBranchWeights()) { - llvm::SmallVector weightValues; - weightValues.reserve(weights->size()); - for (llvm::APInt weight : weights->cast()) - weightValues.push_back(weight.getLimitedValue()); - branchWeights = llvm::MDBuilder(moduleTranslation.getLLVMContext()) - .createBranchWeights(weightValues); - } - + llvm::MDNode *branchWeights = + convertBranchWeights(switchOp.getBranchWeights(), moduleTranslation); llvm::SwitchInst *switchInst = builder.CreateSwitch( moduleTranslation.lookupValue(switchOp.getValue()), moduleTranslation.lookupBlock(switchOp.getDefaultDestination()), diff --git a/mlir/test/Target/LLVMIR/Import/profiling-metadata.ll b/mlir/test/Target/LLVMIR/Import/profiling-metadata.ll --- a/mlir/test/Target/LLVMIR/Import/profiling-metadata.ll +++ b/mlir/test/Target/LLVMIR/Import/profiling-metadata.ll @@ -33,3 +33,36 @@ } !0 = !{!"branch_weights", i32 42, i32 3, i32 5} + +; // ----- + +; CHECK: llvm.func @fn() +declare void @fn() + +; CHECK-LABEL: @call_branch_weights +define void @call_branch_weights() { + ; CHECK: llvm.call @fn() {branch_weights = dense<42> : vector<1xi32>} + call void @fn(), !prof !0 + ret void +} + +!0 = !{!"branch_weights", i32 42} + +; // ----- + +declare void @foo() +declare i32 @__gxx_personality_v0(...) + +; CHECK-LABEL: @invoke_branch_weights +define i32 @invoke_branch_weights() personality i8* bitcast (i32 (...)* @__gxx_personality_v0 to i8*) { + ; CHECK: llvm.invoke @foo() to ^bb2 unwind ^bb1 {branch_weights = dense<[42, 99]> : vector<2xi32>} : () -> () + invoke void @foo() to label %bb2 unwind label %bb1, !prof !0 +bb1: + %1 = landingpad { i8*, i32 } cleanup + br label %bb2 +bb2: + ret i32 1 + +} + +!0 = !{!"branch_weights", i32 42, i32 99} diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir --- a/mlir/test/Target/LLVMIR/llvmir.mlir +++ b/mlir/test/Target/LLVMIR/llvmir.mlir @@ -1631,6 +1631,38 @@ // ----- +llvm.func @fn() + +// CHECK-LABEL: @call_branch_weights +llvm.func @call_branch_weights() { + // CHECK: !prof ![[NODE:[0-9]+]] + llvm.call @fn() {branch_weights = dense<42> : vector<1xi32>} : () -> () + llvm.return +} + +// CHECK: ![[NODE]] = !{!"branch_weights", i32 42} + +// ----- + +llvm.func @foo() +llvm.func @__gxx_personality_v0(...) -> i32 + +// CHECK-LABEL: @invoke_branch_weights +llvm.func @invoke_branch_weights() -> i32 attributes {personality = @__gxx_personality_v0} { + %0 = llvm.mlir.constant(1 : i32) : i32 + // CHECK: !prof ![[NODE:[0-9]+]] + llvm.invoke @foo() to ^bb2 unwind ^bb1 {branch_weights = dense<[42, 99]> : vector<2xi32>} : () -> () +^bb1: // pred: ^bb0 + %1 = llvm.landingpad cleanup : !llvm.struct<(ptr, i32)> + llvm.br ^bb2 +^bb2: // 2 preds: ^bb0, ^bb1 + llvm.return %0 : i32 +} + +// CHECK: ![[NODE]] = !{!"branch_weights", i32 42, i32 99} + +// ----- + llvm.func @volatile_store_and_load() { %val = llvm.mlir.constant(5 : i32) : i32 %size = llvm.mlir.constant(1 : i64) : i64