vmci_transport.c 58.3 KB
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767 1768 1769 1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780 1781 1782 1783 1784 1785 1786 1787 1788 1789 1790 1791 1792 1793 1794 1795 1796 1797 1798 1799 1800 1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818 1819 1820 1821 1822 1823 1824 1825 1826 1827 1828 1829 1830 1831 1832 1833 1834 1835 1836 1837 1838 1839 1840 1841 1842 1843 1844 1845 1846 1847 1848 1849 1850 1851 1852 1853 1854 1855 1856 1857 1858 1859 1860 1861 1862 1863 1864 1865 1866 1867 1868 1869 1870 1871 1872 1873 1874 1875 1876 1877 1878 1879 1880 1881 1882 1883 1884 1885 1886 1887 1888 1889 1890 1891 1892 1893 1894 1895 1896 1897 1898 1899 1900 1901 1902 1903 1904 1905 1906 1907 1908 1909 1910 1911 1912 1913 1914 1915 1916 1917 1918 1919 1920 1921 1922 1923 1924 1925 1926 1927 1928 1929 1930 1931 1932 1933 1934 1935 1936 1937 1938 1939 1940 1941 1942 1943 1944 1945 1946 1947 1948 1949 1950 1951 1952 1953 1954 1955 1956 1957 1958 1959 1960 1961 1962 1963 1964 1965 1966 1967 1968 1969 1970 1971 1972 1973 1974 1975 1976 1977 1978 1979 1980 1981 1982 1983 1984 1985 1986 1987 1988 1989 1990 1991 1992 1993 1994 1995 1996 1997 1998 1999 2000 2001 2002 2003 2004 2005 2006 2007 2008 2009 2010 2011 2012 2013 2014 2015 2016 2017 2018 2019 2020 2021 2022 2023 2024 2025 2026 2027 2028 2029 2030 2031 2032 2033 2034 2035 2036 2037 2038 2039 2040 2041 2042 2043 2044 2045 2046 2047 2048 2049 2050 2051 2052 2053 2054 2055 2056 2057 2058 2059 2060 2061 2062 2063 2064 2065 2066 2067 2068 2069 2070 2071 2072 2073 2074 2075 2076 2077 2078 2079 2080 2081 2082 2083 2084 2085 2086 2087 2088 2089 2090 2091 2092 2093 2094 2095 2096 2097 2098 2099 2100 2101 2102 2103 2104 2105 2106 2107 2108 2109 2110 2111 2112 2113 2114 2115 2116 2117 2118 2119 2120 2121 2122 2123 2124 2125 2126 2127 2128 2129 2130 2131 2132 2133 2134 2135 2136 2137 2138 2139 2140 2141 2142 2143 2144 2145 2146 2147 2148 2149 2150 2151 2152 2153 2154 2155 2156 2157 2158 2159
/*
 * VMware vSockets Driver
 *
 * Copyright (C) 2007-2013 VMware, Inc. All rights reserved.
 *
 * This program is free software; you can redistribute it and/or modify it
 * under the terms of the GNU General Public License as published by the Free
 * Software Foundation version 2 and no later version.
 *
 * This program is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for
 * more details.
 */

#include <linux/types.h>
#include <linux/bitops.h>
#include <linux/cred.h>
#include <linux/init.h>
#include <linux/io.h>
#include <linux/kernel.h>
#include <linux/kmod.h>
#include <linux/list.h>
#include <linux/miscdevice.h>
#include <linux/module.h>
#include <linux/mutex.h>
#include <linux/net.h>
#include <linux/poll.h>
#include <linux/skbuff.h>
#include <linux/smp.h>
#include <linux/socket.h>
#include <linux/stddef.h>
#include <linux/unistd.h>
#include <linux/wait.h>
#include <linux/workqueue.h>
#include <net/sock.h>
#include <net/af_vsock.h>

#include "vmci_transport_notify.h"

static int vmci_transport_recv_dgram_cb(void *data, struct vmci_datagram *dg);
static int vmci_transport_recv_stream_cb(void *data, struct vmci_datagram *dg);
static void vmci_transport_peer_detach_cb(u32 sub_id,
					  const struct vmci_event_data *ed,
					  void *client_data);
static void vmci_transport_recv_pkt_work(struct work_struct *work);
static void vmci_transport_cleanup(struct work_struct *work);
static int vmci_transport_recv_listen(struct sock *sk,
				      struct vmci_transport_packet *pkt);
static int vmci_transport_recv_connecting_server(
					struct sock *sk,
					struct sock *pending,
					struct vmci_transport_packet *pkt);
static int vmci_transport_recv_connecting_client(
					struct sock *sk,
					struct vmci_transport_packet *pkt);
static int vmci_transport_recv_connecting_client_negotiate(
					struct sock *sk,
					struct vmci_transport_packet *pkt);
static int vmci_transport_recv_connecting_client_invalid(
					struct sock *sk,
					struct vmci_transport_packet *pkt);
static int vmci_transport_recv_connected(struct sock *sk,
					 struct vmci_transport_packet *pkt);
static bool vmci_transport_old_proto_override(bool *old_pkt_proto);
static u16 vmci_transport_new_proto_supported_versions(void);
static bool vmci_transport_proto_to_notify_struct(struct sock *sk, u16 *proto,
						  bool old_pkt_proto);

struct vmci_transport_recv_pkt_info {
	struct work_struct work;
	struct sock *sk;
	struct vmci_transport_packet pkt;
};

static LIST_HEAD(vmci_transport_cleanup_list);
static DEFINE_SPINLOCK(vmci_transport_cleanup_lock);
static DECLARE_WORK(vmci_transport_cleanup_work, vmci_transport_cleanup);

static struct vmci_handle vmci_transport_stream_handle = { VMCI_INVALID_ID,
							   VMCI_INVALID_ID };
static u32 vmci_transport_qp_resumed_sub_id = VMCI_INVALID_ID;

static int PROTOCOL_OVERRIDE = -1;

#define VMCI_TRANSPORT_DEFAULT_QP_SIZE_MIN   128
#define VMCI_TRANSPORT_DEFAULT_QP_SIZE       262144
#define VMCI_TRANSPORT_DEFAULT_QP_SIZE_MAX   262144

/* The default peer timeout indicates how long we will wait for a peer response
 * to a control message.
 */
#define VSOCK_DEFAULT_CONNECT_TIMEOUT (2 * HZ)

/* Helper function to convert from a VMCI error code to a VSock error code. */

static s32 vmci_transport_error_to_vsock_error(s32 vmci_error)
{
	int err;

	switch (vmci_error) {
	case VMCI_ERROR_NO_MEM:
		err = ENOMEM;
		break;
	case VMCI_ERROR_DUPLICATE_ENTRY:
	case VMCI_ERROR_ALREADY_EXISTS:
		err = EADDRINUSE;
		break;
	case VMCI_ERROR_NO_ACCESS:
		err = EPERM;
		break;
	case VMCI_ERROR_NO_RESOURCES:
		err = ENOBUFS;
		break;
	case VMCI_ERROR_INVALID_RESOURCE:
		err = EHOSTUNREACH;
		break;
	case VMCI_ERROR_INVALID_ARGS:
	default:
		err = EINVAL;
	}

	return err > 0 ? -err : err;
}

static u32 vmci_transport_peer_rid(u32 peer_cid)
{
	if (VMADDR_CID_HYPERVISOR == peer_cid)
		return VMCI_TRANSPORT_HYPERVISOR_PACKET_RID;

	return VMCI_TRANSPORT_PACKET_RID;
}

static inline void
vmci_transport_packet_init(struct vmci_transport_packet *pkt,
			   struct sockaddr_vm *src,
			   struct sockaddr_vm *dst,
			   u8 type,
			   u64 size,
			   u64 mode,
			   struct vmci_transport_waiting_info *wait,
			   u16 proto,
			   struct vmci_handle handle)
{
	/* We register the stream control handler as an any cid handle so we
	 * must always send from a source address of VMADDR_CID_ANY
	 */
	pkt->dg.src = vmci_make_handle(VMADDR_CID_ANY,
				       VMCI_TRANSPORT_PACKET_RID);
	pkt->dg.dst = vmci_make_handle(dst->svm_cid,
				       vmci_transport_peer_rid(dst->svm_cid));
	pkt->dg.payload_size = sizeof(*pkt) - sizeof(pkt->dg);
	pkt->version = VMCI_TRANSPORT_PACKET_VERSION;
	pkt->type = type;
	pkt->src_port = src->svm_port;
	pkt->dst_port = dst->svm_port;
	memset(&pkt->proto, 0, sizeof(pkt->proto));
	memset(&pkt->_reserved2, 0, sizeof(pkt->_reserved2));

	switch (pkt->type) {
	case VMCI_TRANSPORT_PACKET_TYPE_INVALID:
		pkt->u.size = 0;
		break;

	case VMCI_TRANSPORT_PACKET_TYPE_REQUEST:
	case VMCI_TRANSPORT_PACKET_TYPE_NEGOTIATE:
		pkt->u.size = size;
		break;

	case VMCI_TRANSPORT_PACKET_TYPE_OFFER:
	case VMCI_TRANSPORT_PACKET_TYPE_ATTACH:
		pkt->u.handle = handle;
		break;

	case VMCI_TRANSPORT_PACKET_TYPE_WROTE:
	case VMCI_TRANSPORT_PACKET_TYPE_READ:
	case VMCI_TRANSPORT_PACKET_TYPE_RST:
		pkt->u.size = 0;
		break;

	case VMCI_TRANSPORT_PACKET_TYPE_SHUTDOWN:
		pkt->u.mode = mode;
		break;

	case VMCI_TRANSPORT_PACKET_TYPE_WAITING_READ:
	case VMCI_TRANSPORT_PACKET_TYPE_WAITING_WRITE:
		memcpy(&pkt->u.wait, wait, sizeof(pkt->u.wait));
		break;

	case VMCI_TRANSPORT_PACKET_TYPE_REQUEST2:
	case VMCI_TRANSPORT_PACKET_TYPE_NEGOTIATE2:
		pkt->u.size = size;
		pkt->proto = proto;
		break;
	}
}

