This issue is based on the notebook here. I was comparing the next token prediction (NTP) on the student mode using set_transformer_early_exit_mode(model, 'sft_student') and manual prediction using exiting and freezing usig KV caching.
- When there were no early exits, the NTP was identical. This is expected and I was hoping it would extend to any general exiting.
- When I force an early exit on a single token, the NTP using the student mode is gibberish for later tokens whereas the manual prediction gives sane predictions only slightly different from the base model.
Maybe there is some unexpected behaviour is going on in patched_layer_forward() and patched_attention_forward()?
This issue is based on the notebook here. I was comparing the next token prediction (NTP) on the student mode using
set_transformer_early_exit_mode(model, 'sft_student')and manual prediction using exiting and freezing usig KV caching.Maybe there is some unexpected behaviour is going on in
patched_layer_forward()andpatched_attention_forward()?