Skip to content

Commit f99863a

Browse files
committed
add linr
Signed-off-by: weiwee <[email protected]>
1 parent b9a72e5 commit f99863a

File tree

6 files changed

+444
-0
lines changed

6 files changed

+444
-0
lines changed
Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
/*
2+
* Copyright 2019 The FATE Authors. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.webank.ai.fate.serving.federatedml.model;
18+
19+
import com.google.common.collect.Lists;
20+
import com.google.common.collect.Sets;
21+
import com.webank.ai.fate.core.mlmodel.buffer.LinRModelParamProto.LinRModelParam;
22+
import com.webank.ai.fate.serving.core.bean.Dict;
23+
import com.webank.ai.fate.serving.core.bean.MetaInfo;
24+
import org.slf4j.Logger;
25+
import org.slf4j.LoggerFactory;
26+
27+
import java.util.*;
28+
import java.util.concurrent.ForkJoinTask;
29+
import java.util.concurrent.RecursiveTask;
30+
31+
public abstract class HeteroLinR extends BaseComponent {
32+
private static final Logger logger = LoggerFactory.getLogger(HeteroLinR.class);
33+
private Map<String, Double> weight;
34+
private Double intercept;
35+
LinRModelParam linrModelParam;
36+
37+
@Override
38+
public int initModel(byte[] protoMeta, byte[] protoParam) {
39+
logger.info("start init HeteroLR class");
40+
try {
41+
linrModelParam = this.parseModel(LinRModelParam.parser(), protoParam);
42+
this.weight = linrModelParam.getWeightMap();
43+
this.intercept = linrModelParam.getIntercept();
44+
} catch (Exception ex) {
45+
ex.printStackTrace();
46+
return ILLEGALDATA;
47+
}
48+
logger.info("Finish init HeteroLR class, model weight is {}", this.weight);
49+
return OK;
50+
}
51+
52+
Map<String, Double> forward(List<Map<String, Object>> inputDatas) {
53+
Map<String, Object> inputData = inputDatas.get(0);
54+
int hitCount = 0;
55+
int weightNum = this.weight.size();
56+
int inputFeaturesNum = inputData.size();
57+
if (logger.isDebugEnabled()) {
58+
logger.debug("model weight number:{}", weightNum);
59+
logger.debug("input data features number:{}", inputFeaturesNum);
60+
}
61+
double score = 0;
62+
for (String key : inputData.keySet()) {
63+
if (this.weight.containsKey(key)) {
64+
Double x = new Double(inputData.get(key).toString());
65+
Double w = new Double(this.weight.get(key).toString());
66+
score += w * x;
67+
hitCount += 1;
68+
if (logger.isDebugEnabled()) {
69+
logger.debug("key {} weight is {}, value is {}", key, this.weight.get(key), inputData.get(key));
70+
}
71+
}
72+
}
73+
score += this.intercept;
74+
double modelWeightHitRate = -1.0;
75+
double inputDataHitRate = -1.0;
76+
try {
77+
modelWeightHitRate = (double) hitCount / weightNum;
78+
inputDataHitRate = (double) hitCount / inputFeaturesNum;
79+
} catch (Exception ex) {
80+
ex.printStackTrace();
81+
}
82+
if (logger.isDebugEnabled()) {
83+
logger.debug("model weight hit rate:{}", modelWeightHitRate);
84+
logger.debug("input data features hit rate:{}", inputDataHitRate);
85+
}
86+
Map<String, Double> ret = new HashMap<>(8);
87+
ret.put(Dict.SCORE, score);
88+
ret.put(Dict.MODEL_WRIGHT_HIT_RATE, modelWeightHitRate);
89+
ret.put(Dict.INPUT_DATA_HIT_RATE, inputDataHitRate);
90+
return ret;
91+
}
92+
93+
Map<String, Double> forwardParallel(List<Map<String, Object>> inputDatas) {
94+
Map<String, Object> inputData = inputDatas.get(0);
95+
Map<String, Double> ret = new HashMap<>(8);
96+
double modelWeightHitRate = -1.0;
97+
double inputDataHitRate = -1.0;
98+
Set<String> inputKeys = inputData.keySet();
99+
Set<String> weightKeys = weight.keySet();
100+
Set<String> joinKeys = Sets.newHashSet();
101+
for(String key : inputKeys) {
102+
if(weightKeys.contains(key)){
103+
joinKeys.add(key);
104+
}
105+
}
106+
int modelWeightHitCount = 0;
107+
int inputDataHitCount = 0;
108+
int weightNum = this.weight.size();
109+
int inputFeaturesNum = inputData.size();
110+
if (logger.isDebugEnabled()) {
111+
logger.debug("model weight number:{}", weightNum);
112+
logger.debug("input data features number:{}", inputFeaturesNum);
113+
}
114+
double score = 0;
115+
ForkJoinTask<LinRTaskResult> result = forkJoinPool.submit(new LinRTask(weight, inputData, Lists.newArrayList(joinKeys)));
116+
if (result != null) {
117+
try {
118+
LinRTaskResult lrTaskResult = result.get();
119+
score = lrTaskResult.score;
120+
modelWeightHitCount = lrTaskResult.modelWeightHitCount;
121+
inputDataHitCount = lrTaskResult.inputDataHitCount;
122+
score += this.intercept;
123+
ret.put(Dict.SCORE, score);
124+
modelWeightHitRate = (double) modelWeightHitCount / weightNum;
125+
inputDataHitRate = (double) inputDataHitCount / inputFeaturesNum;
126+
ret.put(Dict.MODEL_WRIGHT_HIT_RATE, modelWeightHitRate);
127+
ret.put(Dict.INPUT_DATA_HIT_RATE, inputDataHitRate);
128+
} catch (Exception e) {
129+
throw new RuntimeException(e);
130+
}
131+
}
132+
return ret;
133+
}
134+
135+
public class LinRTask extends RecursiveTask<LinRTaskResult> {
136+
137+
double modelWeightHitRate = -1.0;
138+
double inputDataHitRate = -1.0;
139+
int splitSize = MetaInfo.PROPERTY_LR_SPLIT_SIZE;
140+
List<String> keys;
141+
Map<String, Object> inputData;
142+
Map<String, Double> weight;
143+
144+
public LinRTask(Map<String, Double> weight, Map<String, Object> inputData, List<String> keys) {
145+
this.keys = keys;
146+
this.inputData = inputData;
147+
this.weight = weight;
148+
}
149+
150+
@Override
151+
protected LinRTaskResult compute() {
152+
double score = 0;
153+
int modelWeightHitCount = 0;
154+
int inputDataHitCount = 0;
155+
if (keys.size() <= splitSize) {
156+
for (String key : keys) {
157+
inputData.get(key);
158+
if (this.weight.containsKey(key)) {
159+
Double x = new Double(inputData.get(key).toString());
160+
Double w = new Double(this.weight.get(key).toString());
161+
score += w * x;
162+
modelWeightHitCount += 1;
163+
inputDataHitCount += 1;
164+
if (logger.isDebugEnabled()) {
165+
logger.debug("key {} weight is {}, value is {}", key, this.weight.get(key), inputData.get(key));
166+
}
167+
}
168+
}
169+
} else {
170+
List<List<Integer>> splits = new ArrayList<List<Integer>>();
171+
int size = keys.size();
172+
int count = (size + splitSize - 1) / splitSize;
173+
List<LinRTask> subJobs = Lists.newArrayList();
174+
for (int i = 0; i < count; i++) {
175+
List<String> subList = keys.subList(i * splitSize, (Math.min((i + 1) * splitSize, size)));
176+
LinRTask subLRTask = new LinRTask(weight, inputData, subList);
177+
subLRTask.fork();
178+
subJobs.add(subLRTask);
179+
}
180+
for (LinRTask lrTask : subJobs) {
181+
LinRTaskResult subResult = lrTask.join();
182+
if (subResult != null) {
183+
score = score + subResult.score;
184+
modelWeightHitCount = modelWeightHitCount + subResult.modelWeightHitCount;
185+
inputDataHitCount = inputDataHitCount + subResult.inputDataHitCount;
186+
}
187+
}
188+
}
189+
return new LinRTaskResult(score, modelWeightHitCount, inputDataHitCount);
190+
}
191+
}
192+
193+
public class LinRTaskResult {
194+
double score = 0;
195+
int modelWeightHitCount = 0;
196+
int inputDataHitCount = 0;
197+
198+
public LinRTaskResult(double score, int modelWeightHitCount, int inputDataHitCount) {
199+
this.score = score;
200+
this.modelWeightHitCount = modelWeightHitCount;
201+
this.inputDataHitCount = inputDataHitCount;
202+
}
203+
}
204+
205+
@Override
206+
public Object getParam() {
207+
return linrModelParam;
208+
}
209+
}
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
/*
2+
* Copyright 2019 The FATE Authors. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.webank.ai.fate.serving.federatedml.model;
18+
19+
import com.webank.ai.fate.serving.common.model.MergeInferenceAware;
20+
import com.webank.ai.fate.serving.core.bean.Context;
21+
import com.webank.ai.fate.serving.core.bean.Dict;
22+
import com.webank.ai.fate.serving.core.constant.StatusCode;
23+
import com.webank.ai.fate.serving.core.exceptions.GuestMergeException;
24+
import org.apache.commons.collections4.CollectionUtils;
25+
import org.slf4j.Logger;
26+
import org.slf4j.LoggerFactory;
27+
28+
import java.util.HashMap;
29+
import java.util.List;
30+
import java.util.Map;
31+
import java.util.concurrent.atomic.AtomicReference;
32+
33+
public class HeteroLinRGuest extends HeteroLR implements MergeInferenceAware, Returnable {
34+
35+
private static final Logger logger = LoggerFactory.getLogger(HeteroLinRGuest.class);
36+
37+
@Override
38+
public Map<String, Object> localInference(Context context, List<Map<String, Object>> input
39+
) {
40+
Map<String, Object> result = new HashMap<>(8);
41+
Map<String, Double> forwardRet = forward(input);
42+
double score = forwardRet.get(Dict.SCORE);
43+
result.put(Dict.SCORE, score);
44+
return result;
45+
}
46+
47+
@Override
48+
public Map<String, Object> mergeRemoteInference(Context context, List<Map<String, Object>> guestData,
49+
Map<String, Object> hostData) {
50+
Map<String, Object> result = this.handleRemoteReturnData(hostData);
51+
if ((int) result.get(Dict.RET_CODE) == StatusCode.SUCCESS) {
52+
if (CollectionUtils.isNotEmpty(guestData)) {
53+
AtomicReference<Double> score = new AtomicReference<>((double) 0);
54+
Map<String, Object> tempMap = guestData.get(0);
55+
Map<String, Object> componentData = (Map<String, Object>) tempMap.get(this.getComponentName());
56+
double localScore = 0;
57+
if (componentData != null && componentData.get(Dict.SCORE) != null) {
58+
localScore = ((Number) componentData.get(Dict.SCORE)).doubleValue();
59+
} else {
60+
throw new GuestMergeException("local result is invalid ");
61+
}
62+
score.set(localScore);
63+
64+
hostData.forEach((k, v) -> {
65+
Map<String, Object> onePartyData = (Map<String, Object>) v;
66+
67+
Map<String, Object> remoteComponentData = (Map<String, Object>) onePartyData.get(this.getComponentName());
68+
double remoteScore;
69+
if (remoteComponentData != null) {
70+
remoteScore = ((Number) remoteComponentData.get(Dict.SCORE)).doubleValue();
71+
} else {
72+
if (onePartyData.get(Dict.PROB) != null) {
73+
remoteScore = ((Number) onePartyData.get(Dict.PROB)).doubleValue();
74+
} else {
75+
throw new GuestMergeException("host data score is null");
76+
}
77+
}
78+
score.updateAndGet(v1 -> new Double((double) (v1 + remoteScore)));
79+
});
80+
result.put(Dict.SCORE, score);
81+
82+
}
83+
}
84+
return result;
85+
}
86+
87+
88+
}
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*
2+
* Copyright 2019 The FATE Authors. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.webank.ai.fate.serving.federatedml.model;
18+
19+
import com.webank.ai.fate.serving.core.bean.Context;
20+
import com.webank.ai.fate.serving.core.bean.Dict;
21+
import com.webank.ai.fate.serving.core.bean.MetaInfo;
22+
import org.slf4j.Logger;
23+
import org.slf4j.LoggerFactory;
24+
25+
import java.util.HashMap;
26+
import java.util.List;
27+
import java.util.Map;
28+
29+
public class HeteroLinRHost extends HeteroLinR implements Returnable {
30+
31+
private static final Logger logger = LoggerFactory.getLogger(HeteroLinRHost.class);
32+
33+
@Override
34+
public Map<String, Object> localInference(Context context, List<Map<String, Object>> inputData) {
35+
HashMap<String, Object> result = new HashMap<>(8);
36+
Map<String, Double> ret = MetaInfo.PROPERTY_LR_USE_PARALLEL ? forwardParallel(inputData) : forward(inputData);
37+
result.put(Dict.SCORE, ret.get(Dict.SCORE));
38+
return result;
39+
}
40+
}

proto/linr-model-meta.proto

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
/*
2+
* Copyright 2019 The FATE Authors. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
syntax = "proto3";
18+
19+
package com.webank.ai.fate.core.mlmodel.buffer;
20+
option java_outer_classname = "LinRModelMetaProto";
21+
22+
message LinRModelMeta {
23+
string penalty = 1;
24+
double tol = 2;
25+
double alpha = 3;
26+
string optimizer = 4;
27+
int64 batch_size = 5;
28+
double learning_rate = 6;
29+
int64 max_iter = 7;
30+
string early_stop = 8;
31+
bool fit_intercept = 9;
32+
string reveal_strategy = 10;
33+
}

0 commit comments

Comments
 (0)