import java.io.IOException; import java.util.Arrays; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.SortedDocValues; import org.apache.lucene.search.IndexSearcher; import org.apache.solr.search.DelegatingCollector; import org.apache.solr.search.ExtendedQueryBase; import org.apache.solr.search.PostFilter; public class AccessControlQuery extends ExtendedQueryBase implements PostFilter { private String user; private String[] groups; public AccessControlQuery(String user, String groups) { this.user = user; this.groups = groups.split(","); } /** * acl is in the form of a series of whitespace separated [+|-][u|g]:name * allowed is determined by any explicit user or group mentions, plus or minus * order matters * if nothing matches, it is not allowed */ public static boolean isAllowed(String acl, String user, String[] groups) { if (user == null && groups == null) return false; String[] permissions = acl.split(" "); for(String p : permissions) { boolean allowed = p.charAt(0) == '+'; String name = p.substring(3); if (p.charAt(1) == 'u') { // user if (user != null && user.equals(name)) return allowed; } else { // group if (groups != null) { for (String g : groups) { if (g.equals(name)) return allowed; } } } } return false; } @Override public boolean getCache() { return false; // never cache } @Override public int getCost() { return Math.max(super.getCost(), 100); // never return less than 100 since we only support post filtering } @Override public DelegatingCollector getFilterCollector(IndexSearcher searcher) { return new DelegatingCollector() { SortedDocValues acls; @Override protected void doSetNextReader(LeafReaderContext context) throws IOException { acls = context.reader().getSortedDocValues("acl"); super.doSetNextReader(context); } @Override public void collect(int doc) throws IOException { if (isAllowed(acls.get(doc).utf8ToString(), user, groups)) super.collect(doc); } }; } @Override public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; if (!super.equals(o)) return false; AccessControlQuery that = (AccessControlQuery) o; if (!Arrays.equals(groups, that.groups)) return false; if (user != null ? !user.equals(that.user) : that.user != null) return false; return true; } @Override public int hashCode() { int result = super.hashCode(); result = 31 * result + (user != null ? user.hashCode() : 0); result = 31 * result + (groups != null ? Arrays.hashCode(groups) : 0); return result; } public static void main(String[] args) { String acl = "+u:user1 +g:group1 -g:group2 +u:user2 -u:user3"; System.out.println("acl = " + acl); test(acl, "user1", null); test(acl, "user2", null); test(acl, "user1", new String[] {"group1"}); test(acl, "user2", new String[] {"group2"}); test(acl, "user3", new String[] {"group1"}); test(acl, "user3", new String[] {"group2"}); test(acl, "user3", new String[] {"group1","group2"}); } private static void test(String acl, String user, String[] groups) { System.out.println("user='" + user + "'" + ", groups=" + (groups == null ? null : Arrays.asList(groups)) + ": " + (isAllowed(acl, user, groups) ? "allowed" : "NOT ALLOWED")); } }