From 56104c6caca63d3cd75e92a0c52778f03e204978 Mon Sep 17 00:00:00 2001
From: "Chua, Vui Seng" <vui.seng.chua@intel.com>
Date: Wed, 17 Nov 2021 11:47:27 -0800
Subject: [PATCH 1/7] Initial movement sparsity implementation for fine-grained
 pruning

---
 nncf/common/sparsity/schedulers.py        | 149 ++++++++++++++++++++++
 nncf/common/sparsity/statistics.py        |  38 ++++++
 nncf/common/statistics.py                 |  13 +-
 nncf/common/utils/tensorboard.py          |  12 ++
 nncf/config/schema.py                     |  54 ++++++++
 nncf/torch/__init__.py                    |   1 +
 nncf/torch/functions.py                   |   6 +-
 nncf/torch/sparsity/movement/__init__.py  |  12 ++
 nncf/torch/sparsity/movement/algo.py      | 136 ++++++++++++++++++++
 nncf/torch/sparsity/movement/functions.py |  23 ++++
 nncf/torch/sparsity/movement/layers.py    |  93 ++++++++++++++
 nncf/torch/sparsity/movement/loss.py      |  77 +++++++++++
 12 files changed, 610 insertions(+), 4 deletions(-)
 create mode 100644 nncf/torch/sparsity/movement/__init__.py
 create mode 100644 nncf/torch/sparsity/movement/algo.py
 create mode 100644 nncf/torch/sparsity/movement/functions.py
 create mode 100644 nncf/torch/sparsity/movement/layers.py
 create mode 100644 nncf/torch/sparsity/movement/loss.py

diff --git a/nncf/common/sparsity/schedulers.py b/nncf/common/sparsity/schedulers.py
index 61eabdf5f00..a9723448ae7 100644
--- a/nncf/common/sparsity/schedulers.py
+++ b/nncf/common/sparsity/schedulers.py
@@ -285,3 +285,152 @@ def epoch_step(self, next_epoch: Optional[int] = None) -> None:
 
     def _calculate_sparsity_level(self) -> float:
         return self.schedule(self.current_epoch)
+
+
+@SPARSITY_SCHEDULERS.register('threshold_polynomial_decay')
+class PolynomialThresholdScheduler(BaseCompressionScheduler):
+    """
+    Sparsity scheduler with a polynomial decay schedule.
+
+    Two ways are available for calculations of the sparsity:
+        - per epoch
+        - per step
+    Parameters `update_per_optimizer_step` and `steps_per_epoch`
+    should be provided in config for the per step calculation.
+    If `update_per_optimizer_step` was only provided then scheduler
+    will use first epoch to calculate `steps_per_epoch`
+    parameter. In this case, `current_epoch` and `current_step` will
+    not be updated on this epoch. The scheduler will start calculation
+    after `steps_per_epoch` will be calculated.
+    """
+
+    def __init__(self, controller: SparsityController, params: dict):
+        """
+        TODO: revise docstring
+        TODO: test epoch-wise stepping
+        Initializes a sparsity scheduler with a polynomial decay schedule.
+
+        :param controller: Sparsity algorithm controller.
+        :param params: Parameters of the scheduler.
+        """
+        super().__init__()
+        self._controller = controller
+        self.init_importance_threshold = params.get('init_importance_threshold', 0.0)
+        self.final_importance_threshold = params.get('final_importance_threshold', 0.1)
+        self.warmup_start_epoch = params.get('warmup_start_epoch', 0.0)
+        self.warmup_end_epoch = params.get('warmup_end_epoch', 0.0)
+        self.importance_target_lambda = params.get('importance_regularization_factor', 1.0)
+        self.current_importance_threshold  = self.init_importance_threshold
+        self.cached_importance_threshold = self.current_importance_threshold
+
+        self.schedule = PolynomialDecaySchedule(
+            self.init_importance_threshold, 
+            self.final_importance_threshold, 
+            self.warmup_end_epoch,
+            params.get('power', 3), 
+            params.get('concave', True)
+            )
+
+        self._steps_in_current_epoch = 0
+        self._update_per_optimizer_step = params.get('update_per_optimizer_step', False)
+        self._steps_per_epoch = params.get('steps_per_epoch', None)
+        self._should_skip = False
+
+    @property
+    def current_importance_lambda(self):
+        return self.importance_target_lambda * (self.current_importance_threshold/self.final_importance_threshold)
+
+    def _disable_importance_grad(self):
+        for m in self._controller.sparsified_module_info:
+            m.operand.freeze_importance()
+                        
+    def _update_importance_masking_threshold(self):
+        if self.cached_importance_threshold != self.current_importance_threshold:
+            for m in self._controller.sparsified_module_info:
+                m.operand.masking_threshold = self.current_importance_threshold
+        self.cached_importance_threshold = self.current_importance_threshold 
+
+    def schedule_threshold(self):
+        if self.current_step <= self.warmup_start_epoch * self._steps_per_epoch:
+            self.current_importance_threshold  = self.init_importance_threshold
+
+        elif self.current_step > self.warmup_end_epoch * self._steps_per_epoch:
+            self.current_importance_threshold  = self.final_importance_threshold
+            self._disable_importance_grad()
+
+            # TODO: gradient freezing should be at the epoch to freeze epoch
+            # for n, m in self._controller.model.named_modules():
+            #     if m.__class__.__name__ == "MovementSparsifyingWeight":
+            #         m.frozen=True
+            #         m._importance.requires_grad=False
+
+        else:
+            self.current_importance_threshold  = self._calculate_threshold_level()
+
+        self._update_importance_masking_threshold()
+        # if _cached_threshold != self.current_importance_threshold  or _cached_regu_lambda != self.current_importance_lambda:
+        #     for n, m in self._controller.model.named_modules():
+        #         if m.__class__.__name__ == "MovementSparsifyingWeight":
+        #             m.masking_threshold = self.current_importance_threshold 
+        #             # m.lmbd = self.current_importance_lambda
+
+    def step(self, next_step: Optional[int] = None) -> None:
+        super().step(next_step)
+        self._steps_in_current_epoch += 1
+        if self._should_skip:
+            return
+
+        if self._update_per_optimizer_step:
+            self.schedule_threshold()
+
+    def epoch_step(self, next_epoch: Optional[int] = None) -> None:
+        self._maybe_should_skip()
+        self._steps_in_current_epoch = 0 # This must be set after _maybe_should_skip as it is used in that routine
+        if self._should_skip:
+            return
+        # only increment epoch if should_skip is checked
+        super().epoch_step(next_epoch)
+        print("-----epoch_step", self.current_epoch)
+        print("-----step", self._steps_in_current_epoch)
+        if not self._update_per_optimizer_step:
+            self.schedule_threshold()
+
+    def _calculate_threshold_level(self) -> float:
+        print("epoch_step", self.current_epoch)
+        print("step", self._steps_in_current_epoch)
+        local_step = max(self._steps_in_current_epoch+1, 0)
+        return self.schedule(self.current_epoch-self.warmup_start_epoch, local_step, self._steps_per_epoch)
+
+    def load_state(self, state: Dict[str, Any]) -> None:
+        super().load_state(state)
+        if self._update_per_optimizer_step:
+            self._steps_per_epoch = state['_steps_per_epoch']
+
+    def get_state(self) -> Dict[str, Any]:
+        state = super().get_state()
+        if self._update_per_optimizer_step:
+            state['_steps_per_epoch'] = self._steps_per_epoch
+        return state
+
+    def _maybe_should_skip(self) -> None:
+        """
+        Checks if the first epoch (with index 0) should be skipped to calculate
+        the steps per epoch. If the skip is needed, then the internal state
+        of the scheduler object will not be changed.
+        """
+        self._should_skip = False
+        if self._update_per_optimizer_step:
+            if self._steps_per_epoch is None and self._steps_in_current_epoch > 0:
+                self._steps_per_epoch = self._steps_in_current_epoch
+
+            if self._steps_per_epoch is not None and self._steps_in_current_epoch > 0:
+                if self._steps_per_epoch != self._steps_in_current_epoch:
+                    raise Exception('Actual steps per epoch and steps per epoch from the scheduler '
+                                    'parameters are different. Scheduling may be incorrect.')
+
+            if self._steps_per_epoch is None:
+                self._should_skip = True
+                logger.warning('Scheduler set to update sparsity level per optimizer step, '
+                               'but steps_per_epoch was not set in config. Will only start updating '
+                               'sparsity level after measuring the actual steps per epoch as signaled '
+                               'by a .epoch_step() call.')
\ No newline at end of file
diff --git a/nncf/common/sparsity/statistics.py b/nncf/common/sparsity/statistics.py
index 734094f4292..9308b0daa60 100644
--- a/nncf/common/sparsity/statistics.py
+++ b/nncf/common/sparsity/statistics.py
@@ -187,3 +187,41 @@ def to_str(self) -> str:
             f'Statistics of the RB-sparsity algorithm:\n{algorithm_string}'
         )
         return pretty_string
+
+class MovementSparsityStatistics(Statistics):
+    """
+    Contains statistics of the movement-sparsity algorithm.
+    """
+
+    def __init__(self,
+                 model_statistics: SparsifiedModelStatistics,
+                 importance_threshold,
+                 importance_regularization_factor):
+        """
+        Initializes statistics of the movement-sparsity algorithm.
+
+        :param model_statistics: Statistics of the sparsified model.
+        :param importance_threshold: importance threshold for
+            sparsity binary mask
+        :param importance_regularization_factor: penalty factor of
+            importance score
+
+        """
+        self.model_statistics = model_statistics
+        self.importance_threshold = importance_threshold
+        self.importance_regularization_factor = importance_regularization_factor
+
+    def to_str(self) -> str:
+        algorithm_string = create_table(
+            header=['Statistic\'s name', 'Value'],
+            rows=[
+                ['Mask Importance Threshold', self.importance_threshold],
+                ['Importance Regularization Factor', self.importance_regularization_factor],
+            ]
+        )
+
+        pretty_string = (
+            f'{self.model_statistics.to_str()}\n\n'
+            f'Statistics of the movement-sparsity algorithm:\n{algorithm_string}'
+        )
+        return pretty_string
diff --git a/nncf/common/statistics.py b/nncf/common/statistics.py
index 4f5cff0e008..5a7651ea5ce 100644
--- a/nncf/common/statistics.py
+++ b/nncf/common/statistics.py
@@ -16,6 +16,7 @@
 from nncf.api.statistics import Statistics
 from nncf.common.sparsity.statistics import MagnitudeSparsityStatistics
 from nncf.common.sparsity.statistics import RBSparsityStatistics
+from nncf.common.sparsity.statistics import MovementSparsityStatistics
 from nncf.common.sparsity.statistics import ConstSparsityStatistics
 from nncf.common.quantization.statistics import QuantizationStatistics
 from nncf.common.pruning.statistics import FilterPruningStatistics
@@ -53,6 +54,16 @@ def rb_sparsity(self) -> Optional[RBSparsityStatistics]:
         """
         return self._storage.get('rb_sparsity')
 
+    @property
+    def movement_sparsity(self) -> Optional[MovementSparsityStatistics]:
+        """
+        Returns statistics of the movement sparsity algorithm. If statistics
+        have not been collected, `None` will be returned.
+
+        :return: Instance of the `MovementSparsityStatistics` class.
+        """
+        return self._storage.get('movement_sparsity')
+
     @property
     def const_sparsity(self) -> Optional[ConstSparsityStatistics]:
         """
@@ -108,7 +119,7 @@ def register(self, algorithm_name: str, stats: Statistics):
         """
 
         available_algorithms = [
-            'magnitude_sparsity', 'rb_sparsity', 'const_sparsity',
+            'magnitude_sparsity', 'rb_sparsity', 'movement_sparsity', 'const_sparsity',
             'quantization', 'filter_pruning', 'binarization'
         ]
         if algorithm_name not in available_algorithms:
diff --git a/nncf/common/utils/tensorboard.py b/nncf/common/utils/tensorboard.py
index 7aa1198eb67..4c8346d284a 100644
--- a/nncf/common/utils/tensorboard.py
+++ b/nncf/common/utils/tensorboard.py
@@ -18,6 +18,7 @@
 from nncf.common.pruning.statistics import FilterPruningStatistics
 from nncf.common.sparsity.statistics import MagnitudeSparsityStatistics
 from nncf.common.sparsity.statistics import RBSparsityStatistics
+from nncf.common.sparsity.statistics import MovementSparsityStatistics
 from nncf.common.sparsity.statistics import ConstSparsityStatistics
 
 
@@ -65,3 +66,14 @@ def _(stats, algorithm_name):
         tensorboard_stats[f'{algorithm_name}/target_sparsity_level'] = target_sparsity_level
 
     return tensorboard_stats
+
+
+@convert_to_dict.register(MovementSparsityStatistics)
+def _(stats, algorithm_name):
+    tensorboard_stats = {
+        f'{algorithm_name}/model_sparsity': stats.model_statistics.sparsity_level,
+        f'{algorithm_name}/relative_sparsity': stats.model_statistics.sparsity_level_for_layers,
+        f'{algorithm_name}/importance_threshold': stats.importance_threshold,
+        f'{algorithm_name}/importance_regularization_factor': stats.importance_regularization_factor,
+    }
+    return tensorboard_stats
diff --git a/nncf/config/schema.py b/nncf/config/schema.py
index 91e9e4729a2..4ed09c3a0e0 100644
--- a/nncf/config/schema.py
+++ b/nncf/config/schema.py
@@ -724,6 +724,59 @@ def with_attributes(schema: Dict, **kwargs) -> Dict:
     "additionalProperties": False
 }
 
+MOVEMENT_SPARSITY_ALGO_NAME_IN_CONFIG = "movement_sparsity"
+MOVEMENT_SPARSITY_SCHEMA = {
+    **BASIC_COMPRESSION_ALGO_SCHEMA,
+    "properties": {
+        "algorithm": {
+            "const": MOVEMENT_SPARSITY_ALGO_NAME_IN_CONFIG
+        },
+        **COMPRESSION_LR_MULTIPLIER_PROPERTY,
+        "sparsity_init": with_attributes(_NUMBER,
+                                         description="Initial value of the sparsity level applied to the "
+                                                     "model"),
+        "params":
+            {
+                # TODO: revise config to expose
+                "type": "object",
+                "properties": {
+                    "schedule": with_attributes(_STRING,
+                                                description="The type of scheduling to use for adjusting the"
+                                                            "importance threshold and its regularization factor"),
+                    "power": with_attributes(_NUMBER,
+                                             description="For polynomial scheduler - determines the corresponding power value."),
+                    "init_importance_threshold": with_attributes(_NUMBER,
+                                                                 description="importance masking threshold @ warmup_start_epoch"),
+                    "warmup_start_epoch": with_attributes(_NUMBER,
+                                                          description="Index of the starting epoch of the importance masking threshold"
+                                                                        "warmup at the value of init_importance_threshold"),
+                    "final_importance_threshold": with_attributes(_NUMBER,
+                                                                  description="importance masking threshold @ warmup_end_epoch"),
+                    "warmup_end_epoch": with_attributes(_NUMBER,
+                                                        description="Index of the ending epoch of the importance masking threshold"
+                                                                    "warmup at the value of final_importance_threshold"),
+                    "importance_regularization_factor": with_attributes(_NUMBER,
+                                                                        description="regularization final lambda"),
+                    "steps_per_epoch": with_attributes(_NUMBER,
+                                       description="Number of optimizer steps in one epoch. Required to start proper "
+                                                   " scheduling in the first training epoch if "
+                                                   "'update_per_optimizer_step' is true"),
+                    "update_per_optimizer_step": with_attributes(_BOOLEAN,
+                                                                description="Whether the function-based sparsity level schedulers "
+                                                                            "should update the sparsity level after each optimizer "
+                                                                            "step instead of each epoch step."),
+                    "sparsity_level_setting_mode": with_attributes(_STRING,
+                                                                description="The mode of sparsity level setting( "
+                                                                            "'global' - one sparsity level is set for all layer, "
+                                                                            "'local' - sparsity level is set per-layer.)"),
+                },
+                "additionalProperties": False
+            },
+        **COMMON_COMPRESSION_ALGORITHM_PROPERTIES
+    },
+    "additionalProperties": False
+}
+
 FILTER_PRUNING_ALGO_NAME_IN_CONFIG = 'filter_pruning'
 FILTER_PRUNING_SCHEMA = {
     **BASIC_COMPRESSION_ALGO_SCHEMA,
@@ -863,6 +916,7 @@ def with_attributes(schema: Dict, **kwargs) -> Dict:
                       CONST_SPARSITY_ALGO_NAME_IN_CONFIG: CONST_SPARSITY_SCHEMA,
                       MAGNITUDE_SPARSITY_ALGO_NAME_IN_CONFIG: MAGNITUDE_SPARSITY_SCHEMA,
                       RB_SPARSITY_ALGO_NAME_IN_CONFIG: RB_SPARSITY_SCHEMA,
+                      MOVEMENT_SPARSITY_ALGO_NAME_IN_CONFIG: MOVEMENT_SPARSITY_SCHEMA,
                       FILTER_PRUNING_ALGO_NAME_IN_CONFIG: FILTER_PRUNING_SCHEMA,
                       KNOWLEDGE_DISTILLATION_ALGO_NAME_IN_CONFIG: KNOWLEDGE_DISTILLATION_SCHEMA}
 
diff --git a/nncf/torch/__init__.py b/nncf/torch/__init__.py
index 35727b49975..2d6eb1116a2 100644
--- a/nncf/torch/__init__.py
+++ b/nncf/torch/__init__.py
@@ -44,6 +44,7 @@
 from nncf.torch.sparsity.const import algo as const_sparsity_algo
 from nncf.torch.sparsity.magnitude import algo as magnitude_sparsity_algo
 from nncf.torch.sparsity.rb import algo as rb_sparsity_algo
+from nncf.torch.sparsity.movement import algo as movement_sparsity_algo
 from nncf.torch.pruning.filter_pruning import algo as filter_pruning_algo
 from nncf.torch.knowledge_distillation import algo as knowledge_distillation_algo
 
diff --git a/nncf/torch/functions.py b/nncf/torch/functions.py
index fc95607c63a..a3ff189a431 100644
--- a/nncf/torch/functions.py
+++ b/nncf/torch/functions.py
@@ -39,10 +39,10 @@ def backward(ctx, grad_output):
 
 class STThreshold(torch.autograd.Function):
     @staticmethod
-    def forward(ctx, input_):
-        output = (input_ > 0.5).type(input_.dtype)
+    def forward(ctx, input_, threshold=0.5):
+        output = (input_ > threshold).type(input_.dtype)
         return output
 
     @staticmethod
     def backward(ctx, grad_output):
-        return grad_output
+        return grad_output, None
diff --git a/nncf/torch/sparsity/movement/__init__.py b/nncf/torch/sparsity/movement/__init__.py
new file mode 100644
index 00000000000..10450b961fe
--- /dev/null
+++ b/nncf/torch/sparsity/movement/__init__.py
@@ -0,0 +1,12 @@
+"""
+ Copyright (c) 2019-2020 Intel Corporation
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+      http://www.apache.org/licenses/LICENSE-2.0
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
diff --git a/nncf/torch/sparsity/movement/algo.py b/nncf/torch/sparsity/movement/algo.py
new file mode 100644
index 00000000000..f6d7179bdc2
--- /dev/null
+++ b/nncf/torch/sparsity/movement/algo.py
@@ -0,0 +1,136 @@
+"""
+ Copyright (c) 2019-2020 Intel Corporation
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+      http://www.apache.org/licenses/LICENSE-2.0
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+from copy import deepcopy
+from typing import List
+
+import torch
+import torch.distributed as dist
+
+from nncf import NNCFConfig
+from nncf.config.extractors import extract_algo_specific_config
+from nncf.torch.algo_selector import PT_COMPRESSION_ALGORITHMS
+from nncf.api.compression import CompressionStage
+from nncf.common.graph import NNCFNode
+from nncf.torch.compression_method_api import PTCompressionAlgorithmController
+from nncf.torch.nncf_network import NNCFNetwork
+from nncf.torch.sparsity.base_algo import BaseSparsityAlgoBuilder, BaseSparsityAlgoController, SparseModuleInfo
+from nncf.torch.sparsity.movement.layers import MovementSparsifyingWeight
+from nncf.torch.sparsity.movement.loss import ImportanceLoss, SparseLossForPerLayerSparsity
+from nncf.torch.utils import get_world_size
+from nncf.common.accuracy_aware_training.training_loop import ADAPTIVE_COMPRESSION_CONTROLLERS
+from nncf.torch.sparsity.collector import PTSparseModelStatisticsCollector
+from nncf.common.sparsity.schedulers import SPARSITY_SCHEDULERS
+from nncf.common.schedulers import StubCompressionScheduler
+from nncf.common.sparsity.statistics import MovementSparsityStatistics
+from nncf.common.statistics import NNCFStatistics
+
+
+@PT_COMPRESSION_ALGORITHMS.register('movement_sparsity')
+class MovementSparsityBuilder(BaseSparsityAlgoBuilder):
+    def create_weight_sparsifying_operation(self, target_module_node: NNCFNode, compression_lr_multiplier: float):
+        return MovementSparsifyingWeight(target_module_node.layer_attributes.get_weight_shape(), frozen=False,
+                                   compression_lr_multiplier=compression_lr_multiplier)
+
+    def _build_controller(self, model: NNCFNetwork) -> PTCompressionAlgorithmController:
+        return MovementSparsityController(model, self._sparsified_module_info, self.config)
+
+
+@ADAPTIVE_COMPRESSION_CONTROLLERS.register('pt_movement_sparsity')
+class MovementSparsityController(BaseSparsityAlgoController):
+    def __init__(self, target_model: NNCFNetwork, sparsified_module_info: List[SparseModuleInfo],
+                 config: NNCFConfig):
+        super().__init__(target_model, sparsified_module_info)
+        algo_config = extract_algo_specific_config(config, 'movement_sparsity')
+        params = deepcopy(algo_config.get('params', {}))
+
+        self._distributed = False
+        self._mode = params.get('sparsity_level_setting_mode', 'global')
+        self._check_sparsity_masks = params.get('check_sparsity_masks', False)
+
+        sparsify_operations = [m.operand for m in self.sparsified_module_info]
+        if self._mode == 'local':
+            # TODO: make sure we test this loop out
+            self._loss = SparseLossForPerLayerSparsity(sparsify_operations)
+            self._scheduler = StubCompressionScheduler()
+        else:
+            scheduler_cls = SPARSITY_SCHEDULERS.get(params.get('schedule', 'exponential')) #TODO: can we actually map to other scheduler in current implementation
+            self._scheduler = scheduler_cls(self, params)
+            self._loss = ImportanceLoss(sparsify_operations, self.scheduler)
+
+    def compression_stage(self) -> CompressionStage:
+        if self._mode == 'local':
+            return CompressionStage.FULLY_COMPRESSED
+
+        if self.scheduler.current_sparsity_level == 0:
+            return CompressionStage.UNCOMPRESSED
+        if self.scheduler.current_sparsity_level >= self.scheduler.target_level:
+            return CompressionStage.FULLY_COMPRESSED
+        return CompressionStage.PARTIALLY_COMPRESSED
+
+    def freeze(self):
+        self._loss.disable()
+
+    def distributed(self):
+        if not dist.is_initialized():
+            raise KeyError('Could not set distributed mode for the compression algorithm '
+                           'because the default process group has not been initialized.')
+
+        if next(self._model.parameters()).is_cuda:
+            state = torch.cuda.get_rng_state()
+            if dist.get_backend() == dist.Backend.NCCL:
+                state = state.cuda()
+            torch.distributed.broadcast(state, src=0)
+            torch.cuda.set_rng_state(state.cpu())
+        else:
+            state = torch.get_rng_state()
+            torch.distributed.broadcast(state, src=0)
+            torch.set_rng_state(state)
+
+        self._distributed = True
+
+    def _check_distributed_masks(self):
+        if not self._distributed or get_world_size() == 1:
+            return 1
+
+        nvalues = 0
+        ncor_values = 0
+        eps = 1e-4
+        for minfo in self.sparsified_module_info:
+            mask = minfo.operand.mask
+
+            mask_list = [torch.empty_like(mask) for _ in range(get_world_size())]
+            # nccl does not support gather, send, recv operations
+            dist.all_gather(mask_list, mask)
+
+            for i in range(1, len(mask_list)):
+                rel_error = (mask_list[0] - mask_list[i]) / mask_list[0]
+                ncor_values = ncor_values + (rel_error.abs() < eps).sum(dtype=mask.dtype)
+                nvalues = nvalues + mask_list[i].numel()
+
+        return ncor_values / nvalues
+
+    def statistics(self, quickly_collected_only=False) -> NNCFStatistics:
+        collector = PTSparseModelStatisticsCollector(self.model, self.sparsified_module_info)
+        model_statistics = collector.collect()
+
+        stats = MovementSparsityStatistics(model_statistics,
+                                           self.scheduler.current_importance_threshold,
+                                           self.scheduler.current_importance_lambda)
+
+        nncf_stats = NNCFStatistics()
+        nncf_stats.register('movement_sparsity', stats)
+        return nncf_stats
+
+    @property
+    def compression_rate(self):
+        return self.statistics().movement_sparsity.model_statistics.sparsity_level
diff --git a/nncf/torch/sparsity/movement/functions.py b/nncf/torch/sparsity/movement/functions.py
new file mode 100644
index 00000000000..3611e6c23ad
--- /dev/null
+++ b/nncf/torch/sparsity/movement/functions.py
@@ -0,0 +1,23 @@
+"""
+ Copyright (c) 2019 Intel Corporation
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+      http://www.apache.org/licenses/LICENSE-2.0
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+
+import torch
+
+from nncf.torch.dynamic_graph.patch_pytorch import register_operator
+from nncf.torch.functions import STThreshold
+
+
+def binary_mask_by_threshold(importance, threshold=0.5, sigmoid=True):
+    if sigmoid is True:
+        return STThreshold.apply(torch.sigmoid(importance), threshold)
+    return STThreshold.apply(importance, threshold)
\ No newline at end of file
diff --git a/nncf/torch/sparsity/movement/layers.py b/nncf/torch/sparsity/movement/layers.py
new file mode 100644
index 00000000000..33e3c8ba78f
--- /dev/null
+++ b/nncf/torch/sparsity/movement/layers.py
@@ -0,0 +1,93 @@
+"""
+ Copyright (c) 2019 Intel Corporation
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+      http://www.apache.org/licenses/LICENSE-2.0
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+from typing import List
+
+import torch
+
+from nncf.torch.sparsity.layers import BinaryMask
+from nncf.torch.sparsity.movement.functions import binary_mask_by_threshold
+from nncf.torch.functions import logit
+from nncf.torch.layer_utils import COMPRESSION_MODULES, CompressionParameter
+
+
+
+@COMPRESSION_MODULES.register()
+class MovementSparsifyingWeight(BinaryMask):
+    def __init__(self, weight_shape: List[int], frozen=True, compression_lr_multiplier=None, eps=1e-6):
+        super().__init__(weight_shape)
+        self.frozen = frozen
+        self.eps = eps
+        self.lmbd = 0.5 # module_level_loss_weightage
+        self.masking_threshold = 0.0
+        self._importance = CompressionParameter(
+                                torch.zeros(weight_shape), 
+                                requires_grad=not self.frozen,
+                                compression_lr_multiplier=compression_lr_multiplier)
+        self.binary_mask = binary_mask_by_threshold(self._importance, self._masking_threshold)
+        self.mask_calculation_hook = MaskCalculationHook(self)
+
+    @property
+    def importance(self):
+        return self._importance.data
+
+    @property
+    def masking_threshold(self):
+        return self._masking_threshold
+    
+    @masking_threshold.setter
+    def masking_threshold(self, threshold_value):
+        self._masking_threshold = threshold_value
+
+    @property
+    def lmbd(self):
+        return self._lmbd
+    
+    @lmbd.setter
+    def lmbd(self, module_level_loss_weightage):
+        self._lmbd = module_level_loss_weightage
+
+    def freeze_importance(self):
+        self.frozen = True
+        self._importance.requires_grad=False
+
+    def unfreeze_importance(self):
+        self.frozen = False
+        self._importance.requires_grad=True
+
+    def _calc_training_binary_mask(self, weight):
+        if self.training and not self.frozen:
+            _mask = binary_mask_by_threshold(self._importance, self._masking_threshold)
+            self.binary_mask = _mask
+            #TODO: remove
+            # if (_mask.numel() - _mask.count_nonzero()) > 0:
+            #     print("yay")
+            return _mask
+        else:
+            return self.binary_mask
+
+    def loss(self):
+        return self.lmbd * (torch.norm(torch.sigmoid(self._importance), p=1) / self._importance.numel())
+
+
+class MaskCalculationHook():
+    def __init__(self, module):
+        # pylint: disable=protected-access
+        self.hook = module._register_state_dict_hook(self.hook_fn)
+
+    def hook_fn(self, module, destination, prefix, local_metadata):
+        module.binary_mask = binary_mask_by_threshold(module.importance, module.masking_threshold)
+        destination[prefix + '_binary_mask'] = module.binary_mask
+        return destination
+
+    def close(self):
+        self.hook.remove()
diff --git a/nncf/torch/sparsity/movement/loss.py b/nncf/torch/sparsity/movement/loss.py
new file mode 100644
index 00000000000..ce6f87af63c
--- /dev/null
+++ b/nncf/torch/sparsity/movement/loss.py
@@ -0,0 +1,77 @@
+"""
+ Copyright (c) 2019-2020 Intel Corporation
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+      http://www.apache.org/licenses/LICENSE-2.0
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+
+import torch
+
+from nncf.torch.compression_method_api import PTCompressionLoss
+
+class ImportanceLoss(PTCompressionLoss):
+    def __init__(self, sparse_layers=None, penalty_scheduler=None):
+        super().__init__()
+        self._sparse_layers = sparse_layers
+        self.disabled = False
+        self.penalty_scheduler = penalty_scheduler
+
+    def set_layers(self, sparse_layers):
+        self._sparse_layers = sparse_layers
+
+    def disable(self):
+        if not self.disabled:
+            self.disabled = True
+
+            for sparse_layer in self._sparse_layers:
+                sparse_layer.freeze_importance()
+
+    def calculate(self) -> torch.Tensor:
+        # TODO, how about frozen?
+        if self.disabled:
+            return 0
+
+        loss = 0
+        n_active_layer=0
+        for sparse_layer in self._sparse_layers:
+            loss += sparse_layer.loss()
+            n_active_layer+=1
+
+        if self.penalty_scheduler is not None:
+            return self.penalty_scheduler.current_importance_lambda * (loss/n_active_layer)
+        return loss/n_active_layer 
+
+
+class SparseLossForPerLayerSparsity(ImportanceLoss):
+    def __init__(self, sparse_layers=None, target=1.0, p=0.05):
+        super().__init__(sparse_layers)
+        self.per_layer_target = {}
+        for sparse_layer in self._sparse_layers:
+            self.per_layer_target[sparse_layer] = self.target
+
+    def calculate(self) -> torch.Tensor:
+        if self.disabled:
+            return 0
+
+        params = 0
+        sparse_prob_sum = 0
+        sparse_layers_loss = 0
+        for sparse_layer in self._sparse_layers:
+            if not self.disabled and not sparse_layer.sparsify:
+                raise AssertionError(
+                    "Invalid state of SparseLoss and SparsifiedWeight: mask is frozen for enabled loss")
+            if sparse_layer.sparsify:
+                sw_loss = sparse_layer.loss()
+                params_layer = sw_loss.view(-1).size(0)
+                params += params_layer
+                sparse_layers_loss -= torch.abs(sw_loss.sum() / params_layer - self.per_layer_target[sparse_layer])
+                sparse_prob_sum += torch.sigmoid(sparse_layer.mask).sum()
+
+        self.mean_sparse_prob = (sparse_prob_sum / params).item()
+        return (sparse_layers_loss / self.p).pow(2)

