Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions python/caffe/pycaffe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,16 @@ using boost::python::vector_indexing_suite;
class CaffeBlob {
public:

CaffeBlob(const shared_ptr<Blob<float> > &blob, const string& name)
: blob_(blob), name_(name) {}

CaffeBlob(const shared_ptr<Blob<float> > &blob)
: blob_(blob) {}

CaffeBlob()
{}

string name() const { return name_; }
int num() const { return blob_->num(); }
int channels() const { return blob_->channels(); }
int height() const { return blob_->height(); }
Expand All @@ -51,6 +55,7 @@ class CaffeBlob {

protected:
shared_ptr<Blob<float> > blob_;
string name_;
};


Expand Down Expand Up @@ -219,15 +224,27 @@ struct CaffeNet
void set_device(int device_id) { Caffe::SetDevice(device_id); }

vector<CaffeBlob> blobs() {
return vector<CaffeBlob>(net_->blobs().begin(), net_->blobs().end());
vector<CaffeBlob> result;
for (int i = 0; i < net_->blobs().size(); ++i) {
result.push_back(CaffeBlob(net_->blobs()[i], net_->blob_names()[i]));
}
return result;
}

vector<CaffeBlob> params() {
return vector<CaffeBlob>(net_->params().begin(), net_->params().end());
vector<CaffeBlob> result;
int ix = 0;
for (int i = 0; i < net_->layers().size(); ++i) {
for (int j = 0; j < net_->layers()[i]->blobs().size(); ++j) {
result.push_back(CaffeBlob(net_->params()[ix], net_->layer_names()[i]));
ix++;
}
}
return result;
}

// The pointer to the internal caffe::Net instant.
shared_ptr<Net<float> > net_;
shared_ptr<Net<float> > net_;
};


Expand All @@ -251,6 +268,7 @@ BOOST_PYTHON_MODULE(pycaffe)

boost::python::class_<CaffeBlob, CaffeBlobWrap>(
"CaffeBlob", boost::python::no_init)
.add_property("name", &CaffeBlob::name)
.add_property("num", &CaffeBlob::num)
.add_property("channels", &CaffeBlob::channels)
.add_property("height", &CaffeBlob::height)
Expand Down