@@ -87,30 +87,31 @@ def __init__(self, model, model_old, device, opts, trainer_state=None, classes=N
87
87
self .lkd_flag = self .lkd > 0. and model_old is not None
88
88
self .kd_need_labels = False
89
89
self .lgkd_flag = opts .lgkd
90
- if opts .unkd :
91
- self .lkd_loss = UnbiasedKnowledgeDistillationLoss (reduction = "none" , alpha = opts .alpha )
92
- elif opts .lgkd :
93
- self .lkd_loss = LabelGuidedKnowledgeDistillationLoss (alpha = opts .alpha ,
94
- prev_kd = opts .prev_kd ,
95
- novel_kd = opts .novel_kd )
96
- elif opts .kd_bce_sig :
97
- self .lkd_loss = BCESigmoid (reduction = "none" , alpha = opts .alpha , shape = opts .kd_bce_sig_shape )
98
- elif opts .exkd_gt and self .old_classes > 0 and self .step > 0 :
99
- self .lkd_loss = ExcludedKnowledgeDistillationLoss (
100
- reduction = 'none' , index_new = self .old_classes , new_reduction = "gt" ,
101
- initial_nb_classes = opts .inital_nb_classes ,
102
- temperature_semiold = opts .temperature_semiold
103
- )
104
- self .kd_need_labels = True
105
- elif opts .exkd_sum and self .old_classes > 0 and self .step > 0 :
106
- self .lkd_loss = ExcludedKnowledgeDistillationLoss (
107
- reduction = 'none' , index_new = self .old_classes , new_reduction = "sum" ,
108
- initial_nb_classes = opts .inital_nb_classes ,
109
- temperature_semiold = opts .temperature_semiold
110
- )
111
- self .kd_need_labels = True
112
- else :
113
- self .lkd_loss = KnowledgeDistillationLoss (alpha = opts .alpha )
90
+ if self .step > 0 :
91
+ if opts .unkd :
92
+ self .lkd_loss = UnbiasedKnowledgeDistillationLoss (reduction = "none" , alpha = opts .alpha )
93
+ elif opts .lgkd :
94
+ self .lkd_loss = LabelGuidedKnowledgeDistillationLoss (alpha = opts .alpha ,
95
+ prev_kd = opts .prev_kd ,
96
+ novel_kd = opts .novel_kd )
97
+ elif opts .kd_bce_sig :
98
+ self .lkd_loss = BCESigmoid (reduction = "none" , alpha = opts .alpha , shape = opts .kd_bce_sig_shape )
99
+ elif opts .exkd_gt and self .old_classes > 0 :
100
+ self .lkd_loss = ExcludedKnowledgeDistillationLoss (
101
+ reduction = 'none' , index_new = self .old_classes , new_reduction = "gt" ,
102
+ initial_nb_classes = opts .inital_nb_classes ,
103
+ temperature_semiold = opts .temperature_semiold
104
+ )
105
+ self .kd_need_labels = True
106
+ elif opts .exkd_sum and self .old_classes > 0 :
107
+ self .lkd_loss = ExcludedKnowledgeDistillationLoss (
108
+ reduction = 'none' , index_new = self .old_classes , new_reduction = "sum" ,
109
+ initial_nb_classes = opts .inital_nb_classes ,
110
+ temperature_semiold = opts .temperature_semiold
111
+ )
112
+ self .kd_need_labels = True
113
+ else :
114
+ self .lkd_loss = KnowledgeDistillationLoss (alpha = opts .alpha )
114
115
115
116
# ICARL
116
117
self .icarl_combined = False
0 commit comments