Skip to content

Commit 20ce4f0

Browse files
lizhouyufacebook-github-bot
authored andcommitted
Add tensorboard to display training and evaluation metrics and revise implementation to support DLRMv2 (#3163)
Summary: Pull Request resolved: #3163 ### Major changes - Add tensorboard to the benchmark testbed, specifically in `benchmark_zch.py`. - Count the number of unique values received by each rank in each epoch by revising `benchmark_zch_utils.py`. - Revise `data/non_zch_remapper.py` to not depend on `batch.to_dict()` method, instead it fetch dataclass `batch`'s attribute with the built-in `vars()` method. - Revise DLRMv2 model EBC config initialization to make the table name identical with the feature name. - Revise DLRMv2 configuration yaml file to set table size for each feature. - Revise the default value for "num_embeddings" parameter in `arguments.py` to None. Rollback Plan: Differential Revision: D77841795
1 parent f40295d commit 20ce4f0

File tree

9 files changed

+359
-122
lines changed

9 files changed

+359
-122
lines changed

torchrec/distributed/benchmark/benchmark_zch/arguments.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def parse_args(argv: List[str]) -> argparse.Namespace:
2525
parser.add_argument(
2626
"--num_embeddings", # ratio of feature ids to embedding table size # 3 axis: x-bath_idx; y-collisions; zembedding table sizes
2727
type=int,
28-
default=100_000,
28+
default=None,
2929
help="max_ind_size. The number of embeddings in each embedding table. Defaults"
3030
" to 100_000 if num_embeddings_per_feature is not supplied.",
3131
)

torchrec/distributed/benchmark/benchmark_zch/benchmark_zch.py

Lines changed: 255 additions & 54 deletions
Large diffs are not rendered by default.

torchrec/distributed/benchmark/benchmark_zch/benchmark_zch_utils.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,13 @@
1-
import argparse
2-
import copy
31
import json
42
import logging
53
import os
6-
from typing import Any, Dict
4+
from typing import Any, Dict, Set
75

86
import numpy as np
97

108
import torch
119
import torch.nn as nn
12-
import yaml
1310
from torchrec.modules.mc_embedding_modules import ManagedCollisionEmbeddingCollection
14-
from torchrec.modules.mc_modules import (
15-
DistanceLFU_EvictionPolicy,
16-
ManagedCollisionCollection,
17-
MCHManagedCollisionModule,
18-
)
19-
20-
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
2111

2212

2313
def get_module_from_instance(
@@ -104,6 +94,7 @@ def __init__(
10494
self._mch_stats: Dict[str, Any] = (
10595
{}
10696
) # dictionary of {table_name [str]: {metric_name [str]: metric_value [int]}}
97+
self.feature_name_unique_queried_values_set_dict: Dict[str, Set[int]] = {}
10798

10899
# record mcec state to file
109100
def record_mcec_state(self, stage: str) -> None:
@@ -260,6 +251,7 @@ def update(self) -> None:
260251
"collision_cnt": 0,
261252
"rank_total_cnt": 0,
262253
"num_empty_slots": 0,
254+
"num_unique_queries": 0,
263255
}
264256
# get the input faeture values
265257
input_feature_values = np.array(rank_feature_value_before_fwd[feature_name])
@@ -313,4 +305,16 @@ def update(self) -> None:
313305
this_rank_total_count - this_rank_hits_count - this_rank_insert_count
314306
)
315307
batch_stats[feature_name]["collision_cnt"] += int(this_rank_collision_count)
308+
# get the unique values in the input feature values
309+
if feature_name not in self.feature_name_unique_queried_values_set_dict:
310+
self.feature_name_unique_queried_values_set_dict[feature_name] = set(
311+
input_feature_values.tolist()
312+
)
313+
else:
314+
self.feature_name_unique_queried_values_set_dict[feature_name].update(
315+
set(input_feature_values.tolist())
316+
)
317+
batch_stats[feature_name]["num_unique_queries"] = len(
318+
self.feature_name_unique_queried_values_set_dict[feature_name]
319+
)
316320
self._mch_stats = batch_stats
Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1-
dataset_path: "/home/lizhouyu/oss_github/dlrm/torchrec_dlrm/criteo_1tb/criteo_kaggle_processed"
1+
dataset_path: "/home/lizhouyu/datasets/criteo_kaggle_processed"
22
batch_size: 4096
33
seed: 0
4+
multitask_configs:
5+
- task_name: is_click
6+
task_weight: 1
7+
task_type: classification

torchrec/distributed/benchmark/benchmark_zch/data/configs/kuairand_1k.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
dataset_path: "/home/lizhouyu/oss_github/generative-recommenders/generative_recommenders/dlrm_v3/data/KuaiRand-1K/data"
1+
dataset_path: "/home/lizhouyu/datasets/kuairand-1k/data"
22
batch_size: 16
33
train_split_percentage: 0.75
44
num_workers: 4

torchrec/distributed/benchmark/benchmark_zch/data/nonzch_remapper.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,24 @@ def __init__(
9999
)
100100
self._input_hash_size = input_hash_size
101101

102+
def get_batch_kjt_dict(self, batch: Batch) -> Dict[str, KeyedJaggedTensor]:
103+
"""
104+
Get the KJT in each batch
105+
Parameters:
106+
batch: the batch whose KJT is ought to be fetched
107+
Returns:
108+
batch_kjt_dict: a dictionary of [batch_attribute_name: KeyedJaggedTensor]
109+
where only attributes whose values are KeyedJaggedTensor are fetched.
110+
"""
111+
batch_kjt_dict = {} # create a dictionary for return
112+
batch_attr_dict = vars(batch) # get batch's attributes and values
113+
for batch_attr_name, batch_attr_value in batch_attr_dict.items():
114+
if isinstance(
115+
batch_attr_value, KeyedJaggedTensor
116+
): # only fetch attributes whose values are KeyedJaggedTensor
117+
batch_kjt_dict[batch_attr_name] = batch_attr_value
118+
return batch_kjt_dict
119+
102120
def remap(self, batch: Batch) -> Batch:
103121
# for all the attributes under batch, like batch.uih_features, batch.candidates_features,
104122
# get the kjt as a dict, and remap the kjt
@@ -118,7 +136,7 @@ def remap(self, batch: Batch) -> Batch:
118136
# candidates_features: KeyedJaggedTensor
119137

120138
# for every attribute in batch, remap the kjt
121-
for attr_name, feature_kjt_dict in batch.get_dict().items():
139+
for attr_name, feature_kjt_dict in self.get_batch_kjt_dict(batch).items():
122140
# separate feature kjt with {feature_name_1: feature_kjt_1, feature_name_2: feature_kjt_2, ...}
123141
# to multiple dict with {feature_name_1: jt_1}, {feature_name_2: jt_2}, ...
124142
attr_feature_jt_dict = {}

torchrec/distributed/benchmark/benchmark_zch/models/configs/dlrmv2.yaml

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -9,31 +9,31 @@ over_arch_layer_sizes:
99
- 1
1010
embedding_dim: 64
1111
num_embeddings_per_feature:
12-
cat_0: 100000
13-
cat_1: 100000
14-
cat_2: 100000
15-
cat_3: 100000
16-
cat_4: 100000
17-
cat_5: 100000
18-
cat_6: 100000
19-
cat_7: 100000
20-
cat_8: 100000
21-
cat_9: 100000
22-
cat_10: 100000
23-
cat_11: 100000
24-
cat_12: 100000
25-
cat_13: 100000
26-
cat_14: 100000
27-
cat_15: 100000
28-
cat_16: 100000
29-
cat_17: 100000
30-
cat_18: 100000
31-
cat_19: 100000
32-
cat_20: 100000
33-
cat_21: 100000
34-
cat_22: 100000
35-
cat_23: 100000
36-
cat_24: 100000
37-
cat_25: 100000
12+
cat_0: 40000000
13+
cat_1: 39060
14+
cat_2: 17295
15+
cat_3: 7424
16+
cat_4: 20265
17+
cat_5: 3
18+
cat_6: 7122
19+
cat_7: 1543
20+
cat_8: 63
21+
cat_9: 40000000
22+
cat_10: 3067956
23+
cat_11: 405282
24+
cat_12: 10
25+
cat_13: 2209
26+
cat_14: 11938
27+
cat_15: 155
28+
cat_16: 4
29+
cat_17: 976
30+
cat_18: 14
31+
cat_19: 40000000
32+
cat_20: 40000000
33+
cat_21: 40000000
34+
cat_22: 590152
35+
cat_23: 12973
36+
cat_24: 108
37+
cat_25: 36
3838
embedding_module_attribute_path: "dlrm.sparse_arch.embedding_bag_collection" # the attribute path after model
3939
managed_collision_module_attribute_path: "module.dlrm.sparse_arch.embedding_bag_collection.mc_embedding_bag_collection._managed_collision_collection._managed_collision_modules" # the attribute path of managed collision module after model

torchrec/distributed/benchmark/benchmark_zch/models/models/dlrmv2.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ def __init__(
3535
dense_device=dense_device,
3636
)
3737
self.train_model = DLRMTrain(self.dlrm)
38+
self.table_configs: List[EmbeddingBagConfig] = list(
39+
embedding_bag_collection.embedding_bag_configs()
40+
)
3841

3942
def forward(
4043
self, batch: Batch
@@ -55,10 +58,10 @@ def make_model_dlrmv2(
5558
) -> nn.Module:
5659
ebc_configs = [
5760
EmbeddingBagConfig(
58-
name=f"t_{feature_name}",
61+
name=f"{feature_name}",
5962
embedding_dim=configs["embedding_dim"],
6063
num_embeddings=(
61-
none_throws(configs["num_embeddings_per_feature"])[feature_idx]
64+
none_throws(configs["num_embeddings_per_feature"])[feature_name]
6265
if args.num_embeddings is None
6366
else args.num_embeddings
6467
),
@@ -76,8 +79,9 @@ def make_model_dlrmv2(
7679
input_hash_size=args.input_hash_size,
7780
device=torch.device("meta"),
7881
world_size=get_local_size(),
79-
use_mpzch=True,
82+
zch_method="mpzch",
8083
mpzch_num_buckets=args.num_buckets,
84+
mpzch_max_probe=args.max_probe,
8185
)
8286
)
8387
else:

torchrec/modules/mc_adapter.py

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,11 @@ def __init__(
148148
world_size: int,
149149
eviction_interval: int = 1,
150150
allow_in_place_embed_weight_update: bool = False,
151-
use_mpzch: bool = False,
152-
mpzch_num_buckets: Optional[int] = None,
153-
mpzch_max_probe: Optional[int] = None,
151+
zch_method: str = "", # method for managing collisions, one of ["", "mpzch", "sort_zch"]
152+
mpzch_num_buckets: Optional[int] = 80,
153+
mpzch_max_probe: Optional[
154+
int
155+
] = 100, # max_probe for HashZchManagedCollisionModule
154156
) -> None:
155157
"""
156158
Initialize an EmbeddingBagCollectionAdapter.
@@ -173,38 +175,44 @@ def __init__(
173175
mc_modules = {}
174176
for table_config in ebc.embedding_bag_configs():
175177
table_name = table_config.name
176-
if use_mpzch:
178+
if zch_method == "mpzch":
177179
# if use MPZCH, create a HashZchManagedCollisionModule
180+
num_buckets = mpzch_num_buckets if mpzch_num_buckets else world_size
181+
max_probe = (
182+
min(
183+
mpzch_max_probe,
184+
table_config.num_embeddings // world_size // num_buckets,
185+
)
186+
if mpzch_max_probe
187+
else table_config.num_embeddings // world_size // num_buckets
188+
)
178189
mc_modules[table_name] = HashZchManagedCollisionModule( # MPZCH
179190
is_inference=False,
180191
zch_size=(table_config.num_embeddings),
181192
input_hash_size=input_hash_size,
182193
device=device,
183-
total_num_buckets=(
184-
mpzch_num_buckets if mpzch_num_buckets else world_size
185-
), # total_num_buckets if not passed, use world_size, WORLD_SIZE should be a factor of total_num_buckets
194+
total_num_buckets=num_buckets, # total_num_buckets if not passed, use world_size, WORLD_SIZE should be a factor of total_num_buckets
195+
max_probe=max_probe,
186196
eviction_policy_name=HashZchEvictionPolicyName.SINGLE_TTL_EVICTION, # defaultly using single ttl eviction policy
187197
eviction_config=HashZchEvictionConfig(
188198
features=table_config.feature_names,
189199
single_ttl=eviction_interval,
190200
),
191-
max_probe=(
192-
mpzch_max_probe
193-
if mpzch_max_probe is not None
194-
and mpzch_max_probe
195-
< (table_config.num_embeddings // world_size)
196-
else table_config.num_embeddings // world_size
197-
), # max_probe for HashZchManagedCollisionModule
198201
)
199-
else: # if not use MPZCH, create a MCHManagedCollisionModule using the sort ZCH algorithm
202+
elif (
203+
zch_method == "sort_zch"
204+
): # if not use MPZCH, create a MCHManagedCollisionModule using the sort ZCH algorithm
200205
mc_modules[table_name] = MCHManagedCollisionModule( # sort ZCH
201206
zch_size=table_config.num_embeddings,
202207
device=device,
203208
input_hash_size=input_hash_size,
204209
eviction_interval=eviction_interval,
205210
eviction_policy=DistanceLFU_EvictionPolicy(),
206211
) # NOTE: the benchmark for sort ZCH is not implemented yet
207-
# TODO: add the pure hash module here
212+
else: # if not use MPZCH, create a MCHManagedCollisionModule using the sort ZCH
213+
raise NotImplementedError(
214+
f"zc method {zch_method} is not supported yet"
215+
)
208216

209217
# create the mcebc module with the mc modules and the original ebc
210218
self.mc_embedding_bag_collection = (
@@ -219,19 +227,14 @@ def __init__(
219227
)
220228
)
221229

222-
self.remapped_ids: Optional[Dict[str, torch.Tensor]] = (
223-
None # to store remapped ids
224-
)
225-
226230
def forward(self, input_kjt: KeyedJaggedTensor) -> Dict[str, JaggedTensor]:
227231
"""
228232
Args:
229233
input (KeyedJaggedTensor): KJT of form [F X B X L].
230234
Returns:
231235
Dict[str, JaggedTensor]: dictionary of {'feature_name': JaggedTensor}
232236
"""
233-
mc_ebc_out, remapped_ids = self.mc_embedding_bag_collection(input_kjt)
234-
self.remapped_ids = remapped_ids
237+
mc_ebc_out, per_table_remapped_id = self.mc_embedding_bag_collection(input_kjt)
235238
return mc_ebc_out
236239

237240
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
@@ -240,12 +243,15 @@ def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
240243
recurse=recurse
241244
)
242245

243-
def embedding_bag_configs(self) -> List[EmbeddingBagConfig]:
246+
def embedding_bag_configs(self) -> List[EmbeddingConfig]:
244247
"""
245248
Returns:
246249
Dict[str, EmbeddingConfig]: dictionary of {'feature_name': EmbeddingConfig}
247250
"""
251+
# pyre-ignore[16]: `ManagedCollisionEmbeddingBagCollection` has no attribute `_embedding_module`
248252
return (
249-
# pyre-ignore [29] # NOTE: the function "embedding_configs" returns the _embedding_module attribute of the EmbeddingCollection
250253
self.mc_embedding_bag_collection._embedding_module.embedding_bag_configs()
251254
)
255+
256+
def get_per_table_remapped_id(self) -> Dict[str, JaggedTensor]:
257+
return self.per_table_remapped_id

0 commit comments

Comments
 (0)