diff --git a/README.md b/README.md index b37f228..6189a19 100644 --- a/README.md +++ b/README.md @@ -30,11 +30,12 @@ Key Contributions: ## Release -- [10/30] 🔥 We released the code of **FERRET** model. +- [12/14] 🔥 We released the [checkpoints(7B, 13B)](#checkpoints). +- [10/30] 🔥 We released the code of **FERRET** model and [Ferret-Bench](ferret/eval/ferret_gpt4_data). -**Usage and License Notices**: The data, and code is intended and licensed for research use only. They are also restricted to uses that follow the license agreement of LLaMA, Vicuna and GPT-4. The dataset is CC BY NC 4.0 (allowing only non-commercial use) and models trained using the dataset should not be used outside of research purposes. +**Usage and License Notices**: The data, and code is intended and licensed for research use only. They are also restricted to uses that follow the license agreement of LLaMA, Vicuna and GPT-4. The dataset is CC BY NC 4.0 (allowing only non-commercial use) and models trained using the dataset should not be used outside of research purposes. ## Contents - [Install](#install) @@ -95,6 +96,25 @@ The scripts are provided ([7B](experiments/ferret_7b_train.sh), [13B](experiment Please see this [doc](EVAL.md) for the details. +## Checkpoints +We extracted the `delta` between our pre-trained model and Vicuna. Please first download weights of Vicuna following the [previous instruction](#prepare-vicuna-checkpoint-and-llavas-projector). Then download our prepared offsets of weights: [7B](https://docs-assets.developer.apple.com/ml-research/models/ferret/ferret-7b/ferret-7b-delta.zip), [13B](https://docs-assets.developer.apple.com/ml-research/models/ferret/ferret-13b/ferret-13b-delta.zip) using `wget` or `curl`, and unzip the downloaded offsets. Lastly, apply the offset to the Vicuna's weight by running the following script: +```Shell +# 7B +python3 -m ferret.model.apply_delta \ + --base ./model/vicuna-7b-v1-3 \ + --target ./model/ferret-7b-v1-3 \ + --delta path/to/ferret-7b-delta +# 13B +python3 -m ferret.model.apply_delta \ + --base ./model/vicuna-13b-v1-3 \ + --target ./model/ferret-13b-v1-3 \ + --delta path/to/ferret-13b-delta +``` + +**Notices**: Apple's rights in the attached weight differentials are hereby licensed under the CC-BY-NC license. Apple makes no representations with regards to LLaMa or any other third party software, which are subject to their own terms. + +Please refer to the next section about how to set up a local demo with pre-trained weight. + ## Demo To run our demo, you need to train FERRET and use the checkpoints locally. Gradio web UI is used. Please run the following commands one by one. diff --git a/ferret/model/apply_delta.py b/ferret/model/apply_delta.py index 75abb6e..5b0cd68 100644 --- a/ferret/model/apply_delta.py +++ b/ferret/model/apply_delta.py @@ -1,6 +1,16 @@ """ Usage: -python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta +# 7B +python3 -m ferret.model.apply_delta \ + --base ./model/vicuna-7b-v1-3 \ + --target ./model/ferret-7b-v1-3 \ + --delta ./checkpoints/ferret_ft_clipL336_vicunaV1-3-7b_3Ep_dataV16_RSamplerV2/ferret-7b-delta + +# 13B +python3 -m ferret.model.apply_delta \ + --base ./model/vicuna-13b-v1-3 \ + --target ./model/ferret-13b-v1-3 \ + --delta ./checkpoints/ferret_ft_clipL336_vicunaV1-3-13b_3Ep_dataV16_RSamplerV2/ferret-13b-delta """ import argparse @@ -10,6 +20,18 @@ from transformers import AutoTokenizer, AutoModelForCausalLM from ferret import FERRETLlamaForCausalLM +exclude_name_lists = ['model.mm_projector.weight', 'model.mm_projector.bias', + 'model.region_geo_sampler.agg_projector_list.0.net.0.bias', 'model.region_geo_sampler.agg_projector_list.0.net.0.weight', + 'model.region_geo_sampler.agg_projector_list.0.norm.bias', 'model.region_geo_sampler.agg_projector_list.0.norm.weight', + 'model.region_geo_sampler.agg_projector_list.1.net.0.bias', 'model.region_geo_sampler.agg_projector_list.1.net.0.weight', + 'model.region_geo_sampler.agg_projector_list.1.norm.bias', 'model.region_geo_sampler.agg_projector_list.1.norm.weight', + 'model.region_geo_sampler.diff_projector_list.0.bias', 'model.region_geo_sampler.diff_projector_list.0.weight', + 'model.region_geo_sampler.diff_projector_list.1.bias', 'model.region_geo_sampler.diff_projector_list.1.weight', + 'model.region_geo_sampler.dim_projector.bias', 'model.region_geo_sampler.dim_projector.weight', + 'model.region_geo_sampler.flatten_projector.bias', 'model.region_geo_sampler.flatten_projector.weight' + ] + + def apply_delta(base_model_path, target_model_path, delta_path): print("Loading base model") base = AutoModelForCausalLM.from_pretrained( @@ -22,7 +44,7 @@ def apply_delta(base_model_path, target_model_path, delta_path): print("Applying delta") for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"): if name not in base.state_dict(): - assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' + assert name in exclude_name_lists, f'{name} not in base model' continue if param.data.shape == base.state_dict()[name].shape: param.data += base.state_dict()[name] diff --git a/ferret/model/make_delta.py b/ferret/model/make_delta.py new file mode 100644 index 0000000..f7781e9 --- /dev/null +++ b/ferret/model/make_delta.py @@ -0,0 +1,74 @@ +""" +Usage: +# 7B +python3 -m ferret.model.make_delta \ + --base ./model/vicuna-7b-v1-3 \ + --target ./checkpoints/ferret_ft_clipL336_vicunaV1-3-7b_3Ep_dataV16_RSamplerV2/checkpoint-final \ + --delta ./checkpoints/ferret_ft_clipL336_vicunaV1-3-7b_3Ep_dataV16_RSamplerV2/ferret-7b-delta + +# 13B +python3 -m ferret.model.make_delta \ + --base ./model/vicuna-13b-v1-3 \ + --target ./checkpoints/ferret_ft_clipL336_vicunaV1-3-13b_3Ep_dataV16_RSamplerV2/checkpoint-final \ + --delta ./checkpoints/ferret_ft_clipL336_vicunaV1-3-13b_3Ep_dataV16_RSamplerV2/ferret-13b-delta +""" +import argparse + +import torch +from tqdm import tqdm +from transformers import AutoTokenizer, AutoModelForCausalLM +from ferret.model.utils import auto_upgrade + +# all the parameters inside the geosampler and mm projector +exclude_name_lists = ['model.mm_projector.weight', 'model.mm_projector.bias', + 'model.region_geo_sampler.agg_projector_list.0.net.0.bias', 'model.region_geo_sampler.agg_projector_list.0.net.0.weight', + 'model.region_geo_sampler.agg_projector_list.0.norm.bias', 'model.region_geo_sampler.agg_projector_list.0.norm.weight', + 'model.region_geo_sampler.agg_projector_list.1.net.0.bias', 'model.region_geo_sampler.agg_projector_list.1.net.0.weight', + 'model.region_geo_sampler.agg_projector_list.1.norm.bias', 'model.region_geo_sampler.agg_projector_list.1.norm.weight', + 'model.region_geo_sampler.diff_projector_list.0.bias', 'model.region_geo_sampler.diff_projector_list.0.weight', + 'model.region_geo_sampler.diff_projector_list.1.bias', 'model.region_geo_sampler.diff_projector_list.1.weight', + 'model.region_geo_sampler.dim_projector.bias', 'model.region_geo_sampler.dim_projector.weight', + 'model.region_geo_sampler.flatten_projector.bias', 'model.region_geo_sampler.flatten_projector.weight' + ] + + +def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id): + print("Loading base model") + base = AutoModelForCausalLM.from_pretrained( + base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) + + print("Loading target model") + auto_upgrade(target_model_path) + target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) + + print("Calculating delta") + for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"): + if name not in base.state_dict(): + assert name in exclude_name_lists, f'{name} not in base model' + continue + if param.data.shape == base.state_dict()[name].shape: + param.data -= base.state_dict()[name] + else: + assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' + bparam = base.state_dict()[name] + param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam + + print("Saving delta") + if hub_repo_id: + kwargs = {"push_to_hub": True, "repo_id": hub_repo_id} + else: + kwargs = {} + target.save_pretrained(delta_path, **kwargs) + target_tokenizer = AutoTokenizer.from_pretrained(target_model_path) + target_tokenizer.save_pretrained(delta_path, **kwargs) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--base-model-path", type=str, required=True) + parser.add_argument("--target-model-path", type=str, required=True) + parser.add_argument("--delta-path", type=str, required=True) + parser.add_argument("--hub-repo-id", type=str, default=None) + args = parser.parse_args() + + make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id) diff --git a/ferret/serve/gradio_web_server.py b/ferret/serve/gradio_web_server.py index c90f46c..67576c8 100644 --- a/ferret/serve/gradio_web_server.py +++ b/ferret/serve/gradio_web_server.py @@ -216,7 +216,7 @@ def clear_history(request: gr.Request): logger.info(f"clear_history. ip: {request.client.host}") state = default_conversation.copy() return (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn,) * 5 + \ - (None, {'region_placeholder_tokens':[],'region_coordinates':[],'region_masks':[],'masks':[]}, [], None) + (None, {'region_placeholder_tokens':[],'region_coordinates':[],'region_masks':[],'region_masks_in_prompts':[],'masks':[]}, [], None) def resize_bbox(box, image_w=None, image_h=None, default_wh=VOCAB_IMAGE_W): @@ -321,7 +321,26 @@ def post_process_code(code): return code +def find_indices_in_order(str_list, STR): + indices = [] + i = 0 + while i < len(STR): + for element in str_list: + if STR[i:i+len(element)] == element: + indices.append(str_list.index(element)) + i += len(element) - 1 + break + i += 1 + return indices + + def format_region_prompt(prompt, refer_input_state): + # Find regions in prompts and assign corresponding region masks + refer_input_state['region_masks_in_prompts'] = [] + indices_region_placeholder_in_prompt = find_indices_in_order(refer_input_state['region_placeholder_tokens'], prompt) + refer_input_state['region_masks_in_prompts'] = [refer_input_state['region_masks'][iii] for iii in indices_region_placeholder_in_prompt] + + # Find regions in prompts and replace with real coordinates and region feature token. for region_ph_index, region_ph_i in enumerate(refer_input_state['region_placeholder_tokens']): prompt = prompt.replace(region_ph_i, '{} {}'.format(refer_input_state['region_coordinates'][region_ph_index], DEFAULT_REGION_FEA_TOKEN)) return prompt @@ -341,6 +360,32 @@ def http_bot(state, model_selector, temperature, top_p, max_new_tokens, refer_in if len(state.messages) == state.offset + 2: # First round of conversation template_name = 'ferret_v1' + # Below is LLaVA's original templates. + # if "llava" in model_name.lower(): + # if 'llama-2' in model_name.lower(): + # template_name = "llava_llama_2" + # elif "v1" in model_name.lower(): + # if 'mmtag' in model_name.lower(): + # template_name = "v1_mmtag" + # elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower(): + # template_name = "v1_mmtag" + # else: + # template_name = "llava_v1" + # elif "mpt" in model_name.lower(): + # template_name = "mpt" + # else: + # if 'mmtag' in model_name.lower(): + # template_name = "v0_mmtag" + # elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower(): + # template_name = "v0_mmtag" + # else: + # template_name = "llava_v0" + # elif "mpt" in model_name: + # template_name = "mpt_text" + # elif "llama-2" in model_name: + # template_name = "llama_2" + # else: + # template_name = "vicuna_v1" new_state = conv_templates[template_name].copy() new_state.append_message(new_state.roles[0], state.messages[-2][1]) new_state.append_message(new_state.roles[1], None) @@ -386,8 +431,8 @@ def http_bot(state, model_selector, temperature, top_p, max_new_tokens, refer_in } logger.info(f"==== request ====\n{pload}") if args.add_region_feature: - pload['region_masks'] = refer_input_state['region_masks'] - logger.info(f"==== add region_masks to request ====\n") + pload['region_masks'] = refer_input_state['region_masks_in_prompts'] + logger.info(f"==== add region_masks_in_prompts to request ====\n") pload['images'] = state.get_images() print(f'Input Prompt: {prompt}') @@ -439,8 +484,8 @@ def http_bot(state, model_selector, temperature, top_p, max_new_tokens, refer_in title_markdown = (""" # 🦦 Ferret: Refer and Ground Anything Anywhere at Any Granularity -[[Code](https://github.com/apple/ml-ferret)] [[Paper](https://arxiv.org/abs/2310.07704)] """) +# [[Project Page]](https://llava-vl.github.io) [[Paper]](https://arxiv.org/abs/2304.08485) tos_markdown = (""" ### Terms of use @@ -554,6 +599,7 @@ def draw(input_mode, input, refer_input_state, refer_text_show, imagebox_refer): refer_input_state['region_placeholder_tokens'].append(cur_region_token) refer_input_state['region_coordinates'].append(cur_region_coordinates) refer_input_state['region_masks'].append(cur_region_masks) + assert len(refer_input_state['region_masks']) == len(refer_input_state['region_coordinates']) == len(refer_input_state['region_placeholder_tokens']) refer_text_show.append((cur_region_token, '')) # Show Parsed Referring. @@ -597,6 +643,7 @@ def build_demo(embed_mode): refer_input_state = gr.State({'region_placeholder_tokens':[], 'region_coordinates':[], 'region_masks':[], + 'region_masks_in_prompts':[], 'masks':[], }) refer_text_show = gr.HighlightedText(value=[], label="Referring Input Cache") diff --git a/scripts/extract_geosampler_and_mm_projector.py b/scripts/extract_geosampler_and_mm_projector.py new file mode 100644 index 0000000..6685429 --- /dev/null +++ b/scripts/extract_geosampler_and_mm_projector.py @@ -0,0 +1,77 @@ +""" +Usage: + +# 7B +To extract region_geo_sampler: +python misc/extract_geosampler_and_mm_projector.py \ + --keys_to_match=region_geo_sampler \ + --model-path=./checkpoints/ferret_ft_clipL336_vicunaV1-3-7b_3Ep_dataV16_RSamplerV2/checkpoint-final \ + --output=./checkpoints/ferret_ft_clipL336_vicunaV1-3-7b_3Ep_dataV16_RSamplerV2/extracted_region_geo_sampler.bin + +To extract mm_projector: +python misc/extract_geosampler_and_mm_projector.py \ + --keys_to_match=mm_projector \ + --model-path=./checkpoints/ferret_ft_clipL336_vicunaV1-3-7b_3Ep_dataV16_RSamplerV2/checkpoint-final \ + --output=./checkpoints/ferret_ft_clipL336_vicunaV1-3-7b_3Ep_dataV16_RSamplerV2/extracted_mm_projector.bin + +# 13B +To extract region_geo_sampler: +python misc/extract_geosampler_and_mm_projector.py \ + --keys_to_match=region_geo_sampler \ + --model-path=./checkpoints/ferret_ft_clipL336_vicunaV1-3-13b_3Ep_dataV16_RSamplerV2/checkpoint-final \ + --output=./checkpoints/ferret_ft_clipL336_vicunaV1-3-13b_3Ep_dataV16_RSamplerV2/extracted_region_geo_sampler.bin + +To extract mm_projector: +python misc/extract_geosampler_and_mm_projector.py \ + --keys_to_match=mm_projector \ + --model-path=./checkpoints/ferret_ft_clipL336_vicunaV1-3-13b_3Ep_dataV16_RSamplerV2/checkpoint-final \ + --output=./checkpoints/ferret_ft_clipL336_vicunaV1-3-13b_3Ep_dataV16_RSamplerV2/extracted_mm_projector.bin +""" + + +import os +import argparse +import torch +import json +from collections import defaultdict + + +def parse_args(): + parser = argparse.ArgumentParser(description='Extract MMProjector or GeoSampler weights') + parser.add_argument('--model-path', type=str, help='model folder') + parser.add_argument('--output', type=str, help='output file') + parser.add_argument('--keys_to_match', type=str, default="region_geo_sampler", choices=["mm_projector", "region_geo_sampler"], help='keys to be matched') + args = parser.parse_args() + return args + + +if __name__ == '__main__': + args = parse_args() + + keys_to_match = [args.keys_to_match] + ckpt_to_key = defaultdict(list) + print('----indexing keys_to_match...----') + try: + model_indices = json.load(open(os.path.join(args.model_path, 'pytorch_model.bin.index.json'))) + for k, v in model_indices['weight_map'].items(): + if any(key_match in k for key_match in keys_to_match): + ckpt_to_key[v].append(k) + except FileNotFoundError: + # Smaller models or model checkpoints saved by DeepSpeed. + v = 'pytorch_model.bin' + for k in torch.load(os.path.join(args.model_path, v), map_location='cpu').keys(): + if any(key_match in k for key_match in keys_to_match): + ckpt_to_key[v].append(k) + + loaded_weights = {} + + print('----loading weights...----') + for ckpt_name, weight_keys in ckpt_to_key.items(): + ckpt = torch.load(os.path.join(args.model_path, ckpt_name), map_location='cpu') + for k in weight_keys: + loaded_weights[k] = ckpt[k] + + print('----saving weights...----') + print(f'the keys of saved weights: {loaded_weights.keys()}') + print(f'----saved to {args.output}----') + torch.save(loaded_weights, args.output) diff --git a/scripts/verify_equal.py b/scripts/verify_equal.py new file mode 100644 index 0000000..8dbabf9 --- /dev/null +++ b/scripts/verify_equal.py @@ -0,0 +1,51 @@ +""" +Usage: +python3 misc/verify_equal.py \ + --orig-model-path ./checkpoints/ferret_ft_clipL336_vicunaV1-3-7b_3Ep_dataV16_RSamplerV2/checkpoint-final \ + --new-model-path ./model/ferret-7b-v1-3 + +""" +import argparse + +import torch +from tqdm import tqdm +from transformers import AutoTokenizer, AutoModelForCausalLM +from ferret import FERRETLlamaForCausalLM + +def verify_equal(old_model_path, new_model_path): + print("Loading old model") + old = FERRETLlamaForCausalLM.from_pretrained(old_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) + + print("Loading saved model") + new = FERRETLlamaForCausalLM.from_pretrained(new_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) + + # Get state dictionaries of both models + state_dict1 = old.state_dict() + state_dict2 = new.state_dict() + + # Compare each parameter + for name, param in tqdm(state_dict1.items(), desc="Traverse all params"): + # Check if the parameter name exists in the second model + if name not in state_dict2: + print(f"Parameter {name} found in the first model but not in the second.") + return False + + # Check if the parameter weights are the same, bf16 vs. f32 + if not torch.allclose(param, state_dict2[name], atol=1e-4): + print(param.shape) + print(state_dict2[name].shape) + print(f"Parameter weights for {name} are different.") + return False + + print("All parameter names and weights are the same.") + return True + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--orig-model-path", type=str, required=True) + parser.add_argument("--new-model-path", type=str, required=True) + + args = parser.parse_args() + + print(verify_equal(args.orig_model_path, args.new_model_path)) \ No newline at end of file