static inline void
vmci_transport_packet_get_addresses(struct vmci_transport_packet *pkt,
				    struct sockaddr_vm *local,
				    struct sockaddr_vm *remote)
{
	vsock_addr_init(local, pkt->dg.dst.context, pkt->dst_port);
	vsock_addr_init(remote, pkt->dg.src.context, pkt->src_port);
}

static int
__vmci_transport_send_control_pkt(struct vmci_transport_packet *pkt,
				  struct sockaddr_vm *src,
				  struct sockaddr_vm *dst,
				  enum vmci_transport_packet_type type,
				  u64 size,
				  u64 mode,
				  struct vmci_transport_waiting_info *wait,
				  u16 proto,
				  struct vmci_handle handle,
				  bool convert_error)
{
	int err;

	vmci_transport_packet_init(pkt, src, dst, type, size, mode, wait,
				   proto, handle);
	err = vmci_datagram_send(&pkt->dg);
	if (convert_error && (err < 0))
		return vmci_transport_error_to_vsock_error(err);

	return err;
}

static int
vmci_transport_reply_control_pkt_fast(struct vmci_transport_packet *pkt,
				      enum vmci_transport_packet_type type,
				      u64 size,
				      u64 mode,
				      struct vmci_transport_waiting_info *wait,
				      struct vmci_handle handle)
{
	struct vmci_transport_packet reply;
	struct sockaddr_vm src, dst;

	if (pkt->type == VMCI_TRANSPORT_PACKET_TYPE_RST) {
		return 0;
	} else {
		vmci_transport_packet_get_addresses(pkt, &src, &dst);
		return __vmci_transport_send_control_pkt(&reply, &src, &dst,
							 type,
							 size, mode, wait,
							 VSOCK_PROTO_INVALID,
							 handle, true);
	}
}

static int
vmci_transport_send_control_pkt_bh(struct sockaddr_vm *src,
				   struct sockaddr_vm *dst,
				   enum vmci_transport_packet_type type,
				   u64 size,
				   u64 mode,
				   struct vmci_transport_waiting_info *wait,
				   struct vmci_handle handle)
{
	/* Note that it is safe to use a single packet across all CPUs since
	 * two tasklets of the same type are guaranteed to not ever run
	 * simultaneously. If that ever changes, or VMCI stops using tasklets,
	 * we can use per-cpu packets.
	 */
	static struct vmci_transport_packet pkt;

	return __vmci_transport_send_control_pkt(&pkt, src, dst, type,
						 size, mode, wait,
						 VSOCK_PROTO_INVALID, handle,
						 false);
}

static int
vmci_transport_send_control_pkt(struct sock *sk,
				enum vmci_transport_packet_type type,
				u64 size,
				u64 mode,
				struct vmci_transport_waiting_info *wait,
				u16 proto,
				struct vmci_handle handle)
{
	struct vmci_transport_packet *pkt;
	struct vsock_sock *vsk;
	int err;

	vsk = vsock_sk(sk);

	if (!vsock_addr_bound(&vsk->local_addr))
		return -EINVAL;

	if (!vsock_addr_bound(&vsk->remote_addr))
		return -EINVAL;

	pkt = kmalloc(sizeof(*pkt), GFP_KERNEL);
	if (!pkt)
		return -ENOMEM;

	err = __vmci_transport_send_control_pkt(pkt, &vsk->local_addr,
						&vsk->remote_addr, type, size,
						mode, wait, proto, handle,
						true);
	kfree(pkt);

	return err;
}

static int vmci_transport_send_reset_bh(struct sockaddr_vm *dst,
					struct sockaddr_vm *src,
					struct vmci_transport_packet *pkt)
{
	if (pkt->type == VMCI_TRANSPORT_PACKET_TYPE_RST)
		return 0;
	return vmci_transport_send_control_pkt_bh(
					dst, src,
					VMCI_TRANSPORT_PACKET_TYPE_RST, 0,
					0, NULL, VMCI_INVALID_HANDLE);
}

static int vmci_transport_send_reset(struct sock *sk,
				     struct vmci_transport_packet *pkt)
{
	if (pkt->type == VMCI_TRANSPORT_PACKET_TYPE_RST)
		return 0;
	return vmci_transport_send_control_pkt(sk,
					VMCI_TRANSPORT_PACKET_TYPE_RST,
					0, 0, NULL, VSOCK_PROTO_INVALID,
					VMCI_INVALID_HANDLE);
}

static int vmci_transport_send_negotiate(struct sock *sk, size_t size)
{
	return vmci_transport_send_control_pkt(
					sk,
					VMCI_TRANSPORT_PACKET_TYPE_NEGOTIATE,
					size, 0, NULL,
					VSOCK_PROTO_INVALID,
					VMCI_INVALID_HANDLE);
}

static int vmci_transport_send_negotiate2(struct sock *sk, size_t size,
					  u16 version)
{
	return vmci_transport_send_control_pkt(
					sk,
					VMCI_TRANSPORT_PACKET_TYPE_NEGOTIATE2,
					size, 0, NULL, version,
					VMCI_INVALID_HANDLE);
}

static int vmci_transport_send_qp_offer(struct sock *sk,
					struct vmci_handle handle)
{
	return vmci_transport_send_control_pkt(
					sk, VMCI_TRANSPORT_PACKET_TYPE_OFFER, 0,
					0, NULL,
					VSOCK_PROTO_INVALID, handle);
}

static int vmci_transport_send_attach(struct sock *sk,
				      struct vmci_handle handle)
{
	return vmci_transport_send_control_pkt(
					sk, VMCI_TRANSPORT_PACKET_TYPE_ATTACH,
					0, 0, NULL, VSOCK_PROTO_INVALID,
					handle);
}

static int vmci_transport_reply_reset(struct vmci_transport_packet *pkt)
{
	return vmci_transport_reply_control_pkt_fast(
						pkt,
						VMCI_TRANSPORT_PACKET_TYPE_RST,
						0, 0, NULL,
						VMCI_INVALID_HANDLE);
}

static int vmci_transport_send_invalid_bh(struct sockaddr_vm *dst,
					  struct sockaddr_vm *src)
{
	return vmci_transport_send_control_pkt_bh(
					dst, src,
					VMCI_TRANSPORT_PACKET_TYPE_INVALID,
					0, 0, NULL, VMCI_INVALID_HANDLE);
}

int vmci_transport_send_wrote_bh(struct sockaddr_vm *dst,
				 struct sockaddr_vm *src)
{
	return vmci_transport_send_control_pkt_bh(
					dst, src,
					VMCI_TRANSPORT_PACKET_TYPE_WROTE, 0,
					0, NULL, VMCI_INVALID_HANDLE);
}

int vmci_transport_send_read_bh(struct sockaddr_vm *dst,
				struct sockaddr_vm *src)
{
	return vmci_transport_send_control_pkt_bh(
					dst, src,
					VMCI_TRANSPORT_PACKET_TYPE_READ, 0,
					0, NULL, VMCI_INVALID_HANDLE);
}

int vmci_transport_send_wrote(struct sock *sk)
{
	return vmci_transport_send_control_pkt(
					sk, VMCI_TRANSPORT_PACKET_TYPE_WROTE, 0,
					0, NULL, VSOCK_PROTO_INVALID,
					VMCI_INVALID_HANDLE);
}

int vmci_transport_send_read(struct sock *sk)
{
	return vmci_transport_send_control_pkt(
					sk, VMCI_TRANSPORT_PACKET_TYPE_READ, 0,
					0, NULL, VSOCK_PROTO_INVALID,
					VMCI_INVALID_HANDLE);
}

int vmci_transport_send_waiting_write(struct sock *sk,
				      struct vmci_transport_waiting_info *wait)
{
	return vmci_transport_send_control_pkt(
				sk, VMCI_TRANSPORT_PACKET_TYPE_WAITING_WRITE,
				0, 0, wait, VSOCK_PROTO_INVALID,
				VMCI_INVALID_HANDLE);
}

int vmci_transport_send_waiting_read(struct sock *sk,
				     struct vmci_transport_waiting_info *wait)
{
	return vmci_transport_send_control_pkt(
				sk, VMCI_TRANSPORT_PACKET_TYPE_WAITING_READ,
				0, 0, wait, VSOCK_PROTO_INVALID,
				VMCI_INVALID_HANDLE);
}

static int vmci_transport_shutdown(struct vsock_sock *vsk, int mode)
{
	return vmci_transport_send_control_pkt(
					&vsk->sk,
					VMCI_TRANSPORT_PACKET_TYPE_SHUTDOWN,
					0, mode, NULL,
					VSOCK_PROTO_INVALID,
					VMCI_INVALID_HANDLE);
}

static int vmci_transport_send_conn_request(struct sock *sk, size_t size)
{
	return vmci_transport_send_control_pkt(sk,
					VMCI_TRANSPORT_PACKET_TYPE_REQUEST,
					size, 0, NULL,
					VSOCK_PROTO_INVALID,
					VMCI_INVALID_HANDLE);
}

static int vmci_transport_send_conn_request2(struct sock *sk, size_t size,
					     u16 version)
{
	return vmci_transport_send_control_pkt(
					sk, VMCI_TRANSPORT_PACKET_TYPE_REQUEST2,
					size, 0, NULL, version,
					VMCI_INVALID_HANDLE);
}

static struct sock *vmci_transport_get_pending(
					struct sock *listener,
					struct vmci_transport_packet *pkt)
{
	struct vsock_sock *vlistener;
	struct vsock_sock *vpending;
	struct sock *pending;
	struct sockaddr_vm src;

	vsock_addr_init(&src, pkt->dg.src.context, pkt->src_port);

	vlistener = vsock_sk(listener);

	list_for_each_entry(vpending, &vlistener->pending_links,
			    pending_links) {
		if (vsock_addr_equals_addr(&src, &vpending->remote_addr) &&
		    pkt->dst_port == vpending->local_addr.svm_port) {
			pending = sk_vsock(vpending);
			sock_hold(pending);
			goto found;
		}
	}

	pending = NULL;
found:
	return pending;

}

static void vmci_transport_release_pending(struct sock *pending)
{
	sock_put(pending);
}

/* We allow two kinds of sockets to communicate with a restricted VM: 1)
 * trusted sockets 2) sockets from applications running as the same user as the
 * VM (this is only true for the host side and only when using hosted products)
 */

static bool vmci_transport_is_trusted(struct vsock_sock *vsock, u32 peer_cid)
{
	return vsock->trusted ||
	       vmci_is_context_owner(peer_cid, vsock->owner->uid);
}

/* We allow sending datagrams to and receiving datagrams from a restricted VM
 * only if it is trusted as described in vmci_transport_is_trusted.
 */

static bool vmci_transport_allow_dgram(struct vsock_sock *vsock, u32 peer_cid)
{
	if (VMADDR_CID_HYPERVISOR == peer_cid)
		return true;

	if (vsock->cached_peer != peer_cid) {
		vsock->cached_peer = peer_cid;
		if (!vmci_transport_is_trusted(vsock, peer_cid) &&
		    (vmci_context_get_priv_flags(peer_cid) &
		     VMCI_PRIVILEGE_FLAG_RESTRICTED)) {
			vsock->cached_peer_allow_dgram = false;
		} else {
			vsock->cached_peer_allow_dgram = true;
		}
	}

	return vsock->cached_peer_allow_dgram;
}

static int
vmci_transport_queue_pair_alloc(struct vmci_qp **qpair,
				struct vmci_handle *handle,
				u64 produce_size,
				u64 consume_size,
				u32 peer, u32 flags, bool trusted)
{
	int err = 0;

	if (trusted) {
		/* Try to allocate our queue pair as trusted. This will only
		 * work if vsock is running in the host.
		 */

		err = vmci_qpair_alloc(qpair, handle, produce_size,
				       consume_size,
				       peer, flags,
				       VMCI_PRIVILEGE_FLAG_TRUSTED);
		if (err != VMCI_ERROR_NO_ACCESS)
			goto out;

	}

	err = vmci_qpair_alloc(qpair, handle, produce_size, consume_size,
			       peer, flags, VMCI_NO_PRIVILEGE_FLAGS);
out:
	if (err < 0) {
		pr_err("Could not attach to queue pair with %d\n",
		       err);
		err = vmci_transport_error_to_vsock_error(err);
	}

	return err;
}

static int
vmci_transport_datagram_create_hnd(u32 resource_id,
				   u32 flags,
				   vmci_datagram_recv_cb recv_cb,
				   void *client_data,
				   struct vmci_handle *out_handle)
{
	int err = 0;

	/* Try to allocate our datagram handler as trusted. This will only work
	 * if vsock is running in the host.
	 */

	err = vmci_datagram_create_handle_priv(resource_id, flags,
					       VMCI_PRIVILEGE_FLAG_TRUSTED,
					       recv_cb,
					       client_data, out_handle);

	if (err == VMCI_ERROR_NO_ACCESS)
		err = vmci_datagram_create_handle(resource_id, flags,
						  recv_cb, client_data,
						  out_handle);

	return err;
}

/* This is invoked as part of a tasklet that's scheduled when the VMCI
 * interrupt fires.  This is run in bottom-half context and if it ever needs to
 * sleep it should defer that work to a work queue.
 */

static int vmci_transport_recv_dgram_cb(void *data, struct vmci_datagram *dg)
{
	struct sock *sk;
	size_t size;
	struct sk_buff *skb;
	struct vsock_sock *vsk;

	sk = (struct sock *)data;

	/* This handler is privileged when this module is running on the host.
	 * We will get datagrams from all endpoints (even VMs that are in a
	 * restricted context). If we get one from a restricted context then
	 * the destination socket must be trusted.
	 *
	 * NOTE: We access the socket struct without holding the lock here.
	 * This is ok because the field we are interested is never modified
	 * outside of the create and destruct socket functions.
	 */
	vsk = vsock_sk(sk);
	if (!vmci_transport_allow_dgram(vsk, dg->src.context))
		return VMCI_ERROR_NO_ACCESS;

	size = VMCI_DG_SIZE(dg);

	/* Attach the packet to the socket's receive queue as an sk_buff. */
	skb = alloc_skb(size, GFP_ATOMIC);
	if (!skb)
		return VMCI_ERROR_NO_MEM;

	/* sk_receive_skb() will do a sock_put(), so hold here. */
	sock_hold(sk);
	skb_put(skb, size);
	memcpy(skb->data, dg, size);
	sk_receive_skb(sk, skb, 0);

	return VMCI_SUCCESS;
}

static bool vmci_transport_stream_allow(u32 cid, u32 port)
{
	static const u32 non_socket_contexts[] = {
		VMADDR_CID_RESERVED,
	};
	int i;

	BUILD_BUG_ON(sizeof(cid) != sizeof(*non_socket_contexts));

	for (i = 0; i < ARRAY_SIZE(non_socket_contexts); i++) {
		if (cid == non_socket_contexts[i])
			return false;
	}

	return true;
}

/* This is invoked as part of a tasklet that's scheduled when the VMCI
 * interrupt fires.  This is run in bottom-half context but it defers most of
 * its work to the packet handling work queue.
 */

static int vmci_transport_recv_stream_cb(void *data, struct vmci_datagram *dg)
{
	struct sock *sk;
	struct sockaddr_vm dst;
	struct sockaddr_vm src;
	struct vmci_transport_packet *pkt;
	struct vsock_sock *vsk;
	bool bh_process_pkt;
	int err;

	sk = NULL;
	err = VMCI_SUCCESS;
	bh_process_pkt = false;

	/* Ignore incoming packets from contexts without sockets, or resources
	 * that aren't vsock implementations.
	 */

	if (!vmci_transport_stream_allow(dg->src.context, -1)
	    || vmci_transport_peer_rid(dg->src.context) != dg->src.resource)
		return VMCI_ERROR_NO_ACCESS;

	if (VMCI_DG_SIZE(dg) < sizeof(*pkt))
		/* Drop datagrams that do not contain full VSock packets. */
		return VMCI_ERROR_INVALID_ARGS;

	pkt = (struct vmci_transport_packet *)dg;

	/* Find the socket that should handle this packet.  First we look for a
	 * connected socket and if there is none we look for a socket bound to
	 * the destintation address.
	 */
	vsock_addr_init(&src, pkt->dg.src.context, pkt->src_port);
	vsock_addr_init(&dst, pkt->dg.dst.context, pkt->dst_port);

	sk = vsock_find_connected_socket(&src, &dst);
	if (!sk) {
		sk = vsock_find_bound_socket(&dst);
		if (!sk) {
			/* We could not find a socket for this specified
			 * address.  If this packet is a RST, we just drop it.
			 * If it is another packet, we send a RST.  Note that
			 * we do not send a RST reply to RSTs so that we do not
			 * continually send RSTs between two endpoints.
			 *
			 * Note that since this is a reply, dst is src and src
			 * is dst.
			 */
			if (vmci_transport_send_reset_bh(&dst, &src, pkt) < 0)
				pr_err("unable to send reset\n");

			err = VMCI_ERROR_NOT_FOUND;
			goto out;
		}
	}

	/* If the received packet type is beyond all types known to this
	 * implementation, reply with an invalid message.  Hopefully this will
	 * help when implementing backwards compatibility in the future.
	 */
	if (pkt->type >= VMCI_TRANSPORT_PACKET_TYPE_MAX) {
		vmci_transport_send_invalid_bh(&dst, &src);
		err = VMCI_ERROR_INVALID_ARGS;
		goto out;
	}

	/* This handler is privileged when this module is running on the host.
	 * We will get datagram connect requests from all endpoints (even VMs
	 * that are in a restricted context). If we get one from a restricted
	 * context then the destination socket must be trusted.
	 *
	 * NOTE: We access the socket struct without holding the lock here.
	 * This is ok because the field we are interested is never modified
	 * outside of the create and destruct socket functions.
	 */
	vsk = vsock_sk(sk);
	if (!vmci_transport_allow_dgram(vsk, pkt->dg.src.context)) {
		err = VMCI_ERROR_NO_ACCESS;
		goto out;
	}

	/* We do most everything in a work queue, but let's fast path the
	 * notification of reads and writes to help data transfer performance.
	 * We can only do this if there is no process context code executing
	 * for this socket since that may change the state.
	 */
	bh_lock_sock(sk);

	if (!sock_owned_by_user(sk)) {
		/* The local context ID may be out of date, update it. */
		vsk->local_addr.svm_cid = dst.svm_cid;

		if (sk->sk_state == SS_CONNECTED)
			vmci_trans(vsk)->notify_ops->handle_notify_pkt(
					sk, pkt, true, &dst, &src,
					&bh_process_pkt);
	}

	bh_unlock_sock(sk);

	if (!bh_process_pkt) {
		struct vmci_transport_recv_pkt_info *recv_pkt_info;

		recv_pkt_info = kmalloc(sizeof(*recv_pkt_info), GFP_ATOMIC);
		if (!recv_pkt_info) {
			if (vmci_transport_send_reset_bh(&dst, &src, pkt) < 0)
				pr_err("unable to send reset\n");

			err = VMCI_ERROR_NO_MEM;
			goto out;
		}

		recv_pkt_info->sk = sk;
		memcpy(&recv_pkt_info->pkt, pkt, sizeof(recv_pkt_info->pkt));
		INIT_WORK(&recv_pkt_info->work, vmci_transport_recv_pkt_work);

		schedule_work(&recv_pkt_info->work);
		/* Clear sk so that the reference count incremented by one of
		 * the Find functions above is not decremented below.  We need
		 * that reference count for the packet handler we've scheduled
		 * to run.
		 */
		sk = NULL;
	}

out:
	if (sk)
		sock_put(sk);

	return err;
}

