Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit dfb2b4e

Browse files
feat(csharp/Tensor.NET): add sum method and api.
1 parent 0a1209a commit dfb2b4e

File tree

8 files changed

+258
-0
lines changed

8 files changed

+258
-0
lines changed

‎apis/numnet_c_cxx_apis.cpp‎

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,24 @@ Status *Onehot(NativeTensor *inp, NativeTensor *oup, param::onehot *param,
306306
}
307307
}
308308

309+
Status *Sum(NativeTensor *inp, NativeTensor *oup, param::sum *param,
310+
ProviderEnum provider) {
311+
Tensor t_inp, t_oup;
312+
inp->ToTensor(t_inp, false);
313+
oup->ToTensor(t_oup, true);
314+
OpBase *impl = GetImpl(provider);
315+
if (impl == nullptr) {
316+
return new Status(StatusCategory::NUMNET, StatusCode::INVALID_ARGUMENT,
317+
"Unsupported provider.");
318+
}
319+
auto status = impl->sum(t_inp, t_oup, *param);
320+
if (status.is_ok()) {
321+
return nullptr;
322+
} else {
323+
return new Status(status);
324+
}
325+
}
326+
309327
Status *Argmxx(NativeTensor *inp, NativeTensor *oup, param::argmxx *param,
310328
ProviderEnum provider) {
311329
Tensor t_inp, t_oup;

‎apis/numnet_c_cxx_apis.h‎

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,9 @@ Status *Sort(NativeTensor *inp, NativeTensor *oup, param::sort *param,
7878
Status *Onehot(NativeTensor *inp, NativeTensor *oup, param::onehot *param,
7979
ProviderEnum provider);
8080

81+
Status *Sum(NativeTensor *inp, NativeTensor *oup, param::sum *param,
82+
ProviderEnum provider);
83+
8184
Status *Argmxx(NativeTensor *inp, NativeTensor *oup, param::argmxx *param,
8285
ProviderEnum provider);
8386

‎core/test/naive/sum.cpp‎

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
#include "core/op/naive/ops.h"
2+
#include "core/test/common/factory.h"
3+
#include "core/test/common/utils.h"
4+
#include "gtest/gtest.h"
5+
6+
using namespace nncore;
7+
using namespace test;
8+
using namespace opr;
9+
using namespace opr::naive;
10+
11+
using F = NDArrayFactory;
12+
using Param = param::sum;
13+
14+
TEST(Naive, Sum) {
15+
OpBase* oprs = OpNaiveImpl::get_instance();
16+
17+
// Group 1
18+
Tensor inp1 = F::from_list(
19+
{19, 25, 12, 16, 22, 1, 37, 3, 22, 2, 27, 9, 21, 29, 13, 13, 10, 14,
20+
33, 44, 15, 6, 48, 25, 15, 37, 6, 30, 6, 14, 26, 0, 30, 4, 17, 7,
21+
26, 28, 28, 27, 42, 39, 38, 47, 25, 31, 34, 25, 11, 22, 26, 6, 47, 35,
22+
13, 46, 40, 21, 39, 18, 10, 10, 40, 40, 13, 31, 24, 11, 19, 31, 5, 16},
23+
{3, 4, 2, 3}, dtype::Int32());
24+
Tensor truth1 = F::from_list({273, 193, 306, 276, 272, 302}, {3, 1, 2, 1},
25+
dtype::Int32());
26+
Param p1({false, true, false, true}, true);
27+
28+
Tensor pred1;
29+
ASSERT_TRUE(oprs->sum(inp1, pred1, p1).is_ok());
30+
assert_same_data<nn_int32>(pred1, truth1, 0.0001f);
31+
32+
// Group 2
33+
Tensor inp2 = F::from_list(
34+
{36, 33, 3, 25, 9, 19, 18, 45, 21, 41, 32, 19, 23, 10, 3, 17, 26, 33,
35+
37, 25, 44, 6, 7, 27, 46, 29, 42, 19, 42, 42, 0, 43, 4, 49, 27, 22,
36+
10, 2, 5, 21, 47, 39, 21, 32, 5, 1, 43, 40, 9, 30, 10, 29, 36, 46,
37+
47, 27, 36, 39, 26, 28, 5, 49, 21, 45, 45, 7, 44, 16, 44, 44, 1, 45},
38+
{3, 4, 2, 3}, dtype::Int32());
39+
Tensor truth2 = F::from_list({238, 267, 241, 283, 128, 280, 268, 214},
40+
{1, 4, 2, 1}, dtype::Int32());
41+
Param p2({true, false, false, true}, true);
42+
43+
Tensor pred2;
44+
ASSERT_TRUE(oprs->sum(inp2, pred2, p2).is_ok());
45+
assert_same_data<nn_int32>(pred2, truth2, 0.0001f);
46+
47+
// Group 3
48+
Tensor inp3 = F::from_list(
49+
{18, 4, 35, 22, 34, 18, 15, 5, 13, 49, 41, 3, 4, 35, 14, 22, 38, 29,
50+
42, 47, 30, 4, 9, 0, 35, 0, 43, 27, 44, 29, 2, 17, 16, 6, 36, 41,
51+
1, 22, 3, 30, 35, 13, 41, 45, 38, 44, 38, 21, 42, 41, 19, 24, 40, 39,
52+
37, 24, 28, 17, 22, 21, 15, 23, 35, 16, 31, 3, 29, 28, 0, 15, 17, 10},
53+
{3, 4, 2, 3}, dtype::Int32());
54+
Tensor truth3 = F::from_list({1734}, {1, 1, 1, 1}, dtype::Int32());
55+
Param p3({true, true, true, true}, true);
56+
57+
Tensor pred3;
58+
ASSERT_TRUE(oprs->sum(inp3, pred3, p3).is_ok());
59+
assert_same_data<nn_int32>(pred3, truth3, 0.0001f);
60+
}

‎csharp/Tensor.NET/Native/NativeApi.cs‎

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ internal static class NativeApi{
3838
public static extern IntPtr Sort(IntPtr inp, IntPtr oup, IntPtr param, NativeProvider provider);
3939
[DllImport("libnumnet")]
4040
public static extern IntPtr Onehot(IntPtr inp, IntPtr oup, IntPtr param, NativeProvider provider);
41+
[DllImport("libnumnet")]
42+
public static extern IntPtr Sum(IntPtr inp, IntPtr oup, IntPtr param, NativeProvider provider);
4143

4244

4345
[DllImport("libnumnet")]

‎csharp/Tensor.NET/Native/NativeParam.cs‎

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@ internal struct SortParam{
5252
internal struct OnehotParam{
5353
internal int maxValue;
5454
}
55+
internal struct SumParam{
56+
internal IntPtr dims;
57+
internal bool keepDims;
58+
}
5559
internal struct TypeConvertParam{
5660
internal DType targetType;
5761
}

‎csharp/Tensor.NET/Statistics/Sum.cs‎

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
using Tensornet.Native;
2+
using Tensornet.Exceptions;
3+
using Tensornet.Native.Param;
4+
5+
namespace Tensornet{
6+
public static class SumExtension{
7+
public static Tensor<T> Sum<T>(this Tensor<T> src, int[] axes, bool keepDims = false) where T : struct, IEquatable<T>, IConvertible
8+
{
9+
Tensor<T> res = new Tensor<T>(DeduceLayout(src.TLayout, axes));
10+
res.TLayout.InitContiguousLayout();
11+
bool[] boolDims = new bool[src.TLayout.NDim];
12+
var span = boolDims.AsSpan();
13+
span.Fill(false);
14+
foreach(var axis in axes){
15+
span[axis] = true;
16+
}
17+
SumInternal(src, res, boolDims, keepDims);
18+
return res;
19+
}
20+
public static Tensor<T> Sum<T>(this Tensor<T> src, int axis, bool keepDims = false) where T : struct, IEquatable<T>, IConvertible
21+
{
22+
Tensor<T> res = new Tensor<T>(DeduceLayout(src.TLayout, axis));
23+
res.TLayout.InitContiguousLayout();
24+
bool[] boolDims = new bool[src.TLayout.NDim];
25+
var span = boolDims.AsSpan();
26+
span.Fill(false);
27+
span[axis] = true;
28+
SumInternal(src, res, boolDims, keepDims);
29+
return res;
30+
}
31+
public static Tensor<T> Sum<T>(this Tensor<T> src, bool keepDims = false) where T : struct, IEquatable<T>, IConvertible
32+
{
33+
Tensor<T> res = new Tensor<T>(DeduceLayout(src.TLayout));
34+
res.TLayout.InitContiguousLayout();
35+
bool[] boolDims = new bool[src.TLayout.NDim];
36+
boolDims.AsSpan().Fill(true);
37+
SumInternal(src, res, boolDims, keepDims);
38+
return res;
39+
}
40+
private unsafe static void SumInternal<T>(Tensor<T> src, Tensor<T> dst, bool[] dims, bool keepDims) where T : struct, IEquatable<T>, IConvertible{
41+
fixed(bool* ptr = dims){
42+
SumParam p = new SumParam() { dims = new IntPtr(ptr), keepDims = keepDims };
43+
IntPtr status = NativeExecutor.Execute(NativeApi.Sum, src.TMemory, dst.TMemory, src.TLayout, dst.TLayout, new IntPtr(&p), Tensor<T>.Provider);
44+
NativeStatus.AssertOK(status);
45+
}
46+
if(!keepDims){
47+
dst.TLayout.RemoveAllDanglingAxisInplace();
48+
}
49+
}
50+
private static TensorLayout DeduceLayout(TensorLayout src, int[] axes){
51+
var res = new TensorLayout(src, true);
52+
foreach(var dim in axes){
53+
res.Shape[dim] = 1;
54+
}
55+
return res;
56+
}
57+
private static TensorLayout DeduceLayout(TensorLayout src, int axis){
58+
var res = new TensorLayout(src, true);
59+
res.Shape[axis] = 1;
60+
return res;
61+
}
62+
private static TensorLayout DeduceLayout(TensorLayout src){
63+
var res = new TensorLayout(src, true);
64+
res.Shape.AsSpan().Fill(1);
65+
return res;
66+
}
67+
}
68+
69+
public static partial class Tensor{
70+
/// <summary>
71+
/// Sum the tensor.
72+
/// </summary>
73+
/// <typeparam name="T"></typeparam>
74+
/// <param name="src"> The tensor to be sumed. </param>
75+
/// <param name="axes"> The axes to sum. </param>
76+
/// <param name="keepDims"> Whether to keep the dims after the sum. If false, the NDim of the result may be different with the input. </param>
77+
/// <returns>The Sumped tensor</returns>
78+
public static Tensor<T> Sum<T>(Tensor<T> src, int[] axes, bool keepDims = false) where T : struct, IEquatable<T>, IConvertible{
79+
return src.Sum(axes, keepDims);
80+
}
81+
/// <summary>
82+
/// Sum the tensor.
83+
/// </summary>
84+
/// <typeparam name="T"></typeparam>
85+
/// <param name="src"> The tensor to be sumed. </param>
86+
/// <param name="axis"> The axis to sum. </param>
87+
/// <param name="keepDims"> Whether to keep the dims after the sum. If false, the NDim of the result may be different with the input. </param>
88+
/// <returns>The Sumped tensor</returns>
89+
public static Tensor<T> Sum<T>(Tensor<T> src, int axis, bool keepDims = false) where T : struct, IEquatable<T>, IConvertible{
90+
return src.Sum(axis, keepDims);
91+
}
92+
/// <summary>
93+
/// Sum the tensor.
94+
/// </summary>
95+
/// <typeparam name="T"></typeparam>
96+
/// <param name="src"> The tensor to be sumed. </param>
97+
/// <param name="keepDims"> Whether to keep the dims after the sum. If false, the NDim of the result may be different with the input. </param>
98+
/// <returns>The Sumped tensor</returns>
99+
public static Tensor<T> Sum<T>(Tensor<T> src, bool keepDims = false) where T : struct, IEquatable<T>, IConvertible{
100+
return src.Sum(keepDims);
101+
}
102+
}
103+
}

‎csharp/Tensor.NET/Tensor/Common/TensorLayout.cs‎

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,14 @@ internal TensorLayout RemoveAxis(int axis){
315315
return res;
316316
}
317317

318+
internal void RemoveAllDanglingAxisInplace(){
319+
for (int i = NDim - 1; i >= 0; i--){
320+
if(Shape[i] == 1 && NDim > 1){
321+
RemoveAxisInplace(i);
322+
}
323+
}
324+
}
325+
318326
internal void AddAxisInplace(int axis, int shape, int stride) {
319327
if(NDim + 1 > MAX_NDIM){
320328
throw new InvalidArgumentException($"can not add axis at {axis} (current ndim is {NDim}, MAX_NDIM is {MAX_NDIM})");

‎csharp/TensorOpTest/Sum.cs‎

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
using NUnit.Framework;
2+
using Tensornet;
3+
4+
namespace TensorOpTest;
5+
6+
public class SumTest
7+
{
8+
[SetUp]
9+
public void Setup()
10+
{
11+
}
12+
13+
[Test]
14+
public void Test1()
15+
{
16+
var s = Tensor.FromArray<int>(
17+
new int[] { 18, 4, 35, 22, 34, 18, 15, 5, 13, 49, 41, 3,
18+
4, 35, 14, 22, 38, 29, 42, 47, 30, 4, 9, 0,
19+
35, 0, 43, 27, 44, 29, 2, 17, 16, 6, 36, 41,
20+
1, 22, 3, 30, 35, 13, 41, 45, 38, 44, 38, 21,
21+
42, 41, 19, 24, 40, 39, 37, 24, 28, 17, 22, 21,
22+
15, 23, 35, 16, 31, 3, 29, 28, 0, 15, 17, 10 }, new TensorShape(3, 4, 2, 3));
23+
var t = Tensor.FromArray<int>(
24+
new int[] { 1734 }, new TensorShape(1));
25+
var p = s.Sum();
26+
Assert.IsTrue(TensorUtils.IsValueEqual(t, p));
27+
}
28+
29+
// [Test]
30+
// public void Test2()
31+
// {
32+
// var s = Tensor.FromArray<int>(
33+
// new int[] { -151, -46, -9, -62, -158, -74, 35, -10, -123, -94, -122, 58,
34+
// -124, 139, -173, -137, -178, 116, 52, -92, -14, -176, -133, -109,
35+
// -114, -157, -186, 46, -78, -144, 155, -60, 47, 150, -133, -58,
36+
// -17, -161, -36, 11, 133, -170, -149, -155, 10, -118, -112, -103,
37+
// -110, 183, 29, 21, 189, -85, 83, -186, -114, -104, -171, -116,
38+
// -110, 88, -130, 42, 106, 120, -94, -77, 49, 74, 96, -28 }, new TensorShape(6, 3, 2, 2));
39+
// var t = Tensor.FromArray<int>(
40+
// new int[] { -151, -124, -114, -17, -110, -110, -158, -178, -78, 133, 189, 106,
41+
// -123, -14, 47, 10, -114, 49, -46, 139, -157, -161, 183, 88,
42+
// -74, 116, -144, -170, -85, 120, -94, -176, 150, -118, -104, 74,
43+
// -9, -173, -186, -36, 29, -130, 35, 52, 155, -149, 83, -94,
44+
// -122, -133, -133, -112, -171, 96, -62, -137, 46, 11, 21, 42,
45+
// -10, -92, -60, -155, -186, -77, 58, -109, -58, -103, -116, -28 }, new TensorShape(2, 2, 3, 6));
46+
// var p = s.Sum(2, 3, 1, 0);
47+
// Assert.IsTrue(TensorUtils.IsValueEqual(t, p));
48+
// }
49+
50+
// [Test]
51+
// public void Test3()
52+
// {
53+
// var s = Tensor.FromArray<double>(
54+
// new double[] { -144.88911173213074, -177.9092558668978 }, new TensorShape(1, 2));
55+
// var t = Tensor.FromArray<double>(
56+
// new double[] { -144.88911173213074, -177.9092558668978 }, new TensorShape(2, 1));
57+
// var p = s.Sum(1, 0);
58+
// Assert.IsTrue(TensorUtils.IsValueEqual(t, p));
59+
// }
60+
}

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /