Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 117 additions & 5 deletions misc/python/materialize/cli/mz_workload_anonymize.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,68 @@ def keywords() -> set[str]:
return result


def _iter_sql(obj: Any, path: str = "") -> Any:
"""Yield (location, sql) for every create_sql/sql string in the workload."""
if isinstance(obj, dict):
for key, value in obj.items():
child_path = f"{path}.{key}"
if key in ("create_sql", "sql") and isinstance(value, str):
yield child_path, value
else:
yield from _iter_sql(value, child_path)
elif isinstance(obj, list):
for i, value in enumerate(obj):
yield from _iter_sql(value, f"{path}[{i}]")


def verify_anonymized(
new: dict[str, Any], mapping: dict[str, str], args: argparse.Namespace
) -> list[str]:
"""Best-effort scan of anonymized output for data that should have been scrubbed.

This is a backstop for the heuristic text substitution, not a proof: it
catches whole-word survivals of original identifiers and any single-quoted
literal that was not reduced to a 'literal_N' placeholder. It cannot detect
sensitive data hidden in dollar-quoted strings, comments, or numeric
literals, which the anonymizer does not handle.

Cluster create_sql is exempt from the literal check: its literals (SIZE,
replication factor, availability zones) are non-sensitive configuration that
replay must preserve verbatim, so they are intentionally not anonymized.
"""
problems: list[str] = []

# Identifiers that were actually renamed (keywords map to themselves).
identifier_checks: list[tuple[str, re.Pattern[str]]] = []
if args.identifiers:
for original, anonymized in mapping.items():
if original == anonymized:
continue
if re.fullmatch(r"\w+", original):
pattern = re.compile(r"\b" + re.escape(original) + r"\b")
else:
pattern = re.compile(re.escape(original))
identifier_checks.append((original, pattern))

string_literal = re.compile(r"'(?:[^']|'')*'")
placeholder = re.compile(r"^'literal_\d+'$")

for location, sql in _iter_sql(new):
for original, pattern in identifier_checks:
if pattern.search(sql):
problems.append(
f"{location}: original identifier {original!r} survived"
)
if args.literals and not location.startswith(".clusters"):
for match in string_literal.finditer(sql):
if not placeholder.fullmatch(match.group(0)):
problems.append(
f"{location}: non-anonymized string literal {match.group(0)!r}"
)

return problems


def main() -> int:
parser = argparse.ArgumentParser(
prog="mz-workload-anonymize",
Expand All @@ -46,14 +108,26 @@ def main() -> int:
"--output",
type=str,
default=None,
help="Path to write the workload.yml, overrides the input file if not specified",
help="Path to write the workload.yml, or - for stdout. Required unless --in-place is given.",
)
parser.add_argument(
"--in-place",
action="store_true",
help="Overwrite the input file with the anonymized workload. Destroys the original capture.",
)
parser.add_argument(
"--identifiers", action=argparse.BooleanOptionalAction, default=True
)
parser.add_argument(
"--literals", action=argparse.BooleanOptionalAction, default=True
)
parser.add_argument(
"--verify",
action=argparse.BooleanOptionalAction,
default=True,
help="After anonymizing, scan the output for surviving original identifiers and "
"non-anonymized string literals, and refuse to write if any are found.",
)

parser.add_argument(
"file",
Expand Down Expand Up @@ -100,7 +174,9 @@ def set_name(name: str, new_name: str) -> str:
else:
return name

string_literal_pattern = re.compile(r"'(?:[^']*(?:'')?)*'")
# Matches a single-quoted SQL string literal, including '' escapes. Written
# without nested quantifiers to avoid catastrophic backtracking (ReDoS).
string_literal_pattern = re.compile(r"'(?:[^']|'')*'")

def anonymize_string_literal(match: re.Match[str]) -> str:
count["literals"] += 1
Expand Down Expand Up @@ -274,16 +350,25 @@ def replace_literals(d: dict[str, Any], entry: str) -> None:
replace_literals(table, "create_sql")
for typ in schema["types"].values():
replace_identifiers(typ, "create_sql")
replace_literals(typ, "create_sql")
for conn in schema["connections"].values():
replace_identifiers(conn, "create_sql")
# Connection create_sql carries hostnames, usernames, regions,
# bucket/broker URLs, etc. as string literals; anonymize them.
replace_literals(conn, "create_sql")
for source in schema["sources"].values():
for column in source.get("columns", []):
if args.identifiers and column["type"] in mapping:
column["type"] = mapping[column["type"]]
for child in source.get("children", {}).values():
if args.identifiers:
child["schema"] = mapping[child["schema"]]
child["database"] = mapping[child["database"]]
# A child's schema/database may be a builtin or otherwise
# uncaptured name that never entered the mapping; leave
# those as-is rather than crashing.
child["schema"] = mapping.get(child["schema"], child["schema"])
child["database"] = mapping.get(
child["database"], child["database"]
)
for column in child["columns"]:
if args.identifiers and column["type"] in mapping:
column["type"] = mapping[column["type"]]
Expand All @@ -299,8 +384,12 @@ def replace_literals(d: dict[str, Any], entry: str) -> None:
replace_literals(mv, "create_sql")
for index in schema["indexes"].values():
replace_identifiers(index, "create_sql")
replace_literals(index, "create_sql")
for sink in schema["sinks"].values():
replace_identifiers(sink, "create_sql")
# Sink create_sql carries topic names, broker lists, and
# bucket/path URLs as string literals; anonymize them.
replace_literals(sink, "create_sql")
for query in workload["queries"]:
if args.identifiers:
query["cluster"] = mapping.get(query["cluster"], query["cluster"])
Expand All @@ -313,7 +402,30 @@ def replace_literals(d: dict[str, Any], entry: str) -> None:
replace_literals(query, "sql")
new["queries"].append(query)

output = args.output or args.file
if args.verify:
problems = verify_anonymized(new, mapping, args)
if problems:
print(
"Refusing to write output: anonymization left sensitive data behind.\n"
"Pass --no-verify to write anyway.",
file=sys.stderr,
)
for problem in problems:
print(f" {problem}", file=sys.stderr)
return 1

if args.output:
output = args.output
elif args.in_place:
output = args.file
else:
print(
"error: specify an output with -o/--output (use '-' for stdout) "
"or pass --in-place to overwrite the input file",
file=sys.stderr,
)
return 1

if output == "-":
yaml.dump(new, sys.stdout, Dumper=yaml.CSafeDumper)
else:
Expand Down
Loading