Shortcuts

Source code for flash.text.embedding.model

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import warnings
from typing import Any, Dict, List, Optional

import torch
from pytorch_lightning import Callback

from flash.core.model import Task
from flash.core.registry import FlashRegistry, print_provider_info
from flash.core.utilities.imports import _TEXT_AVAILABLE
from flash.core.utilities.providers import _SENTENCE_TRANSFORMERS
from flash.text.classification.collate import TextClassificationCollate
from flash.text.embedding.backbones import HUGGINGFACE_BACKBONES
from flash.text.ort_callback import ORTCallback

if _TEXT_AVAILABLE:
    from sentence_transformers.models import Pooling

    Pooling = print_provider_info("Pooling", _SENTENCE_TRANSFORMERS, Pooling)

logger = logging.getLogger(__name__)


[docs]class TextEmbedder(Task): """The ``TextEmbedder`` is a :class:`~flash.Task` for generating sentence embeddings, training and validation. For more details, see `embeddings`. You can change the backbone to any question answering model from `UKPLab/sentence-transformers <https://github.com/UKPLab/sentence-transformers>`_ using the ``backbone`` argument. Args: backbone: backbone model to use for the task. enable_ort: Enable Torch ONNX Runtime Optimization: https://onnxruntime.ai/docs/#onnx-runtime-for-training """ required_extras: str = "text" backbones: FlashRegistry = HUGGINGFACE_BACKBONES def __init__( self, backbone: str = "sentence-transformers/all-MiniLM-L6-v2", max_length: int = 128, tokenizer_backbone: Optional[str] = None, tokenizer_kwargs: Optional[Dict[str, Any]] = None, enable_ort: bool = False, ): os.environ["TOKENIZERS_PARALLELISM"] = "TRUE" # disable HF thousand warnings warnings.simplefilter("ignore") # set os environ variable for multiprocesses os.environ["PYTHONWARNINGS"] = "ignore" super().__init__() if tokenizer_backbone is None: tokenizer_backbone = backbone self.max_length = max_length self.collate_fn = TextClassificationCollate( backbone=tokenizer_backbone, max_length=max_length, tokenizer_kwargs=tokenizer_kwargs ) self.model = self.backbones.get(backbone)() self.pooling = Pooling(self.model.config.hidden_size) self.enable_ort = enable_ort def training_step(self, batch: Any, batch_idx: int) -> Any: raise NotImplementedError("Training a `TextEmbedder` is not supported. Use a different text task instead.") def validation_step(self, batch: Any, batch_idx: int) -> Any: raise NotImplementedError("Validating a `TextEmbedder` is not supported. Use a different text task instead.") def test_step(self, batch: Any, batch_idx: int) -> Any: raise NotImplementedError("Testing a `TextEmbedder` is not supported. Use a different text task instead.")
[docs] def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: """Adapted from sentence-transformers: https://github.com/UKPLab/sentence-transformers/blob/master/sentence_transformers/models/Transformer.py#L45 """ trans_features = {"input_ids": batch["input_ids"], "attention_mask": batch["attention_mask"]} if "token_type_ids" in batch: trans_features["token_type_ids"] = batch["token_type_ids"] output_states = self.model(**trans_features, return_dict=False) output_tokens = output_states[0] batch.update({"token_embeddings": output_tokens, "attention_mask": batch["attention_mask"]}) return self.pooling(batch)["sentence_embedding"]
def configure_callbacks(self) -> List[Callback]: callbacks = super().configure_callbacks() or [] if self.enable_ort: callbacks.append(ORTCallback()) return callbacks

© Copyright 2020-2021, PyTorch Lightning. Revision 8db29e8e.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: 0.7.3
Versions
latest
stable
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
docs-fix_typing
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.