diff --git a/pyproject.toml b/pyproject.toml index b9682d2..b6487dd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -96,6 +96,7 @@ dependencies = [ "timecopilot-granite-tsfm>=0.1.2", "timecopilot-timesfm>=0.2.1", "timecopilot-tirex>=0.1.0 ; python_full_version >= '3.11'", + "timecopilot-tirex>=0.1.1", "timecopilot-toto-2>=0.1.1", "timecopilot-toto>=0.1.6", "timecopilot-uni2ts>=0.1.2 ; python_full_version < '3.14'", diff --git a/timecopilot/models/foundation/tirex.py b/timecopilot/models/foundation/tirex.py index e370c70..264e8a3 100644 --- a/timecopilot/models/foundation/tirex.py +++ b/timecopilot/models/foundation/tirex.py @@ -15,6 +15,8 @@ from ..utils.forecaster import Forecaster, QuantileConverter from .utils import TimeSeriesDataset +DEFAULT_QUANTILES_TIREX = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] + class TiRex(Forecaster): """ @@ -86,31 +88,20 @@ def _forecast( quantiles: list[float] | None, ) -> tuple[np.ndarray, np.ndarray | None]: """handles distinction between quantiles and no quantiles""" - if quantiles is not None: - fcsts = [ - model.forecast( - batch, - prediction_length=h, - quantile_levels=quantiles, - output_type="numpy", - ) - for batch in tqdm(dataset) - ] # list of tuples - fcsts_quantiles, fcsts_mean = zip(*fcsts, strict=False) - fcsts_quantiles_np = np.concatenate(fcsts_quantiles) - fcsts_mean_np = np.concatenate(fcsts_mean) - else: - fcsts = [ - model.forecast( - batch, - prediction_length=h, - output_type="numpy", - ) - for batch in tqdm(dataset) - ] - fcsts_quantiles, fcsts_mean = zip(*fcsts, strict=False) - fcsts_mean_np = np.concatenate(fcsts_mean) - fcsts_quantiles_np = None + fcsts = [ + model.forecast( + batch, + prediction_length=h, + output_type="numpy", + ) + for batch in tqdm(dataset) + ] # list of tuples + fcsts_quantiles, fcsts_mean = zip(*fcsts, strict=False) + fcsts_mean_np = np.concatenate(fcsts_mean) + fcsts_quantiles_np = ( + None if quantiles is None else np.concatenate(fcsts_quantiles) + ) + return fcsts_mean_np, fcsts_quantiles_np def forecast( @@ -164,12 +155,23 @@ def forecast( - prediction intervals if `level` is specified. - quantile forecasts if `quantiles` is specified. - For multi-series data, the output retains the same unique - identifiers as the input DataFrame. + For multi-series data, the output retains the + same unique identifiers as the input DataFrame. """ freq = self._maybe_infer_freq(df, freq) qc = QuantileConverter(level=level, quantiles=quantiles) - dataset = TimeSeriesDataset.from_df(df, batch_size=self.batch_size) + if qc.quantiles is not None and len(qc.quantiles) != len( + DEFAULT_QUANTILES_TIREX + ): + raise ValueError( + "TiRex only supports the default quantiles, " + "please use the default quantiles or default level, " + ) + dataset = TimeSeriesDataset.from_df( + df, + batch_size=self.batch_size, + ) + fcst_df = dataset.make_future_dataframe(h=h, freq=freq) with self._get_model() as model: fcsts_mean_np, fcsts_quantiles_np = self._forecast( diff --git a/uv.lock b/uv.lock index bf082d2..5c27329 100644 --- a/uv.lock +++ b/uv.lock @@ -7203,7 +7203,6 @@ dependencies = [ {extra = ["torch"], name = "gluonts"}, {marker = "python_full_version < '3.13'", name = "tabpfn-time-series"}, {marker = "python_full_version < '3.14'", name = "timecopilot-uni2ts"}, - {marker = "python_full_version >= '3.11'", name = "timecopilot-tirex"}, {marker = "python_full_version >= '3.13'", name = "pandas", source = {registry = "https://pypi.org/simple"}, version = "2.3.3"}, {name = "accelerate"}, {name = "arch"}, @@ -7239,6 +7238,7 @@ dependencies = [ {name = "timecopilot-chronos-forecasting"}, {name = "timecopilot-granite-tsfm"}, {name = "timecopilot-timesfm"}, + {name = "timecopilot-tirex"}, {name = "timecopilot-toto"}, {name = "timecopilot-toto-2"}, {name = "torchmetrics"}, @@ -7326,6 +7326,7 @@ requires-dist = [ {name = "timecopilot-chronos-forecasting", specifier = ">=0.2.1"}, {name = "timecopilot-granite-tsfm", specifier = ">=0.1.2"}, {name = "timecopilot-timesfm", specifier = ">=0.2.1"}, + {name = "timecopilot-tirex", specifier = ">=0.1.1"}, {name = "timecopilot-toto", specifier = ">=0.1.6"}, {name = "timecopilot-toto-2", specifier = ">=0.1.1"}, {name = "torchmetrics", specifier = ">=1.8.2"}, @@ -7451,26 +7452,22 @@ wheels = [ [[package]] dependencies = [ - {marker = "python_full_version >= '3.11' and python_full_version < '3.13'", name = "numpy", source = {registry = "https://pypi.org/simple"}, version = "1.26.4"}, - {marker = "python_full_version >= '3.11' and python_full_version < '3.13'", name = "pandas", source = {registry = "https://pypi.org/simple"}, version = "2.1.4"}, - {marker = "python_full_version >= '3.11'", name = "dacite"}, - {marker = "python_full_version >= '3.11'", name = "einops"}, - {marker = "python_full_version >= '3.11'", name = "huggingface-hub"}, - {marker = "python_full_version >= '3.11'", name = "lightning"}, - {marker = "python_full_version >= '3.11'", name = "ninja"}, - {marker = "python_full_version >= '3.11'", name = "torch"}, - {marker = "python_full_version >= '3.11'", name = "torchvision"}, - {marker = "python_full_version >= '3.11'", name = "tqdm"}, - {marker = "python_full_version >= '3.11'", name = "xlstm"}, + {marker = "python_full_version < '3.11'", name = "xlstm", source = {registry = "https://pypi.org/simple"}, version = "2.0.0"}, + {marker = "python_full_version < '3.13'", name = "numpy", source = {registry = "https://pypi.org/simple"}, version = "1.26.4"}, + {marker = "python_full_version < '3.13'", name = "scikit-learn", source = {registry = "https://pypi.org/simple"}, version = "1.6.1"}, + {marker = "python_full_version >= '3.11'", name = "xlstm", source = {registry = "https://pypi.org/simple"}, version = "2.0.5"}, {marker = "python_full_version >= '3.13'", name = "numpy", source = {registry = "https://pypi.org/simple"}, version = "2.1.3"}, - {marker = "python_full_version >= '3.13'", name = "pandas", source = {registry = "https://pypi.org/simple"}, version = "2.3.3"}, + {marker = "python_full_version >= '3.13'", name = "scikit-learn", source = {registry = "https://pypi.org/simple"}, version = "1.7.2"}, + {name = "huggingface-hub"}, + {name = "ninja"}, + {name = "torch"}, ] name = "timecopilot-tirex" -sdist = {hash = "sha256:1c1b972e4705fce494d2c205d59512980bdad7bdd2124a298c9a6eee45ff3db2", size = 23952, upload-time = "2025-07-14T18:02:28.981Z", url = "https://files.pythonhosted.org/packages/04/5c/d14a58b48a52c00d3d66b4fbc380b582345c58765a7fe564cecb4270e001/timecopilot_tirex-0.1.0.tar.gz"} +sdist = {hash = "sha256:c493aff2d29d1fb7a4c3a5a2642df67030691744cc1ae79d5ea9583f4b5d8895", size = 56461, upload-time = "2026-06-09T23:24:46.477Z", url = "https://files.pythonhosted.org/packages/4a/06/2574fe82fb59ba11a8bf589677243cff349061e95912a762558ca46813aa/timecopilot_tirex-0.1.1.tar.gz"} source = {registry = "https://pypi.org/simple"} -version = "0.1.0" +version = "0.1.1" wheels = [ - {hash = "sha256:810ca0a410868010cb3bd8190e2539574af8df86f7cb64bf3f9363811d899f51", size = 25532, upload-time = "2025-07-14T18:02:27.849Z", url = "https://files.pythonhosted.org/packages/f0/b0/6dd7696d01a6ad7a4dc0095316aab073b900162208af432aaa459e17fc6c/timecopilot_tirex-0.1.0-py3-none-any.whl"}, + {hash = "sha256:7b2403db5274d92d223c178f8c5c3ec7210645b4c884b44d78e7f4c9b8d571ae", size = 61001, upload-time = "2026-06-09T23:24:44.581Z", url = "https://files.pythonhosted.org/packages/e5/f3/061de0fb9acfeb9866750af60531ae4d5ef6474ccdfa6e37e6b0eff19eed/timecopilot_tirex-0.1.1-py3-none-any.whl"}, ] [[package]] @@ -7724,39 +7721,6 @@ wheels = [ {hash = "sha256:08382fd96b923e39e904c4d570f3d49e2cc71ccabd2a94e0f895d1f0dac86242", size = 983161, upload-time = "2025-09-03T14:00:51.921Z", url = "https://files.pythonhosted.org/packages/02/21/aa0f434434c48490f91b65962b1ce863fdcce63febc166ca9fe9d706c2b6/torchmetrics-1.8.2-py3-none-any.whl"}, ] -[[package]] -dependencies = [ - {marker = "python_full_version >= '3.11' and python_full_version < '3.13'", name = "numpy", source = {registry = "https://pypi.org/simple"}, version = "1.26.4"}, - {marker = "python_full_version >= '3.11'", name = "pillow"}, - {marker = "python_full_version >= '3.11'", name = "torch"}, - {marker = "python_full_version >= '3.13'", name = "numpy", source = {registry = "https://pypi.org/simple"}, version = "2.1.3"}, -] -name = "torchvision" -source = {registry = "https://pypi.org/simple"} -version = "0.23.0" -wheels = [ - {hash = "sha256:01dc33ee24c79148aee7cdbcf34ae8a3c9da1674a591e781577b716d233b1fa6", size = 2395543, upload-time = "2025-08-06T14:58:04.373Z", url = "https://files.pythonhosted.org/packages/dd/14/7b44fe766b7d11e064c539d92a172fa9689a53b69029e24f2f1f51e7dc56/torchvision-0.23.0-cp311-cp311-manylinux_2_28_aarch64.whl"}, - {hash = "sha256:07d069cb29691ff566e3b7f11f20d91044f079e1dbdc9d72e0655899a9b06938", size = 1600749, upload-time = "2025-08-06T14:58:10.719Z", url = "https://files.pythonhosted.org/packages/1f/e4/028a27b60aa578a2fa99d9d7334ff1871bb17008693ea055a2fdee96da0d/torchvision-0.23.0-cp313-cp313-win_amd64.whl"}, - {hash = "sha256:09bfde260e7963a15b80c9e442faa9f021c7e7f877ac0a36ca6561b367185013", size = 1600741, upload-time = "2025-08-06T14:57:59.158Z", url = "https://files.pythonhosted.org/packages/93/40/3415d890eb357b25a8e0a215d32365a88ecc75a283f75c4e919024b22d97/torchvision-0.23.0-cp311-cp311-win_amd64.whl"}, - {hash = "sha256:1c37e325e09a184b730c3ef51424f383ec5745378dc0eca244520aca29722600", size = 1856886, upload-time = "2025-08-06T14:58:05.491Z", url = "https://files.pythonhosted.org/packages/91/37/45a5b9407a7900f71d61b2b2f62db4b7c632debca397f205fdcacb502780/torchvision-0.23.0-cp313-cp313-macosx_11_0_arm64.whl"}, - {hash = "sha256:2a3299d2b1d5a7aed2d3b6ffb69c672ca8830671967eb1cee1497bacd82fe47b", size = 2359295, upload-time = "2025-08-06T14:58:17.469Z", url = "https://files.pythonhosted.org/packages/1d/9d/406cea60a9eb9882145bcd62a184ee61e823e8e1d550cdc3c3ea866a9445/torchvision-0.23.0-cp313-cp313t-manylinux_2_28_aarch64.whl"}, - {hash = "sha256:2df618e1143805a7673aaf82cb5720dd9112d4e771983156aaf2ffff692eebf9", size = 2047192, upload-time = "2025-08-06T14:58:11.813Z", url = "https://files.pythonhosted.org/packages/05/35/72f91ad9ac7c19a849dedf083d347dc1123f0adeb401f53974f84f1d04c8/torchvision-0.23.0-cp313-cp313t-macosx_11_0_arm64.whl"}, - {hash = "sha256:2f7fd6c15f3697e80627b77934f77705f3bc0e98278b989b2655de01f6903e1d", size = 2353112, upload-time = "2025-08-06T14:58:26.265Z", url = "https://files.pythonhosted.org/packages/ac/da/a06c60fc84fc849377cf035d3b3e9a1c896d52dbad493b963c0f1cdd74d0/torchvision-0.23.0-cp313-cp313-manylinux_2_28_aarch64.whl"}, - {hash = "sha256:31c583ba27426a3a04eca8c05450524105c1564db41be6632f7536ef405a6de2", size = 2394251, upload-time = "2025-08-06T14:58:01.725Z", url = "https://files.pythonhosted.org/packages/25/44/ddd56d1637bac42a8c5da2c8c440d8a28c431f996dd9790f32dd9a96ca6e/torchvision-0.23.0-cp310-cp310-manylinux_2_28_aarch64.whl"}, - {hash = "sha256:35c27941831b653f5101edfe62c03d196c13f32139310519e8228f35eae0e96a", size = 8628388, upload-time = "2025-08-06T14:58:07.802Z", url = "https://files.pythonhosted.org/packages/79/9c/fcb09aff941c8147d9e6aa6c8f67412a05622b0c750bcf796be4c85a58d4/torchvision-0.23.0-cp311-cp311-manylinux_2_28_x86_64.whl"}, - {hash = "sha256:3932bf67256f2d095ce90a9f826f6033694c818856f4bb26794cf2ce64253e53", size = 8627497, upload-time = "2025-08-06T14:58:09.317Z", url = "https://files.pythonhosted.org/packages/93/f3/3cdf55bbf0f737304d997561c34ab0176222e0496b6743b0feab5995182c/torchvision-0.23.0-cp310-cp310-manylinux_2_28_x86_64.whl"}, - {hash = "sha256:49aa20e21f0c2bd458c71d7b449776cbd5f16693dd5807195a820612b8a229b7", size = 1856884, upload-time = "2025-08-06T14:58:00.237Z", url = "https://files.pythonhosted.org/packages/f0/d7/15d3d7bd8d0239211b21673d1bac7bc345a4ad904a8e25bb3fd8a9cf1fbc/torchvision-0.23.0-cp311-cp311-macosx_11_0_arm64.whl"}, - {hash = "sha256:4e7d31c43bc7cbecbb1a5652ac0106b436aa66e26437585fc2c4b2cf04d6014c", size = 8627108, upload-time = "2025-08-06T14:58:12.956Z", url = "https://files.pythonhosted.org/packages/e4/b5/3e580dcbc16f39a324f3dd71b90edbf02a42548ad44d2b4893cc92b1194b/torchvision-0.23.0-cp312-cp312-manylinux_2_28_x86_64.whl"}, - {hash = "sha256:6dd7c4d329a0e03157803031bc856220c6155ef08c26d4f5bbac938acecf0948", size = 2396614, upload-time = "2025-08-06T14:58:03.116Z", url = "https://files.pythonhosted.org/packages/e2/00/2f6454decc0cd67158c7890364e446aad4b91797087a57a78e72e1a8f8bc/torchvision-0.23.0-cp312-cp312-manylinux_2_28_aarch64.whl"}, - {hash = "sha256:7266871daca00ad46d1c073e55d972179d12a58fa5c9adec9a3db9bbed71284a", size = 1856885, upload-time = "2025-08-06T14:57:55.024Z", url = "https://files.pythonhosted.org/packages/4d/49/5ad5c3ff4920be0adee9eb4339b4fb3b023a0fc55b9ed8dbc73df92946b8/torchvision-0.23.0-cp310-cp310-macosx_11_0_arm64.whl"}, - {hash = "sha256:76bc4c0b63d5114aa81281390f8472a12a6a35ce9906e67ea6044e5af4cab60c", size = 8800474, upload-time = "2025-08-06T14:58:22.53Z", url = "https://files.pythonhosted.org/packages/2b/f4/34662f71a70fa1e59de99772142f22257ca750de05ccb400b8d2e3809c1d/torchvision-0.23.0-cp313-cp313t-manylinux_2_28_x86_64.whl"}, - {hash = "sha256:83ee5bf827d61a8af14620c0a61d8608558638ac9c3bac8adb7b27138e2147d1", size = 1600760, upload-time = "2025-08-06T14:57:56.783Z", url = "https://files.pythonhosted.org/packages/97/90/02afe57c3ef4284c5cf89d3b7ae203829b3a981f72b93a7dd2a3fd2c83c1/torchvision-0.23.0-cp310-cp310-win_amd64.whl"}, - {hash = "sha256:a2e45272abe7b8bf0d06c405e78521b5757be1bd0ed7e5cd78120f7fdd4cbf35", size = 1600723, upload-time = "2025-08-06T14:57:57.986Z", url = "https://files.pythonhosted.org/packages/82/c1/c2fe6d61e110a8d0de2f94276899a2324a8f1e6aee559eb6b4629ab27466/torchvision-0.23.0-cp312-cp312-win_amd64.whl"}, - {hash = "sha256:a76fafe113b2977be3a21bf78f115438c1f88631d7a87203acb3dd6ae55889e6", size = 8627658, upload-time = "2025-08-06T14:58:15.999Z", url = "https://files.pythonhosted.org/packages/a0/27/5ce65ba5c9d3b7d2ccdd79892ab86a2f87ac2ca6638f04bb0280321f1a9c/torchvision-0.23.0-cp313-cp313-manylinux_2_28_x86_64.whl"}, - {hash = "sha256:b9e2dabf0da9c8aa9ea241afb63a8f3e98489e706b22ac3f30416a1be377153b", size = 1527667, upload-time = "2025-08-06T14:58:14.446Z", url = "https://files.pythonhosted.org/packages/6e/f5/b5a2d841a8d228b5dbda6d524704408e19e7ca6b7bb0f24490e081da1fa1/torchvision-0.23.0-cp313-cp313t-win_amd64.whl"}, - {hash = "sha256:e0e2c04a91403e8dd3af9756c6a024a1d9c0ed9c0d592a8314ded8f4fe30d440", size = 1856885, upload-time = "2025-08-06T14:58:06.503Z", url = "https://files.pythonhosted.org/packages/df/1d/0ea0b34bde92a86d42620f29baa6dcbb5c2fc85990316df5cb8f7abb8ea2/torchvision-0.23.0-cp312-cp312-macosx_11_0_arm64.whl"}, -] - [[package]] name = "tornado" sdist = {hash = "sha256:ab53c8f9a0fa351e2c0741284e06c7a45da86afb544133201c5cc8578eb076a0", size = 510821, upload-time = "2025-08-08T18:27:00.78Z", url = "https://files.pythonhosted.org/packages/09/ce/1eb500eae19f4648281bb2186927bb062d2438c2e5093d1360391afd2f90/tornado-6.5.2.tar.gz"} @@ -8477,6 +8441,20 @@ wheels = [ {hash = "sha256:eabbd40d474b8dbf6cb3536325f9150b9e6f0db32d18de9914fb3227d0bef5b7", size = 2328527, upload-time = "2026-02-10T10:51:17.502Z", url = "https://files.pythonhosted.org/packages/93/f1/c09ef1add609453aa3ba5bafcd0d1c1a805c1263c0b60138ec968f8ec296/xgboost-3.2.0-py3-none-macosx_12_0_arm64.whl"}, ] +[[package]] +name = "xlstm" +resolution-markers = [ + "(python_full_version < '3.11' and platform_machine != 'x86_64' and sys_platform == 'darwin') or (python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux')", + "python_full_version < '3.11' and platform_machine == 'x86_64' and sys_platform == 'darwin'", + "python_full_version < '3.11' and sys_platform == 'linux'", +] +sdist = {hash = "sha256:4e25dd699144d90520113d2812d3243e3a72f1904d53b9e371a1bcf6bd99562e", size = 67094, upload-time = "2024-12-10T23:03:48.622Z", url = "https://files.pythonhosted.org/packages/21/2a/bee0c59f2e86593ea29cbf3722a312ef1fe27af556dd138e7c771f9b9112/xlstm-2.0.0.tar.gz"} +source = {registry = "https://pypi.org/simple"} +version = "2.0.0" +wheels = [ + {hash = "sha256:54ae4fc7433d0b281a15a136f2b8cafe17748e73c29e1c0cc8b1afc41c8e63b4", size = 89850, upload-time = "2024-12-10T23:03:46.177Z", url = "https://files.pythonhosted.org/packages/42/7d/0c95c6ff7f40c62614411f4f29dc7f8c52cafbe049713a94d573ae243eb4/xlstm-2.0.0-py3-none-any.whl"}, +] + [[package]] dependencies = [ {marker = "python_full_version >= '3.11' and python_full_version < '3.13'", name = "numpy", source = {registry = "https://pypi.org/simple"}, version = "1.26.4"}, @@ -8501,6 +8479,20 @@ dependencies = [ {marker = "python_full_version >= '3.13'", name = "rich", source = {registry = "https://pypi.org/simple"}, version = "14.1.0"}, ] name = "xlstm" +resolution-markers = [ + "(python_full_version == '3.11.*' and platform_machine != 'x86_64' and sys_platform == 'darwin') or (python_full_version == '3.11.*' and sys_platform != 'darwin' and sys_platform != 'linux')", + "(python_full_version == '3.12.*' and platform_machine != 'x86_64' and sys_platform == 'darwin') or (python_full_version == '3.12.*' and sys_platform != 'darwin' and sys_platform != 'linux')", + "(python_full_version == '3.13.*' and platform_machine != 'x86_64' and sys_platform == 'darwin') or (python_full_version == '3.13.*' and sys_platform != 'darwin' and sys_platform != 'linux')", + "(python_full_version >= '3.14' and platform_machine != 'x86_64' and sys_platform == 'darwin') or (python_full_version >= '3.14' and sys_platform != 'darwin' and sys_platform != 'linux')", + "python_full_version == '3.11.*' and platform_machine == 'x86_64' and sys_platform == 'darwin'", + "python_full_version == '3.11.*' and sys_platform == 'linux'", + "python_full_version == '3.12.*' and platform_machine == 'x86_64' and sys_platform == 'darwin'", + "python_full_version == '3.12.*' and sys_platform == 'linux'", + "python_full_version == '3.13.*' and platform_machine == 'x86_64' and sys_platform == 'darwin'", + "python_full_version == '3.13.*' and sys_platform == 'linux'", + "python_full_version >= '3.14' and platform_machine == 'x86_64' and sys_platform == 'darwin'", + "python_full_version >= '3.14' and sys_platform == 'linux'", +] sdist = {hash = "sha256:24a5572be44207fc15ed5dea6b805c4bcd450a8f2728320cf21b8082c535b60e", size = 71129, upload-time = "2025-08-24T14:38:49.493Z", url = "https://files.pythonhosted.org/packages/7f/4d/05efa4c76b8ade8cbd638e2b0329694a767146616186ef786107740e4e89/xlstm-2.0.5.tar.gz"} source = {registry = "https://pypi.org/simple"} version = "2.0.5"