|
@@ -269,11 +269,10 @@ int intel_svm_bind_mm(struct device *dev, int *pasid, int flags, struct svm_dev_
|
|
|
struct intel_iommu *iommu = intel_svm_device_to_iommu(dev);
|
|
|
struct intel_svm_dev *sdev;
|
|
|
struct intel_svm *svm = NULL;
|
|
|
+ struct mm_struct *mm = NULL;
|
|
|
int pasid_max;
|
|
|
int ret;
|
|
|
|
|
|
- BUG_ON(pasid && !current->mm);
|
|
|
-
|
|
|
if (WARN_ON(!iommu))
|
|
|
return -EINVAL;
|
|
|
|
|
@@ -284,12 +283,20 @@ int intel_svm_bind_mm(struct device *dev, int *pasid, int flags, struct svm_dev_
|
|
|
} else
|
|
|
pasid_max = 1 << 20;
|
|
|
|
|
|
+ if ((flags & SVM_FLAG_SUPERVISOR_MODE)) {
|
|
|
+ if (!ecap_srs(iommu->ecap))
|
|
|
+ return -EINVAL;
|
|
|
+ } else if (pasid) {
|
|
|
+ mm = get_task_mm(current);
|
|
|
+ BUG_ON(!mm);
|
|
|
+ }
|
|
|
+
|
|
|
mutex_lock(&pasid_mutex);
|
|
|
if (pasid && !(flags & SVM_FLAG_PRIVATE_PASID)) {
|
|
|
int i;
|
|
|
|
|
|
idr_for_each_entry(&iommu->pasid_idr, svm, i) {
|
|
|
- if (svm->mm != current->mm ||
|
|
|
+ if (svm->mm != mm ||
|
|
|
(svm->flags & SVM_FLAG_PRIVATE_PASID))
|
|
|
continue;
|
|
|
|
|
@@ -355,17 +362,22 @@ int intel_svm_bind_mm(struct device *dev, int *pasid, int flags, struct svm_dev_
|
|
|
}
|
|
|
svm->pasid = ret;
|
|
|
svm->notifier.ops = &intel_mmuops;
|
|
|
- svm->mm = get_task_mm(current);
|
|
|
+ svm->mm = mm;
|
|
|
svm->flags = flags;
|
|
|
INIT_LIST_HEAD_RCU(&svm->devs);
|
|
|
ret = -ENOMEM;
|
|
|
- if (!svm->mm || (ret = mmu_notifier_register(&svm->notifier, svm->mm))) {
|
|
|
- idr_remove(&svm->iommu->pasid_idr, svm->pasid);
|
|
|
- kfree(svm);
|
|
|
- kfree(sdev);
|
|
|
- goto out;
|
|
|
- }
|
|
|
- iommu->pasid_table[svm->pasid].val = (u64)__pa(svm->mm->pgd) | 1;
|
|
|
+ if (mm) {
|
|
|
+ ret = mmu_notifier_register(&svm->notifier, mm);
|
|
|
+ if (ret) {
|
|
|
+ idr_remove(&svm->iommu->pasid_idr, svm->pasid);
|
|
|
+ kfree(svm);
|
|
|
+ kfree(sdev);
|
|
|
+ goto out;
|
|
|
+ }
|
|
|
+ iommu->pasid_table[svm->pasid].val = (u64)__pa(mm->pgd) | 1;
|
|
|
+ mm = NULL;
|
|
|
+ } else
|
|
|
+ iommu->pasid_table[svm->pasid].val = (u64)__pa(init_mm.pgd) | 1 | (1ULL << 11);
|
|
|
wmb();
|
|
|
}
|
|
|
list_add_rcu(&sdev->list, &svm->devs);
|
|
@@ -375,6 +387,8 @@ int intel_svm_bind_mm(struct device *dev, int *pasid, int flags, struct svm_dev_
|
|
|
ret = 0;
|
|
|
out:
|
|
|
mutex_unlock(&pasid_mutex);
|
|
|
+ if (mm)
|
|
|
+ mmput(mm);
|
|
|
return ret;
|
|
|
}
|
|
|
EXPORT_SYMBOL_GPL(intel_svm_bind_mm);
|
|
@@ -416,7 +430,8 @@ int intel_svm_unbind_mm(struct device *dev, int pasid)
|
|
|
mmu_notifier_unregister(&svm->notifier, svm->mm);
|
|
|
|
|
|
idr_remove(&svm->iommu->pasid_idr, svm->pasid);
|
|
|
- mmput(svm->mm);
|
|
|
+ if (svm->mm)
|
|
|
+ mmput(svm->mm);
|
|
|
/* We mandate that no page faults may be outstanding
|
|
|
* for the PASID when intel_svm_unbind_mm() is called.
|
|
|
* If that is not obeyed, subtle errors will happen.
|
|
@@ -500,6 +515,10 @@ static irqreturn_t prq_event_thread(int irq, void *d)
|
|
|
}
|
|
|
|
|
|
result = QI_RESP_INVALID;
|
|
|
+ /* Since we're using init_mm.pgd directly, we should never take
|
|
|
+ * any faults on kernel addresses. */
|
|
|
+ if (!svm->mm)
|
|
|
+ goto bad_req;
|
|
|
down_read(&svm->mm->mmap_sem);
|
|
|
vma = find_extend_vma(svm->mm, address);
|
|
|
if (!vma || address < vma->vm_start)
|