@@ -106,8 +106,10 @@ void main() {
106106 // Preload weight tensor
107107 for (int r = 0 ; r < 4 ; r++ ) {
108108 T qmat2[TILE_TXCOLS * 4 ];
109- VEC4_T qmat2_vec4;
110- uvec4 packed_weight_tex;
109+ $if QUANT_NBITS == 4 :
110+ uvec4 packed_weight_tex;
111+ $else :
112+ ivec4 packed_weight_tex;
111113
112114 $if QUANT_NBITS == 4 :
113115 $for c in range(0 , TILE_TXCOLS, 2 ):
@@ -119,28 +121,27 @@ void main() {
119121 packed_weight_tex = texelFetch(
120122 t_weight, ivec2 (weight_txcol + ${c}, pos + r), 0 );
121123
122- qmat2_vec4 = VEC4_T( packed_weight_tex >> 4 ) ;
123- qmat2[${c} * 4 * TILE_TXCOLS + 0 ] = qmat2_vec4.x ;
124- qmat2[${c} * 4 * TILE_TXCOLS + 1 ] = qmat2_vec4.y ;
125- qmat2[${c} * 4 * TILE_TXCOLS + 2 ] = qmat2_vec4.z ;
126- qmat2[${c} * 4 * TILE_TXCOLS + 3 ] = qmat2_vec4.w ;
127-
128- qmat2_vec4 = VEC4_T( packed_weight_tex & 0x0F) ;
129- qmat2[${c} * 4 * TILE_TXCOLS + 4 ] = qmat2_vec4.x ;
130- qmat2[${c} * 4 * TILE_TXCOLS + 5 ] = qmat2_vec4.y ;
131- qmat2[${c} * 4 * TILE_TXCOLS + 6 ] = qmat2_vec4.z ;
132- qmat2[${c} * 4 * TILE_TXCOLS + 7 ] = qmat2_vec4.w ;
124+ const uvec4 tmp1 = packed_weight_tex >> 4 ;
125+ qmat2[${c} * 4 * TILE_TXCOLS + 0 ] = T(tmp1.x) ;
126+ qmat2[${c} * 4 * TILE_TXCOLS + 1 ] = T(tmp1.y) ;
127+ qmat2[${c} * 4 * TILE_TXCOLS + 2 ] = T(tmp1.z) ;
128+ qmat2[${c} * 4 * TILE_TXCOLS + 3 ] = T(tmp1.w) ;
129+
130+ const uvec4 tmp2 = packed_weight_tex & 0x0F;
131+ qmat2[${c} * 4 * TILE_TXCOLS + 4 ] = T(tmp2.x) ;
132+ qmat2[${c} * 4 * TILE_TXCOLS + 5 ] = T(tmp2.y) ;
133+ qmat2[${c} * 4 * TILE_TXCOLS + 6 ] = T(tmp2.z) ;
134+ qmat2[${c} * 4 * TILE_TXCOLS + 7 ] = T(tmp2.w) ;
133135 $else :
134136 $for c in range(TILE_TXCOLS):
135137 $if WEIGHT_STORAGE == "buffer ":
136138 qmat2_bufi = (pos + r) * weight_row_txstride + out_txcol;
137139 encoded_weight = t_weight[qmat2_bufi + ${c}];
138- packed_weight_tex = uvec4 (encoded_weight & 0xFF, (encoded_weight >> 8 ) & 0xFF, (encoded_weight >> 16 ) & 0xFF, encoded_weight >> 24 );
139- qmat2_vec4 = VEC4_T(packed_weight_tex);
140+ packed_weight_tex = ivec4 (encoded_weight & 0xFF, (encoded_weight >> 8 ) & 0xFF, (encoded_weight >> 16 ) & 0xFF, encoded_weight >> 24 );
140141 $else :
141- qmat2_vec4 = VEC4_T (texelFetch(t_weight, ivec2 (out_txcol + ${c}, pos + r), 0 ));
142+ packed_weight_tex = ivec4 (texelFetch(t_weight, ivec2 (out_txcol + ${c}, pos + r), 0 ));
142143 $for j in range(4 ):
143- qmat2[${c} * 4 + ${j}] = qmat2_vec4 [${j}];
144+ qmat2[${c} * 4 + ${j}] = T(packed_weight_tex [${j}]) ;
144145
145146 for (int tr = 0 ; tr < TILE_ROWS; ++ tr) {
146147 $for c in range(TILE_TXCOLS):
0 commit comments