#!/usr/bin/env python3 # -*- coding: utf-8 -*- # # Copyright (c) 2018-present, Facebook, Inc. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import numpy as np import argparse from utils import * import sys parser = argparse.ArgumentParser(description='RCSLS for supervised word alignment') parser.add_argument("--src_emb", type=str, default='', help="Load source embeddings") parser.add_argument("--tgt_emb", type=str, default='', help="Load target embeddings") parser.add_argument('--center', action='store_true', help='whether to center embeddings or not') parser.add_argument("--dico_train", type=str, default='', help="train dictionary") parser.add_argument("--dico_test", type=str, default='', help="validation dictionary") parser.add_argument("--output", type=str, default='', help="where to save aligned embeddings") parser.add_argument("--knn", type=int, default=10, help="number of nearest neighbors in RCSL/CSLS") parser.add_argument("--maxneg", type=int, default=200000, help="Maximum number of negatives for the Extended RCSLS") parser.add_argument("--maxsup", type=int, default=-1, help="Maximum number of training examples") parser.add_argument("--maxload", type=int, default=200000, help="Maximum number of loaded vectors") parser.add_argument("--model", type=str, default="none", help="Set of constraints: spectral or none") parser.add_argument("--reg", type=float, default=0.0 , help='regularization parameters') parser.add_argument("--lr", type=float, default=1.0, help='learning rate') parser.add_argument("--niter", type=int, default=10, help='number of iterations') parser.add_argument('--sgd', action='store_true', help='use sgd') parser.add_argument("--batchsize", type=int, default=10000, help="batch size for sgd") params = parser.parse_args() ###### SPECIFIC FUNCTIONS ###### # functions specific to RCSLS # the rest of the functions are in utils.py def getknn(sc, x, y, k=10): sidx = np.argpartition(sc, -k, axis=1)[:, -k:] ytopk = y[sidx.flatten(), :] ytopk = ytopk.reshape(sidx.shape[0], sidx.shape[1], y.shape[1]) f = np.sum(sc[np.arange(sc.shape[0])[:, None], sidx]) df = np.dot(ytopk.sum(1).T, x) return f / k, df / k def rcsls(X_src, Y_tgt, Z_src, Z_tgt, R, knn=10): X_trans = np.dot(X_src, R.T) f = 2 * np.sum(X_trans * Y_tgt) df = 2 * np.dot(Y_tgt.T, X_src) fk0, dfk0 = getknn(np.dot(X_trans, Z_tgt.T), X_src, Z_tgt, knn) fk1, dfk1 = getknn(np.dot(np.dot(Z_src, R.T), Y_tgt.T).T, Y_tgt, Z_src, knn) f = f - fk0 -fk1 df = df - dfk0 - dfk1.T return -f / X_src.shape[0], -df / X_src.shape[0] def proj_spectral(R): U, s, V = np.linalg.svd(R) s[s > 1] = 1 s[s < 0] = 0 return np.dot(U, np.dot(np.diag(s), V)) ###### MAIN ###### # load word embeddings words_tgt, x_tgt = load_vectors(params.tgt_emb, maxload=params.maxload, center=params.center) words_src, x_src = load_vectors(params.src_emb, maxload=params.maxload, center=params.center) # load validation bilingual lexicon src2tgt, lexicon_size = load_lexicon(params.dico_test, words_src, words_tgt) # word --> vector indices idx_src = idx(words_src) idx_tgt = idx(words_tgt) # load train bilingual lexicon pairs = load_pairs(params.dico_train, idx_src, idx_tgt) if params.maxsup > 0 and params.maxsup < len(pairs): pairs = pairs[:params.maxsup] # selecting training vector pairs X_src, Y_tgt = select_vectors_from_pairs(x_src, x_tgt, pairs) # adding negatives for RCSLS Z_src = x_src[:params.maxneg, :] Z_tgt = x_tgt[:params.maxneg, :] # initialization: R = procrustes(X_src, Y_tgt) nnacc = compute_nn_accuracy(np.dot(x_src, R.T), x_tgt, src2tgt, lexicon_size=lexicon_size) print("[init -- Procrustes] NN: %.4f"%(nnacc)) sys.stdout.flush() # optimization fold, Rold = 0, [] niter, lr = params.niter, params.lr for it in range(0, niter + 1): if lr < 1e-4: break if params.sgd: indices = np.random.choice(X_src.shape[0], size=params.batchsize, replace=False) f, df = rcsls(X_src[indices, :], Y_tgt[indices, :], Z_src, Z_tgt, R, params.knn) else: f, df = rcsls(X_src, Y_tgt, Z_src, Z_tgt, R, params.knn) if params.reg > 0: R *= (1 - lr * params.reg) R -= lr * df if params.model == "spectral": R = proj_spectral(R) print("[it=%d] f = %.4f" % (it, f)) sys.stdout.flush() if f > fold and it > 0 and not params.sgd: lr /= 2 f, R = fold, Rold fold, Rold = f, R if (it > 0 and it % 10 == 0) or it == niter: nnacc = compute_nn_accuracy(np.dot(x_src, R.T), x_tgt, src2tgt, lexicon_size=lexicon_size) print("[it=%d] NN = %.4f - Coverage = %.4f" % (it, nnacc, len(src2tgt) / lexicon_size)) nnacc = compute_nn_accuracy(np.dot(x_src, R.T), x_tgt, src2tgt, lexicon_size=lexicon_size) print("[final] NN = %.4f - Coverage = %.4f" % (nnacc, len(src2tgt) / lexicon_size)) if params.output != "": print("Saving all aligned vectors at %s" % params.output) words_full, x_full = load_vectors(params.src_emb, maxload=-1, center=params.center, verbose=False) x = np.dot(x_full, R.T) x /= np.linalg.norm(x, axis=1)[:, np.newaxis] + 1e-8 save_vectors(params.output, x, words_full) save_matrix(params.output + "-mat", R)