Source code for gluonnlp.data.baidu_ernie_data
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Baidu ernie data, contains XNLI."""
__all__ = ['BaiduErnieXNLI', 'BaiduErnieLCQMC', 'BaiduErnieChnSentiCorp']
import os
import tarfile
from urllib.request import urlretrieve
from ..base import get_home_dir
from .dataset import TSVDataset
from .registry import register
_baidu_ernie_data_url = 'https://ernie.bj.bcebos.com/task_data_zh.tgz'
class _BaiduErnieDataset(TSVDataset):
def __init__(self, root=None, dataset_name=None, segment=None, filename=None, **kwargs):
assert (filename or (root and dataset_name and segment))
if not filename:
root = os.path.expanduser(root)
os.makedirs(root, exist_ok=True)
self._root = root
download_data_path = os.path.join(self._root, 'task_data.tgz')
if not os.path.exists(download_data_path):
urlretrieve(_baidu_ernie_data_url, download_data_path)
tar_file = tarfile.open(download_data_path, mode='r:gz')
tar_file.extractall(self._root)
filename = os.path.join(self._root, 'task_data', dataset_name, '%s.tsv' % segment)
super(_BaiduErnieDataset, self).__init__(filename, **kwargs)
[docs]@register(segment=['train', 'dev', 'test'])
class BaiduErnieXNLI(_BaiduErnieDataset):
""" The XNLI dataset redistributed by Baidu
<https://github.com/PaddlePaddle/LARK/tree/develop/ERNIE>.
Original from:
Conneau, Alexis, et al. "Xnli: Evaluating cross-lingual sentence representations."
arXiv preprint arXiv:1809.05053 (2018).
https://github.com/facebookresearch/XNLI
Licensed under a Creative Commons Attribution-NonCommercial 4.0 International License.
License details: https://creativecommons.org/licenses/by-nc/4.0/
Parameters
----------
segment : {'train', 'dev', 'test'}, default 'train'
Dataset segment.
root : str, default '$MXNET_HOME/datasets/baidu_ernie_task_data'
Path to temp folder for storing data.
MXNET_HOME defaults to '~/.mxnet'.
return_all_fields : bool, default False
Return all fields available in the dataset.
Examples
--------
>>> xnli_dev = BaiduErnieXNLI('dev', root='./datasets/baidu_ernie_task_data/')
>>> len(xnli_dev)
2490
>>> len(xnli_dev[0])
3
>>> xnli_dev[0]
['他说,妈妈,我回来了。', '校车把他放下后,他立即给他妈妈打了电话。', 'neutral']
>>> xnli_test = BaiduErnieXNLI('test', root='./datasets/baidu_ernie_task_data/')
>>> len(xnli_test)
5010
>>> len(xnli_test[0])
2
>>> xnli_test[0]
['嗯,我根本没想过,但是我很沮丧,最后我又和他说话了。', '我还没有和他再次谈论。']
"""
def __init__(self, segment='train', root=os.path.join(get_home_dir(), 'datasets',
'baidu_ernie_data'),
return_all_fields=False):
A_IDX, B_IDX, LABEL_IDX = 0, 1, 2
if segment in ['train', 'dev']:
field_indices = [A_IDX, B_IDX, LABEL_IDX] if not return_all_fields else None
num_discard_samples = 1
elif segment == 'test':
field_indices = [A_IDX, B_IDX] if not return_all_fields else None
num_discard_samples = 1
super(BaiduErnieXNLI,
self).__init__(root, 'xnli', segment, num_discard_samples=num_discard_samples,
field_indices=field_indices)
[docs]@register(segment=['train', 'dev', 'test'])
class BaiduErnieLCQMC(_BaiduErnieDataset):
""" The LCQMC dataset original from:
Xin Liu, Qingcai Chen, Chong Deng, Huajun Zeng, Jing Chen, Dongfang Li, Buzhou Tang,
LCQMC: A Large-scale Chinese Question Matching Corpus,COLING2018.
No license granted. You can request a private license via
http://icrc.hitsz.edu.cn/LCQMC_Application_Form.pdf
The code fits the dataset format which was redistributed by Baidu in ERNIE repo.
(Baidu does not hold this version any more.)
Parameters
----------
segment : {'train', 'dev', 'test'}, default 'train'
Dataset segment.
file_path : str
Path to the downloaded dataset file.
return_all_fields : bool, default False
Return all fields available in the dataset.
"""
def __init__(self, file_path, segment='train', return_all_fields=False):
A_IDX, B_IDX, LABEL_IDX = 0, 1, 2
if segment in ['train', 'dev']:
field_indices = [A_IDX, B_IDX, LABEL_IDX] if not return_all_fields else None
num_discard_samples = 1
elif segment == 'test':
field_indices = [A_IDX, B_IDX] if not return_all_fields else None
num_discard_samples = 1
super(BaiduErnieLCQMC,
self).__init__(filename=file_path, num_discard_samples=num_discard_samples,
field_indices=field_indices)
[docs]@register(segment=['train', 'dev', 'test'])
class BaiduErnieChnSentiCorp(_BaiduErnieDataset):
""" The ChnSentiCorp dataset redistributed by Baidu
<https://github.com/PaddlePaddle/LARK/tree/develop/ERNIE>.
Original from Tan Songbo (Chinese Academy of Sciences, tansongbo@software.ict.ac.cn).
Parameters
----------
segment : {'train', 'dev', 'test'}, default 'train'
Dataset segment.
root : str, default '$MXNET_HOME/datasets/baidu_ernie_task_data'
Path to temp folder for storing data.
MXNET_HOME defaults to '~/.mxnet'.
return_all_fields : bool, default False
Return all fields available in the dataset.
Examples
--------
>>> chnsenticorp_dev = BaiduErnieChnSentiCorp('dev', root='./datasets/baidu_ernie_task_data/')
>>> len(chnsenticorp_dev)
1200
>>> len(chnsenticorp_dev[0])
2
>>> chnsenticorp_dev[2]
['商品的不足暂时还没发现,京东的订单处理速度实在.......周二就打包完成,周五才发货...', '0']
>>> chnsenticorp_test = BaiduErnieChnSentiCorp('test', root='./datasets/baidu_ernie_task_data/')
>>> len(chnsenticorp_test)
1200
>>> len(chnsenticorp_test[0])
1
>>> chnsenticorp_test[0]
['这个宾馆比较陈旧了,特价的房间也很一般。总体来说一般']
"""
def __init__(self, segment='train', root=os.path.join(get_home_dir(), 'datasets',
'baidu_ernie_data'),
return_all_fields=False):
LABEL_IDX, A_IDX = 0, 1
if segment in ['train', 'dev']:
field_indices = [A_IDX, LABEL_IDX] if not return_all_fields else None
num_discard_samples = 1
elif segment == 'test':
field_indices = [A_IDX] if not return_all_fields else None
num_discard_samples = 1
super(BaiduErnieChnSentiCorp,
self).__init__(root, 'chnsenticorp', segment,
num_discard_samples=num_discard_samples, field_indices=field_indices)