Sha256: e280973068f80e655e04b6411683a555ba63c0225917e1e81e57c12ae6701175

Contents?: true

Size: 873 Bytes

Versions: 4

Compression:

Stored size: 873 Bytes

Contents

import torch
import math
def spiral_data(N=1000, D=2, C=3):
    X = torch.zeros(N * C, D)
    y = torch.zeros(N * C, dtype=torch.long)
    for c in range(C):
        index = 0
        t = torch.linspace(0, 1, N)
        # When c = 0 and t = 0: start of linspace
        # When c = 0 and t = 1: end of linpace
        # This inner_var is for the formula inside sin() and cos() like sin(inner_var) and cos(inner_Var)
        inner_var = torch.linspace(
            # When t = 0
            (2 * math.pi / C) * (c),
            # When t = 1
            (2 * math.pi / C) * (2 + c),
            N
        ) + torch.randn(N) * 0.2

        for ix in range(N * c, N * (c + 1)):
            X[ix] = t[index] * torch.FloatTensor((
                math.sin(inner_var[index]), math.cos(inner_var[index])
            ))
            y[ix] = c
            index += 1

    return (X, y)

Version data entries

4 entries across 4 versions & 1 rubygems

Version Path
rbbt-dm-1.3.2 python/rbbt_dm/atcold/spiral.py
rbbt-dm-1.3.0 python/rbbt_dm/atcold/spiral.py
rbbt-dm-1.2.10 python/rbbt_dm/atcold/spiral.py
rbbt-dm-1.2.9 python/rbbt_dm/atcold/spiral.py