Skip to content

Commit 78262c2

Browse files
committed
update logger
1 parent 60d1d9f commit 78262c2

File tree

3 files changed

+34
-31
lines changed

3 files changed

+34
-31
lines changed

train.py

+25-24
Original file line numberDiff line numberDiff line change
@@ -87,30 +87,31 @@ def __init__(self, model, model_old, device, opts, trainer_state=None, classes=N
8787
self.lkd_flag = self.lkd > 0. and model_old is not None
8888
self.kd_need_labels = False
8989
self.lgkd_flag = opts.lgkd
90-
if opts.unkd:
91-
self.lkd_loss = UnbiasedKnowledgeDistillationLoss(reduction="none", alpha=opts.alpha)
92-
elif opts.lgkd:
93-
self.lkd_loss = LabelGuidedKnowledgeDistillationLoss(alpha=opts.alpha,
94-
prev_kd=opts.prev_kd,
95-
novel_kd=opts.novel_kd)
96-
elif opts.kd_bce_sig:
97-
self.lkd_loss = BCESigmoid(reduction="none", alpha=opts.alpha, shape=opts.kd_bce_sig_shape)
98-
elif opts.exkd_gt and self.old_classes > 0 and self.step > 0:
99-
self.lkd_loss = ExcludedKnowledgeDistillationLoss(
100-
reduction='none', index_new=self.old_classes, new_reduction="gt",
101-
initial_nb_classes=opts.inital_nb_classes,
102-
temperature_semiold=opts.temperature_semiold
103-
)
104-
self.kd_need_labels = True
105-
elif opts.exkd_sum and self.old_classes > 0 and self.step > 0:
106-
self.lkd_loss = ExcludedKnowledgeDistillationLoss(
107-
reduction='none', index_new=self.old_classes, new_reduction="sum",
108-
initial_nb_classes=opts.inital_nb_classes,
109-
temperature_semiold=opts.temperature_semiold
110-
)
111-
self.kd_need_labels = True
112-
else:
113-
self.lkd_loss = KnowledgeDistillationLoss(alpha=opts.alpha)
90+
if self.step > 0:
91+
if opts.unkd:
92+
self.lkd_loss = UnbiasedKnowledgeDistillationLoss(reduction="none", alpha=opts.alpha)
93+
elif opts.lgkd:
94+
self.lkd_loss = LabelGuidedKnowledgeDistillationLoss(alpha=opts.alpha,
95+
prev_kd=opts.prev_kd,
96+
novel_kd=opts.novel_kd)
97+
elif opts.kd_bce_sig:
98+
self.lkd_loss = BCESigmoid(reduction="none", alpha=opts.alpha, shape=opts.kd_bce_sig_shape)
99+
elif opts.exkd_gt and self.old_classes > 0:
100+
self.lkd_loss = ExcludedKnowledgeDistillationLoss(
101+
reduction='none', index_new=self.old_classes, new_reduction="gt",
102+
initial_nb_classes=opts.inital_nb_classes,
103+
temperature_semiold=opts.temperature_semiold
104+
)
105+
self.kd_need_labels = True
106+
elif opts.exkd_sum and self.old_classes > 0:
107+
self.lkd_loss = ExcludedKnowledgeDistillationLoss(
108+
reduction='none', index_new=self.old_classes, new_reduction="sum",
109+
initial_nb_classes=opts.inital_nb_classes,
110+
temperature_semiold=opts.temperature_semiold
111+
)
112+
self.kd_need_labels = True
113+
else:
114+
self.lkd_loss = KnowledgeDistillationLoss(alpha=opts.alpha)
114115

115116
# ICARL
116117
self.icarl_combined = False

utils/logger.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import functools
12
import logging
23
import os
34
import sys
@@ -92,24 +93,21 @@ def add_results(self, results, epoch=None):
9293
text += "</table>"
9394
self.writer.add_text(tag, text)
9495

95-
def setup_logger(self, output=None, distributed_rank=0, color=True, name="PLOP", abbrev_name=None):
96+
@functools.lru_cache() # so that calling setup_logger multiple times won't add many handlers
97+
def setup_logger(self, output=None, distributed_rank=0, color=True, name="LGKD", abbrev_name=None):
9698
"""
9799
Args:
98100
output (str): a file name or a directory to save log. If None, will not save log file.
99101
If ends with ".txt" or ".log", assumed to be a file name.
100102
Otherwise, logs will be saved to `output/log.txt`.
101103
name (str): the root module name of this logger
102-
abbrev_name (str): an abbreviation of the module, to avoid long names in logs.
103-
Set to "" to not log the root module in logs.
104-
By default, will abbreviate "detectron2" to "d2" and leave other
105-
modules unchanged.
106104
"""
107105
logger = logging.getLogger(name)
108106
logger.setLevel(logging.DEBUG)
109107
logger.propagate = False
110108

111109
if abbrev_name is None:
112-
abbrev_name = "plop" if name == "PLOP" else name
110+
abbrev_name = name
113111

114112
plain_formatter = logging.Formatter(
115113
"[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S"
@@ -170,5 +168,6 @@ def formatMessage(self, record):
170168

171169
# cache the opened file object, so that different calls to `setup_logger`
172170
# with the same file name can safely write to the same file.
171+
@functools.lru_cache(maxsize=None)
173172
def _cached_log_stream(filename):
174173
return open(filename, "a")

utils/loss.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import torch
22
import torch.nn as nn
33
import torch.nn.functional as F
4+
import logging
5+
6+
logger = logging.getLogger('LGKD.' + __name__)
47

58

69
def get_loss(loss_type):
@@ -454,7 +457,7 @@ def __init__(self, reduction='mean', alpha=1., prev_kd=10, novel_kd=1):
454457
self.alpha = alpha
455458
self.prev_kd = prev_kd
456459
self.novel_kd = novel_kd
457-
print("prev kd: {}\t novel kd: {}".format(self.prev_kd, self.novel_kd))
460+
logger.info("prev kd: {}\t novel kd: {}".format(self.prev_kd, self.novel_kd))
458461

459462
def forward(self, new_logits, old_logits, targets):
460463
targets = targets.clone()

0 commit comments

Comments
 (0)