Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions src/mygrad/nnet/layers/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,3 +244,83 @@ def max_pool(
of ``x``.
"""
return Tensor._op(MaxPoolND, x, op_args=(pool, stride), constant=constant)


class MeanPoolND(Operation):
def __call__(self, x, pool, stride):
"""
Perform mean-pooling over the last N dimensions of a data batch.
"""
self.variables = (x,)
x = x.data

assert isinstance(pool, (tuple, list, np.ndarray)) and all(
isinstance(i, Integral) and i > 0 for i in pool
)
pool = np.asarray(pool, dtype=int)

stride = (
np.array([stride] * len(pool))
if isinstance(stride, Integral)
else np.asarray(stride, dtype=int)
)
assert len(stride) == len(pool) and all(
isinstance(s, Integral) and s >= 1 for s in stride
)

self.pool = pool
self.stride = stride

num_pool = len(pool)
num_no_pool = x.ndim - num_pool
w_shape = pool

x_shape = np.array(x.shape[num_no_pool:])
out_shape = (x_shape - w_shape) / stride + 1

if not all(i.is_integer() and i > 0 for i in out_shape):
msg = "Stride and kernel dimensions are incompatible:\n"
msg += f"Input dimensions: {tuple(x_shape)}\n"
msg += f"Stride dimensions: {tuple(stride)}\n"
msg += f"Pooling dimensions: {tuple(w_shape)}\n"
raise ValueError(msg)
pool_axes = tuple(-(i + 1) for i in range(num_pool))

sl = sliding_window_view(x, self.pool, self.stride)
meaned = np.mean(sl, axis=pool_axes)

# reorder axes to move (N0, ...) to the front again
axes = tuple(range(meaned.ndim))

out = meaned.transpose(axes[-num_no_pool:] + axes[:-num_no_pool])
return out if out.flags["C_CONTIGUOUS"] else np.ascontiguousarray(out)

def backward_var(self, grad, index, **kwargs):
"""
"""
var = self.variables[index]
x = var.data
num_pool = len(self.pool)

sl = sliding_window_view(x, self.pool, self.stride)
window_size = int(np.prod(self.pool))

dx = np.zeros_like(x)

it = np.nditer(grad, flags=["multi_index"])
while not it.finished:
idx = it.multi_index

slices = tuple(
slice(i * s, i * s + p)
for i, s, p in zip(idx[-num_pool:], self.stride, self.pool)
)
dx[(...,) + slices] += it[0] / window_size
it.iternext()

return dx

def mean_pool(x, pool, stride, *, constant: Optional[bool] = None) -> Tensor:
"""
"""
return Tensor._op(MeanPoolND, x, op_args=(pool, stride), constant=constant)
Loading