1
+ /**
2
+ * @license
3
+ * Copyright 2019 Google Inc. All Rights Reserved.
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ * =============================================================================
16
+ */
17
+
18
+ import { GPGPUContext } from './gpgpu_context' ;
19
+ import { GPGPUProgram } from './gpgpu_math' ;
20
+ import { getCoordsDataType } from './shader_compiler' ;
21
+ import { getChannels } from '../packing_util' ;
22
+
23
+ export class SlicePackedProgram implements GPGPUProgram {
24
+ variableNames = [ 'source' ] ;
25
+ usesPackedTextures = true ;
26
+ outputShape : number [ ] ;
27
+ userCode : string ;
28
+ rank : number ;
29
+
30
+ // Caching uniform location for speed.
31
+ startLoc : WebGLUniformLocation ;
32
+
33
+ constructor ( destSize : number [ ] ) {
34
+ this . outputShape = destSize ;
35
+ this . rank = destSize . length ;
36
+
37
+ const dtype = getCoordsDataType ( this . rank ) ;
38
+ const coords = getChannels ( 'coords' , this . rank ) ;
39
+ const sourceLoc = getChannels ( 'sourceLoc' , this . rank ) ;
40
+
41
+ const innerDims =
42
+ this . rank === 1 ? 'sourceLoc' : `vec2(${ sourceLoc . slice ( - 2 ) . join ( ) } )` ;
43
+ const getChannel =
44
+ `getChannel(getSource(${ sourceLoc . join ( ) } ), ${ innerDims } )` ;
45
+ const upperRow = `
46
+ result.x = ${ getChannel } ;
47
+ if (++${ coords [ this . rank - 1 ] } < ${ destSize [ this . rank - 1 ] } ) {
48
+ ++${ sourceLoc [ this . rank - 1 ] } ;
49
+ result.y = ${ getChannel } ;
50
+ --${ sourceLoc [ this . rank - 1 ] } ;
51
+ }
52
+ ` ;
53
+ const lowerRow = this . rank === 1 ? '' : `
54
+ --${ coords [ this . rank - 1 ] } ;
55
+ if (++${ coords [ this . rank - 2 ] } < ${ destSize [ this . rank - 2 ] } ) {
56
+ ++${ sourceLoc [ this . rank - 2 ] } ;
57
+ result.z = ${ getChannel } ;
58
+ if (++${ coords [ this . rank - 1 ] } < ${ destSize [ this . rank - 1 ] } ) {
59
+ ++${ sourceLoc [ this . rank - 1 ] } ;
60
+ result.w = ${ getChannel } ;
61
+ }
62
+ }
63
+ ` ;
64
+
65
+ const sourceLocSetup = this . rank <= 4 ?
66
+ `sourceLoc = coords +
67
+ ${ dtype } (${ destSize . map ( ( _ , i ) => `start[${ i } ]` ) . join ( ) } );` :
68
+ destSize . map ( ( _ , i ) => `${ sourceLoc [ i ] } = ${ coords [ i ] } + start[${ i } ];` )
69
+ . join ( '\n' ) ;
70
+ this . userCode = `
71
+ uniform int start[${ this . rank } ];
72
+ void main() {
73
+ ${ dtype } coords = getOutputCoords();
74
+ ${ dtype } sourceLoc;
75
+ ${ sourceLocSetup }
76
+ vec4 result = vec4(0.);
77
+ ${ upperRow }
78
+ ${ lowerRow }
79
+ setOutput(result);
80
+ }
81
+ ` ;
82
+ }
83
+
84
+ getCustomSetupFunc ( start : number [ ] ) {
85
+ if ( start . length !== this . rank ) {
86
+ throw Error (
87
+ `The rank (${ this . rank } ) of the program must match the ` +
88
+ `length of start (${ start . length } )` ) ;
89
+ }
90
+ return ( gpgpu : GPGPUContext , webGLProgram : WebGLProgram ) => {
91
+ if ( this . startLoc == null ) {
92
+ this . startLoc = gpgpu . getUniformLocationNoThrow ( webGLProgram , 'start' ) ;
93
+ if ( this . startLoc == null ) {
94
+ // This means the compiler has optimized and realized it doesn't need
95
+ // the uniform.
96
+ return ;
97
+ }
98
+ }
99
+ gpgpu . gl . uniform1iv ( this . startLoc , start ) ;
100
+ } ;
101
+ }
102
+ }
0 commit comments