@@ -67,14 +67,56 @@ NEML2TensorCompute::NEML2TensorCompute(const InputParameters & params)
67
67
NEML2Utils ::assertNEML2Enabled ();
68
68
69
69
#ifdef NEML2_ENABLED
70
- for (const auto & [swift_input_name , neml2_input_name ] :
71
- getParam < TensorInputBufferName , std ::string > ("swift_inputs ", "neml2_inputs "))
70
+ const auto inputs = getParam < TensorInputBufferName , std ::string > ("swift_inputs ", "neml2_inputs ");
71
+ std ::map < neml2 ::LabeledAxisAccessor , TensorInputBufferName > lookup_swift_name ;
72
+ const auto model_inputs = _model .consumed_items ();
73
+
74
+ // current inputs
75
+ for (const auto & [swift_input_name , neml2_input_name ] : inputs )
72
76
{
73
- const auto * input_buffer = & getInputBufferByName < > (swift_input_name );
74
- _input_mapping .emplace_back (
75
- input_buffer , neml2 ::LabeledAxisAccessor (NEML2Utils ::parseVariableName (neml2_input_name )));
77
+ const auto neml2_input =
78
+ neml2 ::LabeledAxisAccessor (NEML2Utils ::parseVariableName (neml2_input_name ));
79
+
80
+ // populate reverse lookup map
81
+ if (lookup_swift_name .find (neml2_input ) != lookup_swift_name .end ())
82
+ mooseError ("Repeated NEML2 input " , neml2_input_name );
83
+ lookup_swift_name [neml2_input ] = swift_input_name ;
84
+
85
+ // the user should only specify current neml2 axis
86
+ if (!neml2_input .is_state () && !neml2_input .is_force ())
87
+ mooseError ("Specify only current forces or states as inputs. Old forces and states are "
88
+ "automatically coupled when needed." );
89
+
90
+ // add input if the model requires it
91
+ if (model_inputs .count (neml2_input ))
92
+ {
93
+ const auto * input_buffer = & getInputBufferByName < > (swift_input_name );
94
+ const auto type = _model .input_variable (neml2_input ).type ();
95
+ _input_mapping .emplace_back (input_buffer , type , neml2_input );
96
+ }
76
97
}
77
98
99
+ // old state inputs
100
+ for (const auto & neml2_input : model_inputs )
101
+ if (neml2_input .is_old_state ())
102
+ {
103
+ // check if we couple the current state
104
+ auto it = lookup_swift_name .find (neml2_input .current ());
105
+ if (it == lookup_swift_name .end ())
106
+ mooseError ("The model requires " ,
107
+ neml2_input ,
108
+ " but no tensor buffer is assigned to " ,
109
+ neml2_input .current (),
110
+ "." );
111
+ const auto & swift_input_name = it -> second ;
112
+
113
+ const auto * old_states = & getBufferOldByName < > (swift_input_name , 1 );
114
+ // we also get the current state here just to step zero, when no old state exists!
115
+ const auto * input_buffer = & getInputBufferByName < > (swift_input_name );
116
+ const auto type = _model .input_variable (neml2_input ).type ();
117
+ _old_input_mapping .emplace_back (old_states , input_buffer , type , neml2_input );
118
+ }
119
+
78
120
for (const auto & [neml2_output_name , swift_output_name ] :
79
121
getParam < std ::string , TensorInputBufferName > ("neml2_outputs" , "swift_outputs" ))
80
122
{
@@ -99,20 +141,30 @@ NEML2TensorCompute::computeBuffer()
99
141
{
100
142
#ifdef NEML2_ENABLED
101
143
neml2 ::ValueMap in ;
102
- for (const auto & [ tensor_ptr , label ] : _input_mapping )
144
+ auto insert_tensor = [ & in , this ] (const auto & tensor , auto type , const auto & label )
103
145
{
104
146
// convert tensors on the fly at runtime
105
- auto sizes = tensor_ptr -> sizes ();
106
- mooseInfoRepeated (name (), " sizes size " , sizes .size (), " is " , Moose ::stringify (sizes ));
107
- if (sizes .size () == _dim )
108
- in [label ] = neml2 ::Scalar (* tensor_ptr );
109
- else if (sizes .size () == _dim + 1 )
110
- in [label ] = neml2 ::Vec (* tensor_ptr , _domain .getShape ());
111
- else if (sizes .size () == _dim + 3 )
112
- in [label ] = neml2 ::R2 (* tensor_ptr , _domain .getShape ());
147
+ auto sizes = tensor .sizes ();
148
+ if (sizes .size () == _dim && type == neml2 ::TensorType ::kScalar )
149
+ in [label ] = neml2 ::Scalar (tensor );
150
+ else if (sizes .size () == _dim + 1 && type == neml2 ::TensorType ::kVec )
151
+ in [label ] = neml2 ::Vec (tensor , _domain .getShape ());
152
+ else if (sizes .size () == _dim + 3 && type == neml2 ::TensorType ::kR2 )
153
+ in [label ] = neml2 ::R2 (tensor , _domain .getShape ());
113
154
else
114
- mooseError ("Unsupported tensor dimension" );
115
- }
155
+ mooseError ("Unsupported/mismatching tensor dimension" );
156
+ };
157
+
158
+ // insert current state
159
+ for (const auto & [current_state , type , label ] : _input_mapping )
160
+ insert_tensor (* current_state , type , label );
161
+
162
+ // insert old state
163
+ for (const auto & [old_states , current_state , type , label ] : _old_input_mapping )
164
+ if (old_states -> empty ())
165
+ insert_tensor (* current_state , type , label );
166
+ else
167
+ insert_tensor ((* old_states )[0 ], type , label );
116
168
117
169
auto out = _model .value (in );
118
170
0 commit comments