Index: llvm/utils/BranchAccuracy/README.md =================================================================== --- /dev/null +++ llvm/utils/BranchAccuracy/README.md @@ -0,0 +1,50 @@ +# Branch Accuracy + +This directory contains tools relating to comparison of branch probability +information at the compilation against an accurate tracing of the program. + +## Extension + +This infrastructure could be extended to other targets. Emitting of branch +probabilities is target agnostic in AsmPrinter but requires that the target +overrides `InstrInfo::getBranchDestBlock`. A new target would need a tracing +tool that can accurately count branch execution. Depending on the output +format of the trace, "branch_accuracy_data.py" will need to be updated to +accept the target as an option and parse the trace output. + +## Contents + +### branch_accuracy_data.py + +"branch_accuracy_data.py" is a script that takes an executable and trace file +and produces a CSV which all branch data and prints high level summary of +probability difference error and direction match percent. + +#### Supported Targets +* x86_64 - using pin tool +* x86 - using pin tool + +### X86 - cond_br_trace.cpp + +The X86 subdirectory contains a custom Pin Tool for tracing conditional +branches. It requires that the kit for Pin is install at an accessible +location. + +The tool can built with the following command for x86_64. Switch `intel64` +to `ia32` to build x86 (32-bit) tool. +```bash +cd X86/ +make PIN_ROOT= OBJDIR= \ + /cond_br_trace.so TARGET=intel64 +``` + +Running the pin tool should will have the following form. It has an optional +argument to specify the path of the dumping file `-o`, otherwise the default +is "cond_br_trace.out". +```bash +/pin -t /cond_br_trace.so -- +``` + +The trace output can then be used in tandem with the program ELF to collect +data in "branch_accuracy_data.py". + Index: llvm/utils/BranchAccuracy/X86/cond_br_trace.cpp =================================================================== --- /dev/null +++ llvm/utils/BranchAccuracy/X86/cond_br_trace.cpp @@ -0,0 +1,114 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +#include "pin.H" +// Need to include this after pin.H +#include "instlib.H" + +namespace { +/// Record of collected information from each conditional branch +struct BranchCounter { + UINT64 Branch = 0; + UINT64 Taken = 0; +}; + +/// Global counter containing all conditional branch and associated counts +std::unordered_map CounterMap; + +/// Addresses that are used to filter out external code and determine absolute +/// instruction addresses after the trace +ADDRINT MainExecutableAddressLow; +ADDRINT MainExecutableAddressHigh; + +/// Function runs at every conditional branch to trace total and taken count +VOID atBranch(ADDRINT Ip, ADDRINT Target, BOOL Taken) { + CounterMap[Ip].Branch++; + if (Taken) + CounterMap[Ip].Taken++; +} + +/// Insert branch trace calls before every conditional branch +VOID tagCondBr(TRACE Trace, VOID *) { + for (BBL BB = TRACE_BblHead(Trace); BBL_Valid(BB); BB = BBL_Next(BB)) { + INS Ins = BBL_InsTail(BB); + if (!INS_Valid(Ins)) + continue; + + ADDRINT IP = INS_Address(Ins); + if (IP < MainExecutableAddressLow || IP > MainExecutableAddressHigh) + continue; + + if (INS_Category(Ins) != XED_CATEGORY_COND_BR) + continue; + + INS_InsertCall(Ins, IPOINT_BEFORE, (AFUNPTR)atBranch, IARG_INST_PTR, + IARG_BRANCH_TARGET_ADDR, IARG_BRANCH_TAKEN, IARG_END); + } +} + +/// Record the relocated address of the main executable on load before it +/// executes +VOID tagImageLoad(IMG Img, VOID *) { + if (!IMG_IsMainExecutable(Img)) + return; + MainExecutableAddressLow = IMG_LowAddress(Img); + MainExecutableAddressHigh = IMG_HighAddress(Img); +} + +/// Output to rename the output branch trace +KNOB KnobOutputFile(KNOB_MODE_WRITEONCE, "pintool", "o", + "cond_br_trace.out", + "specify output file name"); + +KNOB Silent(KNOB_MODE_WRITEONCE, "pintool", "silent", "0", + "Silence stderr and stdout"); + +/// Dumps all traces to an output file on close +VOID fini(INT32, VOID *) { + // Write to a file since cout and cerr maybe closed by the application + std::ofstream OutFile(KnobOutputFile.Value().c_str()); + + OutFile.setf(std::ios::showbase); + OutFile << "Executable Address: " << std::hex << MainExecutableAddressLow + << " - " << MainExecutableAddressHigh << std::dec << '\n'; + // Output results + for (const auto &KeyVal : CounterMap) { + OutFile << std::hex << KeyVal.first << std::dec + << " => branch count: " << KeyVal.second.Branch + << " => taken count: " << KeyVal.second.Taken << '\n'; + } +} + +} // namespace + +int main(int Argc, char *Argv[]) { + // Initialize pin + if (PIN_Init(Argc, Argv)) { + std::cerr + << "This tool predicts the outcome of conditional branches executed\n"; + std::cerr << '\n' << KNOB_BASE::StringKnobSummary() << '\n'; + return -1; + } + + if (!Silent) { + std::cerr << "===============================================\n" + << "This application is instrumented by cond_br_trace\n" + << "See file " << KnobOutputFile.Value() + << " for analysis results\n" + << "===============================================\n"; + } + + IMG_AddInstrumentFunction(tagImageLoad, nullptr); + TRACE_AddInstrumentFunction(tagCondBr, nullptr); + PIN_AddFiniFunction(fini, nullptr); + + PIN_StartProgram(); + + return 0; +} Index: llvm/utils/BranchAccuracy/X86/makefile =================================================================== --- /dev/null +++ llvm/utils/BranchAccuracy/X86/makefile @@ -0,0 +1,9 @@ +ifdef PIN_ROOT +CONFIG_ROOT := $(PIN_ROOT)/source/tools/Config +else +$(error Must specify PIN_ROOT to build tool) +endif + +include $(CONFIG_ROOT)/makefile.config +include makefile.rules +include $(TOOLS_ROOT)/Config/makefile.default.rules Index: llvm/utils/BranchAccuracy/X86/makefile.rules =================================================================== --- /dev/null +++ llvm/utils/BranchAccuracy/X86/makefile.rules @@ -0,0 +1,2 @@ +TOOL_ROOTS := cond_br_trace + Index: llvm/utils/BranchAccuracy/branch_accuracy_data.py =================================================================== --- /dev/null +++ llvm/utils/BranchAccuracy/branch_accuracy_data.py @@ -0,0 +1,524 @@ +#!/usr/bin/env python3 + +"""Generates compiler and execution branch data for a program and trace""" + + +import argparse +import dataclasses +import json +import re +import subprocess +import sys + +from pathlib import Path +from typing import Any, Dict, List, Optional + +import pandas as pd + + +ISA: str = None +STAT_CHOICES: List[str] = [ + "total_branches", + "total_unique", + "total_unique_encountered", + "prob_diff_error", + "prob_diff_error_unique", + "match_total", + "match_unique", + "match_percent", + "match_percent_unique", +] + + +@dataclasses.dataclass +class DebugLoc: + "Debug location information provided in jump annotation" + + file: str = "" + line: int = 0 + col: int = -1 + discriminator: int = -1 + inlined: int = -1 # -1 = unknown, 0 = false, 1 = true + + def to_flat_dict(self) -> Dict[str, Any]: + """Converts this class into a flat dictionary of numbers and strings""" + return vars(self).copy() + + +@dataclasses.dataclass +class JumpAnnotation: + """Jump data provided from object dumping + + Attributes: + probability Compiler probability for taken prediction + proc Function containing the branch after inlining + dloc Debug location for the branch + """ + + probability: float + proc: str + dloc: DebugLoc + + def to_flat_dict(self) -> Dict[str, Any]: + """Converts this class into a flat dictionary of numbers and strings""" + d: Dict[str, Any] = vars(self).copy() + del d["dloc"] + d.update(self.dloc.to_flat_dict()) + return d + + +@dataclasses.dataclass +class JumpStats: + """Stats for a jump computed from the trace + + Attributes: + total Number of executions for a given branch + taken Number of executions that took the branch + predicted Number of executions that correctly predicted taken/non-taken + """ + + total: int = 0 + taken: int = 0 + predicted: int = 0 + + def is_empty(self) -> bool: + """Checks if a jump was actually executed""" + return self.total == 0 + + def taken_percent(self) -> float: + """Percent of execution that the branch was taken if not empty""" + if self.total == 0: + return float("nan") + return float(self.taken) / float(self.total) + + def to_flat_dict(self) -> Dict[str, Any]: + """Converts this class into a flat dictionary of numbers and strings""" + return vars(self).copy() + + +class Jump: + """Class that holds all information about a unique jump from the trace""" + + def __init__( + self, + pc: int, + asm: str, + annotation: JumpAnnotation, + stats: Optional[JumpStats] = None, + ): + self.pc: int = pc + self.asm: str = asm + self.annotation: JumpAnnotation = annotation + self.stats: JumpStats = stats if stats else JumpStats() + + def does_direction_match(self) -> bool: + """True if annotation and stats bias same taken/not-taken + + Branches that are not encountered are considered a mismatch + and likely should be left out of overall metrics. + + Branches with a 50% compiler probability are considered only + a match if the taken percent is also exactly 50%. The probability + usually shows up when the compiler does not have profile information + but it can come from the profile. Exact comparison is preferred over + treating 50% as always taken or always nontaken because preferring + a given direction taken/non-taken would skew overall direction match. + The skewing may incorrectly suggest that branch data should be + discarded to create more 50% probabilities. + """ + if self.stats.total == 0: + return False + taken_percent: float = self.stats.taken_percent() + probability: float = self.annotation.probability + return ( + (taken_percent > 0.50 and probability > 0.50) + or (taken_percent < 0.50 and probability < 0.50) + or (taken_percent == 0.50 and probability == 0.50) + ) + + def is_ambiguous(self) -> bool: + """True if compiler probability is exactly 50%""" + return self.annotation.probability == 0.50 + + def prob_diff(self) -> float: + """|taken_percent - probability|""" + if self.stats.total == 0: + return float("nan") + taken_percent: float = self.stats.taken_percent() + probability: float = self.annotation.probability + return abs(taken_percent - probability) + + def prob_diff_weighted(self) -> float: + """(taken_percent - probability)^2 * total""" + if self.stats.total == 0: + return float("nan") + prob_diff = self.prob_diff() + return prob_diff * prob_diff * self.stats.total + + def to_flat_dict(self) -> Dict[str, Any]: + """Converts this class into a flat dictionary of numbers and strings""" + return { + **{k: v for k, v in vars(self).items() if k not in ["annotation", "stats"]}, + **self.annotation.to_flat_dict(), + **self.stats.to_flat_dict(), + "taken_percent": self.stats.taken_percent(), + "prob_diff": self.prob_diff(), + "prob_diff_weighted": self.prob_diff_weighted(), + "direction_match": self.does_direction_match(), + "is_ambiguous": self.is_ambiguous(), + } + + def to_dataframe(self) -> pd.DataFrame: + """Flattens all data in jump then constructs a dataframe""" + d: Dict[str, Any] = self.to_flat_dict() + del d["pc"] + return pd.DataFrame(d, index=[self.pc]) + + +def collect_fragment_jump_annotations(program: str, llvm_objdump: str) -> List[Jump]: + """ + Reads disassembly and collects all annotated jumps from the + .branch_probabilities section emitted in the backend. + """ + + assert ISA in ["x86_64", "x86_32"], f"Unsupported isa: {ISA}" + + addr_width: int = 64 if ISA in ["x86_64"] else 32 + byte_order: str = "little" + + bp_res = subprocess.check_output( + [llvm_objdump, "-s", "--section=.branch_probabilities", program], + ) + + addr_to_data: Dict[int, float] = {} + + def byte_str_to_int(bs: str) -> int: + return int.from_bytes(bytes.fromhex(bs), byteorder=byte_order) + + def handle_addr_prob_strs(addr: str, prob: str): + address: int = byte_str_to_int(addr) + probability: int = byte_str_to_int(prob) + + assert address != 0, "Unexpected null address" + assert probability >= 0 and probability <= 10000, f"Unexpected probability {probability}" + probability: float = float(probability) / 10000.0 + addr_to_data[address] = probability + + # First three lines are info about ELF and section before the binary dump + lines: List[str] = bp_res.decode("utf-8").splitlines()[3:] + + # Collect the branch probability section into a dictionary of address to + # branch data + if addr_width == 64: + for line in lines: + [addr1, addr2, data, padding] = line.split(" ")[2:6] + assert padding == "00000000", "unexpected values in the padding" + handle_addr_prob_strs(addr1 + addr2, data) + + elif addr_width == 32: + for line in lines: + [addr1, data1, addr2, data2] = line.split(" ")[2:6] + handle_addr_prob_strs(addr1, data1) + if addr2 and data2: + handle_addr_prob_strs(addr2, data2) + + else: + assert False, "unreachable" + + assert addr_to_data, f"No jumps found in .branch_probabilities section of `{program}`\n" + + # Pattern for finding jump annotations in object dumps. + pattern: re.Pattern = re.compile(rf" *([0-9a-f]+): ") + + # Run object dump to get assembly string + dump_res = subprocess.check_output([llvm_objdump, "-d", program]) + + # Iterate over assembly and collect the jump instructions + jumps: List[Jump] = [] + line: str + for line in dump_res.decode("utf-8").splitlines(): + m = pattern.search(line) + if not m: + continue + address: int = int(m.group(1), base=16) + probability = addr_to_data.get(address) + if probability is None: + continue + + line = line.strip() + jump = Jump(address, line, JumpAnnotation(probability, "", DebugLoc())) + jumps.append(jump) + + assert jumps, ( + f"No jumps were found in program `{program}`\n" + f"Ensure that clang is built with jump annotation patch" + ) + + return jumps + + +def collect_debug_info(jumps: List[Jump], program: str, llvm_objdump: str): + # Collect debug info for each annotated jump + llvm_symbolizer = Path(llvm_objdump).parent / "llvm-symbolizer" + syms_res = subprocess.check_output( + [ + llvm_symbolizer, + "--output-style=JSON", + f"--obj={program}", + *[str(j.pc) for j in jumps], + ], + ) + + debug_infos = json.loads(syms_res.decode("utf-8")) + assert len(jumps) == len(debug_infos) + for jump, debug in zip(jumps, debug_infos): + assert jump.pc == int(debug["Address"], base=16) + symbol = debug["Symbol"] + symbol = symbol[0] + + jump.annotation.proc = symbol["FunctionName"] + jump.annotation.dloc = DebugLoc( + file=symbol["FileName"], + line=symbol["Line"], + col=symbol["Column"], + discriminator=symbol["Discriminator"], + inlined=len(symbol) != 1, + ) + + +def collect_stats_for_x86_annotated_jumps(jumps: List[Jump], trace_file: str): + """ + Collects branch data from custom Pintool and associates each to their + assocaited instruction in the binary. Performed inplace on `jumps` + """ + + is_64bit: bool = ISA == "x86_64" + + # Set up map of PC to Jump data for faster lookup + pc_jump_map: Dict[int, Jump] = {j.pc: j for j in jumps} + + with open(trace_file, "r") as f: + lines: List[str] = f.readlines() + + ptr_width: int = 12 if is_64bit else 8 + pattern: re.Pattern = re.compile( + rf"^Executable Address: " rf"0x([0-9a-f]{{{ptr_width}}}) - 0x([0-9a-f]{{{ptr_width}}})\n$" + ) + match: re.Match = pattern.match(lines[0]) + low_addr: int = int(match.group(1), base=16) + high_addr: int = int(match.group(2), base=16) + + pattern: re.Pattern = re.compile( + rf"^0x([0-9a-f]{{{ptr_width}}}) => branch count: " rf"(\d+) => taken count: (\d+)\n$" + ) + encountered_annotated_jumps: bool = False + for line in lines[1:]: + match: re.Match = pattern.match(line) + + address: int = int(match.group(1), base=16) + total_count: int = int(match.group(2), base=10) + taken_count: int = int(match.group(3), base=10) + + if address < low_addr or address > high_addr: + continue + + address: int = address - low_addr + if address not in pc_jump_map: + continue + + jump: Jump = pc_jump_map[address] + jump.stats.total += total_count + jump.stats.taken += taken_count + + encountered_annotated_jumps = True + + assert encountered_annotated_jumps, "Never encountered annotated jump in branch trace" + + +def data_from_program( + *, + program: str, + trace_file: str, + objdump: str, + csv_name: Optional[str], + no_save: bool, +) -> pd.DataFrame: + """Generate branch dataframe then saves artifacts to current directory""" + + if ISA in ["x86_64", "x86_32"]: + jumps: List[Jump] = collect_fragment_jump_annotations(program, objdump) + assert len(jumps) != 0, "Did not find annotations in objdump" + + collect_debug_info(jumps, program, objdump) + collect_stats_for_x86_annotated_jumps(jumps, trace_file) + else: + assert False, "Unknown ISA should be unreachable" + + # Convert to one dataframe containing all jumps + jumps_df: pd.DataFrame = pd.concat([j.to_dataframe() for j in jumps]) + + if not no_save: + # Save dataframe, executable, and trace to output directory + csv_name: str = csv_name if csv_name else Path(program).stem + ".branch_data.csv" + jumps_df.to_csv(csv_name, index_label="pc") + + return jumps_df + + +def print_summary(jumps_df: pd.DataFrame, single_stat: Optional[str]): + encountered_jumps_df: pd.DataFrame = jumps_df.loc[jumps_df["total"] != 0] + match_df: pd.DataFrame = jumps_df.loc[jumps_df["direction_match"] & (jumps_df["total"] != 0)] + + total_branches: int = jumps_df["total"].sum() + total_unique: int = len(jumps_df) + total_unique_encountered: int = len(encountered_jumps_df) + + prob_diff_error: float = jumps_df["prob_diff_weighted"].sum() / float(total_branches) + prob_diff_error_unique: float = encountered_jumps_df["prob_diff"].sum() / float( + total_unique_encountered + ) + + match_total: int = match_df["total"].sum() + match_unique: int = len(match_df) + match_percent: float = float(match_total) / float(total_branches) + match_percent_unique: float = float(match_unique) / float(total_unique_encountered) + + if single_stat: + stat = { + "total_branches": total_branches, + "total_unique": total_unique, + "total_unique_encountered": total_unique_encountered, + "prob_diff_error": prob_diff_error, + "prob_diff_error_unique": prob_diff_error_unique, + "match_total": match_total, + "match_unique": match_unique, + "match_percent": match_percent, + "match_percent_unique": match_percent_unique, + }[single_stat] + print(stat) + else: + for name, val in { + "Total Dynamic Branches": total_branches, + "Total Unique Branches": total_unique, + "Total Unique Branches Executed": total_unique_encountered, + "Probability Difference Error Dynamic": prob_diff_error, + "Probability Difference Error Unique": prob_diff_error_unique, + "Direction Match Total Dynamic": match_total, + "Direction Match Total Unique Encountered": match_unique, + "Direction Match Percent Dynamic": match_percent, + "Direction Match Percent Unique Encountered": match_percent_unique, + }.items(): + print(f"{name} = {val}") + + +def add_parser(argp: argparse.ArgumentParser): + """Add subparsers and arguments used in this module""" + + # Argument shared with all modes + argp.add_argument( + "program", + metavar="", + type=str, + help="program executable", + ) + argp.add_argument( + "trace_file", + metavar="", + type=str, + help="Trace file that came from running the program executable", + ) + argp.add_argument( + "--csv-name", + type=str, + default=None, + help="Name to use for csv dump (default uses name of program)", + ) + + objdump_path: str = "llvm-objdump" + argp.add_argument( + "--objdump", + metavar="", + type=str, + help="Path to llvm-objdump executable", + default=objdump_path, + ) + + argp.add_argument( + "--single-stat", + type=str, + default=None, + choices=STAT_CHOICES, + help="Print just the value of a single overall stat", + ) + argp.add_argument( + "--no-save", + action="store_true", + help="Do not save individual branch data to CSV", + ) + + isa_choices = [ + "x86_64", + "x86_32", + ] + argp.add_argument( + "--isa", + metavar="", + type=str, + choices=isa_choices, + default=isa_choices[0], + help="The target ISA in the executable", + ) + + +def parse_arguments( + argp: argparse.ArgumentParser, override_args: Optional[List[str]] = None +) -> argparse.Namespace: + """Parse program arguments and error on issues""" + + args: argparse.Namespace = argp.parse_args(override_args) + + global ISA + ISA = args.isa + + if "objdump" in args and not re.search( + { + "x86_64": r"x86-64", + "x86_32": r"x86 ", + }[ISA], + subprocess.check_output([args.objdump, "--version"]).decode("utf-8"), + ): + argp.error(f"Provided llvm-objdump ({args.objdump}) does not have the {ISA} target") + + return args + + +def run_branch_accuracy_data(args: argparse.Namespace): + """Extract args then run the data collection""" + + program: str = args.program + trace_file: str = args.trace_file + objdump: str = args.objdump + csv_name: str = args.csv_name + no_save: bool = args.no_save + jumps_df: pd.Dataframe = data_from_program( + program=program, + trace_file=trace_file, + objdump=objdump, + csv_name=csv_name, + no_save=no_save, + ) + + single_stat: Optional[str] = args.single_stat + print_summary(jumps_df, single_stat) + + +def main() -> int: + argp = argparse.ArgumentParser(description=__doc__) + add_parser(argp) + args: argparse.Namespace = parse_arguments(argp) + run_branch_accuracy_data(args) + return 0 + + +if __name__ == "__main__": + sys.exit(main())