from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2", pad_token_id=tokenizer.eos_token_id)
prompt_text = "This is a nice story that makes me"
max_gen_len = 9
input_ids = tokenizer.encode(prompt_text, return_tensors="pt")
prompt_len = input_ids.shape[-1]
print(f'length of prompt: {prompt_len}, length of generation: {max_gen_len}')
print('>>> Way 1: Use `()` to generate tokens with KV cache')
generated_ids = model.generate(input_ids, max_length=prompt_len+max_gen_len, pad_token_id=tokenizer.eos_token_id)
print('generated_ids:', generated_ids)
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
print('generated_text:', generated_text)
print('>>> Way 2: Use `for loop` to generate tokens with KV cache')
past_key_values = None
print('Prefill Stage..')
outputs = model(input_ids=input_ids, past_key_values=past_key_values, use_cache=True)
past_key_values = outputs.past_key_values
pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
generated_ids = [pred_token_idx.item()]
print('Decoding/Generating Stage..')
for _ in range(max_gen_len - 1):
outputs = model(input_ids=pred_token_idx, past_key_values=past_key_values, use_cache=True)
past_key_values = outputs.past_key_values # if use_cache=False, past_key_values=None
pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
generated_ids.append(pred_token_idx.item())
print('generated_ids:', generated_ids)
generated_text = tokenizer.decode(torch.Tensor(generated_ids), skip_special_tokens=True)
print('generated_text:', prompt_text + generated_text)