Shortcuts

Source code for flash.core.serve.core

import dataclasses
from dataclasses import dataclass, field, make_dataclass
from pathlib import Path
from typing import Dict, Iterator, List, NamedTuple, Optional, Tuple, Type, TypeVar, Union

import pytorch_lightning as pl
import torch

from flash.core.serve.types.base import BaseType
from flash.core.serve.utils import download_file
from flash.core.utilities.imports import _PYDANTIC_AVAILABLE, requires

if _PYDANTIC_AVAILABLE:
    from pydantic import FilePath, HttpUrl, ValidationError, parse_obj_as
else:
    FilePath, HttpUrl, parse_obj_as, ValidationError = None, None, None, None

# -------------------------------- Endpoint -----------------------------------


[docs]@dataclass class Endpoint: """An endpoint maps a route and request/response payload to components. Parameters ---------- route The API route name to construct as the servicing POST endpoint. inputs The full name of a component input. Typically, specified by just passing in the component parameter attribute (i.e.``component.inputs.foo``). outputs The full name of a component output. Typically, specified by just passing in the component parameter attribute (i.e.``component.outputs.bar``). """ route: str inputs: Dict[str, str] outputs: Dict[str, str] def __post_init__(self): if not isinstance(self.route, str): raise TypeError( f"route parameter must be type={str}, recieved " f"route={self.route} of type={type(self.route)}" ) if not self.route.startswith("/"): raise ValueError("route must begin with a `slash` character (ie `/`).") for k in tuple(self.inputs.keys()): v = self.inputs[k] if not isinstance(v, (Parameter, str)): raise TypeError(f"inputs k={k}, v={v}, is not {Parameter} or {str}. type(v)={type(v)}") self.inputs[k] = str(v) for k in tuple(self.outputs.keys()): v = self.outputs[k] if not isinstance(v, (Parameter, str)): raise TypeError(f"k={k}, v={v}, type(v)={type(v)}") self.outputs[k] = str(v)
# -------------------------------- Servable --------------------------------- class FlashServeScriptLoader: __slots__ = ("location", "instance") def __init__(self, location: FilePath): self.location = location self.instance = torch.jit.load(location) def __call__(self, *args, **kwargs): print(self.instance, args, kwargs) return self.instance(*args, **kwargs) ServableValidArgs_T = Union[ Tuple[Type[pl.LightningModule], Union[HttpUrl, FilePath]], Tuple[HttpUrl], Tuple[FilePath], ]
[docs]class Servable: """ModuleWrapperBase around a model object to enable serving at scale. Create a ``Servable`` from either (LM, LOCATION) or (LOCATION,) Parameters ---------- *args A model class and path to the asset file (url or local file path) OR a singular path to a torchscript asset which can be loaded without the model class definition. download_path Optional url to download a model from. TODO ---- * How to handle ``__init__`` args for ``torch.nn.Module`` * How to handle ``__init__`` args not recorded in hparams of ``pl.LightningModule`` """ @requires("serve") def __init__( self, *args: ServableValidArgs_T, download_path: Optional[Path] = None, script_loader_cls: Type[FlashServeScriptLoader] = FlashServeScriptLoader, ): try: loc = args[-1] # last element in args is always loc parsed = parse_obj_as(ServableValidArgs_T, tuple(args)) except ValidationError: if args[0].__qualname__ != script_loader_cls.__qualname__: raise parsed = [script_loader_cls, parse_obj_as(Union[HttpUrl, FilePath], loc)] f_path = loc if isinstance(parsed[-1], Path) else download_file(loc, download_path=download_path) if len(args) == 2 and args[0].__qualname__ != script_loader_cls.__qualname__: # if this is a class and path/url... klass = args[0] instance = klass.load_from_checkpoint(f_path) else: # if this is just a path/url klass = script_loader_cls instance = klass(f_path) self.instance = instance def __call__(self, *args, **kwargs): return self.instance(*args, **kwargs) def __repr__(self): return repr(self.instance)
# ------------------ Connections & Parameters (internal) ---------------------- class Connection(NamedTuple): """A connection maps one output to one input. This is a self-contained data structure, which when given in the context of the other components in a composition, will map input/output keys/indices between components. Warnings -------- * This data structure should not be instantiated directly! The class_methods attached to the class are the indended mechanisms to create a new instance. """ source_component: str target_component: str source_key: str target_key: str def __repr__(self): # pragma: no cover return f"Connection({str(self)})" def _repr_pretty_(self, p, cycle): # pragma: no cover if cycle: return res = ( f"Connection(" f"{self.source_component}.outputs.{self.source_key} >> " f"{self.target_component}.inputs.{self.target_key})" ) p.text(res) def __str__(self): return ( f"{self.source_component}.outputs.{self.source_key} >> " f"{self.target_component}.inputs.{self.target_key}" ) @dataclass class Parameter: """Holder class for each grid type of component and connections from those to the types of other components. Parameters ---------- name Name of the parameter. It's same as the dictionary key from `expose` datatype Grid type object component_uid Which component this type is associated with position Position in the while exposing it i.e `inputs` or `outputs` """ name: str datatype: BaseType component_uid: str position: str connections: List["Connection"] = field(default_factory=list, init=False, repr=False) def __str__(self): return f"{self.component_uid}.{self.position}.{self.name}" def __terminate_invalid_connection_request(self, other: "Parameter", dunder_meth_called: str) -> None: """Verify that components can be composed. Parameters ---------- other object passed into the bitshift operator. We verify if is a ``Parameter`` class and that is not the type of the same component dunder_meth_called: str one of ['__lshift__', '__rshift__']. we need to know the directionality of the bitshift method called when we verify that the directionality of the dag is always outputs -> inputs. Raises ------ TypeError, RuntimeError if the verification fails, we throw an exception to stop the connection from being created. """ # assert this is actually a class object we can compare against. if not isinstance(other, self.__class__) or (other.__class__ != self.__class__): raise TypeError(f"Can only Compose another `Parameter` class, not {type(other)}") # assert not same instance if id(other) == id(self): raise RuntimeError("Cannot compose a parameters of same components") # assert bitshift directionality is acceptable for source/target map source = other if dunder_meth_called == "__lshift__" else self target = self if dunder_meth_called == "__lshift__" else other if source.position != "outputs": raise TypeError( f"A data source component can only provide a target with data listed " f"as ``output``. source component: `{source.component_uid}` " f"key: `{source.name}`" ) if target.position != "inputs": raise TypeError( f"A data target component can only accept data into keys listed as " f"`inputs`. components: source=`{str(source)}` target={str(target)}" ) if source.component_uid == target.component_uid: raise RuntimeError( f"Cannot create cycle by creating connection between outputs and " f"inputs of a single component. source component: `{source.component_uid}`" ) def __lshift__(self, other: "Parameter"): """Implements composition connecting Parameter << Parameter.""" self.__terminate_invalid_connection_request(other, "__lshift__") con = Connection( source_component=other.component_uid, target_component=self.component_uid, source_key=other.name, target_key=self.name, ) self.connections.append(con) def __rshift__(self, other: "Parameter"): """Implements composition connecting Parameter >> Parameter.""" self.__terminate_invalid_connection_request(other, "__rshift__") con = Connection( source_component=self.component_uid, target_component=other.component_uid, source_key=self.name, target_key=other.name, ) self.connections.append(con) class DictAttrAccessBase: def __grid_fields__(self) -> Iterator[str]: for field in dataclasses.fields(self): # noqa F402 yield field.name def __getitem__(self, item) -> Parameter: return getattr(self, item) def __contains__(self, item): return bool(getattr(self, item, False)) def __len__(self): return len(tuple(self.__grid_fields__())) def __iter__(self): yield from self.__grid_fields__() ParameterContainer = TypeVar("ParameterContainer", bound=DictAttrAccessBase) # skipcq: PYL-W1401, PYL-W0621 def make_parameter_container(data: Dict[str, Parameter]) -> ParameterContainer: """Create dotted dict lookup class from parameter map. Parameters ---------- data mapping for ``parameter_name -> 'Parameter' instance`` Returns ------- ParameterContainer A representation of the parameter data dict with keys accessible via ``dotted`` attribute lookup. Notes ----- * parameter name must be valid python attribute (identifier) and cannot be a builtin keyword. input names should have been validated by this point. """ dataclass_fields = [(param_name, type(param)) for param_name, param in data.items()] ParameterContainer = make_dataclass( "ParameterContainer", dataclass_fields, bases=(DictAttrAccessBase,), frozen=True, unsafe_hash=True, ) return ParameterContainer(**data) def make_param_dict( inputs: Dict[str, BaseType], outputs: Dict[str, BaseType], component_uid: str ) -> Tuple[Dict[str, Parameter], Dict[str, Parameter]]: """Convert exposed input/outputs parameters / dtypes to parameter objects. Returns ------- Tuple[Dict[str, Parameter], Dict[str, Parameter]] Element[0] == Input parameter dict Element[1] == Output parameter dict. """ flashserve_inp_params, flashserve_out_params = {}, {} for inp_key, inp_dtype in inputs.items(): flashserve_inp_params[inp_key] = Parameter( name=inp_key, datatype=inp_dtype, component_uid=component_uid, position="inputs" ) for out_key, out_dtype in outputs.items(): flashserve_out_params[out_key] = Parameter( name=out_key, datatype=out_dtype, component_uid=component_uid, position="outputs" ) return flashserve_inp_params, flashserve_out_params

© Copyright 2020-2021, PyTorch Lightning. Revision a9cedb5a.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: stable
Versions
latest
stable
0.8.2
0.8.1.post0
0.8.1
0.8.0
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
Downloads
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.