|
| 1 | +import numpy as np |
| 2 | + |
| 3 | +from hls4ml.backends import Backend |
| 4 | +from hls4ml.backends.template import FunctionCallTemplate |
| 5 | +from hls4ml.model.layers import Layer |
| 6 | +from hls4ml.model.optimizer import OptimizerPass |
| 7 | +from hls4ml.model.optimizer.passes.hgq_proxy_model import FixedPointQuantizer, UnaryLUT |
| 8 | +from hls4ml.model.types import Source |
| 9 | + |
| 10 | + |
| 11 | +def to_apfixed(k, b, i, RND, SAT): |
| 12 | + u = 'u' if k == 0 else '' |
| 13 | + return f'ap_{u}fixed<{b},{i},AP_{RND},AP_{SAT}>' |
| 14 | + |
| 15 | + |
| 16 | +def to_acfixed(k, b, i, RND, SAT): |
| 17 | + k = 'false' if k == 0 else 'true' |
| 18 | + return f'ac_fixed<{b},{i},{k},AC_{RND},AC_{SAT}>' |
| 19 | + |
| 20 | + |
| 21 | +def generate_mask_fn( |
| 22 | + name: str, shape: tuple[int, ...], k: np.ndarray, b: np.ndarray, i: np.ndarray, RND: str, SAT: str, backend: str |
| 23 | +) -> str: |
| 24 | + """Generate heterogenous quantization mask function, ONLY works for IOType=io_parallel""" |
| 25 | + assert k.shape[0] == b.shape[0] == i.shape[0] == 1 |
| 26 | + assert backend.lower() in ('quartus', 'vivado', 'vitis'), f'Backend {backend} not tested' |
| 27 | + Ks, Bs, Is = k[0], b[0], i[0] |
| 28 | + Ks, Bs, Is = np.broadcast_to(Ks, shape), np.broadcast_to(Bs, shape), np.broadcast_to(Is, shape) |
| 29 | + Ks, Bs, Is = Ks.ravel(), Bs.ravel(), Is.ravel() |
| 30 | + masks = [] |
| 31 | + to_fixed = to_acfixed if backend.lower() == 'quartus' else to_apfixed |
| 32 | + for idx, (k, b, i) in enumerate(zip(Ks, Bs, Is)): |
| 33 | + if b == 0: |
| 34 | + fn = f'out[{idx}] = 0;' |
| 35 | + else: |
| 36 | + fn = f'out[{idx}] = {to_fixed(k, b, i, RND, SAT)}(inp[{idx}]);' |
| 37 | + masks.append(f' {fn}') |
| 38 | + body = "\n".join(masks) |
| 39 | + mask_fn = f''' |
| 40 | +template<typename input_t, typename output_t> |
| 41 | +void {name}(input_t *inp, output_t *out) {{ |
| 42 | + #pragma HLS INLINE |
| 43 | +
|
| 44 | +{body} |
| 45 | +}} |
| 46 | +''' |
| 47 | + return mask_fn |
| 48 | + |
| 49 | + |
| 50 | +class ProcessFixedPointQuantizerLayer(OptimizerPass): |
| 51 | + def match(self, node: Layer): |
| 52 | + return isinstance(node, FixedPointQuantizer) |
| 53 | + |
| 54 | + def transform(self, model, node: FixedPointQuantizer): |
| 55 | + if node.fusible: |
| 56 | + model.remove_node(node, rewire=True) |
| 57 | + return True |
| 58 | + |
| 59 | + if model.config.config['IOType'] != 'io_parallel': |
| 60 | + raise NotImplementedError('Heterogenous quantization for activations is only supported with IOType=io_parallel') |
| 61 | + |
| 62 | + backend = model.config.config['Backend'] |
| 63 | + |
| 64 | + name = node.name |
| 65 | + |
| 66 | + assert node.mask_kbi is not None |
| 67 | + k, b, i = node.mask_kbi |
| 68 | + RND = node.RND |
| 69 | + SAT = node.SAT |
| 70 | + mask_fn: str = generate_mask_fn(name, node.get_input_variable().shape, k, b, i, RND, SAT, backend) |
| 71 | + |
| 72 | + node.set_attr('mask_fn_codegen', Source(mask_fn)) |
| 73 | + |
| 74 | + |
| 75 | +class ProcessFixedPointQuantizerCall(FunctionCallTemplate): |
| 76 | + def __init__(self): |
| 77 | + super().__init__(FixedPointQuantizer, include_header=[]) |
| 78 | + self.template = 'nnet::{name}<{input_t}, {output_t}>({input}, {output});' |
| 79 | + |
| 80 | + def format(self, node): |
| 81 | + params = self._default_function_params(node) |
| 82 | + |
| 83 | + return self.template.format(**params) |
| 84 | + |
| 85 | + |
| 86 | +class ProcessUnaryLUTCall(FunctionCallTemplate): |
| 87 | + def __init__(self): |
| 88 | + super().__init__(UnaryLUT, include_header=[]) |
| 89 | + self.template = 'nnet::unary_lut<{input_t}, {output_t}, {config}>({input}, {output}, {table});' |
| 90 | + self.include_header = [ |
| 91 | + 'nnet_utils/nnet_activation.h', |
| 92 | + 'nnet_utils/nnet_activation_stream.h', |
| 93 | + ] |
| 94 | + |
| 95 | + def format(self, node): |
| 96 | + params = self._default_function_params(node) |
| 97 | + node.attributes['result_t'].precision = node.attributes['table_t'].precision |
| 98 | + params['config'] = f'unary_lut_config{node.index}' |
| 99 | + params['table'] = node.get_weights('table').name |
| 100 | + |
| 101 | + return self.template.format(**params) |
| 102 | + |
| 103 | + |
| 104 | +def register_hgq_proxy_model(backend: Backend): |
| 105 | + backend.register_pass('process_fixed_point_quantizer_layer', ProcessFixedPointQuantizerLayer) |
| 106 | + backend.register_template(ProcessFixedPointQuantizerCall) |
| 107 | + backend.register_template(ProcessUnaryLUTCall) |
0 commit comments