Skip to content

Commit

Permalink
it's not wandb. Try disable all self.log
Browse files Browse the repository at this point in the history
  • Loading branch information
mwalmsley committed Nov 21, 2023
1 parent 3147b93 commit dee1207
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 17 deletions.
17 changes: 7 additions & 10 deletions zoobot/pytorch/estimators/define_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,24 +239,21 @@ def configure_optimizers(self):


def log_outputs(self, outputs, step_name):
self.log("{}/epoch_loss".format(step_name), outputs['loss'], on_epoch=True, on_step=False,prog_bar=True, logger=True, rank_zero_only=True)
# if self.log_on_step:
# # seperate call to allow for different name, to allow for consistency with TF.keras auto-names
# self.log("{}/epoch_loss".format(step_name), outputs['loss'], on_epoch=True, on_step=False,prog_bar=True, logger=True, rank_zero_only=True)
# if outputs['predictions'].shape[1] == 2: # will only do for binary classifications
# self.log(
# "{}/step_loss".format(step_name), outputs['loss'], on_epoch=False, on_step=True, prog_bar=True, logger=True, rank_zero_only=True)
if outputs['predictions'].shape[1] == 2: # will only do for binary classifications
# logging.info(predictions.shape, labels.shape)
self.log(
"{}_accuracy".format(step_name), self.train_accuracy(outputs['predictions'], torch.argmax(outputs['labels'], dim=1, keepdim=False)), prog_bar=True, rank_zero_only=True)
# "{}_accuracy".format(step_name), self.train_accuracy(outputs['predictions'], torch.argmax(outputs['labels'], dim=1, keepdim=False)), prog_bar=True, rank_zero_only=True)
pass


def log_loss_per_question(self, multiq_loss, prefix):
# log questions individually
# TODO need schema attribute or similar to have access to question names, this will do for now
# unlike Finetuneable..., does not use TorchMetrics, simply logs directly
# TODO could use TorchMetrics and for q in schema, self.q_metric loop
for question_n in range(multiq_loss.shape[1]):
self.log(f'{prefix}/epoch_questions/question_{question_n}_loss:0', torch.mean(multiq_loss[:, question_n]), on_epoch=True, on_step=False, rank_zero_only=True)
# for question_n in range(multiq_loss.shape[1]):
# self.log(f'{prefix}/epoch_questions/question_{question_n}_loss:0', torch.mean(multiq_loss[:, question_n]), on_epoch=True, on_step=False, rank_zero_only=True)
pass



Expand Down
14 changes: 7 additions & 7 deletions zoobot/pytorch/training/train_with_pytorch_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,9 +275,9 @@ def train_default_zoobot_from_scratch(
save_top_k=save_top_k
)

early_stopping_callback = EarlyStopping(monitor='validation/epoch_loss', patience=patience, check_finite=True)

callbacks = [checkpoint_callback, early_stopping_callback] + extra_callbacks
# early_stopping_callback = EarlyStopping(monitor='validation/epoch_loss', patience=patience, check_finite=True)
# , early_stopping_callback
callbacks = [checkpoint_callback] + extra_callbacks

trainer = pl.Trainer(
log_every_n_steps=150, # at batch 512 (A100 MP max), DR5 has ~161 train steps
Expand All @@ -290,12 +290,12 @@ def train_default_zoobot_from_scratch(
callbacks=callbacks,
max_epochs=epochs,
default_root_dir=save_dir,
plugins=plugins,
use_distributed_sampler=use_distributed_sampler
plugins=plugins
# use_distributed_sampler=use_distributed_sampler
)

logging.info((trainer.strategy, trainer.world_size,
trainer.local_rank, trainer.global_rank, trainer.node_rank))
# logging.info((trainer.strategy, trainer.world_size,
# trainer.local_rank, trainer.global_rank, trainer.node_rank))

trainer.fit(lightning_model, datamodule) # uses batch size of datamodule

Expand Down

0 comments on commit dee1207

Please sign in to comment.