diff --git a/src/vfb_connect/cross_server_tools.py b/src/vfb_connect/cross_server_tools.py index 094fef3e..65a0abdb 100644 --- a/src/vfb_connect/cross_server_tools.py +++ b/src/vfb_connect/cross_server_tools.py @@ -521,15 +521,17 @@ def get_neurons_upstream_of(self, neuron, weight, classification=None, query_by_ classification=classification, query_by_label=query_by_label, return_dataframe=return_dataframe, verbose=verbose) - def get_connected_neurons_by_type(self, weight, upstream_type=None, downstream_type=None, query_by_label=True, - return_dataframe=True, verbose=False): + def get_connected_neurons_by_type(self, weight, upstream_type=None, downstream_type=None, group_by_class=False, + query_by_label=True, exclude_dbs=['hb', 'fafb'], return_dataframe=True, verbose=False): """Get all synaptic connections between individual neurons of `upstream_type` and `downstream_type` where synapse count >= `weight`. At least one of 'upstream_type' or downstream_type must be specified. :param upstream_type: The upstream neuron type (e.g., 'GABAergic neuron'). :param downstream_type: The downstream neuron type (e.g., 'Descending neuron'). + :param group_by_class: If `True`, return connectivity results aggregated by class rather than per neuron. Default `False`. :param query_by_label: Optional. Specify neuron type by label if `True` (default) or by short_form ID if `False`. + :param exclude_dbs: Optional. List of databases (short_forms or symbols) to exclude from results. Hemibrain and catmaid FAFB excluded by default. :param return_dataframe: Optional. Returns pandas DataFrame if `True`, otherwise returns list of dicts. Default `True`. :return: A DataFrame or list of synaptic connections between specified neuron types. :rtype: pandas.DataFrame or list of dicts @@ -546,36 +548,60 @@ def get_connected_neurons_by_type(self, weight, upstream_type=None, downstream_t if upstream_type: upstream_type = self.lookup_id(dequote(upstream_type)) if downstream_type: downstream_type = self.lookup_id(dequote(downstream_type)) cypher_ql = [] - if upstream_type: - cypher_ql.append( - "MATCH (up:Class)<-[:SUBCLASSOF*0..]-(c1:Class)<-[:INSTANCEOF]-(n1:has_neuron_connectivity)" - " WHERE up.short_form = '%s' " % upstream_type) - if downstream_type: - cypher_ql.append( - "MATCH (down:Class)<-[:SUBCLASSOF*0..]-(c2:Class)<-[:INSTANCEOF]-(n2:has_neuron_connectivity)" - "WHERE down.short_form = '%s' " % downstream_type) - if not upstream_type: - cypher_ql.append( - "MATCH (c1:Class)<-[:INSTANCEOF]-(n1)-[r:synapsed_to]->(n2) WHERE r.weight[0] >= %d " % weight) - elif not downstream_type: - cypher_ql.append( - "MATCH (n1)-[r:synapsed_to]->(n2)-[:INSTANCEOF]->(c2:Class) WHERE r.weight[0] >= %d " % weight) + + cypher_ql.append( + "MATCH %s(c1:Class:Neuron) " + % ('(:Class:Neuron {short_form:"' + upstream_type + '"})<-[:SUBCLASSOF*0..]-' if upstream_type else "")) + cypher_ql.append( + "MATCH %s(c2:Class:Neuron) " + % ('(:Class:Neuron {short_form:"' + downstream_type + '"})<-[:SUBCLASSOF*0..]-' if downstream_type else "")) + + cypher_ql.append("MATCH (c1)<-[:INSTANCEOF]-(n1:Individual:Neuron:has_neuron_connectivity)-" + "[r:synapsed_to]->(n2:Individual:Neuron:has_neuron_connectivity)-[:INSTANCEOF]->(c2) " + "WHERE r.weight[0] >= %s " % weight) + + if exclude_dbs: + cypher_ql.append("MATCH (n1)-[:database_cross_reference]->(s:Individual:Site {is_data_source:[True]}) \n" + "WHERE NOT (s.short_form IN %s) \n" + "AND NOT (s.symbol[0] IN %s) " + % (exclude_dbs, exclude_dbs)) + + if not group_by_class: + cypher_ql.append("OPTIONAL MATCH (n1)-[r1:database_cross_reference]->(s1:Individual:Site {is_data_source:[True]}) \n" + "OPTIONAL MATCH (n2)-[r2:database_cross_reference]->(s2:Individual:Site {is_data_source:[True]}) \n" + "RETURN apoc.text.join(collect(distinct c1.label),'|') AS upstream_class, " + "apoc.text.join(collect(distinct c1.short_form),'|') AS upstream_class_id, " + "n1.short_form as upstream_neuron_id, n1.label as upstream_neuron_name, " + "r.weight[0] as weight, n2.short_form as downstream_neuron_id, " + "n2.label as downstream_neuron_name, " + "apoc.text.join(collect(distinct c2.label),'|') as downstream_class, " + "apoc.text.join(collect(distinct c2.short_form),'|') as downstream_class_id, " + "s1.short_form AS up_data_source, r1.accession[0] as up_accession, " + "s2.short_form AS down_data_source, r2.accession[0] AS down_accession ") + else: - cypher_ql.append("MATCH (n1)-[r:synapsed_to]->(n2) WHERE r.weight[0] >= %d " % weight) - cypher_ql.append("OPTIONAL MATCH (n1)-[r1:database_cross_reference]->(s1:Site) " - "WHERE exists(s1.is_data_source) AND s1.is_data_source = [True] ") - cypher_ql.append("OPTIONAL MATCH (n2)-[r2:database_cross_reference]->(s2:Site) " - "WHERE exists(s2.is_data_source) AND s2.is_data_source = [True] " ) - cypher_ql.append("RETURN apoc.text.join(collect(distinct c1.label),'|') AS upstream_class, " - "apoc.text.join(collect(distinct c1.short_form),'|') AS upstream_class_id, " - "n1.short_form as upstream_neuron_id, n1.label as upstream_neuron_name," - "r.weight[0] as weight, n2.short_form as downstream_neuron_id, " - "n2.label as downstream_neuron_name, " - "apoc.text.join(collect(distinct c2.label),'|') as downstream_class, " - "apoc.text.join(collect(distinct c2.short_form),'|') as downstream_class_id, " - "s1.short_form AS up_data_source, r1.accession[0] as up_accession, " - "s2.short_form AS down_source, r2.accession[0] AS down_accession") - cypher_q = ' \n'.join(cypher_ql) + cypher_ql.append("WITH c1, c2, count(*) as pairwise_connections, sum(r.weight[0]) as total_weight, " + "count(distinct n1) as connected_upstream_count \n\n" + "MATCH (c1)<-[:INSTANCEOF]-(all_n1:Individual:has_neuron_connectivity)%s \n\n" + "WITH c1, c2, pairwise_connections, total_weight, connected_upstream_count, " + "count(distinct all_n1) as total_upstream_count \n\n" + "RETURN c1.label AS upstream_class, " + "c1.short_form AS upstream_class_id, " + "c2.label AS downstream_class, " + "c2.short_form AS downstream_class_id, " + "total_upstream_count, " + "connected_upstream_count, " + "round((toFloat(connected_upstream_count)/toFloat(total_upstream_count))*100) as percent_connected, " + "pairwise_connections, " + "total_weight, " + "total_weight/pairwise_connections as average_weight " + "ORDER BY pairwise_connections DESC, average_weight DESC" + % ("-[:database_cross_reference]->(s:Individual:Site {is_data_source:[True]}) \n" + "WHERE NOT (s.short_form IN %s) \n" + "AND NOT (s.symbol[0] IN %s) " + % (exclude_dbs, exclude_dbs) if exclude_dbs else "")) + + cypher_q = ' \n\n'.join(cypher_ql) print(cypher_q) if verbose else None r = self.nc.commit_list([cypher_q]) if not r: diff --git a/src/vfb_connect/test/cross_server_tools_test.py b/src/vfb_connect/test/cross_server_tools_test.py index a4341ebc..e0a91e60 100644 --- a/src/vfb_connect/test/cross_server_tools_test.py +++ b/src/vfb_connect/test/cross_server_tools_test.py @@ -72,11 +72,13 @@ def test_get_upstream_neurons(self): def test_get_connected_neurons_by_type(self): print() - fu = self.vc.get_connected_neurons_by_type(upstream_type='Kenyon cell', downstream_type='mushroom body output neuron', weight=20) + fu = self.vc.get_connected_neurons_by_type(upstream_type='Kenyon cell', exclude_dbs=[], downstream_type='mushroom body output neuron', weight=20) self.assertTrue(len(fu) > 0) - fu = self.vc.get_connected_neurons_by_type(upstream_type='FBbt_00051488', weight=5) + fu = self.vc.get_connected_neurons_by_type(upstream_type='FBbt_00051488', weight=5, exclude_dbs=[]) self.assertTrue(len(fu) > 0) - fu = self.vc.get_connected_neurons_by_type(downstream_type='FBbt_00051488', weight=5) + fu = self.vc.get_connected_neurons_by_type(downstream_type='FBbt_00053287', weight=5) + self.assertTrue(len(fu) > 0) + fu = self.vc.get_connected_neurons_by_type(weight=10, upstream_type='C2', downstream_type='visual projection neuron', group_by_class=True) self.assertTrue(len(fu) > 0)