Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Commit 2ae5a8a

Browse files
jarno-rdsmilkov
authored andcommitted
Fixed division by zero in QR decomposition. Issue #1058 (#1473)
tensorflow/tfjs#1058 The sign() function returns 0 on 0, which causes a division by zero in the QR decomposition function qr() if there is a zero on the diagonal. BUG
1 parent b484b28 commit 2ae5a8a

File tree

5 files changed

+49
-30
lines changed

5 files changed

+49
-30
lines changed

‎src/io/passthrough_test.ts‎

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,8 @@ describeWithFlags('Passthrough Saver', BROWSER_ENVS, () => {
115115

116116
describeWithFlags('Passthrough Loader', BROWSER_ENVS, () => {
117117
it('load topology and weights: legacy signature', async () => {
118-
const passthroughHandler =tf.io.fromMemory(
119-
modelTopology1, weightSpecs1, weightData1);
118+
const passthroughHandler =
119+
tf.io.fromMemory(modelTopology1, weightSpecs1, weightData1);
120120
const modelArtifacts = await passthroughHandler.load();
121121
expect(modelArtifacts.modelTopology).toEqual(modelTopology1);
122122
expect(modelArtifacts.weightSpecs).toEqual(weightSpecs1);
@@ -147,9 +147,8 @@ describeWithFlags('Passthrough Loader', BROWSER_ENVS, () => {
147147
});
148148

149149
it('load model topology only', async () => {
150-
const passthroughHandler = tf.io.fromMemory({
151-
modelTopology: modelTopology1
152-
});
150+
const passthroughHandler =
151+
tf.io.fromMemory({modelTopology: modelTopology1});
153152
const modelArtifacts = await passthroughHandler.load();
154153
expect(modelArtifacts.modelTopology).toEqual(modelTopology1);
155154
expect(modelArtifacts.weightSpecs).toEqual(undefined);

‎src/jasmine_util.ts‎

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ export const SYNC_BACKEND_ENVS: Constraints = {
4040
};
4141

4242
export const HAS_WORKER = {
43-
predicate: () => typeof(Worker) !== 'undefined'
44-
&&typeof(Blob) !== 'undefined' && typeof(URL) !== 'undefined'
43+
predicate: () => typeof(Worker) !== 'undefined'&&
44+
typeof(Blob) !== 'undefined' && typeof(URL) !== 'undefined'
4545
};
4646

4747
export const HAS_NODE_WORKER = {
@@ -52,7 +52,7 @@ export const HAS_NODE_WORKER = {
5252
} catch {
5353
hasWorker = false;
5454
}
55-
return typeof(process) !== 'undefined' && hasWorker;
55+
return typeof(process) !== 'undefined' && hasWorker;
5656
}
5757
};
5858

‎src/ops/concat_test.ts‎

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ describeWithFlags('concat1d', ALL_ENVS, () => {
8888
expectArraysClose(await result.data(), expected);
8989
});
9090

91-
it('concat complex input', async() => {
91+
it('concat complex input', async() => {
9292
// [1+1j, 2+2j]
9393
const c1 = tf.complex([1, 2], [1, 2]);
9494
// [3+3j, 4+4j]
@@ -234,7 +234,7 @@ describeWithFlags('concat2d', ALL_ENVS, () => {
234234
expectArraysEqual(await res2.data(), []);
235235
});
236236

237-
it('concat complex input axis=0', async() => {
237+
it('concat complex input axis=0', async() => {
238238
// [[1+1j, 2+2j], [3+3j, 4+4j]]
239239
const c1 = tf.complex([[1, 2], [3, 4]], [[1, 2], [3, 4]]);
240240
// [[5+5j, 6+6j], [7+7j, 8+8j]]
@@ -247,7 +247,7 @@ describeWithFlags('concat2d', ALL_ENVS, () => {
247247
expectArraysClose(await result.data(), expected);
248248
});
249249

250-
it('concat complex input axis=1', async() => {
250+
it('concat complex input axis=1', async() => {
251251
// [[1+1j, 2+2j], [3+3j, 4+4j]]
252252
const c1 = tf.complex([[1, 2], [3, 4]], [[1, 2], [3, 4]]);
253253
// [[5+5j, 6+6j], [7+7j, 8+8j]]
@@ -500,50 +500,56 @@ describeWithFlags('concat3d', ALL_ENVS, () => {
500500
expectArraysClose(await values.data(), [1, 2, 3, 4, 5, 6]);
501501
});
502502

503-
it('concat complex input axis=0', async() => {
503+
it('concat complex input axis=0', async() => {
504504
// [[[1+1j, 2+2j], [3+3j, 4+4j], [5+5j, 6+6j]]]
505-
const c1 =tf.complex(
506-
[[[1, 2], [3, 4], [5, 6]]], [[[1, 2], [3, 4], [5, 6]]]);
505+
const c1 =
506+
tf.complex([[[1, 2], [3, 4], [5, 6]]], [[[1, 2], [3, 4], [5, 6]]]);
507507
// [[[7+7j, 8+8j], [9+9j, 10+10j], [11+11j, 12+12j]]]
508508
const c2 = tf.complex(
509-
[[[7, 8], [9, 10], [11, 12]]], [[[7, 8], [9, 10], [11, 12]]]);
509+
[[[7, 8], [9, 10], [11, 12]]], [[[7, 8], [9, 10], [11, 12]]]);
510510

511511
const axis = 0;
512512
const result = tf.concat([c1, c2], axis);
513-
const expected = [1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6,
514-
7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12];
513+
const expected = [
514+
1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6,
515+
7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12
516+
];
515517
expect(result.dtype).toEqual('complex64');
516518
expectArraysClose(await result.data(), expected);
517519
});
518520

519-
it('concat complex input axis=1', async() => {
521+
it('concat complex input axis=1', async() => {
520522
// [[[1+1j, 2+2j], [3+3j, 4+4j], [5+5j, 6+6j]]]
521-
const c1 =tf.complex(
522-
[[[1, 2], [3, 4], [5, 6]]], [[[1, 2], [3, 4], [5, 6]]]);
523+
const c1 =
524+
tf.complex([[[1, 2], [3, 4], [5, 6]]], [[[1, 2], [3, 4], [5, 6]]]);
523525
// [[[7+7j, 8+8j], [9+9j, 10+10j], [11+11j, 12+12j]]]
524526
const c2 = tf.complex(
525-
[[[7, 8], [9, 10], [11, 12]]], [[[7, 8], [9, 10], [11, 12]]]);
527+
[[[7, 8], [9, 10], [11, 12]]], [[[7, 8], [9, 10], [11, 12]]]);
526528

527529
const axis = 1;
528530
const result = tf.concat([c1, c2], axis);
529-
const expected = [1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6,
530-
7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12];
531+
const expected = [
532+
1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6,
533+
7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12
534+
];
531535
expect(result.dtype).toEqual('complex64');
532536
expectArraysClose(await result.data(), expected);
533537
});
534538

535-
it('concat complex input axis=1', async() => {
539+
it('concat complex input axis=1', async() => {
536540
// [[[1+1j, 2+2j], [3+3j, 4+4j], [5+5j, 6+6j]]]
537-
const c1 =tf.complex(
538-
[[[1, 2], [3, 4], [5, 6]]], [[[1, 2], [3, 4], [5, 6]]]);
541+
const c1 =
542+
tf.complex([[[1, 2], [3, 4], [5, 6]]], [[[1, 2], [3, 4], [5, 6]]]);
539543
// [[[7+7j, 8+8j], [9+9j, 10+10j], [11+11j, 12+12j]]]
540544
const c2 = tf.complex(
541-
[[[7, 8], [9, 10], [11, 12]]], [[[7, 8], [9, 10], [11, 12]]]);
545+
[[[7, 8], [9, 10], [11, 12]]], [[[7, 8], [9, 10], [11, 12]]]);
542546

543547
const axis = 2;
544548
const result = tf.concat([c1, c2], axis);
545-
const expected = [1, 1, 2, 2, 7, 7, 8, 8, 3, 3, 4, 4,
546-
9, 9, 10, 10, 5, 5, 6, 6, 11, 11, 12, 12];
549+
const expected = [
550+
1, 1, 2, 2, 7, 7, 8, 8, 3, 3, 4, 4,
551+
9, 9, 10, 10, 5, 5, 6, 6, 11, 11, 12, 12
552+
];
547553
expect(result.dtype).toEqual('complex64');
548554
expectArraysClose(await result.data(), expected);
549555
});

‎src/ops/linalg_ops.ts‎

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,10 @@ function qr2d(x: Tensor2D, fullMatrices = false): [Tensor2D, Tensor2D] {
215215
const rjEnd1 = r.slice([j, j], [m - j, 1]);
216216
const normX = rjEnd1.norm();
217217
const rjj = r.slice([j, j], [1, 1]);
218-
const s = rjj.sign().neg() as Tensor2D;
218+
219+
// The sign() function returns 0 on 0, which causes division by zero.
220+
const s = tensor2d([[-1]]).where(rjj.greater(0), tensor2d([[1]]));
221+
219222
const u1 = rjj.sub(s.mul(normX)) as Tensor2D;
220223
const wPre = rjEnd1.div(u1);
221224
if (wPre.shape[0] === 1) {

‎src/ops/linalg_ops_test.ts‎

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,17 @@ describeWithFlags('qr', ALL_ENVS, () => {
140140
[[-8.3066, 8.3066, -2.4077], [0, 4.5826, -2.1822], [0, 0, 7.6447]]);
141141
});
142142

143+
it('3x3, zero on diagonal', async () => {
144+
const x = tensor2d([[0, 2, 2], [1, 1, 1], [0, 1, 2]], [3, 3]);
145+
const [q, r] = tf.linalg.qr(x);
146+
expectArraysClose(await q.data(), [
147+
[0., -0.89442719, 0.4472136], [1., 0., 0.], [0., -0.4472136, -0.89442719]
148+
]);
149+
expectArraysClose(
150+
await r.data(),
151+
[[1., 1., 1.], [0., -2.23606798, -2.68328157], [0., 0., -0.89442719]]);
152+
});
153+
143154
it('3x2, fullMatrices = default false', async () => {
144155
const x = tensor2d([[1, 2], [3, -3], [-2, 1]], [3, 2]);
145156
const [q, r] = tf.linalg.qr(x);

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /