@@ -58,7 +58,7 @@ LearnerTorchVision = R6Class("LearnerTorchVision",
5858 private = list (
5959 .module_generator = NULL ,
6060 .network = function (task , param_vals ) {
61- nout = get_nout (task )
61+ nout = output_dim_for (task )
6262 if (param_vals $ pretrained ) {
6363 network = replace_head(private $ .module_generator(pretrained = TRUE ), nout )
6464 return (network )
@@ -107,126 +107,126 @@ replace_head.VGG = function(network, d_out) {
107107}
108108
109109# ' @include aaa.R
110- register_learner(" classif.alexnet" ,
110+ register_learner(" classif.alexnet" ,
111111 function (loss = NULL , optimizer = NULL , callbacks = list ()) {
112112 LearnerTorchVision $ new(" alexnet" , torchvision :: model_alexnet , " AlexNet" ,
113113 loss = loss , optimizer = optimizer , callbacks = callbacks )
114114 }
115115)
116116
117- # register_learner("classif.inception_v3",
117+ # register_learner("classif.inception_v3",
118118# function(loss = NULL, optimizer = NULL, callbacks = list()) {
119119# LearnerTorchVision$new("inception_v3", torchvision::model_inception_v3, "Inception V3",
120120# loss = loss, optimizer = optimizer, callbacks = callbacks)
121121# }
122122# )
123123
124- register_learner(" classif.mobilenet_v2" ,
124+ register_learner(" classif.mobilenet_v2" ,
125125 function (loss = NULL , optimizer = NULL , callbacks = list ()) {
126126 LearnerTorchVision $ new(" mobilenet_v2" , torchvision :: model_mobilenet_v2 , " Mobilenet V2" ,
127127 loss = loss , optimizer = optimizer , callbacks = callbacks )
128128 }
129129)
130130
131- register_learner(" classif.resnet18" ,
131+ register_learner(" classif.resnet18" ,
132132 function (loss = NULL , optimizer = NULL , callbacks = list ()) {
133133 LearnerTorchVision $ new(" resnet18" , torchvision :: model_resnet18 , " ResNet-18" ,
134134 loss = loss , optimizer = optimizer , callbacks = callbacks )
135135 }
136136)
137137
138- register_learner(" classif.resnet34" ,
138+ register_learner(" classif.resnet34" ,
139139 function (loss = NULL , optimizer = NULL , callbacks = list ()) {
140140 LearnerTorchVision $ new(" resnet34" , torchvision :: model_resnet34 , " ResNet-34" ,
141141 loss = loss , optimizer = optimizer , callbacks = callbacks )
142142 }
143143)
144144
145- register_learner(" classif.resnet50" ,
145+ register_learner(" classif.resnet50" ,
146146 function (loss = NULL , optimizer = NULL , callbacks = list ()) {
147147 LearnerTorchVision $ new(" resnet50" , torchvision :: model_resnet50 , " ResNet-50" ,
148148 loss = loss , optimizer = optimizer , callbacks = callbacks )
149149 }
150150)
151151
152- register_learner(" classif.resnet101" ,
152+ register_learner(" classif.resnet101" ,
153153 function (loss = NULL , optimizer = NULL , callbacks = list ()) {
154154 LearnerTorchVision $ new(" resnet101" , torchvision :: model_resnet101 , " ResNet-101" ,
155155 loss = loss , optimizer = optimizer , callbacks = callbacks )
156156 }
157157)
158158
159- register_learner(" classif.resnet152" ,
159+ register_learner(" classif.resnet152" ,
160160 function (loss = NULL , optimizer = NULL , callbacks = list ()) {
161161 LearnerTorchVision $ new(" resnet152" , torchvision :: model_resnet152 , " ResNet-152" ,
162162 loss = loss , optimizer = optimizer , callbacks = callbacks )
163163 }
164164)
165165
166- register_learner(" classif.resnext101_32x8d" ,
166+ register_learner(" classif.resnext101_32x8d" ,
167167 function (loss = NULL , optimizer = NULL , callbacks = list ()) {
168168 LearnerTorchVision $ new(" resnext101_32x8d" , torchvision :: model_resnext101_32x8d , " ResNeXt-101 32x8d" ,
169169 loss = loss , optimizer = optimizer , callbacks = callbacks )
170170 }
171171)
172172
173- register_learner(" classif.resnext50_32x4d" ,
173+ register_learner(" classif.resnext50_32x4d" ,
174174 function (loss = NULL , optimizer = NULL , callbacks = list ()) {
175175 LearnerTorchVision $ new(" resnext50_32x4d" , torchvision :: model_resnext50_32x4d , " ResNeXt-50 32x4d" ,
176176 loss = loss , optimizer = optimizer , callbacks = callbacks )
177177 }
178178)
179179
180- register_learner(" classif.vgg11" ,
180+ register_learner(" classif.vgg11" ,
181181 function (loss = NULL , optimizer = NULL , callbacks = list ()) {
182182 LearnerTorchVision $ new(" vgg11" , torchvision :: model_vgg11 , " VGG 11" ,
183183 loss = loss , optimizer = optimizer , callbacks = callbacks )
184184 }
185185)
186186
187- register_learner(" classif.vgg11_bn" ,
187+ register_learner(" classif.vgg11_bn" ,
188188 function (loss = NULL , optimizer = NULL , callbacks = list ()) {
189189 LearnerTorchVision $ new(" vgg11_bn" , torchvision :: model_vgg11_bn , " VGG 11" ,
190190 loss = loss , optimizer = optimizer , callbacks = callbacks )
191191 }
192192)
193193
194- register_learner(" classif.vgg13" ,
194+ register_learner(" classif.vgg13" ,
195195 function (loss = NULL , optimizer = NULL , callbacks = list ()) {
196196 LearnerTorchVision $ new(" vgg13" , torchvision :: model_vgg13 , " VGG 13" ,
197197 loss = loss , optimizer = optimizer , callbacks = callbacks )
198198 }
199199)
200200
201- register_learner(" classif.vgg13_bn" ,
201+ register_learner(" classif.vgg13_bn" ,
202202 function (loss = NULL , optimizer = NULL , callbacks = list ()) {
203203 LearnerTorchVision $ new(" vgg13_bn" , torchvision :: model_vgg13_bn , " VGG 13" ,
204204 loss = loss , optimizer = optimizer , callbacks = callbacks )
205205 }
206206)
207207
208- register_learner(" classif.vgg16" ,
208+ register_learner(" classif.vgg16" ,
209209 function (loss = NULL , optimizer = NULL , callbacks = list ()) {
210210 LearnerTorchVision $ new(" vgg16" , torchvision :: model_vgg16 , " VGG 16" ,
211211 loss = loss , optimizer = optimizer , callbacks = callbacks )
212212 }
213213)
214214
215- register_learner(" classif.vgg16_bn" ,
215+ register_learner(" classif.vgg16_bn" ,
216216 function (loss = NULL , optimizer = NULL , callbacks = list ()) {
217217 LearnerTorchVision $ new(" vgg16_bn" , torchvision :: model_vgg16_bn , " VGG 16" ,
218218 loss = loss , optimizer = optimizer , callbacks = callbacks )
219219 }
220220)
221221
222- register_learner(" classif.vgg19" ,
222+ register_learner(" classif.vgg19" ,
223223 function (loss = NULL , optimizer = NULL , callbacks = list ()) {
224224 LearnerTorchVision $ new(" vgg19" , torchvision :: model_vgg19 , " VGG 19" ,
225225 loss = loss , optimizer = optimizer , callbacks = callbacks )
226226 }
227227)
228228
229- register_learner(" classif.vgg19_bn" ,
229+ register_learner(" classif.vgg19_bn" ,
230230 function (loss = NULL , optimizer = NULL , callbacks = list ()) {
231231 LearnerTorchVision $ new(" vgg19_bn" , torchvision :: model_vgg19_bn , " VGG 19" ,
232232 loss = loss , optimizer = optimizer , callbacks = callbacks )
0 commit comments