Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 79 additions & 9 deletions src/pxweb/api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections.abc import Iterator
from concurrent.futures import ThreadPoolExecutor
from logging import getLogger
from typing import Literal, TypeAlias
Expand Down Expand Up @@ -401,6 +402,41 @@ def get_table_data(
│ 0192 Nynäshamn ┆ uppgift saknas ┆ Folkmängd ┆ 2024 ┆ 0 │
└─────────────────────┴────────────────┴────────────────┴──────┴───────┘
"""
return list(
self.get_table_data_iter(table_id, value_codes, code_list, show)
)

def get_table_data_iter(
self,
table_id: str,
value_codes: dict[str, list[str] | str] | None = None,
code_list: dict[str, str] | None = None,
show: Literal["code", "value", "code_value"] | None = None,
) -> Iterator[dict]:
"""
Like `~~.PxApi.get_table_data`, but yields row dicts one at a time
instead of materialising the full dataset in memory before returning.
When a query is split into subqueries, rows are yielded as each
subquery completes, so processing can begin before all network calls
have finished. This is useful for streaming large tables to disk
without holding the whole result in RAM.

Parameters
----------
table_id: str
An ID of a table to get data from.
value_codes: dict, optional
The value codes to use for data selection where the keys are the variable codes. You can use the `~~.PxApi.get_table_variables()` to explore what's available.
code_list: dict, optional
Any named code list to use with a variable for code selection.
show: str, optional
Set to "code_value", "code" or "value", to specify what to show in the categorical columns.

Yields
------
:
One dict per data cell, in the same format as `~~.PxApi.get_table_data`.
"""
# TODO support output_values

if show not in (valid_show := {"code", "value", "code_value", None}):
Expand All @@ -413,8 +449,8 @@ def get_table_data(
response = self._client.call(
endpoint=f"/tables/{table_id}/data",
)
dataset = unpack_table_data(response, show=show)
return dataset
yield from unpack_table_data(response, show=show)
return

# A shallow copy to avoid unexpected mutation, e.g. turning a single item into a list
value_codes = dict(value_codes)
Expand Down Expand Up @@ -492,15 +528,14 @@ def fetch(query):
value_codes, self._client.configuration["maxDataCells"]
)
]
dataset = []
if self.max_workers == 1:
# 1 worker = sequential on main thread
logger.debug(
"Fetching %s subqueries",
len(subqueries),
)
for subquery in subqueries:
dataset.extend(fetch(subquery))
yield from fetch(subquery)
else:
logger.debug(
"Fetching %s subqueries with %s workers",
Expand All @@ -513,13 +548,11 @@ def fetch(query):
) as executor:
# Map() so that we yield results in order
for result in executor.map(fetch, subqueries):
dataset.extend(result)
yield from result
else:
# No batching needed so we just go ahead with the query as is
query = build_query(value_codes, code_list)
dataset = fetch(query)

return dataset
yield from fetch(query)

def get_table_data_all(
self,
Expand All @@ -542,10 +575,47 @@ def get_table_data_all(
:
A dataset in a native format that can be loaded into a dataframe.
"""
return list(self.get_table_data_all_iter(table_id, show=show))

def get_table_data_all_iter(
self,
table_id: str,
show: Literal["code", "value", "code_value"] | None = None,
) -> Iterator[dict]:
"""
Like `~~.PxApi.get_table_data_all`, but yields row dicts one at a time
instead of materialising the full dataset in memory before returning.
Rows are yielded as each subquery completes, so processing can begin
before all network calls have finished. This makes it possible to
stream very large tables to disk without holding the whole result in
RAM.

Parameters
----------
table_id: str
An ID of a table to get data from.
show: str, optional
Set to "code_value", "code" or "value", to specify what to show in the categorical columns.

Yields
------
:
One dict per data cell, in the same format as `~~.PxApi.get_table_data_all`.

Examples
--------
Stream a large table to a newline-delimited JSON file without loading
the whole dataset into memory.

>>> import json
>>> with open("TAB6683.ndjson", "w") as f:
... for record in api.get_table_data_all_iter("TAB6683"):
... f.write(json.dumps(record, ensure_ascii=False) + "\\n")
"""
selection_all: dict[str, list[str] | str] = {
k: ["*"] for k in self.get_table_variables(table_id)
}
return self.get_table_data(
yield from self.get_table_data_iter(
table_id, value_codes=selection_all, show=show
)

Expand Down
22 changes: 22 additions & 0 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,28 @@ def test_get_table_data(api):
assert len(dataset) > 1


def test_get_table_data_iter(api):
iterator = api.get_table_data_iter(table_id="TAB6471")

assert not isinstance(iterator, list)

rows = list(iterator)
assert all(isinstance(row, dict) for row in rows)
# The iterator should yield the same data as the list-returning method
assert rows == api.get_table_data(table_id="TAB6471")


def test_get_table_data_all_iter(api):
iterator = api.get_table_data_all_iter(table_id="TAB6471")

assert not isinstance(iterator, list)

rows = list(iterator)
assert len(rows) > 1
assert all(isinstance(row, dict) and "value" in row for row in rows)
assert rows == api.get_table_data_all(table_id="TAB6471")


def test_get_table_data_only_list_or_strings(api):
with pytest.raises(ValueError):
api.get_table_data(table_id="TAB6471", value_codes={"some_var": 42}) # type: ignore
Expand Down
Loading