Skip to content

Commit 41e98c7

Browse files
author
Nikita Konev
committed
Add MultipartFile support to WebMVC
1 parent b07a1df commit 41e98c7

File tree

2 files changed

+197
-0
lines changed

2 files changed

+197
-0
lines changed

spring-graphql/src/main/java/org/springframework/graphql/server/WebGraphQlRequest.java

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,34 @@ public WebGraphQlRequest(
6868
this.headers = headers;
6969
}
7070

71+
/**
72+
* Create an instance.
73+
* @param uri the URL for the HTTP request
74+
* @param headers the HTTP request headers
75+
* @param query GraphQL's query
76+
* @param operationName GraphQL's operation name
77+
* @param variables GraphQL's variables map
78+
* @param extensions GraphQL's extensions map
79+
* @param id an identifier for the GraphQL request
80+
* @param locale the locale from the HTTP request, if any
81+
*/
82+
public WebGraphQlRequest(
83+
URI uri, HttpHeaders headers,
84+
String query,
85+
String operationName,
86+
Map<String, Object> variables,
87+
Map<String, Object> extensions,
88+
String id, @Nullable Locale locale) {
89+
90+
super(query, operationName, variables, extensions, id, locale);
91+
92+
Assert.notNull(uri, "URI is required'");
93+
Assert.notNull(headers, "HttpHeaders is required'");
94+
95+
this.uri = UriComponentsBuilder.fromUri(uri).build(true);
96+
this.headers = headers;
97+
}
98+
7199
@SuppressWarnings("unchecked")
72100
private static <T> T getKey(String key, Map<String, Object> body) {
73101
if (key.equals("query") && !StringUtils.hasText((String) body.get(key))) {

spring-graphql/src/main/java/org/springframework/graphql/server/webmvc/GraphQlHttpHandler.java

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,20 @@
1818

1919
import java.io.IOException;
2020
import java.util.Arrays;
21+
import java.util.HashMap;
2122
import java.util.List;
2223
import java.util.Map;
24+
import java.util.Optional;
25+
import java.util.regex.Pattern;
2326

2427
import javax.servlet.ServletException;
2528

29+
import com.fasterxml.jackson.core.JsonProcessingException;
30+
import com.fasterxml.jackson.core.type.TypeReference;
2631
import org.apache.commons.logging.Log;
2732
import org.apache.commons.logging.LogFactory;
33+
import org.springframework.web.multipart.MultipartFile;
34+
import org.springframework.web.multipart.support.AbstractMultipartHttpServletRequest;
2835
import reactor.core.publisher.Mono;
2936

3037
import org.springframework.context.i18n.LocaleContextHolder;
@@ -102,6 +109,88 @@ public ServerResponse handleRequest(ServerRequest serverRequest) throws ServletE
102109
return ServerResponse.async(responseMono);
103110
}
104111

112+
public ServerResponse handleMultipartRequest(ServerRequest serverRequest) throws ServletException {
113+
Optional<String> operation = serverRequest.param("operations");
114+
Optional<String> mapParam = serverRequest.param("map");
115+
Map<String, Object> inputQuery = readJson(operation, new TypeReference<>() {});
116+
final Map<String, Object> queryVariables;
117+
if (inputQuery.containsKey("variables")) {
118+
queryVariables = (Map<String, Object>)inputQuery.get("variables");
119+
} else {
120+
queryVariables = new HashMap<>();
121+
}
122+
Map<String, Object> extensions = new HashMap<>();
123+
if (inputQuery.containsKey("extensions")) {
124+
extensions = (Map<String, Object>)inputQuery.get("extensions");
125+
}
126+
127+
Map<String, MultipartFile> fileParams = getMultipartMap(serverRequest);
128+
Map<String, List<String>> fileMapInput = readJson(mapParam, new TypeReference<>() {});
129+
fileMapInput.forEach((String fileKey, List<String> objectPaths) -> {
130+
MultipartFile file = fileParams.get(fileKey);
131+
if (file != null) {
132+
objectPaths.forEach((String objectPath) -> {
133+
MultipartVariableMapper.mapVariable(
134+
objectPath,
135+
queryVariables,
136+
file
137+
);
138+
});
139+
}
140+
});
141+
142+
String query = (String) inputQuery.get("query");
143+
String opName = (String) inputQuery.get("operationName");
144+
145+
WebGraphQlRequest graphQlRequest = new WebGraphQlRequest(
146+
serverRequest.uri(), serverRequest.headers().asHttpHeaders(),
147+
query,
148+
opName,
149+
queryVariables,
150+
extensions,
151+
this.idGenerator.generateId().toString(), LocaleContextHolder.getLocale());
152+
153+
if (logger.isDebugEnabled()) {
154+
logger.debug("Executing: " + graphQlRequest);
155+
}
156+
157+
Mono<ServerResponse> responseMono = this.graphQlHandler.handleRequest(graphQlRequest)
158+
.map(response -> {
159+
if (logger.isDebugEnabled()) {
160+
logger.debug("Execution complete");
161+
}
162+
ServerResponse.BodyBuilder builder = ServerResponse.ok();
163+
builder.headers(headers -> headers.putAll(response.getResponseHeaders()));
164+
builder.contentType(selectResponseMediaType(serverRequest));
165+
return builder.body(response.toMap());
166+
});
167+
168+
return ServerResponse.async(responseMono);
169+
}
170+
171+
private <T> T readJson(Optional<String> string, TypeReference<T> t) {
172+
Map<String, Object> map = new HashMap<>();
173+
if (string.isPresent()) {
174+
try {
175+
return objectMapper.readValue(string.get(), t);
176+
} catch (JsonProcessingException e) {
177+
throw new RuntimeException(e);
178+
}
179+
}
180+
return (T)map;
181+
}
182+
183+
private static Map<String, MultipartFile> getMultipartMap(ServerRequest request) {
184+
try {
185+
AbstractMultipartHttpServletRequest abstractMultipartHttpServletRequest =
186+
(AbstractMultipartHttpServletRequest) request.servletRequest();
187+
return abstractMultipartHttpServletRequest.getFileMap();
188+
}
189+
catch (RuntimeException ex) {
190+
throw new ServerWebInputException("Error while reading request parts", null, ex);
191+
}
192+
}
193+
105194
private static Map<String, Object> readBody(ServerRequest request) throws ServletException {
106195
try {
107196
return request.body(MAP_PARAMETERIZED_TYPE_REF);
@@ -121,3 +210,83 @@ private static MediaType selectResponseMediaType(ServerRequest serverRequest) {
121210
}
122211

123212
}
213+
214+
// As in DGS, this is borrowed from https://github.com/graphql-java-kickstart/graphql-java-servlet/blob/eb4dfdb5c0198adc1b4d4466c3b4ea4a77def5d1/graphql-java-servlet/src/main/java/graphql/kickstart/servlet/core/internal/VariableMapper.java
215+
class MultipartVariableMapper {
216+
217+
private static final Pattern PERIOD = Pattern.compile("\\.");
218+
219+
private static final Mapper<Map<String, Object>> MAP_MAPPER =
220+
new Mapper<Map<String, Object>>() {
221+
@Override
222+
public Object set(Map<String, Object> location, String target, MultipartFile value) {
223+
return location.put(target, value);
224+
}
225+
226+
@Override
227+
public Object recurse(Map<String, Object> location, String target) {
228+
return location.get(target);
229+
}
230+
};
231+
private static final Mapper<List<Object>> LIST_MAPPER =
232+
new Mapper<List<Object>>() {
233+
@Override
234+
public Object set(List<Object> location, String target, MultipartFile value) {
235+
return location.set(Integer.parseInt(target), value);
236+
}
237+
238+
@Override
239+
public Object recurse(List<Object> location, String target) {
240+
return location.get(Integer.parseInt(target));
241+
}
242+
};
243+
244+
@SuppressWarnings({"unchecked", "rawtypes"})
245+
public static void mapVariable(String objectPath, Map<String, Object> variables, MultipartFile part) {
246+
String[] segments = PERIOD.split(objectPath);
247+
248+
if (segments.length < 2) {
249+
throw new RuntimeException("object-path in map must have at least two segments");
250+
} else if (!"variables".equals(segments[0])) {
251+
throw new RuntimeException("can only map into variables");
252+
}
253+
254+
Object currentLocation = variables;
255+
for (int i = 1; i < segments.length; i++) {
256+
String segmentName = segments[i];
257+
Mapper mapper = determineMapper(currentLocation, objectPath, segmentName);
258+
259+
if (i == segments.length - 1) {
260+
if (null != mapper.set(currentLocation, segmentName, part)) {
261+
throw new RuntimeException("expected null value when mapping " + objectPath);
262+
}
263+
} else {
264+
currentLocation = mapper.recurse(currentLocation, segmentName);
265+
if (null == currentLocation) {
266+
throw new RuntimeException(
267+
"found null intermediate value when trying to map " + objectPath);
268+
}
269+
}
270+
}
271+
}
272+
273+
private static Mapper<?> determineMapper(
274+
Object currentLocation, String objectPath, String segmentName) {
275+
if (currentLocation instanceof Map) {
276+
return MAP_MAPPER;
277+
} else if (currentLocation instanceof List) {
278+
return LIST_MAPPER;
279+
}
280+
281+
throw new RuntimeException(
282+
"expected a map or list at " + segmentName + " when trying to map " + objectPath);
283+
}
284+
285+
interface Mapper<T> {
286+
287+
Object set(T location, String target, MultipartFile value);
288+
289+
Object recurse(T location, String target);
290+
}
291+
}
292+

0 commit comments

Comments
 (0)