Skip to content

sine2pi/ASR-model

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

    This model optionally uses any mix of these:
        
        extract_args = {
    
            "spectrogram": False,
            "pitch": False,
            "waveform": False,
            "harmonics": False,
            "aperiodics": False,
            "phase": False,
            "hilbert": False,
            "pitch_tokens": True,
    
        }
# helpers constants and functions

THETA = 30000.0

def l2norm(t):
    return torch.nn.functional.normalize(t, dim = -1)

def have(a):
    return a is not None
    
def Sequential(*modules):
    return nn.Sequential(*filter(have, modules))    

def aorb(a, b):
    return a if have(a) else b

def no_none(xa):
    return xa.apply(lambda tensor: tensor if tensor is not None else None)

@dataclass
class Dimensions:
    tokens: int
    mels: int
    dims: int
    head: int
    layer: int
    act: str
    n_type: str

def get_norm(n_type: str, dims = None, num_groups = None)-> nn.Module:

    if n_type in ["batchnorm", "instancenorm"] and dims is None:
        raise ValueError(f"'{n_type}' requires 'dims'.")
    if n_type == "groupnorm" and num_groups is None:
        raise ValueError(f"'{n_type}' requires 'num_groups'.")

    norm_map = {
        "layernorm": lambda: nn.LayerNorm(normalized_shape=dims, bias=False),
        "instancenorm": lambda: nn.InstanceNorm1d(num_features=dims, affine=False, track_running_stats=False),     
        "rmsnorm": lambda: nn.RMSNorm(normalized_shape=dims),        
        "batchnorm": lambda: nn.BatchNorm1d(num_features=dims),
        "instancenorm2d": lambda: nn.InstanceNorm2d(num_features=dims),
        "groupnorm": lambda: nn.GroupNorm(num_groups=num_groups, num_channels=dims),
        }
   
    norm_func = norm_map.get(n_type)
    if norm_func:
        return norm_func()
    else:
        print(f"Warning: Norm type '{n_type}' not found. Returning LayerNorm.")
        return nn.LayerNorm(dims) 

def get_activation(act: str) -> nn.Module:

    act_map = {
        "gelu": nn.GELU(), 
        "relu": nn.ReLU(), 
        "sigmoid": nn.Sigmoid(), 
        "tanh": nn.Tanh(), 
        "swish": nn.SiLU(), 
        "tanhshrink": nn.Tanhshrink(), 
        "softplus": nn.Softplus(), 
        "softshrink": nn.Softshrink(), 
        "leaky_relu": nn.LeakyReLU(), 
        "elu": nn.ELU(),
    }
    return act_map.get(act, nn.GELU())

def sinusoids(ctx, dims, theta=THETA):
    tscales = torch.exp(-torch.log(torch.tensor(float(theta))) / (dims // 2 - 1) * torch.arange(dims // 2, device=device, dtype=torch.float32))
    scaled = torch.arange(ctx, device=device, dtype=torch.float32).unsqueeze(1) * tscales.unsqueeze(0)
    positional_embedding = nn.Parameter(torch.cat([torch.sin(scaled), torch.cos(scaled)], dim=1) , requires_grad=True)
    return positional_embedding    

###### Model components

class AudioEncoder(nn.Module):
    def __init__(n, mels, dims, head, act, n_type, norm=False, enc=False):
        super().__init__()
        
        act_fn = get_activation(act)
        n.conv1 = nn.Conv1d(mels, dims, kernel_size=3, stride=1, padding=1)
        n.conv2 = nn.Conv1d(1, dims, kernel_size=3, stride=1, padding=1)
        
        n.encoder = nn.Sequential(
            act_fn, nn.Conv1d(dims, dims, kernel_size=3, stride=1, padding=1),
            act_fn, nn.Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)

        theta = nn.Parameter(torch.tensor(THETA), requires_grad=True)
        n.audio = lambda length, dims: sinusoids(length, dims, theta)
        n.EncoderLayer = nn.TransformerEncoderLayer(d_model=dims, nhead=head, batch_first=True) if enc else nn.Identity()
        n.ln = get_norm(n_type, dims) if norm else nn.Identity()

    def _process_feature(n, x):
        if x.dim() == 2:
            x = x.unsqueeze(0)   
        if x.shape[1] > 1:      
            x = n.conv1(x)
        else:
            x = n.conv2(x)
        x = n.encoder(x).permute(0, 2, 1).contiguous().to(device, dtype)
        x = x + n.audio(x.shape[1], x.shape[-1]).to(device, dtype)
        x = n.ln(x)
        return n.EncoderLayer(x)
               
    def forward(n, x):
        if isinstance(x, TensorDict):
            return x.apply(n._process_feature)
        else:
            return n._process_feature(x)

class rotary(nn.Module):
    def __init__(n, dims, head):
        super().__init__()

        n.head_dim = dims // head
        n._head = nn.Linear(dims, 1) 

    def _compute_freqs(n, mask=None):
        if mask is None:
            scale = torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 4000/200)), n.head_dim // 2, device=device, dtype=dtype) / 2595) - 1
            return 200 * scale / 1000
        else:
            return torch.arange(0, n.head_dim, 2, device=device, dtype=dtype) / n.head_dim * torch.log(torch.tensor(THETA))

    def radius(n, xa, freqs, mask=None):
        if xa is not None:
            per_step = n._head(xa).squeeze(-1)
            radius = torch.clamp(per_step / 100.0, 0.5, 2.0)
            return torch.polar(radius.unsqueeze(-1), freqs.unsqueeze(0))
        else:
            return torch.polar(torch.ones_like(freqs).unsqueeze(0), freqs.unsqueeze(0))

    def forward(n, x=None, xa=None, mask=None):
        ctx = x.shape[2]
        t = torch.arange(ctx, device=device, dtype=dtype).float()
        freqs = torch.einsum('i,j->ij', t,  n._compute_freqs(mask))
        freqs = n.radius(xa, freqs, mask)

        x1 = x[..., :freqs.shape[-1]*2]
        x2 = x[..., freqs.shape[-1]*2:]
        orig_shape = x1.shape
        x1 = x1.float().reshape(*x1.shape[:-1], -1, 2).contiguous()
        x1 = torch.view_as_complex(x1) * freqs
        x1 = torch.view_as_real(x1).flatten(-2)
        x1 = x1.view(orig_shape)
        rotary_out = torch.cat([x1.type_as(x), x2], dim=-1)
        return rotary_out
    
