Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion python/peppi_py/frame.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from dataclasses import dataclass
from pyarrow.lib import Int8Array, Int16Array, Int32Array, Int64Array, UInt8Array, UInt16Array, UInt32Array, UInt64Array, FloatArray, DoubleArray
from pyarrow.lib import (
Int8Array, Int16Array, Int32Array, Int64Array,
UInt8Array, UInt16Array, UInt32Array, UInt64Array,
FloatArray, DoubleArray,
ListArray,
)
from .util import _repr

@dataclass(slots=True)
Expand Down Expand Up @@ -53,6 +58,7 @@ class Item:
id: UInt32Array
misc: tuple[UInt8Array, UInt8Array, UInt8Array, UInt8Array] | None = None
owner: Int8Array | None = None
instance_id: UInt16Array | None = None

@dataclass(slots=True)
class Post:
Expand Down Expand Up @@ -113,3 +119,4 @@ class Frame:
__repr__ = _repr
id: object
ports: tuple[PortData]
items: Item | None = None
90 changes: 76 additions & 14 deletions python/peppi_py/parse.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import sys, types, typing
import types, typing
import pyarrow
import dataclasses as dc
import functools
from inflection import underscore
from enum import Enum
from .frame import Data, Frame, PortData
from .frame import Data, Frame, PortData, Item

T = typing.TypeVar('T')

def _repr(x):
if isinstance(x, pyarrow.Array):
Expand All @@ -20,24 +23,30 @@ def _repr(x):
else:
return repr(x)

get_origin = functools.cache(typing.get_origin)
is_dataclass = functools.cache(dc.is_dataclass)
dc_fields = functools.cache(dc.fields)
get_args = functools.cache(typing.get_args)

@functools.cache
def unwrap_union(cls):
if typing.get_origin(cls) is types.UnionType:
return typing.get_args(cls)[0]
if get_origin(cls) is types.UnionType:
return get_args(cls)[0]
else:
return cls

def field_from_sa(cls, arr):
def field_from_sa(cls: type[T], arr: pyarrow.Array | None) -> T | pyarrow.Array | None:
if arr is None:
return None
cls = unwrap_union(cls)
if dc.is_dataclass(cls):
if is_dataclass(cls):
return dc_from_sa(cls, arr)
elif typing.get_origin(cls) is tuple:
elif get_origin(cls) is tuple:
return tuple_from_sa(cls, arr)
else:
return arr

def arr_field(arr, dc_field):
def arr_field(arr, dc_field: dc.Field):
try:
return arr.field(dc_field.name)
except KeyError:
Expand All @@ -46,13 +55,57 @@ def arr_field(arr, dc_field):
else:
return dc_field.default

def dc_from_sa(cls, arr):
return cls(*(field_from_sa(f.type, arr_field(arr, f)) for f in dc.fields(cls)))
def dc_from_sa(cls: type[T], arr: pyarrow.StructArray) -> T:
return cls(*(field_from_sa(f.type, arr_field(arr, f)) for f in dc_fields(cls)))

def tuple_from_sa(cls, arr):
return tuple((field_from_sa(t, arr.field(str(idx))) for (idx, t) in enumerate(typing.get_args(cls))))
def tuple_from_sa(cls: type[tuple], arr: pyarrow.Array) -> tuple:
return cls((field_from_sa(t, arr.field(str(idx))) for (idx, t) in enumerate(get_args(cls))))

def frames_from_sa(arrow_frames):
@functools.cache
def unwrap_optional(cls: type) -> type | None:
if get_origin(cls) is not types.UnionType:
return cls

args = get_args(cls)
assert len(args) == 2
assert args[1] is types.NoneType
return args[0]

# Generic recursion on dataclasses
def map_dc(cls: type[T], fn: typing.Callable, *xs: T) -> T:
unwrapped_cls = unwrap_optional(cls)
if unwrapped_cls is not None:
# Optional fields might be missing; if so, don't recurse.
for x in xs:
if x is None:
return None
cls = unwrapped_cls

if is_dataclass(cls):
return cls(*(
map_dc(f.type, fn, *(getattr(x, f.name) for x in xs))
for f in dc_fields(cls)
))
elif get_origin(cls) is tuple:
return cls(
map_dc(t, fn, *(x[idx] for x in xs))
for (idx, t) in enumerate(get_args(cls))
)
else:
return fn(*xs)

def dc_from_la(cls: type[T], la: pyarrow.ListArray) -> T:
"""Converts ListArray of Structs into dataclass of ListArrays."""
dc_sa = dc_from_sa(cls, la.values)
return map_dc(cls, lambda arr: pyarrow.ListArray.from_arrays(la.offsets, arr), dc_sa)


class RollbackMode(Enum):
ALL = 'all' # All frames in the replay.
FIRST = 'first' # Only the first frame, as seen by the player
LAST = 'last' # Only the finalized frames; the "true" frame sequence.

def frames_from_sa(arrow_frames) -> typing.Optional[Frame]:
if arrow_frames is None:
return None
ports = []
Expand All @@ -64,7 +117,16 @@ def frames_from_sa(arrow_frames):
try: follower = dc_from_sa(Data, port.field('follower'))
except KeyError: follower = None
ports.append(PortData(leader, follower))
return Frame(arrow_frames.field('id'), tuple(ports))

# Extract items if available
items = None
try:
items_array = arrow_frames.field('item')
items = dc_from_la(Item, items_array)
except KeyError:
pass

return Frame(arrow_frames.field('id'), tuple(ports), items)

def field_from_json(cls, json):
if json is None:
Expand Down
Binary file added tests/data/items.slp
Binary file not shown.
20 changes: 19 additions & 1 deletion tests/test_peppi_py.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from math import isclose
from collections import Counter
from pathlib import Path

import pytest

from peppi_py import read_slippi, read_peppi
from peppi_py.game import *

Expand Down Expand Up @@ -92,3 +95,18 @@ def test_basic_game():
assert p1.position.y[1000].as_py() == -18.6373291015625
assert p2.position.x[1000].as_py() == 42.195167541503906
assert p2.position.y[1000].as_py() == 9.287015914916992

def test_items_support():
# Replay with a Peach
game = read_slippi(Path(__file__).parent.joinpath('data/items.slp').as_posix())
assert game.frames is not None
assert game.frames.items is not None

item_types = Counter(game.frames.items.type.values.to_numpy())

# Peach turnip appears on 513 frames.
assert len(item_types) == 1
assert item_types[99] == 513

if __name__ == '__main__':
pytest.main([__file__]) # Uncomment to run with pytest