Large deep learning models rely on heavy computations. However, not
every computation is necessary. And, even when a computation is
necessary, it helps if the values needed for the computation are
available in registers (which have low-latency) rather than being in
memory (which has high-latency).
Compilers can use liveness analysis to:-
(1) Remove extraneous computations from a program before it executes on
hardware, and,
(2) Optimize register allocation.
Both these tasks help achieve one very important goal: reducing runtime.
Recently, liveness analysis was added to MLIR. Thus, this commit uses
the recently added liveness analysis utility to try to accomplish task
(1).
It adds a pass called remove-dead-values whose goal is
optimization (reducing runtime) by removing unnecessary instructions.
Unlike other passes that rely on local information gathered from
patterns to accomplish optimization, this pass uses a full analysis of
the IR, specifically, liveness analysis, and is thus more powerful.
Currently, this pass performs the following optimizations:
(A) Removes function arguments that are not live,
(B) Removes function return values that are not live across all callers of
the function,
(C) Removes unneccesary operands, results, region arguments, region
terminator operands of region branch ops, and,
(D) Removes simple and region branch ops that have all non-live results and
don't affect memory in any way,
iff
the IR doesn't have any non-function symbol ops, non-call symbol user ops
and branch ops.
Here, a "simple op" refers to an op that isn't a symbol op, symbol-user op,
region branch op, branch op, region branch terminator op, or return-like.
It is noteworthy that we do not refer to non-live values as "dead" in this
file to avoid confusing it with dead code analysis's "dead", which refers to
unreachable code (code that never executes on hardware) while "non-live"
refers to code that executes on hardware but is unnecessary. Thus, while the
removal of dead code helps little in reducing runtime, removing non-live
values should theoretically have significant impact (depending on the amount
removed).
It is also important to note that unlike other passes (like canonicalize)
that apply op-specific optimizations through patterns, this pass uses
different interfaces to handle various types of ops and tries to cover all
existing ops through these interfaces.
It is because of its reliance on (a) liveness analysis and (b) interfaces
that makes it so powerful that it can optimize ops that don't have a
canonicalizer and even when an op does have a canonicalizer, it can perform
more aggressive optimizations, as observed in the test files associated with
this pass.
Example of optimization (A):-
int add_2_to_y(int x, int y) { return 2 + y } print(add_2_to_y(3, 4)) print(add_2_to_y(5, 6))
becomes
int add_2_to_y(int y) { return 2 + y } print(add_2_to_y(4)) print(add_2_to_y(6))
Example of optimization (B):-
int, int get_incremented_values(int y) { store y somewhere in memory return y + 1, y + 2 } y1, y2 = get_incremented_values(4) y3, y4 = get_incremented_values(6) print(y2)
becomes
int get_incremented_values(int y) { store y somewhere in memory return y + 2 } y2 = get_incremented_values(4) y4 = get_incremented_values(6) print(y2)
Example of optimization (C):-
Assume only %result1 is live here. Then,
%result1, %result2, %result3 = scf.while (%arg1 = %operand1, %arg2 = %operand2) { %terminator_operand2 = add %arg2, %arg2 %terminator_operand3 = mul %arg2, %arg2 %terminator_operand4 = add %arg1, %arg1 scf.condition(%terminator_operand1) %terminator_operand2, %terminator_operand3, %terminator_operand4 } do { ^bb0(%arg3, %arg4, %arg5): %terminator_operand6 = add %arg4, %arg4 %terminator_operand5 = add %arg5, %arg5 scf.yield %terminator_operand5, %terminator_operand6 }
becomes
%result1, %result2 = scf.while (%arg2 = %operand2) { %terminator_operand2 = add %arg2, %arg2 %terminator_operand3 = mul %arg2, %arg2 scf.condition(%terminator_operand1) %terminator_operand2, %terminator_operand3 } do { ^bb0(%arg3, %arg4): %terminator_operand6 = add %arg4, %arg4 scf.yield %terminator_operand6 }
It is interesting to see that %result2 won't be removed even though it is
not live because %terminator_operand3 forwards to it and cannot be
removed. And, that is because it also forwards to %arg4, which is live.
Example of optimization (D):-
int square_and_double_of_y(int y) { square = y ^ 2 double = y * 2 return square, double } sq, do = square_and_double_of_y(5) print(do)
becomes
int square_and_double_of_y(int y) { double = y * 2 return double } do = square_and_double_of_y(5) print(do)
Signed-off-by: Srishti Srivastava <srishtisrivastava.ai@gmail.com>
Please list the required invariants for the pass to succeed on the IR. There are invariants enforced in the pass definition but they are not listed here.