Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Commit 6311b21

Browse files
author
Nikhil Thorat
authored
Fix unit test to not use tf.tidy in boolean mask test. (#1871)
DEV Also rename booleanMask to booleanMaskAsync.
1 parent cac5b15 commit 6311b21

File tree

2 files changed

+12
-18
lines changed

2 files changed

+12
-18
lines changed

‎src/ops/boolean_mask.ts‎

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ import {gather} from './segment_ops';
2929
* ```js
3030
* const tensor = tf.tensor2d([1, 2, 3, 4, 5, 6], [3, 2]);
3131
* const mask = tf.tensor1d([1, 0, 1], 'bool');
32-
* const result = await tf.booleanMask(tensor, mask);
32+
* const result = await tf.booleanMaskAsync(tensor, mask);
3333
* result.print();
3434
* ```
3535
*
@@ -40,7 +40,7 @@ import {gather} from './segment_ops';
4040
* Otherwise K + axis <= N.
4141
*/
4242
/** @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */
43-
async function booleanMask_(
43+
async function booleanMaskAsync_(
4444
tensor: Tensor|TensorLike, mask: Tensor|TensorLike,
4545
axis?: number): Promise<Tensor> {
4646
const $tensor = convertToTensor(tensor, 'tensor', 'boolMask');
@@ -84,4 +84,4 @@ async function booleanMask_(
8484
return res;
8585
}
8686

87-
export const booleanMask = booleanMask_;
87+
export const booleanMaskAsync = booleanMaskAsync_;

‎src/ops/boolean_mask_test.ts‎

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,13 @@
1717

1818
import * as tf from '../index';
1919
import {ALL_ENVS, describeWithFlags} from '../jasmine_util';
20-
import {Tensor} from '../tensor';
2120
import {expectArraysClose} from '../test_util';
2221

23-
describeWithFlags('booleanMask', ALL_ENVS, () => {
22+
describeWithFlags('booleanMaskAsync', ALL_ENVS, () => {
2423
it('1d array, 1d mask, default axis', async () => {
2524
const array = tf.tensor1d([1, 2, 3]);
2625
const mask = tf.tensor1d([1, 0, 1], 'bool');
27-
const result = await tf.booleanMask(array, mask);
26+
const result = await tf.booleanMaskAsync(array, mask);
2827
expect(result.shape).toEqual([2]);
2928
expect(result.dtype).toBe('float32');
3029
expectArraysClose(await result.data(), [1, 3]);
@@ -33,7 +32,7 @@ describeWithFlags('booleanMask', ALL_ENVS, () => {
3332
it('2d array, 1d mask, default axis', async () => {
3433
const array = tf.tensor2d([1, 2, 3, 4, 5, 6], [3, 2]);
3534
const mask = tf.tensor1d([1, 0, 1], 'bool');
36-
const result = await tf.booleanMask(array, mask);
35+
const result = await tf.booleanMaskAsync(array, mask);
3736
expect(result.shape).toEqual([2, 2]);
3837
expect(result.dtype).toBe('float32');
3938
expectArraysClose(await result.data(), [1, 2, 5, 6]);
@@ -42,7 +41,7 @@ describeWithFlags('booleanMask', ALL_ENVS, () => {
4241
it('2d array, 2d mask, default axis', async () => {
4342
const array = tf.tensor2d([1, 2, 3, 4, 5, 6], [3, 2]);
4443
const mask = tf.tensor2d([1, 0, 1, 0, 1, 0], [3, 2], 'bool');
45-
const result = await tf.booleanMask(array, mask);
44+
const result = await tf.booleanMaskAsync(array, mask);
4645
expect(result.shape).toEqual([3]);
4746
expect(result.dtype).toBe('float32');
4847
expectArraysClose(await result.data(), [1, 3, 5]);
@@ -52,7 +51,7 @@ describeWithFlags('booleanMask', ALL_ENVS, () => {
5251
const array = tf.tensor2d([1, 2, 3, 4, 5, 6], [3, 2]);
5352
const mask = tf.tensor1d([0, 1], 'bool');
5453
const axis = 1;
55-
const result = await tf.booleanMask(array, mask, axis);
54+
const result = await tf.booleanMaskAsync(array, mask, axis);
5655
expect(result.shape).toEqual([3, 1]);
5756
expect(result.dtype).toBe('float32');
5857
expectArraysClose(await result.data(), [2, 4, 6]);
@@ -61,7 +60,7 @@ describeWithFlags('booleanMask', ALL_ENVS, () => {
6160
it('accepts tensor-like object as array or mask', async () => {
6261
const array = [[1, 2], [3, 4], [5, 6]];
6362
const mask = [1, 0, 1];
64-
const result = await tf.booleanMask(array, mask);
63+
const result = await tf.booleanMaskAsync(array, mask);
6564
expect(result.shape).toEqual([2, 2]);
6665
expect(result.dtype).toBe('float32');
6766
expectArraysClose(await result.data(), [1, 2, 5, 6]);
@@ -72,13 +71,8 @@ describeWithFlags('booleanMask', ALL_ENVS, () => {
7271

7372
const array = tf.tensor1d([1, 2, 3]);
7473
const mask = tf.tensor1d([1, 0, 1], 'bool');
75-
let resultPromise: Promise<Tensor> = null;
7674

77-
tf.tidy(() => {
78-
resultPromise = tf.booleanMask(array, mask);
79-
});
80-
81-
const result = await resultPromise;
75+
const result = await tf.booleanMaskAsync(array, mask);
8276
expect(result.shape).toEqual([2]);
8377
expect(result.dtype).toBe('float32');
8478
expectArraysClose(await result.data(), [1, 3]);
@@ -95,7 +89,7 @@ describeWithFlags('booleanMask', ALL_ENVS, () => {
9589
const mask = tf.scalar(1, 'bool');
9690
let errorMessage = 'No error thrown.';
9791
try {
98-
await tf.booleanMask(array, mask);
92+
await tf.booleanMaskAsync(array, mask);
9993
} catch (error) {
10094
errorMessage = error.message;
10195
}
@@ -107,7 +101,7 @@ describeWithFlags('booleanMask', ALL_ENVS, () => {
107101
const mask = tf.tensor2d([1, 0], [1, 2], 'bool');
108102
let errorMessage = 'No error thrown.';
109103
try {
110-
await tf.booleanMask(array, mask);
104+
await tf.booleanMaskAsync(array, mask);
111105
} catch (error) {
112106
errorMessage = error.message;
113107
}

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /