diff options
Diffstat (limited to 'examples/embd-input/llava.py')
-rw-r--r-- | examples/embd-input/llava.py | 70 |
1 files changed, 70 insertions, 0 deletions
diff --git a/examples/embd-input/llava.py b/examples/embd-input/llava.py new file mode 100644 index 0000000..2f20cb7 --- /dev/null +++ b/examples/embd-input/llava.py @@ -0,0 +1,70 @@ +import sys +import os +sys.path.insert(0, os.path.dirname(__file__)) +from embd_input import MyModel +import numpy as np +from torch import nn +import torch +from transformers import CLIPVisionModel, CLIPImageProcessor +from PIL import Image + +# model parameters from 'liuhaotian/LLaVA-13b-delta-v1-1' +vision_tower = "openai/clip-vit-large-patch14" +select_hidden_state_layer = -2 +# (vision_config.image_size // vision_config.patch_size) ** 2 +image_token_len = (224//14)**2 + +class Llava: + def __init__(self, args): + self.image_processor = CLIPImageProcessor.from_pretrained(vision_tower) + self.vision_tower = CLIPVisionModel.from_pretrained(vision_tower) + self.mm_projector = nn.Linear(1024, 5120) + self.model = MyModel(["main", *args]) + + def load_projection(self, path): + state = torch.load(path) + self.mm_projector.load_state_dict({ + "weight": state["model.mm_projector.weight"], + "bias": state["model.mm_projector.bias"]}) + + def chat(self, question): + self.model.eval_string("user: ") + self.model.eval_string(question) + self.model.eval_string("\nassistant: ") + return self.model.generate_with_print() + + def chat_with_image(self, image, question): + with torch.no_grad(): + embd_image = self.image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] + image_forward_out = self.vision_tower(embd_image.unsqueeze(0), output_hidden_states=True) + select_hidden_state = image_forward_out.hidden_states[select_hidden_state_layer] + image_feature = select_hidden_state[:, 1:] + embd_image = self.mm_projector(image_feature) + embd_image = embd_image.cpu().numpy()[0] + self.model.eval_string("user: ") + self.model.eval_token(32003-2) # im_start + self.model.eval_float(embd_image.T) + for i in range(image_token_len-embd_image.shape[0]): + self.model.eval_token(32003-3) # im_patch + self.model.eval_token(32003-1) # im_end + self.model.eval_string(question) + self.model.eval_string("\nassistant: ") + return self.model.generate_with_print() + + +if __name__=="__main__": + # model form liuhaotian/LLaVA-13b-delta-v1-1 + a = Llava(["--model", "./models/ggml-llava-13b-v1.1.bin", "-c", "2048"]) + # Extract from https://huggingface.co/liuhaotian/LLaVA-13b-delta-v1-1/blob/main/pytorch_model-00003-of-00003.bin. + # Also here can use pytorch_model-00003-of-00003.bin directly. + a.load_projection(os.path.join( + os.path.dirname(__file__) , + "llava_projetion.pth")) + respose = a.chat_with_image( + Image.open("./media/llama1-logo.png").convert('RGB'), + "what is the text in the picture?") + respose + a.chat("what is the color of it?") + + + |