|
1 | 1 | import asyncio
|
| 2 | +import inspect |
2 | 3 | import re
|
3 | 4 | import traceback
|
4 | 5 | from typing import Iterable, Iterator
|
|
9 | 10 | from dcim.models import Device, DeviceRole, Site
|
10 | 11 | from django.db.models import Q
|
11 | 12 | from extras.models import CustomField
|
12 |
| -from extras.scripts import MultiObjectVar, ObjectVar, TextVar |
| 13 | +from extras.scripts import MultiObjectVar, ObjectVar, ScriptVariable, TextVar |
13 | 14 | from jinja2.exceptions import TemplateError
|
14 | 15 | from netutils.config.compliance import diff_network_config
|
15 | 16 | from utilities.exceptions import AbortScript
|
@@ -68,13 +69,35 @@ class ConfigDiffBase(SecretsMixin):
|
68 | 69 | "Reference the object as <code>{{ object }}</code>.",
|
69 | 70 | )
|
70 | 71 |
|
71 |
| - def __init__(self, *args, **kwargs): |
72 |
| - super().__init__(*args, **kwargs) |
73 |
| - self.custom_field.query_params["object_type_id"] = self._get_device_object_type_id() |
74 |
| - |
75 |
| - def _get_device_object_type_id(self) -> list[int]: |
| 72 | + @classmethod |
| 73 | + def _get_device_object_type_id(cls) -> list[int]: |
76 | 74 | return list(ObjectType.objects.filter(app_label="dcim", model="device").values_list("id", flat=True))
|
77 | 75 |
|
| 76 | + @classmethod |
| 77 | + def _get_vars(cls): |
| 78 | + vars = {} |
| 79 | + device_id = cls._get_device_object_type_id() |
| 80 | + |
| 81 | + # Iterate all base classes looking for ScriptVariables |
| 82 | + for base_class in inspect.getmro(cls): |
| 83 | + # When object is reached there's no reason to continue |
| 84 | + if base_class is object: |
| 85 | + break |
| 86 | + |
| 87 | + for name, attr in base_class.__dict__.items(): |
| 88 | + if name not in vars and issubclass(attr.__class__, ScriptVariable): |
| 89 | + if name == "custom_field": |
| 90 | + attr.field_attrs["query_params"]["object_type_id"] = device_id |
| 91 | + vars[name] = attr |
| 92 | + |
| 93 | + # Order variables according to field_order |
| 94 | + if not cls.field_order: |
| 95 | + return vars |
| 96 | + ordered_vars = {field: vars.pop(field) for field in cls.field_order if field in vars} |
| 97 | + ordered_vars.update(vars) |
| 98 | + |
| 99 | + return ordered_vars |
| 100 | + |
78 | 101 | def run_script(self, data: dict) -> None:
|
79 | 102 | devices = self.validate_data(data)
|
80 | 103 | devices = list(self.get_devices_with_rendered_configs(devices))
|
|
0 commit comments