Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

validate_ensemble(args, epoch): #83

Open
shasso2s opened this issue Oct 8, 2022 · 1 comment
Open

validate_ensemble(args, epoch): #83

shasso2s opened this issue Oct 8, 2022 · 1 comment
Labels
question Further information is requested

Comments

@shasso2s
Copy link

shasso2s commented Oct 8, 2022

hallo , i don't understand this method(validate_ensemble(args, epoch):) can you plz explain it to me

@YuanGongND
Copy link
Owner

Hi there,

I assume you mean

ast/src/traintest.py

Lines 311 to 327 in 3f53567

def validate_ensemble(args, epoch):
exp_dir = args.exp_dir
target = np.loadtxt(exp_dir+'/predictions/target.csv', delimiter=',')
if epoch == 1:
cum_predictions = np.loadtxt(exp_dir + '/predictions/predictions_1.csv', delimiter=',')
else:
cum_predictions = np.loadtxt(exp_dir + '/predictions/cum_predictions.csv', delimiter=',') * (epoch - 1)
predictions = np.loadtxt(exp_dir+'/predictions/predictions_' + str(epoch) + '.csv', delimiter=',')
cum_predictions = cum_predictions + predictions
# remove the prediction file to save storage space
os.remove(exp_dir+'/predictions/predictions_' + str(epoch-1) + '.csv')
cum_predictions = cum_predictions / epoch
np.savetxt(exp_dir+'/predictions/cum_predictions.csv', cum_predictions, delimiter=',')
stats = calculate_stats(cum_predictions, target)
return stats

Please refer to this paper, Section VI.B.1 Checkpoint Averaging. This function implement the above algorithm that 1) save checkpoint model of every epoch (or start from a specific epoch), 2) use each model checkpoint to do an inference on the data, and 3) average the prediction of each check point model. For this, you don't need to train multiple times, but just train once and ensemble the checkpoint models. Note: though it saves training cost compared with normal ensemble, it performs worse. See the above paper for details.

-Yuan

@YuanGongND YuanGongND added bug Something isn't working question Further information is requested and removed bug Something isn't working labels Oct 9, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants