|
| 1 | +import copy |
| 2 | +import json |
| 3 | +import os |
| 4 | +from typing import Dict |
| 5 | + |
| 6 | +import numpy as np |
| 7 | + |
| 8 | +import torch |
| 9 | +import torch.nn as nn |
| 10 | +from torchrec.modules.mc_embedding_modules import ManagedCollisionEmbeddingCollection |
| 11 | +from torchrec.modules.mc_modules import ( |
| 12 | + DistanceLFU_EvictionPolicy, |
| 13 | + ManagedCollisionCollection, |
| 14 | + MCHManagedCollisionModule, |
| 15 | +) |
| 16 | + |
| 17 | +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor |
| 18 | + |
| 19 | + |
| 20 | +class BenchmarkMCProbe(nn.Module): |
| 21 | + def __init__( |
| 22 | + self, |
| 23 | + mcec: Dict[str, ManagedCollisionEmbeddingCollection], |
| 24 | + mc_method: str, # method for managing collisions, one of ["zch", "mpzch"] |
| 25 | + rank: int, # rank of the current model shard |
| 26 | + log_file_folder: str = "benchmark_logs", # folder to store the logging file |
| 27 | + ) -> None: |
| 28 | + super().__init__() |
| 29 | + # self._mcec is a pointer to the mcec object passed in |
| 30 | + self._mcec = mcec |
| 31 | + # record the mc_method |
| 32 | + self._mc_method = mc_method |
| 33 | + # initialize the logging file handler |
| 34 | + os.makedirs(log_file_folder, exist_ok=True) |
| 35 | + self._log_file_path = os.path.join(log_file_folder, f"rank_{rank}.json") |
| 36 | + self._rank = rank # record the rank of the current model shard |
| 37 | + # get the output_offsets of the mcec |
| 38 | + self.per_table_output_offsets = ( |
| 39 | + {} |
| 40 | + ) # dict of {table_name [str]: output_offsets [torch.Tensor]} TODO: find out relationship between table_name and feature_name |
| 41 | + if self._mc_method == "mpzch": |
| 42 | + for table_name, mcec_module in self._mcec.items(): |
| 43 | + self.per_table_output_offsets[table_name] = ( |
| 44 | + mcec_module._output_global_offset_tensor |
| 45 | + ) |
| 46 | + # create a dictionary to store the state of mcec modules |
| 47 | + self.mcec_state = {} |
| 48 | + # create a dictionary to store the statistics of mch modules |
| 49 | + self._mch_stats = ( |
| 50 | + {} |
| 51 | + ) # dictionary of {table_name [str]: {metric_name [str]: metric_value [int]}} |
| 52 | + |
| 53 | + # record mcec state to file |
| 54 | + def record_mcec_state(self, stage: str) -> None: |
| 55 | + """ |
| 56 | + record the state of mcec modules to the log file |
| 57 | + The recorded state is a dictionary of |
| 58 | + {{stage: {table_name: {metric_name: state}}}} |
| 59 | + It only covers for the current batch |
| 60 | +
|
| 61 | + params: |
| 62 | + stage (str): before_fwd, after_fwd |
| 63 | + return: |
| 64 | + None |
| 65 | + """ |
| 66 | + # check if the stage in the desired options |
| 67 | + assert stage in ( |
| 68 | + "before_fwd", |
| 69 | + "after_fwd", |
| 70 | + ), f"stage {stage} is not supported, valid options are before_fwd, after_fwd" |
| 71 | + # create a dictionary to store the state of mcec modules |
| 72 | + if stage not in self.mcec_state: |
| 73 | + self.mcec_state[stage] = {} # dict of {table_name: {metric_name: state}} |
| 74 | + # if the stage is before_fwd, only record the remapping_table |
| 75 | + # save the mcec table state for each embedding table |
| 76 | + self.mcec_state[stage][ |
| 77 | + "table_state" |
| 78 | + ] = {} # dict of {table_name: {"remapping_table": state}} |
| 79 | + for table_name, mc_module in self._mcec.items(): |
| 80 | + self.mcec_state[stage]["table_state"][table_name] = {} |
| 81 | + # |
| 82 | + if self._mc_method == "zch": |
| 83 | + self.mcec_state[stage]["table_state"][table_name][ |
| 84 | + "remapping_table" |
| 85 | + ] = mc_module._mch_sorted_raw_ids |
| 86 | + # save t |
| 87 | + elif self._mc_method == "mpzch": |
| 88 | + self.mcec_state[stage]["table_state"][table_name]["remapping_table"] = ( |
| 89 | + mc_module._hash_zch_identities.clone() |
| 90 | + .to_dense() |
| 91 | + .squeeze() |
| 92 | + .cpu() |
| 93 | + .numpy() |
| 94 | + .tolist() |
| 95 | + ) |
| 96 | + else: |
| 97 | + raise NotImplementedError( |
| 98 | + f"mc method {self._mc_method} is not supported yet" |
| 99 | + ) |
| 100 | + # for before_fwd, we only need to record the remapping_table |
| 101 | + if stage == "before_fwd": |
| 102 | + return |
| 103 | + # for after_fwd, we need to record the feature values |
| 104 | + # check if the "before_fwd" stage is recorded |
| 105 | + assert ( |
| 106 | + "before_fwd" in self.mcec_state |
| 107 | + ), "before_fwd stage is not recorded, please call record_mcec_state before calling record_mcec_state after_fwd" |
| 108 | + # create the dirctionary to store the mcec feature values before forward |
| 109 | + self.mcec_state["before_fwd"]["feature_values"] = {} |
| 110 | + # create the dirctionary to store the mcec feature values after forward |
| 111 | + self.mcec_state[stage]["feature_values"] = {} # dict of {table_name: state} |
| 112 | + # save the mcec feature values for each embedding table |
| 113 | + for table_name, mc_module in self._mcec.items(): |
| 114 | + # record the remapped feature values |
| 115 | + if self._mc_method == "mpzch": # when using mpzch mc modules |
| 116 | + # record the remapped feature values first |
| 117 | + self.mcec_state[stage]["feature_values"][table_name] = ( |
| 118 | + mc_module.table_name_on_device_remapped_ids_dict[table_name] |
| 119 | + .cpu() |
| 120 | + .numpy() |
| 121 | + .tolist() |
| 122 | + ) |
| 123 | + # record the input feature values |
| 124 | + self.mcec_state["before_fwd"]["feature_values"][table_name] = ( |
| 125 | + mc_module.table_name_on_device_input_ids_dict[table_name] |
| 126 | + .cpu() |
| 127 | + .numpy() |
| 128 | + .tolist() |
| 129 | + ) |
| 130 | + # check if the input feature values list is empty |
| 131 | + if ( |
| 132 | + len(self.mcec_state["before_fwd"]["feature_values"][table_name]) |
| 133 | + == 0 |
| 134 | + ): |
| 135 | + # if the input feature values list is empty, make it a list of -2 with the same length as the remapped feature values |
| 136 | + self.mcec_state["before_fwd"]["feature_values"][table_name] = [ |
| 137 | + -2 |
| 138 | + ] * len(self.mcec_state[stage]["feature_values"][table_name]) |
| 139 | + else: # when using other zch mc modules # TODO: implement the feature value recording for zch |
| 140 | + raise NotImplementedError( |
| 141 | + f"zc method {self._mc_method} is not supported yet" |
| 142 | + ) |
| 143 | + return |
| 144 | + |
| 145 | + def get_mcec_state(self) -> Dict[str, Dict[str, Dict[str, Dict[str, int]]]]: |
| 146 | + """ |
| 147 | + get the state of mcec modules |
| 148 | + the state is a dictionary of |
| 149 | + {{stage: {table_name: {data_name: state}}}} |
| 150 | + """ |
| 151 | + return self.mcec_state |
| 152 | + |
| 153 | + def save_mcec_state(self) -> None: |
| 154 | + """ |
| 155 | + save the state of mcec modules to the log file |
| 156 | + """ |
| 157 | + with open(self._log_file_path, "w") as f: |
| 158 | + json.dump(self.mcec_state, f, indent=4) |
| 159 | + |
| 160 | + def get_mch_stats(self) -> Dict[str, Dict[str, int]]: |
| 161 | + """ |
| 162 | + get the statistics of mch modules |
| 163 | + the statistics is a dictionary of |
| 164 | + {{table_name: {metric_name: metric_value}}} |
| 165 | + """ |
| 166 | + return self._mch_stats |
| 167 | + |
| 168 | + def update(self) -> None: |
| 169 | + """ |
| 170 | + Update the ZCH statistics for the current batch |
| 171 | + Params: |
| 172 | + None |
| 173 | + Return: |
| 174 | + None |
| 175 | + Require: |
| 176 | + self.mcec_state is not None and has recorded both "before_fwd" and "after_fwd" for a batch |
| 177 | + Update: |
| 178 | + self._mch_stats |
| 179 | + """ |
| 180 | + # create a dictionary to store the statistics for each batch |
| 181 | + batch_stats = ( |
| 182 | + {} |
| 183 | + ) # table_name: {hit_cnt: 0, total_cnt: 0, insert_cnt: 0, collision_cnt: 0} |
| 184 | + # calculate the statistics for each rank |
| 185 | + # get the remapping id table before forward pass and the input feature values |
| 186 | + rank_feature_value_before_fwd = self.mcec_state["before_fwd"]["feature_values"] |
| 187 | + # get the remapping id table after forward pass and the remapped feature ids |
| 188 | + rank_feature_value_after_fwd = self.mcec_state["after_fwd"]["feature_values"] |
| 189 | + # for each feature table in the remapped information |
| 190 | + for ( |
| 191 | + feature_name, |
| 192 | + remapped_feature_ids, |
| 193 | + ) in rank_feature_value_after_fwd.items(): |
| 194 | + # create a new diction for the feature table if not created |
| 195 | + if feature_name not in batch_stats: |
| 196 | + batch_stats[feature_name] = { |
| 197 | + "hit_cnt": 0, |
| 198 | + "total_cnt": 0, |
| 199 | + "insert_cnt": 0, |
| 200 | + "collision_cnt": 0, |
| 201 | + "rank_total_cnt": 0, |
| 202 | + } |
| 203 | + # get the input faeture values |
| 204 | + input_feature_values = np.array(rank_feature_value_before_fwd[feature_name]) |
| 205 | + # get the values stored in the remapping table for each remapped feature id after forward pass |
| 206 | + prev_remapped_values = np.array( |
| 207 | + self.mcec_state["before_fwd"]["table_state"][f"{feature_name}"][ |
| 208 | + "remapping_table" |
| 209 | + ] |
| 210 | + )[remapped_feature_ids] |
| 211 | + # get the values stored in the remapping table for each remapped feature id before forward pass |
| 212 | + after_remapped_values = np.array( |
| 213 | + self.mcec_state["after_fwd"]["table_state"][f"{feature_name}"][ |
| 214 | + "remapping_table" |
| 215 | + ] |
| 216 | + )[remapped_feature_ids] |
| 217 | + # count the number of same values in prev_remapped_values and after_remapped_values |
| 218 | + # hit count = number of remapped values that exist in the remapping table before forward pass |
| 219 | + this_rank_hits_count = np.sum(prev_remapped_values == input_feature_values) |
| 220 | + batch_stats[feature_name]["hit_cnt"] += int(this_rank_hits_count) |
| 221 | + # count the number of insertions |
| 222 | + ## insert count = the decreased number of empty slots in the remapping table |
| 223 | + ## before and after forward pass |
| 224 | + this_rank_insert_count = np.sum(prev_remapped_values == -1) - np.sum( |
| 225 | + after_remapped_values == -1 |
| 226 | + ) |
| 227 | + batch_stats[feature_name]["insert_cnt"] += int(this_rank_insert_count) |
| 228 | + # count the number of total values |
| 229 | + ## total count = the number of remapped values in the remapping table after forward pass |
| 230 | + this_rank_total_count = int(len(remapped_feature_ids)) |
| 231 | + # count the number of values redirected to the rank |
| 232 | + batch_stats[feature_name]["rank_total_cnt"] = this_rank_total_count |
| 233 | + batch_stats[feature_name]["total_cnt"] += this_rank_total_count |
| 234 | + # count the number of collisions |
| 235 | + # collision count = total count - hit count - insert count |
| 236 | + this_rank_collision_count = ( |
| 237 | + this_rank_total_count - this_rank_hits_count - this_rank_insert_count |
| 238 | + ) |
| 239 | + batch_stats[feature_name]["collision_cnt"] += int(this_rank_collision_count) |
| 240 | + self._mch_stats = batch_stats |
0 commit comments