@@ -134,17 +134,20 @@ constexpr auto& retrieve_extents(binary_tensor_expression<T,EL,ER,OP> const& exp
134134 static_assert (has_tensor_types_v<T,binary_tensor_expression<T,EL,ER,OP>>,
135135 " Error in boost::numeric::ublas::detail::retrieve_extents: Expression to evaluate should contain tensors." );
136136
137+ auto const & lexpr = expr.left_expr ();
138+ auto const & rexpr = expr.right_expr ();
139+
137140 if constexpr ( same_exp<T,EL> )
138- return expr. el .extents ();
141+ return lexpr .extents ();
139142
140143 else if constexpr ( same_exp<T,ER> )
141- return expr. er .extents ();
144+ return rexpr .extents ();
142145
143146 else if constexpr ( has_tensor_types_v<T,EL> )
144- return retrieve_extents (expr. el );
147+ return retrieve_extents (lexpr );
145148
146149 else if constexpr ( has_tensor_types_v<T,ER> )
147- return retrieve_extents (expr. er );
150+ return retrieve_extents (rexpr );
148151}
149152
150153#ifdef _MSC_VER
@@ -164,12 +167,14 @@ constexpr auto& retrieve_extents(unary_tensor_expression<T,E,OP> const& expr)
164167
165168 static_assert (has_tensor_types_v<T,unary_tensor_expression<T,E,OP>>,
166169 " Error in boost::numeric::ublas::detail::retrieve_extents: Expression to evaluate should contain tensors." );
170+
171+ auto const & uexpr = expr.expr ();
167172
168173 if constexpr ( same_exp<T,E> )
169- return expr. e .extents ();
174+ return uexpr .extents ();
170175
171176 else if constexpr ( has_tensor_types_v<T,E> )
172- return retrieve_extents (expr. e );
177+ return retrieve_extents (uexpr );
173178}
174179
175180} // namespace boost::numeric::ublas::detail
@@ -221,20 +226,23 @@ constexpr auto all_extents_equal(binary_tensor_expression<T,EL,ER,OP> const& exp
221226 using ::operator ==;
222227 using ::operator !=;
223228
229+ auto const & lexpr = expr.left_expr ();
230+ auto const & rexpr = expr.right_expr ();
231+
224232 if constexpr ( same_exp<T,EL> )
225- if (e != expr. el .extents ())
233+ if (e != lexpr .extents ())
226234 return false ;
227235
228236 if constexpr ( same_exp<T,ER> )
229- if (e != expr. er .extents ())
237+ if (e != rexpr .extents ())
230238 return false ;
231239
232240 if constexpr ( has_tensor_types_v<T,EL> )
233- if (!all_extents_equal (expr. el , e))
241+ if (!all_extents_equal (lexpr , e))
234242 return false ;
235243
236244 if constexpr ( has_tensor_types_v<T,ER> )
237- if (!all_extents_equal (expr. er , e))
245+ if (!all_extents_equal (rexpr , e))
238246 return false ;
239247
240248 return true ;
@@ -250,12 +258,14 @@ constexpr auto all_extents_equal(unary_tensor_expression<T,E,OP> const& expr, ex
250258
251259 using ::operator ==;
252260
261+ auto const & uexpr = expr.expr ();
262+
253263 if constexpr ( same_exp<T,E> )
254- if (e != expr. e .extents ())
264+ if (e != uexpr .extents ())
255265 return false ;
256266
257267 if constexpr ( has_tensor_types_v<T,E> )
258- if (!all_extents_equal (expr. e , e))
268+ if (!all_extents_equal (uexpr , e))
259269 return false ;
260270
261271 return true ;
@@ -281,9 +291,11 @@ inline void eval(tensor_type& lhs, tensor_expression<tensor_type, derived_type>
281291 if (!all_extents_equal (expr, lhs.extents () ))
282292 throw std::runtime_error (" Error in boost::numeric::ublas::tensor_core: expression contains tensors with different shapes." );
283293
284- #pragma omp parallel for
294+ auto const & rhs = cast_tensor_exression (expr);
295+
296+ #pragma omp parallel for
285297 for (auto i = 0u ; i < lhs.size (); ++i)
286- lhs (i) = expr () (i);
298+ lhs (i) = rhs (i);
287299}
288300
289301/* * @brief Evaluates expression for a tensor_core
@@ -310,9 +322,11 @@ inline void eval(tensor_type& lhs, tensor_expression<other_tensor_type, derived_
310322 throw std::runtime_error (" Error in boost::numeric::ublas::tensor_core: expression contains tensors with different shapes." );
311323 }
312324
325+ auto const & rhs = cast_tensor_exression (expr);
326+
313327 #pragma omp parallel for
314328 for (auto i = 0u ; i < lhs.size (); ++i)
315- lhs (i) = expr () (i);
329+ lhs (i) = rhs (i);
316330}
317331
318332/* * @brief Evaluates expression for a tensor_core
@@ -330,9 +344,11 @@ inline void eval(tensor_type& lhs, tensor_expression<tensor_type, derived_type>
330344 if (!all_extents_equal ( expr, lhs.extents () ))
331345 throw std::runtime_error (" Error in boost::numeric::ublas::tensor_core: expression contains tensors with different shapes." );
332346
347+ auto const & rhs = cast_tensor_exression (expr);
348+
333349 #pragma omp parallel for
334350 for (auto i = 0u ; i < lhs.size (); ++i)
335- fn (lhs (i), expr () (i));
351+ fn (lhs (i), rhs (i));
336352}
337353
338354
@@ -347,7 +363,7 @@ inline void eval(tensor_type& lhs, tensor_expression<tensor_type, derived_type>
347363template <class tensor_type , class unary_fn >
348364inline void eval (tensor_type& lhs, unary_fn const & fn)
349365{
350- #pragma omp parallel for
366+ #pragma omp parallel for
351367 for (auto i = 0u ; i < lhs.size (); ++i)
352368 fn (lhs (i));
353369}
0 commit comments