static void vmci_transport_handle_detach(struct sock *sk)
{
	struct vsock_sock *vsk;

	vsk = vsock_sk(sk);
	if (!vmci_handle_is_invalid(vmci_trans(vsk)->qp_handle)) {
		sock_set_flag(sk, SOCK_DONE);

		/* On a detach the peer will not be sending or receiving
		 * anymore.
		 */
		vsk->peer_shutdown = SHUTDOWN_MASK;

		/* We should not be sending anymore since the peer won't be
		 * there to receive, but we can still receive if there is data
		 * left in our consume queue.
		 */
		if (vsock_stream_has_data(vsk) <= 0) {
			if (sk->sk_state == SS_CONNECTING) {
				/* The peer may detach from a queue pair while
				 * we are still in the connecting state, i.e.,
				 * if the peer VM is killed after attaching to
				 * a queue pair, but before we complete the
				 * handshake. In that case, we treat the detach
				 * event like a reset.
				 */

				sk->sk_state = SS_UNCONNECTED;
				sk->sk_err = ECONNRESET;
				sk->sk_error_report(sk);
				return;
			}
			sk->sk_state = SS_UNCONNECTED;
		}
		sk->sk_state_change(sk);
	}
}

static void vmci_transport_peer_detach_cb(u32 sub_id,
					  const struct vmci_event_data *e_data,
					  void *client_data)
{
	struct vmci_transport *trans = client_data;
	const struct vmci_event_payload_qp *e_payload;

	e_payload = vmci_event_data_const_payload(e_data);

	/* XXX This is lame, we should provide a way to lookup sockets by
	 * qp_handle.
	 */
	if (vmci_handle_is_invalid(e_payload->handle) ||
	    !vmci_handle_is_equal(trans->qp_handle, e_payload->handle))
		return;

	/* We don't ask for delayed CBs when we subscribe to this event (we
	 * pass 0 as flags to vmci_event_subscribe()).  VMCI makes no
	 * guarantees in that case about what context we might be running in,
	 * so it could be BH or process, blockable or non-blockable.  So we
	 * need to account for all possible contexts here.
	 */
	spin_lock_bh(&trans->lock);
	if (!trans->sk)
		goto out;

	/* Apart from here, trans->lock is only grabbed as part of sk destruct,
	 * where trans->sk isn't locked.
	 */
	bh_lock_sock(trans->sk);

	vmci_transport_handle_detach(trans->sk);

	bh_unlock_sock(trans->sk);
 out:
	spin_unlock_bh(&trans->lock);
}

static void vmci_transport_qp_resumed_cb(u32 sub_id,
					 const struct vmci_event_data *e_data,
					 void *client_data)
{
	vsock_for_each_connected_socket(vmci_transport_handle_detach);
}

static void vmci_transport_recv_pkt_work(struct work_struct *work)
{
	struct vmci_transport_recv_pkt_info *recv_pkt_info;
	struct vmci_transport_packet *pkt;
	struct sock *sk;

	recv_pkt_info =
		container_of(work, struct vmci_transport_recv_pkt_info, work);
	sk = recv_pkt_info->sk;
	pkt = &recv_pkt_info->pkt;

	lock_sock(sk);

	/* The local context ID may be out of date. */
	vsock_sk(sk)->local_addr.svm_cid = pkt->dg.dst.context;

	switch (sk->sk_state) {
	case VSOCK_SS_LISTEN:
		vmci_transport_recv_listen(sk, pkt);
		break;
	case SS_CONNECTING:
		/* Processing of pending connections for servers goes through
		 * the listening socket, so see vmci_transport_recv_listen()
		 * for that path.
		 */
		vmci_transport_recv_connecting_client(sk, pkt);
		break;
	case SS_CONNECTED:
		vmci_transport_recv_connected(sk, pkt);
		break;
	default:
		/* Because this function does not run in the same context as
		 * vmci_transport_recv_stream_cb it is possible that the
		 * socket has closed. We need to let the other side know or it
		 * could be sitting in a connect and hang forever. Send a
		 * reset to prevent that.
		 */
		vmci_transport_send_reset(sk, pkt);
		break;
	}

	release_sock(sk);
	kfree(recv_pkt_info);
	/* Release reference obtained in the stream callback when we fetched
	 * this socket out of the bound or connected list.
	 */
	sock_put(sk);
}

static int vmci_transport_recv_listen(struct sock *sk,
				      struct vmci_transport_packet *pkt)
{
	struct sock *pending;
	struct vsock_sock *vpending;
	int err;
	u64 qp_size;
	bool old_request = false;
	bool old_pkt_proto = false;

	err = 0;

	/* Because we are in the listen state, we could be receiving a packet
	 * for ourself or any previous connection requests that we received.
	 * If it's the latter, we try to find a socket in our list of pending
	 * connections and, if we do, call the appropriate handler for the
	 * state that that socket is in.  Otherwise we try to service the
	 * connection request.
	 */
	pending = vmci_transport_get_pending(sk, pkt);
	if (pending) {
		lock_sock(pending);

		/* The local context ID may be out of date. */
		vsock_sk(pending)->local_addr.svm_cid = pkt->dg.dst.context;

		switch (pending->sk_state) {
		case SS_CONNECTING:
			err = vmci_transport_recv_connecting_server(sk,
								    pending,
								    pkt);
			break;
		default:
			vmci_transport_send_reset(pending, pkt);
			err = -EINVAL;
		}

		if (err < 0)
			vsock_remove_pending(sk, pending);

		release_sock(pending);
		vmci_transport_release_pending(pending);

		return err;
	}

	/* The listen state only accepts connection requests.  Reply with a
	 * reset unless we received a reset.
	 */

	if (!(pkt->type == VMCI_TRANSPORT_PACKET_TYPE_REQUEST ||
	      pkt->type == VMCI_TRANSPORT_PACKET_TYPE_REQUEST2)) {
		vmci_transport_reply_reset(pkt);
		return -EINVAL;
	}

	if (pkt->u.size == 0) {
		vmci_transport_reply_reset(pkt);
		return -EINVAL;
	}

	/* If this socket can't accommodate this connection request, we send a
	 * reset.  Otherwise we create and initialize a child socket and reply
	 * with a connection negotiation.
	 */
	if (sk->sk_ack_backlog >= sk->sk_max_ack_backlog) {
		vmci_transport_reply_reset(pkt);
		return -ECONNREFUSED;
	}

	pending = __vsock_create(sock_net(sk), NULL, sk, GFP_KERNEL,
				 sk->sk_type, 0);
	if (!pending) {
		vmci_transport_send_reset(sk, pkt);
		return -ENOMEM;
	}

	vpending = vsock_sk(pending);

	vsock_addr_init(&vpending->local_addr, pkt->dg.dst.context,
			pkt->dst_port);
	vsock_addr_init(&vpending->remote_addr, pkt->dg.src.context,
			pkt->src_port);

	/* If the proposed size fits within our min/max, accept it. Otherwise
	 * propose our own size.
	 */
	if (pkt->u.size >= vmci_trans(vpending)->queue_pair_min_size &&
	    pkt->u.size <= vmci_trans(vpending)->queue_pair_max_size) {
		qp_size = pkt->u.size;
	} else {
		qp_size = vmci_trans(vpending)->queue_pair_size;
	}

	/* Figure out if we are using old or new requests based on the
	 * overrides pkt types sent by our peer.
	 */
	if (vmci_transport_old_proto_override(&old_pkt_proto)) {
		old_request = old_pkt_proto;
	} else {
		if (pkt->type == VMCI_TRANSPORT_PACKET_TYPE_REQUEST)
			old_request = true;
		else if (pkt->type == VMCI_TRANSPORT_PACKET_TYPE_REQUEST2)
			old_request = false;

	}

