Skip to content

Commit

Permalink
Fixing some issues with unassigned questions
Browse files Browse the repository at this point in the history
  • Loading branch information
nreimers committed Aug 6, 2020
1 parent 77390c7 commit f4377b2
Showing 1 changed file with 58 additions and 33 deletions.
91 changes: 58 additions & 33 deletions examples/training_quora_duplicate_questions/create_splits.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
import os
from sentence_transformers import util

random.seed(42)

#Get raw file
source_file = "quora-IR-dataset/quora_duplicate_questions.tsv"
os.makedirs('quora-IR-dataset', exist_ok=True)
Expand Down Expand Up @@ -104,7 +106,6 @@
#Distribute rows to train/dev/test split
#Ensure that sets contain distinct sentences
is_assigned = set()
random.seed(42)
random.shuffle(rows)

train_ids = set()
Expand All @@ -113,15 +114,28 @@

counter = 0
for row in rows:
if row['qid1'] in is_assigned or row['qid2'] in is_assigned:
if row['qid1'] in is_assigned and row['qid2'] in is_assigned:
continue

#Distribution about 85%/5%/10%
target_set = train_ids
if counter%10 == 0:
target_set = dev_ids
elif counter%10 == 1 or counter%10 == 2:
target_set = test_ids
elif row['qid1'] in is_assigned or row['qid2'] in is_assigned:

if row['qid2'] in is_assigned: #Ensure that qid1 is assigned and qid2 not yet
row['qid1'], row['qid2'] = row['qid2'], row['qid1']

#Move qid2 to the same split as qid1
target_set = train_ids
if row['qid1'] in dev_ids:
target_set = dev_ids
elif row['qid1'] in test_ids:
target_set = test_ids

else:
#Distribution about 85%/5%/10%
target_set = train_ids
if counter%10 == 0:
target_set = dev_ids
elif counter%10 == 1 or counter%10 == 2:
target_set = test_ids
counter += 1

#Get the sentence with all duplicates and add it to the respective sets
target_set.add(row['qid1'])
Expand All @@ -134,9 +148,14 @@
target_set.add(b)
is_assigned.add(b)

counter += 1

print("Train sentences:", len(train_ids))
#Assert all sets are mutually exclusive
assert len(train_ids.intersection(dev_ids)) == 0
assert len(train_ids.intersection(test_ids)) == 0
assert len(test_ids.intersection(dev_ids)) == 0


print("\nTrain sentences:", len(train_ids))
print("Dev sentences:", len(dev_ids))
print("Test sentences:", len(test_ids))

Expand All @@ -154,8 +173,8 @@ def get_duplicate_set(ids_set):
test_duplicates = get_duplicate_set(test_ids)


print("Train duplicates", len(train_duplicates))
print("dev duplicates", len(dev_duplicates))
print("\nTrain duplicates", len(train_duplicates))
print("Dev duplicates", len(dev_duplicates))
print("Test duplicates", len(test_duplicates))

############### Write general files about the duplate questions graph ############
Expand All @@ -174,7 +193,7 @@ def get_duplicate_set(ids_set):
duplicates_list = sorted(duplicates_list, key=lambda x: x[0]*1000000+x[1])


print("Write duplicate graph in pairwise format")
print("\nWrite duplicate graph in pairwise format")
with open('quora-IR-dataset/graph/duplicates-graph-pairwise.tsv', 'w', encoding='utf8') as fOut:
fOut.write("qid1\tqid2\n")
for a, b in duplicates_list:
Expand All @@ -192,7 +211,7 @@ def get_duplicate_set(ids_set):
def write_qids(name, ids_list):
with open('quora-IR-dataset/graph/'+name+'-questions.tsv', 'w', encoding='utf8') as fOut:
fOut.write("qid\n")
fOut.write("\n".join(sorted(ids_list)))
fOut.write("\n".join(sorted(ids_list, key=lambda x: int(x))))

write_qids('train', train_ids)
write_qids('dev', dev_ids)
Expand Down Expand Up @@ -249,54 +268,60 @@ def write_mining_files(name, ids, dups):
test_queries = set()

#Create dev queries
for a, b in dev_duplicates:
if a not in corpus_ids and b not in corpus_ids:
if len(dev_queries) < num_dev_queries:
rnd_dev_ids = sorted(list(dev_ids))
random.shuffle(rnd_dev_ids)

for a in rnd_dev_ids:
if a not in corpus_ids:
if len(dev_queries) < num_dev_queries and len(duplicates[a]) > 0:
dev_queries.add(a)
else:
corpus_ids.add(a)

corpus_ids.add(b)
for further_dups in duplicates[b]:
if further_dups not in dev_queries:
corpus_ids.add(further_dups)
for b in duplicates[a]:
if b not in dev_queries:
corpus_ids.add(b)

#Create test queries
for a, b in test_duplicates:
if a not in corpus_ids and b not in corpus_ids:
if len(test_queries) < num_test_queries:
rnd_test_ids = sorted(list(test_ids))
random.shuffle(rnd_test_ids)

for a in rnd_test_ids:
if a not in corpus_ids:
if len(test_queries) < num_test_queries and len(duplicates[a]) > 0:
test_queries.add(a)
else:
corpus_ids.add(a)

corpus_ids.add(b)
for further_dups in duplicates[b]:
if further_dups not in test_queries:
corpus_ids.add(further_dups)
for b in duplicates[a]:
if b not in test_queries:
corpus_ids.add(b)

#Write output for information-retrieval
print("\nInformation Retrival Setup")
print("Corpus size:", len(corpus_ids))
print("Dev queries:", len(dev_queries))
print("Test queries:", len(test_queries))

with open('quora-IR-dataset/information-retrieval/corpus.tsv', 'w', encoding='utf8') as fOut:
fOut.write("qid\tquestion\n")
for id in corpus_ids:
for id in sorted(corpus_ids, key=lambda id: int(id)):
fOut.write("{}\t{}\n".format(id, sentences[id]))

with open('quora-IR-dataset/information-retrieval/dev-queries.tsv', 'w', encoding='utf8') as fOut:
fOut.write("qid\tquestion\tduplicate_qids\n")
for id in dev_queries:
for id in sorted(dev_queries, key=lambda id: int(id)):
fOut.write("{}\t{}\t{}\n".format(id, sentences[id], ",".join(duplicates[id])))

with open('quora-IR-dataset/information-retrieval/test-queries.tsv', 'w', encoding='utf8') as fOut:
fOut.write("qid\tquestion\tduplicate_qids\n")
for id in test_queries:
for id in sorted(test_queries, key=lambda id: int(id)):
fOut.write("{}\t{}\t{}\n".format(id, sentences[id], ",".join(duplicates[id])))


print("--DONE--")





print("--DONE--")

0 comments on commit f4377b2

Please sign in to comment.