@@ -400,79 +400,42 @@ namespace python_multi_array
400
400
{
401
401
throw std::invalid_argument (" nd" );
402
402
}
403
-
404
- size_t s[8 ];
405
- std::fill (std::begin (s), std::end (s), 0 );
403
+ size_t s[N];
406
404
std::copy (This->shape (), This->shape () + N, s);
405
+ size_t ix[N];
406
+ std::fill (ix, ix + N, 0 );
407
+
408
+ size_t boost_strides[N];
409
+ std::copy (This->strides (), This->strides () + N, boost_strides);
407
410
408
- size_t dt[8 ];
409
- std::fill (std::begin (dt), std::end (dt), 0 );
410
- std::copy (This->strides (), This->strides () + N, dt);
411
+ size_t numpy_strides[N];
412
+ std::transform (nd.get_strides (), nd.get_strides () + N, numpy_strides, [](auto input) { return input / sizeof (S); });
411
413
412
- size_t dnd[8 ];
413
- std::fill (std::begin (dnd), std::end (dnd), 0 );
414
- std::transform (nd.get_strides (), nd.get_strides () + N, dnd, [](auto input) { return input / sizeof (S); });
414
+ T* p_boost_origin = This->origin ();
415
+ const S* p_numpy_origin = reinterpret_cast <S*>(nd.get_data ());
415
416
416
- for ( size_t i = 0 ; i < N; ++i )
417
+ while (ix[ 0 ] < s[ 0 ] )
417
418
{
418
- if (nd.get_shape ()[i] == 1 )
419
+ T* p_boost_element = p_boost_origin;
420
+ const S* p_numpy_element = p_numpy_origin;
421
+ for (size_t d = 0 ; d < (N - 1 ); ++d)
419
422
{
420
- dnd[i] = 0 ;
423
+ p_boost_element += ix[d] * boost_strides[d];
424
+ p_numpy_element += ix[d] * numpy_strides[d];
421
425
}
422
- else if (s[i] != nd. get_shape ()[i ])
426
+ while (ix[N - 1 ] < s[N - 1 ])
423
427
{
424
- throw std::invalid_argument (" nd" );
428
+ *p_boost_element = static_cast <T>(*p_numpy_element);
429
+ p_boost_element += boost_strides[N - 1 ];
430
+ p_numpy_element += numpy_strides[N - 1 ];
431
+ ++(ix[N - 1 ]);
425
432
}
426
- }
427
-
428
- T* pt = This->origin ();
429
- const S* pnd = reinterpret_cast <S*>(nd.get_data ());
430
-
431
- for (size_t i0 = 0 ; i0 < s[0 ]; ++i0)
432
- {
433
- T* pt1 = pt + i0 * dt[0 ];
434
- const S* pnd1 = pnd + i0 * dnd[0 ];
435
- *pt1 = static_cast <T>(*pnd1);
436
- for (size_t i1 = 1 ; i1 < s[1 ]; ++i1)
433
+ for (size_t d = N - 1 ; d > 0 ; --d)
437
434
{
438
- T* pt2 = pt1 + i1 * dt[1 ];
439
- const S* pnd2 = pnd1 + i1 * dnd[1 ];
440
- *pt2 = static_cast <T>(*pnd2);
441
- for (size_t i2 = 1 ; i2 < s[2 ]; ++i2)
435
+ if (s[d] <= ix[d])
442
436
{
443
- T* pt3 = pt2 + i2 * dt[2 ];
444
- const S* pnd3 = pnd2 + i2 * dnd[2 ];
445
- *pt3 = static_cast <T>(*pnd3);
446
- for (size_t i3 = 1 ; i3 < s[3 ]; ++i3)
447
- {
448
- T* pt4 = pt3 + i3 * dt[3 ];
449
- const S* pnd4 = pnd3 + i3 * dnd[3 ];
450
- *pt4 = static_cast <T>(*pnd4);
451
- for (size_t i4 = 1 ; i4 < s[4 ]; ++i4)
452
- {
453
- T* pt5 = pt4 + i4 * dt[4 ];
454
- const S* pnd5 = pnd4 + i4 * dnd[4 ];
455
- *pt5 = static_cast <T>(*pnd5);
456
- for (size_t i5 = 1 ; i5 < s[5 ]; ++i5)
457
- {
458
- T* pt6 = pt5 + i5 * dt[5 ];
459
- const S* pnd6 = pnd5 + i5 * dnd[5 ];
460
- *pt6 = static_cast <T>(*pnd6);
461
- for (size_t i6 = 1 ; i6 < s[6 ]; ++i6)
462
- {
463
- T* pt7 = pt6 + i6 * dt[6 ];
464
- const S* pnd7 = pnd6 + i6 * dnd[6 ];
465
- *pt7 = static_cast <T>(*pnd7);
466
- for (size_t i7 = 1 ; i7 < s[7 ]; ++i7)
467
- {
468
- T* pt8 = pt7 + i7 * dt[7 ];
469
- const S* pnd8 = pnd7 + i7 * dnd[7 ];
470
- *pt8 = static_cast <T>(*pnd8);
471
- }
472
- }
473
- }
474
- }
475
- }
437
+ ix[d] = 0 ;
438
+ ++(ix[d - 1 ]);
476
439
}
477
440
}
478
441
}
0 commit comments