Skip to content

Commit 1f2d549

Browse files
author
slin96
authored
baseline functionality (#7)
* very basic skelton for usecase1 * develop and test basic functionality to insert json to mongo db and retrieve by id * remove @classmethod * test and dev add attribute * dev and test save model * added two placeholder savemodel methods * make mongo service more genral * moved mongo servcie to separate file * dev and test saved_ids * restore model * typo * fix test * extract mongo id helper * save matehod with pickle and code zipped * save and restore working * save and restore all test running * equals -> equal * script to build mmlib in docker container * script to build mmlib in docker container * added documentation: build with docker * fix probe example * added section on repo content * documented examples * documented examples * added comment for model * use python warning istead of custom method * updated requirements * added autor email * updated setup.py * restructuring * added comments to save.py * introduced abstract classes for Save and Recover Service * adjusted comments * added types for helper * added types for public methods from model_equal and probe * updated Readme * updated Readme * fixed str vs ObjectIds * include util in library * refactoring of equal methods * added TODOs to use protocols * simple version of use model * readme adjustments * rename ProbeTest -> SaveTest
1 parent c91d921 commit 1f2d549

25 files changed

+855
-96
lines changed

README.md

+32-5
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,34 @@
1-
# MM-LIB
1+
# mmlib
2+
3+
- A library for model management and related tasks.
4+
5+
## Installation
6+
7+
### Option 1: Docker
8+
9+
- **Requirements**: Docker installed
10+
- **Build Library**
11+
- clone this repo
12+
- run the script `generate-archives-docker.sh`
13+
- it runs a docker container and builds the *mmlib* in it.
14+
- the created `dist` directory is copied back to repository root
15+
- it contains the `.whl` file that can be used to install the library with pip (see below)
16+
- **Install**
17+
- to install mmlib run: `pip install <PATH>/dist/mmlib-0.0.1-py3-none-any.whl`
18+
19+
### Option 2: Local Build
20+
21+
- **Requirements**: Python 3.8
22+
- **Build Library**
23+
- run the script `generate-archives.sh`
24+
- it creates a virtual environment, activates it, and installs all requirements
25+
- afterward it builds the library, and a `dist` directory containing the `.whl` file is created
26+
- **Install**
27+
- to install mmlib run: `pip install <PATH>/dist/mmlib-0.0.1-py3-none-any.whl`
28+
29+
## Examples
30+
31+
- For examples on how to use mmlib checkout the [examples](./examples) directory.
32+
233

3-
## installation
434

5-
- to build the lib run: `generate-archives.sh`
6-
- to install it run: `pip install <PATH>/dist/mmlib-0.0.1-py3-none-any.whl
7-
`

examples/README.md

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Examples
2+
3+
This directory contains examples of how to use the functionality offered by the *mmlib*.
4+
5+
- *probe_store.py* - Creates and stores a probe summary of the training process of a GoogLeNet.
6+
- execution: `python probe_store.py --path <optional path to store probe summary>`
7+
- *probe_load_compare.py* - Creates a probe summary of the training process of a GoogLeNet and compares it to a stored
8+
probe summary
9+
- execution: `python probe_load_compare.py --path <path to the already stored probe summary>`
10+
- note: To generate and store a probe summary to compare to use the *probe_store.py* script.
11+
- *probe_example.py* - Shows extensively how the probe functionality offered by the *mmlib* can be used to make the
12+
PyTorch implementation of GoogLeNet reproducible. It runs the following steps:
13+
- simple summary
14+
- creates a probe summary for the inference mode and prints the representation
15+
- probe inference
16+
- creates two instances of the same model
17+
- creates inference mode probe summaries (covering forward path) for them
18+
- compares the probe summaries
19+
- probe training
20+
- creates two instances of the same model
21+
- creates training mode probe summaries (covering forward and backward path)
22+
- compares the probe summaries
23+
- probe reproducible training
24+
- creates two instances of the same model
25+
- uses *set_deterministic* functionality offered by the *mmlib* to make the training process of both models
26+
reproducible
27+
- creates training mode probe summaries (covering forward and backward path)
28+
- compares the probe summaries
29+
- compares both models using the methods *blackbox_model_equal*, *whitebox_model_equal*, and *model_equal* offered by the *mmlib*.

examples/probe_example.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from mmlib.deterministic import set_deterministic
66
from mmlib.helper import imagenet_input, imagenet_target
7-
from mmlib.model_equals import equals, whitebox_equals, blackbox_equals
7+
from mmlib.equal import model_equal, blackbox_model_equal, whitebox_model_equal
88
from mmlib.probe import ProbeInfo, probe_inference, probe_training
99

1010
MODEL = models.googlenet
@@ -96,13 +96,13 @@ def deterministic_backward_compare(device, forward_indices=None):
9696
summary1.compare_to(summary2, common, compare)
9797

9898
# also the models should be equal
99-
blackbox_equal = blackbox_equals(model1, model2, imagenet_input)
100-
whitebox_equal = whitebox_equals(model1, model2)
101-
models_are_equal = equals(model1, model2, imagenet_input)
99+
blackbox_eq = blackbox_model_equal(model1, model2, imagenet_input)
100+
whitebox_eq = whitebox_model_equal(model1, model2)
101+
models_are_equal = model_equal(model1, model2, imagenet_input)
102102
print()
103103
print('Also the models should be the same - compare the models')
104-
print('models_are_equal (blackbox): {}'.format(blackbox_equal))
105-
print('models_are_equal (whitebox): {}'.format(whitebox_equal))
104+
print('models_are_equal (blackbox): {}'.format(blackbox_eq))
105+
print('models_are_equal (whitebox): {}'.format(whitebox_eq))
106106
print('models_are_equal: {}'.format(models_are_equal))
107107

108108

examples/probe_store.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def _generate_probe_training_summary():
2222
dummy_input = imagenet_input()
2323
dummy_target = imagenet_target(dummy_input)
2424
loss_func = nn.CrossEntropyLoss()
25-
model = models.alexnet(pretrained=True)
25+
model = models.googlenet(pretrained=True)
2626
optimizer = torch.optim.SGD(model.parameters(), 1e-3)
2727
summary = probe_training(model, dummy_input, optimizer, loss_func, dummy_target)
2828
return summary

generate-archives-docker.sh

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#!/bin/bash
2+
3+
cd "$(dirname "$0")"
4+
5+
CONTAINER_NAME=mmlib-python
6+
7+
docker run --rm --name $CONTAINER_NAME -it -d python:3.8
8+
docker cp ../mmlib $CONTAINER_NAME:/
9+
docker exec $CONTAINER_NAME /mmlib/generate-archives.sh
10+
docker cp $CONTAINER_NAME:/mmlib/dist ./
11+
docker kill $CONTAINER_NAME

generate-archives.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,6 @@ python3 -m pip install --upgrade pip
1111

1212
# install requirements
1313
python3 -m pip install --upgrade setuptools wheel
14-
python3 -m pip install -r requirements-tests.txt
14+
python3 -m pip install -r requirements.txt
1515

1616
python3 setup.py sdist bdist_wheel

mmlib/model_equals.py mmlib/equal.py

+31-11
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1+
from typing import Callable
2+
13
import torch
24

3-
from mmlib.helper import _get_device
5+
from util.helper import get_device
46

57

6-
def blackbox_equals(m1, m2, produce_input, device: torch.device = None):
8+
def blackbox_model_equal(m1: torch.nn.Module, m2: torch.nn.Module, produce_input: Callable[[], torch.tensor],
9+
device: torch.device = None) -> bool:
710
"""
811
Compares two models in a blackbox manner meaning if the models are equal is determined only by comparing inputs and
912
outputs.
@@ -14,7 +17,10 @@ def blackbox_equals(m1, m2, produce_input, device: torch.device = None):
1417
:return: Returns if the two given models are equal.
1518
"""
1619

17-
device = _get_device(device)
20+
assert isinstance(m1, torch.nn.Module)
21+
assert isinstance(m2, torch.nn.Module)
22+
23+
device = get_device(device)
1824

1925
inp = produce_input()
2026

@@ -31,7 +37,7 @@ def blackbox_equals(m1, m2, produce_input, device: torch.device = None):
3137
return torch.equal(out1, out2)
3238

3339

34-
def whitebox_equals(m1, m2, device: torch.device = None):
40+
def whitebox_model_equal(m1: torch.nn.Module, m2: torch.nn.Module, device: torch.device = None) -> bool:
3541
"""
3642
Compares two models in a whitebox manner meaning we compare the model weights.
3743
:param m1: The first model to compare.
@@ -40,15 +46,18 @@ def whitebox_equals(m1, m2, device: torch.device = None):
4046
:return: Returns if the two given models are equal.
4147
"""
4248

43-
device = _get_device(device)
49+
assert isinstance(m1, torch.nn.Module)
50+
assert isinstance(m2, torch.nn.Module)
51+
52+
device = get_device(device)
4453

4554
state1 = m1.state_dict()
4655
state2 = m2.state_dict()
4756

48-
return state_dict_equals(state1, state2, device)
57+
return state_dict_equal(state1, state2, device)
4958

5059

51-
def state_dict_equals(d1, d2, device: torch.device = None):
60+
def state_dict_equal(d1: dict, d2: dict, device: torch.device = None) -> bool:
5261
"""
5362
Compares two given state dicts.
5463
:param d1: The first state dict.
@@ -57,7 +66,7 @@ def state_dict_equals(d1, d2, device: torch.device = None):
5766
:return: Returns if the given state dicts are equal.
5867
"""
5968

60-
device = _get_device(device)
69+
device = get_device(device)
6170

6271
for item1, item2 in zip(d1.items(), d2.items()):
6372
layer_name1, weight_tensor1 = item1
@@ -72,7 +81,8 @@ def state_dict_equals(d1, d2, device: torch.device = None):
7281
return True
7382

7483

75-
def equals(m1, m2, produce_input, device: torch.device = None):
84+
def model_equal(m1: torch.nn.Module, m2: torch.nn.Module, produce_input: Callable[[], torch.tensor],
85+
device: torch.device = None) -> bool:
7686
"""
7787
An equals method to compare two given models by making use of whitebox and blackbox equals.
7888
:param m1: The first model to compare.
@@ -81,8 +91,18 @@ def equals(m1, m2, produce_input, device: torch.device = None):
8191
:param device: The device to execute on
8292
:return: Returns if the two given models are equal.
8393
"""
84-
device = _get_device(device)
94+
device = get_device(device)
8595

8696
# whitebox and blackbox check should be redundant,
8797
# but this way we have an extra safety net in case we forgot a special case
88-
return whitebox_equals(m1, m2, device) and blackbox_equals(m1, m2, produce_input, device)
98+
return whitebox_model_equal(m1, m2, device) and blackbox_model_equal(m1, m2, produce_input, device)
99+
100+
101+
def tensor_equal(tensor1: torch.tensor, tensor2: torch.tensor):
102+
"""
103+
Compares to given Pytorch tensors.
104+
:param tensor1: The first tensor to be compared.
105+
:param tensor2: The second tensor to be compared.
106+
:return: Returns if the two given tensors are equal.
107+
"""
108+
return torch.equal(tensor1, tensor2)

mmlib/helper.py

+2-8
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,7 @@
33
import torch
44

55

6-
def _get_device(device):
7-
if device is None:
8-
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
9-
return device
10-
11-
12-
def imagenet_input(batch_size=10):
6+
def imagenet_input(batch_size: int = 10) -> torch.tensor:
137
"""
148
Generates a batch of dummy imputes for models processing imagenet data.
159
:param batch_size: The size of the batch.
@@ -21,7 +15,7 @@ def imagenet_input(batch_size=10):
2115
return torch.stack(batch)
2216

2317

24-
def imagenet_target(dummy_input):
18+
def imagenet_target(dummy_input: torch.tensor) -> torch.tensor:
2519
"""
2620
Creates a batch of random labels for imagenet data based on a given input data.
2721
:param dummy_input: The input to a potential model for the the target values should be produced.

mmlib/log.py

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
def use_model(model_id):
2+
print('use model with model_id: {}'.format(str(model_id)))

mmlib/probe.py

+22-17
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1+
import warnings
12
from enum import Enum
23

34
import torch
45
import torch.nn as nn
56
from colorama import Fore, Style
67

7-
from mmlib.helper import _get_device
8-
from mmlib.util import _print_info, _print_warning
8+
from util.helper import print_info, get_device
99

1010

1111
class ProbeInfo(Enum):
@@ -41,7 +41,10 @@ class ProbeSummary:
4141
DIFF = 'diff'
4242
SAME = 'same'
4343

44-
def __init__(self, summary_path=None):
44+
def __init__(self, summary_path: str = None):
45+
"""
46+
:param summary_path: Path to load a summary from
47+
"""
4548
if summary_path:
4649
self.load(summary_path)
4750
else:
@@ -105,14 +108,14 @@ def compare_to(self, other_summary, common: [ProbeInfo], compare: [ProbeInfo]):
105108
for layer_key, layer_info in self.summary.items():
106109
self._print_compare_layer(common, compare, layer_info, other_summary)
107110

108-
def save(self, path):
111+
def save(self, path: str):
109112
"""
110113
Saves an object to a disk file.
111114
:param path: The path to store to.
112115
"""
113116
torch.save(self.summary, path)
114117

115-
def load(self, path):
118+
def load(self, path: str):
116119
"""
117120
Loads an object saved with :func:`mmlib.probe.save` from a file.
118121
:param path: The path to load from.
@@ -216,7 +219,8 @@ def _compare_values(self, v1, v2):
216219
return v1 == v2
217220

218221

219-
def probe_inference(model, inp, device: torch.device = None, forward_indices=None):
222+
def probe_inference(model: torch.nn.Module, inp: torch.tensor, device: torch.device = None,
223+
forward_indices: [int] = None) -> ProbeSummary:
220224
"""
221225
Probes the inference of a given model.
222226
:param model: The model to probe.
@@ -233,7 +237,8 @@ def probe_inference(model, inp, device: torch.device = None, forward_indices=Non
233237
return _probe_reproducibility(model, inp, ProbeMode.INFERENCE, device, forward_indices=forward_indices)
234238

235239

236-
def probe_training(model, inp, optimizer, loss_func, target, device: torch.device = None, forward_indices=None):
240+
def probe_training(model: torch.nn.Module, inp: torch.tensor, optimizer: torch.optim.Optimizer, loss_func,
241+
target: torch.tensor, device: torch.device = None, forward_indices: [int] = None) -> ProbeSummary:
237242
"""
238243
Probes the training of a given model.
239244
:param model: The model to probe.
@@ -259,7 +264,7 @@ def _probe_reproducibility(model, inp, mode, device, optimizer=None, loss_func=N
259264

260265
_forward_indices_warning(forward_indices)
261266

262-
device = _get_device(device)
267+
device = get_device(device)
263268

264269
def register_forward_hook(module, ):
265270

@@ -362,22 +367,22 @@ def _shape_list(tensor_tuple):
362367

363368
def _forward_indices_warning(forward_indices):
364369
if forward_indices is not None:
365-
_print_warning("You set the forward_indices argument. "
366-
"This means not all layers will be included in the summary.")
370+
print_info("You set the forward_indices argument. "
371+
"This means not all layers will be included in the summary.")
367372
else:
368-
_print_warning("You did not set the forward_indices argument. "
369-
"Every layer will be included in the summary. This might lead to very high memory consumption.")
373+
warnings.warn("You did not set the forward_indices argument."
374+
"Every layer will be included in the summary. This might lead to very high memory consumption.")
370375

371376

372377
def _hashwarning(fields: [ProbeInfo]):
373378
# If we print tensors or shapes it is likely that they are to long. In this case we print a hash instead.
374379
# Warn the user that for example for long tensors same hash values do not guarantee the same values.
375380
if any('shape' in x.value or 'tensor' in x.value for x in fields):
376-
_print_warning("Same hashes don\'t have to mean that values are exactly the same (especially for tensors)."
377-
" Hashes should be seen as an indicator.")
381+
print_info("Same hashes don\'t have to mean that values are exactly the same (especially for tensors)."
382+
" Hashes should be seen as an indicator.")
378383

379384

380385
def _inference_info():
381-
_print_info("You are probing in inference mode so the model will be in eval mode."
382-
"\nSince layers like dropout are switched off in this mode you won't find factors that produce "
383-
"non-reproducibility by these kind of layers.")
386+
print_info("You are probing in inference mode so the model will be in eval mode."
387+
"\nSince layers like dropout are switched off in this mode you won't find factors that produce "
388+
"non-reproducibility by these kind of layers.")

0 commit comments

Comments
 (0)