diff options
Diffstat (limited to 'mm/hmm.c')
-rw-r--r-- | mm/hmm.c | 303 |
1 files changed, 131 insertions, 172 deletions
@@ -20,6 +20,7 @@ #include <linux/swapops.h> #include <linux/hugetlb.h> #include <linux/memremap.h> +#include <linux/sched/mm.h> #include <linux/jump_label.h> #include <linux/dma-mapping.h> #include <linux/mmu_notifier.h> @@ -27,16 +28,6 @@ static const struct mmu_notifier_ops hmm_mmu_notifier_ops; -static inline struct hmm *mm_get_hmm(struct mm_struct *mm) -{ - struct hmm *hmm = READ_ONCE(mm->hmm); - - if (hmm && kref_get_unless_zero(&hmm->kref)) - return hmm; - - return NULL; -} - /** * hmm_get_or_create - register HMM against an mm (HMM internal) * @@ -51,11 +42,16 @@ static inline struct hmm *mm_get_hmm(struct mm_struct *mm) */ static struct hmm *hmm_get_or_create(struct mm_struct *mm) { - struct hmm *hmm = mm_get_hmm(mm); - bool cleanup = false; + struct hmm *hmm; + + lockdep_assert_held_exclusive(&mm->mmap_sem); - if (hmm) - return hmm; + /* Abuse the page_table_lock to also protect mm->hmm. */ + spin_lock(&mm->page_table_lock); + hmm = mm->hmm; + if (mm->hmm && kref_get_unless_zero(&mm->hmm->kref)) + goto out_unlock; + spin_unlock(&mm->page_table_lock); hmm = kmalloc(sizeof(*hmm), GFP_KERNEL); if (!hmm) @@ -65,55 +61,50 @@ static struct hmm *hmm_get_or_create(struct mm_struct *mm) init_rwsem(&hmm->mirrors_sem); hmm->mmu_notifier.ops = NULL; INIT_LIST_HEAD(&hmm->ranges); - mutex_init(&hmm->lock); + spin_lock_init(&hmm->ranges_lock); kref_init(&hmm->kref); hmm->notifiers = 0; - hmm->dead = false; hmm->mm = mm; - spin_lock(&mm->page_table_lock); - if (!mm->hmm) - mm->hmm = hmm; - else - cleanup = true; - spin_unlock(&mm->page_table_lock); + hmm->mmu_notifier.ops = &hmm_mmu_notifier_ops; + if (__mmu_notifier_register(&hmm->mmu_notifier, mm)) { + kfree(hmm); + return NULL; + } - if (cleanup) - goto error; + mmgrab(hmm->mm); /* - * We should only get here if hold the mmap_sem in write mode ie on - * registration of first mirror through hmm_mirror_register() + * We hold the exclusive mmap_sem here so we know that mm->hmm is + * still NULL or 0 kref, and is safe to update. */ - hmm->mmu_notifier.ops = &hmm_mmu_notifier_ops; - if (__mmu_notifier_register(&hmm->mmu_notifier, mm)) - goto error_mm; + spin_lock(&mm->page_table_lock); + mm->hmm = hmm; +out_unlock: + spin_unlock(&mm->page_table_lock); return hmm; +} -error_mm: - spin_lock(&mm->page_table_lock); - if (mm->hmm == hmm) - mm->hmm = NULL; - spin_unlock(&mm->page_table_lock); -error: +static void hmm_free_rcu(struct rcu_head *rcu) +{ + struct hmm *hmm = container_of(rcu, struct hmm, rcu); + + mmdrop(hmm->mm); kfree(hmm); - return NULL; } static void hmm_free(struct kref *kref) { struct hmm *hmm = container_of(kref, struct hmm, kref); - struct mm_struct *mm = hmm->mm; - - mmu_notifier_unregister_no_release(&hmm->mmu_notifier, mm); - spin_lock(&mm->page_table_lock); - if (mm->hmm == hmm) - mm->hmm = NULL; - spin_unlock(&mm->page_table_lock); + spin_lock(&hmm->mm->page_table_lock); + if (hmm->mm->hmm == hmm) + hmm->mm->hmm = NULL; + spin_unlock(&hmm->mm->page_table_lock); - kfree(hmm); + mmu_notifier_unregister_no_release(&hmm->mmu_notifier, hmm->mm); + mmu_notifier_call_srcu(&hmm->rcu, hmm_free_rcu); } static inline void hmm_put(struct hmm *hmm) @@ -121,86 +112,73 @@ static inline void hmm_put(struct hmm *hmm) kref_put(&hmm->kref, hmm_free); } -void hmm_mm_destroy(struct mm_struct *mm) +static void hmm_release(struct mmu_notifier *mn, struct mm_struct *mm) { - struct hmm *hmm; + struct hmm *hmm = container_of(mn, struct hmm, mmu_notifier); + struct hmm_mirror *mirror; - spin_lock(&mm->page_table_lock); - hmm = mm_get_hmm(mm); - mm->hmm = NULL; - if (hmm) { - hmm->mm = NULL; - hmm->dead = true; - spin_unlock(&mm->page_table_lock); - hmm_put(hmm); + /* Bail out if hmm is in the process of being freed */ + if (!kref_get_unless_zero(&hmm->kref)) return; + + /* + * Since hmm_range_register() holds the mmget() lock hmm_release() is + * prevented as long as a range exists. + */ + WARN_ON(!list_empty_careful(&hmm->ranges)); + + down_read(&hmm->mirrors_sem); + list_for_each_entry(mirror, &hmm->mirrors, list) { + /* + * Note: The driver is not allowed to trigger + * hmm_mirror_unregister() from this thread. + */ + if (mirror->ops->release) + mirror->ops->release(mirror); } + up_read(&hmm->mirrors_sem); - spin_unlock(&mm->page_table_lock); + hmm_put(hmm); } -static void hmm_release(struct mmu_notifier *mn, struct mm_struct *mm) +static void notifiers_decrement(struct hmm *hmm) { - struct hmm *hmm = mm_get_hmm(mm); - struct hmm_mirror *mirror; - struct hmm_range *range; - - /* Report this HMM as dying. */ - hmm->dead = true; + unsigned long flags; - /* Wake-up everyone waiting on any range. */ - mutex_lock(&hmm->lock); - list_for_each_entry(range, &hmm->ranges, list) { - range->valid = false; - } - wake_up_all(&hmm->wq); - mutex_unlock(&hmm->lock); + spin_lock_irqsave(&hmm->ranges_lock, flags); + hmm->notifiers--; + if (!hmm->notifiers) { + struct hmm_range *range; - down_write(&hmm->mirrors_sem); - mirror = list_first_entry_or_null(&hmm->mirrors, struct hmm_mirror, - list); - while (mirror) { - list_del_init(&mirror->list); - if (mirror->ops->release) { - /* - * Drop mirrors_sem so callback can wait on any pending - * work that might itself trigger mmu_notifier callback - * and thus would deadlock with us. - */ - up_write(&hmm->mirrors_sem); - mirror->ops->release(mirror); - down_write(&hmm->mirrors_sem); + list_for_each_entry(range, &hmm->ranges, list) { + if (range->valid) + continue; + range->valid = true; } - mirror = list_first_entry_or_null(&hmm->mirrors, - struct hmm_mirror, list); + wake_up_all(&hmm->wq); } - up_write(&hmm->mirrors_sem); - - hmm_put(hmm); + spin_unlock_irqrestore(&hmm->ranges_lock, flags); } static int hmm_invalidate_range_start(struct mmu_notifier *mn, const struct mmu_notifier_range *nrange) { - struct hmm *hmm = mm_get_hmm(nrange->mm); + struct hmm *hmm = container_of(mn, struct hmm, mmu_notifier); struct hmm_mirror *mirror; struct hmm_update update; struct hmm_range *range; + unsigned long flags; int ret = 0; - VM_BUG_ON(!hmm); + if (!kref_get_unless_zero(&hmm->kref)) + return 0; update.start = nrange->start; update.end = nrange->end; update.event = HMM_UPDATE_INVALIDATE; update.blockable = mmu_notifier_range_blockable(nrange); - if (mmu_notifier_range_blockable(nrange)) - mutex_lock(&hmm->lock); - else if (!mutex_trylock(&hmm->lock)) { - ret = -EAGAIN; - goto out; - } + spin_lock_irqsave(&hmm->ranges_lock, flags); hmm->notifiers++; list_for_each_entry(range, &hmm->ranges, list) { if (update.end < range->start || update.start >= range->end) @@ -208,7 +186,7 @@ static int hmm_invalidate_range_start(struct mmu_notifier *mn, range->valid = false; } - mutex_unlock(&hmm->lock); + spin_unlock_irqrestore(&hmm->ranges_lock, flags); if (mmu_notifier_range_blockable(nrange)) down_read(&hmm->mirrors_sem); @@ -216,19 +194,23 @@ static int hmm_invalidate_range_start(struct mmu_notifier *mn, ret = -EAGAIN; goto out; } + list_for_each_entry(mirror, &hmm->mirrors, list) { - int ret; + int rc; - ret = mirror->ops->sync_cpu_device_pagetables(mirror, &update); - if (!update.blockable && ret == -EAGAIN) { - up_read(&hmm->mirrors_sem); + rc = mirror->ops->sync_cpu_device_pagetables(mirror, &update); + if (rc) { + if (WARN_ON(update.blockable || rc != -EAGAIN)) + continue; ret = -EAGAIN; - goto out; + break; } } up_read(&hmm->mirrors_sem); out: + if (ret) + notifiers_decrement(hmm); hmm_put(hmm); return ret; } @@ -236,24 +218,12 @@ out: static void hmm_invalidate_range_end(struct mmu_notifier *mn, const struct mmu_notifier_range *nrange) { - struct hmm *hmm = mm_get_hmm(nrange->mm); - - VM_BUG_ON(!hmm); - - mutex_lock(&hmm->lock); - hmm->notifiers--; - if (!hmm->notifiers) { - struct hmm_range *range; + struct hmm *hmm = container_of(mn, struct hmm, mmu_notifier); - list_for_each_entry(range, &hmm->ranges, list) { - if (range->valid) - continue; - range->valid = true; - } - wake_up_all(&hmm->wq); - } - mutex_unlock(&hmm->lock); + if (!kref_get_unless_zero(&hmm->kref)) + return; + notifiers_decrement(hmm); hmm_put(hmm); } @@ -268,14 +238,15 @@ static const struct mmu_notifier_ops hmm_mmu_notifier_ops = { * * @mirror: new mirror struct to register * @mm: mm to register against + * Return: 0 on success, -ENOMEM if no memory, -EINVAL if invalid arguments * * To start mirroring a process address space, the device driver must register * an HMM mirror struct. - * - * THE mm->mmap_sem MUST BE HELD IN WRITE MODE ! */ int hmm_mirror_register(struct hmm_mirror *mirror, struct mm_struct *mm) { + lockdep_assert_held_exclusive(&mm->mmap_sem); + /* Sanity check */ if (!mm || !mirror || !mirror->ops) return -EINVAL; @@ -295,23 +266,17 @@ EXPORT_SYMBOL(hmm_mirror_register); /* * hmm_mirror_unregister() - unregister a mirror * - * @mirror: new mirror struct to register + * @mirror: mirror struct to unregister * * Stop mirroring a process address space, and cleanup. */ void hmm_mirror_unregister(struct hmm_mirror *mirror) { - struct hmm *hmm = READ_ONCE(mirror->hmm); - - if (hmm == NULL) - return; + struct hmm *hmm = mirror->hmm; down_write(&hmm->mirrors_sem); - list_del_init(&mirror->list); - /* To protect us against double unregister ... */ - mirror->hmm = NULL; + list_del(&mirror->list); up_write(&hmm->mirrors_sem); - hmm_put(hmm); } EXPORT_SYMBOL(hmm_mirror_unregister); @@ -327,7 +292,7 @@ struct hmm_vma_walk { static int hmm_vma_do_fault(struct mm_walk *walk, unsigned long addr, bool write_fault, uint64_t *pfn) { - unsigned int flags = FAULT_FLAG_ALLOW_RETRY | FAULT_FLAG_REMOTE; + unsigned int flags = FAULT_FLAG_REMOTE; struct hmm_vma_walk *hmm_vma_walk = walk->private; struct hmm_range *range = hmm_vma_walk->range; struct vm_area_struct *vma = walk->vma; @@ -369,7 +334,7 @@ static int hmm_pfns_bad(unsigned long addr, * @fault: should we fault or not ? * @write_fault: write fault ? * @walk: mm_walk structure - * Returns: 0 on success, -EBUSY after page fault, or page fault error + * Return: 0 on success, -EBUSY after page fault, or page fault error * * This function will be called whenever pmd_none() or pte_none() returns true, * or whenever there is no page directory covering the virtual address range. @@ -547,7 +512,7 @@ static int hmm_vma_handle_pmd(struct mm_walk *walk, static inline uint64_t pte_to_hmm_pfn_flags(struct hmm_range *range, pte_t pte) { - if (pte_none(pte) || !pte_present(pte)) + if (pte_none(pte) || !pte_present(pte) || pte_protnone(pte)) return 0; return pte_write(pte) ? range->flags[HMM_PFN_VALID] | range->flags[HMM_PFN_WRITE] : @@ -785,7 +750,6 @@ again: return hmm_vma_walk_hole_(addr, end, fault, write_fault, walk); -#ifdef CONFIG_HUGETLB_PAGE pfn = pud_pfn(pud) + ((addr & ~PUD_MASK) >> PAGE_SHIFT); for (i = 0; i < npages; ++i, ++pfn) { hmm_vma_walk->pgmap = get_dev_pagemap(pfn, @@ -801,9 +765,6 @@ again: } hmm_vma_walk->last = end; return 0; -#else - return -EINVAL; -#endif } split_huge_pud(walk->vma, pudp, addr); @@ -906,12 +867,14 @@ static void hmm_pfns_clear(struct hmm_range *range, * Track updates to the CPU page table see include/linux/hmm.h */ int hmm_range_register(struct hmm_range *range, - struct mm_struct *mm, + struct hmm_mirror *mirror, unsigned long start, unsigned long end, unsigned page_shift) { unsigned long mask = ((1UL << page_shift) - 1UL); + struct hmm *hmm = mirror->hmm; + unsigned long flags; range->valid = false; range->hmm = NULL; @@ -925,28 +888,24 @@ int hmm_range_register(struct hmm_range *range, range->start = start; range->end = end; - range->hmm = hmm_get_or_create(mm); - if (!range->hmm) + /* Prevent hmm_release() from running while the range is valid */ + if (!mmget_not_zero(hmm->mm)) return -EFAULT; - /* Check if hmm_mm_destroy() was call. */ - if (range->hmm->mm == NULL || range->hmm->dead) { - hmm_put(range->hmm); - return -EFAULT; - } + /* Initialize range to track CPU page table updates. */ + spin_lock_irqsave(&hmm->ranges_lock, flags); - /* Initialize range to track CPU page table update */ - mutex_lock(&range->hmm->lock); - - list_add_rcu(&range->list, &range->hmm->ranges); + range->hmm = hmm; + kref_get(&hmm->kref); + list_add(&range->list, &hmm->ranges); /* * If there are any concurrent notifiers we have to wait for them for * the range to be valid (see hmm_range_wait_until_valid()). */ - if (!range->hmm->notifiers) + if (!hmm->notifiers) range->valid = true; - mutex_unlock(&range->hmm->lock); + spin_unlock_irqrestore(&hmm->ranges_lock, flags); return 0; } @@ -961,25 +920,31 @@ EXPORT_SYMBOL(hmm_range_register); */ void hmm_range_unregister(struct hmm_range *range) { - /* Sanity check this really should not happen. */ - if (range->hmm == NULL || range->end <= range->start) - return; + struct hmm *hmm = range->hmm; + unsigned long flags; - mutex_lock(&range->hmm->lock); - list_del_rcu(&range->list); - mutex_unlock(&range->hmm->lock); + spin_lock_irqsave(&hmm->ranges_lock, flags); + list_del_init(&range->list); + spin_unlock_irqrestore(&hmm->ranges_lock, flags); /* Drop reference taken by hmm_range_register() */ + mmput(hmm->mm); + hmm_put(hmm); + + /* + * The range is now invalid and the ref on the hmm is dropped, so + * poison the pointer. Leave other fields in place, for the caller's + * use. + */ range->valid = false; - hmm_put(range->hmm); - range->hmm = NULL; + memset(&range->hmm, POISON_INUSE, sizeof(range->hmm)); } EXPORT_SYMBOL(hmm_range_unregister); /* * hmm_range_snapshot() - snapshot CPU page table for a range * @range: range - * Returns: -EINVAL if invalid argument, -ENOMEM out of memory, -EPERM invalid + * Return: -EINVAL if invalid argument, -ENOMEM out of memory, -EPERM invalid * permission (for instance asking for write and range is read only), * -EAGAIN if you need to retry, -EFAULT invalid (ie either no valid * vma or it is illegal to access that range), number of valid pages @@ -998,10 +963,7 @@ long hmm_range_snapshot(struct hmm_range *range) struct vm_area_struct *vma; struct mm_walk mm_walk; - /* Check if hmm_mm_destroy() was call. */ - if (hmm->mm == NULL || hmm->dead) - return -EFAULT; - + lockdep_assert_held(&hmm->mm->mmap_sem); do { /* If range is no longer valid force retry. */ if (!range->valid) @@ -1012,9 +974,8 @@ long hmm_range_snapshot(struct hmm_range *range) return -EFAULT; if (is_vm_hugetlb_page(vma)) { - struct hstate *h = hstate_vma(vma); - - if (huge_page_shift(h) != range->page_shift && + if (huge_page_shift(hstate_vma(vma)) != + range->page_shift && range->page_shift != PAGE_SHIFT) return -EINVAL; } else { @@ -1063,7 +1024,7 @@ EXPORT_SYMBOL(hmm_range_snapshot); * hmm_range_fault() - try to fault some address in a virtual address range * @range: range being faulted * @block: allow blocking on fault (if true it sleeps and do not drop mmap_sem) - * Returns: number of valid pages in range->pfns[] (from range start + * Return: number of valid pages in range->pfns[] (from range start * address). This may be zero. If the return value is negative, * then one of the following values may be returned: * @@ -1097,9 +1058,7 @@ long hmm_range_fault(struct hmm_range *range, bool block) struct mm_walk mm_walk; int ret; - /* Check if hmm_mm_destroy() was call. */ - if (hmm->mm == NULL || hmm->dead) - return -EFAULT; + lockdep_assert_held(&hmm->mm->mmap_sem); do { /* If range is no longer valid force retry. */ @@ -1181,7 +1140,7 @@ EXPORT_SYMBOL(hmm_range_fault); * @device: device against to dma map page to * @daddrs: dma address of mapped pages * @block: allow blocking on fault (if true it sleeps and do not drop mmap_sem) - * Returns: number of pages mapped on success, -EAGAIN if mmap_sem have been + * Return: number of pages mapped on success, -EAGAIN if mmap_sem have been * drop and you need to try again, some other error value otherwise * * Note same usage pattern as hmm_range_fault(). @@ -1269,7 +1228,7 @@ EXPORT_SYMBOL(hmm_range_dma_map); * @device: device against which dma map was done * @daddrs: dma address of mapped pages * @dirty: dirty page if it had the write flag set - * Returns: number of page unmapped on success, -EINVAL otherwise + * Return: number of page unmapped on success, -EINVAL otherwise * * Note that caller MUST abide by mmu notifier or use HMM mirror and abide * to the sync_cpu_device_pagetables() callback so that it is safe here to |