Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
d0b536c
python AST to backend flattened AST
Sh0g0-1758 Oct 12, 2025
636f14d
fix and test flattened AST -> AST functionality
Sh0g0-1758 Oct 12, 2025
7b49965
make it work for broadcast
Sh0g0-1758 Oct 12, 2025
3b3d756
test ternary op and other binary ops
Sh0g0-1758 Oct 13, 2025
2ce7700
ast flattening for matmuls and dot
Sh0g0-1758 Oct 13, 2025
4016abf
fixes for matmul
Sh0g0-1758 Oct 13, 2025
1180b11
add support for scalars
Sh0g0-1758 Oct 13, 2025
00ce700
expand matmuls to operator on rvalues as well
Sh0g0-1758 Oct 13, 2025
1b032f1
test matmuls
Sh0g0-1758 Oct 13, 2025
26d62a5
ok
anant37289 Oct 13, 2025
34323ef
just push it
anant37289 Oct 13, 2025
1247038
Merge pull request #6 from UIUC-PPL/unary_ops
Sh0g0-1758 Oct 13, 2025
bfb8cb4
cleanup
Sh0g0-1758 Oct 13, 2025
13bf3ef
add gitignore
Sh0g0-1758 Oct 13, 2025
0af35ee
nit
Sh0g0-1758 Oct 13, 2025
90eb20e
add support for copy operation
Sh0g0-1758 Oct 14, 2025
1c57953
add axpy support for charmnumerics
anant37289 Oct 14, 2025
19d9c6d
handle scalar computation separately
Sh0g0-1758 Oct 14, 2025
1512955
fix tensor handling
Sh0g0-1758 Oct 14, 2025
c30dbdc
functionally correct conjugate gradient
Sh0g0-1758 Oct 14, 2025
4c5f45c
feat: fusing multiple ASTs using decorators (#8)
anant37289 Oct 15, 2025
62eb3ef
dealloc (#9)
Sh0g0-1758 Oct 17, 2025
f068b6c
nit
Sh0g0-1758 Oct 17, 2025
dd4d3ae
Revert "nit"
Sh0g0-1758 Oct 17, 2025
d6bf0ae
Revert "dealloc (#9)"
Sh0g0-1758 Oct 17, 2025
9820a50
Reapply "dealloc (#9)"
Sh0g0-1758 Oct 18, 2025
03e7db7
Reapply "nit"
Sh0g0-1758 Oct 18, 2025
f8460fc
Revert "Reapply "nit""
Sh0g0-1758 Oct 18, 2025
4c5cdb0
debug deletion + extra temprory saves
anant37289 Oct 18, 2025
160feb7
deferred deletions
Sh0g0-1758 Oct 18, 2025
1d24e4c
fix magic deletions
Sh0g0-1758 Oct 18, 2025
d5bd4d1
nit
Sh0g0-1758 Oct 18, 2025
25a9388
fix deletion of matmuls, copy and final cleanup
Sh0g0-1758 Oct 19, 2025
13e325f
sync + kokkos kernels for cpu
anant37289 Oct 29, 2025
884926a
update cg benchmark to remove numerically unstable init
anant37289 Oct 29, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
build
charmnumeric.egg-info
*.decl.h
*.def.h
dist
.vscode
316 changes: 231 additions & 85 deletions charmnumeric/array.py

Large diffs are not rendered by default.

101 changes: 71 additions & 30 deletions charmnumeric/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


max_depth = 10

multiLineFuse = False

def set_max_depth(d):
global max_depth
Expand All @@ -18,55 +18,96 @@ def get_max_depth():
global max_depth
return max_depth

def charm_fuse(func):
def compile_wrapper(*args, **kwargs):
global multiLineFuse
orig_max_depth = get_max_depth()
multiLineFuse = True
set_max_depth(float('inf'))
out = func(*args, **kwargs)
multiLineFuse = False
set_max_depth(orig_max_depth)
return out
return compile_wrapper

class ASTNode(object):
def __init__(self, name, opcode, operands, arg=0.0):
from charmtiles.array import ndarray
def __init__(self, name, opcode, operands, args=[]):
from charmnumeric.array import ndarray
global multiLineFuse
# contains opcode, operands
# operands are ndarrays
self.name = name
self.opcode = opcode
self.operands = operands
self.depth = 0
self.arg = arg
self.args = args
self.multiLineFuse = multiLineFuse
if self.opcode != 0:
for op in self.operands:
if isinstance(op, ndarray):
self.depth = max(self.depth, 1 + op.command_buffer.depth)

def get_command(self, validated_arrays, save=True):
#################################################################################################################################################################
# Marker determines whether we are dealing with a tensor, a scalar or an arithmetic type #
# Marker = 0 : arithmetic type #
# Marker = 1 : scalar type #
# Marker = 2 : tensor type #
# Encoding = | Marker | dim | shape | opcode | save_op | ID | multiLineFuse | NumArgs | Args | NumOperands | OperandEncodingSize | RecursiveOperandEncoding | #
# | 8 | 8 | 64 | 32 | 1 | 64 | 1 | 32 | 64 | 8 | 32 | ........................ | #
# NB: If opcode is 0, the encoding is limited to ID #
# Encoding = | Marker | shape | val | #
# | 8 | 64 | 64 | #
# NB: Latter encoding for double constants #
#################################################################################################################################################################
def get_command(self, ndim, shape, save=True, is_scalar=False, hasExceededMaxAstDepth=False):
from charmnumeric.array import ndarray

# Ndims and Shape setup
if is_scalar:
cmd = to_bytes(1, 'B')
else:
cmd = to_bytes(2, 'B')
cmd += to_bytes(ndim, 'B')
for _shape in shape:
cmd += to_bytes(_shape, 'L')

if self.opcode == 0:
cmd = to_bytes(self.opcode, 'L')
cmd += to_bytes(False, '?')
cmd += to_bytes(self.operands[0].name, 'L')
cmd += to_bytes(0, 'I') + to_bytes(False, '?') + to_bytes(self.operands[0].name, 'L')
return cmd
cmd = to_bytes(self.opcode, 'L') + to_bytes(self.name, 'L')
cmd += to_bytes(save, '?') + to_bytes(len(self.operands), 'B')

cmd += to_bytes(self.opcode, 'I') + to_bytes(save, '?') + to_bytes(self.name, 'L') + to_bytes(self.multiLineFuse, '?')
cmd += to_bytes(len(self.args), 'I')
for arg in self.args:
cmd += to_bytes(arg, 'd')

cmd += to_bytes(len(self.operands), 'B')
for op in self.operands:
# an operand can also be a double
if isinstance(op, ndarray):
if op.name in validated_arrays:
opcmd = to_bytes(0, 'L')
opcmd += to_bytes(False, '?')
opcmd += to_bytes(op.name, 'L')
cmd += to_bytes(len(opcmd), 'I')
cmd += opcmd
if op.valid:
if op.is_scalar:
opcmd = to_bytes(1, 'B')
else:
opcmd = to_bytes(2, 'B')
opcmd += to_bytes(op.ndim, 'B')
for _shape in op.shape:
opcmd += to_bytes(_shape, 'L')
opcmd += to_bytes(0, 'I') + to_bytes(False, '?') + to_bytes(op.name, 'L')
else:
save_op = True if c_long.from_address(id(op)).value - 2 > 0 else False
opcmd = op.command_buffer.get_command(validated_arrays,
save=save_op)
if not op.valid and save_op:
validated_arrays[op.name] = op
cmd += to_bytes(len(opcmd), 'I')
cmd += opcmd
### this will only be true when AST is being flushed because of exceeding max depth and ensures that unnecessary temporaries are not saved
if hasExceededMaxAstDepth:
save_op = True if c_long.from_address(id(op)).value - 4 > 0 else False
else:
save_op = True if c_long.from_address(id(op)).value - 2 > 0 else False
opcmd = op.command_buffer.get_command(op.ndim, op.shape, save=save_op, is_scalar=op.is_scalar)
if save_op or (op.command_buffer.opcode == OPCODES.get('@')) or (op.command_buffer.opcode == OPCODES.get('copy')):
op.validate()
elif isinstance(op, float) or isinstance(op, int):
opcmd = to_bytes(0, 'L')
opcmd += to_bytes(True, '?')
opcmd += to_bytes(op, 'd')
cmd += to_bytes(len(opcmd), 'I')
cmd += opcmd
cmd += to_bytes(self.arg, 'd')
opcmd = to_bytes(0, 'B')
for _shape in shape:
opcmd += to_bytes(_shape, 'L')
opcmd += to_bytes(float(op), 'd')
cmd += to_bytes(len(opcmd), 'I')
cmd += opcmd
return cmd

def plot_graph(self, validated_arrays={}, G=None, node_map={},
Expand Down
77 changes: 66 additions & 11 deletions charmnumeric/ccs.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,63 @@
import struct
import atexit
from pyccs import Server
from charmnumeric import array
import gc

debug = False
server = None
client_id = 0
next_name = 0
epoch = 0

OPCODES = {'+': 1, '-': 2, '*': 3 ,'/': 4, '@': 5, 'copy': 6, 'axpy': 7,
'axpy_multiplier': 8, 'setitem': 9, 'pow': 10, '>': 11,
'<': 12, '>=': 13, '<=': 14, '==': 15, '!=': 16, '&': 17,
'|': 18, '!':19, 'where':20, 'log': 21, 'exp': 22, 'abs': 23, 'any':24, 'all':25}
OPCODES = {
# base_op
'+': 1,
'-': 2,
'*': 3,
'/': 4,
'@': 5,
'copy': 6,
'>': 9,
'<': 10,
'>=': 11,
'<=': 12,
'==': 13,
'!=': 14,
'&': 15,
'|': 16,
'!': 17,
'where': 18,

# custom_unary_op
'exp': 41,
'log': 42,
'abs': 43,
'negate': 44,
'square': 45,
'sqrt': 46,
'reciprocal': 47,
'sin': 48,
'cos': 49,
'relu': 50,
'scale': 51,
'add_constant': 52,

# custom_binary_op
'add': 71,
'subtract': 72,
'multiply': 73,
'divide': 74,
'power': 75,
'modulo': 76,
'max': 77,
'min': 78,
'greater_than': 79,
'less_than': 80,
'equal': 81,
'atan2': 82,
'weighted_average': 83,
'axpy': 84
}

INV_OPCODES = {v: k for k, v in OPCODES.items()}

Expand Down Expand Up @@ -69,14 +114,25 @@ def connect(server_ip, server_port):
atexit.register(disconnect)

def disconnect():
from charmnumeric.array import deletion_buffer, deletion_buffer_size
global client_id, deletion_buffer, deletion_buffer_size
if deletion_buffer_size > 0:
cmd = to_bytes(len(deletion_buffer), 'I') + deletion_buffer
# cleanup the remaining ndarrays
from charmnumeric.array import ndarray
deleted_id = []
for obj in gc.get_objects():
if isinstance(obj, ndarray):
if not obj.name in deleted_id:
print(obj.name)
deleted_id.append(obj.name)
obj.__del__()
from charmnumeric.array import deletion_buffer, deletion_buffer_size, deferred_deletion_buffer_size, deferred_deletion_buffer
if (deletion_buffer_size > 0) or (deferred_deletion_buffer_size > 0):
cmd = to_bytes(deletion_buffer_size, 'I') + deletion_buffer + to_bytes(deferred_deletion_buffer_size, 'I') + deferred_deletion_buffer
cmd = to_bytes(get_epoch(), 'i') + to_bytes(len(cmd), 'I') + cmd
send_command_async(Handlers.delete_handler, cmd)
deletion_buffer = b''
deletion_buffer_size = b''
deletion_buffer_size = 0
deferred_deletion_buffer = b''
deferred_deletion_buffer_size = 0
global client_id
cmd = to_bytes(client_id, 'B')
cmd = to_bytes(get_epoch(), 'i') + to_bytes(len(cmd), 'I') + cmd
send_command_async(Handlers.disconnection_handler, cmd)
Expand All @@ -95,7 +151,6 @@ def get_creation_command(arr, name, shape, buf=None):
cmd += buf
elif arr.init_value is not None:
cmd += to_bytes(arr.init_value, 'd')
print(cmd)
cmd = to_bytes(get_epoch(), 'i') + to_bytes(len(cmd), 'I') + cmd
return cmd

Expand Down
14 changes: 3 additions & 11 deletions charmnumeric/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,8 @@
from charmnumeric.ast import ASTNode


def axpy(a, x, y, multiplier=None):
operands = [a, x, y]
if multiplier is not None:
operands.append(multiplier)
operation = 'axpy_multiplier'
else:
operation = 'axpy'
def axpy(a, x, y):
res = get_name()
cmd_buffer = ASTNode(res, OPCODES.get(operation), operands)
return create_ndarray(x.ndim, x.dtype,
cmd_buffer = ASTNode(res, OPCODES.get('axpy'), [x, y], args=[a])
return create_ndarray(x.ndim, x.dtype, x.shape,
name=res, command_buffer=cmd_buffer)


13 changes: 13 additions & 0 deletions config.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
set(CHARM_DIR "/home/anant/winter2024/lbp/study/charm/netlrts-linux-x86_64/")
set(BASE_DIR "/home/anant/sem7/LibCharmtyles")
set(EIGEN_DIR "/usr/include/eigen3")
set(CUDA_DIR "/path/to/CUDA/directory")
set(KOKKOS_DIR "${BASE_DIR}/kokkos/install")
set(KOKKOS_KERNELS_DIR "${BASE_DIR}/kokkos-kernels/install")

set(CHARMC "${CHARM_DIR}/bin/charmc")
set(CPU_OPTS "-c++-option -std=c++20 -O3 -march=native -DNDEBUG")
set(GPU_OPTS "-std=c++20 -O3 -march=native -DNDEBUG")
set(GPU_LINK_OPTS -O3 -language charm++ -L${KOKKOS_DIR}/lib64 -L/u/ajain18/kokkos_kernels_install/lib64 -lkokkoscore -lkokkoscontainers -lkokkoskernels -L${CUDA_DIR} -lcuda -lcudart -lcusparse -lcublas)
set(LD_OPTS "")
set(INCS "-I${BASE_DIR}")
51 changes: 51 additions & 0 deletions examples/bench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from charmnumeric.array import connect, ndarray
from charmnumeric.ast import set_max_depth
from charmnumeric.ccs import enable_debug
import charmnumeric.linalg as lg
import numpy as np
import time

set_max_depth(10)

def f():
b = ndarray(1, 10, np.float64, init_value=10)
v = ndarray(1, 10, np.float64, init_value=20)
c = ndarray(1, 10, np.float64, init_value=20)
v1 = (b + v) - c * 3
v2 = b.scale(3) + c.add_constant(10)
v3 = (b + c) @ v
v1.get()
v2.get()
v3.get()
start = time.time()
for i in range(100):
v1 = (b + v) - c * 3
v2 = b.scale(3) + c.add_constant(10)
v3 = (b + c) @ v
v1.get()
v2.get()
v3.get()

end = time.time()

print("VECTOR BENCH ", end-start)


start = time.time()
b = ndarray(2, [10,10], np.float64, init_value=10)
v = ndarray(2, [10,10], np.float64, init_value=20)
c = ndarray(2, [10,10], np.float64, init_value=20)
for i in range(100):
v1 = (b + v) - c * 3
v2 = b.exp() + c.add_constant(10)
v3 = (b + c) @ v
v1.get()
v2.get()
v3.get()

end = time.time()
print("MARIX BENCH ", end-start)

if __name__ == '__main__':
connect("172.17.0.1", 10000)
s = f()
44 changes: 44 additions & 0 deletions examples/charm_fuse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from charmnumeric.array import connect, ndarray
from charmnumeric.ast import set_max_depth, charm_fuse
from charmnumeric.ccs import enable_debug
import charmnumeric.linalg as lg
import numpy as np
import time
set_max_depth(float('inf'))

@charm_fuse
def f():
v = ndarray(1, 1e2, np.float64, init_value=-20)
b = ndarray(1, 1e2, np.float64, init_value=10)

g1 = v.abs().scale(2).scale(2).add_constant(29) + b + 32
g2 = b.log(2).exp()
d = g1 + g2
return d.get()

def g():
v = ndarray(1, 1e2, np.float64, init_value=-20)
b = ndarray(1, 1e2, np.float64, init_value=10)

g1 = v.abs().scale(2).scale(2).add_constant(29) + b + 32
g2 = b.log(2).exp()
d = g1 + g2
return d.get()

if __name__ == '__main__':
connect("127.0.0.1", 10000)
# s = f()
start = time.time()
for(i) in range(1000):
s = f()
end = time.time()
print(s)
print("Time taken(multi line fused): ", end - start)

# k = g()
# start = time.time()
# for(i) in range(100):
# k = g()
# end = time.time()
# print(k)
# print("Time taken(multi line unfused): ", end - start)
Loading