diff --git a/ray_adapter/_private/state.py b/ray_adapter/_private/state.py index c4722c26b6100e2b745a916356d4f119414425d4..37987dedae06917a812225c50103b2ac3247650c 100644 --- a/ray_adapter/_private/state.py +++ b/ray_adapter/_private/state.py @@ -36,17 +36,19 @@ def build_pg_dict(rg_info: RgInfo): for index, bundle in enumerate(rg_info.bundles): bundle_dict = {} - for key, value in bundle.resources.resources.items(): + for value in bundle.resources.resources.items(): resource = value if resource.type == Resource.Type.SCALER: - if "GPU" in resource.name: - bundle_dict["GPU"] = resource.scalar.value - elif "NPU" in resource.name: - bundle_dict["NPU"] = resource.scalar.value - elif "Memory" in resource.name: - bundle_dict["memory"] = resource.scalar.value - else: - bundle_dict[resource.name] = resource.scalar.value + resource_mapping = { + "GPU": "GPU", + "NPU": "NPU", + "Memory": "memory" + } + target_key = next( + (v for k, v in resource_mapping.items() if k in resource.name), + resource.name + ) + bundle_dict[target_key] = resource.scalar.value rg_dict['bundles'][index] = bundle_dict rg_dict['bundles_to_node_id'][index] = bundle.functionProxyId