mirror of
https://github.com/apple/ml-ferret.git
synced 2025-01-02 19:02:31 +00:00
checkpoints release
This commit is contained in:
parent
5fc7ff83ef
commit
262a943e1f
24
README.md
24
README.md
@ -30,11 +30,12 @@ Key Contributions:
|
||||
|
||||
|
||||
## Release
|
||||
- [10/30] 🔥 We released the code of **FERRET** model.
|
||||
- [12/14] 🔥 We released the [checkpoints(7B, 13B)](#checkpoints).
|
||||
- [10/30] 🔥 We released the code of **FERRET** model and [Ferret-Bench](ferret/eval/ferret_gpt4_data).
|
||||
|
||||
|
||||
|
||||
**Usage and License Notices**: The data, and code is intended and licensed for research use only. They are also restricted to uses that follow the license agreement of LLaMA, Vicuna and GPT-4. The dataset is CC BY NC 4.0 (allowing only non-commercial use) and models trained using the dataset should not be used outside of research purposes.
|
||||
**Usage and License Notices**: The data, and code is intended and licensed for research use only. They are also restricted to uses that follow the license agreement of LLaMA, Vicuna and GPT-4. The dataset is CC BY NC 4.0 (allowing only non-commercial use) and models trained using the dataset should not be used outside of research purposes.
|
||||
|
||||
## Contents
|
||||
- [Install](#install)
|
||||
@ -95,6 +96,25 @@ The scripts are provided ([7B](experiments/ferret_7b_train.sh), [13B](experiment
|
||||
|
||||
Please see this [doc](EVAL.md) for the details.
|
||||
|
||||
## Checkpoints
|
||||
We extracted the `delta` between our pre-trained model and Vicuna. Please first download weights of Vicuna following the [previous instruction](#prepare-vicuna-checkpoint-and-llavas-projector). Then download our prepared offsets of weights: [7B](https://docs-assets.developer.apple.com/ml-research/models/ferret/ferret-7b/ferret-7b-delta.zip), [13B](https://docs-assets.developer.apple.com/ml-research/models/ferret/ferret-13b/ferret-13b-delta.zip) using `wget` or `curl`, and unzip the downloaded offsets. Lastly, apply the offset to the Vicuna's weight by running the following script:
|
||||
```Shell
|
||||
# 7B
|
||||
python3 -m ferret.model.apply_delta \
|
||||
--base ./model/vicuna-7b-v1-3 \
|
||||
--target ./model/ferret-7b-v1-3 \
|
||||
--delta path/to/ferret-7b-delta
|
||||
# 13B
|
||||
python3 -m ferret.model.apply_delta \
|
||||
--base ./model/vicuna-13b-v1-3 \
|
||||
--target ./model/ferret-13b-v1-3 \
|
||||
--delta path/to/ferret-13b-delta
|
||||
```
|
||||
|
||||
**Notices**: Apple's rights in the attached weight differentials are hereby licensed under the CC-BY-NC license. Apple makes no representations with regards to LLaMa or any other third party software, which are subject to their own terms.
|
||||
|
||||
Please refer to the next section about how to set up a local demo with pre-trained weight.
|
||||
|
||||
## Demo
|
||||
|
||||
To run our demo, you need to train FERRET and use the checkpoints locally. Gradio web UI is used. Please run the following commands one by one.
|
||||
|
@ -1,6 +1,16 @@
|
||||
"""
|
||||
Usage:
|
||||
python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta
|
||||
# 7B
|
||||
python3 -m ferret.model.apply_delta \
|
||||
--base ./model/vicuna-7b-v1-3 \
|
||||
--target ./model/ferret-7b-v1-3 \
|
||||
--delta ./checkpoints/ferret_ft_clipL336_vicunaV1-3-7b_3Ep_dataV16_RSamplerV2/ferret-7b-delta
|
||||
|
||||
# 13B
|
||||
python3 -m ferret.model.apply_delta \
|
||||
--base ./model/vicuna-13b-v1-3 \
|
||||
--target ./model/ferret-13b-v1-3 \
|
||||
--delta ./checkpoints/ferret_ft_clipL336_vicunaV1-3-13b_3Ep_dataV16_RSamplerV2/ferret-13b-delta
|
||||
"""
|
||||
import argparse
|
||||
|
||||
@ -10,6 +20,18 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from ferret import FERRETLlamaForCausalLM
|
||||
|
||||
|
||||
exclude_name_lists = ['model.mm_projector.weight', 'model.mm_projector.bias',
|
||||
'model.region_geo_sampler.agg_projector_list.0.net.0.bias', 'model.region_geo_sampler.agg_projector_list.0.net.0.weight',
|
||||
'model.region_geo_sampler.agg_projector_list.0.norm.bias', 'model.region_geo_sampler.agg_projector_list.0.norm.weight',
|
||||
'model.region_geo_sampler.agg_projector_list.1.net.0.bias', 'model.region_geo_sampler.agg_projector_list.1.net.0.weight',
|
||||
'model.region_geo_sampler.agg_projector_list.1.norm.bias', 'model.region_geo_sampler.agg_projector_list.1.norm.weight',
|
||||
'model.region_geo_sampler.diff_projector_list.0.bias', 'model.region_geo_sampler.diff_projector_list.0.weight',
|
||||
'model.region_geo_sampler.diff_projector_list.1.bias', 'model.region_geo_sampler.diff_projector_list.1.weight',
|
||||
'model.region_geo_sampler.dim_projector.bias', 'model.region_geo_sampler.dim_projector.weight',
|
||||
'model.region_geo_sampler.flatten_projector.bias', 'model.region_geo_sampler.flatten_projector.weight'
|
||||
]
|
||||
|
||||
|
||||
def apply_delta(base_model_path, target_model_path, delta_path):
|
||||
print("Loading base model")
|
||||
base = AutoModelForCausalLM.from_pretrained(
|
||||
@ -22,7 +44,7 @@ def apply_delta(base_model_path, target_model_path, delta_path):
|
||||
print("Applying delta")
|
||||
for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"):
|
||||
if name not in base.state_dict():
|
||||
assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
|
||||
assert name in exclude_name_lists, f'{name} not in base model'
|
||||
continue
|
||||
if param.data.shape == base.state_dict()[name].shape:
|
||||
param.data += base.state_dict()[name]
|
||||
|
74
ferret/model/make_delta.py
Normal file
74
ferret/model/make_delta.py
Normal file
@ -0,0 +1,74 @@
|
||||
"""
|
||||
Usage:
|
||||
# 7B
|
||||
python3 -m ferret.model.make_delta \
|
||||
--base ./model/vicuna-7b-v1-3 \
|
||||
--target ./checkpoints/ferret_ft_clipL336_vicunaV1-3-7b_3Ep_dataV16_RSamplerV2/checkpoint-final \
|
||||
--delta ./checkpoints/ferret_ft_clipL336_vicunaV1-3-7b_3Ep_dataV16_RSamplerV2/ferret-7b-delta
|
||||
|
||||
# 13B
|
||||
python3 -m ferret.model.make_delta \
|
||||
--base ./model/vicuna-13b-v1-3 \
|
||||
--target ./checkpoints/ferret_ft_clipL336_vicunaV1-3-13b_3Ep_dataV16_RSamplerV2/checkpoint-final \
|
||||
--delta ./checkpoints/ferret_ft_clipL336_vicunaV1-3-13b_3Ep_dataV16_RSamplerV2/ferret-13b-delta
|
||||
"""
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from ferret.model.utils import auto_upgrade
|
||||
|
||||
# all the parameters inside the geosampler and mm projector
|
||||
exclude_name_lists = ['model.mm_projector.weight', 'model.mm_projector.bias',
|
||||
'model.region_geo_sampler.agg_projector_list.0.net.0.bias', 'model.region_geo_sampler.agg_projector_list.0.net.0.weight',
|
||||
'model.region_geo_sampler.agg_projector_list.0.norm.bias', 'model.region_geo_sampler.agg_projector_list.0.norm.weight',
|
||||
'model.region_geo_sampler.agg_projector_list.1.net.0.bias', 'model.region_geo_sampler.agg_projector_list.1.net.0.weight',
|
||||
'model.region_geo_sampler.agg_projector_list.1.norm.bias', 'model.region_geo_sampler.agg_projector_list.1.norm.weight',
|
||||
'model.region_geo_sampler.diff_projector_list.0.bias', 'model.region_geo_sampler.diff_projector_list.0.weight',
|
||||
'model.region_geo_sampler.diff_projector_list.1.bias', 'model.region_geo_sampler.diff_projector_list.1.weight',
|
||||
'model.region_geo_sampler.dim_projector.bias', 'model.region_geo_sampler.dim_projector.weight',
|
||||
'model.region_geo_sampler.flatten_projector.bias', 'model.region_geo_sampler.flatten_projector.weight'
|
||||
]
|
||||
|
||||
|
||||
def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id):
|
||||
print("Loading base model")
|
||||
base = AutoModelForCausalLM.from_pretrained(
|
||||
base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
|
||||
|
||||
print("Loading target model")
|
||||
auto_upgrade(target_model_path)
|
||||
target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
|
||||
|
||||
print("Calculating delta")
|
||||
for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"):
|
||||
if name not in base.state_dict():
|
||||
assert name in exclude_name_lists, f'{name} not in base model'
|
||||
continue
|
||||
if param.data.shape == base.state_dict()[name].shape:
|
||||
param.data -= base.state_dict()[name]
|
||||
else:
|
||||
assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
|
||||
bparam = base.state_dict()[name]
|
||||
param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam
|
||||
|
||||
print("Saving delta")
|
||||
if hub_repo_id:
|
||||
kwargs = {"push_to_hub": True, "repo_id": hub_repo_id}
|
||||
else:
|
||||
kwargs = {}
|
||||
target.save_pretrained(delta_path, **kwargs)
|
||||
target_tokenizer = AutoTokenizer.from_pretrained(target_model_path)
|
||||
target_tokenizer.save_pretrained(delta_path, **kwargs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--base-model-path", type=str, required=True)
|
||||
parser.add_argument("--target-model-path", type=str, required=True)
|
||||
parser.add_argument("--delta-path", type=str, required=True)
|
||||
parser.add_argument("--hub-repo-id", type=str, default=None)
|
||||
args = parser.parse_args()
|
||||
|
||||
make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id)
|
@ -216,7 +216,7 @@ def clear_history(request: gr.Request):
|
||||
logger.info(f"clear_history. ip: {request.client.host}")
|
||||
state = default_conversation.copy()
|
||||
return (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn,) * 5 + \
|
||||
(None, {'region_placeholder_tokens':[],'region_coordinates':[],'region_masks':[],'masks':[]}, [], None)
|
||||
(None, {'region_placeholder_tokens':[],'region_coordinates':[],'region_masks':[],'region_masks_in_prompts':[],'masks':[]}, [], None)
|
||||
|
||||
|
||||
def resize_bbox(box, image_w=None, image_h=None, default_wh=VOCAB_IMAGE_W):
|
||||
@ -321,7 +321,26 @@ def post_process_code(code):
|
||||
return code
|
||||
|
||||
|
||||
def find_indices_in_order(str_list, STR):
|
||||
indices = []
|
||||
i = 0
|
||||
while i < len(STR):
|
||||
for element in str_list:
|
||||
if STR[i:i+len(element)] == element:
|
||||
indices.append(str_list.index(element))
|
||||
i += len(element) - 1
|
||||
break
|
||||
i += 1
|
||||
return indices
|
||||
|
||||
|
||||
def format_region_prompt(prompt, refer_input_state):
|
||||
# Find regions in prompts and assign corresponding region masks
|
||||
refer_input_state['region_masks_in_prompts'] = []
|
||||
indices_region_placeholder_in_prompt = find_indices_in_order(refer_input_state['region_placeholder_tokens'], prompt)
|
||||
refer_input_state['region_masks_in_prompts'] = [refer_input_state['region_masks'][iii] for iii in indices_region_placeholder_in_prompt]
|
||||
|
||||
# Find regions in prompts and replace with real coordinates and region feature token.
|
||||
for region_ph_index, region_ph_i in enumerate(refer_input_state['region_placeholder_tokens']):
|
||||
prompt = prompt.replace(region_ph_i, '{} {}'.format(refer_input_state['region_coordinates'][region_ph_index], DEFAULT_REGION_FEA_TOKEN))
|
||||
return prompt
|
||||
@ -341,6 +360,32 @@ def http_bot(state, model_selector, temperature, top_p, max_new_tokens, refer_in
|
||||
if len(state.messages) == state.offset + 2:
|
||||
# First round of conversation
|
||||
template_name = 'ferret_v1'
|
||||
# Below is LLaVA's original templates.
|
||||
# if "llava" in model_name.lower():
|
||||
# if 'llama-2' in model_name.lower():
|
||||
# template_name = "llava_llama_2"
|
||||
# elif "v1" in model_name.lower():
|
||||
# if 'mmtag' in model_name.lower():
|
||||
# template_name = "v1_mmtag"
|
||||
# elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():
|
||||
# template_name = "v1_mmtag"
|
||||
# else:
|
||||
# template_name = "llava_v1"
|
||||
# elif "mpt" in model_name.lower():
|
||||
# template_name = "mpt"
|
||||
# else:
|
||||
# if 'mmtag' in model_name.lower():
|
||||
# template_name = "v0_mmtag"
|
||||
# elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():
|
||||
# template_name = "v0_mmtag"
|
||||
# else:
|
||||
# template_name = "llava_v0"
|
||||
# elif "mpt" in model_name:
|
||||
# template_name = "mpt_text"
|
||||
# elif "llama-2" in model_name:
|
||||
# template_name = "llama_2"
|
||||
# else:
|
||||
# template_name = "vicuna_v1"
|
||||
new_state = conv_templates[template_name].copy()
|
||||
new_state.append_message(new_state.roles[0], state.messages[-2][1])
|
||||
new_state.append_message(new_state.roles[1], None)
|
||||
@ -386,8 +431,8 @@ def http_bot(state, model_selector, temperature, top_p, max_new_tokens, refer_in
|
||||
}
|
||||
logger.info(f"==== request ====\n{pload}")
|
||||
if args.add_region_feature:
|
||||
pload['region_masks'] = refer_input_state['region_masks']
|
||||
logger.info(f"==== add region_masks to request ====\n")
|
||||
pload['region_masks'] = refer_input_state['region_masks_in_prompts']
|
||||
logger.info(f"==== add region_masks_in_prompts to request ====\n")
|
||||
|
||||
pload['images'] = state.get_images()
|
||||
print(f'Input Prompt: {prompt}')
|
||||
@ -439,8 +484,8 @@ def http_bot(state, model_selector, temperature, top_p, max_new_tokens, refer_in
|
||||
|
||||
title_markdown = ("""
|
||||
# 🦦 Ferret: Refer and Ground Anything Anywhere at Any Granularity
|
||||
[[Code](https://github.com/apple/ml-ferret)] [[Paper](https://arxiv.org/abs/2310.07704)]
|
||||
""")
|
||||
# [[Project Page]](https://llava-vl.github.io) [[Paper]](https://arxiv.org/abs/2304.08485)
|
||||
|
||||
tos_markdown = ("""
|
||||
### Terms of use
|
||||
@ -554,6 +599,7 @@ def draw(input_mode, input, refer_input_state, refer_text_show, imagebox_refer):
|
||||
refer_input_state['region_placeholder_tokens'].append(cur_region_token)
|
||||
refer_input_state['region_coordinates'].append(cur_region_coordinates)
|
||||
refer_input_state['region_masks'].append(cur_region_masks)
|
||||
assert len(refer_input_state['region_masks']) == len(refer_input_state['region_coordinates']) == len(refer_input_state['region_placeholder_tokens'])
|
||||
refer_text_show.append((cur_region_token, ''))
|
||||
|
||||
# Show Parsed Referring.
|
||||
@ -597,6 +643,7 @@ def build_demo(embed_mode):
|
||||
refer_input_state = gr.State({'region_placeholder_tokens':[],
|
||||
'region_coordinates':[],
|
||||
'region_masks':[],
|
||||
'region_masks_in_prompts':[],
|
||||
'masks':[],
|
||||
})
|
||||
refer_text_show = gr.HighlightedText(value=[], label="Referring Input Cache")
|
||||
|
77
scripts/extract_geosampler_and_mm_projector.py
Normal file
77
scripts/extract_geosampler_and_mm_projector.py
Normal file
@ -0,0 +1,77 @@
|
||||
"""
|
||||
Usage:
|
||||
|
||||
# 7B
|
||||
To extract region_geo_sampler:
|
||||
python misc/extract_geosampler_and_mm_projector.py \
|
||||
--keys_to_match=region_geo_sampler \
|
||||
--model-path=./checkpoints/ferret_ft_clipL336_vicunaV1-3-7b_3Ep_dataV16_RSamplerV2/checkpoint-final \
|
||||
--output=./checkpoints/ferret_ft_clipL336_vicunaV1-3-7b_3Ep_dataV16_RSamplerV2/extracted_region_geo_sampler.bin
|
||||
|
||||
To extract mm_projector:
|
||||
python misc/extract_geosampler_and_mm_projector.py \
|
||||
--keys_to_match=mm_projector \
|
||||
--model-path=./checkpoints/ferret_ft_clipL336_vicunaV1-3-7b_3Ep_dataV16_RSamplerV2/checkpoint-final \
|
||||
--output=./checkpoints/ferret_ft_clipL336_vicunaV1-3-7b_3Ep_dataV16_RSamplerV2/extracted_mm_projector.bin
|
||||
|
||||
# 13B
|
||||
To extract region_geo_sampler:
|
||||
python misc/extract_geosampler_and_mm_projector.py \
|
||||
--keys_to_match=region_geo_sampler \
|
||||
--model-path=./checkpoints/ferret_ft_clipL336_vicunaV1-3-13b_3Ep_dataV16_RSamplerV2/checkpoint-final \
|
||||
--output=./checkpoints/ferret_ft_clipL336_vicunaV1-3-13b_3Ep_dataV16_RSamplerV2/extracted_region_geo_sampler.bin
|
||||
|
||||
To extract mm_projector:
|
||||
python misc/extract_geosampler_and_mm_projector.py \
|
||||
--keys_to_match=mm_projector \
|
||||
--model-path=./checkpoints/ferret_ft_clipL336_vicunaV1-3-13b_3Ep_dataV16_RSamplerV2/checkpoint-final \
|
||||
--output=./checkpoints/ferret_ft_clipL336_vicunaV1-3-13b_3Ep_dataV16_RSamplerV2/extracted_mm_projector.bin
|
||||
"""
|
||||
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import torch
|
||||
import json
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Extract MMProjector or GeoSampler weights')
|
||||
parser.add_argument('--model-path', type=str, help='model folder')
|
||||
parser.add_argument('--output', type=str, help='output file')
|
||||
parser.add_argument('--keys_to_match', type=str, default="region_geo_sampler", choices=["mm_projector", "region_geo_sampler"], help='keys to be matched')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
|
||||
keys_to_match = [args.keys_to_match]
|
||||
ckpt_to_key = defaultdict(list)
|
||||
print('----indexing keys_to_match...----')
|
||||
try:
|
||||
model_indices = json.load(open(os.path.join(args.model_path, 'pytorch_model.bin.index.json')))
|
||||
for k, v in model_indices['weight_map'].items():
|
||||
if any(key_match in k for key_match in keys_to_match):
|
||||
ckpt_to_key[v].append(k)
|
||||
except FileNotFoundError:
|
||||
# Smaller models or model checkpoints saved by DeepSpeed.
|
||||
v = 'pytorch_model.bin'
|
||||
for k in torch.load(os.path.join(args.model_path, v), map_location='cpu').keys():
|
||||
if any(key_match in k for key_match in keys_to_match):
|
||||
ckpt_to_key[v].append(k)
|
||||
|
||||
loaded_weights = {}
|
||||
|
||||
print('----loading weights...----')
|
||||
for ckpt_name, weight_keys in ckpt_to_key.items():
|
||||
ckpt = torch.load(os.path.join(args.model_path, ckpt_name), map_location='cpu')
|
||||
for k in weight_keys:
|
||||
loaded_weights[k] = ckpt[k]
|
||||
|
||||
print('----saving weights...----')
|
||||
print(f'the keys of saved weights: {loaded_weights.keys()}')
|
||||
print(f'----saved to {args.output}----')
|
||||
torch.save(loaded_weights, args.output)
|
51
scripts/verify_equal.py
Normal file
51
scripts/verify_equal.py
Normal file
@ -0,0 +1,51 @@
|
||||
"""
|
||||
Usage:
|
||||
python3 misc/verify_equal.py \
|
||||
--orig-model-path ./checkpoints/ferret_ft_clipL336_vicunaV1-3-7b_3Ep_dataV16_RSamplerV2/checkpoint-final \
|
||||
--new-model-path ./model/ferret-7b-v1-3
|
||||
|
||||
"""
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from ferret import FERRETLlamaForCausalLM
|
||||
|
||||
def verify_equal(old_model_path, new_model_path):
|
||||
print("Loading old model")
|
||||
old = FERRETLlamaForCausalLM.from_pretrained(old_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
|
||||
|
||||
print("Loading saved model")
|
||||
new = FERRETLlamaForCausalLM.from_pretrained(new_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
|
||||
|
||||
# Get state dictionaries of both models
|
||||
state_dict1 = old.state_dict()
|
||||
state_dict2 = new.state_dict()
|
||||
|
||||
# Compare each parameter
|
||||
for name, param in tqdm(state_dict1.items(), desc="Traverse all params"):
|
||||
# Check if the parameter name exists in the second model
|
||||
if name not in state_dict2:
|
||||
print(f"Parameter {name} found in the first model but not in the second.")
|
||||
return False
|
||||
|
||||
# Check if the parameter weights are the same, bf16 vs. f32
|
||||
if not torch.allclose(param, state_dict2[name], atol=1e-4):
|
||||
print(param.shape)
|
||||
print(state_dict2[name].shape)
|
||||
print(f"Parameter weights for {name} are different.")
|
||||
return False
|
||||
|
||||
print("All parameter names and weights are the same.")
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--orig-model-path", type=str, required=True)
|
||||
parser.add_argument("--new-model-path", type=str, required=True)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print(verify_equal(args.orig_model_path, args.new_model_path))
|
Loading…
Reference in New Issue
Block a user