diff --git a/crytic_compile/crytic_compile.py b/crytic_compile/crytic_compile.py index 605a20aa..534ac67f 100644 --- a/crytic_compile/crytic_compile.py +++ b/crytic_compile/crytic_compile.py @@ -27,6 +27,7 @@ from crytic_compile.platform.all_export import PLATFORMS_EXPORT from crytic_compile.platform.solc import Solc from crytic_compile.platform.standard import export_to_standard +from crytic_compile.utils.libraries import generate_library_addresses, get_deployment_order from crytic_compile.utils.naming import Filename from crytic_compile.utils.npm import get_package_name from crytic_compile.utils.zip import load_from_zip @@ -206,6 +207,10 @@ def __init__(self, target: Union[str, AbstractPlatform], **kwargs: str) -> None: self._bytecode_only = False + self._autolink: bool = kwargs.get("compile_autolink", False) # type: ignore + + self._autolink_deployment_order: Optional[List[str]] = None + self.libraries: Optional[Dict[str, int]] = _extract_libraries(kwargs.get("compile_libraries", None)) # type: ignore self._compile(**kwargs) @@ -632,12 +637,61 @@ def _compile(self, **kwargs: str) -> None: self._platform.clean(**kwargs) self._platform.compile(self, **kwargs) + # Handle autolink after compilation + if self._autolink: + self._apply_autolink() + remove_metadata = kwargs.get("compile_remove_metadata", False) if remove_metadata: for compilation_unit in self._compilation_units.values(): for source_unit in compilation_unit.source_units.values(): source_unit.remove_metadata() + def _apply_autolink(self) -> None: + """Apply automatic library linking with sequential addresses""" + + # Collect all libraries that need linking and compute deployment info + all_libraries_needed: Set[str] = set() + all_dependencies: Dict[str, List[str]] = {} + all_target_contracts: List[str] = [] + + for compilation_unit in self._compilation_units.values(): + # Build dependency graph for this compilation unit + for source_unit in compilation_unit.source_units.values(): + all_target_contracts.extend(source_unit.contracts_names_without_libraries) + + for contract_name in source_unit.contracts_names: + deps = source_unit.libraries_names(contract_name) + if deps: + all_dependencies[contract_name] = deps + all_libraries_needed.update(deps) + + all_target_contracts = [c for c in all_target_contracts if c not in all_libraries_needed] + + # Calculate deployment order globally + deployment_order, _ = get_deployment_order(all_dependencies, all_target_contracts) + self._autolink_deployment_order = deployment_order + + if all_libraries_needed: + # Apply the library linking (similar to compile_libraries but auto-generated) + library_addresses = generate_library_addresses(all_libraries_needed) + + if self.libraries is None: + self.libraries = {} + + # Respect any user-provided addresses through compile_libraries + library_addresses.update(self.libraries) + self.libraries = library_addresses + + @property + def deployment_order(self) -> Optional[List[str]]: + """Return the library deployment order. + + Returns: + Optional[List[str]]: Library deployment order + """ + return self._autolink_deployment_order + @staticmethod def _run_custom_build(custom_build: str) -> None: """Run a custom build diff --git a/crytic_compile/cryticparser/cryticparser.py b/crytic_compile/cryticparser/cryticparser.py index 682cfec5..0ae26c6f 100755 --- a/crytic_compile/cryticparser/cryticparser.py +++ b/crytic_compile/cryticparser/cryticparser.py @@ -35,6 +35,13 @@ def init(parser: ArgumentParser) -> None: default=DEFAULTS_FLAG_IN_CONFIG["compile_libraries"], ) + group_compile.add_argument( + "--compile-autolink", + help="Automatically link all found libraries with sequential addresses starting from 0xa070", + action="store_true", + default=DEFAULTS_FLAG_IN_CONFIG["compile_autolink"], + ) + group_compile.add_argument( "--compile-remove-metadata", help="Remove the metadata from the bytecodes", diff --git a/crytic_compile/cryticparser/defaults.py b/crytic_compile/cryticparser/defaults.py index f6719149..2205f22b 100755 --- a/crytic_compile/cryticparser/defaults.py +++ b/crytic_compile/cryticparser/defaults.py @@ -48,4 +48,5 @@ "foundry_compile_all": False, "export_dir": "crytic-export", "compile_libraries": None, + "compile_autolink": False, } diff --git a/crytic_compile/platform/solc.py b/crytic_compile/platform/solc.py index 373aea30..130b4061 100644 --- a/crytic_compile/platform/solc.py +++ b/crytic_compile/platform/solc.py @@ -57,9 +57,46 @@ def _build_contract_data(compilation_unit: "CompilationUnit") -> Dict: return contracts +def _export_link_info(compilation_unit: "CompilationUnit", key: str, export_dir: str) -> str: + """Export linking information to a separate file. + + Args: + compilation_unit (CompilationUnit): Compilation unit to export + key (str): Filename Id + export_dir (str): Export directory + + Returns: + str: path to the generated file""" + + autolink_path = os.path.join(export_dir, f"{key}.link") + + # Get library addresses if they exist + library_addresses = {} + if compilation_unit.crytic_compile.libraries: + library_addresses = { + name: f"0x{addr:040x}" + for name, addr in compilation_unit.crytic_compile.libraries.items() + } + + # Filter deployment order to only include libraries that have addresses + full_deployment_order = compilation_unit.crytic_compile.deployment_order or [] + filtered_deployment_order = [lib for lib in full_deployment_order if lib in library_addresses] + + # Create autolink output with deployment order and library addresses + autolink_output = { + "deployment_order": filtered_deployment_order, + "library_addresses": library_addresses, + } + + with open(autolink_path, "w", encoding="utf8") as file_desc: + json.dump(autolink_output, file_desc, indent=2) + + return autolink_path + + def export_to_solc_from_compilation_unit( compilation_unit: "CompilationUnit", key: str, export_dir: str -) -> Optional[str]: +) -> Optional[List[str]]: """Export the compilation unit to the standard solc output format. The exported file will be $key.json @@ -69,7 +106,7 @@ def export_to_solc_from_compilation_unit( export_dir (str): Export directory Returns: - Optional[str]: path to the file generated + Optional[List[str]]: path to the files generated """ contracts = _build_contract_data(compilation_unit) @@ -88,7 +125,15 @@ def export_to_solc_from_compilation_unit( with open(path, "w", encoding="utf8") as file_desc: json.dump(output, file_desc) - return path + + paths = [path] + + # Export link info if compile_autolink or compile_libraries was used + if compilation_unit.crytic_compile.libraries: + link_path = _export_link_info(compilation_unit, key, export_dir) + paths.append(link_path) + + return paths return None @@ -110,17 +155,18 @@ def export_to_solc(crytic_compile: "CryticCompile", **kwargs: str) -> List[str]: if len(crytic_compile.compilation_units) == 1: compilation_unit = list(crytic_compile.compilation_units.values())[0] - path = export_to_solc_from_compilation_unit(compilation_unit, "combined_solc", export_dir) - if path: - return [path] + paths = export_to_solc_from_compilation_unit(compilation_unit, "combined_solc", export_dir) + if paths: + return paths return [] - paths = [] + all_paths = [] for key, compilation_unit in crytic_compile.compilation_units.items(): - path = export_to_solc_from_compilation_unit(compilation_unit, key, export_dir) - if path: - paths.append(path) - return paths + paths = export_to_solc_from_compilation_unit(compilation_unit, key, export_dir) + if paths: + all_paths.extend(paths) + + return all_paths class Solc(AbstractPlatform): diff --git a/crytic_compile/utils/libraries.py b/crytic_compile/utils/libraries.py new file mode 100644 index 00000000..25a916b3 --- /dev/null +++ b/crytic_compile/utils/libraries.py @@ -0,0 +1,94 @@ +""" +Library utilities for dependency resolution and auto-linking +""" +from typing import Dict, List, Set, Tuple + + +def get_deployment_order( + dependencies: Dict[str, List[str]], target_contracts: List[str] +) -> Tuple[List[str], Set[str]]: + """Get deployment order using topological sorting (Kahn's algorithm) + + Args: + dependencies: Dict mapping contract_name -> [required_libraries] + target_contracts: List of target contracts to prioritize + + Raises: + ValueError: if a circular dependency is identified + + Returns: + Tuple of (deployment_order, libraries_needed) + """ + # Build complete dependency graph + all_contracts = set(dependencies.keys()) + for deps in dependencies.values(): + all_contracts.update(deps) + + # Calculate in-degrees + in_degree = {contract: 0 for contract in all_contracts} + for contract, deps in dependencies.items(): + for dep in deps: + if dep in in_degree: + in_degree[contract] += 1 + + # Initialize queue with nodes that have no dependencies + queue = [contract for contract in all_contracts if in_degree[contract] == 0] + + result = [] + libraries_needed = set() + + deployment_order = [] + + while queue: + # Sort queue to prioritize libraries first, then target contracts in order + queue.sort( + key=lambda x: ( + x in target_contracts, # Libraries (False) come before targets (True) + target_contracts.index(x) if x in target_contracts else 0, # Target order + ) + ) + + current = queue.pop(0) + result.append(current) + + # Check if this is a library (not in target contracts but required by others) + if current not in target_contracts: + libraries_needed.add(current) + deployment_order.append(current) # Only add libraries to deployment order + + # Update in-degrees for dependents + for contract, deps in dependencies.items(): + if current in deps: + in_degree[contract] -= 1 + if in_degree[contract] == 0 and contract not in result: + queue.append(contract) + + # Check for circular dependencies + if len(result) != len(all_contracts): + remaining = all_contracts - set(result) + raise ValueError(f"Circular dependency detected involving: {remaining}") + + return deployment_order, libraries_needed + + +def generate_library_addresses( + libraries_needed: Set[str], start_address: int = 0xA070 +) -> Dict[str, int]: + """Generate sequential addresses for libraries + + Args: + libraries_needed: Set of library names that need addresses + start_address: Starting address (default 0xa070, resembling "auto") + + Returns: + Dict mapping library_name -> address + """ + library_addresses = {} + current_address = start_address + + # Sort libraries for consistent ordering + for library in sorted(libraries_needed): + library_addresses[library] = current_address + current_address += 1 + + return library_addresses diff --git a/tests/library_dependency_test.sol b/tests/library_dependency_test.sol new file mode 100644 index 00000000..ad059939 --- /dev/null +++ b/tests/library_dependency_test.sol @@ -0,0 +1,88 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.0; + +// Simple library with no dependencies +library MathLib { + function add(uint256 a, uint256 b) external pure returns (uint256) { + return a + b; + } + + function multiply(uint256 a, uint256 b) external pure returns (uint256) { + return a * b; + } +} + +// Library that depends on MathLib +library AdvancedMath { + function square(uint256 a) external pure returns (uint256) { + return MathLib.multiply(a, a); + } + + function addAndSquare(uint256 a, uint256 b) external pure returns (uint256) { + uint256 sum = MathLib.add(a, b); + return MathLib.multiply(sum, sum); + } +} + +// Library that depends on both MathLib and AdvancedMath +library ComplexMath { + function complexOperation(uint256 a, uint256 b) external pure returns (uint256) { + uint256 squared = AdvancedMath.square(a); + return MathLib.add(squared, b); + } + + function megaOperation(uint256 a, uint256 b, uint256 c) external pure returns (uint256) { + uint256 result1 = AdvancedMath.addAndSquare(a, b); + uint256 result2 = MathLib.multiply(result1, c); + return result2; + } +} + +// Contract that uses ComplexMath (which transitively depends on others) +contract TestComplexDependencies { + uint256 public result; + + constructor() { + result = 0; + } + + function performComplexCalculation(uint256 a, uint256 b, uint256 c) public { + result = ComplexMath.megaOperation(a, b, c); + } + + function performSimpleCalculation(uint256 a, uint256 b) public { + result = ComplexMath.complexOperation(a, b); + } + + function getResult() public view returns (uint256) { + return result; + } +} + +// Another contract that only uses MathLib directly +contract SimpleMathContract { + uint256 public value; + + constructor(uint256 _initial) { + value = _initial; + } + + function addValue(uint256 _amount) public { + value = MathLib.add(value, _amount); + } + + function multiplyValue(uint256 _factor) public { + value = MathLib.multiply(value, _factor); + } +} + +// Contract that uses multiple libraries at the same level +contract MultiLibraryContract { + uint256 public simpleResult; + uint256 public advancedResult; + + function calculate(uint256 a, uint256 b) public { + simpleResult = MathLib.add(a, b); + advancedResult = AdvancedMath.square(a); + } +} \ No newline at end of file diff --git a/tests/test_auto_library_linking.py b/tests/test_auto_library_linking.py new file mode 100644 index 00000000..cbcfc607 --- /dev/null +++ b/tests/test_auto_library_linking.py @@ -0,0 +1,157 @@ +""" +Test auto library linking functionality +""" +import json +import os +from pathlib import Path +import shutil + +from crytic_compile.crytic_compile import CryticCompile +from crytic_compile.utils.libraries import get_deployment_order + +TEST_DIR = Path(__file__).resolve().parent + + +def test_dependency_resolution(): + """Test that library dependencies are resolved correctly""" + cc = CryticCompile(Path(TEST_DIR / "library_dependency_test.sol").as_posix()) + + compilation_unit = list(cc.compilation_units.values())[0] + source_unit = list(compilation_unit.source_units.values())[0] + + # Check dependencies for TestComplexDependencies + deps = source_unit.libraries_names("TestComplexDependencies") + assert "ComplexMath" in deps, "TestComplexDependencies should depend on ComplexMath" + + +def test_deployment_order(): + """Test that deployment order is calculated correctly""" + # Create a simple dependency graph for testing + dependencies = { + "TestComplexDependencies": ["ComplexMath"], + "ComplexMath": ["AdvancedMath", "MathLib"], + "AdvancedMath": ["MathLib"], + "MathLib": [], + "SimpleMathContract": ["MathLib"], + } + + target_contracts = ["TestComplexDependencies", "SimpleMathContract"] + + deployment_order, libraries_needed = get_deployment_order(dependencies, target_contracts) + + # Check that deployment order only contains libraries, not target contracts + assert ( + "TestComplexDependencies" not in deployment_order + ), "Target contracts should not be in deployment order" + assert ( + "SimpleMathContract" not in deployment_order + ), "Target contracts should not be in deployment order" + + # MathLib should come first (no dependencies) + assert deployment_order.index("MathLib") < deployment_order.index("AdvancedMath") + assert deployment_order.index("MathLib") < deployment_order.index("ComplexMath") + assert deployment_order.index("AdvancedMath") < deployment_order.index("ComplexMath") + + # Check that libraries are identified correctly + expected_libraries = {"MathLib", "AdvancedMath", "ComplexMath"} + assert libraries_needed == expected_libraries + + +def test_circular_dependency_detection(): + """Test that circular dependencies are detected""" + # Create a circular dependency graph + dependencies = { + "A": ["B"], + "B": ["C"], + "C": ["A"], # Circular dependency + } + + target_contracts = ["A"] + + try: + get_deployment_order(dependencies, target_contracts) + assert False, "Should have raised ValueError for circular dependency" + except ValueError as e: + assert "Circular dependency" in str(e) + + +def test_no_autolink_without_flag(): + """Test that autolink features don't activate without the flag""" + cc = CryticCompile(Path(TEST_DIR / "library_dependency_test.sol").as_posix()) + + # Check that autolink did not generate library addresses + assert ( + cc.libraries is None or len(cc.libraries) == 0 + ), "Autolink should not generate library addresses without flag" + + # Export and check that no autolink file is created + export_files = cc.export(export_format="solc", export_dir="test_no_autolink_output") + + autolink_file_found = False + for export_file in export_files: + filename = os.path.basename(export_file) + if "autolink" in filename: + autolink_file_found = True + break + + assert not autolink_file_found, "No autolink file should be created without the flag" + + # Clean up + if os.path.exists("test_no_autolink_output"): + shutil.rmtree("test_no_autolink_output") + + +def test_autolink_functionality(): + """Test the autolink functionality""" + cc = CryticCompile( + Path(TEST_DIR / "library_dependency_test.sol").as_posix(), compile_autolink=True + ) + + # Check that autolink generated library addresses + assert cc.libraries is not None, "Autolink should generate library addresses" + assert len(cc.libraries) > 0, "Should have detected libraries to link" + + expected_libs = ["MathLib", "AdvancedMath", "ComplexMath"] + for lib in expected_libs: + assert lib in cc.libraries, f"Library {lib} should be auto-linked" + + # Export and check that autolink file is created + export_files = cc.export(export_format="solc", export_dir="test_autolink_output") + + # Check that autolink file was created + autolink_file = None + for export_file in export_files: + filename = os.path.basename(export_file) + if filename.endswith(".link"): + autolink_file = export_file + break + + assert autolink_file is not None, "Autolink file should be created" + + with open(autolink_file, "r", encoding="utf8") as f: + autolink_data = json.load(f) + + # Check autolink file structure + assert "deployment_order" in autolink_data, "Autolink file should contain deployment_order" + assert "library_addresses" in autolink_data, "Autolink file should contain library_addresses" + assert ( + len(autolink_data["library_addresses"]) > 0 + ), "Should have library addresses in autolink file" + + # Check deployment order contains expected contracts + deployment_order = autolink_data["deployment_order"] + assert "MathLib" in deployment_order, "Deployment order should contain MathLib" + assert "ComplexMath" in deployment_order, "Deployment order should contain ComplexMath" + + # Clean up + if os.path.exists("test_autolink_output"): + shutil.rmtree("test_autolink_output") + + +if __name__ == "__main__": + test_dependency_resolution() + test_deployment_order() + test_circular_dependency_detection() + test_no_autolink_without_flag() + test_autolink_functionality() + print("All tests passed!")