|
@@ -678,9 +678,20 @@ struct mem_cgroup *mem_cgroup_from_task(struct task_struct *p)
|
|
|
}
|
|
|
EXPORT_SYMBOL(mem_cgroup_from_task);
|
|
|
|
|
|
-static struct mem_cgroup *get_mem_cgroup_from_mm(struct mm_struct *mm)
|
|
|
+/**
|
|
|
+ * get_mem_cgroup_from_mm: Obtain a reference on given mm_struct's memcg.
|
|
|
+ * @mm: mm from which memcg should be extracted. It can be NULL.
|
|
|
+ *
|
|
|
+ * Obtain a reference on mm->memcg and returns it if successful. Otherwise
|
|
|
+ * root_mem_cgroup is returned. However if mem_cgroup is disabled, NULL is
|
|
|
+ * returned.
|
|
|
+ */
|
|
|
+struct mem_cgroup *get_mem_cgroup_from_mm(struct mm_struct *mm)
|
|
|
{
|
|
|
- struct mem_cgroup *memcg = NULL;
|
|
|
+ struct mem_cgroup *memcg;
|
|
|
+
|
|
|
+ if (mem_cgroup_disabled())
|
|
|
+ return NULL;
|
|
|
|
|
|
rcu_read_lock();
|
|
|
do {
|
|
@@ -700,6 +711,24 @@ static struct mem_cgroup *get_mem_cgroup_from_mm(struct mm_struct *mm)
|
|
|
rcu_read_unlock();
|
|
|
return memcg;
|
|
|
}
|
|
|
+EXPORT_SYMBOL(get_mem_cgroup_from_mm);
|
|
|
+
|
|
|
+/**
|
|
|
+ * If current->active_memcg is non-NULL, do not fallback to current->mm->memcg.
|
|
|
+ */
|
|
|
+static __always_inline struct mem_cgroup *get_mem_cgroup_from_current(void)
|
|
|
+{
|
|
|
+ if (unlikely(current->active_memcg)) {
|
|
|
+ struct mem_cgroup *memcg = root_mem_cgroup;
|
|
|
+
|
|
|
+ rcu_read_lock();
|
|
|
+ if (css_tryget_online(¤t->active_memcg->css))
|
|
|
+ memcg = current->active_memcg;
|
|
|
+ rcu_read_unlock();
|
|
|
+ return memcg;
|
|
|
+ }
|
|
|
+ return get_mem_cgroup_from_mm(current->mm);
|
|
|
+}
|
|
|
|
|
|
/**
|
|
|
* mem_cgroup_iter - iterate over memory cgroup hierarchy
|
|
@@ -2261,7 +2290,7 @@ struct kmem_cache *memcg_kmem_get_cache(struct kmem_cache *cachep)
|
|
|
if (current->memcg_kmem_skip_account)
|
|
|
return cachep;
|
|
|
|
|
|
- memcg = get_mem_cgroup_from_mm(current->mm);
|
|
|
+ memcg = get_mem_cgroup_from_current();
|
|
|
kmemcg_id = READ_ONCE(memcg->kmemcg_id);
|
|
|
if (kmemcg_id < 0)
|
|
|
goto out;
|
|
@@ -2345,7 +2374,7 @@ int memcg_kmem_charge(struct page *page, gfp_t gfp, int order)
|
|
|
if (memcg_kmem_bypass())
|
|
|
return 0;
|
|
|
|
|
|
- memcg = get_mem_cgroup_from_mm(current->mm);
|
|
|
+ memcg = get_mem_cgroup_from_current();
|
|
|
if (!mem_cgroup_is_root(memcg)) {
|
|
|
ret = memcg_kmem_charge_memcg(page, gfp, order, memcg);
|
|
|
if (!ret)
|