diff --git a/csvdedupe/csvhelpers.py b/csvdedupe/csvhelpers.py index 118c49c..8c354d8 100644 --- a/csvdedupe/csvhelpers.py +++ b/csvdedupe/csvhelpers.py @@ -62,8 +62,11 @@ def writeResults(clustered_dupes, input_file, output_file): cluster_membership = {} for cluster_id, (cluster, score) in enumerate(clustered_dupes): - for record_id in cluster: - cluster_membership[record_id] = cluster_id + for record_index, record_id in enumerate(cluster): + cluster_membership[record_id] = { + 'cluster_id': cluster_id, + 'score': score[record_index], + } unique_record_id = cluster_id + 1 @@ -73,15 +76,19 @@ def writeResults(clustered_dupes, input_file, output_file): heading_row = reader.next() heading_row.insert(0, 'Cluster ID') + heading_row.insert(1, 'Confidence') writer.writerow(heading_row) for row_id, row in enumerate(reader): if row_id in cluster_membership: - cluster_id = cluster_membership[row_id] + cluster_id = cluster_membership[row_id]['cluster_id'] + score = cluster_membership[row_id]['score'] else: cluster_id = unique_record_id unique_record_id += 1 + score = '' row.insert(0, cluster_id) + row.insert(1, score) writer.writerow(row)