diff --git a/webapp/src/edu/cornell/mannlib/vitro/webapp/controller/edit/Authenticate.java b/webapp/src/edu/cornell/mannlib/vitro/webapp/controller/edit/Authenticate.java index 715dea690..387010d72 100644 --- a/webapp/src/edu/cornell/mannlib/vitro/webapp/controller/edit/Authenticate.java +++ b/webapp/src/edu/cornell/mannlib/vitro/webapp/controller/edit/Authenticate.java @@ -27,7 +27,6 @@ import com.hp.hpl.jena.ontology.OntModel; import edu.cornell.mannlib.vedit.beans.LoginStatusBean; import edu.cornell.mannlib.vitro.webapp.beans.User; -import edu.cornell.mannlib.vitro.webapp.controller.Controllers; import edu.cornell.mannlib.vitro.webapp.controller.VitroHttpServlet; import edu.cornell.mannlib.vitro.webapp.controller.VitroRequest; import edu.cornell.mannlib.vitro.webapp.controller.authenticate.Authenticator; @@ -335,12 +334,18 @@ public class Authenticate extends VitroHttpServlet { } /** - * Exit: user is logging in, so show them the login screen. + * Exit: user is still logging in, so go back to the page they were on. */ private void showLoginScreen(VitroRequest vreq, HttpServletResponse response) throws IOException { log.debug("logging in."); - response.sendRedirect(getLoginScreenUrl(vreq)); + + String referringPage = vreq.getHeader("referer"); + if (referringPage == null) { + log.warn("No referring page on the request"); + referringPage = getHomeUrl(vreq); + } + response.sendRedirect(referringPage); return; } @@ -359,13 +364,6 @@ public class Authenticate extends VitroHttpServlet { return Authenticator.getInstance(request); } - /** What's the URL for the login screen? */ - private String getLoginScreenUrl(HttpServletRequest request) { - String contextPath = request.getContextPath(); - String urlParams = "?login=block"; - return contextPath + Controllers.LOGIN + urlParams; - } - /** What's the URL for the home page? */ private String getHomeUrl(HttpServletRequest request) { return request.getContextPath(); diff --git a/webapp/test/edu/cornell/mannlib/vitro/webapp/controller/edit/AuthenticateTest.java b/webapp/test/edu/cornell/mannlib/vitro/webapp/controller/edit/AuthenticateTest.java index 8e42f700b..e8e583a7a 100644 --- a/webapp/test/edu/cornell/mannlib/vitro/webapp/controller/edit/AuthenticateTest.java +++ b/webapp/test/edu/cornell/mannlib/vitro/webapp/controller/edit/AuthenticateTest.java @@ -44,8 +44,8 @@ public class AuthenticateTest extends AbstractTestClass { private static final String USER_OLDHAND_PASSWORD = "oldHandPassword"; private static final int USER_OLDHAND_LOGIN_COUNT = 100; - private static final String URL_LOGIN_PAGE = Controllers.LOGIN - + "?login=block"; + private static final String URL_LOGIN_PAGE = "http://my.local.site/vivo/" + + Controllers.LOGIN; private static final String URL_SITE_ADMIN_PAGE = Controllers.SITE_ADMIN + "?login=block"; private static final String URL_HOME_PAGE = ""; @@ -84,6 +84,7 @@ public class AuthenticateTest extends AbstractTestClass { request.setSession(session); request.setRequestUrl(new URL("http://this.that/vivo/siteAdmin")); request.setMethod("POST"); + request.setHeader("referer", URL_LOGIN_PAGE); response = new HttpServletResponseStub(); @@ -124,7 +125,7 @@ public class AuthenticateTest extends AbstractTestClass { auth.doPost(request, response); - assertExpectedRedirect(URL_LOGIN_PAGE); + assertExpectedRedirect(URL_SITE_ADMIN_PAGE); assertNoProcessBean(); assertExpectedLoginSessions(); } @@ -197,7 +198,7 @@ public class AuthenticateTest extends AbstractTestClass { auth.doPost(request, response); assertNoProcessBean(); - assertExpectedRedirect(URL_LOGIN_PAGE); + assertExpectedRedirect(URL_SITE_ADMIN_PAGE); assertExpectedLoginSessions(USER_OLDHAND_NAME); } @@ -277,7 +278,7 @@ public class AuthenticateTest extends AbstractTestClass { auth.doPost(request, response); assertNoProcessBean(); - assertExpectedRedirect(URL_LOGIN_PAGE); + assertExpectedRedirect(URL_SITE_ADMIN_PAGE); assertExpectedLoginSessions(USER_DBA_NAME); assertExpectedPasswordChanges(USER_DBA_NAME, "NewPassword"); } @@ -351,8 +352,13 @@ public class AuthenticateTest extends AbstractTestClass { } private void assertExpectedRedirect(String path) { - assertEquals("redirect", request.getContextPath() + path, - response.getRedirectLocation()); + if (path.startsWith("http://")) { + assertEquals("absolute redirect", path, + response.getRedirectLocation()); + } else { + assertEquals("relative redirect", request.getContextPath() + path, + response.getRedirectLocation()); + } } /** This is for explicit redirect URLs that already include context. */ diff --git a/webapp/test/stubs/javax/servlet/http/HttpServletRequestStub.java b/webapp/test/stubs/javax/servlet/http/HttpServletRequestStub.java index f6f52d746..6f97991ee 100644 --- a/webapp/test/stubs/javax/servlet/http/HttpServletRequestStub.java +++ b/webapp/test/stubs/javax/servlet/http/HttpServletRequestStub.java @@ -36,10 +36,12 @@ public class HttpServletRequestStub implements HttpServletRequest { private HttpSession session; private final Map> parameters; private final Map attributes; + private final Map> headers; public HttpServletRequestStub() { parameters = new HashMap>(); attributes = new HashMap(); + headers = new HashMap>(); } public HttpServletRequestStub(Map> parameters, @@ -61,6 +63,14 @@ public class HttpServletRequestStub implements HttpServletRequest { public void setRemoteAddr(String remoteAddr) { this.remoteAddr = remoteAddr; } + + public void setHeader(String name, String value) { + name = name.toLowerCase(); + if (!headers.containsKey(name)) { + headers.put(name, new ArrayList()); + } + headers.get(name).add(value); + } public void addParameter(String name, String value) { if (!parameters.containsKey(name)) { @@ -163,6 +173,30 @@ public class HttpServletRequestStub implements HttpServletRequest { attributes.put(name, value); } + @SuppressWarnings("rawtypes") + public Enumeration getHeaderNames() { + return Collections.enumeration(headers.keySet()); + } + + public String getHeader(String name) { + name = name.toLowerCase(); + if (headers.containsKey(name)) { + return headers.get(name).get(0); + } else { + return null; + } + } + + @SuppressWarnings("rawtypes") + public Enumeration getHeaders(String name) { + name = name.toLowerCase(); + if (headers.containsKey(name)) { + return Collections.enumeration(headers.get(name)); + } else { + return Collections.enumeration(Collections.emptyList()); + } + } + // ---------------------------------------------------------------------- // Un-implemented methods // ---------------------------------------------------------------------- @@ -182,23 +216,6 @@ public class HttpServletRequestStub implements HttpServletRequest { "HttpServletRequestStub.getDateHeader() not implemented."); } - public String getHeader(String arg0) { - throw new RuntimeException( - "HttpServletRequestStub.getHeader() not implemented."); - } - - @SuppressWarnings("rawtypes") - public Enumeration getHeaderNames() { - throw new RuntimeException( - "HttpServletRequestStub.getHeaderNames() not implemented."); - } - - @SuppressWarnings("rawtypes") - public Enumeration getHeaders(String arg0) { - throw new RuntimeException( - "HttpServletRequestStub.getHeaders() not implemented."); - } - public int getIntHeader(String arg0) { throw new RuntimeException( "HttpServletRequestStub.getIntHeader() not implemented.");