diff --git a/llvm/lib/Analysis/models/gen-regalloc-eviction-test-model.py b/llvm/lib/Analysis/models/gen-regalloc-eviction-test-model.py --- a/llvm/lib/Analysis/models/gen-regalloc-eviction-test-model.py +++ b/llvm/lib/Analysis/models/gen-regalloc-eviction-test-model.py @@ -46,7 +46,7 @@ module.var = tf.Variable(0, dtype=tf.int64) def action(*inputs): - result = tf.math.argmax(inputs[0]['mask'], axis=-1) + module.var + result = tf.math.argmax(tf.cast(inputs[0]['mask'], tf.int32), axis=-1) + module.var return {POLICY_DECISION_LABEL: result} module.action = tf.function()(action) action = {