Skip to content

Commit b5ea0f9

Browse files
author
slin96
authored
Restructure (#41)
* moved examples * remove unused file * moved schema * moved util * fix basleine save test * fix update and prov test * fix file persistence * fix size tests * reformat * add utils
1 parent e5d6797 commit b5ea0f9

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+108
-109
lines changed

.gitignore

+2-2
Original file line numberDiff line numberDiff line change
@@ -143,10 +143,10 @@ cython_debug/
143143
# build lib
144144
build/
145145
dist/
146-
/examples/summary
146+
/mmlib/examples/summary
147147
/tests/filesystem-tmp/
148148
/tests/saveservice-tmp/
149149
/tests/tmp-data/
150150
/tests/example_files/data/reduced-custom-coco-data.zip
151-
/examples/filesystem-tmp/
151+
/mmlib/examples/filesystem-tmp/
152152
/tests/example_files/local-config.ini

README.md

+1-1

mmlib/equal.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import torch
44

5-
from util.helper import get_device
5+
from mmlib.util.helper import get_device
66

77

88
def blackbox_model_equal(m1: torch.nn.Module, m2: torch.nn.Module, produce_input: Callable[[], torch.tensor],
File renamed without changes.

examples/baseline_save.py mmlib/examples/baseline_save.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
from mmlib.equal import model_equal
44
from mmlib.persistence import FileSystemPersistenceService, MongoDictPersistenceService
55
from mmlib.save import BaselineSaveService
6-
from schema.save_info_builder import ModelSaveInfoBuilder
6+
from mmlib.schema import ModelSaveInfoBuilder
7+
from mmlib.util.dummy_data import imagenet_input
78
from tests.example_files.mynets.mobilenet import mobilenet_v2
8-
from util.dummy_data import imagenet_input
99

1010
CONTAINER_NAME = 'mongo-test'
1111

12-
TARGET_FILE_SYSTEM_DIR = './filesystem-tmp'
12+
TARGET_FILE_SYSTEM_DIR = 'filesystem-tmp'
1313

1414
if __name__ == '__main__':
1515
# initialize a service to store files

examples/probe_example.py mmlib/examples/probe_example.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from mmlib.deterministic import set_deterministic
66
from mmlib.equal import model_equal, blackbox_model_equal, whitebox_model_equal
77
from mmlib.probe import ProbeInfo, probe_inference, probe_training
8-
from util.dummy_data import imagenet_input, imagenet_target
8+
from mmlib.util.dummy_data import imagenet_input, imagenet_target
99

1010
MODEL = models.googlenet
1111

examples/probe_load_compare.py mmlib/examples/probe_load_compare.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import argparse
22

3-
from examples.probe_store import _generate_probe_training_summary
3+
from mmlib.examples.probe_store import _generate_probe_training_summary
44
from mmlib.probe import ProbeSummary, ProbeInfo
55

66

examples/probe_store.py mmlib/examples/probe_store.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from mmlib.deterministic import set_deterministic
99
from mmlib.probe import probe_training
10-
from util.dummy_data import imagenet_input, imagenet_target
10+
from mmlib.util.dummy_data import imagenet_input, imagenet_target
1111

1212

1313
def main(args):

examples/provenance_save.py mmlib/examples/provenance_save.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,16 @@
88
from mmlib.equal import model_equal
99
from mmlib.persistence import FileSystemPersistenceService, MongoDictPersistenceService
1010
from mmlib.save import ProvenanceSaveService
11+
from mmlib.schema import ModelSaveInfoBuilder
12+
from mmlib.schema.restorable_object import RestorableObjectWrapper, OptimizerWrapper
1113
from mmlib.track_env import track_current_environment
12-
from schema.restorable_object import RestorableObjectWrapper, OptimizerWrapper
13-
from schema.save_info_builder import ModelSaveInfoBuilder
14+
from mmlib.util.dummy_data import imagenet_input
1415
from tests.example_files.data.custom_coco import TrainCustomCoco
1516
from tests.example_files.imagenet_train import ImagenetTrainService, DATALOADER, OPTIMIZER, ImagenetTrainWrapper, DATA
1617
from tests.example_files.mynets.resnet18 import resnet18
17-
from util.dummy_data import imagenet_input
1818

1919
CONTAINER_NAME = 'mongo-test'
20-
TARGET_FILE_SYSTEM_DIR = './filesystem-tmp'
20+
TARGET_FILE_SYSTEM_DIR = 'filesystem-tmp'
2121

2222

2323
def initi_train_service():

mmlib/log.py

-2
This file was deleted.

mmlib/persistence.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77

88
from bson import ObjectId
99

10-
from schema.file_reference import FileReference
11-
from util.helper import find_file, log_stop, START_STOP, TIME
12-
from util.mongo import MongoService
10+
from mmlib.schema.file_reference import FileReference
11+
from mmlib.util.helper import find_file, log_stop, START_STOP, TIME
12+
from mmlib.util.mongo import MongoService
1313

1414
MMLIB_FILE_PERS = 'mmlib_file_pers'
1515

mmlib/probe.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch.nn as nn
66
from colorama import Fore, Style
77

8-
from util.helper import print_info, get_device
8+
from mmlib.util.helper import print_info, get_device
99

1010

1111
class ProbeInfo(Enum):

mmlib/save.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,17 @@
88
from mmlib.equal import tensor_equal
99
from mmlib.persistence import FilePersistenceService, DictPersistenceService
1010
from mmlib.save_info import ModelSaveInfo, ProvModelSaveInfo
11+
from mmlib.schema.dataset import Dataset
12+
from mmlib.schema.file_reference import FileReference
13+
from mmlib.schema.model_info import ModelInfo, MODEL_INFO
14+
from mmlib.schema.recover_info import FullModelRecoverInfo, WeightsUpdateRecoverInfo, ProvenanceRecoverInfo
15+
from mmlib.schema.restorable_object import RestoredModelInfo
16+
from mmlib.schema.store_type import ModelStoreType
17+
from mmlib.schema.train_info import TrainInfo
1118
from mmlib.track_env import compare_env_to_current
12-
from schema.dataset import Dataset
13-
from schema.file_reference import FileReference
14-
from schema.model_info import ModelInfo, MODEL_INFO
15-
from schema.recover_info import FullModelRecoverInfo, ProvenanceRecoverInfo, WeightsUpdateRecoverInfo
16-
from schema.restorable_object import RestoredModelInfo
17-
from schema.store_type import ModelStoreType
18-
from schema.train_info import TrainInfo
19-
from util.helper import log_start, log_stop
20-
from util.init_from_file import create_object, create_type
21-
from util.weight_dict_merkle_tree import WeightDictMerkleTree, THIS, OTHER
19+
from mmlib.util.helper import log_start, log_stop
20+
from mmlib.util.init_from_file import create_object, create_type
21+
from mmlib.util.weight_dict_merkle_tree import WeightDictMerkleTree, THIS, OTHER
2222

2323
PROVENANCE = 'provenance'
2424

mmlib/save_info.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import torch
22

3-
from schema.environment import Environment
4-
from schema.restorable_object import StateDictRestorableObjectWrapper
5-
from util.helper import class_name, source_file
3+
from mmlib.schema.environment import Environment
4+
from mmlib.schema.restorable_object import StateDictRestorableObjectWrapper
5+
from mmlib.util.helper import class_name, source_file
66

77

88
class TrainSaveInfo:
File renamed without changes.

schema/dataset.py mmlib/schema/dataset.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import os
22

33
from mmlib.persistence import FilePersistenceService, DictPersistenceService
4-
from schema.file_reference import FileReference
5-
from schema.schema_obj import SchemaObj
6-
from util.zip import zip_path, unzip
4+
from mmlib.schema.file_reference import FileReference
5+
from mmlib.schema.schema_obj import SchemaObj
6+
from mmlib.util.zip import zip_path, unzip
77

88
RAW_DATA = 'raw_data'
99

schema/environment.py mmlib/schema/environment.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from mmlib.persistence import FilePersistenceService, DictPersistenceService
2-
from schema.schema_obj import SchemaObj
2+
from mmlib.schema.schema_obj import SchemaObj
33

44
PYTHON_VERSION = 'python_version'
55
PYTORCH_VERSION = 'pytorch_version'
File renamed without changes.

schema/model_info.py mmlib/schema/model_info.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from mmlib.persistence import FilePersistenceService, DictPersistenceService
2-
from schema.recover_info import FullModelRecoverInfo, AbstractRecoverInfo, WeightsUpdateRecoverInfo, \
2+
from mmlib.schema.recover_info import FullModelRecoverInfo, AbstractRecoverInfo, WeightsUpdateRecoverInfo, \
33
ProvenanceRecoverInfo, RECOVER_INFO
4-
from schema.schema_obj import SchemaObj
5-
from schema.store_type import ModelStoreType
6-
from util.weight_dict_merkle_tree import WeightDictMerkleTree
4+
from mmlib.schema.schema_obj import SchemaObj
5+
from mmlib.schema.store_type import ModelStoreType
6+
from mmlib.util.weight_dict_merkle_tree import WeightDictMerkleTree
77

88
STORE_TYPE = 'store_type'
99
RECOVER_INFO_ID = 'recover_info_id'

schema/recover_info.py mmlib/schema/recover_info.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44

55
from mmlib.constants import MMLIB_CONFIG, CURRENT_DATA_ROOT, VALUES
66
from mmlib.persistence import FilePersistenceService, DictPersistenceService
7-
from schema.dataset import Dataset
8-
from schema.environment import Environment
9-
from schema.file_reference import FileReference
10-
from schema.schema_obj import SchemaObj
11-
from schema.train_info import TrainInfo
12-
from util.helper import copy_all_data, clean
7+
from mmlib.schema.dataset import Dataset
8+
from mmlib.schema.environment import Environment
9+
from mmlib.schema.file_reference import FileReference
10+
from mmlib.schema.schema_obj import SchemaObj
11+
from mmlib.schema.train_info import TrainInfo
12+
from mmlib.util.helper import copy_all_data, clean
1313

1414
RECOVER_INFO = 'recover_info'
1515

schema/restorable_object.py mmlib/schema/restorable_object.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99

1010
from mmlib.constants import MMLIB_CONFIG, VALUES, ID
1111
from mmlib.persistence import FilePersistenceService, DictPersistenceService
12-
from schema.file_reference import FileReference
13-
from schema.schema_obj import SchemaObj
14-
from util.helper import class_name, source_file
15-
from util.init_from_file import create_object_with_parameters
12+
from mmlib.schema.file_reference import FileReference
13+
from mmlib.schema.schema_obj import SchemaObj
14+
from mmlib.util.helper import class_name, source_file
15+
from mmlib.util.init_from_file import create_object_with_parameters
1616

1717
STATE_DICT = 'state_dict'
1818

schema/save_info_builder.py mmlib/schema/save_info_builder.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import torch
22

33
from mmlib.save_info import ModelSaveInfo, TrainSaveInfo, ProvModelSaveInfo
4-
from schema.environment import Environment
5-
from schema.restorable_object import StateDictRestorableObjectWrapper
4+
from mmlib.schema.environment import Environment
5+
from mmlib.schema.restorable_object import StateDictRestorableObjectWrapper
66

77

88
class ModelSaveInfoBuilder:

schema/schema_obj.py mmlib/schema/schema_obj.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from mmlib.constants import ID
88
from mmlib.persistence import FilePersistenceService, DictPersistenceService
9-
from util.helper import log_stop, START_STOP, START, TIME
9+
from mmlib.util.helper import log_stop, START_STOP, START, TIME
1010

1111
METADATA_SIZE = 'metadata_size'
1212

File renamed without changes.

schema/train_info.py mmlib/schema/train_info.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import os
22

33
from mmlib.persistence import FilePersistenceService, DictPersistenceService
4-
from schema.file_reference import FileReference
5-
from schema.restorable_object import StateDictRestorableObjectWrapper
6-
from schema.schema_obj import SchemaObj
7-
from util.init_from_file import create_type
4+
from mmlib.schema.file_reference import FileReference
5+
from mmlib.schema.restorable_object import StateDictRestorableObjectWrapper
6+
from mmlib.schema.schema_obj import SchemaObj
7+
from mmlib.util.init_from_file import create_type
88

99
TRAIN_SERVICE = 'train_service'
1010
TRAIN_KWARGS = 'train_kwargs'

mmlib/track_env.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch
77
from torch.utils.collect_env import get_pretty_env_info
88

9-
from schema.environment import Environment
9+
from mmlib.schema.environment import Environment
1010

1111
ARCHITECTURE = 'architecture'
1212
MACHINE = 'machine'

mmlib/util/__init__.py

Whitespace-only changes.
File renamed without changes.

util/hash.py mmlib/util/hash.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch
44

55
from mmlib.deterministic import set_deterministic
6-
from util.helper import get_device
6+
from mmlib.util.helper import get_device
77

88

99
def inference_hash(model: torch.nn.Module, dummy_input_shape: [int]):
File renamed without changes.
File renamed without changes.

util/mongo.py mmlib/util/mongo.py

File renamed without changes.

util/weight_dict_merkle_tree.py mmlib/util/weight_dict_merkle_tree.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import math
22

3-
from util.hash import hash_string, tensor_hash
3+
from mmlib.util.hash import hash_string, tensor_hash
44

55
OTHER = 'other'
66

util/zip.py mmlib/util/zip.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
import zipfile
33

4-
from util.helper import zip_dir
4+
from mmlib.util.helper import zip_dir
55

66

77
def zip_path(save_path: str) -> str:

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
long_description=long_description,
1818
long_description_content_type="text/markdown",
1919
url="https://github.com/slin96/mmlib",
20-
packages=['mmlib', 'util', 'schema'],
20+
packages=['mmlib', 'mmlib.schema', 'mmlib.util'],
2121
classifiers=[
2222
"Programming Language :: Python :: 3",
2323
"License :: OSI Approved :: MIT License",

tests/example_files/imagenet_optimizer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
from torch.optim import SGD
33

4-
from schema.restorable_object import StateFileRestorableObject
4+
from mmlib.schema.restorable_object import StateFileRestorableObject
55

66

77
class ImagenetOptimizer(SGD, StateFileRestorableObject):

tests/example_files/imagenet_train.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
from mmlib.deterministic import set_deterministic
44
from mmlib.persistence import FilePersistenceService, DictPersistenceService
5-
from schema.restorable_object import TrainService, StateDictRestorableObjectWrapper, \
5+
from mmlib.schema.restorable_object import TrainService, StateDictRestorableObjectWrapper, \
66
RestorableObjectWrapper, StateFileRestorableObjectWrapper
7-
from util.init_from_file import create_object
7+
from mmlib.util.init_from_file import create_object
88

99
DATA = 'data'
1010
DATALOADER = 'dataloader'

tests/init_from_file/test_init_form_file.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
import unittest
33

44
from mmlib.equal import model_equal
5+
from mmlib.util.dummy_data import imagenet_input
6+
from mmlib.util.init_from_file import create_object, create_object_with_parameters
57
from tests.example_files.mynets.resnet18 import resnet18
68
from tests.init_from_file.dummy_classes import DummyA
7-
from util.dummy_data import imagenet_input
8-
from util.init_from_file import create_object, create_object_with_parameters
99

1010
CODE = os.path.join(os.path.dirname(os.path.realpath(__file__)), './dummy_classes.py')
1111

tests/model_equal/test_equal.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55

66
from mmlib.deterministic import set_deterministic
77
from mmlib.equal import state_dict_equal, model_equal
8-
from util.dummy_data import imagenet_input
9-
from util.hash import state_dict_hash, tensor_hash
8+
from mmlib.util import state_dict_hash, tensor_hash
9+
from mmlib.util.dummy_data import imagenet_input
1010

1111

1212
class TestStateDictEqual(unittest.TestCase):

tests/persistence/test_file_persistence.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import unittest
55

66
from mmlib.persistence import FileSystemPersistenceService
7-
from schema.file_reference import FileReference
7+
from mmlib.schema.file_reference import FileReference
88

99

1010
class TestPersistence(unittest.TestCase):

tests/probing/test_probe.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from mmlib.deterministic import set_deterministic
99
from mmlib.probe import ProbeSummary, probe_inference, probe_training
10-
from util.dummy_data import imagenet_input, imagenet_target
10+
from mmlib.util.dummy_data import imagenet_input, imagenet_target
1111

1212
TMP_SUMMARY_PATH = './tmp-summary'
1313

tests/save/test_baseline_save_servcie.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@
55
from mmlib.equal import model_equal
66
from mmlib.persistence import FileSystemPersistenceService, MongoDictPersistenceService
77
from mmlib.save import BaselineSaveService
8+
from mmlib.schema.save_info_builder import ModelSaveInfoBuilder
89
from mmlib.track_env import track_current_environment
9-
from schema.save_info_builder import ModelSaveInfoBuilder
10+
from mmlib.util.dummy_data import imagenet_input
11+
from mmlib.util.mongo import MongoService
1012
from tests.example_files.mynets.googlenet import googlenet
1113
from tests.example_files.mynets.mobilenet import mobilenet_v2
1214
from tests.example_files.mynets.resnet18 import resnet18
13-
from util.dummy_data import imagenet_input
14-
from util.mongo import MongoService
1515

1616
FILE_PATH = os.path.dirname(os.path.realpath(__file__))
1717
NETWORK_CODE_TEMPLATE = os.path.join(FILE_PATH, '../example_files/mynets/{}.py')

0 commit comments

Comments
 (0)