summaryrefslogtreecommitdiff
path: root/drivers/iommu
diff options
context:
space:
mode:
Diffstat (limited to 'drivers/iommu')
-rw-r--r--drivers/iommu/Kconfig9
-rw-r--r--drivers/iommu/arm/arm-smmu-v3/Makefile1
-rw-r--r--drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3-iommufd.c401
-rw-r--r--drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.c139
-rw-r--r--drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.h92
-rw-r--r--drivers/iommu/arm/arm-smmu/arm-smmu.c16
-rw-r--r--drivers/iommu/io-pgtable-arm.c27
-rw-r--r--drivers/iommu/iommu.c10
-rw-r--r--drivers/iommu/iommufd/Kconfig4
-rw-r--r--drivers/iommu/iommufd/Makefile6
-rw-r--r--drivers/iommu/iommufd/driver.c53
-rw-r--r--drivers/iommu/iommufd/fault.c9
-rw-r--r--drivers/iommu/iommufd/hw_pagetable.c113
-rw-r--r--drivers/iommu/iommufd/io_pagetable.c105
-rw-r--r--drivers/iommu/iommufd/io_pagetable.h26
-rw-r--r--drivers/iommu/iommufd/ioas.c259
-rw-r--r--drivers/iommu/iommufd/iommufd_private.h58
-rw-r--r--drivers/iommu/iommufd/iommufd_test.h32
-rw-r--r--drivers/iommu/iommufd/main.c65
-rw-r--r--drivers/iommu/iommufd/pages.c319
-rw-r--r--drivers/iommu/iommufd/selftest.c364
-rw-r--r--drivers/iommu/iommufd/vfio_compat.c7
-rw-r--r--drivers/iommu/iommufd/viommu.c157
23 files changed, 1936 insertions, 336 deletions
diff --git a/drivers/iommu/Kconfig b/drivers/iommu/Kconfig
index b3aa1f5d5321..0c9bceb1653d 100644
--- a/drivers/iommu/Kconfig
+++ b/drivers/iommu/Kconfig
@@ -415,6 +415,15 @@ config ARM_SMMU_V3_SVA
Say Y here if your system supports SVA extensions such as PCIe PASID
and PRI.
+config ARM_SMMU_V3_IOMMUFD
+ bool "Enable IOMMUFD features for ARM SMMUv3 (EXPERIMENTAL)"
+ depends on IOMMUFD
+ help
+ Support for IOMMUFD features intended to support virtual machines
+ with accelerated virtual IOMMUs.
+
+ Say Y here if you are doing development and testing on this feature.
+
config ARM_SMMU_V3_KUNIT_TEST
tristate "KUnit tests for arm-smmu-v3 driver" if !KUNIT_ALL_TESTS
depends on KUNIT
diff --git a/drivers/iommu/arm/arm-smmu-v3/Makefile b/drivers/iommu/arm/arm-smmu-v3/Makefile
index dc98c88b48c8..493a659cc66b 100644
--- a/drivers/iommu/arm/arm-smmu-v3/Makefile
+++ b/drivers/iommu/arm/arm-smmu-v3/Makefile
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: GPL-2.0
obj-$(CONFIG_ARM_SMMU_V3) += arm_smmu_v3.o
arm_smmu_v3-y := arm-smmu-v3.o
+arm_smmu_v3-$(CONFIG_ARM_SMMU_V3_IOMMUFD) += arm-smmu-v3-iommufd.o
arm_smmu_v3-$(CONFIG_ARM_SMMU_V3_SVA) += arm-smmu-v3-sva.o
arm_smmu_v3-$(CONFIG_TEGRA241_CMDQV) += tegra241-cmdqv.o
diff --git a/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3-iommufd.c b/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3-iommufd.c
new file mode 100644
index 000000000000..6cc14d82399f
--- /dev/null
+++ b/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3-iommufd.c
@@ -0,0 +1,401 @@
+// SPDX-License-Identifier: GPL-2.0
+/*
+ * Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES
+ */
+
+#include <uapi/linux/iommufd.h>
+
+#include "arm-smmu-v3.h"
+
+void *arm_smmu_hw_info(struct device *dev, u32 *length, u32 *type)
+{
+ struct arm_smmu_master *master = dev_iommu_priv_get(dev);
+ struct iommu_hw_info_arm_smmuv3 *info;
+ u32 __iomem *base_idr;
+ unsigned int i;
+
+ info = kzalloc(sizeof(*info), GFP_KERNEL);
+ if (!info)
+ return ERR_PTR(-ENOMEM);
+
+ base_idr = master->smmu->base + ARM_SMMU_IDR0;
+ for (i = 0; i <= 5; i++)
+ info->idr[i] = readl_relaxed(base_idr + i);
+ info->iidr = readl_relaxed(master->smmu->base + ARM_SMMU_IIDR);
+ info->aidr = readl_relaxed(master->smmu->base + ARM_SMMU_AIDR);
+
+ *length = sizeof(*info);
+ *type = IOMMU_HW_INFO_TYPE_ARM_SMMUV3;
+
+ return info;
+}
+
+static void arm_smmu_make_nested_cd_table_ste(
+ struct arm_smmu_ste *target, struct arm_smmu_master *master,
+ struct arm_smmu_nested_domain *nested_domain, bool ats_enabled)
+{
+ arm_smmu_make_s2_domain_ste(
+ target, master, nested_domain->vsmmu->s2_parent, ats_enabled);
+
+ target->data[0] = cpu_to_le64(STRTAB_STE_0_V |
+ FIELD_PREP(STRTAB_STE_0_CFG,
+ STRTAB_STE_0_CFG_NESTED));
+ target->data[0] |= nested_domain->ste[0] &
+ ~cpu_to_le64(STRTAB_STE_0_CFG);
+ target->data[1] |= nested_domain->ste[1];
+}
+
+/*
+ * Create a physical STE from the virtual STE that userspace provided when it
+ * created the nested domain. Using the vSTE userspace can request:
+ * - Non-valid STE
+ * - Abort STE
+ * - Bypass STE (install the S2, no CD table)
+ * - CD table STE (install the S2 and the userspace CD table)
+ */
+static void arm_smmu_make_nested_domain_ste(
+ struct arm_smmu_ste *target, struct arm_smmu_master *master,
+ struct arm_smmu_nested_domain *nested_domain, bool ats_enabled)
+{
+ unsigned int cfg =
+ FIELD_GET(STRTAB_STE_0_CFG, le64_to_cpu(nested_domain->ste[0]));
+
+ /*
+ * Userspace can request a non-valid STE through the nesting interface.
+ * We relay that into an abort physical STE with the intention that
+ * C_BAD_STE for this SID can be generated to userspace.
+ */
+ if (!(nested_domain->ste[0] & cpu_to_le64(STRTAB_STE_0_V)))
+ cfg = STRTAB_STE_0_CFG_ABORT;
+
+ switch (cfg) {
+ case STRTAB_STE_0_CFG_S1_TRANS:
+ arm_smmu_make_nested_cd_table_ste(target, master, nested_domain,
+ ats_enabled);
+ break;
+ case STRTAB_STE_0_CFG_BYPASS:
+ arm_smmu_make_s2_domain_ste(target, master,
+ nested_domain->vsmmu->s2_parent,
+ ats_enabled);
+ break;
+ case STRTAB_STE_0_CFG_ABORT:
+ default:
+ arm_smmu_make_abort_ste(target);
+ break;
+ }
+}
+
+static int arm_smmu_attach_dev_nested(struct iommu_domain *domain,
+ struct device *dev)
+{
+ struct arm_smmu_nested_domain *nested_domain =
+ to_smmu_nested_domain(domain);
+ struct arm_smmu_master *master = dev_iommu_priv_get(dev);
+ struct arm_smmu_attach_state state = {
+ .master = master,
+ .old_domain = iommu_get_domain_for_dev(dev),
+ .ssid = IOMMU_NO_PASID,
+ };
+ struct arm_smmu_ste ste;
+ int ret;
+
+ if (nested_domain->vsmmu->smmu != master->smmu)
+ return -EINVAL;
+ if (arm_smmu_ssids_in_use(&master->cd_table))
+ return -EBUSY;
+
+ mutex_lock(&arm_smmu_asid_lock);
+ /*
+ * The VM has to control the actual ATS state at the PCI device because
+ * we forward the invalidations directly from the VM. If the VM doesn't
+ * think ATS is on it will not generate ATC flushes and the ATC will
+ * become incoherent. Since we can't access the actual virtual PCI ATS
+ * config bit here base this off the EATS value in the STE. If the EATS
+ * is set then the VM must generate ATC flushes.
+ */
+ state.disable_ats = !nested_domain->enable_ats;
+ ret = arm_smmu_attach_prepare(&state, domain);
+ if (ret) {
+ mutex_unlock(&arm_smmu_asid_lock);
+ return ret;
+ }
+
+ arm_smmu_make_nested_domain_ste(&ste, master, nested_domain,
+ state.ats_enabled);
+ arm_smmu_install_ste_for_dev(master, &ste);
+ arm_smmu_attach_commit(&state);
+ mutex_unlock(&arm_smmu_asid_lock);
+ return 0;
+}
+
+static void arm_smmu_domain_nested_free(struct iommu_domain *domain)
+{
+ kfree(to_smmu_nested_domain(domain));
+}
+
+static const struct iommu_domain_ops arm_smmu_nested_ops = {
+ .attach_dev = arm_smmu_attach_dev_nested,
+ .free = arm_smmu_domain_nested_free,
+};
+
+static int arm_smmu_validate_vste(struct iommu_hwpt_arm_smmuv3 *arg,
+ bool *enable_ats)
+{
+ unsigned int eats;
+ unsigned int cfg;
+
+ if (!(arg->ste[0] & cpu_to_le64(STRTAB_STE_0_V))) {
+ memset(arg->ste, 0, sizeof(arg->ste));
+ return 0;
+ }
+
+ /* EIO is reserved for invalid STE data. */
+ if ((arg->ste[0] & ~STRTAB_STE_0_NESTING_ALLOWED) ||
+ (arg->ste[1] & ~STRTAB_STE_1_NESTING_ALLOWED))
+ return -EIO;
+
+ cfg = FIELD_GET(STRTAB_STE_0_CFG, le64_to_cpu(arg->ste[0]));
+ if (cfg != STRTAB_STE_0_CFG_ABORT && cfg != STRTAB_STE_0_CFG_BYPASS &&
+ cfg != STRTAB_STE_0_CFG_S1_TRANS)
+ return -EIO;
+
+ /*
+ * Only Full ATS or ATS UR is supported
+ * The EATS field will be set by arm_smmu_make_nested_domain_ste()
+ */
+ eats = FIELD_GET(STRTAB_STE_1_EATS, le64_to_cpu(arg->ste[1]));
+ arg->ste[1] &= ~cpu_to_le64(STRTAB_STE_1_EATS);
+ if (eats != STRTAB_STE_1_EATS_ABT && eats != STRTAB_STE_1_EATS_TRANS)
+ return -EIO;
+
+ if (cfg == STRTAB_STE_0_CFG_S1_TRANS)
+ *enable_ats = (eats == STRTAB_STE_1_EATS_TRANS);
+ return 0;
+}
+
+static struct iommu_domain *
+arm_vsmmu_alloc_domain_nested(struct iommufd_viommu *viommu, u32 flags,
+ const struct iommu_user_data *user_data)
+{
+ struct arm_vsmmu *vsmmu = container_of(viommu, struct arm_vsmmu, core);
+ const u32 SUPPORTED_FLAGS = IOMMU_HWPT_FAULT_ID_VALID;
+ struct arm_smmu_nested_domain *nested_domain;
+ struct iommu_hwpt_arm_smmuv3 arg;
+ bool enable_ats = false;
+ int ret;
+
+ /*
+ * Faults delivered to the nested domain are faults that originated by
+ * the S1 in the domain. The core code will match all PASIDs when
+ * delivering the fault due to user_pasid_table
+ */
+ if (flags & ~SUPPORTED_FLAGS)
+ return ERR_PTR(-EOPNOTSUPP);
+
+ ret = iommu_copy_struct_from_user(&arg, user_data,
+ IOMMU_HWPT_DATA_ARM_SMMUV3, ste);
+ if (ret)
+ return ERR_PTR(ret);
+
+ ret = arm_smmu_validate_vste(&arg, &enable_ats);
+ if (ret)
+ return ERR_PTR(ret);
+
+ nested_domain = kzalloc(sizeof(*nested_domain), GFP_KERNEL_ACCOUNT);
+ if (!nested_domain)
+ return ERR_PTR(-ENOMEM);
+
+ nested_domain->domain.type = IOMMU_DOMAIN_NESTED;
+ nested_domain->domain.ops = &arm_smmu_nested_ops;
+ nested_domain->enable_ats = enable_ats;
+ nested_domain->vsmmu = vsmmu;
+ nested_domain->ste[0] = arg.ste[0];
+ nested_domain->ste[1] = arg.ste[1] & ~cpu_to_le64(STRTAB_STE_1_EATS);
+
+ return &nested_domain->domain;
+}
+
+static int arm_vsmmu_vsid_to_sid(struct arm_vsmmu *vsmmu, u32 vsid, u32 *sid)
+{
+ struct arm_smmu_master *master;
+ struct device *dev;
+ int ret = 0;
+
+ xa_lock(&vsmmu->core.vdevs);
+ dev = iommufd_viommu_find_dev(&vsmmu->core, (unsigned long)vsid);
+ if (!dev) {
+ ret = -EIO;
+ goto unlock;
+ }
+ master = dev_iommu_priv_get(dev);
+
+ /* At this moment, iommufd only supports PCI device that has one SID */
+ if (sid)
+ *sid = master->streams[0].id;
+unlock:
+ xa_unlock(&vsmmu->core.vdevs);
+ return ret;
+}
+
+/* This is basically iommu_viommu_arm_smmuv3_invalidate in u64 for conversion */
+struct arm_vsmmu_invalidation_cmd {
+ union {
+ u64 cmd[2];
+ struct iommu_viommu_arm_smmuv3_invalidate ucmd;
+ };
+};
+
+/*
+ * Convert, in place, the raw invalidation command into an internal format that
+ * can be passed to arm_smmu_cmdq_issue_cmdlist(). Internally commands are
+ * stored in CPU endian.
+ *
+ * Enforce the VMID or SID on the command.
+ */
+static int arm_vsmmu_convert_user_cmd(struct arm_vsmmu *vsmmu,
+ struct arm_vsmmu_invalidation_cmd *cmd)
+{
+ /* Commands are le64 stored in u64 */
+ cmd->cmd[0] = le64_to_cpu(cmd->ucmd.cmd[0]);
+ cmd->cmd[1] = le64_to_cpu(cmd->ucmd.cmd[1]);
+
+ switch (cmd->cmd[0] & CMDQ_0_OP) {
+ case CMDQ_OP_TLBI_NSNH_ALL:
+ /* Convert to NH_ALL */
+ cmd->cmd[0] = CMDQ_OP_TLBI_NH_ALL |
+ FIELD_PREP(CMDQ_TLBI_0_VMID, vsmmu->vmid);
+ cmd->cmd[1] = 0;
+ break;
+ case CMDQ_OP_TLBI_NH_VA:
+ case CMDQ_OP_TLBI_NH_VAA:
+ case CMDQ_OP_TLBI_NH_ALL:
+ case CMDQ_OP_TLBI_NH_ASID:
+ cmd->cmd[0] &= ~CMDQ_TLBI_0_VMID;
+ cmd->cmd[0] |= FIELD_PREP(CMDQ_TLBI_0_VMID, vsmmu->vmid);
+ break;
+ case CMDQ_OP_ATC_INV:
+ case CMDQ_OP_CFGI_CD:
+ case CMDQ_OP_CFGI_CD_ALL: {
+ u32 sid, vsid = FIELD_GET(CMDQ_CFGI_0_SID, cmd->cmd[0]);
+
+ if (arm_vsmmu_vsid_to_sid(vsmmu, vsid, &sid))
+ return -EIO;
+ cmd->cmd[0] &= ~CMDQ_CFGI_0_SID;
+ cmd->cmd[0] |= FIELD_PREP(CMDQ_CFGI_0_SID, sid);
+ break;
+ }
+ default:
+ return -EIO;
+ }
+ return 0;
+}
+
+static int arm_vsmmu_cache_invalidate(struct iommufd_viommu *viommu,
+ struct iommu_user_data_array *array)
+{
+ struct arm_vsmmu *vsmmu = container_of(viommu, struct arm_vsmmu, core);
+ struct arm_smmu_device *smmu = vsmmu->smmu;
+ struct arm_vsmmu_invalidation_cmd *last;
+ struct arm_vsmmu_invalidation_cmd *cmds;
+ struct arm_vsmmu_invalidation_cmd *cur;
+ struct arm_vsmmu_invalidation_cmd *end;
+ int ret;
+
+ cmds = kcalloc(array->entry_num, sizeof(*cmds), GFP_KERNEL);
+ if (!cmds)
+ return -ENOMEM;
+ cur = cmds;
+ end = cmds + array->entry_num;
+
+ static_assert(sizeof(*cmds) == 2 * sizeof(u64));
+ ret = iommu_copy_struct_from_full_user_array(
+ cmds, sizeof(*cmds), array,
+ IOMMU_VIOMMU_INVALIDATE_DATA_ARM_SMMUV3);
+ if (ret)
+ goto out;
+
+ last = cmds;
+ while (cur != end) {
+ ret = arm_vsmmu_convert_user_cmd(vsmmu, cur);
+ if (ret)
+ goto out;
+
+ /* FIXME work in blocks of CMDQ_BATCH_ENTRIES and copy each block? */
+ cur++;
+ if (cur != end && (cur - last) != CMDQ_BATCH_ENTRIES - 1)
+ continue;
+
+ /* FIXME always uses the main cmdq rather than trying to group by type */
+ ret = arm_smmu_cmdq_issue_cmdlist(smmu, &smmu->cmdq, last->cmd,
+ cur - last, true);
+ if (ret) {
+ cur--;
+ goto out;
+ }
+ last = cur;
+ }
+out:
+ array->entry_num = cur - cmds;
+ kfree(cmds);
+ return ret;
+}
+
+static const struct iommufd_viommu_ops arm_vsmmu_ops = {
+ .alloc_domain_nested = arm_vsmmu_alloc_domain_nested,
+ .cache_invalidate = arm_vsmmu_cache_invalidate,
+};
+
+struct iommufd_viommu *arm_vsmmu_alloc(struct device *dev,
+ struct iommu_domain *parent,
+ struct iommufd_ctx *ictx,
+ unsigned int viommu_type)
+{
+ struct arm_smmu_device *smmu =
+ iommu_get_iommu_dev(dev, struct arm_smmu_device, iommu);
+ struct arm_smmu_master *master = dev_iommu_priv_get(dev);
+ struct arm_smmu_domain *s2_parent = to_smmu_domain(parent);
+ struct arm_vsmmu *vsmmu;
+
+ if (viommu_type != IOMMU_VIOMMU_TYPE_ARM_SMMUV3)
+ return ERR_PTR(-EOPNOTSUPP);
+
+ if (!(smmu->features & ARM_SMMU_FEAT_NESTING))
+ return ERR_PTR(-EOPNOTSUPP);
+
+ if (s2_parent->smmu != master->smmu)
+ return ERR_PTR(-EINVAL);
+
+ /*
+ * FORCE_SYNC is not set with FEAT_NESTING. Some study of the exact HW
+ * defect is needed to determine if arm_vsmmu_cache_invalidate() needs
+ * any change to remove this.
+ */
+ if (WARN_ON(smmu->options & ARM_SMMU_OPT_CMDQ_FORCE_SYNC))
+ return ERR_PTR(-EOPNOTSUPP);
+
+ /*
+ * Must support some way to prevent the VM from bypassing the cache
+ * because VFIO currently does not do any cache maintenance. canwbs
+ * indicates the device is fully coherent and no cache maintenance is
+ * ever required, even for PCI No-Snoop. S2FWB means the S1 can't make
+ * things non-coherent using the memattr, but No-Snoop behavior is not
+ * effected.
+ */
+ if (!arm_smmu_master_canwbs(master) &&
+ !(smmu->features & ARM_SMMU_FEAT_S2FWB))
+ return ERR_PTR(-EOPNOTSUPP);
+
+ vsmmu = iommufd_viommu_alloc(ictx, struct arm_vsmmu, core,
+ &arm_vsmmu_ops);
+ if (IS_ERR(vsmmu))
+ return ERR_CAST(vsmmu);
+
+ vsmmu->smmu = smmu;
+ vsmmu->s2_parent = s2_parent;
+ /* FIXME Move VMID allocation from the S2 domain allocation to here */
+ vsmmu->vmid = s2_parent->s2_cfg.vmid;
+
+ return &vsmmu->core;
+}
+
+MODULE_IMPORT_NS(IOMMUFD);
diff --git a/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.c b/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.c
index 353fea58cd31..10b4dbc8d027 100644
--- a/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.c
+++ b/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.c
@@ -295,6 +295,7 @@ static int arm_smmu_cmdq_build_cmd(u64 *cmd, struct arm_smmu_cmdq_ent *ent)
case CMDQ_OP_TLBI_NH_ASID:
cmd[0] |= FIELD_PREP(CMDQ_TLBI_0_ASID, ent->tlbi.asid);
fallthrough;
+ case CMDQ_OP_TLBI_NH_ALL:
case CMDQ_OP_TLBI_S12_VMALL:
cmd[0] |= FIELD_PREP(CMDQ_TLBI_0_VMID, ent->tlbi.vmid);
break;
@@ -765,9 +766,9 @@ static void arm_smmu_cmdq_write_entries(struct arm_smmu_cmdq *cmdq, u64 *cmds,
* insert their own list of commands then all of the commands from one
* CPU will appear before any of the commands from the other CPU.
*/
-static int arm_smmu_cmdq_issue_cmdlist(struct arm_smmu_device *smmu,
- struct arm_smmu_cmdq *cmdq,
- u64 *cmds, int n, bool sync)
+int arm_smmu_cmdq_issue_cmdlist(struct arm_smmu_device *smmu,
+ struct arm_smmu_cmdq *cmdq, u64 *cmds, int n,
+ bool sync)
{
u64 cmd_sync[CMDQ_ENT_DWORDS];
u32 prod;
@@ -1045,7 +1046,8 @@ void arm_smmu_get_ste_used(const __le64 *ent, __le64 *used_bits)
/* S2 translates */
if (cfg & BIT(1)) {
used_bits[1] |=
- cpu_to_le64(STRTAB_STE_1_EATS | STRTAB_STE_1_SHCFG);
+ cpu_to_le64(STRTAB_STE_1_S2FWB | STRTAB_STE_1_EATS |
+ STRTAB_STE_1_SHCFG);
used_bits[2] |=
cpu_to_le64(STRTAB_STE_2_S2VMID | STRTAB_STE_2_VTCR |
STRTAB_STE_2_S2AA64 | STRTAB_STE_2_S2ENDI |
@@ -1549,7 +1551,6 @@ static void arm_smmu_write_ste(struct arm_smmu_master *master, u32 sid,
}
}
-VISIBLE_IF_KUNIT
void arm_smmu_make_abort_ste(struct arm_smmu_ste *target)
{
memset(target, 0, sizeof(*target));
@@ -1632,7 +1633,6 @@ void arm_smmu_make_cdtable_ste(struct arm_smmu_ste *target,
}
EXPORT_SYMBOL_IF_KUNIT(arm_smmu_make_cdtable_ste);
-VISIBLE_IF_KUNIT
void arm_smmu_make_s2_domain_ste(struct arm_smmu_ste *target,
struct arm_smmu_master *master,
struct arm_smmu_domain *smmu_domain,
@@ -1655,6 +1655,8 @@ void arm_smmu_make_s2_domain_ste(struct arm_smmu_ste *target,
FIELD_PREP(STRTAB_STE_1_EATS,
ats_enabled ? STRTAB_STE_1_EATS_TRANS : 0));
+ if (pgtbl_cfg->quirks & IO_PGTABLE_QUIRK_ARM_S2FWB)
+ target->data[1] |= cpu_to_le64(STRTAB_STE_1_S2FWB);
if (smmu->features & ARM_SMMU_FEAT_ATTR_TYPES_OVR)
target->data[1] |= cpu_to_le64(FIELD_PREP(STRTAB_STE_1_SHCFG,
STRTAB_STE_1_SHCFG_INCOMING));
@@ -2105,7 +2107,16 @@ int arm_smmu_atc_inv_domain(struct arm_smmu_domain *smmu_domain,
if (!master->ats_enabled)
continue;
- arm_smmu_atc_inv_to_cmd(master_domain->ssid, iova, size, &cmd);
+ if (master_domain->nested_ats_flush) {
+ /*
+ * If a S2 used as a nesting parent is changed we have
+ * no option but to completely flush the ATC.
+ */
+ arm_smmu_atc_inv_to_cmd(IOMMU_NO_PASID, 0, 0, &cmd);
+ } else {
+ arm_smmu_atc_inv_to_cmd(master_domain->ssid, iova, size,
+ &cmd);
+ }
for (i = 0; i < master->num_streams; i++) {
cmd.atc.sid = master->streams[i].id;
@@ -2232,6 +2243,15 @@ static void arm_smmu_tlb_inv_range_domain(unsigned long iova, size_t size,
}
__arm_smmu_tlb_inv_range(&cmd, iova, size, granule, smmu_domain);
+ if (smmu_domain->nest_parent) {
+ /*
+ * When the S2 domain changes all the nested S1 ASIDs have to be
+ * flushed too.
+ */
+ cmd.opcode = CMDQ_OP_TLBI_NH_ALL;
+ arm_smmu_cmdq_issue_cmd_with_sync(smmu_domain->smmu, &cmd);
+ }
+
/*
* Unfortunately, this can't be leaf-only since we may have
* zapped an entire table.
@@ -2293,6 +2313,8 @@ static bool arm_smmu_capable(struct device *dev, enum iommu_cap cap)
case IOMMU_CAP_CACHE_COHERENCY:
/* Assume that a coherent TCU implies coherent TBUs */
return master->smmu->features & ARM_SMMU_FEAT_COHERENCY;
+ case IOMMU_CAP_ENFORCE_CACHE_COHERENCY:
+ return arm_smmu_master_canwbs(master);
case IOMMU_CAP_NOEXEC:
case IOMMU_CAP_DEFERRED_FLUSH:
return true;
@@ -2303,6 +2325,26 @@ static bool arm_smmu_capable(struct device *dev, enum iommu_cap cap)
}
}
+static bool arm_smmu_enforce_cache_coherency(struct iommu_domain *domain)
+{
+ struct arm_smmu_domain *smmu_domain = to_smmu_domain(domain);
+ struct arm_smmu_master_domain *master_domain;
+ unsigned long flags;
+ bool ret = true;
+
+ spin_lock_irqsave(&smmu_domain->devices_lock, flags);
+ list_for_each_entry(master_domain, &smmu_domain->devices,
+ devices_elm) {
+ if (!arm_smmu_master_canwbs(master_domain->master)) {
+ ret = false;
+ break;
+ }
+ }
+ smmu_domain->enforce_cache_coherency = ret;
+ spin_unlock_irqrestore(&smmu_domain->devices_lock, flags);
+ return ret;
+}
+
struct arm_smmu_domain *arm_smmu_domain_alloc(void)
{
struct arm_smmu_domain *smmu_domain;
@@ -2442,6 +2484,9 @@ static int arm_smmu_domain_finalise(struct arm_smmu_domain *smmu_domain,
pgtbl_cfg.oas = smmu->oas;
fmt = ARM_64_LPAE_S2;
finalise_stage_fn = arm_smmu_domain_finalise_s2;
+ if ((smmu->features & ARM_SMMU_FEAT_S2FWB) &&
+ (flags & IOMMU_HWPT_ALLOC_NEST_PARENT))
+ pgtbl_cfg.quirks |= IO_PGTABLE_QUIRK_ARM_S2FWB;
break;
default:
return -EINVAL;
@@ -2483,8 +2528,8 @@ arm_smmu_get_step_for_sid(struct arm_smmu_device *smmu, u32 sid)
}
}
-static void arm_smmu_install_ste_for_dev(struct arm_smmu_master *master,
- const struct arm_smmu_ste *target)
+void arm_smmu_install_ste_for_dev(struct arm_smmu_master *master,
+ const struct arm_smmu_ste *target)
{
int i, j;
struct arm_smmu_device *smmu = master->smmu;
@@ -2595,7 +2640,7 @@ static void arm_smmu_disable_pasid(struct arm_smmu_master *master)
static struct arm_smmu_master_domain *
arm_smmu_find_master_domain(struct arm_smmu_domain *smmu_domain,
struct arm_smmu_master *master,
- ioasid_t ssid)
+ ioasid_t ssid, bool nested_ats_flush)
{
struct arm_smmu_master_domain *master_domain;
@@ -2604,7 +2649,8 @@ arm_smmu_find_master_domain(struct arm_smmu_domain *smmu_domain,
list_for_each_entry(master_domain, &smmu_domain->devices,
devices_elm) {
if (master_domain->master == master &&
- master_domain->ssid == ssid)
+ master_domain->ssid == ssid &&
+ master_domain->nested_ats_flush == nested_ats_flush)
return master_domain;
}
return NULL;
@@ -2624,6 +2670,8 @@ to_smmu_domain_devices(struct iommu_domain *domain)
if ((domain->type & __IOMMU_DOMAIN_PAGING) ||
domain->type == IOMMU_DOMAIN_SVA)
return to_smmu_domain(domain);
+ if (domain->type == IOMMU_DOMAIN_NESTED)
+ return to_smmu_nested_domain(domain)->vsmmu->s2_parent;
return NULL;
}
@@ -2633,13 +2681,18 @@ static void arm_smmu_remove_master_domain(struct arm_smmu_master *master,
{
struct arm_smmu_domain *smmu_domain = to_smmu_domain_devices(domain);
struct arm_smmu_master_domain *master_domain;
+ bool nested_ats_flush = false;
unsigned long flags;
if (!smmu_domain)
return;
+ if (domain->type == IOMMU_DOMAIN_NESTED)
+ nested_ats_flush = to_smmu_nested_domain(domain)->enable_ats;
+
spin_lock_irqsave(&smmu_domain->devices_lock, flags);
- master_domain = arm_smmu_find_master_domain(smmu_domain, master, ssid);
+ master_domain = arm_smmu_find_master_domain(smmu_domain, master, ssid,
+ nested_ats_flush);
if (master_domain) {
list_del(&master_domain->devices_elm);
kfree(master_domain);
@@ -2649,16 +2702,6 @@ static void arm_smmu_remove_master_domain(struct arm_smmu_master *master,
spin_unlock_irqrestore(&smmu_domain->devices_lock, flags);
}
-struct arm_smmu_attach_state {
- /* Inputs */
- struct iommu_domain *old_domain;
- struct arm_smmu_master *master;
- bool cd_needs_ats;
- ioasid_t ssid;
- /* Resulting state */
- bool ats_enabled;
-};
-
/*
* Start the sequence to attach a domain to a master. The sequence contains three
* steps:
@@ -2679,8 +2722,8 @@ struct arm_smmu_attach_state {
* new_domain can be a non-paging domain. In this case ATS will not be enabled,
* and invalidations won't be tracked.
*/
-static int arm_smmu_attach_prepare(struct arm_smmu_attach_state *state,
- struct iommu_domain *new_domain)
+int arm_smmu_attach_prepare(struct arm_smmu_attach_state *state,
+ struct iommu_domain *new_domain)
{
struct arm_smmu_master *master = state->master;
struct arm_smmu_master_domain *master_domain;
@@ -2706,7 +2749,8 @@ static int arm_smmu_attach_prepare(struct arm_smmu_attach_state *state,
* enabled if we have arm_smmu_domain, those always have page
* tables.
*/
- state->ats_enabled = arm_smmu_ats_supported(master);
+ state->ats_enabled = !state->disable_ats &&
+ arm_smmu_ats_supported(master);
}
if (smmu_domain) {
@@ -2715,6 +2759,9 @@ static int arm_smmu_attach_prepare(struct arm_smmu_attach_state *state,
return -ENOMEM;
master_domain->master = master;
master_domain->ssid = state->ssid;
+ if (new_domain->type == IOMMU_DOMAIN_NESTED)
+ master_domain->nested_ats_flush =
+ to_smmu_nested_domain(new_domain)->enable_ats;
/*
* During prepare we want the current smmu_domain and new
@@ -2731,6 +2778,14 @@ static int arm_smmu_attach_prepare(struct arm_smmu_attach_state *state,
* one of them.
*/
spin_lock_irqsave(&smmu_domain->devices_lock, flags);
+ if (smmu_domain->enforce_cache_coherency &&
+ !arm_smmu_master_canwbs(master)) {
+ spin_unlock_irqrestore(&smmu_domain->devices_lock,
+ flags);
+ kfree(master_domain);
+ return -EINVAL;
+ }
+
if (state->ats_enabled)
atomic_inc(&smmu_domain->nr_ats_masters);
list_add(&master_domain->devices_elm, &smmu_domain->devices);
@@ -2754,7 +2809,7 @@ static int arm_smmu_attach_prepare(struct arm_smmu_attach_state *state,
* completes synchronizing the PCI device's ATC and finishes manipulating the
* smmu_domain->devices list.
*/
-static void arm_smmu_attach_commit(struct arm_smmu_attach_state *state)
+void arm_smmu_attach_commit(struct arm_smmu_attach_state *state)
{
struct arm_smmu_master *master = state->master;
@@ -3084,7 +3139,8 @@ arm_smmu_domain_alloc_user(struct device *dev, u32 flags,
const struct iommu_user_data *user_data)
{
struct arm_smmu_master *master = dev_iommu_priv_get(dev);
- const u32 PAGING_FLAGS = IOMMU_HWPT_ALLOC_DIRTY_TRACKING;
+ const u32 PAGING_FLAGS = IOMMU_HWPT_ALLOC_DIRTY_TRACKING |
+ IOMMU_HWPT_ALLOC_NEST_PARENT;
struct arm_smmu_domain *smmu_domain;
int ret;
@@ -3097,6 +3153,15 @@ arm_smmu_domain_alloc_user(struct device *dev, u32 flags,
if (IS_ERR(smmu_domain))
return ERR_CAST(smmu_domain);
+ if (flags & IOMMU_HWPT_ALLOC_NEST_PARENT) {
+ if (!(master->smmu->features & ARM_SMMU_FEAT_NESTING)) {
+ ret = -EOPNOTSUPP;
+ goto err_free;
+ }
+ smmu_domain->stage = ARM_SMMU_DOMAIN_S2;
+ smmu_domain->nest_parent = true;
+ }
+
smmu_domain->domain.type = IOMMU_DOMAIN_UNMANAGED;
smmu_domain->domain.ops = arm_smmu_ops.default_domain_ops;
ret = arm_smmu_domain_finalise(smmu_domain, master->smmu, flags);
@@ -3378,21 +3443,6 @@ static struct iommu_group *arm_smmu_device_group(struct device *dev)
return group;
}
-static int arm_smmu_enable_nesting(struct iommu_domain *domain)
-{
- struct arm_smmu_domain *smmu_domain = to_smmu_domain(domain);
- int ret = 0;
-
- mutex_lock(&smmu_domain->init_mutex);
- if (smmu_domain->smmu)
- ret = -EPERM;
- else
- smmu_domain->stage = ARM_SMMU_DOMAIN_S2;
- mutex_unlock(&smmu_domain->init_mutex);
-
- return ret;
-}
-
static int arm_smmu_of_xlate(struct device *dev,
const struct of_phandle_args *args)
{
@@ -3491,6 +3541,7 @@ static struct iommu_ops arm_smmu_ops = {
.identity_domain = &arm_smmu_identity_domain,
.blocked_domain = &arm_smmu_blocked_domain,
.capable = arm_smmu_capable,
+ .hw_info = arm_smmu_hw_info,
.domain_alloc_paging = arm_smmu_domain_alloc_paging,
.domain_alloc_sva = arm_smmu_sva_domain_alloc,
.domain_alloc_user = arm_smmu_domain_alloc_user,
@@ -3504,17 +3555,19 @@ static struct iommu_ops arm_smmu_ops = {
.dev_disable_feat = arm_smmu_dev_disable_feature,
.page_response = arm_smmu_page_response,
.def_domain_type = arm_smmu_def_domain_type,
+ .viommu_alloc = arm_vsmmu_alloc,
+ .user_pasid_table = 1,
.pgsize_bitmap = -1UL, /* Restricted during device attach */
.owner = THIS_MODULE,
.default_domain_ops = &(const struct iommu_domain_ops) {
.attach_dev = arm_smmu_attach_dev,
+ .enforce_cache_coherency = arm_smmu_enforce_cache_coherency,
.set_dev_pasid = arm_smmu_s1_set_dev_pasid,
.map_pages = arm_smmu_map_pages,
.unmap_pages = arm_smmu_unmap_pages,
.flush_iotlb_all = arm_smmu_flush_iotlb_all,
.iotlb_sync = arm_smmu_iotlb_sync,
.iova_to_phys = arm_smmu_iova_to_phys,
- .enable_nesting = arm_smmu_enable_nesting,
.free = arm_smmu_domain_free_paging,
}
};
diff --git a/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.h b/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.h
index 1e9952ca989f..af25f092303f 100644
--- a/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.h
+++ b/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.h
@@ -10,6 +10,7 @@
#include <linux/bitfield.h>
#include <linux/iommu.h>
+#include <linux/iommufd.h>
#include <linux/kernel.h>
#include <linux/mmzone.h>
#include <linux/sizes.h>
@@ -57,6 +58,7 @@ struct arm_smmu_device;
#define IDR1_SIDSIZE GENMASK(5, 0)
#define ARM_SMMU_IDR3 0xc
+#define IDR3_FWB (1 << 8)
#define IDR3_RIL (1 << 10)
#define ARM_SMMU_IDR5 0x14
@@ -81,6 +83,8 @@ struct arm_smmu_device;
#define IIDR_REVISION GENMASK(15, 12)
#define IIDR_IMPLEMENTER GENMASK(11, 0)
+#define ARM_SMMU_AIDR 0x1C
+
#define ARM_SMMU_CR0 0x20
#define CR0_ATSCHK (1 << 4)
#define CR0_CMDQEN (1 << 3)
@@ -241,6 +245,7 @@ static inline u32 arm_smmu_strtab_l2_idx(u32 sid)
#define STRTAB_STE_0_CFG_BYPASS 4
#define STRTAB_STE_0_CFG_S1_TRANS 5
#define STRTAB_STE_0_CFG_S2_TRANS 6
+#define STRTAB_STE_0_CFG_NESTED 7
#define STRTAB_STE_0_S1FMT GENMASK_ULL(5, 4)
#define STRTAB_STE_0_S1FMT_LINEAR 0
@@ -261,6 +266,7 @@ static inline u32 arm_smmu_strtab_l2_idx(u32 sid)
#define STRTAB_STE_1_S1COR GENMASK_ULL(5, 4)
#define STRTAB_STE_1_S1CSH GENMASK_ULL(7, 6)
+#define STRTAB_STE_1_S2FWB (1UL << 25)
#define STRTAB_STE_1_S1STALLD (1UL << 27)
#define STRTAB_STE_1_EATS GENMASK_ULL(29, 28)
@@ -292,6 +298,15 @@ static inline u32 arm_smmu_strtab_l2_idx(u32 sid)
#define STRTAB_STE_3_S2TTB_MASK GENMASK_ULL(51, 4)
+/* These bits can be controlled by userspace for STRTAB_STE_0_CFG_NESTED */
+#define STRTAB_STE_0_NESTING_ALLOWED \
+ cpu_to_le64(STRTAB_STE_0_V | STRTAB_STE_0_CFG | STRTAB_STE_0_S1FMT | \
+ STRTAB_STE_0_S1CTXPTR_MASK | STRTAB_STE_0_S1CDMAX)
+#define STRTAB_STE_1_NESTING_ALLOWED \
+ cpu_to_le64(STRTAB_STE_1_S1DSS | STRTAB_STE_1_S1CIR | \
+ STRTAB_STE_1_S1COR | STRTAB_STE_1_S1CSH | \
+ STRTAB_STE_1_S1STALLD | STRTAB_STE_1_EATS)
+
/*
* Context descriptors.
*
@@ -511,8 +526,10 @@ struct arm_smmu_cmdq_ent {
};
} cfgi;
+ #define CMDQ_OP_TLBI_NH_ALL 0x10
#define CMDQ_OP_TLBI_NH_ASID 0x11
#define CMDQ_OP_TLBI_NH_VA 0x12
+ #define CMDQ_OP_TLBI_NH_VAA 0x13
#define CMDQ_OP_TLBI_EL2_ALL 0x20
#define CMDQ_OP_TLBI_EL2_ASID 0x21
#define CMDQ_OP_TLBI_EL2_VA 0x22
@@ -726,6 +743,7 @@ struct arm_smmu_device {
#define ARM_SMMU_FEAT_ATTR_TYPES_OVR (1 << 20)
#define ARM_SMMU_FEAT_HA (1 << 21)
#define ARM_SMMU_FEAT_HD (1 << 22)
+#define ARM_SMMU_FEAT_S2FWB (1 << 23)
u32 features;
#define ARM_SMMU_OPT_SKIP_PREFETCH (1 << 0)
@@ -811,10 +829,20 @@ struct arm_smmu_domain {
/* List of struct arm_smmu_master_domain */
struct list_head devices;
spinlock_t devices_lock;
+ bool enforce_cache_coherency : 1;
+ bool nest_parent : 1;
struct mmu_notifier mmu_notifier;
};
+struct arm_smmu_nested_domain {
+ struct iommu_domain domain;
+ struct arm_vsmmu *vsmmu;
+ bool enable_ats : 1;
+
+ __le64 ste[2];
+};
+
/* The following are exposed for testing purposes. */
struct arm_smmu_entry_writer_ops;
struct arm_smmu_entry_writer {
@@ -827,21 +855,22 @@ struct arm_smmu_entry_writer_ops {
void (*sync)(struct arm_smmu_entry_writer *writer);
};
+void arm_smmu_make_abort_ste(struct arm_smmu_ste *target);
+void arm_smmu_make_s2_domain_ste(struct arm_smmu_ste *target,
+ struct arm_smmu_master *master,
+ struct arm_smmu_domain *smmu_domain,
+ bool ats_enabled);
+
#if IS_ENABLED(CONFIG_KUNIT)
void arm_smmu_get_ste_used(const __le64 *ent, __le64 *used_bits);
void arm_smmu_write_entry(struct arm_smmu_entry_writer *writer, __le64 *cur,
const __le64 *target);
void arm_smmu_get_cd_used(const __le64 *ent, __le64 *used_bits);
-void arm_smmu_make_abort_ste(struct arm_smmu_ste *target);
void arm_smmu_make_bypass_ste(struct arm_smmu_device *smmu,
struct arm_smmu_ste *target);
void arm_smmu_make_cdtable_ste(struct arm_smmu_ste *target,
struct arm_smmu_master *master, bool ats_enabled,
unsigned int s1dss);
-void arm_smmu_make_s2_domain_ste(struct arm_smmu_ste *target,
- struct arm_smmu_master *master,
- struct arm_smmu_domain *smmu_domain,
- bool ats_enabled);
void arm_smmu_make_sva_cd(struct arm_smmu_cd *target,
struct arm_smmu_master *master, struct mm_struct *mm,
u16 asid);
@@ -851,6 +880,7 @@ struct arm_smmu_master_domain {
struct list_head devices_elm;
struct arm_smmu_master *master;
ioasid_t ssid;
+ bool nested_ats_flush : 1;
};
static inline struct arm_smmu_domain *to_smmu_domain(struct iommu_domain *dom)
@@ -858,6 +888,12 @@ static inline struct arm_smmu_domain *to_smmu_domain(struct iommu_domain *dom)
return container_of(dom, struct arm_smmu_domain, domain);
}
+static inline struct arm_smmu_nested_domain *
+to_smmu_nested_domain(struct iommu_domain *dom)
+{
+ return container_of(dom, struct arm_smmu_nested_domain, domain);
+}
+
extern struct xarray arm_smmu_asid_xa;
extern struct mutex arm_smmu_asid_lock;
@@ -893,6 +929,33 @@ int arm_smmu_init_one_queue(struct arm_smmu_device *smmu,
int arm_smmu_cmdq_init(struct arm_smmu_device *smmu,
struct arm_smmu_cmdq *cmdq);
+static inline bool arm_smmu_master_canwbs(struct arm_smmu_master *master)
+{
+ return dev_iommu_fwspec_get(master->dev)->flags &
+ IOMMU_FWSPEC_PCI_RC_CANWBS;
+}
+
+struct arm_smmu_attach_state {
+ /* Inputs */
+ struct iommu_domain *old_domain;
+ struct arm_smmu_master *master;
+ bool cd_needs_ats;
+ bool disable_ats;
+ ioasid_t ssid;
+ /* Resulting state */
+ bool ats_enabled;
+};
+
+int arm_smmu_attach_prepare(struct arm_smmu_attach_state *state,
+ struct iommu_domain *new_domain);
+void arm_smmu_attach_commit(struct arm_smmu_attach_state *state);
+void arm_smmu_install_ste_for_dev(struct arm_smmu_master *master,
+ const struct arm_smmu_ste *target);
+
+int arm_smmu_cmdq_issue_cmdlist(struct arm_smmu_device *smmu,
+ struct arm_smmu_cmdq *cmdq, u64 *cmds, int n,
+ bool sync);
+
#ifdef CONFIG_ARM_SMMU_V3_SVA
bool arm_smmu_sva_supported(struct arm_smmu_device *smmu);
bool arm_smmu_master_sva_supported(struct arm_smmu_master *master);
@@ -949,4 +1012,23 @@ tegra241_cmdqv_probe(struct arm_smmu_device *smmu)
return ERR_PTR(-ENODEV);
}
#endif /* CONFIG_TEGRA241_CMDQV */
+
+struct arm_vsmmu {
+ struct iommufd_viommu core;
+ struct arm_smmu_device *smmu;
+ struct arm_smmu_domain *s2_parent;
+ u16 vmid;
+};
+
+#if IS_ENABLED(CONFIG_ARM_SMMU_V3_IOMMUFD)
+void *arm_smmu_hw_info(struct device *dev, u32 *length, u32 *type);
+struct iommufd_viommu *arm_vsmmu_alloc(struct device *dev,
+ struct iommu_domain *parent,
+ struct iommufd_ctx *ictx,
+ unsigned int viommu_type);
+#else
+#define arm_smmu_hw_info NULL
+#define arm_vsmmu_alloc NULL
+#endif /* CONFIG_ARM_SMMU_V3_IOMMUFD */
+
#endif /* _ARM_SMMU_V3_H */
diff --git a/drivers/iommu/arm/arm-smmu/arm-smmu.c b/drivers/iommu/arm/arm-smmu/arm-smmu.c
index 8321962b3714..12b173eec454 100644
--- a/drivers/iommu/arm/arm-smmu/arm-smmu.c
+++ b/drivers/iommu/arm/arm-smmu/arm-smmu.c
@@ -1558,21 +1558,6 @@ static struct iommu_group *arm_smmu_device_group(struct device *dev)
return group;
}
-static int arm_smmu_enable_nesting(struct iommu_domain *domain)
-{
- struct arm_smmu_domain *smmu_domain = to_smmu_domain(domain);
- int ret = 0;
-
- mutex_lock(&smmu_domain->init_mutex);
- if (smmu_domain->smmu)
- ret = -EPERM;
- else
- smmu_domain->stage = ARM_SMMU_DOMAIN_NESTED;
- mutex_unlock(&smmu_domain->init_mutex);
-
- return ret;
-}
-
static int arm_smmu_set_pgtable_quirks(struct iommu_domain *domain,
unsigned long quirks)
{
@@ -1656,7 +1641,6 @@ static struct iommu_ops arm_smmu_ops = {
.flush_iotlb_all = arm_smmu_flush_iotlb_all,
.iotlb_sync = arm_smmu_iotlb_sync,
.iova_to_phys = arm_smmu_iova_to_phys,
- .enable_nesting = arm_smmu_enable_nesting,
.set_pgtable_quirks = arm_smmu_set_pgtable_quirks,
.free = arm_smmu_domain_free,
}
diff --git a/drivers/iommu/io-pgtable-arm.c b/drivers/iommu/io-pgtable-arm.c
index 0e67f1721a3d..74f58c6ac30c 100644
--- a/drivers/iommu/io-pgtable-arm.c
+++ b/drivers/iommu/io-pgtable-arm.c
@@ -106,6 +106,18 @@
#define ARM_LPAE_PTE_HAP_FAULT (((arm_lpae_iopte)0) << 6)
#define ARM_LPAE_PTE_HAP_READ (((arm_lpae_iopte)1) << 6)
#define ARM_LPAE_PTE_HAP_WRITE (((arm_lpae_iopte)2) << 6)
+/*
+ * For !FWB these code to:
+ * 1111 = Normal outer write back cachable / Inner Write Back Cachable
+ * Permit S1 to override
+ * 0101 = Normal Non-cachable / Inner Non-cachable
+ * 0001 = Device / Device-nGnRE
+ * For S2FWB these code:
+ * 0110 Force Normal Write Back
+ * 0101 Normal* is forced Normal-NC, Device unchanged
+ * 0001 Force Device-nGnRE
+ */
+#define ARM_LPAE_PTE_MEMATTR_FWB_WB (((arm_lpae_iopte)0x6) << 2)
#define ARM_LPAE_PTE_MEMATTR_OIWB (((arm_lpae_iopte)0xf) << 2)
#define ARM_LPAE_PTE_MEMATTR_NC (((arm_lpae_iopte)0x5) << 2)
#define ARM_LPAE_PTE_MEMATTR_DEV (((arm_lpae_iopte)0x1) << 2)
@@ -458,12 +470,16 @@ static arm_lpae_iopte arm_lpae_prot_to_pte(struct arm_lpae_io_pgtable *data,
*/
if (data->iop.fmt == ARM_64_LPAE_S2 ||
data->iop.fmt == ARM_32_LPAE_S2) {
- if (prot & IOMMU_MMIO)
+ if (prot & IOMMU_MMIO) {
pte |= ARM_LPAE_PTE_MEMATTR_DEV;
- else if (prot & IOMMU_CACHE)
- pte |= ARM_LPAE_PTE_MEMATTR_OIWB;
- else
+ } else if (prot & IOMMU_CACHE) {
+ if (data->iop.cfg.quirks & IO_PGTABLE_QUIRK_ARM_S2FWB)
+ pte |= ARM_LPAE_PTE_MEMATTR_FWB_WB;
+ else
+ pte |= ARM_LPAE_PTE_MEMATTR_OIWB;
+ } else {
pte |= ARM_LPAE_PTE_MEMATTR_NC;
+ }
} else {
if (prot & IOMMU_MMIO)
pte |= (ARM_LPAE_MAIR_ATTR_IDX_DEV
@@ -1035,8 +1051,7 @@ arm_64_lpae_alloc_pgtable_s2(struct io_pgtable_cfg *cfg, void *cookie)
struct arm_lpae_io_pgtable *data;
typeof(&cfg->arm_lpae_s2_cfg.vtcr) vtcr = &cfg->arm_lpae_s2_cfg.vtcr;
- /* The NS quirk doesn't apply at stage 2 */
- if (cfg->quirks)
+ if (cfg->quirks & ~(IO_PGTABLE_QUIRK_ARM_S2FWB))
return NULL;
data = arm_lpae_alloc_pgtable(cfg);
diff --git a/drivers/iommu/iommu.c b/drivers/iommu/iommu.c
index 83c8e617a2c5..dbd70d5a4702 100644
--- a/drivers/iommu/iommu.c
+++ b/drivers/iommu/iommu.c
@@ -2723,16 +2723,6 @@ static int __init iommu_init(void)
}
core_initcall(iommu_init);
-int iommu_enable_nesting(struct iommu_domain *domain)
-{
- if (domain->type != IOMMU_DOMAIN_UNMANAGED)
- return -EINVAL;
- if (!domain->ops->enable_nesting)
- return -EINVAL;
- return domain->ops->enable_nesting(domain);
-}
-EXPORT_SYMBOL_GPL(iommu_enable_nesting);
-
int iommu_set_pgtable_quirks(struct iommu_domain *domain,
unsigned long quirk)
{
diff --git a/drivers/iommu/iommufd/Kconfig b/drivers/iommu/iommufd/Kconfig
index 76656fe0470d..0a07f9449fd9 100644
--- a/drivers/iommu/iommufd/Kconfig
+++ b/drivers/iommu/iommufd/Kconfig
@@ -1,4 +1,8 @@
# SPDX-License-Identifier: GPL-2.0-only
+config IOMMUFD_DRIVER_CORE
+ tristate
+ default (IOMMUFD_DRIVER || IOMMUFD) if IOMMUFD!=n
+
config IOMMUFD
tristate "IOMMU Userspace API"
select INTERVAL_TREE
diff --git a/drivers/iommu/iommufd/Makefile b/drivers/iommu/iommufd/Makefile
index cf4605962bea..cb784da6cddc 100644
--- a/drivers/iommu/iommufd/Makefile
+++ b/drivers/iommu/iommufd/Makefile
@@ -7,9 +7,13 @@ iommufd-y := \
ioas.o \
main.o \
pages.o \
- vfio_compat.o
+ vfio_compat.o \
+ viommu.o
iommufd-$(CONFIG_IOMMUFD_TEST) += selftest.o
obj-$(CONFIG_IOMMUFD) += iommufd.o
obj-$(CONFIG_IOMMUFD_DRIVER) += iova_bitmap.o
+
+iommufd_driver-y := driver.o
+obj-$(CONFIG_IOMMUFD_DRIVER_CORE) += iommufd_driver.o
diff --git a/drivers/iommu/iommufd/driver.c b/drivers/iommu/iommufd/driver.c
new file mode 100644
index 000000000000..7b67fdf44134
--- /dev/null
+++ b/drivers/iommu/iommufd/driver.c
@@ -0,0 +1,53 @@
+// SPDX-License-Identifier: GPL-2.0-only
+/* Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES
+ */
+#include "iommufd_private.h"
+
+struct iommufd_object *_iommufd_object_alloc(struct iommufd_ctx *ictx,
+ size_t size,
+ enum iommufd_object_type type)
+{
+ struct iommufd_object *obj;
+ int rc;
+
+ obj = kzalloc(size, GFP_KERNEL_ACCOUNT);
+ if (!obj)
+ return ERR_PTR(-ENOMEM);
+ obj->type = type;
+ /* Starts out bias'd by 1 until it is removed from the xarray */
+ refcount_set(&obj->shortterm_users, 1);
+ refcount_set(&obj->users, 1);
+
+ /*
+ * Reserve an ID in the xarray but do not publish the pointer yet since
+ * the caller hasn't initialized it yet. Once the pointer is published
+ * in the xarray and visible to other threads we can't reliably destroy
+ * it anymore, so the caller must complete all errorable operations
+ * before calling iommufd_object_finalize().
+ */
+ rc = xa_alloc(&ictx->objects, &obj->id, XA_ZERO_ENTRY, xa_limit_31b,
+ GFP_KERNEL_ACCOUNT);
+ if (rc)
+ goto out_free;
+ return obj;
+out_free:
+ kfree(obj);
+ return ERR_PTR(rc);
+}
+EXPORT_SYMBOL_NS_GPL(_iommufd_object_alloc, IOMMUFD);
+
+/* Caller should xa_lock(&viommu->vdevs) to protect the return value */
+struct device *iommufd_viommu_find_dev(struct iommufd_viommu *viommu,
+ unsigned long vdev_id)
+{
+ struct iommufd_vdevice *vdev;
+
+ lockdep_assert_held(&viommu->vdevs.xa_lock);
+
+ vdev = xa_load(&viommu->vdevs, vdev_id);
+ return vdev ? vdev->dev : NULL;
+}
+EXPORT_SYMBOL_NS_GPL(iommufd_viommu_find_dev, IOMMUFD);
+
+MODULE_DESCRIPTION("iommufd code shared with builtin modules");
+MODULE_LICENSE("GPL");
diff --git a/drivers/iommu/iommufd/fault.c b/drivers/iommu/iommufd/fault.c
index e590973ce5cf..053b0e30f55a 100644
--- a/drivers/iommu/iommufd/fault.c
+++ b/drivers/iommu/iommufd/fault.c
@@ -10,6 +10,7 @@
#include <linux/module.h>
#include <linux/mutex.h>
#include <linux/pci.h>
+#include <linux/pci-ats.h>
#include <linux/poll.h>
#include <uapi/linux/iommufd.h>
@@ -27,8 +28,12 @@ static int iommufd_fault_iopf_enable(struct iommufd_device *idev)
* resource between PF and VFs. There is no coordination for this
* shared capability. This waits for a vPRI reset to recover.
*/
- if (dev_is_pci(dev) && to_pci_dev(dev)->is_virtfn)
- return -EINVAL;
+ if (dev_is_pci(dev)) {
+ struct pci_dev *pdev = to_pci_dev(dev);
+
+ if (pdev->is_virtfn && pci_pri_supported(pdev))
+ return -EINVAL;
+ }
mutex_lock(&idev->iopf_lock);
/* Device iopf has already been on. */
diff --git a/drivers/iommu/iommufd/hw_pagetable.c b/drivers/iommu/iommufd/hw_pagetable.c
index d06bf6e6c19f..702057655a81 100644
--- a/drivers/iommu/iommufd/hw_pagetable.c
+++ b/drivers/iommu/iommufd/hw_pagetable.c
@@ -57,7 +57,10 @@ void iommufd_hwpt_nested_destroy(struct iommufd_object *obj)
container_of(obj, struct iommufd_hwpt_nested, common.obj);
__iommufd_hwpt_destroy(&hwpt_nested->common);
- refcount_dec(&hwpt_nested->parent->common.obj.users);
+ if (hwpt_nested->viommu)
+ refcount_dec(&hwpt_nested->viommu->obj.users);
+ else
+ refcount_dec(&hwpt_nested->parent->common.obj.users);
}
void iommufd_hwpt_nested_abort(struct iommufd_object *obj)
@@ -248,8 +251,7 @@ iommufd_hwpt_nested_alloc(struct iommufd_ctx *ictx,
}
hwpt->domain->owner = ops;
- if (WARN_ON_ONCE(hwpt->domain->type != IOMMU_DOMAIN_NESTED ||
- !hwpt->domain->ops->cache_invalidate_user)) {
+ if (WARN_ON_ONCE(hwpt->domain->type != IOMMU_DOMAIN_NESTED)) {
rc = -EINVAL;
goto out_abort;
}
@@ -260,6 +262,58 @@ out_abort:
return ERR_PTR(rc);
}
+/**
+ * iommufd_viommu_alloc_hwpt_nested() - Get a hwpt_nested for a vIOMMU
+ * @viommu: vIOMMU ojbect to associate the hwpt_nested/domain with
+ * @flags: Flags from userspace
+ * @user_data: user_data pointer. Must be valid
+ *
+ * Allocate a new IOMMU_DOMAIN_NESTED for a vIOMMU and return it as a NESTED
+ * hw_pagetable.
+ */
+static struct iommufd_hwpt_nested *
+iommufd_viommu_alloc_hwpt_nested(struct iommufd_viommu *viommu, u32 flags,
+ const struct iommu_user_data *user_data)
+{
+ struct iommufd_hwpt_nested *hwpt_nested;
+ struct iommufd_hw_pagetable *hwpt;
+ int rc;
+
+ if (!user_data->len)
+ return ERR_PTR(-EOPNOTSUPP);
+ if (!viommu->ops || !viommu->ops->alloc_domain_nested)
+ return ERR_PTR(-EOPNOTSUPP);
+
+ hwpt_nested = __iommufd_object_alloc(
+ viommu->ictx, hwpt_nested, IOMMUFD_OBJ_HWPT_NESTED, common.obj);
+ if (IS_ERR(hwpt_nested))
+ return ERR_CAST(hwpt_nested);
+ hwpt = &hwpt_nested->common;
+
+ hwpt_nested->viommu = viommu;
+ refcount_inc(&viommu->obj.users);
+ hwpt_nested->parent = viommu->hwpt;
+
+ hwpt->domain =
+ viommu->ops->alloc_domain_nested(viommu, flags, user_data);
+ if (IS_ERR(hwpt->domain)) {
+ rc = PTR_ERR(hwpt->domain);
+ hwpt->domain = NULL;
+ goto out_abort;
+ }
+ hwpt->domain->owner = viommu->iommu_dev->ops;
+
+ if (WARN_ON_ONCE(hwpt->domain->type != IOMMU_DOMAIN_NESTED)) {
+ rc = -EINVAL;
+ goto out_abort;
+ }
+ return hwpt_nested;
+
+out_abort:
+ iommufd_object_abort_and_destroy(viommu->ictx, &hwpt->obj);
+ return ERR_PTR(rc);
+}
+
int iommufd_hwpt_alloc(struct iommufd_ucmd *ucmd)
{
struct iommu_hwpt_alloc *cmd = ucmd->cmd;
@@ -316,6 +370,22 @@ int iommufd_hwpt_alloc(struct iommufd_ucmd *ucmd)
goto out_unlock;
}
hwpt = &hwpt_nested->common;
+ } else if (pt_obj->type == IOMMUFD_OBJ_VIOMMU) {
+ struct iommufd_hwpt_nested *hwpt_nested;
+ struct iommufd_viommu *viommu;
+
+ viommu = container_of(pt_obj, struct iommufd_viommu, obj);
+ if (viommu->iommu_dev != __iommu_get_iommu_dev(idev->dev)) {
+ rc = -EINVAL;
+ goto out_unlock;
+ }
+ hwpt_nested = iommufd_viommu_alloc_hwpt_nested(
+ viommu, cmd->flags, &user_data);
+ if (IS_ERR(hwpt_nested)) {
+ rc = PTR_ERR(hwpt_nested);
+ goto out_unlock;
+ }
+ hwpt = &hwpt_nested->common;
} else {
rc = -EINVAL;
goto out_put_pt;
@@ -412,7 +482,7 @@ int iommufd_hwpt_invalidate(struct iommufd_ucmd *ucmd)
.entry_len = cmd->entry_len,
.entry_num = cmd->entry_num,
};
- struct iommufd_hw_pagetable *hwpt;
+ struct iommufd_object *pt_obj;
u32 done_num = 0;
int rc;
@@ -426,17 +496,40 @@ int iommufd_hwpt_invalidate(struct iommufd_ucmd *ucmd)
goto out;
}
- hwpt = iommufd_get_hwpt_nested(ucmd, cmd->hwpt_id);
- if (IS_ERR(hwpt)) {
- rc = PTR_ERR(hwpt);
+ pt_obj = iommufd_get_object(ucmd->ictx, cmd->hwpt_id, IOMMUFD_OBJ_ANY);
+ if (IS_ERR(pt_obj)) {
+ rc = PTR_ERR(pt_obj);
goto out;
}
+ if (pt_obj->type == IOMMUFD_OBJ_HWPT_NESTED) {
+ struct iommufd_hw_pagetable *hwpt =
+ container_of(pt_obj, struct iommufd_hw_pagetable, obj);
+
+ if (!hwpt->domain->ops ||
+ !hwpt->domain->ops->cache_invalidate_user) {
+ rc = -EOPNOTSUPP;
+ goto out_put_pt;
+ }
+ rc = hwpt->domain->ops->cache_invalidate_user(hwpt->domain,
+ &data_array);
+ } else if (pt_obj->type == IOMMUFD_OBJ_VIOMMU) {
+ struct iommufd_viommu *viommu =
+ container_of(pt_obj, struct iommufd_viommu, obj);
+
+ if (!viommu->ops || !viommu->ops->cache_invalidate) {
+ rc = -EOPNOTSUPP;
+ goto out_put_pt;
+ }
+ rc = viommu->ops->cache_invalidate(viommu, &data_array);
+ } else {
+ rc = -EINVAL;
+ goto out_put_pt;
+ }
- rc = hwpt->domain->ops->cache_invalidate_user(hwpt->domain,
- &data_array);
done_num = data_array.entry_num;
- iommufd_put_object(ucmd->ictx, &hwpt->obj);
+out_put_pt:
+ iommufd_put_object(ucmd->ictx, pt_obj);
out:
cmd->entry_num = done_num;
if (iommufd_ucmd_respond(ucmd, sizeof(*cmd)))
diff --git a/drivers/iommu/iommufd/io_pagetable.c b/drivers/iommu/iommufd/io_pagetable.c
index 4bf7ccd39d46..8a790e597e12 100644
--- a/drivers/iommu/iommufd/io_pagetable.c
+++ b/drivers/iommu/iommufd/io_pagetable.c
@@ -107,9 +107,9 @@ static bool __alloc_iova_check_used(struct interval_tree_span_iter *span,
* Does not return a 0 IOVA even if it is valid.
*/
static int iopt_alloc_iova(struct io_pagetable *iopt, unsigned long *iova,
- unsigned long uptr, unsigned long length)
+ unsigned long addr, unsigned long length)
{
- unsigned long page_offset = uptr % PAGE_SIZE;
+ unsigned long page_offset = addr % PAGE_SIZE;
struct interval_tree_double_span_iter used_span;
struct interval_tree_span_iter allowed_span;
unsigned long max_alignment = PAGE_SIZE;
@@ -122,15 +122,15 @@ static int iopt_alloc_iova(struct io_pagetable *iopt, unsigned long *iova,
return -EOVERFLOW;
/*
- * Keep alignment present in the uptr when building the IOVA, this
+ * Keep alignment present in addr when building the IOVA, which
* increases the chance we can map a THP.
*/
- if (!uptr)
+ if (!addr)
iova_alignment = roundup_pow_of_two(length);
else
iova_alignment = min_t(unsigned long,
roundup_pow_of_two(length),
- 1UL << __ffs64(uptr));
+ 1UL << __ffs64(addr));
#ifdef CONFIG_TRANSPARENT_HUGEPAGE
max_alignment = HPAGE_SIZE;
@@ -248,6 +248,7 @@ static int iopt_alloc_area_pages(struct io_pagetable *iopt,
int iommu_prot, unsigned int flags)
{
struct iopt_pages_list *elm;
+ unsigned long start;
unsigned long iova;
int rc = 0;
@@ -267,9 +268,15 @@ static int iopt_alloc_area_pages(struct io_pagetable *iopt,
/* Use the first entry to guess the ideal IOVA alignment */
elm = list_first_entry(pages_list, struct iopt_pages_list,
next);
- rc = iopt_alloc_iova(
- iopt, dst_iova,
- (uintptr_t)elm->pages->uptr + elm->start_byte, length);
+ switch (elm->pages->type) {
+ case IOPT_ADDRESS_USER:
+ start = elm->start_byte + (uintptr_t)elm->pages->uptr;
+ break;
+ case IOPT_ADDRESS_FILE:
+ start = elm->start_byte + elm->pages->start;
+ break;
+ }
+ rc = iopt_alloc_iova(iopt, dst_iova, start, length);
if (rc)
goto out_unlock;
if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
@@ -384,6 +391,34 @@ out_unlock_domains:
return rc;
}
+static int iopt_map_common(struct iommufd_ctx *ictx, struct io_pagetable *iopt,
+ struct iopt_pages *pages, unsigned long *iova,
+ unsigned long length, unsigned long start_byte,
+ int iommu_prot, unsigned int flags)
+{
+ struct iopt_pages_list elm = {};
+ LIST_HEAD(pages_list);
+ int rc;
+
+ elm.pages = pages;
+ elm.start_byte = start_byte;
+ if (ictx->account_mode == IOPT_PAGES_ACCOUNT_MM &&
+ elm.pages->account_mode == IOPT_PAGES_ACCOUNT_USER)
+ elm.pages->account_mode = IOPT_PAGES_ACCOUNT_MM;
+ elm.length = length;
+ list_add(&elm.next, &pages_list);
+
+ rc = iopt_map_pages(iopt, &pages_list, length, iova, iommu_prot, flags);
+ if (rc) {
+ if (elm.area)
+ iopt_abort_area(elm.area);
+ if (elm.pages)
+ iopt_put_pages(elm.pages);
+ return rc;
+ }
+ return 0;
+}
+
/**
* iopt_map_user_pages() - Map a user VA to an iova in the io page table
* @ictx: iommufd_ctx the iopt is part of
@@ -408,29 +443,41 @@ int iopt_map_user_pages(struct iommufd_ctx *ictx, struct io_pagetable *iopt,
unsigned long length, int iommu_prot,
unsigned int flags)
{
- struct iopt_pages_list elm = {};
- LIST_HEAD(pages_list);
- int rc;
+ struct iopt_pages *pages;
- elm.pages = iopt_alloc_pages(uptr, length, iommu_prot & IOMMU_WRITE);
- if (IS_ERR(elm.pages))
- return PTR_ERR(elm.pages);
- if (ictx->account_mode == IOPT_PAGES_ACCOUNT_MM &&
- elm.pages->account_mode == IOPT_PAGES_ACCOUNT_USER)
- elm.pages->account_mode = IOPT_PAGES_ACCOUNT_MM;
- elm.start_byte = uptr - elm.pages->uptr;
- elm.length = length;
- list_add(&elm.next, &pages_list);
+ pages = iopt_alloc_user_pages(uptr, length, iommu_prot & IOMMU_WRITE);
+ if (IS_ERR(pages))
+ return PTR_ERR(pages);
- rc = iopt_map_pages(iopt, &pages_list, length, iova, iommu_prot, flags);
- if (rc) {
- if (elm.area)
- iopt_abort_area(elm.area);
- if (elm.pages)
- iopt_put_pages(elm.pages);
- return rc;
- }
- return 0;
+ return iopt_map_common(ictx, iopt, pages, iova, length,
+ uptr - pages->uptr, iommu_prot, flags);
+}
+
+/**
+ * iopt_map_file_pages() - Like iopt_map_user_pages, but map a file.
+ * @ictx: iommufd_ctx the iopt is part of
+ * @iopt: io_pagetable to act on
+ * @iova: If IOPT_ALLOC_IOVA is set this is unused on input and contains
+ * the chosen iova on output. Otherwise is the iova to map to on input
+ * @file: file to map
+ * @start: map file starting at this byte offset
+ * @length: Number of bytes to map
+ * @iommu_prot: Combination of IOMMU_READ/WRITE/etc bits for the mapping
+ * @flags: IOPT_ALLOC_IOVA or zero
+ */
+int iopt_map_file_pages(struct iommufd_ctx *ictx, struct io_pagetable *iopt,
+ unsigned long *iova, struct file *file,
+ unsigned long start, unsigned long length,
+ int iommu_prot, unsigned int flags)
+{
+ struct iopt_pages *pages;
+
+ pages = iopt_alloc_file_pages(file, start, length,
+ iommu_prot & IOMMU_WRITE);
+ if (IS_ERR(pages))
+ return PTR_ERR(pages);
+ return iopt_map_common(ictx, iopt, pages, iova, length,
+ start - pages->start, iommu_prot, flags);
}
struct iova_bitmap_fn_arg {
diff --git a/drivers/iommu/iommufd/io_pagetable.h b/drivers/iommu/iommufd/io_pagetable.h
index c61d74471684..10c928a9a463 100644
--- a/drivers/iommu/iommufd/io_pagetable.h
+++ b/drivers/iommu/iommufd/io_pagetable.h
@@ -173,6 +173,12 @@ enum {
IOPT_PAGES_ACCOUNT_NONE = 0,
IOPT_PAGES_ACCOUNT_USER = 1,
IOPT_PAGES_ACCOUNT_MM = 2,
+ IOPT_PAGES_ACCOUNT_MODE_NUM = 3,
+};
+
+enum iopt_address_type {
+ IOPT_ADDRESS_USER = 0,
+ IOPT_ADDRESS_FILE = 1,
};
/*
@@ -195,7 +201,14 @@ struct iopt_pages {
struct task_struct *source_task;
struct mm_struct *source_mm;
struct user_struct *source_user;
- void __user *uptr;
+ enum iopt_address_type type;
+ union {
+ void __user *uptr; /* IOPT_ADDRESS_USER */
+ struct { /* IOPT_ADDRESS_FILE */
+ struct file *file;
+ unsigned long start;
+ };
+ };
bool writable:1;
u8 account_mode;
@@ -206,8 +219,10 @@ struct iopt_pages {
struct rb_root_cached domains_itree;
};
-struct iopt_pages *iopt_alloc_pages(void __user *uptr, unsigned long length,
- bool writable);
+struct iopt_pages *iopt_alloc_user_pages(void __user *uptr,
+ unsigned long length, bool writable);
+struct iopt_pages *iopt_alloc_file_pages(struct file *file, unsigned long start,
+ unsigned long length, bool writable);
void iopt_release_pages(struct kref *kref);
static inline void iopt_put_pages(struct iopt_pages *pages)
{
@@ -238,4 +253,9 @@ struct iopt_pages_access {
unsigned int users;
};
+struct pfn_reader_user;
+
+int iopt_pages_update_pinned(struct iopt_pages *pages, unsigned long npages,
+ bool inc, struct pfn_reader_user *user);
+
#endif
diff --git a/drivers/iommu/iommufd/ioas.c b/drivers/iommu/iommufd/ioas.c
index 2c4b2bb11e78..1542c5fd10a8 100644
--- a/drivers/iommu/iommufd/ioas.c
+++ b/drivers/iommu/iommufd/ioas.c
@@ -2,6 +2,7 @@
/*
* Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES
*/
+#include <linux/file.h>
#include <linux/interval_tree.h>
#include <linux/iommu.h>
#include <linux/iommufd.h>
@@ -51,7 +52,10 @@ int iommufd_ioas_alloc_ioctl(struct iommufd_ucmd *ucmd)
rc = iommufd_ucmd_respond(ucmd, sizeof(*cmd));
if (rc)
goto out_table;
+
+ down_read(&ucmd->ictx->ioas_creation_lock);
iommufd_object_finalize(ucmd->ictx, &ioas->obj);
+ up_read(&ucmd->ictx->ioas_creation_lock);
return 0;
out_table:
@@ -197,6 +201,52 @@ static int conv_iommu_prot(u32 map_flags)
return iommu_prot;
}
+int iommufd_ioas_map_file(struct iommufd_ucmd *ucmd)
+{
+ struct iommu_ioas_map_file *cmd = ucmd->cmd;
+ unsigned long iova = cmd->iova;
+ struct iommufd_ioas *ioas;
+ unsigned int flags = 0;
+ struct file *file;
+ int rc;
+
+ if (cmd->flags &
+ ~(IOMMU_IOAS_MAP_FIXED_IOVA | IOMMU_IOAS_MAP_WRITEABLE |
+ IOMMU_IOAS_MAP_READABLE))
+ return -EOPNOTSUPP;
+
+ if (cmd->iova >= ULONG_MAX || cmd->length >= ULONG_MAX)
+ return -EOVERFLOW;
+
+ if (!(cmd->flags &
+ (IOMMU_IOAS_MAP_WRITEABLE | IOMMU_IOAS_MAP_READABLE)))
+ return -EINVAL;
+
+ ioas = iommufd_get_ioas(ucmd->ictx, cmd->ioas_id);
+ if (IS_ERR(ioas))
+ return PTR_ERR(ioas);
+
+ if (!(cmd->flags & IOMMU_IOAS_MAP_FIXED_IOVA))
+ flags = IOPT_ALLOC_IOVA;
+
+ file = fget(cmd->fd);
+ if (!file)
+ return -EBADF;
+
+ rc = iopt_map_file_pages(ucmd->ictx, &ioas->iopt, &iova, file,
+ cmd->start, cmd->length,
+ conv_iommu_prot(cmd->flags), flags);
+ if (rc)
+ goto out_put;
+
+ cmd->iova = iova;
+ rc = iommufd_ucmd_respond(ucmd, sizeof(*cmd));
+out_put:
+ iommufd_put_object(ucmd->ictx, &ioas->obj);
+ fput(file);
+ return rc;
+}
+
int iommufd_ioas_map(struct iommufd_ucmd *ucmd)
{
struct iommu_ioas_map *cmd = ucmd->cmd;
@@ -327,6 +377,215 @@ out_put:
return rc;
}
+static void iommufd_release_all_iova_rwsem(struct iommufd_ctx *ictx,
+ struct xarray *ioas_list)
+{
+ struct iommufd_ioas *ioas;
+ unsigned long index;
+
+ xa_for_each(ioas_list, index, ioas) {
+ up_write(&ioas->iopt.iova_rwsem);
+ refcount_dec(&ioas->obj.users);
+ }
+ up_write(&ictx->ioas_creation_lock);
+ xa_destroy(ioas_list);
+}
+
+static int iommufd_take_all_iova_rwsem(struct iommufd_ctx *ictx,
+ struct xarray *ioas_list)
+{
+ struct iommufd_object *obj;
+ unsigned long index;
+ int rc;
+
+ /*
+ * This is very ugly, it is done instead of adding a lock around
+ * pages->source_mm, which is a performance path for mdev, we just
+ * obtain the write side of all the iova_rwsems which also protects the
+ * pages->source_*. Due to copies we can't know which IOAS could read
+ * from the pages, so we just lock everything. This is the only place
+ * locks are nested and they are uniformly taken in ID order.
+ *
+ * ioas_creation_lock prevents new IOAS from being installed in the
+ * xarray while we do this, and also prevents more than one thread from
+ * holding nested locks.
+ */
+ down_write(&ictx->ioas_creation_lock);
+ xa_lock(&ictx->objects);
+ xa_for_each(&ictx->objects, index, obj) {
+ struct iommufd_ioas *ioas;
+
+ if (!obj || obj->type != IOMMUFD_OBJ_IOAS)
+ continue;
+
+ if (!refcount_inc_not_zero(&obj->users))
+ continue;
+
+ xa_unlock(&ictx->objects);
+
+ ioas = container_of(obj, struct iommufd_ioas, obj);
+ down_write_nest_lock(&ioas->iopt.iova_rwsem,
+ &ictx->ioas_creation_lock);
+
+ rc = xa_err(xa_store(ioas_list, index, ioas, GFP_KERNEL));
+ if (rc) {
+ iommufd_release_all_iova_rwsem(ictx, ioas_list);
+ return rc;
+ }
+
+ xa_lock(&ictx->objects);
+ }
+ xa_unlock(&ictx->objects);
+ return 0;
+}
+
+static bool need_charge_update(struct iopt_pages *pages)
+{
+ switch (pages->account_mode) {
+ case IOPT_PAGES_ACCOUNT_NONE:
+ return false;
+ case IOPT_PAGES_ACCOUNT_MM:
+ return pages->source_mm != current->mm;
+ case IOPT_PAGES_ACCOUNT_USER:
+ /*
+ * Update when mm changes because it also accounts
+ * in mm->pinned_vm.
+ */
+ return (pages->source_user != current_user()) ||
+ (pages->source_mm != current->mm);
+ }
+ return true;
+}
+
+static int charge_current(unsigned long *npinned)
+{
+ struct iopt_pages tmp = {
+ .source_mm = current->mm,
+ .source_task = current->group_leader,
+ .source_user = current_user(),
+ };
+ unsigned int account_mode;
+ int rc;
+
+ for (account_mode = 0; account_mode != IOPT_PAGES_ACCOUNT_MODE_NUM;
+ account_mode++) {
+ if (!npinned[account_mode])
+ continue;
+
+ tmp.account_mode = account_mode;
+ rc = iopt_pages_update_pinned(&tmp, npinned[account_mode], true,
+ NULL);
+ if (rc)
+ goto err_undo;
+ }
+ return 0;
+
+err_undo:
+ while (account_mode != 0) {
+ account_mode--;
+ if (!npinned[account_mode])
+ continue;
+ tmp.account_mode = account_mode;
+ iopt_pages_update_pinned(&tmp, npinned[account_mode], false,
+ NULL);
+ }
+ return rc;
+}
+
+static void change_mm(struct iopt_pages *pages)
+{
+ struct task_struct *old_task = pages->source_task;
+ struct user_struct *old_user = pages->source_user;
+ struct mm_struct *old_mm = pages->source_mm;
+
+ pages->source_mm = current->mm;
+ mmgrab(pages->source_mm);
+ mmdrop(old_mm);
+
+ pages->source_task = current->group_leader;
+ get_task_struct(pages->source_task);
+ put_task_struct(old_task);
+
+ pages->source_user = get_uid(current_user());
+ free_uid(old_user);
+}
+
+#define for_each_ioas_area(_xa, _index, _ioas, _area) \
+ xa_for_each((_xa), (_index), (_ioas)) \
+ for (_area = iopt_area_iter_first(&_ioas->iopt, 0, ULONG_MAX); \
+ _area; \
+ _area = iopt_area_iter_next(_area, 0, ULONG_MAX))
+
+int iommufd_ioas_change_process(struct iommufd_ucmd *ucmd)
+{
+ struct iommu_ioas_change_process *cmd = ucmd->cmd;
+ struct iommufd_ctx *ictx = ucmd->ictx;
+ unsigned long all_npinned[IOPT_PAGES_ACCOUNT_MODE_NUM] = {};
+ struct iommufd_ioas *ioas;
+ struct iopt_area *area;
+ struct iopt_pages *pages;
+ struct xarray ioas_list;
+ unsigned long index;
+ int rc;
+
+ if (cmd->__reserved)
+ return -EOPNOTSUPP;
+
+ xa_init(&ioas_list);
+ rc = iommufd_take_all_iova_rwsem(ictx, &ioas_list);
+ if (rc)
+ return rc;
+
+ for_each_ioas_area(&ioas_list, index, ioas, area) {
+ if (area->pages->type != IOPT_ADDRESS_FILE) {
+ rc = -EINVAL;
+ goto out;
+ }
+ }
+
+ /*
+ * Count last_pinned pages, then clear it to avoid double counting
+ * if the same iopt_pages is visited multiple times in this loop.
+ * Since we are under all the locks, npinned == last_npinned, so we
+ * can easily restore last_npinned before we return.
+ */
+ for_each_ioas_area(&ioas_list, index, ioas, area) {
+ pages = area->pages;
+
+ if (need_charge_update(pages)) {
+ all_npinned[pages->account_mode] += pages->last_npinned;
+ pages->last_npinned = 0;
+ }
+ }
+
+ rc = charge_current(all_npinned);
+
+ if (rc) {
+ /* Charge failed. Fix last_npinned and bail. */
+ for_each_ioas_area(&ioas_list, index, ioas, area)
+ area->pages->last_npinned = area->pages->npinned;
+ goto out;
+ }
+
+ for_each_ioas_area(&ioas_list, index, ioas, area) {
+ pages = area->pages;
+
+ /* Uncharge the old one (which also restores last_npinned) */
+ if (need_charge_update(pages)) {
+ int r = iopt_pages_update_pinned(pages, pages->npinned,
+ false, NULL);
+
+ if (WARN_ON(r))
+ rc = r;
+ }
+ change_mm(pages);
+ }
+
+out:
+ iommufd_release_all_iova_rwsem(ictx, &ioas_list);
+ return rc;
+}
+
int iommufd_option_rlimit_mode(struct iommu_option *cmd,
struct iommufd_ctx *ictx)
{
diff --git a/drivers/iommu/iommufd/iommufd_private.h b/drivers/iommu/iommufd/iommufd_private.h
index f1d865e6fab6..b6d706cf2c66 100644
--- a/drivers/iommu/iommufd/iommufd_private.h
+++ b/drivers/iommu/iommufd/iommufd_private.h
@@ -5,8 +5,8 @@
#define __IOMMUFD_PRIVATE_H
#include <linux/iommu.h>
+#include <linux/iommufd.h>
#include <linux/iova_bitmap.h>
-#include <linux/refcount.h>
#include <linux/rwsem.h>
#include <linux/uaccess.h>
#include <linux/xarray.h>
@@ -24,6 +24,7 @@ struct iommufd_ctx {
struct xarray objects;
struct xarray groups;
wait_queue_head_t destroy_wait;
+ struct rw_semaphore ioas_creation_lock;
u8 account_mode;
/* Compatibility with VFIO no iommu */
@@ -69,6 +70,10 @@ int iopt_map_user_pages(struct iommufd_ctx *ictx, struct io_pagetable *iopt,
unsigned long *iova, void __user *uptr,
unsigned long length, int iommu_prot,
unsigned int flags);
+int iopt_map_file_pages(struct iommufd_ctx *ictx, struct io_pagetable *iopt,
+ unsigned long *iova, struct file *file,
+ unsigned long start, unsigned long length,
+ int iommu_prot, unsigned int flags);
int iopt_map_pages(struct io_pagetable *iopt, struct list_head *pages_list,
unsigned long length, unsigned long *dst_iova,
int iommu_prot, unsigned int flags);
@@ -122,29 +127,6 @@ static inline int iommufd_ucmd_respond(struct iommufd_ucmd *ucmd,
return 0;
}
-enum iommufd_object_type {
- IOMMUFD_OBJ_NONE,
- IOMMUFD_OBJ_ANY = IOMMUFD_OBJ_NONE,
- IOMMUFD_OBJ_DEVICE,
- IOMMUFD_OBJ_HWPT_PAGING,
- IOMMUFD_OBJ_HWPT_NESTED,
- IOMMUFD_OBJ_IOAS,
- IOMMUFD_OBJ_ACCESS,
- IOMMUFD_OBJ_FAULT,
-#ifdef CONFIG_IOMMUFD_TEST
- IOMMUFD_OBJ_SELFTEST,
-#endif
- IOMMUFD_OBJ_MAX,
-};
-
-/* Base struct for all objects with a userspace ID handle. */
-struct iommufd_object {
- refcount_t shortterm_users;
- refcount_t users;
- enum iommufd_object_type type;
- unsigned int id;
-};
-
static inline bool iommufd_lock_obj(struct iommufd_object *obj)
{
if (!refcount_inc_not_zero(&obj->users))
@@ -225,10 +207,6 @@ iommufd_object_put_and_try_destroy(struct iommufd_ctx *ictx,
iommufd_object_remove(ictx, obj, obj->id, 0);
}
-struct iommufd_object *_iommufd_object_alloc(struct iommufd_ctx *ictx,
- size_t size,
- enum iommufd_object_type type);
-
#define __iommufd_object_alloc(ictx, ptr, type, obj) \
container_of(_iommufd_object_alloc( \
ictx, \
@@ -276,6 +254,8 @@ void iommufd_ioas_destroy(struct iommufd_object *obj);
int iommufd_ioas_iova_ranges(struct iommufd_ucmd *ucmd);
int iommufd_ioas_allow_iovas(struct iommufd_ucmd *ucmd);
int iommufd_ioas_map(struct iommufd_ucmd *ucmd);
+int iommufd_ioas_map_file(struct iommufd_ucmd *ucmd);
+int iommufd_ioas_change_process(struct iommufd_ucmd *ucmd);
int iommufd_ioas_copy(struct iommufd_ucmd *ucmd);
int iommufd_ioas_unmap(struct iommufd_ucmd *ucmd);
int iommufd_ioas_option(struct iommufd_ucmd *ucmd);
@@ -312,6 +292,7 @@ struct iommufd_hwpt_paging {
struct iommufd_hwpt_nested {
struct iommufd_hw_pagetable common;
struct iommufd_hwpt_paging *parent;
+ struct iommufd_viommu *viommu;
};
static inline bool hwpt_is_paging(struct iommufd_hw_pagetable *hwpt)
@@ -528,6 +509,27 @@ static inline int iommufd_hwpt_replace_device(struct iommufd_device *idev,
return iommu_group_replace_domain(idev->igroup->group, hwpt->domain);
}
+static inline struct iommufd_viommu *
+iommufd_get_viommu(struct iommufd_ucmd *ucmd, u32 id)
+{
+ return container_of(iommufd_get_object(ucmd->ictx, id,
+ IOMMUFD_OBJ_VIOMMU),
+ struct iommufd_viommu, obj);
+}
+
+int iommufd_viommu_alloc_ioctl(struct iommufd_ucmd *ucmd);
+void iommufd_viommu_destroy(struct iommufd_object *obj);
+int iommufd_vdevice_alloc_ioctl(struct iommufd_ucmd *ucmd);
+void iommufd_vdevice_destroy(struct iommufd_object *obj);
+
+struct iommufd_vdevice {
+ struct iommufd_object obj;
+ struct iommufd_ctx *ictx;
+ struct iommufd_viommu *viommu;
+ struct device *dev;
+ u64 id; /* per-vIOMMU virtual ID */
+};
+
#ifdef CONFIG_IOMMUFD_TEST
int iommufd_test(struct iommufd_ucmd *ucmd);
void iommufd_selftest_destroy(struct iommufd_object *obj);
diff --git a/drivers/iommu/iommufd/iommufd_test.h b/drivers/iommu/iommufd/iommufd_test.h
index f4bc23a92f9a..a6b7a163f636 100644
--- a/drivers/iommu/iommufd/iommufd_test.h
+++ b/drivers/iommu/iommufd/iommufd_test.h
@@ -23,6 +23,7 @@ enum {
IOMMU_TEST_OP_DIRTY,
IOMMU_TEST_OP_MD_CHECK_IOTLB,
IOMMU_TEST_OP_TRIGGER_IOPF,
+ IOMMU_TEST_OP_DEV_CHECK_CACHE,
};
enum {
@@ -54,6 +55,11 @@ enum {
MOCK_NESTED_DOMAIN_IOTLB_NUM = 4,
};
+enum {
+ MOCK_DEV_CACHE_ID_MAX = 3,
+ MOCK_DEV_CACHE_NUM = 4,
+};
+
struct iommu_test_cmd {
__u32 size;
__u32 op;
@@ -135,6 +141,10 @@ struct iommu_test_cmd {
__u32 perm;
__u64 addr;
} trigger_iopf;
+ struct {
+ __u32 id;
+ __u32 cache;
+ } check_dev_cache;
};
__u32 last;
};
@@ -152,6 +162,7 @@ struct iommu_test_hw_info {
/* Should not be equal to any defined value in enum iommu_hwpt_data_type */
#define IOMMU_HWPT_DATA_SELFTEST 0xdead
#define IOMMU_TEST_IOTLB_DEFAULT 0xbadbeef
+#define IOMMU_TEST_DEV_CACHE_DEFAULT 0xbaddad
/**
* struct iommu_hwpt_selftest
@@ -180,4 +191,25 @@ struct iommu_hwpt_invalidate_selftest {
__u32 iotlb_id;
};
+#define IOMMU_VIOMMU_TYPE_SELFTEST 0xdeadbeef
+
+/* Should not be equal to any defined value in enum iommu_viommu_invalidate_data_type */
+#define IOMMU_VIOMMU_INVALIDATE_DATA_SELFTEST 0xdeadbeef
+#define IOMMU_VIOMMU_INVALIDATE_DATA_SELFTEST_INVALID 0xdadbeef
+
+/**
+ * struct iommu_viommu_invalidate_selftest - Invalidation data for Mock VIOMMU
+ * (IOMMU_VIOMMU_INVALIDATE_DATA_SELFTEST)
+ * @flags: Invalidate flags
+ * @cache_id: Invalidate cache entry index
+ *
+ * If IOMMU_TEST_INVALIDATE_ALL is set in @flags, @cache_id will be ignored
+ */
+struct iommu_viommu_invalidate_selftest {
+#define IOMMU_TEST_INVALIDATE_FLAG_ALL (1 << 0)
+ __u32 flags;
+ __u32 vdev_id;
+ __u32 cache_id;
+};
+
#endif
diff --git a/drivers/iommu/iommufd/main.c b/drivers/iommu/iommufd/main.c
index b5f5d27ee963..0a96cc8f27da 100644
--- a/drivers/iommu/iommufd/main.c
+++ b/drivers/iommu/iommufd/main.c
@@ -29,38 +29,6 @@ struct iommufd_object_ops {
static const struct iommufd_object_ops iommufd_object_ops[];
static struct miscdevice vfio_misc_dev;
-struct iommufd_object *_iommufd_object_alloc(struct iommufd_ctx *ictx,
- size_t size,
- enum iommufd_object_type type)
-{
- struct iommufd_object *obj;
- int rc;
-
- obj = kzalloc(size, GFP_KERNEL_ACCOUNT);
- if (!obj)
- return ERR_PTR(-ENOMEM);
- obj->type = type;
- /* Starts out bias'd by 1 until it is removed from the xarray */
- refcount_set(&obj->shortterm_users, 1);
- refcount_set(&obj->users, 1);
-
- /*
- * Reserve an ID in the xarray but do not publish the pointer yet since
- * the caller hasn't initialized it yet. Once the pointer is published
- * in the xarray and visible to other threads we can't reliably destroy
- * it anymore, so the caller must complete all errorable operations
- * before calling iommufd_object_finalize().
- */
- rc = xa_alloc(&ictx->objects, &obj->id, XA_ZERO_ENTRY,
- xa_limit_31b, GFP_KERNEL_ACCOUNT);
- if (rc)
- goto out_free;
- return obj;
-out_free:
- kfree(obj);
- return ERR_PTR(rc);
-}
-
/*
* Allow concurrent access to the object.
*
@@ -73,20 +41,26 @@ out_free:
void iommufd_object_finalize(struct iommufd_ctx *ictx,
struct iommufd_object *obj)
{
+ XA_STATE(xas, &ictx->objects, obj->id);
void *old;
- old = xa_store(&ictx->objects, obj->id, obj, GFP_KERNEL);
- /* obj->id was returned from xa_alloc() so the xa_store() cannot fail */
- WARN_ON(old);
+ xa_lock(&ictx->objects);
+ old = xas_store(&xas, obj);
+ xa_unlock(&ictx->objects);
+ /* obj->id was returned from xa_alloc() so the xas_store() cannot fail */
+ WARN_ON(old != XA_ZERO_ENTRY);
}
/* Undo _iommufd_object_alloc() if iommufd_object_finalize() was not called */
void iommufd_object_abort(struct iommufd_ctx *ictx, struct iommufd_object *obj)
{
+ XA_STATE(xas, &ictx->objects, obj->id);
void *old;
- old = xa_erase(&ictx->objects, obj->id);
- WARN_ON(old);
+ xa_lock(&ictx->objects);
+ old = xas_store(&xas, NULL);
+ xa_unlock(&ictx->objects);
+ WARN_ON(old != XA_ZERO_ENTRY);
kfree(obj);
}
@@ -248,6 +222,7 @@ static int iommufd_fops_open(struct inode *inode, struct file *filp)
pr_info_once("IOMMUFD is providing /dev/vfio/vfio, not VFIO.\n");
}
+ init_rwsem(&ictx->ioas_creation_lock);
xa_init_flags(&ictx->objects, XA_FLAGS_ALLOC1 | XA_FLAGS_ACCOUNT);
xa_init(&ictx->groups);
ictx->file = filp;
@@ -333,6 +308,8 @@ union ucmd_buffer {
struct iommu_ioas_unmap unmap;
struct iommu_option option;
struct iommu_vfio_ioas vfio_ioas;
+ struct iommu_viommu_alloc viommu;
+ struct iommu_vdevice_alloc vdev;
#ifdef CONFIG_IOMMUFD_TEST
struct iommu_test_cmd test;
#endif
@@ -372,18 +349,26 @@ static const struct iommufd_ioctl_op iommufd_ioctl_ops[] = {
struct iommu_ioas_alloc, out_ioas_id),
IOCTL_OP(IOMMU_IOAS_ALLOW_IOVAS, iommufd_ioas_allow_iovas,
struct iommu_ioas_allow_iovas, allowed_iovas),
+ IOCTL_OP(IOMMU_IOAS_CHANGE_PROCESS, iommufd_ioas_change_process,
+ struct iommu_ioas_change_process, __reserved),
IOCTL_OP(IOMMU_IOAS_COPY, iommufd_ioas_copy, struct iommu_ioas_copy,
src_iova),
IOCTL_OP(IOMMU_IOAS_IOVA_RANGES, iommufd_ioas_iova_ranges,
struct iommu_ioas_iova_ranges, out_iova_alignment),
IOCTL_OP(IOMMU_IOAS_MAP, iommufd_ioas_map, struct iommu_ioas_map,
iova),
+ IOCTL_OP(IOMMU_IOAS_MAP_FILE, iommufd_ioas_map_file,
+ struct iommu_ioas_map_file, iova),
IOCTL_OP(IOMMU_IOAS_UNMAP, iommufd_ioas_unmap, struct iommu_ioas_unmap,
length),
IOCTL_OP(IOMMU_OPTION, iommufd_option, struct iommu_option,
val64),
IOCTL_OP(IOMMU_VFIO_IOAS, iommufd_vfio_ioas, struct iommu_vfio_ioas,
__reserved),
+ IOCTL_OP(IOMMU_VIOMMU_ALLOC, iommufd_viommu_alloc_ioctl,
+ struct iommu_viommu_alloc, out_viommu_id),
+ IOCTL_OP(IOMMU_VDEVICE_ALLOC, iommufd_vdevice_alloc_ioctl,
+ struct iommu_vdevice_alloc, virt_id),
#ifdef CONFIG_IOMMUFD_TEST
IOCTL_OP(IOMMU_TEST_CMD, iommufd_test, struct iommu_test_cmd, last),
#endif
@@ -519,6 +504,12 @@ static const struct iommufd_object_ops iommufd_object_ops[] = {
[IOMMUFD_OBJ_FAULT] = {
.destroy = iommufd_fault_destroy,
},
+ [IOMMUFD_OBJ_VIOMMU] = {
+ .destroy = iommufd_viommu_destroy,
+ },
+ [IOMMUFD_OBJ_VDEVICE] = {
+ .destroy = iommufd_vdevice_destroy,
+ },
#ifdef CONFIG_IOMMUFD_TEST
[IOMMUFD_OBJ_SELFTEST] = {
.destroy = iommufd_selftest_destroy,
diff --git a/drivers/iommu/iommufd/pages.c b/drivers/iommu/iommufd/pages.c
index 93d806c9c073..3427749bc5ce 100644
--- a/drivers/iommu/iommufd/pages.c
+++ b/drivers/iommu/iommufd/pages.c
@@ -45,6 +45,7 @@
* last_iova + 1 can overflow. An iopt_pages index will always be much less than
* ULONG_MAX so last_index + 1 cannot overflow.
*/
+#include <linux/file.h>
#include <linux/highmem.h>
#include <linux/iommu.h>
#include <linux/iommufd.h>
@@ -346,27 +347,41 @@ static void batch_destroy(struct pfn_batch *batch, void *backup)
kfree(batch->pfns);
}
-/* true if the pfn was added, false otherwise */
-static bool batch_add_pfn(struct pfn_batch *batch, unsigned long pfn)
+static bool batch_add_pfn_num(struct pfn_batch *batch, unsigned long pfn,
+ u32 nr)
{
const unsigned int MAX_NPFNS = type_max(typeof(*batch->npfns));
-
- if (batch->end &&
- pfn == batch->pfns[batch->end - 1] + batch->npfns[batch->end - 1] &&
- batch->npfns[batch->end - 1] != MAX_NPFNS) {
- batch->npfns[batch->end - 1]++;
- batch->total_pfns++;
- return true;
- }
- if (batch->end == batch->array_size)
+ unsigned int end = batch->end;
+
+ if (end && pfn == batch->pfns[end - 1] + batch->npfns[end - 1] &&
+ nr <= MAX_NPFNS - batch->npfns[end - 1]) {
+ batch->npfns[end - 1] += nr;
+ } else if (end < batch->array_size) {
+ batch->pfns[end] = pfn;
+ batch->npfns[end] = nr;
+ batch->end++;
+ } else {
return false;
- batch->total_pfns++;
- batch->pfns[batch->end] = pfn;
- batch->npfns[batch->end] = 1;
- batch->end++;
+ }
+
+ batch->total_pfns += nr;
return true;
}
+static void batch_remove_pfn_num(struct pfn_batch *batch, unsigned long nr)
+{
+ batch->npfns[batch->end - 1] -= nr;
+ if (batch->npfns[batch->end - 1] == 0)
+ batch->end--;
+ batch->total_pfns -= nr;
+}
+
+/* true if the pfn was added, false otherwise */
+static bool batch_add_pfn(struct pfn_batch *batch, unsigned long pfn)
+{
+ return batch_add_pfn_num(batch, pfn, 1);
+}
+
/*
* Fill the batch with pfns from the domain. When the batch is full, or it
* reaches last_index, the function will return. The caller should use
@@ -622,6 +637,41 @@ static void batch_from_pages(struct pfn_batch *batch, struct page **pages,
break;
}
+static int batch_from_folios(struct pfn_batch *batch, struct folio ***folios_p,
+ unsigned long *offset_p, unsigned long npages)
+{
+ int rc = 0;
+ struct folio **folios = *folios_p;
+ unsigned long offset = *offset_p;
+
+ while (npages) {
+ struct folio *folio = *folios;
+ unsigned long nr = folio_nr_pages(folio) - offset;
+ unsigned long pfn = page_to_pfn(folio_page(folio, offset));
+
+ nr = min(nr, npages);
+ npages -= nr;
+
+ if (!batch_add_pfn_num(batch, pfn, nr))
+ break;
+ if (nr > 1) {
+ rc = folio_add_pins(folio, nr - 1);
+ if (rc) {
+ batch_remove_pfn_num(batch, nr);
+ goto out;
+ }
+ }
+
+ folios++;
+ offset = 0;
+ }
+
+out:
+ *folios_p = folios;
+ *offset_p = offset;
+ return rc;
+}
+
static void batch_unpin(struct pfn_batch *batch, struct iopt_pages *pages,
unsigned int first_page_off, size_t npages)
{
@@ -703,19 +753,32 @@ struct pfn_reader_user {
* neither
*/
int locked;
+
+ /* The following are only valid if file != NULL. */
+ struct file *file;
+ struct folio **ufolios;
+ size_t ufolios_len;
+ unsigned long ufolios_offset;
+ struct folio **ufolios_next;
};
static void pfn_reader_user_init(struct pfn_reader_user *user,
struct iopt_pages *pages)
{
user->upages = NULL;
+ user->upages_len = 0;
user->upages_start = 0;
user->upages_end = 0;
user->locked = -1;
-
user->gup_flags = FOLL_LONGTERM;
if (pages->writable)
user->gup_flags |= FOLL_WRITE;
+
+ user->file = (pages->type == IOPT_ADDRESS_FILE) ? pages->file : NULL;
+ user->ufolios = NULL;
+ user->ufolios_len = 0;
+ user->ufolios_next = NULL;
+ user->ufolios_offset = 0;
}
static void pfn_reader_user_destroy(struct pfn_reader_user *user,
@@ -724,13 +787,67 @@ static void pfn_reader_user_destroy(struct pfn_reader_user *user,
if (user->locked != -1) {
if (user->locked)
mmap_read_unlock(pages->source_mm);
- if (pages->source_mm != current->mm)
+ if (!user->file && pages->source_mm != current->mm)
mmput(pages->source_mm);
user->locked = -1;
}
kfree(user->upages);
user->upages = NULL;
+ kfree(user->ufolios);
+ user->ufolios = NULL;
+}
+
+static long pin_memfd_pages(struct pfn_reader_user *user, unsigned long start,
+ unsigned long npages)
+{
+ unsigned long i;
+ unsigned long offset;
+ unsigned long npages_out = 0;
+ struct page **upages = user->upages;
+ unsigned long end = start + (npages << PAGE_SHIFT) - 1;
+ long nfolios = user->ufolios_len / sizeof(*user->ufolios);
+
+ /*
+ * todo: memfd_pin_folios should return the last pinned offset so
+ * we can compute npages pinned, and avoid looping over folios here
+ * if upages == NULL.
+ */
+ nfolios = memfd_pin_folios(user->file, start, end, user->ufolios,
+ nfolios, &offset);
+ if (nfolios <= 0)
+ return nfolios;
+
+ offset >>= PAGE_SHIFT;
+ user->ufolios_next = user->ufolios;
+ user->ufolios_offset = offset;
+
+ for (i = 0; i < nfolios; i++) {
+ struct folio *folio = user->ufolios[i];
+ unsigned long nr = folio_nr_pages(folio);
+ unsigned long npin = min(nr - offset, npages);
+
+ npages -= npin;
+ npages_out += npin;
+
+ if (upages) {
+ if (npin == 1) {
+ *upages++ = folio_page(folio, offset);
+ } else {
+ int rc = folio_add_pins(folio, npin - 1);
+
+ if (rc)
+ return rc;
+
+ while (npin--)
+ *upages++ = folio_page(folio, offset++);
+ }
+ }
+
+ offset = 0;
+ }
+
+ return npages_out;
}
static int pfn_reader_user_pin(struct pfn_reader_user *user,
@@ -739,7 +856,9 @@ static int pfn_reader_user_pin(struct pfn_reader_user *user,
unsigned long last_index)
{
bool remote_mm = pages->source_mm != current->mm;
- unsigned long npages;
+ unsigned long npages = last_index - start_index + 1;
+ unsigned long start;
+ unsigned long unum;
uintptr_t uptr;
long rc;
@@ -747,40 +866,50 @@ static int pfn_reader_user_pin(struct pfn_reader_user *user,
WARN_ON(last_index < start_index))
return -EINVAL;
- if (!user->upages) {
+ if (!user->file && !user->upages) {
/* All undone in pfn_reader_destroy() */
- user->upages_len =
- (last_index - start_index + 1) * sizeof(*user->upages);
+ user->upages_len = npages * sizeof(*user->upages);
user->upages = temp_kmalloc(&user->upages_len, NULL, 0);
if (!user->upages)
return -ENOMEM;
}
+ if (user->file && !user->ufolios) {
+ user->ufolios_len = npages * sizeof(*user->ufolios);
+ user->ufolios = temp_kmalloc(&user->ufolios_len, NULL, 0);
+ if (!user->ufolios)
+ return -ENOMEM;
+ }
+
if (user->locked == -1) {
/*
* The majority of usages will run the map task within the mm
* providing the pages, so we can optimize into
* get_user_pages_fast()
*/
- if (remote_mm) {
+ if (!user->file && remote_mm) {
if (!mmget_not_zero(pages->source_mm))
return -EFAULT;
}
user->locked = 0;
}
- npages = min_t(unsigned long, last_index - start_index + 1,
- user->upages_len / sizeof(*user->upages));
-
+ unum = user->file ? user->ufolios_len / sizeof(*user->ufolios) :
+ user->upages_len / sizeof(*user->upages);
+ npages = min_t(unsigned long, npages, unum);
if (iommufd_should_fail())
return -EFAULT;
- uptr = (uintptr_t)(pages->uptr + start_index * PAGE_SIZE);
- if (!remote_mm)
+ if (user->file) {
+ start = pages->start + (start_index * PAGE_SIZE);
+ rc = pin_memfd_pages(user, start, npages);
+ } else if (!remote_mm) {
+ uptr = (uintptr_t)(pages->uptr + start_index * PAGE_SIZE);
rc = pin_user_pages_fast(uptr, npages, user->gup_flags,
user->upages);
- else {
+ } else {
+ uptr = (uintptr_t)(pages->uptr + start_index * PAGE_SIZE);
if (!user->locked) {
mmap_read_lock(pages->source_mm);
user->locked = 1;
@@ -838,7 +967,8 @@ static int update_mm_locked_vm(struct iopt_pages *pages, unsigned long npages,
mmap_read_unlock(pages->source_mm);
user->locked = 0;
/* If we had the lock then we also have a get */
- } else if ((!user || !user->upages) &&
+
+ } else if ((!user || (!user->upages && !user->ufolios)) &&
pages->source_mm != current->mm) {
if (!mmget_not_zero(pages->source_mm))
return -EINVAL;
@@ -855,8 +985,8 @@ static int update_mm_locked_vm(struct iopt_pages *pages, unsigned long npages,
return rc;
}
-static int do_update_pinned(struct iopt_pages *pages, unsigned long npages,
- bool inc, struct pfn_reader_user *user)
+int iopt_pages_update_pinned(struct iopt_pages *pages, unsigned long npages,
+ bool inc, struct pfn_reader_user *user)
{
int rc = 0;
@@ -890,8 +1020,8 @@ static void update_unpinned(struct iopt_pages *pages)
return;
if (pages->npinned == pages->last_npinned)
return;
- do_update_pinned(pages, pages->last_npinned - pages->npinned, false,
- NULL);
+ iopt_pages_update_pinned(pages, pages->last_npinned - pages->npinned,
+ false, NULL);
}
/*
@@ -921,7 +1051,7 @@ static int pfn_reader_user_update_pinned(struct pfn_reader_user *user,
npages = pages->npinned - pages->last_npinned;
inc = true;
}
- return do_update_pinned(pages, npages, inc, user);
+ return iopt_pages_update_pinned(pages, npages, inc, user);
}
/*
@@ -978,6 +1108,8 @@ static int pfn_reader_fill_span(struct pfn_reader *pfns)
{
struct interval_tree_double_span_iter *span = &pfns->span;
unsigned long start_index = pfns->batch_end_index;
+ struct pfn_reader_user *user = &pfns->user;
+ unsigned long npages;
struct iopt_area *area;
int rc;
@@ -1015,11 +1147,17 @@ static int pfn_reader_fill_span(struct pfn_reader *pfns)
return rc;
}
- batch_from_pages(&pfns->batch,
- pfns->user.upages +
- (start_index - pfns->user.upages_start),
- pfns->user.upages_end - start_index);
- return 0;
+ npages = user->upages_end - start_index;
+ start_index -= user->upages_start;
+ rc = 0;
+
+ if (!user->file)
+ batch_from_pages(&pfns->batch, user->upages + start_index,
+ npages);
+ else
+ rc = batch_from_folios(&pfns->batch, &user->ufolios_next,
+ &user->ufolios_offset, npages);
+ return rc;
}
static bool pfn_reader_done(struct pfn_reader *pfns)
@@ -1092,16 +1230,25 @@ static int pfn_reader_init(struct pfn_reader *pfns, struct iopt_pages *pages,
static void pfn_reader_release_pins(struct pfn_reader *pfns)
{
struct iopt_pages *pages = pfns->pages;
+ struct pfn_reader_user *user = &pfns->user;
- if (pfns->user.upages_end > pfns->batch_end_index) {
- size_t npages = pfns->user.upages_end - pfns->batch_end_index;
-
+ if (user->upages_end > pfns->batch_end_index) {
/* Any pages not transferred to the batch are just unpinned */
- unpin_user_pages(pfns->user.upages + (pfns->batch_end_index -
- pfns->user.upages_start),
- npages);
+
+ unsigned long npages = user->upages_end - pfns->batch_end_index;
+ unsigned long start_index = pfns->batch_end_index -
+ user->upages_start;
+
+ if (!user->file) {
+ unpin_user_pages(user->upages + start_index, npages);
+ } else {
+ long n = user->ufolios_len / sizeof(*user->ufolios);
+
+ unpin_folios(user->ufolios_next,
+ user->ufolios + n - user->ufolios_next);
+ }
iopt_pages_sub_npinned(pages, npages);
- pfns->user.upages_end = pfns->batch_end_index;
+ user->upages_end = pfns->batch_end_index;
}
if (pfns->batch_start_index != pfns->batch_end_index) {
pfn_reader_unpin(pfns);
@@ -1139,11 +1286,11 @@ static int pfn_reader_first(struct pfn_reader *pfns, struct iopt_pages *pages,
return 0;
}
-struct iopt_pages *iopt_alloc_pages(void __user *uptr, unsigned long length,
- bool writable)
+static struct iopt_pages *iopt_alloc_pages(unsigned long start_byte,
+ unsigned long length,
+ bool writable)
{
struct iopt_pages *pages;
- unsigned long end;
/*
* The iommu API uses size_t as the length, and protect the DIV_ROUND_UP
@@ -1152,9 +1299,6 @@ struct iopt_pages *iopt_alloc_pages(void __user *uptr, unsigned long length,
if (length > SIZE_MAX - PAGE_SIZE || length == 0)
return ERR_PTR(-EINVAL);
- if (check_add_overflow((unsigned long)uptr, length, &end))
- return ERR_PTR(-EOVERFLOW);
-
pages = kzalloc(sizeof(*pages), GFP_KERNEL_ACCOUNT);
if (!pages)
return ERR_PTR(-ENOMEM);
@@ -1164,8 +1308,7 @@ struct iopt_pages *iopt_alloc_pages(void __user *uptr, unsigned long length,
mutex_init(&pages->mutex);
pages->source_mm = current->mm;
mmgrab(pages->source_mm);
- pages->uptr = (void __user *)ALIGN_DOWN((uintptr_t)uptr, PAGE_SIZE);
- pages->npages = DIV_ROUND_UP(length + (uptr - pages->uptr), PAGE_SIZE);
+ pages->npages = DIV_ROUND_UP(length + start_byte, PAGE_SIZE);
pages->access_itree = RB_ROOT_CACHED;
pages->domains_itree = RB_ROOT_CACHED;
pages->writable = writable;
@@ -1179,6 +1322,45 @@ struct iopt_pages *iopt_alloc_pages(void __user *uptr, unsigned long length,
return pages;
}
+struct iopt_pages *iopt_alloc_user_pages(void __user *uptr,
+ unsigned long length, bool writable)
+{
+ struct iopt_pages *pages;
+ unsigned long end;
+ void __user *uptr_down =
+ (void __user *) ALIGN_DOWN((uintptr_t)uptr, PAGE_SIZE);
+
+ if (check_add_overflow((unsigned long)uptr, length, &end))
+ return ERR_PTR(-EOVERFLOW);
+
+ pages = iopt_alloc_pages(uptr - uptr_down, length, writable);
+ if (IS_ERR(pages))
+ return pages;
+ pages->uptr = uptr_down;
+ pages->type = IOPT_ADDRESS_USER;
+ return pages;
+}
+
+struct iopt_pages *iopt_alloc_file_pages(struct file *file, unsigned long start,
+ unsigned long length, bool writable)
+
+{
+ struct iopt_pages *pages;
+ unsigned long start_down = ALIGN_DOWN(start, PAGE_SIZE);
+ unsigned long end;
+
+ if (length && check_add_overflow(start, length - 1, &end))
+ return ERR_PTR(-EOVERFLOW);
+
+ pages = iopt_alloc_pages(start - start_down, length, writable);
+ if (IS_ERR(pages))
+ return pages;
+ pages->file = get_file(file);
+ pages->start = start_down;
+ pages->type = IOPT_ADDRESS_FILE;
+ return pages;
+}
+
void iopt_release_pages(struct kref *kref)
{
struct iopt_pages *pages = container_of(kref, struct iopt_pages, kref);
@@ -1191,6 +1373,8 @@ void iopt_release_pages(struct kref *kref)
mutex_destroy(&pages->mutex);
put_task_struct(pages->source_task);
free_uid(pages->source_user);
+ if (pages->type == IOPT_ADDRESS_FILE)
+ fput(pages->file);
kfree(pages);
}
@@ -1630,11 +1814,11 @@ static int iopt_pages_fill_from_domain(struct iopt_pages *pages,
return 0;
}
-static int iopt_pages_fill_from_mm(struct iopt_pages *pages,
- struct pfn_reader_user *user,
- unsigned long start_index,
- unsigned long last_index,
- struct page **out_pages)
+static int iopt_pages_fill(struct iopt_pages *pages,
+ struct pfn_reader_user *user,
+ unsigned long start_index,
+ unsigned long last_index,
+ struct page **out_pages)
{
unsigned long cur_index = start_index;
int rc;
@@ -1708,8 +1892,8 @@ int iopt_pages_fill_xarray(struct iopt_pages *pages, unsigned long start_index,
/* hole */
cur_pages = out_pages + (span.start_hole - start_index);
- rc = iopt_pages_fill_from_mm(pages, &user, span.start_hole,
- span.last_hole, cur_pages);
+ rc = iopt_pages_fill(pages, &user, span.start_hole,
+ span.last_hole, cur_pages);
if (rc)
goto out_clean_xa;
rc = pages_to_xarray(&pages->pinned_pfns, span.start_hole,
@@ -1789,6 +1973,10 @@ static int iopt_pages_rw_page(struct iopt_pages *pages, unsigned long index,
struct page *page = NULL;
int rc;
+ if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
+ WARN_ON(pages->type != IOPT_ADDRESS_USER))
+ return -EINVAL;
+
if (!mmget_not_zero(pages->source_mm))
return iopt_pages_rw_slow(pages, index, index, offset, data,
length, flags);
@@ -1844,6 +2032,15 @@ int iopt_pages_rw_access(struct iopt_pages *pages, unsigned long start_byte,
if ((flags & IOMMUFD_ACCESS_RW_WRITE) && !pages->writable)
return -EPERM;
+ if (pages->type == IOPT_ADDRESS_FILE)
+ return iopt_pages_rw_slow(pages, start_index, last_index,
+ start_byte % PAGE_SIZE, data, length,
+ flags);
+
+ if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
+ WARN_ON(pages->type != IOPT_ADDRESS_USER))
+ return -EINVAL;
+
if (!(flags & IOMMUFD_ACCESS_RW_KTHREAD) && change_mm) {
if (start_index == last_index)
return iopt_pages_rw_page(pages, start_index,
diff --git a/drivers/iommu/iommufd/selftest.c b/drivers/iommu/iommufd/selftest.c
index 540437be168a..2f9de177dffc 100644
--- a/drivers/iommu/iommufd/selftest.c
+++ b/drivers/iommu/iommufd/selftest.c
@@ -126,12 +126,35 @@ struct mock_iommu_domain {
struct xarray pfns;
};
+static inline struct mock_iommu_domain *
+to_mock_domain(struct iommu_domain *domain)
+{
+ return container_of(domain, struct mock_iommu_domain, domain);
+}
+
struct mock_iommu_domain_nested {
struct iommu_domain domain;
+ struct mock_viommu *mock_viommu;
struct mock_iommu_domain *parent;
u32 iotlb[MOCK_NESTED_DOMAIN_IOTLB_NUM];
};
+static inline struct mock_iommu_domain_nested *
+to_mock_nested(struct iommu_domain *domain)
+{
+ return container_of(domain, struct mock_iommu_domain_nested, domain);
+}
+
+struct mock_viommu {
+ struct iommufd_viommu core;
+ struct mock_iommu_domain *s2_parent;
+};
+
+static inline struct mock_viommu *to_mock_viommu(struct iommufd_viommu *viommu)
+{
+ return container_of(viommu, struct mock_viommu, core);
+}
+
enum selftest_obj_type {
TYPE_IDEV,
};
@@ -140,8 +163,14 @@ struct mock_dev {
struct device dev;
unsigned long flags;
int id;
+ u32 cache[MOCK_DEV_CACHE_NUM];
};
+static inline struct mock_dev *to_mock_dev(struct device *dev)
+{
+ return container_of(dev, struct mock_dev, dev);
+}
+
struct selftest_obj {
struct iommufd_object obj;
enum selftest_obj_type type;
@@ -155,10 +184,15 @@ struct selftest_obj {
};
};
+static inline struct selftest_obj *to_selftest_obj(struct iommufd_object *obj)
+{
+ return container_of(obj, struct selftest_obj, obj);
+}
+
static int mock_domain_nop_attach(struct iommu_domain *domain,
struct device *dev)
{
- struct mock_dev *mdev = container_of(dev, struct mock_dev, dev);
+ struct mock_dev *mdev = to_mock_dev(dev);
if (domain->dirty_ops && (mdev->flags & MOCK_FLAGS_DEVICE_NO_DIRTY))
return -EINVAL;
@@ -193,8 +227,7 @@ static void *mock_domain_hw_info(struct device *dev, u32 *length, u32 *type)
static int mock_domain_set_dirty_tracking(struct iommu_domain *domain,
bool enable)
{
- struct mock_iommu_domain *mock =
- container_of(domain, struct mock_iommu_domain, domain);
+ struct mock_iommu_domain *mock = to_mock_domain(domain);
unsigned long flags = mock->flags;
if (enable && !domain->dirty_ops)
@@ -243,8 +276,7 @@ static int mock_domain_read_and_clear_dirty(struct iommu_domain *domain,
unsigned long flags,
struct iommu_dirty_bitmap *dirty)
{
- struct mock_iommu_domain *mock =
- container_of(domain, struct mock_iommu_domain, domain);
+ struct mock_iommu_domain *mock = to_mock_domain(domain);
unsigned long end = iova + size;
void *ent;
@@ -281,7 +313,7 @@ static const struct iommu_dirty_ops dirty_ops = {
static struct iommu_domain *mock_domain_alloc_paging(struct device *dev)
{
- struct mock_dev *mdev = container_of(dev, struct mock_dev, dev);
+ struct mock_dev *mdev = to_mock_dev(dev);
struct mock_iommu_domain *mock;
mock = kzalloc(sizeof(*mock), GFP_KERNEL);
@@ -298,76 +330,85 @@ static struct iommu_domain *mock_domain_alloc_paging(struct device *dev)
return &mock->domain;
}
-static struct iommu_domain *
-__mock_domain_alloc_nested(struct mock_iommu_domain *mock_parent,
- const struct iommu_hwpt_selftest *user_cfg)
+static struct mock_iommu_domain_nested *
+__mock_domain_alloc_nested(const struct iommu_user_data *user_data)
{
struct mock_iommu_domain_nested *mock_nested;
- int i;
+ struct iommu_hwpt_selftest user_cfg;
+ int rc, i;
+
+ if (user_data->type != IOMMU_HWPT_DATA_SELFTEST)
+ return ERR_PTR(-EOPNOTSUPP);
+
+ rc = iommu_copy_struct_from_user(&user_cfg, user_data,
+ IOMMU_HWPT_DATA_SELFTEST, iotlb);
+ if (rc)
+ return ERR_PTR(rc);
mock_nested = kzalloc(sizeof(*mock_nested), GFP_KERNEL);
if (!mock_nested)
return ERR_PTR(-ENOMEM);
- mock_nested->parent = mock_parent;
mock_nested->domain.ops = &domain_nested_ops;
mock_nested->domain.type = IOMMU_DOMAIN_NESTED;
for (i = 0; i < MOCK_NESTED_DOMAIN_IOTLB_NUM; i++)
- mock_nested->iotlb[i] = user_cfg->iotlb;
- return &mock_nested->domain;
+ mock_nested->iotlb[i] = user_cfg.iotlb;
+ return mock_nested;
}
static struct iommu_domain *
-mock_domain_alloc_user(struct device *dev, u32 flags,
- struct iommu_domain *parent,
- const struct iommu_user_data *user_data)
+mock_domain_alloc_nested(struct iommu_domain *parent, u32 flags,
+ const struct iommu_user_data *user_data)
{
+ struct mock_iommu_domain_nested *mock_nested;
struct mock_iommu_domain *mock_parent;
- struct iommu_hwpt_selftest user_cfg;
- int rc;
-
- /* must be mock_domain */
- if (!parent) {
- struct mock_dev *mdev = container_of(dev, struct mock_dev, dev);
- bool has_dirty_flag = flags & IOMMU_HWPT_ALLOC_DIRTY_TRACKING;
- bool no_dirty_ops = mdev->flags & MOCK_FLAGS_DEVICE_NO_DIRTY;
- struct iommu_domain *domain;
-
- if (flags & (~(IOMMU_HWPT_ALLOC_NEST_PARENT |
- IOMMU_HWPT_ALLOC_DIRTY_TRACKING)))
- return ERR_PTR(-EOPNOTSUPP);
- if (user_data || (has_dirty_flag && no_dirty_ops))
- return ERR_PTR(-EOPNOTSUPP);
- domain = mock_domain_alloc_paging(dev);
- if (!domain)
- return ERR_PTR(-ENOMEM);
- if (has_dirty_flag)
- container_of(domain, struct mock_iommu_domain, domain)
- ->domain.dirty_ops = &dirty_ops;
- return domain;
- }
- /* must be mock_domain_nested */
- if (user_data->type != IOMMU_HWPT_DATA_SELFTEST || flags)
+ if (flags)
return ERR_PTR(-EOPNOTSUPP);
if (!parent || parent->ops != mock_ops.default_domain_ops)
return ERR_PTR(-EINVAL);
- mock_parent = container_of(parent, struct mock_iommu_domain, domain);
+ mock_parent = to_mock_domain(parent);
if (!mock_parent)
return ERR_PTR(-EINVAL);
- rc = iommu_copy_struct_from_user(&user_cfg, user_data,
- IOMMU_HWPT_DATA_SELFTEST, iotlb);
- if (rc)
- return ERR_PTR(rc);
+ mock_nested = __mock_domain_alloc_nested(user_data);
+ if (IS_ERR(mock_nested))
+ return ERR_CAST(mock_nested);
+ mock_nested->parent = mock_parent;
+ return &mock_nested->domain;
+}
+
+static struct iommu_domain *
+mock_domain_alloc_user(struct device *dev, u32 flags,
+ struct iommu_domain *parent,
+ const struct iommu_user_data *user_data)
+{
+ bool has_dirty_flag = flags & IOMMU_HWPT_ALLOC_DIRTY_TRACKING;
+ const u32 PAGING_FLAGS = IOMMU_HWPT_ALLOC_DIRTY_TRACKING |
+ IOMMU_HWPT_ALLOC_NEST_PARENT;
+ bool no_dirty_ops = to_mock_dev(dev)->flags &
+ MOCK_FLAGS_DEVICE_NO_DIRTY;
+ struct iommu_domain *domain;
- return __mock_domain_alloc_nested(mock_parent, &user_cfg);
+ if (parent)
+ return mock_domain_alloc_nested(parent, flags, user_data);
+
+ if (user_data)
+ return ERR_PTR(-EOPNOTSUPP);
+ if ((flags & ~PAGING_FLAGS) || (has_dirty_flag && no_dirty_ops))
+ return ERR_PTR(-EOPNOTSUPP);
+
+ domain = mock_domain_alloc_paging(dev);
+ if (!domain)
+ return ERR_PTR(-ENOMEM);
+ if (has_dirty_flag)
+ domain->dirty_ops = &dirty_ops;
+ return domain;
}
static void mock_domain_free(struct iommu_domain *domain)
{
- struct mock_iommu_domain *mock =
- container_of(domain, struct mock_iommu_domain, domain);
+ struct mock_iommu_domain *mock = to_mock_domain(domain);
WARN_ON(!xa_empty(&mock->pfns));
kfree(mock);
@@ -378,8 +419,7 @@ static int mock_domain_map_pages(struct iommu_domain *domain,
size_t pgsize, size_t pgcount, int prot,
gfp_t gfp, size_t *mapped)
{
- struct mock_iommu_domain *mock =
- container_of(domain, struct mock_iommu_domain, domain);
+ struct mock_iommu_domain *mock = to_mock_domain(domain);
unsigned long flags = MOCK_PFN_START_IOVA;
unsigned long start_iova = iova;
@@ -430,8 +470,7 @@ static size_t mock_domain_unmap_pages(struct iommu_domain *domain,
size_t pgcount,
struct iommu_iotlb_gather *iotlb_gather)
{
- struct mock_iommu_domain *mock =
- container_of(domain, struct mock_iommu_domain, domain);
+ struct mock_iommu_domain *mock = to_mock_domain(domain);
bool first = true;
size_t ret = 0;
void *ent;
@@ -479,8 +518,7 @@ static size_t mock_domain_unmap_pages(struct iommu_domain *domain,
static phys_addr_t mock_domain_iova_to_phys(struct iommu_domain *domain,
dma_addr_t iova)
{
- struct mock_iommu_domain *mock =
- container_of(domain, struct mock_iommu_domain, domain);
+ struct mock_iommu_domain *mock = to_mock_domain(domain);
void *ent;
WARN_ON(iova % MOCK_IO_PAGE_SIZE);
@@ -491,7 +529,7 @@ static phys_addr_t mock_domain_iova_to_phys(struct iommu_domain *domain,
static bool mock_domain_capable(struct device *dev, enum iommu_cap cap)
{
- struct mock_dev *mdev = container_of(dev, struct mock_dev, dev);
+ struct mock_dev *mdev = to_mock_dev(dev);
switch (cap) {
case IOMMU_CAP_CACHE_COHERENCY:
@@ -507,14 +545,17 @@ static bool mock_domain_capable(struct device *dev, enum iommu_cap cap)
static struct iopf_queue *mock_iommu_iopf_queue;
-static struct iommu_device mock_iommu_device = {
-};
+static struct mock_iommu_device {
+ struct iommu_device iommu_dev;
+ struct completion complete;
+ refcount_t users;
+} mock_iommu;
static struct iommu_device *mock_probe_device(struct device *dev)
{
if (dev->bus != &iommufd_mock_bus_type.bus)
return ERR_PTR(-ENODEV);
- return &mock_iommu_device;
+ return &mock_iommu.iommu_dev;
}
static void mock_domain_page_response(struct device *dev, struct iopf_fault *evt,
@@ -540,6 +581,132 @@ static int mock_dev_disable_feat(struct device *dev, enum iommu_dev_features fea
return 0;
}
+static void mock_viommu_destroy(struct iommufd_viommu *viommu)
+{
+ struct mock_iommu_device *mock_iommu = container_of(
+ viommu->iommu_dev, struct mock_iommu_device, iommu_dev);
+
+ if (refcount_dec_and_test(&mock_iommu->users))
+ complete(&mock_iommu->complete);
+
+ /* iommufd core frees mock_viommu and viommu */
+}
+
+static struct iommu_domain *
+mock_viommu_alloc_domain_nested(struct iommufd_viommu *viommu, u32 flags,
+ const struct iommu_user_data *user_data)
+{
+ struct mock_viommu *mock_viommu = to_mock_viommu(viommu);
+ struct mock_iommu_domain_nested *mock_nested;
+
+ if (flags & ~IOMMU_HWPT_FAULT_ID_VALID)
+ return ERR_PTR(-EOPNOTSUPP);
+
+ mock_nested = __mock_domain_alloc_nested(user_data);
+ if (IS_ERR(mock_nested))
+ return ERR_CAST(mock_nested);
+ mock_nested->mock_viommu = mock_viommu;
+ mock_nested->parent = mock_viommu->s2_parent;
+ return &mock_nested->domain;
+}
+
+static int mock_viommu_cache_invalidate(struct iommufd_viommu *viommu,
+ struct iommu_user_data_array *array)
+{
+ struct iommu_viommu_invalidate_selftest *cmds;
+ struct iommu_viommu_invalidate_selftest *cur;
+ struct iommu_viommu_invalidate_selftest *end;
+ int rc;
+
+ /* A zero-length array is allowed to validate the array type */
+ if (array->entry_num == 0 &&
+ array->type == IOMMU_VIOMMU_INVALIDATE_DATA_SELFTEST) {
+ array->entry_num = 0;
+ return 0;
+ }
+
+ cmds = kcalloc(array->entry_num, sizeof(*cmds), GFP_KERNEL);
+ if (!cmds)
+ return -ENOMEM;
+ cur = cmds;
+ end = cmds + array->entry_num;
+
+ static_assert(sizeof(*cmds) == 3 * sizeof(u32));
+ rc = iommu_copy_struct_from_full_user_array(
+ cmds, sizeof(*cmds), array,
+ IOMMU_VIOMMU_INVALIDATE_DATA_SELFTEST);
+ if (rc)
+ goto out;
+
+ while (cur != end) {
+ struct mock_dev *mdev;
+ struct device *dev;
+ int i;
+
+ if (cur->flags & ~IOMMU_TEST_INVALIDATE_FLAG_ALL) {
+ rc = -EOPNOTSUPP;
+ goto out;
+ }
+
+ if (cur->cache_id > MOCK_DEV_CACHE_ID_MAX) {
+ rc = -EINVAL;
+ goto out;
+ }
+
+ xa_lock(&viommu->vdevs);
+ dev = iommufd_viommu_find_dev(viommu,
+ (unsigned long)cur->vdev_id);
+ if (!dev) {
+ xa_unlock(&viommu->vdevs);
+ rc = -EINVAL;
+ goto out;
+ }
+ mdev = container_of(dev, struct mock_dev, dev);
+
+ if (cur->flags & IOMMU_TEST_INVALIDATE_FLAG_ALL) {
+ /* Invalidate all cache entries and ignore cache_id */
+ for (i = 0; i < MOCK_DEV_CACHE_NUM; i++)
+ mdev->cache[i] = 0;
+ } else {
+ mdev->cache[cur->cache_id] = 0;
+ }
+ xa_unlock(&viommu->vdevs);
+
+ cur++;
+ }
+out:
+ array->entry_num = cur - cmds;
+ kfree(cmds);
+ return rc;
+}
+
+static struct iommufd_viommu_ops mock_viommu_ops = {
+ .destroy = mock_viommu_destroy,
+ .alloc_domain_nested = mock_viommu_alloc_domain_nested,
+ .cache_invalidate = mock_viommu_cache_invalidate,
+};
+
+static struct iommufd_viommu *mock_viommu_alloc(struct device *dev,
+ struct iommu_domain *domain,
+ struct iommufd_ctx *ictx,
+ unsigned int viommu_type)
+{
+ struct mock_iommu_device *mock_iommu =
+ iommu_get_iommu_dev(dev, struct mock_iommu_device, iommu_dev);
+ struct mock_viommu *mock_viommu;
+
+ if (viommu_type != IOMMU_VIOMMU_TYPE_SELFTEST)
+ return ERR_PTR(-EOPNOTSUPP);
+
+ mock_viommu = iommufd_viommu_alloc(ictx, struct mock_viommu, core,
+ &mock_viommu_ops);
+ if (IS_ERR(mock_viommu))
+ return ERR_CAST(mock_viommu);
+
+ refcount_inc(&mock_iommu->users);
+ return &mock_viommu->core;
+}
+
static const struct iommu_ops mock_ops = {
/*
* IOMMU_DOMAIN_BLOCKED cannot be returned from def_domain_type()
@@ -559,6 +726,7 @@ static const struct iommu_ops mock_ops = {
.dev_enable_feat = mock_dev_enable_feat,
.dev_disable_feat = mock_dev_disable_feat,
.user_pasid_table = true,
+ .viommu_alloc = mock_viommu_alloc,
.default_domain_ops =
&(struct iommu_domain_ops){
.free = mock_domain_free,
@@ -571,18 +739,14 @@ static const struct iommu_ops mock_ops = {
static void mock_domain_free_nested(struct iommu_domain *domain)
{
- struct mock_iommu_domain_nested *mock_nested =
- container_of(domain, struct mock_iommu_domain_nested, domain);
-
- kfree(mock_nested);
+ kfree(to_mock_nested(domain));
}
static int
mock_domain_cache_invalidate_user(struct iommu_domain *domain,
struct iommu_user_data_array *array)
{
- struct mock_iommu_domain_nested *mock_nested =
- container_of(domain, struct mock_iommu_domain_nested, domain);
+ struct mock_iommu_domain_nested *mock_nested = to_mock_nested(domain);
struct iommu_hwpt_invalidate_selftest inv;
u32 processed = 0;
int i = 0, j;
@@ -657,7 +821,7 @@ get_md_pagetable(struct iommufd_ucmd *ucmd, u32 mockpt_id,
iommufd_put_object(ucmd->ictx, &hwpt->obj);
return ERR_PTR(-EINVAL);
}
- *mock = container_of(hwpt->domain, struct mock_iommu_domain, domain);
+ *mock = to_mock_domain(hwpt->domain);
return hwpt;
}
@@ -675,14 +839,13 @@ get_md_pagetable_nested(struct iommufd_ucmd *ucmd, u32 mockpt_id,
iommufd_put_object(ucmd->ictx, &hwpt->obj);
return ERR_PTR(-EINVAL);
}
- *mock_nested = container_of(hwpt->domain,
- struct mock_iommu_domain_nested, domain);
+ *mock_nested = to_mock_nested(hwpt->domain);
return hwpt;
}
static void mock_dev_release(struct device *dev)
{
- struct mock_dev *mdev = container_of(dev, struct mock_dev, dev);
+ struct mock_dev *mdev = to_mock_dev(dev);
ida_free(&mock_dev_ida, mdev->id);
kfree(mdev);
@@ -691,7 +854,7 @@ static void mock_dev_release(struct device *dev)
static struct mock_dev *mock_dev_create(unsigned long dev_flags)
{
struct mock_dev *mdev;
- int rc;
+ int rc, i;
if (dev_flags &
~(MOCK_FLAGS_DEVICE_NO_DIRTY | MOCK_FLAGS_DEVICE_HUGE_IOVA))
@@ -705,6 +868,8 @@ static struct mock_dev *mock_dev_create(unsigned long dev_flags)
mdev->flags = dev_flags;
mdev->dev.release = mock_dev_release;
mdev->dev.bus = &iommufd_mock_bus_type.bus;
+ for (i = 0; i < MOCK_DEV_CACHE_NUM; i++)
+ mdev->cache[i] = IOMMU_TEST_DEV_CACHE_DEFAULT;
rc = ida_alloc(&mock_dev_ida, GFP_KERNEL);
if (rc < 0)
@@ -813,7 +978,7 @@ static int iommufd_test_mock_domain_replace(struct iommufd_ucmd *ucmd,
if (IS_ERR(dev_obj))
return PTR_ERR(dev_obj);
- sobj = container_of(dev_obj, struct selftest_obj, obj);
+ sobj = to_selftest_obj(dev_obj);
if (sobj->type != TYPE_IDEV) {
rc = -EINVAL;
goto out_dev_obj;
@@ -951,8 +1116,7 @@ static int iommufd_test_md_check_iotlb(struct iommufd_ucmd *ucmd,
if (IS_ERR(hwpt))
return PTR_ERR(hwpt);
- mock_nested = container_of(hwpt->domain,
- struct mock_iommu_domain_nested, domain);
+ mock_nested = to_mock_nested(hwpt->domain);
if (iotlb_id > MOCK_NESTED_DOMAIN_IOTLB_ID_MAX ||
mock_nested->iotlb[iotlb_id] != iotlb)
@@ -961,6 +1125,24 @@ static int iommufd_test_md_check_iotlb(struct iommufd_ucmd *ucmd,
return rc;
}
+static int iommufd_test_dev_check_cache(struct iommufd_ucmd *ucmd, u32 idev_id,
+ unsigned int cache_id, u32 cache)
+{
+ struct iommufd_device *idev;
+ struct mock_dev *mdev;
+ int rc = 0;
+
+ idev = iommufd_get_device(ucmd, idev_id);
+ if (IS_ERR(idev))
+ return PTR_ERR(idev);
+ mdev = container_of(idev->dev, struct mock_dev, dev);
+
+ if (cache_id > MOCK_DEV_CACHE_ID_MAX || mdev->cache[cache_id] != cache)
+ rc = -EINVAL;
+ iommufd_put_object(ucmd->ictx, &idev->obj);
+ return rc;
+}
+
struct selftest_access {
struct iommufd_access *access;
struct file *file;
@@ -1431,7 +1613,7 @@ static int iommufd_test_trigger_iopf(struct iommufd_ucmd *ucmd,
void iommufd_selftest_destroy(struct iommufd_object *obj)
{
- struct selftest_obj *sobj = container_of(obj, struct selftest_obj, obj);
+ struct selftest_obj *sobj = to_selftest_obj(obj);
switch (sobj->type) {
case TYPE_IDEV:
@@ -1470,6 +1652,10 @@ int iommufd_test(struct iommufd_ucmd *ucmd)
return iommufd_test_md_check_iotlb(ucmd, cmd->id,
cmd->check_iotlb.id,
cmd->check_iotlb.iotlb);
+ case IOMMU_TEST_OP_DEV_CHECK_CACHE:
+ return iommufd_test_dev_check_cache(ucmd, cmd->id,
+ cmd->check_dev_cache.id,
+ cmd->check_dev_cache.cache);
case IOMMU_TEST_OP_CREATE_ACCESS:
return iommufd_test_create_access(ucmd, cmd->id,
cmd->create_access.flags);
@@ -1536,24 +1722,27 @@ int __init iommufd_test_init(void)
if (rc)
goto err_platform;
- rc = iommu_device_sysfs_add(&mock_iommu_device,
+ rc = iommu_device_sysfs_add(&mock_iommu.iommu_dev,
&selftest_iommu_dev->dev, NULL, "%s",
dev_name(&selftest_iommu_dev->dev));
if (rc)
goto err_bus;
- rc = iommu_device_register_bus(&mock_iommu_device, &mock_ops,
+ rc = iommu_device_register_bus(&mock_iommu.iommu_dev, &mock_ops,
&iommufd_mock_bus_type.bus,
&iommufd_mock_bus_type.nb);
if (rc)
goto err_sysfs;
+ refcount_set(&mock_iommu.users, 1);
+ init_completion(&mock_iommu.complete);
+
mock_iommu_iopf_queue = iopf_queue_alloc("mock-iopfq");
return 0;
err_sysfs:
- iommu_device_sysfs_remove(&mock_iommu_device);
+ iommu_device_sysfs_remove(&mock_iommu.iommu_dev);
err_bus:
bus_unregister(&iommufd_mock_bus_type.bus);
err_platform:
@@ -1563,6 +1752,22 @@ err_dbgfs:
return rc;
}
+static void iommufd_test_wait_for_users(void)
+{
+ if (refcount_dec_and_test(&mock_iommu.users))
+ return;
+ /*
+ * Time out waiting for iommu device user count to become 0.
+ *
+ * Note that this is just making an example here, since the selftest is
+ * built into the iommufd module, i.e. it only unplugs the iommu device
+ * when unloading the module. So, it is expected that this WARN_ON will
+ * not trigger, as long as any iommufd FDs are open.
+ */
+ WARN_ON(!wait_for_completion_timeout(&mock_iommu.complete,
+ msecs_to_jiffies(10000)));
+}
+
void iommufd_test_exit(void)
{
if (mock_iommu_iopf_queue) {
@@ -1570,8 +1775,9 @@ void iommufd_test_exit(void)
mock_iommu_iopf_queue = NULL;
}
- iommu_device_sysfs_remove(&mock_iommu_device);
- iommu_device_unregister_bus(&mock_iommu_device,
+ iommufd_test_wait_for_users();
+ iommu_device_sysfs_remove(&mock_iommu.iommu_dev);
+ iommu_device_unregister_bus(&mock_iommu.iommu_dev,
&iommufd_mock_bus_type.bus,
&iommufd_mock_bus_type.nb);
bus_unregister(&iommufd_mock_bus_type.bus);
diff --git a/drivers/iommu/iommufd/vfio_compat.c b/drivers/iommu/iommufd/vfio_compat.c
index a3ad5f0b6c59..514aacd64009 100644
--- a/drivers/iommu/iommufd/vfio_compat.c
+++ b/drivers/iommu/iommufd/vfio_compat.c
@@ -291,12 +291,7 @@ static int iommufd_vfio_check_extension(struct iommufd_ctx *ictx,
case VFIO_DMA_CC_IOMMU:
return iommufd_vfio_cc_iommu(ictx);
- /*
- * This is obsolete, and to be removed from VFIO. It was an incomplete
- * idea that got merged.
- * https://lore.kernel.org/kvm/0-v1-0093c9b0e345+19-vfio_no_nesting_jgg@nvidia.com/
- */
- case VFIO_TYPE1_NESTING_IOMMU:
+ case __VFIO_RESERVED_TYPE1_NESTING_IOMMU:
return 0;
/*
diff --git a/drivers/iommu/iommufd/viommu.c b/drivers/iommu/iommufd/viommu.c
new file mode 100644
index 000000000000..69b88e8c7c26
--- /dev/null
+++ b/drivers/iommu/iommufd/viommu.c
@@ -0,0 +1,157 @@
+// SPDX-License-Identifier: GPL-2.0-only
+/* Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES
+ */
+#include "iommufd_private.h"
+
+void iommufd_viommu_destroy(struct iommufd_object *obj)
+{
+ struct iommufd_viommu *viommu =
+ container_of(obj, struct iommufd_viommu, obj);
+
+ if (viommu->ops && viommu->ops->destroy)
+ viommu->ops->destroy(viommu);
+ refcount_dec(&viommu->hwpt->common.obj.users);
+ xa_destroy(&viommu->vdevs);
+}
+
+int iommufd_viommu_alloc_ioctl(struct iommufd_ucmd *ucmd)
+{
+ struct iommu_viommu_alloc *cmd = ucmd->cmd;
+ struct iommufd_hwpt_paging *hwpt_paging;
+ struct iommufd_viommu *viommu;
+ struct iommufd_device *idev;
+ const struct iommu_ops *ops;
+ int rc;
+
+ if (cmd->flags || cmd->type == IOMMU_VIOMMU_TYPE_DEFAULT)
+ return -EOPNOTSUPP;
+
+ idev = iommufd_get_device(ucmd, cmd->dev_id);
+ if (IS_ERR(idev))
+ return PTR_ERR(idev);
+
+ ops = dev_iommu_ops(idev->dev);
+ if (!ops->viommu_alloc) {
+ rc = -EOPNOTSUPP;
+ goto out_put_idev;
+ }
+
+ hwpt_paging = iommufd_get_hwpt_paging(ucmd, cmd->hwpt_id);
+ if (IS_ERR(hwpt_paging)) {
+ rc = PTR_ERR(hwpt_paging);
+ goto out_put_idev;
+ }
+
+ if (!hwpt_paging->nest_parent) {
+ rc = -EINVAL;
+ goto out_put_hwpt;
+ }
+
+ viommu = ops->viommu_alloc(idev->dev, hwpt_paging->common.domain,
+ ucmd->ictx, cmd->type);
+ if (IS_ERR(viommu)) {
+ rc = PTR_ERR(viommu);
+ goto out_put_hwpt;
+ }
+
+ xa_init(&viommu->vdevs);
+ viommu->type = cmd->type;
+ viommu->ictx = ucmd->ictx;
+ viommu->hwpt = hwpt_paging;
+ refcount_inc(&viommu->hwpt->common.obj.users);
+ /*
+ * It is the most likely case that a physical IOMMU is unpluggable. A
+ * pluggable IOMMU instance (if exists) is responsible for refcounting
+ * on its own.
+ */
+ viommu->iommu_dev = __iommu_get_iommu_dev(idev->dev);
+
+ cmd->out_viommu_id = viommu->obj.id;
+ rc = iommufd_ucmd_respond(ucmd, sizeof(*cmd));
+ if (rc)
+ goto out_abort;
+ iommufd_object_finalize(ucmd->ictx, &viommu->obj);
+ goto out_put_hwpt;
+
+out_abort:
+ iommufd_object_abort_and_destroy(ucmd->ictx, &viommu->obj);
+out_put_hwpt:
+ iommufd_put_object(ucmd->ictx, &hwpt_paging->common.obj);
+out_put_idev:
+ iommufd_put_object(ucmd->ictx, &idev->obj);
+ return rc;
+}
+
+void iommufd_vdevice_destroy(struct iommufd_object *obj)
+{
+ struct iommufd_vdevice *vdev =
+ container_of(obj, struct iommufd_vdevice, obj);
+ struct iommufd_viommu *viommu = vdev->viommu;
+
+ /* xa_cmpxchg is okay to fail if alloc failed xa_cmpxchg previously */
+ xa_cmpxchg(&viommu->vdevs, vdev->id, vdev, NULL, GFP_KERNEL);
+ refcount_dec(&viommu->obj.users);
+ put_device(vdev->dev);
+}
+
+int iommufd_vdevice_alloc_ioctl(struct iommufd_ucmd *ucmd)
+{
+ struct iommu_vdevice_alloc *cmd = ucmd->cmd;
+ struct iommufd_vdevice *vdev, *curr;
+ struct iommufd_viommu *viommu;
+ struct iommufd_device *idev;
+ u64 virt_id = cmd->virt_id;
+ int rc = 0;
+
+ /* virt_id indexes an xarray */
+ if (virt_id > ULONG_MAX)
+ return -EINVAL;
+
+ viommu = iommufd_get_viommu(ucmd, cmd->viommu_id);
+ if (IS_ERR(viommu))
+ return PTR_ERR(viommu);
+
+ idev = iommufd_get_device(ucmd, cmd->dev_id);
+ if (IS_ERR(idev)) {
+ rc = PTR_ERR(idev);
+ goto out_put_viommu;
+ }
+
+ if (viommu->iommu_dev != __iommu_get_iommu_dev(idev->dev)) {
+ rc = -EINVAL;
+ goto out_put_idev;
+ }
+
+ vdev = iommufd_object_alloc(ucmd->ictx, vdev, IOMMUFD_OBJ_VDEVICE);
+ if (IS_ERR(vdev)) {
+ rc = PTR_ERR(vdev);
+ goto out_put_idev;
+ }
+
+ vdev->id = virt_id;
+ vdev->dev = idev->dev;
+ get_device(idev->dev);
+ vdev->viommu = viommu;
+ refcount_inc(&viommu->obj.users);
+
+ curr = xa_cmpxchg(&viommu->vdevs, virt_id, NULL, vdev, GFP_KERNEL);
+ if (curr) {
+ rc = xa_err(curr) ?: -EEXIST;
+ goto out_abort;
+ }
+
+ cmd->out_vdevice_id = vdev->obj.id;
+ rc = iommufd_ucmd_respond(ucmd, sizeof(*cmd));
+ if (rc)
+ goto out_abort;
+ iommufd_object_finalize(ucmd->ictx, &vdev->obj);
+ goto out_put_idev;
+
+out_abort:
+ iommufd_object_abort_and_destroy(ucmd->ictx, &vdev->obj);
+out_put_idev:
+ iommufd_put_object(ucmd->ictx, &idev->obj);
+out_put_viommu:
+ iommufd_put_object(ucmd->ictx, &viommu->obj);
+ return rc;
+}