Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit b455557

Browse files
author
Rue Yokaze
committed
#1 fix set function
1 parent 0640cc0 commit b455557

File tree

1 file changed

+25
-62
lines changed

1 file changed

+25
-62
lines changed

‎python_multi_array.cpp

Lines changed: 25 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -400,79 +400,42 @@ namespace python_multi_array
400400
{
401401
throw std::invalid_argument("nd");
402402
}
403-
404-
size_t s[8];
405-
std::fill(std::begin(s), std::end(s), 0);
403+
size_t s[N];
406404
std::copy(This->shape(), This->shape() + N, s);
405+
size_t ix[N];
406+
std::fill(ix, ix + N, 0);
407+
408+
size_t boost_strides[N];
409+
std::copy(This->strides(), This->strides() + N, boost_strides);
407410

408-
size_t dt[8];
409-
std::fill(std::begin(dt), std::end(dt), 0);
410-
std::copy(This->strides(), This->strides() + N, dt);
411+
size_t numpy_strides[N];
412+
std::transform(nd.get_strides(), nd.get_strides() + N, numpy_strides, [](auto input) { return input / sizeof(S); });
411413

412-
size_t dnd[8];
413-
std::fill(std::begin(dnd), std::end(dnd), 0);
414-
std::transform(nd.get_strides(), nd.get_strides() + N, dnd, [](auto input) { return input / sizeof(S); });
414+
T* p_boost_origin = This->origin();
415+
const S* p_numpy_origin = reinterpret_cast<S*>(nd.get_data());
415416

416-
for (size_t i = 0; i < N; ++i)
417+
while (ix[0] < s[0])
417418
{
418-
if (nd.get_shape()[i] == 1)
419+
T* p_boost_element = p_boost_origin;
420+
const S* p_numpy_element = p_numpy_origin;
421+
for (size_t d = 0; d < (N - 1); ++d)
419422
{
420-
dnd[i] = 0;
423+
p_boost_element += ix[d] * boost_strides[d];
424+
p_numpy_element += ix[d] * numpy_strides[d];
421425
}
422-
elseif (s[i] != nd.get_shape()[i])
426+
while (ix[N - 1] < s[N - 1])
423427
{
424-
throw std::invalid_argument("nd");
428+
*p_boost_element = static_cast<T>(*p_numpy_element);
429+
p_boost_element += boost_strides[N - 1];
430+
p_numpy_element += numpy_strides[N - 1];
431+
++(ix[N - 1]);
425432
}
426-
}
427-
428-
T* pt = This->origin();
429-
const S* pnd = reinterpret_cast<S*>(nd.get_data());
430-
431-
for (size_t i0 = 0; i0 < s[0]; ++i0)
432-
{
433-
T* pt1 = pt + i0 * dt[0];
434-
const S* pnd1 = pnd + i0 * dnd[0];
435-
*pt1 = static_cast<T>(*pnd1);
436-
for (size_t i1 = 1; i1 < s[1]; ++i1)
433+
for (size_t d = N - 1; d > 0; --d)
437434
{
438-
T* pt2 = pt1 + i1 * dt[1];
439-
const S* pnd2 = pnd1 + i1 * dnd[1];
440-
*pt2 = static_cast<T>(*pnd2);
441-
for (size_t i2 = 1; i2 < s[2]; ++i2)
435+
if (s[d] <= ix[d])
442436
{
443-
T* pt3 = pt2 + i2 * dt[2];
444-
const S* pnd3 = pnd2 + i2 * dnd[2];
445-
*pt3 = static_cast<T>(*pnd3);
446-
for (size_t i3 = 1; i3 < s[3]; ++i3)
447-
{
448-
T* pt4 = pt3 + i3 * dt[3];
449-
const S* pnd4 = pnd3 + i3 * dnd[3];
450-
*pt4 = static_cast<T>(*pnd4);
451-
for (size_t i4 = 1; i4 < s[4]; ++i4)
452-
{
453-
T* pt5 = pt4 + i4 * dt[4];
454-
const S* pnd5 = pnd4 + i4 * dnd[4];
455-
*pt5 = static_cast<T>(*pnd5);
456-
for (size_t i5 = 1; i5 < s[5]; ++i5)
457-
{
458-
T* pt6 = pt5 + i5 * dt[5];
459-
const S* pnd6 = pnd5 + i5 * dnd[5];
460-
*pt6 = static_cast<T>(*pnd6);
461-
for (size_t i6 = 1; i6 < s[6]; ++i6)
462-
{
463-
T* pt7 = pt6 + i6 * dt[6];
464-
const S* pnd7 = pnd6 + i6 * dnd[6];
465-
*pt7 = static_cast<T>(*pnd7);
466-
for (size_t i7 = 1; i7 < s[7]; ++i7)
467-
{
468-
T* pt8 = pt7 + i7 * dt[7];
469-
const S* pnd8 = pnd7 + i7 * dnd[7];
470-
*pt8 = static_cast<T>(*pnd8);
471-
}
472-
}
473-
}
474-
}
475-
}
437+
ix[d] = 0;
438+
++(ix[d - 1]);
476439
}
477440
}
478441
}

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /