Skip to content

Commit 2606c48

Browse files
committed
Restrict ranges of extension maps
To prevent copy statements from accessing arrays out of bounds, ranges of their extension maps are restricted, according to the constraints of domains. Reviewed-by: Michael Kruse <[email protected]> Differential Revision: https://reviews.llvm.org/D25655 llvm-svn: 289815
1 parent 2db6045 commit 2606c48

File tree

2 files changed

+147
-1
lines changed

2 files changed

+147
-1
lines changed

polly/lib/Transform/ScheduleOptimizer.cpp

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -851,6 +851,8 @@ createExtensionNode(__isl_take isl_schedule_node *Node,
851851
static __isl_give isl_schedule_node *optimizeDataLayoutMatrMulPattern(
852852
__isl_take isl_schedule_node *Node, __isl_take isl_map *MapOldIndVar,
853853
MicroKernelParamsTy MicroParams, MacroKernelParamsTy MacroParams) {
854+
// Check whether memory accesses of the SCoP statement correspond to
855+
// the matrix multiplication pattern and if this is true, obtain them.
854856
auto InputDimsId = isl_map_get_tuple_id(MapOldIndVar, isl_dim_in);
855857
auto *Stmt = static_cast<ScopStmt *>(isl_id_get_user(InputDimsId));
856858
isl_id_free(InputDimsId);
@@ -860,6 +862,9 @@ static __isl_give isl_schedule_node *optimizeDataLayoutMatrMulPattern(
860862
isl_map_free(MapOldIndVar);
861863
return Node;
862864
}
865+
866+
// Create a copy statement that corresponds to the memory access to the
867+
// matrix B, the second operand of the matrix multiplication.
863868
Node = isl_schedule_node_parent(isl_schedule_node_parent(Node));
864869
Node = isl_schedule_node_parent(isl_schedule_node_parent(Node));
865870
Node = isl_schedule_node_parent(Node);
@@ -879,10 +884,19 @@ static __isl_give isl_schedule_node *optimizeDataLayoutMatrMulPattern(
879884
isl_map_move_dims(ExtMap, isl_dim_in, 2, isl_dim_out, 0, 1);
880885
ExtMap = isl_map_project_out(ExtMap, isl_dim_in, 2, 1);
881886
auto *Domain = Stmt->getDomain();
887+
888+
// Restrict the domains of the copy statements to only execute when also its
889+
// originating statement is executed.
890+
auto *DomainId = isl_set_get_tuple_id(Domain);
882891
auto *NewStmt = Stmt->getParent()->addScopStmt(
883892
OldAcc, MemAccessB->getAccessRelation(), isl_set_copy(Domain));
893+
ExtMap = isl_map_set_tuple_id(ExtMap, isl_dim_out, isl_id_copy(DomainId));
894+
ExtMap = isl_map_intersect_range(ExtMap, isl_set_copy(Domain));
884895
ExtMap = isl_map_set_tuple_id(ExtMap, isl_dim_out, NewStmt->getDomainId());
885896
Node = createExtensionNode(Node, ExtMap);
897+
898+
// Create a copy statement that corresponds to the memory access
899+
// to the matrix A, the first operand of the matrix multiplication.
886900
Node = isl_schedule_node_child(Node, 0);
887901
AccRel = getMatMulAccRel(MapOldIndVar, MacroParams.Kc, 4, 6);
888902
FirstDimSize = MacroParams.Mc * MacroParams.Kc / MicroParams.Mr;
@@ -896,7 +910,12 @@ static __isl_give isl_schedule_node *optimizeDataLayoutMatrMulPattern(
896910
isl_map_move_dims(ExtMap, isl_dim_out, 0, isl_dim_in, 0, 1);
897911
isl_map_move_dims(ExtMap, isl_dim_in, 2, isl_dim_out, 0, 1);
898912
NewStmt = Stmt->getParent()->addScopStmt(
899-
OldAcc, MemAccessA->getAccessRelation(), Domain);
913+
OldAcc, MemAccessA->getAccessRelation(), isl_set_copy(Domain));
914+
915+
// Restrict the domains of the copy statements to only execute when also its
916+
// originating statement is executed.
917+
ExtMap = isl_map_set_tuple_id(ExtMap, isl_dim_out, DomainId);
918+
ExtMap = isl_map_intersect_range(ExtMap, Domain);
900919
ExtMap = isl_map_set_tuple_id(ExtMap, isl_dim_out, NewStmt->getDomainId());
901920
Node = createExtensionNode(Node, ExtMap);
902921
Node = isl_schedule_node_child(isl_schedule_node_child(Node, 0), 0);
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
; RUN: opt %loadPolly -polly-opt-isl -polly-pattern-matching-based-opts=true -polly-target-througput-vector-fma=1 -polly-target-latency-vector-fma=8 -polly-target-cache-level-associativity=8,8 -polly-target-cache-level-sizes=32768,262144 -polly-ast -analyze < %s | FileCheck %s
2+
;
3+
; /* C := alpha*A*B + beta*C */
4+
; /* _PB_NK % Kc != 0 */
5+
; for (i = 0; i < _PB_NI; i++)
6+
; for (j = 0; j < _PB_NJ; j++)
7+
; {
8+
; C[i][j] *= beta;
9+
; for (k = 0; k < _PB_NK; ++k)
10+
; C[i][j] += alpha * A[i][k] * B[k][j];
11+
; }
12+
;
13+
; CHECK: {
14+
; CHECK-NEXT: // 1st level tiling - Tiles
15+
; CHECK-NEXT: for (int c0 = 0; c0 <= 32; c0 += 1)
16+
; CHECK-NEXT: for (int c1 = 0; c1 <= 32; c1 += 1) {
17+
; CHECK-NEXT: // 1st level tiling - Points
18+
; CHECK-NEXT: for (int c2 = 0; c2 <= 31; c2 += 1)
19+
; CHECK-NEXT: for (int c3 = 0; c3 <= 31; c3 += 1)
20+
; CHECK-NEXT: Stmt_bb9(32 * c0 + c2, 32 * c1 + c3);
21+
; CHECK-NEXT: }
22+
; CHECK-NEXT: // 1st level tiling - Tiles
23+
; CHECK-NEXT: for (int c0 = 0; c0 <= 65; c0 += 1)
24+
; CHECK-NEXT: for (int c1 = 0; c1 <= 3; c1 += 1) {
25+
; CHECK-NEXT: for (int c3 = 16 * c0; c3 <= 16 * c0 + 15; c3 += 1)
26+
; CHECK-NEXT: for (int c4 = 256 * c1; c4 <= min(1022, 256 * c1 + 255); c4 += 1)
27+
; CHECK-NEXT: CopyStmt_0(0, c3, c4);
28+
; CHECK-NEXT: for (int c2 = 0; c2 <= 10; c2 += 1) {
29+
; CHECK-NEXT: for (int c3 = 96 * c2; c3 <= 96 * c2 + 95; c3 += 1)
30+
; CHECK-NEXT: for (int c5 = 256 * c1; c5 <= min(1022, 256 * c1 + 255); c5 += 1)
31+
; CHECK-NEXT: CopyStmt_1(c3, 0, c5);
32+
; CHECK-NEXT: // 1st level tiling - Points
33+
; CHECK-NEXT: // Register tiling - Tiles
34+
; CHECK-NEXT: for (int c3 = 0; c3 <= 1; c3 += 1)
35+
; CHECK-NEXT: for (int c4 = 0; c4 <= 23; c4 += 1)
36+
; CHECK-NEXT: for (int c5 = 0; c5 <= min(255, -256 * c1 + 1022); c5 += 1) {
37+
; CHECK-NEXT: // Register tiling - Points
38+
; CHECK-NEXT: // 1st level tiling - Tiles
39+
; CHECK-NEXT: // 1st level tiling - Points
40+
; CHECK-NEXT: {
41+
; CHECK-NEXT: Stmt_Copy_0(96 * c2 + 4 * c4, 16 * c0 + 8 * c3, 256 * c1 + c5);
42+
; CHECK-NEXT: Stmt_Copy_0(96 * c2 + 4 * c4, 16 * c0 + 8 * c3 + 1, 256 * c1 + c5);
43+
; CHECK-NEXT: Stmt_Copy_0(96 * c2 + 4 * c4, 16 * c0 + 8 * c3 + 2, 256 * c1 + c5);
44+
; CHECK-NEXT: Stmt_Copy_0(96 * c2 + 4 * c4, 16 * c0 + 8 * c3 + 3, 256 * c1 + c5);
45+
; CHECK-NEXT: Stmt_Copy_0(96 * c2 + 4 * c4, 16 * c0 + 8 * c3 + 4, 256 * c1 + c5);
46+
; CHECK-NEXT: Stmt_Copy_0(96 * c2 + 4 * c4, 16 * c0 + 8 * c3 + 5, 256 * c1 + c5);
47+
; CHECK-NEXT: Stmt_Copy_0(96 * c2 + 4 * c4, 16 * c0 + 8 * c3 + 6, 256 * c1 + c5);
48+
; CHECK-NEXT: Stmt_Copy_0(96 * c2 + 4 * c4, 16 * c0 + 8 * c3 + 7, 256 * c1 + c5);
49+
; CHECK-NEXT: Stmt_Copy_0(96 * c2 + 4 * c4 + 1, 16 * c0 + 8 * c3, 256 * c1 + c5);
50+
; CHECK-NEXT: Stmt_Copy_0(96 * c2 + 4 * c4 + 1, 16 * c0 + 8 * c3 + 1, 256 * c1 + c5);
51+
; CHECK-NEXT: Stmt_Copy_0(96 * c2 + 4 * c4 + 1, 16 * c0 + 8 * c3 + 2, 256 * c1 + c5);
52+
; CHECK-NEXT: Stmt_Copy_0(96 * c2 + 4 * c4 + 1, 16 * c0 + 8 * c3 + 3, 256 * c1 + c5);
53+
; CHECK-NEXT: Stmt_Copy_0(96 * c2 + 4 * c4 + 1, 16 * c0 + 8 * c3 + 4, 256 * c1 + c5);
54+
; CHECK-NEXT: Stmt_Copy_0(96 * c2 + 4 * c4 + 1, 16 * c0 + 8 * c3 + 5, 256 * c1 + c5);
55+
; CHECK-NEXT: Stmt_Copy_0(96 * c2 + 4 * c4 + 1, 16 * c0 + 8 * c3 + 6, 256 * c1 + c5);
56+
; CHECK-NEXT: Stmt_Copy_0(96 * c2 + 4 * c4 + 1, 16 * c0 + 8 * c3 + 7, 256 * c1 + c5);
57+
; CHECK-NEXT: Stmt_Copy_0(96 * c2 + 4 * c4 + 2, 16 * c0 + 8 * c3, 256 * c1 + c5);
58+
; CHECK-NEXT: Stmt_Copy_0(96 * c2 + 4 * c4 + 2, 16 * c0 + 8 * c3 + 1, 256 * c1 + c5);
59+
; CHECK-NEXT: Stmt_Copy_0(96 * c2 + 4 * c4 + 2, 16 * c0 + 8 * c3 + 2, 256 * c1 + c5);
60+
; CHECK-NEXT: Stmt_Copy_0(96 * c2 + 4 * c4 + 2, 16 * c0 + 8 * c3 + 3, 256 * c1 + c5);
61+
; CHECK-NEXT: Stmt_Copy_0(96 * c2 + 4 * c4 + 2, 16 * c0 + 8 * c3 + 4, 256 * c1 + c5);
62+
; CHECK-NEXT: Stmt_Copy_0(96 * c2 + 4 * c4 + 2, 16 * c0 + 8 * c3 + 5, 256 * c1 + c5);
63+
; CHECK-NEXT: Stmt_Copy_0(96 * c2 + 4 * c4 + 2, 16 * c0 + 8 * c3 + 6, 256 * c1 + c5);
64+
; CHECK-NEXT: Stmt_Copy_0(96 * c2 + 4 * c4 + 2, 16 * c0 + 8 * c3 + 7, 256 * c1 + c5);
65+
; CHECK-NEXT: Stmt_Copy_0(96 * c2 + 4 * c4 + 3, 16 * c0 + 8 * c3, 256 * c1 + c5);
66+
; CHECK-NEXT: Stmt_Copy_0(96 * c2 + 4 * c4 + 3, 16 * c0 + 8 * c3 + 1, 256 * c1 + c5);
67+
; CHECK-NEXT: Stmt_Copy_0(96 * c2 + 4 * c4 + 3, 16 * c0 + 8 * c3 + 2, 256 * c1 + c5);
68+
; CHECK-NEXT: Stmt_Copy_0(96 * c2 + 4 * c4 + 3, 16 * c0 + 8 * c3 + 3, 256 * c1 + c5);
69+
; CHECK-NEXT: Stmt_Copy_0(96 * c2 + 4 * c4 + 3, 16 * c0 + 8 * c3 + 4, 256 * c1 + c5);
70+
; CHECK-NEXT: Stmt_Copy_0(96 * c2 + 4 * c4 + 3, 16 * c0 + 8 * c3 + 5, 256 * c1 + c5);
71+
; CHECK-NEXT: Stmt_Copy_0(96 * c2 + 4 * c4 + 3, 16 * c0 + 8 * c3 + 6, 256 * c1 + c5);
72+
; CHECK-NEXT: Stmt_Copy_0(96 * c2 + 4 * c4 + 3, 16 * c0 + 8 * c3 + 7, 256 * c1 + c5);
73+
; CHECK-NEXT: }
74+
; CHECK-NEXT: }
75+
; CHECK-NEXT: }
76+
; CHECK-NEXT: }
77+
; CHECK-NEXT: }
78+
;
79+
target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128"
80+
target triple = "x86_64-unknown-unknown"
81+
82+
define internal void @kernel_gemm(i32 %arg, i32 %arg1, i32 %arg2, double %arg3, double %arg4, [1056 x double]* %arg5, [1023 x double]* %arg6, [1056 x double]* %arg7) #0 {
83+
bb:
84+
br label %bb8
85+
86+
bb8: ; preds = %bb29, %bb
87+
%tmp = phi i64 [ 0, %bb ], [ %tmp30, %bb29 ]
88+
br label %bb9
89+
90+
bb9: ; preds = %bb26, %bb8
91+
%tmp10 = phi i64 [ 0, %bb8 ], [ %tmp27, %bb26 ]
92+
%tmp11 = getelementptr inbounds [1056 x double], [1056 x double]* %arg5, i64 %tmp, i64 %tmp10
93+
%tmp12 = load double, double* %tmp11, align 8
94+
%tmp13 = fmul double %tmp12, %arg4
95+
store double %tmp13, double* %tmp11, align 8
96+
br label %Copy_0
97+
98+
Copy_0: ; preds = %Copy_0, %bb9
99+
%tmp15 = phi i64 [ 0, %bb9 ], [ %tmp24, %Copy_0 ]
100+
%tmp16 = getelementptr inbounds [1023 x double], [1023 x double]* %arg6, i64 %tmp, i64 %tmp15
101+
%tmp17 = load double, double* %tmp16, align 8
102+
%tmp18 = fmul double %tmp17, %arg3
103+
%tmp19 = getelementptr inbounds [1056 x double], [1056 x double]* %arg7, i64 %tmp15, i64 %tmp10
104+
%tmp20 = load double, double* %tmp19, align 8
105+
%tmp21 = fmul double %tmp18, %tmp20
106+
%tmp22 = load double, double* %tmp11, align 8
107+
%tmp23 = fadd double %tmp22, %tmp21
108+
store double %tmp23, double* %tmp11, align 8
109+
%tmp24 = add nuw nsw i64 %tmp15, 1
110+
%tmp25 = icmp ne i64 %tmp24, 1023
111+
br i1 %tmp25, label %Copy_0, label %bb26
112+
113+
bb26: ; preds = %Copy_0
114+
%tmp27 = add nuw nsw i64 %tmp10, 1
115+
%tmp28 = icmp ne i64 %tmp27, 1056
116+
br i1 %tmp28, label %bb9, label %bb29
117+
118+
bb29: ; preds = %bb26
119+
%tmp30 = add nuw nsw i64 %tmp, 1
120+
%tmp31 = icmp ne i64 %tmp30, 1056
121+
br i1 %tmp31, label %bb8, label %bb32
122+
123+
bb32: ; preds = %bb29
124+
ret void
125+
}
126+
127+
attributes #0 = { nounwind uwtable "target-cpu"="x86-64" "target-features"="+aes,+avx,+cmov,+cx16,+fxsr,+mmx,+pclmul,+popcnt,+sse,+sse2,+sse3,+sse4.1,+sse4.2,+ssse3,+x87,+xsave,+xsaveopt" }

0 commit comments

Comments
 (0)