diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td @@ -48,6 +48,42 @@ ]; } +def BranchWeightOpInterface : OpInterface<"BranchWeightOpInterface"> { + let description = [{ + An interface for operations that can carry branch weights metadata. It + provides setters and getters for the operation's branch weights attribute. + The default implementation of the interface methods expect the operation to + have an attribute of type DenseI32ArrayAttr named branch_weights. + }]; + + let cppNamespace = "::mlir::LLVM"; + + let methods = [ + InterfaceMethod< + /*desc=*/ "Returns the branch weights attribute or nullptr", + /*returnType=*/ "DenseI32ArrayAttr", + /*methodName=*/ "getBranchWeightsOrNull", + /*args=*/ (ins), + /*methodBody=*/ [{}], + /*defaultImpl=*/ [{ + ConcreteOp op = cast(this->getOperation()); + return op.getBranchWeightsAttr(); + }] + >, + InterfaceMethod< + /*desc=*/ "Sets the branch weights attribute", + /*returnType=*/ "void", + /*methodName=*/ "setBranchWeights", + /*args=*/ (ins "DenseI32ArrayAttr":$attr), + /*methodBody=*/ [{}], + /*defaultImpl=*/ [{ + ConcreteOp op = cast(this->getOperation()); + op.setBranchWeightsAttr(attr); + }] + > + ]; +} + def AccessGroupOpInterface : OpInterface<"AccessGroupOpInterface"> { let description = [{ An interface for memory operations that can carry access groups metadata. diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -536,12 +536,14 @@ def LLVM_InvokeOp : LLVM_Op<"invoke", [ AttrSizedOperandSegments, DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods, Terminator]> { + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + Terminator]> { let arguments = (ins OptionalAttr:$callee, Variadic:$callee_operands, Variadic:$normalDestOperands, Variadic:$unwindDestOperands, - OptionalAttr:$branch_weights); + OptionalAttr:$branch_weights); let results = (outs Variadic); let successors = (successor AnySuccessor:$normalDest, AnySuccessor:$unwindDest); @@ -582,7 +584,8 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call", [DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { let summary = "Call to an LLVM function."; let description = [{ In LLVM IR, functions may return either 0 or 1 value. LLVM IR dialect @@ -616,7 +619,7 @@ Variadic, DefaultValuedAttr:$fastmathFlags, - OptionalAttr:$branch_weights); + OptionalAttr:$branch_weights); // Append the aliasing related attributes defined in LLVM_MemAccessOpBase. let arguments = !con(args, aliasAttrs); let results = (outs Optional:$result); @@ -847,12 +850,14 @@ ]; } def LLVM_CondBrOp : LLVM_TerminatorOp<"cond_br", - [AttrSizedOperandSegments, DeclareOpInterfaceMethods, + [AttrSizedOperandSegments, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, Pure]> { let arguments = (ins I1:$condition, Variadic:$trueDestOperands, Variadic:$falseDestOperands, - OptionalAttr:$branch_weights, + OptionalAttr:$branch_weights, OptionalAttr:$loop_annotation); let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest); let assemblyFormat = [{ @@ -874,7 +879,7 @@ falseOperands); }]>, OpBuilder<(ins "Value":$condition, "ValueRange":$trueOperands, "ValueRange":$falseOperands, - "ElementsAttr":$branchWeights, "Block *":$trueDest, "Block *":$falseDest), + "DenseI32ArrayAttr":$branchWeights, "Block *":$trueDest, "Block *":$falseDest), [{ build($_builder, $_state, condition, trueOperands, falseOperands, branchWeights, {}, trueDest, falseDest); @@ -934,7 +939,9 @@ } def LLVM_SwitchOp : LLVM_TerminatorOp<"switch", - [AttrSizedOperandSegments, DeclareOpInterfaceMethods, + [AttrSizedOperandSegments, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, Pure]> { let arguments = (ins AnyInteger:$value, @@ -942,7 +949,7 @@ VariadicOfVariadic:$caseOperands, OptionalAttr:$case_values, DenseI32ArrayAttr:$case_operand_segments, - OptionalAttr:$branch_weights + OptionalAttr:$branch_weights ); let successors = (successor AnySuccessor:$defaultDestination, diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h --- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -118,6 +118,20 @@ return branchMapping.lookup(op); } + /// Stores a mapping between an MLIR call operation and a corresponding LLVM + /// IR instruction. + void mapCall(Operation *mlir, llvm::CallInst *llvm) { + auto result = callMapping.try_emplace(mlir, llvm); + (void)result; + assert(result.second && "attempting to map a call that is already mapped"); + } + + /// Finds an LLVM IR call instruction that corresponds to the given MLIR call + /// operation. + llvm::CallInst *lookupCall(Operation *op) const { + return callMapping.lookup(op); + } + /// Removes the mapping for blocks contained in the region and values defined /// in these blocks. void forgetMapping(Region ®ion); @@ -141,6 +155,9 @@ /// Sets LLVM TBAA metadata for memory operations that have TBAA attributes. void setTBAAMetadata(AliasAnalysisOpInterface op, llvm::Instruction *inst); + /// Sets LLVM profiling metadata for operations that have branch weights. + void setBranchWeightsMetadata(BranchWeightOpInterface op); + /// Sets LLVM loop metadata for branch operations that have a loop annotation /// attribute. void setLoopMetadata(Operation *op, llvm::Instruction *inst); @@ -328,6 +345,11 @@ /// values after all operations are converted. DenseMap branchMapping; + /// A mapping between MLIR LLVM dialect call operations and LLVM IR call + /// instructions. This allows for adding branch weights after the operations + /// have been converted. + DenseMap callMapping; + /// Mapping from an alias scope metadata operation to its LLVM metadata. /// This map is populated on module entry. DenseMap aliasScopeMetadataMapping; diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp @@ -553,10 +553,12 @@ matchAndRewrite(spirv::BranchConditionalOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // If branch weights exist, map them to 32-bit integer vector. - ElementsAttr branchWeights = nullptr; + DenseI32ArrayAttr branchWeights = nullptr; if (auto weights = op.getBranchWeights()) { - VectorType weightType = VectorType::get(2, rewriter.getI32Type()); - branchWeights = DenseElementsAttr::get(weightType, weights->getValue()); + SmallVector weightValues; + for (auto weight : weights->getAsRange()) + weightValues.push_back(weight.getInt()); + branchWeights = DenseI32ArrayAttr::get(getContext(), weightValues); } rewriter.replaceOpWithNewOp( diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -310,11 +310,11 @@ Value condition, Block *trueDest, ValueRange trueOperands, Block *falseDest, ValueRange falseOperands, std::optional> weights) { - ElementsAttr weightsAttr; + DenseI32ArrayAttr weightsAttr; if (weights) weightsAttr = - builder.getI32VectorAttr({static_cast(weights->first), - static_cast(weights->second)}); + builder.getDenseI32ArrayAttr({static_cast(weights->first), + static_cast(weights->second)}); build(builder, result, condition, trueOperands, falseOperands, weightsAttr, /*loop_annotation=*/{}, trueDest, falseDest); @@ -330,9 +330,9 @@ BlockRange caseDestinations, ArrayRef caseOperands, ArrayRef branchWeights) { - ElementsAttr weightsAttr; + DenseI32ArrayAttr weightsAttr; if (!branchWeights.empty()) - weightsAttr = builder.getI32VectorAttr(llvm::to_vector<4>(branchWeights)); + weightsAttr = builder.getDenseI32ArrayAttr(branchWeights); build(builder, result, value, defaultOperands, caseOperands, caseValues, weightsAttr, defaultDestination, caseDestinations); diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp @@ -125,13 +125,11 @@ branchWeights.push_back(branchWeight->getZExtValue()); } - return TypeSwitch(op) - .Case([&](auto branchWeightOp) { - branchWeightOp.setBranchWeightsAttr( - builder.getI32VectorAttr(branchWeights)); - return success(); - }) - .Default([](auto) { return failure(); }); + if (auto iface = dyn_cast(op)) { + iface.setBranchWeights(builder.getDenseI32ArrayAttr(branchWeights)); + return success(); + } + return failure(); } /// Searches the symbol reference pointing to the metadata operation that diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp @@ -124,21 +124,6 @@ return success(); } -/// Constructs branch weights metadata if the provided `weights` hold a value, -/// otherwise returns nullptr. -static llvm::MDNode * -convertBranchWeights(std::optional weights, - LLVM::ModuleTranslation &moduleTranslation) { - if (!weights) - return nullptr; - SmallVector weightValues; - weightValues.reserve(weights->size()); - for (APInt weight : llvm::cast(*weights)) - weightValues.push_back(weight.getLimitedValue()); - return llvm::MDBuilder(moduleTranslation.getLLVMContext()) - .createBranchWeights(weightValues); -} - static LogicalResult convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) { @@ -182,10 +167,6 @@ callOp.getArgOperands()), operandsRef.front(), operandsRef.drop_front()); } - llvm::MDNode *branchWeights = - convertBranchWeights(callOp.getBranchWeights(), moduleTranslation); - if (branchWeights) - call->setMetadata(llvm::LLVMContext::MD_prof, branchWeights); moduleTranslation.setAccessGroupsMetadata(callOp, call); moduleTranslation.setAliasScopeMetadata(callOp, call); moduleTranslation.setTBAAMetadata(callOp, call); @@ -196,7 +177,10 @@ return success(); } // Check that LLVM call returns void for 0-result functions. - return success(call->getType()->isVoidTy()); + if (!call->getType()->isVoidTy()) + return failure(); + moduleTranslation.mapCall(callOp, call); + return success(); } if (auto inlineAsmOp = dyn_cast(opInst)) { @@ -274,10 +258,6 @@ moduleTranslation.lookupBlock(invOp.getSuccessor(1)), operandsRef.drop_front()); } - llvm::MDNode *branchWeights = - convertBranchWeights(invOp.getBranchWeights(), moduleTranslation); - if (branchWeights) - result->setMetadata(llvm::LLVMContext::MD_prof, branchWeights); moduleTranslation.mapBranch(invOp, result); // InvokeOp can only have 0 or 1 result if (invOp->getNumResults() != 0) { @@ -314,23 +294,19 @@ return success(); } if (auto condbrOp = dyn_cast(opInst)) { - llvm::MDNode *branchWeights = - convertBranchWeights(condbrOp.getBranchWeights(), moduleTranslation); llvm::BranchInst *branch = builder.CreateCondBr( moduleTranslation.lookupValue(condbrOp.getOperand(0)), moduleTranslation.lookupBlock(condbrOp.getSuccessor(0)), - moduleTranslation.lookupBlock(condbrOp.getSuccessor(1)), branchWeights); + moduleTranslation.lookupBlock(condbrOp.getSuccessor(1))); moduleTranslation.mapBranch(&opInst, branch); moduleTranslation.setLoopMetadata(&opInst, branch); return success(); } if (auto switchOp = dyn_cast(opInst)) { - llvm::MDNode *branchWeights = - convertBranchWeights(switchOp.getBranchWeights(), moduleTranslation); llvm::SwitchInst *switchInst = builder.CreateSwitch( moduleTranslation.lookupValue(switchOp.getValue()), moduleTranslation.lookupBlock(switchOp.getDefaultDestination()), - switchOp.getCaseDestinations().size(), branchWeights); + switchOp.getCaseDestinations().size()); auto *ty = llvm::cast( moduleTranslation.convertType(switchOp.getValue().getType())); diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -664,6 +664,10 @@ if (failed(convertOperation(op, builder))) return failure(); + + // Set the branch weight metadata on the translated instruction. + if (auto iface = dyn_cast(op)) + setBranchWeightsMetadata(iface); } return success(); @@ -1190,6 +1194,19 @@ inst->setMetadata(llvm::LLVMContext::MD_tbaa, node); } +void ModuleTranslation::setBranchWeightsMetadata(BranchWeightOpInterface op) { + DenseI32ArrayAttr weightsAttr = op.getBranchWeightsOrNull(); + if (!weightsAttr) + return; + + llvm::Instruction *inst = isa(op) ? lookupCall(op) : lookupBranch(op); + assert(inst && "expected the operation to have a mapping to an instruction"); + SmallVector weights(weightsAttr.asArrayRef()); + inst->setMetadata( + llvm::LLVMContext::MD_prof, + llvm::MDBuilder(getLLVMContext()).createBranchWeights(weights)); +} + LogicalResult ModuleTranslation::createTBAAMetadata() { llvm::LLVMContext &ctx = llvmModule->getContext(); llvm::IntegerType *offsetTy = llvm::IntegerType::get(ctx, 64); diff --git a/mlir/test/Conversion/SPIRVToLLVM/control-flow-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/control-flow-ops-to-llvm.mlir --- a/mlir/test/Conversion/SPIRVToLLVM/control-flow-ops-to-llvm.mlir +++ b/mlir/test/Conversion/SPIRVToLLVM/control-flow-ops-to-llvm.mlir @@ -68,7 +68,7 @@ } spirv.func @cond_branch_with_weights(%cond: i1) -> () "None" { - // CHECK: llvm.cond_br %{{.*}} weights(dense<[1, 2]> : vector<2xi32>), ^bb1, ^bb2 + // CHECK: llvm.cond_br %{{.*}} weights([1, 2]), ^bb1, ^bb2 spirv.BranchConditional %cond [1, 2], ^true, ^false // CHECK: ^bb1: ^true: diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -874,7 +874,7 @@ // expected-error@+1 {{expects number of branch weights to match number of successors: 3 vs 2}} llvm.switch %arg0 : i32, ^bb1 [ 42: ^bb2(%arg0, %arg0 : i32, i32) - ] {branch_weights = dense<[13, 17, 19]> : vector<3xi32>} + ] {branch_weights = array} ^bb1: // pred: ^bb0 llvm.return diff --git a/mlir/test/Target/LLVMIR/Import/metadata-profiling.ll b/mlir/test/Target/LLVMIR/Import/metadata-profiling.ll --- a/mlir/test/Target/LLVMIR/Import/metadata-profiling.ll +++ b/mlir/test/Target/LLVMIR/Import/metadata-profiling.ll @@ -4,7 +4,7 @@ define i64 @cond_br(i1 %arg1, i64 %arg2) { entry: ; CHECK: llvm.cond_br - ; CHECK-SAME: weights(dense<[0, 3]> : vector<2xi32>) + ; CHECK-SAME: weights([0, 3]) br i1 %arg1, label %bb1, label %bb2, !prof !0 bb1: ret i64 %arg2 @@ -19,7 +19,7 @@ ; CHECK-LABEL: @simple_switch( define i32 @simple_switch(i32 %arg1) { ; CHECK: llvm.switch - ; CHECK: {branch_weights = dense<[42, 3, 5]> : vector<3xi32>} + ; CHECK: {branch_weights = array} switch i32 %arg1, label %bbd [ i32 0, label %bb1 i32 9, label %bb2 @@ -41,7 +41,7 @@ ; CHECK-LABEL: @call_branch_weights define void @call_branch_weights() { - ; CHECK: llvm.call @fn() {branch_weights = dense<42> : vector<1xi32>} + ; CHECK: llvm.call @fn() {branch_weights = array} call void @fn(), !prof !0 ret void } @@ -55,7 +55,7 @@ ; CHECK-LABEL: @invoke_branch_weights define i32 @invoke_branch_weights() personality ptr @__gxx_personality_v0 { - ; CHECK: llvm.invoke @foo() to ^bb2 unwind ^bb1 {branch_weights = dense<[42, 99]> : vector<2xi32>} : () -> () + ; CHECK: llvm.invoke @foo() to ^bb2 unwind ^bb1 {branch_weights = array} : () -> () invoke void @foo() to label %bb2 unwind label %bb1, !prof !0 bb1: %1 = landingpad { ptr, i32 } cleanup diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir --- a/mlir/test/Target/LLVMIR/llvmir.mlir +++ b/mlir/test/Target/LLVMIR/llvmir.mlir @@ -1802,7 +1802,7 @@ // Check that branch weight attributes are exported properly as metadata. llvm.func @cond_br_weights(%cond : i1, %arg0 : i32, %arg1 : i32) -> i32 { // CHECK: !prof ![[NODE:[0-9]+]] - llvm.cond_br %cond weights(dense<[5, 10]> : vector<2xi32>), ^bb1, ^bb2 + llvm.cond_br %cond weights([5, 10]), ^bb1, ^bb2 ^bb1: // pred: ^bb0 llvm.return %arg0 : i32 ^bb2: // pred: ^bb0 @@ -1818,7 +1818,7 @@ // CHECK-LABEL: @call_branch_weights llvm.func @call_branch_weights() { // CHECK: !prof ![[NODE:[0-9]+]] - llvm.call @fn() {branch_weights = dense<42> : vector<1xi32>} : () -> () + llvm.call @fn() {branch_weights = array} : () -> () llvm.return } @@ -1833,7 +1833,7 @@ llvm.func @invoke_branch_weights() -> i32 attributes {personality = @__gxx_personality_v0} { %0 = llvm.mlir.constant(1 : i32) : i32 // CHECK: !prof ![[NODE:[0-9]+]] - llvm.invoke @foo() to ^bb2 unwind ^bb1 {branch_weights = dense<[42, 99]> : vector<2xi32>} : () -> () + llvm.invoke @foo() to ^bb2 unwind ^bb1 {branch_weights = array} : () -> () ^bb1: // pred: ^bb0 %1 = llvm.landingpad cleanup : !llvm.struct<(ptr, i32)> llvm.br ^bb2 @@ -2062,7 +2062,7 @@ llvm.switch %arg0 : i32, ^bb1(%0 : i32) [ 9: ^bb2(%1, %2 : i32, i32), 99: ^bb3 - ] {branch_weights = dense<[13, 17, 19]> : vector<3xi32>} + ] {branch_weights = array} ^bb1(%3: i32): // pred: ^bb0 llvm.return %3 : i32