Skip to content

Commit 002074b

Browse files
committed
Merge pull request keycloak#2320 from mposolda/master
KEYCLOAK-2523 Fix concurrency tests with all databases by track trans…
2 parents 375d4e9 + 286de3e commit 002074b

File tree

6 files changed

+109
-75
lines changed

6 files changed

+109
-75
lines changed

model/infinispan/src/main/java/org/keycloak/models/cache/infinispan/StreamCacheRealmProvider.java

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ public RealmModel getRealm(String id) {
305305
if (model == null) return null;
306306
if (invalidations.contains(id)) return model;
307307
cached = new CachedRealm(loaded, model);
308-
cache.addRevisioned(cached);
308+
cache.addRevisioned(cached, session);
309309
} else if (invalidations.contains(id)) {
310310
return getDelegate().getRealm(id);
311311
} else if (managedRealms.containsKey(id)) {
@@ -329,7 +329,7 @@ public RealmModel getRealmByName(String name) {
329329
if (model == null) return null;
330330
if (invalidations.contains(model.getId())) return model;
331331
query = new RealmListQuery(loaded, cacheKey, model.getId());
332-
cache.addRevisioned(query);
332+
cache.addRevisioned(query, session);
333333
return model;
334334
} else if (invalidations.contains(cacheKey)) {
335335
return getDelegate().getRealmByName(name);
@@ -435,7 +435,7 @@ public List<ClientModel> getClients(RealmModel realm) {
435435
for (ClientModel client : model) ids.add(client.getId());
436436
query = new ClientListQuery(loaded, cacheKey, realm, ids);
437437
logger.tracev("adding realm clients cache miss: realm {0} key {1}", realm.getName(), cacheKey);
438-
cache.addRevisioned(query);
438+
cache.addRevisioned(query, session);
439439
return model;
440440
}
441441
List<ClientModel> list = new LinkedList<>();
@@ -508,7 +508,7 @@ public Set<RoleModel> getRealmRoles(RealmModel realm) {
508508
for (RoleModel role : model) ids.add(role.getId());
509509
query = new RoleListQuery(loaded, cacheKey, realm, ids);
510510
logger.tracev("adding realm roles cache miss: realm {0} key {1}", realm.getName(), cacheKey);
511-
cache.addRevisioned(query);
511+
cache.addRevisioned(query, session);
512512
return model;
513513
}
514514
Set<RoleModel> list = new HashSet<>();
@@ -544,7 +544,7 @@ public Set<RoleModel> getClientRoles(RealmModel realm, ClientModel client) {
544544
for (RoleModel role : model) ids.add(role.getId());
545545
query = new RoleListQuery(loaded, cacheKey, realm, ids, client.getClientId());
546546
logger.tracev("adding client roles cache miss: client {0} key {1}", client.getClientId(), cacheKey);
547-
cache.addRevisioned(query);
547+
cache.addRevisioned(query, session);
548548
return model;
549549
}
550550
Set<RoleModel> list = new HashSet<>();
@@ -593,7 +593,7 @@ public RoleModel getRealmRole(RealmModel realm, String name) {
593593
if (model == null) return null;
594594
query = new RoleListQuery(loaded, cacheKey, realm, model.getId());
595595
logger.tracev("adding realm role cache miss: client {0} key {1}", realm.getName(), cacheKey);
596-
cache.addRevisioned(query);
596+
cache.addRevisioned(query, session);
597597
return model;
598598
}
599599
RoleModel role = getRoleById(query.getRoles().iterator().next(), realm);
@@ -623,7 +623,7 @@ public RoleModel getClientRole(RealmModel realm, ClientModel client, String name
623623
if (model == null) return null;
624624
query = new RoleListQuery(loaded, cacheKey, realm, model.getId(), client.getClientId());
625625
logger.tracev("adding client role cache miss: client {0} key {1}", client.getClientId(), cacheKey);
626-
cache.addRevisioned(query);
626+
cache.addRevisioned(query, session);
627627
return model;
628628
}
629629
RoleModel role = getRoleById(query.getRoles().iterator().next(), realm);
@@ -660,7 +660,7 @@ public RoleModel getRoleById(String id, RealmModel realm) {
660660
} else {
661661
cached = new CachedRealmRole(loaded, model, realm);
662662
}
663-
cache.addRevisioned(cached);
663+
cache.addRevisioned(cached, session);
664664

665665
} else if (invalidations.contains(id)) {
666666
return getDelegate().getRoleById(id, realm);
@@ -685,7 +685,7 @@ public GroupModel getGroupById(String id, RealmModel realm) {
685685
if (model == null) return null;
686686
if (invalidations.contains(id)) return model;
687687
cached = new CachedGroup(loaded, realm, model);
688-
cache.addRevisioned(cached);
688+
cache.addRevisioned(cached, session);
689689

690690
} else if (invalidations.contains(id)) {
691691
return getDelegate().getGroupById(id, realm);
@@ -725,7 +725,7 @@ public List<GroupModel> getGroups(RealmModel realm) {
725725
for (GroupModel client : model) ids.add(client.getId());
726726
query = new GroupListQuery(loaded, cacheKey, realm, ids);
727727
logger.tracev("adding realm getGroups cache miss: realm {0} key {1}", realm.getName(), cacheKey);
728-
cache.addRevisioned(query);
728+
cache.addRevisioned(query, session);
729729
return model;
730730
}
731731
List<GroupModel> list = new LinkedList<>();
@@ -761,7 +761,7 @@ public List<GroupModel> getTopLevelGroups(RealmModel realm) {
761761
for (GroupModel client : model) ids.add(client.getId());
762762
query = new GroupListQuery(loaded, cacheKey, realm, ids);
763763
logger.tracev("adding realm getTopLevelGroups cache miss: realm {0} key {1}", realm.getName(), cacheKey);
764-
cache.addRevisioned(query);
764+
cache.addRevisioned(query, session);
765765
return model;
766766
}
767767
List<GroupModel> list = new LinkedList<>();
@@ -837,7 +837,7 @@ public ClientModel getClientById(String id, RealmModel realm) {
837837
if (invalidations.contains(id)) return model;
838838
cached = new CachedClient(loaded, realm, model);
839839
logger.tracev("adding client by id cache miss: {0}", cached.getClientId());
840-
cache.addRevisioned(cached);
840+
cache.addRevisioned(cached, session);
841841
} else if (invalidations.contains(id)) {
842842
return getDelegate().getClientById(id, realm);
843843
} else if (managedApplications.containsKey(id)) {
@@ -866,7 +866,7 @@ public ClientModel getClientByClientId(String clientId, RealmModel realm) {
866866
id = model.getId();
867867
query = new ClientListQuery(loaded, cacheKey, realm, id);
868868
logger.tracev("adding client by name cache miss: {0}", clientId);
869-
cache.addRevisioned(query);
869+
cache.addRevisioned(query, session);
870870
} else if (invalidations.contains(cacheKey)) {
871871
return getDelegate().getClientByClientId(clientId, realm);
872872
} else {
@@ -895,7 +895,7 @@ public ClientTemplateModel getClientTemplateById(String id, RealmModel realm) {
895895
if (model == null) return null;
896896
if (invalidations.contains(id)) return model;
897897
cached = new CachedClientTemplate(loaded, realm, model);
898-
cache.addRevisioned(cached);
898+
cache.addRevisioned(cached, session);
899899
} else if (invalidations.contains(id)) {
900900
return getDelegate().getClientTemplateById(id, realm);
901901
} else if (managedClientTemplates.containsKey(id)) {

model/infinispan/src/main/java/org/keycloak/models/cache/infinispan/StreamRealmCache.java

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.infinispan.notifications.cachelistener.event.CacheEntriesEvictedEvent;
2525
import org.infinispan.notifications.cachelistener.event.CacheEntryInvalidatedEvent;
2626
import org.jboss.logging.Logger;
27+
import org.keycloak.models.KeycloakSession;
2728
import org.keycloak.models.cache.infinispan.entities.AbstractRevisioned;
2829
import org.keycloak.models.cache.infinispan.entities.CachedClient;
2930
import org.keycloak.models.cache.infinispan.entities.CachedClientTemplate;
@@ -38,7 +39,7 @@
3839
import org.keycloak.models.cache.infinispan.stream.InClientPredicate;
3940
import org.keycloak.models.cache.infinispan.stream.InRealmPredicate;
4041
import org.keycloak.models.cache.infinispan.stream.RealmQueryPredicate;
41-
import org.keycloak.models.cache.infinispan.stream.RoleQueryPredicate;
42+
import org.keycloak.models.utils.UpdateCounter;
4243

4344
import java.util.HashSet;
4445
import java.util.Iterator;
@@ -73,7 +74,9 @@ public Cache<String, Long> getRevisions() {
7374

7475
public Long getCurrentRevision(String id) {
7576
Long revision = revisions.get(id);
76-
if (revision == null) revision = UpdateCounter.current();
77+
if (revision == null) {
78+
revision = UpdateCounter.current();
79+
}
7780
// if you do cache.remove() on node 1 and the entry doesn't exist on node 2, node 2 never receives a invalidation event
7881
// so, we do this to force this.
7982
String invalidationKey = "invalidation.key" + id;
@@ -121,7 +124,7 @@ protected void bumpVersion(String id) {
121124
Object rev = revisions.put(id, next);
122125
}
123126

124-
public void addRevisioned(Revisioned object) {
127+
public void addRevisioned(Revisioned object, KeycloakSession session) {
125128
//startRevisionBatch();
126129
String id = object.getId();
127130
try {
@@ -135,12 +138,19 @@ public void addRevisioned(Revisioned object) {
135138
revisions.startBatch();
136139
if (!revisions.getAdvancedCache().lock(id)) {
137140
logger.trace("Could not obtain version lock");
141+
return;
138142
}
139143
rev = revisions.get(id);
140144
if (rev == null) {
141145
if (id.endsWith("realm.clients")) logger.trace("addRevisioned rev2 == null realm.clients");
142146
return;
143147
}
148+
if (rev > session.getTransaction().getStartupRevision()) { // revision is ahead transaction start. Other transaction updated in the meantime. Don't cache
149+
if (logger.isTraceEnabled()) {
150+
logger.tracev("Skipped cache. Current revision {0}, Transaction start revision {1}", object.getRevision(), session.getTransaction().getStartupRevision());
151+
}
152+
return;
153+
}
144154
if (rev.equals(object.getRevision())) {
145155
if (id.endsWith("realm.clients")) logger.tracev("adding Object.revision {0} rev {1}", object.getRevision(), rev);
146156
cache.putForExternalRead(id, object);

server-spi/src/main/java/org/keycloak/models/KeycloakTransactionManager.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
*/
2424
public interface KeycloakTransactionManager extends KeycloakTransaction {
2525

26+
long getStartupRevision();
27+
2628
void enlist(KeycloakTransaction transaction);
2729
void enlistAfterCompletion(KeycloakTransaction transaction);
2830

model/infinispan/src/main/java/org/keycloak/models/cache/infinispan/UpdateCounter.java renamed to server-spi/src/main/java/org/keycloak/models/utils/UpdateCounter.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1-
package org.keycloak.models.cache.infinispan;
1+
package org.keycloak.models.utils;
22

33
import java.util.concurrent.atomic.AtomicLong;
44

55
/**
6+
* Used to track cache revisions
7+
*
68
* @author <a href="mailto:[email protected]">Stian Thorgersen</a>
79
*/
810
public class UpdateCounter {

services/src/main/java/org/keycloak/services/DefaultKeycloakTransactionManager.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import org.keycloak.models.KeycloakTransaction;
2020
import org.keycloak.models.KeycloakTransactionManager;
21+
import org.keycloak.models.utils.UpdateCounter;
2122
import org.keycloak.services.ServicesLogger;
2223

2324
import java.util.LinkedList;
@@ -35,6 +36,12 @@ public class DefaultKeycloakTransactionManager implements KeycloakTransactionMan
3536
private List<KeycloakTransaction> afterCompletion = new LinkedList<KeycloakTransaction>();
3637
private boolean active;
3738
private boolean rollback;
39+
private long startupRevision;
40+
41+
@Override
42+
public long getStartupRevision() {
43+
return startupRevision;
44+
}
3845

3946
@Override
4047
public void enlist(KeycloakTransaction transaction) {
@@ -69,6 +76,8 @@ public void begin() {
6976
throw new IllegalStateException("Transaction already active");
7077
}
7178

79+
startupRevision = UpdateCounter.current();
80+
7281
for (KeycloakTransaction tx : transactions) {
7382
tx.begin();
7483
}

testsuite/integration/src/test/java/org/keycloak/testsuite/broker/AbstractKeycloakIdentityProviderTest.java

Lines changed: 68 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -464,72 +464,83 @@ public void testTokenStorageAndRetrievalByApplication() {
464464
setUpdateProfileFirstLogin(IdentityProviderRepresentation.UPFLM_ON);
465465
IdentityProviderModel identityProviderModel = getIdentityProviderModel();
466466

467-
identityProviderModel.setStoreToken(true);
467+
setStoreToken(identityProviderModel, true);
468+
try {
469+
authenticateWithIdentityProvider(identityProviderModel, "test-user", true);
468470

469-
authenticateWithIdentityProvider(identityProviderModel, "test-user", true);
471+
brokerServerRule.stopSession(session, true);
472+
session = brokerServerRule.startSession();
470473

471-
brokerServerRule.stopSession(session, true);
472-
session = brokerServerRule.startSession();
474+
UserModel federatedUser = getFederatedUser();
475+
RealmModel realm = getRealm();
476+
Set<FederatedIdentityModel> federatedIdentities = this.session.users().getFederatedIdentities(federatedUser, realm);
473477

474-
UserModel federatedUser = getFederatedUser();
475-
RealmModel realm = getRealm();
476-
Set<FederatedIdentityModel> federatedIdentities = this.session.users().getFederatedIdentities(federatedUser, realm);
478+
assertFalse(federatedIdentities.isEmpty());
479+
assertEquals(1, federatedIdentities.size());
477480

478-
assertFalse(federatedIdentities.isEmpty());
479-
assertEquals(1, federatedIdentities.size());
481+
FederatedIdentityModel identityModel = federatedIdentities.iterator().next();
480482

481-
FederatedIdentityModel identityModel = federatedIdentities.iterator().next();
482-
483-
assertNotNull(identityModel.getToken());
484-
485-
UserSessionStatusServlet.UserSessionStatus userSessionStatus = retrieveSessionStatus();
486-
String accessToken = userSessionStatus.getAccessTokenString();
487-
URI tokenEndpointUrl = Urls.identityProviderRetrieveToken(BASE_URI, getProviderId(), realm.getName());
488-
final String authHeader = "Bearer " + accessToken;
489-
ClientRequestFilter authFilter = new ClientRequestFilter() {
490-
@Override
491-
public void filter(ClientRequestContext requestContext) throws IOException {
492-
requestContext.getHeaders().add(HttpHeaders.AUTHORIZATION, authHeader);
493-
}
494-
};
495-
Client client = ClientBuilder.newBuilder().register(authFilter).build();
496-
WebTarget tokenEndpoint = client.target(tokenEndpointUrl);
497-
Response response = tokenEndpoint.request().get();
498-
assertEquals(Response.Status.OK.getStatusCode(), response.getStatus());
499-
assertNotNull(response.readEntity(String.class));
500-
revokeGrant();
483+
assertNotNull(identityModel.getToken());
501484

485+
UserSessionStatusServlet.UserSessionStatus userSessionStatus = retrieveSessionStatus();
486+
String accessToken = userSessionStatus.getAccessTokenString();
487+
URI tokenEndpointUrl = Urls.identityProviderRetrieveToken(BASE_URI, getProviderId(), realm.getName());
488+
final String authHeader = "Bearer " + accessToken;
489+
ClientRequestFilter authFilter = new ClientRequestFilter() {
490+
@Override
491+
public void filter(ClientRequestContext requestContext) throws IOException {
492+
requestContext.getHeaders().add(HttpHeaders.AUTHORIZATION, authHeader);
493+
}
494+
};
495+
Client client = ClientBuilder.newBuilder().register(authFilter).build();
496+
WebTarget tokenEndpoint = client.target(tokenEndpointUrl);
497+
Response response = tokenEndpoint.request().get();
498+
assertEquals(Response.Status.OK.getStatusCode(), response.getStatus());
499+
assertNotNull(response.readEntity(String.class));
500+
revokeGrant();
502501

503-
driver.navigate().to("http://localhost:8081/test-app/logout");
504-
String currentUrl = this.driver.getCurrentUrl();
505-
System.out.println("after logout currentUrl: " + currentUrl);
506-
assertTrue(currentUrl.startsWith("http://localhost:8081/auth/realms/realm-with-broker/protocol/openid-connect/auth"));
507-
508-
unconfigureUserRetrieveToken("test-user");
509-
loginIDP("test-user");
510-
//authenticateWithIdentityProvider(identityProviderModel, "test-user");
511-
assertEquals("http://localhost:8081/test-app", driver.getCurrentUrl());
512-
513-
userSessionStatus = retrieveSessionStatus();
514-
accessToken = userSessionStatus.getAccessTokenString();
515-
final String authHeader2 = "Bearer " + accessToken;
516-
ClientRequestFilter authFilter2 = new ClientRequestFilter() {
517-
@Override
518-
public void filter(ClientRequestContext requestContext) throws IOException {
519-
requestContext.getHeaders().add(HttpHeaders.AUTHORIZATION, authHeader2);
520-
}
521-
};
522-
client = ClientBuilder.newBuilder().register(authFilter2).build();
523-
tokenEndpoint = client.target(tokenEndpointUrl);
524-
response = tokenEndpoint.request().get();
525-
526-
assertEquals(Response.Status.FORBIDDEN.getStatusCode(), response.getStatus());
527502

528-
revokeGrant();
529-
driver.navigate().to("http://localhost:8081/test-app/logout");
530-
driver.navigate().to("http://localhost:8081/test-app");
503+
driver.navigate().to("http://localhost:8081/test-app/logout");
504+
String currentUrl = this.driver.getCurrentUrl();
505+
System.out.println("after logout currentUrl: " + currentUrl);
506+
assertTrue(currentUrl.startsWith("http://localhost:8081/auth/realms/realm-with-broker/protocol/openid-connect/auth"));
507+
508+
unconfigureUserRetrieveToken("test-user");
509+
loginIDP("test-user");
510+
//authenticateWithIdentityProvider(identityProviderModel, "test-user");
511+
assertEquals("http://localhost:8081/test-app", driver.getCurrentUrl());
512+
513+
userSessionStatus = retrieveSessionStatus();
514+
accessToken = userSessionStatus.getAccessTokenString();
515+
final String authHeader2 = "Bearer " + accessToken;
516+
ClientRequestFilter authFilter2 = new ClientRequestFilter() {
517+
@Override
518+
public void filter(ClientRequestContext requestContext) throws IOException {
519+
requestContext.getHeaders().add(HttpHeaders.AUTHORIZATION, authHeader2);
520+
}
521+
};
522+
client = ClientBuilder.newBuilder().register(authFilter2).build();
523+
tokenEndpoint = client.target(tokenEndpointUrl);
524+
response = tokenEndpoint.request().get();
525+
526+
assertEquals(Response.Status.FORBIDDEN.getStatusCode(), response.getStatus());
527+
528+
revokeGrant();
529+
driver.navigate().to("http://localhost:8081/test-app/logout");
530+
driver.navigate().to("http://localhost:8081/test-app");
531531

532-
assertTrue(this.driver.getCurrentUrl().startsWith("http://localhost:8081/auth/realms/realm-with-broker/protocol/openid-connect/auth"));
532+
assertTrue(this.driver.getCurrentUrl().startsWith("http://localhost:8081/auth/realms/realm-with-broker/protocol/openid-connect/auth"));
533+
} finally {
534+
setStoreToken(identityProviderModel, false);
535+
}
536+
}
537+
538+
private void setStoreToken(IdentityProviderModel identityProviderModel, boolean storeToken) {
539+
identityProviderModel.setStoreToken(storeToken);
540+
getRealm().updateIdentityProvider(identityProviderModel);
541+
542+
brokerServerRule.stopSession(session, storeToken);
543+
session = brokerServerRule.startSession();
533544
}
534545

535546
protected abstract void doAssertTokenRetrieval(String pageSource);

0 commit comments

Comments
 (0)