	if (old_request) {
		/* Handle a REQUEST (or override) */
		u16 version = VSOCK_PROTO_INVALID;
		if (vmci_transport_proto_to_notify_struct(
			pending, &version, true))
			err = vmci_transport_send_negotiate(pending, qp_size);
		else
			err = -EINVAL;

	} else {
		/* Handle a REQUEST2 (or override) */
		int proto_int = pkt->proto;
		int pos;
		u16 active_proto_version = 0;

		/* The list of possible protocols is the intersection of all
		 * protocols the client supports ... plus all the protocols we
		 * support.
		 */
		proto_int &= vmci_transport_new_proto_supported_versions();

		/* We choose the highest possible protocol version and use that
		 * one.
		 */
		pos = fls(proto_int);
		if (pos) {
			active_proto_version = (1 << (pos - 1));
			if (vmci_transport_proto_to_notify_struct(
				pending, &active_proto_version, false))
				err = vmci_transport_send_negotiate2(pending,
							qp_size,
							active_proto_version);
			else
				err = -EINVAL;

		} else {
			err = -EINVAL;
		}
	}

	if (err < 0) {
		vmci_transport_send_reset(sk, pkt);
		sock_put(pending);
		err = vmci_transport_error_to_vsock_error(err);
		goto out;
	}

	vsock_add_pending(sk, pending);
	sk->sk_ack_backlog++;

	pending->sk_state = SS_CONNECTING;
	vmci_trans(vpending)->produce_size =
		vmci_trans(vpending)->consume_size = qp_size;
	vmci_trans(vpending)->queue_pair_size = qp_size;

	vmci_trans(vpending)->notify_ops->process_request(pending);

	/* We might never receive another message for this socket and it's not
	 * connected to any process, so we have to ensure it gets cleaned up
	 * ourself.  Our delayed work function will take care of that.  Note
	 * that we do not ever cancel this function since we have few
	 * guarantees about its state when calling cancel_delayed_work().
	 * Instead we hold a reference on the socket for that function and make
	 * it capable of handling cases where it needs to do nothing but
	 * release that reference.
	 */
	vpending->listener = sk;
	sock_hold(sk);
	sock_hold(pending);
	INIT_DELAYED_WORK(&vpending->dwork, vsock_pending_work);
	schedule_delayed_work(&vpending->dwork, HZ);

out:
	return err;
}

static int
vmci_transport_recv_connecting_server(struct sock *listener,
				      struct sock *pending,
				      struct vmci_transport_packet *pkt)
{
	struct vsock_sock *vpending;
	struct vmci_handle handle;
	struct vmci_qp *qpair;
	bool is_local;
	u32 flags;
	u32 detach_sub_id;
	int err;
	int skerr;

	vpending = vsock_sk(pending);
	detach_sub_id = VMCI_INVALID_ID;

	switch (pkt->type) {
	case VMCI_TRANSPORT_PACKET_TYPE_OFFER:
		if (vmci_handle_is_invalid(pkt->u.handle)) {
			vmci_transport_send_reset(pending, pkt);
			skerr = EPROTO;
			err = -EINVAL;
			goto destroy;
		}
		break;
	default:
		/* Close and cleanup the connection. */
		vmci_transport_send_reset(pending, pkt);
		skerr = EPROTO;
		err = pkt->type == VMCI_TRANSPORT_PACKET_TYPE_RST ? 0 : -EINVAL;
		goto destroy;
	}

	/* In order to complete the connection we need to attach to the offered
	 * queue pair and send an attach notification.  We also subscribe to the
	 * detach event so we know when our peer goes away, and we do that
	 * before attaching so we don't miss an event.  If all this succeeds,
	 * we update our state and wakeup anything waiting in accept() for a
	 * connection.
	 */

	/* We don't care about attach since we ensure the other side has
	 * attached by specifying the ATTACH_ONLY flag below.
	 */
	err = vmci_event_subscribe(VMCI_EVENT_QP_PEER_DETACH,
				   vmci_transport_peer_detach_cb,
				   vmci_trans(vpending), &detach_sub_id);
	if (err < VMCI_SUCCESS) {
		vmci_transport_send_reset(pending, pkt);
		err = vmci_transport_error_to_vsock_error(err);
		skerr = -err;
		goto destroy;
	}

	vmci_trans(vpending)->detach_sub_id = detach_sub_id;

	/* Now attach to the queue pair the client created. */
	handle = pkt->u.handle;

	/* vpending->local_addr always has a context id so we do not need to
	 * worry about VMADDR_CID_ANY in this case.
	 */
	is_local =
	    vpending->remote_addr.svm_cid == vpending->local_addr.svm_cid;
	flags = VMCI_QPFLAG_ATTACH_ONLY;
	flags |= is_local ? VMCI_QPFLAG_LOCAL : 0;

	err = vmci_transport_queue_pair_alloc(
					&qpair,
					&handle,
					vmci_trans(vpending)->produce_size,
					vmci_trans(vpending)->consume_size,
					pkt->dg.src.context,
					flags,
					vmci_transport_is_trusted(
						vpending,
						vpending->remote_addr.svm_cid));
	if (err < 0) {
		vmci_transport_send_reset(pending, pkt);
		skerr = -err;
		goto destroy;
	}

	vmci_trans(vpending)->qp_handle = handle;
	vmci_trans(vpending)->qpair = qpair;

	/* When we send the attach message, we must be ready to handle incoming
	 * control messages on the newly connected socket. So we move the
	 * pending socket to the connected state before sending the attach
	 * message. Otherwise, an incoming packet triggered by the attach being
	 * received by the peer may be processed concurrently with what happens
	 * below after sending the attach message, and that incoming packet
	 * will find the listening socket instead of the (currently) pending
	 * socket. Note that enqueueing the socket increments the reference
	 * count, so even if a reset comes before the connection is accepted,
	 * the socket will be valid until it is removed from the queue.
	 *
	 * If we fail sending the attach below, we remove the socket from the
	 * connected list and move the socket to SS_UNCONNECTED before
	 * releasing the lock, so a pending slow path processing of an incoming
	 * packet will not see the socket in the connected state in that case.
	 */
	pending->sk_state = SS_CONNECTED;

	vsock_insert_connected(vpending);

	/* Notify our peer of our attach. */
	err = vmci_transport_send_attach(pending, handle);
	if (err < 0) {
		vsock_remove_connected(vpending);
		pr_err("Could not send attach\n");
		vmci_transport_send_reset(pending, pkt);
		err = vmci_transport_error_to_vsock_error(err);
		skerr = -err;
		goto destroy;
	}

	/* We have a connection. Move the now connected socket from the
	 * listener's pending list to the accept queue so callers of accept()
	 * can find it.
	 */
	vsock_remove_pending(listener, pending);
	vsock_enqueue_accept(listener, pending);

	/* Callers of accept() will be be waiting on the listening socket, not
	 * the pending socket.
	 */
	listener->sk_data_ready(listener);

	return 0;

destroy:
	pending->sk_err = skerr;
	pending->sk_state = SS_UNCONNECTED;
	/* As long as we drop our reference, all necessary cleanup will handle
	 * when the cleanup function drops its reference and our destruct
	 * implementation is called.  Note that since the listen handler will
	 * remove pending from the pending list upon our failure, the cleanup
	 * function won't drop the additional reference, which is why we do it
	 * here.
	 */
	sock_put(pending);

	return err;
}

static int
vmci_transport_recv_connecting_client(struct sock *sk,
				      struct vmci_transport_packet *pkt)
{
	struct vsock_sock *vsk;
	int err;
	int skerr;

	vsk = vsock_sk(sk);

	switch (pkt->type) {
	case VMCI_TRANSPORT_PACKET_TYPE_ATTACH:
		if (vmci_handle_is_invalid(pkt->u.handle) ||
		    !vmci_handle_is_equal(pkt->u.handle,
					  vmci_trans(vsk)->qp_handle)) {
			skerr = EPROTO;
			err = -EINVAL;
			goto destroy;
		}

		/* Signify the socket is connected and wakeup the waiter in
		 * connect(). Also place the socket in the connected table for
		 * accounting (it can already be found since it's in the bound
		 * table).
		 */
		sk->sk_state = SS_CONNECTED;
		sk->sk_socket->state = SS_CONNECTED;
		vsock_insert_connected(vsk);
		sk->sk_state_change(sk);

		break;
	case VMCI_TRANSPORT_PACKET_TYPE_NEGOTIATE:
	case VMCI_TRANSPORT_PACKET_TYPE_NEGOTIATE2:
		if (pkt->u.size == 0
		    || pkt->dg.src.context != vsk->remote_addr.svm_cid
		    || pkt->src_port != vsk->remote_addr.svm_port
		    || !vmci_handle_is_invalid(vmci_trans(vsk)->qp_handle)
		    || vmci_trans(vsk)->qpair
		    || vmci_trans(vsk)->produce_size != 0
		    || vmci_trans(vsk)->consume_size != 0
		    || vmci_trans(vsk)->detach_sub_id != VMCI_INVALID_ID) {
			skerr = EPROTO;
			err = -EINVAL;

			goto destroy;
		}

		err = vmci_transport_recv_connecting_client_negotiate(sk, pkt);
		if (err) {
			skerr = -err;
			goto destroy;
		}

		break;
	case VMCI_TRANSPORT_PACKET_TYPE_INVALID:
		err = vmci_transport_recv_connecting_client_invalid(sk, pkt);
		if (err) {
			skerr = -err;
			goto destroy;
		}

		break;
	case VMCI_TRANSPORT_PACKET_TYPE_RST:
		/* Older versions of the linux code (WS 6.5 / ESX 4.0) used to
		 * continue processing here after they sent an INVALID packet.
		 * This meant that we got a RST after the INVALID. We ignore a
		 * RST after an INVALID. The common code doesn't send the RST
		 * ... so we can hang if an old version of the common code
		 * fails between getting a REQUEST and sending an OFFER back.
		 * Not much we can do about it... except hope that it doesn't
		 * happen.
		 */
		if (vsk->ignore_connecting_rst) {
			vsk->ignore_connecting_rst = false;
		} else {
			skerr = ECONNRESET;
			err = 0;
			goto destroy;
		}

		break;
	default:
		/* Close and cleanup the connection. */
		skerr = EPROTO;
		err = -EINVAL;
		goto destroy;
	}

