Skip to content

[libspirv][ptx-nvidiacl] Change __clc__group_scratch size to 32 x i128 #18431

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 21 additions & 36 deletions libclc/libspirv/lib/amdgcn-amdhsa/group/collectives_helpers.ll
Original file line number Diff line number Diff line change
@@ -1,62 +1,47 @@
; 32 storage locations is sufficient for all current-generation AMD GPUs
; 64 bits per wavefront is sufficient for all fundamental data types
; Reducing storage for small data types or increasing it for user-defined types
; will likely require an additional pass to track group algorithm usage
@__clc__group_scratch = internal addrspace(3) global [32 x i64] undef, align 1
@__clc__group_scratch_i1 = internal addrspace(3) global [32 x i1] poison, align 1
@__clc__group_scratch_i8 = internal addrspace(3) global [32 x i8] poison, align 1
@__clc__group_scratch_i16 = internal addrspace(3) global [32 x i16] poison, align 2
@__clc__group_scratch_i32 = internal addrspace(3) global [32 x i32] poison, align 4
@__clc__group_scratch_i64 = internal addrspace(3) global [32 x i64] poison, align 8

define i8 addrspace(3)* @__clc__get_group_scratch_bool() nounwind alwaysinline {
define ptr addrspace(3) @__clc__get_group_scratch_bool() nounwind alwaysinline {
entry:
%ptr = getelementptr inbounds [32 x i64], [32 x i64] addrspace(3)* @__clc__group_scratch, i64 0, i64 0
%cast = bitcast i64 addrspace(3)* %ptr to i8 addrspace(3)*
ret i8 addrspace(3)* %cast
ret ptr addrspace(3) @__clc__group_scratch_i1
}

define i8 addrspace(3)* @__clc__get_group_scratch_char() nounwind alwaysinline {
define ptr addrspace(3) @__clc__get_group_scratch_char() nounwind alwaysinline {
entry:
%ptr = getelementptr inbounds [32 x i64], [32 x i64] addrspace(3)* @__clc__group_scratch, i64 0, i64 0
%cast = bitcast i64 addrspace(3)* %ptr to i8 addrspace(3)*
ret i8 addrspace(3)* %cast
ret ptr addrspace(3) @__clc__group_scratch_i8
}

define i16 addrspace(3)* @__clc__get_group_scratch_short() nounwind alwaysinline {
define ptr addrspace(3) @__clc__get_group_scratch_short() nounwind alwaysinline {
entry:
%ptr = getelementptr inbounds [32 x i64], [32 x i64] addrspace(3)* @__clc__group_scratch, i64 0, i64 0
%cast = bitcast i64 addrspace(3)* %ptr to i16 addrspace(3)*
ret i16 addrspace(3)* %cast
ret ptr addrspace(3) @__clc__group_scratch_i16
}

define i32 addrspace(3)* @__clc__get_group_scratch_int() nounwind alwaysinline {
define ptr addrspace(3) @__clc__get_group_scratch_int() nounwind alwaysinline {
entry:
%ptr = getelementptr inbounds [32 x i64], [32 x i64] addrspace(3)* @__clc__group_scratch, i64 0, i64 0
%cast = bitcast i64 addrspace(3)* %ptr to i32 addrspace(3)*
ret i32 addrspace(3)* %cast
ret ptr addrspace(3) @__clc__group_scratch_i32
}

define i64 addrspace(3)* @__clc__get_group_scratch_long() nounwind alwaysinline {
define ptr addrspace(3) @__clc__get_group_scratch_long() nounwind alwaysinline {
entry:
%ptr = getelementptr inbounds [32 x i64], [32 x i64] addrspace(3)* @__clc__group_scratch, i64 0, i64 0
%cast = bitcast i64 addrspace(3)* %ptr to i64 addrspace(3)*
ret i64 addrspace(3)* %cast
ret ptr addrspace(3) @__clc__group_scratch_i64
}

define half addrspace(3)* @__clc__get_group_scratch_half() nounwind alwaysinline {
define ptr addrspace(3) @__clc__get_group_scratch_half() nounwind alwaysinline {
entry:
%ptr = getelementptr inbounds [32 x i64], [32 x i64] addrspace(3)* @__clc__group_scratch, i64 0, i64 0
%cast = bitcast i64 addrspace(3)* %ptr to half addrspace(3)*
ret half addrspace(3)* %cast
ret ptr addrspace(3) @__clc__group_scratch_i16
}

define float addrspace(3)* @__clc__get_group_scratch_float() nounwind alwaysinline {
define ptr addrspace(3) @__clc__get_group_scratch_float() nounwind alwaysinline {
entry:
%ptr = getelementptr inbounds [32 x i64], [32 x i64] addrspace(3)* @__clc__group_scratch, i64 0, i64 0
%cast = bitcast i64 addrspace(3)* %ptr to float addrspace(3)*
ret float addrspace(3)* %cast
ret ptr addrspace(3) @__clc__group_scratch_i32
}

define double addrspace(3)* @__clc__get_group_scratch_double() nounwind alwaysinline {
define ptr addrspace(3) @__clc__get_group_scratch_double() nounwind alwaysinline {
entry:
%ptr = getelementptr inbounds [32 x i64], [32 x i64] addrspace(3)* @__clc__group_scratch, i64 0, i64 0
%cast = bitcast i64 addrspace(3)* %ptr to double addrspace(3)*
ret double addrspace(3)* %cast
ret ptr addrspace(3) @__clc__group_scratch_i64
}

91 changes: 28 additions & 63 deletions libclc/libspirv/lib/ptx-nvidiacl/group/collectives_helpers.ll
Original file line number Diff line number Diff line change
@@ -1,97 +1,62 @@
; 32 storage locations is sufficient for all current-generation NVIDIA GPUs
; 128 bits per warp is sufficient for all fundamental data types and complex
; Reducing storage for small data types or increasing it for user-defined types
; will likely require an additional pass to track group algorithm usage
@__clc__group_scratch = internal addrspace(3) global [128 x i64] undef, align 1
@__clc__group_scratch_i1 = internal addrspace(3) global [32 x i1] poison, align 1
@__clc__group_scratch_i8 = internal addrspace(3) global [32 x i8] poison, align 1
@__clc__group_scratch_i16 = internal addrspace(3) global [32 x i16] poison, align 2
@__clc__group_scratch_i32 = internal addrspace(3) global [32 x i32] poison, align 4
@__clc__group_scratch_i64 = internal addrspace(3) global [32 x i64] poison, align 8
@__clc__group_scratch_i128 = internal addrspace(3) global [32 x i128] poison, align 8

define i8 addrspace(3)* @__clc__get_group_scratch_bool() nounwind alwaysinline {
define ptr addrspace(3) @__clc__get_group_scratch_bool() nounwind alwaysinline {
entry:
%ptr = getelementptr inbounds [128 x i64], [128 x i64] addrspace(3)* @__clc__group_scratch, i64 0, i64 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could/should probably rewrite this to use opaque pointers while we're here. If we do, almost all of this can go away. You could just return ptr addrspace(3) @__clc__group_scratch for every overload, I think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could/should probably rewrite this to use opaque pointers while we're here.

done

If we do, almost all of this can go away. You could just return ptr addrspace(3) @__clc__group_scratch for every overload, I think?

I added more global variables for different sizes. It should resolve the comment Reducing storage for small data types or increasing it for user-defined types will likely require an additional pass to track group algorithm usage on the top of the file.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I take it the scratch memory isn't mean to be shared between the different types? If so we couldn''t have separate globals in this way.

%cast = bitcast i64 addrspace(3)* %ptr to i8 addrspace(3)*
ret i8 addrspace(3)* %cast
ret ptr addrspace(3) @__clc__group_scratch_i1
}

define i8 addrspace(3)* @__clc__get_group_scratch_char() nounwind alwaysinline {
define ptr addrspace(3) @__clc__get_group_scratch_char() nounwind alwaysinline {
entry:
%ptr = getelementptr inbounds [128 x i64], [128 x i64] addrspace(3)* @__clc__group_scratch, i64 0, i64 0
%cast = bitcast i64 addrspace(3)* %ptr to i8 addrspace(3)*
ret i8 addrspace(3)* %cast
ret ptr addrspace(3) @__clc__group_scratch_i8
}

define i16 addrspace(3)* @__clc__get_group_scratch_short() nounwind alwaysinline {
define ptr addrspace(3) @__clc__get_group_scratch_short() nounwind alwaysinline {
entry:
%ptr = getelementptr inbounds [128 x i64], [128 x i64] addrspace(3)* @__clc__group_scratch, i64 0, i64 0
%cast = bitcast i64 addrspace(3)* %ptr to i16 addrspace(3)*
ret i16 addrspace(3)* %cast
ret ptr addrspace(3) @__clc__group_scratch_i16
}

define i32 addrspace(3)* @__clc__get_group_scratch_int() nounwind alwaysinline {
define ptr addrspace(3) @__clc__get_group_scratch_int() nounwind alwaysinline {
entry:
%ptr = getelementptr inbounds [128 x i64], [128 x i64] addrspace(3)* @__clc__group_scratch, i64 0, i64 0
%cast = bitcast i64 addrspace(3)* %ptr to i32 addrspace(3)*
ret i32 addrspace(3)* %cast
ret ptr addrspace(3) @__clc__group_scratch_i32
}

define i64 addrspace(3)* @__clc__get_group_scratch_long() nounwind alwaysinline {
define ptr addrspace(3) @__clc__get_group_scratch_long() nounwind alwaysinline {
entry:
%ptr = getelementptr inbounds [128 x i64], [128 x i64] addrspace(3)* @__clc__group_scratch, i64 0, i64 0
%cast = bitcast i64 addrspace(3)* %ptr to i64 addrspace(3)*
ret i64 addrspace(3)* %cast
ret ptr addrspace(3) @__clc__group_scratch_i64
}

define half addrspace(3)* @__clc__get_group_scratch_half() nounwind alwaysinline {
define ptr addrspace(3) @__clc__get_group_scratch_half() nounwind alwaysinline {
entry:
%ptr = getelementptr inbounds [128 x i64], [128 x i64] addrspace(3)* @__clc__group_scratch, i64 0, i64 0
%cast = bitcast i64 addrspace(3)* %ptr to half addrspace(3)*
ret half addrspace(3)* %cast
ret ptr addrspace(3) @__clc__group_scratch_i16
}

define float addrspace(3)* @__clc__get_group_scratch_float() nounwind alwaysinline {
define ptr addrspace(3) @__clc__get_group_scratch_float() nounwind alwaysinline {
entry:
%ptr = getelementptr inbounds [128 x i64], [128 x i64] addrspace(3)* @__clc__group_scratch, i64 0, i64 0
%cast = bitcast i64 addrspace(3)* %ptr to float addrspace(3)*
ret float addrspace(3)* %cast
ret ptr addrspace(3) @__clc__group_scratch_i32
}

define double addrspace(3)* @__clc__get_group_scratch_double() nounwind alwaysinline {
define ptr addrspace(3) @__clc__get_group_scratch_double() nounwind alwaysinline {
entry:
%ptr = getelementptr inbounds [128 x i64], [128 x i64] addrspace(3)* @__clc__group_scratch, i64 0, i64 0
%cast = bitcast i64 addrspace(3)* %ptr to double addrspace(3)*
ret double addrspace(3)* %cast
ret ptr addrspace(3) @__clc__group_scratch_i64
}

%complex_half = type {
half,
half
}

%complex_float = type {
float,
float
}

%complex_double = type {
double,
double
}

define %complex_half addrspace(3)* @__clc__get_group_scratch_complex_half() nounwind alwaysinline {
define ptr addrspace(3) @__clc__get_group_scratch_complex_half() nounwind alwaysinline {
entry:
%ptr = getelementptr inbounds [128 x i64], [128 x i64] addrspace(3)* @__clc__group_scratch, i64 0, i64 0
%cast = bitcast i64 addrspace(3)* %ptr to %complex_half addrspace(3)*
ret %complex_half addrspace(3)* %cast
ret ptr addrspace(3) @__clc__group_scratch_i32
}

define %complex_float addrspace(3)* @__clc__get_group_scratch_complex_float() nounwind alwaysinline {
define ptr addrspace(3) @__clc__get_group_scratch_complex_float() nounwind alwaysinline {
entry:
%ptr = getelementptr inbounds [128 x i64], [128 x i64] addrspace(3)* @__clc__group_scratch, i64 0, i64 0
%cast = bitcast i64 addrspace(3)* %ptr to %complex_float addrspace(3)*
ret %complex_float addrspace(3)* %cast
ret ptr addrspace(3) @__clc__group_scratch_i64
}

define %complex_double addrspace(3)* @__clc__get_group_scratch_complex_double() nounwind alwaysinline {
define ptr addrspace(3) @__clc__get_group_scratch_complex_double() nounwind alwaysinline {
entry:
%ptr = getelementptr inbounds [128 x i64], [128 x i64] addrspace(3)* @__clc__group_scratch, i64 0, i64 0
%cast = bitcast i64 addrspace(3)* %ptr to %complex_double addrspace(3)*
ret %complex_double addrspace(3)* %cast
ret ptr addrspace(3) @__clc__group_scratch_i128
}