/*
 * ldsyms.c -- symbol table handling for ld(1)
 * Copyright (C) 2000 - 2003 Michael Riepe <michael@stud.uni-hannover.de>
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.	 See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA	 02111-1307	 USA
 */

static const char rcsid[] = "@(#) $Id: ldsyms.c,v 1.32 2003/02/08 13:09:46 michael Exp $";

#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <gelf.h>
#include <ar.h>
#include <assert.h>

#include <ld/ld.h>
#include <ld/ldmisc.h>

#define PRIME	263u

Elf_Data astrtab = { NULL, };
struct gsym *aglobals = NULL;
GElf_Word agbuckets[PRIME] = { 0, };
size_t anglobs = 0;

struct gsym *alocals = NULL;	/* .symtab in larval stage */
size_t anlocs = 0;				/* size of alocals */

char*
ld_symbol_name(size_t offset) {
	assert(astrtab.d_buf);
	assert(offset < astrtab.d_size);
	return (char*)astrtab.d_buf + offset;
}

static size_t
ld_add_global(unsigned long hash) {
	if (anlocs + anglobs >= SYMBOLS_MAX) {
		fatal("symbol table overflow");
	}
	if (anglobs % 64u == 0) {
		aglobals = xrealloc(aglobals, (anglobs + 64u) * sizeof(*aglobals));
		if (anglobs == 0) {
			anglobs++;
		}
	}
	aglobals[anglobs].chain = agbuckets[hash % PRIME];
	aglobals[anglobs].hash = hash;
	aglobals[anglobs].sect = NULL;
	agbuckets[hash % PRIME] = anglobs;
	return anglobs++;
}

static size_t
ld_add_local(const char *name, const GElf_Sym *sym) {
	assert(name && sym);
	if (anlocs + anglobs >= SYMBOLS_MAX) {
		fatal("symbol table overflow");
	}
	if (anlocs % 64u == 0) {
		alocals = xrealloc(alocals, (anlocs + 64u) * sizeof(*alocals));
		if (anlocs == 0) {
			anlocs++;
		}
	}
	debug("adding local `%s' (%u)", name, (unsigned)anlocs);
	alocals[anlocs].sym = *sym;
	alocals[anlocs].sym.st_name = ld_add_string(&astrtab, name);
	alocals[anlocs].sect = NULL;
	return anlocs++;
}

size_t
ld_find_symbol(const char *name, unsigned long hash) {
	size_t i;

	for (i = agbuckets[hash % PRIME]; i != STN_UNDEF; i = aglobals[i].chain) {
		assert(i < anglobs);
		if (hash == aglobals[i].hash
		 && strcmp(name, ld_symbol_name(aglobals[i].sym.st_name)) == 0) {
			break;
		}
	}
	return i;
}

int
ld_undefined_symbols(void) {
	GElf_Sym *gsym;
	const char *name;
	int err;
	size_t i;

	err = 0;
	for (i = 1; i < anglobs; i++) {
		gsym = &aglobals[i].sym;
		name = ld_symbol_name(gsym->st_name);
		if (gsym->st_shndx != SHN_UNDEF) {
			continue;
		}
		assert(!aglobals[i].sect);
		switch (GELF_ST_BIND(gsym->st_info)) {
			case STB_LOCAL:
				fatal("local symbol in global symbol table");
				/* not reached */
			case STB_WEAK:
				if (ld_target_type != ET_EXEC || ld_opt_dynamic != 'n') {
					break;
				}
				/*
				 * Change to absolute zero
				 */
				gsym->st_shndx = SHN_ABS;
				gsym->st_value = 0;
				gsym->st_info = GELF_ST_INFO(STB_GLOBAL, STT_NOTYPE);
				break;
			case STB_GLOBAL:
				if (ld_target_type == ET_REL) {
					break;
				}
				if (ld_target_type == ET_DYN && ld_opt_zdefs != 'y') {
					break;
				}
				if (ld_opt_zdefs == 'n') {
					break;
				}
				error("undefined symbol `%s'", name);
				err = -1;
				break;
			default:
				fatal("unrecognized symbol binding %#x",
					(unsigned)GELF_ST_BIND(gsym->st_info));
		}
	}
	return err;
}

/*
 * global  x global	 -> error
 * global  x common	 -> global
 * global  x weak	 -> global
 * global  x uglobal -> global
 * global  x uweak	 -> global
 * common  x common	 -> common
 * common  x weak	 -> common			???first?
 * common  x uglobal -> common
 * common  x uweak	 -> common
 * weak	   x weak	 -> weak(first)
 * weak	   x uglobal -> weak
 * weak	   x uweak	 -> weak
 * uglobal x uglobal -> uglobal
 * uglobal x uweak	 -> uglobal
 * uweak   x uweak	 -> uweak
 *
 * or, in a nutshell:  global > common > weak > uglobal > uweak.
 *
 * 7 -> global
 * 6 -> common
 * 5 -> weak
 * 4 -> weak common (does something like that exist at all?)
 * 3 -> unused
 * 2 -> uglobal
 * 1 -> unused
 * 0 -> uweak
 */
static int
ld_symbol_prio(const GElf_Sym *sym) {
	int prio;

	switch (sym->st_shndx) {
		case SHN_UNDEF:
			prio = 0;
			break;
		case SHN_COMMON:
			prio = 4;
			break;
		default:
			prio = 5;
			break;
	}
	if (GELF_ST_BIND(sym->st_info) == STB_GLOBAL) {
		prio += 2;
	}
	return prio;
}

size_t
ld_resolve_symbol(const char *fn, const char *name, const GElf_Sym *sym, struct outscn *sect) {
	GElf_Sym *gsym;
	unsigned long hash;
	int pa;
	int pb;
	size_t index;

	assert(fn);
	assert(name);
	assert(sym);
	assert(GELF_ST_BIND(sym->st_info) != STB_LOCAL);

	hash = elf_hash(name);
	index = ld_find_symbol(name, hash);
	if (index == STN_UNDEF) {
		/*
		 * Not found, add new symbol
		 */
		index = ld_add_global(hash);
		debug("adding global `%s' (%u)", name, (unsigned)index);
		aglobals[index].sym = *sym;
		aglobals[index].sym.st_name = ld_add_string(&astrtab, name);
		aglobals[index].sect = sect;
		return index;
	}
	gsym = &aglobals[index].sym;
	pa = ld_symbol_prio(gsym);
	pb = ld_symbol_prio(sym);
	if (pa == 7 && pb == 7) {
		/*
		 * Both symbols are defined (and global)
		 */
		if (!ld_opt_zmuldefs) {
			file_error(fn, "duplicate symbol `%s'", name);
		}
		return index;
	}
	if (pa >= 4 && pb >= 4) {
		/*
		 * Both defined (global, weak or tentative)
		 * Check types, sizes and alignments
		 */
		if (GELF_ST_TYPE(gsym->st_info) != GELF_ST_TYPE(sym->st_info)) {
			file_warn(fn, "symbol `%s' has different types", name);
			if (pa == pb && GELF_ST_TYPE(gsym->st_info) == STT_NOTYPE) {
				pa = -1;
			}
		}
		else if (GELF_ST_TYPE(gsym->st_info) != STT_OBJECT) {
			/*
			 * Nothing to do
			 */
		}
		else if (gsym->st_size != sym->st_size) {
			/*
			 * Objects of different size
			 */
			if (!ld_opt_no_size_warn) {
				file_warn(fn, "symbol `%s' has different sizes (%lu, %lu)",
						  name, gsym->st_size, sym->st_size);
			}
			if (pa == pb && gsym->st_size < sym->st_size) {
				pa = -1;
			}
		}
		else if (gsym->st_shndx == SHN_COMMON && sym->st_shndx == SHN_COMMON) {
			/*
			 * Both tentative
			 */
			if (gsym->st_value != sym->st_value) {
				if (!ld_opt_no_size_warn) {
					file_warn(fn, "symbol `%s' has different alignments (%lu, %lu)",
							  name, gsym->st_value, sym->st_value);
				}
				if (pa == pb && gsym->st_value < sym->st_value) {
					pa = -1;
				}
			}
		}
		/* XXX: check alignment for `common + nocommon' case? */
	}
	if (pa < pb) {
		GElf_Word tmpname;

		if (pa < 4 && pb >= 4) {
			debug("resolving symbol `%s' (%d, %d)", name, pa, pb);
		}
		else {
			debug("overriding symbol `%s' (%d, %d)", name, pa, pb);
		}
		tmpname = gsym->st_name;
		*gsym = *sym;
		gsym->st_name = tmpname;
		aglobals[index].sect = sect;
	}
	return index;
}

