You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This is an optional construct that allows Tornasole to bulk-fetch all tensors that you need to
354
+
execute the rule. This helps the rule invocation be more performant so it does not fetch tensor values from S3 one by one. To use this construct, you need to implement a method which lets Tornasole know what tensors you are interested in for invocation at a given step.
336
355
This is the `set_required_tensors` method.
337
356
338
357
Before we look at how to define this method, let us look at the API for `RequiredTensors` class which
339
-
needs to be used by this method
358
+
needs to be used by this method. An object of this class is provided as a member of the rule class, so you can access it as `self.req_tensors`.
@@ -360,8 +379,19 @@ take the value of `self.base_trial` in the rule class. None is the default value
360
379
In such a case, all tensor names in the trial which match that regex pattern are treated as required
361
380
for the invocation of the rule at the given step.
362
381
382
+
***Fetching required tensors***
383
+
384
+
If required tensors were added inside `set_required_tensors`, during rule invocation it is
385
+
automatically used to fetch all tensors at once by calling `req_tensors.fetch()`.
386
+
It can raise the exceptions `TensorUnavailable` and `TensorUnavailableForStep` if the trial does not have that tensor, or if the tensor value is not available for the requested step.
387
+
388
+
389
+
If required tensors were added elsewhere, or later, you can call the `req_tensors.fetch()` method
390
+
yourself to fetch all tensors at once.
391
+
363
392
***Querying required tensors***
364
393
394
+
You can then query the required tensors
365
395
*Get names of required tensors*
366
396
367
397
This method returns the names of the required tensors for a given trial.
@@ -392,43 +422,21 @@ take the value of `self.base_trial` in the rule class. None is the default value
392
422
393
423
394
424
###### Declare required tensors
395
-
We need to implement the `set_required_tensors` method to declare the required tensors
425
+
Here, let us define the `set_required_tensors` method to declare the required tensors
396
426
to execute the rule at a given `step`.
397
427
If we require the gradients of the base_trial to execute the rule at a given step,
398
428
then it would look as follows:
399
429
```
400
-
def required_tensors(self, step):
430
+
def set_required_tensors(self, step):
401
431
for tname in self.base_trial.tensors_in_collection('gradients'):
402
432
self.req_tensors.add(tname, steps=[step])
403
433
```
404
434
405
435
This function will be used by the rule execution engine to fetch all the
406
-
required tensors from local disk or S3 before it executes the rule.
407
-
If you try to retrieve the value of a tensor which was not mentioned as part of `required_tensors`,
408
-
it might not be fetched from the trial directory.
409
-
In such a case you might see one of the exceptions
410
-
`TensorUnavailableForStep` or `TensorUnavailable`.
411
-
This is because the rule invoker executes the rule with `no_refresh` mode.
412
-
Refer discussion above for more on this.
413
-
414
-
##### Function to invoke at a given step
415
-
In this function you can implement the core logic of what you want to do with these tensors.
416
-
You can access the `required_tensors` from here using the methods to query the required tensors.
417
-
418
-
It should return a boolean value `True` or `False`.
419
-
This can be used to define actions that you might want to take based on the output of the rule.
420
-
421
-
A simplified version of the actual invoke function for `VanishingGradientRule` is below:
For first party Rules (see below) that we provide a rule_invoker module that you can use to run them as follows. You can pass any arguments that the rule takes as command line arguments.
0 commit comments