@@ -57,6 +57,37 @@ The following APIs are common across all frameworks.
57
57
versions of the tensor across different microbatches
58
58
(see ``StepOutput`` entry for more information).
59
59
60
+ The argument to ``smp.step `` decorated function should either be a tensor
61
+ or an instance of list, tuple, dict or set for it to be split across
62
+ microbatches. If your object doesn't fall into this category, you can make
63
+ the library split your object, by implementing ``smp_slice `` method.
64
+
65
+ Below is an example of how to use it with PyTorch.
66
+
67
+ .. code :: python
68
+
69
+ class CustomType :
70
+ def __init__ (self , tensor ):
71
+ self .data = tensor
72
+
73
+ # The library will call this to invoke slicing on the object passing in total microbatches (num_mb)
74
+ # and the current microbatch index (mb).
75
+ def smp_slice (self , num_mb , mb , axis ):
76
+ dim_size = list (self .data.size())[axis]
77
+
78
+ split_size = dim_size // num_mb
79
+ sliced_tensor = self .data.narrow(axis, mb * split_size, split_size)
80
+ return CustomType(sliced_tensor, self .other)
81
+
82
+ custom_obj = CustomType(torch.ones(4 ,))
83
+
84
+ @smp.step ()
85
+ def step (custom_obj ):
86
+ loss = model(custom_obj)
87
+ model.backward(loss)
88
+ return loss
89
+
90
+
60
91
**Important:** ``smp.step`` splits the batch into microbatches, and
61
92
executes everything inside the decorated function once per microbatch.
62
93
This might affect the behavior of batch normalization, any operation
0 commit comments