Commit 9c378abc5c0c6fc8e3acf5968924d274503819b3

Authored by Michael S. Tsirkin
Committed by Rusty Russell
1 parent 02edf6abe0

virtio-balloon: fix add/get API use

Since ee7cd8981e15bcb365fc762afe3fc47b8242f630 'virtio: expose added
descriptors immediately.', in virtio balloon virtqueue_get_buf might
now run concurrently with virtqueue_kick.  I audited both and this
seems safe in practice but this is not guaranteed by the API.
Additionally, a spurious interrupt might in theory make
virtqueue_get_buf run in parallel with virtqueue_add_buf, which is
racy.

While we might try to protect against spurious callbacks it's
easier to fix the driver: balloon seems to be the only one
(mis)using the API like this, so let's just fix balloon.

Signed-off-by: Michael S. Tsirkin <mst@redhat.com>
Signed-off-by: Rusty Russell <rusty@rustcorp.com.au> (removed unused var)

Showing 1 changed file with 10 additions and 14 deletions Side-by-side Diff

drivers/virtio/virtio_balloon.c
... ... @@ -47,7 +47,7 @@
47 47 struct task_struct *thread;
48 48  
49 49 /* Waiting for host to ack the pages we released. */
50   - struct completion acked;
  50 + wait_queue_head_t acked;
51 51  
52 52 /* Number of balloon pages we've told the Host we're not using. */
53 53 unsigned int num_pages;
54 54  
55 55  
56 56  
57 57  
... ... @@ -89,29 +89,25 @@
89 89  
90 90 static void balloon_ack(struct virtqueue *vq)
91 91 {
92   - struct virtio_balloon *vb;
93   - unsigned int len;
  92 + struct virtio_balloon *vb = vq->vdev->priv;
94 93  
95   - vb = virtqueue_get_buf(vq, &len);
96   - if (vb)
97   - complete(&vb->acked);
  94 + wake_up(&vb->acked);
98 95 }
99 96  
100 97 static void tell_host(struct virtio_balloon *vb, struct virtqueue *vq)
101 98 {
102 99 struct scatterlist sg;
  100 + unsigned int len;
103 101  
104 102 sg_init_one(&sg, vb->pfns, sizeof(vb->pfns[0]) * vb->num_pfns);
105 103  
106   - init_completion(&vb->acked);
107   -
108 104 /* We should always be able to add one buffer to an empty queue. */
109 105 if (virtqueue_add_buf(vq, &sg, 1, 0, vb, GFP_KERNEL) < 0)
110 106 BUG();
111 107 virtqueue_kick(vq);
112 108  
113 109 /* When host has read buffer, this completes via balloon_ack */
114   - wait_for_completion(&vb->acked);
  110 + wait_event(vb->acked, virtqueue_get_buf(vq, &len));
115 111 }
116 112  
117 113 static void set_page_pfns(u32 pfns[], struct page *page)
118 114  
... ... @@ -231,12 +227,8 @@
231 227 */
232 228 static void stats_request(struct virtqueue *vq)
233 229 {
234   - struct virtio_balloon *vb;
235   - unsigned int len;
  230 + struct virtio_balloon *vb = vq->vdev->priv;
236 231  
237   - vb = virtqueue_get_buf(vq, &len);
238   - if (!vb)
239   - return;
240 232 vb->need_stats_update = 1;
241 233 wake_up(&vb->config_change);
242 234 }
243 235  
... ... @@ -245,11 +237,14 @@
245 237 {
246 238 struct virtqueue *vq;
247 239 struct scatterlist sg;
  240 + unsigned int len;
248 241  
249 242 vb->need_stats_update = 0;
250 243 update_balloon_stats(vb);
251 244  
252 245 vq = vb->stats_vq;
  246 + if (!virtqueue_get_buf(vq, &len))
  247 + return;
253 248 sg_init_one(&sg, vb->stats, sizeof(vb->stats));
254 249 if (virtqueue_add_buf(vq, &sg, 1, 0, vb, GFP_KERNEL) < 0)
255 250 BUG();
... ... @@ -358,6 +353,7 @@
358 353 INIT_LIST_HEAD(&vb->pages);
359 354 vb->num_pages = 0;
360 355 init_waitqueue_head(&vb->config_change);
  356 + init_waitqueue_head(&vb->acked);
361 357 vb->vdev = vdev;
362 358 vb->need_stats_update = 0;
363 359