Shortcuts

Source code for flash.core.data.io.classification_input

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Optional

from flash.core.data.properties import Properties
from flash.core.data.utilities.classification import TargetFormatter, get_target_formatter


[docs]class ClassificationInputMixin(Properties): """The ``ClassificationInputMixin`` class provides utility methods for handling classification targets. :class:`~flash.core.data.io.input.Input` objects that extend ``ClassificationInputMixin`` should do the following: * In the ``load_data`` method, include a call to ``load_target_metadata``. This will determine the format of the targets and store metadata like ``labels`` and ``num_classes``. * In the ``load_sample`` method, use ``format_target`` to convert the target to a standard format for use with our tasks. """ target_formatter: TargetFormatter multi_label: bool labels: list num_classes: int
[docs] def load_target_metadata( self, targets: Optional[List[Any]], target_formatter: Optional[TargetFormatter] = None, add_background: bool = False, ) -> None: """Determine the target format and store the ``labels`` and ``num_classes``. Args: targets: The list of targets. target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` rather than inferring from the targets. add_background: If ``True``, a background class will be inserted as class zero if ``labels`` and ``num_classes`` are being inferred. """ self.target_formatter = target_formatter if target_formatter is None and targets is not None: self.target_formatter = get_target_formatter(targets, add_background=add_background) if self.target_formatter is not None: self.multi_label = self.target_formatter.multi_label self.labels = self.target_formatter.labels self.num_classes = self.target_formatter.num_classes
[docs] def format_target(self, target: Any) -> Any: """Format a single target according to the previously computed target format and metadata. Args: target: The target to format. Returns: The formatted target. """ return getattr(self, "target_formatter", lambda x: x)(target)

© Copyright 2020-2021, PyTorch Lightning. Revision a374dd4f.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: latest
Versions
latest
stable
0.8.2
0.8.1.post0
0.8.1
0.8.0
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
Downloads
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.