Skip to content

Local variable serialization for workload training errors #4084

Open
@landenai

Description

@landenai

Problem Statement

Errors from training workloads (e.g. GPU) result in local variables that are unserializable or have repr/str methods that call out to a device. Attaching these local variables to Sentry results in segfault errors. While I have the option to not include any local variables with my error event payloads, this is an all or nothing approach, which can lead to the loss of valuable debug context.

Solution Brainstorm

Framework-specific logic within the SDK and an accompanying flag to only extract tensor metadata. Here's an example:

def extract_tensor_metadata(obj):
    """Safely extract metadata from tensor objects across frameworks"""
    metadata = {'type': type(obj).__name__}
    
    # Framework-specific safe extraction methods
    if 'torch' in sys.modules and isinstance(obj, sys.modules['torch'].Tensor):
        # PyTorch tensors
        try:
            metadata['shape'] = tuple(obj.shape)
            metadata['dtype'] = str(obj.dtype)
            metadata['device'] = str(obj.device)
            metadata['requires_grad'] = bool(obj.requires_grad)
            # Don't access .data or tensor contents
        except Exception as e:
            metadata['extraction_error'] = str(e)
    
    elif 'jax.numpy' in sys.modules and isinstance(obj, sys.modules['jax.numpy'].ndarray):
        # JAX arrays
        try:
            metadata['shape'] = obj.shape
            metadata['dtype'] = str(obj.dtype)
            # Safely extract device info
            if hasattr(obj, 'device_buffer'):
                try:
                    metadata['device'] = str(obj.device_buffer.device())
                except:
                    metadata['device'] = 'unknown_jax_device'
        except Exception as e:
            metadata['extraction_error'] = str(e)
    
    elif 'tensorflow' in sys.modules and isinstance(obj, sys.modules['tensorflow'].Tensor):
        # TensorFlow tensors
        try:
            metadata['shape'] = tuple(obj.shape)
            metadata['dtype'] = str(obj.dtype)
            # TF device extraction can be complicated
            try:
                if hasattr(obj, 'device'):
                    metadata['device'] = str(obj.device)
                elif hasattr(obj, '_device'):
                    metadata['device'] = str(obj._device)
            except:
                metadata['device'] = 'unknown_tf_device'
        except Exception as e:
            metadata['extraction_error'] = str(e)
    
    # Add fallbacks for common attributes if not already set
    for attr in ['shape', 'dtype']:
        if attr not in metadata and hasattr(obj, attr):
            try:
                val = getattr(obj, attr)
                metadata[attr] = str(val) if val is not None else None
            except:
                pass
    
    return metadata

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions