Skip to content

Commit f815ed2

Browse files
lizhouyufacebook-github-bot
authored andcommitted
Benchamrk framework for torchrec (#3072)
Summary: Pull Request resolved: #3072 # Benchmark framework for MPZCH ### Major changes - Add a `benchmark prober` in `torchrec/distributed/benchmark/benchmark_zch_utils.py` to collect and calculate the zero collision hash related metrics like hit count, insert count, and collision count. - Implement a `benchmark_zch_dlrmv2` local testbed in `torchrec/distributed/benchmark/benchmark_zch_dlrmv2.py`, which allows to profile a DLRMv2 model with and without the MPZCH enabled, and record the metrics including ZCH-related metrics, QPS, NE, and AUROC. - Add `mc_adapter` modules in `torchrec/modules/mc_adapter.py`. These modules enable seamless replacement of embedding collection and embedding bag collection modules with the managed collision version. - Add two dictionaries `self.table_name_on_device_remapped_ids_dict` and `self.table_name_on_device_input_ids_dict` in the `HashZchManagedCollisionModule` module in `torchrec/modules/hash_mc_modules.py` to record the remapped identities and input feature values to the MPZCH module on current rank respectively after input mapping. - Add `count_non_zch_collision.py` script to count the collision rate of non-zch modules after performing `murmur_hash3`. - Add the criteo kaggle dataset data loader in `torchrec/distributed/benchmark/data` and revise the `hashes` attribute of data pipeline in the `_get_in_memory_dataloader` function in the `torchrec/distributed/benchmark/data/dlrm_dataloader.py` file to pre-hash the input feature values to the passed-in argument `input_hash_size` (defaultly as 100000). - Note that we can change the `single_ttl` in the `HashZchSingleTtlScorer` module of `torchrec/modules/hash_mc_evictions.py` to change the eviticability of identities in each `HashZchManagedCollisionModule` module, since exiting benchmark workflow only takes several minutes on the subset of the criteo-kaggle dataset. By default the identities become evictable after one hour. This descrepency leads to non-eviction during the profiling process. ### Dataset - [Criteo Kaggle Small](https://drive.google.com/file/d/1__rPcUSa45FHkmnBwivuM7K4nMYWD7b7/view?usp=sharing) - [Criteo Kaggle](https://drive.google.com/file/d/1_lAbXTEOk5vlPGXd4UvTrxGV6sCPer_R/view?usp=drive_link) Differential Revision: D76150895
1 parent eb5a752 commit f815ed2

File tree

9 files changed

+2052
-1
lines changed

9 files changed

+2052
-1
lines changed

torchrec/distributed/benchmark/benchmark_zch_dlrmv2.py

Lines changed: 936 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
import copy
2+
import json
3+
import os
4+
from typing import Dict
5+
6+
import numpy as np
7+
8+
import torch
9+
import torch.nn as nn
10+
from torchrec.modules.mc_embedding_modules import ManagedCollisionEmbeddingCollection
11+
from torchrec.modules.mc_modules import (
12+
DistanceLFU_EvictionPolicy,
13+
ManagedCollisionCollection,
14+
MCHManagedCollisionModule,
15+
)
16+
17+
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
18+
19+
20+
class BenchmarkMCProbe(nn.Module):
21+
def __init__(
22+
self,
23+
mcec: Dict[str, ManagedCollisionEmbeddingCollection],
24+
mc_method: str, # method for managing collisions, one of ["zch", "mpzch"]
25+
rank: int, # rank of the current model shard
26+
log_file_folder: str = "benchmark_logs", # folder to store the logging file
27+
) -> None:
28+
super().__init__()
29+
# self._mcec is a pointer to the mcec object passed in
30+
self._mcec = mcec
31+
# record the mc_method
32+
self._mc_method = mc_method
33+
# initialize the logging file handler
34+
os.makedirs(log_file_folder, exist_ok=True)
35+
self._log_file_path = os.path.join(log_file_folder, f"rank_{rank}.json")
36+
self._rank = rank # record the rank of the current model shard
37+
# get the output_offsets of the mcec
38+
self.per_table_output_offsets = (
39+
{}
40+
) # dict of {table_name [str]: output_offsets [torch.Tensor]} TODO: find out relationship between table_name and feature_name
41+
if self._mc_method == "mpzch":
42+
for table_name, mcec_module in self._mcec.items():
43+
self.per_table_output_offsets[table_name] = (
44+
mcec_module._output_global_offset_tensor
45+
)
46+
# create a dictionary to store the state of mcec modules
47+
self.mcec_state = {}
48+
# create a dictionary to store the statistics of mch modules
49+
self._mch_stats = (
50+
{}
51+
) # dictionary of {table_name [str]: {metric_name [str]: metric_value [int]}}
52+
53+
# record mcec state to file
54+
def record_mcec_state(self, stage: str) -> None:
55+
"""
56+
record the state of mcec modules to the log file
57+
The recorded state is a dictionary of
58+
{{stage: {table_name: {metric_name: state}}}}
59+
It only covers for the current batch
60+
61+
params:
62+
stage (str): before_fwd, after_fwd
63+
return:
64+
None
65+
"""
66+
# check if the stage in the desired options
67+
assert stage in (
68+
"before_fwd",
69+
"after_fwd",
70+
), f"stage {stage} is not supported, valid options are before_fwd, after_fwd"
71+
# create a dictionary to store the state of mcec modules
72+
if stage not in self.mcec_state:
73+
self.mcec_state[stage] = {} # dict of {table_name: {metric_name: state}}
74+
# if the stage is before_fwd, only record the remapping_table
75+
# save the mcec table state for each embedding table
76+
self.mcec_state[stage][
77+
"table_state"
78+
] = {} # dict of {table_name: {"remapping_table": state}}
79+
for table_name, mc_module in self._mcec.items():
80+
self.mcec_state[stage]["table_state"][table_name] = {}
81+
#
82+
if self._mc_method == "zch":
83+
self.mcec_state[stage]["table_state"][table_name][
84+
"remapping_table"
85+
] = mc_module._mch_sorted_raw_ids
86+
# save t
87+
elif self._mc_method == "mpzch":
88+
self.mcec_state[stage]["table_state"][table_name]["remapping_table"] = (
89+
mc_module._hash_zch_identities.clone()
90+
.to_dense()
91+
.squeeze()
92+
.cpu()
93+
.numpy()
94+
.tolist()
95+
)
96+
else:
97+
raise NotImplementedError(
98+
f"mc method {self._mc_method} is not supported yet"
99+
)
100+
# for before_fwd, we only need to record the remapping_table
101+
if stage == "before_fwd":
102+
return
103+
# for after_fwd, we need to record the feature values
104+
# check if the "before_fwd" stage is recorded
105+
assert (
106+
"before_fwd" in self.mcec_state
107+
), "before_fwd stage is not recorded, please call record_mcec_state before calling record_mcec_state after_fwd"
108+
# create the dirctionary to store the mcec feature values before forward
109+
self.mcec_state["before_fwd"]["feature_values"] = {}
110+
# create the dirctionary to store the mcec feature values after forward
111+
self.mcec_state[stage]["feature_values"] = {} # dict of {table_name: state}
112+
# save the mcec feature values for each embedding table
113+
for table_name, mc_module in self._mcec.items():
114+
# record the remapped feature values
115+
if self._mc_method == "mpzch": # when using mpzch mc modules
116+
# record the remapped feature values first
117+
self.mcec_state[stage]["feature_values"][table_name] = (
118+
mc_module.table_name_on_device_remapped_ids_dict[table_name]
119+
.cpu()
120+
.numpy()
121+
.tolist()
122+
)
123+
# record the input feature values
124+
self.mcec_state["before_fwd"]["feature_values"][table_name] = (
125+
mc_module.table_name_on_device_input_ids_dict[table_name]
126+
.cpu()
127+
.numpy()
128+
.tolist()
129+
)
130+
# check if the input feature values list is empty
131+
if (
132+
len(self.mcec_state["before_fwd"]["feature_values"][table_name])
133+
== 0
134+
):
135+
# if the input feature values list is empty, make it a list of -2 with the same length as the remapped feature values
136+
self.mcec_state["before_fwd"]["feature_values"][table_name] = [
137+
-2
138+
] * len(self.mcec_state[stage]["feature_values"][table_name])
139+
else: # when using other zch mc modules # TODO: implement the feature value recording for zch
140+
raise NotImplementedError(
141+
f"zc method {self._mc_method} is not supported yet"
142+
)
143+
return
144+
145+
def get_mcec_state(self) -> Dict[str, Dict[str, Dict[str, Dict[str, int]]]]:
146+
"""
147+
get the state of mcec modules
148+
the state is a dictionary of
149+
{{stage: {table_name: {data_name: state}}}}
150+
"""
151+
return self.mcec_state
152+
153+
def save_mcec_state(self) -> None:
154+
"""
155+
save the state of mcec modules to the log file
156+
"""
157+
with open(self._log_file_path, "w") as f:
158+
json.dump(self.mcec_state, f, indent=4)
159+
160+
def get_mch_stats(self) -> Dict[str, Dict[str, int]]:
161+
"""
162+
get the statistics of mch modules
163+
the statistics is a dictionary of
164+
{{table_name: {metric_name: metric_value}}}
165+
"""
166+
return self._mch_stats
167+
168+
def update(self) -> None:
169+
"""
170+
Update the ZCH statistics for the current batch
171+
Params:
172+
None
173+
Return:
174+
None
175+
Require:
176+
self.mcec_state is not None and has recorded both "before_fwd" and "after_fwd" for a batch
177+
Update:
178+
self._mch_stats
179+
"""
180+
# create a dictionary to store the statistics for each batch
181+
batch_stats = (
182+
{}
183+
) # table_name: {hit_cnt: 0, total_cnt: 0, insert_cnt: 0, collision_cnt: 0}
184+
# calculate the statistics for each rank
185+
# get the remapping id table before forward pass and the input feature values
186+
rank_feature_value_before_fwd = self.mcec_state["before_fwd"]["feature_values"]
187+
# get the remapping id table after forward pass and the remapped feature ids
188+
rank_feature_value_after_fwd = self.mcec_state["after_fwd"]["feature_values"]
189+
# for each feature table in the remapped information
190+
for (
191+
feature_name,
192+
remapped_feature_ids,
193+
) in rank_feature_value_after_fwd.items():
194+
# create a new diction for the feature table if not created
195+
if feature_name not in batch_stats:
196+
batch_stats[feature_name] = {
197+
"hit_cnt": 0,
198+
"total_cnt": 0,
199+
"insert_cnt": 0,
200+
"collision_cnt": 0,
201+
"rank_total_cnt": 0,
202+
}
203+
# get the input faeture values
204+
input_feature_values = np.array(rank_feature_value_before_fwd[feature_name])
205+
# get the values stored in the remapping table for each remapped feature id after forward pass
206+
prev_remapped_values = np.array(
207+
self.mcec_state["before_fwd"]["table_state"][f"{feature_name}"][
208+
"remapping_table"
209+
]
210+
)[remapped_feature_ids]
211+
# get the values stored in the remapping table for each remapped feature id before forward pass
212+
after_remapped_values = np.array(
213+
self.mcec_state["after_fwd"]["table_state"][f"{feature_name}"][
214+
"remapping_table"
215+
]
216+
)[remapped_feature_ids]
217+
# count the number of same values in prev_remapped_values and after_remapped_values
218+
# hit count = number of remapped values that exist in the remapping table before forward pass
219+
this_rank_hits_count = np.sum(prev_remapped_values == input_feature_values)
220+
batch_stats[feature_name]["hit_cnt"] += int(this_rank_hits_count)
221+
# count the number of insertions
222+
## insert count = the decreased number of empty slots in the remapping table
223+
## before and after forward pass
224+
this_rank_insert_count = np.sum(prev_remapped_values == -1) - np.sum(
225+
after_remapped_values == -1
226+
)
227+
batch_stats[feature_name]["insert_cnt"] += int(this_rank_insert_count)
228+
# count the number of total values
229+
## total count = the number of remapped values in the remapping table after forward pass
230+
this_rank_total_count = int(len(remapped_feature_ids))
231+
# count the number of values redirected to the rank
232+
batch_stats[feature_name]["rank_total_cnt"] = this_rank_total_count
233+
batch_stats[feature_name]["total_cnt"] += this_rank_total_count
234+
# count the number of collisions
235+
# collision count = total count - hit count - insert count
236+
this_rank_collision_count = (
237+
this_rank_total_count - this_rank_hits_count - this_rank_insert_count
238+
)
239+
batch_stats[feature_name]["collision_cnt"] += int(this_rank_collision_count)
240+
self._mch_stats = batch_stats

0 commit comments

Comments
 (0)