Skip to content

Commit

Permalink
added t5 eval
Browse files Browse the repository at this point in the history
  • Loading branch information
Matthew Brady committed Apr 12, 2024
1 parent 73de285 commit 62d67b5
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions ft_model_t5_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,12 @@ def get_completion_merged(input_text: str, output_text: str, model, tokenizer) -
encodeds = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
model_inputs = encodeds.to(device)
generated_ids = model.generate(**model_inputs, max_new_tokens=100, do_sample=True, pad_token_id=tokenizer.eos_token_id)
#decoded = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
decoded = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
# Get prompt length to remove the prepended prompt from the output
prompt_length = model_inputs.shape[1]
print(prompt_length)
decoded = tokenizer.batch_decode(generated_ids[0][prompt_length:], skip_special_tokens=True)
#TODO: Implement some trimming of the inputs so only get predicted prompt tokens. idea vvv
#prompt_length = model_inputs.shape[1]
#print(prompt_length)
#decoded = tokenizer.batch_decode(generated_ids[0][prompt_length:], skip_special_tokens=True)
return decoded[0]


Expand Down

0 comments on commit 62d67b5

Please sign in to comment.