Skip to content

Commit 1772d4e

Browse files
authored
Merge pull request PrefectHQ#195 from PrefectHQ/user-config
Add support for user configuration files
2 parents 0e9a12c + 25384c3 commit 1772d4e

File tree

3 files changed

+85
-39
lines changed

3 files changed

+85
-39
lines changed

src/prefect/config.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
[general]
2+
3+
# the location of the user's config file
4+
user_config_path = "$HOME/.prefect/config.toml"
5+
6+
17
[server]
28
# the Prefect Server address
39
api_server = "http://127.0.0.1:4200"

src/prefect/configuration.py

Lines changed: 38 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,9 @@
66

77
import toml
88

9-
import prefect
109
from prefect.utilities import collections
1110

1211
DEFAULT_CONFIG = os.path.join(os.path.dirname(__file__), "config.toml")
13-
USER_CONFIG = "~/.prefect/config.toml"
1412
ENV_VAR_PREFIX = "PREFECT"
1513
INTERPOLATION_REGEX = re.compile(r"\${(.[^${}]*)}")
1614

@@ -74,7 +72,7 @@ def validate_config(config: Config) -> None:
7472
# Load configuration ----------------------------------------------------------
7573

7674

77-
def load_config_file(path: str, env_var_prefix: str = ENV_VAR_PREFIX) -> Config:
75+
def load_config_file(path: str, env_var_prefix: str = None) -> Config:
7876
"""
7977
Loads a configuration file from a path, optionally merging it into an existing
8078
configuration.
@@ -90,19 +88,22 @@ def load_config_file(path: str, env_var_prefix: str = ENV_VAR_PREFIX) -> Config:
9088
# --------------------- Interpolate env vars -----------------------
9189
# check if any env var sets a configuration value with the format:
9290
# [ENV_VAR_PREFIX]__[Section]__[Optional Sub-Sections...]__[Key] = Value
93-
for env_var in os.environ:
94-
if env_var.startswith(env_var_prefix + "__"):
91+
if env_var_prefix:
92+
for env_var in os.environ:
93+
if env_var.startswith(env_var_prefix + "__"):
9594

96-
# strip the prefix off the env var
97-
env_var_option = env_var[len(env_var_prefix + "__") :]
95+
# strip the prefix off the env var
96+
env_var_option = env_var[len(env_var_prefix + "__") :]
9897

99-
# make sure the resulting env var has at least one delimitied section and key
100-
if "__" not in env_var:
101-
continue
98+
# make sure the resulting env var has at least one delimitied section and key
99+
if "__" not in env_var:
100+
continue
102101

103-
# place the env var in the flat config as a compound key
104-
config_option = collections.CompoundKey(env_var_option.lower().split("__"))
105-
flat_config[config_option] = interpolate_env_var(os.getenv(env_var))
102+
# place the env var in the flat config as a compound key
103+
config_option = collections.CompoundKey(
104+
env_var_option.lower().split("__")
105+
)
106+
flat_config[config_option] = interpolate_env_var(os.getenv(env_var))
106107

107108
# interpolate any env vars referenced
108109
for k, v in list(flat_config.items()):
@@ -153,28 +154,39 @@ def load_config_file(path: str, env_var_prefix: str = ENV_VAR_PREFIX) -> Config:
153154

154155

155156
def load_configuration(
156-
default_config_path: str, user_config_path: str, env_var_prefix: str = None
157+
config_path: str, env_var_prefix: str = None, merge_into_config: Config = None
157158
) -> Config:
159+
"""
160+
Given a `config_path` with a toml configuration file, returns a Config object.
161+
162+
Args:
163+
- config_path (str): the path to the toml configuration file
164+
- env_var_prefix (str): if provided, environment variables starting with this prefix
165+
will be added as configuration settings.
166+
- merge_into_config (Config): if provided, the configuration loaded from
167+
`config_path` will be merged into a copy of this configuration file. The merged
168+
Config is returned.
169+
"""
158170

159171
# load default config
160-
config = load_config_file(default_config_path, env_var_prefix=env_var_prefix or "")
172+
config = load_config_file(config_path, env_var_prefix=env_var_prefix or "")
161173

162-
# if user config exists, load and merge it with default config
163-
if os.path.isfile(user_config_path):
164-
user_config = load_config_file(
165-
user_config_path, env_var_prefix=env_var_prefix or ""
166-
)
167-
config = collections.merge_dicts(config, user_config)
174+
if merge_into_config is not None:
175+
config = collections.merge_dicts(merge_into_config, config)
168176

169177
validate_config(config)
170178

171179
return config
172180

173181

174-
config = load_configuration(
175-
default_config_path=DEFAULT_CONFIG,
176-
user_config_path=USER_CONFIG,
177-
env_var_prefix=ENV_VAR_PREFIX,
178-
)
182+
config = load_configuration(config_path=DEFAULT_CONFIG, env_var_prefix=ENV_VAR_PREFIX)
183+
184+
# if user config exists, load and merge it with default config
185+
if os.path.isfile(config.get("general", {}).get("user_config_path", "")):
186+
config = load_configuration(
187+
config_path=config.general.user_config_path,
188+
env_var_prefix=ENV_VAR_PREFIX,
189+
merge_into_config=config,
190+
)
179191

180192
configure_logging(logger_name="Prefect")

tests/test_configuration.py

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -78,22 +78,50 @@ def test_env_var_interpolation(config):
7878

7979

8080
def test_env_var_overrides(test_config_file_path):
81-
os.environ["PREFECT__ENV_VARS__TEST"] = "OVERRIDE!"
82-
assert config(test_config_file_path).env_vars.test == "OVERRIDE!"
81+
try:
82+
ev = "PREFECT__ENV_VARS__TEST"
83+
os.environ[ev] = "OVERRIDE!"
84+
config = configuration.load_config_file(
85+
test_config_file_path, env_var_prefix="PREFECT"
86+
)
87+
assert config.env_vars.test == "OVERRIDE!"
88+
finally:
89+
del os.environ[ev]
8390

8491

85-
def test_load_user_config_and_update_default(test_config_file_path):
92+
def test_merge_configurations(test_config_file_path):
93+
94+
default_config = configuration.config
95+
96+
assert default_config.logging.format != "log-format"
97+
assert default_config.flows.default_version == "1"
98+
8699
config = configuration.load_configuration(
87-
default_config_path=configuration.DEFAULT_CONFIG,
88-
user_config_path=test_config_file_path,
100+
config_path=test_config_file_path, merge_into_config=default_config
89101
)
90-
assert "logging" in config
91102

92-
# this comes from default
93-
assert config.logging.level == "INFO"
94-
# this comes from user
95103
assert config.logging.format == "log-format"
96-
97-
# this comes from user
98-
assert "general" in config
99-
assert config.general.nested.x == 1
104+
assert config.flows.default_version == "1"
105+
assert config.interpolation.value == 1
106+
107+
108+
def test_load_user_config(test_config_file_path):
109+
110+
with tempfile.NamedTemporaryFile() as user_config:
111+
user_config.write(
112+
b"""
113+
[general]
114+
x = 2
115+
116+
[user]
117+
foo = "bar"
118+
"""
119+
)
120+
user_config.seek(0)
121+
122+
test_config = configuration.load_configuration(test_config_file_path)
123+
config = configuration.load_configuration(
124+
user_config.name, merge_into_config=test_config
125+
)
126+
assert config.general.x == 2
127+
assert config.user.foo == "bar"

0 commit comments

Comments
 (0)