diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..a8240ad --- /dev/null +++ b/.gitignore @@ -0,0 +1,6 @@ +build +charmnumeric.egg-info +*.decl.h +*.def.h +dist +.vscode diff --git a/charmnumeric/array.py b/charmnumeric/array.py index 47162f7..b42edc8 100644 --- a/charmnumeric/array.py +++ b/charmnumeric/array.py @@ -11,11 +11,14 @@ deletion_buffer = b'' deletion_buffer_size = 0 +doDeferredDeletions = False +deferred_deletion_buffer = b'' +deferred_deletion_buffer_size = 0 -def create_ndarray(ndim, dtype, shape=None, name=None, command_buffer=None): +def create_ndarray(ndim, dtype, shape=None, name=None, command_buffer=None, is_scalar=False): z = ndarray(ndim, dtype=dtype, shape=shape, name=name, - command_buffer=command_buffer) + command_buffer=command_buffer, is_scalar=is_scalar) return z @@ -23,10 +26,22 @@ def from_numpy(nparr): return ndarray(nparr.ndim, dtype=nparr.dtype, shape=nparr.shape, nparr=nparr) +def isScalarResult(a, b): + return a.is_scalar and b.is_scalar + +def getDimShape(a, b): + if isinstance(b, float) or isinstance(b, int): + return [a.ndim, a.shape.copy()] + elif isinstance(a, float) or isinstance(a, int): + return [b.ndim, b.shape.copy()] + elif a.is_scalar: + return [b.ndim, b.shape.copy()] + else: + return [a.ndim, a.shape.copy()] class ndarray: def __init__(self, ndim, shape=None, dtype=np.float64, init_value=None, - nparr=None, name=None, command_buffer=None): + nparr=None, name=None, command_buffer=None, is_scalar=False): """ This is the wrapper class for AUM array objects. The argument 'name' should be None except when wrapping @@ -41,6 +56,7 @@ def __init__(self, ndim, shape=None, dtype=np.float64, init_value=None, self.itemsize = np.dtype(dtype).itemsize self.init_value = init_value self.command_buffer = command_buffer + self.is_scalar = is_scalar if isinstance(shape, np.ndarray) or isinstance(shape, list) or \ isinstance(shape, tuple): self.shape = np.asarray(shape, dtype=np.int32) @@ -72,39 +88,33 @@ def __init__(self, ndim, shape=None, dtype=np.float64, init_value=None, if is_debug(): print("Maximum AST depth exceeded for %i, " "flushing buffer" % self.name) - self._flush_command_buffer() + self._flush_command_buffer(hasExceededMaxAstDepth=True) def __del__(self): - global deletion_buffer, deletion_buffer_size - if self.valid: - deletion_buffer += to_bytes(self.name, 'L') - deletion_buffer_size += 1 + global doDeferredDeletions + if doDeferredDeletions: + global deferred_deletion_buffer, deferred_deletion_buffer_size + if self.valid: + deferred_deletion_buffer += to_bytes(self.name, 'L') + deferred_deletion_buffer_size += 1 + else: + global deletion_buffer, deletion_buffer_size + if self.valid: + deletion_buffer += to_bytes(self.name, 'L') + deletion_buffer_size += 1 def __len__(self): return self.shape[0] - #def __str__(self): - # print(self.get()) - - #def __repr__(self): - # #self._flush_command_buffer() - # # FIXME add repr - # pass - - def __setitem__(self, key, value): - if not isinstance(key, slice) or key.start != None or \ - key.stop != None or key.step != None: - raise ValueError("Can't set items or slices") - self.cmd_buffer = ASTNode(res, OPCODES.get('setitem'), [self, value]) - def __neg__(self): return self * -1 def __add__(self, other): res = get_name() cmd_buffer = ASTNode(res, OPCODES.get('+'), [self, other]) - return create_ndarray(self.ndim, self.dtype, shape=self.shape.copy(), - name=res, command_buffer=cmd_buffer) + ndim, shape = getDimShape(self, other) + return create_ndarray(ndim, self.dtype, shape=shape, + name=res, command_buffer=cmd_buffer, is_scalar=isScalarResult(self, other)) def __radd__(self, other): return self + other @@ -112,8 +122,10 @@ def __radd__(self, other): def __sub__(self, other): res = get_name() cmd_buffer = ASTNode(res, OPCODES.get('-'), [self, other]) - return create_ndarray(self.ndim, self.dtype, shape=self.shape.copy(), - name=res, command_buffer=cmd_buffer) + ndim, shape = getDimShape(self, other) + return create_ndarray(ndim, self.dtype, shape=shape, + name=res, command_buffer=cmd_buffer, is_scalar=isScalarResult(self, other)) + def __rsub__(self, other): return -1 * (self - other) @@ -121,8 +133,10 @@ def __rsub__(self, other): def __lt__(self, other): res = get_name() cmd_buffer = ASTNode(res, OPCODES.get('<'), [self, other]) - return create_ndarray(self.ndim, self.dtype, shape=self.shape.copy(), - name=res, command_buffer=cmd_buffer) + ndim, shape = getDimShape(self, other) + return create_ndarray(ndim, self.dtype, shape=shape, + name=res, command_buffer=cmd_buffer, is_scalar=isScalarResult(self, other)) + def __rlt__(self, other): return self >= other @@ -130,8 +144,10 @@ def __rlt__(self, other): def __gt__(self, other): res = get_name() cmd_buffer = ASTNode(res, OPCODES.get('>'), [self, other]) - return create_ndarray(self.ndim, self.dtype, shape=self.shape.copy(), - name=res, command_buffer=cmd_buffer) + ndim, shape = getDimShape(self, other) + return create_ndarray(ndim, self.dtype, shape=shape, + name=res, command_buffer=cmd_buffer, is_scalar=isScalarResult(self, other)) + def __rgt__(self, other): return self <= other @@ -139,8 +155,10 @@ def __rgt__(self, other): def __le__(self, other): res = get_name() cmd_buffer = ASTNode(res, OPCODES.get('<='), [self, other]) - return create_ndarray(self.ndim, self.dtype, shape=self.shape.copy(), - name=res, command_buffer=cmd_buffer) + ndim, shape = getDimShape(self, other) + return create_ndarray(ndim, self.dtype, shape=shape, + name=res, command_buffer=cmd_buffer, is_scalar=isScalarResult(self, other)) + def __rle__(self, other): return self > other @@ -148,8 +166,10 @@ def __rle__(self, other): def __ge__(self, other): res = get_name() cmd_buffer = ASTNode(res, OPCODES.get('>='), [self, other]) - return create_ndarray(self.ndim, self.dtype, shape=self.shape.copy(), - name=res, command_buffer=cmd_buffer) + ndim, shape = getDimShape(self, other) + return create_ndarray(ndim, self.dtype, shape=shape, + name=res, command_buffer=cmd_buffer, is_scalar=isScalarResult(self, other)) + def __rge__(self, other): return self < other @@ -157,8 +177,10 @@ def __rge__(self, other): def __eq__(self, other): res = get_name() cmd_buffer = ASTNode(res, OPCODES.get('=='), [self, other]) - return create_ndarray(self.ndim, self.dtype, shape=self.shape.copy(), - name=res, command_buffer=cmd_buffer) + ndim, shape = getDimShape(self, other) + return create_ndarray(ndim, self.dtype, shape=shape, + name=res, command_buffer=cmd_buffer, is_scalar=isScalarResult(self, other)) + def __req__(self, other): return self == other @@ -166,8 +188,10 @@ def __req__(self, other): def __ne__(self, other): res = get_name() cmd_buffer = ASTNode(res, OPCODES.get('!='), [self, other]) - return create_ndarray(self.ndim, self.dtype, shape=self.shape.copy(), - name=res, command_buffer=cmd_buffer) + ndim, shape = getDimShape(self, other) + return create_ndarray(ndim, self.dtype, shape=shape, + name=res, command_buffer=cmd_buffer, is_scalar=isScalarResult(self, other)) + def __rne__(self, other): return self != other @@ -175,8 +199,10 @@ def __rne__(self, other): def __and__(self, other): res = get_name() cmd_buffer = ASTNode(res, OPCODES.get('&'), [self, other]) - return create_ndarray(self.ndim, self.dtype, shape=self.shape.copy(), - name=res, command_buffer=cmd_buffer) + ndim, shape = getDimShape(self, other) + return create_ndarray(ndim, self.dtype, shape=shape, + name=res, command_buffer=cmd_buffer, is_scalar=isScalarResult(self, other)) + def __rand__(self, other): return self & other @@ -184,8 +210,10 @@ def __rand__(self, other): def __or__(self, other): res = get_name() cmd_buffer = ASTNode(res, OPCODES.get('|'), [self, other]) - return create_ndarray(self.ndim, self.dtype, shape=self.shape.copy(), - name=res, command_buffer=cmd_buffer) + ndim, shape = getDimShape(self, other) + return create_ndarray(ndim, self.dtype, shape=shape, + name=res, command_buffer=cmd_buffer, is_scalar=isScalarResult(self, other)) + def __ror__(self, other): return self | other @@ -194,13 +222,15 @@ def __invert__(self): res = get_name() cmd_buffer = ASTNode(res, OPCODES.get('!'), [self]) return create_ndarray(self.ndim, self.dtype, shape=self.shape.copy(), - name=res, command_buffer=cmd_buffer) + name=res, command_buffer=cmd_buffer, is_scalar=self.is_scalar) def __mul__(self, other): res = get_name() cmd_buffer = ASTNode(res, OPCODES.get('*'), [self, other]) - return create_ndarray(self.ndim, self.dtype, shape=self.shape.copy(), - name=res, command_buffer=cmd_buffer) + ndim, shape = getDimShape(self, other) + return create_ndarray(ndim, self.dtype, shape=shape, + name=res, command_buffer=cmd_buffer, is_scalar=isScalarResult(self, other)) + def __rmul__(self, other): return self * other @@ -208,16 +238,21 @@ def __rmul__(self, other): def __truediv__(self, other): res = get_name() cmd_buffer = ASTNode(res, OPCODES.get('/'), [self, other]) - return create_ndarray(self.ndim, self.dtype, shape=self.shape.copy(), - name=res, command_buffer=cmd_buffer) + ndim, shape = getDimShape(self, other) + return create_ndarray(ndim, self.dtype, shape=shape, + name=res, command_buffer=cmd_buffer, is_scalar=isScalarResult(self, other)) + def __rtruediv__(self, other): res = get_name() cmd_buffer = ASTNode(res, OPCODES.get('/'), [1., self/other]) - return create_ndarray(self.ndim, self.dtype, shape=self.shape.copy(), - name=res, command_buffer=cmd_buffer) + ndim, shape = getDimShape(self, other) + return create_ndarray(ndim, self.dtype, shape=shape, + name=res, command_buffer=cmd_buffer, is_scalar=isScalarResult(self, other)) + def __matmul__(self, other): + is_scalar = False if self.ndim == 2 and other.ndim == 2: res_ndim = 2 shape = np.array([self.shape[0], other.shape[1]], dtype=np.int32) @@ -225,41 +260,34 @@ def __matmul__(self, other): res_ndim = 1 shape = np.array([self.shape[0]], dtype=np.int32) elif self.ndim == 1 and other.ndim == 1: - res_ndim = 0 + res_ndim = 1 shape = np.array([1], dtype=np.int32) + is_scalar = True else: raise RuntimeError("Dimension mismatch") res = get_name() cmd_buffer = ASTNode(res, OPCODES.get('@'), [self, other]) return create_ndarray(res_ndim, self.dtype, shape=shape, - name=res, command_buffer=cmd_buffer) + name=res, command_buffer=cmd_buffer, is_scalar=is_scalar) - def _flush_command_buffer(self): + def _flush_command_buffer(self, hasExceededMaxAstDepth=False): # send the command to server # finally set command buffer to array name - global deletion_buffer, deletion_buffer_size + global deletion_buffer, deletion_buffer_size, deferred_deletion_buffer, deferred_deletion_buffer_size debug = is_debug() if debug: self.command_buffer.plot_graph() if self.valid: return - validated_arrays = {self.name : self} - cmd = self.command_buffer.get_command(validated_arrays) - reply_size = 0 - for name, arr in validated_arrays.items(): - reply_size += 8 + 8 * arr.ndim + cmd = self.command_buffer.get_command(self.ndim, self.shape, is_scalar=self.is_scalar, hasExceededMaxAstDepth=hasExceededMaxAstDepth) if not debug: - cmd = to_bytes(deletion_buffer_size, 'I') + deletion_buffer + cmd + cmd = to_bytes(deletion_buffer_size, 'I') + deletion_buffer + to_bytes(deferred_deletion_buffer_size, 'I') + deferred_deletion_buffer + cmd cmd = to_bytes(get_epoch(), 'i') + to_bytes(len(cmd), 'I') + cmd send_command_async(Handlers.operation_handler, cmd) deletion_buffer = b'' deletion_buffer_size = 0 - for i in range(len(validated_arrays)): - arr = validated_arrays[name] - arr.validate() - else: - for name, arr in validated_arrays.items(): - arr.validate() + deferred_deletion_buffer = b'' + deferred_deletion_buffer_size = 0 self.validate() def get(self): @@ -280,68 +308,186 @@ def evaluate(self): self._flush_command_buffer() def validate(self): + global doDeferredDeletions self.valid = True - self.command_buffer = ASTNode(self.name, 0, [self]) + doDeferredDeletions = True + self.command_buffer = ASTNode(self.name, 0, [weakref.proxy(self)]) + doDeferredDeletions = False def copy(self): res = get_name() cmd_buffer = ASTNode(res, OPCODES.get('copy'), [self]) - return create_ndarray(self.ndim, self.dtype,shape=self.shape.copy(), + return create_ndarray(self.ndim, self.dtype,shape=self.shape.copy(), name=res, command_buffer=cmd_buffer, is_scalar=self.is_scalar) + + def where(self, other, third): + res = get_name() + cmd_buffer = ASTNode(res, OPCODES.get('where'), [other, third, self]) + return create_ndarray(self.ndim, self.dtype, shape=self.shape.copy(), + name=res, command_buffer=cmd_buffer, is_scalar=self.is_scalar) + + def exp(self): + res = get_name() + cmd_buffer = ASTNode(res, OPCODES.get('exp'), [self]) + return create_ndarray(self.ndim, self.dtype, shape=self.shape.copy(), name=res, command_buffer=cmd_buffer) - def sqrt(self): + + def log(self, base = np.e): res = get_name() - cmd_buffer = ASTNode(res, OPCODES.get('pow'), [self], arg=0.5) + cmd_buffer = ASTNode(res, OPCODES.get('log'), [self], args=[base]) return create_ndarray(self.ndim, self.dtype, shape=self.shape.copy(), name=res, command_buffer=cmd_buffer) - def cbrt(self): + def log10(self): res = get_name() - cmd_buffer = ASTNode(res, OPCODES.get('pow'), [self], arg=1/3) + cmd_buffer = ASTNode(res, OPCODES.get('log'), [self], args = [10]) return create_ndarray(self.ndim, self.dtype, shape=self.shape.copy(), name=res, command_buffer=cmd_buffer) - def pow(self, exponent): + def log2(self): res = get_name() - cmd_buffer = ASTNode(res, OPCODES.get('pow'), [self], arg=exponent) + cmd_buffer = ASTNode(res, OPCODES.get('log'), [self], args = [2]) return create_ndarray(self.ndim, self.dtype, shape=self.shape.copy(), name=res, command_buffer=cmd_buffer) - def log(self, base=np.e): + def abs(self): res = get_name() - cmd_buffer = ASTNode(res, OPCODES.get('log'), [self], arg=base) + cmd_buffer = ASTNode(res, OPCODES.get('abs'), [self]) + return create_ndarray(self.ndim, self.dtype, shape=self.shape.copy(), + name=res, command_buffer=cmd_buffer) + + def negate(self): + res = get_name() + cmd_buffer = ASTNode(res, OPCODES.get('negate'), [self]) return create_ndarray(self.ndim, self.dtype, shape=self.shape.copy(), name=res, command_buffer=cmd_buffer) - def log10(self): + def square(self): res = get_name() - cmd_buffer = ASTNode(res, OPCODES.get('log'), [self], arg=10) + cmd_buffer = ASTNode(res, OPCODES.get('sqare'), [self]) return create_ndarray(self.ndim, self.dtype, shape=self.shape.copy(), name=res, command_buffer=cmd_buffer) - def log2(self): + def sqrt(self): res = get_name() - cmd_buffer = ASTNode(res, OPCODES.get('log'), [self], arg=2) + cmd_buffer = ASTNode(res, OPCODES.get('sqrt'), [self]) return create_ndarray(self.ndim, self.dtype, shape=self.shape.copy(), name=res, command_buffer=cmd_buffer) - def exp(self): + def reciprocal(self): res = get_name() - cmd_buffer = ASTNode(res, OPCODES.get('exp'), [self]) + cmd_buffer = ASTNode(res, OPCODES.get('reciprocal'), [self]) + return create_ndarray(self.ndim, self.dtype, shape=self.shape.copy(), + name=res, command_buffer=cmd_buffer) + + def sin(self): + res = get_name() + cmd_buffer = ASTNode(res, OPCODES.get('sin'), [self]) + return create_ndarray(self.ndim, self.dtype, shape=self.shape.copy(), + name=res, command_buffer=cmd_buffer) + + def cos(self): + res = get_name() + cmd_buffer = ASTNode(res, OPCODES.get('cos'), [self]) + return create_ndarray(self.ndim, self.dtype, shape=self.shape.copy(), + name=res, command_buffer=cmd_buffer) + + def relu(self): + res = get_name() + cmd_buffer = ASTNode(res, OPCODES.get('relu'), [self]) return create_ndarray(self.ndim, self.dtype, shape=self.shape.copy(), name=res, command_buffer=cmd_buffer) - def absolute(self): + def scale(self, scalar): res = get_name() - cmd_buffer = ASTNode(res, OPCODES.get('abs'), [self]) + cmd_buffer = ASTNode(res, OPCODES.get('scale'), [self], args=[scalar]) return create_ndarray(self.ndim, self.dtype, shape=self.shape.copy(), name=res, command_buffer=cmd_buffer) - def where(self, other, third): + def add_constant(self, constant): res = get_name() - cmd_buffer = ASTNode(res, OPCODES.get('where'), [other, third, self]) + cmd_buffer = ASTNode(res, OPCODES.get('add_constant'), [self], args=[constant]) + return create_ndarray(self.ndim, self.dtype, shape=self.shape.copy(), + name=res, command_buffer=cmd_buffer) + + def add(self, other): + res = get_name() + cmd_buffer = ASTNode(res, OPCODES.get('add'), [self, other]) return create_ndarray(self.ndim, self.dtype, shape=self.shape.copy(), name=res, command_buffer=cmd_buffer) + def subtract(self, other): + res = get_name() + cmd_buffer = ASTNode(res, OPCODES.get('subtract'), [self, other]) + return create_ndarray(self.ndim, self.dtype, shape=self.shape.copy(), + name=res, command_buffer=cmd_buffer) + + def multiply(self, other): + res = get_name() + cmd_buffer = ASTNode(res, OPCODES.get('multiply'), [self, other]) + return create_ndarray(self.ndim, self.dtype, shape=self.shape.copy(), + name=res, command_buffer=cmd_buffer) + + def divide(self, other): + res = get_name() + cmd_buffer = ASTNode(res, OPCODES.get('divide'), [self, other]) + return create_ndarray(self.ndim, self.dtype, shape=self.shape.copy(), + name=res, command_buffer=cmd_buffer) + + def modulo(self, other): + res = get_name() + cmd_buffer = ASTNode(res, OPCODES.get('modulo'), [self, other]) + return create_ndarray(self.ndim, self.dtype, shape=self.shape.copy(), + name=res, command_buffer=cmd_buffer) + + def power(self, other): + res = get_name() + cmd_buffer = ASTNode(res, OPCODES.get('power'), [self, other]) + return create_ndarray(self.ndim, self.dtype, shape=self.shape.copy(), + name=res, command_buffer=cmd_buffer) + + def max(self, other): + res = get_name() + cmd_buffer = ASTNode(res, OPCODES.get('max'), [self, other]) + return create_ndarray(self.ndim, self.dtype, shape=self.shape.copy(), + name=res, command_buffer=cmd_buffer) + + def min(self, other): + res = get_name() + cmd_buffer = ASTNode(res, OPCODES.get('min'), [self, other]) + return create_ndarray(self.ndim, self.dtype, shape=self.shape.copy(), + name=res, command_buffer=cmd_buffer) + + def greater_than(self, other): + res = get_name() + cmd_buffer = ASTNode(res, OPCODES.get('greater_than'), [self, other]) + return create_ndarray(self.ndim, self.dtype, shape=self.shape.copy(), + name=res, command_buffer=cmd_buffer) + + def less_than(self, other): + res = get_name() + cmd_buffer = ASTNode(res, OPCODES.get('less_than'), [self, other]) + return create_ndarray(self.ndim, self.dtype, shape=self.shape.copy(), + name=res, command_buffer=cmd_buffer) + + def equal(self, other, epsilon=1e-5): + res = get_name() + cmd_buffer = ASTNode(res, OPCODES.get('equal'), [self, other], args=[epsilon]) + return create_ndarray(self.ndim, self.dtype, shape=self.shape.copy(), + name=res, command_buffer=cmd_buffer) + + def atan2(self, other): + res = get_name() + cmd_buffer = ASTNode(res, OPCODES.get('atan2'), [self, other]) + return create_ndarray(self.ndim, self.dtype, shape=self.shape.copy(), + name=res, command_buffer=cmd_buffer) + + def weighted_average(self, other, w1, w2): + res = get_name() + cmd_buffer = ASTNode(res, OPCODES.get('weighted_average'), [self, other], + args=[w1, w2]) + return create_ndarray(self.ndim, self.dtype, shape=self.shape.copy(), + name=res, command_buffer=cmd_buffer) + def any(self): res = get_name() cmd_buffer = ASTNode(res, OPCODES.get('any'), [self]) diff --git a/charmnumeric/ast.py b/charmnumeric/ast.py index f67a98b..217cb61 100644 --- a/charmnumeric/ast.py +++ b/charmnumeric/ast.py @@ -7,7 +7,7 @@ max_depth = 10 - +multiLineFuse = False def set_max_depth(d): global max_depth @@ -18,55 +18,96 @@ def get_max_depth(): global max_depth return max_depth +def charm_fuse(func): + def compile_wrapper(*args, **kwargs): + global multiLineFuse + orig_max_depth = get_max_depth() + multiLineFuse = True + set_max_depth(float('inf')) + out = func(*args, **kwargs) + multiLineFuse = False + set_max_depth(orig_max_depth) + return out + return compile_wrapper class ASTNode(object): - def __init__(self, name, opcode, operands, arg=0.0): - from charmtiles.array import ndarray + def __init__(self, name, opcode, operands, args=[]): + from charmnumeric.array import ndarray + global multiLineFuse # contains opcode, operands # operands are ndarrays self.name = name self.opcode = opcode self.operands = operands self.depth = 0 - self.arg = arg + self.args = args + self.multiLineFuse = multiLineFuse if self.opcode != 0: for op in self.operands: if isinstance(op, ndarray): self.depth = max(self.depth, 1 + op.command_buffer.depth) - def get_command(self, validated_arrays, save=True): + ################################################################################################################################################################# + # Marker determines whether we are dealing with a tensor, a scalar or an arithmetic type # + # Marker = 0 : arithmetic type # + # Marker = 1 : scalar type # + # Marker = 2 : tensor type # + # Encoding = | Marker | dim | shape | opcode | save_op | ID | multiLineFuse | NumArgs | Args | NumOperands | OperandEncodingSize | RecursiveOperandEncoding | # + # | 8 | 8 | 64 | 32 | 1 | 64 | 1 | 32 | 64 | 8 | 32 | ........................ | # + # NB: If opcode is 0, the encoding is limited to ID # + # Encoding = | Marker | shape | val | # + # | 8 | 64 | 64 | # + # NB: Latter encoding for double constants # + ################################################################################################################################################################# + def get_command(self, ndim, shape, save=True, is_scalar=False, hasExceededMaxAstDepth=False): from charmnumeric.array import ndarray + + # Ndims and Shape setup + if is_scalar: + cmd = to_bytes(1, 'B') + else: + cmd = to_bytes(2, 'B') + cmd += to_bytes(ndim, 'B') + for _shape in shape: + cmd += to_bytes(_shape, 'L') + if self.opcode == 0: - cmd = to_bytes(self.opcode, 'L') - cmd += to_bytes(False, '?') - cmd += to_bytes(self.operands[0].name, 'L') + cmd += to_bytes(0, 'I') + to_bytes(False, '?') + to_bytes(self.operands[0].name, 'L') return cmd - cmd = to_bytes(self.opcode, 'L') + to_bytes(self.name, 'L') - cmd += to_bytes(save, '?') + to_bytes(len(self.operands), 'B') + + cmd += to_bytes(self.opcode, 'I') + to_bytes(save, '?') + to_bytes(self.name, 'L') + to_bytes(self.multiLineFuse, '?') + cmd += to_bytes(len(self.args), 'I') + for arg in self.args: + cmd += to_bytes(arg, 'd') + + cmd += to_bytes(len(self.operands), 'B') for op in self.operands: - # an operand can also be a double if isinstance(op, ndarray): - if op.name in validated_arrays: - opcmd = to_bytes(0, 'L') - opcmd += to_bytes(False, '?') - opcmd += to_bytes(op.name, 'L') - cmd += to_bytes(len(opcmd), 'I') - cmd += opcmd + if op.valid: + if op.is_scalar: + opcmd = to_bytes(1, 'B') + else: + opcmd = to_bytes(2, 'B') + opcmd += to_bytes(op.ndim, 'B') + for _shape in op.shape: + opcmd += to_bytes(_shape, 'L') + opcmd += to_bytes(0, 'I') + to_bytes(False, '?') + to_bytes(op.name, 'L') else: - save_op = True if c_long.from_address(id(op)).value - 2 > 0 else False - opcmd = op.command_buffer.get_command(validated_arrays, - save=save_op) - if not op.valid and save_op: - validated_arrays[op.name] = op - cmd += to_bytes(len(opcmd), 'I') - cmd += opcmd + ### this will only be true when AST is being flushed because of exceeding max depth and ensures that unnecessary temporaries are not saved + if hasExceededMaxAstDepth: + save_op = True if c_long.from_address(id(op)).value - 4 > 0 else False + else: + save_op = True if c_long.from_address(id(op)).value - 2 > 0 else False + opcmd = op.command_buffer.get_command(op.ndim, op.shape, save=save_op, is_scalar=op.is_scalar) + if save_op or (op.command_buffer.opcode == OPCODES.get('@')) or (op.command_buffer.opcode == OPCODES.get('copy')): + op.validate() elif isinstance(op, float) or isinstance(op, int): - opcmd = to_bytes(0, 'L') - opcmd += to_bytes(True, '?') - opcmd += to_bytes(op, 'd') - cmd += to_bytes(len(opcmd), 'I') - cmd += opcmd - cmd += to_bytes(self.arg, 'd') + opcmd = to_bytes(0, 'B') + for _shape in shape: + opcmd += to_bytes(_shape, 'L') + opcmd += to_bytes(float(op), 'd') + cmd += to_bytes(len(opcmd), 'I') + cmd += opcmd return cmd def plot_graph(self, validated_arrays={}, G=None, node_map={}, diff --git a/charmnumeric/ccs.py b/charmnumeric/ccs.py index 3e4639f..c565e6a 100644 --- a/charmnumeric/ccs.py +++ b/charmnumeric/ccs.py @@ -1,7 +1,7 @@ import struct import atexit from pyccs import Server -from charmnumeric import array +import gc debug = False server = None @@ -9,10 +9,55 @@ next_name = 0 epoch = 0 -OPCODES = {'+': 1, '-': 2, '*': 3 ,'/': 4, '@': 5, 'copy': 6, 'axpy': 7, - 'axpy_multiplier': 8, 'setitem': 9, 'pow': 10, '>': 11, - '<': 12, '>=': 13, '<=': 14, '==': 15, '!=': 16, '&': 17, - '|': 18, '!':19, 'where':20, 'log': 21, 'exp': 22, 'abs': 23, 'any':24, 'all':25} +OPCODES = { + # base_op + '+': 1, + '-': 2, + '*': 3, + '/': 4, + '@': 5, + 'copy': 6, + '>': 9, + '<': 10, + '>=': 11, + '<=': 12, + '==': 13, + '!=': 14, + '&': 15, + '|': 16, + '!': 17, + 'where': 18, + + # custom_unary_op + 'exp': 41, + 'log': 42, + 'abs': 43, + 'negate': 44, + 'square': 45, + 'sqrt': 46, + 'reciprocal': 47, + 'sin': 48, + 'cos': 49, + 'relu': 50, + 'scale': 51, + 'add_constant': 52, + + # custom_binary_op + 'add': 71, + 'subtract': 72, + 'multiply': 73, + 'divide': 74, + 'power': 75, + 'modulo': 76, + 'max': 77, + 'min': 78, + 'greater_than': 79, + 'less_than': 80, + 'equal': 81, + 'atan2': 82, + 'weighted_average': 83, + 'axpy': 84 +} INV_OPCODES = {v: k for k, v in OPCODES.items()} @@ -69,14 +114,25 @@ def connect(server_ip, server_port): atexit.register(disconnect) def disconnect(): - from charmnumeric.array import deletion_buffer, deletion_buffer_size - global client_id, deletion_buffer, deletion_buffer_size - if deletion_buffer_size > 0: - cmd = to_bytes(len(deletion_buffer), 'I') + deletion_buffer + # cleanup the remaining ndarrays + from charmnumeric.array import ndarray + deleted_id = [] + for obj in gc.get_objects(): + if isinstance(obj, ndarray): + if not obj.name in deleted_id: + print(obj.name) + deleted_id.append(obj.name) + obj.__del__() + from charmnumeric.array import deletion_buffer, deletion_buffer_size, deferred_deletion_buffer_size, deferred_deletion_buffer + if (deletion_buffer_size > 0) or (deferred_deletion_buffer_size > 0): + cmd = to_bytes(deletion_buffer_size, 'I') + deletion_buffer + to_bytes(deferred_deletion_buffer_size, 'I') + deferred_deletion_buffer cmd = to_bytes(get_epoch(), 'i') + to_bytes(len(cmd), 'I') + cmd send_command_async(Handlers.delete_handler, cmd) deletion_buffer = b'' - deletion_buffer_size = b'' + deletion_buffer_size = 0 + deferred_deletion_buffer = b'' + deferred_deletion_buffer_size = 0 + global client_id cmd = to_bytes(client_id, 'B') cmd = to_bytes(get_epoch(), 'i') + to_bytes(len(cmd), 'I') + cmd send_command_async(Handlers.disconnection_handler, cmd) @@ -95,7 +151,6 @@ def get_creation_command(arr, name, shape, buf=None): cmd += buf elif arr.init_value is not None: cmd += to_bytes(arr.init_value, 'd') - print(cmd) cmd = to_bytes(get_epoch(), 'i') + to_bytes(len(cmd), 'I') + cmd return cmd diff --git a/charmnumeric/linalg.py b/charmnumeric/linalg.py index 823b7d5..2894e4b 100644 --- a/charmnumeric/linalg.py +++ b/charmnumeric/linalg.py @@ -7,16 +7,8 @@ from charmnumeric.ast import ASTNode -def axpy(a, x, y, multiplier=None): - operands = [a, x, y] - if multiplier is not None: - operands.append(multiplier) - operation = 'axpy_multiplier' - else: - operation = 'axpy' +def axpy(a, x, y): res = get_name() - cmd_buffer = ASTNode(res, OPCODES.get(operation), operands) - return create_ndarray(x.ndim, x.dtype, + cmd_buffer = ASTNode(res, OPCODES.get('axpy'), [x, y], args=[a]) + return create_ndarray(x.ndim, x.dtype, x.shape, name=res, command_buffer=cmd_buffer) - - diff --git a/config.cmake b/config.cmake new file mode 100644 index 0000000..53660ec --- /dev/null +++ b/config.cmake @@ -0,0 +1,13 @@ +set(CHARM_DIR "/home/anant/winter2024/lbp/study/charm/netlrts-linux-x86_64/") +set(BASE_DIR "/home/anant/sem7/LibCharmtyles") +set(EIGEN_DIR "/usr/include/eigen3") +set(CUDA_DIR "/path/to/CUDA/directory") +set(KOKKOS_DIR "${BASE_DIR}/kokkos/install") +set(KOKKOS_KERNELS_DIR "${BASE_DIR}/kokkos-kernels/install") + +set(CHARMC "${CHARM_DIR}/bin/charmc") +set(CPU_OPTS "-c++-option -std=c++20 -O3 -march=native -DNDEBUG") +set(GPU_OPTS "-std=c++20 -O3 -march=native -DNDEBUG") +set(GPU_LINK_OPTS -O3 -language charm++ -L${KOKKOS_DIR}/lib64 -L/u/ajain18/kokkos_kernels_install/lib64 -lkokkoscore -lkokkoscontainers -lkokkoskernels -L${CUDA_DIR} -lcuda -lcudart -lcusparse -lcublas) +set(LD_OPTS "") +set(INCS "-I${BASE_DIR}") diff --git a/examples/bench.py b/examples/bench.py new file mode 100644 index 0000000..5e63348 --- /dev/null +++ b/examples/bench.py @@ -0,0 +1,51 @@ +from charmnumeric.array import connect, ndarray +from charmnumeric.ast import set_max_depth +from charmnumeric.ccs import enable_debug +import charmnumeric.linalg as lg +import numpy as np +import time + +set_max_depth(10) + +def f(): + b = ndarray(1, 10, np.float64, init_value=10) + v = ndarray(1, 10, np.float64, init_value=20) + c = ndarray(1, 10, np.float64, init_value=20) + v1 = (b + v) - c * 3 + v2 = b.scale(3) + c.add_constant(10) + v3 = (b + c) @ v + v1.get() + v2.get() + v3.get() + start = time.time() + for i in range(100): + v1 = (b + v) - c * 3 + v2 = b.scale(3) + c.add_constant(10) + v3 = (b + c) @ v + v1.get() + v2.get() + v3.get() + + end = time.time() + + print("VECTOR BENCH ", end-start) + + + start = time.time() + b = ndarray(2, [10,10], np.float64, init_value=10) + v = ndarray(2, [10,10], np.float64, init_value=20) + c = ndarray(2, [10,10], np.float64, init_value=20) + for i in range(100): + v1 = (b + v) - c * 3 + v2 = b.exp() + c.add_constant(10) + v3 = (b + c) @ v + v1.get() + v2.get() + v3.get() + + end = time.time() + print("MARIX BENCH ", end-start) + +if __name__ == '__main__': + connect("172.17.0.1", 10000) + s = f() \ No newline at end of file diff --git a/examples/charm_fuse.py b/examples/charm_fuse.py new file mode 100644 index 0000000..62c5cdc --- /dev/null +++ b/examples/charm_fuse.py @@ -0,0 +1,44 @@ +from charmnumeric.array import connect, ndarray +from charmnumeric.ast import set_max_depth, charm_fuse +from charmnumeric.ccs import enable_debug +import charmnumeric.linalg as lg +import numpy as np +import time +set_max_depth(float('inf')) + +@charm_fuse +def f(): + v = ndarray(1, 1e2, np.float64, init_value=-20) + b = ndarray(1, 1e2, np.float64, init_value=10) + + g1 = v.abs().scale(2).scale(2).add_constant(29) + b + 32 + g2 = b.log(2).exp() + d = g1 + g2 + return d.get() + +def g(): + v = ndarray(1, 1e2, np.float64, init_value=-20) + b = ndarray(1, 1e2, np.float64, init_value=10) + + g1 = v.abs().scale(2).scale(2).add_constant(29) + b + 32 + g2 = b.log(2).exp() + d = g1 + g2 + return d.get() + +if __name__ == '__main__': + connect("127.0.0.1", 10000) + # s = f() + start = time.time() + for(i) in range(1000): + s = f() + end = time.time() + print(s) + print("Time taken(multi line fused): ", end - start) + + # k = g() + # start = time.time() + # for(i) in range(100): + # k = g() + # end = time.time() + # print(k) + # print("Time taken(multi line unfused): ", end - start) diff --git a/examples/conjugate_gradient.py b/examples/conjugate_gradient.py index 3f9c216..1a69d5e 100644 --- a/examples/conjugate_gradient.py +++ b/examples/conjugate_gradient.py @@ -1,52 +1,89 @@ from charmnumeric.array import connect, ndarray -import charmnumeric.linalg as lg -from charmnumeric.ccs import enable_debug, sync from charmnumeric.ast import set_max_depth import numpy as np -import gc - import time -#enable_debug() set_max_depth(10) -#gc.set_threshold(1, 1, 1) -def solve(A, b, x): +def generate_2D(N, corners=True): + if corners: + print( + "Generating %dx%d 2-D adjacency system with corners..." + % (N**2, N**2) + ) + A = np.zeros((N**2, N**2)) + 8 * np.eye(N**2) + else: + print( + "Generating %dx%d 2-D adjacency system without corners..." + % (N**2, N**2) + ) + A = np.zeros((N**2, N**2)) + 4 * np.eye(N**2) + # These are the same for both cases + off_one = np.full(N**2 - 1, -1, dtype=np.float64) + A += np.diag(off_one, k=1) + A += np.diag(off_one, k=-1) + off_N = np.full(N * (N - 1), -1, dtype=np.float64) + A += np.diag(off_N, k=N) + A += np.diag(off_N, k=-N) + # If we have corners then we have four more cases + if corners: + off_N_plus = np.full(N * (N - 1) - 1, -1, dtype=np.float64) + A += np.diag(off_N_plus, k=N + 1) + A += np.diag(off_N_plus, k=-(N + 1)) + off_N_minus = np.full(N * (N - 1) + 1, -1, dtype=np.float64) + A += np.diag(off_N_minus, k=N - 1) + A += np.diag(off_N_minus, k=-(N - 1)) + # Then we can generate a random b matrix + b = np.random.rand(N**2) + return A, b + + +def solve(A, b, x_cp): + x = x_cp.copy() r = b - A @ x p = r.copy() rsold = r @ r - for i in range(100): - #if i % 10 == 0: - gc.collect() + for _ in range(10): Ap = A @ p alpha = rsold / (p @ Ap) - x = lg.axpy(alpha, p, x) - r = lg.axpy(alpha, Ap, r, multiplier=-1.) - + x = alpha * p + x + r = alpha * Ap - r rsnew = r @ r - #if np.sqrt(rsnew.get()) < 1e-8: - # print("Converged in %i iterations" % (i + 1)) - # break - - p = lg.axpy(rsnew / rsold, p, r) + p = (rsnew / rsold) * p + r rsold = rsnew return x if __name__ == '__main__': - connect("172.17.0.1", 10000) + connect("127.0.0.1", 10000) - A = ndarray(2, (184, 184), np.float64) - b = ndarray(1, 184, np.float64) - x = ndarray(1, 184, np.float64) + n = 50 - #d = (b @ x).get() + A_np, b_np = generate_2D(n) + A = ndarray(2, (n**2, n**2), np.float64, nparr = A_np) + b = ndarray(1, n**2, np.float64, nparr = b_np) + x = ndarray(1, A.shape[1], np.float64, init_value = 0) + + # Pre-Compilation + _ = solve(A, b, x) + __ = _.get() + print(__) start = time.time() x = solve(A, b, x) - x.evaluate() - sync() - print("Execution time = %.6f" % (time.time() - start)) + x_charm = x.get() + print("Execution time (Charm) = %.6f s" % (time.time() - start)) + + + x = np.zeros(A_np.shape[1], dtype=np.float64) + start = time.time() + x_np = solve(A_np, b_np, x) + print("Execution time (NumPy) = %.6f s" % (time.time() - start)) + + if np.allclose(x_np, x_charm, atol=1e-5): + print("[SUCCESS]") + else: + print("[FAIL]") diff --git a/examples/custom_ops.py b/examples/custom_ops.py new file mode 100644 index 0000000..0a640d5 --- /dev/null +++ b/examples/custom_ops.py @@ -0,0 +1,26 @@ +from charmnumeric.array import connect, ndarray +from charmnumeric.ast import set_max_depth +from charmnumeric.ccs import enable_debug +import charmnumeric.linalg as lg +import numpy as np + +set_max_depth(10) + +def f(): + v = ndarray(1, 50, np.float64, init_value=-20) + b = ndarray(1, 50, np.float64, init_value=10) + c = ndarray(1, 50, np.float64, init_value=30) + d = ndarray(1, 50, np.float64, init_value=5) + + g1 = v.abs().add(b).weighted_average(c, 0.7, 0.3) + g2 = b.log(2).exp() + g3 = v.abs().scale(2).scale(2).add_constant(29) + b + 32 + + + print("k1:", g1.get()) + print("k2:", g2.get()) + print("k3:", g3.get()) + +if __name__ == '__main__': + connect("127.0.0.1", 10000) + f() diff --git a/examples/graph.py b/examples/graph.py index c414995..c0c78a1 100644 --- a/examples/graph.py +++ b/examples/graph.py @@ -4,22 +4,85 @@ import charmnumeric.linalg as lg import numpy as np -#enable_debug() -set_max_depth(100) +# enable_debug() +set_max_depth(2) def f(): - v = ndarray(1, 10, np.float64) - b = ndarray(1, 10, np.float64, init_value=10) - c = ndarray(1, 10, np.float64) - w = c - for i in range(5): - y = v + b + w - z = v - y - w = 2 * (c - z) + b - w.evaluate() + a = ndarray(1, 10, np.float64, init_value=4) + b = ndarray(1, 10, np.float64, init_value=1) + c = ndarray(1, 10, np.float64, init_value=3) + d = ndarray(1, 10, np.float64, init_value=2) + # vx = [a, b] + for _ in range(1): + e = a + b + c + d + f = e + a + print(f.get()) + # e = a + b * d - c + 42 - 34 + # f = e + c / a + 32 - b + # g = f.scale(69) + 53 - a / 32 + # g = f + d + # k = a + b + c + v + # k.get() + # prog + # k + -> k temp object -> ref k + # + operation + # tree + # op (generate command) + # l.get() + # print(k.get()) + # print(l.get()) + # v1 = v @ b + # v1 = (b + c) @ (b - c) + # q.get() + # v1 = q @ c + # v2 = b @ c + # q1 = q @ v + # v3 = b @ c + + # v1.get() + # v2.get() + # v3.get() + + # res = v3 * 8 + v1 - 4 + v2.abs() + + # res.get() + + # final_res = res + 42 + + # q.get() + # a1 = b @ c + # print(a1.get()) + # a2 = v @ c + # print(a2.get()) + # res = (a1 / a2) * b + c + # v = 2 + # # a3 = a1 + a2 + # print(v.get()) + # res = a3 * v + # w = c @ b + # w.get() + # res = q.abs() + c + # baka = final_res.get() + # v1 = b.copy() + # print(res.get()) + # r = b.where(42, 69) + # g = b.where(v, c) + # z = ~r + # g = v.abs() + b + 32 + # k = g + c * 8 + # k = g + 1 + # k = g + 2 * c - 3 * v + # q = k.get() + # print(res.get()) + # w = c + # for i in range(5): + # y = v + b + w + # z = v - y + # w = 2 * (c - z) + b + # w.evaluate() if __name__ == '__main__': - connect("172.17.0.1", 10000) + connect("127.0.0.1", 10000) s = f() diff --git a/examples/run.sh b/examples/run.sh new file mode 100755 index 0000000..92d37a7 --- /dev/null +++ b/examples/run.sh @@ -0,0 +1,4 @@ +cd .. +python setup.py install +cd examples +python $1 diff --git a/setup.py b/setup.py index 8071e21..9c0b53a 100644 --- a/setup.py +++ b/setup.py @@ -10,15 +10,6 @@ def get_version(): exec(compile(open(fname).read(), fname, 'exec'), data) return data.get('__version__') - -def compile_server(): - charmc = os.environ.get('CHARMC', - '/home/adityapb/charm/charm/netlrts-linux-x86_64/bin/charmc') - aum_base = os.environ.get('AUM_HOME', '/home/adityapb/charm/LibCharmtyles') - subprocess.run(["make", "-C", "src/", - "CHARMC=%s" % charmc, "BASE_DIR=%s" % aum_base]) - - install_requires = ['numpy', 'charm4py'] tests_require = ['pytest'] docs_require = ['sphinx'] @@ -39,8 +30,6 @@ def compile_server(): ''' classifiers = [x.strip() for x in classes.splitlines() if x] -compile_server() - setup( name='charmnumeric', version=get_version(), diff --git a/src/.gitignore b/src/.gitignore new file mode 100644 index 0000000..378eac2 --- /dev/null +++ b/src/.gitignore @@ -0,0 +1 @@ +build diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt new file mode 100644 index 0000000..6172df9 --- /dev/null +++ b/src/CMakeLists.txt @@ -0,0 +1,60 @@ +cmake_minimum_required(VERSION 3.16) +project(charmTyles) + +include(${CMAKE_SOURCE_DIR}/../config.cmake) + +set(Kokkos_ROOT ${KOKKOS_DIR}) +find_package(Kokkos 4.5 REQUIRED CONFIG) + +set(KokkosKernels_DIR ${KOKKOS_KERNELS_DIR}/lib/cmake/KokkosKernels) +find_package(KokkosKernels REQUIRED) + +if(Charm_ENABLE_GPU) + message(STATUS "Building for a GPU backend") + add_definitions(-DGPU_BACKEND) + add_definitions(-DCUDA_DIR=\"${CUDA_DIR}\") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${GPU_OPTS} ${LD_OPTS}") +else() + message(STATUS "Building for a CPU backend") + set(CMAKE_CXX_COMPILER "${CHARMC}") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${CPU_OPTS} ${LD_OPTS}") +endif() + +add_definitions(-DKOKKOS_DIR=\"${KOKKOS_DIR}\") + +add_custom_command( + OUTPUT ${BASE_DIR}/charmtyles/backend/libcharmtyles.decl.h + COMMAND ${CMAKE_COMMAND} -E chdir ${BASE_DIR}/charmtyles/backend + ${CHARMC} ${BASE_DIR}/charmtyles/backend/charmtyles.ci + DEPENDS ${BASE_DIR}/charmtyles/backend/charmtyles.ci + COMMENT "Processing charmtyles ci files" +) + +add_custom_command( + OUTPUT ${CMAKE_SOURCE_DIR}/server.decl.h + COMMAND ${CMAKE_COMMAND} -E chdir ${CMAKE_SOURCE_DIR} + ${CHARMC} ${CMAKE_SOURCE_DIR}/server.ci + DEPENDS ${CMAKE_SOURCE_DIR}/server.ci + COMMENT "Processing server ci files" +) + +if(Charm_ENABLE_GPU) + add_library(server OBJECT server.cpp ${BASE_DIR}/charmtyles/backend/libcharmtyles.decl.h ${CMAKE_SOURCE_DIR}/server.decl.h) + target_include_directories(server PRIVATE ${BASE_DIR} ${BASE_DIR}/charmtyles/backend ${EIGEN_DIR}/include ${CHARM_DIR}/include) + target_link_libraries(server Kokkos::kokkos Kokkos::kokkoskernels) + + add_custom_command( + OUTPUT "${CMAKE_BINARY_DIR}/server.out" + COMMAND ${CHARMC} ${GPU_LINK_OPTS} $ -o ${CMAKE_BINARY_DIR}/server.out + DEPENDS server + COMMENT "Linking charm build against kokkos and cuda" + ) + + add_custom_target(result ALL + DEPENDS "${CMAKE_BINARY_DIR}/server.out" + ) +else() + add_executable(server.out server.cpp ${BASE_DIR}/charmtyles/backend/libcharmtyles.decl.h ${CMAKE_SOURCE_DIR}/server.decl.h) + target_include_directories(server.out PRIVATE ${BASE_DIR} ${BASE_DIR}/charmtyles/backend ${EIGEN_DIR}) + target_link_libraries(server.out Kokkos::kokkos Kokkos::kokkoskernels) +endif() diff --git a/src/Makefile b/src/Makefile deleted file mode 100644 index 084d354..0000000 --- a/src/Makefile +++ /dev/null @@ -1,21 +0,0 @@ -CHARMC=/home/adityapb/charm/charm/netlrts-linux-x86_64/bin/charmc -BASE_DIR=/home/adityapb/charm/LibCharmtyles -LIBS_DIR=$(BASE_DIR) -EIGEN_DIR = /usr/include/eigen3 -OPTS=-c++-option -std=c++17 -O3 -I$(EIGEN_DIR) -DNDEBUG -g - -all: server - -.PHONY: clean server.out - -server_ci: server.ci - $(CHARMC) -E server.ci - -server: server.cpp server_ci - $(CHARMC) $< -L$(LIBS_DIR)/charmtyles -lcharmtyles -I$(BASE_DIR) -I$(BASE_DIR)/charmtyles/backend -o $@.out $(OPTS) - -run-server: server.out - ./charmrun +p1 ./server.out ++server ++server-port 10000 ++local - -clean: - rm *.decl.h *.def.h *.out charmrun diff --git a/src/ast.hpp b/src/ast.hpp index 093c019..7bed1ee 100644 --- a/src/ast.hpp +++ b/src/ast.hpp @@ -1,106 +1,603 @@ -#include +#include +#include +#include #include +#include +#include +#include +#include +#include +#include + +using ctop = ct::util::Operation; +using ct_name_t = uint64_t; +using ct_array_t = std::variant, std::unique_ptr>; +std::unordered_map symbol_table; + +inline static void insert(ct_name_t name, ct_array_t arr) { + CkPrintf("Created array %" PRIu64 " on server\n", name); + symbol_table[name] = std::move(arr); +} -template -inline T extract(char *&msg, bool increment = true) -{ +inline static void remove(ct_name_t name) noexcept { + { + ckout<< "Before deletion, the available symbols were: "; + for (const auto& it : symbol_table) + ckout<< it.first << " "; + ckout << endl; + ckout<< "and I have been tasked to remove " << name << endl; + } + symbol_table.erase(name); +} + +static ct_array_t &lookup(ct_name_t name) { + auto find = symbol_table.find(name); + CkPrintf("Looking up array %" PRIu64 " on server\n", name); + if (find == std::end(symbol_table)) + CmiAbort("Symbol%" PRIu64 "not found", name); + return find->second; +} + +template inline T extract(char *&msg) noexcept { T arg = *(reinterpret_cast(msg)); - if (increment) - msg += sizeof(T); + msg += sizeof(T); return arg; } -enum class operation : uint64_t -{ - noop = 0, - add = 1, - sub = 2, - mul = 3, - div = 4, - matmul = 5, - copy = 6, - axpy = 7, - axpy_multiplier = 8, - pow = 10, - greater = 11, - lesser = 12, - geq = 13, - leq = 14, - eq = 15, - neq = 16, - logical_and = 17, - logical_or = 18, - logical_not = 19, - where = 20, - log = 21, - exp = 22, - abs = 23, - any = 24, - all = 25 -}; - -class astnode -{ -public: - bool store; - bool is_scalar; - double arg; - // FIXME double scalars fit into name, but should probably - // handle this better - uint64_t name; - operation oper; - std::vector operands; -}; - -operation lookup_operation(uint64_t opcode) -{ - return static_cast(opcode); +template inline T peek(char *&msg) noexcept { + return *(reinterpret_cast(msg)); } -astnode *decode(char *cmd) -{ - uint64_t opcode = extract(cmd); - astnode *node = new astnode; - node->oper = lookup_operation(opcode); - if (opcode == 0) - { - node->is_scalar = extract(cmd); - // if leaf is a scalar - if (node->is_scalar) - { - double value = extract(cmd); - memcpy(&(node->name), &value, sizeof(double)); +template +std::vector process_tensor(char *cmd, bool flush = false); + +template +std::pair getFlushedOperand(char *cmd) { + char *recurse_cmd = cmd; + + uint8_t marker = extract(cmd); + if (marker != 2) + CmiAbort("Matmuls only supported with Tensor Types"); + + uint8_t dim = extract(cmd); + if (dim < 1 || dim > 2) + CmiAbort("Matmuls not supported with dimension%" PRIu8 "", dim); + + cmd += dim * sizeof(uint64_t); + + uint32_t opcode = extract(cmd); + if (opcode) + process_tensor(recurse_cmd, true); + + cmd += sizeof(bool); + + uint64_t tensorID = extract(cmd); + return {dim, tensorID}; +} + +ctop inline to_ctop(uint64_t opcode) noexcept { + if (opcode >= 41 and opcode <= 52) + return ctop::unary_expr; + if (opcode >= 71 and opcode <= 83) + return ctop::binary_expr; + switch (opcode) { + case 0: + return ctop::noop; + case 1: + return ctop::add; + case 2: + return ctop::sub; + case 3: + return ctop::multiply; + case 4: + return ctop::divide; + case 5: + return ctop::matmul; + case 6: + return ctop::copy; + case 9: + return ctop::greater; + case 10: + return ctop::lesser; + case 11: + return ctop::geq; + case 12: + return ctop::leq; + case 13: + return ctop::eq; + case 14: + return ctop::neq; + case 15: + return ctop::logical_and; + case 16: + return ctop::logical_or; + case 17: + return ctop::logical_not; + case 18: + return ctop::where; + default: + return ctop::noop; + } +} + +std::shared_ptr +to_ct_unary(uint64_t opcode, const std::vector &args) noexcept { + switch (opcode) { + case 41: + return ct::unary_ops::exp(args); + case 42: + return ct::unary_ops::log(args); + case 43: + return ct::unary_ops::abs(args); + case 44: + return ct::unary_ops::negate(args); + case 45: + return ct::unary_ops::square(args); + case 46: + return ct::unary_ops::sqrt(args); + case 47: + return ct::unary_ops::reciprocal(args); + case 48: + return ct::unary_ops::sin(args); + case 49: + return ct::unary_ops::cos(args); + case 50: + return ct::unary_ops::relu(args); + case 51: + return ct::unary_ops::scale(args); + case 52: + return ct::unary_ops::add_constant(args); + default: + return nullptr; + } +} + +std::shared_ptr +to_ct_binary(uint64_t opcode, const std::vector &args) noexcept { + switch (opcode) { + case 71: + return ct::binary_ops::add(args); + case 72: + return ct::binary_ops::subtract(args); + case 73: + return ct::binary_ops::multiply(args); + case 74: + return ct::binary_ops::divide(args); + case 75: + return ct::binary_ops::power(args); + case 76: + return ct::binary_ops::modulo(args); + case 77: + return ct::binary_ops::max(args); + case 78: + return ct::binary_ops::min(args); + case 79: + return ct::binary_ops::greater_than(args); + case 80: + return ct::binary_ops::less_than(args); + case 81: + return ct::binary_ops::equal(args); + case 82: + return ct::binary_ops::atan2(args); + case 83: + return ct::binary_ops::weighted_average(args); + default: + return nullptr; + } +} + +double process_scalar(char *cmd) { + uint8_t marker = extract(cmd); + if (marker == 0) + return extract(cmd); + + /* dims = */ extract(cmd); + /* shape = */ extract(cmd); + + ctop ctopcode = to_ctop(extract(cmd)); + bool store = extract(cmd); + uint64_t tensorID = extract(cmd); + + if (ctopcode == ctop::noop) + return std::get(lookup(tensorID)); + + /* multLineFuse = */ extract(cmd); + /* NumcustomOpArgs = */ extract(cmd); + + if (ctopcode == ctop::unary_expr || ctopcode == ctop::binary_expr) + CmiAbort("Custom Ops are not defined for scalar type"); + + uint8_t numOperands = extract(cmd); + + // when we encounter a matmul, we treat it as a dot product returning a + // scalar. + if (ctopcode == ctop::matmul) { + uint32_t operand_size = extract(cmd); + std::pair xOperandInfo = + getFlushedOperand(cmd); + cmd += operand_size; + operand_size = extract(cmd); + std::pair yOperandInfo = + getFlushedOperand(cmd); + cmd += operand_size; + + const uint8_t &xDim = xOperandInfo.first; + const uint8_t &yDim = yOperandInfo.first; + const uint64_t &xID = xOperandInfo.second; + const uint64_t &yID = yOperandInfo.second; + + if (xDim == 1 and yDim == 1) { + const auto &x = std::get>(lookup(xID)); + const auto &y = std::get>(lookup(yID)); + + ct::scalar tensor0D = ct::dot(*x, *y); + double result = tensor0D.get(); + insert(tensorID, result); + return result; + } else { + CmiAbort("dot product of tensors does not result in a scalar"); } - else - node->name = extract(cmd); - return node; } - node->is_scalar = false; - node->name = extract(cmd); - node->store = extract(cmd); - uint8_t num_operands = extract(cmd); - for (uint8_t i = 0; i < num_operands; i++) - { + double result; + + if (numOperands == 1) { uint32_t operand_size = extract(cmd); - astnode *opnode = decode(cmd); - node->operands.push_back(opnode); + double lhs = process_scalar(cmd); cmd += operand_size; + + switch (ctopcode) { + case ctop::copy: + result = lhs; + break; + case ctop::logical_not: + result = !lhs; + break; + default: + CmiAbort("unrecognized unary op for scalar operands"); + } + } else if (numOperands == 2) { + uint32_t operand_size = extract(cmd); + double lhs = process_scalar(cmd); + cmd += operand_size; + operand_size = extract(cmd); + double rhs = process_scalar(cmd); + cmd += operand_size; + + switch (ctopcode) { + case ctop::add: + result = lhs + rhs; + break; + case ctop::sub: + result = lhs - rhs; + break; + case ctop::multiply: + result = lhs * rhs; + break; + case ctop::divide: + result = lhs / rhs; + break; + case ctop::greater: + result = lhs > rhs; + break; + case ctop::lesser: + result = lhs < rhs; + break; + case ctop::geq: + result = lhs >= rhs; + break; + case ctop::leq: + result = lhs <= rhs; + break; + case ctop::eq: + result = lhs == rhs; + break; + case ctop::neq: + result = lhs != rhs; + break; + default: + CmiAbort("unrecognized binary op for scalar operands"); + } + } else if (numOperands == 3) { + uint32_t operand_size = extract(cmd); + double lhs = process_scalar(cmd); + cmd += operand_size; + + operand_size = extract(cmd); + double rhs = process_scalar(cmd); + cmd += operand_size; + + operand_size = extract(cmd); + double ths = process_scalar(cmd); + cmd += operand_size; + + switch (ctopcode) { + case ctop::where: + result = ths ? lhs : rhs; + break; + default: + CmiAbort("unrecognized ternary op for scalar operands"); + } } - node->arg = extract(cmd); - return node; + + if (store) + insert(tensorID, result); + return result; } -void delete_ast(astnode *node) -{ - if (node->oper == operation::noop) - { - delete node; - return; +template +std::vector process_tensor(char *cmd, bool flush) { + if(peek(cmd) == 1) { + double result = process_scalar(cmd); + if constexpr (std::is_same_v) { + tensorAstNodeType temp_node(0, ctop::broadcast, result, {1}); + return {temp_node}; + } else if constexpr (std::is_same_v) { + tensorAstNodeType temp_node(0, ctop::broadcast, result, {1, 1}); + return {temp_node}; + } + } + + uint8_t marker = extract(cmd); + + std::vector shape; + shape.reserve(2); + + if (marker == 0) { + if constexpr (std::is_same_v) { + shape.push_back(extract(cmd)); + } else if constexpr (std::is_same_v) { + shape.push_back(extract(cmd)); + shape.push_back(extract(cmd)); + } + double value = extract(cmd); + tensorAstNodeType temp_node(0, ctop::broadcast, value, shape); + return {temp_node}; } - for (astnode *n : node->operands) - delete_ast(n); + uint8_t dims = extract(cmd); + + for (uint8_t i = 0; i < dims; i++) + shape.push_back(extract(cmd)); + + uint32_t opcode = extract(cmd); + bool store = extract(cmd); + uint64_t tensorID = extract(cmd); - delete node; + if (opcode == 0) { + const auto &tmp = std::get>(lookup(tensorID)); + return (*tmp)(); + } + bool multiLineFuse = extract(cmd); + + // Args for custom unops/binops + uint32_t numArgs = extract(cmd); + std::vector args; + for (uint32_t i = 0; i < numArgs; i++) + args.push_back(extract(cmd)); + + tensorAstNodeType rootNode; + ctop ctopcode = to_ctop(opcode); + if (ctopcode == ctop::unary_expr) { + rootNode = + tensorAstNodeType(-1, ctopcode, to_ct_unary(opcode, args), shape); + } else if (ctopcode == ctop::binary_expr) { + rootNode = + tensorAstNodeType(-1, ctopcode, to_ct_binary(opcode, args), shape); + } else { + rootNode = tensorAstNodeType(ctopcode, shape); + } + rootNode.multiLineFuse = multiLineFuse; + std::vector ast; + + uint8_t numOperands = extract(cmd); + + // when we encounter a matmul, we treat it as a : + // 1. a dot product returning a scalar if both the operands are vectors + // 2. a dot product returning a vector if one operand is a matrix and the + // other a vector + // 3. a gemm returning a matrix if both the operands are matrices + if (ctopcode == ctop::matmul) { + uint32_t operand_size = extract(cmd); + std::pair xOperandInfo = + getFlushedOperand(cmd); + cmd += operand_size; + operand_size = extract(cmd); + std::pair yOperandInfo = + getFlushedOperand(cmd); + cmd += operand_size; + + const uint8_t &xDim = xOperandInfo.first; + const uint8_t &yDim = yOperandInfo.first; + const uint64_t &xID = xOperandInfo.second; + const uint64_t &yID = yOperandInfo.second; + + if (xDim == 1 and yDim == 1) { + const auto &x = std::get>(lookup(xID)); + const auto &y = std::get>(lookup(yID)); + + ct::scalar tensor0D = ct::dot(*x, *y); + double result = tensor0D.get(); + + insert(tensorID, result); + tensorAstNodeType temp_node(0, ctop::broadcast, result, shape); + return {temp_node}; + } else if constexpr (std::is_same_v) { + if (xDim == 1 and yDim == 2) { + const auto &x = std::get>(lookup(xID)); + const auto &y = std::get>(lookup(yID)); + + std::unique_ptr tensor = std::make_unique(std::move(ct::dot(*x, *y))); + const auto &tensorNode = (*tensor)(); + insert(tensorID, std::move(tensor)); + + return tensorNode; + } else if (xDim == 2 and yDim == 1) { + const auto &x = std::get>(lookup(xID)); + const auto &y = std::get>(lookup(yID)); + + std::unique_ptr tensor = std::make_unique(std::move(ct::dot(*x, *y))); + const auto &tensorNode = (*tensor)(); + insert(tensorID, std::move(tensor)); + + return tensorNode; + } + } else if constexpr (std::is_same_v) { + if (xDim == 2 and yDim == 2) { + const auto &x = std::get>(lookup(xID)); + const auto &y = std::get>(lookup(yID)); + + std::unique_ptr tensor = std::make_unique(std::move(ct::matmul(*x, *y))); + const auto &tensorNode = (*tensor)(); + insert(tensorID, std::move(tensor)); + + return tensorNode; + } + } + } else if (ctopcode == ctop::copy) { + uint32_t operand_size = extract(cmd); + std::pair copyOperandInfo = + getFlushedOperand(cmd); + cmd += operand_size; + + const uint64_t ©ID = copyOperandInfo.second; + const auto © = std::get>(lookup(copyID)); + std::unique_ptr tensor = std::make_unique(*copy); + + const auto &tensorNode = (*tensor)(); + insert(tensorID, std::move(tensor)); + return tensorNode; + } + + if (numOperands == 1) { + uint32_t operand_size = extract(cmd); + std::vector left = process_tensor(cmd); + cmd += operand_size; + rootNode.left_ = 1; + rootNode.right_ = -1; + ast.reserve(left.size() + 1); + ast.emplace_back(rootNode); + std::copy(left.begin(), left.end(), std::back_inserter(ast)); + for (int i = 1; i != left.size(); ++i) { + if (ast[i].left_ != -1) { + ast[i].left_ += 1; + } + + if (ast[i].right_ != -1) { + ast[i].right_ += 1; + } + + if (ast[i].ter_ != -1) { + ast[i].ter_ += 1; + } + } + } else if (numOperands == 2) { + uint32_t operand_size = extract(cmd); + std::vector left = process_tensor(cmd); + cmd += operand_size; + operand_size = extract(cmd); + std::vector right = process_tensor(cmd); + cmd += operand_size; + + rootNode.left_ = 1; + rootNode.right_ = left.size() + 1; + + ast.reserve(left.size() + right.size() + 1); + ast.emplace_back(rootNode); + std::copy(left.begin(), left.end(), std::back_inserter(ast)); + std::copy(right.begin(), right.end(), std::back_inserter(ast)); + + for (int i = 1; i != left.size(); ++i) { + if (ast[i].left_ != -1) { + ast[i].left_ += 1; + } + + if (ast[i].right_ != -1) { + ast[i].right_ += 1; + } + + if (ast[i].ter_ != -1) { + ast[i].ter_ += 1; + } + } + + for (int i = 1 + left.size(); i != ast.size(); ++i) { + if (ast[i].left_ != -1) + { + ast[i].left_ += 1 + left.size(); + } + + if (ast[i].right_ != -1) + { + ast[i].right_ += 1 + left.size(); + } + + if (ast[i].ter_ != -1) + { + ast[i].ter_ += 1 + left.size(); + } + } + } else { + uint32_t operand_size = extract(cmd); + std::vector left = process_tensor(cmd); + cmd += operand_size; + operand_size = extract(cmd); + std::vector right = process_tensor(cmd); + cmd += operand_size; + operand_size = extract(cmd); + std::vector ter = process_tensor(cmd); + cmd += operand_size; + + rootNode.left_ = 1; + rootNode.right_ = left.size() + 1; + rootNode.ter_ = left.size() + right.size() + 1; + + ast.reserve(left.size() + right.size() + ter.size() + 1); + + ast.emplace_back(rootNode); + std::copy(left.begin(), left.end(), std::back_inserter(ast)); + std::copy(right.begin(), right.end(), std::back_inserter(ast)); + std::copy(ter.begin(), ter.end(), std::back_inserter(ast)); + + for (int i = 1; i != left.size(); ++i) { + if (ast[i].left_ != -1) + ast[i].left_ += 1; + + if (ast[i].right_ != -1) + ast[i].right_ += 1; + + if (ast[i].ter_ != -1) + ast[i].ter_ += 1; + } + + for (int i = 1 + left.size(); i != left.size() + right.size(); ++i) { + if (ast[i].left_ != -1) + ast[i].left_ += 1 + left.size(); + + if (ast[i].right_ != -1) + ast[i].right_ += 1 + left.size(); + + if (ast[i].ter_ != -1) + ast[i].ter_ += 1 + left.size(); + } + + for (int i = 1 + left.size() + right.size(); i != ast.size(); ++i) { + if (ast[i].left_ != -1) + ast[i].left_ += 1 + left.size() + right.size(); + + if (ast[i].right_ != -1) + ast[i].right_ += 1 + left.size() + right.size(); + + if (ast[i].ter_ != -1) + ast[i].ter_ += 1 + left.size() + right.size(); + } + } + + if (store or flush) { + std::unique_ptr tensor = std::make_unique(ast); + const auto &tensorNode = (*tensor)(); + insert(tensorID, std::move(tensor)); + return tensorNode; + } + return ast; } diff --git a/src/server.ci b/src/server.ci index 4ebb75b..4f12dfa 100644 --- a/src/server.ci +++ b/src/server.ci @@ -1,11 +1,36 @@ mainmodule server { extern module libcharmtyles; - PUPable pow_t; - PUPable log_t; - PUPable exp_t; - PUPable abs_t; - + // Register all basic unary operators + PUPable ct::negate_op; + PUPable ct::abs_op; + PUPable ct::square_op; + PUPable ct::sqrt_op; + PUPable ct::reciprocal_op; + PUPable ct::sin_op; + PUPable ct::cos_op; + PUPable ct::log_op; + PUPable ct::exp_op; + PUPable ct::scale_op; + PUPable ct::add_constant_op; + PUPable ct::relu_op; + + // Register all basic binary operators + PUPable ct::add_op; + PUPable ct::subtract_op; + PUPable ct::multiply_op; + PUPable ct::divide_op; + PUPable ct::power_op; + PUPable ct::modulo_op; + PUPable ct::max_op; + PUPable ct::min_op; + PUPable ct::greater_than_op; + PUPable ct::less_than_op; + PUPable ct::equal_op; + PUPable ct::atan2_op; + PUPable ct::weighted_average_op; + PUPable ct::axpy_op; + mainchare Main { entry Main(CkArgMsg*); diff --git a/src/server.cpp b/src/server.cpp index 9e88a39..6380d4e 100644 --- a/src/server.cpp +++ b/src/server.cpp @@ -1,4 +1,3 @@ -#include #include "server.hpp" #include "converse.h" #include "conv-ccs.h" @@ -50,7 +49,7 @@ void connection_handler(char *msg) void disconnection_handler(char *msg) { - CkExit(); + ct::sync(); char *cmd = msg + CmiMsgHeaderSizeBytes; int epoch = extract(cmd); uint32_t size = extract(cmd); @@ -106,7 +105,6 @@ void Main::handle_command(int epoch, uint8_t kind, uint32_t size, char *cmd) while (!command_buffer.empty() && std::get<0>(command_buffer.top()) == EPOCH) { buffer_t buffer = command_buffer.top(); - // CkPrintf("Executing buffered at epoch %i, current %i\n", std::get<0>(buffer), EPOCH); execute_command(std::get<0>(buffer), std::get<1>(buffer), (int)size, std::get<2>(buffer)); free(std::get<2>(buffer)); command_buffer.pop(); @@ -131,23 +129,20 @@ void Main::send_reply(int epoch, int size, char *msg) server.reply_buffer.erase(epoch); } -void Main::execute_operation(int epoch, int size, char *cmd) -{ - // first delete arrays +void Main::execute_operation(int epoch, int size, char *cmd) { uint32_t num_deletions = extract(cmd); - // CkPrintf("Num deletions = %u\n", num_deletions); - CkPrintf("Memory usage before delete is %f MB\n", CmiMemoryUsage() / (1024. * 1024.)); for (int i = 0; i < num_deletions; i++) - { - ct_name_t name = extract(cmd); - Server::remove(name); - } - CkPrintf("Memory usage after %u deletions is %f MB\n", num_deletions, CmiMemoryUsage() / (1024. * 1024.)); - - astnode *head = decode(cmd); - std::vector metadata; - calculate(head, metadata); - delete_ast(head); + remove(extract(cmd)); + uint32_t num_deferred_deletions = extract(cmd); + std::vector deferred_deletions; deferred_deletions.reserve(num_deferred_deletions); + for (int i = 0; i < num_deferred_deletions; i++) + deferred_deletions.emplace_back(extract(cmd)); + char* dimPos = cmd + sizeof(uint8_t); + if (peek(cmd) == 1) process_scalar(cmd); + else if (peek(dimPos) == 1) process_tensor(cmd); + else if (peek(dimPos) == 2) process_tensor(cmd); + for(const auto& it : deferred_deletions) + remove(it); } void Main::execute_command(int epoch, uint8_t kind, int size, char *cmd) @@ -203,8 +198,7 @@ void Main::execute_creation(int epoch, int size, char *cmd) { case 0: { - // create scalar - CmiAbort("Not implemented"); + CmiAbort("Scalars can only be made through reduction ops and matmuls"); } case 1: { @@ -214,18 +208,18 @@ void Main::execute_creation(int epoch, int size, char *cmd) if (has_buf) { double *init_buf = (double *)cmd; - res = ct::from_vector(init_buf, size); + res = ct::from_vector_unique(init_buf, size); } else if (has_init) { double init_value = extract(cmd); - res = ct::vector(size, init_value); + res = std::make_unique(size, init_value); } else { - res = ct::vector(size); + res = std::make_unique(size); } - Server::insert(res_name, std::move(res)); + insert(res_name, std::move(res)); break; } case 2: @@ -237,23 +231,22 @@ void Main::execute_creation(int epoch, int size, char *cmd) if (has_buf) { double *init_buf = (double *)cmd; - res = ct::from_matrix(init_buf, size1, size2); + res = ct::from_matrix_unique(init_buf, size1, size2); } else if (has_init) { double init_value = extract(cmd); - res = ct::matrix(size1, size2, init_value); + res = std::make_unique(size1, size2, init_value); } else { - res = ct::matrix(size1, size2); + res = std::make_unique(size1, size2); } - Server::insert(res_name, std::move(res)); + insert(res_name, std::move(res)); break; } default: { - // FIXME is this correctly caught? CmiAbort("Greater than 2 dimensions not supported"); } } @@ -262,50 +255,48 @@ void Main::execute_creation(int epoch, int size, char *cmd) void Main::execute_fetch(int epoch, int size, char *cmd) { ct_name_t name = extract(cmd); - ct_array_t &arr = Server::lookup(name); + ct_array_t &arr = lookup(name); char *reply = nullptr; int reply_size = 0; std::visit( - [&](auto &x) + [&](auto &x) + { + using T = std::decay_t; + if constexpr (std::is_same_v) + { + reply = (char *)&x; + reply_size += 8; + send_reply(epoch, reply_size, reply); + } + else if constexpr (std::is_same_v>) + { + std::vector values = x->get(); + reply = (char *)values.data(); + reply_size += values.size() * sizeof(double); + send_reply(epoch, reply_size, reply); + } + else if constexpr (std::is_same_v>) { - using T = std::decay_t; - if constexpr (std::is_same_v) - { - double value = x.get(); - reply = (char *)&value; - reply_size += 8; - send_reply(epoch, reply_size, reply); - // CcsSendReply(reply_size, reply); - } - else if constexpr (std::is_same_v) - { - std::vector values = x.get(); - reply = (char *)values.data(); - reply_size += values.size() * sizeof(double); - send_reply(epoch, reply_size, reply); - } - else if constexpr (std::is_same_v) - { - std::vector> values = x.get(); - std::vector flat; - for (const auto &row : values) - flat.insert(flat.end(), row.begin(), row.end()); - reply = reinterpret_cast(flat.data()); - reply_size += flat.size() * sizeof(double); - send_reply(epoch, reply_size, reply); - } - }, - arr); + std::vector> values = x->get(); + std::vector flat; + for (const auto &row : values) + flat.insert(flat.end(), row.begin(), row.end()); + reply = reinterpret_cast(flat.data()); + reply_size += flat.size() * sizeof(double); + send_reply(epoch, reply_size, reply); + } + }, + arr); } void Main::execute_delete(int epoch, int size, char *cmd) { uint32_t num_deletions = extract(cmd); for (int i = 0; i < num_deletions; i++) - { - ct_name_t name = extract(cmd); - Server::remove(name); - } + remove(extract(cmd)); + uint32_t num_deferred_deletions = extract(cmd); + for (int i = 0; i < num_deferred_deletions; i++) + remove(extract(cmd)); } void Main::execute_disconnect(int epoch, int size, char *cmd) @@ -319,6 +310,7 @@ void Main::execute_disconnect(int epoch, int size, char *cmd) void Main::execute_sync(int epoch, int size, char *cmd) { + ct::sync(); CkPrintf("Execution time = %f\n", CkTimer() - start_time); bool r = true; send_reply(epoch, 1, (char *)&r); diff --git a/src/server.hpp b/src/server.hpp index 211ec84..f4f7bb2 100644 --- a/src/server.hpp +++ b/src/server.hpp @@ -1,24 +1,12 @@ -#include -#include -#include -#include -#include -#include -#include #include "ast.hpp" #include "server.decl.h" -using ct_name_t = uint64_t; -using ct_array_t = std::variant; using buffer_t = std::tuple; -std::unordered_map symbol_table; std::stack client_ids; CProxy_Main main_proxy; -ct_array_t calculate(astnode *node, std::vector &metadata); - enum class opkind : uint8_t { creation = 0, @@ -58,134 +46,6 @@ class Main : public CBase_Main void execute_sync(int epoch, int size, char *cmd); }; -class pow_t : public ct::unary_operator -{ -public: - pow_t(double arg) : arg_(arg) {} - ~pow_t() {} - - using ct::unary_operator::unary_operator; - - inline double operator()(std::size_t index, double value) override final - { - return std::pow(value, arg_); - } - - inline double operator()(std::size_t rows, std::size_t cols, double value) override final - { - return std::pow(value, arg_); - } - - PUPable_decl(pow_t); - pow_t(CkMigrateMessage *m) - : ct::unary_operator(m) - { - } - - void pup(PUP::er &p) final - { - p | arg_; - ct::unary_operator::pup(p); - } - -private: - double arg_; -}; - -class log_t : public ct::unary_operator -{ -public: - log_t(double arg) : arg_(arg) {} - ~log_t() {} - - using ct::unary_operator::unary_operator; - - inline double operator()(std::size_t index, double value) override final - { - return std::log(value) / std::log(arg_); - } - - inline double operator()(std::size_t rows, std::size_t cols, double value) override final - { - return std::log(value) / std::log(arg_); - } - - PUPable_decl(log_t); - log_t(CkMigrateMessage *m) - : ct::unary_operator(m) - { - } - - void pup(PUP::er &p) final - { - p | arg_; - ct::unary_operator::pup(p); - } - -private: - double arg_; -}; - -class exp_t : public ct::unary_operator -{ -public: - exp_t() = default; - ~exp_t() {} - - using ct::unary_operator::unary_operator; - - inline double operator()(std::size_t index, double value) override final - { - return std::exp(value); - } - - inline double operator()(std::size_t rows, std::size_t cols, double value) override final - { - return std::exp(value); - } - - PUPable_decl(exp_t); - exp_t(CkMigrateMessage *m) - : ct::unary_operator(m) - { - } - - void pup(PUP::er &p) final - { - ct::unary_operator::pup(p); - } -}; - -class abs_t : public ct::unary_operator -{ -public: - abs_t() = default; - ~abs_t() {} - - using ct::unary_operator::unary_operator; - - inline double operator()(std::size_t index, double value) override final - { - return std::abs(value); - } - - inline double operator()(std::size_t rows, std::size_t cols, double value) override final - { - return std::abs(value); - } - - PUPable_decl(abs_t); - abs_t(CkMigrateMessage *m) - : ct::unary_operator(m) - { - } - - void pup(PUP::er &p) final - { - ct::unary_operator::pup(p); - } -}; - class Server { public: @@ -197,22 +57,6 @@ class Server client_ids.push((uint8_t)i); } - inline static void insert(ct_name_t name, ct_array_t arr) - { -#ifndef NDEBUG - CkPrintf("Created array %" PRIu64 " on server\n", name); -#endif - symbol_table[name] = arr; - } - - inline static void remove(ct_name_t name) - { - symbol_table.erase(name); -#ifndef NDEBUG - CkPrintf("Deleted array %" PRIu64 " on server\n", name); -#endif - } - inline static uint8_t get_client_id() { if (client_ids.empty()) @@ -221,935 +65,6 @@ class Server client_ids.pop(); return client_id; } - - static ct_array_t &lookup(ct_name_t name) - { - auto find = symbol_table.find(name); - if (find == std::end(symbol_table)) - { -#ifndef NDEBUG - CkPrintf("Active symbols: "); - for (auto it : symbol_table) - CkPrintf("%" PRIu64 ", ", it.first); - CkPrintf("\n"); -#endif - CmiAbort("Symbol %i not found", name); - } - return find->second; - } -}; - -ct_array_t calculate(astnode *node, std::vector &metadata) -{ - switch (node->oper) - { - case operation::noop: - { - if (node->is_scalar) - return *reinterpret_cast(&(node->name)); - else - return Server::lookup(node->name); - } - case operation::add: - { - ct_array_t s1 = calculate(node->operands[0], metadata); - ct_array_t s2 = calculate(node->operands[1], metadata); - ct_array_t res; - - std::visit( - [&](auto &x, auto &y) - { - using T = std::decay_t; - using V = std::decay_t; - if constexpr ((std::is_same_v && std::is_same_v)) - { - res = x.get() + y.get(); - } - else if constexpr ((std::is_same_v)) - { - res = x.get() + y; - } - else if constexpr ((std::is_same_v)) - { - res = x + y.get(); - } - else if constexpr ((std::is_same_v && std::is_same_v) || - (std::is_same_v && std::is_same_v)) - { - CkAbort("Vector + Matrix operations not supported"); - } - else - { - // Everything else should work with the normal + operator - res = x + y; - } - }, - s1, s2); - - if (node->store) - Server::insert(node->name, res); - return res; - } - case operation::sub: - { - ct_array_t s1 = calculate(node->operands[0], metadata); - ct_array_t s2 = calculate(node->operands[1], metadata); - ct_array_t res; - - std::visit( - [&](auto &x, auto &y) - { - using T = std::decay_t; - using V = std::decay_t; - if constexpr ((std::is_same_v && std::is_same_v)) - { - res = x.get() - y.get(); - } - else if constexpr ((std::is_same_v)) - { - res = x.get() - y; - } - else if constexpr ((std::is_same_v)) - { - res = x - y.get(); - } - else if constexpr ((std::is_same_v && std::is_same_v) || - (std::is_same_v && std::is_same_v)) - { - CkAbort("Vector + Matrix operations not supported"); - } - else - { - // Everything else should work with the normal + operator - res = x - y; - } - }, - s1, s2); - - if (node->store) - { - Server::insert(node->name, res); - } - return res; - } - case operation::mul: - { - ct_array_t s1 = calculate(node->operands[0], metadata); - ct_array_t s2 = calculate(node->operands[1], metadata); - ct_array_t res; - - std::visit( - [&](auto &x, auto &y) - { - using T = std::decay_t; - using V = std::decay_t; - if constexpr ((std::is_same_v && std::is_same_v)) - { - res = x.get() * y.get(); - } - else if constexpr ((std::is_same_v)) - { - res = x.get() * y; - } - else if constexpr ((std::is_same_v)) - { - res = x * y.get(); - } - else if constexpr ((std::is_same_v && std::is_same_v) || - (std::is_same_v && std::is_same_v)) - { - CkAbort("Vector + Matrix operations not supported"); - } - else - { - res = x * y; - } - }, - s1, s2); - - if (node->store) - { - Server::insert(node->name, res); - } - return res; - } - case operation::div: - { - ct_array_t s1 = calculate(node->operands[0], metadata); - ct_array_t s2 = calculate(node->operands[1], metadata); - ct_array_t res; - - std::visit( - [&](auto &x, auto &y) - { - using T = std::decay_t; - using V = std::decay_t; - if constexpr ((std::is_same_v && std::is_same_v)) - { - res = x.get() / y.get(); - } - else if constexpr ((std::is_same_v)) - { - res = x.get() / y; - } - else if constexpr ((std::is_same_v)) - { - res = x / y.get(); - } - else if constexpr ((std::is_same_v && std::is_same_v) || - (std::is_same_v && std::is_same_v)) - { - CkAbort("Vector + Matrix operations not supported"); - } - else - { - res = x / y; - } - }, - s1, s2); - - if (node->store) - { - Server::insert(node->name, res); - } - return res; - } - case operation::matmul: - { - ct_array_t s1 = calculate(node->operands[0], metadata); - ct_array_t s2 = calculate(node->operands[1], metadata); - ct_array_t res; - - std::visit( - [&](auto &x, auto &y) - { - using T = std::decay_t; - using V = std::decay_t; - if constexpr (std::is_same_v && - std::is_same_v) - CmiAbort("Matrix multiplication not yet implemented"); - else if constexpr (std::is_same_v && - std::is_same_v) - res = ct::dot(x, y); - else if constexpr ((std::is_same_v || - std::is_same_v) && - std::is_same_v) - res = ct::dot(x, y); - else - CmiAbort("Operation not permitted5"); - }, - s1, s2); - - if (node->store) - { - Server::insert(node->name, res); - } - return res; - } - case operation::copy: - { - ct_array_t s1 = calculate(node->operands[0], metadata); - ct_array_t res; - - std::visit( - [&](auto &x) - { - using T = std::decay_t; - if constexpr (std::is_same_v || - std::is_same_v || std::is_same_v) - res = x; - else - CmiAbort("Matrix copy not implemented"); - }, - s1); - - if (node->store) - { - Server::insert(node->name, res); - } - return res; - } - case operation::axpy: - { - ct_array_t s1 = calculate(node->operands[0], metadata); - ct_array_t s2 = calculate(node->operands[1], metadata); - ct_array_t s3 = calculate(node->operands[2], metadata); - ct_array_t res; - - std::visit( - [&](auto &a, auto &x, auto &y) - { - using S = std::decay_t; - using T = std::decay_t; - using V = std::decay_t; - if constexpr (std::is_same_v && - std::is_same_v && - std::is_same_v) - res = ct::axpy(a, x, y); - else - CmiAbort("Operation not permitted6"); - }, - s1, s2, s3); - - if (node->store) - { - Server::insert(node->name, res); - } - return res; - } - - case operation::where: - { - ct_array_t s1 = calculate(node->operands[0], metadata); - ct_array_t s2 = calculate(node->operands[1], metadata); - ct_array_t s3 = calculate(node->operands[2], metadata); - ct_array_t res; - - std::visit( - [&](auto &a, auto &x, auto &y) - { - using S = std::decay_t; - using T = std::decay_t; - using V = std::decay_t; - if constexpr (std::is_same_v && - std::is_same_v && - std::is_same_v || - std::is_same_v && - std::is_same_v && - std::is_same_v) - res = ct::where(a, x, y); - else - CmiAbort("All where operations must be of the same type"); - }, - s1, s2, s3); - - if (node->store) - { - Server::insert(node->name, res); - } - return res; - } - - case operation::pow: - { - ct_array_t s1 = calculate(node->operands[0], metadata); - ct_array_t res; - std::visit( - [&](auto &a) - { - using T = std::decay_t; - if constexpr (std::is_same_v) - { - std::shared_ptr pow_ = std::make_shared(node->arg); - ct::vector vec = ct::unary_expr(a, pow_); - res = ct_array_t{vec}; - } - else if constexpr (std::is_same_v) - { - std::shared_ptr pow_ = std::make_shared(node->arg); - ct::matrix mat = ct::unary_expr(a, pow_); - res = ct_array_t{mat}; - } - else if constexpr (std::is_same_v) - { - res = std::pow(a.get(), node->arg); - } - else if constexpr (std::is_same_v) - { - res = std::pow(a, node->arg); - } - }, - s1); - if (node->store) - { - Server::insert(node->name, res); - } - return res; - } - - case operation::log: - { - ct_array_t s1 = calculate(node->operands[0], metadata); - ct_array_t res; - std::visit( - [&](auto &a) - { - using T = std::decay_t; - if constexpr (std::is_same_v) - { - std::shared_ptr log_ = std::make_shared(node->arg); - ct::vector vec = ct::unary_expr(a, log_); - res = ct_array_t{vec}; - } - else if constexpr (std::is_same_v) - { - std::shared_ptr log_ = std::make_shared(node->arg); - ct::matrix mat = ct::unary_expr(a, log_); - res = ct_array_t{mat}; - } - else if constexpr (std::is_same_v) - { - res = std::log(a.get()) / std::log(node->arg); - } - else if constexpr (std::is_same_v) - { - res = std::log(a) / std::log(node->arg); - } - }, - s1); - if (node->store) - { - Server::insert(node->name, res); - } - return res; - } - - case operation::exp: - { - ct_array_t s1 = calculate(node->operands[0], metadata); - ct_array_t res; - std::visit( - [&](auto &a) - { - using T = std::decay_t; - if constexpr (std::is_same_v) - { - std::shared_ptr exp_ = std::make_shared(); - ct::vector vec = ct::unary_expr(a, exp_); - res = ct_array_t{vec}; - } - else if constexpr (std::is_same_v) - { - std::shared_ptr exp_ = std::make_shared(); - ct::matrix mat = ct::unary_expr(a, exp_); - res = ct_array_t{mat}; - } - else if constexpr (std::is_same_v) - { - res = std::exp(a.get()); - } - else if constexpr (std::is_same_v) - { - res = std::exp(a); - } - }, - s1); - if (node->store) - { - Server::insert(node->name, res); - } - return res; - } - - case operation::abs: - { - ct_array_t s1 = calculate(node->operands[0], metadata); - ct_array_t res; - std::visit( - [&](auto &a) - { - using T = std::decay_t; - if constexpr (std::is_same_v) - { - std::shared_ptr abs_ = std::make_shared(); - ct::vector vec = ct::unary_expr(a, abs_); - res = ct_array_t{vec}; - } - else if constexpr (std::is_same_v) - { - std::shared_ptr abs_ = std::make_shared(); - ct::matrix mat = ct::unary_expr(a, abs_); - res = ct_array_t{mat}; - } - else if constexpr (std::is_same_v) - { - res = std::abs(a.get()); - } - else if constexpr (std::is_same_v) - { - res = std::abs(a); - } - }, - s1); - if (node->store) - { - Server::insert(node->name, res); - } - return res; - } - - case operation::axpy_multiplier: - { - ct_array_t s1 = calculate(node->operands[0], metadata); - ct_array_t s2 = calculate(node->operands[1], metadata); - ct_array_t s3 = calculate(node->operands[2], metadata); - ct_array_t multiplier = calculate(node->operands[3], metadata); - ct_array_t res; - - std::visit( - [&](auto &multiplier, auto &a, auto &x, auto &y) - { - using S = std::decay_t; - using T = std::decay_t; - using V = std::decay_t; - using M = std::decay_t; - if constexpr (std::is_same_v && - std::is_same_v && - std::is_same_v && - std::is_same_v) - res = ct::axpy(multiplier * a, x, y); - else - CmiAbort("Operation not permitted7"); - }, - multiplier, s1, s2, s3); - - if (node->store) - Server::insert(node->name, res); - return res; - } - case operation::greater: - { - ct_array_t s1 = calculate(node->operands[0], metadata); - ct_array_t s2 = calculate(node->operands[1], metadata); - ct_array_t res; - - std::visit( - [&](auto &x, auto &y) - { - using T = std::decay_t; - using V = std::decay_t; - if constexpr ((std::is_same_v && std::is_same_v) || - (std::is_same_v && std::is_same_v)) - { - CkAbort("Vector + Matrix operations not supported"); - } - else if constexpr ((std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v)) - { - res = x > y; - } - else if constexpr ((std::is_same_v && std::is_same_v)) - { - res = static_cast(x.get() > y.get()); - } - else if constexpr ((std::is_same_v)) - { - res = static_cast(x.get() > y); - } - else if constexpr ((std::is_same_v)) - { - res = static_cast(x > y.get()); - } - else - { - res = static_cast(x > y); - } - }, - s1, s2); - - if (node->store) - Server::insert(node->name, res); - return res; - } - case operation::lesser: - { - ct_array_t s1 = calculate(node->operands[0], metadata); - ct_array_t s2 = calculate(node->operands[1], metadata); - ct_array_t res; - - std::visit( - [&](auto &x, auto &y) - { - using T = std::decay_t; - using V = std::decay_t; - if constexpr ((std::is_same_v && std::is_same_v) || - (std::is_same_v && std::is_same_v)) - { - CkAbort("Vector + Matrix operations not supported"); - } - else if constexpr ((std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v)) - { - res = x < y; - } - else if constexpr ((std::is_same_v && std::is_same_v)) - { - res = static_cast(x.get() < y.get()); - } - else if constexpr ((std::is_same_v)) - { - res = static_cast(x.get() < y); - } - else if constexpr ((std::is_same_v)) - { - res = static_cast(x < y.get()); - } - else - { - res = static_cast(x < y); - } - }, - s1, s2); - - if (node->store) - Server::insert(node->name, res); - return res; - } - - case operation::geq: - { - ct_array_t s1 = calculate(node->operands[0], metadata); - ct_array_t s2 = calculate(node->operands[1], metadata); - ct_array_t res; - - std::visit( - [&](auto &x, auto &y) - { - using T = std::decay_t; - using V = std::decay_t; - if constexpr ((std::is_same_v && std::is_same_v) || - (std::is_same_v && std::is_same_v)) - { - CkAbort("Vector + Matrix operations not supported"); - } - else if constexpr ((std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v)) - { - res = x >= y; - } - else if constexpr ((std::is_same_v && std::is_same_v)) - { - res = static_cast(x.get() >= y.get()); - } - else if constexpr ((std::is_same_v)) - { - res = static_cast(x.get() >= y); - } - else if constexpr ((std::is_same_v)) - { - res = static_cast(x >= y.get()); - } - else - { - res = static_cast(x >= y); - } - }, - s1, s2); - - if (node->store) - Server::insert(node->name, res); - return res; - } - - case operation::leq: - { - ct_array_t s1 = calculate(node->operands[0], metadata); - ct_array_t s2 = calculate(node->operands[1], metadata); - ct_array_t res; - - std::visit( - [&](auto &x, auto &y) - { - using T = std::decay_t; - using V = std::decay_t; - if constexpr ((std::is_same_v && std::is_same_v) || - (std::is_same_v && std::is_same_v)) - { - CkAbort("Vector + Matrix operations not supported"); - } - else if constexpr ((std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v)) - { - res = x <= y; - } - else if constexpr ((std::is_same_v && std::is_same_v)) - { - res = static_cast(x.get() <= y.get()); - } - else if constexpr ((std::is_same_v)) - { - res = static_cast(x.get() <= y); - } - else if constexpr ((std::is_same_v)) - { - res = static_cast(x <= y.get()); - } - else - { - res = static_cast(x <= y); - } - }, - s1, s2); - - if (node->store) - Server::insert(node->name, res); - return res; - } - - case operation::eq: - { - ct_array_t s1 = calculate(node->operands[0], metadata); - ct_array_t s2 = calculate(node->operands[1], metadata); - ct_array_t res; - - std::visit( - [&](auto &x, auto &y) - { - using T = std::decay_t; - using V = std::decay_t; - if constexpr ((std::is_same_v && std::is_same_v) || - (std::is_same_v && std::is_same_v)) - { - CkAbort("Vector + Matrix operations not supported"); - } - else if constexpr ((std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v)) - { - res = x == y; - } - else if constexpr ((std::is_same_v && std::is_same_v)) - { - res = static_cast(x.get() == y.get()); - } - else if constexpr ((std::is_same_v)) - { - res = static_cast(x.get() == y); - } - else if constexpr ((std::is_same_v)) - { - res = static_cast(x == y.get()); - } - else - { - res = static_cast(x == y); - } - }, - s1, s2); - - if (node->store) - Server::insert(node->name, res); - return res; - } - - case operation::neq: - { - ct_array_t s1 = calculate(node->operands[0], metadata); - ct_array_t s2 = calculate(node->operands[1], metadata); - ct_array_t res; - - std::visit( - [&](auto &x, auto &y) - { - using T = std::decay_t; - using V = std::decay_t; - if constexpr ((std::is_same_v && std::is_same_v) || - (std::is_same_v && std::is_same_v)) - { - CkAbort("Vector + Matrix operations not supported"); - } - else if constexpr ((std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v)) - { - res = x != y; - } - else if constexpr ((std::is_same_v && std::is_same_v)) - { - res = static_cast(x.get() != y.get()); - } - else if constexpr ((std::is_same_v)) - { - res = static_cast(x.get() != y); - } - else if constexpr ((std::is_same_v)) - { - res = static_cast(x != y.get()); - } - else - { - res = static_cast(x != y); - } - }, - s1, s2); - - if (node->store) - Server::insert(node->name, res); - return res; - } - - case operation::logical_and: - { - ct_array_t s1 = calculate(node->operands[0], metadata); - ct_array_t s2 = calculate(node->operands[1], metadata); - ct_array_t res; - - std::visit( - [&](auto &x, auto &y) - { - using T = std::decay_t; - using V = std::decay_t; - if constexpr ((std::is_same_v && std::is_same_v) || - (std::is_same_v && std::is_same_v)) - { - CkAbort("Vector + Matrix operations not supported"); - } - else if constexpr ((std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v)) - { - res = x && y; - } - else if constexpr ((std::is_same_v && std::is_same_v)) - { - res = static_cast(x.get() && y.get()); - } - else if constexpr ((std::is_same_v)) - { - res = static_cast(x.get() && y); - } - else if constexpr ((std::is_same_v)) - { - res = static_cast(x && y.get()); - } - else - { - res = static_cast(x && y); - } - }, - s1, s2); - - if (node->store) - Server::insert(node->name, res); - return res; - } - - case operation::logical_or: - { - ct_array_t s1 = calculate(node->operands[0], metadata); - ct_array_t s2 = calculate(node->operands[1], metadata); - ct_array_t res; - - std::visit( - [&](auto &x, auto &y) - { - using T = std::decay_t; - using V = std::decay_t; - if constexpr ((std::is_same_v && std::is_same_v) || - (std::is_same_v && std::is_same_v)) - { - CkAbort("Vector + Matrix operations not supported"); - } - else if constexpr ((std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v)) - { - res = x || y; - } - else if constexpr ((std::is_same_v && std::is_same_v)) - { - res = static_cast(x.get() || y.get()); - } - else if constexpr ((std::is_same_v)) - { - res = static_cast(x.get() || y); - } - else if constexpr ((std::is_same_v)) - { - res = static_cast(x || y.get()); - } - else - { - res = static_cast(x || y); - } - }, - s1, s2); - - if (node->store) - Server::insert(node->name, res); - return res; - } - - case operation::logical_not: - { - ct_array_t s1 = calculate(node->operands[0], metadata); - ct_array_t res; - - std::visit( - [&](auto &x) - { - using T = std::decay_t; - if constexpr ((std::is_same_v || std::is_same_v)) - { - res = !x; - } - else if constexpr ((std::is_same_v)) - { - res = static_cast(!x.get()); - } - else - { - res = static_cast(!x); - } - }, - s1); - - if (node->store) - Server::insert(node->name, res); - return res; - } - - case operation::any: - { - ct_array_t s1 = calculate(node->operands[0], metadata); - ct_array_t res; - - std::visit( - [&](auto &x) - { - using T = std::decay_t; - if constexpr ((std::is_same_v || std::is_same_v)) - { - res = static_cast(x.any()); - } - else if constexpr ((std::is_same_v)) - { - res = static_cast(x.get()); - } - else - { - res = static_cast(x); - } - }, - s1); - - if (node->store) - Server::insert(node->name, res); - return res; - } - - case operation::all: - { - ct_array_t s1 = calculate(node->operands[0], metadata); - ct_array_t res; - - std::visit( - [&](auto &x) - { - using T = std::decay_t; - if constexpr ((std::is_same_v || std::is_same_v)) - { - res = static_cast(x.all()); - } - else if constexpr ((std::is_same_v)) - { - res = static_cast(x.get()); - } - else - { - res = static_cast(x); - } - }, - s1); - - if (node->store) - Server::insert(node->name, res); - return res; - } - - default: - { - CmiAbort("Operation not implemented8"); - } - } }; #include "server.def.h"