-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
177 lines (134 loc) · 5.87 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import pathlib
from collections import OrderedDict
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import pyarrow.parquet as pq
from torchdata import dataloader2
from tqdm import tqdm
from multiprocessing import Pool, cpu_count
import wandb
from pipeline_frame import load_data_framewise
# Define the autoencoder model with separate encoder and decoder
class Autoencoder(nn.Module):
def __init__(self, input_size, encoding_dim):
super(Autoencoder, self).__init__()
self.encoder = nn.Sequential(
nn.Linear(input_size, 256),
nn.ReLU(),
nn.Linear(256, encoding_dim),
nn.ReLU()
)
self.decoder = nn.Sequential(
nn.Linear(encoding_dim, 256),
nn.ReLU(),
nn.Linear(256, input_size),
nn.Sigmoid()
)
def forward(self, x):
encoded = self.encoder(x)
decoded = self.decoder(encoded)
return decoded
def filter_state_dict_by_prefix(state_dict: OrderedDict[str, torch.Tensor], prefix: str) -> OrderedDict[str, torch.Tensor]:
"""
Filters the given PyTorch state_dict to only keep keys that start with the specified prefix.
Args:
state_dict (OrderedDict[str, torch.Tensor]): The state dict of the PyTorch model, where keys are strings and values are PyTorch tensors.
prefix (str): The prefix to filter keys by.
Returns:
OrderedDict[str, torch.Tensor]: The filtered state dict containing only keys starting with the given prefix.
"""
filtered_state_dict = OrderedDict()
for key, value in state_dict.items():
if key.startswith(prefix):
filtered_state_dict[key] = value
return filtered_state_dict
def count_frames_in_csv(csv_file_path: str) -> int:
"""
Count the total number of frames in the Parquet files specified in the CSV file.
Args:
csv_file_path (str or pathlib.Path): The path to the CSV file containing 'path' column with Parquet file paths.
Returns:
int: The total number of frames calculated from the Parquet files.
"""
data_path = pathlib.Path(csv_file_path).parent
# Read the CSV file
train_df = pd.read_csv(csv_file_path)
num_frames = 0
for _, row in tqdm(train_df.iterrows(), total=len(train_df)):
file_meta = pq.read_metadata(data_path / row['path'])
num_frames += file_meta.num_rows // 543
return num_frames
def process_file(row_data):
file_path, num_frames_per_row = row_data
file_meta = pq.read_metadata(file_path)
return file_meta.num_rows // num_frames_per_row
def count_frames_in_csv_parallel(csv_file_path: str) -> int:
"""
Count the total number of frames in the Parquet files specified in the CSV file using parallel processing.
Args:
csv_file_path (str or pathlib.Path): The path to the CSV file containing 'path' column with Parquet file paths.
Returns:
int: The total number of frames calculated from the Parquet files.
"""
data_path = pathlib.Path(csv_file_path).parent
# Read the CSV file
train_df = pd.read_csv(csv_file_path)
num_frames_per_row = 543
num_workers = min(cpu_count(), len(train_df))
with Pool(num_workers) as pool:
row_data = [(data_path / row['path'], num_frames_per_row) for _, row in train_df.iterrows()]
with tqdm(total=len(train_df), desc="Processing", unit="file") as pbar:
results = list(tqdm(pool.imap_unordered(process_file, row_data), total=len(train_df), position=1, leave=False))
return sum(results)
if __name__ == "__main__":
wandb.init(project="ssl-signing")
# Set the number of features and encoding dimension
input_size = 3 * 543
encoding_dim = 128
# Create an instance of the autoencoder model
model = Autoencoder(input_size, encoding_dim)
# Move the model to the GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# Define the loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Define the training dataset and dataloader (modify as per your data)
base_path = pathlib.Path(__file__).parent
data_path = base_path.parent / "effective-octo-potato" / "data"
train_csv = data_path / "train.csv"
batch_size = 1024
train_pipe = load_data_framewise(csv_file=train_csv, data_path=data_path, batch_size=batch_size)
multi_processor = dataloader2.MultiProcessingReadingService(num_workers=12)
train_loader = dataloader2.DataLoader2(train_pipe, reading_service=multi_processor)
num_frames = count_frames_in_csv_parallel(train_csv)
wandb.watch(model, log="all")
# Train the autoencoder
num_epochs = 10
model.train()
for epoch in range(num_epochs):
with tqdm(desc="Processing", unit="iter", position=0, leave=True, total=num_frames // batch_size) as pbar:
rolling_loss = None
for data in train_loader:
img = data
# Move the input data to the GPU
img = img.to(device)
# Forward pass
output = model(img)
loss = criterion(output, img)
# Backward pass and optimization
optimizer.zero_grad()
loss.backward()
optimizer.step()
wandb.log({"loss": loss.item()})
if rolling_loss is None:
rolling_loss = loss.item()
else:
rolling_loss = 0.9 * rolling_loss + 0.1 * loss.item()
pbar.set_postfix({"loss": rolling_loss})
pbar.update(1)
# Save the encoder weights
model_state_dict = model.state_dict()
torch.save(filter_state_dict_by_prefix(model_state_dict, 'encoder.'), 'encoder_weights.pth')