	return 0;

destroy:
	vmci_transport_send_reset(sk, pkt);

	sk->sk_state = SS_UNCONNECTED;
	sk->sk_err = skerr;
	sk->sk_error_report(sk);
	return err;
}

static int vmci_transport_recv_connecting_client_negotiate(
					struct sock *sk,
					struct vmci_transport_packet *pkt)
{
	int err;
	struct vsock_sock *vsk;
	struct vmci_handle handle;
	struct vmci_qp *qpair;
	u32 detach_sub_id;
	bool is_local;
	u32 flags;
	bool old_proto = true;
	bool old_pkt_proto;
	u16 version;

	vsk = vsock_sk(sk);
	handle = VMCI_INVALID_HANDLE;
	detach_sub_id = VMCI_INVALID_ID;

	/* If we have gotten here then we should be past the point where old
	 * linux vsock could have sent the bogus rst.
	 */
	vsk->sent_request = false;
	vsk->ignore_connecting_rst = false;

	/* Verify that we're OK with the proposed queue pair size */
	if (pkt->u.size < vmci_trans(vsk)->queue_pair_min_size ||
	    pkt->u.size > vmci_trans(vsk)->queue_pair_max_size) {
		err = -EINVAL;
		goto destroy;
	}

	/* At this point we know the CID the peer is using to talk to us. */

	if (vsk->local_addr.svm_cid == VMADDR_CID_ANY)
		vsk->local_addr.svm_cid = pkt->dg.dst.context;

	/* Setup the notify ops to be the highest supported version that both
	 * the server and the client support.
	 */

	if (vmci_transport_old_proto_override(&old_pkt_proto)) {
		old_proto = old_pkt_proto;
	} else {
		if (pkt->type == VMCI_TRANSPORT_PACKET_TYPE_NEGOTIATE)
			old_proto = true;
		else if (pkt->type == VMCI_TRANSPORT_PACKET_TYPE_NEGOTIATE2)
			old_proto = false;

	}

	if (old_proto)
		version = VSOCK_PROTO_INVALID;
	else
		version = pkt->proto;

	if (!vmci_transport_proto_to_notify_struct(sk, &version, old_proto)) {
		err = -EINVAL;
		goto destroy;
	}

	/* Subscribe to detach events first.
	 *
	 * XXX We attach once for each queue pair created for now so it is easy
	 * to find the socket (it's provided), but later we should only
	 * subscribe once and add a way to lookup sockets by queue pair handle.
	 */
	err = vmci_event_subscribe(VMCI_EVENT_QP_PEER_DETACH,
				   vmci_transport_peer_detach_cb,
				   vmci_trans(vsk), &detach_sub_id);
	if (err < VMCI_SUCCESS) {
		err = vmci_transport_error_to_vsock_error(err);
		goto destroy;
	}

	/* Make VMCI select the handle for us. */
	handle = VMCI_INVALID_HANDLE;
	is_local = vsk->remote_addr.svm_cid == vsk->local_addr.svm_cid;
	flags = is_local ? VMCI_QPFLAG_LOCAL : 0;

	err = vmci_transport_queue_pair_alloc(&qpair,
					      &handle,
					      pkt->u.size,
					      pkt->u.size,
					      vsk->remote_addr.svm_cid,
					      flags,
					      vmci_transport_is_trusted(
						  vsk,
						  vsk->
						  remote_addr.svm_cid));
	if (err < 0)
		goto destroy;

	err = vmci_transport_send_qp_offer(sk, handle);
	if (err < 0) {
		err = vmci_transport_error_to_vsock_error(err);
		goto destroy;
	}

	vmci_trans(vsk)->qp_handle = handle;
	vmci_trans(vsk)->qpair = qpair;

	vmci_trans(vsk)->produce_size = vmci_trans(vsk)->consume_size =
		pkt->u.size;

	vmci_trans(vsk)->detach_sub_id = detach_sub_id;

	vmci_trans(vsk)->notify_ops->process_negotiate(sk);

	return 0;

destroy:
	if (detach_sub_id != VMCI_INVALID_ID)
		vmci_event_unsubscribe(detach_sub_id);

	if (!vmci_handle_is_invalid(handle))
		vmci_qpair_detach(&qpair);

	return err;
}

static int
vmci_transport_recv_connecting_client_invalid(struct sock *sk,
					      struct vmci_transport_packet *pkt)
{
	int err = 0;
	struct vsock_sock *vsk = vsock_sk(sk);

	if (vsk->sent_request) {
		vsk->sent_request = false;
		vsk->ignore_connecting_rst = true;

		err = vmci_transport_send_conn_request(
			sk, vmci_trans(vsk)->queue_pair_size);
		if (err < 0)
			err = vmci_transport_error_to_vsock_error(err);
		else
			err = 0;

	}

	return err;
}

static int vmci_transport_recv_connected(struct sock *sk,
					 struct vmci_transport_packet *pkt)
{
	struct vsock_sock *vsk;
	bool pkt_processed = false;

	/* In cases where we are closing the connection, it's sufficient to
	 * mark the state change (and maybe error) and wake up any waiting
	 * threads. Since this is a connected socket, it's owned by a user
	 * process and will be cleaned up when the failure is passed back on
	 * the current or next system call.  Our system call implementations
	 * must therefore check for error and state changes on entry and when
	 * being awoken.
	 */
	switch (pkt->type) {
	case VMCI_TRANSPORT_PACKET_TYPE_SHUTDOWN:
		if (pkt->u.mode) {
			vsk = vsock_sk(sk);

			vsk->peer_shutdown |= pkt->u.mode;
			sk->sk_state_change(sk);
		}
		break;

	case VMCI_TRANSPORT_PACKET_TYPE_RST:
		vsk = vsock_sk(sk);
		/* It is possible that we sent our peer a message (e.g a
		 * WAITING_READ) right before we got notified that the peer had
		 * detached. If that happens then we can get a RST pkt back
		 * from our peer even though there is data available for us to
		 * read. In that case, don't shutdown the socket completely but
		 * instead allow the local client to finish reading data off
		 * the queuepair. Always treat a RST pkt in connected mode like
		 * a clean shutdown.
		 */
		sock_set_flag(sk, SOCK_DONE);
		vsk->peer_shutdown = SHUTDOWN_MASK;
		if (vsock_stream_has_data(vsk) <= 0)
			sk->sk_state = SS_DISCONNECTING;

		sk->sk_state_change(sk);
		break;

	default:
		vsk = vsock_sk(sk);
		vmci_trans(vsk)->notify_ops->handle_notify_pkt(
				sk, pkt, false, NULL, NULL,
				&pkt_processed);
		if (!pkt_processed)
			return -EINVAL;

		break;
	}

	return 0;
}

static int vmci_transport_socket_init(struct vsock_sock *vsk,
				      struct vsock_sock *psk)
{
	vsk->trans = kmalloc(sizeof(struct vmci_transport), GFP_KERNEL);
	if (!vsk->trans)
		return -ENOMEM;

	vmci_trans(vsk)->dg_handle = VMCI_INVALID_HANDLE;
	vmci_trans(vsk)->qp_handle = VMCI_INVALID_HANDLE;
	vmci_trans(vsk)->qpair = NULL;
	vmci_trans(vsk)->produce_size = vmci_trans(vsk)->consume_size = 0;
	vmci_trans(vsk)->detach_sub_id = VMCI_INVALID_ID;
	vmci_trans(vsk)->notify_ops = NULL;
	INIT_LIST_HEAD(&vmci_trans(vsk)->elem);
	vmci_trans(vsk)->sk = &vsk->sk;
	spin_lock_init(&vmci_trans(vsk)->lock);
	if (psk) {
		vmci_trans(vsk)->queue_pair_size =
			vmci_trans(psk)->queue_pair_size;
		vmci_trans(vsk)->queue_pair_min_size =
			vmci_trans(psk)->queue_pair_min_size;
		vmci_trans(vsk)->queue_pair_max_size =
			vmci_trans(psk)->queue_pair_max_size;
	} else {
		vmci_trans(vsk)->queue_pair_size =
			VMCI_TRANSPORT_DEFAULT_QP_SIZE;
		vmci_trans(vsk)->queue_pair_min_size =
			 VMCI_TRANSPORT_DEFAULT_QP_SIZE_MIN;
		vmci_trans(vsk)->queue_pair_max_size =
			VMCI_TRANSPORT_DEFAULT_QP_SIZE_MAX;
	}

	return 0;
}

static void vmci_transport_free_resources(struct list_head *transport_list)
{
	while (!list_empty(transport_list)) {
		struct vmci_transport *transport =
		    list_first_entry(transport_list, struct vmci_transport,
				     elem);
		list_del(&transport->elem);

		if (transport->detach_sub_id != VMCI_INVALID_ID) {
			vmci_event_unsubscribe(transport->detach_sub_id);
			transport->detach_sub_id = VMCI_INVALID_ID;
		}

		if (!vmci_handle_is_invalid(transport->qp_handle)) {
			vmci_qpair_detach(&transport->qpair);
			transport->qp_handle = VMCI_INVALID_HANDLE;
			transport->produce_size = 0;
			transport->consume_size = 0;
		}

		kfree(transport);
	}
}

