@@ -14,9 +14,10 @@ Requirements: - python >= 3.7
14
14
We highly recommend CUDA when using torchRec. If using CUDA: - cuda >=
15
15
11.0
16
16
17
+ .. Should these be updated?
17
18
.. code :: python
18
19
19
- # install conda to make installying pytorch with cudatoolkit 11.3 easier.
20
+ # install conda to make installying pytorch with cudatoolkit 11.3 easier.
20
21
! sudo rm Miniconda3- py37_4.9.2- Linux- x86_64.sh Miniconda3- py37_4.9.2- Linux- x86_64.sh.*
21
22
! sudo wget https:// repo.anaconda.com/ miniconda/ Miniconda3- py37_4.9.2- Linux- x86_64.sh
22
23
! sudo chmod + x Miniconda3- py37_4.9.2- Linux- x86_64.sh
@@ -213,7 +214,7 @@ embedding table placement using planner and generate sharded model using
213
214
)
214
215
sharders = [cast(ModuleSharder[torch.nn.Module], EmbeddingBagCollectionSharder())]
215
216
plan: ShardingPlan = planner.collective_plan(module, sharders, pg)
216
-
217
+
217
218
sharded_model = DistributedModelParallel(
218
219
module,
219
220
env = ShardingEnv.from_process_group(pg),
@@ -234,7 +235,7 @@ ranks.
234
235
.. code :: python
235
236
236
237
import multiprocess
237
-
238
+
238
239
def spmd_sharing_simulation (
239
240
sharding_type : ShardingType = ShardingType.TABLE_WISE ,
240
241
world_size = 2 ,
@@ -254,7 +255,7 @@ ranks.
254
255
)
255
256
p.start()
256
257
processes.append(p)
257
-
258
+
258
259
for p in processes:
259
260
p.join()
260
261
assert 0 == p.exitcode
@@ -333,4 +334,3 @@ With data parallel, we will repeat the tables for all devices.
333
334
334
335
rank:0,sharding plan: {'': {'large_table_0': ParameterSharding(sharding_type='data_parallel', compute_kernel='batched_dense', ranks=[0, 1], sharding_spec=None), 'large_table_1': ParameterSharding(sharding_type='data_parallel', compute_kernel='batched_dense', ranks=[0, 1], sharding_spec=None), 'small_table_0': ParameterSharding(sharding_type='data_parallel', compute_kernel='batched_dense', ranks=[0, 1], sharding_spec=None), 'small_table_1': ParameterSharding(sharding_type='data_parallel', compute_kernel='batched_dense', ranks=[0, 1], sharding_spec=None)}}
335
336
rank:1,sharding plan: {'': {'large_table_0': ParameterSharding(sharding_type='data_parallel', compute_kernel='batched_dense', ranks=[0, 1], sharding_spec=None), 'large_table_1': ParameterSharding(sharding_type='data_parallel', compute_kernel='batched_dense', ranks=[0, 1], sharding_spec=None), 'small_table_0': ParameterSharding(sharding_type='data_parallel', compute_kernel='batched_dense', ranks=[0, 1], sharding_spec=None), 'small_table_1': ParameterSharding(sharding_type='data_parallel', compute_kernel='batched_dense', ranks=[0, 1], sharding_spec=None)}}
336
-
0 commit comments