Source code for gluonnlp.data.translation
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=
"""Machine translation datasets."""
__all__ = ['IWSLT2015', 'WMT2014', 'WMT2014BPE', 'WMT2016', 'WMT2016BPE']
import os
import zipfile
import shutil
import io
from mxnet.gluon.utils import download, check_sha1, _get_repo_file_url
from mxnet.gluon.data import ArrayDataset
from .dataset import TextLineDataset
from ..vocab import Vocab
from .registry import register
from ..base import get_home_dir
def _get_pair_key(src_lang, tgt_lang):
return '_'.join(sorted([src_lang, tgt_lang]))
class _TranslationDataset(ArrayDataset):
def __init__(self, namespace, segment, src_lang, tgt_lang, root):
assert _get_pair_key(src_lang, tgt_lang) in self._archive_file, \
'The given language combination: src_lang={}, tgt_lang={}, is not supported. ' \
'Only supports language pairs = {}.'.format(
src_lang, tgt_lang, str(self._archive_file.keys()))
if isinstance(segment, str):
assert segment in self._supported_segments, \
'Only supports {} for the segment. Received segment={}'.format(
self._supported_segments, segment)
else:
for ele_segment in segment:
assert ele_segment in self._supported_segments, \
'segment should only contain elements in {}. Received segment={}'.format(
self._supported_segments, segment)
self._namespace = 'gluon/dataset/{}'.format(namespace)
self._segment = segment
self._src_lang = src_lang
self._tgt_lang = tgt_lang
self._src_vocab = None
self._tgt_vocab = None
self._pair_key = _get_pair_key(src_lang, tgt_lang)
root = os.path.expanduser(root)
os.makedirs(root, exist_ok=True)
self._root = root
if isinstance(segment, str):
segment = [segment]
src_corpus = []
tgt_corpus = []
for ele_segment in segment:
[src_corpus_path, tgt_corpus_path] = self._get_data(ele_segment)
src_corpus.extend(TextLineDataset(src_corpus_path))
tgt_corpus.extend(TextLineDataset(tgt_corpus_path))
# Filter 0-length src/tgt sentences
src_lines = []
tgt_lines = []
for src_line, tgt_line in zip(list(src_corpus), list(tgt_corpus)):
if len(src_line) > 0 and len(tgt_line) > 0:
src_lines.append(src_line)
tgt_lines.append(tgt_line)
super(_TranslationDataset, self).__init__(src_lines, tgt_lines)
def _fetch_data_path(self, file_name_hashs):
archive_file_name, archive_hash = self._archive_file[self._pair_key]
paths = []
root = self._root
for data_file_name, data_hash in file_name_hashs:
path = os.path.join(root, data_file_name)
if not os.path.exists(path) or not check_sha1(path, data_hash):
downloaded_file_path = download(_get_repo_file_url(self._namespace,
archive_file_name),
path=root,
sha1_hash=archive_hash)
with zipfile.ZipFile(downloaded_file_path, 'r') as zf:
for member in zf.namelist():
filename = os.path.basename(member)
if filename:
dest = os.path.join(root, filename)
with zf.open(member) as source, \
open(dest, 'wb') as target:
shutil.copyfileobj(source, target)
paths.append(path)
return paths
def _get_data(self, segment):
src_corpus_file_name, src_corpus_hash =\
self._data_file[self._pair_key][segment + '_' + self._src_lang]
tgt_corpus_file_name, tgt_corpus_hash =\
self._data_file[self._pair_key][segment + '_' + self._tgt_lang]
return self._fetch_data_path([(src_corpus_file_name, src_corpus_hash),
(tgt_corpus_file_name, tgt_corpus_hash)])
@property
def src_vocab(self):
"""Source Vocabulary of the Dataset.
Returns
-------
src_vocab : Vocab
Source vocabulary.
"""
if self._src_vocab is None:
src_vocab_file_name, src_vocab_hash = \
self._data_file[self._pair_key]['vocab' + '_' + self._src_lang]
[src_vocab_path] = self._fetch_data_path([(src_vocab_file_name, src_vocab_hash)])
with io.open(src_vocab_path, 'r', encoding='utf-8') as in_file:
self._src_vocab = Vocab.from_json(in_file.read())
return self._src_vocab
@property
def tgt_vocab(self):
"""Target Vocabulary of the Dataset.
Returns
-------
tgt_vocab : Vocab
Target vocabulary.
"""
if self._tgt_vocab is None:
tgt_vocab_file_name, tgt_vocab_hash = \
self._data_file[self._pair_key]['vocab' + '_' + self._tgt_lang]
[tgt_vocab_path] = self._fetch_data_path([(tgt_vocab_file_name, tgt_vocab_hash)])
with io.open(tgt_vocab_path, 'r', encoding='utf-8') as in_file:
self._tgt_vocab = Vocab.from_json(in_file.read())
return self._tgt_vocab
[docs]@register(segment=['train', 'val', 'test'])
class IWSLT2015(_TranslationDataset):
"""Preprocessed IWSLT English-Vietnamese Translation Dataset.
We use the preprocessed version provided in https://nlp.stanford.edu/projects/nmt/
Parameters
----------
segment : str or list of str, default 'train'
Dataset segment. Options are 'train', 'val', 'test' or their combinations.
src_lang : str, default 'en'
The source language. Option for source and target languages are 'en' <-> 'vi'
tgt_lang : str, default 'vi'
The target language. Option for source and target languages are 'en' <-> 'vi'
root : str, default '$MXNET_HOME/datasets/iwslt2015'
Path to temp folder for storing data.
MXNET_HOME defaults to '~/.mxnet'.
"""
def __init__(self, segment='train', src_lang='en', tgt_lang='vi',
root=os.path.join(get_home_dir(), 'datasets', 'iwslt2015')):
self._supported_segments = ['train', 'val', 'test']
self._archive_file = {_get_pair_key('en', 'vi'):
('iwslt15.zip', '15a05df23caccb1db458fb3f9d156308b97a217b')}
self._data_file = {_get_pair_key('en', 'vi'):
{'train_en': ('train.en',
'675d16d057f2b6268fb294124b1646d311477325'),
'train_vi': ('train.vi',
'bb6e21d4b02b286f2a570374b0bf22fb070589fd'),
'val_en': ('tst2012.en',
'e381f782d637b8db827d7b4d8bb3494822ec935e'),
'val_vi': ('tst2012.vi',
'4511988ce67591dc8bcdbb999314715f21e5a1e1'),
'test_en': ('tst2013.en',
'd320db4c8127a85de81802f239a6e6b1af473c3d'),
'test_vi': ('tst2013.vi',
'af212c48a68465ceada9263a049f2331f8af6290'),
'vocab_en': ('vocab.en.json',
'b6f8e77a45f6dce648327409acd5d52b37a45d94'),
'vocab_vi' : ('vocab.vi.json',
'9be11a9edd8219647754d04e0793d2d8c19dc852')}}
super(IWSLT2015, self).__init__('iwslt2015', segment=segment, src_lang=src_lang,
tgt_lang=tgt_lang, root=root)
[docs]@register(segment=['train', 'newstest2009', 'newstest2010', 'newstest2011', \
'newstest2012', 'newstest2013', 'newstest2014'])
class WMT2014(_TranslationDataset):
"""Translation Corpus of the WMT2014 Evaluation Campaign.
http://www.statmt.org/wmt14/translation-task.html
Parameters
----------
segment : str or list of str, default 'train'
Dataset segment. Options are 'train', 'newstest2009', 'newstest2010',
'newstest2011', 'newstest2012', 'newstest2013', 'newstest2014' or their combinations
src_lang : str, default 'en'
The source language. Option for source and target languages are 'en' <-> 'de'
tgt_lang : str, default 'de'
The target language. Option for source and target languages are 'en' <-> 'de'
full : bool, default False
By default, we use the "filtered test sets" while if full is True, we use the "cleaned test
sets".
root : str, default '$MXNET_HOME/datasets/wmt2014'
Path to temp folder for storing data.
MXNET_HOME defaults to '~/.mxnet'.
"""
def __init__(self, segment='train', src_lang='en', tgt_lang='de', full=False,
root=os.path.join(get_home_dir(), 'datasets', 'wmt2014')):
self._supported_segments = ['train'] + ['newstest%d' % i for i in range(2009, 2015)]
self._archive_file = {_get_pair_key('de', 'en'):
('wmt2014_de_en-b0e0e703.zip',
'b0e0e7036217ffa94f4b35a5b5d2a96a27f680a4')}
self._data_file = {_get_pair_key('de', 'en'):
{'train_en': ('train.en',
'cec2d4c5035df2a54094076348eaf37e8b588a9b'),
'train_de': ('train.de',
'6348764640ffc40992e7de89a8c48d32a8bcf458'),
'newstest2009_en': ('newstest2009.en',
'f8623af2de682924f9841488427e81c430e3ce60'),
'newstest2009_de': ('newstest2009.de',
'dec03f14cb47e726ccb19bec80c645d4a996f8a9'),
'newstest2010_en': ('newstest2010.en',
'5966eb13bd7cc8855cc6b40f9797607e36e9cc80'),
'newstest2010_de': ('newstest2010.de',
'b9af0cb004fa6996eda246d0173c191693b26025'),
'newstest2011_en': ('newstest2011.en',
'2c1d9d077fdbfe9d0e052a6e08a85ee7959479ab'),
'newstest2011_de': ('newstest2011.de',
'efbded3d175a9d472aa5938fe22afcc55c6055ff'),
'newstest2012_en': ('newstest2012.en',
'52f05ae725be45ee4012c6e208cef13614abacf1'),
'newstest2012_de': ('newstest2012.de',
'd9fe32143b88e6fe770843e15ee442a69ff6752d'),
'newstest2013_en': ('newstest2013.en',
'5dca5d02cf40278d8586ee7d58d58215253156a9'),
'newstest2013_de': ('newstest2013.de',
'ddda1e7b3270cb68108858640bfb619c37ede2ab'),
'newstest2014_en': ('newstest2014.src.en',
'610c5bb4cc866ad04ab1f6f80d740e1f4435027c'),
'newstest2014_de': ('newstest2014.ref.de',
'03b02c7f60c8509ba9bb4c85295358f7c9f00d2d')}}
if full:
self._data_file[_get_pair_key('de', 'en')]['newstest2014_en'] = \
('newstest2014.full.en', '528742a3a9690995d031f49d1dbb704844684976')
self._data_file[_get_pair_key('de', 'en')]['newstest2014_de'] = \
('newstest2014.full.de', '2374b6a28cecbd965b73a9acc35a425e1ed81963')
else:
if src_lang == 'de':
self._data_file[_get_pair_key('de', 'en')]['newstest2014_en'] = \
('newstest2014.ref.en', 'cf23229ec6db8b85f240618d2a245f69afebed1f')
self._data_file[_get_pair_key('de', 'en')]['newstest2014_de'] = \
('newstest2014.src.de', '791d644b1a031268ca19600b2734a63c7bfcecc4')
super(WMT2014, self).__init__('wmt2014', segment=segment, src_lang=src_lang,
tgt_lang=tgt_lang,
root=os.path.join(root, _get_pair_key(src_lang, tgt_lang)))
[docs]@register(segment=['train', 'newstest2009', 'newstest2010', 'newstest2011', \
'newstest2012', 'newstest2013', 'newstest2014'])
class WMT2014BPE(_TranslationDataset):
"""Preprocessed Translation Corpus of the WMT2014 Evaluation Campaign.
We preprocess the dataset by adapting
https://github.com/tensorflow/nmt/blob/master/nmt/scripts/wmt16_en_de.sh
Parameters
----------
segment : str or list of str, default 'train'
Dataset segment. Options are 'train', 'newstest2009', 'newstest2010',
'newstest2011', 'newstest2012', 'newstest2013', 'newstest2014' or their combinations
src_lang : str, default 'en'
The source language. Option for source and target languages are 'en' <-> 'de'
tgt_lang : str, default 'de'
The target language. Option for source and target languages are 'en' <-> 'de'
full : bool, default False
In default, we use the test dataset in http://statmt.org/wmt14/test-filtered.tgz.
When full is True, we use the test dataset in http://statmt.org/wmt14/test-full.tgz
root : str, default '$MXNET_HOME/datasets/wmt2014'
Path to temp folder for storing data.
MXNET_HOME defaults to '~/.mxnet'.
"""
def __init__(self, segment='train', src_lang='en', tgt_lang='de', full=False,
root=os.path.join(get_home_dir(), 'datasets', 'wmt2014')):
self._supported_segments = ['train'] + ['newstest%d' % i for i in range(2009, 2015)]
self._archive_file = {_get_pair_key('de', 'en'):
('wmt2014bpe_de_en-ace8f41c.zip',
'ace8f41c22c0da8729ff15f40d416ebd16738979')}
self._data_file = {_get_pair_key('de', 'en'):
{'train_en': ('train.tok.clean.bpe.32000.en',
'e3f093b64468db7084035c9650d9eecb86a3db5f'),
'train_de': ('train.tok.clean.bpe.32000.de',
'60703ad088706a3d9d1f3328889c6f4725a36cfb'),
'newstest2009_en': ('newstest2009.tok.bpe.32000.en',
'5678547f579528a8716298e895f886e3976085e1'),
'newstest2009_de': ('newstest2009.tok.bpe.32000.de',
'32caa69023eac1750a0036780f9d511d979aed2c'),
'newstest2010_en': ('newstest2010.tok.bpe.32000.en',
'813103f7b4b472cf213fe3b2c3439e267dbc4afb'),
'newstest2010_de': ('newstest2010.tok.bpe.32000.de',
'972076a897ecbc7a3acb639961241b33fd58a374'),
'newstest2011_en': ('newstest2011.tok.bpe.32000.en',
'c3de2d72d5e7bdbe848839c55c284fece90464ce'),
'newstest2011_de': ('newstest2011.tok.bpe.32000.de',
'7a8722aeedacd99f1aa8dffb6d8d072430048011'),
'newstest2012_en': ('newstest2012.tok.bpe.32000.en',
'876ad3c72e33d8e1ed14f5362f97c771ce6a9c7f'),
'newstest2012_de': ('newstest2012.tok.bpe.32000.de',
'57467fcba8442164d058a05eaf642a1da1d92c13'),
'newstest2013_en': ('newstest2013.tok.bpe.32000.en',
'de06a155c3224674b2434f3ff3b2c4a4a293d238'),
'newstest2013_de': ('newstest2013.tok.bpe.32000.de',
'094084989128dd091a2fe2a5818a86bc99ecc0e7'),
'newstest2014_en': ('newstest2014.tok.bpe.32000.src.en',
'347cf4d3d5c3c46ca1220247d22c07aa90092bd9'),
'newstest2014_de': ('newstest2014.tok.bpe.32000.ref.de',
'f66b80a0c460c524ec42731e527c54aab5507a66'),
'vocab_en': ('vocab.bpe.32000.json',
'71413f497ce3a0fa691c55277f367e5d672b27ee'),
'vocab_de': ('vocab.bpe.32000.json',
'71413f497ce3a0fa691c55277f367e5d672b27ee')}}
if full:
self._data_file[_get_pair_key('de', 'en')]['newstest2014_en'] = \
('newstest2014.tok.bpe.32000.full.en', '6c398b61641cd39f186b417c54b171876563193f')
self._data_file[_get_pair_key('de', 'en')]['newstest2014_de'] = \
('newstest2014.tok.bpe.32000.full.de', 'b890a8dfc2146dde570fcbcb42e4157292e95251')
else:
if src_lang == 'de':
self._data_file[_get_pair_key('de', 'en')]['newstest2014_en'] = \
('newstest2014.tok.bpe.32000.ref.en',
'cd416085db722bf07cbba4ff29942fe94e966023')
self._data_file[_get_pair_key('de', 'en')]['newstest2014_de'] = \
('newstest2014.tok.bpe.32000.src.de',
'9274d31f92141933f29a405753d5fae051fa5725')
super(WMT2014BPE, self).__init__('wmt2014', segment=segment, src_lang=src_lang,
tgt_lang=tgt_lang,
root=os.path.join(root, _get_pair_key(src_lang, tgt_lang)))
[docs]@register(segment=['train', 'newstest2012', 'newstest2013', 'newstest2014', \
'newstest2015', 'newstest2016'])
class WMT2016(_TranslationDataset):
"""Translation Corpus of the WMT2016 Evaluation Campaign.
Parameters
----------
segment : str or list of str, default 'train'
Dataset segment. Options are 'train', 'newstest2012', 'newstest2013',
'newstest2014', 'newstest2015', 'newstest2016' or their combinations
src_lang : str, default 'en'
The source language. Option for source and target languages are 'en' <-> 'de'
tgt_lang : str, default 'de'
The target language. Option for source and target languages are 'en' <-> 'de'
root : str, default '$MXNET_HOME/datasets/wmt2016'
Path to temp folder for storing data.
MXNET_HOME defaults to '~/.mxnet'.
"""
def __init__(self, segment='train', src_lang='en', tgt_lang='de',
root=os.path.join(get_home_dir(), 'datasets', 'wmt2016')):
self._supported_segments = ['train'] + ['newstest%d' % i for i in range(2012, 2017)]
self._archive_file = {_get_pair_key('de', 'en'):
('wmt2016_de_en-88767407.zip',
'887674077b951ce949fe3e597086b826bd7574d8')}
self._data_file = {_get_pair_key('de', 'en'):
{'train_en': ('train.en',
'1be6d00c255c57183305276c5de60771e201d3b0'),
'train_de': ('train.de',
'4eec608b8486bfb65b61bda237b0c9b3c0f66f17'),
'newstest2012_en': ('newstest2012.en',
'52f05ae725be45ee4012c6e208cef13614abacf1'),
'newstest2012_de': ('newstest2012.de',
'd9fe32143b88e6fe770843e15ee442a69ff6752d'),
'newstest2013_en': ('newstest2013.en',
'5dca5d02cf40278d8586ee7d58d58215253156a9'),
'newstest2013_de': ('newstest2013.de',
'ddda1e7b3270cb68108858640bfb619c37ede2ab'),
'newstest2014_en': ('newstest2014.en',
'528742a3a9690995d031f49d1dbb704844684976'),
'newstest2014_de': ('newstest2014.de',
'2374b6a28cecbd965b73a9acc35a425e1ed81963'),
'newstest2015_en': ('newstest2015.en',
'bf90439b209a496128995c4b948ad757979d0756'),
'newstest2015_de': ('newstest2015.de',
'd69ac825fe3d5796b4990b969ad71903a38a0cd1'),
'newstest2016_en': ('newstest2016.en',
'a99c145d5214eb1645b56d21b02a541fbe7eb3c2'),
'newstest2016_de': ('newstest2016.de',
'fcdd3104f21eb4b9c49ba8ddef46d9b2d472b3fe')}}
super(WMT2016, self).__init__('wmt2016', segment=segment, src_lang=src_lang,
tgt_lang=tgt_lang,
root=os.path.join(root, _get_pair_key(src_lang, tgt_lang)))
[docs]@register(segment=['train', 'newstest2012', 'newstest2013', 'newstest2014', \
'newstest2015', 'newstest2016'])
class WMT2016BPE(_TranslationDataset):
"""Preprocessed Translation Corpus of the WMT2016 Evaluation Campaign.
We use the preprocessing script in
https://github.com/tensorflow/nmt/blob/master/nmt/scripts/wmt16_en_de.sh
Parameters
----------
segment : str or list of str, default 'train'
Dataset segment. Options are 'train', 'newstest2012', 'newstest2013',
'newstest2014', 'newstest2015', 'newstest2016' or their combinations
src_lang : str, default 'en'
The source language. Option for source and target languages are 'en' <-> 'de'
tgt_lang : str, default 'de'
The target language. Option for source and target languages are 'en' <-> 'de'
root : str, default '$MXNET_HOME/datasets/wmt2016'
Path to temp folder for storing data.
MXNET_HOME defaults to '~/.mxnet'.
"""
def __init__(self, segment='train', src_lang='en', tgt_lang='de',
root=os.path.join(get_home_dir(), 'datasets', 'wmt2016')):
self._supported_segments = ['train'] + ['newstest%d' % i for i in range(2012, 2017)]
self._archive_file = {_get_pair_key('de', 'en'):
('wmt2016bpe_de_en-8cf0dbf6.zip',
'8cf0dbf6a102381443a472bcf9f181299231b496')}
self._data_file = {_get_pair_key('de', 'en'):
{'train_en': ('train.tok.clean.bpe.32000.en',
'56f37cb4d68c2f83efd6a0c555275d1fe09f36b5'),
'train_de': ('train.tok.clean.bpe.32000.de',
'58f30a0ba7f80a8840a5cf3deff3c147de7d3f68'),
'newstest2012_en': ('newstest2012.tok.bpe.32000.en',
'25ed9ad228a236f57f97bf81db1bb004bedb7f33'),
'newstest2012_de': ('newstest2012.tok.bpe.32000.de',
'bb5622831ceea1894966fa993ebcd882cc461943'),
'newstest2013_en': ('newstest2013.tok.bpe.32000.en',
'fa03fe189fe68cb25014c5e64096ac8daf2919fa'),
'newstest2013_de': ('newstest2013.tok.bpe.32000.de',
'7d10a884499d352c2fea6f1badafb40473737640'),
'newstest2014_en': ('newstest2014.tok.bpe.32000.en',
'7b8ea824021cc5291e6a54bb32a1fc27c2955588'),
'newstest2014_de': ('newstest2014.tok.bpe.32000.de',
'd84497d4c425fa4e9b2b6be4b62c763086410aad'),
'newstest2015_en': ('newstest2015.tok.bpe.32000.en',
'ca335076f67b2f9b98848f8abc2cd424386f2309'),
'newstest2015_de': ('newstest2015.tok.bpe.32000.de',
'e633a3fb74506eb498fcad654d82c9b1a0a347b3'),
'newstest2016_en': ('newstest2016.tok.bpe.32000.en',
'5a5e36a6285823035b642aef7c1a9ec218da59f7'),
'newstest2016_de': ('newstest2016.tok.bpe.32000.de',
'135a79acb6a4f8fad0cbf5f74a15d9c0b5bf8c73'),
'vocab_en': ('vocab.bpe.32000.json',
'1c5aea0a77cad592c4e9c1136ec3b70ceeff4e8c'),
'vocab_de': ('vocab.bpe.32000.json',
'1c5aea0a77cad592c4e9c1136ec3b70ceeff4e8c')}}
super(WMT2016BPE, self).__init__('wmt2016', segment=segment, src_lang=src_lang,
tgt_lang=tgt_lang,
root=os.path.join(root, _get_pair_key(src_lang, tgt_lang)))