diff --git a/llm/inference/janus_pro/.gitignore b/llm/inference/janus_pro/.gitignore new file mode 100644 index 000000000..f2d0321a1 --- /dev/null +++ b/llm/inference/janus_pro/.gitignore @@ -0,0 +1,139 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site +*__pycache__* +*kernel_meta* +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ diff --git a/llm/inference/janus_pro/generated_samples/img_0.jpg b/llm/inference/janus_pro/generated_samples/img_0.jpg new file mode 100644 index 000000000..8c046404e Binary files /dev/null and b/llm/inference/janus_pro/generated_samples/img_0.jpg differ diff --git a/llm/inference/janus_pro/generation.py b/llm/inference/janus_pro/generation.py new file mode 100644 index 000000000..11fd731b5 --- /dev/null +++ b/llm/inference/janus_pro/generation.py @@ -0,0 +1,124 @@ +import os +import PIL.Image +import mindspore +from mindspore._c_expression import disable_multi_thread +disable_multi_thread() +import mindspore as ms +import numpy as np +from mindnlp.core import ops +from mindnlp.transformers import AutoModelForCausalLM +from janus.models import MultiModalityCausalLM, VLChatProcessor +import mindspore.context as context + +from mindnlp.configs import use_pyboost, set_pyboost +set_pyboost(False) +print('use_pyboost:', use_pyboost()) +mindspore.set_context( + mode=mindspore.PYNATIVE_MODE, + # max_device_memory="15GB", + pynative_synchronize=True, + device_target="Ascend", + # mode=mindspore.GRAPH_MODE, + # jit_config={"jit_level":"O2"}, + ascend_config={"precision_mode":"allow_mix_precision"}) +print(mindspore.get_context("mode")) +# specify the path to the model +model_path = "/home/HwHiAiUser/Janus-Pro-1B" +vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path) +tokenizer = vl_chat_processor.tokenizer + +vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained( + model_path, trust_remote_code=True, ms_dtype=mindspore.float16 +) +print('loaded processor and ckpt ') + + +conversation = [ + { + "role": "<|User|>", + "content": "A stunning princess from kabul in red, white traditional clothing, blue eyes, brown hair", + # "content": "sun under blue sky", + }, + {"role": "<|Assistant|>", "content": ""}, +] + +sft_format = vl_chat_processor.apply_sft_template_for_multi_turn_prompts( + conversations=conversation, + sft_format=vl_chat_processor.sft_format, + system_prompt="", +) +prompt = sft_format + vl_chat_processor.image_start_tag +from mindnlp.core import no_grad + +# @torch.inference_mode() +with no_grad(): + def generate( + mmgpt: MultiModalityCausalLM, + vl_chat_processor: VLChatProcessor, + prompt: str, + temperature: float = 1, + parallel_size: int = 1, #16, + cfg_weight: float = 5, + # image_token_num_per_image: int = 8,#576, + image_token_num_per_image: int = 576,#576, + img_size: int = 384, + patch_size: int = 16, + ): + input_ids = vl_chat_processor.tokenizer.encode(prompt) + input_ids = ms.Tensor(input_ids, dtype=ms.int64) + + tokens = ops.zeros(parallel_size*2, len(input_ids), dtype=ms.int32) + for i in range(parallel_size*2): + tokens[i, :] = input_ids + if i % 2 != 0: + tokens[i, 1:-1] = vl_chat_processor.pad_id + + inputs_embeds = mmgpt.language_model.get_input_embeddings()(tokens) #(parallel_size*2, len(input_ids) ) + + generated_tokens = ops.zeros(parallel_size, image_token_num_per_image, dtype=ms.int32) + + for i in range(image_token_num_per_image): + print(f"generating token {i}") + outputs = mmgpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=outputs.past_key_values if i != 0 else None) + hidden_states = outputs.last_hidden_state # (parallel_size*2, len(input_ids), 2048) + + logits = mmgpt.gen_head(hidden_states[:, -1, :]) #取最后一个input_id送入gen_head=>(parallel_size*2, vocab_size) + logit_cond = logits[0::2, :] + logit_uncond = logits[1::2, :] + + logits = logit_uncond + cfg_weight * (logit_cond-logit_uncond) + probs = ops.softmax(logits / temperature, dim=-1) + + next_token = ops.multinomial(probs, num_samples=1) # (parallel_size, num_samples=1) + generated_tokens[:, i] = next_token.squeeze(axis=-1) + + next_token = ops.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1) # (parallel_size*2) + img_embeds = mmgpt.prepare_gen_img_embeds(next_token) # (parallel_size*2, 2048) + # print("img_embeds.shape:", img_embeds.shape) + # print("img_embeds.dtype:", img_embeds.dtype) + inputs_embeds = img_embeds.unsqueeze(dim=1) #(parallel_size*2, 2048) + print("generated one token") + + if image_token_num_per_image==576: + dec = mmgpt.gen_vision_model.decode_code(generated_tokens.astype(ms.int32), shape=[parallel_size, 8, img_size//patch_size, img_size//patch_size]) + else: + pad_last_token = generated_tokens[:,-1].unsqueeze(dim=1).tile((1, 576-image_token_num_per_image)) + cat_generated_tokens=ops.cat([generated_tokens, pad_last_token], dim=1) + print("cat_generated_tokens.shape:",cat_generated_tokens.shape) #(1,576) + dec = mmgpt.gen_vision_model.decode_code(cat_generated_tokens.astype(ms.int32), shape=[parallel_size, 8, img_size//patch_size, img_size//patch_size]) + dec = dec.astype(ms.float32).asnumpy().transpose(0, 2, 3, 1) + + dec = np.clip((dec + 1) / 2 * 255, 0, 255) + + visual_img = np.zeros((parallel_size, img_size, img_size, 3), dtype=np.uint8) + visual_img[:, :, :] = dec + + os.makedirs('generated_samples', exist_ok=True) + for i in range(parallel_size): + save_path = os.path.join('generated_samples', "img_{}.jpg".format(i)) + PIL.Image.fromarray(visual_img[i]).save(save_path) + generate( + vl_gpt, + vl_chat_processor, + prompt, + ) \ No newline at end of file diff --git a/llm/inference/janus_pro/inpain_model_cat.png b/llm/inference/janus_pro/inpain_model_cat.png new file mode 100644 index 000000000..e82640973 Binary files /dev/null and b/llm/inference/janus_pro/inpain_model_cat.png differ diff --git a/llm/inference/janus_pro/janus/__init__.py b/llm/inference/janus_pro/janus/__init__.py new file mode 100644 index 000000000..09cc08cee --- /dev/null +++ b/llm/inference/janus_pro/janus/__init__.py @@ -0,0 +1,31 @@ +# Copyright (c) 2023-2024 DeepSeek. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + + +# check if python version is above 3.10 +import sys + +if sys.version_info >= (3, 10): + print("Python version is above 3.10, patching the collections module.") + # Monkey patch collections + import collections + import collections.abc + + for type_name in collections.abc.__all__: + setattr(collections, type_name, getattr(collections.abc, type_name)) diff --git a/llm/inference/janus_pro/janus/models/__init__.py b/llm/inference/janus_pro/janus/models/__init__.py new file mode 100644 index 000000000..946919361 --- /dev/null +++ b/llm/inference/janus_pro/janus/models/__init__.py @@ -0,0 +1,28 @@ +# Copyright (c) 2023-2024 DeepSeek. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +from .image_processing_vlm import VLMImageProcessor +from .modeling_vlm import MultiModalityCausalLM +from .processing_vlm import VLChatProcessor + +__all__ = [ + "VLMImageProcessor", + "VLChatProcessor", + "MultiModalityCausalLM", +] diff --git a/llm/inference/janus_pro/janus/models/clip_encoder.py b/llm/inference/janus_pro/janus/models/clip_encoder.py new file mode 100644 index 000000000..a0620cfe6 --- /dev/null +++ b/llm/inference/janus_pro/janus/models/clip_encoder.py @@ -0,0 +1,121 @@ +# Copyright (c) 2023-2024 DeepSeek. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +from typing import Dict, List, Literal, Optional, Tuple, Union + +import mindspore +from mindnlp.core import nn +from mindspore.dataset.vision import Normalize + +from janus.models.siglip_vit import create_siglip_vit + + +class CLIPVisionTower(nn.Module): + def __init__( + self, + model_name: str = "siglip_large_patch16_384", + image_size: Union[Tuple[int, int], int] = 336, + select_feature: str = "patch", + select_layer: int = -2, + select_layers: list = None, + ckpt_path: str = "", + pixel_mean: Optional[List[float]] = None, + pixel_std: Optional[List[float]] = None, + **kwargs, + ): + super().__init__() + + self.model_name = model_name + self.select_feature = select_feature + self.select_layer = select_layer + self.select_layers = select_layers + + vision_tower_params = { + "model_name": model_name, + "image_size": image_size, + "ckpt_path": ckpt_path, + "select_layer": select_layer, + } + vision_tower_params.update(kwargs) + self.vision_tower, self.forward_kwargs = self.build_vision_tower( + vision_tower_params + ) + + if pixel_mean is not None and pixel_std is not None: + image_norm = Normalize( + mean=pixel_mean, std=pixel_std + ) + else: + image_norm = None + + self.image_norm = image_norm + + def build_vision_tower(self, vision_tower_params): + if self.model_name.startswith("siglip"): + self.select_feature = "same" + vision_tower = create_siglip_vit(**vision_tower_params) + forward_kwargs = dict() + + elif self.model_name.startswith("sam"): + vision_tower = create_sam_vit(**vision_tower_params) + forward_kwargs = dict() + + else: # huggingface + from mindnlp.transformers import CLIPVisionModel + + vision_tower = CLIPVisionModel.from_pretrained(**vision_tower_params) + forward_kwargs = dict(output_hidden_states=True) + + return vision_tower, forward_kwargs + + def feature_select(self, image_forward_outs): + if isinstance(image_forward_outs, mindspore.Tensor): + # the output has been the self.select_layer"s features + image_features = image_forward_outs + else: + image_features = image_forward_outs.hidden_states[self.select_layer] + + if self.select_feature == "patch": + # if the output has cls_token + image_features = image_features[:, 1:] + elif self.select_feature == "cls_patch": + image_features = image_features + elif self.select_feature == "same": + image_features = image_features + + else: + raise ValueError(f"Unexpected select feature: {self.select_feature}") + return image_features + + def forward(self, images): + """ + + Args: + images (torch.Tensor): [b, 3, H, W] + + Returns: + image_features (torch.Tensor): [b, n_patch, d] + """ + + if self.image_norm is not None: + images = self.image_norm(images) + + image_forward_outs = self.vision_tower(images, **self.forward_kwargs) + image_features = self.feature_select(image_forward_outs) + return image_features diff --git a/llm/inference/janus_pro/janus/models/image_processing_vlm.py b/llm/inference/janus_pro/janus/models/image_processing_vlm.py new file mode 100644 index 000000000..450a481c2 --- /dev/null +++ b/llm/inference/janus_pro/janus/models/image_processing_vlm.py @@ -0,0 +1,213 @@ +# Copyright (c) 2023-2024 DeepSeek. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +from typing import List, Tuple, Union + +import numpy as np +import mindspore as ms +# import torchvision +# import torchvision.transforms.functional +from PIL import Image +from mindnlp.transformers import AutoImageProcessor, PretrainedConfig +from mindnlp.transformers.image_processing_utils import BaseImageProcessor, BatchFeature +from mindnlp.transformers.image_utils import to_numpy_array + + +ImageType = Union[np.ndarray, ms.Tensor, Image.Image] +IMAGENET_MEAN = (0.48145466, 0.4578275, 0.40821073) +IMAGENET_STD = (0.26862954, 0.26130258, 0.27577711) +IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) +IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) + + +def expand2square(pil_img, background_color): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + + +class VLMImageProcessorConfig(PretrainedConfig): + model_type = "deepseek_vlm" + image_size: int + min_size: int + image_mean: Union[Tuple[float, float, float], List[float]] + image_std: Union[Tuple[float, float, float], List[float]] + rescale_factor: float + do_normalize: bool + + def __init__( + self, + image_size: int, + min_size: int = 14, + image_mean: Union[Tuple[float, float, float], List[float]] = ( + 0.48145466, + 0.4578275, + 0.40821073, + ), + image_std: Union[Tuple[float, float, float], List[float]] = ( + 0.26862954, + 0.26130258, + 0.27577711, + ), + rescale_factor: float = 1.0 / 255.0, + do_normalize: bool = True, + **kwargs, + ): + self.image_size = image_size + self.min_size = min_size + self.image_mean = image_mean + self.image_std = image_std + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + + super().__init__(**kwargs) + + +class VLMImageProcessor(BaseImageProcessor): + model_input_names = ["pixel_values"] + + def __init__( + self, + image_size: int, + min_size: int = 14, + image_mean: Union[Tuple[float, float, float], List[float]] = ( + 0.48145466, + 0.4578275, + 0.40821073, + ), + image_std: Union[Tuple[float, float, float], List[float]] = ( + 0.26862954, + 0.26130258, + 0.27577711, + ), + rescale_factor: float = 1.0 / 255.0, + do_normalize: bool = True, + **kwargs, + ): + super().__init__(**kwargs) + + self.image_size = image_size + self.rescale_factor = rescale_factor + self.image_mean = image_mean + self.image_std = image_std + self.min_size = min_size + self.do_normalize = do_normalize + + if image_mean is None: + self.background_color = (127, 127, 127) + else: + self.background_color = tuple([int(x * 255) for x in image_mean]) + + def resize(self, pil_img: Image) -> np.ndarray: + """ + + Args: + pil_img (PIL.Image): [H, W, 3] in PIL.Image in RGB + + Returns: + x (np.ndarray): [3, self.image_size, self.image_size] + """ + + width, height = pil_img.size + max_size = max(width, height) + + size = [ + max(int(height / max_size * self.image_size), self.min_size), + max(int(width / max_size * self.image_size), self.min_size), + ] + + if width <= 0 or height <= 0 or size[0] <= 0 or size[1] <= 0: + print(f"orig size = {pil_img.size}, new size = {size}") + raise ValueError("Invalid size!") + + # pil_img = torchvision.transforms.functional.resize( + # pil_img, + # size, + # interpolation=torchvision.transforms.functional.InterpolationMode.BICUBIC, + # antialias=True, + # ) + # pil_img = ms.Tensor(np.array(pil_img)) + pil_img = ms.dataset.vision.Resize( + size, + interpolation=ms.dataset.vision.Inter.BICUBIC, + # antialias=True, 没有此参数 + )(pil_img) + # pil_img = Image.fromarray(pil_img.asnumpy()) + + pil_img = expand2square(pil_img, self.background_color) + x = to_numpy_array(pil_img) + + # [H, W, 3] -> [3, H, W] + x = np.transpose(x, (2, 0, 1)) + + return x + + def preprocess(self, images, return_tensors: str = "pt", **kwargs) -> BatchFeature: + # resize and pad to [self.image_size, self.image_size] + # then convert from [H, W, 3] to [3, H, W] + images: List[np.ndarray] = [self.resize(image) for image in images] + + # resacle from [0, 255] -> [0, 1] + images = [ + self.rescale( + image=image, + scale=self.rescale_factor, + input_data_format="channels_first", + ) + for image in images + ] + + # normalize + if self.do_normalize: + images = [ + self.normalize( + image=image, + mean=self.image_mean, + std=self.image_std, + input_data_format="channels_first", + ) + for image in images + ] + + data = {"pixel_values": images} + return BatchFeature(data=data, tensor_type=return_tensors) + + @property + def default_shape(self): + return [3, self.image_size, self.image_size] + + +AutoImageProcessor.register(VLMImageProcessorConfig, VLMImageProcessor) + + +if __name__ == "__main__": + image_processor = VLMImageProcessor( + image_size=1024, + image_mean=IMAGENET_INCEPTION_MEAN, + image_std=IMAGENET_INCEPTION_STD, + do_normalize=True, + ) diff --git a/llm/inference/janus_pro/janus/models/modeling_vlm.py b/llm/inference/janus_pro/janus/models/modeling_vlm.py new file mode 100644 index 000000000..2910d5a10 --- /dev/null +++ b/llm/inference/janus_pro/janus/models/modeling_vlm.py @@ -0,0 +1,303 @@ +# Copyright (c) 2023-2024 DeepSeek. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import mindspore +from mindnlp.core import nn +from attrdict import AttrDict +from mindnlp.core import ops +from mindnlp.transformers import ( + AutoConfig, + AutoModelForCausalLM, + LlamaConfig, + LlamaForCausalLM, + PreTrainedModel, +) +import numpy as np +from mindnlp.transformers.configuration_utils import PretrainedConfig +from mindnlp.core import Tensor +from janus.models.clip_encoder import CLIPVisionTower +from janus.models.projector import MlpProjector + + +class vision_head(nn.Module): + def __init__(self, params): + super().__init__() + self.output_mlp_projector = nn.Linear( + params.n_embed, params.image_token_embed + ) + self.vision_activation = nn.GELU() + self.vision_head = nn.Linear( + params.image_token_embed, params.image_token_size + ) + + def forward(self, x): + x = self.output_mlp_projector(x) + x = self.vision_activation(x) + x = self.vision_head(x) + return x + + +def model_name_to_cls(cls_name): + if "MlpProjector" in cls_name: + cls = MlpProjector + + elif "CLIPVisionTower" in cls_name: + cls = CLIPVisionTower + + elif "VQ" in cls_name: + from janus.models.vq_model import VQ_models + + cls = VQ_models[cls_name] + elif "vision_head" in cls_name: + cls = vision_head + else: + raise ValueError(f"class_name {cls_name} is invalid.") + + return cls + + +class VisionConfig(PretrainedConfig): + model_type = "vision" + cls: str = "" + params: AttrDict = {} + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + self.cls = kwargs.get("cls", "") + if not isinstance(self.cls, str): + self.cls = self.cls.__name__ + + self.params = AttrDict(kwargs.get("params", {})) + + +class AlignerConfig(PretrainedConfig): + model_type = "aligner" + cls: str = "" + params: AttrDict = {} + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + self.cls = kwargs.get("cls", "") + if not isinstance(self.cls, str): + self.cls = self.cls.__name__ + + self.params = AttrDict(kwargs.get("params", {})) + + +class GenVisionConfig(PretrainedConfig): + model_type = "gen_vision" + cls: str = "" + params: AttrDict = {} + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + self.cls = kwargs.get("cls", "") + if not isinstance(self.cls, str): + self.cls = self.cls.__name__ + + self.params = AttrDict(kwargs.get("params", {})) + + +class GenAlignerConfig(PretrainedConfig): + model_type = "gen_aligner" + cls: str = "" + params: AttrDict = {} + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + self.cls = kwargs.get("cls", "") + if not isinstance(self.cls, str): + self.cls = self.cls.__name__ + + self.params = AttrDict(kwargs.get("params", {})) + + +class GenHeadConfig(PretrainedConfig): + model_type = "gen_head" + cls: str = "" + params: AttrDict = {} + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + self.cls = kwargs.get("cls", "") + if not isinstance(self.cls, str): + self.cls = self.cls.__name__ + + self.params = AttrDict(kwargs.get("params", {})) + + +class MultiModalityConfig(PretrainedConfig): + model_type = "multi_modality" + vision_config: VisionConfig + aligner_config: AlignerConfig + + gen_vision_config: GenVisionConfig + gen_aligner_config: GenAlignerConfig + gen_head_config: GenHeadConfig + + language_config: LlamaConfig + + def __init__(self, **kwargs): + super().__init__(**kwargs) + vision_config = kwargs.get("vision_config", {}) + self.vision_config = VisionConfig(**vision_config) + + aligner_config = kwargs.get("aligner_config", {}) + self.aligner_config = AlignerConfig(**aligner_config) + + gen_vision_config = kwargs.get("gen_vision_config", {}) + self.gen_vision_config = GenVisionConfig(**gen_vision_config) + + gen_aligner_config = kwargs.get("gen_aligner_config", {}) + self.gen_aligner_config = GenAlignerConfig(**gen_aligner_config) + + gen_head_config = kwargs.get("gen_head_config", {}) + self.gen_head_config = GenHeadConfig(**gen_head_config) + + language_config = kwargs.get("language_config", {}) + if isinstance(language_config, LlamaConfig): + self.language_config = language_config + else: + self.language_config = LlamaConfig(**language_config) + + +class MultiModalityPreTrainedModel(PreTrainedModel): + config_class = MultiModalityConfig + base_model_prefix = "multi_modality" + _no_split_modules = [] + _skip_keys_device_placement = "past_key_values" + + +class MultiModalityCausalLM(MultiModalityPreTrainedModel): + def __init__(self, config: MultiModalityConfig): + super().__init__(config) + + vision_config = config.vision_config + vision_cls = model_name_to_cls(vision_config.cls) + self.vision_model = vision_cls(**vision_config.params) + + aligner_config = config.aligner_config + aligner_cls = model_name_to_cls(aligner_config.cls) + self.aligner = aligner_cls(aligner_config.params) + + gen_vision_config = config.gen_vision_config + gen_vision_cls = model_name_to_cls(gen_vision_config.cls) + self.gen_vision_model = gen_vision_cls() + + gen_aligner_config = config.gen_aligner_config + gen_aligner_cls = model_name_to_cls(gen_aligner_config.cls) + self.gen_aligner = gen_aligner_cls(gen_aligner_config.params) + + gen_head_config = config.gen_head_config + gen_head_cls = model_name_to_cls(gen_head_config.cls) + self.gen_head = gen_head_cls(gen_head_config.params) + + self.gen_embed = nn.Embedding( + gen_vision_config.params.image_token_size, gen_vision_config.params.n_embed + ) + + language_config = config.language_config + self.language_model = LlamaForCausalLM(language_config) + + def prepare_inputs_embeds( + self, + input_ids: mindspore.int64, + pixel_values: mindspore.float32, + images_seq_mask: mindspore.int64, + images_emb_mask: mindspore.int64, + **kwargs, + ): + """ + + Args: + input_ids (mindspore.int64): [b, T] + pixel_values (torch.FloatTensor): [b, n_images, 3, h, w] + images_seq_mask (torch.BoolTensor): [b, T] + images_emb_mask (torch.BoolTensor): [b, n_images, n_image_tokens] + + assert torch.sum(images_seq_mask) == torch.sum(images_emb_mask) + + Returns: + input_embeds (torch.Tensor): [b, T, D] + """ + + bs, n = pixel_values.shape[0:2] + # "b n c h w -> (b n) c h w" + images = ops.reshape( + pixel_values, (bs * n, pixel_values.shape[2], pixel_values.shape[3], pixel_values.shape[4])) + images_embeds = self.aligner(self.vision_model(images)) + + # "(b n) t d -> b (n t) d" + images_embeds = ops.reshape( + images_embeds, (bs, n * images_embeds.shape[1], images_embeds.shape[2])) + images_emb_mask = ops.reshape( + images_emb_mask, (bs, n * images_emb_mask.shape[2])) # "b n t -> b (n t)" + + # [b, T, D] + # input_ids[input_ids < 0] = 0 # ignore the image embeddings + condition = input_ids < 0 + input_ids = (1-condition) * input_ids + condition * \ + 0 # ignore the image embeddings + inputs_embeds = self.language_model.get_input_embeddings()(input_ids) + + # replace with the image embeddings + # 627 576 + # inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask] + # print("inputs_embeds:", inputs_embeds.shape) + # print("images_embeds[images_emb_mask].dtype", images_embeds[images_emb_mask].dtype) + print("inputs_embeds.dtype", inputs_embeds.dtype) + padding_size = images_seq_mask.shape[1] - images_emb_mask.shape[1] + padding = Tensor(np.full((images_seq_mask.shape[0], padding_size), False), dtype=images_emb_mask.dtype) + padded_images_emb_mask = ops.concat((images_emb_mask, padding), dim=1) + print("padded_images_emb_mask.shape:",padded_images_emb_mask.shape) + print("images_embeds.shape:",images_embeds.shape) + print("images_seq_mask.shape:",images_seq_mask.shape) + first_true = images_seq_mask.nonzero().squeeze()[0][1] # 42 + last_true = images_seq_mask.nonzero().squeeze()[-1][1] # 42 + print("first_true:",first_true) + print("last_true:",last_true) + left = inputs_embeds[:,:first_true].astype(mindspore.float32) + print(left.shape) + right = inputs_embeds[:, last_true+1:].astype(mindspore.float32) + print(right.shape) + inputs_embeds = ops.cat((left, images_embeds.astype(mindspore.float32), right),1) + print("inputs_embeds.shape:",inputs_embeds.shape) + print("inputs_embeds.dtype:",inputs_embeds.dtype) + + + + # inputs_embeds = images_embeds[padded_images_emb_mask] * images_seq_mask + inputs_embeds * (1 - images_seq_mask) + return inputs_embeds + + def prepare_gen_img_embeds(self, image_ids: mindspore.int64): + return self.gen_aligner(self.gen_embed(image_ids)) + + +AutoConfig.register("vision", VisionConfig) +AutoConfig.register("aligner", AlignerConfig) +AutoConfig.register("gen_vision", GenVisionConfig) +AutoConfig.register("gen_aligner", GenAlignerConfig) +AutoConfig.register("gen_head", GenHeadConfig) +AutoConfig.register("multi_modality", MultiModalityConfig) +AutoModelForCausalLM.register(MultiModalityConfig, MultiModalityCausalLM) diff --git a/llm/inference/janus_pro/janus/models/processing_vlm.py b/llm/inference/janus_pro/janus/models/processing_vlm.py new file mode 100644 index 000000000..7f881e396 --- /dev/null +++ b/llm/inference/janus_pro/janus/models/processing_vlm.py @@ -0,0 +1,424 @@ +# Copyright (c) 2023-2024 DeepSeek. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +from dataclasses import dataclass +from typing import Dict, List + +import mindspore +import mindspore.common.dtype as mstype +from PIL.Image import Image +from mindnlp.transformers import LlamaTokenizerFast +from mindnlp.transformers.processing_utils import ProcessorMixin +from mindnlp.core import ops +from mindnlp.core import Tensor + +from janus.models.image_processing_vlm import VLMImageProcessor +from janus.utils.conversation import get_conv_template + + +class DictOutput(object): + def keys(self): + return self.__dict__.keys() + + def __getitem__(self, item): + return self.__dict__[item] + + def __setitem__(self, key, value): + self.__dict__[key] = value + + +@dataclass +class VLChatProcessorOutput(DictOutput): + sft_format: str + input_ids: mindspore.Tensor + pixel_values: mindspore.Tensor + num_image_tokens: mindspore.int32 + + def __len__(self): + return len(self.input_ids) + + +@dataclass +class BatchedVLChatProcessorOutput(DictOutput): + sft_format: List[str] + input_ids: mindspore.Tensor + pixel_values: mindspore.Tensor + attention_mask: mindspore.Tensor + images_seq_mask: mindspore.Tensor + images_emb_mask: mindspore.Tensor + + def to(self, device, dtype=mindspore.float16): + self.input_ids = self.input_ids.to(device) + self.attention_mask = self.attention_mask.to(device) + self.images_seq_mask = self.images_seq_mask.to(device) + self.images_emb_mask = self.images_emb_mask.to(device) + self.pixel_values = self.pixel_values.to(device=device, dtype=dtype) + return self + + +class VLChatProcessor(ProcessorMixin): + image_processor_class = "AutoImageProcessor" + tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast") + + attributes = ["image_processor", "tokenizer"] + + system_prompt = ( + "You are a helpful language and vision assistant. " + "You are able to understand the visual content that the user provides, " + "and assist the user with a variety of tasks using natural language." + ) + + def __init__( + self, + image_processor: VLMImageProcessor, + tokenizer: LlamaTokenizerFast, + image_tag: str = "", + image_start_tag: str = "", + image_end_tag: str = "", + pad_tag: str = "<|▁pad▁|>", + num_image_tokens: int = 576, + add_special_token: bool = False, + sft_format: str = "deepseek", + mask_prompt: bool = True, + ignore_id: int = -100, + **kwargs, + ): + self.image_processor = image_processor + self.tokenizer = tokenizer + + image_id = self.tokenizer.vocab.get(image_tag) + if image_id is None: + special_tokens = [image_tag] + special_tokens_dict = {"additional_special_tokens": special_tokens} + self.tokenizer.add_special_tokens(special_tokens_dict) + print(f"Add image tag = {image_tag} to the tokenizer") + + self.image_tag = image_tag + self.image_start_tag = image_start_tag + self.image_end_tag = image_end_tag + self.pad_tag = pad_tag + + self.num_image_tokens = num_image_tokens + self.add_special_token = add_special_token + self.sft_format = sft_format + self.mask_prompt = mask_prompt + self.ignore_id = ignore_id + + super().__init__( + image_processor, + tokenizer, + image_tag, + num_image_tokens, + add_special_token, + sft_format, + mask_prompt, + ignore_id, + **kwargs, + ) + + def new_chat_template(self): + conv = get_conv_template(self.sft_format) + conv.set_system_message(self.system_prompt) + return conv + + def apply_sft_template_for_multi_turn_prompts( + self, + conversations: List[Dict[str, str]], + sft_format: str = "deepseek", + system_prompt: str = "", + ): + """ + Applies the SFT template to conversation. + + An example of conversation: + conversation = [ + { + "role": "User", + "content": " is Figure 1.\n is Figure 2.\nWhich image is brighter?", + "images": [ + "./multi-images/attribute_comparison_1.png", + "./multi-images/attribute_comparison_2.png" + ] + }, + { + "role": "Assistant", + "content": "" + } + ] + + Args: + conversations (List[Dict]): A conversation with a List of Dict[str, str] text. + sft_format (str, optional): The format of the SFT template to use. Defaults to "deepseek". + system_prompt (str, optional): The system prompt to use in the SFT template. Defaults to "". + + Returns: + sft_prompt (str): The formatted text. + """ + + conv = get_conv_template(sft_format) + conv.set_system_message(system_prompt) + for message in conversations: + conv.append_message(message["role"], message["content"].strip()) + sft_prompt = conv.get_prompt().strip() + + return sft_prompt + + @property + def image_token(self): + return self.image_tag + + @property + def image_id(self): + image_id = self.tokenizer.vocab.get(self.image_tag) + return image_id + + @property + def image_start_id(self): + image_start_id = self.tokenizer.vocab.get(self.image_start_tag) + return image_start_id + + @property + def image_end_id(self): + image_end_id = self.tokenizer.vocab.get(self.image_end_tag) + return image_end_id + + @property + def image_start_token(self): + return self.image_start_tag + + @property + def image_end_token(self): + return self.image_end_tag + + @property + def pad_id(self): + pad_id = self.tokenizer.vocab.get(self.pad_tag) + # pad_id = self.tokenizer.pad_token_id + # if pad_id is None: + # pad_id = self.tokenizer.eos_token_id + + return pad_id + + def add_image_token( + self, + image_indices: List[int], + input_ids: mindspore.int64, + ): + """ + + Args: + image_indices (List[int]): [index_0, index_1, ..., index_j] + input_ids (torch.LongTensor): [N] + + Returns: + input_ids (torch.LongTensor): [N + image tokens] + num_image_tokens (torch.IntTensor): [n_images] + """ + + input_slices = [] + + start = 0 + for index in image_indices: + if isinstance(index, mindspore.Tensor): + index = index.item() + if self.add_special_token: + end = index + 1 + else: + end = index + + # original text tokens + + input_slices.append(input_ids[start:end]) + + # add boi, image tokens, eoi and set the mask as False + input_slices.append(self.image_start_id * ops.ones(1, dtype=mindspore.int64)) + input_slices.append( + self.image_id * ops.ones(self.num_image_tokens, dtype=mindspore.int64) + ) + input_slices.append(self.image_end_id * ops.ones(1, dtype=mindspore.int64)) + start = index + 1 + + # the left part + input_slices.append(input_ids[start:]) + + # concat all slices + input_ids = ops.cat(input_slices, dim=0) + num_image_tokens = mindspore.Tensor([self.num_image_tokens] * len(image_indices), mindspore.int32) + + return input_ids, num_image_tokens + + def process_one( + self, + prompt: str = None, + conversations: List[Dict[str, str]] = None, + images: List[Image] = None, + **kwargs, + ): + """ + + Args: + prompt (str): the formatted prompt; + conversations (List[Dict]): conversations with a list of messages; + images (List[ImageType]): the list of images; + **kwargs: + + Returns: + outputs (BaseProcessorOutput): the output of the processor, + - input_ids (torch.LongTensor): [N + image tokens] + - target_ids (torch.LongTensor): [N + image tokens] + - images (torch.FloatTensor): [n_images, 3, H, W] + - image_id (int): the id of the image token + - num_image_tokens (List[int]): the number of image tokens + """ + + assert ( + prompt is None or conversations is None + ), "prompt and conversations cannot be used at the same time." + + if prompt is None: + # apply sft format + sft_format = self.apply_sft_template_for_multi_turn_prompts( + conversations=conversations, + sft_format=self.sft_format, + system_prompt=self.system_prompt, + ) + else: + sft_format = prompt + + # tokenize + input_ids = self.tokenizer.encode(sft_format) + input_ids = Tensor(input_ids, dtype=mindspore.int64) + + # add image tokens to the input_ids + image_token_mask = input_ids == self.image_id + image_indices = image_token_mask.nonzero() + input_ids, num_image_tokens = self.add_image_token( + image_indices=image_indices, + input_ids=input_ids, + ) + + # load images + images_outputs = self.image_processor(images, return_tensors="ms") + + prepare = VLChatProcessorOutput( + sft_format=sft_format, + input_ids=input_ids, + pixel_values=images_outputs.pixel_values, + num_image_tokens=num_image_tokens, + ) + + return prepare + + def __call__( + self, + *, + prompt: str = None, + conversations: List[Dict[str, str]] = None, + images: List[Image] = None, + force_batchify: bool = True, + **kwargs, + ): + """ + + Args: + prompt (str): the formatted prompt; + conversations (List[Dict]): conversations with a list of messages; + images (List[ImageType]): the list of images; + force_batchify (bool): force batchify the inputs; + **kwargs: + + Returns: + outputs (BaseProcessorOutput): the output of the processor, + - input_ids (torch.LongTensor): [N + image tokens] + - images (torch.FloatTensor): [n_images, 3, H, W] + - image_id (int): the id of the image token + - num_image_tokens (List[int]): the number of image tokens + """ + + prepare = self.process_one( + prompt=prompt, conversations=conversations, images=images + ) + + if force_batchify: + prepare = self.batchify([prepare]) + + return prepare + + def batchify( + self, prepare_list: List[VLChatProcessorOutput] + ) -> BatchedVLChatProcessorOutput: + """ + Preprocesses the inputs for multimodal inference. + + Args: + prepare_list (List[VLChatProcessorOutput]): A list of VLChatProcessorOutput. + + Returns: + BatchedVLChatProcessorOutput: A dictionary of the inputs to use for multimodal inference. + """ + + batch_size = len(prepare_list) + sft_format = [] + n_images = [] + seq_lens = [] + for prepare in prepare_list: + n_images.append(len(prepare.num_image_tokens)) + seq_lens.append(len(prepare)) + + input_token_max_len = max(seq_lens) + max_n_images = max(1, max(n_images)) + + batched_input_ids = ops.full( + (batch_size, input_token_max_len), self.pad_id + ).long() # FIXME + batched_attention_mask = ops.zeros(batch_size, input_token_max_len).long() + batched_pixel_values = ops.zeros( + batch_size, max_n_images, *self.image_processor.default_shape + ).float() + batched_images_seq_mask = ops.zeros(batch_size, input_token_max_len).bool() + batched_images_emb_mask = ops.zeros( + batch_size, max_n_images, self.num_image_tokens + ).bool() + + for i, prepare in enumerate(prepare_list): + input_ids = prepare.input_ids + seq_len = len(prepare) + n_image = len(prepare.num_image_tokens) + # left-padding + batched_attention_mask[i, -seq_len:] = 1 + batched_input_ids[i, -seq_len:] = mindspore.Tensor(input_ids, dtype=mindspore.int64) + batched_images_seq_mask[i, -seq_len:] = input_ids == self.image_id + + if n_image > 0: + batched_pixel_values[i, :n_image] = prepare.pixel_values + for j, n_image_tokens in enumerate(prepare.num_image_tokens): + batched_images_emb_mask[i, j, :n_image_tokens] = True + + sft_format.append(prepare.sft_format) + + batched_prepares = BatchedVLChatProcessorOutput( + input_ids=batched_input_ids, + attention_mask=batched_attention_mask, + pixel_values=batched_pixel_values, + images_seq_mask=batched_images_seq_mask, + images_emb_mask=batched_images_emb_mask, + sft_format=sft_format, + ) + + return batched_prepares diff --git a/llm/inference/janus_pro/janus/models/projector.py b/llm/inference/janus_pro/janus/models/projector.py new file mode 100644 index 000000000..8c3c096f4 --- /dev/null +++ b/llm/inference/janus_pro/janus/models/projector.py @@ -0,0 +1,110 @@ +# Copyright (c) 2023-2024 DeepSeek. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +from typing import Tuple, Union + +import mindspore +import mindspore.ops as mops +from mindnlp.core import ops +import mindnlp.core.nn as nn +from attrdict import AttrDict + +def contains_nan_or_inf(tensor, info): + tensor = tensor.astype(mindspore.float16) + havenan = mops.isnan(tensor).any() + haveinf = mops.isinf(tensor).any() + if haveinf: + print(info+'haveinf') + if havenan: + print(info+'havenan') +class MlpProjector(nn.Module): + def __init__(self, cfg): + super().__init__() + + self.cfg = cfg + + if cfg.projector_type == "identity": + modules = nn.Identity() + + elif cfg.projector_type == "linear": + modules = nn.Linear(cfg.input_dim, cfg.n_embed) + + elif cfg.projector_type == "mlp_gelu": + mlp_depth = cfg.get("depth", 1) + modules = [nn.Linear(cfg.input_dim, cfg.n_embed)] + for _ in range(1, mlp_depth): + modules.append(nn.GELU()) + modules.append(nn.Linear(cfg.n_embed, cfg.n_embed)) + modules = nn.Sequential(*modules) + + elif cfg.projector_type == "low_high_hybrid_split_mlp_gelu": + mlp_depth = cfg.get("depth", 1) + self.high_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2) + self.low_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2) + + modules = [] + for _ in range(1, mlp_depth): + modules.append(nn.GELU()) + modules.append(nn.Linear(cfg.n_embed, cfg.n_embed)) + modules = nn.Sequential(*modules) + + else: + raise ValueError(f"Unknown projector type: {cfg.projector_type}") + + self.layers = modules + + def forward( + self, x_or_tuple: Union[Tuple[mindspore.Tensor, mindspore.Tensor], mindspore.Tensor] + ): + """ + + Args: + x_or_tuple (Union[Tuple[mindspore.Tensor, mindspore.Tensor], mindspore.Tensor]: if it is a tuple of mindspore.Tensor, + then it comes from the hybrid vision encoder, and x = high_res_x, low_res_x); + otherwise it is the feature from the single vision encoder. + + Returns: + x (mindspore.Tensor): [b, s, c] + """ + + if isinstance(x_or_tuple, tuple): + # self.cfg.projector_type == "low_high_hybrid_split_mlp_gelu": + high_x, low_x = x_or_tuple + high_x = self.high_up_proj(high_x) + low_x = self.low_up_proj(low_x) + x = ops.concat([high_x, low_x], dim=-1) + else: + x = x_or_tuple + + contains_nan_or_inf(x,"MlpProjector.layer(x)") + return self.layers(x) + + +if __name__ == "__main__": + cfg = AttrDict( + input_dim=1024, + n_embed=2048, + depth=2, + projector_type="low_high_hybrid_split_mlp_gelu", + ) + inputs = (mindspore.mint.rand(4, 576, 1024), mindspore.mint.rand(4, 576, 1024)) + + m = MlpProjector(cfg) + out = m(inputs) + print(out.shape) diff --git a/llm/inference/janus_pro/janus/models/siglip_vit.py b/llm/inference/janus_pro/janus/models/siglip_vit.py new file mode 100644 index 000000000..56a6f2999 --- /dev/null +++ b/llm/inference/janus_pro/janus/models/siglip_vit.py @@ -0,0 +1,705 @@ +# Copyright (c) 2023-2024 DeepSeek. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +# https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py +import math +import warnings +from dataclasses import dataclass +from functools import partial +from typing import ( + Callable, + Dict, + Final, + List, + Literal, + Optional, + Sequence, + Set, + Tuple, + Type, + Union, +) + +import mindspore +import mindspore as ms +from mindnlp.core import ops +from mindspore.ops import uniform +import mindnlp.core.nn as nn +from mindnlp.core import no_grad, Tensor +import numpy as np +from scipy.special import erfinv +import mindnlp.core.nn.functional as F +from janus.models.timm_layers import ( + AttentionPoolLatent, + DropPath, + LayerType, + Mlp, + PatchDropout, + PatchEmbed, + resample_abs_pos_embed, +) + + +# ============================================================== +def named_apply( + fn: Callable, + module: nn.Module, name='', + depth_first: bool = True, + include_root: bool = False, +) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = '.'.join((name, child_name)) if name else child_name + named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + fn(module=module, name=name) + return module + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + with no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor = uniform(tensor.shape, Tensor(2 * l - 1, dtype=tensor.dtype), Tensor(2 * u - 1, dtype=tensor.dtype), dtype=tensor.dtype) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + + # tensor = ops.erfinv(tensor) + tensor = Tensor(erfinv(tensor.asnumpy()), dtype=tensor.dtype) + + # Transform to proper mean, std + tensor = tensor.mul(std * math.sqrt(2.) ) + tensor = tensor.add(mean) + + # Clamp to ensure it's in the proper range + tensor = tensor.clamp(min=a, max=b) + return tensor + + + +def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): + # type: (torch.Tensor, float, float, float, float) -> torch.Tensor + r"""The original timm.models.layers.weight_init.trunc_normal_ can not handle bfloat16 yet, here we first + convert the tensor to float32, apply the trunc_normal_() in float32, and then convert it back to its original dtype. + Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn + from the normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + + with no_grad(): + dtype = tensor.dtype + tensor_fp32 = tensor.astype(mindspore.float32) + tensor_fp32 = _no_grad_trunc_normal_(tensor_fp32, mean, std, a, b) + tensor = tensor_fp32.astype(dtype) + + +def init_weights(self): + if self.pos_embed is not None: + trunc_normal_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5) + trunc_normal_(self.latent, std=self.latent_dim**-0.5) + + +def init_weights_vit_timm(module: nn.Module, name: str = "") -> None: + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif hasattr(module, "init_weights"): + module.init_weights() + + +class Attention(nn.Module): + fused_attn: Final[bool] + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = False, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + norm_layer: nn.Module = nn.LayerNorm, + ) -> None: + super().__init__() + assert dim % num_heads == 0, "dim should be divisible by num_heads" + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim**-0.5 + # self.fused_attn = use_fused_attn() + self.fused_attn = True + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) if proj_drop > 0.0 else nn.Identity() + + def forward(self, x: mindspore.Tensor) -> mindspore.Tensor: + B, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, self.head_dim) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = qkv.unbind(0) + q, k = self.q_norm(q), self.k_norm(k) + + if self.fused_attn: + x = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=False, + dropout_p=self.attn_drop.p if self.training else 0.0, + ) + else: + q = q * self.scale + attn = q @ k.swapaxes(-2, -1) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.swapaxes(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: float = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * ops.ones(dim)) + + def forward(self, x: mindspore.Tensor) -> mindspore.Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + qk_norm: bool = False, + proj_drop: float = 0.0, + attn_drop: float = 0.0, + init_values: Optional[float] = None, + drop_path: float = 0.0, + act_layer: nn.Module = nn.GELU, + norm_layer: nn.Module = nn.LayerNorm, + mlp_layer: nn.Module = Mlp, + ) -> None: + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + attn_drop=attn_drop, + proj_drop=proj_drop, + norm_layer=norm_layer, + ) + self.ls1 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + self.mlp = mlp_layer( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + drop=proj_drop, + ) + self.ls2 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + def forward(self, x: mindspore.Tensor) -> mindspore.Tensor: + x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x)))) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) + return x + + +class VisionTransformer(nn.Module): + """Vision Transformer + + A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` + - https://arxiv.org/abs/2010.11929 + """ + + dynamic_img_size: Final[bool] + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + num_classes: int = 1000, + global_pool: Literal["", "avg", "token", "map"] = "token", + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + qk_norm: bool = False, + init_values: Optional[float] = None, + class_token: bool = True, + no_embed_class: bool = False, + reg_tokens: int = 0, + pre_norm: bool = False, + fc_norm: Optional[bool] = None, + dynamic_img_size: bool = False, + dynamic_img_pad: bool = False, + drop_rate: float = 0.0, + pos_drop_rate: float = 0.0, + patch_drop_rate: float = 0.0, + proj_drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + drop_path_rate: float = 0.0, + weight_init: Literal["skip", "jax", "jax_nlhb", "moco", ""] = "", + embed_layer: Callable = PatchEmbed, + norm_layer: Optional[LayerType] = None, + act_layer: Optional[LayerType] = None, + block_fn: Type[nn.Module] = Block, + mlp_layer: Type[nn.Module] = Mlp, + ignore_head: bool = False, + ) -> None: + """ + Args: + img_size: Input image size. + patch_size: Patch size. + in_chans: Number of image input channels. + num_classes: Mumber of classes for classification head. + global_pool: Type of global pooling for final sequence (default: 'token'). + embed_dim: Transformer embedding dimension. + depth: Depth of transformer. + num_heads: Number of attention heads. + mlp_ratio: Ratio of mlp hidden dim to embedding dim. + qkv_bias: Enable bias for qkv projections if True. + init_values: Layer-scale init values (layer-scale enabled if not None). + class_token: Use class token. + no_embed_class: Don't include position embeddings for class (or reg) tokens. + reg_tokens: Number of register tokens. + fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'. + drop_rate: Head dropout rate. + pos_drop_rate: Position embedding dropout rate. + attn_drop_rate: Attention dropout rate. + drop_path_rate: Stochastic depth rate. + weight_init: Weight initialization scheme. + embed_layer: Patch embedding layer. + norm_layer: Normalization layer. + act_layer: MLP activation layer. + block_fn: Transformer block layer. + """ + super().__init__() + assert global_pool in ("", "avg", "token", "map") + assert class_token or global_pool != "token" + use_fc_norm = global_pool == "avg" if fc_norm is None else fc_norm + # norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6) + # act_layer = get_act_layer(act_layer) or nn.GELU + norm_layer = partial(nn.LayerNorm, eps=1e-6) + act_layer = nn.GELU + + self.num_classes = num_classes + self.global_pool = global_pool + self.num_features = self.embed_dim = ( + embed_dim # num_features for consistency with other models + ) + self.num_prefix_tokens = 1 if class_token else 0 + self.num_prefix_tokens += reg_tokens + self.num_reg_tokens = reg_tokens + self.has_class_token = class_token + self.no_embed_class = ( + no_embed_class # don't embed prefix positions (includes reg) + ) + self.dynamic_img_size = dynamic_img_size + self.grad_checkpointing = False + self.ignore_head = ignore_head + + embed_args = {} + if dynamic_img_size: + # flatten deferred until after pos embed + embed_args.update(dict(strict_img_size=False, output_fmt="NHWC")) + self.patch_embed = embed_layer( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP) + dynamic_img_pad=dynamic_img_pad, + **embed_args, + ) + num_patches = self.patch_embed.num_patches + + self.cls_token = ( + nn.Parameter(ops.zeros(1, 1, embed_dim)) if class_token else None + ) + self.reg_token = ( + nn.Parameter(ops.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None + ) + embed_len = ( + num_patches if no_embed_class else num_patches + self.num_prefix_tokens + ) + self.pos_embed = nn.Parameter(ops.randn(1, embed_len, embed_dim) * 0.02) + self.pos_drop = nn.Dropout(p=pos_drop_rate) + if patch_drop_rate > 0: + self.patch_drop = PatchDropout( + patch_drop_rate, + num_prefix_tokens=self.num_prefix_tokens, + ) + else: + self.patch_drop = nn.Identity() + self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity() + + dpr = [ + x.item() for x in ops.linspace(0, drop_path_rate, depth) + ] # stochastic depth decay rule + self.blocks = nn.Sequential( + *[ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + init_values=init_values, + proj_drop=proj_drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + mlp_layer=mlp_layer, + ) + for i in range(depth) + ] + ) + self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity() + + # Classifier Head + if global_pool == "map": + AttentionPoolLatent.init_weights = init_weights + self.attn_pool = AttentionPoolLatent( + self.embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + norm_layer=norm_layer, + ) + else: + self.attn_pool = None + self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity() + self.head_drop = nn.Dropout(drop_rate) + self.head = ( + nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + ) + + if weight_init != "skip": + self.init_weights(weight_init) + + def init_weights(self, mode: Literal["jax", "jax_nlhb", "moco", ""] = "") -> None: + assert mode in ("jax", "jax_nlhb", "moco", "") + # head_bias = -math.log(self.num_classes) if "nlhb" in mode else 0.0 + trunc_normal_(self.pos_embed, std=0.02) + if self.cls_token is not None: + nn.init.normal_(self.cls_token, std=1e-6) + named_apply(init_weights_vit_timm, self) + + # @torch.jit.ignore + def no_weight_decay(self) -> Set: + return {"pos_embed", "cls_token", "dist_token"} + + # @torch.jit.ignore + def group_matcher(self, coarse: bool = False) -> Dict: + return dict( + stem=r"^cls_token|pos_embed|patch_embed", # stem and embed + blocks=[(r"^blocks\.(\d+)", None), (r"^norm", (99999,))], + ) + + # @torch.jit.ignore + def set_grad_checkpointing(self, enable: bool = True) -> None: + self.grad_checkpointing = enable + + # @torch.jit.ignore + def get_classifier(self) -> nn.Module: + return self.head + + def reset_classifier(self, num_classes: int, global_pool=None) -> None: + self.num_classes = num_classes + if global_pool is not None: + assert global_pool in ("", "avg", "token", "map") + if global_pool == "map" and self.attn_pool is None: + assert ( + False + ), "Cannot currently add attention pooling in reset_classifier()." + elif global_pool != "map " and self.attn_pool is not None: + self.attn_pool = None # remove attention pooling + self.global_pool = global_pool + self.head = ( + nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + ) + + def _pos_embed(self, x: mindspore.Tensor) -> mindspore.Tensor: + if self.dynamic_img_size: + B, H, W, C = x.shape + pos_embed = resample_abs_pos_embed( + self.pos_embed, + (H, W), + num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens, + ) + x = x.view(B, -1, C) + else: + pos_embed = self.pos_embed + + to_cat = [] + if self.cls_token is not None: + to_cat.append(self.cls_token.expand(x.shape[0], -1, -1)) + if self.reg_token is not None: + to_cat.append(self.reg_token.expand(x.shape[0], -1, -1)) + + if self.no_embed_class: + # deit-3, updated JAX (big vision) + # position embedding does not overlap with class token, add then concat + x = x + pos_embed + if to_cat: + x = ops.cat(to_cat + [x], axis=1) + else: + # original timm, JAX, and deit vit impl + # pos_embed has entry for class token, concat then add + if to_cat: + x = ops.cat(to_cat + [x], axis=1) + x = x + pos_embed + + return self.pos_drop(x) + + def _intermediate_layers( + self, + x: mindspore.Tensor, + n: Union[int, Sequence] = 1, + ) -> List[mindspore.Tensor]: + outputs, num_blocks = [], len(self.blocks) + take_indices = set( + range(num_blocks - n, num_blocks) if isinstance(n, int) else n + ) + + # forward pass + x = self.patch_embed(x) + x = self._pos_embed(x) + x = self.patch_drop(x) + x = self.norm_pre(x) + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in take_indices: + outputs.append(x) + + return outputs + + def get_intermediate_layers( + self, + x: mindspore.Tensor, + n: Union[int, Sequence] = 1, + reshape: bool = False, + return_prefix_tokens: bool = False, + norm: bool = False, + ) -> Tuple[Union[mindspore.Tensor, Tuple[mindspore.Tensor]]]: + """Intermediate layer accessor (NOTE: This is a WIP experiment). + Inspired by DINO / DINOv2 interface + """ + # take last n blocks if n is an int, if in is a sequence, select by matching indices + outputs = self._intermediate_layers(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + prefix_tokens = [out[:, 0 : self.num_prefix_tokens] for out in outputs] + outputs = [out[:, self.num_prefix_tokens :] for out in outputs] + + if reshape: + grid_size = self.patch_embed.grid_size + outputs = [ + out.reshape(x.shape[0], grid_size[0], grid_size[1], -1) + .permute(0, 3, 1, 2) + .contiguous() + for out in outputs + ] + + if return_prefix_tokens: + return tuple(zip(outputs, prefix_tokens)) + return tuple(outputs) + + def forward_features(self, x: mindspore.Tensor) -> mindspore.Tensor: + x = self.patch_embed(x) + x = self._pos_embed(x) + x = self.patch_drop(x) + x = self.norm_pre(x) + # if self.grad_checkpointing and not torch.jit.is_scripting(): + # x = checkpoint_seq(self.blocks, x) + # else: + # x = self.blocks(x) + x = self.blocks(x) + x = self.norm(x) + return x + + def forward_head(self, x: mindspore.Tensor, pre_logits: bool = False) -> mindspore.Tensor: + if self.attn_pool is not None: + x = self.attn_pool(x) + elif self.global_pool == "avg": + x = x[:, self.num_prefix_tokens :].mean(dim=1) + elif self.global_pool: + x = x[:, 0] # class token + x = self.fc_norm(x) + x = self.head_drop(x) + return x if pre_logits else self.head(x) + + def forward(self, x: mindspore.Tensor) -> mindspore.Tensor: + x = self.forward_features(x) + if not self.ignore_head: + x = self.forward_head(x) + return x + + +@dataclass +class SigLIPVisionCfg: + width: int = 1152 + layers: Union[Tuple[int, int, int, int], int] = 27 + heads: int = 16 + patch_size: int = 14 + image_size: Union[Tuple[int, int], int] = 336 + global_pool: str = "map" + mlp_ratio: float = 3.7362 + class_token: bool = False + num_classes: int = 0 + use_checkpoint: bool = False + + +SigLIP_MODEL_CONFIG = { + "siglip_so400m_patch14_384": { + "image_size": 336, + "patch_size": 14, + "width": 1152, + "layers": 27, + "heads": 16, + "mlp_ratio": 3.7362, + "global_pool": "map", + "use_checkpoint": False, + }, + "siglip_so400m_patch14_224": { + "image_size": 224, + "patch_size": 14, + "width": 1152, + "layers": 27, + "heads": 16, + "mlp_ratio": 3.7362, + "global_pool": "map", + "use_checkpoint": False, + }, + "siglip_large_patch16_384": { + "image_size": 384, + "patch_size": 16, + "width": 1024, + "layers": 24, + "heads": 16, + "mlp_ratio": 4, + "global_pool": "map", + "use_checkpoint": False, + }, +} + + +def create_siglip_vit( + model_name: str = "siglip_so400m_patch14_384", + image_size: int = 384, + select_layer: int = -1, + ckpt_path: str = "", + **kwargs, +): + assert ( + model_name in SigLIP_MODEL_CONFIG.keys() + ), f"model name should be in {SigLIP_MODEL_CONFIG.keys()}" + + vision_cfg = SigLIPVisionCfg(**SigLIP_MODEL_CONFIG[model_name]) + + if select_layer <= 0: + layers = min(vision_cfg.layers, vision_cfg.layers + select_layer + 1) + else: + layers = min(vision_cfg.layers, select_layer) + + model = VisionTransformer( + img_size=image_size, + patch_size=vision_cfg.patch_size, + embed_dim=vision_cfg.width, + depth=layers, + num_heads=vision_cfg.heads, + mlp_ratio=vision_cfg.mlp_ratio, + class_token=vision_cfg.class_token, + global_pool=vision_cfg.global_pool, + ignore_head=kwargs.get("ignore_head", True), + weight_init=kwargs.get("weight_init", "skip"), + num_classes=0, + ) + + if ckpt_path: + assert False, "Not implemented load ckpt for create_siglip_vit" + # state_dict = torch.load(ckpt_path, map_location="cpu") + + incompatible_keys = model.load_state_dict(state_dict, strict=False) + print( + f"SigLIP-ViT restores from {ckpt_path},\n" + f"\tincompatible_keys:', {incompatible_keys}." + ) + + return model diff --git a/llm/inference/janus_pro/janus/models/test_timm_layers.py b/llm/inference/janus_pro/janus/models/test_timm_layers.py new file mode 100644 index 000000000..fb285c63d --- /dev/null +++ b/llm/inference/janus_pro/janus/models/test_timm_layers.py @@ -0,0 +1,59 @@ +import pytest +from .timm_layers import ( + Mlp, + AttentionPoolLatent, + DropPath, + LayerType, + PatchDropout, + PatchEmbed, + resample_abs_pos_embed +) +import mindspore as ms +import numpy as np + +# 测试 Mlp 类 +def test_mlp(): + mlp = Mlp(in_features=10, hidden_features=20, out_features=30) + x = ms.Tensor(np.random.randn(1, 10), dtype=ms.float32) + output = mlp(x) + assert output.shape == (1, 30) + +# 测试 AttentionPoolLatent 类 +def test_attention_pool_latent(): + attn_pool = AttentionPoolLatent(in_features=10, out_features=20) + x = ms.Tensor(np.random.randn(1, 100, 10)) + output = attn_pool(x) + assert output.shape == (1, 20) + +# 测试 DropPath 类 +def test_drop_path(): + drop_path = DropPath(drop_prob=0.5) + x = ms.Tensor(np.random.randn(1, 10)) + output = drop_path(x) + # 由于随机性,这里只检查输出的形状 + assert output.shape == x.shape + +# 测试 PatchDropout 类 +def test_patch_dropout(): + patch_dropout = PatchDropout(prob=0.5) + x = ms.Tensor(np.random.randn(1, 100, 10)) + output = patch_dropout(x) + # 由于随机性,这里只检查输出的形状 + assert output.shape == x.shape + +# 测试 PatchEmbed 类 +def test_patch_embed(): + patch_embed = PatchEmbed(img_size=224, patch_size=16, in_chans=3, embed_dim=768) + x = ms.Tensor(np.random.randn(1, 3, 224, 224)) + output = patch_embed(x) + assert output.shape == (1, 196, 768) + +# 测试 resample_abs_pos_embed 函数 +def test_resample_abs_pos_embed(): + posemb = ms.Tensor(np.random.randn(1, 100, 768)) + new_size = [50, 50] + output = resample_abs_pos_embed(posemb, new_size) + assert output.shape == (1, 2500, 768) + +if __name__ == '__main__': + pytest.main() diff --git a/llm/inference/janus_pro/janus/models/timm_layers.py b/llm/inference/janus_pro/janus/models/timm_layers.py new file mode 100644 index 000000000..c4ff51792 --- /dev/null +++ b/llm/inference/janus_pro/janus/models/timm_layers.py @@ -0,0 +1,595 @@ + +import collections +from enum import Enum +from functools import partial +from itertools import repeat +import warnings +import mindspore +import mindspore as ms +from mindnlp.core import nn, no_grad +from mindspore.ops import uniform + +from typing import Callable, List, Optional, Tuple, Type, Union +import math +from mindnlp.core.nn import Module +from mindnlp.core import ops +import mindnlp.core.nn.functional as F +import numpy as np +from mindnlp.configs import set_pyboost +set_pyboost(False) +# ============ Mlp ====================== +class Mlp(nn.Module): + """ MLP as used in Vision Transformer, MLP-Mixer and related networks + """ + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + norm_layer=None, + bias=True, + drop=0., + use_conv=False, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + bias = to_2tuple(bias) + drop_probs = to_2tuple(drop) + linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear + + self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity() + self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.norm(x) + x = self.fc2(x) + x = self.drop2(x) + return x + +# =======AttentionPoolLatent ============ + +def _trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor = uniform(tensor.shape, ms.Tensor(2 * l - 1, dtype=tensor.dtype), ms.Tensor(2 * u - 1, dtype=tensor.dtype), dtype=tensor.dtype) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor = ops.erfinv(tensor) + + # Transform to proper mean, std + tensor = tensor.mul(std * math.sqrt(2.) ) + tensor = tensor.add(mean) + + # Clamp to ensure it's in the proper range + tensor = tensor.clamp(min=a, max=b) + return tensor +def trunc_normal_tf_(tensor, mean=0., std=1., a=-2., b=2.): + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + + NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the + bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 + and the result is subsquently scaled and shifted by the mean and std args. + + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + """ + with no_grad(): + _trunc_normal_(tensor, 0, 1.0, a, b) + tensor = tensor.mul(std).add(mean) + return tensor +class AttentionPoolLatent(nn.Module): + """ Attention pooling w/ latent query + """ + + def __init__( + self, + in_features: int, + out_features: int = None, + embed_dim: int = None, + num_heads: int = 8, + feat_size: Optional[int] = None, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + qk_norm: bool = False, + latent_len: int = 1, + latent_dim: int = None, + pos_embed: str = '', + pool_type: str = 'token', + norm_layer: Optional[nn.Module] = None, + drop: float = 0.0, + ): + super().__init__() + embed_dim = embed_dim or in_features + out_features = out_features or in_features + assert embed_dim % num_heads == 0 + self.num_heads = num_heads + self.head_dim = embed_dim // num_heads + self.feat_size = feat_size + self.scale = self.head_dim ** -0.5 + self.pool = pool_type + + if pos_embed == 'abs': + assert feat_size is not None + self.pos_embed = nn.Parameter(ops.zeros(feat_size, in_features)) + else: + self.pos_embed = None + + self.latent_dim = latent_dim or embed_dim + self.latent_len = latent_len + self.latent = nn.Parameter(ops.zeros(1, self.latent_len, embed_dim)) + + self.q = nn.Linear(embed_dim, embed_dim, bias=qkv_bias) + self.kv = nn.Linear(embed_dim, embed_dim * 2, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.proj = nn.Linear(embed_dim, embed_dim) + self.proj_drop = nn.Dropout(drop) + + self.norm = norm_layer(out_features) if norm_layer is not None else nn.Identity() + self.mlp = Mlp(embed_dim, int(embed_dim * mlp_ratio)) + + self.init_weights() + + def init_weights(self): + if self.pos_embed is not None: + trunc_normal_tf_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5) + trunc_normal_tf_(self.latent, std=self.latent_dim ** -0.5) + + def forward(self, x): + B, N, C = x.shape + + if self.pos_embed is not None: + x = x + self.pos_embed.unsqueeze(0).to(x.dtype) + + q_latent = self.latent.expand((B, -1, -1)) + q = self.q(q_latent).reshape(B, self.latent_len, self.num_heads, self.head_dim).swapaxes(1, 2) + + kv = self.kv(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + k, v = kv.unbind(0) + + q, k = self.q_norm(q), self.k_norm(k) + + q = q * self.scale + attn = q @ k.swapaxes(-2, -1) + attn = attn.softmax(axis=-1) + + x = attn @ v + x = x.swapaxes(1, 2).reshape(B, self.latent_len, C) + x = self.proj(x) + x = self.proj_drop(x) + + x = x + self.mlp(self.norm(x)) + + # optional pool if latent seq_len > 1 and pooled output is desired + if self.pool == 'token': + x = x[:, 0] + elif self.pool == 'avg': + x = x.mean(1) + return x + +# ============ DropPath ================= +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) + + def extra_repr(self): + return f'drop_prob={round(self.drop_prob,3):0.3f}' + +def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (len(x.shape) - 1) # work with diff dim tensors, not just 2D ConvNets + print(shape) # (1,1)相同 + random_tensor = ops.bernoulli(ms.ops.zeros(shape, x.dtype), p=keep_prob) + print("random_tensor shape:",random_tensor.shape) + print("random_tensor:",random_tensor) + if keep_prob > 0.0 and scale_by_keep: + random_tensor.div(keep_prob) + print(random_tensor.shape) + print(x.shape) + return x * random_tensor + +# ============ LayerType ================ +LayerType = Union[str, Callable, Type[Module]] +PadType = Union[str, int, Tuple[int, int]] + +# ============ PatchDropout ============= +class PatchDropout(nn.Module): + """ + https://arxiv.org/abs/2212.00794 and https://arxiv.org/pdf/2208.07220 + """ + return_indices: bool + + def __init__( + self, + prob: float = 0.5, + num_prefix_tokens: int = 1, + ordered: bool = False, + return_indices: bool = False, + ): + super().__init__() + assert 0 <= prob < 1. + self.prob = prob + self.num_prefix_tokens = num_prefix_tokens # exclude CLS token (or other prefix tokens) + self.ordered = ordered + self.return_indices = return_indices + + def forward(self, x) -> Union[mindspore.Tensor, Tuple[mindspore.Tensor, Optional[mindspore.Tensor]]]: + if not self.training or self.prob == 0.: + if self.return_indices: + return x, None + return x + + if self.num_prefix_tokens: + prefix_tokens, x = x[:, :self.num_prefix_tokens], x[:, self.num_prefix_tokens:] + else: + prefix_tokens = None + + B = x.shape[0] + L = x.shape[1] + num_keep = max(1, int(L * (1. - self.prob))) + rand_tensor = ops.randn(B, L) + argsort_tensor = ops.argsort(rand_tensor, dim=-1) + keep_indices = argsort_tensor[:, :num_keep] + print(keep_indices.shape) + if self.ordered: + # NOTE does not need to maintain patch order in typical transformer use, + # but possibly useful for debug / visualization + keep_indices = keep_indices.sort(axis=-1)[0] + x = x.gather_elements(1, keep_indices.unsqueeze(-1).expand((-1, -1) + x.shape[2:])) + + if prefix_tokens is not None: + x = ops.cat((prefix_tokens, x), dim=1) + + if self.return_indices: + return x, keep_indices + return x + +# ============ PatchEmbed =============== +class Format(str, Enum): + NCHW = 'NCHW' + NHWC = 'NHWC' + NCL = 'NCL' + NLC = 'NLC' +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return tuple(x) + return tuple(repeat(x, n)) + return parse +to_2tuple = _ntuple(2) +def nchw_to(x: mindspore.Tensor, fmt: Format): + if fmt == Format.NHWC: + x = x.permute(0, 2, 3, 1) + elif fmt == Format.NLC: + x = x.flatten(2).swapaxes(1, 2) + elif fmt == Format.NCL: + x = x.flatten(2) + return x +class PatchEmbed(nn.Module): + """ 2D Image to Patch Embedding + """ + output_fmt: Format + dynamic_img_pad: bool + + def __init__( + self, + img_size: Optional[int] = 224, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten: bool = True, + output_fmt: Optional[str] = None, + bias: bool = True, + strict_img_size: bool = True, + dynamic_img_pad: bool = False, + ): + super().__init__() + self.patch_size = to_2tuple(patch_size) + self.img_size, self.grid_size, self.num_patches = self._init_img_size(img_size) + + if output_fmt is not None: + self.flatten = False + self.output_fmt = Format(output_fmt) + else: + # flatten spatial dim and transpose to channels last, kept for bwd compat + self.flatten = flatten + self.output_fmt = Format.NCHW + self.strict_img_size = strict_img_size + self.dynamic_img_pad = dynamic_img_pad + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def _init_img_size(self, img_size: Union[int, Tuple[int, int]]): + assert self.patch_size + if img_size is None: + return None, None, None + img_size = to_2tuple(img_size) + grid_size = tuple(s // p for s, p in zip(img_size, self.patch_size)) + num_patches = grid_size[0] * grid_size[1] + return img_size, grid_size, num_patches + + def set_input_size( + self, + img_size: Optional[Union[int, Tuple[int, int]]] = None, + patch_size: Optional[Union[int, Tuple[int, int]]] = None, + ): + new_patch_size = None + if patch_size is not None: + new_patch_size = to_2tuple(patch_size) + if new_patch_size is not None and new_patch_size != self.patch_size: + with no_grad(): + new_proj = nn.Conv2d( + self.proj.in_channels, + self.proj.out_channels, + kernel_size=new_patch_size, + stride=new_patch_size, + bias=self.proj.bias is not None, + ) + new_proj.weight.copy_(resample_patch_embed(self.proj.weight, new_patch_size, verbose=True)) + if self.proj.bias is not None: + new_proj.bias.copy_(self.proj.bias) + self.proj = new_proj + self.patch_size = new_patch_size + img_size = img_size or self.img_size + if img_size != self.img_size or new_patch_size is not None: + self.img_size, self.grid_size, self.num_patches = self._init_img_size(img_size) + + def feat_ratio(self, as_scalar=True) -> Union[Tuple[int, int], int]: + if as_scalar: + return max(self.patch_size) + else: + return self.patch_size + + def dynamic_feat_size(self, img_size: Tuple[int, int]) -> Tuple[int, int]: + """ Get grid (feature) size for given image size taking account of dynamic padding. + NOTE: must be torchscript compatible so using fixed tuple indexing + """ + if self.dynamic_img_pad: + return math.ceil(img_size[0] / self.patch_size[0]), math.ceil(img_size[1] / self.patch_size[1]) + else: + return img_size[0] // self.patch_size[0], img_size[1] // self.patch_size[1] + + def forward(self, x): + B, C, H, W = x.shape + if self.img_size is not None: + if self.strict_img_size: + assert H == self.img_size[0], f"Input height ({H}) doesn't match model ({self.img_size[0]})." + assert W == self.img_size[1], f"Input width ({W}) doesn't match model ({self.img_size[1]})." + elif not self.dynamic_img_pad: + assert H % self.patch_size[0] == 0, f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})." + assert W % self.patch_size[1] == 0, f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})." + if self.dynamic_img_pad: + pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0] + pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1] + x = F.pad(x, (0, pad_w, 0, pad_h)) + x = self.proj(x) + if self.flatten: + x = x.flatten(start_dim=2).swapaxes(1, 2) # NCHW -> NLC + mindspore.Tensor.flatten + elif self.output_fmt != Format.NCHW: + x = nchw_to(x, self.output_fmt) + x = self.norm(x) + return x +def resample_patch_embed( + patch_embed, + new_size: List[int], + interpolation: str = 'bicubic', + antialias: bool = True, + verbose: bool = False, +): + """Resample the weights of the patch embedding kernel to target resolution. + We resample the patch embedding kernel by approximately inverting the effect + of patch resizing. + + Code based on: + https://github.com/google-research/big_vision/blob/b00544b81f8694488d5f36295aeb7972f3755ffe/big_vision/models/proj/flexi/vit.py + + With this resizing, we can for example load a B/8 filter into a B/16 model + and, on 2x larger input image, the result will match. + + Args: + patch_embed: original parameter to be resized. + new_size (tuple(int, int): target shape (height, width)-only. + interpolation (str): interpolation for resize + antialias (bool): use anti-aliasing filter in resize + verbose (bool): log operation + Returns: + Resized patch embedding kernel. + """ + import numpy as np + # try: + # from torch import vmap + # except ImportError: + # from functorch import vmap + + assert len(patch_embed.shape) == 4, "Four dimensions expected" + assert len(new_size) == 2, "New shape should only be hw" + old_size = patch_embed.shape[-2:] + if tuple(old_size) == tuple(new_size): + return patch_embed + + # if verbose: + # _logger.info(f"Resize patch embedding {patch_embed.shape} to {new_size}, w/ {interpolation} interpolation.") + + def resize(x_np, _new_size): + x_tf = mindspore.Tensor(x_np)[None, None, ...] + x_upsampled = F.interpolate( + x_tf, size=_new_size, mode=interpolation, antialias=antialias)[0, 0, ...].numpy() + return x_upsampled + + def get_resize_mat(_old_size, _new_size): + mat = [] + for i in range(np.prod(_old_size)): + basis_vec = np.zeros(_old_size) + basis_vec[np.unravel_index(i, _old_size)] = 1. + mat.append(resize(basis_vec, _new_size).reshape(-1)) + return np.stack(mat).T + + resize_mat = get_resize_mat(old_size, new_size) + resize_mat_pinv = mindspore.tensor(np.linalg.pinv(resize_mat.T), device=patch_embed.device) + + def resample_kernel(kernel): + resampled_kernel = resize_mat_pinv @ kernel.reshape(-1) + return resampled_kernel.reshape(new_size) + + v_resample_kernel = mindspore.vmap(mindspore.vmap(resample_kernel, 0, 0), 1, 1) + orig_dtype = patch_embed.dtype + patch_embed = patch_embed.float() + patch_embed = v_resample_kernel(patch_embed) + patch_embed = patch_embed.to(orig_dtype) + return patch_embed + +# ======== resample_abs_pos_embed ======= +def resample_abs_pos_embed( + posemb: mindspore.Tensor, + new_size: List[int], + old_size: Optional[List[int]] = None, + num_prefix_tokens: int = 1, + interpolation: str = 'bicubic', + antialias: bool = True, + verbose: bool = False, +): + # sort out sizes, assume square if old size not provided + num_pos_tokens = posemb.shape[1] + num_new_tokens = new_size[0] * new_size[1] + num_prefix_tokens + if num_new_tokens == num_pos_tokens and new_size[0] == new_size[1]: + return posemb + + if old_size is None: + hw = int(math.sqrt(num_pos_tokens - num_prefix_tokens)) + old_size = hw, hw + + if num_prefix_tokens: + posemb_prefix, posemb = posemb[:, :num_prefix_tokens], posemb[:, num_prefix_tokens:] + else: + posemb_prefix = None + + # do the interpolation + embed_dim = posemb.shape[-1] + orig_dtype = posemb.dtype + posemb = posemb.float() # interpolate needs float32 + posemb = posemb.reshape(1, old_size[0], old_size[1], -1).permute(0, 3, 1, 2) + posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias) + posemb = posemb.permute(0, 2, 3, 1).reshape(1, -1, embed_dim) + posemb = posemb.to(orig_dtype) + + # add back extra (class, etc) prefix tokens + if posemb_prefix is not None: + posemb = ops.cat([posemb_prefix, posemb], dim=1) + + + return posemb + +# 测试 Mlp 类 +def test_mlp(): + mlp = Mlp(in_features=10, hidden_features=20, out_features=30) + x = ms.Tensor(np.random.randn(1, 10), dtype=ms.float32) + output = mlp(x) + assert output.shape == (1, 30) + +# 测试 AttentionPoolLatent 类 +def test_attention_pool_latent(): + attn_pool = AttentionPoolLatent(in_features=16, out_features=8) + x = ms.Tensor(np.random.randn(1, 100, 16),dtype=ms.float32) + output = attn_pool(x) + print(output.shape) # 1,16 + assert output.shape == (1,16) + +# 测试 DropPath 类 +def test_drop_path(): + drop_path = DropPath(drop_prob=1) + x = ms.Tensor(np.random.randn(1, 10),dtype=ms.float32) + output = drop_path(x) + print(output) + # 由于随机性,这里只检查输出的形状 + # print(output.shape) + +# 测试 PatchDropout 类 +def test_patch_dropout(): + patch_dropout = PatchDropout(prob=0.5) + x = ms.Tensor(np.random.randn(1, 100, 10),dtype=ms.float32) + output = patch_dropout(x) + # 由于随机性,这里只检查输出的形状 + print("PatchDropout:",output.shape) + print("PatchDropout:",output) + +# 测试 PatchEmbed 类 +def test_patch_embed(): + patch_embed = PatchEmbed(img_size=224, patch_size=16, in_chans=3, embed_dim=768) + x = ms.Tensor(np.random.randn(1, 3, 224, 224),dtype=ms.float32) + output = patch_embed(x) + print("PatchEmbed:",output.shape) + print("PatchEmbed:",output) + +# 测试 resample_abs_pos_embed 函数 +def test_resample_abs_pos_embed(): + posemb = ms.Tensor(np.random.randn(1, 101, 768),dtype=ms.float32) + new_size = [50, 50] + output = resample_abs_pos_embed(posemb, new_size) + # assert output.shape == (1, 2501, 768) + print("test_resample_abs_pos_embed:",output.shape) + print("test_resample_abs_pos_embed:",output) +if __name__ =='__main__': + test_mlp() #pass + test_attention_pool_latent() #pass + test_drop_path() # fail,无法打印 + test_patch_dropout() + test_patch_embed() + test_resample_abs_pos_embed() + diff --git a/llm/inference/janus_pro/janus/models/vq_model.py b/llm/inference/janus_pro/janus/models/vq_model.py new file mode 100755 index 000000000..c96fffb50 --- /dev/null +++ b/llm/inference/janus_pro/janus/models/vq_model.py @@ -0,0 +1,576 @@ +# Copyright (c) 2023-2024 DeepSeek. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + + +from dataclasses import dataclass, field +from typing import List + +import mindspore +# import mindspore.ops as ops +import mindnlp.core.ops as ops +import mindnlp.core.nn as nn +import mindspore.common.dtype as mstype +import mindnlp.core.nn.functional as F +from mindnlp.core.nn import Parameter +from mindnlp.core import Tensor +from mindspore.common.initializer import initializer, Uniform +from functools import partial +import numpy as np + + +@dataclass +class ModelArgs: + codebook_size: int = 16384 + codebook_embed_dim: int = 8 + codebook_l2_norm: bool = True + codebook_show_usage: bool = True + commit_loss_beta: float = 0.25 + entropy_loss_ratio: float = 0.0 + + encoder_ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4]) + decoder_ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4]) + z_channels: int = 256 + dropout_p: float = 0.0 + +def normalize_np(x, p=2, dim=1, eps=1e-12): + x = x.asnumpy() + norm = np.linalg.norm(x, ord=p, axis=dim, keepdims=True) + norm = np.maximum(norm, eps) + y = x / norm + y = mindspore.Tensor(y) + return y + +class Encoder(nn.Module): + def __init__( + self, + in_channels=3, + ch=128, + ch_mult=(1, 1, 2, 2, 4), + num_res_blocks=2, + norm_type="group", + dropout=0.0, + resamp_with_conv=True, + z_channels=256, + ): + super().__init__() + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.conv_in = nn.Conv2d(in_channels, ch, kernel_size=3, stride=1, padding=1) + + # downsampling + in_ch_mult = (1,) + tuple(ch_mult) + self.conv_blocks = nn.ModuleList() + for i_level in range(self.num_resolutions): + conv_block = nn.Module() + # res & attn + res_block = nn.ModuleList() + attn_block = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks): + res_block.append( + ResnetBlock( + block_in, block_out, dropout=dropout, norm_type=norm_type + ) + ) + block_in = block_out + if i_level == self.num_resolutions - 1: + attn_block.append(AttnBlock(block_in, norm_type)) + conv_block.res = res_block + conv_block.attn = attn_block + # downsample + if i_level != self.num_resolutions - 1: + conv_block.downsample = Downsample(block_in, resamp_with_conv) + self.conv_blocks.append(conv_block) + + # middle + self.mid = nn.ModuleList() + self.mid.append( + ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type) + ) + self.mid.append(AttnBlock(block_in, norm_type=norm_type)) + self.mid.append( + ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type) + ) + + # end + self.norm_out = Normalize(block_in, norm_type) + self.conv_out = nn.Conv2d( + block_in, z_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + h = self.conv_in(x) + # downsampling + for i_level, block in enumerate(self.conv_blocks): + for i_block in range(self.num_res_blocks): + h = block.res[i_block](h) + if len(block.attn) > 0: + h = block.attn[i_block](h) + if i_level != self.num_resolutions - 1: + h = block.downsample(h) + + # middle + for mid_block in self.mid: + h = mid_block(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + z_channels=256, + ch=128, + ch_mult=(1, 1, 2, 2, 4), + num_res_blocks=2, + norm_type="group", + dropout=0.0, + resamp_with_conv=True, + out_channels=3, + ): + super().__init__() + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + + block_in = ch * ch_mult[self.num_resolutions - 1] + # z to block_in + self.conv_in = nn.Conv2d( + z_channels, block_in, kernel_size=3, stride=1, padding=1 + ) + + # middle + self.mid = nn.ModuleList() + self.mid.append( + ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type) + ) + self.mid.append(AttnBlock(block_in, norm_type=norm_type)) + self.mid.append( + ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type) + ) + + # upsampling + self.conv_blocks = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + conv_block = nn.Module() + # res & attn + res_block = nn.ModuleList() + attn_block = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks + 1): + res_block.append( + ResnetBlock( + block_in, block_out, dropout=dropout, norm_type=norm_type + ) + ) + block_in = block_out + if i_level == self.num_resolutions - 1: + attn_block.append(AttnBlock(block_in, norm_type)) + conv_block.res = res_block + conv_block.attn = attn_block + # downsample + if i_level != 0: + conv_block.upsample = Upsample(block_in, resamp_with_conv) + self.conv_blocks.append(conv_block) + + # end + self.norm_out = Normalize(block_in, norm_type) + self.conv_out = nn.Conv2d( + block_in, out_channels, kernel_size=3, stride=1, padding=1 + ) + + @property + def last_layer(self): + return self.conv_out.weight + + def forward(self, z): + # z to block_in + h = self.conv_in(z) + + # middle + for mid_block in self.mid: + h = mid_block(h) + + # upsampling + for i_level, block in enumerate(self.conv_blocks): + for i_block in range(self.num_res_blocks + 1): + h = block.res[i_block](h) + if len(block.attn) > 0: + h = block.attn[i_block](h) + if i_level != self.num_resolutions - 1: + h = block.upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class VectorQuantizer(nn.Module): + def __init__(self, n_e, e_dim, beta, entropy_loss_ratio, l2_norm, show_usage): + super().__init__() + self.n_e = n_e + self.e_dim = e_dim + self.beta = beta + self.entropy_loss_ratio = entropy_loss_ratio + self.l2_norm = l2_norm + self.show_usage = show_usage + + self.embedding = nn.Embedding(self.n_e, self.e_dim) + left= mindspore.Tensor(-1.0 / self.n_e, dtype=self.embedding.weight.dtype) + right= mindspore.Tensor(1.0 / self.n_e, dtype=self.embedding.weight.dtype) + # self.embedding.weight = Parameter(initializer(Uniform(1.0 / self.n_e), self.embedding.weight.shape, self.embedding.weight.dtype)) + nn.init.uniform_(self.embedding.weight, -1.0 / self.n_e, 1.0 / self.n_e) + # uniform(self.embedding.weight.shape, left, right,self.embedding.weight.dtype, seed=0) + # mindspore.common.initializer.Uniform((-1.0 / self.n_e, 1.0 / self.n_e)).apply(self.embedding.weight) + if self.l2_norm: + self.embedding.weight = Parameter(F.normalize( + self.embedding.weight, p=2, dim=-1 + )) + if self.show_usage: + self.register_buffer("codebook_used", nn.Parameter(ops.zeros(65536))) + + def forward(self, z): + # reshape z -> (batch, height, width, channel) and flatten + z = ops.permute(z, (0, 2, 3, 1)) + z_flattened = z.view(-1, self.e_dim) + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + + if self.l2_norm: + z = F.normalize(z, p=2, dim=-1) + z_flattened = F.normalize(z_flattened, p=2, dim=-1) + embedding = F.normalize(self.embedding.weight, p=2, dim=-1) + else: + embedding = self.embedding.weight + + d = ( + ops.sum(z_flattened**2, dim=1, keepdim=True) + + ops.sum(embedding**2, dim=1) + - 2 + * ops.einsum( + "bd,dn->bn", z_flattened, ops.einsum("n d -> d n", embedding) + ) + ) + + min_encoding_indices = ops.argmin(d, dim=1) + z_q = embedding[min_encoding_indices].view(z.shape) + perplexity = None + min_encodings = None + vq_loss = None + commit_loss = None + entropy_loss = None + + # compute loss for embedding + if self.training: + vq_loss = ops.mean((z_q - ops.stop_gradient(z)) ** 2) + commit_loss = self.beta * ops.mean((ops.stop_gradient(z_q) - z) ** 2) + entropy_loss = self.entropy_loss_ratio * compute_entropy_loss(-d) + + # preserve gradients + z_q = z + ops.stop_gradient(z_q - z) + + # reshape back to match original input shape + z_q = ops.einsum("b h w c -> b c h w", z_q) + + return ( + z_q, + (vq_loss, commit_loss, entropy_loss), + (perplexity, min_encodings, min_encoding_indices), + ) + + def get_codebook_entry(self, indices, shape=None, channel_first=True): + # shape = (batch, channel, height, width) if channel_first else (batch, height, width, channel) + if self.l2_norm: + print("self.embedding.weight.dtype:", self.embedding.weight.dtype) + embedding = F.normalize(self.embedding.weight.astype(mindspore.float16), p=2, dim=-1) + # embedding = normalize_np(self.embedding.weight, p=2, dim=-1) + else: + embedding = self.embedding.weight + z_q = embedding[indices] # (b*h*w, c) + + if shape is not None: + if channel_first: + z_q = z_q.reshape(shape[0], shape[2], shape[3], shape[1]) + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + else: + z_q = z_q.view(shape) + return z_q + + +class ResnetBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout=0.0, + norm_type="group", + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels, norm_type) + self.conv1 = nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + self.norm2 = Normalize(out_channels, norm_type) + self.dropout = nn.Dropout(dropout) + self.conv2 = nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + else: + self.nin_shortcut = nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + return x + h + + +class AttnBlock(nn.Module): + def __init__(self, in_channels, norm_type="group"): + super().__init__() + self.norm = Normalize(in_channels, norm_type) + self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x): + print("AttnBlock forward") + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w) + q = q.permute(0, 2, 1) # b,hw,c (1, 576, 512) + k = k.reshape(b, c, h * w) # b,c,hw + # q = q.permute(0, 2, 1) # (1, 576, 512) + print("q.shape:",q.shape) # q.shape: (1, 576, 512) + print("k.shape:",k.shape) # k.shape: (1, 512, 576) (576,512) + w_ = ops.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) + w_ = F.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w) + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + print("w_.shape:",w_.shape) + print("v.shape:",v.shape) + h_ = ops.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w) + print("h_.dtype:",h_.dtype) + + h_ = self.proj_out(h_.astype(mindspore.float16)) + + return x + h_ + + +def nonlinearity(x): + # swish + return x * ops.sigmoid(x) + + +def Normalize(in_channels, norm_type="group"): + assert norm_type in ["group", "batch"] + if norm_type == "group": + return nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True + ) + elif norm_type == "batch": + assert False, "not implement batch Norm" + # return nn.SyncBatchNorm(in_channels) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + if x.dtype != mindspore.float32: + x = F.interpolate(x.astype(mstype.float32), scale_factor=2.0, mode="nearest", recompute_scale_factor=True).astype(mstype.float16) + else: + x = F.interpolate(x, scale_factor=2.0, mode="nearest", recompute_scale_factor=True) + + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=2, padding=0 + ) + + def forward(self, x): + if self.with_conv: + pad = (0, 1, 0, 1) + x = F.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = F.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +def compute_entropy_loss(affinity, loss_type="softmax", temperature=0.01): + flat_affinity = affinity.reshape(-1, affinity.shape[-1]) + flat_affinity /= temperature + probs = F.softmax(flat_affinity, dim=-1) + log_probs = F.log_softmax(flat_affinity + 1e-5, dim=-1) + if loss_type == "softmax": + target_probs = probs + else: + raise ValueError("Entropy loss {} not supported".format(loss_type)) + avg_probs = ops.mean(target_probs, axis=0) + avg_entropy = -ops.sum(avg_probs * ops.log(avg_probs + 1e-5)) + sample_entropy = -ops.mean(ops.sum(target_probs * log_probs, axis=-1)) + loss = sample_entropy - avg_entropy + return loss + + +class VQModel(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + self.config = config + self.encoder = Encoder( + ch_mult=config.encoder_ch_mult, + z_channels=config.z_channels, + dropout=config.dropout_p, + ) + self.decoder = Decoder( + ch_mult=config.decoder_ch_mult, + z_channels=config.z_channels, + dropout=config.dropout_p, + ) + + self.quantize = VectorQuantizer( + config.codebook_size, + config.codebook_embed_dim, + config.commit_loss_beta, + config.entropy_loss_ratio, + config.codebook_l2_norm, + config.codebook_show_usage, + ) + self.quant_conv = nn.Conv2d(config.z_channels, config.codebook_embed_dim, 1) + self.post_quant_conv = nn.Conv2d( + config.codebook_embed_dim, config.z_channels, 1 + ) + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + quant, emb_loss, info = self.quantize(h) + return quant, emb_loss, info + + def decode(self, quant): + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + return dec + + def decode_code(self, code_b, shape=None, channel_first=True): + quant_b = self.quantize.get_codebook_entry(code_b, shape, channel_first) + dec = self.decode(quant_b) + return dec + + def forward(self, input): + quant, diff, _ = self.encode(input) + dec = self.decode(quant) + return dec, diff + + +################################################################################# +# VQ Model Configs # +################################################################################# +def VQ_16(**kwargs): + return VQModel( + ModelArgs( + encoder_ch_mult=[1, 1, 2, 2, 4], decoder_ch_mult=[1, 1, 2, 2, 4], **kwargs + ) + ) + + +VQ_models = {"VQ-16": VQ_16} +if __name__ == '__main__': + import numpy as np + import mindspore as ms + np_tokens = np.load('/home/xuhang/np_token.npy') + print('np_tokens shape:', np_tokens.shape) + vq_model = VQModel( + ModelArgs( + encoder_ch_mult=[1, 1, 2, 2, 4], decoder_ch_mult=[1, 1, 2, 2, 4] + ) + ) + # vq_model.decode_code(generated_tokens.to(dtype=torch.int), shape=[ + # parallel_size, 8, img_size//patch_size, img_size//patch_size]) + # ret = vq_model.decode_code(ms.Tensor(np_tokens).astype(ms.int32), shape=[ + # 16, 8, 384//16, 384//16]) + # print(ret.shape) # 16, 3, 384, 384 + ms_token = np.load('/home/xuhang/ms_token.npy') + print('ms_token shape:', ms_token.shape) + # vq_model.decode_code(generated_tokens.to(dtype=torch.int), shape=[ + # parallel_size, 8, img_size//patch_size, img_size//patch_size]) + ret = vq_model.decode_code(Tensor(ms_token).astype(ms.int32), shape=[ + 2, 8, 384//16, 384//16]) + print(ret.shape) # 2, 3, 384, 384 + ret = ret.astype(ms.float32).asnumpy().transpose(0, 2, 3, 1) \ No newline at end of file diff --git a/llm/inference/janus_pro/janus/utils/__init__.py b/llm/inference/janus_pro/janus/utils/__init__.py new file mode 100644 index 000000000..8cb76409f --- /dev/null +++ b/llm/inference/janus_pro/janus/utils/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2023-2024 DeepSeek. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/llm/inference/janus_pro/janus/utils/conversation.py b/llm/inference/janus_pro/janus/utils/conversation.py new file mode 100644 index 000000000..98609d014 --- /dev/null +++ b/llm/inference/janus_pro/janus/utils/conversation.py @@ -0,0 +1,365 @@ +# Copyright (c) 2023-2024 DeepSeek. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +""" +From https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py +""" + +import dataclasses +from enum import IntEnum, auto +from typing import Dict, List + + +class SeparatorStyle(IntEnum): + """Separator styles.""" + + ADD_COLON_SINGLE = auto() + ADD_COLON_TWO = auto() + ADD_COLON_SPACE_SINGLE = auto() + NO_COLON_SINGLE = auto() + NO_COLON_TWO = auto() + ADD_NEW_LINE_SINGLE = auto() + LLAMA2 = auto() + CHATGLM = auto() + CHATML = auto() + CHATINTERN = auto() + DOLLY = auto() + RWKV = auto() + PHOENIX = auto() + ROBIN = auto() + DeepSeek = auto() + PLAIN = auto() + ALIGNMENT = auto() + + +@dataclasses.dataclass +class Conversation: + """A class that manages prompt templates and keeps all conversation history.""" + + # The name of this template + name: str + # The template of the system prompt + system_template: str = "{system_message}" + # The system message + system_message: str = "" + # The names of two roles + roles: List[str] = (("USER", "ASSISTANT"),) + # All messages. Each item is (role, message). + messages: List[List[str]] = () + # The number of few shot examples + offset: int = 0 + # The separator style and configurations + sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE + sep: str = "\n" + sep2: str = None + # Stop criteria (the default one is EOS token) + stop_str: str = None + # Stops generation if meeting any token in this list + stop_token_ids: List[int] = None + + def get_prompt(self) -> str: + """Get the prompt for generation.""" + system_prompt = self.system_template.format(system_message=self.system_message) + + if self.sep_style == SeparatorStyle.DeepSeek: + seps = [self.sep, self.sep2] + if system_prompt == "" or system_prompt is None: + ret = "" + else: + ret = system_prompt + seps[0] + for i, (role, message) in enumerate(self.messages): + if message: + ret += role + ": " + message + seps[i % 2] + else: + ret += role + ":" + return ret + elif self.sep_style == SeparatorStyle.LLAMA2: + seps = [self.sep, self.sep2] + if self.system_message: + ret = system_prompt + else: + ret = "[INST] " + for i, (role, message) in enumerate(self.messages): + tag = self.roles[i % 2] + if message: + if type(message) is tuple: # multimodal message + message, _ = message + if i == 0: + ret += message + " " + else: + ret += tag + " " + message + seps[i % 2] + else: + ret += tag + return ret + elif self.sep_style == SeparatorStyle.PLAIN: + seps = [self.sep, self.sep2] + ret = "" + for i, (role, message) in enumerate(self.messages): + if message: + if type(message) is tuple: + message, _, _ = message + if i % 2 == 0: + ret += message + seps[i % 2] + else: + ret += message + seps[i % 2] + else: + ret += "" + return ret + elif self.sep_style == SeparatorStyle.ALIGNMENT: + seps = [self.sep, self.sep2] + ret = "" + for i, (role, message) in enumerate(self.messages): + if message: + if type(message) is tuple: + message, _, _ = message + if i % 2 == 0: + ret += "\n" + seps[i % 2] + else: + ret += message + seps[i % 2] + else: + ret += "" + return ret + else: + raise ValueError(f"Invalid style: {self.sep_style}") + + def get_prompt_for_current_round(self, content=None): + """Get current round formatted question prompt during sft training""" + if self.sep_style == SeparatorStyle.PLAIN: + formatted_question = "\n" + elif self.sep_style == SeparatorStyle.DeepSeek: + formatted_question = ( + f"{self.roles[0]}: " + content.strip() + self.sep + f"{self.roles[1]}:" + ) + else: + raise ValueError(f"Unsupported sep_style: {self.sep_style}") + return formatted_question + + def set_system_message(self, system_message: str): + """Set the system message.""" + self.system_message = system_message + + def append_message(self, role: str, message: str): + """Append a new message.""" + self.messages.append([role, message]) + + def reset_message(self): + """Reset a new message.""" + self.messages = [] + + def update_last_message(self, message: str): + """Update the last output. + + The last message is typically set to be None when constructing the prompt, + so we need to update it in-place after getting the response from a model. + """ + self.messages[-1][1] = message + + def to_gradio_chatbot(self): + """Convert the conversation to gradio chatbot format.""" + ret = [] + for i, (role, msg) in enumerate(self.messages[self.offset :]): + if i % 2 == 0: + ret.append([msg, None]) + else: + ret[-1][-1] = msg + return ret + + def to_openai_api_messages(self): + """Convert the conversation to OpenAI chat completion format.""" + system_prompt = self.system_template.format(system_message=self.system_message) + ret = [{"role": "system", "content": system_prompt}] + + for i, (_, msg) in enumerate(self.messages[self.offset :]): + if i % 2 == 0: + ret.append({"role": "user", "content": msg}) + else: + if msg is not None: + ret.append({"role": "assistant", "content": msg}) + return ret + + def copy(self): + return Conversation( + name=self.name, + system_template=self.system_template, + system_message=self.system_message, + roles=self.roles, + messages=[[x, y] for x, y in self.messages], + offset=self.offset, + sep_style=self.sep_style, + sep=self.sep, + sep2=self.sep2, + stop_str=self.stop_str, + stop_token_ids=self.stop_token_ids, + ) + + def dict(self): + return { + "template_name": self.name, + "system_message": self.system_message, + "roles": self.roles, + "messages": self.messages, + "offset": self.offset, + } + + +# A global registry for all conversation templates +conv_templates: Dict[str, Conversation] = {} + + +def register_conv_template(template: Conversation, override: bool = False): + """Register a new conversation template.""" + if not override: + assert ( + template.name not in conv_templates + ), f"{template.name} has been registered." + + conv_templates[template.name] = template + + +def get_conv_template(name: str) -> Conversation: + """Get a conversation template.""" + return conv_templates[name].copy() + + +# llava_llama2 template +register_conv_template( + Conversation( + name="llava_llama2", + system_message="You are a helpful language and vision assistant. " + "You are able to understand the visual content that the user provides, " + "and assist the user with a variety of tasks using natural language.", + system_template="[INST] <>\n{system_message}\n<>\n\n", + roles=("[INST]", "[/INST]"), + messages=(), + offset=0, + sep_style=SeparatorStyle.LLAMA2, + sep=" ", + sep2=" ", + stop_token_ids=[2], + ) +) + +# llama2 template +# reference: https://github.com/facebookresearch/llama/blob/cfc3fc8c1968d390eb830e65c63865e980873a06/llama/generation.py#L212 +register_conv_template( + Conversation( + name="llama-2", + system_template="[INST] <>\n{system_message}\n<>\n\n", + roles=("[INST]", "[/INST]"), + messages=(), + offset=0, + sep_style=SeparatorStyle.LLAMA2, + sep=" ", + sep2=" ", + stop_token_ids=[2], + ) +) + + +# deepseek template +register_conv_template( + Conversation( + name="deepseek_old", + system_template="{system_message}", + # system_message="You are a helpful assistant. Please answer truthfully and write out your " + # "thinking step by step to be sure you get the right answer.", + system_message="", + roles=("User", "Assistant"), + messages=(), + offset=0, + sep_style=SeparatorStyle.DeepSeek, + sep="\n\n", + sep2="<|end▁of▁sentence|>", + stop_token_ids=[100001], + stop_str=["User:", "<|end▁of▁sentence|>"], + ) +) +register_conv_template( + Conversation( + name="deepseek", + system_template="{system_message}", + # system_message="You are a helpful assistant. Please answer truthfully and write out your " + # "thinking step by step to be sure you get the right answer.", + system_message="", + roles=("<|User|>", "<|Assistant|>"), + messages=(), + offset=0, + sep_style=SeparatorStyle.DeepSeek, + sep="\n\n", + sep2="<|end▁of▁sentence|>", + stop_token_ids=[100001], + stop_str=["<|User|>", "<|end▁of▁sentence|>"] + ) +) + +register_conv_template( + Conversation( + name="plain", + system_template="", + system_message="", + roles=("", ""), + messages=(), + offset=0, + sep_style=SeparatorStyle.PLAIN, + sep="", + sep2="", + stop_token_ids=[2], + stop_str=[""], + ) +) + + +register_conv_template( + Conversation( + name="alignment", + system_template="", + system_message="", + roles=("", ""), + messages=(), + offset=0, + sep_style=SeparatorStyle.ALIGNMENT, + sep="", + sep2="", + stop_token_ids=[2], + stop_str=[""], + ) +) + + +if __name__ == "__main__": + # print("Llama-2 template:") + # conv = get_conv_template("llama-2") + # conv.set_system_message("You are a helpful, respectful and honest assistant.") + # conv.append_message(conv.roles[0], "Hello!") + # conv.append_message(conv.roles[1], "Hi!") + # conv.append_message(conv.roles[0], "How are you?") + # conv.append_message(conv.roles[1], None) + # print(conv.get_prompt()) + + # print("\n") + + print("deepseek template:") + conv = get_conv_template("deepseek") + conv.append_message(conv.roles[0], "Hello!") + conv.append_message(conv.roles[1], "Hi! This is Tony.") + conv.append_message(conv.roles[0], "Who are you?") + conv.append_message(conv.roles[1], "I am a helpful assistant.") + conv.append_message(conv.roles[0], "How are you?") + conv.append_message(conv.roles[1], None) + print(conv.get_prompt()) diff --git a/llm/inference/janus_pro/janus/utils/io.py b/llm/inference/janus_pro/janus/utils/io.py new file mode 100644 index 000000000..48fee08df --- /dev/null +++ b/llm/inference/janus_pro/janus/utils/io.py @@ -0,0 +1,89 @@ +# Copyright (c) 2023-2024 DeepSeek. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import json +from typing import Dict, List + +import PIL.Image +import mindspore +import base64 +import io +from mindnlp.transformers import AutoModelForCausalLM + +from janus.models import MultiModalityCausalLM, VLChatProcessor + + +def load_pretrained_model(model_path: str): + vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path) + tokenizer = vl_chat_processor.tokenizer + + vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained( + model_path, trust_remote_code=True + ) + vl_gpt = vl_gpt.to_float(mindspore.float16).set_train(False) + + return tokenizer, vl_chat_processor, vl_gpt + + +def load_pil_images(conversations: List[Dict[str, str]]) -> List[PIL.Image.Image]: + """ + + Support file path or base64 images. + + Args: + conversations (List[Dict[str, str]]): the conversations with a list of messages. An example is : + [ + { + "role": "User", + "content": "\nExtract all information from this image and convert them into markdown format.", + "images": ["./examples/table_datasets.png"] + }, + {"role": "Assistant", "content": ""}, + ] + + Returns: + pil_images (List[PIL.Image.Image]): the list of PIL images. + + """ + + pil_images = [] + + for message in conversations: + if "images" not in message: + continue + + for image_data in message["images"]: + if image_data.startswith("data:image"): + # Image data is in base64 format + _, image_data = image_data.split(",", 1) + image_bytes = base64.b64decode(image_data) + pil_img = PIL.Image.open(io.BytesIO(image_bytes)) + else: + # Image data is a file path + pil_img = PIL.Image.open(image_data) + pil_img = pil_img.convert("RGB") + pil_images.append(pil_img) + + return pil_images + + +def load_json(filepath): + with open(filepath, "r") as f: + data = json.load(f) + return data diff --git a/llm/inference/janus_pro/understanding.py b/llm/inference/janus_pro/understanding.py new file mode 100644 index 000000000..c41fc752f --- /dev/null +++ b/llm/inference/janus_pro/understanding.py @@ -0,0 +1,69 @@ +import mindspore +from mindspore._c_expression import disable_multi_thread +disable_multi_thread() +from mindnlp.transformers import AutoModelForCausalLM +from janus.models import MultiModalityCausalLM, VLChatProcessor +from janus.utils.io import load_pil_images +from mindnlp.configs import set_pyboost, use_pyboost +from mindnlp.core import nn, Tensor +from mindnlp.core import no_grad + + +from mindnlp.configs import use_pyboost, set_pyboost +print('use_pyboost:', use_pyboost()) # 这里默认是False +mindspore.set_context( + mode=mindspore.PYNATIVE_MODE, + pynative_synchronize=True, + device_target="Ascend", + # mode=mindspore.GRAPH_MODE, + # jit_config={"jit_level":"O2"}, + ascend_config={"precision_mode": "allow_mix_precision"}) +print(mindspore.get_context("mode")) +# specify the path to the model +model_path = "/home/HwHiAiUser/Janus-Pro-1B" +print('start load processor') +vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained( + model_path) +tokenizer = vl_chat_processor.tokenizer +print('loaded processor') +vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained( + model_path, trust_remote_code=True, ms_dtype=mindspore.float16 +) +print('loaded processor and ckpt ') +# question = 'describe this image' +question = 'what is the animal in the image' +image = "./inpain_model_cat.png" +conversation = [ + { + "role": "<|User|>", + "content": f"\n{question}", + "images": [image], + }, + {"role": "<|Assistant|>", "content": ""}, +] + +# load images and prepare for inputs +pil_images = load_pil_images(conversation) +prepare_inputs = vl_chat_processor( + conversations=conversation, images=pil_images, force_batchify=True +) +print('process inputs') +# # run image encoder to get the image embeddings +inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs) +print('prepare inputs') +with no_grad(): + # # run the model to get the response + outputs = vl_gpt.language_model.generate( + inputs_embeds=inputs_embeds, + attention_mask=prepare_inputs.attention_mask, + pad_token_id=tokenizer.eos_token_id, + bos_token_id=tokenizer.bos_token_id, + eos_token_id=tokenizer.eos_token_id, + max_new_tokens=512, + do_sample=False, + use_cache=True, + ) + + answer = tokenizer.decode( + outputs[0].asnumpy().tolist(), skip_special_tokens=True) + print(f"{prepare_inputs['sft_format'][0]}", answer) diff --git a/mindnlp/core/nn/functional.py b/mindnlp/core/nn/functional.py index ddc5e3c14..d104f368f 100644 --- a/mindnlp/core/nn/functional.py +++ b/mindnlp/core/nn/functional.py @@ -476,6 +476,10 @@ def max_pool1d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode output_1d = output_2d.squeeze(2) return output_1d +def addcmul(input, tensor1, tensor2, value=1): + if not isinstance(value, mindspore.Tensor): + value = mindspore.Tensor(value, dtype=input.dtype) + return input + value*tensor1*tensor2 def group_norm(input, num_groups, weight=None, bias=None, eps=1e-5): if use_pyboost(): @@ -491,7 +495,8 @@ def group_norm(input, num_groups, weight=None, bias=None, eps=1e-5): affine_param_shape[1] = C affine_param_shape = tuple(affine_param_shape) if weight is not None and bias is not None: - out = bias.view(affine_param_shape).addcmul(out, weight.view(affine_param_shape), 1) + # out = bias.view(affine_param_shape).addcmul(out, weight.view(affine_param_shape), 1) + out = addcmul(bias.view(affine_param_shape), out, weight.view(affine_param_shape), 1) elif weight is not None: out = out.mul(weight.view(affine_param_shape)) elif bias is not None: diff --git a/mindnlp/core/nn/modules/conv.py b/mindnlp/core/nn/modules/conv.py index 5906acb31..62d49d465 100644 --- a/mindnlp/core/nn/modules/conv.py +++ b/mindnlp/core/nn/modules/conv.py @@ -1,6 +1,7 @@ # coding=utf-8 """conv""" import math +import mindspore from typing import Optional, Tuple, Union, List from mindspore import Tensor, ops as mops from ..parameter import Parameter @@ -267,9 +268,9 @@ def __init__( def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]): if self.padding_mode != 'zeros': input = ops.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode) - output = self.conv2d(input, weight) + output = self.conv2d(input.astype(mindspore.float16), weight.astype(mindspore.float16)) if bias is not None: - output = mops.bias_add(output, bias) + output = mops.bias_add(output.astype(mindspore.float16), bias.astype(mindspore.float16)) return output def forward(self, input): diff --git a/mindnlp/core/nn/modules/linear.py b/mindnlp/core/nn/modules/linear.py index 409e7c690..091db56d2 100644 --- a/mindnlp/core/nn/modules/linear.py +++ b/mindnlp/core/nn/modules/linear.py @@ -1,7 +1,9 @@ """linear""" from typing import Any import math +import mindspore from mindspore import Tensor +from mindspore import ops as mops from ..parameter import Parameter from .module import Module from .. import init @@ -58,6 +60,10 @@ def reset_parameters(self) -> None: init.uniform_(self.bias, -bound, bound) def forward(self, input): + if self.weight.dtype == mindspore.float32: + self.weight = Parameter(self.weight.astype(mindspore.float16)) + if self.bias is not None and self.bias.dtype == mindspore.float32: + self.bias = Parameter(self.bias.astype(mindspore.float16)) return F.linear(input, self.weight, self.bias) def __repr__(self): diff --git a/mindnlp/core/nn/modules/module.py b/mindnlp/core/nn/modules/module.py index c5efd9ac8..15309ab78 100644 --- a/mindnlp/core/nn/modules/module.py +++ b/mindnlp/core/nn/modules/module.py @@ -1225,7 +1225,7 @@ def train(self, mode=True): Module: self """ if ON_ORANGE_PI: - set_pyboost(not mode) + set_pyboost(False) self.training = mode for module in self.children(): module.train(mode) diff --git a/mindnlp/core/ops/blas.py b/mindnlp/core/ops/blas.py index 5ad117bd8..12c3d1a8f 100644 --- a/mindnlp/core/ops/blas.py +++ b/mindnlp/core/ops/blas.py @@ -30,7 +30,7 @@ def baddbmm(input, batch1, batch2, *, beta=1, alpha=1): def bmm(input, other): if ON_ORANGE_PI: input = input.to(mindspore.float16) - other = input.to(mindspore.float16) + other = other.to(mindspore.float16) if use_pyboost() and has_bmm: return mindspore.mint.bmm(input, other) return ops.bmm(input, other) diff --git a/mindnlp/transformers/cache_utils.py b/mindnlp/transformers/cache_utils.py index cadd2e047..1f2c4f95f 100644 --- a/mindnlp/transformers/cache_utils.py +++ b/mindnlp/transformers/cache_utils.py @@ -364,7 +364,10 @@ def update( # Update the number of seen tokens if layer_idx == 0: self._seen_tokens += key_states.shape[-2] - + if key_states.dtype!=mindspore.float16: + key_states = key_states.astype(mindspore.float16) + if key_states.dtype!=mindspore.float16: + value_states = value_states.astype(mindspore.float16) # Update the cache if len(self.key_cache) <= layer_idx: self.key_cache.append(key_states) @@ -375,7 +378,7 @@ def update( self.key_cache[layer_idx] = key_states self.value_cache[layer_idx] = value_states else: - self.key_cache[layer_idx] = ops.cat([self.key_cache[layer_idx], key_states], dim=-2) + self.key_cache[layer_idx] = ops.cat([self.key_cache[layer_idx].astype(mindspore.float16), key_states.astype(mindspore.float16)], dim=-2) self.value_cache[layer_idx] = ops.cat([self.value_cache[layer_idx], value_states], dim=-2) return self.key_cache[layer_idx], self.value_cache[layer_idx] diff --git a/mindnlp/transformers/models/llama/modeling_llama.py b/mindnlp/transformers/models/llama/modeling_llama.py index de8bd31c2..cb4ee7047 100644 --- a/mindnlp/transformers/models/llama/modeling_llama.py +++ b/mindnlp/transformers/models/llama/modeling_llama.py @@ -92,7 +92,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( causal_mask = causal_mask.copy() mask_length = attention_mask.shape[-1] # padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] - padding_mask = ops.narrow(causal_mask, -1, 0, mask_length) + attention_mask.view(attention_mask.shape[0], 1, 1, attention_mask.shape[1]) + padding_mask = ops.narrow(causal_mask, -1, 0, mask_length).astype(mindspore.float16) + attention_mask.view(attention_mask.shape[0], 1, 1, attention_mask.shape[1]).astype(mindspore.float16) padding_mask = padding_mask == 0 # causal_mask[:, :, :, :mask_length] = ops.narrow(causal_mask, -1, 0, mask_length).masked_fill( # padding_mask, min_dtype @@ -101,10 +101,10 @@ def _prepare_4d_causal_attention_mask_with_cache_position( causal_mask = causal_mask.masked_fill(padding_mask, min_dtype) else: causal_mask = ops.cat( - [ops.narrow(causal_mask, -1, 0, mask_length).masked_fill(padding_mask, min_dtype), - ops.narrow(causal_mask, -1, mask_length, causal_mask.shape[-1] - mask_length)], - dim=-1 - ) + [ops.narrow(causal_mask, -1, 0, mask_length).masked_fill(padding_mask, min_dtype), + ops.narrow(causal_mask, -1, mask_length, causal_mask.shape[-1] - mask_length)], + dim=-1 + ) return causal_mask @@ -165,7 +165,8 @@ def __init__( else: # BC: "rope_type" was originally "type" if config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + self.rope_type = config.rope_scaling.get( + "rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" self.max_seq_len_cached = config.max_position_embeddings @@ -174,7 +175,8 @@ def __init__( self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, **self.rope_kwargs) + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, **self.rope_kwargs) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq @@ -426,10 +428,10 @@ def forward( if attention_mask is not None: # no matter the length, we just slice it # causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] causal_mask = ops.narrow(attention_mask, 3, 0, key_states.shape[-2]) - attn_weights = attn_weights + causal_mask + attn_weights = attn_weights.astype(mindspore.float16) + causal_mask.astype(mindspore.float16) # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=mindspore.float32).to(query_states.dtype) + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=mindspore.float16).to(query_states.dtype) attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) attn_output = ops.matmul(attn_weights, value_states) @@ -849,7 +851,7 @@ def forward( hidden_states = outputs[0] if self.config.pretraining_tp > 1: - lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) + lm_head_slices = ops.split(self.lm_head.weight,self.vocab_size // self.config.pretraining_tp, dim=0) logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] logits = ops.cat(logits, dim=-1) else: @@ -1228,6 +1230,7 @@ def forward( attentions=outputs.attentions, ) + __all__ = [ "LlamaForCausalLM", "LlamaModel",