Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 40 additions & 11 deletions paddlenlp/datasets/cote.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

class Cote(DatasetBuilder):
"""
COTE_DP dataset for Opinion Role Labeling task.
COTE_DP/COTE-BD/COTE-MFW dataset for Opinion Role Labeling task.
More information please refer to https://aistudio.baidu.com/aistudio/competition/detail/50/?isFromLuge=1.

"""
Expand All @@ -37,22 +37,52 @@ class Cote(DatasetBuilder):
'splits': {
'train': [
os.path.join('COTE-DP', 'train.tsv'),
'17d11ca91b7979f2c2023757650096e5', (0, 1), 1
'17d11ca91b7979f2c2023757650096e5'
],
'test': [
os.path.join('COTE-DP', 'test.tsv'),
'5bb9b9ccaaee6bcc1ac7a6c852b46f66', (1, ), 1
'5bb9b9ccaaee6bcc1ac7a6c852b46f66'
],
},
'labels': ["B", "I", "O"]
},
'bd': {
'url': "https://dataset-bj.cdn.bcebos.com/qianyan/COTE-BD.zip",
'md5': "8d87ff9bb6f5e5d46269d72632a1b01f",
'splits': {
'train': [
os.path.join('COTE-BD', 'train.tsv'),
'4c08ccbcc373cb3bf05c3429d435f608'
],
'test': [
os.path.join('COTE-BD', 'test.tsv'),
'aeb5c9af61488dadb12cbcc1d2180667'
],
},
'labels': ["B", "I", "O"]
},
'mfw': {
'url': "https://dataset-bj.cdn.bcebos.com/qianyan/COTE-MFW.zip",
'md5': "c85326bf2be4424d03373ea70cb32c3f",
'splits': {
'train': [
os.path.join('COTE-MFW', 'train.tsv'),
'01fc90b9098d35615df6b8d257eb46ca'
],
'test': [
os.path.join('COTE-MFW', 'test.tsv'),
'c61a475917a461089db141c59c688343'
],
},
'labels': ["B", "I", "O"]
}
}

def _get_data(self, mode, **kwargs):
"""Downloads dataset."""
builder_config = self.BUILDER_CONFIGS[self.name]
default_root = os.path.join(DATA_HOME, 'COTE-DP')
filename, data_hash, _, _ = builder_config['splits'][mode]
default_root = os.path.join(DATA_HOME, f'COTE-{self.name.upper()}')
filename, data_hash = builder_config['splits'][mode]
fullname = os.path.join(default_root, filename)
if not os.path.exists(fullname) or (data_hash and
not md5file(fullname) == data_hash):
Expand All @@ -64,21 +94,19 @@ def _get_data(self, mode, **kwargs):

def _read(self, filename, split):
"""Reads data"""
_, _, field_indices, num_discard_samples = self.BUILDER_CONFIGS[
self.name]['splits'][split]
with open(filename, 'r', encoding='utf-8') as f:
for idx, line in enumerate(f):
if idx < num_discard_samples:
if idx == 0:
# ignore first line about title
continue
line_stripped = line.strip().split('\t')
if not line_stripped:
continue
example = [line_stripped[indice] for indice in field_indices]
if split == "test":
yield {"tokens": list(example[0])}
yield {"tokens": list(line_stripped[1])}
else:
try:
entity, text = example[0], example[1]
entity, text = line_stripped[0], line_stripped[1]
start_idx = text.index(entity)
except:
# drop the dirty data
Expand All @@ -94,6 +122,7 @@ def _read(self, filename, split):
"entity": entity
}


def get_labels(self):
"""
Return labels of the COTE.
Expand Down