Skip to content

Commit

Permalink
Merge pull request #379 from leondavi/multiple_sources_stream_fix
Browse files Browse the repository at this point in the history
Multiple_sources_stream_fix
  • Loading branch information
leondavi authored Jul 22, 2024
2 parents e46270b + f693df2 commit 4f3bd0c
Show file tree
Hide file tree
Showing 9 changed files with 358 additions and 58 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"connectionsMap":
{
"r1":["mainServer", "c1", "r2"],
"r2":["c2", "s1", "r3"],
"r3":["c3", "r4", "r1"],
"r4":["s2", "r5", "r2"],
"r5":["s3", "r1"]
}
}
174 changes: 174 additions & 0 deletions inputJsonsFiles/DistributedConfig/dc_EEG_8d_3c_3s_5r_3w_RR.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
{
"nerlnetSettings": {
"frequency": "100",
"batchSize": "5"
},
"mainServer": {
"port": "8900",
"args": ""
},
"apiServer": {
"port": "8901",
"args": ""
},
"devices": [
{
"name": "c0VM5",
"ipv4": "10.0.0.11",
"entities": "apiServer,mainServer"
},
{
"name": "c0VM6",
"ipv4": "10.0.0.8",
"entities": "c1,r3"
},
{
"name": "c0VM7",
"ipv4": "10.0.0.12",
"entities": "c2,r4"
},
{
"name": "c0VM4",
"ipv4": "10.0.0.10",
"entities": "c3"
},
{
"name": "nerlSpilke0",
"ipv4": "10.0.0.32",
"entities": "s2,r5"
},
{
"name": "nerlSpilke1",
"ipv4": "10.0.0.33",
"entities": "s3"
},
{
"name": "nerlSpilke2",
"ipv4": "10.0.0.34",
"entities": "s1,r2"
},
{
"name": "nerlSpilke3",
"ipv4": "10.0.0.35",
"entities": "r1"
}
],
"routers": [
{
"name": "r1",
"port": "8902",
"policy": "0"
},
{
"name": "r2",
"port": "8903",
"policy": "0"
},
{
"name": "r3",
"port": "8910",
"policy": "0"
},
{
"name": "r4",
"port": "8911",
"policy": "0"
},
{
"name": "r5",
"port": "8912",
"policy": "0"
}
],
"sources": [
{
"name": "s1",
"port": "8904",
"frequency": "100",
"policy": "1",
"epochs": "1",
"type": "0"
},
{
"name": "s2",
"port": "8905",
"frequency": "100",
"policy": "1",
"epochs": "1",
"type": "0"
},
{
"name": "s3",
"port": "8906",
"frequency": "100",
"policy": "1",
"epochs": "1",
"type": "0"
}
],
"clients": [
{
"name": "c1",
"port": "8907",
"workers": "w1"
},
{
"name": "c2",
"port": "8908",
"workers": "w2"
},
{
"name": "c3",
"port": "8909",
"workers": "w3"
}
],
"workers": [
{
"name": "w1",
"model_sha": "d8df752e0a2e8f01de8f66e9cec941cdbc65d144ecf90ab7713e69d65e7e82aa"
},
{
"name": "w2",
"model_sha": "d8df752e0a2e8f01de8f66e9cec941cdbc65d144ecf90ab7713e69d65e7e82aa"
},
{
"name": "w3",
"model_sha": "d8df752e0a2e8f01de8f66e9cec941cdbc65d144ecf90ab7713e69d65e7e82aa"
}
],
"model_sha": {
"d8df752e0a2e8f01de8f66e9cec941cdbc65d144ecf90ab7713e69d65e7e82aa": {
"modelType": "0",
"_doc_modelType": " nn:0 | approximation:1 | classification:2 | forecasting:3 | image-classification:4 | text-classification:5 | text-generation:6 | auto-association:7 | autoencoder:8 | ae-classifier:9 |",
"modelArgs": "",
"layersSizes": "70x1x1k5x1x1x128p0s1t0,66x1x128k2x1p0s1,65x1x128k5x1x128x128p0s1t0,61x1x128k2x1p0s1,60x1x128k5x1x128x64p0s1t0,1,64,32,16,9",
"_doc_layersSizes": "List of postive integers [L0, L1, ..., LN]",
"layerTypesList": "2,4,2,4,2,9,3,3,3,3",
"_doc_LayerTypes": " Default:0 | Scaling:1 | CNN:2 | Perceptron:3 | Pooling:4 | Probabilistic:5 | LSTM:6 | Reccurrent:7 | Unscaling:8 | Flatten:9 | Bounding:10 |",
"layers_functions": "11,2,11,2,11,1,6,6,6,11",
"_doc_layers_functions_activation": " Threshold:1 | Sign:2 | Logistic:3 | Tanh:4 | Linear:5 | ReLU:6 | eLU:7 | SeLU:8 | Soft-plus:9 | Soft-sign:10 | Hard-sigmoid:11 |",
"_doc_layer_functions_pooling": " none:1 | Max:2 | Avg:3 |",
"_doc_layer_functions_probabilistic": " Binary:1 | Logistic:2 | Competitive:3 | Softmax:4 |",
"_doc_layer_functions_scaler": " none:1 | MinMax:2 | MeanStd:3 | STD:4 | Log:5 |",
"lossMethod": "2",
"_doc_lossMethod": " SSE:1 | MSE:2 | NSE:3 | MinkowskiE:4 | WSE:5 | CEE:6 |",
"lr": "0.00001",
"_doc_lr": "Positve float",
"epochs": "1",
"_doc_epochs": "Positve Integer",
"optimizer": "5",
"_doc_optimizer": " GD:0 | CGD:1 | SGD:2 | QuasiNeuton:3 | LVM:4 | ADAM:5 |",
"optimizerArgs": "",
"_doc_optimizerArgs": "String",
"infraType": "0",
"_doc_infraType": " opennn:0 | wolfengine:1 |",
"distributedSystemType": "0",
"_doc_distributedSystemType": " none:0 | fedClientAvg:1 | fedServerAvg:2 |",
"distributedSystemArgs": "",
"_doc_distributedSystemArgs": "String",
"distributedSystemToken": "none",
"_doc_distributedSystemToken": "Token that associates distributed group of workers and parameter-server"
}
}
}
71 changes: 71 additions & 0 deletions inputJsonsFiles/experimentsFlow/exp_EEG_3s_3w_half3_people_RR.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
{
"experimentName": "EEG_Valence_Recognition_DEAP",
"experimentType": "classification",
"batchSize": 5,
"csvFilePath": "/home/nerlnet/workspace/1_3_person_valence_nerlnet.csv",
"numOfFeatures": "70",
"numOfLabels": "9",
"headersNames": "1,2,3,4,5,6,7,8,9",
"Phases":

[
{
"phaseName": "training_phase",
"phaseType": "training",
"sourcePieces":
[
{
"sourceName": "s1",
"startingSample": "0",
"numOfBatches": "100",
"workers": "w1,w2,w3",
"nerltensorType": "float"
},
{
"sourceName": "s2",
"startingSample": "19520",
"numOfBatches": "100",
"workers": "w1,w2,w3",
"nerltensorType": "float"
},
{
"sourceName": "s3",
"startingSample": "39040",
"numOfBatches": "100",
"workers": "w1,w2,w3",
"nerltensorType": "float"
}
]
},
{
"phaseName": "prediction_phase",
"phaseType": "prediction",
"sourcePieces":
[
{
"sourceName": "s1",
"startingSample": "15610",
"numOfBatches": "200",
"workers": "w1,w2,w3",
"nerltensorType": "float"
},
{
"sourceName": "s2",
"startingSample": "35130",
"numOfBatches": "200",
"workers": "w1,w2,w3",
"nerltensorType": "float"
},
{
"sourceName": "s3",
"startingSample": "54650",
"numOfBatches": "200",
"workers": "w1,w2,w3",
"nerltensorType": "float"
}


]
}
]
}
1 change: 0 additions & 1 deletion src_erl/NerlnetApp/src/Bridge/onnWorkers/w2wCom.erl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ handle_call({?W2WCOM_ATOM, FromWorkerName, ThisWorkerName, {msg_with_event, Even
GenWorkerPid = get(gen_worker_pid),
case Event of
post_train_update -> gen_statem:cast(GenWorkerPid, {post_train_update, Data});
worker_done -> gen_statem:cast(GenWorkerPid, {worker_done, Data});
start_stream -> gen_statem:cast(GenWorkerPid, {start_stream, Data});
end_stream -> gen_statem:cast(GenWorkerPid, {end_stream, Data})
end,
Expand Down
2 changes: 1 addition & 1 deletion src_erl/NerlnetApp/src/Bridge/onnWorkers/w2wCom.hrl
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@
-define(SYNC_INBOX_TIMEOUT, 30000). % 30 seconds
-define(SYNC_INBOX_TIMEOUT_NO_LIMIT, 36000000). % 36000 seconds = 10 hours , no limit
-define(DEFAULT_SYNC_INBOX_BUSY_WAITING_SLEEP, 5). % 5 milliseconds
-define(SUPPORTED_EVENTS , [post_train_update, worker_done, start_stream, end_stream]).
-define(SUPPORTED_EVENTS , [post_train_update, start_stream, end_stream]).
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ handshake(FedClientEts) ->
lists:foreach(Func, MessagesList).

start_stream({GenWorkerEts, WorkerData}) -> % WorkerData is currently a list of [SourceName, State]
[_SourceName, ModelPhase] = WorkerData,
[SourceName, ModelPhase] = WorkerData,
FirstMsg = 1,
case ModelPhase of
train ->
Expand All @@ -96,14 +96,14 @@ start_stream({GenWorkerEts, WorkerData}) -> % WorkerData is currently a list of
W2WPid = ets:lookup_element(ThisEts, w2wcom_pid, ?ETS_KEYVAL_VAL_IDX),
ActiveStreams = ets:lookup_element(GenWorkerEts, active_streams, ?ETS_KEYVAL_VAL_IDX),
case length(ActiveStreams) of % Send to server an updater after got start_stream from the first source
FirstMsg -> w2wCom:send_message_with_event(W2WPid, MyName, ServerName , start_stream, MyName); % Server gets FedWorkerName instead of SourceName
FirstMsg -> w2wCom:send_message_with_event(W2WPid, MyName, ServerName , start_stream, {MyName, SourceName}); % Server gets FedWorkerName instead of SourceName
_ -> ok
end;
predict -> ok
end.

end_stream({GenWorkerEts, WorkerData}) -> % WorkerData is currently a list of [SourceName]
[_SourceName, ModelPhase] = WorkerData,
[SourceName, ModelPhase] = WorkerData,
case ModelPhase of
predict -> ok;
_ -> % train/wait
Expand All @@ -113,7 +113,8 @@ end_stream({GenWorkerEts, WorkerData}) -> % WorkerData is currently a list of [S
W2WPid = ets:lookup_element(ThisEts, w2wcom_pid, ?ETS_KEYVAL_VAL_IDX),
ActiveStreams = ets:lookup_element(GenWorkerEts, active_streams, ?ETS_KEYVAL_VAL_IDX),
case length(ActiveStreams) of % Send to server an updater after got start_stream from the first source
0 -> w2wCom:send_message_with_event(W2WPid, MyName, ServerName , end_stream, MyName); % Mimic source behavior
0 -> io:format("Worker ~p ending stream with ~p~n", [MyName, SourceName]),
w2wCom:send_message_with_event(W2WPid, MyName, ServerName , end_stream, {MyName, SourceName}); % Mimic source behavior
_ -> ok
end
end.
Expand Down
20 changes: 13 additions & 7 deletions src_erl/NerlnetApp/src/Bridge/onnWorkers/workerFederatedServer.erl
Original file line number Diff line number Diff line change
Expand Up @@ -72,21 +72,21 @@ init({GenWorkerEts, WorkerData}) ->


start_stream({GenWorkerEts, WorkerData}) ->
[FedWorkerName , _ModelPhase] = WorkerData,
[Pair , _ModelPhase] = WorkerData,
FedServerEts = get_this_server_ets(GenWorkerEts),
ClientPid = ets:lookup_element(GenWorkerEts, client_pid, ?ETS_KEYVAL_VAL_IDX),
MyName = ets:lookup_element(FedServerEts, my_name, ?ETS_KEYVAL_VAL_IDX),
gen_server:cast(ClientPid, {start_stream, {worker, MyName, FedWorkerName}}).
gen_server:cast(ClientPid, {start_stream, {worker, MyName, Pair}}).

end_stream({GenWorkerEts, WorkerData}) -> % Federated server takes the control of popping the stream from the active streams list
[FedWorkerName , _ModelPhase] = WorkerData,
[Pair , _ModelPhase] = WorkerData,
FedServerEts = get_this_server_ets(GenWorkerEts),
MyName = ets:lookup_element(FedServerEts, my_name, ?ETS_KEYVAL_VAL_IDX),
ClientPid = ets:lookup_element(GenWorkerEts, client_pid, ?ETS_KEYVAL_VAL_IDX),
gen_statem:cast(ClientPid, {worker_done, {MyName, FedWorkerName}}),
gen_statem:cast(ClientPid, {stream_ended, {MyName, Pair}}),
ActiveStreams = ets:lookup_element(GenWorkerEts, active_streams, ?ETS_KEYVAL_VAL_IDX),
case ActiveStreams of
[] -> ets:update_element(FedServerEts, active_streams, {?ETS_KEYVAL_VAL_IDX, none});
[] -> ets:update_element(FedServerEts, active_streams, {?ETS_KEYVAL_VAL_IDX, []});
_ -> ok
end.

Expand Down Expand Up @@ -141,7 +141,8 @@ post_train({GenWorkerEts, WeightsTensor}) ->
CurrWorkersWeightsList = ets:lookup_element(FedServerEts, weights_list, ?ETS_KEYVAL_VAL_IDX),
{WorkerWeights, _BinaryType} = WeightsTensor,
TotalWorkersWeights = CurrWorkersWeightsList ++ [WorkerWeights],
NumOfActiveWorkers = length(ets:lookup_element(GenWorkerEts, active_streams, ?ETS_KEYVAL_VAL_IDX)),
ActiveWorkersSourcesList = ets:lookup_element(GenWorkerEts, active_streams, ?ETS_KEYVAL_VAL_IDX),
NumOfActiveWorkers = length([FedWorker || {_MyName, {FedWorker, _Source}} <- ActiveWorkersSourcesList]),
case length(TotalWorkersWeights) of
NumOfActiveWorkers ->
ets:update_counter(FedServerEts, total_syncs, 1),
Expand All @@ -151,14 +152,19 @@ post_train({GenWorkerEts, WeightsTensor}) ->
{CurrentModelWeights, BinaryType} = nerlNIF:call_to_get_weights(ModelID),
FedServerName = ets:lookup_element(FedServerEts, my_name, ?ETS_KEYVAL_VAL_IDX),
AllWorkersWeightsList = TotalWorkersWeights ++ [CurrentModelWeights],
io:format("GOT HERE1~n"),
AvgWeightsNerlTensor = generate_avg_weights(AllWorkersWeightsList, BinaryType),
io:format("GOT HERE2~n"),
nerlNIF:call_to_set_weights(ModelID, AvgWeightsNerlTensor), %% update self weights to new model
io:format("GOT HERE3~n"),
Func = fun(FedClient) ->
FedServerName = ets:lookup_element(ThisEts, my_name, ?ETS_KEYVAL_VAL_IDX),
W2WPid = ets:lookup_element(ThisEts, w2wcom_pid, ?ETS_KEYVAL_VAL_IDX),
w2wCom:send_message_with_event(W2WPid, FedServerName, FedClient, post_train_update, {SyncIdx, AvgWeightsNerlTensor})
end,
WorkersList = ets:lookup_element(GenWorkerEts, active_streams, ?ETS_KEYVAL_VAL_IDX),
WorkersSourcesList = ets:lookup_element(GenWorkerEts, active_streams, ?ETS_KEYVAL_VAL_IDX),
WorkersList = [FedWorker || {_MyName, {FedWorker, _Source}} <- WorkersSourcesList],
% io:format("Sending new weights to workers ~p~n",[WorkersList]),
lists:foreach(Func, WorkersList),
ets:update_element(FedServerEts, weights_list, {?ETS_KEYVAL_VAL_IDX, []});
_ -> ets:update_element(FedServerEts, weights_list, {?ETS_KEYVAL_VAL_IDX, TotalWorkersWeights})
Expand Down
Loading

0 comments on commit 4f3bd0c

Please sign in to comment.