|
@@ -814,6 +814,7 @@ static struct iommu_group *iommu_group_get_for_pci_dev(struct pci_dev *pdev)
|
|
|
*/
|
|
|
struct iommu_group *iommu_group_get_for_dev(struct device *dev)
|
|
|
{
|
|
|
+ const struct iommu_ops *ops = dev->bus->iommu_ops;
|
|
|
struct iommu_group *group;
|
|
|
int ret;
|
|
|
|
|
@@ -821,10 +822,12 @@ struct iommu_group *iommu_group_get_for_dev(struct device *dev)
|
|
|
if (group)
|
|
|
return group;
|
|
|
|
|
|
- if (!dev_is_pci(dev))
|
|
|
- return ERR_PTR(-EINVAL);
|
|
|
+ group = ERR_PTR(-EINVAL);
|
|
|
|
|
|
- group = iommu_group_get_for_pci_dev(to_pci_dev(dev));
|
|
|
+ if (ops && ops->device_group)
|
|
|
+ group = ops->device_group(dev);
|
|
|
+ else if (dev_is_pci(dev))
|
|
|
+ group = iommu_group_get_for_pci_dev(to_pci_dev(dev));
|
|
|
|
|
|
if (IS_ERR(group))
|
|
|
return group;
|