Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Commit 5cc5267

Browse files
Fuse prelu activation. (#1867)
FEATURE PERF
1 parent 5aa35a3 commit 5cc5267

File tree

8 files changed

+454
-108
lines changed

8 files changed

+454
-108
lines changed

‎src/backends/backend.ts‎

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
*/
1717

1818
import {Conv2DInfo, Conv3DInfo} from '../ops/conv_util';
19-
import {Activation} from '../ops/fused_util';
19+
import {Activation,FusedBatchMatMulConfig} from '../ops/fused_util';
2020
import {Backend, DataId, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D} from '../tensor';
2121
import {BackendValues, DataType, PixelData, Rank, ShapeMap} from '../types';
2222

@@ -132,8 +132,8 @@ export class KernelBackend implements TensorStorage, Backend, BackendTimer {
132132
}
133133

134134
fusedBatchMatMul(
135-
a: Tensor3D, b: Tensor3D, transposeA: boolean, transposeB: boolean,
136-
bias?: Tensor,activation?: Activation): Tensor3D {
135+
{a, b, transposeA, transposeB, bias, activation, preluActivationWeights}:
136+
FusedBatchMatMulConfig): Tensor3D {
137137
throw new Error('Not yet implemented');
138138
}
139139

@@ -413,7 +413,7 @@ export class KernelBackend implements TensorStorage, Backend, BackendTimer {
413413

414414
fusedConv2d(
415415
x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo, bias?: Tensor4D,
416-
activation?: Activation): Tensor4D {
416+
activation?: Activation,preluActivationWeights?: Tensor): Tensor4D {
417417
throw new Error('Not yet implemented');
418418
}
419419

‎src/backends/cpu/backend_cpu.ts‎

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import * as broadcast_util from '../../ops/broadcast_util';
2626
import * as concat_util from '../../ops/concat_util';
2727
import {Conv2DInfo, Conv3DInfo} from '../../ops/conv_util';
2828
import * as erf_util from '../../ops/erf_util';
29-
import {Activation} from '../../ops/fused_util';
29+
import {Activation,FusedBatchMatMulConfig} from '../../ops/fused_util';
3030
import * as gather_nd_util from '../../ops/gather_nd_util';
3131
import * as ops from '../../ops/ops';
3232
import {buffer, scalar, tensor, tensor3d, tensor4d} from '../../ops/ops';
@@ -47,11 +47,14 @@ import {topkImpl} from '../topk_impl';
4747
import {whereImpl} from '../where_impl';
4848

4949
function mapActivation(
50-
backend: MathBackendCPU, activation: Activation, x: Tensor): Tensor {
50+
backend: MathBackendCPU, x: Tensor, activation: Activation,
51+
preluActivationWeights?: Tensor): Tensor {
5152
if (activation === 'linear') {
5253
return backend.linear(x);
5354
} else if (activation === 'relu') {
5455
return backend.relu(x);
56+
} else if (activation === 'prelu') {
57+
return backend.prelu(x, preluActivationWeights);
5558
}
5659
throw new Error(
5760
`Activation ${activation} has not been implemented for the CPU backend.`);
@@ -522,14 +525,16 @@ export class MathBackendCPU implements KernelBackend {
522525
}
523526

524527
fusedBatchMatMul(
525-
a: Tensor3D, b: Tensor3D, transposeA: boolean, transposeB: boolean,
526-
bias?: Tensor,activation?: Activation): Tensor3D {
528+
{a, b, transposeA, transposeB, bias, activation, preluActivationWeights}:
529+
FusedBatchMatMulConfig): Tensor3D {
527530
let result = this.batchMatMul(a, b, transposeA, transposeB);
528531
if (bias) {
529532
result = this.add(result, bias) as Tensor3D;
530533
}
531534
if (activation) {
532-
result = mapActivation(this, activation, result) as Tensor3D;
535+
result =
536+
mapActivation(this, result, activation, preluActivationWeights) as
537+
Tensor3D;
533538
}
534539
return result;
535540
}
@@ -1515,14 +1520,16 @@ export class MathBackendCPU implements KernelBackend {
15151520

15161521
fusedConv2d(
15171522
x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo, bias?: Tensor4D,
1518-
activation?: Activation): Tensor4D {
1523+
activation?: Activation,preluActivationWeights?: Tensor): Tensor4D {
15191524
let result = this.conv2d(x, filter, convInfo);
15201525

15211526
if (bias) {
15221527
result = this.add(result, bias) as Tensor4D;
15231528
}
15241529
if (activation) {
1525-
result = mapActivation(this, activation, result) as Tensor4D;
1530+
result =
1531+
mapActivation(this, result, activation, preluActivationWeights) as
1532+
Tensor4D;
15261533
}
15271534
return result;
15281535
}

‎src/backends/webgl/backend_webgl.ts‎

Lines changed: 49 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import * as array_ops_util from '../../ops/array_ops_util';
2828
import * as axis_util from '../../ops/axis_util';
2929
import {computeOutShape} from '../../ops/concat_util';
3030
import {Conv2DInfo, Conv3DInfo} from '../../ops/conv_util';
31-
import {Activation} from '../../ops/fused_util';
31+
import {Activation,FusedBatchMatMulConfig} from '../../ops/fused_util';
3232
import * as gather_nd_util from '../../ops/gather_nd_util';
3333
import * as reduce_util from '../../ops/reduce_util';
3434
import * as scatter_nd_util from '../../ops/scatter_nd_util';
@@ -174,6 +174,11 @@ function mapActivationToShaderProgram(
174174
return unary_packed_op.RELU;
175175
}
176176
return unary_op.RELU;
177+
} else if (activation === 'prelu') {
178+
if (packed) {
179+
return binaryop_packed_gpu.PRELU;
180+
}
181+
return binaryop_gpu.PRELU;
177182
}
178183
throw new Error(`Activation ${
179184
activation} has not been implemented for the WebGL backend.`);
@@ -865,26 +870,30 @@ export class MathBackendWebGL implements KernelBackend {
865870
}
866871

867872
fusedBatchMatMul(
868-
a: Tensor3D, b: Tensor3D, transposeA: boolean, transposeB: boolean,
869-
bias?: Tensor,activation?: Activation): Tensor3D {
873+
{a, b, transposeA, transposeB, bias, activation, preluActivationWeights}:
874+
FusedBatchMatMulConfig): Tensor3D {
870875
const outerShapeA = transposeA ? a.shape[2] : a.shape[1];
871876
const outerShapeB = transposeB ? b.shape[1] : b.shape[2];
872877
const [batch, , ] = a.shape;
873878

874879
const dtype = upcastType(a.dtype, b.dtype);
875880

876881
const hasBias = bias != null;
882+
const hasPreluActivationWeights = preluActivationWeights != null;
877883
const fusedActivation =
878884
activation ? mapActivationToShaderProgram(activation, true) : null;
879885
const program = new MatMulPackedProgram(
880886
a.shape, [batch, outerShapeA, outerShapeB], transposeA, transposeB,
881-
hasBias, fusedActivation);
887+
hasBias, fusedActivation,hasPreluActivationWeights);
882888
const output =
883889
this.makePackedTensor(program.outputShape, dtype) as Tensor3D;
884890
const inputs: TensorHandle[] = [a, b];
885891
if (bias) {
886892
inputs.push(bias);
887893
}
894+
if (preluActivationWeights) {
895+
inputs.push(preluActivationWeights);
896+
}
888897
return this.compileAndRun<Tensor3D>(program, inputs, output);
889898
}
890899

@@ -1819,7 +1828,7 @@ export class MathBackendWebGL implements KernelBackend {
18191828

18201829
private conv2dByMatMul(
18211830
x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo, bias?: Tensor4D,
1822-
activation?: Activation): Tensor4D {
1831+
activation?: Activation,preluActivationWeights?: Tensor): Tensor4D {
18231832
// Reshapes conv2D input to 2D tensors, uses matMul and then reshape the
18241833
// result from 2D to 4D.
18251834
const xShape = x.shape;
@@ -1850,9 +1859,15 @@ export class MathBackendWebGL implements KernelBackend {
18501859
Tensor3D;
18511860

18521861
return this.reshape<Rank.R4>(
1853-
this.fusedBatchMatMul(
1854-
xReshaped, filterReshaped, transposeA, transposeB, bias,
1855-
activation),
1862+
this.fusedBatchMatMul({
1863+
a: xReshaped,
1864+
b: filterReshaped,
1865+
transposeA,
1866+
transposeB,
1867+
bias,
1868+
activation,
1869+
preluActivationWeights
1870+
}),
18561871
convInfo.outShape);
18571872
}
18581873

@@ -1888,8 +1903,15 @@ export class MathBackendWebGL implements KernelBackend {
18881903
this.reshape(filter, [1, convInfo.inChannels, convInfo.outChannels]) as
18891904
Tensor3D;
18901905

1891-
const pointwiseConv = this.fusedBatchMatMul(
1892-
xReshaped, filterReshaped, transposeA, transposeB, bias, activation);
1906+
const pointwiseConv = this.fusedBatchMatMul({
1907+
a: xReshaped,
1908+
b: filterReshaped,
1909+
transposeA,
1910+
transposeB,
1911+
bias,
1912+
activation,
1913+
preluActivationWeights
1914+
});
18931915
const pointwiseConvTexData = this.texData.get(pointwiseConv.dataId);
18941916
util.assert(
18951917
pointwiseConvTexData.isPacked,
@@ -1906,7 +1928,7 @@ export class MathBackendWebGL implements KernelBackend {
19061928

19071929
private conv2dWithIm2Row(
19081930
x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo, bias?: Tensor4D,
1909-
activation?: Activation): Tensor4D {
1931+
activation?: Activation,preluActivationWeights?: Tensor): Tensor4D {
19101932
// Rearranges conv2d input so each block to be convolved over forms the
19111933
// column of a new matrix with shape [filterWidth * filterHeight *
19121934
// inChannels, outHeight * outWidth]. The filter is also rearranged so each
@@ -1938,42 +1960,53 @@ export class MathBackendWebGL implements KernelBackend {
19381960
]) as Tensor3D;
19391961

19401962
const hasBias = bias != null;
1963+
const hasPreluActivationWeights = preluActivationWeights != null;
19411964
const fusedActivation =
19421965
activation ? mapActivationToShaderProgram(activation, true) : null;
19431966
const matmulProgram = new MatMulPackedProgram(
19441967
im2Col.shape, [1, numCols, convInfo.outChannels], transposeA,
1945-
transposeB, hasBias, fusedActivation);
1968+
transposeB, hasBias, fusedActivation,hasPreluActivationWeights);
19461969
const inputs: TensorHandle[] = [im2Col, w2Row];
19471970
if (bias) {
19481971
inputs.push(bias);
19491972
}
1973+
if (hasPreluActivationWeights) {
1974+
inputs.push(preluActivationWeights);
1975+
}
19501976
const product = this.compileAndRun<Tensor4D>(matmulProgram, inputs);
19511977

19521978
return product.reshape([1, outHeight, outWidth, convInfo.outChannels]);
19531979
}
19541980

19551981
fusedConv2d(
19561982
x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo, bias?: Tensor4D,
1957-
activation?: Activation): Tensor4D {
1983+
activation?: Activation,preluActivationWeights?: Tensor): Tensor4D {
19581984
if (convInfo.filterHeight === 1 && convInfo.filterWidth === 1 &&
19591985
convInfo.dilationHeight === 1 && convInfo.dilationWidth === 1 &&
19601986
convInfo.strideHeight === 1 && convInfo.strideWidth === 1 &&
19611987
(convInfo.padInfo.type === 'SAME' ||
19621988
convInfo.padInfo.type === 'VALID')) {
1963-
return this.conv2dByMatMul(x, filter, convInfo, bias, activation);
1989+
return this.conv2dByMatMul(
1990+
x, filter, convInfo, bias, activation, preluActivationWeights);
19641991
}
19651992
if (ENV.getBool('WEBGL_CONV_IM2COL') && x.shape[0] === 1) {
1966-
return this.conv2dWithIm2Row(x, filter, convInfo, bias, activation);
1993+
return this.conv2dWithIm2Row(
1994+
x, filter, convInfo, bias, activation, preluActivationWeights);
19671995
}
19681996

19691997
const hasBias = bias != null;
1998+
const hasPreluActivationWeights = preluActivationWeights != null;
19701999
const fusedActivation =
19712000
activation ? mapActivationToShaderProgram(activation, false) : null;
1972-
const program = new Conv2DProgram(convInfo, hasBias, fusedActivation);
2001+
const program = new Conv2DProgram(
2002+
convInfo, hasBias, fusedActivation, hasPreluActivationWeights);
19732003
const inputs: TensorHandle[] = [x, filter];
19742004
if (bias) {
19752005
inputs.push(bias);
19762006
}
2007+
if (preluActivationWeights) {
2008+
inputs.push(preluActivationWeights);
2009+
}
19772010
return this.compileAndRun(program, inputs);
19782011
}
19792012

‎src/backends/webgl/conv_gpu.ts‎

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ export class Conv2DProgram implements GPGPUProgram {
2424
userCode: string;
2525

2626
constructor(
27-
convInfo: Conv2DInfo, addBias = false, activation: string = null) {
27+
convInfo: Conv2DInfo, addBias = false, activation: string = null,
28+
hasPreluActivationWeights = false) {
2829
this.outputShape = convInfo.outShape;
2930
const padTop = convInfo.padInfo.top;
3031
const padLeft = convInfo.padInfo.left;
@@ -40,11 +41,18 @@ export class Conv2DProgram implements GPGPUProgram {
4041

4142
let activationSnippet = '', applyActivationSnippet = '';
4243
if (activation) {
43-
activationSnippet = `
44-
float activation(float x) {
44+
if (hasPreluActivationWeights) {
45+
activationSnippet = `float activation(float a) {
46+
float b = getPreluActivationWeightsAtOutCoords();
4547
${activation}
46-
}
47-
`;
48+
}`;
49+
} else {
50+
activationSnippet = `
51+
float activation(float x) {
52+
${activation}
53+
}
54+
`;
55+
}
4856

4957
applyActivationSnippet = `result = activation(result);`;
5058
}
@@ -54,6 +62,10 @@ export class Conv2DProgram implements GPGPUProgram {
5462
this.variableNames.push('bias');
5563
}
5664

65+
if (hasPreluActivationWeights) {
66+
this.variableNames.push('preluActivationWeights');
67+
}
68+
5769
this.userCode = `
5870
${activationSnippet}
5971

‎src/backends/webgl/mulmat_packed_gpu.ts‎

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ export class MatMulPackedProgram implements GPGPUProgram {
2626
constructor(
2727
aShape: [number, number, number], outputShape: [number, number, number],
2828
transposeA = false, transposeB = false, addBias = false,
29-
activation: string = null) {
29+
activation: string = null,hasPreluActivation=false) {
3030
this.outputShape = outputShape;
3131

3232
const sharedDim = transposeA ? aShape[1] : aShape[2];
@@ -39,9 +39,16 @@ export class MatMulPackedProgram implements GPGPUProgram {
3939

4040
let activationSnippet = '', applyActivationSnippet = '';
4141
if (activation) {
42-
activationSnippet = `vec4 activation(vec4 x) {
43-
${activation}
44-
}`;
42+
if (hasPreluActivation) {
43+
activationSnippet = `vec4 activation(vec4 a) {
44+
vec4 b = getPreluActivationWeightsAtOutCoords();
45+
${activation}
46+
}`;
47+
} else {
48+
activationSnippet = `vec4 activation(vec4 x) {
49+
${activation}
50+
}`;
51+
}
4552

4653
applyActivationSnippet = `result = activation(result);`;
4754
}
@@ -51,6 +58,10 @@ export class MatMulPackedProgram implements GPGPUProgram {
5158
this.variableNames.push('bias');
5259
}
5360

61+
if (hasPreluActivation) {
62+
this.variableNames.push('preluActivationWeights');
63+
}
64+
5465
this.userCode = `
5566
${activationSnippet}
5667
@@ -82,4 +93,4 @@ export class MatMulPackedProgram implements GPGPUProgram {
8293
}
8394
`;
8495
}
85-
}
96+
}

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /