diff --git a/.gitignore b/.gitignore
index df8b21794..01886f168 100644
--- a/.gitignore
+++ b/.gitignore
@@ -15,7 +15,9 @@ flower_data
*.xml
*.bin
*.mapping
+*.csv
checkpoint
data
VOCdevkit
ssd_resnet50_v1_fpn_shared_box_predictor
+runs
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 000000000..f288702d2
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,674 @@
+ GNU GENERAL PUBLIC LICENSE
+ Version 3, 29 June 2007
+
+ Copyright (C) 2007 Free Software Foundation, Inc.
+ Everyone is permitted to copy and distribute verbatim copies
+ of this license document, but changing it is not allowed.
+
+ Preamble
+
+ The GNU General Public License is a free, copyleft license for
+software and other kinds of works.
+
+ The licenses for most software and other practical works are designed
+to take away your freedom to share and change the works. By contrast,
+the GNU General Public License is intended to guarantee your freedom to
+share and change all versions of a program--to make sure it remains free
+software for all its users. We, the Free Software Foundation, use the
+GNU General Public License for most of our software; it applies also to
+any other work released this way by its authors. You can apply it to
+your programs, too.
+
+ When we speak of free software, we are referring to freedom, not
+price. Our General Public Licenses are designed to make sure that you
+have the freedom to distribute copies of free software (and charge for
+them if you wish), that you receive source code or can get it if you
+want it, that you can change the software or use pieces of it in new
+free programs, and that you know you can do these things.
+
+ To protect your rights, we need to prevent others from denying you
+these rights or asking you to surrender the rights. Therefore, you have
+certain responsibilities if you distribute copies of the software, or if
+you modify it: responsibilities to respect the freedom of others.
+
+ For example, if you distribute copies of such a program, whether
+gratis or for a fee, you must pass on to the recipients the same
+freedoms that you received. You must make sure that they, too, receive
+or can get the source code. And you must show them these terms so they
+know their rights.
+
+ Developers that use the GNU GPL protect your rights with two steps:
+(1) assert copyright on the software, and (2) offer you this License
+giving you legal permission to copy, distribute and/or modify it.
+
+ For the developers' and authors' protection, the GPL clearly explains
+that there is no warranty for this free software. For both users' and
+authors' sake, the GPL requires that modified versions be marked as
+changed, so that their problems will not be attributed erroneously to
+authors of previous versions.
+
+ Some devices are designed to deny users access to install or run
+modified versions of the software inside them, although the manufacturer
+can do so. This is fundamentally incompatible with the aim of
+protecting users' freedom to change the software. The systematic
+pattern of such abuse occurs in the area of products for individuals to
+use, which is precisely where it is most unacceptable. Therefore, we
+have designed this version of the GPL to prohibit the practice for those
+products. If such problems arise substantially in other domains, we
+stand ready to extend this provision to those domains in future versions
+of the GPL, as needed to protect the freedom of users.
+
+ Finally, every program is threatened constantly by software patents.
+States should not allow patents to restrict development and use of
+software on general-purpose computers, but in those that do, we wish to
+avoid the special danger that patents applied to a free program could
+make it effectively proprietary. To prevent this, the GPL assures that
+patents cannot be used to render the program non-free.
+
+ The precise terms and conditions for copying, distribution and
+modification follow.
+
+ TERMS AND CONDITIONS
+
+ 0. Definitions.
+
+ "This License" refers to version 3 of the GNU General Public License.
+
+ "Copyright" also means copyright-like laws that apply to other kinds of
+works, such as semiconductor masks.
+
+ "The Program" refers to any copyrightable work licensed under this
+License. Each licensee is addressed as "you". "Licensees" and
+"recipients" may be individuals or organizations.
+
+ To "modify" a work means to copy from or adapt all or part of the work
+in a fashion requiring copyright permission, other than the making of an
+exact copy. The resulting work is called a "modified version" of the
+earlier work or a work "based on" the earlier work.
+
+ A "covered work" means either the unmodified Program or a work based
+on the Program.
+
+ To "propagate" a work means to do anything with it that, without
+permission, would make you directly or secondarily liable for
+infringement under applicable copyright law, except executing it on a
+computer or modifying a private copy. Propagation includes copying,
+distribution (with or without modification), making available to the
+public, and in some countries other activities as well.
+
+ To "convey" a work means any kind of propagation that enables other
+parties to make or receive copies. Mere interaction with a user through
+a computer network, with no transfer of a copy, is not conveying.
+
+ An interactive user interface displays "Appropriate Legal Notices"
+to the extent that it includes a convenient and prominently visible
+feature that (1) displays an appropriate copyright notice, and (2)
+tells the user that there is no warranty for the work (except to the
+extent that warranties are provided), that licensees may convey the
+work under this License, and how to view a copy of this License. If
+the interface presents a list of user commands or options, such as a
+menu, a prominent item in the list meets this criterion.
+
+ 1. Source Code.
+
+ The "source code" for a work means the preferred form of the work
+for making modifications to it. "Object code" means any non-source
+form of a work.
+
+ A "Standard Interface" means an interface that either is an official
+standard defined by a recognized standards body, or, in the case of
+interfaces specified for a particular programming language, one that
+is widely used among developers working in that language.
+
+ The "System Libraries" of an executable work include anything, other
+than the work as a whole, that (a) is included in the normal form of
+packaging a Major Component, but which is not part of that Major
+Component, and (b) serves only to enable use of the work with that
+Major Component, or to implement a Standard Interface for which an
+implementation is available to the public in source code form. A
+"Major Component", in this context, means a major essential component
+(kernel, window system, and so on) of the specific operating system
+(if any) on which the executable work runs, or a compiler used to
+produce the work, or an object code interpreter used to run it.
+
+ The "Corresponding Source" for a work in object code form means all
+the source code needed to generate, install, and (for an executable
+work) run the object code and to modify the work, including scripts to
+control those activities. However, it does not include the work's
+System Libraries, or general-purpose tools or generally available free
+programs which are used unmodified in performing those activities but
+which are not part of the work. For example, Corresponding Source
+includes interface definition files associated with source files for
+the work, and the source code for shared libraries and dynamically
+linked subprograms that the work is specifically designed to require,
+such as by intimate data communication or control flow between those
+subprograms and other parts of the work.
+
+ The Corresponding Source need not include anything that users
+can regenerate automatically from other parts of the Corresponding
+Source.
+
+ The Corresponding Source for a work in source code form is that
+same work.
+
+ 2. Basic Permissions.
+
+ All rights granted under this License are granted for the term of
+copyright on the Program, and are irrevocable provided the stated
+conditions are met. This License explicitly affirms your unlimited
+permission to run the unmodified Program. The output from running a
+covered work is covered by this License only if the output, given its
+content, constitutes a covered work. This License acknowledges your
+rights of fair use or other equivalent, as provided by copyright law.
+
+ You may make, run and propagate covered works that you do not
+convey, without conditions so long as your license otherwise remains
+in force. You may convey covered works to others for the sole purpose
+of having them make modifications exclusively for you, or provide you
+with facilities for running those works, provided that you comply with
+the terms of this License in conveying all material for which you do
+not control copyright. Those thus making or running the covered works
+for you must do so exclusively on your behalf, under your direction
+and control, on terms that prohibit them from making any copies of
+your copyrighted material outside their relationship with you.
+
+ Conveying under any other circumstances is permitted solely under
+the conditions stated below. Sublicensing is not allowed; section 10
+makes it unnecessary.
+
+ 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
+
+ No covered work shall be deemed part of an effective technological
+measure under any applicable law fulfilling obligations under article
+11 of the WIPO copyright treaty adopted on 20 December 1996, or
+similar laws prohibiting or restricting circumvention of such
+measures.
+
+ When you convey a covered work, you waive any legal power to forbid
+circumvention of technological measures to the extent such circumvention
+is effected by exercising rights under this License with respect to
+the covered work, and you disclaim any intention to limit operation or
+modification of the work as a means of enforcing, against the work's
+users, your or third parties' legal rights to forbid circumvention of
+technological measures.
+
+ 4. Conveying Verbatim Copies.
+
+ You may convey verbatim copies of the Program's source code as you
+receive it, in any medium, provided that you conspicuously and
+appropriately publish on each copy an appropriate copyright notice;
+keep intact all notices stating that this License and any
+non-permissive terms added in accord with section 7 apply to the code;
+keep intact all notices of the absence of any warranty; and give all
+recipients a copy of this License along with the Program.
+
+ You may charge any price or no price for each copy that you convey,
+and you may offer support or warranty protection for a fee.
+
+ 5. Conveying Modified Source Versions.
+
+ You may convey a work based on the Program, or the modifications to
+produce it from the Program, in the form of source code under the
+terms of section 4, provided that you also meet all of these conditions:
+
+ a) The work must carry prominent notices stating that you modified
+ it, and giving a relevant date.
+
+ b) The work must carry prominent notices stating that it is
+ released under this License and any conditions added under section
+ 7. This requirement modifies the requirement in section 4 to
+ "keep intact all notices".
+
+ c) You must license the entire work, as a whole, under this
+ License to anyone who comes into possession of a copy. This
+ License will therefore apply, along with any applicable section 7
+ additional terms, to the whole of the work, and all its parts,
+ regardless of how they are packaged. This License gives no
+ permission to license the work in any other way, but it does not
+ invalidate such permission if you have separately received it.
+
+ d) If the work has interactive user interfaces, each must display
+ Appropriate Legal Notices; however, if the Program has interactive
+ interfaces that do not display Appropriate Legal Notices, your
+ work need not make them do so.
+
+ A compilation of a covered work with other separate and independent
+works, which are not by their nature extensions of the covered work,
+and which are not combined with it such as to form a larger program,
+in or on a volume of a storage or distribution medium, is called an
+"aggregate" if the compilation and its resulting copyright are not
+used to limit the access or legal rights of the compilation's users
+beyond what the individual works permit. Inclusion of a covered work
+in an aggregate does not cause this License to apply to the other
+parts of the aggregate.
+
+ 6. Conveying Non-Source Forms.
+
+ You may convey a covered work in object code form under the terms
+of sections 4 and 5, provided that you also convey the
+machine-readable Corresponding Source under the terms of this License,
+in one of these ways:
+
+ a) Convey the object code in, or embodied in, a physical product
+ (including a physical distribution medium), accompanied by the
+ Corresponding Source fixed on a durable physical medium
+ customarily used for software interchange.
+
+ b) Convey the object code in, or embodied in, a physical product
+ (including a physical distribution medium), accompanied by a
+ written offer, valid for at least three years and valid for as
+ long as you offer spare parts or customer support for that product
+ model, to give anyone who possesses the object code either (1) a
+ copy of the Corresponding Source for all the software in the
+ product that is covered by this License, on a durable physical
+ medium customarily used for software interchange, for a price no
+ more than your reasonable cost of physically performing this
+ conveying of source, or (2) access to copy the
+ Corresponding Source from a network server at no charge.
+
+ c) Convey individual copies of the object code with a copy of the
+ written offer to provide the Corresponding Source. This
+ alternative is allowed only occasionally and noncommercially, and
+ only if you received the object code with such an offer, in accord
+ with subsection 6b.
+
+ d) Convey the object code by offering access from a designated
+ place (gratis or for a charge), and offer equivalent access to the
+ Corresponding Source in the same way through the same place at no
+ further charge. You need not require recipients to copy the
+ Corresponding Source along with the object code. If the place to
+ copy the object code is a network server, the Corresponding Source
+ may be on a different server (operated by you or a third party)
+ that supports equivalent copying facilities, provided you maintain
+ clear directions next to the object code saying where to find the
+ Corresponding Source. Regardless of what server hosts the
+ Corresponding Source, you remain obligated to ensure that it is
+ available for as long as needed to satisfy these requirements.
+
+ e) Convey the object code using peer-to-peer transmission, provided
+ you inform other peers where the object code and Corresponding
+ Source of the work are being offered to the general public at no
+ charge under subsection 6d.
+
+ A separable portion of the object code, whose source code is excluded
+from the Corresponding Source as a System Library, need not be
+included in conveying the object code work.
+
+ A "User Product" is either (1) a "consumer product", which means any
+tangible personal property which is normally used for personal, family,
+or household purposes, or (2) anything designed or sold for incorporation
+into a dwelling. In determining whether a product is a consumer product,
+doubtful cases shall be resolved in favor of coverage. For a particular
+product received by a particular user, "normally used" refers to a
+typical or common use of that class of product, regardless of the status
+of the particular user or of the way in which the particular user
+actually uses, or expects or is expected to use, the product. A product
+is a consumer product regardless of whether the product has substantial
+commercial, industrial or non-consumer uses, unless such uses represent
+the only significant mode of use of the product.
+
+ "Installation Information" for a User Product means any methods,
+procedures, authorization keys, or other information required to install
+and execute modified versions of a covered work in that User Product from
+a modified version of its Corresponding Source. The information must
+suffice to ensure that the continued functioning of the modified object
+code is in no case prevented or interfered with solely because
+modification has been made.
+
+ If you convey an object code work under this section in, or with, or
+specifically for use in, a User Product, and the conveying occurs as
+part of a transaction in which the right of possession and use of the
+User Product is transferred to the recipient in perpetuity or for a
+fixed term (regardless of how the transaction is characterized), the
+Corresponding Source conveyed under this section must be accompanied
+by the Installation Information. But this requirement does not apply
+if neither you nor any third party retains the ability to install
+modified object code on the User Product (for example, the work has
+been installed in ROM).
+
+ The requirement to provide Installation Information does not include a
+requirement to continue to provide support service, warranty, or updates
+for a work that has been modified or installed by the recipient, or for
+the User Product in which it has been modified or installed. Access to a
+network may be denied when the modification itself materially and
+adversely affects the operation of the network or violates the rules and
+protocols for communication across the network.
+
+ Corresponding Source conveyed, and Installation Information provided,
+in accord with this section must be in a format that is publicly
+documented (and with an implementation available to the public in
+source code form), and must require no special password or key for
+unpacking, reading or copying.
+
+ 7. Additional Terms.
+
+ "Additional permissions" are terms that supplement the terms of this
+License by making exceptions from one or more of its conditions.
+Additional permissions that are applicable to the entire Program shall
+be treated as though they were included in this License, to the extent
+that they are valid under applicable law. If additional permissions
+apply only to part of the Program, that part may be used separately
+under those permissions, but the entire Program remains governed by
+this License without regard to the additional permissions.
+
+ When you convey a copy of a covered work, you may at your option
+remove any additional permissions from that copy, or from any part of
+it. (Additional permissions may be written to require their own
+removal in certain cases when you modify the work.) You may place
+additional permissions on material, added by you to a covered work,
+for which you have or can give appropriate copyright permission.
+
+ Notwithstanding any other provision of this License, for material you
+add to a covered work, you may (if authorized by the copyright holders of
+that material) supplement the terms of this License with terms:
+
+ a) Disclaiming warranty or limiting liability differently from the
+ terms of sections 15 and 16 of this License; or
+
+ b) Requiring preservation of specified reasonable legal notices or
+ author attributions in that material or in the Appropriate Legal
+ Notices displayed by works containing it; or
+
+ c) Prohibiting misrepresentation of the origin of that material, or
+ requiring that modified versions of such material be marked in
+ reasonable ways as different from the original version; or
+
+ d) Limiting the use for publicity purposes of names of licensors or
+ authors of the material; or
+
+ e) Declining to grant rights under trademark law for use of some
+ trade names, trademarks, or service marks; or
+
+ f) Requiring indemnification of licensors and authors of that
+ material by anyone who conveys the material (or modified versions of
+ it) with contractual assumptions of liability to the recipient, for
+ any liability that these contractual assumptions directly impose on
+ those licensors and authors.
+
+ All other non-permissive additional terms are considered "further
+restrictions" within the meaning of section 10. If the Program as you
+received it, or any part of it, contains a notice stating that it is
+governed by this License along with a term that is a further
+restriction, you may remove that term. If a license document contains
+a further restriction but permits relicensing or conveying under this
+License, you may add to a covered work material governed by the terms
+of that license document, provided that the further restriction does
+not survive such relicensing or conveying.
+
+ If you add terms to a covered work in accord with this section, you
+must place, in the relevant source files, a statement of the
+additional terms that apply to those files, or a notice indicating
+where to find the applicable terms.
+
+ Additional terms, permissive or non-permissive, may be stated in the
+form of a separately written license, or stated as exceptions;
+the above requirements apply either way.
+
+ 8. Termination.
+
+ You may not propagate or modify a covered work except as expressly
+provided under this License. Any attempt otherwise to propagate or
+modify it is void, and will automatically terminate your rights under
+this License (including any patent licenses granted under the third
+paragraph of section 11).
+
+ However, if you cease all violation of this License, then your
+license from a particular copyright holder is reinstated (a)
+provisionally, unless and until the copyright holder explicitly and
+finally terminates your license, and (b) permanently, if the copyright
+holder fails to notify you of the violation by some reasonable means
+prior to 60 days after the cessation.
+
+ Moreover, your license from a particular copyright holder is
+reinstated permanently if the copyright holder notifies you of the
+violation by some reasonable means, this is the first time you have
+received notice of violation of this License (for any work) from that
+copyright holder, and you cure the violation prior to 30 days after
+your receipt of the notice.
+
+ Termination of your rights under this section does not terminate the
+licenses of parties who have received copies or rights from you under
+this License. If your rights have been terminated and not permanently
+reinstated, you do not qualify to receive new licenses for the same
+material under section 10.
+
+ 9. Acceptance Not Required for Having Copies.
+
+ You are not required to accept this License in order to receive or
+run a copy of the Program. Ancillary propagation of a covered work
+occurring solely as a consequence of using peer-to-peer transmission
+to receive a copy likewise does not require acceptance. However,
+nothing other than this License grants you permission to propagate or
+modify any covered work. These actions infringe copyright if you do
+not accept this License. Therefore, by modifying or propagating a
+covered work, you indicate your acceptance of this License to do so.
+
+ 10. Automatic Licensing of Downstream Recipients.
+
+ Each time you convey a covered work, the recipient automatically
+receives a license from the original licensors, to run, modify and
+propagate that work, subject to this License. You are not responsible
+for enforcing compliance by third parties with this License.
+
+ An "entity transaction" is a transaction transferring control of an
+organization, or substantially all assets of one, or subdividing an
+organization, or merging organizations. If propagation of a covered
+work results from an entity transaction, each party to that
+transaction who receives a copy of the work also receives whatever
+licenses to the work the party's predecessor in interest had or could
+give under the previous paragraph, plus a right to possession of the
+Corresponding Source of the work from the predecessor in interest, if
+the predecessor has it or can get it with reasonable efforts.
+
+ You may not impose any further restrictions on the exercise of the
+rights granted or affirmed under this License. For example, you may
+not impose a license fee, royalty, or other charge for exercise of
+rights granted under this License, and you may not initiate litigation
+(including a cross-claim or counterclaim in a lawsuit) alleging that
+any patent claim is infringed by making, using, selling, offering for
+sale, or importing the Program or any portion of it.
+
+ 11. Patents.
+
+ A "contributor" is a copyright holder who authorizes use under this
+License of the Program or a work on which the Program is based. The
+work thus licensed is called the contributor's "contributor version".
+
+ A contributor's "essential patent claims" are all patent claims
+owned or controlled by the contributor, whether already acquired or
+hereafter acquired, that would be infringed by some manner, permitted
+by this License, of making, using, or selling its contributor version,
+but do not include claims that would be infringed only as a
+consequence of further modification of the contributor version. For
+purposes of this definition, "control" includes the right to grant
+patent sublicenses in a manner consistent with the requirements of
+this License.
+
+ Each contributor grants you a non-exclusive, worldwide, royalty-free
+patent license under the contributor's essential patent claims, to
+make, use, sell, offer for sale, import and otherwise run, modify and
+propagate the contents of its contributor version.
+
+ In the following three paragraphs, a "patent license" is any express
+agreement or commitment, however denominated, not to enforce a patent
+(such as an express permission to practice a patent or covenant not to
+sue for patent infringement). To "grant" such a patent license to a
+party means to make such an agreement or commitment not to enforce a
+patent against the party.
+
+ If you convey a covered work, knowingly relying on a patent license,
+and the Corresponding Source of the work is not available for anyone
+to copy, free of charge and under the terms of this License, through a
+publicly available network server or other readily accessible means,
+then you must either (1) cause the Corresponding Source to be so
+available, or (2) arrange to deprive yourself of the benefit of the
+patent license for this particular work, or (3) arrange, in a manner
+consistent with the requirements of this License, to extend the patent
+license to downstream recipients. "Knowingly relying" means you have
+actual knowledge that, but for the patent license, your conveying the
+covered work in a country, or your recipient's use of the covered work
+in a country, would infringe one or more identifiable patents in that
+country that you have reason to believe are valid.
+
+ If, pursuant to or in connection with a single transaction or
+arrangement, you convey, or propagate by procuring conveyance of, a
+covered work, and grant a patent license to some of the parties
+receiving the covered work authorizing them to use, propagate, modify
+or convey a specific copy of the covered work, then the patent license
+you grant is automatically extended to all recipients of the covered
+work and works based on it.
+
+ A patent license is "discriminatory" if it does not include within
+the scope of its coverage, prohibits the exercise of, or is
+conditioned on the non-exercise of one or more of the rights that are
+specifically granted under this License. You may not convey a covered
+work if you are a party to an arrangement with a third party that is
+in the business of distributing software, under which you make payment
+to the third party based on the extent of your activity of conveying
+the work, and under which the third party grants, to any of the
+parties who would receive the covered work from you, a discriminatory
+patent license (a) in connection with copies of the covered work
+conveyed by you (or copies made from those copies), or (b) primarily
+for and in connection with specific products or compilations that
+contain the covered work, unless you entered into that arrangement,
+or that patent license was granted, prior to 28 March 2007.
+
+ Nothing in this License shall be construed as excluding or limiting
+any implied license or other defenses to infringement that may
+otherwise be available to you under applicable patent law.
+
+ 12. No Surrender of Others' Freedom.
+
+ If conditions are imposed on you (whether by court order, agreement or
+otherwise) that contradict the conditions of this License, they do not
+excuse you from the conditions of this License. If you cannot convey a
+covered work so as to satisfy simultaneously your obligations under this
+License and any other pertinent obligations, then as a consequence you may
+not convey it at all. For example, if you agree to terms that obligate you
+to collect a royalty for further conveying from those to whom you convey
+the Program, the only way you could satisfy both those terms and this
+License would be to refrain entirely from conveying the Program.
+
+ 13. Use with the GNU Affero General Public License.
+
+ Notwithstanding any other provision of this License, you have
+permission to link or combine any covered work with a work licensed
+under version 3 of the GNU Affero General Public License into a single
+combined work, and to convey the resulting work. The terms of this
+License will continue to apply to the part which is the covered work,
+but the special requirements of the GNU Affero General Public License,
+section 13, concerning interaction through a network will apply to the
+combination as such.
+
+ 14. Revised Versions of this License.
+
+ The Free Software Foundation may publish revised and/or new versions of
+the GNU General Public License from time to time. Such new versions will
+be similar in spirit to the present version, but may differ in detail to
+address new problems or concerns.
+
+ Each version is given a distinguishing version number. If the
+Program specifies that a certain numbered version of the GNU General
+Public License "or any later version" applies to it, you have the
+option of following the terms and conditions either of that numbered
+version or of any later version published by the Free Software
+Foundation. If the Program does not specify a version number of the
+GNU General Public License, you may choose any version ever published
+by the Free Software Foundation.
+
+ If the Program specifies that a proxy can decide which future
+versions of the GNU General Public License can be used, that proxy's
+public statement of acceptance of a version permanently authorizes you
+to choose that version for the Program.
+
+ Later license versions may give you additional or different
+permissions. However, no additional obligations are imposed on any
+author or copyright holder as a result of your choosing to follow a
+later version.
+
+ 15. Disclaimer of Warranty.
+
+ THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
+APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
+HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
+OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
+THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
+IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
+ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
+
+ 16. Limitation of Liability.
+
+ IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
+WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
+THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
+GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
+USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
+DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
+PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
+EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
+SUCH DAMAGES.
+
+ 17. Interpretation of Sections 15 and 16.
+
+ If the disclaimer of warranty and limitation of liability provided
+above cannot be given local legal effect according to their terms,
+reviewing courts shall apply local law that most closely approximates
+an absolute waiver of all civil liability in connection with the
+Program, unless a warranty or assumption of liability accompanies a
+copy of the Program in return for a fee.
+
+ END OF TERMS AND CONDITIONS
+
+ How to Apply These Terms to Your New Programs
+
+ If you develop a new program, and you want it to be of the greatest
+possible use to the public, the best way to achieve this is to make it
+free software which everyone can redistribute and change under these terms.
+
+ To do so, attach the following notices to the program. It is safest
+to attach them to the start of each source file to most effectively
+state the exclusion of warranty; and each file should have at least
+the "copyright" line and a pointer to where the full notice is found.
+
+
+ Copyright (C)
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see .
+
+Also add information on how to contact you by electronic and paper mail.
+
+ If the program does terminal interaction, make it output a short
+notice like this when it starts in an interactive mode:
+
+ Copyright (C)
+ This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
+ This is free software, and you are welcome to redistribute it
+ under certain conditions; type `show c' for details.
+
+The hypothetical commands `show w' and `show c' should show the appropriate
+parts of the General Public License. Of course, your program's commands
+might be different; for a GUI interface, you would use an "about box".
+
+ You should also get your employer (if you work as a programmer) or school,
+if any, to sign a "copyright disclaimer" for the program, if necessary.
+For more information on this, and how to apply and follow the GNU GPL, see
+.
+
+ The GNU General Public License does not permit incorporating your program
+into proprietary programs. If your program is a subroutine library, you
+may consider it more useful to permit linking proprietary applications with
+the library. If this is what you want to do, use the GNU Lesser General
+Public License instead of this License. But first, please read
+.
diff --git a/README.md b/README.md
index 7128d5066..4fb186d53 100644
--- a/README.md
+++ b/README.md
@@ -40,56 +40,74 @@
* [ResNeXt网络讲解](https://www.bilibili.com/video/BV1Ap4y1p71v/)
* [Pytorch搭建ResNeXt网络](https://www.bilibili.com/video/BV1rX4y1N7tE)
- * MobileNet_v1_v2(已完成)
- * [MobileNet_v1_v2网络讲解](https://www.bilibili.com/video/BV1yE411p7L7)
+ * MobileNet_V1_V2(已完成)
+ * [MobileNet_V1_V2网络讲解](https://www.bilibili.com/video/BV1yE411p7L7)
* [Pytorch搭建MobileNetV2网络](https://www.bilibili.com/video/BV1qE411T7qZ)
* [Tensorflow2搭建MobileNetV2网络](https://www.bilibili.com/video/BV1NE411K7tX)
- * MobileNet_v3(已完成)
- * [MobileNet_v3网络讲解](https://www.bilibili.com/video/BV1GK4y1p7uE)
+ * MobileNet_V3(已完成)
+ * [MobileNet_V3网络讲解](https://www.bilibili.com/video/BV1GK4y1p7uE)
* [Pytorch搭建MobileNetV3网络](https://www.bilibili.com/video/BV1zT4y1P7pd)
* [Tensorflow2搭建MobileNetV3网络](https://www.bilibili.com/video/BV1KA411g7wX)
- * ShuffleNet_v1_v2 (已完成)
- * [ShuffleNet_v1_v2网络讲解](https://www.bilibili.com/video/BV15y4y1Y7SY)
- * [使用Pytorch搭建ShuffleNetv2](https://www.bilibili.com/video/BV1dh411r76X)
- * [使用Tensorflow2搭建ShuffleNetv2](https://www.bilibili.com/video/BV1kr4y1N7bh)
+ * ShuffleNet_V1_V2 (已完成)
+ * [ShuffleNet_V1_V2网络讲解](https://www.bilibili.com/video/BV15y4y1Y7SY)
+ * [使用Pytorch搭建ShuffleNetV2](https://www.bilibili.com/video/BV1dh411r76X)
+ * [使用Tensorflow2搭建ShuffleNetV2](https://www.bilibili.com/video/BV1kr4y1N7bh)
- * EfficientNet_v1(已完成)
+ * EfficientNet_V1(已完成)
* [EfficientNet网络讲解](https://www.bilibili.com/video/BV1XK4y1U7PX)
* [使用Pytorch搭建EfficientNet](https://www.bilibili.com/video/BV19z4y1179h/)
* [使用Tensorflow2搭建EfficientNet](https://www.bilibili.com/video/BV1PK4y1S7Jf)
- * EfficientNet_v2 (已完成)
- * [EfficientNetV2网络讲解](https://b23.tv/NDR7Ug)
- * [使用Pytorch搭建EfficientNetV2](https://b23.tv/M4hagB)
- * [使用Tensorflow搭建EfficientNetV2](https://b23.tv/KUPbdr)
+ * EfficientNet_V2 (已完成)
+ * [EfficientNetV2网络讲解](https://www.bilibili.com/video/BV19v41157AU)
+ * [使用Pytorch搭建EfficientNetV2](https://www.bilibili.com/video/BV1Xy4y1g74u)
+ * [使用Tensorflow搭建EfficientNetV2](https://www.bilibili.com/video/BV19K4y1g7m4)
+
+ * RepVGG(已完成)
+ * [RepVGG网络讲解](https://www.bilibili.com/video/BV15f4y1o7QR)
* Vision Transformer(已完成)
- * [Multi-Head Attention讲解](https://b23.tv/gucpvt)
+ * [Multi-Head Attention讲解](https://www.bilibili.com/video/BV15v411W78M)
* [Vision Transformer网络讲解](https://www.bilibili.com/video/BV1Jh411Y7WQ)
- * [使用Pytorch搭建Vision Transformer](https://b23.tv/TT4VBM)
+ * [使用Pytorch搭建Vision Transformer](https://www.bilibili.com/video/BV1AL411W7dT)
* [使用tensorflow2搭建Vision Transformer](https://www.bilibili.com/video/BV1q64y1X7GY)
* Swin Transformer(已完成)
* [Swin Transformer网络讲解](https://www.bilibili.com/video/BV1pL4y1v7jC)
- * [使用Pytorch搭建Swin Transformer](https://b23.tv/vZnpJf)
- * [使用Tensorflow2搭建Swin Transformer](https://b23.tv/UHLMSF)
+ * [使用Pytorch搭建Swin Transformer](https://www.bilibili.com/video/BV1yg411K7Yc)
+ * [使用Tensorflow2搭建Swin Transformer](https://www.bilibili.com/video/BV1bR4y1t7qT)
+
+ * ConvNeXt(已完成)
+ * [ConvNeXt网络讲解](https://www.bilibili.com/video/BV1SS4y157fu)
+ * [使用Pytorch搭建ConvNeXt](https://www.bilibili.com/video/BV14S4y1L791)
+ * [使用Tensorflow2搭建ConvNeXt](https://www.bilibili.com/video/BV1TS4y1V7Gz)
+
+ * MobileViT(已完成)
+ * [MobileViT网络讲解](https://www.bilibili.com/video/BV1TG41137sb)
+ * [使用Pytorch搭建MobileViT](https://www.bilibili.com/video/BV1ae411L7Ki)
* 目标检测
* Faster-RCNN/FPN(已完成)
* [Faster-RCNN网络讲解](https://www.bilibili.com/video/BV1af4y1m7iL)
- * [FPN网络讲解](https://b23.tv/Qhn6xA)
+ * [FPN网络讲解](https://www.bilibili.com/video/BV1dh411U7D9)
* [Faster-RCNN源码解析(Pytorch)](https://www.bilibili.com/video/BV1of4y1m7nj)
* SSD/RetinaNet (已完成)
* [SSD网络讲解](https://www.bilibili.com/video/BV1fT4y1L7Gi)
- * [RetinaNet网络讲解](https://b23.tv/ZYCfd2)
+ * [RetinaNet网络讲解](https://www.bilibili.com/video/BV1Q54y1L7sM)
* [SSD源码解析(Pytorch)](https://www.bilibili.com/video/BV1vK411H771)
- * YOLOv3 SPP (已完成)
- * [YOLO系列网络讲解](https://www.bilibili.com/video/BV1yi4y1g7ro)
+ * YOLO Series (已完成)
+ * [YOLO系列网络讲解(V1~V3)](https://www.bilibili.com/video/BV1yi4y1g7ro)
* [YOLOv3 SPP源码解析(Pytorch版)](https://www.bilibili.com/video/BV1t54y1C7ra)
+ * [YOLOV4网络讲解](https://www.bilibili.com/video/BV1NF41147So)
+ * [YOLOV5网络讲解](https://www.bilibili.com/video/BV1T3411p7zR)
+ * [YOLOX 网络讲解](https://www.bilibili.com/video/BV1JW4y1k76c)
+
+ * FCOS(已完成)
+ * [FCOS网络讲解](https://www.bilibili.com/video/BV1G5411X7jw)
* 语义分割
* FCN (已完成)
@@ -98,7 +116,7 @@
* DeepLabV3 (已完成)
* [DeepLabV1网络讲解](https://www.bilibili.com/video/BV1SU4y1N7Ao)
- * [DeepLabV2网络讲解](https://www.bilibili.com/video/BV1gP4y1G7TC)
+ * [DeepLabV2网络讲解](https://www.bilibili.com/video/BV1gP4y1G7TC)
* [DeepLabV3网络讲解](https://www.bilibili.com/video/BV1Jb4y1q7j7)
* [DeepLabV3源码解析(Pytorch版)](https://www.bilibili.com/video/BV1TD4y1c7Wx)
@@ -106,21 +124,32 @@
* [LR-ASPP网络讲解](https://www.bilibili.com/video/BV1LS4y1M76E)
* [LR-ASPP源码解析(Pytorch版)](https://www.bilibili.com/video/bv13D4y1F7ML)
- * UNet (准备中)
- * [UNet网络讲解](https://www.bilibili.com/video/BV1Vq4y127fB/)
+ * U-Net (已完成)
+ * [U-Net网络讲解](https://www.bilibili.com/video/BV1Vq4y127fB/)
+ * [U-Net源码解析(Pytorch版)](https://www.bilibili.com/video/BV1Vq4y127fB)
+
+ * U2Net (已完成)
+ * [U2Net网络讲解](https://www.bilibili.com/video/BV1yB4y1z7mj)
+ * [U2Net源码解析(Pytorch版)](https://www.bilibili.com/video/BV1Kt4y137iS)
+
+* 实例分割
+ * Mask R-CNN(已完成)
+ * [Mask R-CNN网络讲解](https://www.bilibili.com/video/BV1ZY411774T)
+ * [Mask R-CNN源码解析(Pytorch版)](https://www.bilibili.com/video/BV1hY411E7wD)
+
+* 关键点检测
+ * DeepPose(已完成)
+ * [DeepPose网络讲解](https://www.bilibili.com/video/BV1bm421g7aJ)
+ * [DeepPose源码解析(Pytorch版)](https://www.bilibili.com/video/BV1bm421g7aJ)
+
+ * HRNet(已完成)
+ * [HRNet网络讲解](https://www.bilibili.com/video/BV1bB4y1y7qP)
+ * [HRNet源码解析(Pytorch版)](https://www.bilibili.com/video/BV1ar4y157JM)
**[更多相关视频请进入我的bilibili频道查看](https://space.bilibili.com/18161609/channel/index)**
---
-## 所需环境
-* Anaconda3(建议使用)
-* python3.6/3.7/3.8
-* pycharm (IDE)
-* pytorch 1.7.1 (pip package)
-* torchvision 0.8.1 (pip package)
-* tensorflow 2.4.1 (pip package)
-
欢迎大家关注下我的微信公众号(**阿喆学习小记**),平时会总结些相关学习博文。
如果有什么问题,也可以到我的CSDN中一起讨论。
diff --git a/article_link/README.md b/article_link/README.md
index cba6499e4..a1ca1ba6f 100644
--- a/article_link/README.md
+++ b/article_link/README.md
@@ -1,7 +1,5 @@
# 文献链接
------
-
## 图像分类(Classification)
- LeNet [http://yann.lecun.com/exdb/lenet/index.html](http://yann.lecun.com/exdb/lenet/index.html)
- AlexNet [http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf](http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf)
@@ -34,7 +32,11 @@
- Swin Transformer V2: Scaling Up Capacity and Resolution [https://arxiv.org/abs/2111.09883](https://arxiv.org/abs/2111.09883)
- BEiT: BERT Pre-Training of Image Transformers [https://arxiv.org/abs/2106.08254](https://arxiv.org/abs/2106.08254)
- MAE(Masked Autoencoders Are Scalable Vision Learners) [https://arxiv.org/abs/2111.06377](https://arxiv.org/abs/2111.06377)
-------
+- ConvNeXt(A ConvNet for the 2020s) [https://arxiv.org/abs/2201.03545](https://arxiv.org/abs/2201.03545)
+- MobileViT V1 [https://arxiv.org/abs/2110.02178](https://arxiv.org/abs/2110.02178)
+- MobileViT V2(Separable Self-attention for Mobile Vision Transformers) [https://arxiv.org/abs/2206.02680](https://arxiv.org/abs/2206.02680)
+- MobileOne(An Improved One millisecond Mobile Backbone) [https://arxiv.org/abs/2206.04040](https://arxiv.org/abs/2206.04040)
+
## 目标检测(Object Detection)
- R-CNN [https://arxiv.org/abs/1311.2524](https://arxiv.org/abs/1311.2524)
@@ -51,26 +53,49 @@
- YOLOv3 [https://arxiv.org/abs/1804.02767](https://arxiv.org/abs/1804.02767)
- YOLOv4 [https://arxiv.org/abs/2004.10934](https://arxiv.org/abs/2004.10934)
- YOLOX(Exceeding YOLO Series in 2021) [https://arxiv.org/abs/2107.08430](https://arxiv.org/abs/2107.08430)
+- YOLOv7 [https://arxiv.org/abs/2207.02696](https://arxiv.org/abs/2207.02696)
- PP-YOLO [https://arxiv.org/abs/2007.12099](https://arxiv.org/abs/2007.12099)
- PP-YOLOv2 [https://arxiv.org/abs/2104.10419](https://arxiv.org/abs/2104.10419)
- CornerNet [https://arxiv.org/abs/1808.01244](https://arxiv.org/abs/1808.01244)
-- FCOS [https://arxiv.org/abs/1904.01355](https://arxiv.org/abs/1904.01355)
+- FCOS(Old) [https://arxiv.org/abs/1904.01355](https://arxiv.org/abs/1904.01355)
+- FCOS(New) [https://arxiv.org/abs/2006.09214](https://arxiv.org/abs/2006.09214)
- CenterNet [https://arxiv.org/abs/1904.07850](https://arxiv.org/abs/1904.07850)
-## 图像分割(Segmentation)
+## 语义分割(Semantic Segmentation)
- FCN(Fully Convolutional Networks for Semantic Segmentation) [https://arxiv.org/abs/1411.4038](https://arxiv.org/abs/1411.4038)
- UNet(U-Net: Convolutional Networks for Biomedical Image Segmentation) [https://arxiv.org/abs/1505.04597](https://arxiv.org/abs/1505.04597)
- DeepLabv1(Semantic Image Segmentation with Deep Convolutional Nets and Fully Connected CRFs) [https://arxiv.org/abs/1412.7062](https://arxiv.org/abs/1412.7062)
- DeepLabv2(Semantic Image Segmentation with Deep Convolutional Nets, Atrous Convolution, and Fully Connected CRFs) [https://arxiv.org/abs/1606.00915](https://arxiv.org/abs/1606.00915)
- DeepLabv3(Rethinking Atrous Convolution for Semantic Image Segmentation) [https://arxiv.org/abs/1706.05587](https://arxiv.org/abs/1706.05587)
- DeepLabv3+(Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation) [https://arxiv.org/abs/1802.02611](https://arxiv.org/abs/1802.02611)
+- SegFormer [https://arxiv.org/abs/2105.15203](https://arxiv.org/abs/2105.15203)
+
+
+## 显著性目标检测(Salient Object Detection)
+- U2Net [https://arxiv.org/abs/2005.09007](https://arxiv.org/abs/2005.09007)
+
+
+## 实例分割(Instance Segmentation)
- Mask R-CNN [https://arxiv.org/abs/1703.06870](https://arxiv.org/abs/1703.06870)
+## 关键点检测(Keypoint Detection)
+- HRNet(Deep High-Resolution Representation Learning for Human Pose Estimation) [https://arxiv.org/abs/1902.09212](https://arxiv.org/abs/1902.09212)
+
+## 网络量化(Quantization)
+- Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference [https://arxiv.org/abs/1712.05877](https://arxiv.org/abs/1712.05877)
+- Quantizing deep convolutional networks for efficient inference: A whitepaper [https://arxiv.org/abs/1806.08342](https://arxiv.org/abs/1806.08342)
+- Data-Free Quantization Through Weight Equalization and Bias Correction [https://arxiv.org/abs/1906.04721](https://arxiv.org/abs/1906.04721)
+- LSQ: Learned Step Size Quantization [https://arxiv.org/abs/1902.08153](https://arxiv.org/abs/1902.08153)
+- LSQ+: Improving low-bit quantization through learnable offsets and better initialization [https://arxiv.org/abs/2004.09576](https://arxiv.org/abs/2004.09576)
+
+
+
## 自然语言处理
- Attention Is All You Need [https://arxiv.org/abs/1706.03762](https://arxiv.org/abs/1706.03762)
## Others
- Microsoft COCO: Common Objects in Context [https://arxiv.org/abs/1405.0312](https://arxiv.org/abs/1405.0312)
- The PASCALVisual Object Classes Challenge: A Retrospective [http://host.robots.ox.ac.uk/pascal/VOC/pubs/everingham15.pdf](http://host.robots.ox.ac.uk/pascal/VOC/pubs/everingham15.pdf)
+- Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization [https://arxiv.org/abs/1610.02391](https://arxiv.org/abs/1610.02391)
diff --git a/course_ppt/README.md b/course_ppt/README.md
index 98339951c..42b85e94a 100644
--- a/course_ppt/README.md
+++ b/course_ppt/README.md
@@ -1,42 +1,62 @@
# 为了精简项目,课程中的所有ppt都已转存至百度云
+**所有PPT都放在该文件夹中** 链接:https://pan.baidu.com/s/1VL6QTQ86sfY2aMDVo4Z-kg 提取码:4ydw
+
+**下面为单独每个ppt的链接**:
## 分类网络相关
-- **AlexNet** 链接: https://pan.baidu.com/s/1RJn5lzY8LwrmckUPvXcjmg 密码: 34ue
-- **VGG** 链接: https://pan.baidu.com/s/1BnYpdaDwAIcgRm7YwakEZw 密码: 8ev0
-- **GoogleNet** 链接: https://pan.baidu.com/s/1XjZXprvayV3dDMvLjoOk3A 密码: 9hq4
-- **ResNet** 链接: https://pan.baidu.com/s/1I2LUlwCSjNKr37T0n3NKzg 密码: f1s9
-- **ResNext** 链接:https://pan.baidu.com/s/1-anFYX5572MJmiQym9D4Eg 密码:f8ob
-- **MobileNet_v1_v2** 链接: https://pan.baidu.com/s/1ReDDCuK8wyH0XqniUgiSYQ 密码: ipqv
-- **MobileNet_v3** 链接:https://pan.baidu.com/s/13mzSpyxuA4T4ki7kEN1Xqw 密码:fp5g
-- **ShuffleNet_v1_v2** 链接:https://pan.baidu.com/s/1-DDwePMPCDvjw08YU8nAAA 密码:ad6n
-- **EfficientNet_v1** 链接:https://pan.baidu.com/s/1Sep9W0vLzfjhcHAXr6Bv0Q 密码:eufl
-- **EfficientNet_v2** 链接:https://pan.baidu.com/s/1tesrgY4CHLmq6P7s7TcHCw 密码:y2kz
-- **Transformer** 链接:https://pan.baidu.com/s/1DE6RDySr7NS0HQ35gBqP_g 密码:y9e7
-- **Vision Transformer** 链接:https://pan.baidu.com/s/1wzpHG8EK5gxg6UCMscYqMw 密码:cm1m
-- **Swin Transformer** 链接:https://pan.baidu.com/s/1O6XEEZUb6B6AGYON7-EOgA 密码:qkrn
-- **ConfusionMatrix** 链接: https://pan.baidu.com/s/1EtKzHkZyv2XssYtqmGYCLg 密码: uoo5
+- **AlexNet** 链接: https://pan.baidu.com/s/1RJn5lzY8LwrmckUPvXcjmg 提取码: 34ue
+- **VGG** 链接: https://pan.baidu.com/s/1BnYpdaDwAIcgRm7YwakEZw 提取码: 8ev0
+- **GoogleNet** 链接: https://pan.baidu.com/s/1XjZXprvayV3dDMvLjoOk3A 提取码: 9hq4
+- **ResNet** 链接: https://pan.baidu.com/s/1I2LUlwCSjNKr37T0n3NKzg 提取码: f1s9
+- **ResNext** 链接:https://pan.baidu.com/s/1-anFYX5572MJmiQym9D4Eg 提取码:f8ob
+- **MobileNet_v1_v2** 链接: https://pan.baidu.com/s/1ReDDCuK8wyH0XqniUgiSYQ 提取码: ipqv
+- **MobileNet_v3** 链接:https://pan.baidu.com/s/13mzSpyxuA4T4ki7kEN1Xqw 提取码:fp5g
+- **ShuffleNet_v1_v2** 链接:https://pan.baidu.com/s/1-DDwePMPCDvjw08YU8nAAA 提取码:ad6n
+- **EfficientNet_v1** 链接:https://pan.baidu.com/s/1Sep9W0vLzfjhcHAXr6Bv0Q 提取码:eufl
+- **EfficientNet_v2** 链接:https://pan.baidu.com/s/1tesrgY4CHLmq6P7s7TcHCw 提取码:y2kz
+- **Transformer** 链接:https://pan.baidu.com/s/1DE6RDySr7NS0HQ35gBqP_g 提取码:y9e7
+- **Vision Transformer** 链接:https://pan.baidu.com/s/1wzpHG8EK5gxg6UCMscYqMw 提取码:cm1m
+- **Swin Transformer** 链接:https://pan.baidu.com/s/1O6XEEZUb6B6AGYON7-EOgA 提取码:qkrn
+- **ConvNeXt** 链接:https://pan.baidu.com/s/1mgZjkirJPZ8huVls-O0xXA 提取码:kvqx
+- **RepVGG** 链接:https://pan.baidu.com/s/1uJP3hCHI79-tUdBNR_VAWQ 提取码:qe8a
+- **MobileViT** 链接:https://pan.baidu.com/s/1F8QJtFhTPWX8Vjr8_97scQ 提取码:lfn5
+- **ConfusionMatrix** 链接: https://pan.baidu.com/s/1EtKzHkZyv2XssYtqmGYCLg 提取码: uoo5
+- **Grad-CAM** 链接:https://pan.baidu.com/s/1ZHKBW7hINQXFI36hBYdC0Q 提取码:aru7
## 目标检测网络相关
-- **R-CNN** 链接: https://pan.baidu.com/s/1l_ZxkfJdyp3KoMLqwWbx5A 密码: nm1l
-- **Fast R-CNN** 链接: https://pan.baidu.com/s/1Pe_Tg43OVo-yZWj7t-_L6Q 密码: fe73
-- **Faster R-CNN** 链接:https://pan.baidu.com/s/1Dd0d_LY8l7Y1YkHQhp-WfA 密码:vzp4
-- **FPN** 链接:https://pan.baidu.com/s/1O9H0iqQMg9f_FZezUEKZ9g 密码:qbl8
-- **SSD** 链接: https://pan.baidu.com/s/15zF3GhIdg-E_tZX2Y2X-rw 密码: u7k1
-- **RetinaNet** 链接:https://pan.baidu.com/s/1beW612VCSnSu-v8iu_2-fA 密码:vqbu
-- **YOLOv1** 链接: https://pan.baidu.com/s/1vVyUNQHYEGjqosezlx_1Mg 密码: b3i0
-- **YOLOv2** 链接: https://pan.baidu.com/s/132aW1e_NYbaxxGi3cDVLYg 密码: tak7
-- **YOLOv3** 链接: https://pan.baidu.com/s/10oqZewzJmx5ptT9A4t-64w 密码: npji
-- **YOLOv3SPP** 链接: https://pan.baidu.com/s/15LRssnPez9pn6jRpW89Wlw 密码: nv9f
-- **Calculate mAP** 链接: https://pan.baidu.com/s/1jdA_n78J7nSUoOg6TTO5Bg 密码: eh62
-- **coco数据集简介** 链接:https://pan.baidu.com/s/1HfCvjt-8o9j5a916IYNVjw 密码:6rec
+- **R-CNN** 链接: https://pan.baidu.com/s/1l_ZxkfJdyp3KoMLqwWbx5A 提取码: nm1l
+- **Fast R-CNN** 链接: https://pan.baidu.com/s/1Pe_Tg43OVo-yZWj7t-_L6Q 提取码: fe73
+- **Faster R-CNN** 链接:https://pan.baidu.com/s/1Dd0d_LY8l7Y1YkHQhp-WfA 提取码:vzp4
+- **FPN** 链接:https://pan.baidu.com/s/1O9H0iqQMg9f_FZezUEKZ9g 提取码:qbl8
+- **SSD** 链接: https://pan.baidu.com/s/15zF3GhIdg-E_tZX2Y2X-rw 提取码: u7k1
+- **RetinaNet** 链接:https://pan.baidu.com/s/1beW612VCSnSu-v8iu_2-fA 提取码:vqbu
+- **YOLOv1** 链接: https://pan.baidu.com/s/1vVyUNQHYEGjqosezlx_1Mg 提取码: b3i0
+- **YOLOv2** 链接: https://pan.baidu.com/s/132aW1e_NYbaxxGi3cDVLYg 提取码: tak7
+- **YOLOv3** 链接:https://pan.baidu.com/s/1hZqdgh7wA7QeGAYTttlVOQ 提取码:5ulo
+- **YOLOv3SPP** 链接: https://pan.baidu.com/s/15LRssnPez9pn6jRpW89Wlw 提取码: nv9f
+- **YOLOv4** 链接:https://pan.baidu.com/s/1Ltw4v1pg0eZNFYR2ZBbZmQ 提取码:qjx4
+- **YOLOv5** 链接:https://pan.baidu.com/s/1rnvjwHLvOlJ9KpJ5z95GWw 提取码:kt04
+- **YOLOX** 链接:https://pan.baidu.com/s/1ex54twQC7hBE3szNko_K5A 提取码:al0r
+- **FCOS** 链接: https://pan.baidu.com/s/1KUc9dzvAbtwtGGm3ZZy_cw 提取码: h0as
+- **Calculate mAP** 链接: https://pan.baidu.com/s/1jdA_n78J7nSUoOg6TTO5Bg 提取码: eh62
+- **coco数据集简介** 链接:https://pan.baidu.com/s/1HfCvjt-8o9j5a916IYNVjw 提取码:6rec
## 图像分割网络相关
-- **语义分割前言** 链接:https://pan.baidu.com/s/1cwxe2wbaA_2DqNYADq3myA 密码:zzij
-- **转置卷积** 链接:https://pan.baidu.com/s/1A8688168fuWHyxJQtzupHw 密码:pgnf
-- **FCN** 链接:https://pan.baidu.com/s/1XLUneTLrdUyDAiV6kqi9rw 密码:126a
-- **膨胀卷积** 链接:https://pan.baidu.com/s/1QlQyniuMhBeXyEK420MIdQ 密码:ry6p
-- **DeepLab V1** 链接:https://pan.baidu.com/s/1NFxb7ADQOMVYLxmIKqTONQ 密码:500s
-- **DeepLab V2** 链接:https://pan.baidu.com/s/1woe3lJYBVkOdnn6XXlKf8g 密码:76ec
-- **DeepLab V3** 链接:https://pan.baidu.com/s/1WVBgc2Ld13D0_dkHGwhTpA 密码:m54m
\ No newline at end of file
+- **语义分割前言** 链接:https://pan.baidu.com/s/1cwxe2wbaA_2DqNYADq3myA 提取码:zzij
+- **转置卷积** 链接:https://pan.baidu.com/s/1A8688168fuWHyxJQtzupHw 提取码:pgnf
+- **FCN** 链接:https://pan.baidu.com/s/1XLUneTLrdUyDAiV6kqi9rw 提取码:126a
+- **膨胀卷积** 链接:https://pan.baidu.com/s/1QlQyniuMhBeXyEK420MIdQ 提取码:ry6p
+- **DeepLab V1** 链接:https://pan.baidu.com/s/1NFxb7ADQOMVYLxmIKqTONQ 提取码:500s
+- **DeepLab V2** 链接:https://pan.baidu.com/s/1woe3lJYBVkOdnn6XXlKf8g 提取码:76ec
+- **DeepLab V3** 链接:https://pan.baidu.com/s/1WVBgc2Ld13D0_dkHGwhTpA 提取码:m54m
+- **U2Net** 链接:https://pan.baidu.com/s/1ekbEm4dsjlFamK8dCs8yfA 提取码:472j
+
+
+## 实例分割
+- **Mask R-CNN** 链接:https://pan.baidu.com/s/1JpQ7ENEv_x9A1-O_NpjwYA 提取码:1t4i
+
+## 关键点检测
+- **HRNet** 链接: https://pan.baidu.com/s/1-8AJdU82K1j70KZK_rN7aQ 提取码: t4me
+
diff --git a/data_set/README.md b/data_set/README.md
index b81800caf..60007a5a5 100644
--- a/data_set/README.md
+++ b/data_set/README.md
@@ -1,7 +1,7 @@
-## 该文件夹是用来存放训练样本的目录
+## 该文件夹是用来存放训练数据的目录
### 使用步骤如下:
* (1)在data_set文件夹下创建新文件夹"flower_data"
-* (2)点击链接下载花分类数据集 [http://download.tensorflow.org/example_images/flower_photos.tgz](http://download.tensorflow.org/example_images/flower_photos.tgz)
+* (2)点击链接下载花分类数据集 [https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz](https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz)
* (3)解压数据集到flower_data文件夹下
* (4)执行"split_data.py"脚本自动将数据集划分成训练集train和验证集val
@@ -10,4 +10,4 @@
├── flower_photos(解压的数据集文件夹,3670个样本)
├── train(生成的训练集,3306个样本)
└── val(生成的验证集,364个样本)
-```
\ No newline at end of file
+```
diff --git a/deploying_service/deploying_pytorch/convert_openvino/convert_resnet34/README.md b/deploying_service/deploying_pytorch/convert_openvino/convert_resnet34/README.md
new file mode 100644
index 000000000..0376a4994
--- /dev/null
+++ b/deploying_service/deploying_pytorch/convert_openvino/convert_resnet34/README.md
@@ -0,0 +1,15 @@
+本项目展示如何将Pytorch中的ResNet34网络转成Openvino的IR格式,并进行量化处理,具体使用流程如下:
+1. 按照`requirements.txt`配置环境
+2. 下载事先训练好的ResNet34权重(之前在花分类数据集上训练得到的)放在当前文件夹下。百度云链接: https://pan.baidu.com/s/1x4WFX1HynYcXLium3UaaFQ 密码: qvi6
+3. 使用`convert_pytorch2onnx.py`将Resnet34转成ONNX格式
+4. 在命令行中使用以下指令将ONNX转成IR格式:
+```
+mo --input_model resnet34.onnx \
+ --input_shape "[1,3,224,224]" \
+ --mean_values="[123.675,116.28,103.53]" \
+ --scale_values="[58.395,57.12,57.375]" \
+ --data_type FP32 \
+ --output_dir ir_output
+```
+5. 下载并解压花分类数据集,将`quantization_int8.py`中的`data_path`指向解压后的`flower_photos`
+6. 使用`quantization_int8.py`量化模型
\ No newline at end of file
diff --git a/deploying_service/deploying_pytorch/convert_openvino/convert_resnet34/compare_fps.py b/deploying_service/deploying_pytorch/convert_openvino/convert_resnet34/compare_fps.py
new file mode 100644
index 000000000..c74639c25
--- /dev/null
+++ b/deploying_service/deploying_pytorch/convert_openvino/convert_resnet34/compare_fps.py
@@ -0,0 +1,126 @@
+import time
+import numpy as np
+import torch
+import onnxruntime
+import matplotlib.pyplot as plt
+from openvino.runtime import Core
+from torchvision.models import resnet34
+
+
+def normalize(image: np.ndarray) -> np.ndarray:
+ """
+ Normalize the image to the given mean and standard deviation
+ """
+ image = image.astype(np.float32)
+ mean = (0.485, 0.456, 0.406)
+ std = (0.229, 0.224, 0.225)
+ image /= 255.0
+ image -= mean
+ image /= std
+ return image
+
+
+def onnx_inference(onnx_path: str, image: np.ndarray, num_images: int = 20):
+ # load onnx model
+ ort_session = onnxruntime.InferenceSession(onnx_path)
+
+ # compute onnx Runtime output prediction
+ ort_inputs = {ort_session.get_inputs()[0].name: image}
+
+ start = time.perf_counter()
+ for _ in range(num_images):
+ ort_session.run(None, ort_inputs)
+ end = time.perf_counter()
+ time_onnx = end - start
+ print(
+ f"ONNX model in Inference Engine/CPU: {time_onnx / num_images:.3f} "
+ f"seconds per image, FPS: {num_images / time_onnx:.2f}"
+ )
+
+ return num_images / time_onnx
+
+
+def ir_inference(ir_path: str, image: np.ndarray, num_images: int = 20):
+ # Load the network in Inference Engine
+ ie = Core()
+ model_ir = ie.read_model(model=ir_path)
+ compiled_model_ir = ie.compile_model(model=model_ir, device_name="CPU")
+
+ # Get input and output layers
+ input_layer_ir = next(iter(compiled_model_ir.inputs))
+ output_layer_ir = next(iter(compiled_model_ir.outputs))
+
+ start = time.perf_counter()
+ request_ir = compiled_model_ir.create_infer_request()
+ for _ in range(num_images):
+ request_ir.infer(inputs={input_layer_ir.any_name: image})
+ end = time.perf_counter()
+ time_ir = end - start
+ print(
+ f"IR model in Inference Engine/CPU: {time_ir / num_images:.3f} "
+ f"seconds per image, FPS: {num_images / time_ir:.2f}"
+ )
+
+ return num_images / time_ir
+
+
+def pytorch_inference(image: np.ndarray, num_images: int = 20):
+ image = torch.as_tensor(image, dtype=torch.float32)
+
+ model = resnet34(pretrained=False, num_classes=5)
+ model.eval()
+
+ with torch.no_grad():
+ start = time.perf_counter()
+ for _ in range(num_images):
+ model(image)
+ end = time.perf_counter()
+ time_torch = end - start
+
+ print(
+ f"PyTorch model on CPU: {time_torch / num_images:.3f} seconds per image, "
+ f"FPS: {num_images / time_torch:.2f}"
+ )
+
+ return num_images / time_torch
+
+
+def plot_fps(v: dict):
+ x = list(v.keys())
+ y = list(v.values())
+
+ plt.bar(range(len(x)), y, align='center')
+ plt.xticks(range(len(x)), x)
+ for i, v in enumerate(y):
+ plt.text(x=i, y=v+0.5, s=f"{v:.2f}", ha='center')
+ plt.xlabel('model format')
+ plt.ylabel('fps')
+ plt.title('FPS comparison')
+ plt.show()
+ plt.savefig('fps_vs.jpg')
+
+
+def main():
+ image_h = 224
+ image_w = 224
+ onnx_path = "resnet34.onnx"
+ ir_path = "ir_output/resnet34.xml"
+
+ image = np.random.randn(image_h, image_w, 3)
+ normalized_image = normalize(image)
+
+ # Convert the resized images to network input shape
+ # [h, w, c] -> [c, h, w] -> [1, c, h, w]
+ input_image = np.expand_dims(np.transpose(image, (2, 0, 1)), 0)
+ normalized_input_image = np.expand_dims(np.transpose(normalized_image, (2, 0, 1)), 0)
+
+ onnx_fps = onnx_inference(onnx_path, normalized_input_image, num_images=100)
+ ir_fps = ir_inference(ir_path, input_image, num_images=100)
+ pytorch_fps = pytorch_inference(normalized_input_image, num_images=100)
+ plot_fps({"pytorch": round(pytorch_fps, 2),
+ "onnx": round(onnx_fps, 2),
+ "ir": round(ir_fps, 2)})
+
+
+if __name__ == '__main__':
+ main()
diff --git a/deploying_service/deploying_pytorch/convert_openvino/convert_resnet34/compare_onnx_and_ir.py b/deploying_service/deploying_pytorch/convert_openvino/convert_resnet34/compare_onnx_and_ir.py
new file mode 100644
index 000000000..c8ac7f32e
--- /dev/null
+++ b/deploying_service/deploying_pytorch/convert_openvino/convert_resnet34/compare_onnx_and_ir.py
@@ -0,0 +1,65 @@
+import numpy as np
+import onnxruntime
+from openvino.runtime import Core
+
+
+def normalize(image: np.ndarray) -> np.ndarray:
+ """
+ Normalize the image to the given mean and standard deviation
+ """
+ image = image.astype(np.float32)
+ mean = (0.485, 0.456, 0.406)
+ std = (0.229, 0.224, 0.225)
+ image /= 255.0
+ image -= mean
+ image /= std
+ return image
+
+
+def onnx_inference(onnx_path: str, image: np.ndarray):
+ # load onnx model
+ ort_session = onnxruntime.InferenceSession(onnx_path)
+
+ # compute onnx Runtime output prediction
+ ort_inputs = {ort_session.get_inputs()[0].name: image}
+ res_onnx = ort_session.run(None, ort_inputs)[0]
+ return res_onnx
+
+
+def ir_inference(ir_path: str, image: np.ndarray):
+ # Load the network in Inference Engine
+ ie = Core()
+ model_ir = ie.read_model(model=ir_path)
+ compiled_model_ir = ie.compile_model(model=model_ir, device_name="CPU")
+
+ # Get input and output layers
+ input_layer_ir = next(iter(compiled_model_ir.inputs))
+ output_layer_ir = next(iter(compiled_model_ir.outputs))
+
+ # Run inference on the input image
+ res_ir = compiled_model_ir([image])[output_layer_ir]
+ return res_ir
+
+
+def main():
+ image_h = 224
+ image_w = 224
+ onnx_path = "resnet34.onnx"
+ ir_path = "ir_output/resnet34.xml"
+
+ image = np.random.randn(image_h, image_w, 3)
+ normalized_image = normalize(image)
+
+ # Convert the resized images to network input shape
+ # [h, w, c] -> [c, h, w] -> [1, c, h, w]
+ input_image = np.expand_dims(np.transpose(image, (2, 0, 1)), 0)
+ normalized_input_image = np.expand_dims(np.transpose(normalized_image, (2, 0, 1)), 0)
+
+ onnx_res = onnx_inference(onnx_path, normalized_input_image)
+ ir_res = ir_inference(ir_path, input_image)
+ np.testing.assert_allclose(onnx_res, ir_res, rtol=1e-03, atol=1e-05)
+ print("Exported model has been tested with OpenvinoRuntime, and the result looks good!")
+
+
+if __name__ == '__main__':
+ main()
diff --git a/deploying_service/deploying_pytorch/convert_openvino/convert_resnet34/convert_pytorch2onnx.py b/deploying_service/deploying_pytorch/convert_openvino/convert_resnet34/convert_pytorch2onnx.py
new file mode 100644
index 000000000..9fd00349a
--- /dev/null
+++ b/deploying_service/deploying_pytorch/convert_openvino/convert_resnet34/convert_pytorch2onnx.py
@@ -0,0 +1,56 @@
+import torch
+import torch.onnx
+import onnx
+import onnxruntime
+import numpy as np
+from torchvision.models import resnet34
+
+device = torch.device("cpu")
+
+
+def to_numpy(tensor):
+ return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
+
+
+def main():
+ weights_path = "resNet34(flower).pth"
+ onnx_file_name = "resnet34.onnx"
+ batch_size = 1
+ img_h = 224
+ img_w = 224
+ img_channel = 3
+
+ # create model and load pretrain weights
+ model = resnet34(pretrained=False, num_classes=5)
+ model.load_state_dict(torch.load(weights_path, map_location='cpu'))
+
+ model.eval()
+ # input to the model
+ # [batch, channel, height, width]
+ x = torch.rand(batch_size, img_channel, img_h, img_w, requires_grad=True)
+ torch_out = model(x)
+
+ # export the model
+ torch.onnx.export(model, # model being run
+ x, # model input (or a tuple for multiple inputs)
+ onnx_file_name, # where to save the model (can be a file or file-like object)
+ verbose=False)
+
+ # check onnx model
+ onnx_model = onnx.load(onnx_file_name)
+ onnx.checker.check_model(onnx_model)
+
+ ort_session = onnxruntime.InferenceSession(onnx_file_name)
+
+ # compute ONNX Runtime output prediction
+ ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
+ ort_outs = ort_session.run(None, ort_inputs)
+
+ # compare ONNX Runtime and Pytorch results
+ # assert_allclose: Raises an AssertionError if two objects are not equal up to desired tolerance.
+ np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)
+ print("Exported model has been tested with ONNXRuntime, and the result looks good!")
+
+
+if __name__ == '__main__':
+ main()
diff --git a/deploying_service/deploying_pytorch/convert_openvino/convert_resnet34/model.py b/deploying_service/deploying_pytorch/convert_openvino/convert_resnet34/model.py
new file mode 100644
index 000000000..c6faa981c
--- /dev/null
+++ b/deploying_service/deploying_pytorch/convert_openvino/convert_resnet34/model.py
@@ -0,0 +1,302 @@
+from typing import Callable, List, Optional
+
+import torch
+from torch import nn, Tensor
+from torch.nn import functional as F
+from functools import partial
+
+
+def _make_divisible(ch, divisor=8, min_ch=None):
+ """
+ This function is taken from the original tf repo.
+ It ensures that all layers have a channel number that is divisible by 8
+ It can be seen here:
+ https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
+ """
+ if min_ch is None:
+ min_ch = divisor
+ new_ch = max(min_ch, int(ch + divisor / 2) // divisor * divisor)
+ # Make sure that round down does not go down by more than 10%.
+ if new_ch < 0.9 * ch:
+ new_ch += divisor
+ return new_ch
+
+
+class ConvBNActivation(nn.Sequential):
+ def __init__(self,
+ in_planes: int,
+ out_planes: int,
+ kernel_size: int = 3,
+ stride: int = 1,
+ groups: int = 1,
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
+ activation_layer: Optional[Callable[..., nn.Module]] = None):
+ padding = (kernel_size - 1) // 2
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ if activation_layer is None:
+ activation_layer = nn.ReLU6
+ super(ConvBNActivation, self).__init__(nn.Conv2d(in_channels=in_planes,
+ out_channels=out_planes,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ groups=groups,
+ bias=False),
+ norm_layer(out_planes),
+ activation_layer(inplace=True))
+
+
+class SqueezeExcitation(nn.Module):
+ def __init__(self, input_c: int, squeeze_factor: int = 4):
+ super(SqueezeExcitation, self).__init__()
+ squeeze_c = _make_divisible(input_c // squeeze_factor, 8)
+ self.fc1 = nn.Conv2d(input_c, squeeze_c, 1)
+ self.fc2 = nn.Conv2d(squeeze_c, input_c, 1)
+
+ def forward(self, x: Tensor) -> Tensor:
+ scale = F.adaptive_avg_pool2d(x, output_size=(1, 1))
+ scale = self.fc1(scale)
+ scale = F.relu(scale, inplace=True)
+ scale = self.fc2(scale)
+ scale = F.hardsigmoid(scale, inplace=True)
+ return scale * x
+
+
+class InvertedResidualConfig:
+ def __init__(self,
+ input_c: int,
+ kernel: int,
+ expanded_c: int,
+ out_c: int,
+ use_se: bool,
+ activation: str,
+ stride: int,
+ width_multi: float):
+ self.input_c = self.adjust_channels(input_c, width_multi)
+ self.kernel = kernel
+ self.expanded_c = self.adjust_channels(expanded_c, width_multi)
+ self.out_c = self.adjust_channels(out_c, width_multi)
+ self.use_se = use_se
+ self.use_hs = activation == "HS" # whether using h-swish activation
+ self.stride = stride
+
+ @staticmethod
+ def adjust_channels(channels: int, width_multi: float):
+ return _make_divisible(channels * width_multi, 8)
+
+
+class InvertedResidual(nn.Module):
+ def __init__(self,
+ cnf: InvertedResidualConfig,
+ norm_layer: Callable[..., nn.Module]):
+ super(InvertedResidual, self).__init__()
+
+ if cnf.stride not in [1, 2]:
+ raise ValueError("illegal stride value.")
+
+ self.use_res_connect = (cnf.stride == 1 and cnf.input_c == cnf.out_c)
+
+ layers: List[nn.Module] = []
+ activation_layer = nn.Hardswish if cnf.use_hs else nn.ReLU
+
+ # expand
+ if cnf.expanded_c != cnf.input_c:
+ layers.append(ConvBNActivation(cnf.input_c,
+ cnf.expanded_c,
+ kernel_size=1,
+ norm_layer=norm_layer,
+ activation_layer=activation_layer))
+
+ # depthwise
+ layers.append(ConvBNActivation(cnf.expanded_c,
+ cnf.expanded_c,
+ kernel_size=cnf.kernel,
+ stride=cnf.stride,
+ groups=cnf.expanded_c,
+ norm_layer=norm_layer,
+ activation_layer=activation_layer))
+
+ if cnf.use_se:
+ layers.append(SqueezeExcitation(cnf.expanded_c))
+
+ # project
+ layers.append(ConvBNActivation(cnf.expanded_c,
+ cnf.out_c,
+ kernel_size=1,
+ norm_layer=norm_layer,
+ activation_layer=nn.Identity))
+
+ self.block = nn.Sequential(*layers)
+ self.out_channels = cnf.out_c
+ self.is_strided = cnf.stride > 1
+
+ def forward(self, x: Tensor) -> Tensor:
+ result = self.block(x)
+ if self.use_res_connect:
+ result += x
+
+ return result
+
+
+class MobileNetV3(nn.Module):
+ def __init__(self,
+ inverted_residual_setting: List[InvertedResidualConfig],
+ last_channel: int,
+ num_classes: int = 1000,
+ block: Optional[Callable[..., nn.Module]] = None,
+ norm_layer: Optional[Callable[..., nn.Module]] = None):
+ super(MobileNetV3, self).__init__()
+
+ if not inverted_residual_setting:
+ raise ValueError("The inverted_residual_setting should not be empty.")
+ elif not (isinstance(inverted_residual_setting, List) and
+ all([isinstance(s, InvertedResidualConfig) for s in inverted_residual_setting])):
+ raise TypeError("The inverted_residual_setting should be List[InvertedResidualConfig]")
+
+ if block is None:
+ block = InvertedResidual
+
+ if norm_layer is None:
+ norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.01)
+
+ layers: List[nn.Module] = []
+
+ # building first layer
+ firstconv_output_c = inverted_residual_setting[0].input_c
+ layers.append(ConvBNActivation(3,
+ firstconv_output_c,
+ kernel_size=3,
+ stride=2,
+ norm_layer=norm_layer,
+ activation_layer=nn.Hardswish))
+ # building inverted residual blocks
+ for cnf in inverted_residual_setting:
+ layers.append(block(cnf, norm_layer))
+
+ # building last several layers
+ lastconv_input_c = inverted_residual_setting[-1].out_c
+ lastconv_output_c = 6 * lastconv_input_c
+ layers.append(ConvBNActivation(lastconv_input_c,
+ lastconv_output_c,
+ kernel_size=1,
+ norm_layer=norm_layer,
+ activation_layer=nn.Hardswish))
+ self.features = nn.Sequential(*layers)
+ self.avgpool = nn.AdaptiveAvgPool2d(1)
+ self.classifier = nn.Sequential(nn.Linear(lastconv_output_c, last_channel),
+ nn.Hardswish(inplace=True),
+ nn.Dropout(p=0.2, inplace=True),
+ nn.Linear(last_channel, num_classes))
+
+ # initial weights
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode="fan_out")
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.ones_(m.weight)
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.Linear):
+ nn.init.normal_(m.weight, 0, 0.01)
+ nn.init.zeros_(m.bias)
+
+ def _forward_impl(self, x: Tensor) -> Tensor:
+ x = self.features(x)
+ x = self.avgpool(x)
+ x = torch.flatten(x, 1)
+ x = self.classifier(x)
+
+ return x
+
+ def forward(self, x: Tensor) -> Tensor:
+ return self._forward_impl(x)
+
+
+def mobilenet_v3_large(num_classes: int = 1000,
+ reduced_tail: bool = False) -> MobileNetV3:
+ """
+ Constructs a large MobileNetV3 architecture from
+ "Searching for MobileNetV3" .
+
+ weights_link:
+ https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth
+
+ Args:
+ num_classes (int): number of classes
+ reduced_tail (bool): If True, reduces the channel counts of all feature layers
+ between C4 and C5 by 2. It is used to reduce the channel redundancy in the
+ backbone for Detection and Segmentation.
+ """
+ width_multi = 1.0
+ bneck_conf = partial(InvertedResidualConfig, width_multi=width_multi)
+ adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_multi=width_multi)
+
+ reduce_divider = 2 if reduced_tail else 1
+
+ inverted_residual_setting = [
+ # input_c, kernel, expanded_c, out_c, use_se, activation, stride
+ bneck_conf(16, 3, 16, 16, False, "RE", 1),
+ bneck_conf(16, 3, 64, 24, False, "RE", 2), # C1
+ bneck_conf(24, 3, 72, 24, False, "RE", 1),
+ bneck_conf(24, 5, 72, 40, True, "RE", 2), # C2
+ bneck_conf(40, 5, 120, 40, True, "RE", 1),
+ bneck_conf(40, 5, 120, 40, True, "RE", 1),
+ bneck_conf(40, 3, 240, 80, False, "HS", 2), # C3
+ bneck_conf(80, 3, 200, 80, False, "HS", 1),
+ bneck_conf(80, 3, 184, 80, False, "HS", 1),
+ bneck_conf(80, 3, 184, 80, False, "HS", 1),
+ bneck_conf(80, 3, 480, 112, True, "HS", 1),
+ bneck_conf(112, 3, 672, 112, True, "HS", 1),
+ bneck_conf(112, 5, 672, 160 // reduce_divider, True, "HS", 2), # C4
+ bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1),
+ bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1),
+ ]
+ last_channel = adjust_channels(1280 // reduce_divider) # C5
+
+ return MobileNetV3(inverted_residual_setting=inverted_residual_setting,
+ last_channel=last_channel,
+ num_classes=num_classes)
+
+
+def mobilenet_v3_small(num_classes: int = 1000,
+ reduced_tail: bool = False) -> MobileNetV3:
+ """
+ Constructs a large MobileNetV3 architecture from
+ "Searching for MobileNetV3" .
+
+ weights_link:
+ https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth
+
+ Args:
+ num_classes (int): number of classes
+ reduced_tail (bool): If True, reduces the channel counts of all feature layers
+ between C4 and C5 by 2. It is used to reduce the channel redundancy in the
+ backbone for Detection and Segmentation.
+ """
+ width_multi = 1.0
+ bneck_conf = partial(InvertedResidualConfig, width_multi=width_multi)
+ adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_multi=width_multi)
+
+ reduce_divider = 2 if reduced_tail else 1
+
+ inverted_residual_setting = [
+ # input_c, kernel, expanded_c, out_c, use_se, activation, stride
+ bneck_conf(16, 3, 16, 16, True, "RE", 2), # C1
+ bneck_conf(16, 3, 72, 24, False, "RE", 2), # C2
+ bneck_conf(24, 3, 88, 24, False, "RE", 1),
+ bneck_conf(24, 5, 96, 40, True, "HS", 2), # C3
+ bneck_conf(40, 5, 240, 40, True, "HS", 1),
+ bneck_conf(40, 5, 240, 40, True, "HS", 1),
+ bneck_conf(40, 5, 120, 48, True, "HS", 1),
+ bneck_conf(48, 5, 144, 48, True, "HS", 1),
+ bneck_conf(48, 5, 288, 96 // reduce_divider, True, "HS", 2), # C4
+ bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1),
+ bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1)
+ ]
+ last_channel = adjust_channels(1024 // reduce_divider) # C5
+
+ return MobileNetV3(inverted_residual_setting=inverted_residual_setting,
+ last_channel=last_channel,
+ num_classes=num_classes)
diff --git a/deploying_service/deploying_pytorch/convert_openvino/convert_resnet34/quantization_int8.py b/deploying_service/deploying_pytorch/convert_openvino/convert_resnet34/quantization_int8.py
new file mode 100644
index 000000000..a6d663735
--- /dev/null
+++ b/deploying_service/deploying_pytorch/convert_openvino/convert_resnet34/quantization_int8.py
@@ -0,0 +1,84 @@
+from addict import Dict
+from compression.engines.ie_engine import IEEngine
+from compression.graph import load_model, save_model
+from compression.graph.model_utils import compress_model_weights
+from compression.pipeline.initializer import create_pipeline
+from utils import MyDataLoader, Accuracy, read_split_data
+
+
+def main():
+ data_path = "/data/flower_photos"
+ ir_model_xml = "ir_output/resnet34.xml"
+ ir_model_bin = "ir_output/resnet34.bin"
+ save_dir = "quant_ir_output"
+ model_name = "quantized_resnet34"
+ img_w = 224
+ img_h = 224
+
+ model_config = Dict({
+ 'model_name': 'resnet34',
+ 'model': ir_model_xml,
+ 'weights': ir_model_bin
+ })
+ engine_config = Dict({
+ 'device': 'CPU',
+ 'stat_requests_number': 2,
+ 'eval_requests_number': 2
+ })
+ dataset_config = {
+ 'data_source': data_path
+ }
+ algorithms = [
+ {
+ 'name': 'DefaultQuantization',
+ 'params': {
+ 'target_device': 'CPU',
+ 'preset': 'performance',
+ 'stat_subset_size': 300
+ }
+ }
+ ]
+
+ # Steps 1-7: Model optimization
+ # Step 1: Load the model.
+ model = load_model(model_config)
+
+ # Step 2: Initialize the data loader.
+ _, _, val_images_path, val_images_label = read_split_data(data_path, val_rate=0.2)
+ data_loader = MyDataLoader(dataset_config, val_images_path, val_images_label, img_w, img_h)
+
+ # Step 3 (Optional. Required for AccuracyAwareQuantization): Initialize the metric.
+ metric = Accuracy(top_k=1)
+
+ # Step 4: Initialize the engine for metric calculation and statistics collection.
+ engine = IEEngine(engine_config, data_loader, metric)
+
+ # Step 5: Create a pipeline of compression algorithms.
+ pipeline = create_pipeline(algorithms, engine)
+
+ # Step 6: Execute the pipeline.
+ compressed_model = pipeline.run(model)
+
+ # Step 7 (Optional): Compress model weights quantized precision
+ # in order to reduce the size of final .bin file.
+ compress_model_weights(compressed_model)
+
+ # Step 8: Save the compressed model to the desired path.
+ compressed_model_paths = save_model(model=compressed_model,
+ save_path=save_dir,
+ model_name=model_name)
+
+ # Step 9: Compare accuracy of the original and quantized models.
+ metric_results = pipeline.evaluate(model)
+ if metric_results:
+ for name, value in metric_results.items():
+ print(f"Accuracy of the original model: {name}: {value}")
+
+ metric_results = pipeline.evaluate(compressed_model)
+ if metric_results:
+ for name, value in metric_results.items():
+ print(f"Accuracy of the optimized model: {name}: {value}")
+
+
+if __name__ == '__main__':
+ main()
diff --git a/deploying_service/deploying_pytorch/convert_openvino/convert_resnet34/requirements.txt b/deploying_service/deploying_pytorch/convert_openvino/convert_resnet34/requirements.txt
new file mode 100644
index 000000000..662c48d20
--- /dev/null
+++ b/deploying_service/deploying_pytorch/convert_openvino/convert_resnet34/requirements.txt
@@ -0,0 +1,7 @@
+torch==1.11.0
+torchvision==0.12.0
+onnx==1.13.0
+onnxruntime==1.8.0
+protobuf==3.19.5
+openvino-dev==2022.1.0
+matplotlib
\ No newline at end of file
diff --git a/deploying_service/deploying_pytorch/convert_openvino/convert_resnet34/utils.py b/deploying_service/deploying_pytorch/convert_openvino/convert_resnet34/utils.py
new file mode 100644
index 000000000..62d0ae03c
--- /dev/null
+++ b/deploying_service/deploying_pytorch/convert_openvino/convert_resnet34/utils.py
@@ -0,0 +1,137 @@
+import os
+import json
+import random
+
+from PIL import Image
+import numpy as np
+from compression.api import DataLoader, Metric
+from torchvision.transforms import transforms
+
+
+def read_split_data(root: str, val_rate: float = 0.2):
+ random.seed(0) # 保证随机结果可复现
+ assert os.path.exists(root), "dataset root: {} does not exist.".format(root)
+
+ # 遍历文件夹,一个文件夹对应一个类别
+ flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
+ # 排序,保证顺序一致
+ flower_class.sort()
+ # 生成类别名称以及对应的数字索引
+ class_indices = dict((k, v) for v, k in enumerate(flower_class))
+ json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4)
+ with open('class_indices.json', 'w') as json_file:
+ json_file.write(json_str)
+
+ train_images_path = [] # 存储训练集的所有图片路径
+ train_images_label = [] # 存储训练集图片对应索引信息
+ val_images_path = [] # 存储验证集的所有图片路径
+ val_images_label = [] # 存储验证集图片对应索引信息
+ every_class_num = [] # 存储每个类别的样本总数
+ supported = [".jpg", ".JPG", ".png", ".PNG"] # 支持的文件后缀类型
+ # 遍历每个文件夹下的文件
+ for cla in flower_class:
+ cla_path = os.path.join(root, cla)
+ # 遍历获取supported支持的所有文件路径
+ images = [os.path.join(root, cla, i) for i in os.listdir(cla_path)
+ if os.path.splitext(i)[-1] in supported]
+ # 获取该类别对应的索引
+ image_class = class_indices[cla]
+ # 记录该类别的样本数量
+ every_class_num.append(len(images))
+ # 按比例随机采样验证样本
+ val_path = random.sample(images, k=int(len(images) * val_rate))
+
+ for img_path in images:
+ if img_path in val_path: # 如果该路径在采样的验证集样本中则存入验证集
+ val_images_path.append(img_path)
+ val_images_label.append(image_class)
+ else: # 否则存入训练集
+ train_images_path.append(img_path)
+ train_images_label.append(image_class)
+
+ print("{} images were found in the dataset.".format(sum(every_class_num)))
+ print("{} images for training.".format(len(train_images_path)))
+ print("{} images for validation.".format(len(val_images_path)))
+
+ return train_images_path, train_images_label, val_images_path, val_images_label
+
+
+# Custom implementation of classification accuracy metric.
+class Accuracy(Metric):
+ # Required methods
+ def __init__(self, top_k=1):
+ super().__init__()
+ self._top_k = top_k
+ self._name = 'accuracy@top{}'.format(self._top_k)
+ self._matches = []
+
+ @property
+ def value(self):
+ """ Returns accuracy metric value for the last model output. """
+ return {self._name: self._matches[-1]}
+
+ @property
+ def avg_value(self):
+ """ Returns accuracy metric value for all model outputs. """
+ return {self._name: np.ravel(self._matches).mean()}
+
+ def update(self, output, target):
+ """ Updates prediction matches.
+ :param output: model output
+ :param target: annotations
+ """
+ if len(output) > 1:
+ raise Exception('The accuracy metric cannot be calculated '
+ 'for a model with multiple outputs')
+ if isinstance(target, dict):
+ target = list(target.values())
+ predictions = np.argsort(output[0], axis=1)[:, -self._top_k:]
+ match = [float(t in predictions[i]) for i, t in enumerate(target)]
+
+ self._matches.append(match)
+
+ def reset(self):
+ """ Resets collected matches """
+ self._matches = []
+
+ def get_attributes(self):
+ """
+ Returns a dictionary of metric attributes {metric_name: {attribute_name: value}}.
+ Required attributes: 'direction': 'higher-better' or 'higher-worse'
+ 'type': metric type
+ """
+ return {self._name: {'direction': 'higher-better',
+ 'type': 'accuracy'}}
+
+
+class MyDataLoader(DataLoader):
+ def __init__(self, cfg, images_path: list, images_label: list, img_w: int = 224, img_h: int = 224):
+ super().__init__(cfg)
+ self.images_path = images_path
+ self.images_label = images_label
+ self.image_w = img_w
+ self.image_h = img_h
+ self.transforms = transforms.Compose([
+ transforms.Resize(min(img_h, img_w)),
+ transforms.CenterCrop((img_h, img_w))
+ ])
+
+ def __len__(self):
+ return len(self.images_label)
+
+ def __getitem__(self, index):
+ """
+ Return one sample of index, label and picture.
+ :param index: index of the taken sample.
+ """
+ if index >= len(self):
+ raise IndexError
+
+ img = Image.open(self.images_path[index])
+ img = self.transforms(img)
+
+ # Convert the resized images to network input shape
+ # [h, w, c] -> [c, h, w] -> [1, c, h, w]
+ img = np.expand_dims(np.transpose(np.array(img), (2, 0, 1)), 0)
+
+ return (index, self.images_label[index]), img
diff --git a/deploying_service/deploying_pytorch/convert_openvino/convert_yolov5/README.md b/deploying_service/deploying_pytorch/convert_openvino/convert_yolov5/README.md
new file mode 100644
index 000000000..682bb111f
--- /dev/null
+++ b/deploying_service/deploying_pytorch/convert_openvino/convert_yolov5/README.md
@@ -0,0 +1,61 @@
+OpenVINO量化YOLOv5
+
+1. 按照`requirements.txt`配置环境
+2. 将YOLOv5转为ONNX
+YOLOv5官方有提供导出ONNX以及OpenVINO的方法,但我这里仅导出成ONNX,这里以YOLOv5s为例
+```
+python export.py --weights yolov5s.pt --include onnx
+```
+
+3. ONNX转换为IR
+使用OpenVINO的`mo`工具将ONNX转为OpenVINO的IR格式
+```
+mo --input_model yolov5s.onnx \
+ --input_shape "[1,3,640,640]" \
+ --scale 255 \
+ --data_type FP32 \
+ --output_dir ir_output
+```
+
+4. 量化模型
+使用`quantization_int8.py`进行模型的量化,量化过程中需要使用到COCO2017数据集,需要将`data_path`指向coco2017目录
+```
+├── coco2017: 数据集根目录
+ ├── train2017: 所有训练图像文件夹(118287张)
+ ├── val2017: 所有验证图像文件夹(5000张)
+ └── annotations: 对应标注文件夹
+ ├── instances_train2017.json: 对应目标检测、分割任务的训练集标注文件
+ ├── instances_val2017.json: 对应目标检测、分割任务的验证集标注文件
+ ├── captions_train2017.json: 对应图像描述的训练集标注文件
+ ├── captions_val2017.json: 对应图像描述的验证集标注文件
+ ├── person_keypoints_train2017.json: 对应人体关键点检测的训练集标注文件
+ └── person_keypoints_val2017.json: 对应人体关键点检测的验证集标注文件夹
+```
+
+5. benchmark
+直接利用`benchmark_app`工具测试量化前后的`Throughput`,这里以`CPU: Intel(R) Core(TM) i7-6700 CPU @ 3.40GHz`设备为例
+```
+benchmark_app -m ir_output/yolov5s.xml -d CPU -api sync
+```
+output:
+```
+Latency:
+ Median: 59.56 ms
+ AVG: 63.30 ms
+ MIN: 57.88 ms
+ MAX: 99.89 ms
+Throughput: 16.79 FPS
+```
+
+```
+benchmark_app -m quant_ir_output/quantized_yolov5s.xml -d CPU -api sync
+```
+output:
+```
+Latency:
+ Median: 42.97 ms
+ AVG: 46.56 ms
+ MIN: 41.18 ms
+ MAX: 95.75 ms
+Throughput: 23.27 FPS
+```
\ No newline at end of file
diff --git a/deploying_service/deploying_pytorch/convert_openvino/convert_yolov5/compare_fps.py b/deploying_service/deploying_pytorch/convert_openvino/convert_yolov5/compare_fps.py
new file mode 100644
index 000000000..0a4abfd84
--- /dev/null
+++ b/deploying_service/deploying_pytorch/convert_openvino/convert_yolov5/compare_fps.py
@@ -0,0 +1,121 @@
+import time
+import numpy as np
+import torch
+import onnxruntime
+import matplotlib.pyplot as plt
+from openvino.runtime import Core
+
+
+def normalize(image: np.ndarray) -> np.ndarray:
+ """
+ Normalize the image to the given mean and standard deviation
+ """
+ image = image.astype(np.float32)
+ image /= 255.0
+ return image
+
+
+def onnx_inference(onnx_path: str, image: np.ndarray, num_images: int = 20):
+ # load onnx model
+ ort_session = onnxruntime.InferenceSession(onnx_path)
+
+ # compute onnx Runtime output prediction
+ ort_inputs = {ort_session.get_inputs()[0].name: image}
+
+ start = time.perf_counter()
+ for _ in range(num_images):
+ ort_session.run(None, ort_inputs)
+ end = time.perf_counter()
+ time_onnx = end - start
+ print(
+ f"ONNX model in Inference Engine/CPU: {time_onnx / num_images:.3f} "
+ f"seconds per image, FPS: {num_images / time_onnx:.2f}"
+ )
+
+ return num_images / time_onnx
+
+
+def ir_inference(ir_path: str, image: np.ndarray, num_images: int = 20):
+ # Load the network in Inference Engine
+ ie = Core()
+ model_ir = ie.read_model(model=ir_path)
+ compiled_model_ir = ie.compile_model(model=model_ir, device_name="CPU")
+
+ # Get input and output layers
+ input_layer_ir = next(iter(compiled_model_ir.inputs))
+ output_layer_ir = next(iter(compiled_model_ir.outputs))
+
+ start = time.perf_counter()
+ request_ir = compiled_model_ir.create_infer_request()
+ for _ in range(num_images):
+ request_ir.infer(inputs={input_layer_ir.any_name: image})
+ end = time.perf_counter()
+ time_ir = end - start
+ print(
+ f"IR model in Inference Engine/CPU: {time_ir / num_images:.3f} "
+ f"seconds per image, FPS: {num_images / time_ir:.2f}"
+ )
+
+ return num_images / time_ir
+
+
+def pytorch_inference(image: np.ndarray, num_images: int = 20):
+ image = torch.as_tensor(image, dtype=torch.float32)
+
+ model = torch.hub.load('ultralytics/yolov5', 'yolov5s')
+ model.eval()
+
+ with torch.no_grad():
+ start = time.perf_counter()
+ for _ in range(num_images):
+ model(image)
+ end = time.perf_counter()
+ time_torch = end - start
+
+ print(
+ f"PyTorch model on CPU: {time_torch / num_images:.3f} seconds per image, "
+ f"FPS: {num_images / time_torch:.2f}"
+ )
+
+ return num_images / time_torch
+
+
+def plot_fps(v: dict):
+ x = list(v.keys())
+ y = list(v.values())
+
+ plt.bar(range(len(x)), y, align='center')
+ plt.xticks(range(len(x)), x)
+ for i, v in enumerate(y):
+ plt.text(x=i, y=v+0.5, s=f"{v:.2f}", ha='center')
+ plt.xlabel('model format')
+ plt.ylabel('fps')
+ plt.title('FPS comparison')
+ plt.show()
+ plt.savefig('fps_vs.jpg')
+
+
+def main():
+ image_h = 640
+ image_w = 640
+ onnx_path = "yolov5s.onnx"
+ ir_path = "ir_output/yolov5s.xml"
+
+ image = np.random.randn(image_h, image_w, 3)
+ normalized_image = normalize(image)
+
+ # Convert the resized images to network input shape
+ # [h, w, c] -> [c, h, w] -> [1, c, h, w]
+ input_image = np.expand_dims(np.transpose(image, (2, 0, 1)), 0)
+ normalized_input_image = np.expand_dims(np.transpose(normalized_image, (2, 0, 1)), 0)
+
+ onnx_fps = onnx_inference(onnx_path, normalized_input_image, num_images=100)
+ ir_fps = ir_inference(ir_path, input_image, num_images=100)
+ pytorch_fps = pytorch_inference(normalized_input_image, num_images=100)
+ plot_fps({"pytorch": round(pytorch_fps, 2),
+ "onnx": round(onnx_fps, 2),
+ "ir": round(ir_fps, 2)})
+
+
+if __name__ == '__main__':
+ main()
diff --git a/deploying_service/deploying_pytorch/convert_openvino/convert_yolov5/compare_onnx_and_ir.py b/deploying_service/deploying_pytorch/convert_openvino/convert_yolov5/compare_onnx_and_ir.py
new file mode 100644
index 000000000..110f22e3c
--- /dev/null
+++ b/deploying_service/deploying_pytorch/convert_openvino/convert_yolov5/compare_onnx_and_ir.py
@@ -0,0 +1,61 @@
+import numpy as np
+import onnxruntime
+from openvino.runtime import Core
+
+
+def normalize(image: np.ndarray) -> np.ndarray:
+ """
+ Normalize the image to the given mean and standard deviation
+ """
+ image = image.astype(np.float32)
+ image /= 255.0
+ return image
+
+
+def onnx_inference(onnx_path: str, image: np.ndarray):
+ # load onnx model
+ ort_session = onnxruntime.InferenceSession(onnx_path)
+
+ # compute onnx Runtime output prediction
+ ort_inputs = {ort_session.get_inputs()[0].name: image}
+ res_onnx = ort_session.run(None, ort_inputs)[0]
+ return res_onnx
+
+
+def ir_inference(ir_path: str, image: np.ndarray):
+ # Load the network in Inference Engine
+ ie = Core()
+ model_ir = ie.read_model(model=ir_path)
+ compiled_model_ir = ie.compile_model(model=model_ir, device_name="CPU")
+
+ # Get input and output layers
+ input_layer_ir = next(iter(compiled_model_ir.inputs))
+ output_layer_ir = next(iter(compiled_model_ir.outputs))
+
+ # Run inference on the input image
+ res_ir = compiled_model_ir([image])[output_layer_ir]
+ return res_ir
+
+
+def main():
+ image_h = 640
+ image_w = 640
+ onnx_path = "yolov5s.onnx"
+ ir_path = "ir_output/yolov5s.xml"
+
+ image = np.random.randn(image_h, image_w, 3)
+ normalized_image = normalize(image)
+
+ # Convert the resized images to network input shape
+ # [h, w, c] -> [c, h, w] -> [1, c, h, w]
+ input_image = np.expand_dims(np.transpose(image, (2, 0, 1)), 0)
+ normalized_input_image = np.expand_dims(np.transpose(normalized_image, (2, 0, 1)), 0)
+
+ onnx_res = onnx_inference(onnx_path, normalized_input_image)
+ ir_res = ir_inference(ir_path, input_image)
+ np.testing.assert_allclose(onnx_res, ir_res, rtol=1e-03, atol=1e-05)
+ print("Exported model has been tested with OpenvinoRuntime, and the result looks good!")
+
+
+if __name__ == '__main__':
+ main()
diff --git a/deploying_service/deploying_pytorch/convert_openvino/convert_yolov5/draw_box_utils.py b/deploying_service/deploying_pytorch/convert_openvino/convert_yolov5/draw_box_utils.py
new file mode 100644
index 000000000..835d7f7c1
--- /dev/null
+++ b/deploying_service/deploying_pytorch/convert_openvino/convert_yolov5/draw_box_utils.py
@@ -0,0 +1,153 @@
+from PIL.Image import Image, fromarray
+import PIL.ImageDraw as ImageDraw
+import PIL.ImageFont as ImageFont
+from PIL import ImageColor
+import numpy as np
+
+STANDARD_COLORS = [
+ 'AliceBlue', 'Chartreuse', 'Aqua', 'Aquamarine', 'Azure', 'Beige', 'Bisque',
+ 'BlanchedAlmond', 'BlueViolet', 'BurlyWood', 'CadetBlue', 'AntiqueWhite',
+ 'Chocolate', 'Coral', 'CornflowerBlue', 'Cornsilk', 'Crimson', 'Cyan',
+ 'DarkCyan', 'DarkGoldenRod', 'DarkGrey', 'DarkKhaki', 'DarkOrange',
+ 'DarkOrchid', 'DarkSalmon', 'DarkSeaGreen', 'DarkTurquoise', 'DarkViolet',
+ 'DeepPink', 'DeepSkyBlue', 'DodgerBlue', 'FireBrick', 'FloralWhite',
+ 'ForestGreen', 'Fuchsia', 'Gainsboro', 'GhostWhite', 'Gold', 'GoldenRod',
+ 'Salmon', 'Tan', 'HoneyDew', 'HotPink', 'IndianRed', 'Ivory', 'Khaki',
+ 'Lavender', 'LavenderBlush', 'LawnGreen', 'LemonChiffon', 'LightBlue',
+ 'LightCoral', 'LightCyan', 'LightGoldenRodYellow', 'LightGray', 'LightGrey',
+ 'LightGreen', 'LightPink', 'LightSalmon', 'LightSeaGreen', 'LightSkyBlue',
+ 'LightSlateGray', 'LightSlateGrey', 'LightSteelBlue', 'LightYellow', 'Lime',
+ 'LimeGreen', 'Linen', 'Magenta', 'MediumAquaMarine', 'MediumOrchid',
+ 'MediumPurple', 'MediumSeaGreen', 'MediumSlateBlue', 'MediumSpringGreen',
+ 'MediumTurquoise', 'MediumVioletRed', 'MintCream', 'MistyRose', 'Moccasin',
+ 'NavajoWhite', 'OldLace', 'Olive', 'OliveDrab', 'Orange', 'OrangeRed',
+ 'Orchid', 'PaleGoldenRod', 'PaleGreen', 'PaleTurquoise', 'PaleVioletRed',
+ 'PapayaWhip', 'PeachPuff', 'Peru', 'Pink', 'Plum', 'PowderBlue', 'Purple',
+ 'Red', 'RosyBrown', 'RoyalBlue', 'SaddleBrown', 'Green', 'SandyBrown',
+ 'SeaGreen', 'SeaShell', 'Sienna', 'Silver', 'SkyBlue', 'SlateBlue',
+ 'SlateGray', 'SlateGrey', 'Snow', 'SpringGreen', 'SteelBlue', 'GreenYellow',
+ 'Teal', 'Thistle', 'Tomato', 'Turquoise', 'Violet', 'Wheat', 'White',
+ 'WhiteSmoke', 'Yellow', 'YellowGreen'
+]
+
+
+def draw_text(draw,
+ box: list,
+ cls: int,
+ score: float,
+ category_index: dict,
+ color: str,
+ font: str = 'arial.ttf',
+ font_size: int = 24):
+ """
+ 将目标边界框和类别信息绘制到图片上
+ """
+ try:
+ font = ImageFont.truetype(font, font_size)
+ except IOError:
+ font = ImageFont.load_default()
+
+ left, top, right, bottom = box
+ # If the total height of the display strings added to the top of the bounding
+ # box exceeds the top of the image, stack the strings below the bounding box
+ # instead of above.
+ display_str = f"{category_index[str(cls)]}: {int(100 * score)}%"
+ display_str_heights = [font.getsize(ds)[1] for ds in display_str]
+ # Each display_str has a top and bottom margin of 0.05x.
+ display_str_height = (1 + 2 * 0.05) * max(display_str_heights)
+
+ if top > display_str_height:
+ text_top = top - display_str_height
+ text_bottom = top
+ else:
+ text_top = bottom
+ text_bottom = bottom + display_str_height
+
+ for ds in display_str:
+ text_width, text_height = font.getsize(ds)
+ margin = np.ceil(0.05 * text_width)
+ draw.rectangle([(left, text_top),
+ (left + text_width + 2 * margin, text_bottom)], fill=color)
+ draw.text((left + margin, text_top),
+ ds,
+ fill='black',
+ font=font)
+ left += text_width
+
+
+def draw_masks(image, masks, colors, thresh: float = 0.7, alpha: float = 0.5):
+ np_image = np.array(image)
+ masks = np.where(masks > thresh, True, False)
+
+ # colors = np.array(colors)
+ img_to_draw = np.copy(np_image)
+ # TODO: There might be a way to vectorize this
+ for mask, color in zip(masks, colors):
+ img_to_draw[mask] = color
+
+ out = np_image * (1 - alpha) + img_to_draw * alpha
+ return fromarray(out.astype(np.uint8))
+
+
+def draw_objs(image: Image,
+ boxes: np.ndarray = None,
+ classes: np.ndarray = None,
+ scores: np.ndarray = None,
+ masks: np.ndarray = None,
+ category_index: dict = None,
+ box_thresh: float = 0.1,
+ mask_thresh: float = 0.5,
+ line_thickness: int = 8,
+ font: str = 'arial.ttf',
+ font_size: int = 24,
+ draw_boxes_on_image: bool = True,
+ draw_masks_on_image: bool = False):
+ """
+ 将目标边界框信息,类别信息,mask信息绘制在图片上
+ Args:
+ image: 需要绘制的图片
+ boxes: 目标边界框信息
+ classes: 目标类别信息
+ scores: 目标概率信息
+ masks: 目标mask信息
+ category_index: 类别与名称字典
+ box_thresh: 过滤的概率阈值
+ mask_thresh:
+ line_thickness: 边界框宽度
+ font: 字体类型
+ font_size: 字体大小
+ draw_boxes_on_image:
+ draw_masks_on_image:
+
+ Returns:
+
+ """
+
+ # 过滤掉低概率的目标
+ idxs = np.greater(scores, box_thresh)
+ boxes = boxes[idxs]
+ classes = classes[idxs]
+ scores = scores[idxs]
+ if masks is not None:
+ masks = masks[idxs]
+ if len(boxes) == 0:
+ return image
+
+ colors = [ImageColor.getrgb(STANDARD_COLORS[cls % len(STANDARD_COLORS)]) for cls in classes]
+
+ if draw_boxes_on_image:
+ # Draw all boxes onto image.
+ draw = ImageDraw.Draw(image)
+ for box, cls, score, color in zip(boxes, classes, scores, colors):
+ left, top, right, bottom = box
+ # 绘制目标边界框
+ draw.line([(left, top), (left, bottom), (right, bottom),
+ (right, top), (left, top)], width=line_thickness, fill=color)
+ # 绘制类别和概率信息
+ draw_text(draw, box.tolist(), int(cls), float(score), category_index, color, font, font_size)
+
+ if draw_masks_on_image and (masks is not None):
+ # Draw all mask onto image.
+ image = draw_masks(image, masks, colors, mask_thresh)
+
+ return image
diff --git a/deploying_service/deploying_pytorch/convert_openvino/convert_yolov5/evaluation.py b/deploying_service/deploying_pytorch/convert_openvino/convert_yolov5/evaluation.py
new file mode 100644
index 000000000..96f1ada13
--- /dev/null
+++ b/deploying_service/deploying_pytorch/convert_openvino/convert_yolov5/evaluation.py
@@ -0,0 +1,44 @@
+from tqdm import tqdm
+import torch
+from openvino.runtime import Core
+from utils import MyDataLoader, EvalCOCOMetric, non_max_suppression
+
+
+def main():
+ data_path = "/data/coco2017"
+ ir_model_xml = "quant_ir_output/quantized_yolov5s.xml"
+ img_size = (640, 640) # h, w
+
+ data_loader = MyDataLoader(data_path, "val", size=img_size)
+ coco80_to_91 = data_loader.coco_id80_to_id91
+ metrics = EvalCOCOMetric(coco=data_loader.coco, classes_mapping=coco80_to_91)
+
+ # Load the network in Inference Engine
+ ie = Core()
+ model_ir = ie.read_model(model=ir_model_xml)
+ compiled_model = ie.compile_model(model=model_ir, device_name="CPU")
+ inputs_names = compiled_model.inputs
+ outputs_names = compiled_model.outputs
+
+ # inference
+ request = compiled_model.create_infer_request()
+ for i in tqdm(range(len(data_loader))):
+ data = data_loader[i]
+ ann, img, info = data
+ ann = ann + (info,)
+
+ request.infer(inputs={inputs_names[0]: img})
+ result = request.get_output_tensor(outputs_names[0].index).data
+
+ # post-process
+ result = non_max_suppression(torch.Tensor(result), conf_thres=0.001, iou_thres=0.6, multi_label=True)[0]
+ boxes = result[:, :4].numpy()
+ scores = result[:, 4].numpy()
+ cls = result[:, 5].numpy().astype(int)
+ metrics.update(ann, [boxes, cls, scores])
+
+ metrics.evaluate()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/deploying_service/deploying_pytorch/convert_openvino/convert_yolov5/predict.py b/deploying_service/deploying_pytorch/convert_openvino/convert_yolov5/predict.py
new file mode 100644
index 000000000..6f01b5709
--- /dev/null
+++ b/deploying_service/deploying_pytorch/convert_openvino/convert_yolov5/predict.py
@@ -0,0 +1,50 @@
+import cv2
+import numpy as np
+import torch
+from PIL import Image
+import matplotlib.pyplot as plt
+from openvino.runtime import Core
+from utils import letterbox, scale_coords, non_max_suppression, coco80_names
+from draw_box_utils import draw_objs
+
+
+def main():
+ img_path = "test.jpg"
+ ir_model_xml = "ir_output/yolov5s.xml"
+ img_size = (640, 640) # h, w
+
+ origin_img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
+ reshape_img, ratio, pad = letterbox(origin_img, img_size, auto=False)
+ input_img = np.expand_dims(np.transpose(reshape_img, [2, 0, 1]), 0).astype(np.float32)
+
+ # Load the network in Inference Engine
+ ie = Core()
+ model_ir = ie.read_model(model=ir_model_xml)
+ compiled_model = ie.compile_model(model=model_ir, device_name="CPU")
+ inputs_names = compiled_model.inputs
+ outputs_names = compiled_model.outputs
+
+ # inference
+ request = compiled_model.create_infer_request()
+ request.infer(inputs={inputs_names[0]: input_img})
+ result = request.get_output_tensor(outputs_names[0].index).data
+
+ # post-process
+ result = non_max_suppression(torch.Tensor(result))[0]
+ boxes = result[:, :4].numpy()
+ scores = result[:, 4].numpy()
+ cls = result[:, 5].numpy().astype(int)
+ boxes = scale_coords(reshape_img.shape, boxes, origin_img.shape, (ratio, pad))
+
+ draw_img = draw_objs(Image.fromarray(origin_img),
+ boxes,
+ cls,
+ scores,
+ category_index=dict([(str(i), v) for i, v in enumerate(coco80_names)]))
+ plt.imshow(draw_img)
+ plt.show()
+ draw_img.save("predict.jpg")
+
+
+if __name__ == '__main__':
+ main()
diff --git a/deploying_service/deploying_pytorch/convert_openvino/convert_yolov5/quantization_int8.py b/deploying_service/deploying_pytorch/convert_openvino/convert_yolov5/quantization_int8.py
new file mode 100644
index 000000000..d0decff0b
--- /dev/null
+++ b/deploying_service/deploying_pytorch/convert_openvino/convert_yolov5/quantization_int8.py
@@ -0,0 +1,102 @@
+import time
+from addict import Dict
+from compression.engines.ie_engine import IEEngine
+from compression.graph import load_model, save_model
+from compression.graph.model_utils import compress_model_weights
+from compression.pipeline.initializer import create_pipeline
+from yaspin import yaspin
+from utils import MyDataLoader, MAPMetric
+
+
+def main():
+ data_path = "/data/coco2017"
+ ir_model_xml = "ir_output/yolov5s.xml"
+ ir_model_bin = "ir_output/yolov5s.bin"
+ save_dir = "quant_ir_output"
+ model_name = "quantized_yolov5s"
+ img_w = 640
+ img_h = 640
+
+ model_config = Dict({
+ 'model_name': 'yolov5s',
+ 'model': ir_model_xml,
+ 'weights': ir_model_bin,
+ 'inputs': 'images',
+ 'outputs': 'output'
+ })
+ engine_config = Dict({'device': 'CPU'})
+
+ algorithms = [
+ {
+ 'name': 'DefaultQuantization',
+ 'params': {
+ 'target_device': 'CPU',
+ 'preset': 'performance',
+ 'stat_subset_size': 300
+ }
+ }
+ ]
+
+ # Step 1: Load the model.
+ model = load_model(model_config)
+
+ # Step 2: Initialize the data loader.
+ data_loader = MyDataLoader(data_path, "val", (img_h, img_w))
+
+ # Step 3: initialize the metric
+ # For DefaultQuantization, specifying a metric is optional: metric can be set to None
+ metric = MAPMetric(map_value="map")
+
+ # Step 4: Initialize the engine for metric calculation and statistics collection.
+ engine = IEEngine(config=engine_config, data_loader=data_loader, metric=metric)
+
+ # Step 5: Create a pipeline of compression algorithms.
+ pipeline = create_pipeline(algorithms, engine)
+
+ # Step 6: Execute the pipeline to quantize the model
+ algorithm_name = pipeline.algo_seq[0].name
+ with yaspin(
+ text=f"Executing POT pipeline on {model_config['model']} with {algorithm_name}"
+ ) as sp:
+ start_time = time.perf_counter()
+ compressed_model = pipeline.run(model)
+ end_time = time.perf_counter()
+ sp.ok("✔")
+ print(f"Quantization finished in {end_time - start_time:.2f} seconds")
+
+ # Step 7 (Optional): Compress model weights to quantized precision
+ # in order to reduce the size of the final .bin file
+ compress_model_weights(compressed_model)
+
+ # Step 8: Save the compressed model to the desired path.
+ # Set save_path to the directory where the compressed model should be stored
+ compressed_model_paths = save_model(
+ model=compressed_model,
+ save_path=save_dir,
+ model_name=model_name,
+ )
+
+ compressed_model_path = compressed_model_paths[0]["model"]
+ print("The quantized model is stored at", compressed_model_path)
+
+ # Compute the mAP on the quantized model and compare with the mAP on the FP16 IR model.
+ ir_model = load_model(model_config=model_config)
+ evaluation_pipeline = create_pipeline(algo_config=dict(), engine=engine)
+
+ with yaspin(text="Evaluating original IR model") as sp:
+ original_metric = evaluation_pipeline.evaluate(ir_model)
+
+ if original_metric:
+ for key, value in original_metric.items():
+ print(f"The {key} score of the original model is {value:.5f}")
+
+ with yaspin(text="Evaluating quantized IR model") as sp:
+ quantized_metric = pipeline.evaluate(compressed_model)
+
+ if quantized_metric:
+ for key, value in quantized_metric.items():
+ print(f"The {key} score of the quantized INT8 model is {value:.5f}")
+
+
+if __name__ == '__main__':
+ main()
diff --git a/deploying_service/deploying_pytorch/convert_openvino/convert_yolov5/requirements.txt b/deploying_service/deploying_pytorch/convert_openvino/convert_yolov5/requirements.txt
new file mode 100644
index 000000000..30b17622c
--- /dev/null
+++ b/deploying_service/deploying_pytorch/convert_openvino/convert_yolov5/requirements.txt
@@ -0,0 +1,8 @@
+torch==1.13.1
+torchvision==0.12.0
+onnx==1.13.0
+onnxruntime==1.8.0
+protobuf==3.19.5
+openvino-dev==2022.1.0
+matplotlib
+torchmetrics==0.9.1
\ No newline at end of file
diff --git a/deploying_service/deploying_pytorch/convert_openvino/convert_yolov5/utils.py b/deploying_service/deploying_pytorch/convert_openvino/convert_yolov5/utils.py
new file mode 100644
index 000000000..e3bcf6d4a
--- /dev/null
+++ b/deploying_service/deploying_pytorch/convert_openvino/convert_yolov5/utils.py
@@ -0,0 +1,552 @@
+import os
+import time
+import json
+import copy
+
+import cv2
+import numpy as np
+import torch
+from torchmetrics.detection.mean_ap import MeanAveragePrecision
+import torchvision
+from pycocotools.coco import COCO
+from pycocotools.cocoeval import COCOeval
+from compression.api import DataLoader, Metric
+
+
+coco80_names = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
+ 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
+ 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
+ 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
+ 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
+ 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
+ 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
+ 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
+ 'scissors', 'teddy bear', 'hair drier', 'toothbrush']
+
+
+def box_iou(box1, box2):
+ # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
+ """
+ Return intersection-over-union (Jaccard index) of boxes.
+ Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
+ Arguments:
+ box1 (Tensor[N, 4])
+ box2 (Tensor[M, 4])
+ Returns:
+ iou (Tensor[N, M]): the NxM matrix containing the pairwise
+ IoU values for every element in boxes1 and boxes2
+ """
+
+ def box_area(box):
+ # box = 4xn
+ return (box[2] - box[0]) * (box[3] - box[1])
+
+ area1 = box_area(box1.T)
+ area2 = box_area(box2.T)
+
+ # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
+ inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
+ return inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter)
+
+
+def xywh2xyxy(x):
+ # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
+ y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
+ y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
+ y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
+ y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
+ return y
+
+
+def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
+ labels=(), max_det=300):
+ """Runs Non-Maximum Suppression (NMS) on inference results
+
+ Returns:
+ list of detections, on (n,6) tensor per image [xyxy, conf, cls]
+ """
+
+ nc = prediction.shape[2] - 5 # number of classes
+ xc = prediction[..., 4] > conf_thres # candidates
+
+ # Checks
+ assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
+ assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
+
+ # Settings
+ min_wh, max_wh = 2, 7680 # (pixels) minimum and maximum box width and height
+ max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
+ time_limit = 10.0 # seconds to quit after
+ redundant = True # require redundant detections
+ multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
+ merge = False # use merge-NMS
+
+ t = time.time()
+ output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
+ for xi, x in enumerate(prediction): # image index, image inference
+ # Apply constraints
+ x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
+ x = x[xc[xi]] # confidence
+
+ # Cat apriori labels if autolabelling
+ if labels and len(labels[xi]):
+ lb = labels[xi]
+ v = torch.zeros((len(lb), nc + 5), device=x.device)
+ v[:, :4] = lb[:, 1:5] # box
+ v[:, 4] = 1.0 # conf
+ v[range(len(lb)), lb[:, 0].long() + 5] = 1.0 # cls
+ x = torch.cat((x, v), 0)
+
+ # If none remain process next image
+ if not x.shape[0]:
+ continue
+
+ # Compute conf
+ x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
+
+ # Box (center x, center y, width, height) to (x1, y1, x2, y2)
+ box = xywh2xyxy(x[:, :4])
+
+ # Detections matrix nx6 (xyxy, conf, cls)
+ if multi_label:
+ i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
+ x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
+ else: # best class only
+ conf, j = x[:, 5:].max(1, keepdim=True)
+ x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
+
+ # Filter by class
+ if classes is not None:
+ x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
+
+ # Apply finite constraint
+ # if not torch.isfinite(x).all():
+ # x = x[torch.isfinite(x).all(1)]
+
+ # Check shape
+ n = x.shape[0] # number of boxes
+ if not n: # no boxes
+ continue
+ elif n > max_nms: # excess boxes
+ x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
+
+ # Batched NMS
+ c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
+ boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
+ i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
+ if i.shape[0] > max_det: # limit detections
+ i = i[:max_det]
+ if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
+ # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
+ iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
+ weights = iou * scores[None] # box weights
+ x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
+ if redundant:
+ i = i[iou.sum(1) > 1] # require redundancy
+
+ output[xi] = x[i]
+ if (time.time() - t) > time_limit:
+ print(f'WARNING: NMS time limit {time_limit}s exceeded')
+ break # time limit exceeded
+
+ return output
+
+
+class MAPMetric(Metric):
+ def __init__(self, map_value="map", conf_thres=0.001, iou_thres=0.6):
+ """
+ Mean Average Precision Metric. Wraps torchmetrics implementation, see
+ https://torchmetrics.readthedocs.io/en/stable/detection/mean_average_precision.html
+
+ :map_value: specific metric to return. Default: "map"
+ Change `to one of the values in the list below to return a different value
+ ['mar_1', 'mar_10', 'mar_100', 'mar_small', 'mar_medium', 'mar_large',
+ 'map', 'map_50', 'map_75', 'map_small', 'map_medium', 'map_large']
+ See torchmetrics documentation for more details.
+ """
+
+ self._name = map_value
+ self.metric = MeanAveragePrecision(box_format="xyxy")
+ self.conf_thres = conf_thres
+ self.iou_thres = iou_thres
+ super().__init__()
+
+ @property
+ def value(self):
+ """
+ Returns metric value for the last model output.
+ Possible format: {metric_name: [metric_values_per_image]}
+ """
+ return {self._name: [0]}
+
+ @property
+ def avg_value(self):
+ """
+ Returns average metric value for all model outputs.
+ Possible format: {metric_name: metric_value}
+ """
+ return {self._name: self.metric.compute()[self._name].item()}
+
+ def update(self, output, target):
+ """
+ Convert network output and labels to the format that torchmetrics' MAP
+ implementation expects, and call `metric.update()`.
+
+ :param output: model output
+ :param target: annotations for model output
+ """
+ targetboxes = []
+ targetlabels = []
+ predboxes = []
+ predlabels = []
+ scores = []
+
+ for single_target in target[0]:
+ txmin, tymin, txmax, tymax = single_target["bbox"]
+ category = single_target["category_id"]
+
+ targetbox = [round(txmin), round(tymin), round(txmax), round(tymax)]
+ targetboxes.append(targetbox)
+ targetlabels.append(category)
+
+ output = torch.Tensor(output[0]).float()
+ output = non_max_suppression(output, conf_thres=self.conf_thres, iou_thres=self.iou_thres, multi_label=True)
+ for single_output in output:
+ for pred in single_output.numpy():
+ xmin, ymin, xmax, ymax, conf, label = pred
+
+ predbox = [round(xmin), round(ymin), round(xmax), round(ymax)]
+ predboxes.append(predbox)
+ predlabels.append(label)
+ scores.append(conf)
+
+ preds = [
+ dict(
+ boxes=torch.Tensor(predboxes).float(),
+ labels=torch.Tensor(predlabels).short(),
+ scores=torch.Tensor(scores),
+ )
+ ]
+ targets = [
+ dict(
+ boxes=torch.Tensor(targetboxes).float(),
+ labels=torch.Tensor(targetlabels).short(),
+ )
+ ]
+ self.metric.update(preds, targets)
+
+ def reset(self):
+ """
+ Resets metric
+ """
+ self.metric.reset()
+
+ def get_attributes(self):
+ """
+ Returns a dictionary of metric attributes {metric_name: {attribute_name: value}}.
+ Required attributes: 'direction': 'higher-better' or 'higher-worse'
+ 'type': metric type
+ """
+ return {self._name: {"direction": "higher-better", "type": "mAP"}}
+
+
+def _coco_remove_images_without_annotations(dataset, ids):
+ """
+ 删除coco数据集中没有目标,或者目标面积非常小的数据
+ refer to:
+ https://github.com/pytorch/vision/blob/master/references/detection/coco_utils.py
+ :param dataset:
+ :param cat_list:
+ :return:
+ """
+ def _has_only_empty_bbox(anno):
+ return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno)
+
+ def _has_valid_annotation(anno):
+ # if it's empty, there is no annotation
+ if len(anno) == 0:
+ return False
+ # if all boxes have close to zero area, there is no annotation
+ if _has_only_empty_bbox(anno):
+ return False
+
+ return True
+
+ valid_ids = []
+ for ds_idx, img_id in enumerate(ids):
+ ann_ids = dataset.getAnnIds(imgIds=img_id, iscrowd=None)
+ anno = dataset.loadAnns(ann_ids)
+
+ if _has_valid_annotation(anno):
+ valid_ids.append(img_id)
+
+ return valid_ids
+
+
+def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
+ # Rescale coords (xyxy) from img1_shape to img0_shape
+ if ratio_pad is None: # calculate from img0_shape
+ gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
+ pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
+ else:
+ assert ratio_pad[0][0] == ratio_pad[0][1]
+ gain = ratio_pad[0][0]
+ pad = ratio_pad[1]
+
+ coords[:, [0, 2]] -= pad[0] # x padding
+ coords[:, [1, 3]] -= pad[1] # y padding
+ coords[:, :4] /= gain
+ clip_coords(coords, img0_shape)
+ return coords
+
+
+def clip_coords(boxes, shape):
+ # Clip bounding xyxy bounding boxes to image shape (height, width)
+ if isinstance(boxes, torch.Tensor): # faster individually
+ boxes[:, 0].clamp_(0, shape[1]) # x1
+ boxes[:, 1].clamp_(0, shape[0]) # y1
+ boxes[:, 2].clamp_(0, shape[1]) # x2
+ boxes[:, 3].clamp_(0, shape[0]) # y2
+ else: # np.array (faster grouped)
+ boxes[:, [0, 2]] = boxes[:, [0, 2]].clip(0, shape[1]) # x1, x2
+ boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, shape[0]) # y1, y2
+
+
+def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32):
+ # Resize and pad image while meeting stride-multiple constraints
+ shape = im.shape[:2] # current shape [height, width]
+ if isinstance(new_shape, int):
+ new_shape = (new_shape, new_shape)
+
+ # Scale ratio (new / old)
+ r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
+ if not scaleup: # only scale down, do not scale up (for better val mAP)
+ r = min(r, 1.0)
+
+ # Compute padding
+ ratio = r, r # width, height ratios
+ new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
+ dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
+ if auto: # minimum rectangle
+ dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding
+ elif scaleFill: # stretch
+ dw, dh = 0.0, 0.0
+ new_unpad = (new_shape[1], new_shape[0])
+ ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
+
+ dw /= 2 # divide padding into 2 sides
+ dh /= 2
+
+ if shape[::-1] != new_unpad: # resize
+ im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
+ top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
+ left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
+ im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
+ return im, ratio, (left, top)
+
+
+class MyDataLoader(DataLoader):
+ """`MS Coco Detection `_ Dataset.
+
+ Args:
+ root (string): Root directory where images are downloaded to.
+ dataset (string): "train" or "val.
+ size (tuple): (h, w)
+ """
+ def __init__(self, root, dataset="train", size=(640, 640)):
+ assert dataset in ["train", "val"], 'dataset must be in ["train", "val"]'
+ anno_file = "instances_{}2017.json".format(dataset)
+ assert os.path.exists(root), "file '{}' does not exist.".format(root)
+ self.img_root = os.path.join(root, "{}2017".format(dataset))
+ assert os.path.exists(self.img_root), "path '{}' does not exist.".format(self.img_root)
+ self.anno_path = os.path.join(root, "annotations", anno_file)
+ assert os.path.exists(self.anno_path), "file '{}' does not exist.".format(self.anno_path)
+
+ self.mode = dataset
+ self.size = size
+ self.coco = COCO(self.anno_path)
+
+ self.coco91_id2classes = dict([(v["id"], v["name"]) for k, v in self.coco.cats.items()])
+ coco90_classes2id = dict([(v["name"], v["id"]) for k, v in self.coco.cats.items()])
+
+ self.coco80_classes = coco80_names
+ self.coco_id80_to_id91 = dict([(i, coco90_classes2id[k]) for i, k in enumerate(coco80_names)])
+
+ ids = list(sorted(self.coco.imgs.keys()))
+
+ # 移除没有目标,或者目标面积非常小的数据
+ valid_ids = _coco_remove_images_without_annotations(self.coco, ids)
+ self.ids = valid_ids
+
+ def parse_targets(self,
+ coco_targets: list,
+ w: int = None,
+ h: int = None,
+ ratio: tuple = None,
+ pad: tuple = None):
+ assert w > 0
+ assert h > 0
+
+ # 只筛选出单个对象的情况
+ anno = [obj for obj in coco_targets if obj['iscrowd'] == 0]
+
+ boxes = [obj["bbox"] for obj in anno]
+
+ # guard against no boxes via resizing
+ boxes = np.array(boxes, dtype=np.float32).reshape(-1, 4)
+ # [xmin, ymin, w, h] -> [xmin, ymin, xmax, ymax]
+ boxes[:, 2:] += boxes[:, :2]
+ boxes[:, 0::2] = np.clip(boxes[:, 0::2], a_min=0, a_max=w)
+ boxes[:, 1::2] = np.clip(boxes[:, 1::2], a_min=0, a_max=h)
+
+ classes = [self.coco80_classes.index(self.coco91_id2classes[obj["category_id"]])
+ for obj in anno]
+ classes = np.array(classes, dtype=int)
+
+ # 筛选出合法的目标,即x_max>x_min且y_max>y_min
+ keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
+ boxes = boxes[keep]
+ classes = classes[keep]
+
+ if ratio is not None:
+ # width, height ratios
+ boxes[:, 0::2] *= ratio[0]
+ boxes[:, 1::2] *= ratio[1]
+
+ if pad is not None:
+ # dw, dh padding
+ dw, dh = pad
+ boxes[:, 0::2] += dw
+ boxes[:, 1::2] += dh
+
+ target_annotations = []
+ for i in range(boxes.shape[0]):
+ target_annotation = {
+ "category_id": int(classes[i]),
+ "bbox": boxes[i].tolist()
+ }
+ target_annotations.append(target_annotation)
+
+ return target_annotations
+
+ def __getitem__(self, index):
+ """
+ Get an item from the dataset at the specified index.
+ Detection boxes are converted from absolute coordinates to relative coordinates
+ between 0 and 1 by dividing xmin, xmax by image width and ymin, ymax by image height.
+
+ :return: (annotation, input_image, metadata) where annotation is (index, target_annotation)
+ with target_annotation as a dictionary with keys category_id, image_width, image_height
+ and bbox, containing the relative bounding box coordinates [xmin, ymin, xmax, ymax]
+ (with values between 0 and 1) and metadata a dictionary: {"filename": path_to_image}
+ """
+ coco = self.coco
+ img_id = self.ids[index]
+ ann_ids = coco.getAnnIds(imgIds=img_id)
+ coco_target = coco.loadAnns(ann_ids)
+
+ image_path = coco.loadImgs(img_id)[0]['file_name']
+ img = cv2.imread(os.path.join(self.img_root, image_path))
+
+ origin_h, origin_w, c = img.shape
+ image, ratio, pad = letterbox(img, auto=False, new_shape=self.size)
+ target_annotations = self.parse_targets(coco_target, origin_w, origin_h, ratio, pad)
+
+ item_annotation = (index, target_annotations)
+ input_image = np.expand_dims(image.transpose(2, 0, 1), axis=0).astype(
+ np.float32
+ )
+ return (
+ item_annotation,
+ input_image,
+ {"filename": str(image_path),
+ "origin_shape": img.shape,
+ "shape": image.shape,
+ "img_id": img_id,
+ "ratio_pad": [ratio, pad]},
+ )
+
+ def __len__(self):
+ return len(self.ids)
+
+ @staticmethod
+ def collate_fn(x):
+ return x
+
+
+class EvalCOCOMetric:
+ def __init__(self,
+ coco: COCO = None,
+ iou_type: str = "bbox",
+ results_file_name: str = "predict_results.json",
+ classes_mapping: dict = None):
+ self.coco = copy.deepcopy(coco)
+ self.results = []
+ self.classes_mapping = classes_mapping
+ self.coco_evaluator = None
+ assert iou_type in ["bbox"]
+ self.iou_type = iou_type
+ self.results_file_name = results_file_name
+
+ def prepare_for_coco_detection(self, ann, output):
+ """将预测的结果转换成COCOeval指定的格式,针对目标检测任务"""
+ # 遍历每张图像的预测结果
+ if len(output[0]) == 0:
+ return
+
+ img_id = ann[2]["img_id"]
+ per_image_boxes = output[0]
+ per_image_boxes = scale_coords(img1_shape=ann[2]["shape"],
+ coords=per_image_boxes,
+ img0_shape=ann[2]["origin_shape"],
+ ratio_pad=ann[2]["ratio_pad"])
+ # 对于coco_eval, 需要的每个box的数据格式为[x_min, y_min, w, h]
+ # 而我们预测的box格式是[x_min, y_min, x_max, y_max],所以需要转下格式
+ per_image_boxes[:, 2:] -= per_image_boxes[:, :2]
+ per_image_classes = output[1].tolist()
+ per_image_scores = output[2].tolist()
+
+ # 遍历每个目标的信息
+ for object_score, object_class, object_box in zip(
+ per_image_scores, per_image_classes, per_image_boxes):
+ object_score = float(object_score)
+ class_idx = int(object_class)
+ if self.classes_mapping is not None:
+ class_idx = self.classes_mapping[class_idx]
+ # We recommend rounding coordinates to the nearest tenth of a pixel
+ # to reduce resulting JSON file size.
+ object_box = [round(b, 2) for b in object_box.tolist()]
+
+ res = {"image_id": img_id,
+ "category_id": class_idx,
+ "bbox": object_box,
+ "score": round(object_score, 3)}
+ self.results.append(res)
+
+ def update(self, targets, outputs):
+ if self.iou_type == "bbox":
+ self.prepare_for_coco_detection(targets, outputs)
+ else:
+ raise KeyError(f"not support iou_type: {self.iou_type}")
+
+ def evaluate(self):
+ # write predict results into json file
+ json_str = json.dumps(self.results, indent=4)
+ with open(self.results_file_name, 'w') as json_file:
+ json_file.write(json_str)
+
+ # accumulate predictions from all images
+ coco_true = self.coco
+ coco_pre = coco_true.loadRes(self.results_file_name)
+
+ self.coco_evaluator = COCOeval(cocoGt=coco_true, cocoDt=coco_pre, iouType=self.iou_type)
+
+ self.coco_evaluator.evaluate()
+ self.coco_evaluator.accumulate()
+ print(f"IoU metric: {self.iou_type}")
+ self.coco_evaluator.summarize()
+
+ coco_info = self.coco_evaluator.stats.tolist() # numpy to list
+ return coco_info
+
diff --git a/deploying_service/deploying_pytorch/convert_tensorrt/convert_resnet34/compare_onnx_and_trt.py b/deploying_service/deploying_pytorch/convert_tensorrt/convert_resnet34/compare_onnx_and_trt.py
new file mode 100644
index 000000000..a9293236b
--- /dev/null
+++ b/deploying_service/deploying_pytorch/convert_tensorrt/convert_resnet34/compare_onnx_and_trt.py
@@ -0,0 +1,90 @@
+import numpy as np
+import tensorrt as trt
+import onnxruntime
+import pycuda.driver as cuda
+import pycuda.autoinit
+
+
+def normalize(image: np.ndarray) -> np.ndarray:
+ """
+ Normalize the image to the given mean and standard deviation
+ """
+ image = image.astype(np.float32)
+ mean = (0.485, 0.456, 0.406)
+ std = (0.229, 0.224, 0.225)
+ image /= 255.0
+ image -= mean
+ image /= std
+ return image
+
+
+def onnx_inference(onnx_path: str, image: np.ndarray):
+ # load onnx model
+ ort_session = onnxruntime.InferenceSession(onnx_path)
+
+ # compute onnx Runtime output prediction
+ ort_inputs = {ort_session.get_inputs()[0].name: image}
+ res_onnx = ort_session.run(None, ort_inputs)[0]
+ return res_onnx
+
+
+def trt_inference(trt_path: str, image: np.ndarray):
+ # Load the network in Inference Engine
+ trt_logger = trt.Logger(trt.Logger.WARNING)
+ with open(trt_path, "rb") as f, trt.Runtime(trt_logger) as runtime:
+ engine = runtime.deserialize_cuda_engine(f.read())
+
+ with engine.create_execution_context() as context:
+ # Set input shape based on image dimensions for inference
+ context.set_binding_shape(engine.get_binding_index("input"), (1, 3, image.shape[-2], image.shape[-1]))
+ # Allocate host and device buffers
+ bindings = []
+ for binding in engine:
+ binding_idx = engine.get_binding_index(binding)
+ size = trt.volume(context.get_binding_shape(binding_idx))
+ dtype = trt.nptype(engine.get_binding_dtype(binding))
+ if engine.binding_is_input(binding):
+ input_buffer = np.ascontiguousarray(image)
+ input_memory = cuda.mem_alloc(image.nbytes)
+ bindings.append(int(input_memory))
+ else:
+ output_buffer = cuda.pagelocked_empty(size, dtype)
+ output_memory = cuda.mem_alloc(output_buffer.nbytes)
+ bindings.append(int(output_memory))
+
+ stream = cuda.Stream()
+ # Transfer input data to the GPU.
+ cuda.memcpy_htod_async(input_memory, input_buffer, stream)
+ # Run inference
+ context.execute_async_v2(bindings=bindings, stream_handle=stream.handle)
+ # Transfer prediction output from the GPU.
+ cuda.memcpy_dtoh_async(output_buffer, output_memory, stream)
+ # Synchronize the stream
+ stream.synchronize()
+
+ res_trt = np.reshape(output_buffer, (1, -1))
+
+ return res_trt
+
+
+def main():
+ image_h = 224
+ image_w = 224
+ onnx_path = "resnet34.onnx"
+ trt_path = "trt_output/resnet34.trt"
+
+ image = np.random.randn(image_h, image_w, 3)
+ normalized_image = normalize(image)
+
+ # Convert the resized images to network input shape
+ # [h, w, c] -> [c, h, w] -> [1, c, h, w]
+ normalized_image = np.expand_dims(np.transpose(normalized_image, (2, 0, 1)), 0)
+
+ onnx_res = onnx_inference(onnx_path, normalized_image)
+ ir_res = trt_inference(trt_path, normalized_image)
+ np.testing.assert_allclose(onnx_res, ir_res, rtol=1e-03, atol=1e-05)
+ print("Exported model has been tested with TensorRT Runtime, and the result looks good!")
+
+
+if __name__ == '__main__':
+ main()
diff --git a/deploying_service/deploying_pytorch/convert_tensorrt/convert_resnet34/convert_pytorch2onnx.py b/deploying_service/deploying_pytorch/convert_tensorrt/convert_resnet34/convert_pytorch2onnx.py
new file mode 100644
index 000000000..7dec1c402
--- /dev/null
+++ b/deploying_service/deploying_pytorch/convert_tensorrt/convert_resnet34/convert_pytorch2onnx.py
@@ -0,0 +1,58 @@
+import torch
+import torch.onnx
+import onnx
+import onnxruntime
+import numpy as np
+from torchvision.models import resnet34
+
+device = torch.device("cpu")
+
+
+def to_numpy(tensor):
+ return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
+
+
+def main():
+ weights_path = "resNet34(flower).pth"
+ onnx_file_name = "resnet34.onnx"
+ batch_size = 1
+ img_h = 224
+ img_w = 224
+ img_channel = 3
+
+ # create model and load pretrain weights
+ model = resnet34(pretrained=False, num_classes=5)
+ model.load_state_dict(torch.load(weights_path, map_location='cpu'))
+
+ model.eval()
+ # input to the model
+ # [batch, channel, height, width]
+ x = torch.rand(batch_size, img_channel, img_h, img_w, requires_grad=True)
+ torch_out = model(x)
+
+ # export the model
+ torch.onnx.export(model, # model being run
+ x, # model input (or a tuple for multiple inputs)
+ onnx_file_name, # where to save the model (can be a file or file-like object)
+ input_names=["input"],
+ output_names=["output"],
+ verbose=False)
+
+ # check onnx model
+ onnx_model = onnx.load(onnx_file_name)
+ onnx.checker.check_model(onnx_model)
+
+ ort_session = onnxruntime.InferenceSession(onnx_file_name)
+
+ # compute ONNX Runtime output prediction
+ ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
+ ort_outs = ort_session.run(None, ort_inputs)
+
+ # compare ONNX Runtime and Pytorch results
+ # assert_allclose: Raises an AssertionError if two objects are not equal up to desired tolerance.
+ np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)
+ print("Exported model has been tested with ONNXRuntime, and the result looks good!")
+
+
+if __name__ == '__main__':
+ main()
diff --git a/deploying_service/deploying_pytorch/convert_tensorrt/convert_resnet34/my_dataset.py b/deploying_service/deploying_pytorch/convert_tensorrt/convert_resnet34/my_dataset.py
new file mode 100644
index 000000000..167bc9a30
--- /dev/null
+++ b/deploying_service/deploying_pytorch/convert_tensorrt/convert_resnet34/my_dataset.py
@@ -0,0 +1,37 @@
+from PIL import Image
+import torch
+from torch.utils.data import Dataset
+
+
+class MyDataSet(Dataset):
+ """自定义数据集"""
+
+ def __init__(self, images_path: list, images_class: list, transform=None):
+ self.images_path = images_path
+ self.images_class = images_class
+ self.transform = transform
+
+ def __len__(self):
+ return len(self.images_path)
+
+ def __getitem__(self, item):
+ img = Image.open(self.images_path[item])
+ # RGB为彩色图片,L为灰度图片
+ if img.mode != 'RGB':
+ raise ValueError("image: {} isn't RGB mode.".format(self.images_path[item]))
+ label = self.images_class[item]
+
+ if self.transform is not None:
+ img = self.transform(img)
+
+ return img, label
+
+ @staticmethod
+ def collate_fn(batch):
+ # 官方实现的default_collate可以参考
+ # https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py
+ images, labels = tuple(zip(*batch))
+
+ images = torch.stack(images, dim=0)
+ labels = torch.as_tensor(labels)
+ return images, labels
diff --git a/deploying_service/deploying_pytorch/convert_tensorrt/convert_resnet34/quantization.py b/deploying_service/deploying_pytorch/convert_tensorrt/convert_resnet34/quantization.py
new file mode 100644
index 000000000..6516ed744
--- /dev/null
+++ b/deploying_service/deploying_pytorch/convert_tensorrt/convert_resnet34/quantization.py
@@ -0,0 +1,196 @@
+"""
+refer to:
+https://docs.nvidia.com/deeplearning/tensorrt/pytorch-quantization-toolkit/docs/userguide.html
+"""
+import os
+import math
+import argparse
+
+from absl import logging
+from tqdm import tqdm
+import torch
+import torch.optim as optim
+import torch.optim.lr_scheduler as lr_scheduler
+from torchvision import transforms
+from torchvision.models.resnet import resnet34 as create_model
+from pytorch_quantization import nn as quant_nn
+from pytorch_quantization import quant_modules, calib
+from pytorch_quantization.tensor_quant import QuantDescriptor
+
+from my_dataset import MyDataSet
+from utils import read_split_data, train_one_epoch, evaluate
+
+logging.set_verbosity(logging.FATAL)
+
+
+def export_onnx(model, onnx_filename, onnx_bs):
+ model.eval()
+ # We have to shift to pytorch's fake quant ops before exporting the model to ONNX
+ quant_nn.TensorQuantizer.use_fb_fake_quant = True
+ opset_version = 13
+
+ print(f"Export ONNX file: {onnx_filename}")
+ dummy_input = torch.randn(onnx_bs, 3, 224, 224).cuda()
+ torch.onnx.export(model,
+ dummy_input,
+ onnx_filename,
+ verbose=False,
+ opset_version=opset_version,
+ enable_onnx_checker=False,
+ input_names=["input"],
+ output_names=["output"])
+
+
+def collect_stats(model, data_loader, num_batches):
+ """Feed data to the network and collect statistic"""
+
+ # Enable calibrators
+ for name, module in model.named_modules():
+ if isinstance(module, quant_nn.TensorQuantizer):
+ if module._calibrator is not None:
+ module.disable_quant()
+ module.enable_calib()
+ else:
+ module.disable()
+
+ for i, (images, _) in tqdm(enumerate(data_loader), total=num_batches):
+ model(images.cuda())
+ if i >= num_batches:
+ break
+
+ # Disable calibrators
+ for name, module in model.named_modules():
+ if isinstance(module, quant_nn.TensorQuantizer):
+ if module._calibrator is not None:
+ module.enable_quant()
+ module.disable_calib()
+ else:
+ module.enable()
+
+
+def compute_amax(model, **kwargs):
+ # Load calib result
+ for name, module in model.named_modules():
+ if isinstance(module, quant_nn.TensorQuantizer):
+ if module._calibrator is not None:
+ if isinstance(module._calibrator, calib.MaxCalibrator):
+ module.load_calib_amax()
+ else:
+ module.load_calib_amax(**kwargs)
+ print(f"{name:40}: {module}")
+ model.cuda()
+
+
+def main(args):
+ quant_modules.initialize()
+ assert torch.cuda.is_available(), "only support GPU!"
+
+ train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path)
+
+ data_transform = {
+ "train": transforms.Compose([transforms.RandomResizedCrop(224),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
+ "val": transforms.Compose([transforms.Resize(256),
+ transforms.CenterCrop(224),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])}
+
+ # 实例化训练数据集
+ train_dataset = MyDataSet(images_path=train_images_path,
+ images_class=train_images_label,
+ transform=data_transform["train"])
+
+ # 实例化验证数据集
+ val_dataset = MyDataSet(images_path=val_images_path,
+ images_class=val_images_label,
+ transform=data_transform["val"])
+
+ batch_size = args.batch_size
+ nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
+ print('Using {} dataloader workers every process'.format(nw))
+ train_loader = torch.utils.data.DataLoader(train_dataset,
+ batch_size=batch_size,
+ shuffle=True,
+ pin_memory=True,
+ num_workers=nw,
+ collate_fn=train_dataset.collate_fn)
+
+ val_loader = torch.utils.data.DataLoader(val_dataset,
+ batch_size=batch_size,
+ shuffle=False,
+ pin_memory=True,
+ num_workers=nw,
+ collate_fn=val_dataset.collate_fn)
+
+ # ########################## #
+ # Post Training Quantization #
+ # ########################## #
+ # We will use histogram based calibration for activations and the default max calibration for weights.
+ quant_desc_input = QuantDescriptor(calib_method='histogram')
+ quant_nn.QuantConv2d.set_default_quant_desc_input(quant_desc_input)
+ quant_nn.QuantLinear.set_default_quant_desc_input(quant_desc_input)
+
+ model = create_model(num_classes=args.num_classes)
+ assert os.path.exists(args.weights), "weights file: '{}' not exist.".format(args.weights)
+ model.load_state_dict(torch.load(args.weights, map_location='cpu'))
+ model.cuda()
+
+ # It is a bit slow since we collect histograms on CPU
+ with torch.no_grad():
+ collect_stats(model, val_loader, num_batches=1000 // batch_size)
+ compute_amax(model, method="percentile", percentile=99.99)
+ # validate
+ evaluate(model=model, data_loader=val_loader, epoch=0)
+
+ torch.save(model.state_dict(), "quant_model_calibrated.pth")
+
+ if args.qat:
+ # ########################### #
+ # Quantization Aware Training #
+ # ########################### #
+ pg = [p for p in model.parameters() if p.requires_grad]
+ optimizer = optim.SGD(pg, lr=args.lr, momentum=0.9, weight_decay=5E-5)
+ # Scheduler(half of a cosine period)
+ lf = lambda x: (math.cos(x * math.pi / 2 / args.epochs)) * (1 - args.lrf) + args.lrf
+ scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
+
+ for epoch in range(args.epochs):
+ # train
+ train_one_epoch(model=model, optimizer=optimizer, data_loader=train_loader, epoch=epoch)
+
+ scheduler.step()
+
+ # validate
+ evaluate(model=model, data_loader=val_loader, epoch=epoch)
+
+ export_onnx(model, args.onnx_filename, args.onnx_bs)
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--num_classes', type=int, default=5)
+ parser.add_argument('--epochs', type=int, default=5)
+ parser.add_argument('--batch-size', type=int, default=8)
+ parser.add_argument('--lr', type=float, default=0.0001)
+ parser.add_argument('--lrf', type=float, default=0.01)
+
+ # 数据集所在根目录
+ # https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
+ parser.add_argument('--data-path', type=str,
+ default="/data/flower_photos")
+
+ # 训练好的权重路径
+ parser.add_argument('--weights', type=str, default='./resNet(flower).pth',
+ help='trained weights path')
+
+ parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')
+
+ parser.add_argument('--onnx-filename', default='resnet34.onnx', help='save onnx model filename')
+ parser.add_argument('--onnx-bs', default=1, help='save onnx model batch size')
+ parser.add_argument('--qat', type=bool, default=True, help='whether use quantization aware training')
+
+ opt = parser.parse_args()
+
+ main(opt)
diff --git a/deploying_service/deploying_pytorch/convert_tensorrt/convert_resnet34/utils.py b/deploying_service/deploying_pytorch/convert_tensorrt/convert_resnet34/utils.py
new file mode 100644
index 000000000..309c32675
--- /dev/null
+++ b/deploying_service/deploying_pytorch/convert_tensorrt/convert_resnet34/utils.py
@@ -0,0 +1,131 @@
+import os
+import sys
+import json
+import pickle
+import random
+
+import torch
+from tqdm import tqdm
+
+
+def read_split_data(root: str, val_rate: float = 0.2):
+ random.seed(0) # 保证随机结果可复现
+ assert os.path.exists(root), "dataset root: {} does not exist.".format(root)
+
+ # 遍历文件夹,一个文件夹对应一个类别
+ flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
+ # 排序,保证顺序一致
+ flower_class.sort()
+ # 生成类别名称以及对应的数字索引
+ class_indices = dict((k, v) for v, k in enumerate(flower_class))
+ json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4)
+ with open('class_indices.json', 'w') as json_file:
+ json_file.write(json_str)
+
+ train_images_path = [] # 存储训练集的所有图片路径
+ train_images_label = [] # 存储训练集图片对应索引信息
+ val_images_path = [] # 存储验证集的所有图片路径
+ val_images_label = [] # 存储验证集图片对应索引信息
+ every_class_num = [] # 存储每个类别的样本总数
+ supported = [".jpg", ".JPG", ".png", ".PNG"] # 支持的文件后缀类型
+ # 遍历每个文件夹下的文件
+ for cla in flower_class:
+ cla_path = os.path.join(root, cla)
+ # 遍历获取supported支持的所有文件路径
+ images = [os.path.join(root, cla, i) for i in os.listdir(cla_path)
+ if os.path.splitext(i)[-1] in supported]
+ # 获取该类别对应的索引
+ image_class = class_indices[cla]
+ # 记录该类别的样本数量
+ every_class_num.append(len(images))
+ # 按比例随机采样验证样本
+ val_path = random.sample(images, k=int(len(images) * val_rate))
+
+ for img_path in images:
+ if img_path in val_path: # 如果该路径在采样的验证集样本中则存入验证集
+ val_images_path.append(img_path)
+ val_images_label.append(image_class)
+ else: # 否则存入训练集
+ train_images_path.append(img_path)
+ train_images_label.append(image_class)
+
+ print("{} images were found in the dataset.".format(sum(every_class_num)))
+ print("{} images for training.".format(len(train_images_path)))
+ print("{} images for validation.".format(len(val_images_path)))
+
+ return train_images_path, train_images_label, val_images_path, val_images_label
+
+
+def write_pickle(list_info: list, file_name: str):
+ with open(file_name, 'wb') as f:
+ pickle.dump(list_info, f)
+
+
+def read_pickle(file_name: str) -> list:
+ with open(file_name, 'rb') as f:
+ info_list = pickle.load(f)
+ return info_list
+
+
+def train_one_epoch(model, optimizer, data_loader, epoch):
+ model.train()
+ loss_function = torch.nn.CrossEntropyLoss()
+ accu_loss = torch.zeros(1).cuda() # 累计损失
+ accu_num = torch.zeros(1).cuda() # 累计预测正确的样本数
+ optimizer.zero_grad()
+
+ sample_num = 0
+ data_loader = tqdm(data_loader, file=sys.stdout)
+ for step, data in enumerate(data_loader):
+ images, labels = data
+ sample_num += images.shape[0]
+
+ pred = model(images.cuda())
+ pred_classes = torch.max(pred, dim=1)[1]
+ accu_num += torch.eq(pred_classes, labels.cuda()).sum()
+
+ loss = loss_function(pred, labels.cuda())
+ loss.backward()
+ accu_loss += loss.detach()
+
+ data_loader.desc = "[train epoch {}] loss: {:.3f}, acc: {:.3f}".format(epoch,
+ accu_loss.item() / (step + 1),
+ accu_num.item() / sample_num)
+
+ if not torch.isfinite(loss):
+ print('WARNING: non-finite loss, ending training ', loss)
+ sys.exit(1)
+
+ optimizer.step()
+ optimizer.zero_grad()
+
+ return accu_loss.item() / (step + 1), accu_num.item() / sample_num
+
+
+@torch.no_grad()
+def evaluate(model, data_loader, epoch):
+ loss_function = torch.nn.CrossEntropyLoss()
+
+ model.eval()
+
+ accu_num = torch.zeros(1).cuda() # 累计预测正确的样本数
+ accu_loss = torch.zeros(1).cuda() # 累计损失
+
+ sample_num = 0
+ data_loader = tqdm(data_loader, file=sys.stdout)
+ for step, data in enumerate(data_loader):
+ images, labels = data
+ sample_num += images.shape[0]
+
+ pred = model(images.cuda())
+ pred_classes = torch.max(pred, dim=1)[1]
+ accu_num += torch.eq(pred_classes, labels.cuda()).sum()
+
+ loss = loss_function(pred, labels.cuda())
+ accu_loss += loss
+
+ data_loader.desc = "[valid epoch {}] loss: {:.3f}, acc: {:.3f}".format(epoch,
+ accu_loss.item() / (step + 1),
+ accu_num.item() / sample_num)
+
+ return accu_loss.item() / (step + 1), accu_num.item() / sample_num
diff --git a/deploying_service/deploying_pytorch/pytorch_flask_service/main.py b/deploying_service/deploying_pytorch/pytorch_flask_service/main.py
index 974d2a453..2f25d6d0d 100644
--- a/deploying_service/deploying_pytorch/pytorch_flask_service/main.py
+++ b/deploying_service/deploying_pytorch/pytorch_flask_service/main.py
@@ -20,10 +20,10 @@
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
# create model
-model = MobileNetV2(num_classes=5)
+model = MobileNetV2(num_classes=5).to(device)
# load model weights
model.load_state_dict(torch.load(weights_path, map_location=device))
-model.to(device)
+
model.eval()
# load class info
diff --git a/deploying_service/deploying_pytorch/pytorch_flask_service/requirements.txt b/deploying_service/deploying_pytorch/pytorch_flask_service/requirements.txt
index 83b476f73..bdbbb72cd 100644
--- a/deploying_service/deploying_pytorch/pytorch_flask_service/requirements.txt
+++ b/deploying_service/deploying_pytorch/pytorch_flask_service/requirements.txt
@@ -1,3 +1,3 @@
-Flask==1.1.1
+Flask==2.2.5
Flask_Cors==3.0.9
Pillow
diff --git a/pytorch_classification/ConvNeXt/README.md b/pytorch_classification/ConvNeXt/README.md
new file mode 100644
index 000000000..c93d9df0e
--- /dev/null
+++ b/pytorch_classification/ConvNeXt/README.md
@@ -0,0 +1,12 @@
+## 代码使用简介
+
+1. 下载好数据集,代码中默认使用的是花分类数据集,下载地址: [https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz](https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz),
+如果下载不了的话可以通过百度云链接下载: https://pan.baidu.com/s/1QLCTA4sXnQAw_yvxPj9szg 提取码:58p0
+2. 在`train.py`脚本中将`--data-path`设置成解压后的`flower_photos`文件夹绝对路径
+3. 下载预训练权重,在`model.py`文件中每个模型都有提供预训练权重的下载地址,根据自己使用的模型下载对应预训练权重
+4. 在`train.py`脚本中将`--weights`参数设成下载好的预训练权重路径
+5. 设置好数据集的路径`--data-path`以及预训练权重的路径`--weights`就能使用`train.py`脚本开始训练了(训练过程中会自动生成`class_indices.json`文件)
+6. 在`predict.py`脚本中导入和训练脚本中同样的模型,并将`model_weight_path`设置成训练好的模型权重路径(默认保存在weights文件夹下)
+7. 在`predict.py`脚本中将`img_path`设置成你自己需要预测的图片绝对路径
+8. 设置好权重路径`model_weight_path`和预测的图片路径`img_path`就能使用`predict.py`脚本进行预测了
+9. 如果要使用自己的数据集,请按照花分类数据集的文件结构进行摆放(即一个类别对应一个文件夹),并且将训练以及预测脚本中的`num_classes`设置成你自己数据的类别数
diff --git a/pytorch_classification/ConvNeXt/model.py b/pytorch_classification/ConvNeXt/model.py
new file mode 100644
index 000000000..6e8337c2b
--- /dev/null
+++ b/pytorch_classification/ConvNeXt/model.py
@@ -0,0 +1,212 @@
+"""
+original code from facebook research:
+https://github.com/facebookresearch/ConvNeXt
+"""
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def drop_path(x, drop_prob: float = 0., training: bool = False):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
+ 'survival rate' as the argument.
+
+ """
+ if drop_prob == 0. or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
+ random_tensor.floor_() # binarize
+ output = x.div(keep_prob) * random_tensor
+ return output
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+ """
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
+
+
+class LayerNorm(nn.Module):
+ r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
+ shape (batch_size, height, width, channels) while channels_first corresponds to inputs
+ with shape (batch_size, channels, height, width).
+ """
+
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(normalized_shape), requires_grad=True)
+ self.bias = nn.Parameter(torch.zeros(normalized_shape), requires_grad=True)
+ self.eps = eps
+ self.data_format = data_format
+ if self.data_format not in ["channels_last", "channels_first"]:
+ raise ValueError(f"not support data format '{self.data_format}'")
+ self.normalized_shape = (normalized_shape,)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ if self.data_format == "channels_last":
+ return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
+ elif self.data_format == "channels_first":
+ # [batch_size, channels, height, width]
+ mean = x.mean(1, keepdim=True)
+ var = (x - mean).pow(2).mean(1, keepdim=True)
+ x = (x - mean) / torch.sqrt(var + self.eps)
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
+ return x
+
+
+class Block(nn.Module):
+ r""" ConvNeXt Block. There are two equivalent implementations:
+ (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
+ (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
+ We use (2) as we find it slightly faster in PyTorch
+
+ Args:
+ dim (int): Number of input channels.
+ drop_rate (float): Stochastic depth rate. Default: 0.0
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
+ """
+ def __init__(self, dim, drop_rate=0., layer_scale_init_value=1e-6):
+ super().__init__()
+ self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
+ self.norm = LayerNorm(dim, eps=1e-6, data_format="channels_last")
+ self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
+ self.act = nn.GELU()
+ self.pwconv2 = nn.Linear(4 * dim, dim)
+ self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim,)),
+ requires_grad=True) if layer_scale_init_value > 0 else None
+ self.drop_path = DropPath(drop_rate) if drop_rate > 0. else nn.Identity()
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ shortcut = x
+ x = self.dwconv(x)
+ x = x.permute(0, 2, 3, 1) # [N, C, H, W] -> [N, H, W, C]
+ x = self.norm(x)
+ x = self.pwconv1(x)
+ x = self.act(x)
+ x = self.pwconv2(x)
+ if self.gamma is not None:
+ x = self.gamma * x
+ x = x.permute(0, 3, 1, 2) # [N, H, W, C] -> [N, C, H, W]
+
+ x = shortcut + self.drop_path(x)
+ return x
+
+
+class ConvNeXt(nn.Module):
+ r""" ConvNeXt
+ A PyTorch impl of : `A ConvNet for the 2020s` -
+ https://arxiv.org/pdf/2201.03545.pdf
+ Args:
+ in_chans (int): Number of input image channels. Default: 3
+ num_classes (int): Number of classes for classification head. Default: 1000
+ depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
+ dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
+ drop_path_rate (float): Stochastic depth rate. Default: 0.
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
+ head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
+ """
+ def __init__(self, in_chans: int = 3, num_classes: int = 1000, depths: list = None,
+ dims: list = None, drop_path_rate: float = 0., layer_scale_init_value: float = 1e-6,
+ head_init_scale: float = 1.):
+ super().__init__()
+ self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
+ stem = nn.Sequential(nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
+ LayerNorm(dims[0], eps=1e-6, data_format="channels_first"))
+ self.downsample_layers.append(stem)
+
+ # 对应stage2-stage4前的3个downsample
+ for i in range(3):
+ downsample_layer = nn.Sequential(LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
+ nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2))
+ self.downsample_layers.append(downsample_layer)
+
+ self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple blocks
+ dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
+ cur = 0
+ # 构建每个stage中堆叠的block
+ for i in range(4):
+ stage = nn.Sequential(
+ *[Block(dim=dims[i], drop_rate=dp_rates[cur + j], layer_scale_init_value=layer_scale_init_value)
+ for j in range(depths[i])]
+ )
+ self.stages.append(stage)
+ cur += depths[i]
+
+ self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer
+ self.head = nn.Linear(dims[-1], num_classes)
+ self.apply(self._init_weights)
+ self.head.weight.data.mul_(head_init_scale)
+ self.head.bias.data.mul_(head_init_scale)
+
+ def _init_weights(self, m):
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
+ nn.init.trunc_normal_(m.weight, std=0.2)
+ nn.init.constant_(m.bias, 0)
+
+ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
+ for i in range(4):
+ x = self.downsample_layers[i](x)
+ x = self.stages[i](x)
+
+ return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.forward_features(x)
+ x = self.head(x)
+ return x
+
+
+def convnext_tiny(num_classes: int):
+ # https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth
+ model = ConvNeXt(depths=[3, 3, 9, 3],
+ dims=[96, 192, 384, 768],
+ num_classes=num_classes)
+ return model
+
+
+def convnext_small(num_classes: int):
+ # https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth
+ model = ConvNeXt(depths=[3, 3, 27, 3],
+ dims=[96, 192, 384, 768],
+ num_classes=num_classes)
+ return model
+
+
+def convnext_base(num_classes: int):
+ # https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth
+ # https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth
+ model = ConvNeXt(depths=[3, 3, 27, 3],
+ dims=[128, 256, 512, 1024],
+ num_classes=num_classes)
+ return model
+
+
+def convnext_large(num_classes: int):
+ # https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth
+ # https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth
+ model = ConvNeXt(depths=[3, 3, 27, 3],
+ dims=[192, 384, 768, 1536],
+ num_classes=num_classes)
+ return model
+
+
+def convnext_xlarge(num_classes: int):
+ # https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth
+ model = ConvNeXt(depths=[3, 3, 27, 3],
+ dims=[256, 512, 1024, 2048],
+ num_classes=num_classes)
+ return model
diff --git a/pytorch_classification/ConvNeXt/my_dataset.py b/pytorch_classification/ConvNeXt/my_dataset.py
new file mode 100644
index 000000000..167bc9a30
--- /dev/null
+++ b/pytorch_classification/ConvNeXt/my_dataset.py
@@ -0,0 +1,37 @@
+from PIL import Image
+import torch
+from torch.utils.data import Dataset
+
+
+class MyDataSet(Dataset):
+ """自定义数据集"""
+
+ def __init__(self, images_path: list, images_class: list, transform=None):
+ self.images_path = images_path
+ self.images_class = images_class
+ self.transform = transform
+
+ def __len__(self):
+ return len(self.images_path)
+
+ def __getitem__(self, item):
+ img = Image.open(self.images_path[item])
+ # RGB为彩色图片,L为灰度图片
+ if img.mode != 'RGB':
+ raise ValueError("image: {} isn't RGB mode.".format(self.images_path[item]))
+ label = self.images_class[item]
+
+ if self.transform is not None:
+ img = self.transform(img)
+
+ return img, label
+
+ @staticmethod
+ def collate_fn(batch):
+ # 官方实现的default_collate可以参考
+ # https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py
+ images, labels = tuple(zip(*batch))
+
+ images = torch.stack(images, dim=0)
+ labels = torch.as_tensor(labels)
+ return images, labels
diff --git a/pytorch_classification/ConvNeXt/predict.py b/pytorch_classification/ConvNeXt/predict.py
new file mode 100644
index 000000000..a603b22e7
--- /dev/null
+++ b/pytorch_classification/ConvNeXt/predict.py
@@ -0,0 +1,63 @@
+import os
+import json
+
+import torch
+from PIL import Image
+from torchvision import transforms
+import matplotlib.pyplot as plt
+
+from model import convnext_tiny as create_model
+
+
+def main():
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+ print(f"using {device} device.")
+
+ num_classes = 5
+ img_size = 224
+ data_transform = transforms.Compose(
+ [transforms.Resize(int(img_size * 1.14)),
+ transforms.CenterCrop(img_size),
+ transforms.ToTensor(),
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
+
+ # load image
+ img_path = "../tulip.jpg"
+ assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
+ img = Image.open(img_path)
+ plt.imshow(img)
+ # [N, C, H, W]
+ img = data_transform(img)
+ # expand batch dimension
+ img = torch.unsqueeze(img, dim=0)
+
+ # read class_indict
+ json_path = './class_indices.json'
+ assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
+
+ with open(json_path, "r") as f:
+ class_indict = json.load(f)
+
+ # create model
+ model = create_model(num_classes=num_classes).to(device)
+ # load model weights
+ model_weight_path = "./weights/best_model.pth"
+ model.load_state_dict(torch.load(model_weight_path, map_location=device))
+ model.eval()
+ with torch.no_grad():
+ # predict class
+ output = torch.squeeze(model(img.to(device))).cpu()
+ predict = torch.softmax(output, dim=0)
+ predict_cla = torch.argmax(predict).numpy()
+
+ print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)],
+ predict[predict_cla].numpy())
+ plt.title(print_res)
+ for i in range(len(predict)):
+ print("class: {:10} prob: {:.3}".format(class_indict[str(i)],
+ predict[i].numpy()))
+ plt.show()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/pytorch_classification/ConvNeXt/train.py b/pytorch_classification/ConvNeXt/train.py
new file mode 100644
index 000000000..cbec65967
--- /dev/null
+++ b/pytorch_classification/ConvNeXt/train.py
@@ -0,0 +1,139 @@
+import os
+import argparse
+
+import torch
+import torch.optim as optim
+from torch.utils.tensorboard import SummaryWriter
+from torchvision import transforms
+
+from my_dataset import MyDataSet
+from model import convnext_tiny as create_model
+from utils import read_split_data, create_lr_scheduler, get_params_groups, train_one_epoch, evaluate
+
+
+def main(args):
+ device = torch.device(args.device if torch.cuda.is_available() else "cpu")
+ print(f"using {device} device.")
+
+ if os.path.exists("./weights") is False:
+ os.makedirs("./weights")
+
+ tb_writer = SummaryWriter()
+
+ train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path)
+
+ img_size = 224
+ data_transform = {
+ "train": transforms.Compose([transforms.RandomResizedCrop(img_size),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
+ "val": transforms.Compose([transforms.Resize(int(img_size * 1.143)),
+ transforms.CenterCrop(img_size),
+ transforms.ToTensor(),
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
+
+ # 实例化训练数据集
+ train_dataset = MyDataSet(images_path=train_images_path,
+ images_class=train_images_label,
+ transform=data_transform["train"])
+
+ # 实例化验证数据集
+ val_dataset = MyDataSet(images_path=val_images_path,
+ images_class=val_images_label,
+ transform=data_transform["val"])
+
+ batch_size = args.batch_size
+ nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
+ print('Using {} dataloader workers every process'.format(nw))
+ train_loader = torch.utils.data.DataLoader(train_dataset,
+ batch_size=batch_size,
+ shuffle=True,
+ pin_memory=True,
+ num_workers=nw,
+ collate_fn=train_dataset.collate_fn)
+
+ val_loader = torch.utils.data.DataLoader(val_dataset,
+ batch_size=batch_size,
+ shuffle=False,
+ pin_memory=True,
+ num_workers=nw,
+ collate_fn=val_dataset.collate_fn)
+
+ model = create_model(num_classes=args.num_classes).to(device)
+
+ if args.weights != "":
+ assert os.path.exists(args.weights), "weights file: '{}' not exist.".format(args.weights)
+ weights_dict = torch.load(args.weights, map_location=device)["model"]
+ # 删除有关分类类别的权重
+ for k in list(weights_dict.keys()):
+ if "head" in k:
+ del weights_dict[k]
+ print(model.load_state_dict(weights_dict, strict=False))
+
+ if args.freeze_layers:
+ for name, para in model.named_parameters():
+ # 除head外,其他权重全部冻结
+ if "head" not in name:
+ para.requires_grad_(False)
+ else:
+ print("training {}".format(name))
+
+ # pg = [p for p in model.parameters() if p.requires_grad]
+ pg = get_params_groups(model, weight_decay=args.wd)
+ optimizer = optim.AdamW(pg, lr=args.lr, weight_decay=args.wd)
+ lr_scheduler = create_lr_scheduler(optimizer, len(train_loader), args.epochs,
+ warmup=True, warmup_epochs=1)
+
+ best_acc = 0.
+ for epoch in range(args.epochs):
+ # train
+ train_loss, train_acc = train_one_epoch(model=model,
+ optimizer=optimizer,
+ data_loader=train_loader,
+ device=device,
+ epoch=epoch,
+ lr_scheduler=lr_scheduler)
+
+ # validate
+ val_loss, val_acc = evaluate(model=model,
+ data_loader=val_loader,
+ device=device,
+ epoch=epoch)
+
+ tags = ["train_loss", "train_acc", "val_loss", "val_acc", "learning_rate"]
+ tb_writer.add_scalar(tags[0], train_loss, epoch)
+ tb_writer.add_scalar(tags[1], train_acc, epoch)
+ tb_writer.add_scalar(tags[2], val_loss, epoch)
+ tb_writer.add_scalar(tags[3], val_acc, epoch)
+ tb_writer.add_scalar(tags[4], optimizer.param_groups[0]["lr"], epoch)
+
+ if best_acc < val_acc:
+ torch.save(model.state_dict(), "./weights/best_model.pth")
+ best_acc = val_acc
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--num_classes', type=int, default=5)
+ parser.add_argument('--epochs', type=int, default=10)
+ parser.add_argument('--batch-size', type=int, default=8)
+ parser.add_argument('--lr', type=float, default=5e-4)
+ parser.add_argument('--wd', type=float, default=5e-2)
+
+ # 数据集所在根目录
+ # https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
+ parser.add_argument('--data-path', type=str,
+ default="/data/flower_photos")
+
+ # 预训练权重路径,如果不想载入就设置为空字符
+ # 链接: https://pan.baidu.com/s/1aNqQW4n_RrUlWUBNlaJRHA 密码: i83t
+ parser.add_argument('--weights', type=str, default='./convnext_tiny_1k_224_ema.pth',
+ help='initial weights path')
+ # 是否冻结head以外所有权重
+ parser.add_argument('--freeze-layers', type=bool, default=False)
+ parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')
+
+ opt = parser.parse_args()
+
+ main(opt)
diff --git a/pytorch_classification/ConvNeXt/utils.py b/pytorch_classification/ConvNeXt/utils.py
new file mode 100644
index 000000000..c2bcb5594
--- /dev/null
+++ b/pytorch_classification/ConvNeXt/utils.py
@@ -0,0 +1,241 @@
+import os
+import sys
+import json
+import pickle
+import random
+import math
+
+import torch
+from tqdm import tqdm
+
+import matplotlib.pyplot as plt
+
+
+def read_split_data(root: str, val_rate: float = 0.2):
+ random.seed(0) # 保证随机结果可复现
+ assert os.path.exists(root), "dataset root: {} does not exist.".format(root)
+
+ # 遍历文件夹,一个文件夹对应一个类别
+ flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
+ # 排序,保证各平台顺序一致
+ flower_class.sort()
+ # 生成类别名称以及对应的数字索引
+ class_indices = dict((k, v) for v, k in enumerate(flower_class))
+ json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4)
+ with open('class_indices.json', 'w') as json_file:
+ json_file.write(json_str)
+
+ train_images_path = [] # 存储训练集的所有图片路径
+ train_images_label = [] # 存储训练集图片对应索引信息
+ val_images_path = [] # 存储验证集的所有图片路径
+ val_images_label = [] # 存储验证集图片对应索引信息
+ every_class_num = [] # 存储每个类别的样本总数
+ supported = [".jpg", ".JPG", ".png", ".PNG"] # 支持的文件后缀类型
+ # 遍历每个文件夹下的文件
+ for cla in flower_class:
+ cla_path = os.path.join(root, cla)
+ # 遍历获取supported支持的所有文件路径
+ images = [os.path.join(root, cla, i) for i in os.listdir(cla_path)
+ if os.path.splitext(i)[-1] in supported]
+ # 排序,保证各平台顺序一致
+ images.sort()
+ # 获取该类别对应的索引
+ image_class = class_indices[cla]
+ # 记录该类别的样本数量
+ every_class_num.append(len(images))
+ # 按比例随机采样验证样本
+ val_path = random.sample(images, k=int(len(images) * val_rate))
+
+ for img_path in images:
+ if img_path in val_path: # 如果该路径在采样的验证集样本中则存入验证集
+ val_images_path.append(img_path)
+ val_images_label.append(image_class)
+ else: # 否则存入训练集
+ train_images_path.append(img_path)
+ train_images_label.append(image_class)
+
+ print("{} images were found in the dataset.".format(sum(every_class_num)))
+ print("{} images for training.".format(len(train_images_path)))
+ print("{} images for validation.".format(len(val_images_path)))
+ assert len(train_images_path) > 0, "number of training images must greater than 0."
+ assert len(val_images_path) > 0, "number of validation images must greater than 0."
+
+ plot_image = False
+ if plot_image:
+ # 绘制每种类别个数柱状图
+ plt.bar(range(len(flower_class)), every_class_num, align='center')
+ # 将横坐标0,1,2,3,4替换为相应的类别名称
+ plt.xticks(range(len(flower_class)), flower_class)
+ # 在柱状图上添加数值标签
+ for i, v in enumerate(every_class_num):
+ plt.text(x=i, y=v + 5, s=str(v), ha='center')
+ # 设置x坐标
+ plt.xlabel('image class')
+ # 设置y坐标
+ plt.ylabel('number of images')
+ # 设置柱状图的标题
+ plt.title('flower class distribution')
+ plt.show()
+
+ return train_images_path, train_images_label, val_images_path, val_images_label
+
+
+def plot_data_loader_image(data_loader):
+ batch_size = data_loader.batch_size
+ plot_num = min(batch_size, 4)
+
+ json_path = './class_indices.json'
+ assert os.path.exists(json_path), json_path + " does not exist."
+ json_file = open(json_path, 'r')
+ class_indices = json.load(json_file)
+
+ for data in data_loader:
+ images, labels = data
+ for i in range(plot_num):
+ # [C, H, W] -> [H, W, C]
+ img = images[i].numpy().transpose(1, 2, 0)
+ # 反Normalize操作
+ img = (img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]) * 255
+ label = labels[i].item()
+ plt.subplot(1, plot_num, i+1)
+ plt.xlabel(class_indices[str(label)])
+ plt.xticks([]) # 去掉x轴的刻度
+ plt.yticks([]) # 去掉y轴的刻度
+ plt.imshow(img.astype('uint8'))
+ plt.show()
+
+
+def write_pickle(list_info: list, file_name: str):
+ with open(file_name, 'wb') as f:
+ pickle.dump(list_info, f)
+
+
+def read_pickle(file_name: str) -> list:
+ with open(file_name, 'rb') as f:
+ info_list = pickle.load(f)
+ return info_list
+
+
+def train_one_epoch(model, optimizer, data_loader, device, epoch, lr_scheduler):
+ model.train()
+ loss_function = torch.nn.CrossEntropyLoss()
+ accu_loss = torch.zeros(1).to(device) # 累计损失
+ accu_num = torch.zeros(1).to(device) # 累计预测正确的样本数
+ optimizer.zero_grad()
+
+ sample_num = 0
+ data_loader = tqdm(data_loader, file=sys.stdout)
+ for step, data in enumerate(data_loader):
+ images, labels = data
+ sample_num += images.shape[0]
+
+ pred = model(images.to(device))
+ pred_classes = torch.max(pred, dim=1)[1]
+ accu_num += torch.eq(pred_classes, labels.to(device)).sum()
+
+ loss = loss_function(pred, labels.to(device))
+ loss.backward()
+ accu_loss += loss.detach()
+
+ data_loader.desc = "[train epoch {}] loss: {:.3f}, acc: {:.3f}, lr: {:.5f}".format(
+ epoch,
+ accu_loss.item() / (step + 1),
+ accu_num.item() / sample_num,
+ optimizer.param_groups[0]["lr"]
+ )
+
+ if not torch.isfinite(loss):
+ print('WARNING: non-finite loss, ending training ', loss)
+ sys.exit(1)
+
+ optimizer.step()
+ optimizer.zero_grad()
+ # update lr
+ lr_scheduler.step()
+
+ return accu_loss.item() / (step + 1), accu_num.item() / sample_num
+
+
+@torch.no_grad()
+def evaluate(model, data_loader, device, epoch):
+ loss_function = torch.nn.CrossEntropyLoss()
+
+ model.eval()
+
+ accu_num = torch.zeros(1).to(device) # 累计预测正确的样本数
+ accu_loss = torch.zeros(1).to(device) # 累计损失
+
+ sample_num = 0
+ data_loader = tqdm(data_loader, file=sys.stdout)
+ for step, data in enumerate(data_loader):
+ images, labels = data
+ sample_num += images.shape[0]
+
+ pred = model(images.to(device))
+ pred_classes = torch.max(pred, dim=1)[1]
+ accu_num += torch.eq(pred_classes, labels.to(device)).sum()
+
+ loss = loss_function(pred, labels.to(device))
+ accu_loss += loss
+
+ data_loader.desc = "[valid epoch {}] loss: {:.3f}, acc: {:.3f}".format(
+ epoch,
+ accu_loss.item() / (step + 1),
+ accu_num.item() / sample_num
+ )
+
+ return accu_loss.item() / (step + 1), accu_num.item() / sample_num
+
+
+def create_lr_scheduler(optimizer,
+ num_step: int,
+ epochs: int,
+ warmup=True,
+ warmup_epochs=1,
+ warmup_factor=1e-3,
+ end_factor=1e-6):
+ assert num_step > 0 and epochs > 0
+ if warmup is False:
+ warmup_epochs = 0
+
+ def f(x):
+ """
+ 根据step数返回一个学习率倍率因子,
+ 注意在训练开始之前,pytorch会提前调用一次lr_scheduler.step()方法
+ """
+ if warmup is True and x <= (warmup_epochs * num_step):
+ alpha = float(x) / (warmup_epochs * num_step)
+ # warmup过程中lr倍率因子从warmup_factor -> 1
+ return warmup_factor * (1 - alpha) + alpha
+ else:
+ current_step = (x - warmup_epochs * num_step)
+ cosine_steps = (epochs - warmup_epochs) * num_step
+ # warmup后lr倍率因子从1 -> end_factor
+ return ((1 + math.cos(current_step * math.pi / cosine_steps)) / 2) * (1 - end_factor) + end_factor
+
+ return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=f)
+
+
+def get_params_groups(model: torch.nn.Module, weight_decay: float = 1e-5):
+ # 记录optimize要训练的权重参数
+ parameter_group_vars = {"decay": {"params": [], "weight_decay": weight_decay},
+ "no_decay": {"params": [], "weight_decay": 0.}}
+
+ # 记录对应的权重名称
+ parameter_group_names = {"decay": {"params": [], "weight_decay": weight_decay},
+ "no_decay": {"params": [], "weight_decay": 0.}}
+
+ for name, param in model.named_parameters():
+ if not param.requires_grad:
+ continue # frozen weights
+
+ if len(param.shape) == 1 or name.endswith(".bias"):
+ group_name = "no_decay"
+ else:
+ group_name = "decay"
+
+ parameter_group_vars[group_name]["params"].append(param)
+ parameter_group_names[group_name]["params"].append(name)
+
+ print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
+ return list(parameter_group_vars.values())
diff --git a/pytorch_classification/MobileViT/README.md b/pytorch_classification/MobileViT/README.md
new file mode 100644
index 000000000..c93d9df0e
--- /dev/null
+++ b/pytorch_classification/MobileViT/README.md
@@ -0,0 +1,12 @@
+## 代码使用简介
+
+1. 下载好数据集,代码中默认使用的是花分类数据集,下载地址: [https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz](https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz),
+如果下载不了的话可以通过百度云链接下载: https://pan.baidu.com/s/1QLCTA4sXnQAw_yvxPj9szg 提取码:58p0
+2. 在`train.py`脚本中将`--data-path`设置成解压后的`flower_photos`文件夹绝对路径
+3. 下载预训练权重,在`model.py`文件中每个模型都有提供预训练权重的下载地址,根据自己使用的模型下载对应预训练权重
+4. 在`train.py`脚本中将`--weights`参数设成下载好的预训练权重路径
+5. 设置好数据集的路径`--data-path`以及预训练权重的路径`--weights`就能使用`train.py`脚本开始训练了(训练过程中会自动生成`class_indices.json`文件)
+6. 在`predict.py`脚本中导入和训练脚本中同样的模型,并将`model_weight_path`设置成训练好的模型权重路径(默认保存在weights文件夹下)
+7. 在`predict.py`脚本中将`img_path`设置成你自己需要预测的图片绝对路径
+8. 设置好权重路径`model_weight_path`和预测的图片路径`img_path`就能使用`predict.py`脚本进行预测了
+9. 如果要使用自己的数据集,请按照花分类数据集的文件结构进行摆放(即一个类别对应一个文件夹),并且将训练以及预测脚本中的`num_classes`设置成你自己数据的类别数
diff --git a/pytorch_classification/MobileViT/model.py b/pytorch_classification/MobileViT/model.py
new file mode 100644
index 000000000..1606f1b69
--- /dev/null
+++ b/pytorch_classification/MobileViT/model.py
@@ -0,0 +1,562 @@
+"""
+original code from apple:
+https://github.com/apple/ml-cvnets/blob/main/cvnets/models/classification/mobilevit.py
+"""
+
+from typing import Optional, Tuple, Union, Dict
+import math
+import torch
+import torch.nn as nn
+from torch import Tensor
+from torch.nn import functional as F
+
+from transformer import TransformerEncoder
+from model_config import get_config
+
+
+def make_divisible(
+ v: Union[float, int],
+ divisor: Optional[int] = 8,
+ min_value: Optional[Union[float, int]] = None,
+) -> Union[float, int]:
+ """
+ This function is taken from the original tf repo.
+ It ensures that all layers have a channel number that is divisible by 8
+ It can be seen here:
+ https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
+ :param v:
+ :param divisor:
+ :param min_value:
+ :return:
+ """
+ if min_value is None:
+ min_value = divisor
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
+ # Make sure that round down does not go down by more than 10%.
+ if new_v < 0.9 * v:
+ new_v += divisor
+ return new_v
+
+
+class ConvLayer(nn.Module):
+ """
+ Applies a 2D convolution over an input
+
+ Args:
+ in_channels (int): :math:`C_{in}` from an expected input of size :math:`(N, C_{in}, H_{in}, W_{in})`
+ out_channels (int): :math:`C_{out}` from an expected output of size :math:`(N, C_{out}, H_{out}, W_{out})`
+ kernel_size (Union[int, Tuple[int, int]]): Kernel size for convolution.
+ stride (Union[int, Tuple[int, int]]): Stride for convolution. Default: 1
+ groups (Optional[int]): Number of groups in convolution. Default: 1
+ bias (Optional[bool]): Use bias. Default: ``False``
+ use_norm (Optional[bool]): Use normalization layer after convolution. Default: ``True``
+ use_act (Optional[bool]): Use activation layer after convolution (or convolution and normalization).
+ Default: ``True``
+
+ Shape:
+ - Input: :math:`(N, C_{in}, H_{in}, W_{in})`
+ - Output: :math:`(N, C_{out}, H_{out}, W_{out})`
+
+ .. note::
+ For depth-wise convolution, `groups=C_{in}=C_{out}`.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: Union[int, Tuple[int, int]],
+ stride: Optional[Union[int, Tuple[int, int]]] = 1,
+ groups: Optional[int] = 1,
+ bias: Optional[bool] = False,
+ use_norm: Optional[bool] = True,
+ use_act: Optional[bool] = True,
+ ) -> None:
+ super().__init__()
+
+ if isinstance(kernel_size, int):
+ kernel_size = (kernel_size, kernel_size)
+
+ if isinstance(stride, int):
+ stride = (stride, stride)
+
+ assert isinstance(kernel_size, Tuple)
+ assert isinstance(stride, Tuple)
+
+ padding = (
+ int((kernel_size[0] - 1) / 2),
+ int((kernel_size[1] - 1) / 2),
+ )
+
+ block = nn.Sequential()
+
+ conv_layer = nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ groups=groups,
+ padding=padding,
+ bias=bias
+ )
+
+ block.add_module(name="conv", module=conv_layer)
+
+ if use_norm:
+ norm_layer = nn.BatchNorm2d(num_features=out_channels, momentum=0.1)
+ block.add_module(name="norm", module=norm_layer)
+
+ if use_act:
+ act_layer = nn.SiLU()
+ block.add_module(name="act", module=act_layer)
+
+ self.block = block
+
+ def forward(self, x: Tensor) -> Tensor:
+ return self.block(x)
+
+
+class InvertedResidual(nn.Module):
+ """
+ This class implements the inverted residual block, as described in `MobileNetv2 `_ paper
+
+ Args:
+ in_channels (int): :math:`C_{in}` from an expected input of size :math:`(N, C_{in}, H_{in}, W_{in})`
+ out_channels (int): :math:`C_{out}` from an expected output of size :math:`(N, C_{out}, H_{out}, W_{out)`
+ stride (int): Use convolutions with a stride. Default: 1
+ expand_ratio (Union[int, float]): Expand the input channels by this factor in depth-wise conv
+ skip_connection (Optional[bool]): Use skip-connection. Default: True
+
+ Shape:
+ - Input: :math:`(N, C_{in}, H_{in}, W_{in})`
+ - Output: :math:`(N, C_{out}, H_{out}, W_{out})`
+
+ .. note::
+ If `in_channels =! out_channels` and `stride > 1`, we set `skip_connection=False`
+
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ stride: int,
+ expand_ratio: Union[int, float],
+ skip_connection: Optional[bool] = True,
+ ) -> None:
+ assert stride in [1, 2]
+ hidden_dim = make_divisible(int(round(in_channels * expand_ratio)), 8)
+
+ super().__init__()
+
+ block = nn.Sequential()
+ if expand_ratio != 1:
+ block.add_module(
+ name="exp_1x1",
+ module=ConvLayer(
+ in_channels=in_channels,
+ out_channels=hidden_dim,
+ kernel_size=1
+ ),
+ )
+
+ block.add_module(
+ name="conv_3x3",
+ module=ConvLayer(
+ in_channels=hidden_dim,
+ out_channels=hidden_dim,
+ stride=stride,
+ kernel_size=3,
+ groups=hidden_dim
+ ),
+ )
+
+ block.add_module(
+ name="red_1x1",
+ module=ConvLayer(
+ in_channels=hidden_dim,
+ out_channels=out_channels,
+ kernel_size=1,
+ use_act=False,
+ use_norm=True,
+ ),
+ )
+
+ self.block = block
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.exp = expand_ratio
+ self.stride = stride
+ self.use_res_connect = (
+ self.stride == 1 and in_channels == out_channels and skip_connection
+ )
+
+ def forward(self, x: Tensor, *args, **kwargs) -> Tensor:
+ if self.use_res_connect:
+ return x + self.block(x)
+ else:
+ return self.block(x)
+
+
+class MobileViTBlock(nn.Module):
+ """
+ This class defines the `MobileViT block `_
+
+ Args:
+ opts: command line arguments
+ in_channels (int): :math:`C_{in}` from an expected input of size :math:`(N, C_{in}, H, W)`
+ transformer_dim (int): Input dimension to the transformer unit
+ ffn_dim (int): Dimension of the FFN block
+ n_transformer_blocks (int): Number of transformer blocks. Default: 2
+ head_dim (int): Head dimension in the multi-head attention. Default: 32
+ attn_dropout (float): Dropout in multi-head attention. Default: 0.0
+ dropout (float): Dropout rate. Default: 0.0
+ ffn_dropout (float): Dropout between FFN layers in transformer. Default: 0.0
+ patch_h (int): Patch height for unfolding operation. Default: 8
+ patch_w (int): Patch width for unfolding operation. Default: 8
+ transformer_norm_layer (Optional[str]): Normalization layer in the transformer block. Default: layer_norm
+ conv_ksize (int): Kernel size to learn local representations in MobileViT block. Default: 3
+ no_fusion (Optional[bool]): Do not combine the input and output feature maps. Default: False
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ transformer_dim: int,
+ ffn_dim: int,
+ n_transformer_blocks: int = 2,
+ head_dim: int = 32,
+ attn_dropout: float = 0.0,
+ dropout: float = 0.0,
+ ffn_dropout: float = 0.0,
+ patch_h: int = 8,
+ patch_w: int = 8,
+ conv_ksize: Optional[int] = 3,
+ *args,
+ **kwargs
+ ) -> None:
+ super().__init__()
+
+ conv_3x3_in = ConvLayer(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ kernel_size=conv_ksize,
+ stride=1
+ )
+ conv_1x1_in = ConvLayer(
+ in_channels=in_channels,
+ out_channels=transformer_dim,
+ kernel_size=1,
+ stride=1,
+ use_norm=False,
+ use_act=False
+ )
+
+ conv_1x1_out = ConvLayer(
+ in_channels=transformer_dim,
+ out_channels=in_channels,
+ kernel_size=1,
+ stride=1
+ )
+ conv_3x3_out = ConvLayer(
+ in_channels=2 * in_channels,
+ out_channels=in_channels,
+ kernel_size=conv_ksize,
+ stride=1
+ )
+
+ self.local_rep = nn.Sequential()
+ self.local_rep.add_module(name="conv_3x3", module=conv_3x3_in)
+ self.local_rep.add_module(name="conv_1x1", module=conv_1x1_in)
+
+ assert transformer_dim % head_dim == 0
+ num_heads = transformer_dim // head_dim
+
+ global_rep = [
+ TransformerEncoder(
+ embed_dim=transformer_dim,
+ ffn_latent_dim=ffn_dim,
+ num_heads=num_heads,
+ attn_dropout=attn_dropout,
+ dropout=dropout,
+ ffn_dropout=ffn_dropout
+ )
+ for _ in range(n_transformer_blocks)
+ ]
+ global_rep.append(nn.LayerNorm(transformer_dim))
+ self.global_rep = nn.Sequential(*global_rep)
+
+ self.conv_proj = conv_1x1_out
+ self.fusion = conv_3x3_out
+
+ self.patch_h = patch_h
+ self.patch_w = patch_w
+ self.patch_area = self.patch_w * self.patch_h
+
+ self.cnn_in_dim = in_channels
+ self.cnn_out_dim = transformer_dim
+ self.n_heads = num_heads
+ self.ffn_dim = ffn_dim
+ self.dropout = dropout
+ self.attn_dropout = attn_dropout
+ self.ffn_dropout = ffn_dropout
+ self.n_blocks = n_transformer_blocks
+ self.conv_ksize = conv_ksize
+
+ def unfolding(self, x: Tensor) -> Tuple[Tensor, Dict]:
+ patch_w, patch_h = self.patch_w, self.patch_h
+ patch_area = patch_w * patch_h
+ batch_size, in_channels, orig_h, orig_w = x.shape
+
+ new_h = int(math.ceil(orig_h / self.patch_h) * self.patch_h)
+ new_w = int(math.ceil(orig_w / self.patch_w) * self.patch_w)
+
+ interpolate = False
+ if new_w != orig_w or new_h != orig_h:
+ # Note: Padding can be done, but then it needs to be handled in attention function.
+ x = F.interpolate(x, size=(new_h, new_w), mode="bilinear", align_corners=False)
+ interpolate = True
+
+ # number of patches along width and height
+ num_patch_w = new_w // patch_w # n_w
+ num_patch_h = new_h // patch_h # n_h
+ num_patches = num_patch_h * num_patch_w # N
+
+ # [B, C, H, W] -> [B * C * n_h, p_h, n_w, p_w]
+ x = x.reshape(batch_size * in_channels * num_patch_h, patch_h, num_patch_w, patch_w)
+ # [B * C * n_h, p_h, n_w, p_w] -> [B * C * n_h, n_w, p_h, p_w]
+ x = x.transpose(1, 2)
+ # [B * C * n_h, n_w, p_h, p_w] -> [B, C, N, P] where P = p_h * p_w and N = n_h * n_w
+ x = x.reshape(batch_size, in_channels, num_patches, patch_area)
+ # [B, C, N, P] -> [B, P, N, C]
+ x = x.transpose(1, 3)
+ # [B, P, N, C] -> [BP, N, C]
+ x = x.reshape(batch_size * patch_area, num_patches, -1)
+
+ info_dict = {
+ "orig_size": (orig_h, orig_w),
+ "batch_size": batch_size,
+ "interpolate": interpolate,
+ "total_patches": num_patches,
+ "num_patches_w": num_patch_w,
+ "num_patches_h": num_patch_h,
+ }
+
+ return x, info_dict
+
+ def folding(self, x: Tensor, info_dict: Dict) -> Tensor:
+ n_dim = x.dim()
+ assert n_dim == 3, "Tensor should be of shape BPxNxC. Got: {}".format(
+ x.shape
+ )
+ # [BP, N, C] --> [B, P, N, C]
+ x = x.contiguous().view(
+ info_dict["batch_size"], self.patch_area, info_dict["total_patches"], -1
+ )
+
+ batch_size, pixels, num_patches, channels = x.size()
+ num_patch_h = info_dict["num_patches_h"]
+ num_patch_w = info_dict["num_patches_w"]
+
+ # [B, P, N, C] -> [B, C, N, P]
+ x = x.transpose(1, 3)
+ # [B, C, N, P] -> [B*C*n_h, n_w, p_h, p_w]
+ x = x.reshape(batch_size * channels * num_patch_h, num_patch_w, self.patch_h, self.patch_w)
+ # [B*C*n_h, n_w, p_h, p_w] -> [B*C*n_h, p_h, n_w, p_w]
+ x = x.transpose(1, 2)
+ # [B*C*n_h, p_h, n_w, p_w] -> [B, C, H, W]
+ x = x.reshape(batch_size, channels, num_patch_h * self.patch_h, num_patch_w * self.patch_w)
+ if info_dict["interpolate"]:
+ x = F.interpolate(
+ x,
+ size=info_dict["orig_size"],
+ mode="bilinear",
+ align_corners=False,
+ )
+ return x
+
+ def forward(self, x: Tensor) -> Tensor:
+ res = x
+
+ fm = self.local_rep(x)
+
+ # convert feature map to patches
+ patches, info_dict = self.unfolding(fm)
+
+ # learn global representations
+ for transformer_layer in self.global_rep:
+ patches = transformer_layer(patches)
+
+ # [B x Patch x Patches x C] -> [B x C x Patches x Patch]
+ fm = self.folding(x=patches, info_dict=info_dict)
+
+ fm = self.conv_proj(fm)
+
+ fm = self.fusion(torch.cat((res, fm), dim=1))
+ return fm
+
+
+class MobileViT(nn.Module):
+ """
+ This class implements the `MobileViT architecture `_
+ """
+ def __init__(self, model_cfg: Dict, num_classes: int = 1000):
+ super().__init__()
+
+ image_channels = 3
+ out_channels = 16
+
+ self.conv_1 = ConvLayer(
+ in_channels=image_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ stride=2
+ )
+
+ self.layer_1, out_channels = self._make_layer(input_channel=out_channels, cfg=model_cfg["layer1"])
+ self.layer_2, out_channels = self._make_layer(input_channel=out_channels, cfg=model_cfg["layer2"])
+ self.layer_3, out_channels = self._make_layer(input_channel=out_channels, cfg=model_cfg["layer3"])
+ self.layer_4, out_channels = self._make_layer(input_channel=out_channels, cfg=model_cfg["layer4"])
+ self.layer_5, out_channels = self._make_layer(input_channel=out_channels, cfg=model_cfg["layer5"])
+
+ exp_channels = min(model_cfg["last_layer_exp_factor"] * out_channels, 960)
+ self.conv_1x1_exp = ConvLayer(
+ in_channels=out_channels,
+ out_channels=exp_channels,
+ kernel_size=1
+ )
+
+ self.classifier = nn.Sequential()
+ self.classifier.add_module(name="global_pool", module=nn.AdaptiveAvgPool2d(1))
+ self.classifier.add_module(name="flatten", module=nn.Flatten())
+ if 0.0 < model_cfg["cls_dropout"] < 1.0:
+ self.classifier.add_module(name="dropout", module=nn.Dropout(p=model_cfg["cls_dropout"]))
+ self.classifier.add_module(name="fc", module=nn.Linear(in_features=exp_channels, out_features=num_classes))
+
+ # weight init
+ self.apply(self.init_parameters)
+
+ def _make_layer(self, input_channel, cfg: Dict) -> Tuple[nn.Sequential, int]:
+ block_type = cfg.get("block_type", "mobilevit")
+ if block_type.lower() == "mobilevit":
+ return self._make_mit_layer(input_channel=input_channel, cfg=cfg)
+ else:
+ return self._make_mobilenet_layer(input_channel=input_channel, cfg=cfg)
+
+ @staticmethod
+ def _make_mobilenet_layer(input_channel: int, cfg: Dict) -> Tuple[nn.Sequential, int]:
+ output_channels = cfg.get("out_channels")
+ num_blocks = cfg.get("num_blocks", 2)
+ expand_ratio = cfg.get("expand_ratio", 4)
+ block = []
+
+ for i in range(num_blocks):
+ stride = cfg.get("stride", 1) if i == 0 else 1
+
+ layer = InvertedResidual(
+ in_channels=input_channel,
+ out_channels=output_channels,
+ stride=stride,
+ expand_ratio=expand_ratio
+ )
+ block.append(layer)
+ input_channel = output_channels
+
+ return nn.Sequential(*block), input_channel
+
+ @staticmethod
+ def _make_mit_layer(input_channel: int, cfg: Dict) -> [nn.Sequential, int]:
+ stride = cfg.get("stride", 1)
+ block = []
+
+ if stride == 2:
+ layer = InvertedResidual(
+ in_channels=input_channel,
+ out_channels=cfg.get("out_channels"),
+ stride=stride,
+ expand_ratio=cfg.get("mv_expand_ratio", 4)
+ )
+
+ block.append(layer)
+ input_channel = cfg.get("out_channels")
+
+ transformer_dim = cfg["transformer_channels"]
+ ffn_dim = cfg.get("ffn_dim")
+ num_heads = cfg.get("num_heads", 4)
+ head_dim = transformer_dim // num_heads
+
+ if transformer_dim % head_dim != 0:
+ raise ValueError("Transformer input dimension should be divisible by head dimension. "
+ "Got {} and {}.".format(transformer_dim, head_dim))
+
+ block.append(MobileViTBlock(
+ in_channels=input_channel,
+ transformer_dim=transformer_dim,
+ ffn_dim=ffn_dim,
+ n_transformer_blocks=cfg.get("transformer_blocks", 1),
+ patch_h=cfg.get("patch_h", 2),
+ patch_w=cfg.get("patch_w", 2),
+ dropout=cfg.get("dropout", 0.1),
+ ffn_dropout=cfg.get("ffn_dropout", 0.0),
+ attn_dropout=cfg.get("attn_dropout", 0.1),
+ head_dim=head_dim,
+ conv_ksize=3
+ ))
+
+ return nn.Sequential(*block), input_channel
+
+ @staticmethod
+ def init_parameters(m):
+ if isinstance(m, nn.Conv2d):
+ if m.weight is not None:
+ nn.init.kaiming_normal_(m.weight, mode="fan_out")
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
+ if m.weight is not None:
+ nn.init.ones_(m.weight)
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, (nn.Linear,)):
+ if m.weight is not None:
+ nn.init.trunc_normal_(m.weight, mean=0.0, std=0.02)
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+ else:
+ pass
+
+ def forward(self, x: Tensor) -> Tensor:
+ x = self.conv_1(x)
+ x = self.layer_1(x)
+ x = self.layer_2(x)
+
+ x = self.layer_3(x)
+ x = self.layer_4(x)
+ x = self.layer_5(x)
+ x = self.conv_1x1_exp(x)
+ x = self.classifier(x)
+ return x
+
+
+def mobile_vit_xx_small(num_classes: int = 1000):
+ # pretrain weight link
+ # https://docs-assets.developer.apple.com/ml-research/models/cvnets/classification/mobilevit_xxs.pt
+ config = get_config("xx_small")
+ m = MobileViT(config, num_classes=num_classes)
+ return m
+
+
+def mobile_vit_x_small(num_classes: int = 1000):
+ # pretrain weight link
+ # https://docs-assets.developer.apple.com/ml-research/models/cvnets/classification/mobilevit_xs.pt
+ config = get_config("x_small")
+ m = MobileViT(config, num_classes=num_classes)
+ return m
+
+
+def mobile_vit_small(num_classes: int = 1000):
+ # pretrain weight link
+ # https://docs-assets.developer.apple.com/ml-research/models/cvnets/classification/mobilevit_s.pt
+ config = get_config("small")
+ m = MobileViT(config, num_classes=num_classes)
+ return m
diff --git a/pytorch_classification/MobileViT/model_config.py b/pytorch_classification/MobileViT/model_config.py
new file mode 100644
index 000000000..932a0a0f0
--- /dev/null
+++ b/pytorch_classification/MobileViT/model_config.py
@@ -0,0 +1,176 @@
+def get_config(mode: str = "xxs") -> dict:
+ if mode == "xx_small":
+ mv2_exp_mult = 2
+ config = {
+ "layer1": {
+ "out_channels": 16,
+ "expand_ratio": mv2_exp_mult,
+ "num_blocks": 1,
+ "stride": 1,
+ "block_type": "mv2",
+ },
+ "layer2": {
+ "out_channels": 24,
+ "expand_ratio": mv2_exp_mult,
+ "num_blocks": 3,
+ "stride": 2,
+ "block_type": "mv2",
+ },
+ "layer3": { # 28x28
+ "out_channels": 48,
+ "transformer_channels": 64,
+ "ffn_dim": 128,
+ "transformer_blocks": 2,
+ "patch_h": 2, # 8,
+ "patch_w": 2, # 8,
+ "stride": 2,
+ "mv_expand_ratio": mv2_exp_mult,
+ "num_heads": 4,
+ "block_type": "mobilevit",
+ },
+ "layer4": { # 14x14
+ "out_channels": 64,
+ "transformer_channels": 80,
+ "ffn_dim": 160,
+ "transformer_blocks": 4,
+ "patch_h": 2, # 4,
+ "patch_w": 2, # 4,
+ "stride": 2,
+ "mv_expand_ratio": mv2_exp_mult,
+ "num_heads": 4,
+ "block_type": "mobilevit",
+ },
+ "layer5": { # 7x7
+ "out_channels": 80,
+ "transformer_channels": 96,
+ "ffn_dim": 192,
+ "transformer_blocks": 3,
+ "patch_h": 2,
+ "patch_w": 2,
+ "stride": 2,
+ "mv_expand_ratio": mv2_exp_mult,
+ "num_heads": 4,
+ "block_type": "mobilevit",
+ },
+ "last_layer_exp_factor": 4,
+ "cls_dropout": 0.1
+ }
+ elif mode == "x_small":
+ mv2_exp_mult = 4
+ config = {
+ "layer1": {
+ "out_channels": 32,
+ "expand_ratio": mv2_exp_mult,
+ "num_blocks": 1,
+ "stride": 1,
+ "block_type": "mv2",
+ },
+ "layer2": {
+ "out_channels": 48,
+ "expand_ratio": mv2_exp_mult,
+ "num_blocks": 3,
+ "stride": 2,
+ "block_type": "mv2",
+ },
+ "layer3": { # 28x28
+ "out_channels": 64,
+ "transformer_channels": 96,
+ "ffn_dim": 192,
+ "transformer_blocks": 2,
+ "patch_h": 2,
+ "patch_w": 2,
+ "stride": 2,
+ "mv_expand_ratio": mv2_exp_mult,
+ "num_heads": 4,
+ "block_type": "mobilevit",
+ },
+ "layer4": { # 14x14
+ "out_channels": 80,
+ "transformer_channels": 120,
+ "ffn_dim": 240,
+ "transformer_blocks": 4,
+ "patch_h": 2,
+ "patch_w": 2,
+ "stride": 2,
+ "mv_expand_ratio": mv2_exp_mult,
+ "num_heads": 4,
+ "block_type": "mobilevit",
+ },
+ "layer5": { # 7x7
+ "out_channels": 96,
+ "transformer_channels": 144,
+ "ffn_dim": 288,
+ "transformer_blocks": 3,
+ "patch_h": 2,
+ "patch_w": 2,
+ "stride": 2,
+ "mv_expand_ratio": mv2_exp_mult,
+ "num_heads": 4,
+ "block_type": "mobilevit",
+ },
+ "last_layer_exp_factor": 4,
+ "cls_dropout": 0.1
+ }
+ elif mode == "small":
+ mv2_exp_mult = 4
+ config = {
+ "layer1": {
+ "out_channels": 32,
+ "expand_ratio": mv2_exp_mult,
+ "num_blocks": 1,
+ "stride": 1,
+ "block_type": "mv2",
+ },
+ "layer2": {
+ "out_channels": 64,
+ "expand_ratio": mv2_exp_mult,
+ "num_blocks": 3,
+ "stride": 2,
+ "block_type": "mv2",
+ },
+ "layer3": { # 28x28
+ "out_channels": 96,
+ "transformer_channels": 144,
+ "ffn_dim": 288,
+ "transformer_blocks": 2,
+ "patch_h": 2,
+ "patch_w": 2,
+ "stride": 2,
+ "mv_expand_ratio": mv2_exp_mult,
+ "num_heads": 4,
+ "block_type": "mobilevit",
+ },
+ "layer4": { # 14x14
+ "out_channels": 128,
+ "transformer_channels": 192,
+ "ffn_dim": 384,
+ "transformer_blocks": 4,
+ "patch_h": 2,
+ "patch_w": 2,
+ "stride": 2,
+ "mv_expand_ratio": mv2_exp_mult,
+ "num_heads": 4,
+ "block_type": "mobilevit",
+ },
+ "layer5": { # 7x7
+ "out_channels": 160,
+ "transformer_channels": 240,
+ "ffn_dim": 480,
+ "transformer_blocks": 3,
+ "patch_h": 2,
+ "patch_w": 2,
+ "stride": 2,
+ "mv_expand_ratio": mv2_exp_mult,
+ "num_heads": 4,
+ "block_type": "mobilevit",
+ },
+ "last_layer_exp_factor": 4,
+ "cls_dropout": 0.1
+ }
+ else:
+ raise NotImplementedError
+
+ for k in ["layer1", "layer2", "layer3", "layer4", "layer5"]:
+ config[k].update({"dropout": 0.1, "ffn_dropout": 0.0, "attn_dropout": 0.0})
+
+ return config
diff --git a/pytorch_classification/MobileViT/my_dataset.py b/pytorch_classification/MobileViT/my_dataset.py
new file mode 100644
index 000000000..167bc9a30
--- /dev/null
+++ b/pytorch_classification/MobileViT/my_dataset.py
@@ -0,0 +1,37 @@
+from PIL import Image
+import torch
+from torch.utils.data import Dataset
+
+
+class MyDataSet(Dataset):
+ """自定义数据集"""
+
+ def __init__(self, images_path: list, images_class: list, transform=None):
+ self.images_path = images_path
+ self.images_class = images_class
+ self.transform = transform
+
+ def __len__(self):
+ return len(self.images_path)
+
+ def __getitem__(self, item):
+ img = Image.open(self.images_path[item])
+ # RGB为彩色图片,L为灰度图片
+ if img.mode != 'RGB':
+ raise ValueError("image: {} isn't RGB mode.".format(self.images_path[item]))
+ label = self.images_class[item]
+
+ if self.transform is not None:
+ img = self.transform(img)
+
+ return img, label
+
+ @staticmethod
+ def collate_fn(batch):
+ # 官方实现的default_collate可以参考
+ # https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py
+ images, labels = tuple(zip(*batch))
+
+ images = torch.stack(images, dim=0)
+ labels = torch.as_tensor(labels)
+ return images, labels
diff --git a/pytorch_classification/MobileViT/predict.py b/pytorch_classification/MobileViT/predict.py
new file mode 100644
index 000000000..525260912
--- /dev/null
+++ b/pytorch_classification/MobileViT/predict.py
@@ -0,0 +1,61 @@
+import os
+import json
+
+import torch
+from PIL import Image
+from torchvision import transforms
+import matplotlib.pyplot as plt
+
+from model import mobile_vit_xx_small as create_model
+
+
+def main():
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+
+ img_size = 224
+ data_transform = transforms.Compose(
+ [transforms.Resize(int(img_size * 1.14)),
+ transforms.CenterCrop(img_size),
+ transforms.ToTensor(),
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
+
+ # load image
+ img_path = "../tulip.jpg"
+ assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
+ img = Image.open(img_path)
+ plt.imshow(img)
+ # [N, C, H, W]
+ img = data_transform(img)
+ # expand batch dimension
+ img = torch.unsqueeze(img, dim=0)
+
+ # read class_indict
+ json_path = './class_indices.json'
+ assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
+
+ with open(json_path, "r") as f:
+ class_indict = json.load(f)
+
+ # create model
+ model = create_model(num_classes=5).to(device)
+ # load model weights
+ model_weight_path = "./weights/best_model.pth"
+ model.load_state_dict(torch.load(model_weight_path, map_location=device))
+ model.eval()
+ with torch.no_grad():
+ # predict class
+ output = torch.squeeze(model(img.to(device))).cpu()
+ predict = torch.softmax(output, dim=0)
+ predict_cla = torch.argmax(predict).numpy()
+
+ print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)],
+ predict[predict_cla].numpy())
+ plt.title(print_res)
+ for i in range(len(predict)):
+ print("class: {:10} prob: {:.3}".format(class_indict[str(i)],
+ predict[i].numpy()))
+ plt.show()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/pytorch_classification/MobileViT/train.py b/pytorch_classification/MobileViT/train.py
new file mode 100644
index 000000000..edb26ee98
--- /dev/null
+++ b/pytorch_classification/MobileViT/train.py
@@ -0,0 +1,135 @@
+import os
+import argparse
+
+import torch
+import torch.optim as optim
+from torch.utils.tensorboard import SummaryWriter
+from torchvision import transforms
+
+from my_dataset import MyDataSet
+from model import mobile_vit_xx_small as create_model
+from utils import read_split_data, train_one_epoch, evaluate
+
+
+def main(args):
+ device = torch.device(args.device if torch.cuda.is_available() else "cpu")
+
+ if os.path.exists("./weights") is False:
+ os.makedirs("./weights")
+
+ tb_writer = SummaryWriter()
+
+ train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path)
+
+ img_size = 224
+ data_transform = {
+ "train": transforms.Compose([transforms.RandomResizedCrop(img_size),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
+ "val": transforms.Compose([transforms.Resize(int(img_size * 1.143)),
+ transforms.CenterCrop(img_size),
+ transforms.ToTensor(),
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
+
+ # 实例化训练数据集
+ train_dataset = MyDataSet(images_path=train_images_path,
+ images_class=train_images_label,
+ transform=data_transform["train"])
+
+ # 实例化验证数据集
+ val_dataset = MyDataSet(images_path=val_images_path,
+ images_class=val_images_label,
+ transform=data_transform["val"])
+
+ batch_size = args.batch_size
+ nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
+ print('Using {} dataloader workers every process'.format(nw))
+ train_loader = torch.utils.data.DataLoader(train_dataset,
+ batch_size=batch_size,
+ shuffle=True,
+ pin_memory=True,
+ num_workers=nw,
+ collate_fn=train_dataset.collate_fn)
+
+ val_loader = torch.utils.data.DataLoader(val_dataset,
+ batch_size=batch_size,
+ shuffle=False,
+ pin_memory=True,
+ num_workers=nw,
+ collate_fn=val_dataset.collate_fn)
+
+ model = create_model(num_classes=args.num_classes).to(device)
+
+ if args.weights != "":
+ assert os.path.exists(args.weights), "weights file: '{}' not exist.".format(args.weights)
+ weights_dict = torch.load(args.weights, map_location=device)
+ weights_dict = weights_dict["model"] if "model" in weights_dict else weights_dict
+ # 删除有关分类类别的权重
+ for k in list(weights_dict.keys()):
+ if "classifier" in k:
+ del weights_dict[k]
+ print(model.load_state_dict(weights_dict, strict=False))
+
+ if args.freeze_layers:
+ for name, para in model.named_parameters():
+ # 除head外,其他权重全部冻结
+ if "classifier" not in name:
+ para.requires_grad_(False)
+ else:
+ print("training {}".format(name))
+
+ pg = [p for p in model.parameters() if p.requires_grad]
+ optimizer = optim.AdamW(pg, lr=args.lr, weight_decay=1E-2)
+
+ best_acc = 0.
+ for epoch in range(args.epochs):
+ # train
+ train_loss, train_acc = train_one_epoch(model=model,
+ optimizer=optimizer,
+ data_loader=train_loader,
+ device=device,
+ epoch=epoch)
+
+ # validate
+ val_loss, val_acc = evaluate(model=model,
+ data_loader=val_loader,
+ device=device,
+ epoch=epoch)
+
+ tags = ["train_loss", "train_acc", "val_loss", "val_acc", "learning_rate"]
+ tb_writer.add_scalar(tags[0], train_loss, epoch)
+ tb_writer.add_scalar(tags[1], train_acc, epoch)
+ tb_writer.add_scalar(tags[2], val_loss, epoch)
+ tb_writer.add_scalar(tags[3], val_acc, epoch)
+ tb_writer.add_scalar(tags[4], optimizer.param_groups[0]["lr"], epoch)
+
+ if val_acc > best_acc:
+ best_acc = val_acc
+ torch.save(model.state_dict(), "./weights/best_model.pth")
+
+ torch.save(model.state_dict(), "./weights/latest_model.pth")
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--num_classes', type=int, default=5)
+ parser.add_argument('--epochs', type=int, default=10)
+ parser.add_argument('--batch-size', type=int, default=8)
+ parser.add_argument('--lr', type=float, default=0.0002)
+
+ # 数据集所在根目录
+ # https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
+ parser.add_argument('--data-path', type=str,
+ default="/data/flower_photos")
+
+ # 预训练权重路径,如果不想载入就设置为空字符
+ parser.add_argument('--weights', type=str, default='./mobilevit_xxs.pt',
+ help='initial weights path')
+ # 是否冻结权重
+ parser.add_argument('--freeze-layers', type=bool, default=False)
+ parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')
+
+ opt = parser.parse_args()
+
+ main(opt)
diff --git a/pytorch_classification/MobileViT/transformer.py b/pytorch_classification/MobileViT/transformer.py
new file mode 100644
index 000000000..1124820df
--- /dev/null
+++ b/pytorch_classification/MobileViT/transformer.py
@@ -0,0 +1,155 @@
+from typing import Optional
+
+import torch
+import torch.nn as nn
+from torch import Tensor
+
+
+class MultiHeadAttention(nn.Module):
+ """
+ This layer applies a multi-head self- or cross-attention as described in
+ `Attention is all you need `_ paper
+
+ Args:
+ embed_dim (int): :math:`C_{in}` from an expected input of size :math:`(N, P, C_{in})`
+ num_heads (int): Number of heads in multi-head attention
+ attn_dropout (float): Attention dropout. Default: 0.0
+ bias (bool): Use bias or not. Default: ``True``
+
+ Shape:
+ - Input: :math:`(N, P, C_{in})` where :math:`N` is batch size, :math:`P` is number of patches,
+ and :math:`C_{in}` is input embedding dim
+ - Output: same shape as the input
+
+ """
+
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ attn_dropout: float = 0.0,
+ bias: bool = True,
+ *args,
+ **kwargs
+ ) -> None:
+ super().__init__()
+ if embed_dim % num_heads != 0:
+ raise ValueError(
+ "Embedding dim must be divisible by number of heads in {}. Got: embed_dim={} and num_heads={}".format(
+ self.__class__.__name__, embed_dim, num_heads
+ )
+ )
+
+ self.qkv_proj = nn.Linear(in_features=embed_dim, out_features=3 * embed_dim, bias=bias)
+
+ self.attn_dropout = nn.Dropout(p=attn_dropout)
+ self.out_proj = nn.Linear(in_features=embed_dim, out_features=embed_dim, bias=bias)
+
+ self.head_dim = embed_dim // num_heads
+ self.scaling = self.head_dim ** -0.5
+ self.softmax = nn.Softmax(dim=-1)
+ self.num_heads = num_heads
+ self.embed_dim = embed_dim
+
+ def forward(self, x_q: Tensor) -> Tensor:
+ # [N, P, C]
+ b_sz, n_patches, in_channels = x_q.shape
+
+ # self-attention
+ # [N, P, C] -> [N, P, 3C] -> [N, P, 3, h, c] where C = hc
+ qkv = self.qkv_proj(x_q).reshape(b_sz, n_patches, 3, self.num_heads, -1)
+
+ # [N, P, 3, h, c] -> [N, h, 3, P, C]
+ qkv = qkv.transpose(1, 3).contiguous()
+
+ # [N, h, 3, P, C] -> [N, h, P, C] x 3
+ query, key, value = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
+
+ query = query * self.scaling
+
+ # [N h, P, c] -> [N, h, c, P]
+ key = key.transpose(-1, -2)
+
+ # QK^T
+ # [N, h, P, c] x [N, h, c, P] -> [N, h, P, P]
+ attn = torch.matmul(query, key)
+ attn = self.softmax(attn)
+ attn = self.attn_dropout(attn)
+
+ # weighted sum
+ # [N, h, P, P] x [N, h, P, c] -> [N, h, P, c]
+ out = torch.matmul(attn, value)
+
+ # [N, h, P, c] -> [N, P, h, c] -> [N, P, C]
+ out = out.transpose(1, 2).reshape(b_sz, n_patches, -1)
+ out = self.out_proj(out)
+
+ return out
+
+
+class TransformerEncoder(nn.Module):
+ """
+ This class defines the pre-norm `Transformer encoder `_
+ Args:
+ embed_dim (int): :math:`C_{in}` from an expected input of size :math:`(N, P, C_{in})`
+ ffn_latent_dim (int): Inner dimension of the FFN
+ num_heads (int) : Number of heads in multi-head attention. Default: 8
+ attn_dropout (float): Dropout rate for attention in multi-head attention. Default: 0.0
+ dropout (float): Dropout rate. Default: 0.0
+ ffn_dropout (float): Dropout between FFN layers. Default: 0.0
+
+ Shape:
+ - Input: :math:`(N, P, C_{in})` where :math:`N` is batch size, :math:`P` is number of patches,
+ and :math:`C_{in}` is input embedding dim
+ - Output: same shape as the input
+ """
+
+ def __init__(
+ self,
+ embed_dim: int,
+ ffn_latent_dim: int,
+ num_heads: Optional[int] = 8,
+ attn_dropout: Optional[float] = 0.0,
+ dropout: Optional[float] = 0.0,
+ ffn_dropout: Optional[float] = 0.0,
+ *args,
+ **kwargs
+ ) -> None:
+
+ super().__init__()
+
+ attn_unit = MultiHeadAttention(
+ embed_dim,
+ num_heads,
+ attn_dropout=attn_dropout,
+ bias=True
+ )
+
+ self.pre_norm_mha = nn.Sequential(
+ nn.LayerNorm(embed_dim),
+ attn_unit,
+ nn.Dropout(p=dropout)
+ )
+
+ self.pre_norm_ffn = nn.Sequential(
+ nn.LayerNorm(embed_dim),
+ nn.Linear(in_features=embed_dim, out_features=ffn_latent_dim, bias=True),
+ nn.SiLU(),
+ nn.Dropout(p=ffn_dropout),
+ nn.Linear(in_features=ffn_latent_dim, out_features=embed_dim, bias=True),
+ nn.Dropout(p=dropout)
+ )
+ self.embed_dim = embed_dim
+ self.ffn_dim = ffn_latent_dim
+ self.ffn_dropout = ffn_dropout
+ self.std_dropout = dropout
+
+ def forward(self, x: Tensor) -> Tensor:
+ # multi-head attention
+ res = x
+ x = self.pre_norm_mha(x)
+ x = x + res
+
+ # feed forward network
+ x = x + self.pre_norm_ffn(x)
+ return x
diff --git a/pytorch_classification/MobileViT/unfold_test.py b/pytorch_classification/MobileViT/unfold_test.py
new file mode 100644
index 000000000..6370a4b7d
--- /dev/null
+++ b/pytorch_classification/MobileViT/unfold_test.py
@@ -0,0 +1,56 @@
+import time
+import torch
+
+batch_size = 8
+in_channels = 32
+patch_h = 2
+patch_w = 2
+num_patch_h = 16
+num_patch_w = 16
+num_patches = num_patch_h * num_patch_w
+patch_area = patch_h * patch_w
+
+
+def official(x: torch.Tensor):
+ # [B, C, H, W] -> [B * C * n_h, p_h, n_w, p_w]
+ x = x.reshape(batch_size * in_channels * num_patch_h, patch_h, num_patch_w, patch_w)
+ # [B * C * n_h, p_h, n_w, p_w] -> [B * C * n_h, n_w, p_h, p_w]
+ x = x.transpose(1, 2)
+ # [B * C * n_h, n_w, p_h, p_w] -> [B, C, N, P] where P = p_h * p_w and N = n_h * n_w
+ x = x.reshape(batch_size, in_channels, num_patches, patch_area)
+ # [B, C, N, P] -> [B, P, N, C]
+ x = x.transpose(1, 3)
+ # [B, P, N, C] -> [BP, N, C]
+ x = x.reshape(batch_size * patch_area, num_patches, -1)
+
+ return x
+
+
+def my_self(x: torch.Tensor):
+ # [B, C, H, W] -> [B, C, n_h, p_h, n_w, p_w]
+ x = x.reshape(batch_size, in_channels, num_patch_h, patch_h, num_patch_w, patch_w)
+ # [B, C, n_h, p_h, n_w, p_w] -> [B, C, n_h, n_w, p_h, p_w]
+ x = x.transpose(3, 4)
+ # [B, C, n_h, n_w, p_h, p_w] -> [B, C, N, P] where P = p_h * p_w and N = n_h * n_w
+ x = x.reshape(batch_size, in_channels, num_patches, patch_area)
+ # [B, C, N, P] -> [B, P, N, C]
+ x = x.transpose(1, 3)
+ # [B, P, N, C] -> [BP, N, C]
+ x = x.reshape(batch_size * patch_area, num_patches, -1)
+
+ return x
+
+
+if __name__ == '__main__':
+ t = torch.randn(batch_size, in_channels, num_patch_h * patch_h, num_patch_w * patch_w)
+ print(torch.equal(official(t), my_self(t)))
+
+ t1 = time.time()
+ for _ in range(1000):
+ official(t)
+ print(f"official time: {time.time() - t1}")
+
+ t1 = time.time()
+ for _ in range(1000):
+ my_self(t)
+ print(f"self time: {time.time() - t1}")
diff --git a/pytorch_classification/MobileViT/utils.py b/pytorch_classification/MobileViT/utils.py
new file mode 100644
index 000000000..da201e6eb
--- /dev/null
+++ b/pytorch_classification/MobileViT/utils.py
@@ -0,0 +1,179 @@
+import os
+import sys
+import json
+import pickle
+import random
+
+import torch
+from tqdm import tqdm
+
+import matplotlib.pyplot as plt
+
+
+def read_split_data(root: str, val_rate: float = 0.2):
+ random.seed(0) # 保证随机结果可复现
+ assert os.path.exists(root), "dataset root: {} does not exist.".format(root)
+
+ # 遍历文件夹,一个文件夹对应一个类别
+ flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
+ # 排序,保证各平台顺序一致
+ flower_class.sort()
+ # 生成类别名称以及对应的数字索引
+ class_indices = dict((k, v) for v, k in enumerate(flower_class))
+ json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4)
+ with open('class_indices.json', 'w') as json_file:
+ json_file.write(json_str)
+
+ train_images_path = [] # 存储训练集的所有图片路径
+ train_images_label = [] # 存储训练集图片对应索引信息
+ val_images_path = [] # 存储验证集的所有图片路径
+ val_images_label = [] # 存储验证集图片对应索引信息
+ every_class_num = [] # 存储每个类别的样本总数
+ supported = [".jpg", ".JPG", ".png", ".PNG"] # 支持的文件后缀类型
+ # 遍历每个文件夹下的文件
+ for cla in flower_class:
+ cla_path = os.path.join(root, cla)
+ # 遍历获取supported支持的所有文件路径
+ images = [os.path.join(root, cla, i) for i in os.listdir(cla_path)
+ if os.path.splitext(i)[-1] in supported]
+ # 排序,保证各平台顺序一致
+ images.sort()
+ # 获取该类别对应的索引
+ image_class = class_indices[cla]
+ # 记录该类别的样本数量
+ every_class_num.append(len(images))
+ # 按比例随机采样验证样本
+ val_path = random.sample(images, k=int(len(images) * val_rate))
+
+ for img_path in images:
+ if img_path in val_path: # 如果该路径在采样的验证集样本中则存入验证集
+ val_images_path.append(img_path)
+ val_images_label.append(image_class)
+ else: # 否则存入训练集
+ train_images_path.append(img_path)
+ train_images_label.append(image_class)
+
+ print("{} images were found in the dataset.".format(sum(every_class_num)))
+ print("{} images for training.".format(len(train_images_path)))
+ print("{} images for validation.".format(len(val_images_path)))
+ assert len(train_images_path) > 0, "number of training images must greater than 0."
+ assert len(val_images_path) > 0, "number of validation images must greater than 0."
+
+ plot_image = False
+ if plot_image:
+ # 绘制每种类别个数柱状图
+ plt.bar(range(len(flower_class)), every_class_num, align='center')
+ # 将横坐标0,1,2,3,4替换为相应的类别名称
+ plt.xticks(range(len(flower_class)), flower_class)
+ # 在柱状图上添加数值标签
+ for i, v in enumerate(every_class_num):
+ plt.text(x=i, y=v + 5, s=str(v), ha='center')
+ # 设置x坐标
+ plt.xlabel('image class')
+ # 设置y坐标
+ plt.ylabel('number of images')
+ # 设置柱状图的标题
+ plt.title('flower class distribution')
+ plt.show()
+
+ return train_images_path, train_images_label, val_images_path, val_images_label
+
+
+def plot_data_loader_image(data_loader):
+ batch_size = data_loader.batch_size
+ plot_num = min(batch_size, 4)
+
+ json_path = './class_indices.json'
+ assert os.path.exists(json_path), json_path + " does not exist."
+ json_file = open(json_path, 'r')
+ class_indices = json.load(json_file)
+
+ for data in data_loader:
+ images, labels = data
+ for i in range(plot_num):
+ # [C, H, W] -> [H, W, C]
+ img = images[i].numpy().transpose(1, 2, 0)
+ # 反Normalize操作
+ img = (img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]) * 255
+ label = labels[i].item()
+ plt.subplot(1, plot_num, i+1)
+ plt.xlabel(class_indices[str(label)])
+ plt.xticks([]) # 去掉x轴的刻度
+ plt.yticks([]) # 去掉y轴的刻度
+ plt.imshow(img.astype('uint8'))
+ plt.show()
+
+
+def write_pickle(list_info: list, file_name: str):
+ with open(file_name, 'wb') as f:
+ pickle.dump(list_info, f)
+
+
+def read_pickle(file_name: str) -> list:
+ with open(file_name, 'rb') as f:
+ info_list = pickle.load(f)
+ return info_list
+
+
+def train_one_epoch(model, optimizer, data_loader, device, epoch):
+ model.train()
+ loss_function = torch.nn.CrossEntropyLoss(label_smoothing=0.1)
+ accu_loss = torch.zeros(1).to(device) # 累计损失
+ accu_num = torch.zeros(1).to(device) # 累计预测正确的样本数
+ optimizer.zero_grad()
+
+ sample_num = 0
+ data_loader = tqdm(data_loader, file=sys.stdout)
+ for step, data in enumerate(data_loader):
+ images, labels = data
+ sample_num += images.shape[0]
+
+ pred = model(images.to(device))
+ pred_classes = torch.max(pred, dim=1)[1]
+ accu_num += torch.eq(pred_classes, labels.to(device)).sum()
+
+ loss = loss_function(pred, labels.to(device))
+ loss.backward()
+ accu_loss += loss.detach()
+
+ data_loader.desc = "[train epoch {}] loss: {:.3f}, acc: {:.3f}".format(epoch,
+ accu_loss.item() / (step + 1),
+ accu_num.item() / sample_num)
+
+ if not torch.isfinite(loss):
+ print('WARNING: non-finite loss, ending training ', loss)
+ sys.exit(1)
+
+ optimizer.step()
+ optimizer.zero_grad()
+
+ return accu_loss.item() / (step + 1), accu_num.item() / sample_num
+
+
+@torch.no_grad()
+def evaluate(model, data_loader, device, epoch):
+ loss_function = torch.nn.CrossEntropyLoss()
+
+ model.eval()
+
+ accu_num = torch.zeros(1).to(device) # 累计预测正确的样本数
+ accu_loss = torch.zeros(1).to(device) # 累计损失
+
+ sample_num = 0
+ data_loader = tqdm(data_loader, file=sys.stdout)
+ for step, data in enumerate(data_loader):
+ images, labels = data
+ sample_num += images.shape[0]
+
+ pred = model(images.to(device))
+ pred_classes = torch.max(pred, dim=1)[1]
+ accu_num += torch.eq(pred_classes, labels.to(device)).sum()
+
+ loss = loss_function(pred, labels.to(device))
+ accu_loss += loss
+
+ data_loader.desc = "[valid epoch {}] loss: {:.3f}, acc: {:.3f}".format(epoch,
+ accu_loss.item() / (step + 1),
+ accu_num.item() / sample_num)
+
+ return accu_loss.item() / (step + 1), accu_num.item() / sample_num
diff --git a/pytorch_classification/Test10_regnet/README.md b/pytorch_classification/Test10_regnet/README.md
new file mode 100644
index 000000000..4b41177f8
--- /dev/null
+++ b/pytorch_classification/Test10_regnet/README.md
@@ -0,0 +1,12 @@
+## 代码使用简介
+
+1. 下载好数据集,代码中默认使用的是花分类数据集,下载地址: [https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz](https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz),
+如果下载不了的话可以通过百度云链接下载: https://pan.baidu.com/s/1QLCTA4sXnQAw_yvxPj9szg 提取码:58p0
+2. 在`train.py`脚本中将`--data-path`设置成解压后的`flower_photos`文件夹绝对路径
+3. 下载预训练权重,根据自己使用的模型下载对应预训练权重: https://pan.baidu.com/s/1XTo3walj9ai7ZhWz7jh-YA 密码: 8lmu
+4. 在`train.py`脚本中将`--weights`参数设成下载好的预训练权重路径
+5. 设置好数据集的路径`--data-path`以及预训练权重的路径`--weights`就能使用`train.py`脚本开始训练了(训练过程中会自动生成`class_indices.json`文件)
+6. 在`predict.py`脚本中导入和训练脚本中同样的模型,并将`model_weight_path`设置成训练好的模型权重路径(默认保存在weights文件夹下)
+7. 在`predict.py`脚本中将`img_path`设置成你自己需要预测的图片绝对路径
+8. 设置好权重路径`model_weight_path`和预测的图片路径`img_path`就能使用`predict.py`脚本进行预测了
+9. 如果要使用自己的数据集,请按照花分类数据集的文件结构进行摆放(即一个类别对应一个文件夹),并且将训练以及预测脚本中的`num_classes`设置成你自己数据的类别数
diff --git a/pytorch_classification/Test10_regnet/predict.py b/pytorch_classification/Test10_regnet/predict.py
index d0f9b21b2..32df3cb2a 100644
--- a/pytorch_classification/Test10_regnet/predict.py
+++ b/pytorch_classification/Test10_regnet/predict.py
@@ -32,8 +32,8 @@ def main():
json_path = './class_indices.json'
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
- json_file = open(json_path, "r")
- class_indict = json.load(json_file)
+ with open(json_path, "r") as f:
+ class_indict = json.load(f)
# create model
model = create_regnet(model_name="RegNetY_400MF", num_classes=5).to(device)
diff --git a/pytorch_classification/Test10_regnet/train.py b/pytorch_classification/Test10_regnet/train.py
index 1a95cf567..19ce89940 100644
--- a/pytorch_classification/Test10_regnet/train.py
+++ b/pytorch_classification/Test10_regnet/train.py
@@ -123,7 +123,7 @@ def main(args):
parser.add_argument('--lrf', type=float, default=0.01)
# 数据集所在根目录
- # http://download.tensorflow.org/example_images/flower_photos.tgz
+ # https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
parser.add_argument('--data-path', type=str,
default="/data/flower_photos")
parser.add_argument('--model-name', default='RegNetY_400MF', help='create model name')
diff --git a/pytorch_classification/Test10_regnet/utils.py b/pytorch_classification/Test10_regnet/utils.py
index f4355900b..11f677974 100644
--- a/pytorch_classification/Test10_regnet/utils.py
+++ b/pytorch_classification/Test10_regnet/utils.py
@@ -16,7 +16,7 @@ def read_split_data(root: str, val_rate: float = 0.2):
# 遍历文件夹,一个文件夹对应一个类别
flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
- # 排序,保证顺序一致
+ # 排序,保证各平台顺序一致
flower_class.sort()
# 生成类别名称以及对应的数字索引
class_indices = dict((k, v) for v, k in enumerate(flower_class))
@@ -36,6 +36,8 @@ def read_split_data(root: str, val_rate: float = 0.2):
# 遍历获取supported支持的所有文件路径
images = [os.path.join(root, cla, i) for i in os.listdir(cla_path)
if os.path.splitext(i)[-1] in supported]
+ # 排序,保证各平台顺序一致
+ images.sort()
# 获取该类别对应的索引
image_class = class_indices[cla]
# 记录该类别的样本数量
@@ -54,6 +56,8 @@ def read_split_data(root: str, val_rate: float = 0.2):
print("{} images were found in the dataset.".format(sum(every_class_num)))
print("{} images for training.".format(len(train_images_path)))
print("{} images for validation.".format(len(val_images_path)))
+ assert len(train_images_path) > 0, "number of training images must greater than 0."
+ assert len(val_images_path) > 0, "number of validation images must greater than 0."
plot_image = False
if plot_image:
diff --git a/pytorch_classification/Test11_efficientnetV2/README.md b/pytorch_classification/Test11_efficientnetV2/README.md
new file mode 100644
index 000000000..36fb99997
--- /dev/null
+++ b/pytorch_classification/Test11_efficientnetV2/README.md
@@ -0,0 +1,12 @@
+## 代码使用简介
+
+1. 下载好数据集,代码中默认使用的是花分类数据集,下载地址: [https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz](https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz),
+如果下载不了的话可以通过百度云链接下载: https://pan.baidu.com/s/1QLCTA4sXnQAw_yvxPj9szg 提取码:58p0
+2. 在`train.py`脚本中将`--data-path`设置成解压后的`flower_photos`文件夹绝对路径
+3. 下载预训练权重,根据自己使用的模型下载对应预训练权重: https://pan.baidu.com/s/1uZX36rvrfEss-JGj4yfzbQ 密码: 5gu1
+4. 在`train.py`脚本中将`--weights`参数设成下载好的预训练权重路径
+5. 设置好数据集的路径`--data-path`以及预训练权重的路径`--weights`就能使用`train.py`脚本开始训练了(训练过程中会自动生成`class_indices.json`文件)
+6. 在`predict.py`脚本中导入和训练脚本中同样的模型,并将`model_weight_path`设置成训练好的模型权重路径(默认保存在weights文件夹下)
+7. 在`predict.py`脚本中将`img_path`设置成你自己需要预测的图片绝对路径
+8. 设置好权重路径`model_weight_path`和预测的图片路径`img_path`就能使用`predict.py`脚本进行预测了
+9. 如果要使用自己的数据集,请按照花分类数据集的文件结构进行摆放(即一个类别对应一个文件夹),并且将训练以及预测脚本中的`num_classes`设置成你自己数据的类别数
diff --git a/pytorch_classification/Test11_efficientnetV2/predict.py b/pytorch_classification/Test11_efficientnetV2/predict.py
index d803571c6..690ddec6b 100644
--- a/pytorch_classification/Test11_efficientnetV2/predict.py
+++ b/pytorch_classification/Test11_efficientnetV2/predict.py
@@ -37,8 +37,8 @@ def main():
json_path = './class_indices.json'
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
- json_file = open(json_path, "r")
- class_indict = json.load(json_file)
+ with open(json_path, "r") as f:
+ class_indict = json.load(f)
# create model
model = create_model(num_classes=5).to(device)
diff --git a/pytorch_classification/Test11_efficientnetV2/train.py b/pytorch_classification/Test11_efficientnetV2/train.py
index 7aaab9b5e..cfe08bff1 100644
--- a/pytorch_classification/Test11_efficientnetV2/train.py
+++ b/pytorch_classification/Test11_efficientnetV2/train.py
@@ -127,7 +127,7 @@ def main(args):
parser.add_argument('--lrf', type=float, default=0.01)
# 数据集所在根目录
- # http://download.tensorflow.org/example_images/flower_photos.tgz
+ # https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
parser.add_argument('--data-path', type=str,
default="/data/flower_photos")
diff --git a/pytorch_classification/Test11_efficientnetV2/utils.py b/pytorch_classification/Test11_efficientnetV2/utils.py
index 96ad54a4b..23c53a06f 100644
--- a/pytorch_classification/Test11_efficientnetV2/utils.py
+++ b/pytorch_classification/Test11_efficientnetV2/utils.py
@@ -16,7 +16,7 @@ def read_split_data(root: str, val_rate: float = 0.2):
# 遍历文件夹,一个文件夹对应一个类别
flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
- # 排序,保证顺序一致
+ # 排序,保证各平台顺序一致
flower_class.sort()
# 生成类别名称以及对应的数字索引
class_indices = dict((k, v) for v, k in enumerate(flower_class))
@@ -36,6 +36,8 @@ def read_split_data(root: str, val_rate: float = 0.2):
# 遍历获取supported支持的所有文件路径
images = [os.path.join(root, cla, i) for i in os.listdir(cla_path)
if os.path.splitext(i)[-1] in supported]
+ # 排序,保证各平台顺序一致
+ images.sort()
# 获取该类别对应的索引
image_class = class_indices[cla]
# 记录该类别的样本数量
@@ -54,6 +56,8 @@ def read_split_data(root: str, val_rate: float = 0.2):
print("{} images were found in the dataset.".format(sum(every_class_num)))
print("{} images for training.".format(len(train_images_path)))
print("{} images for validation.".format(len(val_images_path)))
+ assert len(train_images_path) > 0, "number of training images must greater than 0."
+ assert len(val_images_path) > 0, "number of validation images must greater than 0."
plot_image = False
if plot_image:
diff --git a/pytorch_classification/Test1_official_demo/predict.py b/pytorch_classification/Test1_official_demo/predict.py
index b1b597c00..c0ecf31f9 100644
--- a/pytorch_classification/Test1_official_demo/predict.py
+++ b/pytorch_classification/Test1_official_demo/predict.py
@@ -23,7 +23,7 @@ def main():
with torch.no_grad():
outputs = net(im)
- predict = torch.max(outputs, dim=1)[1].data.numpy()
+ predict = torch.max(outputs, dim=1)[1].numpy()
print(classes[int(predict)])
diff --git a/pytorch_classification/Test1_official_demo/train.py b/pytorch_classification/Test1_official_demo/train.py
index ae935ce03..fd61ddae2 100644
--- a/pytorch_classification/Test1_official_demo/train.py
+++ b/pytorch_classification/Test1_official_demo/train.py
@@ -25,7 +25,7 @@ def main():
val_loader = torch.utils.data.DataLoader(val_set, batch_size=5000,
shuffle=False, num_workers=0)
val_data_iter = iter(val_loader)
- val_image, val_label = val_data_iter.next()
+ val_image, val_label = next(val_data_iter)
# classes = ('plane', 'car', 'bird', 'cat',
# 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
diff --git a/pytorch_classification/Test2_alexnet/predict.py b/pytorch_classification/Test2_alexnet/predict.py
index 3b2fc1d7b..e96329867 100644
--- a/pytorch_classification/Test2_alexnet/predict.py
+++ b/pytorch_classification/Test2_alexnet/predict.py
@@ -32,8 +32,8 @@ def main():
json_path = './class_indices.json'
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
- json_file = open(json_path, "r")
- class_indict = json.load(json_file)
+ with open(json_path, "r") as f:
+ class_indict = json.load(f)
# create model
model = AlexNet(num_classes=5).to(device)
diff --git a/pytorch_classification/Test3_vggnet/predict.py b/pytorch_classification/Test3_vggnet/predict.py
index 248d4cbbc..a0375e9b7 100644
--- a/pytorch_classification/Test3_vggnet/predict.py
+++ b/pytorch_classification/Test3_vggnet/predict.py
@@ -31,8 +31,8 @@ def main():
json_path = './class_indices.json'
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
- json_file = open(json_path, "r")
- class_indict = json.load(json_file)
+ with open(json_path, "r") as f:
+ class_indict = json.load(f)
# create model
model = vgg(model_name="vgg16", num_classes=5).to(device)
diff --git a/pytorch_classification/Test4_googlenet/model.py b/pytorch_classification/Test4_googlenet/model.py
index 2282c56e9..954de7191 100644
--- a/pytorch_classification/Test4_googlenet/model.py
+++ b/pytorch_classification/Test4_googlenet/model.py
@@ -116,6 +116,8 @@ def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_pr
self.branch3 = nn.Sequential(
BasicConv2d(in_channels, ch5x5red, kernel_size=1),
+ # 在官方的实现中,其实是3x3的kernel并不是5x5,这里我也懒得改了,具体可以参考下面的issue
+ # Please see https://github.com/pytorch/vision/issues/906 for details.
BasicConv2d(ch5x5red, ch5x5, kernel_size=5, padding=2) # 保证输出大小等于输入大小
)
diff --git a/pytorch_classification/Test4_googlenet/predict.py b/pytorch_classification/Test4_googlenet/predict.py
index 11955e308..d91011fc8 100644
--- a/pytorch_classification/Test4_googlenet/predict.py
+++ b/pytorch_classification/Test4_googlenet/predict.py
@@ -31,8 +31,8 @@ def main():
json_path = './class_indices.json'
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
- json_file = open(json_path, "r")
- class_indict = json.load(json_file)
+ with open(json_path, "r") as f:
+ class_indict = json.load(f)
# create model
model = GoogLeNet(num_classes=5, aux_logits=False).to(device)
diff --git a/pytorch_classification/Test4_googlenet/train.py b/pytorch_classification/Test4_googlenet/train.py
index 0218478fd..32f8e0c10 100644
--- a/pytorch_classification/Test4_googlenet/train.py
+++ b/pytorch_classification/Test4_googlenet/train.py
@@ -60,8 +60,13 @@ def main():
# test_data_iter = iter(validate_loader)
# test_image, test_label = test_data_iter.next()
+ net = GoogLeNet(num_classes=5, aux_logits=True, init_weights=True)
+ # 如果要使用官方的预训练权重,注意是将权重载入官方的模型,不是我们自己实现的模型
+ # 官方的模型中使用了bn层以及改了一些参数,不能混用
+ # import torchvision
# net = torchvision.models.googlenet(num_classes=5)
# model_dict = net.state_dict()
+ # # 预训练权重下载地址: https://download.pytorch.org/models/googlenet-1378be20.pth
# pretrain_model = torch.load("googlenet.pth")
# del_list = ["aux1.fc2.weight", "aux1.fc2.bias",
# "aux2.fc2.weight", "aux2.fc2.bias",
@@ -69,7 +74,6 @@ def main():
# pretrain_dict = {k: v for k, v in pretrain_model.items() if k not in del_list}
# model_dict.update(pretrain_dict)
# net.load_state_dict(model_dict)
- net = GoogLeNet(num_classes=5, aux_logits=True, init_weights=True)
net.to(device)
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.0003)
diff --git a/pytorch_classification/Test5_resnet/predict.py b/pytorch_classification/Test5_resnet/predict.py
index c327741c7..f478b3bfd 100644
--- a/pytorch_classification/Test5_resnet/predict.py
+++ b/pytorch_classification/Test5_resnet/predict.py
@@ -32,8 +32,8 @@ def main():
json_path = './class_indices.json'
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
- json_file = open(json_path, "r")
- class_indict = json.load(json_file)
+ with open(json_path, "r") as f:
+ class_indict = json.load(f)
# create model
model = resnet34(num_classes=5).to(device)
diff --git a/pytorch_classification/Test5_resnet/train.py b/pytorch_classification/Test5_resnet/train.py
index 2f8befdd9..310b462ce 100644
--- a/pytorch_classification/Test5_resnet/train.py
+++ b/pytorch_classification/Test5_resnet/train.py
@@ -1,4 +1,5 @@
import os
+import sys
import json
import torch
@@ -62,7 +63,7 @@ def main():
# download url: https://download.pytorch.org/models/resnet34-333f7ec4.pth
model_weight_path = "./resnet34-pre.pth"
assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)
- net.load_state_dict(torch.load(model_weight_path, map_location=device))
+ net.load_state_dict(torch.load(model_weight_path, map_location='cpu'))
# for param in net.parameters():
# param.requires_grad = False
diff --git a/pytorch_classification/Test6_mobilenet/predict.py b/pytorch_classification/Test6_mobilenet/predict.py
index a8a03ceb9..a0e6df088 100644
--- a/pytorch_classification/Test6_mobilenet/predict.py
+++ b/pytorch_classification/Test6_mobilenet/predict.py
@@ -32,8 +32,8 @@ def main():
json_path = './class_indices.json'
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
- json_file = open(json_path, "r")
- class_indict = json.load(json_file)
+ with open(json_path, "r") as f:
+ class_indict = json.load(f)
# create model
model = MobileNetV2(num_classes=5).to(device)
diff --git a/pytorch_classification/Test6_mobilenet/train.py b/pytorch_classification/Test6_mobilenet/train.py
index 594185467..0fe629212 100644
--- a/pytorch_classification/Test6_mobilenet/train.py
+++ b/pytorch_classification/Test6_mobilenet/train.py
@@ -67,7 +67,7 @@ def main():
# download url: https://download.pytorch.org/models/mobilenet_v2-b0353104.pth
model_weight_path = "./mobilenet_v2.pth"
assert os.path.exists(model_weight_path), "file {} dose not exist.".format(model_weight_path)
- pre_weights = torch.load(model_weight_path, map_location=device)
+ pre_weights = torch.load(model_weight_path, map_location='cpu')
# delete classifier weights
pre_dict = {k: v for k, v in pre_weights.items() if net.state_dict()[k].numel() == v.numel()}
diff --git a/pytorch_classification/Test7_shufflenet/README.md b/pytorch_classification/Test7_shufflenet/README.md
new file mode 100644
index 000000000..c93d9df0e
--- /dev/null
+++ b/pytorch_classification/Test7_shufflenet/README.md
@@ -0,0 +1,12 @@
+## 代码使用简介
+
+1. 下载好数据集,代码中默认使用的是花分类数据集,下载地址: [https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz](https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz),
+如果下载不了的话可以通过百度云链接下载: https://pan.baidu.com/s/1QLCTA4sXnQAw_yvxPj9szg 提取码:58p0
+2. 在`train.py`脚本中将`--data-path`设置成解压后的`flower_photos`文件夹绝对路径
+3. 下载预训练权重,在`model.py`文件中每个模型都有提供预训练权重的下载地址,根据自己使用的模型下载对应预训练权重
+4. 在`train.py`脚本中将`--weights`参数设成下载好的预训练权重路径
+5. 设置好数据集的路径`--data-path`以及预训练权重的路径`--weights`就能使用`train.py`脚本开始训练了(训练过程中会自动生成`class_indices.json`文件)
+6. 在`predict.py`脚本中导入和训练脚本中同样的模型,并将`model_weight_path`设置成训练好的模型权重路径(默认保存在weights文件夹下)
+7. 在`predict.py`脚本中将`img_path`设置成你自己需要预测的图片绝对路径
+8. 设置好权重路径`model_weight_path`和预测的图片路径`img_path`就能使用`predict.py`脚本进行预测了
+9. 如果要使用自己的数据集,请按照花分类数据集的文件结构进行摆放(即一个类别对应一个文件夹),并且将训练以及预测脚本中的`num_classes`设置成你自己数据的类别数
diff --git a/pytorch_classification/Test7_shufflenet/model.py b/pytorch_classification/Test7_shufflenet/model.py
index adc1dfa48..dbdb81967 100644
--- a/pytorch_classification/Test7_shufflenet/model.py
+++ b/pytorch_classification/Test7_shufflenet/model.py
@@ -147,6 +147,23 @@ def forward(self, x: Tensor) -> Tensor:
return self._forward_impl(x)
+def shufflenet_v2_x0_5(num_classes=1000):
+ """
+ Constructs a ShuffleNetV2 with 0.5x output channels, as described in
+ `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
+ `.
+ weight: https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth
+
+ :param num_classes:
+ :return:
+ """
+ model = ShuffleNetV2(stages_repeats=[4, 8, 4],
+ stages_out_channels=[24, 48, 96, 192, 1024],
+ num_classes=num_classes)
+
+ return model
+
+
def shufflenet_v2_x1_0(num_classes=1000):
"""
Constructs a ShuffleNetV2 with 1.0x output channels, as described in
@@ -164,18 +181,35 @@ def shufflenet_v2_x1_0(num_classes=1000):
return model
-def shufflenet_v2_x0_5(num_classes=1000):
+def shufflenet_v2_x1_5(num_classes=1000):
"""
- Constructs a ShuffleNetV2 with 0.5x output channels, as described in
+ Constructs a ShuffleNetV2 with 1.0x output channels, as described in
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
`.
- weight: https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth
+ weight: https://download.pytorch.org/models/shufflenetv2_x1_5-3c479a10.pth
:param num_classes:
:return:
"""
model = ShuffleNetV2(stages_repeats=[4, 8, 4],
- stages_out_channels=[24, 48, 96, 192, 1024],
+ stages_out_channels=[24, 176, 352, 704, 1024],
+ num_classes=num_classes)
+
+ return model
+
+
+def shufflenet_v2_x2_0(num_classes=1000):
+ """
+ Constructs a ShuffleNetV2 with 1.0x output channels, as described in
+ `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
+ `.
+ weight: https://download.pytorch.org/models/shufflenetv2_x2_0-8be3c8ee.pth
+
+ :param num_classes:
+ :return:
+ """
+ model = ShuffleNetV2(stages_repeats=[4, 8, 4],
+ stages_out_channels=[24, 244, 488, 976, 2048],
num_classes=num_classes)
return model
diff --git a/pytorch_classification/Test7_shufflenet/predict.py b/pytorch_classification/Test7_shufflenet/predict.py
index 2d62e6eac..8845b0a42 100644
--- a/pytorch_classification/Test7_shufflenet/predict.py
+++ b/pytorch_classification/Test7_shufflenet/predict.py
@@ -32,8 +32,8 @@ def main():
json_path = './class_indices.json'
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
- json_file = open(json_path, "r")
- class_indict = json.load(json_file)
+ with open(json_path, "r") as f:
+ class_indict = json.load(f)
# create model
model = shufflenet_v2_x1_0(num_classes=5).to(device)
diff --git a/pytorch_classification/Test7_shufflenet/train.py b/pytorch_classification/Test7_shufflenet/train.py
index 59e148eb3..1973a72fe 100644
--- a/pytorch_classification/Test7_shufflenet/train.py
+++ b/pytorch_classification/Test7_shufflenet/train.py
@@ -118,7 +118,7 @@ def main(args):
parser.add_argument('--lrf', type=float, default=0.1)
# 数据集所在根目录
- # http://download.tensorflow.org/example_images/flower_photos.tgz
+ # https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
parser.add_argument('--data-path', type=str,
default="/data/flower_photos")
diff --git a/pytorch_classification/Test7_shufflenet/utils.py b/pytorch_classification/Test7_shufflenet/utils.py
index f4355900b..11f677974 100644
--- a/pytorch_classification/Test7_shufflenet/utils.py
+++ b/pytorch_classification/Test7_shufflenet/utils.py
@@ -16,7 +16,7 @@ def read_split_data(root: str, val_rate: float = 0.2):
# 遍历文件夹,一个文件夹对应一个类别
flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
- # 排序,保证顺序一致
+ # 排序,保证各平台顺序一致
flower_class.sort()
# 生成类别名称以及对应的数字索引
class_indices = dict((k, v) for v, k in enumerate(flower_class))
@@ -36,6 +36,8 @@ def read_split_data(root: str, val_rate: float = 0.2):
# 遍历获取supported支持的所有文件路径
images = [os.path.join(root, cla, i) for i in os.listdir(cla_path)
if os.path.splitext(i)[-1] in supported]
+ # 排序,保证各平台顺序一致
+ images.sort()
# 获取该类别对应的索引
image_class = class_indices[cla]
# 记录该类别的样本数量
@@ -54,6 +56,8 @@ def read_split_data(root: str, val_rate: float = 0.2):
print("{} images were found in the dataset.".format(sum(every_class_num)))
print("{} images for training.".format(len(train_images_path)))
print("{} images for validation.".format(len(val_images_path)))
+ assert len(train_images_path) > 0, "number of training images must greater than 0."
+ assert len(val_images_path) > 0, "number of validation images must greater than 0."
plot_image = False
if plot_image:
diff --git a/pytorch_classification/Test8_densenet/README.md b/pytorch_classification/Test8_densenet/README.md
new file mode 100644
index 000000000..c93d9df0e
--- /dev/null
+++ b/pytorch_classification/Test8_densenet/README.md
@@ -0,0 +1,12 @@
+## 代码使用简介
+
+1. 下载好数据集,代码中默认使用的是花分类数据集,下载地址: [https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz](https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz),
+如果下载不了的话可以通过百度云链接下载: https://pan.baidu.com/s/1QLCTA4sXnQAw_yvxPj9szg 提取码:58p0
+2. 在`train.py`脚本中将`--data-path`设置成解压后的`flower_photos`文件夹绝对路径
+3. 下载预训练权重,在`model.py`文件中每个模型都有提供预训练权重的下载地址,根据自己使用的模型下载对应预训练权重
+4. 在`train.py`脚本中将`--weights`参数设成下载好的预训练权重路径
+5. 设置好数据集的路径`--data-path`以及预训练权重的路径`--weights`就能使用`train.py`脚本开始训练了(训练过程中会自动生成`class_indices.json`文件)
+6. 在`predict.py`脚本中导入和训练脚本中同样的模型,并将`model_weight_path`设置成训练好的模型权重路径(默认保存在weights文件夹下)
+7. 在`predict.py`脚本中将`img_path`设置成你自己需要预测的图片绝对路径
+8. 设置好权重路径`model_weight_path`和预测的图片路径`img_path`就能使用`predict.py`脚本进行预测了
+9. 如果要使用自己的数据集,请按照花分类数据集的文件结构进行摆放(即一个类别对应一个文件夹),并且将训练以及预测脚本中的`num_classes`设置成你自己数据的类别数
diff --git a/pytorch_classification/Test8_densenet/predict.py b/pytorch_classification/Test8_densenet/predict.py
index aa9d5d9ab..535358bee 100644
--- a/pytorch_classification/Test8_densenet/predict.py
+++ b/pytorch_classification/Test8_densenet/predict.py
@@ -32,8 +32,8 @@ def main():
json_path = './class_indices.json'
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
- json_file = open(json_path, "r")
- class_indict = json.load(json_file)
+ with open(json_path, "r") as f:
+ class_indict = json.load(f)
# create model
model = densenet121(num_classes=5).to(device)
diff --git a/pytorch_classification/Test8_densenet/train.py b/pytorch_classification/Test8_densenet/train.py
index 7f628c3d0..07b615dd0 100644
--- a/pytorch_classification/Test8_densenet/train.py
+++ b/pytorch_classification/Test8_densenet/train.py
@@ -115,7 +115,7 @@ def main(args):
parser.add_argument('--lrf', type=float, default=0.1)
# 数据集所在根目录
- # http://download.tensorflow.org/example_images/flower_photos.tgz
+ # https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
parser.add_argument('--data-path', type=str,
default="/data/flower_photos")
diff --git a/pytorch_classification/Test8_densenet/utils.py b/pytorch_classification/Test8_densenet/utils.py
index f4355900b..11f677974 100644
--- a/pytorch_classification/Test8_densenet/utils.py
+++ b/pytorch_classification/Test8_densenet/utils.py
@@ -16,7 +16,7 @@ def read_split_data(root: str, val_rate: float = 0.2):
# 遍历文件夹,一个文件夹对应一个类别
flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
- # 排序,保证顺序一致
+ # 排序,保证各平台顺序一致
flower_class.sort()
# 生成类别名称以及对应的数字索引
class_indices = dict((k, v) for v, k in enumerate(flower_class))
@@ -36,6 +36,8 @@ def read_split_data(root: str, val_rate: float = 0.2):
# 遍历获取supported支持的所有文件路径
images = [os.path.join(root, cla, i) for i in os.listdir(cla_path)
if os.path.splitext(i)[-1] in supported]
+ # 排序,保证各平台顺序一致
+ images.sort()
# 获取该类别对应的索引
image_class = class_indices[cla]
# 记录该类别的样本数量
@@ -54,6 +56,8 @@ def read_split_data(root: str, val_rate: float = 0.2):
print("{} images were found in the dataset.".format(sum(every_class_num)))
print("{} images for training.".format(len(train_images_path)))
print("{} images for validation.".format(len(val_images_path)))
+ assert len(train_images_path) > 0, "number of training images must greater than 0."
+ assert len(val_images_path) > 0, "number of validation images must greater than 0."
plot_image = False
if plot_image:
diff --git a/pytorch_classification/Test9_efficientNet/README.md b/pytorch_classification/Test9_efficientNet/README.md
new file mode 100644
index 000000000..24fb5021d
--- /dev/null
+++ b/pytorch_classification/Test9_efficientNet/README.md
@@ -0,0 +1,12 @@
+## 代码使用简介
+
+1. 下载好数据集,代码中默认使用的是花分类数据集,下载地址: [https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz](https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz),
+如果下载不了的话可以通过百度云链接下载: https://pan.baidu.com/s/1QLCTA4sXnQAw_yvxPj9szg 提取码:58p0
+2. 在`train.py`脚本中将`--data-path`设置成解压后的`flower_photos`文件夹绝对路径
+3. 下载预训练权重,根据自己使用的模型下载对应预训练权重: https://pan.baidu.com/s/1ouX0UmjCsmSx3ZrqXbowjw 密码: 090i
+4. 在`train.py`脚本中将`--weights`参数设成下载好的预训练权重路径
+5. 设置好数据集的路径`--data-path`以及预训练权重的路径`--weights`就能使用`train.py`脚本开始训练了(训练过程中会自动生成`class_indices.json`文件)
+6. 在`predict.py`脚本中导入和训练脚本中同样的模型,并将`model_weight_path`设置成训练好的模型权重路径(默认保存在weights文件夹下)
+7. 在`predict.py`脚本中将`img_path`设置成你自己需要预测的图片绝对路径
+8. 设置好权重路径`model_weight_path`和预测的图片路径`img_path`就能使用`predict.py`脚本进行预测了
+9. 如果要使用自己的数据集,请按照花分类数据集的文件结构进行摆放(即一个类别对应一个文件夹),并且将训练以及预测脚本中的`num_classes`设置成你自己数据的类别数
diff --git a/pytorch_classification/Test9_efficientNet/predict.py b/pytorch_classification/Test9_efficientNet/predict.py
index 56a278123..22f8e40c8 100644
--- a/pytorch_classification/Test9_efficientNet/predict.py
+++ b/pytorch_classification/Test9_efficientNet/predict.py
@@ -42,8 +42,8 @@ def main():
json_path = './class_indices.json'
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
- json_file = open(json_path, "r")
- class_indict = json.load(json_file)
+ with open(json_path, "r") as f:
+ class_indict = json.load(f)
# create model
model = create_model(num_classes=5).to(device)
diff --git a/pytorch_classification/Test9_efficientNet/train.py b/pytorch_classification/Test9_efficientNet/train.py
index 52f07a7fa..e20ec0692 100644
--- a/pytorch_classification/Test9_efficientNet/train.py
+++ b/pytorch_classification/Test9_efficientNet/train.py
@@ -129,7 +129,7 @@ def main(args):
parser.add_argument('--lrf', type=float, default=0.01)
# 数据集所在根目录
- # http://download.tensorflow.org/example_images/flower_photos.tgz
+ # https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
parser.add_argument('--data-path', type=str,
default="/data/flower_photos")
diff --git a/pytorch_classification/Test9_efficientNet/utils.py b/pytorch_classification/Test9_efficientNet/utils.py
index f4355900b..11f677974 100644
--- a/pytorch_classification/Test9_efficientNet/utils.py
+++ b/pytorch_classification/Test9_efficientNet/utils.py
@@ -16,7 +16,7 @@ def read_split_data(root: str, val_rate: float = 0.2):
# 遍历文件夹,一个文件夹对应一个类别
flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
- # 排序,保证顺序一致
+ # 排序,保证各平台顺序一致
flower_class.sort()
# 生成类别名称以及对应的数字索引
class_indices = dict((k, v) for v, k in enumerate(flower_class))
@@ -36,6 +36,8 @@ def read_split_data(root: str, val_rate: float = 0.2):
# 遍历获取supported支持的所有文件路径
images = [os.path.join(root, cla, i) for i in os.listdir(cla_path)
if os.path.splitext(i)[-1] in supported]
+ # 排序,保证各平台顺序一致
+ images.sort()
# 获取该类别对应的索引
image_class = class_indices[cla]
# 记录该类别的样本数量
@@ -54,6 +56,8 @@ def read_split_data(root: str, val_rate: float = 0.2):
print("{} images were found in the dataset.".format(sum(every_class_num)))
print("{} images for training.".format(len(train_images_path)))
print("{} images for validation.".format(len(val_images_path)))
+ assert len(train_images_path) > 0, "number of training images must greater than 0."
+ assert len(val_images_path) > 0, "number of validation images must greater than 0."
plot_image = False
if plot_image:
diff --git a/pytorch_classification/custom_dataset/main.py b/pytorch_classification/custom_dataset/main.py
index 632756c74..3f987787c 100644
--- a/pytorch_classification/custom_dataset/main.py
+++ b/pytorch_classification/custom_dataset/main.py
@@ -6,7 +6,7 @@
from my_dataset import MyDataSet
from utils import read_split_data, plot_data_loader_image
-# http://download.tensorflow.org/example_images/flower_photos.tgz
+# https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
root = "/home/wz/my_github/data_set/flower_data/flower_photos" # 数据集所在根目录
diff --git a/pytorch_classification/grad_cam/README.md b/pytorch_classification/grad_cam/README.md
index f17087ebf..328600e1d 100644
--- a/pytorch_classification/grad_cam/README.md
+++ b/pytorch_classification/grad_cam/README.md
@@ -1 +1,12 @@
-Original Impl: https://github.com/jacobgil/pytorch-grad-cam
+## Grad-CAM
+- Original Impl: [https://github.com/jacobgil/pytorch-grad-cam](https://github.com/jacobgil/pytorch-grad-cam)
+- Grad-CAM简介: [https://b23.tv/1kccjmb](https://b23.tv/1kccjmb)
+- 使用Pytorch实现Grad-CAM并绘制热力图: [https://b23.tv/n1e60vN](https://b23.tv/n1e60vN)
+
+## 使用流程(替换成自己的网络)
+1. 将创建模型部分代码替换成自己创建模型的代码,并载入自己训练好的权重
+2. 根据自己网络设置合适的`target_layers`
+3. 根据自己的网络设置合适的预处理方法
+4. 将要预测的图片路径赋值给`img_path`
+5. 将感兴趣的类别id赋值给`target_category`
+
diff --git a/pytorch_classification/grad_cam/main_cnn.py b/pytorch_classification/grad_cam/main_cnn.py
index 6e2cb2476..254f8e767 100644
--- a/pytorch_classification/grad_cam/main_cnn.py
+++ b/pytorch_classification/grad_cam/main_cnn.py
@@ -5,7 +5,7 @@
import matplotlib.pyplot as plt
from torchvision import models
from torchvision import transforms
-from utils import GradCAM, show_cam_on_image
+from utils import GradCAM, show_cam_on_image, center_crop_img
def main():
@@ -31,10 +31,12 @@ def main():
assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
img = Image.open(img_path).convert('RGB')
img = np.array(img, dtype=np.uint8)
+ # img = center_crop_img(img, 224)
- # [N, C, H, W]
+ # [C, H, W]
img_tensor = data_transform(img)
# expand batch dimension
+ # [C, H, W] -> [N, C, H, W]
input_tensor = torch.unsqueeze(img_tensor, dim=0)
cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False)
diff --git a/pytorch_classification/grad_cam/main_swin.py b/pytorch_classification/grad_cam/main_swin.py
index 600cea00b..292d2a30c 100644
--- a/pytorch_classification/grad_cam/main_swin.py
+++ b/pytorch_classification/grad_cam/main_swin.py
@@ -1,4 +1,5 @@
import os
+import math
import numpy as np
import torch
from PIL import Image
@@ -9,15 +10,19 @@
class ResizeTransform:
- def __init__(self, height=7, width=7):
- self.height = height
- self.width = width
+ def __init__(self, im_h: int, im_w: int):
+ self.height = self.feature_size(im_h)
+ self.width = self.feature_size(im_w)
+
+ @staticmethod
+ def feature_size(s):
+ s = math.ceil(s / 4) # PatchEmbed
+ s = math.ceil(s / 2) # PatchMerging1
+ s = math.ceil(s / 2) # PatchMerging2
+ s = math.ceil(s / 2) # PatchMerging3
+ return s
def __call__(self, x):
- if isinstance(x, tuple):
- self.height = x[1]
- self.width = x[2]
- x = x[0]
result = x.reshape(x.size(0),
self.height,
self.width,
@@ -25,18 +30,24 @@ def __call__(self, x):
# Bring the channels to the first dimension,
# like in CNNs.
- result = result.transpose(2, 3).transpose(1, 2)
+ # [batch_size, H, W, C] -> [batch, C, H, W]
+ result = result.permute(0, 3, 1, 2)
return result
def main():
+ # 注意输入的图片必须是32的整数倍
+ # 否则由于padding的原因会出现注意力飘逸的问题
+ img_size = 224
+ assert img_size % 32 == 0
+
model = swin_base_patch4_window7_224()
# https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth
weights_path = "./swin_base_patch4_window7_224.pth"
model.load_state_dict(torch.load(weights_path, map_location="cpu")["model"], strict=False)
- target_layers = [model.layers[-2]]
+ target_layers = [model.norm]
data_transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
@@ -45,14 +56,16 @@ def main():
assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
img = Image.open(img_path).convert('RGB')
img = np.array(img, dtype=np.uint8)
- img = center_crop_img(img, 224)
+ img = center_crop_img(img, img_size)
- # [N, C, H, W]
+ # [C, H, W]
img_tensor = data_transform(img)
# expand batch dimension
+ # [C, H, W] -> [N, C, H, W]
input_tensor = torch.unsqueeze(img_tensor, dim=0)
- cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False, reshape_transform=ResizeTransform())
+ cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False,
+ reshape_transform=ResizeTransform(im_h=img_size, im_w=img_size))
target_category = 281 # tabby, tabby cat
# target_category = 254 # pug, pug-dog
diff --git a/pytorch_classification/grad_cam/main_vit.py b/pytorch_classification/grad_cam/main_vit.py
index 8fc126b94..44a95c1fa 100644
--- a/pytorch_classification/grad_cam/main_vit.py
+++ b/pytorch_classification/grad_cam/main_vit.py
@@ -17,6 +17,7 @@ def __init__(self, model):
def __call__(self, x):
# remove cls token and reshape
+ # [batch_size, num_tokens, token_dim]
result = x[:, 1:, :].reshape(x.size(0),
self.h,
self.w,
@@ -24,7 +25,8 @@ def __call__(self, x):
# Bring the channels to the first dimension,
# like in CNNs.
- result = result.transpose(2, 3).transpose(1, 2)
+ # [batch_size, H, W, C] -> [batch, C, H, W]
+ result = result.permute(0, 3, 1, 2)
return result
@@ -47,9 +49,10 @@ def main():
img = Image.open(img_path).convert('RGB')
img = np.array(img, dtype=np.uint8)
img = center_crop_img(img, 224)
- # [N, C, H, W]
+ # [C, H, W]
img_tensor = data_transform(img)
# expand batch dimension
+ # [C, H, W] -> [N, C, H, W]
input_tensor = torch.unsqueeze(img_tensor, dim=0)
cam = GradCAM(model=model,
diff --git a/pytorch_classification/grad_cam/utils.py b/pytorch_classification/grad_cam/utils.py
index 005e6c477..acbb0f4da 100644
--- a/pytorch_classification/grad_cam/utils.py
+++ b/pytorch_classification/grad_cam/utils.py
@@ -4,7 +4,7 @@
class ActivationsAndGradients:
""" Class for extracting activations and
- registering gradients from targetted intermediate layers """
+ registering gradients from targeted intermediate layers """
def __init__(self, model, target_layers, reshape_transform):
self.model = model
@@ -16,7 +16,7 @@ def __init__(self, model, target_layers, reshape_transform):
self.handles.append(
target_layer.register_forward_hook(
self.save_activation))
- # Backward compatability with older pytorch versions:
+ # Backward compatibility with older pytorch versions:
if hasattr(target_layer, 'register_full_backward_hook'):
self.handles.append(
target_layer.register_full_backward_hook(
@@ -70,7 +70,7 @@ def __init__(self,
@staticmethod
def get_cam_weights(grads):
- return np.mean(grads, axis=(2, 3))
+ return np.mean(grads, axis=(2, 3), keepdims=True)
@staticmethod
def get_loss(output, target_category):
@@ -81,7 +81,7 @@ def get_loss(output, target_category):
def get_cam_image(self, activations, grads):
weights = self.get_cam_weights(grads)
- weighted_activations = weights[:, :, None, None] * activations
+ weighted_activations = weights * activations
cam = weighted_activations.sum(axis=1)
return cam
diff --git a/pytorch_classification/swin_transformer/README.md b/pytorch_classification/swin_transformer/README.md
new file mode 100644
index 000000000..c93d9df0e
--- /dev/null
+++ b/pytorch_classification/swin_transformer/README.md
@@ -0,0 +1,12 @@
+## 代码使用简介
+
+1. 下载好数据集,代码中默认使用的是花分类数据集,下载地址: [https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz](https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz),
+如果下载不了的话可以通过百度云链接下载: https://pan.baidu.com/s/1QLCTA4sXnQAw_yvxPj9szg 提取码:58p0
+2. 在`train.py`脚本中将`--data-path`设置成解压后的`flower_photos`文件夹绝对路径
+3. 下载预训练权重,在`model.py`文件中每个模型都有提供预训练权重的下载地址,根据自己使用的模型下载对应预训练权重
+4. 在`train.py`脚本中将`--weights`参数设成下载好的预训练权重路径
+5. 设置好数据集的路径`--data-path`以及预训练权重的路径`--weights`就能使用`train.py`脚本开始训练了(训练过程中会自动生成`class_indices.json`文件)
+6. 在`predict.py`脚本中导入和训练脚本中同样的模型,并将`model_weight_path`设置成训练好的模型权重路径(默认保存在weights文件夹下)
+7. 在`predict.py`脚本中将`img_path`设置成你自己需要预测的图片绝对路径
+8. 设置好权重路径`model_weight_path`和预测的图片路径`img_path`就能使用`predict.py`脚本进行预测了
+9. 如果要使用自己的数据集,请按照花分类数据集的文件结构进行摆放(即一个类别对应一个文件夹),并且将训练以及预测脚本中的`num_classes`设置成你自己数据的类别数
diff --git a/pytorch_classification/swin_transformer/predict.py b/pytorch_classification/swin_transformer/predict.py
index 999fde040..26e95c584 100644
--- a/pytorch_classification/swin_transformer/predict.py
+++ b/pytorch_classification/swin_transformer/predict.py
@@ -33,8 +33,8 @@ def main():
json_path = './class_indices.json'
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
- json_file = open(json_path, "r")
- class_indict = json.load(json_file)
+ with open(json_path, "r") as f:
+ class_indict = json.load(f)
# create model
model = create_model(num_classes=5).to(device)
diff --git a/pytorch_classification/swin_transformer/train.py b/pytorch_classification/swin_transformer/train.py
index 047545cc9..845d77575 100644
--- a/pytorch_classification/swin_transformer/train.py
+++ b/pytorch_classification/swin_transformer/train.py
@@ -113,7 +113,7 @@ def main(args):
parser.add_argument('--lr', type=float, default=0.0001)
# 数据集所在根目录
- # http://download.tensorflow.org/example_images/flower_photos.tgz
+ # https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
parser.add_argument('--data-path', type=str,
default="/data/flower_photos")
diff --git a/pytorch_classification/swin_transformer/utils.py b/pytorch_classification/swin_transformer/utils.py
index 96ad54a4b..23c53a06f 100644
--- a/pytorch_classification/swin_transformer/utils.py
+++ b/pytorch_classification/swin_transformer/utils.py
@@ -16,7 +16,7 @@ def read_split_data(root: str, val_rate: float = 0.2):
# 遍历文件夹,一个文件夹对应一个类别
flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
- # 排序,保证顺序一致
+ # 排序,保证各平台顺序一致
flower_class.sort()
# 生成类别名称以及对应的数字索引
class_indices = dict((k, v) for v, k in enumerate(flower_class))
@@ -36,6 +36,8 @@ def read_split_data(root: str, val_rate: float = 0.2):
# 遍历获取supported支持的所有文件路径
images = [os.path.join(root, cla, i) for i in os.listdir(cla_path)
if os.path.splitext(i)[-1] in supported]
+ # 排序,保证各平台顺序一致
+ images.sort()
# 获取该类别对应的索引
image_class = class_indices[cla]
# 记录该类别的样本数量
@@ -54,6 +56,8 @@ def read_split_data(root: str, val_rate: float = 0.2):
print("{} images were found in the dataset.".format(sum(every_class_num)))
print("{} images for training.".format(len(train_images_path)))
print("{} images for validation.".format(len(val_images_path)))
+ assert len(train_images_path) > 0, "number of training images must greater than 0."
+ assert len(val_images_path) > 0, "number of validation images must greater than 0."
plot_image = False
if plot_image:
diff --git a/pytorch_classification/tensorboard_test/requirements.txt b/pytorch_classification/tensorboard_test/requirements.txt
index 15ba10ebe..c42b25958 100644
--- a/pytorch_classification/tensorboard_test/requirements.txt
+++ b/pytorch_classification/tensorboard_test/requirements.txt
@@ -1,6 +1,6 @@
torchvision==0.7.0
tqdm==4.42.1
matplotlib==3.2.1
-torch==1.6.0
+torch==1.13.1
Pillow
tensorboard
diff --git a/pytorch_classification/tensorboard_test/train.py b/pytorch_classification/tensorboard_test/train.py
index a2382e8da..25482b58b 100644
--- a/pytorch_classification/tensorboard_test/train.py
+++ b/pytorch_classification/tensorboard_test/train.py
@@ -150,7 +150,7 @@ def main(args):
parser.add_argument('--lrf', type=float, default=0.1)
# 数据集所在根目录
- # http://download.tensorflow.org/example_images/flower_photos.tgz
+ # https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
img_root = "/home/wz/my_project/my_github/data_set/flower_data/flower_photos"
parser.add_argument('--data-path', type=str, default=img_root)
diff --git a/pytorch_classification/train_multi_GPU/requirements.txt b/pytorch_classification/train_multi_GPU/requirements.txt
index f5fd2973e..a31687d0e 100644
--- a/pytorch_classification/train_multi_GPU/requirements.txt
+++ b/pytorch_classification/train_multi_GPU/requirements.txt
@@ -1,4 +1,4 @@
matplotlib==3.2.1
tqdm==4.42.1
torchvision==0.7.0
-torch==1.6.0
+torch==1.13.1
diff --git a/pytorch_classification/train_multi_GPU/train_multi_gpu_using_launch.py b/pytorch_classification/train_multi_GPU/train_multi_gpu_using_launch.py
index 6c5a84a31..944db144a 100644
--- a/pytorch_classification/train_multi_GPU/train_multi_gpu_using_launch.py
+++ b/pytorch_classification/train_multi_GPU/train_multi_gpu_using_launch.py
@@ -28,6 +28,7 @@ def main(args):
batch_size = args.batch_size
weights_path = args.weights
args.lr *= args.world_size # 学习率要根据并行GPU的数量进行倍增
+ checkpoint_path = ""
if rank == 0: # 在第一个进程中打印信息,并实例化tensorboard
print(args)
@@ -172,7 +173,7 @@ def main(args):
parser.add_argument('--syncBN', type=bool, default=True)
# 数据集所在根目录
- # http://download.tensorflow.org/example_images/flower_photos.tgz
+ # https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
parser.add_argument('--data-path', type=str, default="/home/wz/data_set/flower_data/flower_photos")
# resnet34 官方权重下载地址
diff --git a/pytorch_classification/train_multi_GPU/train_multi_gpu_using_spawn.py b/pytorch_classification/train_multi_GPU/train_multi_gpu_using_spawn.py
index 3aa3c7e28..1f8f9a564 100644
--- a/pytorch_classification/train_multi_GPU/train_multi_gpu_using_spawn.py
+++ b/pytorch_classification/train_multi_GPU/train_multi_gpu_using_spawn.py
@@ -46,6 +46,7 @@ def main_fun(rank, world_size, args):
batch_size = args.batch_size
weights_path = args.weights
args.lr *= args.world_size # 学习率要根据并行GPU的数量进行倍增
+ checkpoint_path = ""
if rank == 0: # 在第一个进程中打印信息,并实例化tensorboard
print(args)
@@ -191,7 +192,7 @@ def main_fun(rank, world_size, args):
parser.add_argument('--syncBN', type=bool, default=True)
# 数据集所在根目录
- # http://download.tensorflow.org/example_images/flower_photos.tgz
+ # https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
parser.add_argument('--data-path', type=str, default="/home/wz/data_set/flower_data/flower_photos")
# resnet34 官方权重下载地址
diff --git a/pytorch_classification/train_multi_GPU/train_single_gpu.py b/pytorch_classification/train_multi_GPU/train_single_gpu.py
index bc0fd7312..ce9df27ae 100644
--- a/pytorch_classification/train_multi_GPU/train_single_gpu.py
+++ b/pytorch_classification/train_multi_GPU/train_single_gpu.py
@@ -125,7 +125,7 @@ def main(args):
parser.add_argument('--lrf', type=float, default=0.1)
# 数据集所在根目录
- # http://download.tensorflow.org/example_images/flower_photos.tgz
+ # https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
parser.add_argument('--data-path', type=str,
default="/home/w180662/my_project/my_github/data_set/flower_data/flower_photos")
diff --git a/pytorch_classification/train_multi_GPU/utils.py b/pytorch_classification/train_multi_GPU/utils.py
index 5365d4ef2..54b2c7d18 100644
--- a/pytorch_classification/train_multi_GPU/utils.py
+++ b/pytorch_classification/train_multi_GPU/utils.py
@@ -12,7 +12,7 @@ def read_split_data(root: str, val_rate: float = 0.2):
# 遍历文件夹,一个文件夹对应一个类别
class_names = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
- # 排序,保证顺序一致
+ # 排序,保证各平台顺序一致
class_names.sort()
# 生成类别名称以及对应的数字索引
class_indices = dict((k, v) for v, k in enumerate(class_names))
@@ -32,6 +32,8 @@ def read_split_data(root: str, val_rate: float = 0.2):
# 遍历获取supported支持的所有文件路径
images = [os.path.join(root, cla, i) for i in os.listdir(cla_path)
if os.path.splitext(i)[-1] in supported]
+ # 排序,保证各平台顺序一致
+ images.sort()
# 获取该类别对应的索引
image_class = class_indices[cla]
# 记录该类别的样本数量
@@ -50,6 +52,8 @@ def read_split_data(root: str, val_rate: float = 0.2):
print("{} images were found in the dataset.".format(sum(every_class_num)))
print("{} images for training.".format(len(train_images_path)))
print("{} images for validation.".format(len(val_images_path)))
+ assert len(train_images_path) > 0, "number of training images must greater than 0."
+ assert len(val_images_path) > 0, "number of validation images must greater than 0."
plot_image = False
if plot_image:
diff --git a/pytorch_classification/vision_transformer/README.md b/pytorch_classification/vision_transformer/README.md
new file mode 100644
index 000000000..4b700b2df
--- /dev/null
+++ b/pytorch_classification/vision_transformer/README.md
@@ -0,0 +1,12 @@
+## 代码使用简介
+
+1. 下载好数据集,代码中默认使用的是花分类数据集,下载地址: [https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz](https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz),
+如果下载不了的话可以通过百度云链接下载: https://pan.baidu.com/s/1QLCTA4sXnQAw_yvxPj9szg 提取码:58p0
+2. 在`train.py`脚本中将`--data-path`设置成解压后的`flower_photos`文件夹绝对路径
+3. 下载预训练权重,在`vit_model.py`文件中每个模型都有提供预训练权重的下载地址,根据自己使用的模型下载对应预训练权重
+4. 在`train.py`脚本中将`--weights`参数设成下载好的预训练权重路径
+5. 设置好数据集的路径`--data-path`以及预训练权重的路径`--weights`就能使用`train.py`脚本开始训练了(训练过程中会自动生成`class_indices.json`文件)
+6. 在`predict.py`脚本中导入和训练脚本中同样的模型,并将`model_weight_path`设置成训练好的模型权重路径(默认保存在weights文件夹下)
+7. 在`predict.py`脚本中将`img_path`设置成你自己需要预测的图片绝对路径
+8. 设置好权重路径`model_weight_path`和预测的图片路径`img_path`就能使用`predict.py`脚本进行预测了
+9. 如果要使用自己的数据集,请按照花分类数据集的文件结构进行摆放(即一个类别对应一个文件夹),并且将训练以及预测脚本中的`num_classes`设置成你自己数据的类别数
diff --git a/pytorch_classification/vision_transformer/predict.py b/pytorch_classification/vision_transformer/predict.py
index fad2d117a..1c4c7fe30 100644
--- a/pytorch_classification/vision_transformer/predict.py
+++ b/pytorch_classification/vision_transformer/predict.py
@@ -32,8 +32,8 @@ def main():
json_path = './class_indices.json'
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
- json_file = open(json_path, "r")
- class_indict = json.load(json_file)
+ with open(json_path, "r") as f:
+ class_indict = json.load(f)
# create model
model = create_model(num_classes=5, has_logits=False).to(device)
diff --git a/pytorch_classification/vision_transformer/train.py b/pytorch_classification/vision_transformer/train.py
index 6c1e7cda5..66bb1d296 100644
--- a/pytorch_classification/vision_transformer/train.py
+++ b/pytorch_classification/vision_transformer/train.py
@@ -61,7 +61,7 @@ def main(args):
num_workers=nw,
collate_fn=val_dataset.collate_fn)
- model = create_model(num_classes=5, has_logits=False).to(device)
+ model = create_model(num_classes=args.num_classes, has_logits=False).to(device)
if args.weights != "":
assert os.path.exists(args.weights), "weights file: '{}' not exist.".format(args.weights)
@@ -122,7 +122,7 @@ def main(args):
parser.add_argument('--lrf', type=float, default=0.01)
# 数据集所在根目录
- # http://download.tensorflow.org/example_images/flower_photos.tgz
+ # https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
parser.add_argument('--data-path', type=str,
default="/data/flower_photos")
parser.add_argument('--model-name', default='', help='create model name')
diff --git a/pytorch_classification/vision_transformer/utils.py b/pytorch_classification/vision_transformer/utils.py
index 96ad54a4b..23c53a06f 100644
--- a/pytorch_classification/vision_transformer/utils.py
+++ b/pytorch_classification/vision_transformer/utils.py
@@ -16,7 +16,7 @@ def read_split_data(root: str, val_rate: float = 0.2):
# 遍历文件夹,一个文件夹对应一个类别
flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
- # 排序,保证顺序一致
+ # 排序,保证各平台顺序一致
flower_class.sort()
# 生成类别名称以及对应的数字索引
class_indices = dict((k, v) for v, k in enumerate(flower_class))
@@ -36,6 +36,8 @@ def read_split_data(root: str, val_rate: float = 0.2):
# 遍历获取supported支持的所有文件路径
images = [os.path.join(root, cla, i) for i in os.listdir(cla_path)
if os.path.splitext(i)[-1] in supported]
+ # 排序,保证各平台顺序一致
+ images.sort()
# 获取该类别对应的索引
image_class = class_indices[cla]
# 记录该类别的样本数量
@@ -54,6 +56,8 @@ def read_split_data(root: str, val_rate: float = 0.2):
print("{} images were found in the dataset.".format(sum(every_class_num)))
print("{} images for training.".format(len(train_images_path)))
print("{} images for validation.".format(len(val_images_path)))
+ assert len(train_images_path) > 0, "number of training images must greater than 0."
+ assert len(val_images_path) > 0, "number of validation images must greater than 0."
plot_image = False
if plot_image:
diff --git a/pytorch_keypoint/DeepPose/README.md b/pytorch_keypoint/DeepPose/README.md
new file mode 100644
index 000000000..9d0a54d79
--- /dev/null
+++ b/pytorch_keypoint/DeepPose/README.md
@@ -0,0 +1,68 @@
+# DeepPose
+## 对应论文
+论文名称:`DeepPose: Human Pose Estimation via Deep Neural Networks`
+论文arxiv链接:[https://arxiv.org/abs/1312.4659](https://arxiv.org/abs/1312.4659)
+
+## 开发环境
+开发环境主要信息如下,其他Python依赖详情可见`requirements.txt`文件
+- Python3.10
+- torch2.0.1+cu118(建议大于等于此版本)
+- torchvision0.15.2+cu118(建议大于等于此版本)
+
+## 训练数据集准备
+该项目采用的训练数据是WFLW数据集(人脸98点检测),官方链接:[https://wywu.github.io/projects/LAB/WFLW.html](https://wywu.github.io/projects/LAB/WFLW.html)
+
+在官方网页下载数据集后解压并组织成如下目录形式:
+```
+WFLW
+ ├── WFLW_annotations
+ │ ├── list_98pt_rect_attr_train_test
+ │ └── list_98pt_test
+ └── WFLW_images
+ ├── 0--Parade
+ ├── 1--Handshaking
+ ├── 10--People_Marching
+ ├── 11--Meeting
+ ├── 12--Group
+ └── ......
+```
+
+## 预训练权重准备
+由于该项目默认使用的backbone是torchvision中的resnet50,在实例化模型时会自动下载在imagenet上的预训练权重。
+- 若训练环境可正常联网,则会自动下载预训练权重
+- 若训练环境无法正常链接网络,可预先在联网的机器上手动下载,下载链接:[https://download.pytorch.org/models/resnet50-11ad3fa6.pth](https://download.pytorch.org/models/resnet50-11ad3fa6.pth) 下载完成后将权重拷贝至训练服务器的`~/.cache/torch/hub/checkpoints`目录下即可
+
+## 启动训练
+将训练脚本中的`--dataset_dir`设置成自己构建的`WFLW`数据集绝对路径,例如`/home/wz/datasets/WFLW`
+### 单卡训练
+使用`train.py`脚本:
+```bash
+python train.py
+```
+### 多卡训练
+使用`train_multi_GPU.py`脚本:
+```
+torchrun --nproc_per_node=8 train_multi_GPU.py
+```
+若要单独指定使用某些卡可在启动指令前加入`CUDA_VISIBLE_DEVICES`参数,例如:
+```
+CUDA_VISIBLE_DEVICES=4,5,6,7 torchrun --nproc_per_node=4 train_multi_GPU.py
+```
+
+## 训练好的权重下载地址
+若没有训练条件或者只想简单体验下,可使用本人训练好的模型权重(包含optimizer等信息故文件会略大),该权重在WFLW验证集上的NME指标为`0.048`,百度网盘下载地址:[https://pan.baidu.com/s/1L_zg-fmocEyzhSTxj8IDJw](https://pan.baidu.com/s/1L_zg-fmocEyzhSTxj8IDJw)
+提取码:8fux
+
+下载完成后在当前项目下创建一个`weights`文件夹,并将权重放置该文件夹内。
+
+## 测试图片
+可参考`predict.py`文件,将`img_path`设置成自己要预测的人脸图片(注意这里只支持单人脸的关键点检测,故需要提供单独的人脸图片,具体使用时可配合一个人脸检测器联合使用),例如输入图片:
+
+
+
+网络预测可视化结果为:
+
+
+
+## 导出ONNX模型(可选)
+若需要导出ONNX模型可使用`export_onnx.py`脚本。
\ No newline at end of file
diff --git a/pytorch_keypoint/DeepPose/datasets.py b/pytorch_keypoint/DeepPose/datasets.py
new file mode 100644
index 000000000..7e79cef12
--- /dev/null
+++ b/pytorch_keypoint/DeepPose/datasets.py
@@ -0,0 +1,121 @@
+import os
+from typing import List, Tuple
+
+import cv2
+import torch
+import torch.utils.data as data
+import numpy as np
+
+
+class WFLWDataset(data.Dataset):
+ """
+ https://wywu.github.io/projects/LAB/WFLW.html
+
+ dataset structure:
+
+ ├── WFLW_annotations
+ │ ├── list_98pt_rect_attr_train_test
+ │ └── list_98pt_test
+ └── WFLW_images
+ ├── 0--Parade
+ ├── 1--Handshaking
+ ├── 10--People_Marching
+ ├── 11--Meeting
+ ├── 12--Group
+ └── ......
+ """
+ def __init__(self,
+ root: str,
+ train: bool = True,
+ transforms=None):
+ super().__init__()
+ self.img_root = os.path.join(root, "WFLW_images")
+ assert os.path.exists(self.img_root), "path '{}' does not exist.".format(self.img_root)
+ ana_txt_name = "list_98pt_rect_attr_train.txt" if train else "list_98pt_rect_attr_test.txt"
+ self.anno_path = os.path.join(root, "WFLW_annotations", "list_98pt_rect_attr_train_test", ana_txt_name)
+ assert os.path.exists(self.anno_path), "file '{}' does not exist.".format(self.anno_path)
+
+ self.transforms = transforms
+ self.keypoints: List[np.ndarray] = []
+ self.face_rects: List[List[int]] = []
+ self.img_paths: List[str] = []
+ with open(self.anno_path, "rt") as f:
+ for line in f.readlines():
+ if not line.strip():
+ continue
+
+ split_list = line.strip().split(" ")
+ keypoint_ = self.get_98_points(split_list)
+ keypoint = np.array(keypoint_, dtype=np.float32).reshape((-1, 2))
+ face_rect = list(map(int, split_list[196: 196 + 4])) # xmin, ymin, xmax, ymax
+ img_name = split_list[-1]
+
+ self.keypoints.append(keypoint)
+ self.face_rects.append(face_rect)
+ self.img_paths.append(os.path.join(self.img_root, img_name))
+
+ @staticmethod
+ def get_5_points(keypoints: List[str]) -> List[float]:
+ five_num = [76, 82, 54, 96, 97]
+ five_keypoint = []
+ for i in five_num:
+ five_keypoint.append(keypoints[i * 2])
+ five_keypoint.append(keypoints[i * 2 + 1])
+ return list(map(float, five_keypoint))
+
+ @staticmethod
+ def get_98_points(keypoints: List[str]) -> List[float]:
+ return list(map(float, keypoints[:196]))
+
+ @staticmethod
+ def collate_fn(batch_infos: List[Tuple[torch.Tensor, dict]]):
+ imgs, ori_keypoints, keypoints, m_invs = [], [], [], []
+ for info in batch_infos:
+ imgs.append(info[0])
+ ori_keypoints.append(info[1]["ori_keypoint"])
+ keypoints.append(info[1]["keypoint"])
+ m_invs.append(info[1]["m_inv"])
+
+ imgs_tensor = torch.stack(imgs)
+ keypoints_tensor = torch.stack(keypoints)
+ ori_keypoints_tensor = torch.stack(ori_keypoints)
+ m_invs_tensor = torch.stack(m_invs)
+
+ targets = {"ori_keypoints": ori_keypoints_tensor,
+ "keypoints": keypoints_tensor,
+ "m_invs": m_invs_tensor}
+ return imgs_tensor, targets
+
+ def __getitem__(self, idx: int):
+ img_bgr = cv2.imread(self.img_paths[idx], flags=cv2.IMREAD_COLOR)
+ img = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
+
+ target = {
+ "box": self.face_rects[idx],
+ "ori_keypoint": self.keypoints[idx],
+ "keypoint": self.keypoints[idx]
+ }
+
+ if self.transforms is not None:
+ img, target = self.transforms(img, target)
+
+ return img, target
+
+ def __len__(self):
+ return len(self.keypoints)
+
+
+if __name__ == '__main__':
+ train_dataset = WFLWDataset("/home/wz/datasets/WFLW", train=True)
+ print(len(train_dataset))
+
+ eval_dataset = WFLWDataset("/home/wz/datasets/WFLW", train=False)
+ print(len(eval_dataset))
+
+ from utils import draw_keypoints
+ img, target = train_dataset[0]
+ keypoint = target["keypoint"]
+ h, w, c = img.shape
+ keypoint[:, 0] /= w
+ keypoint[:, 1] /= h
+ draw_keypoints(img, keypoint, "test_plot.jpg", is_rel=True)
diff --git a/pytorch_keypoint/DeepPose/export_onnx.py b/pytorch_keypoint/DeepPose/export_onnx.py
new file mode 100644
index 000000000..3d44dc37e
--- /dev/null
+++ b/pytorch_keypoint/DeepPose/export_onnx.py
@@ -0,0 +1,29 @@
+import os
+import torch
+from model import create_deep_pose_model
+
+
+def main():
+ img_hw = [256, 256]
+ num_keypoints = 98
+ weights_path = "./weights/model_weights_209.pth"
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+
+ # create model
+ model = create_deep_pose_model(num_keypoints=num_keypoints)
+
+ # load model weights
+ assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
+ model.load_state_dict(torch.load(weights_path, map_location="cpu")["model"])
+ model.to(device)
+
+ model.eval()
+ with torch.inference_mode():
+ x = torch.randn(size=(1, 3, img_hw[0], img_hw[1]), device=device)
+ torch.onnx.export(model=model,
+ args=(x,),
+ f="deeppose.onnx")
+
+
+if __name__ == '__main__':
+ main()
diff --git a/pytorch_keypoint/DeepPose/model.py b/pytorch_keypoint/DeepPose/model.py
new file mode 100644
index 000000000..1d5abdfb2
--- /dev/null
+++ b/pytorch_keypoint/DeepPose/model.py
@@ -0,0 +1,21 @@
+import torch
+import torch.nn as nn
+from torchvision.models import resnet50, ResNet50_Weights
+
+
+def create_deep_pose_model(num_keypoints: int) -> nn.Module:
+ res50 = resnet50(ResNet50_Weights.IMAGENET1K_V2)
+ in_features = res50.fc.in_features
+ res50.fc = nn.Linear(in_features=in_features, out_features=num_keypoints * 2)
+
+ return res50
+
+
+if __name__ == '__main__':
+ torch.manual_seed(1234)
+ model = create_deep_pose_model(98)
+ model.eval()
+ with torch.inference_mode():
+ x = torch.randn(1, 3, 224, 224)
+ res = model(x)
+ print(res.shape)
diff --git a/pytorch_keypoint/DeepPose/predict.jpg b/pytorch_keypoint/DeepPose/predict.jpg
new file mode 100644
index 000000000..2107a2fe7
Binary files /dev/null and b/pytorch_keypoint/DeepPose/predict.jpg differ
diff --git a/pytorch_keypoint/DeepPose/predict.py b/pytorch_keypoint/DeepPose/predict.py
new file mode 100644
index 000000000..a12a60d6c
--- /dev/null
+++ b/pytorch_keypoint/DeepPose/predict.py
@@ -0,0 +1,55 @@
+import os
+
+import torch
+import numpy as np
+from PIL import Image
+
+import transforms
+from model import create_deep_pose_model
+from utils import draw_keypoints
+
+
+def main():
+ img_hw = [256, 256]
+ num_keypoints = 98
+ img_path = "./test_img.jpg"
+ weights_path = "./weights/model_weights_209.pth"
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+
+ transform = transforms.Compose([
+ transforms.AffineTransform(scale_prob=0., rotate_prob=0., shift_prob=0., fixed_size=img_hw),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+ ])
+
+ # load image
+ assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
+ img = np.array(Image.open(img_path))
+ h, w, c = img.shape
+ target = {"box": [0, 0, w, h]}
+ img_tensor, target = transform(img, target=target)
+ # expand batch dimension
+ img_tensor = img_tensor.unsqueeze(0)
+
+ # create model
+ model = create_deep_pose_model(num_keypoints=num_keypoints)
+
+ # load model weights
+ assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
+ model.load_state_dict(torch.load(weights_path, map_location="cpu")["model"])
+ model.to(device)
+
+ # prediction
+ model.eval()
+ with torch.inference_mode():
+ with torch.autocast(device_type=device.type):
+ pred = torch.squeeze(model(img_tensor.to(device))).reshape([-1, 2]).cpu().numpy()
+
+ wh_tensor = np.array(img_hw[::-1], dtype=np.float32).reshape([1, 2])
+ pred = pred * wh_tensor # rel coord to abs coord
+ pred = transforms.affine_points_np(pred, target["m_inv"].numpy())
+ draw_keypoints(img, coordinate=pred, save_path="predict.jpg", radius=2)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/pytorch_keypoint/DeepPose/requirements.txt b/pytorch_keypoint/DeepPose/requirements.txt
new file mode 100644
index 000000000..385ffc3f2
--- /dev/null
+++ b/pytorch_keypoint/DeepPose/requirements.txt
@@ -0,0 +1,5 @@
+torch>=2.0.1
+torchvision>=0.15.2
+opencv-python
+tqdm
+tensorboard
\ No newline at end of file
diff --git a/pytorch_keypoint/DeepPose/test_img.jpg b/pytorch_keypoint/DeepPose/test_img.jpg
new file mode 100644
index 000000000..6388b49ed
Binary files /dev/null and b/pytorch_keypoint/DeepPose/test_img.jpg differ
diff --git a/pytorch_keypoint/DeepPose/train.py b/pytorch_keypoint/DeepPose/train.py
new file mode 100644
index 000000000..4d8e108f6
--- /dev/null
+++ b/pytorch_keypoint/DeepPose/train.py
@@ -0,0 +1,176 @@
+import os
+
+import torch
+import torch.amp
+from torch.utils.data import DataLoader
+from torch.utils.tensorboard import SummaryWriter
+
+import transforms
+from model import create_deep_pose_model
+from datasets import WFLWDataset
+from train_utils.train_eval_utils import train_one_epoch, evaluate
+
+
+def get_args_parser(add_help=True):
+ import argparse
+
+ parser = argparse.ArgumentParser(description="PyTorch DeepPose Training", add_help=add_help)
+ parser.add_argument("--dataset_dir", type=str, default="/home/wz/datasets/WFLW", help="WFLW dataset directory")
+ parser.add_argument("--device", type=str, default="cuda:0", help="training device, e.g. cpu, cuda:0")
+ parser.add_argument("--save_weights_dir", type=str, default="./weights", help="save dir for model weights")
+ parser.add_argument("--save_freq", type=int, default=10, help="save frequency for weights and generated imgs")
+ parser.add_argument("--eval_freq", type=int, default=5, help="evaluate frequency")
+ parser.add_argument('--img_hw', default=[256, 256], nargs='+', type=int, help='training image size[h, w]')
+ parser.add_argument("--epochs", type=int, default=210, help="number of epochs of training")
+ parser.add_argument("--batch_size", type=int, default=32, help="size of the batches")
+ parser.add_argument("--num_workers", type=int, default=8, help="number of workers, default: 8")
+ parser.add_argument("--num_keypoints", type=int, default=98, help="number of keypoints")
+ parser.add_argument("--lr", type=float, default=5e-4, help="Adam: learning rate")
+ parser.add_argument('--lr_steps', default=[170, 200], nargs='+', type=int,
+ help='decrease lr every step-size epochs')
+ parser.add_argument("--warmup_epoch", type=int, default=10, help="number of warmup epoch for training")
+ parser.add_argument('--resume', default='', type=str, help='resume from checkpoint')
+ parser.add_argument('--test_only', action="http://github.com/store_true", help='Only test the model')
+
+ return parser
+
+
+def main(args):
+ torch.manual_seed(1234)
+ dataset_dir = args.dataset_dir
+ save_weights_dir = args.save_weights_dir
+ save_freq = args.save_freq
+ eval_freq = args.eval_freq
+ num_keypoints = args.num_keypoints
+ num_workers = args.num_workers
+ epochs = args.epochs
+ bs = args.batch_size
+ start_epoch = 0
+ img_hw = args.img_hw
+ os.makedirs(save_weights_dir, exist_ok=True)
+
+ if "cuda" in args.device and not torch.cuda.is_available():
+ device = torch.device("cpu")
+ else:
+ device = torch.device(args.device)
+ print(f"using device: {device} for training.")
+
+ # tensorboard writer
+ tb_writer = SummaryWriter()
+
+ # create model
+ model = create_deep_pose_model(num_keypoints)
+ model.to(device)
+
+ # config dataset and dataloader
+ data_transform = {
+ "train": transforms.Compose([
+ transforms.AffineTransform(scale_factor=(0.65, 1.35), rotate=45, shift_factor=0.15, fixed_size=img_hw),
+ transforms.RandomHorizontalFlip(0.5),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+ ]),
+ "val": transforms.Compose([
+ transforms.AffineTransform(scale_prob=0., rotate_prob=0., shift_prob=0., fixed_size=img_hw),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+ ])
+ }
+ train_dataset = WFLWDataset(root=dataset_dir,
+ train=True,
+ transforms=data_transform["train"])
+ val_dataset = WFLWDataset(root=dataset_dir,
+ train=False,
+ transforms=data_transform["val"])
+
+ train_loader = DataLoader(train_dataset,
+ batch_size=bs,
+ shuffle=True,
+ pin_memory=True,
+ num_workers=num_workers,
+ collate_fn=WFLWDataset.collate_fn,
+ persistent_workers=True)
+
+ val_loader = DataLoader(val_dataset,
+ batch_size=bs,
+ shuffle=False,
+ pin_memory=True,
+ num_workers=num_workers,
+ collate_fn=WFLWDataset.collate_fn,
+ persistent_workers=True)
+
+ # define optimizers
+ optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
+
+ # define learning rate scheduler
+ warmup_scheduler = torch.optim.lr_scheduler.LinearLR(
+ optimizer=optimizer,
+ start_factor=0.01,
+ end_factor=1.0,
+ total_iters=len(train_loader) * args.warmup_epoch
+ )
+ multi_step_scheduler = torch.optim.lr_scheduler.MultiStepLR(
+ optimizer=optimizer,
+ milestones=[len(train_loader) * i for i in args.lr_steps],
+ gamma=0.1
+ )
+
+ lr_scheduler = torch.optim.lr_scheduler.ChainedScheduler([warmup_scheduler, multi_step_scheduler])
+
+ if args.resume:
+ assert os.path.exists(args.resume)
+ checkpoint = torch.load(args.resume, map_location='cpu')
+ model.load_state_dict(checkpoint['model'])
+ optimizer.load_state_dict(checkpoint['optimizer'])
+ lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
+ start_epoch = checkpoint['epoch'] + 1
+ print("the training process from epoch{}...".format(start_epoch))
+
+ if args.test_only:
+ evaluate(model=model,
+ epoch=start_epoch,
+ val_loader=val_loader,
+ device=device,
+ tb_writer=tb_writer,
+ affine_points_torch_func=transforms.affine_points_torch,
+ num_keypoints=num_keypoints,
+ img_hw=img_hw)
+ return
+
+ for epoch in range(start_epoch, epochs):
+ # train
+ train_one_epoch(model=model,
+ epoch=epoch,
+ train_loader=train_loader,
+ device=device,
+ optimizer=optimizer,
+ lr_scheduler=lr_scheduler,
+ tb_writer=tb_writer,
+ num_keypoints=num_keypoints,
+ img_hw=img_hw)
+
+ # eval
+ if epoch % eval_freq == 0 or epoch == args.epochs - 1:
+ evaluate(model=model,
+ epoch=epoch,
+ val_loader=val_loader,
+ device=device,
+ tb_writer=tb_writer,
+ affine_points_torch_func=transforms.affine_points_torch,
+ num_keypoints=num_keypoints,
+ img_hw=img_hw)
+
+ # save weights
+ if epoch % save_freq == 0 or epoch == args.epochs - 1:
+ save_files = {
+ 'model': model.state_dict(),
+ 'optimizer': optimizer.state_dict(),
+ 'lr_scheduler': lr_scheduler.state_dict(),
+ 'epoch': epoch
+ }
+ torch.save(save_files, os.path.join(save_weights_dir, f"model_weights_{epoch}.pth"))
+
+
+if __name__ == '__main__':
+ args = get_args_parser().parse_args()
+ main(args)
diff --git a/pytorch_keypoint/DeepPose/train_multi_GPU.py b/pytorch_keypoint/DeepPose/train_multi_GPU.py
new file mode 100644
index 000000000..d1d1c2f9a
--- /dev/null
+++ b/pytorch_keypoint/DeepPose/train_multi_GPU.py
@@ -0,0 +1,188 @@
+import os
+
+import torch
+import torch.amp
+from torch.utils.data import DataLoader, DistributedSampler, BatchSampler
+from torch.utils.tensorboard import SummaryWriter
+
+import transforms
+from model import create_deep_pose_model
+from datasets import WFLWDataset
+from train_utils.train_eval_utils import train_one_epoch, evaluate
+from train_utils.distributed_utils import init_distributed_mode, is_main_process
+
+
+def get_args_parser(add_help=True):
+ import argparse
+
+ parser = argparse.ArgumentParser(description="PyTorch DeepPose Training", add_help=add_help)
+ parser.add_argument("--dataset_dir", type=str, default="/home/wz/datasets/WFLW", help="WFLW dataset directory")
+ parser.add_argument("--device", type=str, default="cuda", help="training device, e.g. cpu, cuda")
+ parser.add_argument("--save_weights_dir", type=str, default="./weights", help="save dir for model weights")
+ parser.add_argument("--save_freq", type=int, default=5, help="save frequency for weights and generated imgs")
+ parser.add_argument("--eval_freq", type=int, default=5, help="evaluate frequency")
+ parser.add_argument('--img_hw', default=[256, 256], nargs='+', type=int, help='training image size[h, w]')
+ parser.add_argument("--epochs", type=int, default=210, help="number of epochs of training")
+ parser.add_argument("--batch_size", type=int, default=32, help="size of the batches")
+ parser.add_argument("--num_workers", type=int, default=8, help="number of workers, default: 8")
+ parser.add_argument("--num_keypoints", type=int, default=98, help="number of keypoints")
+ parser.add_argument("--lr", type=float, default=5e-4, help="Adam: learning rate")
+ parser.add_argument('--lr_steps', default=[170, 200], nargs='+', type=int,
+ help='decrease lr every step-size epochs')
+ parser.add_argument("--warmup_epoch", type=int, default=10, help="number of warmup epoch for training")
+ parser.add_argument('--resume', default='', type=str, help='resume from checkpoint')
+ parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training')
+ parser.add_argument('--test_only', action="http://github.com/store_true", help='Only test the model')
+
+ return parser
+
+
+def main(args):
+ torch.manual_seed(1234)
+ init_distributed_mode(args)
+ if not args.distributed:
+ raise EnvironmentError("not support distributed training.")
+
+ dataset_dir = args.dataset_dir
+ save_weights_dir = args.save_weights_dir
+ save_freq = args.save_freq
+ eval_freq = args.eval_freq
+ num_keypoints = args.num_keypoints
+ num_workers = args.num_workers
+ epochs = args.epochs
+ bs = args.batch_size
+ start_epoch = 0
+ img_hw = args.img_hw
+ device = torch.device(args.device)
+ os.makedirs(save_weights_dir, exist_ok=True)
+
+ # adjust learning rate
+ args.lr = args.lr * args.world_size
+
+ tb_writer = None
+ if is_main_process():
+ # tensorboard writer
+ tb_writer = SummaryWriter()
+
+ # create model
+ model = create_deep_pose_model(num_keypoints)
+ model.to(device)
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
+
+ # config dataset and dataloader
+ data_transform = {
+ "train": transforms.Compose([
+ transforms.AffineTransform(scale_factor=(0.65, 1.35), rotate=45, shift_factor=0.15, fixed_size=img_hw),
+ transforms.RandomHorizontalFlip(0.5),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+ ]),
+ "val": transforms.Compose([
+ transforms.AffineTransform(scale_prob=0., rotate_prob=0., shift_prob=0., fixed_size=img_hw),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+ ])
+ }
+ train_dataset = WFLWDataset(root=dataset_dir,
+ train=True,
+ transforms=data_transform["train"])
+ val_dataset = WFLWDataset(root=dataset_dir,
+ train=False,
+ transforms=data_transform["val"])
+
+ train_sampler = DistributedSampler(train_dataset)
+ val_sampler = DistributedSampler(val_dataset)
+ train_batch_sampler = BatchSampler(train_sampler, args.batch_size, drop_last=True)
+
+ train_loader = DataLoader(train_dataset,
+ batch_sampler=train_batch_sampler,
+ pin_memory=True,
+ num_workers=num_workers,
+ collate_fn=WFLWDataset.collate_fn,
+ persistent_workers=True)
+
+ val_loader = DataLoader(val_dataset,
+ batch_size=bs,
+ sampler=val_sampler,
+ shuffle=False,
+ pin_memory=True,
+ num_workers=num_workers,
+ collate_fn=WFLWDataset.collate_fn,
+ persistent_workers=True)
+
+ # define optimizers
+ optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
+
+ # define learning rate scheduler
+ warmup_scheduler = torch.optim.lr_scheduler.LinearLR(
+ optimizer=optimizer,
+ start_factor=0.01,
+ end_factor=1.0,
+ total_iters=len(train_loader) * args.warmup_epoch
+ )
+ multi_step_scheduler = torch.optim.lr_scheduler.MultiStepLR(
+ optimizer=optimizer,
+ milestones=[len(train_loader) * i for i in args.lr_steps],
+ gamma=0.1
+ )
+
+ lr_scheduler = torch.optim.lr_scheduler.ChainedScheduler([warmup_scheduler, multi_step_scheduler])
+
+ if args.resume:
+ assert os.path.exists(args.resume)
+ checkpoint = torch.load(args.resume, map_location='cpu')
+ model.module.load_state_dict(checkpoint['model'])
+ optimizer.load_state_dict(checkpoint['optimizer'])
+ lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
+ start_epoch = checkpoint['epoch'] + 1
+ print("the training process from epoch{}...".format(start_epoch))
+
+ if args.test_only:
+ evaluate(model=model,
+ epoch=start_epoch,
+ val_loader=val_loader,
+ device=device,
+ tb_writer=tb_writer,
+ affine_points_torch_func=transforms.affine_points_torch,
+ num_keypoints=num_keypoints,
+ img_hw=img_hw)
+ return
+
+ for epoch in range(start_epoch, epochs):
+ # train
+ train_sampler.set_epoch(epoch) # shuffle training data
+ train_one_epoch(model=model,
+ epoch=epoch,
+ train_loader=train_loader,
+ device=device,
+ optimizer=optimizer,
+ lr_scheduler=lr_scheduler,
+ tb_writer=tb_writer,
+ num_keypoints=num_keypoints,
+ img_hw=img_hw)
+
+ # eval
+ if epoch % eval_freq == 0 or epoch == args.epochs - 1:
+ evaluate(model=model,
+ epoch=epoch,
+ val_loader=val_loader,
+ device=device,
+ tb_writer=tb_writer,
+ affine_points_torch_func=transforms.affine_points_torch,
+ num_keypoints=num_keypoints,
+ img_hw=img_hw)
+
+ # save weights
+ if is_main_process() and (epoch % save_freq == 0 or epoch == args.epochs - 1):
+ save_files = {
+ 'model': model.module.state_dict(),
+ 'optimizer': optimizer.state_dict(),
+ 'lr_scheduler': lr_scheduler.state_dict(),
+ 'epoch': epoch
+ }
+ torch.save(save_files, os.path.join(save_weights_dir, f"model_weights_{epoch}.pth"))
+
+
+if __name__ == '__main__':
+ args = get_args_parser().parse_args()
+ main(args)
diff --git a/pytorch_keypoint/DeepPose/train_utils/distributed_utils.py b/pytorch_keypoint/DeepPose/train_utils/distributed_utils.py
new file mode 100644
index 000000000..ef3cdef66
--- /dev/null
+++ b/pytorch_keypoint/DeepPose/train_utils/distributed_utils.py
@@ -0,0 +1,95 @@
+import os
+
+import torch
+import torch.distributed as dist
+
+
+def reduce_value(input_value: torch.Tensor, average=True) -> torch.Tensor:
+ """
+ Args:
+ input_value (Tensor): all the values will be reduced
+ average (bool): whether to do average or sum
+ Reduce the values from all processes so that all processes
+ have the averaged results.
+ """
+ world_size = get_world_size()
+ if world_size < 2: # 单GPU的情况
+ return input_value
+
+ with torch.inference_mode(): # 多GPU的情况
+ dist.all_reduce(input_value)
+ if average:
+ input_value /= world_size
+
+ return input_value
+
+
+def setup_for_distributed(is_master):
+ """
+ This function disables when not in master process
+ """
+ import builtins as __builtin__
+ builtin_print = __builtin__.print
+
+ def print(*args, **kwargs):
+ force = kwargs.pop('force', False)
+ if is_master or force:
+ builtin_print(*args, **kwargs)
+
+ __builtin__.print = print
+
+
+def is_dist_avail_and_initialized():
+ """检查是否支持分布式环境"""
+ if not dist.is_available():
+ return False
+ if not dist.is_initialized():
+ return False
+ return True
+
+
+def get_world_size():
+ if not is_dist_avail_and_initialized():
+ return 1
+ return dist.get_world_size()
+
+
+def get_rank():
+ if not is_dist_avail_and_initialized():
+ return 0
+ return dist.get_rank()
+
+
+def is_main_process():
+ return get_rank() == 0
+
+
+def init_distributed_mode(args):
+ if not torch.cuda.is_available():
+ print('No available device')
+ args.distributed = False
+ return
+
+ if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
+ args.rank = int(os.environ["RANK"])
+ args.world_size = int(os.environ['WORLD_SIZE'])
+ args.gpu = int(os.environ['LOCAL_RANK'])
+ elif 'SLURM_PROCID' in os.environ:
+ args.rank = int(os.environ['SLURM_PROCID'])
+ args.gpu = args.rank % torch.cuda.device_count()
+ else:
+ print('Not using distributed mode')
+ args.distributed = False
+ return
+
+ args.distributed = True
+
+ torch.cuda.set_device(args.gpu)
+ args.dist_backend = 'nccl'
+ print(f'| distributed init (rank {args.rank}): {args.dist_url}', flush=True)
+ torch.distributed.init_process_group(backend=args.dist_backend,
+ init_method=args.dist_url,
+ world_size=args.world_size,
+ rank=args.rank)
+ torch.distributed.barrier()
+ setup_for_distributed(args.rank == 0)
diff --git a/pytorch_keypoint/DeepPose/train_utils/losses.py b/pytorch_keypoint/DeepPose/train_utils/losses.py
new file mode 100644
index 000000000..163a93af2
--- /dev/null
+++ b/pytorch_keypoint/DeepPose/train_utils/losses.py
@@ -0,0 +1,128 @@
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class L1Loss(nn.Module):
+ def __init__(self) -> None:
+ super().__init__()
+
+ def forward(self, pred: torch.Tensor, label: torch.Tensor, mask: torch = None) -> torch.Tensor:
+ """
+ Args:
+ pred [N, K, 2]
+ label [N, K, 2]
+ mask [N, K]
+ """
+ losses = F.l1_loss(pred, label, reduction="none")
+ if mask is not None:
+ # filter invalid keypoints(e.g. out of range)
+ losses = losses * mask.unsqueeze(2)
+
+ return torch.mean(torch.sum(losses, dim=(1, 2)), dim=0)
+
+
+class SmoothL1Loss(nn.Module):
+ def __init__(self) -> None:
+ super().__init__()
+
+ def forward(self, pred: torch.Tensor, label: torch.Tensor, mask: torch = None) -> torch.Tensor:
+ """
+ Args:
+ pred [N, K, 2]
+ label [N, K, 2]
+ mask [N, K]
+ """
+ losses = F.smooth_l1_loss(pred, label, reduction="none")
+ if mask is not None:
+ # filter invalid keypoints(e.g. out of range)
+ losses = losses * mask.unsqueeze(2)
+
+ return torch.mean(torch.sum(losses, dim=(1, 2)), dim=0)
+
+
+class L2Loss(nn.Module):
+ def __init__(self) -> None:
+ super().__init__()
+
+ def forward(self, pred: torch.Tensor, label: torch.Tensor, mask: torch = None) -> torch.Tensor:
+ """
+ Args:
+ pred [N, K, 2]
+ label [N, K, 2]
+ mask [N, K]
+ """
+ losses = F.mse_loss(pred, label, reduction="none")
+ if mask is not None:
+ # filter invalid keypoints(e.g. out of range)
+ losses = losses * mask.unsqueeze(2)
+
+ return torch.mean(torch.sum(losses, dim=(1, 2)), dim=0)
+
+
+class WingLoss(nn.Module):
+ """refer https://github.com/TropComplique/wing-loss/blob/master/loss.py
+ """
+ def __init__(self, w: float = 10.0, epsilon: float = 2.0) -> None:
+ super().__init__()
+ self.w = w
+ self.epsilon = epsilon
+ self.C = w * (1.0 - math.log(1.0 + w / epsilon))
+
+ def forward(self,
+ pred: torch.Tensor,
+ label: torch.Tensor,
+ wh_tensor: torch.Tensor,
+ mask: torch = None) -> torch.Tensor:
+ """
+ Args:
+ pred [N, K, 2]
+ wh_tensor [1, 1, 2]
+ label [N, K, 2]
+ mask [N, K]
+ """
+ delta = (pred - label).abs() * wh_tensor # rel to abs
+ losses = torch.where(condition=self.w > delta,
+ input=self.w * torch.log(1.0 + delta / self.epsilon),
+ other=delta - self.C)
+ if mask is not None:
+ # filter invalid keypoints(e.g. out of range)
+ losses = losses * mask.unsqueeze(2)
+
+ return torch.mean(torch.sum(losses, dim=(1, 2)), dim=0)
+
+
+class SoftWingLoss(nn.Module):
+ """refer mmpose/models/losses/regression_loss.py
+ """
+ def __init__(self, omega1: float = 2.0, omega2: float = 20.0, epsilon: float = 0.5) -> None:
+ super().__init__()
+ self.omega1 = omega1
+ self.omega2 = omega2
+ self.epsilon = epsilon
+ self.B = omega1 - omega2 * math.log(1.0 + omega1 / epsilon)
+
+ def forward(self,
+ pred: torch.Tensor,
+ label: torch.Tensor,
+ wh_tensor: torch.Tensor,
+ mask: torch = None) -> torch.Tensor:
+ """
+ Args:
+ pred [N, K, 2]
+ label [N, K, 2]
+ wh_tensor [1, 1, 2]
+ mask [N, K]
+ """
+ delta = (pred - label).abs() * wh_tensor # rel to abs
+ losses = torch.where(condition=delta < self.omega1,
+ input=delta,
+ other=self.omega2 * torch.log(1.0 + delta / self.epsilon) + self.B)
+ if mask is not None:
+ # filter invalid keypoints(e.g. out of range)
+ losses = losses * mask.unsqueeze(2)
+
+ loss = torch.mean(torch.sum(losses, dim=(1, 2)), dim=0)
+ return loss
diff --git a/pytorch_keypoint/DeepPose/train_utils/metrics.py b/pytorch_keypoint/DeepPose/train_utils/metrics.py
new file mode 100644
index 000000000..b0f0c7ce3
--- /dev/null
+++ b/pytorch_keypoint/DeepPose/train_utils/metrics.py
@@ -0,0 +1,61 @@
+import torch
+
+from .distributed_utils import reduce_value, is_dist_avail_and_initialized
+
+
+class NMEMetric:
+ def __init__(self, device: torch.device) -> None:
+ # 两眼外角点对应keypoint索引
+ self.keypoint_idxs = [60, 72]
+ self.nme_accumulator: float = 0.
+ self.counter: float = 0.
+ self.device = device
+
+ def update(self, pred: torch.Tensor, gt: torch.Tensor, mask: torch.Tensor = None):
+ """
+ Args:
+ pred (shape [N, K, 2]): pred keypoints
+ gt (shape [N, K, 2]): gt keypoints
+ mask (shape [N, K]): valid keypoints mask
+ """
+ # ion: inter-ocular distance normalized error
+ ion = torch.linalg.norm(gt[:, self.keypoint_idxs[0]] - gt[:, self.keypoint_idxs[1]], dim=1)
+
+ valid_ion_mask = ion > 0
+ if mask is None:
+ mask = valid_ion_mask
+ else:
+ mask = torch.logical_and(mask, valid_ion_mask.unsqueeze_(dim=1)).sum(dim=1) > 0
+ num_valid = mask.sum().item()
+
+ # equal: (pred - gt).pow(2).sum(dim=2).pow(0.5).mean(dim=1)
+ l2_dis = torch.linalg.norm(pred - gt, dim=2)[mask].mean(dim=1) # [N]
+
+ # avoid divide by zero
+ ion = ion[mask] # [N]
+
+ self.nme_accumulator += l2_dis.div(ion).sum().item()
+ self.counter += num_valid
+
+ def evaluate(self):
+ return self.nme_accumulator / self.counter
+
+ def synchronize_results(self):
+ if is_dist_avail_and_initialized():
+ self.nme_accumulator = reduce_value(
+ torch.as_tensor(self.nme_accumulator, device=self.device),
+ average=False
+ ).item()
+
+ self.counter = reduce_value(
+ torch.as_tensor(self.counter, device=self.device),
+ average=False
+ )
+
+
+if __name__ == '__main__':
+ metric = NMEMetric()
+ metric.update(pred=torch.randn(32, 98, 2),
+ gt=torch.randn(32, 98, 2),
+ mask=torch.randn(32, 98))
+ print(metric.evaluate())
diff --git a/pytorch_keypoint/DeepPose/train_utils/train_eval_utils.py b/pytorch_keypoint/DeepPose/train_utils/train_eval_utils.py
new file mode 100644
index 000000000..bba484af5
--- /dev/null
+++ b/pytorch_keypoint/DeepPose/train_utils/train_eval_utils.py
@@ -0,0 +1,92 @@
+import sys
+import math
+from typing import Callable, List
+
+from tqdm import tqdm
+import torch
+from torch.utils.data import DataLoader
+from torch.utils.tensorboard import SummaryWriter
+
+from .losses import WingLoss
+from .metrics import NMEMetric
+from .distributed_utils import is_main_process, reduce_value
+
+
+def train_one_epoch(model: torch.nn.Module,
+ epoch: int,
+ train_loader: DataLoader,
+ device: torch.device,
+ optimizer: torch.optim.Optimizer,
+ lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
+ tb_writer: SummaryWriter,
+ num_keypoints: int,
+ img_hw: List[int]) -> None:
+ # define loss function
+ loss_func = WingLoss()
+ wh_tensor = torch.as_tensor(img_hw[::-1], dtype=torch.float32, device=device).reshape([1, 1, 2])
+
+ model.train()
+ train_bar = train_loader
+ if is_main_process():
+ train_bar = tqdm(train_loader, file=sys.stdout)
+
+ for step, (imgs, targets) in enumerate(train_bar):
+ imgs = imgs.to(device)
+ labels = targets["keypoints"].to(device)
+
+ optimizer.zero_grad()
+ # use mixed precision to speed up training
+ with torch.autocast(device_type=device.type):
+ pred: torch.Tensor = model(imgs)
+ loss: torch.Tensor = loss_func(pred.reshape((-1, num_keypoints, 2)), labels, wh_tensor)
+
+ loss_value = reduce_value(loss).item()
+ if not math.isfinite(loss_value):
+ print("Loss is {}, stopping training".format(loss_value))
+ sys.exit(1)
+
+ loss.backward()
+ optimizer.step()
+ lr_scheduler.step()
+
+ if is_main_process():
+ train_bar.desc = f"train epoch[{epoch}] loss:{loss_value:.3f}"
+
+ global_step = epoch * len(train_loader) + step
+ tb_writer.add_scalar("train loss", loss.item(), global_step=global_step)
+ tb_writer.add_scalar("learning rate", optimizer.param_groups[0]["lr"], global_step=global_step)
+
+
+@torch.inference_mode()
+def evaluate(model: torch.nn.Module,
+ epoch: int,
+ val_loader: DataLoader,
+ device: torch.device,
+ tb_writer: SummaryWriter,
+ affine_points_torch_func: Callable,
+ num_keypoints: int,
+ img_hw: List[int]) -> None:
+ model.eval()
+ metric = NMEMetric(device=device)
+ wh_tensor = torch.as_tensor(img_hw[::-1], dtype=torch.float32, device=device).reshape([1, 1, 2])
+ eval_bar = val_loader
+ if is_main_process():
+ eval_bar = tqdm(val_loader, file=sys.stdout, desc="evaluation")
+
+ for step, (imgs, targets) in enumerate(eval_bar):
+ imgs = imgs.to(device)
+ m_invs = targets["m_invs"].to(device)
+ labels = targets["ori_keypoints"].to(device)
+
+ pred = model(imgs)
+ pred = pred.reshape((-1, num_keypoints, 2)) # [N, K, 2]
+ pred = pred * wh_tensor # rel coord to abs coord
+ pred = affine_points_torch_func(pred, m_invs)
+
+ metric.update(pred, labels)
+
+ metric.synchronize_results()
+ if is_main_process():
+ nme = metric.evaluate()
+ tb_writer.add_scalar("evaluation nme", nme, global_step=epoch)
+ print(f"evaluation NME[{epoch}]: {nme:.3f}")
diff --git a/pytorch_keypoint/DeepPose/transforms.py b/pytorch_keypoint/DeepPose/transforms.py
new file mode 100644
index 000000000..ea55d25fb
--- /dev/null
+++ b/pytorch_keypoint/DeepPose/transforms.py
@@ -0,0 +1,217 @@
+import math
+import random
+from typing import Tuple
+
+import cv2
+import torch
+import numpy as np
+
+from wflw_horizontal_flip_indices import wflw_flip_indices_dict
+
+
+def adjust_box(xmin: int, ymin: int, xmax: int, ymax: int, fixed_size: Tuple[int, int]):
+ """通过增加w或者h的方式保证输入图片的长宽比固定"""
+ w = xmax - xmin
+ h = ymax - ymin
+
+ hw_ratio = fixed_size[0] / fixed_size[1]
+ if h / w > hw_ratio:
+ # 需要在w方向padding
+ wi = h / hw_ratio
+ pad_w = (wi - w) / 2
+ xmin = xmin - pad_w
+ xmax = xmax + pad_w
+ else:
+ # 需要在h方向padding
+ hi = w * hw_ratio
+ pad_h = (hi - h) / 2
+ ymin = ymin - pad_h
+ ymax = ymax + pad_h
+
+ return xmin, ymin, xmax, ymax
+
+
+def affine_points_np(keypoint: np.ndarray, m: np.ndarray) -> np.ndarray:
+ """
+ Args:
+ keypoint [k, 2]
+ m [2, 3]
+ """
+ ones = np.ones((keypoint.shape[0], 1), dtype=np.float32)
+ keypoint = np.concatenate([keypoint, ones], axis=1) # [k, 3]
+ new_keypoint = np.matmul(keypoint, m.T)
+ return new_keypoint
+
+
+def affine_points_torch(keypoint: torch.Tensor, m: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ keypoint [n, k, 2]
+ m [n, 2, 3]
+ """
+ dtype = keypoint.dtype
+ device = keypoint.device
+
+ n, k, _ = keypoint.shape
+ ones = torch.ones(size=(n, k, 1), dtype=dtype, device=device)
+ keypoint = torch.concat([keypoint, ones], dim=2) # [n, k, 3]
+ new_keypoint = torch.matmul(keypoint, m.transpose(1, 2))
+ return new_keypoint
+
+
+class Compose(object):
+ """组合多个transform函数"""
+ def __init__(self, transforms):
+ self.transforms = transforms
+
+ def __call__(self, image, target):
+ for t in self.transforms:
+ image, target = t(image, target)
+ return image, target
+
+
+class Resize(object):
+ def __init__(self, h: int, w: int):
+ self.h = h
+ self.w = w
+
+ def __call__(self, image: np.ndarray, target):
+ image = cv2.resize(image, dsize=(self.w, self.h), fx=0, fy=0,
+ interpolation=cv2.INTER_LINEAR)
+
+ return image, target
+
+
+class ToTensor(object):
+ """将opencv图像转为Tensor, HWC2CHW, 并缩放数值至0~1"""
+ def __call__(self, image, target):
+ image = torch.from_numpy(image).permute((2, 0, 1))
+ image = image.to(torch.float32) / 255.
+
+ if "ori_keypoint" in target and "keypoint" in target:
+ target["ori_keypoint"] = torch.from_numpy(target["ori_keypoint"])
+ target["keypoint"] = torch.from_numpy(target["keypoint"])
+ target["m_inv"] = torch.from_numpy(target["m_inv"])
+ return image, target
+
+
+class Normalize(object):
+ def __init__(self, mean=None, std=None):
+ self.mean = torch.as_tensor(mean, dtype=torch.float32).reshape((3, 1, 1))
+ self.std = torch.as_tensor(std, dtype=torch.float32).reshape((3, 1, 1))
+
+ def __call__(self, image: torch.Tensor, target: dict):
+ image.sub_(self.mean).div_(self.std)
+
+ if "keypoint" in target:
+ _, h, w = image.shape
+ keypoint = target["keypoint"]
+ keypoint[:, 0] /= w
+ keypoint[:, 1] /= h
+ target["keypoint"] = keypoint
+ return image, target
+
+
+class RandomHorizontalFlip(object):
+ """随机对输入图片进行水平翻转"""
+ def __init__(self, p: float = 0.5):
+ self.p = p
+ self.wflw_flip_ids = list(wflw_flip_indices_dict.values())
+
+ def __call__(self, image: np.ndarray, target: dict):
+ if random.random() < self.p:
+ # [h, w, c]
+ image = np.ascontiguousarray(np.flip(image, axis=[1]))
+
+ # [k, 2]
+ if "keypoint" in target:
+ _, w, _ = image.shape
+ keypoint: torch.Tensor = target["keypoint"]
+ keypoint = keypoint[self.wflw_flip_ids]
+ keypoint[:, 0] = w - keypoint[:, 0]
+ target["keypoint"] = keypoint
+
+ return image, target
+
+
+class AffineTransform(object):
+ """shift+scale+rotation"""
+ def __init__(self,
+ scale_factor: Tuple[float, float] = (0.65, 1.35),
+ scale_prob: float = 1.,
+ rotate: int = 45,
+ rotate_prob: float = 0.6,
+ shift_factor: float = 0.15,
+ shift_prob: float = 0.3,
+ fixed_size: Tuple[int, int] = (256, 256)):
+ self.scale_factor = scale_factor
+ self.scale_prob = scale_prob
+ self.rotate = rotate
+ self.rotate_prob = rotate_prob
+ self.shift_factor = shift_factor
+ self.shift_prob = shift_prob
+ self.fixed_size = fixed_size # (h, w)
+
+ def __call__(self, img: np.ndarray, target: dict):
+ src_xmin, src_ymin, src_xmax, src_ymax = adjust_box(*target["box"], fixed_size=self.fixed_size)
+ src_w = src_xmax - src_xmin
+ src_h = src_ymax - src_ymin
+
+ if random.random() < self.shift_prob:
+ shift_w_factor = random.uniform(-self.shift_factor, self.shift_factor)
+ shift_h_factor = random.uniform(-self.shift_factor, self.shift_factor)
+ src_xmin -= int(src_w * shift_w_factor)
+ src_xmax -= int(src_w * shift_w_factor)
+ src_ymin -= int(src_h * shift_h_factor)
+ src_ymax -= int(src_h * shift_h_factor)
+
+ src_center = np.array([(src_xmin + src_xmax) / 2, (src_ymin + src_ymax) / 2], dtype=np.float32)
+ src_p2 = src_center + np.array([0, -src_h / 2], dtype=np.float32) # top middle
+ src_p3 = src_center + np.array([src_w / 2, 0], dtype=np.float32) # right middle
+
+ dst_center = np.array([(self.fixed_size[1] - 1) / 2, (self.fixed_size[0] - 1) / 2], dtype=np.float32)
+ dst_p2 = np.array([(self.fixed_size[1] - 1) / 2, 0], dtype=np.float32) # top middle
+ dst_p3 = np.array([self.fixed_size[1] - 1, (self.fixed_size[0] - 1) / 2], dtype=np.float32) # right middle
+
+ if random.random() < self.scale_prob:
+ scale = random.uniform(*self.scale_factor)
+ src_w = src_w * scale
+ src_h = src_h * scale
+ src_p2 = src_center + np.array([0, -src_h / 2], dtype=np.float32) # top middle
+ src_p3 = src_center + np.array([src_w / 2, 0], dtype=np.float32) # right middle
+
+ if random.random() < self.rotate_prob:
+ angle = random.randint(-self.rotate, self.rotate) # 角度制
+ angle = angle / 180 * math.pi # 弧度制
+ src_p2 = src_center + np.array([src_h / 2 * math.sin(angle),
+ -src_h / 2 * math.cos(angle)], dtype=np.float32)
+ src_p3 = src_center + np.array([src_w / 2 * math.cos(angle),
+ src_w / 2 * math.sin(angle)], dtype=np.float32)
+
+ src = np.stack([src_center, src_p2, src_p3])
+ dst = np.stack([dst_center, dst_p2, dst_p3])
+
+ m = cv2.getAffineTransform(src, dst).astype(np.float32) # 计算正向仿射变换矩阵
+ m_inv = cv2.getAffineTransform(dst, src).astype(np.float32) # 计算逆向仿射变换矩阵,方便后续还原
+
+ # 对图像进行仿射变换
+ warp_img = cv2.warpAffine(src=img,
+ M=m,
+ dsize=tuple(self.fixed_size[::-1]), # [w, h]
+ borderMode=cv2.BORDER_CONSTANT,
+ borderValue=(0, 0, 0),
+ flags=cv2.INTER_LINEAR)
+
+ if "keypoint" in target:
+ keypoint = target["keypoint"]
+ keypoint = affine_points_np(keypoint, m)
+ target["keypoint"] = keypoint
+
+ # from utils import draw_keypoints
+ # keypoint[:, 0] /= self.fixed_size[1]
+ # keypoint[:, 1] /= self.fixed_size[0]
+ # draw_keypoints(warp_img, keypoint, "affine.jpg", 2, is_rel=True)
+
+ target["m"] = m
+ target["m_inv"] = m_inv
+ return warp_img, target
diff --git a/pytorch_keypoint/DeepPose/utils.py b/pytorch_keypoint/DeepPose/utils.py
new file mode 100644
index 000000000..e848022c0
--- /dev/null
+++ b/pytorch_keypoint/DeepPose/utils.py
@@ -0,0 +1,17 @@
+import cv2
+import numpy as np
+
+
+def draw_keypoints(img: np.ndarray, coordinate: np.ndarray, save_path: str, radius: int = 3, is_rel: bool = False):
+ coordinate_ = coordinate.copy()
+ if is_rel:
+ h, w, c = img.shape
+ coordinate_[:, 0] *= w
+ coordinate_[:, 1] *= h
+ coordinate_ = coordinate_.astype(np.int64).tolist()
+
+ img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
+ for x, y in coordinate_:
+ cv2.circle(img_bgr, center=(x, y), radius=radius, color=(255, 0, 0), thickness=-1)
+
+ cv2.imwrite(save_path, img_bgr)
diff --git a/pytorch_keypoint/DeepPose/wflw_horizontal_flip_indices.py b/pytorch_keypoint/DeepPose/wflw_horizontal_flip_indices.py
new file mode 100644
index 000000000..de0151319
--- /dev/null
+++ b/pytorch_keypoint/DeepPose/wflw_horizontal_flip_indices.py
@@ -0,0 +1,100 @@
+wflw_flip_indices_dict = {
+ 0: 32,
+ 1: 31,
+ 2: 30,
+ 3: 29,
+ 4: 28,
+ 5: 27,
+ 6: 26,
+ 7: 25,
+ 8: 24,
+ 9: 23,
+ 10: 22,
+ 11: 21,
+ 12: 20,
+ 13: 19,
+ 14: 18,
+ 15: 17,
+ 16: 16,
+ 17: 15,
+ 18: 14,
+ 19: 13,
+ 20: 12,
+ 21: 11,
+ 22: 10,
+ 23: 9,
+ 24: 8,
+ 25: 7,
+ 26: 6,
+ 27: 5,
+ 28: 4,
+ 29: 3,
+ 30: 2,
+ 31: 1,
+ 32: 0,
+ 33: 46,
+ 34: 45,
+ 35: 44,
+ 36: 43,
+ 37: 42,
+ 38: 50,
+ 39: 49,
+ 40: 48,
+ 41: 47,
+ 42: 37,
+ 43: 36,
+ 44: 35,
+ 45: 34,
+ 46: 33,
+ 47: 41,
+ 48: 40,
+ 49: 39,
+ 50: 38,
+ 51: 51,
+ 52: 52,
+ 53: 53,
+ 54: 54,
+ 55: 59,
+ 56: 58,
+ 57: 57,
+ 58: 56,
+ 59: 55,
+ 60: 72,
+ 61: 71,
+ 62: 70,
+ 63: 69,
+ 64: 68,
+ 65: 75,
+ 66: 74,
+ 67: 73,
+ 68: 64,
+ 69: 63,
+ 70: 62,
+ 71: 61,
+ 72: 60,
+ 73: 67,
+ 74: 66,
+ 75: 65,
+ 76: 82,
+ 77: 81,
+ 78: 80,
+ 79: 79,
+ 80: 78,
+ 81: 77,
+ 82: 76,
+ 83: 87,
+ 84: 86,
+ 85: 85,
+ 86: 84,
+ 87: 83,
+ 88: 92,
+ 89: 91,
+ 90: 90,
+ 91: 89,
+ 92: 88,
+ 93: 95,
+ 94: 94,
+ 95: 93,
+ 96: 97,
+ 97: 96,
+}
diff --git a/pytorch_keypoint/HRNet/HRNet.png b/pytorch_keypoint/HRNet/HRNet.png
new file mode 100644
index 000000000..96e83b8b5
Binary files /dev/null and b/pytorch_keypoint/HRNet/HRNet.png differ
diff --git a/pytorch_keypoint/HRNet/README.md b/pytorch_keypoint/HRNet/README.md
new file mode 100644
index 000000000..509f1097f
--- /dev/null
+++ b/pytorch_keypoint/HRNet/README.md
@@ -0,0 +1,105 @@
+# HRNet
+
+## 该项目主要参考以下仓库
+* https://github.com/leoxiaobin/deep-high-resolution-net.pytorch
+* https://github.com/stefanopini/simple-HRNet
+
+## 环境配置:
+* Python3.6/3.7/3.8
+* Pytorch1.10或以上
+* pycocotools(Linux:`pip install pycocotools`; Windows:`pip install pycocotools-windows`(不需要额外安装vs))
+* Ubuntu或Centos(不建议Windows)
+* 最好使用GPU训练
+* 详细环境配置见`requirements.txt`
+
+## 文件结构:
+```
+ ├── model: 搭建HRNet相关代码
+ ├── train_utils: 训练验证相关模块(包括coco验证相关)
+ ├── my_dataset_coco.py: 自定义dataset用于读取COCO2017数据集
+ ├── person_keypoints.json: COCO数据集中人体关键点相关信息
+ ├── train.py: 单GPU/CPU训练脚本
+ ├── train_multi_GPU.py: 针对使用多GPU的用户使用
+ ├── predict.py: 简易的预测脚本,使用训练好的权重进行预测
+ ├── validation.py: 利用训练好的权重验证/测试数据的COCO指标,并生成record_mAP.txt文件
+ └── transforms.py: 数据增强相关
+```
+
+## 预训练权重下载地址(下载后放入当前文件夹中):
+由于原作者提供的预训练权重(Imagenet和COCO)是放在GoogleDrive和OneDrive上的,国内无法正常访问。所有我提前将权重文件全部下载并放在百度网盘中,
+需要的可以自行下载,链接:https://pan.baidu.com/s/1Lu6mMAWfm_8GGykttFMpVw 提取码:f43o
+
+下载后的目录结构如下:
+```
+├── pytorch
+ ├── pose_mpii
+ ├── pose_coco
+ │ ├── pose_resnet_50_384x288.pth
+ │ ├── pose_resnet_50_256x192.pth
+ │ ├── pose_resnet_101_384x288.pth
+ │ ├── pose_resnet_101_256x192.pth
+ │ ├── pose_hrnet_w32_384x288.pth
+ │ └── pose_hrnet_w32_256x192.pth
+ └── imagenet
+ ├── resnet50-19c8e357.pth
+ ├── resnet152-b121ed2d.pth
+ ├── resnet101-5d3b4d8f.pth
+ └── hrnet_w32-36af842e.pth
+```
+如果要直接使用在COCO数据集上预训练好的权重进行预测,下载pose_coco下的`pose_hrnet_w32_256x192.pth`使用即可。
+如果要从头训练网络,下载imagenet下的`hrnet_w32-36af842e.pth`文件,并重命名为`hrnet_w32.pth`即可。
+
+除此之外,还有一个`person_detection_results`文件,存储的是论文中提到的人体检测器的检测结果,如果需要使用可以下载,但个人建议直接使用COCO val中GT信息即可。
+链接: https://pan.baidu.com/s/19Z4mmNHUD934GQ9QYcF5iw 密码: i08q
+
+## 数据集,本例程使用的是COCO2017数据集
+* COCO官网地址:https://cocodataset.org/
+* 对数据集不了解的可以看下我写的博文:https://blog.csdn.net/qq_37541097/article/details/113247318
+* 这里以下载coco2017数据集为例,主要下载三个文件:
+ * `2017 Train images [118K/18GB]`:训练过程中使用到的所有图像文件
+ * `2017 Val images [5K/1GB]`:验证过程中使用到的所有图像文件
+ * `2017 Train/Val annotations [241MB]`:对应训练集和验证集的标注json文件
+* 都解压到`coco2017`文件夹下,可得到如下文件夹结构:
+```
+├── coco2017: 数据集根目录
+ ├── train2017: 所有训练图像文件夹(118287张)
+ ├── val2017: 所有验证图像文件夹(5000张)
+ └── annotations: 对应标注文件夹
+ ├── instances_train2017.json: 对应目标检测、分割任务的训练集标注文件
+ ├── instances_val2017.json: 对应目标检测、分割任务的验证集标注文件
+ ├── captions_train2017.json: 对应图像描述的训练集标注文件
+ ├── captions_val2017.json: 对应图像描述的验证集标注文件
+ ├── person_keypoints_train2017.json: 对应人体关键点检测的训练集标注文件
+ └── person_keypoints_val2017.json: 对应人体关键点检测的验证集标注文件夹
+```
+
+## 训练方法
+* 注:该项目从头训练HRNet在MS COCO2017的val上的mAP[@0.50:0.95]为76.1,利用原作者提供的权重在val上的mAP[@0.50:0.95]为76.6,相差0.5个点,
+暂时没有找到原因。由于训练该网络需要迭代210个epoch(按照论文中的数据),训练时间很长,建议直接使用原作者提供训练好的权重。另外,在训练过程中发现GPU的利用率
+并不高(在20%~60%之间浮动),暂时猜测是网络结构的原因。
+* 确保提前准备好数据集
+* 确保提前下载好对应预训练模型权重
+* 确保设置好`--num-joints`(对于人体检测的关键点个数,COCO是17个点)、`--fixed-size`(输入目标图像的高宽,默认[256, 192])和`--data-path`(指向`coco2017`目录)
+* 若要使用单GPU训练直接使用train.py训练脚本
+* 若要使用多GPU训练,使用`torchrun --nproc_per_node=8 train_multi_GPU.py`指令,`nproc_per_node`参数为使用GPU数量
+* 如果想指定使用哪些GPU设备可在指令前加上`CUDA_VISIBLE_DEVICES=0,3`(例如我只要使用设备中的第1块和第4块GPU设备)
+* `CUDA_VISIBLE_DEVICES=0,3 torchrun --nproc_per_node=2 train_multi_GPU.py`
+
+## 注意事项
+1. 在使用训练脚本时,注意要将`--data-path`设置为自己存放数据集的**根目录**:
+假设要使用COCO数据集,启用自定义数据集读取CocoDetection并将数据集解压到成/data/coco2017目录下
+```
+python train.py --data-path /data/coco2017
+```
+2. 训练过程中保存的`results.txt`是每个epoch在验证集上的COCO指标,前10个值是COCO指标,后面两个值是训练平均损失以及学习率
+3. 在使用预测脚本时,如果要读取自己训练好的权重要将`weights_path`设置为你自己生成的权重路径。
+
+
+## 如果对HRNet网络不是很理解可参考我的bilibili
+https://www.bilibili.com/video/BV1bB4y1y7qP
+
+## 进一步了解该项目,以及对HRNet代码的分析可参考我的bilibili
+https://www.bilibili.com/video/BV1ar4y157JM
+
+## HRNet网络结构图
+
diff --git a/pytorch_keypoint/HRNet/draw_utils.py b/pytorch_keypoint/HRNet/draw_utils.py
new file mode 100644
index 000000000..dbaddd579
--- /dev/null
+++ b/pytorch_keypoint/HRNet/draw_utils.py
@@ -0,0 +1,58 @@
+import numpy as np
+from numpy import ndarray
+import PIL
+from PIL import ImageDraw, ImageFont
+from PIL.Image import Image
+
+# COCO 17 points
+point_name = ["nose", "left_eye", "right_eye",
+ "left_ear", "right_ear",
+ "left_shoulder", "right_shoulder",
+ "left_elbow", "right_elbow",
+ "left_wrist", "right_wrist",
+ "left_hip", "right_hip",
+ "left_knee", "right_knee",
+ "left_ankle", "right_ankle"]
+
+point_color = [(240, 2, 127), (240, 2, 127), (240, 2, 127),
+ (240, 2, 127), (240, 2, 127),
+ (255, 255, 51), (255, 255, 51),
+ (254, 153, 41), (44, 127, 184),
+ (217, 95, 14), (0, 0, 255),
+ (255, 255, 51), (255, 255, 51), (228, 26, 28),
+ (49, 163, 84), (252, 176, 243), (0, 176, 240),
+ (255, 255, 0), (169, 209, 142),
+ (255, 255, 0), (169, 209, 142),
+ (255, 255, 0), (169, 209, 142)]
+
+
+def draw_keypoints(img: Image,
+ keypoints: ndarray,
+ scores: ndarray = None,
+ thresh: float = 0.2,
+ r: int = 2,
+ draw_text: bool = False,
+ font: str = 'arial.ttf',
+ font_size: int = 10):
+ if isinstance(img, ndarray):
+ img = PIL.Image.fromarray(img)
+
+ if scores is None:
+ scores = np.ones(keypoints.shape[0])
+
+ if draw_text:
+ try:
+ font = ImageFont.truetype(font, font_size)
+ except IOError:
+ font = ImageFont.load_default()
+
+ draw = ImageDraw.Draw(img)
+ for i, (point, score) in enumerate(zip(keypoints, scores)):
+ if score > thresh and np.max(point) > 0:
+ draw.ellipse([point[0] - r, point[1] - r, point[0] + r, point[1] + r],
+ fill=point_color[i],
+ outline=(255, 255, 255))
+ if draw_text:
+ draw.text((point[0] + r, point[1] + r), text=point_name[i], font=font)
+
+ return img
diff --git a/pytorch_keypoint/HRNet/model/__init__.py b/pytorch_keypoint/HRNet/model/__init__.py
new file mode 100644
index 000000000..8db7aa786
--- /dev/null
+++ b/pytorch_keypoint/HRNet/model/__init__.py
@@ -0,0 +1 @@
+from .hrnet import HighResolutionNet
diff --git a/pytorch_keypoint/HRNet/model/hrnet.py b/pytorch_keypoint/HRNet/model/hrnet.py
new file mode 100644
index 000000000..4524aa693
--- /dev/null
+++ b/pytorch_keypoint/HRNet/model/hrnet.py
@@ -0,0 +1,278 @@
+import torch.nn as nn
+
+BN_MOMENTUM = 0.1
+
+
+class BasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
+ super(BasicBlock, self).__init__()
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
+ super(Bottleneck, self).__init__()
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
+ padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
+ bias=False)
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion,
+ momentum=BN_MOMENTUM)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class StageModule(nn.Module):
+ def __init__(self, input_branches, output_branches, c):
+ """
+ 构建对应stage,即用来融合不同尺度的实现
+ :param input_branches: 输入的分支数,每个分支对应一种尺度
+ :param output_branches: 输出的分支数
+ :param c: 输入的第一个分支通道数
+ """
+ super().__init__()
+ self.input_branches = input_branches
+ self.output_branches = output_branches
+
+ self.branches = nn.ModuleList()
+ for i in range(self.input_branches): # 每个分支上都先通过4个BasicBlock
+ w = c * (2 ** i) # 对应第i个分支的通道数
+ branch = nn.Sequential(
+ BasicBlock(w, w),
+ BasicBlock(w, w),
+ BasicBlock(w, w),
+ BasicBlock(w, w)
+ )
+ self.branches.append(branch)
+
+ self.fuse_layers = nn.ModuleList() # 用于融合每个分支上的输出
+ for i in range(self.output_branches):
+ self.fuse_layers.append(nn.ModuleList())
+ for j in range(self.input_branches):
+ if i == j:
+ # 当输入、输出为同一个分支时不做任何处理
+ self.fuse_layers[-1].append(nn.Identity())
+ elif i < j:
+ # 当输入分支j大于输出分支i时(即输入分支下采样率大于输出分支下采样率),
+ # 此时需要对输入分支j进行通道调整以及上采样,方便后续相加
+ self.fuse_layers[-1].append(
+ nn.Sequential(
+ nn.Conv2d(c * (2 ** j), c * (2 ** i), kernel_size=1, stride=1, bias=False),
+ nn.BatchNorm2d(c * (2 ** i), momentum=BN_MOMENTUM),
+ nn.Upsample(scale_factor=2.0 ** (j - i), mode='nearest')
+ )
+ )
+ else: # i > j
+ # 当输入分支j小于输出分支i时(即输入分支下采样率小于输出分支下采样率),
+ # 此时需要对输入分支j进行通道调整以及下采样,方便后续相加
+ # 注意,这里每次下采样2x都是通过一个3x3卷积层实现的,4x就是两个,8x就是三个,总共i-j个
+ ops = []
+ # 前i-j-1个卷积层不用变通道,只进行下采样
+ for k in range(i - j - 1):
+ ops.append(
+ nn.Sequential(
+ nn.Conv2d(c * (2 ** j), c * (2 ** j), kernel_size=3, stride=2, padding=1, bias=False),
+ nn.BatchNorm2d(c * (2 ** j), momentum=BN_MOMENTUM),
+ nn.ReLU(inplace=True)
+ )
+ )
+ # 最后一个卷积层不仅要调整通道,还要进行下采样
+ ops.append(
+ nn.Sequential(
+ nn.Conv2d(c * (2 ** j), c * (2 ** i), kernel_size=3, stride=2, padding=1, bias=False),
+ nn.BatchNorm2d(c * (2 ** i), momentum=BN_MOMENTUM)
+ )
+ )
+ self.fuse_layers[-1].append(nn.Sequential(*ops))
+
+ self.relu = nn.ReLU(inplace=True)
+
+ def forward(self, x):
+ # 每个分支通过对应的block
+ x = [branch(xi) for branch, xi in zip(self.branches, x)]
+
+ # 接着融合不同尺寸信息
+ x_fused = []
+ for i in range(len(self.fuse_layers)):
+ x_fused.append(
+ self.relu(
+ sum([self.fuse_layers[i][j](x[j]) for j in range(len(self.branches))])
+ )
+ )
+
+ return x_fused
+
+
+class HighResolutionNet(nn.Module):
+ def __init__(self, base_channel: int = 32, num_joints: int = 17):
+ super().__init__()
+ # Stem
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
+ self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
+ self.relu = nn.ReLU(inplace=True)
+
+ # Stage1
+ downsample = nn.Sequential(
+ nn.Conv2d(64, 256, kernel_size=1, stride=1, bias=False),
+ nn.BatchNorm2d(256, momentum=BN_MOMENTUM)
+ )
+ self.layer1 = nn.Sequential(
+ Bottleneck(64, 64, downsample=downsample),
+ Bottleneck(256, 64),
+ Bottleneck(256, 64),
+ Bottleneck(256, 64)
+ )
+
+ self.transition1 = nn.ModuleList([
+ nn.Sequential(
+ nn.Conv2d(256, base_channel, kernel_size=3, stride=1, padding=1, bias=False),
+ nn.BatchNorm2d(base_channel, momentum=BN_MOMENTUM),
+ nn.ReLU(inplace=True)
+ ),
+ nn.Sequential(
+ nn.Sequential( # 这里又使用一次Sequential是为了适配原项目中提供的权重
+ nn.Conv2d(256, base_channel * 2, kernel_size=3, stride=2, padding=1, bias=False),
+ nn.BatchNorm2d(base_channel * 2, momentum=BN_MOMENTUM),
+ nn.ReLU(inplace=True)
+ )
+ )
+ ])
+
+ # Stage2
+ self.stage2 = nn.Sequential(
+ StageModule(input_branches=2, output_branches=2, c=base_channel)
+ )
+
+ # transition2
+ self.transition2 = nn.ModuleList([
+ nn.Identity(), # None, - Used in place of "None" because it is callable
+ nn.Identity(), # None, - Used in place of "None" because it is callable
+ nn.Sequential(
+ nn.Sequential(
+ nn.Conv2d(base_channel * 2, base_channel * 4, kernel_size=3, stride=2, padding=1, bias=False),
+ nn.BatchNorm2d(base_channel * 4, momentum=BN_MOMENTUM),
+ nn.ReLU(inplace=True)
+ )
+ )
+ ])
+
+ # Stage3
+ self.stage3 = nn.Sequential(
+ StageModule(input_branches=3, output_branches=3, c=base_channel),
+ StageModule(input_branches=3, output_branches=3, c=base_channel),
+ StageModule(input_branches=3, output_branches=3, c=base_channel),
+ StageModule(input_branches=3, output_branches=3, c=base_channel)
+ )
+
+ # transition3
+ self.transition3 = nn.ModuleList([
+ nn.Identity(), # None, - Used in place of "None" because it is callable
+ nn.Identity(), # None, - Used in place of "None" because it is callable
+ nn.Identity(), # None, - Used in place of "None" because it is callable
+ nn.Sequential(
+ nn.Sequential(
+ nn.Conv2d(base_channel * 4, base_channel * 8, kernel_size=3, stride=2, padding=1, bias=False),
+ nn.BatchNorm2d(base_channel * 8, momentum=BN_MOMENTUM),
+ nn.ReLU(inplace=True)
+ )
+ )
+ ])
+
+ # Stage4
+ # 注意,最后一个StageModule只输出分辨率最高的特征层
+ self.stage4 = nn.Sequential(
+ StageModule(input_branches=4, output_branches=4, c=base_channel),
+ StageModule(input_branches=4, output_branches=4, c=base_channel),
+ StageModule(input_branches=4, output_branches=1, c=base_channel)
+ )
+
+ # Final layer
+ self.final_layer = nn.Conv2d(base_channel, num_joints, kernel_size=1, stride=1)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.conv2(x)
+ x = self.bn2(x)
+ x = self.relu(x)
+
+ x = self.layer1(x)
+ x = [trans(x) for trans in self.transition1] # Since now, x is a list
+
+ x = self.stage2(x)
+ x = [
+ self.transition2[0](x[0]),
+ self.transition2[1](x[1]),
+ self.transition2[2](x[-1])
+ ] # New branch derives from the "upper" branch only
+
+ x = self.stage3(x)
+ x = [
+ self.transition3[0](x[0]),
+ self.transition3[1](x[1]),
+ self.transition3[2](x[2]),
+ self.transition3[3](x[-1]),
+ ] # New branch derives from the "upper" branch only
+
+ x = self.stage4(x)
+
+ x = self.final_layer(x[0])
+
+ return x
diff --git a/pytorch_keypoint/HRNet/my_dataset_coco.py b/pytorch_keypoint/HRNet/my_dataset_coco.py
new file mode 100644
index 000000000..ff1cea78a
--- /dev/null
+++ b/pytorch_keypoint/HRNet/my_dataset_coco.py
@@ -0,0 +1,108 @@
+import os
+import copy
+
+import torch
+import numpy as np
+import cv2
+import torch.utils.data as data
+from pycocotools.coco import COCO
+
+
+class CocoKeypoint(data.Dataset):
+ def __init__(self,
+ root,
+ dataset="train",
+ years="2017",
+ transforms=None,
+ det_json_path=None,
+ fixed_size=(256, 192)):
+ super().__init__()
+ assert dataset in ["train", "val"], 'dataset must be in ["train", "val"]'
+ anno_file = f"person_keypoints_{dataset}{years}.json"
+ assert os.path.exists(root), "file '{}' does not exist.".format(root)
+ self.img_root = os.path.join(root, f"{dataset}{years}")
+ assert os.path.exists(self.img_root), "path '{}' does not exist.".format(self.img_root)
+ self.anno_path = os.path.join(root, "annotations", anno_file)
+ assert os.path.exists(self.anno_path), "file '{}' does not exist.".format(self.anno_path)
+
+ self.fixed_size = fixed_size
+ self.mode = dataset
+ self.transforms = transforms
+ self.coco = COCO(self.anno_path)
+ img_ids = list(sorted(self.coco.imgs.keys()))
+
+ if det_json_path is not None:
+ det = self.coco.loadRes(det_json_path)
+ else:
+ det = self.coco
+
+ self.valid_person_list = []
+ obj_idx = 0
+ for img_id in img_ids:
+ img_info = self.coco.loadImgs(img_id)[0]
+ ann_ids = det.getAnnIds(imgIds=img_id)
+ anns = det.loadAnns(ann_ids)
+ for ann in anns:
+ # only save person class
+ if ann["category_id"] != 1:
+ print(f'warning: find not support id: {ann["category_id"]}, only support id: 1 (person)')
+ continue
+
+ # COCO_val2017_detections_AP_H_56_person.json文件中只有det信息,没有keypoint信息,跳过检查
+ if det_json_path is None:
+ # skip objs without keypoints annotation
+ if "keypoints" not in ann:
+ continue
+ if max(ann["keypoints"]) == 0:
+ continue
+
+ xmin, ymin, w, h = ann['bbox']
+ # Use only valid bounding boxes
+ if w > 0 and h > 0:
+ info = {
+ "box": [xmin, ymin, w, h],
+ "image_path": os.path.join(self.img_root, img_info["file_name"]),
+ "image_id": img_id,
+ "image_width": img_info['width'],
+ "image_height": img_info['height'],
+ "obj_origin_hw": [h, w],
+ "obj_index": obj_idx,
+ "score": ann["score"] if "score" in ann else 1.
+ }
+
+ # COCO_val2017_detections_AP_H_56_person.json文件中只有det信息,没有keypoint信息,跳过
+ if det_json_path is None:
+ keypoints = np.array(ann["keypoints"]).reshape([-1, 3])
+ visible = keypoints[:, 2]
+ keypoints = keypoints[:, :2]
+ info["keypoints"] = keypoints
+ info["visible"] = visible
+
+ self.valid_person_list.append(info)
+ obj_idx += 1
+
+ def __getitem__(self, idx):
+ target = copy.deepcopy(self.valid_person_list[idx])
+
+ image = cv2.imread(target["image_path"])
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+ if self.transforms is not None:
+ image, person_info = self.transforms(image, target)
+
+ return image, target
+
+ def __len__(self):
+ return len(self.valid_person_list)
+
+ @staticmethod
+ def collate_fn(batch):
+ imgs_tuple, targets_tuple = tuple(zip(*batch))
+ imgs_tensor = torch.stack(imgs_tuple)
+ return imgs_tensor, targets_tuple
+
+
+if __name__ == '__main__':
+ train = CocoKeypoint("/data/coco2017/", dataset="val")
+ print(len(train))
+ t = train[0]
+ print(t)
diff --git a/pytorch_keypoint/HRNet/person.png b/pytorch_keypoint/HRNet/person.png
new file mode 100644
index 000000000..f647848b5
Binary files /dev/null and b/pytorch_keypoint/HRNet/person.png differ
diff --git a/pytorch_keypoint/HRNet/person_keypoints.json b/pytorch_keypoint/HRNet/person_keypoints.json
new file mode 100644
index 000000000..ffdbc5bd8
--- /dev/null
+++ b/pytorch_keypoint/HRNet/person_keypoints.json
@@ -0,0 +1,8 @@
+{
+ "keypoints": ["nose","left_eye","right_eye","left_ear","right_ear","left_shoulder","right_shoulder","left_elbow","right_elbow","left_wrist","right_wrist","left_hip","right_hip","left_knee","right_knee","left_ankle","right_ankle"],
+ "skeleton": [[16,14],[14,12],[17,15],[15,13],[12,13],[6,12],[7,13],[6,7],[6,8],[7,9],[8,10],[9,11],[2,3],[1,2],[1,3],[2,4],[3,5],[4,6],[5,7]],
+ "flip_pairs": [[1,2], [3,4], [5,6], [7,8], [9,10], [11,12], [13,14], [15,16]],
+ "kps_weights": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.2, 1.2, 1.5, 1.5, 1.0, 1.0, 1.2, 1.2, 1.5, 1.5],
+ "upper_body_ids": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
+ "lower_body_ids": [11, 12, 13, 14, 15, 16]
+}
\ No newline at end of file
diff --git a/pytorch_keypoint/HRNet/plot_curve.py b/pytorch_keypoint/HRNet/plot_curve.py
new file mode 100644
index 000000000..188df710e
--- /dev/null
+++ b/pytorch_keypoint/HRNet/plot_curve.py
@@ -0,0 +1,46 @@
+import datetime
+import matplotlib.pyplot as plt
+
+
+def plot_loss_and_lr(train_loss, learning_rate):
+ try:
+ x = list(range(len(train_loss)))
+ fig, ax1 = plt.subplots(1, 1)
+ ax1.plot(x, train_loss, 'r', label='loss')
+ ax1.set_xlabel("step")
+ ax1.set_ylabel("loss")
+ ax1.set_title("Train Loss and lr")
+ plt.legend(loc='best')
+
+ ax2 = ax1.twinx()
+ ax2.plot(x, learning_rate, label='lr')
+ ax2.set_ylabel("learning rate")
+ ax2.set_xlim(0, len(train_loss)) # 设置横坐标整数间隔
+ plt.legend(loc='best')
+
+ handles1, labels1 = ax1.get_legend_handles_labels()
+ handles2, labels2 = ax2.get_legend_handles_labels()
+ plt.legend(handles1 + handles2, labels1 + labels2, loc='upper right')
+
+ fig.subplots_adjust(right=0.8) # 防止出现保存图片显示不全的情况
+ fig.savefig('./loss_and_lr{}.png'.format(datetime.datetime.now().strftime("%Y%m%d-%H%M%S")))
+ plt.close()
+ print("successful save loss curve! ")
+ except Exception as e:
+ print(e)
+
+
+def plot_map(mAP):
+ try:
+ x = list(range(len(mAP)))
+ plt.plot(x, mAP, label='mAp')
+ plt.xlabel('epoch')
+ plt.ylabel('mAP')
+ plt.title('Eval mAP')
+ plt.xlim(0, len(mAP))
+ plt.legend(loc='best')
+ plt.savefig('./mAP.png')
+ plt.close()
+ print("successful save mAP curve!")
+ except Exception as e:
+ print(e)
diff --git a/pytorch_keypoint/HRNet/predict.py b/pytorch_keypoint/HRNet/predict.py
new file mode 100644
index 000000000..ffb46a24c
--- /dev/null
+++ b/pytorch_keypoint/HRNet/predict.py
@@ -0,0 +1,82 @@
+import os
+import json
+
+import torch
+import numpy as np
+import cv2
+import matplotlib.pyplot as plt
+
+from model import HighResolutionNet
+from draw_utils import draw_keypoints
+import transforms
+
+
+def predict_all_person():
+ # TODO
+ pass
+
+
+def predict_single_person():
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+ print(f"using device: {device}")
+
+ flip_test = True
+ resize_hw = (256, 192)
+ img_path = "./person.png"
+ weights_path = "./pose_hrnet_w32_256x192.pth"
+ keypoint_json_path = "person_keypoints.json"
+ assert os.path.exists(img_path), f"file: {img_path} does not exist."
+ assert os.path.exists(weights_path), f"file: {weights_path} does not exist."
+ assert os.path.exists(keypoint_json_path), f"file: {keypoint_json_path} does not exist."
+
+ data_transform = transforms.Compose([
+ transforms.AffineTransform(scale=(1.25, 1.25), fixed_size=resize_hw),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+ ])
+
+ # read json file
+ with open(keypoint_json_path, "r") as f:
+ person_info = json.load(f)
+
+ # read single-person image
+ img = cv2.imread(img_path)
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ img_tensor, target = data_transform(img, {"box": [0, 0, img.shape[1] - 1, img.shape[0] - 1]})
+ img_tensor = torch.unsqueeze(img_tensor, dim=0)
+
+ # create model
+ # HRNet-W32: base_channel=32
+ # HRNet-W48: base_channel=48
+ model = HighResolutionNet(base_channel=32)
+ weights = torch.load(weights_path, map_location=device)
+ weights = weights if "model" not in weights else weights["model"]
+ model.load_state_dict(weights)
+ model.to(device)
+ model.eval()
+
+ with torch.inference_mode():
+ outputs = model(img_tensor.to(device))
+
+ if flip_test:
+ flip_tensor = transforms.flip_images(img_tensor)
+ flip_outputs = torch.squeeze(
+ transforms.flip_back(model(flip_tensor.to(device)), person_info["flip_pairs"]),
+ )
+ # feature is not aligned, shift flipped heatmap for higher accuracy
+ # https://github.com/leoxiaobin/deep-high-resolution-net.pytorch/issues/22
+ flip_outputs[..., 1:] = flip_outputs.clone()[..., 0: -1]
+ outputs = (outputs + flip_outputs) * 0.5
+
+ keypoints, scores = transforms.get_final_preds(outputs, [target["reverse_trans"]], True)
+ keypoints = np.squeeze(keypoints)
+ scores = np.squeeze(scores)
+
+ plot_img = draw_keypoints(img, keypoints, scores, thresh=0.2, r=3)
+ plt.imshow(plot_img)
+ plt.show()
+ plot_img.save("test_result.jpg")
+
+
+if __name__ == '__main__':
+ predict_single_person()
diff --git a/pytorch_keypoint/HRNet/requirements.txt b/pytorch_keypoint/HRNet/requirements.txt
new file mode 100644
index 000000000..d57b6b410
--- /dev/null
+++ b/pytorch_keypoint/HRNet/requirements.txt
@@ -0,0 +1,8 @@
+numpy
+opencv_python==4.5.4.60
+lxml
+torch==1.10.1
+torchvision==0.11.1
+pycocotools
+matplotlib
+tqdm
\ No newline at end of file
diff --git a/pytorch_keypoint/HRNet/train.py b/pytorch_keypoint/HRNet/train.py
new file mode 100644
index 000000000..7b7fa31f6
--- /dev/null
+++ b/pytorch_keypoint/HRNet/train.py
@@ -0,0 +1,229 @@
+import json
+import os
+import datetime
+
+import torch
+from torch.utils import data
+import numpy as np
+
+import transforms
+from model import HighResolutionNet
+from my_dataset_coco import CocoKeypoint
+from train_utils import train_eval_utils as utils
+
+
+def create_model(num_joints, load_pretrain_weights=True):
+ model = HighResolutionNet(base_channel=32, num_joints=num_joints)
+
+ if load_pretrain_weights:
+ # 载入预训练模型权重
+ # 链接:https://pan.baidu.com/s/1Lu6mMAWfm_8GGykttFMpVw 提取码:f43o
+ weights_dict = torch.load("./hrnet_w32.pth", map_location='cpu')
+
+ for k in list(weights_dict.keys()):
+ # 如果载入的是imagenet权重,就删除无用权重
+ if ("head" in k) or ("fc" in k):
+ del weights_dict[k]
+
+ # 如果载入的是coco权重,对比下num_joints,如果不相等就删除
+ if "final_layer" in k:
+ if weights_dict[k].shape[0] != num_joints:
+ del weights_dict[k]
+
+ missing_keys, unexpected_keys = model.load_state_dict(weights_dict, strict=False)
+ if len(missing_keys) != 0:
+ print("missing_keys: ", missing_keys)
+
+ return model
+
+
+def main(args):
+ device = torch.device(args.device if torch.cuda.is_available() else "cpu")
+ print("Using {} device training.".format(device.type))
+
+ # 用来保存coco_info的文件
+ results_file = "results{}.txt".format(datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
+
+ with open(args.keypoints_path, "r") as f:
+ person_kps_info = json.load(f)
+
+ fixed_size = args.fixed_size
+ heatmap_hw = (args.fixed_size[0] // 4, args.fixed_size[1] // 4)
+ kps_weights = np.array(person_kps_info["kps_weights"],
+ dtype=np.float32).reshape((args.num_joints,))
+ data_transform = {
+ "train": transforms.Compose([
+ transforms.HalfBody(0.3, person_kps_info["upper_body_ids"], person_kps_info["lower_body_ids"]),
+ transforms.AffineTransform(scale=(0.65, 1.35), rotation=(-45, 45), fixed_size=fixed_size),
+ transforms.RandomHorizontalFlip(0.5, person_kps_info["flip_pairs"]),
+ transforms.KeypointToHeatMap(heatmap_hw=heatmap_hw, gaussian_sigma=2, keypoints_weights=kps_weights),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+ ]),
+ "val": transforms.Compose([
+ transforms.AffineTransform(scale=(1.25, 1.25), fixed_size=fixed_size),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+ ])
+ }
+
+ data_root = args.data_path
+
+ # load train data set
+ # coco2017 -> annotations -> person_keypoints_train2017.json
+ train_dataset = CocoKeypoint(data_root, "train", transforms=data_transform["train"], fixed_size=args.fixed_size)
+
+ # 注意这里的collate_fn是自定义的,因为读取的数据包括image和targets,不能直接使用默认的方法合成batch
+ batch_size = args.batch_size
+ nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
+ print('Using %g dataloader workers' % nw)
+
+ train_data_loader = data.DataLoader(train_dataset,
+ batch_size=batch_size,
+ shuffle=True,
+ pin_memory=True,
+ num_workers=nw,
+ collate_fn=train_dataset.collate_fn)
+
+ # load validation data set
+ # coco2017 -> annotations -> person_keypoints_val2017.json
+ val_dataset = CocoKeypoint(data_root, "val", transforms=data_transform["val"], fixed_size=args.fixed_size,
+ det_json_path=args.person_det)
+ val_data_loader = data.DataLoader(val_dataset,
+ batch_size=batch_size,
+ shuffle=False,
+ pin_memory=True,
+ num_workers=nw,
+ collate_fn=val_dataset.collate_fn)
+
+ # create model
+ model = create_model(num_joints=args.num_joints)
+ # print(model)
+
+ model.to(device)
+
+ # define optimizer
+ params = [p for p in model.parameters() if p.requires_grad]
+ optimizer = torch.optim.AdamW(params,
+ lr=args.lr,
+ weight_decay=args.weight_decay)
+
+ scaler = torch.cuda.amp.GradScaler() if args.amp else None
+
+ # learning rate scheduler
+ lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_steps, gamma=args.lr_gamma)
+
+ # 如果指定了上次训练保存的权重文件地址,则接着上次结果接着训练
+ if args.resume != "":
+ checkpoint = torch.load(args.resume, map_location='cpu')
+ model.load_state_dict(checkpoint['model'])
+ optimizer.load_state_dict(checkpoint['optimizer'])
+ lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
+ args.start_epoch = checkpoint['epoch'] + 1
+ if args.amp and "scaler" in checkpoint:
+ scaler.load_state_dict(checkpoint["scaler"])
+ print("the training process from epoch{}...".format(args.start_epoch))
+
+ train_loss = []
+ learning_rate = []
+ val_map = []
+
+ for epoch in range(args.start_epoch, args.epochs):
+ # train for one epoch, printing every 50 iterations
+ mean_loss, lr = utils.train_one_epoch(model, optimizer, train_data_loader,
+ device=device, epoch=epoch,
+ print_freq=50, warmup=True,
+ scaler=scaler)
+ train_loss.append(mean_loss.item())
+ learning_rate.append(lr)
+
+ # update the learning rate
+ lr_scheduler.step()
+
+ # evaluate on the test dataset
+ coco_info = utils.evaluate(model, val_data_loader, device=device,
+ flip=True, flip_pairs=person_kps_info["flip_pairs"])
+
+ # write into txt
+ with open(results_file, "a") as f:
+ # 写入的数据包括coco指标还有loss和learning rate
+ result_info = [f"{i:.4f}" for i in coco_info + [mean_loss.item()]] + [f"{lr:.6f}"]
+ txt = "epoch:{} {}".format(epoch, ' '.join(result_info))
+ f.write(txt + "\n")
+
+ val_map.append(coco_info[1]) # @0.5 mAP
+
+ # save weights
+ save_files = {
+ 'model': model.state_dict(),
+ 'optimizer': optimizer.state_dict(),
+ 'lr_scheduler': lr_scheduler.state_dict(),
+ 'epoch': epoch}
+ if args.amp:
+ save_files["scaler"] = scaler.state_dict()
+ torch.save(save_files, "./save_weights/model-{}.pth".format(epoch))
+
+ # plot loss and lr curve
+ if len(train_loss) != 0 and len(learning_rate) != 0:
+ from plot_curve import plot_loss_and_lr
+ plot_loss_and_lr(train_loss, learning_rate)
+
+ # plot mAP curve
+ if len(val_map) != 0:
+ from plot_curve import plot_map
+ plot_map(val_map)
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser(
+ description=__doc__)
+
+ # 训练设备类型
+ parser.add_argument('--device', default='cuda:0', help='device')
+ # 训练数据集的根目录(coco2017)
+ parser.add_argument('--data-path', default='/data/coco2017', help='dataset')
+ # COCO数据集人体关键点信息
+ parser.add_argument('--keypoints-path', default="./person_keypoints.json", type=str,
+ help='person_keypoints.json path')
+ # 原项目提供的验证集person检测信息,如果要使用GT信息,直接将该参数置为None,建议设置成None
+ parser.add_argument('--person-det', type=str, default=None)
+ parser.add_argument('--fixed-size', default=[256, 192], nargs='+', type=int, help='input size')
+ # keypoints点数
+ parser.add_argument('--num-joints', default=17, type=int, help='num_joints')
+ # 文件保存地址
+ parser.add_argument('--output-dir', default='./save_weights', help='path where to save')
+ # 若需要接着上次训练,则指定上次训练保存权重文件地址
+ parser.add_argument('--resume', default='', type=str, help='resume from checkpoint')
+ # 指定接着从哪个epoch数开始训练
+ parser.add_argument('--start-epoch', default=0, type=int, help='start epoch')
+ # 训练的总epoch数
+ parser.add_argument('--epochs', default=210, type=int, metavar='N',
+ help='number of total epochs to run')
+ # 针对torch.optim.lr_scheduler.MultiStepLR的参数
+ parser.add_argument('--lr-steps', default=[170, 200], nargs='+', type=int, help='decrease lr every step-size epochs')
+ # 针对torch.optim.lr_scheduler.MultiStepLR的参数
+ parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma')
+ # 学习率
+ parser.add_argument('--lr', default=0.001, type=float,
+ help='initial learning rate, 0.02 is the default value for training '
+ 'on 8 gpus and 2 images_per_gpu')
+ # AdamW的weight_decay参数
+ parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
+ metavar='W', help='weight decay (default: 1e-4)',
+ dest='weight_decay')
+ # 训练的batch size
+ parser.add_argument('--batch-size', default=32, type=int, metavar='N',
+ help='batch size when training.')
+ # 是否使用混合精度训练(需要GPU支持混合精度)
+ parser.add_argument("--amp", action="http://github.com/store_true", help="Use torch.cuda.amp for mixed precision training")
+
+ args = parser.parse_args()
+ print(args)
+
+ # 检查保存权重文件夹是否存在,不存在则创建
+ if not os.path.exists(args.output_dir):
+ os.makedirs(args.output_dir)
+
+ main(args)
diff --git a/pytorch_keypoint/HRNet/train_multi_GPU.py b/pytorch_keypoint/HRNet/train_multi_GPU.py
new file mode 100644
index 000000000..9235db6e1
--- /dev/null
+++ b/pytorch_keypoint/HRNet/train_multi_GPU.py
@@ -0,0 +1,272 @@
+import json
+import time
+import os
+import datetime
+
+import torch
+from torch.utils import data
+import numpy as np
+
+import transforms
+from model import HighResolutionNet
+from my_dataset_coco import CocoKeypoint
+import train_utils.train_eval_utils as utils
+from train_utils import init_distributed_mode, save_on_master, mkdir
+
+
+def create_model(num_joints, load_pretrain_weights=True):
+ model = HighResolutionNet(base_channel=32, num_joints=num_joints)
+
+ if load_pretrain_weights:
+ # 载入预训练模型权重
+ # 链接:https://pan.baidu.com/s/1Lu6mMAWfm_8GGykttFMpVw 提取码:f43o
+ weights_dict = torch.load("./hrnet_w32.pth", map_location='cpu')
+
+ for k in list(weights_dict.keys()):
+ # 如果载入的是imagenet权重,就删除无用权重
+ if ("head" in k) or ("fc" in k):
+ del weights_dict[k]
+
+ # 如果载入的是coco权重,对比下num_joints,如果不相等就删除
+ if "final_layer" in k:
+ if weights_dict[k].shape[0] != num_joints:
+ del weights_dict[k]
+
+ missing_keys, unexpected_keys = model.load_state_dict(weights_dict, strict=False)
+ if len(missing_keys) != 0:
+ print("missing_keys: ", missing_keys)
+
+ return model
+
+
+def main(args):
+ init_distributed_mode(args)
+ print(args)
+
+ device = torch.device(args.device)
+
+ # 用来保存coco_info的文件
+ now = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
+ key_results_file = f"results{now}.txt"
+
+ with open(args.keypoints_path, "r") as f:
+ person_kps_info = json.load(f)
+
+ fixed_size = args.fixed_size
+ heatmap_hw = (args.fixed_size[0] // 4, args.fixed_size[1] // 4)
+ kps_weights = np.array(person_kps_info["kps_weights"],
+ dtype=np.float32).reshape((args.num_joints,))
+ data_transform = {
+ "train": transforms.Compose([
+ transforms.HalfBody(0.3, person_kps_info["upper_body_ids"], person_kps_info["lower_body_ids"]),
+ transforms.AffineTransform(scale=(0.65, 1.35), rotation=(-45, 45), fixed_size=fixed_size),
+ transforms.RandomHorizontalFlip(0.5, person_kps_info["flip_pairs"]),
+ transforms.KeypointToHeatMap(heatmap_hw=heatmap_hw, gaussian_sigma=2, keypoints_weights=kps_weights),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+ ]),
+ "val": transforms.Compose([
+ transforms.AffineTransform(scale=(1.25, 1.25), fixed_size=fixed_size),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+ ])
+ }
+
+ data_root = args.data_path
+
+ # load train data set
+ # coco2017 -> annotations -> person_keypoints_train2017.json
+ train_dataset = CocoKeypoint(data_root, "train", transforms=data_transform["train"], fixed_size=args.fixed_size)
+
+ # load validation data set
+ # coco2017 -> annotations -> person_keypoints_val2017.json
+ val_dataset = CocoKeypoint(data_root, "val", transforms=data_transform["val"], fixed_size=args.fixed_size,
+ det_json_path=args.person_det)
+
+ print("Creating data loaders")
+ if args.distributed:
+ train_sampler = data.distributed.DistributedSampler(train_dataset)
+ test_sampler = data.distributed.DistributedSampler(val_dataset)
+ else:
+ train_sampler = data.RandomSampler(train_dataset)
+ test_sampler = data.SequentialSampler(val_dataset)
+
+ train_batch_sampler = data.BatchSampler(train_sampler, args.batch_size, drop_last=True)
+
+ data_loader = data.DataLoader(train_dataset,
+ batch_sampler=train_batch_sampler,
+ num_workers=args.workers,
+ collate_fn=train_dataset.collate_fn)
+
+ data_loader_test = data.DataLoader(val_dataset,
+ batch_size=args.batch_size,
+ sampler=test_sampler,
+ num_workers=args.workers,
+ collate_fn=train_dataset.collate_fn)
+
+ print("Creating model")
+ # create model num_classes equal background + classes
+ model = create_model(num_joints=args.num_joints)
+ model.to(device)
+
+ if args.distributed and args.sync_bn:
+ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
+
+ model_without_ddp = model
+ if args.distributed:
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
+ model_without_ddp = model.module
+
+ params = [p for p in model.parameters() if p.requires_grad]
+ optimizer = torch.optim.AdamW(params,
+ lr=args.lr,
+ weight_decay=args.weight_decay)
+
+ scaler = torch.cuda.amp.GradScaler() if args.amp else None
+
+ lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_steps, gamma=args.lr_gamma)
+
+ # 如果传入resume参数,即上次训练的权重地址,则接着上次的参数训练
+ if args.resume:
+ # If map_location is missing, torch.load will first load the module to CPU
+ # and then copy each parameter to where it was saved,
+ # which would result in all processes on the same machine using the same set of devices.
+ checkpoint = torch.load(args.resume, map_location='cpu') # 读取之前保存的权重文件(包括优化器以及学习率策略)
+ model_without_ddp.load_state_dict(checkpoint['model'])
+ optimizer.load_state_dict(checkpoint['optimizer'])
+ lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
+ args.start_epoch = checkpoint['epoch'] + 1
+ if args.amp and "scaler" in checkpoint:
+ scaler.load_state_dict(checkpoint["scaler"])
+
+ if args.test_only:
+ utils.evaluate(model, data_loader_test, device=device,
+ flip=True, flip_pairs=person_kps_info["flip_pairs"])
+ return
+
+ train_loss = []
+ learning_rate = []
+ val_map = []
+
+ print("Start training")
+ start_time = time.time()
+ for epoch in range(args.start_epoch, args.epochs):
+ if args.distributed:
+ train_sampler.set_epoch(epoch)
+ mean_loss, lr = utils.train_one_epoch(model, optimizer, data_loader,
+ device, epoch, args.print_freq,
+ warmup=True, scaler=scaler)
+
+ # update learning rate
+ lr_scheduler.step()
+
+ # evaluate after every epoch
+ key_info = utils.evaluate(model, data_loader_test, device=device,
+ flip=True, flip_pairs=person_kps_info["flip_pairs"])
+
+ # 只在主进程上进行写操作
+ if args.rank in [-1, 0]:
+ train_loss.append(mean_loss.item())
+ learning_rate.append(lr)
+ val_map.append(key_info[1]) # @0.5 mAP
+
+ # write into txt
+ with open(key_results_file, "a") as f:
+ # 写入的数据包括coco指标还有loss和learning rate
+ result_info = [f"{i:.4f}" for i in key_info + [mean_loss.item()]] + [f"{lr:.6f}"]
+ txt = "epoch:{} {}".format(epoch, ' '.join(result_info))
+ f.write(txt + "\n")
+
+ if args.output_dir:
+ # 只在主进程上执行保存权重操作
+ save_files = {'model': model_without_ddp.state_dict(),
+ 'optimizer': optimizer.state_dict(),
+ 'lr_scheduler': lr_scheduler.state_dict(),
+ 'args': args,
+ 'epoch': epoch}
+ if args.amp:
+ save_files["scaler"] = scaler.state_dict()
+ save_on_master(save_files,
+ os.path.join(args.output_dir, f'model_{epoch}.pth'))
+
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print('Training time {}'.format(total_time_str))
+
+ if args.rank in [-1, 0]:
+ # plot loss and lr curve
+ if len(train_loss) != 0 and len(learning_rate) != 0:
+ from plot_curve import plot_loss_and_lr
+ plot_loss_and_lr(train_loss, learning_rate)
+
+ # plot mAP curve
+ if len(val_map) != 0:
+ from plot_curve import plot_map
+ plot_map(val_map)
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser(
+ description=__doc__)
+
+ # 训练文件的根目录(coco2017)
+ parser.add_argument('--data-path', default='/data/coco2017', help='dataset')
+ # 训练设备类型
+ parser.add_argument('--device', default='cuda', help='device')
+ # COCO数据集人体关键点信息
+ parser.add_argument('--keypoints-path', default="./person_keypoints.json", type=str,
+ help='person_keypoints.json path')
+ # 原项目提供的验证集person检测信息,如果要使用GT信息,直接将该参数置为None,建议设置成None
+ parser.add_argument('--person-det', type=str, default=None)
+ parser.add_argument('--fixed-size', default=[256, 192], nargs='+', type=int, help='input size')
+ # 检测目标类别数(不包含背景)
+ parser.add_argument('--num-joints', default=17, type=int, help='num_joints(num_keypoints)')
+ # 每块GPU上的batch_size
+ parser.add_argument('-b', '--batch-size', default=32, type=int,
+ help='images per gpu, the total batch size is $NGPU x batch_size')
+ # 指定接着从哪个epoch数开始训练
+ parser.add_argument('--start-epoch', default=0, type=int, help='start epoch')
+ # 训练的总epoch数
+ parser.add_argument('--epochs', default=210, type=int, metavar='N',
+ help='number of total epochs to run')
+ # 数据加载以及预处理的线程数
+ parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
+ help='number of data loading workers (default: 4)')
+ # 学习率
+ parser.add_argument('--lr', default=0.001, type=float,
+ help='initial learning rate, 0.001 is the default value for training '
+ 'on 4 gpus and 32 images_per_gpu')
+ # AdamW的weight_decay参数
+ parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
+ metavar='W', help='weight decay (default: 1e-4)',
+ dest='weight_decay')
+ # 针对torch.optim.lr_scheduler.MultiStepLR的参数
+ parser.add_argument('--lr-steps', default=[170, 200], nargs='+', type=int,
+ help='decrease lr every step-size epochs')
+ # 针对torch.optim.lr_scheduler.MultiStepLR的参数
+ parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma')
+ # 训练过程打印信息的频率
+ parser.add_argument('--print-freq', default=50, type=int, help='print frequency')
+ # 文件保存地址
+ parser.add_argument('--output-dir', default='./multi_train', help='path where to save')
+ # 基于上次的训练结果接着训练
+ parser.add_argument('--resume', default='', help='resume from checkpoint')
+ parser.add_argument('--test-only', action="http://github.com/store_true", help="test only")
+
+ # 开启的进程数(注意不是线程)
+ parser.add_argument('--world-size', default=4, type=int,
+ help='number of distributed processes')
+ parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training')
+ parser.add_argument("--sync-bn", action="http://github.com/store_true", help="Use sync batch norm")
+ # 是否使用混合精度训练(需要GPU支持混合精度)
+ parser.add_argument("--amp", action="http://github.com/store_true", help="Use torch.cuda.amp for mixed precision training")
+
+ args = parser.parse_args()
+
+ # 如果指定了保存文件地址,检查文件夹是否存在,若不存在,则创建
+ if args.output_dir:
+ mkdir(args.output_dir)
+
+ main(args)
diff --git a/pytorch_keypoint/HRNet/train_utils/__init__.py b/pytorch_keypoint/HRNet/train_utils/__init__.py
new file mode 100644
index 000000000..3dfa7eadc
--- /dev/null
+++ b/pytorch_keypoint/HRNet/train_utils/__init__.py
@@ -0,0 +1,4 @@
+from .group_by_aspect_ratio import GroupedBatchSampler, create_aspect_ratio_groups
+from .distributed_utils import init_distributed_mode, save_on_master, mkdir
+from .coco_eval import EvalCOCOMetric
+from .coco_utils import coco_remove_images_without_annotations, convert_coco_poly_mask, convert_to_coco_api
diff --git a/pytorch_keypoint/HRNet/train_utils/coco_eval.py b/pytorch_keypoint/HRNet/train_utils/coco_eval.py
new file mode 100644
index 000000000..99aff2c20
--- /dev/null
+++ b/pytorch_keypoint/HRNet/train_utils/coco_eval.py
@@ -0,0 +1,132 @@
+import json
+import copy
+
+from PIL import Image, ImageDraw
+import numpy as np
+from pycocotools.coco import COCO
+from pycocotools.cocoeval import COCOeval
+from .distributed_utils import all_gather, is_main_process
+from transforms import affine_points
+
+
+def merge(img_ids, eval_results):
+ """将多个进程之间的数据汇总在一起"""
+ all_img_ids = all_gather(img_ids)
+ all_eval_results = all_gather(eval_results)
+
+ merged_img_ids = []
+ for p in all_img_ids:
+ merged_img_ids.extend(p)
+
+ merged_eval_results = []
+ for p in all_eval_results:
+ merged_eval_results.extend(p)
+
+ merged_img_ids = np.array(merged_img_ids)
+
+ # keep only unique (and in sorted order) images
+ # 去除重复的图片索引,多GPU训练时为了保证每个进程的训练图片数量相同,可能将一张图片分配给多个进程
+ merged_img_ids, idx = np.unique(merged_img_ids, return_index=True)
+ merged_eval_results = [merged_eval_results[i] for i in idx]
+
+ return list(merged_img_ids), merged_eval_results
+
+
+class EvalCOCOMetric:
+ def __init__(self,
+ coco: COCO = None,
+ iou_type: str = "keypoints",
+ results_file_name: str = "predict_results.json",
+ classes_mapping: dict = None,
+ threshold: float = 0.2):
+ self.coco = copy.deepcopy(coco)
+ self.obj_ids = [] # 记录每个进程处理目标(person)的ids
+ self.results = []
+ self.aggregation_results = None
+ self.classes_mapping = classes_mapping
+ self.coco_evaluator = None
+ assert iou_type in ["keypoints"]
+ self.iou_type = iou_type
+ self.results_file_name = results_file_name
+ self.threshold = threshold
+
+ def plot_img(self, img_path, keypoints, r=3):
+ img = Image.open(img_path)
+ draw = ImageDraw.Draw(img)
+ for i, point in enumerate(keypoints):
+ draw.ellipse([point[0] - r, point[1] - r, point[0] + r, point[1] + r],
+ fill=(255, 0, 0))
+ img.show()
+
+ def prepare_for_coco_keypoints(self, targets, outputs):
+ # 遍历每个person的预测结果(注意这里不是每张,一张图片里可能有多个person)
+ for target, keypoints, scores in zip(targets, outputs[0], outputs[1]):
+ if len(keypoints) == 0:
+ continue
+
+ obj_idx = int(target["obj_index"])
+ if obj_idx in self.obj_ids:
+ # 防止出现重复的数据
+ continue
+
+ self.obj_ids.append(obj_idx)
+ # self.plot_img(target["image_path"], keypoints)
+
+ mask = np.greater(scores, 0.2)
+ if mask.sum() == 0:
+ k_score = 0
+ else:
+ k_score = np.mean(scores[mask])
+
+ keypoints = np.concatenate([keypoints, scores], axis=1)
+ keypoints = np.reshape(keypoints, -1)
+
+ # We recommend rounding coordinates to the nearest tenth of a pixel
+ # to reduce resulting JSON file size.
+ keypoints = [round(k, 2) for k in keypoints.tolist()]
+
+ res = {"image_id": target["image_id"],
+ "category_id": 1, # person
+ "keypoints": keypoints,
+ "score": target["score"] * k_score}
+
+ self.results.append(res)
+
+ def update(self, targets, outputs):
+ if self.iou_type == "keypoints":
+ self.prepare_for_coco_keypoints(targets, outputs)
+ else:
+ raise KeyError(f"not support iou_type: {self.iou_type}")
+
+ def synchronize_results(self):
+ # 同步所有进程中的数据
+ eval_ids, eval_results = merge(self.obj_ids, self.results)
+ self.aggregation_results = {"obj_ids": eval_ids, "results": eval_results}
+
+ # 主进程上保存即可
+ if is_main_process():
+ # results = []
+ # [results.extend(i) for i in eval_results]
+ # write predict results into json file
+ json_str = json.dumps(eval_results, indent=4)
+ with open(self.results_file_name, 'w') as json_file:
+ json_file.write(json_str)
+
+ def evaluate(self):
+ # 只在主进程上评估即可
+ if is_main_process():
+ # accumulate predictions from all images
+ coco_true = self.coco
+ coco_pre = coco_true.loadRes(self.results_file_name)
+
+ self.coco_evaluator = COCOeval(cocoGt=coco_true, cocoDt=coco_pre, iouType=self.iou_type)
+
+ self.coco_evaluator.evaluate()
+ self.coco_evaluator.accumulate()
+ print(f"IoU metric: {self.iou_type}")
+ self.coco_evaluator.summarize()
+
+ coco_info = self.coco_evaluator.stats.tolist() # numpy to list
+ return coco_info
+ else:
+ return None
diff --git a/pytorch_keypoint/HRNet/train_utils/coco_utils.py b/pytorch_keypoint/HRNet/train_utils/coco_utils.py
new file mode 100644
index 000000000..7a3b3122e
--- /dev/null
+++ b/pytorch_keypoint/HRNet/train_utils/coco_utils.py
@@ -0,0 +1,98 @@
+import torch
+import torch.utils.data
+from pycocotools import mask as coco_mask
+from pycocotools.coco import COCO
+
+
+def coco_remove_images_without_annotations(dataset, ids):
+ """
+ 删除coco数据集中没有目标,或者目标面积非常小的数据
+ refer to:
+ https://github.com/pytorch/vision/blob/master/references/detection/coco_utils.py
+ :param dataset:
+ :param cat_list:
+ :return:
+ """
+ def _has_only_empty_bbox(anno):
+ return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno)
+
+ def _has_valid_annotation(anno):
+ # if it's empty, there is no annotation
+ if len(anno) == 0:
+ return False
+ # if all boxes have close to zero area, there is no annotation
+ if _has_only_empty_bbox(anno):
+ return False
+
+ return True
+
+ valid_ids = []
+ for ds_idx, img_id in enumerate(ids):
+ ann_ids = dataset.getAnnIds(imgIds=img_id, iscrowd=None)
+ anno = dataset.loadAnns(ann_ids)
+
+ if _has_valid_annotation(anno):
+ valid_ids.append(img_id)
+
+ return valid_ids
+
+
+def convert_coco_poly_mask(segmentations, height, width):
+ masks = []
+ for polygons in segmentations:
+ rles = coco_mask.frPyObjects(polygons, height, width)
+ mask = coco_mask.decode(rles)
+ if len(mask.shape) < 3:
+ mask = mask[..., None]
+ mask = torch.as_tensor(mask, dtype=torch.uint8)
+ mask = mask.any(dim=2)
+ masks.append(mask)
+ if masks:
+ masks = torch.stack(masks, dim=0)
+ else:
+ # 如果mask为空,则说明没有目标,直接返回数值为0的mask
+ masks = torch.zeros((0, height, width), dtype=torch.uint8)
+ return masks
+
+
+def convert_to_coco_api(self):
+ coco_ds = COCO()
+ # annotation IDs need to start at 1, not 0, see torchvision issue #1530
+ ann_id = 1
+ dataset = {"images": [], "categories": [], "annotations": []}
+ categories = set()
+ for img_idx in range(len(self)):
+ targets, h, w = self.get_annotations(img_idx)
+ img_id = targets["image_id"].item()
+ img_dict = {"id": img_id,
+ "height": h,
+ "width": w}
+ dataset["images"].append(img_dict)
+ bboxes = targets["boxes"].clone()
+ # convert (x_min, ymin, xmax, ymax) to (xmin, ymin, w, h)
+ bboxes[:, 2:] -= bboxes[:, :2]
+ bboxes = bboxes.tolist()
+ labels = targets["labels"].tolist()
+ areas = targets["area"].tolist()
+ iscrowd = targets["iscrowd"].tolist()
+ if "masks" in targets:
+ masks = targets["masks"]
+ # make masks Fortran contiguous for coco_mask
+ masks = masks.permute(0, 2, 1).contiguous().permute(0, 2, 1)
+ num_objs = len(bboxes)
+ for i in range(num_objs):
+ ann = {"image_id": img_id,
+ "bbox": bboxes[i],
+ "category_id": labels[i],
+ "area": areas[i],
+ "iscrowd": iscrowd[i],
+ "id": ann_id}
+ categories.add(labels[i])
+ if "masks" in targets:
+ ann["segmentation"] = coco_mask.encode(masks[i].numpy())
+ dataset["annotations"].append(ann)
+ ann_id += 1
+ dataset["categories"] = [{"id": i} for i in sorted(categories)]
+ coco_ds.dataset = dataset
+ coco_ds.createIndex()
+ return coco_ds
diff --git a/pytorch_keypoint/HRNet/train_utils/distributed_utils.py b/pytorch_keypoint/HRNet/train_utils/distributed_utils.py
new file mode 100644
index 000000000..514b8fd92
--- /dev/null
+++ b/pytorch_keypoint/HRNet/train_utils/distributed_utils.py
@@ -0,0 +1,298 @@
+from collections import defaultdict, deque
+import datetime
+import pickle
+import time
+import errno
+import os
+
+import torch
+import torch.distributed as dist
+
+
+class SmoothedValue(object):
+ """Track a series of values and provide access to smoothed values over a
+ window or the global series average.
+ """
+ def __init__(self, window_size=20, fmt=None):
+ if fmt is None:
+ fmt = "{value:.4f} ({global_avg:.4f})"
+ self.deque = deque(maxlen=window_size) # deque简单理解成加强版list
+ self.total = 0.0
+ self.count = 0
+ self.fmt = fmt
+
+ def update(self, value, n=1):
+ self.deque.append(value)
+ self.count += n
+ self.total += value * n
+
+ def synchronize_between_processes(self):
+ """
+ Warning: does not synchronize the deque!
+ """
+ if not is_dist_avail_and_initialized():
+ return
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
+ dist.barrier()
+ dist.all_reduce(t)
+ t = t.tolist()
+ self.count = int(t[0])
+ self.total = t[1]
+
+ @property
+ def median(self): # @property 是装饰器,这里可简单理解为增加median属性(只读)
+ d = torch.tensor(list(self.deque))
+ return d.median().item()
+
+ @property
+ def avg(self):
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
+ return d.mean().item()
+
+ @property
+ def global_avg(self):
+ return self.total / self.count
+
+ @property
+ def max(self):
+ return max(self.deque)
+
+ @property
+ def value(self):
+ return self.deque[-1]
+
+ def __str__(self):
+ return self.fmt.format(
+ median=self.median,
+ avg=self.avg,
+ global_avg=self.global_avg,
+ max=self.max,
+ value=self.value)
+
+
+def all_gather(data):
+ """
+ 收集各个进程中的数据
+ Run all_gather on arbitrary picklable data (not necessarily tensors)
+ Args:
+ data: any picklable object
+ Returns:
+ list[data]: list of data gathered from each rank
+ """
+ world_size = get_world_size() # 进程数
+ if world_size == 1:
+ return [data]
+
+ data_list = [None] * world_size
+ dist.all_gather_object(data_list, data)
+
+ return data_list
+
+
+def reduce_dict(input_dict, average=True):
+ """
+ Args:
+ input_dict (dict): all the values will be reduced
+ average (bool): whether to do average or sum
+ Reduce the values in the dictionary from all processes so that all processes
+ have the averaged results. Returns a dict with the same fields as
+ input_dict, after reduction.
+ """
+ world_size = get_world_size()
+ if world_size < 2: # 单GPU的情况
+ return input_dict
+ with torch.no_grad(): # 多GPU的情况
+ names = []
+ values = []
+ # sort the keys so that they are consistent across processes
+ for k in sorted(input_dict.keys()):
+ names.append(k)
+ values.append(input_dict[k])
+ values = torch.stack(values, dim=0)
+ dist.all_reduce(values)
+ if average:
+ values /= world_size
+
+ reduced_dict = {k: v for k, v in zip(names, values)}
+ return reduced_dict
+
+
+class MetricLogger(object):
+ def __init__(self, delimiter="\t"):
+ self.meters = defaultdict(SmoothedValue)
+ self.delimiter = delimiter
+
+ def update(self, **kwargs):
+ for k, v in kwargs.items():
+ if isinstance(v, torch.Tensor):
+ v = v.item()
+ assert isinstance(v, (float, int))
+ self.meters[k].update(v)
+
+ def __getattr__(self, attr):
+ if attr in self.meters:
+ return self.meters[attr]
+ if attr in self.__dict__:
+ return self.__dict__[attr]
+ raise AttributeError("'{}' object has no attribute '{}'".format(
+ type(self).__name__, attr))
+
+ def __str__(self):
+ loss_str = []
+ for name, meter in self.meters.items():
+ loss_str.append(
+ "{}: {}".format(name, str(meter))
+ )
+ return self.delimiter.join(loss_str)
+
+ def synchronize_between_processes(self):
+ for meter in self.meters.values():
+ meter.synchronize_between_processes()
+
+ def add_meter(self, name, meter):
+ self.meters[name] = meter
+
+ def log_every(self, iterable, print_freq, header=None):
+ i = 0
+ if not header:
+ header = ""
+ start_time = time.time()
+ end = time.time()
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
+ data_time = SmoothedValue(fmt='{avg:.4f}')
+ space_fmt = ":" + str(len(str(len(iterable)))) + "d"
+ if torch.cuda.is_available():
+ log_msg = self.delimiter.join([header,
+ '[{0' + space_fmt + '}/{1}]',
+ 'eta: {eta}',
+ '{meters}',
+ 'time: {time}',
+ 'data: {data}',
+ 'max mem: {memory:.0f}'])
+ else:
+ log_msg = self.delimiter.join([header,
+ '[{0' + space_fmt + '}/{1}]',
+ 'eta: {eta}',
+ '{meters}',
+ 'time: {time}',
+ 'data: {data}'])
+ MB = 1024.0 * 1024.0
+ for obj in iterable:
+ data_time.update(time.time() - end)
+ yield obj
+ iter_time.update(time.time() - end)
+ if i % print_freq == 0 or i == len(iterable) - 1:
+ eta_second = int(iter_time.global_avg * (len(iterable) - i))
+ eta_string = str(datetime.timedelta(seconds=eta_second))
+ if torch.cuda.is_available():
+ print(log_msg.format(i, len(iterable),
+ eta=eta_string,
+ meters=str(self),
+ time=str(iter_time),
+ data=str(data_time),
+ memory=torch.cuda.max_memory_allocated() / MB))
+ else:
+ print(log_msg.format(i, len(iterable),
+ eta=eta_string,
+ meters=str(self),
+ time=str(iter_time),
+ data=str(data_time)))
+ i += 1
+ end = time.time()
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print('{} Total time: {} ({:.4f} s / it)'.format(header,
+ total_time_str,
+ total_time / len(iterable)))
+
+
+def warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor):
+
+ def f(x):
+ """根据step数返回一个学习率倍率因子"""
+ if x >= warmup_iters: # 当迭代数大于给定的warmup_iters时,倍率因子为1
+ return 1
+ alpha = float(x) / warmup_iters
+ # 迭代过程中倍率因子从warmup_factor -> 1
+ return warmup_factor * (1 - alpha) + alpha
+
+ return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=f)
+
+
+def mkdir(path):
+ try:
+ os.makedirs(path)
+ except OSError as e:
+ if e.errno != errno.EEXIST:
+ raise
+
+
+def setup_for_distributed(is_master):
+ """
+ This function disables when not in master process
+ """
+ import builtins as __builtin__
+ builtin_print = __builtin__.print
+
+ def print(*args, **kwargs):
+ force = kwargs.pop('force', False)
+ if is_master or force:
+ builtin_print(*args, **kwargs)
+
+ __builtin__.print = print
+
+
+def is_dist_avail_and_initialized():
+ """检查是否支持分布式环境"""
+ if not dist.is_available():
+ return False
+ if not dist.is_initialized():
+ return False
+ return True
+
+
+def get_world_size():
+ if not is_dist_avail_and_initialized():
+ return 1
+ return dist.get_world_size()
+
+
+def get_rank():
+ if not is_dist_avail_and_initialized():
+ return 0
+ return dist.get_rank()
+
+
+def is_main_process():
+ return get_rank() == 0
+
+
+def save_on_master(*args, **kwargs):
+ if is_main_process():
+ torch.save(*args, **kwargs)
+
+
+def init_distributed_mode(args):
+ if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
+ args.rank = int(os.environ["RANK"])
+ args.world_size = int(os.environ['WORLD_SIZE'])
+ args.gpu = int(os.environ['LOCAL_RANK'])
+ elif 'SLURM_PROCID' in os.environ:
+ args.rank = int(os.environ['SLURM_PROCID'])
+ args.gpu = args.rank % torch.cuda.device_count()
+ else:
+ print('Not using distributed mode')
+ args.distributed = False
+ return
+
+ args.distributed = True
+
+ torch.cuda.set_device(args.gpu)
+ args.dist_backend = 'nccl'
+ print('| distributed init (rank {}): {}'.format(
+ args.rank, args.dist_url), flush=True)
+ torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
+ world_size=args.world_size, rank=args.rank)
+ torch.distributed.barrier()
+ setup_for_distributed(args.rank == 0)
+
diff --git a/pytorch_keypoint/HRNet/train_utils/group_by_aspect_ratio.py b/pytorch_keypoint/HRNet/train_utils/group_by_aspect_ratio.py
new file mode 100644
index 000000000..e7b8b9e88
--- /dev/null
+++ b/pytorch_keypoint/HRNet/train_utils/group_by_aspect_ratio.py
@@ -0,0 +1,201 @@
+import bisect
+from collections import defaultdict
+import copy
+from itertools import repeat, chain
+import math
+import numpy as np
+
+import torch
+import torch.utils.data
+from torch.utils.data.sampler import BatchSampler, Sampler
+from torch.utils.model_zoo import tqdm
+import torchvision
+
+from PIL import Image
+
+
+def _repeat_to_at_least(iterable, n):
+ repeat_times = math.ceil(n / len(iterable))
+ repeated = chain.from_iterable(repeat(iterable, repeat_times))
+ return list(repeated)
+
+
+class GroupedBatchSampler(BatchSampler):
+ """
+ Wraps another sampler to yield a mini-batch of indices.
+ It enforces that the batch only contain elements from the same group.
+ It also tries to provide mini-batches which follows an ordering which is
+ as close as possible to the ordering from the original sampler.
+ Arguments:
+ sampler (Sampler): Base sampler.
+ group_ids (list[int]): If the sampler produces indices in range [0, N),
+ `group_ids` must be a list of `N` ints which contains the group id of each sample.
+ The group ids must be a continuous set of integers starting from
+ 0, i.e. they must be in the range [0, num_groups).
+ batch_size (int): Size of mini-batch.
+ """
+ def __init__(self, sampler, group_ids, batch_size):
+ if not isinstance(sampler, Sampler):
+ raise ValueError(
+ "sampler should be an instance of "
+ "torch.utils.data.Sampler, but got sampler={}".format(sampler)
+ )
+ self.sampler = sampler
+ self.group_ids = group_ids
+ self.batch_size = batch_size
+
+ def __iter__(self):
+ buffer_per_group = defaultdict(list)
+ samples_per_group = defaultdict(list)
+
+ num_batches = 0
+ for idx in self.sampler:
+ group_id = self.group_ids[idx]
+ buffer_per_group[group_id].append(idx)
+ samples_per_group[group_id].append(idx)
+ if len(buffer_per_group[group_id]) == self.batch_size:
+ yield buffer_per_group[group_id]
+ num_batches += 1
+ del buffer_per_group[group_id]
+ assert len(buffer_per_group[group_id]) < self.batch_size
+
+ # now we have run out of elements that satisfy
+ # the group criteria, let's return the remaining
+ # elements so that the size of the sampler is
+ # deterministic
+ expected_num_batches = len(self)
+ num_remaining = expected_num_batches - num_batches
+ if num_remaining > 0:
+ # for the remaining batches, take first the buffers with largest number
+ # of elements
+ for group_id, _ in sorted(buffer_per_group.items(),
+ key=lambda x: len(x[1]), reverse=True):
+ remaining = self.batch_size - len(buffer_per_group[group_id])
+ samples_from_group_id = _repeat_to_at_least(samples_per_group[group_id], remaining)
+ buffer_per_group[group_id].extend(samples_from_group_id[:remaining])
+ assert len(buffer_per_group[group_id]) == self.batch_size
+ yield buffer_per_group[group_id]
+ num_remaining -= 1
+ if num_remaining == 0:
+ break
+ assert num_remaining == 0
+
+ def __len__(self):
+ return len(self.sampler) // self.batch_size
+
+
+def _compute_aspect_ratios_slow(dataset, indices=None):
+ print("Your dataset doesn't support the fast path for "
+ "computing the aspect ratios, so will iterate over "
+ "the full dataset and load every image instead. "
+ "This might take some time...")
+ if indices is None:
+ indices = range(len(dataset))
+
+ class SubsetSampler(Sampler):
+ def __init__(self, indices):
+ self.indices = indices
+
+ def __iter__(self):
+ return iter(self.indices)
+
+ def __len__(self):
+ return len(self.indices)
+
+ sampler = SubsetSampler(indices)
+ data_loader = torch.utils.data.DataLoader(
+ dataset, batch_size=1, sampler=sampler,
+ num_workers=14, # you might want to increase it for faster processing
+ collate_fn=lambda x: x[0])
+ aspect_ratios = []
+ with tqdm(total=len(dataset)) as pbar:
+ for _i, (img, _) in enumerate(data_loader):
+ pbar.update(1)
+ height, width = img.shape[-2:]
+ aspect_ratio = float(width) / float(height)
+ aspect_ratios.append(aspect_ratio)
+ return aspect_ratios
+
+
+def _compute_aspect_ratios_custom_dataset(dataset, indices=None):
+ if indices is None:
+ indices = range(len(dataset))
+ aspect_ratios = []
+ for i in indices:
+ height, width = dataset.get_height_and_width(i)
+ aspect_ratio = float(width) / float(height)
+ aspect_ratios.append(aspect_ratio)
+ return aspect_ratios
+
+
+def _compute_aspect_ratios_coco_dataset(dataset, indices=None):
+ if indices is None:
+ indices = range(len(dataset))
+ aspect_ratios = []
+ for i in indices:
+ img_info = dataset.coco.imgs[dataset.ids[i]]
+ aspect_ratio = float(img_info["width"]) / float(img_info["height"])
+ aspect_ratios.append(aspect_ratio)
+ return aspect_ratios
+
+
+def _compute_aspect_ratios_voc_dataset(dataset, indices=None):
+ if indices is None:
+ indices = range(len(dataset))
+ aspect_ratios = []
+ for i in indices:
+ # this doesn't load the data into memory, because PIL loads it lazily
+ width, height = Image.open(dataset.images[i]).size
+ aspect_ratio = float(width) / float(height)
+ aspect_ratios.append(aspect_ratio)
+ return aspect_ratios
+
+
+def _compute_aspect_ratios_subset_dataset(dataset, indices=None):
+ if indices is None:
+ indices = range(len(dataset))
+
+ ds_indices = [dataset.indices[i] for i in indices]
+ return compute_aspect_ratios(dataset.dataset, ds_indices)
+
+
+def compute_aspect_ratios(dataset, indices=None):
+ if hasattr(dataset, "get_height_and_width"):
+ return _compute_aspect_ratios_custom_dataset(dataset, indices)
+
+ if isinstance(dataset, torchvision.datasets.CocoDetection):
+ return _compute_aspect_ratios_coco_dataset(dataset, indices)
+
+ if isinstance(dataset, torchvision.datasets.VOCDetection):
+ return _compute_aspect_ratios_voc_dataset(dataset, indices)
+
+ if isinstance(dataset, torch.utils.data.Subset):
+ return _compute_aspect_ratios_subset_dataset(dataset, indices)
+
+ # slow path
+ return _compute_aspect_ratios_slow(dataset, indices)
+
+
+def _quantize(x, bins):
+ bins = copy.deepcopy(bins)
+ bins = sorted(bins)
+ # bisect_right:寻找y元素按顺序应该排在bins中哪个元素的右边,返回的是索引
+ quantized = list(map(lambda y: bisect.bisect_right(bins, y), x))
+ return quantized
+
+
+def create_aspect_ratio_groups(dataset, k=0):
+ # 计算所有数据集中的图片width/height比例
+ aspect_ratios = compute_aspect_ratios(dataset)
+ # 将[0.5, 2]区间划分成2*k+1等份
+ bins = (2 ** np.linspace(-1, 1, 2 * k + 1)).tolist() if k > 0 else [1.0]
+
+ # 统计所有图像比例在bins区间中的位置索引
+ groups = _quantize(aspect_ratios, bins)
+ # count number of elements per group
+ # 统计每个区间的频次
+ counts = np.unique(groups, return_counts=True)[1]
+ fbins = [0] + bins + [np.inf]
+ print("Using {} as bins for aspect ratio quantization".format(fbins))
+ print("Count of instances per bin: {}".format(counts))
+ return groups
diff --git a/pytorch_keypoint/HRNet/train_utils/loss.py b/pytorch_keypoint/HRNet/train_utils/loss.py
new file mode 100644
index 000000000..1628dbf9a
--- /dev/null
+++ b/pytorch_keypoint/HRNet/train_utils/loss.py
@@ -0,0 +1,20 @@
+import torch
+
+
+class KpLoss(object):
+ def __init__(self):
+ self.criterion = torch.nn.MSELoss(reduction='none')
+
+ def __call__(self, logits, targets):
+ assert len(logits.shape) == 4, 'logits should be 4-ndim'
+ device = logits.device
+ bs = logits.shape[0]
+ # [num_kps, H, W] -> [B, num_kps, H, W]
+ heatmaps = torch.stack([t["heatmap"].to(device) for t in targets])
+ # [num_kps] -> [B, num_kps]
+ kps_weights = torch.stack([t["kps_weights"].to(device) for t in targets])
+
+ # [B, num_kps, H, W] -> [B, num_kps]
+ loss = self.criterion(logits, heatmaps).mean(dim=[2, 3])
+ loss = torch.sum(loss * kps_weights) / bs
+ return loss
diff --git a/pytorch_keypoint/HRNet/train_utils/train_eval_utils.py b/pytorch_keypoint/HRNet/train_utils/train_eval_utils.py
new file mode 100644
index 000000000..5f678b8de
--- /dev/null
+++ b/pytorch_keypoint/HRNet/train_utils/train_eval_utils.py
@@ -0,0 +1,119 @@
+import math
+import sys
+import time
+
+import torch
+
+import transforms
+import train_utils.distributed_utils as utils
+from .coco_eval import EvalCOCOMetric
+from .loss import KpLoss
+
+
+def train_one_epoch(model, optimizer, data_loader, device, epoch,
+ print_freq=50, warmup=False, scaler=None):
+ model.train()
+ metric_logger = utils.MetricLogger(delimiter=" ")
+ metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
+ header = 'Epoch: [{}]'.format(epoch)
+
+ lr_scheduler = None
+ if epoch == 0 and warmup is True: # 当训练第一轮(epoch=0)时,启用warmup训练方式,可理解为热身训练
+ warmup_factor = 1.0 / 1000
+ warmup_iters = min(1000, len(data_loader) - 1)
+
+ lr_scheduler = utils.warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor)
+
+ mse = KpLoss()
+ mloss = torch.zeros(1).to(device) # mean losses
+ for i, [images, targets] in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
+ images = torch.stack([image.to(device) for image in images])
+
+ # 混合精度训练上下文管理器,如果在CPU环境中不起任何作用
+ with torch.cuda.amp.autocast(enabled=scaler is not None):
+ results = model(images)
+
+ losses = mse(results, targets)
+
+ # reduce losses over all GPUs for logging purpose
+ loss_dict_reduced = utils.reduce_dict({"losses": losses})
+ losses_reduced = sum(loss for loss in loss_dict_reduced.values())
+
+ loss_value = losses_reduced.item()
+ # 记录训练损失
+ mloss = (mloss * i + loss_value) / (i + 1) # update mean losses
+
+ if not math.isfinite(loss_value): # 当计算的损失为无穷大时停止训练
+ print("Loss is {}, stopping training".format(loss_value))
+ print(loss_dict_reduced)
+ sys.exit(1)
+
+ optimizer.zero_grad()
+ if scaler is not None:
+ scaler.scale(losses).backward()
+ scaler.step(optimizer)
+ scaler.update()
+ else:
+ losses.backward()
+ optimizer.step()
+
+ if lr_scheduler is not None: # 第一轮使用warmup训练方式
+ lr_scheduler.step()
+
+ metric_logger.update(loss=losses_reduced)
+ now_lr = optimizer.param_groups[0]["lr"]
+ metric_logger.update(lr=now_lr)
+
+ return mloss, now_lr
+
+
+@torch.no_grad()
+def evaluate(model, data_loader, device, flip=False, flip_pairs=None):
+ if flip:
+ assert flip_pairs is not None, "enable flip must provide flip_pairs."
+
+ model.eval()
+ metric_logger = utils.MetricLogger(delimiter=" ")
+ header = "Test: "
+
+ key_metric = EvalCOCOMetric(data_loader.dataset.coco, "keypoints", "key_results.json")
+ for image, targets in metric_logger.log_every(data_loader, 100, header):
+ images = torch.stack([img.to(device) for img in image])
+
+ # 当使用CPU时,跳过GPU相关指令
+ if device != torch.device("cpu"):
+ torch.cuda.synchronize(device)
+
+ model_time = time.time()
+ outputs = model(images)
+ if flip:
+ flipped_images = transforms.flip_images(images)
+ flipped_outputs = model(flipped_images)
+ flipped_outputs = transforms.flip_back(flipped_outputs, flip_pairs)
+ # feature is not aligned, shift flipped heatmap for higher accuracy
+ # https://github.com/leoxiaobin/deep-high-resolution-net.pytorch/issues/22
+ flipped_outputs[..., 1:] = flipped_outputs.clone()[..., 0:-1]
+ outputs = (outputs + flipped_outputs) * 0.5
+
+ model_time = time.time() - model_time
+
+ # decode keypoint
+ reverse_trans = [t["reverse_trans"] for t in targets]
+ outputs = transforms.get_final_preds(outputs, reverse_trans, post_processing=True)
+
+ key_metric.update(targets, outputs)
+ metric_logger.update(model_time=model_time)
+
+ # gather the stats from all processes
+ metric_logger.synchronize_between_processes()
+ print("Averaged stats:", metric_logger)
+
+ # 同步所有进程中的数据
+ key_metric.synchronize_results()
+
+ if utils.is_main_process():
+ coco_info = key_metric.evaluate()
+ else:
+ coco_info = None
+
+ return coco_info
diff --git a/pytorch_keypoint/HRNet/transforms.py b/pytorch_keypoint/HRNet/transforms.py
new file mode 100644
index 000000000..b914e2fe3
--- /dev/null
+++ b/pytorch_keypoint/HRNet/transforms.py
@@ -0,0 +1,443 @@
+import math
+import random
+from typing import Tuple
+
+import cv2
+import numpy as np
+import torch
+from torchvision.transforms import functional as F
+import matplotlib.pyplot as plt
+
+
+def flip_images(img):
+ assert len(img.shape) == 4, 'images has to be [batch_size, channels, height, width]'
+ img = torch.flip(img, dims=[3])
+ return img
+
+
+def flip_back(output_flipped, matched_parts):
+ assert len(output_flipped.shape) == 4, 'output_flipped has to be [batch_size, num_joints, height, width]'
+ output_flipped = torch.flip(output_flipped, dims=[3])
+
+ for pair in matched_parts:
+ tmp = output_flipped[:, pair[0]].clone()
+ output_flipped[:, pair[0]] = output_flipped[:, pair[1]]
+ output_flipped[:, pair[1]] = tmp
+
+ return output_flipped
+
+
+def get_max_preds(batch_heatmaps):
+ """
+ get predictions from score maps
+ heatmaps: numpy.ndarray([batch_size, num_joints, height, width])
+ """
+ assert isinstance(batch_heatmaps, torch.Tensor), 'batch_heatmaps should be torch.Tensor'
+ assert len(batch_heatmaps.shape) == 4, 'batch_images should be 4-ndim'
+
+ batch_size, num_joints, h, w = batch_heatmaps.shape
+ heatmaps_reshaped = batch_heatmaps.reshape(batch_size, num_joints, -1)
+ maxvals, idx = torch.max(heatmaps_reshaped, dim=2)
+
+ maxvals = maxvals.unsqueeze(dim=-1)
+ idx = idx.float()
+
+ preds = torch.zeros((batch_size, num_joints, 2)).to(batch_heatmaps)
+
+ preds[:, :, 0] = idx % w # column 对应最大值的x坐标
+ preds[:, :, 1] = torch.floor(idx / w) # row 对应最大值的y坐标
+
+ pred_mask = torch.gt(maxvals, 0.0).repeat(1, 1, 2).float().to(batch_heatmaps.device)
+
+ preds *= pred_mask
+ return preds, maxvals
+
+
+def affine_points(pt, t):
+ ones = np.ones((pt.shape[0], 1), dtype=float)
+ pt = np.concatenate([pt, ones], axis=1).T
+ new_pt = np.dot(t, pt)
+ return new_pt.T
+
+
+def get_final_preds(batch_heatmaps: torch.Tensor,
+ trans: list = None,
+ post_processing: bool = False):
+ assert trans is not None
+ coords, maxvals = get_max_preds(batch_heatmaps)
+
+ heatmap_height = batch_heatmaps.shape[2]
+ heatmap_width = batch_heatmaps.shape[3]
+
+ # post-processing
+ if post_processing:
+ for n in range(coords.shape[0]):
+ for p in range(coords.shape[1]):
+ hm = batch_heatmaps[n][p]
+ px = int(math.floor(coords[n][p][0] + 0.5))
+ py = int(math.floor(coords[n][p][1] + 0.5))
+ if 1 < px < heatmap_width - 1 and 1 < py < heatmap_height - 1:
+ diff = torch.tensor(
+ [
+ hm[py][px + 1] - hm[py][px - 1],
+ hm[py + 1][px] - hm[py - 1][px]
+ ]
+ ).to(batch_heatmaps.device)
+ coords[n][p] += torch.sign(diff) * .25
+
+ preds = coords.clone().cpu().numpy()
+
+ # Transform back
+ for i in range(coords.shape[0]):
+ preds[i] = affine_points(preds[i], trans[i])
+
+ return preds, maxvals.cpu().numpy()
+
+
+def decode_keypoints(outputs, origin_hw, num_joints: int = 17):
+ keypoints = []
+ scores = []
+ heatmap_h, heatmap_w = outputs.shape[-2:]
+ for i in range(num_joints):
+ pt = np.unravel_index(np.argmax(outputs[i]), (heatmap_h, heatmap_w))
+ score = outputs[i, pt[0], pt[1]]
+ keypoints.append(pt[::-1]) # hw -> wh(xy)
+ scores.append(score)
+
+ keypoints = np.array(keypoints, dtype=float)
+ scores = np.array(scores, dtype=float)
+ # convert to full image scale
+ keypoints[:, 0] = np.clip(keypoints[:, 0] / heatmap_w * origin_hw[1],
+ a_min=0,
+ a_max=origin_hw[1])
+ keypoints[:, 1] = np.clip(keypoints[:, 1] / heatmap_h * origin_hw[0],
+ a_min=0,
+ a_max=origin_hw[0])
+ return keypoints, scores
+
+
+def resize_pad(img: np.ndarray, size: tuple):
+ h, w, c = img.shape
+ src = np.array([[0, 0], # 原坐标系中图像左上角点
+ [w - 1, 0], # 原坐标系中图像右上角点
+ [0, h - 1]], # 原坐标系中图像左下角点
+ dtype=np.float32)
+ dst = np.zeros((3, 2), dtype=np.float32)
+ if h / w > size[0] / size[1]:
+ # 需要在w方向padding
+ wi = size[0] * (w / h)
+ pad_w = (size[1] - wi) / 2
+ dst[0, :] = [pad_w - 1, 0] # 目标坐标系中图像左上角点
+ dst[1, :] = [size[1] - pad_w - 1, 0] # 目标坐标系中图像右上角点
+ dst[2, :] = [pad_w - 1, size[0] - 1] # 目标坐标系中图像左下角点
+ else:
+ # 需要在h方向padding
+ hi = size[1] * (h / w)
+ pad_h = (size[0] - hi) / 2
+ dst[0, :] = [0, pad_h - 1] # 目标坐标系中图像左上角点
+ dst[1, :] = [size[1] - 1, pad_h - 1] # 目标坐标系中图像右上角点
+ dst[2, :] = [0, size[0] - pad_h - 1] # 目标坐标系中图像左下角点
+
+ trans = cv2.getAffineTransform(src, dst) # 计算正向仿射变换矩阵
+ # 对图像进行仿射变换
+ resize_img = cv2.warpAffine(img,
+ trans,
+ size[::-1], # w, h
+ flags=cv2.INTER_LINEAR)
+ # import matplotlib.pyplot as plt
+ # plt.imshow(resize_img)
+ # plt.show()
+
+ dst /= 4 # 网络预测的heatmap尺寸是输入图像的1/4
+ reverse_trans = cv2.getAffineTransform(dst, src) # 计算逆向仿射变换矩阵,方便后续还原
+
+ return resize_img, reverse_trans
+
+
+def adjust_box(xmin: float, ymin: float, w: float, h: float, fixed_size: Tuple[float, float]):
+ """通过增加w或者h的方式保证输入图片的长宽比固定"""
+ xmax = xmin + w
+ ymax = ymin + h
+
+ hw_ratio = fixed_size[0] / fixed_size[1]
+ if h / w > hw_ratio:
+ # 需要在w方向padding
+ wi = h / hw_ratio
+ pad_w = (wi - w) / 2
+ xmin = xmin - pad_w
+ xmax = xmax + pad_w
+ else:
+ # 需要在h方向padding
+ hi = w * hw_ratio
+ pad_h = (hi - h) / 2
+ ymin = ymin - pad_h
+ ymax = ymax + pad_h
+
+ return xmin, ymin, xmax, ymax
+
+
+def scale_box(xmin: float, ymin: float, w: float, h: float, scale_ratio: Tuple[float, float]):
+ """根据传入的h、w缩放因子scale_ratio,重新计算xmin,ymin,w,h"""
+ s_h = h * scale_ratio[0]
+ s_w = w * scale_ratio[1]
+ xmin = xmin - (s_w - w) / 2.
+ ymin = ymin - (s_h - h) / 2.
+ return xmin, ymin, s_w, s_h
+
+
+def plot_heatmap(image, heatmap, kps, kps_weights):
+ for kp_id in range(len(kps_weights)):
+ if kps_weights[kp_id] > 0:
+ plt.subplot(1, 2, 1)
+ plt.imshow(image)
+ plt.plot(*kps[kp_id].tolist(), "ro")
+ plt.title("image")
+ plt.subplot(1, 2, 2)
+ plt.imshow(heatmap[kp_id], cmap=plt.cm.Blues)
+ plt.colorbar(ticks=[0, 1])
+ plt.title(f"kp_id: {kp_id}")
+ plt.show()
+
+
+class Compose(object):
+ """组合多个transform函数"""
+ def __init__(self, transforms):
+ self.transforms = transforms
+
+ def __call__(self, image, target):
+ for t in self.transforms:
+ image, target = t(image, target)
+ return image, target
+
+
+class ToTensor(object):
+ """将PIL图像转为Tensor"""
+ def __call__(self, image, target):
+ image = F.to_tensor(image)
+ return image, target
+
+
+class Normalize(object):
+ def __init__(self, mean=None, std=None):
+ self.mean = mean
+ self.std = std
+
+ def __call__(self, image, target):
+ image = F.normalize(image, mean=self.mean, std=self.std)
+ return image, target
+
+
+class HalfBody(object):
+ def __init__(self, p: float = 0.3, upper_body_ids=None, lower_body_ids=None):
+ assert upper_body_ids is not None
+ assert lower_body_ids is not None
+ self.p = p
+ self.upper_body_ids = upper_body_ids
+ self.lower_body_ids = lower_body_ids
+
+ def __call__(self, image, target):
+ if random.random() < self.p:
+ kps = target["keypoints"]
+ vis = target["visible"]
+ upper_kps = []
+ lower_kps = []
+
+ # 对可见的keypoints进行归类
+ for i, v in enumerate(vis):
+ if v > 0.5:
+ if i in self.upper_body_ids:
+ upper_kps.append(kps[i])
+ else:
+ lower_kps.append(kps[i])
+
+ # 50%的概率选择上或下半身
+ if random.random() < 0.5:
+ selected_kps = upper_kps
+ else:
+ selected_kps = lower_kps
+
+ # 如果点数太少就不做任何处理
+ if len(selected_kps) > 2:
+ selected_kps = np.array(selected_kps, dtype=np.float32)
+ xmin, ymin = np.min(selected_kps, axis=0).tolist()
+ xmax, ymax = np.max(selected_kps, axis=0).tolist()
+ w = xmax - xmin
+ h = ymax - ymin
+ if w > 1 and h > 1:
+ # 把w和h适当放大点,要不然关键点处于边缘位置
+ xmin, ymin, w, h = scale_box(xmin, ymin, w, h, (1.5, 1.5))
+ target["box"] = [xmin, ymin, w, h]
+
+ return image, target
+
+
+class AffineTransform(object):
+ """scale+rotation"""
+ def __init__(self,
+ scale: Tuple[float, float] = None, # e.g. (0.65, 1.35)
+ rotation: Tuple[int, int] = None, # e.g. (-45, 45)
+ fixed_size: Tuple[int, int] = (256, 192)):
+ self.scale = scale
+ self.rotation = rotation
+ self.fixed_size = fixed_size
+
+ def __call__(self, img, target):
+ src_xmin, src_ymin, src_xmax, src_ymax = adjust_box(*target["box"], fixed_size=self.fixed_size)
+ src_w = src_xmax - src_xmin
+ src_h = src_ymax - src_ymin
+ src_center = np.array([(src_xmin + src_xmax) / 2, (src_ymin + src_ymax) / 2])
+ src_p2 = src_center + np.array([0, -src_h / 2]) # top middle
+ src_p3 = src_center + np.array([src_w / 2, 0]) # right middle
+
+ dst_center = np.array([(self.fixed_size[1] - 1) / 2, (self.fixed_size[0] - 1) / 2])
+ dst_p2 = np.array([(self.fixed_size[1] - 1) / 2, 0]) # top middle
+ dst_p3 = np.array([self.fixed_size[1] - 1, (self.fixed_size[0] - 1) / 2]) # right middle
+
+ if self.scale is not None:
+ scale = random.uniform(*self.scale)
+ src_w = src_w * scale
+ src_h = src_h * scale
+ src_p2 = src_center + np.array([0, -src_h / 2]) # top middle
+ src_p3 = src_center + np.array([src_w / 2, 0]) # right middle
+
+ if self.rotation is not None:
+ angle = random.randint(*self.rotation) # 角度制
+ angle = angle / 180 * math.pi # 弧度制
+ src_p2 = src_center + np.array([src_h / 2 * math.sin(angle), -src_h / 2 * math.cos(angle)])
+ src_p3 = src_center + np.array([src_w / 2 * math.cos(angle), src_w / 2 * math.sin(angle)])
+
+ src = np.stack([src_center, src_p2, src_p3]).astype(np.float32)
+ dst = np.stack([dst_center, dst_p2, dst_p3]).astype(np.float32)
+
+ trans = cv2.getAffineTransform(src, dst) # 计算正向仿射变换矩阵
+ dst /= 4 # 网络预测的heatmap尺寸是输入图像的1/4
+ reverse_trans = cv2.getAffineTransform(dst, src) # 计算逆向仿射变换矩阵,方便后续还原
+
+ # 对图像进行仿射变换
+ resize_img = cv2.warpAffine(img,
+ trans,
+ tuple(self.fixed_size[::-1]), # [w, h]
+ flags=cv2.INTER_LINEAR)
+
+ if "keypoints" in target:
+ kps = target["keypoints"]
+ mask = np.logical_and(kps[:, 0] != 0, kps[:, 1] != 0)
+ kps[mask] = affine_points(kps[mask], trans)
+ target["keypoints"] = kps
+
+ # import matplotlib.pyplot as plt
+ # from draw_utils import draw_keypoints
+ # resize_img = draw_keypoints(resize_img, target["keypoints"])
+ # plt.imshow(resize_img)
+ # plt.show()
+
+ target["trans"] = trans
+ target["reverse_trans"] = reverse_trans
+ return resize_img, target
+
+
+class RandomHorizontalFlip(object):
+ """随机对输入图片进行水平翻转,注意该方法必须接在 AffineTransform 后"""
+ def __init__(self, p: float = 0.5, matched_parts: list = None):
+ assert matched_parts is not None
+ self.p = p
+ self.matched_parts = matched_parts
+
+ def __call__(self, image, target):
+ if random.random() < self.p:
+ # [h, w, c]
+ image = np.ascontiguousarray(np.flip(image, axis=[1]))
+ keypoints = target["keypoints"]
+ visible = target["visible"]
+ width = image.shape[1]
+
+ # Flip horizontal
+ keypoints[:, 0] = width - keypoints[:, 0] - 1
+
+ # Change left-right parts
+ for pair in self.matched_parts:
+ keypoints[pair[0], :], keypoints[pair[1], :] = \
+ keypoints[pair[1], :], keypoints[pair[0], :].copy()
+
+ visible[pair[0]], visible[pair[1]] = \
+ visible[pair[1]], visible[pair[0]].copy()
+
+ target["keypoints"] = keypoints
+ target["visible"] = visible
+
+ return image, target
+
+
+class KeypointToHeatMap(object):
+ def __init__(self,
+ heatmap_hw: Tuple[int, int] = (256 // 4, 192 // 4),
+ gaussian_sigma: int = 2,
+ keypoints_weights=None):
+ self.heatmap_hw = heatmap_hw
+ self.sigma = gaussian_sigma
+ self.kernel_radius = self.sigma * 3
+ self.use_kps_weights = False if keypoints_weights is None else True
+ self.kps_weights = keypoints_weights
+
+ # generate gaussian kernel(not normalized)
+ kernel_size = 2 * self.kernel_radius + 1
+ kernel = np.zeros((kernel_size, kernel_size), dtype=np.float32)
+ x_center = y_center = kernel_size // 2
+ for x in range(kernel_size):
+ for y in range(kernel_size):
+ kernel[y, x] = np.exp(-((x - x_center) ** 2 + (y - y_center) ** 2) / (2 * self.sigma ** 2))
+ # print(kernel)
+
+ self.kernel = kernel
+
+ def __call__(self, image, target):
+ kps = target["keypoints"]
+ num_kps = kps.shape[0]
+ kps_weights = np.ones((num_kps,), dtype=np.float32)
+ if "visible" in target:
+ visible = target["visible"]
+ kps_weights = visible
+
+ heatmap = np.zeros((num_kps, self.heatmap_hw[0], self.heatmap_hw[1]), dtype=np.float32)
+ heatmap_kps = (kps / 4 + 0.5).astype(np.int) # round
+ for kp_id in range(num_kps):
+ v = kps_weights[kp_id]
+ if v < 0.5:
+ # 如果该点的可见度很低,则直接忽略
+ continue
+
+ x, y = heatmap_kps[kp_id]
+ ul = [x - self.kernel_radius, y - self.kernel_radius] # up-left x,y
+ br = [x + self.kernel_radius, y + self.kernel_radius] # bottom-right x,y
+ # 如果以xy为中心kernel_radius为半径的辐射范围内与heatmap没交集,则忽略该点(该规则并不严格)
+ if ul[0] > self.heatmap_hw[1] - 1 or \
+ ul[1] > self.heatmap_hw[0] - 1 or \
+ br[0] < 0 or \
+ br[1] < 0:
+ # If not, just return the image as is
+ kps_weights[kp_id] = 0
+ continue
+
+ # Usable gaussian range
+ # 计算高斯核有效区域(高斯核坐标系)
+ g_x = (max(0, -ul[0]), min(br[0], self.heatmap_hw[1] - 1) - ul[0])
+ g_y = (max(0, -ul[1]), min(br[1], self.heatmap_hw[0] - 1) - ul[1])
+ # image range
+ # 计算heatmap中的有效区域(heatmap坐标系)
+ img_x = (max(0, ul[0]), min(br[0], self.heatmap_hw[1] - 1))
+ img_y = (max(0, ul[1]), min(br[1], self.heatmap_hw[0] - 1))
+
+ if kps_weights[kp_id] > 0.5:
+ # 将高斯核有效区域复制到heatmap对应区域
+ heatmap[kp_id][img_y[0]:img_y[1] + 1, img_x[0]:img_x[1] + 1] = \
+ self.kernel[g_y[0]:g_y[1] + 1, g_x[0]:g_x[1] + 1]
+
+ if self.use_kps_weights:
+ kps_weights = np.multiply(kps_weights, self.kps_weights)
+
+ # plot_heatmap(image, heatmap, kps, kps_weights)
+
+ target["heatmap"] = torch.as_tensor(heatmap, dtype=torch.float32)
+ target["kps_weights"] = torch.as_tensor(kps_weights, dtype=torch.float32)
+
+ return image, target
diff --git a/pytorch_keypoint/HRNet/validation.py b/pytorch_keypoint/HRNet/validation.py
new file mode 100644
index 000000000..63d7611d0
--- /dev/null
+++ b/pytorch_keypoint/HRNet/validation.py
@@ -0,0 +1,205 @@
+"""
+该脚本用于调用训练好的模型权重去计算验证集/测试集的COCO指标
+"""
+
+import os
+import json
+
+import torch
+from tqdm import tqdm
+import numpy as np
+
+from model import HighResolutionNet
+from train_utils import EvalCOCOMetric
+from my_dataset_coco import CocoKeypoint
+import transforms
+
+
+def summarize(self, catId=None):
+ """
+ Compute and display summary metrics for evaluation results.
+ Note this functin can *only* be applied on the default parameter setting
+ """
+
+ def _summarize(ap=1, iouThr=None, areaRng='all', maxDets=100):
+ p = self.params
+ iStr = ' {:<18} {} @[ IoU={:<9} | area={:>6s} | maxDets={:>3d} ] = {:0.3f}'
+ titleStr = 'Average Precision' if ap == 1 else 'Average Recall'
+ typeStr = '(AP)' if ap == 1 else '(AR)'
+ iouStr = '{:0.2f}:{:0.2f}'.format(p.iouThrs[0], p.iouThrs[-1]) \
+ if iouThr is None else '{:0.2f}'.format(iouThr)
+
+ aind = [i for i, aRng in enumerate(p.areaRngLbl) if aRng == areaRng]
+ mind = [i for i, mDet in enumerate(p.maxDets) if mDet == maxDets]
+
+ if ap == 1:
+ # dimension of precision: [TxRxKxAxM]
+ s = self.eval['precision']
+ # IoU
+ if iouThr is not None:
+ t = np.where(iouThr == p.iouThrs)[0]
+ s = s[t]
+
+ if isinstance(catId, int):
+ s = s[:, :, catId, aind, mind]
+ else:
+ s = s[:, :, :, aind, mind]
+
+ else:
+ # dimension of recall: [TxKxAxM]
+ s = self.eval['recall']
+ if iouThr is not None:
+ t = np.where(iouThr == p.iouThrs)[0]
+ s = s[t]
+
+ if isinstance(catId, int):
+ s = s[:, catId, aind, mind]
+ else:
+ s = s[:, :, aind, mind]
+
+ if len(s[s > -1]) == 0:
+ mean_s = -1
+ else:
+ mean_s = np.mean(s[s > -1])
+
+ print_string = iStr.format(titleStr, typeStr, iouStr, areaRng, maxDets, mean_s)
+ return mean_s, print_string
+
+ stats, print_list = [0] * 10, [""] * 10
+ stats[0], print_list[0] = _summarize(1, maxDets=20)
+ stats[1], print_list[1] = _summarize(1, maxDets=20, iouThr=.5)
+ stats[2], print_list[2] = _summarize(1, maxDets=20, iouThr=.75)
+ stats[3], print_list[3] = _summarize(1, maxDets=20, areaRng='medium')
+ stats[4], print_list[4] = _summarize(1, maxDets=20, areaRng='large')
+ stats[5], print_list[5] = _summarize(0, maxDets=20)
+ stats[6], print_list[6] = _summarize(0, maxDets=20, iouThr=.5)
+ stats[7], print_list[7] = _summarize(0, maxDets=20, iouThr=.75)
+ stats[8], print_list[8] = _summarize(0, maxDets=20, areaRng='medium')
+ stats[9], print_list[9] = _summarize(0, maxDets=20, areaRng='large')
+
+ print_info = "\n".join(print_list)
+
+ if not self.eval:
+ raise Exception('Please run accumulate() first')
+
+ return stats, print_info
+
+
+def save_info(coco_evaluator,
+ save_name: str = "record_mAP.txt"):
+ # calculate COCO info for all keypoints
+ coco_stats, print_coco = summarize(coco_evaluator)
+
+ # 将验证结果保存至txt文件中
+ with open(save_name, "w") as f:
+ record_lines = ["COCO results:", print_coco]
+ f.write("\n".join(record_lines))
+
+
+def main(args):
+ device = torch.device(args.device if torch.cuda.is_available() else "cpu")
+ print("Using {} device training.".format(device.type))
+
+ data_transform = {
+ "val": transforms.Compose([
+ transforms.AffineTransform(scale=(1.25, 1.25), fixed_size=args.resize_hw),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+ ])
+ }
+
+ # read class_indict
+ label_json_path = args.label_json_path
+ assert os.path.exists(label_json_path), "json file {} dose not exist.".format(label_json_path)
+ with open(label_json_path, 'r') as f:
+ person_coco_info = json.load(f)
+
+ data_root = args.data_path
+
+ # 注意这里的collate_fn是自定义的,因为读取的数据包括image和targets,不能直接使用默认的方法合成batch
+ batch_size = args.batch_size
+ nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
+ print('Using %g dataloader workers' % nw)
+
+ # load validation data set
+ val_dataset = CocoKeypoint(data_root, "val", transforms=data_transform["val"], det_json_path=None)
+ # VOCdevkit -> VOC2012 -> ImageSets -> Main -> val.txt
+ # val_dataset = VOCInstances(data_root, year="2012", txt_name="val.txt", transforms=data_transform["val"])
+ val_dataset_loader = torch.utils.data.DataLoader(val_dataset,
+ batch_size=batch_size,
+ shuffle=False,
+ pin_memory=True,
+ num_workers=nw,
+ collate_fn=val_dataset.collate_fn)
+
+ # create model
+ model = HighResolutionNet()
+
+ # 载入你自己训练好的模型权重
+ weights_path = args.weights_path
+ assert os.path.exists(weights_path), "not found {} file.".format(weights_path)
+ model.load_state_dict(torch.load(weights_path, map_location='cpu'))
+ # print(model)
+ model.to(device)
+
+ # evaluate on the val dataset
+ key_metric = EvalCOCOMetric(val_dataset.coco, "keypoints", "key_results.json")
+ model.eval()
+ with torch.no_grad():
+ for images, targets in tqdm(val_dataset_loader, desc="validation..."):
+ # 将图片传入指定设备device
+ images = images.to(device)
+
+ # inference
+ outputs = model(images)
+ if args.flip:
+ flipped_images = transforms.flip_images(images)
+ flipped_outputs = model(flipped_images)
+ flipped_outputs = transforms.flip_back(flipped_outputs, person_coco_info["flip_pairs"])
+ # feature is not aligned, shift flipped heatmap for higher accuracy
+ # https://github.com/leoxiaobin/deep-high-resolution-net.pytorch/issues/22
+ flipped_outputs[..., 1:] = flipped_outputs.clone()[..., 0:-1]
+ outputs = (outputs + flipped_outputs) * 0.5
+
+ # decode keypoint
+ reverse_trans = [t["reverse_trans"] for t in targets]
+ outputs = transforms.get_final_preds(outputs, reverse_trans, post_processing=True)
+
+ key_metric.update(targets, outputs)
+
+ key_metric.synchronize_results()
+ key_metric.evaluate()
+
+ save_info(key_metric.coco_evaluator, "keypoint_record_mAP.txt")
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser(
+ description=__doc__)
+
+ # 使用设备类型
+ parser.add_argument('--device', default='cuda:0', help='device')
+
+ parser.add_argument('--resize-hw', type=list, default=[256, 192], help="resize for predict")
+ # 是否开启图像翻转
+ parser.add_argument('--flip', type=bool, default=True, help='whether using flipped images')
+
+ # 数据集的根目录
+ parser.add_argument('--data-path', default='/data/coco2017', help='dataset root')
+
+ # 训练好的权重文件
+ parser.add_argument('--weights-path', default='./pose_hrnet_w32_256x192.pth', type=str, help='training weights')
+
+ # batch size
+ parser.add_argument('--batch-size', default=1, type=int, metavar='N',
+ help='batch size when validation.')
+ # 类别索引和类别名称对应关系
+ parser.add_argument('--label-json-path', type=str, default="person_keypoints.json")
+ # 原项目提供的验证集person检测信息,如果要使用GT信息,直接将该参数置为None
+ parser.add_argument('--person-det', type=str, default="./COCO_val2017_detections_AP_H_56_person.json")
+
+ args = parser.parse_args()
+
+ main(args)
diff --git a/pytorch_object_detection/faster_rcnn/README.md b/pytorch_object_detection/faster_rcnn/README.md
index c674ca58a..08ac15cf0 100644
--- a/pytorch_object_detection/faster_rcnn/README.md
+++ b/pytorch_object_detection/faster_rcnn/README.md
@@ -6,10 +6,10 @@
## 环境配置:
* Python3.6/3.7/3.8
* Pytorch1.7.1(注意:必须是1.6.0或以上,因为使用官方提供的混合精度训练1.6.0后才支持)
-* pycocotools(Linux:```pip install pycocotools```; Windows:```pip install pycocotools-windows```(不需要额外安装vs))
+* pycocotools(Linux:`pip install pycocotools`; Windows:`pip install pycocotools-windows`(不需要额外安装vs))
* Ubuntu或Centos(不建议Windows)
* 最好使用GPU训练
-* 详细环境配置见```requirements.txt```
+* 详细环境配置见`requirements.txt`
## 文件结构:
```
@@ -26,10 +26,11 @@
```
## 预训练权重下载地址(下载后放入backbone文件夹中):
-* MobileNetV2 backbone: https://download.pytorch.org/models/mobilenet_v2-b0353104.pth
-* ResNet50+FPN backbone: https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth
-* 注意,下载的预训练权重记得要重命名,比如在train_resnet50_fpn.py中读取的是```fasterrcnn_resnet50_fpn_coco.pth```文件,
- 不是```fasterrcnn_resnet50_fpn_coco-258fb6c6.pth```
+* MobileNetV2 weights(下载后重命名为`mobilenet_v2.pth`,然后放到`bakcbone`文件夹下): https://download.pytorch.org/models/mobilenet_v2-b0353104.pth
+* Resnet50 weights(下载后重命名为`resnet50.pth`,然后放到`bakcbone`文件夹下): https://download.pytorch.org/models/resnet50-0676ba61.pth
+* ResNet50+FPN weights: https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth
+* 注意,下载的预训练权重记得要重命名,比如在train_resnet50_fpn.py中读取的是`fasterrcnn_resnet50_fpn_coco.pth`文件,
+ 不是`fasterrcnn_resnet50_fpn_coco-258fb6c6.pth`,然后放到当前项目根目录下即可。
## 数据集,本例程使用的是PASCAL VOC2012数据集
@@ -42,16 +43,17 @@
* 确保提前下载好对应预训练模型权重
* 若要训练mobilenetv2+fasterrcnn,直接使用train_mobilenet.py训练脚本
* 若要训练resnet50+fpn+fasterrcnn,直接使用train_resnet50_fpn.py训练脚本
-* 若要使用多GPU训练,使用```python -m torch.distributed.launch --nproc_per_node=8 --use_env train_multi_GPU.py```指令,```nproc_per_node```参数为使用GPU数量
-* 如果想指定使用哪些GPU设备可在指令前加上```CUDA_VISIBLE_DEVICES=0,3```(例如我只要使用设备中的第1块和第4块GPU设备)
-* ```CUDA_VISIBLE_DEVICES=0,3 python -m torch.distributed.launch --nproc_per_node=2 --use_env train_multi_GPU.py```
+* 若要使用多GPU训练,使用`python -m torch.distributed.launch --nproc_per_node=8 --use_env train_multi_GPU.py`指令,`nproc_per_node`参数为使用GPU数量
+* 如果想指定使用哪些GPU设备可在指令前加上`CUDA_VISIBLE_DEVICES=0,3`(例如我只要使用设备中的第1块和第4块GPU设备)
+* `CUDA_VISIBLE_DEVICES=0,3 python -m torch.distributed.launch --nproc_per_node=2 --use_env train_multi_GPU.py`
## 注意事项
-* 在使用训练脚本时,注意要将'--data-path'(VOC_root)设置为自己存放'VOCdevkit'文件夹所在的**根目录**
+* 在使用训练脚本时,注意要将`--data-path`(VOC_root)设置为自己存放`VOCdevkit`文件夹所在的**根目录**
* 由于带有FPN结构的Faster RCNN很吃显存,如果GPU的显存不够(如果batch_size小于8的话)建议在create_model函数中使用默认的norm_layer,
即不传递norm_layer变量,默认去使用FrozenBatchNorm2d(即不会去更新参数的bn层),使用中发现效果也很好。
-* 在使用预测脚本时,要将'train_weights'设置为你自己生成的权重路径。
-* 使用validation文件时,注意确保你的验证集或者测试集中必须包含每个类别的目标,并且使用时只需要修改'--num-classes'、'--data-path'和'--weights'即可,其他代码尽量不要改动
+* 训练过程中保存的`results.txt`是每个epoch在验证集上的COCO指标,前12个值是COCO指标,后面两个值是训练平均损失以及学习率
+* 在使用预测脚本时,要将`train_weights`设置为你自己生成的权重路径。
+* 使用validation文件时,注意确保你的验证集或者测试集中必须包含每个类别的目标,并且使用时只需要修改`--num-classes`、`--data-path`和`--weights-path`即可,其他代码尽量不要改动
## 如果对Faster RCNN原理不是很理解可参考我的bilibili
* https://b23.tv/sXcBSP
diff --git a/pytorch_object_detection/faster_rcnn/backbone/__init__.py b/pytorch_object_detection/faster_rcnn/backbone/__init__.py
index f7559da86..1cedf7584 100644
--- a/pytorch_object_detection/faster_rcnn/backbone/__init__.py
+++ b/pytorch_object_detection/faster_rcnn/backbone/__init__.py
@@ -1,3 +1,4 @@
from .resnet50_fpn_model import resnet50_fpn_backbone
from .mobilenetv2_model import MobileNetV2
from .vgg_model import vgg
+from .feature_pyramid_network import LastLevelMaxPool, BackboneWithFPN
diff --git a/pytorch_object_detection/faster_rcnn/backbone/feature_pyramid_network.py b/pytorch_object_detection/faster_rcnn/backbone/feature_pyramid_network.py
index 636829fc0..450960985 100644
--- a/pytorch_object_detection/faster_rcnn/backbone/feature_pyramid_network.py
+++ b/pytorch_object_detection/faster_rcnn/backbone/feature_pyramid_network.py
@@ -8,6 +8,59 @@
from torch.jit.annotations import Tuple, List, Dict
+class IntermediateLayerGetter(nn.ModuleDict):
+ """
+ Module wrapper that returns intermediate layers from a model
+ It has a strong assumption that the modules have been registered
+ into the model in the same order as they are used.
+ This means that one should **not** reuse the same nn.Module
+ twice in the forward if you want this to work.
+ Additionally, it is only able to query submodules that are directly
+ assigned to the model. So if `model` is passed, `model.feature1` can
+ be returned, but not `model.feature1.layer2`.
+ Arguments:
+ model (nn.Module): model on which we will extract the features
+ return_layers (Dict[name, new_name]): a dict containing the names
+ of the modules for which the activations will be returned as
+ the key of the dict, and the value of the dict is the name
+ of the returned activation (which the user can specify).
+ """
+ __annotations__ = {
+ "return_layers": Dict[str, str],
+ }
+
+ def __init__(self, model, return_layers):
+ if not set(return_layers).issubset([name for name, _ in model.named_children()]):
+ raise ValueError("return_layers are not present in model")
+
+ orig_return_layers = return_layers
+ return_layers = {str(k): str(v) for k, v in return_layers.items()}
+ layers = OrderedDict()
+
+ # 遍历模型子模块按顺序存入有序字典
+ # 只保存layer4及其之前的结构,舍去之后不用的结构
+ for name, module in model.named_children():
+ layers[name] = module
+ if name in return_layers:
+ del return_layers[name]
+ if not return_layers:
+ break
+
+ super().__init__(layers)
+ self.return_layers = orig_return_layers
+
+ def forward(self, x):
+ out = OrderedDict()
+ # 依次遍历模型的所有子模块,并进行正向传播,
+ # 收集layer1, layer2, layer3, layer4的输出
+ for name, module in self.items():
+ x = module(x)
+ if name in self.return_layers:
+ out_name = self.return_layers[name]
+ out[out_name] = x
+ return out
+
+
class FeaturePyramidNetwork(nn.Module):
"""
Module that adds a FPN from on top of a set of feature maps. This is based on
@@ -27,7 +80,7 @@ class FeaturePyramidNetwork(nn.Module):
"""
def __init__(self, in_channels_list, out_channels, extra_blocks=None):
- super(FeaturePyramidNetwork, self).__init__()
+ super().__init__()
# 用来调整resnet特征矩阵(layer1,2,3,4)的channel(kernel_size=1)
self.inner_blocks = nn.ModuleList()
# 对调整后的特征矩阵使用3x3的卷积核来得到对应的预测特征矩阵
@@ -48,8 +101,7 @@ def __init__(self, in_channels_list, out_channels, extra_blocks=None):
self.extra_blocks = extra_blocks
- def get_result_from_inner_blocks(self, x, idx):
- # type: (Tensor, int) -> Tensor
+ def get_result_from_inner_blocks(self, x: Tensor, idx: int) -> Tensor:
"""
This is equivalent to self.inner_blocks[idx](x),
but torchscript doesn't support this yet
@@ -65,8 +117,7 @@ def get_result_from_inner_blocks(self, x, idx):
i += 1
return out
- def get_result_from_layer_blocks(self, x, idx):
- # type: (Tensor, int) -> Tensor
+ def get_result_from_layer_blocks(self, x: Tensor, idx: int) -> Tensor:
"""
This is equivalent to self.layer_blocks[idx](x),
but torchscript doesn't support this yet
@@ -82,8 +133,7 @@ def get_result_from_layer_blocks(self, x, idx):
i += 1
return out
- def forward(self, x):
- # type: (Dict[str, Tensor]) -> Dict[str, Tensor]
+ def forward(self, x: Dict[str, Tensor]) -> Dict[str, Tensor]:
"""
Computes the FPN for a set of feature maps.
Arguments:
@@ -127,8 +177,59 @@ class LastLevelMaxPool(torch.nn.Module):
Applies a max_pool2d on top of the last feature map
"""
- def forward(self, x, y, names):
- # type: (List[Tensor], List[Tensor], List[str]) -> Tuple[List[Tensor], List[str]]
+ def forward(self, x: List[Tensor], y: List[Tensor], names: List[str]) -> Tuple[List[Tensor], List[str]]:
names.append("pool")
x.append(F.max_pool2d(x[-1], 1, 2, 0)) # input, kernel_size, stride, padding
return x, names
+
+
+class BackboneWithFPN(nn.Module):
+ """
+ Adds a FPN on top of a model.
+ Internally, it uses torchvision.models._utils.IntermediateLayerGetter to
+ extract a submodel that returns the feature maps specified in return_layers.
+ The same limitations of IntermediatLayerGetter apply here.
+ Arguments:
+ backbone (nn.Module)
+ return_layers (Dict[name, new_name]): a dict containing the names
+ of the modules for which the activations will be returned as
+ the key of the dict, and the value of the dict is the name
+ of the returned activation (which the user can specify).
+ in_channels_list (List[int]): number of channels for each feature map
+ that is returned, in the order they are present in the OrderedDict
+ out_channels (int): number of channels in the FPN.
+ extra_blocks: ExtraFPNBlock
+ Attributes:
+ out_channels (int): the number of channels in the FPN
+ """
+
+ def __init__(self,
+ backbone: nn.Module,
+ return_layers=None,
+ in_channels_list=None,
+ out_channels=256,
+ extra_blocks=None,
+ re_getter=True):
+ super().__init__()
+
+ if extra_blocks is None:
+ extra_blocks = LastLevelMaxPool()
+
+ if re_getter is True:
+ assert return_layers is not None
+ self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
+ else:
+ self.body = backbone
+
+ self.fpn = FeaturePyramidNetwork(
+ in_channels_list=in_channels_list,
+ out_channels=out_channels,
+ extra_blocks=extra_blocks,
+ )
+
+ self.out_channels = out_channels
+
+ def forward(self, x):
+ x = self.body(x)
+ x = self.fpn(x)
+ return x
diff --git a/pytorch_object_detection/faster_rcnn/backbone/resnet50_fpn_model.py b/pytorch_object_detection/faster_rcnn/backbone/resnet50_fpn_model.py
index 8c796cfac..b15930765 100644
--- a/pytorch_object_detection/faster_rcnn/backbone/resnet50_fpn_model.py
+++ b/pytorch_object_detection/faster_rcnn/backbone/resnet50_fpn_model.py
@@ -1,19 +1,17 @@
import os
-from collections import OrderedDict
import torch
import torch.nn as nn
-from torch.jit.annotations import List, Dict
from torchvision.ops.misc import FrozenBatchNorm2d
-from .feature_pyramid_network import FeaturePyramidNetwork, LastLevelMaxPool
+from .feature_pyramid_network import BackboneWithFPN, LastLevelMaxPool
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, in_channel, out_channel, stride=1, downsample=None, norm_layer=None):
- super(Bottleneck, self).__init__()
+ super().__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
@@ -56,7 +54,7 @@ def forward(self, x):
class ResNet(nn.Module):
def __init__(self, block, blocks_num, num_classes=1000, include_top=True, norm_layer=None):
- super(ResNet, self).__init__()
+ super().__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer
@@ -136,100 +134,6 @@ def overwrite_eps(model, eps):
module.eps = eps
-class IntermediateLayerGetter(nn.ModuleDict):
- """
- Module wrapper that returns intermediate layers from a model
- It has a strong assumption that the modules have been registered
- into the model in the same order as they are used.
- This means that one should **not** reuse the same nn.Module
- twice in the forward if you want this to work.
- Additionally, it is only able to query submodules that are directly
- assigned to the model. So if `model` is passed, `model.feature1` can
- be returned, but not `model.feature1.layer2`.
- Arguments:
- model (nn.Module): model on which we will extract the features
- return_layers (Dict[name, new_name]): a dict containing the names
- of the modules for which the activations will be returned as
- the key of the dict, and the value of the dict is the name
- of the returned activation (which the user can specify).
- """
- __annotations__ = {
- "return_layers": Dict[str, str],
- }
-
- def __init__(self, model, return_layers):
- if not set(return_layers).issubset([name for name, _ in model.named_children()]):
- raise ValueError("return_layers are not present in model")
-
- orig_return_layers = return_layers
- return_layers = {str(k): str(v) for k, v in return_layers.items()}
- layers = OrderedDict()
-
- # 遍历模型子模块按顺序存入有序字典
- # 只保存layer4及其之前的结构,舍去之后不用的结构
- for name, module in model.named_children():
- layers[name] = module
- if name in return_layers:
- del return_layers[name]
- if not return_layers:
- break
-
- super(IntermediateLayerGetter, self).__init__(layers)
- self.return_layers = orig_return_layers
-
- def forward(self, x):
- out = OrderedDict()
- # 依次遍历模型的所有子模块,并进行正向传播,
- # 收集layer1, layer2, layer3, layer4的输出
- for name, module in self.items():
- x = module(x)
- if name in self.return_layers:
- out_name = self.return_layers[name]
- out[out_name] = x
- return out
-
-
-class BackboneWithFPN(nn.Module):
- """
- Adds a FPN on top of a model.
- Internally, it uses torchvision.models._utils.IntermediateLayerGetter to
- extract a submodel that returns the feature maps specified in return_layers.
- The same limitations of IntermediatLayerGetter apply here.
- Arguments:
- backbone (nn.Module)
- return_layers (Dict[name, new_name]): a dict containing the names
- of the modules for which the activations will be returned as
- the key of the dict, and the value of the dict is the name
- of the returned activation (which the user can specify).
- in_channels_list (List[int]): number of channels for each feature map
- that is returned, in the order they are present in the OrderedDict
- out_channels (int): number of channels in the FPN.
- extra_blocks: ExtraFPNBlock
- Attributes:
- out_channels (int): the number of channels in the FPN
- """
-
- def __init__(self, backbone, return_layers, in_channels_list, out_channels, extra_blocks=None):
- super(BackboneWithFPN, self).__init__()
-
- if extra_blocks is None:
- extra_blocks = LastLevelMaxPool()
-
- self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
- self.fpn = FeaturePyramidNetwork(
- in_channels_list=in_channels_list,
- out_channels=out_channels,
- extra_blocks=extra_blocks,
- )
-
- self.out_channels = out_channels
-
- def forward(self, x):
- x = self.body(x)
- x = self.fpn(x)
- return x
-
-
def resnet50_fpn_backbone(pretrain_path="",
norm_layer=FrozenBatchNorm2d, # FrozenBatchNorm2d的功能与BatchNorm2d类似,但参数无法更新
trainable_layers=3,
diff --git a/pytorch_object_detection/faster_rcnn/change_backbone_with_fpn.py b/pytorch_object_detection/faster_rcnn/change_backbone_with_fpn.py
new file mode 100644
index 000000000..4ee20fd44
--- /dev/null
+++ b/pytorch_object_detection/faster_rcnn/change_backbone_with_fpn.py
@@ -0,0 +1,255 @@
+import os
+import datetime
+
+import torch
+
+import transforms
+from network_files import FasterRCNN, AnchorsGenerator
+from my_dataset import VOCDataSet
+from train_utils import GroupedBatchSampler, create_aspect_ratio_groups
+from train_utils import train_eval_utils as utils
+from backbone import BackboneWithFPN, LastLevelMaxPool
+
+
+def create_model(num_classes):
+ import torchvision
+ from torchvision.models.feature_extraction import create_feature_extractor
+
+ # --- mobilenet_v3_large fpn backbone --- #
+ backbone = torchvision.models.mobilenet_v3_large(pretrained=True)
+ # print(backbone)
+ return_layers = {"features.6": "0", # stride 8
+ "features.12": "1", # stride 16
+ "features.16": "2"} # stride 32
+ # 提供给fpn的每个特征层channel
+ in_channels_list = [40, 112, 960]
+ new_backbone = create_feature_extractor(backbone, return_layers)
+ # img = torch.randn(1, 3, 224, 224)
+ # outputs = new_backbone(img)
+ # [print(f"{k} shape: {v.shape}") for k, v in outputs.items()]
+
+ # --- efficientnet_b0 fpn backbone --- #
+ # backbone = torchvision.models.efficientnet_b0(pretrained=True)
+ # # print(backbone)
+ # return_layers = {"features.3": "0", # stride 8
+ # "features.4": "1", # stride 16
+ # "features.8": "2"} # stride 32
+ # # 提供给fpn的每个特征层channel
+ # in_channels_list = [40, 80, 1280]
+ # new_backbone = create_feature_extractor(backbone, return_layers)
+ # # img = torch.randn(1, 3, 224, 224)
+ # # outputs = new_backbone(img)
+ # # [print(f"{k} shape: {v.shape}") for k, v in outputs.items()]
+
+ backbone_with_fpn = BackboneWithFPN(new_backbone,
+ return_layers=return_layers,
+ in_channels_list=in_channels_list,
+ out_channels=256,
+ extra_blocks=LastLevelMaxPool(),
+ re_getter=False)
+
+ anchor_sizes = ((64,), (128,), (256,), (512,))
+ aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
+ anchor_generator = AnchorsGenerator(sizes=anchor_sizes,
+ aspect_ratios=aspect_ratios)
+
+ roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0', '1', '2'], # 在哪些特征层上进行RoIAlign pooling
+ output_size=[7, 7], # RoIAlign pooling输出特征矩阵尺寸
+ sampling_ratio=2) # 采样率
+
+ model = FasterRCNN(backbone=backbone_with_fpn,
+ num_classes=num_classes,
+ rpn_anchor_generator=anchor_generator,
+ box_roi_pool=roi_pooler)
+
+ return model
+
+
+def main(args):
+ device = torch.device(args.device if torch.cuda.is_available() else "cpu")
+ print("Using {} device training.".format(device.type))
+
+ # 用来保存coco_info的文件
+ results_file = "results{}.txt".format(datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
+
+ data_transform = {
+ "train": transforms.Compose([transforms.ToTensor(),
+ transforms.RandomHorizontalFlip(0.5)]),
+ "val": transforms.Compose([transforms.ToTensor()])
+ }
+
+ VOC_root = args.data_path
+ # check voc root
+ if os.path.exists(os.path.join(VOC_root, "VOCdevkit")) is False:
+ raise FileNotFoundError("VOCdevkit dose not in path:'{}'.".format(VOC_root))
+
+ # load train data set
+ # VOCdevkit -> VOC2012 -> ImageSets -> Main -> train.txt
+ train_dataset = VOCDataSet(VOC_root, "2012", data_transform["train"], "train.txt")
+ train_sampler = None
+
+ # 是否按图片相似高宽比采样图片组成batch
+ # 使用的话能够减小训练时所需GPU显存,默认使用
+ if args.aspect_ratio_group_factor >= 0:
+ train_sampler = torch.utils.data.RandomSampler(train_dataset)
+ # 统计所有图像高宽比例在bins区间中的位置索引
+ group_ids = create_aspect_ratio_groups(train_dataset, k=args.aspect_ratio_group_factor)
+ # 每个batch图片从同一高宽比例区间中取
+ train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids, args.batch_size)
+
+ # 注意这里的collate_fn是自定义的,因为读取的数据包括image和targets,不能直接使用默认的方法合成batch
+ batch_size = args.batch_size
+ nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
+ print('Using %g dataloader workers' % nw)
+ if train_sampler:
+ # 如果按照图片高宽比采样图片,dataloader中需要使用batch_sampler
+ train_data_loader = torch.utils.data.DataLoader(train_dataset,
+ batch_sampler=train_batch_sampler,
+ pin_memory=True,
+ num_workers=nw,
+ collate_fn=train_dataset.collate_fn)
+ else:
+ train_data_loader = torch.utils.data.DataLoader(train_dataset,
+ batch_size=batch_size,
+ shuffle=True,
+ pin_memory=True,
+ num_workers=nw,
+ collate_fn=train_dataset.collate_fn)
+
+ # load validation data set
+ # VOCdevkit -> VOC2012 -> ImageSets -> Main -> val.txt
+ val_dataset = VOCDataSet(VOC_root, "2012", data_transform["val"], "val.txt")
+ val_data_set_loader = torch.utils.data.DataLoader(val_dataset,
+ batch_size=1,
+ shuffle=False,
+ pin_memory=True,
+ num_workers=nw,
+ collate_fn=val_dataset.collate_fn)
+
+ # create model num_classes equal background + 20 classes
+ model = create_model(num_classes=args.num_classes + 1)
+ # print(model)
+
+ model.to(device)
+
+ # define optimizer
+ params = [p for p in model.parameters() if p.requires_grad]
+ optimizer = torch.optim.SGD(params,
+ lr=args.lr,
+ momentum=args.momentum,
+ weight_decay=args.weight_decay)
+
+ scaler = torch.cuda.amp.GradScaler() if args.amp else None
+
+ # learning rate scheduler
+ lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
+ step_size=3,
+ gamma=0.33)
+
+ # 如果指定了上次训练保存的权重文件地址,则接着上次结果接着训练
+ if args.resume != "":
+ checkpoint = torch.load(args.resume, map_location='cpu')
+ model.load_state_dict(checkpoint['model'])
+ optimizer.load_state_dict(checkpoint['optimizer'])
+ lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
+ args.start_epoch = checkpoint['epoch'] + 1
+ if args.amp and "scaler" in checkpoint:
+ scaler.load_state_dict(checkpoint["scaler"])
+ print("the training process from epoch{}...".format(args.start_epoch))
+
+ train_loss = []
+ learning_rate = []
+ val_map = []
+
+ for epoch in range(args.start_epoch, args.epochs):
+ # train for one epoch, printing every 10 iterations
+ mean_loss, lr = utils.train_one_epoch(model, optimizer, train_data_loader,
+ device=device, epoch=epoch,
+ print_freq=50, warmup=True,
+ scaler=scaler)
+ train_loss.append(mean_loss.item())
+ learning_rate.append(lr)
+
+ # update the learning rate
+ lr_scheduler.step()
+
+ # evaluate on the test dataset
+ coco_info = utils.evaluate(model, val_data_set_loader, device=device)
+
+ # write into txt
+ with open(results_file, "a") as f:
+ # 写入的数据包括coco指标还有loss和learning rate
+ result_info = [f"{i:.4f}" for i in coco_info + [mean_loss.item()]] + [f"{lr:.6f}"]
+ txt = "epoch:{} {}".format(epoch, ' '.join(result_info))
+ f.write(txt + "\n")
+
+ val_map.append(coco_info[1]) # pascal mAP
+
+ # save weights
+ save_files = {
+ 'model': model.state_dict(),
+ 'optimizer': optimizer.state_dict(),
+ 'lr_scheduler': lr_scheduler.state_dict(),
+ 'epoch': epoch}
+ if args.amp:
+ save_files["scaler"] = scaler.state_dict()
+ torch.save(save_files, "./save_weights/resNetFpn-model-{}.pth".format(epoch))
+
+ # plot loss and lr curve
+ if len(train_loss) != 0 and len(learning_rate) != 0:
+ from plot_curve import plot_loss_and_lr
+ plot_loss_and_lr(train_loss, learning_rate)
+
+ # plot mAP curve
+ if len(val_map) != 0:
+ from plot_curve import plot_map
+ plot_map(val_map)
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser(
+ description=__doc__)
+
+ # 训练设备类型
+ parser.add_argument('--device', default='cuda:0', help='device')
+ # 训练数据集的根目录(VOCdevkit)
+ parser.add_argument('--data-path', default='./', help='dataset')
+ # 检测目标类别数(不包含背景)
+ parser.add_argument('--num-classes', default=20, type=int, help='num_classes')
+ # 文件保存地址
+ parser.add_argument('--output-dir', default='./save_weights', help='path where to save')
+ # 若需要接着上次训练,则指定上次训练保存权重文件地址
+ parser.add_argument('--resume', default='', type=str, help='resume from checkpoint')
+ # 指定接着从哪个epoch数开始训练
+ parser.add_argument('--start_epoch', default=0, type=int, help='start epoch')
+ # 训练的总epoch数
+ parser.add_argument('--epochs', default=15, type=int, metavar='N',
+ help='number of total epochs to run')
+ # 学习率
+ parser.add_argument('--lr', default=0.005, type=float,
+ help='initial learning rate, 0.02 is the default value for training '
+ 'on 8 gpus and 2 images_per_gpu')
+ # SGD的momentum参数
+ parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
+ help='momentum')
+ # SGD的weight_decay参数
+ parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
+ metavar='W', help='weight decay (default: 1e-4)',
+ dest='weight_decay')
+ # 训练的batch size
+ parser.add_argument('--batch_size', default=4, type=int, metavar='N',
+ help='batch size when training.')
+ parser.add_argument('--aspect-ratio-group-factor', default=3, type=int)
+ # 是否使用混合精度训练(需要GPU支持混合精度)
+ parser.add_argument("--amp", default=False, help="Use torch.cuda.amp for mixed precision training")
+
+ args = parser.parse_args()
+ print(args)
+
+ # 检查保存权重文件夹是否存在,不存在则创建
+ if not os.path.exists(args.output_dir):
+ os.makedirs(args.output_dir)
+
+ main(args)
diff --git a/pytorch_object_detection/faster_rcnn/change_backbone_without_fpn.py b/pytorch_object_detection/faster_rcnn/change_backbone_without_fpn.py
new file mode 100644
index 000000000..f4c9e5938
--- /dev/null
+++ b/pytorch_object_detection/faster_rcnn/change_backbone_without_fpn.py
@@ -0,0 +1,243 @@
+import os
+import datetime
+
+import torch
+
+import transforms
+from network_files import FasterRCNN, AnchorsGenerator
+from my_dataset import VOCDataSet
+from train_utils import GroupedBatchSampler, create_aspect_ratio_groups
+from train_utils import train_eval_utils as utils
+
+
+def create_model(num_classes):
+ import torchvision
+ from torchvision.models.feature_extraction import create_feature_extractor
+
+ # vgg16
+ backbone = torchvision.models.vgg16_bn(pretrained=True)
+ # print(backbone)
+ backbone = create_feature_extractor(backbone, return_nodes={"features.42": "0"})
+ # out = backbone(torch.rand(1, 3, 224, 224))
+ # print(out["0"].shape)
+ backbone.out_channels = 512
+
+ # resnet50 backbone
+ # backbone = torchvision.models.resnet50(pretrained=True)
+ # # print(backbone)
+ # backbone = create_feature_extractor(backbone, return_nodes={"layer3": "0"})
+ # # out = backbone(torch.rand(1, 3, 224, 224))
+ # # print(out["0"].shape)
+ # backbone.out_channels = 1024
+
+ # EfficientNetB0
+ # backbone = torchvision.models.efficientnet_b0(pretrained=True)
+ # # print(backbone)
+ # backbone = create_feature_extractor(backbone, return_nodes={"features.5": "0"})
+ # # out = backbone(torch.rand(1, 3, 224, 224))
+ # # print(out["0"].shape)
+ # backbone.out_channels = 112
+
+ anchor_generator = AnchorsGenerator(sizes=((32, 64, 128, 256, 512),),
+ aspect_ratios=((0.5, 1.0, 2.0),))
+
+ roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'], # 在哪些特征层上进行RoIAlign pooling
+ output_size=[7, 7], # RoIAlign pooling输出特征矩阵尺寸
+ sampling_ratio=2) # 采样率
+
+ model = FasterRCNN(backbone=backbone,
+ num_classes=num_classes,
+ rpn_anchor_generator=anchor_generator,
+ box_roi_pool=roi_pooler)
+
+ return model
+
+
+def main(args):
+ device = torch.device(args.device if torch.cuda.is_available() else "cpu")
+ print("Using {} device training.".format(device.type))
+
+ # 用来保存coco_info的文件
+ results_file = "results{}.txt".format(datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
+
+ data_transform = {
+ "train": transforms.Compose([transforms.ToTensor(),
+ transforms.RandomHorizontalFlip(0.5)]),
+ "val": transforms.Compose([transforms.ToTensor()])
+ }
+
+ VOC_root = args.data_path
+ # check voc root
+ if os.path.exists(os.path.join(VOC_root, "VOCdevkit")) is False:
+ raise FileNotFoundError("VOCdevkit dose not in path:'{}'.".format(VOC_root))
+
+ # load train data set
+ # VOCdevkit -> VOC2012 -> ImageSets -> Main -> train.txt
+ train_dataset = VOCDataSet(VOC_root, "2012", data_transform["train"], "train.txt")
+ train_sampler = None
+
+ # 是否按图片相似高宽比采样图片组成batch
+ # 使用的话能够减小训练时所需GPU显存,默认使用
+ if args.aspect_ratio_group_factor >= 0:
+ train_sampler = torch.utils.data.RandomSampler(train_dataset)
+ # 统计所有图像高宽比例在bins区间中的位置索引
+ group_ids = create_aspect_ratio_groups(train_dataset, k=args.aspect_ratio_group_factor)
+ # 每个batch图片从同一高宽比例区间中取
+ train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids, args.batch_size)
+
+ # 注意这里的collate_fn是自定义的,因为读取的数据包括image和targets,不能直接使用默认的方法合成batch
+ batch_size = args.batch_size
+ nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
+ print('Using %g dataloader workers' % nw)
+ if train_sampler:
+ # 如果按照图片高宽比采样图片,dataloader中需要使用batch_sampler
+ train_data_loader = torch.utils.data.DataLoader(train_dataset,
+ batch_sampler=train_batch_sampler,
+ pin_memory=True,
+ num_workers=nw,
+ collate_fn=train_dataset.collate_fn)
+ else:
+ train_data_loader = torch.utils.data.DataLoader(train_dataset,
+ batch_size=batch_size,
+ shuffle=True,
+ pin_memory=True,
+ num_workers=nw,
+ collate_fn=train_dataset.collate_fn)
+
+ # load validation data set
+ # VOCdevkit -> VOC2012 -> ImageSets -> Main -> val.txt
+ val_dataset = VOCDataSet(VOC_root, "2012", data_transform["val"], "val.txt")
+ val_data_set_loader = torch.utils.data.DataLoader(val_dataset,
+ batch_size=1,
+ shuffle=False,
+ pin_memory=True,
+ num_workers=nw,
+ collate_fn=val_dataset.collate_fn)
+
+ # create model num_classes equal background + 20 classes
+ model = create_model(num_classes=args.num_classes + 1)
+ # print(model)
+
+ model.to(device)
+
+ # define optimizer
+ params = [p for p in model.parameters() if p.requires_grad]
+ optimizer = torch.optim.SGD(params,
+ lr=args.lr,
+ momentum=args.momentum,
+ weight_decay=args.weight_decay)
+
+ scaler = torch.cuda.amp.GradScaler() if args.amp else None
+
+ # learning rate scheduler
+ lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
+ step_size=3,
+ gamma=0.33)
+
+ # 如果指定了上次训练保存的权重文件地址,则接着上次结果接着训练
+ if args.resume != "":
+ checkpoint = torch.load(args.resume, map_location='cpu')
+ model.load_state_dict(checkpoint['model'])
+ optimizer.load_state_dict(checkpoint['optimizer'])
+ lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
+ args.start_epoch = checkpoint['epoch'] + 1
+ if args.amp and "scaler" in checkpoint:
+ scaler.load_state_dict(checkpoint["scaler"])
+ print("the training process from epoch{}...".format(args.start_epoch))
+
+ train_loss = []
+ learning_rate = []
+ val_map = []
+
+ for epoch in range(args.start_epoch, args.epochs):
+ # train for one epoch, printing every 10 iterations
+ mean_loss, lr = utils.train_one_epoch(model, optimizer, train_data_loader,
+ device=device, epoch=epoch,
+ print_freq=50, warmup=True,
+ scaler=scaler)
+ train_loss.append(mean_loss.item())
+ learning_rate.append(lr)
+
+ # update the learning rate
+ lr_scheduler.step()
+
+ # evaluate on the test dataset
+ coco_info = utils.evaluate(model, val_data_set_loader, device=device)
+
+ # write into txt
+ with open(results_file, "a") as f:
+ # 写入的数据包括coco指标还有loss和learning rate
+ result_info = [f"{i:.4f}" for i in coco_info + [mean_loss.item()]] + [f"{lr:.6f}"]
+ txt = "epoch:{} {}".format(epoch, ' '.join(result_info))
+ f.write(txt + "\n")
+
+ val_map.append(coco_info[1]) # pascal mAP
+
+ # save weights
+ save_files = {
+ 'model': model.state_dict(),
+ 'optimizer': optimizer.state_dict(),
+ 'lr_scheduler': lr_scheduler.state_dict(),
+ 'epoch': epoch}
+ if args.amp:
+ save_files["scaler"] = scaler.state_dict()
+ torch.save(save_files, "./save_weights/resNetFpn-model-{}.pth".format(epoch))
+
+ # plot loss and lr curve
+ if len(train_loss) != 0 and len(learning_rate) != 0:
+ from plot_curve import plot_loss_and_lr
+ plot_loss_and_lr(train_loss, learning_rate)
+
+ # plot mAP curve
+ if len(val_map) != 0:
+ from plot_curve import plot_map
+ plot_map(val_map)
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser(
+ description=__doc__)
+
+ # 训练设备类型
+ parser.add_argument('--device', default='cuda:0', help='device')
+ # 训练数据集的根目录(VOCdevkit)
+ parser.add_argument('--data-path', default='./', help='dataset')
+ # 检测目标类别数(不包含背景)
+ parser.add_argument('--num-classes', default=20, type=int, help='num_classes')
+ # 文件保存地址
+ parser.add_argument('--output-dir', default='./save_weights', help='path where to save')
+ # 若需要接着上次训练,则指定上次训练保存权重文件地址
+ parser.add_argument('--resume', default='', type=str, help='resume from checkpoint')
+ # 指定接着从哪个epoch数开始训练
+ parser.add_argument('--start_epoch', default=0, type=int, help='start epoch')
+ # 训练的总epoch数
+ parser.add_argument('--epochs', default=15, type=int, metavar='N',
+ help='number of total epochs to run')
+ # 学习率
+ parser.add_argument('--lr', default=0.005, type=float,
+ help='initial learning rate, 0.02 is the default value for training '
+ 'on 8 gpus and 2 images_per_gpu')
+ # SGD的momentum参数
+ parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
+ help='momentum')
+ # SGD的weight_decay参数
+ parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
+ metavar='W', help='weight decay (default: 1e-4)',
+ dest='weight_decay')
+ # 训练的batch size
+ parser.add_argument('--batch_size', default=4, type=int, metavar='N',
+ help='batch size when training.')
+ parser.add_argument('--aspect-ratio-group-factor', default=3, type=int)
+ # 是否使用混合精度训练(需要GPU支持混合精度)
+ parser.add_argument("--amp", default=False, help="Use torch.cuda.amp for mixed precision training")
+
+ args = parser.parse_args()
+ print(args)
+
+ # 检查保存权重文件夹是否存在,不存在则创建
+ if not os.path.exists(args.output_dir):
+ os.makedirs(args.output_dir)
+
+ main(args)
diff --git a/pytorch_object_detection/faster_rcnn/draw_box_utils.py b/pytorch_object_detection/faster_rcnn/draw_box_utils.py
index 1a2926583..835d7f7c1 100644
--- a/pytorch_object_detection/faster_rcnn/draw_box_utils.py
+++ b/pytorch_object_detection/faster_rcnn/draw_box_utils.py
@@ -1,6 +1,7 @@
-import collections
+from PIL.Image import Image, fromarray
import PIL.ImageDraw as ImageDraw
import PIL.ImageFont as ImageFont
+from PIL import ImageColor
import numpy as np
STANDARD_COLORS = [
@@ -30,66 +31,123 @@
]
-def filter_low_thresh(boxes, scores, classes, category_index, thresh, box_to_display_str_map, box_to_color_map):
- for i in range(boxes.shape[0]):
- if scores[i] > thresh:
- box = tuple(boxes[i].tolist()) # numpy -> list -> tuple
- if classes[i] in category_index.keys():
- class_name = category_index[classes[i]]
- else:
- class_name = 'N/A'
- display_str = str(class_name)
- display_str = '{}: {}%'.format(display_str, int(100 * scores[i]))
- box_to_display_str_map[box].append(display_str)
- box_to_color_map[box] = STANDARD_COLORS[
- classes[i] % len(STANDARD_COLORS)]
- else:
- break # 网络输出概率已经排序过,当遇到一个不满足后面的肯定不满足
-
-
-def draw_text(draw, box_to_display_str_map, box, left, right, top, bottom, color):
+def draw_text(draw,
+ box: list,
+ cls: int,
+ score: float,
+ category_index: dict,
+ color: str,
+ font: str = 'arial.ttf',
+ font_size: int = 24):
+ """
+ 将目标边界框和类别信息绘制到图片上
+ """
try:
- font = ImageFont.truetype('arial.ttf', 24)
+ font = ImageFont.truetype(font, font_size)
except IOError:
font = ImageFont.load_default()
+ left, top, right, bottom = box
# If the total height of the display strings added to the top of the bounding
# box exceeds the top of the image, stack the strings below the bounding box
# instead of above.
- display_str_heights = [font.getsize(ds)[1] for ds in box_to_display_str_map[box]]
+ display_str = f"{category_index[str(cls)]}: {int(100 * score)}%"
+ display_str_heights = [font.getsize(ds)[1] for ds in display_str]
# Each display_str has a top and bottom margin of 0.05x.
- total_display_str_height = (1 + 2 * 0.05) * sum(display_str_heights)
+ display_str_height = (1 + 2 * 0.05) * max(display_str_heights)
- if top > total_display_str_height:
+ if top > display_str_height:
+ text_top = top - display_str_height
text_bottom = top
else:
- text_bottom = bottom + total_display_str_height
- # Reverse list and print from bottom to top.
- for display_str in box_to_display_str_map[box][::-1]:
- text_width, text_height = font.getsize(display_str)
- margin = np.ceil(0.05 * text_height)
- draw.rectangle([(left, text_bottom - text_height - 2 * margin),
- (left + text_width, text_bottom)], fill=color)
- draw.text((left + margin, text_bottom - text_height - margin),
- display_str,
+ text_top = bottom
+ text_bottom = bottom + display_str_height
+
+ for ds in display_str:
+ text_width, text_height = font.getsize(ds)
+ margin = np.ceil(0.05 * text_width)
+ draw.rectangle([(left, text_top),
+ (left + text_width + 2 * margin, text_bottom)], fill=color)
+ draw.text((left + margin, text_top),
+ ds,
fill='black',
font=font)
- text_bottom -= text_height - 2 * margin
+ left += text_width
+
+
+def draw_masks(image, masks, colors, thresh: float = 0.7, alpha: float = 0.5):
+ np_image = np.array(image)
+ masks = np.where(masks > thresh, True, False)
+
+ # colors = np.array(colors)
+ img_to_draw = np.copy(np_image)
+ # TODO: There might be a way to vectorize this
+ for mask, color in zip(masks, colors):
+ img_to_draw[mask] = color
+
+ out = np_image * (1 - alpha) + img_to_draw * alpha
+ return fromarray(out.astype(np.uint8))
+
+
+def draw_objs(image: Image,
+ boxes: np.ndarray = None,
+ classes: np.ndarray = None,
+ scores: np.ndarray = None,
+ masks: np.ndarray = None,
+ category_index: dict = None,
+ box_thresh: float = 0.1,
+ mask_thresh: float = 0.5,
+ line_thickness: int = 8,
+ font: str = 'arial.ttf',
+ font_size: int = 24,
+ draw_boxes_on_image: bool = True,
+ draw_masks_on_image: bool = False):
+ """
+ 将目标边界框信息,类别信息,mask信息绘制在图片上
+ Args:
+ image: 需要绘制的图片
+ boxes: 目标边界框信息
+ classes: 目标类别信息
+ scores: 目标概率信息
+ masks: 目标mask信息
+ category_index: 类别与名称字典
+ box_thresh: 过滤的概率阈值
+ mask_thresh:
+ line_thickness: 边界框宽度
+ font: 字体类型
+ font_size: 字体大小
+ draw_boxes_on_image:
+ draw_masks_on_image:
+
+ Returns:
+
+ """
+
+ # 过滤掉低概率的目标
+ idxs = np.greater(scores, box_thresh)
+ boxes = boxes[idxs]
+ classes = classes[idxs]
+ scores = scores[idxs]
+ if masks is not None:
+ masks = masks[idxs]
+ if len(boxes) == 0:
+ return image
+ colors = [ImageColor.getrgb(STANDARD_COLORS[cls % len(STANDARD_COLORS)]) for cls in classes]
-def draw_box(image, boxes, classes, scores, category_index, thresh=0.5, line_thickness=8):
- box_to_display_str_map = collections.defaultdict(list)
- box_to_color_map = collections.defaultdict(str)
+ if draw_boxes_on_image:
+ # Draw all boxes onto image.
+ draw = ImageDraw.Draw(image)
+ for box, cls, score, color in zip(boxes, classes, scores, colors):
+ left, top, right, bottom = box
+ # 绘制目标边界框
+ draw.line([(left, top), (left, bottom), (right, bottom),
+ (right, top), (left, top)], width=line_thickness, fill=color)
+ # 绘制类别和概率信息
+ draw_text(draw, box.tolist(), int(cls), float(score), category_index, color, font, font_size)
- filter_low_thresh(boxes, scores, classes, category_index, thresh, box_to_display_str_map, box_to_color_map)
+ if draw_masks_on_image and (masks is not None):
+ # Draw all mask onto image.
+ image = draw_masks(image, masks, colors, mask_thresh)
- # Draw all boxes onto image.
- draw = ImageDraw.Draw(image)
- im_width, im_height = image.size
- for box, color in box_to_color_map.items():
- xmin, ymin, xmax, ymax = box
- (left, right, top, bottom) = (xmin * 1, xmax * 1,
- ymin * 1, ymax * 1)
- draw.line([(left, top), (left, bottom), (right, bottom),
- (right, top), (left, top)], width=line_thickness, fill=color)
- draw_text(draw, box_to_display_str_map, box, left, right, top, bottom, color)
+ return image
diff --git a/pytorch_object_detection/faster_rcnn/my_dataset.py b/pytorch_object_detection/faster_rcnn/my_dataset.py
index 23986bdf5..efabd862e 100644
--- a/pytorch_object_detection/faster_rcnn/my_dataset.py
+++ b/pytorch_object_detection/faster_rcnn/my_dataset.py
@@ -1,3 +1,4 @@
+import numpy as np
from torch.utils.data import Dataset
import os
import torch
@@ -11,7 +12,11 @@ class VOCDataSet(Dataset):
def __init__(self, voc_root, year="2012", transforms=None, txt_name: str = "train.txt"):
assert year in ["2007", "2012"], "year must be in ['2007', '2012']"
- self.root = os.path.join(voc_root, "VOCdevkit", f"VOC{year}")
+ # 增加容错能力
+ if "VOCdevkit" in voc_root:
+ self.root = os.path.join(voc_root, f"VOC{year}")
+ else:
+ self.root = os.path.join(voc_root, "VOCdevkit", f"VOC{year}")
self.img_root = os.path.join(self.root, "JPEGImages")
self.annotations_root = os.path.join(self.root, "Annotations")
@@ -20,20 +25,34 @@ def __init__(self, voc_root, year="2012", transforms=None, txt_name: str = "trai
assert os.path.exists(txt_path), "not found {} file.".format(txt_name)
with open(txt_path) as read:
- self.xml_list = [os.path.join(self.annotations_root, line.strip() + ".xml")
- for line in read.readlines() if len(line.strip()) > 0]
+ xml_list = [os.path.join(self.annotations_root, line.strip() + ".xml")
+ for line in read.readlines() if len(line.strip()) > 0]
+ self.xml_list = []
# check file
+ for xml_path in xml_list:
+ if os.path.exists(xml_path) is False:
+ print(f"Warning: not found '{xml_path}', skip this annotation file.")
+ continue
+
+ # check for targets
+ with open(xml_path) as fid:
+ xml_str = fid.read()
+ xml = etree.fromstring(xml_str)
+ data = self.parse_xml_to_dict(xml)["annotation"]
+ if "object" not in data:
+ print(f"INFO: no objects in {xml_path}, skip this annotation file.")
+ continue
+
+ self.xml_list.append(xml_path)
+
assert len(self.xml_list) > 0, "in '{}' file does not find any information.".format(txt_path)
- for xml_path in self.xml_list:
- assert os.path.exists(xml_path), "not found '{}' file.".format(xml_path)
# read class_indict
json_file = './pascal_voc_classes.json'
assert os.path.exists(json_file), "{} file not exist.".format(json_file)
- json_file = open(json_file, 'r')
- self.class_dict = json.load(json_file)
- json_file.close()
+ with open(json_file, 'r') as f:
+ self.class_dict = json.load(f)
self.transforms = transforms
@@ -181,7 +200,7 @@ def collate_fn(batch):
return tuple(zip(*batch))
# import transforms
-# from draw_box_utils import draw_box
+# from draw_box_utils import draw_objs
# from PIL import Image
# import json
# import matplotlib.pyplot as plt
@@ -193,7 +212,7 @@ def collate_fn(batch):
# try:
# json_file = open('./pascal_voc_classes.json', 'r')
# class_dict = json.load(json_file)
-# category_index = {v: k for k, v in class_dict.items()}
+# category_index = {str(v): str(k) for k, v in class_dict.items()}
# except Exception as e:
# print(e)
# exit(-1)
@@ -210,12 +229,14 @@ def collate_fn(batch):
# for index in random.sample(range(0, len(train_data_set)), k=5):
# img, target = train_data_set[index]
# img = ts.ToPILImage()(img)
-# draw_box(img,
-# target["boxes"].numpy(),
-# target["labels"].numpy(),
-# [1 for i in range(len(target["labels"].numpy()))],
-# category_index,
-# thresh=0.5,
-# line_thickness=5)
-# plt.imshow(img)
+# plot_img = draw_objs(img,
+# target["boxes"].numpy(),
+# target["labels"].numpy(),
+# np.ones(target["labels"].shape[0]),
+# category_index=category_index,
+# box_thresh=0.5,
+# line_thickness=3,
+# font='arial.ttf',
+# font_size=20)
+# plt.imshow(plot_img)
# plt.show()
diff --git a/pytorch_object_detection/faster_rcnn/network_files/boxes.py b/pytorch_object_detection/faster_rcnn/network_files/boxes.py
index f720df1f8..8eeca4573 100644
--- a/pytorch_object_detection/faster_rcnn/network_files/boxes.py
+++ b/pytorch_object_detection/faster_rcnn/network_files/boxes.py
@@ -23,7 +23,7 @@ def nms(boxes, scores, iou_threshold):
scores for each one of the boxes
iou_threshold : float
discards all overlapping
- boxes with IoU < iou_threshold
+ boxes with IoU > iou_threshold
Returns
-------
diff --git a/pytorch_object_detection/faster_rcnn/predict.py b/pytorch_object_detection/faster_rcnn/predict.py
index 35ad35dd2..2b85400be 100644
--- a/pytorch_object_detection/faster_rcnn/predict.py
+++ b/pytorch_object_detection/faster_rcnn/predict.py
@@ -10,7 +10,7 @@
from torchvision import transforms
from network_files import FasterRCNN, FastRCNNPredictor, AnchorsGenerator
from backbone import resnet50_fpn_backbone, MobileNetV2
-from draw_box_utils import draw_box
+from draw_box_utils import draw_objs
def create_model(num_classes):
@@ -52,18 +52,20 @@ def main():
model = create_model(num_classes=21)
# load train weights
- train_weights = "./save_weights/model.pth"
- assert os.path.exists(train_weights), "{} file dose not exist.".format(train_weights)
- model.load_state_dict(torch.load(train_weights, map_location=device)["model"])
+ weights_path = "./save_weights/model.pth"
+ assert os.path.exists(weights_path), "{} file dose not exist.".format(weights_path)
+ weights_dict = torch.load(weights_path, map_location='cpu')
+ weights_dict = weights_dict["model"] if "model" in weights_dict else weights_dict
+ model.load_state_dict(weights_dict)
model.to(device)
# read class_indict
label_json_path = './pascal_voc_classes.json'
assert os.path.exists(label_json_path), "json file {} dose not exist.".format(label_json_path)
- json_file = open(label_json_path, 'r')
- class_dict = json.load(json_file)
- json_file.close()
- category_index = {v: k for k, v in class_dict.items()}
+ with open(label_json_path, 'r') as f:
+ class_dict = json.load(f)
+
+ category_index = {str(v): str(k) for k, v in class_dict.items()}
# load image
original_img = Image.open("./test.jpg")
@@ -93,19 +95,20 @@ def main():
if len(predict_boxes) == 0:
print("没有检测到任何目标!")
- draw_box(original_img,
- predict_boxes,
- predict_classes,
- predict_scores,
- category_index,
- thresh=0.5,
- line_thickness=3)
- plt.imshow(original_img)
+ plot_img = draw_objs(original_img,
+ predict_boxes,
+ predict_classes,
+ predict_scores,
+ category_index=category_index,
+ box_thresh=0.5,
+ line_thickness=3,
+ font='arial.ttf',
+ font_size=20)
+ plt.imshow(plot_img)
plt.show()
# 保存预测的图片结果
- original_img.save("test_result.jpg")
+ plot_img.save("test_result.jpg")
if __name__ == '__main__':
main()
-
diff --git a/pytorch_object_detection/faster_rcnn/train_mobilenetv2.py b/pytorch_object_detection/faster_rcnn/train_mobilenetv2.py
index eab4fd4ec..bcbd28f8e 100644
--- a/pytorch_object_detection/faster_rcnn/train_mobilenetv2.py
+++ b/pytorch_object_detection/faster_rcnn/train_mobilenetv2.py
@@ -146,7 +146,7 @@ def main():
# write into txt
with open(results_file, "a") as f:
# 写入的数据包括coco指标还有loss和learning rate
- result_info = [str(round(i, 4)) for i in coco_info + [mean_loss.item()]] + [str(round(lr, 6))]
+ result_info = [f"{i:.4f}" for i in coco_info + [mean_loss.item()]] + [f"{lr:.6f}"]
txt = "epoch:{} {}".format(epoch, ' '.join(result_info))
f.write(txt + "\n")
@@ -193,7 +193,7 @@ def main():
# write into txt
with open(results_file, "a") as f:
# 写入的数据包括coco指标还有loss和learning rate
- result_info = [str(round(i, 4)) for i in coco_info + [mean_loss.item()]] + [str(round(lr, 6))]
+ result_info = [f"{i:.4f}" for i in coco_info + [mean_loss.item()]] + [f"{lr:.6f}"]
txt = "epoch:{} {}".format(epoch, ' '.join(result_info))
f.write(txt + "\n")
diff --git a/pytorch_object_detection/faster_rcnn/train_multi_GPU.py b/pytorch_object_detection/faster_rcnn/train_multi_GPU.py
index 3315bc760..1ec76c076 100644
--- a/pytorch_object_detection/faster_rcnn/train_multi_GPU.py
+++ b/pytorch_object_detection/faster_rcnn/train_multi_GPU.py
@@ -96,6 +96,9 @@ def main(args):
model = create_model(num_classes=args.num_classes + 1)
model.to(device)
+ if args.distributed and args.sync_bn:
+ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
+
model_without_ddp = model
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
@@ -154,7 +157,7 @@ def main(args):
# write into txt
with open(results_file, "a") as f:
# 写入的数据包括coco指标还有loss和learning rate
- result_info = [str(round(i, 4)) for i in coco_info + [mean_loss.item()]] + [str(round(lr, 6))]
+ result_info = [f"{i:.4f}" for i in coco_info + [mean_loss.item()]] + [f"{lr:.6f}"]
txt = "epoch:{} {}".format(epoch, ' '.join(result_info))
f.write(txt + "\n")
@@ -246,6 +249,7 @@ def main(args):
parser.add_argument('--world-size', default=4, type=int,
help='number of distributed processes')
parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training')
+ parser.add_argument("--sync-bn", dest="sync_bn", help="Use sync batch norm", type=bool, default=False)
# 是否使用混合精度训练(需要GPU支持混合精度)
parser.add_argument("--amp", default=False, help="Use torch.cuda.amp for mixed precision training")
diff --git a/pytorch_object_detection/faster_rcnn/train_res50_fpn.py b/pytorch_object_detection/faster_rcnn/train_res50_fpn.py
index b45c4897e..f45e62901 100644
--- a/pytorch_object_detection/faster_rcnn/train_res50_fpn.py
+++ b/pytorch_object_detection/faster_rcnn/train_res50_fpn.py
@@ -11,22 +11,26 @@
from train_utils import train_eval_utils as utils
-def create_model(num_classes):
+def create_model(num_classes, load_pretrain_weights=True):
# 注意,这里的backbone默认使用的是FrozenBatchNorm2d,即不会去更新bn参数
# 目的是为了防止batch_size太小导致效果更差(如果显存很小,建议使用默认的FrozenBatchNorm2d)
# 如果GPU显存很大可以设置比较大的batch_size就可以将norm_layer设置为普通的BatchNorm2d
# trainable_layers包括['layer4', 'layer3', 'layer2', 'layer1', 'conv1'], 5代表全部训练
- backbone = resnet50_fpn_backbone(norm_layer=torch.nn.BatchNorm2d,
+ # resnet50 imagenet weights url: https://download.pytorch.org/models/resnet50-0676ba61.pth
+ backbone = resnet50_fpn_backbone(pretrain_path="./backbone/resnet50.pth",
+ norm_layer=torch.nn.BatchNorm2d,
trainable_layers=3)
# 训练自己数据集时不要修改这里的91,修改的是传入的num_classes参数
model = FasterRCNN(backbone=backbone, num_classes=91)
- # 载入预训练模型权重
- # https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth
- weights_dict = torch.load("./backbone/fasterrcnn_resnet50_fpn_coco.pth", map_location='cpu')
- missing_keys, unexpected_keys = model.load_state_dict(weights_dict, strict=False)
- if len(missing_keys) != 0 or len(unexpected_keys) != 0:
- print("missing_keys: ", missing_keys)
- print("unexpected_keys: ", unexpected_keys)
+
+ if load_pretrain_weights:
+ # 载入预训练模型权重
+ # https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth
+ weights_dict = torch.load("./backbone/fasterrcnn_resnet50_fpn_coco.pth", map_location='cpu')
+ missing_keys, unexpected_keys = model.load_state_dict(weights_dict, strict=False)
+ if len(missing_keys) != 0 or len(unexpected_keys) != 0:
+ print("missing_keys: ", missing_keys)
+ print("unexpected_keys: ", unexpected_keys)
# get number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features
@@ -36,8 +40,8 @@ def create_model(num_classes):
return model
-def main(parser_data):
- device = torch.device(parser_data.device if torch.cuda.is_available() else "cpu")
+def main(args):
+ device = torch.device(args.device if torch.cuda.is_available() else "cpu")
print("Using {} device training.".format(device.type))
# 用来保存coco_info的文件
@@ -49,7 +53,7 @@ def main(parser_data):
"val": transforms.Compose([transforms.ToTensor()])
}
- VOC_root = parser_data.data_path
+ VOC_root = args.data_path
# check voc root
if os.path.exists(os.path.join(VOC_root, "VOCdevkit")) is False:
raise FileNotFoundError("VOCdevkit dose not in path:'{}'.".format(VOC_root))
@@ -69,7 +73,7 @@ def main(parser_data):
train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids, args.batch_size)
# 注意这里的collate_fn是自定义的,因为读取的数据包括image和targets,不能直接使用默认的方法合成batch
- batch_size = parser_data.batch_size
+ batch_size = args.batch_size
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
print('Using %g dataloader workers' % nw)
if train_sampler:
@@ -98,15 +102,17 @@ def main(parser_data):
collate_fn=val_dataset.collate_fn)
# create model num_classes equal background + 20 classes
- model = create_model(num_classes=parser_data.num_classes + 1)
+ model = create_model(num_classes=args.num_classes + 1)
# print(model)
model.to(device)
# define optimizer
params = [p for p in model.parameters() if p.requires_grad]
- optimizer = torch.optim.SGD(params, lr=0.005,
- momentum=0.9, weight_decay=0.0005)
+ optimizer = torch.optim.SGD(params,
+ lr=args.lr,
+ momentum=args.momentum,
+ weight_decay=args.weight_decay)
scaler = torch.cuda.amp.GradScaler() if args.amp else None
@@ -116,21 +122,21 @@ def main(parser_data):
gamma=0.33)
# 如果指定了上次训练保存的权重文件地址,则接着上次结果接着训练
- if parser_data.resume != "":
- checkpoint = torch.load(parser_data.resume, map_location='cpu')
+ if args.resume != "":
+ checkpoint = torch.load(args.resume, map_location='cpu')
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
- parser_data.start_epoch = checkpoint['epoch'] + 1
+ args.start_epoch = checkpoint['epoch'] + 1
if args.amp and "scaler" in checkpoint:
scaler.load_state_dict(checkpoint["scaler"])
- print("the training process from epoch{}...".format(parser_data.start_epoch))
+ print("the training process from epoch{}...".format(args.start_epoch))
train_loss = []
learning_rate = []
val_map = []
- for epoch in range(parser_data.start_epoch, parser_data.epochs):
+ for epoch in range(args.start_epoch, args.epochs):
# train for one epoch, printing every 10 iterations
mean_loss, lr = utils.train_one_epoch(model, optimizer, train_data_loader,
device=device, epoch=epoch,
@@ -148,7 +154,7 @@ def main(parser_data):
# write into txt
with open(results_file, "a") as f:
# 写入的数据包括coco指标还有loss和learning rate
- result_info = [str(round(i, 4)) for i in coco_info + [mean_loss.item()]] + [str(round(lr, 6))]
+ result_info = [f"{i:.4f}" for i in coco_info + [mean_loss.item()]] + [f"{lr:.6f}"]
txt = "epoch:{} {}".format(epoch, ' '.join(result_info))
f.write(txt + "\n")
@@ -196,6 +202,17 @@ def main(parser_data):
# 训练的总epoch数
parser.add_argument('--epochs', default=15, type=int, metavar='N',
help='number of total epochs to run')
+ # 学习率
+ parser.add_argument('--lr', default=0.01, type=float,
+ help='initial learning rate, 0.02 is the default value for training '
+ 'on 8 gpus and 2 images_per_gpu')
+ # SGD的momentum参数
+ parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
+ help='momentum')
+ # SGD的weight_decay参数
+ parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
+ metavar='W', help='weight decay (default: 1e-4)',
+ dest='weight_decay')
# 训练的batch size
parser.add_argument('--batch_size', default=8, type=int, metavar='N',
help='batch size when training.')
diff --git a/pytorch_object_detection/faster_rcnn/validation.py b/pytorch_object_detection/faster_rcnn/validation.py
index 95b3ba696..d353aed4e 100644
--- a/pytorch_object_detection/faster_rcnn/validation.py
+++ b/pytorch_object_detection/faster_rcnn/validation.py
@@ -100,9 +100,9 @@ def main(parser_data):
# read class_indict
label_json_path = './pascal_voc_classes.json'
assert os.path.exists(label_json_path), "json file {} dose not exist.".format(label_json_path)
- json_file = open(label_json_path, 'r')
- class_dict = json.load(json_file)
- json_file.close()
+ with open(label_json_path, 'r') as f:
+ class_dict = json.load(f)
+
category_index = {v: k for k, v in class_dict.items()}
VOC_root = parser_data.data_path
@@ -130,9 +130,11 @@ def main(parser_data):
model = FasterRCNN(backbone=backbone, num_classes=parser_data.num_classes + 1)
# 载入你自己训练好的模型权重
- weights_path = parser_data.weights
+ weights_path = parser_data.weights_path
assert os.path.exists(weights_path), "not found {} file.".format(weights_path)
- model.load_state_dict(torch.load(weights_path, map_location=device)['model'])
+ weights_dict = torch.load(weights_path, map_location='cpu')
+ weights_dict = weights_dict["model"] if "model" in weights_dict else weights_dict
+ model.load_state_dict(weights_dict)
# print(model)
model.to(device)
@@ -201,7 +203,7 @@ def main(parser_data):
parser.add_argument('--data-path', default='/data/', help='dataset root')
# 训练好的权重文件
- parser.add_argument('--weights', default='./save_weights/model.pth', type=str, help='training weights')
+ parser.add_argument('--weights-path', default='./save_weights/model.pth', type=str, help='training weights')
# batch size
parser.add_argument('--batch_size', default=1, type=int, metavar='N',
diff --git a/pytorch_object_detection/mask_rcnn/README.md b/pytorch_object_detection/mask_rcnn/README.md
new file mode 100644
index 000000000..77f014021
--- /dev/null
+++ b/pytorch_object_detection/mask_rcnn/README.md
@@ -0,0 +1,153 @@
+# Mask R-CNN
+
+## 该项目参考自pytorch官方torchvision模块中的源码(使用pycocotools处略有不同)
+* https://github.com/pytorch/vision/tree/master/references/detection
+
+## 环境配置:
+* Python3.6/3.7/3.8
+* Pytorch1.10或以上
+* pycocotools(Linux:`pip install pycocotools`; Windows:`pip install pycocotools-windows`(不需要额外安装vs))
+* Ubuntu或Centos(不建议Windows)
+* 最好使用GPU训练
+* 详细环境配置见`requirements.txt`
+
+## 文件结构:
+```
+ ├── backbone: 特征提取网络
+ ├── network_files: Mask R-CNN网络
+ ├── train_utils: 训练验证相关模块(包括coco验证相关)
+ ├── my_dataset_coco.py: 自定义dataset用于读取COCO2017数据集
+ ├── my_dataset_voc.py: 自定义dataset用于读取Pascal VOC数据集
+ ├── train.py: 单GPU/CPU训练脚本
+ ├── train_multi_GPU.py: 针对使用多GPU的用户使用
+ ├── predict.py: 简易的预测脚本,使用训练好的权重进行预测
+ ├── validation.py: 利用训练好的权重验证/测试数据的COCO指标,并生成record_mAP.txt文件
+ └── transforms.py: 数据预处理(随机水平翻转图像以及bboxes、将PIL图像转为Tensor)
+```
+
+## 预训练权重下载地址(下载后放入当前文件夹中):
+* Resnet50预训练权重 https://download.pytorch.org/models/resnet50-0676ba61.pth (注意,下载预训练权重后要重命名,
+比如在train.py中读取的是`resnet50.pth`文件,不是`resnet50-0676ba61.pth`)
+* Mask R-CNN(Resnet50+FPN)预训练权重 https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth (注意,
+载预训练权重后要重命名,比如在train.py中读取的是`maskrcnn_resnet50_fpn_coco.pth`文件,不是`maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth`)
+
+
+## 数据集,本例程使用的有COCO2017数据集和Pascal VOC2012数据集
+### COCO2017数据集
+* COCO官网地址:https://cocodataset.org/
+* 对数据集不了解的可以看下我写的博文:https://blog.csdn.net/qq_37541097/article/details/113247318
+* 这里以下载coco2017数据集为例,主要下载三个文件:
+ * `2017 Train images [118K/18GB]`:训练过程中使用到的所有图像文件
+ * `2017 Val images [5K/1GB]`:验证过程中使用到的所有图像文件
+ * `2017 Train/Val annotations [241MB]`:对应训练集和验证集的标注json文件
+* 都解压到`coco2017`文件夹下,可得到如下文件夹结构:
+```
+├── coco2017: 数据集根目录
+ ├── train2017: 所有训练图像文件夹(118287张)
+ ├── val2017: 所有验证图像文件夹(5000张)
+ └── annotations: 对应标注文件夹
+ ├── instances_train2017.json: 对应目标检测、分割任务的训练集标注文件
+ ├── instances_val2017.json: 对应目标检测、分割任务的验证集标注文件
+ ├── captions_train2017.json: 对应图像描述的训练集标注文件
+ ├── captions_val2017.json: 对应图像描述的验证集标注文件
+ ├── person_keypoints_train2017.json: 对应人体关键点检测的训练集标注文件
+ └── person_keypoints_val2017.json: 对应人体关键点检测的验证集标注文件夹
+```
+
+### Pascal VOC2012数据集
+* 数据集下载地址: http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html#devkit
+* 对数据集不了解的可以看下我写的博文:https://blog.csdn.net/qq_37541097/article/details/115787033
+* 解压后得到的文件夹结构如下:
+```
+VOCdevkit
+ └── VOC2012
+ ├── Annotations 所有的图像标注信息(XML文件)
+ ├── ImageSets
+ │ ├── Action 人的行为动作图像信息
+ │ ├── Layout 人的各个部位图像信息
+ │ │
+ │ ├── Main 目标检测分类图像信息
+ │ │ ├── train.txt 训练集(5717)
+ │ │ ├── val.txt 验证集(5823)
+ │ │ └── trainval.txt 训练集+验证集(11540)
+ │ │
+ │ └── Segmentation 目标分割图像信息
+ │ ├── train.txt 训练集(1464)
+ │ ├── val.txt 验证集(1449)
+ │ └── trainval.txt 训练集+验证集(2913)
+ │
+ ├── JPEGImages 所有图像文件
+ ├── SegmentationClass 语义分割png图(基于类别)
+ └── SegmentationObject 实例分割png图(基于目标)
+```
+
+## 训练方法
+* 确保提前准备好数据集
+* 确保提前下载好对应预训练模型权重
+* 确保设置好`--num-classes`和`--data-path`
+* 若要使用单GPU训练直接使用train.py训练脚本
+* 若要使用多GPU训练,使用`torchrun --nproc_per_node=8 train_multi_GPU.py`指令,`nproc_per_node`参数为使用GPU数量
+* 如果想指定使用哪些GPU设备可在指令前加上`CUDA_VISIBLE_DEVICES=0,3`(例如我只要使用设备中的第1块和第4块GPU设备)
+* `CUDA_VISIBLE_DEVICES=0,3 torchrun --nproc_per_node=2 train_multi_GPU.py`
+
+## 注意事项
+1. 在使用训练脚本时,注意要将`--data-path`设置为自己存放数据集的**根目录**:
+```
+# 假设要使用COCO数据集,启用自定义数据集读取CocoDetection并将数据集解压到成/data/coco2017目录下
+python train.py --data-path /data/coco2017
+
+# 假设要使用Pascal VOC数据集,启用自定义数据集读取VOCInstances并数据集解压到成/data/VOCdevkit目录下
+python train.py --data-path /data/VOCdevkit
+```
+
+2. 如果倍增`batch_size`,建议学习率也跟着倍增。假设将`batch_size`从4设置成8,那么学习率`lr`从0.004设置成0.008
+3. 如果使用Batch Normalization模块时,`batch_size`不能小于4,否则效果会变差。**如果显存不够,batch_size必须小于4时**,建议在创建`resnet50_fpn_backbone`时,
+将`norm_layer`设置成`FrozenBatchNorm2d`或将`trainable_layers`设置成0(即冻结整个`backbone`)
+4. 训练过程中保存的`det_results.txt`(目标检测任务)以及`seg_results.txt`(实例分割任务)是每个epoch在验证集上的COCO指标,前12个值是COCO指标,后面两个值是训练平均损失以及学习率
+5. 在使用预测脚本时,要将`weights_path`设置为你自己生成的权重路径。
+6. 使用validation文件时,注意确保你的验证集或者测试集中必须包含每个类别的目标,并且使用时需要修改`--num-classes`、`--data-path`、`--weights-path`以及
+`--label-json-path`(该参数是根据训练的数据集设置的)。其他代码尽量不要改动
+
+
+## 复现结果
+在COCO2017数据集上进行复现,训练过程中仅载入Resnet50的预训练权重,训练26个epochs。训练采用指令如下:
+```
+torchrun --nproc_per_node=8 train_multi_GPU.py --batch-size 8 --lr 0.08 --pretrain False --amp True
+```
+
+训练得到权重下载地址: https://pan.baidu.com/s/1qpXUIsvnj8RHY-V05J-mnA 密码: 63d5
+
+在COCO2017验证集上的mAP(目标检测任务):
+```
+ Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.381
+ Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.588
+ Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.411
+ Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.215
+ Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.420
+ Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.492
+ Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.315
+ Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.499
+ Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.523
+ Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.319
+ Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.565
+ Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.666
+```
+
+在COCO2017验证集上的mAP(实例分割任务):
+```
+ Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.340
+ Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.552
+ Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.361
+ Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.151
+ Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.369
+ Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.500
+ Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.290
+ Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.449
+ Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.468
+ Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.266
+ Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.509
+ Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.619
+```
+
+## 如果对Mask RCNN原理不是很理解可参考我的bilibili
+https://www.bilibili.com/video/BV1ZY411774T
diff --git a/pytorch_object_detection/mask_rcnn/backbone/__init__.py b/pytorch_object_detection/mask_rcnn/backbone/__init__.py
new file mode 100644
index 000000000..314eb748f
--- /dev/null
+++ b/pytorch_object_detection/mask_rcnn/backbone/__init__.py
@@ -0,0 +1 @@
+from .resnet50_fpn_model import resnet50_fpn_backbone
diff --git a/pytorch_object_detection/mask_rcnn/backbone/feature_pyramid_network.py b/pytorch_object_detection/mask_rcnn/backbone/feature_pyramid_network.py
new file mode 100644
index 000000000..fc2fc757f
--- /dev/null
+++ b/pytorch_object_detection/mask_rcnn/backbone/feature_pyramid_network.py
@@ -0,0 +1,235 @@
+from collections import OrderedDict
+
+import torch.nn as nn
+import torch
+from torch import Tensor
+import torch.nn.functional as F
+
+from torch.jit.annotations import Tuple, List, Dict
+
+
+class IntermediateLayerGetter(nn.ModuleDict):
+ """
+ Module wrapper that returns intermediate layers from a model
+ It has a strong assumption that the modules have been registered
+ into the model in the same order as they are used.
+ This means that one should **not** reuse the same nn.Module
+ twice in the forward if you want this to work.
+ Additionally, it is only able to query submodules that are directly
+ assigned to the model. So if `model` is passed, `model.feature1` can
+ be returned, but not `model.feature1.layer2`.
+ Arguments:
+ model (nn.Module): model on which we will extract the features
+ return_layers (Dict[name, new_name]): a dict containing the names
+ of the modules for which the activations will be returned as
+ the key of the dict, and the value of the dict is the name
+ of the returned activation (which the user can specify).
+ """
+ __annotations__ = {
+ "return_layers": Dict[str, str],
+ }
+
+ def __init__(self, model, return_layers):
+ if not set(return_layers).issubset([name for name, _ in model.named_children()]):
+ raise ValueError("return_layers are not present in model")
+
+ orig_return_layers = return_layers
+ return_layers = {str(k): str(v) for k, v in return_layers.items()}
+ layers = OrderedDict()
+
+ # 遍历模型子模块按顺序存入有序字典
+ # 只保存layer4及其之前的结构,舍去之后不用的结构
+ for name, module in model.named_children():
+ layers[name] = module
+ if name in return_layers:
+ del return_layers[name]
+ if not return_layers:
+ break
+
+ super().__init__(layers)
+ self.return_layers = orig_return_layers
+
+ def forward(self, x):
+ out = OrderedDict()
+ # 依次遍历模型的所有子模块,并进行正向传播,
+ # 收集layer1, layer2, layer3, layer4的输出
+ for name, module in self.items():
+ x = module(x)
+ if name in self.return_layers:
+ out_name = self.return_layers[name]
+ out[out_name] = x
+ return out
+
+
+class BackboneWithFPN(nn.Module):
+ """
+ Adds a FPN on top of a model.
+ Internally, it uses torchvision.models._utils.IntermediateLayerGetter to
+ extract a submodel that returns the feature maps specified in return_layers.
+ The same limitations of IntermediatLayerGetter apply here.
+ Arguments:
+ backbone (nn.Module)
+ return_layers (Dict[name, new_name]): a dict containing the names
+ of the modules for which the activations will be returned as
+ the key of the dict, and the value of the dict is the name
+ of the returned activation (which the user can specify).
+ in_channels_list (List[int]): number of channels for each feature map
+ that is returned, in the order they are present in the OrderedDict
+ out_channels (int): number of channels in the FPN.
+ extra_blocks: ExtraFPNBlock
+ Attributes:
+ out_channels (int): the number of channels in the FPN
+ """
+
+ def __init__(self,
+ backbone: nn.Module,
+ return_layers=None,
+ in_channels_list=None,
+ out_channels=256,
+ extra_blocks=None,
+ re_getter=True):
+ super().__init__()
+
+ if extra_blocks is None:
+ extra_blocks = LastLevelMaxPool()
+
+ if re_getter:
+ assert return_layers is not None
+ self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
+ else:
+ self.body = backbone
+
+ self.fpn = FeaturePyramidNetwork(
+ in_channels_list=in_channels_list,
+ out_channels=out_channels,
+ extra_blocks=extra_blocks,
+ )
+
+ self.out_channels = out_channels
+
+ def forward(self, x):
+ x = self.body(x)
+ x = self.fpn(x)
+ return x
+
+
+class FeaturePyramidNetwork(nn.Module):
+ """
+ Module that adds a FPN from on top of a set of feature maps. This is based on
+ `"Feature Pyramid Network for Object Detection" `_.
+ The feature maps are currently supposed to be in increasing depth
+ order.
+ The input to the model is expected to be an OrderedDict[Tensor], containing
+ the feature maps on top of which the FPN will be added.
+ Arguments:
+ in_channels_list (list[int]): number of channels for each feature map that
+ is passed to the module
+ out_channels (int): number of channels of the FPN representation
+ extra_blocks (ExtraFPNBlock or None): if provided, extra operations will
+ be performed. It is expected to take the fpn features, the original
+ features and the names of the original features as input, and returns
+ a new list of feature maps and their corresponding names
+ """
+
+ def __init__(self, in_channels_list, out_channels, extra_blocks=None):
+ super().__init__()
+ # 用来调整resnet特征矩阵(layer1,2,3,4)的channel(kernel_size=1)
+ self.inner_blocks = nn.ModuleList()
+ # 对调整后的特征矩阵使用3x3的卷积核来得到对应的预测特征矩阵
+ self.layer_blocks = nn.ModuleList()
+ for in_channels in in_channels_list:
+ if in_channels == 0:
+ continue
+ inner_block_module = nn.Conv2d(in_channels, out_channels, 1)
+ layer_block_module = nn.Conv2d(out_channels, out_channels, 3, padding=1)
+ self.inner_blocks.append(inner_block_module)
+ self.layer_blocks.append(layer_block_module)
+
+ # initialize parameters now to avoid modifying the initialization of top_blocks
+ for m in self.children():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_uniform_(m.weight, a=1)
+ nn.init.constant_(m.bias, 0)
+
+ self.extra_blocks = extra_blocks
+
+ def get_result_from_inner_blocks(self, x: Tensor, idx: int) -> Tensor:
+ """
+ This is equivalent to self.inner_blocks[idx](x),
+ but torchscript doesn't support this yet
+ """
+ num_blocks = len(self.inner_blocks)
+ if idx < 0:
+ idx += num_blocks
+ i = 0
+ out = x
+ for module in self.inner_blocks:
+ if i == idx:
+ out = module(x)
+ i += 1
+ return out
+
+ def get_result_from_layer_blocks(self, x: Tensor, idx: int) -> Tensor:
+ """
+ This is equivalent to self.layer_blocks[idx](x),
+ but torchscript doesn't support this yet
+ """
+ num_blocks = len(self.layer_blocks)
+ if idx < 0:
+ idx += num_blocks
+ i = 0
+ out = x
+ for module in self.layer_blocks:
+ if i == idx:
+ out = module(x)
+ i += 1
+ return out
+
+ def forward(self, x: Dict[str, Tensor]) -> Dict[str, Tensor]:
+ """
+ Computes the FPN for a set of feature maps.
+ Arguments:
+ x (OrderedDict[Tensor]): feature maps for each feature level.
+ Returns:
+ results (OrderedDict[Tensor]): feature maps after FPN layers.
+ They are ordered from highest resolution first.
+ """
+ # unpack OrderedDict into two lists for easier handling
+ names = list(x.keys())
+ x = list(x.values())
+
+ # 将resnet layer4的channel调整到指定的out_channels
+ # last_inner = self.inner_blocks[-1](x[-1])
+ last_inner = self.get_result_from_inner_blocks(x[-1], -1)
+ # result中保存着每个预测特征层
+ results = []
+ # 将layer4调整channel后的特征矩阵,通过3x3卷积后得到对应的预测特征矩阵
+ # results.append(self.layer_blocks[-1](last_inner))
+ results.append(self.get_result_from_layer_blocks(last_inner, -1))
+
+ for idx in range(len(x) - 2, -1, -1):
+ inner_lateral = self.get_result_from_inner_blocks(x[idx], idx)
+ feat_shape = inner_lateral.shape[-2:]
+ inner_top_down = F.interpolate(last_inner, size=feat_shape, mode="nearest")
+ last_inner = inner_lateral + inner_top_down
+ results.insert(0, self.get_result_from_layer_blocks(last_inner, idx))
+
+ # 在layer4对应的预测特征层基础上生成预测特征矩阵5
+ if self.extra_blocks is not None:
+ results, names = self.extra_blocks(results, x, names)
+
+ # make it back an OrderedDict
+ out = OrderedDict([(k, v) for k, v in zip(names, results)])
+
+ return out
+
+
+class LastLevelMaxPool(torch.nn.Module):
+ """
+ Applies a max_pool2d on top of the last feature map
+ """
+
+ def forward(self, x: List[Tensor], y: List[Tensor], names: List[str]) -> Tuple[List[Tensor], List[str]]:
+ names.append("pool")
+ x.append(F.max_pool2d(x[-1], 1, 2, 0))
+ return x, names
diff --git a/pytorch_object_detection/mask_rcnn/backbone/resnet50_fpn_model.py b/pytorch_object_detection/mask_rcnn/backbone/resnet50_fpn_model.py
new file mode 100644
index 000000000..a79502e5b
--- /dev/null
+++ b/pytorch_object_detection/mask_rcnn/backbone/resnet50_fpn_model.py
@@ -0,0 +1,199 @@
+import os
+
+import torch
+import torch.nn as nn
+from torchvision.ops.misc import FrozenBatchNorm2d
+
+from .feature_pyramid_network import BackboneWithFPN, LastLevelMaxPool
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self, in_channel, out_channel, stride=1, downsample=None, norm_layer=None):
+ super().__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+
+ self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
+ kernel_size=1, stride=1, bias=False) # squeeze channels
+ self.bn1 = norm_layer(out_channel)
+ # -----------------------------------------
+ self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
+ kernel_size=3, stride=stride, bias=False, padding=1)
+ self.bn2 = norm_layer(out_channel)
+ # -----------------------------------------
+ self.conv3 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel * self.expansion,
+ kernel_size=1, stride=1, bias=False) # unsqueeze channels
+ self.bn3 = norm_layer(out_channel * self.expansion)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+
+ def forward(self, x):
+ identity = x
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ out += identity
+ out = self.relu(out)
+
+ return out
+
+
+class ResNet(nn.Module):
+
+ def __init__(self, block, blocks_num, num_classes=1000, include_top=True, norm_layer=None):
+ super().__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ self._norm_layer = norm_layer
+
+ self.include_top = include_top
+ self.in_channel = 64
+
+ self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,
+ padding=3, bias=False)
+ self.bn1 = norm_layer(self.in_channel)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+ self.layer1 = self._make_layer(block, 64, blocks_num[0])
+ self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)
+ self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)
+ self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)
+ if self.include_top:
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # output size = (1, 1)
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+
+ def _make_layer(self, block, channel, block_num, stride=1):
+ norm_layer = self._norm_layer
+ downsample = None
+ if stride != 1 or self.in_channel != channel * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),
+ norm_layer(channel * block.expansion))
+
+ layers = []
+ layers.append(block(self.in_channel, channel, downsample=downsample,
+ stride=stride, norm_layer=norm_layer))
+ self.in_channel = channel * block.expansion
+
+ for _ in range(1, block_num):
+ layers.append(block(self.in_channel, channel, norm_layer=norm_layer))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+
+ if self.include_top:
+ x = self.avgpool(x)
+ x = torch.flatten(x, 1)
+ x = self.fc(x)
+
+ return x
+
+
+def overwrite_eps(model, eps):
+ """
+ This method overwrites the default eps values of all the
+ FrozenBatchNorm2d layers of the model with the provided value.
+ This is necessary to address the BC-breaking change introduced
+ by the bug-fix at pytorch/vision#2933. The overwrite is applied
+ only when the pretrained weights are loaded to maintain compatibility
+ with previous versions.
+
+ Args:
+ model (nn.Module): The model on which we perform the overwrite.
+ eps (float): The new value of eps.
+ """
+ for module in model.modules():
+ if isinstance(module, FrozenBatchNorm2d):
+ module.eps = eps
+
+
+def resnet50_fpn_backbone(pretrain_path="",
+ norm_layer=nn.BatchNorm2d,
+ trainable_layers=3,
+ returned_layers=None,
+ extra_blocks=None):
+ """
+ 搭建resnet50_fpn——backbone
+ Args:
+ pretrain_path: resnet50的预训练权重,如果不使用就默认为空
+ norm_layer: 默认是nn.BatchNorm2d,如果GPU显存很小,batch_size不能设置很大,
+ 建议将norm_layer设置成FrozenBatchNorm2d(默认是nn.BatchNorm2d)
+ (https://github.com/facebookresearch/maskrcnn-benchmark/issues/267)
+ trainable_layers: 指定训练哪些层结构
+ returned_layers: 指定哪些层的输出需要返回
+ extra_blocks: 在输出的特征层基础上额外添加的层结构
+
+ Returns:
+
+ """
+ resnet_backbone = ResNet(Bottleneck, [3, 4, 6, 3],
+ include_top=False,
+ norm_layer=norm_layer)
+
+ if isinstance(norm_layer, FrozenBatchNorm2d):
+ overwrite_eps(resnet_backbone, 0.0)
+
+ if pretrain_path != "":
+ assert os.path.exists(pretrain_path), "{} is not exist.".format(pretrain_path)
+ # 载入预训练权重
+ print(resnet_backbone.load_state_dict(torch.load(pretrain_path), strict=False))
+
+ # select layers that wont be frozen
+ assert 0 <= trainable_layers <= 5
+ layers_to_train = ['layer4', 'layer3', 'layer2', 'layer1', 'conv1'][:trainable_layers]
+
+ # 如果要训练所有层结构的话,不要忘了conv1后还有一个bn1
+ if trainable_layers == 5:
+ layers_to_train.append("bn1")
+
+ # freeze layers
+ for name, parameter in resnet_backbone.named_parameters():
+ # 只训练不在layers_to_train列表中的层结构
+ if all([not name.startswith(layer) for layer in layers_to_train]):
+ parameter.requires_grad_(False)
+
+ if extra_blocks is None:
+ extra_blocks = LastLevelMaxPool()
+
+ if returned_layers is None:
+ returned_layers = [1, 2, 3, 4]
+ # 返回的特征层个数肯定大于0小于5
+ assert min(returned_layers) > 0 and max(returned_layers) < 5
+
+ # return_layers = {'layer1': '0', 'layer2': '1', 'layer3': '2', 'layer4': '3'}
+ return_layers = {f'layer{k}': str(v) for v, k in enumerate(returned_layers)}
+
+ # in_channel 为layer4的输出特征矩阵channel = 2048
+ in_channels_stage2 = resnet_backbone.in_channel // 8 # 256
+ # 记录resnet50提供给fpn的每个特征层channel
+ in_channels_list = [in_channels_stage2 * 2 ** (i - 1) for i in returned_layers]
+ # 通过fpn后得到的每个特征层的channel
+ out_channels = 256
+ return BackboneWithFPN(resnet_backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks)
diff --git a/pytorch_object_detection/mask_rcnn/coco91_indices.json b/pytorch_object_detection/mask_rcnn/coco91_indices.json
new file mode 100644
index 000000000..decbe58ce
--- /dev/null
+++ b/pytorch_object_detection/mask_rcnn/coco91_indices.json
@@ -0,0 +1,92 @@
+{
+ "1": "person",
+ "2": "bicycle",
+ "3": "car",
+ "4": "motorcycle",
+ "5": "airplane",
+ "6": "bus",
+ "7": "train",
+ "8": "truck",
+ "9": "boat",
+ "10": "traffic light",
+ "11": "fire hydrant",
+ "12": "N/A",
+ "13": "stop sign",
+ "14": "parking meter",
+ "15": "bench",
+ "16": "bird",
+ "17": "cat",
+ "18": "dog",
+ "19": "horse",
+ "20": "sheep",
+ "21": "cow",
+ "22": "elephant",
+ "23": "bear",
+ "24": "zebra",
+ "25": "giraffe",
+ "26": "N/A",
+ "27": "backpack",
+ "28": "umbrella",
+ "29": "N/A",
+ "30": "N/A",
+ "31": "handbag",
+ "32": "tie",
+ "33": "suitcase",
+ "34": "frisbee",
+ "35": "skis",
+ "36": "snowboard",
+ "37": "sports ball",
+ "38": "kite",
+ "39": "baseball bat",
+ "40": "baseball glove",
+ "41": "skateboard",
+ "42": "surfboard",
+ "43": "tennis racket",
+ "44": "bottle",
+ "45": "N/A",
+ "46": "wine glass",
+ "47": "cup",
+ "48": "fork",
+ "49": "knife",
+ "50": "spoon",
+ "51": "bowl",
+ "52": "banana",
+ "53": "apple",
+ "54": "sandwich",
+ "55": "orange",
+ "56": "broccoli",
+ "57": "carrot",
+ "58": "hot dog",
+ "59": "pizza",
+ "60": "donut",
+ "61": "cake",
+ "62": "chair",
+ "63": "couch",
+ "64": "potted plant",
+ "65": "bed",
+ "66": "N/A",
+ "67": "dining table",
+ "68": "N/A",
+ "69": "N/A",
+ "70": "toilet",
+ "71": "N/A",
+ "72": "tv",
+ "73": "laptop",
+ "74": "mouse",
+ "75": "remote",
+ "76": "keyboard",
+ "77": "cell phone",
+ "78": "microwave",
+ "79": "oven",
+ "80": "toaster",
+ "81": "sink",
+ "82": "refrigerator",
+ "83": "N/A",
+ "84": "book",
+ "85": "clock",
+ "86": "vase",
+ "87": "scissors",
+ "88": "teddy bear",
+ "89": "hair drier",
+ "90": "toothbrush"
+}
\ No newline at end of file
diff --git a/pytorch_object_detection/mask_rcnn/det_results20220406-141544.txt b/pytorch_object_detection/mask_rcnn/det_results20220406-141544.txt
new file mode 100644
index 000000000..28014527b
--- /dev/null
+++ b/pytorch_object_detection/mask_rcnn/det_results20220406-141544.txt
@@ -0,0 +1,26 @@
+epoch:0 0.171 0.342 0.154 0.099 0.211 0.213 0.184 0.315 0.334 0.168 0.375 0.440 1.3826 0.08
+epoch:1 0.230 0.419 0.230 0.132 0.266 0.288 0.224 0.374 0.395 0.216 0.435 0.512 1.0356 0.08
+epoch:2 0.242 0.435 0.244 0.133 0.272 0.313 0.233 0.393 0.416 0.232 0.452 0.532 0.9718 0.08
+epoch:3 0.261 0.456 0.269 0.145 0.284 0.326 0.248 0.415 0.440 0.260 0.475 0.550 0.9363 0.08
+epoch:4 0.266 0.458 0.277 0.150 0.301 0.337 0.250 0.409 0.433 0.245 0.467 0.564 0.9145 0.08
+epoch:5 0.272 0.465 0.286 0.155 0.309 0.348 0.251 0.407 0.429 0.247 0.461 0.561 0.8982 0.08
+epoch:6 0.288 0.482 0.303 0.163 0.321 0.363 0.263 0.431 0.452 0.265 0.491 0.570 0.8859 0.08
+epoch:7 0.287 0.483 0.302 0.164 0.320 0.363 0.268 0.432 0.454 0.268 0.483 0.584 0.8771 0.08
+epoch:8 0.298 0.492 0.318 0.166 0.336 0.377 0.268 0.434 0.454 0.265 0.500 0.580 0.8685 0.08
+epoch:9 0.289 0.484 0.306 0.156 0.325 0.374 0.263 0.428 0.450 0.252 0.490 0.589 0.8612 0.08
+epoch:10 0.297 0.489 0.316 0.167 0.330 0.381 0.270 0.436 0.459 0.258 0.501 0.579 0.8547 0.08
+epoch:11 0.299 0.494 0.317 0.171 0.335 0.382 0.272 0.439 0.461 0.276 0.501 0.586 0.8498 0.08
+epoch:12 0.301 0.497 0.321 0.178 0.333 0.390 0.270 0.443 0.466 0.277 0.505 0.600 0.8461 0.08
+epoch:13 0.307 0.503 0.327 0.175 0.345 0.388 0.276 0.441 0.465 0.269 0.510 0.574 0.8409 0.08
+epoch:14 0.299 0.491 0.319 0.171 0.339 0.372 0.271 0.445 0.470 0.284 0.508 0.593 0.8355 0.08
+epoch:15 0.306 0.503 0.324 0.166 0.342 0.396 0.278 0.443 0.468 0.271 0.511 0.598 0.8330 0.08
+epoch:16 0.374 0.579 0.407 0.214 0.415 0.476 0.311 0.500 0.526 0.325 0.573 0.659 0.7421 0.008
+epoch:17 0.379 0.587 0.409 0.214 0.420 0.484 0.316 0.502 0.528 0.322 0.569 0.668 0.7157 0.008
+epoch:18 0.380 0.587 0.411 0.214 0.423 0.486 0.315 0.503 0.528 0.323 0.571 0.669 0.7016 0.008
+epoch:19 0.381 0.588 0.413 0.216 0.422 0.490 0.317 0.508 0.532 0.332 0.574 0.676 0.6897 0.008
+epoch:20 0.379 0.586 0.410 0.212 0.418 0.488 0.313 0.499 0.523 0.317 0.566 0.667 0.6802 0.008
+epoch:21 0.378 0.587 0.408 0.210 0.418 0.488 0.314 0.496 0.520 0.314 0.560 0.667 0.6708 0.008
+epoch:22 0.381 0.588 0.411 0.213 0.420 0.495 0.316 0.500 0.524 0.318 0.567 0.673 0.6497 0.0008
+epoch:23 0.381 0.588 0.411 0.215 0.420 0.492 0.315 0.499 0.523 0.319 0.565 0.666 0.6447 0.0008
+epoch:24 0.381 0.588 0.412 0.214 0.419 0.495 0.316 0.499 0.523 0.317 0.565 0.669 0.6421 0.0008
+epoch:25 0.380 0.585 0.411 0.214 0.419 0.494 0.314 0.498 0.522 0.316 0.566 0.664 0.6398 0.0008
diff --git a/pytorch_object_detection/mask_rcnn/draw_box_utils.py b/pytorch_object_detection/mask_rcnn/draw_box_utils.py
new file mode 100644
index 000000000..2d74c9529
--- /dev/null
+++ b/pytorch_object_detection/mask_rcnn/draw_box_utils.py
@@ -0,0 +1,153 @@
+from PIL.Image import Image, fromarray
+import PIL.ImageDraw as ImageDraw
+import PIL.ImageFont as ImageFont
+from PIL import ImageColor
+import numpy as np
+
+STANDARD_COLORS = [
+ 'AliceBlue', 'Chartreuse', 'Aqua', 'Aquamarine', 'Azure', 'Beige', 'Bisque',
+ 'BlanchedAlmond', 'BlueViolet', 'BurlyWood', 'CadetBlue', 'AntiqueWhite',
+ 'Chocolate', 'Coral', 'CornflowerBlue', 'Cornsilk', 'Crimson', 'Cyan',
+ 'DarkCyan', 'DarkGoldenRod', 'DarkGrey', 'DarkKhaki', 'DarkOrange',
+ 'DarkOrchid', 'DarkSalmon', 'DarkSeaGreen', 'DarkTurquoise', 'DarkViolet',
+ 'DeepPink', 'DeepSkyBlue', 'DodgerBlue', 'FireBrick', 'FloralWhite',
+ 'ForestGreen', 'Fuchsia', 'Gainsboro', 'GhostWhite', 'Gold', 'GoldenRod',
+ 'Salmon', 'Tan', 'HoneyDew', 'HotPink', 'IndianRed', 'Ivory', 'Khaki',
+ 'Lavender', 'LavenderBlush', 'LawnGreen', 'LemonChiffon', 'LightBlue',
+ 'LightCoral', 'LightCyan', 'LightGoldenRodYellow', 'LightGray', 'LightGrey',
+ 'LightGreen', 'LightPink', 'LightSalmon', 'LightSeaGreen', 'LightSkyBlue',
+ 'LightSlateGray', 'LightSlateGrey', 'LightSteelBlue', 'LightYellow', 'Lime',
+ 'LimeGreen', 'Linen', 'Magenta', 'MediumAquaMarine', 'MediumOrchid',
+ 'MediumPurple', 'MediumSeaGreen', 'MediumSlateBlue', 'MediumSpringGreen',
+ 'MediumTurquoise', 'MediumVioletRed', 'MintCream', 'MistyRose', 'Moccasin',
+ 'NavajoWhite', 'OldLace', 'Olive', 'OliveDrab', 'Orange', 'OrangeRed',
+ 'Orchid', 'PaleGoldenRod', 'PaleGreen', 'PaleTurquoise', 'PaleVioletRed',
+ 'PapayaWhip', 'PeachPuff', 'Peru', 'Pink', 'Plum', 'PowderBlue', 'Purple',
+ 'Red', 'RosyBrown', 'RoyalBlue', 'SaddleBrown', 'Green', 'SandyBrown',
+ 'SeaGreen', 'SeaShell', 'Sienna', 'Silver', 'SkyBlue', 'SlateBlue',
+ 'SlateGray', 'SlateGrey', 'Snow', 'SpringGreen', 'SteelBlue', 'GreenYellow',
+ 'Teal', 'Thistle', 'Tomato', 'Turquoise', 'Violet', 'Wheat', 'White',
+ 'WhiteSmoke', 'Yellow', 'YellowGreen'
+]
+
+
+def draw_text(draw,
+ box: list,
+ cls: int,
+ score: float,
+ category_index: dict,
+ color: str,
+ font: str = 'arial.ttf',
+ font_size: int = 24):
+ """
+ 将目标边界框和类别信息绘制到图片上
+ """
+ try:
+ font = ImageFont.truetype(font, font_size)
+ except IOError:
+ font = ImageFont.load_default()
+
+ left, top, right, bottom = box
+ # If the total height of the display strings added to the top of the bounding
+ # box exceeds the top of the image, stack the strings below the bounding box
+ # instead of above.
+ display_str = f"{category_index[str(cls)]}: {int(100 * score)}%"
+ display_str_heights = [font.getsize(ds)[1] for ds in display_str]
+ # Each display_str has a top and bottom margin of 0.05x.
+ display_str_height = (1 + 2 * 0.05) * max(display_str_heights)
+
+ if top > display_str_height:
+ text_top = top - display_str_height
+ text_bottom = top
+ else:
+ text_top = bottom
+ text_bottom = bottom + display_str_height
+
+ for ds in display_str:
+ text_width, text_height = font.getsize(ds)
+ margin = np.ceil(0.05 * text_width)
+ draw.rectangle([(left, text_top),
+ (left + text_width + 2 * margin, text_bottom)], fill=color)
+ draw.text((left + margin, text_top),
+ ds,
+ fill='black',
+ font=font)
+ left += text_width
+
+
+def draw_masks(image, masks, colors, thresh: float = 0.7, alpha: float = 0.5):
+ np_image = np.array(image)
+ masks = np.where(masks > thresh, True, False)
+
+ # colors = np.array(colors)
+ img_to_draw = np.copy(np_image)
+ # TODO: There might be a way to vectorize this
+ for mask, color in zip(masks, colors):
+ img_to_draw[mask] = color
+
+ out = np_image * (1 - alpha) + img_to_draw * alpha
+ return fromarray(out.astype(np.uint8))
+
+
+def draw_objs(image: Image,
+ boxes: np.ndarray = None,
+ classes: np.ndarray = None,
+ scores: np.ndarray = None,
+ masks: np.ndarray = None,
+ category_index: dict = None,
+ box_thresh: float = 0.1,
+ mask_thresh: float = 0.5,
+ line_thickness: int = 8,
+ font: str = 'arial.ttf',
+ font_size: int = 24,
+ draw_boxes_on_image: bool = True,
+ draw_masks_on_image: bool = True):
+ """
+ 将目标边界框信息,类别信息,mask信息绘制在图片上
+ Args:
+ image: 需要绘制的图片
+ boxes: 目标边界框信息
+ classes: 目标类别信息
+ scores: 目标概率信息
+ masks: 目标mask信息
+ category_index: 类别与名称字典
+ box_thresh: 过滤的概率阈值
+ mask_thresh:
+ line_thickness: 边界框宽度
+ font: 字体类型
+ font_size: 字体大小
+ draw_boxes_on_image:
+ draw_masks_on_image:
+
+ Returns:
+
+ """
+
+ # 过滤掉低概率的目标
+ idxs = np.greater(scores, box_thresh)
+ boxes = boxes[idxs]
+ classes = classes[idxs]
+ scores = scores[idxs]
+ if masks is not None:
+ masks = masks[idxs]
+ if len(boxes) == 0:
+ return image
+
+ colors = [ImageColor.getrgb(STANDARD_COLORS[cls % len(STANDARD_COLORS)]) for cls in classes]
+
+ if draw_boxes_on_image:
+ # Draw all boxes onto image.
+ draw = ImageDraw.Draw(image)
+ for box, cls, score, color in zip(boxes, classes, scores, colors):
+ left, top, right, bottom = box
+ # 绘制目标边界框
+ draw.line([(left, top), (left, bottom), (right, bottom),
+ (right, top), (left, top)], width=line_thickness, fill=color)
+ # 绘制类别和概率信息
+ draw_text(draw, box.tolist(), int(cls), float(score), category_index, color, font, font_size)
+
+ if draw_masks_on_image and (masks is not None):
+ # Draw all mask onto image.
+ image = draw_masks(image, masks, colors, mask_thresh)
+
+ return image
diff --git a/pytorch_object_detection/mask_rcnn/my_dataset_coco.py b/pytorch_object_detection/mask_rcnn/my_dataset_coco.py
new file mode 100644
index 000000000..6946e07e9
--- /dev/null
+++ b/pytorch_object_detection/mask_rcnn/my_dataset_coco.py
@@ -0,0 +1,154 @@
+import os
+import json
+
+import torch
+from PIL import Image
+import torch.utils.data as data
+from pycocotools.coco import COCO
+from train_utils import coco_remove_images_without_annotations, convert_coco_poly_mask
+
+
+class CocoDetection(data.Dataset):
+ """`MS Coco Detection `_ Dataset.
+
+ Args:
+ root (string): Root directory where images are downloaded to.
+ dataset (string): train or val.
+ transforms (callable, optional): A function/transform that takes input sample and its target as entry
+ and returns a transformed version.
+ """
+
+ def __init__(self, root, dataset="train", transforms=None, years="2017"):
+ super(CocoDetection, self).__init__()
+ assert dataset in ["train", "val"], 'dataset must be in ["train", "val"]'
+ anno_file = f"instances_{dataset}{years}.json"
+ assert os.path.exists(root), "file '{}' does not exist.".format(root)
+ self.img_root = os.path.join(root, f"{dataset}{years}")
+ assert os.path.exists(self.img_root), "path '{}' does not exist.".format(self.img_root)
+ self.anno_path = os.path.join(root, "annotations", anno_file)
+ assert os.path.exists(self.anno_path), "file '{}' does not exist.".format(self.anno_path)
+
+ self.mode = dataset
+ self.transforms = transforms
+ self.coco = COCO(self.anno_path)
+
+ # 获取coco数据索引与类别名称的关系
+ # 注意在object80中的索引并不是连续的,虽然只有80个类别,但索引还是按照stuff91来排序的
+ data_classes = dict([(v["id"], v["name"]) for k, v in self.coco.cats.items()])
+ max_index = max(data_classes.keys()) # 90
+ # 将缺失的类别名称设置成N/A
+ coco_classes = {}
+ for k in range(1, max_index + 1):
+ if k in data_classes:
+ coco_classes[k] = data_classes[k]
+ else:
+ coco_classes[k] = "N/A"
+
+ if dataset == "train":
+ json_str = json.dumps(coco_classes, indent=4)
+ with open("coco91_indices.json", "w") as f:
+ f.write(json_str)
+
+ self.coco_classes = coco_classes
+
+ ids = list(sorted(self.coco.imgs.keys()))
+ if dataset == "train":
+ # 移除没有目标,或者目标面积非常小的数据
+ valid_ids = coco_remove_images_without_annotations(self.coco, ids)
+ self.ids = valid_ids
+ else:
+ self.ids = ids
+
+ def parse_targets(self,
+ img_id: int,
+ coco_targets: list,
+ w: int = None,
+ h: int = None):
+ assert w > 0
+ assert h > 0
+
+ # 只筛选出单个对象的情况
+ anno = [obj for obj in coco_targets if obj['iscrowd'] == 0]
+
+ boxes = [obj["bbox"] for obj in anno]
+
+ # guard against no boxes via resizing
+ boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
+ # [xmin, ymin, w, h] -> [xmin, ymin, xmax, ymax]
+ boxes[:, 2:] += boxes[:, :2]
+ boxes[:, 0::2].clamp_(min=0, max=w)
+ boxes[:, 1::2].clamp_(min=0, max=h)
+
+ classes = [obj["category_id"] for obj in anno]
+ classes = torch.tensor(classes, dtype=torch.int64)
+
+ area = torch.tensor([obj["area"] for obj in anno])
+ iscrowd = torch.tensor([obj["iscrowd"] for obj in anno])
+
+ segmentations = [obj["segmentation"] for obj in anno]
+ masks = convert_coco_poly_mask(segmentations, h, w)
+
+ # 筛选出合法的目标,即x_max>x_min且y_max>y_min
+ keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
+ boxes = boxes[keep]
+ classes = classes[keep]
+ masks = masks[keep]
+ area = area[keep]
+ iscrowd = iscrowd[keep]
+
+ target = {}
+ target["boxes"] = boxes
+ target["labels"] = classes
+ target["masks"] = masks
+ target["image_id"] = torch.tensor([img_id])
+
+ # for conversion to coco api
+ target["area"] = area
+ target["iscrowd"] = iscrowd
+
+ return target
+
+ def __getitem__(self, index):
+ """
+ Args:
+ index (int): Index
+
+ Returns:
+ tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``.
+ """
+ coco = self.coco
+ img_id = self.ids[index]
+ ann_ids = coco.getAnnIds(imgIds=img_id)
+ coco_target = coco.loadAnns(ann_ids)
+
+ path = coco.loadImgs(img_id)[0]['file_name']
+ img = Image.open(os.path.join(self.img_root, path)).convert('RGB')
+
+ w, h = img.size
+ target = self.parse_targets(img_id, coco_target, w, h)
+ if self.transforms is not None:
+ img, target = self.transforms(img, target)
+
+ return img, target
+
+ def __len__(self):
+ return len(self.ids)
+
+ def get_height_and_width(self, index):
+ coco = self.coco
+ img_id = self.ids[index]
+
+ img_info = coco.loadImgs(img_id)[0]
+ w = img_info["width"]
+ h = img_info["height"]
+ return h, w
+
+ @staticmethod
+ def collate_fn(batch):
+ return tuple(zip(*batch))
+
+
+if __name__ == '__main__':
+ train = CocoDetection("/data/coco2017", dataset="train")
+ print(len(train))
+ t = train[0]
diff --git a/pytorch_object_detection/mask_rcnn/my_dataset_voc.py b/pytorch_object_detection/mask_rcnn/my_dataset_voc.py
new file mode 100644
index 000000000..2034b5ace
--- /dev/null
+++ b/pytorch_object_detection/mask_rcnn/my_dataset_voc.py
@@ -0,0 +1,215 @@
+import os
+import json
+
+from lxml import etree
+import numpy as np
+from PIL import Image
+import torch
+from torch.utils.data import Dataset
+from train_utils import convert_to_coco_api
+
+
+class VOCInstances(Dataset):
+ def __init__(self, voc_root, year="2012", txt_name: str = "train.txt", transforms=None):
+ super().__init__()
+ if isinstance(year, int):
+ year = str(year)
+ assert year in ["2007", "2012"], "year must be in ['2007', '2012']"
+ if "VOCdevkit" in voc_root:
+ root = os.path.join(voc_root, f"VOC{year}")
+ else:
+ root = os.path.join(voc_root, "VOCdevkit", f"VOC{year}")
+ assert os.path.exists(root), "path '{}' does not exist.".format(root)
+ image_dir = os.path.join(root, 'JPEGImages')
+ xml_dir = os.path.join(root, 'Annotations')
+ mask_dir = os.path.join(root, 'SegmentationObject')
+
+ txt_path = os.path.join(root, "ImageSets", "Segmentation", txt_name)
+ assert os.path.exists(txt_path), "file '{}' does not exist.".format(txt_path)
+ with open(os.path.join(txt_path), "r") as f:
+ file_names = [x.strip() for x in f.readlines() if len(x.strip()) > 0]
+
+ # read class_indict
+ json_file = 'pascal_voc_indices.json'
+ assert os.path.exists(json_file), "{} file not exist.".format(json_file)
+ with open(json_file, 'r') as f:
+ idx2classes = json.load(f)
+ self.class_dict = dict([(v, k) for k, v in idx2classes.items()])
+
+ self.images_path = [] # 存储图片路径
+ self.xmls_path = [] # 存储xml文件路径
+ self.xmls_info = [] # 存储解析的xml字典文件
+ self.masks_path = [] # 存储SegmentationObject图片路径
+ self.objects_bboxes = [] # 存储解析的目标boxes等信息
+ self.masks = [] # 存储读取的SegmentationObject图片信息
+
+ # 检查图片、xml文件以及mask是否都在
+ images_path = [os.path.join(image_dir, x + ".jpg") for x in file_names]
+ xmls_path = [os.path.join(xml_dir, x + '.xml') for x in file_names]
+ masks_path = [os.path.join(mask_dir, x + ".png") for x in file_names]
+ for idx, (img_path, xml_path, mask_path) in enumerate(zip(images_path, xmls_path, masks_path)):
+ assert os.path.exists(img_path), f"not find {img_path}"
+ assert os.path.exists(xml_path), f"not find {xml_path}"
+ assert os.path.exists(mask_path), f"not find {mask_path}"
+
+ # 解析xml中bbox信息
+ with open(xml_path) as fid:
+ xml_str = fid.read()
+ xml = etree.fromstring(xml_str)
+ obs_dict = parse_xml_to_dict(xml)["annotation"] # 将xml文件解析成字典
+ obs_bboxes = parse_objects(obs_dict, xml_path, self.class_dict, idx) # 解析出目标信息
+ num_objs = obs_bboxes["boxes"].shape[0]
+
+ # 读取SegmentationObject并检查是否和bboxes信息数量一致
+ instances_mask = Image.open(mask_path)
+ instances_mask = np.array(instances_mask)
+ instances_mask[instances_mask == 255] = 0 # 255为背景或者忽略掉的地方,这里为了方便直接设置为背景(0)
+
+ # 需要检查一下标注的bbox个数是否和instances个数一致
+ num_instances = instances_mask.max()
+ if num_objs != num_instances:
+ print(f"warning: num_boxes:{num_objs} and num_instances:{num_instances} do not correspond. "
+ f"skip image:{img_path}")
+ continue
+
+ self.images_path.append(img_path)
+ self.xmls_path.append(xml_path)
+ self.xmls_info.append(obs_dict)
+ self.masks_path.append(mask_path)
+ self.objects_bboxes.append(obs_bboxes)
+ self.masks.append(instances_mask)
+
+ self.transforms = transforms
+ self.coco = convert_to_coco_api(self)
+
+ def parse_mask(self, idx: int):
+ mask = self.masks[idx]
+ c = mask.max() # 有几个目标最大索引就等于几
+ masks = []
+ # 对每个目标的mask单独使用一个channel存放
+ for i in range(1, c+1):
+ masks.append(mask == i)
+ masks = np.stack(masks, axis=0)
+ return torch.as_tensor(masks, dtype=torch.uint8)
+
+ def __getitem__(self, idx):
+ """
+ Args:
+ idx (int): Index
+
+ Returns:
+ tuple: (image, target) where target is the image segmentation.
+ """
+ img = Image.open(self.images_path[idx]).convert('RGB')
+ target = self.objects_bboxes[idx]
+ masks = self.parse_mask(idx)
+ target["masks"] = masks
+
+ if self.transforms is not None:
+ img, target = self.transforms(img, target)
+
+ return img, target
+
+ def __len__(self):
+ return len(self.images_path)
+
+ def get_height_and_width(self, idx):
+ """方便统计所有图片的高宽比例信息"""
+ # read xml
+ data = self.xmls_info[idx]
+ data_height = int(data["size"]["height"])
+ data_width = int(data["size"]["width"])
+ return data_height, data_width
+
+ def get_annotations(self, idx):
+ """方便构建COCO()"""
+ data = self.xmls_info[idx]
+ h = int(data["size"]["height"])
+ w = int(data["size"]["width"])
+ target = self.objects_bboxes[idx]
+ masks = self.parse_mask(idx)
+ target["masks"] = masks
+ return target, h, w
+
+ @staticmethod
+ def collate_fn(batch):
+ return tuple(zip(*batch))
+
+
+def parse_xml_to_dict(xml):
+ """
+ 将xml文件解析成字典形式,参考tensorflow的recursive_parse_xml_to_dict
+ Args:
+ xml: xml tree obtained by parsing XML file contents using lxml.etree
+
+ Returns:
+ Python dictionary holding XML contents.
+ """
+
+ if len(xml) == 0: # 遍历到底层,直接返回tag对应的信息
+ return {xml.tag: xml.text}
+
+ result = {}
+ for child in xml:
+ child_result = parse_xml_to_dict(child) # 递归遍历标签信息
+ if child.tag != 'object':
+ result[child.tag] = child_result[child.tag]
+ else:
+ if child.tag not in result: # 因为object可能有多个,所以需要放入列表里
+ result[child.tag] = []
+ result[child.tag].append(child_result[child.tag])
+ return {xml.tag: result}
+
+
+def parse_objects(data: dict, xml_path: str, class_dict: dict, idx: int):
+ """
+ 解析出bboxes、labels、iscrowd以及ares等信息
+ Args:
+ data: 将xml解析成dict的Annotation数据
+ xml_path: 对应xml的文件路径
+ class_dict: 类别与索引对应关系
+ idx: 图片对应的索引
+
+ Returns:
+
+ """
+ boxes = []
+ labels = []
+ iscrowd = []
+ assert "object" in data, "{} lack of object information.".format(xml_path)
+ for obj in data["object"]:
+ xmin = float(obj["bndbox"]["xmin"])
+ xmax = float(obj["bndbox"]["xmax"])
+ ymin = float(obj["bndbox"]["ymin"])
+ ymax = float(obj["bndbox"]["ymax"])
+
+ # 进一步检查数据,有的标注信息中可能有w或h为0的情况,这样的数据会导致计算回归loss为nan
+ if xmax <= xmin or ymax <= ymin:
+ print("Warning: in '{}' xml, there are some bbox w/h <=0".format(xml_path))
+ continue
+
+ boxes.append([xmin, ymin, xmax, ymax])
+ labels.append(int(class_dict[obj["name"]]))
+ if "difficult" in obj:
+ iscrowd.append(int(obj["difficult"]))
+ else:
+ iscrowd.append(0)
+
+ # convert everything into a torch.Tensor
+ boxes = torch.as_tensor(boxes, dtype=torch.float32)
+ labels = torch.as_tensor(labels, dtype=torch.int64)
+ iscrowd = torch.as_tensor(iscrowd, dtype=torch.int64)
+ image_id = torch.tensor([idx])
+ area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
+
+ return {"boxes": boxes,
+ "labels": labels,
+ "iscrowd": iscrowd,
+ "image_id": image_id,
+ "area": area}
+
+
+if __name__ == '__main__':
+ dataset = VOCInstances(voc_root="/data/")
+ print(len(dataset))
+ d1 = dataset[0]
diff --git a/pytorch_object_detection/mask_rcnn/network_files/__init__.py b/pytorch_object_detection/mask_rcnn/network_files/__init__.py
new file mode 100644
index 000000000..3a2ed2299
--- /dev/null
+++ b/pytorch_object_detection/mask_rcnn/network_files/__init__.py
@@ -0,0 +1,3 @@
+from .faster_rcnn_framework import FasterRCNN, FastRCNNPredictor
+from .rpn_function import AnchorsGenerator
+from .mask_rcnn import MaskRCNN
diff --git a/pytorch_object_detection/mask_rcnn/network_files/boxes.py b/pytorch_object_detection/mask_rcnn/network_files/boxes.py
new file mode 100644
index 000000000..8eeca4573
--- /dev/null
+++ b/pytorch_object_detection/mask_rcnn/network_files/boxes.py
@@ -0,0 +1,181 @@
+import torch
+from typing import Tuple
+from torch import Tensor
+import torchvision
+
+
+def nms(boxes, scores, iou_threshold):
+ # type: (Tensor, Tensor, float) -> Tensor
+ """
+ Performs non-maximum suppression (NMS) on the boxes according
+ to their intersection-over-union (IoU).
+
+ NMS iteratively removes lower scoring boxes which have an
+ IoU greater than iou_threshold with another (higher scoring)
+ box.
+
+ Parameters
+ ----------
+ boxes : Tensor[N, 4])
+ boxes to perform NMS on. They
+ are expected to be in (x1, y1, x2, y2) format
+ scores : Tensor[N]
+ scores for each one of the boxes
+ iou_threshold : float
+ discards all overlapping
+ boxes with IoU > iou_threshold
+
+ Returns
+ -------
+ keep : Tensor
+ int64 tensor with the indices
+ of the elements that have been kept
+ by NMS, sorted in decreasing order of scores
+ """
+ return torch.ops.torchvision.nms(boxes, scores, iou_threshold)
+
+
+def batched_nms(boxes, scores, idxs, iou_threshold):
+ # type: (Tensor, Tensor, Tensor, float) -> Tensor
+ """
+ Performs non-maximum suppression in a batched fashion.
+
+ Each index value correspond to a category, and NMS
+ will not be applied between elements of different categories.
+
+ Parameters
+ ----------
+ boxes : Tensor[N, 4]
+ boxes where NMS will be performed. They
+ are expected to be in (x1, y1, x2, y2) format
+ scores : Tensor[N]
+ scores for each one of the boxes
+ idxs : Tensor[N]
+ indices of the categories for each one of the boxes.
+ iou_threshold : float
+ discards all overlapping boxes
+ with IoU < iou_threshold
+
+ Returns
+ -------
+ keep : Tensor
+ int64 tensor with the indices of
+ the elements that have been kept by NMS, sorted
+ in decreasing order of scores
+ """
+ if boxes.numel() == 0:
+ return torch.empty((0,), dtype=torch.int64, device=boxes.device)
+
+ # strategy: in order to perform NMS independently per class.
+ # we add an offset to all the boxes. The offset is dependent
+ # only on the class idx, and is large enough so that boxes
+ # from different classes do not overlap
+ # 获取所有boxes中最大的坐标值(xmin, ymin, xmax, ymax)
+ max_coordinate = boxes.max()
+
+ # to(): Performs Tensor dtype and/or device conversion
+ # 为每一个类别/每一层生成一个很大的偏移量
+ # 这里的to只是让生成tensor的dytpe和device与boxes保持一致
+ offsets = idxs.to(boxes) * (max_coordinate + 1)
+ # boxes加上对应层的偏移量后,保证不同类别/层之间boxes不会有重合的现象
+ boxes_for_nms = boxes + offsets[:, None]
+ keep = nms(boxes_for_nms, scores, iou_threshold)
+ return keep
+
+
+def remove_small_boxes(boxes, min_size):
+ # type: (Tensor, float) -> Tensor
+ """
+ Remove boxes which contains at least one side smaller than min_size.
+ 移除宽高小于指定阈值的索引
+ Arguments:
+ boxes (Tensor[N, 4]): boxes in (x1, y1, x2, y2) format
+ min_size (float): minimum size
+
+ Returns:
+ keep (Tensor[K]): indices of the boxes that have both sides
+ larger than min_size
+ """
+ ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1] # 预测boxes的宽和高
+ # keep = (ws >= min_size) & (hs >= min_size) # 当满足宽,高都大于给定阈值时为True
+ keep = torch.logical_and(torch.ge(ws, min_size), torch.ge(hs, min_size))
+ # nonzero(): Returns a tensor containing the indices of all non-zero elements of input
+ # keep = keep.nonzero().squeeze(1)
+ keep = torch.where(keep)[0]
+ return keep
+
+
+def clip_boxes_to_image(boxes, size):
+ # type: (Tensor, Tuple[int, int]) -> Tensor
+ """
+ Clip boxes so that they lie inside an image of size `size`.
+ 裁剪预测的boxes信息,将越界的坐标调整到图片边界上
+
+ Arguments:
+ boxes (Tensor[N, 4]): boxes in (x1, y1, x2, y2) format
+ size (Tuple[height, width]): size of the image
+
+ Returns:
+ clipped_boxes (Tensor[N, 4])
+ """
+ dim = boxes.dim()
+ boxes_x = boxes[..., 0::2] # x1, x2
+ boxes_y = boxes[..., 1::2] # y1, y2
+ height, width = size
+
+ if torchvision._is_tracing():
+ boxes_x = torch.max(boxes_x, torch.tensor(0, dtype=boxes.dtype, device=boxes.device))
+ boxes_x = torch.min(boxes_x, torch.tensor(width, dtype=boxes.dtype, device=boxes.device))
+ boxes_y = torch.max(boxes_y, torch.tensor(0, dtype=boxes.dtype, device=boxes.device))
+ boxes_y = torch.min(boxes_y, torch.tensor(height, dtype=boxes.dtype, device=boxes.device))
+ else:
+ boxes_x = boxes_x.clamp(min=0, max=width) # 限制x坐标范围在[0,width]之间
+ boxes_y = boxes_y.clamp(min=0, max=height) # 限制y坐标范围在[0,height]之间
+
+ clipped_boxes = torch.stack((boxes_x, boxes_y), dim=dim)
+ return clipped_boxes.reshape(boxes.shape)
+
+
+def box_area(boxes):
+ """
+ Computes the area of a set of bounding boxes, which are specified by its
+ (x1, y1, x2, y2) coordinates.
+
+ Arguments:
+ boxes (Tensor[N, 4]): boxes for which the area will be computed. They
+ are expected to be in (x1, y1, x2, y2) format
+
+ Returns:
+ area (Tensor[N]): area for each box
+ """
+ return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
+
+
+def box_iou(boxes1, boxes2):
+ """
+ Return intersection-over-union (Jaccard index) of boxes.
+
+ Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
+
+ Arguments:
+ boxes1 (Tensor[N, 4])
+ boxes2 (Tensor[M, 4])
+
+ Returns:
+ iou (Tensor[N, M]): the NxM matrix containing the pairwise
+ IoU values for every element in boxes1 and boxes2
+ """
+ area1 = box_area(boxes1)
+ area2 = box_area(boxes2)
+
+ # When the shapes do not match,
+ # the shape of the returned output tensor follows the broadcasting rules
+ lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # left-top [N,M,2]
+ rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # right-bottom [N,M,2]
+
+ wh = (rb - lt).clamp(min=0) # [N,M,2]
+ inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
+
+ iou = inter / (area1[:, None] + area2 - inter)
+ return iou
+
diff --git a/pytorch_object_detection/mask_rcnn/network_files/det_utils.py b/pytorch_object_detection/mask_rcnn/network_files/det_utils.py
new file mode 100644
index 000000000..6b4fe6013
--- /dev/null
+++ b/pytorch_object_detection/mask_rcnn/network_files/det_utils.py
@@ -0,0 +1,408 @@
+import torch
+import math
+from typing import List, Tuple
+from torch import Tensor
+
+
+class BalancedPositiveNegativeSampler(object):
+ """
+ This class samples batches, ensuring that they contain a fixed proportion of positives
+ """
+
+ def __init__(self, batch_size_per_image, positive_fraction):
+ # type: (int, float) -> None
+ """
+ Arguments:
+ batch_size_per_image (int): number of elements to be selected per image
+ positive_fraction (float): percentage of positive elements per batch
+ """
+ self.batch_size_per_image = batch_size_per_image
+ self.positive_fraction = positive_fraction
+
+ def __call__(self, matched_idxs):
+ # type: (List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
+ """
+ Arguments:
+ matched idxs: list of tensors containing -1, 0 or positive values.
+ Each tensor corresponds to a specific image.
+ -1 values are ignored, 0 are considered as negatives and > 0 as
+ positives.
+
+ Returns:
+ pos_idx (list[tensor])
+ neg_idx (list[tensor])
+
+ Returns two lists of binary masks for each image.
+ The first list contains the positive elements that were selected,
+ and the second list the negative example.
+ """
+ pos_idx = []
+ neg_idx = []
+ # 遍历每张图像的matched_idxs
+ for matched_idxs_per_image in matched_idxs:
+ # >= 1的为正样本, nonzero返回非零元素索引
+ # positive = torch.nonzero(matched_idxs_per_image >= 1).squeeze(1)
+ positive = torch.where(torch.ge(matched_idxs_per_image, 1))[0]
+ # = 0的为负样本
+ # negative = torch.nonzero(matched_idxs_per_image == 0).squeeze(1)
+ negative = torch.where(torch.eq(matched_idxs_per_image, 0))[0]
+
+ # 指定正样本的数量
+ num_pos = int(self.batch_size_per_image * self.positive_fraction)
+ # protect against not enough positive examples
+ # 如果正样本数量不够就直接采用所有正样本
+ num_pos = min(positive.numel(), num_pos)
+ # 指定负样本数量
+ num_neg = self.batch_size_per_image - num_pos
+ # protect against not enough negative examples
+ # 如果负样本数量不够就直接采用所有负样本
+ num_neg = min(negative.numel(), num_neg)
+
+ # randomly select positive and negative examples
+ # Returns a random permutation of integers from 0 to n - 1.
+ # 随机选择指定数量的正负样本
+ perm1 = torch.randperm(positive.numel(), device=positive.device)[:num_pos]
+ perm2 = torch.randperm(negative.numel(), device=negative.device)[:num_neg]
+
+ pos_idx_per_image = positive[perm1]
+ neg_idx_per_image = negative[perm2]
+
+ # create binary mask from indices
+ pos_idx_per_image_mask = torch.zeros_like(
+ matched_idxs_per_image, dtype=torch.uint8
+ )
+ neg_idx_per_image_mask = torch.zeros_like(
+ matched_idxs_per_image, dtype=torch.uint8
+ )
+
+ pos_idx_per_image_mask[pos_idx_per_image] = 1
+ neg_idx_per_image_mask[neg_idx_per_image] = 1
+
+ pos_idx.append(pos_idx_per_image_mask)
+ neg_idx.append(neg_idx_per_image_mask)
+
+ return pos_idx, neg_idx
+
+
+@torch.jit._script_if_tracing
+def encode_boxes(reference_boxes, proposals, weights):
+ # type: (torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor
+ """
+ Encode a set of proposals with respect to some
+ reference boxes
+
+ Arguments:
+ reference_boxes (Tensor): reference boxes(gt)
+ proposals (Tensor): boxes to be encoded(anchors)
+ weights:
+ """
+
+ # perform some unpacking to make it JIT-fusion friendly
+ wx = weights[0]
+ wy = weights[1]
+ ww = weights[2]
+ wh = weights[3]
+
+ # unsqueeze()
+ # Returns a new tensor with a dimension of size one inserted at the specified position.
+ proposals_x1 = proposals[:, 0].unsqueeze(1)
+ proposals_y1 = proposals[:, 1].unsqueeze(1)
+ proposals_x2 = proposals[:, 2].unsqueeze(1)
+ proposals_y2 = proposals[:, 3].unsqueeze(1)
+
+ reference_boxes_x1 = reference_boxes[:, 0].unsqueeze(1)
+ reference_boxes_y1 = reference_boxes[:, 1].unsqueeze(1)
+ reference_boxes_x2 = reference_boxes[:, 2].unsqueeze(1)
+ reference_boxes_y2 = reference_boxes[:, 3].unsqueeze(1)
+
+ # implementation starts here
+ # parse widths and heights
+ ex_widths = proposals_x2 - proposals_x1
+ ex_heights = proposals_y2 - proposals_y1
+ # parse coordinate of center point
+ ex_ctr_x = proposals_x1 + 0.5 * ex_widths
+ ex_ctr_y = proposals_y1 + 0.5 * ex_heights
+
+ gt_widths = reference_boxes_x2 - reference_boxes_x1
+ gt_heights = reference_boxes_y2 - reference_boxes_y1
+ gt_ctr_x = reference_boxes_x1 + 0.5 * gt_widths
+ gt_ctr_y = reference_boxes_y1 + 0.5 * gt_heights
+
+ targets_dx = wx * (gt_ctr_x - ex_ctr_x) / ex_widths
+ targets_dy = wy * (gt_ctr_y - ex_ctr_y) / ex_heights
+ targets_dw = ww * torch.log(gt_widths / ex_widths)
+ targets_dh = wh * torch.log(gt_heights / ex_heights)
+
+ targets = torch.cat((targets_dx, targets_dy, targets_dw, targets_dh), dim=1)
+ return targets
+
+
+class BoxCoder(object):
+ """
+ This class encodes and decodes a set of bounding boxes into
+ the representation used for training the regressors.
+ """
+
+ def __init__(self, weights, bbox_xform_clip=math.log(1000. / 16)):
+ # type: (Tuple[float, float, float, float], float) -> None
+ """
+ Arguments:
+ weights (4-element tuple)
+ bbox_xform_clip (float)
+ """
+ self.weights = weights
+ self.bbox_xform_clip = bbox_xform_clip
+
+ def encode(self, reference_boxes, proposals):
+ # type: (List[Tensor], List[Tensor]) -> List[Tensor]
+ """
+ 结合anchors和与之对应的gt计算regression参数
+ Args:
+ reference_boxes: List[Tensor] 每个proposal/anchor对应的gt_boxes
+ proposals: List[Tensor] anchors/proposals
+
+ Returns: regression parameters
+
+ """
+ # 统计每张图像的anchors个数,方便后面拼接在一起处理后在分开
+ # reference_boxes和proposal数据结构相同
+ boxes_per_image = [len(b) for b in reference_boxes]
+ reference_boxes = torch.cat(reference_boxes, dim=0)
+ proposals = torch.cat(proposals, dim=0)
+
+ # targets_dx, targets_dy, targets_dw, targets_dh
+ targets = self.encode_single(reference_boxes, proposals)
+ return targets.split(boxes_per_image, 0)
+
+ def encode_single(self, reference_boxes, proposals):
+ """
+ Encode a set of proposals with respect to some
+ reference boxes
+
+ Arguments:
+ reference_boxes (Tensor): reference boxes
+ proposals (Tensor): boxes to be encoded
+ """
+ dtype = reference_boxes.dtype
+ device = reference_boxes.device
+ weights = torch.as_tensor(self.weights, dtype=dtype, device=device)
+ targets = encode_boxes(reference_boxes, proposals, weights)
+
+ return targets
+
+ def decode(self, rel_codes, boxes):
+ # type: (Tensor, List[Tensor]) -> Tensor
+ """
+
+ Args:
+ rel_codes: bbox regression parameters
+ boxes: anchors/proposals
+
+ Returns:
+
+ """
+ assert isinstance(boxes, (list, tuple))
+ assert isinstance(rel_codes, torch.Tensor)
+ boxes_per_image = [b.size(0) for b in boxes]
+ concat_boxes = torch.cat(boxes, dim=0)
+
+ box_sum = 0
+ for val in boxes_per_image:
+ box_sum += val
+
+ # 将预测的bbox回归参数应用到对应anchors上得到预测bbox的坐标
+ pred_boxes = self.decode_single(
+ rel_codes, concat_boxes
+ )
+
+ # 防止pred_boxes为空时导致reshape报错
+ if box_sum > 0:
+ pred_boxes = pred_boxes.reshape(box_sum, -1, 4)
+
+ return pred_boxes
+
+ def decode_single(self, rel_codes, boxes):
+ """
+ From a set of original boxes and encoded relative box offsets,
+ get the decoded boxes.
+
+ Arguments:
+ rel_codes (Tensor): encoded boxes (bbox regression parameters)
+ boxes (Tensor): reference boxes (anchors/proposals)
+ """
+ boxes = boxes.to(rel_codes.dtype)
+
+ # xmin, ymin, xmax, ymax
+ widths = boxes[:, 2] - boxes[:, 0] # anchor/proposal宽度
+ heights = boxes[:, 3] - boxes[:, 1] # anchor/proposal高度
+ ctr_x = boxes[:, 0] + 0.5 * widths # anchor/proposal中心x坐标
+ ctr_y = boxes[:, 1] + 0.5 * heights # anchor/proposal中心y坐标
+
+ wx, wy, ww, wh = self.weights # RPN中为[1,1,1,1], fastrcnn中为[10,10,5,5]
+ dx = rel_codes[:, 0::4] / wx # 预测anchors/proposals的中心坐标x回归参数
+ dy = rel_codes[:, 1::4] / wy # 预测anchors/proposals的中心坐标y回归参数
+ dw = rel_codes[:, 2::4] / ww # 预测anchors/proposals的宽度回归参数
+ dh = rel_codes[:, 3::4] / wh # 预测anchors/proposals的高度回归参数
+
+ # limit max value, prevent sending too large values into torch.exp()
+ # self.bbox_xform_clip=math.log(1000. / 16) 4.135
+ dw = torch.clamp(dw, max=self.bbox_xform_clip)
+ dh = torch.clamp(dh, max=self.bbox_xform_clip)
+
+ pred_ctr_x = dx * widths[:, None] + ctr_x[:, None]
+ pred_ctr_y = dy * heights[:, None] + ctr_y[:, None]
+ pred_w = torch.exp(dw) * widths[:, None]
+ pred_h = torch.exp(dh) * heights[:, None]
+
+ # xmin
+ pred_boxes1 = pred_ctr_x - torch.tensor(0.5, dtype=pred_ctr_x.dtype, device=pred_w.device) * pred_w
+ # ymin
+ pred_boxes2 = pred_ctr_y - torch.tensor(0.5, dtype=pred_ctr_y.dtype, device=pred_h.device) * pred_h
+ # xmax
+ pred_boxes3 = pred_ctr_x + torch.tensor(0.5, dtype=pred_ctr_x.dtype, device=pred_w.device) * pred_w
+ # ymax
+ pred_boxes4 = pred_ctr_y + torch.tensor(0.5, dtype=pred_ctr_y.dtype, device=pred_h.device) * pred_h
+
+ pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4), dim=2).flatten(1)
+ return pred_boxes
+
+
+class Matcher(object):
+ BELOW_LOW_THRESHOLD = -1
+ BETWEEN_THRESHOLDS = -2
+
+ __annotations__ = {
+ 'BELOW_LOW_THRESHOLD': int,
+ 'BETWEEN_THRESHOLDS': int,
+ }
+
+ def __init__(self, high_threshold, low_threshold, allow_low_quality_matches=False):
+ # type: (float, float, bool) -> None
+ """
+ Args:
+ high_threshold (float): quality values greater than or equal to
+ this value are candidate matches.
+ low_threshold (float): a lower quality threshold used to stratify
+ matches into three levels:
+ 1) matches >= high_threshold
+ 2) BETWEEN_THRESHOLDS matches in [low_threshold, high_threshold)
+ 3) BELOW_LOW_THRESHOLD matches in [0, low_threshold)
+ allow_low_quality_matches (bool): if True, produce additional matches
+ for predictions that have only low-quality match candidates. See
+ set_low_quality_matches_ for more details.
+ """
+ self.BELOW_LOW_THRESHOLD = -1
+ self.BETWEEN_THRESHOLDS = -2
+ assert low_threshold <= high_threshold
+ self.high_threshold = high_threshold # 0.7
+ self.low_threshold = low_threshold # 0.3
+ self.allow_low_quality_matches = allow_low_quality_matches
+
+ def __call__(self, match_quality_matrix):
+ """
+ 计算anchors与每个gtboxes匹配的iou最大值,并记录索引,
+ iou= self.low_threshold) & (
+ matched_vals < self.high_threshold
+ )
+ # iou小于low_threshold的matches索引置为-1
+ matches[below_low_threshold] = self.BELOW_LOW_THRESHOLD # -1
+
+ # iou在[low_threshold, high_threshold]之间的matches索引置为-2
+ matches[between_thresholds] = self.BETWEEN_THRESHOLDS # -2
+
+ if self.allow_low_quality_matches:
+ assert all_matches is not None
+ self.set_low_quality_matches_(matches, all_matches, match_quality_matrix)
+
+ return matches
+
+ def set_low_quality_matches_(self, matches, all_matches, match_quality_matrix):
+ """
+ Produce additional matches for predictions that have only low-quality matches.
+ Specifically, for each ground-truth find the set of predictions that have
+ maximum overlap with it (including ties); for each prediction in that set, if
+ it is unmatched, then match it to the ground-truth with which it has the highest
+ quality value.
+ """
+ # For each gt, find the prediction with which it has highest quality
+ # 对于每个gt boxes寻找与其iou最大的anchor,
+ # highest_quality_foreach_gt为匹配到的最大iou值
+ highest_quality_foreach_gt, _ = match_quality_matrix.max(dim=1) # the dimension to reduce.
+
+ # Find highest quality match available, even if it is low, including ties
+ # 寻找每个gt boxes与其iou最大的anchor索引,一个gt匹配到的最大iou可能有多个anchor
+ # gt_pred_pairs_of_highest_quality = torch.nonzero(
+ # match_quality_matrix == highest_quality_foreach_gt[:, None]
+ # )
+ gt_pred_pairs_of_highest_quality = torch.where(
+ torch.eq(match_quality_matrix, highest_quality_foreach_gt[:, None])
+ )
+ # Example gt_pred_pairs_of_highest_quality:
+ # tensor([[ 0, 39796],
+ # [ 1, 32055],
+ # [ 1, 32070],
+ # [ 2, 39190],
+ # [ 2, 40255],
+ # [ 3, 40390],
+ # [ 3, 41455],
+ # [ 4, 45470],
+ # [ 5, 45325],
+ # [ 5, 46390]])
+ # Each row is a (gt index, prediction index)
+ # Note how gt items 1, 2, 3, and 5 each have two ties
+
+ # gt_pred_pairs_of_highest_quality[:, 0]代表是对应的gt index(不需要)
+ # pre_inds_to_update = gt_pred_pairs_of_highest_quality[:, 1]
+ pre_inds_to_update = gt_pred_pairs_of_highest_quality[1]
+ # 保留该anchor匹配gt最大iou的索引,即使iou低于设定的阈值
+ matches[pre_inds_to_update] = all_matches[pre_inds_to_update]
+
+
+def smooth_l1_loss(input, target, beta: float = 1. / 9, size_average: bool = True):
+ """
+ very similar to the smooth_l1_loss from pytorch, but with
+ the extra beta parameter
+ """
+ n = torch.abs(input - target)
+ # cond = n < beta
+ cond = torch.lt(n, beta)
+ loss = torch.where(cond, 0.5 * n ** 2 / beta, n - 0.5 * beta)
+ if size_average:
+ return loss.mean()
+ return loss.sum()
diff --git a/pytorch_object_detection/mask_rcnn/network_files/faster_rcnn_framework.py b/pytorch_object_detection/mask_rcnn/network_files/faster_rcnn_framework.py
new file mode 100644
index 000000000..827d8c653
--- /dev/null
+++ b/pytorch_object_detection/mask_rcnn/network_files/faster_rcnn_framework.py
@@ -0,0 +1,354 @@
+import warnings
+from collections import OrderedDict
+from typing import Tuple, List, Dict, Optional, Union
+
+import torch
+from torch import nn, Tensor
+import torch.nn.functional as F
+from torchvision.ops import MultiScaleRoIAlign
+
+from .roi_head import RoIHeads
+from .transform import GeneralizedRCNNTransform
+from .rpn_function import AnchorsGenerator, RPNHead, RegionProposalNetwork
+
+
+class FasterRCNNBase(nn.Module):
+ """
+ Main class for Generalized R-CNN.
+
+ Arguments:
+ backbone (nn.Module):
+ rpn (nn.Module):
+ roi_heads (nn.Module): takes the features + the proposals from the RPN and computes
+ detections / masks from it.
+ transform (nn.Module): performs the data transformation from the inputs to feed into
+ the model
+ """
+
+ def __init__(self, backbone, rpn, roi_heads, transform):
+ super(FasterRCNNBase, self).__init__()
+ self.transform = transform
+ self.backbone = backbone
+ self.rpn = rpn
+ self.roi_heads = roi_heads
+ # used only on torchscript mode
+ self._has_warned = False
+
+ @torch.jit.unused
+ def eager_outputs(self, losses, detections):
+ # type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Union[Dict[str, Tensor], List[Dict[str, Tensor]]]
+ if self.training:
+ return losses
+
+ return detections
+
+ def forward(self, images, targets=None):
+ # type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
+ """
+ Arguments:
+ images (list[Tensor]): images to be processed
+ targets (list[Dict[Tensor]]): ground-truth boxes present in the image (optional)
+
+ Returns:
+ result (list[BoxList] or dict[Tensor]): the output from the model.
+ During training, it returns a dict[Tensor] which contains the losses.
+ During testing, it returns list[BoxList] contains additional fields
+ like `scores`, `labels` and `mask` (for Mask R-CNN models).
+
+ """
+ if self.training and targets is None:
+ raise ValueError("In training mode, targets should be passed")
+
+ if self.training:
+ assert targets is not None
+ for target in targets: # 进一步判断传入的target的boxes参数是否符合规定
+ boxes = target["boxes"]
+ if isinstance(boxes, torch.Tensor):
+ if len(boxes.shape) != 2 or boxes.shape[-1] != 4:
+ raise ValueError("Expected target boxes to be a tensor"
+ "of shape [N, 4], got {:}.".format(
+ boxes.shape))
+ else:
+ raise ValueError("Expected target boxes to be of type "
+ "Tensor, got {:}.".format(type(boxes)))
+
+ original_image_sizes = torch.jit.annotate(List[Tuple[int, int]], [])
+ for img in images:
+ val = img.shape[-2:]
+ assert len(val) == 2 # 防止输入的是个一维向量
+ original_image_sizes.append((val[0], val[1]))
+ # original_image_sizes = [img.shape[-2:] for img in images]
+
+ images, targets = self.transform(images, targets) # 对图像进行预处理
+ # print(images.tensors.shape)
+ features = self.backbone(images.tensors) # 将图像输入backbone得到特征图
+ if isinstance(features, torch.Tensor): # 若只在一层特征层上预测,将feature放入有序字典中,并编号为‘0’
+ features = OrderedDict([('0', features)]) # 若在多层特征层上预测,传入的就是一个有序字典
+
+ # 将特征层以及标注target信息传入rpn中
+ # proposals: List[Tensor], Tensor_shape: [num_proposals, 4],
+ # 每个proposals是绝对坐标,且为(x1, y1, x2, y2)格式
+ proposals, proposal_losses = self.rpn(images, features, targets)
+
+ # 将rpn生成的数据以及标注target信息传入fast rcnn后半部分
+ detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets)
+
+ # 对网络的预测结果进行后处理(主要将bboxes还原到原图像尺度上)
+ detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)
+
+ losses = {}
+ losses.update(detector_losses)
+ losses.update(proposal_losses)
+
+ if torch.jit.is_scripting():
+ if not self._has_warned:
+ warnings.warn("RCNN always returns a (Losses, Detections) tuple in scripting")
+ self._has_warned = True
+ return losses, detections
+ else:
+ return self.eager_outputs(losses, detections)
+
+ # if self.training:
+ # return losses
+ #
+ # return detections
+
+
+class TwoMLPHead(nn.Module):
+ """
+ Standard heads for FPN-based models
+
+ Arguments:
+ in_channels (int): number of input channels
+ representation_size (int): size of the intermediate representation
+ """
+
+ def __init__(self, in_channels, representation_size):
+ super(TwoMLPHead, self).__init__()
+
+ self.fc6 = nn.Linear(in_channels, representation_size)
+ self.fc7 = nn.Linear(representation_size, representation_size)
+
+ def forward(self, x):
+ x = x.flatten(start_dim=1)
+
+ x = F.relu(self.fc6(x))
+ x = F.relu(self.fc7(x))
+
+ return x
+
+
+class FastRCNNPredictor(nn.Module):
+ """
+ Standard classification + bounding box regression layers
+ for Fast R-CNN.
+
+ Arguments:
+ in_channels (int): number of input channels
+ num_classes (int): number of output classes (including background)
+ """
+
+ def __init__(self, in_channels, num_classes):
+ super(FastRCNNPredictor, self).__init__()
+ self.cls_score = nn.Linear(in_channels, num_classes)
+ self.bbox_pred = nn.Linear(in_channels, num_classes * 4)
+
+ def forward(self, x):
+ if x.dim() == 4:
+ assert list(x.shape[2:]) == [1, 1]
+ x = x.flatten(start_dim=1)
+ scores = self.cls_score(x)
+ bbox_deltas = self.bbox_pred(x)
+
+ return scores, bbox_deltas
+
+
+class FasterRCNN(FasterRCNNBase):
+ """
+ Implements Faster R-CNN.
+
+ The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
+ image, and should be in 0-1 range. Different images can have different sizes.
+
+ The behavior of the model changes depending if it is in training or evaluation mode.
+
+ During training, the model expects both the input tensors, as well as a targets (list of dictionary),
+ containing:
+ - boxes (FloatTensor[N, 4]): the ground-truth boxes in [x1, y1, x2, y2] format, with values
+ between 0 and H and 0 and W
+ - labels (Int64Tensor[N]): the class label for each ground-truth box
+
+ The model returns a Dict[Tensor] during training, containing the classification and regression
+ losses for both the RPN and the R-CNN.
+
+ During inference, the model requires only the input tensors, and returns the post-processed
+ predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
+ follows:
+ - boxes (FloatTensor[N, 4]): the predicted boxes in [x1, y1, x2, y2] format, with values between
+ 0 and H and 0 and W
+ - labels (Int64Tensor[N]): the predicted labels for each image
+ - scores (Tensor[N]): the scores or each prediction
+
+ Arguments:
+ backbone (nn.Module): the network used to compute the features for the model.
+ It should contain a out_channels attribute, which indicates the number of output
+ channels that each feature map has (and it should be the same for all feature maps).
+ The backbone should return a single Tensor or and OrderedDict[Tensor].
+ num_classes (int): number of output classes of the model (including the background).
+ If box_predictor is specified, num_classes should be None.
+ min_size (int): minimum size of the image to be rescaled before feeding it to the backbone
+ max_size (int): maximum size of the image to be rescaled before feeding it to the backbone
+ image_mean (Tuple[float, float, float]): mean values used for input normalization.
+ They are generally the mean values of the dataset on which the backbone has been trained
+ on
+ image_std (Tuple[float, float, float]): std values used for input normalization.
+ They are generally the std values of the dataset on which the backbone has been trained on
+ rpn_anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
+ maps.
+ rpn_head (nn.Module): module that computes the objectness and regression deltas from the RPN
+ rpn_pre_nms_top_n_train (int): number of proposals to keep before applying NMS during training
+ rpn_pre_nms_top_n_test (int): number of proposals to keep before applying NMS during testing
+ rpn_post_nms_top_n_train (int): number of proposals to keep after applying NMS during training
+ rpn_post_nms_top_n_test (int): number of proposals to keep after applying NMS during testing
+ rpn_nms_thresh (float): NMS threshold used for postprocessing the RPN proposals
+ rpn_fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
+ considered as positive during training of the RPN.
+ rpn_bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
+ considered as negative during training of the RPN.
+ rpn_batch_size_per_image (int): number of anchors that are sampled during training of the RPN
+ for computing the loss
+ rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training
+ of the RPN
+ rpn_score_thresh (float): during inference, only return proposals with a classification score
+ greater than rpn_score_thresh
+ box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
+ the locations indicated by the bounding boxes
+ box_head (nn.Module): module that takes the cropped feature maps as input
+ box_predictor (nn.Module): module that takes the output of box_head and returns the
+ classification logits and box regression deltas.
+ box_score_thresh (float): during inference, only return proposals with a classification score
+ greater than box_score_thresh
+ box_nms_thresh (float): NMS threshold for the prediction head. Used during inference
+ box_detections_per_img (int): maximum number of detections per image, for all classes.
+ box_fg_iou_thresh (float): minimum IoU between the proposals and the GT box so that they can be
+ considered as positive during training of the classification head
+ box_bg_iou_thresh (float): maximum IoU between the proposals and the GT box so that they can be
+ considered as negative during training of the classification head
+ box_batch_size_per_image (int): number of proposals that are sampled during training of the
+ classification head
+ box_positive_fraction (float): proportion of positive proposals in a mini-batch during training
+ of the classification head
+ bbox_reg_weights (Tuple[float, float, float, float]): weights for the encoding/decoding of the
+ bounding boxes
+
+ """
+
+ def __init__(self, backbone, num_classes=None,
+ # transform parameter
+ min_size=800, max_size=1333, # 预处理resize时限制的最小尺寸与最大尺寸
+ image_mean=None, image_std=None, # 预处理normalize时使用的均值和方差
+ # RPN parameters
+ rpn_anchor_generator=None, rpn_head=None,
+ rpn_pre_nms_top_n_train=2000, rpn_pre_nms_top_n_test=1000, # rpn中在nms处理前保留的proposal数(根据score)
+ rpn_post_nms_top_n_train=2000, rpn_post_nms_top_n_test=1000, # rpn中在nms处理后保留的proposal数
+ rpn_nms_thresh=0.7, # rpn中进行nms处理时使用的iou阈值
+ rpn_fg_iou_thresh=0.7, rpn_bg_iou_thresh=0.3, # rpn计算损失时,采集正负样本设置的阈值
+ rpn_batch_size_per_image=256, rpn_positive_fraction=0.5, # rpn计算损失时采样的样本数,以及正样本占总样本的比例
+ rpn_score_thresh=0.0,
+ # Box parameters
+ box_roi_pool=None, box_head=None, box_predictor=None,
+ # 移除低目标概率 fast rcnn中进行nms处理的阈值 对预测结果根据score排序取前100个目标
+ box_score_thresh=0.05, box_nms_thresh=0.5, box_detections_per_img=100,
+ box_fg_iou_thresh=0.5, box_bg_iou_thresh=0.5, # fast rcnn计算误差时,采集正负样本设置的阈值
+ box_batch_size_per_image=512, box_positive_fraction=0.25, # fast rcnn计算误差时采样的样本数,以及正样本占所有样本的比例
+ bbox_reg_weights=None):
+ if not hasattr(backbone, "out_channels"):
+ raise ValueError(
+ "backbone should contain an attribute out_channels"
+ "specifying the number of output channels (assumed to be the"
+ "same for all the levels"
+ )
+
+ # assert isinstance(rpn_anchor_generator, (AnchorsGenerator, type(None)))
+ assert isinstance(box_roi_pool, (MultiScaleRoIAlign, type(None)))
+
+ if num_classes is not None:
+ if box_predictor is not None:
+ raise ValueError("num_classes should be None when box_predictor "
+ "is specified")
+ else:
+ if box_predictor is None:
+ raise ValueError("num_classes should not be None when box_predictor "
+ "is not specified")
+
+ # 预测特征层的channels
+ out_channels = backbone.out_channels
+
+ # 若anchor生成器为空,则自动生成针对resnet50_fpn的anchor生成器
+ if rpn_anchor_generator is None:
+ anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
+ aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
+ rpn_anchor_generator = AnchorsGenerator(
+ anchor_sizes, aspect_ratios
+ )
+
+ # 生成RPN通过滑动窗口预测网络部分
+ if rpn_head is None:
+ rpn_head = RPNHead(
+ out_channels, rpn_anchor_generator.num_anchors_per_location()[0]
+ )
+
+ # 默认rpn_pre_nms_top_n_train = 2000, rpn_pre_nms_top_n_test = 1000,
+ # 默认rpn_post_nms_top_n_train = 2000, rpn_post_nms_top_n_test = 1000,
+ rpn_pre_nms_top_n = dict(training=rpn_pre_nms_top_n_train, testing=rpn_pre_nms_top_n_test)
+ rpn_post_nms_top_n = dict(training=rpn_post_nms_top_n_train, testing=rpn_post_nms_top_n_test)
+
+ # 定义整个RPN框架
+ rpn = RegionProposalNetwork(
+ rpn_anchor_generator, rpn_head,
+ rpn_fg_iou_thresh, rpn_bg_iou_thresh,
+ rpn_batch_size_per_image, rpn_positive_fraction,
+ rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_nms_thresh,
+ score_thresh=rpn_score_thresh)
+
+ # Multi-scale RoIAlign pooling
+ if box_roi_pool is None:
+ box_roi_pool = MultiScaleRoIAlign(
+ featmap_names=['0', '1', '2', '3'], # 在哪些特征层进行roi pooling
+ output_size=[7, 7],
+ sampling_ratio=2)
+
+ # fast RCNN中roi pooling后的展平处理两个全连接层部分
+ if box_head is None:
+ resolution = box_roi_pool.output_size[0] # 默认等于7
+ representation_size = 1024
+ box_head = TwoMLPHead(
+ out_channels * resolution ** 2,
+ representation_size
+ )
+
+ # 在box_head的输出上预测部分
+ if box_predictor is None:
+ representation_size = 1024
+ box_predictor = FastRCNNPredictor(
+ representation_size,
+ num_classes)
+
+ # 将roi pooling, box_head以及box_predictor结合在一起
+ roi_heads = RoIHeads(
+ # box
+ box_roi_pool, box_head, box_predictor,
+ box_fg_iou_thresh, box_bg_iou_thresh, # 0.5 0.5
+ box_batch_size_per_image, box_positive_fraction, # 512 0.25
+ bbox_reg_weights,
+ box_score_thresh, box_nms_thresh, box_detections_per_img) # 0.05 0.5 100
+
+ if image_mean is None:
+ image_mean = [0.485, 0.456, 0.406]
+ if image_std is None:
+ image_std = [0.229, 0.224, 0.225]
+
+ # 对数据进行标准化,缩放,打包成batch等处理部分
+ transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std)
+
+ super(FasterRCNN, self).__init__(backbone, rpn, roi_heads, transform)
diff --git a/pytorch_object_detection/mask_rcnn/network_files/image_list.py b/pytorch_object_detection/mask_rcnn/network_files/image_list.py
new file mode 100644
index 000000000..a1b36f334
--- /dev/null
+++ b/pytorch_object_detection/mask_rcnn/network_files/image_list.py
@@ -0,0 +1,27 @@
+from typing import List, Tuple
+from torch import Tensor
+
+
+class ImageList(object):
+ """
+ Structure that holds a list of images (of possibly
+ varying sizes) as a single tensor.
+ This works by padding the images to the same size,
+ and storing in a field the original sizes of each image
+ """
+
+ def __init__(self, tensors, image_sizes):
+ # type: (Tensor, List[Tuple[int, int]]) -> None
+ """
+ Arguments:
+ tensors (tensor) padding后的图像数据
+ image_sizes (list[tuple[int, int]]) padding前的图像尺寸
+ """
+ self.tensors = tensors
+ self.image_sizes = image_sizes
+
+ def to(self, device):
+ # type: (Device) -> ImageList # noqa
+ cast_tensor = self.tensors.to(device)
+ return ImageList(cast_tensor, self.image_sizes)
+
diff --git a/pytorch_object_detection/mask_rcnn/network_files/mask_rcnn.py b/pytorch_object_detection/mask_rcnn/network_files/mask_rcnn.py
new file mode 100644
index 000000000..97a8d7fe9
--- /dev/null
+++ b/pytorch_object_detection/mask_rcnn/network_files/mask_rcnn.py
@@ -0,0 +1,239 @@
+from collections import OrderedDict
+import torch.nn as nn
+from torchvision.ops import MultiScaleRoIAlign
+
+from .faster_rcnn_framework import FasterRCNN
+
+
+class MaskRCNN(FasterRCNN):
+ """
+ Implements Mask R-CNN.
+
+ The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
+ image, and should be in 0-1 range. Different images can have different sizes.
+
+ The behavior of the model changes depending if it is in training or evaluation mode.
+
+ During training, the model expects both the input tensors, as well as a targets (list of dictionary),
+ containing:
+ - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
+ ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+ - labels (Int64Tensor[N]): the class label for each ground-truth box
+ - masks (UInt8Tensor[N, H, W]): the segmentation binary masks for each instance
+
+ The model returns a Dict[Tensor] during training, containing the classification and regression
+ losses for both the RPN and the R-CNN, and the mask loss.
+
+ During inference, the model requires only the input tensors, and returns the post-processed
+ predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
+ follows:
+ - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
+ ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
+ - labels (Int64Tensor[N]): the predicted labels for each image
+ - scores (Tensor[N]): the scores or each prediction
+ - masks (UInt8Tensor[N, 1, H, W]): the predicted masks for each instance, in 0-1 range. In order to
+ obtain the final segmentation masks, the soft masks can be thresholded, generally
+ with a value of 0.5 (mask >= 0.5)
+
+ Args:
+ backbone (nn.Module): the network used to compute the features for the model.
+ It should contain a out_channels attribute, which indicates the number of output
+ channels that each feature map has (and it should be the same for all feature maps).
+ The backbone should return a single Tensor or and OrderedDict[Tensor].
+ num_classes (int): number of output classes of the model (including the background).
+ If box_predictor is specified, num_classes should be None.
+ min_size (int): minimum size of the image to be rescaled before feeding it to the backbone
+ max_size (int): maximum size of the image to be rescaled before feeding it to the backbone
+ image_mean (Tuple[float, float, float]): mean values used for input normalization.
+ They are generally the mean values of the dataset on which the backbone has been trained
+ on
+ image_std (Tuple[float, float, float]): std values used for input normalization.
+ They are generally the std values of the dataset on which the backbone has been trained on
+ rpn_anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
+ maps.
+ rpn_head (nn.Module): module that computes the objectness and regression deltas from the RPN
+ rpn_pre_nms_top_n_train (int): number of proposals to keep before applying NMS during training
+ rpn_pre_nms_top_n_test (int): number of proposals to keep before applying NMS during testing
+ rpn_post_nms_top_n_train (int): number of proposals to keep after applying NMS during training
+ rpn_post_nms_top_n_test (int): number of proposals to keep after applying NMS during testing
+ rpn_nms_thresh (float): NMS threshold used for postprocessing the RPN proposals
+ rpn_fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
+ considered as positive during training of the RPN.
+ rpn_bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
+ considered as negative during training of the RPN.
+ rpn_batch_size_per_image (int): number of anchors that are sampled during training of the RPN
+ for computing the loss
+ rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training
+ of the RPN
+ rpn_score_thresh (float): during inference, only return proposals with a classification score
+ greater than rpn_score_thresh
+ box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
+ the locations indicated by the bounding boxes
+ box_head (nn.Module): module that takes the cropped feature maps as input
+ box_predictor (nn.Module): module that takes the output of box_head and returns the
+ classification logits and box regression deltas.
+ box_score_thresh (float): during inference, only return proposals with a classification score
+ greater than box_score_thresh
+ box_nms_thresh (float): NMS threshold for the prediction head. Used during inference
+ box_detections_per_img (int): maximum number of detections per image, for all classes.
+ box_fg_iou_thresh (float): minimum IoU between the proposals and the GT box so that they can be
+ considered as positive during training of the classification head
+ box_bg_iou_thresh (float): maximum IoU between the proposals and the GT box so that they can be
+ considered as negative during training of the classification head
+ box_batch_size_per_image (int): number of proposals that are sampled during training of the
+ classification head
+ box_positive_fraction (float): proportion of positive proposals in a mini-batch during training
+ of the classification head
+ bbox_reg_weights (Tuple[float, float, float, float]): weights for the encoding/decoding of the
+ bounding boxes
+ mask_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
+ the locations indicated by the bounding boxes, which will be used for the mask head.
+ mask_head (nn.Module): module that takes the cropped feature maps as input
+ mask_predictor (nn.Module): module that takes the output of the mask_head and returns the
+ segmentation mask logits
+
+ """
+
+ def __init__(
+ self,
+ backbone,
+ num_classes=None,
+ # transform parameters
+ min_size=800,
+ max_size=1333,
+ image_mean=None,
+ image_std=None,
+ # RPN parameters
+ rpn_anchor_generator=None,
+ rpn_head=None,
+ rpn_pre_nms_top_n_train=2000,
+ rpn_pre_nms_top_n_test=1000,
+ rpn_post_nms_top_n_train=2000,
+ rpn_post_nms_top_n_test=1000,
+ rpn_nms_thresh=0.7,
+ rpn_fg_iou_thresh=0.7,
+ rpn_bg_iou_thresh=0.3,
+ rpn_batch_size_per_image=256,
+ rpn_positive_fraction=0.5,
+ rpn_score_thresh=0.0,
+ # Box parameters
+ box_roi_pool=None,
+ box_head=None,
+ box_predictor=None,
+ box_score_thresh=0.05,
+ box_nms_thresh=0.5,
+ box_detections_per_img=100,
+ box_fg_iou_thresh=0.5,
+ box_bg_iou_thresh=0.5,
+ box_batch_size_per_image=512,
+ box_positive_fraction=0.25,
+ bbox_reg_weights=None,
+ # Mask parameters
+ mask_roi_pool=None,
+ mask_head=None,
+ mask_predictor=None,
+ ):
+
+ if not isinstance(mask_roi_pool, (MultiScaleRoIAlign, type(None))):
+ raise TypeError(
+ f"mask_roi_pool should be of type MultiScaleRoIAlign or None instead of {type(mask_roi_pool)}"
+ )
+
+ if num_classes is not None:
+ if mask_predictor is not None:
+ raise ValueError("num_classes should be None when mask_predictor is specified")
+
+ out_channels = backbone.out_channels
+
+ if mask_roi_pool is None:
+ mask_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=14, sampling_ratio=2)
+
+ if mask_head is None:
+ mask_layers = (256, 256, 256, 256)
+ mask_dilation = 1
+ mask_head = MaskRCNNHeads(out_channels, mask_layers, mask_dilation)
+
+ if mask_predictor is None:
+ mask_predictor_in_channels = 256
+ mask_dim_reduced = 256
+ mask_predictor = MaskRCNNPredictor(mask_predictor_in_channels, mask_dim_reduced, num_classes)
+
+ super().__init__(
+ backbone,
+ num_classes,
+ # transform parameters
+ min_size,
+ max_size,
+ image_mean,
+ image_std,
+ # RPN-specific parameters
+ rpn_anchor_generator,
+ rpn_head,
+ rpn_pre_nms_top_n_train,
+ rpn_pre_nms_top_n_test,
+ rpn_post_nms_top_n_train,
+ rpn_post_nms_top_n_test,
+ rpn_nms_thresh,
+ rpn_fg_iou_thresh,
+ rpn_bg_iou_thresh,
+ rpn_batch_size_per_image,
+ rpn_positive_fraction,
+ rpn_score_thresh,
+ # Box parameters
+ box_roi_pool,
+ box_head,
+ box_predictor,
+ box_score_thresh,
+ box_nms_thresh,
+ box_detections_per_img,
+ box_fg_iou_thresh,
+ box_bg_iou_thresh,
+ box_batch_size_per_image,
+ box_positive_fraction,
+ bbox_reg_weights,
+ )
+
+ self.roi_heads.mask_roi_pool = mask_roi_pool
+ self.roi_heads.mask_head = mask_head
+ self.roi_heads.mask_predictor = mask_predictor
+
+
+class MaskRCNNHeads(nn.Sequential):
+ def __init__(self, in_channels, layers, dilation):
+ """
+ Args:
+ in_channels (int): number of input channels
+ layers (tuple): feature dimensions of each FCN layer
+ dilation (int): dilation rate of kernel
+ """
+ d = OrderedDict()
+ next_feature = in_channels
+
+ for layer_idx, layers_features in enumerate(layers, 1):
+ d[f"mask_fcn{layer_idx}"] = nn.Conv2d(next_feature,
+ layers_features,
+ kernel_size=3,
+ stride=1,
+ padding=dilation,
+ dilation=dilation)
+ d[f"relu{layer_idx}"] = nn.ReLU(inplace=True)
+ next_feature = layers_features
+
+ super().__init__(d)
+ # initial params
+ for name, param in self.named_parameters():
+ if "weight" in name:
+ nn.init.kaiming_normal_(param, mode="fan_out", nonlinearity="relu")
+
+
+class MaskRCNNPredictor(nn.Sequential):
+ def __init__(self, in_channels, dim_reduced, num_classes):
+ super().__init__(OrderedDict([
+ ("conv5_mask", nn.ConvTranspose2d(in_channels, dim_reduced, 2, 2, 0)),
+ ("relu", nn.ReLU(inplace=True)),
+ ("mask_fcn_logits", nn.Conv2d(dim_reduced, num_classes, 1, 1, 0))
+ ]))
+ # initial params
+ for name, param in self.named_parameters():
+ if "weight" in name:
+ nn.init.kaiming_normal_(param, mode="fan_out", nonlinearity="relu")
diff --git a/pytorch_object_detection/mask_rcnn/network_files/roi_head.py b/pytorch_object_detection/mask_rcnn/network_files/roi_head.py
new file mode 100644
index 000000000..7269f58da
--- /dev/null
+++ b/pytorch_object_detection/mask_rcnn/network_files/roi_head.py
@@ -0,0 +1,560 @@
+from typing import Optional, List, Dict, Tuple
+
+import torch
+from torch import Tensor
+import torch.nn.functional as F
+from torchvision.ops import roi_align
+
+from . import det_utils
+from . import boxes as box_ops
+
+
+def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
+ # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
+ """
+ Computes the loss for Faster R-CNN.
+
+ Arguments:
+ class_logits : 预测类别概率信息,shape=[num_anchors, num_classes]
+ box_regression : 预测边目标界框回归信息
+ labels : 真实类别信息
+ regression_targets : 真实目标边界框信息
+
+ Returns:
+ classification_loss (Tensor)
+ box_loss (Tensor)
+ """
+
+ labels = torch.cat(labels, dim=0)
+ regression_targets = torch.cat(regression_targets, dim=0)
+
+ # 计算类别损失信息
+ classification_loss = F.cross_entropy(class_logits, labels)
+
+ # get indices that correspond to the regression targets for
+ # the corresponding ground truth labels, to be used with
+ # advanced indexing
+ # 返回标签类别大于0的索引
+ # sampled_pos_inds_subset = torch.nonzero(torch.gt(labels, 0)).squeeze(1)
+ sampled_pos_inds_subset = torch.where(torch.gt(labels, 0))[0]
+
+ # 返回标签类别大于0位置的类别信息
+ labels_pos = labels[sampled_pos_inds_subset]
+
+ # shape=[num_proposal, num_classes]
+ N, num_classes = class_logits.shape
+ box_regression = box_regression.reshape(N, -1, 4)
+
+ # 计算边界框损失信息
+ box_loss = det_utils.smooth_l1_loss(
+ # 获取指定索引proposal的指定类别box信息
+ box_regression[sampled_pos_inds_subset, labels_pos],
+ regression_targets[sampled_pos_inds_subset],
+ beta=1 / 9,
+ size_average=False,
+ ) / labels.numel()
+
+ return classification_loss, box_loss
+
+
+def maskrcnn_inference(x, labels):
+ # type: (Tensor, List[Tensor]) -> List[Tensor]
+ """
+ From the results of the CNN, post process the masks
+ by taking the mask corresponding to the class with max
+ probability (which are of fixed size and directly output
+ by the CNN) and return the masks in the mask field of the BoxList.
+
+ Args:
+ x (Tensor): the mask logits
+ labels (list[BoxList]): bounding boxes that are used as
+ reference, one for ech image
+
+ Returns:
+ results (list[BoxList]): one BoxList for each image, containing
+ the extra field mask
+ """
+ # 将预测值通过sigmoid激活全部缩放到0~1之间
+ mask_prob = x.sigmoid()
+
+ # select masks corresponding to the predicted classes
+ num_masks = x.shape[0]
+ # 先记录每张图片中boxes/masks的个数
+ boxes_per_image = [label.shape[0] for label in labels]
+ # 在将所有图片中的masks信息拼接在一起(拼接后统一处理能够提升并行度)
+ labels = torch.cat(labels)
+ index = torch.arange(num_masks, device=labels.device)
+ # 提取每个masks中对应预测最终类别的mask
+ mask_prob = mask_prob[index, labels][:, None]
+ # 最后再按照每张图片中的masks个数分离开
+ mask_prob = mask_prob.split(boxes_per_image, dim=0)
+
+ return mask_prob
+
+
+def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M):
+ # type: (Tensor, Tensor, Tensor, int) -> Tensor
+ """
+ Given segmentation masks and the bounding boxes corresponding
+ to the location of the masks in the image, this function
+ crops and resizes the masks in the position defined by the
+ boxes. This prepares the masks for them to be fed to the
+ loss computation as the targets.
+ """
+ matched_idxs = matched_idxs.to(boxes)
+ rois = torch.cat([matched_idxs[:, None], boxes], dim=1)
+ gt_masks = gt_masks[:, None].to(rois)
+ return roi_align(gt_masks, rois, (M, M), 1.0)[:, 0]
+
+
+def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs):
+ # type: (Tensor, List[Tensor], List[Tensor], List[Tensor], List[Tensor]) -> Tensor
+ """
+
+ Args:
+ mask_logits:
+ proposals:
+ gt_masks:
+ gt_labels:
+ mask_matched_idxs:
+
+ Returns:
+ mask_loss (Tensor): scalar tensor containing the loss
+ """
+
+ # 28(FCN分支输出mask的大小)
+ discretization_size = mask_logits.shape[-1]
+ # 获取每个Proposal(全部为正样本)对应的gt类别
+ labels = [gt_label[idxs] for gt_label, idxs in zip(gt_labels, mask_matched_idxs)]
+ # 根据Proposal信息在gt_masks上裁剪对应区域做为计算loss时的真正gt_mask
+ mask_targets = [
+ project_masks_on_boxes(m, p, i, discretization_size) for m, p, i in zip(gt_masks, proposals, mask_matched_idxs)
+ ]
+
+ # 将一个batch中所有的Proposal对应信息拼接在一起(统一处理提高并行度)
+ labels = torch.cat(labels, dim=0)
+ mask_targets = torch.cat(mask_targets, dim=0)
+
+ # torch.mean (in binary_cross_entropy_with_logits) doesn't
+ # accept empty tensors, so handle it separately
+ if mask_targets.numel() == 0:
+ return mask_logits.sum() * 0
+
+ # 计算预测mask与真实gt_mask之间的BCELoss
+ mask_loss = F.binary_cross_entropy_with_logits(
+ mask_logits[torch.arange(labels.shape[0], device=labels.device), labels], mask_targets
+ )
+ return mask_loss
+
+
+class RoIHeads(torch.nn.Module):
+ __annotations__ = {
+ 'box_coder': det_utils.BoxCoder,
+ 'proposal_matcher': det_utils.Matcher,
+ 'fg_bg_sampler': det_utils.BalancedPositiveNegativeSampler,
+ }
+
+ def __init__(self,
+ box_roi_pool, # Multi-scale RoIAlign pooling
+ box_head, # TwoMLPHead
+ box_predictor, # FastRCNNPredictor
+ # Faster R-CNN training
+ fg_iou_thresh, bg_iou_thresh, # default: 0.5, 0.5
+ batch_size_per_image, positive_fraction, # default: 512, 0.25
+ bbox_reg_weights, # None
+ # Faster R-CNN inference
+ score_thresh, # default: 0.05
+ nms_thresh, # default: 0.5
+ detection_per_img, # default: 100
+ # Mask
+ mask_roi_pool=None,
+ mask_head=None,
+ mask_predictor=None,
+ ):
+ super(RoIHeads, self).__init__()
+
+ self.box_similarity = box_ops.box_iou
+ # assign ground-truth boxes for each proposal
+ self.proposal_matcher = det_utils.Matcher(
+ fg_iou_thresh, # default: 0.5
+ bg_iou_thresh, # default: 0.5
+ allow_low_quality_matches=False)
+
+ self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler(
+ batch_size_per_image, # default: 512
+ positive_fraction) # default: 0.25
+
+ if bbox_reg_weights is None:
+ bbox_reg_weights = (10., 10., 5., 5.)
+ self.box_coder = det_utils.BoxCoder(bbox_reg_weights)
+
+ self.box_roi_pool = box_roi_pool # Multi-scale RoIAlign pooling
+ self.box_head = box_head # TwoMLPHead
+ self.box_predictor = box_predictor # FastRCNNPredictor
+
+ self.score_thresh = score_thresh # default: 0.05
+ self.nms_thresh = nms_thresh # default: 0.5
+ self.detection_per_img = detection_per_img # default: 100
+
+ self.mask_roi_pool = mask_roi_pool
+ self.mask_head = mask_head
+ self.mask_predictor = mask_predictor
+
+ def has_mask(self):
+ if self.mask_roi_pool is None:
+ return False
+ if self.mask_head is None:
+ return False
+ if self.mask_predictor is None:
+ return False
+ return True
+
+ def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels):
+ # type: (List[Tensor], List[Tensor], List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
+ """
+ 为每个proposal匹配对应的gt_box,并划分到正负样本中
+ Args:
+ proposals:
+ gt_boxes:
+ gt_labels:
+
+ Returns:
+
+ """
+ matched_idxs = []
+ labels = []
+ # 遍历每张图像的proposals, gt_boxes, gt_labels信息
+ for proposals_in_image, gt_boxes_in_image, gt_labels_in_image in zip(proposals, gt_boxes, gt_labels):
+ if gt_boxes_in_image.numel() == 0: # 该张图像中没有gt框,为背景
+ # background image
+ device = proposals_in_image.device
+ clamped_matched_idxs_in_image = torch.zeros(
+ (proposals_in_image.shape[0],), dtype=torch.int64, device=device
+ )
+ labels_in_image = torch.zeros(
+ (proposals_in_image.shape[0],), dtype=torch.int64, device=device
+ )
+ else:
+ # set to self.box_similarity when https://github.com/pytorch/pytorch/issues/27495 lands
+ # 计算proposal与每个gt_box的iou重合度
+ match_quality_matrix = box_ops.box_iou(gt_boxes_in_image, proposals_in_image)
+
+ # 计算proposal与每个gt_box匹配的iou最大值,并记录索引,
+ # iou < low_threshold索引值为 -1, low_threshold <= iou < high_threshold索引值为 -2
+ matched_idxs_in_image = self.proposal_matcher(match_quality_matrix)
+
+ # 限制最小值,防止匹配标签时出现越界的情况
+ # 注意-1, -2对应的gt索引会调整到0,获取的标签类别为第0个gt的类别(实际上并不是),后续会进一步处理
+ clamped_matched_idxs_in_image = matched_idxs_in_image.clamp(min=0)
+ # 获取proposal匹配到的gt对应标签
+ labels_in_image = gt_labels_in_image[clamped_matched_idxs_in_image]
+ labels_in_image = labels_in_image.to(dtype=torch.int64)
+
+ # label background (below the low threshold)
+ # 将gt索引为-1的类别设置为0,即背景,负样本
+ bg_inds = matched_idxs_in_image == self.proposal_matcher.BELOW_LOW_THRESHOLD # -1
+ labels_in_image[bg_inds] = 0
+
+ # label ignore proposals (between low and high threshold)
+ # 将gt索引为-2的类别设置为-1, 即废弃样本
+ ignore_inds = matched_idxs_in_image == self.proposal_matcher.BETWEEN_THRESHOLDS # -2
+ labels_in_image[ignore_inds] = -1 # -1 is ignored by sampler
+
+ matched_idxs.append(clamped_matched_idxs_in_image)
+ labels.append(labels_in_image)
+ return matched_idxs, labels
+
+ def subsample(self, labels):
+ # type: (List[Tensor]) -> List[Tensor]
+ # BalancedPositiveNegativeSampler
+ sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
+ sampled_inds = []
+ # 遍历每张图片的正负样本索引
+ for img_idx, (pos_inds_img, neg_inds_img) in enumerate(zip(sampled_pos_inds, sampled_neg_inds)):
+ # 记录所有采集样本索引(包括正样本和负样本)
+ # img_sampled_inds = torch.nonzero(pos_inds_img | neg_inds_img).squeeze(1)
+ img_sampled_inds = torch.where(pos_inds_img | neg_inds_img)[0]
+ sampled_inds.append(img_sampled_inds)
+ return sampled_inds
+
+ def add_gt_proposals(self, proposals, gt_boxes):
+ # type: (List[Tensor], List[Tensor]) -> List[Tensor]
+ """
+ 将gt_boxes拼接到proposal后面
+ Args:
+ proposals: 一个batch中每张图像rpn预测的boxes
+ gt_boxes: 一个batch中每张图像对应的真实目标边界框
+
+ Returns:
+
+ """
+ proposals = [
+ torch.cat((proposal, gt_box))
+ for proposal, gt_box in zip(proposals, gt_boxes)
+ ]
+ return proposals
+
+ def check_targets(self, targets):
+ # type: (Optional[List[Dict[str, Tensor]]]) -> None
+ assert targets is not None
+ assert all(["boxes" in t for t in targets])
+ assert all(["labels" in t for t in targets])
+
+ def select_training_samples(self,
+ proposals, # type: List[Tensor]
+ targets # type: Optional[List[Dict[str, Tensor]]]
+ ):
+ # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]]
+ """
+ 划分正负样本,统计对应gt的标签以及边界框回归信息
+ list元素个数为batch_size
+ Args:
+ proposals: rpn预测的boxes
+ targets:
+
+ Returns:
+
+ """
+
+ # 检查target数据是否为空
+ self.check_targets(targets)
+ if targets is None:
+ raise ValueError("target should not be None.")
+
+ dtype = proposals[0].dtype
+ device = proposals[0].device
+
+ # 获取标注好的boxes以及labels信息
+ gt_boxes = [t["boxes"].to(dtype) for t in targets]
+ gt_labels = [t["labels"] for t in targets]
+
+ # append ground-truth bboxes to proposal
+ # 将gt_boxes拼接到proposal后面
+ proposals = self.add_gt_proposals(proposals, gt_boxes)
+
+ # get matching gt indices for each proposal
+ # 为每个proposal匹配对应的gt_box,并划分到正负样本中
+ matched_idxs, labels = self.assign_targets_to_proposals(proposals, gt_boxes, gt_labels)
+ # sample a fixed proportion of positive-negative proposals
+ # 按给定数量和比例采样正负样本
+ sampled_inds = self.subsample(labels)
+ matched_gt_boxes = []
+ num_images = len(proposals)
+
+ # 遍历每张图像
+ for img_id in range(num_images):
+ # 获取每张图像的正负样本索引
+ img_sampled_inds = sampled_inds[img_id]
+ # 获取对应正负样本的proposals信息
+ proposals[img_id] = proposals[img_id][img_sampled_inds]
+ # 获取对应正负样本的真实类别信息
+ labels[img_id] = labels[img_id][img_sampled_inds]
+ # 获取对应正负样本的gt索引信息
+ matched_idxs[img_id] = matched_idxs[img_id][img_sampled_inds]
+
+ gt_boxes_in_image = gt_boxes[img_id]
+ if gt_boxes_in_image.numel() == 0:
+ gt_boxes_in_image = torch.zeros((1, 4), dtype=dtype, device=device)
+ # 获取对应正负样本的gt box信息
+ matched_gt_boxes.append(gt_boxes_in_image[matched_idxs[img_id]])
+
+ # 根据gt和proposal计算边框回归参数(针对gt的)
+ regression_targets = self.box_coder.encode(matched_gt_boxes, proposals)
+ return proposals, matched_idxs, labels, regression_targets
+
+ def postprocess_detections(self,
+ class_logits, # type: Tensor
+ box_regression, # type: Tensor
+ proposals, # type: List[Tensor]
+ image_shapes # type: List[Tuple[int, int]]
+ ):
+ # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]
+ """
+ 对网络的预测数据进行后处理,包括
+ (1)根据proposal以及预测的回归参数计算出最终bbox坐标
+ (2)对预测类别结果进行softmax处理
+ (3)裁剪预测的boxes信息,将越界的坐标调整到图片边界上
+ (4)移除所有背景信息
+ (5)移除低概率目标
+ (6)移除小尺寸目标
+ (7)执行nms处理,并按scores进行排序
+ (8)根据scores排序返回前topk个目标
+ Args:
+ class_logits: 网络预测类别概率信息
+ box_regression: 网络预测的边界框回归参数
+ proposals: rpn输出的proposal
+ image_shapes: 打包成batch前每张图像的宽高
+
+ Returns:
+
+ """
+ device = class_logits.device
+ # 预测目标类别数
+ num_classes = class_logits.shape[-1]
+
+ # 获取每张图像的预测bbox数量
+ boxes_per_image = [boxes_in_image.shape[0] for boxes_in_image in proposals]
+ # 根据proposal以及预测的回归参数计算出最终bbox坐标
+ pred_boxes = self.box_coder.decode(box_regression, proposals)
+
+ # 对预测类别结果进行softmax处理
+ pred_scores = F.softmax(class_logits, -1)
+
+ # split boxes and scores per image
+ # 根据每张图像的预测bbox数量分割结果
+ pred_boxes_list = pred_boxes.split(boxes_per_image, 0)
+ pred_scores_list = pred_scores.split(boxes_per_image, 0)
+
+ all_boxes = []
+ all_scores = []
+ all_labels = []
+ # 遍历每张图像预测信息
+ for boxes, scores, image_shape in zip(pred_boxes_list, pred_scores_list, image_shapes):
+ # 裁剪预测的boxes信息,将越界的坐标调整到图片边界上
+ boxes = box_ops.clip_boxes_to_image(boxes, image_shape)
+
+ # create labels for each prediction
+ labels = torch.arange(num_classes, device=device)
+ labels = labels.view(1, -1).expand_as(scores)
+
+ # remove prediction with the background label
+ # 移除索引为0的所有信息(0代表背景)
+ boxes = boxes[:, 1:]
+ scores = scores[:, 1:]
+ labels = labels[:, 1:]
+
+ # batch everything, by making every class prediction be a separate instance
+ boxes = boxes.reshape(-1, 4)
+ scores = scores.reshape(-1)
+ labels = labels.reshape(-1)
+
+ # remove low scoring boxes
+ # 移除低概率目标,self.scores_thresh=0.05
+ # gt: Computes input > other element-wise.
+ # inds = torch.nonzero(torch.gt(scores, self.score_thresh)).squeeze(1)
+ inds = torch.where(torch.gt(scores, self.score_thresh))[0]
+ boxes, scores, labels = boxes[inds], scores[inds], labels[inds]
+
+ # remove empty boxes
+ # 移除小目标
+ keep = box_ops.remove_small_boxes(boxes, min_size=1.)
+ boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
+
+ # non-maximun suppression, independently done per class
+ # 执行nms处理,执行后的结果会按照scores从大到小进行排序返回
+ keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh)
+
+ # keep only topk scoring predictions
+ # 获取scores排在前topk个预测目标
+ keep = keep[:self.detection_per_img]
+ boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
+
+ all_boxes.append(boxes)
+ all_scores.append(scores)
+ all_labels.append(labels)
+
+ return all_boxes, all_scores, all_labels
+
+ def forward(self,
+ features, # type: Dict[str, Tensor]
+ proposals, # type: List[Tensor]
+ image_shapes, # type: List[Tuple[int, int]]
+ targets=None # type: Optional[List[Dict[str, Tensor]]]
+ ):
+ # type: (...) -> Tuple[List[Dict[str, Tensor]], Dict[str, Tensor]]
+ """
+ Arguments:
+ features (List[Tensor])
+ proposals (List[Tensor[N, 4]])
+ image_shapes (List[Tuple[H, W]])
+ targets (List[Dict])
+ """
+
+ # 检查targets的数据类型是否正确
+ if targets is not None:
+ for t in targets:
+ floating_point_types = (torch.float, torch.double, torch.half)
+ assert t["boxes"].dtype in floating_point_types, "target boxes must of float type"
+ assert t["labels"].dtype == torch.int64, "target labels must of int64 type"
+
+ if self.training:
+ # 划分正负样本,统计对应gt的标签以及边界框回归信息
+ proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets)
+ else:
+ labels = None
+ regression_targets = None
+ matched_idxs = None
+
+ # 将采集样本通过Multi-scale RoIAlign pooling层
+ # box_features_shape: [num_proposals, channel, height, width]
+ box_features = self.box_roi_pool(features, proposals, image_shapes)
+
+ # 通过roi_pooling后的两层全连接层
+ # box_features_shape: [num_proposals, representation_size]
+ box_features = self.box_head(box_features)
+
+ # 接着分别预测目标类别和边界框回归参数
+ class_logits, box_regression = self.box_predictor(box_features)
+
+ result: List[Dict[str, torch.Tensor]] = []
+ losses = {}
+ if self.training:
+ assert labels is not None and regression_targets is not None
+ loss_classifier, loss_box_reg = fastrcnn_loss(
+ class_logits, box_regression, labels, regression_targets)
+ losses = {
+ "loss_classifier": loss_classifier,
+ "loss_box_reg": loss_box_reg
+ }
+ else:
+ boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes)
+ num_images = len(boxes)
+ for i in range(num_images):
+ result.append(
+ {
+ "boxes": boxes[i],
+ "labels": labels[i],
+ "scores": scores[i],
+ }
+ )
+
+ if self.has_mask():
+ mask_proposals = [p["boxes"] for p in result] # 将最终预测的Boxes信息取出
+ if self.training:
+ # matched_idxs为每个proposal在正负样本匹配过程中得到的gt索引(背景的gt索引也默认设置成了0)
+ if matched_idxs is None:
+ raise ValueError("if in training, matched_idxs should not be None")
+
+ # during training, only focus on positive boxes
+ num_images = len(proposals)
+ mask_proposals = []
+ pos_matched_idxs = []
+ for img_id in range(num_images):
+ pos = torch.where(labels[img_id] > 0)[0] # 寻找对应gt类别大于0,即正样本
+ mask_proposals.append(proposals[img_id][pos])
+ pos_matched_idxs.append(matched_idxs[img_id][pos])
+ else:
+ pos_matched_idxs = None
+
+ mask_features = self.mask_roi_pool(features, mask_proposals, image_shapes)
+ mask_features = self.mask_head(mask_features)
+ mask_logits = self.mask_predictor(mask_features)
+
+ loss_mask = {}
+ if self.training:
+ if targets is None or pos_matched_idxs is None or mask_logits is None:
+ raise ValueError("targets, pos_matched_idxs, mask_logits cannot be None when training")
+
+ gt_masks = [t["masks"] for t in targets]
+ gt_labels = [t["labels"] for t in targets]
+ rcnn_loss_mask = maskrcnn_loss(mask_logits, mask_proposals, gt_masks, gt_labels, pos_matched_idxs)
+ loss_mask = {"loss_mask": rcnn_loss_mask}
+ else:
+ labels = [r["labels"] for r in result]
+ mask_probs = maskrcnn_inference(mask_logits, labels)
+ for mask_prob, r in zip(mask_probs, result):
+ r["masks"] = mask_prob
+
+ losses.update(loss_mask)
+
+ return result, losses
diff --git a/pytorch_object_detection/mask_rcnn/network_files/rpn_function.py b/pytorch_object_detection/mask_rcnn/network_files/rpn_function.py
new file mode 100644
index 000000000..b18689884
--- /dev/null
+++ b/pytorch_object_detection/mask_rcnn/network_files/rpn_function.py
@@ -0,0 +1,643 @@
+from typing import List, Optional, Dict, Tuple
+
+import torch
+from torch import nn, Tensor
+from torch.nn import functional as F
+import torchvision
+
+from . import det_utils
+from . import boxes as box_ops
+from .image_list import ImageList
+
+
+@torch.jit.unused
+def _onnx_get_num_anchors_and_pre_nms_top_n(ob, orig_pre_nms_top_n):
+ # type: (Tensor, int) -> Tuple[int, int]
+ from torch.onnx import operators
+ num_anchors = operators.shape_as_tensor(ob)[1].unsqueeze(0)
+ pre_nms_top_n = torch.min(torch.cat(
+ (torch.tensor([orig_pre_nms_top_n], dtype=num_anchors.dtype),
+ num_anchors), 0))
+
+ return num_anchors, pre_nms_top_n
+
+
+class AnchorsGenerator(nn.Module):
+ __annotations__ = {
+ "cell_anchors": Optional[List[torch.Tensor]],
+ "_cache": Dict[str, List[torch.Tensor]]
+ }
+
+ """
+ anchors生成器
+ Module that generates anchors for a set of feature maps and
+ image sizes.
+
+ The module support computing anchors at multiple sizes and aspect ratios
+ per feature map.
+
+ sizes and aspect_ratios should have the same number of elements, and it should
+ correspond to the number of feature maps.
+
+ sizes[i] and aspect_ratios[i] can have an arbitrary number of elements,
+ and AnchorGenerator will output a set of sizes[i] * aspect_ratios[i] anchors
+ per spatial location for feature map i.
+
+ Arguments:
+ sizes (Tuple[Tuple[int]]):
+ aspect_ratios (Tuple[Tuple[float]]):
+ """
+
+ def __init__(self, sizes=(128, 256, 512), aspect_ratios=(0.5, 1.0, 2.0)):
+ super(AnchorsGenerator, self).__init__()
+
+ if not isinstance(sizes[0], (list, tuple)):
+ # TODO change this
+ sizes = tuple((s,) for s in sizes)
+ if not isinstance(aspect_ratios[0], (list, tuple)):
+ aspect_ratios = (aspect_ratios,) * len(sizes)
+
+ assert len(sizes) == len(aspect_ratios)
+
+ self.sizes = sizes
+ self.aspect_ratios = aspect_ratios
+ self.cell_anchors = None
+ self._cache = {}
+
+ def generate_anchors(self, scales, aspect_ratios, dtype=torch.float32, device=torch.device("cpu")):
+ # type: (List[int], List[float], torch.dtype, torch.device) -> Tensor
+ """
+ compute anchor sizes
+ Arguments:
+ scales: sqrt(anchor_area)
+ aspect_ratios: h/w ratios
+ dtype: float32
+ device: cpu/gpu
+ """
+ scales = torch.as_tensor(scales, dtype=dtype, device=device)
+ aspect_ratios = torch.as_tensor(aspect_ratios, dtype=dtype, device=device)
+ h_ratios = torch.sqrt(aspect_ratios)
+ w_ratios = 1.0 / h_ratios
+
+ # [r1, r2, r3]' * [s1, s2, s3]
+ # number of elements is len(ratios)*len(scales)
+ ws = (w_ratios[:, None] * scales[None, :]).view(-1)
+ hs = (h_ratios[:, None] * scales[None, :]).view(-1)
+
+ # left-top, right-bottom coordinate relative to anchor center(0, 0)
+ # 生成的anchors模板都是以(0, 0)为中心的, shape [len(ratios)*len(scales), 4]
+ base_anchors = torch.stack([-ws, -hs, ws, hs], dim=1) / 2
+
+ return base_anchors.round() # round 四舍五入
+
+ def set_cell_anchors(self, dtype, device):
+ # type: (torch.dtype, torch.device) -> None
+ if self.cell_anchors is not None:
+ cell_anchors = self.cell_anchors
+ assert cell_anchors is not None
+ # suppose that all anchors have the same device
+ # which is a valid assumption in the current state of the codebase
+ if cell_anchors[0].device == device:
+ return
+
+ # 根据提供的sizes和aspect_ratios生成anchors模板
+ # anchors模板都是以(0, 0)为中心的anchor
+ cell_anchors = [
+ self.generate_anchors(sizes, aspect_ratios, dtype, device)
+ for sizes, aspect_ratios in zip(self.sizes, self.aspect_ratios)
+ ]
+ self.cell_anchors = cell_anchors
+
+ def num_anchors_per_location(self):
+ # 计算每个预测特征层上每个滑动窗口的预测目标数
+ return [len(s) * len(a) for s, a in zip(self.sizes, self.aspect_ratios)]
+
+ # For every combination of (a, (g, s), i) in (self.cell_anchors, zip(grid_sizes, strides), 0:2),
+ # output g[i] anchors that are s[i] distance apart in direction i, with the same dimensions as a.
+ def grid_anchors(self, grid_sizes, strides):
+ # type: (List[List[int]], List[List[Tensor]]) -> List[Tensor]
+ """
+ anchors position in grid coordinate axis map into origin image
+ 计算预测特征图对应原始图像上的所有anchors的坐标
+ Args:
+ grid_sizes: 预测特征矩阵的height和width
+ strides: 预测特征矩阵上一步对应原始图像上的步距
+ """
+ anchors = []
+ cell_anchors = self.cell_anchors
+ assert cell_anchors is not None
+
+ # 遍历每个预测特征层的grid_size,strides和cell_anchors
+ for size, stride, base_anchors in zip(grid_sizes, strides, cell_anchors):
+ grid_height, grid_width = size
+ stride_height, stride_width = stride
+ device = base_anchors.device
+
+ # For output anchor, compute [x_center, y_center, x_center, y_center]
+ # shape: [grid_width] 对应原图上的x坐标(列)
+ shifts_x = torch.arange(0, grid_width, dtype=torch.float32, device=device) * stride_width
+ # shape: [grid_height] 对应原图上的y坐标(行)
+ shifts_y = torch.arange(0, grid_height, dtype=torch.float32, device=device) * stride_height
+
+ # 计算预测特征矩阵上每个点对应原图上的坐标(anchors模板的坐标偏移量)
+ # torch.meshgrid函数分别传入行坐标和列坐标,生成网格行坐标矩阵和网格列坐标矩阵
+ # shape: [grid_height, grid_width]
+ shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
+ shift_x = shift_x.reshape(-1)
+ shift_y = shift_y.reshape(-1)
+
+ # 计算anchors坐标(xmin, ymin, xmax, ymax)在原图上的坐标偏移量
+ # shape: [grid_width*grid_height, 4]
+ shifts = torch.stack([shift_x, shift_y, shift_x, shift_y], dim=1)
+
+ # For every (base anchor, output anchor) pair,
+ # offset each zero-centered base anchor by the center of the output anchor.
+ # 将anchors模板与原图上的坐标偏移量相加得到原图上所有anchors的坐标信息(shape不同时会使用广播机制)
+ shifts_anchor = shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)
+ anchors.append(shifts_anchor.reshape(-1, 4))
+
+ return anchors # List[Tensor(all_num_anchors, 4)]
+
+ def cached_grid_anchors(self, grid_sizes, strides):
+ # type: (List[List[int]], List[List[Tensor]]) -> List[Tensor]
+ """将计算得到的所有anchors信息进行缓存"""
+ key = str(grid_sizes) + str(strides)
+ # self._cache是字典类型
+ if key in self._cache:
+ return self._cache[key]
+ anchors = self.grid_anchors(grid_sizes, strides)
+ self._cache[key] = anchors
+ return anchors
+
+ def forward(self, image_list, feature_maps):
+ # type: (ImageList, List[Tensor]) -> List[Tensor]
+ # 获取每个预测特征层的尺寸(height, width)
+ grid_sizes = list([feature_map.shape[-2:] for feature_map in feature_maps])
+
+ # 获取输入图像的height和width
+ image_size = image_list.tensors.shape[-2:]
+
+ # 获取变量类型和设备类型
+ dtype, device = feature_maps[0].dtype, feature_maps[0].device
+
+ # one step in feature map equate n pixel stride in origin image
+ # 计算特征层上的一步等于原始图像上的步长
+ strides = [[torch.tensor(image_size[0] // g[0], dtype=torch.int64, device=device),
+ torch.tensor(image_size[1] // g[1], dtype=torch.int64, device=device)] for g in grid_sizes]
+
+ # 根据提供的sizes和aspect_ratios生成anchors模板
+ self.set_cell_anchors(dtype, device)
+
+ # 计算/读取所有anchors的坐标信息(这里的anchors信息是映射到原图上的所有anchors信息,不是anchors模板)
+ # 得到的是一个list列表,对应每张预测特征图映射回原图的anchors坐标信息
+ anchors_over_all_feature_maps = self.cached_grid_anchors(grid_sizes, strides)
+
+ anchors = torch.jit.annotate(List[List[torch.Tensor]], [])
+ # 遍历一个batch中的每张图像
+ for i, (image_height, image_width) in enumerate(image_list.image_sizes):
+ anchors_in_image = []
+ # 遍历每张预测特征图映射回原图的anchors坐标信息
+ for anchors_per_feature_map in anchors_over_all_feature_maps:
+ anchors_in_image.append(anchors_per_feature_map)
+ anchors.append(anchors_in_image)
+ # 将每一张图像的所有预测特征层的anchors坐标信息拼接在一起
+ # anchors是个list,每个元素为一张图像的所有anchors信息
+ anchors = [torch.cat(anchors_per_image) for anchors_per_image in anchors]
+ # Clear the cache in case that memory leaks.
+ self._cache.clear()
+ return anchors
+
+
+class RPNHead(nn.Module):
+ """
+ add a RPN head with classification and regression
+ 通过滑动窗口计算预测目标概率与bbox regression参数
+
+ Arguments:
+ in_channels: number of channels of the input feature
+ num_anchors: number of anchors to be predicted
+ """
+
+ def __init__(self, in_channels, num_anchors):
+ super(RPNHead, self).__init__()
+ # 3x3 滑动窗口
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
+ # 计算预测的目标分数(这里的目标只是指前景或者背景)
+ self.cls_logits = nn.Conv2d(in_channels, num_anchors, kernel_size=1, stride=1)
+ # 计算预测的目标bbox regression参数
+ self.bbox_pred = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=1, stride=1)
+
+ for layer in self.children():
+ if isinstance(layer, nn.Conv2d):
+ torch.nn.init.normal_(layer.weight, std=0.01)
+ torch.nn.init.constant_(layer.bias, 0)
+
+ def forward(self, x):
+ # type: (List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
+ logits = []
+ bbox_reg = []
+ for i, feature in enumerate(x):
+ t = F.relu(self.conv(feature))
+ logits.append(self.cls_logits(t))
+ bbox_reg.append(self.bbox_pred(t))
+ return logits, bbox_reg
+
+
+def permute_and_flatten(layer, N, A, C, H, W):
+ # type: (Tensor, int, int, int, int, int) -> Tensor
+ """
+ 调整tensor顺序,并进行reshape
+ Args:
+ layer: 预测特征层上预测的目标概率或bboxes regression参数
+ N: batch_size
+ A: anchors_num_per_position
+ C: classes_num or 4(bbox coordinate)
+ H: height
+ W: width
+
+ Returns:
+ layer: 调整tensor顺序,并reshape后的结果[N, -1, C]
+ """
+ # view和reshape功能是一样的,先展平所有元素在按照给定shape排列
+ # view函数只能用于内存中连续存储的tensor,permute等操作会使tensor在内存中变得不再连续,此时就不能再调用view函数
+ # reshape则不需要依赖目标tensor是否在内存中是连续的
+ # [batch_size, anchors_num_per_position * (C or 4), height, width]
+ layer = layer.view(N, -1, C, H, W)
+ # 调换tensor维度
+ layer = layer.permute(0, 3, 4, 1, 2) # [N, H, W, -1, C]
+ layer = layer.reshape(N, -1, C)
+ return layer
+
+
+def concat_box_prediction_layers(box_cls, box_regression):
+ # type: (List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
+ """
+ 对box_cla和box_regression两个list中的每个预测特征层的预测信息
+ 的tensor排列顺序以及shape进行调整 -> [N, -1, C]
+ Args:
+ box_cls: 每个预测特征层上的预测目标概率
+ box_regression: 每个预测特征层上的预测目标bboxes regression参数
+
+ Returns:
+
+ """
+ box_cls_flattened = []
+ box_regression_flattened = []
+
+ # 遍历每个预测特征层
+ for box_cls_per_level, box_regression_per_level in zip(box_cls, box_regression):
+ # [batch_size, anchors_num_per_position * classes_num, height, width]
+ # 注意,当计算RPN中的proposal时,classes_num=1,只区分目标和背景
+ N, AxC, H, W = box_cls_per_level.shape
+ # # [batch_size, anchors_num_per_position * 4, height, width]
+ Ax4 = box_regression_per_level.shape[1]
+ # anchors_num_per_position
+ A = Ax4 // 4
+ # classes_num
+ C = AxC // A
+
+ # [N, -1, C]
+ box_cls_per_level = permute_and_flatten(box_cls_per_level, N, A, C, H, W)
+ box_cls_flattened.append(box_cls_per_level)
+
+ # [N, -1, C]
+ box_regression_per_level = permute_and_flatten(box_regression_per_level, N, A, 4, H, W)
+ box_regression_flattened.append(box_regression_per_level)
+
+ box_cls = torch.cat(box_cls_flattened, dim=1).flatten(0, -2) # start_dim, end_dim
+ box_regression = torch.cat(box_regression_flattened, dim=1).reshape(-1, 4)
+ return box_cls, box_regression
+
+
+class RegionProposalNetwork(torch.nn.Module):
+ """
+ Implements Region Proposal Network (RPN).
+
+ Arguments:
+ anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
+ maps.
+ head (nn.Module): module that computes the objectness and regression deltas
+ fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
+ considered as positive during training of the RPN.
+ bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
+ considered as negative during training of the RPN.
+ batch_size_per_image (int): number of anchors that are sampled during training of the RPN
+ for computing the loss
+ positive_fraction (float): proportion of positive anchors in a mini-batch during training
+ of the RPN
+ pre_nms_top_n (Dict[str]): number of proposals to keep before applying NMS. It should
+ contain two fields: training and testing, to allow for different values depending
+ on training or evaluation
+ post_nms_top_n (Dict[str]): number of proposals to keep after applying NMS. It should
+ contain two fields: training and testing, to allow for different values depending
+ on training or evaluation
+ nms_thresh (float): NMS threshold used for postprocessing the RPN proposals
+
+ """
+ __annotations__ = {
+ 'box_coder': det_utils.BoxCoder,
+ 'proposal_matcher': det_utils.Matcher,
+ 'fg_bg_sampler': det_utils.BalancedPositiveNegativeSampler,
+ 'pre_nms_top_n': Dict[str, int],
+ 'post_nms_top_n': Dict[str, int],
+ }
+
+ def __init__(self, anchor_generator, head,
+ fg_iou_thresh, bg_iou_thresh,
+ batch_size_per_image, positive_fraction,
+ pre_nms_top_n, post_nms_top_n, nms_thresh, score_thresh=0.0):
+ super(RegionProposalNetwork, self).__init__()
+ self.anchor_generator = anchor_generator
+ self.head = head
+ self.box_coder = det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))
+
+ # use during training
+ # 计算anchors与真实bbox的iou
+ self.box_similarity = box_ops.box_iou
+
+ self.proposal_matcher = det_utils.Matcher(
+ fg_iou_thresh, # 当iou大于fg_iou_thresh(0.7)时视为正样本
+ bg_iou_thresh, # 当iou小于bg_iou_thresh(0.3)时视为负样本
+ allow_low_quality_matches=True
+ )
+
+ self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler(
+ batch_size_per_image, positive_fraction # 256, 0.5
+ )
+
+ # use during testing
+ self._pre_nms_top_n = pre_nms_top_n
+ self._post_nms_top_n = post_nms_top_n
+ self.nms_thresh = nms_thresh
+ self.score_thresh = score_thresh
+ self.min_size = 1.
+
+ def pre_nms_top_n(self):
+ if self.training:
+ return self._pre_nms_top_n['training']
+ return self._pre_nms_top_n['testing']
+
+ def post_nms_top_n(self):
+ if self.training:
+ return self._post_nms_top_n['training']
+ return self._post_nms_top_n['testing']
+
+ def assign_targets_to_anchors(self, anchors, targets):
+ # type: (List[Tensor], List[Dict[str, Tensor]]) -> Tuple[List[Tensor], List[Tensor]]
+ """
+ 计算每个anchors最匹配的gt,并划分为正样本,背景以及废弃的样本
+ Args:
+ anchors: (List[Tensor])
+ targets: (List[Dict[Tensor])
+ Returns:
+ labels: 标记anchors归属类别(1, 0, -1分别对应正样本,背景,废弃的样本)
+ 注意,在RPN中只有前景和背景,所有正样本的类别都是1,0代表背景
+ matched_gt_boxes:与anchors匹配的gt
+ """
+ labels = []
+ matched_gt_boxes = []
+ # 遍历每张图像的anchors和targets
+ for anchors_per_image, targets_per_image in zip(anchors, targets):
+ gt_boxes = targets_per_image["boxes"]
+ if gt_boxes.numel() == 0:
+ device = anchors_per_image.device
+ matched_gt_boxes_per_image = torch.zeros(anchors_per_image.shape, dtype=torch.float32, device=device)
+ labels_per_image = torch.zeros((anchors_per_image.shape[0],), dtype=torch.float32, device=device)
+ else:
+ # 计算anchors与真实bbox的iou信息
+ # set to self.box_similarity when https://github.com/pytorch/pytorch/issues/27495 lands
+ match_quality_matrix = box_ops.box_iou(gt_boxes, anchors_per_image)
+ # 计算每个anchors与gt匹配iou最大的索引(如果iou<0.3索引置为-1,0.3= 0
+ labels_per_image = labels_per_image.to(dtype=torch.float32)
+
+ # background (negative examples)
+ bg_indices = matched_idxs == self.proposal_matcher.BELOW_LOW_THRESHOLD # -1
+ labels_per_image[bg_indices] = 0.0
+
+ # discard indices that are between thresholds
+ inds_to_discard = matched_idxs == self.proposal_matcher.BETWEEN_THRESHOLDS # -2
+ labels_per_image[inds_to_discard] = -1.0
+
+ labels.append(labels_per_image)
+ matched_gt_boxes.append(matched_gt_boxes_per_image)
+ return labels, matched_gt_boxes
+
+ def _get_top_n_idx(self, objectness, num_anchors_per_level):
+ # type: (Tensor, List[int]) -> Tensor
+ """
+ 获取每张预测特征图上预测概率排前pre_nms_top_n的anchors索引值
+ Args:
+ objectness: Tensor(每张图像的预测目标概率信息 )
+ num_anchors_per_level: List(每个预测特征层上的预测的anchors个数)
+ Returns:
+
+ """
+ r = [] # 记录每个预测特征层上预测目标概率前pre_nms_top_n的索引信息
+ offset = 0
+ # 遍历每个预测特征层上的预测目标概率信息
+ for ob in objectness.split(num_anchors_per_level, 1):
+ if torchvision._is_tracing():
+ num_anchors, pre_nms_top_n = _onnx_get_num_anchors_and_pre_nms_top_n(ob, self.pre_nms_top_n())
+ else:
+ num_anchors = ob.shape[1] # 预测特征层上的预测的anchors个数
+ pre_nms_top_n = min(self.pre_nms_top_n(), num_anchors)
+
+ # Returns the k largest elements of the given input tensor along a given dimension
+ _, top_n_idx = ob.topk(pre_nms_top_n, dim=1)
+ r.append(top_n_idx + offset)
+ offset += num_anchors
+ return torch.cat(r, dim=1)
+
+ def filter_proposals(self, proposals, objectness, image_shapes, num_anchors_per_level):
+ # type: (Tensor, Tensor, List[Tuple[int, int]], List[int]) -> Tuple[List[Tensor], List[Tensor]]
+ """
+ 筛除小boxes框,nms处理,根据预测概率获取前post_nms_top_n个目标
+ Args:
+ proposals: 预测的bbox坐标
+ objectness: 预测的目标概率
+ image_shapes: batch中每张图片的size信息
+ num_anchors_per_level: 每个预测特征层上预测anchors的数目
+
+ Returns:
+
+ """
+ num_images = proposals.shape[0]
+ device = proposals.device
+
+ # do not backprop throught objectness
+ objectness = objectness.detach()
+ objectness = objectness.reshape(num_images, -1)
+
+ # Returns a tensor of size size filled with fill_value
+ # levels负责记录分隔不同预测特征层上的anchors索引信息
+ levels = [torch.full((n, ), idx, dtype=torch.int64, device=device)
+ for idx, n in enumerate(num_anchors_per_level)]
+ levels = torch.cat(levels, 0)
+
+ # Expand this tensor to the same size as objectness
+ levels = levels.reshape(1, -1).expand_as(objectness)
+
+ # select top_n boxes independently per level before applying nms
+ # 获取每张预测特征图上预测概率排前pre_nms_top_n的anchors索引值
+ top_n_idx = self._get_top_n_idx(objectness, num_anchors_per_level)
+
+ image_range = torch.arange(num_images, device=device)
+ batch_idx = image_range[:, None] # [batch_size, 1]
+
+ # 根据每个预测特征层预测概率排前pre_nms_top_n的anchors索引值获取相应概率信息
+ objectness = objectness[batch_idx, top_n_idx]
+ levels = levels[batch_idx, top_n_idx]
+ # 预测概率排前pre_nms_top_n的anchors索引值获取相应bbox坐标信息
+ proposals = proposals[batch_idx, top_n_idx]
+
+ objectness_prob = torch.sigmoid(objectness)
+
+ final_boxes = []
+ final_scores = []
+ # 遍历每张图像的相关预测信息
+ for boxes, scores, lvl, img_shape in zip(proposals, objectness_prob, levels, image_shapes):
+ # 调整预测的boxes信息,将越界的坐标调整到图片边界上
+ boxes = box_ops.clip_boxes_to_image(boxes, img_shape)
+
+ # 返回boxes满足宽,高都大于min_size的索引
+ keep = box_ops.remove_small_boxes(boxes, self.min_size)
+ boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep]
+
+ # 移除小概率boxes,参考下面这个链接
+ # https://github.com/pytorch/vision/pull/3205
+ keep = torch.where(torch.ge(scores, self.score_thresh))[0] # ge: >=
+ boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep]
+
+ # non-maximum suppression, independently done per level
+ keep = box_ops.batched_nms(boxes, scores, lvl, self.nms_thresh)
+
+ # keep only topk scoring predictions
+ keep = keep[: self.post_nms_top_n()]
+ boxes, scores = boxes[keep], scores[keep]
+
+ final_boxes.append(boxes)
+ final_scores.append(scores)
+ return final_boxes, final_scores
+
+ def compute_loss(self, objectness, pred_bbox_deltas, labels, regression_targets):
+ # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
+ """
+ 计算RPN损失,包括类别损失(前景与背景),bbox regression损失
+ Arguments:
+ objectness (Tensor):预测的前景概率
+ pred_bbox_deltas (Tensor):预测的bbox regression
+ labels (List[Tensor]):真实的标签 1, 0, -1(batch中每一张图片的labels对应List的一个元素中)
+ regression_targets (List[Tensor]):真实的bbox regression
+
+ Returns:
+ objectness_loss (Tensor) : 类别损失
+ box_loss (Tensor):边界框回归损失
+ """
+ # 按照给定的batch_size_per_image, positive_fraction选择正负样本
+ sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
+ # 将一个batch中的所有正负样本List(Tensor)分别拼接在一起,并获取非零位置的索引
+ # sampled_pos_inds = torch.nonzero(torch.cat(sampled_pos_inds, dim=0)).squeeze(1)
+ sampled_pos_inds = torch.where(torch.cat(sampled_pos_inds, dim=0))[0]
+ # sampled_neg_inds = torch.nonzero(torch.cat(sampled_neg_inds, dim=0)).squeeze(1)
+ sampled_neg_inds = torch.where(torch.cat(sampled_neg_inds, dim=0))[0]
+
+ # 将所有正负样本索引拼接在一起
+ sampled_inds = torch.cat([sampled_pos_inds, sampled_neg_inds], dim=0)
+ objectness = objectness.flatten()
+
+ labels = torch.cat(labels, dim=0)
+ regression_targets = torch.cat(regression_targets, dim=0)
+
+ # 计算边界框回归损失
+ box_loss = det_utils.smooth_l1_loss(
+ pred_bbox_deltas[sampled_pos_inds],
+ regression_targets[sampled_pos_inds],
+ beta=1 / 9,
+ size_average=False,
+ ) / (sampled_inds.numel())
+
+ # 计算目标预测概率损失
+ objectness_loss = F.binary_cross_entropy_with_logits(
+ objectness[sampled_inds], labels[sampled_inds]
+ )
+
+ return objectness_loss, box_loss
+
+ def forward(self,
+ images, # type: ImageList
+ features, # type: Dict[str, Tensor]
+ targets=None # type: Optional[List[Dict[str, Tensor]]]
+ ):
+ # type: (...) -> Tuple[List[Tensor], Dict[str, Tensor]]
+ """
+ Arguments:
+ images (ImageList): images for which we want to compute the predictions
+ features (Dict[Tensor]): features computed from the images that are
+ used for computing the predictions. Each tensor in the list
+ correspond to different feature levels
+ targets (List[Dict[Tensor]): ground-truth boxes present in the image (optional).
+ If provided, each element in the dict should contain a field `boxes`,
+ with the locations of the ground-truth boxes.
+
+ Returns:
+ boxes (List[Tensor]): the predicted boxes from the RPN, one Tensor per
+ image.
+ losses (Dict[Tensor]): the losses for the model during training. During
+ testing, it is an empty dict.
+ """
+ # RPN uses all feature maps that are available
+ # features是所有预测特征层组成的OrderedDict
+ features = list(features.values())
+
+ # 计算每个预测特征层上的预测目标概率和bboxes regression参数
+ # objectness和pred_bbox_deltas都是list
+ objectness, pred_bbox_deltas = self.head(features)
+
+ # 生成一个batch图像的所有anchors信息,list(tensor)元素个数等于batch_size
+ anchors = self.anchor_generator(images, features)
+
+ # batch_size
+ num_images = len(anchors)
+
+ # numel() Returns the total number of elements in the input tensor.
+ # 计算每个预测特征层上的对应的anchors数量
+ num_anchors_per_level_shape_tensors = [o[0].shape for o in objectness]
+ num_anchors_per_level = [s[0] * s[1] * s[2] for s in num_anchors_per_level_shape_tensors]
+
+ # 调整内部tensor格式以及shape
+ objectness, pred_bbox_deltas = concat_box_prediction_layers(objectness,
+ pred_bbox_deltas)
+
+ # apply pred_bbox_deltas to anchors to obtain the decoded proposals
+ # note that we detach the deltas because Faster R-CNN do not backprop through
+ # the proposals
+ # 将预测的bbox regression参数应用到anchors上得到最终预测bbox坐标
+ proposals = self.box_coder.decode(pred_bbox_deltas.detach(), anchors)
+ proposals = proposals.view(num_images, -1, 4)
+
+ # 筛除小boxes框,nms处理,根据预测概率获取前post_nms_top_n个目标
+ boxes, scores = self.filter_proposals(proposals, objectness, images.image_sizes, num_anchors_per_level)
+
+ losses = {}
+ if self.training:
+ assert targets is not None
+ # 计算每个anchors最匹配的gt,并将anchors进行分类,前景,背景以及废弃的anchors
+ labels, matched_gt_boxes = self.assign_targets_to_anchors(anchors, targets)
+ # 结合anchors以及对应的gt,计算regression参数
+ regression_targets = self.box_coder.encode(matched_gt_boxes, anchors)
+ loss_objectness, loss_rpn_box_reg = self.compute_loss(
+ objectness, pred_bbox_deltas, labels, regression_targets
+ )
+ losses = {
+ "loss_objectness": loss_objectness,
+ "loss_rpn_box_reg": loss_rpn_box_reg
+ }
+ return boxes, losses
diff --git a/pytorch_object_detection/mask_rcnn/network_files/transform.py b/pytorch_object_detection/mask_rcnn/network_files/transform.py
new file mode 100644
index 000000000..420d8ed0e
--- /dev/null
+++ b/pytorch_object_detection/mask_rcnn/network_files/transform.py
@@ -0,0 +1,490 @@
+import math
+from typing import List, Tuple, Dict, Optional
+
+import torch
+from torch import nn, Tensor
+import torch.nn.functional as F
+import torchvision
+
+from .image_list import ImageList
+
+
+def _onnx_paste_mask_in_image(mask, box, im_h, im_w):
+ one = torch.ones(1, dtype=torch.int64)
+ zero = torch.zeros(1, dtype=torch.int64)
+
+ w = box[2] - box[0] + one
+ h = box[3] - box[1] + one
+ w = torch.max(torch.cat((w, one)))
+ h = torch.max(torch.cat((h, one)))
+
+ # Set shape to [batchxCxHxW]
+ mask = mask.expand((1, 1, mask.size(0), mask.size(1)))
+
+ # Resize mask
+ mask = F.interpolate(mask, size=(int(h), int(w)), mode="bilinear", align_corners=False)
+ mask = mask[0][0]
+
+ x_0 = torch.max(torch.cat((box[0].unsqueeze(0), zero)))
+ x_1 = torch.min(torch.cat((box[2].unsqueeze(0) + one, im_w.unsqueeze(0))))
+ y_0 = torch.max(torch.cat((box[1].unsqueeze(0), zero)))
+ y_1 = torch.min(torch.cat((box[3].unsqueeze(0) + one, im_h.unsqueeze(0))))
+
+ unpaded_im_mask = mask[(y_0 - box[1]): (y_1 - box[1]), (x_0 - box[0]): (x_1 - box[0])]
+
+ # TODO : replace below with a dynamic padding when support is added in ONNX
+
+ # pad y
+ zeros_y0 = torch.zeros(y_0, unpaded_im_mask.size(1))
+ zeros_y1 = torch.zeros(im_h - y_1, unpaded_im_mask.size(1))
+ concat_0 = torch.cat((zeros_y0, unpaded_im_mask.to(dtype=torch.float32), zeros_y1), 0)[0:im_h, :]
+ # pad x
+ zeros_x0 = torch.zeros(concat_0.size(0), x_0)
+ zeros_x1 = torch.zeros(concat_0.size(0), im_w - x_1)
+ im_mask = torch.cat((zeros_x0, concat_0, zeros_x1), 1)[:, :im_w]
+ return im_mask
+
+
+@torch.jit._script_if_tracing
+def _onnx_paste_mask_in_image_loop(masks, boxes, im_h, im_w):
+ res_append = torch.zeros(0, im_h, im_w)
+ for i in range(masks.size(0)):
+ mask_res = _onnx_paste_mask_in_image(masks[i][0], boxes[i], im_h, im_w)
+ mask_res = mask_res.unsqueeze(0)
+ res_append = torch.cat((res_append, mask_res))
+
+ return res_append
+
+
+@torch.jit.unused
+def _get_shape_onnx(image: Tensor) -> Tensor:
+ from torch.onnx import operators
+
+ return operators.shape_as_tensor(image)[-2:]
+
+
+@torch.jit.unused
+def _fake_cast_onnx(v: Tensor) -> float:
+ # ONNX requires a tensor but here we fake its type for JIT.
+ return v
+
+
+def _resize_image_and_masks(image: Tensor,
+ self_min_size: float,
+ self_max_size: float,
+ target: Optional[Dict[str, Tensor]] = None,
+ fixed_size: Optional[Tuple[int, int]] = None
+ ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
+
+ if torchvision._is_tracing():
+ im_shape = _get_shape_onnx(image)
+ else:
+ im_shape = torch.tensor(image.shape[-2:])
+
+ size: Optional[List[int]] = None
+ scale_factor: Optional[float] = None
+ recompute_scale_factor: Optional[bool] = None
+ if fixed_size is not None:
+ size = [fixed_size[1], fixed_size[0]]
+ else:
+ min_size = torch.min(im_shape).to(dtype=torch.float32) # 获取高宽中的最小值
+ max_size = torch.max(im_shape).to(dtype=torch.float32) # 获取高宽中的最大值
+ scale = torch.min(self_min_size / min_size, self_max_size / max_size) # 计算缩放比例
+
+ if torchvision._is_tracing():
+ scale_factor = _fake_cast_onnx(scale)
+ else:
+ scale_factor = scale.item()
+ recompute_scale_factor = True
+
+ # interpolate利用插值的方法缩放图片
+ # image[None]操作是在最前面添加batch维度[C, H, W] -> [1, C, H, W]
+ # bilinear只支持4D Tensor
+ image = torch.nn.functional.interpolate(
+ image[None],
+ size=size,
+ scale_factor=scale_factor,
+ mode="bilinear",
+ recompute_scale_factor=recompute_scale_factor,
+ align_corners=False)[0]
+
+ if target is None:
+ return image, target
+
+ if "masks" in target:
+ mask = target["masks"]
+ mask = torch.nn.functional.interpolate(
+ mask[:, None].float(), size=size, scale_factor=scale_factor, recompute_scale_factor=recompute_scale_factor
+ )[:, 0].byte() # self.byte() is equivalent to self.to(torch.uint8).
+ target["masks"] = mask
+
+ return image, target
+
+
+def _onnx_expand_boxes(boxes, scale):
+ # type: (Tensor, float) -> Tensor
+ w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
+ h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
+ x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
+ y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
+
+ w_half = w_half.to(dtype=torch.float32) * scale
+ h_half = h_half.to(dtype=torch.float32) * scale
+
+ boxes_exp0 = x_c - w_half
+ boxes_exp1 = y_c - h_half
+ boxes_exp2 = x_c + w_half
+ boxes_exp3 = y_c + h_half
+ boxes_exp = torch.stack((boxes_exp0, boxes_exp1, boxes_exp2, boxes_exp3), 1)
+ return boxes_exp
+
+
+# the next two functions should be merged inside Masker
+# but are kept here for the moment while we need them
+# temporarily for paste_mask_in_image
+def expand_boxes(boxes, scale):
+ # type: (Tensor, float) -> Tensor
+ if torchvision._is_tracing():
+ return _onnx_expand_boxes(boxes, scale)
+ w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
+ h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
+ x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
+ y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
+
+ w_half *= scale
+ h_half *= scale
+
+ boxes_exp = torch.zeros_like(boxes)
+ boxes_exp[:, 0] = x_c - w_half
+ boxes_exp[:, 2] = x_c + w_half
+ boxes_exp[:, 1] = y_c - h_half
+ boxes_exp[:, 3] = y_c + h_half
+ return boxes_exp
+
+
+@torch.jit.unused
+def expand_masks_tracing_scale(M, padding):
+ # type: (int, int) -> float
+ return torch.tensor(M + 2 * padding).to(torch.float32) / torch.tensor(M).to(torch.float32)
+
+
+def expand_masks(mask, padding):
+ # type: (Tensor, int) -> Tuple[Tensor, float]
+ M = mask.shape[-1]
+ if torch._C._get_tracing_state(): # could not import is_tracing(), not sure why
+ scale = expand_masks_tracing_scale(M, padding)
+ else:
+ scale = float(M + 2 * padding) / M
+ padded_mask = F.pad(mask, (padding,) * 4)
+ return padded_mask, scale
+
+
+def paste_mask_in_image(mask, box, im_h, im_w):
+ # type: (Tensor, Tensor, int, int) -> Tensor
+
+ # refer to: https://github.com/pytorch/vision/issues/5845
+ TO_REMOVE = 1
+ w = int(box[2] - box[0] + TO_REMOVE)
+ h = int(box[3] - box[1] + TO_REMOVE)
+ w = max(w, 1)
+ h = max(h, 1)
+
+ # Set shape to [batch, C, H, W]
+ # 因为后续的bilinear操作只支持4-D的Tensor
+ mask = mask.expand((1, 1, -1, -1)) # -1 means not changing the size of that dimension
+
+ # Resize mask
+ mask = F.interpolate(mask, size=(h, w), mode='bilinear', align_corners=False)
+ mask = mask[0][0] # [batch, C, H, W] -> [H, W]
+
+ im_mask = torch.zeros((im_h, im_w), dtype=mask.dtype, device=mask.device)
+ # 填入原图的目标区域(防止越界)
+ x_0 = max(box[0], 0)
+ x_1 = min(box[2] + 1, im_w)
+ y_0 = max(box[1], 0)
+ y_1 = min(box[3] + 1, im_h)
+
+ # 将resize后的mask填入对应目标区域
+ im_mask[y_0:y_1, x_0:x_1] = mask[(y_0 - box[1]):(y_1 - box[1]), (x_0 - box[0]):(x_1 - box[0])]
+ return im_mask
+
+
+def paste_masks_in_image(masks, boxes, img_shape, padding=1):
+ # type: (Tensor, Tensor, Tuple[int, int], int) -> Tensor
+
+ # pytorch官方说对mask进行expand能够略微提升mAP
+ # refer to: https://github.com/pytorch/vision/issues/5845
+ masks, scale = expand_masks(masks, padding=padding)
+ boxes = expand_boxes(boxes, scale).to(dtype=torch.int64)
+ im_h, im_w = img_shape
+
+ if torchvision._is_tracing():
+ return _onnx_paste_mask_in_image_loop(
+ masks, boxes, torch.scalar_tensor(im_h, dtype=torch.int64), torch.scalar_tensor(im_w, dtype=torch.int64)
+ )[:, None]
+ res = [paste_mask_in_image(m[0], b, im_h, im_w) for m, b in zip(masks, boxes)]
+ if len(res) > 0:
+ ret = torch.stack(res, dim=0)[:, None] # [num_obj, 1, H, W]
+ else:
+ ret = masks.new_empty((0, 1, im_h, im_w))
+ return ret
+
+
+class GeneralizedRCNNTransform(nn.Module):
+ """
+ Performs input / target transformation before feeding the data to a GeneralizedRCNN
+ model.
+
+ The transformations it perform are:
+ - input normalization (mean subtraction and std division)
+ - input / target resizing to match min_size / max_size
+
+ It returns a ImageList for the inputs, and a List[Dict[Tensor]] for the targets
+ """
+
+ def __init__(self,
+ min_size: int,
+ max_size: int,
+ image_mean: List[float],
+ image_std: List[float],
+ size_divisible: int = 32,
+ fixed_size: Optional[Tuple[int, int]] = None):
+ super().__init__()
+ if not isinstance(min_size, (list, tuple)):
+ min_size = (min_size,)
+ self.min_size = min_size # 指定图像的最小边长范围
+ self.max_size = max_size # 指定图像的最大边长范围
+ self.image_mean = image_mean # 指定图像在标准化处理中的均值
+ self.image_std = image_std # 指定图像在标准化处理中的方差
+ self.size_divisible = size_divisible
+ self.fixed_size = fixed_size
+
+ def normalize(self, image):
+ """标准化处理"""
+ dtype, device = image.dtype, image.device
+ mean = torch.as_tensor(self.image_mean, dtype=dtype, device=device)
+ std = torch.as_tensor(self.image_std, dtype=dtype, device=device)
+ # [:, None, None]: shape [3] -> [3, 1, 1]
+ return (image - mean[:, None, None]) / std[:, None, None]
+
+ def torch_choice(self, k):
+ # type: (List[int]) -> int
+ """
+ Implements `random.choice` via torch ops so it can be compiled with
+ TorchScript. Remove if https://github.com/pytorch/pytorch/issues/25803
+ is fixed.
+ """
+ index = int(torch.empty(1).uniform_(0., float(len(k))).item())
+ return k[index]
+
+ def resize(self, image, target):
+ # type: (Tensor, Optional[Dict[str, Tensor]]) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]
+ """
+ 将图片缩放到指定的大小范围内,并对应缩放bboxes信息
+ Args:
+ image: 输入的图片
+ target: 输入图片的相关信息(包括bboxes信息)
+
+ Returns:
+ image: 缩放后的图片
+ target: 缩放bboxes后的图片相关信息
+ """
+ # image shape is [channel, height, width]
+ h, w = image.shape[-2:]
+
+ if self.training:
+ size = float(self.torch_choice(self.min_size)) # 指定输入图片的最小边长,注意是self.min_size不是min_size
+ else:
+ # FIXME assume for now that testing uses the largest scale
+ size = float(self.min_size[-1]) # 指定输入图片的最小边长,注意是self.min_size不是min_size
+
+ image, target = _resize_image_and_masks(image, size, float(self.max_size), target, self.fixed_size)
+
+ if target is None:
+ return image, target
+
+ bbox = target["boxes"]
+ # 根据图像的缩放比例来缩放bbox
+ bbox = resize_boxes(bbox, [h, w], image.shape[-2:])
+ target["boxes"] = bbox
+
+ return image, target
+
+ # _onnx_batch_images() is an implementation of
+ # batch_images() that is supported by ONNX tracing.
+ @torch.jit.unused
+ def _onnx_batch_images(self, images, size_divisible=32):
+ # type: (List[Tensor], int) -> Tensor
+ max_size = []
+ for i in range(images[0].dim()):
+ max_size_i = torch.max(torch.stack([img.shape[i] for img in images]).to(torch.float32)).to(torch.int64)
+ max_size.append(max_size_i)
+ stride = size_divisible
+ max_size[1] = (torch.ceil((max_size[1].to(torch.float32)) / stride) * stride).to(torch.int64)
+ max_size[2] = (torch.ceil((max_size[2].to(torch.float32)) / stride) * stride).to(torch.int64)
+ max_size = tuple(max_size)
+
+ # work around for
+ # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
+ # which is not yet supported in onnx
+ padded_imgs = []
+ for img in images:
+ padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
+ padded_img = torch.nn.functional.pad(img, [0, padding[2], 0, padding[1], 0, padding[0]])
+ padded_imgs.append(padded_img)
+
+ return torch.stack(padded_imgs)
+
+ def max_by_axis(self, the_list):
+ # type: (List[List[int]]) -> List[int]
+ maxes = the_list[0]
+ for sublist in the_list[1:]:
+ for index, item in enumerate(sublist):
+ maxes[index] = max(maxes[index], item)
+ return maxes
+
+ def batch_images(self, images, size_divisible=32):
+ # type: (List[Tensor], int) -> Tensor
+ """
+ 将一批图像打包成一个batch返回(注意batch中每个tensor的shape是相同的)
+ Args:
+ images: 输入的一批图片
+ size_divisible: 将图像高和宽调整到该数的整数倍
+
+ Returns:
+ batched_imgs: 打包成一个batch后的tensor数据
+ """
+
+ if torchvision._is_tracing():
+ # batch_images() does not export well to ONNX
+ # call _onnx_batch_images() instead
+ return self._onnx_batch_images(images, size_divisible)
+
+ # 分别计算一个batch中所有图片中的最大channel, height, width
+ max_size = self.max_by_axis([list(img.shape) for img in images])
+
+ stride = float(size_divisible)
+ # max_size = list(max_size)
+ # 将height向上调整到stride的整数倍
+ max_size[1] = int(math.ceil(float(max_size[1]) / stride) * stride)
+ # 将width向上调整到stride的整数倍
+ max_size[2] = int(math.ceil(float(max_size[2]) / stride) * stride)
+
+ # [batch, channel, height, width]
+ batch_shape = [len(images)] + max_size
+
+ # 创建shape为batch_shape且值全部为0的tensor
+ batched_imgs = images[0].new_full(batch_shape, 0)
+ for img, pad_img in zip(images, batched_imgs):
+ # 将输入images中的每张图片复制到新的batched_imgs的每张图片中,对齐左上角,保证bboxes的坐标不变
+ # 这样保证输入到网络中一个batch的每张图片的shape相同
+ # copy_: Copies the elements from src into self tensor and returns self
+ pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
+
+ return batched_imgs
+
+ def postprocess(self,
+ result, # type: List[Dict[str, Tensor]]
+ image_shapes, # type: List[Tuple[int, int]]
+ original_image_sizes # type: List[Tuple[int, int]]
+ ):
+ # type: (...) -> List[Dict[str, Tensor]]
+ """
+ 对网络的预测结果进行后处理(主要将bboxes还原到原图像尺度上)
+ Args:
+ result: list(dict), 网络的预测结果, len(result) == batch_size
+ image_shapes: list(torch.Size), 图像预处理缩放后的尺寸, len(image_shapes) == batch_size
+ original_image_sizes: list(torch.Size), 图像的原始尺寸, len(original_image_sizes) == batch_size
+
+ Returns:
+
+ """
+ if self.training:
+ return result
+
+ # 遍历每张图片的预测信息,将boxes信息还原回原尺度
+ for i, (pred, im_s, o_im_s) in enumerate(zip(result, image_shapes, original_image_sizes)):
+ boxes = pred["boxes"]
+ boxes = resize_boxes(boxes, im_s, o_im_s) # 将bboxes缩放回原图像尺度上
+ result[i]["boxes"] = boxes
+ if "masks" in pred:
+ masks = pred["masks"]
+ # 将mask映射回原图尺度
+ masks = paste_masks_in_image(masks, boxes, o_im_s)
+ result[i]["masks"] = masks
+
+ return result
+
+ def __repr__(self):
+ """自定义输出实例化对象的信息,可通过print打印实例信息"""
+ format_string = self.__class__.__name__ + '('
+ _indent = '\n '
+ format_string += "{0}Normalize(mean={1}, std={2})".format(_indent, self.image_mean, self.image_std)
+ format_string += "{0}Resize(min_size={1}, max_size={2}, mode='bilinear')".format(_indent, self.min_size,
+ self.max_size)
+ format_string += '\n)'
+ return format_string
+
+ def forward(self,
+ images, # type: List[Tensor]
+ targets=None # type: Optional[List[Dict[str, Tensor]]]
+ ):
+ # type: (...) -> Tuple[ImageList, Optional[List[Dict[str, Tensor]]]]
+ images = [img for img in images]
+ for i in range(len(images)):
+ image = images[i]
+ target_index = targets[i] if targets is not None else None
+
+ if image.dim() != 3:
+ raise ValueError("images is expected to be a list of 3d tensors "
+ "of shape [C, H, W], got {}".format(image.shape))
+ image = self.normalize(image) # 对图像进行标准化处理
+ image, target_index = self.resize(image, target_index) # 对图像和对应的bboxes缩放到指定范围
+ images[i] = image
+ if targets is not None and target_index is not None:
+ targets[i] = target_index
+
+ # 记录resize后的图像尺寸
+ image_sizes = [img.shape[-2:] for img in images]
+ images = self.batch_images(images, self.size_divisible) # 将images打包成一个batch
+ image_sizes_list = torch.jit.annotate(List[Tuple[int, int]], [])
+
+ for image_size in image_sizes:
+ assert len(image_size) == 2
+ image_sizes_list.append((image_size[0], image_size[1]))
+
+ image_list = ImageList(images, image_sizes_list)
+ return image_list, targets
+
+
+def resize_boxes(boxes, original_size, new_size):
+ # type: (Tensor, List[int], List[int]) -> Tensor
+ """
+ 将boxes参数根据图像的缩放情况进行相应缩放
+
+ Arguments:
+ original_size: 图像缩放前的尺寸
+ new_size: 图像缩放后的尺寸
+ """
+ ratios = [
+ torch.tensor(s, dtype=torch.float32, device=boxes.device) /
+ torch.tensor(s_orig, dtype=torch.float32, device=boxes.device)
+ for s, s_orig in zip(new_size, original_size)
+ ]
+ ratios_height, ratios_width = ratios
+ # Removes a tensor dimension, boxes [minibatch, 4]
+ # Returns a tuple of all slices along a given dimension, already without it.
+ xmin, ymin, xmax, ymax = boxes.unbind(1)
+ xmin = xmin * ratios_width
+ xmax = xmax * ratios_width
+ ymin = ymin * ratios_height
+ ymax = ymax * ratios_height
+ return torch.stack((xmin, ymin, xmax, ymax), dim=1)
+
+
+
+
+
+
+
+
diff --git a/pytorch_object_detection/mask_rcnn/pascal_voc_indices.json b/pytorch_object_detection/mask_rcnn/pascal_voc_indices.json
new file mode 100644
index 000000000..1c795887b
--- /dev/null
+++ b/pytorch_object_detection/mask_rcnn/pascal_voc_indices.json
@@ -0,0 +1,22 @@
+{
+ "1": "aeroplane",
+ "2": "bicycle",
+ "3": "bird",
+ "4": "boat",
+ "5": "bottle",
+ "6": "bus",
+ "7": "car",
+ "8": "cat",
+ "9": "chair",
+ "10": "cow",
+ "11": "diningtable",
+ "12": "dog",
+ "13": "horse",
+ "14": "motorbike",
+ "15": "person",
+ "16": "pottedplant",
+ "17": "sheep",
+ "18": "sofa",
+ "19": "train",
+ "20": "tvmonitor"
+}
\ No newline at end of file
diff --git a/pytorch_object_detection/mask_rcnn/plot_curve.py b/pytorch_object_detection/mask_rcnn/plot_curve.py
new file mode 100644
index 000000000..188df710e
--- /dev/null
+++ b/pytorch_object_detection/mask_rcnn/plot_curve.py
@@ -0,0 +1,46 @@
+import datetime
+import matplotlib.pyplot as plt
+
+
+def plot_loss_and_lr(train_loss, learning_rate):
+ try:
+ x = list(range(len(train_loss)))
+ fig, ax1 = plt.subplots(1, 1)
+ ax1.plot(x, train_loss, 'r', label='loss')
+ ax1.set_xlabel("step")
+ ax1.set_ylabel("loss")
+ ax1.set_title("Train Loss and lr")
+ plt.legend(loc='best')
+
+ ax2 = ax1.twinx()
+ ax2.plot(x, learning_rate, label='lr')
+ ax2.set_ylabel("learning rate")
+ ax2.set_xlim(0, len(train_loss)) # 设置横坐标整数间隔
+ plt.legend(loc='best')
+
+ handles1, labels1 = ax1.get_legend_handles_labels()
+ handles2, labels2 = ax2.get_legend_handles_labels()
+ plt.legend(handles1 + handles2, labels1 + labels2, loc='upper right')
+
+ fig.subplots_adjust(right=0.8) # 防止出现保存图片显示不全的情况
+ fig.savefig('./loss_and_lr{}.png'.format(datetime.datetime.now().strftime("%Y%m%d-%H%M%S")))
+ plt.close()
+ print("successful save loss curve! ")
+ except Exception as e:
+ print(e)
+
+
+def plot_map(mAP):
+ try:
+ x = list(range(len(mAP)))
+ plt.plot(x, mAP, label='mAp')
+ plt.xlabel('epoch')
+ plt.ylabel('mAP')
+ plt.title('Eval mAP')
+ plt.xlim(0, len(mAP))
+ plt.legend(loc='best')
+ plt.savefig('./mAP.png')
+ plt.close()
+ print("successful save mAP curve!")
+ except Exception as e:
+ print(e)
diff --git a/pytorch_object_detection/mask_rcnn/predict.py b/pytorch_object_detection/mask_rcnn/predict.py
new file mode 100644
index 000000000..46f086756
--- /dev/null
+++ b/pytorch_object_detection/mask_rcnn/predict.py
@@ -0,0 +1,106 @@
+import os
+import time
+import json
+
+import numpy as np
+from PIL import Image
+import matplotlib.pyplot as plt
+import torch
+from torchvision import transforms
+
+from network_files import MaskRCNN
+from backbone import resnet50_fpn_backbone
+from draw_box_utils import draw_objs
+
+
+def create_model(num_classes, box_thresh=0.5):
+ backbone = resnet50_fpn_backbone()
+ model = MaskRCNN(backbone,
+ num_classes=num_classes,
+ rpn_score_thresh=box_thresh,
+ box_score_thresh=box_thresh)
+
+ return model
+
+
+def time_synchronized():
+ torch.cuda.synchronize() if torch.cuda.is_available() else None
+ return time.time()
+
+
+def main():
+ num_classes = 90 # 不包含背景
+ box_thresh = 0.5
+ weights_path = "./save_weights/model_25.pth"
+ img_path = "./test.jpg"
+ label_json_path = './coco91_indices.json'
+
+ # get devices
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+ print("using {} device.".format(device))
+
+ # create model
+ model = create_model(num_classes=num_classes + 1, box_thresh=box_thresh)
+
+ # load train weights
+ assert os.path.exists(weights_path), "{} file dose not exist.".format(weights_path)
+ weights_dict = torch.load(weights_path, map_location='cpu')
+ weights_dict = weights_dict["model"] if "model" in weights_dict else weights_dict
+ model.load_state_dict(weights_dict)
+ model.to(device)
+
+ # read class_indict
+ assert os.path.exists(label_json_path), "json file {} dose not exist.".format(label_json_path)
+ with open(label_json_path, 'r') as json_file:
+ category_index = json.load(json_file)
+
+ # load image
+ assert os.path.exists(img_path), f"{img_path} does not exits."
+ original_img = Image.open(img_path).convert('RGB')
+
+ # from pil image to tensor, do not normalize image
+ data_transform = transforms.Compose([transforms.ToTensor()])
+ img = data_transform(original_img)
+ # expand batch dimension
+ img = torch.unsqueeze(img, dim=0)
+
+ model.eval() # 进入验证模式
+ with torch.no_grad():
+ # init
+ img_height, img_width = img.shape[-2:]
+ init_img = torch.zeros((1, 3, img_height, img_width), device=device)
+ model(init_img)
+
+ t_start = time_synchronized()
+ predictions = model(img.to(device))[0]
+ t_end = time_synchronized()
+ print("inference+NMS time: {}".format(t_end - t_start))
+
+ predict_boxes = predictions["boxes"].to("cpu").numpy()
+ predict_classes = predictions["labels"].to("cpu").numpy()
+ predict_scores = predictions["scores"].to("cpu").numpy()
+ predict_mask = predictions["masks"].to("cpu").numpy()
+ predict_mask = np.squeeze(predict_mask, axis=1) # [batch, 1, h, w] -> [batch, h, w]
+
+ if len(predict_boxes) == 0:
+ print("没有检测到任何目标!")
+ return
+
+ plot_img = draw_objs(original_img,
+ boxes=predict_boxes,
+ classes=predict_classes,
+ scores=predict_scores,
+ masks=predict_mask,
+ category_index=category_index,
+ line_thickness=3,
+ font='arial.ttf',
+ font_size=20)
+ plt.imshow(plot_img)
+ plt.show()
+ # 保存预测的图片结果
+ plot_img.save("test_result.jpg")
+
+
+if __name__ == '__main__':
+ main()
+
diff --git a/pytorch_object_detection/mask_rcnn/requirements.txt b/pytorch_object_detection/mask_rcnn/requirements.txt
new file mode 100644
index 000000000..9e524e23e
--- /dev/null
+++ b/pytorch_object_detection/mask_rcnn/requirements.txt
@@ -0,0 +1,8 @@
+lxml
+matplotlib
+numpy
+tqdm
+pycocotools
+Pillow
+torch==1.13.1
+torchvision==0.11.1
diff --git a/pytorch_object_detection/mask_rcnn/seg_results20220406-141544.txt b/pytorch_object_detection/mask_rcnn/seg_results20220406-141544.txt
new file mode 100644
index 000000000..ac46baf82
--- /dev/null
+++ b/pytorch_object_detection/mask_rcnn/seg_results20220406-141544.txt
@@ -0,0 +1,26 @@
+epoch:0 0.172 0.321 0.167 0.065 0.195 0.250 0.188 0.307 0.324 0.147 0.366 0.440 1.3826 0.08
+epoch:1 0.223 0.395 0.225 0.092 0.249 0.322 0.222 0.354 0.372 0.186 0.413 0.499 1.0356 0.08
+epoch:2 0.235 0.408 0.241 0.100 0.258 0.350 0.230 0.372 0.392 0.204 0.429 0.517 0.9718 0.08
+epoch:3 0.246 0.426 0.252 0.103 0.267 0.357 0.241 0.386 0.408 0.225 0.448 0.521 0.9363 0.08
+epoch:4 0.250 0.424 0.257 0.106 0.272 0.367 0.242 0.381 0.400 0.210 0.438 0.530 0.9145 0.08
+epoch:5 0.255 0.434 0.262 0.109 0.279 0.375 0.242 0.379 0.398 0.209 0.433 0.534 0.8982 0.08
+epoch:6 0.270 0.456 0.283 0.120 0.293 0.392 0.254 0.403 0.421 0.229 0.462 0.551 0.8859 0.08
+epoch:7 0.269 0.455 0.280 0.118 0.296 0.388 0.257 0.402 0.421 0.228 0.454 0.564 0.8771 0.08
+epoch:8 0.276 0.465 0.290 0.120 0.301 0.398 0.255 0.401 0.418 0.227 0.461 0.553 0.8685 0.08
+epoch:9 0.271 0.458 0.282 0.113 0.297 0.404 0.253 0.398 0.417 0.211 0.460 0.570 0.8612 0.08
+epoch:10 0.277 0.463 0.289 0.119 0.299 0.410 0.258 0.405 0.425 0.221 0.466 0.558 0.8547 0.08
+epoch:11 0.276 0.463 0.287 0.122 0.304 0.405 0.259 0.406 0.425 0.236 0.466 0.559 0.8498 0.08
+epoch:12 0.276 0.464 0.288 0.127 0.294 0.409 0.257 0.406 0.425 0.236 0.459 0.563 0.8461 0.08
+epoch:13 0.284 0.477 0.296 0.124 0.311 0.412 0.262 0.407 0.429 0.229 0.474 0.555 0.8409 0.08
+epoch:14 0.277 0.464 0.292 0.121 0.304 0.397 0.257 0.410 0.431 0.238 0.473 0.565 0.8355 0.08
+epoch:15 0.282 0.474 0.296 0.121 0.308 0.413 0.264 0.411 0.432 0.231 0.473 0.575 0.833 0.08
+epoch:16 0.336 0.549 0.356 0.149 0.367 0.491 0.288 0.451 0.473 0.269 0.519 0.620 0.7421 0.008
+epoch:17 0.339 0.553 0.360 0.153 0.371 0.496 0.292 0.454 0.475 0.271 0.518 0.624 0.7157 0.008
+epoch:18 0.340 0.553 0.361 0.150 0.371 0.494 0.290 0.453 0.473 0.269 0.516 0.620 0.7016 0.008
+epoch:19 0.341 0.555 0.363 0.154 0.372 0.500 0.293 0.458 0.478 0.273 0.522 0.630 0.6897 0.008
+epoch:20 0.340 0.554 0.361 0.154 0.370 0.496 0.289 0.450 0.471 0.266 0.514 0.622 0.6802 0.008
+epoch:21 0.338 0.552 0.358 0.151 0.367 0.500 0.289 0.447 0.467 0.262 0.507 0.622 0.6708 0.008
+epoch:22 0.340 0.553 0.360 0.151 0.370 0.500 0.290 0.450 0.470 0.267 0.513 0.623 0.6497 0.0008
+epoch:23 0.340 0.552 0.361 0.151 0.369 0.500 0.290 0.449 0.468 0.266 0.509 0.619 0.6447 0.0008
+epoch:24 0.339 0.552 0.359 0.150 0.369 0.500 0.290 0.448 0.468 0.264 0.510 0.619 0.6421 0.0008
+epoch:25 0.338 0.551 0.359 0.152 0.367 0.500 0.289 0.448 0.467 0.264 0.509 0.618 0.6398 0.0008
diff --git a/pytorch_object_detection/mask_rcnn/train.py b/pytorch_object_detection/mask_rcnn/train.py
new file mode 100644
index 000000000..3f5179d61
--- /dev/null
+++ b/pytorch_object_detection/mask_rcnn/train.py
@@ -0,0 +1,240 @@
+import os
+import datetime
+
+import torch
+from torchvision.ops.misc import FrozenBatchNorm2d
+
+import transforms
+from network_files import MaskRCNN
+from backbone import resnet50_fpn_backbone
+from my_dataset_coco import CocoDetection
+from my_dataset_voc import VOCInstances
+from train_utils import train_eval_utils as utils
+from train_utils import GroupedBatchSampler, create_aspect_ratio_groups
+
+
+def create_model(num_classes, load_pretrain_weights=True):
+ # 如果GPU显存很小,batch_size不能设置很大,建议将norm_layer设置成FrozenBatchNorm2d(默认是nn.BatchNorm2d)
+ # FrozenBatchNorm2d的功能与BatchNorm2d类似,但参数无法更新
+ # trainable_layers包括['layer4', 'layer3', 'layer2', 'layer1', 'conv1'], 5代表全部训练
+ # backbone = resnet50_fpn_backbone(norm_layer=FrozenBatchNorm2d,
+ # trainable_layers=3)
+ # resnet50 imagenet weights url: https://download.pytorch.org/models/resnet50-0676ba61.pth
+ backbone = resnet50_fpn_backbone(pretrain_path="resnet50.pth", trainable_layers=3)
+
+ model = MaskRCNN(backbone, num_classes=num_classes)
+
+ if load_pretrain_weights:
+ # coco weights url: "https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth"
+ weights_dict = torch.load("./maskrcnn_resnet50_fpn_coco.pth", map_location="cpu")
+ for k in list(weights_dict.keys()):
+ if ("box_predictor" in k) or ("mask_fcn_logits" in k):
+ del weights_dict[k]
+
+ print(model.load_state_dict(weights_dict, strict=False))
+
+ return model
+
+
+def main(args):
+ device = torch.device(args.device if torch.cuda.is_available() else "cpu")
+ print("Using {} device training.".format(device.type))
+
+ # 用来保存coco_info的文件
+ now = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
+ det_results_file = f"det_results{now}.txt"
+ seg_results_file = f"seg_results{now}.txt"
+
+ data_transform = {
+ "train": transforms.Compose([transforms.ToTensor(),
+ transforms.RandomHorizontalFlip(0.5)]),
+ "val": transforms.Compose([transforms.ToTensor()])
+ }
+
+ data_root = args.data_path
+
+ # load train data set
+ # coco2017 -> annotations -> instances_train2017.json
+ train_dataset = CocoDetection(data_root, "train", data_transform["train"])
+ # VOCdevkit -> VOC2012 -> ImageSets -> Main -> train.txt
+ # train_dataset = VOCInstances(data_root, year="2012", txt_name="train.txt", transforms=data_transform["train"])
+ train_sampler = None
+
+ # 是否按图片相似高宽比采样图片组成batch
+ # 使用的话能够减小训练时所需GPU显存,默认使用
+ if args.aspect_ratio_group_factor >= 0:
+ train_sampler = torch.utils.data.RandomSampler(train_dataset)
+ # 统计所有图像高宽比例在bins区间中的位置索引
+ group_ids = create_aspect_ratio_groups(train_dataset, k=args.aspect_ratio_group_factor)
+ # 每个batch图片从同一高宽比例区间中取
+ train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids, args.batch_size)
+
+ # 注意这里的collate_fn是自定义的,因为读取的数据包括image和targets,不能直接使用默认的方法合成batch
+ batch_size = args.batch_size
+ nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
+ print('Using %g dataloader workers' % nw)
+
+ if train_sampler:
+ # 如果按照图片高宽比采样图片,dataloader中需要使用batch_sampler
+ train_data_loader = torch.utils.data.DataLoader(train_dataset,
+ batch_sampler=train_batch_sampler,
+ pin_memory=True,
+ num_workers=nw,
+ collate_fn=train_dataset.collate_fn)
+ else:
+ train_data_loader = torch.utils.data.DataLoader(train_dataset,
+ batch_size=batch_size,
+ shuffle=True,
+ pin_memory=True,
+ num_workers=nw,
+ collate_fn=train_dataset.collate_fn)
+
+ # load validation data set
+ # coco2017 -> annotations -> instances_val2017.json
+ val_dataset = CocoDetection(data_root, "val", data_transform["val"])
+ # VOCdevkit -> VOC2012 -> ImageSets -> Main -> val.txt
+ # val_dataset = VOCInstances(data_root, year="2012", txt_name="val.txt", transforms=data_transform["val"])
+ val_data_loader = torch.utils.data.DataLoader(val_dataset,
+ batch_size=1,
+ shuffle=False,
+ pin_memory=True,
+ num_workers=nw,
+ collate_fn=train_dataset.collate_fn)
+
+ # create model num_classes equal background + classes
+ model = create_model(num_classes=args.num_classes + 1, load_pretrain_weights=args.pretrain)
+ model.to(device)
+
+ train_loss = []
+ learning_rate = []
+ val_map = []
+
+ # define optimizer
+ params = [p for p in model.parameters() if p.requires_grad]
+ optimizer = torch.optim.SGD(params, lr=args.lr,
+ momentum=args.momentum,
+ weight_decay=args.weight_decay)
+
+ scaler = torch.cuda.amp.GradScaler() if args.amp else None
+
+ # learning rate scheduler
+ lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
+ milestones=args.lr_steps,
+ gamma=args.lr_gamma)
+ # 如果传入resume参数,即上次训练的权重地址,则接着上次的参数训练
+ if args.resume:
+ # If map_location is missing, torch.load will first load the module to CPU
+ # and then copy each parameter to where it was saved,
+ # which would result in all processes on the same machine using the same set of devices.
+ checkpoint = torch.load(args.resume, map_location='cpu') # 读取之前保存的权重文件(包括优化器以及学习率策略)
+ model.load_state_dict(checkpoint['model'])
+ optimizer.load_state_dict(checkpoint['optimizer'])
+ lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
+ args.start_epoch = checkpoint['epoch'] + 1
+ if args.amp and "scaler" in checkpoint:
+ scaler.load_state_dict(checkpoint["scaler"])
+
+ for epoch in range(args.start_epoch, args.epochs):
+ # train for one epoch, printing every 50 iterations
+ mean_loss, lr = utils.train_one_epoch(model, optimizer, train_data_loader,
+ device, epoch, print_freq=50,
+ warmup=True, scaler=scaler)
+ train_loss.append(mean_loss.item())
+ learning_rate.append(lr)
+
+ # update the learning rate
+ lr_scheduler.step()
+
+ # evaluate on the test dataset
+ det_info, seg_info = utils.evaluate(model, val_data_loader, device=device)
+
+ # write detection into txt
+ with open(det_results_file, "a") as f:
+ # 写入的数据包括coco指标还有loss和learning rate
+ result_info = [f"{i:.4f}" for i in det_info + [mean_loss.item()]] + [f"{lr:.6f}"]
+ txt = "epoch:{} {}".format(epoch, ' '.join(result_info))
+ f.write(txt + "\n")
+
+ # write seg into txt
+ with open(seg_results_file, "a") as f:
+ # 写入的数据包括coco指标还有loss和learning rate
+ result_info = [f"{i:.4f}" for i in seg_info + [mean_loss.item()]] + [f"{lr:.6f}"]
+ txt = "epoch:{} {}".format(epoch, ' '.join(result_info))
+ f.write(txt + "\n")
+
+ val_map.append(det_info[1]) # pascal mAP
+
+ # save weights
+ save_files = {
+ 'model': model.state_dict(),
+ 'optimizer': optimizer.state_dict(),
+ 'lr_scheduler': lr_scheduler.state_dict(),
+ 'epoch': epoch}
+ if args.amp:
+ save_files["scaler"] = scaler.state_dict()
+ torch.save(save_files, "./save_weights/model_{}.pth".format(epoch))
+
+ # plot loss and lr curve
+ if len(train_loss) != 0 and len(learning_rate) != 0:
+ from plot_curve import plot_loss_and_lr
+ plot_loss_and_lr(train_loss, learning_rate)
+
+ # plot mAP curve
+ if len(val_map) != 0:
+ from plot_curve import plot_map
+ plot_map(val_map)
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser(
+ description=__doc__)
+
+ # 训练设备类型
+ parser.add_argument('--device', default='cuda:0', help='device')
+ # 训练数据集的根目录
+ parser.add_argument('--data-path', default='/data/coco2017', help='dataset')
+ # 检测目标类别数(不包含背景)
+ parser.add_argument('--num-classes', default=90, type=int, help='num_classes')
+ # 文件保存地址
+ parser.add_argument('--output-dir', default='./save_weights', help='path where to save')
+ # 若需要接着上次训练,则指定上次训练保存权重文件地址
+ parser.add_argument('--resume', default='', type=str, help='resume from checkpoint')
+ # 指定接着从哪个epoch数开始训练
+ parser.add_argument('--start_epoch', default=0, type=int, help='start epoch')
+ # 训练的总epoch数
+ parser.add_argument('--epochs', default=26, type=int, metavar='N',
+ help='number of total epochs to run')
+ # 学习率
+ parser.add_argument('--lr', default=0.004, type=float,
+ help='initial learning rate, 0.02 is the default value for training '
+ 'on 8 gpus and 2 images_per_gpu')
+ # SGD的momentum参数
+ parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
+ help='momentum')
+ # SGD的weight_decay参数
+ parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
+ metavar='W', help='weight decay (default: 1e-4)',
+ dest='weight_decay')
+ # 针对torch.optim.lr_scheduler.MultiStepLR的参数
+ parser.add_argument('--lr-steps', default=[16, 22], nargs='+', type=int,
+ help='decrease lr every step-size epochs')
+ # 针对torch.optim.lr_scheduler.MultiStepLR的参数
+ parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma')
+ # 训练的batch size(如果内存/GPU显存充裕,建议设置更大)
+ parser.add_argument('--batch_size', default=2, type=int, metavar='N',
+ help='batch size when training.')
+ parser.add_argument('--aspect-ratio-group-factor', default=3, type=int)
+ parser.add_argument("--pretrain", type=bool, default=True, help="load COCO pretrain weights.")
+ # 是否使用混合精度训练(需要GPU支持混合精度)
+ parser.add_argument("--amp", default=False, help="Use torch.cuda.amp for mixed precision training")
+
+ args = parser.parse_args()
+ print(args)
+
+ # 检查保存权重文件夹是否存在,不存在则创建
+ if not os.path.exists(args.output_dir):
+ os.makedirs(args.output_dir)
+
+ main(args)
diff --git a/pytorch_object_detection/mask_rcnn/train_multi_GPU.py b/pytorch_object_detection/mask_rcnn/train_multi_GPU.py
new file mode 100644
index 000000000..05647edef
--- /dev/null
+++ b/pytorch_object_detection/mask_rcnn/train_multi_GPU.py
@@ -0,0 +1,268 @@
+import time
+import os
+import datetime
+
+import torch
+from torchvision.ops.misc import FrozenBatchNorm2d
+
+import transforms
+from my_dataset_coco import CocoDetection
+from my_dataset_voc import VOCInstances
+from backbone import resnet50_fpn_backbone
+from network_files import MaskRCNN
+import train_utils.train_eval_utils as utils
+from train_utils import GroupedBatchSampler, create_aspect_ratio_groups, init_distributed_mode, save_on_master, mkdir
+
+
+def create_model(num_classes, load_pretrain_weights=True):
+ # 如果GPU显存很小,batch_size不能设置很大,建议将norm_layer设置成FrozenBatchNorm2d(默认是nn.BatchNorm2d)
+ # FrozenBatchNorm2d的功能与BatchNorm2d类似,但参数无法更新
+ # trainable_layers包括['layer4', 'layer3', 'layer2', 'layer1', 'conv1'], 5代表全部训练
+ # backbone = resnet50_fpn_backbone(norm_layer=FrozenBatchNorm2d,
+ # trainable_layers=3)
+ # resnet50 imagenet weights url: https://download.pytorch.org/models/resnet50-0676ba61.pth
+ backbone = resnet50_fpn_backbone(pretrain_path="resnet50.pth", trainable_layers=3)
+ model = MaskRCNN(backbone, num_classes=num_classes)
+
+ if load_pretrain_weights:
+ # coco weights url: "https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth"
+ weights_dict = torch.load("./maskrcnn_resnet50_fpn_coco.pth", map_location="cpu")
+ for k in list(weights_dict.keys()):
+ if ("box_predictor" in k) or ("mask_fcn_logits" in k):
+ del weights_dict[k]
+
+ print(model.load_state_dict(weights_dict, strict=False))
+
+ return model
+
+
+def main(args):
+ init_distributed_mode(args)
+ print(args)
+
+ device = torch.device(args.device)
+
+ # 用来保存coco_info的文件
+ now = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
+ det_results_file = f"det_results{now}.txt"
+ seg_results_file = f"seg_results{now}.txt"
+
+ # Data loading code
+ print("Loading data")
+
+ data_transform = {
+ "train": transforms.Compose([transforms.ToTensor(),
+ transforms.RandomHorizontalFlip(0.5)]),
+ "val": transforms.Compose([transforms.ToTensor()])
+ }
+
+ COCO_root = args.data_path
+
+ # load train data set
+ # coco2017 -> annotations -> instances_train2017.json
+ train_dataset = CocoDetection(COCO_root, "train", data_transform["train"])
+ # VOCdevkit -> VOC2012 -> ImageSets -> Main -> train.txt
+ # train_dataset = VOCInstances(data_root, year="2012", txt_name="train.txt")
+
+ # load validation data set
+ # coco2017 -> annotations -> instances_val2017.json
+ val_dataset = CocoDetection(COCO_root, "val", data_transform["val"])
+ # VOCdevkit -> VOC2012 -> ImageSets -> Main -> val.txt
+ # val_dataset = VOCInstances(data_root, year="2012", txt_name="val.txt")
+
+ print("Creating data loaders")
+ if args.distributed:
+ train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
+ test_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
+ else:
+ train_sampler = torch.utils.data.RandomSampler(train_dataset)
+ test_sampler = torch.utils.data.SequentialSampler(val_dataset)
+
+ if args.aspect_ratio_group_factor >= 0:
+ # 统计所有图像比例在bins区间中的位置索引
+ group_ids = create_aspect_ratio_groups(train_dataset, k=args.aspect_ratio_group_factor)
+ train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids, args.batch_size)
+ else:
+ train_batch_sampler = torch.utils.data.BatchSampler(
+ train_sampler, args.batch_size, drop_last=True)
+
+ data_loader = torch.utils.data.DataLoader(
+ train_dataset, batch_sampler=train_batch_sampler, num_workers=args.workers,
+ collate_fn=train_dataset.collate_fn)
+
+ data_loader_test = torch.utils.data.DataLoader(
+ val_dataset, batch_size=1,
+ sampler=test_sampler, num_workers=args.workers,
+ collate_fn=train_dataset.collate_fn)
+
+ print("Creating model")
+ # create model num_classes equal background + classes
+ model = create_model(num_classes=args.num_classes + 1, load_pretrain_weights=args.pretrain)
+ model.to(device)
+
+ if args.distributed and args.sync_bn:
+ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
+
+ model_without_ddp = model
+ if args.distributed:
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
+ model_without_ddp = model.module
+
+ params = [p for p in model.parameters() if p.requires_grad]
+ optimizer = torch.optim.SGD(
+ params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
+
+ scaler = torch.cuda.amp.GradScaler() if args.amp else None
+
+ # lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
+ lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_steps, gamma=args.lr_gamma)
+
+ # 如果传入resume参数,即上次训练的权重地址,则接着上次的参数训练
+ if args.resume:
+ # If map_location is missing, torch.load will first load the module to CPU
+ # and then copy each parameter to where it was saved,
+ # which would result in all processes on the same machine using the same set of devices.
+ checkpoint = torch.load(args.resume, map_location='cpu') # 读取之前保存的权重文件(包括优化器以及学习率策略)
+ model_without_ddp.load_state_dict(checkpoint['model'])
+ optimizer.load_state_dict(checkpoint['optimizer'])
+ lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
+ args.start_epoch = checkpoint['epoch'] + 1
+ if args.amp and "scaler" in checkpoint:
+ scaler.load_state_dict(checkpoint["scaler"])
+
+ if args.test_only:
+ utils.evaluate(model, data_loader_test, device=device)
+ return
+
+ train_loss = []
+ learning_rate = []
+ val_map = []
+
+ print("Start training")
+ start_time = time.time()
+ for epoch in range(args.start_epoch, args.epochs):
+ if args.distributed:
+ train_sampler.set_epoch(epoch)
+ mean_loss, lr = utils.train_one_epoch(model, optimizer, data_loader,
+ device, epoch, args.print_freq,
+ warmup=True, scaler=scaler)
+
+ # update learning rate
+ lr_scheduler.step()
+
+ # evaluate after every epoch
+ det_info, seg_info = utils.evaluate(model, data_loader_test, device=device)
+
+ # 只在主进程上进行写操作
+ if args.rank in [-1, 0]:
+ train_loss.append(mean_loss.item())
+ learning_rate.append(lr)
+ val_map.append(det_info[1]) # pascal mAP
+
+ # write into txt
+ with open(det_results_file, "a") as f:
+ # 写入的数据包括coco指标还有loss和learning rate
+ result_info = [f"{i:.4f}" for i in det_info + [mean_loss.item()]] + [f"{lr:.6f}"]
+ txt = "epoch:{} {}".format(epoch, ' '.join(result_info))
+ f.write(txt + "\n")
+
+ with open(seg_results_file, "a") as f:
+ # 写入的数据包括coco指标还有loss和learning rate
+ result_info = [f"{i:.4f}" for i in seg_info + [mean_loss.item()]] + [f"{lr:.6f}"]
+ txt = "epoch:{} {}".format(epoch, ' '.join(result_info))
+ f.write(txt + "\n")
+
+ if args.output_dir:
+ # 只在主进程上执行保存权重操作
+ save_files = {'model': model_without_ddp.state_dict(),
+ 'optimizer': optimizer.state_dict(),
+ 'lr_scheduler': lr_scheduler.state_dict(),
+ 'args': args,
+ 'epoch': epoch}
+ if args.amp:
+ save_files["scaler"] = scaler.state_dict()
+ save_on_master(save_files,
+ os.path.join(args.output_dir, f'model_{epoch}.pth'))
+
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print('Training time {}'.format(total_time_str))
+
+ if args.rank in [-1, 0]:
+ # plot loss and lr curve
+ if len(train_loss) != 0 and len(learning_rate) != 0:
+ from plot_curve import plot_loss_and_lr
+ plot_loss_and_lr(train_loss, learning_rate)
+
+ # plot mAP curve
+ if len(val_map) != 0:
+ from plot_curve import plot_map
+ plot_map(val_map)
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser(
+ description=__doc__)
+
+ # 训练文件的根目录(coco2017)
+ parser.add_argument('--data-path', default='/data/coco2017', help='dataset')
+ # 训练设备类型
+ parser.add_argument('--device', default='cuda', help='device')
+ # 检测目标类别数(不包含背景)
+ parser.add_argument('--num-classes', default=90, type=int, help='num_classes')
+ # 每块GPU上的batch_size
+ parser.add_argument('-b', '--batch-size', default=4, type=int,
+ help='images per gpu, the total batch size is $NGPU x batch_size')
+ # 指定接着从哪个epoch数开始训练
+ parser.add_argument('--start_epoch', default=0, type=int, help='start epoch')
+ # 训练的总epoch数
+ parser.add_argument('--epochs', default=26, type=int, metavar='N',
+ help='number of total epochs to run')
+ # 数据加载以及预处理的线程数
+ parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
+ help='number of data loading workers (default: 4)')
+ # 学习率,这个需要根据gpu的数量以及batch_size进行设置0.02 / bs * num_GPU
+ parser.add_argument('--lr', default=0.005, type=float,
+ help='initial learning rate, 0.02 is the default value for training '
+ 'on 8 gpus and 2 images_per_gpu')
+ # SGD的momentum参数
+ parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
+ help='momentum')
+ # SGD的weight_decay参数
+ parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
+ metavar='W', help='weight decay (default: 1e-4)',
+ dest='weight_decay')
+ # 针对torch.optim.lr_scheduler.StepLR的参数
+ parser.add_argument('--lr-step-size', default=8, type=int, help='decrease lr every step-size epochs')
+ # 针对torch.optim.lr_scheduler.MultiStepLR的参数
+ parser.add_argument('--lr-steps', default=[16, 22], nargs='+', type=int,
+ help='decrease lr every step-size epochs')
+ # 针对torch.optim.lr_scheduler.MultiStepLR的参数
+ parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma')
+ # 训练过程打印信息的频率
+ parser.add_argument('--print-freq', default=50, type=int, help='print frequency')
+ # 文件保存地址
+ parser.add_argument('--output-dir', default='./multi_train', help='path where to save')
+ # 基于上次的训练结果接着训练
+ parser.add_argument('--resume', default='', help='resume from checkpoint')
+ parser.add_argument('--aspect-ratio-group-factor', default=3, type=int)
+ parser.add_argument('--test-only', action="http://github.com/store_true", help="test only")
+
+ # 开启的进程数(注意不是线程)
+ parser.add_argument('--world-size', default=4, type=int,
+ help='number of distributed processes')
+ parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training')
+ parser.add_argument("--sync-bn", dest="sync_bn", help="Use sync batch norm", type=bool, default=False)
+ parser.add_argument("--pretrain", type=bool, default=True, help="load COCO pretrain weights.")
+ # 是否使用混合精度训练(需要GPU支持混合精度)
+ parser.add_argument("--amp", default=False, help="Use torch.cuda.amp for mixed precision training")
+
+ args = parser.parse_args()
+
+ # 如果指定了保存文件地址,检查文件夹是否存在,若不存在,则创建
+ if args.output_dir:
+ mkdir(args.output_dir)
+
+ main(args)
diff --git a/pytorch_object_detection/mask_rcnn/train_utils/__init__.py b/pytorch_object_detection/mask_rcnn/train_utils/__init__.py
new file mode 100644
index 000000000..3dfa7eadc
--- /dev/null
+++ b/pytorch_object_detection/mask_rcnn/train_utils/__init__.py
@@ -0,0 +1,4 @@
+from .group_by_aspect_ratio import GroupedBatchSampler, create_aspect_ratio_groups
+from .distributed_utils import init_distributed_mode, save_on_master, mkdir
+from .coco_eval import EvalCOCOMetric
+from .coco_utils import coco_remove_images_without_annotations, convert_coco_poly_mask, convert_to_coco_api
diff --git a/pytorch_object_detection/mask_rcnn/train_utils/coco_eval.py b/pytorch_object_detection/mask_rcnn/train_utils/coco_eval.py
new file mode 100644
index 000000000..b8df0204d
--- /dev/null
+++ b/pytorch_object_detection/mask_rcnn/train_utils/coco_eval.py
@@ -0,0 +1,163 @@
+import json
+import copy
+
+import numpy as np
+from pycocotools.coco import COCO
+from pycocotools.cocoeval import COCOeval
+import pycocotools.mask as mask_util
+from .distributed_utils import all_gather, is_main_process
+
+
+def merge(img_ids, eval_results):
+ """将多个进程之间的数据汇总在一起"""
+ all_img_ids = all_gather(img_ids)
+ all_eval_results = all_gather(eval_results)
+
+ merged_img_ids = []
+ for p in all_img_ids:
+ merged_img_ids.extend(p)
+
+ merged_eval_results = []
+ for p in all_eval_results:
+ merged_eval_results.extend(p)
+
+ merged_img_ids = np.array(merged_img_ids)
+
+ # keep only unique (and in sorted order) images
+ # 去除重复的图片索引,多GPU训练时为了保证每个进程的训练图片数量相同,可能将一张图片分配给多个进程
+ merged_img_ids, idx = np.unique(merged_img_ids, return_index=True)
+ merged_eval_results = [merged_eval_results[i] for i in idx]
+
+ return list(merged_img_ids), merged_eval_results
+
+
+class EvalCOCOMetric:
+ def __init__(self,
+ coco: COCO = None,
+ iou_type: str = None,
+ results_file_name: str = "predict_results.json",
+ classes_mapping: dict = None):
+ self.coco = copy.deepcopy(coco)
+ self.img_ids = [] # 记录每个进程处理图片的ids
+ self.results = []
+ self.aggregation_results = None
+ self.classes_mapping = classes_mapping
+ self.coco_evaluator = None
+ assert iou_type in ["bbox", "segm", "keypoints"]
+ self.iou_type = iou_type
+ self.results_file_name = results_file_name
+
+ def prepare_for_coco_detection(self, targets, outputs):
+ """将预测的结果转换成COCOeval指定的格式,针对目标检测任务"""
+ # 遍历每张图像的预测结果
+ for target, output in zip(targets, outputs):
+ if len(output) == 0:
+ continue
+
+ img_id = int(target["image_id"])
+ if img_id in self.img_ids:
+ # 防止出现重复的数据
+ continue
+ self.img_ids.append(img_id)
+ per_image_boxes = output["boxes"]
+ # 对于coco_eval, 需要的每个box的数据格式为[x_min, y_min, w, h]
+ # 而我们预测的box格式是[x_min, y_min, x_max, y_max],所以需要转下格式
+ per_image_boxes[:, 2:] -= per_image_boxes[:, :2]
+ per_image_classes = output["labels"].tolist()
+ per_image_scores = output["scores"].tolist()
+
+ res_list = []
+ # 遍历每个目标的信息
+ for object_score, object_class, object_box in zip(
+ per_image_scores, per_image_classes, per_image_boxes):
+ object_score = float(object_score)
+ class_idx = int(object_class)
+ if self.classes_mapping is not None:
+ class_idx = int(self.classes_mapping[str(class_idx)])
+ # We recommend rounding coordinates to the nearest tenth of a pixel
+ # to reduce resulting JSON file size.
+ object_box = [round(b, 2) for b in object_box.tolist()]
+
+ res = {"image_id": img_id,
+ "category_id": class_idx,
+ "bbox": object_box,
+ "score": round(object_score, 3)}
+ res_list.append(res)
+ self.results.append(res_list)
+
+ def prepare_for_coco_segmentation(self, targets, outputs):
+ """将预测的结果转换成COCOeval指定的格式,针对实例分割任务"""
+ # 遍历每张图像的预测结果
+ for target, output in zip(targets, outputs):
+ if len(output) == 0:
+ continue
+
+ img_id = int(target["image_id"])
+ if img_id in self.img_ids:
+ # 防止出现重复的数据
+ continue
+
+ self.img_ids.append(img_id)
+ per_image_masks = output["masks"]
+ per_image_classes = output["labels"].tolist()
+ per_image_scores = output["scores"].tolist()
+
+ masks = per_image_masks > 0.5
+
+ res_list = []
+ # 遍历每个目标的信息
+ for mask, label, score in zip(masks, per_image_classes, per_image_scores):
+ rle = mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0]
+ rle["counts"] = rle["counts"].decode("utf-8")
+
+ class_idx = int(label)
+ if self.classes_mapping is not None:
+ class_idx = int(self.classes_mapping[str(class_idx)])
+
+ res = {"image_id": img_id,
+ "category_id": class_idx,
+ "segmentation": rle,
+ "score": round(score, 3)}
+ res_list.append(res)
+ self.results.append(res_list)
+
+ def update(self, targets, outputs):
+ if self.iou_type == "bbox":
+ self.prepare_for_coco_detection(targets, outputs)
+ elif self.iou_type == "segm":
+ self.prepare_for_coco_segmentation(targets, outputs)
+ else:
+ raise KeyError(f"not support iou_type: {self.iou_type}")
+
+ def synchronize_results(self):
+ # 同步所有进程中的数据
+ eval_ids, eval_results = merge(self.img_ids, self.results)
+ self.aggregation_results = {"img_ids": eval_ids, "results": eval_results}
+
+ # 主进程上保存即可
+ if is_main_process():
+ results = []
+ [results.extend(i) for i in eval_results]
+ # write predict results into json file
+ json_str = json.dumps(results, indent=4)
+ with open(self.results_file_name, 'w') as json_file:
+ json_file.write(json_str)
+
+ def evaluate(self):
+ # 只在主进程上评估即可
+ if is_main_process():
+ # accumulate predictions from all images
+ coco_true = self.coco
+ coco_pre = coco_true.loadRes(self.results_file_name)
+
+ self.coco_evaluator = COCOeval(cocoGt=coco_true, cocoDt=coco_pre, iouType=self.iou_type)
+
+ self.coco_evaluator.evaluate()
+ self.coco_evaluator.accumulate()
+ print(f"IoU metric: {self.iou_type}")
+ self.coco_evaluator.summarize()
+
+ coco_info = self.coco_evaluator.stats.tolist() # numpy to list
+ return coco_info
+ else:
+ return None
diff --git a/pytorch_object_detection/mask_rcnn/train_utils/coco_utils.py b/pytorch_object_detection/mask_rcnn/train_utils/coco_utils.py
new file mode 100644
index 000000000..7a3b3122e
--- /dev/null
+++ b/pytorch_object_detection/mask_rcnn/train_utils/coco_utils.py
@@ -0,0 +1,98 @@
+import torch
+import torch.utils.data
+from pycocotools import mask as coco_mask
+from pycocotools.coco import COCO
+
+
+def coco_remove_images_without_annotations(dataset, ids):
+ """
+ 删除coco数据集中没有目标,或者目标面积非常小的数据
+ refer to:
+ https://github.com/pytorch/vision/blob/master/references/detection/coco_utils.py
+ :param dataset:
+ :param cat_list:
+ :return:
+ """
+ def _has_only_empty_bbox(anno):
+ return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno)
+
+ def _has_valid_annotation(anno):
+ # if it's empty, there is no annotation
+ if len(anno) == 0:
+ return False
+ # if all boxes have close to zero area, there is no annotation
+ if _has_only_empty_bbox(anno):
+ return False
+
+ return True
+
+ valid_ids = []
+ for ds_idx, img_id in enumerate(ids):
+ ann_ids = dataset.getAnnIds(imgIds=img_id, iscrowd=None)
+ anno = dataset.loadAnns(ann_ids)
+
+ if _has_valid_annotation(anno):
+ valid_ids.append(img_id)
+
+ return valid_ids
+
+
+def convert_coco_poly_mask(segmentations, height, width):
+ masks = []
+ for polygons in segmentations:
+ rles = coco_mask.frPyObjects(polygons, height, width)
+ mask = coco_mask.decode(rles)
+ if len(mask.shape) < 3:
+ mask = mask[..., None]
+ mask = torch.as_tensor(mask, dtype=torch.uint8)
+ mask = mask.any(dim=2)
+ masks.append(mask)
+ if masks:
+ masks = torch.stack(masks, dim=0)
+ else:
+ # 如果mask为空,则说明没有目标,直接返回数值为0的mask
+ masks = torch.zeros((0, height, width), dtype=torch.uint8)
+ return masks
+
+
+def convert_to_coco_api(self):
+ coco_ds = COCO()
+ # annotation IDs need to start at 1, not 0, see torchvision issue #1530
+ ann_id = 1
+ dataset = {"images": [], "categories": [], "annotations": []}
+ categories = set()
+ for img_idx in range(len(self)):
+ targets, h, w = self.get_annotations(img_idx)
+ img_id = targets["image_id"].item()
+ img_dict = {"id": img_id,
+ "height": h,
+ "width": w}
+ dataset["images"].append(img_dict)
+ bboxes = targets["boxes"].clone()
+ # convert (x_min, ymin, xmax, ymax) to (xmin, ymin, w, h)
+ bboxes[:, 2:] -= bboxes[:, :2]
+ bboxes = bboxes.tolist()
+ labels = targets["labels"].tolist()
+ areas = targets["area"].tolist()
+ iscrowd = targets["iscrowd"].tolist()
+ if "masks" in targets:
+ masks = targets["masks"]
+ # make masks Fortran contiguous for coco_mask
+ masks = masks.permute(0, 2, 1).contiguous().permute(0, 2, 1)
+ num_objs = len(bboxes)
+ for i in range(num_objs):
+ ann = {"image_id": img_id,
+ "bbox": bboxes[i],
+ "category_id": labels[i],
+ "area": areas[i],
+ "iscrowd": iscrowd[i],
+ "id": ann_id}
+ categories.add(labels[i])
+ if "masks" in targets:
+ ann["segmentation"] = coco_mask.encode(masks[i].numpy())
+ dataset["annotations"].append(ann)
+ ann_id += 1
+ dataset["categories"] = [{"id": i} for i in sorted(categories)]
+ coco_ds.dataset = dataset
+ coco_ds.createIndex()
+ return coco_ds
diff --git a/pytorch_object_detection/mask_rcnn/train_utils/distributed_utils.py b/pytorch_object_detection/mask_rcnn/train_utils/distributed_utils.py
new file mode 100644
index 000000000..80b2412c6
--- /dev/null
+++ b/pytorch_object_detection/mask_rcnn/train_utils/distributed_utils.py
@@ -0,0 +1,299 @@
+from collections import defaultdict, deque
+import datetime
+import pickle
+import time
+import errno
+import os
+
+import torch
+import torch.distributed as dist
+
+
+class SmoothedValue(object):
+ """Track a series of values and provide access to smoothed values over a
+ window or the global series average.
+ """
+ def __init__(self, window_size=20, fmt=None):
+ if fmt is None:
+ fmt = "{value:.4f} ({global_avg:.4f})"
+ self.deque = deque(maxlen=window_size) # deque简单理解成加强版list
+ self.total = 0.0
+ self.count = 0
+ self.fmt = fmt
+
+ def update(self, value, n=1):
+ self.deque.append(value)
+ self.count += n
+ self.total += value * n
+
+ def synchronize_between_processes(self):
+ """
+ Warning: does not synchronize the deque!
+ """
+ if not is_dist_avail_and_initialized():
+ return
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
+ dist.barrier()
+ dist.all_reduce(t)
+ t = t.tolist()
+ self.count = int(t[0])
+ self.total = t[1]
+
+ @property
+ def median(self): # @property 是装饰器,这里可简单理解为增加median属性(只读)
+ d = torch.tensor(list(self.deque))
+ return d.median().item()
+
+ @property
+ def avg(self):
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
+ return d.mean().item()
+
+ @property
+ def global_avg(self):
+ return self.total / self.count
+
+ @property
+ def max(self):
+ return max(self.deque)
+
+ @property
+ def value(self):
+ return self.deque[-1]
+
+ def __str__(self):
+ return self.fmt.format(
+ median=self.median,
+ avg=self.avg,
+ global_avg=self.global_avg,
+ max=self.max,
+ value=self.value)
+
+
+def all_gather(data):
+ """
+ 收集各个进程中的数据
+ Run all_gather on arbitrary picklable data (not necessarily tensors)
+ Args:
+ data: any picklable object
+ Returns:
+ list[data]: list of data gathered from each rank
+ """
+ world_size = get_world_size() # 进程数
+ if world_size == 1:
+ return [data]
+
+ data_list = [None] * world_size
+ dist.all_gather_object(data_list, data)
+
+ return data_list
+
+
+def reduce_dict(input_dict, average=True):
+ """
+ Args:
+ input_dict (dict): all the values will be reduced
+ average (bool): whether to do average or sum
+ Reduce the values in the dictionary from all processes so that all processes
+ have the averaged results. Returns a dict with the same fields as
+ input_dict, after reduction.
+ """
+ world_size = get_world_size()
+ if world_size < 2: # 单GPU的情况
+ return input_dict
+ with torch.no_grad(): # 多GPU的情况
+ names = []
+ values = []
+ # sort the keys so that they are consistent across processes
+ for k in sorted(input_dict.keys()):
+ names.append(k)
+ values.append(input_dict[k])
+ values = torch.stack(values, dim=0)
+ dist.all_reduce(values)
+ if average:
+ values /= world_size
+
+ reduced_dict = {k: v for k, v in zip(names, values)}
+ return reduced_dict
+
+
+class MetricLogger(object):
+ def __init__(self, delimiter="\t"):
+ self.meters = defaultdict(SmoothedValue)
+ self.delimiter = delimiter
+
+ def update(self, **kwargs):
+ for k, v in kwargs.items():
+ if isinstance(v, torch.Tensor):
+ v = v.item()
+ assert isinstance(v, (float, int))
+ self.meters[k].update(v)
+
+ def __getattr__(self, attr):
+ if attr in self.meters:
+ return self.meters[attr]
+ if attr in self.__dict__:
+ return self.__dict__[attr]
+ raise AttributeError("'{}' object has no attribute '{}'".format(
+ type(self).__name__, attr))
+
+ def __str__(self):
+ loss_str = []
+ for name, meter in self.meters.items():
+ loss_str.append(
+ "{}: {}".format(name, str(meter))
+ )
+ return self.delimiter.join(loss_str)
+
+ def synchronize_between_processes(self):
+ for meter in self.meters.values():
+ meter.synchronize_between_processes()
+
+ def add_meter(self, name, meter):
+ self.meters[name] = meter
+
+ def log_every(self, iterable, print_freq, header=None):
+ i = 0
+ if not header:
+ header = ""
+ start_time = time.time()
+ end = time.time()
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
+ data_time = SmoothedValue(fmt='{avg:.4f}')
+ space_fmt = ":" + str(len(str(len(iterable)))) + "d"
+ if torch.cuda.is_available():
+ log_msg = self.delimiter.join([header,
+ '[{0' + space_fmt + '}/{1}]',
+ 'eta: {eta}',
+ '{meters}',
+ 'time: {time}',
+ 'data: {data}',
+ 'max mem: {memory:.0f}'])
+ else:
+ log_msg = self.delimiter.join([header,
+ '[{0' + space_fmt + '}/{1}]',
+ 'eta: {eta}',
+ '{meters}',
+ 'time: {time}',
+ 'data: {data}'])
+ MB = 1024.0 * 1024.0
+ for obj in iterable:
+ data_time.update(time.time() - end)
+ yield obj
+ iter_time.update(time.time() - end)
+ if i % print_freq == 0 or i == len(iterable) - 1:
+ eta_second = iter_time.global_avg * (len(iterable) - i)
+ eta_string = str(datetime.timedelta(seconds=eta_second))
+ if torch.cuda.is_available():
+ print(log_msg.format(i, len(iterable),
+ eta=eta_string,
+ meters=str(self),
+ time=str(iter_time),
+ data=str(data_time),
+ memory=torch.cuda.max_memory_allocated() / MB))
+ else:
+ print(log_msg.format(i, len(iterable),
+ eta=eta_string,
+ meters=str(self),
+ time=str(iter_time),
+ data=str(data_time)))
+ i += 1
+ end = time.time()
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print('{} Total time: {} ({:.4f} s / it)'.format(header,
+ total_time_str,
+
+ total_time / len(iterable)))
+
+
+def warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor):
+
+ def f(x):
+ """根据step数返回一个学习率倍率因子"""
+ if x >= warmup_iters: # 当迭代数大于给定的warmup_iters时,倍率因子为1
+ return 1
+ alpha = float(x) / warmup_iters
+ # 迭代过程中倍率因子从warmup_factor -> 1
+ return warmup_factor * (1 - alpha) + alpha
+
+ return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=f)
+
+
+def mkdir(path):
+ try:
+ os.makedirs(path)
+ except OSError as e:
+ if e.errno != errno.EEXIST:
+ raise
+
+
+def setup_for_distributed(is_master):
+ """
+ This function disables when not in master process
+ """
+ import builtins as __builtin__
+ builtin_print = __builtin__.print
+
+ def print(*args, **kwargs):
+ force = kwargs.pop('force', False)
+ if is_master or force:
+ builtin_print(*args, **kwargs)
+
+ __builtin__.print = print
+
+
+def is_dist_avail_and_initialized():
+ """检查是否支持分布式环境"""
+ if not dist.is_available():
+ return False
+ if not dist.is_initialized():
+ return False
+ return True
+
+
+def get_world_size():
+ if not is_dist_avail_and_initialized():
+ return 1
+ return dist.get_world_size()
+
+
+def get_rank():
+ if not is_dist_avail_and_initialized():
+ return 0
+ return dist.get_rank()
+
+
+def is_main_process():
+ return get_rank() == 0
+
+
+def save_on_master(*args, **kwargs):
+ if is_main_process():
+ torch.save(*args, **kwargs)
+
+
+def init_distributed_mode(args):
+ if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
+ args.rank = int(os.environ["RANK"])
+ args.world_size = int(os.environ['WORLD_SIZE'])
+ args.gpu = int(os.environ['LOCAL_RANK'])
+ elif 'SLURM_PROCID' in os.environ:
+ args.rank = int(os.environ['SLURM_PROCID'])
+ args.gpu = args.rank % torch.cuda.device_count()
+ else:
+ print('Not using distributed mode')
+ args.distributed = False
+ return
+
+ args.distributed = True
+
+ torch.cuda.set_device(args.gpu)
+ args.dist_backend = 'nccl'
+ print('| distributed init (rank {}): {}'.format(
+ args.rank, args.dist_url), flush=True)
+ torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
+ world_size=args.world_size, rank=args.rank)
+ torch.distributed.barrier()
+ setup_for_distributed(args.rank == 0)
+
diff --git a/pytorch_object_detection/mask_rcnn/train_utils/group_by_aspect_ratio.py b/pytorch_object_detection/mask_rcnn/train_utils/group_by_aspect_ratio.py
new file mode 100644
index 000000000..e7b8b9e88
--- /dev/null
+++ b/pytorch_object_detection/mask_rcnn/train_utils/group_by_aspect_ratio.py
@@ -0,0 +1,201 @@
+import bisect
+from collections import defaultdict
+import copy
+from itertools import repeat, chain
+import math
+import numpy as np
+
+import torch
+import torch.utils.data
+from torch.utils.data.sampler import BatchSampler, Sampler
+from torch.utils.model_zoo import tqdm
+import torchvision
+
+from PIL import Image
+
+
+def _repeat_to_at_least(iterable, n):
+ repeat_times = math.ceil(n / len(iterable))
+ repeated = chain.from_iterable(repeat(iterable, repeat_times))
+ return list(repeated)
+
+
+class GroupedBatchSampler(BatchSampler):
+ """
+ Wraps another sampler to yield a mini-batch of indices.
+ It enforces that the batch only contain elements from the same group.
+ It also tries to provide mini-batches which follows an ordering which is
+ as close as possible to the ordering from the original sampler.
+ Arguments:
+ sampler (Sampler): Base sampler.
+ group_ids (list[int]): If the sampler produces indices in range [0, N),
+ `group_ids` must be a list of `N` ints which contains the group id of each sample.
+ The group ids must be a continuous set of integers starting from
+ 0, i.e. they must be in the range [0, num_groups).
+ batch_size (int): Size of mini-batch.
+ """
+ def __init__(self, sampler, group_ids, batch_size):
+ if not isinstance(sampler, Sampler):
+ raise ValueError(
+ "sampler should be an instance of "
+ "torch.utils.data.Sampler, but got sampler={}".format(sampler)
+ )
+ self.sampler = sampler
+ self.group_ids = group_ids
+ self.batch_size = batch_size
+
+ def __iter__(self):
+ buffer_per_group = defaultdict(list)
+ samples_per_group = defaultdict(list)
+
+ num_batches = 0
+ for idx in self.sampler:
+ group_id = self.group_ids[idx]
+ buffer_per_group[group_id].append(idx)
+ samples_per_group[group_id].append(idx)
+ if len(buffer_per_group[group_id]) == self.batch_size:
+ yield buffer_per_group[group_id]
+ num_batches += 1
+ del buffer_per_group[group_id]
+ assert len(buffer_per_group[group_id]) < self.batch_size
+
+ # now we have run out of elements that satisfy
+ # the group criteria, let's return the remaining
+ # elements so that the size of the sampler is
+ # deterministic
+ expected_num_batches = len(self)
+ num_remaining = expected_num_batches - num_batches
+ if num_remaining > 0:
+ # for the remaining batches, take first the buffers with largest number
+ # of elements
+ for group_id, _ in sorted(buffer_per_group.items(),
+ key=lambda x: len(x[1]), reverse=True):
+ remaining = self.batch_size - len(buffer_per_group[group_id])
+ samples_from_group_id = _repeat_to_at_least(samples_per_group[group_id], remaining)
+ buffer_per_group[group_id].extend(samples_from_group_id[:remaining])
+ assert len(buffer_per_group[group_id]) == self.batch_size
+ yield buffer_per_group[group_id]
+ num_remaining -= 1
+ if num_remaining == 0:
+ break
+ assert num_remaining == 0
+
+ def __len__(self):
+ return len(self.sampler) // self.batch_size
+
+
+def _compute_aspect_ratios_slow(dataset, indices=None):
+ print("Your dataset doesn't support the fast path for "
+ "computing the aspect ratios, so will iterate over "
+ "the full dataset and load every image instead. "
+ "This might take some time...")
+ if indices is None:
+ indices = range(len(dataset))
+
+ class SubsetSampler(Sampler):
+ def __init__(self, indices):
+ self.indices = indices
+
+ def __iter__(self):
+ return iter(self.indices)
+
+ def __len__(self):
+ return len(self.indices)
+
+ sampler = SubsetSampler(indices)
+ data_loader = torch.utils.data.DataLoader(
+ dataset, batch_size=1, sampler=sampler,
+ num_workers=14, # you might want to increase it for faster processing
+ collate_fn=lambda x: x[0])
+ aspect_ratios = []
+ with tqdm(total=len(dataset)) as pbar:
+ for _i, (img, _) in enumerate(data_loader):
+ pbar.update(1)
+ height, width = img.shape[-2:]
+ aspect_ratio = float(width) / float(height)
+ aspect_ratios.append(aspect_ratio)
+ return aspect_ratios
+
+
+def _compute_aspect_ratios_custom_dataset(dataset, indices=None):
+ if indices is None:
+ indices = range(len(dataset))
+ aspect_ratios = []
+ for i in indices:
+ height, width = dataset.get_height_and_width(i)
+ aspect_ratio = float(width) / float(height)
+ aspect_ratios.append(aspect_ratio)
+ return aspect_ratios
+
+
+def _compute_aspect_ratios_coco_dataset(dataset, indices=None):
+ if indices is None:
+ indices = range(len(dataset))
+ aspect_ratios = []
+ for i in indices:
+ img_info = dataset.coco.imgs[dataset.ids[i]]
+ aspect_ratio = float(img_info["width"]) / float(img_info["height"])
+ aspect_ratios.append(aspect_ratio)
+ return aspect_ratios
+
+
+def _compute_aspect_ratios_voc_dataset(dataset, indices=None):
+ if indices is None:
+ indices = range(len(dataset))
+ aspect_ratios = []
+ for i in indices:
+ # this doesn't load the data into memory, because PIL loads it lazily
+ width, height = Image.open(dataset.images[i]).size
+ aspect_ratio = float(width) / float(height)
+ aspect_ratios.append(aspect_ratio)
+ return aspect_ratios
+
+
+def _compute_aspect_ratios_subset_dataset(dataset, indices=None):
+ if indices is None:
+ indices = range(len(dataset))
+
+ ds_indices = [dataset.indices[i] for i in indices]
+ return compute_aspect_ratios(dataset.dataset, ds_indices)
+
+
+def compute_aspect_ratios(dataset, indices=None):
+ if hasattr(dataset, "get_height_and_width"):
+ return _compute_aspect_ratios_custom_dataset(dataset, indices)
+
+ if isinstance(dataset, torchvision.datasets.CocoDetection):
+ return _compute_aspect_ratios_coco_dataset(dataset, indices)
+
+ if isinstance(dataset, torchvision.datasets.VOCDetection):
+ return _compute_aspect_ratios_voc_dataset(dataset, indices)
+
+ if isinstance(dataset, torch.utils.data.Subset):
+ return _compute_aspect_ratios_subset_dataset(dataset, indices)
+
+ # slow path
+ return _compute_aspect_ratios_slow(dataset, indices)
+
+
+def _quantize(x, bins):
+ bins = copy.deepcopy(bins)
+ bins = sorted(bins)
+ # bisect_right:寻找y元素按顺序应该排在bins中哪个元素的右边,返回的是索引
+ quantized = list(map(lambda y: bisect.bisect_right(bins, y), x))
+ return quantized
+
+
+def create_aspect_ratio_groups(dataset, k=0):
+ # 计算所有数据集中的图片width/height比例
+ aspect_ratios = compute_aspect_ratios(dataset)
+ # 将[0.5, 2]区间划分成2*k+1等份
+ bins = (2 ** np.linspace(-1, 1, 2 * k + 1)).tolist() if k > 0 else [1.0]
+
+ # 统计所有图像比例在bins区间中的位置索引
+ groups = _quantize(aspect_ratios, bins)
+ # count number of elements per group
+ # 统计每个区间的频次
+ counts = np.unique(groups, return_counts=True)[1]
+ fbins = [0] + bins + [np.inf]
+ print("Using {} as bins for aspect ratio quantization".format(fbins))
+ print("Count of instances per bin: {}".format(counts))
+ return groups
diff --git a/pytorch_object_detection/mask_rcnn/train_utils/train_eval_utils.py b/pytorch_object_detection/mask_rcnn/train_utils/train_eval_utils.py
new file mode 100644
index 000000000..29bae2fb2
--- /dev/null
+++ b/pytorch_object_detection/mask_rcnn/train_utils/train_eval_utils.py
@@ -0,0 +1,109 @@
+import math
+import sys
+import time
+
+import torch
+
+import train_utils.distributed_utils as utils
+from .coco_eval import EvalCOCOMetric
+
+
+def train_one_epoch(model, optimizer, data_loader, device, epoch,
+ print_freq=50, warmup=False, scaler=None):
+ model.train()
+ metric_logger = utils.MetricLogger(delimiter=" ")
+ metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
+ header = 'Epoch: [{}]'.format(epoch)
+
+ lr_scheduler = None
+ if epoch == 0 and warmup is True: # 当训练第一轮(epoch=0)时,启用warmup训练方式,可理解为热身训练
+ warmup_factor = 1.0 / 1000
+ warmup_iters = min(1000, len(data_loader) - 1)
+
+ lr_scheduler = utils.warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor)
+
+ mloss = torch.zeros(1).to(device) # mean losses
+ for i, [images, targets] in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
+ images = list(image.to(device) for image in images)
+ targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
+
+ # 混合精度训练上下文管理器,如果在CPU环境中不起任何作用
+ with torch.cuda.amp.autocast(enabled=scaler is not None):
+ loss_dict = model(images, targets)
+
+ losses = sum(loss for loss in loss_dict.values())
+
+ # reduce losses over all GPUs for logging purpose
+ loss_dict_reduced = utils.reduce_dict(loss_dict)
+ losses_reduced = sum(loss for loss in loss_dict_reduced.values())
+
+ loss_value = losses_reduced.item()
+ # 记录训练损失
+ mloss = (mloss * i + loss_value) / (i + 1) # update mean losses
+
+ if not math.isfinite(loss_value): # 当计算的损失为无穷大时停止训练
+ print("Loss is {}, stopping training".format(loss_value))
+ print(loss_dict_reduced)
+ sys.exit(1)
+
+ optimizer.zero_grad()
+ if scaler is not None:
+ scaler.scale(losses).backward()
+ scaler.step(optimizer)
+ scaler.update()
+ else:
+ losses.backward()
+ optimizer.step()
+
+ if lr_scheduler is not None: # 第一轮使用warmup训练方式
+ lr_scheduler.step()
+
+ metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
+ now_lr = optimizer.param_groups[0]["lr"]
+ metric_logger.update(lr=now_lr)
+
+ return mloss, now_lr
+
+
+@torch.no_grad()
+def evaluate(model, data_loader, device):
+ cpu_device = torch.device("cpu")
+ model.eval()
+ metric_logger = utils.MetricLogger(delimiter=" ")
+ header = "Test: "
+
+ det_metric = EvalCOCOMetric(data_loader.dataset.coco, iou_type="bbox", results_file_name="det_results.json")
+ seg_metric = EvalCOCOMetric(data_loader.dataset.coco, iou_type="segm", results_file_name="seg_results.json")
+ for image, targets in metric_logger.log_every(data_loader, 100, header):
+ image = list(img.to(device) for img in image)
+
+ # 当使用CPU时,跳过GPU相关指令
+ if device != torch.device("cpu"):
+ torch.cuda.synchronize(device)
+
+ model_time = time.time()
+ outputs = model(image)
+
+ outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs]
+ model_time = time.time() - model_time
+
+ det_metric.update(targets, outputs)
+ seg_metric.update(targets, outputs)
+ metric_logger.update(model_time=model_time)
+
+ # gather the stats from all processes
+ metric_logger.synchronize_between_processes()
+ print("Averaged stats:", metric_logger)
+
+ # 同步所有进程中的数据
+ det_metric.synchronize_results()
+ seg_metric.synchronize_results()
+
+ if utils.is_main_process():
+ coco_info = det_metric.evaluate()
+ seg_info = seg_metric.evaluate()
+ else:
+ coco_info = None
+ seg_info = None
+
+ return coco_info, seg_info
diff --git a/pytorch_object_detection/mask_rcnn/transforms.py b/pytorch_object_detection/mask_rcnn/transforms.py
new file mode 100644
index 000000000..6b3abe871
--- /dev/null
+++ b/pytorch_object_detection/mask_rcnn/transforms.py
@@ -0,0 +1,38 @@
+import random
+from torchvision.transforms import functional as F
+
+
+class Compose(object):
+ """组合多个transform函数"""
+ def __init__(self, transforms):
+ self.transforms = transforms
+
+ def __call__(self, image, target):
+ for t in self.transforms:
+ image, target = t(image, target)
+ return image, target
+
+
+class ToTensor(object):
+ """将PIL图像转为Tensor"""
+ def __call__(self, image, target):
+ image = F.to_tensor(image)
+ return image, target
+
+
+class RandomHorizontalFlip(object):
+ """随机水平翻转图像以及bboxes"""
+ def __init__(self, prob=0.5):
+ self.prob = prob
+
+ def __call__(self, image, target):
+ if random.random() < self.prob:
+ height, width = image.shape[-2:]
+ image = image.flip(-1) # 水平翻转图片
+ bbox = target["boxes"]
+ # bbox: xmin, ymin, xmax, ymax
+ bbox[:, [0, 2]] = width - bbox[:, [2, 0]] # 翻转对应bbox坐标信息
+ target["boxes"] = bbox
+ if "masks" in target:
+ target["masks"] = target["masks"].flip(-1)
+ return image, target
diff --git a/pytorch_object_detection/mask_rcnn/validation.py b/pytorch_object_detection/mask_rcnn/validation.py
new file mode 100644
index 000000000..a27288121
--- /dev/null
+++ b/pytorch_object_detection/mask_rcnn/validation.py
@@ -0,0 +1,218 @@
+"""
+该脚本用于调用训练好的模型权重去计算验证集/测试集的COCO指标
+以及每个类别的mAP(IoU=0.5)
+"""
+
+import os
+import json
+
+import torch
+from tqdm import tqdm
+import numpy as np
+
+import transforms
+from backbone import resnet50_fpn_backbone
+from network_files import MaskRCNN
+from my_dataset_coco import CocoDetection
+from my_dataset_voc import VOCInstances
+from train_utils import EvalCOCOMetric
+
+
+def summarize(self, catId=None):
+ """
+ Compute and display summary metrics for evaluation results.
+ Note this functin can *only* be applied on the default parameter setting
+ """
+
+ def _summarize(ap=1, iouThr=None, areaRng='all', maxDets=100):
+ p = self.params
+ iStr = ' {:<18} {} @[ IoU={:<9} | area={:>6s} | maxDets={:>3d} ] = {:0.3f}'
+ titleStr = 'Average Precision' if ap == 1 else 'Average Recall'
+ typeStr = '(AP)' if ap == 1 else '(AR)'
+ iouStr = '{:0.2f}:{:0.2f}'.format(p.iouThrs[0], p.iouThrs[-1]) \
+ if iouThr is None else '{:0.2f}'.format(iouThr)
+
+ aind = [i for i, aRng in enumerate(p.areaRngLbl) if aRng == areaRng]
+ mind = [i for i, mDet in enumerate(p.maxDets) if mDet == maxDets]
+
+ if ap == 1:
+ # dimension of precision: [TxRxKxAxM]
+ s = self.eval['precision']
+ # IoU
+ if iouThr is not None:
+ t = np.where(iouThr == p.iouThrs)[0]
+ s = s[t]
+
+ if isinstance(catId, int):
+ s = s[:, :, catId, aind, mind]
+ else:
+ s = s[:, :, :, aind, mind]
+
+ else:
+ # dimension of recall: [TxKxAxM]
+ s = self.eval['recall']
+ if iouThr is not None:
+ t = np.where(iouThr == p.iouThrs)[0]
+ s = s[t]
+
+ if isinstance(catId, int):
+ s = s[:, catId, aind, mind]
+ else:
+ s = s[:, :, aind, mind]
+
+ if len(s[s > -1]) == 0:
+ mean_s = -1
+ else:
+ mean_s = np.mean(s[s > -1])
+
+ print_string = iStr.format(titleStr, typeStr, iouStr, areaRng, maxDets, mean_s)
+ return mean_s, print_string
+
+ stats, print_list = [0] * 12, [""] * 12
+ stats[0], print_list[0] = _summarize(1)
+ stats[1], print_list[1] = _summarize(1, iouThr=.5, maxDets=self.params.maxDets[2])
+ stats[2], print_list[2] = _summarize(1, iouThr=.75, maxDets=self.params.maxDets[2])
+ stats[3], print_list[3] = _summarize(1, areaRng='small', maxDets=self.params.maxDets[2])
+ stats[4], print_list[4] = _summarize(1, areaRng='medium', maxDets=self.params.maxDets[2])
+ stats[5], print_list[5] = _summarize(1, areaRng='large', maxDets=self.params.maxDets[2])
+ stats[6], print_list[6] = _summarize(0, maxDets=self.params.maxDets[0])
+ stats[7], print_list[7] = _summarize(0, maxDets=self.params.maxDets[1])
+ stats[8], print_list[8] = _summarize(0, maxDets=self.params.maxDets[2])
+ stats[9], print_list[9] = _summarize(0, areaRng='small', maxDets=self.params.maxDets[2])
+ stats[10], print_list[10] = _summarize(0, areaRng='medium', maxDets=self.params.maxDets[2])
+ stats[11], print_list[11] = _summarize(0, areaRng='large', maxDets=self.params.maxDets[2])
+
+ print_info = "\n".join(print_list)
+
+ if not self.eval:
+ raise Exception('Please run accumulate() first')
+
+ return stats, print_info
+
+
+def save_info(coco_evaluator,
+ category_index: dict,
+ save_name: str = "record_mAP.txt"):
+ iou_type = coco_evaluator.params.iouType
+ print(f"IoU metric: {iou_type}")
+ # calculate COCO info for all classes
+ coco_stats, print_coco = summarize(coco_evaluator)
+
+ # calculate voc info for every classes(IoU=0.5)
+ classes = [v for v in category_index.values() if v != "N/A"]
+ voc_map_info_list = []
+ for i in range(len(classes)):
+ stats, _ = summarize(coco_evaluator, catId=i)
+ voc_map_info_list.append(" {:15}: {}".format(classes[i], stats[1]))
+
+ print_voc = "\n".join(voc_map_info_list)
+ print(print_voc)
+
+ # 将验证结果保存至txt文件中
+ with open(save_name, "w") as f:
+ record_lines = ["COCO results:",
+ print_coco,
+ "",
+ "mAP(IoU=0.5) for each category:",
+ print_voc]
+ f.write("\n".join(record_lines))
+
+
+def main(parser_data):
+ device = torch.device(parser_data.device if torch.cuda.is_available() else "cpu")
+ print("Using {} device training.".format(device.type))
+
+ data_transform = {
+ "val": transforms.Compose([transforms.ToTensor()])
+ }
+
+ # read class_indict
+ label_json_path = parser_data.label_json_path
+ assert os.path.exists(label_json_path), "json file {} dose not exist.".format(label_json_path)
+ with open(label_json_path, 'r') as f:
+ category_index = json.load(f)
+
+ data_root = parser_data.data_path
+
+ # 注意这里的collate_fn是自定义的,因为读取的数据包括image和targets,不能直接使用默认的方法合成batch
+ batch_size = parser_data.batch_size
+ nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
+ print('Using %g dataloader workers' % nw)
+
+ # load validation data set
+ val_dataset = CocoDetection(data_root, "val", data_transform["val"])
+ # VOCdevkit -> VOC2012 -> ImageSets -> Main -> val.txt
+ # val_dataset = VOCInstances(data_root, year="2012", txt_name="val.txt", transforms=data_transform["val"])
+ val_dataset_loader = torch.utils.data.DataLoader(val_dataset,
+ batch_size=batch_size,
+ shuffle=False,
+ pin_memory=True,
+ num_workers=nw,
+ collate_fn=val_dataset.collate_fn)
+
+ # create model
+ backbone = resnet50_fpn_backbone()
+ model = MaskRCNN(backbone, num_classes=args.num_classes + 1)
+
+ # 载入你自己训练好的模型权重
+ weights_path = parser_data.weights_path
+ assert os.path.exists(weights_path), "not found {} file.".format(weights_path)
+ model.load_state_dict(torch.load(weights_path, map_location='cpu')['model'])
+ # print(model)
+
+ model.to(device)
+
+ # evaluate on the val dataset
+ cpu_device = torch.device("cpu")
+
+ det_metric = EvalCOCOMetric(val_dataset.coco, "bbox", "det_results.json")
+ seg_metric = EvalCOCOMetric(val_dataset.coco, "segm", "seg_results.json")
+ model.eval()
+ with torch.no_grad():
+ for image, targets in tqdm(val_dataset_loader, desc="validation..."):
+ # 将图片传入指定设备device
+ image = list(img.to(device) for img in image)
+
+ # inference
+ outputs = model(image)
+
+ outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs]
+ det_metric.update(targets, outputs)
+ seg_metric.update(targets, outputs)
+
+ det_metric.synchronize_results()
+ seg_metric.synchronize_results()
+ det_metric.evaluate()
+ seg_metric.evaluate()
+
+ save_info(det_metric.coco_evaluator, category_index, "det_record_mAP.txt")
+ save_info(seg_metric.coco_evaluator, category_index, "seg_record_mAP.txt")
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser(
+ description=__doc__)
+
+ # 使用设备类型
+ parser.add_argument('--device', default='cuda', help='device')
+
+ # 检测目标类别数(不包含背景)
+ parser.add_argument('--num-classes', type=int, default=90, help='number of classes')
+
+ # 数据集的根目录
+ parser.add_argument('--data-path', default='/data/coco2017', help='dataset root')
+
+ # 训练好的权重文件
+ parser.add_argument('--weights-path', default='./save_weights/model_25.pth', type=str, help='training weights')
+
+ # batch size(set to 1, don't change)
+ parser.add_argument('--batch-size', default=1, type=int, metavar='N',
+ help='batch size when validation.')
+ # 类别索引和类别名称对应关系
+ parser.add_argument('--label-json-path', type=str, default="coco91_indices.json")
+
+ args = parser.parse_args()
+
+ main(args)
diff --git a/pytorch_object_detection/retinaNet/README.md b/pytorch_object_detection/retinaNet/README.md
index ffd48ecfb..bab2a22f8 100644
--- a/pytorch_object_detection/retinaNet/README.md
+++ b/pytorch_object_detection/retinaNet/README.md
@@ -6,10 +6,10 @@
## 环境配置:
* Python3.6/3.7/3.8
* Pytorch1.7.1(注意:必须是1.6.0或以上,因为使用官方提供的混合精度训练1.6.0后才支持)
-* pycocotools(Linux:```pip install pycocotools```; Windows:```pip install pycocotools-windows```(不需要额外安装vs))
+* pycocotools(Linux:`pip install pycocotools`; Windows:`pip install pycocotools-windows`(不需要额外安装vs))
* Ubuntu或Centos(不建议Windows)
* 最好使用GPU训练
-* 详细环境配置见```requirements.txt```
+* 详细环境配置见`requirements.txt`
## 文件结构:
```
@@ -26,8 +26,8 @@
## 预训练权重下载地址(下载后放入backbone文件夹中):
* ResNet50+FPN backbone: https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth
-* 注意,下载的预训练权重记得要重命名,比如在train.py中读取的是```retinanet_resnet50_fpn_coco.pth```文件,
- 不是```retinanet_resnet50_fpn_coco-eeacb38b.pth```
+* 注意,下载的预训练权重记得要重命名,比如在train.py中读取的是`retinanet_resnet50_fpn_coco.pth`文件,
+ 不是`retinanet_resnet50_fpn_coco-eeacb38b.pth`
## 数据集,本例程使用的是PASCAL VOC2012数据集
@@ -54,14 +54,15 @@
* 确保提前准备好数据集
* 确保提前下载好对应预训练模型权重
* 若要单GPU训练,直接使用train.py训练脚本
-* 若要使用多GPU训练,使用```python -m torch.distributed.launch --nproc_per_node=8 --use_env train_multi_GPU.py```指令,```nproc_per_node```参数为使用GPU数量
-* 如果想指定使用哪些GPU设备可在指令前加上```CUDA_VISIBLE_DEVICES=0,3```(例如我只要使用设备中的第1块和第4块GPU设备)
-* ```CUDA_VISIBLE_DEVICES=0,3 python -m torch.distributed.launch --nproc_per_node=2 --use_env train_multi_GPU.py```
+* 若要使用多GPU训练,使用`python -m torch.distributed.launch --nproc_per_node=8 --use_env train_multi_GPU.py`指令,`nproc_per_node`参数为使用GPU数量
+* 如果想指定使用哪些GPU设备可在指令前加上`CUDA_VISIBLE_DEVICES=0,3`(例如我只要使用设备中的第1块和第4块GPU设备)
+* `CUDA_VISIBLE_DEVICES=0,3 python -m torch.distributed.launch --nproc_per_node=2 --use_env train_multi_GPU.py`
## 注意事项
-* 在使用训练脚本时,注意要将'--data-path'(VOC_root)设置为自己存放'VOCdevkit'文件夹所在的**根目录**
+* 在使用训练脚本时,注意要将`--data-path`(VOC_root)设置为自己存放`VOCdevkit`文件夹所在的**根目录**
* 由于带有FPN结构的Faster RCNN很吃显存,如果GPU的显存不够(如果batch_size小于8的话)建议在create_model函数中使用默认的norm_layer,
即不传递norm_layer变量,默认去使用FrozenBatchNorm2d(即不会去更新参数的bn层),使用中发现效果也很好。
-* 在使用预测脚本时,要将'train_weights'设置为你自己生成的权重路径。
-* 使用validation文件时,注意确保你的验证集或者测试集中必须包含每个类别的目标,并且使用时只需要修改'--num-classes'、'--data-path'和'--weights'即可,其他代码尽量不要改动
+* 训练过程中保存的`results.txt`是每个epoch在验证集上的COCO指标,前12个值是COCO指标,后面两个值是训练平均损失以及学习率
+* 在使用预测脚本时,要将`weights_path`设置为你自己生成的权重路径。
+* 使用validation文件时,注意确保你的验证集或者测试集中必须包含每个类别的目标,并且使用时只需要修改`--num-classes`、`--data-path`和`--weights-path`即可,其他代码尽量不要改动
diff --git a/pytorch_object_detection/retinaNet/backbone/feature_pyramid_network.py b/pytorch_object_detection/retinaNet/backbone/feature_pyramid_network.py
index b9f4ea50b..505fbae3b 100644
--- a/pytorch_object_detection/retinaNet/backbone/feature_pyramid_network.py
+++ b/pytorch_object_detection/retinaNet/backbone/feature_pyramid_network.py
@@ -8,6 +8,111 @@
from torch.jit.annotations import Tuple, List, Dict
+class IntermediateLayerGetter(nn.ModuleDict):
+ """
+ Module wrapper that returns intermediate layers from a model
+ It has a strong assumption that the modules have been registered
+ into the model in the same order as they are used.
+ This means that one should **not** reuse the same nn.Module
+ twice in the forward if you want this to work.
+ Additionally, it is only able to query submodules that are directly
+ assigned to the model. So if `model` is passed, `model.feature1` can
+ be returned, but not `model.feature1.layer2`.
+ Arguments:
+ model (nn.Module): model on which we will extract the features
+ return_layers (Dict[name, new_name]): a dict containing the names
+ of the modules for which the activations will be returned as
+ the key of the dict, and the value of the dict is the name
+ of the returned activation (which the user can specify).
+ """
+ __annotations__ = {
+ "return_layers": Dict[str, str],
+ }
+
+ def __init__(self, model, return_layers):
+ if not set(return_layers).issubset([name for name, _ in model.named_children()]):
+ raise ValueError("return_layers are not present in model")
+
+ orig_return_layers = return_layers
+ return_layers = {str(k): str(v) for k, v in return_layers.items()}
+ layers = OrderedDict()
+
+ # 遍历模型子模块按顺序存入有序字典
+ # 只保存layer4及其之前的结构,舍去之后不用的结构
+ for name, module in model.named_children():
+ layers[name] = module
+ if name in return_layers:
+ del return_layers[name]
+ if not return_layers:
+ break
+
+ super().__init__(layers)
+ self.return_layers = orig_return_layers
+
+ def forward(self, x):
+ out = OrderedDict()
+ # 依次遍历模型的所有子模块,并进行正向传播,
+ # 收集layer1, layer2, layer3, layer4的输出
+ for name, module in self.items():
+ x = module(x)
+ if name in self.return_layers:
+ out_name = self.return_layers[name]
+ out[out_name] = x
+ return out
+
+
+class BackboneWithFPN(nn.Module):
+ """
+ Adds a FPN on top of a model.
+ Internally, it uses torchvision.models._utils.IntermediateLayerGetter to
+ extract a submodel that returns the feature maps specified in return_layers.
+ The same limitations of IntermediatLayerGetter apply here.
+ Arguments:
+ backbone (nn.Module)
+ return_layers (Dict[name, new_name]): a dict containing the names
+ of the modules for which the activations will be returned as
+ the key of the dict, and the value of the dict is the name
+ of the returned activation (which the user can specify).
+ in_channels_list (List[int]): number of channels for each feature map
+ that is returned, in the order they are present in the OrderedDict
+ out_channels (int): number of channels in the FPN.
+ extra_blocks: ExtraFPNBlock
+ Attributes:
+ out_channels (int): the number of channels in the FPN
+ """
+
+ def __init__(self,
+ backbone: nn.Module,
+ return_layers=None,
+ in_channels_list=None,
+ out_channels=256,
+ extra_blocks=None,
+ re_getter=True):
+ super().__init__()
+
+ if extra_blocks is None:
+ extra_blocks = LastLevelMaxPool()
+
+ if re_getter:
+ assert return_layers is not None
+ self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
+ else:
+ self.body = backbone
+
+ self.fpn = FeaturePyramidNetwork(
+ in_channels_list=in_channels_list,
+ out_channels=out_channels,
+ extra_blocks=extra_blocks,
+ )
+
+ self.out_channels = out_channels
+
+ def forward(self, x):
+ x = self.body(x)
+ x = self.fpn(x)
+ return x
+
+
class ExtraFPNBlock(nn.Module):
"""
Base class for the extra block in the FPN.
@@ -35,8 +140,7 @@ class LastLevelMaxPool(torch.nn.Module):
Applies a max_pool2d on top of the last feature map
"""
- def forward(self, x, y, names):
- # type: (List[Tensor], List[Tensor], List[str]) -> Tuple[List[Tensor], List[str]]
+ def forward(self, x: List[Tensor], y: List[Tensor], names: List[str]) -> Tuple[List[Tensor], List[str]]:
names.append("pool")
x.append(F.max_pool2d(x[-1], 1, 2, 0))
return x, names
@@ -47,7 +151,7 @@ class LastLevelP6P7(ExtraFPNBlock):
This module is used in RetinaNet to generate extra layers, P6 and P7.
"""
def __init__(self, in_channels: int, out_channels: int):
- super(LastLevelP6P7, self).__init__()
+ super().__init__()
self.p6 = nn.Conv2d(in_channels, out_channels, 3, 2, 1)
self.p7 = nn.Conv2d(out_channels, out_channels, 3, 2, 1)
for module in [self.p6, self.p7]:
@@ -87,7 +191,7 @@ class FeaturePyramidNetwork(nn.Module):
"""
def __init__(self, in_channels_list, out_channels, extra_blocks=None):
- super(FeaturePyramidNetwork, self).__init__()
+ super().__init__()
# 用来调整resnet特征矩阵(layer1,2,3,4)的channel(kernel_size=1)
self.inner_blocks = nn.ModuleList()
# 对调整后的特征矩阵使用3x3的卷积核来得到对应的预测特征矩阵
@@ -108,8 +212,7 @@ def __init__(self, in_channels_list, out_channels, extra_blocks=None):
self.extra_blocks = extra_blocks
- def get_result_from_inner_blocks(self, x, idx):
- # type: (Tensor, int) -> Tensor
+ def get_result_from_inner_blocks(self, x: Tensor, idx: int) -> Tensor:
"""
This is equivalent to self.inner_blocks[idx](x),
but torchscript doesn't support this yet
@@ -125,8 +228,7 @@ def get_result_from_inner_blocks(self, x, idx):
i += 1
return out
- def get_result_from_layer_blocks(self, x, idx):
- # type: (Tensor, int) -> Tensor
+ def get_result_from_layer_blocks(self, x: Tensor, idx: int) -> Tensor:
"""
This is equivalent to self.layer_blocks[idx](x),
but torchscript doesn't support this yet
@@ -142,8 +244,7 @@ def get_result_from_layer_blocks(self, x, idx):
i += 1
return out
- def forward(self, x):
- # type: (Dict[str, Tensor]) -> Dict[str, Tensor]
+ def forward(self, x: Dict[str, Tensor]) -> Dict[str, Tensor]:
"""
Computes the FPN for a set of feature maps.
Arguments:
diff --git a/pytorch_object_detection/retinaNet/backbone/resnet50_fpn_model.py b/pytorch_object_detection/retinaNet/backbone/resnet50_fpn_model.py
index 553f8aac8..451bf5649 100644
--- a/pytorch_object_detection/retinaNet/backbone/resnet50_fpn_model.py
+++ b/pytorch_object_detection/retinaNet/backbone/resnet50_fpn_model.py
@@ -1,19 +1,17 @@
import os
-from collections import OrderedDict
import torch.nn as nn
import torch
-from torch.jit.annotations import List, Dict
from torchvision.ops.misc import FrozenBatchNorm2d
-from .feature_pyramid_network import LastLevelMaxPool, FeaturePyramidNetwork
+from .feature_pyramid_network import LastLevelMaxPool, BackboneWithFPN
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, in_channel, out_channel, stride=1, downsample=None, norm_layer=None):
- super(Bottleneck, self).__init__()
+ super().__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
@@ -56,7 +54,7 @@ def forward(self, x):
class ResNet(nn.Module):
def __init__(self, block, blocks_num, num_classes=1000, include_top=True, norm_layer=None):
- super(ResNet, self).__init__()
+ super().__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer
@@ -136,100 +134,6 @@ def overwrite_eps(model, eps):
module.eps = eps
-class IntermediateLayerGetter(nn.ModuleDict):
- """
- Module wrapper that returns intermediate layers from a model
- It has a strong assumption that the modules have been registered
- into the model in the same order as they are used.
- This means that one should **not** reuse the same nn.Module
- twice in the forward if you want this to work.
- Additionally, it is only able to query submodules that are directly
- assigned to the model. So if `model` is passed, `model.feature1` can
- be returned, but not `model.feature1.layer2`.
- Arguments:
- model (nn.Module): model on which we will extract the features
- return_layers (Dict[name, new_name]): a dict containing the names
- of the modules for which the activations will be returned as
- the key of the dict, and the value of the dict is the name
- of the returned activation (which the user can specify).
- """
- __annotations__ = {
- "return_layers": Dict[str, str],
- }
-
- def __init__(self, model, return_layers):
- if not set(return_layers).issubset([name for name, _ in model.named_children()]):
- raise ValueError("return_layers are not present in model")
-
- orig_return_layers = return_layers
- return_layers = {str(k): str(v) for k, v in return_layers.items()}
- layers = OrderedDict()
-
- # 遍历模型子模块按顺序存入有序字典
- # 只保存layer4及其之前的结构,舍去之后不用的结构
- for name, module in model.named_children():
- layers[name] = module
- if name in return_layers:
- del return_layers[name]
- if not return_layers:
- break
-
- super(IntermediateLayerGetter, self).__init__(layers)
- self.return_layers = orig_return_layers
-
- def forward(self, x):
- out = OrderedDict()
- # 依次遍历模型的所有子模块,并进行正向传播,
- # 收集layer1, layer2, layer3, layer4的输出
- for name, module in self.items():
- x = module(x)
- if name in self.return_layers:
- out_name = self.return_layers[name]
- out[out_name] = x
- return out
-
-
-class BackboneWithFPN(nn.Module):
- """
- Adds a FPN on top of a model.
- Internally, it uses torchvision.models._utils.IntermediateLayerGetter to
- extract a submodel that returns the feature maps specified in return_layers.
- The same limitations of IntermediatLayerGetter apply here.
- Arguments:
- backbone (nn.Module)
- return_layers (Dict[name, new_name]): a dict containing the names
- of the modules for which the activations will be returned as
- the key of the dict, and the value of the dict is the name
- of the returned activation (which the user can specify).
- in_channels_list (List[int]): number of channels for each feature map
- that is returned, in the order they are present in the OrderedDict
- out_channels (int): number of channels in the FPN.
- extra_blocks: ExtraFPNBlock
- Attributes:
- out_channels (int): the number of channels in the FPN
- """
-
- def __init__(self, backbone, return_layers, in_channels_list, out_channels, extra_blocks=None):
- super(BackboneWithFPN, self).__init__()
-
- if extra_blocks is None:
- extra_blocks = LastLevelMaxPool()
-
- self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
- self.fpn = FeaturePyramidNetwork(
- in_channels_list=in_channels_list,
- out_channels=out_channels,
- extra_blocks=extra_blocks,
- )
-
- self.out_channels = out_channels
-
- def forward(self, x):
- x = self.body(x)
- x = self.fpn(x)
- return x
-
-
def resnet50_fpn_backbone(pretrain_path="",
norm_layer=FrozenBatchNorm2d, # FrozenBatchNorm2d的功能与BatchNorm2d类似,但参数无法更新
trainable_layers=3,
diff --git a/pytorch_object_detection/retinaNet/draw_box_utils.py b/pytorch_object_detection/retinaNet/draw_box_utils.py
index 1a2926583..835d7f7c1 100644
--- a/pytorch_object_detection/retinaNet/draw_box_utils.py
+++ b/pytorch_object_detection/retinaNet/draw_box_utils.py
@@ -1,6 +1,7 @@
-import collections
+from PIL.Image import Image, fromarray
import PIL.ImageDraw as ImageDraw
import PIL.ImageFont as ImageFont
+from PIL import ImageColor
import numpy as np
STANDARD_COLORS = [
@@ -30,66 +31,123 @@
]
-def filter_low_thresh(boxes, scores, classes, category_index, thresh, box_to_display_str_map, box_to_color_map):
- for i in range(boxes.shape[0]):
- if scores[i] > thresh:
- box = tuple(boxes[i].tolist()) # numpy -> list -> tuple
- if classes[i] in category_index.keys():
- class_name = category_index[classes[i]]
- else:
- class_name = 'N/A'
- display_str = str(class_name)
- display_str = '{}: {}%'.format(display_str, int(100 * scores[i]))
- box_to_display_str_map[box].append(display_str)
- box_to_color_map[box] = STANDARD_COLORS[
- classes[i] % len(STANDARD_COLORS)]
- else:
- break # 网络输出概率已经排序过,当遇到一个不满足后面的肯定不满足
-
-
-def draw_text(draw, box_to_display_str_map, box, left, right, top, bottom, color):
+def draw_text(draw,
+ box: list,
+ cls: int,
+ score: float,
+ category_index: dict,
+ color: str,
+ font: str = 'arial.ttf',
+ font_size: int = 24):
+ """
+ 将目标边界框和类别信息绘制到图片上
+ """
try:
- font = ImageFont.truetype('arial.ttf', 24)
+ font = ImageFont.truetype(font, font_size)
except IOError:
font = ImageFont.load_default()
+ left, top, right, bottom = box
# If the total height of the display strings added to the top of the bounding
# box exceeds the top of the image, stack the strings below the bounding box
# instead of above.
- display_str_heights = [font.getsize(ds)[1] for ds in box_to_display_str_map[box]]
+ display_str = f"{category_index[str(cls)]}: {int(100 * score)}%"
+ display_str_heights = [font.getsize(ds)[1] for ds in display_str]
# Each display_str has a top and bottom margin of 0.05x.
- total_display_str_height = (1 + 2 * 0.05) * sum(display_str_heights)
+ display_str_height = (1 + 2 * 0.05) * max(display_str_heights)
- if top > total_display_str_height:
+ if top > display_str_height:
+ text_top = top - display_str_height
text_bottom = top
else:
- text_bottom = bottom + total_display_str_height
- # Reverse list and print from bottom to top.
- for display_str in box_to_display_str_map[box][::-1]:
- text_width, text_height = font.getsize(display_str)
- margin = np.ceil(0.05 * text_height)
- draw.rectangle([(left, text_bottom - text_height - 2 * margin),
- (left + text_width, text_bottom)], fill=color)
- draw.text((left + margin, text_bottom - text_height - margin),
- display_str,
+ text_top = bottom
+ text_bottom = bottom + display_str_height
+
+ for ds in display_str:
+ text_width, text_height = font.getsize(ds)
+ margin = np.ceil(0.05 * text_width)
+ draw.rectangle([(left, text_top),
+ (left + text_width + 2 * margin, text_bottom)], fill=color)
+ draw.text((left + margin, text_top),
+ ds,
fill='black',
font=font)
- text_bottom -= text_height - 2 * margin
+ left += text_width
+
+
+def draw_masks(image, masks, colors, thresh: float = 0.7, alpha: float = 0.5):
+ np_image = np.array(image)
+ masks = np.where(masks > thresh, True, False)
+
+ # colors = np.array(colors)
+ img_to_draw = np.copy(np_image)
+ # TODO: There might be a way to vectorize this
+ for mask, color in zip(masks, colors):
+ img_to_draw[mask] = color
+
+ out = np_image * (1 - alpha) + img_to_draw * alpha
+ return fromarray(out.astype(np.uint8))
+
+
+def draw_objs(image: Image,
+ boxes: np.ndarray = None,
+ classes: np.ndarray = None,
+ scores: np.ndarray = None,
+ masks: np.ndarray = None,
+ category_index: dict = None,
+ box_thresh: float = 0.1,
+ mask_thresh: float = 0.5,
+ line_thickness: int = 8,
+ font: str = 'arial.ttf',
+ font_size: int = 24,
+ draw_boxes_on_image: bool = True,
+ draw_masks_on_image: bool = False):
+ """
+ 将目标边界框信息,类别信息,mask信息绘制在图片上
+ Args:
+ image: 需要绘制的图片
+ boxes: 目标边界框信息
+ classes: 目标类别信息
+ scores: 目标概率信息
+ masks: 目标mask信息
+ category_index: 类别与名称字典
+ box_thresh: 过滤的概率阈值
+ mask_thresh:
+ line_thickness: 边界框宽度
+ font: 字体类型
+ font_size: 字体大小
+ draw_boxes_on_image:
+ draw_masks_on_image:
+
+ Returns:
+
+ """
+
+ # 过滤掉低概率的目标
+ idxs = np.greater(scores, box_thresh)
+ boxes = boxes[idxs]
+ classes = classes[idxs]
+ scores = scores[idxs]
+ if masks is not None:
+ masks = masks[idxs]
+ if len(boxes) == 0:
+ return image
+ colors = [ImageColor.getrgb(STANDARD_COLORS[cls % len(STANDARD_COLORS)]) for cls in classes]
-def draw_box(image, boxes, classes, scores, category_index, thresh=0.5, line_thickness=8):
- box_to_display_str_map = collections.defaultdict(list)
- box_to_color_map = collections.defaultdict(str)
+ if draw_boxes_on_image:
+ # Draw all boxes onto image.
+ draw = ImageDraw.Draw(image)
+ for box, cls, score, color in zip(boxes, classes, scores, colors):
+ left, top, right, bottom = box
+ # 绘制目标边界框
+ draw.line([(left, top), (left, bottom), (right, bottom),
+ (right, top), (left, top)], width=line_thickness, fill=color)
+ # 绘制类别和概率信息
+ draw_text(draw, box.tolist(), int(cls), float(score), category_index, color, font, font_size)
- filter_low_thresh(boxes, scores, classes, category_index, thresh, box_to_display_str_map, box_to_color_map)
+ if draw_masks_on_image and (masks is not None):
+ # Draw all mask onto image.
+ image = draw_masks(image, masks, colors, mask_thresh)
- # Draw all boxes onto image.
- draw = ImageDraw.Draw(image)
- im_width, im_height = image.size
- for box, color in box_to_color_map.items():
- xmin, ymin, xmax, ymax = box
- (left, right, top, bottom) = (xmin * 1, xmax * 1,
- ymin * 1, ymax * 1)
- draw.line([(left, top), (left, bottom), (right, bottom),
- (right, top), (left, top)], width=line_thickness, fill=color)
- draw_text(draw, box_to_display_str_map, box, left, right, top, bottom, color)
+ return image
diff --git a/pytorch_object_detection/retinaNet/my_dataset.py b/pytorch_object_detection/retinaNet/my_dataset.py
index 5a8a4e93a..3dc863bc0 100644
--- a/pytorch_object_detection/retinaNet/my_dataset.py
+++ b/pytorch_object_detection/retinaNet/my_dataset.py
@@ -11,7 +11,11 @@ class VOCDataSet(Dataset):
def __init__(self, voc_root, year="2012", transforms=None, txt_name: str = "train.txt"):
assert year in ["2007", "2012"], "year must be in ['2007', '2012']"
- self.root = os.path.join(voc_root, "VOCdevkit", f"VOC{year}")
+ # 增加容错能力
+ if "VOCdevkit" in voc_root:
+ self.root = os.path.join(voc_root, f"VOC{year}")
+ else:
+ self.root = os.path.join(voc_root, "VOCdevkit", f"VOC{year}")
self.img_root = os.path.join(self.root, "JPEGImages")
self.annotations_root = os.path.join(self.root, "Annotations")
@@ -31,9 +35,8 @@ def __init__(self, voc_root, year="2012", transforms=None, txt_name: str = "trai
# read class_indict
json_file = './pascal_voc_classes.json'
assert os.path.exists(json_file), "{} file not exist.".format(json_file)
- json_file = open(json_file, 'r')
- self.class_dict = json.load(json_file)
- json_file.close()
+ with open(json_file, 'r') as f:
+ self.class_dict = json.load(f)
self.transforms = transforms
@@ -181,7 +184,7 @@ def collate_fn(batch):
return tuple(zip(*batch))
# import transforms
-# from draw_box_utils import draw_box
+# from draw_box_utils import draw_objs
# from PIL import Image
# import json
# import matplotlib.pyplot as plt
@@ -193,7 +196,7 @@ def collate_fn(batch):
# try:
# json_file = open('./pascal_voc_classes.json', 'r')
# class_dict = json.load(json_file)
-# category_index = {v: k for k, v in class_dict.items()}
+# category_index = {str(v): str(k) for k, v in class_dict.items()}
# except Exception as e:
# print(e)
# exit(-1)
@@ -210,12 +213,14 @@ def collate_fn(batch):
# for index in random.sample(range(0, len(train_data_set)), k=5):
# img, target = train_data_set[index]
# img = ts.ToPILImage()(img)
-# draw_box(img,
-# target["boxes"].numpy(),
-# target["labels"].numpy(),
-# [1 for i in range(len(target["labels"].numpy()))],
-# category_index,
-# thresh=0.5,
-# line_thickness=5)
-# plt.imshow(img)
+# plot_img = draw_objs(img,
+# target["boxes"].numpy(),
+# target["labels"].numpy(),
+# np.ones(target["labels"].shape[0]),
+# category_index=category_index,
+# box_thresh=0.5,
+# line_thickness=3,
+# font='arial.ttf',
+# font_size=20)
+# plt.imshow(plot_img)
# plt.show()
diff --git a/pytorch_object_detection/retinaNet/network_files/boxes.py b/pytorch_object_detection/retinaNet/network_files/boxes.py
index f720df1f8..8eeca4573 100644
--- a/pytorch_object_detection/retinaNet/network_files/boxes.py
+++ b/pytorch_object_detection/retinaNet/network_files/boxes.py
@@ -23,7 +23,7 @@ def nms(boxes, scores, iou_threshold):
scores for each one of the boxes
iou_threshold : float
discards all overlapping
- boxes with IoU < iou_threshold
+ boxes with IoU > iou_threshold
Returns
-------
diff --git a/pytorch_object_detection/retinaNet/predict.py b/pytorch_object_detection/retinaNet/predict.py
index 47ed83008..954fd336e 100644
--- a/pytorch_object_detection/retinaNet/predict.py
+++ b/pytorch_object_detection/retinaNet/predict.py
@@ -9,7 +9,7 @@
from torchvision import transforms
from network_files import RetinaNet
from backbone import resnet50_fpn_backbone, LastLevelP6P7
-from draw_box_utils import draw_box
+from draw_box_utils import draw_objs
def create_model(num_classes):
@@ -38,18 +38,20 @@ def main():
model = create_model(num_classes=20)
# load train weights
- train_weights = "./save_weights/model.pth"
- assert os.path.exists(train_weights), "{} file dose not exist.".format(train_weights)
- model.load_state_dict(torch.load(train_weights, map_location=device)["model"])
+ weights_path = "./save_weights/model.pth"
+ assert os.path.exists(weights_path), "{} file dose not exist.".format(weights_path)
+ weights_dict = torch.load(weights_path, map_location='cpu')
+ weights_dict = weights_dict["model"] if "model" in weights_dict else weights_dict
+ model.load_state_dict(weights_dict)
model.to(device)
# read class_indict
label_json_path = './pascal_voc_classes.json'
assert os.path.exists(label_json_path), "json file {} dose not exist.".format(label_json_path)
- json_file = open(label_json_path, 'r')
- class_dict = json.load(json_file)
- json_file.close()
- category_index = {v: k for k, v in class_dict.items()}
+ with open(label_json_path, 'r') as f:
+ class_dict = json.load(f)
+
+ category_index = {str(v): str(k) for k, v in class_dict.items()}
# load image
original_img = Image.open("./test.jpg")
@@ -79,17 +81,19 @@ def main():
if len(predict_boxes) == 0:
print("没有检测到任何目标!")
- draw_box(original_img,
- predict_boxes,
- predict_classes,
- predict_scores,
- category_index,
- thresh=0.4,
- line_thickness=3)
- plt.imshow(original_img)
+ plot_img = draw_objs(original_img,
+ predict_boxes,
+ predict_classes,
+ predict_scores,
+ category_index=category_index,
+ box_thresh=0.5,
+ line_thickness=3,
+ font='arial.ttf',
+ font_size=20)
+ plt.imshow(plot_img)
plt.show()
# 保存预测的图片结果
- original_img.save("test_result.jpg")
+ plot_img.save("test_result.jpg")
if __name__ == '__main__':
diff --git a/pytorch_object_detection/retinaNet/requirements.txt b/pytorch_object_detection/retinaNet/requirements.txt
index b5854c8d5..846ad37de 100644
--- a/pytorch_object_detection/retinaNet/requirements.txt
+++ b/pytorch_object_detection/retinaNet/requirements.txt
@@ -1,6 +1,6 @@
lxml
matplotlib
-nump
+numpy
tqdm
torch==1.7.1
torchvision==0.8.2
diff --git a/pytorch_object_detection/retinaNet/train.py b/pytorch_object_detection/retinaNet/train.py
index bded930ff..314bad117 100644
--- a/pytorch_object_detection/retinaNet/train.py
+++ b/pytorch_object_detection/retinaNet/train.py
@@ -35,8 +35,8 @@ def create_model(num_classes):
return model
-def main(parser_data):
- device = torch.device(parser_data.device if torch.cuda.is_available() else "cpu")
+def main(args):
+ device = torch.device(args.device if torch.cuda.is_available() else "cpu")
print("Using {} device training.".format(device.type))
results_file = "results{}.txt".format(datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
@@ -47,7 +47,7 @@ def main(parser_data):
"val": transforms.Compose([transforms.ToTensor()])
}
- VOC_root = parser_data.data_path
+ VOC_root = args.data_path
# check voc root
if os.path.exists(os.path.join(VOC_root, "VOCdevkit")) is False:
raise FileNotFoundError("VOCdevkit dose not in path:'{}'.".format(VOC_root))
@@ -67,7 +67,7 @@ def main(parser_data):
train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids, args.batch_size)
# 注意这里的collate_fn是自定义的,因为读取的数据包括image和targets,不能直接使用默认的方法合成batch
- batch_size = parser_data.batch_size
+ batch_size = args.batch_size
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
print('Using %g dataloader workers' % nw)
if train_sampler:
@@ -97,7 +97,7 @@ def main(parser_data):
# create model
# 注意:不包含背景
- model = create_model(num_classes=parser_data.num_classes)
+ model = create_model(num_classes=args.num_classes)
# print(model)
model.to(device)
@@ -115,21 +115,21 @@ def main(parser_data):
gamma=0.33)
# 如果指定了上次训练保存的权重文件地址,则接着上次结果接着训练
- if parser_data.resume != "":
- checkpoint = torch.load(parser_data.resume, map_location='cpu')
+ if args.resume != "":
+ checkpoint = torch.load(args.resume, map_location='cpu')
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
- parser_data.start_epoch = checkpoint['epoch'] + 1
+ args.start_epoch = checkpoint['epoch'] + 1
if args.amp and "scaler" in checkpoint:
scaler.load_state_dict(checkpoint["scaler"])
- print("the training process from epoch{}...".format(parser_data.start_epoch))
+ print("the training process from epoch{}...".format(args.start_epoch))
train_loss = []
learning_rate = []
val_map = []
- for epoch in range(parser_data.start_epoch, parser_data.epochs):
+ for epoch in range(args.start_epoch, args.epochs):
# train for one epoch, printing every 10 iterations
mean_loss, lr = utils.train_one_epoch(model, optimizer, train_data_loader,
device, epoch, print_freq=50,
@@ -146,7 +146,7 @@ def main(parser_data):
# write into txt
with open(results_file, "a") as f:
# 写入的数据包括coco指标还有loss和learning rate
- result_info = [str(round(i, 4)) for i in coco_info + [mean_loss.item()]] + [str(round(lr, 6))]
+ result_info = [f"{i:.4f}" for i in coco_info + [mean_loss.item()]] + [f"{lr:.6f}"]
txt = "epoch:{} {}".format(epoch, ' '.join(result_info))
f.write(txt + "\n")
diff --git a/pytorch_object_detection/retinaNet/train_multi_GPU.py b/pytorch_object_detection/retinaNet/train_multi_GPU.py
index 047a64eb2..35ed8fc77 100644
--- a/pytorch_object_detection/retinaNet/train_multi_GPU.py
+++ b/pytorch_object_detection/retinaNet/train_multi_GPU.py
@@ -156,7 +156,7 @@ def main(args):
# write into txt
with open(results_file, "a") as f:
# 写入的数据包括coco指标还有loss和learning rate
- result_info = [str(round(i, 4)) for i in coco_info + [mean_loss.item()]] + [str(round(lr, 6))]
+ result_info = [f"{i:.4f}" for i in coco_info + [mean_loss.item()]] + [f"{lr:.6f}"]
txt = "epoch:{} {}".format(epoch, ' '.join(result_info))
f.write(txt + "\n")
diff --git a/pytorch_object_detection/retinaNet/validation.py b/pytorch_object_detection/retinaNet/validation.py
index ffd320443..cc2826763 100644
--- a/pytorch_object_detection/retinaNet/validation.py
+++ b/pytorch_object_detection/retinaNet/validation.py
@@ -100,9 +100,9 @@ def main(parser_data):
# read class_indict
label_json_path = './pascal_voc_classes.json'
assert os.path.exists(label_json_path), "json file {} dose not exist.".format(label_json_path)
- json_file = open(label_json_path, 'r')
- class_dict = json.load(json_file)
- json_file.close()
+ with open(label_json_path, 'r') as f:
+ class_dict = json.load(f)
+
category_index = {v: k for k, v in class_dict.items()}
VOC_root = parser_data.data_path
@@ -132,9 +132,11 @@ def main(parser_data):
model = RetinaNet(backbone, parser_data.num_classes)
# 载入你自己训练好的模型权重
- weights_path = parser_data.weights
+ weights_path = parser_data.weights_path
assert os.path.exists(weights_path), "not found {} file.".format(weights_path)
- model.load_state_dict(torch.load(weights_path, map_location=device)['model'])
+ weights_dict = torch.load(weights_path, map_location='cpu')
+ weights_dict = weights_dict["model"] if "model" in weights_dict else weights_dict
+ model.load_state_dict(weights_dict)
# print(model)
model.to(device)
@@ -203,7 +205,7 @@ def main(parser_data):
parser.add_argument('--data-path', default='/data', help='dataset root')
# 训练好的权重文件
- parser.add_argument('--weights', default='./save_weights/model.pth', type=str, help='training weights')
+ parser.add_argument('--weights-path', default='./save_weights/model.pth', type=str, help='training weights')
# batch size
parser.add_argument('--batch_size', default=1, type=int, metavar='N',
diff --git a/pytorch_object_detection/ssd/README.md b/pytorch_object_detection/ssd/README.md
index be7f2f435..ab51771ab 100644
--- a/pytorch_object_detection/ssd/README.md
+++ b/pytorch_object_detection/ssd/README.md
@@ -38,6 +38,7 @@
* 确保提前下载好对应预训练模型权重
* 单GPU训练或CPU,直接使用train_ssd300.py训练脚本
* 若要使用多GPU训练,使用 "python -m torch.distributed.launch --nproc_per_node=8 --use_env train_multi_GPU.py" 指令,nproc_per_node参数为使用GPU数量
+* 训练过程中保存的`results.txt`是每个epoch在验证集上的COCO指标,前12个值是COCO指标,后面两个值是训练平均损失以及学习率
## 如果对SSD算法原理不是很理解可参考我的bilibili
* https://www.bilibili.com/video/BV1fT4y1L7Gi
diff --git a/pytorch_object_detection/ssd/draw_box_utils.py b/pytorch_object_detection/ssd/draw_box_utils.py
index 1a2926583..835d7f7c1 100644
--- a/pytorch_object_detection/ssd/draw_box_utils.py
+++ b/pytorch_object_detection/ssd/draw_box_utils.py
@@ -1,6 +1,7 @@
-import collections
+from PIL.Image import Image, fromarray
import PIL.ImageDraw as ImageDraw
import PIL.ImageFont as ImageFont
+from PIL import ImageColor
import numpy as np
STANDARD_COLORS = [
@@ -30,66 +31,123 @@
]
-def filter_low_thresh(boxes, scores, classes, category_index, thresh, box_to_display_str_map, box_to_color_map):
- for i in range(boxes.shape[0]):
- if scores[i] > thresh:
- box = tuple(boxes[i].tolist()) # numpy -> list -> tuple
- if classes[i] in category_index.keys():
- class_name = category_index[classes[i]]
- else:
- class_name = 'N/A'
- display_str = str(class_name)
- display_str = '{}: {}%'.format(display_str, int(100 * scores[i]))
- box_to_display_str_map[box].append(display_str)
- box_to_color_map[box] = STANDARD_COLORS[
- classes[i] % len(STANDARD_COLORS)]
- else:
- break # 网络输出概率已经排序过,当遇到一个不满足后面的肯定不满足
-
-
-def draw_text(draw, box_to_display_str_map, box, left, right, top, bottom, color):
+def draw_text(draw,
+ box: list,
+ cls: int,
+ score: float,
+ category_index: dict,
+ color: str,
+ font: str = 'arial.ttf',
+ font_size: int = 24):
+ """
+ 将目标边界框和类别信息绘制到图片上
+ """
try:
- font = ImageFont.truetype('arial.ttf', 24)
+ font = ImageFont.truetype(font, font_size)
except IOError:
font = ImageFont.load_default()
+ left, top, right, bottom = box
# If the total height of the display strings added to the top of the bounding
# box exceeds the top of the image, stack the strings below the bounding box
# instead of above.
- display_str_heights = [font.getsize(ds)[1] for ds in box_to_display_str_map[box]]
+ display_str = f"{category_index[str(cls)]}: {int(100 * score)}%"
+ display_str_heights = [font.getsize(ds)[1] for ds in display_str]
# Each display_str has a top and bottom margin of 0.05x.
- total_display_str_height = (1 + 2 * 0.05) * sum(display_str_heights)
+ display_str_height = (1 + 2 * 0.05) * max(display_str_heights)
- if top > total_display_str_height:
+ if top > display_str_height:
+ text_top = top - display_str_height
text_bottom = top
else:
- text_bottom = bottom + total_display_str_height
- # Reverse list and print from bottom to top.
- for display_str in box_to_display_str_map[box][::-1]:
- text_width, text_height = font.getsize(display_str)
- margin = np.ceil(0.05 * text_height)
- draw.rectangle([(left, text_bottom - text_height - 2 * margin),
- (left + text_width, text_bottom)], fill=color)
- draw.text((left + margin, text_bottom - text_height - margin),
- display_str,
+ text_top = bottom
+ text_bottom = bottom + display_str_height
+
+ for ds in display_str:
+ text_width, text_height = font.getsize(ds)
+ margin = np.ceil(0.05 * text_width)
+ draw.rectangle([(left, text_top),
+ (left + text_width + 2 * margin, text_bottom)], fill=color)
+ draw.text((left + margin, text_top),
+ ds,
fill='black',
font=font)
- text_bottom -= text_height - 2 * margin
+ left += text_width
+
+
+def draw_masks(image, masks, colors, thresh: float = 0.7, alpha: float = 0.5):
+ np_image = np.array(image)
+ masks = np.where(masks > thresh, True, False)
+
+ # colors = np.array(colors)
+ img_to_draw = np.copy(np_image)
+ # TODO: There might be a way to vectorize this
+ for mask, color in zip(masks, colors):
+ img_to_draw[mask] = color
+
+ out = np_image * (1 - alpha) + img_to_draw * alpha
+ return fromarray(out.astype(np.uint8))
+
+
+def draw_objs(image: Image,
+ boxes: np.ndarray = None,
+ classes: np.ndarray = None,
+ scores: np.ndarray = None,
+ masks: np.ndarray = None,
+ category_index: dict = None,
+ box_thresh: float = 0.1,
+ mask_thresh: float = 0.5,
+ line_thickness: int = 8,
+ font: str = 'arial.ttf',
+ font_size: int = 24,
+ draw_boxes_on_image: bool = True,
+ draw_masks_on_image: bool = False):
+ """
+ 将目标边界框信息,类别信息,mask信息绘制在图片上
+ Args:
+ image: 需要绘制的图片
+ boxes: 目标边界框信息
+ classes: 目标类别信息
+ scores: 目标概率信息
+ masks: 目标mask信息
+ category_index: 类别与名称字典
+ box_thresh: 过滤的概率阈值
+ mask_thresh:
+ line_thickness: 边界框宽度
+ font: 字体类型
+ font_size: 字体大小
+ draw_boxes_on_image:
+ draw_masks_on_image:
+
+ Returns:
+
+ """
+
+ # 过滤掉低概率的目标
+ idxs = np.greater(scores, box_thresh)
+ boxes = boxes[idxs]
+ classes = classes[idxs]
+ scores = scores[idxs]
+ if masks is not None:
+ masks = masks[idxs]
+ if len(boxes) == 0:
+ return image
+ colors = [ImageColor.getrgb(STANDARD_COLORS[cls % len(STANDARD_COLORS)]) for cls in classes]
-def draw_box(image, boxes, classes, scores, category_index, thresh=0.5, line_thickness=8):
- box_to_display_str_map = collections.defaultdict(list)
- box_to_color_map = collections.defaultdict(str)
+ if draw_boxes_on_image:
+ # Draw all boxes onto image.
+ draw = ImageDraw.Draw(image)
+ for box, cls, score, color in zip(boxes, classes, scores, colors):
+ left, top, right, bottom = box
+ # 绘制目标边界框
+ draw.line([(left, top), (left, bottom), (right, bottom),
+ (right, top), (left, top)], width=line_thickness, fill=color)
+ # 绘制类别和概率信息
+ draw_text(draw, box.tolist(), int(cls), float(score), category_index, color, font, font_size)
- filter_low_thresh(boxes, scores, classes, category_index, thresh, box_to_display_str_map, box_to_color_map)
+ if draw_masks_on_image and (masks is not None):
+ # Draw all mask onto image.
+ image = draw_masks(image, masks, colors, mask_thresh)
- # Draw all boxes onto image.
- draw = ImageDraw.Draw(image)
- im_width, im_height = image.size
- for box, color in box_to_color_map.items():
- xmin, ymin, xmax, ymax = box
- (left, right, top, bottom) = (xmin * 1, xmax * 1,
- ymin * 1, ymax * 1)
- draw.line([(left, top), (left, bottom), (right, bottom),
- (right, top), (left, top)], width=line_thickness, fill=color)
- draw_text(draw, box_to_display_str_map, box, left, right, top, bottom, color)
+ return image
diff --git a/pytorch_object_detection/ssd/my_dataset.py b/pytorch_object_detection/ssd/my_dataset.py
index 23bce7430..ebea5635f 100644
--- a/pytorch_object_detection/ssd/my_dataset.py
+++ b/pytorch_object_detection/ssd/my_dataset.py
@@ -11,7 +11,11 @@ class VOCDataSet(Dataset):
def __init__(self, voc_root, year="2012", transforms=None, train_set='train.txt'):
assert year in ["2007", "2012"], "year must be in ['2007', '2012']"
- self.root = os.path.join(voc_root, "VOCdevkit", f"VOC{year}")
+ # 增加容错能力
+ if "VOCdevkit" in voc_root:
+ self.root = os.path.join(voc_root, f"VOC{year}")
+ else:
+ self.root = os.path.join(voc_root, "VOCdevkit", f"VOC{year}")
self.img_root = os.path.join(self.root, "JPEGImages")
self.annotations_root = os.path.join(self.root, "Annotations")
@@ -24,9 +28,8 @@ def __init__(self, voc_root, year="2012", transforms=None, train_set='train.txt'
# read class_indict
json_file = "./pascal_voc_classes.json"
assert os.path.exists(json_file), "{} file not exist.".format(json_file)
- json_file = open(json_file, 'r')
- self.class_dict = json.load(json_file)
- json_file.close()
+ with open(json_file, 'r') as f:
+ self.class_dict = json.load(f)
self.transforms = transforms
@@ -198,7 +201,7 @@ def collate_fn(batch):
return images, targets
# import transforms
-# from draw_box_utils import draw_box
+# from draw_box_utils import draw_objs
# from PIL import Image
# import json
# import matplotlib.pyplot as plt
@@ -210,7 +213,7 @@ def collate_fn(batch):
# try:
# json_file = open('./pascal_voc_classes.json', 'r')
# class_dict = json.load(json_file)
-# category_index = {v: k for k, v in class_dict.items()}
+# category_index = {str(v): str(k) for k, v in class_dict.items()}
# except Exception as e:
# print(e)
# exit(-1)
@@ -227,12 +230,14 @@ def collate_fn(batch):
# for index in random.sample(range(0, len(train_data_set)), k=5):
# img, target = train_data_set[index]
# img = ts.ToPILImage()(img)
-# draw_box(img,
-# target["boxes"].numpy(),
-# target["labels"].numpy(),
-# [1 for i in range(len(target["labels"].numpy()))],
-# category_index,
-# thresh=0.5,
-# line_thickness=5)
-# plt.imshow(img)
+# plot_img = draw_objs(img,
+# target["boxes"].numpy(),
+# target["labels"].numpy(),
+# np.ones(target["labels"].shape[0]),
+# category_index=category_index,
+# box_thresh=0.5,
+# line_thickness=3,
+# font='arial.ttf',
+# font_size=20)
+# plt.imshow(plot_img)
# plt.show()
diff --git a/pytorch_object_detection/ssd/predict_test.py b/pytorch_object_detection/ssd/predict_test.py
index dee265c49..ea8e8eeef 100644
--- a/pytorch_object_detection/ssd/predict_test.py
+++ b/pytorch_object_detection/ssd/predict_test.py
@@ -8,7 +8,7 @@
import transforms
from src import SSD300, Backbone
-from draw_box_utils import draw_box
+from draw_box_utils import draw_objs
def create_model(num_classes):
@@ -34,8 +34,10 @@ def main():
model = create_model(num_classes=num_classes)
# load train weights
- train_weights = "./save_weights/ssd300-14.pth"
- model.load_state_dict(torch.load(train_weights, map_location=device)['model'])
+ weights_path = "./save_weights/ssd300-14.pth"
+ weights_dict = torch.load(weights_path, map_location='cpu')
+ weights_dict = weights_dict["model"] if "model" in weights_dict else weights_dict
+ model.load_state_dict(weights_dict)
model.to(device)
# read class_indict
@@ -44,7 +46,7 @@ def main():
json_file = open(json_path, 'r')
class_dict = json.load(json_file)
json_file.close()
- category_index = {v: k for k, v in class_dict.items()}
+ category_index = {str(v): str(k) for k, v in class_dict.items()}
# load image
original_img = Image.open("./test.jpg")
@@ -77,15 +79,19 @@ def main():
if len(predict_boxes) == 0:
print("没有检测到任何目标!")
- draw_box(original_img,
- predict_boxes,
- predict_classes,
- predict_scores,
- category_index,
- thresh=0.5,
- line_thickness=5)
- plt.imshow(original_img)
+ plot_img = draw_objs(original_img,
+ predict_boxes,
+ predict_classes,
+ predict_scores,
+ category_index=category_index,
+ box_thresh=0.5,
+ line_thickness=3,
+ font='arial.ttf',
+ font_size=20)
+ plt.imshow(plot_img)
plt.show()
+ # 保存预测的图片结果
+ plot_img.save("test_result.jpg")
if __name__ == "__main__":
diff --git a/pytorch_object_detection/ssd/validation.py b/pytorch_object_detection/ssd/validation.py
index aed5e55fc..4cda72ab3 100644
--- a/pytorch_object_detection/ssd/validation.py
+++ b/pytorch_object_detection/ssd/validation.py
@@ -101,9 +101,9 @@ def main(parser_data):
# read class_indict
label_json_path = './pascal_voc_classes.json'
assert os.path.exists(label_json_path), "json file {} dose not exist.".format(label_json_path)
- json_file = open(label_json_path, 'r')
- class_dict = json.load(json_file)
- json_file.close()
+ with open(label_json_path, 'r') as f:
+ class_dict = json.load(f)
+
category_index = {v: k for k, v in class_dict.items()}
VOC_root = parser_data.data_path
@@ -133,7 +133,9 @@ def main(parser_data):
# 载入你自己训练好的模型权重
weights_path = parser_data.weights
assert os.path.exists(weights_path), "not found {} file.".format(weights_path)
- model.load_state_dict(torch.load(weights_path, map_location=device)['model'])
+ weights_dict = torch.load(weights_path, map_location='cpu')
+ weights_dict = weights_dict["model"] if "model" in weights_dict else weights_dict
+ model.load_state_dict(weights_dict)
# print(model)
model.to(device)
diff --git a/pytorch_object_detection/train_coco_dataset/README.md b/pytorch_object_detection/train_coco_dataset/README.md
index db98ce758..eb516bc8b 100644
--- a/pytorch_object_detection/train_coco_dataset/README.md
+++ b/pytorch_object_detection/train_coco_dataset/README.md
@@ -5,7 +5,7 @@
## 环境配置:
* Python3.6/3.7/3.8
-* Pytorch1.7.1
+* Pytorch1.10.0
* pycocotools(Linux:```pip install pycocotools```; Windows:```pip install pycocotools-windows```(不需要额外安装vs))
* Ubuntu或Centos(不建议Windows)
* 最好使用GPU训练
@@ -17,28 +17,27 @@
├── network_files: Faster R-CNN网络(包括Fast R-CNN以及RPN等模块)
├── train_utils: 训练验证相关模块(包括pycocotools)
├── my_dataset.py: 自定义dataset用于读取COCO2017数据集
- ├── train.py: 以VGG16做为backbone进行训练
+ ├── train.py: 以resnet50做为backbone进行训练
├── train_multi_GPU.py: 针对使用多GPU的用户使用
├── predict.py: 简易的预测脚本,使用训练好的权重进行预测测试
├── validation.py: 利用训练好的权重验证/测试数据的COCO指标,并生成record_mAP.txt文件
- ├── transforms.py: 数据预处理(随机水平翻转图像以及bboxes、将PIL图像转为Tensor)
- └── compute_receptive_field.py: 计算VGG16用于特征提取部分的感受野(不包括最后一个maxpool层,228)
+ └── transforms.py: 数据预处理(随机水平翻转图像以及bboxes、将PIL图像转为Tensor)
```
-## 预训练权重下载地址(下载后放入backbone文件夹中):
-* VGG16 https://download.pytorch.org/models/vgg16-397923af.pth
-* 注意,下载的预训练权重记得要重命名,比如在train.py中读取的是```vgg16.pth```文件,
- 不是```vgg16-397923af.pth```
+## 预训练权重下载地址(下载后放入项目根目录):
+* Resnet50 https://download.pytorch.org/models/resnet50-19c8e357.pth
+* 注意,下载的预训练权重记得要重命名,比如在train.py中读取的是`resnet50.pth`文件,
+ 不是`resnet50-19c8e357.pth`
## 数据集,本例程使用的是COCO2017数据集
* COCO官网地址:https://cocodataset.org/
* 对数据集不了解的可以看下我写的博文:https://blog.csdn.net/qq_37541097/article/details/113247318
* 这里以下载coco2017数据集为例,主要下载三个文件:
- * ```2017 Train images [118K/18GB]```:训练过程中使用到的所有图像文件
- * ```2017 Val images [5K/1GB]```:验证过程中使用到的所有图像文件
- * ```2017 Train/Val annotations [241MB]```:对应训练集和验证集的标注json文件
-* 都解压到```coco2017```文件夹下,可得到如下文件结构:
+ * `2017 Train images [118K/18GB]`:训练过程中使用到的所有图像文件
+ * `2017 Val images [5K/1GB]`:验证过程中使用到的所有图像文件
+ * `2017 Train/Val annotations [241MB]`:对应训练集和验证集的标注json文件
+* 都解压到`coco2017`文件夹下,可得到如下文件结构:
```
├── coco2017: 数据集根目录
├── train2017: 所有训练图像文件夹(118287张)
@@ -56,35 +55,36 @@
* 确保提前准备好数据集
* 确保提前下载好对应预训练模型权重
* 若要使用单GPU训练直接使用train.py训练脚本
-* 若要使用多GPU训练,使用```python -m torch.distributed.launch --nproc_per_node=8 --use_env train_multi_GPU.py```指令,```nproc_per_node```参数为使用GPU数量
-* 如果想指定使用哪些GPU设备可在指令前加上```CUDA_VISIBLE_DEVICES=0,3```(例如我只要使用设备中的第1块和第4块GPU设备)
-* ```CUDA_VISIBLE_DEVICES=0,3 python -m torch.distributed.launch --nproc_per_node=2 --use_env train_multi_GPU.py```
+* 若要使用多GPU训练,使用`torchrun --nproc_per_node=8 train_multi_GPU.py`指令,`nproc_per_node`参数为使用GPU数量
+* 如果想指定使用哪些GPU设备可在指令前加上`CUDA_VISIBLE_DEVICES=0,3`(例如我只要使用设备中的第1块和第4块GPU设备)
+* `CUDA_VISIBLE_DEVICES=0,3 torchrun --nproc_per_node=2 train_multi_GPU.py`
## 注意事项
-* 在使用训练脚本时,注意要将'--data-path'设置为自己存放'coco2017'文件夹所在的**根目录**
-* 在使用预测脚本时,要将'train_weights'设置为你自己生成的权重路径。
-* 使用validation文件时,注意确保你的验证集或者测试集中必须包含每个类别的目标,并且使用时只需要修改'--num-classes'、'--data-path'和'--weights'即可,其他代码尽量不要改动
+* 在使用训练脚本时,注意要将`--data-path`设置为自己存放`coco2017`文件夹所在的**根目录**
+* 训练过程中保存的`results.txt`是每个epoch在验证集上的COCO指标,前12个值是COCO指标,后面两个值是训练平均损失以及学习率
+* 在使用预测脚本时,要将`weights_path`设置为你自己生成的权重路径。
+* 使用validation文件时,注意确保你的验证集或者测试集中必须包含每个类别的目标,并且使用时只需要修改`--num-classes`、`--data-path`和`--weights-path`即可,其他代码尽量不要改动
-## 本项目训练得到的权重(Faster R-CNN + VGG16)
-* 链接: https://pan.baidu.com/s/1fz_9raY6gGLNuAO2_uNp9Q 密码: 7l3v
+## 本项目训练得到的权重(Faster R-CNN + Resnet50)
+* 链接: https://pan.baidu.com/s/1iF-Yl_9TkFFeAy-JysfGSw 密码: d2d8
* COCO2017验证集mAP:
```
- Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.233
- Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.415
- Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.233
- Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.104
- Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.262
- Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.323
- Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.216
- Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.319
- Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.327
- Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.145
- Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.361
- Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.463
+ Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.277
+ Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.453
+ Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.290
+ Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.126
+ Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.308
+ Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.378
+ Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.243
+ Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.358
+ Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.366
+ Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.169
+ Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.402
+ Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.512
```
## 如果对Faster RCNN原理不是很理解可参考我的bilibili
* https://b23.tv/sXcBSP
## Faster RCNN框架图
-
\ No newline at end of file
+
diff --git a/pytorch_object_detection/train_coco_dataset/backbone/__init__.py b/pytorch_object_detection/train_coco_dataset/backbone/__init__.py
index f7559da86..292a703ac 100644
--- a/pytorch_object_detection/train_coco_dataset/backbone/__init__.py
+++ b/pytorch_object_detection/train_coco_dataset/backbone/__init__.py
@@ -1,3 +1,5 @@
from .resnet50_fpn_model import resnet50_fpn_backbone
from .mobilenetv2_model import MobileNetV2
from .vgg_model import vgg
+from .resnet import *
+from .feature_pyramid_network import BackboneWithFPN, LastLevelMaxPool
diff --git a/pytorch_object_detection/train_coco_dataset/backbone/feature_pyramid_network.py b/pytorch_object_detection/train_coco_dataset/backbone/feature_pyramid_network.py
index 79739f219..fc2fc757f 100644
--- a/pytorch_object_detection/train_coco_dataset/backbone/feature_pyramid_network.py
+++ b/pytorch_object_detection/train_coco_dataset/backbone/feature_pyramid_network.py
@@ -8,6 +8,111 @@
from torch.jit.annotations import Tuple, List, Dict
+class IntermediateLayerGetter(nn.ModuleDict):
+ """
+ Module wrapper that returns intermediate layers from a model
+ It has a strong assumption that the modules have been registered
+ into the model in the same order as they are used.
+ This means that one should **not** reuse the same nn.Module
+ twice in the forward if you want this to work.
+ Additionally, it is only able to query submodules that are directly
+ assigned to the model. So if `model` is passed, `model.feature1` can
+ be returned, but not `model.feature1.layer2`.
+ Arguments:
+ model (nn.Module): model on which we will extract the features
+ return_layers (Dict[name, new_name]): a dict containing the names
+ of the modules for which the activations will be returned as
+ the key of the dict, and the value of the dict is the name
+ of the returned activation (which the user can specify).
+ """
+ __annotations__ = {
+ "return_layers": Dict[str, str],
+ }
+
+ def __init__(self, model, return_layers):
+ if not set(return_layers).issubset([name for name, _ in model.named_children()]):
+ raise ValueError("return_layers are not present in model")
+
+ orig_return_layers = return_layers
+ return_layers = {str(k): str(v) for k, v in return_layers.items()}
+ layers = OrderedDict()
+
+ # 遍历模型子模块按顺序存入有序字典
+ # 只保存layer4及其之前的结构,舍去之后不用的结构
+ for name, module in model.named_children():
+ layers[name] = module
+ if name in return_layers:
+ del return_layers[name]
+ if not return_layers:
+ break
+
+ super().__init__(layers)
+ self.return_layers = orig_return_layers
+
+ def forward(self, x):
+ out = OrderedDict()
+ # 依次遍历模型的所有子模块,并进行正向传播,
+ # 收集layer1, layer2, layer3, layer4的输出
+ for name, module in self.items():
+ x = module(x)
+ if name in self.return_layers:
+ out_name = self.return_layers[name]
+ out[out_name] = x
+ return out
+
+
+class BackboneWithFPN(nn.Module):
+ """
+ Adds a FPN on top of a model.
+ Internally, it uses torchvision.models._utils.IntermediateLayerGetter to
+ extract a submodel that returns the feature maps specified in return_layers.
+ The same limitations of IntermediatLayerGetter apply here.
+ Arguments:
+ backbone (nn.Module)
+ return_layers (Dict[name, new_name]): a dict containing the names
+ of the modules for which the activations will be returned as
+ the key of the dict, and the value of the dict is the name
+ of the returned activation (which the user can specify).
+ in_channels_list (List[int]): number of channels for each feature map
+ that is returned, in the order they are present in the OrderedDict
+ out_channels (int): number of channels in the FPN.
+ extra_blocks: ExtraFPNBlock
+ Attributes:
+ out_channels (int): the number of channels in the FPN
+ """
+
+ def __init__(self,
+ backbone: nn.Module,
+ return_layers=None,
+ in_channels_list=None,
+ out_channels=256,
+ extra_blocks=None,
+ re_getter=True):
+ super().__init__()
+
+ if extra_blocks is None:
+ extra_blocks = LastLevelMaxPool()
+
+ if re_getter:
+ assert return_layers is not None
+ self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
+ else:
+ self.body = backbone
+
+ self.fpn = FeaturePyramidNetwork(
+ in_channels_list=in_channels_list,
+ out_channels=out_channels,
+ extra_blocks=extra_blocks,
+ )
+
+ self.out_channels = out_channels
+
+ def forward(self, x):
+ x = self.body(x)
+ x = self.fpn(x)
+ return x
+
+
class FeaturePyramidNetwork(nn.Module):
"""
Module that adds a FPN from on top of a set of feature maps. This is based on
@@ -27,7 +132,7 @@ class FeaturePyramidNetwork(nn.Module):
"""
def __init__(self, in_channels_list, out_channels, extra_blocks=None):
- super(FeaturePyramidNetwork, self).__init__()
+ super().__init__()
# 用来调整resnet特征矩阵(layer1,2,3,4)的channel(kernel_size=1)
self.inner_blocks = nn.ModuleList()
# 对调整后的特征矩阵使用3x3的卷积核来得到对应的预测特征矩阵
@@ -48,8 +153,7 @@ def __init__(self, in_channels_list, out_channels, extra_blocks=None):
self.extra_blocks = extra_blocks
- def get_result_from_inner_blocks(self, x, idx):
- # type: (Tensor, int) -> Tensor
+ def get_result_from_inner_blocks(self, x: Tensor, idx: int) -> Tensor:
"""
This is equivalent to self.inner_blocks[idx](x),
but torchscript doesn't support this yet
@@ -65,8 +169,7 @@ def get_result_from_inner_blocks(self, x, idx):
i += 1
return out
- def get_result_from_layer_blocks(self, x, idx):
- # type: (Tensor, int) -> Tensor
+ def get_result_from_layer_blocks(self, x: Tensor, idx: int) -> Tensor:
"""
This is equivalent to self.layer_blocks[idx](x),
but torchscript doesn't support this yet
@@ -82,8 +185,7 @@ def get_result_from_layer_blocks(self, x, idx):
i += 1
return out
- def forward(self, x):
- # type: (Dict[str, Tensor]) -> Dict[str, Tensor]
+ def forward(self, x: Dict[str, Tensor]) -> Dict[str, Tensor]:
"""
Computes the FPN for a set of feature maps.
Arguments:
@@ -127,8 +229,7 @@ class LastLevelMaxPool(torch.nn.Module):
Applies a max_pool2d on top of the last feature map
"""
- def forward(self, x, y, names):
- # type: (List[Tensor], List[Tensor], List[str]) -> Tuple[List[Tensor], List[str]]
+ def forward(self, x: List[Tensor], y: List[Tensor], names: List[str]) -> Tuple[List[Tensor], List[str]]:
names.append("pool")
x.append(F.max_pool2d(x[-1], 1, 2, 0))
return x, names
diff --git a/pytorch_object_detection/train_coco_dataset/backbone/resnet.py b/pytorch_object_detection/train_coco_dataset/backbone/resnet.py
new file mode 100644
index 000000000..c2aa086fe
--- /dev/null
+++ b/pytorch_object_detection/train_coco_dataset/backbone/resnet.py
@@ -0,0 +1,198 @@
+import torch.nn as nn
+import torch
+
+
+class BasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs):
+ super(BasicBlock, self).__init__()
+ self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
+ kernel_size=3, stride=stride, padding=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(out_channel)
+ self.relu = nn.ReLU()
+ self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
+ kernel_size=3, stride=1, padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(out_channel)
+ self.downsample = downsample
+
+ def forward(self, x):
+ identity = x
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ out += identity
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck(nn.Module):
+ """
+ 注意:原论文中,在虚线残差结构的主分支上,第一个1x1卷积层的步距是2,第二个3x3卷积层步距是1。
+ 但在pytorch官方实现过程中是第一个1x1卷积层的步距是1,第二个3x3卷积层步距是2,
+ 这么做的好处是能够在top1上提升大概0.5%的准确率。
+ 可参考Resnet v1.5 https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch
+ """
+ expansion = 4
+
+ def __init__(self, in_channel, out_channel, stride=1, downsample=None,
+ groups=1, width_per_group=64):
+ super(Bottleneck, self).__init__()
+
+ width = int(out_channel * (width_per_group / 64.)) * groups
+
+ self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=width,
+ kernel_size=1, stride=1, bias=False) # squeeze channels
+ self.bn1 = nn.BatchNorm2d(width)
+ # -----------------------------------------
+ self.conv2 = nn.Conv2d(in_channels=width, out_channels=width, groups=groups,
+ kernel_size=3, stride=stride, bias=False, padding=1)
+ self.bn2 = nn.BatchNorm2d(width)
+ # -----------------------------------------
+ self.conv3 = nn.Conv2d(in_channels=width, out_channels=out_channel*self.expansion,
+ kernel_size=1, stride=1, bias=False) # unsqueeze channels
+ self.bn3 = nn.BatchNorm2d(out_channel*self.expansion)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+
+ def forward(self, x):
+ identity = x
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ out += identity
+ out = self.relu(out)
+
+ return out
+
+
+class ResNet(nn.Module):
+
+ def __init__(self,
+ block,
+ blocks_num,
+ num_classes=1000,
+ include_top=True,
+ groups=1,
+ width_per_group=64):
+ super(ResNet, self).__init__()
+ self.include_top = include_top
+ self.in_channel = 64
+
+ self.groups = groups
+ self.width_per_group = width_per_group
+
+ self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,
+ padding=3, bias=False)
+ self.bn1 = nn.BatchNorm2d(self.in_channel)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+ self.layer1 = self._make_layer(block, 64, blocks_num[0])
+ self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)
+ self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)
+ self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)
+ if self.include_top:
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # output size = (1, 1)
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+
+ def _make_layer(self, block, channel, block_num, stride=1):
+ downsample = None
+ if stride != 1 or self.in_channel != channel * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),
+ nn.BatchNorm2d(channel * block.expansion))
+
+ layers = []
+ layers.append(block(self.in_channel,
+ channel,
+ downsample=downsample,
+ stride=stride,
+ groups=self.groups,
+ width_per_group=self.width_per_group))
+ self.in_channel = channel * block.expansion
+
+ for _ in range(1, block_num):
+ layers.append(block(self.in_channel,
+ channel,
+ groups=self.groups,
+ width_per_group=self.width_per_group))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+
+ if self.include_top:
+ x = self.avgpool(x)
+ x = torch.flatten(x, 1)
+ x = self.fc(x)
+
+ return x
+
+
+def resnet34(num_classes=1000, include_top=True):
+ # https://download.pytorch.org/models/resnet34-333f7ec4.pth
+ return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)
+
+
+def resnet50(num_classes=1000, include_top=True):
+ # https://download.pytorch.org/models/resnet50-19c8e357.pth
+ return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)
+
+
+def resnet101(num_classes=1000, include_top=True):
+ # https://download.pytorch.org/models/resnet101-5d3b4d8f.pth
+ return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, include_top=include_top)
+
+
+def resnext50_32x4d(num_classes=1000, include_top=True):
+ # https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth
+ groups = 32
+ width_per_group = 4
+ return ResNet(Bottleneck, [3, 4, 6, 3],
+ num_classes=num_classes,
+ include_top=include_top,
+ groups=groups,
+ width_per_group=width_per_group)
+
+
+def resnext101_32x8d(num_classes=1000, include_top=True):
+ # https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth
+ groups = 32
+ width_per_group = 8
+ return ResNet(Bottleneck, [3, 4, 23, 3],
+ num_classes=num_classes,
+ include_top=include_top,
+ groups=groups,
+ width_per_group=width_per_group)
diff --git a/pytorch_object_detection/train_coco_dataset/backbone/resnet50_fpn_model.py b/pytorch_object_detection/train_coco_dataset/backbone/resnet50_fpn_model.py
index 8c796cfac..b15930765 100644
--- a/pytorch_object_detection/train_coco_dataset/backbone/resnet50_fpn_model.py
+++ b/pytorch_object_detection/train_coco_dataset/backbone/resnet50_fpn_model.py
@@ -1,19 +1,17 @@
import os
-from collections import OrderedDict
import torch
import torch.nn as nn
-from torch.jit.annotations import List, Dict
from torchvision.ops.misc import FrozenBatchNorm2d
-from .feature_pyramid_network import FeaturePyramidNetwork, LastLevelMaxPool
+from .feature_pyramid_network import BackboneWithFPN, LastLevelMaxPool
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, in_channel, out_channel, stride=1, downsample=None, norm_layer=None):
- super(Bottleneck, self).__init__()
+ super().__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
@@ -56,7 +54,7 @@ def forward(self, x):
class ResNet(nn.Module):
def __init__(self, block, blocks_num, num_classes=1000, include_top=True, norm_layer=None):
- super(ResNet, self).__init__()
+ super().__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer
@@ -136,100 +134,6 @@ def overwrite_eps(model, eps):
module.eps = eps
-class IntermediateLayerGetter(nn.ModuleDict):
- """
- Module wrapper that returns intermediate layers from a model
- It has a strong assumption that the modules have been registered
- into the model in the same order as they are used.
- This means that one should **not** reuse the same nn.Module
- twice in the forward if you want this to work.
- Additionally, it is only able to query submodules that are directly
- assigned to the model. So if `model` is passed, `model.feature1` can
- be returned, but not `model.feature1.layer2`.
- Arguments:
- model (nn.Module): model on which we will extract the features
- return_layers (Dict[name, new_name]): a dict containing the names
- of the modules for which the activations will be returned as
- the key of the dict, and the value of the dict is the name
- of the returned activation (which the user can specify).
- """
- __annotations__ = {
- "return_layers": Dict[str, str],
- }
-
- def __init__(self, model, return_layers):
- if not set(return_layers).issubset([name for name, _ in model.named_children()]):
- raise ValueError("return_layers are not present in model")
-
- orig_return_layers = return_layers
- return_layers = {str(k): str(v) for k, v in return_layers.items()}
- layers = OrderedDict()
-
- # 遍历模型子模块按顺序存入有序字典
- # 只保存layer4及其之前的结构,舍去之后不用的结构
- for name, module in model.named_children():
- layers[name] = module
- if name in return_layers:
- del return_layers[name]
- if not return_layers:
- break
-
- super(IntermediateLayerGetter, self).__init__(layers)
- self.return_layers = orig_return_layers
-
- def forward(self, x):
- out = OrderedDict()
- # 依次遍历模型的所有子模块,并进行正向传播,
- # 收集layer1, layer2, layer3, layer4的输出
- for name, module in self.items():
- x = module(x)
- if name in self.return_layers:
- out_name = self.return_layers[name]
- out[out_name] = x
- return out
-
-
-class BackboneWithFPN(nn.Module):
- """
- Adds a FPN on top of a model.
- Internally, it uses torchvision.models._utils.IntermediateLayerGetter to
- extract a submodel that returns the feature maps specified in return_layers.
- The same limitations of IntermediatLayerGetter apply here.
- Arguments:
- backbone (nn.Module)
- return_layers (Dict[name, new_name]): a dict containing the names
- of the modules for which the activations will be returned as
- the key of the dict, and the value of the dict is the name
- of the returned activation (which the user can specify).
- in_channels_list (List[int]): number of channels for each feature map
- that is returned, in the order they are present in the OrderedDict
- out_channels (int): number of channels in the FPN.
- extra_blocks: ExtraFPNBlock
- Attributes:
- out_channels (int): the number of channels in the FPN
- """
-
- def __init__(self, backbone, return_layers, in_channels_list, out_channels, extra_blocks=None):
- super(BackboneWithFPN, self).__init__()
-
- if extra_blocks is None:
- extra_blocks = LastLevelMaxPool()
-
- self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
- self.fpn = FeaturePyramidNetwork(
- in_channels_list=in_channels_list,
- out_channels=out_channels,
- extra_blocks=extra_blocks,
- )
-
- self.out_channels = out_channels
-
- def forward(self, x):
- x = self.body(x)
- x = self.fpn(x)
- return x
-
-
def resnet50_fpn_backbone(pretrain_path="",
norm_layer=FrozenBatchNorm2d, # FrozenBatchNorm2d的功能与BatchNorm2d类似,但参数无法更新
trainable_layers=3,
diff --git a/pytorch_object_detection/train_coco_dataset/change_backbone_with_fpn.py b/pytorch_object_detection/train_coco_dataset/change_backbone_with_fpn.py
new file mode 100644
index 000000000..36b2fa554
--- /dev/null
+++ b/pytorch_object_detection/train_coco_dataset/change_backbone_with_fpn.py
@@ -0,0 +1,257 @@
+import os
+import datetime
+
+import torch
+
+import transforms
+from network_files import FasterRCNN, AnchorsGenerator
+from my_dataset import CocoDetection
+from train_utils import GroupedBatchSampler, create_aspect_ratio_groups
+from train_utils import train_eval_utils as utils
+from backbone import BackboneWithFPN, LastLevelMaxPool
+
+
+def create_model(num_classes):
+ import torchvision
+ from torchvision.models.feature_extraction import create_feature_extractor
+
+ # --- mobilenet_v3_large fpn backbone --- #
+ backbone = torchvision.models.mobilenet_v3_large(pretrained=True)
+ # print(backbone)
+ return_layers = {"features.6": "0", # stride 8
+ "features.12": "1", # stride 16
+ "features.16": "2"} # stride 32
+ # 提供给fpn的每个特征层channel
+ in_channels_list = [40, 112, 960]
+ new_backbone = create_feature_extractor(backbone, return_layers)
+ # img = torch.randn(1, 3, 224, 224)
+ # outputs = new_backbone(img)
+ # [print(f"{k} shape: {v.shape}") for k, v in outputs.items()]
+
+ # --- efficientnet_b0 fpn backbone --- #
+ # backbone = torchvision.models.efficientnet_b0(pretrained=True)
+ # # print(backbone)
+ # return_layers = {"features.3": "0", # stride 8
+ # "features.4": "1", # stride 16
+ # "features.8": "2"} # stride 32
+ # # 提供给fpn的每个特征层channel
+ # in_channels_list = [40, 80, 1280]
+ # new_backbone = create_feature_extractor(backbone, return_layers)
+ # # img = torch.randn(1, 3, 224, 224)
+ # # outputs = new_backbone(img)
+ # # [print(f"{k} shape: {v.shape}") for k, v in outputs.items()]
+
+ backbone_with_fpn = BackboneWithFPN(new_backbone,
+ return_layers=return_layers,
+ in_channels_list=in_channels_list,
+ out_channels=256,
+ extra_blocks=LastLevelMaxPool(),
+ re_getter=False)
+
+ anchor_sizes = ((64,), (128,), (256,), (512,))
+ aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
+ anchor_generator = AnchorsGenerator(sizes=anchor_sizes,
+ aspect_ratios=aspect_ratios)
+
+ roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0', '1', '2'], # 在哪些特征层上进行RoIAlign pooling
+ output_size=[7, 7], # RoIAlign pooling输出特征矩阵尺寸
+ sampling_ratio=2) # 采样率
+
+ model = FasterRCNN(backbone=backbone_with_fpn,
+ num_classes=num_classes,
+ rpn_anchor_generator=anchor_generator,
+ box_roi_pool=roi_pooler)
+
+ return model
+
+
+def main(args):
+ device = torch.device(args.device if torch.cuda.is_available() else "cpu")
+ print("Using {} device training.".format(device.type))
+
+ # 用来保存coco_info的文件
+ results_file = "results{}.txt".format(datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
+
+ data_transform = {
+ "train": transforms.Compose([transforms.ToTensor(),
+ transforms.RandomHorizontalFlip(0.5)]),
+ "val": transforms.Compose([transforms.ToTensor()])
+ }
+
+ COCO_root = args.data_path
+
+ # load train data set
+ # coco2017 -> annotations -> instances_train2017.json
+ train_dataset = CocoDetection(COCO_root, "train", data_transform["train"])
+ train_sampler = None
+
+ # 是否按图片相似高宽比采样图片组成batch
+ # 使用的话能够减小训练时所需GPU显存,默认使用
+ if args.aspect_ratio_group_factor >= 0:
+ train_sampler = torch.utils.data.RandomSampler(train_dataset)
+ # 统计所有图像高宽比例在bins区间中的位置索引
+ group_ids = create_aspect_ratio_groups(train_dataset, k=args.aspect_ratio_group_factor)
+ # 每个batch图片从同一高宽比例区间中取
+ train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids, args.batch_size)
+
+ # 注意这里的collate_fn是自定义的,因为读取的数据包括image和targets,不能直接使用默认的方法合成batch
+ batch_size = args.batch_size
+ nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
+ print('Using %g dataloader workers' % nw)
+ if train_sampler:
+ # 如果按照图片高宽比采样图片,dataloader中需要使用batch_sampler
+ train_data_loader = torch.utils.data.DataLoader(train_dataset,
+ batch_sampler=train_batch_sampler,
+ pin_memory=True,
+ num_workers=nw,
+ collate_fn=train_dataset.collate_fn)
+ else:
+ train_data_loader = torch.utils.data.DataLoader(train_dataset,
+ batch_size=batch_size,
+ shuffle=True,
+ pin_memory=True,
+ num_workers=nw,
+ collate_fn=train_dataset.collate_fn)
+
+ # load validation data set
+ # coco2017 -> annotations -> instances_val2017.json
+ val_dataset = CocoDetection(COCO_root, "val", data_transform["val"])
+ val_data_set_loader = torch.utils.data.DataLoader(val_dataset,
+ batch_size=1,
+ shuffle=False,
+ pin_memory=True,
+ num_workers=nw,
+ collate_fn=val_dataset.collate_fn)
+
+ # create model num_classes equal background + classes
+ model = create_model(num_classes=args.num_classes + 1)
+ # print(model)
+
+ model.to(device)
+
+ # define optimizer
+ params = [p for p in model.parameters() if p.requires_grad]
+ optimizer = torch.optim.SGD(params,
+ lr=args.lr,
+ momentum=args.momentum,
+ weight_decay=args.weight_decay)
+
+ scaler = torch.cuda.amp.GradScaler() if args.amp else None
+
+ # learning rate scheduler
+ lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
+ milestones=args.lr_steps,
+ gamma=args.lr_gamma)
+
+ # 如果指定了上次训练保存的权重文件地址,则接着上次结果接着训练
+ if args.resume != "":
+ checkpoint = torch.load(args.resume, map_location='cpu')
+ model.load_state_dict(checkpoint['model'])
+ optimizer.load_state_dict(checkpoint['optimizer'])
+ lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
+ args.start_epoch = checkpoint['epoch'] + 1
+ if args.amp and "scaler" in checkpoint:
+ scaler.load_state_dict(checkpoint["scaler"])
+ print("the training process from epoch{}...".format(args.start_epoch))
+
+ train_loss = []
+ learning_rate = []
+ val_map = []
+
+ for epoch in range(args.start_epoch, args.epochs):
+ # train for one epoch, printing every 10 iterations
+ mean_loss, lr = utils.train_one_epoch(model, optimizer, train_data_loader,
+ device=device, epoch=epoch,
+ print_freq=50, warmup=True,
+ scaler=scaler)
+ train_loss.append(mean_loss.item())
+ learning_rate.append(lr)
+
+ # update the learning rate
+ lr_scheduler.step()
+
+ # evaluate on the test dataset
+ coco_info = utils.evaluate(model, val_data_set_loader, device=device)
+
+ # write into txt
+ with open(results_file, "a") as f:
+ # 写入的数据包括coco指标还有loss和learning rate
+ result_info = [f"{i:.4f}" for i in coco_info + [mean_loss.item()]] + [f"{lr:.6f}"]
+ txt = "epoch:{} {}".format(epoch, ' '.join(result_info))
+ f.write(txt + "\n")
+
+ val_map.append(coco_info[1]) # pascal mAP
+
+ # save weights
+ save_files = {
+ 'model': model.state_dict(),
+ 'optimizer': optimizer.state_dict(),
+ 'lr_scheduler': lr_scheduler.state_dict(),
+ 'epoch': epoch}
+ if args.amp:
+ save_files["scaler"] = scaler.state_dict()
+ torch.save(save_files, "./save_weights/model-{}.pth".format(epoch))
+
+ # plot loss and lr curve
+ if len(train_loss) != 0 and len(learning_rate) != 0:
+ from plot_curve import plot_loss_and_lr
+ plot_loss_and_lr(train_loss, learning_rate)
+
+ # plot mAP curve
+ if len(val_map) != 0:
+ from plot_curve import plot_map
+ plot_map(val_map)
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser(
+ description=__doc__)
+
+ # 训练设备类型
+ parser.add_argument('--device', default='cuda:0', help='device')
+ # 训练数据集的根目录
+ parser.add_argument('--data-path', default='/data/coco2017', help='dataset')
+ # 检测目标类别数(不包含背景)
+ parser.add_argument('--num-classes', default=90, type=int, help='num_classes')
+ # 文件保存地址
+ parser.add_argument('--output-dir', default='./save_weights', help='path where to save')
+ # 若需要接着上次训练,则指定上次训练保存权重文件地址
+ parser.add_argument('--resume', default='', type=str, help='resume from checkpoint')
+ # 指定接着从哪个epoch数开始训练
+ parser.add_argument('--start_epoch', default=0, type=int, help='start epoch')
+ # 训练的总epoch数
+ parser.add_argument('--epochs', default=26, type=int, metavar='N',
+ help='number of total epochs to run')
+ # 学习率
+ parser.add_argument('--lr', default=0.005, type=float,
+ help='initial learning rate, 0.02 is the default value for training '
+ 'on 8 gpus and 2 images_per_gpu')
+ # SGD的momentum参数
+ parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
+ help='momentum')
+ # SGD的weight_decay参数
+ parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
+ metavar='W', help='weight decay (default: 1e-4)',
+ dest='weight_decay')
+ # 针对torch.optim.lr_scheduler.MultiStepLR的参数
+ parser.add_argument('--lr-steps', default=[16, 22], nargs='+', type=int,
+ help='decrease lr every step-size epochs')
+ # 针对torch.optim.lr_scheduler.MultiStepLR的参数
+ parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma')
+ # 训练的batch size
+ parser.add_argument('--batch_size', default=4, type=int, metavar='N',
+ help='batch size when training.')
+ parser.add_argument('--aspect-ratio-group-factor', default=3, type=int)
+ # 是否使用混合精度训练(需要GPU支持混合精度)
+ parser.add_argument("--amp", default=False, help="Use torch.cuda.amp for mixed precision training")
+
+ args = parser.parse_args()
+ print(args)
+
+ # 检查保存权重文件夹是否存在,不存在则创建
+ if not os.path.exists(args.output_dir):
+ os.makedirs(args.output_dir)
+
+ main(args)
diff --git a/pytorch_object_detection/train_coco_dataset/coco80_indices.json b/pytorch_object_detection/train_coco_dataset/coco80_indices.json
deleted file mode 100644
index e317e9299..000000000
--- a/pytorch_object_detection/train_coco_dataset/coco80_indices.json
+++ /dev/null
@@ -1,82 +0,0 @@
-{
- "1": "person",
- "2": "bicycle",
- "3": "car",
- "4": "motorcycle",
- "5": "airplane",
- "6": "bus",
- "7": "train",
- "8": "truck",
- "9": "boat",
- "10": "traffic light",
- "11": "fire hydrant",
- "12": "stop sign",
- "13": "parking meter",
- "14": "bench",
- "15": "bird",
- "16": "cat",
- "17": "dog",
- "18": "horse",
- "19": "sheep",
- "20": "cow",
- "21": "elephant",
- "22": "bear",
- "23": "zebra",
- "24": "giraffe",
- "25": "backpack",
- "26": "umbrella",
- "27": "handbag",
- "28": "tie",
- "29": "suitcase",
- "30": "frisbee",
- "31": "skis",
- "32": "snowboard",
- "33": "sports ball",
- "34": "kite",
- "35": "baseball bat",
- "36": "baseball glove",
- "37": "skateboard",
- "38": "surfboard",
- "39": "tennis racket",
- "40": "bottle",
- "41": "wine glass",
- "42": "cup",
- "43": "fork",
- "44": "knife",
- "45": "spoon",
- "46": "bowl",
- "47": "banana",
- "48": "apple",
- "49": "sandwich",
- "50": "orange",
- "51": "broccoli",
- "52": "carrot",
- "53": "hot dog",
- "54": "pizza",
- "55": "donut",
- "56": "cake",
- "57": "chair",
- "58": "couch",
- "59": "potted plant",
- "60": "bed",
- "61": "dining table",
- "62": "toilet",
- "63": "tv",
- "64": "laptop",
- "65": "mouse",
- "66": "remote",
- "67": "keyboard",
- "68": "cell phone",
- "69": "microwave",
- "70": "oven",
- "71": "toaster",
- "72": "sink",
- "73": "refrigerator",
- "74": "book",
- "75": "clock",
- "76": "vase",
- "77": "scissors",
- "78": "teddy bear",
- "79": "hair drier",
- "80": "toothbrush"
-}
\ No newline at end of file
diff --git a/pytorch_object_detection/train_coco_dataset/coco91_indices.json b/pytorch_object_detection/train_coco_dataset/coco91_indices.json
new file mode 100644
index 000000000..decbe58ce
--- /dev/null
+++ b/pytorch_object_detection/train_coco_dataset/coco91_indices.json
@@ -0,0 +1,92 @@
+{
+ "1": "person",
+ "2": "bicycle",
+ "3": "car",
+ "4": "motorcycle",
+ "5": "airplane",
+ "6": "bus",
+ "7": "train",
+ "8": "truck",
+ "9": "boat",
+ "10": "traffic light",
+ "11": "fire hydrant",
+ "12": "N/A",
+ "13": "stop sign",
+ "14": "parking meter",
+ "15": "bench",
+ "16": "bird",
+ "17": "cat",
+ "18": "dog",
+ "19": "horse",
+ "20": "sheep",
+ "21": "cow",
+ "22": "elephant",
+ "23": "bear",
+ "24": "zebra",
+ "25": "giraffe",
+ "26": "N/A",
+ "27": "backpack",
+ "28": "umbrella",
+ "29": "N/A",
+ "30": "N/A",
+ "31": "handbag",
+ "32": "tie",
+ "33": "suitcase",
+ "34": "frisbee",
+ "35": "skis",
+ "36": "snowboard",
+ "37": "sports ball",
+ "38": "kite",
+ "39": "baseball bat",
+ "40": "baseball glove",
+ "41": "skateboard",
+ "42": "surfboard",
+ "43": "tennis racket",
+ "44": "bottle",
+ "45": "N/A",
+ "46": "wine glass",
+ "47": "cup",
+ "48": "fork",
+ "49": "knife",
+ "50": "spoon",
+ "51": "bowl",
+ "52": "banana",
+ "53": "apple",
+ "54": "sandwich",
+ "55": "orange",
+ "56": "broccoli",
+ "57": "carrot",
+ "58": "hot dog",
+ "59": "pizza",
+ "60": "donut",
+ "61": "cake",
+ "62": "chair",
+ "63": "couch",
+ "64": "potted plant",
+ "65": "bed",
+ "66": "N/A",
+ "67": "dining table",
+ "68": "N/A",
+ "69": "N/A",
+ "70": "toilet",
+ "71": "N/A",
+ "72": "tv",
+ "73": "laptop",
+ "74": "mouse",
+ "75": "remote",
+ "76": "keyboard",
+ "77": "cell phone",
+ "78": "microwave",
+ "79": "oven",
+ "80": "toaster",
+ "81": "sink",
+ "82": "refrigerator",
+ "83": "N/A",
+ "84": "book",
+ "85": "clock",
+ "86": "vase",
+ "87": "scissors",
+ "88": "teddy bear",
+ "89": "hair drier",
+ "90": "toothbrush"
+}
\ No newline at end of file
diff --git a/pytorch_object_detection/train_coco_dataset/coco91_to_80.json b/pytorch_object_detection/train_coco_dataset/coco91_to_80.json
deleted file mode 100644
index fd190538e..000000000
--- a/pytorch_object_detection/train_coco_dataset/coco91_to_80.json
+++ /dev/null
@@ -1,82 +0,0 @@
-{
- "1": 1,
- "2": 2,
- "3": 3,
- "4": 4,
- "5": 5,
- "6": 6,
- "7": 7,
- "8": 8,
- "9": 9,
- "10": 10,
- "11": 11,
- "13": 12,
- "14": 13,
- "15": 14,
- "16": 15,
- "17": 16,
- "18": 17,
- "19": 18,
- "20": 19,
- "21": 20,
- "22": 21,
- "23": 22,
- "24": 23,
- "25": 24,
- "27": 25,
- "28": 26,
- "31": 27,
- "32": 28,
- "33": 29,
- "34": 30,
- "35": 31,
- "36": 32,
- "37": 33,
- "38": 34,
- "39": 35,
- "40": 36,
- "41": 37,
- "42": 38,
- "43": 39,
- "44": 40,
- "46": 41,
- "47": 42,
- "48": 43,
- "49": 44,
- "50": 45,
- "51": 46,
- "52": 47,
- "53": 48,
- "54": 49,
- "55": 50,
- "56": 51,
- "57": 52,
- "58": 53,
- "59": 54,
- "60": 55,
- "61": 56,
- "62": 57,
- "63": 58,
- "64": 59,
- "65": 60,
- "67": 61,
- "70": 62,
- "72": 63,
- "73": 64,
- "74": 65,
- "75": 66,
- "76": 67,
- "77": 68,
- "78": 69,
- "79": 70,
- "80": 71,
- "81": 72,
- "82": 73,
- "84": 74,
- "85": 75,
- "86": 76,
- "87": 77,
- "88": 78,
- "89": 79,
- "90": 80
-}
\ No newline at end of file
diff --git a/pytorch_object_detection/train_coco_dataset/draw_box_utils.py b/pytorch_object_detection/train_coco_dataset/draw_box_utils.py
index 25d86f4fa..835d7f7c1 100644
--- a/pytorch_object_detection/train_coco_dataset/draw_box_utils.py
+++ b/pytorch_object_detection/train_coco_dataset/draw_box_utils.py
@@ -1,6 +1,7 @@
-import collections
+from PIL.Image import Image, fromarray
import PIL.ImageDraw as ImageDraw
import PIL.ImageFont as ImageFont
+from PIL import ImageColor
import numpy as np
STANDARD_COLORS = [
@@ -30,68 +31,123 @@
]
-def filter_low_thresh(boxes, scores, classes, category_index, thresh, box_to_display_str_map, box_to_color_map):
- for i in range(boxes.shape[0]):
- if scores[i] > thresh:
- box = tuple(boxes[i].tolist()) # numpy -> list -> tuple
- if classes[i] in category_index.keys():
- class_name = category_index[classes[i]]
- elif str(classes[i]) in category_index.keys():
- class_name = category_index[str(classes[i])]
- else:
- class_name = 'N/A'
- display_str = str(class_name)
- display_str = '{}: {}%'.format(display_str, int(100 * scores[i]))
- box_to_display_str_map[box].append(display_str)
- box_to_color_map[box] = STANDARD_COLORS[
- classes[i] % len(STANDARD_COLORS)]
- else:
- break # 网络输出概率已经排序过,当遇到一个不满足后面的肯定不满足
-
-
-def draw_text(draw, box_to_display_str_map, box, left, right, top, bottom, color):
+def draw_text(draw,
+ box: list,
+ cls: int,
+ score: float,
+ category_index: dict,
+ color: str,
+ font: str = 'arial.ttf',
+ font_size: int = 24):
+ """
+ 将目标边界框和类别信息绘制到图片上
+ """
try:
- font = ImageFont.truetype('arial.ttf', 24)
+ font = ImageFont.truetype(font, font_size)
except IOError:
font = ImageFont.load_default()
+ left, top, right, bottom = box
# If the total height of the display strings added to the top of the bounding
# box exceeds the top of the image, stack the strings below the bounding box
# instead of above.
- display_str_heights = [font.getsize(ds)[1] for ds in box_to_display_str_map[box]]
+ display_str = f"{category_index[str(cls)]}: {int(100 * score)}%"
+ display_str_heights = [font.getsize(ds)[1] for ds in display_str]
# Each display_str has a top and bottom margin of 0.05x.
- total_display_str_height = (1 + 2 * 0.05) * sum(display_str_heights)
+ display_str_height = (1 + 2 * 0.05) * max(display_str_heights)
- if top > total_display_str_height:
+ if top > display_str_height:
+ text_top = top - display_str_height
text_bottom = top
else:
- text_bottom = bottom + total_display_str_height
- # Reverse list and print from bottom to top.
- for display_str in box_to_display_str_map[box][::-1]:
- text_width, text_height = font.getsize(display_str)
- margin = np.ceil(0.05 * text_height)
- draw.rectangle([(left, text_bottom - text_height - 2 * margin),
- (left + text_width, text_bottom)], fill=color)
- draw.text((left + margin, text_bottom - text_height - margin),
- display_str,
+ text_top = bottom
+ text_bottom = bottom + display_str_height
+
+ for ds in display_str:
+ text_width, text_height = font.getsize(ds)
+ margin = np.ceil(0.05 * text_width)
+ draw.rectangle([(left, text_top),
+ (left + text_width + 2 * margin, text_bottom)], fill=color)
+ draw.text((left + margin, text_top),
+ ds,
fill='black',
font=font)
- text_bottom -= text_height - 2 * margin
+ left += text_width
+
+
+def draw_masks(image, masks, colors, thresh: float = 0.7, alpha: float = 0.5):
+ np_image = np.array(image)
+ masks = np.where(masks > thresh, True, False)
+
+ # colors = np.array(colors)
+ img_to_draw = np.copy(np_image)
+ # TODO: There might be a way to vectorize this
+ for mask, color in zip(masks, colors):
+ img_to_draw[mask] = color
+
+ out = np_image * (1 - alpha) + img_to_draw * alpha
+ return fromarray(out.astype(np.uint8))
+
+
+def draw_objs(image: Image,
+ boxes: np.ndarray = None,
+ classes: np.ndarray = None,
+ scores: np.ndarray = None,
+ masks: np.ndarray = None,
+ category_index: dict = None,
+ box_thresh: float = 0.1,
+ mask_thresh: float = 0.5,
+ line_thickness: int = 8,
+ font: str = 'arial.ttf',
+ font_size: int = 24,
+ draw_boxes_on_image: bool = True,
+ draw_masks_on_image: bool = False):
+ """
+ 将目标边界框信息,类别信息,mask信息绘制在图片上
+ Args:
+ image: 需要绘制的图片
+ boxes: 目标边界框信息
+ classes: 目标类别信息
+ scores: 目标概率信息
+ masks: 目标mask信息
+ category_index: 类别与名称字典
+ box_thresh: 过滤的概率阈值
+ mask_thresh:
+ line_thickness: 边界框宽度
+ font: 字体类型
+ font_size: 字体大小
+ draw_boxes_on_image:
+ draw_masks_on_image:
+
+ Returns:
+
+ """
+
+ # 过滤掉低概率的目标
+ idxs = np.greater(scores, box_thresh)
+ boxes = boxes[idxs]
+ classes = classes[idxs]
+ scores = scores[idxs]
+ if masks is not None:
+ masks = masks[idxs]
+ if len(boxes) == 0:
+ return image
+ colors = [ImageColor.getrgb(STANDARD_COLORS[cls % len(STANDARD_COLORS)]) for cls in classes]
-def draw_box(image, boxes, classes, scores, category_index, thresh=0.5, line_thickness=8):
- box_to_display_str_map = collections.defaultdict(list)
- box_to_color_map = collections.defaultdict(str)
+ if draw_boxes_on_image:
+ # Draw all boxes onto image.
+ draw = ImageDraw.Draw(image)
+ for box, cls, score, color in zip(boxes, classes, scores, colors):
+ left, top, right, bottom = box
+ # 绘制目标边界框
+ draw.line([(left, top), (left, bottom), (right, bottom),
+ (right, top), (left, top)], width=line_thickness, fill=color)
+ # 绘制类别和概率信息
+ draw_text(draw, box.tolist(), int(cls), float(score), category_index, color, font, font_size)
- filter_low_thresh(boxes, scores, classes, category_index, thresh, box_to_display_str_map, box_to_color_map)
+ if draw_masks_on_image and (masks is not None):
+ # Draw all mask onto image.
+ image = draw_masks(image, masks, colors, mask_thresh)
- # Draw all boxes onto image.
- draw = ImageDraw.Draw(image)
- im_width, im_height = image.size
- for box, color in box_to_color_map.items():
- xmin, ymin, xmax, ymax = box
- (left, right, top, bottom) = (xmin * 1, xmax * 1,
- ymin * 1, ymax * 1)
- draw.line([(left, top), (left, bottom), (right, bottom),
- (right, top), (left, top)], width=line_thickness, fill=color)
- draw_text(draw, box_to_display_str_map, box, left, right, top, bottom, color)
+ return image
diff --git a/pytorch_object_detection/train_coco_dataset/my_dataset.py b/pytorch_object_detection/train_coco_dataset/my_dataset.py
index dbacb54d9..31b71cee9 100644
--- a/pytorch_object_detection/train_coco_dataset/my_dataset.py
+++ b/pytorch_object_detection/train_coco_dataset/my_dataset.py
@@ -64,30 +64,24 @@ def __init__(self, root, dataset="train", transforms=None):
self.transforms = transforms
self.coco = COCO(self.anno_path)
- if dataset == "train":
- # 获取coco数据索引与类别名称的关系
- # 注意在object80中的索引并不是连续的,虽然只有80个类别,但索引还是按照stuff91来排序的
- coco_classes = dict([(v["id"], v["name"]) for k, v in self.coco.cats.items()])
-
- # 将stuff91的类别索引重新编排,从1到80
- coco91to80 = dict([(str(k), idx+1) for idx, (k, _) in enumerate(coco_classes.items())])
- json_str = json.dumps(coco91to80, indent=4)
- with open('coco91_to_80.json', 'w') as json_file:
- json_file.write(json_str)
-
- # 记录重新编排后的索引以及类别名称关系
- coco80_info = dict([(str(idx+1), v) for idx, (_, v) in enumerate(coco_classes.items())])
- json_str = json.dumps(coco80_info, indent=4)
- with open('coco80_indices.json', 'w') as json_file:
- json_file.write(json_str)
- else:
- # 如果是验证集就直接读取生成好的数据
- coco91to80_path = 'coco91_to_80.json'
- assert os.path.exists(coco91to80_path), "file '{}' does not exist.".format(coco91to80_path)
+ # 获取coco数据索引与类别名称的关系
+ # 注意在object80中的索引并不是连续的,虽然只有80个类别,但索引还是按照stuff91来排序的
+ data_classes = dict([(v["id"], v["name"]) for k, v in self.coco.cats.items()])
+ max_index = max(data_classes.keys()) # 90
+ # 将缺失的类别名称设置成N/A
+ coco_classes = {}
+ for k in range(1, max_index + 1):
+ if k in data_classes:
+ coco_classes[k] = data_classes[k]
+ else:
+ coco_classes[k] = "N/A"
- coco91to80 = json.load(open(coco91to80_path, "r"))
+ if dataset == "train":
+ json_str = json.dumps(coco_classes, indent=4)
+ with open("coco91_indices.json", "w") as f:
+ f.write(json_str)
- self.coco91to80 = coco91to80
+ self.coco_classes = coco_classes
ids = list(sorted(self.coco.imgs.keys()))
if dataset == "train":
@@ -102,34 +96,40 @@ def parse_targets(self,
coco_targets: list,
w: int = None,
h: int = None):
+ assert w > 0
+ assert h > 0
+
# 只筛选出单个对象的情况
anno = [obj for obj in coco_targets if obj['iscrowd'] == 0]
- # 进一步检查数据,有的标注信息中可能有w或h为0的情况,这样的数据会导致计算回归loss为nan
- boxes = []
- for obj in anno:
- if obj["bbox"][2] > 0 and obj["bbox"][3] > 0:
- boxes.append(obj["bbox"])
+ boxes = [obj["bbox"] for obj in anno]
# guard against no boxes via resizing
boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
# [xmin, ymin, w, h] -> [xmin, ymin, xmax, ymax]
boxes[:, 2:] += boxes[:, :2]
- if (w is not None) and (h is not None):
- boxes[:, 0::2].clamp_(min=0, max=w)
- boxes[:, 1::2].clamp_(min=0, max=h)
+ boxes[:, 0::2].clamp_(min=0, max=w)
+ boxes[:, 1::2].clamp_(min=0, max=h)
- classes = [self.coco91to80[str(obj["category_id"])] for obj in anno]
+ classes = [obj["category_id"] for obj in anno]
classes = torch.tensor(classes, dtype=torch.int64)
+ area = torch.tensor([obj["area"] for obj in anno])
+ iscrowd = torch.tensor([obj["iscrowd"] for obj in anno])
+
+ # 筛选出合法的目标,即x_max>x_min且y_max>y_min
+ keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
+ boxes = boxes[keep]
+ classes = classes[keep]
+ area = area[keep]
+ iscrowd = iscrowd[keep]
+
target = {}
target["boxes"] = boxes
target["labels"] = classes
target["image_id"] = torch.tensor([img_id])
# for conversion to coco api
- area = torch.tensor([obj["area"] for obj in anno])
- iscrowd = torch.tensor([obj["iscrowd"] for obj in anno])
target["area"] = area
target["iscrowd"] = iscrowd
diff --git a/pytorch_object_detection/train_coco_dataset/network_files/boxes.py b/pytorch_object_detection/train_coco_dataset/network_files/boxes.py
index f720df1f8..8eeca4573 100644
--- a/pytorch_object_detection/train_coco_dataset/network_files/boxes.py
+++ b/pytorch_object_detection/train_coco_dataset/network_files/boxes.py
@@ -23,7 +23,7 @@ def nms(boxes, scores, iou_threshold):
scores for each one of the boxes
iou_threshold : float
discards all overlapping
- boxes with IoU < iou_threshold
+ boxes with IoU > iou_threshold
Returns
-------
diff --git a/pytorch_object_detection/train_coco_dataset/network_files/faster_rcnn_framework.py b/pytorch_object_detection/train_coco_dataset/network_files/faster_rcnn_framework.py
index 20f9bccbd..d658b0113 100644
--- a/pytorch_object_detection/train_coco_dataset/network_files/faster_rcnn_framework.py
+++ b/pytorch_object_detection/train_coco_dataset/network_files/faster_rcnn_framework.py
@@ -245,7 +245,7 @@ class FasterRCNN(FasterRCNNBase):
def __init__(self, backbone, num_classes=None,
# transform parameter
- min_size=800, max_size=1000, # 预处理resize时限制的最小尺寸与最大尺寸
+ min_size=800, max_size=1333, # 预处理resize时限制的最小尺寸与最大尺寸
image_mean=None, image_std=None, # 预处理normalize时使用的均值和方差
# RPN parameters
rpn_anchor_generator=None, rpn_head=None,
diff --git a/pytorch_object_detection/train_coco_dataset/predict.py b/pytorch_object_detection/train_coco_dataset/predict.py
index b74831cdf..2dc508d7e 100644
--- a/pytorch_object_detection/train_coco_dataset/predict.py
+++ b/pytorch_object_detection/train_coco_dataset/predict.py
@@ -6,20 +6,18 @@
import torchvision
from PIL import Image
import matplotlib.pyplot as plt
-
from torchvision import transforms
+from torchvision.models.feature_extraction import create_feature_extractor
+
from network_files import FasterRCNN, AnchorsGenerator
-from backbone import vgg, MobileNetV2
-from draw_box_utils import draw_box
+from backbone import vgg, MobileNetV2, resnet50
+from draw_box_utils import draw_objs
def create_model(num_classes):
- vgg_feature = vgg(model_name="vgg16").features
- backbone = torch.nn.Sequential(*list(vgg_feature._modules.values())[:-1]) # 删除feature中最后的maxpool层
- backbone.out_channels = 512
-
- # backbone = MobileNetV2().features
- # backbone.out_channels = 1280 # 设置对应backbone输出特征矩阵的channels
+ res50 = resnet50()
+ backbone = create_feature_extractor(res50, return_nodes={"layer3": "0"})
+ backbone.out_channels = 1024
anchor_generator = AnchorsGenerator(sizes=((32, 64, 128, 256, 512),),
aspect_ratios=((0.5, 1.0, 2.0),))
@@ -47,21 +45,22 @@ def main():
print("using {} device.".format(device))
# create model
- num_classes = 80
- model = create_model(num_classes=num_classes+1)
+ num_classes = 90 # 不包含背景
+ model = create_model(num_classes=num_classes + 1)
# load train weights
- train_weights = "./save_weights/model_25.pth"
- assert os.path.exists(train_weights), "{} file dose not exist.".format(train_weights)
- model.load_state_dict(torch.load(train_weights, map_location=device)["model"])
+ weights_path = "./save_weights/model_25.pth"
+ assert os.path.exists(weights_path), "{} file dose not exist.".format(weights_path)
+ weights_dict = torch.load(weights_path, map_location='cpu')
+ weights_dict = weights_dict["model"] if "model" in weights_dict else weights_dict
+ model.load_state_dict(weights_dict)
model.to(device)
# read class_indict
- label_json_path = './coco80_indices.json'
+ label_json_path = './coco91_indices.json'
assert os.path.exists(label_json_path), "json file {} dose not exist.".format(label_json_path)
- json_file = open(label_json_path, 'r')
- category_index = json.load(json_file)
- json_file.close()
+ with open(label_json_path, 'r') as f:
+ category_index = json.load(f)
# load image
original_img = Image.open("./test.jpg")
@@ -91,17 +90,19 @@ def main():
if len(predict_boxes) == 0:
print("没有检测到任何目标!")
- draw_box(original_img,
- predict_boxes,
- predict_classes,
- predict_scores,
- category_index,
- thresh=0.5,
- line_thickness=3)
- plt.imshow(original_img)
+ plot_img = draw_objs(original_img,
+ predict_boxes,
+ predict_classes,
+ predict_scores,
+ category_index=category_index,
+ box_thresh=0.5,
+ line_thickness=3,
+ font='arial.ttf',
+ font_size=20)
+ plt.imshow(plot_img)
plt.show()
# 保存预测的图片结果
- original_img.save("test_result.jpg")
+ plot_img.save("test_result.jpg")
if __name__ == '__main__':
diff --git a/pytorch_object_detection/train_coco_dataset/requirements.txt b/pytorch_object_detection/train_coco_dataset/requirements.txt
index 2d0345811..3a50c7ad2 100644
--- a/pytorch_object_detection/train_coco_dataset/requirements.txt
+++ b/pytorch_object_detection/train_coco_dataset/requirements.txt
@@ -4,5 +4,5 @@ numpy
tqdm
pycocotools
Pillow
-torch==1.7.1
-torchvision==0.8.2
+torch==1.10
+torchvision==0.11.1
diff --git a/pytorch_object_detection/train_coco_dataset/results20210412-092355.txt b/pytorch_object_detection/train_coco_dataset/results20210412-092355.txt
deleted file mode 100644
index c0a250e10..000000000
--- a/pytorch_object_detection/train_coco_dataset/results20210412-092355.txt
+++ /dev/null
@@ -1,26 +0,0 @@
-epoch:0 0.0413 0.109 0.0222 0.0107 0.0495 0.062 0.0567 0.0896 0.0919 0.0169 0.0882 0.1438 1.6263 0.01
-epoch:1 0.0986 0.2188 0.0751 0.0359 0.1154 0.1411 0.111 0.167 0.1703 0.0504 0.1853 0.2523 1.0788 0.01
-epoch:2 0.1258 0.2702 0.1021 0.0473 0.1483 0.1761 0.1359 0.2075 0.2132 0.0694 0.2369 0.3164 1.0252 0.01
-epoch:3 0.1369 0.2809 0.1176 0.0543 0.1551 0.1961 0.1433 0.2086 0.2133 0.0778 0.2312 0.3145 0.992 0.01
-epoch:4 0.1553 0.3137 0.1371 0.0635 0.1774 0.2202 0.1592 0.2367 0.2422 0.0918 0.2677 0.3499 0.9698 0.01
-epoch:5 0.1626 0.3245 0.1439 0.0663 0.1869 0.2251 0.1663 0.2442 0.2501 0.0958 0.2726 0.3581 0.9524 0.01
-epoch:6 0.1739 0.3442 0.1539 0.0752 0.2007 0.2348 0.1742 0.2617 0.2689 0.1091 0.3073 0.3663 0.9372 0.01
-epoch:7 0.1762 0.3417 0.1609 0.0756 0.1963 0.2461 0.1728 0.2529 0.2583 0.1034 0.2799 0.3733 0.9238 0.01
-epoch:8 0.1844 0.3551 0.1709 0.0792 0.2107 0.2535 0.1796 0.2716 0.2785 0.1134 0.3089 0.3964 0.9136 0.01
-epoch:9 0.1909 0.3631 0.1811 0.0857 0.2172 0.2603 0.1837 0.2731 0.2797 0.1231 0.3095 0.3914 0.9045 0.01
-epoch:10 0.1955 0.3684 0.1894 0.0858 0.2242 0.2667 0.1873 0.2756 0.282 0.1173 0.3138 0.3901 0.896 0.01
-epoch:11 0.1995 0.373 0.1932 0.0856 0.2234 0.2732 0.1889 0.2804 0.2874 0.1217 0.3172 0.4012 0.8883 0.01
-epoch:12 0.2067 0.38 0.2019 0.0884 0.2332 0.281 0.1962 0.2891 0.2959 0.1223 0.3251 0.4223 0.881 0.01
-epoch:13 0.2109 0.3912 0.2052 0.0935 0.2408 0.2837 0.2016 0.3053 0.3138 0.1343 0.3526 0.4342 0.8743 0.01
-epoch:14 0.2131 0.3901 0.206 0.0933 0.241 0.2885 0.2023 0.301 0.3086 0.13 0.3429 0.4311 0.8676 0.01
-epoch:15 0.216 0.3968 0.2136 0.0967 0.2432 0.2984 0.2028 0.2999 0.307 0.1338 0.3364 0.4383 0.8631 0.01
-epoch:16 0.2286 0.4099 0.2293 0.1024 0.2594 0.3138 0.2118 0.315 0.3228 0.1409 0.3608 0.4518 0.8302 0.001
-epoch:17 0.2296 0.4102 0.2283 0.1027 0.2577 0.3116 0.2131 0.316 0.3239 0.1422 0.3564 0.4546 0.8255 0.001
-epoch:18 0.2306 0.4125 0.2308 0.1031 0.2592 0.3155 0.213 0.3161 0.3242 0.1416 0.3585 0.4566 0.8239 0.001
-epoch:19 0.2324 0.4163 0.2327 0.1042 0.2618 0.3163 0.2146 0.3193 0.3273 0.1483 0.3635 0.4543 0.8221 0.001
-epoch:20 0.2306 0.4129 0.2293 0.103 0.2611 0.3147 0.2121 0.3143 0.3219 0.1394 0.359 0.4506 0.8216 0.001
-epoch:21 0.2325 0.4147 0.2338 0.1052 0.2623 0.3167 0.2157 0.3185 0.3263 0.1458 0.3625 0.4546 0.8208 0.001
-epoch:22 0.2321 0.4145 0.2313 0.1034 0.261 0.3226 0.2145 0.3181 0.3261 0.1428 0.3616 0.462 0.8159 0.0001
-epoch:23 0.232 0.4143 0.2305 0.1036 0.2613 0.3198 0.2139 0.3181 0.3261 0.1433 0.3617 0.4559 0.8162 0.0001
-epoch:24 0.2315 0.4136 0.2302 0.1032 0.2603 0.3209 0.2131 0.317 0.3249 0.1422 0.3598 0.4594 0.8161 0.0001
-epoch:25 0.232 0.4145 0.2317 0.1035 0.2614 0.3219 0.215 0.3183 0.3262 0.1444 0.3605 0.4601 0.8158 0.0001
diff --git a/pytorch_object_detection/train_coco_dataset/results20220408-201436.txt b/pytorch_object_detection/train_coco_dataset/results20220408-201436.txt
new file mode 100644
index 000000000..0927e308c
--- /dev/null
+++ b/pytorch_object_detection/train_coco_dataset/results20220408-201436.txt
@@ -0,0 +1,26 @@
+epoch:0 0.0504 0.1144 0.0362 0.0207 0.0601 0.0657 0.0702 0.1069 0.1087 0.0335 0.1153 0.1486 1.7430 0.005000
+epoch:1 0.1138 0.2300 0.0994 0.0494 0.1279 0.1554 0.1303 0.1940 0.1980 0.0747 0.2051 0.2831 1.2282 0.005000
+epoch:2 0.1461 0.2773 0.1394 0.0636 0.1635 0.1997 0.1530 0.2243 0.2288 0.0938 0.2435 0.3309 1.1391 0.005000
+epoch:3 0.1669 0.3134 0.1642 0.0750 0.1843 0.2282 0.1680 0.2509 0.2561 0.1091 0.2705 0.3701 1.0902 0.005000
+epoch:4 0.1857 0.3389 0.1828 0.0829 0.2074 0.2568 0.1830 0.2708 0.2756 0.1140 0.2937 0.3998 1.0581 0.005000
+epoch:5 0.1908 0.3431 0.1930 0.0901 0.2128 0.2578 0.1839 0.2704 0.2753 0.1197 0.2927 0.3893 1.0337 0.005000
+epoch:6 0.2044 0.3634 0.2077 0.0954 0.2247 0.2796 0.1947 0.2893 0.2956 0.1317 0.3138 0.4178 1.0127 0.005000
+epoch:7 0.2068 0.3651 0.2099 0.0953 0.2269 0.2840 0.1959 0.2869 0.2926 0.1290 0.3093 0.4186 0.9945 0.005000
+epoch:8 0.2171 0.3788 0.2218 0.0996 0.2470 0.2969 0.2012 0.3001 0.3071 0.1329 0.3375 0.4371 0.9806 0.005000
+epoch:9 0.2146 0.3717 0.2207 0.0946 0.2315 0.3038 0.2011 0.2910 0.2962 0.1277 0.3091 0.4321 0.9691 0.005000
+epoch:10 0.2280 0.3974 0.2345 0.1035 0.2535 0.3108 0.2118 0.3119 0.3182 0.1402 0.3429 0.4537 0.9567 0.005000
+epoch:11 0.2332 0.3983 0.2443 0.1111 0.2534 0.3149 0.2136 0.3128 0.3190 0.1515 0.3417 0.4438 0.9450 0.005000
+epoch:12 0.2400 0.4094 0.2486 0.1102 0.2622 0.3251 0.2175 0.3214 0.3289 0.1507 0.3521 0.4588 0.9369 0.005000
+epoch:13 0.2449 0.4152 0.2563 0.1121 0.2741 0.3308 0.2234 0.3286 0.3363 0.1552 0.3703 0.4627 0.9286 0.005000
+epoch:14 0.2466 0.4192 0.2542 0.1131 0.2765 0.3412 0.2220 0.3258 0.3322 0.1481 0.3627 0.4776 0.9203 0.005000
+epoch:15 0.2492 0.4216 0.2569 0.1147 0.2781 0.3417 0.2254 0.3337 0.3402 0.1565 0.3666 0.4893 0.9116 0.005000
+epoch:16 0.2689 0.4433 0.2814 0.1246 0.2963 0.3705 0.2384 0.3495 0.3569 0.1671 0.3864 0.5046 0.8616 0.000500
+epoch:17 0.2719 0.4473 0.2865 0.1243 0.3021 0.3743 0.2399 0.3519 0.3593 0.1669 0.3931 0.5017 0.8515 0.000500
+epoch:18 0.2738 0.4521 0.2857 0.1256 0.3048 0.3718 0.2416 0.3564 0.3645 0.1713 0.3996 0.5037 0.8472 0.000500
+epoch:19 0.2759 0.4534 0.2893 0.1259 0.3094 0.3719 0.2448 0.3603 0.3681 0.1691 0.4073 0.5055 0.8439 0.000500
+epoch:20 0.2720 0.4483 0.2838 0.1250 0.3021 0.3681 0.2400 0.3532 0.3613 0.1688 0.3944 0.4994 0.8417 0.000500
+epoch:21 0.2748 0.4501 0.2904 0.1241 0.3019 0.3759 0.2421 0.3561 0.3641 0.1682 0.3941 0.5101 0.8378 0.000500
+epoch:22 0.2754 0.4532 0.2896 0.1281 0.3064 0.3759 0.2419 0.3586 0.3660 0.1712 0.3993 0.5115 0.8304 0.000050
+epoch:23 0.2757 0.4516 0.2907 0.1271 0.3068 0.3748 0.2423 0.3572 0.3650 0.1692 0.4005 0.5087 0.8307 0.000050
+epoch:24 0.2750 0.4500 0.2888 0.1256 0.3017 0.3760 0.2411 0.3536 0.3611 0.1669 0.3894 0.5040 0.8299 0.000050
+epoch:25 0.2769 0.4537 0.2903 0.1263 0.3082 0.3782 0.2424 0.3582 0.3663 0.1693 0.4020 0.5116 0.8281 0.000050
diff --git a/pytorch_object_detection/train_coco_dataset/train.py b/pytorch_object_detection/train_coco_dataset/train.py
index b02200d39..4b068a3ec 100644
--- a/pytorch_object_detection/train_coco_dataset/train.py
+++ b/pytorch_object_detection/train_coco_dataset/train.py
@@ -6,20 +6,30 @@
import transforms
from network_files import FasterRCNN, AnchorsGenerator
-from backbone import MobileNetV2, vgg
+from backbone import MobileNetV2, vgg, resnet50
from my_dataset import CocoDetection
from train_utils import train_eval_utils as utils
+from train_utils import GroupedBatchSampler, create_aspect_ratio_groups
+from torchvision.models.feature_extraction import create_feature_extractor
def create_model(num_classes):
- # https://download.pytorch.org/models/vgg16-397923af.pth
- # 如果使用mobilenetv2的话就下载对应预训练权重并注释下面三行,接着把mobilenetv2模型对应的两行代码注释取消掉
- vgg_feature = vgg(model_name="vgg16", weights_path="./backbone/vgg16.pth").features
- backbone = torch.nn.Sequential(*list(vgg_feature._modules.values())[:-1]) # 删除feature中最后的maxpool层
- backbone.out_channels = 512
-
- # https://download.pytorch.org/models/mobilenet_v2-b0353104.pth
- # backbone = MobileNetV2(weights_path="./backbone/mobilenet_v2.pth").features
+ # 以vgg16为backbone
+ # 预训练权重地址: https://download.pytorch.org/models/vgg16-397923af.pth
+ # vgg16 = vgg(model_name="vgg16", weights_path="./vgg16.pth")
+ # backbone = create_feature_extractor(vgg16, return_nodes={"features.29": "0"}) # 删除feature中最后的maxpool层
+ # backbone.out_channels = 512
+
+ # 以resnet50为backbone
+ # 预训练权重地址:https://download.pytorch.org/models/resnet50-19c8e357.pth
+ res50 = resnet50()
+ res50.load_state_dict(torch.load("./resnet50.pth", map_location="cpu"))
+ backbone = create_feature_extractor(res50, return_nodes={"layer3": "0"})
+ backbone.out_channels = 1024
+
+ # 以mobilenetv2为backbone
+ # 预训练权重地址:https://download.pytorch.org/models/mobilenet_v2-b0353104.pth
+ # backbone = MobileNetV2(weights_path="./mobilenet_v2.pth").features
# backbone.out_channels = 1280 # 设置对应backbone输出特征矩阵的channels
anchor_generator = AnchorsGenerator(sizes=((32, 64, 128, 256, 512),),
@@ -54,29 +64,49 @@ def main(args):
# load train data set
# coco2017 -> annotations -> instances_train2017.json
- train_data_set = CocoDetection(COCO_root, "train", data_transform["train"])
+ train_dataset = CocoDetection(COCO_root, "train", data_transform["train"])
+ train_sampler = None
+
+ # 是否按图片相似高宽比采样图片组成batch
+ # 使用的话能够减小训练时所需GPU显存,默认使用
+ if args.aspect_ratio_group_factor >= 0:
+ train_sampler = torch.utils.data.RandomSampler(train_dataset)
+ # 统计所有图像高宽比例在bins区间中的位置索引
+ group_ids = create_aspect_ratio_groups(train_dataset, k=args.aspect_ratio_group_factor)
+ # 每个batch图片从同一高宽比例区间中取
+ train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids, args.batch_size)
+
# 注意这里的collate_fn是自定义的,因为读取的数据包括image和targets,不能直接使用默认的方法合成batch
batch_size = args.batch_size
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
print('Using %g dataloader workers' % nw)
- train_data_loader = torch.utils.data.DataLoader(train_data_set,
- batch_size=batch_size,
- shuffle=True,
- pin_memory=True,
- num_workers=nw,
- collate_fn=train_data_set.collate_fn)
+
+ if train_sampler:
+ # 如果按照图片高宽比采样图片,dataloader中需要使用batch_sampler
+ train_data_loader = torch.utils.data.DataLoader(train_dataset,
+ batch_sampler=train_batch_sampler,
+ pin_memory=True,
+ num_workers=nw,
+ collate_fn=train_dataset.collate_fn)
+ else:
+ train_data_loader = torch.utils.data.DataLoader(train_dataset,
+ batch_size=batch_size,
+ shuffle=True,
+ pin_memory=True,
+ num_workers=nw,
+ collate_fn=train_dataset.collate_fn)
# load validation data set
# coco2017 -> annotations -> instances_val2017.json
- val_data_set = CocoDetection(COCO_root, "val", data_transform["val"])
- val_data_set_loader = torch.utils.data.DataLoader(val_data_set,
- batch_size=batch_size,
- shuffle=False,
- pin_memory=True,
- num_workers=nw,
- collate_fn=train_data_set.collate_fn)
-
- # create model num_classes equal background + 80 classes
+ val_dataset = CocoDetection(COCO_root, "val", data_transform["val"])
+ val_data_loader = torch.utils.data.DataLoader(val_dataset,
+ batch_size=1,
+ shuffle=False,
+ pin_memory=True,
+ num_workers=nw,
+ collate_fn=train_dataset.collate_fn)
+
+ # create model num_classes equal background + classes
model = create_model(num_classes=args.num_classes + 1)
# print(model)
@@ -123,12 +153,12 @@ def main(args):
lr_scheduler.step()
# evaluate on the test dataset
- coco_info = utils.evaluate(model, val_data_set_loader, device=device)
+ coco_info = utils.evaluate(model, val_data_loader, device=device)
# write into txt
with open(results_file, "a") as f:
# 写入的数据包括coco指标还有loss和learning rate
- result_info = [str(round(i, 4)) for i in coco_info + [mean_loss.item()]] + [str(round(lr, 6))]
+ result_info = [f"{i:.4f}" for i in coco_info + [mean_loss.item()]] + [f"{lr:.6f}"]
txt = "epoch:{} {}".format(epoch, ' '.join(result_info))
f.write(txt + "\n")
@@ -166,7 +196,7 @@ def main(args):
# 训练数据集的根目录
parser.add_argument('--data-path', default='/data/coco2017', help='dataset')
# 检测目标类别数(不包含背景)
- parser.add_argument('--num-classes', default=80, type=int, help='num_classes')
+ parser.add_argument('--num-classes', default=90, type=int, help='num_classes')
# 文件保存地址
parser.add_argument('--output-dir', default='./save_weights', help='path where to save')
# 若需要接着上次训练,则指定上次训练保存权重文件地址
@@ -177,7 +207,7 @@ def main(args):
parser.add_argument('--epochs', default=26, type=int, metavar='N',
help='number of total epochs to run')
# 学习率
- parser.add_argument('--lr', default=0.002, type=float,
+ parser.add_argument('--lr', default=0.005, type=float,
help='initial learning rate, 0.02 is the default value for training '
'on 8 gpus and 2 images_per_gpu')
# SGD的momentum参数
@@ -193,8 +223,9 @@ def main(args):
# 针对torch.optim.lr_scheduler.MultiStepLR的参数
parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma')
# 训练的batch size(如果内存/GPU显存充裕,建议设置更大)
- parser.add_argument('--batch_size', default=2, type=int, metavar='N',
+ parser.add_argument('--batch_size', default=4, type=int, metavar='N',
help='batch size when training.')
+ parser.add_argument('--aspect-ratio-group-factor', default=3, type=int)
# 是否使用混合精度训练(需要GPU支持混合精度)
parser.add_argument("--amp", default=False, help="Use torch.cuda.amp for mixed precision training")
diff --git a/pytorch_object_detection/train_coco_dataset/train_multi_GPU.py b/pytorch_object_detection/train_coco_dataset/train_multi_GPU.py
index ccc9cd364..5ee8e5303 100644
--- a/pytorch_object_detection/train_coco_dataset/train_multi_GPU.py
+++ b/pytorch_object_detection/train_coco_dataset/train_multi_GPU.py
@@ -7,22 +7,20 @@
import transforms
from my_dataset import CocoDetection
-from backbone import vgg
+from backbone import resnet50
from network_files import FasterRCNN, AnchorsGenerator
import train_utils.train_eval_utils as utils
from train_utils import GroupedBatchSampler, create_aspect_ratio_groups, init_distributed_mode, save_on_master, mkdir
+from torchvision.models.feature_extraction import create_feature_extractor
def create_model(num_classes):
- # https://download.pytorch.org/models/vgg16-397923af.pth
- # 如果使用mobilenetv2的话就下载对应预训练权重并注释下面三行,接着把mobilenetv2模型对应的两行代码注释取消掉
- vgg_feature = vgg(model_name="vgg16", weights_path="./backbone/vgg16.pth").features
- backbone = torch.nn.Sequential(*list(vgg_feature._modules.values())[:-1]) # 删除feature中最后的maxpool层
- backbone.out_channels = 512
-
- # https://download.pytorch.org/models/mobilenet_v2-b0353104.pth
- # backbone = MobileNetV2(weights_path="./backbone/mobilenet_v2.pth").features
- # backbone.out_channels = 1280 # 设置对应backbone输出特征矩阵的channels
+ # 以resnet50为backbone
+ # 预训练权重地址:https://download.pytorch.org/models/resnet50-19c8e357.pth
+ res50 = resnet50()
+ res50.load_state_dict(torch.load("./resnet50.pth", map_location="cpu"))
+ backbone = create_feature_extractor(res50, return_nodes={"layer3": "0"})
+ backbone.out_channels = 1024
anchor_generator = AnchorsGenerator(sizes=((32, 64, 128, 256, 512),),
aspect_ratios=((0.5, 1.0, 2.0),))
@@ -61,42 +59,45 @@ def main(args):
# load train data set
# coco2017 -> annotations -> instances_train2017.json
- train_data_set = CocoDetection(COCO_root, "train", data_transform["train"])
+ train_dataset = CocoDetection(COCO_root, "train", data_transform["train"])
# load validation data set
# coco2017 -> annotations -> instances_val2017.json
- val_data_set = CocoDetection(COCO_root, "val", data_transform["val"])
+ val_dataset = CocoDetection(COCO_root, "val", data_transform["val"])
print("Creating data loaders")
if args.distributed:
- train_sampler = torch.utils.data.distributed.DistributedSampler(train_data_set)
- test_sampler = torch.utils.data.distributed.DistributedSampler(val_data_set)
+ train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
+ test_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
else:
- train_sampler = torch.utils.data.RandomSampler(train_data_set)
- test_sampler = torch.utils.data.SequentialSampler(val_data_set)
+ train_sampler = torch.utils.data.RandomSampler(train_dataset)
+ test_sampler = torch.utils.data.SequentialSampler(val_dataset)
if args.aspect_ratio_group_factor >= 0:
# 统计所有图像比例在bins区间中的位置索引
- group_ids = create_aspect_ratio_groups(train_data_set, k=args.aspect_ratio_group_factor)
+ group_ids = create_aspect_ratio_groups(train_dataset, k=args.aspect_ratio_group_factor)
train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids, args.batch_size)
else:
train_batch_sampler = torch.utils.data.BatchSampler(
train_sampler, args.batch_size, drop_last=True)
data_loader = torch.utils.data.DataLoader(
- train_data_set, batch_sampler=train_batch_sampler, num_workers=args.workers,
- collate_fn=train_data_set.collate_fn)
+ train_dataset, batch_sampler=train_batch_sampler, num_workers=args.workers,
+ collate_fn=train_dataset.collate_fn)
data_loader_test = torch.utils.data.DataLoader(
- val_data_set, batch_size=1,
+ val_dataset, batch_size=1,
sampler=test_sampler, num_workers=args.workers,
- collate_fn=train_data_set.collate_fn)
+ collate_fn=train_dataset.collate_fn)
print("Creating model")
- # create model num_classes equal background + 80 classes
+ # create model num_classes equal background + classes
model = create_model(num_classes=args.num_classes + 1)
model.to(device)
+ if args.distributed and args.sync_bn:
+ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
+
model_without_ddp = model
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
@@ -152,7 +153,7 @@ def main(args):
# write into txt
with open(results_file, "a") as f:
# 写入的数据包括coco指标还有loss和learning rate
- result_info = [str(round(i, 4)) for i in coco_info + [mean_loss.item()]] + [str(round(lr, 6))]
+ result_info = [f"{i:.4f}" for i in coco_info + [mean_loss.item()]] + [f"{lr:.6f}"]
txt = "epoch:{} {}".format(epoch, ' '.join(result_info))
f.write(txt + "\n")
@@ -195,9 +196,9 @@ def main(args):
# 训练设备类型
parser.add_argument('--device', default='cuda', help='device')
# 检测目标类别数(不包含背景)
- parser.add_argument('--num-classes', default=80, type=int, help='num_classes')
+ parser.add_argument('--num-classes', default=90, type=int, help='num_classes')
# 每块GPU上的batch_size
- parser.add_argument('-b', '--batch-size', default=16, type=int,
+ parser.add_argument('-b', '--batch-size', default=4, type=int,
help='images per gpu, the total batch size is $NGPU x batch_size')
# 指定接着从哪个epoch数开始训练
parser.add_argument('--start_epoch', default=0, type=int, help='start epoch')
@@ -237,6 +238,7 @@ def main(args):
parser.add_argument('--world-size', default=4, type=int,
help='number of distributed processes')
parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training')
+ parser.add_argument("--sync-bn", dest="sync_bn", help="Use sync batch norm", type=bool, default=False)
# 是否使用混合精度训练(需要GPU支持混合精度)
parser.add_argument("--amp", default=False, help="Use torch.cuda.amp for mixed precision training")
diff --git a/pytorch_object_detection/train_coco_dataset/train_utils/__init__.py b/pytorch_object_detection/train_coco_dataset/train_utils/__init__.py
index 78167b64d..ce519bc94 100644
--- a/pytorch_object_detection/train_coco_dataset/train_utils/__init__.py
+++ b/pytorch_object_detection/train_coco_dataset/train_utils/__init__.py
@@ -1,2 +1,3 @@
from .group_by_aspect_ratio import GroupedBatchSampler, create_aspect_ratio_groups
from .distributed_utils import init_distributed_mode, save_on_master, mkdir
+from .coco_eval import EvalCOCOMetric
diff --git a/pytorch_object_detection/train_coco_dataset/train_utils/coco_eval.py b/pytorch_object_detection/train_coco_dataset/train_utils/coco_eval.py
new file mode 100644
index 000000000..b8df0204d
--- /dev/null
+++ b/pytorch_object_detection/train_coco_dataset/train_utils/coco_eval.py
@@ -0,0 +1,163 @@
+import json
+import copy
+
+import numpy as np
+from pycocotools.coco import COCO
+from pycocotools.cocoeval import COCOeval
+import pycocotools.mask as mask_util
+from .distributed_utils import all_gather, is_main_process
+
+
+def merge(img_ids, eval_results):
+ """将多个进程之间的数据汇总在一起"""
+ all_img_ids = all_gather(img_ids)
+ all_eval_results = all_gather(eval_results)
+
+ merged_img_ids = []
+ for p in all_img_ids:
+ merged_img_ids.extend(p)
+
+ merged_eval_results = []
+ for p in all_eval_results:
+ merged_eval_results.extend(p)
+
+ merged_img_ids = np.array(merged_img_ids)
+
+ # keep only unique (and in sorted order) images
+ # 去除重复的图片索引,多GPU训练时为了保证每个进程的训练图片数量相同,可能将一张图片分配给多个进程
+ merged_img_ids, idx = np.unique(merged_img_ids, return_index=True)
+ merged_eval_results = [merged_eval_results[i] for i in idx]
+
+ return list(merged_img_ids), merged_eval_results
+
+
+class EvalCOCOMetric:
+ def __init__(self,
+ coco: COCO = None,
+ iou_type: str = None,
+ results_file_name: str = "predict_results.json",
+ classes_mapping: dict = None):
+ self.coco = copy.deepcopy(coco)
+ self.img_ids = [] # 记录每个进程处理图片的ids
+ self.results = []
+ self.aggregation_results = None
+ self.classes_mapping = classes_mapping
+ self.coco_evaluator = None
+ assert iou_type in ["bbox", "segm", "keypoints"]
+ self.iou_type = iou_type
+ self.results_file_name = results_file_name
+
+ def prepare_for_coco_detection(self, targets, outputs):
+ """将预测的结果转换成COCOeval指定的格式,针对目标检测任务"""
+ # 遍历每张图像的预测结果
+ for target, output in zip(targets, outputs):
+ if len(output) == 0:
+ continue
+
+ img_id = int(target["image_id"])
+ if img_id in self.img_ids:
+ # 防止出现重复的数据
+ continue
+ self.img_ids.append(img_id)
+ per_image_boxes = output["boxes"]
+ # 对于coco_eval, 需要的每个box的数据格式为[x_min, y_min, w, h]
+ # 而我们预测的box格式是[x_min, y_min, x_max, y_max],所以需要转下格式
+ per_image_boxes[:, 2:] -= per_image_boxes[:, :2]
+ per_image_classes = output["labels"].tolist()
+ per_image_scores = output["scores"].tolist()
+
+ res_list = []
+ # 遍历每个目标的信息
+ for object_score, object_class, object_box in zip(
+ per_image_scores, per_image_classes, per_image_boxes):
+ object_score = float(object_score)
+ class_idx = int(object_class)
+ if self.classes_mapping is not None:
+ class_idx = int(self.classes_mapping[str(class_idx)])
+ # We recommend rounding coordinates to the nearest tenth of a pixel
+ # to reduce resulting JSON file size.
+ object_box = [round(b, 2) for b in object_box.tolist()]
+
+ res = {"image_id": img_id,
+ "category_id": class_idx,
+ "bbox": object_box,
+ "score": round(object_score, 3)}
+ res_list.append(res)
+ self.results.append(res_list)
+
+ def prepare_for_coco_segmentation(self, targets, outputs):
+ """将预测的结果转换成COCOeval指定的格式,针对实例分割任务"""
+ # 遍历每张图像的预测结果
+ for target, output in zip(targets, outputs):
+ if len(output) == 0:
+ continue
+
+ img_id = int(target["image_id"])
+ if img_id in self.img_ids:
+ # 防止出现重复的数据
+ continue
+
+ self.img_ids.append(img_id)
+ per_image_masks = output["masks"]
+ per_image_classes = output["labels"].tolist()
+ per_image_scores = output["scores"].tolist()
+
+ masks = per_image_masks > 0.5
+
+ res_list = []
+ # 遍历每个目标的信息
+ for mask, label, score in zip(masks, per_image_classes, per_image_scores):
+ rle = mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0]
+ rle["counts"] = rle["counts"].decode("utf-8")
+
+ class_idx = int(label)
+ if self.classes_mapping is not None:
+ class_idx = int(self.classes_mapping[str(class_idx)])
+
+ res = {"image_id": img_id,
+ "category_id": class_idx,
+ "segmentation": rle,
+ "score": round(score, 3)}
+ res_list.append(res)
+ self.results.append(res_list)
+
+ def update(self, targets, outputs):
+ if self.iou_type == "bbox":
+ self.prepare_for_coco_detection(targets, outputs)
+ elif self.iou_type == "segm":
+ self.prepare_for_coco_segmentation(targets, outputs)
+ else:
+ raise KeyError(f"not support iou_type: {self.iou_type}")
+
+ def synchronize_results(self):
+ # 同步所有进程中的数据
+ eval_ids, eval_results = merge(self.img_ids, self.results)
+ self.aggregation_results = {"img_ids": eval_ids, "results": eval_results}
+
+ # 主进程上保存即可
+ if is_main_process():
+ results = []
+ [results.extend(i) for i in eval_results]
+ # write predict results into json file
+ json_str = json.dumps(results, indent=4)
+ with open(self.results_file_name, 'w') as json_file:
+ json_file.write(json_str)
+
+ def evaluate(self):
+ # 只在主进程上评估即可
+ if is_main_process():
+ # accumulate predictions from all images
+ coco_true = self.coco
+ coco_pre = coco_true.loadRes(self.results_file_name)
+
+ self.coco_evaluator = COCOeval(cocoGt=coco_true, cocoDt=coco_pre, iouType=self.iou_type)
+
+ self.coco_evaluator.evaluate()
+ self.coco_evaluator.accumulate()
+ print(f"IoU metric: {self.iou_type}")
+ self.coco_evaluator.summarize()
+
+ coco_info = self.coco_evaluator.stats.tolist() # numpy to list
+ return coco_info
+ else:
+ return None
diff --git a/pytorch_object_detection/train_coco_dataset/train_utils/distributed_utils.py b/pytorch_object_detection/train_coco_dataset/train_utils/distributed_utils.py
index 95d0b11e1..80b2412c6 100644
--- a/pytorch_object_detection/train_coco_dataset/train_utils/distributed_utils.py
+++ b/pytorch_object_detection/train_coco_dataset/train_utils/distributed_utils.py
@@ -83,38 +83,8 @@ def all_gather(data):
if world_size == 1:
return [data]
- # serialized to a Tensor
- # 将数据转为tensor
- buffer = pickle.dumps(data)
- storage = torch.ByteStorage.from_buffer(buffer)
- tensor = torch.ByteTensor(storage).to("cuda")
-
- # obtain Tensor size of each rank
- # 获取每个进程中tensor的大小,并求最大值
- local_size = torch.tensor([tensor.numel()], device="cuda")
- size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
- dist.all_gather(size_list, local_size)
- size_list = [int(size.item()) for size in size_list]
- max_size = max(size_list)
-
- # receiving Tensor from all ranks
- # we pad the tensor because torch all_gather does not support
- # gathering tensors of different shapes
- # 由于现在all_gather方法只能传播相同长度的数据,所以需要pad处理
- tensor_list = []
- for _ in size_list:
- tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
- if local_size != max_size:
- padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
- tensor = torch.cat((tensor, padding), dim=0)
- dist.all_gather(tensor_list, tensor)
-
- # 将从各个进程中获取得到的数据整合在一起
- # 注意要将多余的pad给删除掉
- data_list = []
- for size, tensor in zip(size_list, tensor_list):
- buffer = tensor.cpu().numpy().tobytes()[:size]
- data_list.append(pickle.loads(buffer))
+ data_list = [None] * world_size
+ dist.all_gather_object(data_list, data)
return data_list
diff --git a/pytorch_object_detection/train_coco_dataset/train_utils/train_eval_utils.py b/pytorch_object_detection/train_coco_dataset/train_utils/train_eval_utils.py
index b3710a208..ba009fa43 100644
--- a/pytorch_object_detection/train_coco_dataset/train_utils/train_eval_utils.py
+++ b/pytorch_object_detection/train_coco_dataset/train_utils/train_eval_utils.py
@@ -1,12 +1,11 @@
import math
import sys
import time
-import json
import torch
-from pycocotools.cocoeval import COCOeval
import train_utils.distributed_utils as utils
+from .coco_eval import EvalCOCOMetric
def train_one_epoch(model, optimizer, data_loader, device, epoch,
@@ -73,9 +72,7 @@ def evaluate(model, data_loader, device):
metric_logger = utils.MetricLogger(delimiter=" ")
header = "Test: "
- coco91to80 = data_loader.dataset.coco91to80
- coco80to91 = dict([(str(v), k) for k, v in coco91to80.items()])
- results = []
+ det_metric = EvalCOCOMetric(data_loader.dataset.coco, iou_type="bbox", results_file_name="det_results.json")
for image, targets in metric_logger.log_every(data_loader, 100, header):
image = list(img.to(device) for img in image)
@@ -89,36 +86,7 @@ def evaluate(model, data_loader, device):
outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs]
model_time = time.time() - model_time
- # 遍历每张图像的预测结果
- for target, output in zip(targets, outputs):
- if len(output) == 0:
- continue
-
- img_id = int(target["image_id"])
- per_image_boxes = output["boxes"]
- # 对于coco_eval, 需要的每个box的数据格式为[x_min, y_min, w, h]
- # 而我们预测的box格式是[x_min, y_min, x_max, y_max],所以需要转下格式
- per_image_boxes[:, 2:] -= per_image_boxes[:, :2]
- per_image_classes = output["labels"]
- per_image_scores = output["scores"]
-
- # 遍历每个目标的信息
- for object_score, object_class, object_box in zip(
- per_image_scores, per_image_classes, per_image_boxes):
- object_score = float(object_score)
- # 要将类别信息还原回coco91中
- coco80_class = int(object_class)
- coco91_class = int(coco80to91[str(coco80_class)])
- # We recommend rounding coordinates to the nearest tenth of a pixel
- # to reduce resulting JSON file size.
- object_box = [round(b, 2) for b in object_box.tolist()]
-
- res = {"image_id": img_id,
- "category_id": coco91_class,
- "bbox": object_box,
- "score": round(object_score, 3)}
- results.append(res)
-
+ det_metric.update(targets, outputs)
metric_logger.update(model_time=model_time)
# gather the stats from all processes
@@ -126,29 +94,10 @@ def evaluate(model, data_loader, device):
print("Averaged stats:", metric_logger)
# 同步所有进程中的数据
- all_results = utils.all_gather(results)
+ det_metric.synchronize_results()
if utils.is_main_process():
- # 将所有进程上的数据合并到一个list当中
- results = []
- for res in all_results:
- results.extend(res)
-
- # write predict results into json file
- json_str = json.dumps(results, indent=4)
- with open('predict_tmp.json', 'w') as json_file:
- json_file.write(json_str)
-
- # accumulate predictions from all images
- coco_true = data_loader.dataset.coco
- coco_pre = coco_true.loadRes('predict_tmp.json')
-
- coco_evaluator = COCOeval(cocoGt=coco_true, cocoDt=coco_pre, iouType="bbox")
- coco_evaluator.evaluate()
- coco_evaluator.accumulate()
- coco_evaluator.summarize()
-
- coco_info = coco_evaluator.stats.tolist() # numpy to list
+ coco_info = det_metric.evaluate()
else:
coco_info = None
diff --git a/pytorch_object_detection/train_coco_dataset/validation.py b/pytorch_object_detection/train_coco_dataset/validation.py
index e6ea58dae..98a230f77 100644
--- a/pytorch_object_detection/train_coco_dataset/validation.py
+++ b/pytorch_object_detection/train_coco_dataset/validation.py
@@ -10,12 +10,13 @@
import torchvision
from tqdm import tqdm
import numpy as np
-from pycocotools.cocoeval import COCOeval
+from torchvision.models.feature_extraction import create_feature_extractor
import transforms
from network_files import FasterRCNN, AnchorsGenerator
from my_dataset import CocoDetection
-from backbone import vgg
+from backbone import resnet50
+from train_utils import EvalCOCOMetric
def summarize(self, catId=None):
@@ -99,11 +100,10 @@ def main(parser_data):
}
# read class_indict
- label_json_path = './coco80_indices.json'
+ label_json_path = './coco91_indices.json'
assert os.path.exists(label_json_path), "json file {} dose not exist.".format(label_json_path)
- json_file = open(label_json_path, 'r')
- category_index = json.load(json_file)
- json_file.close()
+ with open(label_json_path, 'r') as f:
+ category_index = json.load(f)
coco_root = parser_data.data_path
@@ -122,9 +122,9 @@ def main(parser_data):
collate_fn=val_dataset.collate_fn)
# create model
- vgg_feature = vgg(model_name="vgg16", weights_path="./backbone/vgg16.pth").features
- backbone = torch.nn.Sequential(*list(vgg_feature._modules.values())[:-1]) # 删除feature中最后的maxpool层
- backbone.out_channels = 512
+ res50 = resnet50()
+ backbone = create_feature_extractor(res50, return_nodes={"layer3": "0"})
+ backbone.out_channels = 1024
anchor_generator = AnchorsGenerator(sizes=((32, 64, 128, 256, 512),),
aspect_ratios=((0.5, 1.0, 2.0),))
@@ -140,19 +140,19 @@ def main(parser_data):
box_roi_pool=roi_pooler)
# 载入你自己训练好的模型权重
- weights_path = parser_data.weights
+ weights_path = parser_data.weights_path
assert os.path.exists(weights_path), "not found {} file.".format(weights_path)
- model.load_state_dict(torch.load(weights_path, map_location=device)['model'])
+ weights_dict = torch.load(weights_path, map_location='cpu')
+ weights_dict = weights_dict["model"] if "model" in weights_dict else weights_dict
+ model.load_state_dict(weights_dict)
# print(model)
model.to(device)
# evaluate on the val dataset
cpu_device = torch.device("cpu")
- coco91to80 = val_dataset.coco91to80
- coco80to91 = dict([(str(v), k) for k, v in coco91to80.items()])
- results = []
+ det_metric = EvalCOCOMetric(val_dataset.coco, "bbox", "det_results.json")
model.eval()
with torch.no_grad():
for image, targets in tqdm(val_dataset_loader, desc="validation..."):
@@ -161,62 +161,21 @@ def main(parser_data):
# inference
outputs = model(image)
-
outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs]
+ det_metric.update(targets, outputs)
- # 遍历每张图像的预测结果
- for target, output in zip(targets, outputs):
- if len(output) == 0:
- continue
-
- img_id = int(target["image_id"])
- per_image_boxes = output["boxes"]
- # 对于coco_eval, 需要的每个box的数据格式为[x_min, y_min, w, h]
- # 而我们预测的box格式是[x_min, y_min, x_max, y_max],所以需要转下格式
- per_image_boxes[:, 2:] -= per_image_boxes[:, :2]
- per_image_classes = output["labels"]
- per_image_scores = output["scores"]
-
- # 遍历每个目标的信息
- for object_score, object_class, object_box in zip(
- per_image_scores, per_image_classes, per_image_boxes):
- object_score = float(object_score)
- # 要将类别信息还原回coco91中
- coco80_class = int(object_class)
- coco91_class = int(coco80to91[str(coco80_class)])
- # We recommend rounding coordinates to the nearest tenth of a pixel
- # to reduce resulting JSON file size.
- object_box = [round(b, 2) for b in object_box.tolist()]
-
- res = {"image_id": img_id,
- "category_id": coco91_class,
- "bbox": object_box,
- "score": round(object_score, 3)}
- results.append(res)
-
- # accumulate predictions from all images
- # write predict results into json file
- json_str = json.dumps(results, indent=4)
- with open('predict_tmp.json', 'w') as json_file:
- json_file.write(json_str)
-
- # accumulate predictions from all images
- coco_true = val_dataset.coco
- coco_pre = coco_true.loadRes('predict_tmp.json')
-
- coco_evaluator = COCOeval(cocoGt=coco_true, cocoDt=coco_pre, iouType="bbox")
- coco_evaluator.evaluate()
- coco_evaluator.accumulate()
- coco_evaluator.summarize()
+ det_metric.synchronize_results()
+ det_metric.evaluate()
# calculate COCO info for all classes
- coco_stats, print_coco = summarize(coco_evaluator)
+ coco_stats, print_coco = summarize(det_metric.coco_evaluator)
# calculate voc info for every classes(IoU=0.5)
voc_map_info_list = []
- for i in range(len(category_index)):
- stats, _ = summarize(coco_evaluator, catId=i)
- voc_map_info_list.append(" {:15}: {}".format(category_index[str(i + 1)], stats[1]))
+ classes = [v for v in category_index.values() if v != "N/A"]
+ for i in range(len(classes)):
+ stats, _ = summarize(det_metric.coco_evaluator, catId=i)
+ voc_map_info_list.append(" {:15}: {}".format(classes[i], stats[1]))
print_voc = "\n".join(voc_map_info_list)
print(print_voc)
@@ -241,13 +200,13 @@ def main(parser_data):
parser.add_argument('--device', default='cuda', help='device')
# 检测目标类别数
- parser.add_argument('--num-classes', type=int, default='80', help='number of classes')
+ parser.add_argument('--num-classes', type=int, default=90, help='number of classes')
# 数据集的根目录(coco2017根目录)
parser.add_argument('--data-path', default='/data/coco2017', help='dataset root')
# 训练好的权重文件
- parser.add_argument('--weights', default='./save_weights/model.pth', type=str, help='training weights')
+ parser.add_argument('--weights-path', default='./save_weights/model.pth', type=str, help='training weights')
# batch size
parser.add_argument('--batch_size', default=1, type=int, metavar='N',
diff --git a/pytorch_object_detection/yolov3_spp/README.md b/pytorch_object_detection/yolov3_spp/README.md
index aee46f384..9d9301a2e 100644
--- a/pytorch_object_detection/yolov3_spp/README.md
+++ b/pytorch_object_detection/yolov3_spp/README.md
@@ -3,9 +3,9 @@
## 1 环境配置:
* Python3.6或者3.7
* Pytorch1.7.1(注意:必须是1.6.0或以上,因为使用官方提供的混合精度训练1.6.0后才支持)
-* pycocotools(Linux: ```pip install pycocotools```;
- Windows: ```pip install pycocotools-windows```(不需要额外安装vs))
-* 更多环境配置信息,请查看```requirements.txt```文件
+* pycocotools(Linux: `pip install pycocotools`;
+ Windows: `pip install pycocotools-windows`(不需要额外安装vs))
+* 更多环境配置信息,请查看`requirements.txt`文件
* 最好使用GPU训练
## 2 文件结构:
@@ -39,8 +39,8 @@
```
## 3 训练数据的准备以及目录结构
-* 这里建议标注数据时直接生成yolo格式的标签文件```.txt```,推荐使用免费开源的标注软件(支持yolo格式),[https://github.com/tzutalin/labelImg](https://github.com/tzutalin/labelImg)
-* 如果之前已经标注成pascal voc的```.xml```格式了也没关系,我写了个voc转yolo格式的转化脚本,4.1会讲怎么使用
+* 这里建议标注数据时直接生成yolo格式的标签文件`.txt`,推荐使用免费开源的标注软件(支持yolo格式),[https://github.com/tzutalin/labelImg](https://github.com/tzutalin/labelImg)
+* 如果之前已经标注成pascal voc的`.xml`格式了也没关系,我写了个voc转yolo格式的转化脚本,4.1会讲怎么使用
* 测试图像时最好将图像缩放到32的倍数
* 标注好的数据集请按照以下目录结构进行摆放:
```
@@ -58,12 +58,12 @@
├── data 利用数据集生成的一系列相关准备文件目录
│ ├── my_train_data.txt: 该文件里存储的是所有训练图片的路径地址
│ ├── my_val_data.txt: 该文件里存储的是所有验证图片的路径地址
-│ ├── my_data_label.names: 该文件里存储的是所有类别的名称,一个类别对应一行(这里会根据```.json```文件自动生成)
+│ ├── my_data_label.names: 该文件里存储的是所有类别的名称,一个类别对应一行(这里会根据`.json`文件自动生成)
│ └── my_data.data: 该文件里记录的是类别数类别信息、train以及valid对应的txt文件
```
### 4.1 将VOC标注数据转为YOLO标注数据(如果你的数据已经是YOLO格式了,可跳过该步骤)
-* 使用```trans_voc2yolo.py```脚本进行转换,并在```./data/```文件夹下生成```my_data_label.names```标签文件,
+* 使用`trans_voc2yolo.py`脚本进行转换,并在`./data/`文件夹下生成`my_data_label.names`标签文件,
* 执行脚本前,需要根据自己的路径修改以下参数
```python
# voc数据集根目录以及版本
@@ -80,7 +80,7 @@ save_file_root = "/home/wz/my_project/my_yolo_dataset"
# label标签对应json文件
label_json_path = './data/pascal_voc_classes.json'
```
-* 生成的```my_data_label.names```标签文件格式如下
+* 生成的`my_data_label.names`标签文件格式如下
```text
aeroplane
bicycle
@@ -92,7 +92,7 @@ bus
```
### 4.2 根据摆放好的数据集信息生成一系列相关准备文件
-* 使用```calculate_dataset.py```脚本生成```my_train_data.txt```文件、```my_val_data.txt```文件以及```my_data.data```文件,并生成新的```my_yolov3.cfg```文件
+* 使用`calculate_dataset.py`脚本生成`my_train_data.txt`文件、`my_val_data.txt`文件以及`my_data.data`文件,并生成新的`my_yolov3.cfg`文件
* 执行脚本前,需要根据自己的路径修改以下参数
```python
# 训练集的labels目录路径
@@ -106,21 +106,22 @@ cfg_path = "./cfg/yolov3-spp.cfg"
```
## 5 预训练权重下载地址(下载后放入weights文件夹中):
-* ```yolov3-spp-ultralytics-416.pt```: 链接: https://pan.baidu.com/s/1cK3USHKxDx-d5dONij52lA 密码: r3vm
-* ```yolov3-spp-ultralytics-512.pt```: 链接: https://pan.baidu.com/s/1k5yeTZZNv8Xqf0uBXnUK-g 密码: e3k1
-* ```yolov3-spp-ultralytics-608.pt```: 链接: https://pan.baidu.com/s/1GI8BA0wxeWMC0cjrC01G7Q 密码: ma3t
-* ```yolov3spp-voc-512.pt``` **(这是我在视频演示训练中得到的权重)**: 链接: https://pan.baidu.com/s/1aFAtaHlge0ieFtQ9nhmj3w 密码: 8ph3
+* `yolov3-spp-ultralytics-416.pt`: 链接: https://pan.baidu.com/s/1cK3USHKxDx-d5dONij52lA 密码: r3vm
+* `yolov3-spp-ultralytics-512.pt`: 链接: https://pan.baidu.com/s/1k5yeTZZNv8Xqf0uBXnUK-g 密码: e3k1
+* `yolov3-spp-ultralytics-608.pt`: 链接: https://pan.baidu.com/s/1GI8BA0wxeWMC0cjrC01G7Q 密码: ma3t
+* `yolov3spp-voc-512.pt` **(这是我在视频演示训练中得到的权重)**: 链接: https://pan.baidu.com/s/1aFAtaHlge0ieFtQ9nhmj3w 密码: 8ph3
## 6 数据集,本例程使用的是PASCAL VOC2012数据集
-* ```Pascal VOC2012``` train/val数据集下载地址:http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
+* `Pascal VOC2012` train/val数据集下载地址:http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
* 如果不了解数据集或者想使用自己的数据集进行训练,请参考我的bilibili:https://b23.tv/F1kSCK
## 7 使用方法
* 确保提前准备好数据集
* 确保提前下载好对应预训练模型权重
* 若要使用单GPU训练或者使用CPU训练,直接使用train.py训练脚本
-* 若要使用多GPU训练,使用```python -m torch.distributed.launch --nproc_per_node=8 --use_env train_multi_GPU.py```指令,```nproc_per_node```参数为使用GPU数量
+* 若要使用多GPU训练,使用`python -m torch.distributed.launch --nproc_per_node=8 --use_env train_multi_GPU.py`指令,`nproc_per_node`参数为使用GPU数量
+* 训练过程中保存的`results.txt`是每个epoch在验证集上的COCO指标,前12个值是COCO指标,后面两个值是训练平均损失以及学习率
## 如果对YOLOv3 SPP网络原理不是很理解可参考我的bilibili
[https://www.bilibili.com/video/BV1yi4y1g7ro?p=3](https://www.bilibili.com/video/BV1yi4y1g7ro?p=3)
diff --git a/pytorch_object_detection/yolov3_spp/build_utils/img_utils.py b/pytorch_object_detection/yolov3_spp/build_utils/img_utils.py
index fc4c71929..cabc6c9b3 100644
--- a/pytorch_object_detection/yolov3_spp/build_utils/img_utils.py
+++ b/pytorch_object_detection/yolov3_spp/build_utils/img_utils.py
@@ -37,8 +37,8 @@ def letterbox(img: np.ndarray,
dw, dh = np.mod(dw, 64), np.mod(dh, 64) # wh padding
elif scale_fill: # stretch 简单粗暴的将图片缩放到指定尺寸
dw, dh = 0, 0
- new_unpad = new_shape
- ratio = new_shape[0] / shape[1], new_shape[1] / shape[0] # wh ratios
+ new_unpad = new_shape[::-1] # [h, w] -> [w, h]
+ ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # wh ratios
dw /= 2 # divide padding into 2 sides 将padding分到上下,左右两侧
dh /= 2
diff --git a/pytorch_object_detection/yolov3_spp/build_utils/utils.py b/pytorch_object_detection/yolov3_spp/build_utils/utils.py
index 2c6f73a6d..bf08ea70f 100755
--- a/pytorch_object_detection/yolov3_spp/build_utils/utils.py
+++ b/pytorch_object_detection/yolov3_spp/build_utils/utils.py
@@ -273,7 +273,7 @@ def build_targets(p, targets, model):
# Build targets for compute_loss(), input targets(image_idx,class,x,y,w,h)
nt = targets.shape[0]
tcls, tbox, indices, anch = [], [], [], []
- gain = torch.ones(6, device=targets.device) # normalized to gridspace gain
+ gain = torch.ones(6, device=targets.device).long() # normalized to gridspace gain
multi_gpu = type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
for i, j in enumerate(model.yolo_layers): # j: [89, 101, 113]
diff --git a/pytorch_object_detection/yolov3_spp/draw_box_utils.py b/pytorch_object_detection/yolov3_spp/draw_box_utils.py
index 4d545148e..835d7f7c1 100644
--- a/pytorch_object_detection/yolov3_spp/draw_box_utils.py
+++ b/pytorch_object_detection/yolov3_spp/draw_box_utils.py
@@ -1,7 +1,7 @@
-import collections
-from PIL import Image
+from PIL.Image import Image, fromarray
import PIL.ImageDraw as ImageDraw
import PIL.ImageFont as ImageFont
+from PIL import ImageColor
import numpy as np
STANDARD_COLORS = [
@@ -31,69 +31,123 @@
]
-def filter_low_thresh(boxes, scores, classes, category_index, thresh, box_to_display_str_map, box_to_color_map):
- for i in range(boxes.shape[0]):
- if scores[i] > thresh:
- box = tuple(boxes[i].tolist()) # numpy -> list -> tuple
- if classes[i] in category_index.keys():
- class_name = category_index[classes[i]]
- else:
- class_name = 'N/A'
- display_str = str(class_name)
- display_str = '{}: {}%'.format(display_str, int(100 * scores[i]))
- box_to_display_str_map[box].append(display_str)
- box_to_color_map[box] = STANDARD_COLORS[
- classes[i] % len(STANDARD_COLORS)]
- else:
- break # 网络输出概率已经排序过,当遇到一个不满足后面的肯定不满足
-
-
-def draw_text(draw, box_to_display_str_map, box, left, right, top, bottom, color):
+def draw_text(draw,
+ box: list,
+ cls: int,
+ score: float,
+ category_index: dict,
+ color: str,
+ font: str = 'arial.ttf',
+ font_size: int = 24):
+ """
+ 将目标边界框和类别信息绘制到图片上
+ """
try:
- font = ImageFont.truetype('arial.ttf', 20)
+ font = ImageFont.truetype(font, font_size)
except IOError:
font = ImageFont.load_default()
+ left, top, right, bottom = box
# If the total height of the display strings added to the top of the bounding
# box exceeds the top of the image, stack the strings below the bounding box
# instead of above.
- display_str_heights = [font.getsize(ds)[1] for ds in box_to_display_str_map[box]]
+ display_str = f"{category_index[str(cls)]}: {int(100 * score)}%"
+ display_str_heights = [font.getsize(ds)[1] for ds in display_str]
# Each display_str has a top and bottom margin of 0.05x.
- total_display_str_height = (1 + 2 * 0.05) * sum(display_str_heights)
+ display_str_height = (1 + 2 * 0.05) * max(display_str_heights)
- if top > total_display_str_height:
+ if top > display_str_height:
+ text_top = top - display_str_height
text_bottom = top
else:
- text_bottom = bottom + total_display_str_height
- # Reverse list and print from bottom to top.
- for display_str in box_to_display_str_map[box][::-1]:
- text_width, text_height = font.getsize(display_str)
- margin = np.ceil(0.05 * text_height)
- draw.rectangle([(left, text_bottom - text_height - 2 * margin),
- (left + text_width, text_bottom)], fill=color)
- draw.text((left + margin, text_bottom - text_height - margin),
- display_str,
+ text_top = bottom
+ text_bottom = bottom + display_str_height
+
+ for ds in display_str:
+ text_width, text_height = font.getsize(ds)
+ margin = np.ceil(0.05 * text_width)
+ draw.rectangle([(left, text_top),
+ (left + text_width + 2 * margin, text_bottom)], fill=color)
+ draw.text((left + margin, text_top),
+ ds,
fill='black',
font=font)
- text_bottom -= text_height - 2 * margin
-
-
-def draw_box(image, boxes, classes, scores, category_index, thresh=0.1, line_thickness=3):
- box_to_display_str_map = collections.defaultdict(list)
- box_to_color_map = collections.defaultdict(str)
-
- filter_low_thresh(boxes, scores, classes, category_index, thresh, box_to_display_str_map, box_to_color_map)
-
- # Draw all boxes onto image.
- if isinstance(image, np.ndarray):
- image = Image.fromarray(image)
- draw = ImageDraw.Draw(image)
- im_width, im_height = image.size
- for box, color in box_to_color_map.items():
- xmin, ymin, xmax, ymax = box
- (left, right, top, bottom) = (xmin * 1, xmax * 1,
- ymin * 1, ymax * 1)
- draw.line([(left, top), (left, bottom), (right, bottom),
- (right, top), (left, top)], width=line_thickness, fill=color)
- draw_text(draw, box_to_display_str_map, box, left, right, top, bottom, color)
+ left += text_width
+
+
+def draw_masks(image, masks, colors, thresh: float = 0.7, alpha: float = 0.5):
+ np_image = np.array(image)
+ masks = np.where(masks > thresh, True, False)
+
+ # colors = np.array(colors)
+ img_to_draw = np.copy(np_image)
+ # TODO: There might be a way to vectorize this
+ for mask, color in zip(masks, colors):
+ img_to_draw[mask] = color
+
+ out = np_image * (1 - alpha) + img_to_draw * alpha
+ return fromarray(out.astype(np.uint8))
+
+
+def draw_objs(image: Image,
+ boxes: np.ndarray = None,
+ classes: np.ndarray = None,
+ scores: np.ndarray = None,
+ masks: np.ndarray = None,
+ category_index: dict = None,
+ box_thresh: float = 0.1,
+ mask_thresh: float = 0.5,
+ line_thickness: int = 8,
+ font: str = 'arial.ttf',
+ font_size: int = 24,
+ draw_boxes_on_image: bool = True,
+ draw_masks_on_image: bool = False):
+ """
+ 将目标边界框信息,类别信息,mask信息绘制在图片上
+ Args:
+ image: 需要绘制的图片
+ boxes: 目标边界框信息
+ classes: 目标类别信息
+ scores: 目标概率信息
+ masks: 目标mask信息
+ category_index: 类别与名称字典
+ box_thresh: 过滤的概率阈值
+ mask_thresh:
+ line_thickness: 边界框宽度
+ font: 字体类型
+ font_size: 字体大小
+ draw_boxes_on_image:
+ draw_masks_on_image:
+
+ Returns:
+
+ """
+
+ # 过滤掉低概率的目标
+ idxs = np.greater(scores, box_thresh)
+ boxes = boxes[idxs]
+ classes = classes[idxs]
+ scores = scores[idxs]
+ if masks is not None:
+ masks = masks[idxs]
+ if len(boxes) == 0:
+ return image
+
+ colors = [ImageColor.getrgb(STANDARD_COLORS[cls % len(STANDARD_COLORS)]) for cls in classes]
+
+ if draw_boxes_on_image:
+ # Draw all boxes onto image.
+ draw = ImageDraw.Draw(image)
+ for box, cls, score, color in zip(boxes, classes, scores, colors):
+ left, top, right, bottom = box
+ # 绘制目标边界框
+ draw.line([(left, top), (left, bottom), (right, bottom),
+ (right, top), (left, top)], width=line_thickness, fill=color)
+ # 绘制类别和概率信息
+ draw_text(draw, box.tolist(), int(cls), float(score), category_index, color, font, font_size)
+
+ if draw_masks_on_image and (masks is not None):
+ # Draw all mask onto image.
+ image = draw_masks(image, masks, colors, mask_thresh)
+
return image
diff --git a/pytorch_object_detection/yolov3_spp/load_onnx_test.py b/pytorch_object_detection/yolov3_spp/load_onnx_test.py
index 1ac3dcbd5..de33fc3dd 100644
--- a/pytorch_object_detection/yolov3_spp/load_onnx_test.py
+++ b/pytorch_object_detection/yolov3_spp/load_onnx_test.py
@@ -150,7 +150,7 @@ def nms(bboxes: np.ndarray, iou_threshold=0.5, soft_threshold=0.3, sigma=0.5, me
bboxes = bboxes[iou_mask]
- return np.array(best_bboxes_index, dtype=np.int8)
+ return np.array(best_bboxes_index, dtype=np.int32)
def post_process(pred: np.ndarray, multi_label=False, conf_thres=0.3):
diff --git a/pytorch_object_detection/yolov3_spp/predict_test.py b/pytorch_object_detection/yolov3_spp/predict_test.py
index 67dd40b4f..bbd2d87b4 100644
--- a/pytorch_object_detection/yolov3_spp/predict_test.py
+++ b/pytorch_object_detection/yolov3_spp/predict_test.py
@@ -6,16 +6,17 @@
import cv2
import numpy as np
from matplotlib import pyplot as plt
+from PIL import Image
from build_utils import img_utils, torch_utils, utils
from models import Darknet
-from draw_box_utils import draw_box
+from draw_box_utils import draw_objs
def main():
img_size = 512 # 必须是32的整数倍 [416, 512, 608]
cfg = "cfg/my_yolov3.cfg" # 改成生成的.cfg文件
- weights = "weights/yolov3spp-voc-512.pt" # 改成自己训练好的权重文件
+ weights_path = "weights/yolov3spp-voc-512.pt" # 改成自己训练好的权重文件
json_path = "./data/pascal_voc_classes.json" # json标签文件
img_path = "test.jpg"
assert os.path.exists(cfg), "cfg file {} dose not exist.".format(cfg)
@@ -23,17 +24,19 @@ def main():
assert os.path.exists(json_path), "json file {} dose not exist.".format(json_path)
assert os.path.exists(img_path), "image file {} dose not exist.".format(img_path)
- json_file = open(json_path, 'r')
- class_dict = json.load(json_file)
- json_file.close()
- category_index = {v: k for k, v in class_dict.items()}
+ with open(json_path, 'r') as f:
+ class_dict = json.load(f)
+
+ category_index = {str(v): str(k) for k, v in class_dict.items()}
input_size = (img_size, img_size)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = Darknet(cfg, img_size)
- model.load_state_dict(torch.load(weights, map_location=device)["model"])
+ weights_dict = torch.load(weights_path, map_location='cpu')
+ weights_dict = weights_dict["model"] if "model" in weights_dict else weights_dict
+ model.load_state_dict(weights_dict)
model.to(device)
model.eval()
@@ -75,11 +78,20 @@ def main():
scores = pred[:, 4].detach().cpu().numpy()
classes = pred[:, 5].detach().cpu().numpy().astype(np.int) + 1
- img_o = draw_box(img_o[:, :, ::-1], bboxes, classes, scores, category_index)
- plt.imshow(img_o)
+ pil_img = Image.fromarray(img_o[:, :, ::-1])
+ plot_img = draw_objs(pil_img,
+ bboxes,
+ classes,
+ scores,
+ category_index=category_index,
+ box_thresh=0.2,
+ line_thickness=3,
+ font='arial.ttf',
+ font_size=20)
+ plt.imshow(plot_img)
plt.show()
-
- img_o.save("test_result.jpg")
+ # 保存预测的图片结果
+ plot_img.save("test_result.jpg")
if __name__ == "__main__":
diff --git a/pytorch_object_detection/yolov3_spp/validation.py b/pytorch_object_detection/yolov3_spp/validation.py
index 34737cbae..074b3c839 100644
--- a/pytorch_object_detection/yolov3_spp/validation.py
+++ b/pytorch_object_detection/yolov3_spp/validation.py
@@ -89,9 +89,9 @@ def main(parser_data):
# read class_indict
label_json_path = './data/pascal_voc_classes.json'
assert os.path.exists(label_json_path), "json file {} dose not exist.".format(label_json_path)
- json_file = open(label_json_path, 'r')
- class_dict = json.load(json_file)
- json_file.close()
+ with open(label_json_path, 'r') as f:
+ class_dict = json.load(f)
+
category_index = {v: k for k, v in class_dict.items()}
data_dict = parse_data_cfg(parser_data.data)
@@ -116,7 +116,9 @@ def main(parser_data):
# create model
model = Darknet(parser_data.cfg, parser_data.img_size)
- model.load_state_dict(torch.load(parser_data.weights, map_location=device)["model"])
+ weights_dict = torch.load(parser_data.weights, map_location='cpu')
+ weights_dict = weights_dict["model"] if "model" in weights_dict else weights_dict
+ model.load_state_dict(weights_dict)
model.to(device)
# evaluate on the test dataset
diff --git a/pytorch_segmentation/deeplab_v3/predict.py b/pytorch_segmentation/deeplab_v3/predict.py
index 98fc5a4db..1e14eb3c7 100644
--- a/pytorch_segmentation/deeplab_v3/predict.py
+++ b/pytorch_segmentation/deeplab_v3/predict.py
@@ -69,7 +69,7 @@ def main():
t_start = time_synchronized()
output = model(img.to(device))
t_end = time_synchronized()
- print("inference+NMS time: {}".format(t_end - t_start))
+ print("inference time: {}".format(t_end - t_start))
prediction = output['out'].argmax(1).squeeze(0)
prediction = prediction.to("cpu").numpy().astype(np.uint8)
diff --git a/pytorch_segmentation/deeplab_v3/requirements.txt b/pytorch_segmentation/deeplab_v3/requirements.txt
index 50b913cfc..ede3e2584 100644
--- a/pytorch_segmentation/deeplab_v3/requirements.txt
+++ b/pytorch_segmentation/deeplab_v3/requirements.txt
@@ -1,4 +1,4 @@
-numpy==1.21.3
+numpy==1.22.0
torch==1.10.0
torchvision==0.11.1
-Pillow==8.4.0
\ No newline at end of file
+Pillow
diff --git a/pytorch_segmentation/fcn/predict.py b/pytorch_segmentation/fcn/predict.py
index c9db2865a..f25222e58 100644
--- a/pytorch_segmentation/fcn/predict.py
+++ b/pytorch_segmentation/fcn/predict.py
@@ -69,7 +69,7 @@ def main():
t_start = time_synchronized()
output = model(img.to(device))
t_end = time_synchronized()
- print("inference+NMS time: {}".format(t_end - t_start))
+ print("inference time: {}".format(t_end - t_start))
prediction = output['out'].argmax(1).squeeze(0)
prediction = prediction.to("cpu").numpy().astype(np.uint8)
diff --git a/pytorch_segmentation/fcn/requirements.txt b/pytorch_segmentation/fcn/requirements.txt
index 50b913cfc..2c58f889e 100644
--- a/pytorch_segmentation/fcn/requirements.txt
+++ b/pytorch_segmentation/fcn/requirements.txt
@@ -1,4 +1,4 @@
-numpy==1.21.3
-torch==1.10.0
+numpy==1.22.0
+torch==1.13.1
torchvision==0.11.1
-Pillow==8.4.0
\ No newline at end of file
+Pillow
diff --git a/pytorch_segmentation/lraspp/predict.py b/pytorch_segmentation/lraspp/predict.py
index 11ebfd7b3..27963fbc3 100644
--- a/pytorch_segmentation/lraspp/predict.py
+++ b/pytorch_segmentation/lraspp/predict.py
@@ -63,7 +63,7 @@ def main():
t_start = time_synchronized()
output = model(img.to(device))
t_end = time_synchronized()
- print("inference+NMS time: {}".format(t_end - t_start))
+ print("inference time: {}".format(t_end - t_start))
prediction = output['out'].argmax(1).squeeze(0)
prediction = prediction.to("cpu").numpy().astype(np.uint8)
diff --git a/pytorch_segmentation/lraspp/requirements.txt b/pytorch_segmentation/lraspp/requirements.txt
index 50b913cfc..ede3e2584 100644
--- a/pytorch_segmentation/lraspp/requirements.txt
+++ b/pytorch_segmentation/lraspp/requirements.txt
@@ -1,4 +1,4 @@
-numpy==1.21.3
+numpy==1.22.0
torch==1.10.0
torchvision==0.11.1
-Pillow==8.4.0
\ No newline at end of file
+Pillow
diff --git a/pytorch_segmentation/u2net/README.md b/pytorch_segmentation/u2net/README.md
new file mode 100644
index 000000000..aa0ae013e
--- /dev/null
+++ b/pytorch_segmentation/u2net/README.md
@@ -0,0 +1,90 @@
+# U2-Net(Going Deeper with Nested U-Structure for Salient Object Detection)
+
+## 该项目主要是来自官方的源码
+- https://github.com/xuebinqin/U-2-Net
+- 注意,该项目是针对显著性目标检测领域(Salient Object Detection / SOD)
+
+## 环境配置:
+- Python3.6/3.7/3.8
+- Pytorch1.10
+- Ubuntu或Centos(Windows暂不支持多GPU训练)
+- 建议使用GPU训练
+- 详细环境配置见`requirements.txt`
+
+
+## 文件结构
+```
+├── src: 搭建网络相关代码
+├── train_utils: 训练以及验证相关代码
+├── my_dataset.py: 自定义数据集读取相关代码
+├── predict.py: 简易的预测代码
+├── train.py: 单GPU或CPU训练代码
+├── train_multi_GPU.py: 多GPU并行训练代码
+├── validation.py: 单独验证模型相关代码
+├── transforms.py: 数据预处理相关代码
+└── requirements.txt: 项目依赖
+```
+
+## DUTS数据集准备
+- DUTS数据集官方下载地址:[http://saliencydetection.net/duts/](http://saliencydetection.net/duts/)
+- 如果下载不了,可以通过我提供的百度云下载,链接: https://pan.baidu.com/s/1nBI6GTN0ZilqH4Tvu18dow 密码: r7k6
+- 其中DUTS-TR为训练集,DUTS-TE是测试(验证)集,数据集解压后目录结构如下:
+```
+├── DUTS-TR
+│ ├── DUTS-TR-Image: 该文件夹存放所有训练集的图片
+│ └── DUTS-TR-Mask: 该文件夹存放对应训练图片的GT标签(Mask蒙板形式)
+│
+└── DUTS-TE
+ ├── DUTS-TE-Image: 该文件夹存放所有测试(验证)集的图片
+ └── DUTS-TE-Mask: 该文件夹存放对应测试(验证)图片的GT标签(Mask蒙板形式)
+```
+- 注意训练或者验证过程中,将`--data-path`指向`DUTS-TR`所在根目录
+
+## 官方权重
+从官方转换得到的权重:
+- `u2net_full.pth`下载链接: https://pan.baidu.com/s/1ojJZS8v3F_eFKkF3DEdEXA 密码: fh1v
+- `u2net_lite.pth`下载链接: https://pan.baidu.com/s/1TIWoiuEz9qRvTX9quDqQHg 密码: 5stj
+
+`u2net_full`在DUTS-TE上的验证结果(使用`validation.py`进行验证):
+```
+MAE: 0.044
+maxF1: 0.868
+```
+**注:**
+- 这里的maxF1和原论文中的结果有些差异,经过对比发现差异主要来自post_norm,原仓库中会对预测结果进行post_norm,但在本仓库中将post_norm给移除了。
+如果加上post_norm这里的maxF1为`0.872`,如果需要做该后处理可自行添加,post_norm流程如下,其中output为验证时网络预测的输出:
+```python
+ma = torch.max(output)
+mi = torch.min(output)
+output = (output - mi) / (ma - mi)
+```
+- 如果要载入官方提供的权重,需要将`src/model.py`中`ConvBNReLU`类里卷积的bias设置成True,因为官方代码里没有进行设置(Conv2d的bias默认为True)。
+因为卷积后跟了BN,所以bias是起不到作用的,所以在本仓库中默认将bias设置为False。
+
+## 训练记录(`u2net_full`)
+训练指令:
+```
+torchrun --nproc_per_node=4 train_multi_GPU.py --lr 0.004 --amp
+```
+训练最终在DUTS-TE上的验证结果:
+```
+MAE: 0.047
+maxF1: 0.859
+```
+训练过程详情可见results.txt文件,训练权重下载链接: https://pan.baidu.com/s/1df2jMkrjbgEv-r1NMaZCZg 密码: n4l6
+
+## 训练方法
+* 确保提前准备好数据集
+* 若要使用单GPU或者CPU训练,直接使用train.py训练脚本
+* 若要使用多GPU训练,使用`torchrun --nproc_per_node=8 train_multi_GPU.py`指令,`nproc_per_node`参数为使用GPU数量
+* 如果想指定使用哪些GPU设备可在指令前加上`CUDA_VISIBLE_DEVICES=0,3`(例如我只要使用设备中的第1块和第4块GPU设备)
+* `CUDA_VISIBLE_DEVICES=0,3 torchrun --nproc_per_node=2 train_multi_GPU.py`
+
+## 如果对U2Net网络不了解的可参考我的bilibili
+- [https://www.bilibili.com/video/BV1yB4y1z7m](https://www.bilibili.com/video/BV1yB4y1z7m)
+
+## 进一步了解该项目,以及对U2Net代码的分析可参考我的bilibili
+- [https://www.bilibili.com/video/BV1Kt4y137iS](https://www.bilibili.com/video/BV1Kt4y137iS)
+
+## U2NET网络结构
+
\ No newline at end of file
diff --git a/pytorch_segmentation/u2net/convert_weight.py b/pytorch_segmentation/u2net/convert_weight.py
new file mode 100644
index 000000000..df447e72b
--- /dev/null
+++ b/pytorch_segmentation/u2net/convert_weight.py
@@ -0,0 +1,140 @@
+import re
+import torch
+from src import u2net_full, u2net_lite
+
+layers = {"encode": [7, 6, 5, 4, 4, 4],
+ "decode": [4, 4, 5, 6, 7]}
+
+
+def convert_conv_bn(new_weight, prefix, ks, v):
+ if "conv" in ks[0]:
+ if "weight" == ks[1]:
+ new_weight[prefix + ".conv.weight"] = v
+ elif "bias" == ks[1]:
+ new_weight[prefix + ".conv.bias"] = v
+ else:
+ print(f"unrecognized weight {prefix + ks[1]}")
+ return
+
+ if "bn" in ks[0]:
+ if "running_mean" == ks[1]:
+ new_weight[prefix + ".bn.running_mean"] = v
+ elif "running_var" == ks[1]:
+ new_weight[prefix + ".bn.running_var"] = v
+ elif "weight" == ks[1]:
+ new_weight[prefix + ".bn.weight"] = v
+ elif "bias" == ks[1]:
+ new_weight[prefix + ".bn.bias"] = v
+ elif "num_batches_tracked" == ks[1]:
+ return
+ else:
+ print(f"unrecognized weight {prefix + ks[1]}")
+ return
+
+
+def convert(old_weight: dict):
+ new_weight = {}
+ for k, v in old_weight.items():
+ ks = k.split(".")
+ if ("stage" in ks[0]) and ("d" not in ks[0]):
+ # encode stage
+ num = int(re.findall(r'\d', ks[0])[0]) - 1
+ prefix = f"encode_modules.{num}"
+ if "rebnconvin" == ks[1]:
+ # ConvBNReLU module
+ prefix += ".conv_in"
+ convert_conv_bn(new_weight, prefix, ks[2:], v)
+ elif ("rebnconv" in ks[1]) and ("d" not in ks[1]):
+ num_ = int(re.findall(r'\d', ks[1])[0]) - 1
+ prefix += f".encode_modules.{num_}"
+ convert_conv_bn(new_weight, prefix, ks[2:], v)
+ elif ("rebnconv" in ks[1]) and ("d" in ks[1]):
+ num_ = layers["encode"][num] - int(re.findall(r'\d', ks[1])[0]) - 1
+ prefix += f".decode_modules.{num_}"
+ convert_conv_bn(new_weight, prefix, ks[2:], v)
+ else:
+ print(f"unrecognized key: {k}")
+
+ elif ("stage" in ks[0]) and ("d" in ks[0]):
+ # decode stage
+ num = 5 - int(re.findall(r'\d', ks[0])[0])
+ prefix = f"decode_modules.{num}"
+ if "rebnconvin" == ks[1]:
+ # ConvBNReLU module
+ prefix += ".conv_in"
+ convert_conv_bn(new_weight, prefix, ks[2:], v)
+ elif ("rebnconv" in ks[1]) and ("d" not in ks[1]):
+ num_ = int(re.findall(r'\d', ks[1])[0]) - 1
+ prefix += f".encode_modules.{num_}"
+ convert_conv_bn(new_weight, prefix, ks[2:], v)
+ elif ("rebnconv" in ks[1]) and ("d" in ks[1]):
+ num_ = layers["decode"][num] - int(re.findall(r'\d', ks[1])[0]) - 1
+ prefix += f".decode_modules.{num_}"
+ convert_conv_bn(new_weight, prefix, ks[2:], v)
+ else:
+ print(f"unrecognized key: {k}")
+ elif "side" in ks[0]:
+ # side
+ num = 6 - int(re.findall(r'\d', ks[0])[0])
+ prefix = f"side_modules.{num}"
+ if "weight" == ks[1]:
+ new_weight[prefix + ".weight"] = v
+ elif "bias" == ks[1]:
+ new_weight[prefix + ".bias"] = v
+ else:
+ print(f"unrecognized weight {prefix + ks[1]}")
+ elif "outconv" in ks[0]:
+ prefix = f"out_conv"
+ if "weight" == ks[1]:
+ new_weight[prefix + ".weight"] = v
+ elif "bias" == ks[1]:
+ new_weight[prefix + ".bias"] = v
+ else:
+ print(f"unrecognized weight {prefix + ks[1]}")
+ else:
+ print(f"unrecognized key: {k}")
+
+ return new_weight
+
+
+def main_1():
+ from u2net import U2NET, U2NETP
+
+ old_m = U2NET()
+ old_m.load_state_dict(torch.load("u2net.pth", map_location='cpu'))
+ new_m = u2net_full()
+
+ # old_m = U2NETP()
+ # old_m.load_state_dict(torch.load("u2netp.pth", map_location='cpu'))
+ # new_m = u2net_lite()
+
+ old_w = old_m.state_dict()
+
+ w = convert(old_w)
+ new_m.load_state_dict(w, strict=True)
+
+ torch.random.manual_seed(0)
+ x = torch.randn(1, 3, 288, 288)
+ old_m.eval()
+ new_m.eval()
+ with torch.no_grad():
+ out1 = old_m(x)[0]
+ out2 = new_m(x)
+ assert torch.equal(out1, out2)
+ torch.save(new_m.state_dict(), "u2net_full.pth")
+
+
+def main():
+ old_w = torch.load("u2net.pth", map_location='cpu')
+ new_m = u2net_full()
+
+ # old_w = torch.load("u2netp.pth", map_location='cpu')
+ # new_m = u2net_lite()
+
+ w = convert(old_w)
+ new_m.load_state_dict(w, strict=True)
+ torch.save(new_m.state_dict(), "u2net_full.pth")
+
+
+if __name__ == '__main__':
+ main()
diff --git a/pytorch_segmentation/u2net/my_dataset.py b/pytorch_segmentation/u2net/my_dataset.py
new file mode 100644
index 000000000..6c993db5d
--- /dev/null
+++ b/pytorch_segmentation/u2net/my_dataset.py
@@ -0,0 +1,80 @@
+import os
+
+import cv2
+import torch.utils.data as data
+
+
+class DUTSDataset(data.Dataset):
+ def __init__(self, root: str, train: bool = True, transforms=None):
+ assert os.path.exists(root), f"path '{root}' does not exist."
+ if train:
+ self.image_root = os.path.join(root, "DUTS-TR", "DUTS-TR-Image")
+ self.mask_root = os.path.join(root, "DUTS-TR", "DUTS-TR-Mask")
+ else:
+ self.image_root = os.path.join(root, "DUTS-TE", "DUTS-TE-Image")
+ self.mask_root = os.path.join(root, "DUTS-TE", "DUTS-TE-Mask")
+ assert os.path.exists(self.image_root), f"path '{self.image_root}' does not exist."
+ assert os.path.exists(self.mask_root), f"path '{self.mask_root}' does not exist."
+
+ image_names = [p for p in os.listdir(self.image_root) if p.endswith(".jpg")]
+ mask_names = [p for p in os.listdir(self.mask_root) if p.endswith(".png")]
+ assert len(image_names) > 0, f"not find any images in {self.image_root}."
+
+ # check images and mask
+ re_mask_names = []
+ for p in image_names:
+ mask_name = p.replace(".jpg", ".png")
+ assert mask_name in mask_names, f"{p} has no corresponding mask."
+ re_mask_names.append(mask_name)
+ mask_names = re_mask_names
+
+ self.images_path = [os.path.join(self.image_root, n) for n in image_names]
+ self.masks_path = [os.path.join(self.mask_root, n) for n in mask_names]
+
+ self.transforms = transforms
+
+ def __getitem__(self, idx):
+ image_path = self.images_path[idx]
+ mask_path = self.masks_path[idx]
+ image = cv2.imread(image_path, flags=cv2.IMREAD_COLOR)
+ assert image is not None, f"failed to read image: {image_path}"
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # BGR -> RGB
+ h, w, _ = image.shape
+
+ target = cv2.imread(mask_path, flags=cv2.IMREAD_GRAYSCALE)
+ assert target is not None, f"failed to read mask: {mask_path}"
+
+ if self.transforms is not None:
+ image, target = self.transforms(image, target)
+
+ return image, target
+
+ def __len__(self):
+ return len(self.images_path)
+
+ @staticmethod
+ def collate_fn(batch):
+ images, targets = list(zip(*batch))
+ batched_imgs = cat_list(images, fill_value=0)
+ batched_targets = cat_list(targets, fill_value=0)
+
+ return batched_imgs, batched_targets
+
+
+def cat_list(images, fill_value=0):
+ max_size = tuple(max(s) for s in zip(*[img.shape for img in images]))
+ batch_shape = (len(images),) + max_size
+ batched_imgs = images[0].new(*batch_shape).fill_(fill_value)
+ for img, pad_img in zip(images, batched_imgs):
+ pad_img[..., :img.shape[-2], :img.shape[-1]].copy_(img)
+ return batched_imgs
+
+
+if __name__ == '__main__':
+ train_dataset = DUTSDataset("./", train=True)
+ print(len(train_dataset))
+
+ val_dataset = DUTSDataset("./", train=False)
+ print(len(val_dataset))
+
+ i, t = train_dataset[0]
diff --git a/pytorch_segmentation/u2net/predict.py b/pytorch_segmentation/u2net/predict.py
new file mode 100644
index 000000000..26b2d257a
--- /dev/null
+++ b/pytorch_segmentation/u2net/predict.py
@@ -0,0 +1,71 @@
+import os
+import time
+
+import cv2
+import numpy as np
+import matplotlib.pyplot as plt
+import torch
+from torchvision.transforms import transforms
+
+from src import u2net_full
+
+
+def time_synchronized():
+ torch.cuda.synchronize() if torch.cuda.is_available() else None
+ return time.time()
+
+
+def main():
+ weights_path = "./u2net_full.pth"
+ img_path = "./test.png"
+ threshold = 0.5
+
+ assert os.path.exists(img_path), f"image file {img_path} dose not exists."
+
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+
+ data_transform = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Resize(320),
+ transforms.Normalize(mean=(0.485, 0.456, 0.406),
+ std=(0.229, 0.224, 0.225))
+ ])
+
+ origin_img = cv2.cvtColor(cv2.imread(img_path, flags=cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)
+
+ h, w = origin_img.shape[:2]
+ img = data_transform(origin_img)
+ img = torch.unsqueeze(img, 0).to(device) # [C, H, W] -> [1, C, H, W]
+
+ model = u2net_full()
+ weights = torch.load(weights_path, map_location='cpu')
+ if "model" in weights:
+ model.load_state_dict(weights["model"])
+ else:
+ model.load_state_dict(weights)
+ model.to(device)
+ model.eval()
+
+ with torch.no_grad():
+ # init model
+ img_height, img_width = img.shape[-2:]
+ init_img = torch.zeros((1, 3, img_height, img_width), device=device)
+ model(init_img)
+
+ t_start = time_synchronized()
+ pred = model(img)
+ t_end = time_synchronized()
+ print("inference time: {}".format(t_end - t_start))
+ pred = torch.squeeze(pred).to("cpu").numpy() # [1, 1, H, W] -> [H, W]
+
+ pred = cv2.resize(pred, dsize=(w, h), interpolation=cv2.INTER_LINEAR)
+ pred_mask = np.where(pred > threshold, 1, 0)
+ origin_img = np.array(origin_img, dtype=np.uint8)
+ seg_img = origin_img * pred_mask[..., None]
+ plt.imshow(seg_img)
+ plt.show()
+ cv2.imwrite("pred_result.png", cv2.cvtColor(seg_img.astype(np.uint8), cv2.COLOR_RGB2BGR))
+
+
+if __name__ == '__main__':
+ main()
diff --git a/pytorch_segmentation/u2net/requirements.txt b/pytorch_segmentation/u2net/requirements.txt
new file mode 100644
index 000000000..a47904dd0
--- /dev/null
+++ b/pytorch_segmentation/u2net/requirements.txt
@@ -0,0 +1,4 @@
+numpy==1.22.0
+torch==1.13.1
+torchvision==0.11.1
+opencv_python==4.5.4.60
diff --git a/pytorch_segmentation/u2net/results20220723-123632.txt b/pytorch_segmentation/u2net/results20220723-123632.txt
new file mode 100644
index 000000000..e174d4c94
--- /dev/null
+++ b/pytorch_segmentation/u2net/results20220723-123632.txt
@@ -0,0 +1,37 @@
+[epoch: 0] train_loss: 2.7385 lr: 0.002002 MAE: 0.465 maxF1: 0.464
+[epoch: 10] train_loss: 1.0385 lr: 0.003994 MAE: 0.124 maxF1: 0.719
+[epoch: 20] train_loss: 0.7629 lr: 0.003972 MAE: 0.077 maxF1: 0.787
+[epoch: 30] train_loss: 0.6758 lr: 0.003936 MAE: 0.083 maxF1: 0.791
+[epoch: 40] train_loss: 0.4905 lr: 0.003884 MAE: 0.073 maxF1: 0.805
+[epoch: 50] train_loss: 0.4337 lr: 0.003818 MAE: 0.063 maxF1: 0.821
+[epoch: 60] train_loss: 0.4157 lr: 0.003738 MAE: 0.067 maxF1: 0.818
+[epoch: 70] train_loss: 0.3424 lr: 0.003644 MAE: 0.058 maxF1: 0.840
+[epoch: 80] train_loss: 0.2909 lr: 0.003538 MAE: 0.057 maxF1: 0.842
+[epoch: 90] train_loss: 0.3220 lr: 0.003420 MAE: 0.064 maxF1: 0.837
+[epoch: 100] train_loss: 0.2653 lr: 0.003292 MAE: 0.055 maxF1: 0.847
+[epoch: 110] train_loss: 0.2627 lr: 0.003153 MAE: 0.055 maxF1: 0.846
+[epoch: 120] train_loss: 0.3230 lr: 0.003005 MAE: 0.058 maxF1: 0.837
+[epoch: 130] train_loss: 0.2177 lr: 0.002850 MAE: 0.053 maxF1: 0.852
+[epoch: 140] train_loss: 0.2807 lr: 0.002688 MAE: 0.061 maxF1: 0.824
+[epoch: 150] train_loss: 0.2091 lr: 0.002520 MAE: 0.057 maxF1: 0.846
+[epoch: 160] train_loss: 0.1971 lr: 0.002349 MAE: 0.049 maxF1: 0.857
+[epoch: 170] train_loss: 0.2157 lr: 0.002175 MAE: 0.050 maxF1: 0.851
+[epoch: 180] train_loss: 0.1881 lr: 0.002000 MAE: 0.048 maxF1: 0.857
+[epoch: 190] train_loss: 0.1855 lr: 0.001825 MAE: 0.047 maxF1: 0.860
+[epoch: 200] train_loss: 0.1817 lr: 0.001651 MAE: 0.047 maxF1: 0.863
+[epoch: 210] train_loss: 0.1740 lr: 0.001480 MAE: 0.048 maxF1: 0.858
+[epoch: 220] train_loss: 0.1707 lr: 0.001312 MAE: 0.048 maxF1: 0.860
+[epoch: 230] train_loss: 0.1653 lr: 0.001150 MAE: 0.048 maxF1: 0.859
+[epoch: 240] train_loss: 0.1652 lr: 0.000995 MAE: 0.046 maxF1: 0.860
+[epoch: 250] train_loss: 0.1631 lr: 0.000847 MAE: 0.048 maxF1: 0.857
+[epoch: 260] train_loss: 0.1584 lr: 0.000708 MAE: 0.047 maxF1: 0.862
+[epoch: 270] train_loss: 0.1590 lr: 0.000580 MAE: 0.047 maxF1: 0.860
+[epoch: 280] train_loss: 0.1521 lr: 0.000462 MAE: 0.047 maxF1: 0.861
+[epoch: 290] train_loss: 0.1535 lr: 0.000356 MAE: 0.047 maxF1: 0.861
+[epoch: 300] train_loss: 0.1520 lr: 0.000262 MAE: 0.047 maxF1: 0.860
+[epoch: 310] train_loss: 0.1488 lr: 0.000182 MAE: 0.047 maxF1: 0.860
+[epoch: 320] train_loss: 0.1493 lr: 0.000116 MAE: 0.047 maxF1: 0.859
+[epoch: 330] train_loss: 0.1470 lr: 0.000064 MAE: 0.047 maxF1: 0.860
+[epoch: 340] train_loss: 0.1493 lr: 0.000028 MAE: 0.047 maxF1: 0.859
+[epoch: 350] train_loss: 0.1482 lr: 0.000006 MAE: 0.047 maxF1: 0.858
+[epoch: 359] train_loss: 0.1518 lr: 0.000000 MAE: 0.047 maxF1: 0.859
diff --git a/pytorch_segmentation/u2net/src/__init__.py b/pytorch_segmentation/u2net/src/__init__.py
new file mode 100644
index 000000000..9411dd2c0
--- /dev/null
+++ b/pytorch_segmentation/u2net/src/__init__.py
@@ -0,0 +1 @@
+from .model import u2net_full, u2net_lite
diff --git a/pytorch_segmentation/u2net/src/model.py b/pytorch_segmentation/u2net/src/model.py
new file mode 100644
index 000000000..9c5b38a25
--- /dev/null
+++ b/pytorch_segmentation/u2net/src/model.py
@@ -0,0 +1,233 @@
+from typing import Union, List
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class ConvBNReLU(nn.Module):
+ def __init__(self, in_ch: int, out_ch: int, kernel_size: int = 3, dilation: int = 1):
+ super().__init__()
+
+ padding = kernel_size // 2 if dilation == 1 else dilation
+ self.conv = nn.Conv2d(in_ch, out_ch, kernel_size, padding=padding, dilation=dilation, bias=False)
+ self.bn = nn.BatchNorm2d(out_ch)
+ self.relu = nn.ReLU(inplace=True)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.relu(self.bn(self.conv(x)))
+
+
+class DownConvBNReLU(ConvBNReLU):
+ def __init__(self, in_ch: int, out_ch: int, kernel_size: int = 3, dilation: int = 1, flag: bool = True):
+ super().__init__(in_ch, out_ch, kernel_size, dilation)
+ self.down_flag = flag
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ if self.down_flag:
+ x = F.max_pool2d(x, kernel_size=2, stride=2, ceil_mode=True)
+
+ return self.relu(self.bn(self.conv(x)))
+
+
+class UpConvBNReLU(ConvBNReLU):
+ def __init__(self, in_ch: int, out_ch: int, kernel_size: int = 3, dilation: int = 1, flag: bool = True):
+ super().__init__(in_ch, out_ch, kernel_size, dilation)
+ self.up_flag = flag
+
+ def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
+ if self.up_flag:
+ x1 = F.interpolate(x1, size=x2.shape[2:], mode='bilinear', align_corners=False)
+ return self.relu(self.bn(self.conv(torch.cat([x1, x2], dim=1))))
+
+
+class RSU(nn.Module):
+ def __init__(self, height: int, in_ch: int, mid_ch: int, out_ch: int):
+ super().__init__()
+
+ assert height >= 2
+ self.conv_in = ConvBNReLU(in_ch, out_ch)
+
+ encode_list = [DownConvBNReLU(out_ch, mid_ch, flag=False)]
+ decode_list = [UpConvBNReLU(mid_ch * 2, mid_ch, flag=False)]
+ for i in range(height - 2):
+ encode_list.append(DownConvBNReLU(mid_ch, mid_ch))
+ decode_list.append(UpConvBNReLU(mid_ch * 2, mid_ch if i < height - 3 else out_ch))
+
+ encode_list.append(ConvBNReLU(mid_ch, mid_ch, dilation=2))
+ self.encode_modules = nn.ModuleList(encode_list)
+ self.decode_modules = nn.ModuleList(decode_list)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x_in = self.conv_in(x)
+
+ x = x_in
+ encode_outputs = []
+ for m in self.encode_modules:
+ x = m(x)
+ encode_outputs.append(x)
+
+ x = encode_outputs.pop()
+ for m in self.decode_modules:
+ x2 = encode_outputs.pop()
+ x = m(x, x2)
+
+ return x + x_in
+
+
+class RSU4F(nn.Module):
+ def __init__(self, in_ch: int, mid_ch: int, out_ch: int):
+ super().__init__()
+ self.conv_in = ConvBNReLU(in_ch, out_ch)
+ self.encode_modules = nn.ModuleList([ConvBNReLU(out_ch, mid_ch),
+ ConvBNReLU(mid_ch, mid_ch, dilation=2),
+ ConvBNReLU(mid_ch, mid_ch, dilation=4),
+ ConvBNReLU(mid_ch, mid_ch, dilation=8)])
+
+ self.decode_modules = nn.ModuleList([ConvBNReLU(mid_ch * 2, mid_ch, dilation=4),
+ ConvBNReLU(mid_ch * 2, mid_ch, dilation=2),
+ ConvBNReLU(mid_ch * 2, out_ch)])
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x_in = self.conv_in(x)
+
+ x = x_in
+ encode_outputs = []
+ for m in self.encode_modules:
+ x = m(x)
+ encode_outputs.append(x)
+
+ x = encode_outputs.pop()
+ for m in self.decode_modules:
+ x2 = encode_outputs.pop()
+ x = m(torch.cat([x, x2], dim=1))
+
+ return x + x_in
+
+
+class U2Net(nn.Module):
+ def __init__(self, cfg: dict, out_ch: int = 1):
+ super().__init__()
+ assert "encode" in cfg
+ assert "decode" in cfg
+ self.encode_num = len(cfg["encode"])
+
+ encode_list = []
+ side_list = []
+ for c in cfg["encode"]:
+ # c: [height, in_ch, mid_ch, out_ch, RSU4F, side]
+ assert len(c) == 6
+ encode_list.append(RSU(*c[:4]) if c[4] is False else RSU4F(*c[1:4]))
+
+ if c[5] is True:
+ side_list.append(nn.Conv2d(c[3], out_ch, kernel_size=3, padding=1))
+ self.encode_modules = nn.ModuleList(encode_list)
+
+ decode_list = []
+ for c in cfg["decode"]:
+ # c: [height, in_ch, mid_ch, out_ch, RSU4F, side]
+ assert len(c) == 6
+ decode_list.append(RSU(*c[:4]) if c[4] is False else RSU4F(*c[1:4]))
+
+ if c[5] is True:
+ side_list.append(nn.Conv2d(c[3], out_ch, kernel_size=3, padding=1))
+ self.decode_modules = nn.ModuleList(decode_list)
+ self.side_modules = nn.ModuleList(side_list)
+ self.out_conv = nn.Conv2d(self.encode_num * out_ch, out_ch, kernel_size=1)
+
+ def forward(self, x: torch.Tensor) -> Union[torch.Tensor, List[torch.Tensor]]:
+ _, _, h, w = x.shape
+
+ # collect encode outputs
+ encode_outputs = []
+ for i, m in enumerate(self.encode_modules):
+ x = m(x)
+ encode_outputs.append(x)
+ if i != self.encode_num - 1:
+ x = F.max_pool2d(x, kernel_size=2, stride=2, ceil_mode=True)
+
+ # collect decode outputs
+ x = encode_outputs.pop()
+ decode_outputs = [x]
+ for m in self.decode_modules:
+ x2 = encode_outputs.pop()
+ x = F.interpolate(x, size=x2.shape[2:], mode='bilinear', align_corners=False)
+ x = m(torch.concat([x, x2], dim=1))
+ decode_outputs.insert(0, x)
+
+ # collect side outputs
+ side_outputs = []
+ for m in self.side_modules:
+ x = decode_outputs.pop()
+ x = F.interpolate(m(x), size=[h, w], mode='bilinear', align_corners=False)
+ side_outputs.insert(0, x)
+
+ x = self.out_conv(torch.concat(side_outputs, dim=1))
+
+ if self.training:
+ # do not use torch.sigmoid for amp safe
+ return [x] + side_outputs
+ else:
+ return torch.sigmoid(x)
+
+
+def u2net_full(out_ch: int = 1):
+ cfg = {
+ # height, in_ch, mid_ch, out_ch, RSU4F, side
+ "encode": [[7, 3, 32, 64, False, False], # En1
+ [6, 64, 32, 128, False, False], # En2
+ [5, 128, 64, 256, False, False], # En3
+ [4, 256, 128, 512, False, False], # En4
+ [4, 512, 256, 512, True, False], # En5
+ [4, 512, 256, 512, True, True]], # En6
+ # height, in_ch, mid_ch, out_ch, RSU4F, side
+ "decode": [[4, 1024, 256, 512, True, True], # De5
+ [4, 1024, 128, 256, False, True], # De4
+ [5, 512, 64, 128, False, True], # De3
+ [6, 256, 32, 64, False, True], # De2
+ [7, 128, 16, 64, False, True]] # De1
+ }
+
+ return U2Net(cfg, out_ch)
+
+
+def u2net_lite(out_ch: int = 1):
+ cfg = {
+ # height, in_ch, mid_ch, out_ch, RSU4F, side
+ "encode": [[7, 3, 16, 64, False, False], # En1
+ [6, 64, 16, 64, False, False], # En2
+ [5, 64, 16, 64, False, False], # En3
+ [4, 64, 16, 64, False, False], # En4
+ [4, 64, 16, 64, True, False], # En5
+ [4, 64, 16, 64, True, True]], # En6
+ # height, in_ch, mid_ch, out_ch, RSU4F, side
+ "decode": [[4, 128, 16, 64, True, True], # De5
+ [4, 128, 16, 64, False, True], # De4
+ [5, 128, 16, 64, False, True], # De3
+ [6, 128, 16, 64, False, True], # De2
+ [7, 128, 16, 64, False, True]] # De1
+ }
+
+ return U2Net(cfg, out_ch)
+
+
+def convert_onnx(m, save_path):
+ m.eval()
+ x = torch.rand(1, 3, 288, 288, requires_grad=True)
+
+ # export the model
+ torch.onnx.export(m, # model being run
+ x, # model input (or a tuple for multiple inputs)
+ save_path, # where to save the model (can be a file or file-like object)
+ export_params=True,
+ opset_version=11)
+
+
+if __name__ == '__main__':
+ # n_m = RSU(height=7, in_ch=3, mid_ch=12, out_ch=3)
+ # convert_onnx(n_m, "RSU7.onnx")
+ #
+ # n_m = RSU4F(in_ch=3, mid_ch=12, out_ch=3)
+ # convert_onnx(n_m, "RSU4F.onnx")
+
+ u2net = u2net_full()
+ convert_onnx(u2net, "u2net_full.onnx")
diff --git a/pytorch_segmentation/u2net/train.py b/pytorch_segmentation/u2net/train.py
new file mode 100644
index 000000000..4ccbf96d4
--- /dev/null
+++ b/pytorch_segmentation/u2net/train.py
@@ -0,0 +1,160 @@
+import os
+import time
+import datetime
+from typing import Union, List
+
+import torch
+from torch.utils import data
+
+from src import u2net_full
+from train_utils import train_one_epoch, evaluate, get_params_groups, create_lr_scheduler
+from my_dataset import DUTSDataset
+import transforms as T
+
+
+class SODPresetTrain:
+ def __init__(self, base_size: Union[int, List[int]], crop_size: int,
+ hflip_prob=0.5, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
+ self.transforms = T.Compose([
+ T.ToTensor(),
+ T.Resize(base_size, resize_mask=True),
+ T.RandomCrop(crop_size),
+ T.RandomHorizontalFlip(hflip_prob),
+ T.Normalize(mean=mean, std=std)
+ ])
+
+ def __call__(self, img, target):
+ return self.transforms(img, target)
+
+
+class SODPresetEval:
+ def __init__(self, base_size: Union[int, List[int]], mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
+ self.transforms = T.Compose([
+ T.ToTensor(),
+ T.Resize(base_size, resize_mask=False),
+ T.Normalize(mean=mean, std=std),
+ ])
+
+ def __call__(self, img, target):
+ return self.transforms(img, target)
+
+
+def main(args):
+ device = torch.device(args.device if torch.cuda.is_available() else "cpu")
+ batch_size = args.batch_size
+
+ # 用来保存训练以及验证过程中信息
+ results_file = "results{}.txt".format(datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
+
+ train_dataset = DUTSDataset(args.data_path, train=True, transforms=SODPresetTrain([320, 320], crop_size=288))
+ val_dataset = DUTSDataset(args.data_path, train=False, transforms=SODPresetEval([320, 320]))
+
+ num_workers = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])
+ train_data_loader = data.DataLoader(train_dataset,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ shuffle=True,
+ pin_memory=True,
+ collate_fn=train_dataset.collate_fn)
+
+ val_data_loader = data.DataLoader(val_dataset,
+ batch_size=1, # must be 1
+ num_workers=num_workers,
+ pin_memory=True,
+ collate_fn=val_dataset.collate_fn)
+
+ model = u2net_full()
+ model.to(device)
+
+ params_group = get_params_groups(model, weight_decay=args.weight_decay)
+ optimizer = torch.optim.AdamW(params_group, lr=args.lr, weight_decay=args.weight_decay)
+ lr_scheduler = create_lr_scheduler(optimizer, len(train_data_loader), args.epochs,
+ warmup=True, warmup_epochs=2)
+
+ scaler = torch.cuda.amp.GradScaler() if args.amp else None
+
+ if args.resume:
+ checkpoint = torch.load(args.resume, map_location='cpu')
+ model.load_state_dict(checkpoint['model'])
+ optimizer.load_state_dict(checkpoint['optimizer'])
+ lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
+ args.start_epoch = checkpoint['epoch'] + 1
+ if args.amp:
+ scaler.load_state_dict(checkpoint["scaler"])
+
+ current_mae, current_f1 = 1.0, 0.0
+ start_time = time.time()
+ for epoch in range(args.start_epoch, args.epochs):
+ mean_loss, lr = train_one_epoch(model, optimizer, train_data_loader, device, epoch,
+ lr_scheduler=lr_scheduler, print_freq=args.print_freq, scaler=scaler)
+
+ save_file = {"model": model.state_dict(),
+ "optimizer": optimizer.state_dict(),
+ "lr_scheduler": lr_scheduler.state_dict(),
+ "epoch": epoch,
+ "args": args}
+ if args.amp:
+ save_file["scaler"] = scaler.state_dict()
+
+ if epoch % args.eval_interval == 0 or epoch == args.epochs - 1:
+ # 每间隔eval_interval个epoch验证一次,减少验证频率节省训练时间
+ mae_metric, f1_metric = evaluate(model, val_data_loader, device=device)
+ mae_info, f1_info = mae_metric.compute(), f1_metric.compute()
+ print(f"[epoch: {epoch}] val_MAE: {mae_info:.3f} val_maxF1: {f1_info:.3f}")
+ # write into txt
+ with open(results_file, "a") as f:
+ # 记录每个epoch对应的train_loss、lr以及验证集各指标
+ write_info = f"[epoch: {epoch}] train_loss: {mean_loss:.4f} lr: {lr:.6f} " \
+ f"MAE: {mae_info:.3f} maxF1: {f1_info:.3f} \n"
+ f.write(write_info)
+
+ # save_best
+ if current_mae >= mae_info and current_f1 <= f1_info:
+ torch.save(save_file, "save_weights/model_best.pth")
+
+ # only save latest 10 epoch weights
+ if os.path.exists(f"save_weights/model_{epoch-10}.pth"):
+ os.remove(f"save_weights/model_{epoch-10}.pth")
+
+ torch.save(save_file, f"save_weights/model_{epoch}.pth")
+
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print("training time {}".format(total_time_str))
+
+
+def parse_args():
+ import argparse
+ parser = argparse.ArgumentParser(description="pytorch u2net training")
+
+ parser.add_argument("--data-path", default="./", help="DUTS root")
+ parser.add_argument("--device", default="cuda", help="training device")
+ parser.add_argument("-b", "--batch-size", default=16, type=int)
+ parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
+ metavar='W', help='weight decay (default: 1e-4)',
+ dest='weight_decay')
+ parser.add_argument("--epochs", default=360, type=int, metavar="N",
+ help="number of total epochs to train")
+ parser.add_argument("--eval-interval", default=10, type=int, help="validation interval default 10 Epochs")
+
+ parser.add_argument('--lr', default=0.001, type=float, help='initial learning rate')
+ parser.add_argument('--print-freq', default=50, type=int, help='print frequency')
+ parser.add_argument('--resume', default='', help='resume from checkpoint')
+ parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
+ help='start epoch')
+ # Mixed precision training parameters
+ parser.add_argument("--amp", action='http://github.com/store_true',
+ help="Use torch.cuda.amp for mixed precision training")
+
+ args = parser.parse_args()
+
+ return args
+
+
+if __name__ == '__main__':
+ args = parse_args()
+
+ if not os.path.exists("./save_weights"):
+ os.mkdir("./save_weights")
+
+ main(args)
diff --git a/pytorch_segmentation/u2net/train_multi_GPU.py b/pytorch_segmentation/u2net/train_multi_GPU.py
new file mode 100644
index 000000000..1a62a0ec7
--- /dev/null
+++ b/pytorch_segmentation/u2net/train_multi_GPU.py
@@ -0,0 +1,224 @@
+import time
+import os
+import datetime
+from typing import Union, List
+
+import torch
+from torch.utils import data
+
+from src import u2net_full
+from train_utils import (train_one_epoch, evaluate, init_distributed_mode, save_on_master, mkdir,
+ create_lr_scheduler, get_params_groups)
+from my_dataset import DUTSDataset
+import transforms as T
+
+
+class SODPresetTrain:
+ def __init__(self, base_size: Union[int, List[int]], crop_size: int,
+ hflip_prob=0.5, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
+ self.transforms = T.Compose([
+ T.ToTensor(),
+ T.Resize(base_size, resize_mask=True),
+ T.RandomCrop(crop_size),
+ T.RandomHorizontalFlip(hflip_prob),
+ T.Normalize(mean=mean, std=std)
+ ])
+
+ def __call__(self, img, target):
+ return self.transforms(img, target)
+
+
+class SODPresetEval:
+ def __init__(self, base_size: Union[int, List[int]], mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
+ self.transforms = T.Compose([
+ T.ToTensor(),
+ T.Resize(base_size, resize_mask=False),
+ T.Normalize(mean=mean, std=std),
+ ])
+
+ def __call__(self, img, target):
+ return self.transforms(img, target)
+
+
+def main(args):
+ init_distributed_mode(args)
+ print(args)
+
+ device = torch.device(args.device)
+
+ # 用来保存训练以及验证过程中信息
+ results_file = "results{}.txt".format(datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
+
+ train_dataset = DUTSDataset(args.data_path, train=True, transforms=SODPresetTrain([320, 320], crop_size=288))
+ val_dataset = DUTSDataset(args.data_path, train=False, transforms=SODPresetEval([320, 320]))
+
+ print("Creating data loaders")
+ if args.distributed:
+ train_sampler = data.distributed.DistributedSampler(train_dataset)
+ test_sampler = data.distributed.DistributedSampler(val_dataset)
+ else:
+ train_sampler = data.RandomSampler(train_dataset)
+ test_sampler = data.SequentialSampler(val_dataset)
+
+ train_data_loader = data.DataLoader(
+ train_dataset, batch_size=args.batch_size,
+ sampler=train_sampler, num_workers=args.workers,
+ pin_memory=True, collate_fn=train_dataset.collate_fn, drop_last=True)
+
+ val_data_loader = data.DataLoader(
+ val_dataset, batch_size=1, # batch_size must be 1
+ sampler=test_sampler, num_workers=args.workers,
+ pin_memory=True, collate_fn=train_dataset.collate_fn)
+
+ # create model num_classes equal background + 20 classes
+ model = u2net_full()
+ model.to(device)
+
+ if args.sync_bn:
+ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
+
+ model_without_ddp = model
+ if args.distributed:
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
+ model_without_ddp = model.module
+
+ params_group = get_params_groups(model, weight_decay=args.weight_decay)
+ optimizer = torch.optim.AdamW(params_group, lr=args.lr, weight_decay=args.weight_decay)
+ lr_scheduler = create_lr_scheduler(optimizer, len(train_data_loader), args.epochs,
+ warmup=True, warmup_epochs=2)
+
+ scaler = torch.cuda.amp.GradScaler() if args.amp else None
+
+ # 如果传入resume参数,即上次训练的权重地址,则接着上次的参数训练
+ if args.resume:
+ # If map_location is missing, torch.load will first load the module to CPU
+ # and then copy each parameter to where it was saved,
+ # which would result in all processes on the same machine using the same set of devices.
+ checkpoint = torch.load(args.resume, map_location='cpu') # 读取之前保存的权重文件(包括优化器以及学习率策略)
+ model_without_ddp.load_state_dict(checkpoint['model'])
+ optimizer.load_state_dict(checkpoint['optimizer'])
+ lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
+ args.start_epoch = checkpoint['epoch'] + 1
+ if args.amp:
+ scaler.load_state_dict(checkpoint["scaler"])
+
+ if args.test_only:
+ mae_metric, f1_metric = evaluate(model, val_data_loader, device=device)
+ print(mae_metric, f1_metric)
+ return
+
+ print("Start training")
+ current_mae, current_f1 = 1.0, 0.0
+ start_time = time.time()
+ for epoch in range(args.start_epoch, args.epochs):
+ if args.distributed:
+ train_sampler.set_epoch(epoch)
+
+ mean_loss, lr = train_one_epoch(model, optimizer, train_data_loader, device, epoch,
+ lr_scheduler=lr_scheduler, print_freq=args.print_freq, scaler=scaler)
+
+ save_file = {'model': model_without_ddp.state_dict(),
+ 'optimizer': optimizer.state_dict(),
+ "lr_scheduler": lr_scheduler.state_dict(),
+ 'args': args,
+ 'epoch': epoch}
+ if args.amp:
+ save_file["scaler"] = scaler.state_dict()
+
+ if epoch % args.eval_interval == 0 or epoch == args.epochs - 1:
+ # 每间隔eval_interval个epoch验证一次,减少验证频率节省训练时间
+ mae_metric, f1_metric = evaluate(model, val_data_loader, device=device)
+ mae_info, f1_info = mae_metric.compute(), f1_metric.compute()
+ print(f"[epoch: {epoch}] val_MAE: {mae_info:.3f} val_maxF1: {f1_info:.3f}")
+
+ # 只在主进程上进行写操作
+ if args.rank in [-1, 0]:
+ # write into txt
+ with open(results_file, "a") as f:
+ # 记录每个epoch对应的train_loss、lr以及验证集各指标
+ write_info = f"[epoch: {epoch}] train_loss: {mean_loss:.4f} lr: {lr:.6f} " \
+ f"MAE: {mae_info:.3f} maxF1: {f1_info:.3f} \n"
+ f.write(write_info)
+
+ # save_best
+ if current_mae >= mae_info and current_f1 <= f1_info:
+ if args.output_dir:
+ # 只在主节点上执行保存权重操作
+ save_on_master(save_file,
+ os.path.join(args.output_dir, 'model_best.pth'))
+
+ if args.output_dir:
+ if args.rank in [-1, 0]:
+ # only save latest 10 epoch weights
+ if os.path.exists(os.path.join(args.output_dir, f'model_{epoch - 10}.pth')):
+ os.remove(os.path.join(args.output_dir, f'model_{epoch - 10}.pth'))
+
+ # 只在主节点上执行保存权重操作
+ save_on_master(save_file,
+ os.path.join(args.output_dir, f'model_{epoch}.pth'))
+
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print('Training time {}'.format(total_time_str))
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser(
+ description=__doc__)
+
+ # 训练文件的根目录(VOCdevkit)
+ parser.add_argument('--data-path', default='./', help='DUTS root')
+ # 训练设备类型
+ parser.add_argument('--device', default='cuda', help='device')
+ # 每块GPU上的batch_size
+ parser.add_argument('-b', '--batch-size', default=16, type=int,
+ help='images per gpu, the total batch size is $NGPU x batch_size')
+ # 指定接着从哪个epoch数开始训练
+ parser.add_argument('--start-epoch', default=0, type=int, help='start epoch')
+ # 训练的总epoch数
+ parser.add_argument('--epochs', default=360, type=int, metavar='N',
+ help='number of total epochs to run')
+ parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
+ metavar='W', help='weight decay (default: 1e-4)',
+ dest='weight_decay')
+ # 是否使用同步BN(在多个GPU之间同步),默认不开启,开启后训练速度会变慢
+ parser.add_argument('--sync-bn', action='http://github.com/store_true', help='whether using SyncBatchNorm')
+ # 数据加载以及预处理的线程数
+ parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
+ help='number of data loading workers (default: 4)')
+ # 训练学习率
+ parser.add_argument('--lr', default=0.001, type=float,
+ help='initial learning rate')
+ # 验证频率
+ parser.add_argument("--eval-interval", default=10, type=int, help="validation interval default 10 Epochs")
+ # 训练过程打印信息的频率
+ parser.add_argument('--print-freq', default=20, type=int, help='print frequency')
+ # 文件保存地址
+ parser.add_argument('--output-dir', default='./multi_train', help='path where to save')
+ # 基于上次的训练结果接着训练
+ parser.add_argument('--resume', default='', help='resume from checkpoint')
+ # 不训练,仅测试
+ parser.add_argument(
+ "--test-only",
+ dest="test_only",
+ help="Only test the model",
+ action="http://github.com/store_true",
+ )
+
+ # 分布式进程数
+ parser.add_argument('--world-size', default=1, type=int,
+ help='number of distributed processes')
+ parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training')
+ # Mixed precision training parameters
+ parser.add_argument("--amp", action='http://github.com/store_true',
+ help="Use torch.cuda.amp for mixed precision training")
+
+ args = parser.parse_args()
+
+ # 如果指定了保存文件地址,检查文件夹是否存在,若不存在,则创建
+ if args.output_dir:
+ mkdir(args.output_dir)
+
+ main(args)
diff --git a/pytorch_segmentation/u2net/train_utils/__init__.py b/pytorch_segmentation/u2net/train_utils/__init__.py
new file mode 100644
index 000000000..dfe313dd8
--- /dev/null
+++ b/pytorch_segmentation/u2net/train_utils/__init__.py
@@ -0,0 +1,2 @@
+from .train_and_eval import train_one_epoch, evaluate, create_lr_scheduler, get_params_groups
+from .distributed_utils import init_distributed_mode, save_on_master, mkdir
diff --git a/pytorch_segmentation/u2net/train_utils/distributed_utils.py b/pytorch_segmentation/u2net/train_utils/distributed_utils.py
new file mode 100644
index 000000000..c9bfebb5d
--- /dev/null
+++ b/pytorch_segmentation/u2net/train_utils/distributed_utils.py
@@ -0,0 +1,356 @@
+from collections import defaultdict, deque
+import datetime
+import time
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+
+import errno
+import os
+
+
+class SmoothedValue(object):
+ """Track a series of values and provide access to smoothed values over a
+ window or the global series average.
+ """
+
+ def __init__(self, window_size=20, fmt=None):
+ if fmt is None:
+ fmt = "{value:.4f} ({global_avg:.4f})"
+ self.deque = deque(maxlen=window_size)
+ self.total = 0.0
+ self.count = 0
+ self.fmt = fmt
+
+ def update(self, value, n=1):
+ self.deque.append(value)
+ self.count += n
+ self.total += value * n
+
+ def synchronize_between_processes(self):
+ """
+ Warning: does not synchronize the deque!
+ """
+ if not is_dist_avail_and_initialized():
+ return
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
+ dist.barrier()
+ dist.all_reduce(t)
+ t = t.tolist()
+ self.count = int(t[0])
+ self.total = t[1]
+
+ @property
+ def median(self):
+ d = torch.tensor(list(self.deque))
+ return d.median().item()
+
+ @property
+ def avg(self):
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
+ return d.mean().item()
+
+ @property
+ def global_avg(self):
+ return self.total / self.count
+
+ @property
+ def max(self):
+ return max(self.deque)
+
+ @property
+ def value(self):
+ return self.deque[-1]
+
+ def __str__(self):
+ return self.fmt.format(
+ median=self.median,
+ avg=self.avg,
+ global_avg=self.global_avg,
+ max=self.max,
+ value=self.value)
+
+
+def all_gather(data):
+ """
+ 收集各个进程中的数据
+ Run all_gather on arbitrary picklable data (not necessarily tensors)
+ Args:
+ data: any picklable object
+ Returns:
+ list[data]: list of data gathered from each rank
+ """
+ world_size = get_world_size() # 进程数
+ if world_size == 1:
+ return [data]
+
+ data_list = [None] * world_size
+ dist.all_gather_object(data_list, data)
+
+ return data_list
+
+
+class MeanAbsoluteError(object):
+ def __init__(self):
+ self.mae_list = []
+
+ def update(self, pred: torch.Tensor, gt: torch.Tensor):
+ batch_size, c, h, w = gt.shape
+ assert batch_size == 1, f"validation mode batch_size must be 1, but got batch_size: {batch_size}."
+ resize_pred = F.interpolate(pred, (h, w), mode="bilinear", align_corners=False)
+ error_pixels = torch.sum(torch.abs(resize_pred - gt), dim=(1, 2, 3)) / (h * w)
+ self.mae_list.extend(error_pixels.tolist())
+
+ def compute(self):
+ mae = sum(self.mae_list) / len(self.mae_list)
+ return mae
+
+ def gather_from_all_processes(self):
+ if not torch.distributed.is_available():
+ return
+ if not torch.distributed.is_initialized():
+ return
+ torch.distributed.barrier()
+ gather_mae_list = []
+ for i in all_gather(self.mae_list):
+ gather_mae_list.extend(i)
+ self.mae_list = gather_mae_list
+
+ def __str__(self):
+ mae = self.compute()
+ return f'MAE: {mae:.3f}'
+
+
+class F1Score(object):
+ """
+ refer: https://github.com/xuebinqin/DIS/blob/main/IS-Net/basics.py
+ """
+
+ def __init__(self, threshold: float = 0.5):
+ self.precision_cum = None
+ self.recall_cum = None
+ self.num_cum = None
+ self.threshold = threshold
+
+ def update(self, pred: torch.Tensor, gt: torch.Tensor):
+ batch_size, c, h, w = gt.shape
+ assert batch_size == 1, f"validation mode batch_size must be 1, but got batch_size: {batch_size}."
+ resize_pred = F.interpolate(pred, (h, w), mode="bilinear", align_corners=False)
+ gt_num = torch.sum(torch.gt(gt, self.threshold).float())
+
+ pp = resize_pred[torch.gt(gt, self.threshold)] # 对应预测map中GT为前景的区域
+ nn = resize_pred[torch.le(gt, self.threshold)] # 对应预测map中GT为背景的区域
+
+ pp_hist = torch.histc(pp, bins=255, min=0.0, max=1.0)
+ nn_hist = torch.histc(nn, bins=255, min=0.0, max=1.0)
+
+ # Sort according to the prediction probability from large to small
+ pp_hist_flip = torch.flipud(pp_hist)
+ nn_hist_flip = torch.flipud(nn_hist)
+
+ pp_hist_flip_cum = torch.cumsum(pp_hist_flip, dim=0)
+ nn_hist_flip_cum = torch.cumsum(nn_hist_flip, dim=0)
+
+ precision = pp_hist_flip_cum / (pp_hist_flip_cum + nn_hist_flip_cum + 1e-4)
+ recall = pp_hist_flip_cum / (gt_num + 1e-4)
+
+ if self.precision_cum is None:
+ self.precision_cum = torch.full_like(precision, fill_value=0.)
+
+ if self.recall_cum is None:
+ self.recall_cum = torch.full_like(recall, fill_value=0.)
+
+ if self.num_cum is None:
+ self.num_cum = torch.zeros([1], dtype=gt.dtype, device=gt.device)
+
+ self.precision_cum += precision
+ self.recall_cum += recall
+ self.num_cum += batch_size
+
+ def compute(self):
+ pre_mean = self.precision_cum / self.num_cum
+ rec_mean = self.recall_cum / self.num_cum
+ f1_mean = (1 + 0.3) * pre_mean * rec_mean / (0.3 * pre_mean + rec_mean + 1e-8)
+ max_f1 = torch.amax(f1_mean).item()
+ return max_f1
+
+ def reduce_from_all_processes(self):
+ if not torch.distributed.is_available():
+ return
+ if not torch.distributed.is_initialized():
+ return
+ torch.distributed.barrier()
+ torch.distributed.all_reduce(self.precision_cum)
+ torch.distributed.all_reduce(self.recall_cum)
+ torch.distributed.all_reduce(self.num_cum)
+
+ def __str__(self):
+ max_f1 = self.compute()
+ return f'maxF1: {max_f1:.3f}'
+
+
+class MetricLogger(object):
+ def __init__(self, delimiter="\t"):
+ self.meters = defaultdict(SmoothedValue)
+ self.delimiter = delimiter
+
+ def update(self, **kwargs):
+ for k, v in kwargs.items():
+ if isinstance(v, torch.Tensor):
+ v = v.item()
+ assert isinstance(v, (float, int))
+ self.meters[k].update(v)
+
+ def __getattr__(self, attr):
+ if attr in self.meters:
+ return self.meters[attr]
+ if attr in self.__dict__:
+ return self.__dict__[attr]
+ raise AttributeError("'{}' object has no attribute '{}'".format(
+ type(self).__name__, attr))
+
+ def __str__(self):
+ loss_str = []
+ for name, meter in self.meters.items():
+ loss_str.append(
+ "{}: {}".format(name, str(meter))
+ )
+ return self.delimiter.join(loss_str)
+
+ def synchronize_between_processes(self):
+ for meter in self.meters.values():
+ meter.synchronize_between_processes()
+
+ def add_meter(self, name, meter):
+ self.meters[name] = meter
+
+ def log_every(self, iterable, print_freq, header=None):
+ i = 0
+ if not header:
+ header = ''
+ start_time = time.time()
+ end = time.time()
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
+ data_time = SmoothedValue(fmt='{avg:.4f}')
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
+ if torch.cuda.is_available():
+ log_msg = self.delimiter.join([
+ header,
+ '[{0' + space_fmt + '}/{1}]',
+ 'eta: {eta}',
+ '{meters}',
+ 'time: {time}',
+ 'data: {data}',
+ 'max mem: {memory:.0f}'
+ ])
+ else:
+ log_msg = self.delimiter.join([
+ header,
+ '[{0' + space_fmt + '}/{1}]',
+ 'eta: {eta}',
+ '{meters}',
+ 'time: {time}',
+ 'data: {data}'
+ ])
+ MB = 1024.0 * 1024.0
+ for obj in iterable:
+ data_time.update(time.time() - end)
+ yield obj
+ iter_time.update(time.time() - end)
+ if i % print_freq == 0:
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
+ if torch.cuda.is_available():
+ print(log_msg.format(
+ i, len(iterable), eta=eta_string,
+ meters=str(self),
+ time=str(iter_time), data=str(data_time),
+ memory=torch.cuda.max_memory_allocated() / MB))
+ else:
+ print(log_msg.format(
+ i, len(iterable), eta=eta_string,
+ meters=str(self),
+ time=str(iter_time), data=str(data_time)))
+ i += 1
+ end = time.time()
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print('{} Total time: {}'.format(header, total_time_str))
+
+
+def mkdir(path):
+ try:
+ os.makedirs(path)
+ except OSError as e:
+ if e.errno != errno.EEXIST:
+ raise
+
+
+def setup_for_distributed(is_master):
+ """
+ This function disables printing when not in master process
+ """
+ import builtins as __builtin__
+ builtin_print = __builtin__.print
+
+ def print(*args, **kwargs):
+ force = kwargs.pop('force', False)
+ if is_master or force:
+ builtin_print(*args, **kwargs)
+
+ __builtin__.print = print
+
+
+def is_dist_avail_and_initialized():
+ if not dist.is_available():
+ return False
+ if not dist.is_initialized():
+ return False
+ return True
+
+
+def get_world_size():
+ if not is_dist_avail_and_initialized():
+ return 1
+ return dist.get_world_size()
+
+
+def get_rank():
+ if not is_dist_avail_and_initialized():
+ return 0
+ return dist.get_rank()
+
+
+def is_main_process():
+ return get_rank() == 0
+
+
+def save_on_master(*args, **kwargs):
+ if is_main_process():
+ torch.save(*args, **kwargs)
+
+
+def init_distributed_mode(args):
+ if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
+ args.rank = int(os.environ["RANK"])
+ args.world_size = int(os.environ['WORLD_SIZE'])
+ args.gpu = int(os.environ['LOCAL_RANK'])
+ elif 'SLURM_PROCID' in os.environ:
+ args.rank = int(os.environ['SLURM_PROCID'])
+ args.gpu = args.rank % torch.cuda.device_count()
+ elif hasattr(args, "rank"):
+ pass
+ else:
+ print('Not using distributed mode')
+ args.distributed = False
+ return
+
+ args.distributed = True
+
+ torch.cuda.set_device(args.gpu)
+ args.dist_backend = 'nccl'
+ print('| distributed init (rank {}): {}'.format(
+ args.rank, args.dist_url), flush=True)
+ torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
+ world_size=args.world_size, rank=args.rank)
+ setup_for_distributed(args.rank == 0)
diff --git a/pytorch_segmentation/u2net/train_utils/train_and_eval.py b/pytorch_segmentation/u2net/train_utils/train_and_eval.py
new file mode 100644
index 000000000..3ff1150aa
--- /dev/null
+++ b/pytorch_segmentation/u2net/train_utils/train_and_eval.py
@@ -0,0 +1,111 @@
+import math
+import torch
+from torch.nn import functional as F
+import train_utils.distributed_utils as utils
+
+
+def criterion(inputs, target):
+ losses = [F.binary_cross_entropy_with_logits(inputs[i], target) for i in range(len(inputs))]
+ total_loss = sum(losses)
+
+ return total_loss
+
+
+def evaluate(model, data_loader, device):
+ model.eval()
+ mae_metric = utils.MeanAbsoluteError()
+ f1_metric = utils.F1Score()
+ metric_logger = utils.MetricLogger(delimiter=" ")
+ header = 'Test:'
+ with torch.no_grad():
+ for images, targets in metric_logger.log_every(data_loader, 100, header):
+ images, targets = images.to(device), targets.to(device)
+ output = model(images)
+
+ # post norm
+ # ma = torch.max(output)
+ # mi = torch.min(output)
+ # output = (output - mi) / (ma - mi)
+
+ mae_metric.update(output, targets)
+ f1_metric.update(output, targets)
+
+ mae_metric.gather_from_all_processes()
+ f1_metric.reduce_from_all_processes()
+
+ return mae_metric, f1_metric
+
+
+def train_one_epoch(model, optimizer, data_loader, device, epoch, lr_scheduler, print_freq=10, scaler=None):
+ model.train()
+ metric_logger = utils.MetricLogger(delimiter=" ")
+ metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
+ header = 'Epoch: [{}]'.format(epoch)
+
+ for image, target in metric_logger.log_every(data_loader, print_freq, header):
+ image, target = image.to(device), target.to(device)
+ with torch.cuda.amp.autocast(enabled=scaler is not None):
+ output = model(image)
+ loss = criterion(output, target)
+
+ optimizer.zero_grad()
+ if scaler is not None:
+ scaler.scale(loss).backward()
+ scaler.step(optimizer)
+ scaler.update()
+ else:
+ loss.backward()
+ optimizer.step()
+
+ lr_scheduler.step()
+
+ lr = optimizer.param_groups[0]["lr"]
+ metric_logger.update(loss=loss.item(), lr=lr)
+
+ return metric_logger.meters["loss"].global_avg, lr
+
+
+def create_lr_scheduler(optimizer,
+ num_step: int,
+ epochs: int,
+ warmup=True,
+ warmup_epochs=1,
+ warmup_factor=1e-3,
+ end_factor=1e-6):
+ assert num_step > 0 and epochs > 0
+ if warmup is False:
+ warmup_epochs = 0
+
+ def f(x):
+ """
+ 根据step数返回一个学习率倍率因子,
+ 注意在训练开始之前,pytorch会提前调用一次lr_scheduler.step()方法
+ """
+ if warmup is True and x <= (warmup_epochs * num_step):
+ alpha = float(x) / (warmup_epochs * num_step)
+ # warmup过程中lr倍率因子从warmup_factor -> 1
+ return warmup_factor * (1 - alpha) + alpha
+ else:
+ current_step = (x - warmup_epochs * num_step)
+ cosine_steps = (epochs - warmup_epochs) * num_step
+ # warmup后lr倍率因子从1 -> end_factor
+ return ((1 + math.cos(current_step * math.pi / cosine_steps)) / 2) * (1 - end_factor) + end_factor
+
+ return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=f)
+
+
+def get_params_groups(model: torch.nn.Module, weight_decay: float = 1e-4):
+ params_group = [{"params": [], "weight_decay": 0.}, # no decay
+ {"params": [], "weight_decay": weight_decay}] # with decay
+
+ for name, param in model.named_parameters():
+ if not param.requires_grad:
+ continue # frozen weights
+
+ if len(param.shape) == 1 or name.endswith(".bias"):
+ # bn:(weight,bias) conv2d:(bias) linear:(bias)
+ params_group[0]["params"].append(param) # no decay
+ else:
+ params_group[1]["params"].append(param) # with decay
+
+ return params_group
diff --git a/pytorch_segmentation/u2net/transforms.py b/pytorch_segmentation/u2net/transforms.py
new file mode 100644
index 000000000..230b0fb87
--- /dev/null
+++ b/pytorch_segmentation/u2net/transforms.py
@@ -0,0 +1,79 @@
+import random
+from typing import List, Union
+from torchvision.transforms import functional as F
+from torchvision.transforms import transforms as T
+
+
+class Compose(object):
+ def __init__(self, transforms):
+ self.transforms = transforms
+
+ def __call__(self, image, target=None):
+ for t in self.transforms:
+ image, target = t(image, target)
+
+ return image, target
+
+
+class ToTensor(object):
+ def __call__(self, image, target):
+ image = F.to_tensor(image)
+ target = F.to_tensor(target)
+ return image, target
+
+
+class RandomHorizontalFlip(object):
+ def __init__(self, prob):
+ self.flip_prob = prob
+
+ def __call__(self, image, target):
+ if random.random() < self.flip_prob:
+ image = F.hflip(image)
+ target = F.hflip(target)
+ return image, target
+
+
+class Normalize(object):
+ def __init__(self, mean, std):
+ self.mean = mean
+ self.std = std
+
+ def __call__(self, image, target):
+ image = F.normalize(image, mean=self.mean, std=self.std)
+ return image, target
+
+
+class Resize(object):
+ def __init__(self, size: Union[int, List[int]], resize_mask: bool = True):
+ self.size = size # [h, w]
+ self.resize_mask = resize_mask
+
+ def __call__(self, image, target=None):
+ image = F.resize(image, self.size)
+ if self.resize_mask is True:
+ target = F.resize(target, self.size)
+
+ return image, target
+
+
+class RandomCrop(object):
+ def __init__(self, size: int):
+ self.size = size
+
+ def pad_if_smaller(self, img, fill=0):
+ # 如果图像最小边长小于给定size,则用数值fill进行padding
+ min_size = min(img.shape[-2:])
+ if min_size < self.size:
+ ow, oh = img.size
+ padh = self.size - oh if oh < self.size else 0
+ padw = self.size - ow if ow < self.size else 0
+ img = F.pad(img, [0, 0, padw, padh], fill=fill)
+ return img
+
+ def __call__(self, image, target):
+ image = self.pad_if_smaller(image)
+ target = self.pad_if_smaller(target)
+ crop_params = T.RandomCrop.get_params(image, (self.size, self.size))
+ image = F.crop(image, *crop_params)
+ target = F.crop(target, *crop_params)
+ return image, target
diff --git a/pytorch_segmentation/u2net/u2net.png b/pytorch_segmentation/u2net/u2net.png
new file mode 100644
index 000000000..61b4cba34
Binary files /dev/null and b/pytorch_segmentation/u2net/u2net.png differ
diff --git a/pytorch_segmentation/u2net/validation.py b/pytorch_segmentation/u2net/validation.py
new file mode 100644
index 000000000..0c1b4e224
--- /dev/null
+++ b/pytorch_segmentation/u2net/validation.py
@@ -0,0 +1,67 @@
+import os
+from typing import Union, List
+
+import torch
+from torch.utils import data
+
+from src import u2net_full
+from train_utils import evaluate
+from my_dataset import DUTSDataset
+import transforms as T
+
+
+class SODPresetEval:
+ def __init__(self, base_size: Union[int, List[int]], mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
+ self.transforms = T.Compose([
+ T.ToTensor(),
+ T.Resize(base_size, resize_mask=False),
+ T.Normalize(mean=mean, std=std),
+ ])
+
+ def __call__(self, img, target):
+ return self.transforms(img, target)
+
+
+def main(args):
+ device = torch.device(args.device if torch.cuda.is_available() else "cpu")
+ assert os.path.exists(args.weights), f"weights {args.weights} not found."
+
+ val_dataset = DUTSDataset(args.data_path, train=False, transforms=SODPresetEval([320, 320]))
+
+ num_workers = 4
+ val_data_loader = data.DataLoader(val_dataset,
+ batch_size=1, # must be 1
+ num_workers=num_workers,
+ pin_memory=True,
+ shuffle=False,
+ collate_fn=val_dataset.collate_fn)
+
+ model = u2net_full()
+ pretrain_weights = torch.load(args.weights, map_location='cpu')
+ if "model" in pretrain_weights:
+ model.load_state_dict(pretrain_weights["model"])
+ else:
+ model.load_state_dict(pretrain_weights)
+ model.to(device)
+
+ mae_metric, f1_metric = evaluate(model, val_data_loader, device=device)
+ print(mae_metric, f1_metric)
+
+
+def parse_args():
+ import argparse
+ parser = argparse.ArgumentParser(description="pytorch u2net validation")
+
+ parser.add_argument("--data-path", default="./", help="DUTS root")
+ parser.add_argument("--weights", default="./u2net_full.pth")
+ parser.add_argument("--device", default="cuda:0", help="training device")
+ parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
+
+ args = parser.parse_args()
+
+ return args
+
+
+if __name__ == '__main__':
+ args = parse_args()
+ main(args)
diff --git a/pytorch_segmentation/unet/README.md b/pytorch_segmentation/unet/README.md
index 115bce4a2..8576783df 100644
--- a/pytorch_segmentation/unet/README.md
+++ b/pytorch_segmentation/unet/README.md
@@ -47,7 +47,7 @@
## 进一步了解该项目,以及对U-Net代码的分析可参考我的bilibili
-
+* [https://b23.tv/PCJJmqN](https://b23.tv/PCJJmqN)
## 本项目U-Net默认使用双线性插值做为上采样,结构图如下
-
\ No newline at end of file
+
diff --git a/pytorch_segmentation/unet/my_dataset.py b/pytorch_segmentation/unet/my_dataset.py
index d11e1217f..969859d4f 100644
--- a/pytorch_segmentation/unet/my_dataset.py
+++ b/pytorch_segmentation/unet/my_dataset.py
@@ -7,9 +7,9 @@
class DriveDataset(Dataset):
def __init__(self, root: str, train: bool, transforms=None):
super(DriveDataset, self).__init__()
- data_root = os.path.join(root, "DRIVE", "training" if train else "test")
- assert os.path.exists(data_root), f"path '{data_root}' does not exists."
self.flag = "training" if train else "test"
+ data_root = os.path.join(root, "DRIVE", self.flag)
+ assert os.path.exists(data_root), f"path '{data_root}' does not exists."
self.transforms = transforms
img_names = [i for i in os.listdir(os.path.join(data_root, "images")) if i.endswith(".tif")]
self.img_list = [os.path.join(data_root, "images", i) for i in img_names]
@@ -18,14 +18,14 @@ def __init__(self, root: str, train: bool, transforms=None):
# check files
for i in self.manual:
if os.path.exists(i) is False:
- print(f"file {i} does not exists.")
+ raise FileNotFoundError(f"file {i} does not exists.")
self.roi_mask = [os.path.join(data_root, "mask", i.split("_")[0] + f"_{self.flag}_mask.gif")
for i in img_names]
# check files
for i in self.roi_mask:
if os.path.exists(i) is False:
- print(f"file {i} does not exists.")
+ raise FileNotFoundError(f"file {i} does not exists.")
def __getitem__(self, idx):
img = Image.open(self.img_list[idx]).convert('RGB')
diff --git a/pytorch_segmentation/unet/predict.py b/pytorch_segmentation/unet/predict.py
index 2e1e1b9a9..c7d557fa7 100644
--- a/pytorch_segmentation/unet/predict.py
+++ b/pytorch_segmentation/unet/predict.py
@@ -61,7 +61,7 @@ def main():
t_start = time_synchronized()
output = model(img.to(device))
t_end = time_synchronized()
- print("inference+NMS time: {}".format(t_end - t_start))
+ print("inference time: {}".format(t_end - t_start))
prediction = output['out'].argmax(1).squeeze(0)
prediction = prediction.to("cpu").numpy().astype(np.uint8)
diff --git a/pytorch_segmentation/unet/requirements.txt b/pytorch_segmentation/unet/requirements.txt
index 50b913cfc..2c58f889e 100644
--- a/pytorch_segmentation/unet/requirements.txt
+++ b/pytorch_segmentation/unet/requirements.txt
@@ -1,4 +1,4 @@
-numpy==1.21.3
-torch==1.10.0
+numpy==1.22.0
+torch==1.13.1
torchvision==0.11.1
-Pillow==8.4.0
\ No newline at end of file
+Pillow
diff --git a/pytorch_segmentation/unet/src/mobilenet_unet.py b/pytorch_segmentation/unet/src/mobilenet_unet.py
index 859e847ba..aff981864 100644
--- a/pytorch_segmentation/unet/src/mobilenet_unet.py
+++ b/pytorch_segmentation/unet/src/mobilenet_unet.py
@@ -88,7 +88,7 @@ def __init__(self, num_classes, pretrain_backbone: bool = False):
self.up4 = Up(c, self.stage_out_channels[0])
self.conv = OutConv(self.stage_out_channels[0], num_classes=num_classes)
- def forward(self, x):
+ def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
input_shape = x.shape[-2:]
backbone_out = self.backbone(x)
x = self.up1(backbone_out['stage4'], backbone_out['stage3'])
diff --git a/pytorch_segmentation/unet/src/unet.py b/pytorch_segmentation/unet/src/unet.py
index 0b50af243..31717aea8 100644
--- a/pytorch_segmentation/unet/src/unet.py
+++ b/pytorch_segmentation/unet/src/unet.py
@@ -1,3 +1,4 @@
+from typing import Dict
import torch
import torch.nn as nn
import torch.nn.functional as F
@@ -35,7 +36,7 @@ def __init__(self, in_channels, out_channels, bilinear=True):
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
- def forward(self, x1, x2):
+ def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
x1 = self.up(x1)
# [N, C, H, W]
diff_y = x2.size()[2] - x1.size()[2]
@@ -80,7 +81,7 @@ def __init__(self,
self.up4 = Up(base_c * 2, base_c, bilinear)
self.out_conv = OutConv(base_c, num_classes)
- def forward(self, x):
+ def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
x1 = self.in_conv(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
diff --git a/pytorch_segmentation/unet/src/vgg_unet.py b/pytorch_segmentation/unet/src/vgg_unet.py
index 830af075a..44a21e911 100644
--- a/pytorch_segmentation/unet/src/vgg_unet.py
+++ b/pytorch_segmentation/unet/src/vgg_unet.py
@@ -88,7 +88,7 @@ def __init__(self, num_classes, pretrain_backbone: bool = False):
self.up4 = Up(c, self.stage_out_channels[0])
self.conv = OutConv(self.stage_out_channels[0], num_classes=num_classes)
- def forward(self, x):
+ def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
backbone_out = self.backbone(x)
x = self.up1(backbone_out['stage4'], backbone_out['stage3'])
x = self.up2(x, backbone_out['stage2'])
diff --git a/pytorch_segmentation/unet/train.py b/pytorch_segmentation/unet/train.py
index 21ccade3b..2ac065016 100644
--- a/pytorch_segmentation/unet/train.py
+++ b/pytorch_segmentation/unet/train.py
@@ -169,7 +169,7 @@ def parse_args():
parser.add_argument("--num-classes", default=1, type=int)
parser.add_argument("--device", default="cuda", help="training device")
parser.add_argument("-b", "--batch-size", default=4, type=int)
- parser.add_argument("--epochs", default=100, type=int, metavar="N",
+ parser.add_argument("--epochs", default=200, type=int, metavar="N",
help="number of total epochs to train")
parser.add_argument('--lr', default=0.01, type=float, help='initial learning rate')
diff --git a/pytorch_segmentation/unet/train_multi_GPU.py b/pytorch_segmentation/unet/train_multi_GPU.py
index 8a7007609..11b76ace3 100644
--- a/pytorch_segmentation/unet/train_multi_GPU.py
+++ b/pytorch_segmentation/unet/train_multi_GPU.py
@@ -217,7 +217,7 @@ def main(args):
# 指定接着从哪个epoch数开始训练
parser.add_argument('--start_epoch', default=0, type=int, help='start epoch')
# 训练的总epoch数
- parser.add_argument('--epochs', default=100, type=int, metavar='N',
+ parser.add_argument('--epochs', default=200, type=int, metavar='N',
help='number of total epochs to run')
# 是否使用同步BN(在多个GPU之间同步),默认不开启,开启后训练速度会变慢
parser.add_argument('--sync_bn', type=bool, default=False, help='whether using SyncBatchNorm')
diff --git a/pytorch_segmentation/unet/train_utils/distributed_utils.py b/pytorch_segmentation/unet/train_utils/distributed_utils.py
index 6f044f511..577d5ea3b 100644
--- a/pytorch_segmentation/unet/train_utils/distributed_utils.py
+++ b/pytorch_segmentation/unet/train_utils/distributed_utils.py
@@ -130,11 +130,13 @@ def __init__(self, num_classes: int = 2, ignore_index: int = -100):
self.cumulative_dice = None
self.num_classes = num_classes
self.ignore_index = ignore_index
- self.count = 0
+ self.count = None
def update(self, pred, target):
if self.cumulative_dice is None:
self.cumulative_dice = torch.zeros(1, dtype=pred.dtype, device=pred.device)
+ if self.count is None:
+ self.count = torch.zeros(1, dtype=pred.dtype, device=pred.device)
# compute the Dice score, ignoring background
pred = F.one_hot(pred.argmax(dim=1), self.num_classes).permute(0, 3, 1, 2).float()
dice_target = build_target(target, self.num_classes, self.ignore_index)
@@ -149,10 +151,12 @@ def value(self):
return self.cumulative_dice / self.count
def reset(self):
- self.count = 0
if self.cumulative_dice is not None:
self.cumulative_dice.zero_()
+ if self.count is not None:
+ self.count.zeros_()
+
def reduce_from_all_processes(self):
if not torch.distributed.is_available():
return
@@ -160,6 +164,7 @@ def reduce_from_all_processes(self):
return
torch.distributed.barrier()
torch.distributed.all_reduce(self.cumulative_dice)
+ torch.distributed.all_reduce(self.count)
class MetricLogger(object):
diff --git a/pytorch_segmentation/unet/unet.png b/pytorch_segmentation/unet/unet.png
index a9d874fa9..2107e8bc3 100644
Binary files a/pytorch_segmentation/unet/unet.png and b/pytorch_segmentation/unet/unet.png differ
diff --git a/tensorflow_classification/ConvNeXt/model.py b/tensorflow_classification/ConvNeXt/model.py
new file mode 100644
index 000000000..f1893eb72
--- /dev/null
+++ b/tensorflow_classification/ConvNeXt/model.py
@@ -0,0 +1,214 @@
+import numpy as np
+import tensorflow as tf
+from tensorflow.keras import layers, initializers, Model
+
+KERNEL_INITIALIZER = {
+ "class_name": "TruncatedNormal",
+ "config": {
+ "stddev": 0.2
+ }
+}
+
+BIAS_INITIALIZER = "Zeros"
+
+
+class Block(layers.Layer):
+ """
+ Args:
+ dim (int): Number of input channels.
+ drop_rate (float): Stochastic depth rate. Default: 0.0
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
+ """
+ def __init__(self, dim, drop_rate=0., layer_scale_init_value=1e-6, name: str = None):
+ super().__init__(name=name)
+ self.layer_scale_init_value = layer_scale_init_value
+ self.dwconv = layers.DepthwiseConv2D(7,
+ padding="same",
+ depthwise_initializer=KERNEL_INITIALIZER,
+ bias_initializer=BIAS_INITIALIZER,
+ name="dwconv")
+ self.norm = layers.LayerNormalization(epsilon=1e-6, name="norm")
+ self.pwconv1 = layers.Dense(4 * dim,
+ kernel_initializer=KERNEL_INITIALIZER,
+ bias_initializer=BIAS_INITIALIZER,
+ name="pwconv1")
+ self.act = layers.Activation("gelu")
+ self.pwconv2 = layers.Dense(dim,
+ kernel_initializer=KERNEL_INITIALIZER,
+ bias_initializer=BIAS_INITIALIZER,
+ name="pwconv2")
+ self.drop_path = layers.Dropout(drop_rate, noise_shape=(None, 1, 1, 1)) if drop_rate > 0 else None
+
+ def build(self, input_shape):
+ if self.layer_scale_init_value > 0:
+ self.gamma = self.add_weight(shape=[input_shape[-1]],
+ initializer=initializers.Constant(self.layer_scale_init_value),
+ trainable=True,
+ dtype=tf.float32,
+ name="gamma")
+ else:
+ self.gamma = None
+
+ def call(self, x, training=False):
+ shortcut = x
+ x = self.dwconv(x)
+ x = self.norm(x, training=training)
+ x = self.pwconv1(x)
+ x = self.act(x)
+ x = self.pwconv2(x)
+
+ if self.gamma is not None:
+ x = self.gamma * x
+
+ if self.drop_path is not None:
+ x = self.drop_path(x, training=training)
+
+ return shortcut + x
+
+
+class Stem(layers.Layer):
+ def __init__(self, dim, name: str = None):
+ super().__init__(name=name)
+ self.conv = layers.Conv2D(dim,
+ kernel_size=4,
+ strides=4,
+ padding="same",
+ kernel_initializer=KERNEL_INITIALIZER,
+ bias_initializer=BIAS_INITIALIZER,
+ name="conv2d")
+ self.norm = layers.LayerNormalization(epsilon=1e-6, name="norm")
+
+ def call(self, x, training=False):
+ x = self.conv(x)
+ x = self.norm(x, training=training)
+ return x
+
+
+class DownSample(layers.Layer):
+ def __init__(self, dim, name: str = None):
+ super().__init__(name=name)
+ self.norm = layers.LayerNormalization(epsilon=1e-6, name="norm")
+ self.conv = layers.Conv2D(dim,
+ kernel_size=2,
+ strides=2,
+ padding="same",
+ kernel_initializer=KERNEL_INITIALIZER,
+ bias_initializer=BIAS_INITIALIZER,
+ name="conv2d")
+
+ def call(self, x, training=False):
+ x = self.norm(x, training=training)
+ x = self.conv(x)
+ return x
+
+
+class ConvNeXt(Model):
+ r""" ConvNeXt
+ A Tensorflow impl of : `A ConvNet for the 2020s` -
+ https://arxiv.org/pdf/2201.03545.pdf
+ Args:
+ num_classes (int): Number of classes for classification head. Default: 1000
+ depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
+ dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
+ drop_path_rate (float): Stochastic depth rate. Default: 0.
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
+ """
+ def __init__(self, num_classes: int, depths: list, dims: list, drop_path_rate: float = 0.,
+ layer_scale_init_value: float = 1e-6):
+ super().__init__()
+ self.stem = Stem(dims[0], name="stem")
+
+ cur = 0
+ dp_rates = np.linspace(start=0, stop=drop_path_rate, num=sum(depths))
+ self.stage1 = [Block(dim=dims[0],
+ drop_rate=dp_rates[cur + i],
+ layer_scale_init_value=layer_scale_init_value,
+ name=f"stage1_block{i}")
+ for i in range(depths[0])]
+ cur += depths[0]
+
+ self.downsample2 = DownSample(dims[1], name="downsample2")
+ self.stage2 = [Block(dim=dims[1],
+ drop_rate=dp_rates[cur + i],
+ layer_scale_init_value=layer_scale_init_value,
+ name=f"stage2_block{i}")
+ for i in range(depths[1])]
+ cur += depths[1]
+
+ self.downsample3 = DownSample(dims[2], name="downsample3")
+ self.stage3 = [Block(dim=dims[2],
+ drop_rate=dp_rates[cur + i],
+ layer_scale_init_value=layer_scale_init_value,
+ name=f"stage3_block{i}")
+ for i in range(depths[2])]
+ cur += depths[2]
+
+ self.downsample4 = DownSample(dims[3], name="downsample4")
+ self.stage4 = [Block(dim=dims[3],
+ drop_rate=dp_rates[cur + i],
+ layer_scale_init_value=layer_scale_init_value,
+ name=f"stage4_block{i}")
+ for i in range(depths[3])]
+
+ self.norm = layers.LayerNormalization(epsilon=1e-6, name="norm")
+ self.head = layers.Dense(units=num_classes,
+ kernel_initializer=KERNEL_INITIALIZER,
+ bias_initializer=BIAS_INITIALIZER,
+ name="head")
+
+ def call(self, x, training=False):
+ x = self.stem(x, training=training)
+ for block in self.stage1:
+ x = block(x, training=training)
+
+ x = self.downsample2(x, training=training)
+ for block in self.stage2:
+ x = block(x, training=training)
+
+ x = self.downsample3(x, training=training)
+ for block in self.stage3:
+ x = block(x, training=training)
+
+ x = self.downsample4(x, training=training)
+ for block in self.stage4:
+ x = block(x, training=training)
+
+ x = tf.reduce_mean(x, axis=[1, 2])
+ x = self.norm(x, training=training)
+ x = self.head(x)
+ return x
+
+
+def convnext_tiny(num_classes: int):
+ model = ConvNeXt(depths=[3, 3, 9, 3],
+ dims=[96, 192, 384, 768],
+ num_classes=num_classes)
+ return model
+
+
+def convnext_small(num_classes: int):
+ model = ConvNeXt(depths=[3, 3, 27, 3],
+ dims=[96, 192, 384, 768],
+ num_classes=num_classes)
+ return model
+
+
+def convnext_base(num_classes: int):
+ model = ConvNeXt(depths=[3, 3, 27, 3],
+ dims=[128, 256, 512, 1024],
+ num_classes=num_classes)
+ return model
+
+
+def convnext_large(num_classes: int):
+ model = ConvNeXt(depths=[3, 3, 27, 3],
+ dims=[192, 384, 768, 1536],
+ num_classes=num_classes)
+ return model
+
+
+def convnext_xlarge(num_classes: int):
+ model = ConvNeXt(depths=[3, 3, 27, 3],
+ dims=[256, 512, 1024, 2048],
+ num_classes=num_classes)
+ return model
diff --git a/tensorflow_classification/ConvNeXt/predict.py b/tensorflow_classification/ConvNeXt/predict.py
new file mode 100644
index 000000000..269f509fd
--- /dev/null
+++ b/tensorflow_classification/ConvNeXt/predict.py
@@ -0,0 +1,63 @@
+import os
+import json
+import glob
+import numpy as np
+
+from PIL import Image
+import tensorflow as tf
+import matplotlib.pyplot as plt
+
+from model import convnext_tiny as create_model
+
+
+def main():
+ num_classes = 5
+ im_height = im_width = 224
+
+ # load image
+ img_path = "../tulip.jpg"
+ assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
+ img = Image.open(img_path)
+ # resize image
+ img = img.resize((im_width, im_height))
+ plt.imshow(img)
+
+ # read image
+ img = np.array(img).astype(np.float32)
+
+ # preprocess
+ img = (img / 255. - [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225]
+
+ # Add the image to a batch where it's the only member.
+ img = (np.expand_dims(img, 0))
+
+ # read class_indict
+ json_path = './class_indices.json'
+ assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
+
+ with open(json_path, "r") as f:
+ class_indict = json.load(f)
+
+ # create model
+ model = create_model(num_classes=num_classes)
+ model.build([1, 224, 224, 3])
+
+ weights_path = './save_weights/model.ckpt'
+ assert len(glob.glob(weights_path+"*")), "cannot find {}".format(weights_path)
+ model.load_weights(weights_path)
+
+ result = np.squeeze(model.predict(img, batch_size=1))
+ result = tf.keras.layers.Softmax()(result)
+ predict_class = np.argmax(result)
+
+ print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_class)],
+ result[predict_class])
+ plt.title(print_res)
+ for i in range(len(result)):
+ print("class: {:10} prob: {:.3}".format(class_indict[str(i)],
+ result[i]))
+ plt.show()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/tensorflow_classification/ConvNeXt/train.py b/tensorflow_classification/ConvNeXt/train.py
new file mode 100644
index 000000000..b2cf77248
--- /dev/null
+++ b/tensorflow_classification/ConvNeXt/train.py
@@ -0,0 +1,150 @@
+import os
+import re
+import sys
+import datetime
+
+import tensorflow as tf
+from tqdm import tqdm
+
+from model import convnext_tiny as create_model
+from utils import generate_ds, cosine_scheduler
+
+assert tf.version.VERSION >= "2.4.0", "version of tf must greater/equal than 2.4.0"
+
+
+def main():
+ data_root = "/data/flower_photos" # get data root path
+
+ if not os.path.exists("./save_weights"):
+ os.makedirs("./save_weights")
+
+ batch_size = 8
+ epochs = 10
+ num_classes = 5
+ freeze_layers = False
+ initial_lr = 0.005
+ weight_decay = 5e-4
+
+ log_dir = "./logs/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
+ train_writer = tf.summary.create_file_writer(os.path.join(log_dir, "train"))
+ val_writer = tf.summary.create_file_writer(os.path.join(log_dir, "val"))
+
+ # data generator with data augmentation
+ train_ds, val_ds = generate_ds(data_root, batch_size=batch_size, val_rate=0.2)
+
+ # create model
+ model = create_model(num_classes=num_classes)
+ model.build((1, 224, 224, 3))
+
+ # 下载我提前转好的预训练权重
+ # 链接: https://pan.baidu.com/s/1MtYJ3FCAkiPwaMRKuyZN1Q 密码: 1cgp
+ # load weights
+ pre_weights_path = './convnext_tiny_1k_224.h5'
+ assert os.path.exists(pre_weights_path), "cannot find {}".format(pre_weights_path)
+ model.load_weights(pre_weights_path, by_name=True, skip_mismatch=True)
+
+ # freeze bottom layers
+ if freeze_layers:
+ for layer in model.layers:
+ if "head" not in layer.name:
+ layer.trainable = False
+ else:
+ print("training {}".format(layer.name))
+
+ model.summary()
+
+ # custom learning rate scheduler
+ scheduler = cosine_scheduler(initial_lr, epochs, len(train_ds), train_writer=train_writer)
+
+ # using keras low level api for training
+ loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
+ optimizer = tf.keras.optimizers.SGD(learning_rate=initial_lr, momentum=0.9)
+
+ train_loss = tf.keras.metrics.Mean(name='train_loss')
+ train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
+
+ val_loss = tf.keras.metrics.Mean(name='val_loss')
+ val_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='val_accuracy')
+
+ @tf.function
+ def train_step(train_images, train_labels):
+ with tf.GradientTape() as tape:
+ output = model(train_images, training=True)
+ ce_loss = loss_object(train_labels, output)
+
+ # l2 loss
+ matcher = re.compile(".*(bias|gamma|beta).*")
+ l2loss = weight_decay * tf.add_n([
+ tf.nn.l2_loss(v)
+ for v in model.trainable_variables
+ if not matcher.match(v.name)
+ ])
+
+ loss = ce_loss + l2loss
+
+ gradients = tape.gradient(loss, model.trainable_variables)
+ optimizer.apply_gradients(zip(gradients, model.trainable_variables))
+ train_loss(ce_loss)
+ train_accuracy(train_labels, output)
+
+ @tf.function
+ def val_step(val_images, val_labels):
+ output = model(val_images, training=False)
+ loss = loss_object(val_labels, output)
+
+ val_loss(loss)
+ val_accuracy(val_labels, output)
+
+ best_val_acc = 0.
+ for epoch in range(epochs):
+ train_loss.reset_states() # clear history info
+ train_accuracy.reset_states() # clear history info
+ val_loss.reset_states() # clear history info
+ val_accuracy.reset_states() # clear history info
+
+ # train
+ train_bar = tqdm(train_ds, file=sys.stdout)
+ for images, labels in train_bar:
+ # update learning rate
+ optimizer.learning_rate = next(scheduler)
+
+ train_step(images, labels)
+
+ # print train process
+ train_bar.desc = "train epoch[{}/{}] loss:{:.3f}, acc:{:.3f}, lr:{:.5f}".format(
+ epoch + 1,
+ epochs,
+ train_loss.result(),
+ train_accuracy.result(),
+ optimizer.learning_rate.numpy()
+ )
+
+ # validate
+ val_bar = tqdm(val_ds, file=sys.stdout)
+ for images, labels in val_bar:
+ val_step(images, labels)
+
+ # print val process
+ val_bar.desc = "valid epoch[{}/{}] loss:{:.3f}, acc:{:.3f}".format(epoch + 1,
+ epochs,
+ val_loss.result(),
+ val_accuracy.result())
+ # writing training loss and acc
+ with train_writer.as_default():
+ tf.summary.scalar("loss", train_loss.result(), epoch)
+ tf.summary.scalar("accuracy", train_accuracy.result(), epoch)
+
+ # writing validation loss and acc
+ with val_writer.as_default():
+ tf.summary.scalar("loss", val_loss.result(), epoch)
+ tf.summary.scalar("accuracy", val_accuracy.result(), epoch)
+
+ # only save best weights
+ if val_accuracy.result() > best_val_acc:
+ best_val_acc = val_accuracy.result()
+ save_name = "./save_weights/model.ckpt"
+ model.save_weights(save_name, save_format="tf")
+
+
+if __name__ == '__main__':
+ main()
diff --git a/tensorflow_classification/ConvNeXt/trans_weights.py b/tensorflow_classification/ConvNeXt/trans_weights.py
new file mode 100644
index 000000000..a35b1cc2c
--- /dev/null
+++ b/tensorflow_classification/ConvNeXt/trans_weights.py
@@ -0,0 +1,149 @@
+import torch
+from model import *
+
+
+def transpose_weights(m_type, w_dict, k, v):
+ if m_type == "conv":
+ if len(v.shape) > 1:
+ # conv weights
+ v = np.transpose(v.numpy(), (2, 3, 1, 0)).astype(np.float32)
+ w_dict[k] = v
+ elif m_type == "dwconv":
+ if len(v.shape) > 1:
+ # dwconv weights
+ v = np.transpose(v.numpy(), (2, 3, 0, 1)).astype(np.float32)
+ w_dict[k] = v
+ elif m_type == "linear":
+ if len(v.shape) > 1:
+ v = np.transpose(v.numpy(), (1, 0)).astype(np.float32)
+ w_dict[k] = v
+ elif m_type == "norm":
+ w_dict[k] = v
+ else:
+ ValueError(f"not support type:{m_type}")
+
+
+def main(weights_path: str,
+ model_name: str,
+ model: tf.keras.Model):
+ var_dict = {v.name.split(':')[0]: v for v in model.weights}
+
+ weights_dict = torch.load(weights_path, map_location="cpu")["model"]
+ w_dict = {}
+ for k, v in weights_dict.items():
+ if "downsample_layers" in k:
+ split_k = k.split(".")
+ if split_k[1] == "0":
+ if split_k[2] == "0":
+ k = "stem/conv2d/" + split_k[-1]
+ k = k.replace("weight", "kernel")
+ transpose_weights("conv", w_dict, k, v)
+ else:
+ k = "stem/norm/" + split_k[-1]
+ k = k.replace("weight", "gamma")
+ k = k.replace("bias", "beta")
+ transpose_weights("norm", w_dict, k, v)
+ else:
+ stage = int(split_k[1]) + 1
+ if split_k[2] == "1":
+ k = f"downsample{stage}/conv2d/" + split_k[-1]
+ k = k.replace("weight", "kernel")
+ transpose_weights("conv", w_dict, k, v)
+ else:
+ k = f"downsample{stage}/norm/" + split_k[-1]
+ k = k.replace("weight", "gamma")
+ k = k.replace("bias", "beta")
+ transpose_weights("norm", w_dict, k, v)
+ elif "stages" in k:
+ split_k = k.split(".")
+ stage = int(split_k[1]) + 1
+ block = int(split_k[2])
+ if "dwconv" in k:
+ k = f"stage{stage}_block{block}/{split_k[-2]}/{split_k[-1]}"
+ k = k.replace("weight", "depthwise_kernel")
+ transpose_weights("dwconv", w_dict, k, v)
+ elif "pwconv" in k:
+ k = f"stage{stage}_block{block}/{split_k[-2]}/{split_k[-1]}"
+ k = k.replace("weight", "kernel")
+ transpose_weights("linear", w_dict, k, v)
+ elif "norm" in k:
+ k = f"stage{stage}_block{block}/{split_k[-2]}/{split_k[-1]}"
+ k = k.replace("weight", "gamma")
+ k = k.replace("bias", "beta")
+ transpose_weights("norm", w_dict, k, v)
+ elif "gamma" in k:
+ k = f"stage{stage}_block{block}/{split_k[-1]}"
+ transpose_weights("norm", w_dict, k, v)
+ else:
+ ValueError(f"unrecognized {k}")
+ elif "norm" in k:
+ split_k = k.split(".")
+ k = f"norm/{split_k[-1]}"
+ k = k.replace("weight", "gamma")
+ k = k.replace("bias", "beta")
+ transpose_weights("norm", w_dict, k, v)
+ elif "head" in k:
+ split_k = k.split(".")
+ k = f"head/{split_k[-1]}"
+ k = k.replace("weight", "kernel")
+ transpose_weights("linear", w_dict, k, v)
+ else:
+ ValueError(f"unrecognized {k}")
+
+ for key, var in var_dict.items():
+ if key in w_dict:
+ if w_dict[key].shape != var.shape:
+ msg = "shape mismatch: {}".format(key)
+ print(msg)
+ else:
+ var.assign(w_dict[key], read_value=False)
+ else:
+ msg = "Not found {} in {}".format(key, weights_path)
+ print(msg)
+
+ model.save_weights("./{}.h5".format(model_name))
+
+
+if __name__ == '__main__':
+ model = convnext_tiny(num_classes=1000)
+ model.build((1, 224, 224, 3))
+ # https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth
+ main(weights_path="./convnext_tiny_1k_224_ema.pth",
+ model_name="convnext_tiny_1k_224",
+ model=model)
+
+ # model = convnext_small(num_classes=1000)
+ # model.build((1, 224, 224, 3))
+ # # https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth
+ # main(weights_path="./convnext_small_1k_224_ema.pth",
+ # model_name="convnext_small_1k_224",
+ # model=model)
+
+ # model = convnext_base(num_classes=1000)
+ # model.build((1, 224, 224, 3))
+ # # https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth
+ # main(weights_path="./convnext_base_1k_224_ema.pth",
+ # model_name="convnext_base_1k_224",
+ # model=model)
+
+ # model = convnext_base(num_classes=21841)
+ # model.build((1, 224, 224, 3))
+ # # https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth
+ # main(weights_path="./convnext_base_22k_224.pth",
+ # model_name="convnext_base_22k_224",
+ # model=model)
+
+ # model = convnext_large(num_classes=1000)
+ # model.build((1, 224, 224, 3))
+ # # https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth
+ # main(weights_path="./convnext_large_1k_224_ema.pth",
+ # model_name="convnext_large_1k_224",
+ # model=model)
+
+ # model = convnext_large(num_classes=21841)
+ # model.build((1, 224, 224, 3))
+ # # https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth
+ # main(weights_path="./convnext_large_22k_224.pth",
+ # model_name="convnext_large_22k_224",
+ # model=model)
+
diff --git a/tensorflow_classification/ConvNeXt/utils.py b/tensorflow_classification/ConvNeXt/utils.py
new file mode 100644
index 000000000..57470b045
--- /dev/null
+++ b/tensorflow_classification/ConvNeXt/utils.py
@@ -0,0 +1,174 @@
+import os
+import json
+import random
+import math
+
+import numpy as np
+import tensorflow as tf
+import matplotlib.pyplot as plt
+
+
+def read_split_data(root: str, val_rate: float = 0.2):
+ random.seed(0) # 保证随机划分结果一致
+ assert os.path.exists(root), "dataset root: {} does not exist.".format(root)
+
+ # 遍历文件夹,一个文件夹对应一个类别
+ flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
+ # 排序,保证顺序一致
+ flower_class.sort()
+ # 生成类别名称以及对应的数字索引
+ class_indices = dict((k, v) for v, k in enumerate(flower_class))
+ json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4)
+ with open('class_indices.json', 'w') as json_file:
+ json_file.write(json_str)
+
+ train_images_path = [] # 存储训练集的所有图片路径
+ train_images_label = [] # 存储训练集图片对应索引信息
+ val_images_path = [] # 存储验证集的所有图片路径
+ val_images_label = [] # 存储验证集图片对应索引信息
+ every_class_num = [] # 存储每个类别的样本总数
+ supported = [".jpg", ".JPG", ".jpeg", ".JPEG"] # 支持的文件后缀类型
+ # 遍历每个文件夹下的文件
+ for cla in flower_class:
+ cla_path = os.path.join(root, cla)
+ # 遍历获取supported支持的所有文件路径
+ images = [os.path.join(root, cla, i) for i in os.listdir(cla_path)
+ if os.path.splitext(i)[-1] in supported]
+ # 获取该类别对应的索引
+ image_class = class_indices[cla]
+ # 记录该类别的样本数量
+ every_class_num.append(len(images))
+ # 按比例随机采样验证样本
+ val_path = random.sample(images, k=int(len(images) * val_rate))
+
+ for img_path in images:
+ if img_path in val_path: # 如果该路径在采样的验证集样本中则存入验证集
+ val_images_path.append(img_path)
+ val_images_label.append(image_class)
+ else: # 否则存入训练集
+ train_images_path.append(img_path)
+ train_images_label.append(image_class)
+
+ print("{} images were found in the dataset.\n{} for training, {} for validation".format(sum(every_class_num),
+ len(train_images_path),
+ len(val_images_path)
+ ))
+
+ plot_image = False
+ if plot_image:
+ # 绘制每种类别个数柱状图
+ plt.bar(range(len(flower_class)), every_class_num, align='center')
+ # 将横坐标0,1,2,3,4替换为相应的类别名称
+ plt.xticks(range(len(flower_class)), flower_class)
+ # 在柱状图上添加数值标签
+ for i, v in enumerate(every_class_num):
+ plt.text(x=i, y=v + 5, s=str(v), ha='center')
+ # 设置x坐标
+ plt.xlabel('image class')
+ # 设置y坐标
+ plt.ylabel('number of images')
+ # 设置柱状图的标题
+ plt.title('flower class distribution')
+ plt.show()
+
+ return train_images_path, train_images_label, val_images_path, val_images_label
+
+
+def generate_ds(data_root: str,
+ train_im_height: int = 224,
+ train_im_width: int = 224,
+ val_im_height: int = None,
+ val_im_width: int = None,
+ batch_size: int = 8,
+ val_rate: float = 0.1,
+ cache_data: bool = False):
+ """
+ 读取划分数据集,并生成训练集和验证集的迭代器
+ :param data_root: 数据根目录
+ :param train_im_height: 训练输入网络图像的高度
+ :param train_im_width: 训练输入网络图像的宽度
+ :param val_im_height: 验证输入网络图像的高度
+ :param val_im_width: 验证输入网络图像的宽度
+ :param batch_size: 训练使用的batch size
+ :param val_rate: 将数据按给定比例划分到验证集
+ :param cache_data: 是否缓存数据
+ :return:
+ """
+ assert train_im_height is not None
+ assert train_im_width is not None
+ if val_im_width is None:
+ val_im_width = train_im_width
+ if val_im_height is None:
+ val_im_height = train_im_height
+
+ train_img_path, train_img_label, val_img_path, val_img_label = read_split_data(data_root, val_rate=val_rate)
+ AUTOTUNE = tf.data.experimental.AUTOTUNE
+
+ def process_train_info(img_path, label):
+ image = tf.io.read_file(img_path)
+ image = tf.image.decode_jpeg(image, channels=3)
+ image = tf.cast(image, tf.float32)
+ image = tf.image.resize_with_crop_or_pad(image, train_im_height, train_im_width)
+ image = tf.image.random_flip_left_right(image)
+ image = (image / 255. - [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225]
+ return image, label
+
+ def process_val_info(img_path, label):
+ image = tf.io.read_file(img_path)
+ image = tf.image.decode_jpeg(image, channels=3)
+ image = tf.cast(image, tf.float32)
+ image = tf.image.resize_with_crop_or_pad(image, val_im_height, val_im_width)
+ image = (image / 255. - [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225]
+ return image, label
+
+ # Configure dataset for performance
+ def configure_for_performance(ds,
+ shuffle_size: int,
+ shuffle: bool = False,
+ cache: bool = False):
+ if cache:
+ ds = ds.cache() # 读取数据后缓存至内存
+ if shuffle:
+ ds = ds.shuffle(buffer_size=shuffle_size) # 打乱数据顺序
+ ds = ds.batch(batch_size) # 指定batch size
+ ds = ds.prefetch(buffer_size=AUTOTUNE) # 在训练的同时提前准备下一个step的数据
+ return ds
+
+ train_ds = tf.data.Dataset.from_tensor_slices((tf.constant(train_img_path),
+ tf.constant(train_img_label)))
+ total_train = len(train_img_path)
+
+ # Use Dataset.map to create a dataset of image, label pairs
+ train_ds = train_ds.map(process_train_info, num_parallel_calls=AUTOTUNE)
+ train_ds = configure_for_performance(train_ds, total_train, shuffle=True, cache=cache_data)
+
+ val_ds = tf.data.Dataset.from_tensor_slices((tf.constant(val_img_path),
+ tf.constant(val_img_label)))
+ total_val = len(val_img_path)
+ # Use Dataset.map to create a dataset of image, label pairs
+ val_ds = val_ds.map(process_val_info, num_parallel_calls=AUTOTUNE)
+ val_ds = configure_for_performance(val_ds, total_val, cache=False)
+
+ return train_ds, val_ds
+
+
+def cosine_rate(now_step, total_step, end_lr_rate):
+ rate = ((1 + math.cos(now_step * math.pi / total_step)) / 2) * (1 - end_lr_rate) + end_lr_rate # cosine
+ return rate
+
+
+def cosine_scheduler(initial_lr, epochs, steps, warmup_epochs=1, end_lr_rate=1e-6, train_writer=None):
+ """custom learning rate scheduler"""
+ assert warmup_epochs < epochs
+ warmup = np.linspace(start=1e-8, stop=initial_lr, num=warmup_epochs*steps)
+ remainder_steps = (epochs - warmup_epochs) * steps
+ cosine = initial_lr * np.array([cosine_rate(i, remainder_steps, end_lr_rate) for i in range(remainder_steps)])
+ lr_list = np.concatenate([warmup, cosine])
+
+ for i in range(len(lr_list)):
+ new_lr = lr_list[i]
+ if train_writer is not None:
+ # writing lr into tensorboard
+ with train_writer.as_default():
+ tf.summary.scalar('learning rate', data=new_lr, step=i)
+ yield new_lr
diff --git a/tensorflow_classification/Test11_efficientnetV2/predict.py b/tensorflow_classification/Test11_efficientnetV2/predict.py
index dec912667..27476c45f 100644
--- a/tensorflow_classification/Test11_efficientnetV2/predict.py
+++ b/tensorflow_classification/Test11_efficientnetV2/predict.py
@@ -40,8 +40,8 @@ def main():
json_path = './class_indices.json'
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
- json_file = open(json_path, "r")
- class_indict = json.load(json_file)
+ with open(json_path, "r") as f:
+ class_indict = json.load(f)
# create model
model = create_model(num_classes=num_classes)
@@ -59,7 +59,7 @@ def main():
plt.title(print_res)
for i in range(len(result)):
print("class: {:10} prob: {:.3}".format(class_indict[str(i)],
- result[i].numpy()))
+ result[i]))
plt.show()
diff --git a/tensorflow_classification/Test2_alexnet/predict.py b/tensorflow_classification/Test2_alexnet/predict.py
index bd4401359..59fd66496 100644
--- a/tensorflow_classification/Test2_alexnet/predict.py
+++ b/tensorflow_classification/Test2_alexnet/predict.py
@@ -31,8 +31,8 @@ def main():
json_path = './class_indices.json'
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
- json_file = open(json_path, "r")
- class_indict = json.load(json_file)
+ with open(json_path, "r") as f:
+ class_indict = json.load(f)
# create model
model = AlexNet_v1(num_classes=5)
@@ -49,7 +49,7 @@ def main():
plt.title(print_res)
for i in range(len(result)):
print("class: {:10} prob: {:.3}".format(class_indict[str(i)],
- result[i].numpy()))
+ result[i]))
plt.show()
diff --git a/tensorflow_classification/Test3_vgg/predict.py b/tensorflow_classification/Test3_vgg/predict.py
index c060f90a3..3cb4f0dcc 100644
--- a/tensorflow_classification/Test3_vgg/predict.py
+++ b/tensorflow_classification/Test3_vgg/predict.py
@@ -31,8 +31,8 @@ def main():
json_path = './class_indices.json'
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
- json_file = open(json_path, "r")
- class_indict = json.load(json_file)
+ with open(json_path, "r") as f:
+ class_indict = json.load(f)
# create model
model = vgg("vgg16", im_height=im_height, im_width=im_width, num_classes=num_classes)
@@ -49,7 +49,7 @@ def main():
plt.title(print_res)
for i in range(len(result)):
print("class: {:10} prob: {:.3}".format(class_indict[str(i)],
- result[i].numpy()))
+ result[i]))
plt.show()
diff --git a/tensorflow_classification/Test4_goolenet/predict.py b/tensorflow_classification/Test4_goolenet/predict.py
index ee9ed521e..a74a07cbc 100644
--- a/tensorflow_classification/Test4_goolenet/predict.py
+++ b/tensorflow_classification/Test4_goolenet/predict.py
@@ -31,8 +31,8 @@ def main():
json_path = './class_indices.json'
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
- json_file = open(json_path, "r")
- class_indict = json.load(json_file)
+ with open(json_path, "r") as f:
+ class_indict = json.load(f)
model = GoogLeNet(class_num=5, aux_logits=False)
model.summary()
@@ -49,7 +49,7 @@ def main():
plt.title(print_res)
for i in range(len(result)):
print("class: {:10} prob: {:.3}".format(class_indict[str(i)],
- result[i].numpy()))
+ result[i]))
plt.show()
diff --git a/tensorflow_classification/Test5_resnet/predict.py b/tensorflow_classification/Test5_resnet/predict.py
index 2939f6362..9cb0df536 100644
--- a/tensorflow_classification/Test5_resnet/predict.py
+++ b/tensorflow_classification/Test5_resnet/predict.py
@@ -37,8 +37,8 @@ def main():
json_path = './class_indices.json'
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
- json_file = open(json_path, "r")
- class_indict = json.load(json_file)
+ with open(json_path, "r") as f:
+ class_indict = json.load(f)
# create model
feature = resnet50(num_classes=num_classes, include_top=False)
@@ -65,7 +65,7 @@ def main():
plt.title(print_res)
for i in range(len(result)):
print("class: {:10} prob: {:.3}".format(class_indict[str(i)],
- result[i].numpy()))
+ result[i]))
plt.show()
diff --git a/tensorflow_classification/Test6_mobilenet/predict.py b/tensorflow_classification/Test6_mobilenet/predict.py
index c98619ffb..9ba39cc86 100644
--- a/tensorflow_classification/Test6_mobilenet/predict.py
+++ b/tensorflow_classification/Test6_mobilenet/predict.py
@@ -34,8 +34,8 @@ def main():
json_path = './class_indices.json'
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
- json_file = open(json_path, "r")
- class_indict = json.load(json_file)
+ with open(json_path, "r") as f:
+ class_indict = json.load(f)
# create model
feature = MobileNetV2(include_top=False)
@@ -56,7 +56,7 @@ def main():
plt.title(print_res)
for i in range(len(result)):
print("class: {:10} prob: {:.3}".format(class_indict[str(i)],
- result[i].numpy()))
+ result[i]))
plt.show()
diff --git a/tensorflow_classification/Test7_shuffleNet/predict.py b/tensorflow_classification/Test7_shuffleNet/predict.py
index 48a4f6751..4ede6789b 100644
--- a/tensorflow_classification/Test7_shuffleNet/predict.py
+++ b/tensorflow_classification/Test7_shuffleNet/predict.py
@@ -36,8 +36,8 @@ def main():
json_path = './class_indices.json'
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
- json_file = open(json_path, "r")
- class_indict = json.load(json_file)
+ with open(json_path, "r") as f:
+ class_indict = json.load(f)
# create model
model = shufflenet_v2_x1_0(num_classes=num_classes)
@@ -54,7 +54,7 @@ def main():
plt.title(print_res)
for i in range(len(result)):
print("class: {:10} prob: {:.3}".format(class_indict[str(i)],
- result[i].numpy()))
+ result[i]))
plt.show()
diff --git a/tensorflow_classification/Test9_efficientNet/predict.py b/tensorflow_classification/Test9_efficientNet/predict.py
index 3897e5591..632a202b1 100644
--- a/tensorflow_classification/Test9_efficientNet/predict.py
+++ b/tensorflow_classification/Test9_efficientNet/predict.py
@@ -41,8 +41,8 @@ def main():
json_path = './class_indices.json'
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
- json_file = open(json_path, "r")
- class_indict = json.load(json_file)
+ with open(json_path, "r") as f:
+ class_indict = json.load(f)
# create model
model = create_model(num_classes=num_classes)
@@ -59,7 +59,7 @@ def main():
plt.title(print_res)
for i in range(len(result)):
print("class: {:10} prob: {:.3}".format(class_indict[str(i)],
- result[i].numpy()))
+ result[i]))
plt.show()
diff --git a/tensorflow_classification/swin_transformer/predict.py b/tensorflow_classification/swin_transformer/predict.py
index e5e0ae545..95e3fc892 100644
--- a/tensorflow_classification/swin_transformer/predict.py
+++ b/tensorflow_classification/swin_transformer/predict.py
@@ -35,8 +35,8 @@ def main():
json_path = './class_indices.json'
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
- json_file = open(json_path, "r")
- class_indict = json.load(json_file)
+ with open(json_path, "r") as f:
+ class_indict = json.load(f)
# create model
model = create_model(num_classes=num_classes)
@@ -55,7 +55,7 @@ def main():
plt.title(print_res)
for i in range(len(result)):
print("class: {:10} prob: {:.3}".format(class_indict[str(i)],
- result[i].numpy()))
+ result[i]))
plt.show()
diff --git a/tensorflow_classification/vision_transformer/predict.py b/tensorflow_classification/vision_transformer/predict.py
index 49c4c462f..95e803064 100755
--- a/tensorflow_classification/vision_transformer/predict.py
+++ b/tensorflow_classification/vision_transformer/predict.py
@@ -35,8 +35,8 @@ def main():
json_path = './class_indices.json'
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
- json_file = open(json_path, "r")
- class_indict = json.load(json_file)
+ with open(json_path, "r") as f:
+ class_indict = json.load(f)
# create model
model = create_model(num_classes=num_classes, has_logits=False)
@@ -55,7 +55,7 @@ def main():
plt.title(print_res)
for i in range(len(result)):
print("class: {:10} prob: {:.3}".format(class_indict[str(i)],
- result[i].numpy()))
+ result[i]))
plt.show()