Skip to content

Commit ba17c82

Browse files
authored
Make required tensors optional (aws#148)
* make required tensors optional * Update README.md
1 parent 82e7827 commit ba17c82

File tree

2 files changed

+54
-41
lines changed

2 files changed

+54
-41
lines changed

docs/rules/README.md

Lines changed: 54 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -321,22 +321,41 @@ from tornasole.rules import Rule
321321
class VanishingGradientRule(Rule):
322322
def __init__(self, base_trial, threshold=0.0000001):
323323
super().__init__(base_trial, other_trials=None)
324-
self.threshold = threshold
324+
self.threshold = float(threshold)
325325
```
326326

327327
Please note that apart from `base_trial` and `other_trials` (if required), we require all
328-
arguments of the rule constructor to take a string as value. This means if you want to pass
328+
arguments of the rule constructor to take a string as value. You can parse them to the type
329+
that you want from the string. This means if you want to pass
329330
a list of strings, you might want to pass them as a comma separated string. This restriction is
330-
being enforced so as to let you create and invoke rules from json using Sagemaker's APIs.
331+
being enforced so as to let you create and invoke rules from json using Sagemaker's APIs.
332+
333+
##### Function to invoke at a given step
334+
In this function you can implement the core logic of what you want to do with these tensors.
335+
You can access the `required_tensors` from here using the methods to query the required tensors.
331336

332-
##### RequiredTensors
337+
It should return a boolean value `True` or `False`.
338+
This can be used to define actions that you might want to take based on the output of the rule.
333339

334-
Next you need to implement a method which lets Tornasole know what tensors you
335-
are interested in for invocation at a given step.
340+
A simplified version of the actual invoke function for `VanishingGradientRule` is below:
341+
```
342+
def invoke_at_step(self, step):
343+
for tensor in self.req_tensors.get():
344+
abs_mean = tensor.reduction_value(step, 'mean', abs=True)
345+
if abs_mean < self.threshold:
346+
return True
347+
else:
348+
return False
349+
```
350+
351+
##### Optional: RequiredTensors
352+
353+
This is an optional construct that allows Tornasole to bulk-fetch all tensors that you need to
354+
execute the rule. This helps the rule invocation be more performant so it does not fetch tensor values from S3 one by one. To use this construct, you need to implement a method which lets Tornasole know what tensors you are interested in for invocation at a given step.
336355
This is the `set_required_tensors` method.
337356

338357
Before we look at how to define this method, let us look at the API for `RequiredTensors` class which
339-
needs to be used by this method
358+
needs to be used by this method. An object of this class is provided as a member of the rule class, so you can access it as `self.req_tensors`.
340359

341360
**[RequiredTensors](../../tornasole/rules/req_tensors.py) API**
342361

@@ -360,8 +379,19 @@ take the value of `self.base_trial` in the rule class. None is the default value
360379
In such a case, all tensor names in the trial which match that regex pattern are treated as required
361380
for the invocation of the rule at the given step.
362381

382+
***Fetching required tensors***
383+
384+
If required tensors were added inside `set_required_tensors`, during rule invocation it is
385+
automatically used to fetch all tensors at once by calling `req_tensors.fetch()`.
386+
It can raise the exceptions `TensorUnavailable` and `TensorUnavailableForStep` if the trial does not have that tensor, or if the tensor value is not available for the requested step.
387+
388+
389+
If required tensors were added elsewhere, or later, you can call the `req_tensors.fetch()` method
390+
yourself to fetch all tensors at once.
391+
363392
***Querying required tensors***
364393

394+
You can then query the required tensors
365395
*Get names of required tensors*
366396

367397
This method returns the names of the required tensors for a given trial.
@@ -392,43 +422,21 @@ take the value of `self.base_trial` in the rule class. None is the default value
392422

393423

394424
###### Declare required tensors
395-
We need to implement the `set_required_tensors` method to declare the required tensors
425+
Here, let us define the `set_required_tensors` method to declare the required tensors
396426
to execute the rule at a given `step`.
397427
If we require the gradients of the base_trial to execute the rule at a given step,
398428
then it would look as follows:
399429
```
400-
def required_tensors(self, step):
430+
def set_required_tensors(self, step):
401431
for tname in self.base_trial.tensors_in_collection('gradients'):
402432
self.req_tensors.add(tname, steps=[step])
403433
```
404434

405435
This function will be used by the rule execution engine to fetch all the
406-
required tensors from local disk or S3 before it executes the rule.
407-
If you try to retrieve the value of a tensor which was not mentioned as part of `required_tensors`,
408-
it might not be fetched from the trial directory.
409-
In such a case you might see one of the exceptions
410-
`TensorUnavailableForStep` or `TensorUnavailable`.
411-
This is because the rule invoker executes the rule with `no_refresh` mode.
412-
Refer discussion above for more on this.
413-
414-
##### Function to invoke at a given step
415-
In this function you can implement the core logic of what you want to do with these tensors.
416-
You can access the `required_tensors` from here using the methods to query the required tensors.
417-
418-
It should return a boolean value `True` or `False`.
419-
This can be used to define actions that you might want to take based on the output of the rule.
420-
421-
A simplified version of the actual invoke function for `VanishingGradientRule` is below:
422-
423-
```
424-
def invoke_at_step(self, step):
425-
for tensor in self.req_tensors.get():
426-
abs_mean = tensor.reduction_value(step, 'mean', abs=True)
427-
if abs_mean < self.threshold:
428-
return True
429-
else:
430-
return False
431-
```
436+
required tensors before it executes the rule.
437+
The rule invoker executes the `set_required_tensors` and `invoke_at_step`
438+
methods within a single `no_refresh` block, hence you are guaranteed that the
439+
tensor values or steps numbers will stay the same during multiple calls.
432440

433441
#### Executing a rule
434442
Now that you have written a rule, here's how you can execute it. We provide a function to invoke rules easily.
@@ -437,17 +445,23 @@ The invoke function has the following syntax.
437445
It takes a instance of a Rule and invokes it for a series of steps one after the other.
438446

439447
```
440-
invoke(rule_obj, start_step=0, end_step=None)
448+
from tornasole.rules import invoke_rule
449+
invoke_rule(rule_obj, start_step=0, end_step=None)
441450
```
442451

443-
For first party Rules (see below) that we provide a rule_invoker module that you can use to run them as follows
452+
You can invoking the VanishingGradientRule is
453+
```
454+
trial_obj = create_trial(trial_dir)
455+
vr = VanishingGradientRule(base_trial=trial_obj, threshold=0.0000001)
456+
invoke_rule(vr, start_step=0, end_step=1000)
457+
```
458+
459+
For first party Rules (see below) that we provide a rule_invoker module that you can use to run them as follows. You can pass any arguments that the rule takes as command line arguments.
444460

445461
```
446-
python -m tornasole.rules.rule_invoker --trial-dir ~/ts_outputs/vanishing_gradients --rule-name VanishingGradient
462+
python -m tornasole.rules.rule_invoker --trial-dir ~/ts_outputs/vanishing_gradients --rule-name VanishingGradient --threshold 0.0000000001
447463
```
448464

449-
You can pass any arguments that the rule takes as command line arguments, like below:
450-
451465
```
452466
python -m tornasole.rules.rule_invoker --trial-dir s3://tornasole-runes/trial0 --rule-name UnchangedTensor --tensor_regex .* --num_steps 10
453467
```

tornasole/rules/rule.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ def __init__(self, base_trial, other_trials=None):
2222
self.logger = get_logger()
2323
self.rule_name = self.__class__.__name__
2424

25-
@abstractmethod
2625
def set_required_tensors(self, step):
2726
pass
2827

0 commit comments

Comments
 (0)