Source code for plate_model_manager.plate_model_manager

import json
import os

import requests

from .plate_model import PlateModel


[docs]class PlateModelManager: """load a models.json file and manage plate models see an example models.json file at PlateModelManager.get_default_repo_url() """ def __init__(self, model_manifest=None): """constructor :param model_manifest: the path to a models.json file """ if not model_manifest: self.model_manifest = PlateModelManager.get_default_repo_url() else: self.model_manifest = model_manifest self.models = None # check if the model manifest file is a local file if os.path.isfile(self.model_manifest): with open(self.model_manifest) as f: self.models = json.load(f) elif self.model_manifest.startswith( "http://" ) or self.model_manifest.startswith("https://"): # try the http(s) url try: r = requests.get(self.model_manifest) self.models = r.json() except requests.exceptions.ConnectionError: raise Exception( f"Unable to fetch {self.model_manifest}. " + "No network connection or invalid URL!" ) else: raise Exception( f"The model_manifest '{self.model_manifest}' should be either a local file path or a http(s) URL." )
[docs] def get_model(self, model_name: str = "default", data_dir: str = "."): """return a PlateModel object by model_name :param model_name: model name :param data_dir: the default data_dir for the model. This dir can be changed with PlateModel.set_data_dir() later. :returns: a PlateModel object or none if model name is no good """ model_name = model_name.lower() if model_name in self.models: # model name is an alias if isinstance(self.models[model_name], str): m_name = self.models[model_name] if m_name.startswith("@"): m_name = self.models[model_name][1:] m = self.get_model(m_name, data_dir=data_dir) return PlateModel(model_name, model_cfg=m.get_cfg(), data_dir=data_dir) else: return PlateModel( model_name, model_cfg=self.models[model_name], data_dir=data_dir ) else: print(f"Model {model_name} is not available.") return None
[docs] def get_available_model_names(self): """return the names of available models as a list""" return [name for name in self.models]
[docs] @staticmethod def get_local_available_model_names(local_dir): """list all model names in a local folder""" models = [] for file in os.listdir(local_dir): d = os.path.join(local_dir, file) if os.path.isdir(d) and os.path.isfile(f"{d}/.metadata.json"): models.append(file) return models
[docs] @staticmethod def get_default_repo_url(): return "https://repo.gplates.org/webdav/pmm/models.json"
[docs] def download_all_models(self, data_dir="./"): """download all available models into data_dir""" model_names = self.get_available_model_names() for name in model_names: print(f"download {name}") model = self.get_model(name) model.set_data_dir(data_dir) model.download_all_layers()