Shortcuts

Source code for flash.tabular.forecasting.data

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, Type, Union

from torch.utils.data.sampler import Sampler

from flash.core.data.callback import BaseDataFetcher
from flash.core.data.data_module import DataModule
from flash.core.data.io.input import Input
from flash.core.data.io.input_transform import INPUT_TRANSFORM_TYPE, InputTransform
from flash.core.utilities.imports import _PANDAS_AVAILABLE, _TABULAR_TESTING
from flash.core.utilities.stages import RunningStage
from flash.tabular.forecasting.input import TabularForecastingDataFrameInput

if _PANDAS_AVAILABLE:
    from pandas.core.frame import DataFrame
else:
    DataFrame = object


# Skip doctests if requirements aren't available
if not _TABULAR_TESTING:
    __doctest_skip__ = ["TabularForecastingData", "TabularForecastingData.*"]


[docs]class TabularForecastingData(DataModule): """The ``TabularForecastingData`` class is a :class:`~flash.core.data.data_module.DataModule` with a set of classmethods for loading data for tabular forecasting.""" input_transform_cls = InputTransform @property def parameters(self) -> Optional[Dict[str, Any]]: """The ``parameters`` dictionary from the ``TimeSeriesDataSet`` object created from the train data when constructing the ``TabularForecastingData`` object.""" return getattr(self.train_dataset, "parameters", None)
[docs] @classmethod def from_data_frame( cls, time_idx: Optional[str] = None, target: Optional[Union[str, List[str]]] = None, group_ids: Optional[List[str]] = None, parameters: Optional[Dict[str, Any]] = None, train_data_frame: Optional[DataFrame] = None, val_data_frame: Optional[DataFrame] = None, test_data_frame: Optional[DataFrame] = None, predict_data_frame: Optional[DataFrame] = None, input_cls: Type[Input] = TabularForecastingDataFrameInput, transform: INPUT_TRANSFORM_TYPE = InputTransform, transform_kwargs: Optional[Dict] = None, data_fetcher: Optional[BaseDataFetcher] = None, val_split: Optional[float] = None, batch_size: Optional[int] = None, num_workers: int = 0, sampler: Optional[Type[Sampler]] = None, pin_memory: bool = True, persistent_workers: bool = True, **input_kwargs: Any, ) -> "TabularForecastingData": """Creates a :class:`~flash.tabular.forecasting.data.TabularForecastingData` object from the given data frames. .. note:: The ``time_idx``, ``target``, and ``group_ids`` do not need to be provided if ``parameters`` are passed instead. These can be obtained from the :attr:`~flash.tabular.forecasting.data.TabularForecastingData.parameters` attribute of the :class:`~flash.tabular.forecasting.data.TabularForecastingData` object that contains your training data. To learn how to customize the transforms applied for each stage, read our :ref:`customizing transforms guide <customizing_transforms>`. Args: time_idx: Column denoting the time index of each observation. target: Column denoting the target or list of columns denoting the target. group_ids: List of column names identifying a time series. This means that the group_ids identify a sample together with the time_idx. If you have only one timeseries, set this to the name of a column that is constant. parameters: Parameters to use for the timeseries if ``time_idx``, ``target``, and ``group_ids`` are not provided (e.g. when loading data for inference or validation). train_data_frame: The pandas DataFrame to use when training. val_data_frame: The pandas DataFrame to use when validating. test_data_frame: The pandas DataFrame to use when testing. predict_data_frame: The pandas DataFrame to use when predicting. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. input_kwargs: Additional keyword arguments to be used when creating the TimeSeriesDataset. Returns: The constructed :class:`~flash.tabular.forecasting.data.TabularForecastingData`. Examples ________ .. testsetup:: >>> from pytorch_forecasting.data.examples import generate_ar_data >>> data = generate_ar_data(seasonality=10.0, timesteps=100, n_series=5, seed=42) We have a DataFrame `data` with the following contents: .. doctest:: >>> data.head(3) series time_idx value 0 0 0 -0.000000 1 0 1 0.141552 2 0 2 0.232782 .. doctest:: >>> from pandas import DataFrame >>> from flash import Trainer >>> from flash.tabular import TabularForecaster, TabularForecastingData >>> datamodule = TabularForecastingData.from_data_frame( ... "time_idx", ... "value", ... ["series"], ... train_data_frame=data, ... predict_data_frame=DataFrame.from_dict( ... { ... "time_idx": list(range(50)), ... "value": [0.0] * 50, ... "series": [0] * 50, ... } ... ), ... time_varying_unknown_reals=["value"], ... max_encoder_length=30, ... max_prediction_length=20, ... batch_size=32, ... ) >>> model = TabularForecaster( ... datamodule.parameters, ... backbone="n_beats", ... backbone_kwargs={"widths": [16, 256]}, ... ) >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Training... >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Predicting... .. testcleanup:: >>> del data """ ds_kw = dict( time_idx=time_idx, group_ids=group_ids, target=target, parameters=parameters, **input_kwargs, ) train_input = input_cls(RunningStage.TRAINING, train_data_frame, **ds_kw) ds_kw["parameters"] = train_input.parameters if train_input else parameters return cls( train_input, input_cls(RunningStage.VALIDATING, val_data_frame, **ds_kw), input_cls(RunningStage.TESTING, test_data_frame, **ds_kw), input_cls(RunningStage.PREDICTING, predict_data_frame, **ds_kw), transform=transform, transform_kwargs=transform_kwargs, data_fetcher=data_fetcher, val_split=val_split, batch_size=batch_size, num_workers=num_workers, sampler=sampler, pin_memory=pin_memory, persistent_workers=persistent_workers, )

© Copyright 2020-2021, PyTorch Lightning. Revision da42a635.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: stable
Versions
latest
stable
0.8.0
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
Downloads
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.