static void vmci_transport_cleanup(struct work_struct *work)
{
	LIST_HEAD(pending);

	spin_lock_bh(&vmci_transport_cleanup_lock);
	list_replace_init(&vmci_transport_cleanup_list, &pending);
	spin_unlock_bh(&vmci_transport_cleanup_lock);
	vmci_transport_free_resources(&pending);
}

static void vmci_transport_destruct(struct vsock_sock *vsk)
{
	/* Ensure that the detach callback doesn't use the sk/vsk
	 * we are about to destruct.
	 */
	spin_lock_bh(&vmci_trans(vsk)->lock);
	vmci_trans(vsk)->sk = NULL;
	spin_unlock_bh(&vmci_trans(vsk)->lock);

	if (vmci_trans(vsk)->notify_ops)
		vmci_trans(vsk)->notify_ops->socket_destruct(vsk);

	spin_lock_bh(&vmci_transport_cleanup_lock);
	list_add(&vmci_trans(vsk)->elem, &vmci_transport_cleanup_list);
	spin_unlock_bh(&vmci_transport_cleanup_lock);
	schedule_work(&vmci_transport_cleanup_work);

	vsk->trans = NULL;
}

static void vmci_transport_release(struct vsock_sock *vsk)
{
	vsock_remove_sock(vsk);

	if (!vmci_handle_is_invalid(vmci_trans(vsk)->dg_handle)) {
		vmci_datagram_destroy_handle(vmci_trans(vsk)->dg_handle);
		vmci_trans(vsk)->dg_handle = VMCI_INVALID_HANDLE;
	}
}

static int vmci_transport_dgram_bind(struct vsock_sock *vsk,
				     struct sockaddr_vm *addr)
{
	u32 port;
	u32 flags;
	int err;

	/* VMCI will select a resource ID for us if we provide
	 * VMCI_INVALID_ID.
	 */
	port = addr->svm_port == VMADDR_PORT_ANY ?
			VMCI_INVALID_ID : addr->svm_port;

	if (port <= LAST_RESERVED_PORT && !capable(CAP_NET_BIND_SERVICE))
		return -EACCES;

	flags = addr->svm_cid == VMADDR_CID_ANY ?
				VMCI_FLAG_ANYCID_DG_HND : 0;

	err = vmci_transport_datagram_create_hnd(port, flags,
						 vmci_transport_recv_dgram_cb,
						 &vsk->sk,
						 &vmci_trans(vsk)->dg_handle);
	if (err < VMCI_SUCCESS)
		return vmci_transport_error_to_vsock_error(err);
	vsock_addr_init(&vsk->local_addr, addr->svm_cid,
			vmci_trans(vsk)->dg_handle.resource);

	return 0;
}

static int vmci_transport_dgram_enqueue(
	struct vsock_sock *vsk,
	struct sockaddr_vm *remote_addr,
	struct msghdr *msg,
	size_t len)
{
	int err;
	struct vmci_datagram *dg;

	if (len > VMCI_MAX_DG_PAYLOAD_SIZE)
		return -EMSGSIZE;

	if (!vmci_transport_allow_dgram(vsk, remote_addr->svm_cid))
		return -EPERM;

	/* Allocate a buffer for the user's message and our packet header. */
	dg = kmalloc(len + sizeof(*dg), GFP_KERNEL);
	if (!dg)
		return -ENOMEM;

	memcpy_from_msg(VMCI_DG_PAYLOAD(dg), msg, len);

	dg->dst = vmci_make_handle(remote_addr->svm_cid,
				   remote_addr->svm_port);
	dg->src = vmci_make_handle(vsk->local_addr.svm_cid,
				   vsk->local_addr.svm_port);
	dg->payload_size = len;

	err = vmci_datagram_send(dg);
	kfree(dg);
	if (err < 0)
		return vmci_transport_error_to_vsock_error(err);

	return err - sizeof(*dg);
}

static int vmci_transport_dgram_dequeue(struct vsock_sock *vsk,
					struct msghdr *msg, size_t len,
					int flags)
{
	int err;
	int noblock;
	struct vmci_datagram *dg;
	size_t payload_len;
	struct sk_buff *skb;

	noblock = flags & MSG_DONTWAIT;

	if (flags & MSG_OOB || flags & MSG_ERRQUEUE)
		return -EOPNOTSUPP;

	/* Retrieve the head sk_buff from the socket's receive queue. */
	err = 0;
	skb = skb_recv_datagram(&vsk->sk, flags, noblock, &err);
	if (!skb)
		return err;

	dg = (struct vmci_datagram *)skb->data;
	if (!dg)
		/* err is 0, meaning we read zero bytes. */
		goto out;

	payload_len = dg->payload_size;
	/* Ensure the sk_buff matches the payload size claimed in the packet. */
	if (payload_len != skb->len - sizeof(*dg)) {
		err = -EINVAL;
		goto out;
	}

	if (payload_len > len) {
		payload_len = len;
		msg->msg_flags |= MSG_TRUNC;
	}

	/* Place the datagram payload in the user's iovec. */
	err = skb_copy_datagram_msg(skb, sizeof(*dg), msg, payload_len);
	if (err)
		goto out;

	if (msg->msg_name) {
		/* Provide the address of the sender. */
		DECLARE_SOCKADDR(struct sockaddr_vm *, vm_addr, msg->msg_name);
		vsock_addr_init(vm_addr, dg->src.context, dg->src.resource);
		msg->msg_namelen = sizeof(*vm_addr);
	}
	err = payload_len;

out:
	skb_free_datagram(&vsk->sk, skb);
	return err;
}

static bool vmci_transport_dgram_allow(u32 cid, u32 port)
{
	if (cid == VMADDR_CID_HYPERVISOR) {
		/* Registrations of PBRPC Servers do not modify VMX/Hypervisor
		 * state and are allowed.
		 */
		return port == VMCI_UNITY_PBRPC_REGISTER;
	}

	return true;
}

static int vmci_transport_connect(struct vsock_sock *vsk)
{
	int err;
	bool old_pkt_proto = false;
	struct sock *sk = &vsk->sk;

	if (vmci_transport_old_proto_override(&old_pkt_proto) &&
		old_pkt_proto) {
		err = vmci_transport_send_conn_request(
			sk, vmci_trans(vsk)->queue_pair_size);
		if (err < 0) {
			sk->sk_state = SS_UNCONNECTED;
			return err;
		}
	} else {
		int supported_proto_versions =
			vmci_transport_new_proto_supported_versions();
		err = vmci_transport_send_conn_request2(
				sk, vmci_trans(vsk)->queue_pair_size,
				supported_proto_versions);
		if (err < 0) {
			sk->sk_state = SS_UNCONNECTED;
			return err;
		}

		vsk->sent_request = true;
	}

	return err;
}

static ssize_t vmci_transport_stream_dequeue(
	struct vsock_sock *vsk,
	struct msghdr *msg,
	size_t len,
	int flags)
{
	if (flags & MSG_PEEK)
		return vmci_qpair_peekv(vmci_trans(vsk)->qpair, msg, len, 0);
	else
		return vmci_qpair_dequev(vmci_trans(vsk)->qpair, msg, len, 0);
}

static ssize_t vmci_transport_stream_enqueue(
	struct vsock_sock *vsk,
	struct msghdr *msg,
	size_t len)
{
	return vmci_qpair_enquev(vmci_trans(vsk)->qpair, msg, len, 0);
}

static s64 vmci_transport_stream_has_data(struct vsock_sock *vsk)
{
	return vmci_qpair_consume_buf_ready(vmci_trans(vsk)->qpair);
}

static s64 vmci_transport_stream_has_space(struct vsock_sock *vsk)
{
	return vmci_qpair_produce_free_space(vmci_trans(vsk)->qpair);
}

static u64 vmci_transport_stream_rcvhiwat(struct vsock_sock *vsk)
{
	return vmci_trans(vsk)->consume_size;
}

static bool vmci_transport_stream_is_active(struct vsock_sock *vsk)
{
	return !vmci_handle_is_invalid(vmci_trans(vsk)->qp_handle);
}

static u64 vmci_transport_get_buffer_size(struct vsock_sock *vsk)
{
	return vmci_trans(vsk)->queue_pair_size;
}

static u64 vmci_transport_get_min_buffer_size(struct vsock_sock *vsk)
{
	return vmci_trans(vsk)->queue_pair_min_size;
}

static u64 vmci_transport_get_max_buffer_size(struct vsock_sock *vsk)
{
	return vmci_trans(vsk)->queue_pair_max_size;
}

static void vmci_transport_set_buffer_size(struct vsock_sock *vsk, u64 val)
{
	if (val < vmci_trans(vsk)->queue_pair_min_size)
		vmci_trans(vsk)->queue_pair_min_size = val;
	if (val > vmci_trans(vsk)->queue_pair_max_size)
		vmci_trans(vsk)->queue_pair_max_size = val;
	vmci_trans(vsk)->queue_pair_size = val;
}

static void vmci_transport_set_min_buffer_size(struct vsock_sock *vsk,
					       u64 val)
{
	if (val > vmci_trans(vsk)->queue_pair_size)
		vmci_trans(vsk)->queue_pair_size = val;
	vmci_trans(vsk)->queue_pair_min_size = val;
}

static void vmci_transport_set_max_buffer_size(struct vsock_sock *vsk,
					       u64 val)
{
	if (val < vmci_trans(vsk)->queue_pair_size)
		vmci_trans(vsk)->queue_pair_size = val;
	vmci_trans(vsk)->queue_pair_max_size = val;
}

