diff --git a/few_shot_tests.py b/few_shot_tests.py index 1658fe1..8495dd3 100755 --- a/few_shot_tests.py +++ b/few_shot_tests.py @@ -61,12 +61,10 @@ def idx_to_big(self, phase, idx): cathegories = 5 dataloader = OurMiniImageNetDataLoader(shot_num=5 * 2, way_num=cathegories, episode_test_sample_num=args.episode_test_sample_num, shuffle_images = args.shuffle_images) #twice shot_num is because one might be uses as the base for the samples -dataloader.generate_data_list(phase='train') -dataloader.generate_data_list(phase='val') -dataloader.generate_data_list(phase='test') +dataloader.generate_data_list(phase=args.dataset) print('mode is',args.dataset) -dataloader.load_list('all') +dataloader.load_list(args.dataset) #print('train',dataloader.train_filenames) #print('val',dataloader.val_filenames) @@ -74,7 +72,7 @@ def idx_to_big(self, phase, idx): base_train_img, base_train_label, base_test_img, base_test_label = \ - dataloader.get_batch(phase='train', idx=0) + dataloader.get_batch(phase=args.dataset, idx=0) train_epoch_size = base_train_img.shape[0] if not args.train_indep_and_dependent: