log_model

import argparse
import logging
import os

from fgvc.utils.wandb import wandb

logger = logging.getLogger("script")


def load_args(args: str = None) -> argparse.Namespace:
    """Load script arguments using `argparse` library."""
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--wandb-run-id",
        help="Run id for logging experiment to wandb.",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--wandb-project",
        help="Project name for logging experiment to wandb.",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--wandb-entity",
        help="Entity name for logging experiment to wandb.",
        type=str,
        required=True,
    )
    args = parser.parse_args(args)
    return args


def log_model(
    *,
    wandb_entity: str = None,
    wandb_project: str = None,
    wandb_run_id: str = None,
    **kwargs,
):
    """Log model from W&B experiment run as a W&B artifact."""
    if wandb is None:
        raise ImportError("Package wandb is not installed.")

    if wandb_entity is None or wandb_project is None or wandb_run_id is None:
        # load script args
        args = load_args()
        wandb_entity = args.wandb_entity
        wandb_project = args.wandb_project
        wandb_run_id = args.wandb_run_id

    # resume wandb run
    wandb_run_path = f"{wandb_entity}/{wandb_project}/{wandb_run_id}"
    logger.info(f"Logging model in W&B from experiment run: {wandb_run_path}")
    with wandb.init(
        id=wandb_run_id,
        project=wandb_project,
        entity=wandb_entity,
        save_code=False,
        resume="must",
    ) as run:
        model_weights = os.path.join(run.config["exp_path"], "best_loss.pth")
        if not os.path.isfile(model_weights):
            raise ValueError(f"Model checkpoint '{model_weights}' not found.")
        logger.info(f"Using model checkpoint from the file: {model_weights}")
        arch_name = run.config["architecture"]

        # log artifact
        artifact_name = f"{run.name}-{arch_name}"
        artifact = wandb.Artifact(artifact_name, type="model", metadata=None)
        artifact.add_file(local_path=model_weights)
        run.log_artifact(artifact)
        logger.info(f"Created W&B artifact: {artifact_name}")


if __name__ == "__main__":
    log_model()