17
17
18
18
import * as tf from '../index' ;
19
19
import { ALL_ENVS , describeWithFlags } from '../jasmine_util' ;
20
- import { Tensor } from '../tensor' ;
21
20
import { expectArraysClose } from '../test_util' ;
22
21
23
- describeWithFlags ( 'booleanMask ' , ALL_ENVS , ( ) => {
22
+ describeWithFlags ( 'booleanMaskAsync ' , ALL_ENVS , ( ) => {
24
23
it ( '1d array, 1d mask, default axis' , async ( ) => {
25
24
const array = tf . tensor1d ( [ 1 , 2 , 3 ] ) ;
26
25
const mask = tf . tensor1d ( [ 1 , 0 , 1 ] , 'bool' ) ;
27
- const result = await tf . booleanMask ( array , mask ) ;
26
+ const result = await tf . booleanMaskAsync ( array , mask ) ;
28
27
expect ( result . shape ) . toEqual ( [ 2 ] ) ;
29
28
expect ( result . dtype ) . toBe ( 'float32' ) ;
30
29
expectArraysClose ( await result . data ( ) , [ 1 , 3 ] ) ;
@@ -33,7 +32,7 @@ describeWithFlags('booleanMask', ALL_ENVS, () => {
33
32
it ( '2d array, 1d mask, default axis' , async ( ) => {
34
33
const array = tf . tensor2d ( [ 1 , 2 , 3 , 4 , 5 , 6 ] , [ 3 , 2 ] ) ;
35
34
const mask = tf . tensor1d ( [ 1 , 0 , 1 ] , 'bool' ) ;
36
- const result = await tf . booleanMask ( array , mask ) ;
35
+ const result = await tf . booleanMaskAsync ( array , mask ) ;
37
36
expect ( result . shape ) . toEqual ( [ 2 , 2 ] ) ;
38
37
expect ( result . dtype ) . toBe ( 'float32' ) ;
39
38
expectArraysClose ( await result . data ( ) , [ 1 , 2 , 5 , 6 ] ) ;
@@ -42,7 +41,7 @@ describeWithFlags('booleanMask', ALL_ENVS, () => {
42
41
it ( '2d array, 2d mask, default axis' , async ( ) => {
43
42
const array = tf . tensor2d ( [ 1 , 2 , 3 , 4 , 5 , 6 ] , [ 3 , 2 ] ) ;
44
43
const mask = tf . tensor2d ( [ 1 , 0 , 1 , 0 , 1 , 0 ] , [ 3 , 2 ] , 'bool' ) ;
45
- const result = await tf . booleanMask ( array , mask ) ;
44
+ const result = await tf . booleanMaskAsync ( array , mask ) ;
46
45
expect ( result . shape ) . toEqual ( [ 3 ] ) ;
47
46
expect ( result . dtype ) . toBe ( 'float32' ) ;
48
47
expectArraysClose ( await result . data ( ) , [ 1 , 3 , 5 ] ) ;
@@ -52,7 +51,7 @@ describeWithFlags('booleanMask', ALL_ENVS, () => {
52
51
const array = tf . tensor2d ( [ 1 , 2 , 3 , 4 , 5 , 6 ] , [ 3 , 2 ] ) ;
53
52
const mask = tf . tensor1d ( [ 0 , 1 ] , 'bool' ) ;
54
53
const axis = 1 ;
55
- const result = await tf . booleanMask ( array , mask , axis ) ;
54
+ const result = await tf . booleanMaskAsync ( array , mask , axis ) ;
56
55
expect ( result . shape ) . toEqual ( [ 3 , 1 ] ) ;
57
56
expect ( result . dtype ) . toBe ( 'float32' ) ;
58
57
expectArraysClose ( await result . data ( ) , [ 2 , 4 , 6 ] ) ;
@@ -61,7 +60,7 @@ describeWithFlags('booleanMask', ALL_ENVS, () => {
61
60
it ( 'accepts tensor-like object as array or mask' , async ( ) => {
62
61
const array = [ [ 1 , 2 ] , [ 3 , 4 ] , [ 5 , 6 ] ] ;
63
62
const mask = [ 1 , 0 , 1 ] ;
64
- const result = await tf . booleanMask ( array , mask ) ;
63
+ const result = await tf . booleanMaskAsync ( array , mask ) ;
65
64
expect ( result . shape ) . toEqual ( [ 2 , 2 ] ) ;
66
65
expect ( result . dtype ) . toBe ( 'float32' ) ;
67
66
expectArraysClose ( await result . data ( ) , [ 1 , 2 , 5 , 6 ] ) ;
@@ -72,13 +71,8 @@ describeWithFlags('booleanMask', ALL_ENVS, () => {
72
71
73
72
const array = tf . tensor1d ( [ 1 , 2 , 3 ] ) ;
74
73
const mask = tf . tensor1d ( [ 1 , 0 , 1 ] , 'bool' ) ;
75
- let resultPromise : Promise < Tensor > = null ;
76
74
77
- tf . tidy ( ( ) => {
78
- resultPromise = tf . booleanMask ( array , mask ) ;
79
- } ) ;
80
-
81
- const result = await resultPromise ;
75
+ const result = await tf . booleanMaskAsync ( array , mask ) ;
82
76
expect ( result . shape ) . toEqual ( [ 2 ] ) ;
83
77
expect ( result . dtype ) . toBe ( 'float32' ) ;
84
78
expectArraysClose ( await result . data ( ) , [ 1 , 3 ] ) ;
@@ -95,7 +89,7 @@ describeWithFlags('booleanMask', ALL_ENVS, () => {
95
89
const mask = tf . scalar ( 1 , 'bool' ) ;
96
90
let errorMessage = 'No error thrown.' ;
97
91
try {
98
- await tf . booleanMask ( array , mask ) ;
92
+ await tf . booleanMaskAsync ( array , mask ) ;
99
93
} catch ( error ) {
100
94
errorMessage = error . message ;
101
95
}
@@ -107,7 +101,7 @@ describeWithFlags('booleanMask', ALL_ENVS, () => {
107
101
const mask = tf . tensor2d ( [ 1 , 0 ] , [ 1 , 2 ] , 'bool' ) ;
108
102
let errorMessage = 'No error thrown.' ;
109
103
try {
110
- await tf . booleanMask ( array , mask ) ;
104
+ await tf . booleanMaskAsync ( array , mask ) ;
111
105
} catch ( error ) {
112
106
errorMessage = error . message ;
113
107
}
0 commit comments