Refactor user_defined_types.py

Change-Id: I8fca1740a99daa14da77fac0632af495925116d4
This commit is contained in:
Henrik Wahlqvist
2025-04-02 16:15:14 +02:00
parent c9583f4055
commit 6faabe161d

View File

@@ -49,6 +49,8 @@ class UserDefinedTypes(ProblemLogger):
)
raise TypeError(err)
start_time = time.time()
self.info(' Start parsing files with user defined data types')
self._build_prj_cfg = build_prj_config
self._unit_configs = unit_configs
self.enums_per_unit = {}
@@ -57,7 +59,9 @@ class UserDefinedTypes(ProblemLogger):
self._parse_all_user_defined_types()
self.common_header_files = []
# Must run last to be able to compare with TL/EC data types
self.all_enums = self._get_enumerations()
self._interface_data_types = self._parse_interface_data_types()
self.info(' Finished parsing files with user defined data types (in %4.2f s)', time.time() - start_time)
@staticmethod
def convert_interface_enum_to_simulink(interface_enum, underlying_data_type=None):
@@ -111,12 +115,9 @@ class UserDefinedTypes(ProblemLogger):
def _parse_all_user_defined_types(self):
"""Parse all files containing user defined data types."""
start_time = time.time()
self.info(' Start parsing files with user defined data types')
src_dirs = self._build_prj_cfg.get_unit_src_dirs()
for unit, src_dir in src_dirs.items():
self._parse_unit_user_defined_types(unit, src_dir)
self.info(' Finished parsing files with user defined data types (in %4.2f s)', time.time() - start_time)
def _parse_unit_user_defined_types(self, unit, unit_src_dir):
"""Parse unit defined types for a given unit.
@@ -134,7 +135,7 @@ class UserDefinedTypes(ProblemLogger):
for file_prefix in self.FILE_PREFIXES:
found_files.extend(Path(unit_src_dir).glob(file_prefix + '*.h'))
for found_file in found_files:
self.info(' Parsing file: %s', str(found_file))
self.debug(' Parsing file: %s', str(found_file))
self._parse_target_link_enum(unit, found_file)
self._parse_target_link_struct(unit, found_file)
self._validate_project_enumerations()
@@ -197,27 +198,24 @@ class UserDefinedTypes(ProblemLogger):
valid_interface_enumerations (dict): Specification, valid interface enumerations.
"""
valid_interface_enumerations = {}
user_defined_enumerations = self.get_enumerations()
for enum_name, enum_data in interface_enumerations.items():
if enum_name in valid_interface_enumerations:
self.critical('%s is multiply defined in interface enumeration definitions.', enum_name)
elif enum_name not in user_defined_enumerations:
elif enum_name not in self.all_enums:
self.critical('%s is not defined in the project', enum_name)
else:
converted = self.convert_interface_enum_to_simulink(
enum_data,
user_defined_enumerations[enum_name]['underlying_data_type']
self.all_enums[enum_name]['underlying_data_type']
)
is_consistent = self._compare_enum_definitions(
user_defined_enumerations[enum_name]['units'],
self.all_enums[enum_name]['units'],
enum_name,
converted,
user_defined_enumerations[enum_name]
self.all_enums[enum_name]
)
if is_consistent:
valid_interface_enumerations[enum_name] = enum_data
return valid_interface_enumerations
def _validate_project_enumerations(self):
@@ -424,6 +422,33 @@ class UserDefinedTypes(ProblemLogger):
}
return common_enums
def _get_enumerations(self):
"""Get all enumeration defined in the project, together with unit usage.
Information already provided in self._validate_project_enumerations during initialization.
Returns:
enumerations (dict): Enumerations defined in the projects, including unit usage.
"""
enumerations = {}
for unit, enum_names in self.enums_per_unit.items():
for enum_name, enum_data in enum_names.items():
if enum_name in enumerations:
enumerations[enum_name]['units'].append(unit)
else:
enumerations[enum_name] = deepcopy(enum_data)
enumerations[enum_name]['units'] = [unit]
if self._build_prj_cfg.get_code_generation_config("includeAllEnums"):
for enum_name, enum_data in self.common_enums.items():
if enum_name not in enumerations:
self.warning(
"Enumeration %s is not used in any unit. Included since 'includeAllEnums' is set in config.",
enum_name
)
enumerations[enum_name] = deepcopy(enum_data)
enumerations[enum_name]['units'] = []
return enumerations
def get_default_enum_value(self, unit, enum_name):
"""Get default value of given enumeration name by searching in the unit configuration.
@@ -466,29 +491,10 @@ class UserDefinedTypes(ProblemLogger):
def get_enumerations(self):
"""Get all enumeration defined in the project, together with unit usage.
Information already provided in self._validate_project_enumerations during initialization.
Returns:
enumerations (dict): Enumerations defined in the projects, including unit usage.
self.all_enums (dict): Enumerations defined in the projects, including unit usage.
"""
enumerations = {}
for unit, enum_names in self.enums_per_unit.items():
for enum_name, enum_data in enum_names.items():
if enum_name in enumerations:
enumerations[enum_name]['units'].append(unit)
else:
enumerations[enum_name] = deepcopy(enum_data)
enumerations[enum_name]['units'] = [unit]
if self._build_prj_cfg.get_code_generation_config("includeAllEnums"):
for enum_name, enum_data in self.common_enums.items():
if enum_name not in enumerations:
self.warning(
"Enumeration %s is not used in any unit. Included since 'includeAllEnums' is set in config.",
enum_name
)
enumerations[enum_name] = deepcopy(enum_data)
enumerations[enum_name]['units'] = []
return enumerations
return self.all_enums
def get_interface_data_types(self):
"""Returns all interface data types"""