10#ifndef XTENSOR_STRIDES_HPP
11#define XTENSOR_STRIDES_HPP
18#include <xtl/xsequence.hpp>
20#include "../core/xshape.hpp"
21#include "../core/xtensor_config.hpp"
22#include "../core/xtensor_forward.hpp"
23#include "../utils/xexception.hpp"
28 template <
class shape_type>
29 std::size_t compute_size(
const shape_type& shape)
noexcept;
39 template <
class offset_type,
class S>
40 offset_type data_offset(
const S&
strides)
noexcept;
66 template <
class offset_type,
class S,
class Arg,
class... Args>
67 offset_type data_offset(
const S&
strides, Arg
arg, Args... args)
noexcept;
70 offset_type unchecked_data_offset(
const S&
strides, Args... args)
noexcept;
72 template <
class offset_type,
class S,
class It>
73 offset_type element_offset(
const S&
strides, It first, It last)
noexcept;
88 template <layout_type L = layout_type::dynamic,
class shape_type,
class str
ides_type>
91 template <layout_type L = layout_type::dynamic,
class shape_type,
class str
ides_type,
class backstr
ides_type>
95 template <
class shape_type,
class str
ides_type>
96 void adapt_strides(
const shape_type& shape, strides_type&
strides)
noexcept;
98 template <
class shape_type,
class str
ides_type,
class backstr
ides_type>
99 void adapt_strides(
const shape_type& shape, strides_type&
strides, backstrides_type& backstrides)
noexcept;
112 template <
class S,
class T>
113 std::vector<get_strides_t<S>>
120 template <
class S,
class size_type>
121 S uninitialized_shape(size_type size);
123 template <
class S1,
class S2>
124 bool broadcast_shape(
const S1& input, S2& output);
126 template <
class S1,
class S2>
127 bool broadcastable(
const S1& s1, S2& s2);
133 template <layout_type L>
148 template <
class S,
class... Args>
163 template <
class S,
class... Args>
170 template <
class C,
class It,
class size_type>
171 It strided_data_end(
const C& c, It begin,
layout_type l, size_type offset)
173 using difference_type =
typename std::iterator_traits<It>::difference_type;
174 if (c.size() == 0 || std::find(c.shape().cbegin(), c.shape().cend(), size_type(0)) != c.shape().cend())
178 if (c.dimension() == 0)
184 for (std::size_t i = 0; i != c.dimension(); ++i)
186 begin += c.strides()[i] * difference_type(c.shape()[i] - 1);
190 begin += c.strides().back();
196 begin += c.strides().front();
209 template <
class return_type,
class S,
class T,
class D>
210 inline return_type compute_stride_impl(
layout_type layout,
const S& shape, T axis, D default_stride)
214 return std::accumulate(
215 shape.cbegin() + axis + 1,
217 static_cast<return_type
>(1),
218 std::multiplies<return_type>()
223 return std::accumulate(
225 shape.cbegin() + axis,
226 static_cast<return_type
>(1),
227 std::multiplies<return_type>()
230 return default_stride;
256 using strides_type =
typename E::strides_type;
257 using return_type =
typename strides_type::value_type;
258 strides_type ret = e.strides();
259 auto shape = e.shape();
266 for (std::size_t i = 0; i < ret.size(); ++i)
270 ret[i] = detail::compute_stride_impl<return_type>(e.layout(), shape, i, ret[i]);
276 return_type f =
static_cast<return_type
>(
sizeof(
typename E::value_type));
302 using strides_type =
typename E::strides_type;
303 using return_type =
typename strides_type::value_type;
305 return_type ret = e.strides()[axis];
314 if (e.shape(axis) == 1)
316 ret = detail::compute_stride_impl<return_type>(e.layout(), e.shape(), axis, ret);
322 return_type f =
static_cast<return_type
>(
sizeof(
typename E::value_type));
335 template <
class shape_type>
336 inline std::size_t compute_size_impl(
const shape_type& shape, std::true_type )
338 using size_type = std::decay_t<typename shape_type::value_type>;
339 return static_cast<std::size_t
>(std::abs(
340 std::accumulate(shape.cbegin(), shape.cend(), size_type(1), std::multiplies<size_type>())
344 template <
class shape_type>
345 inline std::size_t compute_size_impl(
const shape_type& shape, std::false_type )
347 using size_type = std::decay_t<typename shape_type::value_type>;
348 return static_cast<std::size_t
>(
349 std::accumulate(shape.cbegin(), shape.cend(), size_type(1), std::multiplies<size_type>())
354 template <
class shape_type>
355 inline std::size_t compute_size(
const shape_type& shape)
noexcept
357 return detail::compute_size_impl(
359 xtl::is_signed<std::decay_t<
typename std::decay_t<shape_type>::value_type>>()
366 template <std::
size_t dim,
class S>
367 inline auto raw_data_offset(
const S&)
noexcept
369 using strides_value_type = std::decay_t<decltype(std::declval<S>()[0])>;
370 return strides_value_type(0);
373 template <std::
size_t dim,
class S>
374 inline auto raw_data_offset(
const S&, missing_type)
noexcept
376 using strides_value_type = std::decay_t<decltype(std::declval<S>()[0])>;
377 return strides_value_type(0);
380 template <std::size_t dim,
class S,
class Arg,
class... Args>
381 inline auto raw_data_offset(
const S&
strides, Arg
arg, Args... args)
noexcept
383 return static_cast<std::ptrdiff_t
>(
arg) *
strides[dim] + raw_data_offset<dim + 1>(
strides, args...);
386 template <layout_type L, std::ptrdiff_t static_dim>
387 struct layout_data_offset
389 template <std::size_t dim,
class S,
class Arg,
class... Args>
390 inline static auto run(
const S&
strides, Arg
arg, Args... args)
noexcept
392 return raw_data_offset<dim>(
strides,
arg, args...);
396 template <std::ptrdiff_t static_dim>
399 using self_type = layout_data_offset<layout_type::row_major, static_dim>;
401 template <std::
size_t dim,
class S,
class Arg>
402 inline static auto run(
const S&
strides, Arg
arg)
noexcept
404 if (std::ptrdiff_t(dim) + 1 == static_dim)
414 template <std::size_t dim,
class S,
class Arg,
class... Args>
415 inline static auto run(
const S&
strides, Arg
arg, Args... args)
noexcept
421 template <std::ptrdiff_t static_dim>
424 using self_type = layout_data_offset<layout_type::column_major, static_dim>;
426 template <std::
size_t dim,
class S,
class Arg>
427 inline static auto run(
const S&
strides, Arg
arg)
noexcept
439 template <std::size_t dim,
class S,
class Arg,
class... Args>
440 inline static auto run(
const S&
strides, Arg
arg, Args... args)
noexcept
444 return arg + self_type::template run<dim + 1>(
strides, args...);
454 template <
class offset_type,
class S>
455 inline offset_type data_offset(
const S&)
noexcept
457 return offset_type(0);
460 template <
class offset_type,
class S,
class Arg,
class... Args>
461 inline offset_type data_offset(
const S&
strides, Arg
arg, Args... args)
noexcept
463 constexpr std::size_t nargs =
sizeof...(Args) + 1;
467 return static_cast<offset_type
>(detail::raw_data_offset<0>(
strides,
arg, args...));
469 else if (nargs >
strides.size())
472 return data_offset<offset_type, S>(
strides, args...);
474 else if (detail::last_type_is_missing<Args...>)
477 return static_cast<offset_type
>(detail::raw_data_offset<0>(
strides,
arg, args...));
483 return static_cast<offset_type
>(detail::raw_data_offset<0>(
view,
arg, args...));
487 template <
class offset_type,
layout_type L,
class S,
class... Args>
488 inline offset_type unchecked_data_offset(
const S&
strides, Args... args)
noexcept
490 return static_cast<offset_type
>(
491 detail::layout_data_offset<L, static_dimension<S>::value>::template run<0>(
strides.cbegin(), args...)
495 template <
class offset_type,
class S,
class It>
496 inline offset_type element_offset(
const S&
strides, It first, It last)
noexcept
498 using difference_type =
typename std::iterator_traits<It>::difference_type;
499 auto size =
static_cast<difference_type
>(
500 (std::min)(
static_cast<typename S::size_type
>(std::distance(first, last)),
strides.size())
502 return std::inner_product(last - size, last,
strides.cend() - size, offset_type(0));
507 template <
class shape_type,
class str
ides_type,
class bs_ptr>
508 inline void adapt_strides(
509 const shape_type& shape,
512 typename strides_type::size_type i
519 (*backstrides)[i] =
strides[i] * std::ptrdiff_t(shape[i] - 1);
522 template <
class shape_type,
class str
ides_type>
523 inline void adapt_strides(
524 const shape_type& shape,
527 typename strides_type::size_type i
536 template <layout_type L,
class shape_type,
class str
ides_type,
class bs_ptr>
538 compute_strides(
const shape_type& shape,
layout_type l, strides_type&
strides, bs_ptr bs)
540 using strides_value_type =
typename std::decay_t<strides_type>::value_type;
541 strides_value_type data_size = 1;
543#if defined(_MSC_VER) && (1931 <= _MSC_VER)
545 if (0 == shape.size())
547 return static_cast<std::size_t
>(data_size);
553 for (std::size_t i = shape.size(); i != 0; --i)
556 data_size =
strides[i - 1] *
static_cast<strides_value_type
>(shape[i - 1]);
557 adapt_strides(shape,
strides, bs, i - 1);
562 for (std::size_t i = 0; i < shape.size(); ++i)
565 data_size =
strides[i] *
static_cast<strides_value_type
>(shape[i]);
566 adapt_strides(shape,
strides, bs, i);
569 return static_cast<std::size_t
>(data_size);
573 template <layout_type L,
class shape_type,
class str
ides_type>
576 return detail::compute_strides<L>(shape, l,
strides,
nullptr);
579 template <layout_type L,
class shape_type,
class str
ides_type,
class backstr
ides_type>
583 return detail::compute_strides<L>(shape, l,
strides, &backstrides);
586 template <
class T1,
class T2>
588 stride_match_condition(
const T1& stride,
const T2& shape,
const T1& data_size,
bool zero_strides)
590 return (shape == T2(1) && stride == T1(0) && zero_strides) || (stride == data_size);
594 template <
class shape_type,
class str
ides_type>
596 do_strides_match(
const shape_type& shape,
const strides_type&
strides,
layout_type l,
bool zero_strides)
598 using value_type =
typename strides_type::value_type;
599 value_type data_size = 1;
602 for (std::size_t i =
strides.size(); i != 0; --i)
604 if (!stride_match_condition(
strides[i - 1], shape[i - 1], data_size, zero_strides))
608 data_size *=
static_cast<value_type
>(shape[i - 1]);
614 for (std::size_t i = 0; i <
strides.size(); ++i)
616 if (!stride_match_condition(
strides[i], shape[i], data_size, zero_strides))
620 data_size *=
static_cast<value_type
>(shape[i]);
630 template <
class shape_type,
class str
ides_type>
631 inline void adapt_strides(
const shape_type& shape, strides_type&
strides)
noexcept
633 for (
typename shape_type::size_type i = 0; i < shape.size(); ++i)
635 detail::adapt_strides(shape,
strides,
nullptr, i);
639 template <
class shape_type,
class str
ides_type,
class backstr
ides_type>
641 adapt_strides(
const shape_type& shape, strides_type&
strides, backstrides_type& backstrides)
noexcept
643 for (
typename shape_type::size_type i = 0; i < shape.size(); ++i)
645 detail::adapt_strides(shape,
strides, &backstrides, i);
652 inline S unravel_noexcept(
typename S::value_type idx,
const S&
strides,
layout_type l)
noexcept
654 using value_type =
typename S::value_type;
655 using size_type =
typename S::size_type;
656 S result = xtl::make_sequence<S>(
strides.size(), 0);
659 for (size_type i = 0; i <
strides.size(); ++i)
662 value_type quot = str != 0 ? idx / str : 0;
663 idx = str != 0 ? idx % str : idx;
669 for (size_type i =
strides.size(); i != 0; --i)
671 value_type str =
strides[i - 1];
672 value_type quot = str != 0 ? idx / str : 0;
673 idx = str != 0 ? idx % str : idx;
674 result[i - 1] = quot;
682 inline S unravel_from_strides(
typename S::value_type index,
const S&
strides,
layout_type l)
686 XTENSOR_THROW(std::runtime_error,
"unravel_index: dynamic layout not supported");
688 return detail::unravel_noexcept(index,
strides, l);
691 template <
class S,
class T>
692 inline get_value_type_t<T> ravel_from_strides(
const T& index,
const S&
strides)
694 return element_offset<get_value_type_t<T>>(
strides, index.begin(), index.end());
698 inline get_strides_t<S> unravel_index(
typename S::value_type index,
const S& shape,
layout_type l)
700 using strides_type = get_strides_t<S>;
701 using strides_value_type =
typename strides_type::value_type;
702 strides_type
strides = xtl::make_sequence<strides_type>(shape.size(), 0);
704 return unravel_from_strides(
static_cast<strides_value_type
>(index),
strides, l);
707 template <
class S,
class T>
708 inline std::vector<get_strides_t<S>> unravel_indices(
const T& idx,
const S& shape,
layout_type l)
710 using strides_type = get_strides_t<S>;
711 using strides_value_type =
typename strides_type::value_type;
712 strides_type
strides = xtl::make_sequence<strides_type>(shape.size(), 0);
714 std::vector<get_strides_t<S>> out(idx.size());
715 auto out_iter = out.begin();
716 auto idx_iter = idx.begin();
717 for (; out_iter != out.end(); ++out_iter, ++idx_iter)
719 *out_iter = unravel_from_strides(
static_cast<strides_value_type
>(*idx_iter),
strides, l);
724 template <
class S,
class T>
725 inline get_value_type_t<T> ravel_index(
const T& index,
const S& shape,
layout_type l)
727 using strides_type = get_strides_t<S>;
728 strides_type
strides = xtl::make_sequence<strides_type>(shape.size(), 0);
730 return ravel_from_strides(index,
strides);
733 template <
class S,
class stype>
734 inline S uninitialized_shape(stype size)
736 using value_type =
typename S::value_type;
737 using size_type =
typename S::size_type;
738 return xtl::make_sequence<S>(
static_cast<size_type
>(size), std::numeric_limits<value_type>::max());
741 template <
class S1,
class S2>
742 inline bool broadcast_shape(
const S1& input, S2& output)
744 bool trivial_broadcast = (input.size() == output.size());
746 using value_type =
typename S2::value_type;
747 auto output_index = output.size();
748 auto input_index = input.size();
750 if (output_index < input_index)
752 throw_broadcast_error(output, input);
754 for (; input_index != 0; --input_index, --output_index)
759 if (output[output_index - 1] == std::numeric_limits<value_type>::max())
761 output[output_index - 1] =
static_cast<value_type
>(input[input_index - 1]);
765 else if (output[output_index - 1] == 1)
767 output[output_index - 1] =
static_cast<value_type
>(input[input_index - 1]);
768 trivial_broadcast = trivial_broadcast && (input[input_index - 1] == 1);
772 else if (input[input_index - 1] == 1)
774 trivial_broadcast =
false;
778 else if (
static_cast<value_type
>(input[input_index - 1]) != output[output_index - 1])
780 throw_broadcast_error(output, input);
783 return trivial_broadcast;
786 template <
class S1,
class S2>
787 inline bool broadcastable(
const S1& src_shape,
const S2& dst_shape)
789 auto src_iter = src_shape.crbegin();
790 auto dst_iter = dst_shape.crbegin();
791 bool res = dst_shape.size() >= src_shape.size();
792 for (; src_iter != src_shape.crend() && res; ++src_iter, ++dst_iter)
794 res = (
static_cast<std::size_t
>(*src_iter) ==
static_cast<std::size_t
>(*dst_iter))
803 template <
class S1,
class S2>
804 static std::size_t get(
const S1& s1,
const S2& s2)
806 using value_type =
typename S1::value_type;
808 auto s1_index = s1.size();
809 auto s2_index = s2.size();
811 for (; s2_index != 0; --s1_index, --s2_index)
813 if (
static_cast<value_type
>(s1[s1_index - 1]) !=
static_cast<value_type
>(s2[s2_index - 1]))
825 template <
class S1,
class S2>
826 static std::size_t get(
const S1& s1,
const S2& s2)
829 using size_type =
typename S1::size_type;
830 using value_type =
typename S1::value_type;
835 if (s1.size() != s2.size())
840 auto size = s2.size();
842 for (; index < size; ++index)
844 if (
static_cast<value_type
>(s1[index]) !=
static_cast<value_type
>(s2[index]))
855 template <
class S, std::
size_t dim>
856 inline bool check_in_bounds_impl(
const S&)
861 template <
class S, std::
size_t dim>
862 inline bool check_in_bounds_impl(
const S&, missing_type)
867 template <
class S, std::size_t dim,
class T,
class... Args>
868 inline bool check_in_bounds_impl(
const S& shape, T&
arg, Args&... args)
870 if (
sizeof...(Args) + 1 > shape.size())
872 return check_in_bounds_impl<S, dim>(shape, args...);
877 && check_in_bounds_impl<S, dim + 1>(shape, args...);
882 template <
class S,
class... Args>
883 inline bool check_in_bounds(
const S& shape, Args&... args)
885 return detail::check_in_bounds_impl<S, 0>(shape, args...);
890 template <
class S, std::
size_t dim>
891 inline void normalize_periodic_impl(
const S&)
895 template <
class S, std::
size_t dim>
896 inline void normalize_periodic_impl(
const S&, missing_type)
900 template <
class S, std::size_t dim,
class T,
class... Args>
901 inline void normalize_periodic_impl(
const S& shape, T&
arg, Args&... args)
903 if (
sizeof...(Args) + 1 > shape.size())
905 normalize_periodic_impl<S, dim>(shape, args...);
909 T n =
static_cast<T
>(shape[dim]);
910 arg = (n + (
arg % n)) % n;
911 normalize_periodic_impl<S, dim + 1>(shape, args...);
916 template <
class S,
class... Args>
919 check_dimension(shape, args...);
920 detail::normalize_periodic_impl<S, 0>(shape, args...);
auto arg(E &&e) noexcept
Calculates the phase angle (in radians) elementwise for the complex numbers in e.
std::size_t compute_strides(const shape_type &shape, layout_type l, strides_type &strides)
Compute the strides given the shape and the layout of an array.
auto strides(const E &e, stride_type type=stride_type::normal) noexcept
Get strides of an object.
stride_type
Choose stride type.
void normalize_periodic(const S &shape, Args &... args)
Normalise an index of a periodic array.
@ bytes
Normal stride in bytes.
@ internal
As used internally (with stride(axis) == 0 if shape(axis) == 1)
@ normal
Normal stride corresponding to storage.
standard mathematical functions for xexpressions
bool in_bounds(const S &shape, Args &... args)
Check if the index is within the bounds of the array.
auto view(E &&e, S &&... slices)
Constructs and returns a view on the specified xexpression.