diff --git a/.gitignore b/.gitignore index b7faf40..f58e991 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ __pycache__/ # C extensions *.so +*.bin # Distribution / packaging .Python @@ -25,6 +26,7 @@ share/python-wheels/ .installed.cfg *.egg MANIFEST +.venv/ # PyInstaller # Usually these files are written by a python script from a template @@ -182,9 +184,9 @@ cython_debug/ .abstra/ # Visual Studio Code -# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore +# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore -# and can be added to the global gitignore or merged into this file. However, if you prefer, +# and can be added to the global gitignore or merged into this file. However, if you prefer, # you could uncomment the following to ignore the entire vscode folder # .vscode/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..f5a1230 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,20 @@ +exclude: ^tests/generic_tests/targets/ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + - id: check-added-large-files + - id: check-merge-conflict + - id: check-yaml + - id: end-of-file-fixer + - id: mixed-line-ending + args: ["--fix=no"] + - id: trailing-whitespace + - repo: https://github.com/python-jsonschema/check-jsonschema + rev: 0.27.0 + hooks: + - id: check-dependabot + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.4.8 + hooks: + - id: ruff-format diff --git a/patchery/__init__.py b/patchery/__init__.py index 1d476fc..43a488c 100644 --- a/patchery/__init__.py +++ b/patchery/__init__.py @@ -1,6 +1,7 @@ __version__ = "0.0.0" import logging + logging.getLogger("patchery").addHandler(logging.NullHandler()) from .logger import Loggers @@ -8,10 +9,7 @@ del Loggers import os + # stop LiteLLM from querying at all to the remote server # https://github.com/BerriAI/litellm/blob/4d29c1fb6941e49191280c4fd63961dec1a1e7c5/litellm/__init__.py#L286C20-L286C48 os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" -from .data import Patch -from .generator import LLMPatchGenerator -#from .verifier import PatchVerifier -#from .patcher import Patcher diff --git a/patchery/aicc_patcher.py b/patchery/aicc_patcher.py index 16cbbe2..b0a08c6 100644 --- a/patchery/aicc_patcher.py +++ b/patchery/aicc_patcher.py @@ -7,25 +7,37 @@ import yaml -from shellphish_crs_utils.models.crs_reports import RepresentativeFullPoVReport, POIReport -from shellphish_crs_utils.models.patch import PatchMetaData -from shellphish_crs_utils.oss_fuzz.project import OSSFuzzProject +from patchery.data.models.crs_reports import RepresentativeFullPoVReport, POIReport +from patchery.data.models.patch import PatchMetaData +# from shellphish_crs_utils.oss_fuzz.project import OSSFuzzProject import patchery from patchery import Patcher, LLMPatchGenerator -from patchery.utils import absolute_path_finder, read_src_from_file, find_src_root_from_commit, llm_model_name +from patchery.utils import ( + absolute_path_finder, + read_src_from_file, + find_src_root_from_commit, + llm_model_name, +) from patchery.kumushi.root_cause_analyzer import RootCauseAnalyzer from patchery.kumushi.rca_mode import RCAMode from patchery.kumushi.aixcc import AICCProgram -from patchery.data import ProgramInput, ProgramInputType, PoI, PoICluster, PoISource, Program +from patchery.data import ( + ProgramInput, + ProgramInputType, + PoI, + PoICluster, + PoISource, + Program, +) from patchery.kumushi.util import load_clusters_from_yaml _l = logging.getLogger(__name__) class AICCPatcher(Patcher): - DEFAULT_LLM_MODEL = 'claude-3.7-sonnet' + DEFAULT_LLM_MODEL = "claude-3.7-sonnet" def __init__( self, @@ -35,9 +47,8 @@ def __init__( patch_metadata_output_dir=None, local_run: bool = False, kumushi_clusters: list | None = None, - **kwargs + **kwargs, ): - # private api self._patch_output_dir = patch_output_dir self._patch_metadata_output_dir = patch_metadata_output_dir @@ -48,7 +59,9 @@ def __init__( self.is_local_run = local_run self.pois = [] - super().__init__(program, llm_model_name(model=self.DEFAULT_LLM_MODEL), **kwargs) + super().__init__( + program, llm_model_name(model=self.DEFAULT_LLM_MODEL), **kwargs + ) # generate pois for patching self.pois = self.poi_clusters_from_kumushi() @@ -60,13 +73,15 @@ def __init__( def poi_clusters_from_kumushi(self, kumushi_report=None): if not self._kumushi_clusters: - _l.info("No KumuShi report provided, generating PoIs from local KumuShi run...") - rca = RootCauseAnalyzer( - self.program_info, - rca_mode=RCAMode.WEIGHTLESS + _l.info( + "No KumuShi report provided, generating PoIs from local KumuShi run..." ) + rca = RootCauseAnalyzer(self.program_info, rca_mode=RCAMode.WEIGHTLESS) poi_clusters = rca.weightless_pois - _l.info(f"Since we are using KumuShi in weightless, we will limit attempts to only %d.", self._weightless_limited_attempts) + _l.info( + f"Since we are using KumuShi in weightless, we will limit attempts to only %d.", + self._weightless_limited_attempts, + ) self.max_attempts = self._weightless_limited_attempts self.program_info.code.reinit_or_get_function_resolver() else: @@ -79,8 +94,14 @@ def poi_clusters_from_kumushi(self, kumushi_report=None): def _update_patch_output_locations(self) -> tuple[Path, Path]: # patch output location patch_name = hashlib.md5(os.urandom(16)).hexdigest() - patch_output_dir = Path(self._patch_output_dir) if self._patch_output_dir else None - patch_metadata_output_dir = Path(self._patch_metadata_output_dir) if self._patch_metadata_output_dir else None + patch_output_dir = ( + Path(self._patch_output_dir) if self._patch_output_dir else None + ) + patch_metadata_output_dir = ( + Path(self._patch_metadata_output_dir) + if self._patch_metadata_output_dir + else None + ) assert patch_output_dir.exists() assert patch_metadata_output_dir.exists() return patch_output_dir / patch_name, patch_metadata_output_dir / patch_name @@ -93,62 +114,76 @@ def generate_verified_patches(self, *args, **kwargs): verified_patches = super().generate_verified_patches(self.pois, **kwargs) if verified_patches: for patch_group in verified_patches: - for patch in patch_group['patches']: + for patch in patch_group["patches"]: patch_diff = self.program_info.git_diff(patch) - patch_output_file, patch_metadata_output_file = self._update_patch_output_locations() - build_request = patch.metadata.get('build_request_id', None) - summary = patch.metadata.get('summary', None) + patch_output_file, patch_metadata_output_file = ( + self._update_patch_output_locations() + ) + build_request = patch.metadata.get("build_request_id", None) + summary = patch.metadata.get("summary", None) if build_request is None: - _l.critical("No build request ID found in patch metadata, using crash report ID instead.") + _l.critical( + "No build request ID found in patch metadata, using crash report ID instead." + ) with open(patch_metadata_output_file, "w") as f: patch_metadata: PatchMetaData = PatchMetaData( patcher_name=patcher_name, - total_cost=patch_group['cost'], + total_cost=patch_group["cost"], poi_report_id=self.program_info.poi_report.crash_report_id, pdt_project_id=self.program_info.poi_report.project_id, pdt_project_name=self.program_info.poi_report.project_name, pdt_harness_info_id=self.program_info.poi_report.harness_info_id, build_request_id=build_request, ) - yaml.safe_dump(patch_metadata.model_dump(), f, default_flow_style=False, sort_keys=False) + yaml.safe_dump( + patch_metadata.model_dump(), + f, + default_flow_style=False, + sort_keys=False, + ) with open(patch_output_file, "w") as f: f.write(patch_diff) - _l.info(f'Patch data saved! Patch: %s | Metadata: %s', patch_output_file, patch_metadata_output_file) + _l.info( + f"Patch data saved! Patch: %s | Metadata: %s", + patch_output_file, + patch_metadata_output_file, + ) _l.info(f"💸 The total cost of this patch was {self.total_cost} dollars.") else: - _l.info(f"💸 We could not make a patch. The total cost was {self.total_cost} dollars.") + _l.info( + f"💸 We could not make a patch. The total cost was {self.total_cost} dollars." + ) _l.error("Failed to generate any verified patches.") return verified_patches @classmethod def from_files( - cls, - *args, - target_root: Path = None, - source_root: Path = None, - report_yaml_path: Path = None, - project_metadata_path=None, - raw_report_path=None, - function_json_dir=None, - function_indices=None, - alerting_inputs_path=None, - patch_output_dir=None, - patch_metadata_output_dir=None, - crashing_commit=None, - indices_by_commit=None, - changed_func_by_commit=None, - patch_planning=None, - local_run=False, - kumushi_report_path=None, - delta_mode=False, - coverage_build_project_path: Path=None, - patch_request_meta: Path = None, - bypassing_inputs: str = None, - **kwargs + cls, + *args, + target_root: Path = None, + source_root: Path = None, + report_yaml_path: Path = None, + project_metadata_path=None, + raw_report_path=None, + function_json_dir=None, + function_indices=None, + alerting_inputs_path=None, + patch_output_dir=None, + patch_metadata_output_dir=None, + crashing_commit=None, + indices_by_commit=None, + changed_func_by_commit=None, + patch_planning=None, + local_run=False, + kumushi_report_path=None, + delta_mode=False, + coverage_build_project_path: Path = None, + patch_request_meta: Path = None, + bypassing_inputs: str = None, + **kwargs, ) -> "AICCPatcher": - # validate outputs locations exists if patch_output_dir is not None: Path(patch_output_dir).mkdir(exist_ok=True) @@ -178,17 +213,17 @@ def from_files( with raw_report_path.open("r") as f: rep = yaml.safe_load(f) - #rep["dedup_crash_report"]["dedup_tokens_shellphish"] = {} - #rep["run_pov_result"]["pov"]["organizer_crash_eval"] = {} - #rep["run_pov_result"]["pov"]["dedup_crash_report"]["dedup_tokens_shellphish"] = {} - #rep["run_pov_result"]["pov"]["organizer_crash_eval"]["code_label"] = "" - #rep["run_pov_result"]["pov"]["organizer_crash_eval"]["significance"] = 0 - #rep["run_pov_result"]["pov"]["organizer_crash_eval"]["significance_message"] = "" - #rep["run_pov_result"]["pov"]["organizer_crash_eval"]["crash_state"] = "" - #rep["run_pov_result"]["pov"]["dedup_crash_report"]["dedup_tokens_shellphish"]["code_label"] = "" - #rep["run_pov_result"]["pov"]["dedup_crash_report"]["dedup_tokens_shellphish"]["significance"] = "" - #rep["run_pov_result"]["pov"]["dedup_crash_report"]["dedup_tokens_shellphish"]["significance_message"] = "" - #rep["run_pov_result"]["pov"]["dedup_crash_report"]["dedup_tokens_shellphish"]["crash_state"] = "" + # rep["dedup_crash_report"]["dedup_tokens_shellphish"] = {} + # rep["run_pov_result"]["pov"]["organizer_crash_eval"] = {} + # rep["run_pov_result"]["pov"]["dedup_crash_report"]["dedup_tokens_shellphish"] = {} + # rep["run_pov_result"]["pov"]["organizer_crash_eval"]["code_label"] = "" + # rep["run_pov_result"]["pov"]["organizer_crash_eval"]["significance"] = 0 + # rep["run_pov_result"]["pov"]["organizer_crash_eval"]["significance_message"] = "" + # rep["run_pov_result"]["pov"]["organizer_crash_eval"]["crash_state"] = "" + # rep["run_pov_result"]["pov"]["dedup_crash_report"]["dedup_tokens_shellphish"]["code_label"] = "" + # rep["run_pov_result"]["pov"]["dedup_crash_report"]["dedup_tokens_shellphish"]["significance"] = "" + # rep["run_pov_result"]["pov"]["dedup_crash_report"]["dedup_tokens_shellphish"]["significance_message"] = "" + # rep["run_pov_result"]["pov"]["dedup_crash_report"]["dedup_tokens_shellphish"]["crash_state"] = "" pov_report = RepresentativeFullPoVReport.model_validate(rep) @@ -224,7 +259,9 @@ def from_files( if kumushi_report_path: kumushi_report_path = Path(kumushi_report_path) if kumushi_report_path.exists() and kumushi_report_path.is_file(): - kumushi_clusters = load_clusters_from_yaml(kumushi_report_path, aicc_program) + kumushi_clusters = load_clusters_from_yaml( + kumushi_report_path, aicc_program + ) patcher = cls( aicc_program, diff --git a/patchery/data/__init__.py b/patchery/data/__init__.py index 18e4def..d8bfaa1 100644 --- a/patchery/data/__init__.py +++ b/patchery/data/__init__.py @@ -4,5 +4,16 @@ from .program_input import ProgramInput, ProgramInputType from .program_alert import ProgramAlert, ProgramExitType from .program import Program +from .models import ( + PatchRequestMeta, + POIReport, + RootCauseReport, + RepresentativeFullPoVReport, +) +from .function_resolver import ( + FunctionResolver, + LocalFunctionResolver, + RemoteFunctionResolver, +) JAZZER_CMD_INJECT_STR = "OS Command Injection" diff --git a/patchery/data/function_resolver.py b/patchery/data/function_resolver.py new file mode 100644 index 0000000..c869bfc --- /dev/null +++ b/patchery/data/function_resolver.py @@ -0,0 +1,1962 @@ +import hashlib +from abc import abstractmethod + +from collections import defaultdict +from dataclasses import dataclass +from enum import Enum +from functools import lru_cache +import json +import threading +import jq +import logging +import re +import shutil +import subprocess +import tempfile +import requests +import os +from pathlib import Path +import time +from typing import Dict, Iterator, List, Optional, Tuple, Union, Literal +from patchery.data.models.constraints import PDT_ID +from patchery.data.models.symbols import RelativePathKind, SourceLocation +from patchery.data.models.indexer import ( + FunctionIndex, + FUNCTION_INDEX_KEY, +) +from patchery.data.models.coverage import ( + CoverageLine, + LinesCoverage, + FunctionCoverageMap, + FileCoverageMap, +) +from patchery.data.models.target import VALID_SOURCE_FILE_SUFFIXES +from patchery.utils import artiphishell_should_fail_on_error +import yaml + +log = logging.getLogger(__name__) +log.setLevel(logging.INFO) + + +class MatchKind(Enum): + DEFINITELY = 1 + MAYBE = 2 + DEFINITELY_NOT_IN_WELL_BEHAVED_SETTINGS = 3 # heck you sqlite3 + + +@dataclass +class FunctionIndexRanking: + match_kind: MatchKind + match_value: float + + +# for context, we hardcode sqlite3.c here to be compatible with the sqlite3 amalgamation. We should probably be okay with the ranking +# even without this customization, but it helps to be sure for a target we know is likely involved + + +def get_function_name_match( + source_location: SourceLocation, function_index_entry: FunctionIndex +) -> Optional[FunctionIndexRanking]: + if not source_location.function_name: + return None + if source_location.function_name == function_index_entry.funcname: + return FunctionIndexRanking(MatchKind.DEFINITELY, 1.0) + if ( + source_location.function_name.startswith("OSS_FUZZ_") + and source_location.function_name[len("OSS_FUZZ_") :] + == function_index_entry.funcname + ): + # special case function prefixes of OSS_FUZZ_ as e.g., used in libpng + return FunctionIndexRanking(MatchKind.DEFINITELY, 1.0) + if source_location.function_name.endswith( + "::" + function_index_entry.funcname + ) or source_location.function_name.endswith("." + function_index_entry.funcname): + return FunctionIndexRanking(MatchKind.MAYBE, 0.8) + if function_index_entry.funcname.endswith( + "::" + source_location.function_name + ) or function_index_entry.funcname.endswith("." + source_location.function_name): + return FunctionIndexRanking(MatchKind.MAYBE, 0.8) + if function_index_entry.funcname in source_location.function_name: + return FunctionIndexRanking(MatchKind.MAYBE, 0.6) + if source_location.function_name in function_index_entry.funcname: + char_before = ( + source_location.function_name.index(function_index_entry.funcname) - 1 + ) + char_after = char_before + len(function_index_entry.funcname) + char_before = ( + function_index_entry.funcname[char_before] if char_before >= 0 else " " + ) + char_after = ( + function_index_entry.funcname[char_after] + if char_after < len(function_index_entry.funcname) + else " " + ) + if char_before in ["_", " ", ":", "."] and char_after in ["_", " ", ":", "."]: + # underscores are a last resort for stuff like OSS_FUZZ_libpng_read_row and should be ranked lower than non-underscore matches + return FunctionIndexRanking( + MatchKind.MAYBE, + 0.4 if char_before != "_" and char_after != "_" else 0.3, + ) + elif function_index_entry.funcname in source_location.function_name: + char_before = ( + function_index_entry.funcname.index(source_location.function_name) - 1 + ) + char_after = char_before + len(source_location.function_name) + char_before = ( + source_location.function_name[char_before] if char_before >= 0 else " " + ) + char_after = ( + source_location.function_name[char_after] + if char_after < len(source_location.function_name) + else " " + ) + if char_before in ["_", " ", ":", "."] and char_after in ["_", " ", ":", "."]: + return FunctionIndexRanking( + MatchKind.MAYBE, + 0.4 if char_before != "_" and char_after != "_" else 0.3, + ) + + if ( + function_index_entry.funcname in source_location.function_name + or source_location.function_name in function_index_entry.funcname + ): + # if it's included in ANY way possible, it is at least a bit more of a match than no match, but only barely + return FunctionIndexRanking(MatchKind.MAYBE, 0.1) + + return FunctionIndexRanking(MatchKind.DEFINITELY_NOT_IN_WELL_BEHAVED_SETTINGS, 0.0) + + +def count_matching_final_path_parts(path_a, path_b): + parts_a = path_a.parts + parts_b = path_b.parts + count = 0 + for i, (part_a, part_b) in enumerate(zip(reversed(parts_a), reversed(parts_b))): + if part_a == part_b: + count += 1 + else: + break + return count, min(len(parts_a), len(parts_b)) + + +def get_relative_filename_match( + source_location: SourceLocation, function_index_entry: FunctionIndex +) -> Optional[FunctionIndexRanking]: + if ( + source_location.focus_repo_relative_path + and function_index_entry.focus_repo_relative_path + and source_location.focus_repo_relative_path + == function_index_entry.focus_repo_relative_path + ): + return FunctionIndexRanking( + MatchKind.DEFINITELY, 1.0 + ) # if the focus repo relative path is the same, we have a perfect match + if source_location.relative_path: + if ( + source_location.relative_path.name == "sqlite3.c" + ): # fork the amalgamation, holy shirt + return FunctionIndexRanking(MatchKind.DEFINITELY, 1.0) + + if ( + function_index_entry.focus_repo_relative_path + and source_location.relative_path + == function_index_entry.focus_repo_relative_path + ): + return FunctionIndexRanking(MatchKind.DEFINITELY, 1.0) + if str(function_index_entry.target_container_path).endswith( + str(source_location.relative_path) + ): + return FunctionIndexRanking(MatchKind.DEFINITELY, 1.0) + # if function_index_entry.focus_repo_relative_path and ( + # str(source_location.relative_path) in str(function_index_entry.focus_repo_relative_path) + # or + # str(function_index_entry.focus_repo_relative_path) in str(source_location.relative_path) + # ): + # matching, min_match = count_matching_final_path_parts(source_location.relative_path, function_index_entry.focus_repo_relative_path) + # return FunctionIndexRanking(MatchKind.MAYBE, 0.8 * matching / min_match) + + if str(source_location.relative_path) in str( + function_index_entry.target_container_path + ): + matching, min_match = count_matching_final_path_parts( + source_location.relative_path, + function_index_entry.target_container_path, + ) + return FunctionIndexRanking(MatchKind.MAYBE, 0.5 * matching / min_match) + + +def get_full_file_path_match( + source_location: SourceLocation, function_index_entry: FunctionIndex +) -> Optional[FunctionIndexRanking]: + if source_location.full_file_path: + if function_index_entry.target_container_path == source_location.full_file_path: + return FunctionIndexRanking(MatchKind.DEFINITELY, 1.0) + if ( + source_location.full_file_path.name == "sqlite3.c" + ): # fork the amalgamation, holy shirt + return FunctionIndexRanking(MatchKind.DEFINITELY, 1.0) + matching, min_match = count_matching_final_path_parts( + source_location.full_file_path, function_index_entry.target_container_path + ) + return FunctionIndexRanking(MatchKind.MAYBE, 0.8 * matching / min_match) + + +def get_filename_match( + source_location: SourceLocation, function_index_entry: FunctionIndex +) -> Optional[FunctionIndexRanking]: + if source_location.file_name: + if source_location.file_name == function_index_entry.filename: + return FunctionIndexRanking(MatchKind.DEFINITELY, 1.0) + if ( + source_location.file_name == "sqlite3.c" + ): # fork the amalgamation, holy shirt + return FunctionIndexRanking(MatchKind.DEFINITELY, 1.0) + if ( + function_index_entry.focus_repo_relative_path + and source_location.file_name + == function_index_entry.focus_repo_relative_path.name + ): + return FunctionIndexRanking(MatchKind.DEFINITELY, 1.0) + if ( + source_location.file_name.name + == function_index_entry.target_container_path.name + ): + return FunctionIndexRanking(MatchKind.DEFINITELY, 1.0) + return FunctionIndexRanking( + MatchKind.DEFINITELY_NOT_IN_WELL_BEHAVED_SETTINGS, 0.0 + ) + + +def get_line_number_match( + source_location: SourceLocation, function_index_entry: FunctionIndex +) -> Optional[FunctionIndexRanking]: + if source_location.line_number is not None: + if source_location.line_number == function_index_entry.start_line: + return FunctionIndexRanking(MatchKind.DEFINITELY, 1.0) + + # sometimes the first line is a bit off (allow wiggle room of 3 lines) + if ( + function_index_entry.start_line - 3 + <= source_location.line_number + <= function_index_entry.start_line + 3 + ): + return FunctionIndexRanking(MatchKind.MAYBE, 0.8) + + if ( + function_index_entry.start_line - 3 + <= source_location.line_number + <= function_index_entry.end_line + 3 + ): + return FunctionIndexRanking(MatchKind.MAYBE, 0.5) + + # line number is so far away that it's almost certainly not the same function in well-behaved apps. *cough* not sqlite3 *cough* + return FunctionIndexRanking( + MatchKind.DEFINITELY_NOT_IN_WELL_BEHAVED_SETTINGS, 0.0 + ) + + +def get_java_info_match( + source_location: SourceLocation, function_index_entry: FunctionIndex +) -> Optional[FunctionIndexRanking]: + matching = 0 + total = 0 + if not source_location.java_info: + return None + if source_location.java_info.class_name: + total += 1 + cur = 0 + if ( + source_location.java_info.class_name + ".java" + == function_index_entry.target_container_path.name + ): + cur += 1 + if ( + source_location.java_info.class_name == function_index_entry.class_name + or function_index_entry.class_name.endswith( + source_location.java_info.class_name + ) + ): + cur += 1 + + class_name_root = function_index_entry.target_container_path.name.split( + ".java" + )[0] + if source_location.java_info.class_name.startswith(class_name_root): + cur += ( + 0.8 + if source_location.java_info.class_name.startswith( + class_name_root + "$" + ) + else 0.5 + ) + matching += cur / 3 + + if source_location.java_info.method_name: + total += 1 + if source_location.java_info.method_name == function_index_entry.funcname: + matching += 1 + + if source_location.java_info.package: + total += 1 + if source_location.java_info.package == function_index_entry.package: + matching += 1 + + return ( + FunctionIndexRanking(MatchKind.MAYBE, matching / total) if total > 0 else None + ) + + +def prepare_for_json(obj): + if isinstance(obj, Path): + return str(obj) + if isinstance(obj, dict): + return {k: prepare_for_json(v) for k, v in obj.items()} + if isinstance(obj, list): + return [prepare_for_json(v) for v in obj] + return obj + + +def function_index_to_source_location( + key: FUNCTION_INDEX_KEY, index_obj: FunctionIndex +) -> SourceLocation: + # this is questionable + return SourceLocation( + focus_repo_relative_path=index_obj.focus_repo_relative_path, + function_name=index_obj.funcname, + line_number=index_obj.start_line, + raw_signature=index_obj.signature or key.split("::", 1)[-1], + function_index_key=key, + function_index_signature=key, + file_name=Path(index_obj.filename), + full_file_path=key.split(":", 1)[0], + ) + + +class FunctionResolver: + def __init__(self, focus_repo_container_path: Optional[Path] = None): + self.cached_code_lines = {} + self.cached_jq_filter_expression_keys = {} + self.focus_repo_container_path = focus_repo_container_path + self.cached_hashes = {} + + @abstractmethod + def is_ready(self) -> bool: + raise NotImplementedError + + @abstractmethod + def get_funcname(self, key: FUNCTION_INDEX_KEY) -> str: + raise NotImplementedError + + @abstractmethod + def get_focus_repo_relative_path(self, key: FUNCTION_INDEX_KEY) -> Optional[Path]: + raise NotImplementedError + + @abstractmethod + def get_target_container_path(self, key: FUNCTION_INDEX_KEY) -> Path: + raise NotImplementedError + + @abstractmethod + def get_code( + self, key: FUNCTION_INDEX_KEY + ) -> Tuple[Optional[Path], Path, int, str]: + raise NotImplementedError + + @abstractmethod + def get_function_boundary(self, key: FUNCTION_INDEX_KEY) -> Tuple[int, int]: + raise NotImplementedError + + @abstractmethod + def _find_matching_indices(self, s: str) -> Iterator[FUNCTION_INDEX_KEY]: + raise NotImplementedError + + @abstractmethod + def resolve_with_leniency(self, name: str) -> Iterator[FUNCTION_INDEX_KEY]: + raise NotImplementedError + + @abstractmethod + def find_by_funcname(self, s: str) -> Iterator[FUNCTION_INDEX_KEY]: + raise NotImplementedError + + @abstractmethod + def find_by_filename(self, s: Union[Path, str]) -> Iterator[FUNCTION_INDEX_KEY]: + raise NotImplementedError + + @abstractmethod + def find_matching_indices( + self, + indices: List[FUNCTION_INDEX_KEY], + scope: Literal["all", "focus", "non-focus", "compiled"] = "focus", + can_include_self: bool = True, + can_include_build_generated: bool = True, + ) -> Tuple[Dict[FUNCTION_INDEX_KEY, FUNCTION_INDEX_KEY], List[FUNCTION_INDEX_KEY]]: + """ + For each index: + Finds the first "other" index which matches the given index. + For example if given an index from the source directory, + it can find the matching copy in the build directory + Returns a dict of the form {index: matching_index} + and a list of indices that were not found + """ + raise NotImplementedError + + def find_matching_index( + self, + index: FUNCTION_INDEX_KEY, + scope: Literal["all", "focus", "non-focus", "compiled"] = "focus", + can_include_self: bool = True, + can_include_build_generated: bool = True, + ) -> Optional[FUNCTION_INDEX_KEY]: + """ + Finds the first "other" index which matches the given index. + For example if given an index from the source directory, + it can find the matching copy in the build directory + Returns the matching index or None if no match "other" was found + """ + matches, missing = self.find_matching_indices( + [index], scope, can_include_self, can_include_build_generated + ) + if len(matches) == 1: + return list(matches.values())[0] + return None + + def get(self, key: FUNCTION_INDEX_KEY) -> FunctionIndex: + raise NotImplementedError + + def find_functions_with_annotation( + self, annotation: str + ) -> Iterator[FUNCTION_INDEX_KEY]: + raise NotImplementedError( + "This method should be implemented in the subclass. It is not supported in the base FunctionResolver class." + ) + + def get_many( + self, keys: List[FUNCTION_INDEX_KEY] + ) -> Dict[FUNCTION_INDEX_KEY, FunctionIndex]: + return {key: self.get(key) for key in keys} + + def get_with_default( + self, key: FUNCTION_INDEX_KEY, default=None + ) -> Optional[FunctionIndex]: + try: + return self.get(key) + except KeyError: + return default + + def get_full_hash(self, key: FUNCTION_INDEX_KEY, do_cache: bool = True) -> str: + hk = (key, "full") + if hk in self.cached_hashes: + return self.cached_hashes[hk] + + hash = self.get(key).hash + self.cached_hashes[hk] = hash + return hash + + def get_code_line_hash(self, key: FUNCTION_INDEX_KEY, do_cache: bool = True) -> str: + hk = (key, "code_line") + if hk in self.cached_hashes: + return self.cached_hashes[hk] + + meta = self.get(key) + h = hashlib.md5() + h.update(meta.code.encode("utf-8")) + h.update(f"{meta.start_line}".encode("utf-8")) + hash = h.hexdigest() + self.cached_hashes[hk] = hash + return hash + + def get_code_hash(self, key: FUNCTION_INDEX_KEY, do_cache: bool = True) -> str: + hk = (key, "code") + if hk in self.cached_hashes: + return self.cached_hashes[hk] + + meta = self.get(key) + code = meta.code + + hash = hashlib.md5(code.encode("utf-8")).hexdigest() + self.cached_hashes[hk] = hash + return hash + + def get_focus_repo_keys( + self, focus_repo_container_path: Optional[Union[Path, str]] + ) -> List[FUNCTION_INDEX_KEY]: + if not focus_repo_container_path and self.focus_repo_container_path: + focus_repo_container_path = self.focus_repo_container_path + if not focus_repo_container_path: + raise ValueError( + "No focus repo container path provided and no default set." + ) + return [ + key for key in self.keys() if key.startswith(str(focus_repo_container_path)) + ] + + def get_filtered_keys( + self, key_only_jq_filter_expression: str + ) -> List[FUNCTION_INDEX_KEY]: + """ + Get a filtered list of function indices based on a jq filter expression. The jq filter expression operates directly on the function index key. + """ + if key_only_jq_filter_expression not in self.cached_jq_filter_expression_keys: + jq_filter = jq.compile(key_only_jq_filter_expression) + all_keys = self.keys() + is_valid = jq_filter.input_values(self.keys()).all() + filtered_keys = [k for k, valid in zip(all_keys, is_valid) if valid] + self.cached_jq_filter_expression_keys[key_only_jq_filter_expression] = ( + filtered_keys + ) + return self.cached_jq_filter_expression_keys[key_only_jq_filter_expression] + + def get_filtered( + self, key_only_filter_expression: str, full_filter_expression: str + ) -> List[FUNCTION_INDEX_KEY]: + """ + Get a filtered list of function indices based on a jq filter expression. Allows you to retrieve keys that are + filtered based on the key and then filtered again based on the full function index values themselves. + + E.g. `key_only_filter_expression` could be `true` and `full_filter_expression` could be `.focus_repo_relative_path != null` to get all functions that are in the focus repo. + """ + if ( + key_only_filter_expression, + full_filter_expression, + ) not in self.cached_jq_filter_expression_keys: + filtered_keys = self.get_filtered_keys(key_only_filter_expression) + full_jq_filter = jq.compile(full_filter_expression) + + vals_to_filter = [ + {"key": str(key), "value": prepare_for_json(self.get(key).model_dump())} + for key in filtered_keys + ] + is_valid = full_jq_filter.input_values(vals_to_filter).all() + filtered_vals = { + k["key"]: k["value"] + for k, valid in zip(vals_to_filter, is_valid) + if valid + } + + self.cached_jq_filter_expression_keys[ + (key_only_filter_expression, full_filter_expression) + ] = filtered_vals + return self.cached_jq_filter_expression_keys[ + (key_only_filter_expression, full_filter_expression) + ] + + def get_function_code_line(self, key: FUNCTION_INDEX_KEY, line_no: int) -> str: + if (key, line_no) not in self.cached_code_lines: + idx = self.get(key) + try: + code_line = ( + idx.code.split("\n")[line_no - idx.start_line] + if idx is not None + else None + ) + except Exception as e: + log.warning(f"Error getting function code line for {key}: {e}") + code_line = None + self.cached_code_lines[(key, line_no)] = code_line + return self.cached_code_lines[(key, line_no)] + + def get_function_coverage_for_file( + self, + path: Union[Path, str], + lines: LinesCoverage, + function_keys_of_interest=None, + ) -> FunctionCoverageMap: + if path not in self.cached_lines_to_function: + start_time = time.time() + self.cached_lines_to_function[path] = {} + for key in self.find_by_filename(path): + start, end = self.get_function_boundary(key) + self.cached_lines_to_function[path].update( + {line: key for line in range(start, end + 1)} + ) + + # print(f"Loaded {len(self.cached_lines_to_function[path])} lines to function mappings for {path} in {time.time() - start_time:.2f}s") + + start_time = time.time() + res = defaultdict(list) + for line in lines: + if line.line_number not in self.cached_lines_to_function[path]: + continue + containing_function_key = self.cached_lines_to_function[path][ + line.line_number + ] + if ( + function_keys_of_interest is not None + and containing_function_key not in function_keys_of_interest + ): + continue + new = CoverageLine( + line_number=line.line_number, + count_covered=line.count_covered, + code=self.get_function_code_line( + containing_function_key, line.line_number + ), + ) + res[self.cached_lines_to_function[path][line.line_number]].append(new) + + # print(f"Resolved {len(res)} functions for {path} in {time.time() - start_time:.2f}s") + return res + + def get_function_coverage( + self, + file_coverage: FileCoverageMap, + path_suffixes_of_interest=None, + function_keys_of_interest=None, + ) -> FunctionCoverageMap: + res = {} + if function_keys_of_interest is not None: + assert ( + path_suffixes_of_interest is None + ), "Cannot specify both function keys and path suffixes of interest" + path_suffixes_of_interest = [ + os.path.basename(self.get_target_container_path(key)) + for key in function_keys_of_interest + ] + for path, lines in file_coverage.items(): + if path_suffixes_of_interest is not None and not any( + str(path).endswith(suffix) for suffix in path_suffixes_of_interest + ): + continue + res.update( + self.get_function_coverage_for_file( + path, lines, function_keys_of_interest=function_keys_of_interest + ) + ) + return res + + def get_function_coverage_report( + self, + inputs, + function_coverage: FunctionCoverageMap, + keys_of_interest: List[FUNCTION_INDEX_KEY] = None, + ): + reports = [] + if not keys_of_interest: + keys_of_interest = list(function_coverage.keys()) + for key in keys_of_interest: + focus_repo_rel_path, target_container_path, func_start_line, func_code = ( + self.get_code(key) + ) + + report = f"# Coverage Report ({len(inputs)} unique inputs)\n" + report += f"## {key}\n" + report += f"## {target_container_path}:{func_start_line}\n" + if key not in function_coverage: + report += "No coverage was reached in this function.\n" + continue + + func_cov_lines = list( + sorted(function_coverage[key], key=lambda x: x.line_number) + ) + report += f'Line | {"Count":8} | Code\n' + for i, line in enumerate(func_code.split("\n")): + count = None + if ( + func_cov_lines + and func_cov_lines[0].line_number == i + func_start_line + ): + cur = func_cov_lines.pop(0) + count = cur.count_covered + + report += f'{i+func_start_line:4} | {count if count is not None else "":8} | {line}\n' + + reports.append(report) + + report = "\n\n# Function coverage (for the requested functions)\n" + report += "\n\n".join( + reports + if reports + else [ + "ERROR: No coverage was reached in any of the requested functions. You should probably check the coverage of the harness or earlier functions to see where you are getting stuck." + ] + ) + + return report + + def resolve_source_location( + self, + srcloc: SourceLocation, + num_top_matches: int = 3, + allow_build_generated: bool = False, + focus_repo_only: bool = False, + ) -> List[Tuple[FUNCTION_INDEX_KEY, List[FunctionIndexRanking]]]: + contenders_to_rank = [] + + if srcloc.function_name: + # we have a function name. First, if we have a perfect match and there's only one of them, return the perfect match + perfect_function_matches = [] + imperfect_function_matches = [] + for key in self.find_by_funcname(srcloc.function_name): + index_entry = self.get(key) + if index_entry.is_generated_during_build and not allow_build_generated: + continue + if focus_repo_only and self.get(key).focus_repo_relative_path is None: + continue + if ( + index_entry.class_name + and srcloc.java_info + and srcloc.java_info.class_path + ): # java info should be high-confidence, this cannot be a match + # we have a java function, so we need to check the java info as well + if srcloc.java_info.class_path != index_entry.class_name: + continue + ranking = get_function_name_match(srcloc, self.get(key)) + assert ranking, f"Match is None for {srcloc} and {key}" + if ranking.match_kind == MatchKind.DEFINITELY: + perfect_function_matches.append((key, ranking)) + else: + imperfect_function_matches.append((key, ranking)) + + if perfect_function_matches: + if len(perfect_function_matches) == 1: + return [ + ( + perfect_function_matches[0][0], + [perfect_function_matches[0][1]], + ) + ] + + contenders_to_rank = ( + perfect_function_matches + if perfect_function_matches + else imperfect_function_matches + ) + + else: + # we don't have a function name, first, check by filename to find the correct entries + if artiphishell_should_fail_on_error(): + raise NotImplementedError("This is not implemented yet") + return None + + return self._rank_contenders(srcloc, contenders_to_rank, num_top_matches) + + def _rank_contenders( + self, + srcloc: SourceLocation, + contenders_to_rank: List[Tuple[FUNCTION_INDEX_KEY, List[FunctionIndexRanking]]], + num_top_matches: int = 3, + ) -> List[Tuple[FUNCTION_INDEX_KEY, List[FunctionIndexRanking]]]: + if not contenders_to_rank: + return None + + rankings = [] + for key, ranking in contenders_to_rank: + # okay, we aggregate the filename, line number, etc. rankings by just summing them up + total = 0 + rank_vals = [] + # first, we check the relative path + if relative_path_ranking := get_relative_filename_match( + srcloc, self.get(key) + ): + total += relative_path_ranking.match_value + rank_vals.append(("relative_path", relative_path_ranking)) + # then, we check the full file path + if full_file_path_ranking := get_full_file_path_match( + srcloc, self.get(key) + ): + total += full_file_path_ranking.match_value + rank_vals.append(("full_file_path", full_file_path_ranking)) + # then, we check the filename + if filename_ranking := get_filename_match(srcloc, self.get(key)): + total += filename_ranking.match_value + rank_vals.append(("filename", filename_ranking)) + # then, we check the line number + if line_number_ranking := get_line_number_match(srcloc, self.get(key)): + total += line_number_ranking.match_value + rank_vals.append(("line_number", line_number_ranking)) + # then, we check the java info + if java_info_ranking := get_java_info_match(srcloc, self.get(key)): + total += java_info_ranking.match_value + rank_vals.append(("java_info", java_info_ranking)) + + rankings.append((key, total, rank_vals)) + + # check the sorted rankings + rankings = sorted( + rankings, key=lambda x: x[1], reverse=True + ) # sort by highest ranking first + log.debug(f"Rankings for {srcloc}:") + for i, (key, rank, rank_vals) in enumerate(rankings[:5]): + log.debug(f"{i+1}. {key}: {rank} with rank values: {rank_vals}") + log.debug( + f"Returning the highest {num_top_matches} rankings: {rankings[:num_top_matches]}" + ) + return [ + (r[0], [v[1] for v in r[2]]) for r in rankings[:num_top_matches] + ] # return the key and the rank values + + +class LocalFunctionResolver(FunctionResolver): + def __init__(self, functions_index_path: str, functions_jsons_path: str): + super().__init__() + + self.functions_index_path = Path(functions_index_path) + self.functions_jsons_path = Path(functions_jsons_path) + + self.function_full_hashes_write_lock = threading.Lock() + self.function_full_hashes = None + self.function_code_line_hashes = None + self.function_code_hashes = None + + self.cached_func_names = {} + + self.cached_func_codes = {} + self.cached_focus_repo_relative_paths = {} + self.cached_target_container_paths = {} + self.cached_function_boundaries = {} + self.cached_lines_to_function = {} + self.cached_by_filename = {} + self.cached_by_funcname = {} + self.cached_matching_indices = {} + self.cached_code_lines = {} + self.cached_leniency_resolutions = {} + self.cached_with_annotation = {} + + # This has to be done here to avoid leaking memory from the lru_cache, see https://rednafi.com/python/lru_cache_on_methods/ + self.get = lru_cache(maxsize=2048)(self._get) + + # NOTE: since we want to use the LocalFunctionResolver for a commit index, + # we need to detect what we are looking at + try: + # NOTE: For the base case: I think it's faster to fail here than to try to validate the json with pydantic + with open(self.functions_index_path, "r") as infile: + self.functions_index: Dict[str, Path] = { + k: Path(v) for k, v in json.load(infile).items() + } + except Exception: + log.warning("[INFO] 🔄 Not a full index, trying to load as a commit index.") + log.warning("[INFO] 🔄 Attempting loading a commit index...") + # However, if we fail, I want to make sure you are passing me a CommitToFunctionIndex + with open(self.functions_index_path, "r") as infile: + # FIXME: currently this breaks + # _ = CommitToFunctionIndex.model_validate(yaml.safe_load(infile.read())) + + # NOTE: + # If we don't crash, we are looking at a CommitToFunctionIndex! :D + # The structure of this report is: + # { '1_hash' : {'func_sig': 'index'}} + # Since we are gonna have only 1 commit, let's just extract the internal dict and call + # it a day. + try: + # FIX: can we have multiple projects' names here? + # WARNING: Next two lines are absolutely gorgeous 💋 + thedata = json.load(infile) + thedata = { + key: value + for commit, commit_funcs_dict in thedata.items() + for key, value in commit_funcs_dict.items() + } + self.functions_index: Dict[str, Path] = { + k: Path(v) for k, v in thedata.items() + } + except Exception as e: + log.critical( + "[CRITICAL] 🤯 Could not load a function index nor a commit index. Exiting." + ) + log.critical(e) + import traceback + + traceback.print_exc() + raise ValueError( + "Could not load a function index nor a commit index. Exiting." + ) + + for k, v in self.functions_index.items(): + fname = k.split(":")[0] + # TODO: this is a hack to ignore invalid function names, but we should fix this in the future + if not any(fname.endswith(suffix) for suffix in VALID_SOURCE_FILE_SUFFIXES): + log.warning(f"Invalid function name: {fname} in {k!r}") + continue + basename = os.path.basename(fname) + if basename not in self.cached_by_filename: + self.cached_by_filename[basename] = [] + self.cached_by_filename[basename].append(k) + + def is_ready(self) -> bool: + return True + + def keys(self) -> List[FUNCTION_INDEX_KEY]: + return list(self.functions_index.keys()) + + def _get(self, key: FUNCTION_INDEX_KEY) -> FunctionIndex: + if key not in self.functions_index: + raise KeyError(f"Function {key} not found in index") + + if not (self.functions_jsons_path / self.functions_index[key]).exists(): + raise ValueError( + f"Function jsons entry {self.functions_index[key]} does not exist for {key} at {self.functions_jsons_path}: {os.listdir(self.functions_jsons_path)}" + ) + + with open(self.functions_jsons_path / self.functions_index[key], "r") as infile: + result = FunctionIndex.model_validate(json.load(infile)) + + return result + + def __full_scan_of_doom_and_destruction__load_all_hashes_if_needed( + self, scope: Literal["all", "focus", "non-focus", "compiled"] = "all" + ): + # Our alg requires that we have EVERY possible function hash loaded so we can search through them + + # TODO we can split this into sub-sets based on the repo scope + with self.function_full_hashes_write_lock: + if self.function_full_hashes: + return + + self.function_full_hashes = defaultdict(list) + self.function_code_line_hashes = defaultdict(list) + self.function_code_hashes = defaultdict(list) + + # Load every single goddamn function so that we have the hashes for them + + for key in self.functions_index.keys(): + self.function_full_hashes[ + self.get_full_hash(key, do_cache=False) + ].append(key) + self.function_code_line_hashes[ + self.get_code_line_hash(key, do_cache=False) + ].append(key) + self.function_code_hashes[ + self.get_code_hash(key, do_cache=False) + ].append(key) + + def get_funcname(self, key: FUNCTION_INDEX_KEY) -> str: + if key not in self.cached_func_names: + self.cached_func_names[key] = self.get(key).funcname + return self.cached_func_names[key] + + def get_full_hash(self, key: FUNCTION_INDEX_KEY, do_cache: bool = True) -> str: + if self.function_full_hashes and key in self.function_full_hashes: + return self.function_full_hashes[key] + + return super().get_full_hash(key, do_cache) + + def get_code_line_hash(self, key: FUNCTION_INDEX_KEY, do_cache: bool = True) -> str: + if self.function_code_line_hashes and key in self.function_code_line_hashes: + return self.function_code_line_hashes[key] + + return super().get_code_line_hash(key, do_cache) + + def get_code_hash(self, key: FUNCTION_INDEX_KEY, do_cache: bool = True) -> str: + if self.function_code_hashes and key in self.function_code_hashes: + return self.function_code_hashes[key] + + return super().get_code_hash(key, do_cache) + + def get_focus_repo_relative_path(self, key: FUNCTION_INDEX_KEY) -> Optional[Path]: + if key not in self.cached_focus_repo_relative_paths: + res = self.get(key) + self.cached_focus_repo_relative_paths[key] = res.focus_repo_relative_path + return self.cached_focus_repo_relative_paths[key] + + def get_target_container_path(self, key: FUNCTION_INDEX_KEY) -> Path: + if key not in self.cached_target_container_paths: + res = self.get(key) + self.cached_target_container_paths[key] = res.target_container_path + return self.cached_target_container_paths[key] + + def get_code(self, key: FUNCTION_INDEX_KEY) -> Tuple[Path, Path, int, str]: + if key not in self.cached_func_codes: + idx = self.get(key) + self.cached_func_codes[key] = ( + idx.focus_repo_relative_path, + idx.target_container_path, + idx.start_line, + idx.code, + ) + return self.cached_func_codes[key] + + def get_function_boundary(self, key: FUNCTION_INDEX_KEY) -> Tuple[int, int]: + if key not in self.cached_function_boundaries: + idx = self.get(key) + self.cached_function_boundaries[key] = (idx.start_line, idx.end_line) + return self.cached_function_boundaries[key] + + def find_by_funcname(self, s: str) -> Iterator[FUNCTION_INDEX_KEY]: + if s not in self.cached_by_funcname: + self.cached_by_funcname[s] = [ + key + for key in self._find_matching_indices(s) + if self.get_funcname(key).split("::")[-1] == s + ] + for key in self.cached_by_funcname[s]: + yield key + + def find_functions_with_annotation( + self, annotation: str + ) -> Iterator[FUNCTION_INDEX_KEY]: + if annotation not in self.cached_with_annotation.keys(): + self.cached_with_annotation[annotation] = [] + for key in self.functions_index.keys(): + func = self.get(key) + if ( + func.language_specific_info + and "annotations" in func.language_specific_info.keys() + ): + for found_annotation in func.language_specific_info["annotations"]: + if found_annotation["identifier"] == annotation: + self.cached_with_annotation[annotation].append(key) + for key in self.cached_with_annotation[annotation]: + yield key + + def find_matching_indices( + self, + indices: List[FUNCTION_INDEX_KEY], + scope: Literal["all", "focus", "non-focus", "compiled"] = "focus", + can_include_self: bool = True, + can_include_build_generated: bool = True, + ) -> Tuple[Dict[FUNCTION_INDEX_KEY, FUNCTION_INDEX_KEY], List[FUNCTION_INDEX_KEY]]: + assert scope in [ + "all", + "focus", + "non-focus", + "compiled", + ], f"Invalid scope: {scope}" + if scope == "all": + assert ( + not can_include_self + ), "can_include_self=true on `all` scope will always return self..." + + cache = self.cached_matching_indices + out_map = {} + + if len(indices) == 0: + return out_map, [] + + if len(indices) == 1: + # Rather than doing the more complex Aho-Corasick search, we can just do a simple lookup by funcname + found = self.get(indices[0]) + if not found: + return out_map, indices + + # Loads the cache with all the matches for this funcname + self.find_by_funcname(found.funcname) + + missing = set() + + for goal_k in indices: + cache_key = (goal_k, scope, can_include_self) + # Check if we have found this match before + if goal_k in cache: + cached_v = cache[goal_k] + if cached_v is None: + missing.add(goal_k) + else: + out_map[goal_k] = cached_v + continue + + to_find = [k for k in indices if k not in out_map and k not in missing] + + if not to_find: + return out_map, list(missing) + + log.warning(f"Before full scan {time.perf_counter()}") + + self.__full_scan_of_doom_and_destruction__load_all_hashes_if_needed(scope) + + log.warning(f"After full scan {time.perf_counter()}") + + for goal_key in to_find: + matches = set() + + full_hash = self.get_full_hash(goal_key) + matches |= set(self.function_full_hashes.get(full_hash, [])) + + code_line_hash = self.get_code_line_hash(goal_key) + matches |= set(self.function_code_line_hashes.get(code_line_hash, [])) + + code_hash = self.get_code_hash(goal_key) + matches |= set(self.function_code_hashes.get(code_hash, [])) + + meta = self.get(goal_key) + + # log.warning(f"--- Looking for {goal_key} matches (focus repo relative path: {meta.focus_repo_relative_path})") + def allow_build_generated(key: FUNCTION_INDEX_KEY) -> bool: + if can_include_build_generated: + return True + return not self.get(key).is_generated_during_build + + # log.warning(f"Matches: {matches}") + def is_in_scope(key: FUNCTION_INDEX_KEY) -> bool: + if scope == "focus": + return self.get_focus_repo_relative_path(key) is not None + elif scope == "non-focus": + return self.get_focus_repo_relative_path(key) is None + elif scope == "compiled": + return self.get(key).was_directly_compiled + else: + return True + + # filter down the matches based on the scope + log.warning(f"Matches before scope filter: {matches}") + matches = { + key + for key in matches + if is_in_scope(key) and allow_build_generated(key) + } + log.warning(f"Matches after scope filter: {matches}") + + should_have_self = can_include_self and is_in_scope(goal_key) + # log.warning(f"Should have self: {should_have_self}") + + if should_have_self: + matches.add(goal_key) + else: + matches.discard(goal_key) + + if not matches: + missing.add(goal_key) + continue + + if len(matches) == 1: + best_match_key = list(matches)[0] + else: + # If we have multiple matches, we need to rank them + srcloc = function_index_to_source_location(goal_key, meta) + + contenders = [(key, None) for key in matches] + + ranking = self._rank_contenders(srcloc, contenders, num_top_matches=1) + best_match_key = ranking[0][0] + + assert ( + best_match_key is not None + ), f"Could not find a best match for {goal_key}" + + out_map[goal_key] = best_match_key + cache_key = (goal_key, scope, can_include_self) + cache[cache_key] = best_match_key + + if missing: + log.warning(f"Could not find matches for {len(missing)} indices") + # log.warning(f"Remaining: {missing}") + for k in missing: + cache_key = (k, scope, can_include_self) + cache[cache_key] = None + + return out_map, list(missing) + + def find_by_filename(self, s: Union[Path, str]) -> Iterator[FUNCTION_INDEX_KEY]: + basename = os.path.basename(s) + if basename not in self.cached_by_filename: + self.cached_by_filename[basename] = list( + self._find_matching_indices(basename) + ) + + for key in self.cached_by_filename[basename]: + rel_path = self.get_target_container_path(key).relative_to("/") + if str(rel_path).endswith(str(s)) or str(s).endswith(str(rel_path)): + yield key + + def _find_matching_indices(self, s: str) -> Iterator[FUNCTION_INDEX_KEY]: + for key in self.functions_index.keys(): + if s in key: + yield key + return None + + def resolve_with_leniency(self, name: str) -> Iterator[FUNCTION_INDEX_KEY]: + if not name: + return + + if name in self.functions_index: + yield name + return + + if re.fullmatch(r".*:\d+", name): + # okay, we have a filename with a line number, let's just return it + filename, line_number = name.rsplit(":", 1) + line_number = int(line_number) + for key in self.find_by_filename(filename): + start, end = self.get_function_boundary(key) + if start <= line_number <= end: + yield key + return + raise ValueError(f"Could not find any function matching {name}.") + + if name in self.cached_leniency_resolutions: + yield from self.cached_leniency_resolutions[name] + return + + func = list(self.find_by_funcname(name)) + if len(func) >= 1: + yield from func + return + + func = list(self._find_matching_indices(name)) + if len(func) >= 1: + yield from func + return + + if "(" in name: + # okay, try for java to split the path + # import ipdb; ipdb.set_trace() + no_signature = name.rsplit("(", 1)[0] + yield from self.resolve_with_leniency(no_signature) + return + + if "." in name: + # okay, try for java to split the path + # import ipdb; ipdb.set_trace() + class_name, name = name.rsplit(".", 1) + class_name = class_name.replace(".", "/") + ".java" + func_keys = list(self.find_by_filename(class_name)) + func_keys = [k for k in func_keys if self.get_funcname(k) == name] + if len(func_keys) > 0: + yield from func_keys + return + + func_keys = list(self.find_by_funcname(name)) + if len(func_keys) > 0: + yield from func_keys + return + if match := re.fullmatch(r"source:(.+):(\d+):(\d+)::", name): + # it looks like this is supposed to be a function index key. + path_match = match.group(1) + line_match = match.group(2) + possible_keys = list(self.find_by_filename(path_match)) + line_match = int(line_match) + filtered = [ + k + for k in possible_keys + if self.get_function_boundary(k)[0] + <= line_match + <= self.get_function_boundary(k)[1] + ] + if len(filtered) > 0: + yield from filtered + return + # otherwise, it's clearly trying to hallucinate a key. Tell it go kick rocks. + raise ValueError( + f"This looks like a function index key but we could not find any function matching {name}. This function key does not exist, please move on." + ) + + if "::" in name: + # okay, try for cpp try to split the path + # import ipdb; ipdb.set_trace() + class_name, name = name.rsplit("::", 1) + possible_keys = list(self._find_matching_indices(class_name)) + # import ipdb; ipdb.set_trace() + filtered = [ + k + for k in possible_keys + if self.get_funcname(k) == name + and all(sub in k for sub in class_name.split("::")) + ] + if len(filtered) > 0: + yield from filtered + return + if name.upper().startswith("OSS_FUZZ_"): + # okay, try for cpp to split the path + # import ipdb; ipdb.set_trace() + name = name[len("OSS_FUZZ_") :] + func_keys = list(self.find_by_funcname(name)) + if len(func_keys) > 0: + yield from func_keys + return + + raise ValueError(f"Could not find any function matching {name}.") + + +class RemoteFunctionResolver(FunctionResolver): + def __init__(self, cp_name: str, project_id: Union[str, PDT_ID]): + super().__init__() + + self.url = os.getenv("FUNC_RESOLVER_URL", None) + if os.getenv("CRS_TASK_NUM"): + self.url = self.url.replace("TASKNUM", os.getenv("CRS_TASK_NUM")) + else: + if "TASKNUM" in self.url: + raise ValueError( + "Env CRS_TASK_NUM is not set but FUNC_RESOLVER_URL contains TASKNUM" + ) + + if self.url is None: + raise ValueError("FUNC_RESOLVER_URL is not set") + self.cp_name = cp_name + self.project_id = project_id + + self.cached_func_names = {} + self.cached_func_codes = {} + self.cached_focus_repo_relative_paths = {} + self.cached_target_container_paths = {} + self.cached_function_boundaries = {} + self.cached_lines_to_function = {} + self.cached_by_filename = {} + self.cached_by_funcname = {} + self.cached_code_lines = {} + self.cached_matching_indices = {} + self.cached_leniency_resolutions = {} + self.cached_with_annotation = {} + + self.get = lru_cache(maxsize=512)(self._get) + + def is_ready(self) -> bool: + r = requests.get( + f"{self.url}/health", + params={ + "cp_name": self.cp_name, + "project_id": self.project_id, + }, + ) + + if r.status_code != 200: + return False + + result = r.json() + if ( + result.get("status", None) == "error" + and result.get("data", None) == "Server not initialized" + ): + return False + + return True + + def _make_request(self, endpoint: str, data: dict) -> dict: + while True: + r = requests.post(f"{self.url}/{endpoint}", data=data) + if r.status_code != 200: + # These are always critical errors we must fix + assert False, f"Internal Server Error in /{endpoint} : {r.text}" + result = r.json() + if ( + result.get("status", None) == "error" + and result.get("data", None) == "Server not initialized" + ): + log.warning( + f"Function resolver server not initialized, waiting 30 seconds before retrying /{endpoint}" + ) + time.sleep(30) + continue + return result + + def keys(self): + data = {"cp_name": self.cp_name, "project_id": self.project_id} + result = self._make_request("keys", data) + + api_status = result.get("status", None) + assert api_status is not None, f"API status code is None: {result}" + if api_status == "error": + # This means the function was not found in the index. + # Users will have to handle this Exception themselves + raise KeyError(f"Function keys not found in index: {result}") + else: + assert ( + api_status == "success" + ), f"API status code is not success|error: {result}" + # This means the function was found in the index. + # The response should be a dict with the function index + return result.get("data", []) + + def _get(self, key: FUNCTION_INDEX_KEY) -> FunctionIndex: + data = {"cp_name": self.cp_name, "project_id": self.project_id, "key": key} + + result = self._make_request("get", data) + + api_status = result.get("status", None) + assert api_status is not None, f"API status code is None for {key}: {result}" + if api_status == "error": + # This means the function was not found in the index. + # Users will have to handle this Exception themselves + raise KeyError(f"Function {key} not found in index: {result}") + else: + assert ( + api_status == "success" + ), f"API status code is not success|error for {key}: {result}" + # This means the function was found in the index. + # The response should be a dict with the function index + return FunctionIndex.model_validate(result.get("data", None)) + + def get_many(self, keys): + # optimized implementation for getting many keys in one request + data = {"cp_name": self.cp_name, "project_id": self.project_id, "keys": keys} + + result = self._make_request("get_many", data) + + api_status = result.get("status", None) + assert api_status is not None, f"API status code is None for {keys}: {r.text}" + if api_status == "error": + # This means the function was not found in the index. + # Users will have to handle this Exception themselves + raise KeyError(f"Function {keys} not found in index: {result}") + else: + assert ( + api_status == "success" + ), f"API status code is not success|error for {keys}: {result}" + # This means the function was found in the index. + # The response should be a dict with the function index + return { + key: FunctionIndex.model_validate(value) + for key, value in result.get("data", {}).items() + } + + def get_focus_repo_keys(self, focus_repo_container_path): + data = { + "cp_name": self.cp_name, + "project_id": self.project_id, + "focus_repo_container_path": focus_repo_container_path, + } + + result = self._make_request("get_focus_repo_keys", data) + + api_status = result.get("status", None) + assert ( + api_status is not None + ), f"API status code is None for {focus_repo_container_path}: {result}" + if api_status == "error": + raise KeyError(f"Function keys not found in index: {result}") + else: + assert ( + api_status == "success" + ), f"API status code is not success|error for {focus_repo_container_path}: {result}" + # This means the function was found in the index. + return result.get("data", []) + + def get_filtered_keys( + self, key_only_jq_filter_expression: str + ) -> List[FUNCTION_INDEX_KEY]: + """ + Get a filtered list of function indices based on a jq filter expression. The jq filter expression operates directly on the function index key. + """ + data = { + "cp_name": self.cp_name, + "project_id": self.project_id, + "key_only_filter_expression": key_only_jq_filter_expression, + } + + result = self._make_request("get_filtered_keys", data) + + api_status = result.get("status", None) + assert ( + api_status is not None + ), f"API status code is None for {key_only_jq_filter_expression}: {result}" + if api_status == "error": + raise KeyError(f"Function keys not found in index: {result}") + else: + assert ( + api_status == "success" + ), f"API status code is not success|error for {key_only_jq_filter_expression}: {result}" + # This means the function was found in the index. + return result.get("data", []) + + def get_filtered( + self, key_only_filter_expression: str, full_filter_expression: str + ) -> Dict[FUNCTION_INDEX_KEY, FunctionIndex]: + """ + Get a filtered list of function indices based on a jq filter expression. The jq filter expression operates directly on the function index key. + """ + data = { + "cp_name": self.cp_name, + "project_id": self.project_id, + "key_only_filter_expression": key_only_filter_expression, + "full_filter_expression": full_filter_expression, + } + result = self._make_request("get_filtered", data) + + api_status = result.get("status", None) + assert ( + api_status is not None + ), f"API status code is None for {key_only_filter_expression}: {result}" + if api_status == "error": + raise KeyError(f"Function keys not found in index: {result}") + else: + assert ( + api_status == "success" + ), f"API status code is not success|error for {key_only_filter_expression}: {result}" + # This means the function was found in the index. + return { + k: FunctionIndex.model_validate(v) + for k, v in result.get("data", {}).items() + } + + def get_funcname(self, key: FUNCTION_INDEX_KEY) -> str: + if key in self.cached_func_names: + return self.cached_func_names[key] + else: + data = {"cp_name": self.cp_name, "project_id": self.project_id, "key": key} + + result = self._make_request("get_funcname", data) + + api_status = result.get("status", None) + assert ( + api_status is not None + ), f"API status code is None for {key}: {result}" + + if api_status == "error": + # This means the function was not found in the index. + # Users will have to handle this Exception themselves + raise KeyError(f"Function name for {key} not found in index: {result}") + else: + assert ( + api_status == "success" + ), f"API status code is not success|error for {key}: {result}" + # This means the function was found in the index. + func_name = result.get("data", None) + assert ( + func_name is not None + ), f"Function name is None for {key}: {result}" + + self.cached_func_names[key] = func_name + return self.cached_func_names[key] + + def get_focus_repo_relative_path(self, key: FUNCTION_INDEX_KEY) -> Optional[Path]: + if key in self.cached_focus_repo_relative_paths: + return self.cached_focus_repo_relative_paths[key] + else: + data = {"cp_name": self.cp_name, "project_id": self.project_id, "key": key} + result = self._make_request("get_focus_repo_relative_path", data) + + api_status = result.get("status", None) + assert ( + api_status is not None + ), f"API status code is None for {key}: {result}" + if api_status == "error": + raise KeyError(f"Relative path for {key} not found in index: {result}") + else: + assert ( + api_status == "success" + ), f"API status code is not success|error for {key}: {result}" + # This means the function was found in the index. + rel_path = result.get("data", None) + assert rel_path is not None, f"rel_path is None for {key}: {result}" + + # Convert to a tuple RelativePathKind and Path objects + self.cached_focus_repo_relative_paths[key] = ( + RelativePathKind(rel_path[0]), + Path(rel_path[1]), + ) + + return self.cached_focus_repo_relative_paths[key] + + def get_target_container_path(self, key: FUNCTION_INDEX_KEY) -> Path: + if key in self.cached_target_container_paths: + return self.cached_target_container_paths[key] + + data = {"cp_name": self.cp_name, "project_id": self.project_id, "key": key} + result = self._make_request("get_target_container_path", data) + + api_status = result.get("status", None) + assert api_status is not None, f"API status code is None for {key}: {result}" + if api_status == "error": + raise KeyError( + f"Target container path for {key} not found in index: {result}" + ) + else: + assert ( + api_status == "success" + ), f"API status code is not success|error for {key}: {result}" + + target_container_path = result.get("data", None) + assert ( + target_container_path is not None + ), f"Target container path is None for {key}: {r.text} {result}" + + self.cached_target_container_paths[key] = Path(target_container_path) + return self.cached_target_container_paths[key] + + def get_code(self, key: FUNCTION_INDEX_KEY) -> Tuple[Path, Path, int, str]: + if key in self.cached_func_codes: + return self.cached_func_codes[key] + else: + data = {"cp_name": self.cp_name, "project_id": self.project_id, "key": key} + + result = self._make_request("get_code", data) + + api_status = result.get("status", None) + assert ( + api_status is not None + ), f"API status code is None for {key}: {result}" + if api_status == "error": + # This means the function was not found in the index. + # Users will have to handle this Exception themselves + raise KeyError(f"Code for {key} not found in index: {result}") + else: + assert ( + api_status == "success" + ), f"API status code is not success|error for {key}: {result}" + # This means the function was found in the index. + # The response should be a list of 4 elements: RelativePathKind, Path, start_line, code + code = result.get("data", None) + assert code is not None, f"Code is None for {key}: {result}" + assert ( + len(code) == 4 + ), f"Code is not a list of 4 elements for {key}: {result}" + # Convert to a tuple RelativePathKind and Path objects + self.cached_func_codes[key] = ( + Path(code[0]) if code[0] else None, + Path(code[1]) if code[1] else None, + code[2], + code[3], + ) + return self.cached_func_codes[key] + + def get_function_boundary(self, key: FUNCTION_INDEX_KEY) -> Tuple[int, int]: + if key in self.cached_function_boundaries: + return self.cached_function_boundaries[key] + else: + data = {"cp_name": self.cp_name, "project_id": self.project_id, "key": key} + + result = self._make_request("get_function_boundary", data) + + api_status = result.get("status", None) + assert ( + api_status is not None + ), f"API status code is None for {key}: {result}" + if api_status == "error": + # This means the function was not found in the index. + # Users will have to handle this Exception themselves + raise KeyError( + f"Function boundary for {key} not found in index: {result}" + ) + else: + assert ( + api_status == "success" + ), f"API status code is not success|error for {key}: {result}" + # This means the function was found in the index. + func_boundaries = result.get("data", None) + self.cached_function_boundaries[key] = ( + func_boundaries[0], + func_boundaries[1], + ) + return self.cached_function_boundaries[key] + + def find_functions_with_annotation( + self, annotation: str + ) -> Iterator[FUNCTION_INDEX_KEY]: + if not annotation: + return + + if annotation not in self.cached_with_annotation: + data = { + "cp_name": self.cp_name, + "project_id": self.project_id, + "annotation": annotation, + } + result = self._make_request("find_functions_with_annotation", data) + + api_status = result.get("status", None) + assert ( + api_status is not None + ), f"API status code is None for {annotation}: {result}" + if api_status == "error": + # This means the function was not found in the index. + # Users will have to handle this Exception themselves + raise KeyError( + f"No results for name {annotation} in function index: {result}" + ) + else: + assert ( + api_status == "success" + ), f"API status code is not success|error for {annotation}: {result}" + # This means the function was found in the index. + # The response should be a list of function keys + matches = result.get("data", None) + assert ( + matches is not None + ), f"matches is None for {annotation}: {result}" + + self.cached_with_annotation[annotation] = matches + yield from self.cached_with_annotation[annotation] + + def find_by_funcname(self, s: str) -> Iterator[FUNCTION_INDEX_KEY]: + if not s: + return + + if s not in self.cached_by_funcname: + data = {"cp_name": self.cp_name, "project_id": self.project_id, "s": s} + + result = self._make_request("find_by_funcname", data) + + api_status = result.get("status", None) + assert api_status is not None, f"API status code is None for {s}: {result}" + if api_status == "error": + # This means the function was not found in the index. + # Users will have to handle this Exception themselves + raise KeyError(f"No results for name {s} in function index: {result}") + else: + assert ( + api_status == "success" + ), f"API status code is not success|error for {s}: {result}" + # This means the function was found in the index. + # The response should be a list of function keys + matches = result.get("data", None) + assert matches is not None, f"matches is None for {s}: {result}" + + self.cached_by_funcname[s] = matches + + yield from self.cached_by_funcname[s] + + def find_matching_indices( + self, + indices: List[FUNCTION_INDEX_KEY], + scope: Literal["all", "focus", "non-focus", "compiled"] = "focus", + can_include_self: bool = True, + can_include_build_generated: bool = True, + ) -> Tuple[Dict[FUNCTION_INDEX_KEY, FUNCTION_INDEX_KEY], List[FUNCTION_INDEX_KEY]]: + if scope == "all": + assert ( + not can_include_self + ), "can_include_self=true on `all` scope will always return self..." + + cached_values = {} + cached_missing = [] + uncached_indices = [] + for k in indices: + cache_key = (k, scope, can_include_self) + if cache_key in self.cached_matching_indices: + cached_val = self.cached_matching_indices[cache_key] + if cached_val is not None: + cached_values[k] = cached_val + else: + cached_missing.append(k) + else: + uncached_indices.append(k) + + if len(uncached_indices) == 0: + return cached_values, cached_missing + + data = { + "cp_name": self.cp_name, + "project_id": self.project_id, + "indices": indices, + "scope": scope, + "can_include_self": can_include_self, + } + + result = self._make_request("find_matching_indices", data) + + api_status = result.get("status", None) + assert ( + api_status is not None + ), f"API status code is None for {len(indices)} strings: {result}" + if api_status == "error": + raise Exception( + f"Error in find_matching_indices for {len(indices)} strings: {result}" + ) + + assert ( + api_status == "success" + ), f"API status code is not success|error for {len(indices)} strings: {result}" + + # The response should be a list of function keys + matches = result.get("matching", None) + for k, v in matches.items(): + self.cached_matching_indices[k] = v + + missing = result.get("missing", None) + for k in missing: + self.cached_matching_indices[k] = None + + # combine the previous cached values with the new ones + cached_values.update(matches) + cached_missing.extend(missing) + + return cached_values, cached_missing + + def find_by_filename(self, s: Union[Path, str]) -> Iterator[FUNCTION_INDEX_KEY]: + if not s: + return + + if s not in self.cached_by_filename: + data = {"cp_name": self.cp_name, "project_id": self.project_id, "s": str(s)} + result = self._make_request("find_by_filename", data) + + api_status = result.get("status", None) + assert api_status is not None, f"API status code is None for {s}: {result}" + if api_status == "error": + # This means the function was not found in the index. + # Users will have to handle this Exception themselves + raise KeyError( + f"No results for filename {s} in function index: {result}" + ) + else: + assert ( + api_status == "success" + ), f"API status code is not success|error for {s}: {result}" + # This means the function was found in the index. + # The response should be a list of function keys + matches = result.get("data", None) + assert matches is not None, f"matches is None for {s}: {result}" + + self.cached_by_filename[s] = matches + + yield from self.cached_by_filename[s] + + def resolve_with_leniency(self, name: str) -> Iterator[FUNCTION_INDEX_KEY]: + if not name: + return + + if name not in self.cached_leniency_resolutions: + data = { + "cp_name": self.cp_name, + "project_id": self.project_id, + "name": name, + } + + result = self._make_request("resolve_with_leniency", data) + + api_status = result.get("status", None) + assert ( + api_status is not None + ), f"API status code is None for {name}: {result}" + if api_status == "error": + # This means the function was not found in the index. + # Users will have to handle this Exception themselves + raise KeyError( + f"No results for name {name} in function index: {result}" + ) + else: + assert ( + api_status == "success" + ), f"API status code is not success|error for {name}: {result}" + # This means the function was found in the index. + # The response should be a list of function keys + matches = result.get("data", None) + assert matches is not None, f"matches is None for {name}: {result}" + self.cached_leniency_resolutions[name] = matches + + yield from self.cached_leniency_resolutions[name] + + def upload( + self, full_functions_index_path: Path, full_functions_index_jsons_dir: Path + ): + with tempfile.TemporaryDirectory() as temp_dir: + temp_dir = Path(temp_dir) + (temp_dir / "functions_index").mkdir() + (temp_dir / "functions_jsons").mkdir() + shutil.copy(full_functions_index_path, temp_dir / self.project_id) + subprocess.run( + ["tar", "-cvf", "functions_index/functions_index.tar", self.project_id], + check=True, + cwd=temp_dir, + ) + os.unlink(temp_dir / self.project_id) + + shutil.copytree(full_functions_index_jsons_dir, temp_dir / self.project_id) + subprocess.run( + [ + "tar", + "-cvf", + os.path.join(temp_dir, "functions_jsons/functions_jsons.tar"), + ".", + ], + check=True, + cwd=temp_dir / self.project_id, + ) + shutil.rmtree(temp_dir / self.project_id) + + subprocess.run( + [ + "tar", + "-cvf", + "data.tar", + "functions_index/functions_index.tar", + "functions_jsons/functions_jsons.tar", + ], + check=True, + cwd=temp_dir, + ) + + data = { + "cp_name": self.cp_name, + "project_id": self.project_id, + } + with open(temp_dir / "data.tar", "rb") as f: + response = requests.post( + f"{self.url}/init_server", data=data, files={"data": f} + ) + response.raise_for_status() + return response.json() + + +def function_resolver_upload(): + import argparse + + parser = argparse.ArgumentParser( + description="Upload function index to the function resolver server" + ) + parser.add_argument( + "project_name", type=str, help="Name of the project (aka cp_name)" + ) + parser.add_argument("project_id", type=str, help="Project ID") + parser.add_argument( + "full_functions_index_path", type=Path, help="Path to the full functions index" + ) + parser.add_argument( + "full_functions_index_jsons_dir", + type=Path, + help="Path to the full functions index jsons dir", + ) + args = parser.parse_args() + + resolver = RemoteFunctionResolver(args.project_name, args.project_id) + result = resolver.upload( + args.full_functions_index_path, args.full_functions_index_jsons_dir + ) + print(result) + + +def function_resolver_upload_backup(): + import argparse + + parser = argparse.ArgumentParser( + description="Upload function index to the function resolver server" + ) + parser.add_argument("backup_dir", type=Path, help="Path to the backup directory") + args = parser.parse_args() + + project_ids = [] + for f in os.listdir( + args.backup_dir / "generate_full_function_index.target_functions_index" + ): + assert ( + f.split(".")[0] == f + ), f"the type of target_functions_index has changed: {f}" + project_ids.append(f.split(".")[0]) + + for project_id in project_ids: + with tempfile.TemporaryDirectory() as tempdir: + if not os.path.isdir( + args.backup_dir + / "generate_full_function_index.target_functions_jsons_dir" + / project_id + ): + tar_path = ( + args.backup_dir + / "generate_full_function_index.target_functions_jsons_dir" + / f"{project_id}.tar.gz" + ) + # extract the tar + subprocess.check_call( + ["tar", "-xvf", tar_path], + cwd=tempdir, + ) + else: + subprocess.check_call( + [ + "rsync", + "-ra", + str( + args.backup_dir + / "generate_full_function_index.target_functions_jsons_dir" + / project_id + ) + + "/", + tempdir, + ], + ) + + full_functions_index_jsons_dir = Path(tempdir) + full_functions_index_path = ( + args.backup_dir + / "generate_full_function_index.target_functions_index" + / project_id + ) + crs_task = ( + args.backup_dir + / "generate_full_function_index.crs_task" + / f"{project_id}.yaml" + ) + with open(crs_task, "r") as f: + project_name = yaml.safe_load(f)["project_name"] + resolver = RemoteFunctionResolver(project_name, project_id) + result = resolver.upload( + full_functions_index_path, full_functions_index_jsons_dir + ) + print(result) diff --git a/patchery/data/models/__init__.py b/patchery/data/models/__init__.py new file mode 100644 index 0000000..c0ad581 --- /dev/null +++ b/patchery/data/models/__init__.py @@ -0,0 +1,6 @@ +from .base import ShellphishBaseModel +from .indexer import * +from .symbols import * +from .crash_reports import * +from .target import * +from .crs_reports import * diff --git a/patchery/data/models/aixcc_api.py b/patchery/data/models/aixcc_api.py new file mode 100644 index 0000000..9ef8adb --- /dev/null +++ b/patchery/data/models/aixcc_api.py @@ -0,0 +1,376 @@ +# 99e6bb03e45b2717208d1467162f521f1993afa1 +# AUTOGENERATED BY afc-api-schema, but you can change the descriptions if you like + +from typing import Optional, List, Dict, Any +from pydantic import Field, AliasChoices +from shellphish_crs_utils.models.base import ShellphishBaseModel +from shellphish_crs_utils.models.constraints import PDT_ID +from uuid import UUID +from enum import Enum + + +class SourceType(str, Enum): + """Type of source being provided""" + + SourceTypeRepo = "repo" + SourceTypeFuzzTooling = "fuzz-tooling" + SourceTypeDiff = "diff" + + +class TaskType(str, Enum): + """Type of task being requested""" + + TaskTypeFull = "full" + TaskTypeDelta = "delta" + + +class Assessment(str, Enum): + """Assessment of a SARIF report""" + + AssessmentCorrect = "correct" + AssessmentIncorrect = "incorrect" + + +class Architecture(str, Enum): + """Target architecture""" + + ArchitectureX8664 = "x86_64" + + +class SubmissionStatus(str, Enum): + """Status of a submission""" + + SubmissionStatusAccepted = "accepted" # Successfully submitted + SubmissionStatusPassed = "passed" # Successfully evaluated submission + SubmissionStatusFailed = "failed" # Submission failed testing + SubmissionStatusDeadlineExceeded = "deadline_exceeded" # Task deadline exceeded. All submissions marked accepted before the deadline will be evaluated. + SubmissionStatusErrored = "errored" # Server side error when testing submission + SubmissionInconclusive = "inconclusive" # Submission was inconclusive + + +class SourceDetail(ShellphishBaseModel): + """Details about a source to analyze""" + + sha256: str = Field(description="Integrity hash of the gzipped tarball") + type: SourceType = Field(description="Type of the source") + url: str = Field(description="URL to fetch the source gzipped tarball") + + +class StatusTasksState(ShellphishBaseModel): + """State counts for tasks""" + + canceled: int = Field( + description="Number of tasks that competition infrastructure has cancelled" + ) + errored: int = Field( + description="Number of tasks that the CRS encountered an unrecoverable issue for" + ) + failed: int = Field( + description="Number of submissions that the competition infrastructure marked failed" + ) + pending: int = Field( + description="Number of tasks that the CRS has not started work on" + ) + processing: int = Field( + description="Number of tasks that the CRS is currently processing" + ) + succeeded: int = Field( + description="Number of submissions that the competition infrastructure marked passed" + ) + waiting: int = Field( + description="Number of submissions that the competition infrastructure is currently testing" + ) + + +class StatusState(ShellphishBaseModel): + """Overall state information""" + + tasks: StatusTasksState = Field(description="Task state information") + + +class Status(ShellphishBaseModel): + """Status of the CRS""" + + details: Optional[Dict[str, str]] = Field( + description="This is optional arbitrary content that may be logged in error cases, but is mainly for interactive troubleshooting.\n\nKeep in mind this endpoint is unauthenticated. Do not place sensitive details in this object.", + default=None, + ) + ready: bool = Field( + description="Boolean indicating if the CRS is prepared to work on tasks. Do not return true unless you have successfully tested connectivity to the Competition API via /v1/ping/" + ) + since: int = Field(description="Last time task and submission stats were reset") + state: StatusState = Field( + description="State of the currently running tasks and submissions" + ) + version: str = Field( + description="Version string for verification and reproducibility\n\n- git commit\n\n- SemVer\n\n- etc" + ) + + +class TaskDetail(ShellphishBaseModel): + """Details of a specific task""" + + deadline: int = Field( + description="UNIX timestamp by which any submissions for this task must be in" + ) + source: List[SourceDetail] = Field( + description="List of sources needed to evaluate a task" + ) + task_id: UUID = Field( + description="Unique identifier for the task", + validation_alias=AliasChoices("task.id", "task_id"), + ) + type: TaskType = Field(description="Type of task") + focus: str = Field( + description="Because the challenge task may contain multiple repositories, the folder in the type repo source tarball containing the main project.\nIt is still set when there is only one repository.\n\nThis is the project the CRS is meant to submit vulns, patches, and SARIF assessments against." + ) + project_name: str = Field(description="OSS Fuzz project name") + metadata: Dict[str, str] = Field( + description="String to string map containing data that should be attached to outputs like log messages and OpenTelemetry trace attributes for traceability" + ) + harnesses_included: bool = Field(description="Indicates if harnesses are included") + + +class Task(ShellphishBaseModel): + """Task message from DARPA""" + + message_id: UUID = Field(description="Unique identifier for the message") + message_time: int = Field(description="Unix timestamp of message creation") + tasks: List[TaskDetail] = Field(description="List of tasks to perform") + + +class SARIFBroadcastDetail(ShellphishBaseModel): + """Details of a SARIF broadcast""" + + metadata: Dict[str, str] = Field( + description="String to string map containing data that should be attached to outputs like log messages and OpenTelemetry trace attributes for traceability" + ) + sarif: Dict[str, Any] = Field( + description="SARIF Report compliant with provided schema" + ) + sarif_id: UUID = Field(description="Unique identifier for the SARIF report") + task_id: UUID = Field(description="ID of task this SARIF report was generated for") + + +class SARIFMetadata(ShellphishBaseModel): + """Metadata for a SARIF report""" + + metadata: Dict[str, str] = Field( + description="String to string map containing data that should be attached to outputs like log messages and OpenTelemetry trace attributes for traceability" + ) + sarif_id: UUID = Field(description="Unique identifier for the SARIF report") + task_id: UUID = Field(description="ID of task this SARIF report was generated for") + pdt_sarif_id: PDT_ID = Field(description="PDT identifier for the SARIF report") + pdt_task_id: PDT_ID = Field(description="PDT identifier for the task") + assessment: Assessment = Field( + description="Assessment verdict", default=Assessment.AssessmentCorrect + ) + description: Optional[str] = Field( + description="Optional plain text reasoning for the assessment\n\n128KiB max size", + default="", + ) + + +class SARIFBroadcast(ShellphishBaseModel): + """SARIF broadcast message""" + + broadcasts: List[SARIFBroadcastDetail] = Field( + description="List of SARIF broadcasts" + ) + message_id: UUID = Field(description="Unique identifier for the message") + message_time: int = Field(description="Unix timestamp of message creation") + + +class Error(ShellphishBaseModel): + """Error response""" + + fields: Optional[Dict[str, str]] = Field( + description="Field-specific error messages", default=None + ) + message: str = Field(description="Error message") + + +class PatchSubmission(ShellphishBaseModel): + """Patch submission for a vulnerability""" + + description: Optional[str] = Field( + description="Optional plain text reasoning for the assessment\n\n128KiB max size", + max_length=131072, + default=None, + ) + patch: str = Field( + description="Base64 encoded patch in unified diff format\n\n100KiB max size before Base64 encoding", + json_schema_extra={"format": "base64"}, + ) + sarif_id: Optional[UUID] = Field( + description="Optional ID of SARIF Broadcast this patch is associated with", + default=None, + ) + vuln_id: Optional[UUID] = Field( + description="Optional ID of Vuln this patch is associated with", default=None + ) + + +class PatchSubmissionResponse(ShellphishBaseModel): + """Response to a patch submission""" + + patch_id: UUID = Field(description="Unique identifier for the patch") + status: SubmissionStatus = Field(description="Status of the submission") + functionality_tests_passing: Optional[bool] = Field( + description="null indicates the tests have not been run", default=None + ) + project_id: PDT_ID | None = Field( + description="PDT identifier for the project", default=None + ) + + +class SarifAssessmentSubmission(ShellphishBaseModel): + """Assessment submission for a SARIF report""" + + assessment: Assessment = Field(description="Assessment verdict") + description: str = Field( + description="Plain text reasoning for the assessment\n\n128KiB max size", + max_length=131072, + ) + + +class SarifAssessmentResponse(ShellphishBaseModel): + """Response to a SARIF assessment submission""" + + status: SubmissionStatus = Field(description="Status of the assessment submission") + project_id: PDT_ID | None = Field( + description="PDT identifier for the project", default=None + ) + + +class ExtendedSarifAssessmentResponse(SarifAssessmentResponse): + assessment: Assessment = Field(description="Assessment verdict") + + +class POVSubmission(ShellphishBaseModel): + """POV submission""" + + architecture: Architecture = Field(description="Target architecture") + testcase: str = Field( + description="Base64 encoded vuln trigger\n\n2MiB max size before Base64 encoding", + json_schema_extra={"format": "base64"}, + ) + fuzzer_name: str = Field( + description="Fuzz Tooling fuzzer that exercises this vuln\n\n4KiB max size", + max_length=4096, + ) + sanitizer: str = Field( + description="Fuzz Tooling Sanitizer that exercises this vuln\n\n4KiB max size", + max_length=4096, + ) + engine: str = Field( + description="Fuzz Tooling Engine that exercises this vuln. Allowable engine values are specified in project.yaml.\n\n4KiB max size", + max_length=4096, + ) + + +class POVSubmissionResponse(ShellphishBaseModel): + """Response to a POV submission""" + + status: SubmissionStatus = Field(description="Status of the submission") + pov_id: UUID = Field(description="Unique identifier for the POV") + project_id: PDT_ID | None = Field( + description="PDT identifier for the project", default=None + ) + + +class SARIFSubmission(ShellphishBaseModel): + """SARIF submission""" + + sarif: Dict[str, Any] = Field( + description="SARIF object compliant with the provided schema" + ) + + +class SARIFSubmissionResponse(ShellphishBaseModel): + """Response to a SARIF submission""" + + status: SubmissionStatus = Field(description="Status of the submission") + submitted_sarif_id: UUID = Field( + description="Unique identifier for the submitted SARIF" + ) + project_id: PDT_ID | None = Field( + description="PDT identifier for the project", default=None + ) + + +class FreeformSubmission(ShellphishBaseModel): + """Freeform submission""" + + submission: str = Field( + description="Base64 encoded arbitrary data\n\n2MiB max size before Base64 encoding" + ) + + +class FreeformResponse(ShellphishBaseModel): + """Response to a freeform submission""" + + freeform_id: UUID = Field( + description="Unique identifier for the freeform submission" + ) + status: SubmissionStatus = Field(description="Status of the submission") + + +class BundleSubmission(ShellphishBaseModel): + """Bundle submission""" + + broadcast_sarif_id: Optional[UUID] = Field( + description="ID of the broadcast SARIF", default=None + ) + description: Optional[str] = Field( + description="optional plaintext description of the components of the bundle, such as would be found in a pull request description or bug report", + default=None, + ) + patch_id: Optional[UUID] = Field(description="ID of the patch", default=None) + pov_id: Optional[UUID] = Field(description="ID of the POV", default=None) + submitted_sarif_id: Optional[UUID] = Field( + description="ID of the submitted SARIF", default=None + ) + freeform_id: Optional[UUID] = Field( + description="ID of the freeform submission", default=None + ) + + +class BundleSubmissionResponse(ShellphishBaseModel): + """Response to a bundle submission""" + + bundle_id: UUID = Field(description="Unique identifier for the bundle") + status: SubmissionStatus = Field(description="Status of the submission") + project_id: PDT_ID | None = Field( + description="PDT identifier for the project", default=None + ) + + +class BundleSubmissionResponseVerbose(ShellphishBaseModel): + """Verbose response to a bundle submission""" + + broadcast_sarif_id: Optional[UUID] = Field( + description="ID of the broadcast SARIF", default=None + ) + bundle_id: UUID = Field(description="Unique identifier for the bundle") + description: Optional[str] = Field( + description="Description of the bundle", default=None + ) + patch_id: Optional[UUID] = Field(description="ID of the patch", default=None) + pov_id: Optional[UUID] = Field(description="ID of the POV", default=None) + status: SubmissionStatus = Field(description="Status of the submission") + submitted_sarif_id: Optional[UUID] = Field( + description="ID of the submitted SARIF", default=None + ) + freeform_id: Optional[UUID] = Field( + description="ID of the freeform submission", default=None + ) + project_id: PDT_ID | None = Field( + description="PDT identifier for the project", default=None + ) + + +class PingResponse(ShellphishBaseModel): + """Response to a ping request""" + + status: str = Field(description="Status of the ping") diff --git a/patchery/data/models/base.py b/patchery/data/models/base.py new file mode 100644 index 0000000..18a7e07 --- /dev/null +++ b/patchery/data/models/base.py @@ -0,0 +1,8 @@ +from pydantic import BaseModel, ConfigDict + + +class ShellphishBaseModel(BaseModel): + # THIS IS CRITICAL, do not remove without checking with @clasm, @honululu or @Amy-B. + # If you are overriding this, make sure to set extra='forbid'. But still check with us. + + model_config = ConfigDict(extra="forbid") diff --git a/patchery/data/models/constraints.py b/patchery/data/models/constraints.py new file mode 100644 index 0000000..7895336 --- /dev/null +++ b/patchery/data/models/constraints.py @@ -0,0 +1,23 @@ +from typing import Union +from pydantic import StringConstraints + +PDT_ID = Union[str, int] + +ID_REGEX = r"^id_[0-9]+$" +ID_CONSTRAINTS = StringConstraints(strip_whitespace=True, pattern=ID_REGEX) + +SHA1_REGEX = r"[0-9a-f]{40}" +SHA1_CONSTRAINTS = StringConstraints( + strip_whitespace=True, + pattern=SHA1_REGEX, + max_length=40, + min_length=40, +) + +MD5_REGEX = r"[0-9a-f]{32}" +MD5_CONSTRAINTS = StringConstraints( + strip_whitespace=True, + pattern=MD5_REGEX, + max_length=32, + min_length=32, +) diff --git a/patchery/data/models/coverage.py b/patchery/data/models/coverage.py new file mode 100644 index 0000000..410cd60 --- /dev/null +++ b/patchery/data/models/coverage.py @@ -0,0 +1,42 @@ +from dataclasses import dataclass +from enum import Enum +from pathlib import Path +from typing import Dict, List, Optional, Tuple, TypeAlias, Union +from patchery.data.models.indexer import FUNCTION_INDEX_KEY + + +# THIS IS EXTREMELY PERFORMANCE CRITICAL for coverage collection +# DO NOT change this to a pydantic model, it doesn't benefit much and it's SIGNIFICANTLY slower +@dataclass +class CoverageLine: + line_number: int + count_covered: Optional[int] = None + code: Optional[str] = None + + @property + def count(self): + return self.count_covered + + @count.setter + def count(self, value): + self.count_covered = value + + def can_be_covered(self): + return self.count_covered is not None + + def as_tuple(self): + return (self.line_number, self.count_covered, self.code) + + +class SeedCoverageExitStatus(Enum): + CRASH = "crash" + TIMEOUT = "timeout" + SUCCESS = "success" + UNKNOWN = "unknown" + + +LinesCoverage: TypeAlias = List[CoverageLine] +FileCoverage: TypeAlias = Tuple[Union[Path, str], LinesCoverage] +FileCoverageMap: TypeAlias = Dict[Union[Path, str], LinesCoverage] +FunctionCoverage: TypeAlias = Tuple[FUNCTION_INDEX_KEY, LinesCoverage] +FunctionCoverageMap: TypeAlias = Dict[FUNCTION_INDEX_KEY, LinesCoverage] diff --git a/patchery/data/models/crash_reports.py b/patchery/data/models/crash_reports.py new file mode 100644 index 0000000..f12e40a --- /dev/null +++ b/patchery/data/models/crash_reports.py @@ -0,0 +1,511 @@ +from enum import Enum +import logging +from pathlib import Path +import re +from typing import Any, Dict, List, Optional +from pydantic import Field, model_validator +from patchery.data.function_resolver import FunctionResolver +from patchery.data.models import JavaInfo +from patchery.data.models.base import ShellphishBaseModel +from patchery.data.models.symbols import BinaryLocation, SourceLocation +from patchery.utils import artiphishell_should_fail_on_error + +log = logging.getLogger(__name__) + + +class BacktraceType(Enum): + source = "source" + binary = "binary" + entrypoint = "entrypoint" + instrumentation = "instrumentation" + sanitizer_instrumentation = "sanitizer_instrumentation" + unknown = "unknown" + + +class CallTraceEntry(ShellphishBaseModel): + depth: int + type: BacktraceType = Field( + description="The type of the backtrace entry. Describes what this entry refers to (as best we can tell)." + ) + + trace_line: Optional[str] = Field( + description="Either the function signature or the error line (Not currently used)" + ) + + source_location: Optional[SourceLocation] = None + binary_location: Optional[BinaryLocation] = None + + def enhance_with_function_resolver( + self, + function_resolver: FunctionResolver, + ): + if function_resolver is None: + log.warning("No function resolver provided, skipping enhancement") + return + try: + source_location = None + log.info("##### Enhancing call trace entry: %s", self) + if self.source_location: + source_location = self.source_location + elif self.binary_location: + func_name = self.binary_location.function_name + source_location = SourceLocation( + function_name=func_name, + java_info=JavaInfo( + package=self.binary_location.package, + ) + if self.binary_location.package + else None, + ) + else: + log.warning( + "No source or binary location found for call trace entry: %s", self + ) + return self + + try: + results = function_resolver.resolve_source_location( + source_location, num_top_matches=3, allow_build_generated=False + ) + except Exception: + log.error( + "Error resolving source location: %s", + source_location, + exc_info=True, + ) + if artiphishell_should_fail_on_error(): + raise + return + + if not results: + return + + log.error( + "Results found for source location: loc=%s: results=%s", + source_location, + results, + ) + for key, ranking in results: + log.info("Ranking: %s: %s", key, ranking) + try: + focus_repo_keys = [ + x[0] + for x in results + if function_resolver.get(x[0]).focus_repo_relative_path + ] + except Exception: + log.error("Error getting focus repo keys: %s", results, exc_info=True) + if artiphishell_should_fail_on_error(): + raise + return + + if len(focus_repo_keys) > 1: + # WARNING: multiple focus repo paths found for function name + log.warning( + "Multiple focus repo paths found for function name: %s => %s", + source_location.function_name, + results, + ) + log.warning("Picking the best one, but this is sus") + if artiphishell_should_fail_on_error(): + raise ValueError( + "Multiple focus repo paths found for function name: %s => %s" + % (source_location.function_name, results) + ) + key = None + if focus_repo_keys: + key = focus_repo_keys[0] + + if not key and len(results) >= 2: + # WARNING: multiple results found for function name + log.warning( + "Multiple results found for function name: %s => %s", + source_location.function_name, + results, + ) + log.warning("Picking the best one, but this is sus") + if artiphishell_should_fail_on_error(): + raise ValueError( + "Multiple results found for function name: %s => %s" + % (source_location.function_name, results) + ) + + if not key and len(results): + key, rankings = results[0] + log.info( + f"Lookup result function index for {source_location=} {key=} {rankings=}" + ) + + log.info(f"Lookup result function index for {source_location=} {key=}") + func_index = function_resolver.get_with_default(key, default=None) + + if func_index is not None: + log.info("Found function index: %s", func_index) + + line_no = ( + source_location.line_number + if ( + source_location.line_number is not None + and ( + func_index.start_line + <= source_location.line_number + <= func_index.end_line + ) + ) + else None + ) + line_map = { + func_index.start_line + i: code_line + for i, code_line in enumerate(func_index.code.split("\n")) + } + source_location = SourceLocation( + focus_repo_relative_path=func_index.focus_repo_relative_path, + relative_path=func_index.focus_repo_relative_path, + full_file_path=func_index.target_container_path, + file_name=Path(func_index.target_container_path.name) + if func_index.target_container_path + else None, + function_name=func_index.funcname, + line_number=line_no, + line_text=line_map[line_no] if line_no in line_map else None, + symbol_offset=self.source_location.symbol_offset + if self.source_location + else ( + self.binary_location.symbol_offset + if self.binary_location + else None + ), + symbol_size=self.source_location.symbol_size + if self.source_location + else ( + self.binary_location.symbol_size + if self.binary_location + else None + ), + raw_signature=func_index.signature, + function_index_signature=key, + function_index_key=key, + java_info=( + self.source_location.java_info if self.source_location else None + ) + or JavaInfo( + package=func_index.package, + class_name=func_index.class_name, + ), + ) + self.source_location = source_location + if self.binary_location: + self.binary_location.function_index_key = key + self.binary_location.function_index_signature = key + except Exception: + log.error("Error enhancing call trace entry: %s", self, exc_info=True) + if artiphishell_should_fail_on_error(): + raise + + +class CallTrace(ShellphishBaseModel): + reason: Optional[str] = Field( + description="The reason for the call trace, normally a description of the crash" + ) + dedup_token: Optional[str] = Field( + default=None, description="A token that can be used to deduplicate call traces" + ) + call_locations: List[CallTraceEntry] = Field( + description="The call locations of the crash", default_factory=list + ) + + def get_dedup_token(self): + if self.dedup_token: + return self.dedup_token + + return self.get_dedup_token_oss_fuzz() + + def get_dedup_token_full(self): + dedup_vals = [] + for entry in self.call_locations: + if entry.source_location: + if entry.source_location.function_name: + dedup_vals.append(str(entry.source_location.function_name)) + elif entry.source_location.file_name: + dedup_vals.append(str(entry.source_location.file_name)) + else: + dedup_vals.append(f"UNKNOWN_{entry.type}") + elif entry.binary_location: + if entry.binary_location.function_name: + dedup_vals.append(str(entry.binary_location.function_name)) + elif entry.binary_location.file_name: + dedup_vals.append(str(entry.binary_location.file_name)) + else: + dedup_vals.append(f"UNKNOWN_{entry.type}") + else: + dedup_vals.append(f"UNKNOWN_{entry.type}") + return "--".join(dedup_vals) + + def get_dedup_token_oss_fuzz(self, num_entries=3): + dedup_vals = [] + for entry in self.call_locations[:num_entries]: + if entry.source_location and entry.source_location.function_name: + dedup_vals.append(entry.source_location.function_name) + elif entry.binary_location and entry.binary_location.function_name: + dedup_vals.append(entry.binary_location.function_name) + else: + dedup_vals.append(f"UNKNOWN_LOCATION_{entry.type}_{entry.trace_line}") + return "--".join(dedup_vals) + + def get_dedup_token_shellphish(self, num_entries=3): + dedup_vals = [] + i = 0 + while len(dedup_vals) < num_entries and i < len(self.call_locations): + entry = self.call_locations[i] + if entry.source_location and ( + entry.source_location.focus_repo_relative_path + or entry.source_location.function_index_signature + ): + dedup_vals.append(entry.source_location.function_name) + i += 1 + return "--".join(dedup_vals) + + +class LosanSanitizerEnum(str, Enum): + OSCommandInjection = "OSCommandInjection" + SQLInjection = "SQLInjection" + FilePathTraversal = "FilePathTraversal" + ScriptEngineInjection = "ScriptEngineInjection" + ServerSideRequestForgery = "ServerSideRequestForgery" + ServerSideTemplateInjection = "ServerSideTemplateInjection" + DeserializationVulnerability = "DeserializationVulnerability" + ExpressionLanguageInjection = "ExpressionLanguageInjection" + + +class LoSanMetaData(ShellphishBaseModel): + sanitizer_type: LosanSanitizerEnum = Field( + description="The type of losan sanitizer crash that was triggered", + required=True, + ) + found_string: Optional[bytes] = None + expected_string: Optional[bytes] = None + + def clean_invalid_utf8(self): + if self.found_string is not None and isinstance(self.found_string, bytes): + self.found_string = self.found_string.decode( + "utf-8", errors="ignore" + ).encode("utf-8") + if self.expected_string is not None and isinstance(self.expected_string, bytes): + self.expected_string = self.expected_string.decode( + "utf-8", errors="ignore" + ).encode("utf-8") + + def description(self) -> str: + return f"The {self.sanitizer_type} sanitizer expected to find `{self.expected_string!r}` but found `{self.found_string!r}`" + + +class SanitizerReport(ShellphishBaseModel): + raw_report: bytes + summary: str + crash_type: str + sanitizer: str + internal_crash_type: Optional[str] = None + stack_traces: Dict[str, CallTrace] + crash_info: Dict[str, Any] + extra_context: Optional[str] = None + losan: bool = False + losan_metadata: Optional[LoSanMetaData] = None + + def clean_invalid_utf8(self): + if self.raw_report is not None and isinstance(self.raw_report, bytes): + self.raw_report = self.raw_report.decode("utf-8", errors="ignore").encode( + "utf-8" + ) + if self.losan_metadata is not None: + self.losan_metadata.clean_invalid_utf8() + + @model_validator(mode="after") + def sanity_check_model(self) -> "SourceLocation": + # print(self) + # import ipdb; ipdb.set_trace() + assert self.losan == ( + self.losan_metadata is not None + ), "If losan is set, losan_metadata must be set" + assert self.losan == ( + b"[LOSAN]" in self.raw_report + ), "losan should be set if and only if the report contains [LOSAN]" + if self.losan: + # if there's a losan report, we should *always* have a stack trace + assert self.stack_traces, "If losan is set, we should have stack traces" + assert ( + "[LOSAN]" in self.summary + ), "If losan is set, we should have [LOSAN] in the summary" + + return self + + @property + def final_crash_type(self): + if self.internal_crash_type == None: + return self.crash_type + return self.internal_crash_type + + @property + def final_sanitizer_type(self): + return self.sanitizer + ": " + self.crash_type + + @classmethod + def from_asan_report(cls, report: bytes): + return cls( + raw_report=report, + summary="", + crash_type="UNKNOWN", + sanitizer="asan", + stack_traces={}, + crash_info={}, + ) + + def enhance_with_function_resolver(self, function_resolver: FunctionResolver): + if not function_resolver: + log.warning("No function resolver provided, skipping enhancement") + return + assert ( + function_resolver is not None + ), "Function resolver is required to enhance sanitizer report" + log.info("Enhancing sanitizer report: %s", self) + for stack_trace_name, stack_trace in self.stack_traces.items(): + for cte in stack_trace.call_locations: + try: + cte.enhance_with_function_resolver(function_resolver) + except Exception as e: + log.error( + "Error enhancing call trace entry: %s", cte, exc_info=True + ) + log.error("Error: %s", e) + if artiphishell_should_fail_on_error(): + raise + + +def clean_java_report_str(report: str): + report_start = report.split("\n")[0] + report_start = report_start.split(":")[:3] + report_start = ":".join(report_start) + report = report_start + "\n" + "\n".join(report.split("\n", 1)[1:]) + report = re.sub(r"collections took \d+m?s", "collections took ", report) + report = re.sub(r"\(use .* to reproduce\)", "", report).strip() + report = re.sub(r"File path traversal: .*", "File path traversal", report) + report = re.sub( + r'FOUND:\s+"(.*?)"\s+and\s+EXPECTED:\s+"(.*?)"', + "FOUND: and EXPECTED: ", + report, + ) + report = re.sub(r"-Xmx\d+[mMgGkK]", "-Xmx", report) + report = re.sub( + r'java.io.IOException: Cannot run program ".*', + "java.io.IOException: Cannot run program", + report, + ) + report = re.sub( + r"PS Scavenge: \d+ collections took", + "PS Scavenge: collections took", + report, + ) + report = re.sub( + r"Attempted connection to: .*", "Attempted connection to: ", report + ) + report = re.sub( + r"Index \d+ out of bounds for length \d+", + "Index out of bounds for length ", + report, + ) + report = re.sub(r" \d+ seconds", " seconds", report) + return report + + +def clean_report(report: bytes): + # first, find the line starting with "SUMMARY: " and remove every line after that + if b"SUMMARY:" in report: + report, rest = report.split(b"SUMMARY:", 1) + report += b"SUMMARY:" + rest.split(b"\n", 1)[0] + + # if 'libFuzzer: timeout' is in the report, replace the entire report since the stack-trace is non-deterministically ordered in the report. + if b"libFuzzer: timeout" in report: + return b"\nSUMMARY: libFuzzer: timeout" + + report = report.decode("utf-8", errors="ignore") + report = re.sub(r"\+0x[0-9a-fA-F]*", "+0x", report) + report = re.sub(r"==\d+==", "==MARKER==", report) + report = re.sub(r"of size \d+", "of size ", report) + report = re.sub(r"is located \d+ bytes", "is located bytes", report) + report = re.sub(r"inside of \d+-byte", "inside of -byte", report) + report = re.sub(r"after \d+-byte region", "after -byte region", report) + report = re.sub(r"SCARINESS: \d+", "SCARINESS: ", report) + report = re.sub(r"\d+-byte-write", "-byte-write", report) + report = re.sub(r"\d+-byte-read", "-byte-read", report) + report = re.sub(r"multi-byte-write", "multi-byte-write", report) + report = re.sub(r"multi-byte-read", "multi-byte-read", report) + report = re.sub(r"0x[0-9a-fA-F]{8,}", "0x", report) + report = re.sub(r" is ascii string \'[^\']+\'", "", report) + report = re.sub(r"0x[0-9a-fA-F]{8,}", "0x", report) + + report = re.sub( + r"SHELLPHISH_FOUND_LOSAN: \"(.*?)\"", + 'SHELLPHISH_FOUND_LOSAN: ""', + report, + ) + report = re.sub( + r"SHELLPHISH_EXPECTED_LOSAN: \"(.*?)\"", + 'SHELLPHISH_EXPECTED_LOSAN: ""', + report, + ) + + return clean_java_report_str(report) + + +class DedupSanitizerReport(ShellphishBaseModel): + cleaned_report: str + dedup_tokens_shellphish: Dict[str, str] + dedup_tokens_full: Dict[str, str] + dedup_tokens: Dict[str, str] + crash_type: str + sanitizer: str + internal_crash_type: Optional[str] = None + stack_traces: Dict[str, CallTrace] + losan: bool + + @classmethod + def from_sanitizer_report(cls, report: SanitizerReport): + dedup_tokens = {k: v.get_dedup_token() for k, v in report.stack_traces.items()} + dedup_tokens_full = { + k: v.get_dedup_token_full() for k, v in report.stack_traces.items() + } + dedup_tokens_shellphish = { + k: v.get_dedup_token_shellphish() for k, v in report.stack_traces.items() + } + stack_traces = {} + for k, v in report.stack_traces.items(): + call_locations = [] + for entry in v.call_locations: + ent = CallTraceEntry(**entry.model_dump()) + ent.trace_line = clean_report(ent.trace_line.encode("utf-8")) + call_locations.append(ent) + stack_traces[k] = CallTrace( + reason=v.reason, + dedup_token=v.dedup_token, + call_locations=call_locations, + ) + return cls( + cleaned_report=clean_report(report.raw_report), + dedup_tokens_full=dedup_tokens_full, + dedup_tokens=dedup_tokens, + dedup_tokens_shellphish=dedup_tokens_shellphish, + crash_type=report.crash_type, + sanitizer=report.sanitizer, + internal_crash_type=report.internal_crash_type, + stack_traces=stack_traces, + losan=report.losan, + ) + + @property + def final_crash_type(self): + if self.internal_crash_type == None: + return self.crash_type + return self.internal_crash_type + + @property + def final_sanitizer_type(self): + return self.sanitizer + ": " + self.crash_type diff --git a/patchery/data/models/crs_reports.py b/patchery/data/models/crs_reports.py new file mode 100644 index 0000000..0c295bc --- /dev/null +++ b/patchery/data/models/crs_reports.py @@ -0,0 +1,576 @@ +from enum import Enum +import hashlib +from pydantic import Field +from annotated_types import Len +from typing import List, Optional, Any, Annotated, Literal, Dict, Union +from pathlib import Path + +from patchery.data.models.base import ShellphishBaseModel +from patchery.data.models.constraints import ( + PDT_ID, + MD5_CONSTRAINTS, + SHA1_CONSTRAINTS, +) +from patchery.data.models.crash_reports import ( + CallTrace, + DedupSanitizerReport, + SanitizerReport, +) +from patchery.data.models.symbols import POI +from patchery.data.models.target import HarnessInfo, CrashingInputMetadata +from patchery.data.models.indexer import GlobalVariableReference +from patchery.data.models.organizer_evaluation import ( + OrganizerCrashEvaluation, + SignificanceEnum, +) +import yaml + + +class DedupInfoKind(Enum): + FULL = "full" + OSS_FUZZ = "oss-fuzz" + SHELLPHISH = "shellphish" + ORGANIZERS = "organizers" + + +class DedupInfo(ShellphishBaseModel): + kind: DedupInfoKind + pdt_project_id: PDT_ID = Field( + description="The pydatatask project id this deduplication information is associated with" + ) + consistent_sanitizers: Annotated[List[str], Len(min_length=1)] + tokens: Dict[str, str] + + def hash(self) -> str: + """ + Get the hash of the deduplication information. + :return: A string representing the hash of the deduplication information. + """ + return hashlib.md5(self.identifier().encode("utf-8")).hexdigest() + + def identifier(self) -> str: + s = f"{self.pdt_project_id}--{self.kind.value}--{','.join(sorted(self.consistent_sanitizers))}===" + s += "===".join(f"{k}--{v}" for k, v in sorted(self.tokens.items())) + return s + + def canonical_representation(self) -> str: + """ + Get the canonical representation of the deduplication information. + :return: A string representing the canonical representation of the deduplication information. + """ + return yaml.safe_dump( + { + "kind": self.kind.value, + "pdt_project_id": self.pdt_project_id, + "consistent_sanitizers": self.consistent_sanitizers, + "tokens": self.tokens, + }, + sort_keys=True, + default_flow_style=False, + ) + + +class POIReport(HarnessInfo): + harness_info_id: str = Field(description="The pydatatask harness info id") + organizer_crash_eval: OrganizerCrashEvaluation = Field( + description="The crash evaluation generated by the organizer's interaction scripts" + ) + detection_strategy: str = Field( + description="The detection strategy used to find the crash" + ) + fuzzer: str = Field(description="The fuzzer used to find the crash") + consistent_sanitizers: Annotated[List[str], Len(min_length=1)] = Field( + description="The sanitizers that are consistently triggered in the crash" + ) + inconsistent_sanitizers: List[str] = Field( + description="The sanitizers that are inconsistently triggered in the crash" + ) + # sanitizer_history: Annotated[List[Annotated[List[Annotated[str, ID_CONSTRAINTS]], Len(min_length=1)]], Len(min_length=1)] = Field(description="The history of sanitizers triggered in the crash") + crash_report_id: str = Field( + description="The pydatatask id of the representative crash report (md5sum)" + ) + crash_reason: str = Field(description="The reason for the crash") + pois: Annotated[List[POI], Len(min_length=1)] = Field( + description="The points of interest in the crash" + ) + stack_traces: Dict[str, CallTrace] = Field( + description="The call traces of the crash (Currently mostly identical to the pois in use). This maps the call trace type (e.g. main, alloc, free, etc.) to the call trace.", + default_factory=dict, + ) + extra_context: Optional[str] = Field( + description="Extra context about the crash (e.g. other reports than the fatal one. Usually happens with ubsan reports.)", + default=None, + ) + additional_information: Any = Field( + description="Additional information about the crash (None in all current cases)", + default=None, + ) + + def get_dedup_info(self, kind: Union[DedupInfoKind, str]) -> DedupInfo: + """ + Get the deduplication information for the OSS-Fuzz deduplication strategy. + :return: A DedupInfo object. + """ + if isinstance(kind, str): + kind = DedupInfoKind(kind) + + dedup_tokens = {} + if kind in (DedupInfoKind.FULL, DedupInfoKind.SHELLPHISH, DedupInfoKind.FULL): + for reason, ct in self.stack_traces.items(): + getter = { + DedupInfoKind.OSS_FUZZ: lambda: ct.get_dedup_token(), + DedupInfoKind.SHELLPHISH: lambda: ct.get_dedup_token_shellphish( + num_entries=3 + ), + DedupInfoKind.FULL: lambda: ct.get_dedup_token_full(), + }.get(kind or DedupInfoKind.OSS_FUZZ, None) + assert getter is not None + dedup_tokens[reason] = getter() + + elif kind == DedupInfoKind.ORGANIZERS: + dedup_tokens["significance"] = str( + self.organizer_crash_eval.significance.value + ) + dedup_tokens["crash-state"] = self.organizer_crash_eval.crash_state + if self.organizer_crash_eval.instrumentation_key: + dedup_tokens["instrumentation-key"] = ( + self.organizer_crash_eval.instrumentation_key + ) + + else: + raise ValueError(f"Unknown DedupInfoKind: {kind}") + + return DedupInfo( + kind=kind, + pdt_project_id=self.project_id, + consistent_sanitizers=self.consistent_sanitizers, + tokens=dedup_tokens, + ) + + def get_dedup_info_full(self) -> DedupInfo: + return self.get_dedup_info(DedupInfoKind.FULL) + + def get_dedup_info_oss_fuzz(self) -> DedupInfo: + return self.get_dedup_info(DedupInfoKind.OSS_FUZZ) + + def get_dedup_info_shellphish(self) -> DedupInfo: + return self.get_dedup_info(DedupInfoKind.SHELLPHISH) + + def get_dedup_info_organizers(self) -> DedupInfo: + return self.get_dedup_info(DedupInfoKind.ORGANIZERS) + + def get_all_dedup_infos(self) -> List[DedupInfo]: + return [self.get_dedup_info(kind) for kind in DedupInfoKind] + + +# class JazzerStackFrame(ShellphishBaseModel): +# text: str = Field(description="The text of the stack frame line in jazzer output", default="") +# package: Optional[str] = Field(description="The package in the stack frame line", default=None) +# file: Optional[str] = Field(description="The file in the stack frame line", default=None) +# line: Optional[int] = Field(description="The line in the stack frame line", default=None) +# function: Optional[str] = Field(description="The function in the stack frame line", default=None) +# cls: Optional[str] = Field(description="The class in the stack frame line", default=None) + +# class RawJazzerReport(ShellphishBaseModel): +# triggered_sanitizers: List[str] = Field(description="The sanitizers triggered in the crash") +# report: bytes = Field(description="The raw jazzer report") +# error_line: str = Field(description="The error line in the jazzer report") +# argument: Optional[str] = Field(description="The argument in the jazzer report linked to the crash") +# stack_trace: List[JazzerStackFrame] = Field(description="The stack trace in the jazzer report", default_factory=list) + +# class RawAsanReport(ShellphishBaseModel): +# triggered_sanitizers: List[Annotated[str, ID_CONSTRAINTS]] = Field(description="The sanitizers triggered in the crash") +# report: bytes = Field(description="The raw asan report") +# error_line: str = Field(description="The asan specified error line") + +# class RawKasanReport(ShellphishBaseModel): +# triggered_sanitizers: List[Annotated[str, ID_CONSTRAINTS]] = Field(description="The sanitizers triggered in the crash") +# report: bytes = Field(description="The raw kasan report") + +# class CrashReport(ShellphishBaseModel): +# reports: List[Union[RawJazzerReport, RawAsanReport, RawKasanReport]] = Field(description="The raw reports of the crash") + + +class AIxCCDedupValues(ShellphishBaseModel): + significance: SignificanceEnum = Field( + description="The exit code of the deduplication values" + ) + crash_state: List[str] = Field( + description="The crash state of the deduplication values" + ) + instrumentation_state: List[str] = Field( + description="The instrumentation state of the deduplication values" + ) + + +class RawPoVReport(ShellphishBaseModel): + parser: Literal["failed", "jazzer", "asan", "kasan"] = Field( + description="The fuzzer sanitizer that generated the PoV" + ) + exception: Optional[str] = Field( + description="The exception raised during report parsing", default=None + ) + traceback: Optional[str] = Field( + description="The traceback of the exeception raised during report parsing", + default=None, + ) + unparsed: Optional[bytes] = Field( + description="The raw unparsed report (None unless an error occurs)", + default=None, + ) + extra_context: Optional[str] = Field( + description="Extra context about the crash (e.g. other reports than the fatal one. Usually happens with ubsan reports.)", + default=None, + ) + organizer_crash_eval: OrganizerCrashEvaluation = Field( + description="The crash evaluation generated by the organizer's interaction scripts", + ) + crash_report: Optional[SanitizerReport] = Field( + description="The parsed crash report (None if a parsing error occurs or no crash was found)", + default=None, + ) + dedup_crash_report: Optional[DedupSanitizerReport] = Field( + description="The deduplicated crash report (None if a parsing error occurs or no crash was found)", + default=None, + ) + triggered_sanitizers: List[str] = Field( + description="All sanitizers triggered in the crash" + ) + + def clean_invalid_utf8(self): + """ + Clean invalid UTF-8 sequences from the unparsed bytes field. + First tries to decode as UTF-8, and if that fails, strips invalid characters. + """ + + if self.crash_report is not None: + self.crash_report.clean_invalid_utf8() + + if self.unparsed is not None and isinstance(self.unparsed, bytes): + self.unparsed = self.unparsed.decode("utf-8", errors="ignore").encode( + "utf-8" + ) + + +class RunImageResult(ShellphishBaseModel): + task_success: bool = Field(description="Whether the task was successful") + run_exit_code: Optional[int] = Field( + description="The exit code of the command run in the background", default=None + ) + time_scheduled: float = Field(description="The time the command was scheduled") + time_start: float = Field(description="The time the command started") + time_end: float = Field(description="The time the command terminated") + time_taken: float = Field(description="The time taken to run the command") + stdout: bytes = Field( + description="The stdout of the process inside the docker run by run.sh" + ) + stderr: bytes = Field( + description="The stderr of the process inside the docker run by run.sh" + ) + + container_id: Optional[str] = Field( + description="The container id of the docker run by run.sh for a local run", + default=None, + ) + container_name: Optional[str] = Field( + description="The container name of the docker run by run.sh for a local run", + default=None, + ) + out_dir: Optional[Path] = Field( + description="The output directory of the docker run by run.sh for a local run", + default=None, + ) + build_job_pdt_id: Optional[PDT_ID] = Field( + description="The pydatatask id of the build job that was run to build the image", + default=None, + ) + + +class BuildTargetResult(RunImageResult): + build_success: bool = Field(description="Whether the build was successful") + build_request_id: Optional[PDT_ID] = Field( + description="The request id of the build job that was run to build the image", + default=None, + ) + + +class RunImageInBackgroundResult(ShellphishBaseModel): + task_success: bool = Field(description="Whether the task was successful") + run_exit_code: int = Field( + description="The exit code of the command run in the background" + ) + time_scheduled: float = Field(description="The time the command was scheduled") + time_start: float = Field(description="The time the command started") + + container_id: Optional[str] = Field( + description="The container id of the docker run by run.sh for a local run", + default=None, + ) + container_name: Optional[str] = Field( + description="The container name of the docker run by run.sh for a local run", + default=None, + ) + out_dir: Optional[Path] = Field( + description="The output directory of the docker run by run.sh for a local run", + default=None, + ) + + +class RunPoVResult(RunImageResult): + # exitcode: int = Field(description="The exit code of the run_pov") + pov: RawPoVReport = Field( + description="The raw PoV report generated during the run_pov" + ) + + +class DedupPoVReportRepresentativeMetadata(HarnessInfo): + original_crash_id: PDT_ID = Field( + description="The pydatatask crash harness-input-id of the first crash that had this deduplicated report." + ) + consistent_sanitizers: List[str] = Field( + description="The sanitizers that are consistently triggered in the crash" + ) + harness_info_id: PDT_ID = Field( + description="The pydatatask harness info id of the representative crash report" + ) + + +class PoVReport(CrashingInputMetadata, RawPoVReport): + consistent_sanitizers: Annotated[List[str], Len(min_length=1)] = Field( + description="The sanitizers that are consistently triggered in the crash" + ) + inconsistent_sanitizers: List[str] = Field( + description="The sanitizers that are inconsistently triggered in the crash" + ) + + def get_dedup_info(self, kind: Union[DedupInfoKind, str]) -> DedupInfo: + """ + Get the deduplication information for the OSS-Fuzz deduplication strategy. + :return: A DedupInfo object. + """ + if isinstance(kind, str): + kind = DedupInfoKind(kind) + + dedup_tokens = {} + if kind in ( + DedupInfoKind.FULL, + DedupInfoKind.SHELLPHISH, + DedupInfoKind.OSS_FUZZ, + ): + stack_traces = {} + if self.dedup_crash_report and self.dedup_crash_report.stack_traces: + # Use the dedup_crash_report stack traces if available + stack_traces = self.dedup_crash_report.stack_traces + for reason, ct in stack_traces.items(): + getter = { + DedupInfoKind.OSS_FUZZ: lambda: ct.get_dedup_token(), + DedupInfoKind.SHELLPHISH: lambda: ct.get_dedup_token_shellphish( + num_entries=3 + ), + DedupInfoKind.FULL: lambda: ct.get_dedup_token_full(), + }.get(kind or DedupInfoKind.OSS_FUZZ, None) + assert getter is not None + dedup_tokens[reason] = getter() + elif kind == DedupInfoKind.ORGANIZERS: + dedup_tokens["significance"] = str( + self.organizer_crash_eval.significance.value + ) + dedup_tokens["crash-state"] = self.organizer_crash_eval.crash_state + if self.organizer_crash_eval.instrumentation_key: + dedup_tokens["instrumentation-key"] = ( + self.organizer_crash_eval.instrumentation_key + ) + else: + raise ValueError(f"Unknown DedupInfoKind: {kind}") + + return DedupInfo( + kind=kind, + pdt_project_id=self.project_id, + consistent_sanitizers=self.consistent_sanitizers, + tokens=dedup_tokens, + ) + + def get_dedup_info_full(self) -> DedupInfo: + return self.get_dedup_info(DedupInfoKind.FULL) + + def get_dedup_info_oss_fuzz(self) -> DedupInfo: + return self.get_dedup_info(DedupInfoKind.OSS_FUZZ) + + def get_dedup_info_shellphish(self) -> DedupInfo: + return self.get_dedup_info(DedupInfoKind.SHELLPHISH) + + def get_all_dedup_infos(self) -> List[DedupInfo]: + return [self.get_dedup_info(kind) for kind in DedupInfoKind] + + +class RepresentativeFullPoVReport(PoVReport): + run_pov_result: RunPoVResult = Field(description="The run_pov result of the crash") + original_crash_id: PDT_ID = Field( + description="The non-deduplicated pydatatask crash id" + ) + crash_report_id: Annotated[str, MD5_CONSTRAINTS] = Field( + description="The md5sum of the crash report" + ) + sanitizer_history: Annotated[ + List[Annotated[List[str], Len(min_length=1)]], Len(min_length=1) + ] = Field(description="The history of sanitizers triggered in the crash") + + +# class ASANReport(HarnessInfo): +# harness_info_id: PDT_ID = Field(description="The pydatatask harness info id") +# crash_report_id: PDT_ID = Field(description="The pydatatask id of the representative crash report (md5sum)") +# consistent_sanitizers: Annotated[List[Annotated[str, ID_CONSTRAINTS]], Len(min_length=1)] = Field(description="The sanitizers that are consistently triggered in the crash") +# inconsistent_sanitizers: List[str] = Field(description="The sanitizers that are inconsistently triggered in the crash", default_factory=list) +# sanitizer_history: Annotated[List[Annotated[List[Annotated[str, ID_CONSTRAINTS]], Len(min_length=1)]], Len(min_length=1)] = Field(description="The history of sanitizers triggered in the crash") +# fuzzer: str = Field(description="The fuzzer used to find the crash") +# sanitizer: str = Field(description="The sanitizer present in the crash") +# crash_type: str = Field(description="The type of crash") +# stack_traces: Dict[str, List[ASANStackTrace]] = Field(description="The stack traces of the crash", default_factory=dict) + + +class DedupedFirstCrashCommitReport(ShellphishBaseModel): + crashing_commit: Annotated[str, SHA1_CONSTRAINTS] = Field( + description="The introducing commit that caused the crash" + ) + sanitizer_ids: Annotated[List[str], Len(min_length=1)] = Field( + description="The sanitizers triggered in the crash" + ) + cp_harness_name: str = Field( + description="The challenge project harness name of the crash" + ) + project_id: PDT_ID = Field(description="The pydatatask id of the project") + + +class FirstCrashCommitReport(ShellphishBaseModel): + crashing_commit: Annotated[str, SHA1_CONSTRAINTS] = Field( + description="The introducing commit that caused the crash" + ) + cp_source: str = Field( + description="The challenge project source linked to the crash" + ) + sanitizer_ids: Annotated[List[str], Len(min_length=1)] = Field( + description="The sanitizers triggered in the crash" + ) + crash_report_id: PDT_ID = Field( + description="The pydatatask id of the representative crash report (md5sum)" + ) + crash_id: PDT_ID = Field(description="The pydatatask id of the crash") + cp_harness_name: str = Field( + description="The challenge project harness name of the crash" + ) + + +class LocationOfInterest(ShellphishBaseModel): + file: str = Field(description="The file containing the point of interest") + function: str = Field(description="The function containing the point of interest") + start_line: int = Field( + description="The start line number of the point of interest" + ) + end_line: int = Field(description="The end line number of the point of interest") + signature: str = Field( + description="The signature of the point of interest", default="" + ) + + +class SuggestedPatch(ShellphishBaseModel): + description: str = Field(description="The description of the patch") + + +class RootCauseReport(ShellphishBaseModel): + found_root_cause: bool = Field(description="Whether the root cause was found") + description: str = Field(description="The root cause of the crash") + dataflow: str = Field(description="The dataflow of the crash") + bug_locations: List[LocationOfInterest] = Field( + description="The location of the bug in the code" + ) + bug_classes: List[str] = Field(description="The potential classes of the bug") + root_cause_locations: List[LocationOfInterest] = Field( + description="The root cause locations of the crash" + ) + patches: List[SuggestedPatch] = Field(description="The patches to fix the crash") + errored: bool = Field( + description="Whether there was an error in the root cause analysis" + ) + + +class KumushiCodeFunction(ShellphishBaseModel): + name: str = Field(description="The name of the function") + start_line: int = Field(description="The start line of the function") + end_line: int = Field(description="The end line of the function") + file_path: str = Field(description="The path to the file") + code: str | None = Field(default=None, description="The code of the function") + global_vars: list[GlobalVariableReference] | None = Field( + default=None, description="The global variables of the function" + ) + version: str | None = Field(default=None, description="The version of the function") + + +class KumushiPOI(ShellphishBaseModel): + sources: list[int] = Field(description="The source we collect the POI from") + crash_line_number: int | None = Field( + default=None, description="The line number of the crash" + ) + crash_line: str | None = Field( + default=None, description="The line text of the crash" + ) + code_function: KumushiCodeFunction = Field( + description="The code function of the POI" + ) + + +class KumushiPOICluster(ShellphishBaseModel): + poi_cluster: List[KumushiPOI] = Field(description="The Kumushi POI cluster") + reasoning: str | None = Field( + default=None, description="The reasoning behind the cluster" + ) + + +class KumushiRootCauseReport(ShellphishBaseModel): + poi_clusters: List[KumushiPOICluster] = Field( + description="The Kumushi POI clusters" + ) + rca_hash: str = Field(description="The hash of the RCA") + + +class PatchRequestMeta(ShellphishBaseModel): + request_type: Literal["refine", "patch"] = Field( + description="The type for this request (refine/patch)" + ) + poi_report_id: str = Field(description="The poi report ID") + failed_functionality: bool = Field( + description="Whether the patch failed", default=False + ) + patcher_name: str | None = Field( + description="The patcher that create the previous patch", default=None + ) + patch_id: str | None = Field( + description="The id of patch that need to be refined", default=None + ) + bucket_id: str | None = Field(description="The bucket ID", default=None) + + +class RunTestsResult(ShellphishBaseModel): + tests_exist: bool = Field( + description="True when the tests exist. When false, all other data is invalid" + ) + timedout: bool = Field( + description="True when the tests never finished. Normally should be False.", + default=False, + ) + all_passed: bool = Field(description="All testcases passed") + stdout: str = Field(description="Tests stdout", default="") + stderr: str = Field(description="Tests stderr", default="") + + +class RunOssFuzzBuildCheckResult(ShellphishBaseModel): + all_passed: bool = Field(description="All testcases passed") + timedout: bool = Field( + description="True when the tests never finished. Normally should be False.", + default=False, + ) + internal_error: bool = Field( + description="True when an internal error occurred during the build check", + default=False, + ) + stdout: str = Field(description="Tests stdout", default="") + stderr: str = Field(description="Tests stderr", default="") diff --git a/patchery/data/models/extended_aixcc_api.py b/patchery/data/models/extended_aixcc_api.py new file mode 100644 index 0000000..eabdc5f --- /dev/null +++ b/patchery/data/models/extended_aixcc_api.py @@ -0,0 +1,16 @@ +from typing import List, Optional +from shellphish_crs_utils.models.aixcc_api import TaskDetail, Task +from shellphish_crs_utils.models.constraints import PDT_ID +from uuid import UUID + + +class ExtendedTaskDetail(TaskDetail): + task_uuid: UUID + task_sanitizer: str + pdt_task_id: PDT_ID + concurrent_target_num: Optional[int] = None + fuzzing_pool_name: Optional[str] = None + + +class ExtendedTask(Task): + tasks: List[ExtendedTaskDetail] diff --git a/patchery/data/models/indexer.py b/patchery/data/models/indexer.py new file mode 100644 index 0000000..1ed4703 --- /dev/null +++ b/patchery/data/models/indexer.py @@ -0,0 +1,194 @@ +from typing import List, Optional, Any, Dict, TypeAlias +from pathlib import Path + +from pydantic import Field +from patchery.data.models.base import ShellphishBaseModel + +FUNCTION_INDEX_KEY: TypeAlias = str + + +class ReferenceBase(ShellphishBaseModel): + unique_identifier: str = Field(description="A unique identifier of the reference") + name: str = Field(description="The name of the reference") + + +class GlobalVariableReference(ReferenceBase): + declaration: str = Field(description="The declaration of the global variable") + raw_comment: Optional[str] = Field( + description="The raw comment of the global variable" + ) + declaration_start_line: int = Field( + description="The starting line of the declaration" + ) + declaration_end_line: int = Field(description="The ending line of the declaration") + declaration_start_offset: int = Field( + description="The starting offset of the declaration" + ) + declaration_end_offset: int = Field( + description="The ending offset of the declaration" + ) + declaration_start_column: int = Field( + description="The starting column of the declaration" + ) + declaration_end_column: int = Field( + description="The ending column of the declaration" + ) + type: str = Field(description="The type of the global variable") + + +class FunctionReference(ReferenceBase): + pass + + +class FunctionReferenceCall(FunctionReference): + type: str = Field(description="The type of the function call") + + +class FunctionBase(ShellphishBaseModel): + target_compile_args: Dict[str, Any] = Field( + description="The compile arguments of the file this function resides in" + ) + was_directly_compiled: bool = Field( + description="Whether the function was in a file that was compiled or not. If false, this file might have been excluded by the build-system or post-processed elsewhere before building." + ) + is_generated_during_build: bool = Field( + description="Whether the function is generated during build time or already exists in the project sources. This is really only set for things in the focus_repo.", + default=False, + ) + + unique_identifier: str = Field(description="A unique identifier of the function") + code: str = Field(description="The source code of the function") + hash: str = Field(description="The hash of the function source code") + raw_comment: Optional[str] = Field(description="The raw comment of the function") + start_line: int = Field(description="The starting line of the function") + end_line: int = Field(description="The ending line of the function") + start_offset: int = Field(description="The starting offset of the function") + end_offset: int = Field(description="The ending offset of the function") + start_column: int = Field(description="The starting column of the function") + end_column: int = Field(description="The ending column of the function") + global_variables: List[GlobalVariableReference] = Field( + description="The global variables used in the function" + ) + signature: Optional[str] = Field(description="The signature of the function") + target_container_path: Optional[Path] = Field( + description="The path to the file containing the function. This is the absolute path as seen from in the target container.", + examples=["/src/hiredis/hiredis.c"], + default=None, + ) + focus_repo_relative_path: Optional[Path] = Field( + description="The path to the file containing the function. This is the relative path as seen from the focus directory. If this is null, the function is *not* inside the focus repo.", + examples=["hiredis.c"], + default=None, + ) + + +class FunctionInfo(FunctionBase): + name: str = Field(description="The name of the function") + mangled_name: str = Field(description="The mangled name of the function") + comment: Optional[str] = Field( + description="The comments in the function (Not Implemented)" + ) + calls: List[FunctionReferenceCall] = Field( + description="The list of function calls in the function" + ) + func_return_type: str = Field(description="The return type of the function") + + +class MethodInfo(FunctionBase): + mangled_name: str = Field(description="The mangled name of the method") + full_name: str = Field( + description="The full name of the method (class_name::method_name)" + ) + method_name: str = Field(description="The name of the method") + comment: Optional[str] = Field( + description="The comments in the method (Not Implemented)" + ) + calls: List[FunctionReferenceCall] = Field( + description="The list of function calls in the method" + ) + + +class MacroInfo(FunctionBase): + name: str = Field(description="The name of the macro") + + +class FunctionIndex(FunctionBase): + funcname: str = Field(description="The name of the function") + full_funcname: str = Field( + description="The full name of the function (class_name::function_name if applicable)" + ) + func_return_type: str = Field(description="The return type of the function") + signature: Optional[FUNCTION_INDEX_KEY] = Field( + description="Raw function signature", default=None + ) + arguments: List[str] = Field(description="The arguments of the function") + local_variables: List[str] = Field( + description="The local variables used in the function" + ) + func_calls_in_func_with_fullname: List[Any] = Field( + description="The list of function calls in the function" + ) + filename: str = Field(description="The name of the file containing the function") + class_name: Optional[str] = Field( + description="The name of the class containing the function", default=None + ) + comments: List[str] = Field( + description="The comments in the function (Not Implemented)", + default_factory=list, + ) + cfg: Optional[str] = Field(description="Not Implemented", default=None) + package: Optional[str] = Field( + description="The package of the function (Java Only)", default=None + ) + language_specific_info: Optional[Dict[str, Any]] = Field( + description="The language specific information of the function", default=None + ) + + +class CommitToFunctionIndex(ShellphishBaseModel): + commit_to_index_info: Dict[str, Dict[FUNCTION_INDEX_KEY, Path]] = Field( + description="The mapping of commit sha (e.g. '1_9faebc...') to function index information" + ) + + +class SignatureToFile(ShellphishBaseModel): + sig_to_file: Dict[FUNCTION_INDEX_KEY, Path] = Field( + description="The mapping of function signature to jsons directory file path" + ) + + +class ReducedFunctionIndex(ShellphishBaseModel): + func_name: str = Field(description="The name of the function") + function_signature: FUNCTION_INDEX_KEY = Field( + description="The signature of the function (filename:start_line:start_column::signature)" + ) + filename: str = Field(description="The name of the file containing the function") + indexed_jsons_relative_filepath: Path = Field( + description="The relative path to the source code of the function inside the functions jsons produced by the indexers." + ) + start_line: int = Field(description="The starting line of the function") + end_line: int = Field(description="The ending line of the function") + start_column: int = Field(description="The starting column of the function") + end_column: int = Field(description="The ending column of the function") + start_offset: int = Field(description="The starting offset of the function") + end_offset: int = Field(description="The ending offset of the function") + line_map: Optional[Dict[int, str]] = Field( + description="The mapping of line number to source code line (Contains full source of the function)", + default=None, + ) + target_container_path: Optional[Path] = Field( + description="The path to the file containing the function. This is the absolute path as seen from in the target container.", + examples=["/src/hiredis/hiredis.c"], + default=None, + ) + focus_repo_relative_path: Optional[Path] = Field( + description="The path to the file containing the function. This is the relative path as seen from the focus repository. If this is null, the function is *not* inside the focus repo.", + examples=["hiredis.c"], + default=None, + ) + + +class FunctionsByFile(ShellphishBaseModel): + func_by_file: Dict[Path, List[ReducedFunctionIndex]] = Field( + description="The mapping of file path to list of functions" + ) diff --git a/patchery/data/models/llvm_symbolizer.py b/patchery/data/models/llvm_symbolizer.py new file mode 100644 index 0000000..e412af6 --- /dev/null +++ b/patchery/data/models/llvm_symbolizer.py @@ -0,0 +1,135 @@ +# [ +# { +# "Address": "0x2915e0", +# "ModuleName": "/out/njs_process_script_fuzzer", +# "Symbol": [ +# { +# "Column": 0, +# "Discriminator": 0, +# "FileName": "/src/njs/external/njs_shell.c", +# "FunctionName": "LLVMFuzzerTestOneInput", +# "Line": 855, +# "StartAddress": "0x2915e0", +# "StartFileName": "", +# "StartLine": 0 +# } +# ] +# } +# ] + +# [ +# { +# "Address": "0x291915", +# "ModuleName": "./njs_process_script_fuzzer", +# "Symbol": [ +# { +# "Column": 5, +# "Discriminator": 0, +# "FileName": "/src/njs/external/njs_shell.c", +# "FunctionName": "njs_read_file", +# "Line": 3186, +# "StartAddress": "", +# "StartFileName": "", +# "StartLine": 0 +# }, +# { +# "Column": 11, +# "Discriminator": 0, +# "FileName": "/src/njs/external/njs_shell.c", +# "FunctionName": "njs_process_file", +# "Line": 3285, +# "StartAddress": "", +# "StartFileName": "", +# "StartLine": 0 +# }, +# { +# "Column": 15, +# "Discriminator": 0, +# "FileName": "/src/njs/external/njs_shell.c", +# "FunctionName": "njs_main", +# "Line": 458, +# "StartAddress": "", +# "StartFileName": "", +# "StartLine": 0 +# }, +# { +# "Column": 12, +# "Discriminator": 0, +# "FileName": "/src/njs/external/njs_shell.c", +# "FunctionName": "LLVMFuzzerTestOneInput", +# "Line": 869, +# "StartAddress": "0x2915e0", +# "StartFileName": "", +# "StartLine": 0 +# } +# ] +# } +# ] + +import json +import os +from typing import List, Tuple, TypeAlias +from patchery.data.models.base import ShellphishBaseModel +from patchery.data.models.symbols import BinaryLocation, SourceLocation + + +class LLVMSymbolizerSymbol(ShellphishBaseModel): + Column: int + Discriminator: int + FileName: str + FunctionName: str + Line: int + StartAddress: str + StartFileName: str + StartLine: int + + @property + def was_inlined(self) -> bool: + return not bool(self.StartAddress) + + def to_location(self) -> SourceLocation: + return SourceLocation( + full_file_path=self.FileName, + file_name=os.path.basename(self.FileName), + function_name=self.FunctionName, + line_number=self.Line, + symbol_offset=int(self.StartAddress, 16), + ) + + +class LLVMSymbolizerEntry(ShellphishBaseModel): + Address: str + ModuleName: str + + # in the case of inlines there can be multiple symbols for a given binary location + Symbol: List[LLVMSymbolizerSymbol] + + def get_locations(self) -> Tuple[BinaryLocation, List[SourceLocation]]: + binary_location = BinaryLocation.create( + full_binary_path=self.ModuleName, + offset=int(self.Address, 16), + ) + source_locs = [symbol.to_location() for symbol in self.Symbol] + assert all( + loc.was_inlined for loc in self.Symbol[:-1] + ), f"Only the last symbol can be non-inlined: {self.Symbol}, {source_locs}" + assert ( + not source_locs or not self.Symbol[-1].was_inlined + ), f"The last symbol should never be inlined: {self.Symbol}, {source_locs}" + return binary_location, source_locs + + +LLVMSymbolizerList: TypeAlias = List[LLVMSymbolizerEntry] + + +def parse_llvm_symbolizer_json_output_string(output: str) -> LLVMSymbolizerList: + parsed = json.loads(output) + symbols = [LLVMSymbolizerEntry.model_validate(entry) for entry in parsed] + return symbols + + +def parse_llvm_symbolizer_json_output_file(file_path: str) -> LLVMSymbolizerList: + with open(file_path, "r") as f: + parsed = json.load(f) + symbols = [LLVMSymbolizerEntry.model_validate(entry) for entry in parsed] + return symbols diff --git a/patchery/data/models/organizer_evaluation.py b/patchery/data/models/organizer_evaluation.py new file mode 100644 index 0000000..e304770 --- /dev/null +++ b/patchery/data/models/organizer_evaluation.py @@ -0,0 +1,46 @@ +from enum import Enum +import hashlib +from typing import Optional +from pydantic import Field +from patchery.data.models.base import ShellphishBaseModel + + +class SignificanceEnum(Enum): + """ + Enum for significance levels. + """ + + NoSignificantCrashRecognized = 0 + RecognizedSanitizerCrash = 211 + RecognizedNonSanitizerNotableCrash = 212 + RecognizedSanitizerSignatureDespite0ReturnCode = 213 + RecognizedErrorInReproducing = 214 + + +class OrganizerCrashEvaluation(ShellphishBaseModel): + """ + Model for the organizer's crash evaluation results. + """ + + code_label: str + significance: SignificanceEnum + significance_message: str + crash_state: str + instrumentation_key: Optional[str] = Field( + default=None, + description="The instrumentation key for the crash, if available.", + ) + + def plaintext_identifier(self) -> str: + """ + Generate a unique plaintext identifier for the crash evaluation. + :return: A string identifier based on the evaluation content. + """ + return f"{self.code_label}--{self.significance.value}--{self.significance_message}--{self.crash_state}--{self.instrumentation_key or ''}" + + def hashed_identifier(self) -> str: + """ + Generate a unique identifier for the crash evaluation based on its content. + :return: A SHA256 hash of the evaluation content. + """ + return hashlib.sha256(self.plaintext_identifier().encode("utf-8")).hexdigest() diff --git a/patchery/data/models/oss_fuzz.py b/patchery/data/models/oss_fuzz.py new file mode 100644 index 0000000..0c28b55 --- /dev/null +++ b/patchery/data/models/oss_fuzz.py @@ -0,0 +1,176 @@ +from configparser import ConfigParser +from enum import Enum +from typing import List, Optional, Any, Tuple, Dict, Union +from pathlib import Path + +from pydantic import field_validator, Field, HttpUrl +from patchery.data.models.base import ShellphishBaseModel + +from patchery.data.models.symbols import SourceLocation + + +class LanguageEnum(str, Enum): + c = "c" + cpp = "c++" + go = "go" + rust = "rust" + python = "python" + jvm = "jvm" + swift = "swift" + javascript = "javascript" + ruby = "ruby" + + +class SanitizerEnum(str, Enum): + address = "address" + memory = "memory" + undefined = "undefined" + thread = "thread" + coverage = "coverage" + none = "none" + + +class SanitizerConfig(ShellphishBaseModel): + experimental: Optional[bool] = False + + +SanitizerWithConfig = Dict[SanitizerEnum, SanitizerConfig] + + +class ArchitectureEnum(str, Enum): + x86_64 = "x86_64" + i386 = "i386" + aarch64 = "aarch64" + + +class FuzzingEngineEnum(str, Enum): + none = "none" + libfuzzer = "libfuzzer" + afl = "afl" + honggfuzz = "honggfuzz" + centipede = "centipede" + wycheproof = "wycheproof" + + +class ViewRestrictionsEnum(str, Enum): + none = "none" + + +class OSSFuzzProjectYAML(ShellphishBaseModel): + language: LanguageEnum + homepage: Optional[str] = None + primary_contact: Optional[str] = None # EmailStr + auto_ccs: Optional[Union[str, List[str]]] = None # EmailStr + main_repo: Optional[str] = ( + None # TODO: bring Url back, but has to handle git@github.com:asdf urls + ) + vendor_ccs: Optional[List[str]] = None # EmailStr + sanitizers: Optional[List[Union[SanitizerEnum, SanitizerWithConfig]]] = Field( + default=["address", "undefined"], + description="Sanitizers the project supports, can opt-in to memory sanitizer or opt out of either ubsan or asan.", + ) + architectures: Optional[List[ArchitectureEnum]] = Field( + default=["x86_64"], + description="Architectures the project supports, can opt-in to i386 or aarch64.", + ) + fuzzing_engines: Optional[List[FuzzingEngineEnum]] = Field( + default=["libfuzzer", "afl", "honggfuzz", "centipede"], + description="Fuzzing engines the project supports, can opt-in to afl or honggfuzz.", + ) + help_url: Optional[HttpUrl] = None + builds_per_day: Optional[int] = None + file_github_issue: Optional[bool] = None + coverage_extra_args: Optional[str] = None + disabled: Optional[bool] = False + blackbox: Optional[bool] = False + + run_tests: Optional[bool] = True + + view_restrictions: Optional[ViewRestrictionsEnum] = None + + labels: Dict[str, Any] = Field(default_factory=dict) + + # allegedly this is a thing, used by bitcoin-core + selective_unpack: Optional[bool] = ( + False # Required to avoid out-of-space when executing AFL on clusterfuzz bots + ) + + shellphish_docker_image: Optional[str] = None + shellphish_project_name: Optional[str] = None + + @field_validator("builds_per_day") + @classmethod + def check_builds_per_day(cls, v): + if v is not None and (v < 1 or v > 4): + raise ValueError("builds_per_day must be between 1 and 4") + return v + + def is_prebuilt(self) -> bool: + return self.shellphish_docker_image is not None + + def get_docker_image_name(self, project_name: str) -> str: + return self.shellphish_docker_image or f"oss-fuzz-{project_name}" + + def get_project_name(self) -> str: + if self.shellphish_project_name: + return self.shellphish_project_name + else: + assert False, "Project name must be set in the shellphish_project_name field or derived from the project directory name." + + +class ShellphishMetadata(ShellphishBaseModel): + fuzzing_engine: FuzzingEngineEnum + project_name: str + sanitizer: SanitizerEnum + source_repo_path: str + architecture: ArchitectureEnum + harnesses: List[str] = Field(default_factory=list) + harness_source_locations: Dict[str, SourceLocation] = Field(default_factory=dict) + known_sources: Dict[str, Any] = Field( + default_factory=dict, help="A mapping of known target sources to their contents" + ) + files_by_type: Dict[str, int] = Field( + default_factory=dict, + help="A mapping of known file types to the count of them in the source repo", + ) + + +class AugmentedProjectMetadata(OSSFuzzProjectYAML): + shellphish: ShellphishMetadata = Field(default_factory=dict) + + @property + def harnesses(self) -> List[str]: + return self.shellphish.harnesses + + @property + def harness_source_locations(self) -> Dict[str, SourceLocation]: + return self.shellphish.harness_source_locations + + @property + def source_repo_path(self) -> Path: + return Path(self.shellphish.source_repo_path) + + +class HarnessOptions(ShellphishBaseModel): + libfuzzer: Optional[List[Tuple[str, str]]] = None + afl: Optional[List[Tuple[str, str]]] = None + honggfuzz: Optional[List[Tuple[str, str]]] = None + centipede: Optional[List[Tuple[str, str]]] = None + wycheproof: Optional[List[Tuple[str, str]]] = None + none: Optional[List[Tuple[str, str]]] = None + + +class Harness(ShellphishBaseModel): + name: str + dict_path: Optional[Path] = None + options: Optional[HarnessOptions] = None + seed_corpus_tar_path: Optional[Path] = None + + @staticmethod + def from_project(project_dir: Path, harness_name: str) -> "Harness": + harness = Harness(name=harness_name) + options_path = project_dir / "out" / f"{harness_name}.options" + if options_path.exists(): + harness.options = HarnessOptions.model_validate( + ConfigParser().parse(options_path) + ) diff --git a/patchery/data/models/oss_fuzz_runner_service.py b/patchery/data/models/oss_fuzz_runner_service.py new file mode 100644 index 0000000..680a18f --- /dev/null +++ b/patchery/data/models/oss_fuzz_runner_service.py @@ -0,0 +1,200 @@ +import base64 +import hashlib +from typing import ( + List, + Optional, + Dict, + Union, +) +import uuid + +from pydantic import BaseModel, ConfigDict, field_validator, Field +from patchery.data.models.base import ShellphishBaseModel + +from patchery.data.models.constraints import PDT_ID +from patchery.data.models.oss_fuzz import LanguageEnum, SanitizerEnum +# from shellphish_crs_utils.pydatatask import PDClient + + +class Base64Bytes(BaseModel): + model_config = ConfigDict( + extra="forbid", json_encoders={bytes: lambda v: base64.b64encode(v).decode()} + ) + content: bytes + + @classmethod + def from_bytes(cls, content: bytes): + return cls.model_validate({"content": content}) + + @field_validator("content", mode="before") + def decode_base64(cls, value): + if isinstance(value, str): + return base64.b64decode(value) + return value + + +class ProjectTaskRequest(ShellphishBaseModel): + """ + This model is used to request oss-fuzz container tasks in the current cluster + """ + + # request_id: PDT_ID = Field(default_factory=lambda: str(uuid.uuid4().replace('-', ''))) + # """ + # Unique identifier for the task request. Can be used later to retrive artifacts or trigger a run with the build results + # """ + def compute_request_id(self) -> PDT_ID: + jsonned = self.model_dump_json() + return hashlib.sha256(jsonned.encode()).hexdigest()[:32] + "11" + + project_id: PDT_ID + """ + The project id must be provided and match the project id in pdt repos + """ + + docker_image: str + """ + The docker image should be the target image for the given oss project. + It may be a custom image which builds on top of the target oss-fuzz image. + + If the image is in a registry, the registry should be included. + + If in k8, the registry is given via the env `DOCKER_IMAGE_PREFIX` ie `foo.com/` + If running in local docker daemon, no registry is needed + """ + + quota: dict[str, Union[str, str]] = Field( + default_factory=lambda: { + "cpu": "6", + "mem": "26Gi", + } + ) + """ + The quota to use for the task, should be a dict either with: + - `max`: a float % of total resources available to the cluster + - `cpu` and `mem`: Explicit values for cpu and memory. Memory expects units of Gi + """ + + resource_limits: Optional[dict[str, Union[str, str]]] = Field( + default_factory=lambda: { + "cpu": "10", + "mem": "40Gi", + } + ) + """ + The resource limits to use for the task, should be a dict either with: + - `cpu` and `mem`: Explicit values for cpu and memory. Memory expects units of Gi + """ + + priority: Union[float, str] = 2.0 + """ + The pipeline priority of the task, defaults to 2.0. Higher priorites are scheduled first. + """ + + project_language: str + """ + The language of the project + """ + + sanitizer: SanitizerEnum = "address" + fuzzing_engine: str = "libfuzzer" + + extra_files: Optional[Dict[str, Base64Bytes]] = None + """ + Extra files to upload to the container + """ + + env: Dict[str, str] = Field(default_factory=dict) + """ + Extra environment variables to set in the container during the build + """ + + timeout: Optional[int] = None + + command: List[str] + """ + The command to run in the container + """ + + nonce: Optional[str] = Field(default_factory=lambda: str(uuid.uuid4())) + """ + Use to allow re-running the same request multiple times. + Set to None if you want to only get cached results (including failures) + """ + + +class ProjectBuildRequest(ProjectTaskRequest): + """ + Used to run a build task in the current cluster + """ + + patch: Optional[str] = None + """ + A patch to apply to the project source code before building + """ + + command: List[str] = ["compile"] + """ + The command to run in the container, defaults to the oss-fuzz compile command + """ + + preserve_built_src_dir: Optional[bool] = False + """ + Whether or not the built source directory should be be preserved as part of the artifacts in /out/src/ + """ + + git_ref: Optional[str] = None + """ + The git ref to checkout before building + """ + + +class ProjectBuildResult(ShellphishBaseModel): + """ + The result of a build task + """ + + request_id: PDT_ID + project_id: PDT_ID + sanitizer: SanitizerEnum + fuzzing_engine: str + language: LanguageEnum + build_success: bool + + +class ProjectRunTaskRequest(ProjectTaskRequest): + """ + Used to run a run task in the current cluster + """ + + command: List[str] + """ + Arbitrary command to run in the container + """ + + volumes_id: Optional[PDT_ID] + """ + The ID of the volumes to mount for this task. These should be uploaded to the pdt repo `oss_fuzz_project_run.project_volumes` + """ + + collect_artifacts: Optional[bool] = False + """ + If true, the output artifacts will be collected from the container and matched against the glob pattern provided in `output_artifacts_glob` + """ + + output_artifacts_globs: Optional[List[str]] = None + """ + If provided, the output artifacts will be collected from the container and matched against this glob pattern + """ + + +class ProjectRunTaskResult(ShellphishBaseModel): + """ + The result of a run task + """ + + run_success: bool + run_exit_code: int + + request_id: PDT_ID + project_id: PDT_ID + input_volumes_id: PDT_ID diff --git a/patchery/data/models/patch.py b/patchery/data/models/patch.py new file mode 100644 index 0000000..93478a8 --- /dev/null +++ b/patchery/data/models/patch.py @@ -0,0 +1,86 @@ +from pydantic import Field +from typing import Optional +from patchery.data.models.base import ShellphishBaseModel +from patchery.data.models.constraints import PDT_ID + + +class PatchMetaData(ShellphishBaseModel): + """ + Metadata for a patch. + """ + + patcher_name: str = Field(description="The name of the patcher") + total_cost: float = Field(description="The total cost of the patch") + poi_report_id: Optional[str] = Field( + description="The id of poi report", default=None + ) + pdt_harness_info_id: Optional[str] = Field( + description="The id of harness info", default=None + ) + pdt_project_id: Optional[str] = Field( + description="The id of oss fuzz project", default=None + ) + pdt_project_name: Optional[str] = Field( + description="The name of the oss fuzz project", default=None + ) + build_request_id: Optional[PDT_ID] = Field( + description="The id of the build request", default=None + ) + + +class PatchBypassRequestMeta(ShellphishBaseModel): + """ + Metadata for a patch bypass request. + """ + + project_id: str = Field(description="The id of the project") + harness_id: str = Field(description="The id of the harness") + sanitizer_name: Optional[str] = Field(description="The id of the sanitizer") + patch_id: str = Field(description="The id of the patch we just created") + mitigated_poi_report_id: str = Field( + description="The id of the poi report mitigated by this patch" + ) + patcher_name: str = Field( + description="The name of the patcher that asked for this bypass" + ) + build_request_id: str = Field(description="The id of the build request") + patch_description: Optional[str] = Field( + description="The description of the patch we just created", default=None + ) + sarif_id: Optional[str] = Field( + description="The id of the sarif report generated for this patch", default=None + ) + + +class BypassResultMeta(ShellphishBaseModel): + """ + Metadata for a patch bypass result. + """ + + patch_id: str = Field(description="The id of the patch that was bypass by discoguy") + summary: str = Field(description="The summary of how the patch was bypassed") + crashing_input_id: Optional[str] = Field( + description="The id of the new crashing input that bypassed the patch", + default=None, + ) + + +class PatchBucketRanking(ShellphishBaseModel): + bucket: list[str] = Field(description="List of patch ids in the bucket") + patch_info: dict[str, float] = Field( + description="Dictionary of patch ids and their scores" + ) + poi_report_ids: list[str] = Field( + description="List of POI report IDs associated with the patches in the bucket" + ) + ranks: list[str] = Field( + description="Sorted list of patch ids in the order of their scores" + ) + timestamp: int = Field(description="Timestamp of the bucket ranking") + + +class PatchRankings(ShellphishBaseModel): + buckets: list[PatchBucketRanking] = Field( + description="List of patch buckets with their rankings" + ) + timestamp: int = Field(description="Timestamp of the rankings file being created") diff --git a/patchery/data/models/submission.py b/patchery/data/models/submission.py new file mode 100644 index 0000000..9137a87 --- /dev/null +++ b/patchery/data/models/submission.py @@ -0,0 +1,51 @@ +from typing import Annotated, List +from pydantic import StringConstraints, Field +from patchery.data.models.base import ShellphishBaseModel +from patchery.data.models.constraints import PDT_ID, ID_CONSTRAINTS + +SHA1_REGEX = r"[0-9a-f]{40}" +SHA1_CONSTRAINTS = StringConstraints( + strip_whitespace=True, + to_upper=True, + pattern=SHA1_REGEX, + max_length=40, + min_length=40, +) + + +class CrashingCommitReport(ShellphishBaseModel): + cp_source: str = Field( + description="The challenge project source linked to crashing commit" + ) + crashing_commit: Annotated[str, SHA1_CONSTRAINTS] = Field( + description="The crashing commit sha" + ) + sanitizer_ids: List[Annotated[str, ID_CONSTRAINTS]] = Field( + description="The sanitizer ids in the crash" + ) + crash_report_id: PDT_ID = Field(description="The pydatatask pov report id") + crash_id: PDT_ID = Field(description="The pydatatask crashing input id") + harness_id: PDT_ID = Field(description="The pydatatask harness info id") + + +class PatchVerificationRequest(ShellphishBaseModel): + project_id: PDT_ID = Field(description="The pydatatask target id") + harness_id: PDT_ID = Field(description="The pydatatask harness id") + patch_id: PDT_ID = Field(description="The pydatatask patch id") + crashing_commit_sha: Annotated[str, SHA1_CONSTRAINTS] = Field( + description="The crashing commit sha" + ) + crashing_commit_report_id: PDT_ID = Field( + description="The pydatatask crashing commit id" + ) + crash_report_representative_crashing_input_id: PDT_ID = Field( + description="The pydatatask representative crashing input id" + ) + sanitizer_id: str = Field( + description="The sanitizer id reported for triggering the crash" + ) + + +class PatchVerificationResult(ShellphishBaseModel): + patch_id: PDT_ID = Field(description="The pydatatask patch id") + still_crashing: bool = Field(description="Whether the patch still crashes") diff --git a/patchery/data/models/symbols.py b/patchery/data/models/symbols.py new file mode 100644 index 0000000..5cfc6d6 --- /dev/null +++ b/patchery/data/models/symbols.py @@ -0,0 +1,460 @@ +from enum import Enum +from pydantic import ( + field_validator, + ValidationInfo, + Field, + model_validator, +) +from typing import Optional +from pathlib import Path + +from patchery.data.models.base import ShellphishBaseModel + + +class RelativePathKind(str, Enum): + ARTIFACTS_DIR = "artifacts" + TARGET_ROOT = "target_root" + SOURCE_REPO = "source" + OSS_FUZZ = "oss-fuzz" + + +class BinaryLocation(ShellphishBaseModel): + file_name: Optional[Path] = Field( + description="The name of the binary occurred", default=None + ) + full_binary_path: Path = Field(description="The full path to the binary occurred") + + package: Optional[str] = Field( + default=None, description="The package (for java, python, etc.)" + ) + offset: Optional[int] = Field( + default=None, description="The offset of the symbol in the binary" + ) + function_name: Optional[str] = Field( + default=None, description="The name of the function" + ) + build_id: Optional[str] = Field( + default=None, description="The build id of the binary" + ) + raw_signature: Optional[str] = Field( + default=None, description="The signature of the function (if available)" + ) + + symbol_offset: Optional[int] = Field( + default=None, description="The offset of the symbol (Not currently used)" + ) + symbol_size: Optional[int] = Field( + default=None, description="The size of the symbol (Not currently used)" + ) + function_index_signature: Optional[str] = Field( + default=None, + description="Function signature defined by the function indexer convention (filename:start_line:start_column::signature)", + ) + function_index_key: Optional[str] = Field( + default=None, + description="The key index of the function in the function index (The same as the function signature)", + ) + + @model_validator(mode="after") + def sanity_check_model(self) -> "BinaryLocation": + if not self.file_name and not self.full_binary_path: + raise ValueError("Neither file name nor full binary path is available??") + if not self.file_name and self.full_binary_path: + raise ValueError( + "File name is not available despite full binary path being known" + ) + + return self + + # create a binary location + @classmethod + def create( + cls, + full_binary_path=None, + file_name=None, + package=None, + offset=None, + function_name=None, + build_id=None, + raw_signature=None, + symbol_offset=None, + symbol_size=None, + function_index_signature=None, + function_index_key=None, + ): + # create a binary location + v = {} + if full_binary_path: + v["full_binary_path"] = Path(full_binary_path) + if file_name: + v["file_name"] = Path(file_name) + elif full_binary_path: + v["file_name"] = Path(full_binary_path).name + if package: + v["package"] = package + if offset: + v["offset"] = offset + if function_name: + v["function_name"] = function_name + if build_id: + v["build_id"] = build_id + if raw_signature: + v["raw_signature"] = raw_signature + if symbol_offset: + v["symbol_offset"] = symbol_offset + if symbol_size: + v["symbol_size"] = symbol_size + if function_index_signature: + v["function_index_signature"] = function_index_signature + if function_index_key: + v["function_index_key"] = function_index_key + return cls(**v) + + +class JavaInfo(ShellphishBaseModel): + full_method_path: Optional[str] = Field( + examples=[ + "net.lingala.zip4j.model.AbstractFileHeader.getZip64ExtendedInfo", + "java.lang.ProcessBuilder.start", + ], + default=None, + description="The full method path", + ) + package: Optional[str] = Field( + examples=["net.lingala.zip4j.model", "java.lang"], + default=None, + description="The package (not including the class name.", + ) + class_path: Optional[str] = Field( + examples=[ + "net.lingala.zip4j.model.AbstractFileHeader", + "java.lang.ProcessBuilder", + ], + default=None, + description="The full class path", + ) + class_name: Optional[str] = Field( + examples=["AbstractFileHeader", "ProcessBuilder"], + default=None, + description="Only the class name", + ) + method_name: Optional[str] = Field( + examples=["getZip64ExtendedInfo", "start"], + default=None, + description="The method name", + ) + package_prefix: Optional[str] = Field( + examples=[None, "java.base"], + default=None, + description="The package prefix (java.base, java.xml, app/)", + ) + method_descriptor: Optional[str] = Field( + examples=[ + "()V", + "(Z)V", + "()Ljava/lang/Exception;", + "(Lnet/lingala/zip4j/progress/ProgressMonitor$Task;)V", + ], + default=None, + description="The method descriptor", + ) + is_native_method: Optional[bool] = Field( + default=False, description="Whether the method is a native method" + ) + + @model_validator(mode="after") + def sanity_check_model(self) -> "JavaInfo": + if self.full_method_path: + if not self.package and self.full_method_path.count(".") > 1: + raise ValueError( + "Package is not set despite full method path being known" + ) + if not self.class_path: + raise ValueError( + "Class path is not set despite full method path being known" + ) + if not self.class_name: + raise ValueError( + "Class name is not set despite full method path being known" + ) + if not self.method_name: + raise ValueError( + "Method name is not set despite full method path being known" + ) + if self.class_path: + # if we know the classpath we should at least know the package and class name + if not self.package and "." in self.class_path: + raise ValueError( + f"Package is not set despite class path being known, class path: {self.class_path}" + ) + if not self.class_name: + raise ValueError("Class name is not set despite class path being known") + + return self + + @field_validator("full_method_path") + def check_valid_full_method_path(cls, value: str, info: ValidationInfo): + # ensure the full_method_path is always in the format of "package.class.method" + if value is None: + return value + + if "/" in value: + # no slashes allowed in the full method path + raise ValueError("Full method path cannot contain slashes") + + return value + + @field_validator("package_prefix") + def check_valid_package_prefix(cls, value: str, info: ValidationInfo): + # Ensure the package prefix is always "java.base" + if value is None: + return value + + # if value != "java.base": + # raise ValueError("Package prefix must be 'java.base'") + + return value + + @field_validator("class_path") + def check_valid_class_path(cls, value: str, info: ValidationInfo): + # Ensure the class path always contains dots, not slashes + if value is None: + return value + + if "/" in value: + raise ValueError("Class path cannot contain slashes") + + return value + + @field_validator("package") + def check_valid_package(cls, value: str, info: ValidationInfo): + # Ensure the package always contains dots, not slashes + if value is None: + return value + + if "/" in value: + raise ValueError("Package cannot contain slashes") + + return value + + @field_validator("class_name") + def check_valid_class_name(cls, value: str, info: ValidationInfo): + # Ensure the class name does not contain dots + if value is None: + return value + + if "." in value: + raise ValueError("Class name cannot contain dots") + + return value + + @field_validator("method_name") + def check_valid_method_name(cls, value: str, info: ValidationInfo): + # Ensure the method name does not contain dots + if value is None: + return value + + if "." in value: + raise ValueError("Method name cannot contain dots") + + return value + + @field_validator("method_descriptor") + def check_valid_method_descriptor(cls, value: str, info: ValidationInfo): + # Ensure the method descriptor is always in the format of "()V", "(Z)V", "()Ljava/lang/Exception;", "(Lnet/lingala/zip4j/progress/ProgressMonitor$Task;)V" + if value is None: + return value + + open_paren = value.count("(") + close_paren = value.count(")") + if open_paren != 1 or close_paren != 1: + raise ValueError( + "Method descriptor must contain exactly one open and close parenthesis" + ) + + if value[0] != "(": + raise ValueError("Method descriptor must start with an open parenthesis") + + if value[-1] == ")": + raise ValueError( + "Method descriptor must not end with a close parenthesis, as it should always be followed by the return type" + ) + + return value + + +class SourceLocation(ShellphishBaseModel): + # CAUTION: You should (if possible) always use source_relative_file_path. + # However, if that is not available, you can attempt to use full_file_path. + # For java crashes, the full file path is not available, in which case you should be able to at least use the file_name + # together with the method info and such from the JavaInfo object to locate the source. + focus_repo_relative_path: Optional[Path] = Field( + default=None, + description="If we know for sure this is in the focus repo, this is the path relative to the source repo.", + ) + + relative_path: Optional[Path] = Field( + description="The path to the source code of the method. The root of this path is unspecified.", + default=None, + ) + + full_file_path: Optional[Path] = Field( + default=None, + description="The full path to the file where the crash occurred. This is not very reusable as it might contain run-specific paths. Only ever use this if focus_repo_relative_path and relative_path are not available.", + ) + file_name: Optional[Path] = Field( + default=None, + description="The name of the file where the crash occurred. This might be set even if the full and relative paths are not known.", + ) + function_name: Optional[str] = Field( + default=None, description="The name of the function " + ) + + line_text: Optional[str] = Field(default=None, description="The line of code") + line_number: Optional[int] = Field(default=None, description="The line number") + symbol_offset: Optional[int] = Field( + default=None, description="The offset of the symbol (Not currently used)" + ) + symbol_size: Optional[int] = Field( + default=None, description="The size of the symbol (Not currently used)" + ) + + raw_signature: Optional[str] = Field( + default=None, description="The signature of the function (if available)" + ) + + function_index_signature: Optional[str] = Field( + default=None, + description="Function signature defined by the function indexer convention (filename:start_line:start_column::signature)", + ) + function_index_key: Optional[str] = Field( + default=None, + description="The key index of the function in the function index (The same as the function signature)", + ) + + java_info: Optional[JavaInfo] = Field( + default=None, description="Java specific information" + ) + + def __hash__(self): + return hash(self.model_dump_json()) + + def __eq__(self, other): + return self.model_dump() == other.model_dump() + + @classmethod + def create( + cls, + full_file_path=None, + relative_path=None, + file_name=None, + function_name=None, + line_number=None, + line_text=None, + symbol_offset=None, + symbol_size=None, + raw_signature=None, + focus_repo_container_path=None, + focus_repo_relative_path=None, + function_index_signature=None, + function_index_key=None, + java_info=None, + ): + v = {} + if full_file_path: + v["full_file_path"] = Path(full_file_path) + if relative_path: + v["relative_path"] = Path(relative_path) + if file_name or relative_path or full_file_path: + v["file_name"] = Path( + file_name + or (relative_path.name if relative_path else None) + or full_file_path.name + ) + if function_name: + v["function_name"] = function_name + if line_number: + v["line_number"] = line_number + if line_text: + v["line_text"] = line_text + if symbol_offset: + v["symbol_offset"] = symbol_offset + if symbol_size: + v["symbol_size"] = symbol_size + if raw_signature: + v["raw_signature"] = raw_signature + if focus_repo_relative_path: + v["focus_repo_relative_path"] = Path(focus_repo_relative_path) + elif focus_repo_container_path and full_file_path: + v["focus_repo_relative_path"] = Path(full_file_path).relative_to( + focus_repo_container_path + ) + return cls(**v) + + @model_validator(mode="after") + def sanity_check_model(self) -> "SourceLocation": + if not self.file_name and (self.full_file_path or self.relative_path): + raise ValueError( + "File name is not available despite full and/or relative paths being known" + ) + if ( + not self.file_name + and self.java_info + and self.java_info.class_name + and not self.java_info.is_native_method + ): + raise ValueError( + "File name is not available despite Java class name being known" + ) + if "/" in str(self.file_name): + raise ValueError("File name contains a slash") + + if self.file_name and self.relative_path: + assert ( + self.file_name.name == self.relative_path.name + ), f"File name {self.file_name} does not match the name of the relative path {self.relative_path.name}" + if self.file_name and self.full_file_path: + assert ( + self.file_name.name == self.full_file_path.name + ), f"File name {self.file_name} does not match the name of the full file path {self.full_file_path.name}" + + if self.function_index_key and not self.function_index_signature: + raise ValueError( + "Function index key is set but function index signature is not" + ) + if self.function_index_signature and not self.function_index_key: + raise ValueError( + "Function index signature is set but function index key is not" + ) + + if self.function_index_key: + # if we have the function index key we should know a bunch of things + if not self.file_name: + raise ValueError( + "File name is not set despite function index key being known" + ) + # if not self.line_number: + # raise ValueError('Line number is not set despite function index key being known') + if not self.function_name: + raise ValueError( + "Function name is not set despite function index key being known" + ) + if not self.full_file_path: + raise ValueError( + "Full file path is not set despite function index key being known" + ) + if not self.raw_signature: + raise ValueError( + "Raw signature is not set despite function index key being known" + ) + return self + + +class POI(ShellphishBaseModel): + reason: Optional[str] = Field( + description="The reason for the POI, normally a description of the crash", + default=None, + ) + source_location: SourceLocation = Field( + description="The source location of the crash" + ) diff --git a/patchery/data/models/target.py b/patchery/data/models/target.py new file mode 100644 index 0000000..e5a2ccc --- /dev/null +++ b/patchery/data/models/target.py @@ -0,0 +1,117 @@ +from pydantic import ( + Field, +) +from typing import Optional, TypeAlias +from pathlib import Path + +from patchery.data.models.base import ShellphishBaseModel + +from patchery.data.models.constraints import PDT_ID +from patchery.data.models.oss_fuzz import ArchitectureEnum, SanitizerEnum + +VALID_SOURCE_FILE_SUFFIXES_C = [ + ".c", + ".cpp", + ".cc", + ".cxx", + ".c++", + ".h", + ".hpp", + ".hh", + ".hxx", + ".h++", + ".inl", +] +VALID_SOURCE_FILE_SUFFIXES_JVM = [ + ".java", + #'.kt', '.scala', + #'.groovy', '.clj', '.cljs', '.cljc', '.edn', +] +VALID_SOURCE_FILE_SUFFIXES = ( + VALID_SOURCE_FILE_SUFFIXES_C + VALID_SOURCE_FILE_SUFFIXES_JVM +) + + +class ProjectInfoMixin: + project_id: PDT_ID = Field(description="The pydatatask target id") + project_name: str = Field(description="The oss fuzz project name") + + @property + def project_info(self): + return { + "project_id": self.project_id, + "project_name": self.project_name, + } + + +class BuildInfoMixin: + sanitizer: SanitizerEnum = Field( + description="The sanitizer used in this target configuration" + ) + architecture: ArchitectureEnum = Field( + description="The architecture for this target configuration" + ) + + @property + def build_info(self): + return { + "sanitizer": self.sanitizer, + "architecture": self.architecture, + } + + +class HarnessInfoMixin: + cp_harness_name: str = Field(description="The challenge project harness name") + cp_harness_binary_path: Path = Field( + description="The challenge project harness binary path" + ) + entrypoint_function: Optional[str] = Field( + description="The entrypoint function of the harness", default=None + ) + source_entrypoint: Optional[Path] = Field( + description="The source file which contains the entrypoint of the harness", + default=None, + ) + + @property + def harness_info(self): + result = { + "cp_harness_name": self.cp_harness_name, + "cp_harness_binary_path": self.cp_harness_binary_path, + } + if self.entrypoint_function: + result["entrypoint_function"] = self.entrypoint_function + if self.source_entrypoint: + result["source_entrypoint"] = self.source_entrypoint + return result + + +class ProjectHarnessMetadata(ShellphishBaseModel, ProjectInfoMixin, HarnessInfoMixin): + pass + + +class BuildConfiguration(ShellphishBaseModel, ProjectInfoMixin, BuildInfoMixin): + pass + + +class HarnessInfo( + ShellphishBaseModel, ProjectInfoMixin, BuildInfoMixin, HarnessInfoMixin +): + build_configuration_id: PDT_ID = Field( + description="The pydatatask build configuration id" + ) + project_harness_metadata_id: Optional[PDT_ID] = Field( + default=None, description="The pydatatask project harness metadata id" + ) + + +class CrashingInputMetadata(HarnessInfo): + harness_info_id: PDT_ID = Field(description="The pydatatask harness info id") + fuzzer: str = Field(description="The fuzzer used to generate the crashing input") + generated_by_sarif: Optional[str] = Field( + description="The SARIF file used to generate the crashing input", default=None + ) + + +HARNESS_NAME: TypeAlias = str +PROJECT_NAME: TypeAlias = str diff --git a/patchery/data/program.py b/patchery/data/program.py index cbc9061..d0d57b8 100644 --- a/patchery/data/program.py +++ b/patchery/data/program.py @@ -15,17 +15,22 @@ HAS_AIXCC = True try: - from shellphish_crs_utils.function_resolver import FunctionResolver, LocalFunctionResolver, RemoteFunctionResolver + from patchery.data.function_resolver import ( + FunctionResolver, + LocalFunctionResolver, + RemoteFunctionResolver, + ) except ImportError: HAS_AIXCC = False _l = logging.getLogger(__name__) + class Program: def __init__( self, source_root: Path, - function_resolver = None, + function_resolver=None, crashing_inputs: list[ProgramInput] | None = None, language=None, should_init_resolver: bool = False, @@ -34,7 +39,9 @@ def __init__( self.language = language self._crashing_inputs = crashing_inputs or [] self._should_init_resolver = should_init_resolver - self.function_resolver = function_resolver if self._should_init_resolver else None + self.function_resolver = ( + function_resolver if self._should_init_resolver else None + ) # save the args to recreate the function resolver later (if needed for pickling) if HAS_AIXCC: if isinstance(function_resolver, RemoteFunctionResolver): @@ -50,7 +57,11 @@ def __init__( ) self._saved_resolver_cls = LocalFunctionResolver - self._git_repo = git.Repo(str(self.source_root)) if (self.source_root / ".git").exists() else None + self._git_repo = ( + git.Repo(str(self.source_root)) + if (self.source_root / ".git").exists() + else None + ) self._versioned_code = {} self._latest_code = None self.crashing_function: typing.Optional[str] = None @@ -66,8 +77,7 @@ def copy(self, **kwargs) -> "Program": temp_dir = Path(tempfile.mkdtemp()) shutil.copytree(self.source_root, temp_dir, dirs_exist_ok=True) return Program( - temp_dir, crashing_inputs=self._crashing_inputs, - language=self.language + temp_dir, crashing_inputs=self._crashing_inputs, language=self.language ) def cleanup(self): @@ -79,7 +89,6 @@ def cleanup(self): if self._git_repo is not None: self._git_repo.close() - def setup_program(self): pass @@ -169,7 +178,9 @@ def git_diff(self, patch: "Patch"): for patched_func in patched_funcs: grouped_funcs[patched_func.file].append(patched_func) for patched_file in grouped_funcs: - new_code = self.file_patch_to_new_file(grouped_funcs[patched_file], lang=self.language) + new_code = self.file_patch_to_new_file( + grouped_funcs[patched_file], lang=self.language + ) with open(patched_file, "w") as f: f.write(new_code) @@ -183,7 +194,9 @@ def git_diff(self, patch: "Patch"): return patch.diff - def update_pois_for_src_path(self, poi_clusters: list[PoICluster]) -> list[PoICluster]: + def update_pois_for_src_path( + self, poi_clusters: list[PoICluster] + ) -> list[PoICluster]: new_clusters = [] for cluster in poi_clusters: updated_pois = [] @@ -241,7 +254,11 @@ def file_patch_to_new_file(grouped_funcs: "List[PatchedFunction]", lang="C") -> line_offset += new_lines_count - old_lines_count # Apply the replacement to the old_code_lines - old_code_lines = old_code_lines[:adjusted_start - 1] + new_code + old_code_lines[adjusted_end:] + old_code_lines = ( + old_code_lines[: adjusted_start - 1] + + new_code + + old_code_lines[adjusted_end:] + ) # Return the modified code as a string return "\n".join(old_code_lines) @@ -250,7 +267,9 @@ def file_patch_to_new_file(grouped_funcs: "List[PatchedFunction]", lang="C") -> def load_inputs_from_dir(input_dir: Path): inputs = [] for input_file in input_dir.iterdir(): - inputs.append(ProgramInput(input_file.absolute().read_bytes(), ProgramInputType.FILE)) + inputs.append( + ProgramInput(input_file.absolute().read_bytes(), ProgramInputType.FILE) + ) return inputs @@ -258,7 +277,13 @@ def load_inputs_from_dir(input_dir: Path): # Compilation & Building # - def compile(self, patch: typing.Optional["Patch"] = None, edited_in_place=False, flags=None, **kwargs) -> tuple[bool, str]: + def compile( + self, + patch: typing.Optional["Patch"] = None, + edited_in_place=False, + flags=None, + **kwargs, + ) -> tuple[bool, str]: git_diff = None source_path = str(Path(self.source_root).resolve().absolute()) if patch is None: @@ -278,13 +303,17 @@ def compile(self, patch: typing.Optional["Patch"] = None, edited_in_place=False, git_diff = self.git_diff(patch) if git_diff is None: - raise ValueError("Failed to create diff in a scenario where we either had no Patch or no edits") + raise ValueError( + "Failed to create diff in a scenario where we either had no Patch or no edits" + ) with tempfile.NamedTemporaryFile(delete=False) as f: f.write(git_diff.encode()) f.seek(0) patch_path = Path(f.name) - compile_res = self._compile_core(patch_path=patch_path, patch_obj=patch, flags=flags, **kwargs) + compile_res = self._compile_core( + patch_path=patch_path, patch_obj=patch, flags=flags, **kwargs + ) return compile_res @@ -301,28 +330,38 @@ def rsync_copy_dir(source: Path, dest: Path): # Execution # - def _check_functionality_core(self, **kwargs) -> tuple[ProgramExitType, typing.Optional[str]]: + def _check_functionality_core( + self, **kwargs + ) -> tuple[ProgramExitType, typing.Optional[str]]: raise NotImplementedError("Subclasses must implement this method.") - def check_functionality(self, patch: typing.Optional["Patch"] = None, **kwargs) -> tuple[ProgramExitType, typing.Optional[str]]: + def check_functionality( + self, patch: typing.Optional["Patch"] = None, **kwargs + ) -> tuple[ProgramExitType, typing.Optional[str]]: # we need to make the patch in an acceptable form git_diff = self.git_diff(patch) if patch is not None else None if patch is not None and git_diff is None: - raise ValueError("Failed to create diff in a scenario where we either had no Patch or no edits") + raise ValueError( + "Failed to create diff in a scenario where we either had no Patch or no edits" + ) if git_diff: with tempfile.NamedTemporaryFile(delete=False) as f: f.write(git_diff.encode()) f.seek(0) patch_path = Path(f.name) - return self._check_functionality_core(patch_path=patch_path, patch_obj=patch, **kwargs) + return self._check_functionality_core( + patch_path=patch_path, patch_obj=patch, **kwargs + ) else: return self._check_functionality_core(**kwargs) def execute(self, prog_input: ProgramInput): raise NotImplementedError - def generates_alerts(self, prog_input: ProgramInput) -> Tuple[ProgramExitType, str | None, list]: + def generates_alerts( + self, prog_input: ProgramInput + ) -> Tuple[ProgramExitType, str | None, list]: raise NotImplementedError def triggers_alert(self, prog_input: ProgramInput) -> bool: diff --git a/patchery/kumushi/aixcc/aicc_program.py b/patchery/kumushi/aixcc/aicc_program.py index a125031..d0982ed 100644 --- a/patchery/kumushi/aixcc/aicc_program.py +++ b/patchery/kumushi/aixcc/aicc_program.py @@ -1,27 +1,33 @@ import json import logging +import os import shutil import tempfile import typing from pathlib import Path from typing import List, Optional, Tuple -import requests -import os +import requests import yaml -from git import Repo, InvalidGitRepositoryError, NoSuchPathError -from shellphish_crs_utils.models import POIReport, PatchRequestMeta, RootCauseReport -from shellphish_crs_utils.models.oss_fuzz import AugmentedProjectMetadata -from shellphish_crs_utils.models.testguy import TestGuyLibMetaData - +from git import InvalidGitRepositoryError, NoSuchPathError, Repo +from patchery.data.function_resolver import ( + FunctionResolver, + LocalFunctionResolver, + RemoteFunctionResolver, +) +from patchery.data.models import PatchRequestMeta, POIReport, RootCauseReport +from patchery.data.models.oss_fuzz import AugmentedProjectMetadata +# from shellphish_crs_utils.models.testguy import TestGuyLibMetaData +# from shellphish_crs_utils.oss_fuzz.project import OSSFuzzProject +from patchery.data.program import Program +from patchery.data.program_alert import ProgramAlert, ProgramExitType from patchery.data.program_input import ProgramInput, ProgramInputType -from shellphish_crs_utils.oss_fuzz.project import OSSFuzzProject -from shellphish_crs_utils.function_resolver import LocalFunctionResolver, RemoteFunctionResolver, FunctionResolver -from patchery.data.program import Program -from patchery.data.program_alert import ProgramExitType, ProgramAlert +class OSSFuzzProject: + pass + if typing.TYPE_CHECKING: from . import AICCProgram @@ -51,7 +57,8 @@ def __init__( function_resolver: FunctionResolver = None, functions_by_commit_jsons_dir: Path = None, indices_by_commit_path: Path = None, - diffguy_funcs: list = None,patch_request_metadata: PatchRequestMeta = None, + diffguy_funcs: list = None, + patch_request_metadata: PatchRequestMeta = None, crashing_input_dir: Path = None, previously_built: bool = False, dyva_report: RootCauseReport = None, @@ -60,7 +67,12 @@ def __init__( build_checker_works: bool = False, **kwargs, ): - super().__init__(source_root, function_resolver=function_resolver, should_init_resolver=should_init_resolver, **kwargs) + super().__init__( + source_root, + function_resolver=function_resolver, + should_init_resolver=should_init_resolver, + **kwargs, + ) self.target_project = target_project self.harness_name = harness_name self.sanitizer_string = sanitizer_string @@ -76,7 +88,9 @@ def __init__( self.indices_by_commit_path = indices_by_commit_path self.diffguy_funcs = diffguy_funcs self.patch_request_metadata = patch_request_metadata - self.crashing_input_dir = Path(crashing_input_dir) if crashing_input_dir else None + self.crashing_input_dir = ( + Path(crashing_input_dir) if crashing_input_dir else None + ) self.dyva_report = dyva_report self.crashing_function = self._recover_crashing_function_name() self.bypassing_input_path = bypassing_input_path @@ -87,23 +101,37 @@ def _recover_crashing_function_name(self) -> Optional[str]: stack_trace = self.poi_report.stack_traces.get("main", None) if stack_trace: call_locations = stack_trace.call_locations - if call_locations and call_locations[0].source_location is not None and call_locations[0].source_location.function_name: + if ( + call_locations + and call_locations[0].source_location is not None + and call_locations[0].source_location.function_name + ): return call_locations[0].source_location.function_name - _l.critical("Failed to recover crashing function name from POI report. This is unexpected!") + _l.critical( + "Failed to recover crashing function name from POI report. This is unexpected!" + ) return None def copy(self, pre_built=False, **kwargs) -> "AICCProgram": - Path(f"/shared/patchery/{self.poi_report.project_id}").mkdir(parents=True, exist_ok=True) + Path(f"/shared/patchery/{self.poi_report.project_id}").mkdir( + parents=True, exist_ok=True + ) # first make a central folder for the new source and the new oss fuzz project - new_dir = Path(tempfile.mkdtemp(dir=f"/shared/patchery/{self.poi_report.project_id}/")) + new_dir = Path( + tempfile.mkdtemp(dir=f"/shared/patchery/{self.poi_report.project_id}/") + ) # copy the oss fuzz project new_oss_fuzz_project_path = new_dir / f"{self.target_project.project_path.name}" new_oss_fuzz_project_path.mkdir(parents=True, exist_ok=True) - shutil.copytree(self.target_project.project_path, new_oss_fuzz_project_path, dirs_exist_ok=True) + shutil.copytree( + self.target_project.project_path, + new_oss_fuzz_project_path, + dirs_exist_ok=True, + ) # copy the source code as well - new_source_path = new_dir / 'source-root' + new_source_path = new_dir / "source-root" new_source_path.mkdir(parents=True, exist_ok=True) shutil.copytree(self.source_root, new_source_path, dirs_exist_ok=True) @@ -159,29 +187,36 @@ def cleanup(self): def check_and_set_build_checker_works(self): _l.info("Checking build checker works") if not self.poi_report or not self.poi_report.build_configuration_id: - _l.warning("POI report or build configuration ID is missing, cannot check build checker") + _l.warning( + "POI report or build configuration ID is missing, cannot check build checker" + ) self.build_checker_works = False return build_configuration_id = self.poi_report.build_configuration_id try: import yaml + resp = requests.get( - f'{os.environ.get("PDT_AGENT_URL")}/data/verify_build_check_works/build_check_success/{build_configuration_id}', - timeout=180) + f"{os.environ.get('PDT_AGENT_URL')}/data/verify_build_check_works/build_check_success/{build_configuration_id}", + timeout=180, + ) if resp.status_code == 200: check_data = yaml.safe_load(resp.text) - check_success = check_data.get('runs', None) + check_success = check_data.get("runs", None) self.build_checker_works = check_success is True else: self.build_checker_works = False except Exception as e: import traceback + traceback.print_exc() self.build_checker_works = False if not self.build_checker_works: - _l.warning("Build checker does not work, will not use it for build checking") + _l.warning( + "Build checker does not work, will not use it for build checking" + ) def _build_containers(self): if not self._previously_built: @@ -190,7 +225,9 @@ def _build_containers(self): self.target_project.build_runner_image() self._previously_built = True - def _compile_core(self, patch_path: Optional[Path] = None, patch_obj = None, flags=None, **kwargs) -> Tuple[bool, str]: + def _compile_core( + self, patch_path: Optional[Path] = None, patch_obj=None, flags=None, **kwargs + ) -> Tuple[bool, str]: self._build_containers() print_output = kwargs.get("print_output", False) if patch_path is not None: @@ -203,8 +240,12 @@ def _compile_core(self, patch_path: Optional[Path] = None, patch_obj = None, fla extra_env = {"CFLAGS": flags} build_result = self.target_project.build_target( - patch_path=str(patch_path), sanitizer=self.sanitizer_string, print_output=print_output, preserve_built_src_dir=True, - extra_env=extra_env, get_cached_build=get_cached_build + patch_path=str(patch_path), + sanitizer=self.sanitizer_string, + print_output=print_output, + preserve_built_src_dir=True, + extra_env=extra_env, + get_cached_build=get_cached_build, ) if patch_obj is not None and hasattr(patch_obj, "metadata"): patch_obj.metadata["build_request_id"] = build_result.build_request_id @@ -213,12 +254,12 @@ def _compile_core(self, patch_path: Optional[Path] = None, patch_obj = None, fla build_passed = build_result.build_success # in local run, task success mand build pass mean the same if not build_passed: - stdout = build_result.stdout.decode(errors='ignore') - stderr = build_result.stderr.decode(errors='ignore') + stdout = build_result.stdout.decode(errors="ignore") + stderr = build_result.stderr.decode(errors="ignore") _l.debug(f"Compilation failed: stdout {stdout}") _l.debug(f"Compilation failed: stderr {stderr}") # FIXME: actual compilation output is saved in self.target_project.artifacts_dir_docker - reason = f"Compilation failed.\n" + f'{stderr}' + reason = f"Compilation failed.\n" + f"{stderr}" if self.language == "jvm": reason = "" lines = stdout.replace("\\n", "\n").split("\n") @@ -232,7 +273,10 @@ def _compile_core(self, patch_path: Optional[Path] = None, patch_obj = None, fla if patch_path is not None: patch_path = Path(patch_path).absolute() build_result = self.target_project.build_target( - patch_path=str(patch_path), sanitizer=self.sanitizer_string, print_output=False, preserve_built_src_dir=True + patch_path=str(patch_path), + sanitizer=self.sanitizer_string, + print_output=False, + preserve_built_src_dir=True, ) task_success = build_result.task_success build_passed = build_result.build_success @@ -244,10 +288,19 @@ def _compile_core(self, patch_path: Optional[Path] = None, patch_obj = None, fla return build_passed, reason def setup_program(self): - Path(f"/shared/patchery/{self.poi_report.project_id}").mkdir(parents=True, exist_ok=True) - assert self.target_project.project_source and self.target_project.project_source.exists(), f"Missing project source: {self.target_project.project_source}" - self.target_project.project_source = Path(self.target_project.project_source).absolute() - assert self.target_project.project_source.is_dir(), f"Project source is not a directory: {self.target_project.project_source}" + Path(f"/shared/patchery/{self.poi_report.project_id}").mkdir( + parents=True, exist_ok=True + ) + assert ( + self.target_project.project_source + and self.target_project.project_source.exists() + ), f"Missing project source: {self.target_project.project_source}" + self.target_project.project_source = Path( + self.target_project.project_source + ).absolute() + assert ( + self.target_project.project_source.is_dir() + ), f"Project source is not a directory: {self.target_project.project_source}" try: worked = Repo(self.target_project.project_source).git_dir @@ -259,17 +312,26 @@ def setup_program(self): try: Repo.init(self.target_project.project_source) except Exception as e: - raise Exception(f"Failed to initialize git repository at {self.target_project.project_source}: {e}") + raise Exception( + f"Failed to initialize git repository at {self.target_project.project_source}: {e}" + ) self._build_containers() - def generates_alerts(self, prog_input: "ProgramInput") -> Tuple[ProgramExitType, str | None, list[str]]: - raw_data = prog_input.data.encode() if isinstance(prog_input.data, str) else prog_input.data + def generates_alerts( + self, prog_input: "ProgramInput" + ) -> Tuple[ProgramExitType, str | None, list[str]]: + raw_data = ( + prog_input.data.encode() + if isinstance(prog_input.data, str) + else prog_input.data + ) run_pov_res = self.target_project.run_pov( - harness=self.harness_name, data=raw_data, + harness=self.harness_name, + data=raw_data, sanitizer=self.sanitizer_string, fuzzing_engine=self.project_metadata.shellphish.fuzzing_engine.value, - timeout=30 + timeout=30, ) pov = run_pov_res.pov pov_report_data = None @@ -283,12 +345,16 @@ def generates_alerts(self, prog_input: "ProgramInput") -> Tuple[ProgramExitType, main_stack_trace = pov.crash_report.stack_traces.get("main", None) if main_stack_trace is not None and main_stack_trace.call_locations: for call_location in main_stack_trace.call_locations: - if call_location.source_location and call_location.source_location.function_name: - stack_trace_functions.append(call_location.source_location.function_name) + if ( + call_location.source_location + and call_location.source_location.function_name + ): + stack_trace_functions.append( + call_location.source_location.function_name + ) else: stack_trace_functions.append("") - if pov.triggered_sanitizers: pov_report_data = pov.crash_report.raw_report or pov.unparsed if isinstance(pov_report_data, bytes): @@ -300,14 +366,23 @@ def generates_alerts(self, prog_input: "ProgramInput") -> Tuple[ProgramExitType, return alert._exit_type, pov_report_data, stack_trace_functions def execute(self, prog_input: "ProgramInput") -> tuple[str, str]: - raw_data = prog_input.data.encode() if isinstance(prog_input.data, str) else prog_input.data + raw_data = ( + prog_input.data.encode() + if isinstance(prog_input.data, str) + else prog_input.data + ) run_pov_res = self.target_project.run_pov( - harness=self.harness_name, data=raw_data, print_output=False, sanitizer=self.sanitizer_string, - fuzzing_engine=self.project_metadata.shellphish.fuzzing_engine.value + harness=self.harness_name, + data=raw_data, + print_output=False, + sanitizer=self.sanitizer_string, + fuzzing_engine=self.project_metadata.shellphish.fuzzing_engine.value, ) return run_pov_res.stdout, run_pov_res.stderr - def _check_functionality_core(self, patch_path: Optional[Path] = None, **kwargs) -> tuple[ProgramExitType, Optional[str]]: + def _check_functionality_core( + self, patch_path: Optional[Path] = None, **kwargs + ) -> tuple[ProgramExitType, Optional[str]]: run_result = self.target_project.run_tests( patch_path=patch_path, sanitizer=self.sanitizer_string, @@ -315,7 +390,11 @@ def _check_functionality_core(self, patch_path: Optional[Path] = None, **kwargs) ) if run_result.tests_exist: output = run_result.stderr or run_result.stdout - return (ProgramExitType.NORMAL, None) if run_result.all_passed else (ProgramExitType.TEST_FAILED, output) + return ( + (ProgramExitType.NORMAL, None) + if run_result.all_passed + else (ProgramExitType.TEST_FAILED, output) + ) else: return ProgramExitType.NORMAL, None @@ -336,67 +415,77 @@ def apply_refine_patch(self) -> bool | None: @classmethod def from_files( - cls, - source_root: Path, - # artiphishell generated files - ossfuzz_project_root: Path, - metadata_path: Path, - poi_report_path: Path, - function_indices: Path, - function_json_dir: Path, - indices_by_commit: Path | None = None, - functions_by_commit_jsons_dir: Path | None = None, - delta_mode: bool = False, - # general - crashing_input_paths: List[Path] = None, - benign_input_paths: List[Path] = None, - # coverage - coverage_build_project_path: Path = None, - aflpp_build_project_path: Path = None, - local_run: bool = False, - # diffguy - diffguy_report_path: Path = None, - patch_request_meta: Path = None, - # crash_exploration - crashing_input_dir: Path = None, - debug_build_project_path: Path = None, - # dyva - dyva_report_path: Path = None, - # bypassings inputs - bypassing_input_path: Path = None, - should_init_resolver: bool = False, - **kwargs, + cls, + source_root: Path, + # artiphishell generated files + ossfuzz_project_root: Path, + metadata_path: Path, + poi_report_path: Path, + function_indices: Path, + function_json_dir: Path, + indices_by_commit: Path | None = None, + functions_by_commit_jsons_dir: Path | None = None, + delta_mode: bool = False, + # general + crashing_input_paths: List[Path] = None, + benign_input_paths: List[Path] = None, + # coverage + coverage_build_project_path: Path = None, + aflpp_build_project_path: Path = None, + local_run: bool = False, + # diffguy + diffguy_report_path: Path = None, + patch_request_meta: Path = None, + # crash_exploration + crashing_input_dir: Path = None, + debug_build_project_path: Path = None, + # dyva + dyva_report_path: Path = None, + # bypassings inputs + bypassing_input_path: Path = None, + should_init_resolver: bool = False, + **kwargs, ): # fix paths and assert they exist when mandatory source_root = Path(source_root).absolute() assert source_root.exists(), f"Source root does not exist: {source_root}" ossfuzz_project_root = Path(ossfuzz_project_root).absolute() - assert ossfuzz_project_root.exists(), f"OSSFuzz project root does not exist: {ossfuzz_project_root}" + assert ( + ossfuzz_project_root.exists() + ), f"OSSFuzz project root does not exist: {ossfuzz_project_root}" metadata_path = Path(metadata_path).absolute() assert metadata_path.exists(), f"Metadata path does not exist: {metadata_path}" poi_report_path = Path(poi_report_path).absolute() - assert poi_report_path.exists(), f"POI report path does not exist: {poi_report_path}" + assert ( + poi_report_path.exists() + ), f"POI report path does not exist: {poi_report_path}" if indices_by_commit is not None: indices_by_commit = Path(indices_by_commit).absolute() - assert indices_by_commit.exists(), f"indices_by_commit path does not exist: {indices_by_commit}" + assert ( + indices_by_commit.exists() + ), f"indices_by_commit path does not exist: {indices_by_commit}" if functions_by_commit_jsons_dir is not None: - functions_by_commit_jsons_dir = Path(functions_by_commit_jsons_dir).absolute() + functions_by_commit_jsons_dir = Path( + functions_by_commit_jsons_dir + ).absolute() assert functions_by_commit_jsons_dir.exists(), f"functions_by_commit_jsons_dir path does not exist: {functions_by_commit_jsons_dir}" _l.info("Loading AICCProgram from files") # read the project metadata with metadata_path.open("r") as f: - project_metadata_data = AugmentedProjectMetadata.model_validate(yaml.safe_load(f)) + project_metadata_data = AugmentedProjectMetadata.model_validate( + yaml.safe_load(f) + ) # read the poi report data with open(poi_report_path, "r") as f: # rep = yaml.safe_load(f) - #rep["organizer_crash_eval"] = {} - #rep["organizer_crash_eval"]["code_label"] = "" - #rep["organizer_crash_eval"]["significance"] = 0 - #rep["organizer_crash_eval"]["significance_message"] = "" - #rep["organizer_crash_eval"]["crash_state"] = "" + # rep["organizer_crash_eval"] = {} + # rep["organizer_crash_eval"]["code_label"] = "" + # rep["organizer_crash_eval"]["significance"] = 0 + # rep["organizer_crash_eval"]["significance_message"] = "" + # rep["organizer_crash_eval"]["crash_state"] = "" poi_report_data = POIReport.model_validate(yaml.safe_load(f)) # read all the input files @@ -409,14 +498,20 @@ def from_files( benign_inputs.append(ProgramInput(f.read(), ProgramInputType.FILE)) if crashing_input_paths is not None: if isinstance(crashing_input_paths, list): - crashing_input_paths = [Path(p).absolute() for p in crashing_input_paths] + crashing_input_paths = [ + Path(p).absolute() for p in crashing_input_paths + ] for input_file in crashing_input_paths: with open(input_file, "rb") as f: - crashing_inputs.append(ProgramInput(f.read(), ProgramInputType.FILE)) + crashing_inputs.append( + ProgramInput(f.read(), ProgramInputType.FILE) + ) elif isinstance(crashing_input_paths, str): crashing_input_paths = Path(crashing_input_paths).absolute() with open(crashing_input_paths, "rb") as f: - crashing_inputs.append(ProgramInput(f.read(), ProgramInputType.FILE)) + crashing_inputs.append( + ProgramInput(f.read(), ProgramInputType.FILE) + ) # load the ossfuzz project oss_fuzz_project = OSSFuzzProject( @@ -425,17 +520,25 @@ def from_files( project_source=source_root, use_task_service=not local_run, ) - #oss_fuzz_project.project_metadata.shellphish_project_name = "nginx" + # oss_fuzz_project.project_metadata.shellphish_project_name = "nginx" # read the clang info if local_run: function_indices = Path(function_indices).absolute() - assert function_indices.exists(), f"Function indices path does not exist: {function_indices}" + assert ( + function_indices.exists() + ), f"Function indices path does not exist: {function_indices}" function_json_dir = Path(function_json_dir).absolute() - assert function_json_dir.exists(), f"Function JSON directory does not exist: {function_json_dir}" - function_resolver = LocalFunctionResolver(str(function_indices.resolve()), str(function_json_dir.resolve())) + assert ( + function_json_dir.exists() + ), f"Function JSON directory does not exist: {function_json_dir}" + function_resolver = LocalFunctionResolver( + str(function_indices.resolve()), str(function_json_dir.resolve()) + ) else: - function_resolver = RemoteFunctionResolver(poi_report_data.project_name, poi_report_data.project_id) + function_resolver = RemoteFunctionResolver( + poi_report_data.project_name, poi_report_data.project_id + ) # read diff guy report diffguy_funcs = None @@ -443,23 +546,23 @@ def from_files( with open(diffguy_report_path, "r") as f: tmp = json.load(f) diffguy_funcs = [] - if 'overlap' in tmp and tmp['overlap']: - diffguy_funcs = tmp['overlap'] - elif 'heuristic' in tmp and tmp['heuristic']: - diffguy_funcs = tmp['heuristic'] - elif 'union' in tmp and tmp['union']: - diffguy_funcs = tmp['union'] + if "overlap" in tmp and tmp["overlap"]: + diffguy_funcs = tmp["overlap"] + elif "heuristic" in tmp and tmp["heuristic"]: + diffguy_funcs = tmp["heuristic"] + elif "union" in tmp and tmp["union"]: + diffguy_funcs = tmp["union"] patch_request_metadata = None if patch_request_meta is not None and not Path(patch_request_meta).is_dir(): with open(patch_request_meta, "r") as f: data = yaml.safe_load(f) - #data["bucket_id"] = "dank" - #if "crashing_inputs_keys" in data: + # data["bucket_id"] = "dank" + # if "crashing_inputs_keys" in data: # del data["crashing_inputs_keys"] - #del data["harness_info_id"] - #del data["project_id"] - #del data["project_name"] + # del data["harness_info_id"] + # del data["project_id"] + # del data["project_name"] patch_request_metadata = PatchRequestMeta.model_validate(data) dyva_report_data = None diff --git a/patchery/kumushi/util.py b/patchery/kumushi/util.py index 9d1515b..bf2b810 100644 --- a/patchery/kumushi/util.py +++ b/patchery/kumushi/util.py @@ -4,15 +4,22 @@ from pathlib import Path import yaml -from shellphish_crs_utils.models.crs_reports import KumushiRootCauseReport, KumushiPOICluster, KumushiPOI, KumushiCodeFunction +from patchery.data.models.crs_reports import ( + KumushiRootCauseReport, + KumushiPOICluster, + KumushiPOI, + KumushiCodeFunction, +) from patchery.kumushi.code_parsing import CodeFunction from patchery.data import PoI, PoICluster, PoISource, Program -_l = (logging.getLogger(__name__)) + +_l = logging.getLogger(__name__) TMP_POI_DIR = Path("/tmp/kumushi_poi") + class timeout: - def __init__(self, seconds=1, error_message='Timeout'): + def __init__(self, seconds=1, error_message="Timeout"): self.seconds = seconds self.error_message = error_message @@ -68,7 +75,7 @@ def absolute_path_finder(src_root: Path, relative_file_path: Path) -> Path | Non if full_path.exists(): _l.critical( f"Found the file by hacking the path: %s! Clang Indexer likely failed earlier!", - relative_file_path + relative_file_path, ) return full_path @@ -85,7 +92,7 @@ def read_src_from_file(src_file, start_line, end_line, backup_code=None): with open(src_file, "r") as f: lines = f.readlines() - return "".join(lines[start_line - 1:end_line]) + return "".join(lines[start_line - 1 : end_line]) class WorkDirContext: @@ -106,14 +113,18 @@ def convert_poi_to_kumushi_poi(poi: PoI) -> KumushiPOI: kumushi_function = KumushiCodeFunction.model_validate(poi.function.to_dict()) # Create KumushiPOI return KumushiPOI( - sources=poi.sources if poi.sources else [PoISource.UNKNOWN], # You might want to adjust this default + sources=poi.sources + if poi.sources + else [PoISource.UNKNOWN], # You might want to adjust this default crash_line_number=poi.crash_line_num, crash_line=poi.crash_line, - code_function=kumushi_function + code_function=kumushi_function, ) -def convert_poi_clusters_to_kumushi_report(poi_clusters: list[PoICluster], rca_hash: str) -> KumushiRootCauseReport: +def convert_poi_clusters_to_kumushi_report( + poi_clusters: list[PoICluster], rca_hash: str +) -> KumushiRootCauseReport: # Convert each PoICluster to KumushiPOICluster kumushi_clusters = [] @@ -123,8 +134,7 @@ def convert_poi_clusters_to_kumushi_report(poi_clusters: list[PoICluster], rca_h # Create KumushiPOICluster kumushi_cluster = KumushiPOICluster( - poi_cluster=kumushi_pois, - reasoning=cluster.reasoning + poi_cluster=kumushi_pois, reasoning=cluster.reasoning ) kumushi_clusters.append(kumushi_cluster) @@ -144,18 +154,22 @@ def convert_kumushi_poi_to_poi(kumushi_poi: KumushiPOI) -> PoI: file_path=kumushi_poi.code_function.file_path, code=kumushi_poi.code_function.code, global_vars=kumushi_poi.code_function.global_vars, - version=kumushi_poi.code_function.version + version=kumushi_poi.code_function.version, ) # Create PoI return PoI( - sources=kumushi_poi.sources if kumushi_poi.sources else [PoISource.UNKNOWN], # You might want to adjust this default + sources=kumushi_poi.sources + if kumushi_poi.sources + else [PoISource.UNKNOWN], # You might want to adjust this default crash_line_num=kumushi_poi.crash_line_number, crash_line=kumushi_poi.crash_line, - function=function + function=function, ) -def convert_kumushi_report_to_poi_clusters(kumushi_report: KumushiRootCauseReport) -> list[PoICluster]: +def convert_kumushi_report_to_poi_clusters( + kumushi_report: KumushiRootCauseReport, +) -> list[PoICluster]: poi_clusters = [] for kumushi_cluster in kumushi_report.poi_clusters: @@ -165,19 +179,28 @@ def convert_kumushi_report_to_poi_clusters(kumushi_report: KumushiRootCauseRepor return poi_clusters -def save_clusters_to_yaml(poi_clusters: list[PoICluster], output_file: Path, rca_hash: str, program: Program): +def save_clusters_to_yaml( + poi_clusters: list[PoICluster], output_file: Path, rca_hash: str, program: Program +): # update pois to be source relative new_clusters = [] for cluster in poi_clusters: new_pois = [] for poi in cluster.pois: try: - poi.function.file_path = poi.function.file_path.relative_to(program.source_root) + poi.function.file_path = poi.function.file_path.relative_to( + program.source_root + ) except Exception as e: - _l.warning("Failed to make the path relative to the source root:", exc_info=True) + _l.warning( + "Failed to make the path relative to the source root:", + exc_info=True, + ) new_pois.append(poi) - new_clusters.append(PoICluster(new_pois, reasoning=cluster.reasoning, source=cluster.source)) + new_clusters.append( + PoICluster(new_pois, reasoning=cluster.reasoning, source=cluster.source) + ) # Convert to Kumushi format kumushi_report = convert_poi_clusters_to_kumushi_report(new_clusters, rca_hash) @@ -186,15 +209,17 @@ def save_clusters_to_yaml(poi_clusters: list[PoICluster], output_file: Path, rca report_dict = kumushi_report.model_dump() # Save to YAML - with open(output_file, 'w') as f: + with open(output_file, "w") as f: yaml.safe_dump(report_dict, f, default_flow_style=False, sort_keys=False) def load_clusters_from_yaml(yaml_path: Path, program: Program) -> list[PoICluster]: - with open(yaml_path, 'r') as f: + with open(yaml_path, "r") as f: report_dict = yaml.safe_load(f) - kumushi_report: KumushiRootCauseReport = KumushiRootCauseReport.model_validate(report_dict) + kumushi_report: KumushiRootCauseReport = KumushiRootCauseReport.model_validate( + report_dict + ) poi_clusters = convert_kumushi_report_to_poi_clusters(kumushi_report) # update pois to be source relative @@ -204,10 +229,13 @@ def load_clusters_from_yaml(yaml_path: Path, program: Program) -> list[PoICluste for poi in cluster.pois: poi.function.file_path = program.source_root / poi.function.file_path new_pois.append(poi) - new_clusters.append(PoICluster(new_pois, reasoning=cluster.reasoning, source=cluster.source)) + new_clusters.append( + PoICluster(new_pois, reasoning=cluster.reasoning, source=cluster.source) + ) return new_clusters + def save_clusters_to_file(clusters: list["PoICluster"], file_path: Path) -> None: """ Save a list of PoICluster instances to a file using pickle. @@ -232,6 +260,7 @@ def save_clusters_to_file(clusters: list["PoICluster"], file_path: Path) -> None except OSError as e: raise OSError(f"Failed to save list of PoIClusters to {file_path}: {str(e)}") + def load_clusters_from_file(file_path: Path) -> list["PoICluster"]: """ Load a list of PoICluster instances from a pickle file. @@ -256,10 +285,12 @@ def load_clusters_from_file(file_path: Path) -> list["PoICluster"]: with open(file_path, "rb") as f: clusters = pickle.load(f) if not isinstance(clusters, list): - raise pickle.UnpicklingError(f"Expected a list of PoICluster objects, got {type(clusters)}") + raise pickle.UnpicklingError( + f"Expected a list of PoICluster objects, got {type(clusters)}" + ) return clusters except OSError as e: _l.info(f"Failed to load PoIClusters from {file_path}: {str(e)}") return [] except pickle.UnpicklingError as e: - raise pickle.UnpicklingError(f"Invalid pickle data in {file_path}: {str(e)}") \ No newline at end of file + raise pickle.UnpicklingError(f"Invalid pickle data in {file_path}: {str(e)}") diff --git a/patchery/ranker.py b/patchery/ranker.py index 67914d0..852878a 100644 --- a/patchery/ranker.py +++ b/patchery/ranker.py @@ -10,7 +10,7 @@ from patchery.deduplicator import PatchDeduplicator from patchery.data import Program -from shellphish_crs_utils.models.patch import PatchRankings +from patchery.data.models.patch import PatchRankings _l = logging.getLogger(__name__) @@ -51,7 +51,9 @@ def __init__( if still_crashing_percent: self.still_crashing_percent.update(still_crashing_percent) - self._rank_output_dir = Path(rank_output_dir) if rank_output_dir is not None else None + self._rank_output_dir = ( + Path(rank_output_dir) if rank_output_dir is not None else None + ) # output of the ranking self.scored_patches = {} @@ -71,13 +73,20 @@ def score_patches(self): for patch in self._patches: self.scored_patches[patch] = self.score_patch_badness(patch) - self.ranked_patches = sorted(self.scored_patches, key=lambda x: self.scored_patches[x]) + self.ranked_patches = sorted( + self.scored_patches, key=lambda x: self.scored_patches[x] + ) timestamp = int(time.time_ns()) output_yaml_data = { "ranks": [str(Path(p.file_path).stem) for p in self.ranked_patches], - "patch_info": {str(Path(p.file_path).stem): self.scored_patches[p] for p in self.ranked_patches}, + "patch_info": { + str(Path(p.file_path).stem): self.scored_patches[p] + for p in self.ranked_patches + }, "timestamp": timestamp, - "poi_report_ids": list(set(p.metadata["poi_report_id"] for p in self._patches)), + "poi_report_ids": list( + set(p.metadata["poi_report_id"] for p in self._patches) + ), } return output_yaml_data @@ -106,10 +115,12 @@ def score_patch_badness(self, patch: Patch) -> float: # than this is actually an invalid patch, which gets a penalty score that is HUGE # TODO: make this generic again after submission # right now we assume everything in this is a crashing input on the patched binay - #prev_still_crashing_inputs = self._prev_crash_inputs[patch] + # prev_still_crashing_inputs = self._prev_crash_inputs[patch] still_crashing_percent = self.still_crashing_percent[patch] if still_crashing_percent > 0: - _l.warning(f"Patch {patch.file_path.name} is likely invalid because old crashes still crashes the patched binary!") + _l.warning( + f"Patch {patch.file_path.name} is likely invalid because old crashes still crashes the patched binary!" + ) final_score += self.INVALID_PATCH_PENALTY * still_crashing_percent self.invalidated_patches.add(patch) @@ -149,12 +160,22 @@ def rank_many_aicc_patch_dirs( """ if not patches_dir.exists(): patches_dir.mkdir(parents=True) - _l.warning(f"Created patches directory {patches_dir} because it did not exist before!") + _l.warning( + f"Created patches directory {patches_dir} because it did not exist before!" + ) # normalize paths patches_dir = Path(patches_dir).absolute() - patch_metadatas_dir = Path(patch_metadatas_dir).absolute() if patch_metadatas_dir is not None else None - previous_crashes_dir = Path(previous_crashes_dir).absolute() if previous_crashes_dir is not None else None + patch_metadatas_dir = ( + Path(patch_metadatas_dir).absolute() + if patch_metadatas_dir is not None + else None + ) + previous_crashes_dir = ( + Path(previous_crashes_dir).absolute() + if previous_crashes_dir is not None + else None + ) patch_crash_percent = {} for patch_metadata_file in patch_metadatas_dir.iterdir(): @@ -166,7 +187,9 @@ def rank_many_aicc_patch_dirs( try: metadata = yaml.safe_load(patch_metadata_file.read_text()) except Exception as e: - _l.error(f"Error loading metadata file {patch_metadata_file}: {e}, skipped for ranking!") + _l.error( + f"Error loading metadata file {patch_metadata_file}: {e}, skipped for ranking!" + ) continue patch = Patch.from_git_diff(patch_file, metadata=metadata) @@ -178,7 +201,9 @@ def rank_many_aicc_patch_dirs( patch_buckets = deduplicator.deduplicate() output = {"buckets": [], "timestamp": None} for bucket in patch_buckets: - crash_perc_by_patch = {patch: patch_crash_percent[patch] for patch in bucket} + crash_perc_by_patch = { + patch: patch_crash_percent[patch] for patch in bucket + } ranker: PatchRanker = cls( bucket, continuous=continuous, @@ -195,11 +220,18 @@ def rank_many_aicc_patch_dirs( timestamp = int(time.time_ns()) output["timestamp"] = timestamp parsed_model = PatchRankings.model_validate(output) - output_file = rank_output_dir / f"{PatchRanker.RANK_FILE_PREFIX}{timestamp}.yaml" + output_file = ( + rank_output_dir / f"{PatchRanker.RANK_FILE_PREFIX}{timestamp}.yaml" + ) # now dump the parsed model to a yaml file with open(output_file, "w") as fp: - yaml.safe_dump(parsed_model.model_dump(), fp, default_flow_style=False, sort_keys=False) + yaml.safe_dump( + parsed_model.model_dump(), + fp, + default_flow_style=False, + sort_keys=False, + ) _l.info(f"Ranking output written to {output_file}") diff --git a/patchery/utils.py b/patchery/utils.py index 7a05e7d..c4f59f6 100644 --- a/patchery/utils.py +++ b/patchery/utils.py @@ -17,7 +17,7 @@ _l = logging.getLogger(__name__) -MULTITHREAD_LOG_FOLDER_BASE = '/tmp/patchery/thread_logs' +MULTITHREAD_LOG_FOLDER_BASE = "/tmp/patchery/thread_logs" class WorkDirContext: @@ -56,8 +56,8 @@ def read_src_from_file(src_file, start_line, end_line, backup_code=None): with open(src_file, "r") as f: lines = f.readlines() - - return "".join(lines[start_line-1:end_line]) + + return "".join(lines[start_line - 1 : end_line]) def find_src_root_from_commit(target_root: Path, commit: str) -> Optional[Path]: @@ -88,6 +88,7 @@ def find_src_root_from_commit(target_root: Path, commit: str) -> Optional[Path]: # Hashing # + def md5_hash(bstring: bytes) -> str: hasher = hashlib.md5() hasher.update(bstring) @@ -144,7 +145,7 @@ def absolute_path_finder(src_root: Path, relative_file_path: Path) -> Path | Non if full_path.exists(): _l.critical( f"Found the file by hacking the path: %s! Clang Indexer likely failed earlier!", - relative_file_path + relative_file_path, ) return full_path @@ -166,41 +167,82 @@ def _normalize_hash_score(score: int): "claude-3.7-sonnet": "claude-3.7-sonnet", "o3-mini": "o3-mini", "oai-gpt-o3-mini": "o3-mini", - 'o4-mini': 'o4-mini', - 'oai-gpt-o4-mini': 'o4-mini', - 'o3': 'o3', - 'oai-gpt-o3': 'o3', - 'gpt-4.1': 'gpt-4.1', - 'claude-4-sonnet': 'claude-4-sonnet', + "o4-mini": "o4-mini", + "oai-gpt-o4-mini": "o4-mini", + "o3": "o3", + "oai-gpt-o3": "o3", + "gpt-4.1": "gpt-4.1", + "claude-4-sonnet": "claude-4-sonnet", } -def llm_model_name(model: str = "", agentlib = False) -> str: +def llm_model_name(model: str = "", agentlib=False) -> str: if model.strip() == "": model = os.getenv("LLM_MODEL_NAME", "claude-3.7-sonnet") return LLM_MAPPING.get(model) if model not in LLM_MAPPING.keys(): - raise ValueError(f"Invalid LLM model name: {model}, you should use one of {LLM_MAPPING.keys()}") + raise ValueError( + f"Invalid LLM model name: {model}, you should use one of {LLM_MAPPING.keys()}" + ) if agentlib: return model return LLM_MAPPING.get(model) -def llm_cost(model_name: str, prompt_tokens: int, completion_tokens: int, cached_prompt_tokens: int = 0): +def llm_cost( + model_name: str, + prompt_tokens: int, + completion_tokens: int, + cached_prompt_tokens: int = 0, +): # these are the $x per Million tokens cost = { "oai-gpt-4-turbo": {"prompt_price": 10, "completion_price": 30}, "oai-gpt-4": {"prompt_price": 30, "completion_price": 60}, - "oai-gpt-4o": {"prompt_price": 2.5, "cached_prompt_price": 1.25, "completion_price": 10}, - "oai-gpt-o1-preview": {"prompt_price": 15, "cached_prompt_price": 7.5, "completion_price": 60}, - "oai-gpt-o3-mini": {"prompt_price": 1.1, "cached_prompt_price": 0.55, "completion_price": 4.4}, + "oai-gpt-4o": { + "prompt_price": 2.5, + "cached_prompt_price": 1.25, + "completion_price": 10, + }, + "oai-gpt-o1-preview": { + "prompt_price": 15, + "cached_prompt_price": 7.5, + "completion_price": 60, + }, + "oai-gpt-o3-mini": { + "prompt_price": 1.1, + "cached_prompt_price": 0.55, + "completion_price": 4.4, + }, "claude-3-5-sonnet-20241022": {"prompt_price": 3, "completion_price": 15}, "claude-3-7-sonnet-20250219": {"prompt_price": 3, "completion_price": 15}, } - llm_price = cost.get(model_name, cost.get(f'oai-{model_name}')) - prompt_price = ( (prompt_tokens - cached_prompt_tokens) / 1000000) * llm_price["prompt_price"] + llm_price = cost.get(model_name, cost.get(f"oai-{model_name}")) + prompt_price = ((prompt_tokens - cached_prompt_tokens) / 1000000) * llm_price[ + "prompt_price" + ] completion_price = (completion_tokens / 1000000) * llm_price["completion_price"] - cached_prompt_price = (cached_prompt_tokens / 1000000) * llm_price.get("cached_prompt_price", 0) + cached_prompt_price = (cached_prompt_tokens / 1000000) * llm_price.get( + "cached_prompt_price", 0 + ) cost = round(prompt_price + completion_price + cached_prompt_price, 5) return cost + + +def is_true_value(value): + if value is None: + return False + + elif value.lower() in ["true", "1", "yes", "y"]: + return True + + elif value.lower() in ["false", "0", "no", "n"]: + return False + + else: + raise ValueError(f"Invalid value for boolean conversion: {value}") + + +def artiphishell_should_fail_on_error(): + return is_true_value(os.environ.get("ARTIPHISHELL_FAIL_EARLY", None)) diff --git a/patchery/verifier/patch_verifier.py b/patchery/verifier/patch_verifier.py index 67f02ed..14826c9 100644 --- a/patchery/verifier/patch_verifier.py +++ b/patchery/verifier/patch_verifier.py @@ -10,15 +10,16 @@ BaseVerificationPass, CompileVerificationPass, AlertEliminationVerificationPass, - SyzCallerVerificationPass, - FunctionalityVerificationPass, DuplicateVerificationPass, NewCodeCheckPass, - RegressionPass + # RegressionPass, + # SyzCallerVerificationPass, + # FunctionalityVerificationPass, ) -from .verification_passes.fuzz_pass import FuzzVerificationPass -from .verification_passes.ossfuzz_build_check_pass import OssFuzzBuildCheckPass -from .. import Patch + +# from .verification_passes.fuzz_pass import FuzzVerificationPass +# from .verification_passes.ossfuzz_build_check_pass import OssFuzzBuildCheckPass +from patchery.data.patch import Patch if typing.TYPE_CHECKING: from patchery.patcher import Patcher @@ -31,17 +32,26 @@ class PatchVerifier: (DuplicateVerificationPass, True), (NewCodeCheckPass, True), (CompileVerificationPass, True), - (OssFuzzBuildCheckPass, True), + # (OssFuzzBuildCheckPass, True), (AlertEliminationVerificationPass, True), - (RegressionPass, True), - (FunctionalityVerificationPass, True), - (SyzCallerVerificationPass, False), - (FuzzVerificationPass, True), + # (RegressionPass, False), + # (FunctionalityVerificationPass, False), + # (SyzCallerVerificationPass, False), + # (FuzzVerificationPass, False), ] - def __init__(self, prog_info: AICCProgram, initial_failure_heat=0.0, passes=None, smart_mode=False, patcher=None): + def __init__( + self, + prog_info: AICCProgram, + initial_failure_heat=0.0, + passes=None, + smart_mode=False, + patcher=None, + ): self._prog_info = prog_info - self._passes: List[Tuple[Type[BaseVerificationPass], bool]] = passes or self.DEFAULT_PASSES + self._passes: List[Tuple[Type[BaseVerificationPass], bool]] = ( + passes or self.DEFAULT_PASSES + ) self.smart_mode = smart_mode self._patcher: "Patcher" = patcher @@ -50,9 +60,14 @@ def __init__(self, prog_info: AICCProgram, initial_failure_heat=0.0, passes=None self.failed_patches = set() # shared data for passes - os.makedirs(f"/shared/patchery/{self._prog_info.poi_report.project_id}", exist_ok=True) + os.makedirs( + f"/shared/patchery/{self._prog_info.poi_report.project_id}", exist_ok=True + ) self.regression_fuzzing_dir = Path( - tempfile.TemporaryDirectory(dir=f"/shared/patchery/{self._prog_info.poi_report.project_id}", prefix="regression_fuzz_").name + tempfile.TemporaryDirectory( + dir=f"/shared/patchery/{self._prog_info.poi_report.project_id}", + prefix="regression_fuzz_", + ).name ) def verify(self, patch: Patch) -> Tuple[bool, Any]: @@ -60,12 +75,16 @@ def verify(self, patch: Patch) -> Tuple[bool, Any]: reasoning = None for pass_cls, should_run in self._passes: if self._patcher and not self._patcher.should_work: - _l.warning("The patcher is shutting down all threads! Stopping verification...") + _l.warning( + "The patcher is shutting down all threads! Stopping verification..." + ) verified = False reasoning = "Patcher is shutting down" break - verifier = pass_cls(self._prog_info, patch, verifier=self, smart_mode=self.smart_mode) + verifier = pass_cls( + self._prog_info, patch, verifier=self, smart_mode=self.smart_mode + ) force_skip, skip_reason = verifier.should_skip() if not should_run or force_skip: skip_reason = skip_reason if force_skip else "Pass disabled" @@ -76,7 +95,12 @@ def verify(self, patch: Patch) -> Tuple[bool, Any]: try: verifier.verify() except Exception as e: - _l.error("❌ %s failed with an exception: %s... skipping and assuming pass.", pass_cls.__name__, e, exc_info=True) + _l.error( + "❌ %s failed with an exception: %s... skipping and assuming pass.", + pass_cls.__name__, + e, + exc_info=True, + ) if not verifier.FAIL_ON_EXCEPTION: continue # exception had an internal error, but is dangerous enough to fail the verification process @@ -94,5 +118,9 @@ def verify(self, patch: Patch) -> Tuple[bool, Any]: _l.info(f"✅ {pass_cls.__name__} passed") - _l.info("✅ 🎉 Patch is verified!!!!" if verified else f"❌ 🤡 Patch is NOT verified: {reasoning}") + _l.info( + "✅ 🎉 Patch is verified!!!!" + if verified + else f"❌ 🤡 Patch is NOT verified: {reasoning}" + ) return verified, reasoning diff --git a/patchery/verifier/verification_passes/__init__.py b/patchery/verifier/verification_passes/__init__.py index 195c6af..8857bd4 100644 --- a/patchery/verifier/verification_passes/__init__.py +++ b/patchery/verifier/verification_passes/__init__.py @@ -2,7 +2,7 @@ from .compile_pass import CompileVerificationPass from .alert_elim_pass import AlertEliminationVerificationPass from .syz_caller_pass import SyzCallerVerificationPass -from .func_verification_pass import FunctionalityVerificationPass from .duplicate_check_pass import DuplicateVerificationPass from .new_code_check_pass import NewCodeCheckPass -from .regression_pass import RegressionPass +# from .regression_pass import RegressionPass +# from .func_verification_pass import FunctionalityVerificationPass diff --git a/patchery/verifier/verification_passes/func_verification_pass.py b/patchery/verifier/verification_passes/func_verification_pass.py index 0f114de..fc630b2 100644 --- a/patchery/verifier/verification_passes/func_verification_pass.py +++ b/patchery/verifier/verification_passes/func_verification_pass.py @@ -4,12 +4,12 @@ from patchery.kumushi.aixcc import AICCProgram from .base_verification_pass import BaseVerificationPass -from shellphish_crs_utils.models.testguy import TestGuyMetaData from patchery.data import ProgramExitType _l = logging.getLogger(__name__) + class FunctionalityVerificationPass(BaseVerificationPass): def __init__(self, *args, requires_executor=True, **kwargs): super().__init__(*args, requires_executor=requires_executor, **kwargs) diff --git a/patchery/verifier/verification_passes/fuzz_pass.py b/patchery/verifier/verification_passes/fuzz_pass.py index a83b974..e4006f4 100644 --- a/patchery/verifier/verification_passes/fuzz_pass.py +++ b/patchery/verifier/verification_passes/fuzz_pass.py @@ -11,14 +11,14 @@ from pathlib import Path from typing import Optional -from shellphish_crs_utils.oss_fuzz.instrumentation.jazzer import JazzerInstrumentation +# from shellphish_crs_utils.oss_fuzz.instrumentation.jazzer import JazzerInstrumentation from patchery.data import ProgramExitType from .base_verification_pass import BaseVerificationPass from patchery.data.program_input import ProgramInput, ProgramInputType -from shellphish_crs_utils.oss_fuzz.project import InstrumentedOssFuzzProject -from shellphish_crs_utils.oss_fuzz.instrumentation.aflpp import AFLPPInstrumentation -from shellphish_crs_utils.oss_fuzz.instrumentation.aijon import AIJONInstrumentation +# from shellphish_crs_utils.oss_fuzz.project import InstrumentedOssFuzzProject +# from shellphish_crs_utils.oss_fuzz.instrumentation.aflpp import AFLPPInstrumentation +# from shellphish_crs_utils.oss_fuzz.instrumentation.aijon import AIJONInstrumentation if typing.TYPE_CHECKING: from patchery.verifier import PatchVerifier @@ -28,6 +28,7 @@ from ...data import JAZZER_CMD_INJECT_STR import litellm + litellm.set_verbose = False _l = logging.getLogger(__name__) @@ -38,7 +39,7 @@ This indicates that the patch introduced a bug or instability in the program. This was the patch that was applied: -### Patch +### Patch ``` %s ``` @@ -65,27 +66,34 @@ FUZZING_LOCK = threading.Lock() + class FuzzVerificationPass(BaseVerificationPass): - TIMEOUT = 60*15 # 15 minutes + TIMEOUT = 60 * 15 # 15 minutes - TOTAL_FUZZING_TIME = 60*5 # 5 minutes + TOTAL_FUZZING_TIME = 60 * 5 # 5 minutes THREAD_FUZZING = False USE_AIJON = False SAVE_CRASHES = True + def __init__(self, *args, verifier: "PatchVerifier" = None, **kwargs): self._verifier = verifier assert self._verifier is not None, "FuzzVerificationPass requires a verifier" super().__init__(*args, **kwargs) self.save_folder = f"/shared/patchery/{self._prog_info.poi_report.project_id}" - def _prepare_seeds(self, make_dummy_input=False, save_loc: Path = None, zip_seeds=True) -> Path: + def _prepare_seeds( + self, make_dummy_input=False, save_loc: Path = None, zip_seeds=True + ) -> Path: # Create a seed corpus from crashing inputs os.makedirs(self.save_folder, exist_ok=True) corpus_dir = tempfile.mkdtemp(dir=self.save_folder) # Collect inputs and ensure we have at least 3 inputs = [] - if hasattr(self._prog_info, "_crashing_inputs") and self._prog_info._crashing_inputs: + if ( + hasattr(self._prog_info, "_crashing_inputs") + and self._prog_info._crashing_inputs + ): inputs = self._prog_info._crashing_inputs # If we have fewer than 3 inputs, duplicate some to reach 3 @@ -94,19 +102,23 @@ def _prepare_seeds(self, make_dummy_input=False, save_loc: Path = None, zip_seed inputs.append(original_inputs[0]) if make_dummy_input: - dummy_input = ProgramInput(b'fuzz', ProgramInputType.STDIN) + dummy_input = ProgramInput(b"fuzz", ProgramInputType.STDIN) inputs.append(dummy_input) # Save inputs to files and add to zip if zip_seeds: import zipfile - corpus_zip = Path(corpus_dir) / f"{self._prog_info.poi_report.cp_harness_name}_seed_corpus.zip" + + corpus_zip = ( + Path(corpus_dir) + / f"{self._prog_info.poi_report.cp_harness_name}_seed_corpus.zip" + ) save_loc = corpus_zip - with zipfile.ZipFile(corpus_zip, 'w') as zipf: + with zipfile.ZipFile(corpus_zip, "w") as zipf: for i, input_data in enumerate(inputs): input_path = Path(corpus_dir) / f"input_{i}" try: - with open(input_path, 'wb') as f: + with open(input_path, "wb") as f: f.write(input_data.data) zipf.write(input_path, arcname=f"input_{i}") except Exception as e: @@ -117,17 +129,24 @@ def _prepare_seeds(self, make_dummy_input=False, save_loc: Path = None, zip_seed for i, input_data in enumerate(inputs): input_path = Path(save_loc) / f"input_{i}" try: - with open(input_path, 'wb') as f: + with open(input_path, "wb") as f: f.write(input_data.data) except Exception as e: _l.error(f"Failed to save input {i}: {e}") _l.info(f"Created seed corpus at {save_loc} with {len(inputs)} inputs") else: - raise ValueError("Either zip_seeds must be True or save_loc must be provided.") + raise ValueError( + "Either zip_seeds must be True or save_loc must be provided." + ) return save_loc - def _setup_fuzzer(self, timeout=TOTAL_FUZZING_TIME, use_aijon: bool = USE_AIJON, sync_inst_dir: Path = None) -> tuple[InstrumentedOssFuzzProject, Path, dict]: + def _setup_fuzzer( + self, + timeout=TOTAL_FUZZING_TIME, + use_aijon: bool = USE_AIJON, + sync_inst_dir: Path = None, + ) -> tuple[InstrumentedOssFuzzProject, Path, dict]: is_java = self._prog_info.language in {"java", "jvm"} is_binary = self._prog_info.language in BINARY_LANGS @@ -138,66 +157,85 @@ def _setup_fuzzer(self, timeout=TOTAL_FUZZING_TIME, use_aijon: bool = USE_AIJON, project_dir = Path(tempfile.TemporaryDirectory(dir=self.save_folder).name) inter_sync_dir = Path(tempfile.mkdtemp(dir=sync_inst_dir, suffix="inter_sync")) fuzz_envs = { - 'ARTIPHISHELL_DO_NOT_CREATE_INPUT': '1', - 'FORCED_CREATE_INITIAL_INPUT': '1', - 'FORCED_FUZZER_TIMEOUT': '4', - 'FORCED_DO_CMPLOG': '1', - 'FORCED_USE_CUSTOM_MUTATOR': '1', - 'FORCED_USE_AFLPP_DICT': '0', - 'FORCED_USE_CORPUSGUY_DICT': '0', + "ARTIPHISHELL_DO_NOT_CREATE_INPUT": "1", + "FORCED_CREATE_INITIAL_INPUT": "1", + "FORCED_FUZZER_TIMEOUT": "4", + "FORCED_DO_CMPLOG": "1", + "FORCED_USE_CUSTOM_MUTATOR": "1", + "FORCED_USE_AFLPP_DICT": "0", + "FORCED_USE_CORPUSGUY_DICT": "0", "ARTIPHISHELL_AFL_TIMEOUT": str(timeout), - 'ARTIPHISHELL_INTER_HARNESS_SYNC_DIR': str(inter_sync_dir), + "ARTIPHISHELL_INTER_HARNESS_SYNC_DIR": str(inter_sync_dir), } if is_java: # corpus dir needs to be manually created for Jazzer corp_save_loc = sync_inst_dir / "corpus" corp_save_loc.mkdir(parents=True, exist_ok=True) - corpus_location = self._prepare_seeds(zip_seeds=False, save_loc=corp_save_loc) + corpus_location = self._prepare_seeds( + zip_seeds=False, save_loc=corp_save_loc + ) - fuzz_envs['FUZZING_ENGINE'] = "libfuzzer" - fuzz_envs['ARTIPHISHELL_JAZZER_BENIGN_SEEDS'] = str(corp_save_loc) - fuzz_envs['ARTIPHISHELL_JAZZER_CRASHING_SEEDS'] = str(sync_inst_dir / "crashes") + fuzz_envs["FUZZING_ENGINE"] = "libfuzzer" + fuzz_envs["ARTIPHISHELL_JAZZER_BENIGN_SEEDS"] = str(corp_save_loc) + fuzz_envs["ARTIPHISHELL_JAZZER_CRASHING_SEEDS"] = str( + sync_inst_dir / "crashes" + ) if use_aijon: - raise ValueError("AIJON instrumentation is not supported for Java projects.") + raise ValueError( + "AIJON instrumentation is not supported for Java projects." + ) elif is_binary: corpus_location = self._prepare_seeds(zip_seeds=True) - fuzz_envs['FUZZING_ENGINE'] = "shellphish_aflpp" if not use_aijon else "shellphish_aijon" + fuzz_envs["FUZZING_ENGINE"] = ( + "shellphish_aflpp" if not use_aijon else "shellphish_aijon" + ) else: raise ValueError("Unsupported language!") subprocess.run( - f'rsync -a --delete --ignore-missing-args {self._prog_info.target_project.project_path} {project_dir}', - shell=True) + f"rsync -a --delete --ignore-missing-args {self._prog_info.target_project.project_path} {project_dir}", + shell=True, + ) project_dir = project_dir / self._prog_info.target_project.project_path.name shutil.rmtree(project_dir / "artifacts", ignore_errors=True) subprocess.run(f"mkdir -p artifacts", shell=True, cwd=project_dir) # write the patch to a temporary file patch_dir = Path(tempfile.mkdtemp(dir=self.save_folder)) - with open(patch_dir / 'patch', 'w') as f: + with open(patch_dir / "patch", "w") as f: f.write(self._patch.diff) - patch_path = patch_dir / 'patch' + patch_path = patch_dir / "patch" if use_aijon: if is_java: - raise ValueError("AIJON instrumentation is not supported for Java projects.") + raise ValueError( + "AIJON instrumentation is not supported for Java projects." + ) _l.info("Using AIJON instrumentation for fuzzing") instrumentation = AIJONInstrumentation() patch_path = Path( annotate_from_patch( - patch_path, self._prog_info.source_root, self._prog_info.code._function_resolver, - language=self._prog_info.language + patch_path, + self._prog_info.source_root, + self._prog_info.code._function_resolver, + language=self._prog_info.language, ) ) patch_data = patch_path.read_text() if not patch_data: - _l.warning("AIJON instrumentation was requested, but the patch is empty.") + _l.warning( + "AIJON instrumentation was requested, but the patch is empty." + ) raise RuntimeError("Empty patch after AIJON annotation.") if "IJON" not in patch_data: - _l.warning("AIJON instrumentation was requested, but the patch does not contain AIJON annotations.") - raise RuntimeError("AIJON annotations not found in the patch after AIJON annotation.") + _l.warning( + "AIJON instrumentation was requested, but the patch does not contain AIJON annotations." + ) + raise RuntimeError( + "AIJON annotations not found in the patch after AIJON annotation." + ) elif is_binary: _l.info("Using AFL++ instrumentation for fuzzing") instrumentation = AFLPPInstrumentation() @@ -211,14 +249,16 @@ def _setup_fuzzer(self, timeout=TOTAL_FUZZING_TIME, use_aijon: bool = USE_AIJON, project_source=self._prog_info.target_project.project_source, project_id=self._prog_info.target_project.project_id, use_task_service=self._prog_info.target_project.use_task_service, - augmented_metadata=self._prog_info.target_project.augmented_metadata + augmented_metadata=self._prog_info.target_project.augmented_metadata, + ) + build_res = instr_project.build_target( + patch_path=patch_path, sanitizer=self._prog_info.poi_report.sanitizer ) - build_res = instr_project.build_target(patch_path=patch_path, sanitizer=self._prog_info.poi_report.sanitizer) if build_res.build_success and is_binary: # we copy over a zip in binary mode - shutil.copy(corpus_location, project_dir / "artifacts" / 'out') + shutil.copy(corpus_location, project_dir / "artifacts" / "out") - seed_corpus_dir = project_dir / "artifacts" / 'out' + seed_corpus_dir = project_dir / "artifacts" / "out" return instr_project, seed_corpus_dir, fuzz_envs def _save_crash_dir(self, crash_dir: Path): @@ -231,9 +271,18 @@ def _save_crash_dir(self, crash_dir: Path): shutil.copytree(crash_dir, dir_path, dirs_exist_ok=True) _l.info("Saved crash inputs to %s", dir_path) - def _fuzz_core(self, instance_name, sync_dir, timeout=TOTAL_FUZZING_TIME, use_aijon: bool = USE_AIJON, threaded=False) -> Optional[threading.Thread]: + def _fuzz_core( + self, + instance_name, + sync_dir, + timeout=TOTAL_FUZZING_TIME, + use_aijon: bool = USE_AIJON, + threaded=False, + ) -> Optional[threading.Thread]: sync_inst_dir = sync_dir / instance_name - fuzzer, seed_corpus_dir, fuzz_envs = self._setup_fuzzer(timeout=timeout, use_aijon=use_aijon, sync_inst_dir=sync_inst_dir) + fuzzer, seed_corpus_dir, fuzz_envs = self._setup_fuzzer( + timeout=timeout, use_aijon=use_aijon, sync_inst_dir=sync_inst_dir + ) def _run_fuzzer(): fuzzer.fuzz_harness( @@ -248,11 +297,14 @@ def _run_fuzzer(): ) if self.SAVE_CRASHES: # save all crashing inputs regardless of whether we found a valid crash or not - crash_dir_path = sync_inst_dir / 'crashes' + crash_dir_path = sync_inst_dir / "crashes" if crash_dir_path.exists() and crash_dir_path.is_dir(): self._save_crash_dir(crash_dir_path) else: - _l.error("Somehow the crash directory %s does not exist or is not a directory.", crash_dir_path) + _l.error( + "Somehow the crash directory %s does not exist or is not a directory.", + crash_dir_path, + ) if threaded: thread = threading.Thread(target=_run_fuzzer, daemon=True) @@ -266,8 +318,9 @@ def _run_fuzzer(): return thread - def _wait_for_valid_crash(self, crash_dir_path, fuzzing_thread, timeout=TOTAL_FUZZING_TIME) -> tuple[Optional[Path], Optional[str]]: - + def _wait_for_valid_crash( + self, crash_dir_path, fuzzing_thread, timeout=TOTAL_FUZZING_TIME + ) -> tuple[Optional[Path], Optional[str]]: crash_found = False crash_file = None crash_info = None @@ -280,7 +333,7 @@ def _wait_for_valid_crash(self, crash_dir_path, fuzzing_thread, timeout=TOTAL_FU crash_inputs = list(crash_dir_path.iterdir()) _l.info(f"Found {len(crash_inputs)} crash inputs in {crash_dir_path}") for crash_file in crash_inputs: - if crash_file.suffix == '.txt': + if crash_file.suffix == ".txt": continue if crash_file in tested_crashes: @@ -291,7 +344,9 @@ def _wait_for_valid_crash(self, crash_dir_path, fuzzing_thread, timeout=TOTAL_FU crashes, crash_info, stack_trace = self.run_pov(crash_file) if crashes and stack_trace: if self._crash_in_relevant_location(stack_trace): - _l.info(f"Found a crash input that reproduces the issue: {crash_file}") + _l.info( + f"Found a crash input that reproduces the issue: {crash_file}" + ) crash_found = True break @@ -305,51 +360,77 @@ def _wait_for_valid_crash(self, crash_dir_path, fuzzing_thread, timeout=TOTAL_FU _l.info("Finished waiting for valid crashes!") return (crash_file, crash_info) if crash_found else (None, None) - def _fuzz_for_crashes(self, timeout=TOTAL_FUZZING_TIME): instance_name = f"patcher-fuzz-{str(uuid.uuid4())[:8]}" sync_dir = Path(tempfile.mkdtemp(dir=self.save_folder)) # make sure that crashes directory exists - crash_dir_path = sync_dir / instance_name / 'crashes' + crash_dir_path = sync_dir / instance_name / "crashes" crash_dir_path.mkdir(parents=True, exist_ok=True) _l.info(f"Fuzzing crash dir path: %s", crash_dir_path) - _l.info("Fuzzing patch %s with fuzzer %s for %d seconds", self._patch.diff, self._prog_info.poi_report.fuzzer, timeout) + _l.info( + "Fuzzing patch %s with fuzzer %s for %d seconds", + self._patch.diff, + self._prog_info.poi_report.fuzzer, + timeout, + ) with FUZZING_LOCK: if self._verifier._patcher and not self._verifier._patcher.should_work: - _l.warning("The patcher is shutting down all threads! Stopping fuzzing...") + _l.warning( + "The patcher is shutting down all threads! Stopping fuzzing..." + ) return False, "Patcher is shutting down" fuzz_failed = False use_aijon = self.USE_AIJON and AIJON_AVAILABLE if use_aijon: try: - fuzz_thread = self._fuzz_core(instance_name, sync_dir, timeout=timeout, use_aijon=True, threaded=self.THREAD_FUZZING) + fuzz_thread = self._fuzz_core( + instance_name, + sync_dir, + timeout=timeout, + use_aijon=True, + threaded=self.THREAD_FUZZING, + ) except Exception as e: fuzz_failed = True - _l.critical(f"Failed to fuzz patch {instance_name}: {e} with AIJON instrumentation. Falling back to AFL++...") + _l.critical( + f"Failed to fuzz patch {instance_name}: {e} with AIJON instrumentation. Falling back to AFL++..." + ) if not use_aijon or fuzz_failed: - fuzz_thread = self._fuzz_core(instance_name, sync_dir, timeout=timeout, use_aijon=False, threaded=self.THREAD_FUZZING) + fuzz_thread = self._fuzz_core( + instance_name, + sync_dir, + timeout=timeout, + use_aijon=False, + threaded=self.THREAD_FUZZING, + ) _l.info("Waiting for a valid fuzzing crash...") - valid_crash_path, crash_info = self._wait_for_valid_crash(crash_dir_path, fuzz_thread, timeout=(timeout if self.THREAD_FUZZING else 4)) + valid_crash_path, crash_info = self._wait_for_valid_crash( + crash_dir_path, + fuzz_thread, + timeout=(timeout if self.THREAD_FUZZING else 4), + ) if self.SAVE_CRASHES: # save only the reproducing crash input to the regression dir if self._verifier.regression_fuzzing_dir and valid_crash_path is not None: - with open(valid_crash_path, 'rb') as crash_file: + with open(valid_crash_path, "rb") as crash_file: crash_data = crash_file.read() md5_hash = hashlib.md5(crash_data).hexdigest() new_crash_save = self._verifier.regression_fuzzing_dir / f"{md5_hash}" - with open(new_crash_save, 'wb') as f: + with open(new_crash_save, "wb") as f: f.write(crash_data) # congratz, no crashes found if valid_crash_path is None: if len(list(crash_dir_path.iterdir())) != 0: - _l.warning("Unable to reproduce a crash with any fuzzer found crash. The patch may be unstable, but no valid crash inputs were found.") + _l.warning( + "Unable to reproduce a crash with any fuzzer found crash. The patch may be unstable, but no valid crash inputs were found." + ) return True, "No valid crash inputs found (no reproduce)." return True, "No crashes found." @@ -359,35 +440,54 @@ def _fuzz_for_crashes(self, timeout=TOTAL_FUZZING_TIME): def _crash_in_relevant_location(self, stack_trace: list[str]) -> bool: stack_trace_slice = stack_trace[:3] - patched_functions = [f.function_name for f in self._patch.patched_functions if f and f.function_name] + patched_functions = [ + f.function_name + for f in self._patch.patched_functions + if f and f.function_name + ] # if any intersection exists between the patched functions and the stack trace, we consider it a relevant crash if any(func in stack_trace_slice for func in patched_functions): - _l.info("Fuzzer discovered crash in a patched function: %s", stack_trace_slice) + _l.info( + "Fuzzer discovered crash in a patched function: %s", stack_trace_slice + ) return True - if stack_trace_slice and self._prog_info.crashing_function == stack_trace_slice[0]: - _l.info("Fuzzer discovered crash in the original crashing function: %s", self._prog_info.crashing_function) + if ( + stack_trace_slice + and self._prog_info.crashing_function == stack_trace_slice[0] + ): + _l.info( + "Fuzzer discovered crash in the original crashing function: %s", + self._prog_info.crashing_function, + ) return True return False def run_pov(self, pov_file: Path) -> tuple[bool, str, list[str]]: - with open(pov_file, 'rb') as pov_file: + with open(pov_file, "rb") as pov_file: input_obj = ProgramInput(pov_file.read(), ProgramInputType.STDIN) - exit_type, pov_report, stack_trace_funcs = self._prog_info.generates_alerts(input_obj) + exit_type, pov_report, stack_trace_funcs = self._prog_info.generates_alerts( + input_obj + ) if exit_type == ProgramExitType.TRIGGERED: san_info = "unknown" - if ("AICC" in str(self._prog_info) + if ( + "AICC" in str(self._prog_info) and self._prog_info.sanitizer_string is not None and JAZZER_CMD_INJECT_STR not in self._prog_info.sanitizer_string ): san_info = self._prog_info.sanitizer_string - crash_info = f"Bug still triggered after patching with sanitizer: {san_info}\n" + crash_info = ( + f"Bug still triggered after patching with sanitizer: {san_info}\n" + ) crash_info += f"\n {pov_report}" reasoning = REASONING % (self._patch.diff, crash_info) if input_obj.is_human_readable() and len(input_obj.data) < 5000: - reasoning += INPUT_INFO % input_obj.data.decode('utf-8', errors='replace') + reasoning += INPUT_INFO % input_obj.data.decode( + "utf-8", errors="replace" + ) crashes = True elif exit_type == ProgramExitType.INTERNAL_ERROR: @@ -408,10 +508,12 @@ def _verify(self): def should_skip(self): if not self._prog_info.crashing_function: - _l.critical("No crashing function set in program info, cannot run fuzz verification pass.") + _l.critical( + "No crashing function set in program info, cannot run fuzz verification pass." + ) return True, "No crashing function set in program info." if not self.smart_mode: return True, "Fuzz verification pass is only applicable to smart modes." - return super().should_skip() \ No newline at end of file + return super().should_skip() diff --git a/pyproject.toml b/pyproject.toml index 1eb23cb..f553cdc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,7 @@ requires-python = ">= 3.10" dependencies = [ "tqdm", "pyyaml", - "gitpython", + "GitPython", "tree-sitter==0.21.3", "tree-sitter-languages", "PyYAML", @@ -24,11 +24,16 @@ dependencies = [ "jinja2", 'py-tlsh', 'openai', + "langchain>=1.0.0a7", 'requests', 'unidiff', 'numpy', - 'litellm==1.63.14' + "pre-commit", + "pydantic", + 'jq', + "ruff" ] + dynamic = ["version"] [project.readme] @@ -58,3 +63,24 @@ find = {namespaces = false} [tool.setuptools.dynamic] version = {attr = "patchery.__version__"} + +[tool.ruff] +target-version = "py310" + +[tool.ruff.lint] +# Enable basic rule sets +select = [ + "E", # pycodestyle errors + "F", # pyflakes + "I", # isort (import sorting), + "UP", # Pyupgrade + "D", # Pydocstyle +] +# Ignore specific rules +ignore = [ + "E501", # line too long (let formatter handle it) +] + +[tool.ruff.lint.per-file-ignores] +# Allow unused imports in __init__.py files +"__init__.py" = ["F401"] diff --git a/tests/generic_tests/targets/hamlin/README.md b/tests/generic_tests/targets/hamlin/README.md new file mode 100644 index 0000000..5200313 --- /dev/null +++ b/tests/generic_tests/targets/hamlin/README.md @@ -0,0 +1,9 @@ +# AIxCC Hamlin Edition +This version of Hamlin has been modified to work with the AIxCC pipeline. +The main changes are the following: + +* added a run.sh script which applies patches, builds, and allows running the program in the AIxCC format +* updated the source Makefile to use clang, with `-g` and `-fsanitize=address` + +## Running +Like all AIxCC challenges, just use the `run.sh` located in the root of the `challenge` dir. diff --git a/tests/generic_tests/targets/hamlin/alerting_inputs/crash_input b/tests/generic_tests/targets/hamlin/alerting_inputs/crash_input new file mode 100644 index 0000000..10196f6 Binary files /dev/null and b/tests/generic_tests/targets/hamlin/alerting_inputs/crash_input differ diff --git a/tests/generic_tests/targets/hamlin/benign_inputs/id_000583 b/tests/generic_tests/targets/hamlin/benign_inputs/id_000583 new file mode 100644 index 0000000..cbdc64c --- /dev/null +++ b/tests/generic_tests/targets/hamlin/benign_inputs/id_000583 @@ -0,0 +1 @@ +624157469507472699 diff --git a/tests/generic_tests/targets/hamlin/challenge/run.sh b/tests/generic_tests/targets/hamlin/challenge/run.sh new file mode 100755 index 0000000..f985c7c --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/run.sh @@ -0,0 +1,87 @@ +#!/bin/bash + +# +# Based on the original AIxCC run.sh script +# + +# get directory containing the script +SCRIPT_DIR="$(dirname $(realpath $0))" +SRC="${SCRIPT_DIR}/src" + +warn() { + echo "$*" >&2 +} + +# kill the script with an error message +die() { + warn "$*" + exit 1 +} + +print_usage() { + warn "A helper script for CP interactions." + warn + warn "Usage: ${SCRIPT_FILE} pull_source|build|run_pov|run_test" + warn + warn "Subcommands:" + warn " pull_source Pull the CP source code into the src/ directory; will overwrite existing source" + warn " build [ ] Build the CP (an optional patch file for a given source repo can be supplied)" + warn " run_pov Run the binary data blob against specified harness" + warn " run_tests Run functionality tests" + die +} + +## execute commands +CMD_NAME=$1 +shift +case ${CMD_NAME,,} in + "build") + ##### Run patch command if patch file was supplied ##### + + if [ -n "$1" ]; then + PATCH_FILE=$1 + SOURCE_TARGET="./src" + + if [ ! -d "${SRC}/${SOURCE_TARGET}" ]; then + echo "Source repository not found: ${SRC}/${SOURCE_TARGET}" + echo "Valid source names: ${CP_SOURCE_NAMES[*]}" + fi + + # check validity of patch file provided + PATCH_FILE=$(realpath "${PATCH_FILE}") + [[ -f "${PATCH_FILE}" ]] || die "Patch file not found: ${PATCH_FILE}" + + # apply patch + # shellcheck disable=SC2086 + git -C "${SRC}/${SOURCE_TARGET}" apply \ + ${PATCH_EXTRA_ARGS} \ + "${PATCH_FILE}" || die "Patching failed using: ${PATCH_FILE}" + fi + + ( + cd src/ && \ + make clean && \ + make -j8 && \ + cp build/hamlin.bin $SCRIPT_DIR/ + ) + + ;; + + "run") + ##### Run based on a blob ##### + IN_FILE=$1 + BIN_FILE="$SCRIPT_DIR/hamlin.bin" + + export CHESS=1 + output=$($BIN_FILE < $IN_FILE 2>&1) + echo "$output" + if echo "$output" | grep -q "ERROR: AddressSanitizer"; then + exit 37 + fi + exit 0 + ;; + *) + echo "Invalid command $CMD_NAME" + print_usage + ;; +esac diff --git a/tests/generic_tests/targets/hamlin/challenge/src/Makefile b/tests/generic_tests/targets/hamlin/challenge/src/Makefile new file mode 100644 index 0000000..56f777a --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/Makefile @@ -0,0 +1,35 @@ +BUILD_DIR ?= ./build +SRC_DIR ?= ./src +PRIV_DIR ?= ./priv + +TARGET ?= hamlin.bin + +CC ?= clang +CXX ?= clang++ + +CPPFLAGS += -MD -g +CFLAGS += -Wall -O2 -g -fsanitize=address +CXXFLAGS += -Wall -std=c++17 -O2 -g -fsanitize=address +LDFLAGS += -std=c++17 -lstdc++fs -O3 -lasan + +SRCS := $(wildcard $(SRC_DIR)/*.cpp) $(wildcard $(SRC_DIR)/**/*.cpp) +OBJS := $(SRCS:$(SRC_DIR)/%.cpp=$(BUILD_DIR)/%.o) +DEPS := $(OBJS:.o=.d) + +$(BUILD_DIR)/$(TARGET): $(OBJS) + $(CXX) $(filter-out ./build/pngtest.o,$(OBJS)) -o $@ $(LDFLAGS) + #strip -s $@ + +$(BUILD_DIR)/%.o: $(SRC_DIR)/%.cpp + mkdir -p $(dir $@) + $(CXX) $(CPPFLAGS) $(CXXFLAGS) -c $< -o $@ + +.PHONY: clean docker + +run: $(BUILD_DIR)/$(TARGET) + $< + +clean: + rm -rf $(BUILD_DIR) + +-include $(DEPS) diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/assert.cpp b/tests/generic_tests/targets/hamlin/challenge/src/src/assert.cpp new file mode 100755 index 0000000..82a4fcd --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/assert.cpp @@ -0,0 +1,24 @@ +#include +#include + +#include "assert.hpp" + +void __assert_fail(const char* assertion, + const char* file, + unsigned int line, + const char* function) { + std::cerr << std::dec << file << ":" << line << ": " << function << + ": Assertion `" << assertion << "` failed." << std::endl; + + std::exit(-1); +} + +void __assert_zero_fail(const char* assertion, + const char* file, + unsigned int line, + const char* function) { + std::cerr << file << ":" << line << ": " << function << + ": Non-zero assertion `" << assertion << "` failed." << std::endl; + + std::exit(-1); +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/assert.hpp b/tests/generic_tests/targets/hamlin/challenge/src/src/assert.hpp new file mode 100755 index 0000000..2337296 --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/assert.hpp @@ -0,0 +1,23 @@ +#pragma once + +#undef assert + +#define assert(expr) \ + (static_cast(expr) ? void(0) : \ + __assert_fail(#expr, __FILE__, __LINE__, \ + __extension__ __PRETTY_FUNCTION__)) + +#define assert_zero(expr) \ + (static_cast(0 == expr) ? void(0) : \ + __assert_zero_fail(#expr, __FILE__, __LINE__, \ + __extension__ __PRETTY_FUNCTION__)) + +[[ noreturn ]] void __assert_fail(const char* assertion, + const char* file, + unsigned int line, + const char* function); + +[[ noreturn ]] void __assert_zero_fail(const char* assertion, + const char* file, + unsigned int line, + const char* function); diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/crc32.cpp b/tests/generic_tests/targets/hamlin/challenge/src/src/crc32.cpp new file mode 100755 index 0000000..7d8b3f5 --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/crc32.cpp @@ -0,0 +1,21 @@ +#include "hton.hpp" + +#include "crc32.hpp" + +using namespace crc32; + +crc_t crc32::calculate(const char* in, std::size_t count) { + return constexp::calculate(in, count); +} + +crc_t crc32::calculate_begin(const char* in, std::size_t count) { + return constexp::calculate_begin(in, count); +} + +crc_t crc32::calculate_inter(crc_t inter, const char* in, std::size_t count) { + return constexp::calculate_inter(inter, in, count); +} + +crc_t crc32::calculate_final(crc_t inter) { + return constexp::calculate_final(inter); +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/crc32.hpp b/tests/generic_tests/targets/hamlin/challenge/src/src/crc32.hpp new file mode 100755 index 0000000..cb31424 --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/crc32.hpp @@ -0,0 +1,77 @@ +#pragma once + +#include +#include +#include +#include + +#include "hbytes.hpp" + +namespace crc32 { + constexpr std::size_t _table_size = 256; + + using crc_t = uint32_t; + using table_t = std::array; + + constexpr table_t _make_table() { + table_t tbl{}; + + for (std::size_t n = 0; n < _table_size; n++) { + crc_t c = n; + + for (byte k = 0; k < 8; k++) { + if (1 & c) { + c = 0xedb88320L ^ (c >> 1); + } else { + c = c >> 1; + } + } + tbl[n] = c; + } + + return tbl; + } + + static constexpr table_t table = _make_table(); + + static_assert(0x00000000 == table[ 0]); + static_assert(0x77073096 == table[ 1]); + static_assert(0x2d02ef8d == table[255]); + + crc_t calculate(const char* in, std::size_t count); + crc_t calculate_begin(const char* in, std::size_t count); + crc_t calculate_inter(crc_t inter, const char* in, std::size_t count); + crc_t calculate_final(crc_t inter); + + namespace constexp { + constexpr crc_t calculate_inter(crc_t inter, + const char* in, + std::size_t count) { + crc_t crc = inter; + + for (std::size_t i = 0; i < count; i++) { + byte b = in[i]; + crc_t cursor = (crc ^ b) & 0xFF; + crc = table[cursor] ^ (crc >> 8); + } + + return crc; + } + + constexpr crc_t calculate_begin(const char *in, std::size_t count) { + return calculate_inter(0xFFFFFFFF, in, count); + } + + constexpr crc_t calculate_final(crc_t inter) { + return ~inter; + } + + constexpr crc_t calculate(const char* in, std::size_t count) { + crc_t intermediate = calculate_begin(in, count); + return calculate_final(intermediate); + } + + static_assert(0 == calculate("", 0)); + static_assert(0x6783f342 == calculate("hamlin", 6)); + } +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/array_history.cpp b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/array_history.cpp new file mode 100755 index 0000000..557368c --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/array_history.cpp @@ -0,0 +1,34 @@ +#include "array_history.hpp" + +using namespace deflate; + +ArrayHistory::ArrayHistory(ZlibHeader& zlh) : header(zlh) { + assert(header.window_size() == buf.size()); +} + +void ArrayHistory::append(byte b) { + buf[cursor] = b; + if (cursor +1 >= buf.size()) { + wrapped = true; + } + cursor = (cursor + 1) % buf.size(); +} + +std::vector ArrayHistory::copy(uint32_t dist, uint16_t count) { + std::ptrdiff_t start_cur = (cursor - dist); + + std::vector cpy{}; + cpy.reserve(count); + + for (std::size_t n = 0; n < count; n++) { + std::ptrdiff_t pos = n + start_cur; + if (wrapped) { + pos = pos % buf.size(); + } + byte b = buf[pos]; + append(b); + cpy.push_back(b); + } + + return cpy; +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/array_history.hpp b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/array_history.hpp new file mode 100755 index 0000000..77d68dd --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/array_history.hpp @@ -0,0 +1,22 @@ +#pragma once + +#include + +#include "history.hpp" + +namespace deflate { + class ArrayHistory : public History { + public: + ArrayHistory(ZlibHeader& zlh); + virtual ~ArrayHistory() override {}; + + virtual void append(byte b) override; + virtual std::vector copy(uint32_t dist, uint16_t count) override; + + private: + ZlibHeader& header; + std::array buf = {}; + std::size_t cursor = 0; + bool wrapped = false; + }; +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/bit_vector.cpp b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/bit_vector.cpp new file mode 100755 index 0000000..6e2515a --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/bit_vector.cpp @@ -0,0 +1,95 @@ +#include + +#include "../assert.hpp" +#include "../hbytes.hpp" + +#include "bit_vector.hpp" + +using namespace deflate; + +byte BitVector::read_bit() { + uint32_t got = read_bits(1); + assert(got <= 1); + return (byte)got; +} + +uint32_t BitVector::read_bits(byte count) { + if (0 == count) return 0; + + assert(count <= 32); + + byte remaining_bits_this_byte = 8 - bit_cursor; + + if (count > remaining_bits_this_byte) { + return read_bits_not_enough(count); + } + + if (count == remaining_bits_this_byte) { + byte got = backing[byte_cursor] >> bit_cursor; + byte_cursor++; + bit_cursor = 0; + return got; + } + + assert(count < remaining_bits_this_byte); + + byte got_too_much = backing[byte_cursor] >> bit_cursor; + byte down_to_size = got_too_much & (( 1 << count) - 1); + + bit_cursor += count; + + return down_to_size; +} + +uint32_t BitVector::read_bits_not_enough(byte count) { + uint32_t accumulator = 0; + byte shift = 0; + + byte remain_this_byte = 8 - bit_cursor; + + accumulator |= read_bits(remain_this_byte); + shift += remain_this_byte; + + while (shift < count) { + byte chonk_count = ((count - shift) > 8) ? 8 : (count - shift); + uint32_t chonk = read_bits(chonk_count); + + accumulator |= (chonk << shift); + shift += chonk_count; + } + + return accumulator; +} + +byte BitVector::finish_byte() { + uint32_t got_big = read_bits(8 - bit_cursor); + assert(got_big <= 255); + + return (byte) got_big; +} + +void BitVector::inspect(std::ostream& o) const { + o << std::hex << "BitVector byte(" << byte_cursor << ") bit(" + << (int)bit_cursor << ")" << std::endl; +} + +uint16_t BitVector::read_u16() { + uint32_t got_big = read_bits(16); + assert(got_big <= 0xFFFF); + + return (uint16_t) got_big; +} + +std::vector BitVector::read_bytes(std::size_t count) { + assert(0 == bit_cursor); + assert((byte_cursor + count) <= backing.size()); + + std::vector ret = {}; + auto slice_start = backing.begin() + byte_cursor; + auto slice_end = slice_start + count; + ret.insert(ret.end(), slice_start, slice_end); + + byte_cursor += count; + + return ret; +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/bit_vector.hpp b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/bit_vector.hpp new file mode 100755 index 0000000..3589798 --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/bit_vector.hpp @@ -0,0 +1,34 @@ +#pragma once + +#include +#include + +#include "../hbytes.hpp" + +namespace deflate { + class BitVector { + public: + BitVector(const std::vector v) : + backing(v), byte_cursor(0), bit_cursor(0) {}; + + BitVector(const BitVector&) = delete; // don't copy + + byte read_bit(); + uint32_t read_bits(byte count); + + byte finish_byte(); + + uint16_t read_u16(); + std::vector read_bytes(std::size_t count); + + void inspect(std::ostream& o) const; + + private: + const std::vector backing; + std::size_t byte_cursor; + byte bit_cursor; + + uint32_t read_bits_not_enough(byte count); + }; + +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/block_type.cpp b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/block_type.cpp new file mode 100755 index 0000000..f97cc18 --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/block_type.cpp @@ -0,0 +1,19 @@ +#include "block_type.hpp" + +using namespace deflate; + +std::ostream& operator<<(std::ostream& os, deflate::BlockType b) { + return os << to_string(b); +} +std::string to_string(deflate::BlockType b) { + switch (b) { + case BlockType::uncompressed: + return "uncompressed"; + case BlockType::fixed: + return "fixed"; + case BlockType::dynamic: + return "dynamic"; + case BlockType::_reserved: + return "reserved (error)"; + } +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/block_type.hpp b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/block_type.hpp new file mode 100755 index 0000000..f1b2522 --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/block_type.hpp @@ -0,0 +1,17 @@ +#pragma once + +#include + +#include "../hbytes.hpp" + +namespace deflate { + enum struct BlockType : byte { + uncompressed = 0, + fixed = 1, + dynamic = 2, + _reserved = 3 + }; +} + +std::ostream& operator<<(std::ostream& os, deflate::BlockType b); +std::string to_string(deflate::BlockType b); diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/canonical_code.cpp b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/canonical_code.cpp new file mode 100755 index 0000000..10f0204 --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/canonical_code.cpp @@ -0,0 +1,80 @@ +#include "canonical_code.hpp" + +using namespace deflate; + +CanonicalCode::CanonicalCode(std::vector lengths) { + uint16_t next_code = 0; + max_code_len = 0; + + for (uint16_t len : lengths) { + if (len > max_code_len) max_code_len = len; + } + + for (uint16_t code_len = 1; code_len <= max_code_len; code_len++) { + next_code <<= 1; + uint16_t start_bit = 1 << code_len; + for (std::size_t symbol = 0; symbol < lengths.size(); symbol++) { + uint16_t symbol_len = lengths[symbol]; + if (symbol_len != code_len) continue; + assert(next_code < start_bit); + code_bits_to_symbol[start_bit | next_code] = symbol; + next_code ++; + } + } + + assert((1 << max_code_len) == next_code); +} + +CanonicalCode::CanonicalCode(std::vector lengths) { + std::vector dest = {}; + dest.reserve(lengths.size()); + for (byte e : lengths) { + dest.push_back(e); + } + + CanonicalCode{dest}; +} + +uint16_t CanonicalCode::get_next_symbol(BitVector& reader) { + uint16_t code_bits = 1; + for (byte len = 1; len <= max_code_len; len++) { + code_bits = (code_bits << 1) | reader.read_bit(); + auto got = code_bits_to_symbol.find(code_bits); + if (code_bits_to_symbol.end() != got) { + return got->second; + } + } + assert(false); +} + +bool CanonicalCode::need_more_bits(CodeBits bits) { + if (bits.len > max_code_len) return false; + + uint16_t prefix_code = bits.to_prefix_symbol(); + + auto got = code_bits_to_symbol.find(prefix_code); + + if (code_bits_to_symbol.end() != got) return false; + + return true; +} + +bool CanonicalCode::is_valid_code(CodeBits bits) { + if (bits.len > max_code_len) return false; + + uint16_t prefix_code = bits.to_prefix_symbol(); + + auto got = code_bits_to_symbol.find(prefix_code); + + if (code_bits_to_symbol.end() != got) return true; + + return false; +} + +uint16_t CanonicalCode::get_symbol(CodeBits bits) { + auto got = code_bits_to_symbol.find(bits.to_prefix_symbol()); + + assert(code_bits_to_symbol.end() != got); + + return got->second; +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/canonical_code.hpp b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/canonical_code.hpp new file mode 100755 index 0000000..847cba6 --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/canonical_code.hpp @@ -0,0 +1,67 @@ +#pragma once + +#include + +#include "bit_vector.hpp" +#include "common.hpp" + +#include "code_bits.hpp" + +namespace deflate { + class CanonicalCode { + public: + CanonicalCode(std::vector lengths); + CanonicalCode(std::vector lengths); + CanonicalCode() : code_bits_to_symbol({}), max_code_len(0) {}; + CanonicalCode(const CanonicalCode& other) : + code_bits_to_symbol(other.code_bits_to_symbol), + max_code_len(other.max_code_len) {}; + + + template + static CanonicalCode create_from_array(std::array lengths); + + template + static CanonicalCode create_from_array(std::array lengths); + + uint16_t get_next_symbol(BitVector& reader); + + bool need_more_bits(CodeBits bits); + bool is_valid_code(CodeBits bits); + + uint16_t get_symbol(CodeBits bits); + + private: + std::map code_bits_to_symbol; + byte max_code_len; + }; + + + + template + CanonicalCode + CanonicalCode::create_from_array(std::array lengths) { + std::vector dest = {}; + dest.reserve(array_len); + for (uint16_t e : lengths) { + dest.push_back(e); + } + + return dest; + } + + template + CanonicalCode + CanonicalCode::create_from_array(std::array lengths) { + std::vector dest = {}; + dest.reserve(array_len); + for (uint16_t e : lengths) { + dest.push_back(e); + } + + return dest; + } + +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/code.cpp b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/code.cpp new file mode 100755 index 0000000..3932cd8 --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/code.cpp @@ -0,0 +1,5 @@ +#include "common.hpp" + +#include "code.hpp" + +using namespace deflate; diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/code.hpp b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/code.hpp new file mode 100755 index 0000000..a1d0e33 --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/code.hpp @@ -0,0 +1,21 @@ +#pragma once + +#include + +#include "common.hpp" + +#include "code_bits.hpp" +#include "destination.hpp" + +namespace deflate { + using LengthArray = std::array; + using CodeArray = std::array; + + class Code { + public: + virtual const Destination& get_destination(CodeBits path) = 0; + + private: + + }; +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/code_bits.cpp b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/code_bits.cpp new file mode 100755 index 0000000..6791018 --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/code_bits.cpp @@ -0,0 +1,18 @@ +#include "code_bits.hpp" + +using namespace deflate; + +CodeBits CodeBits::operator+(byte bit) { + assert(bit <= 1); + assert(len < 255); + + return {.len = (byte)(len + 1), + .bits = (bits << 1) | bit}; +} + +uint16_t CodeBits::to_prefix_symbol() { + assert(len < 16); + assert(bits < (1l << 16l)); + uint16_t symbol = (1 << len) + bits; + return symbol; +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/code_bits.hpp b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/code_bits.hpp new file mode 100755 index 0000000..ad68ece --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/code_bits.hpp @@ -0,0 +1,14 @@ +#pragma once + +#include "common.hpp" + +namespace deflate { + struct CodeBits { + byte len = 0; + uint32_t bits = 0; + + CodeBits operator+(byte bit); + + uint16_t to_prefix_symbol(); + }; +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/common.hpp b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/common.hpp new file mode 100755 index 0000000..e2ceb40 --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/common.hpp @@ -0,0 +1,11 @@ +#pragma once + +#include + +#include "../assert.hpp" +#include "../hbytes.hpp" +#include "../log.hpp" + +namespace deflate { + constexpr uint16_t CODE_COUNT = 288; +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/destination.cpp b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/destination.cpp new file mode 100755 index 0000000..b62969e --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/destination.cpp @@ -0,0 +1,81 @@ +#include "destination.hpp" + +using namespace deflate::destination; + +bool Base::operator==(const destination::Base& other) const { + assert(false); +} + +void Base::inspect(std::ostream& out) const { + out << "destination::Base" << std::endl; +} + +bool Base::operator!=(const destination::Base& other) const { + return !(operator==(other)); +} + +bool Incomplete::operator==(const destination::Base& other) const { + const Incomplete* other_ptr = dynamic_cast(&other); + if (nullptr == other_ptr) return false; + + return true; +} + +void Incomplete::inspect(std::ostream& out) const { + out << "destination::Incomplete" << std::endl; +} + +bool Invalid::operator==(const destination::Base& other) const { + const Invalid* other_ptr = dynamic_cast(&other); + if (nullptr == other_ptr) return false; + + return true; +} + +void Invalid::inspect(std::ostream& out) const { + out << "destination::Invalid" << std::endl; +} + +bool Literal::operator==(const destination::Base& other) const { + const Literal* other_ptr = dynamic_cast(&other); + if (nullptr == other_ptr) return false; + + if (val != other_ptr->val) return false; + + return true; +} + +void Literal::inspect(std::ostream& out) const { + out << "destination::Literal val(" << (int)val << ")" << std::endl; +} + +bool EndOfBlock::operator==(const destination::Base& other) const { + const EndOfBlock* other_ptr = dynamic_cast(&other); + if (nullptr == other_ptr) return false; + + return true; +} + +void EndOfBlock::inspect(std::ostream& out) const { + out << "destination::EndOfBlock" << std::endl; +} + +bool Backref::operator==(const destination::Base& other) const { + const Backref* other_ptr = dynamic_cast(&other); + if (nullptr == other_ptr) return false; + if (bits != other_ptr->bits) return false; + if (min != other_ptr->min) return false; + + return true; +} + +void Backref::inspect(std::ostream& out) const { + out << "destination::Backref bits(" << std::hex << (int)bits << + ") min(" << std::dec << min << ")" << std::endl; +} + +uint16_t Backref::len(BitVector bv) { + if (0 == bits) return min; + uint32_t lsb = bv.read_bits(bits); + return min + lsb; +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/destination.hpp b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/destination.hpp new file mode 100755 index 0000000..a0cacbb --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/destination.hpp @@ -0,0 +1,128 @@ +#pragma once + +#include +#include + +#include "common.hpp" + +#include "bit_vector.hpp" + +namespace deflate { + namespace destination { + class Base{ + public: + virtual bool operator==(const destination::Base& other) const; + bool operator!=(const destination::Base& other) const; + + virtual void inspect(std::ostream& out) const; + }; + + class Incomplete : public Base { + public: + virtual bool operator==(const destination::Base& other) const override; + virtual void inspect(std::ostream& out) const override; + }; + + class Invalid : public Base { + public: + virtual bool operator==(const destination::Base& other) const override; + virtual void inspect(std::ostream& out) const override; + }; + + class Literal : public Base { + public: + byte val; + virtual bool operator==(const destination::Base& other) const override; + virtual void inspect(std::ostream& out) const override; + + constexpr Literal() : val(0) {}; + constexpr Literal(byte l) : val(l) {}; + }; + + class EndOfBlock : public Base { + public: + virtual bool operator==(const destination::Base& other) const override; + virtual void inspect(std::ostream& out) const override; + }; + + class Backref : public Base { + public: + byte bits; + uint16_t min; + + constexpr Backref() : bits(0), min(0) {}; + constexpr Backref(byte b, uint16_t m) : bits(b), min(m) {}; + + virtual bool operator==(const destination::Base& other) const override; + virtual void inspect(std::ostream& out) const override; + + uint16_t len(BitVector bv); + }; + + namespace identity { + constexpr Incomplete incomplete{}; + constexpr Invalid invalid{}; + constexpr EndOfBlock end_of_block{}; + } + + using _literal_table_t = std::array; + + constexpr _literal_table_t _make_literal_table() { + _literal_table_t tbl{}; + + for (std::size_t n = 0; n <= 255; n++) { + tbl[n] = Literal{(byte)n}; + } + + return tbl; + } + + constexpr _literal_table_t _literal_table = _make_literal_table(); + + using _backref_table_t = std::array; + + constexpr _backref_table_t _make_backref_table() { + _backref_table_t tbl{}; + + for (std::size_t n = 257; n <= 264; n++) { + tbl[n - 257] = Backref(0, n - 254); + } + + for (std::size_t n = 0; n <= 3; n++) { + tbl[n + 265 - 257] = Backref(1, (n << 1) + 11); + tbl[n + 269 - 257] = Backref(2, (n << 2) + 19); + tbl[n + 273 - 257] = Backref(3, (n << 3) + 35); + tbl[n + 277 - 257] = Backref(4, (n << 4) + 67); + tbl[n + 281 - 257] = Backref(5, (n << 5) + 131); + } + + tbl[285 - 257] = Backref(0, 258); + + return tbl; + } + + constexpr _backref_table_t _backref_table = _make_backref_table(); + + using table_t = std::array; + + constexpr table_t _make_destination_table() { + table_t tbl{}; + + for (std::size_t n = 0; n <= 255; n++) { + tbl[n] = &_literal_table[n]; + } + + tbl[256] = &identity::end_of_block; + + for (std::size_t n = 257; n <= 285; n++) { + tbl[n] = &_backref_table[n - 257]; + } + + return tbl; + } + + constexpr table_t fixed_destinations = _make_destination_table(); + } + + using Destination = destination::Base; +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/distance_code.hpp b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/distance_code.hpp new file mode 100755 index 0000000..789f613 --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/distance_code.hpp @@ -0,0 +1,10 @@ +#pragma once + +#include "distance_destination.hpp" + +namespace deflate { + class DistanceCode { + public: + virtual const DistanceDestination& get_destination(CodeBits path) = 0; + }; +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/distance_destination.cpp b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/distance_destination.cpp new file mode 100755 index 0000000..c964414 --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/distance_destination.cpp @@ -0,0 +1,8 @@ +#include "distance_destination.hpp" + +using namespace deflate; + +void DistanceDestination::inspect(std::ostream& o) const { + o << "DistanceDestination bits(" << (int)extra << ") min(" << min << ")" + << std::endl; +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/distance_destination.hpp b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/distance_destination.hpp new file mode 100755 index 0000000..678ffaf --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/distance_destination.hpp @@ -0,0 +1,61 @@ +#pragma once + +#include +#include + +#include "bit_vector.hpp" +#include "common.hpp" + +namespace deflate { + class DistanceDestination { + public: + byte extra; + uint16_t min; + + const static byte INCOMPLETE_EXTRA = 254; + const static byte INVALID_EXTRA = 255; + + constexpr DistanceDestination() : extra(0), min(0) {}; + constexpr DistanceDestination(byte e, uint16_t m) : extra(e), min(m) {}; + constexpr DistanceDestination(const DistanceDestination& other) : + extra(other.extra), min(other.min) {}; + + constexpr bool operator==(const DistanceDestination& other) const { + if (extra != other.extra) return false; + if (min != other.min) return false; + + return true; + } + + constexpr bool is_complete() const { + if (extra < INCOMPLETE_EXTRA) return true; + return false; + } + + constexpr bool is_incomplete() const { + if (extra == INCOMPLETE_EXTRA) return true; + return false; + } + + constexpr bool is_invalid() const { + if (extra == INVALID_EXTRA) return true; + return false; + } + + void inspect(std::ostream& o) const; + }; + + namespace distance_destination { + namespace identity { + constexpr DistanceDestination + incomplete{DistanceDestination::INCOMPLETE_EXTRA, 0}; + constexpr DistanceDestination + invalid{DistanceDestination::INVALID_EXTRA, 0}; + + static_assert(incomplete.is_incomplete()); + static_assert(invalid.is_invalid()); + static_assert(!incomplete.is_complete()); + static_assert(!invalid.is_complete()); + } + } +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/dynamic_code.cpp b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/dynamic_code.cpp new file mode 100755 index 0000000..4370fda --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/dynamic_code.cpp @@ -0,0 +1,20 @@ +#include "dynamic_code.hpp" + +using namespace deflate; +using namespace deflate::destination; + +const Destination& DynamicCode::get_destination(CodeBits path) { + CanonicalCode& code = decoder.literal_codes; + + if (code.need_more_bits(path)) return identity::incomplete; + if (!code.is_valid_code(path)) return identity::invalid; + + uint16_t symbol = code.get_symbol(path); + + assert(symbol < fixed_destinations.size()); + return *fixed_destinations[symbol]; +} + +DynamicDistanceCode DynamicCode::get_distance_code() { + return {decoder.distance_codes}; +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/dynamic_code.hpp b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/dynamic_code.hpp new file mode 100755 index 0000000..6909127 --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/dynamic_code.hpp @@ -0,0 +1,20 @@ +#pragma once + +#include + +#include "code.hpp" +#include "dynamic_code_code.hpp" +#include "dynamic_distance_code.hpp" + +namespace deflate { + class DynamicCode : public Code { + public: + DynamicCode(BitVector& reader) : decoder(reader) {}; + + virtual const Destination& get_destination(CodeBits path); + + DynamicDistanceCode get_distance_code(); + private: + dynamic_code_code::DCCDecoder decoder; + }; +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/dynamic_code_code.cpp b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/dynamic_code_code.cpp new file mode 100755 index 0000000..82bda47 --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/dynamic_code_code.cpp @@ -0,0 +1,97 @@ +#include "common.hpp" + +#include "dynamic_code_code.hpp" + +#include "canonical_code.hpp" +#include "history.hpp" + +using namespace deflate; +using namespace deflate::dynamic_code_code; + +DCCDecoder::DCCDecoder(BitVector& reader) { + hlit = reader.read_bits(5) + 257; + hdist = reader.read_bits(5) + 1; + hclen = reader.read_bits(4) + 4; + + total_codes = hlit + hdist; + + std::array code_len_code_lens{}; + code_len_code_lens.fill(0); + std::array code_len_read_order = {16, 17, 18, + 0, 8, 7, 9, 6, 10, 5, 11, + 4, 12, 3, 13, 2, 14, 1, 15}; + + for (byte n = 0; n < hclen; n++) { + std::size_t idx = code_len_read_order[n]; + code_len_code_lens[idx] = reader.read_bits(3); + } + + CanonicalCode code_len_code = + CanonicalCode::create_from_array<19>(code_len_code_lens); + + std::vector lit_code_lens; + lit_code_lens.reserve(hlit); + + while (lit_code_lens.size() < hlit) { + uint16_t symbol = code_len_code.get_next_symbol(reader); + if (16 > symbol) { + lit_code_lens.push_back(symbol); + } else if (16 == symbol) { + assert(lit_code_lens.size() > 0); + byte run = reader.read_bits(2) + 3; + byte prev = lit_code_lens[lit_code_lens.size() - 1]; + for(byte c = 0; c < run; c++) { + lit_code_lens.push_back(prev); + } + } else if (17 == symbol) { + byte run = reader.read_bits(3) + 3; + for (byte c = 0; c < run; c++) { + lit_code_lens.push_back(0); + } + } else if (18 == symbol) { + byte run = reader.read_bits(7) + 11; + for (byte c = 0; c < run; c++) { + lit_code_lens.push_back(0); + } + } else { + assert(symbol <= 18); + } + } + + assert(lit_code_lens.size() == hlit); + + literal_codes = {lit_code_lens}; + + std::vector dist_code_lens; + dist_code_lens.reserve(hdist); + + while (dist_code_lens.size() < hdist) { + uint16_t symbol = code_len_code.get_next_symbol(reader); + if (16 > symbol) { + dist_code_lens.push_back(symbol); + } else if (16 == symbol) { + assert(dist_code_lens.size() > 0); + byte run = reader.read_bits(2) + 3; + byte prev = dist_code_lens[dist_code_lens.size() - 1]; + for(byte c = 0; c < run; c++) { + dist_code_lens.push_back(prev); + } + } else if (17 == symbol) { + byte run = reader.read_bits(3) + 3; + for (byte c = 0; c < run; c++) { + dist_code_lens.push_back(0); + } + } else if (18 == symbol) { + byte run = reader.read_bits(7) + 11; + for (byte c = 0; c < run; c++) { + dist_code_lens.push_back(0); + } + } else { + assert(symbol <= 18); + } + } + + assert(dist_code_lens.size() == hdist); + + distance_codes = {dist_code_lens}; +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/dynamic_code_code.hpp b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/dynamic_code_code.hpp new file mode 100755 index 0000000..250001d --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/dynamic_code_code.hpp @@ -0,0 +1,27 @@ +#pragma once + +#include "common.hpp" + +#include "code.hpp" +#include "canonical_code.hpp" + +namespace deflate { + namespace dynamic_code_code { + class DCCDecoder { + public: + DCCDecoder(BitVector& reader); + + CanonicalCode literal_codes; + CanonicalCode distance_codes; + + uint16_t get_symbol(CodeBits symbol); + private: + uint16_t hlit; + byte hdist; + byte hclen; + + uint16_t total_codes; + + }; + } +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/dynamic_code_code_destination.hpp b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/dynamic_code_code_destination.hpp new file mode 100755 index 0000000..487d680 --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/dynamic_code_code_destination.hpp @@ -0,0 +1,10 @@ +#pragma once + +#include "common.hpp" + +namespace deflate { + namespace dynamic_code_code { + class DCCDestination { + public: + uint16_t literal; + uint16_t length; diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/dynamic_distance_code.cpp b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/dynamic_distance_code.cpp new file mode 100755 index 0000000..a135d3f --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/dynamic_distance_code.cpp @@ -0,0 +1,18 @@ +#include "dynamic_distance_code.hpp" + +#include "fixed_distance_code.hpp" + +using namespace deflate; +using namespace deflate::distance_destination::identity; + +const DistanceDestination& DynamicDistanceCode::get_destination(CodeBits path) { + if (dist_codes.need_more_bits(path)) return incomplete; + if (!dist_codes.is_valid_code(path)) return invalid; + + uint16_t distance_symbol = dist_codes.get_symbol(path); + + const DistanceDestination& dest = + fixed_distance_code::distance_table[distance_symbol]; + + return dest; +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/dynamic_distance_code.hpp b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/dynamic_distance_code.hpp new file mode 100755 index 0000000..667f17e --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/dynamic_distance_code.hpp @@ -0,0 +1,18 @@ +#pragma once + +#include "common.hpp" + +#include "canonical_code.hpp" +#include "distance_code.hpp" + +namespace deflate { + class DynamicDistanceCode : public DistanceCode { + public: + DynamicDistanceCode(CanonicalCode dc) : dist_codes(dc) {}; + + virtual const DistanceDestination& get_destination(CodeBits path); + + private: + CanonicalCode dist_codes; + }; +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/fixed_code.cpp b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/fixed_code.cpp new file mode 100755 index 0000000..7282db9 --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/fixed_code.cpp @@ -0,0 +1,19 @@ +#include "fixed_code.hpp" + +using namespace deflate; +using namespace deflate::fixed_code; + +const Destination& FixedCode::get_destination(const CodeBits path) { + if (path.len < 7) return destination::identity::incomplete; + if (path.len > 9) return destination::identity::invalid; + + uint16_t maybe_sym = code_to_sym[path.bits]; + byte maybe_len = lengths[maybe_sym]; + uint16_t back_code = codes[maybe_sym]; + + if ((path.len == maybe_len) && (path.bits == back_code)) { + return *destination::fixed_destinations[maybe_sym]; + } + + return destination::identity::incomplete; +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/fixed_code.hpp b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/fixed_code.hpp new file mode 100755 index 0000000..86d38e7 --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/fixed_code.hpp @@ -0,0 +1,95 @@ +#pragma once + +#include + +#include "common.hpp" + +#include "code.hpp" +#include "destination.hpp" + +namespace deflate { + namespace fixed_code { + constexpr LengthArray _make_lengths() { + LengthArray len_ary{}; + + for (size_t n = 0; n <= 143; n++) { + len_ary[n] = 8; + } + for (size_t n = 144; n <= 255; n++) { + len_ary[n] = 9; + } + for (size_t n = 256; n <= 279; n++) { + len_ary[n] = 7; + } + for (size_t n = 280; n <= 287; n++) { + len_ary[n] = 8; + } + + return len_ary; + } + + constexpr LengthArray lengths = _make_lengths(); + + constexpr CodeArray _make_codes(const LengthArray len_ary) { + CodeArray code_ary{}; + + uint16_t code = 0; + std::array bl_count{0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + std::array next_code{0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + + for (byte bit_count : len_ary) { + bl_count[bit_count] += 1; + } + + for (byte bits = 1; bits <= 9; bits++) { + code = (code + bl_count[bits - 1]) << 1; + next_code[bits] = code; + } + + for (uint16_t n = 0; n < CODE_COUNT; n++) { + byte len = len_ary[n]; + if (0 != len) { + code_ary[n] = next_code[len]; + next_code[len] += 1; + } + } + + return code_ary; + } + + constexpr CodeArray codes = _make_codes(lengths); + + using BackCodeArray = std::array; + + constexpr BackCodeArray _invert_codes(const CodeArray codes) { + BackCodeArray inv_ary{}; + + for (uint16_t n = 0; n < CODE_COUNT; n++) { + inv_ary[codes[n]] = n; + } + + return inv_ary; + } + + constexpr BackCodeArray code_to_sym = _invert_codes(codes); + + static_assert(0b10111111 == codes[143]); + static_assert(8 == lengths[143]); + static_assert(143 == code_to_sym[0b10111111]); + + static_assert(511 == codes[255]); + static_assert(9 == lengths[255]); + static_assert(255 == code_to_sym[511]); + + static_assert(0 == codes[256]); + static_assert(7 == lengths[256]); + static_assert(256 == code_to_sym[0]); + + } + + class FixedCode : public Code { + public: + + virtual const Destination& get_destination(CodeBits path); + }; +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/fixed_distance_code.cpp b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/fixed_distance_code.cpp new file mode 100755 index 0000000..3b7316d --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/fixed_distance_code.cpp @@ -0,0 +1,23 @@ +#include "fixed_distance_code.hpp" + +using namespace deflate; +using namespace deflate::fixed_distance_code; + +const DistanceDestination& FixedDistanceCode::get_destination(CodeBits path) { + if (path.len < 5) { + return distance_destination::identity::incomplete; + } + if (path.len > 5) { + return distance_destination::identity::invalid; + } + + uint16_t maybe_sym = code_to_sym[path.bits]; + byte maybe_len = lengths[maybe_sym]; + uint16_t back_code = codes[maybe_sym]; + + if ((path.len == maybe_len) && (path.bits == back_code)) { + return distance_table[maybe_sym]; + } + + return distance_destination::identity::invalid; +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/fixed_distance_code.hpp b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/fixed_distance_code.hpp new file mode 100755 index 0000000..2acbdcc --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/fixed_distance_code.hpp @@ -0,0 +1,101 @@ +#pragma once + +#include + +#include "common.hpp" +#include "code_bits.hpp" + +#include "distance_code.hpp" + +namespace deflate { + namespace fixed_distance_code { + const std::size_t DIST_CODE_COUNT = 30; + + using _distance_table_t = std::array; + + constexpr _distance_table_t _make_distance_table() { + _distance_table_t tbl{}; + + for (std::size_t n = 0; n < 4; n++) { + tbl[n] = DistanceDestination(0, n+1); + } + + for (std::size_t extra = 1; extra <= 13; extra++) { + std::size_t base_d = 1 + (2 << extra); + std::size_t second_d = base_d + (2 << (extra - 1)); + std::size_t base_n = 2 + extra + extra; + + tbl[base_n] = DistanceDestination(extra, base_d); + tbl[base_n + 1] = DistanceDestination(extra,second_d); + } + + return tbl; + } + + constexpr _distance_table_t distance_table = _make_distance_table(); + + static_assert(DistanceDestination(0, 4) == distance_table[3]); + static_assert(DistanceDestination(7, 385) == distance_table[17]); + static_assert(DistanceDestination(12, 8193) == distance_table[26]); + + using _code_table_t = std::array; + + constexpr _code_table_t _make_lengths() { + _code_table_t len_ary{}; + + for (size_t n = 0; n < len_ary.size(); n++) { + len_ary[n] = 5; + } + + return len_ary; + } + + constexpr _code_table_t lengths = _make_lengths(); + + constexpr _code_table_t _make_codes(const _code_table_t len_ary) { + _code_table_t code_ary{}; + + uint16_t code = 0; + std::array bl_count{0, 0, 0, 0, 0, 0}; + std::array next_code{0, 0, 0, 0, 0, 0}; + + for (byte bit_count : len_ary) { + bl_count[bit_count] += 1; + } + + for (byte bits = 1; bits <= 5; bits++) { + code = (code + bl_count[bits - 1]) << 1; + next_code[bits] = code; + } + + for (uint16_t n = 0; n < DIST_CODE_COUNT; n++) { + byte len = len_ary[n]; + if (0 != len) { + code_ary[n] = next_code[len]; + next_code[len] += 1; + } + } + + return code_ary; + } + + constexpr _code_table_t codes = _make_codes(lengths); + + constexpr _code_table_t _invert_codes(const _code_table_t codes) { + _code_table_t inv_ary{}; + + for (uint16_t n = 0; n < DIST_CODE_COUNT; n++) { + inv_ary[codes[n]] = n; + } + + return inv_ary; + } + + constexpr _code_table_t code_to_sym = _invert_codes(codes); + } + + class FixedDistanceCode : public DistanceCode { + public: + virtual const DistanceDestination& get_destination(CodeBits path); + }; +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/history.hpp b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/history.hpp new file mode 100755 index 0000000..9a37a5f --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/history.hpp @@ -0,0 +1,20 @@ +#pragma once + +#include +#include + +#include "common.hpp" + +#include "zlib_header.hpp" + +namespace deflate { + constexpr uint32_t MAX_HISTORY = 32768; + + class History { + public: + virtual void append(byte b) = 0; + virtual std::vector copy(uint32_t dist, uint16_t count) = 0; + + virtual ~History() {}; + }; +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/navigator.hpp b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/navigator.hpp new file mode 100755 index 0000000..c6fb0b0 --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/navigator.hpp @@ -0,0 +1,13 @@ +#pragma once + +class Navigator { +public: + Navigator(Code c) code(c); + + Destination navigate(BitVector bv); + + Destinatinon destination = {}' + +private: + Code c; +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/vector_history.cpp b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/vector_history.cpp new file mode 100755 index 0000000..2d0d459 --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/vector_history.cpp @@ -0,0 +1,28 @@ +#include "vector_history.hpp" + +using namespace deflate; + +VectorHistory::VectorHistory(ZlibHeader& zlh) : header(zlh) { + buf.resize(header.window_size()); +} + +void VectorHistory::append(byte b) { + buf[cursor] = b; + cursor = (cursor + 1) % buf.size(); +} + +std::vector VectorHistory::copy(uint32_t dist, uint16_t count) { + std::ptrdiff_t start_cur = (cursor - dist) % buf.size(); + std::ptrdiff_t end_cur = start_cur + count; + + std::vector cpy{}; + cpy.reserve(count); + + for (std::size_t n = start_cur; n < end_cur; n++) { + byte b = buf[n % MAX_HISTORY]; + append(b); + cpy.push_back(b); + } + + return cpy; +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/vector_history.hpp b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/vector_history.hpp new file mode 100755 index 0000000..2f8d423 --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/vector_history.hpp @@ -0,0 +1,21 @@ +#pragma once + +#include + +#include "history.hpp" + +namespace deflate { + class VectorHistory : public History { + public: + VectorHistory(ZlibHeader& zlh); + virtual ~VectorHistory() override {}; + + virtual void append(byte b) override; + virtual std::vector copy(uint32_t dist, uint16_t count) override; + + private: + ZlibHeader& header; + std::vector buf; + std::size_t cursor = 0; + }; +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/zlib_header.cpp b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/zlib_header.cpp new file mode 100755 index 0000000..82cc3cc --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/zlib_header.cpp @@ -0,0 +1,46 @@ +#include "../assert.hpp" + +#include "zlib_header.hpp" + +using namespace deflate; + +ZlibHeader::ZlibHeader(BitVector& bv) { + compression_method = (byte)bv.read_bits(4); + compression_info = (byte)bv.read_bits(4); + + byte cmf_byte = (compression_info << 4) | (compression_method); + + byte cmf_ck = (byte)bv.read_bits(5); + + preset_dictionary = (byte)bv.read_bits(1); + compression_level = (byte)bv.read_bits(2); + + byte flag_byte = + (compression_level << 6) | + (preset_dictionary << 5) | + (cmf_ck); + + uint16_t flag_check = (cmf_byte << 8) | flag_byte; + + assert((flag_check % 31) == 0); +} + +uint32_t ZlibHeader::window_size() const { + return 1 << (8 + compression_info); +} + +void ZlibHeader::validate_png() const { + assert(8 == compression_method); + assert(7 >= compression_info); + + assert(0 == preset_dictionary); +} + +void ZlibHeader::inspect(std::ostream& w) const { + w << std::hex + << "method(" << (int)compression_method + << ") info(" << (int)compression_info + << ") preset_dict(" << (bool)preset_dictionary + << ") level(" << (int)compression_level + << ")" << std::endl; +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/zlib_header.hpp b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/zlib_header.hpp new file mode 100755 index 0000000..be457a2 --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/deflate/zlib_header.hpp @@ -0,0 +1,27 @@ +#pragma once + +#include +#include + +#include "../hbytes.hpp" + +#include "bit_vector.hpp" + +namespace deflate { + class ZlibHeader { + public: + ZlibHeader(BitVector& bv); + + byte compression_method; + byte compression_info; + + byte preset_dictionary; + byte compression_level; + + void validate_png() const; + + uint32_t window_size() const; + + void inspect(std::ostream& w) const; + }; +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/dimensions.cpp b/tests/generic_tests/targets/hamlin/challenge/src/src/dimensions.cpp new file mode 100755 index 0000000..1bd632c --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/dimensions.cpp @@ -0,0 +1,6 @@ +#include "dimensions.hpp" + +std::ostream& operator<<(std::ostream& os, Dimensions d) { + os << d.width << "x" << d.height; + return os; +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/dimensions.hpp b/tests/generic_tests/targets/hamlin/challenge/src/src/dimensions.hpp new file mode 100755 index 0000000..64273a1 --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/dimensions.hpp @@ -0,0 +1,18 @@ +#pragma once + +#include +#include + +using dim = uint32_t; + +struct Dimensions { +public: + Dimensions(dim w, dim h) : width(w), height(h) {}; + + Dimensions(const Dimensions& o) : width(o.width), height(o.height) {}; + + dim width; + dim height; +}; + +std::ostream& operator<<(std::ostream& os, Dimensions d); diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/hbytes.hpp b/tests/generic_tests/targets/hamlin/challenge/src/src/hbytes.hpp new file mode 100755 index 0000000..501af09 --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/hbytes.hpp @@ -0,0 +1,3 @@ +#pragma once + +using byte = unsigned char; diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/hrl.cpp b/tests/generic_tests/targets/hamlin/challenge/src/src/hrl.cpp new file mode 100755 index 0000000..dae436a --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/hrl.cpp @@ -0,0 +1,158 @@ +#include + +#include "assert.hpp" +#include "log.hpp" +#include "reader.hpp" + +#include "hrl/header.hpp" +#include "hrl/pixel.hpp" +#include "hrl/run.hpp" + +#include "hrl.hpp" + +using namespace hrl; + +Hrl::Hrl(std::istream& r) { + + Header head(r); + + _width = head.width; + _height = head.height; + + _sigil = head.sigil; + + read_pixels(r); +} + +Hrl::Hrl(const Image& i) : +_width(i.width()), _height(i.height()), _pixels(i.pixels()) +{ + _sigil = determine_sigil(); +} + +Hrl::~Hrl() {}; + +void Hrl::write(std::ostream& o) { + hrl::Header head(_width, _height, _sigil); + head.write(o); + + uint64_t total_pixels = _height * _width; + uint64_t current_pixel = 0; + + while (total_pixels > current_pixel) { + Rgba cur_color = _pixels[current_pixel]; + if ((current_pixel + 1) < _pixels.size()) { + + Rgba next_color = _pixels[current_pixel + 1]; + + if (cur_color == next_color) { + byte run_count = 1; + while (cur_color == next_color) { + run_count++; + if (255 == run_count) break; + if (total_pixels <= current_pixel + run_count) break; + next_color = _pixels[current_pixel + run_count]; + } + + + + + Run r{_sigil, run_count, cur_color.r, cur_color.g, cur_color.b}; + o.write((char*)(void*)&r, sizeof(r)); + current_pixel += run_count; + continue; + } + } + + if (_sigil == cur_color.r) { + + + + Run r{_sigil, 1, cur_color.r, cur_color.g, cur_color.b}; + o.write((char*)(void*)&r, sizeof(r)); + current_pixel += 1; + continue; + } + + + Pixel p{cur_color.r, cur_color.g, cur_color.b}; + o.write((char*)(void*)&p, sizeof(p)); + current_pixel += 1; + } +} + +void Hrl::inspect(std::ostream& w) { + w << "HRL" << std::endl; + + w << "\tdimensions(" << dimensions() << ") sigil(" << _sigil << + ")" << std::endl; +} + +Dimensions Hrl::dimensions() const { + return Dimensions(_width, _height); +} + +uint32_t Hrl::width() const { + return _width; +} + +uint32_t Hrl::height() const { + return _height; +} + +const std::vector& Hrl::pixels() const { + return _pixels; +} + +byte Hrl::determine_sigil() { + std::array red_counts = {0}; + + for (Rgba px : _pixels) { + red_counts[px.r]++; + } + + uint32_t least_popular_count = _width * _height; + uint32_t least_popular_value = 0; + + for (uint32_t val = 0; val < 256; val++) { + uint32_t cnt = red_counts[val]; + if (cnt < least_popular_count) { + least_popular_count = cnt; + least_popular_value = val; + } + } + + lll("sigil %d with count %d", least_popular_value, least_popular_count); + + return least_popular_value; +} + +void Hrl::read_pixels(std::istream& r) { + uint32_t pixel_count = _width * _height; + + _pixels.clear(); + _pixels.reserve(pixel_count); + + uint32_t remaining_pixels = pixel_count; + + while (remaining_pixels > 0) { + byte sentry = static_cast(r.peek()); + + if (_sigil == sentry) { + Run run; + r.read((char*)(void*) &run, sizeof(run)); + Rgba px(run.r, run.g, run.b); + for (byte c = 0; c < run.length; c++) { + _pixels.push_back(px); + } + + remaining_pixels -= run.length; + } else { + Pixel pxl; + r.read((char*)(void*) &pxl, sizeof(pxl)); + _pixels.push_back({pxl.r, pxl.g, pxl.b}); + + remaining_pixels -= 1; + } + } +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/hrl.hpp b/tests/generic_tests/targets/hamlin/challenge/src/src/hrl.hpp new file mode 100755 index 0000000..b7c1dfc --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/hrl.hpp @@ -0,0 +1,37 @@ +#pragma once + +#include + +#include "dimensions.hpp" +#include "image.hpp" +#include "rgba.hpp" +#include "reader.hpp" + +class Hrl : public Image { +public: + Hrl(std::istream& r); + Hrl(const Image& i); + + ~Hrl() override; + + void inspect(std::ostream& w) override; + void write(std::ostream& w) override; + + Dimensions dimensions() const; + + uint32_t width() const override; + uint32_t height() const override; + const std::vector& pixels() const override; + +private: + uint32_t _width; + uint32_t _height; + + std::vector _pixels; + + byte _sigil; + + byte determine_sigil(); + + void read_pixels(std::istream& r); +}; diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/hrl/common.hpp b/tests/generic_tests/targets/hamlin/challenge/src/src/hrl/common.hpp new file mode 100755 index 0000000..767a568 --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/hrl/common.hpp @@ -0,0 +1,10 @@ +#pragma once + +#include +#include +#include + +#include "../assert.hpp" +#include "../hbytes.hpp" +#include "../hton.hpp" +#include "../log.hpp" diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/hrl/header.cpp b/tests/generic_tests/targets/hamlin/challenge/src/src/hrl/header.cpp new file mode 100755 index 0000000..84594a6 --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/hrl/header.cpp @@ -0,0 +1,31 @@ +#include + +#include "common.hpp" + +#include "header.hpp" + +using namespace hrl; + +const uint32_t expected_magic = 'HRLe'; + +Header::Header(std::istream& r) { + Pack p; + r.read((char*)(void*)(&p), sizeof(p)); + + lll("hrl got magic (net %x host %x) want %x", p.magic, ntoh(p.magic), expected_magic); + + assert(expected_magic == ntohl(p.magic)); + width = ntohl(p.width); + height = ntohl(p.height); + sigil = p.sigil; +} + +void Header::write(std::ostream& o) { + Pack p; + p.magic = htonl(expected_magic); + p.width = htonl(width); + p.height = htonl(height); + p.sigil = sigil; + + o.write((char*)(void*)(&p), sizeof(p)); +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/hrl/header.hpp b/tests/generic_tests/targets/hamlin/challenge/src/src/hrl/header.hpp new file mode 100755 index 0000000..11471fe --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/hrl/header.hpp @@ -0,0 +1,30 @@ +#pragma once + +#include "common.hpp" + +namespace hrl { + class Header { + public: + Header(std::istream& r); + Header(uint32_t w, uint32_t h, uint32_t s) : + width(w), height(h), sigil(s) {}; + Header(const Header& o) : + width(o.width), height(o.height), sigil(o.sigil) {}; + + uint32_t width; + uint32_t height; + byte sigil; + + void write(std::ostream& o); + + private: + struct __attribute__((packed)) Pack { + uint32_t magic; + uint32_t width; + uint32_t height; + byte sigil; + }; + + static_assert(sizeof(Pack) == 13); + }; +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/hrl/pixel.hpp b/tests/generic_tests/targets/hamlin/challenge/src/src/hrl/pixel.hpp new file mode 100755 index 0000000..4372c95 --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/hrl/pixel.hpp @@ -0,0 +1,13 @@ +#pragma once + +#include "common.hpp" + +namespace hrl { + struct __attribute__((packed)) Pixel { + byte r; + byte g; + byte b; + }; + + static_assert(sizeof(Pixel) == 3); +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/hrl/run.hpp b/tests/generic_tests/targets/hamlin/challenge/src/src/hrl/run.hpp new file mode 100755 index 0000000..3287ca7 --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/hrl/run.hpp @@ -0,0 +1,15 @@ +#pragma once + +#include "common.hpp" + +namespace hrl { + struct __attribute__((packed)) Run { + byte sigil; + byte length; + byte r; + byte g; + byte b; + }; + + static_assert(sizeof(Run) == 5); +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/hton.hpp b/tests/generic_tests/targets/hamlin/challenge/src/src/hton.hpp new file mode 100755 index 0000000..8648a64 --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/hton.hpp @@ -0,0 +1,37 @@ +#pragma once + +#include + +#include "hbytes.hpp" + +template +constexpr T +hton(const T &host_i) +{ + T net_i = 0; + const std::size_t byte_count = sizeof(host_i); + byte* bytes = (byte*)(void*) &net_i; + + for (std::size_t idx = 0; idx < byte_count; idx++) { + bytes[idx] = host_i >> ((byte_count - idx - 1) * 8); + } + + return net_i; +} + + +// yeah it's the same +template +constexpr T +ntoh(const T &host_i) +{ + T net_i = 0; + const std::size_t byte_count = sizeof(host_i); + byte* bytes = (byte*)(void*) &net_i; + + for (std::size_t idx = 0; idx < byte_count; idx++) { + bytes[idx] = host_i >> ((byte_count - idx - 1) * 8); + } + + return net_i; +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/image.cpp b/tests/generic_tests/targets/hamlin/challenge/src/src/image.cpp new file mode 100755 index 0000000..fc71605 --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/image.cpp @@ -0,0 +1,3 @@ +#include "image.hpp" + +Image::~Image() {}; diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/image.hpp b/tests/generic_tests/targets/hamlin/challenge/src/src/image.hpp new file mode 100755 index 0000000..995ec1c --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/image.hpp @@ -0,0 +1,19 @@ +#pragma once + +#include +#include +#include + +#include "rgba.hpp" + +class Image { +public: + virtual ~Image(); + + virtual uint32_t width() const = 0; + virtual uint32_t height() const = 0; + virtual const std::vector& pixels() const = 0; + + virtual void inspect(std::ostream& w) = 0; + virtual void write(std::ostream& w) = 0; +}; diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/inflate.cpp b/tests/generic_tests/targets/hamlin/challenge/src/src/inflate.cpp new file mode 100755 index 0000000..18cf630 --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/inflate.cpp @@ -0,0 +1,158 @@ +#include +#include + +#include "assert.hpp" +#include "log.hpp" + +#include "inflate.hpp" +#include "deflate/block_type.hpp" +#include "deflate/code.hpp" +#include "deflate/destination.hpp" +#include "deflate/distance_destination.hpp" +#include "deflate/dynamic_code.hpp" +#include "deflate/fixed_code.hpp" +#include "deflate/fixed_distance_code.hpp" +#include "deflate/history.hpp" + +using namespace deflate; + +Inflate::Inflate(ZlibHeader zlh, BitVector& b) : + header(zlh), compressed(b) { + if (MAX_HISTORY == header.window_size()) { + history = new ArrayHistory(header); + } else { + history = new VectorHistory(header); + } +} + +std::vector Inflate::inflated() { + if (_inflated.size() > 0) return _inflated; + + inflate(); + + assert(_inflated.size() > 0); + + return _inflated; +} + +void Inflate::inflate() { + byte is_final = 0; + + while (1 != is_final) { + + is_final = (byte)compressed.read_bits(1); + BlockType blk_type = (BlockType)compressed.read_bits(2); + + lll("final %d block %s(%d)", + is_final, to_string(blk_type).c_str(), (int)blk_type); + + switch(blk_type) { + case BlockType::uncompressed: + inflate_uncompressed(); + break; + case BlockType::fixed: + inflate_fixed(); + break; + case BlockType::dynamic: + inflate_dynamic(); + break; + default: + assert(false); + } + } +} + +void Inflate::inflate_uncompressed() { + compressed.finish_byte(); + uint16_t len = compressed.read_u16(); + uint16_t nlen = compressed.read_u16(); + assert(len == (~nlen & 0xffff)); + std::vector got = compressed.read_bytes(len); + _inflated.insert(_inflated.end(), + got.begin(), got.end()); +} + +void Inflate::inflate_dynamic() { + DynamicCode dynamic_code = DynamicCode(compressed); + DynamicDistanceCode dyn_dist_code = dynamic_code.get_distance_code(); + inflate_code(dynamic_code, dyn_dist_code); +} + +void Inflate::inflate_fixed() { + FixedCode fc{}; + FixedDistanceCode dist_code{}; + inflate_code(fc, dist_code); +} + +void Inflate::inflate_code(Code& code, DistanceCode& dist_code) { + CodeBits path{}; + + while (true) { + path = path + compressed.read_bit(); + + const Destination& dest = code.get_destination(path); + + if (dest == destination::identity::incomplete) { + continue; + } + + if (dest == destination::identity::end_of_block) { + break; + } + + if (dest == destination::identity::invalid) { + assert(false); + } + + + + const destination::Literal* lit_dest = + dynamic_cast(&dest); + + if (nullptr != lit_dest) { + history->append(lit_dest->val); + _inflated.push_back(lit_dest->val); + + path = CodeBits{}; + continue; + } + + const destination::Backref* backref_dest = + dynamic_cast(&dest); + + assert(nullptr != backref_dest); + const destination::Backref& backref = *backref_dest; + + uint32_t extra_len = compressed.read_bits(backref.bits); + uint32_t total_len = backref.min + extra_len; + + assert(total_len <= (2 << 15) - 1); + + + DistanceDestination const* dist_dest = nullptr; + + CodeBits dist_path{}; + do { + dist_path = dist_path + compressed.read_bit(); + dist_dest = &dist_code.get_destination(dist_path); + + } while (dist_dest->is_incomplete()); + assert(!dist_dest->is_invalid()); + + uint32_t extra_dist = compressed.read_bits(dist_dest->extra); + uint32_t total_dist = dist_dest->min + extra_dist; + + + + + + std::vector backref_contents = history->copy(total_dist, + (uint16_t)total_len); + assert(backref_contents.size() == total_len); + _inflated.insert(_inflated.end(), + backref_contents.begin(), + backref_contents.end()); + + path = CodeBits{}; + } +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/inflate.hpp b/tests/generic_tests/targets/hamlin/challenge/src/src/inflate.hpp new file mode 100755 index 0000000..5efff9d --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/inflate.hpp @@ -0,0 +1,40 @@ +#pragma once + +#include + +#include "deflate/bit_vector.hpp" +#include "deflate/code.hpp" +#include "deflate/distance_code.hpp" +#include "deflate/history.hpp" +#include "deflate/array_history.hpp" +#include "deflate/vector_history.hpp" +#include "deflate/zlib_header.hpp" + +using deflate::BitVector; +using deflate::Code; +using deflate::DistanceCode; +using deflate::History; +using deflate::ZlibHeader; + +class Inflate { +public: + Inflate(ZlibHeader zlh, BitVector& b); + ~Inflate() { + if (nullptr != history) delete history; + } + + std::vector inflated(); +private: + ZlibHeader header; + BitVector& compressed; + std::vector _inflated = {}; + History* history = nullptr; + + void inflate(); + + void inflate_uncompressed(); + void inflate_fixed(); + void inflate_dynamic(); + + void inflate_code(Code& code, DistanceCode& dist_code); +}; diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/log.cpp b/tests/generic_tests/targets/hamlin/challenge/src/src/log.cpp new file mode 100755 index 0000000..ad420db --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/log.cpp @@ -0,0 +1,18 @@ +#ifndef NO_LOG +#include +#include + +#include "log.hpp" + +void logloglog(const char* file, + unsigned int line, + const char* function, + std::string message, ...) { + fprintf(stderr, "%s:%d: %s: ", file, line, function); + va_list args; + va_start(args, message); + vfprintf(stderr, message.c_str(), args); + va_end(args); + fprintf(stderr, "\n"); +} +#endif diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/log.hpp b/tests/generic_tests/targets/hamlin/challenge/src/src/log.hpp new file mode 100755 index 0000000..4d9657a --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/log.hpp @@ -0,0 +1,15 @@ +#pragma once + + +#ifdef NO_LOG +#define lll(...) /* noop( __VA_ARGS__) */ +#else +#include + +#define lll(...) logloglog(__FILE__, __LINE__, __extension__ __PRETTY_FUNCTION__, __VA_ARGS__) + +void logloglog(const char* file, + unsigned int line, + const char* function, + std::string message, ...); +#endif diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/main.cpp b/tests/generic_tests/targets/hamlin/challenge/src/src/main.cpp new file mode 100755 index 0000000..2ffd439 --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/main.cpp @@ -0,0 +1,146 @@ +#include +#include +#include +#include +#include + +#include "assert.hpp" +#include "log.hpp" +#include "hrl.hpp" +#include "hton.hpp" +#include "metered_in.hpp" +#include "net.hpp" +#include "png.hpp" +#include "ppm.hpp" +#include "reader.hpp" + +using std::string; + +const uint32_t png_format = '\0PNG'; +const uint32_t ppm_format = '\0PPM'; +const uint32_t hrl_format = '\0HRL'; + +void expect_good(std::istream& in) { + if (in.good()) return; + + if (in.eof()) { + lll("got EOF, exiting"); + std::exit(0); + } + + lll("expected in to be good, but it wasn't"); + std::exit(-1); +} + +int main() { + std::set_terminate([](){ + lll("Unhandled exception, terminating"); + std::exit(-1); + }); + + #ifndef NO_TESTBED + char *b = getenv("CHESS"); + if ( b == NULL ) { + std::cout << "[TESTBED] ENV variable check failed" << std::endl; + exit(0); + } + #endif + + std::istream* in_p; + std::ostream* out_p; + + if (std::getenv("PORT")) { + Net handler{(uint16_t)std::atoi(std::getenv("PORT"))}; + in_p = handler.get_in(); + out_p = handler.get_out(); + } else { + in_p = &std::cin; + out_p = &std::cout; + } + + std::istream& in = *in_p; + std::ostream& out = *out_p; + + expect_good(in); + + while (true) { + + uint32_t in_format; + in.read((char*)(void*) &in_format, sizeof(in_format)); + + expect_good(in); + + + uint32_t out_format; + in.read((char*)(void*) &out_format, sizeof(out_format)); + + expect_good(in); + + + uint64_t in_len_net; + in.read((char*)(void*) &in_len_net, sizeof(in_len_net)); + uint64_t in_len = ntoh(in_len_net); + + expect_good(in); + + lll("formats in(net %x host %x) out(net %x host %x)", + in_format, ntoh(in_format), out_format, ntoh(out_format)); + lll("len %d", in_len); + + + std::ios_base::iostate existing_exceptions = in.exceptions(); + in.exceptions(std::ios::eofbit | existing_exceptions); + + Image* in_image; + switch (ntoh(in_format)) { + case ppm_format: + in_image = new Ppm(in); + break; + case hrl_format: + in_image = new Hrl(in); + break; + + case png_format: + in_image = new Png(in); + break; + + default: + lll("unknown format %lx", ntoh(in_format)); + assert(false); + } + + in.exceptions(existing_exceptions); + + + Image* out_image; + switch (ntoh(out_format)) { + case ppm_format: + out_image = new Ppm(*in_image); + break; + case hrl_format: + out_image = new Hrl(*in_image); + break; + + default: + lll("unknown format %lx", ntoh(out_format)); + assert(false); + } + + delete in_image; + + std::ostringstream out_buf; + out_image->write(out_buf); + + delete out_image; + + + std::string out_str = out_buf.str(); + uint64_t out_len_net = hton(out_str.size()); + out.write((char*)(void*) &out_len_net, sizeof(out_len_net)); + + out.write(out_str.c_str(), out_str.size()); + out.flush(); + } + + return 0; +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/metered_in.cpp b/tests/generic_tests/targets/hamlin/challenge/src/src/metered_in.cpp new file mode 100755 index 0000000..5a97931 --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/metered_in.cpp @@ -0,0 +1,24 @@ +#include +#include + +#include "assert.hpp" + +#include "metered_in.hpp" + +MeteredIn::MeteredIn(std::istream& base, uint64_t limit) : + std::istream(base.rdbuf()), remain(limit) +{ + +} + +MeteredIn& MeteredIn::read(char* dest, std::streamsize count) { + assert(count > remain); + remain -= count; + std::istream::read(dest, count); + return *this; +} + +char MeteredIn::peek() { + assert(remain >= 1); + return std::istream::peek(); +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/metered_in.hpp b/tests/generic_tests/targets/hamlin/challenge/src/src/metered_in.hpp new file mode 100755 index 0000000..1a51a2f --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/metered_in.hpp @@ -0,0 +1,15 @@ +#pragma once + +#include +#include + +class MeteredIn : public std::istream { +public: + MeteredIn(std::istream& base, uint64_t limit); + + MeteredIn& read(char* dest, std::streamsize count); + + char peek(); +private: + uint64_t remain; +}; diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/net.cpp b/tests/generic_tests/targets/hamlin/challenge/src/src/net.cpp new file mode 100755 index 0000000..2717a06 --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/net.cpp @@ -0,0 +1,43 @@ +#include +#include +#include +#include +#include + +#include "assert.hpp" +#include "log.hpp" +#include "hton.hpp" +#include "net.hpp" + +Net::Net(uint16_t p) : port(p) { + lll("preparing to listen on port %d", port); + int sock = socket(AF_INET, SOCK_STREAM, 0); + int opt = 1; + int sock_got = setsockopt(sock, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, + &opt, sizeof(opt)); + assert_zero(sock_got); + + struct sockaddr_in address; + address.sin_family = AF_INET; + address.sin_addr.s_addr = INADDR_ANY; + address.sin_port = hton(port); + + int bind_got = bind(sock, (const struct sockaddr*)&address, sizeof(address)); + assert_zero(bind_got); + + int listen_got = listen(sock, 1); + assert_zero(listen_got); + + struct sockaddr client_address; + socklen_t client_address_len = sizeof(client_address); + + client_fd = accept(sock, &client_address, &client_address_len); +} + +std::istream* Net::get_in() { + return &std::cin; +} + +std::ostream* Net::get_out() { + return &std::cout; +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/net.hpp b/tests/generic_tests/targets/hamlin/challenge/src/src/net.hpp new file mode 100755 index 0000000..8daa9f8 --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/net.hpp @@ -0,0 +1,14 @@ +#pragma once + +#include + +class Net { +public: + Net(uint16_t p); + + std::istream* get_in(); + std::ostream* get_out(); + + uint16_t port; + int client_fd; +}; diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/png.cpp b/tests/generic_tests/targets/hamlin/challenge/src/src/png.cpp new file mode 100755 index 0000000..5e6a94e --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/png.cpp @@ -0,0 +1,62 @@ +#include +#include +#include +#include +#include +#include + +#include "assert.hpp" +#include "png/chunk_factory.hpp" +#include "dimensions.hpp" +#include "hbytes.hpp" +#include "log.hpp" +#include "reader.hpp" + +#include "png.hpp" + +using namespace png; + +Png::Png(std::istream& r) : file(r), factory(r) { + std::vector raw_idat = image_data(); + + factory.ihdr()->inspect(std::cerr); + + coder.load_image_data(raw_idat, *factory.ihdr(), factory.plte()); +} + +Png::~Png() {} + +void Png::write(std::ostream& w) { + w << "blah"; +} + +void Png::inspect(std::ostream& w) { + w << "PNG" << std::endl; + ihdr().inspect(w); + factory.inspect(w); + coder.inspect(w); +} + +Dimensions Png::dimensions() const { + return Dimensions(ihdr().cols, ihdr().rows); +} + +Ihdr Png::ihdr() const { + return *(factory.ihdr()); +} + +uint32_t Png::width() const { + return ihdr().cols; +} + +uint32_t Png::height() const { + return ihdr().rows; +} + +std::vector Png::image_data() const { + return factory.image_data(); +} + +const std::vector& Png::pixels() const { + return coder.to_pixels(); +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/png.hpp b/tests/generic_tests/targets/hamlin/challenge/src/src/png.hpp new file mode 100755 index 0000000..d41288f --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/png.hpp @@ -0,0 +1,51 @@ +#pragma once + +#include +#include + +#include "dimensions.hpp" +#include "hbytes.hpp" +#include "image.hpp" +#include "reader.hpp" + +#include "png/chunk.hpp" +#include "png/chunk_factory.hpp" +#include "png/file.hpp" +#include "png/idat.hpp" +#include "png/ihdr.hpp" +#include "png/image_coder.hpp" + +using Chunk = png::Chunk; +using ChunkFactory = png::ChunkFactory; +using File = png::File; +using Idat = png::Idat; +using Ihdr = png::Ihdr; +using ImageCoder = png::ImageCoder; + +class Png : public Image { +public: + Png(std::istream& r); + Png(const Image& i); + + ~Png() override; + + void inspect(std::ostream& w) override; + void write(std::ostream& w) override; + byte bytes_per_pixel() const; + + Dimensions dimensions() const; + + uint32_t width() const override; + uint32_t height() const override; + const std::vector& pixels() const override; + + std::vector image_data() const; + +private: + + File file; + ChunkFactory factory; + ImageCoder coder; + Ihdr ihdr() const; + std::vector idat; +}; diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/png/chunk.cpp b/tests/generic_tests/targets/hamlin/challenge/src/src/png/chunk.cpp new file mode 100755 index 0000000..ead6192 --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/png/chunk.cpp @@ -0,0 +1,58 @@ +#include "chunk.hpp" + +using namespace png; + +Chunk::Chunk(std::istream& r) { + uint32_t len_n; + r.read((char*)(void*)&len_n, sizeof(len_n)); + length = ntoh(len_n); + + uint32_t type_n; + r.read((char*)(void*)&type_n, sizeof(type_n)); + type = ntoh(type_n); + + data.resize(length); + if (0 != length) { + r.read((char*)(void*)data.data(), length); + } + + uint32_t crc_n; + r.read((char*)(void*)&crc_n, sizeof(crc_n)); + crc = ntoh(crc_n); + + crc32::crc_t calc_crc = crc32::calculate_begin((char*)(void*)&type_n, + sizeof(type_n)); + calc_crc = crc32::calculate_inter(calc_crc, (char*)data.data(), data.size()); + calc_crc = crc32::calculate_final(calc_crc); + + lll("len %x type %c%c%c%c (%x)", length, + 0xFF & (type >> 24), + 0xFF & (type >> 16), + 0xFF & (type >> 8), + 0xFF & (type), + type); + lll("crc at got %x expected %x", calc_crc, crc); + + // assert(calc_crc == crc); +} + +void Chunk::inspect(std::ostream& w) { + w << "Chunk length(" << length << ") type(" << type << ") crc(" << crc << ") " + << std::endl; +} + +void Chunk::assert_crc() { + +} + + +std::string Chunk::type_string() { + std::string dest = "1234"; + sprintf(dest.data(), "%c%c%c%c", + 0xFF & (type >> 24), + 0xFF & (type >> 16), + 0xFF & (type >> 8), + 0xFF & (type >> 0)); + + return dest; +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/png/chunk.hpp b/tests/generic_tests/targets/hamlin/challenge/src/src/png/chunk.hpp new file mode 100755 index 0000000..2b0753e --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/png/chunk.hpp @@ -0,0 +1,29 @@ +#pragma once + +#include +#include + +#include "common.hpp" + +namespace png { + class Chunk { + public: + Chunk(std::istream& r); + Chunk(const Chunk& c) : length(c.length), type(c.type), + data(c.data), crc(c.crc) {}; + Chunk() : length(0), type('__PH'), + data({}), crc(0) {}; + virtual ~Chunk() {}; + + virtual void inspect(std::ostream& w); + + void assert_crc(); + + std::string type_string(); + + uint32_t length = 0; + uint32_t type; + std::vector data; + uint32_t crc; + }; +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/png/chunk_factory.cpp b/tests/generic_tests/targets/hamlin/challenge/src/src/png/chunk_factory.cpp new file mode 100755 index 0000000..7fc7076 --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/png/chunk_factory.cpp @@ -0,0 +1,123 @@ +#include +#include + +#include "common.hpp" + +#include "chunk_factory.hpp" + +using namespace png; + +ChunkFactory::ChunkFactory(std::istream& r) { + Chunk* read_chunk = {}; + + do { + assert(r.good()); + assert(!r.eof()); + read_chunk = new Chunk(r); + chunks.push_back(read_chunk); + } while ('IEND' != read_chunk->type); + + materialize_chunks(); +} + +ChunkFactory::~ChunkFactory() { + for (Chunk* c : chunks) { + delete c; + } +} + +void ChunkFactory::inspect(std::ostream& w) { + w << "ChunkFactory chunks(" << chunks.size() << ") {" << std::endl; + for (Chunk* c : chunks) { + w << "\t" << c->type_string() << ": len " + << c->length << std::endl; + } + w << "}" << std::endl; +} + +Ihdr* ChunkFactory::ihdr() const { + Ihdr* ih = dynamic_cast(chunks[0]); + assert(0 != ih); + return ih; +} + +std::vector ChunkFactory::idat() const { + std::vector dats = {}; + + for(Chunk* c : chunks) { + Idat* i; + if ((i = dynamic_cast(c))) { + dats.push_back(i); + } + } + + return dats; +} + +Plte* ChunkFactory::plte() const { + for(Chunk* c : chunks) { + Plte* p; + if ((p = dynamic_cast(c))) { + return p; + } + } + + return nullptr; +} + +std::vector ChunkFactory::image_data() const { + std::vector id = {}; + std::vector idcx = idat(); + + assert(idcx.size() > 0); + + auto sizer = [](std::size_t accumulator, Idat* i) { + return accumulator + i->length; + }; + + id.reserve(std::accumulate(std::next(idcx.begin()), idcx.end(), + 0, + sizer)); + + for (Idat* cur : idcx) { + std::vector& to_insert = cur->data; + id.insert(id.end(), to_insert.begin(), to_insert.end()); + } + + return id; +} + +void ChunkFactory::materialize_chunks() { + for(std::size_t i = 0; i < chunks.size(); i++) { + Chunk* c = chunks[i]; + + + if (0 == i) { + assert('IHDR' == c->type); + } + if ((chunks.size() - 1) == i) { + assert('IEND' == c->type); + } + + switch (c->type) { + case 'IHDR': + assert(0 == i); + chunks[i] = new Ihdr(*c); + delete c; + break; + case 'PLTE': + chunks[i] = new Plte(*c); + delete c; + break; + case 'IDAT': + chunks[i] = new Idat(*c); + delete c; + break; + case 'IEND': + assert((chunks.size() - 1) == i); + chunks[i] = new Iend(*c); + delete c; + break; + } + } +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/png/chunk_factory.hpp b/tests/generic_tests/targets/hamlin/challenge/src/src/png/chunk_factory.hpp new file mode 100755 index 0000000..0876d70 --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/png/chunk_factory.hpp @@ -0,0 +1,34 @@ +#pragma once + +#include + +#include "common.hpp" + +#include "chunk.hpp" +#include "ihdr.hpp" +#include "idat.hpp" +#include "iend.hpp" +#include "plte.hpp" + +namespace png { + class ChunkFactory { + public: + ChunkFactory(std::istream& r); + ~ChunkFactory(); + + void inspect(std::ostream& w); + + std::vector chunks; + + Ihdr* ihdr() const; + std::vector idat() const; + + Plte* plte() const; + + std::vector image_data() const; + + private: + void materialize_chunks(); + + }; +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/png/colorizer.cpp b/tests/generic_tests/targets/hamlin/challenge/src/src/png/colorizer.cpp new file mode 100755 index 0000000..c7c9cab --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/png/colorizer.cpp @@ -0,0 +1,20 @@ +#include "colorizer.hpp" +#include "rgb_colorizer.hpp" +#include "palette_colorizer.hpp" + +#include "plte.hpp" + +using namespace png; +using namespace png::colorizer; + +colorizer::Base* png::colorizer::get_colorizer(Ihdr ihdr, Plte* plte) { + switch (ihdr.color_type) { + case ColorType::rgb: + return new RgbColorizer(); + case ColorType::palette: + assert(nullptr != plte); + return new PaletteColorizer(*plte); + default: + assert(false); + } +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/png/colorizer.hpp b/tests/generic_tests/targets/hamlin/challenge/src/src/png/colorizer.hpp new file mode 100755 index 0000000..daa772f --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/png/colorizer.hpp @@ -0,0 +1,22 @@ +#pragma once + +#include + +#include "common.hpp" +#include "ihdr.hpp" +#include "plte.hpp" + +namespace png { + namespace colorizer { + class Base { + public: + virtual std::vector colorize(std::vector image_bytes) + const = 0; + + virtual ~Base() {}; + }; + + Base* get_colorizer(Ihdr ihdr, Plte* plte); + } + using Colorizer = colorizer::Base; +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/png/colortype.cpp b/tests/generic_tests/targets/hamlin/challenge/src/src/png/colortype.cpp new file mode 100755 index 0000000..4c4c82b --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/png/colortype.cpp @@ -0,0 +1,28 @@ +#include "colortype.hpp" + +using namespace png; + +std::ostream& operator<<(std::ostream& os, ColorType c) { + switch (c) { + case greyscale: + os << "greyscale"; + break; + case rgb: + os << "rgb"; + break; + case palette: + os << "palette"; + break; + case greyscale_alpha: + os << "greyscale_alpha"; + break; + case rgba: + os << rgba; + break; + default: + os << "unknown colortype " << (uint8_t)c; + break; + } + + return os; +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/png/colortype.hpp b/tests/generic_tests/targets/hamlin/challenge/src/src/png/colortype.hpp new file mode 100755 index 0000000..b8ebf26 --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/png/colortype.hpp @@ -0,0 +1,18 @@ +#pragma once + +#include +#include + +namespace png { + enum ColorType : uint8_t { + greyscale = 0, + rgb = 2, + palette = 3, + greyscale_alpha = 4, + rgba = 6, + _invalid = 255 + + }; +} + +std::ostream& operator<<(std::ostream& os, png::ColorType c); diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/png/common.hpp b/tests/generic_tests/targets/hamlin/challenge/src/src/png/common.hpp new file mode 100755 index 0000000..d36a78a --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/png/common.hpp @@ -0,0 +1,11 @@ +#pragma once + +#include + +#include "../assert.hpp" +#include "../crc32.hpp" +#include "../dimensions.hpp" +#include "../hbytes.hpp" +#include "../hton.hpp" +#include "../log.hpp" +#include "../rgba.hpp" diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/png/file.cpp b/tests/generic_tests/targets/hamlin/challenge/src/src/png/file.cpp new file mode 100755 index 0000000..5581662 --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/png/file.cpp @@ -0,0 +1,19 @@ +#include "common.hpp" +#include "file.hpp" + +using namespace png; + +const std::array file_header_expectation = + {0x89, + 0x50, 0x4e, 0x47, + 0x0d, 0x0a, + 0x1a, + 0x0a}; +const uint8_t file_header_size = file_header_expectation.size(); + +File::File(std::istream& r) { + std::array signature; + r.read((char*)(void*) signature.data(), signature.size()); + + assert(file_header_expectation == signature); +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/png/file.hpp b/tests/generic_tests/targets/hamlin/challenge/src/src/png/file.hpp new file mode 100755 index 0000000..e368026 --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/png/file.hpp @@ -0,0 +1,10 @@ +#pragma once + +#include "common.hpp" + +namespace png { + class File { + public: + File(std::istream& r); + }; +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/png/filter.cpp b/tests/generic_tests/targets/hamlin/challenge/src/src/png/filter.cpp new file mode 100755 index 0000000..61acb42 --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/png/filter.cpp @@ -0,0 +1,215 @@ +#include + +#include "common.hpp" + +#include "filter.hpp" + +using namespace png; +using namespace png::filter; +using namespace png::filter::identity; + +const Base& png::filter::get_filter(byte filter_sigil) { + FilterType type = (FilterType)filter_sigil; + switch (type) { + case FilterType::none: + return none; + case FilterType::sub: + return sub; + case FilterType::up: + return up; + case FilterType::average: + return average; + case FilterType::paeth: + return paeth; + default: + assert(false); + } +} + +std::string to_string(png::FilterType ft) { + switch (ft) { + case FilterType::none: + return "none"; + case FilterType::sub: + return "sub"; + case FilterType::up: + return "up"; + case FilterType::average: + return "average"; + case FilterType::paeth: + return "paeth"; + default: + return "unknown filtertype"; + } +} + +std::ostream& operator<<(std::ostream& os, png::FilterType ft) { + os << to_string(ft); + return os; +} + +ByteVec None::decode_first_row(ByteVec row, + uint32_t cols, + byte bytes_per_pixel) const { + ByteVec out_px{}; + out_px.reserve(cols * bytes_per_pixel); + + for (uint32_t i = 0; i < cols; i++) { + std::size_t offset = i * bytes_per_pixel; + for (byte j = 0; j < bytes_per_pixel; j++) { + out_px.push_back(row[offset+j]); + } + } + + return out_px; +} + +ByteVec None::decode_row(ByteVec row, + uint32_t cols, + byte bytes_per_pixel, + ByteVec _prev_row) const { + return decode_first_row(row, cols, bytes_per_pixel); +} + +ByteVec Sub::decode_first_row(ByteVec row, + uint32_t cols, + byte bytes_per_pixel) const { + ByteVec out_px{}; + out_px.reserve(cols * bytes_per_pixel); + + for (uint32_t i = 0; i < cols; i++) { + std::size_t offset = i * bytes_per_pixel; + for (byte j = 0; j < bytes_per_pixel; j++) { + std::size_t inner_offset = offset + j; + byte diff = row[inner_offset]; + byte orig = 0; + + if (0 != i) { + orig = out_px[inner_offset - bytes_per_pixel]; + } + + out_px.push_back(diff + orig); + } + } + + return out_px; +} + +ByteVec Sub::decode_row(ByteVec row, + uint32_t cols, + byte bytes_per_pixel, + ByteVec _prev_row) const { + return decode_first_row(row, cols, bytes_per_pixel); +} + +ByteVec Up::decode_first_row(ByteVec row, + uint32_t cols, + byte bytes_per_pixel) const { + ByteVec fake_prev_row{}; + fake_prev_row.resize(cols * bytes_per_pixel, 0); + return decode_row(row, cols, bytes_per_pixel, fake_prev_row); +} + +ByteVec Up::decode_row(ByteVec row, + uint32_t cols, + byte bytes_per_pixel, + ByteVec prev_row) const { + ByteVec out_px{}; + out_px.reserve(cols * bytes_per_pixel); + + for (uint32_t i = 0; i < cols; i++) { + std::size_t offset = (i * bytes_per_pixel); + for (byte j = 0; j < bytes_per_pixel; j++) { + std::size_t inner_offset = offset + j; + byte diff = row[inner_offset]; + byte orig = prev_row[inner_offset]; + + out_px.push_back(diff + orig); + } + } + + return out_px; +} + +ByteVec Average::decode_first_row(ByteVec row, + uint32_t cols, + byte bytes_per_pixel) const { + ByteVec fake_prev_row{}; + fake_prev_row.resize(cols * bytes_per_pixel, 0); + return decode_row(row, cols, bytes_per_pixel, fake_prev_row); +} + +ByteVec Average::decode_row(ByteVec row, + uint32_t cols, + byte bytes_per_pixel, + ByteVec prev_row) const { + ByteVec out_px{}; + out_px.reserve(cols * bytes_per_pixel); + + for (uint32_t i = 0; i < cols; i++) { + std::size_t offset = (i * bytes_per_pixel); + for (byte j = 0; j < bytes_per_pixel; j++) { + std::size_t inner_offset = offset + j; + uint16_t diff = row[inner_offset]; + uint16_t up = prev_row[inner_offset]; + uint16_t left = 0; + if (0 != i) { + left = out_px[inner_offset - bytes_per_pixel]; + } + + uint16_t calc = (up + left) >> 1; + assert(256 > calc); + byte result = (byte)(diff + calc); + out_px.push_back(result); + } + } + return out_px; +} + +byte paeth_predictor(byte a, byte b, byte c) { + int16_t p = a + b - c; + int16_t pa = std::abs((int16_t)p - (int16_t)a); + int16_t pb = std::abs((int16_t)p - (int16_t)b); + int16_t pc = std::abs((int16_t)p - (int16_t)c); + + if ((pa <= pb) && (pa <= pc)) return a; + if (pb <= pc) return b; + return c; +} + +ByteVec Paeth::decode_first_row(ByteVec row, + uint32_t cols, + byte bytes_per_pixel) const { + ByteVec fake_prev_row{}; + fake_prev_row.resize(cols * bytes_per_pixel, 0); + return decode_row(row, cols, bytes_per_pixel, fake_prev_row); +} + +ByteVec Paeth::decode_row(ByteVec row, + uint32_t cols, + byte bytes_per_pixel, + ByteVec prev_row) const { + ByteVec out_px{}; + out_px.reserve(cols * bytes_per_pixel); + + for (uint32_t i = 0; i < cols; i++) { + std::size_t offset = (i * bytes_per_pixel); + for (byte j = 0; j < bytes_per_pixel; j++) { + std::size_t inner_offset = offset + j; + byte diff = row[inner_offset]; + byte up = prev_row[inner_offset]; + byte left = 0; + byte upleft = 0; + if (0 != i) { + std::size_t left_offset = inner_offset - bytes_per_pixel; + left = out_px[left_offset]; + upleft = prev_row[left_offset]; + } + + byte calc = paeth_predictor(left, up, upleft); + + out_px.push_back(diff + calc); + } + } + return out_px; +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/png/filter.hpp b/tests/generic_tests/targets/hamlin/challenge/src/src/png/filter.hpp new file mode 100755 index 0000000..693b549 --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/png/filter.hpp @@ -0,0 +1,102 @@ +#pragma once + +#include + +#include "common.hpp" + +namespace png { + enum class FilterType : byte { + none = 0, + sub = 1, + up = 2, + average = 3, + paeth = 4, + _unknown = 255, + }; + + namespace filter { + using PixelVec = std::vector; + using ByteVec = std::vector; + + class Base { + public: + virtual ByteVec decode_first_row(ByteVec row, + uint32_t cols, + byte bytes_per_pixel) const = 0; + virtual ByteVec decode_row(ByteVec row, + uint32_t cols, + byte bytes_per_pixel, + ByteVec prev_row) const = 0; + }; + + class None : public Base { + public: + virtual ByteVec decode_first_row(ByteVec row, + uint32_t cols, + byte bytes_per_pixel) const override; + virtual ByteVec decode_row(ByteVec row, + uint32_t cols, + byte bytes_per_pixel, + ByteVec prev_row) const override; + }; + + class Sub : public Base { + public: + virtual ByteVec decode_first_row(ByteVec row, + uint32_t cols, + byte bytes_per_pixel) const override; + virtual ByteVec decode_row(ByteVec row, + uint32_t cols, + byte bytes_per_pixel, + ByteVec prev_row) const override; + }; + + class Up : public Base { + public: + virtual ByteVec decode_first_row(ByteVec row, + uint32_t cols, + byte bytes_per_pixel) const override; + virtual ByteVec decode_row(ByteVec row, + uint32_t cols, + byte bytes_per_pixel, + ByteVec prev_row) const override; + }; + + class Average : public Base { + public: + virtual ByteVec decode_first_row(ByteVec row, + uint32_t cols, + byte bytes_per_pixel) const override; + virtual ByteVec decode_row(ByteVec row, + uint32_t cols, + byte bytes_per_pixel, + ByteVec prev_row) const override; + }; + + class Paeth : public Base { + public: + virtual ByteVec decode_first_row(ByteVec row, + uint32_t cols, + byte bytes_per_pixel) const override; + virtual ByteVec decode_row(ByteVec row, + uint32_t cols, + byte bytes_per_pixel, + ByteVec prev_row) const override; + }; + + const Base& get_filter(byte filter_sigil); + + namespace identity { + constexpr None none = None{}; + constexpr Sub sub = Sub{}; + constexpr Up up = Up{}; + constexpr Average average = Average{}; + constexpr Paeth paeth = Paeth{}; + } + } + + using Filter = filter::Base; +} + +std::string to_string(png::FilterType ft); +std::ostream& operator<<(std::ostream& os, png::FilterType ft); diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/png/idat.cpp b/tests/generic_tests/targets/hamlin/challenge/src/src/png/idat.cpp new file mode 100755 index 0000000..dbd3bd5 --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/png/idat.cpp @@ -0,0 +1,11 @@ +#include "idat.hpp" + +using namespace png; + +Idat::Idat(Chunk c) : Chunk(c) { + assert('IDAT' == type); +} + +void Idat::inspect(std::ostream& w) { + w << "\tIDAT bytes(" << length << ")" << std::endl; +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/png/idat.hpp b/tests/generic_tests/targets/hamlin/challenge/src/src/png/idat.hpp new file mode 100755 index 0000000..e860051 --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/png/idat.hpp @@ -0,0 +1,18 @@ +#pragma once + +#include "common.hpp" + +#include "chunk.hpp" + +namespace png { + class Idat : public Chunk { + public: + Idat(Chunk c); + + Idat(const Idat& i) : Chunk(i) {}; + + virtual ~Idat() {}; + + void inspect(std::ostream& w) override; + }; +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/png/iend.cpp b/tests/generic_tests/targets/hamlin/challenge/src/src/png/iend.cpp new file mode 100755 index 0000000..39e1e36 --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/png/iend.cpp @@ -0,0 +1,8 @@ +#include "iend.hpp" + +using namespace png; + +Iend::Iend(Chunk c) : Chunk(c) { + assert('IEND' == type); + assert(0 == length); +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/png/iend.hpp b/tests/generic_tests/targets/hamlin/challenge/src/src/png/iend.hpp new file mode 100755 index 0000000..0805507 --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/png/iend.hpp @@ -0,0 +1,14 @@ +#pragma once + +#include "common.hpp" + +#include "chunk.hpp" + +namespace png { + class Iend : public Chunk { + public: + Iend(Chunk c); + + virtual ~Iend() {}; + }; +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/png/ihdr.cpp b/tests/generic_tests/targets/hamlin/challenge/src/src/png/ihdr.cpp new file mode 100755 index 0000000..e634248 --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/png/ihdr.cpp @@ -0,0 +1,35 @@ +#include + +#include "ihdr.hpp" + +using namespace png; + +Ihdr::Ihdr(Chunk c) : Chunk(c) { + assert(13 == length); + assert('IHDR' == type); + assert(sizeof(Pack) == data.size()); + + Pack* pd = (Pack*)(void*) data.data(); + + cols = ntohl(pd->cols); + rows = ntohl(pd->rows); + bit_depth = pd->bit_depth; + color_type = (ColorType) pd->color_type; + compression = pd->compression; + filter = pd->filter; + interlace = pd->interlace; + +} + +void Ihdr::inspect(std::ostream& w) { + Chunk::inspect(w); + w << "\tIHDR dimensions(" << dimensions() << ") bit depth(" << + (int)bit_depth << ") color_type(" << color_type << ") compression(" << + (int)compression << ") filter(" << (int)filter << + ") interlace(" << (int)interlace << + ")" << std::endl; +} + +Dimensions Ihdr::dimensions() { + return Dimensions{cols, rows}; +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/png/ihdr.hpp b/tests/generic_tests/targets/hamlin/challenge/src/src/png/ihdr.hpp new file mode 100755 index 0000000..1eff16c --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/png/ihdr.hpp @@ -0,0 +1,46 @@ +#pragma once + +#include "common.hpp" + +#include "chunk.hpp" +#include "colortype.hpp" + +namespace png { + class Ihdr : public Chunk { + public: + Ihdr(Chunk c); + + Ihdr(const Ihdr& i) : Chunk(i), cols(i.cols), rows(i.rows), + bit_depth(i.bit_depth), color_type(i.color_type), + compression(i.compression), filter(i.filter), + interlace(i.interlace) {}; + + Ihdr() : Chunk(), cols(0), rows(0), bit_depth(0), color_type(_invalid), + compression(255), filter(255), interlace(255) {}; + + virtual ~Ihdr() {}; + + void inspect(std::ostream& w) override; + + Dimensions dimensions(); + + uint32_t cols; // width + uint32_t rows; // height + uint8_t bit_depth; + png::ColorType color_type; + uint8_t compression; + uint8_t filter; + uint8_t interlace; + + private: + struct __attribute__((packed)) Pack { + uint32_t cols; // width + uint32_t rows; // height + uint8_t bit_depth; + uint8_t color_type; + uint8_t compression; + uint8_t filter; + uint8_t interlace; + }; + }; +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/png/image_coder.cpp b/tests/generic_tests/targets/hamlin/challenge/src/src/png/image_coder.cpp new file mode 100755 index 0000000..6201ebc --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/png/image_coder.cpp @@ -0,0 +1,113 @@ +#include + +#include "common.hpp" + +#include "colorizer.hpp" +#include "image_coder.hpp" +#include "filter.hpp" +#include "../inflate.hpp" +#include "../deflate/zlib_header.hpp" + +using namespace png; + +using ZlibHeader = deflate::ZlibHeader; + +void ImageCoder::load_image_data(std::vector image_data, + Ihdr ihdr, + Plte* plte) { + BitVector bv = {image_data}; + zlh = new ZlibHeader(bv); + zlh->inspect(std::cerr); + bv.inspect(std::cerr); + + Inflate inflater = {*zlh, bv}; + + std::vector inflated_data = inflater.inflated(); + + std::ofstream inflated_dump("/mnt/challenge/tmp/idat.idat", + std::ios::binary | std::ios::trunc); + inflated_dump.write((char*)(void*) inflated_data.data(), + inflated_data.size()); + inflated_dump.close(); + + assert(0 == (ihdr.bit_depth % 8)); + + byte bytes_per_channel = ihdr.bit_depth / 8; + + byte bytes_per_pixel = 0; + + switch (ihdr.color_type) { + case ColorType::rgb: + bytes_per_pixel = 3 * bytes_per_channel; + break; + case ColorType::rgba: + bytes_per_pixel = 4 * bytes_per_channel; + break; + case ColorType::palette: + bytes_per_pixel = 1; + break; + default: + assert(false); + } + + std::size_t bytes_per_row = 1 + (ihdr.cols * bytes_per_pixel); + std::size_t expected_total_bytes = bytes_per_row * ihdr.rows; + + assert(inflated_data.size() == expected_total_bytes); + + std::vector prev_row{}; + std::vector pixel_row{}; + + prev_row.reserve(ihdr.cols * bytes_per_pixel); + pixel_row.reserve(ihdr.cols); + + Colorizer* colorizer = colorizer::get_colorizer(ihdr, plte); + + for(int r = 0; r < ihdr.rows; r++) { + std::size_t offset = r * bytes_per_row; + std::vector row_data{}; + + byte filter_sigil = inflated_data[offset]; + + auto offset_start = inflated_data.begin() + offset + 1; + auto offset_end = offset_start + bytes_per_row - 1; + + row_data.insert(row_data.end(), offset_start, offset_end); + + const Filter& f = filter::get_filter(filter_sigil); + + + + + std::vector cur_row{}; + + + if (0 == r) { + cur_row = f.decode_first_row(row_data, ihdr.cols, bytes_per_pixel); + } else { + cur_row = f.decode_row(row_data, ihdr.cols, bytes_per_pixel, prev_row); + } + + pixel_row = colorizer->colorize(cur_row); + + pixels.insert(pixels.end(), pixel_row.begin(), pixel_row.end()); + prev_row = cur_row; + } + + delete colorizer; +} + +void ImageCoder::inspect(std::ostream& w) { + w << "ImageCoder" << std::endl; + w << "\timage_data(" << image_data.size() + << ") pixels(" << pixels.size() << ")" << std::endl; + if (nullptr != zlh) { + w << "\tZlibHeader "; + zlh->inspect(w); + w << std::endl; + } +} + +const std::vector& ImageCoder::to_pixels() const { + return pixels; +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/png/image_coder.hpp b/tests/generic_tests/targets/hamlin/challenge/src/src/png/image_coder.hpp new file mode 100755 index 0000000..b621657 --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/png/image_coder.hpp @@ -0,0 +1,28 @@ +#pragma once + +#include "common.hpp" +#include "ihdr.hpp" +#include "plte.hpp" +#include "../deflate/zlib_header.hpp" + +namespace png { + class ImageCoder { + public: + void load_image_data(std::vector image_data, + Ihdr ihdr, + Plte* plte); + void load_pixels(std::vector pixels, + Ihdr ihdr); + + std::vector to_image_data(); + const std::vector& to_pixels() const; + + void inspect(std::ostream& w); + + private: + std::vector image_data = {}; + std::vector pixels = {}; + + deflate::ZlibHeader* zlh = nullptr; + }; +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/png/palette_colorizer.cpp b/tests/generic_tests/targets/hamlin/challenge/src/src/png/palette_colorizer.cpp new file mode 100755 index 0000000..23b760f --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/png/palette_colorizer.cpp @@ -0,0 +1,18 @@ +#include "palette_colorizer.hpp" + +using namespace png; +using namespace png::colorizer; + +std::vector PaletteColorizer::colorize(std::vector image_bytes) + const { + std::vector out_px{}; + out_px.reserve(image_bytes.size()); + + for (std::size_t i = 0; i < image_bytes.size(); i++) { + byte pal_idx = image_bytes[i]; + Rgba color = palette[pal_idx]; + out_px.push_back(color); + } + + return out_px; +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/png/palette_colorizer.hpp b/tests/generic_tests/targets/hamlin/challenge/src/src/png/palette_colorizer.hpp new file mode 100755 index 0000000..0fcd7ad --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/png/palette_colorizer.hpp @@ -0,0 +1,20 @@ +#pragma once + +#include "colorizer.hpp" +#include "plte.hpp" + +namespace png { + namespace colorizer { + class PaletteColorizer : public Colorizer { + public: + PaletteColorizer(Plte plte) : palette(plte) {}; + virtual ~PaletteColorizer() {}; + + virtual std::vector colorize(std::vector image_bytes) + const override; + + private: + Plte palette; + }; + } +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/png/plte.cpp b/tests/generic_tests/targets/hamlin/challenge/src/src/png/plte.cpp new file mode 100755 index 0000000..9c1d615 --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/png/plte.cpp @@ -0,0 +1,42 @@ +#include "plte.hpp" + +using namespace png; + +Plte::Plte(Chunk c) : Chunk(c) { + assert(0 == (length % 3)); + assert('PLTE' == type); + + entry_count = length / 3; + + assert(entry_count >= 1); + assert(entry_count <= 256); + + palette.resize(entry_count, {0, 0, 0}); + + for (std::size_t n = 0; n < entry_count; n++) { + std::size_t offset = n * 3; + + palette[n] = { + data[offset + 0], + data[offset + 1], + data[offset + 2] + }; + } +} + +Rgba Plte::operator[](byte pal_idx) const { + assert(pal_idx < entry_count); + + return palette[pal_idx]; +} + +void Plte::inspect(std::ostream& w) { + Chunk::inspect(w); + w << "\tPLTE entry_count(" << entry_count << ") entries(" << std::endl; + for (std::size_t n = 0; n < entry_count; n++) { + w << "\t\t" << n << ": {"; + w << (int)palette[n].r << " " + << (int)palette[n].g << " " + << (int)palette[n].b << "}" << std::endl; + } +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/png/plte.hpp b/tests/generic_tests/targets/hamlin/challenge/src/src/png/plte.hpp new file mode 100755 index 0000000..b82a7e5 --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/png/plte.hpp @@ -0,0 +1,26 @@ +#pragma once + +#include "common.hpp" + +#include "chunk.hpp" +#include "plte.hpp" + +namespace png { + class Plte : public Chunk { + public: + Plte(Chunk c); + + Plte(const Plte& i) : + Chunk(i), palette(i.palette), entry_count(i.entry_count) {}; + + virtual ~Plte() {}; + + Rgba operator[](byte pal_idx) const; + + void inspect(std::ostream& w) override; + + private: + std::vector palette = {}; + std::size_t entry_count = 0; + }; +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/png/rgb_colorizer.cpp b/tests/generic_tests/targets/hamlin/challenge/src/src/png/rgb_colorizer.cpp new file mode 100755 index 0000000..0c944a9 --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/png/rgb_colorizer.cpp @@ -0,0 +1,22 @@ +#include "rgb_colorizer.hpp" + +using namespace png::colorizer; + +std::vector RgbColorizer::colorize(std::vector image_bytes) const { + std::vector out_px{}; + assert((image_bytes.size() % 3) == 0); + std::size_t cols = image_bytes.size() / 3; + out_px.reserve(cols); + + for (std::size_t i = 0; i < cols; i++) { + std::size_t offset = (i * 3); + + byte pal_idx_r = image_bytes[offset + 0]; + byte pal_idx_g = image_bytes[offset + 1]; + byte pal_idx_b = image_bytes[offset + 2]; + + out_px.push_back({pal_idx_r, pal_idx_g, pal_idx_b}); + } + + return out_px; +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/png/rgb_colorizer.hpp b/tests/generic_tests/targets/hamlin/challenge/src/src/png/rgb_colorizer.hpp new file mode 100755 index 0000000..1d3f201 --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/png/rgb_colorizer.hpp @@ -0,0 +1,13 @@ +#pragma once + +#include "colorizer.hpp" + +namespace png { + namespace colorizer { + class RgbColorizer : public Colorizer { + virtual ~RgbColorizer() {}; + virtual std::vector colorize(std::vector image_bytes) + const override; + }; + } +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/pngtest.cpp b/tests/generic_tests/targets/hamlin/challenge/src/src/pngtest.cpp new file mode 100755 index 0000000..161cfd5 --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/pngtest.cpp @@ -0,0 +1,41 @@ +#include +#include +#include +#include + +#include "assert.hpp" +#include "log.hpp" +#include "hrl.hpp" +#include "hton.hpp" +#include "inflate.hpp" +#include "png.hpp" +#include "ppm.hpp" + +using std::string; + +int main() { + std::ifstream in("/mnt/pov/images/fun.png", + std::ios::binary | std::ios::ate); + in.seekg(0); + + Png png = Png(in); + + png.inspect(std::cerr); + + std::ofstream out("/mnt/challenge/tmp/idat.deflate", + std::ios::binary | std::ios::trunc); + std::vector id = png.image_data(); + out.write((char*)(void*)id.data(), id.size()); + out.close(); + + Ppm hrl = Ppm(png); + + hrl.inspect(std::cerr); + + std::ofstream ppmout("/mnt/challenge/tmp/pngtest.ppm", + std::ios::binary | std::ios::trunc); + hrl.write(ppmout); + ppmout.close(); + + return 0; +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/ppm.cpp b/tests/generic_tests/targets/hamlin/challenge/src/src/ppm.cpp new file mode 100755 index 0000000..8935824 --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/ppm.cpp @@ -0,0 +1,103 @@ +#include +#include +#include +#include +#include +#include + +#include "assert.hpp" +#include "dimensions.hpp" +#include "hbytes.hpp" +#include "log.hpp" +#include "reader.hpp" +#include "rgba.hpp" + +#include "ppm.hpp" + +Ppm::Ppm(std::istream& r) { + std::string magic; + r >> magic; + + lll("ppm got magic %s want %s", magic.c_str(), "P6"); + assert("P6" == magic); + + r >> _width >> _height; + r >> _maxval; + + r.get(); + + assert(255 >= _maxval); + + _clamp = (byte)((uint32_t)256 / (_maxval + 1)); + + lll("w %d h %d m %d c %d", _width, _height, _maxval, _clamp); + + load_pixels(r); +} + +Ppm::Ppm(const Image& i) { + _width = i.width(); + _height = i.height(); + _pixels = i.pixels(); + _maxval = 255; + _clamp = 1; +} + +uint32_t Ppm::width() const { + return _width; +} + +uint32_t Ppm::height() const { + return _height; +} + +const std::vector& Ppm::pixels() const { + return _pixels; +} + +Dimensions Ppm::dimensions() { + return Dimensions{_width, _height}; +} + +void Ppm::inspect(std::ostream& w) { + w << "PPM" << std::endl; + + w << "\tdimensions(" << dimensions() << ") maxval(" << + _maxval << ") pixel_count(" << _pixels.size() << ")" << std::endl; +} + +void Ppm::write(std::ostream& w) { + w << "P6" << std::endl << _width << " " << _height << std::endl + << _maxval << std::endl; + + w.flush(); + + for (Rgba px : _pixels) { + w << px.r << px.g << px.b; + } + w << std::flush; +} + +void Ppm::load_pixels(std::istream& r) { + _pixels.reserve(_width * _height); + lll("px %d", _pixels.size()); + for (uint32_t y = 0; y < _height; y++) { + for (uint32_t x = 0; x < _width; x++) { + _pixels.push_back(read_pixel(r)); + } + } +} + +Rgba Ppm::read_pixel(std::istream& r) { + std::array px; + r.read((char*)(void*) px.data(), px.size()); + return Rgba{declamp(px)}; +} + +std::array Ppm::declamp(std::array b) { + return {declamp(b[0]), declamp(b[1]), declamp(b[2])}; +} + +byte Ppm::declamp(byte b) { + return (byte)(b * _clamp); +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/ppm.hpp b/tests/generic_tests/targets/hamlin/challenge/src/src/ppm.hpp new file mode 100755 index 0000000..2b2f034 --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/ppm.hpp @@ -0,0 +1,46 @@ +#pragma once + +#include +#include +#include +#include + +#include "dimensions.hpp" +#include "hbytes.hpp" +#include "image.hpp" +#include "reader.hpp" +#include "rgba.hpp" + +class Ppm : public Image /* limited */ { +public: + Ppm(std::istream& r); + Ppm(const Image& i); + + ~Ppm() override {}; + + void inspect(std::ostream& w) override; + void write(std::ostream& w) override; + + Dimensions dimensions(); + + uint32_t width() const override; + uint32_t height() const override; + const std::vector& pixels() const override; + + +private: + uint32_t _width; + uint32_t _height; + uint16_t _maxval; + + byte _clamp; + + std::vector _pixels; + + void load_pixels(std::istream& r); + + std::array declamp(std::array); + byte declamp(byte i); + + Rgba read_pixel(std::istream& r); +}; diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/reader.cpp b/tests/generic_tests/targets/hamlin/challenge/src/src/reader.cpp new file mode 100755 index 0000000..2cf0975 --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/reader.cpp @@ -0,0 +1,30 @@ +#include +#include + +#include "hbytes.hpp" + +#include "reader.hpp" + +template +void Reader::read_ba(std::array& buf) { + char* cb = (char*)(void*) buf.data(); + reader.read(cb, count); +} + +uint32_t Reader::read_l() { + uint32_t buf; + reader.read((char*)(void*) &buf, sizeof(buf)); + return ntohl(buf); +} + +std::vector Reader::read_bv(uint32_t length) { + std::vector buf(length); + reader.read((char*)(void*) buf.data(), length); + + return buf; +} + +bool Reader::eof() { + reader.peek(); + return reader.eof(); +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/reader.hpp b/tests/generic_tests/targets/hamlin/challenge/src/src/reader.hpp new file mode 100755 index 0000000..ba888f7 --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/reader.hpp @@ -0,0 +1,33 @@ +#pragma once + +#include +#include +#include +#include + +#include "hbytes.hpp" + +class Reader { +public: + Reader(std::istream& r) : reader(r) {}; + + template + void read_ba(std::array& buf); + + uint32_t read_l(); + std::vector read_bv(uint32_t length); + + bool eof(); + + template + friend Reader operator>>(const Reader& lhs, const T rhs); + +private: + std::istream& reader; +}; + +template +Reader operator>>(const Reader& lhs, const T rhs) { + lhs.reader >> rhs; + return lhs; +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/rgba.cpp b/tests/generic_tests/targets/hamlin/challenge/src/src/rgba.cpp new file mode 100755 index 0000000..3fb33be --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/rgba.cpp @@ -0,0 +1,18 @@ +#include + +#include "hbytes.hpp" + +#include "rgba.hpp" + +std::array Rgba::to_pxarray() { + return {r, g, b, a}; +} + +bool operator==(const Rgba& l, const Rgba& r) { + if (l.r != r.r) return false; + if (l.g != r.g) return false; + if (l.b != r.b) return false; + if (l.a != r.a) return false; + + return true; +} diff --git a/tests/generic_tests/targets/hamlin/challenge/src/src/rgba.hpp b/tests/generic_tests/targets/hamlin/challenge/src/src/rgba.hpp new file mode 100755 index 0000000..8ca4ac2 --- /dev/null +++ b/tests/generic_tests/targets/hamlin/challenge/src/src/rgba.hpp @@ -0,0 +1,24 @@ +#pragma once + +#include + +#include "hbytes.hpp" + +class Rgba { +public: + Rgba(byte red, byte green, byte blue) : r(red), g(green), b(blue) {}; + Rgba(byte red, byte green, byte blue, byte alpha) : + r(red), g(green), b(blue), a(alpha) {}; + + Rgba(std::array px) : r(px[0]), g(px[1]), b(px[2]) {}; + Rgba(std::array px) : r(px[0]), g(px[1]), b(px[2]), a(px[3]) {}; + + std::array to_pxarray(); + + byte r; + byte g; + byte b; + byte a = 255; +}; + +bool operator==(const Rgba& l, const Rgba& r); diff --git a/tests/DISABLED_test_core.py b/tests/test_core.py similarity index 68% rename from tests/DISABLED_test_core.py rename to tests/test_core.py index 6f0f605..d28148d 100644 --- a/tests/DISABLED_test_core.py +++ b/tests/test_core.py @@ -1,42 +1,47 @@ +import logging import os -import sys +import shutil import subprocess +import sys +import tempfile import textwrap import unittest -import tempfile from pathlib import Path -import shutil +from subprocess import CalledProcessError, run from tempfile import NamedTemporaryFile -from subprocess import run, CalledProcessError from typing import Optional import git -import logging - -from patchery.report import Report - logging.basicConfig(level=logging.DEBUG) +from common import GENERIC_TEST_DIR, PATCHES, TARGETS + import patchery -from patchery import Patcher -from patchery.verifier.verification_passes import CompileVerificationPass +from patchery.patcher import Patcher from patchery.kumushi.code_parsing import CodeParser from patchery.utils import WorkDirContext - -from common import TARGETS, PATCHES, GENERIC_TEST_DIR +from patchery.verifier.verification_passes import CompileVerificationPass +from patchery.data.patched_function import PatchedFunction +from patchery.data.program import Program +from patchery.data.poi import PoI +from patchery.data.patch import Patch +from patchery.data.program_input import ProgramInput, ProgramInputType # # Testing Utils # FAKE_GIT_REPOS = [ - TARGETS / "adams", + # TARGETS / "adams", TARGETS / "hamlin/challenge/src", ] -GIT_REPOS = [TARGETS / "jenkins/src/plugins/pipeline-util-plugin", TARGETS / "kernel/src"] +GIT_REPOS = [ + TARGETS / "jenkins/src/plugins/pipeline-util-plugin", + TARGETS / "kernel/src", +] def apply_patch_text(repo_path, target_file, patch_file_path) -> str: @@ -60,14 +65,18 @@ def apply_patch_text(repo_path, target_file, patch_file_path) -> str: return patched_code -def patch_func_from_patch_file(repo_path, target_file, func_name, patch_file_path, lang="C") -> list[PatchedFunction]: +def patch_func_from_patch_file( + repo_path, target_file, func_name, patch_file_path, lang="C" +) -> list[PatchedFunction]: """ Applies a patch to a file and extracts the target function. """ new_file_text = apply_patch_text(repo_path, target_file, patch_file_path) parser = CodeParser.from_code_string(new_file_text, func_name, lang=lang) new_func_code = parser.func_code(func_name) - func_patch = PatchedFunction(function_name=func_name, file=target_file, new_code=new_func_code) + func_patch = PatchedFunction( + function_name=func_name, file=target_file, new_code=new_func_code + ) return [func_patch] @@ -87,6 +96,7 @@ def teardown_testcase(): if git_dir.exists(): repo = git.Repo(str(repo_dir)) repo.git.reset("--hard") + repo.git.clean("-fd") shutil.rmtree(git_dir) for repo_dir in GIT_REPOS: @@ -104,45 +114,53 @@ def reset_repo_path(repo_dir: Path): # -class SimpleExecutor(Executor): - def __init__(self, run_script_path: Path, **kwargs): - self._runner_path: Path = Path(run_script_path).resolve().absolute() - super().__init__(**kwargs) +# class SimpleExecutor(): +# def __init__(self, run_script_path: Path, **kwargs): +# self._runner_path: Path = Path(run_script_path).resolve().absolute() +# super().__init__(**kwargs) - def generates_alerts(self, prog_input: ProgramInput, *args) -> bool: - with NamedTemporaryFile(delete=False) as input_file: - input_file.write(prog_input.data) - input_file.close() +# def generates_alerts(self, prog_input: ProgramInput, *args) -> bool: +# with NamedTemporaryFile(delete=False) as input_file: +# input_file.write(prog_input.data) +# input_file.close() - with WorkDirContext(self._runner_path.parent): - try: - proc = run(["./run.sh", "run", input_file.name], capture_output=True, check=True) - crash = False - except CalledProcessError as e: - crash = True +# with WorkDirContext(self._runner_path.parent): +# try: +# proc = run( +# ["./run.sh", "run", input_file.name], +# capture_output=True, +# check=True, +# ) +# crash = False +# except CalledProcessError as e: +# crash = True - return crash +# return crash - def check_functionality(self) -> ProgramExitType: - return ProgramExitType.NORMAL +# def check_functionality(self) -> ProgramExitType: +# return ProgramExitType.NORMAL class SimpleProgram(Program): def __init__(self, run_script_path: Path, *args, **kwargs): super().__init__(*args, **kwargs) - self.executor = SimpleExecutor(run_script_path) + # self.executor = SimpleExecutor(run_script_path) self._runner_path = Path(run_script_path).resolve().absolute() - def _compile_core(self, patch_path: Optional[Path] = None): + def compile(self, patch: Patch): + patch_path = patch.file_path if patch_path is not None: patch_path = Path(patch_path).absolute() + with WorkDirContext(self._runner_path.parent): compile_cmd = f"./run.sh build " if patch_path is not None: compile_cmd += f"{patch_path}" failed = False try: - proc = subprocess.run(compile_cmd.split(), capture_output=True, text=True) + proc = subprocess.run( + compile_cmd.split(), capture_output=True, text=True + ) except Exception as e: print(f"Compilation failed: {e}") failed = True @@ -173,6 +191,7 @@ def test_cli(self): version = output.stdout.decode().strip() assert version == patchery.__version__ + @unittest.skip("Skipping patch diffing test") def test_patch_diffing(self): # This test verifies that after an agent has generated a full function patch for a targeted function in a file, # that we can create a valid AIxCC patch from it, which is a Git diff. @@ -186,7 +205,11 @@ def test_patch_diffing(self): # an AI agent, but for this we are just using the source code and a pre-computed patch. perfect_patch = Patch( patch_func_from_patch_file( - source_root, target_file, "handle_AUTH", PATCHES / "adams_good.patch", lang=prog_info.lang + source_root, + target_file, + "handle_AUTH", + PATCHES / "adams_good.patch", + lang=prog_info.language, ), reasoning="Perfect patch.", ) @@ -202,36 +225,44 @@ def test_patch_diffing(self): temp_patch.write(generated_patch_diff.encode()) temp_patch.seek(0) - proc = subprocess.run(["git", "-C", str(source_root), "apply", temp_patch.name], cwd=GENERIC_TEST_DIR) + proc = subprocess.run( + ["git", "-C", str(source_root), "apply", temp_patch.name], + cwd=GENERIC_TEST_DIR, + ) assert proc.returncode == 0 def test_valid_patch_compile(self): source_root = TARGETS / "hamlin/challenge/src" target_file = source_root / "src/deflate/array_history.cpp" - poi = PoI(target_file, "ArrayHistory::copy", 26) prog_info = SimpleProgram( TARGETS / "hamlin/challenge/run.sh", source_root=source_root, - lang="C++", + language="C++", ) # load a pre-computed perfect patch perfect_patch = Patch( patch_func_from_patch_file( - source_root, target_file, "ArrayHistory::copy", PATCHES / "hamlin_good.patch", lang=prog_info.lang + source_root, + target_file, + "ArrayHistory::copy", + PATCHES / "hamlin_good.patch", + lang=prog_info.language, ), reasoning="Perfect patch.", ) - # directly call the compile checker in verified comp_pass = CompileVerificationPass(prog_info, perfect_patch) comp_pass.verify() assert comp_pass.verified is True, comp_pass.reasoning + @unittest.skip("Skipping alert generation test") def test_alert_generation(self): hamlin_chall = TARGETS / "hamlin/challenge" source_root = TARGETS / "hamlin/challenge/src" - prog_info = SimpleProgram(hamlin_chall / "run.sh", source_root=source_root, lang="C++") + prog_info = SimpleProgram( + hamlin_chall / "run.sh", source_root=source_root, lang="C++" + ) # create a binary for execution compiled, _ = prog_info.compile() if not compiled: @@ -242,9 +273,17 @@ def test_alert_generation(self): alerting_input = f.read() benign_input = b"benign" - assert prog_info.triggers_alert(ProgramInput(alerting_input, ProgramInputType.FILE)) is True - assert prog_info.triggers_alert(ProgramInput(benign_input, ProgramInputType.FILE)) is False - + assert ( + prog_info.triggers_alert( + ProgramInput(alerting_input, ProgramInputType.FILE) + ) + is True + ) + assert ( + prog_info.triggers_alert(ProgramInput(benign_input, ProgramInputType.FILE)) + is False + ) + @unittest.skip("Skipping end to end hamlin test") def test_end_to_end_hamlin(self): source_root = TARGETS / "hamlin/challenge/src" poi = PoI( @@ -256,7 +295,9 @@ def test_end_to_end_hamlin(self): prog_info = SimpleProgram( TARGETS / "hamlin/challenge/run.sh", source_root=source_root, - alerting_inputs=AICCProgram.load_inputs_from_dir(TARGETS / "hamlin/alerting_inputs"), + alerting_inputs=AICCProgram.load_inputs_from_dir( + TARGETS / "hamlin/alerting_inputs" + ), lang="C++", ) @@ -268,19 +309,31 @@ def test_end_to_end_hamlin(self): hamlin_bin = TARGETS / "hamlin/challenge/hamlin.bin" with open(TARGETS / "hamlin/alerting_inputs/crash_input", "rb") as f: crashing_input_data = f.read() - proc = subprocess.run([hamlin_bin], capture_output=True, input=crashing_input_data, env=env, text=False) + proc = subprocess.run( + [hamlin_bin], + capture_output=True, + input=crashing_input_data, + env=env, + text=False, + ) assert b"ERROR: AddressSanitizer" in proc.stderr model = os.getenv("AIXCC_MODEL_LLM", default="oai-gpt-o1-preview") patcher = Patcher(prog_info, max_patches=1, max_attempts=5, model=model) crashing_input = ProgramInput(crashing_input_data, ProgramInputType.FILE) - #bug_info = BugInfo(crashing_input, poi_clusters=[PoICluster.from_pois([poi])], reports=[Report(poi.report)]) + # bug_info = BugInfo(crashing_input, poi_clusters=[PoICluster.from_pois([poi])], reports=[Report(poi.report)]) verified_patches = patcher.generate_verified_patches() assert bool(verified_patches) verified_patch = verified_patches[0] prog_info.compile(verified_patch) # validate the patch actually fixed the issue - proc = subprocess.run([hamlin_bin], capture_output=True, input=crashing_input.data, env=env, text=False) + proc = subprocess.run( + [hamlin_bin], + capture_output=True, + input=crashing_input.data, + env=env, + text=False, + ) assert b"ERROR: AddressSanitizer" not in proc.stderr @unittest.skip("Does not work while clang_indexer is disabled") @@ -300,4 +353,4 @@ def test_code_parsing(self): if __name__ == "__main__": - unittest.main(argv=sys.argv, buffer=True) \ No newline at end of file + unittest.main(argv=sys.argv, buffer=True)