web123456

Use the Transformers library to implement basic large-model text generation and KV cache precautions

from transformers import GPT2LMHeadModel, GPT2Tokenizer import torch tokenizer = GPT2Tokenizer.from_pretrained("gpt2") model = GPT2LMHeadModel.from_pretrained("gpt2", pad_token_id=tokenizer.eos_token_id) prompt_text = "This is a nice story that makes me" max_gen_len = 9 input_ids = tokenizer.encode(prompt_text, return_tensors="pt") prompt_len = input_ids.shape[-1] print(f'length of prompt: {prompt_len}, length of generation: {max_gen_len}') print('>>> Way 1: Use `()` to generate tokens with KV cache') generated_ids = model.generate(input_ids, max_length=prompt_len+max_gen_len, pad_token_id=tokenizer.eos_token_id) print('generated_ids:', generated_ids) generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) print('generated_text:', generated_text) print('>>> Way 2: Use `for loop` to generate tokens with KV cache') past_key_values = None print('Prefill Stage..') outputs = model(input_ids=input_ids, past_key_values=past_key_values, use_cache=True) past_key_values = outputs.past_key_values pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1) generated_ids = [pred_token_idx.item()] print('Decoding/Generating Stage..') for _ in range(max_gen_len - 1): outputs = model(input_ids=pred_token_idx, past_key_values=past_key_values, use_cache=True) past_key_values = outputs.past_key_values # if use_cache=False, past_key_values=None pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1) generated_ids.append(pred_token_idx.item()) print('generated_ids:', generated_ids) generated_text = tokenizer.decode(torch.Tensor(generated_ids), skip_special_tokens=True) print('generated_text:', prompt_text + generated_text)