1
+ // Licensed to the .NET Foundation under one or more agreements.
2
+ // The .NET Foundation licenses this file to you under the MIT license.
3
+ // See the LICENSE file in the project root for more information.
4
+
5
+ using System ;
6
+ using System . Runtime . CompilerServices ;
7
+ using System . Runtime . Intrinsics ;
8
+ using System . Runtime . Intrinsics . X86 ;
9
+ using Microsoft . ML . Internal . CpuMath . Core ;
10
+
11
+ namespace Microsoft . ML . Internal . CpuMath . FactorizationMachine
12
+ {
13
+ internal static class AvxIntrinsics
14
+ {
15
+ private static readonly Vector256 < float > _point5 = Vector256 . Create ( 0.5f ) ;
16
+
17
+ [ MethodImplAttribute ( MethodImplOptions . AggressiveInlining ) ]
18
+ private static Vector256 < float > MultiplyAdd ( Vector256 < float > src1 , Vector256 < float > src2 , Vector256 < float > src3 )
19
+ {
20
+ if ( Fma . IsSupported )
21
+ {
22
+ return Fma . MultiplyAdd ( src1 , src2 , src3 ) ;
23
+ }
24
+ else
25
+ {
26
+ Vector256 < float > product = Avx . Multiply ( src1 , src2 ) ;
27
+ return Avx . Add ( product , src3 ) ;
28
+ }
29
+ }
30
+
31
+ [ MethodImplAttribute ( MethodImplOptions . AggressiveInlining ) ]
32
+ private static Vector256 < float > MultiplyAddNegated ( Vector256 < float > src1 , Vector256 < float > src2 , Vector256 < float > src3 )
33
+ {
34
+ if ( Fma . IsSupported )
35
+ {
36
+ return Fma . MultiplyAddNegated ( src1 , src2 , src3 ) ;
37
+ }
38
+ else
39
+ {
40
+ Vector256 < float > product = Avx . Multiply ( src1 , src2 ) ;
41
+ return Avx . Subtract ( src3 , product ) ;
42
+ }
43
+ }
44
+
45
+ // This function implements Algorithm 1 in https://github.com/wschin/fast-ffm/blob/master/fast-ffm.pdf.
46
+ // Compute the output value of the field-aware factorization, as the sum of the linear part and the latent part.
47
+ // The linear part is the inner product of linearWeights and featureValues.
48
+ // The latent part is the sum of all intra-field interactions in one field f, for all fields possible
49
+ public static unsafe void CalculateIntermediateVariables ( int * fieldIndices , int * featureIndices , float * featureValues ,
50
+ float * linearWeights , float * latentWeights , float * latentSum , float * response , int fieldCount , int latentDim , int count )
51
+ {
52
+ Contracts . Assert ( Avx . IsSupported ) ;
53
+
54
+ // The number of all possible fields.
55
+ int m = fieldCount ;
56
+ int d = latentDim ;
57
+ int c = count ;
58
+ int * pf = fieldIndices ;
59
+ int * pi = featureIndices ;
60
+ float * px = featureValues ;
61
+ float * pw = linearWeights ;
62
+ float * pv = latentWeights ;
63
+ float * pq = latentSum ;
64
+ float linearResponse = 0 ;
65
+ float latentResponse = 0 ;
66
+
67
+ Unsafe . InitBlock ( pq , 0 , ( uint ) ( m * m * d * sizeof ( float ) ) ) ;
68
+
69
+ Vector256 < float > y = Vector256 < float > . Zero ;
70
+ Vector256 < float > tmp = Vector256 < float > . Zero ;
71
+
72
+ for ( int i = 0 ; i < c ; i ++ )
73
+ {
74
+ int f = pf [ i ] ;
75
+ int j = pi [ i ] ;
76
+ linearResponse += pw [ j ] * px [ i ] ;
77
+
78
+ Vector256 < float > x = Avx . BroadcastScalarToVector256 ( px + i ) ;
79
+ Vector256 < float > xx = Avx . Multiply ( x , x ) ;
80
+
81
+ // tmp -= <v_j,f, v_j,f> * x * x
82
+ int vBias = j * m * d + f * d ;
83
+
84
+ // j-th feature's latent vector in the f-th field hidden space.
85
+ float * vjf = pv + vBias ;
86
+
87
+ for ( int k = 0 ; k + 8 <= d ; k += 8 )
88
+ {
89
+ Vector256 < float > vjfBuffer = Avx . LoadVector256 ( vjf + k ) ;
90
+ tmp = MultiplyAddNegated ( Avx . Multiply ( vjfBuffer , vjfBuffer ) , xx , tmp ) ;
91
+ }
92
+
93
+ for ( int fprime = 0 ; fprime < m ; fprime ++ )
94
+ {
95
+ vBias = j * m * d + fprime * d ;
96
+ int qBias = f * m * d + fprime * d ;
97
+ float * vjfprime = pv + vBias ;
98
+ float * qffprime = pq + qBias ;
99
+
100
+ // q_f,f' += v_j,f' * x
101
+ for ( int k = 0 ; k + 8 <= d ; k += 8 )
102
+ {
103
+ Vector256 < float > vjfprimeBuffer = Avx . LoadVector256 ( vjfprime + k ) ;
104
+ Vector256 < float > q = Avx . LoadVector256 ( qffprime + k ) ;
105
+ q = MultiplyAdd ( vjfprimeBuffer , x , q ) ;
106
+ Avx . Store ( qffprime + k , q ) ;
107
+ }
108
+ }
109
+ }
110
+
111
+ for ( int f = 0 ; f < m ; f ++ )
112
+ {
113
+ // tmp += <q_f,f, q_f,f>
114
+ float * qff = pq + f * m * d + f * d ;
115
+ for ( int k = 0 ; k + 8 <= d ; k += 8 )
116
+ {
117
+ Vector256 < float > qffBuffer = Avx . LoadVector256 ( qff + k ) ;
118
+
119
+ // Intra-field interactions.
120
+ tmp = MultiplyAdd ( qffBuffer , qffBuffer , tmp ) ;
121
+ }
122
+
123
+ // y += <q_f,f', q_f',f>, f != f'
124
+ // Whis loop handles inter - field interactions because f != f'.
125
+ for ( int fprime = f + 1 ; fprime < m ; fprime ++ )
126
+ {
127
+ float * qffprime = pq + f * m * d + fprime * d ;
128
+ float * qfprimef = pq + fprime * m * d + f * d ;
129
+ for ( int k = 0 ; k + 8 <= d ; k += 8 )
130
+ {
131
+ // Inter-field interaction.
132
+ Vector256 < float > qffprimeBuffer = Avx . LoadVector256 ( qffprime + k ) ;
133
+ Vector256 < float > qfprimefBuffer = Avx . LoadVector256 ( qfprimef + k ) ;
134
+ y = MultiplyAdd ( qffprimeBuffer , qfprimefBuffer , y ) ;
135
+ }
136
+ }
137
+ }
138
+
139
+ y = MultiplyAdd ( _point5 , tmp , y ) ;
140
+ tmp = Avx . Add ( y , Avx . Permute2x128 ( y , y , 1 ) ) ;
141
+ tmp = Avx . HorizontalAdd ( tmp , tmp ) ;
142
+ y = Avx . HorizontalAdd ( tmp , tmp ) ;
143
+ Sse . StoreScalar ( & latentResponse , y . GetLower ( ) ) ; // The lowest slot is the response value.
144
+ * response = linearResponse + latentResponse ;
145
+ }
146
+
147
+ // This function implements Algorithm 2 in https://github.com/wschin/fast-ffm/blob/master/fast-ffm.pdf
148
+ // Calculate the stochastic gradient and update the model.
149
+ public static unsafe void CalculateGradientAndUpdate ( int * fieldIndices , int * featureIndices , float * featureValues , float * latentSum , float * linearWeights ,
150
+ float * latentWeights , float * linearAccumulatedSquaredGrads , float * latentAccumulatedSquaredGrads , float lambdaLinear , float lambdaLatent , float learningRate ,
151
+ int fieldCount , int latentDim , float weight , int count , float slope )
152
+ {
153
+ Contracts . Assert ( Avx . IsSupported ) ;
154
+
155
+ int m = fieldCount ;
156
+ int d = latentDim ;
157
+ int c = count ;
158
+ int * pf = fieldIndices ;
159
+ int * pi = featureIndices ;
160
+ float * px = featureValues ;
161
+ float * pq = latentSum ;
162
+ float * pw = linearWeights ;
163
+ float * pv = latentWeights ;
164
+ float * phw = linearAccumulatedSquaredGrads ;
165
+ float * phv = latentAccumulatedSquaredGrads ;
166
+
167
+ Vector256 < float > wei = Vector256 . Create ( weight ) ;
168
+ Vector256 < float > s = Vector256 . Create ( slope ) ;
169
+ Vector256 < float > lr = Vector256 . Create ( learningRate ) ;
170
+ Vector256 < float > lambdav = Vector256 . Create ( lambdaLatent ) ;
171
+
172
+ for ( int i = 0 ; i < count ; i ++ )
173
+ {
174
+ int f = pf [ i ] ;
175
+ int j = pi [ i ] ;
176
+
177
+ // Calculate gradient of linear term w_j.
178
+ float g = weight * ( lambdaLinear * pw [ j ] + slope * px [ i ] ) ;
179
+
180
+ // Accumulate the gradient of the linear term.
181
+ phw [ j ] += g * g ;
182
+
183
+ // Perform ADAGRAD update rule to adjust linear term.
184
+ pw [ j ] -= learningRate / MathF . Sqrt ( phw [ j ] ) * g ;
185
+
186
+ // Update latent term, v_j,f', f'=1,...,m.
187
+ Vector256 < float > x = Avx . BroadcastScalarToVector256 ( px + i ) ;
188
+
189
+ for ( int fprime = 0 ; fprime < m ; fprime ++ )
190
+ {
191
+ float * vjfprime = pv + j * m * d + fprime * d ;
192
+ float * hvjfprime = phv + j * m * d + fprime * d ;
193
+ float * qfprimef = pq + fprime * m * d + f * d ;
194
+ Vector256 < float > sx = Avx . Multiply ( s , x ) ;
195
+
196
+ for ( int k = 0 ; k + 8 <= d ; k += 8 )
197
+ {
198
+ Vector256 < float > v = Avx . LoadVector256 ( vjfprime + k ) ;
199
+ Vector256 < float > q = Avx . LoadVector256 ( qfprimef + k ) ;
200
+
201
+ // Calculate L2-norm regularization's gradient.
202
+ Vector256 < float > gLatent = Avx . Multiply ( lambdav , v ) ;
203
+
204
+ Vector256 < float > tmp = q ;
205
+
206
+ // Calculate loss function's gradient.
207
+ if ( fprime == f )
208
+ tmp = MultiplyAddNegated ( v , x , q ) ;
209
+ gLatent = MultiplyAdd ( sx , tmp , gLatent ) ;
210
+ gLatent = Avx . Multiply ( wei , gLatent ) ;
211
+
212
+ // Accumulate the gradient of latent vectors.
213
+ Vector256 < float > h = MultiplyAdd ( gLatent , gLatent , Avx . LoadVector256 ( hvjfprime + k ) ) ;
214
+
215
+ // Perform ADAGRAD update rule to adjust latent vector.
216
+ v = MultiplyAddNegated ( lr , Avx . Multiply ( Avx . ReciprocalSqrt ( h ) , gLatent ) , v ) ;
217
+ Avx . Store ( vjfprime + k , v ) ;
218
+ Avx . Store ( hvjfprime + k , h ) ;
219
+ }
220
+ }
221
+ }
222
+ }
223
+ }
224
+ }
0 commit comments