Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
GuyPerets106 committed Jul 21, 2024
1 parent e0ad3ee commit f693df2
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 8 deletions.
4 changes: 0 additions & 4 deletions src_erl/NerlnetApp/src/Bridge/onnWorkers/workerGeneric.erl
Original file line number Diff line number Diff line change
Expand Up @@ -171,13 +171,11 @@ idle(cast, _Param, State = #workerGeneric_state{myName = _MyName}) ->
%% Waiting for receiving results or loss function
%% Got nan or inf from loss function - Error, loss function too big for double
wait(cast, {loss, nan , TrainTime , BatchID , SourceName}, State = #workerGeneric_state{myName = MyName, distributedBehaviorFunc = DistributedBehaviorFunc}) ->
io:format("~p got loss tensor nan for batch ~p~n", [MyName, BatchID]),
stats:increment_by_value(get(worker_stats_ets), nan_loss_count, 1),
WorkerToken = ets:lookup_element(get(generic_worker_ets), distributed_system_token, ?ETS_KEYVAL_VAL_IDX),
gen_statem:cast(get(client_pid),{loss, MyName , SourceName ,nan , TrainTime, WorkerToken ,BatchID}),
NextStateBehavior = DistributedBehaviorFunc(post_train, {get(generic_worker_ets),[]}), %% First call sends empty list , then it will be updated by the federated server and clients
EndStreamWaitingList = ets:lookup_element(get(generic_worker_ets), end_streams_waiting_list, ?ETS_KEYVAL_VAL_IDX),
io:format("EndStreamWaitingList: ~p~n",[EndStreamWaitingList]),
case length(EndStreamWaitingList) of
0 -> ok;
_ ->
Expand All @@ -193,13 +191,11 @@ wait(cast, {loss, nan , TrainTime , BatchID , SourceName}, State = #workerGeneri


wait(cast, {loss, {LossTensor, LossTensorType} , TrainTime , BatchID , SourceName}, State = #workerGeneric_state{myName = MyName, modelID=_ModelID, distributedBehaviorFunc = DistributedBehaviorFunc}) ->
io:format("~p got loss tensor for batch ~p~n", [MyName, BatchID]),
BatchTimeStamp = erlang:system_time(nanosecond),
WorkerToken = ets:lookup_element(get(generic_worker_ets), distributed_system_token, ?ETS_KEYVAL_VAL_IDX),
gen_statem:cast(get(client_pid),{loss, MyName, SourceName ,{LossTensor, LossTensorType} , TrainTime , WorkerToken, BatchID , BatchTimeStamp}),
NextStateBehavior = DistributedBehaviorFunc(post_train, {get(generic_worker_ets),[]}), %% First call sends empty list , then it will be updated by the federated server and clients
EndStreamWaitingList = ets:lookup_element(get(generic_worker_ets), end_streams_waiting_list, ?ETS_KEYVAL_VAL_IDX),
io:format("EndStreamWaitingList: ~p~n",[EndStreamWaitingList]),
case length(EndStreamWaitingList) of
0 -> ok;
_ ->
Expand Down
4 changes: 0 additions & 4 deletions src_erl/NerlnetApp/src/Client/clientStatem.erl
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,6 @@ training(cast, In = {start_stream , Data}, State = #client_statem_state{etsRef =

training(cast, In = {end_stream , Data}, State = #client_statem_state{etsRef = EtsRef}) ->
{SourceName, _ClientName, WorkerName} = binary_to_term(Data),
% io:format("~p send end_stream to worker ~p~n",[SourceName, WorkerName]),
ClientStatsEts = get(client_stats_ets),
stats:increment_messages_received(ClientStatsEts),
stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)),
Expand All @@ -294,7 +293,6 @@ training(cast, In = {stream_ended , Pair}, State = #client_statem_state{etsRef =
ListOfActiveWorkersSources = ets:lookup_element(EtsRef, active_workers_streams, ?DATA_IDX),
UpdatedListOfActiveWorkersSources = ListOfActiveWorkersSources -- [Pair],
ets:update_element(EtsRef, active_workers_streams, {?DATA_IDX, UpdatedListOfActiveWorkersSources}),
io:format("Updated List of Active Workers ~p~n",[UpdatedListOfActiveWorkersSources]),
case length(UpdatedListOfActiveWorkersSources) of
0 -> ets:update_element(EtsRef, all_workers_done, {?DATA_IDX, true});
_ -> ok
Expand All @@ -308,7 +306,6 @@ training(cast, In = {idle}, State = #client_statem_state{myName = MyName, etsRef
stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)),
MessageToCast = {idle},
WorkersDone = ets:lookup_element(EtsRef , all_workers_done , ?DATA_IDX),
% io:format("Client ~p Workers Done? ~p~n",[MyName, WorkersDone]),
case WorkersDone of
true -> cast_message_to_workers(EtsRef, MessageToCast),
Workers = clientWorkersFunctions:get_workers_names(EtsRef),
Expand Down Expand Up @@ -408,7 +405,6 @@ predict(cast, In = {idle}, State = #client_statem_state{myName = MyName, etsRef
stats:increment_bytes_received(ClientStatsEts , nerl_tools:calculate_size(In)),
MessageToCast = {idle},
WorkersDone = ets:lookup_element(EtsRef , all_workers_done , ?DATA_IDX),
% io:format("Client ~p Workers Done? ~p~n",[MyName, WorkersDone]),
case WorkersDone of
true -> cast_message_to_workers(EtsRef, MessageToCast),
Workers = clientWorkersFunctions:get_workers_names(EtsRef),
Expand Down

0 comments on commit f693df2

Please sign in to comment.