Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 13 additions & 17 deletions src/main/java/io/anserini/util/FeatureVector.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,18 @@

package io.anserini.util;

import it.unimi.dsi.fastutil.objects.Object2FloatOpenHashMap;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;

public class FeatureVector {
private Object2FloatOpenHashMap<String> features = new Object2FloatOpenHashMap<String>();
private Map<String, Float> features = new HashMap<>();

public enum Order {
FEATURE_DESCENDING, FEATURE_ASCENDING, VALUE_DESCENDING, VALUE_ASCENDING
Expand All @@ -36,16 +36,12 @@ public enum Order {
public FeatureVector() {}

public void addFeatureValue(String feature, float value) {
if (!features.containsKey(feature)) {
features.put(feature, value);
} else {
features.put(feature, features.getFloat(feature) + value);
}
features.put(feature, features.getOrDefault(feature, 0.0f) + value);
}

public FeatureVector pruneToSize(int k) {
List<FeatureValuePair> pairs = getOrderedFeatures();
Object2FloatOpenHashMap<String> pruned = new Object2FloatOpenHashMap<>();
Map<String, Float> pruned = new HashMap<>();

for (FeatureValuePair pair : pairs) {
pruned.put(pair.getFeature(), pair.getValue());
Expand All @@ -60,17 +56,17 @@ public FeatureVector pruneToSize(int k) {

public FeatureVector scaleToUnitL2Norm() {
double norm = computeL2Norm();
for (String f : features.keySet()) {
features.put(f, (float) (features.getFloat(f) / norm));
for (Map.Entry<String, Float> e : features.entrySet()) {
e.setValue((float) (e.getValue() / norm));
}

return this;
}

public FeatureVector scaleToUnitL1Norm() {
double norm = computeL1Norm();
for (String f : features.keySet()) {
features.put(f, (float) (features.getFloat(f) / norm));
for (Map.Entry<String, Float> e : features.entrySet()) {
e.setValue((float) (e.getValue() / norm));
}

return this;
Expand All @@ -81,7 +77,7 @@ public Set<String> getFeatures() {
}

public float getValue(String feature) {
return features.containsKey(feature) ? features.getFloat(feature) : 0.0f;
return features.getOrDefault(feature, 0.0f);
}

public Iterator<String> iterator() {
Expand All @@ -95,15 +91,15 @@ public boolean contains(String feature) {
public double computeL2Norm() {
double norm = 0.0;
for (String term : features.keySet()) {
norm += Math.pow(features.getFloat(term), 2.0);
norm += Math.pow(features.getOrDefault(term, 0.0f), 2.0);
}
return Math.sqrt(norm);
}

public double computeL1Norm() {
double norm = 0.0;
for (String term : features.keySet()) {
norm += Math.abs(features.getFloat(term));
norm += Math.abs(features.getOrDefault(term, 0.0f));
}
return norm;
}
Expand All @@ -125,7 +121,7 @@ private List<FeatureValuePair> getOrderedFeatures(Order order) {
Iterator<String> featureIterator = features.keySet().iterator();
while (featureIterator.hasNext()) {
String feature = featureIterator.next();
float value = features.getFloat(feature);
float value = features.getOrDefault(feature, 0.0f);
FeatureValuePair featureValuePair = new FeatureValuePair(feature, value);
pairs.add(featureValuePair);
}
Expand Down