Skip to content

Commit

Permalink
feat(framework) Enable federation configuration overrides via `--fede…
Browse files Browse the repository at this point in the history
…ration-config` in `flwr` CLI (#4841)

Co-authored-by: Javier <[email protected]>
Co-authored-by: Chong Shen Ng <[email protected]>
  • Loading branch information
3 people authored Jan 23, 2025
1 parent 2d54600 commit cf3eb04
Show file tree
Hide file tree
Showing 10 changed files with 190 additions and 28 deletions.
17 changes: 15 additions & 2 deletions src/py/flwr/cli/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,13 @@
import tomli
import typer

from flwr.common.config import get_fab_config, get_metadata_from_config, validate_config
from flwr.common.config import (
fuse_dicts,
get_fab_config,
get_metadata_from_config,
parse_config_args,
validate_config,
)


def get_fab_metadata(fab_file: Union[Path, bytes]) -> tuple[str, str]:
Expand Down Expand Up @@ -127,7 +133,9 @@ def process_loaded_project_config(


def validate_federation_in_project_config(
federation: Optional[str], config: dict[str, Any]
federation: Optional[str],
config: dict[str, Any],
overrides: Optional[list[str]] = None,
) -> tuple[str, dict[str, Any]]:
"""Validate the federation name in the Flower project configuration."""
federation = federation or config["tool"]["flwr"]["federations"].get("default")
Expand Down Expand Up @@ -157,6 +165,11 @@ def validate_federation_in_project_config(
)
raise typer.Exit(code=1)

# Override the federation configuration if provided
if overrides:
overrides_dict = parse_config_args(overrides, flatten=False)
federation_config = fuse_dicts(federation_config, overrides_dict)

return federation, federation_config


Expand Down
45 changes: 45 additions & 0 deletions src/py/flwr/cli/config_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,51 @@ def test_validate_federation_in_project_config() -> None:
assert federation_config == {"new_key": "new_val"}


def test_validate_federation_in_project_config_with_overrides() -> None:
"""Test that validate_federation_in_config works with overrides."""
# Prepare - Test federation is None
federation_config = {"k1": "v1", "k2": True, "grp": {"k3": 42.8}}
config: dict[str, Any] = {
"project": {
"name": "fedgpt",
"version": "1.0.0",
"description": "",
"license": "",
"authors": [],
},
"tool": {
"flwr": {
"app": {
"publisher": "flwrlabs",
"components": {
"serverapp": "flwr.cli.run:run",
"clientapp": "flwr.cli.run:run",
},
},
"federations": {
"default": "default_federation",
"default_federation": federation_config,
},
},
},
}
overrides = ["k1=false grp.k3=42.9", "k2='hello, world!'"]
federation = None

# Execute
federation, federation_config = validate_federation_in_project_config(
federation, config, overrides
)

# Assert
assert federation == "default_federation"
assert federation_config == {
"k1": False,
"k2": "hello, world!",
"grp": {"k3": 42.9},
}


def test_validate_federation_in_project_config_fail() -> None:
"""Test that validate_federation_in_config fails correctly."""

Expand Down
27 changes: 27 additions & 0 deletions src/py/flwr/cli/constant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright 2025 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Constants for CLI commands."""


# The help message for `--federation-config` option
FEDERATION_CONFIG_HELP_MESSAGE = (
"Override federation configuration values in the format:\n\n"
"`--federation-config 'key1=value1 key2=value2' --federation-config "
"'key3=value3'`\n\nValues can be of any type supported in TOML, such as "
"bool, int, float, or string. Ensure that the keys (`key1`, `key2`, `key3` "
"in this example) exist in the federation configuration under the "
"`[tool.flwr.federations.<YOUR_FEDERATION>]` table of the `pyproject.toml` "
"for proper overriding."
)
19 changes: 17 additions & 2 deletions src/py/flwr/cli/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
process_loaded_project_config,
validate_federation_in_project_config,
)
from flwr.cli.constant import FEDERATION_CONFIG_HELP_MESSAGE
from flwr.common.constant import CONN_RECONNECT_INTERVAL, CONN_REFRESH_PERIOD
from flwr.common.logger import log as logger
from flwr.proto.exec_pb2 import StreamLogsRequest # pylint: disable=E0611
Expand Down Expand Up @@ -57,6 +58,8 @@ def start_stream(
logger(ERROR, "Invalid run_id `%s`, exiting", run_id)
if e.code() == grpc.StatusCode.CANCELLED:
pass
else:
raise e
finally:
channel.close()

Expand Down Expand Up @@ -123,6 +126,7 @@ def print_logs(run_id: int, channel: grpc.Channel, timeout: int) -> None:
break
if e.code() == grpc.StatusCode.CANCELLED:
break
raise e
except KeyboardInterrupt:
logger(DEBUG, "Stream interrupted by user")
finally:
Expand All @@ -143,6 +147,13 @@ def log(
Optional[str],
typer.Argument(help="Name of the federation to run the app on"),
] = None,
federation_config_overrides: Annotated[
Optional[list[str]],
typer.Option(
"--federation-config",
help=FEDERATION_CONFIG_HELP_MESSAGE,
),
] = None,
stream: Annotated[
bool,
typer.Option(
Expand All @@ -158,11 +169,15 @@ def log(
config, errors, warnings = load_and_validate(path=pyproject_path)
config = process_loaded_project_config(config, errors, warnings)
federation, federation_config = validate_federation_in_project_config(
federation, config
federation, config, federation_config_overrides
)
exit_if_no_address(federation_config, "log")

_log_with_exec_api(app, federation, federation_config, run_id, stream)
try:
_log_with_exec_api(app, federation, federation_config, run_id, stream)
except Exception as err: # pylint: disable=broad-except
typer.secho(str(err), fg=typer.colors.RED, bold=True)
raise typer.Exit(code=1) from None


def _log_with_exec_api(
Expand Down
10 changes: 9 additions & 1 deletion src/py/flwr/cli/login/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
process_loaded_project_config,
validate_federation_in_project_config,
)
from flwr.cli.constant import FEDERATION_CONFIG_HELP_MESSAGE
from flwr.common.typing import UserAuthLoginDetails
from flwr.proto.exec_pb2 import ( # pylint: disable=E0611
GetLoginDetailsRequest,
Expand All @@ -45,6 +46,13 @@ def login( # pylint: disable=R0914
Optional[str],
typer.Argument(help="Name of the federation to login into."),
] = None,
federation_config_overrides: Annotated[
Optional[list[str]],
typer.Option(
"--federation-config",
help=FEDERATION_CONFIG_HELP_MESSAGE,
),
] = None,
) -> None:
"""Login to Flower SuperLink."""
typer.secho("Loading project configuration... ", fg=typer.colors.BLUE)
Expand All @@ -54,7 +62,7 @@ def login( # pylint: disable=R0914

config = process_loaded_project_config(config, errors, warnings)
federation, federation_config = validate_federation_in_project_config(
federation, config
federation, config, federation_config_overrides
)
exit_if_no_address(federation_config, "login")
channel = init_channel(app, federation_config, None)
Expand Down
12 changes: 10 additions & 2 deletions src/py/flwr/cli/ls.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
process_loaded_project_config,
validate_federation_in_project_config,
)
from flwr.cli.constant import FEDERATION_CONFIG_HELP_MESSAGE
from flwr.common.constant import FAB_CONFIG_FILE, CliOutputFormat, SubStatus
from flwr.common.date import format_timedelta, isoformat8601_utc
from flwr.common.logger import print_json_error, redirect_output, restore_output
Expand All @@ -48,7 +49,7 @@
_RunListType = tuple[int, str, str, str, str, str, str, str, str]


def ls( # pylint: disable=too-many-locals, too-many-branches
def ls( # pylint: disable=too-many-locals, too-many-branches, R0913, R0917
app: Annotated[
Path,
typer.Argument(help="Path of the Flower project"),
Expand All @@ -57,6 +58,13 @@ def ls( # pylint: disable=too-many-locals, too-many-branches
Optional[str],
typer.Argument(help="Name of the federation"),
] = None,
federation_config_overrides: Annotated[
Optional[list[str]],
typer.Option(
"--federation-config",
help=FEDERATION_CONFIG_HELP_MESSAGE,
),
] = None,
runs: Annotated[
bool,
typer.Option(
Expand Down Expand Up @@ -106,7 +114,7 @@ def ls( # pylint: disable=too-many-locals, too-many-branches
config, errors, warnings = load_and_validate(path=pyproject_path)
config = process_loaded_project_config(config, errors, warnings)
federation, federation_config = validate_federation_in_project_config(
federation, config
federation, config, federation_config_overrides
)
exit_if_no_address(federation_config, "ls")

Expand Down
30 changes: 20 additions & 10 deletions src/py/flwr/cli/run/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
process_loaded_project_config,
validate_federation_in_project_config,
)
from flwr.cli.constant import FEDERATION_CONFIG_HELP_MESSAGE
from flwr.common.config import (
flatten_dict,
parse_config_args,
Expand All @@ -57,7 +58,7 @@
CONN_REFRESH_PERIOD = 60 # Connection refresh period for log streaming (seconds)


# pylint: disable-next=too-many-locals
# pylint: disable-next=too-many-locals, R0913, R0917
def run(
app: Annotated[
Path,
Expand All @@ -67,16 +68,23 @@ def run(
Optional[str],
typer.Argument(help="Name of the federation to run the app on."),
] = None,
config_overrides: Annotated[
run_config_overrides: Annotated[
Optional[list[str]],
typer.Option(
"--run-config",
"-c",
help="Override configuration key-value pairs, should be of the format:\n\n"
'`--run-config \'key1="value1" key2="value2"\' '
"--run-config 'key3=\"value3\"'`\n\n"
"Note that `key1`, `key2`, and `key3` in this example need to exist "
"inside the `pyproject.toml` in order to be properly overriden.",
help="Override run configuration values in the format:\n\n"
"`--run-config 'key1=value1 key2=value2' --run-config 'key3=value3'`\n\n"
"Values can be of any type supported in TOML, such as bool, int, "
"float, or string. Ensure that the keys (`key1`, `key2`, `key3` "
"in this example) exist in `pyproject.toml` for proper overriding.",
),
] = None,
federation_config_overrides: Annotated[
Optional[list[str]],
typer.Option(
"--federation-config",
help=FEDERATION_CONFIG_HELP_MESSAGE,
),
] = None,
stream: Annotated[
Expand Down Expand Up @@ -108,20 +116,22 @@ def run(
config, errors, warnings = load_and_validate(path=pyproject_path)
config = process_loaded_project_config(config, errors, warnings)
federation, federation_config = validate_federation_in_project_config(
federation, config
federation, config, federation_config_overrides
)

if "address" in federation_config:
_run_with_exec_api(
app,
federation,
federation_config,
config_overrides,
run_config_overrides,
stream,
output_format,
)
else:
_run_without_exec_api(app, federation_config, config_overrides, federation)
_run_without_exec_api(
app, federation_config, run_config_overrides, federation
)
except (typer.Exit, Exception) as err: # pylint: disable=broad-except
if suppress_output:
restore_output()
Expand Down
10 changes: 9 additions & 1 deletion src/py/flwr/cli/stop.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
process_loaded_project_config,
validate_federation_in_project_config,
)
from flwr.cli.constant import FEDERATION_CONFIG_HELP_MESSAGE
from flwr.common.constant import FAB_CONFIG_FILE, CliOutputFormat
from flwr.common.logger import print_json_error, redirect_output, restore_output
from flwr.proto.exec_pb2 import StopRunRequest, StopRunResponse # pylint: disable=E0611
Expand All @@ -50,6 +51,13 @@ def stop( # pylint: disable=R0914
Optional[str],
typer.Argument(help="Name of the federation"),
] = None,
federation_config_overrides: Annotated[
Optional[list[str]],
typer.Option(
"--federation-config",
help=FEDERATION_CONFIG_HELP_MESSAGE,
),
] = None,
output_format: Annotated[
str,
typer.Option(
Expand All @@ -73,7 +81,7 @@ def stop( # pylint: disable=R0914
config, errors, warnings = load_and_validate(path=pyproject_path)
config = process_loaded_project_config(config, errors, warnings)
federation, federation_config = validate_federation_in_project_config(
federation, config
federation, config, federation_config_overrides
)
exit_if_no_address(federation_config, "stop")

Expand Down
Loading

0 comments on commit cf3eb04

Please sign in to comment.