Skip to content

Commit 77e39c7

Browse files
committed
Add Einsum
1 parent 559aca4 commit 77e39c7

File tree

9 files changed

+110
-5
lines changed

9 files changed

+110
-5
lines changed

docs/modules/ops.rst

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ API - Operations
116116
diag
117117
mask_select
118118
eye
119+
einsum
119120

120121
TensorLayerX Tensor Operations
121122
--------------------------------
@@ -558,4 +559,8 @@ mask_select
558559

559560
eye
560561
^^^^^^^^^^^^^^^^^^^^^^^
561-
.. autofunction:: eye
562+
.. autofunction:: eye
563+
564+
einsum
565+
^^^^^^^^^^^^^^^^^^^^^^^
566+
.. autofunction:: einsum

examples/basic_tutorials/tutorial_tensorlayerx_graph.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
#! /usr/bin/python
22
# -*- coding: utf-8 -*-
33

4+
import os
5+
# os.environ['TL_BACKEND'] = 'tensorflow'
6+
# os.environ['TL_BACKEND'] = 'mindspore'
7+
os.environ['TL_BACKEND'] = 'torch'
8+
49
import tensorlayerx as tlx
510
from tensorlayerx.nn import Module
611
from tensorlayerx.nn import Linear, Conv2d, BatchNorm2d, MaxPool2d, Flatten

tensorlayerx/backend/ops/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@
8484
from .load_backend import QuanConvBn
8585

8686
# load ops
87+
from .load_backend import einsum
8788
from .load_backend import Variable
8889
from .load_backend import matmul
8990
from .load_backend import add
@@ -245,3 +246,4 @@
245246
from .load_backend import ClipGradByValue
246247
from .load_backend import ClipGradByNorm
247248
from .load_backend import ClipByGlobalNorm
249+
from .load_backend import Einsum

tensorlayerx/backend/ops/mindspore_backend.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1821,3 +1821,21 @@ def eye(n, m = None, dtype = None):
18211821
if dtype is None:
18221822
dtype = mstype.float32
18231823
return ms.numpy.eye(n, m, dtype = dtype)
1824+
1825+
1826+
def einsum(equation, *operands):
1827+
if ms.__version__ < '1.7.0':
1828+
raise NotImplementedError("Only MindSpore versions later than 1.7.0 are supported.")
1829+
einsum_obj = ms.ops.Einsum(equation)
1830+
return einsum_obj(tuple(operands))
1831+
1832+
1833+
class Einsum(Cell):
1834+
def __init__(self, equation):
1835+
super(Einsum, self).__init__()
1836+
if ms.__version__ < '1.7.0':
1837+
raise NotImplementedError("Only MindSpore versions later than 1.7.0 are supported.")
1838+
self.einsum = ms.ops.Einsum(equation)
1839+
1840+
def __call__(self, *args):
1841+
return self.einsum(tuple(args))

tensorlayerx/backend/ops/paddle_backend.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1834,3 +1834,25 @@ def _apply_mask_1d(reshaped_tensor, mask, axis=None):
18341834

18351835
def eye(n, m=None, dtype=None):
18361836
return paddle.eye(n, m, dtype)
1837+
1838+
1839+
def einsum(equation, *operands):
1840+
try:
1841+
from paddlenlp.ops import einsum
1842+
except:
1843+
raise Exception("Paddlenlp needs to be installed.")
1844+
return einsum(equation, *operands)
1845+
1846+
1847+
class Einsum(object):
1848+
def __init__(self, equation):
1849+
super(Einsum, self).__init__()
1850+
try:
1851+
from paddlenlp.ops import einsum
1852+
except:
1853+
raise Exception("Paddlenlp needs to be installed.")
1854+
self.equation = equation
1855+
1856+
def __call__(self, *args):
1857+
return einsum(self.equation, *args)
1858+

tensorlayerx/backend/ops/tensorflow_backend.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3827,3 +3827,44 @@ def eye(n, m = None, dtype = None):
38273827
dtype = tf.dtypes.float32
38283828
return tf.eye(n, m, dtype = dtype)
38293829

3830+
3831+
def einsum(equation, *operands):
3832+
"""
3833+
Sums the product of the elements of the input operands along dimensions specified
3834+
using a notation based on the Einstein summation convention.
3835+
3836+
Parameters
3837+
----------
3838+
equation : An attribute
3839+
represent the operation you want to do.
3840+
the value can contain only letters([a-z][A-Z]), commas(,), ellipsis(…), and arrow(->).
3841+
the letters represent inputs’s tensor dimension, commas(,)represent separate tensors, ellipsis(…) indicates
3842+
the tensor dimension that you do not care about, the left of the arrow(->) indicates the input tensors,
3843+
and the right of it indicates the desired output dimension.
3844+
3845+
operands : list
3846+
input tensor used for calculation. the data type of the tensor must be the same.
3847+
3848+
Returns
3849+
-------
3850+
Tensor, the shape of it can be obtained from the equation, and the data type is the same as input tensors.
3851+
3852+
Examples
3853+
---------
3854+
>>> import tensorlayerx as tlx
3855+
>>> x = tlx.nn.Input((5,))
3856+
>>> y = tlx.nn.Input((4,))
3857+
>>> out = tlx.ops.einsum('i,j->ij', x, y)
3858+
>>> cal_enisum = tlx.ops.Einsum('i,j->ij')
3859+
>>> out = cal_enisum(x, y)
3860+
"""
3861+
return tf.einsum(equation, *operands)
3862+
3863+
3864+
class Einsum(object):
3865+
def __init__(self, equation):
3866+
super(Einsum, self).__init__()
3867+
self.equation = equation
3868+
3869+
def __call__(self, *args):
3870+
return tf.einsum(self.equation, *args)

tensorlayerx/backend/ops/torch_backend.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1666,4 +1666,17 @@ def mask_select(x, mask, axis = 0):
16661666
return x[:,:,:, mask]
16671667

16681668
def eye(n, m=None, dtype=None):
1669-
return torch.eye(n = n, m = m, dtype =dtype)
1669+
return torch.eye(n = n, m = m, dtype =dtype)
1670+
1671+
1672+
def einsum(equation, *operands):
1673+
return torch.einsum(equation, *operands)
1674+
1675+
1676+
class Einsum(object):
1677+
def __init__(self, equation):
1678+
super(Einsum, self).__init__()
1679+
self.equation = equation
1680+
1681+
def __call__(self, *args):
1682+
return torch.einsum(self.equation, *args)

tensorlayerx/nn/core/core_mindspore.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -287,8 +287,7 @@ def init_build(self, *inputs, **kwargs):
287287

288288
def build_graph(self, *inputs, **kwargs):
289289
# Add nodes only when the composition is needed.
290-
layers = self.cells_and_names(name_prefix='')
291-
for layer_name, layer in layers:
290+
for layer_name, layer in self._cells.items():
292291
if isinstance(layer, Module):
293292
layer._build_graph = True
294293
self.set_eval()

tensorlayerx/nn/core/core_torch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def init_build(self, *inputs, **kwargs):
174174

175175
def build_graph(self, *inputs, **kwargs):
176176
# Add nodes only when the composition is needed.
177-
for name, layer in self.named_modules():
177+
for name, layer in self._modules.items():
178178
if isinstance(layer, Module):
179179
layer._build_graph = True
180180
self.set_eval()

0 commit comments

Comments
 (0)