Skip to content
Merged
16 changes: 8 additions & 8 deletions docarray/index/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from docarray.array.any_array import AnyDocArray
from docarray.typing import ID, AnyTensor
from docarray.typing.tensor.abstract_tensor import AbstractTensor
from docarray.utils._internal._typing import is_tensor_union
from docarray.utils._internal._typing import is_tensor_union, safe_issubclass
from docarray.utils._internal.misc import import_library
from docarray.utils.find import (
FindResult,
Expand Down Expand Up @@ -390,7 +390,7 @@ def __delitem__(self, key: Union[str, Sequence[str]]):
for field_name, type_, _ in self._flatten_schema(
cast(Type[BaseDoc], self._schema)
):
if issubclass(type_, AnyDocArray):
if safe_issubclass(type_, AnyDocArray):
for doc_id in key:
nested_docs_id = self._subindices[field_name]._filter_by_parent_id(
doc_id
Expand Down Expand Up @@ -776,7 +776,7 @@ def _update_subindex_data(
for field_name, type_, _ in self._flatten_schema(
cast(Type[BaseDoc], self._schema)
):
if issubclass(type_, AnyDocArray):
if safe_issubclass(type_, AnyDocArray):
for doc in docs:
_list = getattr(doc, field_name)
for i, nested_doc in enumerate(_list):
Expand Down Expand Up @@ -857,11 +857,11 @@ def _flatten_schema(
raise ValueError(
f'Union type {t_} is not supported. Only Union of subclasses of AbstractTensor or Union[type, None] are supported.'
)
elif issubclass(t_, BaseDoc):
elif safe_issubclass(t_, BaseDoc):
names_types_fields.extend(
cls._flatten_schema(t_, name_prefix=inner_prefix)
)
elif issubclass(t_, AbstractTensor):
elif safe_issubclass(t_, AbstractTensor):
names_types_fields.append(
(name_prefix + field_name, AbstractTensor, field_)
)
Expand All @@ -879,7 +879,7 @@ def _create_column_infos(self, schema: Type[BaseDoc]) -> Dict[str, _ColumnInfo]:
column_infos: Dict[str, _ColumnInfo] = dict()
for field_name, type_, field_ in self._flatten_schema(schema):
# Union types are handle in _flatten_schema
if issubclass(type_, AnyDocArray):
if safe_issubclass(type_, AnyDocArray):
column_infos[field_name] = _ColumnInfo(
docarray_type=type_, db_type=None, config=dict(), n_dim=None
)
Expand Down Expand Up @@ -921,7 +921,7 @@ def _init_subindex(
):
"""Initialize subindices if any column is subclass of AnyDocArray."""
for col_name, col in self._column_infos.items():
if issubclass(col.docarray_type, AnyDocArray):
if safe_issubclass(col.docarray_type, AnyDocArray):
sub_db_config = copy.deepcopy(self._db_config)
sub_db_config.index_name = f'{self.index_name}__{col_name}'
self._subindices[col_name] = self.__class__[col.docarray_type.doc_type]( # type: ignore
Expand Down Expand Up @@ -1087,7 +1087,7 @@ def _index_subindex(self, column_to_data: Dict[str, Generator[Any, None, None]])
:param column_to_data: A dictionary from column name to a generator
"""
for col_name, col in self._column_infos.items():
if issubclass(col.docarray_type, AnyDocArray):
if safe_issubclass(col.docarray_type, AnyDocArray):
docs = [
doc for doc_list in column_to_data[col_name] for doc in doc_list
]
Expand Down
7 changes: 4 additions & 3 deletions docarray/index/backends/hnswlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from docarray.typing.tensor.abstract_tensor import AbstractTensor
from docarray.typing.tensor.ndarray import NdArray
from docarray.utils._internal.misc import import_library, is_np_int
from docarray.utils._internal._typing import safe_issubclass
from docarray.utils.find import _FindResult, _FindResultBatched

if TYPE_CHECKING:
Expand Down Expand Up @@ -98,7 +99,7 @@ def __init__(self, db_config=None, **kwargs):
}
self._hnsw_indices = {}
for col_name, col in self._column_infos.items():
if issubclass(col.docarray_type, AnyDocArray):
if safe_issubclass(col.docarray_type, AnyDocArray):
continue
if not col.config:
# non-tensor type; don't create an index
Expand Down Expand Up @@ -200,7 +201,7 @@ def python_type_to_db_type(self, python_type: Type) -> Any:
or None if ``python_type`` is not supported.
"""
for allowed_type in HNSWLIB_PY_VEC_TYPES:
if issubclass(python_type, allowed_type):
if safe_issubclass(python_type, allowed_type):
return np.ndarray

return None # all types allowed, but no db type needed
Expand Down Expand Up @@ -350,7 +351,7 @@ def _del_items(self, doc_ids: Sequence[str]):
for field_name, type_, _ in self._flatten_schema(
cast(Type[BaseDoc], self._schema)
):
if issubclass(type_, AnyDocArray):
if safe_issubclass(type_, AnyDocArray):
for id in doc_ids:
doc = self.__getitem__(id)
sub_ids = [sub_doc.id for sub_doc in getattr(doc, field_name)]
Expand Down
18 changes: 17 additions & 1 deletion docarray/utils/_internal/_typing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, Optional

from typing_inspect import get_args, is_union_type
from typing_extensions import get_origin
from typing_inspect import get_args, is_typevar, is_union_type

from docarray.typing.tensor.abstract_tensor import AbstractTensor

Expand Down Expand Up @@ -32,3 +33,18 @@ def change_cls_name(cls: type, new_name: str, scope: Optional[dict] = None) -> N
scope[new_name] = cls
cls.__qualname__ = cls.__qualname__[: -len(cls.__name__)] + new_name
cls.__name__ = new_name


def safe_issubclass(x: type, a_tuple: type) -> bool:
"""
This is a modified version of the built-in 'issubclass' function to support non-class input.
Traditional 'issubclass' calls can result in a crash if the input is non-class type (e.g. list/tuple).

:param x: A class 'x'
:param a_tuple: A class, or a tuple of classes.
:return: A boolean value - 'True' if 'x' is a subclass of 'A_tuple', 'False' otherwise.
Note that if the origin of 'x' is a list or tuple, the function immediately returns 'False'.
"""
if (get_origin(x) in (list, tuple, dict, set)) or is_typevar(x):
return False
return issubclass(x, a_tuple)
30 changes: 30 additions & 0 deletions tests/index/hnswlib/test_index_get_del.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,36 @@ class TfDoc(BaseDoc):
assert index.get_current_count() == 10


def test_index_lst_str(tmp_path):
from typing import List

class ListDoc(BaseDoc):
list_str: List[str]

docs = [ListDoc(list_str=[str(i) for i in range(10)]) for _ in range(10)]
assert isinstance(docs[0].list_str, List)

index = HnswDocumentIndex[ListDoc](work_dir=str(tmp_path))
index.index(docs)
assert index.num_docs() == 10
for index in index._hnsw_indices.values():
assert index.get_current_count() == 10


def test_index_typevar(tmp_path):
from typing import TypeVar

T = TypeVar("T")

class TypeDoc(BaseDoc):
list_str: T

index = HnswDocumentIndex[TypeDoc](work_dir=str(tmp_path))
docs = [TypeDoc(list_str=10) for _ in range(10)]
index.index(docs)
assert index.num_docs() == 10


def test_index_builtin_docs(tmp_path):
# TextDoc
class TextSchema(TextDoc):
Expand Down
22 changes: 20 additions & 2 deletions tests/units/util/test_typing.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from typing import Dict, Optional, Union
from typing import Dict, List, Optional, Set, Tuple, Union

import pytest

from docarray.typing import NdArray, TorchTensor
from docarray.typing.tensor.abstract_tensor import AbstractTensor
from docarray.utils._internal._typing import is_tensor_union, is_type_tensor
from docarray.utils._internal._typing import (
is_tensor_union,
is_type_tensor,
safe_issubclass,
)
from docarray.utils._internal.misc import is_tf_available

tf_available = is_tf_available()
Expand Down Expand Up @@ -73,3 +77,17 @@ def test_is_union_type_tensor(type_, is_union_tensor):
)
def test_is_union_type_tensor_with_tf(type_, is_union_tensor):
assert is_tensor_union(type_) == is_union_tensor


@pytest.mark.parametrize(
'type_, cls, is_subclass',
[
(List[str], object, False),
(List[List[int]], object, False),
(Set[str], object, False),
(Dict, object, False),
(Tuple[int, int], object, False),
],
)
def test_safe_issubclass(type_, cls, is_subclass):
assert safe_issubclass(type_, cls) == is_subclass