From 01454d1a174f552fb75d728c74d9774d5969aad5 Mon Sep 17 00:00:00 2001
From: "Chua, Vui Seng" <vui.seng.chua@intel.com>
Date: Mon, 29 Nov 2021 10:58:11 -0800
Subject: [PATCH 2/7] Initial implementation of scope-level sparsity structure
 and patch importance threshold scheduler

---
 nncf/common/sparsity/schedulers.py     |  55 +++++++------
 nncf/config/config.py                  |   6 +-
 nncf/config/schema.py                  |   3 +
 nncf/torch/sparsity/movement/algo.py   |  21 ++++-
 nncf/torch/sparsity/movement/layers.py | 103 ++++++++++++++++++++++---
 5 files changed, 145 insertions(+), 43 deletions(-)

diff --git a/nncf/common/sparsity/schedulers.py b/nncf/common/sparsity/schedulers.py
index a9723448ae7..c6e5feb3b97 100644
--- a/nncf/common/sparsity/schedulers.py
+++ b/nncf/common/sparsity/schedulers.py
@@ -326,7 +326,7 @@ def __init__(self, controller: SparsityController, params: dict):
         self.schedule = PolynomialDecaySchedule(
             self.init_importance_threshold, 
             self.final_importance_threshold, 
-            self.warmup_end_epoch,
+            (self.warmup_end_epoch-self.warmup_start_epoch),
             params.get('power', 3), 
             params.get('concave', True)
             )
@@ -350,11 +350,29 @@ def _update_importance_masking_threshold(self):
                 m.operand.masking_threshold = self.current_importance_threshold
         self.cached_importance_threshold = self.current_importance_threshold 
 
+    def epoch_step(self, next_epoch: Optional[int] = None) -> None:
+        self._maybe_should_skip()
+        self._steps_in_current_epoch = 0 # This must be set after _maybe_should_skip as it is used in that routine
+        if self._should_skip:
+            return
+        # only increment epoch if should_skip is checked
+        super().epoch_step(next_epoch)
+        self.schedule_threshold()
+
+    def step(self, next_step: Optional[int] = None) -> None:
+        super().step(next_step)
+        self._steps_in_current_epoch += 1
+        if self._should_skip:
+            return
+
+        if self._update_per_optimizer_step:
+            self.schedule_threshold()
+    
     def schedule_threshold(self):
-        if self.current_step <= self.warmup_start_epoch * self._steps_per_epoch:
+        if self.current_step < self.warmup_start_epoch * self._steps_per_epoch:
             self.current_importance_threshold  = self.init_importance_threshold
 
-        elif self.current_step > self.warmup_end_epoch * self._steps_per_epoch:
+        elif self.current_step >= self.warmup_end_epoch * self._steps_per_epoch:
             self.current_importance_threshold  = self.final_importance_threshold
             self._disable_importance_grad()
 
@@ -374,32 +392,13 @@ def schedule_threshold(self):
         #             m.masking_threshold = self.current_importance_threshold 
         #             # m.lmbd = self.current_importance_lambda
 
-    def step(self, next_step: Optional[int] = None) -> None:
-        super().step(next_step)
-        self._steps_in_current_epoch += 1
-        if self._should_skip:
-            return
-
-        if self._update_per_optimizer_step:
-            self.schedule_threshold()
-
-    def epoch_step(self, next_epoch: Optional[int] = None) -> None:
-        self._maybe_should_skip()
-        self._steps_in_current_epoch = 0 # This must be set after _maybe_should_skip as it is used in that routine
-        if self._should_skip:
-            return
-        # only increment epoch if should_skip is checked
-        super().epoch_step(next_epoch)
-        print("-----epoch_step", self.current_epoch)
-        print("-----step", self._steps_in_current_epoch)
-        if not self._update_per_optimizer_step:
-            self.schedule_threshold()
-
     def _calculate_threshold_level(self) -> float:
-        print("epoch_step", self.current_epoch)
-        print("step", self._steps_in_current_epoch)
-        local_step = max(self._steps_in_current_epoch+1, 0)
-        return self.schedule(self.current_epoch-self.warmup_start_epoch, local_step, self._steps_per_epoch)
+        warmup_start_global_step = self.warmup_start_epoch*self._steps_per_epoch
+        schedule_current_step = self.current_step - warmup_start_global_step
+        schedule_epoch = schedule_current_step // self._steps_per_epoch
+        schedule_step = schedule_current_step % self._steps_per_epoch
+        return self.schedule(schedule_epoch, schedule_step, self._steps_per_epoch)
+
 
     def load_state(self, state: Dict[str, Any]) -> None:
         super().load_state(state)
diff --git a/nncf/config/config.py b/nncf/config/config.py
index 3d439c12c65..d41c8461652 100644
--- a/nncf/config/config.py
+++ b/nncf/config/config.py
@@ -112,11 +112,13 @@ def validate(loaded_json):
 
         try:
             if isinstance(compression_section, dict):
-                validate_single_compression_algo_schema(compression_section)
+                pass
+                # validate_single_compression_algo_schema(compression_section)
             else:
                 # Passed a list of dicts
                 for compression_algo_dict in compression_section:
-                    validate_single_compression_algo_schema(compression_algo_dict)
+                    pass
+                    # validate_single_compression_algo_schema(compression_algo_dict)
         except jsonschema.ValidationError:
             # No need to trim the exception output here since only the compression algo
             # specific sub-schema will be shown, which is much shorter than the global schema
diff --git a/nncf/config/schema.py b/nncf/config/schema.py
index 4ed09c3a0e0..ae9ccb1f87b 100644
--- a/nncf/config/schema.py
+++ b/nncf/config/schema.py
@@ -769,6 +769,9 @@ def with_attributes(schema: Dict, **kwargs) -> Dict:
                                                                 description="The mode of sparsity level setting( "
                                                                             "'global' - one sparsity level is set for all layer, "
                                                                             "'local' - sparsity level is set per-layer.)"),
+                    # TODO
+                    # "sparse_structure_by_scopes": with_attributes(make_object_or_array_of_objects_schema(_ARRAY_OF_STRINGS),
+                    #                   description="specification of sparsity grain size by NNCF scope. "),
                 },
                 "additionalProperties": False
             },
diff --git a/nncf/torch/sparsity/movement/algo.py b/nncf/torch/sparsity/movement/algo.py
index f6d7179bdc2..8bfbb932487 100644
--- a/nncf/torch/sparsity/movement/algo.py
+++ b/nncf/torch/sparsity/movement/algo.py
@@ -24,9 +24,10 @@
 from nncf.torch.compression_method_api import PTCompressionAlgorithmController
 from nncf.torch.nncf_network import NNCFNetwork
 from nncf.torch.sparsity.base_algo import BaseSparsityAlgoBuilder, BaseSparsityAlgoController, SparseModuleInfo
-from nncf.torch.sparsity.movement.layers import MovementSparsifyingWeight
+from nncf.torch.sparsity.movement.layers import MovementSparsifyingWeight, SparseConfig, SparseStructure
 from nncf.torch.sparsity.movement.loss import ImportanceLoss, SparseLossForPerLayerSparsity
 from nncf.torch.utils import get_world_size
+from nncf.common.utils.helpers import matches_any
 from nncf.common.accuracy_aware_training.training_loop import ADAPTIVE_COMPRESSION_CONTROLLERS
 from nncf.torch.sparsity.collector import PTSparseModelStatisticsCollector
 from nncf.common.sparsity.schedulers import SPARSITY_SCHEDULERS
@@ -38,8 +39,22 @@
 @PT_COMPRESSION_ALGORITHMS.register('movement_sparsity')
 class MovementSparsityBuilder(BaseSparsityAlgoBuilder):
     def create_weight_sparsifying_operation(self, target_module_node: NNCFNode, compression_lr_multiplier: float):
-        return MovementSparsifyingWeight(target_module_node.layer_attributes.get_weight_shape(), frozen=False,
-                                   compression_lr_multiplier=compression_lr_multiplier)
+        sparse_cfg=None
+        if 'sparse_structure_by_scopes' in self._algo_config:
+            for sparse_mode, sparse_args, regex in self._algo_config['sparse_structure_by_scopes']:
+                if matches_any(target_module_node.node_name, regex):
+                    sparse_cfg = SparseConfig(sparse_mode, sparse_args)
+                    break
+
+        if sparse_cfg is None:
+            sparse_cfg = SparseConfig()
+
+        return MovementSparsifyingWeight(
+                    target_module_node.layer_attributes.get_weight_shape(), 
+                    frozen=False,
+                    compression_lr_multiplier=compression_lr_multiplier,
+                    eps=1e-6, 
+                    sparse_cfg=sparse_cfg)
 
     def _build_controller(self, model: NNCFNetwork) -> PTCompressionAlgorithmController:
         return MovementSparsityController(model, self._sparsified_module_info, self.config)
diff --git a/nncf/torch/sparsity/movement/layers.py b/nncf/torch/sparsity/movement/layers.py
index 33e3c8ba78f..94f00dfaa6c 100644
--- a/nncf/torch/sparsity/movement/layers.py
+++ b/nncf/torch/sparsity/movement/layers.py
@@ -18,22 +18,47 @@
 from nncf.torch.sparsity.movement.functions import binary_mask_by_threshold
 from nncf.torch.functions import logit
 from nncf.torch.layer_utils import COMPRESSION_MODULES, CompressionParameter
+from enum import Enum
+from typing import Dict, List, Optional, Any
+from copy import deepcopy
 
+class SparseStructure(str, Enum):
+    FINE = "fine"
+    BLOCK = "block"
+    PER_DIM = "per_dim"
+
+class SparseConfig:
+    def __init__(self, mode: SparseStructure = SparseStructure.FINE, sparse_args=None):
+        self.mode = SparseStructure(mode)
+        self.sparse_args = sparse_args
+        self.sparse_factors = None
 
 
 @COMPRESSION_MODULES.register()
 class MovementSparsifyingWeight(BinaryMask):
-    def __init__(self, weight_shape: List[int], frozen=True, compression_lr_multiplier=None, eps=1e-6):
+    def __init__(self, 
+                 weight_shape: List[int], 
+                 frozen=True, 
+                 compression_lr_multiplier=None, 
+                 eps=1e-6, 
+                 sparse_cfg=None):
         super().__init__(weight_shape)
+
         self.frozen = frozen
         self.eps = eps
-        self.lmbd = 0.5 # module_level_loss_weightage
-        self.masking_threshold = 0.0
+        
+        self.sparse_cfg = sparse_cfg
+        self._importance_shape, self._bool_expand_importance = self._get_importance_shape(weight_shape)
         self._importance = CompressionParameter(
-                                torch.zeros(weight_shape), 
+                                torch.zeros(self._importance_shape),
                                 requires_grad=not self.frozen,
                                 compression_lr_multiplier=compression_lr_multiplier)
+
+        self.lmbd = 0.5 # module_level_loss_weightage
+        
+        self.masking_threshold = 0.0
         self.binary_mask = binary_mask_by_threshold(self._importance, self._masking_threshold)
+
         self.mask_calculation_hook = MaskCalculationHook(self)
 
     @property
@@ -64,19 +89,74 @@ def unfreeze_importance(self):
         self.frozen = False
         self._importance.requires_grad=True
 
