Sha256: aebecba1eb30af06a3aac7e78a1c870d872370aa666903631e2a1a53e73a7e75

Contents?: true

Size: 696 Bytes

Versions: 1

Compression:

Stored size: 696 Bytes

Contents

import random
import torch
import numpy

def set_seed(seed):
    """
    Set seed in several backends
    """
    random.seed(seed)
    numpy.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

def deterministic():
    """
    Ensure that all operations are deterministic on GPU (if used) for
    reproducibility
    """
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def device():
    return torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

def data_directory():
    from pathlib import Path
    print(Path.home())

Version data entries

1 entries across 1 versions & 1 rubygems

Version Path
rbbt-dm-1.2.9 python/rbbt_dm/util.py