size_t
ld_undef_symbol(const char *name) {
	size_t index;
	unsigned long hash;

	assert(name);
	hash = elf_hash(name);
	index = ld_find_symbol(name, hash);
	if (index == STN_UNDEF) {
		/*
		 * not found, add new (undefined) symbol
		 */
		index = ld_add_global(hash);
		aglobals[index].sym.st_name = ld_add_string(&astrtab, name);
	}
	else {
		warn("undefining symbol `%s'", name);
		aglobals[index].sect = NULL;
	}
	aglobals[index].sym.st_value = 0;
	aglobals[index].sym.st_size = 0;
	aglobals[index].sym.st_info = GELF_ST_INFO(STB_GLOBAL, STT_NOTYPE);
	aglobals[index].sym.st_other = 0;
	aglobals[index].sym.st_shndx = SHN_UNDEF;
	return index;
}

int
ld_copy_symbols(void) {
	GElf_Sym sym;
	const char *name;
	size_t i;
	size_t n;
	size_t secno;

	debug("ld_copy_symbols");

	/*
	 * Process symbols
	 */
	cur.stable = xrealloc(cur.stable, cur.nsyms * sizeof(*cur.stable));
	cur.stable[0] = 0;
	for (i = 1; i < cur.nsyms; i++) {
		struct outscn *sect = NULL;

		cur.stable[i] = 0;
		sym = cur.syms[i];
		name = cur.strs + sym.st_name;
		switch ((secno = sym.st_shndx)) {
			case SHN_UNDEF:
				if (GELF_ST_BIND(sym.st_info) == STB_LOCAL) {
					file_warn(cur.fn, "skipping undefined local `%s'", name);
					continue;
				}
				break;
			case SHN_ABS:
				break;
			case SHN_COMMON:
				if (GELF_ST_BIND(sym.st_info) == STB_LOCAL) {
					file_warn(cur.fn, "common symbol `%s' is local", name);
				}
				break;
			case SHN_XINDEX:
				secno = cur.shndx[i];
				if (secno < SHN_LORESERVE) {
					file_error(cur.fn, "bad st_shndx escape value for `%s'", name);
					return -1;
				}
				goto ordinary_symbol;
			default:
				if (secno >= SHN_LOPROC && secno <= SHN_HIPROC) {
					file_error(cur.fn, "unrecognized CPU-specific "
						"section index (SHN_LOPROC+%u) for `%s'",
						(unsigned)secno - SHN_LOPROC, name);
					return -1;
				}
				if (secno >= SHN_LOOS && secno <= SHN_HIOS) {
					file_error(cur.fn, "unrecognized OS-specific "
						"section index (SHN_LOOS+%u) for `%s'",
						(unsigned)secno - SHN_LOOS, name);
					return -1;
				}
				if (secno >= SHN_LORESERVE && secno <= SHN_HIRESERVE) {
					file_error(cur.fn, "unrecognized reserved "
						"section index (SHN_LORESERVE+%u) for `%s'",
						(unsigned)secno - SHN_LORESERVE, name);
					return -1;
				}
			ordinary_symbol:
				if (secno >= cur.nscns) {
					file_error(cur.fn, "symbol `%s' references non-existing section #%u",
						name, (unsigned)secno);
					return -1;
				}
				/*
				 * Skip section symbols.
				 * We'll re-create them later (if they're needed).
				 */
				if (GELF_ST_TYPE(sym.st_info) == STT_SECTION) {
					if (GELF_ST_BIND(sym.st_info) != STB_LOCAL) {
						file_error(cur.fn, "STT_SECTION symbol is not local");
						return -1;
					}
					continue;
				}
				if ((sect = cur.scns[secno].out)) {
					/*
					 * Adjust section index and offset
					 */
					sym.st_value += cur.scns[secno].off;
					break;
				}
				/*
				 * Section has not been copied,
				 * XXX: this should never happen!
				 */
				fatal("symbol `%s' references ignored section `%s' (%u)",
					name, cur.names + cur.scns[secno].sh.sh_name,
					(unsigned)secno);
		}
		if (GELF_ST_TYPE(sym.st_info) == STT_SECTION) {
			file_error(cur.fn, "bad st_shndx (0x%4x) in section symbol",
				(unsigned)sym.st_shndx);
			return -1;
		}
		if (GELF_ST_BIND(sym.st_info) != STB_LOCAL) {
			/*
			 * Values greater than `anlocs' indicate global symbols
			 */
			n = ld_resolve_symbol(cur.fn, name, &sym, sect);
			cur.stable[i] = SYMBOLS_MAX - n;
		}
		else {
			n = ld_add_local(name, &sym);
			alocals[n].sect = sect;
			cur.stable[i] = n;
		}
	}
	return 0;
}