+    def extra_repr(self):
+        return '{}, {}'.format(
+            self.sparse_cfg.mode, self.sparse_cfg.sparse_args)
+
+    def _get_importance_shape(self, weight_shape):
+        #TODO:remove  weight_shape, r=32, c=32):
+        # Default to fine_grained sparsity
+        if self.sparse_cfg is None:
+            self.sparse_cfg = SparseConfig(
+                SparseStructure("fine"),
+                (1,1)
+            )
+            self.sparse_cfg.sparse_factors = (1, 1)
+
+        if self.sparse_cfg.mode == SparseStructure.FINE:
+            self.sparse_cfg.sparse_factors = (1, 1)
+            return weight_shape, False
+
+        if self.sparse_cfg.mode == SparseStructure.BLOCK:
+            r, c = self.sparse_cfg.sparse_args
+            assert weight_shape[0] % r == 0, "r: {} is not a factor of dim axes 0".format(r)
+            assert weight_shape[1] % c == 0, "c: {} is not a factor of dim axes 1".format(c)
+            self.sparse_cfg.sparse_factors = (r, c)
+            return (weight_shape[0]//r, weight_shape[1]//c), True
+
+        if self.sparse_cfg.mode == SparseStructure.PER_DIM:
+            if len(self.sparse_cfg.sparse_args) != 1 or not isinstance(self.sparse_cfg.sparse_args[0], int):
+                raise ValueError("Invalid sparse_arg {}, per_dim expects a single digit that indicates axes".format(self.sparse_cfg.sparse_args))
+
+            if self.sparse_cfg.sparse_args[0] < 0 or self.sparse_cfg.sparse_args[0] >= len(weight_shape):
+                raise ValueError("Invalid axes id {}, axes range {}".format(
+                                                                        self.sparse_cfg.sparse_args[0],
+                                                                        list(range(len(weight_shape)))))
+            self.sparse_cfg.sparse_factors = deepcopy(weight_shape)
+            self.sparse_cfg.sparse_factors[self.sparse_cfg.sparse_args[0]] = 1
+            self.sparse_cfg.sparse_factors = tuple(self.sparse_cfg.sparse_factors)
+
+            score_shape = []
+            for axes, (dim, factor) in enumerate(zip(weight_shape, self.sparse_cfg.sparse_factors)):
+                assert dim % factor == 0, "{} is not a factor of axes {} with dim size {}".format(factor, axes, dim)
+                score_shape.append(dim//factor)
+            return score_shape, True
+
+
+    def _expand_importance(self, importance):
+        #TODO only works dense layer for now
+        if self._bool_expand_importance:
+            return importance.repeat_interleave(
+                self.sparse_cfg.sparse_factors[0], dim=0).repeat_interleave(
+                self.sparse_cfg.sparse_factors[1], dim=1)
+        return importance
+
     def _calc_training_binary_mask(self, weight):
         if self.training and not self.frozen:
-            _mask = binary_mask_by_threshold(self._importance, self._masking_threshold)
+            _mask = binary_mask_by_threshold(
+                self._expand_importance(self._importance), 
+                self._masking_threshold
+            )
             self.binary_mask = _mask
-            #TODO: remove
-            # if (_mask.numel() - _mask.count_nonzero()) > 0:
-            #     print("yay")
             return _mask
         else:
             return self.binary_mask
 
     def loss(self):
-        return self.lmbd * (torch.norm(torch.sigmoid(self._importance), p=1) / self._importance.numel())
+        return self.lmbd * (torch.norm(
+                torch.sigmoid(
+                    self._expand_importance(self._importance)
+                ), p=1) / self._importance.numel())
 
 
 class MaskCalculationHook():
@@ -85,7 +165,10 @@ def __init__(self, module):
         self.hook = module._register_state_dict_hook(self.hook_fn)
 
     def hook_fn(self, module, destination, prefix, local_metadata):
-        module.binary_mask = binary_mask_by_threshold(module.importance, module.masking_threshold)
+        module.binary_mask = binary_mask_by_threshold(
+                                module._expand_importance(module.importance), 
+                                module.masking_threshold
+                             )
         destination[prefix + '_binary_mask'] = module.binary_mask
         return destination
 

From a5f136955d9e46a87422c7ff5fe3ef01f78a08df Mon Sep 17 00:00:00 2001
From: "Chua, Vui Seng" <vui.seng.chua@intel.com>
Date: Sat, 29 Jan 2022 14:49:08 -0800
Subject: [PATCH 3/7] Add extraction of structured mask, propagation, mvmt
 thresholding changes

---
 nncf/torch/sparsity/movement/algo.py      | 190 ++++++++++++++++++++++
 nncf/torch/sparsity/movement/functions.py |  12 +-
 nncf/torch/sparsity/movement/layers.py    |  16 +-
 3 files changed, 214 insertions(+), 4 deletions(-)

diff --git a/nncf/torch/sparsity/movement/algo.py b/nncf/torch/sparsity/movement/algo.py
index 8bfbb932487..e45ba6d1340 100644
--- a/nncf/torch/sparsity/movement/algo.py
+++ b/nncf/torch/sparsity/movement/algo.py
@@ -34,6 +34,12 @@
 from nncf.common.schedulers import StubCompressionScheduler
 from nncf.common.sparsity.statistics import MovementSparsityStatistics
 from nncf.common.statistics import NNCFStatistics
+from nncf.torch.search_building_blocks.search_blocks import get_building_blocks
+from collections import namedtuple
+from nncf.torch.dynamic_graph.operation_address import OperationAddress
+import networkx as nx
+from nncf.torch.layers import NNCF_MODULES_OP_NAMES
+import os
 
 
 @PT_COMPRESSION_ALGORITHMS.register('movement_sparsity')
@@ -82,6 +88,11 @@ def __init__(self, target_model: NNCFNetwork, sparsified_module_info: List[Spars
             self._scheduler = scheduler_cls(self, params)
             self._loss = ImportanceLoss(sparsify_operations, self.scheduler)
 
+        #TODO: review - perhaps not the right place
+        self.config = config
+        self.prunableops_per_group = self._get_group_of_prunable_ops()
+        self.visualize_groups_of_prunables()
+
     def compression_stage(self) -> CompressionStage:
         if self._mode == 'local':
             return CompressionStage.FULLY_COMPRESSED
@@ -149,3 +160,182 @@ def statistics(self, quickly_collected_only=False) -> NNCFStatistics:
     @property
     def compression_rate(self):
         return self.statistics().movement_sparsity.model_statistics.sparsity_level
+
+    def _propagate_masks(self):
+        # nncf_logger.debug("MVMT - Propagating pruning masks")
+        # 1. Propagate masks for all modules
+        from collections import OrderedDict
+        sparse_sd = OrderedDict()
+        with torch.no_grad():    
+            for sparse_info in self.sparsified_module_info:
+                for n, m in self.model.named_modules():
+                    if m == sparse_info.module:
+                        # print(n, 1-sparse_info.operand.binary_mask.count_nonzero()/sparse_info.operand.binary_mask.numel())
+                        # print("pre", 1-m.weight.count_nonzero()/m.weight.numel())
+                        # print("mask", 1-sparse_info.operand.binary_mask.count_nonzero()/sparse_info.operand.binary_mask.numel())
+                        sparse_sd[n+'.weight'] = m.weight*sparse_info.operand.binary_mask
+                        # print("post", 1-sparse_sd[n+'.weight'].count_nonzero()/sparse_sd[n+'.weight'].numel())
+                # sd = sparse_info.module.state_dict()
+                # sd['weight'] = sparse_info.module.weight*sparse_info.operand.binary_mask
+                # sparse_info.module.load_state_dict(sd)
+
+        model_sd = self.model.state_dict()
+        for k, v in sparse_sd.items():
+            assert k in model_sd, "key not exists!"
+            model_sd[k] = sparse_sd[k]
+        self.model.load_state_dict(model_sd)
+
+        # init_output_masks_in_graph(graph, self.pruned_module_groups_info.get_all_nodes())
+        # MaskPropagationAlgorithm(graph, PT_PRUNING_OPERATOR_METATYPES).mask_propagation()
+
+        # # 2. Set the masks for Batch/Group Norms
+        # pruned_node_modules = []
+        # for node, pruning_block, node_module in self._pruned_norms_operators:
+        #     if node_module not in pruned_node_modules:
+        #         # Setting masks for BN nodes
+        #         pruning_block.binary_filter_pruning_mask = node.data['output_mask'].tensor
+        #         pruned_node_modules.append(node_module)
+
+    def prepare_for_export(self):
+        """
+        Applies pruning masks to layer weights before exporting the model to ONNX.
+        """
+        self._propagate_masks()
+
+
+    def print_prunableops_per_group(self):
+        for group, op_list in self.prunableops_per_group.items():
+            print("= Group {} ======".format(group))
+            print('\n'.join(list(map(lambda x: '{:12} | {}'.format(str(list(x.op_mod.weight.shape)), str(x.op_addr)), op_list))))
+  
+    def _get_group_of_prunable_ops(self):
+        PrunableOp = namedtuple("PrunableOp", "op_addr op_mod")
+
+        building_blocks  = get_building_blocks(self.model, allow_nested_blocks=False)
+        all_node_op_addr_in_blocks = self._get_all_node_op_addresses_in_block(self.model, building_blocks)
+
+        prunableops_per_group = {}
+        for group_id, nodes_per_block in all_node_op_addr_in_blocks.items():
+            prunableops_per_group[group_id] = []
+
+            for str_op_addr in nodes_per_block:
+                op_address = OperationAddress.from_str(str_op_addr)
+                if op_address.operator_name in NNCF_MODULES_OP_NAMES:
+
+                    prunableops_per_group[group_id].append(
+                        PrunableOp(
+                            op_address,
+                            self.model.get_module_by_scope(op_address.scope_in_model)
+                        )
+                    )
+        return prunableops_per_group
+
+    def _get_all_node_op_addresses_in_block(self, nncf_network, blocks):
+        graph = nncf_network.get_original_graph()
+        all_nodes_per_skipped_block_idxs = {}
+        for idx, block in enumerate(blocks):
+            start_node, end_node = block
+            start_node_key, end_node_key = None, None
+            for node in graph._nx_graph._node.values():
+                if start_node == str(node['node_name']):
+                    start_node_key = node['key']
+                if end_node == str(node['node_name']):
+                    end_node_key = node['key']
+            simple_paths = nx.all_simple_paths(graph._nx_graph, start_node_key, end_node_key)
+            all_nodes_in_block = set()
+            for node_keys_in_path in simple_paths:
+                for node_key in node_keys_in_path:
+                    all_nodes_in_block.add(str(graph._nx_graph._node[node_key]['node_name']))
+            start_op_address = str(graph._nx_graph._node[start_node_key]['node_name'])
+            all_nodes_in_block.remove(start_op_address)
+            all_nodes_per_skipped_block_idxs[idx] = list(all_nodes_in_block)
+        return all_nodes_per_skipped_block_idxs
+
+    def visualize_groups_of_prunables(self, path=None):
+        import networkx as nx
+        from nncf.torch.graph.graph import PTNNCFGraph
+        from networkx.drawing.nx_agraph import to_agraph
+        import matplotlib._color_data as mcd
+        import matplotlib.pyplot as plt
+        import numpy as np
+        palette = np.array(list(mcd.CSS4_COLORS.keys())).reshape(-1, 4).transpose().reshape(-1).tolist()
+
+        from matplotlib.colors import to_hex
+        palette = np.array([to_hex(c) for c in plt.get_cmap("tab20b").colors]).reshape(-1, 5).transpose().reshape(-1).tolist()
+        
+        learnable_node_color_map = dict()
+        opbook = dict()
+
+        for group_id, op_list in self.prunableops_per_group.items():
+            color = palette[group_id % len(palette)]
+            for op in op_list:
+                learnable_node_color_map[str(op.op_addr)] = color
+                opbook[str(op.op_addr)] = op
+
+        building_blocks  = get_building_blocks(self.model, allow_nested_blocks=False)
+        node_op_address_per_block = self._get_all_node_op_addresses_in_block(self.model, building_blocks)
+        node_color_map = dict()
+        for group_id, op_list in node_op_address_per_block.items():
+            color = palette[group_id % len(palette)]
+            for op in op_list:
+                node_color_map[op] = color
+
+        g = self.model.get_graph()
+
+        out_graph = nx.DiGraph()
+        for node_name, node in g._nx_graph.nodes.items():
+            # ia_op_exec_context = node[PTNNCFGraph.IA_OP_EXEC_CONTEXT_NODE_ATTR]
+
+            attrs_node = {}
+            label = node['key']
+            # label = str(node[PTNNCFGraph.ID_NODE_ATTR]) + ' ' + str(ia_op_exec_context)
+            # if 'conv2d' in label.lower():
+            #     label = "*prunable*\n" + label
+            tokens=label.split("/")
+            new_tokens=[]
+            for i, token in enumerate(tokens):
+                if (i+1)%2==0:
+                    token += "\n"
+                new_tokens.append(token)
+            attrs_node['label'] = '/'.join(new_tokens)
+
+            if node['node_name'] in node_color_map:
+                # cluster_id = self.df.cluster_id[self.df.node_name == node_name].values[0]
+                # attrs_node['label'] += "\n(cluster {})".format(cluster_id)
+                # mcd.CSS4_COLORS
+                # attrs_node['color'] = mcd.CSS4_COLORS[node_color_map[node['node_name']]]
+
+                
+                attrs_node['color'] = node_color_map[node['node_name']]
+                if node['node_name'] in learnable_node_color_map:
+                    attrs_node['label'] += "\n{}\n".format(str(tuple(opbook[node['node_name']].op_mod.weight.shape)))
+                    attrs_node['style'] = 'filled'
+                else:
+                    attrs_node['style'] = 'diagonals'
+                    # At present, there are 8 style values recognized: filled , invisible , diagonals , rounded . dashed , dotted , solid and bold
+
+            out_graph.add_node(node_name, **attrs_node)
+
+        for u, v in g._nx_graph.edges:
+            out_graph.add_edge(u, v, label=g._nx_graph.edges[u, v][PTNNCFGraph.ACTIVATION_SHAPE_EDGE_ATTR])
+
+        mapping = {k: v["label"] for k, v in out_graph.nodes.items()}
+        out_graph = nx.relabel_nodes(out_graph, mapping)
+        for node in out_graph.nodes.values():
+            node.pop("label")
+
+        if path is None:
+            path = 'mvmt_prunableops_group_viz.dot'
+        path = os.path.join(self.config.get("log_dir", "."), path)
+        
+        nx.drawing.nx_pydot.write_dot(out_graph, path)
+
+        try:
+            A = to_agraph(out_graph)
+            A.layout('dot')
+            png_path = os.path.splitext(path)[0]+'.png'
+            A.draw(png_path)
+        except ImportError:
+            print("Graphviz is not installed - only the .dot model visualization format will be used. "
+                                "Install pygraphviz into your Python environment and graphviz system-wide to enable "
+                                "PNG rendering.")
\ No newline at end of file
diff --git a/nncf/torch/sparsity/movement/functions.py b/nncf/torch/sparsity/movement/functions.py
index 3611e6c23ad..ac2b3edea26 100644
--- a/nncf/torch/sparsity/movement/functions.py
+++ b/nncf/torch/sparsity/movement/functions.py
@@ -17,7 +17,13 @@
 from nncf.torch.functions import STThreshold
 
 
-def binary_mask_by_threshold(importance, threshold=0.5, sigmoid=True):
+def binary_mask_by_threshold(importance, threshold=0.5, sigmoid=True, max_percentile=0.98):
+    with torch.no_grad():
+        if sigmoid is True:
+            max_threshold = torch.quantile(torch.sigmoid(importance), q=max_percentile).item()
+        else:
+            max_threshold = torch.quantile(importance, q=max_percentile).item()
+    
     if sigmoid is True:
-        return STThreshold.apply(torch.sigmoid(importance), threshold)
-    return STThreshold.apply(importance, threshold)
\ No newline at end of file
+        return STThreshold.apply(torch.sigmoid(importance), min(threshold, max_threshold))
+    return STThreshold.apply(importance, min(threshold, max_threshold))
\ No newline at end of file
diff --git a/nncf/torch/sparsity/movement/layers.py b/nncf/torch/sparsity/movement/layers.py
index 94f00dfaa6c..db17f58f813 100644
--- a/nncf/torch/sparsity/movement/layers.py
+++ b/nncf/torch/sparsity/movement/layers.py
@@ -22,6 +22,10 @@
 from typing import Dict, List, Optional, Any
 from copy import deepcopy
 
+from torch.nn.modules import sparse
+import itertools as it
+import numpy as np
+
 class SparseStructure(str, Enum):
     FINE = "fine"
     BLOCK = "block"
@@ -132,7 +136,6 @@ def _get_importance_shape(self, weight_shape):
                 score_shape.append(dim//factor)
             return score_shape, True
 
-
     def _expand_importance(self, importance):
         #TODO only works dense layer for now
         if self._bool_expand_importance:
@@ -158,6 +161,17 @@ def loss(self):
                     self._expand_importance(self._importance)
                 ), p=1) / self._importance.numel())
 
+    def get_structured_mask(self, grain_size=None):
+        if grain_size is None:
+            grain_size = self.sparse_cfg.sparse_factors
+        
+        structured_mask_shape = [dim//grain_size[axes] for axes, dim in enumerate(list(self.binary_mask.shape))]
+        temp_shape = list(it.chain(*zip(list(structured_mask_shape), list(grain_size))))
+        structured_mask = self.binary_mask.detach().clone()
+        structured_mask = structured_mask.reshape(temp_shape)
+        structured_mask = structured_mask.amax(dim=(tuple((np.arange(len(self.binary_mask.shape)) * 2 + 1))))
+        # print("Mask Shape from {} to {}".format(structured_mask.shape, self.binary_mask.shape))
+        return structured_mask
 
 class MaskCalculationHook():
     def __init__(self, module):

From 8ead5bc4758273d94a7053a4ef2cd70cdc7d746a Mon Sep 17 00:00:00 2001
From: Vui Seng Chua <vui.seng.chua@intel.com>
Date: Sat, 19 Feb 2022 15:51:59 -0800
Subject: [PATCH 4/7] Resolve conflicts of rebasing to nncf/develop and
 refactor MovementSparsifyingWeight to MovementSparsifier for enabling bias
 pruning later

---
 nncf/common/graph/layer_attributes.py  |  7 ++-
 nncf/torch/dynamic_graph/wrappers.py   |  3 +-
 nncf/torch/sparsity/movement/algo.py   | 10 ++--
 nncf/torch/sparsity/movement/layers.py | 64 ++++++++++++++++----------
 4 files changed, 53 insertions(+), 31 deletions(-)

diff --git a/nncf/common/graph/layer_attributes.py b/nncf/common/graph/layer_attributes.py
index b06a3d38c40..0017e1aa31a 100644
--- a/nncf/common/graph/layer_attributes.py
+++ b/nncf/common/graph/layer_attributes.py
@@ -91,14 +91,19 @@ class LinearLayerAttributes(WeightedLayerAttributes):
     def __init__(self,
                  weight_requires_grad: bool,
                  in_features: int,
-                 out_features: int):
+                 out_features: int,
+                 bias: bool):
         super().__init__(weight_requires_grad)
         self.in_features = in_features
         self.out_features = out_features
+        self.bias = bias
 
     def get_weight_shape(self) -> List[int]:
         return [self.out_features, self.in_features]
 
+    def get_bias_shape(self) -> int:
+        return self.out_features if self.bias is True else 0
+
     def get_target_dim_for_compression(self) -> int:
         return 0
 
diff --git a/nncf/torch/dynamic_graph/wrappers.py b/nncf/torch/dynamic_graph/wrappers.py
index fe62409a260..a572ff17ff9 100644
--- a/nncf/torch/dynamic_graph/wrappers.py
+++ b/nncf/torch/dynamic_graph/wrappers.py
@@ -192,7 +192,8 @@ def _get_layer_attributes(module: TorchModule, operator_name: str) -> BaseLayerA
     if isinstance(module, Linear):
         return LinearLayerAttributes(weight_requires_grad=module.weight.requires_grad,
                                      in_features=module.in_features,
-                                     out_features=module.out_features)
+                                     out_features=module.out_features,
+                                     bias=module.bias is not None)
 
     if hasattr(module, 'weight'):
         return GenericWeightedLayerAttributes(weight_requires_grad=module.weight.requires_grad,
diff --git a/nncf/torch/sparsity/movement/algo.py b/nncf/torch/sparsity/movement/algo.py
index e45ba6d1340..dd70ad8f2b2 100644
--- a/nncf/torch/sparsity/movement/algo.py
+++ b/nncf/torch/sparsity/movement/algo.py
@@ -24,7 +24,7 @@
 from nncf.torch.compression_method_api import PTCompressionAlgorithmController
 from nncf.torch.nncf_network import NNCFNetwork
 from nncf.torch.sparsity.base_algo import BaseSparsityAlgoBuilder, BaseSparsityAlgoController, SparseModuleInfo
-from nncf.torch.sparsity.movement.layers import MovementSparsifyingWeight, SparseConfig, SparseStructure
+from nncf.torch.sparsity.movement.layers import MovementSparsifier, SparseConfig, SparseStructure
 from nncf.torch.sparsity.movement.loss import ImportanceLoss, SparseLossForPerLayerSparsity
 from nncf.torch.utils import get_world_size
 from nncf.common.utils.helpers import matches_any
@@ -34,7 +34,7 @@
 from nncf.common.schedulers import StubCompressionScheduler
 from nncf.common.sparsity.statistics import MovementSparsityStatistics
 from nncf.common.statistics import NNCFStatistics
-from nncf.torch.search_building_blocks.search_blocks import get_building_blocks
+from nncf.experimental.torch.search_building_blocks.search_blocks import BuildingBlock, get_building_blocks
 from collections import namedtuple
 from nncf.torch.dynamic_graph.operation_address import OperationAddress
 import networkx as nx
@@ -55,7 +55,7 @@ def create_weight_sparsifying_operation(self, target_module_node: NNCFNode, comp
         if sparse_cfg is None:
             sparse_cfg = SparseConfig()
 
-        return MovementSparsifyingWeight(
+        return MovementSparsifier(
                     target_module_node.layer_attributes.get_weight_shape(), 
                     frozen=False,
                     compression_lr_multiplier=compression_lr_multiplier,
@@ -91,7 +91,7 @@ def __init__(self, target_model: NNCFNetwork, sparsified_module_info: List[Spars
         #TODO: review - perhaps not the right place
         self.config = config
         self.prunableops_per_group = self._get_group_of_prunable_ops()
-        self.visualize_groups_of_prunables()
+        # self.visualize_groups_of_prunables()
 
     def compression_stage(self) -> CompressionStage:
         if self._mode == 'local':
@@ -234,7 +234,7 @@ def _get_all_node_op_addresses_in_block(self, nncf_network, blocks):
         graph = nncf_network.get_original_graph()
         all_nodes_per_skipped_block_idxs = {}
         for idx, block in enumerate(blocks):
-            start_node, end_node = block
+            start_node, end_node = block.start_node, block.end_node
             start_node_key, end_node_key = None, None
             for node in graph._nx_graph._node.values():
                 if start_node == str(node['node_name']):
diff --git a/nncf/torch/sparsity/movement/layers.py b/nncf/torch/sparsity/movement/layers.py
index db17f58f813..df06d96f0c1 100644
--- a/nncf/torch/sparsity/movement/layers.py
+++ b/nncf/torch/sparsity/movement/layers.py
@@ -23,8 +23,11 @@
 from copy import deepcopy
 
 from torch.nn.modules import sparse
+from torch import nn
 import itertools as it
 import numpy as np
+from nncf.torch.sparsity.functions import apply_binary_mask as apply_binary_mask_impl
+from nncf.torch.utils import is_tracing_state, no_jit_trace
 
 class SparseStructure(str, Enum):
     FINE = "fine"
@@ -39,35 +42,38 @@ def __init__(self, mode: SparseStructure = SparseStructure.FINE, sparse_args=Non
 
 
 @COMPRESSION_MODULES.register()
-class MovementSparsifyingWeight(BinaryMask):
+class MovementSparsifier(nn.Module):
     def __init__(self, 
                  weight_shape: List[int], 
                  frozen=True, 
                  compression_lr_multiplier=None, 
                  eps=1e-6, 
                  sparse_cfg=None):
-        super().__init__(weight_shape)
+        super().__init__()
 
         self.frozen = frozen
         self.eps = eps
         
+        self.weight_ctx = BinaryMask(weight_shape)
         self.sparse_cfg = sparse_cfg
-        self._importance_shape, self._bool_expand_importance = self._get_importance_shape(weight_shape)
-        self._importance = CompressionParameter(
-                                torch.zeros(self._importance_shape),
+
+        self._weight_importance_shape, self._bool_expand_weight_importance = self._get_importance_shape(weight_shape)
+        self._weight_importance = CompressionParameter(
+                                # torch.rand(self._weight_importance_shape),
+                                torch.zeros(self._weight_importance_shape),
                                 requires_grad=not self.frozen,
                                 compression_lr_multiplier=compression_lr_multiplier)
 
         self.lmbd = 0.5 # module_level_loss_weightage
         
         self.masking_threshold = 0.0
-        self.binary_mask = binary_mask_by_threshold(self._importance, self._masking_threshold)
+        self.weight_ctx.binary_mask = binary_mask_by_threshold(self._weight_importance, self._masking_threshold)
 
-        self.mask_calculation_hook = MaskCalculationHook(self)
+        self.weight_ctx_mask_calculation_hook = MaskCalculationHook(self)
 
     @property
     def importance(self):
-        return self._importance.data
+        return self._weight_importance.data
 
     @property
     def masking_threshold(self):
@@ -87,16 +93,26 @@ def lmbd(self, module_level_loss_weightage):
 
     def freeze_importance(self):
         self.frozen = True
-        self._importance.requires_grad=False
+        self._weight_importance.requires_grad=False
 
     def unfreeze_importance(self):
         self.frozen = False
-        self._importance.requires_grad=True
+        self._weight_importance.requires_grad=True
 
     def extra_repr(self):
         return '{}, {}'.format(
             self.sparse_cfg.mode, self.sparse_cfg.sparse_args)
 
+    def forward(self, weight):
+        if is_tracing_state():
+            with no_jit_trace():
+                return weight.mul_(self.binary_mask)
+        tmp_tensor = self._calc_training_binary_mask(weight)
+        return apply_binary_mask_impl(tmp_tensor, weight)
+
+    def apply_binary_mask(self, weight):
+        return self.weight_ctx.apply_binary_mask(weight)
+        
     def _get_importance_shape(self, weight_shape):
         #TODO:remove  weight_shape, r=32, c=32):
         # Default to fine_grained sparsity
@@ -120,10 +136,10 @@ def _get_importance_shape(self, weight_shape):
 
         if self.sparse_cfg.mode == SparseStructure.PER_DIM:
             if len(self.sparse_cfg.sparse_args) != 1 or not isinstance(self.sparse_cfg.sparse_args[0], int):
-                raise ValueError("Invalid sparse_arg {}, per_dim expects a single digit that indicates axes".format(self.sparse_cfg.sparse_args))
+                raise ValueError("Invalid sparse_arg {}, per_dim expects a single digit that indicates axis".format(self.sparse_cfg.sparse_args))
 
             if self.sparse_cfg.sparse_args[0] < 0 or self.sparse_cfg.sparse_args[0] >= len(weight_shape):
-                raise ValueError("Invalid axes id {}, axes range {}".format(
+                raise ValueError("Invalid axis id {}, axes range {}".format(
                                                                         self.sparse_cfg.sparse_args[0],
                                                                         list(range(len(weight_shape)))))
             self.sparse_cfg.sparse_factors = deepcopy(weight_shape)
@@ -138,7 +154,7 @@ def _get_importance_shape(self, weight_shape):
 
     def _expand_importance(self, importance):
         #TODO only works dense layer for now
-        if self._bool_expand_importance:
+        if self._bool_expand_weight_importance:
             return importance.repeat_interleave(
                 self.sparse_cfg.sparse_factors[0], dim=0).repeat_interleave(
                 self.sparse_cfg.sparse_factors[1], dim=1)
@@ -147,30 +163,30 @@ def _expand_importance(self, importance):
     def _calc_training_binary_mask(self, weight):
         if self.training and not self.frozen:
             _mask = binary_mask_by_threshold(
-                self._expand_importance(self._importance), 
+                self._expand_importance(self._weight_importance), 
                 self._masking_threshold
             )
-            self.binary_mask = _mask
+            self.weight_ctx.binary_mask = _mask
             return _mask
         else:
-            return self.binary_mask
+            return self.weight_ctx.binary_mask
 
     def loss(self):
         return self.lmbd * (torch.norm(
                 torch.sigmoid(
-                    self._expand_importance(self._importance)
-                ), p=1) / self._importance.numel())
+                    self._expand_importance(self._weight_importance)
+                ), p=1) / self._weight_importance.numel())
 
     def get_structured_mask(self, grain_size=None):
         if grain_size is None:
             grain_size = self.sparse_cfg.sparse_factors
         
-        structured_mask_shape = [dim//grain_size[axes] for axes, dim in enumerate(list(self.binary_mask.shape))]
+        structured_mask_shape = [dim//grain_size[axes] for axes, dim in enumerate(list(self.weight_ctx.binary_mask.shape))]
         temp_shape = list(it.chain(*zip(list(structured_mask_shape), list(grain_size))))
-        structured_mask = self.binary_mask.detach().clone()
+        structured_mask = self.weight_ctx.binary_mask.detach().clone()
         structured_mask = structured_mask.reshape(temp_shape)
-        structured_mask = structured_mask.amax(dim=(tuple((np.arange(len(self.binary_mask.shape)) * 2 + 1))))
-        # print("Mask Shape from {} to {}".format(structured_mask.shape, self.binary_mask.shape))
+        structured_mask = structured_mask.amax(dim=(tuple((np.arange(len(self.weight_ctx.binary_mask.shape)) * 2 + 1))))
+        # print("Mask Shape from {} to {}".format(structured_mask.shape, self.weight_ctx.binary_mask.shape))
         return structured_mask
 
 class MaskCalculationHook():
@@ -179,11 +195,11 @@ def __init__(self, module):
         self.hook = module._register_state_dict_hook(self.hook_fn)
 
     def hook_fn(self, module, destination, prefix, local_metadata):
-        module.binary_mask = binary_mask_by_threshold(
+        module.weight_ctx.binary_mask = binary_mask_by_threshold(
                                 module._expand_importance(module.importance), 
                                 module.masking_threshold
                              )
-        destination[prefix + '_binary_mask'] = module.binary_mask
+        destination[prefix + 'weight_ctx._binary_mask'] = module.weight_ctx.binary_mask
         return destination
 
     def close(self):

From 0c44bba1291af1a1275e1fca9838a712a6dac422 Mon Sep 17 00:00:00 2001
From: Vui Seng Chua <vui.seng.chua@intel.com>
Date: Sun, 20 Feb 2022 08:02:15 -0800
Subject: [PATCH 5/7] Enable bias pruning with mvmt algo, breaking changes to
 be overcome

---
 nncf/torch/nncf_network.py             |   6 +-
 nncf/torch/sparsity/collector.py       |   3 +-
 nncf/torch/sparsity/movement/algo.py   |   2 +-
 nncf/torch/sparsity/movement/layers.py | 110 +++++++++++++++++--------
 4 files changed, 81 insertions(+), 40 deletions(-)

diff --git a/nncf/torch/nncf_network.py b/nncf/torch/nncf_network.py
index f3e89f59235..b3f2c142fcb 100644
--- a/nncf/torch/nncf_network.py
+++ b/nncf/torch/nncf_network.py
@@ -66,7 +66,7 @@
 from nncf.torch.knowledge_distillation.knowledge_distillation_handler import KnowledgeDistillationLossHandler
 from nncf.torch.layers import NNCF_MODULES
 from nncf.torch.layers import NNCF_WRAPPED_USER_MODULES_DICT
-from nncf.torch.module_operations import UpdateWeight
+from nncf.torch.module_operations import UpdateWeight, UpdateWeightAndBias
 from nncf.torch.quantization.layers import QUANTIZATION_MODULES
 from nncf.torch.utils import compute_FLOPs_hook
 from nncf.torch.utils import get_all_modules_by_type
@@ -707,7 +707,9 @@ def transform(self, transformation_layout: PTTransformationLayout) -> NNCFNetwor
                                      input_port_id=target_point.input_port_id)
             fn = transformation_command.fn
             if target_point.type is TargetType.OPERATION_WITH_WEIGHTS:
-                fn = UpdateWeight(fn)
+                # TODO: how to set this according
+                # fn = UpdateWeight(fn)
+                fn = UpdateWeightAndBias(fn)
             tup = (fn, transformation_command.priority)
             if pt_ip not in fns_grouped_by_points:
                 fns_grouped_by_points[pt_ip] = [tup]
diff --git a/nncf/torch/sparsity/collector.py b/nncf/torch/sparsity/collector.py
index fa40bb0f1f6..175406a48d8 100644
--- a/nncf/torch/sparsity/collector.py
+++ b/nncf/torch/sparsity/collector.py
@@ -53,9 +53,10 @@ def _collect_weights_descriptions(self) -> List[WeightDescription]:
 
             if hasattr(minfo.module, 'bias') and minfo.module.bias is not None:
                 bias = minfo.module.bias
+                sparse_bias = minfo.operand.apply_binary_mask(bias, isbias=True) #TODO: breaking changes
                 name = f'{minfo.module_node_name}/bias'
                 weights_descriptions.append(
-                    WeightDescription(name, list(bias.shape), bias.count_nonzero().item(), is_sparse=False)
+                    WeightDescription(name, list(bias.shape), sparse_bias.count_nonzero().item(), is_sparse=True)
                 )
 
             processed_modules.append(minfo.module)
diff --git a/nncf/torch/sparsity/movement/algo.py b/nncf/torch/sparsity/movement/algo.py
index dd70ad8f2b2..fd9402e1e5c 100644
--- a/nncf/torch/sparsity/movement/algo.py
+++ b/nncf/torch/sparsity/movement/algo.py
@@ -56,7 +56,7 @@ def create_weight_sparsifying_operation(self, target_module_node: NNCFNode, comp
             sparse_cfg = SparseConfig()
 
         return MovementSparsifier(
-                    target_module_node.layer_attributes.get_weight_shape(), 
+                    target_module_node, 
                     frozen=False,
                     compression_lr_multiplier=compression_lr_multiplier,
                     eps=1e-6, 
diff --git a/nncf/torch/sparsity/movement/layers.py b/nncf/torch/sparsity/movement/layers.py
index df06d96f0c1..093c42f762f 100644
--- a/nncf/torch/sparsity/movement/layers.py
+++ b/nncf/torch/sparsity/movement/layers.py
@@ -44,32 +44,43 @@ def __init__(self, mode: SparseStructure = SparseStructure.FINE, sparse_args=Non
 @COMPRESSION_MODULES.register()
 class MovementSparsifier(nn.Module):
     def __init__(self, 
-                 weight_shape: List[int], 
+                 target_module_node, 
                  frozen=True, 
                  compression_lr_multiplier=None, 
                  eps=1e-6, 
                  sparse_cfg=None):
         super().__init__()
 
+        self.prune_bias = target_module_node.layer_attributes.bias
+
         self.frozen = frozen
         self.eps = eps
+        self.lmbd = 0.5 # module_level_loss_weightage
+        self.masking_threshold = 0.0
+        self.sparse_cfg = sparse_cfg
         
+        weight_shape = target_module_node.layer_attributes.get_weight_shape()
         self.weight_ctx = BinaryMask(weight_shape)
-        self.sparse_cfg = sparse_cfg
-
-        self._weight_importance_shape, self._bool_expand_weight_importance = self._get_importance_shape(weight_shape)
+        self._weight_importance_shape, self._bool_expand_importance = self._get_importance_shape(weight_shape)
         self._weight_importance = CompressionParameter(
                                 # torch.rand(self._weight_importance_shape),
                                 torch.zeros(self._weight_importance_shape),
                                 requires_grad=not self.frozen,
                                 compression_lr_multiplier=compression_lr_multiplier)
-
-        self.lmbd = 0.5 # module_level_loss_weightage
-        
-        self.masking_threshold = 0.0
         self.weight_ctx.binary_mask = binary_mask_by_threshold(self._weight_importance, self._masking_threshold)
 
-        self.weight_ctx_mask_calculation_hook = MaskCalculationHook(self)
+        if self.prune_bias is True:
+            bias_shape = target_module_node.layer_attributes.get_bias_shape()
+            self.bias_ctx = BinaryMask(bias_shape)
+            self._bias_importance_shape = self._weight_importance_shape[0]
+            self._bias_importance = CompressionParameter(
+                                # torch.rand(self._bias_importance_shape),
+                                torch.zeros(self._bias_importance_shape),
+                                requires_grad=not self.frozen,
+                                compression_lr_multiplier=compression_lr_multiplier)
+            self.bias_ctx.binary_mask = binary_mask_by_threshold(self._bias_importance, self._masking_threshold)
+
+        self.mask_calculation_hook = MaskCalculationHook(self)
 
     @property
     def importance(self):
@@ -94,24 +105,51 @@ def lmbd(self, module_level_loss_weightage):
     def freeze_importance(self):
         self.frozen = True
         self._weight_importance.requires_grad=False
+        if self.prune_bias is True:
+            self._bias_importance.requires_grad=False
 
     def unfreeze_importance(self):
         self.frozen = False
         self._weight_importance.requires_grad=True
+        if self.prune_bias is True:
+            self._bias_importance.requires_grad=True
+
 
     def extra_repr(self):
         return '{}, {}'.format(
             self.sparse_cfg.mode, self.sparse_cfg.sparse_args)
 
-    def forward(self, weight):
+    def forward(self, weight, bias):
         if is_tracing_state():
             with no_jit_trace():
                 return weight.mul_(self.binary_mask)
-        tmp_tensor = self._calc_training_binary_mask(weight)
-        return apply_binary_mask_impl(tmp_tensor, weight)
+        tmp_wtensor, tmp_btensor = self._calc_training_binary_mask(weight, bias)
+        wtensor = apply_binary_mask_impl(tmp_wtensor, weight)
+        btensor = apply_binary_mask_impl(tmp_btensor, bias)
+        return wtensor, btensor
 
-    def apply_binary_mask(self, weight):
-        return self.weight_ctx.apply_binary_mask(weight)
+    def _calc_training_binary_mask(self, weight, bias):
+        if self.training and not self.frozen:
+            w_mask = binary_mask_by_threshold(
+                self._expand_importance(self._weight_importance), 
+                self._masking_threshold
+            )
+            self.weight_ctx.binary_mask = w_mask
+            
+            b_mask = binary_mask_by_threshold(
+                self._expand_importance(self._bias_importance, isbias=True), 
+                self._masking_threshold
+            )
+            self.bias_ctx.binary_mask = b_mask
+            return w_mask, b_mask
+        else:
+            return self.weight_ctx.binary_mask, self.bias_ctx.binary_mask
+
+
+    def apply_binary_mask(self, param_tensor, isbias=False):
+        if isbias is True:
+            return self.bias_ctx.apply_binary_mask(param_tensor)
+        return self.weight_ctx.apply_binary_mask(param_tensor)
         
     def _get_importance_shape(self, weight_shape):
         #TODO:remove  weight_shape, r=32, c=32):
@@ -152,30 +190,23 @@ def _get_importance_shape(self, weight_shape):
                 score_shape.append(dim//factor)
             return score_shape, True
 
-    def _expand_importance(self, importance):
+    def _expand_importance(self, importance, isbias=False):
         #TODO only works dense layer for now
-        if self._bool_expand_weight_importance:
-            return importance.repeat_interleave(
-                self.sparse_cfg.sparse_factors[0], dim=0).repeat_interleave(
-                self.sparse_cfg.sparse_factors[1], dim=1)
+        if self._bool_expand_importance:
+            if isbias is False:
+                return importance.repeat_interleave(
+                    self.sparse_cfg.sparse_factors[0], dim=0).repeat_interleave(
+                    self.sparse_cfg.sparse_factors[1], dim=1)
+            else:
+                return importance.repeat_interleave(
+                    self.sparse_cfg.sparse_factors[0], dim=0)
         return importance
 
-    def _calc_training_binary_mask(self, weight):
-        if self.training and not self.frozen:
-            _mask = binary_mask_by_threshold(
-                self._expand_importance(self._weight_importance), 
-                self._masking_threshold
-            )
-            self.weight_ctx.binary_mask = _mask
-            return _mask
-        else:
-            return self.weight_ctx.binary_mask
-
     def loss(self):
-        return self.lmbd * (torch.norm(
-                torch.sigmoid(
-                    self._expand_importance(self._weight_importance)
-                ), p=1) / self._weight_importance.numel())
+        return self.lmbd * (
+            torch.norm(torch.sigmoid(self._expand_importance(self._weight_importance)), p=1) / self._weight_importance.numel() + \
+            torch.norm(torch.sigmoid(self._expand_importance(self._bias_importance, isbias=True)), p=1) / self._bias_importance.numel()
+        )
 
     def get_structured_mask(self, grain_size=None):
         if grain_size is None:
@@ -196,11 +227,18 @@ def __init__(self, module):
 
     def hook_fn(self, module, destination, prefix, local_metadata):
         module.weight_ctx.binary_mask = binary_mask_by_threshold(
-                                module._expand_importance(module.importance), 
+                                module._expand_importance(module._weight_importance), 
                                 module.masking_threshold
                              )
         destination[prefix + 'weight_ctx._binary_mask'] = module.weight_ctx.binary_mask
+
+        if module.prune_bias is True:
+            module.bias_ctx.binary_mask = binary_mask_by_threshold(
+                                module._expand_importance(module._bias_importance, isbias=True), 
+                                module.masking_threshold
+                            )
+            destination[prefix + 'bias_ctx._binary_mask'] = module.bias_ctx.binary_mask
         return destination
 
     def close(self):
-        self.hook.remove()
+        self.hook.remove()
\ No newline at end of file

From a11c72d37f03c21140cf8a57ce6d7dd25f552d1a Mon Sep 17 00:00:00 2001
From: Vui Seng Chua <vui.seng.chua@intel.com>
Date: Sun, 20 Feb 2022 21:33:10 -0800
Subject: [PATCH 6/7] Enable composite mvmt sparsity and quantization, enable
 onnx generation with mask burnt into state dict

---
 nncf/common/graph/transformations/commands.py |  1 +
 nncf/torch/exporter.py                        | 12 ++++
 nncf/torch/graph/transformations/commands.py  |  3 +-
 nncf/torch/nncf_network.py                    |  5 +-
 nncf/torch/sparsity/movement/algo.py          | 63 ++++++++++++++++++-
 nncf/torch/sparsity/movement/functions.py     |  2 +-
 nncf/torch/sparsity/movement/layers.py        | 38 ++++++-----
 7 files changed, 103 insertions(+), 21 deletions(-)

diff --git a/nncf/common/graph/transformations/commands.py b/nncf/common/graph/transformations/commands.py
index ffb4a040a53..c243694e51e 100644
--- a/nncf/common/graph/transformations/commands.py
+++ b/nncf/common/graph/transformations/commands.py
@@ -82,6 +82,7 @@ class TargetType(OrderedEnum):
     OPERATION_WITH_WEIGHTS = 5
     OPERATOR_PRE_HOOK = 6
     OPERATOR_POST_HOOK = 7
+    OPERATION_WITH_WEIGHT_WT_BIAS = 8
 
     def get_state(self) -> Dict[str, Any]:
         """
diff --git a/nncf/torch/exporter.py b/nncf/torch/exporter.py
index 386428225a6..03c272dc6b8 100644
--- a/nncf/torch/exporter.py
+++ b/nncf/torch/exporter.py
@@ -114,11 +114,23 @@ def _export_to_onnx(self, save_path: str) -> None:
                 retval = dummy_forward(self._model)
                 output_names = generate_output_names_list(count_tensors(retval))
 
+            import os
+            torch.onnx.export(model, tuple(input_tensor_list), os.path.join(os.path.dirname(save_path), "graph-only."+os.path.basename(save_path)),
+                              export_params=False,
+                              input_names=input_names,
+                              output_names=output_names,
+                              enable_onnx_checker=False,
+                              opset_version=10,
+                              do_constant_folding=False,
+                              # Do not fuse Conv+BN in ONNX. May cause dropout elements to appear in ONNX.
+                              training=True)
+
             torch.onnx.export(model, tuple(input_tensor_list), save_path,
                               input_names=input_names,
                               output_names=output_names,
                               enable_onnx_checker=False,
                               opset_version=10,
+                              do_constant_folding=False,
                               # Do not fuse Conv+BN in ONNX. May cause dropout elements to appear in ONNX.
                               training=True)
             model.enable_dynamic_graph_building()
diff --git a/nncf/torch/graph/transformations/commands.py b/nncf/torch/graph/transformations/commands.py
index a29e93f903a..cb24bcab506 100644
--- a/nncf/torch/graph/transformations/commands.py
+++ b/nncf/torch/graph/transformations/commands.py
@@ -19,7 +19,8 @@ class PTTargetPointStateNames:
 class PTTargetPoint(TargetPoint):
     _OPERATION_TYPES = [TargetType.PRE_LAYER_OPERATION,
                         TargetType.POST_LAYER_OPERATION,
-                        TargetType.OPERATION_WITH_WEIGHTS]
+                        TargetType.OPERATION_WITH_WEIGHTS,
+                        TargetType.OPERATION_WITH_WEIGHT_WT_BIAS]
     _HOOK_TYPES = [TargetType.OPERATOR_PRE_HOOK,
                    TargetType.OPERATOR_POST_HOOK]
 
diff --git a/nncf/torch/nncf_network.py b/nncf/torch/nncf_network.py
index b3f2c142fcb..152da7e88ba 100644
--- a/nncf/torch/nncf_network.py
+++ b/nncf/torch/nncf_network.py
@@ -120,6 +120,7 @@ class PTInsertionPoint:
         TargetType.PRE_LAYER_OPERATION: PTInsertionType.NNCF_MODULE_PRE_OP,
         TargetType.POST_LAYER_OPERATION: PTInsertionType.NNCF_MODULE_POST_OP,
         TargetType.OPERATION_WITH_WEIGHTS: PTInsertionType.NNCF_MODULE_PRE_OP,
+        TargetType.OPERATION_WITH_WEIGHT_WT_BIAS: PTInsertionType.NNCF_MODULE_PRE_OP,
         TargetType.OPERATOR_PRE_HOOK: PTInsertionType.OPERATOR_PRE_HOOK,
         TargetType.OPERATOR_POST_HOOK: PTInsertionType.OPERATOR_POST_HOOK
     }
@@ -707,8 +708,8 @@ def transform(self, transformation_layout: PTTransformationLayout) -> NNCFNetwor
                                      input_port_id=target_point.input_port_id)
             fn = transformation_command.fn
             if target_point.type is TargetType.OPERATION_WITH_WEIGHTS:
-                # TODO: how to set this according
-                # fn = UpdateWeight(fn)
+                fn = UpdateWeight(fn)
+            elif target_point.type is TargetType.OPERATION_WITH_WEIGHT_WT_BIAS:
                 fn = UpdateWeightAndBias(fn)
             tup = (fn, transformation_command.priority)
             if pt_ip not in fns_grouped_by_points:
diff --git a/nncf/torch/sparsity/movement/algo.py b/nncf/torch/sparsity/movement/algo.py
index fd9402e1e5c..30f761f04e1 100644
--- a/nncf/torch/sparsity/movement/algo.py
+++ b/nncf/torch/sparsity/movement/algo.py
@@ -21,12 +21,17 @@
 from nncf.torch.algo_selector import PT_COMPRESSION_ALGORITHMS
 from nncf.api.compression import CompressionStage
 from nncf.common.graph import NNCFNode
+from nncf.common.graph.transformations.commands import TargetType
+from nncf.common.utils.logger import logger as nncf_logger
 from nncf.torch.compression_method_api import PTCompressionAlgorithmController
 from nncf.torch.nncf_network import NNCFNetwork
 from nncf.torch.sparsity.base_algo import BaseSparsityAlgoBuilder, BaseSparsityAlgoController, SparseModuleInfo
+from nncf.torch.graph.transformations.commands import PTInsertionCommand
+from nncf.torch.graph.transformations.commands import PTTargetPoint
+from nncf.torch.graph.transformations.commands import TransformationPriority
 from nncf.torch.sparsity.movement.layers import MovementSparsifier, SparseConfig, SparseStructure
 from nncf.torch.sparsity.movement.loss import ImportanceLoss, SparseLossForPerLayerSparsity
-from nncf.torch.utils import get_world_size
+from nncf.torch.utils import get_world_size, get_model_device
 from nncf.common.utils.helpers import matches_any
 from nncf.common.accuracy_aware_training.training_loop import ADAPTIVE_COMPRESSION_CONTROLLERS
 from nncf.torch.sparsity.collector import PTSparseModelStatisticsCollector
@@ -44,6 +49,34 @@
 
 @PT_COMPRESSION_ALGORITHMS.register('movement_sparsity')
 class MovementSparsityBuilder(BaseSparsityAlgoBuilder):
+    def _sparsify_weights(self, target_model: NNCFNetwork) -> List[PTInsertionCommand]:
+        device = get_model_device(target_model)
+        sparsified_module_nodes = target_model.get_weighted_original_graph_nodes(
+            nncf_module_names=self.compressed_nncf_module_names)
+        insertion_commands = []
+        for module_node in sparsified_module_nodes:
+            node_name = module_node.node_name
+
+            if not self._should_consider_scope(node_name):
+                nncf_logger.info("Ignored adding Weight Sparsifier in scope: {}".format(node_name))
+                continue
+
+            nncf_logger.info("Adding Weight Sparsifier in scope: {}".format(node_name))
+            compression_lr_multiplier = \
+                self.config.get_redefinable_global_param_value_for_algo('compression_lr_multiplier',
+                                                                        self.name)
+            operation = self.create_weight_sparsifying_operation(module_node, compression_lr_multiplier)
+            hook = operation.to(device)
+            # TODO: hardcoded to OPERATION_WITH_WEIGHT_WT_BIAS
+            insertion_commands.append(PTInsertionCommand(PTTargetPoint(TargetType.OPERATION_WITH_WEIGHT_WT_BIAS,
+                                                                       target_node_name=node_name),
+                                                         hook, TransformationPriority.SPARSIFICATION_PRIORITY))
+            sparsified_module = target_model.get_containing_module(node_name)
+            self._sparsified_module_info.append(
+                SparseModuleInfo(node_name, sparsified_module, hook))
+
+        return insertion_commands
+
     def create_weight_sparsifying_operation(self, target_module_node: NNCFNode, compression_lr_multiplier: float):
         sparse_cfg=None
         if 'sparse_structure_by_scopes' in self._algo_config:
@@ -202,6 +235,34 @@ def prepare_for_export(self):
         """
         self._propagate_masks()
 
+    def _propagate_masks(self):
+        def calc_sparsity(tensor):
+            return 1-tensor.count_nonzero()/tensor.numel()
+        # nncf_logger.debug("MVMT - Propagating pruning masks")
+        # 1. Propagate masks for all modules
+        from collections import OrderedDict
+        sparse_sd = OrderedDict()
+        with torch.no_grad():    
+            for sparse_info in self.sparsified_module_info:
+                for n, m in self.model.named_modules():
+                    if m == sparse_info.module:
+                        # print("- SparseModule: {} -".format(n))
+                        # print("\tw_mask sparsity: {:.3f}".format(calc_sparsity(sparse_info.operand.weight_ctx.binary_mask)))
+                        # print("\tw_sd   sparsity: {:.3f}".format(calc_sparsity(m.weight)))
+                        sparse_sd[n+'.weight'] = sparse_info.operand.apply_binary_mask(m.weight)
+                        # print("\t*w_sd  sparsity: {:.3f}".format(calc_sparsity(sparse_sd[n+'.weight'])))
+
+                        if hasattr(m, 'bias'):
+                            # print("\tb_mask sparsity: {:.3f}".format(calc_sparsity(sparse_info.operand.bias_ctx.binary_mask)))
+                            # print("\tb_sd   sparsity: {:.3f}".format(calc_sparsity(m.bias)))
+                            sparse_sd[n+'.bias'] = sparse_info.operand.apply_binary_mask(m.bias, isbias=True)
+                            # print("\t*w_sd  sparsity: {:.3f}".format(calc_sparsity(sparse_sd[n+'.bias'])))
+
+        model_sd = self.model.state_dict()
+        for k, v in sparse_sd.items():
+            assert k in model_sd, "key not exists!"
+            model_sd[k] = sparse_sd[k]
+        self.model.load_state_dict(model_sd)
 
     def print_prunableops_per_group(self):
         for group, op_list in self.prunableops_per_group.items():
diff --git a/nncf/torch/sparsity/movement/functions.py b/nncf/torch/sparsity/movement/functions.py
index ac2b3edea26..97db5375d7b 100644
--- a/nncf/torch/sparsity/movement/functions.py
+++ b/nncf/torch/sparsity/movement/functions.py
@@ -16,7 +16,7 @@
 from nncf.torch.dynamic_graph.patch_pytorch import register_operator
 from nncf.torch.functions import STThreshold
 
-
+@register_operator()
 def binary_mask_by_threshold(importance, threshold=0.5, sigmoid=True, max_percentile=0.98):
     with torch.no_grad():
         if sigmoid is True:
diff --git a/nncf/torch/sparsity/movement/layers.py b/nncf/torch/sparsity/movement/layers.py
index 093c42f762f..5f6c5a02b45 100644
--- a/nncf/torch/sparsity/movement/layers.py
+++ b/nncf/torch/sparsity/movement/layers.py
@@ -63,22 +63,28 @@ def __init__(self,
         self.weight_ctx = BinaryMask(weight_shape)
         self._weight_importance_shape, self._bool_expand_importance = self._get_importance_shape(weight_shape)
         self._weight_importance = CompressionParameter(
-                                # torch.rand(self._weight_importance_shape),
-                                torch.zeros(self._weight_importance_shape),
+                                torch.rand(self._weight_importance_shape),
+                                # torch.zeros(self._weight_importance_shape),
                                 requires_grad=not self.frozen,
                                 compression_lr_multiplier=compression_lr_multiplier)
-        self.weight_ctx.binary_mask = binary_mask_by_threshold(self._weight_importance, self._masking_threshold)
+        self.weight_ctx.binary_mask = binary_mask_by_threshold(
+                                            self._expand_importance(self._weight_importance), 
+                                            self._masking_threshold
+                                        )
 
         if self.prune_bias is True:
             bias_shape = target_module_node.layer_attributes.get_bias_shape()
             self.bias_ctx = BinaryMask(bias_shape)
             self._bias_importance_shape = self._weight_importance_shape[0]
             self._bias_importance = CompressionParameter(
-                                # torch.rand(self._bias_importance_shape),
-                                torch.zeros(self._bias_importance_shape),
+                                torch.rand(self._bias_importance_shape),
+                                # torch.zeros(self._bias_importance_shape),
                                 requires_grad=not self.frozen,
                                 compression_lr_multiplier=compression_lr_multiplier)
-            self.bias_ctx.binary_mask = binary_mask_by_threshold(self._bias_importance, self._masking_threshold)
+            self.bias_ctx.binary_mask = binary_mask_by_threshold(
+                                            self._expand_importance(self._bias_importance, isbias=True), 
+                                            self._masking_threshold
+                                        )
 
         self.mask_calculation_hook = MaskCalculationHook(self)
 
@@ -116,13 +122,13 @@ def unfreeze_importance(self):
 
 
     def extra_repr(self):
-        return '{}, {}'.format(
+        return 'sparse_structure: {}, {}'.format(
             self.sparse_cfg.mode, self.sparse_cfg.sparse_args)
 
     def forward(self, weight, bias):
         if is_tracing_state():
             with no_jit_trace():
-                return weight.mul_(self.binary_mask)
+                return weight.mul_(self.weight_ctx.binary_mask), bias.mul_(self.bias_ctx.binary_mask)
         tmp_wtensor, tmp_btensor = self._calc_training_binary_mask(weight, bias)
         wtensor = apply_binary_mask_impl(tmp_wtensor, weight)
         btensor = apply_binary_mask_impl(tmp_btensor, bias)
@@ -226,17 +232,17 @@ def __init__(self, module):
         self.hook = module._register_state_dict_hook(self.hook_fn)
 
     def hook_fn(self, module, destination, prefix, local_metadata):
-        module.weight_ctx.binary_mask = binary_mask_by_threshold(
-                                module._expand_importance(module._weight_importance), 
-                                module.masking_threshold
-                             )
+        # module.weight_ctx.binary_mask = binary_mask_by_threshold(
+        #                         module._expand_importance(module._weight_importance), 
+        #                         module.masking_threshold
+        #                      )
         destination[prefix + 'weight_ctx._binary_mask'] = module.weight_ctx.binary_mask
 
         if module.prune_bias is True:
-            module.bias_ctx.binary_mask = binary_mask_by_threshold(
-                                module._expand_importance(module._bias_importance, isbias=True), 
-                                module.masking_threshold
-                            )
+            # module.bias_ctx.binary_mask = binary_mask_by_threshold(
+            #                     module._expand_importance(module._bias_importance, isbias=True), 
+            #                     module.masking_threshold
+            #                 )
             destination[prefix + 'bias_ctx._binary_mask'] = module.bias_ctx.binary_mask
         return destination
 

From 49e0a1c8bf13dcac3ff747f8d0a68d11268d4009 Mon Sep 17 00:00:00 2001
From: Vui Seng Chua <vui.seng.chua@intel.com>
Date: Mon, 28 Feb 2022 10:09:20 -0800
Subject: [PATCH 7/7] initial commit for fill flow (major changes)

---
 nncf/common/sparsity/schedulers.py     |   1 +
 nncf/torch/exporter.py                 |  20 +-
 nncf/torch/sparsity/movement/algo.py   | 318 ++++++++++++++++++++++---
 nncf/torch/sparsity/movement/layers.py |  21 +-
 4 files changed, 311 insertions(+), 49 deletions(-)

diff --git a/nncf/common/sparsity/schedulers.py b/nncf/common/sparsity/schedulers.py
index c6e5feb3b97..56c3866ae13 100644
--- a/nncf/common/sparsity/schedulers.py
+++ b/nncf/common/sparsity/schedulers.py
@@ -385,6 +385,7 @@ def schedule_threshold(self):
         else:
             self.current_importance_threshold  = self._calculate_threshold_level()
 
+        # self.current_importance_threshold = 0.1
         self._update_importance_masking_threshold()
         # if _cached_threshold != self.current_importance_threshold  or _cached_regu_lambda != self.current_importance_lambda:
         #     for n, m in self._controller.model.named_modules():
diff --git a/nncf/torch/exporter.py b/nncf/torch/exporter.py
index 03c272dc6b8..42b1dabd616 100644
--- a/nncf/torch/exporter.py
+++ b/nncf/torch/exporter.py
@@ -114,16 +114,18 @@ def _export_to_onnx(self, save_path: str) -> None:
                 retval = dummy_forward(self._model)
                 output_names = generate_output_names_list(count_tensors(retval))
 
+            DEBUG=False
             import os
-            torch.onnx.export(model, tuple(input_tensor_list), os.path.join(os.path.dirname(save_path), "graph-only."+os.path.basename(save_path)),
-                              export_params=False,
-                              input_names=input_names,
-                              output_names=output_names,
-                              enable_onnx_checker=False,
-                              opset_version=10,
-                              do_constant_folding=False,
-                              # Do not fuse Conv+BN in ONNX. May cause dropout elements to appear in ONNX.
-                              training=True)
+            if DEBUG is True:
+                torch.onnx.export(model, tuple(input_tensor_list), os.path.join(os.path.dirname(save_path), "graph-only."+os.path.basename(save_path)),
+                                export_params=False,
+                                input_names=input_names,
+                                output_names=output_names,
+                                enable_onnx_checker=False,
+                                opset_version=10,
+                                do_constant_folding=False,
+                                # Do not fuse Conv+BN in ONNX. May cause dropout elements to appear in ONNX.
+                                training=True)
 
             torch.onnx.export(model, tuple(input_tensor_list), save_path,
                               input_names=input_names,
diff --git a/nncf/torch/sparsity/movement/algo.py b/nncf/torch/sparsity/movement/algo.py
index 30f761f04e1..b5e52d02f3c 100644
--- a/nncf/torch/sparsity/movement/algo.py
+++ b/nncf/torch/sparsity/movement/algo.py
@@ -11,7 +11,7 @@
  limitations under the License.
 """
 from copy import deepcopy
-from typing import List
+from typing import DefaultDict, List, OrderedDict
 
 import torch
 import torch.distributed as dist
@@ -40,11 +40,13 @@
 from nncf.common.sparsity.statistics import MovementSparsityStatistics
 from nncf.common.statistics import NNCFStatistics
 from nncf.experimental.torch.search_building_blocks.search_blocks import BuildingBlock, get_building_blocks
-from collections import namedtuple
+from collections import defaultdict, namedtuple
 from nncf.torch.dynamic_graph.operation_address import OperationAddress
 import networkx as nx
 from nncf.torch.layers import NNCF_MODULES_OP_NAMES
 import os
+import numpy as np
+import pandas as pd
 
 
 @PT_COMPRESSION_ALGORITHMS.register('movement_sparsity')
@@ -98,6 +100,40 @@ def create_weight_sparsifying_operation(self, target_module_node: NNCFNode, comp
     def _build_controller(self, model: NNCFNetwork) -> PTCompressionAlgorithmController:
         return MovementSparsityController(model, self._sparsified_module_info, self.config)
 
+class StructuredMask:
+    def __init__(self, 
+                 target_module_node, 
+                 sparsifying_node_name, 
+                 grid_size,
+                 dependent_group_id,
+                 sparse_module_info):
+
+        self.target_module_node=target_module_node
+        self.sparsifying_node_name=sparsifying_node_name
+        self.grid_size=grid_size
+        self.dependent_group_id=dependent_group_id
+        self.sparse_module_info=sparse_module_info
+
+    @property
+    def independent_structured_mask(self):
+        return self._independent_structured_mask
+    
+    @independent_structured_mask.setter
+    def independent_structured_mask(self, tensor):
+        with torch.no_grad():
+            self._independent_structured_mask = tensor
+            # self._independent_structured_mask.set_(tensor)
+
+    @property
+    def dependent_structured_mask(self):
+        return self._dependent_structured_mask
+    
+    @dependent_structured_mask.setter
+    def dependent_structured_mask(self, tensor):
+        # TODO: check dim
+        with torch.no_grad():
+            self._dependent_structured_mask = tensor
+            # self._dependent_structured_mask.set_(tensor)
 
 @ADAPTIVE_COMPRESSION_CONTROLLERS.register('pt_movement_sparsity')
 class MovementSparsityController(BaseSparsityAlgoController):
@@ -125,6 +161,7 @@ def __init__(self, target_model: NNCFNetwork, sparsified_module_info: List[Spars
         self.config = config
         self.prunableops_per_group = self._get_group_of_prunable_ops()
         # self.visualize_groups_of_prunables()
+        self.create_structured_sparsity_context()
 
     def compression_stage(self) -> CompressionStage:
         if self._mode == 'local':
@@ -190,44 +227,253 @@ def statistics(self, quickly_collected_only=False) -> NNCFStatistics:
         nncf_stats.register('movement_sparsity', stats)
         return nncf_stats
 
-    @property
-    def compression_rate(self):
-        return self.statistics().movement_sparsity.model_statistics.sparsity_level
+    def create_structured_sparsity_context(self):
+        DEBUG=False
+        # Structured_mask per tensor -------------------
+        node_name_sparse_mod_info_map = {sparse_info.module_node_name: sparse_info for sparse_info in self.sparsified_module_info}
+        self.node_name_sparse_mod_info_map = node_name_sparse_mod_info_map
 
-    def _propagate_masks(self):
-        # nncf_logger.debug("MVMT - Propagating pruning masks")
-        # 1. Propagate masks for all modules
-        from collections import OrderedDict
-        sparse_sd = OrderedDict()
-        with torch.no_grad():    
-            for sparse_info in self.sparsified_module_info:
-                for n, m in self.model.named_modules():
-                    if m == sparse_info.module:
-                        # print(n, 1-sparse_info.operand.binary_mask.count_nonzero()/sparse_info.operand.binary_mask.numel())
-                        # print("pre", 1-m.weight.count_nonzero()/m.weight.numel())
-                        # print("mask", 1-sparse_info.operand.binary_mask.count_nonzero()/sparse_info.operand.binary_mask.numel())
-                        sparse_sd[n+'.weight'] = m.weight*sparse_info.operand.binary_mask
-                        # print("post", 1-sparse_sd[n+'.weight'].count_nonzero()/sparse_sd[n+'.weight'].numel())
-                # sd = sparse_info.module.state_dict()
-                # sd['weight'] = sparse_info.module.weight*sparse_info.operand.binary_mask
-                # sparse_info.module.load_state_dict(sd)
+        self.structured_ctx_by_group = defaultdict(list)
+        masks_per_group = dict()
+        op2namedmodule = dict()
 
-        model_sd = self.model.state_dict()
-        for k, v in sparse_sd.items():
-            assert k in model_sd, "key not exists!"
-            model_sd[k] = sparse_sd[k]
-        self.model.load_state_dict(model_sd)
+        for group_id, op_list in self.prunableops_per_group.items():
+            masks_per_group[group_id]=dict()
+            for op in op_list:
+                sparsifying_node_name = str(op.op_addr)
 
-        # init_output_masks_in_graph(graph, self.pruned_module_groups_info.get_all_nodes())
-        # MaskPropagationAlgorithm(graph, PT_PRUNING_OPERATOR_METATYPES).mask_propagation()
+                # find op's torch module name
+                for n, m in self.model.named_modules():
+                    if m == op.op_mod:
+                        op2namedmodule[sparsifying_node_name] = n
+                        break
+                
+                sparse_module_info = node_name_sparse_mod_info_map[sparsifying_node_name]
+
+                if any(map(sparsifying_node_name.__contains__, ['query','key','value'])):
+                    # these matrices must be pruned by group(s) of cols
+                    nrow_per_head = self.model.nncf_module.bert.config.hidden_size//self.model.nncf_module.bert.config.num_attention_heads
+                    ncol_per_head = self.model.nncf_module.bert.config.hidden_size
+                    grid_size = (nrow_per_head, ncol_per_head)
+                    
+                    if DEBUG is True:
+                        masks_per_group[group_id]['qkv_grain'] = grid_size
+                        mask = sparse_module_info.operand.get_structured_mask(grid_size)
+                        if 'qkv' not in masks_per_group[group_id]:
+                            masks_per_group[group_id]['qkv'] = [mask]
+                            masks_per_group[group_id]['qkv_nodes'] = [sparsifying_node_name]
+                        else:
+                            masks_per_group[group_id]['qkv'].append(mask)
+                            masks_per_group[group_id]['qkv_nodes'].append(sparsifying_node_name)
+                        print("{:15} | {:20} | {}".format('group_of_rows', str(mask.shape), sparsifying_node_name))
+
+                    structured_mask_ctx = StructuredMask(
+                                                sparse_module_info.module_node_name,
+                                                sparsifying_node_name,
+                                                grid_size,
+                                                group_id,
+                                                sparse_module_info)
+
+                    structured_mask_ctx.independent_structured_mask = sparse_module_info.operand.get_structured_mask(grid_size)
+                    sparse_module_info.operand.structured_mask_ctx = structured_mask_ctx
+                    self.structured_ctx_by_group[group_id].append(sparse_module_info.operand.structured_mask_ctx)
+
+                    if DEBUG is True:
+                        assert ((mask==structured_mask_ctx.independent_structured_mask).sum() == mask.numel()).item(), "qkv: Logical Bug, pls debug"
+                    
+                elif 'BertSelfOutput' in sparsifying_node_name:
+                    # this matrix must be pruned by group(s) of cols
+                    ncol_per_head = self.model.nncf_module.bert.config.hidden_size//self.model.nncf_module.bert.config.num_attention_heads
+                    nrow_per_head = self.model.nncf_module.bert.config.hidden_size
+                    grid_size = (nrow_per_head, ncol_per_head)
+
+                    if DEBUG is True:
+                        masks_per_group[group_id]['concat_grain'] = grid_size
+                        mask = sparse_module_info.operand.get_structured_mask(grid_size)
+                        masks_per_group[group_id]['concat'] = mask
+                        masks_per_group[group_id]['concat_node'] = sparsifying_node_name
+                        print("{:15} | {:20} | {}".format('group_of_cols', str(mask.shape), sparsifying_node_name))
+
+                    structured_mask_ctx = StructuredMask(
+                                                sparse_module_info.module_node_name,
+                                                sparsifying_node_name,
+                                                grid_size,
+                                                group_id,
+                                                sparse_module_info)
+
+                    structured_mask_ctx.independent_structured_mask = sparse_module_info.operand.get_structured_mask(grid_size)
+                    sparse_module_info.operand.structured_mask_ctx = structured_mask_ctx
+                    self.structured_ctx_by_group[group_id].append(sparse_module_info.operand.structured_mask_ctx)
+
+                    if DEBUG is True:
+                        assert ((mask==structured_mask_ctx.independent_structured_mask).sum() == mask.numel()).item(), "BertSelfOutput: Logical Bug, pls debug"
+
+                elif any(map(sparsifying_node_name.__contains__, ['BertIntermediate','BertOutput'])):
+                    mask = sparse_module_info.operand.get_structured_mask()
+                    grid_size = sparse_module_info.operand.sparse_cfg.sparse_factors
+
+                    if DEBUG is True:
+                        if 'BertIntermediate' in sparsifying_node_name:
+                            masks_per_group[group_id]['ffnn_w1_grain'] = grid_size
+                            masks_per_group[group_id]['ffnn_w1'] = mask
+                            masks_per_group[group_id]['ffnn_w1_node'] = sparsifying_node_name
+                        elif 'BertOutput' in sparsifying_node_name:
+                            masks_per_group[group_id]['ffnn_w2_grain'] = grid_size
+                            masks_per_group[group_id]['ffnn_w2'] = mask
+                            masks_per_group[group_id]['ffnn_w2_node'] = sparsifying_node_name
+                        print("{:15} | {:20} | {}".format('per_dim', str(mask.shape), sparsifying_node_name))
+
+                    structured_mask_ctx = StructuredMask(
+                                                sparse_module_info.module_node_name,
+                                                sparsifying_node_name,
+                                                grid_size,
+                                                group_id,
+                                                sparse_module_info)
+
+                    structured_mask_ctx.independent_structured_mask = sparse_module_info.operand.get_structured_mask(grid_size)
+                    sparse_module_info.operand.structured_mask_ctx = structured_mask_ctx
+                    self.structured_ctx_by_group[group_id].append(sparse_module_info.operand.structured_mask_ctx)
+
+                    if DEBUG is True:
+                        assert ((mask==structured_mask_ctx.independent_structured_mask).sum() == mask.numel()).item(), "ffnn: Logical Bug, pls debug"
+                else:
+                    raise ValueError("Invalid entry, pls debug")
+        
+        self.op2namedmodule = op2namedmodule
+
+        # This Structure can be improved but good for now: TODO: revision of structure
+        # masks_per_group[group_id][sparsifying_node_name] = mask
+
+    def reset_independent_structured_mask(self):
+        for group_id, ctxes in self.structured_ctx_by_group.items():
+            for ctx in ctxes:
+                ctx.independent_structured_mask = ctx.sparse_module_info.operand.get_structured_mask(ctx.grid_size)
+
+    def populate_structured_mask(self):
+        def inflate_structured_mask(mask, grid_size):
+            assert len(mask.shape) == len(grid_size), "Unmatching dimension"
+            inflated_mask = mask.detach().clone()
+            for axis, repeat in enumerate(grid_size):
+                inflated_mask = inflated_mask.repeat_interleave(repeat, dim=axis)
+            return inflated_mask
+
+        for group_id, ctxes in self.structured_ctx_by_group.items():
+            for ctx in ctxes:
+                ctx.sparse_module_info.operand.set_structured_mask(
+                    inflate_structured_mask(ctx.dependent_structured_mask, ctx.grid_size)
+                )
+
+    def resolve_structured_mask(self):
+        for group_id, ctxes in self.structured_ctx_by_group.items():
+            allnodenames = list(map(lambda x: x.target_module_node, ctxes))
+            
+            if any(map(ctxes[0].target_module_node.__contains__, ['query','key','value','BertSelfOutput'])):
+                qid = list(map(lambda x: x.__contains__('query'), allnodenames)).index(True)
+                kid = list(map(lambda x: x.__contains__('key'), allnodenames)).index(True)
+                vid = list(map(lambda x: x.__contains__('value'), allnodenames)).index(True)
+                oid = list(map(lambda x: x.__contains__('BertSelfOutput'), allnodenames)).index(True)
+
+                coarse_mask = ctxes[qid].independent_structured_mask.logical_or(
+                                ctxes[kid].independent_structured_mask).logical_or(
+                                    ctxes[vid].independent_structured_mask).logical_or(
+                                        ctxes[oid].independent_structured_mask.transpose(0, 1)
+                                    ).to(torch.float32)
+                ctxes[qid].dependent_structured_mask = coarse_mask
+                ctxes[kid].dependent_structured_mask = coarse_mask
+                ctxes[vid].dependent_structured_mask = coarse_mask
+                ctxes[oid].dependent_structured_mask = coarse_mask.transpose(0, 1)
+            elif any(map(ctxes[0].target_module_node.__contains__, ['BertIntermediate','BertOutput'])):
+                w1_id = list(map(lambda x: x.__contains__('BertIntermediate'), allnodenames)).index(True)
+                w2_id = list(map(lambda x: x.__contains__('BertOutput'), allnodenames)).index(True)
+                coarse_mask = ctxes[w1_id].independent_structured_mask.logical_or(
+                                ctxes[w2_id].independent_structured_mask.transpose(0, 1)
+                              ).to(torch.float32)
+
+                ctxes[w1_id].dependent_structured_mask = coarse_mask
+                ctxes[w2_id].dependent_structured_mask = coarse_mask.transpose(0, 1)
+            else:
+                raise ValueError("logical bug, pls debug")
+
+        # # Structured_mask alignment by group -------------------
+
+        # for group_id, mask_dict in masks_per_group.items():
+        #     if 'qkv' in mask_dict:
+        #         final_mask = torch.zeros_like(mask_dict['qkv'][0]).to(torch.bool)
+        #         for each_mask in mask_dict['qkv']:
+        #             final_mask = final_mask.logical_or(each_mask)
+        #         final_mask = final_mask.logical_or(mask_dict['concat'].transpose(0, 1))
+        #         final_mask = final_mask.to(torch.float32)
+
+        #         masks_per_group[group_id]['final_structured_mask'] = dict(
+        #             qkv=final_mask,
+        #             concat=final_mask.transpose(0, 1)
+        #         )
+
+        #     elif 'ffnn_w1' in mask_dict:
+        #         final_mask = mask_dict['ffnn_w1'].logical_or(mask_dict['ffnn_w2'].transpose(0, 1))
+        #         final_mask = final_mask.to(torch.float32)
+
+        #         masks_per_group[group_id]['final_structured_mask'] = dict(
+        #             ffnn_w1=final_mask,
+        #             ffnn_w2=final_mask.transpose(0, 1)
+        #         )
+        #     else:
+        #         raise ValueError("Invalid entry, pls debug")
+
+    def report_structured_sparsity(self, dirname):
+        listofentry=[]
+        for group_id, ctxes in self.structured_ctx_by_group.items():
+            for ctx in ctxes:
+                nncf_graph_node_name = ctx.sparsifying_node_name
+                named_mod = self.op2namedmodule[nncf_graph_node_name]
+                block_id = group_id
+                orig_wshape = tuple(list(ctx.sparse_module_info.module.weight.shape))
+                if hasattr(ctx.sparse_module_info.module, 'bias'):
+                    orig_bshape = tuple(list(ctx.sparse_module_info.module.bias.shape))
+
+                if any(map(nncf_graph_node_name.__contains__, ['BertIntermediate','BertOutput'])):
+                    head_id_to_keep = 'skip reporting'
+                    if nncf_graph_node_name.__contains__('BertIntermediate'):
+                        final_wshape = (ctx.sparse_module_info.operand.weight_ctx.binary_mask.amax(dim=1).count_nonzero().item(), orig_wshape[1])
+                        final_bshape = (ctx.sparse_module_info.operand.bias_ctx.binary_mask.count_nonzero().item(),)
+                    else:
+                        final_wshape = (orig_wshape[0], ctx.sparse_module_info.operand.weight_ctx.binary_mask.amax(dim=0).count_nonzero().item())
+                        final_bshape = (ctx.sparse_module_info.operand.bias_ctx.binary_mask.count_nonzero().item(),)
+                else:
+                    ndiv = ctx.dependent_structured_mask.reshape(-1).shape[0]
+                    head_id_to_keep = torch.masked_select(torch.range(0, ndiv-1, dtype=int), 
+                                        ctx.dependent_structured_mask.reshape(-1).cpu().to(bool)).tolist()
+
+                    if any(map(nncf_graph_node_name.__contains__, ['query','key','value'])):
+                        # prune by row
+                        final_wshape = (ctx.sparse_module_info.operand.weight_ctx.binary_mask.amax(dim=1).count_nonzero().item(), orig_wshape[1])
+                        final_bshape = (ctx.sparse_module_info.operand.bias_ctx.binary_mask.count_nonzero().item(),)
+                    else:
+                        # prune by col
+                        final_wshape = (orig_wshape[0], ctx.sparse_module_info.operand.weight_ctx.binary_mask.amax(dim=0).count_nonzero().item())
+                        final_bshape = (ctx.sparse_module_info.operand.bias_ctx.binary_mask.count_nonzero().item(),)
+
+                listofentry.append(
+                    OrderedDict(
+                        pt_module_name=named_mod,
+                        block_id=block_id,
+                        weight_shape=orig_wshape,
+                        prune_w_shape=final_wshape,
+                        bias_shape=orig_bshape,
+                        prune_b_shape=final_bshape,
+                        head_id_to_keep=head_id_to_keep,
+                        nncf_graph_node=nncf_graph_node_name
+                    )
+                )
+        df = pd.DataFrame.from_dict(listofentry)
+        df.to_csv(os.path.join(dirname, 'structured_sparsity.csv'))
+        with open(os.path.join(dirname, 'structured_sparsity.md'), 'w') as f:
+            df.to_markdown(f)
 
-        # # 2. Set the masks for Batch/Group Norms
-        # pruned_node_modules = []
-        # for node, pruning_block, node_module in self._pruned_norms_operators:
-        #     if node_module not in pruned_node_modules:
-        #         # Setting masks for BN nodes
-        #         pruning_block.binary_filter_pruning_mask = node.data['output_mask'].tensor
-        #         pruned_node_modules.append(node_module)
+
+    @property
+    def compression_rate(self):
+        return self.statistics().movement_sparsity.model_statistics.sparsity_level
 
     def prepare_for_export(self):
         """
diff --git a/nncf/torch/sparsity/movement/layers.py b/nncf/torch/sparsity/movement/layers.py
index 5f6c5a02b45..836309916da 100644
--- a/nncf/torch/sparsity/movement/layers.py
+++ b/nncf/torch/sparsity/movement/layers.py
@@ -51,6 +51,9 @@ def __init__(self,
                  sparse_cfg=None):
         super().__init__()
 
+        DEBUG=False
+
+        self.target_module_node = target_module_node
         self.prune_bias = target_module_node.layer_attributes.bias
 
         self.frozen = frozen
@@ -63,8 +66,7 @@ def __init__(self,
         self.weight_ctx = BinaryMask(weight_shape)
         self._weight_importance_shape, self._bool_expand_importance = self._get_importance_shape(weight_shape)
         self._weight_importance = CompressionParameter(
-                                torch.rand(self._weight_importance_shape),
-                                # torch.zeros(self._weight_importance_shape),
+                                torch.rand(self._weight_importance_shape) if DEBUG is True else torch.zeros(self._weight_importance_shape),
                                 requires_grad=not self.frozen,
                                 compression_lr_multiplier=compression_lr_multiplier)
         self.weight_ctx.binary_mask = binary_mask_by_threshold(
@@ -77,8 +79,7 @@ def __init__(self,
             self.bias_ctx = BinaryMask(bias_shape)
             self._bias_importance_shape = self._weight_importance_shape[0]
             self._bias_importance = CompressionParameter(
-                                torch.rand(self._bias_importance_shape),
-                                # torch.zeros(self._bias_importance_shape),
+                                torch.rand(self._bias_importance_shape) if DEBUG is True else torch.zeros(self._bias_importance_shape),
                                 requires_grad=not self.frozen,
                                 compression_lr_multiplier=compression_lr_multiplier)
             self.bias_ctx.binary_mask = binary_mask_by_threshold(
@@ -224,8 +225,20 @@ def get_structured_mask(self, grain_size=None):
         structured_mask = structured_mask.reshape(temp_shape)
         structured_mask = structured_mask.amax(dim=(tuple((np.arange(len(self.weight_ctx.binary_mask.shape)) * 2 + 1))))
         # print("Mask Shape from {} to {}".format(structured_mask.shape, self.weight_ctx.binary_mask.shape))
+        if self.prune_bias is True:
+            structured_bias_mask_shape = structured_mask_shape[0]
+            structured_bias_mask = self.bias_ctx.binary_mask.detach().clone()
+            structured_bias_mask = structured_bias_mask.reshape((structured_bias_mask_shape, -1))
+            structured_bias_mask = structured_bias_mask.amax(dim=1)
+            dim_aligned = structured_bias_mask.repeat(structured_mask.shape[1]).reshape(-1, structured_mask.shape[1])
+            structured_mask = structured_mask.logical_or(dim_aligned).to(torch.float32)
         return structured_mask
 
+    def set_structured_mask(self, structured_mask):
+        self.weight_ctx.binary_mask=structured_mask
+        if self.prune_bias is True:
+            self.bias_ctx.binary_mask=structured_mask.amax(dim=1)
+
 class MaskCalculationHook():
     def __init__(self, module):
         # pylint: disable=protected-access