diff --git a/extract_features/README.md b/extract_features/README.md new file mode 100644 index 0000000..8416477 --- /dev/null +++ b/extract_features/README.md @@ -0,0 +1,5 @@ +To run: + +pip install geowatch[all-strict] +pip install --prefer-binary GDAL>=3.4.1 --find-links https://girder.github.io/large_image_wheels +pip install timm \ No newline at end of file diff --git a/extract_features/create_coco_file.py b/extract_features/create_coco_file.py new file mode 100644 index 0000000..f27908a --- /dev/null +++ b/extract_features/create_coco_file.py @@ -0,0 +1,111 @@ +# %% Load Libraries +import numpy as np +from pathlib import Path +import copy +import json +from rich.progress import Progress + + +import kwimage +import kwcoco + +import logging +from rich.logging import RichHandler + +logging.basicConfig( + level="NOTSET", format="%(message)s", datefmt="[%X]", handlers=[RichHandler()] +) + +log = logging.getLogger("rich") + + +# %% Set Parameters + +path_to_dataset_folder = '/mnt/3bcbbdb5-8833-4d40-a51a-7fca7f39ec24/data2/xview2/geotiffs' +output_coco_file = './xview2.coco.zip' + +# %% Load dataset + + +dataset = Path(path_to_dataset_folder) + +img_list = list(dataset.glob('*/*/*.tif')) +# %% Test to see if labels for each image + +def get_labels_path(img_path): + labels_path = copy.deepcopy(img_path) + return Path(img_path.parent.parent,'labels',img_path.stem + '.json') + +for img_path in img_list: + labels_path = get_labels_path(img_path) + if not labels_path.exists(): + log.error(labels_path) + +else: + log.info('All label files found') + + + + +# %% Create COCO + +coco = kwcoco.CocoDataset() + +coco.add_category('un-classified') +coco.add_category('no-damage') +coco.add_category('minor-damage') +coco.add_category('major-damage') +coco.add_category('destroyed') + +n = len(img_list) + +with Progress() as progress: + task1 = progress.add_task("[red]running_image...", total=n) + + for img_path in img_list: + progress.update(task1, advance=1, refresh=True) + + # Load labels + labels_path = get_labels_path(img_path) + with open(labels_path,'r') as f: + label_json = json.load(f) + + + # Collect Metadata + metadata = label_json['metadata'] + pre_disaster = labels_path.name[-17:-14] == 'pre' + if pre_disaster: + time = 'pre_disaster' + else: + time = 'post disaster' + + # add image to coco + gid = coco.add_image(img_path, width=metadata['width'], + height=metadata['height'], metadata=metadata, + split=img_path.parts[-3], time=time) + + # For each annotation, extract info and add to coco + for d_ann in label_json['features']['xy']: + properties = d_ann['properties'] + polygon = d_ann['wkt'] + bbox = kwimage.Polygon.coerce(d_ann['wkt']).to_box().toformat('xywh') + + #pre disaster doesn't have labels so assuming no-damage + if pre_disaster: + subtype = 'no-damage' + else: + subtype = properties['subtype'] + + cid = coco.name_to_cat[subtype]['id'] + coco.add_annotation(image_id=gid, category_id=cid, + bbox=bbox, polygon=polygon, properties=properties) + + +# %% Save Dataset + +# coco._ensure_json_serializable() + +coco.fpath = output_coco_file +coco.dump() + +# %% diff --git a/extract_features/extract_features.py b/extract_features/extract_features.py new file mode 100644 index 0000000..80aa237 --- /dev/null +++ b/extract_features/extract_features.py @@ -0,0 +1,145 @@ +# %% Load Libraries +import numpy as np + +import kwimage +import kwcoco + +import timm +import torch +from torch.cuda.amp import autocast +from rich.progress import Progress, TextColumn, BarColumn, TaskProgressColumn, MofNCompleteColumn, TimeElapsedColumn, TimeRemainingColumn +import logging +from rich.logging import RichHandler + + +logging.basicConfig( + level="NOTSET", format="%(message)s", datefmt="[%X]", handlers=[RichHandler()] +) + +log = logging.getLogger("rich") + +# %% Set Parameters + + +# Model to extract out the files +timm_model_name = 'convnext_xlarge_in22ft1k' + +# Coco file you want to extract feature out of +input_coco_filename = './xview2.coco.json' + +# Coco file you want the features to be saved into (will retain other images) +output_coco_filename = f'/mnt/3bcbbdb5-8833-4d40-a51a-7fca7f39ec24/data2/xview2/xview2_{timm_model_name}.coco.json' + +device = 'cuda' + + +# %% Load the data + + +coco = kwcoco.CocoDataset.coerce(input_coco_filename) + + +# %% Setup Feature Extractor + +# Check to see if model is available +avail_pretrained_models = timm.list_models(pretrained=True) + +if timm_model_name not in avail_pretrained_models: + log.error(f'Model called "{timm_model_name}". Here are all the models {avail_pretrained_models}') + raise Exception("Fix Network Name") + +# Will download pretrained network if not already downloaded +feature_model = timm.create_model(timm_model_name, pretrained=True, num_classes=0) +feature_model.to(device) +# logit_model = timm.create_model(timm_model_name, pretrained=True) +# logit_model.to(device) + + +model_info = timm.data.resolve_data_config(args={},model=feature_model) + +feature_model.eval() +# logit_model.eval() + +def predict_feature(chip=None): + # return dict({'logits':logit_model(chip), 'features':feature_model(chip)}) + return dict({'features':feature_model(chip)}) + +# x = torch.rand((1,3,model_info['input_size'][1], model_info['input_size'][2])).to(device) +# y = feature_model(x) +# feature_array = np.zeros((coco.n_annots,y.shape[1])) +# y = logit_model(x) +# logit_array = np.zeros((coco.n_annots,y.shape[1])) +# del y + +# %% +def prepare_image(x): + x = torch.asarray(img).to(device) + x = x.transpose(0,2) + x = x / 255.0 + x[0] = x[0] - model_info['mean'][0] / model_info['std'][0] + x[0] = x[0] - model_info['mean'][1] / model_info['std'][1] + x[0] = x[0] - model_info['mean'][2] / model_info['std'][2] + return x[None, :] + + +with Progress( + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TaskProgressColumn(), + MofNCompleteColumn(), + TimeElapsedColumn(), + TimeRemainingColumn(), +) as progress: + task1 = progress.add_task("[red] Images...", total=coco.n_images) + task2 = progress.add_task("[blue] Annots...", total=coco.n_annots) + id = 0 + for gid in coco.images(): + coco_img = coco.coco_image(gid) + annots = coco_img.annots() + for ann in annots.objs: + box = kwimage.Box.coerce(ann['bbox'], format='xywh') + box = box.to_ltrb() + delay_img = coco_img.imdelay() + delay_img = delay_img.crop(box.quantize().to_slice(), clip=False, wrap=False) + delay_img = delay_img.resize(model_info['input_size'][1:]) + img = delay_img.finalize() + x = prepare_image(img) + + feats = predict_feature(x) + ann['features'] = feats['features'].tolist() + # ann['logits'] = feats['logits'].tolist() + # plt.imshow(img) + + progress.update(task2, advance=1, refresh=True) + + progress.update(task1, advance=1, refresh=True) + + + +# %% +coco._ensure_json_serializable() + +coco.fpath = output_coco_filename +coco.dump() + + + +# %% + +# coco2 = coco.view_sql() +# coco2.pandas_table('images') + + +# %% + +# annots = coco.annots() +# # %% + + +# data = annots.lookup('image_id') +# flags = [x%2==0 for x in data] +# annots.compress(flags) + +# annots.__dict__ + +# %%