fork download
  1. #include <bits/stdc++.h>
  2.  
  3. #define name "test"
  4. #define FOR(i, a, b) for (int i = (a); i <= (int)(b); i++)
  5. #define ll long long
  6. #define fi first
  7. #define se second
  8. #define ii pair<int,int>
  9. #define pb push_back
  10. #define sz(a) a.size()
  11. #define mingdu signed main()
  12.  
  13. using namespace std;
  14.  
  15. const int N = 2e5 + 5;
  16. const int LOG = 20;
  17. const ll MOD = 1e9 + 7;
  18.  
  19. int n, q;
  20. vector<int> g[N];
  21. int h[N], id[N];
  22. vector<int> euler, dep;
  23. int st[LOG][2 * N];
  24. int pw2[2 * N];
  25. int tin[N], tout[N], timer;
  26.  
  27. void dfs1_iter(int root) {
  28. timer = 0;
  29. vector<pair<int,int>> stk;
  30. stk.reserve(n * 2);
  31. stk.push_back({root, 0});
  32. while (!stk.empty()) {
  33. auto cur = stk.back(); stk.pop_back();
  34. int u = cur.first, p = cur.second;
  35. if (u > 0) {
  36. tin[u] = ++timer;
  37. stk.push_back({-u, p});
  38. for (int i = (int)g[u].size() - 1; i >= 0; --i) {
  39. int v = g[u][i];
  40. if (v == p) continue;
  41. stk.push_back({v, u});
  42. }
  43. } else {
  44. u = -u;
  45. tout[u] = ++timer;
  46. }
  47. }
  48. }
  49.  
  50. bool is_anc(int a, int b) {
  51. return tin[a] <= tin[b] && tout[b] <= tout[a];
  52. }
  53.  
  54. struct SNode { int u, p, idx; };
  55. void dfs2_iter(int root) {
  56. euler.clear(); dep.clear();
  57. vector<SNode> stk;
  58. stk.reserve(n * 2);
  59. stk.push_back({root, 0, 0});
  60. while (!stk.empty()) {
  61. SNode &top = stk.back();
  62. int u = top.u, p = top.p, &idx = top.idx;
  63. if (idx == 0) {
  64. id[u] = sz(euler);
  65. euler.pb(u);
  66. dep.pb(h[u]);
  67. }
  68. if (idx < (int)g[u].size()) {
  69. int v = g[u][idx++];
  70. if (v == p) continue;
  71. h[v] = h[u] + 1;
  72. stk.push_back({v, u, 0});
  73. } else {
  74. stk.pop_back();
  75. if (!stk.empty()) {
  76. int par = stk.back().u;
  77. euler.pb(par);
  78. dep.pb(h[par]);
  79. }
  80. }
  81. }
  82. }
  83.  
  84. void RMQ() {
  85. int m = sz(dep);
  86. if (m == 0) return;
  87. for (int i = 0; i < m; ++i) st[0][i] = i;
  88. for (int k = 1; (1 << k) <= m; ++k)
  89. for (int i = 0; i + (1 << k) <= m; ++i) {
  90. int x = st[k - 1][i];
  91. int y = st[k - 1][i + (1 << (k - 1))];
  92. st[k][i] = (dep[x] < dep[y] ? x : y);
  93. }
  94. pw2[0] = 0;
  95. if (m >= 1) pw2[1] = 0;
  96. for (int i = 2; i <= m; ++i) pw2[i] = pw2[i / 2] + 1;
  97. }
  98.  
  99. int lca(int u, int v) {
  100. int L = id[u], R = id[v];
  101. if (L > R) swap(L, R);
  102. int len = R - L + 1;
  103. int k = pw2[len];
  104. int x = st[k][L], y = st[k][R - (1 << k) + 1];
  105. return dep[x] < dep[y] ? euler[x] : euler[y];
  106. }
  107.  
  108. vector<int> Nodes;
  109. vector<int> adj[N];
  110.  
  111. bool cmp(const int &a, const int &b) {
  112. return tin[a] < tin[b];
  113. }
  114.  
  115. int build(vector<int> nodes) {
  116. sort(nodes.begin(), nodes.end(), cmp);
  117. Nodes = nodes;
  118.  
  119. FOR(i, 0, sz(nodes) - 2)
  120. Nodes.pb(lca(nodes[i], nodes[i + 1]));
  121.  
  122. sort(Nodes.begin(), Nodes.end(), cmp);
  123. Nodes.erase(unique(Nodes.begin(), Nodes.end()), Nodes.end());
  124.  
  125. for (int x : Nodes) adj[x].clear();
  126.  
  127. vector<int> stk;
  128. stk.reserve(sz(Nodes));
  129. stk.push_back(Nodes[0]);
  130.  
  131. FOR(i, 1, sz(Nodes) - 1) {
  132. int u = Nodes[i];
  133. while (!stk.empty() && !is_anc(stk.back(), u)) stk.pop_back();
  134. if (!stk.empty()) adj[stk.back()].pb(u);
  135. stk.pb(u);
  136. }
  137.  
  138. return Nodes[0];
  139. }
  140.  
  141. ll solve_query(vector<int> &nodes) {
  142. int root = build(nodes);
  143. ll total_sum = 0;
  144. static bool inS[N];
  145. for (int x : Nodes) inS[x] = 0;
  146. for (int x : nodes) {
  147. inS[x] = 1;
  148. total_sum = (total_sum + x) % MOD;
  149. }
  150.  
  151. ll ans = 0;
  152. static ll sumSub[N];
  153. vector<int> order;
  154. order.reserve(sz(Nodes));
  155. stack<int> stck;
  156. stck.push(root);
  157. while (!stck.empty()) {
  158. int u = stck.top(); stck.pop();
  159. order.pb(u);
  160. for (int v : adj[u]) stck.push(v);
  161. }
  162.  
  163. for (int i = sz(order) - 1; i >= 0; --i) {
  164. int u = order[i];
  165. ll cur = (inS[u] ? u % MOD : 0);
  166. for (int v : adj[u]) {
  167. cur += sumSub[v];
  168. if (cur >= MOD) cur -= MOD;
  169. }
  170. sumSub[u] = cur;
  171.  
  172. for (int v : adj[u]) {
  173. ll s = sumSub[v];
  174. ll other = (total_sum - s + MOD) % MOD;
  175. ll len = h[v] - h[u];
  176. ll tmp = len * s % MOD * other % MOD;
  177. ans = (ans + tmp) % MOD;
  178. }
  179. }
  180.  
  181. return ans;
  182. }
  183.  
  184. void nhap() {
  185. cin >> n >> q;
  186. FOR(i, 1, n - 1) {
  187. int u, v; cin >> u >> v;
  188. g[u].pb(v); g[v].pb(u);
  189. }
  190. h[1] = 0;
  191. dfs1_iter(1);
  192. dfs2_iter(1);
  193. RMQ();
  194. }
  195.  
  196. void giai() {
  197. while (q--) {
  198. int k; cin >> k;
  199. vector<int> nodes(k);
  200. FOR(i, 0, k - 1) cin >> nodes[i];
  201. cout << solve_query(nodes) << '\n';
  202. }
  203. }
  204.  
  205. mingdu {
  206. ios_base::sync_with_stdio(0);
  207. cin.tie(0); cout.tie(0);
  208.  
  209. if (fopen(name".inp", "r")) {
  210. freopen(name".inp", "r", stdin);
  211. freopen(name".out", "w", stdout);
  212. }
  213.  
  214. nhap();
  215. giai();
  216.  
  217. return 0;
  218. }
  219.  
Success #stdin #stdout 0.01s 21820KB
stdin
Standard input is empty
stdout
Standard output is empty