from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import cv2 as cv
import torch
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
from torch.autograd import Variable
from config import config
from config import update_config
from PIL import Image
import argparse
from models import cls_hrnet
def prediect():
# loda HRNET
config.merge_from_file('C:\\Users\\User\\Desktop\\HRnet-class\\HRNet-Classification\\experiments\\cls_hrnet_w18_small_v2_sgd_lr5e-2_wd1e-4_bs32_x100.yaml')
config.freeze()
#
# parser = argparse.ArgumentParser(description='Train network')
# parser.add_argument('--TEST.MODEL_FILE',
# help='model path',
# type = str,
# default='')
# args = parser.parse_args()
# update_config(config,args)
cudnn.benchmark = config.CUDNN.BENCHMARK
torch.backends.cudnn.deterministic = config.CUDNN.DETERMINISTIC
torch.backends.cudnn.enabled = config.CUDNN.ENABLED
hrnet = cls_hrnet.get_cls_net(config)
print('*********************************')
print(config.TEST.MODEL_FILE)
if config.TEST.MODEL_FILE:
hrnet.load_state_dict(torch.load(config.TEST.MODEL_FILE))
else:
print('没找到模型文件')
gpus = list(config.GPUS)
hrnet = torch.nn.DataParallel(hrnet, device_ids=gpus).cuda()
hrnet.eval()
Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
pli_img_path = r'C:\Users\User\Desktop\HRnet-class\HRNet-Classification\imagenet\images\train\normal\im0088.jpg'
pil_img = Image.open(pli_img_path)
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
input = transforms.Compose([
transforms.Resize(int(config.MODEL.IMAGE_SIZE[0] / 0.875)),
transforms.CenterCrop(config.MODEL.IMAGE_SIZE[0]),
transforms.ToTensor(),
normalize,
])(pil_img)
input = Variable(torch.unsqueeze(input, dim=0).float(), requires_grad=False)
# switch to evaluate mode
cls_pred = 0
with torch.no_grad():
output = hrnet(input)
# print(output)
# free image
torch.cuda.empty_cache()
cls_pred = output.argmax(dim=1)
print(cls_pred)
if cls_pred ==0:
print('This pic belong to class:')
print('fall')
if cls_pred ==1:
print('This pic belong to class:')
print('normal')
prediect()
使用HRnet训练自己的模型并检测
最新推荐文章于 2026-04-29 07:04:46 发布
Qwen3-32B-Chat 私有部署镜像 | RTX4090D 24G 显存 CUDA12.4 优化版
Qwen
文本生成
Qwen3
本镜像基于 RTX 4090D 24GB 显存 + CUDA 12.4 + 驱动 550.90.07 深度优化,内置完整运行环境与 Qwen3-32B 模型依赖,开箱即用。
您可能感兴趣的与本文相关的镜像
Qwen3-32B-Chat 私有部署镜像 | RTX4090D 24G 显存 CUDA12.4 优化版
Qwen
文本生成
Qwen3
本镜像基于 RTX 4090D 24GB 显存 + CUDA 12.4 + 驱动 550.90.07 深度优化,内置完整运行环境与 Qwen3-32B 模型依赖,开箱即用。
1万+

被折叠的 条评论
为什么被折叠?



