@@ -142,16 +142,6 @@ class UserList(generics.ListCreateAPIView):
142
142
class LabelersListAPI (APIView ):
143
143
pagination_class = None
144
144
permission_classes = (IsAuthenticated , IsProjectUser , IsAdminUserAndWriteOnly )
145
- def render_agreement_matrix (self ):
146
- return ''
147
-
148
- def get_truth_agreement (self ):
149
- cursor = connection .cursor ()
150
-
151
- return 42
152
-
153
- def get_labelers_agreement (self ):
154
- return 42
155
145
156
146
def get (self , request , * args , ** kwargs ):
157
147
# def get_annotations
@@ -231,6 +221,7 @@ def plot_agreement_matrix(agreement):
231
221
232
222
233
223
annotations_df = get_annotations (cursor , project_id )
224
+ annotations_df .to_csv (r'C:\Users\omri.allouche\Downloads\labeler_agreement.csv' )
234
225
annotations_df = annotations_df .drop_duplicates (['document_id' , 'user_id' ])
235
226
annotations_df ['is_correct' ] = [int (x ) for x in annotations_df ['label_id' ]== annotations_df ['true_label_id' ]]
236
227
user_truth_agreement = annotations_df [ pd .notnull (annotations_df ['true_label_id' ]) ].groupby ('user_id' )['is_correct' ].agg (['count' , 'mean' ])
@@ -246,7 +237,9 @@ def plot_agreement_matrix(agreement):
246
237
users ['agreement_with_truth' ] = user_truth_agreement ['mean' ]
247
238
users = users .reset_index ()
248
239
240
+ num_truth_annotations = annotations_df ['true_label_id' ].count ()
249
241
response = {
242
+ 'num_truth_annotations' : num_truth_annotations ,
250
243
'users' : users .fillna (0 ).T .to_dict (),
251
244
'document_agreement' : documents_agreement_df .fillna (0 ).T .to_dict (),
252
245
'matrix' : plot_agreement_matrix (users_agreement_kappa ),
@@ -285,8 +278,9 @@ def get(self, request, *args, **kwargs):
285
278
server_documentannotation.label_id
286
279
FROM
287
280
server_document
288
- LEFT JOIN server_documentannotation ON server_documentannotation.document_id = server_document.id AND server_documentannotation.user_id = %s
289
- WHERE server_document.project_id = %s''' % (str (request .user .id ), str (self .kwargs ['project_id' ]))
281
+ LEFT JOIN server_documentannotation ON server_documentannotation.document_id = server_document.id
282
+ -- AND server_documentannotation.user_id = {user_id}
283
+ WHERE server_document.project_id = {project_id}''' .format (user_id = request .user .id , project_id = project_id )
290
284
291
285
doc_annotations_gold_query = '''SELECT
292
286
server_document.id,
@@ -295,17 +289,21 @@ def get(self, request, *args, **kwargs):
295
289
FROM
296
290
server_document
297
291
LEFT JOIN server_documentgoldannotation ON server_documentgoldannotation.document_id = server_document.id
298
- WHERE server_document.project_id =''' + str ( self . kwargs [ ' project_id' ] )
292
+ WHERE server_document.project_id = {project_id} ''' . format ( project_id = project_id )
299
293
300
294
if not os .path .isdir (ML_FOLDER ):
301
295
os .makedirs (ML_FOLDER )
302
- with open (os .path .join (ML_FOLDER , INPUT_FILE ), 'w' , encoding = 'utf-8' , newline = '' ) as outfile :
303
- wr = csv .writer (outfile , quoting = csv .QUOTE_ALL )
304
- wr .writerow (['document_id' , 'text' , 'label_id' ])
305
- cursor .execute (doc_annotations_query )
306
- for row in cursor .fetchall ():
307
- label_id = None
308
- wr .writerow ([row [0 ], row [1 ], row [2 ]])
296
+
297
+ cursor .execute (doc_annotations_gold_query )
298
+ gold_annotations = cursor .fetchall ()
299
+ cursor .execute (doc_annotations_query )
300
+ user_annotations = cursor .fetchall ()
301
+
302
+ annotations = gold_annotations + user_annotations
303
+
304
+ df = pd .DataFrame (annotations , columns = ['document_id' , 'text' , 'label_id' ])
305
+ df = df .drop_duplicates (['document_id' ])
306
+ df .to_csv (os .path .join (ML_FOLDER , INPUT_FILE ), encoding = 'utf-8' )
309
307
310
308
result = run_model_on_file (os .path .join (ML_FOLDER , INPUT_FILE ), os .path .join (ML_FOLDER , OUTPUT_FILE ), user_id = 0 , project_id = project_id )
311
309
@@ -349,30 +347,40 @@ def get(self, request, *args, **kwargs):
349
347
class DocumentExplainAPI (generics .RetrieveUpdateDestroyAPIView ):
350
348
project_id = 999 # TODO: Change this to the actual current project
351
349
pagination_class = None
352
- permission_classes = (IsAuthenticated , IsProjectUser , IsAdminUser )
350
+ permission_classes = (IsAuthenticated , IsProjectUser )
353
351
class_weights = None
354
352
filename = 'ml_models/ml_logistic_regression_weights_{project_id}.csv' .format (project_id = project_id )
355
353
has_weights = False
356
- if (os .path .isfile (filename )):
357
- class_weights = pd .read_csv (os .path .abspath (filename ), header = None ,
358
- names = ['term' , 'weight' ]).set_index ('term' )['weight' ]
359
- has_weights = True
360
-
354
+
355
+ def get_class_weights (self ):
356
+ if not self .has_weights :
357
+ self .set_class_weights ()
358
+ return self .class_weights
359
+
360
+ def set_class_weights (self ):
361
+ if (os .path .isfile (self .filename )):
362
+ data = pd .read_csv (os .path .abspath (self .filename ), header = None , names = ['term' , 'weight' ])
363
+ data ['term' ] = data ['term' ].str .replace ('processed_text_w_' , '' )
364
+ self .class_weights = data .set_index ('term' )['weight' ]
365
+ self .has_weights = True
366
+
361
367
def get (self , request , * args , ** kwargs ):
362
368
d = get_object_or_404 (Document , pk = self .kwargs ['doc_id' ])
363
369
doc_text_splited = d .text .split (' ' )
364
370
format_str_positive = '<span class="has-background-success">{}</span>'
365
371
format_str_negative = '<span class="has-background-danger">{}</span>'
366
372
text = []
367
- if self .has_weights :
373
+ class_weights = self .get_class_weights ()
374
+ if class_weights is not None :
368
375
for w in doc_text_splited :
369
- weight = self . class_weights .get (w .lower ().replace (',' ,'' ).replace ('.' ,'' ), 0 )
376
+ weight = class_weights .get (w .lower ().replace (',' ,'' ).replace ('.' ,'' ), 0 )
370
377
if weight < - 0.2 :
371
378
text .append (format_str_negative .format (w ))
372
379
elif weight > 0.2 :
373
380
text .append (format_str_positive .format (w ))
374
381
else :
375
382
text .append (w )
383
+
376
384
response = {'document' : ' ' .join (text )}
377
385
# doc_text_splited = [w if np.abs(self.class_weights.get(w,0))<0.2 else format_str.format(w) for w in doc_text_splited]
378
386
# doc_text_splited[0] = '<span class="has-background-primary">' + doc_text_splited[0] + '</span>'
@@ -465,7 +473,8 @@ def get_queryset(self):
465
473
466
474
if self .request .query_params .get ('is_checked' ):
467
475
is_null = self .request .query_params .get ('is_checked' ) == 'true'
468
- queryset = project .get_documents (is_null ).distinct ()
476
+ print (int (is_null ))
477
+ queryset = project .get_documents (is_null = is_null , user = self .request .user .id ).distinct ()
469
478
470
479
if (project .use_machine_model_sort ):
471
480
queryset = queryset .order_by ('doc_mlm_annotations__prob' ).filter (project = self .kwargs ['project_id' ]).exclude (doc_mlm_annotations__prob__isnull = True )
0 commit comments