class attention(nn.Module):
    def __init__(n, dims, head, layer, n_type=None, modal=False, pattern=None): 
        super().__init__()
        n.layer = layer

        n.scale = (dims // head) ** -0.25
        n.modal = modal

        n.q   = nn.Sequential(get_norm(n_type, dims), nn.Linear(dims, dims), Rearrange('b c (h d) -> b h c d', h = head))
        n.kv  = nn.Sequential(get_norm(n_type, dims), nn.Linear(dims, dims * 2), Rearrange('b c (kv h d) -> kv b h c d', kv = 2, h = head))
        n.out = nn.Sequential(Rearrange('b h c d -> b c (h d)'), nn.Linear(dims, dims))

        n.conv = nn.Conv2d(head, head, 1, bias=False) if modal else nn.Identity()
        n.ln = get_norm(n_type, dims // head)
        n.rot = rotary(dims, head)

    def forward(n, x, xa=None, mask=None, pt=None, skip=False, pattern=None): 
        b, c, d = x.shape
        p = pattern if pattern is not None else None

        k, v = n.kv(x if xa is None else xa)
        q = n.q(x)
        q, k = n.rot(q, xa=None, mask=mask), n.rot(k, xa=None, mask=mask)  

        if skip and not have(p):
            a = scaled_dot_product_attention(n.ln(q), 
                n.ln(k[:, :, ::max(1, 6 - n.layer), :]), v[:, :, ::max(1, 6 - n.layer), :], is_causal=mask is not None)
            
        elif have(p) and p > 1:
            k, v = k[:, :, ::p, :], v[:, :, ::p, :]
            a = scaled_dot_product_attention(n.ln(q), n.ln(k), v, is_causal=mask is not None)
        else:
            a = scaled_dot_product_attention(n.ln(q), n.ln(k), v, is_causal=mask is not None)

        if n.modal and xa is not None:
            (ka, va), (kb, vb) = n.kv(x), n.kv(xa)
            qa, qb, ka, kb = n.rot(qa), n.rot(qb), n.rot(ka), n.rot(kb)

            if have(p) and p > 1:
                k, ka, kb = k[:, :, ::p, :], k[:, :, 1::p, :], k[:, :, 2::p, :]
                v, va, vb = v[:, :, ::p, :], v[:, :, 1::p, :], v[:, :, 2::p, :]
            elif skip:
                ka, va = ka[:, :, ::max(1, 6 - n.layer), :], va[:, :, ::max(1, 6 - n.layer), :]
                kb, vb = kb[:, :, ::max(1, 6 - n.layer), :], vb[:, :, ::max(1, 6 - n.layer), :]
            else:
                ka, va = ka, va
                kb, vb = kb, vb
                
            b = scaled_dot_product_attention(n.ln(qa), n.ln(kb), vb, is_causal=mask is not None)
            c = scaled_dot_product_attention(n.ln(qb), n.ln(ka), va, is_causal=mask is not None)

            return n.out(a), n.out(n.conv(b)), n.out(n.conv(c))
        else:
            return n.out(a)

class attentionb(nn.Module):
    def __init__(n, dims: int, head: int, layer: int, n_type):
        super().__init__()

        n.q   = nn.Sequential(get_norm(n_type, dims) , nn.Linear(dims, dims), Rearrange('b c (h d) -> b h c d', h = head))
        n.c   = nn.Sequential(get_norm(n_type, dims) , nn.Linear(dims, dims), Rearrange('b c (h d) -> b h c d', h = head))
        n.kv  = nn.Sequential(get_norm(n_type, dims), nn.Linear(dims, dims * 2), Rearrange('b c (kv h d) -> kv b h c d', kv = 2, h = head))
        n.out = nn.Sequential(Rearrange('b h n d -> b n (h d)'), nn.Linear(dims, dims), nn.Dropout(0.01))
        
        n.ln = get_norm(n_type, dims // head)
        
    def forward_triplet_toy(n, x, xa=None, mask=None, pt=None, context_window=3):
        q = n.q(x)
        k, v = n.kv(aorb(xa, x))
        b, h, seq_len, d = q.shape 
        scale = d ** -0.5

        if pt is not None:
            c = n.c(pt)
        else:
            c = torch.zeros_like(x)

        triplet_scores = torch.zeros(b, h, seq_len, seq_len, device=device)
        
        for i in range(seq_len):
            for j in range(seq_len):
                context_start = max(0, min(i, j) - context_window)
                context_end = min(seq_len, max(i, j) + context_window)
                
                for k in range(context_start, context_end): 
                    score = (q[:, :, i, :] * k[:, :, j, :] * c[:, :, k, :]).sum(dim=-1)
                    triplet_scores[:, :, i, j] += score

        qk = einsum('b h k d, b h q d -> b h k q', q, k) * scale + triplet_scores

        if have(mask):
            qk = qk + mask[:seq_len, :seq_len]

        qk = torch.nn.functional.softmax(qk, dim=-1)
        wv = einsum('b h k q, b h q d -> b h k d', qk, v) 
        return n.out(wv)

class gate(nn.Module):
    def __init__(n, dims, num_feature):
        super().__init__()

        n.gates = nn.ModuleList([nn.Sequential(nn.Linear(dims, 1), nn.Sigmoid()) for _ in range(num_feature)])
        n.features = nn.Sequential(nn.Linear(dims, num_feature), nn.Softmax(dim=-1))

        n.top = nn.Linear(dims, num_feature)
        n.alpha = nn.Parameter(torch.ones(1), requires_grad=True)

    def forward(n, x, num=None):
        types, indices = torch.topk(n.top(x), num, dim=-1)
        type = torch.zeros_like(n.features(x))
        type.scatter_(-1, indices, torch.nn.functional.softmax(types, dim=-1))
        features = torch.sigmoid(n.alpha) * type + (1 - torch.sigmoid(n.alpha)) * n.features(x)
        return torch.sum(torch.stack([gate(x) for gate in n.gates], dim=-1) * features.unsqueeze(2), dim=-1)

class residual(nn.Module):
    def __init__(n, dims, head, layer, act, n_type, expand=4, skip=False, pattern=None):
        super().__init__()

        n.head = head
        n.skip = skip
        n.ln = get_norm(n_type=n_type, dims=dims)
        n.act_fn = get_activation(act)

        n.attn = attention(dims, head, layer, n_type=n_type)
        n.gate = gate(dims, num_feature=head) if not skip else nn.Identity()
        n.mlp = nn.Sequential(n.ln, nn.Linear(dims, dims*expand), get_activation(act), nn.Linear(dims*expand, dims))

    def forward(n, x, xa=None, mask=None, pt=None, skip=None, pattern=None):

        x = x + n.attn(n.ln(x), mask=mask, pt=pt, skip=skip, pattern=pattern) 

        if xa is not None:        
            xa = xa + n.gate(xa, n.head // 2) if not n.skip else xa
            x = x + n.attn(n.ln(x), xa=xa, pt=pt, skip=skip, pattern=pattern)

        return x + n.mlp(x)

class attn_pass(nn.Module):
    def __init__(n, dims, head, layer, act, n_type, skip=None, pattern=None):
        super().__init__()
        
        n.layers = nn.ModuleList()
        for i in range(layer):
            n.layers.append(residual(dims, head, layer, act, n_type, skip=skip and i in skip, pattern=pattern[i] if pattern else None))

    def forward(n, x, override=None):
        for i, layer in enumerate(n.layers):
            x = layer(x, skip=i, pattern=override[i] if override else None)
        return x

class processor(nn.Module):
    def __init__(n, tokens, mels, dims, head, layer, act, n_type, ctx=2048): 
        super().__init__()
        
        n.ln = get_norm(n_type, dims)

        n.token = nn.Embedding(tokens, dims) 
        n.position = nn.Parameter(torch.ones(ctx, dims), requires_grad=True)
        n.blend = nn.Parameter(torch.tensor(0.5), requires_grad=True)

        n.block: Iterable[residual] = nn.ModuleList(
            [residual(dims=dims, head=head, layer=layer, act=act, n_type=n_type) for _ in range(layer)]) 
        
        n.register_buffer("mask", torch.empty(ctx, ctx).fill_(-np.inf).triu_(1), persistent=False)

    def forward(n, x, xa=None, seq=False) -> Tensor:

        pt, xb, xc, xa = (xa.pop(k, None) for k in ('pt', 'b', 'c', 'a')) if isinstance(xa, TensorDict) else (None, None, None, None)

        blend = torch.sigmoid(n.blend)
        x = (n.token(x) + n.position[:x.shape[-1]]).to(device, dtype)

        for i in n.block:
            a = i(x,  mask=n.mask, pt=pt)
            b = i(a, xa=i(xa, pt=pt))
            c = i(b, xa=i(xb, pt=pt))
            d = i(c, xa=i(xc, pt=pt))

            for j in [(xa), (xb), (xc)]:
                e = i(x, xa=i(j, pt=pt))

            f = torch.cat([d, e], dim=1)
            g = i(x=f[:, :x.shape[1]], xa=f[:, x.shape[1]:])
            x = g if seq else blend * (d) + (1 - blend) * g 

        return (n.ln(x) @ torch.transpose(n.token.weight.to(dtype), 0, 1)).float()

class Model(nn.Module):
    def __init__(n, param: Dimensions):
        super().__init__()

        n.param = param
        n.processor = processor(
            tokens=param.tokens,
            mels=param.mels,
            dims=param.dims,
            head=param.head,
            layer=param.layer,
            act=param.act,
            n_type=param.n_type,
            )

        n.enc = AudioEncoder(param.mels, param.dims, param.head, param.act, param.n_type, norm=False, enc=False)

        n.layer = 0
        for name, module in n.named_modules():
            if name == '':
                continue
            n.layer += 1        

    def forward(n, labels=None, input_ids=None, spectrogram=None, pitch=None, waveform=None, 
            harmonics=None, aperiodics=None, phase=None, hilbert=None, pitch_tokens=None):

        fb = next((t for t in (pitch, spectrogram, waveform) if t is not None), None)
        xa = TensorDict({
            'a': aorb(pitch, fb),
            'b': aorb(spectrogram, fb),
            'c': aorb(waveform, fb),
            'd': harmonics,
            'e': aperiodics,
            'f': phase,
            'g': hilbert,
            }, batch_size=fb.shape[0])

        x = input_ids
        xa = n.enc(no_none(xa))
        xa['pt']  = pitch_tokens if pitch_tokens is not None else None

        output = n.processor(x, xa, seq=False)
        loss = None
        if labels is not None:
            loss = torch.nn.functional.cross_entropy(output.view(-1, output.shape[-1]), 
                                                    labels.view(-1), ignore_index=0)
        return {"logits": output, "loss": loss}

    def _init_w(n, m):
        n.counts = {"Linear": 0, "Conv1d": 0, "LayerNorm": 0, "RMSNorm": 0, "Conv2d": 0, "processor": 0, "attention": 0, "Residual": 0}
        for name, m in n.named_modules():
            if isinstance(m, nn.RMSNorm):
                n.counts["RMSNorm"] += 1
            if isinstance(m, nn.LayerNorm):
                n.counts["LayerNorm"] += 1                
            elif isinstance(m, nn.Linear):
                n.counts["Linear"] += 1

    def init_w(n):
        print("Initializing model w...")
        n.apply(n._init_w)
        print("Initialization summary:")
        for module_type, count in n.counts.items():
            if count > 0:
                print(f"{module_type}: {count}")