Skip to content

Commit 817484e

Browse files
committed
Creates Patch
This creates the Patch concept along with some start of usages. There is a more specialized ParamPatch for the standard parameter additive patches and a Scaled, Basic, and LoRA implementation. The patches can be created directly, by comparing models, and from gradients. This is an initial step. Following this, there are a few pieces of work that could be considered: 1. DJL Serving Python engine specific patch implementation 2. LoRA for full training 3. Make BasicParamPatch from Optimizer (including gradients, momentum, and lr)
1 parent 1397b2c commit 817484e

File tree

8 files changed

+552
-0
lines changed

8 files changed

+552
-0
lines changed
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
/*
2+
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
5+
* with the License. A copy of the License is located at
6+
*
7+
* http://aws.amazon.com/apache2.0/
8+
*
9+
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
10+
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
11+
* and limitations under the License.
12+
*/
13+
package ai.djl.patch;
14+
15+
import ai.djl.Model;
16+
import ai.djl.ndarray.NDArray;
17+
import ai.djl.nn.Block;
18+
import ai.djl.nn.Parameter;
19+
import ai.djl.nn.ParameterList;
20+
import ai.djl.training.GradientCollector;
21+
import ai.djl.util.Pair;
22+
23+
import java.util.Map;
24+
import java.util.concurrent.ConcurrentHashMap;
25+
26+
/** The basic implementation of a {@link ParamPatch}. */
27+
public class BasicParamPatch extends ParamPatch {
28+
29+
Map<String, NDArray> data;
30+
31+
/**
32+
* Constructs a {@link BasicParamPatch} with patching data.
33+
*
34+
* @param data the patching data
35+
*/
36+
public BasicParamPatch(Map<String, NDArray> data) {
37+
this.data = data;
38+
}
39+
40+
/**
41+
* Makes a patch by comparing two models.
42+
*
43+
* @param source the source model
44+
* @param target the target model
45+
* @return a patch that would transform the source model to the target model
46+
*/
47+
public static BasicParamPatch makePatch(Model source, Model target) {
48+
return BasicParamPatch.makePatch(source.getBlock(), target.getBlock());
49+
}
50+
51+
/**
52+
* Makes a patch by comparing two blocks.
53+
*
54+
* @param source the source block
55+
* @param target the target block
56+
* @return a patch that would transform the source block to the target block
57+
*/
58+
public static BasicParamPatch makePatch(Block source, Block target) {
59+
return BasicParamPatch.makePatch(source.getParameters(), target.getParameters());
60+
}
61+
62+
/**
63+
* Makes a patch by comparing two {@link ParameterList}s.
64+
*
65+
* @param source the source {@link ParameterList}
66+
* @param target the target {@link ParameterList}
67+
* @return a patch that would transform the source {@link ParameterList} to the target {@link
68+
* ParameterList}.
69+
*/
70+
public static BasicParamPatch makePatch(ParameterList source, ParameterList target) {
71+
Map<String, NDArray> data = new ConcurrentHashMap<>(source.size());
72+
for (Pair<String, Parameter> sourcePair : source) {
73+
String key = sourcePair.getKey();
74+
NDArray patchValue = target.get(key).getArray().sub(sourcePair.getValue().getArray());
75+
data.put(key, patchValue);
76+
}
77+
return new BasicParamPatch(data);
78+
}
79+
80+
/**
81+
* Makes a patch from gradients.
82+
*
83+
* <p>This does not include learning rates or any other data from the {@link
84+
* ai.djl.training.optimizer.Optimizer}.
85+
*
86+
* <p>Making the patch does not modify the existing gradients. After this, you can call {@link
87+
* GradientCollector#zeroGradients()} to clear the gradients.
88+
*
89+
* @param block the block for which to collect gradients
90+
* @param gradientCollector the {@link GradientCollector} of the gradients
91+
* @return the gradients as a {@link BasicParamPatch}.
92+
*/
93+
public static BasicParamPatch makePatch(Block block, GradientCollector gradientCollector) {
94+
ParameterList params = block.getParameters();
95+
Map<String, NDArray> data = new ConcurrentHashMap<>(params.size());
96+
for (Pair<String, Parameter> param : params) {
97+
String key = param.getKey();
98+
// Get gradient * -1 to account for gradient being subtracted from param
99+
NDArray patchValue = param.getValue().getArray().getGradient().duplicate().mul(-1);
100+
data.put(key, patchValue);
101+
}
102+
return new BasicParamPatch(data);
103+
}
104+
105+
/**
106+
* Makes a patch from gradients.
107+
*
108+
* <p>This does not include learning rates or any other data from the {@link
109+
* ai.djl.training.optimizer.Optimizer}.
110+
*
111+
* <p>Making the patch does not modify the existing gradients. After this, you can call {@link
112+
* GradientCollector#zeroGradients()} to clear the gradients.
113+
*
114+
* @param model the model for which to collect gradients
115+
* @param gradientCollector the {@link GradientCollector} of the gradients
116+
* @return the gradients as a {@link BasicParamPatch}.
117+
*/
118+
public static BasicParamPatch makePatch(Model model, GradientCollector gradientCollector) {
119+
return makePatch(model.getBlock(), gradientCollector);
120+
}
121+
122+
/** {@inheritDoc} */
123+
@Override
124+
public NDArray getPatch(String paramName) {
125+
return data.get(paramName).duplicate();
126+
}
127+
128+
/** {@inheritDoc} */
129+
@Override
130+
public void close() {
131+
for (NDArray d : data.values()) {
132+
d.close();
133+
}
134+
}
135+
}
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
/*
2+
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
5+
* with the License. A copy of the License is located at
6+
*
7+
* http://aws.amazon.com/apache2.0/
8+
*
9+
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
10+
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
11+
* and limitations under the License.
12+
*/
13+
package ai.djl.patch;
14+
15+
import ai.djl.ndarray.NDArray;
16+
import ai.djl.util.Pair;
17+
18+
import java.util.Map;
19+
20+
/**
21+
* A {@link ParamPatch} based on low-rank adapters.
22+
*
23+
* <p>Based on the paper <a href="https://arxiv.org/abs/2106.09685">LoRA: Low-Rank Adaptation of
24+
* Large Language Models</a>.
25+
*
26+
* <p>TODO This support for LoRA is still a placeholder and needs effective code for creating and
27+
* training
28+
*/
29+
public class LoRA extends ParamPatch {
30+
31+
/** Data of type map from param name to (A, B) pair. */
32+
Map<String, Pair<NDArray, NDArray>> data;
33+
34+
/**
35+
* Constructs a {@link LoRA}.
36+
*
37+
* @param data the data to patch with
38+
*/
39+
public LoRA(Map<String, Pair<NDArray, NDArray>> data) {
40+
this.data = data;
41+
}
42+
43+
/** {@inheritDoc} */
44+
@Override
45+
public NDArray getPatch(String paramName) {
46+
Pair<NDArray, NDArray> d = data.get(paramName);
47+
return d.getKey().get(paramName).matMul(d.getValue().get(paramName));
48+
}
49+
50+
/** {@inheritDoc} */
51+
@Override
52+
public void close() {
53+
for (Pair<NDArray, NDArray> d : data.values()) {
54+
d.getKey().close();
55+
d.getValue().close();
56+
}
57+
}
58+
}
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
/*
2+
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
5+
* with the License. A copy of the License is located at
6+
*
7+
* http://aws.amazon.com/apache2.0/
8+
*
9+
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
10+
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
11+
* and limitations under the License.
12+
*/
13+
package ai.djl.patch;
14+
15+
import ai.djl.Model;
16+
import ai.djl.ndarray.NDArray;
17+
import ai.djl.nn.Block;
18+
import ai.djl.nn.Parameter;
19+
import ai.djl.nn.ParameterList;
20+
import ai.djl.util.Pair;
21+
22+
/**
23+
* A standard {@link Patch} that only adds to {@link Parameter}s.
24+
*
25+
* <p>To create a param patch, see {@link BasicParamPatch}.
26+
*/
27+
public abstract class ParamPatch extends Patch {
28+
29+
/**
30+
* Scales the patch by a scalar multiplier.
31+
*
32+
* @param scale the scalar multiplier for each patch NDArray
33+
* @return a new patch that is a scaled version of this patch
34+
*/
35+
public ParamPatch scale(float scale) {
36+
return new ScaledParamPatch(scale, this);
37+
}
38+
39+
/**
40+
* Returns a new {@link ParamPatch} that is the additive inverse of this patch.
41+
*
42+
* <p>It is equivalent to scaling by -1.
43+
*
44+
* @return a new {@link ParamPatch} that is the additive inverse of this patch
45+
*/
46+
public ParamPatch reverse() {
47+
return scale(-1);
48+
}
49+
50+
/**
51+
* Returns the patch {@link NDArray} for a particular paramName.
52+
*
53+
* @param paramName the parameter path in a {@link ParameterList}.
54+
* @return the patch array
55+
*/
56+
public abstract NDArray getPatch(String paramName);
57+
58+
/**
59+
* Applies the part of this patch to a particular {@link Parameter}.
60+
*
61+
* @param paramName the parameter path in a {@link ParameterList}.
62+
* @param param the {@link Parameter} to patch
63+
*/
64+
public void apply(String paramName, Parameter param) {
65+
NDArray p = getPatch(paramName).duplicate();
66+
param.getArray().addi(p);
67+
p.close();
68+
}
69+
70+
/**
71+
* Applies this patch to a {@link ParameterList}.
72+
*
73+
* @param params the params to patch
74+
*/
75+
public void apply(ParameterList params) {
76+
for (Pair<String, Parameter> param : params) {
77+
apply(param.getKey(), param.getValue());
78+
}
79+
}
80+
81+
/**
82+
* Applies this patch to a {@link Block}.
83+
*
84+
* @param block the block to patch
85+
*/
86+
public void apply(Block block) {
87+
apply(block.getParameters());
88+
}
89+
90+
/**
91+
* Applies this patch to a {@link Model}.
92+
*
93+
* @param model the model to patch
94+
*/
95+
@Override
96+
public void apply(Model model) {
97+
apply(model.getBlock());
98+
}
99+
}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/*
2+
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
5+
* with the License. A copy of the License is located at
6+
*
7+
* http://aws.amazon.com/apache2.0/
8+
*
9+
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
10+
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
11+
* and limitations under the License.
12+
*/
13+
package ai.djl.patch;
14+
15+
import ai.djl.Model;
16+
17+
/**
18+
* A method for modifying a {@link Model}.
19+
*
20+
* <p>The most standard form is the {@link ParamPatch}.
21+
*/
22+
public abstract class Patch implements AutoCloseable {
23+
24+
/**
25+
* Applies this patch to a model.
26+
*
27+
* @param model the model to update with the patch
28+
*/
29+
public abstract void apply(Model model);
30+
31+
/** {@inheritDoc} */
32+
@Override
33+
public abstract void close();
34+
}
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
/*
2+
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
5+
* with the License. A copy of the License is located at
6+
*
7+
* http://aws.amazon.com/apache2.0/
8+
*
9+
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
10+
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
11+
* and limitations under the License.
12+
*/
13+
package ai.djl.patch;
14+
15+
import ai.djl.ndarray.NDArray;
16+
17+
/**
18+
* Constructs a {@link ScaledParamPatch} to scale a {@link ParamPatch} by a scalar multiplier.
19+
*
20+
* @see ParamPatch#scale(float)
21+
*/
22+
public class ScaledParamPatch extends ParamPatch {
23+
24+
float scale;
25+
ParamPatch base;
26+
27+
/**
28+
* Constructs a {@link ScaledParamPatch}.
29+
*
30+
* @param scale the scalar multiplier
31+
* @param base the {@link ParamPatch} to scale
32+
*/
33+
public ScaledParamPatch(float scale, ParamPatch base) {
34+
if (base instanceof ScaledParamPatch) {
35+
ScaledParamPatch sbase = (ScaledParamPatch) base;
36+
this.scale = scale * sbase.scale;
37+
this.base = sbase.base;
38+
} else {
39+
this.scale = scale;
40+
this.base = base;
41+
}
42+
}
43+
44+
/** {@inheritDoc} */
45+
@Override
46+
public NDArray getPatch(String paramName) {
47+
return base.getPatch(paramName).muli(scale);
48+
}
49+
50+
/** {@inheritDoc} */
51+
@Override
52+
public void close() {
53+
base.close();
54+
}
55+
}

0 commit comments

Comments
 (0)