-
Notifications
You must be signed in to change notification settings - Fork 223
Expand file tree
/
Copy pathcommon_types.py
More file actions
48 lines (42 loc) · 1.97 KB
/
Copy pathcommon_types.py
File metadata and controls
48 lines (42 loc) · 1.97 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
# Copyright 2020 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common types in tf.transform."""
from typing import Any, Dict, Iterable, List, Optional, TypeVar, Union
import numpy as np
import tensorflow as tf
from tensorflow_metadata.proto.v0 import schema_pb2
from typing_extensions import Literal
# Demonstrational per-row data formats.
PrimitiveType = Union[str, bytes, float, int]
InstanceValueType = Optional[Union[np.ndarray, np.generic, PrimitiveType, List[Any]]]
InstanceDictType = Dict[str, InstanceValueType]
# TODO(b/185719271): Define BucketBoundariesType at module level of mappers.py.
BucketBoundariesType = Union[tf.Tensor, Iterable[Union[int, float]]]
FeatureSpecType = Union[
tf.io.FixedLenFeature, tf.io.VarLenFeature, tf.io.SparseFeature, tf.io.RaggedFeature
]
DomainType = Union[
schema_pb2.IntDomain, schema_pb2.FloatDomain, schema_pb2.StringDomain
]
TensorType = Union[tf.Tensor, tf.SparseTensor, tf.RaggedTensor]
ConsistentTensorType = TypeVar( # pylint: disable=invalid-name
"ConsistentTensorType", tf.Tensor, tf.SparseTensor, tf.RaggedTensor
)
SparseTensorValueType = Union[tf.SparseTensor, tf.compat.v1.SparseTensorValue]
RaggedTensorValueType = Union[tf.RaggedTensor, tf.compat.v1.ragged.RaggedTensorValue]
TensorValueType = Union[
tf.Tensor, np.ndarray, SparseTensorValueType, RaggedTensorValueType
]
TemporaryAnalyzerOutputType = Union[tf.Tensor, tf.saved_model.Asset]
VocabularyFileFormatType = Literal["text", "tfrecord_gzip"]