Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Commit 707a669

Browse files
kedevkeddsmilkov
authored andcommitted
add tf.diag (#1256)
cla: yes Add tf.diag
1 parent 975e5f6 commit 707a669

File tree

8 files changed

+219
-2
lines changed

8 files changed

+219
-2
lines changed

‎src/backends/backend.ts‎

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,10 @@ export class KernelBackend implements TensorStorage, Backend, BackendTimer {
609609
throw new Error('Not yet implemented');
610610
}
611611

612+
diag(x: Tensor): Tensor {
613+
throw new Error('Not yet implemented');
614+
}
615+
612616
fill<R extends Rank>(
613617
shape: ShapeMap[R], value: number|string, dtype?: DataType): Tensor<R> {
614618
throw new Error('Not yet implemented.');

‎src/backends/cpu/backend_cpu.ts‎

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import {warn} from '../../log';
2323
import * as array_ops_util from '../../ops/array_ops_util';
2424
import * as axis_util from '../../ops/axis_util';
2525
import * as broadcast_util from '../../ops/broadcast_util';
26+
import {complex, imag, real} from '../../ops/complex_ops';
2627
import * as concat_util from '../../ops/concat_util';
2728
import {Conv2DInfo, Conv3DInfo} from '../../ops/conv_util';
2829
import * as erf_util from '../../ops/erf_util';
@@ -45,7 +46,6 @@ import {split} from '../split_shared';
4546
import {tile} from '../tile_impl';
4647
import {topkImpl} from '../topk_impl';
4748
import {whereImpl} from '../where_impl';
48-
import {real, imag, complex} from '../../ops/complex_ops';
4949

5050
function mapActivation(
5151
backend: MathBackendCPU, x: Tensor, activation: Activation,
@@ -343,6 +343,16 @@ export class MathBackendCPU implements KernelBackend {
343343
return buffer.toTensor().reshape(shape) as T;
344344
}
345345

346+
diag(x: Tensor): Tensor {
347+
const xVals = this.readSync(x.dataId) as TypedArray;
348+
const buffer = ops.buffer([x.size, x.size], x.dtype);
349+
const vals = buffer.values;
350+
for (let i = 0; i < xVals.length; i++) {
351+
vals[i * x.size + i] = xVals[i];
352+
}
353+
return buffer.toTensor();
354+
}
355+
346356
unstack(x: Tensor, axis: number): Tensor[] {
347357
const num = x.shape[axis];
348358
const outShape: number[] = new Array(x.rank - 1);

‎src/backends/webgl/backend_webgl.ts‎

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import {warn} from '../../log';
2626
import {buffer} from '../../ops/array_ops';
2727
import * as array_ops_util from '../../ops/array_ops_util';
2828
import * as axis_util from '../../ops/axis_util';
29+
import {complex, imag, real} from '../../ops/complex_ops';
2930
import {computeOutShape} from '../../ops/concat_util';
3031
import {Conv2DInfo, Conv3DInfo} from '../../ops/conv_util';
3132
import {Activation, FusedBatchMatMulConfig} from '../../ops/fused_util';
@@ -78,6 +79,7 @@ import {CumSumProgram} from './cumsum_gpu';
7879
import {DecodeMatrixProgram} from './decode_matrix_gpu';
7980
import {DecodeMatrixPackedProgram} from './decode_matrix_packed_gpu';
8081
import {DepthToSpaceProgram} from './depth_to_space_gpu';
82+
import {DiagProgram} from './diag_gpu';
8183
import {EncodeFloatProgram} from './encode_float_gpu';
8284
import {EncodeFloatPackedProgram} from './encode_float_packed_gpu';
8385
import {EncodeMatrixProgram} from './encode_matrix_gpu';
@@ -131,7 +133,6 @@ import * as unary_packed_op from './unaryop_packed_gpu';
131133
import {UnaryOpPackedProgram} from './unaryop_packed_gpu';
132134
import {UnpackProgram} from './unpack_gpu';
133135
import * as webgl_util from './webgl_util';
134-
import {real, imag, complex} from '../../ops/complex_ops';
135136

136137
type KernelInfo = {
137138
name: string; query: Promise<number>;
@@ -2208,6 +2209,11 @@ export class MathBackendWebGL implements KernelBackend {
22082209
return this.compileAndRun(program, [indices]);
22092210
}
22102211

2212+
diag(x: Tensor): Tensor {
2213+
const program = new DiagProgram(x.size);
2214+
return this.compileAndRun(program, [x]);
2215+
}
2216+
22112217
nonMaxSuppression(
22122218
boxes: Tensor2D, scores: Tensor1D, maxOutputSize: number,
22132219
iouThreshold: number, scoreThreshold: number): Tensor1D {

‎src/backends/webgl/diag_gpu.ts‎

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
/**
2+
* @license
3+
* Copyright 2019 Google Inc. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import {GPGPUProgram} from './gpgpu_math';
19+
20+
export class DiagProgram implements GPGPUProgram {
21+
variableNames = ['X'];
22+
outputShape: number[];
23+
userCode: string;
24+
25+
constructor(size: number) {
26+
this.outputShape = [size, size];
27+
this.userCode = `
28+
void main() {
29+
ivec2 coords = getOutputCoords();
30+
float val = coords[0] == coords[1] ? getX(coords[0]) : 0.0;
31+
setOutput(val);
32+
}
33+
`;
34+
}
35+
}

‎src/ops/diag.ts‎

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
/**
2+
* @license
3+
* Copyright 2019 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import {ENGINE} from '../engine';
19+
import {Tensor} from '../tensor';
20+
import {convertToTensor} from '../tensor_util_env';
21+
import {op} from './operation';
22+
23+
/**
24+
* Returns a diagonal tensor with a given diagonal values.
25+
*
26+
* Given a diagonal, this operation returns a tensor with the diagonal and
27+
* everything else padded with zeros.
28+
*
29+
* Assume the input has dimensions `[D1,..., Dk]`, then the output is a tensor
30+
* of rank 2k with dimensions `[D1,..., Dk, D1,..., Dk]`
31+
*
32+
* ```js
33+
* const x = tf.tensor1d([1, 2, 3, 4]);
34+
*
35+
* tf.diag(x).print()
36+
* ```
37+
* ```js
38+
* const x = tf.tensor1d([1, 2, 3, 4, 5, 6, 6, 8], [4, 2])
39+
*
40+
* tf.diag(x).print()
41+
* ```
42+
* @param x The input tensor.
43+
*/
44+
function diag_(x: Tensor): Tensor {
45+
const $x = convertToTensor(x, 'x', 'diag').flatten();
46+
const outShape = [...x.shape, ...x.shape];
47+
return ENGINE.runKernel(backend => backend.diag($x), {$x}).reshape(outShape);
48+
}
49+
50+
export const diag = op({diag_});

‎src/ops/diag_test.ts‎

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
/**
2+
* @license
3+
* Copyright 2019 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
import * as tf from '../index';
18+
import {ALL_ENVS, describeWithFlags} from '../jasmine_util';
19+
import {expectArraysClose, expectArraysEqual} from '../test_util';
20+
21+
describeWithFlags('diag', ALL_ENVS, () => {
22+
it('1d', async () => {
23+
const m = tf.tensor1d([5]);
24+
const diag = tf.diag(m);
25+
expect(diag.shape).toEqual([1, 1]);
26+
expectArraysClose(await diag.data(), [5]);
27+
});
28+
it('2d', async () => {
29+
const m = tf.tensor2d([8, 2, 3, 4, 5, 1], [3, 2]);
30+
const diag = tf.diag(m);
31+
expect(diag.shape).toEqual([3, 2, 3, 2]);
32+
expectArraysClose(await diag.data(), [
33+
8, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0,
34+
0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 1
35+
]);
36+
});
37+
it('3d', async () => {
38+
const m = tf.tensor3d([8, 5, 5, 7, 9, 10, 15, 1, 2, 14, 12, 3], [2, 2, 3]);
39+
const diag = tf.diag(m);
40+
expect(diag.shape).toEqual([2, 2, 3, 2, 2, 3]);
41+
expectArraysClose(await diag.data(), [
42+
8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0,
43+
0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, 0, 0,
44+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
45+
0, 0, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 15, 0, 0, 0, 0, 0,
46+
0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2,
47+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 14, 0, 0, 0, 0, 0, 0, 0, 0,
48+
0, 0, 0, 0, 12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3,
49+
]);
50+
});
51+
it('4d', async () => {
52+
const m = tf.tensor4d(
53+
[
54+
8, 5, 5, 7, 9, 10, 15, 1, 2, 14, 12, 3,
55+
9, 6, 6, 8, 10, 11, 16, 2, 3, 15, 13, 4
56+
],
57+
[2, 2, 3, 2]);
58+
const diag = tf.diag(m);
59+
expect(diag.shape).toEqual([2, 2, 3, 2, 2, 2, 3, 2]);
60+
expectArraysClose(await diag.data(), [
61+
8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
62+
0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
63+
0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
64+
0, 0, 0, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
65+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0,
66+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 0, 0, 0, 0, 0, 0,
67+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 15, 0, 0, 0,
68+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
69+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
70+
0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
71+
0, 0, 0, 0, 0, 14, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
72+
0, 0, 0, 0, 0, 0, 0, 0, 12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
73+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
74+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9, 0, 0, 0, 0, 0, 0, 0,
75+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0,
76+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 0,
77+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
78+
0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
79+
0, 0, 0, 0, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
80+
0, 0, 0, 0, 0, 0, 0, 11, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
81+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
82+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0,
83+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0,
84+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 15, 0, 0,
85+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
86+
13, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
87+
0, 0, 0, 4
88+
]);
89+
});
90+
it('int32', async () => {
91+
const m = tf.tensor1d([5, 3], 'int32');
92+
const diag = tf.diag(m);
93+
expect(diag.shape).toEqual([2, 2]);
94+
expect(diag.dtype).toBe('int32');
95+
expectArraysEqual(await diag.data(), [5, 0, 0, 3]);
96+
});
97+
it('bool', async () => {
98+
const m = tf.tensor1d([5, 3], 'bool');
99+
const diag = tf.diag(m);
100+
expect(diag.shape).toEqual([2, 2]);
101+
expect(diag.dtype).toBe('bool');
102+
expectArraysEqual(await diag.data(), [1, 0, 0, 1]);
103+
});
104+
it('complex', () => {
105+
const real = tf.tensor1d([2.25]);
106+
const imag = tf.tensor1d([4.75]);
107+
const m = tf.complex(real, imag);
108+
expect(() => tf.diag(m)).toThrowError();
109+
});
110+
});

‎src/ops/ops.ts‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ export * from './scatter_nd';
4444
export * from './spectral_ops';
4545
export * from './sparse_to_dense';
4646
export * from './gather_nd';
47+
export * from './diag';
4748
export * from './dropout';
4849
export * from './signal_ops';
4950

‎src/tests.ts‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ import './ops/conv2d_test';
5555
import './ops/conv2d_transpose_test';
5656
import './ops/conv3d_test';
5757
import './ops/conv_util_test';
58+
import './ops/diag_test';
5859
import './ops/dropout_test';
5960
import './ops/fused_test';
6061
import './ops/gather_nd_test';

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /