|
17 | 17 | package org.springframework.graphql.server.webmvc;
|
18 | 18 |
|
19 | 19 | import java.io.IOException;
|
| 20 | +import java.io.InputStream; |
| 21 | +import java.lang.reflect.Type; |
20 | 22 | import java.util.Arrays;
|
21 | 23 | import java.util.List;
|
22 | 24 | import java.util.Map;
|
|
25 | 27 |
|
26 | 28 | import javax.servlet.ServletException;
|
27 | 29 | import javax.servlet.http.HttpServletRequest;
|
| 30 | +import javax.servlet.http.Part; |
28 | 31 |
|
29 | 32 | import com.fasterxml.jackson.databind.ObjectMapper;
|
30 | 33 | import org.apache.commons.logging.Log;
|
@@ -124,27 +127,26 @@ public ServerResponse handleRequest(ServerRequest serverRequest) throws ServletE
|
124 | 127 | }
|
125 | 128 |
|
126 | 129 | public ServerResponse handleMultipartRequest(ServerRequest serverRequest) throws ServletException {
|
127 |
| - Optional<String> operation = serverRequest.param("operations"); |
128 |
| - Optional<String> mapParam = serverRequest.param("map"); |
| 130 | + HttpServletRequest httpServletRequest = serverRequest.servletRequest(); |
129 | 131 |
|
130 |
| - Map<String, Object> inputQuery = operation |
131 |
| - .map(part -> |
132 |
| - partReader.<Map<String, Object>>readPart(part, MAP_PARAMETERIZED_TYPE_REF.getType()) |
133 |
| - ) |
134 |
| - .orElse(new HashMap<>()); |
| 132 | + Map<String, Object> inputQuery = Optional.ofNullable(this.<Map<String, Object>>deserializePart( |
| 133 | + httpServletRequest, |
| 134 | + "operations", |
| 135 | + MAP_PARAMETERIZED_TYPE_REF.getType() |
| 136 | + )).orElse(new HashMap<>()); |
135 | 137 |
|
136 | 138 | final Map<String, Object> queryVariables = getFromMapOrEmpty(inputQuery, "variables");
|
137 | 139 | final Map<String, Object> extensions = getFromMapOrEmpty(inputQuery, "extensions");
|
138 | 140 |
|
139 |
| - Map<String, MultipartFile> fileParams = readMultipartFiles(serverRequest); |
| 141 | + Map<String, MultipartFile> fileParams = readMultipartFiles(httpServletRequest); |
140 | 142 |
|
141 |
| - Map<String, List<String>> fileMapInput = |
142 |
| - mapParam.map(part -> |
143 |
| - partReader.<Map<String, List<String>>>readPart(part, LIST_PARAMETERIZED_TYPE_REF.getType()) |
144 |
| - ) |
145 |
| - .orElse(new HashMap<>()); |
| 143 | + Map<String, List<String>> fileMappings = Optional.ofNullable(this.<Map<String, List<String>>>deserializePart( |
| 144 | + httpServletRequest, |
| 145 | + "map", |
| 146 | + LIST_PARAMETERIZED_TYPE_REF.getType() |
| 147 | + )).orElse(new HashMap<>()); |
146 | 148 |
|
147 |
| - fileMapInput.forEach((String fileKey, List<String> objectPaths) -> { |
| 149 | + fileMappings.forEach((String fileKey, List<String> objectPaths) -> { |
148 | 150 | MultipartFile file = fileParams.get(fileKey);
|
149 | 151 | if (file != null) {
|
150 | 152 | objectPaths.forEach((String objectPath) -> {
|
@@ -186,16 +188,31 @@ public ServerResponse handleMultipartRequest(ServerRequest serverRequest) throws
|
186 | 188 | return ServerResponse.async(responseMono);
|
187 | 189 | }
|
188 | 190 |
|
189 |
| - private Map<String, Object> getFromMapOrEmpty(Map<String, Object> input, String key) { |
| 191 | + private <T> T deserializePart(HttpServletRequest httpServletRequest, String name, Type type) { |
| 192 | + try { |
| 193 | + Part part = httpServletRequest.getPart(name); |
| 194 | + if (part == null) { |
| 195 | + return null; |
| 196 | + } |
| 197 | + try(InputStream inputStream = part.getInputStream()) { |
| 198 | + return partReader.readPart(inputStream, type); |
| 199 | + } catch (IOException e) { |
| 200 | + throw new RuntimeException(e); |
| 201 | + } |
| 202 | + } catch (IOException | ServletException e) { |
| 203 | + throw new RuntimeException(e); |
| 204 | + } |
| 205 | + } |
| 206 | + |
| 207 | + private Map<String, Object> getFromMapOrEmpty(Map<String, Object> input, String key) { |
190 | 208 | if (input.containsKey(key)) {
|
191 | 209 | return (Map<String, Object>)input.get(key);
|
192 | 210 | } else {
|
193 | 211 | return new HashMap<>();
|
194 | 212 | }
|
195 | 213 | }
|
196 | 214 |
|
197 |
| - private static Map<String, MultipartFile> readMultipartFiles(ServerRequest request) { |
198 |
| - HttpServletRequest httpServletRequest = request.servletRequest(); |
| 215 | + private static Map<String, MultipartFile> readMultipartFiles(HttpServletRequest httpServletRequest) { |
199 | 216 | Assert.isInstanceOf(MultipartHttpServletRequest.class, httpServletRequest,
|
200 | 217 | "Request should be of type MultipartHttpServletRequest");
|
201 | 218 | MultipartHttpServletRequest multipartHttpServletRequest = (MultipartHttpServletRequest) httpServletRequest;
|
|
0 commit comments