first commit
This commit is contained in:
		
						commit
						c8504c8e34
					
				
							
								
								
									
										22
									
								
								.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @ -0,0 +1,22 @@ | |||||||
|  | work_dirs/ | ||||||
|  | predicts/ | ||||||
|  | output/ | ||||||
|  | data/ | ||||||
|  | data | ||||||
|  | 
 | ||||||
|  | __pycache__/ | ||||||
|  | */*.un~ | ||||||
|  | .*.swp | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | *.egg-info/ | ||||||
|  | *.egg | ||||||
|  | 
 | ||||||
|  | output.txt | ||||||
|  | .vscode/* | ||||||
|  | .DS_Store | ||||||
|  | tmp.* | ||||||
|  | *.pt | ||||||
|  | *.pth | ||||||
|  | *.un~ | ||||||
							
								
								
									
										74
									
								
								INSTALL.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										74
									
								
								INSTALL.md
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,74 @@ | |||||||
|  | 
 | ||||||
|  | # Install | ||||||
|  | 
 | ||||||
|  | 1. Clone the RESA repository | ||||||
|  |     ``` | ||||||
|  |     git clone https://github.com/zjulearning/resa.git | ||||||
|  |     ``` | ||||||
|  |     We call this directory as `$RESA_ROOT` | ||||||
|  | 
 | ||||||
|  | 2. Create a conda virtual environment and activate it (conda is optional) | ||||||
|  | 
 | ||||||
|  |     ```Shell | ||||||
|  |     conda create -n resa python=3.8 -y | ||||||
|  |     conda activate resa | ||||||
|  |     ``` | ||||||
|  | 
 | ||||||
|  | 3. Install dependencies | ||||||
|  | 
 | ||||||
|  |     ```Shell | ||||||
|  |     # Install pytorch firstly, the cudatoolkit version should be same in your system. (you can also use pip to install pytorch and torchvision) | ||||||
|  |     conda install pytorch torchvision cudatoolkit=10.1 -c pytorch | ||||||
|  | 
 | ||||||
|  |     # Or you can install via pip | ||||||
|  |     pip install torch torchvision | ||||||
|  | 
 | ||||||
|  |     # Install python packages | ||||||
|  |     pip install -r requirements.txt | ||||||
|  |     ``` | ||||||
|  | 
 | ||||||
|  | 4. Data preparation | ||||||
|  | 
 | ||||||
|  |     Download [CULane](https://xingangpan.github.io/projects/CULane.html) and [Tusimple](https://github.com/TuSimple/tusimple-benchmark/issues/3). Then extract them to `$CULANEROOT` and `$TUSIMPLEROOT`. Create link to `data` directory. | ||||||
|  |      | ||||||
|  |     ```Shell | ||||||
|  |     cd $RESA_ROOT | ||||||
|  |     ln -s $CULANEROOT data/CULane | ||||||
|  |     ln -s $TUSIMPLEROOT data/tusimple | ||||||
|  |     ``` | ||||||
|  | 
 | ||||||
|  |     For Tusimple, the segmentation annotation is not provided, hence we need to generate segmentation from the json annotation.  | ||||||
|  | 
 | ||||||
|  |     ```Shell | ||||||
|  |     python scripts/convert_tusimple.py --root $TUSIMPLEROOT | ||||||
|  |     # this will generate segmentations and two list files: train_gt.txt and test.txt | ||||||
|  |     ``` | ||||||
|  | 
 | ||||||
|  |     For CULane, you should have structure like this: | ||||||
|  |     ``` | ||||||
|  |     $RESA_ROOT/data/CULane/driver_xx_xxframe    # data folders x6 | ||||||
|  |     $RESA_ROOT/data/CULane/laneseg_label_w16    # lane segmentation labels | ||||||
|  |     $RESA_ROOT/data/CULane/list                 # data lists | ||||||
|  |     ``` | ||||||
|  | 
 | ||||||
|  |     For Tusimple, you should have structure like this: | ||||||
|  |     ``` | ||||||
|  |     $RESA_ROOT/data/tusimple/clips # data folders | ||||||
|  |     $RESA_ROOT/data/tusimple/lable_data_xxxx.json # label json file x4 | ||||||
|  |     $RESA_ROOT/data/tusimple/test_tasks_0627.json # test tasks json file | ||||||
|  |     $RESA_ROOT/data/tusimple/test_label.json # test label json file | ||||||
|  |     ``` | ||||||
|  | 
 | ||||||
|  | 5. Install CULane evaluation tools.  | ||||||
|  | 
 | ||||||
|  |     This tools requires OpenCV C++. Please follow [here](https://docs.opencv.org/master/d7/d9f/tutorial_linux_install.html) to install OpenCV C++.  Or just install opencv with command `sudo apt-get install libopencv-dev` | ||||||
|  | 
 | ||||||
|  |      | ||||||
|  |     Then compile the evaluation tool of CULane. | ||||||
|  |     ```Shell | ||||||
|  |     cd $RESA_ROOT/runner/evaluator/culane/lane_evaluation | ||||||
|  |     make | ||||||
|  |     cd - | ||||||
|  |     ``` | ||||||
|  |      | ||||||
|  |     Note that, the default `opencv` version is 3. If you use opencv2, please modify the `OPENCV_VERSION := 3` to `OPENCV_VERSION := 2` in the `Makefile`. | ||||||
							
								
								
									
										201
									
								
								LICENSE
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										201
									
								
								LICENSE
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,201 @@ | |||||||
|  | Apache License | ||||||
|  |                            Version 2.0, January 2004 | ||||||
|  |                         http://www.apache.org/licenses/ | ||||||
|  | 
 | ||||||
|  |    TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION | ||||||
|  | 
 | ||||||
|  |    1. Definitions. | ||||||
|  | 
 | ||||||
|  |       "License" shall mean the terms and conditions for use, reproduction, | ||||||
|  |       and distribution as defined by Sections 1 through 9 of this document. | ||||||
|  | 
 | ||||||
|  |       "Licensor" shall mean the copyright owner or entity authorized by | ||||||
|  |       the copyright owner that is granting the License. | ||||||
|  | 
 | ||||||
|  |       "Legal Entity" shall mean the union of the acting entity and all | ||||||
|  |       other entities that control, are controlled by, or are under common | ||||||
|  |       control with that entity. For the purposes of this definition, | ||||||
|  |       "control" means (i) the power, direct or indirect, to cause the | ||||||
|  |       direction or management of such entity, whether by contract or | ||||||
|  |       otherwise, or (ii) ownership of fifty percent (50%) or more of the | ||||||
|  |       outstanding shares, or (iii) beneficial ownership of such entity. | ||||||
|  | 
 | ||||||
|  |       "You" (or "Your") shall mean an individual or Legal Entity | ||||||
|  |       exercising permissions granted by this License. | ||||||
|  | 
 | ||||||
|  |       "Source" form shall mean the preferred form for making modifications, | ||||||
|  |       including but not limited to software source code, documentation | ||||||
|  |       source, and configuration files. | ||||||
|  | 
 | ||||||
|  |       "Object" form shall mean any form resulting from mechanical | ||||||
|  |       transformation or translation of a Source form, including but | ||||||
|  |       not limited to compiled object code, generated documentation, | ||||||
|  |       and conversions to other media types. | ||||||
|  | 
 | ||||||
|  |       "Work" shall mean the work of authorship, whether in Source or | ||||||
|  |       Object form, made available under the License, as indicated by a | ||||||
|  |       copyright notice that is included in or attached to the work | ||||||
|  |       (an example is provided in the Appendix below). | ||||||
|  | 
 | ||||||
|  |       "Derivative Works" shall mean any work, whether in Source or Object | ||||||
|  |       form, that is based on (or derived from) the Work and for which the | ||||||
|  |       editorial revisions, annotations, elaborations, or other modifications | ||||||
|  |       represent, as a whole, an original work of authorship. For the purposes | ||||||
|  |       of this License, Derivative Works shall not include works that remain | ||||||
|  |       separable from, or merely link (or bind by name) to the interfaces of, | ||||||
|  |       the Work and Derivative Works thereof. | ||||||
|  | 
 | ||||||
|  |       "Contribution" shall mean any work of authorship, including | ||||||
|  |       the original version of the Work and any modifications or additions | ||||||
|  |       to that Work or Derivative Works thereof, that is intentionally | ||||||
|  |       submitted to Licensor for inclusion in the Work by the copyright owner | ||||||
|  |       or by an individual or Legal Entity authorized to submit on behalf of | ||||||
|  |       the copyright owner. For the purposes of this definition, "submitted" | ||||||
|  |       means any form of electronic, verbal, or written communication sent | ||||||
|  |       to the Licensor or its representatives, including but not limited to | ||||||
|  |       communication on electronic mailing lists, source code control systems, | ||||||
|  |       and issue tracking systems that are managed by, or on behalf of, the | ||||||
|  |       Licensor for the purpose of discussing and improving the Work, but | ||||||
|  |       excluding communication that is conspicuously marked or otherwise | ||||||
|  |       designated in writing by the copyright owner as "Not a Contribution." | ||||||
|  | 
 | ||||||
|  |       "Contributor" shall mean Licensor and any individual or Legal Entity | ||||||
|  |       on behalf of whom a Contribution has been received by Licensor and | ||||||
|  |       subsequently incorporated within the Work. | ||||||
|  | 
 | ||||||
|  |    2. Grant of Copyright License. Subject to the terms and conditions of | ||||||
|  |       this License, each Contributor hereby grants to You a perpetual, | ||||||
|  |       worldwide, non-exclusive, no-charge, royalty-free, irrevocable | ||||||
|  |       copyright license to reproduce, prepare Derivative Works of, | ||||||
|  |       publicly display, publicly perform, sublicense, and distribute the | ||||||
|  |       Work and such Derivative Works in Source or Object form. | ||||||
|  | 
 | ||||||
|  |    3. Grant of Patent License. Subject to the terms and conditions of | ||||||
|  |       this License, each Contributor hereby grants to You a perpetual, | ||||||
|  |       worldwide, non-exclusive, no-charge, royalty-free, irrevocable | ||||||
|  |       (except as stated in this section) patent license to make, have made, | ||||||
|  |       use, offer to sell, sell, import, and otherwise transfer the Work, | ||||||
|  |       where such license applies only to those patent claims licensable | ||||||
|  |       by such Contributor that are necessarily infringed by their | ||||||
|  |       Contribution(s) alone or by combination of their Contribution(s) | ||||||
|  |       with the Work to which such Contribution(s) was submitted. If You | ||||||
|  |       institute patent litigation against any entity (including a | ||||||
|  |       cross-claim or counterclaim in a lawsuit) alleging that the Work | ||||||
|  |       or a Contribution incorporated within the Work constitutes direct | ||||||
|  |       or contributory patent infringement, then any patent licenses | ||||||
|  |       granted to You under this License for that Work shall terminate | ||||||
|  |       as of the date such litigation is filed. | ||||||
|  | 
 | ||||||
|  |    4. Redistribution. You may reproduce and distribute copies of the | ||||||
|  |       Work or Derivative Works thereof in any medium, with or without | ||||||
|  |       modifications, and in Source or Object form, provided that You | ||||||
|  |       meet the following conditions: | ||||||
|  | 
 | ||||||
|  |       (a) You must give any other recipients of the Work or | ||||||
|  |           Derivative Works a copy of this License; and | ||||||
|  | 
 | ||||||
|  |       (b) You must cause any modified files to carry prominent notices | ||||||
|  |           stating that You changed the files; and | ||||||
|  | 
 | ||||||
|  |       (c) You must retain, in the Source form of any Derivative Works | ||||||
|  |           that You distribute, all copyright, patent, trademark, and | ||||||
|  |           attribution notices from the Source form of the Work, | ||||||
|  |           excluding those notices that do not pertain to any part of | ||||||
|  |           the Derivative Works; and | ||||||
|  | 
 | ||||||
|  |       (d) If the Work includes a "NOTICE" text file as part of its | ||||||
|  |           distribution, then any Derivative Works that You distribute must | ||||||
|  |           include a readable copy of the attribution notices contained | ||||||
|  |           within such NOTICE file, excluding those notices that do not | ||||||
|  |           pertain to any part of the Derivative Works, in at least one | ||||||
|  |           of the following places: within a NOTICE text file distributed | ||||||
|  |           as part of the Derivative Works; within the Source form or | ||||||
|  |           documentation, if provided along with the Derivative Works; or, | ||||||
|  |           within a display generated by the Derivative Works, if and | ||||||
|  |           wherever such third-party notices normally appear. The contents | ||||||
|  |           of the NOTICE file are for informational purposes only and | ||||||
|  |           do not modify the License. You may add Your own attribution | ||||||
|  |           notices within Derivative Works that You distribute, alongside | ||||||
|  |           or as an addendum to the NOTICE text from the Work, provided | ||||||
|  |           that such additional attribution notices cannot be construed | ||||||
|  |           as modifying the License. | ||||||
|  | 
 | ||||||
|  |       You may add Your own copyright statement to Your modifications and | ||||||
|  |       may provide additional or different license terms and conditions | ||||||
|  |       for use, reproduction, or distribution of Your modifications, or | ||||||
|  |       for any such Derivative Works as a whole, provided Your use, | ||||||
|  |       reproduction, and distribution of the Work otherwise complies with | ||||||
|  |       the conditions stated in this License. | ||||||
|  | 
 | ||||||
|  |    5. Submission of Contributions. Unless You explicitly state otherwise, | ||||||
|  |       any Contribution intentionally submitted for inclusion in the Work | ||||||
|  |       by You to the Licensor shall be under the terms and conditions of | ||||||
|  |       this License, without any additional terms or conditions. | ||||||
|  |       Notwithstanding the above, nothing herein shall supersede or modify | ||||||
|  |       the terms of any separate license agreement you may have executed | ||||||
|  |       with Licensor regarding such Contributions. | ||||||
|  | 
 | ||||||
|  |    6. Trademarks. This License does not grant permission to use the trade | ||||||
|  |       names, trademarks, service marks, or product names of the Licensor, | ||||||
|  |       except as required for reasonable and customary use in describing the | ||||||
|  |       origin of the Work and reproducing the content of the NOTICE file. | ||||||
|  | 
 | ||||||
|  |    7. Disclaimer of Warranty. Unless required by applicable law or | ||||||
|  |       agreed to in writing, Licensor provides the Work (and each | ||||||
|  |       Contributor provides its Contributions) on an "AS IS" BASIS, | ||||||
|  |       WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||||
|  |       implied, including, without limitation, any warranties or conditions | ||||||
|  |       of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A | ||||||
|  |       PARTICULAR PURPOSE. You are solely responsible for determining the | ||||||
|  |       appropriateness of using or redistributing the Work and assume any | ||||||
|  |       risks associated with Your exercise of permissions under this License. | ||||||
|  | 
 | ||||||
|  |    8. Limitation of Liability. In no event and under no legal theory, | ||||||
|  |       whether in tort (including negligence), contract, or otherwise, | ||||||
|  |       unless required by applicable law (such as deliberate and grossly | ||||||
|  |       negligent acts) or agreed to in writing, shall any Contributor be | ||||||
|  |       liable to You for damages, including any direct, indirect, special, | ||||||
|  |       incidental, or consequential damages of any character arising as a | ||||||
|  |       result of this License or out of the use or inability to use the | ||||||
|  |       Work (including but not limited to damages for loss of goodwill, | ||||||
|  |       work stoppage, computer failure or malfunction, or any and all | ||||||
|  |       other commercial damages or losses), even if such Contributor | ||||||
|  |       has been advised of the possibility of such damages. | ||||||
|  | 
 | ||||||
|  |    9. Accepting Warranty or Additional Liability. While redistributing | ||||||
|  |       the Work or Derivative Works thereof, You may choose to offer, | ||||||
|  |       and charge a fee for, acceptance of support, warranty, indemnity, | ||||||
|  |       or other liability obligations and/or rights consistent with this | ||||||
|  |       License. However, in accepting such obligations, You may act only | ||||||
|  |       on Your own behalf and on Your sole responsibility, not on behalf | ||||||
|  |       of any other Contributor, and only if You agree to indemnify, | ||||||
|  |       defend, and hold each Contributor harmless for any liability | ||||||
|  |       incurred by, or claims asserted against, such Contributor by reason | ||||||
|  |       of your accepting any such warranty or additional liability. | ||||||
|  | 
 | ||||||
|  |    END OF TERMS AND CONDITIONS | ||||||
|  | 
 | ||||||
|  |    APPENDIX: How to apply the Apache License to your work. | ||||||
|  | 
 | ||||||
|  |       To apply the Apache License to your work, attach the following | ||||||
|  |       boilerplate notice, with the fields enclosed by brackets "[]" | ||||||
|  |       replaced with your own identifying information. (Don't include | ||||||
|  |       the brackets!)  The text should be enclosed in the appropriate | ||||||
|  |       comment syntax for the file format. We also recommend that a | ||||||
|  |       file or class name and description of purpose be included on the | ||||||
|  |       same "printed page" as the copyright notice for easier | ||||||
|  |       identification within third-party archives. | ||||||
|  | 
 | ||||||
|  |    Copyright 2021 Tu Zheng | ||||||
|  | 
 | ||||||
|  |    Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  |    you may not use this file except in compliance with the License. | ||||||
|  |    You may obtain a copy of the License at | ||||||
|  | 
 | ||||||
|  |        http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  | 
 | ||||||
|  |    Unless required by applicable law or agreed to in writing, software | ||||||
|  |    distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  |    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  |    See the License for the specific language governing permissions and | ||||||
|  |    limitations under the License. | ||||||
							
								
								
									
										148
									
								
								README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										148
									
								
								README.md
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,148 @@ | |||||||
|  | # RESA  | ||||||
|  | PyTorch implementation of the paper "[RESA: Recurrent Feature-Shift Aggregator for Lane Detection](https://arxiv.org/abs/2008.13719)". | ||||||
|  | 
 | ||||||
|  | Our paper has been accepted by AAAI2021. | ||||||
|  | 
 | ||||||
|  | **News**: We also release RESA on [LaneDet](https://github.com/Turoad/lanedet). It's also recommended for you to try LaneDet. | ||||||
|  | 
 | ||||||
|  | ## Introduction | ||||||
|  |  | ||||||
|  | - RESA shifts sliced | ||||||
|  | feature map recurrently in vertical and horizontal directions | ||||||
|  | and enables each pixel to gather global information. | ||||||
|  | - RESA achieves SOTA results on CULane and Tusimple Dataset. | ||||||
|  | 
 | ||||||
|  | ## Get started | ||||||
|  | 1. Clone the RESA repository | ||||||
|  |     ``` | ||||||
|  |     git clone https://github.com/zjulearning/resa.git | ||||||
|  |     ``` | ||||||
|  |     We call this directory as `$RESA_ROOT` | ||||||
|  | 
 | ||||||
|  | 2. Create a conda virtual environment and activate it (conda is optional) | ||||||
|  | 
 | ||||||
|  |     ```Shell | ||||||
|  |     conda create -n resa python=3.8 -y | ||||||
|  |     conda activate resa | ||||||
|  |     ``` | ||||||
|  | 
 | ||||||
|  | 3. Install dependencies | ||||||
|  | 
 | ||||||
|  |     ```Shell | ||||||
|  |     # Install pytorch firstly, the cudatoolkit version should be same in your system. (you can also use pip to install pytorch and torchvision) | ||||||
|  |     conda install pytorch torchvision cudatoolkit=10.1 -c pytorch | ||||||
|  | 
 | ||||||
|  |     # Or you can install via pip | ||||||
|  |     pip install torch torchvision | ||||||
|  | 
 | ||||||
|  |     # Install python packages | ||||||
|  |     pip install -r requirements.txt | ||||||
|  |     ``` | ||||||
|  | 
 | ||||||
|  | 4. Data preparation | ||||||
|  | 
 | ||||||
|  |     Download [CULane](https://xingangpan.github.io/projects/CULane.html) and [Tusimple](https://github.com/TuSimple/tusimple-benchmark/issues/3). Then extract them to `$CULANEROOT` and `$TUSIMPLEROOT`. Create link to `data` directory. | ||||||
|  |      | ||||||
|  |     ```Shell | ||||||
|  |     cd $RESA_ROOT | ||||||
|  |     mkdir -p data | ||||||
|  |     ln -s $CULANEROOT data/CULane | ||||||
|  |     ln -s $TUSIMPLEROOT data/tusimple | ||||||
|  |     ``` | ||||||
|  | 
 | ||||||
|  |     For CULane, you should have structure like this: | ||||||
|  |     ``` | ||||||
|  |     $CULANEROOT/driver_xx_xxframe    # data folders x6 | ||||||
|  |     $CULANEROOT/laneseg_label_w16    # lane segmentation labels | ||||||
|  |     $CULANEROOT/list                 # data lists | ||||||
|  |     ``` | ||||||
|  | 
 | ||||||
|  |     For Tusimple, you should have structure like this: | ||||||
|  |     ``` | ||||||
|  |     $TUSIMPLEROOT/clips # data folders | ||||||
|  |     $TUSIMPLEROOT/lable_data_xxxx.json # label json file x4 | ||||||
|  |     $TUSIMPLEROOT/test_tasks_0627.json # test tasks json file | ||||||
|  |     $TUSIMPLEROOT/test_label.json # test label json file | ||||||
|  | 
 | ||||||
|  |     ``` | ||||||
|  | 
 | ||||||
|  |     For Tusimple, the segmentation annotation is not provided, hence we need to generate segmentation from the json annotation.  | ||||||
|  | 
 | ||||||
|  |     ```Shell | ||||||
|  |     python tools/generate_seg_tusimple.py --root $TUSIMPLEROOT | ||||||
|  |     # this will generate seg_label directory | ||||||
|  |     ``` | ||||||
|  | 
 | ||||||
|  | 5. Install CULane evaluation tools.  | ||||||
|  | 
 | ||||||
|  |     This tools requires OpenCV C++. Please follow [here](https://docs.opencv.org/master/d7/d9f/tutorial_linux_install.html) to install OpenCV C++.  Or just install opencv with command `sudo apt-get install libopencv-dev` | ||||||
|  | 
 | ||||||
|  |      | ||||||
|  |     Then compile the evaluation tool of CULane. | ||||||
|  |     ```Shell | ||||||
|  |     cd $RESA_ROOT/runner/evaluator/culane/lane_evaluation | ||||||
|  |     make | ||||||
|  |     cd - | ||||||
|  |     ``` | ||||||
|  |      | ||||||
|  |     Note that, the default `opencv` version is 3. If you use opencv2, please modify the `OPENCV_VERSION := 3` to `OPENCV_VERSION := 2` in the `Makefile`. | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | ## Training | ||||||
|  | 
 | ||||||
|  | For training, run | ||||||
|  | 
 | ||||||
|  | ```Shell | ||||||
|  | python main.py [configs/path_to_your_config] --gpus [gpu_ids] | ||||||
|  | ``` | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | For example, run | ||||||
|  | ```Shell | ||||||
|  | python main.py configs/culane.py --gpus 0 1 2 3 | ||||||
|  | ``` | ||||||
|  | 
 | ||||||
|  | ## Testing | ||||||
|  | For testing, run | ||||||
|  | ```Shell | ||||||
|  | python main.py c[configs/path_to_your_config] --validate --load_from [path_to_your_model] [gpu_num] | ||||||
|  | ``` | ||||||
|  | 
 | ||||||
|  | For example, run | ||||||
|  | ```Shell | ||||||
|  | python main.py configs/culane.py --validate --load_from culane_resnet50.pth --gpus 0 1 2 3 | ||||||
|  | 
 | ||||||
|  | python main.py configs/tusimple.py --validate --load_from tusimple_resnet34.pth --gpus 0 1 2 3 | ||||||
|  | ``` | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | We provide two trained ResNet models on CULane and Tusimple, downloading our best performed model (Tusimple: [GoogleDrive](https://drive.google.com/file/d/1M1xi82y0RoWUwYYG9LmZHXWSD2D60o0D/view?usp=sharing)/[BaiduDrive(code:s5ii)](https://pan.baidu.com/s/1CgJFrt9OHe-RUNooPpHRGA), | ||||||
|  | CULane: [GoogleDrive](https://drive.google.com/file/d/1pcqq9lpJ4ixJgFVFndlPe42VgVsjgn0Q/view?usp=sharing)/[BaiduDrive(code:rlwj)](https://pan.baidu.com/s/1ODKAZxpKrZIPXyaNnxcV3g) | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | ## Visualization | ||||||
|  | Just add `--view`. | ||||||
|  | 
 | ||||||
|  | For example: | ||||||
|  | ```Shell | ||||||
|  | python main.py configs/culane.py --validate --load_from culane_resnet50.pth --gpus 0 1 2 3 --view | ||||||
|  | ``` | ||||||
|  | You will get the result in the directory: `work_dirs/[DATASET]/xxx/vis`. | ||||||
|  | 
 | ||||||
|  | ## Citation | ||||||
|  | If you use our method, please consider citing: | ||||||
|  | ```BibTeX | ||||||
|  | @inproceedings{zheng2021resa, | ||||||
|  |   title={RESA: Recurrent Feature-Shift Aggregator for Lane Detection}, | ||||||
|  |   author={Zheng, Tu and Fang, Hao and Zhang, Yi and Tang, Wenjian and Yang, Zheng and Liu, Haifeng and Cai, Deng}, | ||||||
|  |   booktitle={Proceedings of the AAAI Conference on Artificial Intelligence}, | ||||||
|  |   volume={35}, | ||||||
|  |   number={4}, | ||||||
|  |   pages={3547--3554}, | ||||||
|  |   year={2021} | ||||||
|  | } | ||||||
|  | ``` | ||||||
|  | 
 | ||||||
|  | <!-- ## Thanks | ||||||
|  | 
 | ||||||
|  | The evaluation code is modified from [SCNN](https://github.com/XingangPan/SCNN) and [Tusimple Benchmark](https://github.com/TuSimple/tusimple-benchmark). --> | ||||||
							
								
								
									
										88
									
								
								configs/culane.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										88
									
								
								configs/culane.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,88 @@ | |||||||
|  | net = dict( | ||||||
|  |     type='RESANet', | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | backbone = dict( | ||||||
|  |     type='ResNetWrapper', | ||||||
|  |     resnet='resnet50', | ||||||
|  |     pretrained=True, | ||||||
|  |     replace_stride_with_dilation=[False, True, True], | ||||||
|  |     out_conv=True, | ||||||
|  |     fea_stride=8, | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | resa = dict( | ||||||
|  |     type='RESA', | ||||||
|  |     alpha=2.0, | ||||||
|  |     iter=4, | ||||||
|  |     input_channel=128, | ||||||
|  |     conv_stride=9, | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | #decoder = 'PlainDecoder'       | ||||||
|  | decoder = 'BUSD' | ||||||
|  | 
 | ||||||
|  | trainer = dict( | ||||||
|  |     type='RESA' | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | evaluator = dict( | ||||||
|  |     type='CULane',         | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | optimizer = dict( | ||||||
|  |   type='sgd', | ||||||
|  |   lr=0.025, | ||||||
|  |   weight_decay=1e-4, | ||||||
|  |   momentum=0.9 | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | epochs = 12 | ||||||
|  | batch_size = 8 | ||||||
|  | total_iter = (88880 // batch_size) * epochs | ||||||
|  | import math | ||||||
|  | scheduler = dict( | ||||||
|  |     type = 'LambdaLR', | ||||||
|  |     lr_lambda = lambda _iter : math.pow(1 - _iter/total_iter, 0.9) | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | loss_type = 'dice_loss' | ||||||
|  | seg_loss_weight = 2. | ||||||
|  | eval_ep = 6 | ||||||
|  | save_ep = epochs | ||||||
|  | 
 | ||||||
|  | bg_weight = 0.4 | ||||||
|  | 
 | ||||||
|  | img_norm = dict( | ||||||
|  |     mean=[103.939, 116.779, 123.68], | ||||||
|  |     std=[1., 1., 1.] | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | img_height = 288 | ||||||
|  | img_width = 800 | ||||||
|  | cut_height = 240  | ||||||
|  | 
 | ||||||
|  | dataset_path = './data/CULane' | ||||||
|  | dataset = dict( | ||||||
|  |     train=dict( | ||||||
|  |         type='CULane', | ||||||
|  |         img_path=dataset_path, | ||||||
|  |         data_list='train_gt.txt', | ||||||
|  |     ), | ||||||
|  |     val=dict( | ||||||
|  |         type='CULane', | ||||||
|  |         img_path=dataset_path, | ||||||
|  |         data_list='test.txt', | ||||||
|  |     ), | ||||||
|  |     test=dict( | ||||||
|  |         type='CULane', | ||||||
|  |         img_path=dataset_path, | ||||||
|  |         data_list='test.txt', | ||||||
|  |     ) | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | workers = 12 | ||||||
|  | num_classes = 4 + 1 | ||||||
|  | ignore_label = 255 | ||||||
|  | log_interval = 500 | ||||||
							
								
								
									
										97
									
								
								configs/culane_copy.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										97
									
								
								configs/culane_copy.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,97 @@ | |||||||
|  | net = dict( | ||||||
|  |     type='RESANet', | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | # backbone = dict( | ||||||
|  | #     type='ResNetWrapper', | ||||||
|  | #     resnet='resnet50', | ||||||
|  | #     pretrained=True, | ||||||
|  | #     replace_stride_with_dilation=[False, True, True], | ||||||
|  | #     out_conv=True, | ||||||
|  | #     fea_stride=8, | ||||||
|  | # ) | ||||||
|  | 
 | ||||||
|  | backbone = dict( | ||||||
|  |     type='ResNetWrapper', | ||||||
|  |     resnet='resnet34', | ||||||
|  |     pretrained=True, | ||||||
|  |     replace_stride_with_dilation=[False, False, False], | ||||||
|  |     out_conv=False, | ||||||
|  |     fea_stride=8, | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | resa = dict( | ||||||
|  |     type='RESA', | ||||||
|  |     alpha=2.0, | ||||||
|  |     iter=4, | ||||||
|  |     input_channel=128, | ||||||
|  |     conv_stride=9, | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | #decoder = 'PlainDecoder'       | ||||||
|  | decoder = 'BUSD' | ||||||
|  | 
 | ||||||
|  | trainer = dict( | ||||||
|  |     type='RESA' | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | evaluator = dict( | ||||||
|  |     type='CULane',         | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | optimizer = dict( | ||||||
|  |   type='sgd', | ||||||
|  |   lr=0.025, | ||||||
|  |   weight_decay=1e-4, | ||||||
|  |   momentum=0.9 | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | epochs = 20 | ||||||
|  | batch_size = 8 | ||||||
|  | total_iter = (88880 // batch_size) * epochs | ||||||
|  | import math | ||||||
|  | scheduler = dict( | ||||||
|  |     type = 'LambdaLR', | ||||||
|  |     lr_lambda = lambda _iter : math.pow(1 - _iter/total_iter, 0.9) | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | loss_type = 'dice_loss' | ||||||
|  | seg_loss_weight = 2. | ||||||
|  | eval_ep = 1 | ||||||
|  | save_ep = epochs | ||||||
|  | 
 | ||||||
|  | bg_weight = 0.4 | ||||||
|  | 
 | ||||||
|  | img_norm = dict( | ||||||
|  |     mean=[103.939, 116.779, 123.68], | ||||||
|  |     std=[1., 1., 1.] | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | img_height = 288 | ||||||
|  | img_width = 800 | ||||||
|  | cut_height = 240  | ||||||
|  | 
 | ||||||
|  | dataset_path = './data/CULane' | ||||||
|  | dataset = dict( | ||||||
|  |     train=dict( | ||||||
|  |         type='CULane', | ||||||
|  |         img_path=dataset_path, | ||||||
|  |         data_list='train_gt.txt', | ||||||
|  |     ), | ||||||
|  |     val=dict( | ||||||
|  |         type='CULane', | ||||||
|  |         img_path=dataset_path, | ||||||
|  |         data_list='test.txt', | ||||||
|  |     ), | ||||||
|  |     test=dict( | ||||||
|  |         type='CULane', | ||||||
|  |         img_path=dataset_path, | ||||||
|  |         data_list='test.txt', | ||||||
|  |     ) | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | workers = 12 | ||||||
|  | num_classes = 4 + 1 | ||||||
|  | ignore_label = 255 | ||||||
|  | log_interval = 500 | ||||||
							
								
								
									
										93
									
								
								configs/tusimple.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										93
									
								
								configs/tusimple.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,93 @@ | |||||||
|  | net = dict( | ||||||
|  |     type='RESANet', | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | backbone = dict( | ||||||
|  |     type='ResNetWrapper', | ||||||
|  |     resnet='resnet34', | ||||||
|  |     pretrained=True, | ||||||
|  |     replace_stride_with_dilation=[False, True, True], | ||||||
|  |     out_conv=True, | ||||||
|  |     fea_stride=8, | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | resa = dict( | ||||||
|  |     type='RESA', | ||||||
|  |     alpha=2.0, | ||||||
|  |     iter=5, | ||||||
|  |     input_channel=128, | ||||||
|  |     conv_stride=9, | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | decoder = 'BUSD'         | ||||||
|  | 
 | ||||||
|  | trainer = dict( | ||||||
|  |     type='RESA' | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | evaluator = dict( | ||||||
|  |     type='Tusimple',         | ||||||
|  |     thresh = 0.60 | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | optimizer = dict( | ||||||
|  |   type='sgd', | ||||||
|  |   lr=0.020, | ||||||
|  |   weight_decay=1e-4, | ||||||
|  |   momentum=0.9 | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | total_iter = 181400 | ||||||
|  | import math | ||||||
|  | scheduler = dict( | ||||||
|  |     type = 'LambdaLR', | ||||||
|  |     lr_lambda = lambda _iter : math.pow(1 - _iter/total_iter, 0.9) | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | bg_weight = 0.4 | ||||||
|  | 
 | ||||||
|  | img_norm = dict( | ||||||
|  |     mean=[103.939, 116.779, 123.68], | ||||||
|  |     std=[1., 1., 1.] | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | img_height = 368 | ||||||
|  | img_width = 640 | ||||||
|  | cut_height = 160 | ||||||
|  | seg_label = "seg_label" | ||||||
|  | 
 | ||||||
|  | dataset_path = './data/tusimple' | ||||||
|  | test_json_file = './data/tusimple/test_label.json' | ||||||
|  | 
 | ||||||
|  | dataset = dict( | ||||||
|  |     train=dict( | ||||||
|  |         type='TuSimple', | ||||||
|  |         img_path=dataset_path, | ||||||
|  |         data_list='train_val_gt.txt', | ||||||
|  |     ), | ||||||
|  |     val=dict( | ||||||
|  |         type='TuSimple', | ||||||
|  |         img_path=dataset_path, | ||||||
|  |         data_list='test_gt.txt' | ||||||
|  |     ), | ||||||
|  |     test=dict( | ||||||
|  |         type='TuSimple', | ||||||
|  |         img_path=dataset_path, | ||||||
|  |         data_list='test_gt.txt' | ||||||
|  |     ) | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | loss_type = 'cross_entropy' | ||||||
|  | seg_loss_weight = 1.0 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | batch_size = 4 | ||||||
|  | workers = 12 | ||||||
|  | num_classes = 6 + 1 | ||||||
|  | ignore_label = 255 | ||||||
|  | epochs = 300 | ||||||
|  | log_interval = 100 | ||||||
|  | eval_ep = 1 | ||||||
|  | save_ep = epochs | ||||||
|  | log_note = '' | ||||||
							
								
								
									
										4
									
								
								datasets/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								datasets/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,4 @@ | |||||||
|  | from .registry import build_dataset, build_dataloader | ||||||
|  | 
 | ||||||
|  | from .tusimple import TuSimple | ||||||
|  | from .culane import CULane | ||||||
							
								
								
									
										86
									
								
								datasets/base_dataset.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										86
									
								
								datasets/base_dataset.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,86 @@ | |||||||
|  | import os.path as osp | ||||||
|  | import os | ||||||
|  | import numpy as np | ||||||
|  | import cv2 | ||||||
|  | import torch | ||||||
|  | from torch.utils.data import Dataset | ||||||
|  | import torchvision | ||||||
|  | import utils.transforms as tf | ||||||
|  | from .registry import DATASETS | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @DATASETS.register_module | ||||||
|  | class BaseDataset(Dataset): | ||||||
|  |     def __init__(self, img_path, data_list, list_path='list', cfg=None): | ||||||
|  |         self.cfg = cfg | ||||||
|  |         self.img_path = img_path | ||||||
|  |         self.list_path = osp.join(img_path, list_path) | ||||||
|  |         self.data_list = data_list | ||||||
|  |         self.is_training = ('train' in data_list) | ||||||
|  | 
 | ||||||
|  |         self.img_name_list = [] | ||||||
|  |         self.full_img_path_list = [] | ||||||
|  |         self.label_list = [] | ||||||
|  |         self.exist_list = [] | ||||||
|  | 
 | ||||||
|  |         self.transform = self.transform_train() if self.is_training else self.transform_val() | ||||||
|  | 
 | ||||||
|  |         self.init() | ||||||
|  | 
 | ||||||
|  |     def transform_train(self): | ||||||
|  |         raise NotImplementedError() | ||||||
|  | 
 | ||||||
|  |     def transform_val(self): | ||||||
|  |         val_transform = torchvision.transforms.Compose([ | ||||||
|  |             tf.SampleResize((self.cfg.img_width, self.cfg.img_height)), | ||||||
|  |             tf.GroupNormalize(mean=(self.cfg.img_norm['mean'], (0, )), std=( | ||||||
|  |                 self.cfg.img_norm['std'], (1, ))), | ||||||
|  |         ]) | ||||||
|  |         return val_transform | ||||||
|  | 
 | ||||||
|  |     def view(self, img, coords, file_path=None): | ||||||
|  |         for coord in coords: | ||||||
|  |             for x, y in coord: | ||||||
|  |                 if x <= 0 or y <= 0: | ||||||
|  |                     continue | ||||||
|  |                 x, y = int(x), int(y) | ||||||
|  |                 cv2.circle(img, (x, y), 4, (255, 0, 0), 2) | ||||||
|  | 
 | ||||||
|  |         if file_path is not None: | ||||||
|  |             if not os.path.exists(osp.dirname(file_path)): | ||||||
|  |                 os.makedirs(osp.dirname(file_path)) | ||||||
|  |             cv2.imwrite(file_path, img) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |     def init(self): | ||||||
|  |         raise NotImplementedError() | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |     def __len__(self): | ||||||
|  |         return len(self.full_img_path_list) | ||||||
|  | 
 | ||||||
|  |     def __getitem__(self, idx): | ||||||
|  |         img = cv2.imread(self.full_img_path_list[idx]).astype(np.float32) | ||||||
|  |         img = img[self.cfg.cut_height:, :, :] | ||||||
|  | 
 | ||||||
|  |         if self.is_training: | ||||||
|  |             label = cv2.imread(self.label_list[idx], cv2.IMREAD_UNCHANGED) | ||||||
|  |             if len(label.shape) > 2: | ||||||
|  |                 label = label[:, :, 0] | ||||||
|  |             label = label.squeeze() | ||||||
|  |             label = label[self.cfg.cut_height:, :] | ||||||
|  |             exist = self.exist_list[idx] | ||||||
|  |             if self.transform: | ||||||
|  |                 img, label = self.transform((img, label)) | ||||||
|  |             label = torch.from_numpy(label).contiguous().long() | ||||||
|  |         else: | ||||||
|  |             img, = self.transform((img,)) | ||||||
|  | 
 | ||||||
|  |         img = torch.from_numpy(img).permute(2, 0, 1).contiguous().float() | ||||||
|  |         meta = {'full_img_path': self.full_img_path_list[idx], | ||||||
|  |                 'img_name': self.img_name_list[idx]} | ||||||
|  | 
 | ||||||
|  |         data = {'img': img, 'meta': meta} | ||||||
|  |         if self.is_training: | ||||||
|  |             data.update({'label': label, 'exist': exist}) | ||||||
|  |         return data | ||||||
							
								
								
									
										72
									
								
								datasets/culane.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										72
									
								
								datasets/culane.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,72 @@ | |||||||
|  | import os | ||||||
|  | import os.path as osp | ||||||
|  | import numpy as np | ||||||
|  | import torchvision | ||||||
|  | import utils.transforms as tf | ||||||
|  | from .base_dataset import BaseDataset | ||||||
|  | from .registry import DATASETS | ||||||
|  | import cv2 | ||||||
|  | import torch | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @DATASETS.register_module | ||||||
|  | class CULane(BaseDataset): | ||||||
|  |     def __init__(self, img_path, data_list, cfg=None): | ||||||
|  |         super().__init__(img_path, data_list, cfg=cfg) | ||||||
|  |         self.ori_imgh = 590 | ||||||
|  |         self.ori_imgw = 1640 | ||||||
|  | 
 | ||||||
|  |     def init(self): | ||||||
|  |         with open(osp.join(self.list_path, self.data_list)) as f: | ||||||
|  |             for line in f: | ||||||
|  |                 line_split = line.strip().split(" ") | ||||||
|  |                 self.img_name_list.append(line_split[0]) | ||||||
|  |                 self.full_img_path_list.append(self.img_path + line_split[0]) | ||||||
|  |                 if not self.is_training: | ||||||
|  |                     continue | ||||||
|  |                 self.label_list.append(self.img_path + line_split[1]) | ||||||
|  |                 self.exist_list.append( | ||||||
|  |                     np.array([int(line_split[2]), int(line_split[3]), | ||||||
|  |                               int(line_split[4]), int(line_split[5])])) | ||||||
|  | 
 | ||||||
|  |     def transform_train(self): | ||||||
|  |         train_transform = torchvision.transforms.Compose([ | ||||||
|  |             tf.GroupRandomRotation(degree=(-2, 2)), | ||||||
|  |             tf.GroupRandomHorizontalFlip(), | ||||||
|  |             tf.SampleResize((self.cfg.img_width, self.cfg.img_height)), | ||||||
|  |             tf.GroupNormalize(mean=(self.cfg.img_norm['mean'], (0, )), std=( | ||||||
|  |                 self.cfg.img_norm['std'], (1, ))), | ||||||
|  |         ]) | ||||||
|  |         return train_transform | ||||||
|  | 
 | ||||||
|  |     def probmap2lane(self, probmaps, exists, pts=18): | ||||||
|  |         coords = [] | ||||||
|  |         probmaps = probmaps[1:, ...] | ||||||
|  |         exists = exists > 0.5 | ||||||
|  |         for probmap, exist in zip(probmaps, exists): | ||||||
|  |             if exist == 0: | ||||||
|  |                 continue | ||||||
|  |             probmap = cv2.blur(probmap, (9, 9), borderType=cv2.BORDER_REPLICATE) | ||||||
|  |             thr = 0.3 | ||||||
|  |             coordinate = np.zeros(pts) | ||||||
|  |             cut_height = self.cfg.cut_height | ||||||
|  |             for i in range(pts): | ||||||
|  |                 line = probmap[round( | ||||||
|  |                     self.cfg.img_height-i*20/(self.ori_imgh-cut_height)*self.cfg.img_height)-1] | ||||||
|  | 
 | ||||||
|  |                 if np.max(line) > thr: | ||||||
|  |                     coordinate[i] = np.argmax(line)+1 | ||||||
|  |             if np.sum(coordinate > 0) < 2: | ||||||
|  |                 continue | ||||||
|  |      | ||||||
|  |             img_coord = np.zeros((pts, 2)) | ||||||
|  |             img_coord[:, :] = -1 | ||||||
|  |             for idx, value in enumerate(coordinate): | ||||||
|  |                 if value > 0: | ||||||
|  |                     img_coord[idx][0] = round(value*self.ori_imgw/self.cfg.img_width-1) | ||||||
|  |                     img_coord[idx][1] = round(self.ori_imgh-idx*20-1) | ||||||
|  |      | ||||||
|  |             img_coord = img_coord.astype(int) | ||||||
|  |             coords.append(img_coord) | ||||||
|  |      | ||||||
|  |         return coords | ||||||
							
								
								
									
										36
									
								
								datasets/registry.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										36
									
								
								datasets/registry.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,36 @@ | |||||||
|  | from utils import Registry, build_from_cfg | ||||||
|  | 
 | ||||||
|  | import torch | ||||||
|  | 
 | ||||||
|  | DATASETS = Registry('datasets') | ||||||
|  | 
 | ||||||
|  | def build(cfg, registry, default_args=None): | ||||||
|  |     if isinstance(cfg, list): | ||||||
|  |         modules = [ | ||||||
|  |             build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg | ||||||
|  |         ] | ||||||
|  |         return nn.Sequential(*modules) | ||||||
|  |     else: | ||||||
|  |         return build_from_cfg(cfg, registry, default_args) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def build_dataset(split_cfg, cfg): | ||||||
|  |     args = split_cfg.copy() | ||||||
|  |     args.pop('type') | ||||||
|  |     args = args.to_dict() | ||||||
|  |     args['cfg'] = cfg | ||||||
|  |     return build(split_cfg, DATASETS, default_args=args) | ||||||
|  | 
 | ||||||
|  | def build_dataloader(split_cfg, cfg, is_train=True): | ||||||
|  |     if is_train: | ||||||
|  |         shuffle = True | ||||||
|  |     else: | ||||||
|  |         shuffle = False | ||||||
|  | 
 | ||||||
|  |     dataset = build_dataset(split_cfg, cfg) | ||||||
|  | 
 | ||||||
|  |     data_loader = torch.utils.data.DataLoader( | ||||||
|  |         dataset, batch_size = cfg.batch_size, shuffle = shuffle, | ||||||
|  |         num_workers = cfg.workers, pin_memory = False, drop_last = False) | ||||||
|  | 
 | ||||||
|  |     return data_loader | ||||||
							
								
								
									
										150
									
								
								datasets/tusimple.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										150
									
								
								datasets/tusimple.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,150 @@ | |||||||
|  | import os.path as osp | ||||||
|  | import numpy as np | ||||||
|  | import cv2 | ||||||
|  | import torchvision | ||||||
|  | import utils.transforms as tf | ||||||
|  | from .base_dataset import BaseDataset | ||||||
|  | from .registry import DATASETS | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @DATASETS.register_module | ||||||
|  | class TuSimple(BaseDataset): | ||||||
|  |     def __init__(self, img_path, data_list, cfg=None): | ||||||
|  |         super().__init__(img_path, data_list, 'seg_label/list', cfg) | ||||||
|  | 
 | ||||||
|  |     def transform_train(self): | ||||||
|  |         input_mean = self.cfg.img_norm['mean'] | ||||||
|  |         train_transform = torchvision.transforms.Compose([ | ||||||
|  |             tf.GroupRandomRotation(), | ||||||
|  |             tf.GroupRandomHorizontalFlip(), | ||||||
|  |             tf.SampleResize((self.cfg.img_width, self.cfg.img_height)), | ||||||
|  |             tf.GroupNormalize(mean=(self.cfg.img_norm['mean'], (0, )), std=( | ||||||
|  |                 self.cfg.img_norm['std'], (1, ))), | ||||||
|  |         ]) | ||||||
|  |         return train_transform | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |     def init(self): | ||||||
|  |         with open(osp.join(self.list_path, self.data_list)) as f: | ||||||
|  |             for line in f: | ||||||
|  |                 line_split = line.strip().split(" ") | ||||||
|  |                 self.img_name_list.append(line_split[0]) | ||||||
|  |                 self.full_img_path_list.append(self.img_path + line_split[0]) | ||||||
|  |                 if not self.is_training: | ||||||
|  |                     continue | ||||||
|  |                 self.label_list.append(self.img_path + line_split[1]) | ||||||
|  |                 self.exist_list.append( | ||||||
|  |                     np.array([int(line_split[2]), int(line_split[3]), | ||||||
|  |                               int(line_split[4]), int(line_split[5]), | ||||||
|  |                               int(line_split[6]), int(line_split[7]) | ||||||
|  |                               ])) | ||||||
|  | 
 | ||||||
|  |     def fix_gap(self, coordinate): | ||||||
|  |         if any(x > 0 for x in coordinate): | ||||||
|  |             start = [i for i, x in enumerate(coordinate) if x > 0][0] | ||||||
|  |             end = [i for i, x in reversed(list(enumerate(coordinate))) if x > 0][0] | ||||||
|  |             lane = coordinate[start:end+1] | ||||||
|  |             if any(x < 0 for x in lane): | ||||||
|  |                 gap_start = [i for i, x in enumerate( | ||||||
|  |                     lane[:-1]) if x > 0 and lane[i+1] < 0] | ||||||
|  |                 gap_end = [i+1 for i, | ||||||
|  |                            x in enumerate(lane[:-1]) if x < 0 and lane[i+1] > 0] | ||||||
|  |                 gap_id = [i for i, x in enumerate(lane) if x < 0] | ||||||
|  |                 if len(gap_start) == 0 or len(gap_end) == 0: | ||||||
|  |                     return coordinate | ||||||
|  |                 for id in gap_id: | ||||||
|  |                     for i in range(len(gap_start)): | ||||||
|  |                         if i >= len(gap_end): | ||||||
|  |                             return coordinate | ||||||
|  |                         if id > gap_start[i] and id < gap_end[i]: | ||||||
|  |                             gap_width = float(gap_end[i] - gap_start[i]) | ||||||
|  |                             lane[id] = int((id - gap_start[i]) / gap_width * lane[gap_end[i]] + ( | ||||||
|  |                                 gap_end[i] - id) / gap_width * lane[gap_start[i]]) | ||||||
|  |                 if not all(x > 0 for x in lane): | ||||||
|  |                     print("Gaps still exist!") | ||||||
|  |                 coordinate[start:end+1] = lane | ||||||
|  |         return coordinate | ||||||
|  | 
 | ||||||
|  |     def is_short(self, lane): | ||||||
|  |         start = [i for i, x in enumerate(lane) if x > 0] | ||||||
|  |         if not start: | ||||||
|  |             return 1 | ||||||
|  |         else: | ||||||
|  |             return 0 | ||||||
|  | 
 | ||||||
|  |     def get_lane(self, prob_map, y_px_gap, pts, thresh, resize_shape=None): | ||||||
|  |         """ | ||||||
|  |         Arguments: | ||||||
|  |         ---------- | ||||||
|  |         prob_map: prob map for single lane, np array size (h, w) | ||||||
|  |         resize_shape:  reshape size target, (H, W) | ||||||
|  |      | ||||||
|  |         Return: | ||||||
|  |         ---------- | ||||||
|  |         coords: x coords bottom up every y_px_gap px, 0 for non-exist, in resized shape | ||||||
|  |         """ | ||||||
|  |         if resize_shape is None: | ||||||
|  |             resize_shape = prob_map.shape | ||||||
|  |         h, w = prob_map.shape | ||||||
|  |         H, W = resize_shape | ||||||
|  |         H -= self.cfg.cut_height | ||||||
|  |      | ||||||
|  |         coords = np.zeros(pts) | ||||||
|  |         coords[:] = -1.0 | ||||||
|  |         for i in range(pts): | ||||||
|  |             y = int((H - 10 - i * y_px_gap) * h / H) | ||||||
|  |             if y < 0: | ||||||
|  |                 break | ||||||
|  |             line = prob_map[y, :] | ||||||
|  |             id = np.argmax(line) | ||||||
|  |             if line[id] > thresh: | ||||||
|  |                 coords[i] = int(id / w * W) | ||||||
|  |         if (coords > 0).sum() < 2: | ||||||
|  |             coords = np.zeros(pts) | ||||||
|  |         self.fix_gap(coords) | ||||||
|  |         #print(coords.shape) | ||||||
|  | 
 | ||||||
|  |         return coords | ||||||
|  | 
 | ||||||
|  |     def probmap2lane(self, seg_pred, exist, resize_shape=(720, 1280), smooth=True, y_px_gap=10, pts=56, thresh=0.6): | ||||||
|  |         """ | ||||||
|  |         Arguments: | ||||||
|  |         ---------- | ||||||
|  |         seg_pred:      np.array size (5, h, w) | ||||||
|  |         resize_shape:  reshape size target, (H, W) | ||||||
|  |         exist:       list of existence, e.g. [0, 1, 1, 0] | ||||||
|  |         smooth:      whether to smooth the probability or not | ||||||
|  |         y_px_gap:    y pixel gap for sampling | ||||||
|  |         pts:     how many points for one lane | ||||||
|  |         thresh:  probability threshold | ||||||
|  |      | ||||||
|  |         Return: | ||||||
|  |         ---------- | ||||||
|  |         coordinates: [x, y] list of lanes, e.g.: [ [[9, 569], [50, 549]] ,[[630, 569], [647, 549]] ] | ||||||
|  |         """ | ||||||
|  |         if resize_shape is None: | ||||||
|  |             resize_shape = seg_pred.shape[1:]  # seg_pred (5, h, w) | ||||||
|  |         _, h, w = seg_pred.shape | ||||||
|  |         H, W = resize_shape | ||||||
|  |         coordinates = [] | ||||||
|  |      | ||||||
|  |         for i in range(self.cfg.num_classes - 1): | ||||||
|  |             prob_map = seg_pred[i + 1] | ||||||
|  |             if smooth: | ||||||
|  |                 prob_map = cv2.blur(prob_map, (9, 9), borderType=cv2.BORDER_REPLICATE) | ||||||
|  |             coords = self.get_lane(prob_map, y_px_gap, pts, thresh, resize_shape) | ||||||
|  |             if self.is_short(coords): | ||||||
|  |                 continue | ||||||
|  |             coordinates.append( | ||||||
|  |                 [[coords[j], H - 10 - j * y_px_gap] if coords[j] > 0 else [-1, H - 10 - j * y_px_gap] for j in | ||||||
|  |                  range(pts)]) | ||||||
|  |      | ||||||
|  |      | ||||||
|  |         if len(coordinates) == 0: | ||||||
|  |             coords = np.zeros(pts) | ||||||
|  |             coordinates.append( | ||||||
|  |                 [[coords[j], H - 10 - j * y_px_gap] if coords[j] > 0 else [-1, H - 10 - j * y_px_gap] for j in | ||||||
|  |                  range(pts)]) | ||||||
|  |         #print(coordinates) | ||||||
|  |      | ||||||
|  |         return coordinates | ||||||
							
								
								
									
										73
									
								
								main.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										73
									
								
								main.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,73 @@ | |||||||
|  | import os | ||||||
|  | import os.path as osp | ||||||
|  | import time | ||||||
|  | import shutil | ||||||
|  | import torch | ||||||
|  | import torchvision | ||||||
|  | import torch.nn.parallel | ||||||
|  | import torch.backends.cudnn as cudnn | ||||||
|  | import torch.nn.functional as F | ||||||
|  | import torch.optim | ||||||
|  | import cv2 | ||||||
|  | import numpy as np | ||||||
|  | import models | ||||||
|  | import argparse | ||||||
|  | from utils.config import Config | ||||||
|  | from runner.runner import Runner  | ||||||
|  | from datasets import build_dataloader | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def main(): | ||||||
|  |     args = parse_args() | ||||||
|  |     os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(gpu) for gpu in args.gpus) | ||||||
|  | 
 | ||||||
|  |     cfg = Config.fromfile(args.config) | ||||||
|  |     cfg.gpus = len(args.gpus) | ||||||
|  | 
 | ||||||
|  |     cfg.load_from = args.load_from | ||||||
|  |     cfg.finetune_from = args.finetune_from | ||||||
|  |     cfg.view = args.view | ||||||
|  | 
 | ||||||
|  |     cfg.work_dirs = args.work_dirs + '/' + cfg.dataset.train.type | ||||||
|  | 
 | ||||||
|  |     cudnn.benchmark = True | ||||||
|  |     cudnn.fastest = True | ||||||
|  | 
 | ||||||
|  |     runner = Runner(cfg) | ||||||
|  | 
 | ||||||
|  |     if args.validate: | ||||||
|  |         val_loader = build_dataloader(cfg.dataset.val, cfg, is_train=False) | ||||||
|  |         runner.validate(val_loader) | ||||||
|  |     else: | ||||||
|  |         runner.train() | ||||||
|  | 
 | ||||||
|  | def parse_args(): | ||||||
|  |     parser = argparse.ArgumentParser(description='Train a detector') | ||||||
|  |     parser.add_argument('config', help='train config file path') | ||||||
|  |     parser.add_argument( | ||||||
|  |         '--work_dirs', type=str, default='work_dirs', | ||||||
|  |         help='work dirs') | ||||||
|  |     parser.add_argument( | ||||||
|  |         '--load_from', default=None, | ||||||
|  |         help='the checkpoint file to resume from') | ||||||
|  |     parser.add_argument( | ||||||
|  |         '--finetune_from', default=None, | ||||||
|  |         help='whether to finetune from the checkpoint') | ||||||
|  |     parser.add_argument( | ||||||
|  |         '--validate', | ||||||
|  |         action='store_true', | ||||||
|  |         help='whether to evaluate the checkpoint during training') | ||||||
|  |     parser.add_argument( | ||||||
|  |         '--view', | ||||||
|  |         action='store_true', | ||||||
|  |         help='whether to show visualization result') | ||||||
|  |     parser.add_argument('--gpus', nargs='+', type=int, default='0') | ||||||
|  |     parser.add_argument('--seed', type=int, | ||||||
|  |                         default=None, help='random seed') | ||||||
|  |     args = parser.parse_args() | ||||||
|  | 
 | ||||||
|  |     return args | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | if __name__ == '__main__': | ||||||
|  |     main() | ||||||
							
								
								
									
										1
									
								
								models/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								models/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1 @@ | |||||||
|  | from .resa import * | ||||||
							
								
								
									
										129
									
								
								models/decoder.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										129
									
								
								models/decoder.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,129 @@ | |||||||
|  | from torch import nn | ||||||
|  | import torch.nn.functional as F | ||||||
|  | 
 | ||||||
|  | class PlainDecoder(nn.Module): | ||||||
|  |     def __init__(self, cfg): | ||||||
|  |         super(PlainDecoder, self).__init__() | ||||||
|  |         self.cfg = cfg | ||||||
|  | 
 | ||||||
|  |         self.dropout = nn.Dropout2d(0.1) | ||||||
|  |         self.conv8 = nn.Conv2d(128, cfg.num_classes, 1) | ||||||
|  | 
 | ||||||
|  |     def forward(self, x): | ||||||
|  |         x = self.dropout(x) | ||||||
|  |         x = self.conv8(x) | ||||||
|  |         x = F.interpolate(x, size=[self.cfg.img_height,  self.cfg.img_width], | ||||||
|  |                            mode='bilinear', align_corners=False) | ||||||
|  | 
 | ||||||
|  |         return x | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def conv1x1(in_planes, out_planes, stride=1): | ||||||
|  |     """1x1 convolution""" | ||||||
|  |     return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class non_bottleneck_1d(nn.Module): | ||||||
|  |     def __init__(self, chann, dropprob, dilated): | ||||||
|  |         super().__init__() | ||||||
|  | 
 | ||||||
|  |         self.conv3x1_1 = nn.Conv2d( | ||||||
|  |             chann, chann, (3, 1), stride=1, padding=(1, 0), bias=True) | ||||||
|  | 
 | ||||||
|  |         self.conv1x3_1 = nn.Conv2d( | ||||||
|  |             chann, chann, (1, 3), stride=1, padding=(0, 1), bias=True) | ||||||
|  | 
 | ||||||
|  |         self.bn1 = nn.BatchNorm2d(chann, eps=1e-03) | ||||||
|  | 
 | ||||||
|  |         self.conv3x1_2 = nn.Conv2d(chann, chann, (3, 1), stride=1, padding=(1 * dilated, 0), bias=True, | ||||||
|  |                                    dilation=(dilated, 1)) | ||||||
|  | 
 | ||||||
|  |         self.conv1x3_2 = nn.Conv2d(chann, chann, (1, 3), stride=1, padding=(0, 1 * dilated), bias=True, | ||||||
|  |                                    dilation=(1, dilated)) | ||||||
|  | 
 | ||||||
|  |         self.bn2 = nn.BatchNorm2d(chann, eps=1e-03) | ||||||
|  | 
 | ||||||
|  |         self.dropout = nn.Dropout2d(dropprob) | ||||||
|  | 
 | ||||||
|  |     def forward(self, input): | ||||||
|  |         output = self.conv3x1_1(input) | ||||||
|  |         output = F.relu(output) | ||||||
|  |         output = self.conv1x3_1(output) | ||||||
|  |         output = self.bn1(output) | ||||||
|  |         output = F.relu(output) | ||||||
|  | 
 | ||||||
|  |         output = self.conv3x1_2(output) | ||||||
|  |         output = F.relu(output) | ||||||
|  |         output = self.conv1x3_2(output) | ||||||
|  |         output = self.bn2(output) | ||||||
|  | 
 | ||||||
|  |         if (self.dropout.p != 0): | ||||||
|  |             output = self.dropout(output) | ||||||
|  | 
 | ||||||
|  |         # +input = identity (residual connection) | ||||||
|  |         return F.relu(output + input) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class UpsamplerBlock(nn.Module): | ||||||
|  |     def __init__(self, ninput, noutput, up_width, up_height): | ||||||
|  |         super().__init__() | ||||||
|  | 
 | ||||||
|  |         self.conv = nn.ConvTranspose2d( | ||||||
|  |             ninput, noutput, 3, stride=2, padding=1, output_padding=1, bias=True) | ||||||
|  | 
 | ||||||
|  |         self.bn = nn.BatchNorm2d(noutput, eps=1e-3, track_running_stats=True) | ||||||
|  | 
 | ||||||
|  |         self.follows = nn.ModuleList() | ||||||
|  |         self.follows.append(non_bottleneck_1d(noutput, 0, 1)) | ||||||
|  |         self.follows.append(non_bottleneck_1d(noutput, 0, 1)) | ||||||
|  | 
 | ||||||
|  |         # interpolate | ||||||
|  |         self.up_width = up_width | ||||||
|  |         self.up_height = up_height | ||||||
|  |         self.interpolate_conv = conv1x1(ninput, noutput) | ||||||
|  |         self.interpolate_bn = nn.BatchNorm2d( | ||||||
|  |             noutput, eps=1e-3, track_running_stats=True) | ||||||
|  | 
 | ||||||
|  |     def forward(self, input): | ||||||
|  |         output = self.conv(input) | ||||||
|  |         output = self.bn(output) | ||||||
|  |         out = F.relu(output) | ||||||
|  |         for follow in self.follows: | ||||||
|  |             out = follow(out) | ||||||
|  | 
 | ||||||
|  |         interpolate_output = self.interpolate_conv(input) | ||||||
|  |         interpolate_output = self.interpolate_bn(interpolate_output) | ||||||
|  |         interpolate_output = F.relu(interpolate_output) | ||||||
|  | 
 | ||||||
|  |         interpolate = F.interpolate(interpolate_output, size=[self.up_height,  self.up_width], | ||||||
|  |                                     mode='bilinear', align_corners=False) | ||||||
|  | 
 | ||||||
|  |         return out + interpolate | ||||||
|  | 
 | ||||||
|  | class BUSD(nn.Module): | ||||||
|  |     def __init__(self, cfg): | ||||||
|  |         super().__init__() | ||||||
|  |         img_height = cfg.img_height | ||||||
|  |         img_width = cfg.img_width | ||||||
|  |         num_classes = cfg.num_classes | ||||||
|  | 
 | ||||||
|  |         self.layers = nn.ModuleList() | ||||||
|  | 
 | ||||||
|  |         self.layers.append(UpsamplerBlock(ninput=128, noutput=64, | ||||||
|  |                                           up_height=int(img_height)//4, up_width=int(img_width)//4)) | ||||||
|  |         self.layers.append(UpsamplerBlock(ninput=64, noutput=32, | ||||||
|  |                                           up_height=int(img_height)//2, up_width=int(img_width)//2)) | ||||||
|  |         self.layers.append(UpsamplerBlock(ninput=32, noutput=16, | ||||||
|  |                                           up_height=int(img_height)//1, up_width=int(img_width)//1)) | ||||||
|  | 
 | ||||||
|  |         self.output_conv = conv1x1(16, num_classes) | ||||||
|  | 
 | ||||||
|  |     def forward(self, input): | ||||||
|  |         output = input | ||||||
|  | 
 | ||||||
|  |         for layer in self.layers: | ||||||
|  |             output = layer(output) | ||||||
|  | 
 | ||||||
|  |         output = self.output_conv(output) | ||||||
|  | 
 | ||||||
|  |         return output | ||||||
							
								
								
									
										135
									
								
								models/decoder_copy.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										135
									
								
								models/decoder_copy.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,135 @@ | |||||||
|  | from torch import nn | ||||||
|  | import torch.nn.functional as F | ||||||
|  | import torch | ||||||
|  | 
 | ||||||
|  | class PlainDecoder(nn.Module): | ||||||
|  |     def __init__(self, cfg): | ||||||
|  |         super(PlainDecoder, self).__init__() | ||||||
|  |         self.cfg = cfg | ||||||
|  | 
 | ||||||
|  |         self.dropout = nn.Dropout2d(0.1) | ||||||
|  |         self.conv8 = nn.Conv2d(128, cfg.num_classes, 1) | ||||||
|  | 
 | ||||||
|  |     def forward(self, x): | ||||||
|  |         x = self.dropout(x) | ||||||
|  |         x = self.conv8(x) | ||||||
|  |         x = F.interpolate(x, size=[self.cfg.img_height,  self.cfg.img_width], | ||||||
|  |                            mode='bilinear', align_corners=False) | ||||||
|  | 
 | ||||||
|  |         return x | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def conv1x1(in_planes, out_planes, stride=1): | ||||||
|  |     """1x1 convolution""" | ||||||
|  |     return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class non_bottleneck_1d(nn.Module): | ||||||
|  |     def __init__(self, chann, dropprob, dilated): | ||||||
|  |         super().__init__() | ||||||
|  | 
 | ||||||
|  |         self.conv3x1_1 = nn.Conv2d( | ||||||
|  |             chann, chann, (3, 1), stride=1, padding=(1, 0), bias=True) | ||||||
|  | 
 | ||||||
|  |         self.conv1x3_1 = nn.Conv2d( | ||||||
|  |             chann, chann, (1, 3), stride=1, padding=(0, 1), bias=True) | ||||||
|  | 
 | ||||||
|  |         self.bn1 = nn.BatchNorm2d(chann, eps=1e-03) | ||||||
|  | 
 | ||||||
|  |         self.conv3x1_2 = nn.Conv2d(chann, chann, (3, 1), stride=1, padding=(1 * dilated, 0), bias=True, | ||||||
|  |                                    dilation=(dilated, 1)) | ||||||
|  | 
 | ||||||
|  |         self.conv1x3_2 = nn.Conv2d(chann, chann, (1, 3), stride=1, padding=(0, 1 * dilated), bias=True, | ||||||
|  |                                    dilation=(1, dilated)) | ||||||
|  | 
 | ||||||
|  |         self.bn2 = nn.BatchNorm2d(chann, eps=1e-03) | ||||||
|  | 
 | ||||||
|  |         self.dropout = nn.Dropout2d(dropprob) | ||||||
|  | 
 | ||||||
|  |     def forward(self, input): | ||||||
|  |         output = self.conv3x1_1(input) | ||||||
|  |         output = F.relu(output) | ||||||
|  |         output = self.conv1x3_1(output) | ||||||
|  |         output = self.bn1(output) | ||||||
|  |         output = F.relu(output) | ||||||
|  | 
 | ||||||
|  |         output = self.conv3x1_2(output) | ||||||
|  |         output = F.relu(output) | ||||||
|  |         output = self.conv1x3_2(output) | ||||||
|  |         output = self.bn2(output) | ||||||
|  | 
 | ||||||
|  |         if (self.dropout.p != 0): | ||||||
|  |             output = self.dropout(output) | ||||||
|  | 
 | ||||||
|  |         # +input = identity (residual connection) | ||||||
|  |         return F.relu(output + input) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class UpsamplerBlock(nn.Module): | ||||||
|  |     def __init__(self, ninput, noutput, up_width, up_height): | ||||||
|  |         super().__init__() | ||||||
|  | 
 | ||||||
|  |         self.conv = nn.ConvTranspose2d( | ||||||
|  |             ninput, noutput, 3, stride=2, padding=1, output_padding=1, bias=True) | ||||||
|  | 
 | ||||||
|  |         self.bn = nn.BatchNorm2d(noutput, eps=1e-3, track_running_stats=True) | ||||||
|  | 
 | ||||||
|  |         self.follows = nn.ModuleList() | ||||||
|  |         self.follows.append(non_bottleneck_1d(noutput, 0, 1)) | ||||||
|  |         self.follows.append(non_bottleneck_1d(noutput, 0, 1)) | ||||||
|  | 
 | ||||||
|  |         # interpolate | ||||||
|  |         self.up_width = up_width | ||||||
|  |         self.up_height = up_height | ||||||
|  |         self.interpolate_conv = conv1x1(ninput, noutput) | ||||||
|  |         self.interpolate_bn = nn.BatchNorm2d( | ||||||
|  |             noutput, eps=1e-3, track_running_stats=True) | ||||||
|  | 
 | ||||||
|  |     def forward(self, input): | ||||||
|  |         output = self.conv(input) | ||||||
|  |         output = self.bn(output) | ||||||
|  |         out = F.relu(output) | ||||||
|  |         for follow in self.follows: | ||||||
|  |             out = follow(out) | ||||||
|  | 
 | ||||||
|  |         interpolate_output = self.interpolate_conv(input) | ||||||
|  |         interpolate_output = self.interpolate_bn(interpolate_output) | ||||||
|  |         interpolate_output = F.relu(interpolate_output) | ||||||
|  | 
 | ||||||
|  |         interpolate = F.interpolate(interpolate_output, size=[self.up_height,  self.up_width], | ||||||
|  |                                     mode='bilinear', align_corners=False) | ||||||
|  | 
 | ||||||
|  |         return out + interpolate | ||||||
|  | 
 | ||||||
|  | class BUSD(nn.Module): | ||||||
|  |     def __init__(self, cfg): | ||||||
|  |         super().__init__() | ||||||
|  |         img_height = cfg.img_height | ||||||
|  |         img_width = cfg.img_width | ||||||
|  |         num_classes = cfg.num_classes | ||||||
|  | 
 | ||||||
|  |         self.layers = nn.ModuleList() | ||||||
|  | 
 | ||||||
|  |         self.layers.append(UpsamplerBlock(ninput=128, noutput=64, | ||||||
|  |                                           up_height=int(img_height)//4, up_width=int(img_width)//4)) | ||||||
|  |         self.layers.append(UpsamplerBlock(ninput=128, noutput=64, | ||||||
|  |                                           up_height=int(img_height)//2, up_width=int(img_width)//2)) | ||||||
|  |         self.layers.append(UpsamplerBlock(ninput=64, noutput=32, | ||||||
|  |                                           up_height=int(img_height)//1, up_width=int(img_width)//1)) | ||||||
|  | 
 | ||||||
|  |         self.output_conv = conv1x1(32, num_classes) | ||||||
|  | 
 | ||||||
|  |     def forward(self, input): | ||||||
|  |         x = input[0] | ||||||
|  |         output = input[1] | ||||||
|  | 
 | ||||||
|  |         for i,layer in enumerate(self.layers): | ||||||
|  |             output = layer(output) | ||||||
|  |             if i == 0: | ||||||
|  |                 output = torch.cat((x, output), dim=1) | ||||||
|  |              | ||||||
|  |              | ||||||
|  | 
 | ||||||
|  |         output = self.output_conv(output) | ||||||
|  | 
 | ||||||
|  |         return output | ||||||
							
								
								
									
										143
									
								
								models/decoder_copy2.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										143
									
								
								models/decoder_copy2.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,143 @@ | |||||||
|  | from torch import nn | ||||||
|  | import torch | ||||||
|  | import torch.nn.functional as F | ||||||
|  | 
 | ||||||
|  | class PlainDecoder(nn.Module): | ||||||
|  |     def __init__(self, cfg): | ||||||
|  |         super(PlainDecoder, self).__init__() | ||||||
|  |         self.cfg = cfg | ||||||
|  | 
 | ||||||
|  |         self.dropout = nn.Dropout2d(0.1) | ||||||
|  |         self.conv8 = nn.Conv2d(128, cfg.num_classes, 1) | ||||||
|  | 
 | ||||||
|  |     def forward(self, x): | ||||||
|  |         x = self.dropout(x) | ||||||
|  |         x = self.conv8(x) | ||||||
|  |         x = F.interpolate(x, size=[self.cfg.img_height,  self.cfg.img_width], | ||||||
|  |                            mode='bilinear', align_corners=False) | ||||||
|  | 
 | ||||||
|  |         return x | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def conv1x1(in_planes, out_planes, stride=1): | ||||||
|  |     """1x1 convolution""" | ||||||
|  |     return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class non_bottleneck_1d(nn.Module): | ||||||
|  |     def __init__(self, chann, dropprob, dilated): | ||||||
|  |         super().__init__() | ||||||
|  | 
 | ||||||
|  |         self.conv3x1_1 = nn.Conv2d( | ||||||
|  |             chann, chann, (3, 1), stride=1, padding=(1, 0), bias=True) | ||||||
|  | 
 | ||||||
|  |         self.conv1x3_1 = nn.Conv2d( | ||||||
|  |             chann, chann, (1, 3), stride=1, padding=(0, 1), bias=True) | ||||||
|  | 
 | ||||||
|  |         self.bn1 = nn.BatchNorm2d(chann, eps=1e-03) | ||||||
|  | 
 | ||||||
|  |         self.conv3x1_2 = nn.Conv2d(chann, chann, (3, 1), stride=1, padding=(1 * dilated, 0), bias=True, | ||||||
|  |                                    dilation=(dilated, 1)) | ||||||
|  | 
 | ||||||
|  |         self.conv1x3_2 = nn.Conv2d(chann, chann, (1, 3), stride=1, padding=(0, 1 * dilated), bias=True, | ||||||
|  |                                    dilation=(1, dilated)) | ||||||
|  | 
 | ||||||
|  |         self.bn2 = nn.BatchNorm2d(chann, eps=1e-03) | ||||||
|  | 
 | ||||||
|  |         self.dropout = nn.Dropout2d(dropprob) | ||||||
|  | 
 | ||||||
|  |     def forward(self, input): | ||||||
|  |         output = self.conv3x1_1(input) | ||||||
|  |         output = F.relu(output) | ||||||
|  |         output = self.conv1x3_1(output) | ||||||
|  |         output = self.bn1(output) | ||||||
|  |         output = F.relu(output) | ||||||
|  | 
 | ||||||
|  |         output = self.conv3x1_2(output) | ||||||
|  |         output = F.relu(output) | ||||||
|  |         output = self.conv1x3_2(output) | ||||||
|  |         output = self.bn2(output) | ||||||
|  | 
 | ||||||
|  |         if (self.dropout.p != 0): | ||||||
|  |             output = self.dropout(output) | ||||||
|  | 
 | ||||||
|  |         # +input = identity (residual connection) | ||||||
|  |         return F.relu(output + input) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class UpsamplerBlock(nn.Module): | ||||||
|  |     def __init__(self, ninput, noutput, up_width, up_height): | ||||||
|  |         super().__init__() | ||||||
|  | 
 | ||||||
|  |         self.conv = nn.ConvTranspose2d( | ||||||
|  |             ninput, noutput, 3, stride=2, padding=1, output_padding=1, bias=True) | ||||||
|  | 
 | ||||||
|  |         self.bn = nn.BatchNorm2d(noutput, eps=1e-3, track_running_stats=True) | ||||||
|  | 
 | ||||||
|  |         self.follows = nn.ModuleList() | ||||||
|  |         self.follows.append(non_bottleneck_1d(noutput, 0, 1)) | ||||||
|  |         self.follows.append(non_bottleneck_1d(noutput, 0, 1)) | ||||||
|  | 
 | ||||||
|  |         # interpolate | ||||||
|  |         self.up_width = up_width | ||||||
|  |         self.up_height = up_height | ||||||
|  |         self.interpolate_conv = conv1x1(ninput, noutput) | ||||||
|  |         self.interpolate_bn = nn.BatchNorm2d( | ||||||
|  |             noutput, eps=1e-3, track_running_stats=True) | ||||||
|  | 
 | ||||||
|  |     def forward(self, input): | ||||||
|  |         output = self.conv(input) | ||||||
|  |         output = self.bn(output) | ||||||
|  |         out = F.relu(output) | ||||||
|  |         for follow in self.follows: | ||||||
|  |             out = follow(out) | ||||||
|  | 
 | ||||||
|  |         interpolate_output = self.interpolate_conv(input) | ||||||
|  |         interpolate_output = self.interpolate_bn(interpolate_output) | ||||||
|  |         interpolate_output = F.relu(interpolate_output) | ||||||
|  | 
 | ||||||
|  |         interpolate = F.interpolate(interpolate_output, size=[self.up_height,  self.up_width], | ||||||
|  |                                     mode='bilinear', align_corners=False) | ||||||
|  | 
 | ||||||
|  |         return out + interpolate | ||||||
|  | 
 | ||||||
|  | class BUSD(nn.Module): | ||||||
|  |     def __init__(self, cfg): | ||||||
|  |         super().__init__() | ||||||
|  |         img_height = cfg.img_height | ||||||
|  |         img_width = cfg.img_width | ||||||
|  |         num_classes = cfg.num_classes | ||||||
|  | 
 | ||||||
|  |         self.layers = nn.ModuleList() | ||||||
|  | 
 | ||||||
|  |         self.layers.append(UpsamplerBlock(ninput=128, noutput=64, | ||||||
|  |                                           up_height=int(img_height)//4, up_width=int(img_width)//4)) | ||||||
|  |         self.layers.append(UpsamplerBlock(ninput=64, noutput=32, | ||||||
|  |                                           up_height=int(img_height)//2, up_width=int(img_width)//2)) | ||||||
|  |         self.layers.append(UpsamplerBlock(ninput=32, noutput=16, | ||||||
|  |                                           up_height=int(img_height)//1, up_width=int(img_width)//1)) | ||||||
|  |         self.out1 = conv1x1(128, 64) | ||||||
|  |         self.out2 = conv1x1(64, 32) | ||||||
|  |         self.output_conv = conv1x1(16, num_classes) | ||||||
|  |          | ||||||
|  | 
 | ||||||
|  |     def forward(self, input): | ||||||
|  |         out1 = input[0] | ||||||
|  |         out2 = input[1] | ||||||
|  |         output = input[2] | ||||||
|  | 
 | ||||||
|  |         for i,layer in enumerate(self.layers): | ||||||
|  |             if i == 0: | ||||||
|  |                 output = layer(output) | ||||||
|  |                 output = torch.cat((out2, output), dim=1) | ||||||
|  |                 output = self.out1(output) | ||||||
|  |             elif i == 1: | ||||||
|  |                 output = layer(output) | ||||||
|  |                 output = torch.cat((out1, output), dim=1) | ||||||
|  |                 output = self.out2(output) | ||||||
|  |             else: | ||||||
|  |                 output = layer(output) | ||||||
|  | 
 | ||||||
|  |         output = self.output_conv(output) | ||||||
|  | 
 | ||||||
|  |         return output | ||||||
							
								
								
									
										422
									
								
								models/mobilenetv2.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										422
									
								
								models/mobilenetv2.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,422 @@ | |||||||
|  | from functools import partial | ||||||
|  | from typing import Any, Callable, List, Optional | ||||||
|  | 
 | ||||||
|  | import torch | ||||||
|  | from torch import nn, Tensor | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | from torchvision.transforms._presets import ImageClassification | ||||||
|  | from torchvision.utils import _log_api_usage_once | ||||||
|  | from torchvision.models._api import  Weights, WeightsEnum | ||||||
|  | from torchvision.models._meta import _IMAGENET_CATEGORIES | ||||||
|  | from torchvision.models._utils import _make_divisible, _ovewrite_named_param, handle_legacy_interface | ||||||
|  | import warnings | ||||||
|  | from typing import Callable, List, Optional, Sequence, Tuple, Union, TypeVar | ||||||
|  | import collections | ||||||
|  | from itertools import repeat | ||||||
|  | M = TypeVar("M", bound=nn.Module) | ||||||
|  | 
 | ||||||
|  | BUILTIN_MODELS = {} | ||||||
|  | def register_model(name: Optional[str] = None) -> Callable[[Callable[..., M]], Callable[..., M]]: | ||||||
|  |     def wrapper(fn: Callable[..., M]) -> Callable[..., M]: | ||||||
|  |         key = name if name is not None else fn.__name__ | ||||||
|  |         if key in BUILTIN_MODELS: | ||||||
|  |             raise ValueError(f"An entry is already registered under the name '{key}'.") | ||||||
|  |         BUILTIN_MODELS[key] = fn | ||||||
|  |         return fn | ||||||
|  | 
 | ||||||
|  |     return wrapper | ||||||
|  | 
 | ||||||
|  | def _make_ntuple(x: Any, n: int) -> Tuple[Any, ...]: | ||||||
|  |     """ | ||||||
|  |     Make n-tuple from input x. If x is an iterable, then we just convert it to tuple. | ||||||
|  |     Otherwise, we will make a tuple of length n, all with value of x. | ||||||
|  |     reference: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/utils.py#L8 | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         x (Any): input value | ||||||
|  |         n (int): length of the resulting tuple | ||||||
|  |     """ | ||||||
|  |     if isinstance(x, collections.abc.Iterable): | ||||||
|  |         return tuple(x) | ||||||
|  |     return tuple(repeat(x, n)) | ||||||
|  | 
 | ||||||
|  | class ConvNormActivation(torch.nn.Sequential): | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         in_channels: int, | ||||||
|  |         out_channels: int, | ||||||
|  |         kernel_size: Union[int, Tuple[int, ...]] = 3, | ||||||
|  |         stride: Union[int, Tuple[int, ...]] = 1, | ||||||
|  |         padding: Optional[Union[int, Tuple[int, ...], str]] = None, | ||||||
|  |         groups: int = 1, | ||||||
|  |         norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d, | ||||||
|  |         activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU, | ||||||
|  |         dilation: Union[int, Tuple[int, ...]] = 1, | ||||||
|  |         inplace: Optional[bool] = True, | ||||||
|  |         bias: Optional[bool] = None, | ||||||
|  |         conv_layer: Callable[..., torch.nn.Module] = torch.nn.Conv2d, | ||||||
|  |     ) -> None: | ||||||
|  | 
 | ||||||
|  |         if padding is None: | ||||||
|  |             if isinstance(kernel_size, int) and isinstance(dilation, int): | ||||||
|  |                 padding = (kernel_size - 1) // 2 * dilation | ||||||
|  |             else: | ||||||
|  |                 _conv_dim = len(kernel_size) if isinstance(kernel_size, Sequence) else len(dilation) | ||||||
|  |                 kernel_size = _make_ntuple(kernel_size, _conv_dim) | ||||||
|  |                 dilation = _make_ntuple(dilation, _conv_dim) | ||||||
|  |                 padding = tuple((kernel_size[i] - 1) // 2 * dilation[i] for i in range(_conv_dim)) | ||||||
|  |         if bias is None: | ||||||
|  |             bias = norm_layer is None | ||||||
|  | 
 | ||||||
|  |         layers = [ | ||||||
|  |             conv_layer( | ||||||
|  |                 in_channels, | ||||||
|  |                 out_channels, | ||||||
|  |                 kernel_size, | ||||||
|  |                 stride, | ||||||
|  |                 padding, | ||||||
|  |                 dilation=dilation, | ||||||
|  |                 groups=groups, | ||||||
|  |                 bias=bias, | ||||||
|  |             ) | ||||||
|  |         ] | ||||||
|  | 
 | ||||||
|  |         if norm_layer is not None: | ||||||
|  |             layers.append(norm_layer(out_channels)) | ||||||
|  | 
 | ||||||
|  |         if activation_layer is not None: | ||||||
|  |             params = {} if inplace is None else {"inplace": inplace} | ||||||
|  |             layers.append(activation_layer(**params)) | ||||||
|  |         super().__init__(*layers) | ||||||
|  |         _log_api_usage_once(self) | ||||||
|  |         self.out_channels = out_channels | ||||||
|  | 
 | ||||||
|  |         if self.__class__ == ConvNormActivation: | ||||||
|  |             warnings.warn( | ||||||
|  |                 "Don't use ConvNormActivation directly, please use Conv2dNormActivation and Conv3dNormActivation instead." | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class Conv2dNormActivation(ConvNormActivation): | ||||||
|  |     """ | ||||||
|  |     Configurable block used for Convolution2d-Normalization-Activation blocks. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         in_channels (int): Number of channels in the input image | ||||||
|  |         out_channels (int): Number of channels produced by the Convolution-Normalization-Activation block | ||||||
|  |         kernel_size: (int, optional): Size of the convolving kernel. Default: 3 | ||||||
|  |         stride (int, optional): Stride of the convolution. Default: 1 | ||||||
|  |         padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in which case it will be calculated as ``padding = (kernel_size - 1) // 2 * dilation`` | ||||||
|  |         groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 | ||||||
|  |         norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer won't be used. Default: ``torch.nn.BatchNorm2d`` | ||||||
|  |         activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer won't be used. Default: ``torch.nn.ReLU`` | ||||||
|  |         dilation (int): Spacing between kernel elements. Default: 1 | ||||||
|  |         inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True`` | ||||||
|  |         bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``. | ||||||
|  | 
 | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         in_channels: int, | ||||||
|  |         out_channels: int, | ||||||
|  |         kernel_size: Union[int, Tuple[int, int]] = 3, | ||||||
|  |         stride: Union[int, Tuple[int, int]] = 1, | ||||||
|  |         padding: Optional[Union[int, Tuple[int, int], str]] = None, | ||||||
|  |         groups: int = 1, | ||||||
|  |         norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d, | ||||||
|  |         activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU, | ||||||
|  |         dilation: Union[int, Tuple[int, int]] = 1, | ||||||
|  |         inplace: Optional[bool] = True, | ||||||
|  |         bias: Optional[bool] = None, | ||||||
|  |     ) -> None: | ||||||
|  | 
 | ||||||
|  |         super().__init__( | ||||||
|  |             in_channels, | ||||||
|  |             out_channels, | ||||||
|  |             kernel_size, | ||||||
|  |             stride, | ||||||
|  |             padding, | ||||||
|  |             groups, | ||||||
|  |             norm_layer, | ||||||
|  |             activation_layer, | ||||||
|  |             dilation, | ||||||
|  |             inplace, | ||||||
|  |             bias, | ||||||
|  |             torch.nn.Conv2d, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  | __all__ = ["MobileNetV2", "MobileNet_V2_Weights", "mobilenet_v2"] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # necessary for backwards compatibility | ||||||
|  | class InvertedResidual(nn.Module): | ||||||
|  |     def __init__( | ||||||
|  |         self, inp: int, oup: int, stride: int, expand_ratio: int, norm_layer: Optional[Callable[..., nn.Module]] = None | ||||||
|  |     ) -> None: | ||||||
|  |         super().__init__() | ||||||
|  |         self.stride = stride | ||||||
|  |         if stride not in [1, 2]: | ||||||
|  |             raise ValueError(f"stride should be 1 or 2 instead of {stride}") | ||||||
|  | 
 | ||||||
|  |         if norm_layer is None: | ||||||
|  |             norm_layer = nn.BatchNorm2d | ||||||
|  | 
 | ||||||
|  |         hidden_dim = int(round(inp * expand_ratio)) | ||||||
|  |         self.use_res_connect = self.stride == 1 and inp == oup | ||||||
|  | 
 | ||||||
|  |         layers: List[nn.Module] = [] | ||||||
|  |         if expand_ratio != 1: | ||||||
|  |             # pw | ||||||
|  |             layers.append( | ||||||
|  |                 Conv2dNormActivation(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.ReLU6) | ||||||
|  |             ) | ||||||
|  |         layers.extend( | ||||||
|  |             [ | ||||||
|  |                 # dw | ||||||
|  |                 Conv2dNormActivation( | ||||||
|  |                     hidden_dim, | ||||||
|  |                     hidden_dim, | ||||||
|  |                     stride=stride, | ||||||
|  |                     groups=hidden_dim, | ||||||
|  |                     norm_layer=norm_layer, | ||||||
|  |                     activation_layer=nn.ReLU6, | ||||||
|  |                 ), | ||||||
|  |                 # pw-linear | ||||||
|  |                 nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), | ||||||
|  |                 norm_layer(oup), | ||||||
|  |             ] | ||||||
|  |         ) | ||||||
|  |         self.conv = nn.Sequential(*layers) | ||||||
|  |         self.out_channels = oup | ||||||
|  |         self._is_cn = stride > 1 | ||||||
|  | 
 | ||||||
|  |     def forward(self, x: Tensor) -> Tensor: | ||||||
|  |         if self.use_res_connect: | ||||||
|  |             return x + self.conv(x) | ||||||
|  |         else: | ||||||
|  |             return self.conv(x) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class MobileNetV2(nn.Module): | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         num_classes: int = 1000, | ||||||
|  |         width_mult: float = 1.0, | ||||||
|  |         inverted_residual_setting: Optional[List[List[int]]] = None, | ||||||
|  |         round_nearest: int = 8, | ||||||
|  |         block: Optional[Callable[..., nn.Module]] = None, | ||||||
|  |         norm_layer: Optional[Callable[..., nn.Module]] = None, | ||||||
|  |         dropout: float = 0.2, | ||||||
|  |     ) -> None: | ||||||
|  |         """ | ||||||
|  |         MobileNet V2 main class | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             num_classes (int): Number of classes | ||||||
|  |             width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount | ||||||
|  |             inverted_residual_setting: Network structure | ||||||
|  |             round_nearest (int): Round the number of channels in each layer to be a multiple of this number | ||||||
|  |             Set to 1 to turn off rounding | ||||||
|  |             block: Module specifying inverted residual building block for mobilenet | ||||||
|  |             norm_layer: Module specifying the normalization layer to use | ||||||
|  |             dropout (float): The droupout probability | ||||||
|  | 
 | ||||||
|  |         """ | ||||||
|  |         super().__init__() | ||||||
|  |         _log_api_usage_once(self) | ||||||
|  | 
 | ||||||
|  |         if block is None: | ||||||
|  |             block = InvertedResidual | ||||||
|  | 
 | ||||||
|  |         if norm_layer is None: | ||||||
|  |             norm_layer = nn.BatchNorm2d | ||||||
|  | 
 | ||||||
|  |         input_channel = 32 | ||||||
|  |         last_channel = 1280 | ||||||
|  | 
 | ||||||
|  |         if inverted_residual_setting is None: | ||||||
|  |             inverted_residual_setting = [ | ||||||
|  |                 # t, c, n, s | ||||||
|  |                 [1, 16, 1, 1], | ||||||
|  |                 [6, 24, 2, 2], | ||||||
|  |                 [6, 32, 3, 2], | ||||||
|  |                 [6, 64, 4, 1], # ** | ||||||
|  |                 [6, 96, 3, 1], | ||||||
|  |                 [6, 160, 3, 1], # ** | ||||||
|  |                 [6, 320, 1, 1], | ||||||
|  |             ] | ||||||
|  | 
 | ||||||
|  |         # only check the first element, assuming user knows t,c,n,s are required | ||||||
|  |         if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: | ||||||
|  |             raise ValueError( | ||||||
|  |                 f"inverted_residual_setting should be non-empty or a 4-element list, got {inverted_residual_setting}" | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |         # building first layer | ||||||
|  |         input_channel = _make_divisible(input_channel * width_mult, round_nearest) | ||||||
|  |         self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) | ||||||
|  |         features: List[nn.Module] = [ | ||||||
|  |             Conv2dNormActivation(3, input_channel, stride=2, norm_layer=norm_layer, activation_layer=nn.ReLU6) | ||||||
|  |         ] | ||||||
|  |         # building inverted residual blocks | ||||||
|  |         for t, c, n, s in inverted_residual_setting: | ||||||
|  |             output_channel = _make_divisible(c * width_mult, round_nearest) | ||||||
|  |             for i in range(n): | ||||||
|  |                 stride = s if i == 0 else 1 | ||||||
|  |                 features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer)) | ||||||
|  |                 input_channel = output_channel | ||||||
|  |         # building last several layers | ||||||
|  |         features.append( | ||||||
|  |             Conv2dNormActivation( | ||||||
|  |                 input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.ReLU6 | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |         # make it nn.Sequential | ||||||
|  |         self.features = nn.Sequential(*features) | ||||||
|  | 
 | ||||||
|  |         # building classifier | ||||||
|  |         self.classifier = nn.Sequential( | ||||||
|  |             nn.Dropout(p=dropout), | ||||||
|  |             nn.Linear(self.last_channel, num_classes), | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         # weight initialization | ||||||
|  |         for m in self.modules(): | ||||||
|  |             if isinstance(m, nn.Conv2d): | ||||||
|  |                 nn.init.kaiming_normal_(m.weight, mode="fan_out") | ||||||
|  |                 if m.bias is not None: | ||||||
|  |                     nn.init.zeros_(m.bias) | ||||||
|  |             elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): | ||||||
|  |                 nn.init.ones_(m.weight) | ||||||
|  |                 nn.init.zeros_(m.bias) | ||||||
|  |             elif isinstance(m, nn.Linear): | ||||||
|  |                 nn.init.normal_(m.weight, 0, 0.01) | ||||||
|  |                 nn.init.zeros_(m.bias) | ||||||
|  | 
 | ||||||
|  |     def _forward_impl(self, x: Tensor) -> Tensor: | ||||||
|  |         # This exists since TorchScript doesn't support inheritance, so the superclass method | ||||||
|  |         # (this one) needs to have a name other than `forward` that can be accessed in a subclass | ||||||
|  |         x = self.features(x) | ||||||
|  |         # Cannot use "squeeze" as batch-size can be 1 | ||||||
|  |         # x = nn.functional.adaptive_avg_pool2d(x, (1, 1)) | ||||||
|  |         # x = torch.flatten(x, 1) | ||||||
|  |         # x = self.classifier(x) | ||||||
|  |         return x | ||||||
|  | 
 | ||||||
|  |     def forward(self, x: Tensor) -> Tensor: | ||||||
|  |         return self._forward_impl(x) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | _COMMON_META = { | ||||||
|  |     "num_params": 3504872, | ||||||
|  |     "min_size": (1, 1), | ||||||
|  |     "categories": _IMAGENET_CATEGORIES, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class MobileNet_V2_Weights(WeightsEnum): | ||||||
|  |     IMAGENET1K_V1 = Weights( | ||||||
|  |         url="https://download.pytorch.org/models/mobilenet_v2-b0353104.pth", | ||||||
|  |         transforms=partial(ImageClassification, crop_size=224), | ||||||
|  |         meta={ | ||||||
|  |             **_COMMON_META, | ||||||
|  |             "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv2", | ||||||
|  |             "_metrics": { | ||||||
|  |                 "ImageNet-1K": { | ||||||
|  |                     "acc@1": 71.878, | ||||||
|  |                     "acc@5": 90.286, | ||||||
|  |                 } | ||||||
|  |             }, | ||||||
|  |             "_ops": 0.301, | ||||||
|  |             "_file_size": 13.555, | ||||||
|  |             "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", | ||||||
|  |         }, | ||||||
|  |     ) | ||||||
|  |     IMAGENET1K_V2 = Weights( | ||||||
|  |         url="https://download.pytorch.org/models/mobilenet_v2-7ebf99e0.pth", | ||||||
|  |         transforms=partial(ImageClassification, crop_size=224, resize_size=232), | ||||||
|  |         meta={ | ||||||
|  |             **_COMMON_META, | ||||||
|  |             "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-reg-tuning", | ||||||
|  |             "_metrics": { | ||||||
|  |                 "ImageNet-1K": { | ||||||
|  |                     "acc@1": 72.154, | ||||||
|  |                     "acc@5": 90.822, | ||||||
|  |                 } | ||||||
|  |             }, | ||||||
|  |             "_ops": 0.301, | ||||||
|  |             "_file_size": 13.598, | ||||||
|  |             "_docs": """ | ||||||
|  |                 These weights improve upon the results of the original paper by using a modified version of TorchVision's | ||||||
|  |                 `new training recipe | ||||||
|  |                 <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_. | ||||||
|  |             """, | ||||||
|  |         }, | ||||||
|  |     ) | ||||||
|  |     DEFAULT = IMAGENET1K_V2 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # @register_model() | ||||||
|  | # @handle_legacy_interface(weights=("pretrained", MobileNet_V2_Weights.IMAGENET1K_V1)) | ||||||
|  | def mobilenet_v2( | ||||||
|  |     *, weights: Optional[MobileNet_V2_Weights] = MobileNet_V2_Weights.IMAGENET1K_V1, progress: bool = True, **kwargs: Any | ||||||
|  | ) -> MobileNetV2: | ||||||
|  |     """MobileNetV2 architecture from the `MobileNetV2: Inverted Residuals and Linear | ||||||
|  |     Bottlenecks <https://arxiv.org/abs/1801.04381>`_ paper. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         weights (:class:`~torchvision.models.MobileNet_V2_Weights`, optional): The | ||||||
|  |             pretrained weights to use. See | ||||||
|  |             :class:`~torchvision.models.MobileNet_V2_Weights` below for | ||||||
|  |             more details, and possible values. By default, no pre-trained | ||||||
|  |             weights are used. | ||||||
|  |         progress (bool, optional): If True, displays a progress bar of the | ||||||
|  |             download to stderr. Default is True. | ||||||
|  |         **kwargs: parameters passed to the ``torchvision.models.mobilenetv2.MobileNetV2`` | ||||||
|  |             base class. Please refer to the `source code | ||||||
|  |             <https://github.com/pytorch/vision/blob/main/torchvision/models/mobilenetv2.py>`_ | ||||||
|  |             for more details about this class. | ||||||
|  | 
 | ||||||
|  |     .. autoclass:: torchvision.models.MobileNet_V2_Weights | ||||||
|  |         :members: | ||||||
|  |     """ | ||||||
|  |     weights = MobileNet_V2_Weights.verify(weights) | ||||||
|  | 
 | ||||||
|  |     if weights is not None: | ||||||
|  |         _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) | ||||||
|  | 
 | ||||||
|  |     model = MobileNetV2(**kwargs) | ||||||
|  | 
 | ||||||
|  |     if weights is not None: | ||||||
|  |         model.load_state_dict(weights.get_state_dict(progress=progress)) | ||||||
|  | 
 | ||||||
|  |     return model | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def conv1x1(in_planes, out_planes, stride=1): | ||||||
|  |     """1x1 convolution""" | ||||||
|  |     return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) | ||||||
|  | 
 | ||||||
|  | class MobileNetv2Wrapper(nn.Module): | ||||||
|  |     def __init__(self): | ||||||
|  |         super(MobileNetv2Wrapper, self).__init__() | ||||||
|  |         weights = MobileNet_V2_Weights.verify(MobileNet_V2_Weights.IMAGENET1K_V1) | ||||||
|  | 
 | ||||||
|  |         self.model = MobileNetV2() | ||||||
|  | 
 | ||||||
|  |         if weights is not None: | ||||||
|  |             self.model.load_state_dict(weights.get_state_dict(progress=True)) | ||||||
|  |         self.out = conv1x1( | ||||||
|  |                 1280, 128) | ||||||
|  | 
 | ||||||
|  |     def forward(self, x): | ||||||
|  |         # print(x.shape) | ||||||
|  |         x = self.model(x) | ||||||
|  |         # print(x.shape) | ||||||
|  |         if self.out: | ||||||
|  |             x = self.out(x) | ||||||
|  |             # print(x.shape) | ||||||
|  |         return x | ||||||
|  | 
 | ||||||
							
								
								
									
										436
									
								
								models/mobilenetv2_copy2.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										436
									
								
								models/mobilenetv2_copy2.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,436 @@ | |||||||
|  | from functools import partial | ||||||
|  | from typing import Any, Callable, List, Optional | ||||||
|  | 
 | ||||||
|  | import torch | ||||||
|  | from torch import nn, Tensor | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | from torchvision.transforms._presets import ImageClassification | ||||||
|  | from torchvision.utils import _log_api_usage_once | ||||||
|  | from torchvision.models._api import  Weights, WeightsEnum | ||||||
|  | from torchvision.models._meta import _IMAGENET_CATEGORIES | ||||||
|  | from torchvision.models._utils import _make_divisible, _ovewrite_named_param, handle_legacy_interface | ||||||
|  | import warnings | ||||||
|  | from typing import Callable, List, Optional, Sequence, Tuple, Union, TypeVar | ||||||
|  | import collections | ||||||
|  | from itertools import repeat | ||||||
|  | M = TypeVar("M", bound=nn.Module) | ||||||
|  | 
 | ||||||
|  | BUILTIN_MODELS = {} | ||||||
|  | def register_model(name: Optional[str] = None) -> Callable[[Callable[..., M]], Callable[..., M]]: | ||||||
|  |     def wrapper(fn: Callable[..., M]) -> Callable[..., M]: | ||||||
|  |         key = name if name is not None else fn.__name__ | ||||||
|  |         if key in BUILTIN_MODELS: | ||||||
|  |             raise ValueError(f"An entry is already registered under the name '{key}'.") | ||||||
|  |         BUILTIN_MODELS[key] = fn | ||||||
|  |         return fn | ||||||
|  | 
 | ||||||
|  |     return wrapper | ||||||
|  | 
 | ||||||
|  | def _make_ntuple(x: Any, n: int) -> Tuple[Any, ...]: | ||||||
|  |     """ | ||||||
|  |     Make n-tuple from input x. If x is an iterable, then we just convert it to tuple. | ||||||
|  |     Otherwise, we will make a tuple of length n, all with value of x. | ||||||
|  |     reference: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/utils.py#L8 | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         x (Any): input value | ||||||
|  |         n (int): length of the resulting tuple | ||||||
|  |     """ | ||||||
|  |     if isinstance(x, collections.abc.Iterable): | ||||||
|  |         return tuple(x) | ||||||
|  |     return tuple(repeat(x, n)) | ||||||
|  | 
 | ||||||
|  | class ConvNormActivation(torch.nn.Sequential): | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         in_channels: int, | ||||||
|  |         out_channels: int, | ||||||
|  |         kernel_size: Union[int, Tuple[int, ...]] = 3, | ||||||
|  |         stride: Union[int, Tuple[int, ...]] = 1, | ||||||
|  |         padding: Optional[Union[int, Tuple[int, ...], str]] = None, | ||||||
|  |         groups: int = 1, | ||||||
|  |         norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d, | ||||||
|  |         activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU, | ||||||
|  |         dilation: Union[int, Tuple[int, ...]] = 1, | ||||||
|  |         inplace: Optional[bool] = True, | ||||||
|  |         bias: Optional[bool] = None, | ||||||
|  |         conv_layer: Callable[..., torch.nn.Module] = torch.nn.Conv2d, | ||||||
|  |     ) -> None: | ||||||
|  | 
 | ||||||
|  |         if padding is None: | ||||||
|  |             if isinstance(kernel_size, int) and isinstance(dilation, int): | ||||||
|  |                 padding = (kernel_size - 1) // 2 * dilation | ||||||
|  |             else: | ||||||
|  |                 _conv_dim = len(kernel_size) if isinstance(kernel_size, Sequence) else len(dilation) | ||||||
|  |                 kernel_size = _make_ntuple(kernel_size, _conv_dim) | ||||||
|  |                 dilation = _make_ntuple(dilation, _conv_dim) | ||||||
|  |                 padding = tuple((kernel_size[i] - 1) // 2 * dilation[i] for i in range(_conv_dim)) | ||||||
|  |         if bias is None: | ||||||
|  |             bias = norm_layer is None | ||||||
|  | 
 | ||||||
|  |         layers = [ | ||||||
|  |             conv_layer( | ||||||
|  |                 in_channels, | ||||||
|  |                 out_channels, | ||||||
|  |                 kernel_size, | ||||||
|  |                 stride, | ||||||
|  |                 padding, | ||||||
|  |                 dilation=dilation, | ||||||
|  |                 groups=groups, | ||||||
|  |                 bias=bias, | ||||||
|  |             ) | ||||||
|  |         ] | ||||||
|  | 
 | ||||||
|  |         if norm_layer is not None: | ||||||
|  |             layers.append(norm_layer(out_channels)) | ||||||
|  | 
 | ||||||
|  |         if activation_layer is not None: | ||||||
|  |             params = {} if inplace is None else {"inplace": inplace} | ||||||
|  |             layers.append(activation_layer(**params)) | ||||||
|  |         super().__init__(*layers) | ||||||
|  |         _log_api_usage_once(self) | ||||||
|  |         self.out_channels = out_channels | ||||||
|  | 
 | ||||||
|  |         if self.__class__ == ConvNormActivation: | ||||||
|  |             warnings.warn( | ||||||
|  |                 "Don't use ConvNormActivation directly, please use Conv2dNormActivation and Conv3dNormActivation instead." | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class Conv2dNormActivation(ConvNormActivation): | ||||||
|  |     """ | ||||||
|  |     Configurable block used for Convolution2d-Normalization-Activation blocks. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         in_channels (int): Number of channels in the input image | ||||||
|  |         out_channels (int): Number of channels produced by the Convolution-Normalization-Activation block | ||||||
|  |         kernel_size: (int, optional): Size of the convolving kernel. Default: 3 | ||||||
|  |         stride (int, optional): Stride of the convolution. Default: 1 | ||||||
|  |         padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in which case it will be calculated as ``padding = (kernel_size - 1) // 2 * dilation`` | ||||||
|  |         groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 | ||||||
|  |         norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer won't be used. Default: ``torch.nn.BatchNorm2d`` | ||||||
|  |         activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer won't be used. Default: ``torch.nn.ReLU`` | ||||||
|  |         dilation (int): Spacing between kernel elements. Default: 1 | ||||||
|  |         inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True`` | ||||||
|  |         bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``. | ||||||
|  | 
 | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         in_channels: int, | ||||||
|  |         out_channels: int, | ||||||
|  |         kernel_size: Union[int, Tuple[int, int]] = 3, | ||||||
|  |         stride: Union[int, Tuple[int, int]] = 1, | ||||||
|  |         padding: Optional[Union[int, Tuple[int, int], str]] = None, | ||||||
|  |         groups: int = 1, | ||||||
|  |         norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d, | ||||||
|  |         activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU, | ||||||
|  |         dilation: Union[int, Tuple[int, int]] = 1, | ||||||
|  |         inplace: Optional[bool] = True, | ||||||
|  |         bias: Optional[bool] = None, | ||||||
|  |     ) -> None: | ||||||
|  | 
 | ||||||
|  |         super().__init__( | ||||||
|  |             in_channels, | ||||||
|  |             out_channels, | ||||||
|  |             kernel_size, | ||||||
|  |             stride, | ||||||
|  |             padding, | ||||||
|  |             groups, | ||||||
|  |             norm_layer, | ||||||
|  |             activation_layer, | ||||||
|  |             dilation, | ||||||
|  |             inplace, | ||||||
|  |             bias, | ||||||
|  |             torch.nn.Conv2d, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  | __all__ = ["MobileNetV2", "MobileNet_V2_Weights", "mobilenet_v2"] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # necessary for backwards compatibility | ||||||
|  | class InvertedResidual(nn.Module): | ||||||
|  |     def __init__( | ||||||
|  |         self, inp: int, oup: int, stride: int, expand_ratio: int, norm_layer: Optional[Callable[..., nn.Module]] = None | ||||||
|  |     ) -> None: | ||||||
|  |         super().__init__() | ||||||
|  |         self.stride = stride | ||||||
|  |         if stride not in [1, 2]: | ||||||
|  |             raise ValueError(f"stride should be 1 or 2 instead of {stride}") | ||||||
|  | 
 | ||||||
|  |         if norm_layer is None: | ||||||
|  |             norm_layer = nn.BatchNorm2d | ||||||
|  | 
 | ||||||
|  |         hidden_dim = int(round(inp * expand_ratio)) | ||||||
|  |         self.use_res_connect = self.stride == 1 and inp == oup | ||||||
|  | 
 | ||||||
|  |         layers: List[nn.Module] = [] | ||||||
|  |         if expand_ratio != 1: | ||||||
|  |             # pw | ||||||
|  |             layers.append( | ||||||
|  |                 Conv2dNormActivation(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.ReLU6) | ||||||
|  |             ) | ||||||
|  |         layers.extend( | ||||||
|  |             [ | ||||||
|  |                 # dw | ||||||
|  |                 Conv2dNormActivation( | ||||||
|  |                     hidden_dim, | ||||||
|  |                     hidden_dim, | ||||||
|  |                     stride=stride, | ||||||
|  |                     groups=hidden_dim, | ||||||
|  |                     norm_layer=norm_layer, | ||||||
|  |                     activation_layer=nn.ReLU6, | ||||||
|  |                 ), | ||||||
|  |                 # pw-linear | ||||||
|  |                 nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), | ||||||
|  |                 norm_layer(oup), | ||||||
|  |             ] | ||||||
|  |         ) | ||||||
|  |         self.conv = nn.Sequential(*layers) | ||||||
|  |         self.out_channels = oup | ||||||
|  |         self._is_cn = stride > 1 | ||||||
|  | 
 | ||||||
|  |     def forward(self, x: Tensor) -> Tensor: | ||||||
|  |         if self.use_res_connect: | ||||||
|  |             return x + self.conv(x) | ||||||
|  |         else: | ||||||
|  |             return self.conv(x) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class MobileNetV2(nn.Module): | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         num_classes: int = 1000, | ||||||
|  |         width_mult: float = 1.0, | ||||||
|  |         inverted_residual_setting: Optional[List[List[int]]] = None, | ||||||
|  |         round_nearest: int = 8, | ||||||
|  |         block: Optional[Callable[..., nn.Module]] = None, | ||||||
|  |         norm_layer: Optional[Callable[..., nn.Module]] = None, | ||||||
|  |         dropout: float = 0.2, | ||||||
|  |     ) -> None: | ||||||
|  |         """ | ||||||
|  |         MobileNet V2 main class | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             num_classes (int): Number of classes | ||||||
|  |             width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount | ||||||
|  |             inverted_residual_setting: Network structure | ||||||
|  |             round_nearest (int): Round the number of channels in each layer to be a multiple of this number | ||||||
|  |             Set to 1 to turn off rounding | ||||||
|  |             block: Module specifying inverted residual building block for mobilenet | ||||||
|  |             norm_layer: Module specifying the normalization layer to use | ||||||
|  |             dropout (float): The droupout probability | ||||||
|  | 
 | ||||||
|  |         """ | ||||||
|  |         super().__init__() | ||||||
|  |         _log_api_usage_once(self) | ||||||
|  | 
 | ||||||
|  |         if block is None: | ||||||
|  |             block = InvertedResidual | ||||||
|  | 
 | ||||||
|  |         if norm_layer is None: | ||||||
|  |             norm_layer = nn.BatchNorm2d | ||||||
|  | 
 | ||||||
|  |         input_channel = 32 | ||||||
|  |         last_channel = 1280 | ||||||
|  | 
 | ||||||
|  |         if inverted_residual_setting is None: | ||||||
|  |             inverted_residual_setting = [ | ||||||
|  |                 # t, c, n, s | ||||||
|  |                 [1, 16, 1, 1], | ||||||
|  |                 [6, 24, 2, 1], | ||||||
|  |                 [6, 32, 3, 1], | ||||||
|  |                 [6, 64, 4, 2], | ||||||
|  |                 [6, 96, 3, 1], | ||||||
|  |                 [6, 160, 3, 2], | ||||||
|  |                 [6, 320, 1, 1], | ||||||
|  |             ] | ||||||
|  | 
 | ||||||
|  |         # only check the first element, assuming user knows t,c,n,s are required | ||||||
|  |         if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: | ||||||
|  |             raise ValueError( | ||||||
|  |                 f"inverted_residual_setting should be non-empty or a 4-element list, got {inverted_residual_setting}" | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |         # building first layer | ||||||
|  |         input_channel = _make_divisible(input_channel * width_mult, round_nearest) | ||||||
|  |         self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) | ||||||
|  |         features: List[nn.Module] = [ | ||||||
|  |             Conv2dNormActivation(3, input_channel, stride=2, norm_layer=norm_layer, activation_layer=nn.ReLU6) | ||||||
|  |         ] | ||||||
|  |         # building inverted residual blocks | ||||||
|  |         for t, c, n, s in inverted_residual_setting: | ||||||
|  |             output_channel = _make_divisible(c * width_mult, round_nearest) | ||||||
|  |             for i in range(n): | ||||||
|  |                 stride = s if i == 0 else 1 | ||||||
|  |                 features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer)) | ||||||
|  |                 input_channel = output_channel | ||||||
|  |         # building last several layers | ||||||
|  |         features.append( | ||||||
|  |             Conv2dNormActivation( | ||||||
|  |                 input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.ReLU6 | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |         # make it nn.Sequential | ||||||
|  |         self.features = nn.Sequential(*features) | ||||||
|  |         # self.layer1 = nn.Sequential(*features[:]) | ||||||
|  |         # self.layer2 = features[57:120] | ||||||
|  |         # self.layer3 = features[120:] | ||||||
|  | 
 | ||||||
|  |         # building classifier | ||||||
|  |         self.classifier = nn.Sequential( | ||||||
|  |             nn.Dropout(p=dropout), | ||||||
|  |             nn.Linear(self.last_channel, num_classes), | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         # weight initialization | ||||||
|  |         for m in self.modules(): | ||||||
|  |             if isinstance(m, nn.Conv2d): | ||||||
|  |                 nn.init.kaiming_normal_(m.weight, mode="fan_out") | ||||||
|  |                 if m.bias is not None: | ||||||
|  |                     nn.init.zeros_(m.bias) | ||||||
|  |             elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): | ||||||
|  |                 nn.init.ones_(m.weight) | ||||||
|  |                 nn.init.zeros_(m.bias) | ||||||
|  |             elif isinstance(m, nn.Linear): | ||||||
|  |                 nn.init.normal_(m.weight, 0, 0.01) | ||||||
|  |                 nn.init.zeros_(m.bias) | ||||||
|  | 
 | ||||||
|  |     def _forward_impl(self, x: Tensor) -> Tensor: | ||||||
|  |         # This exists since TorchScript doesn't support inheritance, so the superclass method | ||||||
|  |         # (this one) needs to have a name other than `forward` that can be accessed in a subclass | ||||||
|  |         out_layers = [] | ||||||
|  |         for layer in self.features.named_modules(): | ||||||
|  |             for i, layer1 in enumerate(layer[1]): | ||||||
|  |                 # print(layer1) | ||||||
|  |                 x = layer1(x) | ||||||
|  |                 # print("第{}层,输出大小{}".format(i, x.shape)) | ||||||
|  |                 if i in [0, 10, 18]: | ||||||
|  |                     out_layers.append(x) | ||||||
|  |             break | ||||||
|  |         # x = self.features(x) | ||||||
|  |         # Cannot use "squeeze" as batch-size can be 1 | ||||||
|  |         # x = nn.functional.adaptive_avg_pool2d(x, (1, 1)) | ||||||
|  |         # x = torch.flatten(x, 1) | ||||||
|  |         # x = self.classifier(x) | ||||||
|  |         return out_layers | ||||||
|  | 
 | ||||||
|  |     def forward(self, x: Tensor) -> Tensor: | ||||||
|  |         return self._forward_impl(x) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | _COMMON_META = { | ||||||
|  |     "num_params": 3504872, | ||||||
|  |     "min_size": (1, 1), | ||||||
|  |     "categories": _IMAGENET_CATEGORIES, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class MobileNet_V2_Weights(WeightsEnum): | ||||||
|  |     IMAGENET1K_V1 = Weights( | ||||||
|  |         url="https://download.pytorch.org/models/mobilenet_v2-b0353104.pth", | ||||||
|  |         transforms=partial(ImageClassification, crop_size=224), | ||||||
|  |         meta={ | ||||||
|  |             **_COMMON_META, | ||||||
|  |             "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv2", | ||||||
|  |             "_metrics": { | ||||||
|  |                 "ImageNet-1K": { | ||||||
|  |                     "acc@1": 71.878, | ||||||
|  |                     "acc@5": 90.286, | ||||||
|  |                 } | ||||||
|  |             }, | ||||||
|  |             "_ops": 0.301, | ||||||
|  |             "_file_size": 13.555, | ||||||
|  |             "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", | ||||||
|  |         }, | ||||||
|  |     ) | ||||||
|  |     IMAGENET1K_V2 = Weights( | ||||||
|  |         url="https://download.pytorch.org/models/mobilenet_v2-7ebf99e0.pth", | ||||||
|  |         transforms=partial(ImageClassification, crop_size=224, resize_size=232), | ||||||
|  |         meta={ | ||||||
|  |             **_COMMON_META, | ||||||
|  |             "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-reg-tuning", | ||||||
|  |             "_metrics": { | ||||||
|  |                 "ImageNet-1K": { | ||||||
|  |                     "acc@1": 72.154, | ||||||
|  |                     "acc@5": 90.822, | ||||||
|  |                 } | ||||||
|  |             }, | ||||||
|  |             "_ops": 0.301, | ||||||
|  |             "_file_size": 13.598, | ||||||
|  |             "_docs": """ | ||||||
|  |                 These weights improve upon the results of the original paper by using a modified version of TorchVision's | ||||||
|  |                 `new training recipe | ||||||
|  |                 <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_. | ||||||
|  |             """, | ||||||
|  |         }, | ||||||
|  |     ) | ||||||
|  |     DEFAULT = IMAGENET1K_V2 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # @register_model() | ||||||
|  | # @handle_legacy_interface(weights=("pretrained", MobileNet_V2_Weights.IMAGENET1K_V1)) | ||||||
|  | def mobilenet_v2( | ||||||
|  |     *, weights: Optional[MobileNet_V2_Weights] = MobileNet_V2_Weights.IMAGENET1K_V1, progress: bool = True, **kwargs: Any | ||||||
|  | ) -> MobileNetV2: | ||||||
|  |     """MobileNetV2 architecture from the `MobileNetV2: Inverted Residuals and Linear | ||||||
|  |     Bottlenecks <https://arxiv.org/abs/1801.04381>`_ paper. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         weights (:class:`~torchvision.models.MobileNet_V2_Weights`, optional): The | ||||||
|  |             pretrained weights to use. See | ||||||
|  |             :class:`~torchvision.models.MobileNet_V2_Weights` below for | ||||||
|  |             more details, and possible values. By default, no pre-trained | ||||||
|  |             weights are used. | ||||||
|  |         progress (bool, optional): If True, displays a progress bar of the | ||||||
|  |             download to stderr. Default is True. | ||||||
|  |         **kwargs: parameters passed to the ``torchvision.models.mobilenetv2.MobileNetV2`` | ||||||
|  |             base class. Please refer to the `source code | ||||||
|  |             <https://github.com/pytorch/vision/blob/main/torchvision/models/mobilenetv2.py>`_ | ||||||
|  |             for more details about this class. | ||||||
|  | 
 | ||||||
|  |     .. autoclass:: torchvision.models.MobileNet_V2_Weights | ||||||
|  |         :members: | ||||||
|  |     """ | ||||||
|  |     weights = MobileNet_V2_Weights.verify(weights) | ||||||
|  | 
 | ||||||
|  |     if weights is not None: | ||||||
|  |         _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) | ||||||
|  | 
 | ||||||
|  |     model = MobileNetV2(**kwargs) | ||||||
|  | 
 | ||||||
|  |     if weights is not None: | ||||||
|  |         model.load_state_dict(weights.get_state_dict(progress=progress)) | ||||||
|  | 
 | ||||||
|  |     return model | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def conv1x1(in_planes, out_planes, stride=1): | ||||||
|  |     """1x1 convolution""" | ||||||
|  |     return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) | ||||||
|  | 
 | ||||||
|  | class MobileNetv2Wrapper(nn.Module): | ||||||
|  |     def __init__(self): | ||||||
|  |         super(MobileNetv2Wrapper, self).__init__() | ||||||
|  |         weights = MobileNet_V2_Weights.verify(MobileNet_V2_Weights.IMAGENET1K_V1) | ||||||
|  | 
 | ||||||
|  |         self.model = MobileNetV2() | ||||||
|  | 
 | ||||||
|  |         if weights is not None: | ||||||
|  |             self.model.load_state_dict(weights.get_state_dict(progress=True)) | ||||||
|  |         self.out3 = conv1x1(1280, 128) | ||||||
|  | 
 | ||||||
|  |     def forward(self, x): | ||||||
|  |         # print(x.shape) | ||||||
|  |         out_layers = self.model(x) | ||||||
|  |         # print(x.shape) | ||||||
|  | 
 | ||||||
|  |         # out_layers[0] = self.out1(out_layers[0]) | ||||||
|  |         # out_layers[1] = self.out2(out_layers[1]) | ||||||
|  |         out_layers[2] = self.out3(out_layers[2]) | ||||||
|  |         # print(x.shape) | ||||||
|  |         return out_layers | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
							
								
								
									
										16
									
								
								models/registry.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										16
									
								
								models/registry.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,16 @@ | |||||||
|  | from utils import Registry, build_from_cfg | ||||||
|  | 
 | ||||||
|  | NET = Registry('net') | ||||||
|  | 
 | ||||||
|  | def build(cfg, registry, default_args=None): | ||||||
|  |     if isinstance(cfg, list): | ||||||
|  |         modules = [ | ||||||
|  |             build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg | ||||||
|  |         ] | ||||||
|  |         return nn.Sequential(*modules) | ||||||
|  |     else: | ||||||
|  |         return build_from_cfg(cfg, registry, default_args) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def build_net(cfg): | ||||||
|  |     return build(cfg.net, NET, default_args=dict(cfg=cfg)) | ||||||
							
								
								
									
										142
									
								
								models/resa.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										142
									
								
								models/resa.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,142 @@ | |||||||
|  | import torch.nn as nn | ||||||
|  | import torch | ||||||
|  | import torch.nn.functional as F | ||||||
|  | 
 | ||||||
|  | from models.registry import NET | ||||||
|  | # from .resnet_copy import ResNetWrapper | ||||||
|  | # from .resnet import ResNetWrapper  | ||||||
|  | from .decoder_copy2 import BUSD, PlainDecoder | ||||||
|  | # from .decoder import BUSD, PlainDecoder | ||||||
|  | # from .mobilenetv2 import MobileNetv2Wrapper | ||||||
|  | from .mobilenetv2_copy2 import MobileNetv2Wrapper | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class RESA(nn.Module): | ||||||
|  |     def __init__(self, cfg): | ||||||
|  |         super(RESA, self).__init__() | ||||||
|  |         self.iter = cfg.resa.iter | ||||||
|  |         chan = cfg.resa.input_channel | ||||||
|  |         fea_stride = cfg.backbone.fea_stride | ||||||
|  |         self.height = cfg.img_height // fea_stride | ||||||
|  |         self.width = cfg.img_width // fea_stride | ||||||
|  |         self.alpha = cfg.resa.alpha | ||||||
|  |         conv_stride = cfg.resa.conv_stride | ||||||
|  | 
 | ||||||
|  |         for i in range(self.iter): | ||||||
|  |             conv_vert1 = nn.Conv2d( | ||||||
|  |                 chan, chan, (1, conv_stride), | ||||||
|  |                 padding=(0, conv_stride//2), groups=1, bias=False) | ||||||
|  |             conv_vert2 = nn.Conv2d( | ||||||
|  |                 chan, chan, (1, conv_stride), | ||||||
|  |                 padding=(0, conv_stride//2), groups=1, bias=False) | ||||||
|  | 
 | ||||||
|  |             setattr(self, 'conv_d'+str(i), conv_vert1) | ||||||
|  |             setattr(self, 'conv_u'+str(i), conv_vert2) | ||||||
|  | 
 | ||||||
|  |             conv_hori1 = nn.Conv2d( | ||||||
|  |                 chan, chan, (conv_stride, 1), | ||||||
|  |                 padding=(conv_stride//2, 0), groups=1, bias=False) | ||||||
|  |             conv_hori2 = nn.Conv2d( | ||||||
|  |                 chan, chan, (conv_stride, 1), | ||||||
|  |                 padding=(conv_stride//2, 0), groups=1, bias=False) | ||||||
|  | 
 | ||||||
|  |             setattr(self, 'conv_r'+str(i), conv_hori1) | ||||||
|  |             setattr(self, 'conv_l'+str(i), conv_hori2) | ||||||
|  | 
 | ||||||
|  |             idx_d = (torch.arange(self.height) + self.height // | ||||||
|  |                      2**(self.iter - i)) % self.height | ||||||
|  |             setattr(self, 'idx_d'+str(i), idx_d) | ||||||
|  | 
 | ||||||
|  |             idx_u = (torch.arange(self.height) - self.height // | ||||||
|  |                      2**(self.iter - i)) % self.height | ||||||
|  |             setattr(self, 'idx_u'+str(i), idx_u) | ||||||
|  | 
 | ||||||
|  |             idx_r = (torch.arange(self.width) + self.width // | ||||||
|  |                      2**(self.iter - i)) % self.width | ||||||
|  |             setattr(self, 'idx_r'+str(i), idx_r) | ||||||
|  | 
 | ||||||
|  |             idx_l = (torch.arange(self.width) - self.width // | ||||||
|  |                      2**(self.iter - i)) % self.width | ||||||
|  |             setattr(self, 'idx_l'+str(i), idx_l) | ||||||
|  | 
 | ||||||
|  |     def forward(self, x): | ||||||
|  |         x = x.clone() | ||||||
|  | 
 | ||||||
|  |         for direction in ['d', 'u']: | ||||||
|  |             for i in range(self.iter): | ||||||
|  |                 conv = getattr(self, 'conv_' + direction + str(i)) | ||||||
|  |                 idx = getattr(self, 'idx_' + direction + str(i)) | ||||||
|  |                 x.add_(self.alpha * F.relu(conv(x[..., idx, :]))) | ||||||
|  | 
 | ||||||
|  |         for direction in ['r', 'l']: | ||||||
|  |             for i in range(self.iter): | ||||||
|  |                 conv = getattr(self, 'conv_' + direction + str(i)) | ||||||
|  |                 idx = getattr(self, 'idx_' + direction + str(i)) | ||||||
|  |                 x.add_(self.alpha * F.relu(conv(x[..., idx]))) | ||||||
|  | 
 | ||||||
|  |         return x | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class ExistHead(nn.Module): | ||||||
|  |     def __init__(self, cfg=None): | ||||||
|  |         super(ExistHead, self).__init__() | ||||||
|  |         self.cfg = cfg | ||||||
|  | 
 | ||||||
|  |         self.dropout = nn.Dropout2d(0.1)  # ??? | ||||||
|  |         self.conv8 = nn.Conv2d(128, cfg.num_classes, 1) | ||||||
|  | 
 | ||||||
|  |         stride = cfg.backbone.fea_stride * 2 | ||||||
|  |         self.fc9 = nn.Linear( | ||||||
|  |             int(cfg.num_classes * cfg.img_width / stride * cfg.img_height / stride), 128) | ||||||
|  |         self.fc10 = nn.Linear(128, cfg.num_classes-1) | ||||||
|  | 
 | ||||||
|  |     def forward(self, x): | ||||||
|  |         x = self.dropout(x) | ||||||
|  |         x = self.conv8(x) | ||||||
|  | 
 | ||||||
|  |         x = F.softmax(x, dim=1) | ||||||
|  |         x = F.avg_pool2d(x, 2, stride=2, padding=0) | ||||||
|  |         x = x.view(-1, x.numel() // x.shape[0]) | ||||||
|  |         x = self.fc9(x) | ||||||
|  |         x = F.relu(x) | ||||||
|  |         x = self.fc10(x) | ||||||
|  |         x = torch.sigmoid(x) | ||||||
|  | 
 | ||||||
|  |         return x | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @NET.register_module | ||||||
|  | class RESANet(nn.Module): | ||||||
|  |     def __init__(self, cfg): | ||||||
|  |         super(RESANet, self).__init__() | ||||||
|  |         self.cfg = cfg | ||||||
|  |     #     self.backbone = ResNetWrapper(resnet='resnet34',pretrained=True, | ||||||
|  |     # replace_stride_with_dilation=[False, False, False], | ||||||
|  |     # out_conv=False) | ||||||
|  |         self.backbone = MobileNetv2Wrapper() | ||||||
|  |         self.resa = RESA(cfg) | ||||||
|  |         self.decoder = eval(cfg.decoder)(cfg) | ||||||
|  |         self.heads = ExistHead(cfg)  | ||||||
|  | 
 | ||||||
|  |     def forward(self, batch): | ||||||
|  |         # x1, fea, _, _ = self.backbone(batch) | ||||||
|  |         # fea = self.resa(fea) | ||||||
|  |         # # print(fea.shape) | ||||||
|  |         # seg = self.decoder([x1,fea]) | ||||||
|  |         # # print(seg.shape) | ||||||
|  |         # exist = self.heads(fea) | ||||||
|  |          | ||||||
|  |         fea1,fea2,fea = self.backbone(batch) | ||||||
|  |         # print('fea1',fea1.shape) | ||||||
|  |         # print('fea2',fea2.shape) | ||||||
|  |         # print('fea',fea.shape) | ||||||
|  |         fea = self.resa(fea) | ||||||
|  |         # print(fea.shape) | ||||||
|  |         seg = self.decoder([fea1,fea2,fea]) | ||||||
|  |         # print(seg.shape) | ||||||
|  |         exist = self.heads(fea) | ||||||
|  | 
 | ||||||
|  |         output = {'seg': seg, 'exist': exist} | ||||||
|  | 
 | ||||||
|  |         return output | ||||||
							
								
								
									
										377
									
								
								models/resnet.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										377
									
								
								models/resnet.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,377 @@ | |||||||
|  | import torch | ||||||
|  | from torch import nn | ||||||
|  | import torch.nn.functional as F | ||||||
|  | from torch.hub import load_state_dict_from_url | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # This code is borrow from torchvision. | ||||||
|  | 
 | ||||||
|  | model_urls = { | ||||||
|  |     'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', | ||||||
|  |     'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', | ||||||
|  |     'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', | ||||||
|  |     'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', | ||||||
|  |     'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', | ||||||
|  |     'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', | ||||||
|  |     'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', | ||||||
|  |     'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', | ||||||
|  |     'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): | ||||||
|  |     """3x3 convolution with padding""" | ||||||
|  |     return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, | ||||||
|  |                      padding=dilation, groups=groups, bias=False, dilation=dilation) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def conv1x1(in_planes, out_planes, stride=1): | ||||||
|  |     """1x1 convolution""" | ||||||
|  |     return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class BasicBlock(nn.Module): | ||||||
|  |     expansion = 1 | ||||||
|  | 
 | ||||||
|  |     def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, | ||||||
|  |                  base_width=64, dilation=1, norm_layer=None): | ||||||
|  |         super(BasicBlock, self).__init__() | ||||||
|  |         if norm_layer is None: | ||||||
|  |             norm_layer = nn.BatchNorm2d | ||||||
|  |         if groups != 1 or base_width != 64: | ||||||
|  |             raise ValueError( | ||||||
|  |                 'BasicBlock only supports groups=1 and base_width=64') | ||||||
|  |         # if dilation > 1: | ||||||
|  |         #     raise NotImplementedError( | ||||||
|  |         #         "Dilation > 1 not supported in BasicBlock") | ||||||
|  |         # Both self.conv1 and self.downsample layers downsample the input when stride != 1 | ||||||
|  |         self.conv1 = conv3x3(inplanes, planes, stride, dilation=dilation) | ||||||
|  |         self.bn1 = norm_layer(planes) | ||||||
|  |         self.relu = nn.ReLU(inplace=True) | ||||||
|  |         self.conv2 = conv3x3(planes, planes, dilation=dilation) | ||||||
|  |         self.bn2 = norm_layer(planes) | ||||||
|  |         self.downsample = downsample | ||||||
|  |         self.stride = stride | ||||||
|  | 
 | ||||||
|  |     def forward(self, x): | ||||||
|  |         identity = x | ||||||
|  | 
 | ||||||
|  |         out = self.conv1(x) | ||||||
|  |         out = self.bn1(out) | ||||||
|  |         out = self.relu(out) | ||||||
|  | 
 | ||||||
|  |         out = self.conv2(out) | ||||||
|  |         out = self.bn2(out) | ||||||
|  | 
 | ||||||
|  |         if self.downsample is not None: | ||||||
|  |             identity = self.downsample(x) | ||||||
|  | 
 | ||||||
|  |         out += identity | ||||||
|  |         out = self.relu(out) | ||||||
|  | 
 | ||||||
|  |         return out | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class Bottleneck(nn.Module): | ||||||
|  |     expansion = 4 | ||||||
|  | 
 | ||||||
|  |     def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, | ||||||
|  |                  base_width=64, dilation=1, norm_layer=None): | ||||||
|  |         super(Bottleneck, self).__init__() | ||||||
|  |         if norm_layer is None: | ||||||
|  |             norm_layer = nn.BatchNorm2d | ||||||
|  |         width = int(planes * (base_width / 64.)) * groups | ||||||
|  |         # Both self.conv2 and self.downsample layers downsample the input when stride != 1 | ||||||
|  |         self.conv1 = conv1x1(inplanes, width) | ||||||
|  |         self.bn1 = norm_layer(width) | ||||||
|  |         self.conv2 = conv3x3(width, width, stride, groups, dilation) | ||||||
|  |         self.bn2 = norm_layer(width) | ||||||
|  |         self.conv3 = conv1x1(width, planes * self.expansion) | ||||||
|  |         self.bn3 = norm_layer(planes * self.expansion) | ||||||
|  |         self.relu = nn.ReLU(inplace=True) | ||||||
|  |         self.downsample = downsample | ||||||
|  |         self.stride = stride | ||||||
|  | 
 | ||||||
|  |     def forward(self, x): | ||||||
|  |         identity = x | ||||||
|  | 
 | ||||||
|  |         out = self.conv1(x) | ||||||
|  |         out = self.bn1(out) | ||||||
|  |         out = self.relu(out) | ||||||
|  | 
 | ||||||
|  |         out = self.conv2(out) | ||||||
|  |         out = self.bn2(out) | ||||||
|  |         out = self.relu(out) | ||||||
|  | 
 | ||||||
|  |         out = self.conv3(out) | ||||||
|  |         out = self.bn3(out) | ||||||
|  | 
 | ||||||
|  |         if self.downsample is not None: | ||||||
|  |             identity = self.downsample(x) | ||||||
|  | 
 | ||||||
|  |         out += identity | ||||||
|  |         out = self.relu(out) | ||||||
|  | 
 | ||||||
|  |         return out | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class ResNetWrapper(nn.Module): | ||||||
|  | 
 | ||||||
|  |     def __init__(self, cfg): | ||||||
|  |         super(ResNetWrapper, self).__init__() | ||||||
|  |         self.cfg = cfg | ||||||
|  |         self.in_channels = [64, 128, 256, 512] | ||||||
|  |         if 'in_channels' in cfg.backbone: | ||||||
|  |             self.in_channels = cfg.backbone.in_channels | ||||||
|  |         self.model = eval(cfg.backbone.resnet)( | ||||||
|  |             pretrained=cfg.backbone.pretrained, | ||||||
|  |             replace_stride_with_dilation=cfg.backbone.replace_stride_with_dilation, in_channels=self.in_channels) | ||||||
|  |         self.out = None | ||||||
|  |         if cfg.backbone.out_conv: | ||||||
|  |             out_channel = 512 | ||||||
|  |             for chan in reversed(self.in_channels): | ||||||
|  |                 if chan < 0: continue | ||||||
|  |                 out_channel = chan | ||||||
|  |                 break | ||||||
|  |             self.out = conv1x1( | ||||||
|  |                 out_channel * self.model.expansion, 128) | ||||||
|  | 
 | ||||||
|  |     def forward(self, x): | ||||||
|  |         x = self.model(x) | ||||||
|  |         if self.out: | ||||||
|  |             x = self.out(x) | ||||||
|  |         return x | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class ResNet(nn.Module): | ||||||
|  | 
 | ||||||
|  |     def __init__(self, block, layers, zero_init_residual=False, | ||||||
|  |                  groups=1, width_per_group=64, replace_stride_with_dilation=None, | ||||||
|  |                  norm_layer=None, in_channels=None): | ||||||
|  |         super(ResNet, self).__init__() | ||||||
|  |         if norm_layer is None: | ||||||
|  |             norm_layer = nn.BatchNorm2d | ||||||
|  |         self._norm_layer = norm_layer | ||||||
|  | 
 | ||||||
|  |         self.inplanes = 64 | ||||||
|  |         self.dilation = 1 | ||||||
|  |         if replace_stride_with_dilation is None: | ||||||
|  |             # each element in the tuple indicates if we should replace | ||||||
|  |             # the 2x2 stride with a dilated convolution instead | ||||||
|  |             replace_stride_with_dilation = [False, False, False] | ||||||
|  |         if len(replace_stride_with_dilation) != 3: | ||||||
|  |             raise ValueError("replace_stride_with_dilation should be None " | ||||||
|  |                              "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) | ||||||
|  |         self.groups = groups | ||||||
|  |         self.base_width = width_per_group | ||||||
|  |         self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, | ||||||
|  |                                bias=False) | ||||||
|  |         self.bn1 = norm_layer(self.inplanes) | ||||||
|  |         self.relu = nn.ReLU(inplace=True) | ||||||
|  |         self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | ||||||
|  |         self.in_channels = in_channels | ||||||
|  |         self.layer1 = self._make_layer(block, in_channels[0], layers[0]) | ||||||
|  |         self.layer2 = self._make_layer(block, in_channels[1], layers[1], stride=2, | ||||||
|  |                                        dilate=replace_stride_with_dilation[0]) | ||||||
|  |         self.layer3 = self._make_layer(block, in_channels[2], layers[2], stride=2, | ||||||
|  |                                        dilate=replace_stride_with_dilation[1]) | ||||||
|  |         if in_channels[3] > 0: | ||||||
|  |             self.layer4 = self._make_layer(block, in_channels[3], layers[3], stride=2, | ||||||
|  |                                            dilate=replace_stride_with_dilation[2]) | ||||||
|  |         self.expansion = block.expansion | ||||||
|  | 
 | ||||||
|  |         # self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) | ||||||
|  |         # self.fc = nn.Linear(512 * block.expansion, num_classes) | ||||||
|  | 
 | ||||||
|  |         for m in self.modules(): | ||||||
|  |             if isinstance(m, nn.Conv2d): | ||||||
|  |                 nn.init.kaiming_normal_( | ||||||
|  |                     m.weight, mode='fan_out', nonlinearity='relu') | ||||||
|  |             elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): | ||||||
|  |                 nn.init.constant_(m.weight, 1) | ||||||
|  |                 nn.init.constant_(m.bias, 0) | ||||||
|  | 
 | ||||||
|  |         # Zero-initialize the last BN in each residual branch, | ||||||
|  |         # so that the residual branch starts with zeros, and each residual block behaves like an identity. | ||||||
|  |         # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 | ||||||
|  |         if zero_init_residual: | ||||||
|  |             for m in self.modules(): | ||||||
|  |                 if isinstance(m, Bottleneck): | ||||||
|  |                     nn.init.constant_(m.bn3.weight, 0) | ||||||
|  |                 elif isinstance(m, BasicBlock): | ||||||
|  |                     nn.init.constant_(m.bn2.weight, 0) | ||||||
|  | 
 | ||||||
|  |     def _make_layer(self, block, planes, blocks, stride=1, dilate=False): | ||||||
|  |         norm_layer = self._norm_layer | ||||||
|  |         downsample = None | ||||||
|  |         previous_dilation = self.dilation | ||||||
|  |         if dilate: | ||||||
|  |             self.dilation *= stride | ||||||
|  |             stride = 1 | ||||||
|  |         if stride != 1 or self.inplanes != planes * block.expansion: | ||||||
|  |             downsample = nn.Sequential( | ||||||
|  |                 conv1x1(self.inplanes, planes * block.expansion, stride), | ||||||
|  |                 norm_layer(planes * block.expansion), | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |         layers = [] | ||||||
|  |         layers.append(block(self.inplanes, planes, stride, downsample, self.groups, | ||||||
|  |                             self.base_width, previous_dilation, norm_layer)) | ||||||
|  |         self.inplanes = planes * block.expansion | ||||||
|  |         for _ in range(1, blocks): | ||||||
|  |             layers.append(block(self.inplanes, planes, groups=self.groups, | ||||||
|  |                                 base_width=self.base_width, dilation=self.dilation, | ||||||
|  |                                 norm_layer=norm_layer)) | ||||||
|  | 
 | ||||||
|  |         return nn.Sequential(*layers) | ||||||
|  | 
 | ||||||
|  |     def forward(self, x): | ||||||
|  |         x = self.conv1(x) | ||||||
|  |         x = self.bn1(x) | ||||||
|  |         x = self.relu(x) | ||||||
|  |         x = self.maxpool(x) | ||||||
|  | 
 | ||||||
|  |         x = self.layer1(x) | ||||||
|  |         x = self.layer2(x) | ||||||
|  |         x = self.layer3(x) | ||||||
|  |         if self.in_channels[3] > 0: | ||||||
|  |             x = self.layer4(x) | ||||||
|  | 
 | ||||||
|  |         # x = self.avgpool(x) | ||||||
|  |         # x = torch.flatten(x, 1) | ||||||
|  |         # x = self.fc(x) | ||||||
|  | 
 | ||||||
|  |         return x | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def _resnet(arch, block, layers, pretrained, progress, **kwargs): | ||||||
|  |     model = ResNet(block, layers, **kwargs) | ||||||
|  |     if pretrained: | ||||||
|  |         state_dict = load_state_dict_from_url(model_urls[arch], | ||||||
|  |                                               progress=progress) | ||||||
|  |         model.load_state_dict(state_dict, strict=False) | ||||||
|  |     return model | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def resnet18(pretrained=False, progress=True, **kwargs): | ||||||
|  |     r"""ResNet-18 model from | ||||||
|  |     `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         pretrained (bool): If True, returns a model pre-trained on ImageNet | ||||||
|  |         progress (bool): If True, displays a progress bar of the download to stderr | ||||||
|  |     """ | ||||||
|  |     return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, | ||||||
|  |                    **kwargs) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def resnet34(pretrained=False, progress=True, **kwargs): | ||||||
|  |     r"""ResNet-34 model from | ||||||
|  |     `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         pretrained (bool): If True, returns a model pre-trained on ImageNet | ||||||
|  |         progress (bool): If True, displays a progress bar of the download to stderr | ||||||
|  |     """ | ||||||
|  |     return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, | ||||||
|  |                    **kwargs) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def resnet50(pretrained=False, progress=True, **kwargs): | ||||||
|  |     r"""ResNet-50 model from | ||||||
|  |     `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         pretrained (bool): If True, returns a model pre-trained on ImageNet | ||||||
|  |         progress (bool): If True, displays a progress bar of the download to stderr | ||||||
|  |     """ | ||||||
|  |     return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, | ||||||
|  |                    **kwargs) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def resnet101(pretrained=False, progress=True, **kwargs): | ||||||
|  |     r"""ResNet-101 model from | ||||||
|  |     `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         pretrained (bool): If True, returns a model pre-trained on ImageNet | ||||||
|  |         progress (bool): If True, displays a progress bar of the download to stderr | ||||||
|  |     """ | ||||||
|  |     return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, | ||||||
|  |                    **kwargs) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def resnet152(pretrained=False, progress=True, **kwargs): | ||||||
|  |     r"""ResNet-152 model from | ||||||
|  |     `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         pretrained (bool): If True, returns a model pre-trained on ImageNet | ||||||
|  |         progress (bool): If True, displays a progress bar of the download to stderr | ||||||
|  |     """ | ||||||
|  |     return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, | ||||||
|  |                    **kwargs) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): | ||||||
|  |     r"""ResNeXt-50 32x4d model from | ||||||
|  |     `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_ | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         pretrained (bool): If True, returns a model pre-trained on ImageNet | ||||||
|  |         progress (bool): If True, displays a progress bar of the download to stderr | ||||||
|  |     """ | ||||||
|  |     kwargs['groups'] = 32 | ||||||
|  |     kwargs['width_per_group'] = 4 | ||||||
|  |     return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], | ||||||
|  |                    pretrained, progress, **kwargs) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): | ||||||
|  |     r"""ResNeXt-101 32x8d model from | ||||||
|  |     `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_ | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         pretrained (bool): If True, returns a model pre-trained on ImageNet | ||||||
|  |         progress (bool): If True, displays a progress bar of the download to stderr | ||||||
|  |     """ | ||||||
|  |     kwargs['groups'] = 32 | ||||||
|  |     kwargs['width_per_group'] = 8 | ||||||
|  |     return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], | ||||||
|  |                    pretrained, progress, **kwargs) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def wide_resnet50_2(pretrained=False, progress=True, **kwargs): | ||||||
|  |     r"""Wide ResNet-50-2 model from | ||||||
|  |     `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_ | ||||||
|  | 
 | ||||||
|  |     The model is the same as ResNet except for the bottleneck number of channels | ||||||
|  |     which is twice larger in every block. The number of channels in outer 1x1 | ||||||
|  |     convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 | ||||||
|  |     channels, and in Wide ResNet-50-2 has 2048-1024-2048. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         pretrained (bool): If True, returns a model pre-trained on ImageNet | ||||||
|  |         progress (bool): If True, displays a progress bar of the download to stderr | ||||||
|  |     """ | ||||||
|  |     kwargs['width_per_group'] = 64 * 2 | ||||||
|  |     return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], | ||||||
|  |                    pretrained, progress, **kwargs) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def wide_resnet101_2(pretrained=False, progress=True, **kwargs): | ||||||
|  |     r"""Wide ResNet-101-2 model from | ||||||
|  |     `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_ | ||||||
|  | 
 | ||||||
|  |     The model is the same as ResNet except for the bottleneck number of channels | ||||||
|  |     which is twice larger in every block. The number of channels in outer 1x1 | ||||||
|  |     convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 | ||||||
|  |     channels, and in Wide ResNet-50-2 has 2048-1024-2048. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         pretrained (bool): If True, returns a model pre-trained on ImageNet | ||||||
|  |         progress (bool): If True, displays a progress bar of the download to stderr | ||||||
|  |     """ | ||||||
|  |     kwargs['width_per_group'] = 64 * 2 | ||||||
|  |     return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], | ||||||
|  |                    pretrained, progress, **kwargs) | ||||||
							
								
								
									
										432
									
								
								models/resnet_copy.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										432
									
								
								models/resnet_copy.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,432 @@ | |||||||
|  | import torch | ||||||
|  | from torch import nn | ||||||
|  | import torch.nn.functional as F | ||||||
|  | from torch.hub import load_state_dict_from_url | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | model_urls = { | ||||||
|  |     'resnet18': | ||||||
|  |     'https://download.pytorch.org/models/resnet18-5c106cde.pth', | ||||||
|  |     'resnet34': | ||||||
|  |     'https://download.pytorch.org/models/resnet34-333f7ec4.pth', | ||||||
|  |     'resnet50': | ||||||
|  |     'https://download.pytorch.org/models/resnet50-19c8e357.pth', | ||||||
|  |     'resnet101': | ||||||
|  |     'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', | ||||||
|  |     'resnet152': | ||||||
|  |     'https://download.pytorch.org/models/resnet152-b121ed2d.pth', | ||||||
|  |     'resnext50_32x4d': | ||||||
|  |     'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', | ||||||
|  |     'resnext101_32x8d': | ||||||
|  |     'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', | ||||||
|  |     'wide_resnet50_2': | ||||||
|  |     'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', | ||||||
|  |     'wide_resnet101_2': | ||||||
|  |     'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): | ||||||
|  |     """3x3 convolution with padding""" | ||||||
|  |     return nn.Conv2d(in_planes, | ||||||
|  |                      out_planes, | ||||||
|  |                      kernel_size=3, | ||||||
|  |                      stride=stride, | ||||||
|  |                      padding=dilation, | ||||||
|  |                      groups=groups, | ||||||
|  |                      bias=False, | ||||||
|  |                      dilation=dilation) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def conv1x1(in_planes, out_planes, stride=1): | ||||||
|  |     """1x1 convolution""" | ||||||
|  |     return nn.Conv2d(in_planes, | ||||||
|  |                      out_planes, | ||||||
|  |                      kernel_size=1, | ||||||
|  |                      stride=stride, | ||||||
|  |                      bias=False) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class BasicBlock(nn.Module): | ||||||
|  |     expansion = 1 | ||||||
|  | 
 | ||||||
|  |     def __init__(self, | ||||||
|  |                  inplanes, | ||||||
|  |                  planes, | ||||||
|  |                  stride=1, | ||||||
|  |                  downsample=None, | ||||||
|  |                  groups=1, | ||||||
|  |                  base_width=64, | ||||||
|  |                  dilation=1, | ||||||
|  |                  norm_layer=None): | ||||||
|  |         super(BasicBlock, self).__init__() | ||||||
|  |         if norm_layer is None: | ||||||
|  |             norm_layer = nn.BatchNorm2d | ||||||
|  |         if groups != 1 or base_width != 64: | ||||||
|  |             raise ValueError( | ||||||
|  |                 'BasicBlock only supports groups=1 and base_width=64') | ||||||
|  |         # if dilation > 1: | ||||||
|  |         #     raise NotImplementedError( | ||||||
|  |         #         "Dilation > 1 not supported in BasicBlock") | ||||||
|  |         # Both self.conv1 and self.downsample layers downsample the input when stride != 1 | ||||||
|  |         self.conv1 = conv3x3(inplanes, planes, stride, dilation=dilation) | ||||||
|  |         self.bn1 = norm_layer(planes) | ||||||
|  |         self.relu = nn.ReLU(inplace=True) | ||||||
|  |         self.conv2 = conv3x3(planes, planes, dilation=dilation) | ||||||
|  |         self.bn2 = norm_layer(planes) | ||||||
|  |         self.downsample = downsample | ||||||
|  |         self.stride = stride | ||||||
|  | 
 | ||||||
|  |     def forward(self, x): | ||||||
|  |         identity = x | ||||||
|  | 
 | ||||||
|  |         out = self.conv1(x) | ||||||
|  |         out = self.bn1(out) | ||||||
|  |         out = self.relu(out) | ||||||
|  | 
 | ||||||
|  |         out = self.conv2(out) | ||||||
|  |         out = self.bn2(out) | ||||||
|  | 
 | ||||||
|  |         if self.downsample is not None: | ||||||
|  |             identity = self.downsample(x) | ||||||
|  | 
 | ||||||
|  |         out += identity | ||||||
|  |         out = self.relu(out) | ||||||
|  | 
 | ||||||
|  |         return out | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class Bottleneck(nn.Module): | ||||||
|  |     expansion = 4 | ||||||
|  | 
 | ||||||
|  |     def __init__(self, | ||||||
|  |                  inplanes, | ||||||
|  |                  planes, | ||||||
|  |                  stride=1, | ||||||
|  |                  downsample=None, | ||||||
|  |                  groups=1, | ||||||
|  |                  base_width=64, | ||||||
|  |                  dilation=1, | ||||||
|  |                  norm_layer=None): | ||||||
|  |         super(Bottleneck, self).__init__() | ||||||
|  |         if norm_layer is None: | ||||||
|  |             norm_layer = nn.BatchNorm2d | ||||||
|  |         width = int(planes * (base_width / 64.)) * groups | ||||||
|  |         # Both self.conv2 and self.downsample layers downsample the input when stride != 1 | ||||||
|  |         self.conv1 = conv1x1(inplanes, width) | ||||||
|  |         self.bn1 = norm_layer(width) | ||||||
|  |         self.conv2 = conv3x3(width, width, stride, groups, dilation) | ||||||
|  |         self.bn2 = norm_layer(width) | ||||||
|  |         self.conv3 = conv1x1(width, planes * self.expansion) | ||||||
|  |         self.bn3 = norm_layer(planes * self.expansion) | ||||||
|  |         self.relu = nn.ReLU(inplace=True) | ||||||
|  |         self.downsample = downsample | ||||||
|  |         self.stride = stride | ||||||
|  | 
 | ||||||
|  |     def forward(self, x): | ||||||
|  |         identity = x | ||||||
|  | 
 | ||||||
|  |         out = self.conv1(x) | ||||||
|  |         out = self.bn1(out) | ||||||
|  |         out = self.relu(out) | ||||||
|  | 
 | ||||||
|  |         out = self.conv2(out) | ||||||
|  |         out = self.bn2(out) | ||||||
|  |         out = self.relu(out) | ||||||
|  | 
 | ||||||
|  |         out = self.conv3(out) | ||||||
|  |         out = self.bn3(out) | ||||||
|  | 
 | ||||||
|  |         if self.downsample is not None: | ||||||
|  |             identity = self.downsample(x) | ||||||
|  | 
 | ||||||
|  |         out += identity | ||||||
|  |         out = self.relu(out) | ||||||
|  | 
 | ||||||
|  |         return out | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class ResNetWrapper(nn.Module): | ||||||
|  |     def __init__(self, | ||||||
|  |                  resnet='resnet18', | ||||||
|  |                  pretrained=True, | ||||||
|  |                  replace_stride_with_dilation=[False, False, False], | ||||||
|  |                  out_conv=False, | ||||||
|  |                  fea_stride=8, | ||||||
|  |                  out_channel=128, | ||||||
|  |                  in_channels=[64, 128, 256, 512], | ||||||
|  |                  cfg=None): | ||||||
|  |         super(ResNetWrapper, self).__init__() | ||||||
|  |         self.cfg = cfg | ||||||
|  |         self.in_channels = in_channels | ||||||
|  | 
 | ||||||
|  |         self.model = eval(resnet)( | ||||||
|  |             pretrained=pretrained, | ||||||
|  |             replace_stride_with_dilation=replace_stride_with_dilation, | ||||||
|  |             in_channels=self.in_channels) | ||||||
|  |         self.out = None | ||||||
|  |         if out_conv: | ||||||
|  |             out_channel = 512 | ||||||
|  |             for chan in reversed(self.in_channels): | ||||||
|  |                 if chan < 0: continue | ||||||
|  |                 out_channel = chan | ||||||
|  |                 break | ||||||
|  |             self.out = conv1x1(out_channel * self.model.expansion, | ||||||
|  |                                cfg.featuremap_out_channel) | ||||||
|  | 
 | ||||||
|  |     def forward(self, x): | ||||||
|  |         x = self.model(x) | ||||||
|  |         if self.out: | ||||||
|  |             x[-1] = self.out(x[-1]) | ||||||
|  |         return x | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class ResNet(nn.Module): | ||||||
|  |     def __init__(self, | ||||||
|  |                  block, | ||||||
|  |                  layers, | ||||||
|  |                  zero_init_residual=False, | ||||||
|  |                  groups=1, | ||||||
|  |                  width_per_group=64, | ||||||
|  |                  replace_stride_with_dilation=None, | ||||||
|  |                  norm_layer=None, | ||||||
|  |                  in_channels=None): | ||||||
|  |         super(ResNet, self).__init__() | ||||||
|  |         if norm_layer is None: | ||||||
|  |             norm_layer = nn.BatchNorm2d | ||||||
|  |         self._norm_layer = norm_layer | ||||||
|  | 
 | ||||||
|  |         self.inplanes = 64 | ||||||
|  |         self.dilation = 1 | ||||||
|  |         if replace_stride_with_dilation is None: | ||||||
|  |             # each element in the tuple indicates if we should replace | ||||||
|  |             # the 2x2 stride with a dilated convolution instead | ||||||
|  |             replace_stride_with_dilation = [False, False, False] | ||||||
|  |         if len(replace_stride_with_dilation) != 3: | ||||||
|  |             raise ValueError("replace_stride_with_dilation should be None " | ||||||
|  |                              "or a 3-element tuple, got {}".format( | ||||||
|  |                                  replace_stride_with_dilation)) | ||||||
|  |         self.groups = groups | ||||||
|  |         self.base_width = width_per_group | ||||||
|  |         self.conv1 = nn.Conv2d(3, | ||||||
|  |                                self.inplanes, | ||||||
|  |                                kernel_size=7, | ||||||
|  |                                stride=2, | ||||||
|  |                                padding=3, | ||||||
|  |                                bias=False) | ||||||
|  |         self.bn1 = norm_layer(self.inplanes) | ||||||
|  |         self.relu = nn.ReLU(inplace=True) | ||||||
|  |         self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | ||||||
|  |         self.in_channels = in_channels | ||||||
|  |         self.layer1 = self._make_layer(block, in_channels[0], layers[0]) | ||||||
|  |         self.layer2 = self._make_layer(block, | ||||||
|  |                                        in_channels[1], | ||||||
|  |                                        layers[1], | ||||||
|  |                                        stride=2, | ||||||
|  |                                        dilate=replace_stride_with_dilation[0]) | ||||||
|  |         self.layer3 = self._make_layer(block, | ||||||
|  |                                        in_channels[2], | ||||||
|  |                                        layers[2], | ||||||
|  |                                        stride=2, | ||||||
|  |                                        dilate=replace_stride_with_dilation[1]) | ||||||
|  |         if in_channels[3] > 0: | ||||||
|  |             self.layer4 = self._make_layer( | ||||||
|  |                 block, | ||||||
|  |                 in_channels[3], | ||||||
|  |                 layers[3], | ||||||
|  |                 stride=2, | ||||||
|  |                 dilate=replace_stride_with_dilation[2]) | ||||||
|  |         self.expansion = block.expansion | ||||||
|  | 
 | ||||||
|  |         # self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) | ||||||
|  |         # self.fc = nn.Linear(512 * block.expansion, num_classes) | ||||||
|  | 
 | ||||||
|  |         for m in self.modules(): | ||||||
|  |             if isinstance(m, nn.Conv2d): | ||||||
|  |                 nn.init.kaiming_normal_(m.weight, | ||||||
|  |                                         mode='fan_out', | ||||||
|  |                                         nonlinearity='relu') | ||||||
|  |             elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): | ||||||
|  |                 nn.init.constant_(m.weight, 1) | ||||||
|  |                 nn.init.constant_(m.bias, 0) | ||||||
|  | 
 | ||||||
|  |         # Zero-initialize the last BN in each residual branch, | ||||||
|  |         # so that the residual branch starts with zeros, and each residual block behaves like an identity. | ||||||
|  |         # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 | ||||||
|  |         if zero_init_residual: | ||||||
|  |             for m in self.modules(): | ||||||
|  |                 if isinstance(m, Bottleneck): | ||||||
|  |                     nn.init.constant_(m.bn3.weight, 0) | ||||||
|  |                 elif isinstance(m, BasicBlock): | ||||||
|  |                     nn.init.constant_(m.bn2.weight, 0) | ||||||
|  | 
 | ||||||
|  |     def _make_layer(self, block, planes, blocks, stride=1, dilate=False): | ||||||
|  |         norm_layer = self._norm_layer | ||||||
|  |         downsample = None | ||||||
|  |         previous_dilation = self.dilation | ||||||
|  |         if dilate: | ||||||
|  |             self.dilation *= stride | ||||||
|  |             stride = 1 | ||||||
|  |         if stride != 1 or self.inplanes != planes * block.expansion: | ||||||
|  |             downsample = nn.Sequential( | ||||||
|  |                 conv1x1(self.inplanes, planes * block.expansion, stride), | ||||||
|  |                 norm_layer(planes * block.expansion), | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |         layers = [] | ||||||
|  |         layers.append( | ||||||
|  |             block(self.inplanes, planes, stride, downsample, self.groups, | ||||||
|  |                   self.base_width, previous_dilation, norm_layer)) | ||||||
|  |         self.inplanes = planes * block.expansion | ||||||
|  |         for _ in range(1, blocks): | ||||||
|  |             layers.append( | ||||||
|  |                 block(self.inplanes, | ||||||
|  |                       planes, | ||||||
|  |                       groups=self.groups, | ||||||
|  |                       base_width=self.base_width, | ||||||
|  |                       dilation=self.dilation, | ||||||
|  |                       norm_layer=norm_layer)) | ||||||
|  | 
 | ||||||
|  |         return nn.Sequential(*layers) | ||||||
|  | 
 | ||||||
|  |     def forward(self, x): | ||||||
|  |         out_layers = [] | ||||||
|  |         x = self.conv1(x) | ||||||
|  |         x = self.bn1(x) | ||||||
|  |         x = self.relu(x) | ||||||
|  |         x = self.maxpool(x) | ||||||
|  | 
 | ||||||
|  |         # out_layers = [] | ||||||
|  |         for name in ['layer1', 'layer2', 'layer3', 'layer4']: | ||||||
|  |             if not hasattr(self, name): | ||||||
|  |                 continue | ||||||
|  |             layer = getattr(self, name) | ||||||
|  |             x = layer(x) | ||||||
|  |             out_layers.append(x) | ||||||
|  | 
 | ||||||
|  |         return out_layers | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def _resnet(arch, block, layers, pretrained, progress, **kwargs): | ||||||
|  |     model = ResNet(block, layers, **kwargs) | ||||||
|  |     if pretrained: | ||||||
|  |         print('pretrained model: ', model_urls[arch]) | ||||||
|  |         # state_dict = torch.load(model_urls[arch])['net'] | ||||||
|  |         state_dict = load_state_dict_from_url(model_urls[arch]) | ||||||
|  |         model.load_state_dict(state_dict, strict=False) | ||||||
|  |     return model | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def resnet18(pretrained=False, progress=True, **kwargs): | ||||||
|  |     r"""ResNet-18 model from | ||||||
|  |     `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ | ||||||
|  |     Args: | ||||||
|  |         pretrained (bool): If True, returns a model pre-trained on ImageNet | ||||||
|  |         progress (bool): If True, displays a progress bar of the download to stderr | ||||||
|  |     """ | ||||||
|  |     return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, | ||||||
|  |                    **kwargs) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def resnet34(pretrained=False, progress=True, **kwargs): | ||||||
|  |     r"""ResNet-34 model from | ||||||
|  |     `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ | ||||||
|  |     Args: | ||||||
|  |         pretrained (bool): If True, returns a model pre-trained on ImageNet | ||||||
|  |         progress (bool): If True, displays a progress bar of the download to stderr | ||||||
|  |     """ | ||||||
|  |     return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, | ||||||
|  |                    **kwargs) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def resnet50(pretrained=False, progress=True, **kwargs): | ||||||
|  |     r"""ResNet-50 model from | ||||||
|  |     `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ | ||||||
|  |     Args: | ||||||
|  |         pretrained (bool): If True, returns a model pre-trained on ImageNet | ||||||
|  |         progress (bool): If True, displays a progress bar of the download to stderr | ||||||
|  |     """ | ||||||
|  |     return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, | ||||||
|  |                    **kwargs) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def resnet101(pretrained=False, progress=True, **kwargs): | ||||||
|  |     r"""ResNet-101 model from | ||||||
|  |     `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ | ||||||
|  |     Args: | ||||||
|  |         pretrained (bool): If True, returns a model pre-trained on ImageNet | ||||||
|  |         progress (bool): If True, displays a progress bar of the download to stderr | ||||||
|  |     """ | ||||||
|  |     return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, | ||||||
|  |                    progress, **kwargs) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def resnet152(pretrained=False, progress=True, **kwargs): | ||||||
|  |     r"""ResNet-152 model from | ||||||
|  |     `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ | ||||||
|  |     Args: | ||||||
|  |         pretrained (bool): If True, returns a model pre-trained on ImageNet | ||||||
|  |         progress (bool): If True, displays a progress bar of the download to stderr | ||||||
|  |     """ | ||||||
|  |     return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, | ||||||
|  |                    progress, **kwargs) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): | ||||||
|  |     r"""ResNeXt-50 32x4d model from | ||||||
|  |     `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_ | ||||||
|  |     Args: | ||||||
|  |         pretrained (bool): If True, returns a model pre-trained on ImageNet | ||||||
|  |         progress (bool): If True, displays a progress bar of the download to stderr | ||||||
|  |     """ | ||||||
|  |     kwargs['groups'] = 32 | ||||||
|  |     kwargs['width_per_group'] = 4 | ||||||
|  |     return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], pretrained, | ||||||
|  |                    progress, **kwargs) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): | ||||||
|  |     r"""ResNeXt-101 32x8d model from | ||||||
|  |     `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_ | ||||||
|  |     Args: | ||||||
|  |         pretrained (bool): If True, returns a model pre-trained on ImageNet | ||||||
|  |         progress (bool): If True, displays a progress bar of the download to stderr | ||||||
|  |     """ | ||||||
|  |     kwargs['groups'] = 32 | ||||||
|  |     kwargs['width_per_group'] = 8 | ||||||
|  |     return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], pretrained, | ||||||
|  |                    progress, **kwargs) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def wide_resnet50_2(pretrained=False, progress=True, **kwargs): | ||||||
|  |     r"""Wide ResNet-50-2 model from | ||||||
|  |     `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_ | ||||||
|  |     The model is the same as ResNet except for the bottleneck number of channels | ||||||
|  |     which is twice larger in every block. The number of channels in outer 1x1 | ||||||
|  |     convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 | ||||||
|  |     channels, and in Wide ResNet-50-2 has 2048-1024-2048. | ||||||
|  |     Args: | ||||||
|  |         pretrained (bool): If True, returns a model pre-trained on ImageNet | ||||||
|  |         progress (bool): If True, displays a progress bar of the download to stderr | ||||||
|  |     """ | ||||||
|  |     kwargs['width_per_group'] = 64 * 2 | ||||||
|  |     return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], pretrained, | ||||||
|  |                    progress, **kwargs) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def wide_resnet101_2(pretrained=False, progress=True, **kwargs): | ||||||
|  |     r"""Wide ResNet-101-2 model from | ||||||
|  |     `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_ | ||||||
|  |     The model is the same as ResNet except for the bottleneck number of channels | ||||||
|  |     which is twice larger in every block. The number of channels in outer 1x1 | ||||||
|  |     convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 | ||||||
|  |     channels, and in Wide ResNet-50-2 has 2048-1024-2048. | ||||||
|  |     Args: | ||||||
|  |         pretrained (bool): If True, returns a model pre-trained on ImageNet | ||||||
|  |         progress (bool): If True, displays a progress bar of the download to stderr | ||||||
|  |     """ | ||||||
|  |     kwargs['width_per_group'] = 64 * 2 | ||||||
|  |     return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], pretrained, | ||||||
|  |                    progress, **kwargs) | ||||||
							
								
								
									
										8
									
								
								requirement.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										8
									
								
								requirement.txt
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,8 @@ | |||||||
|  | pandas | ||||||
|  | addict | ||||||
|  | sklearn | ||||||
|  | opencv-python | ||||||
|  | pytorch_warmup | ||||||
|  | scikit-image | ||||||
|  | tqdm | ||||||
|  | termcolor | ||||||
							
								
								
									
										4
									
								
								runner/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								runner/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,4 @@ | |||||||
|  | from .evaluator import * | ||||||
|  | from .resa_trainer import * | ||||||
|  | 
 | ||||||
|  | from .registry import build_evaluator  | ||||||
							
								
								
									
										2
									
								
								runner/evaluator/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								runner/evaluator/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,2 @@ | |||||||
|  | from .tusimple.tusimple import Tusimple | ||||||
|  | from .culane.culane import CULane | ||||||
							
								
								
									
										158
									
								
								runner/evaluator/culane/culane.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										158
									
								
								runner/evaluator/culane/culane.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,158 @@ | |||||||
|  | import torch.nn as nn | ||||||
|  | import torch | ||||||
|  | import torch.nn.functional as F | ||||||
|  | from runner.logger import get_logger | ||||||
|  | 
 | ||||||
|  | from runner.registry import EVALUATOR  | ||||||
|  | import json | ||||||
|  | import os | ||||||
|  | import subprocess | ||||||
|  | from shutil import rmtree | ||||||
|  | import cv2 | ||||||
|  | import numpy as np | ||||||
|  | 
 | ||||||
|  | def check(): | ||||||
|  |     import subprocess | ||||||
|  |     import sys | ||||||
|  |     FNULL = open(os.devnull, 'w') | ||||||
|  |     result = subprocess.call( | ||||||
|  |         './runner/evaluator/culane/lane_evaluation/evaluate', stdout=FNULL, stderr=FNULL) | ||||||
|  |     if result > 1: | ||||||
|  |         print('There is something wrong with evaluate tool, please compile it.') | ||||||
|  |         sys.exit() | ||||||
|  | 
 | ||||||
|  | def read_helper(path): | ||||||
|  |     lines = open(path, 'r').readlines()[1:] | ||||||
|  |     lines = ' '.join(lines) | ||||||
|  |     values = lines.split(' ')[1::2] | ||||||
|  |     keys = lines.split(' ')[0::2] | ||||||
|  |     keys = [key[:-1] for key in keys] | ||||||
|  |     res = {k : v for k,v in zip(keys,values)} | ||||||
|  |     return res | ||||||
|  | 
 | ||||||
|  | def call_culane_eval(data_dir, output_path='./output'): | ||||||
|  |     if data_dir[-1] != '/': | ||||||
|  |         data_dir = data_dir + '/' | ||||||
|  |     detect_dir=os.path.join(output_path, 'lines')+'/' | ||||||
|  | 
 | ||||||
|  |     w_lane=30 | ||||||
|  |     iou=0.5;  # Set iou to 0.3 or 0.5 | ||||||
|  |     im_w=1640 | ||||||
|  |     im_h=590 | ||||||
|  |     frame=1 | ||||||
|  |     list0 = os.path.join(data_dir,'list/test_split/test0_normal.txt') | ||||||
|  |     list1 = os.path.join(data_dir,'list/test_split/test1_crowd.txt') | ||||||
|  |     list2 = os.path.join(data_dir,'list/test_split/test2_hlight.txt') | ||||||
|  |     list3 = os.path.join(data_dir,'list/test_split/test3_shadow.txt') | ||||||
|  |     list4 = os.path.join(data_dir,'list/test_split/test4_noline.txt') | ||||||
|  |     list5 = os.path.join(data_dir,'list/test_split/test5_arrow.txt') | ||||||
|  |     list6 = os.path.join(data_dir,'list/test_split/test6_curve.txt') | ||||||
|  |     list7 = os.path.join(data_dir,'list/test_split/test7_cross.txt') | ||||||
|  |     list8 = os.path.join(data_dir,'list/test_split/test8_night.txt') | ||||||
|  |     if not os.path.exists(os.path.join(output_path,'txt')): | ||||||
|  |         os.mkdir(os.path.join(output_path,'txt')) | ||||||
|  |     out0 = os.path.join(output_path,'txt','out0_normal.txt') | ||||||
|  |     out1 = os.path.join(output_path,'txt','out1_crowd.txt') | ||||||
|  |     out2 = os.path.join(output_path,'txt','out2_hlight.txt') | ||||||
|  |     out3 = os.path.join(output_path,'txt','out3_shadow.txt') | ||||||
|  |     out4 = os.path.join(output_path,'txt','out4_noline.txt') | ||||||
|  |     out5 = os.path.join(output_path,'txt','out5_arrow.txt') | ||||||
|  |     out6 = os.path.join(output_path,'txt','out6_curve.txt') | ||||||
|  |     out7 = os.path.join(output_path,'txt','out7_cross.txt') | ||||||
|  |     out8 = os.path.join(output_path,'txt','out8_night.txt') | ||||||
|  | 
 | ||||||
|  |     eval_cmd = './runner/evaluator/culane/lane_evaluation/evaluate' | ||||||
|  | 
 | ||||||
|  |     os.system('%s -a %s -d %s -i %s -l %s -w %s -t %s -c %s -r %s -f %s -o %s'%(eval_cmd,data_dir,detect_dir,data_dir,list0,w_lane,iou,im_w,im_h,frame,out0)) | ||||||
|  |     os.system('%s -a %s -d %s -i %s -l %s -w %s -t %s -c %s -r %s -f %s -o %s'%(eval_cmd,data_dir,detect_dir,data_dir,list1,w_lane,iou,im_w,im_h,frame,out1)) | ||||||
|  |     os.system('%s -a %s -d %s -i %s -l %s -w %s -t %s -c %s -r %s -f %s -o %s'%(eval_cmd,data_dir,detect_dir,data_dir,list2,w_lane,iou,im_w,im_h,frame,out2)) | ||||||
|  |     os.system('%s -a %s -d %s -i %s -l %s -w %s -t %s -c %s -r %s -f %s -o %s'%(eval_cmd,data_dir,detect_dir,data_dir,list3,w_lane,iou,im_w,im_h,frame,out3)) | ||||||
|  |     os.system('%s -a %s -d %s -i %s -l %s -w %s -t %s -c %s -r %s -f %s -o %s'%(eval_cmd,data_dir,detect_dir,data_dir,list4,w_lane,iou,im_w,im_h,frame,out4)) | ||||||
|  |     os.system('%s -a %s -d %s -i %s -l %s -w %s -t %s -c %s -r %s -f %s -o %s'%(eval_cmd,data_dir,detect_dir,data_dir,list5,w_lane,iou,im_w,im_h,frame,out5)) | ||||||
|  |     os.system('%s -a %s -d %s -i %s -l %s -w %s -t %s -c %s -r %s -f %s -o %s'%(eval_cmd,data_dir,detect_dir,data_dir,list6,w_lane,iou,im_w,im_h,frame,out6)) | ||||||
|  |     os.system('%s -a %s -d %s -i %s -l %s -w %s -t %s -c %s -r %s -f %s -o %s'%(eval_cmd,data_dir,detect_dir,data_dir,list7,w_lane,iou,im_w,im_h,frame,out7)) | ||||||
|  |     os.system('%s -a %s -d %s -i %s -l %s -w %s -t %s -c %s -r %s -f %s -o %s'%(eval_cmd,data_dir,detect_dir,data_dir,list8,w_lane,iou,im_w,im_h,frame,out8)) | ||||||
|  |     res_all = {} | ||||||
|  |     res_all['normal'] = read_helper(out0) | ||||||
|  |     res_all['crowd']= read_helper(out1) | ||||||
|  |     res_all['night']= read_helper(out8) | ||||||
|  |     res_all['noline'] = read_helper(out4) | ||||||
|  |     res_all['shadow'] = read_helper(out3) | ||||||
|  |     res_all['arrow']= read_helper(out5) | ||||||
|  |     res_all['hlight'] = read_helper(out2) | ||||||
|  |     res_all['curve']= read_helper(out6) | ||||||
|  |     res_all['cross']= read_helper(out7) | ||||||
|  |     return res_all | ||||||
|  | 
 | ||||||
|  | @EVALUATOR.register_module | ||||||
|  | class CULane(nn.Module): | ||||||
|  |     def __init__(self, cfg): | ||||||
|  |         super(CULane, self).__init__() | ||||||
|  |         # Firstly, check the evaluation tool | ||||||
|  |         check() | ||||||
|  |         self.cfg = cfg  | ||||||
|  |         self.blur = torch.nn.Conv2d( | ||||||
|  |             5, 5, 9, padding=4, bias=False, groups=5).cuda() | ||||||
|  |         torch.nn.init.constant_(self.blur.weight, 1 / 81) | ||||||
|  |         self.logger = get_logger('resa') | ||||||
|  |         self.out_dir = os.path.join(self.cfg.work_dir, 'lines') | ||||||
|  |         if cfg.view: | ||||||
|  |             self.view_dir = os.path.join(self.cfg.work_dir, 'vis') | ||||||
|  | 
 | ||||||
|  |     def evaluate(self, dataset, output, batch): | ||||||
|  |         seg, exists = output['seg'], output['exist'] | ||||||
|  |         predictmaps = F.softmax(seg, dim=1).cpu().numpy() | ||||||
|  |         exists = exists.cpu().numpy() | ||||||
|  |         batch_size = seg.size(0) | ||||||
|  |         img_name = batch['meta']['img_name'] | ||||||
|  |         img_path = batch['meta']['full_img_path'] | ||||||
|  |         for i in range(batch_size): | ||||||
|  |             coords = dataset.probmap2lane(predictmaps[i], exists[i]) | ||||||
|  |             outname = self.out_dir + img_name[i][:-4] + '.lines.txt' | ||||||
|  |             outdir = os.path.dirname(outname) | ||||||
|  |             if not os.path.exists(outdir): | ||||||
|  |                 os.makedirs(outdir) | ||||||
|  |             f = open(outname, 'w') | ||||||
|  |             for coord in coords: | ||||||
|  |                 for x, y in coord: | ||||||
|  |                     if x < 0 and y < 0: | ||||||
|  |                         continue | ||||||
|  |                     f.write('%d %d ' % (x, y)) | ||||||
|  |                 f.write('\n') | ||||||
|  |             f.close() | ||||||
|  | 
 | ||||||
|  |             if self.cfg.view: | ||||||
|  |                 img = cv2.imread(img_path[i]).astype(np.float32) | ||||||
|  |                 dataset.view(img, coords, self.view_dir+img_name[i]) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |     def summarize(self): | ||||||
|  |         self.logger.info('summarize result...') | ||||||
|  |         eval_list_path = os.path.join( | ||||||
|  |             self.cfg.dataset_path, "list", self.cfg.dataset.val.data_list) | ||||||
|  |         #prob2lines(self.prob_dir, self.out_dir, eval_list_path, self.cfg) | ||||||
|  |         res = call_culane_eval(self.cfg.dataset_path, output_path=self.cfg.work_dir) | ||||||
|  |         TP,FP,FN = 0,0,0 | ||||||
|  |         out_str = 'Copypaste: ' | ||||||
|  |         for k, v in res.items(): | ||||||
|  |             val = float(v['Fmeasure']) if 'nan' not in v['Fmeasure'] else 0 | ||||||
|  |             val_tp, val_fp, val_fn = int(v['tp']), int(v['fp']), int(v['fn']) | ||||||
|  |             val_p, val_r, val_f1 = float(v['precision']), float(v['recall']), float(v['Fmeasure']) | ||||||
|  |             TP += val_tp | ||||||
|  |             FP += val_fp | ||||||
|  |             FN += val_fn | ||||||
|  |             self.logger.info(k + ': ' + str(v)) | ||||||
|  |             out_str += k | ||||||
|  |             for metric, value in v.items(): | ||||||
|  |                 out_str += ' ' + str(value).rstrip('\n') | ||||||
|  |             out_str += ' ' | ||||||
|  |         P = TP * 1.0 / (TP + FP + 1e-9) | ||||||
|  |         R = TP * 1.0 / (TP + FN + 1e-9) | ||||||
|  |         F = 2*P*R/(P + R + 1e-9) | ||||||
|  |         overall_result_str = ('Overall Precision: %f Recall: %f F1: %f' % (P, R, F)) | ||||||
|  |         self.logger.info(overall_result_str) | ||||||
|  |         out_str = out_str + overall_result_str | ||||||
|  |         self.logger.info(out_str) | ||||||
|  | 
 | ||||||
|  |         # delete the tmp output | ||||||
|  |         rmtree(self.out_dir) | ||||||
							
								
								
									
										2
									
								
								runner/evaluator/culane/lane_evaluation/.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								runner/evaluator/culane/lane_evaluation/.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @ -0,0 +1,2 @@ | |||||||
|  | build/ | ||||||
|  | evaluate | ||||||
							
								
								
									
										50
									
								
								runner/evaluator/culane/lane_evaluation/Makefile
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										50
									
								
								runner/evaluator/culane/lane_evaluation/Makefile
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,50 @@ | |||||||
|  | PROJECT_NAME:= evaluate | ||||||
|  | 
 | ||||||
|  | # config ----------------------------------
 | ||||||
|  | OPENCV_VERSION := 3 | ||||||
|  | 
 | ||||||
|  | INCLUDE_DIRS := include | ||||||
|  | LIBRARY_DIRS := lib /usr/local/lib | ||||||
|  | 
 | ||||||
|  | COMMON_FLAGS := -DCPU_ONLY | ||||||
|  | CXXFLAGS := -std=c++11 -fopenmp | ||||||
|  | LDFLAGS := -fopenmp -Wl,-rpath,./lib | ||||||
|  | BUILD_DIR := build | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # make rules -------------------------------
 | ||||||
|  | CXX ?= g++ | ||||||
|  | BUILD_DIR ?= ./build | ||||||
|  | 
 | ||||||
|  | LIBRARIES += opencv_core opencv_highgui opencv_imgproc  | ||||||
|  | ifeq ($(OPENCV_VERSION), 3) | ||||||
|  | 		LIBRARIES += opencv_imgcodecs | ||||||
|  | endif | ||||||
|  | 
 | ||||||
|  | CXXFLAGS += $(COMMON_FLAGS) $(foreach includedir,$(INCLUDE_DIRS),-I$(includedir)) | ||||||
|  | LDFLAGS +=  $(COMMON_FLAGS) $(foreach includedir,$(LIBRARY_DIRS),-L$(includedir)) $(foreach library,$(LIBRARIES),-l$(library)) | ||||||
|  | SRC_DIRS += $(shell find * -type d -exec bash -c "find {} -maxdepth 1 \( -name '*.cpp' -o -name '*.proto' \) | grep -q ." \; -print) | ||||||
|  | CXX_SRCS += $(shell find src/ -name "*.cpp") | ||||||
|  | CXX_TARGETS:=$(patsubst %.cpp, $(BUILD_DIR)/%.o, $(CXX_SRCS)) | ||||||
|  | ALL_BUILD_DIRS := $(sort $(BUILD_DIR) $(addprefix $(BUILD_DIR)/, $(SRC_DIRS))) | ||||||
|  | 
 | ||||||
|  | .PHONY: all | ||||||
|  | all: $(PROJECT_NAME) | ||||||
|  | 
 | ||||||
|  | .PHONY: $(ALL_BUILD_DIRS) | ||||||
|  | $(ALL_BUILD_DIRS): | ||||||
|  | 	@mkdir -p $@ | ||||||
|  | 
 | ||||||
|  | $(BUILD_DIR)/%.o: %.cpp | $(ALL_BUILD_DIRS) | ||||||
|  | 	@echo "CXX" $< | ||||||
|  | 	@$(CXX) $(CXXFLAGS) -c -o $@ $< | ||||||
|  | 
 | ||||||
|  | $(PROJECT_NAME): $(CXX_TARGETS) | ||||||
|  | 	@echo "CXX/LD" $@ | ||||||
|  | 	@$(CXX) -o $@ $^ $(LDFLAGS) | ||||||
|  | 
 | ||||||
|  | .PHONY: clean | ||||||
|  | clean: | ||||||
|  | 	@rm -rf $(CXX_TARGETS) | ||||||
|  | 	@rm -rf $(PROJECT_NAME) | ||||||
|  | 	@rm -rf $(BUILD_DIR) | ||||||
							
								
								
									
										47
									
								
								runner/evaluator/culane/lane_evaluation/include/counter.hpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										47
									
								
								runner/evaluator/culane/lane_evaluation/include/counter.hpp
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,47 @@ | |||||||
|  | #ifndef COUNTER_HPP | ||||||
|  | #define COUNTER_HPP | ||||||
|  | 
 | ||||||
|  | #include "lane_compare.hpp" | ||||||
|  | #include "hungarianGraph.hpp" | ||||||
|  | #include <iostream> | ||||||
|  | #include <algorithm> | ||||||
|  | #include <tuple> | ||||||
|  | #include <vector> | ||||||
|  | #include <opencv2/core/core.hpp> | ||||||
|  | 
 | ||||||
|  | using namespace std; | ||||||
|  | using namespace cv; | ||||||
|  | 
 | ||||||
|  | // before coming to use functions of this class, the lanes should resize to im_width and im_height using resize_lane() in lane_compare.hpp
 | ||||||
|  | class Counter | ||||||
|  | { | ||||||
|  | 	public: | ||||||
|  | 		Counter(int _im_width, int _im_height, double _iou_threshold=0.4, int _lane_width=10):tp(0),fp(0),fn(0){ | ||||||
|  | 			im_width = _im_width; | ||||||
|  | 			im_height = _im_height; | ||||||
|  | 			sim_threshold = _iou_threshold; | ||||||
|  | 			lane_compare = new LaneCompare(_im_width, _im_height,  _lane_width, LaneCompare::IOU); | ||||||
|  | 		}; | ||||||
|  | 		double get_precision(void); | ||||||
|  | 		double get_recall(void); | ||||||
|  | 		long getTP(void); | ||||||
|  | 		long getFP(void); | ||||||
|  | 		long getFN(void); | ||||||
|  | 		void setTP(long); | ||||||
|  | 		void setFP(long); | ||||||
|  | 		void setFN(long); | ||||||
|  | 		// direct add tp, fp, tn and fn
 | ||||||
|  | 		// first match with hungarian
 | ||||||
|  | 		tuple<vector<int>, long, long, long, long> count_im_pair(const vector<vector<Point2f> > &anno_lanes, const vector<vector<Point2f> > &detect_lanes); | ||||||
|  | 		void makeMatch(const vector<vector<double> > &similarity, vector<int> &match1, vector<int> &match2); | ||||||
|  | 
 | ||||||
|  | 	private: | ||||||
|  | 		double sim_threshold; | ||||||
|  | 		int im_width; | ||||||
|  | 		int im_height; | ||||||
|  | 		long tp; | ||||||
|  | 		long fp; | ||||||
|  | 		long fn; | ||||||
|  | 		LaneCompare *lane_compare; | ||||||
|  | }; | ||||||
|  | #endif | ||||||
| @ -0,0 +1,71 @@ | |||||||
|  | #ifndef HUNGARIAN_GRAPH_HPP | ||||||
|  | #define HUNGARIAN_GRAPH_HPP | ||||||
|  | #include <vector> | ||||||
|  | using namespace std; | ||||||
|  | 
 | ||||||
|  | struct pipartiteGraph { | ||||||
|  |     vector<vector<double> > mat; | ||||||
|  |     vector<bool> leftUsed, rightUsed; | ||||||
|  |     vector<double> leftWeight, rightWeight; | ||||||
|  |     vector<int>rightMatch, leftMatch; | ||||||
|  |     int leftNum, rightNum; | ||||||
|  |     bool matchDfs(int u) { | ||||||
|  |         leftUsed[u] = true; | ||||||
|  |         for (int v = 0; v < rightNum; v++) { | ||||||
|  |             if (!rightUsed[v] && fabs(leftWeight[u] + rightWeight[v] - mat[u][v]) < 1e-2) { | ||||||
|  |                 rightUsed[v] = true; | ||||||
|  |                 if (rightMatch[v] == -1 || matchDfs(rightMatch[v])) { | ||||||
|  |                     rightMatch[v] = u; | ||||||
|  |                     leftMatch[u] = v; | ||||||
|  |                     return true; | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |         return false; | ||||||
|  |     } | ||||||
|  |     void resize(int leftNum, int rightNum) { | ||||||
|  |         this->leftNum = leftNum; | ||||||
|  |         this->rightNum = rightNum; | ||||||
|  |         leftMatch.resize(leftNum); | ||||||
|  |         rightMatch.resize(rightNum); | ||||||
|  |         leftUsed.resize(leftNum); | ||||||
|  |         rightUsed.resize(rightNum); | ||||||
|  |         leftWeight.resize(leftNum); | ||||||
|  |         rightWeight.resize(rightNum); | ||||||
|  |         mat.resize(leftNum); | ||||||
|  |         for (int i = 0; i < leftNum; i++) mat[i].resize(rightNum); | ||||||
|  |     } | ||||||
|  |     void match() { | ||||||
|  |         for (int i = 0; i < leftNum; i++) leftMatch[i] = -1; | ||||||
|  |         for (int i = 0; i < rightNum; i++) rightMatch[i] = -1; | ||||||
|  |         for (int i = 0; i < rightNum; i++) rightWeight[i] = 0; | ||||||
|  |         for (int i = 0; i < leftNum; i++) { | ||||||
|  |             leftWeight[i] = -1e5; | ||||||
|  |             for (int j = 0; j < rightNum; j++) { | ||||||
|  |                 if (leftWeight[i] < mat[i][j]) leftWeight[i] = mat[i][j]; | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         for (int u = 0; u < leftNum; u++) { | ||||||
|  |             while (1) { | ||||||
|  |                 for (int i = 0; i < leftNum; i++) leftUsed[i] = false; | ||||||
|  |                 for (int i = 0; i < rightNum; i++) rightUsed[i] = false; | ||||||
|  |                 if (matchDfs(u)) break; | ||||||
|  |                 double d = 1e10; | ||||||
|  |                 for (int i = 0; i < leftNum; i++) { | ||||||
|  |                     if (leftUsed[i] ) { | ||||||
|  |                         for (int j = 0; j < rightNum; j++) { | ||||||
|  |                             if (!rightUsed[j]) d = min(d, leftWeight[i] + rightWeight[j] - mat[i][j]); | ||||||
|  |                         } | ||||||
|  |                     } | ||||||
|  |                 } | ||||||
|  |                 if (d == 1e10) return ; | ||||||
|  |                 for (int i = 0; i < leftNum; i++) if (leftUsed[i]) leftWeight[i] -= d; | ||||||
|  |                 for (int i = 0; i < rightNum; i++) if (rightUsed[i]) rightWeight[i] += d; | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | #endif // HUNGARIAN_GRAPH_HPP
 | ||||||
| @ -0,0 +1,51 @@ | |||||||
|  | #ifndef LANE_COMPARE_HPP | ||||||
|  | #define LANE_COMPARE_HPP | ||||||
|  | 
 | ||||||
|  | #include "spline.hpp" | ||||||
|  | #include <vector> | ||||||
|  | #include <iostream> | ||||||
|  | #include <opencv2/core/version.hpp> | ||||||
|  | #include <opencv2/core/core.hpp> | ||||||
|  | 
 | ||||||
|  | #if CV_VERSION_EPOCH == 2 | ||||||
|  | #define OPENCV2 | ||||||
|  | #elif CV_VERSION_MAJOR == 3 | ||||||
|  | #define  OPENCV3 | ||||||
|  | #else | ||||||
|  | #error Not support this OpenCV version | ||||||
|  | #endif | ||||||
|  | 
 | ||||||
|  | #ifdef OPENCV3 | ||||||
|  | #include <opencv2/imgproc.hpp> | ||||||
|  | #elif defined(OPENCV2) | ||||||
|  | #include <opencv2/imgproc/imgproc.hpp> | ||||||
|  | #endif | ||||||
|  | 
 | ||||||
|  | using namespace std; | ||||||
|  | using namespace cv; | ||||||
|  | 
 | ||||||
|  | class LaneCompare{ | ||||||
|  | 	public: | ||||||
|  | 		enum CompareMode{ | ||||||
|  | 			IOU, | ||||||
|  | 			Caltech | ||||||
|  | 		}; | ||||||
|  | 
 | ||||||
|  | 		LaneCompare(int _im_width, int _im_height, int _lane_width = 10, CompareMode _compare_mode = IOU){ | ||||||
|  | 			im_width = _im_width; | ||||||
|  | 			im_height = _im_height; | ||||||
|  | 			compare_mode = _compare_mode; | ||||||
|  | 			lane_width = _lane_width; | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		double get_lane_similarity(const vector<Point2f> &lane1, const vector<Point2f> &lane2); | ||||||
|  | 		void resize_lane(vector<Point2f> &curr_lane, int curr_width, int curr_height); | ||||||
|  | 	private: | ||||||
|  | 		CompareMode compare_mode; | ||||||
|  | 		int im_width; | ||||||
|  | 		int im_height; | ||||||
|  | 		int lane_width; | ||||||
|  | 		Spline splineSolver; | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | #endif | ||||||
							
								
								
									
										28
									
								
								runner/evaluator/culane/lane_evaluation/include/spline.hpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										28
									
								
								runner/evaluator/culane/lane_evaluation/include/spline.hpp
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,28 @@ | |||||||
|  | #ifndef SPLINE_HPP | ||||||
|  | #define SPLINE_HPP | ||||||
|  | #include <vector> | ||||||
|  | #include <cstdio> | ||||||
|  | #include <math.h> | ||||||
|  | #include <opencv2/core/core.hpp> | ||||||
|  | 
 | ||||||
|  | using namespace cv; | ||||||
|  | using namespace std; | ||||||
|  | 
 | ||||||
|  | struct Func { | ||||||
|  |     double a_x; | ||||||
|  |     double b_x; | ||||||
|  |     double c_x; | ||||||
|  |     double d_x; | ||||||
|  |     double a_y; | ||||||
|  |     double b_y; | ||||||
|  |     double c_y; | ||||||
|  |     double d_y; | ||||||
|  |     double h; | ||||||
|  | }; | ||||||
|  | class Spline { | ||||||
|  | public: | ||||||
|  | 	vector<Point2f> splineInterpTimes(const vector<Point2f> &tmp_line, int times); | ||||||
|  |     vector<Point2f> splineInterpStep(vector<Point2f> tmp_line, double step); | ||||||
|  | 	vector<Func> cal_fun(const vector<Point2f> &point_v); | ||||||
|  | }; | ||||||
|  | #endif | ||||||
							
								
								
									
										134
									
								
								runner/evaluator/culane/lane_evaluation/src/counter.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										134
									
								
								runner/evaluator/culane/lane_evaluation/src/counter.cpp
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,134 @@ | |||||||
|  | /*************************************************************************
 | ||||||
|  | 	> File Name: counter.cpp | ||||||
|  | 	> Author: Xingang Pan, Jun Li | ||||||
|  | 	> Mail: px117@ie.cuhk.edu.hk | ||||||
|  | 	> Created Time: Thu Jul 14 20:23:08 2016 | ||||||
|  |  ************************************************************************/ | ||||||
|  | 
 | ||||||
|  | #include "counter.hpp" | ||||||
|  | 
 | ||||||
|  | double Counter::get_precision(void) | ||||||
|  | { | ||||||
|  | 	cerr<<"tp: "<<tp<<" fp: "<<fp<<" fn: "<<fn<<endl; | ||||||
|  | 	if(tp+fp == 0) | ||||||
|  | 	{ | ||||||
|  | 		cerr<<"no positive detection"<<endl; | ||||||
|  | 		return -1; | ||||||
|  | 	} | ||||||
|  | 	return tp/double(tp + fp); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | double Counter::get_recall(void) | ||||||
|  | { | ||||||
|  | 	if(tp+fn == 0) | ||||||
|  | 	{ | ||||||
|  | 		cerr<<"no ground truth positive"<<endl; | ||||||
|  | 		return -1; | ||||||
|  | 	} | ||||||
|  | 	return tp/double(tp + fn); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | long Counter::getTP(void) | ||||||
|  | { | ||||||
|  | 	return tp; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | long Counter::getFP(void) | ||||||
|  | { | ||||||
|  | 	return fp; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | long Counter::getFN(void) | ||||||
|  | { | ||||||
|  | 	return fn; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | void Counter::setTP(long value)  | ||||||
|  | { | ||||||
|  | 	tp = value; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | void Counter::setFP(long value) | ||||||
|  | { | ||||||
|  |   fp = value; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | void Counter::setFN(long value) | ||||||
|  | { | ||||||
|  | 	fn = value; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | tuple<vector<int>, long, long, long, long> Counter::count_im_pair(const vector<vector<Point2f> > &anno_lanes, const vector<vector<Point2f> > &detect_lanes) | ||||||
|  | { | ||||||
|  | 	vector<int> anno_match(anno_lanes.size(), -1); | ||||||
|  | 	vector<int> detect_match; | ||||||
|  | 	if(anno_lanes.empty()) | ||||||
|  | 	{ | ||||||
|  | 		return make_tuple(anno_match, 0, detect_lanes.size(), 0, 0); | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if(detect_lanes.empty()) | ||||||
|  | 	{ | ||||||
|  | 		return make_tuple(anno_match, 0, 0, 0, anno_lanes.size()); | ||||||
|  | 	} | ||||||
|  | 	// hungarian match first
 | ||||||
|  | 	 | ||||||
|  | 	// first calc similarity matrix
 | ||||||
|  | 	vector<vector<double> > similarity(anno_lanes.size(), vector<double>(detect_lanes.size(), 0)); | ||||||
|  | 	for(int i=0; i<anno_lanes.size(); i++) | ||||||
|  | 	{ | ||||||
|  | 		const vector<Point2f> &curr_anno_lane = anno_lanes[i]; | ||||||
|  | 		for(int j=0; j<detect_lanes.size(); j++) | ||||||
|  | 		{ | ||||||
|  | 			const vector<Point2f> &curr_detect_lane = detect_lanes[j]; | ||||||
|  | 			similarity[i][j] = lane_compare->get_lane_similarity(curr_anno_lane, curr_detect_lane); | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 	makeMatch(similarity, anno_match, detect_match); | ||||||
|  | 
 | ||||||
|  | 	 | ||||||
|  | 	int curr_tp = 0; | ||||||
|  | 	// count and add
 | ||||||
|  | 	for(int i=0; i<anno_lanes.size(); i++) | ||||||
|  | 	{ | ||||||
|  | 		if(anno_match[i]>=0 && similarity[i][anno_match[i]] > sim_threshold) | ||||||
|  | 		{ | ||||||
|  | 			curr_tp++; | ||||||
|  | 		} | ||||||
|  | 		else | ||||||
|  | 		{ | ||||||
|  | 			anno_match[i] = -1; | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	int curr_fn = anno_lanes.size() - curr_tp; | ||||||
|  | 	int curr_fp = detect_lanes.size() - curr_tp; | ||||||
|  | 	return make_tuple(anno_match, curr_tp, curr_fp, 0, curr_fn); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | void Counter::makeMatch(const vector<vector<double> > &similarity, vector<int> &match1, vector<int> &match2) { | ||||||
|  | 	int m = similarity.size(); | ||||||
|  | 	int n = similarity[0].size(); | ||||||
|  |     pipartiteGraph gra; | ||||||
|  |     bool have_exchange = false; | ||||||
|  |     if (m > n) { | ||||||
|  |         have_exchange = true; | ||||||
|  |         swap(m, n); | ||||||
|  |     } | ||||||
|  |     gra.resize(m, n); | ||||||
|  |     for (int i = 0; i < gra.leftNum; i++) { | ||||||
|  |         for (int j = 0; j < gra.rightNum; j++) { | ||||||
|  | 			if(have_exchange) | ||||||
|  | 				gra.mat[i][j] = similarity[j][i]; | ||||||
|  | 			else | ||||||
|  | 				gra.mat[i][j] = similarity[i][j]; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |     gra.match(); | ||||||
|  |     match1 = gra.leftMatch; | ||||||
|  |     match2 = gra.rightMatch; | ||||||
|  |     if (have_exchange) swap(match1, match2); | ||||||
|  | } | ||||||
							
								
								
									
										302
									
								
								runner/evaluator/culane/lane_evaluation/src/evaluate.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										302
									
								
								runner/evaluator/culane/lane_evaluation/src/evaluate.cpp
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,302 @@ | |||||||
|  | /*************************************************************************
 | ||||||
|  |         > File Name: evaluate.cpp | ||||||
|  |         > Author: Xingang Pan, Jun Li | ||||||
|  |         > Mail: px117@ie.cuhk.edu.hk | ||||||
|  |         > Created Time: 2016年07月14日 星期四 18时28分45秒 | ||||||
|  |  ************************************************************************/ | ||||||
|  | 
 | ||||||
|  | #include "counter.hpp" | ||||||
|  | #include "spline.hpp" | ||||||
|  | #include <unistd.h> | ||||||
|  | #include <iostream> | ||||||
|  | #include <fstream> | ||||||
|  | #include <sstream> | ||||||
|  | #include <cstdlib> | ||||||
|  | #include <string> | ||||||
|  | #include <opencv2/core/core.hpp> | ||||||
|  | #include <opencv2/highgui/highgui.hpp> | ||||||
|  | using namespace std; | ||||||
|  | using namespace cv; | ||||||
|  | 
 | ||||||
|  | void help(void) { | ||||||
|  |   cout << "./evaluate [OPTIONS]" << endl; | ||||||
|  |   cout << "-h                  : print usage help" << endl; | ||||||
|  |   cout << "-a                  : directory for annotation files (default: " | ||||||
|  |           "/data/driving/eval_data/anno_label/)" << endl; | ||||||
|  |   cout << "-d                  : directory for detection files (default: " | ||||||
|  |           "/data/driving/eval_data/predict_label/)" << endl; | ||||||
|  |   cout << "-i                  : directory for image files (default: " | ||||||
|  |           "/data/driving/eval_data/img/)" << endl; | ||||||
|  |   cout << "-l                  : list of images used for evaluation (default: " | ||||||
|  |           "/data/driving/eval_data/img/all.txt)" << endl; | ||||||
|  |   cout << "-w                  : width of the lanes (default: 10)" << endl; | ||||||
|  |   cout << "-t                  : threshold of iou (default: 0.4)" << endl; | ||||||
|  |   cout << "-c                  : cols (max image width) (default: 1920)" | ||||||
|  |        << endl; | ||||||
|  |   cout << "-r                  : rows (max image height) (default: 1080)" | ||||||
|  |        << endl; | ||||||
|  |   cout << "-s                  : show visualization" << endl; | ||||||
|  |   cout << "-f                  : start frame in the test set (default: 1)" | ||||||
|  |        << endl; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | void read_lane_file(const string &file_name, vector<vector<Point2f>> &lanes); | ||||||
|  | void visualize(string &full_im_name, vector<vector<Point2f>> &anno_lanes, | ||||||
|  |                vector<vector<Point2f>> &detect_lanes, vector<int> anno_match, | ||||||
|  |                int width_lane, string save_path = ""); | ||||||
|  | 
 | ||||||
|  | int main(int argc, char **argv) { | ||||||
|  |   // process params
 | ||||||
|  |   string anno_dir = "/data/driving/eval_data/anno_label/"; | ||||||
|  |   string detect_dir = "/data/driving/eval_data/predict_label/"; | ||||||
|  |   string im_dir = "/data/driving/eval_data/img/"; | ||||||
|  |   string list_im_file = "/data/driving/eval_data/img/all.txt"; | ||||||
|  |   string output_file = "./output.txt"; | ||||||
|  |   int width_lane = 10; | ||||||
|  |   double iou_threshold = 0.4; | ||||||
|  |   int im_width = 1920; | ||||||
|  |   int im_height = 1080; | ||||||
|  |   int oc; | ||||||
|  |   bool show = false; | ||||||
|  |   int frame = 1; | ||||||
|  |   string save_path = ""; | ||||||
|  |   while ((oc = getopt(argc, argv, "ha:d:i:l:w:t:c:r:sf:o:p:")) != -1) { | ||||||
|  |     switch (oc) { | ||||||
|  |     case 'h': | ||||||
|  |       help(); | ||||||
|  |       return 0; | ||||||
|  |     case 'a': | ||||||
|  |       anno_dir = optarg; | ||||||
|  |       break; | ||||||
|  |     case 'd': | ||||||
|  |       detect_dir = optarg; | ||||||
|  |       break; | ||||||
|  |     case 'i': | ||||||
|  |       im_dir = optarg; | ||||||
|  |       break; | ||||||
|  |     case 'l': | ||||||
|  |       list_im_file = optarg; | ||||||
|  |       break; | ||||||
|  |     case 'w': | ||||||
|  |       width_lane = atoi(optarg); | ||||||
|  |       break; | ||||||
|  |     case 't': | ||||||
|  |       iou_threshold = atof(optarg); | ||||||
|  |       break; | ||||||
|  |     case 'c': | ||||||
|  |       im_width = atoi(optarg); | ||||||
|  |       break; | ||||||
|  |     case 'r': | ||||||
|  |       im_height = atoi(optarg); | ||||||
|  |       break; | ||||||
|  |     case 's': | ||||||
|  |       show = true; | ||||||
|  |       break; | ||||||
|  |     case 'p': | ||||||
|  |       save_path = optarg; | ||||||
|  |       break; | ||||||
|  |     case 'f': | ||||||
|  |       frame = atoi(optarg); | ||||||
|  |       break; | ||||||
|  |     case 'o': | ||||||
|  |       output_file = optarg; | ||||||
|  |       break; | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   cout << "------------Configuration---------" << endl; | ||||||
|  |   cout << "anno_dir: " << anno_dir << endl; | ||||||
|  |   cout << "detect_dir: " << detect_dir << endl; | ||||||
|  |   cout << "im_dir: " << im_dir << endl; | ||||||
|  |   cout << "list_im_file: " << list_im_file << endl; | ||||||
|  |   cout << "width_lane: " << width_lane << endl; | ||||||
|  |   cout << "iou_threshold: " << iou_threshold << endl; | ||||||
|  |   cout << "im_width: " << im_width << endl; | ||||||
|  |   cout << "im_height: " << im_height << endl; | ||||||
|  |   cout << "-----------------------------------" << endl; | ||||||
|  |   cout << "Evaluating the results..." << endl; | ||||||
|  |   // this is the max_width and max_height
 | ||||||
|  | 
 | ||||||
|  |   if (width_lane < 1) { | ||||||
|  |     cerr << "width_lane must be positive" << endl; | ||||||
|  |     help(); | ||||||
|  |     return 1; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   ifstream ifs_im_list(list_im_file, ios::in); | ||||||
|  |   if (ifs_im_list.fail()) { | ||||||
|  |     cerr << "Error: file " << list_im_file << " not exist!" << endl; | ||||||
|  |     return 1; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   Counter counter(im_width, im_height, iou_threshold, width_lane); | ||||||
|  | 
 | ||||||
|  |   vector<int> anno_match; | ||||||
|  |   string sub_im_name; | ||||||
|  |   // pre-load filelist
 | ||||||
|  |   vector<string> filelists; | ||||||
|  |   while (getline(ifs_im_list, sub_im_name)) { | ||||||
|  |     filelists.push_back(sub_im_name); | ||||||
|  |   } | ||||||
|  |   ifs_im_list.close(); | ||||||
|  | 
 | ||||||
|  |   vector<tuple<vector<int>, long, long, long, long>> tuple_lists; | ||||||
|  |   tuple_lists.resize(filelists.size()); | ||||||
|  | 
 | ||||||
|  | #pragma omp parallel for | ||||||
|  |   for (size_t i = 0; i < filelists.size(); i++) { | ||||||
|  |     auto sub_im_name = filelists[i]; | ||||||
|  |     string full_im_name = im_dir + sub_im_name; | ||||||
|  |     string sub_txt_name = | ||||||
|  |         sub_im_name.substr(0, sub_im_name.find_last_of(".")) + ".lines.txt"; | ||||||
|  |     string anno_file_name = anno_dir + sub_txt_name; | ||||||
|  |     string detect_file_name = detect_dir + sub_txt_name; | ||||||
|  |     vector<vector<Point2f>> anno_lanes; | ||||||
|  |     vector<vector<Point2f>> detect_lanes; | ||||||
|  |     read_lane_file(anno_file_name, anno_lanes); | ||||||
|  |     read_lane_file(detect_file_name, detect_lanes); | ||||||
|  |     // cerr<<count<<": "<<full_im_name<<endl;
 | ||||||
|  |     tuple_lists[i] = counter.count_im_pair(anno_lanes, detect_lanes); | ||||||
|  |     if (show) { | ||||||
|  |       auto anno_match = get<0>(tuple_lists[i]); | ||||||
|  |       visualize(full_im_name, anno_lanes, detect_lanes, anno_match, width_lane); | ||||||
|  |       waitKey(0); | ||||||
|  |     } | ||||||
|  |     if (save_path != "") { | ||||||
|  |       auto anno_match = get<0>(tuple_lists[i]); | ||||||
|  |       visualize(full_im_name, anno_lanes, detect_lanes, anno_match, width_lane, | ||||||
|  |                 save_path); | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   long tp = 0, fp = 0, tn = 0, fn = 0; | ||||||
|  |   for (auto result : tuple_lists) { | ||||||
|  |     tp += get<1>(result); | ||||||
|  |     fp += get<2>(result); | ||||||
|  |     // tn = get<3>(result);
 | ||||||
|  |     fn += get<4>(result); | ||||||
|  |   } | ||||||
|  |   counter.setTP(tp); | ||||||
|  |   counter.setFP(fp); | ||||||
|  |   counter.setFN(fn); | ||||||
|  | 
 | ||||||
|  |   double precision = counter.get_precision(); | ||||||
|  |   double recall = counter.get_recall(); | ||||||
|  |   double F = 2 * precision * recall / (precision + recall); | ||||||
|  |   cerr << "finished process file" << endl; | ||||||
|  |   cout << "precision: " << precision << endl; | ||||||
|  |   cout << "recall: " << recall << endl; | ||||||
|  |   cout << "Fmeasure: " << F << endl; | ||||||
|  |   cout << "----------------------------------" << endl; | ||||||
|  | 
 | ||||||
|  |   ofstream ofs_out_file; | ||||||
|  |   ofs_out_file.open(output_file, ios::out); | ||||||
|  |   ofs_out_file << "file: " << output_file << endl; | ||||||
|  |   ofs_out_file << "tp: " << counter.getTP() << " fp: " << counter.getFP() | ||||||
|  |                << " fn: " << counter.getFN() << endl; | ||||||
|  |   ofs_out_file << "precision: " << precision << endl; | ||||||
|  |   ofs_out_file << "recall: " << recall << endl; | ||||||
|  |   ofs_out_file << "Fmeasure: " << F << endl << endl; | ||||||
|  |   ofs_out_file.close(); | ||||||
|  |   return 0; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | void read_lane_file(const string &file_name, vector<vector<Point2f>> &lanes) { | ||||||
|  |   lanes.clear(); | ||||||
|  |   ifstream ifs_lane(file_name, ios::in); | ||||||
|  |   if (ifs_lane.fail()) { | ||||||
|  |     return; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   string str_line; | ||||||
|  |   while (getline(ifs_lane, str_line)) { | ||||||
|  |     vector<Point2f> curr_lane; | ||||||
|  |     stringstream ss; | ||||||
|  |     ss << str_line; | ||||||
|  |     double x, y; | ||||||
|  |     while (ss >> x >> y) { | ||||||
|  |       curr_lane.push_back(Point2f(x, y)); | ||||||
|  |     } | ||||||
|  |     lanes.push_back(curr_lane); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   ifs_lane.close(); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | void visualize(string &full_im_name, vector<vector<Point2f>> &anno_lanes, | ||||||
|  |                vector<vector<Point2f>> &detect_lanes, vector<int> anno_match, | ||||||
|  |                int width_lane, string save_path) { | ||||||
|  |   Mat img = imread(full_im_name, 1); | ||||||
|  |   Mat img2 = imread(full_im_name, 1); | ||||||
|  |   vector<Point2f> curr_lane; | ||||||
|  |   vector<Point2f> p_interp; | ||||||
|  |   Spline splineSolver; | ||||||
|  |   Scalar color_B = Scalar(255, 0, 0); | ||||||
|  |   Scalar color_G = Scalar(0, 255, 0); | ||||||
|  |   Scalar color_R = Scalar(0, 0, 255); | ||||||
|  |   Scalar color_P = Scalar(255, 0, 255); | ||||||
|  |   Scalar color; | ||||||
|  |   for (int i = 0; i < anno_lanes.size(); i++) { | ||||||
|  |     curr_lane = anno_lanes[i]; | ||||||
|  |     if (curr_lane.size() == 2) { | ||||||
|  |       p_interp = curr_lane; | ||||||
|  |     } else { | ||||||
|  |       p_interp = splineSolver.splineInterpTimes(curr_lane, 50); | ||||||
|  |     } | ||||||
|  |     if (anno_match[i] >= 0) { | ||||||
|  |       color = color_G; | ||||||
|  |     } else { | ||||||
|  |       color = color_G; | ||||||
|  |     } | ||||||
|  |     for (int n = 0; n < p_interp.size() - 1; n++) { | ||||||
|  |       line(img, p_interp[n], p_interp[n + 1], color, width_lane); | ||||||
|  |       line(img2, p_interp[n], p_interp[n + 1], color, 2); | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |   bool detected; | ||||||
|  |   for (int i = 0; i < detect_lanes.size(); i++) { | ||||||
|  |     detected = false; | ||||||
|  |     curr_lane = detect_lanes[i]; | ||||||
|  |     if (curr_lane.size() == 2) { | ||||||
|  |       p_interp = curr_lane; | ||||||
|  |     } else { | ||||||
|  |       p_interp = splineSolver.splineInterpTimes(curr_lane, 50); | ||||||
|  |     } | ||||||
|  |     for (int n = 0; n < anno_lanes.size(); n++) { | ||||||
|  |       if (anno_match[n] == i) { | ||||||
|  |         detected = true; | ||||||
|  |         break; | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |     if (detected == true) { | ||||||
|  |       color = color_B; | ||||||
|  |     } else { | ||||||
|  |       color = color_R; | ||||||
|  |     } | ||||||
|  |     for (int n = 0; n < p_interp.size() - 1; n++) { | ||||||
|  |       line(img, p_interp[n], p_interp[n + 1], color, width_lane); | ||||||
|  |       line(img2, p_interp[n], p_interp[n + 1], color, 2); | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |   if (save_path != "") { | ||||||
|  |     size_t pos = 0; | ||||||
|  |     string s = full_im_name; | ||||||
|  |     std::string token; | ||||||
|  |     std::string delimiter = "/"; | ||||||
|  |     vector<string> names; | ||||||
|  |     while ((pos = s.find(delimiter)) != std::string::npos) { | ||||||
|  |       token = s.substr(0, pos); | ||||||
|  |       names.emplace_back(token); | ||||||
|  |       s.erase(0, pos + delimiter.length()); | ||||||
|  |     } | ||||||
|  |     names.emplace_back(s); | ||||||
|  |     string file_name = names[3] + '_' + names[4] + '_' + names[5]; | ||||||
|  |     // cout << file_name << endl;
 | ||||||
|  |     imwrite(save_path + '/' + file_name, img); | ||||||
|  |   } else { | ||||||
|  |     namedWindow("visualize", 1); | ||||||
|  |     imshow("visualize", img); | ||||||
|  |     namedWindow("visualize2", 1); | ||||||
|  |     imshow("visualize2", img2); | ||||||
|  |   } | ||||||
|  | } | ||||||
							
								
								
									
										73
									
								
								runner/evaluator/culane/lane_evaluation/src/lane_compare.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										73
									
								
								runner/evaluator/culane/lane_evaluation/src/lane_compare.cpp
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,73 @@ | |||||||
|  | /*************************************************************************
 | ||||||
|  | 	> File Name: lane_compare.cpp | ||||||
|  | 	> Author: Xingang Pan, Jun Li | ||||||
|  | 	> Mail: px117@ie.cuhk.edu.hk | ||||||
|  | 	> Created Time: Fri Jul 15 10:26:32 2016 | ||||||
|  |  ************************************************************************/ | ||||||
|  | 
 | ||||||
|  | #include "lane_compare.hpp" | ||||||
|  | 
 | ||||||
|  | double LaneCompare::get_lane_similarity(const vector<Point2f> &lane1, const vector<Point2f> &lane2) | ||||||
|  | { | ||||||
|  | 	if(lane1.size()<2 || lane2.size()<2) | ||||||
|  | 	{ | ||||||
|  | 		cerr<<"lane size must be greater or equal to 2"<<endl; | ||||||
|  | 		return 0; | ||||||
|  | 	} | ||||||
|  | 	Mat im1 = Mat::zeros(im_height, im_width, CV_8UC1); | ||||||
|  | 	Mat im2 = Mat::zeros(im_height, im_width, CV_8UC1); | ||||||
|  | 	// draw lines on im1 and im2
 | ||||||
|  | 	vector<Point2f> p_interp1; | ||||||
|  | 	vector<Point2f> p_interp2; | ||||||
|  | 	if(lane1.size() == 2) | ||||||
|  | 	{ | ||||||
|  | 		p_interp1 = lane1; | ||||||
|  | 	} | ||||||
|  | 	else | ||||||
|  | 	{ | ||||||
|  | 		p_interp1 = splineSolver.splineInterpTimes(lane1, 50); | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if(lane2.size() == 2) | ||||||
|  | 	{ | ||||||
|  | 		p_interp2 = lane2; | ||||||
|  | 	} | ||||||
|  | 	else | ||||||
|  | 	{ | ||||||
|  | 		p_interp2 = splineSolver.splineInterpTimes(lane2, 50); | ||||||
|  | 	} | ||||||
|  | 	 | ||||||
|  | 	Scalar color_white = Scalar(1); | ||||||
|  | 	for(int n=0; n<p_interp1.size()-1; n++) | ||||||
|  | 	{ | ||||||
|  | 		line(im1, p_interp1[n], p_interp1[n+1], color_white, lane_width); | ||||||
|  | 	} | ||||||
|  | 	for(int n=0; n<p_interp2.size()-1; n++) | ||||||
|  | 	{ | ||||||
|  | 		line(im2, p_interp2[n], p_interp2[n+1], color_white, lane_width); | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	double sum_1 = cv::sum(im1).val[0]; | ||||||
|  | 	double sum_2 = cv::sum(im2).val[0]; | ||||||
|  | 	double inter_sum = cv::sum(im1.mul(im2)).val[0]; | ||||||
|  | 	double union_sum = sum_1 + sum_2 - inter_sum;  | ||||||
|  | 	double iou = inter_sum / union_sum; | ||||||
|  | 	return iou; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | // resize the lane from Size(curr_width, curr_height) to Size(im_width, im_height)
 | ||||||
|  | void LaneCompare::resize_lane(vector<Point2f> &curr_lane, int curr_width, int curr_height) | ||||||
|  | { | ||||||
|  | 	if(curr_width == im_width && curr_height == im_height) | ||||||
|  | 	{ | ||||||
|  | 		return; | ||||||
|  | 	} | ||||||
|  | 	double x_scale = im_width/(double)curr_width; | ||||||
|  | 	double y_scale = im_height/(double)curr_height; | ||||||
|  | 	for(int n=0; n<curr_lane.size(); n++) | ||||||
|  | 	{ | ||||||
|  | 		curr_lane[n] = Point2f(curr_lane[n].x*x_scale, curr_lane[n].y*y_scale); | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
							
								
								
									
										178
									
								
								runner/evaluator/culane/lane_evaluation/src/spline.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										178
									
								
								runner/evaluator/culane/lane_evaluation/src/spline.cpp
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,178 @@ | |||||||
|  | #include <vector> | ||||||
|  | #include <iostream> | ||||||
|  | #include "spline.hpp" | ||||||
|  | using namespace std; | ||||||
|  | using namespace cv; | ||||||
|  | 
 | ||||||
|  | vector<Point2f> Spline::splineInterpTimes(const vector<Point2f>& tmp_line, int times) { | ||||||
|  |     vector<Point2f> res; | ||||||
|  | 
 | ||||||
|  |     if(tmp_line.size() == 2) { | ||||||
|  |         double x1 = tmp_line[0].x; | ||||||
|  |         double y1 = tmp_line[0].y; | ||||||
|  |         double x2 = tmp_line[1].x; | ||||||
|  |         double y2 = tmp_line[1].y; | ||||||
|  | 
 | ||||||
|  |         for (int k = 0; k <= times; k++) { | ||||||
|  |             double xi =  x1 + double((x2 - x1) * k) / times; | ||||||
|  |             double yi =  y1 + double((y2 - y1) * k) / times; | ||||||
|  |             res.push_back(Point2f(xi, yi)); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     else if(tmp_line.size() > 2) | ||||||
|  |     { | ||||||
|  |         vector<Func> tmp_func; | ||||||
|  |         tmp_func = this->cal_fun(tmp_line); | ||||||
|  |         if (tmp_func.empty()) { | ||||||
|  |             cout << "in splineInterpTimes: cal_fun failed" << endl; | ||||||
|  |             return res; | ||||||
|  |         } | ||||||
|  |         for(int j = 0; j < tmp_func.size(); j++) | ||||||
|  |         { | ||||||
|  |             double delta = tmp_func[j].h / times; | ||||||
|  |             for(int k = 0; k < times; k++) | ||||||
|  |             { | ||||||
|  |                 double t1 = delta*k; | ||||||
|  |                 double x1 = tmp_func[j].a_x + tmp_func[j].b_x*t1 + tmp_func[j].c_x*pow(t1,2) + tmp_func[j].d_x*pow(t1,3); | ||||||
|  |                 double y1 = tmp_func[j].a_y + tmp_func[j].b_y*t1 + tmp_func[j].c_y*pow(t1,2) + tmp_func[j].d_y*pow(t1,3); | ||||||
|  |                 res.push_back(Point2f(x1, y1)); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |         res.push_back(tmp_line[tmp_line.size() - 1]); | ||||||
|  |     } | ||||||
|  | 	else { | ||||||
|  | 		cerr << "in splineInterpTimes: not enough points" << endl; | ||||||
|  | 	} | ||||||
|  |     return res; | ||||||
|  | } | ||||||
|  | vector<Point2f> Spline::splineInterpStep(vector<Point2f> tmp_line, double step) { | ||||||
|  | 	vector<Point2f> res; | ||||||
|  | 	/*
 | ||||||
|  | 	if (tmp_line.size() == 2) { | ||||||
|  | 		double x1 = tmp_line[0].x; | ||||||
|  | 		double y1 = tmp_line[0].y; | ||||||
|  | 		double x2 = tmp_line[1].x; | ||||||
|  | 		double y2 = tmp_line[1].y; | ||||||
|  | 
 | ||||||
|  | 		for (double yi = std::min(y1, y2); yi < std::max(y1, y2); yi += step) { | ||||||
|  |             double xi; | ||||||
|  | 			if (yi == y1) xi = x1; | ||||||
|  | 			else xi = (x2 - x1) / (y2 - y1) * (yi - y1) + x1; | ||||||
|  | 			res.push_back(Point2f(xi, yi)); | ||||||
|  | 		} | ||||||
|  | 	}*/ | ||||||
|  | 	if (tmp_line.size() == 2) { | ||||||
|  | 		double x1 = tmp_line[0].x; | ||||||
|  | 		double y1 = tmp_line[0].y; | ||||||
|  | 		double x2 = tmp_line[1].x; | ||||||
|  | 		double y2 = tmp_line[1].y; | ||||||
|  | 		tmp_line[1].x = (x1 + x2) / 2; | ||||||
|  | 		tmp_line[1].y = (y1 + y2) / 2; | ||||||
|  | 		tmp_line.push_back(Point2f(x2, y2)); | ||||||
|  | 	} | ||||||
|  | 	if (tmp_line.size() > 2) { | ||||||
|  | 		vector<Func> tmp_func; | ||||||
|  | 		tmp_func = this->cal_fun(tmp_line); | ||||||
|  | 		double ystart = tmp_line[0].y; | ||||||
|  | 		double yend = tmp_line[tmp_line.size() - 1].y; | ||||||
|  | 		bool down; | ||||||
|  | 		if (ystart < yend) down = 1; | ||||||
|  | 		else down = 0; | ||||||
|  | 		if (tmp_func.empty()) { | ||||||
|  | 			cerr << "in splineInterpStep: cal_fun failed" << endl; | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		for(int j = 0; j < tmp_func.size(); j++) | ||||||
|  |         { | ||||||
|  |             for(double t1 = 0; t1 < tmp_func[j].h; t1 += step) | ||||||
|  |             { | ||||||
|  |                 double x1 = tmp_func[j].a_x + tmp_func[j].b_x*t1 + tmp_func[j].c_x*pow(t1,2) + tmp_func[j].d_x*pow(t1,3); | ||||||
|  |                 double y1 = tmp_func[j].a_y + tmp_func[j].b_y*t1 + tmp_func[j].c_y*pow(t1,2) + tmp_func[j].d_y*pow(t1,3); | ||||||
|  |                 res.push_back(Point2f(x1, y1)); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |         res.push_back(tmp_line[tmp_line.size() - 1]); | ||||||
|  | 	} | ||||||
|  |     else { | ||||||
|  |         cerr << "in splineInterpStep: not enough points" << endl; | ||||||
|  |     } | ||||||
|  |     return res; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | vector<Func> Spline::cal_fun(const vector<Point2f> &point_v) | ||||||
|  | { | ||||||
|  |     vector<Func> func_v; | ||||||
|  |     int n = point_v.size(); | ||||||
|  |     if(n<=2) { | ||||||
|  |         cout << "in cal_fun: point number less than 3" << endl; | ||||||
|  |         return func_v; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     func_v.resize(point_v.size()-1); | ||||||
|  | 
 | ||||||
|  |     vector<double> Mx(n); | ||||||
|  |     vector<double> My(n); | ||||||
|  |     vector<double> A(n-2); | ||||||
|  |     vector<double> B(n-2); | ||||||
|  |     vector<double> C(n-2); | ||||||
|  |     vector<double> Dx(n-2); | ||||||
|  |     vector<double> Dy(n-2); | ||||||
|  |     vector<double> h(n-1); | ||||||
|  |     //vector<func> func_v(n-1);
 | ||||||
|  | 
 | ||||||
|  |     for(int i = 0; i < n-1; i++) | ||||||
|  |     { | ||||||
|  |         h[i] = sqrt(pow(point_v[i+1].x - point_v[i].x, 2) + pow(point_v[i+1].y - point_v[i].y, 2)); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     for(int i = 0; i < n-2; i++) | ||||||
|  |     { | ||||||
|  |         A[i] = h[i]; | ||||||
|  |         B[i] = 2*(h[i]+h[i+1]); | ||||||
|  |         C[i] = h[i+1]; | ||||||
|  | 
 | ||||||
|  |         Dx[i] =  6*( (point_v[i+2].x - point_v[i+1].x)/h[i+1] - (point_v[i+1].x - point_v[i].x)/h[i] ); | ||||||
|  |         Dy[i] =  6*( (point_v[i+2].y - point_v[i+1].y)/h[i+1] - (point_v[i+1].y - point_v[i].y)/h[i] ); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     //TDMA
 | ||||||
|  |     C[0] = C[0] / B[0]; | ||||||
|  |     Dx[0] = Dx[0] / B[0]; | ||||||
|  |     Dy[0] = Dy[0] / B[0]; | ||||||
|  |     for(int i = 1; i < n-2; i++) | ||||||
|  |     { | ||||||
|  |         double tmp = B[i] - A[i]*C[i-1]; | ||||||
|  |         C[i] = C[i] / tmp; | ||||||
|  |         Dx[i] = (Dx[i] - A[i]*Dx[i-1]) / tmp; | ||||||
|  |         Dy[i] = (Dy[i] - A[i]*Dy[i-1]) / tmp; | ||||||
|  |     } | ||||||
|  |     Mx[n-2] = Dx[n-3]; | ||||||
|  |     My[n-2] = Dy[n-3]; | ||||||
|  |     for(int i = n-4; i >= 0; i--) | ||||||
|  |     { | ||||||
|  |         Mx[i+1] = Dx[i] - C[i]*Mx[i+2]; | ||||||
|  |         My[i+1] = Dy[i] - C[i]*My[i+2]; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     Mx[0] = 0; | ||||||
|  |     Mx[n-1] = 0; | ||||||
|  |     My[0] = 0; | ||||||
|  |     My[n-1] = 0; | ||||||
|  | 
 | ||||||
|  |     for(int i = 0; i < n-1; i++) | ||||||
|  |     { | ||||||
|  |         func_v[i].a_x = point_v[i].x; | ||||||
|  |         func_v[i].b_x = (point_v[i+1].x - point_v[i].x)/h[i] - (2*h[i]*Mx[i] + h[i]*Mx[i+1]) / 6; | ||||||
|  |         func_v[i].c_x = Mx[i]/2; | ||||||
|  |         func_v[i].d_x = (Mx[i+1] - Mx[i]) / (6*h[i]); | ||||||
|  | 
 | ||||||
|  |         func_v[i].a_y = point_v[i].y; | ||||||
|  |         func_v[i].b_y = (point_v[i+1].y - point_v[i].y)/h[i] - (2*h[i]*My[i] + h[i]*My[i+1]) / 6; | ||||||
|  |         func_v[i].c_y = My[i]/2; | ||||||
|  |         func_v[i].d_y = (My[i+1] - My[i]) / (6*h[i]); | ||||||
|  | 
 | ||||||
|  |         func_v[i].h = h[i]; | ||||||
|  |     } | ||||||
|  |     return func_v; | ||||||
|  | } | ||||||
							
								
								
									
										51
									
								
								runner/evaluator/culane/prob2lines.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										51
									
								
								runner/evaluator/culane/prob2lines.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,51 @@ | |||||||
|  | import os | ||||||
|  | import argparse | ||||||
|  | import numpy as np | ||||||
|  | import pandas as pd | ||||||
|  | from PIL import Image | ||||||
|  | import tqdm | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def getLane(probmap, pts, cfg = None): | ||||||
|  |     thr = 0.3 | ||||||
|  |     coordinate = np.zeros(pts) | ||||||
|  |     cut_height = 0 | ||||||
|  |     if cfg.cut_height: | ||||||
|  |         cut_height = cfg.cut_height | ||||||
|  |     for i in range(pts): | ||||||
|  |         line = probmap[round(cfg.img_height-i*20/(590-cut_height)*cfg.img_height)-1] | ||||||
|  |         if np.max(line)/255 > thr: | ||||||
|  |             coordinate[i] = np.argmax(line)+1 | ||||||
|  |     if np.sum(coordinate > 0) < 2: | ||||||
|  |         coordinate = np.zeros(pts) | ||||||
|  |     return coordinate | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def prob2lines(prob_dir, out_dir, list_file, cfg = None): | ||||||
|  |     lists = pd.read_csv(list_file, sep=' ', header=None, | ||||||
|  |                         names=('img', 'probmap', 'label1', 'label2', 'label3', 'label4')) | ||||||
|  |     pts = 18 | ||||||
|  | 
 | ||||||
|  |     for k, im in enumerate(lists['img'], 1): | ||||||
|  |         existPath = prob_dir + im[:-4] + '.exist.txt' | ||||||
|  |         outname = out_dir + im[:-4] + '.lines.txt' | ||||||
|  |         prefix = '/'.join(outname.split('/')[:-1]) | ||||||
|  |         if not os.path.exists(prefix): | ||||||
|  |             os.makedirs(prefix) | ||||||
|  |         f = open(outname, 'w') | ||||||
|  | 
 | ||||||
|  |         labels = list(pd.read_csv(existPath, sep=' ', header=None).iloc[0]) | ||||||
|  |         coordinates = np.zeros((4, pts)) | ||||||
|  |         for i in range(4): | ||||||
|  |             if labels[i] == 1: | ||||||
|  |                 probfile = prob_dir + im[:-4] + '_{0}_avg.png'.format(i+1) | ||||||
|  |                 probmap = np.array(Image.open(probfile)) | ||||||
|  |                 coordinates[i] = getLane(probmap, pts, cfg) | ||||||
|  | 
 | ||||||
|  |                 if np.sum(coordinates[i] > 0) > 1: | ||||||
|  |                     for idx, value in enumerate(coordinates[i]): | ||||||
|  |                         if value > 0: | ||||||
|  |                             f.write('%d %d ' % ( | ||||||
|  |                                 round(value*1640/cfg.img_width)-1, round(590-idx*20)-1)) | ||||||
|  |                     f.write('\n') | ||||||
|  |         f.close() | ||||||
							
								
								
									
										115
									
								
								runner/evaluator/tusimple/getLane.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										115
									
								
								runner/evaluator/tusimple/getLane.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,115 @@ | |||||||
|  | import cv2 | ||||||
|  | import numpy as np | ||||||
|  | 
 | ||||||
|  | def isShort(lane): | ||||||
|  |     start = [i for i, x in enumerate(lane) if x > 0] | ||||||
|  |     if not start: | ||||||
|  |         return 1 | ||||||
|  |     else: | ||||||
|  |         return 0 | ||||||
|  | 
 | ||||||
|  | def fixGap(coordinate): | ||||||
|  |     if any(x > 0 for x in coordinate): | ||||||
|  |         start = [i for i, x in enumerate(coordinate) if x > 0][0] | ||||||
|  |         end = [i for i, x in reversed(list(enumerate(coordinate))) if x > 0][0] | ||||||
|  |         lane = coordinate[start:end+1] | ||||||
|  |         if any(x < 0 for x in lane): | ||||||
|  |             gap_start = [i for i, x in enumerate( | ||||||
|  |                 lane[:-1]) if x > 0 and lane[i+1] < 0] | ||||||
|  |             gap_end = [i+1 for i, | ||||||
|  |                        x in enumerate(lane[:-1]) if x < 0 and lane[i+1] > 0] | ||||||
|  |             gap_id = [i for i, x in enumerate(lane) if x < 0] | ||||||
|  |             if len(gap_start) == 0 or len(gap_end) == 0: | ||||||
|  |                 return coordinate | ||||||
|  |             for id in gap_id: | ||||||
|  |                 for i in range(len(gap_start)): | ||||||
|  |                     if i >= len(gap_end): | ||||||
|  |                         return coordinate | ||||||
|  |                     if id > gap_start[i] and id < gap_end[i]: | ||||||
|  |                         gap_width = float(gap_end[i] - gap_start[i]) | ||||||
|  |                         lane[id] = int((id - gap_start[i]) / gap_width * lane[gap_end[i]] + ( | ||||||
|  |                             gap_end[i] - id) / gap_width * lane[gap_start[i]]) | ||||||
|  |             if not all(x > 0 for x in lane): | ||||||
|  |                 print("Gaps still exist!") | ||||||
|  |             coordinate[start:end+1] = lane | ||||||
|  |     return coordinate | ||||||
|  | 
 | ||||||
|  | def getLane_tusimple(prob_map, y_px_gap, pts, thresh, resize_shape=None, cfg=None): | ||||||
|  |     """ | ||||||
|  |     Arguments: | ||||||
|  |     ---------- | ||||||
|  |     prob_map: prob map for single lane, np array size (h, w) | ||||||
|  |     resize_shape:  reshape size target, (H, W) | ||||||
|  | 
 | ||||||
|  |     Return: | ||||||
|  |     ---------- | ||||||
|  |     coords: x coords bottom up every y_px_gap px, 0 for non-exist, in resized shape | ||||||
|  |     """ | ||||||
|  |     if resize_shape is None: | ||||||
|  |         resize_shape = prob_map.shape | ||||||
|  |     h, w = prob_map.shape | ||||||
|  |     H, W = resize_shape | ||||||
|  |     H -= cfg.cut_height | ||||||
|  | 
 | ||||||
|  |     coords = np.zeros(pts) | ||||||
|  |     coords[:] = -1.0 | ||||||
|  |     for i in range(pts): | ||||||
|  |         y = int((H - 10 - i * y_px_gap) * h / H) | ||||||
|  |         if y < 0: | ||||||
|  |             break | ||||||
|  |         line = prob_map[y, :] | ||||||
|  |         id = np.argmax(line) | ||||||
|  |         if line[id] > thresh: | ||||||
|  |             coords[i] = int(id / w * W) | ||||||
|  |     if (coords > 0).sum() < 2: | ||||||
|  |         coords = np.zeros(pts) | ||||||
|  |     fixGap(coords) | ||||||
|  |     return coords | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def prob2lines_tusimple(seg_pred, exist, resize_shape=None, smooth=True, y_px_gap=10, pts=None, thresh=0.3, cfg=None): | ||||||
|  |     """ | ||||||
|  |     Arguments: | ||||||
|  |     ---------- | ||||||
|  |     seg_pred:      np.array size (5, h, w) | ||||||
|  |     resize_shape:  reshape size target, (H, W) | ||||||
|  |     exist:       list of existence, e.g. [0, 1, 1, 0] | ||||||
|  |     smooth:      whether to smooth the probability or not | ||||||
|  |     y_px_gap:    y pixel gap for sampling | ||||||
|  |     pts:     how many points for one lane | ||||||
|  |     thresh:  probability threshold | ||||||
|  | 
 | ||||||
|  |     Return: | ||||||
|  |     ---------- | ||||||
|  |     coordinates: [x, y] list of lanes, e.g.: [ [[9, 569], [50, 549]] ,[[630, 569], [647, 549]] ] | ||||||
|  |     """ | ||||||
|  |     if resize_shape is None: | ||||||
|  |         resize_shape = seg_pred.shape[1:]  # seg_pred (5, h, w) | ||||||
|  |     _, h, w = seg_pred.shape | ||||||
|  |     H, W = resize_shape | ||||||
|  |     coordinates = [] | ||||||
|  | 
 | ||||||
|  |     if pts is None: | ||||||
|  |         pts = round(H / 2 / y_px_gap) | ||||||
|  | 
 | ||||||
|  |     seg_pred = np.ascontiguousarray(np.transpose(seg_pred, (1, 2, 0))) | ||||||
|  |     for i in range(cfg.num_classes - 1): | ||||||
|  |         prob_map = seg_pred[..., i + 1] | ||||||
|  |         if smooth: | ||||||
|  |             prob_map = cv2.blur(prob_map, (9, 9), borderType=cv2.BORDER_REPLICATE) | ||||||
|  |         coords = getLane_tusimple(prob_map, y_px_gap, pts, thresh, resize_shape, cfg) | ||||||
|  |         if isShort(coords): | ||||||
|  |             continue | ||||||
|  |         coordinates.append( | ||||||
|  |             [[coords[j], H - 10 - j * y_px_gap] if coords[j] > 0 else [-1, H - 10 - j * y_px_gap] for j in | ||||||
|  |              range(pts)]) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |     if len(coordinates) == 0: | ||||||
|  |         coords = np.zeros(pts) | ||||||
|  |         coordinates.append( | ||||||
|  |             [[coords[j], H - 10 - j * y_px_gap] if coords[j] > 0 else [-1, H - 10 - j * y_px_gap] for j in | ||||||
|  |              range(pts)]) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |     return coordinates | ||||||
							
								
								
									
										108
									
								
								runner/evaluator/tusimple/lane.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										108
									
								
								runner/evaluator/tusimple/lane.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,108 @@ | |||||||
|  | import numpy as np | ||||||
|  | from sklearn.linear_model import LinearRegression | ||||||
|  | import json as json | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class LaneEval(object): | ||||||
|  |     lr = LinearRegression() | ||||||
|  |     pixel_thresh = 20 | ||||||
|  |     pt_thresh = 0.85 | ||||||
|  | 
 | ||||||
|  |     @staticmethod | ||||||
|  |     def get_angle(xs, y_samples): | ||||||
|  |         xs, ys = xs[xs >= 0], y_samples[xs >= 0] | ||||||
|  |         if len(xs) > 1: | ||||||
|  |             LaneEval.lr.fit(ys[:, None], xs) | ||||||
|  |             k = LaneEval.lr.coef_[0] | ||||||
|  |             theta = np.arctan(k) | ||||||
|  |         else: | ||||||
|  |             theta = 0 | ||||||
|  |         return theta | ||||||
|  | 
 | ||||||
|  |     @staticmethod | ||||||
|  |     def line_accuracy(pred, gt, thresh): | ||||||
|  |         pred = np.array([p if p >= 0 else -100 for p in pred]) | ||||||
|  |         gt = np.array([g if g >= 0 else -100 for g in gt]) | ||||||
|  |         return np.sum(np.where(np.abs(pred - gt) < thresh, 1., 0.)) / len(gt) | ||||||
|  | 
 | ||||||
|  |     @staticmethod | ||||||
|  |     def bench(pred, gt, y_samples, running_time): | ||||||
|  |         if any(len(p) != len(y_samples) for p in pred): | ||||||
|  |             raise Exception('Format of lanes error.') | ||||||
|  |         if running_time > 200 or len(gt) + 2 < len(pred): | ||||||
|  |             return 0., 0., 1. | ||||||
|  |         angles = [LaneEval.get_angle( | ||||||
|  |             np.array(x_gts), np.array(y_samples)) for x_gts in gt] | ||||||
|  |         threshs = [LaneEval.pixel_thresh / np.cos(angle) for angle in angles] | ||||||
|  |         line_accs = [] | ||||||
|  |         fp, fn = 0., 0. | ||||||
|  |         matched = 0. | ||||||
|  |         for x_gts, thresh in zip(gt, threshs): | ||||||
|  |             accs = [LaneEval.line_accuracy( | ||||||
|  |                 np.array(x_preds), np.array(x_gts), thresh) for x_preds in pred] | ||||||
|  |             max_acc = np.max(accs) if len(accs) > 0 else 0. | ||||||
|  |             if max_acc < LaneEval.pt_thresh: | ||||||
|  |                 fn += 1 | ||||||
|  |             else: | ||||||
|  |                 matched += 1 | ||||||
|  |             line_accs.append(max_acc) | ||||||
|  |         fp = len(pred) - matched | ||||||
|  |         if len(gt) > 4 and fn > 0: | ||||||
|  |             fn -= 1 | ||||||
|  |         s = sum(line_accs) | ||||||
|  |         if len(gt) > 4: | ||||||
|  |             s -= min(line_accs) | ||||||
|  |         return s / max(min(4.0, len(gt)), 1.), fp / len(pred) if len(pred) > 0 else 0., fn / max(min(len(gt), 4.), 1.) | ||||||
|  | 
 | ||||||
|  |     @staticmethod | ||||||
|  |     def bench_one_submit(pred_file, gt_file): | ||||||
|  |         try: | ||||||
|  |             json_pred = [json.loads(line) | ||||||
|  |                          for line in open(pred_file).readlines()] | ||||||
|  |         except BaseException as e: | ||||||
|  |             raise Exception('Fail to load json file of the prediction.') | ||||||
|  |         json_gt = [json.loads(line) for line in open(gt_file).readlines()] | ||||||
|  |         if len(json_gt) != len(json_pred): | ||||||
|  |             raise Exception( | ||||||
|  |                 'We do not get the predictions of all the test tasks') | ||||||
|  |         gts = {l['raw_file']: l for l in json_gt} | ||||||
|  |         accuracy, fp, fn = 0., 0., 0. | ||||||
|  |         for pred in json_pred: | ||||||
|  |             if 'raw_file' not in pred or 'lanes' not in pred or 'run_time' not in pred: | ||||||
|  |                 raise Exception( | ||||||
|  |                     'raw_file or lanes or run_time not in some predictions.') | ||||||
|  |             raw_file = pred['raw_file'] | ||||||
|  |             pred_lanes = pred['lanes'] | ||||||
|  |             run_time = pred['run_time'] | ||||||
|  |             if raw_file not in gts: | ||||||
|  |                 raise Exception( | ||||||
|  |                     'Some raw_file from your predictions do not exist in the test tasks.') | ||||||
|  |             gt = gts[raw_file] | ||||||
|  |             gt_lanes = gt['lanes'] | ||||||
|  |             y_samples = gt['h_samples'] | ||||||
|  |             try: | ||||||
|  |                 a, p, n = LaneEval.bench( | ||||||
|  |                     pred_lanes, gt_lanes, y_samples, run_time) | ||||||
|  |             except BaseException as e: | ||||||
|  |                 raise Exception('Format of lanes error.') | ||||||
|  |             accuracy += a | ||||||
|  |             fp += p | ||||||
|  |             fn += n | ||||||
|  |         num = len(gts) | ||||||
|  |         # the first return parameter is the default ranking parameter | ||||||
|  |         return json.dumps([ | ||||||
|  |             {'name': 'Accuracy', 'value': accuracy / num, 'order': 'desc'}, | ||||||
|  |             {'name': 'FP', 'value': fp / num, 'order': 'asc'}, | ||||||
|  |             {'name': 'FN', 'value': fn / num, 'order': 'asc'} | ||||||
|  |         ]), accuracy / num | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | if __name__ == '__main__': | ||||||
|  |     import sys | ||||||
|  |     try: | ||||||
|  |         if len(sys.argv) != 3: | ||||||
|  |             raise Exception('Invalid input arguments') | ||||||
|  |         print(LaneEval.bench_one_submit(sys.argv[1], sys.argv[2])) | ||||||
|  |     except Exception as e: | ||||||
|  |         print(e.message) | ||||||
|  |         sys.exit(e.message) | ||||||
							
								
								
									
										111
									
								
								runner/evaluator/tusimple/tusimple.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										111
									
								
								runner/evaluator/tusimple/tusimple.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,111 @@ | |||||||
|  | import torch.nn as nn | ||||||
|  | import torch | ||||||
|  | import torch.nn.functional as F | ||||||
|  | from runner.logger import get_logger | ||||||
|  | 
 | ||||||
|  | from runner.registry import EVALUATOR  | ||||||
|  | import json | ||||||
|  | import os | ||||||
|  | import cv2 | ||||||
|  | 
 | ||||||
|  | from .lane import LaneEval | ||||||
|  | 
 | ||||||
|  | def split_path(path): | ||||||
|  |     """split path tree into list""" | ||||||
|  |     folders = [] | ||||||
|  |     while True: | ||||||
|  |         path, folder = os.path.split(path) | ||||||
|  |         if folder != "": | ||||||
|  |             folders.insert(0, folder) | ||||||
|  |         else: | ||||||
|  |             if path != "": | ||||||
|  |                 folders.insert(0, path) | ||||||
|  |             break | ||||||
|  |     return folders | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @EVALUATOR.register_module | ||||||
|  | class Tusimple(nn.Module): | ||||||
|  |     def __init__(self, cfg): | ||||||
|  |         super(Tusimple, self).__init__() | ||||||
|  |         self.cfg = cfg  | ||||||
|  |         exp_dir = os.path.join(self.cfg.work_dir, "output") | ||||||
|  |         if not os.path.exists(exp_dir): | ||||||
|  |             os.mkdir(exp_dir) | ||||||
|  |         self.out_path = os.path.join(exp_dir, "coord_output") | ||||||
|  |         if not os.path.exists(self.out_path): | ||||||
|  |             os.mkdir(self.out_path) | ||||||
|  |         self.dump_to_json = []  | ||||||
|  |         self.thresh = cfg.evaluator.thresh | ||||||
|  |         self.logger = get_logger('resa') | ||||||
|  |         if cfg.view: | ||||||
|  |             self.view_dir = os.path.join(self.cfg.work_dir, 'vis') | ||||||
|  | 
 | ||||||
|  |     def evaluate_pred(self, dataset, seg_pred, exist_pred, batch): | ||||||
|  |         img_name = batch['meta']['img_name'] | ||||||
|  |         img_path = batch['meta']['full_img_path'] | ||||||
|  |         for b in range(len(seg_pred)): | ||||||
|  |             seg = seg_pred[b] | ||||||
|  |             exist = [1 if exist_pred[b, i] > | ||||||
|  |                      0.5 else 0 for i in range(self.cfg.num_classes-1)] | ||||||
|  |             lane_coords = dataset.probmap2lane(seg, exist, thresh = self.thresh) | ||||||
|  |             for i in range(len(lane_coords)): | ||||||
|  |                 lane_coords[i] = sorted( | ||||||
|  |                     lane_coords[i], key=lambda pair: pair[1]) | ||||||
|  | 
 | ||||||
|  |             path_tree = split_path(img_name[b]) | ||||||
|  |             save_dir, save_name = path_tree[-3:-1], path_tree[-1] | ||||||
|  |             save_dir = os.path.join(self.out_path, *save_dir) | ||||||
|  |             save_name = save_name[:-3] + "lines.txt" | ||||||
|  |             save_name = os.path.join(save_dir, save_name) | ||||||
|  |             if not os.path.exists(save_dir): | ||||||
|  |                 os.makedirs(save_dir, exist_ok=True) | ||||||
|  | 
 | ||||||
|  |             with open(save_name, "w") as f: | ||||||
|  |                 for l in lane_coords: | ||||||
|  |                     for (x, y) in l: | ||||||
|  |                         print("{} {}".format(x, y), end=" ", file=f) | ||||||
|  |                     print(file=f) | ||||||
|  | 
 | ||||||
|  |             json_dict = {} | ||||||
|  |             json_dict['lanes'] = [] | ||||||
|  |             json_dict['h_sample'] = [] | ||||||
|  |             json_dict['raw_file'] = os.path.join(*path_tree[-4:]) | ||||||
|  |             json_dict['run_time'] = 0 | ||||||
|  |             for l in lane_coords: | ||||||
|  |                 if len(l) == 0: | ||||||
|  |                     continue | ||||||
|  |                 json_dict['lanes'].append([]) | ||||||
|  |                 for (x, y) in l: | ||||||
|  |                     json_dict['lanes'][-1].append(int(x)) | ||||||
|  |             for (x, y) in lane_coords[0]: | ||||||
|  |                 json_dict['h_sample'].append(y) | ||||||
|  |             self.dump_to_json.append(json.dumps(json_dict)) | ||||||
|  |             if self.cfg.view: | ||||||
|  |                 img = cv2.imread(img_path[b]) | ||||||
|  |                 new_img_name = img_name[b].replace('/', '_') | ||||||
|  |                 save_dir = os.path.join(self.view_dir, new_img_name) | ||||||
|  |                 dataset.view(img, lane_coords, save_dir) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |     def evaluate(self, dataset, output, batch): | ||||||
|  |         seg_pred, exist_pred = output['seg'], output['exist'] | ||||||
|  |         seg_pred = F.softmax(seg_pred, dim=1) | ||||||
|  |         seg_pred = seg_pred.detach().cpu().numpy() | ||||||
|  |         exist_pred = exist_pred.detach().cpu().numpy() | ||||||
|  |         self.evaluate_pred(dataset, seg_pred, exist_pred, batch) | ||||||
|  | 
 | ||||||
|  |     def summarize(self): | ||||||
|  |         best_acc = 0 | ||||||
|  |         output_file = os.path.join(self.out_path, 'predict_test.json') | ||||||
|  |         with open(output_file, "w+") as f: | ||||||
|  |             for line in self.dump_to_json: | ||||||
|  |                 print(line, end="\n", file=f) | ||||||
|  | 
 | ||||||
|  |         eval_result, acc = LaneEval.bench_one_submit(output_file, | ||||||
|  |                             self.cfg.test_json_file) | ||||||
|  | 
 | ||||||
|  |         self.logger.info(eval_result) | ||||||
|  |         self.dump_to_json = [] | ||||||
|  |         best_acc = max(acc, best_acc) | ||||||
|  |         return best_acc | ||||||
							
								
								
									
										50
									
								
								runner/logger.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										50
									
								
								runner/logger.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,50 @@ | |||||||
|  | import logging | ||||||
|  | 
 | ||||||
|  | logger_initialized = {} | ||||||
|  | 
 | ||||||
|  | def get_logger(name, log_file=None, log_level=logging.INFO): | ||||||
|  |     """Initialize and get a logger by name. | ||||||
|  |     If the logger has not been initialized, this method will initialize the | ||||||
|  |     logger by adding one or two handlers, otherwise the initialized logger will | ||||||
|  |     be directly returned. During initialization, a StreamHandler will always be | ||||||
|  |     added. If `log_file` is specified and the process rank is 0, a FileHandler | ||||||
|  |     will also be added. | ||||||
|  |     Args: | ||||||
|  |         name (str): Logger name. | ||||||
|  |         log_file (str | None): The log filename. If specified, a FileHandler | ||||||
|  |             will be added to the logger. | ||||||
|  |         log_level (int): The logger level. Note that only the process of | ||||||
|  |             rank 0 is affected, and other processes will set the level to | ||||||
|  |             "Error" thus be silent most of the time. | ||||||
|  |     Returns: | ||||||
|  |         logging.Logger: The expected logger. | ||||||
|  |     """ | ||||||
|  |     logger = logging.getLogger(name) | ||||||
|  |     if name in logger_initialized: | ||||||
|  |         return logger | ||||||
|  |     # handle hierarchical names | ||||||
|  |     # e.g., logger "a" is initialized, then logger "a.b" will skip the | ||||||
|  |     # initialization since it is a child of "a". | ||||||
|  |     for logger_name in logger_initialized: | ||||||
|  |         if name.startswith(logger_name): | ||||||
|  |             return logger | ||||||
|  | 
 | ||||||
|  |     stream_handler = logging.StreamHandler() | ||||||
|  |     handlers = [stream_handler] | ||||||
|  | 
 | ||||||
|  |     if log_file is not None: | ||||||
|  |         file_handler = logging.FileHandler(log_file, 'w') | ||||||
|  |         handlers.append(file_handler) | ||||||
|  | 
 | ||||||
|  |     formatter = logging.Formatter( | ||||||
|  |         '%(asctime)s - %(name)s - %(levelname)s - %(message)s') | ||||||
|  |     for handler in handlers: | ||||||
|  |         handler.setFormatter(formatter) | ||||||
|  |         handler.setLevel(log_level) | ||||||
|  |         logger.addHandler(handler) | ||||||
|  | 
 | ||||||
|  |     logger.setLevel(log_level) | ||||||
|  | 
 | ||||||
|  |     logger_initialized[name] = True | ||||||
|  | 
 | ||||||
|  |     return logger | ||||||
							
								
								
									
										43
									
								
								runner/net_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										43
									
								
								runner/net_utils.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,43 @@ | |||||||
|  | import torch | ||||||
|  | import os | ||||||
|  | from torch import nn | ||||||
|  | import numpy as np | ||||||
|  | import torch.nn.functional | ||||||
|  | from termcolor import colored | ||||||
|  | from .logger import get_logger | ||||||
|  | 
 | ||||||
|  | def save_model(net, optim, scheduler, recorder, is_best=False): | ||||||
|  |     model_dir = os.path.join(recorder.work_dir, 'ckpt') | ||||||
|  |     os.system('mkdir -p {}'.format(model_dir)) | ||||||
|  |     epoch = recorder.epoch | ||||||
|  |     ckpt_name = 'best' if is_best else epoch | ||||||
|  |     torch.save({ | ||||||
|  |         'net': net.state_dict(), | ||||||
|  |         'optim': optim.state_dict(), | ||||||
|  |         'scheduler': scheduler.state_dict(), | ||||||
|  |         'recorder': recorder.state_dict(), | ||||||
|  |         'epoch': epoch | ||||||
|  |     }, os.path.join(model_dir, '{}.pth'.format(ckpt_name))) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def load_network_specified(net, model_dir, logger=None): | ||||||
|  |     pretrained_net = torch.load(model_dir)['net'] | ||||||
|  |     net_state = net.state_dict() | ||||||
|  |     state = {} | ||||||
|  |     for k, v in pretrained_net.items(): | ||||||
|  |         if k not in net_state.keys() or v.size() != net_state[k].size(): | ||||||
|  |             if logger: | ||||||
|  |                 logger.info('skip weights: ' + k) | ||||||
|  |             continue | ||||||
|  |         state[k] = v | ||||||
|  |     net.load_state_dict(state, strict=False) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def load_network(net, model_dir, finetune_from=None, logger=None): | ||||||
|  |     if finetune_from: | ||||||
|  |         if logger: | ||||||
|  |             logger.info('Finetune model from: ' + finetune_from) | ||||||
|  |         load_network_specified(net, finetune_from, logger) | ||||||
|  |         return | ||||||
|  |     pretrained_model = torch.load(model_dir) | ||||||
|  |     net.load_state_dict(pretrained_model['net'], strict=True) | ||||||
							
								
								
									
										26
									
								
								runner/optimizer.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								runner/optimizer.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,26 @@ | |||||||
|  | import torch | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | _optimizer_factory = { | ||||||
|  |     'adam': torch.optim.Adam, | ||||||
|  |     'sgd': torch.optim.SGD | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def build_optimizer(cfg, net): | ||||||
|  |     params = [] | ||||||
|  |     lr = cfg.optimizer.lr | ||||||
|  |     weight_decay = cfg.optimizer.weight_decay | ||||||
|  | 
 | ||||||
|  |     for key, value in net.named_parameters(): | ||||||
|  |         if not value.requires_grad: | ||||||
|  |             continue | ||||||
|  |         params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] | ||||||
|  | 
 | ||||||
|  |     if 'adam' in cfg.optimizer.type: | ||||||
|  |         optimizer = _optimizer_factory[cfg.optimizer.type](params, lr, weight_decay=weight_decay) | ||||||
|  |     else: | ||||||
|  |         optimizer = _optimizer_factory[cfg.optimizer.type]( | ||||||
|  |                 params, lr, weight_decay=weight_decay, momentum=cfg.optimizer.momentum) | ||||||
|  | 
 | ||||||
|  |     return optimizer | ||||||
							
								
								
									
										100
									
								
								runner/recorder.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										100
									
								
								runner/recorder.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,100 @@ | |||||||
|  | from collections import deque, defaultdict | ||||||
|  | import torch | ||||||
|  | import os | ||||||
|  | import datetime | ||||||
|  | from .logger import get_logger | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class SmoothedValue(object): | ||||||
|  |     """Track a series of values and provide access to smoothed values over a | ||||||
|  |     window or the global series average. | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|  |     def __init__(self, window_size=20): | ||||||
|  |         self.deque = deque(maxlen=window_size) | ||||||
|  |         self.total = 0.0 | ||||||
|  |         self.count = 0 | ||||||
|  | 
 | ||||||
|  |     def update(self, value): | ||||||
|  |         self.deque.append(value) | ||||||
|  |         self.count += 1 | ||||||
|  |         self.total += value | ||||||
|  | 
 | ||||||
|  |     @property | ||||||
|  |     def median(self): | ||||||
|  |         d = torch.tensor(list(self.deque)) | ||||||
|  |         return d.median().item() | ||||||
|  | 
 | ||||||
|  |     @property | ||||||
|  |     def avg(self): | ||||||
|  |         d = torch.tensor(list(self.deque)) | ||||||
|  |         return d.mean().item() | ||||||
|  | 
 | ||||||
|  |     @property | ||||||
|  |     def global_avg(self): | ||||||
|  |         return self.total / self.count | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class Recorder(object): | ||||||
|  |     def __init__(self, cfg): | ||||||
|  |         self.cfg = cfg | ||||||
|  |         self.work_dir = self.get_work_dir() | ||||||
|  |         cfg.work_dir = self.work_dir | ||||||
|  |         self.log_path = os.path.join(self.work_dir, 'log.txt') | ||||||
|  | 
 | ||||||
|  |         self.logger = get_logger('resa', self.log_path) | ||||||
|  |         self.logger.info('Config: \n' + cfg.text) | ||||||
|  | 
 | ||||||
|  |         # scalars | ||||||
|  |         self.epoch = 0 | ||||||
|  |         self.step = 0 | ||||||
|  |         self.loss_stats = defaultdict(SmoothedValue) | ||||||
|  |         self.batch_time = SmoothedValue() | ||||||
|  |         self.data_time = SmoothedValue() | ||||||
|  |         self.max_iter = self.cfg.total_iter  | ||||||
|  |         self.lr = 0. | ||||||
|  | 
 | ||||||
|  |     def get_work_dir(self): | ||||||
|  |         now = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') | ||||||
|  |         hyper_param_str = '_lr_%1.0e_b_%d' % (self.cfg.optimizer.lr, self.cfg.batch_size) | ||||||
|  |         work_dir = os.path.join(self.cfg.work_dirs, now + hyper_param_str) | ||||||
|  |         if not os.path.exists(work_dir): | ||||||
|  |             os.makedirs(work_dir) | ||||||
|  |         return work_dir | ||||||
|  | 
 | ||||||
|  |     def update_loss_stats(self, loss_dict): | ||||||
|  |         for k, v in loss_dict.items(): | ||||||
|  |             self.loss_stats[k].update(v.detach().cpu()) | ||||||
|  | 
 | ||||||
|  |     def record(self, prefix, step=-1, loss_stats=None, image_stats=None): | ||||||
|  |         self.logger.info(self) | ||||||
|  |         # self.write(str(self)) | ||||||
|  | 
 | ||||||
|  |     def write(self, content): | ||||||
|  |         with open(self.log_path, 'a+') as f: | ||||||
|  |             f.write(content) | ||||||
|  |             f.write('\n') | ||||||
|  | 
 | ||||||
|  |     def state_dict(self): | ||||||
|  |         scalar_dict = {} | ||||||
|  |         scalar_dict['step'] = self.step | ||||||
|  |         return scalar_dict | ||||||
|  | 
 | ||||||
|  |     def load_state_dict(self, scalar_dict): | ||||||
|  |         self.step = scalar_dict['step'] | ||||||
|  | 
 | ||||||
|  |     def __str__(self): | ||||||
|  |         loss_state = [] | ||||||
|  |         for k, v in self.loss_stats.items(): | ||||||
|  |             loss_state.append('{}: {:.4f}'.format(k, v.avg)) | ||||||
|  |         loss_state = '  '.join(loss_state) | ||||||
|  | 
 | ||||||
|  |         recording_state = '  '.join(['epoch: {}', 'step: {}', 'lr: {:.4f}', '{}', 'data: {:.4f}', 'batch: {:.4f}', 'eta: {}']) | ||||||
|  |         eta_seconds = self.batch_time.global_avg * (self.max_iter - self.step) | ||||||
|  |         eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) | ||||||
|  |         return recording_state.format(self.epoch, self.step, self.lr, loss_state, self.data_time.avg, self.batch_time.avg, eta_string) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def build_recorder(cfg): | ||||||
|  |     return Recorder(cfg) | ||||||
|  | 
 | ||||||
							
								
								
									
										19
									
								
								runner/registry.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								runner/registry.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,19 @@ | |||||||
|  | from utils import Registry, build_from_cfg | ||||||
|  | 
 | ||||||
|  | TRAINER = Registry('trainer') | ||||||
|  | EVALUATOR = Registry('evaluator') | ||||||
|  | 
 | ||||||
|  | def build(cfg, registry, default_args=None): | ||||||
|  |     if isinstance(cfg, list): | ||||||
|  |         modules = [ | ||||||
|  |             build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg | ||||||
|  |         ] | ||||||
|  |         return nn.Sequential(*modules) | ||||||
|  |     else: | ||||||
|  |         return build_from_cfg(cfg, registry, default_args) | ||||||
|  | 
 | ||||||
|  | def build_trainer(cfg): | ||||||
|  |     return build(cfg.trainer, TRAINER, default_args=dict(cfg=cfg)) | ||||||
|  | 
 | ||||||
|  | def build_evaluator(cfg): | ||||||
|  |     return build(cfg.evaluator, EVALUATOR, default_args=dict(cfg=cfg)) | ||||||
							
								
								
									
										58
									
								
								runner/resa_trainer.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										58
									
								
								runner/resa_trainer.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,58 @@ | |||||||
|  | import torch.nn as nn | ||||||
|  | import torch | ||||||
|  | import torch.nn.functional as F | ||||||
|  | 
 | ||||||
|  | from runner.registry import TRAINER | ||||||
|  | 
 | ||||||
|  | def dice_loss(input, target): | ||||||
|  |     input = input.contiguous().view(input.size()[0], -1) | ||||||
|  |     target = target.contiguous().view(target.size()[0], -1).float() | ||||||
|  | 
 | ||||||
|  |     a = torch.sum(input * target, 1) | ||||||
|  |     b = torch.sum(input * input, 1) + 0.001 | ||||||
|  |     c = torch.sum(target * target, 1) + 0.001 | ||||||
|  |     d = (2 * a) / (b + c) | ||||||
|  |     return (1-d).mean() | ||||||
|  | 
 | ||||||
|  | @TRAINER.register_module | ||||||
|  | class RESA(nn.Module): | ||||||
|  |     def __init__(self, cfg): | ||||||
|  |         super(RESA, self).__init__() | ||||||
|  |         self.cfg = cfg | ||||||
|  |         self.loss_type = cfg.loss_type | ||||||
|  |         if self.loss_type == 'cross_entropy': | ||||||
|  |             weights = torch.ones(cfg.num_classes) | ||||||
|  |             weights[0] = cfg.bg_weight | ||||||
|  |             weights = weights.cuda() | ||||||
|  |             self.criterion = torch.nn.NLLLoss(ignore_index=self.cfg.ignore_label, | ||||||
|  |                                               weight=weights).cuda() | ||||||
|  | 
 | ||||||
|  |         self.criterion_exist = torch.nn.BCEWithLogitsLoss().cuda() | ||||||
|  | 
 | ||||||
|  |     def forward(self, net, batch): | ||||||
|  |         output = net(batch['img']) | ||||||
|  | 
 | ||||||
|  |         loss_stats = {} | ||||||
|  |         loss = 0. | ||||||
|  | 
 | ||||||
|  |         if self.loss_type == 'dice_loss': | ||||||
|  |             target = F.one_hot(batch['label'], num_classes=self.cfg.num_classes).permute(0, 3, 1, 2) | ||||||
|  |             seg_loss = dice_loss(F.softmax( | ||||||
|  |                 output['seg'], dim=1)[:, 1:], target[:, 1:]) | ||||||
|  |         else: | ||||||
|  |             seg_loss = self.criterion(F.log_softmax( | ||||||
|  |                 output['seg'], dim=1), batch['label'].long()) | ||||||
|  | 
 | ||||||
|  |         loss += seg_loss * self.cfg.seg_loss_weight | ||||||
|  | 
 | ||||||
|  |         loss_stats.update({'seg_loss': seg_loss}) | ||||||
|  | 
 | ||||||
|  |         if 'exist' in output: | ||||||
|  |             exist_loss = 0.1 * \ | ||||||
|  |                 self.criterion_exist(output['exist'], batch['exist'].float()) | ||||||
|  |             loss += exist_loss | ||||||
|  |             loss_stats.update({'exist_loss': exist_loss}) | ||||||
|  | 
 | ||||||
|  |         ret = {'loss': loss, 'loss_stats': loss_stats} | ||||||
|  | 
 | ||||||
|  |         return ret | ||||||
							
								
								
									
										112
									
								
								runner/runner.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										112
									
								
								runner/runner.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,112 @@ | |||||||
|  | import time | ||||||
|  | import torch | ||||||
|  | import numpy as np | ||||||
|  | from tqdm import tqdm | ||||||
|  | import pytorch_warmup as warmup | ||||||
|  | 
 | ||||||
|  | from models.registry import build_net | ||||||
|  | from .registry import build_trainer, build_evaluator | ||||||
|  | from .optimizer import build_optimizer | ||||||
|  | from .scheduler import build_scheduler | ||||||
|  | from datasets import build_dataloader | ||||||
|  | from .recorder import build_recorder | ||||||
|  | from .net_utils import save_model, load_network | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class Runner(object): | ||||||
|  |     def __init__(self, cfg): | ||||||
|  |         self.cfg = cfg | ||||||
|  |         self.recorder = build_recorder(self.cfg) | ||||||
|  |         self.net = build_net(self.cfg) | ||||||
|  |         self.net = torch.nn.parallel.DataParallel( | ||||||
|  |                 self.net, device_ids = range(self.cfg.gpus)).cuda() | ||||||
|  |         self.recorder.logger.info('Network: \n' + str(self.net)) | ||||||
|  |         self.resume() | ||||||
|  |         self.optimizer = build_optimizer(self.cfg, self.net) | ||||||
|  |         self.scheduler = build_scheduler(self.cfg, self.optimizer) | ||||||
|  |         self.evaluator = build_evaluator(self.cfg) | ||||||
|  |         self.warmup_scheduler = warmup.LinearWarmup( | ||||||
|  |             self.optimizer, warmup_period=5000) | ||||||
|  |         self.metric = 0. | ||||||
|  | 
 | ||||||
|  |     def resume(self): | ||||||
|  |         if not self.cfg.load_from and not self.cfg.finetune_from: | ||||||
|  |             return | ||||||
|  |         load_network(self.net, self.cfg.load_from, | ||||||
|  |                 finetune_from=self.cfg.finetune_from, logger=self.recorder.logger) | ||||||
|  | 
 | ||||||
|  |     def to_cuda(self, batch): | ||||||
|  |         for k in batch: | ||||||
|  |             if k == 'meta': | ||||||
|  |                 continue | ||||||
|  |             batch[k] = batch[k].cuda() | ||||||
|  |         return batch | ||||||
|  |      | ||||||
|  |     def train_epoch(self, epoch, train_loader): | ||||||
|  |         self.net.train() | ||||||
|  |         end = time.time() | ||||||
|  |         max_iter = len(train_loader) | ||||||
|  |         for i, data in enumerate(train_loader): | ||||||
|  |             if self.recorder.step >= self.cfg.total_iter: | ||||||
|  |                 break | ||||||
|  |             date_time = time.time() - end | ||||||
|  |             self.recorder.step += 1 | ||||||
|  |             data = self.to_cuda(data) | ||||||
|  |             output = self.trainer.forward(self.net, data) | ||||||
|  |             self.optimizer.zero_grad() | ||||||
|  |             loss = output['loss'] | ||||||
|  |             loss.backward() | ||||||
|  |             self.optimizer.step() | ||||||
|  |             self.scheduler.step() | ||||||
|  |             self.warmup_scheduler.dampen() | ||||||
|  |             batch_time = time.time() - end | ||||||
|  |             end = time.time() | ||||||
|  |             self.recorder.update_loss_stats(output['loss_stats']) | ||||||
|  |             self.recorder.batch_time.update(batch_time) | ||||||
|  |             self.recorder.data_time.update(date_time) | ||||||
|  | 
 | ||||||
|  |             if i % self.cfg.log_interval == 0 or i == max_iter - 1: | ||||||
|  |                 lr = self.optimizer.param_groups[0]['lr'] | ||||||
|  |                 self.recorder.lr = lr | ||||||
|  |                 self.recorder.record('train') | ||||||
|  | 
 | ||||||
|  |     def train(self): | ||||||
|  |         self.recorder.logger.info('start training...') | ||||||
|  |         self.trainer = build_trainer(self.cfg) | ||||||
|  |         train_loader = build_dataloader(self.cfg.dataset.train, self.cfg, is_train=True) | ||||||
|  |         val_loader = build_dataloader(self.cfg.dataset.val, self.cfg, is_train=False) | ||||||
|  | 
 | ||||||
|  |         for epoch in range(self.cfg.epochs): | ||||||
|  |             print('Epoch: [{}/{}]'.format(self.recorder.step, self.cfg.total_iter)) | ||||||
|  |             print('Epoch: [{}/{}]'.format(epoch, self.cfg.epochs)) | ||||||
|  |             self.recorder.epoch = epoch | ||||||
|  |             self.train_epoch(epoch, train_loader) | ||||||
|  |             if (epoch + 1) % self.cfg.save_ep == 0 or epoch == self.cfg.epochs - 1: | ||||||
|  |                 self.save_ckpt() | ||||||
|  |             if (epoch + 1) % self.cfg.eval_ep == 0 or epoch == self.cfg.epochs - 1: | ||||||
|  |                 self.validate(val_loader) | ||||||
|  |             if self.recorder.step >= self.cfg.total_iter: | ||||||
|  |                 break | ||||||
|  | 
 | ||||||
|  |     def validate(self, val_loader): | ||||||
|  |         self.net.eval() | ||||||
|  |         count = 10 | ||||||
|  |         for i, data in enumerate(tqdm(val_loader, desc=f'Validate')): | ||||||
|  |             start_time = time.time() | ||||||
|  |             data = self.to_cuda(data) | ||||||
|  |             with torch.no_grad(): | ||||||
|  |                 output = self.net(data['img']) | ||||||
|  |                 self.evaluator.evaluate(val_loader.dataset, output, data) | ||||||
|  |             # print("第{}张图片检测花了{}秒".format(i,time.time()-start_time)) | ||||||
|  | 
 | ||||||
|  |         metric = self.evaluator.summarize() | ||||||
|  |         if not metric: | ||||||
|  |             return | ||||||
|  |         if metric > self.metric: | ||||||
|  |             self.metric = metric | ||||||
|  |             self.save_ckpt(is_best=True) | ||||||
|  |         self.recorder.logger.info('Best metric: ' + str(self.metric)) | ||||||
|  | 
 | ||||||
|  |     def save_ckpt(self, is_best=False): | ||||||
|  |         save_model(self.net, self.optimizer, self.scheduler, | ||||||
|  |                 self.recorder, is_best) | ||||||
							
								
								
									
										20
									
								
								runner/scheduler.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								runner/scheduler.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,20 @@ | |||||||
|  | import torch | ||||||
|  | import math | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | _scheduler_factory = { | ||||||
|  |     'LambdaLR': torch.optim.lr_scheduler.LambdaLR, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def build_scheduler(cfg, optimizer): | ||||||
|  | 
 | ||||||
|  |     assert cfg.scheduler.type in _scheduler_factory | ||||||
|  | 
 | ||||||
|  |     cfg_cp = cfg.scheduler.copy() | ||||||
|  |     cfg_cp.pop('type') | ||||||
|  | 
 | ||||||
|  |     scheduler = _scheduler_factory[cfg.scheduler.type](optimizer, **cfg_cp) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |     return scheduler  | ||||||
							
								
								
									
										105
									
								
								tools/generate_seg_tusimple.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										105
									
								
								tools/generate_seg_tusimple.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,105 @@ | |||||||
|  | import json | ||||||
|  | import numpy as np | ||||||
|  | import cv2 | ||||||
|  | import os | ||||||
|  | import argparse | ||||||
|  | 
 | ||||||
|  | TRAIN_SET = ['label_data_0313.json', 'label_data_0601.json'] | ||||||
|  | VAL_SET = ['label_data_0531.json'] | ||||||
|  | TRAIN_VAL_SET = TRAIN_SET + VAL_SET | ||||||
|  | TEST_SET = ['test_label.json'] | ||||||
|  | 
 | ||||||
|  | def gen_label_for_json(args, image_set): | ||||||
|  |     H, W = 720, 1280 | ||||||
|  |     SEG_WIDTH = 30 | ||||||
|  |     save_dir = args.savedir | ||||||
|  | 
 | ||||||
|  |     os.makedirs(os.path.join(args.root, args.savedir, "list"), exist_ok=True) | ||||||
|  |     list_f = open(os.path.join(args.root, args.savedir, "list", "{}_gt.txt".format(image_set)), "w") | ||||||
|  | 
 | ||||||
|  |     json_path = os.path.join(args.root, args.savedir, "{}.json".format(image_set)) | ||||||
|  |     with open(json_path) as f: | ||||||
|  |         for line in f: | ||||||
|  |             label = json.loads(line) | ||||||
|  |             # ---------- clean and sort lanes ------------- | ||||||
|  |             lanes = [] | ||||||
|  |             _lanes = [] | ||||||
|  |             slope = [] # identify 0th, 1st, 2nd, 3rd, 4th, 5th lane through slope | ||||||
|  |             for i in range(len(label['lanes'])): | ||||||
|  |                 l = [(x, y) for x, y in zip(label['lanes'][i], label['h_samples']) if x >= 0] | ||||||
|  |                 if (len(l)>1): | ||||||
|  |                     _lanes.append(l) | ||||||
|  |                     slope.append(np.arctan2(l[-1][1]-l[0][1], l[0][0]-l[-1][0]) / np.pi * 180) | ||||||
|  |             _lanes = [_lanes[i] for i in np.argsort(slope)] | ||||||
|  |             slope = [slope[i] for i in np.argsort(slope)] | ||||||
|  | 
 | ||||||
|  |             idx = [None for i in range(6)] | ||||||
|  |             for i in range(len(slope)): | ||||||
|  |                 if slope[i] <= 90: | ||||||
|  |                     idx[2] = i | ||||||
|  |                     idx[1] = i-1 if i > 0 else None | ||||||
|  |                     idx[0] = i-2 if i > 1 else None | ||||||
|  |                 else: | ||||||
|  |                     idx[3] = i | ||||||
|  |                     idx[4] = i+1 if i+1 < len(slope) else None | ||||||
|  |                     idx[5] = i+2 if i+2 < len(slope) else None | ||||||
|  |                     break | ||||||
|  |             for i in range(6): | ||||||
|  |                 lanes.append([] if idx[i] is None else _lanes[idx[i]]) | ||||||
|  | 
 | ||||||
|  |             # --------------------------------------------- | ||||||
|  | 
 | ||||||
|  |             img_path = label['raw_file'] | ||||||
|  |             seg_img = np.zeros((H, W, 3)) | ||||||
|  |             list_str = []  # str to be written to list.txt | ||||||
|  |             for i in range(len(lanes)): | ||||||
|  |                 coords = lanes[i] | ||||||
|  |                 if len(coords) < 4: | ||||||
|  |                     list_str.append('0') | ||||||
|  |                     continue | ||||||
|  |                 for j in range(len(coords)-1): | ||||||
|  |                     cv2.line(seg_img, coords[j], coords[j+1], (i+1, i+1, i+1), SEG_WIDTH//2) | ||||||
|  |                 list_str.append('1') | ||||||
|  | 
 | ||||||
|  |             seg_path = img_path.split("/") | ||||||
|  |             seg_path, img_name = os.path.join(args.root, args.savedir, seg_path[1], seg_path[2]), seg_path[3] | ||||||
|  |             os.makedirs(seg_path, exist_ok=True) | ||||||
|  |             seg_path = os.path.join(seg_path, img_name[:-3]+"png") | ||||||
|  |             cv2.imwrite(seg_path, seg_img) | ||||||
|  | 
 | ||||||
|  |             seg_path = "/".join([args.savedir, *img_path.split("/")[1:3], img_name[:-3]+"png"]) | ||||||
|  |             if seg_path[0] != '/': | ||||||
|  |                 seg_path = '/' + seg_path | ||||||
|  |             if img_path[0] != '/': | ||||||
|  |                 img_path = '/' + img_path | ||||||
|  |             list_str.insert(0, seg_path) | ||||||
|  |             list_str.insert(0, img_path) | ||||||
|  |             list_str = " ".join(list_str) + "\n" | ||||||
|  |             list_f.write(list_str) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def generate_json_file(save_dir, json_file, image_set): | ||||||
|  |     with open(os.path.join(save_dir, json_file), "w") as outfile: | ||||||
|  |         for json_name in (image_set): | ||||||
|  |             with open(os.path.join(args.root, json_name)) as infile: | ||||||
|  |                 for line in infile: | ||||||
|  |                     outfile.write(line) | ||||||
|  | 
 | ||||||
|  | def generate_label(args): | ||||||
|  |     save_dir = os.path.join(args.root, args.savedir) | ||||||
|  |     os.makedirs(save_dir, exist_ok=True) | ||||||
|  |     generate_json_file(save_dir, "train_val.json", TRAIN_VAL_SET) | ||||||
|  |     generate_json_file(save_dir, "test.json", TEST_SET) | ||||||
|  | 
 | ||||||
|  |     print("generating train_val set...") | ||||||
|  |     gen_label_for_json(args, 'train_val') | ||||||
|  |     print("generating test set...") | ||||||
|  |     gen_label_for_json(args, 'test') | ||||||
|  | 
 | ||||||
|  | if __name__ == '__main__': | ||||||
|  |     parser = argparse.ArgumentParser() | ||||||
|  |     parser.add_argument('--root', required=True, help='The root of the Tusimple dataset') | ||||||
|  |     parser.add_argument('--savedir', type=str, default='seg_label', help='The root of the Tusimple dataset') | ||||||
|  |     args = parser.parse_args() | ||||||
|  | 
 | ||||||
|  |     generate_label(args) | ||||||
							
								
								
									
										2
									
								
								utils/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								utils/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,2 @@ | |||||||
|  | from .config import Config | ||||||
|  | from .registry import Registry, build_from_cfg | ||||||
							
								
								
									
										417
									
								
								utils/config.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										417
									
								
								utils/config.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,417 @@ | |||||||
|  | # Copyright (c) Open-MMLab. All rights reserved. | ||||||
|  | import ast | ||||||
|  | import os.path as osp | ||||||
|  | import shutil | ||||||
|  | import sys | ||||||
|  | import tempfile | ||||||
|  | from argparse import Action, ArgumentParser | ||||||
|  | from collections import abc | ||||||
|  | from importlib import import_module | ||||||
|  | 
 | ||||||
|  | from addict import Dict | ||||||
|  | from yapf.yapflib.yapf_api import FormatCode | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | BASE_KEY = '_base_' | ||||||
|  | DELETE_KEY = '_delete_' | ||||||
|  | RESERVED_KEYS = ['filename', 'text', 'pretty_text'] | ||||||
|  | 
 | ||||||
|  | def check_file_exist(filename, msg_tmpl='file "{}" does not exist'): | ||||||
|  |     if not osp.isfile(filename): | ||||||
|  |         raise FileNotFoundError(msg_tmpl.format(filename)) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class ConfigDict(Dict): | ||||||
|  | 
 | ||||||
|  |     def __missing__(self, name): | ||||||
|  |         raise KeyError(name) | ||||||
|  | 
 | ||||||
|  |     def __getattr__(self, name): | ||||||
|  |         try: | ||||||
|  |             value = super(ConfigDict, self).__getattr__(name) | ||||||
|  |         except KeyError: | ||||||
|  |             ex = AttributeError(f"'{self.__class__.__name__}' object has no " | ||||||
|  |                                 f"attribute '{name}'") | ||||||
|  |         except Exception as e: | ||||||
|  |             ex = e | ||||||
|  |         else: | ||||||
|  |             return value | ||||||
|  |         raise ex | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def add_args(parser, cfg, prefix=''): | ||||||
|  |     for k, v in cfg.items(): | ||||||
|  |         if isinstance(v, str): | ||||||
|  |             parser.add_argument('--' + prefix + k) | ||||||
|  |         elif isinstance(v, int): | ||||||
|  |             parser.add_argument('--' + prefix + k, type=int) | ||||||
|  |         elif isinstance(v, float): | ||||||
|  |             parser.add_argument('--' + prefix + k, type=float) | ||||||
|  |         elif isinstance(v, bool): | ||||||
|  |             parser.add_argument('--' + prefix + k, action='store_true') | ||||||
|  |         elif isinstance(v, dict): | ||||||
|  |             add_args(parser, v, prefix + k + '.') | ||||||
|  |         elif isinstance(v, abc.Iterable): | ||||||
|  |             parser.add_argument('--' + prefix + k, type=type(v[0]), nargs='+') | ||||||
|  |         else: | ||||||
|  |             print(f'cannot parse key {prefix + k} of type {type(v)}') | ||||||
|  |     return parser | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class Config: | ||||||
|  |     """A facility for config and config files. | ||||||
|  |     It supports common file formats as configs: python/json/yaml. The interface | ||||||
|  |     is the same as a dict object and also allows access config values as | ||||||
|  |     attributes. | ||||||
|  |     Example: | ||||||
|  |         >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1]))) | ||||||
|  |         >>> cfg.a | ||||||
|  |         1 | ||||||
|  |         >>> cfg.b | ||||||
|  |         {'b1': [0, 1]} | ||||||
|  |         >>> cfg.b.b1 | ||||||
|  |         [0, 1] | ||||||
|  |         >>> cfg = Config.fromfile('tests/data/config/a.py') | ||||||
|  |         >>> cfg.filename | ||||||
|  |         "/home/kchen/projects/mmcv/tests/data/config/a.py" | ||||||
|  |         >>> cfg.item4 | ||||||
|  |         'test' | ||||||
|  |         >>> cfg | ||||||
|  |         "Config [path: /home/kchen/projects/mmcv/tests/data/config/a.py]: " | ||||||
|  |         "{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}" | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|  |     @staticmethod | ||||||
|  |     def _validate_py_syntax(filename): | ||||||
|  |         with open(filename) as f: | ||||||
|  |             content = f.read() | ||||||
|  |         try: | ||||||
|  |             ast.parse(content) | ||||||
|  |         except SyntaxError: | ||||||
|  |             raise SyntaxError('There are syntax errors in config ' | ||||||
|  |                               f'file {filename}') | ||||||
|  | 
 | ||||||
|  |     @staticmethod | ||||||
|  |     def _file2dict(filename): | ||||||
|  |         filename = osp.abspath(osp.expanduser(filename)) | ||||||
|  |         check_file_exist(filename) | ||||||
|  |         if filename.endswith('.py'): | ||||||
|  |             with tempfile.TemporaryDirectory() as temp_config_dir: | ||||||
|  |                 temp_config_file = tempfile.NamedTemporaryFile( | ||||||
|  |                     dir=temp_config_dir, suffix='.py') | ||||||
|  |                 temp_config_name = osp.basename(temp_config_file.name) | ||||||
|  |                 shutil.copyfile(filename, | ||||||
|  |                                 osp.join(temp_config_dir, temp_config_name)) | ||||||
|  |                 temp_module_name = osp.splitext(temp_config_name)[0] | ||||||
|  |                 sys.path.insert(0, temp_config_dir) | ||||||
|  |                 Config._validate_py_syntax(filename) | ||||||
|  |                 mod = import_module(temp_module_name) | ||||||
|  |                 sys.path.pop(0) | ||||||
|  |                 cfg_dict = { | ||||||
|  |                     name: value | ||||||
|  |                     for name, value in mod.__dict__.items() | ||||||
|  |                     if not name.startswith('__') | ||||||
|  |                 } | ||||||
|  |                 # delete imported module | ||||||
|  |                 del sys.modules[temp_module_name] | ||||||
|  |                 # close temp file | ||||||
|  |                 temp_config_file.close() | ||||||
|  |         elif filename.endswith(('.yml', '.yaml', '.json')): | ||||||
|  |             import mmcv | ||||||
|  |             cfg_dict = mmcv.load(filename) | ||||||
|  |         else: | ||||||
|  |             raise IOError('Only py/yml/yaml/json type are supported now!') | ||||||
|  | 
 | ||||||
|  |         cfg_text = filename + '\n' | ||||||
|  |         with open(filename, 'r') as f: | ||||||
|  |             cfg_text += f.read() | ||||||
|  | 
 | ||||||
|  |         if BASE_KEY in cfg_dict: | ||||||
|  |             cfg_dir = osp.dirname(filename) | ||||||
|  |             base_filename = cfg_dict.pop(BASE_KEY) | ||||||
|  |             base_filename = base_filename if isinstance( | ||||||
|  |                 base_filename, list) else [base_filename] | ||||||
|  | 
 | ||||||
|  |             cfg_dict_list = list() | ||||||
|  |             cfg_text_list = list() | ||||||
|  |             for f in base_filename: | ||||||
|  |                 _cfg_dict, _cfg_text = Config._file2dict(osp.join(cfg_dir, f)) | ||||||
|  |                 cfg_dict_list.append(_cfg_dict) | ||||||
|  |                 cfg_text_list.append(_cfg_text) | ||||||
|  | 
 | ||||||
|  |             base_cfg_dict = dict() | ||||||
|  |             for c in cfg_dict_list: | ||||||
|  |                 if len(base_cfg_dict.keys() & c.keys()) > 0: | ||||||
|  |                     raise KeyError('Duplicate key is not allowed among bases') | ||||||
|  |                 base_cfg_dict.update(c) | ||||||
|  | 
 | ||||||
|  |             base_cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict) | ||||||
|  |             cfg_dict = base_cfg_dict | ||||||
|  | 
 | ||||||
|  |             # merge cfg_text | ||||||
|  |             cfg_text_list.append(cfg_text) | ||||||
|  |             cfg_text = '\n'.join(cfg_text_list) | ||||||
|  | 
 | ||||||
|  |         return cfg_dict, cfg_text | ||||||
|  | 
 | ||||||
|  |     @staticmethod | ||||||
|  |     def _merge_a_into_b(a, b): | ||||||
|  |         # merge dict `a` into dict `b` (non-inplace). values in `a` will | ||||||
|  |         # overwrite `b`. | ||||||
|  |         # copy first to avoid inplace modification | ||||||
|  |         b = b.copy() | ||||||
|  |         for k, v in a.items(): | ||||||
|  |             if isinstance(v, dict) and k in b and not v.pop(DELETE_KEY, False): | ||||||
|  |                 if not isinstance(b[k], dict): | ||||||
|  |                     raise TypeError( | ||||||
|  |                         f'{k}={v} in child config cannot inherit from base ' | ||||||
|  |                         f'because {k} is a dict in the child config but is of ' | ||||||
|  |                         f'type {type(b[k])} in base config. You may set ' | ||||||
|  |                         f'`{DELETE_KEY}=True` to ignore the base config') | ||||||
|  |                 b[k] = Config._merge_a_into_b(v, b[k]) | ||||||
|  |             else: | ||||||
|  |                 b[k] = v | ||||||
|  |         return b | ||||||
|  | 
 | ||||||
|  |     @staticmethod | ||||||
|  |     def fromfile(filename): | ||||||
|  |         cfg_dict, cfg_text = Config._file2dict(filename) | ||||||
|  |         return Config(cfg_dict, cfg_text=cfg_text, filename=filename) | ||||||
|  | 
 | ||||||
|  |     @staticmethod | ||||||
|  |     def auto_argparser(description=None): | ||||||
|  |         """Generate argparser from config file automatically (experimental) | ||||||
|  |         """ | ||||||
|  |         partial_parser = ArgumentParser(description=description) | ||||||
|  |         partial_parser.add_argument('config', help='config file path') | ||||||
|  |         cfg_file = partial_parser.parse_known_args()[0].config | ||||||
|  |         cfg = Config.fromfile(cfg_file) | ||||||
|  |         parser = ArgumentParser(description=description) | ||||||
|  |         parser.add_argument('config', help='config file path') | ||||||
|  |         add_args(parser, cfg) | ||||||
|  |         return parser, cfg | ||||||
|  | 
 | ||||||
|  |     def __init__(self, cfg_dict=None, cfg_text=None, filename=None): | ||||||
|  |         if cfg_dict is None: | ||||||
|  |             cfg_dict = dict() | ||||||
|  |         elif not isinstance(cfg_dict, dict): | ||||||
|  |             raise TypeError('cfg_dict must be a dict, but ' | ||||||
|  |                             f'got {type(cfg_dict)}') | ||||||
|  |         for key in cfg_dict: | ||||||
|  |             if key in RESERVED_KEYS: | ||||||
|  |                 raise KeyError(f'{key} is reserved for config file') | ||||||
|  | 
 | ||||||
|  |         super(Config, self).__setattr__('_cfg_dict', ConfigDict(cfg_dict)) | ||||||
|  |         super(Config, self).__setattr__('_filename', filename) | ||||||
|  |         if cfg_text: | ||||||
|  |             text = cfg_text | ||||||
|  |         elif filename: | ||||||
|  |             with open(filename, 'r') as f: | ||||||
|  |                 text = f.read() | ||||||
|  |         else: | ||||||
|  |             text = '' | ||||||
|  |         super(Config, self).__setattr__('_text', text) | ||||||
|  | 
 | ||||||
|  |     @property | ||||||
|  |     def filename(self): | ||||||
|  |         return self._filename | ||||||
|  | 
 | ||||||
|  |     @property | ||||||
|  |     def text(self): | ||||||
|  |         return self._text | ||||||
|  | 
 | ||||||
|  |     @property | ||||||
|  |     def pretty_text(self): | ||||||
|  | 
 | ||||||
|  |         indent = 4 | ||||||
|  | 
 | ||||||
|  |         def _indent(s_, num_spaces): | ||||||
|  |             s = s_.split('\n') | ||||||
|  |             if len(s) == 1: | ||||||
|  |                 return s_ | ||||||
|  |             first = s.pop(0) | ||||||
|  |             s = [(num_spaces * ' ') + line for line in s] | ||||||
|  |             s = '\n'.join(s) | ||||||
|  |             s = first + '\n' + s | ||||||
|  |             return s | ||||||
|  | 
 | ||||||
|  |         def _format_basic_types(k, v, use_mapping=False): | ||||||
|  |             if isinstance(v, str): | ||||||
|  |                 v_str = f"'{v}'" | ||||||
|  |             else: | ||||||
|  |                 v_str = str(v) | ||||||
|  | 
 | ||||||
|  |             if use_mapping: | ||||||
|  |                 k_str = f"'{k}'" if isinstance(k, str) else str(k) | ||||||
|  |                 attr_str = f'{k_str}: {v_str}' | ||||||
|  |             else: | ||||||
|  |                 attr_str = f'{str(k)}={v_str}' | ||||||
|  |             attr_str = _indent(attr_str, indent) | ||||||
|  | 
 | ||||||
|  |             return attr_str | ||||||
|  | 
 | ||||||
|  |         def _format_list(k, v, use_mapping=False): | ||||||
|  |             # check if all items in the list are dict | ||||||
|  |             if all(isinstance(_, dict) for _ in v): | ||||||
|  |                 v_str = '[\n' | ||||||
|  |                 v_str += '\n'.join( | ||||||
|  |                     f'dict({_indent(_format_dict(v_), indent)}),' | ||||||
|  |                     for v_ in v).rstrip(',') | ||||||
|  |                 if use_mapping: | ||||||
|  |                     k_str = f"'{k}'" if isinstance(k, str) else str(k) | ||||||
|  |                     attr_str = f'{k_str}: {v_str}' | ||||||
|  |                 else: | ||||||
|  |                     attr_str = f'{str(k)}={v_str}' | ||||||
|  |                 attr_str = _indent(attr_str, indent) + ']' | ||||||
|  |             else: | ||||||
|  |                 attr_str = _format_basic_types(k, v, use_mapping) | ||||||
|  |             return attr_str | ||||||
|  | 
 | ||||||
|  |         def _contain_invalid_identifier(dict_str): | ||||||
|  |             contain_invalid_identifier = False | ||||||
|  |             for key_name in dict_str: | ||||||
|  |                 contain_invalid_identifier |= \ | ||||||
|  |                     (not str(key_name).isidentifier()) | ||||||
|  |             return contain_invalid_identifier | ||||||
|  | 
 | ||||||
|  |         def _format_dict(input_dict, outest_level=False): | ||||||
|  |             r = '' | ||||||
|  |             s = [] | ||||||
|  | 
 | ||||||
|  |             use_mapping = _contain_invalid_identifier(input_dict) | ||||||
|  |             if use_mapping: | ||||||
|  |                 r += '{' | ||||||
|  |             for idx, (k, v) in enumerate(input_dict.items()): | ||||||
|  |                 is_last = idx >= len(input_dict) - 1 | ||||||
|  |                 end = '' if outest_level or is_last else ',' | ||||||
|  |                 if isinstance(v, dict): | ||||||
|  |                     v_str = '\n' + _format_dict(v) | ||||||
|  |                     if use_mapping: | ||||||
|  |                         k_str = f"'{k}'" if isinstance(k, str) else str(k) | ||||||
|  |                         attr_str = f'{k_str}: dict({v_str}' | ||||||
|  |                     else: | ||||||
|  |                         attr_str = f'{str(k)}=dict({v_str}' | ||||||
|  |                     attr_str = _indent(attr_str, indent) + ')' + end | ||||||
|  |                 elif isinstance(v, list): | ||||||
|  |                     attr_str = _format_list(k, v, use_mapping) + end | ||||||
|  |                 else: | ||||||
|  |                     attr_str = _format_basic_types(k, v, use_mapping) + end | ||||||
|  | 
 | ||||||
|  |                 s.append(attr_str) | ||||||
|  |             r += '\n'.join(s) | ||||||
|  |             if use_mapping: | ||||||
|  |                 r += '}' | ||||||
|  |             return r | ||||||
|  | 
 | ||||||
|  |         cfg_dict = self._cfg_dict.to_dict() | ||||||
|  |         text = _format_dict(cfg_dict, outest_level=True) | ||||||
|  |         # copied from setup.cfg | ||||||
|  |         yapf_style = dict( | ||||||
|  |             based_on_style='pep8', | ||||||
|  |             blank_line_before_nested_class_or_def=True, | ||||||
|  |             split_before_expression_after_opening_paren=True) | ||||||
|  |         text, _ = FormatCode(text, style_config=yapf_style, verify=True) | ||||||
|  | 
 | ||||||
|  |         return text | ||||||
|  | 
 | ||||||
|  |     def __repr__(self): | ||||||
|  |         return f'Config (path: {self.filename}): {self._cfg_dict.__repr__()}' | ||||||
|  | 
 | ||||||
|  |     def __len__(self): | ||||||
|  |         return len(self._cfg_dict) | ||||||
|  | 
 | ||||||
|  |     def __getattr__(self, name): | ||||||
|  |         return getattr(self._cfg_dict, name) | ||||||
|  | 
 | ||||||
|  |     def __getitem__(self, name): | ||||||
|  |         return self._cfg_dict.__getitem__(name) | ||||||
|  | 
 | ||||||
|  |     def __setattr__(self, name, value): | ||||||
|  |         if isinstance(value, dict): | ||||||
|  |             value = ConfigDict(value) | ||||||
|  |         self._cfg_dict.__setattr__(name, value) | ||||||
|  | 
 | ||||||
|  |     def __setitem__(self, name, value): | ||||||
|  |         if isinstance(value, dict): | ||||||
|  |             value = ConfigDict(value) | ||||||
|  |         self._cfg_dict.__setitem__(name, value) | ||||||
|  | 
 | ||||||
|  |     def __iter__(self): | ||||||
|  |         return iter(self._cfg_dict) | ||||||
|  | 
 | ||||||
|  |     def dump(self, file=None): | ||||||
|  |         cfg_dict = super(Config, self).__getattribute__('_cfg_dict').to_dict() | ||||||
|  |         if self.filename.endswith('.py'): | ||||||
|  |             if file is None: | ||||||
|  |                 return self.pretty_text | ||||||
|  |             else: | ||||||
|  |                 with open(file, 'w') as f: | ||||||
|  |                     f.write(self.pretty_text) | ||||||
|  |         else: | ||||||
|  |             import mmcv | ||||||
|  |             if file is None: | ||||||
|  |                 file_format = self.filename.split('.')[-1] | ||||||
|  |                 return mmcv.dump(cfg_dict, file_format=file_format) | ||||||
|  |             else: | ||||||
|  |                 mmcv.dump(cfg_dict, file) | ||||||
|  | 
 | ||||||
|  |     def merge_from_dict(self, options): | ||||||
|  |         """Merge list into cfg_dict | ||||||
|  |         Merge the dict parsed by MultipleKVAction into this cfg. | ||||||
|  |         Examples: | ||||||
|  |             >>> options = {'model.backbone.depth': 50, | ||||||
|  |             ...            'model.backbone.with_cp':True} | ||||||
|  |             >>> cfg = Config(dict(model=dict(backbone=dict(type='ResNet')))) | ||||||
|  |             >>> cfg.merge_from_dict(options) | ||||||
|  |             >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict') | ||||||
|  |             >>> assert cfg_dict == dict( | ||||||
|  |             ...     model=dict(backbone=dict(depth=50, with_cp=True))) | ||||||
|  |         Args: | ||||||
|  |             options (dict): dict of configs to merge from. | ||||||
|  |         """ | ||||||
|  |         option_cfg_dict = {} | ||||||
|  |         for full_key, v in options.items(): | ||||||
|  |             d = option_cfg_dict | ||||||
|  |             key_list = full_key.split('.') | ||||||
|  |             for subkey in key_list[:-1]: | ||||||
|  |                 d.setdefault(subkey, ConfigDict()) | ||||||
|  |                 d = d[subkey] | ||||||
|  |             subkey = key_list[-1] | ||||||
|  |             d[subkey] = v | ||||||
|  | 
 | ||||||
|  |         cfg_dict = super(Config, self).__getattribute__('_cfg_dict') | ||||||
|  |         super(Config, self).__setattr__( | ||||||
|  |             '_cfg_dict', Config._merge_a_into_b(option_cfg_dict, cfg_dict)) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class DictAction(Action): | ||||||
|  |     """ | ||||||
|  |     argparse action to split an argument into KEY=VALUE form | ||||||
|  |     on the first = and append to a dictionary. List options should | ||||||
|  |     be passed as comma separated values, i.e KEY=V1,V2,V3 | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|  |     @staticmethod | ||||||
|  |     def _parse_int_float_bool(val): | ||||||
|  |         try: | ||||||
|  |             return int(val) | ||||||
|  |         except ValueError: | ||||||
|  |             pass | ||||||
|  |         try: | ||||||
|  |             return float(val) | ||||||
|  |         except ValueError: | ||||||
|  |             pass | ||||||
|  |         if val.lower() in ['true', 'false']: | ||||||
|  |             return True if val.lower() == 'true' else False | ||||||
|  |         return val | ||||||
|  | 
 | ||||||
|  |     def __call__(self, parser, namespace, values, option_string=None): | ||||||
|  |         options = {} | ||||||
|  |         for kv in values: | ||||||
|  |             key, val = kv.split('=', maxsplit=1) | ||||||
|  |             val = [self._parse_int_float_bool(v) for v in val.split(',')] | ||||||
|  |             if len(val) == 1: | ||||||
|  |                 val = val[0] | ||||||
|  |             options[key] = val | ||||||
|  |         setattr(namespace, self.dest, options) | ||||||
							
								
								
									
										81
									
								
								utils/registry.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										81
									
								
								utils/registry.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,81 @@ | |||||||
|  | import inspect | ||||||
|  | 
 | ||||||
|  | import six | ||||||
|  | 
 | ||||||
|  | # borrow from mmdetection | ||||||
|  | 
 | ||||||
|  | def is_str(x): | ||||||
|  |     """Whether the input is an string instance.""" | ||||||
|  |     return isinstance(x, six.string_types) | ||||||
|  | 
 | ||||||
|  | class Registry(object): | ||||||
|  | 
 | ||||||
|  |     def __init__(self, name): | ||||||
|  |         self._name = name | ||||||
|  |         self._module_dict = dict() | ||||||
|  | 
 | ||||||
|  |     def __repr__(self): | ||||||
|  |         format_str = self.__class__.__name__ + '(name={}, items={})'.format( | ||||||
|  |             self._name, list(self._module_dict.keys())) | ||||||
|  |         return format_str | ||||||
|  | 
 | ||||||
|  |     @property | ||||||
|  |     def name(self): | ||||||
|  |         return self._name | ||||||
|  | 
 | ||||||
|  |     @property | ||||||
|  |     def module_dict(self): | ||||||
|  |         return self._module_dict | ||||||
|  | 
 | ||||||
|  |     def get(self, key): | ||||||
|  |         return self._module_dict.get(key, None) | ||||||
|  | 
 | ||||||
|  |     def _register_module(self, module_class): | ||||||
|  |         """Register a module. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             module (:obj:`nn.Module`): Module to be registered. | ||||||
|  |         """ | ||||||
|  |         if not inspect.isclass(module_class): | ||||||
|  |             raise TypeError('module must be a class, but got {}'.format( | ||||||
|  |                 type(module_class))) | ||||||
|  |         module_name = module_class.__name__ | ||||||
|  |         if module_name in self._module_dict: | ||||||
|  |             raise KeyError('{} is already registered in {}'.format( | ||||||
|  |                 module_name, self.name)) | ||||||
|  |         self._module_dict[module_name] = module_class | ||||||
|  | 
 | ||||||
|  |     def register_module(self, cls): | ||||||
|  |         self._register_module(cls) | ||||||
|  |         return cls | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def build_from_cfg(cfg, registry, default_args=None): | ||||||
|  |     """Build a module from config dict. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         cfg (dict): Config dict. It should at least contain the key "type". | ||||||
|  |         registry (:obj:`Registry`): The registry to search the type from. | ||||||
|  |         default_args (dict, optional): Default initialization arguments. | ||||||
|  | 
 | ||||||
|  |     Returns: | ||||||
|  |         obj: The constructed object. | ||||||
|  |     """ | ||||||
|  |     assert isinstance(cfg, dict) and 'type' in cfg | ||||||
|  |     assert isinstance(default_args, dict) or default_args is None | ||||||
|  |     args = {} | ||||||
|  |     obj_type = cfg.type  | ||||||
|  |     if is_str(obj_type): | ||||||
|  |         obj_cls = registry.get(obj_type) | ||||||
|  |         if obj_cls is None: | ||||||
|  |             raise KeyError('{} is not in the {} registry'.format( | ||||||
|  |                 obj_type, registry.name)) | ||||||
|  |     elif inspect.isclass(obj_type): | ||||||
|  |         obj_cls = obj_type | ||||||
|  |     else: | ||||||
|  |         raise TypeError('type must be a str or valid type, but got {}'.format( | ||||||
|  |             type(obj_type))) | ||||||
|  |     if default_args is not None: | ||||||
|  |         for name, value in default_args.items(): | ||||||
|  |             args.setdefault(name, value) | ||||||
|  |     return obj_cls(**args) | ||||||
							
								
								
									
										357
									
								
								utils/transforms.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										357
									
								
								utils/transforms.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,357 @@ | |||||||
|  | import random | ||||||
|  | import cv2 | ||||||
|  | import numpy as np | ||||||
|  | import numbers | ||||||
|  | import collections | ||||||
|  | 
 | ||||||
|  | # copy from: https://github.com/cardwing/Codes-for-Lane-Detection/blob/master/ERFNet-CULane-PyTorch/utils/transforms.py | ||||||
|  | 
 | ||||||
|  | __all__ = ['GroupRandomCrop', 'GroupCenterCrop', 'GroupRandomPad', 'GroupCenterPad', | ||||||
|  |            'GroupRandomScale', 'GroupRandomHorizontalFlip', 'GroupNormalize'] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class SampleResize(object): | ||||||
|  |     def __init__(self, size): | ||||||
|  |         assert (isinstance(size, collections.Iterable) and len(size) == 2) | ||||||
|  |         self.size = size | ||||||
|  | 
 | ||||||
|  |     def __call__(self, sample): | ||||||
|  |         out = list() | ||||||
|  |         out.append(cv2.resize(sample[0], self.size, | ||||||
|  |                               interpolation=cv2.INTER_CUBIC)) | ||||||
|  |         if len(sample) > 1: | ||||||
|  |             out.append(cv2.resize(sample[1], self.size, | ||||||
|  |                                   interpolation=cv2.INTER_NEAREST)) | ||||||
|  |         return out | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class GroupRandomCrop(object): | ||||||
|  |     def __init__(self, size): | ||||||
|  |         if isinstance(size, numbers.Number): | ||||||
|  |             self.size = (int(size), int(size)) | ||||||
|  |         else: | ||||||
|  |             self.size = size | ||||||
|  | 
 | ||||||
|  |     def __call__(self, img_group): | ||||||
|  |         h, w = img_group[0].shape[0:2] | ||||||
|  |         th, tw = self.size | ||||||
|  | 
 | ||||||
|  |         out_images = list() | ||||||
|  |         h1 = random.randint(0, max(0, h - th)) | ||||||
|  |         w1 = random.randint(0, max(0, w - tw)) | ||||||
|  |         h2 = min(h1 + th, h) | ||||||
|  |         w2 = min(w1 + tw, w) | ||||||
|  | 
 | ||||||
|  |         for img in img_group: | ||||||
|  |             assert (img.shape[0] == h and img.shape[1] == w) | ||||||
|  |             out_images.append(img[h1:h2, w1:w2, ...]) | ||||||
|  |         return out_images | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class GroupRandomCropRatio(object): | ||||||
|  |     def __init__(self, size): | ||||||
|  |         if isinstance(size, numbers.Number): | ||||||
|  |             self.size = (int(size), int(size)) | ||||||
|  |         else: | ||||||
|  |             self.size = size | ||||||
|  | 
 | ||||||
|  |     def __call__(self, img_group): | ||||||
|  |         h, w = img_group[0].shape[0:2] | ||||||
|  |         tw, th = self.size | ||||||
|  | 
 | ||||||
|  |         out_images = list() | ||||||
|  |         h1 = random.randint(0, max(0, h - th)) | ||||||
|  |         w1 = random.randint(0, max(0, w - tw)) | ||||||
|  |         h2 = min(h1 + th, h) | ||||||
|  |         w2 = min(w1 + tw, w) | ||||||
|  | 
 | ||||||
|  |         for img in img_group: | ||||||
|  |             assert (img.shape[0] == h and img.shape[1] == w) | ||||||
|  |             out_images.append(img[h1:h2, w1:w2, ...]) | ||||||
|  |         return out_images | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class GroupCenterCrop(object): | ||||||
|  |     def __init__(self, size): | ||||||
|  |         if isinstance(size, numbers.Number): | ||||||
|  |             self.size = (int(size), int(size)) | ||||||
|  |         else: | ||||||
|  |             self.size = size | ||||||
|  | 
 | ||||||
|  |     def __call__(self, img_group): | ||||||
|  |         h, w = img_group[0].shape[0:2] | ||||||
|  |         th, tw = self.size | ||||||
|  | 
 | ||||||
|  |         out_images = list() | ||||||
|  |         h1 = max(0, int((h - th) / 2)) | ||||||
|  |         w1 = max(0, int((w - tw) / 2)) | ||||||
|  |         h2 = min(h1 + th, h) | ||||||
|  |         w2 = min(w1 + tw, w) | ||||||
|  | 
 | ||||||
|  |         for img in img_group: | ||||||
|  |             assert (img.shape[0] == h and img.shape[1] == w) | ||||||
|  |             out_images.append(img[h1:h2, w1:w2, ...]) | ||||||
|  |         return out_images | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class GroupRandomPad(object): | ||||||
|  |     def __init__(self, size, padding): | ||||||
|  |         if isinstance(size, numbers.Number): | ||||||
|  |             self.size = (int(size), int(size)) | ||||||
|  |         else: | ||||||
|  |             self.size = size | ||||||
|  |         self.padding = padding | ||||||
|  | 
 | ||||||
|  |     def __call__(self, img_group): | ||||||
|  |         assert (len(self.padding) == len(img_group)) | ||||||
|  |         h, w = img_group[0].shape[0:2] | ||||||
|  |         th, tw = self.size | ||||||
|  | 
 | ||||||
|  |         out_images = list() | ||||||
|  |         h1 = random.randint(0, max(0, th - h)) | ||||||
|  |         w1 = random.randint(0, max(0, tw - w)) | ||||||
|  |         h2 = max(th - h - h1, 0) | ||||||
|  |         w2 = max(tw - w - w1, 0) | ||||||
|  | 
 | ||||||
|  |         for img, padding in zip(img_group, self.padding): | ||||||
|  |             assert (img.shape[0] == h and img.shape[1] == w) | ||||||
|  |             out_images.append(cv2.copyMakeBorder( | ||||||
|  |                 img, h1, h2, w1, w2, cv2.BORDER_CONSTANT, value=padding)) | ||||||
|  |             if len(img.shape) > len(out_images[-1].shape): | ||||||
|  |                 out_images[-1] = out_images[-1][..., | ||||||
|  |                                                 np.newaxis]  # single channel image | ||||||
|  |         return out_images | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class GroupCenterPad(object): | ||||||
|  |     def __init__(self, size, padding): | ||||||
|  |         if isinstance(size, numbers.Number): | ||||||
|  |             self.size = (int(size), int(size)) | ||||||
|  |         else: | ||||||
|  |             self.size = size | ||||||
|  |         self.padding = padding | ||||||
|  | 
 | ||||||
|  |     def __call__(self, img_group): | ||||||
|  |         assert (len(self.padding) == len(img_group)) | ||||||
|  |         h, w = img_group[0].shape[0:2] | ||||||
|  |         th, tw = self.size | ||||||
|  | 
 | ||||||
|  |         out_images = list() | ||||||
|  |         h1 = max(0, int((th - h) / 2)) | ||||||
|  |         w1 = max(0, int((tw - w) / 2)) | ||||||
|  |         h2 = max(th - h - h1, 0) | ||||||
|  |         w2 = max(tw - w - w1, 0) | ||||||
|  | 
 | ||||||
|  |         for img, padding in zip(img_group, self.padding): | ||||||
|  |             assert (img.shape[0] == h and img.shape[1] == w) | ||||||
|  |             out_images.append(cv2.copyMakeBorder( | ||||||
|  |                 img, h1, h2, w1, w2, cv2.BORDER_CONSTANT, value=padding)) | ||||||
|  |             if len(img.shape) > len(out_images[-1].shape): | ||||||
|  |                 out_images[-1] = out_images[-1][..., | ||||||
|  |                                                 np.newaxis]  # single channel image | ||||||
|  |         return out_images | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class GroupConcerPad(object): | ||||||
|  |     def __init__(self, size, padding): | ||||||
|  |         if isinstance(size, numbers.Number): | ||||||
|  |             self.size = (int(size), int(size)) | ||||||
|  |         else: | ||||||
|  |             self.size = size | ||||||
|  |         self.padding = padding | ||||||
|  | 
 | ||||||
|  |     def __call__(self, img_group): | ||||||
|  |         assert (len(self.padding) == len(img_group)) | ||||||
|  |         h, w = img_group[0].shape[0:2] | ||||||
|  |         th, tw = self.size | ||||||
|  | 
 | ||||||
|  |         out_images = list() | ||||||
|  |         h1 = 0 | ||||||
|  |         w1 = 0 | ||||||
|  |         h2 = max(th - h - h1, 0) | ||||||
|  |         w2 = max(tw - w - w1, 0) | ||||||
|  | 
 | ||||||
|  |         for img, padding in zip(img_group, self.padding): | ||||||
|  |             assert (img.shape[0] == h and img.shape[1] == w) | ||||||
|  |             out_images.append(cv2.copyMakeBorder( | ||||||
|  |                 img, h1, h2, w1, w2, cv2.BORDER_CONSTANT, value=padding)) | ||||||
|  |             if len(img.shape) > len(out_images[-1].shape): | ||||||
|  |                 out_images[-1] = out_images[-1][..., | ||||||
|  |                                                 np.newaxis]  # single channel image | ||||||
|  |         return out_images | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class GroupRandomScaleNew(object): | ||||||
|  |     def __init__(self, size=(976, 208), interpolation=(cv2.INTER_LINEAR, cv2.INTER_NEAREST)): | ||||||
|  |         self.size = size | ||||||
|  |         self.interpolation = interpolation | ||||||
|  | 
 | ||||||
|  |     def __call__(self, img_group): | ||||||
|  |         assert (len(self.interpolation) == len(img_group)) | ||||||
|  |         scale_w, scale_h = self.size[0] * 1.0 / 1640, self.size[1] * 1.0 / 590 | ||||||
|  |         out_images = list() | ||||||
|  |         for img, interpolation in zip(img_group, self.interpolation): | ||||||
|  |             out_images.append(cv2.resize(img, None, fx=scale_w, | ||||||
|  |                                          fy=scale_h, interpolation=interpolation)) | ||||||
|  |             if len(img.shape) > len(out_images[-1].shape): | ||||||
|  |                 out_images[-1] = out_images[-1][..., | ||||||
|  |                                                 np.newaxis]  # single channel image | ||||||
|  |         return out_images | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class GroupRandomScale(object): | ||||||
|  |     def __init__(self, size=(0.5, 1.5), interpolation=(cv2.INTER_LINEAR, cv2.INTER_NEAREST)): | ||||||
|  |         self.size = size | ||||||
|  |         self.interpolation = interpolation | ||||||
|  | 
 | ||||||
|  |     def __call__(self, img_group): | ||||||
|  |         assert (len(self.interpolation) == len(img_group)) | ||||||
|  |         scale = random.uniform(self.size[0], self.size[1]) | ||||||
|  |         out_images = list() | ||||||
|  |         for img, interpolation in zip(img_group, self.interpolation): | ||||||
|  |             out_images.append(cv2.resize(img, None, fx=scale, | ||||||
|  |                                          fy=scale, interpolation=interpolation)) | ||||||
|  |             if len(img.shape) > len(out_images[-1].shape): | ||||||
|  |                 out_images[-1] = out_images[-1][..., | ||||||
|  |                                                 np.newaxis]  # single channel image | ||||||
|  |         return out_images | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class GroupRandomMultiScale(object): | ||||||
|  |     def __init__(self, size=(0.5, 1.5), interpolation=(cv2.INTER_LINEAR, cv2.INTER_NEAREST)): | ||||||
|  |         self.size = size | ||||||
|  |         self.interpolation = interpolation | ||||||
|  | 
 | ||||||
|  |     def __call__(self, img_group): | ||||||
|  |         assert (len(self.interpolation) == len(img_group)) | ||||||
|  |         scales = [0.5, 1.0, 1.5]  # random.uniform(self.size[0], self.size[1]) | ||||||
|  |         out_images = list() | ||||||
|  |         for scale in scales: | ||||||
|  |             for img, interpolation in zip(img_group, self.interpolation): | ||||||
|  |                 out_images.append(cv2.resize( | ||||||
|  |                     img, None, fx=scale, fy=scale, interpolation=interpolation)) | ||||||
|  |                 if len(img.shape) > len(out_images[-1].shape): | ||||||
|  |                     out_images[-1] = out_images[-1][..., | ||||||
|  |                                                     np.newaxis]  # single channel image | ||||||
|  |         return out_images | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class GroupRandomScaleRatio(object): | ||||||
|  |     def __init__(self, size=(680, 762, 562, 592), interpolation=(cv2.INTER_LINEAR, cv2.INTER_NEAREST)): | ||||||
|  |         self.size = size | ||||||
|  |         self.interpolation = interpolation | ||||||
|  |         self.origin_id = [0, 1360, 580, 768, 255, 300, 680, 710, 312, 1509, 800, 1377, 880, 910, 1188, 128, 960, 1784, | ||||||
|  |                           1414, 1150, 512, 1162, 950, 750, 1575, 708, 2111, 1848, 1071, 1204, 892, 639, 2040, 1524, 832, 1122, 1224, 2295] | ||||||
|  | 
 | ||||||
|  |     def __call__(self, img_group): | ||||||
|  |         assert (len(self.interpolation) == len(img_group)) | ||||||
|  |         w_scale = random.randint(self.size[0], self.size[1]) | ||||||
|  |         h_scale = random.randint(self.size[2], self.size[3]) | ||||||
|  |         h, w, _ = img_group[0].shape | ||||||
|  |         out_images = list() | ||||||
|  |         out_images.append(cv2.resize(img_group[0], None, fx=w_scale*1.0/w, fy=h_scale*1.0/h, | ||||||
|  |                                      interpolation=self.interpolation[0]))  # fx=w_scale*1.0/w, fy=h_scale*1.0/h | ||||||
|  |         ### process label map ### | ||||||
|  |         origin_label = cv2.resize( | ||||||
|  |             img_group[1], None, fx=w_scale*1.0/w, fy=h_scale*1.0/h, interpolation=self.interpolation[1]) | ||||||
|  |         origin_label = origin_label.astype(int) | ||||||
|  |         label = origin_label[:, :, 0] * 5 + \ | ||||||
|  |             origin_label[:, :, 1] * 3 + origin_label[:, :, 2] | ||||||
|  |         new_label = np.ones(label.shape) * 100 | ||||||
|  |         new_label = new_label.astype(int) | ||||||
|  |         for cnt in range(37): | ||||||
|  |             new_label = ( | ||||||
|  |                 label == self.origin_id[cnt]) * (cnt - 100) + new_label | ||||||
|  |         new_label = (label == self.origin_id[37]) * (36 - 100) + new_label | ||||||
|  |         assert(100 not in np.unique(new_label)) | ||||||
|  |         out_images.append(new_label) | ||||||
|  |         return out_images | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class GroupRandomRotation(object): | ||||||
|  |     def __init__(self, degree=(-10, 10), interpolation=(cv2.INTER_LINEAR, cv2.INTER_NEAREST), padding=None): | ||||||
|  |         self.degree = degree | ||||||
|  |         self.interpolation = interpolation | ||||||
|  |         self.padding = padding | ||||||
|  |         if self.padding is None: | ||||||
|  |             self.padding = [0, 0] | ||||||
|  | 
 | ||||||
|  |     def __call__(self, img_group): | ||||||
|  |         assert (len(self.interpolation) == len(img_group)) | ||||||
|  |         v = random.random() | ||||||
|  |         if v < 0.5: | ||||||
|  |             degree = random.uniform(self.degree[0], self.degree[1]) | ||||||
|  |             h, w = img_group[0].shape[0:2] | ||||||
|  |             center = (w / 2, h / 2) | ||||||
|  |             map_matrix = cv2.getRotationMatrix2D(center, degree, 1.0) | ||||||
|  |             out_images = list() | ||||||
|  |             for img, interpolation, padding in zip(img_group, self.interpolation, self.padding): | ||||||
|  |                 out_images.append(cv2.warpAffine( | ||||||
|  |                     img, map_matrix, (w, h), flags=interpolation, borderMode=cv2.BORDER_CONSTANT, borderValue=padding)) | ||||||
|  |                 if len(img.shape) > len(out_images[-1].shape): | ||||||
|  |                     out_images[-1] = out_images[-1][..., | ||||||
|  |                                                     np.newaxis]  # single channel image | ||||||
|  |             return out_images | ||||||
|  |         else: | ||||||
|  |             return img_group | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class GroupRandomBlur(object): | ||||||
|  |     def __init__(self, applied): | ||||||
|  |         self.applied = applied | ||||||
|  | 
 | ||||||
|  |     def __call__(self, img_group): | ||||||
|  |         assert (len(self.applied) == len(img_group)) | ||||||
|  |         v = random.random() | ||||||
|  |         if v < 0.5: | ||||||
|  |             out_images = [] | ||||||
|  |             for img, a in zip(img_group, self.applied): | ||||||
|  |                 if a: | ||||||
|  |                     img = cv2.GaussianBlur( | ||||||
|  |                         img, (5, 5), random.uniform(1e-6, 0.6)) | ||||||
|  |                 out_images.append(img) | ||||||
|  |                 if len(img.shape) > len(out_images[-1].shape): | ||||||
|  |                     out_images[-1] = out_images[-1][..., | ||||||
|  |                                                     np.newaxis]  # single channel image | ||||||
|  |             return out_images | ||||||
|  |         else: | ||||||
|  |             return img_group | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class GroupRandomHorizontalFlip(object): | ||||||
|  |     """Randomly horizontally flips the given numpy Image with a probability of 0.5 | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|  |     def __init__(self, is_flow=False): | ||||||
|  |         self.is_flow = is_flow | ||||||
|  | 
 | ||||||
|  |     def __call__(self, img_group, is_flow=False): | ||||||
|  |         v = random.random() | ||||||
|  |         if v < 0.5: | ||||||
|  |             out_images = [np.fliplr(img) for img in img_group] | ||||||
|  |             if self.is_flow: | ||||||
|  |                 for i in range(0, len(out_images), 2): | ||||||
|  |                     # invert flow pixel values when flipping | ||||||
|  |                     out_images[i] = -out_images[i] | ||||||
|  |             return out_images | ||||||
|  |         else: | ||||||
|  |             return img_group | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class GroupNormalize(object): | ||||||
|  |     def __init__(self, mean, std): | ||||||
|  |         self.mean = mean | ||||||
|  |         self.std = std | ||||||
|  | 
 | ||||||
|  |     def __call__(self, img_group): | ||||||
|  |         out_images = list() | ||||||
|  |         for img, m, s in zip(img_group, self.mean, self.std): | ||||||
|  |             if len(m) == 1: | ||||||
|  |                 img = img - np.array(m)  # single channel image | ||||||
|  |                 img = img / np.array(s) | ||||||
|  |             else: | ||||||
|  |                 img = img - np.array(m)[np.newaxis, np.newaxis, ...] | ||||||
|  |                 img = img / np.array(s)[np.newaxis, np.newaxis, ...] | ||||||
|  |             out_images.append(img) | ||||||
|  | 
 | ||||||
|  |         return out_images | ||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user