Skip to content

Commit 8eb3272

Browse files
committed
[CudaMathEngine] CUBLAS_POINTER_MODE_DEVICE allows device pointers only
Signed-off-by: Kirill Golikov <[email protected]>
1 parent 5efd161 commit 8eb3272

File tree

4 files changed

+40
-20
lines changed

4 files changed

+40
-20
lines changed

NeoMathEngine/src/GPU/CUDA/CudaMathEngineCublas.cu

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -41,25 +41,6 @@ void CCudaMathEngine::VectorDotProduct(const CConstFloatHandle& firstHandle, con
4141
GetRaw( secondHandle ), 1, GetRaw( resultHandle ) ) );
4242
}
4343

44-
void CCudaMathEngine::VectorMultiplyAndAdd( const CConstFloatHandle& firstHandle, const CConstFloatHandle& secondHandle,
45-
const CFloatHandle& resultHandle, int vectorSize, CFloatParam multParam )
46-
{
47-
ASSERT_EXPR( firstHandle.GetMathEngine() == this );
48-
ASSERT_EXPR( secondHandle.GetMathEngine() == this );
49-
ASSERT_EXPR( resultHandle.GetMathEngine() == this );
50-
SetCudaDevice( device->DeviceNumber );
51-
52-
const float* const first = GetRaw( firstHandle );
53-
const float* const second = GetRaw( secondHandle );
54-
float* const result = GetRaw( resultHandle );
55-
// cublasSaxpy allows (host or device) pointer
56-
const float* mult = multParam.Handle.IsNull() ? &multParam.Value : GetRaw( multParam.Handle );
57-
58-
if( result != first ) {
59-
ASSERT_CUDA( cudaMemcpy( result, first, vectorSize * sizeof( float ), cudaMemcpyDeviceToDevice ) );
60-
}
61-
ASSERT_CUBLAS( cublas->Saxpy( cublasHandle, vectorSize, mult, second, 1, result, 1 ) );
62-
}
6344

6445
void CCudaMathEngine::MultiplyMatrixByTransposedMatrix( const CConstFloatHandle& firstHandle, int firstHeight,
6546
int firstWidth, int firstRowSize, const CConstFloatHandle& secondHandle, int secondHeight, int secondRowSize,

NeoMathEngine/src/GPU/CUDA/CudaMathEngineVectorMath.cu

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -951,6 +951,22 @@ void CCudaMathEngine::VectorSub(float first, const CConstFloatHandle& secondHand
951951
( first, GetRaw( secondHandle ), GetRaw(resultHandle), vectorSize);
952952
}
953953

954+
void CCudaMathEngine::VectorMultiplyAndAdd(const CConstFloatHandle& firstHandle, const CConstFloatHandle& secondHandle,
955+
const CFloatHandle& resultHandle, int vectorSize, CFloatParam mult)
956+
{
957+
ASSERT_EXPR(firstHandle.GetMathEngine() == this);
958+
ASSERT_EXPR(secondHandle.GetMathEngine() == this);
959+
ASSERT_EXPR(resultHandle.GetMathEngine() == this);
960+
SetCudaDevice(device->DeviceNumber);
961+
962+
int blockCount = 0;
963+
int threadCount = 0;
964+
getCudaTaskGrid(blockCount, threadCount, vectorSize);
965+
966+
VectorMultiplyAndAddKernel<<<blockCount, threadCount>>>
967+
( GetRaw(firstHandle), GetRaw(secondHandle), GetRaw(resultHandle), vectorSize, mult );
968+
}
969+
954970
void CCudaMathEngine::VectorMultiplyAndSub(const CConstFloatHandle& firstHandle, const CConstFloatHandle& secondHandle,
955971
const CFloatHandle& resultHandle, int vectorSize, CFloatParam mult)
956972
{

NeoMathEngine/src/GPU/CUDA/Kernels/CudaVectorMathKernels.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -979,6 +979,16 @@ __global__ void VectorSubKernel( float first, const float* __restrict__ second,
979979
}
980980
}
981981

982+
// MultiplyAndAdd
983+
__global__ void VectorMultiplyAndAddKernel( const float* __restrict__ first,
984+
const float* __restrict__ second, float* result, int count, CCudaScalarParameter<float> mult )
985+
{
986+
int index = 0;
987+
if( GetCudaTaskIndex( count, index ) ) {
988+
result[index] = first[index] + mult * second[index];
989+
}
990+
}
991+
982992
// MultiplyAndSub
983993
__global__ void VectorMultiplyAndSubKernel(const float* __restrict__ first,
984994
const float* __restrict__ second, float* result, int count, CCudaScalarParameter<float> mult)

NeoMathEngine/test/src/inference/VectorMultiplyAndAddTest.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright © 2017-2020 ABBYY Production LLC
1+
/* Copyright © 2017-2024 ABBYY
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -38,6 +38,19 @@ static void vectorMultiplyAndAddImpl( const CTestParams& params, int seed )
3838
float expected = a[i] + mult * b[i];
3939
ASSERT_NEAR( expected, result[i], 1e-3 );
4040
}
41+
42+
{
43+
auto resultWrapper = CARRAY_FLOAT_WRAPPER( result );
44+
{
45+
float multTemp = mult;
46+
MathEngine().VectorMultiplyAndAdd( CARRAY_FLOAT_WRAPPER( a ), CARRAY_FLOAT_WRAPPER( b ), resultWrapper, vectorSize, multTemp );
47+
}
48+
}
49+
50+
for( int i = 0; i < vectorSize; i++ ) {
51+
float expected = a[i] + mult * b[i];
52+
ASSERT_NEAR( expected, result[i], 1e-3 );
53+
}
4154
}
4255

4356
//------------------------------------------------------------------------------------------------------------

0 commit comments

Comments
 (0)