diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 00000000..fe58f73b --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,17 @@ +name: Lint + +on: + push: + branches: [main] + pull_request: + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: astral-sh/setup-uv@v5 + - run: uv python install 3.10 + - run: uv pip install ruff + - run: uv run ruff check . + - run: uv run ruff format --check . diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..700cbc40 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,7 @@ +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.12.2 + hooks: + - id: ruff + args: [--fix] + - id: ruff-format diff --git a/agents/game_agent/app.py b/agents/game_agent/app.py index 65d64507..8fbdc9cb 100644 --- a/agents/game_agent/app.py +++ b/agents/game_agent/app.py @@ -5,15 +5,15 @@ # Jianwei Yang (jianwyan@microsoft.com) # -------------------------------------------------------- -import pygame -import numpy as np +import random +import re + import gradio as gr -import time +import numpy as np +import pygame import torch from PIL import Image from transformers import AutoModelForCausalLM, AutoProcessor -import re -import random pygame.mixer.quit() # Disable sound diff --git a/agents/libero/eval_magma_libero.py b/agents/libero/eval_magma_libero.py index 1ef1a8ef..b5c29e81 100644 --- a/agents/libero/eval_magma_libero.py +++ b/agents/libero/eval_magma_libero.py @@ -1,22 +1,13 @@ import os -import numpy as np -import draccus from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Tuple + +import draccus import tqdm from libero.libero import benchmark -from libero_env_utils import ( - get_libero_env, - get_libero_dummy_action, - get_libero_obs, - get_max_steps, - set_seed_everywhere -) -from libero_magma_utils import ( - get_magma_model, - get_magma_prompt, - get_magma_action -) +from libero_env_utils import get_libero_dummy_action, get_libero_env, get_libero_obs, get_max_steps, set_seed_everywhere +from libero_magma_utils import get_magma_action, get_magma_model, get_magma_prompt + @dataclass class LiberoConfig: diff --git a/agents/libero/libero_env_utils.py b/agents/libero/libero_env_utils.py index d7bd19c1..d5be9fec 100644 --- a/agents/libero/libero_env_utils.py +++ b/agents/libero/libero_env_utils.py @@ -2,14 +2,16 @@ import math import os -import torch import random -from PIL import Image + import imageio import numpy as np import tensorflow as tf +import torch from libero.libero import get_libero_path from libero.libero.envs import OffScreenRenderEnv +from PIL import Image + def resize_image(img, resize_size): """ @@ -91,7 +93,7 @@ def quat2axisangle(quat): def save_rollout_video(replay_images, success, task_description): """Saves a video replay of a rollout in libero.""" - save_dir = f"./libero_videos" + save_dir = "./libero_videos" os.makedirs(save_dir, exist_ok=True) processed_task_description = task_description.lower().replace(" ", "_").replace("\n", "_").replace(".", "_")[:50] video_path = f"{save_dir}/quick_eval-success={success}--task={processed_task_description}.mp4" diff --git a/agents/libero/libero_magma_utils.py b/agents/libero/libero_magma_utils.py index 3f7afa53..52a6d3fa 100644 --- a/agents/libero/libero_magma_utils.py +++ b/agents/libero/libero_magma_utils.py @@ -1,10 +1,12 @@ -import os import json -import torch +import os + import numpy as np -from magma.image_processing_magma import MagmaImageProcessor -from magma.processing_magma import MagmaProcessor +import torch + from magma.modeling_magma import MagmaForConditionalGeneration +from magma.processing_magma import MagmaProcessor + def get_magma_model(model_name): processor = MagmaProcessor.from_pretrained(model_name, trust_remote_code=True) diff --git a/agents/robot_traj/app.py b/agents/robot_traj/app.py index c84d7b43..ef9a0097 100644 --- a/agents/robot_traj/app.py +++ b/agents/robot_traj/app.py @@ -5,17 +5,14 @@ # Jianwei Yang (jianwyan@microsoft.com) # -------------------------------------------------------- -import os -import warnings -from utils.visualizer import Visualizer -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple +import ast import random -import gradio as gr -import ast, re +import gradio as gr import torch import torchvision from transformers import AutoModelForCausalLM, AutoProcessor +from utils.visualizer import Visualizer ''' build model diff --git a/agents/robot_traj/app.pyi b/agents/robot_traj/app.pyi index 13ab2d91..8478a4be 100644 --- a/agents/robot_traj/app.pyi +++ b/agents/robot_traj/app.pyi @@ -5,17 +5,15 @@ # Jianwei Yang (jianwyan@microsoft.com) # -------------------------------------------------------- -import os -import warnings -from utils.visualizer import Visualizer -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple +import ast import random -import gradio as gr -import ast, re +from typing import TYPE_CHECKING +import gradio as gr import torch import torchvision from transformers import AutoModelForCausalLM, AutoProcessor +from utils.visualizer import Visualizer ''' build model @@ -132,7 +130,6 @@ def inference(image, task, *args, **kwargs): except Exception as e: print(e) return None -from gradio.events import Dependency class ImageMask(gr.components.Image): """ @@ -146,7 +143,8 @@ class ImageMask(gr.components.Image): def preprocess(self, x): return super().preprocess(x) - from typing import Callable, Literal, Sequence, Any, TYPE_CHECKING + from typing import Literal, Sequence + from gradio.blocks import Block if TYPE_CHECKING: from gradio.components import Timer @@ -163,7 +161,8 @@ class Video(gr.components.Video): def preprocess(self, x): return super().preprocess(x) - from typing import Callable, Literal, Sequence, Any, TYPE_CHECKING + from typing import Literal, Sequence + from gradio.blocks import Block if TYPE_CHECKING: from gradio.components import Timer diff --git a/agents/robot_traj/utils/visualizer.py b/agents/robot_traj/utils/visualizer.py index 88287c37..8599dfd8 100644 --- a/agents/robot_traj/utils/visualizer.py +++ b/agents/robot_traj/utils/visualizer.py @@ -4,14 +4,14 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import os -import numpy as np + import imageio +import matplotlib.pyplot as plt +import numpy as np import torch - -from matplotlib import cm import torch.nn.functional as F import torchvision.transforms as transforms -import matplotlib.pyplot as plt +from matplotlib import cm from PIL import Image, ImageDraw diff --git a/agents/ui_agent/app.py b/agents/ui_agent/app.py index 76e2b65f..328048ad 100644 --- a/agents/ui_agent/app.py +++ b/agents/ui_agent/app.py @@ -5,27 +5,19 @@ # Jianwei Yang (jianwyan@microsoft.com) # -------------------------------------------------------- +import base64 +import io from typing import Optional -import spaces + import gradio as gr -import numpy as np +import spaces import torch +from huggingface_hub import snapshot_download from PIL import Image -import io -import re - -import base64, os -from util.utils import check_ocr_box, get_yolo_model, get_caption_model_processor, get_som_labeled_img +from transformers import AutoModelForCausalLM, AutoProcessor +from util.process_utils import extract_bbox, extract_mark_id, pred_2_point from util.som import MarkHelper, plot_boxes_with_marks, plot_circles_with_marks -from util.process_utils import pred_2_point, extract_bbox, extract_mark_id - -import torch -from PIL import Image - -from huggingface_hub import snapshot_download -import torch -from transformers import AutoModelForCausalLM -from transformers import AutoProcessor +from util.utils import check_ocr_box, get_caption_model_processor, get_som_labeled_img, get_yolo_model # Define repository and local directory repo_id = "microsoft/OmniParser-v2.0" # HF repo diff --git a/agents/ui_agent/util/box_annotator.py b/agents/ui_agent/util/box_annotator.py index 26af49c0..ede66c3b 100644 --- a/agents/ui_agent/util/box_annotator.py +++ b/agents/ui_agent/util/box_annotator.py @@ -1,8 +1,7 @@ -from typing import List, Optional, Union, Tuple +from typing import List, Optional, Tuple, Union import cv2 import numpy as np - from supervision.detection.core import Detections from supervision.draw.color import Color, ColorPalette diff --git a/agents/ui_agent/util/omniparser.py b/agents/ui_agent/util/omniparser.py index 536385e6..b590ca97 100644 --- a/agents/ui_agent/util/omniparser.py +++ b/agents/ui_agent/util/omniparser.py @@ -1,9 +1,13 @@ -from util.utils import get_som_labeled_img, get_caption_model_processor, get_yolo_model, check_ocr_box -import torch -from PIL import Image -import io import base64 +import io from typing import Dict + +import torch +from PIL import Image + +from util.utils import check_ocr_box, get_caption_model_processor, get_som_labeled_img, get_yolo_model + + class Omniparser(object): def __init__(self, config: Dict): self.config = config diff --git a/agents/ui_agent/util/process_utils.py b/agents/ui_agent/util/process_utils.py index ea7ac26d..b36f6067 100644 --- a/agents/ui_agent/util/process_utils.py +++ b/agents/ui_agent/util/process_utils.py @@ -1,5 +1,6 @@ import re + # is instruction English def is_english_simple(text): try: diff --git a/agents/ui_agent/util/som.py b/agents/ui_agent/util/som.py index d04f5648..f6df03d4 100644 --- a/agents/ui_agent/util/som.py +++ b/agents/ui_agent/util/som.py @@ -1,13 +1,11 @@ -import torch -from ultralytics import YOLO from PIL import Image -import io -import base64 + device = 'cuda' -from PIL import Image, ImageDraw, ImageFont -import numpy as np import networkx as nx +import numpy as np +from PIL import ImageDraw, ImageFont + # import cv2 font_path = "agents/ui_agent/util/arial.ttf" diff --git a/agents/ui_agent/util/utils.py b/agents/ui_agent/util/utils.py index 9fd3b76a..b61e1b05 100644 --- a/agents/ui_agent/util/utils.py +++ b/agents/ui_agent/util/utils.py @@ -1,23 +1,18 @@ # from ultralytics import YOLO -import os -import io import base64 +import io import time -from PIL import Image, ImageDraw, ImageFont -import json -import requests -# utility function -import os -import json -import sys -import os +# utility function import cv2 +import easyocr import numpy as np + # %matplotlib inline from matplotlib import pyplot as plt -import easyocr from paddleocr import PaddleOCR +from PIL import Image + reader = easyocr.Reader(['en']) paddle_ocr = PaddleOCR( lang='en', # other lang also available @@ -28,26 +23,23 @@ use_dilation=True, # improves accuracy det_db_score_mode='slow', # improves accuracy rec_batch_num=1024) -import time -import base64 -import os -import ast +from typing import List, Tuple, Union + +import supervision as sv import torch -from typing import Tuple, List, Union +import torchvision.transforms as T from torchvision.ops import box_convert -import re from torchvision.transforms import ToPILImage -import supervision as sv -import torchvision.transforms as T -from util.box_annotator import BoxAnnotator + +from util.box_annotator import BoxAnnotator def get_caption_model_processor(model_name, model_name_or_path="Salesforce/blip2-opt-2.7b", device=None): if not device: device = "cuda" if torch.cuda.is_available() else "cpu" if model_name == "blip2": - from transformers import Blip2Processor, Blip2ForConditionalGeneration + from transformers import Blip2ForConditionalGeneration, Blip2Processor processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b") if device == 'cpu': model = Blip2ForConditionalGeneration.from_pretrained( @@ -58,7 +50,7 @@ def get_caption_model_processor(model_name, model_name_or_path="Salesforce/blip2 model_name_or_path, device_map=None, torch_dtype=torch.float16 ).to(device) elif model_name == "florence2": - from transformers import AutoProcessor, AutoModelForCausalLM + from transformers import AutoModelForCausalLM, AutoProcessor processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True) if device == 'cpu': model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float32, trust_remote_code=True) diff --git a/data/__init__.py b/data/__init__.py index 4f49a3ba..a871fa6e 100644 --- a/data/__init__.py +++ b/data/__init__.py @@ -1,15 +1,13 @@ # datasets -from .epic import epic +# data collators +from .data_collator import DataCollatorForHFDataset, DataCollatorForSupervisedDataset + +# (joint) datasets +from .dataset import build_joint_dataset from .ego4d import ego4d +from .epic import epic +from .llava import llava +from .magma import magma from .openx import openx from .openx_magma import openx_magma -from .magma import magma -from .llava import llava from .seeclick import seeclick - -# (joint) datasets -from .dataset import build_joint_dataset - -# data collators -from .data_collator import DataCollatorForSupervisedDataset -from .data_collator import DataCollatorForHFDataset diff --git a/data/conversations.py b/data/conversations.py index b4b01541..88f8e3fa 100644 --- a/data/conversations.py +++ b/data/conversations.py @@ -1,18 +1,17 @@ -import torch -import torchvision +import os +import random import re +import time + import cv2 -import numpy as np -import os -import yaml -from PIL import Image -from data.utils.visual_trace import visual_trace -from data.utils.som_tom import som_prompting, tom_prompting -import torchvision.io as tv_io +import torch import torchvision -import time -import random from decord import VideoReader, cpu +from PIL import Image + +from data.utils.som_tom import som_prompting +from data.utils.visual_trace import visual_trace + class Constructor(): def __init__(self, **kwargs): @@ -192,7 +191,7 @@ def _construct_conv(self, item, video_path, visual_traces): return item if 'image_size' not in item: - assert '(height,width)' in item, f"image_size not in item and (height,width) not in item" + assert '(height,width)' in item, "image_size not in item and (height,width) not in item" item['image_size'] = item['(height,width)'][::-1] if isinstance(item['image_size'][0], torch.Tensor): @@ -514,7 +513,7 @@ def _construct_conv(self, item, video_path, visual_traces): item['conversations'].append({'from': 'gpt', 'value': conv_gpt}) item['image'].append(image) - import pdb; pdb.set_trace() + import pdb; pdb.set_trace() # noqa: I001 return item @@ -580,7 +579,7 @@ def _construct_caption(self, item, video_path, visual_traces): return item if 'image_size' not in item: - assert '(height,width)' in item, f"image_size not in item and (height,width) not in item" + assert '(height,width)' in item, "image_size not in item and (height,width) not in item" item['image_size'] = item['(height,width)'][::-1] if isinstance(item['image_size'][0], torch.Tensor): diff --git a/data/data_collator.py b/data/data_collator.py index af4dff1c..c4c94a54 100644 --- a/data/data_collator.py +++ b/data/data_collator.py @@ -1,9 +1,14 @@ +from dataclasses import dataclass +from typing import Dict, Sequence + import torch -from dataclasses import dataclass, field +import transformers + +from data.utils.constants import ( + IGNORE_INDEX, +) from magma.processing_magma import MagmaProcessor -from typing import Dict, Optional, Sequence, List -import transformers -from data.utils.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN + @dataclass class DataCollatorForSupervisedDataset(object): diff --git a/data/data_item.py b/data/data_item.py index 2afb3d98..36669337 100644 --- a/data/data_item.py +++ b/data/data_item.py @@ -1,13 +1,12 @@ import json -import yaml -import torch -import random import os -import glob -import pickle + +import torch +import yaml from datasets import load_dataset + from .openx import OpenXDataItem -from tqdm import tqdm + class DataItem: """ diff --git a/data/dataset.py b/data/dataset.py index 4f924c74..f90c6f02 100644 --- a/data/dataset.py +++ b/data/dataset.py @@ -1,34 +1,25 @@ -import os +import collections import copy -from dataclasses import dataclass, field -import json -import logging -import pathlib -from typing import Dict, Optional, Sequence, List -import pandas as pd -import torch -import deepspeed -import glob -import pandas as pd -import transformers -import tokenizers +import os import random -import re -import cv2 -from torch.utils.data import Dataset, DataLoader -from torch.utils.data import DataLoader, Dataset, DistributedSampler, IterableDataset +from typing import Dict, List, Sequence + +import torch import torch.distributed as dist -import collections from PIL import Image -from io import BytesIO -from data.utils.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN -from magma.image_processing_magma import MagmaImageProcessor +from torch.utils.data import DataLoader, Dataset, DistributedSampler + +from data.utils.constants import ( + DEFAULT_IM_END_TOKEN, + DEFAULT_IM_START_TOKEN, + DEFAULT_IMAGE_TOKEN, + IGNORE_INDEX, +) from magma.processing_magma import MagmaProcessor -from .data_item import DataItem + from . import * -from PIL import Image, ImageFile -from PIL import ImageDraw, ImageFont -from typing import List, Optional, Union +from .data_item import DataItem + def preprocess_multimodal( sources: Sequence[str], diff --git a/data/ego4d/data_utils.py b/data/ego4d/data_utils.py index 6bce0afc..133cb5c2 100644 --- a/data/ego4d/data_utils.py +++ b/data/ego4d/data_utils.py @@ -1,16 +1,11 @@ -import torch -import torchvision -import re -import cv2 -import numpy as np +import logging import os + import yaml from tqdm import tqdm -from PIL import Image -from data.utils.visual_trace import visual_trace -from data.utils.som_tom import som_prompting, tom_prompting + from data.conversations import Constructor -import logging + logger = logging.getLogger(__name__) class Ego4d(Constructor): diff --git a/data/epic/data_utils.py b/data/epic/data_utils.py index 4d1d2d39..3f1b47d9 100644 --- a/data/epic/data_utils.py +++ b/data/epic/data_utils.py @@ -1,13 +1,10 @@ -import torch -import torchvision -import re -import cv2 -import numpy as np import os + import yaml -from PIL import Image + from data.conversations import Constructor + class EpicKitchen(Constructor): def __init__(self, **kwargs): super(EpicKitchen, self).__init__(**kwargs) diff --git a/data/llava/data_utils.py b/data/llava/data_utils.py index 6bfc23e9..17eafb59 100644 --- a/data/llava/data_utils.py +++ b/data/llava/data_utils.py @@ -1,16 +1,11 @@ -import torch -import torchvision -import re -import cv2 -import numpy as np import os + import yaml from tqdm import tqdm -from PIL import Image -from data.utils.visual_trace import visual_trace -from data.utils.som_tom import som_prompting, tom_prompting + from data.conversations import Constructor + class LlaVA(Constructor): def __init__(self, **kwargs): super(LlaVA, self).__init__(**kwargs) diff --git a/data/magma/data_utils.py b/data/magma/data_utils.py index e49d2874..12b9121d 100644 --- a/data/magma/data_utils.py +++ b/data/magma/data_utils.py @@ -1,14 +1,11 @@ -import torch -import torchvision -import re -import cv2 -import numpy as np import os + import yaml from tqdm import tqdm -from PIL import Image + from data.conversations import Constructor + class Magma(Constructor): def __init__(self, **kwargs): super(Magma, self).__init__(**kwargs) diff --git a/data/openx/__init__.py b/data/openx/__init__.py index a9721b7c..450d0dcc 100644 --- a/data/openx/__init__.py +++ b/data/openx/__init__.py @@ -1,2 +1,2 @@ +from .data_utils import OpenX as openx from .data_utils import OpenXDataItem -from .data_utils import OpenX as openx \ No newline at end of file diff --git a/data/openx/data_utils.py b/data/openx/data_utils.py index b95f6213..c5c45520 100644 --- a/data/openx/data_utils.py +++ b/data/openx/data_utils.py @@ -1,24 +1,18 @@ -import torch -import torchvision -import re -import cv2 -import numpy as np -import os -import yaml import logging -from PIL import Image +import os +from dataclasses import dataclass + +import torch import torch.distributed as dist -from data.utils.visual_trace import visual_trace -from data.utils.som_tom import som_prompting, tom_prompting +import yaml + from data.conversations import Constructor -from .conf import VLAConfig, VLARegistry -from dataclasses import dataclass, field -from magma.processing_magma import MagmaProcessor -from .materialize import get_vla_dataset_and_collator -from .datasets.rlds.utils.data_utils import save_dataset_statistics from data.utils.visual_tracker import visual_tracker +from .conf import VLAConfig +from .materialize import get_vla_dataset_and_collator + logger = logging.getLogger(__name__) """ @@ -27,12 +21,7 @@ General utilities and classes for facilitating data loading and collation. """ -from dataclasses import dataclass -from typing import Callable, Dict, Sequence, Tuple -import torch -from torch.nn.utils.rnn import pad_sequence -from torch import distributed as dist # HuggingFace Default / LLaMa-2 IGNORE_INDEX (for labels) IGNORE_INDEX = -100 diff --git a/data/openx/datasets/datasets.py b/data/openx/datasets/datasets.py index 0782414c..e1e3d01c 100644 --- a/data/openx/datasets/datasets.py +++ b/data/openx/datasets/datasets.py @@ -5,20 +5,17 @@ format to OpenVLA, IterableDataset shim. """ +import collections from dataclasses import dataclass from pathlib import Path -from typing import Any, Dict, Tuple, Type -import collections -import os +from typing import Any, Dict, Tuple + import numpy as np -import cv2 import torch from PIL import Image -import torchvision -from torchvision.transforms import transforms from torch.utils.data import Dataset, IterableDataset +from torchvision.transforms import transforms from transformers import PreTrainedTokenizerBase -from data.utils.som_tom import som_prompting, tom_prompting # from prismatic.models.backbones.llm.prompting import PromptBuilder # from prismatic.models.backbones.vision import ImageTransform @@ -29,7 +26,8 @@ # HuggingFace Default / LLaMa-2 IGNORE_INDEX (for labels) IGNORE_INDEX = -100 -from typing import Callable, Dict, Sequence, Tuple +from typing import Callable, Sequence + def tree_map(fn: Callable, tree: dict) -> dict: """Maps a function over a nested dictionary.""" diff --git a/data/openx/datasets/rlds/dataset.py b/data/openx/datasets/rlds/dataset.py index 7f38a116..84be0938 100644 --- a/data/openx/datasets/rlds/dataset.py +++ b/data/openx/datasets/rlds/dataset.py @@ -7,15 +7,15 @@ import copy import inspect import json +import logging from functools import partial from typing import Callable, Dict, List, Optional, Tuple, Union -import logging -import torch.distributed as dist import dlimp as dl import numpy as np import tensorflow as tf import tensorflow_datasets as tfds +import torch.distributed as dist from data.openx.datasets.rlds import obs_transforms, traj_transforms from data.openx.datasets.rlds.utils import goal_relabeling, task_augmentation diff --git a/data/openx/datasets/rlds/datasets_latent.py b/data/openx/datasets/rlds/datasets_latent.py index 9ad15bed..12de0bf9 100644 --- a/data/openx/datasets/rlds/datasets_latent.py +++ b/data/openx/datasets/rlds/datasets_latent.py @@ -14,7 +14,6 @@ import numpy as np import tensorflow as tf import tensorflow_datasets as tfds - from prismatic.logging import initialize_logging from prismatic.vla.datasets.rlds import obs_transforms, traj_transforms from prismatic.vla.datasets.rlds.utils import goal_relabeling, task_augmentation diff --git a/data/openx/datasets/rlds/utils/data_utils.py b/data/openx/datasets/rlds/utils/data_utils.py index f520b2ef..e85c1e4d 100644 --- a/data/openx/datasets/rlds/utils/data_utils.py +++ b/data/openx/datasets/rlds/utils/data_utils.py @@ -6,6 +6,7 @@ import hashlib import json +import logging import os from enum import Enum from typing import Any, Callable, Dict, List, Optional, Tuple @@ -14,7 +15,6 @@ import numpy as np import tensorflow as tf from tqdm import tqdm -import logging # from prismatic.logging import initialize_logging diff --git a/data/openx/materialize.py b/data/openx/materialize.py index ac1a0e3b..84209b43 100644 --- a/data/openx/materialize.py +++ b/data/openx/materialize.py @@ -5,17 +5,18 @@ exports individual functions for clear control flow. """ +from dataclasses import dataclass from pathlib import Path -from typing import Tuple, Type, Dict, Sequence -from dataclasses import dataclass, field +from typing import Dict, Sequence, Tuple +import torch from torch.utils.data import Dataset from transformers import PreTrainedTokenizerBase -import torch from .action_tokenizer import ActionTokenizer from .datasets import EpisodicRLDSDataset, RLDSBatchTransform, RLDSDataset + @dataclass class PaddedCollatorForLanguageModeling: model_max_length: int diff --git a/data/openx_magma/data_utils.py b/data/openx_magma/data_utils.py index 83eedbde..0a4711cd 100644 --- a/data/openx_magma/data_utils.py +++ b/data/openx_magma/data_utils.py @@ -1,15 +1,15 @@ +import os + +import cv2 import torch import torchvision -import re -import cv2 -import numpy as np -import os import yaml -from PIL import Image -from data.utils.visual_trace import visual_trace -from data.utils.som_tom import som_prompting, tom_prompting + from data.conversations import Constructor from data.openx.action_tokenizer import ActionTokenizer +from data.utils.som_tom import som_prompting +from data.utils.visual_trace import visual_trace + class OpenXMagma(Constructor): def __init__(self, **kwargs): diff --git a/data/seeclick/data_utils.py b/data/seeclick/data_utils.py index aafb00bf..2c2e010c 100644 --- a/data/seeclick/data_utils.py +++ b/data/seeclick/data_utils.py @@ -1,16 +1,12 @@ -import torch -import torchvision -import re -import cv2 -import numpy as np import os +import re + import yaml from tqdm import tqdm -from PIL import Image -from data.utils.visual_trace import visual_trace -from data.utils.som_tom import som_prompting, tom_prompting + from data.conversations import Constructor + class SeeClick(Constructor): def __init__(self, **kwargs): super(SeeClick, self).__init__(**kwargs) diff --git a/data/utils/som_tom.py b/data/utils/som_tom.py index 755d7ba8..b969dd21 100644 --- a/data/utils/som_tom.py +++ b/data/utils/som_tom.py @@ -1,7 +1,8 @@ +import matplotlib.pyplot as plt +import numpy as np import torch from PIL import Image, ImageDraw, ImageFont -import numpy as np -import matplotlib.pyplot as plt + def som_prompting(image, pos_traces, neg_traces, draw_som_positive=False, draw_som_negative=False): """ diff --git a/data/utils/visual_trace.py b/data/utils/visual_trace.py index 5c061ac6..c391f941 100644 --- a/data/utils/visual_trace.py +++ b/data/utils/visual_trace.py @@ -1,16 +1,9 @@ -import io -import os -import cv2 -import json -import torch -import numpy as np -from PIL import Image -from IPython import display -from tqdm import tqdm -from cotracker.utils.visualizer import Visualizer, read_video_from_path -from matplotlib import cm import faiss +import torch +from cotracker.utils.visualizer import Visualizer from kmeans_pytorch import kmeans +from matplotlib import cm + class visual_trace(): def __init__( diff --git a/data/utils/visual_tracker.py b/data/utils/visual_tracker.py index 57e4f56d..79a6dbb5 100644 --- a/data/utils/visual_tracker.py +++ b/data/utils/visual_tracker.py @@ -1,16 +1,9 @@ -import io -import os -import cv2 -import json +import faiss import torch -import numpy as np -from PIL import Image -from IPython import display -from tqdm import tqdm -from cotracker.utils.visualizer import Visualizer, read_video_from_path from cotracker.predictor import CoTrackerPredictor +from cotracker.utils.visualizer import Visualizer from matplotlib import cm -import faiss + class visual_tracker(): def __init__( diff --git a/magma/configuration_magma.py b/magma/configuration_magma.py index 51a79ef1..06df5e39 100644 --- a/magma/configuration_magma.py +++ b/magma/configuration_magma.py @@ -20,8 +20,8 @@ """Magma model configuration""" from transformers.configuration_utils import PretrainedConfig -from transformers.utils import logging from transformers.models.auto import CONFIG_MAPPING +from transformers.utils import logging logger = logging.get_logger(__name__) diff --git a/magma/image_processing_magma.py b/magma/image_processing_magma.py index 1e889046..92bb95dc 100644 --- a/magma/image_processing_magma.py +++ b/magma/image_processing_magma.py @@ -15,9 +15,10 @@ """Image processor class for Magma.""" -from typing import List, Optional, Union import ast -import numpy as np +from typing import List, Optional, Union + +from transformers import AutoImageProcessor from transformers.image_processing_utils import BaseImageProcessor, BatchFeature from transformers.image_transforms import ( convert_to_rgb, @@ -31,17 +32,16 @@ ) from transformers.utils import TensorType, is_vision_available, logging -from transformers import AutoImageProcessor - logger = logging.get_logger(__name__) if is_vision_available(): - from PIL import Image + pass import torch import torchvision + def select_best_resolution(original_size, possible_resolutions): """ Selects the best resolution from a list of possible resolutions based on the original size. diff --git a/magma/image_tower_magma.py b/magma/image_tower_magma.py index d1bb0139..6b0fa847 100755 --- a/magma/image_tower_magma.py +++ b/magma/image_tower_magma.py @@ -15,33 +15,22 @@ """Image processor class for Magma.""" -from typing import List, Optional, Union import logging +from typing import Optional, Union # Configure root logger logging.basicConfig(level=logging.INFO) -from transformers.image_processing_utils import BaseImageProcessor, BatchFeature -from transformers.image_transforms import ( - convert_to_rgb, -) -from transformers.image_utils import ( - OPENAI_CLIP_MEAN, - OPENAI_CLIP_STD, - ImageInput, - make_list_of_images, - valid_images, -) - -from transformers.utils import TensorType, is_vision_available, logging + +from transformers.utils import is_vision_available, logging + # logging.set_verbosity_info() logger = logging.get_logger(__name__) if is_vision_available(): - from PIL import Image + pass -import torchvision # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. @@ -49,22 +38,17 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import json -import torch -import torch.nn as nn -import torch.nn.functional as F +from dataclasses import asdict +from typing import Any, Dict, Tuple import open_clip -from open_clip.transform import image_transform_v2, AugmentationCfg, PreprocessCfg, merge_preprocess_dict, merge_preprocess_kwargs -from open_clip.pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained,\ - list_pretrained_tags_by_model, download_pretrained_from_hf -from open_clip.model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\ - resize_pos_embed, get_cast_dtype, resize_text_pos_embed, set_model_preprocess_cfg -from pathlib import Path -from typing import Optional, Tuple, Type -from functools import partial +import torch +import torch.nn as nn import torch.utils.checkpoint as checkpoint -from typing import Any, Dict, Optional, Tuple, Union -from dataclasses import asdict +from open_clip.model import CLIP, CustomTextCLIP, convert_weights_to_lp, get_cast_dtype, set_model_preprocess_cfg +from open_clip.pretrained import download_pretrained, download_pretrained_from_hf, get_pretrained_cfg +from open_clip.transform import AugmentationCfg, PreprocessCfg, merge_preprocess_dict, merge_preprocess_kwargs + HF_HUB_PREFIX = 'hf-hub:' def _get_hf_config(model_id, cache_dir=None): diff --git a/magma/modeling_magma.py b/magma/modeling_magma.py index ab210fd2..2bfb1132 100755 --- a/magma/modeling_magma.py +++ b/magma/modeling_magma.py @@ -14,30 +14,27 @@ # limitations under the License. """PyTorch Magma model.""" -import math -import re import os +import re from dataclasses import dataclass from typing import List, Optional, Tuple, Union -import numpy as np import torch +import torch.distributed as dist import torch.utils.checkpoint -from torch import nn import wandb -import torch.distributed as dist +from torch import nn +from transformers import AutoConfig, AutoModelForCausalLM +from transformers.cache_utils import Cache from transformers.modeling_utils import PreTrainedModel -from transformers.activations import ACT2FN -from transformers.cache_utils import Cache, DynamicCache -from transformers.utils import ModelOutput from transformers.utils import ( - add_code_sample_docstrings, + ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings, ) -from transformers import AutoConfig, AutoModelForCausalLM + from .configuration_magma import MagmaConfig from .image_tower_magma import MagmaImageTower @@ -327,7 +324,6 @@ def tie_weights(self): return self.language_model.tie_weights() def load_special_module_from_ckpt(self, ckpt_path, torch_dtype=None): - from deepspeed.runtime.zero import Init from deepspeed import zero # Defer initialization for ZeRO-3 compatibility # with Init(data_parallel_group=None): diff --git a/magma/processing_magma.py b/magma/processing_magma.py index 18c65daa..121fac67 100644 --- a/magma/processing_magma.py +++ b/magma/processing_magma.py @@ -23,7 +23,6 @@ from transformers.processing_utils import ProcessorMixin from transformers.tokenization_utils_base import PaddingStrategy, TextInput, TruncationStrategy from transformers.utils import TensorType -from .configuration_magma import MagmaConfig class MagmaProcessor(ProcessorMixin): diff --git a/pyproject.toml b/pyproject.toml index d8e63c39..0f4ef4d4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -132,3 +132,20 @@ exclude = [ "wandb", "docs", ] + +[tool.ruff] +line-length = 120 + +[tool.ruff.lint] +select = ["E", "F", "I"] +# F821: missing imports in several files (real issues, out of scope for this PR) +# I001: inline `import pdb; pdb.set_trace()` debug statements trigger block-level sort check +# E101: mixed tabs/spaces in tools/simplerenv-magma +# F811: intentional subclassing of imported TrainingArguments in train.py +ignore = ["E501", "E402", "E731", "E722", "E741", "E721", "E701", "E702", "E711", "I001", "F841", "F403", "F405", "F821", "E101", "F811"] + +[tool.ruff.lint.per-file-ignores] +"__init__.py" = ["F401"] + +[tool.uv] +# Install with: uv sync diff --git a/server/test_api.py b/server/test_api.py index fb1f5ecf..b2f88ab6 100755 --- a/server/test_api.py +++ b/server/test_api.py @@ -3,14 +3,13 @@ Test script for the Magma API service. This script tests if the API is running correctly and can load the model. """ -import requests import argparse -import time import base64 -from PIL import Image -import io -import json import sys +import time + +import requests + def test_health(base_url): """Test the health endpoint""" diff --git a/tools/lmms-eval-magma/magma.py b/tools/lmms-eval-magma/magma.py index 0fd7f946..f8ff82a8 100644 --- a/tools/lmms-eval-magma/magma.py +++ b/tools/lmms-eval-magma/magma.py @@ -1,26 +1,23 @@ -import os -import uuid import warnings from typing import List, Optional, Tuple, Union +import numpy as np +import PIL import torch from accelerate import Accelerator, DistributedType -from tqdm import tqdm -import PIL -from torchvision.transforms.functional import to_pil_image from decord import VideoReader, cpu -import numpy as np -from lmms_eval import utils from lmms_eval.api.instance import Instance from lmms_eval.api.model import lmms from lmms_eval.api.registry import register_model -from lmms_eval.models.model_utils.qwen.qwen_generate_utils import make_context +from torchvision.transforms.functional import to_pil_image +from tqdm import tqdm warnings.simplefilter("ignore", category=DeprecationWarning) warnings.filterwarnings("ignore") from loguru import logger as eval_logger -from transformers import AutoModelForCausalLM, AutoProcessor +from transformers import AutoModelForCausalLM, AutoProcessor + @register_model("magma") class Magma(lmms): diff --git a/tools/simplerenv-magma/simpler_env/main_inference_magma.py b/tools/simplerenv-magma/simpler_env/main_inference_magma.py index 14622144..8a707700 100644 --- a/tools/simplerenv-magma/simpler_env/main_inference_magma.py +++ b/tools/simplerenv-magma/simpler_env/main_inference_magma.py @@ -2,11 +2,11 @@ import numpy as np import tensorflow as tf - from simpler_env.evaluation.argparse import get_args from simpler_env.evaluation.maniskill2_evaluator import maniskill2_evaluator -from simpler_env.policies.octo.octo_server_model import OctoServerInference from simpler_env.policies.magma.magma_model import MagmaInference +from simpler_env.policies.octo.octo_server_model import OctoServerInference + try: from simpler_env.policies.octo.octo_model import OctoInference except ImportError as e: diff --git a/tools/simplerenv-magma/simpler_env/policies/magma/magma_model.py b/tools/simplerenv-magma/simpler_env/policies/magma/magma_model.py index 2b950f6f..bf8d6d9f 100644 --- a/tools/simplerenv-magma/simpler_env/policies/magma/magma_model.py +++ b/tools/simplerenv-magma/simpler_env/policies/magma/magma_model.py @@ -1,15 +1,11 @@ -import numpy as np -from PIL import Image import random + +import numpy as np import torch -import torchvision -import json -import sys -import os -from transformers import AutoProcessor, AutoModelForCausalLM +from PIL import Image +from transformers import AutoModelForCausalLM, AutoProcessor from transforms3d.euler import euler2axangle - action_norm_stats = { "bridge_orig": {'mask': [True, True, True, True, True, True, False], 'max': [0.41691166162490845, 0.25864794850349426, 0.21218234300613403, 3.122201919555664, 1.8618112802505493, 6.280478477478027, 1.0], 'mean': [0.0002334194869035855, 0.00013004911306779832, -0.00012762474943883717, -0.0001556558854645118, -0.0004039328487124294, 0.00023557482927571982, 0.5764579176902771], 'min': [-0.4007510244846344, -0.13874775171279907, -0.22553899884223938, -3.2010786533355713, -1.8618112802505493, -6.279075622558594, 0.0], 'q01': [-0.02872725307941437, -0.04170349963009357, -0.026093858778476715, -0.08092105075716972, -0.09288699507713317, -0.20718276381492615, 0.0], 'q99': [0.028309678435325586, 0.040855254605412394, 0.040161586627364146, 0.08192047759890528, 0.07792850524187081, 0.20382574498653397, 1.0], 'std': [0.009765930473804474, 0.013689135201275349, 0.012667362578213215, 0.028534092009067535, 0.030637972056865692, 0.07691419124603271, 0.4973701536655426]}, "google_robot": {'mask': [True, True, True, True, True, True, False], 'max': [2.9984593391418457, 22.09052848815918, 2.7507524490356445, 1.570636510848999, 1.5321086645126343, 1.5691522359848022, 1.0], 'mean': [0.006987582892179489, 0.006265917327255011, -0.01262515690177679, 0.04333311319351196, -0.005756212864071131, 0.0009130256366916001, 0.5354204773902893], 'min': [-2.0204520225524902, -5.497899532318115, -2.031663417816162, -1.569917917251587, -1.569892168045044, -1.570419430732727, 0.0], 'q01': [-0.22453527510166169, -0.14820013284683228, -0.231589707583189, -0.3517994859814644, -0.4193011274933815, -0.43643461108207704, 0.0], 'q99': [0.17824687153100965, 0.14938379630446405, 0.21842354819178575, 0.5892666035890578, 0.35272657424211445, 0.44796681255102094, 1.0], 'std': [0.0692116990685463, 0.05970962345600128, 0.07353084534406662, 0.15610496699810028, 0.13164450228214264, 0.14593800902366638, 0.497110515832901]} diff --git a/tools/som_tom/demo.py b/tools/som_tom/demo.py index debfd322..5b990774 100644 --- a/tools/som_tom/demo.py +++ b/tools/som_tom/demo.py @@ -1,18 +1,12 @@ -import os -import json import cv2 -import csv -import io import numpy as np -from PIL import Image -from tqdm import tqdm - import torch import torchvision - from cotracker.utils.visualizer import Visualizer +from PIL import Image + +from data.utils.som_tom import som_prompting from data.utils.visual_trace import visual_trace -from data.utils.som_tom import som_prompting, tom_prompting device = 'cuda' grid_size = 15 @@ -80,7 +74,7 @@ def som_tom(video, pred_tracks, pred_visibility, item={}, epsilon=2): # visualize the traces images = [image] * pos_tracks.shape[1] video = torch.stack([torchvision.transforms.ToTensor()(img) for img in images])[None].float()*255 - _ = vis.visualize(video, pos_tracks, pos_visibility, filename=f"som_tom") + _ = vis.visualize(video, pos_tracks, pos_visibility, filename="som_tom") video_path = "assets/videos/tom_orig_sample.mp4" # load video diff --git a/tools/video_preprocessing/curate_list.py b/tools/video_preprocessing/curate_list.py index 8cd242f4..de50a710 100644 --- a/tools/video_preprocessing/curate_list.py +++ b/tools/video_preprocessing/curate_list.py @@ -1,17 +1,11 @@ -import torch -import torchvision -from torch.utils.data import DataLoader -import os -import sys import argparse -from typing import Dict, Optional, Sequence, List -from dataclasses import dataclass, field -import clip -import multiprocessing as mp -from dataloader import * -import threading import json +import os import pickle +import threading + +import torch +from dataloader import * parser = argparse.ArgumentParser('') parser.add_argument('--dataset_name', type=str, default="video-dataset", metavar='DN', diff --git a/tools/video_preprocessing/run_clip_filtering.py b/tools/video_preprocessing/run_clip_filtering.py index 68e563f3..19e2887c 100644 --- a/tools/video_preprocessing/run_clip_filtering.py +++ b/tools/video_preprocessing/run_clip_filtering.py @@ -1,16 +1,14 @@ +import argparse +import concurrent.futures +import os +from dataclasses import dataclass +from typing import Dict, Sequence + +import clip import torch import torchvision from torch.utils.data import DataLoader, Dataset -from torchvision import datasets -from torchvision.transforms import Resize from torchvision.transforms import ToPILImage -import os -import sys -import argparse -from typing import Dict, Optional, Sequence, List -from dataclasses import dataclass, field -import clip -import concurrent.futures parser = argparse.ArgumentParser('') parser.add_argument('--dataset_name', type=str, default="video-dataset", metavar='DN', @@ -126,7 +124,6 @@ def main(): if not os.path.exists(args.output_dir): os.mkdir(args.output_dir) - import clip model, preprocess = clip.load("ViT-L/14@336px", device='cuda') model.eval() diff --git a/tools/video_preprocessing/run_detect_segments.py b/tools/video_preprocessing/run_detect_segments.py index 009c70a1..67d3701e 100644 --- a/tools/video_preprocessing/run_detect_segments.py +++ b/tools/video_preprocessing/run_detect_segments.py @@ -1,16 +1,13 @@ -import torch -import json -import cv2 -import os -import sys -import csv -import pickle import argparse -import random -import numpy as np +import json import multiprocessing as mp +import os + +import cv2 import imageio -from scenedetect import detect, ContentDetector +import numpy as np +import torch +from scenedetect import ContentDetector, detect parser = argparse.ArgumentParser('') parser.add_argument('--ann_path', type=str, default="/path/to/json/file", metavar='AP', diff --git a/train.py b/train.py index 0a80bdec..0352a2b6 100644 --- a/train.py +++ b/train.py @@ -14,37 +14,29 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -import copy -from dataclasses import dataclass, field import json import logging +import os import pathlib -from typing import Dict, Optional, Sequence, List +from dataclasses import dataclass, field +from typing import Dict, Optional, Sequence + import torch -import deepspeed -import glob import transformers -import tokenizers -import random -import re - -from magma.image_processing_magma import MagmaImageProcessor -from magma.processing_magma import MagmaProcessor -from magma.modeling_magma import MagmaForCausalLM -from magma.configuration_magma import MagmaConfig from transformers import ( + AutoConfig, AutoModelForCausalLM, - AutoProcessor, - BitsAndBytesConfig, - Trainer, - TrainingArguments, + AutoTokenizer, + TrainingArguments, ) -from transformers import AutoTokenizer, AutoConfig from transformers.trainer import get_model_param_count -from trainer import MagmaTrainer from data import * +from magma.configuration_magma import MagmaConfig +from magma.image_processing_magma import MagmaImageProcessor +from magma.modeling_magma import MagmaForCausalLM +from magma.processing_magma import MagmaProcessor +from trainer import MagmaTrainer local_rank = None @@ -52,7 +44,6 @@ def rank0_print(*args): if local_rank == 0: print(*args) -from packaging import version @dataclass class ModelArguments: @@ -237,7 +228,7 @@ def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, os.makedirs(mm_projector_folder, exist_ok=True) torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin')) else: - torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin')) + torch.save(weight_to_save, os.path.join(output_dir, 'mm_projector.bin')) return if trainer.deepspeed: @@ -423,7 +414,7 @@ def train(): **bnb_model_from_pretrained_args ) # reload vision encoder - from open_clip.pretrained import download_pretrained_from_hf + from open_clip.pretrained import download_pretrained_from_hf if vision_config['vision_tower'] == 'convnext': model_id = 'laion/CLIP-convnext_large-laion2B-s34B-b82K-augreg' else: diff --git a/trainer/trainer.py b/trainer/trainer.py index cfbef6c6..39678ecf 100644 --- a/trainer/trainer.py +++ b/trainer/trainer.py @@ -1,18 +1,17 @@ import os -import torch +from typing import List, Optional -from torch.utils.data import Sampler +import torch from torch.cuda import synchronize - +from torch.utils.data import Sampler from transformers import Trainer from transformers.trainer import ( - is_sagemaker_mp_enabled, + ALL_LAYERNORM_LAYERS, get_parameter_names, has_length, - ALL_LAYERNORM_LAYERS, + is_sagemaker_mp_enabled, logger, ) -from typing import List, Optional def maybe_zero_3(param, ignore_status=False, name=None): @@ -296,8 +295,8 @@ def _save_checkpoint(self, model, trial, metrics=None): if self.args.local_rank == 0 or self.args.local_rank == -1: self.model.config.save_pretrained(output_dir) print(f"keys to match: {keys_to_match}") - print(f"save checkpoint to {os.path.join(output_dir, f'mm_projector.bin')}") - torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin')) + print(f"save checkpoint to {os.path.join(output_dir, 'mm_projector.bin')}") + torch.save(weight_to_save, os.path.join(output_dir, 'mm_projector.bin')) else: super(MagmaTrainer, self)._save_checkpoint(model, trial, metrics)