Remove merge_into and just have merged which copies inputs to avoid footguns

This commit is contained in:
Olivier Wilkinson (reivilibre) 2024-01-17 15:02:07 +00:00
parent 29541fd994
commit c91ab4bc55

View file

@ -55,6 +55,7 @@ import subprocess
import sys import sys
from argparse import ArgumentParser from argparse import ArgumentParser
from collections import defaultdict from collections import defaultdict
from copy import deepcopy
from dataclasses import dataclass, field from dataclasses import dataclass, field
from itertools import chain from itertools import chain
from pathlib import Path from pathlib import Path
@ -321,37 +322,42 @@ def flush_buffers() -> None:
sys.stderr.flush() sys.stderr.flush()
def merge_into(dest: Any, new: Any) -> None: def merged(a: Any, b: Any) -> Any:
""" """
Merges `new` into `dest` with the following rules: Merges `a` and `b` together, returning the result.
The merge is performed with the following rules:
- dicts: values with the same key will be merged recursively - dicts: values with the same key will be merged recursively
- lists: `new` will be appended to `dest` - lists: `new` will be appended to `dest`
- primitives: they will be checked for equality and inequality will result - primitives: they will be checked for equality and inequality will result
in a ValueError in a ValueError
It is an error for `dest` and `new` to be of different types.
It is an error for `a` and `b` to be of different types.
""" """
if isinstance(dest, dict) and isinstance(new, dict): if isinstance(a, dict) and isinstance(b, dict):
for k, v in new.items(): result = {}
if k in dest: for key in set(a.keys()) | set(b.keys()):
merge_into(dest[k], v) if key in a and key in b:
result[key] = merged(a[key], b[key])
elif key in a:
result[key] = deepcopy(a[key])
else: else:
dest[k] = v result[key] = deepcopy(b[key])
elif isinstance(dest, list) and isinstance(new, list):
dest.extend(new)
elif type(dest) != type(new):
raise TypeError(f"Cannot merge {type(dest).__name__} and {type(new).__name__}")
elif dest != new:
raise ValueError(f"Cannot merge primitive values: {dest!r} != {new!r}")
return result
elif isinstance(a, list) and isinstance(b, list):
return deepcopy(a) + deepcopy(b)
elif type(a) != type(b):
raise TypeError(f"Cannot merge {type(a).__name__} and {type(b).__name__}")
elif a != b:
raise ValueError(f"Cannot merge primitive values: {a!r} != {b!r}")
def merged(a: Dict[str, Any], b: Dict[str, Any]) -> Dict[str, Any]: if type(a) not in {str, int, float, bool, None.__class__}:
""" raise TypeError(
Merges `b` into `a` and returns `a`. Here because we can't use `merge_into` f"Cannot use `merged` on type {a} as it may not be safe (must either be an immutable primitive or must have special copy/merge logic)"
in a lamba conveniently. )
"""
merge_into(a, b)
return a return a
@ -454,10 +460,10 @@ def instantiate_worker_template(
Returns: worker configuration dictionary Returns: worker configuration dictionary
""" """
worker_config_dict = dataclasses.asdict(template) worker_config_dict = dataclasses.asdict(template)
stream_writers_dict = { stream_writers_dict = {writer: worker_name for writer in template.stream_writers}
writer: worker_name for writer in template.stream_writers worker_config_dict["shared_extra_conf"] = merged(
} template.shared_extra_conf(worker_name), stream_writers_dict
worker_config_dict["shared_extra_conf"] = merged(template.shared_extra_conf(worker_name), stream_writers_dict) )
worker_config_dict["endpoint_patterns"] = sorted(template.endpoint_patterns) worker_config_dict["endpoint_patterns"] = sorted(template.endpoint_patterns)
worker_config_dict["listener_resources"] = sorted(template.listener_resources) worker_config_dict["listener_resources"] = sorted(template.listener_resources)
return worker_config_dict return worker_config_dict
@ -786,7 +792,7 @@ def generate_worker_files(
) )
# Update the shared config with any options needed to enable this worker. # Update the shared config with any options needed to enable this worker.
merge_into(shared_config, worker_config["shared_extra_conf"]) shared_config = merged(shared_config, worker_config["shared_extra_conf"])
if using_unix_sockets: if using_unix_sockets:
healthcheck_urls.append( healthcheck_urls.append(