1+ #include  " core/op/naive/ops.h" 
2+ #include  " core/test/common/factory.h" 
3+ #include  " core/test/common/utils.h" 
4+ #include  " gtest/gtest.h" 
5+ 6+ using  namespace  nncore ; 
7+ using  namespace  test ; 
8+ using  namespace  opr ; 
9+ using  namespace  opr ::naive; 
10+ 11+ using  F = NDArrayFactory;
12+ using  Param = param::sum;
13+ 14+ TEST (Naive, Sum) {
15+  OpBase* oprs = OpNaiveImpl::get_instance ();
16+ 17+  //  Group 1
18+  Tensor inp1 = F::from_list (
19+  {19 , 25 , 12 , 16 , 22 , 1 , 37 , 3 , 22 , 2 , 27 , 9 , 21 , 29 , 13 , 13 , 10 , 14 ,
20+  33 , 44 , 15 , 6 , 48 , 25 , 15 , 37 , 6 , 30 , 6 , 14 , 26 , 0 , 30 , 4 , 17 , 7 ,
21+  26 , 28 , 28 , 27 , 42 , 39 , 38 , 47 , 25 , 31 , 34 , 25 , 11 , 22 , 26 , 6 , 47 , 35 ,
22+  13 , 46 , 40 , 21 , 39 , 18 , 10 , 10 , 40 , 40 , 13 , 31 , 24 , 11 , 19 , 31 , 5 , 16 },
23+  {3 , 4 , 2 , 3 }, dtype::Int32 ());
24+  Tensor truth1 = F::from_list ({273 , 193 , 306 , 276 , 272 , 302 }, {3 , 1 , 2 , 1 },
25+  dtype::Int32 ());
26+  Param p1 ({false , true , false , true }, true );
27+ 28+  Tensor pred1;
29+  ASSERT_TRUE (oprs->sum (inp1, pred1, p1).is_ok ());
30+  assert_same_data<nn_int32>(pred1, truth1, 0 .0001f );
31+ 32+  //  Group 2
33+  Tensor inp2 = F::from_list (
34+  {36 , 33 , 3 , 25 , 9 , 19 , 18 , 45 , 21 , 41 , 32 , 19 , 23 , 10 , 3 , 17 , 26 , 33 ,
35+  37 , 25 , 44 , 6 , 7 , 27 , 46 , 29 , 42 , 19 , 42 , 42 , 0 , 43 , 4 , 49 , 27 , 22 ,
36+  10 , 2 , 5 , 21 , 47 , 39 , 21 , 32 , 5 , 1 , 43 , 40 , 9 , 30 , 10 , 29 , 36 , 46 ,
37+  47 , 27 , 36 , 39 , 26 , 28 , 5 , 49 , 21 , 45 , 45 , 7 , 44 , 16 , 44 , 44 , 1 , 45 },
38+  {3 , 4 , 2 , 3 }, dtype::Int32 ());
39+  Tensor truth2 = F::from_list ({238 , 267 , 241 , 283 , 128 , 280 , 268 , 214 },
40+  {1 , 4 , 2 , 1 }, dtype::Int32 ());
41+  Param p2 ({true , false , false , true }, true );
42+ 43+  Tensor pred2;
44+  ASSERT_TRUE (oprs->sum (inp2, pred2, p2).is_ok ());
45+  assert_same_data<nn_int32>(pred2, truth2, 0 .0001f );
46+ 47+  //  Group 3
48+  Tensor inp3 = F::from_list (
49+  {18 , 4 , 35 , 22 , 34 , 18 , 15 , 5 , 13 , 49 , 41 , 3 , 4 , 35 , 14 , 22 , 38 , 29 ,
50+  42 , 47 , 30 , 4 , 9 , 0 , 35 , 0 , 43 , 27 , 44 , 29 , 2 , 17 , 16 , 6 , 36 , 41 ,
51+  1 , 22 , 3 , 30 , 35 , 13 , 41 , 45 , 38 , 44 , 38 , 21 , 42 , 41 , 19 , 24 , 40 , 39 ,
52+  37 , 24 , 28 , 17 , 22 , 21 , 15 , 23 , 35 , 16 , 31 , 3 , 29 , 28 , 0 , 15 , 17 , 10 },
53+  {3 , 4 , 2 , 3 }, dtype::Int32 ());
54+  Tensor truth3 = F::from_list ({1734 }, {1 , 1 , 1 , 1 }, dtype::Int32 ());
55+  Param p3 ({true , true , true , true }, true );
56+ 57+  Tensor pred3;
58+  ASSERT_TRUE (oprs->sum (inp3, pred3, p3).is_ok ());
59+  assert_same_data<nn_int32>(pred3, truth3, 0 .0001f );
60+ }
0 commit comments