Sha256: 7cee98edd5a220dfe2496cb9cfa083d444549d13507a06c743e32a4b47f0352d

Contents?: true

Size: 952 Bytes

Versions: 1

Compression:

Stored size: 952 Bytes

Contents

from torch.utils.data import Dataset, DataLoader

class TSVDataset(Dataset):
    def __init__(self, tsv):
        self.tsv = tsv

    def __getitem__(self, key):
        if (type(key) == int):
            row = self.tsv.iloc[key]
        else:
            row = self.tsv.loc[key]

        row = row.to_numpy()
        features = row[:-1]
        label = row[-1]

        return features, label

    def __len__(self):
        return len(self.tsv)

def tsv_dataset(filename, *args, **kwargs):
    import rbbt
    return TSVDataset(rbbt.tsv(filename, *args, **kwargs))

def tsv(*args, **kwargs):
    return tsv_dataset(*args, **kwargs)

def data_dir():
    import rbbt
    return rbbt.path('var/rbbt_dm/data')

if __name__ == "__main__":
    import rbbt

    filename = "/home/miki/test/numeric.tsv"
    ds = tsv(filename)

    dl = DataLoader(ds, batch_size=1)

    for f, l in iter(dl):
        print(".")
        print(f[0,:])
        print(l[0])




Version data entries

1 entries across 1 versions & 1 rubygems

Version Path
rbbt-dm-1.2.9 python/rbbt_dm/__init__.py