static int vmci_transport_notify_poll_in(
	struct vsock_sock *vsk,
	size_t target,
	bool *data_ready_now)
{
	return vmci_trans(vsk)->notify_ops->poll_in(
			&vsk->sk, target, data_ready_now);
}

static int vmci_transport_notify_poll_out(
	struct vsock_sock *vsk,
	size_t target,
	bool *space_available_now)
{
	return vmci_trans(vsk)->notify_ops->poll_out(
			&vsk->sk, target, space_available_now);
}

static int vmci_transport_notify_recv_init(
	struct vsock_sock *vsk,
	size_t target,
	struct vsock_transport_recv_notify_data *data)
{
	return vmci_trans(vsk)->notify_ops->recv_init(
			&vsk->sk, target,
			(struct vmci_transport_recv_notify_data *)data);
}

static int vmci_transport_notify_recv_pre_block(
	struct vsock_sock *vsk,
	size_t target,
	struct vsock_transport_recv_notify_data *data)
{
	return vmci_trans(vsk)->notify_ops->recv_pre_block(
			&vsk->sk, target,
			(struct vmci_transport_recv_notify_data *)data);
}

static int vmci_transport_notify_recv_pre_dequeue(
	struct vsock_sock *vsk,
	size_t target,
	struct vsock_transport_recv_notify_data *data)
{
	return vmci_trans(vsk)->notify_ops->recv_pre_dequeue(
			&vsk->sk, target,
			(struct vmci_transport_recv_notify_data *)data);
}

static int vmci_transport_notify_recv_post_dequeue(
	struct vsock_sock *vsk,
	size_t target,
	ssize_t copied,
	bool data_read,
	struct vsock_transport_recv_notify_data *data)
{
	return vmci_trans(vsk)->notify_ops->recv_post_dequeue(
			&vsk->sk, target, copied, data_read,
			(struct vmci_transport_recv_notify_data *)data);
}

static int vmci_transport_notify_send_init(
	struct vsock_sock *vsk,
	struct vsock_transport_send_notify_data *data)
{
	return vmci_trans(vsk)->notify_ops->send_init(
			&vsk->sk,
			(struct vmci_transport_send_notify_data *)data);
}

static int vmci_transport_notify_send_pre_block(
	struct vsock_sock *vsk,
	struct vsock_transport_send_notify_data *data)
{
	return vmci_trans(vsk)->notify_ops->send_pre_block(
			&vsk->sk,
			(struct vmci_transport_send_notify_data *)data);
}

static int vmci_transport_notify_send_pre_enqueue(
	struct vsock_sock *vsk,
	struct vsock_transport_send_notify_data *data)
{
	return vmci_trans(vsk)->notify_ops->send_pre_enqueue(
			&vsk->sk,
			(struct vmci_transport_send_notify_data *)data);
}

static int vmci_transport_notify_send_post_enqueue(
	struct vsock_sock *vsk,
	ssize_t written,
	struct vsock_transport_send_notify_data *data)
{
	return vmci_trans(vsk)->notify_ops->send_post_enqueue(
			&vsk->sk, written,
			(struct vmci_transport_send_notify_data *)data);
}

static bool vmci_transport_old_proto_override(bool *old_pkt_proto)
{
	if (PROTOCOL_OVERRIDE != -1) {
		if (PROTOCOL_OVERRIDE == 0)
			*old_pkt_proto = true;
		else
			*old_pkt_proto = false;

		pr_info("Proto override in use\n");
		return true;
	}

	return false;
}

static bool vmci_transport_proto_to_notify_struct(struct sock *sk,
						  u16 *proto,
						  bool old_pkt_proto)
{
	struct vsock_sock *vsk = vsock_sk(sk);

	if (old_pkt_proto) {
		if (*proto != VSOCK_PROTO_INVALID) {
			pr_err("Can't set both an old and new protocol\n");
			return false;
		}
		vmci_trans(vsk)->notify_ops = &vmci_transport_notify_pkt_ops;
		goto exit;
	}

	switch (*proto) {
	case VSOCK_PROTO_PKT_ON_NOTIFY:
		vmci_trans(vsk)->notify_ops =
			&vmci_transport_notify_pkt_q_state_ops;
		break;
	default:
		pr_err("Unknown notify protocol version\n");
		return false;
	}

exit:
	vmci_trans(vsk)->notify_ops->socket_init(sk);
	return true;
}

static u16 vmci_transport_new_proto_supported_versions(void)
{
	if (PROTOCOL_OVERRIDE != -1)
		return PROTOCOL_OVERRIDE;

	return VSOCK_PROTO_ALL_SUPPORTED;
}

static u32 vmci_transport_get_local_cid(void)
{
	return vmci_get_context_id();
}

static const struct vsock_transport vmci_transport = {
	.init = vmci_transport_socket_init,
	.destruct = vmci_transport_destruct,
	.release = vmci_transport_release,
	.connect = vmci_transport_connect,
	.dgram_bind = vmci_transport_dgram_bind,
	.dgram_dequeue = vmci_transport_dgram_dequeue,
	.dgram_enqueue = vmci_transport_dgram_enqueue,
	.dgram_allow = vmci_transport_dgram_allow,
	.stream_dequeue = vmci_transport_stream_dequeue,
	.stream_enqueue = vmci_transport_stream_enqueue,
	.stream_has_data = vmci_transport_stream_has_data,
	.stream_has_space = vmci_transport_stream_has_space,
	.stream_rcvhiwat = vmci_transport_stream_rcvhiwat,
	.stream_is_active = vmci_transport_stream_is_active,
	.stream_allow = vmci_transport_stream_allow,
	.notify_poll_in = vmci_transport_notify_poll_in,
	.notify_poll_out = vmci_transport_notify_poll_out,
	.notify_recv_init = vmci_transport_notify_recv_init,
	.notify_recv_pre_block = vmci_transport_notify_recv_pre_block,
	.notify_recv_pre_dequeue = vmci_transport_notify_recv_pre_dequeue,
	.notify_recv_post_dequeue = vmci_transport_notify_recv_post_dequeue,
	.notify_send_init = vmci_transport_notify_send_init,
	.notify_send_pre_block = vmci_transport_notify_send_pre_block,
	.notify_send_pre_enqueue = vmci_transport_notify_send_pre_enqueue,
	.notify_send_post_enqueue = vmci_transport_notify_send_post_enqueue,
	.shutdown = vmci_transport_shutdown,
	.set_buffer_size = vmci_transport_set_buffer_size,
	.set_min_buffer_size = vmci_transport_set_min_buffer_size,
	.set_max_buffer_size = vmci_transport_set_max_buffer_size,
	.get_buffer_size = vmci_transport_get_buffer_size,
	.get_min_buffer_size = vmci_transport_get_min_buffer_size,
	.get_max_buffer_size = vmci_transport_get_max_buffer_size,
	.get_local_cid = vmci_transport_get_local_cid,
};

static int __init vmci_transport_init(void)
{
	int err;

	/* Create the datagram handle that we will use to send and receive all
	 * VSocket control messages for this context.
	 */
	err = vmci_transport_datagram_create_hnd(VMCI_TRANSPORT_PACKET_RID,
						 VMCI_FLAG_ANYCID_DG_HND,
						 vmci_transport_recv_stream_cb,
						 NULL,
						 &vmci_transport_stream_handle);
	if (err < VMCI_SUCCESS) {
		pr_err("Unable to create datagram handle. (%d)\n", err);
		return vmci_transport_error_to_vsock_error(err);
	}

	err = vmci_event_subscribe(VMCI_EVENT_QP_RESUMED,
				   vmci_transport_qp_resumed_cb,
				   NULL, &vmci_transport_qp_resumed_sub_id);
	if (err < VMCI_SUCCESS) {
		pr_err("Unable to subscribe to resumed event. (%d)\n", err);
		err = vmci_transport_error_to_vsock_error(err);
		vmci_transport_qp_resumed_sub_id = VMCI_INVALID_ID;
		goto err_destroy_stream_handle;
	}

	err = vsock_core_init(&vmci_transport);
	if (err < 0)
		goto err_unsubscribe;

	return 0;

err_unsubscribe:
	vmci_event_unsubscribe(vmci_transport_qp_resumed_sub_id);
err_destroy_stream_handle:
	vmci_datagram_destroy_handle(vmci_transport_stream_handle);
	return err;
}
module_init(vmci_transport_init);

static void __exit vmci_transport_exit(void)
{
	cancel_work_sync(&vmci_transport_cleanup_work);
	vmci_transport_free_resources(&vmci_transport_cleanup_list);

	if (!vmci_handle_is_invalid(vmci_transport_stream_handle)) {
		if (vmci_datagram_destroy_handle(
			vmci_transport_stream_handle) != VMCI_SUCCESS)
			pr_err("Couldn't destroy datagram handle\n");
		vmci_transport_stream_handle = VMCI_INVALID_HANDLE;
	}

	if (vmci_transport_qp_resumed_sub_id != VMCI_INVALID_ID) {
		vmci_event_unsubscribe(vmci_transport_qp_resumed_sub_id);
		vmci_transport_qp_resumed_sub_id = VMCI_INVALID_ID;
	}

	vsock_core_exit();
}
module_exit(vmci_transport_exit);

MODULE_AUTHOR("VMware, Inc.");
MODULE_DESCRIPTION("VMCI transport for Virtual Sockets");
MODULE_VERSION("1.0.4.0-k");
MODULE_LICENSE("GPL v2");
MODULE_ALIAS("vmware_vsock");
MODULE_ALIAS_NETPROTO(PF_VSOCK);