size_t
ld_section_symbol(size_t index) {
	GElf_Sym sym;
	size_t n;
	size_t secno;
	struct outscn *sect;

	if (index < 1 || index >= cur.nsyms) {
		file_error(cur.fn, "bad symbol index");
		return 0;
	}
	if (cur.stable[index]) {
		return cur.stable[index];
	}
	sym = cur.syms[index];
	if ((secno = sym.st_shndx) == SHN_XINDEX) {
		secno = cur.shndx[index];
		if (secno < SHN_LORESERVE) {
			file_error(cur.fn, "bad st_shndx escape value in section symbol");
			return 0;
		}
	}
	else if (secno == SHN_UNDEF
	 || (secno >= SHN_LORESERVE && secno <= SHN_HIRESERVE)) {
		file_error(cur.fn, "bad st_shndx in section symbol");
		return 0;
	}
	if (secno >= cur.nscns) {
		file_error(cur.fn, "section symbol references non-existing section #%u",
			(unsigned)secno);
		return 0;
	}
	if (!(sect = cur.scns[secno].out)) {
		file_error(cur.fn, "reference to ignored section");
		return 0;
	}
	sym.st_value += cur.scns[secno].off;
	n = ld_add_local(ld_symbol_name(sym.st_name), &sym);
	alocals[n].sect = sect;
	cur.stable[index] = n;
	return n;
}

void
ld_symbol_index(struct gsym *gsym) {
	size_t index;
	struct outscn *sect;

	assert(gsym);
	gsym->shndx = 0;
	if (gsym->sym.st_shndx == SHN_UNDEF) {
		return;
	}
	if (gsym->sym.st_shndx != SHN_XINDEX
	 && gsym->sym.st_shndx >= SHN_LORESERVE
	 && gsym->sym.st_shndx <= SHN_HIRESERVE) {
		return;
	}
	sect = gsym->sect;
	assert(sect && sect->scn);
	index = elf_ndxscn(sect->scn);
	assert(index);
	if (index >= SHN_LORESERVE) {
		gsym->shndx = index;
		index = SHN_XINDEX;
	}
	gsym->sym.st_shndx = index;
}

void
ld_symbol_addr(struct gsym *gsym) {
	assert(gsym);
	if (gsym->sym.st_shndx == SHN_UNDEF) {
		return;
	}
	if (gsym->sym.st_shndx != SHN_XINDEX
	 && gsym->sym.st_shndx >= SHN_LORESERVE
	 && gsym->sym.st_shndx <= SHN_HIRESERVE) {
		return;
	}
	assert(gsym->sect);
	gsym->sym.st_value += gsym->sect->shdr.sh_addr;
}
