@@ -409,6 +409,7 @@ def _from_keras_tf1(model, opset=None, custom_ops=None, custom_op_handlers=None,
409
409
410
410
return model_proto , external_tensor_storage
411
411
412
+
412
413
def from_keras3 (model , input_signature = None , opset = None , custom_ops = None , custom_op_handlers = None ,
413
414
custom_rewriter = None , inputs_as_nchw = None , outputs_as_nchw = None , extra_opset = None , shape_override = None ,
414
415
target = None , large_model = False , output_path = None , optimizers = None ):
@@ -434,10 +435,25 @@ def from_keras3(model, input_signature=None, opset=None, custom_ops=None, custom
434
435
A tuple (model_proto, external_tensor_storage_dict)
435
436
"""
436
437
if not input_signature :
438
+ if hasattr (model , "inputs" ):
439
+ model_input = model .inputs
440
+ elif hasattr (model , "input_dtype" ) and hasattr (model , "_build_shapes_dict" ):
441
+ if len (model ._build_shapes_dict ) == 1 :
442
+ shape = list (model ._build_shapes_dict .values ())[0 ]
443
+ model_input = [tf .Variable (tf .zeros (shape , dtype = model .input_dtype ), name = "input" )]
444
+ else :
445
+ raise RuntimeError (f"Not implemented yet with input_dtype={ model .input_dtype } and model._build_shapes_dict={ model ._build_shapes_dict } " )
446
+ else :
447
+ if not hasattr (model , "inputs_spec" ):
448
+ raise RuntimeError ("You may set attribute 'inputs_spec' with your inputs (model.input_specs = ...)" )
449
+ model_input = model .inputs_spec
450
+
437
451
input_signature = [
438
452
tf .TensorSpec (tensor .shape , tensor .dtype , name = tensor .name .split (":" )[0 ])
439
- for tensor in model . inputs
453
+ for tensor in model_input
440
454
]
455
+ else :
456
+ model_input = None
441
457
442
458
# Trace model
443
459
function = tf .function (model )
@@ -459,13 +475,33 @@ def from_keras3(model, input_signature=None, opset=None, custom_ops=None, custom
459
475
reverse_lookup = {v : k for k , v in tensors_to_rename .items ()}
460
476
461
477
valid_names = []
462
- for out in [t .name for t in model .outputs ]:
478
+ if hasattr (model , "outputs" ):
479
+ model_output = model .outputs
480
+ else :
481
+ if hasattr (model , "outputs_spec" ):
482
+ model_output = model .outputs_spec
483
+ elif model_input and len (model_input ) == 1 :
484
+ # Let's try something to make unit test work. This should be replaced.
485
+ model_output = [tf .Variable (model_input [0 ], name = "output" )]
486
+ else :
487
+ raise RuntimeError (
488
+ "You should set attribute 'outputs_spec' with your outputs "
489
+ "so that the expected can use that information."
490
+ )
491
+
492
+ def _get_name (t , i ):
493
+ try :
494
+ return t .name
495
+ except AttributeError :
496
+ return f"output:{ i } "
497
+
498
+ for out in [_get_name (t , i ) for i , t in enumerate (model_output )]:
463
499
if out in reverse_lookup :
464
500
valid_names .append (reverse_lookup [out ])
465
501
else :
466
502
print (f"Warning: Output name '{ out } ' not found in reverse_lookup." )
467
503
# Fallback: verwende TensorFlow-Ausgangsnamen direkt
468
- valid_names = [t . name for t in concrete_func .outputs if t .dtype != tf .dtypes .resource ]
504
+ valid_names = [_get_name ( t , i ) for i , t in enumerate ( concrete_func .outputs ) if t .dtype != tf .dtypes .resource ]
469
505
break
470
506
output_names = valid_names
471
507
0 commit comments