Skip to content

Commit 2d13e4c

Browse files
Chenyu Zhangfacebook-github-bot
authored andcommitted
kvzch use new operator in model publish (#3108)
Summary: Pull Request resolved: #3108 Publish change to enable KVEmbeddingInference when use_virtual_table is set to true Differential Revision: D75321284
1 parent 6f583af commit 2d13e4c

File tree

4 files changed

+59
-40
lines changed

4 files changed

+59
-40
lines changed

torchrec/distributed/embedding_types.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,12 @@ def embedding_shard_metadata(self) -> List[Optional[ShardMetadata]]:
299299
embedding_shard_metadata.append(table.local_metadata)
300300
return embedding_shard_metadata
301301

302+
def is_using_virtual_table(self) -> bool:
303+
return self.compute_kernel in [
304+
EmbeddingComputeKernel.SSD_VIRTUAL_TABLE,
305+
EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE,
306+
]
307+
302308

303309
F = TypeVar("F", bound=Multistreamable)
304310
T = TypeVar("T")

torchrec/distributed/embeddingbag.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,7 @@ def create_sharding_infos_by_sharding_device_group(
292292
getattr(config, "num_embeddings_post_pruning", None)
293293
# TODO: Need to check if attribute exists for BC
294294
),
295+
use_virtual_table=config.use_virtual_table,
295296
),
296297
param_sharding=parameter_sharding,
297298
param=param,

torchrec/distributed/quant_embedding_kernel.py

Lines changed: 36 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
PoolingMode,
2121
rounded_row_size_in_bytes,
2222
)
23+
from fbgemm_gpu.tbe.cache import KVEmbeddingInference
2324
from torchrec.distributed.batched_embedding_kernel import (
2425
BaseBatchedEmbedding,
2526
BaseBatchedEmbeddingBag,
@@ -284,6 +285,8 @@ def __init__(
284285

285286
if self.lengths_to_tbe:
286287
tbe_clazz = IntNBitTableBatchedEmbeddingBagsCodegenWithLength
288+
elif config.is_using_virtual_table():
289+
tbe_clazz = KVEmbeddingInference
287290
else:
288291
tbe_clazz = IntNBitTableBatchedEmbeddingBagsCodegen
289292

@@ -465,37 +468,40 @@ def __init__(
465468
)
466469
# 16 for CUDA, 1 for others like CPU and MTIA.
467470
self._tbe_row_alignment: int = 16 if self._runtime_device.type == "cuda" else 1
468-
self._emb_module: IntNBitTableBatchedEmbeddingBagsCodegen = (
469-
IntNBitTableBatchedEmbeddingBagsCodegen(
470-
embedding_specs=[
471+
embedding_clazz = (
472+
KVEmbeddingInference
473+
if config.is_using_virtual_table()
474+
else IntNBitTableBatchedEmbeddingBagsCodegen
475+
)
476+
self._emb_module: IntNBitTableBatchedEmbeddingBagsCodegen = embedding_clazz(
477+
embedding_specs=[
478+
(
479+
table.name,
480+
local_rows,
471481
(
472-
table.name,
473-
local_rows,
474-
(
475-
local_cols
476-
if self._quant_state_dict_split_scale_bias
477-
else table.embedding_dim
478-
),
479-
data_type_to_sparse_type(table.data_type),
480-
location,
481-
)
482-
for local_rows, local_cols, table, location in zip(
483-
self._local_rows,
484-
self._local_cols,
485-
config.embedding_tables,
486-
managed,
487-
)
488-
],
489-
device=device,
490-
pooling_mode=PoolingMode.NONE,
491-
feature_table_map=self._feature_table_map,
492-
row_alignment=self._tbe_row_alignment,
493-
uvm_host_mapped=True, # Use cudaHostAlloc for UVM CACHING to fix imbalance numa memory issue
494-
feature_names_per_table=[
495-
table.feature_names for table in config.embedding_tables
496-
],
497-
**(tbe_fused_params(fused_params) or {}),
498-
)
482+
local_cols
483+
if self._quant_state_dict_split_scale_bias
484+
else table.embedding_dim
485+
),
486+
data_type_to_sparse_type(table.data_type),
487+
location,
488+
)
489+
for local_rows, local_cols, table, location in zip(
490+
self._local_rows,
491+
self._local_cols,
492+
config.embedding_tables,
493+
managed,
494+
)
495+
],
496+
device=device,
497+
pooling_mode=PoolingMode.NONE,
498+
feature_table_map=self._feature_table_map,
499+
row_alignment=self._tbe_row_alignment,
500+
uvm_host_mapped=True, # Use cudaHostAlloc for UVM CACHING to fix imbalance numa memory issue
501+
feature_names_per_table=[
502+
table.feature_names for table in config.embedding_tables
503+
],
504+
**(tbe_fused_params(fused_params) or {}),
499505
)
500506
if device is not None:
501507
self._emb_module.initialize_weights()

torchrec/quant/embedding_modules.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
IntNBitTableBatchedEmbeddingBagsCodegen,
3131
PoolingMode,
3232
)
33+
from fbgemm_gpu.tbe.cache import KVEmbeddingInference
3334
from torch import Tensor
3435
from torchrec.distributed.utils import none_throws
3536
from torchrec.modules.embedding_configs import (
@@ -357,7 +358,7 @@ def __init__(
357358
self._is_weighted = is_weighted
358359
self._embedding_bag_configs: List[EmbeddingBagConfig] = tables
359360
self._key_to_tables: Dict[
360-
Tuple[PoolingType, DataType, bool], List[EmbeddingBagConfig]
361+
Tuple[PoolingType, bool], List[EmbeddingBagConfig]
361362
] = defaultdict(list)
362363
self._feature_names: List[str] = []
363364
self._feature_splits: List[int] = []
@@ -383,15 +384,13 @@ def __init__(
383384
key = (table.pooling, table.use_virtual_table)
384385
else:
385386
key = (table.pooling, False)
386-
# pyre-ignore
387387
self._key_to_tables[key].append(table)
388388

389389
location = (
390390
EmbeddingLocation.HOST if device.type == "cpu" else EmbeddingLocation.DEVICE
391391
)
392392

393-
for key, emb_configs in self._key_to_tables.items():
394-
pooling = key[0]
393+
for (pooling, use_virtual_table), emb_configs in self._key_to_tables.items():
395394
embedding_specs = []
396395
weight_lists: Optional[
397396
List[Tuple[torch.Tensor, Optional[torch.Tensor]]]
@@ -420,7 +419,12 @@ def __init__(
420419
)
421420
feature_table_map.extend([idx] * table.num_features())
422421

423-
emb_module = IntNBitTableBatchedEmbeddingBagsCodegen(
422+
embedding_clazz = (
423+
KVEmbeddingInference
424+
if use_virtual_table
425+
else IntNBitTableBatchedEmbeddingBagsCodegen
426+
)
427+
emb_module = embedding_clazz(
424428
embedding_specs=embedding_specs,
425429
pooling_mode=pooling_type_to_pooling_mode(pooling),
426430
weight_lists=weight_lists,
@@ -790,8 +794,7 @@ def __init__( # noqa C901
790794
key = (table.data_type, False)
791795
self._key_to_tables[key].append(table)
792796
self._feature_splits: List[int] = []
793-
for key, emb_configs in self._key_to_tables.items():
794-
data_type = key[0]
797+
for (data_type, use_virtual_table), emb_configs in self._key_to_tables.items():
795798
embedding_specs = []
796799
weight_lists: Optional[
797800
List[Tuple[torch.Tensor, Optional[torch.Tensor]]]
@@ -816,10 +819,13 @@ def __init__( # noqa C901
816819
table_name_to_quantized_weights[table.name]
817820
)
818821
feature_table_map.extend([idx] * table.num_features())
819-
# move to here to make sure feature_names order is consistent with the embedding groups
820822
self._feature_names.extend(table.feature_names)
821-
822-
emb_module = IntNBitTableBatchedEmbeddingBagsCodegen(
823+
embedding_clazz = (
824+
KVEmbeddingInference
825+
if use_virtual_table
826+
else IntNBitTableBatchedEmbeddingBagsCodegen
827+
)
828+
emb_module = embedding_clazz(
823829
embedding_specs=embedding_specs,
824830
pooling_mode=PoolingMode.NONE,
825831
weight_lists=weight_lists,

0 commit comments

Comments
 (0)