Was anyone able to AC the last problem in Java? I have similar solution to the editorial but I kept TLE
import java.io.*;
import java.util.*;
public class Main {
final int IMAX = Integer.MAX_VALUE, IMIN = Integer.MIN_VALUE;
final long LMAX = Long.MAX_VALUE, LMIN = Long.MIN_VALUE;
public static void main(String[] args) {
InputReader in = new InputReader(System.in);
PrintWriter w = new PrintWriter(System.out);
int T = in.ii();
Main ok = new Main();
for (int i = 0; i < T; i++)
ok.solve(in, w);
w.close();
}
static class Node {
Node n0;
Node n1;
int count;
Node() {
n0 = null;
n1 = null;
count = 0;
}
}
Node root;
void insert(int val) {
Node tmp = root;
for (int i = 31; i >= 0; i--) {
if ((val & (1 << i)) == 0) { // current bit is 0
if (tmp.n0 == null)
tmp.n0 = new Node();
tmp = tmp.n0;
} else {
if (tmp.n1 == null)
tmp.n1 = new Node();
tmp = tmp.n1;
}
tmp.count++;
}
}
long count(Node cur, int i, int val, int high, boolean less) {
if (cur == null)
return 0;
long res = 0;
if (less)
return cur.count;
if (i < 0)
return res;
if ((high & (1 << i)) == 0) { // current high bit is 0
if ((val & (1 << i)) == 0) {// current val bit is 0
res += count(cur.n0, i - 1, val, high, less);
if (less)
res += count(cur.n1, i - 1, val, high, less);
} else { // current val bit is 1
if (less)
res += count(cur.n0, i - 1, val, high, less);
res += count(cur.n1, i - 1, val, high, less);
}
} else { // currrent high bit is 1
if ((val & (1 << i)) == 0) {// current val bit is 0
res += count(cur.n0, i - 1, val, high, true);
res += count(cur.n1, i - 1, val, high, less);
} else {
res += count(cur.n0, i - 1, val, high, less);
res += count(cur.n1, i - 1, val, high, true);
}
}
return res;
}
public long countPairs(int[] nums, int low, int high) {
root = new Node();
long res = 0;
for (int num : nums) {
res += count(root, 31, num, high + 1, false);
insert(num);
}
return res;
}
public long max_val(int[] nums) {
long res = 0;
for (int num : nums) {
res = (res | num);
}
return res;
}
public void solve(InputReader in, PrintWriter w) {
int n = in.ii();
long k = in.ll();
int[] nums = in.nextIntArray(n);
long left = 0;
long right = max_val(nums);
long res = right;
while (left <= right) {
long m = left + (right - left) / 2;
long count = countPairs(nums, 0, (int) m);
if (count >= k) {
res = Math.min(res, m);
right = m - 1;
} else {
left = m + 1;
}
}
w.println(res);
}
public static class InputReader {
private final InputStream stream;
private final byte[] buf = new byte[8192];
private int curChar, snumChars;
private SpaceCharFilter filter;
public InputReader(InputStream stream) {
this.stream = stream;
}
public int snext() {
if (snumChars == -1)
throw new InputMismatchException();
if (curChar >= snumChars) {
curChar = 0;
try {
snumChars = stream.read(buf);
} catch (IOException e) {
throw new InputMismatchException();
}
if (snumChars <= 0)
return -1;
}
return buf[curChar++];
}
public int ii() {
int c = snext();
while (isSpaceChar(c)) {
c = snext();
}
int sgn = 1;
if (c == '-') {
sgn = -1;
c = snext();
}
int res = 0;
do {
if (c < '0' || c > '9')
throw new InputMismatchException();
res *= 10;
res += c - '0';
c = snext();
} while (!isSpaceChar(c));
return res * sgn;
}
public long ll() {
int c = snext();
while (isSpaceChar(c)) {
c = snext();
}
int sgn = 1;
if (c == '-') {
sgn = -1;
c = snext();
}
long res = 0;
do {
if (c < '0' || c > '9')
throw new InputMismatchException();
res *= 10;
res += c - '0';
c = snext();
} while (!isSpaceChar(c));
return res * sgn;
}
public double nextDouble() {
return Double.parseDouble(readString());
}
public int[] nextIntArray(int n) {
int a[] = new int[n];
for (int i = 0; i < n; i++) {
a[i] = ii();
}
return a;
}
public long[] nextLongArray(int n) {
long a[] = new long[n];
for (int i = 0; i < n; i++) {
a[i] = ll();
}
return a;
}
public String readString() {
int c = snext();
while (isSpaceChar(c)) {
c = snext();
}
StringBuilder res = new StringBuilder();
do {
res.appendCodePoint(c);
c = snext();
} while (!isSpaceChar(c));
return res.toString();
}
public String nextLine() {
int c = snext();
while (isSpaceChar(c))
c = snext();
StringBuilder res = new StringBuilder();
do {
res.appendCodePoint(c);
c = snext();
} while (!isEndOfLine(c));
return res.toString();
}
public boolean isSpaceChar(int c) {
if (filter != null)
return filter.isSpaceChar(c);
return c == ' ' || c == '\n' || c == '\r' || c == '\t' || c == -1;
}
private boolean isEndOfLine(int c) {
return c == '\n' || c == '\r' || c == -1;
}
public interface SpaceCharFilter {
public boolean isSpaceChar(int ch);
}
}
class Pair<S extends Comparable<S>, T extends Comparable<T>> implements Comparable<Pair<S, T>> {
S first;
T second;
Pair(S f, T s) {
first = f;
second = s;
}
@Override
public int compareTo(Pair<S, T> o) {
int t = first.compareTo(o.first);
if (t == 0)
return second.compareTo(o.second);
return t;
}
@Override
public int hashCode() {
return (31 + first.hashCode()) * 31 + second.hashCode();
}
@Override
public boolean equals(Object o) {
if (!(o instanceof Pair))
return false;
if (o == this)
return true;
Pair p = (Pair) o;
return first.equals(p.first) && second.equals(p.second);
}
@Override
public String toString() {
return "Pair{" + first + ", " + second + "}";
}
}
}