Skip to content

Commit a2cf69b

Browse files
committed
Add old state (#50)
1 parent 0206a59 commit a2cf69b

File tree

4 files changed

+98
-18
lines changed

4 files changed

+98
-18
lines changed

include/tensor_buffers/TensorBuffer.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
#include "TensorBufferBase.h"
1212

1313
/**
14-
* Tensor wrapper arbitrary tensor value dimensions
14+
* Tensor wrapper for arbitrary tensor value dimensions
1515
*/
1616
template <typename T>
1717
class TensorBuffer : public TensorBufferBase

include/tensor_computes/NEML2TensorCompute.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,14 @@ class NEML2TensorCompute : public TensorOperatorBase
3232
#ifdef NEML2_ENABLED
3333
neml2::Model & _model;
3434

35-
std::vector<std::pair<const torch::Tensor *, neml2::LabeledAxisAccessor>> _input_mapping;
35+
std::vector<std::tuple<const torch::Tensor *, neml2::TensorType, neml2::LabeledAxisAccessor>>
36+
_input_mapping;
37+
std::vector<std::tuple<const std::vector<torch::Tensor> *,
38+
const torch::Tensor *,
39+
neml2::TensorType,
40+
neml2::LabeledAxisAccessor>>
41+
_old_input_mapping;
42+
3643
std::vector<std::pair<neml2::LabeledAxisAccessor, torch::Tensor *>> _output_mapping;
3744
#endif
3845
};

include/tensor_computes/TensorOperatorBase.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,13 @@ class TensorOperatorBase : public MooseObject,
5959
template <typename T = torch::Tensor>
6060
T & getOutputBufferByName(const TensorOutputBufferName & buffer_name);
6161

62+
template <typename T = torch::Tensor>
63+
const std::vector<T> & getBufferOld(const std::string & param, unsigned int max_states);
64+
65+
template <typename T = torch::Tensor>
66+
const std::vector<T> & getBufferOldByName(const std::string & buffer_name,
67+
unsigned int max_states);
68+
6269
std::set<std::string> _requested_buffers;
6370
std::set<std::string> _supplied_buffers;
6471

@@ -110,3 +117,17 @@ TensorOperatorBase::getOutputBufferByName(const TensorOutputBufferName & buffer_
110117
_supplied_buffers.insert(buffer_name);
111118
return _tensor_problem.getBuffer<T>(buffer_name);
112119
}
120+
121+
template <typename T>
122+
const std::vector<T> &
123+
TensorOperatorBase::getBufferOld(const std::string & param, unsigned int max_states)
124+
{
125+
return getBufferOldByName<T>(getParam<TensorInputBufferName>(param), max_states);
126+
}
127+
128+
template <typename T>
129+
const std::vector<T> &
130+
TensorOperatorBase::getBufferOldByName(const std::string & buffer_name, unsigned int max_states)
131+
{
132+
return _tensor_problem.getBufferOld<T>(buffer_name, max_states);
133+
}

src/tensor_computes/NEML2TensorCompute.C

Lines changed: 68 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,56 @@ NEML2TensorCompute::NEML2TensorCompute(const InputParameters & params)
6767
NEML2Utils::assertNEML2Enabled();
6868

6969
#ifdef NEML2_ENABLED
70-
for (const auto & [swift_input_name, neml2_input_name] :
71-
getParam<TensorInputBufferName, std::string>("swift_inputs", "neml2_inputs"))
70+
const auto inputs = getParam<TensorInputBufferName, std::string>("swift_inputs", "neml2_inputs");
71+
std::map<neml2::LabeledAxisAccessor, TensorInputBufferName> lookup_swift_name;
72+
const auto model_inputs = _model.consumed_items();
73+
74+
// current inputs
75+
for (const auto & [swift_input_name, neml2_input_name] : inputs)
7276
{
73-
const auto * input_buffer = &getInputBufferByName<>(swift_input_name);
74-
_input_mapping.emplace_back(
75-
input_buffer, neml2::LabeledAxisAccessor(NEML2Utils::parseVariableName(neml2_input_name)));
77+
const auto neml2_input =
78+
neml2::LabeledAxisAccessor(NEML2Utils::parseVariableName(neml2_input_name));
79+
80+
// populate reverse lookup map
81+
if (lookup_swift_name.find(neml2_input) != lookup_swift_name.end())
82+
mooseError("Repeated NEML2 input ", neml2_input_name);
83+
lookup_swift_name[neml2_input] = swift_input_name;
84+
85+
// the user should only specify current neml2 axis
86+
if (!neml2_input.is_state() && !neml2_input.is_force())
87+
mooseError("Specify only current forces or states as inputs. Old forces and states are "
88+
"automatically coupled when needed.");
89+
90+
// add input if the model requires it
91+
if (model_inputs.count(neml2_input))
92+
{
93+
const auto * input_buffer = &getInputBufferByName<>(swift_input_name);
94+
const auto type = _model.input_variable(neml2_input).type();
95+
_input_mapping.emplace_back(input_buffer, type, neml2_input);
96+
}
7697
}
7798

99+
// old state inputs
100+
for (const auto & neml2_input : model_inputs)
101+
if (neml2_input.is_old_state())
102+
{
103+
// check if we couple the current state
104+
auto it = lookup_swift_name.find(neml2_input.current());
105+
if (it == lookup_swift_name.end())
106+
mooseError("The model requires ",
107+
neml2_input,
108+
" but no tensor buffer is assigned to ",
109+
neml2_input.current(),
110+
".");
111+
const auto & swift_input_name = it->second;
112+
113+
const auto * old_states = &getBufferOldByName<>(swift_input_name, 1);
114+
// we also get the current state here just to step zero, when no old state exists!
115+
const auto * input_buffer = &getInputBufferByName<>(swift_input_name);
116+
const auto type = _model.input_variable(neml2_input).type();
117+
_old_input_mapping.emplace_back(old_states, input_buffer, type, neml2_input);
118+
}
119+
78120
for (const auto & [neml2_output_name, swift_output_name] :
79121
getParam<std::string, TensorInputBufferName>("neml2_outputs", "swift_outputs"))
80122
{
@@ -99,20 +141,30 @@ NEML2TensorCompute::computeBuffer()
99141
{
100142
#ifdef NEML2_ENABLED
101143
neml2::ValueMap in;
102-
for (const auto & [tensor_ptr, label] : _input_mapping)
144+
auto insert_tensor = [&in, this](const auto & tensor, auto type, const auto & label)
103145
{
104146
// convert tensors on the fly at runtime
105-
auto sizes = tensor_ptr->sizes();
106-
mooseInfoRepeated(name(), " sizes size ", sizes.size(), " is ", Moose::stringify(sizes));
107-
if (sizes.size() == _dim)
108-
in[label] = neml2::Scalar(*tensor_ptr);
109-
else if (sizes.size() == _dim + 1)
110-
in[label] = neml2::Vec(*tensor_ptr, _domain.getShape());
111-
else if (sizes.size() == _dim + 3)
112-
in[label] = neml2::R2(*tensor_ptr, _domain.getShape());
147+
auto sizes = tensor.sizes();
148+
if (sizes.size() == _dim && type == neml2::TensorType::kScalar)
149+
in[label] = neml2::Scalar(tensor);
150+
else if (sizes.size() == _dim + 1 && type == neml2::TensorType::kVec)
151+
in[label] = neml2::Vec(tensor, _domain.getShape());
152+
else if (sizes.size() == _dim + 3 && type == neml2::TensorType::kR2)
153+
in[label] = neml2::R2(tensor, _domain.getShape());
113154
else
114-
mooseError("Unsupported tensor dimension");
115-
}
155+
mooseError("Unsupported/mismatching tensor dimension");
156+
};
157+
158+
// insert current state
159+
for (const auto & [current_state, type, label] : _input_mapping)
160+
insert_tensor(*current_state, type, label);
161+
162+
// insert old state
163+
for (const auto & [old_states, current_state, type, label] : _old_input_mapping)
164+
if (old_states->empty())
165+
insert_tensor(*current_state, type, label);
166+
else
167+
insert_tensor((*old_states)[0], type, label);
116168

117169
auto out = _model.value(in);
118170

0 commit comments

Comments
 (0)