checkpoints release

This commit is contained in:
Haotian Zhang 2023-12-14 21:16:40 -08:00
parent 5fc7ff83ef
commit 262a943e1f
6 changed files with 299 additions and 8 deletions

View File

@ -30,11 +30,12 @@ Key Contributions:
## Release
- [10/30] 🔥 We released the code of **FERRET** model.
- [12/14] 🔥 We released the [checkpoints(7B, 13B)](#checkpoints).
- [10/30] 🔥 We released the code of **FERRET** model and [Ferret-Bench](ferret/eval/ferret_gpt4_data).
**Usage and License Notices**: The data, and code is intended and licensed for research use only. They are also restricted to uses that follow the license agreement of LLaMA, Vicuna and GPT-4. The dataset is CC BY NC 4.0 (allowing only non-commercial use) and models trained using the dataset should not be used outside of research purposes.
**Usage and License Notices**: The data, and code is intended and licensed for research use only. They are also restricted to uses that follow the license agreement of LLaMA, Vicuna and GPT-4. The dataset is CC BY NC 4.0 (allowing only non-commercial use) and models trained using the dataset should not be used outside of research purposes.
## Contents
- [Install](#install)
@ -95,6 +96,25 @@ The scripts are provided ([7B](experiments/ferret_7b_train.sh), [13B](experiment
Please see this [doc](EVAL.md) for the details.
## Checkpoints
We extracted the `delta` between our pre-trained model and Vicuna. Please first download weights of Vicuna following the [previous instruction](#prepare-vicuna-checkpoint-and-llavas-projector). Then download our prepared offsets of weights: [7B](https://docs-assets.developer.apple.com/ml-research/models/ferret/ferret-7b/ferret-7b-delta.zip), [13B](https://docs-assets.developer.apple.com/ml-research/models/ferret/ferret-13b/ferret-13b-delta.zip) using `wget` or `curl`, and unzip the downloaded offsets. Lastly, apply the offset to the Vicuna's weight by running the following script:
```Shell
# 7B
python3 -m ferret.model.apply_delta \
--base ./model/vicuna-7b-v1-3 \
--target ./model/ferret-7b-v1-3 \
--delta path/to/ferret-7b-delta
# 13B
python3 -m ferret.model.apply_delta \
--base ./model/vicuna-13b-v1-3 \
--target ./model/ferret-13b-v1-3 \
--delta path/to/ferret-13b-delta
```
**Notices**: Apple's rights in the attached weight differentials are hereby licensed under the CC-BY-NC license. Apple makes no representations with regards to LLaMa or any other third party software, which are subject to their own terms.
Please refer to the next section about how to set up a local demo with pre-trained weight.
## Demo
To run our demo, you need to train FERRET and use the checkpoints locally. Gradio web UI is used. Please run the following commands one by one.

View File

@ -1,6 +1,16 @@
"""
Usage:
python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta
# 7B
python3 -m ferret.model.apply_delta \
--base ./model/vicuna-7b-v1-3 \
--target ./model/ferret-7b-v1-3 \
--delta ./checkpoints/ferret_ft_clipL336_vicunaV1-3-7b_3Ep_dataV16_RSamplerV2/ferret-7b-delta
# 13B
python3 -m ferret.model.apply_delta \
--base ./model/vicuna-13b-v1-3 \
--target ./model/ferret-13b-v1-3 \
--delta ./checkpoints/ferret_ft_clipL336_vicunaV1-3-13b_3Ep_dataV16_RSamplerV2/ferret-13b-delta
"""
import argparse
@ -10,6 +20,18 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
from ferret import FERRETLlamaForCausalLM
exclude_name_lists = ['model.mm_projector.weight', 'model.mm_projector.bias',
'model.region_geo_sampler.agg_projector_list.0.net.0.bias', 'model.region_geo_sampler.agg_projector_list.0.net.0.weight',
'model.region_geo_sampler.agg_projector_list.0.norm.bias', 'model.region_geo_sampler.agg_projector_list.0.norm.weight',
'model.region_geo_sampler.agg_projector_list.1.net.0.bias', 'model.region_geo_sampler.agg_projector_list.1.net.0.weight',
'model.region_geo_sampler.agg_projector_list.1.norm.bias', 'model.region_geo_sampler.agg_projector_list.1.norm.weight',
'model.region_geo_sampler.diff_projector_list.0.bias', 'model.region_geo_sampler.diff_projector_list.0.weight',
'model.region_geo_sampler.diff_projector_list.1.bias', 'model.region_geo_sampler.diff_projector_list.1.weight',
'model.region_geo_sampler.dim_projector.bias', 'model.region_geo_sampler.dim_projector.weight',
'model.region_geo_sampler.flatten_projector.bias', 'model.region_geo_sampler.flatten_projector.weight'
]
def apply_delta(base_model_path, target_model_path, delta_path):
print("Loading base model")
base = AutoModelForCausalLM.from_pretrained(
@ -22,7 +44,7 @@ def apply_delta(base_model_path, target_model_path, delta_path):
print("Applying delta")
for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"):
if name not in base.state_dict():
assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
assert name in exclude_name_lists, f'{name} not in base model'
continue
if param.data.shape == base.state_dict()[name].shape:
param.data += base.state_dict()[name]

View File

@ -0,0 +1,74 @@
"""
Usage:
# 7B
python3 -m ferret.model.make_delta \
--base ./model/vicuna-7b-v1-3 \
--target ./checkpoints/ferret_ft_clipL336_vicunaV1-3-7b_3Ep_dataV16_RSamplerV2/checkpoint-final \
--delta ./checkpoints/ferret_ft_clipL336_vicunaV1-3-7b_3Ep_dataV16_RSamplerV2/ferret-7b-delta
# 13B
python3 -m ferret.model.make_delta \
--base ./model/vicuna-13b-v1-3 \
--target ./checkpoints/ferret_ft_clipL336_vicunaV1-3-13b_3Ep_dataV16_RSamplerV2/checkpoint-final \
--delta ./checkpoints/ferret_ft_clipL336_vicunaV1-3-13b_3Ep_dataV16_RSamplerV2/ferret-13b-delta
"""
import argparse
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from ferret.model.utils import auto_upgrade
# all the parameters inside the geosampler and mm projector
exclude_name_lists = ['model.mm_projector.weight', 'model.mm_projector.bias',
'model.region_geo_sampler.agg_projector_list.0.net.0.bias', 'model.region_geo_sampler.agg_projector_list.0.net.0.weight',
'model.region_geo_sampler.agg_projector_list.0.norm.bias', 'model.region_geo_sampler.agg_projector_list.0.norm.weight',
'model.region_geo_sampler.agg_projector_list.1.net.0.bias', 'model.region_geo_sampler.agg_projector_list.1.net.0.weight',
'model.region_geo_sampler.agg_projector_list.1.norm.bias', 'model.region_geo_sampler.agg_projector_list.1.norm.weight',
'model.region_geo_sampler.diff_projector_list.0.bias', 'model.region_geo_sampler.diff_projector_list.0.weight',
'model.region_geo_sampler.diff_projector_list.1.bias', 'model.region_geo_sampler.diff_projector_list.1.weight',
'model.region_geo_sampler.dim_projector.bias', 'model.region_geo_sampler.dim_projector.weight',
'model.region_geo_sampler.flatten_projector.bias', 'model.region_geo_sampler.flatten_projector.weight'
]
def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id):
print("Loading base model")
base = AutoModelForCausalLM.from_pretrained(
base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
print("Loading target model")
auto_upgrade(target_model_path)
target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
print("Calculating delta")
for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"):
if name not in base.state_dict():
assert name in exclude_name_lists, f'{name} not in base model'
continue
if param.data.shape == base.state_dict()[name].shape:
param.data -= base.state_dict()[name]
else:
assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
bparam = base.state_dict()[name]
param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam
print("Saving delta")
if hub_repo_id:
kwargs = {"push_to_hub": True, "repo_id": hub_repo_id}
else:
kwargs = {}
target.save_pretrained(delta_path, **kwargs)
target_tokenizer = AutoTokenizer.from_pretrained(target_model_path)
target_tokenizer.save_pretrained(delta_path, **kwargs)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--base-model-path", type=str, required=True)
parser.add_argument("--target-model-path", type=str, required=True)
parser.add_argument("--delta-path", type=str, required=True)
parser.add_argument("--hub-repo-id", type=str, default=None)
args = parser.parse_args()
make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id)

View File

@ -216,7 +216,7 @@ def clear_history(request: gr.Request):
logger.info(f"clear_history. ip: {request.client.host}")
state = default_conversation.copy()
return (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn,) * 5 + \
(None, {'region_placeholder_tokens':[],'region_coordinates':[],'region_masks':[],'masks':[]}, [], None)
(None, {'region_placeholder_tokens':[],'region_coordinates':[],'region_masks':[],'region_masks_in_prompts':[],'masks':[]}, [], None)
def resize_bbox(box, image_w=None, image_h=None, default_wh=VOCAB_IMAGE_W):
@ -321,7 +321,26 @@ def post_process_code(code):
return code
def find_indices_in_order(str_list, STR):
indices = []
i = 0
while i < len(STR):
for element in str_list:
if STR[i:i+len(element)] == element:
indices.append(str_list.index(element))
i += len(element) - 1
break
i += 1
return indices
def format_region_prompt(prompt, refer_input_state):
# Find regions in prompts and assign corresponding region masks
refer_input_state['region_masks_in_prompts'] = []
indices_region_placeholder_in_prompt = find_indices_in_order(refer_input_state['region_placeholder_tokens'], prompt)
refer_input_state['region_masks_in_prompts'] = [refer_input_state['region_masks'][iii] for iii in indices_region_placeholder_in_prompt]
# Find regions in prompts and replace with real coordinates and region feature token.
for region_ph_index, region_ph_i in enumerate(refer_input_state['region_placeholder_tokens']):
prompt = prompt.replace(region_ph_i, '{} {}'.format(refer_input_state['region_coordinates'][region_ph_index], DEFAULT_REGION_FEA_TOKEN))
return prompt
@ -341,6 +360,32 @@ def http_bot(state, model_selector, temperature, top_p, max_new_tokens, refer_in
if len(state.messages) == state.offset + 2:
# First round of conversation
template_name = 'ferret_v1'
# Below is LLaVA's original templates.
# if "llava" in model_name.lower():
# if 'llama-2' in model_name.lower():
# template_name = "llava_llama_2"
# elif "v1" in model_name.lower():
# if 'mmtag' in model_name.lower():
# template_name = "v1_mmtag"
# elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():
# template_name = "v1_mmtag"
# else:
# template_name = "llava_v1"
# elif "mpt" in model_name.lower():
# template_name = "mpt"
# else:
# if 'mmtag' in model_name.lower():
# template_name = "v0_mmtag"
# elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():
# template_name = "v0_mmtag"
# else:
# template_name = "llava_v0"
# elif "mpt" in model_name:
# template_name = "mpt_text"
# elif "llama-2" in model_name:
# template_name = "llama_2"
# else:
# template_name = "vicuna_v1"
new_state = conv_templates[template_name].copy()
new_state.append_message(new_state.roles[0], state.messages[-2][1])
new_state.append_message(new_state.roles[1], None)
@ -386,8 +431,8 @@ def http_bot(state, model_selector, temperature, top_p, max_new_tokens, refer_in
}
logger.info(f"==== request ====\n{pload}")
if args.add_region_feature:
pload['region_masks'] = refer_input_state['region_masks']
logger.info(f"==== add region_masks to request ====\n")
pload['region_masks'] = refer_input_state['region_masks_in_prompts']
logger.info(f"==== add region_masks_in_prompts to request ====\n")
pload['images'] = state.get_images()
print(f'Input Prompt: {prompt}')
@ -439,8 +484,8 @@ def http_bot(state, model_selector, temperature, top_p, max_new_tokens, refer_in
title_markdown = ("""
# 🦦 Ferret: Refer and Ground Anything Anywhere at Any Granularity
[[Code](https://github.com/apple/ml-ferret)] [[Paper](https://arxiv.org/abs/2310.07704)]
""")
# [[Project Page]](https://llava-vl.github.io) [[Paper]](https://arxiv.org/abs/2304.08485)
tos_markdown = ("""
### Terms of use
@ -554,6 +599,7 @@ def draw(input_mode, input, refer_input_state, refer_text_show, imagebox_refer):
refer_input_state['region_placeholder_tokens'].append(cur_region_token)
refer_input_state['region_coordinates'].append(cur_region_coordinates)
refer_input_state['region_masks'].append(cur_region_masks)
assert len(refer_input_state['region_masks']) == len(refer_input_state['region_coordinates']) == len(refer_input_state['region_placeholder_tokens'])
refer_text_show.append((cur_region_token, ''))
# Show Parsed Referring.
@ -597,6 +643,7 @@ def build_demo(embed_mode):
refer_input_state = gr.State({'region_placeholder_tokens':[],
'region_coordinates':[],
'region_masks':[],
'region_masks_in_prompts':[],
'masks':[],
})
refer_text_show = gr.HighlightedText(value=[], label="Referring Input Cache")

View File

@ -0,0 +1,77 @@
"""
Usage:
# 7B
To extract region_geo_sampler:
python misc/extract_geosampler_and_mm_projector.py \
--keys_to_match=region_geo_sampler \
--model-path=./checkpoints/ferret_ft_clipL336_vicunaV1-3-7b_3Ep_dataV16_RSamplerV2/checkpoint-final \
--output=./checkpoints/ferret_ft_clipL336_vicunaV1-3-7b_3Ep_dataV16_RSamplerV2/extracted_region_geo_sampler.bin
To extract mm_projector:
python misc/extract_geosampler_and_mm_projector.py \
--keys_to_match=mm_projector \
--model-path=./checkpoints/ferret_ft_clipL336_vicunaV1-3-7b_3Ep_dataV16_RSamplerV2/checkpoint-final \
--output=./checkpoints/ferret_ft_clipL336_vicunaV1-3-7b_3Ep_dataV16_RSamplerV2/extracted_mm_projector.bin
# 13B
To extract region_geo_sampler:
python misc/extract_geosampler_and_mm_projector.py \
--keys_to_match=region_geo_sampler \
--model-path=./checkpoints/ferret_ft_clipL336_vicunaV1-3-13b_3Ep_dataV16_RSamplerV2/checkpoint-final \
--output=./checkpoints/ferret_ft_clipL336_vicunaV1-3-13b_3Ep_dataV16_RSamplerV2/extracted_region_geo_sampler.bin
To extract mm_projector:
python misc/extract_geosampler_and_mm_projector.py \
--keys_to_match=mm_projector \
--model-path=./checkpoints/ferret_ft_clipL336_vicunaV1-3-13b_3Ep_dataV16_RSamplerV2/checkpoint-final \
--output=./checkpoints/ferret_ft_clipL336_vicunaV1-3-13b_3Ep_dataV16_RSamplerV2/extracted_mm_projector.bin
"""
import os
import argparse
import torch
import json
from collections import defaultdict
def parse_args():
parser = argparse.ArgumentParser(description='Extract MMProjector or GeoSampler weights')
parser.add_argument('--model-path', type=str, help='model folder')
parser.add_argument('--output', type=str, help='output file')
parser.add_argument('--keys_to_match', type=str, default="region_geo_sampler", choices=["mm_projector", "region_geo_sampler"], help='keys to be matched')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
keys_to_match = [args.keys_to_match]
ckpt_to_key = defaultdict(list)
print('----indexing keys_to_match...----')
try:
model_indices = json.load(open(os.path.join(args.model_path, 'pytorch_model.bin.index.json')))
for k, v in model_indices['weight_map'].items():
if any(key_match in k for key_match in keys_to_match):
ckpt_to_key[v].append(k)
except FileNotFoundError:
# Smaller models or model checkpoints saved by DeepSpeed.
v = 'pytorch_model.bin'
for k in torch.load(os.path.join(args.model_path, v), map_location='cpu').keys():
if any(key_match in k for key_match in keys_to_match):
ckpt_to_key[v].append(k)
loaded_weights = {}
print('----loading weights...----')
for ckpt_name, weight_keys in ckpt_to_key.items():
ckpt = torch.load(os.path.join(args.model_path, ckpt_name), map_location='cpu')
for k in weight_keys:
loaded_weights[k] = ckpt[k]
print('----saving weights...----')
print(f'the keys of saved weights: {loaded_weights.keys()}')
print(f'----saved to {args.output}----')
torch.save(loaded_weights, args.output)

51
scripts/verify_equal.py Normal file
View File

@ -0,0 +1,51 @@
"""
Usage:
python3 misc/verify_equal.py \
--orig-model-path ./checkpoints/ferret_ft_clipL336_vicunaV1-3-7b_3Ep_dataV16_RSamplerV2/checkpoint-final \
--new-model-path ./model/ferret-7b-v1-3
"""
import argparse
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from ferret import FERRETLlamaForCausalLM
def verify_equal(old_model_path, new_model_path):
print("Loading old model")
old = FERRETLlamaForCausalLM.from_pretrained(old_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
print("Loading saved model")
new = FERRETLlamaForCausalLM.from_pretrained(new_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
# Get state dictionaries of both models
state_dict1 = old.state_dict()
state_dict2 = new.state_dict()
# Compare each parameter
for name, param in tqdm(state_dict1.items(), desc="Traverse all params"):
# Check if the parameter name exists in the second model
if name not in state_dict2:
print(f"Parameter {name} found in the first model but not in the second.")
return False
# Check if the parameter weights are the same, bf16 vs. f32
if not torch.allclose(param, state_dict2[name], atol=1e-4):
print(param.shape)
print(state_dict2[name].shape)
print(f"Parameter weights for {name} are different.")
return False
print("All parameter names and weights are the same.")
return True
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--orig-model-path", type=str, required=True)
parser.add_argument("--new-model-path", type=str, required=True)
args = parser.parse_args()
print(verify_equal(args.orig_model_path, args.new_model_path))