Skip to content

add SEA-LION support #6448

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 3, 2024
Merged

add SEA-LION support #6448

merged 6 commits into from
Apr 3, 2024

Conversation

bryanSwk
Copy link
Contributor

@bryanSwk bryanSwk commented Apr 3, 2024

This PR intends to add support for SEA-LION models, which is based on the MPT architecture with added bias, pos_embd and qk_ln layers.

This PR builds upon @datquocnguyen's PR with modifications by adding optional pos_embd and qk_ln layers.

Sanity checks have been done on SEA-LION 7B Instruct and MPT 7B Instruct.

Copy link
Contributor

github-actions bot commented Apr 3, 2024

📈 llama.cpp server for bench-server-baseline on Standard_NC4as_T4_v3: 494 iterations 🚀

  • Concurrent users: 8, duration: 10m
  • HTTP request : avg=9475.79ms p(90)=26962.72ms fails=0, finish reason: stop=494 truncated=0
  • Prompt processing (pp): avg=243.43tk/s p(90)=739.3tk/s total=197.3tk/s
  • Token generation (tg): avg=101.53tk/s p(90)=286.82tk/s total=129.45tk/s
  • ggml-org/models/phi-2/ggml-model-q4_0.gguf parallel=8 ctx-size=16384 ngl=33 batch-size=2048 ubatch-size=256 pp=1024 pp+tg=2048 branch=master commit=d9e48194e49ecf223d42730d6ed8f7169978b736
Time series

prompt_tokens_seconds

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 494 iterations"
    y-axis "llamacpp:prompt_tokens_seconds"
    x-axis "llamacpp:prompt_tokens_seconds" 1712113983 --> 1712114611
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 661.17, 661.17, 661.17, 661.17, 661.17, 606.9, 606.9, 606.9, 606.9, 606.9, 612.38, 612.38, 612.38, 612.38, 612.38, 606.3, 606.3, 606.3, 606.3, 606.3, 638.48, 638.48, 638.48, 638.48, 638.48, 649.8, 649.8, 649.8, 649.8, 649.8, 651.76, 651.76, 651.76, 651.76, 651.76, 676.64, 676.64, 676.64, 676.64, 676.64, 673.05, 673.05, 673.05, 673.05, 673.05, 670.86, 670.86, 670.86, 670.86, 670.86, 674.44, 674.44, 674.44, 674.44, 674.44, 687.94, 687.94, 687.94, 687.94, 687.94, 685.6, 685.6, 685.6, 685.6, 685.6, 703.33, 703.33, 703.33, 703.33, 703.33, 713.01, 713.01, 713.01, 713.01, 713.01, 684.75, 684.75, 684.75, 684.75, 684.75, 683.61, 683.61, 683.61, 683.61, 683.61, 688.77, 688.77, 688.77, 688.77, 688.77, 686.98, 686.98, 686.98, 686.98, 686.98, 696.91, 696.91, 696.91, 696.91, 696.91, 696.92, 696.92, 696.92, 696.92, 696.92, 696.39, 696.39, 696.39, 696.39, 696.39, 695.58, 695.58, 695.58, 695.58, 695.58, 699.27, 699.27, 699.27, 699.27, 699.27, 702.11, 702.11, 702.11, 702.11, 702.11, 718.25, 718.25, 718.25, 718.25, 718.25, 717.68, 717.68, 717.68, 717.68, 717.68, 718.28, 718.28, 718.28, 718.28, 718.28, 724.18, 724.18, 724.18, 724.18, 724.18, 725.38, 725.38, 725.38, 725.38, 725.38, 726.02, 726.02, 726.02, 726.02, 726.02, 729.56, 729.56, 729.56, 729.56, 729.56, 730.63, 730.63, 730.63, 730.63, 730.63, 729.31, 729.31, 729.31, 729.31, 729.31, 731.83, 731.83, 731.83, 731.83, 731.83, 741.0, 741.0, 741.0, 741.0, 741.0, 745.76, 745.76, 745.76, 745.76, 745.76, 749.94, 749.94, 749.94, 749.94, 749.94, 752.05, 752.05, 752.05, 752.05, 752.05, 749.92, 749.92, 749.92, 749.92, 749.92, 745.45, 745.45, 745.45, 745.45, 745.45, 744.94, 744.94, 744.94, 744.94, 744.94, 750.14, 750.14, 750.14, 750.14, 750.14, 752.6, 752.6, 752.6, 752.6, 752.6, 749.32, 749.32, 749.32, 749.32, 749.32, 744.53, 744.53, 744.53, 744.53, 744.53, 741.52, 741.52, 741.52, 741.52, 741.52, 739.28, 739.28, 739.28, 739.28, 739.28, 737.33, 737.33, 737.33, 737.33, 737.33, 737.6, 737.6, 737.6, 737.6, 737.6, 740.08, 740.08, 740.08, 740.08, 740.08, 740.65, 740.65, 740.65, 740.65, 740.65, 743.19, 743.19, 743.19, 743.19, 743.19, 745.6, 745.6, 745.6, 745.6, 745.6, 744.93, 744.93, 744.93, 744.93, 744.93, 745.3, 745.3, 745.3, 745.3, 745.3, 744.23, 744.23, 744.23, 744.23, 744.23, 743.82, 743.82, 743.82, 743.82, 743.82, 747.93, 747.93, 747.93, 747.93, 747.93, 747.24, 747.24, 747.24, 747.24, 747.24]
                    
Loading
predicted_tokens_seconds
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 494 iterations"
    y-axis "llamacpp:predicted_tokens_seconds"
    x-axis "llamacpp:predicted_tokens_seconds" 1712113983 --> 1712114611
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 26.15, 26.15, 26.15, 26.15, 26.15, 18.46, 18.46, 18.46, 18.46, 18.46, 17.28, 17.28, 17.28, 17.28, 17.28, 16.9, 16.9, 16.9, 16.9, 16.9, 16.74, 16.74, 16.74, 16.74, 16.74, 16.86, 16.86, 16.86, 16.86, 16.86, 17.75, 17.75, 17.75, 17.75, 17.75, 18.25, 18.25, 18.25, 18.25, 18.25, 18.63, 18.63, 18.63, 18.63, 18.63, 18.74, 18.74, 18.74, 18.74, 18.74, 18.86, 18.86, 18.86, 18.86, 18.86, 18.83, 18.83, 18.83, 18.83, 18.83, 18.79, 18.79, 18.79, 18.79, 18.79, 18.64, 18.64, 18.64, 18.64, 18.64, 18.26, 18.26, 18.26, 18.26, 18.26, 17.9, 17.9, 17.9, 17.9, 17.9, 17.8, 17.8, 17.8, 17.8, 17.8, 17.71, 17.71, 17.71, 17.71, 17.71, 17.79, 17.79, 17.79, 17.79, 17.79, 17.86, 17.86, 17.86, 17.86, 17.86, 17.78, 17.78, 17.78, 17.78, 17.78, 17.7, 17.7, 17.7, 17.7, 17.7, 17.7, 17.7, 17.7, 17.7, 17.7, 17.62, 17.62, 17.62, 17.62, 17.62, 17.64, 17.64, 17.64, 17.64, 17.64, 17.7, 17.7, 17.7, 17.7, 17.7, 17.68, 17.68, 17.68, 17.68, 17.68, 17.78, 17.78, 17.78, 17.78, 17.78, 17.85, 17.85, 17.85, 17.85, 17.85, 17.82, 17.82, 17.82, 17.82, 17.82, 17.86, 17.86, 17.86, 17.86, 17.86, 17.99, 17.99, 17.99, 17.99, 17.99, 18.09, 18.09, 18.09, 18.09, 18.09, 18.2, 18.2, 18.2, 18.2, 18.2, 18.34, 18.34, 18.34, 18.34, 18.34, 18.38, 18.38, 18.38, 18.38, 18.38, 18.33, 18.33, 18.33, 18.33, 18.33, 18.34, 18.34, 18.34, 18.34, 18.34, 18.23, 18.23, 18.23, 18.23, 18.23, 18.21, 18.21, 18.21, 18.21, 18.21, 18.25, 18.25, 18.25, 18.25, 18.25, 18.3, 18.3, 18.3, 18.3, 18.3, 18.34, 18.34, 18.34, 18.34, 18.34, 18.32, 18.32, 18.32, 18.32, 18.32, 18.23, 18.23, 18.23, 18.23, 18.23, 18.06, 18.06, 18.06, 18.06, 18.06, 17.82, 17.82, 17.82, 17.82, 17.82, 17.78, 17.78, 17.78, 17.78, 17.78, 17.61, 17.61, 17.61, 17.61, 17.61, 17.35, 17.35, 17.35, 17.35, 17.35, 17.34, 17.34, 17.34, 17.34, 17.34, 17.4, 17.4, 17.4, 17.4, 17.4, 17.44, 17.44, 17.44, 17.44, 17.44, 17.5, 17.5, 17.5, 17.5, 17.5, 17.52, 17.52, 17.52, 17.52, 17.52, 17.52, 17.52, 17.52, 17.52, 17.52, 17.5, 17.5, 17.5, 17.5, 17.5, 17.46, 17.46, 17.46, 17.46, 17.46, 17.46, 17.46, 17.46, 17.46, 17.46, 17.46, 17.46, 17.46, 17.46, 17.46]
                    
Loading

Details

kv_cache_usage_ratio

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 494 iterations"
    y-axis "llamacpp:kv_cache_usage_ratio"
    x-axis "llamacpp:kv_cache_usage_ratio" 1712113983 --> 1712114611
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.26, 0.26, 0.26, 0.26, 0.26, 0.32, 0.32, 0.32, 0.32, 0.32, 0.29, 0.29, 0.29, 0.29, 0.29, 0.3, 0.3, 0.3, 0.3, 0.3, 0.24, 0.24, 0.24, 0.24, 0.24, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.11, 0.18, 0.18, 0.18, 0.18, 0.18, 0.17, 0.17, 0.17, 0.17, 0.17, 0.25, 0.25, 0.25, 0.25, 0.25, 0.21, 0.21, 0.21, 0.21, 0.21, 0.19, 0.19, 0.19, 0.19, 0.19, 0.24, 0.24, 0.24, 0.24, 0.24, 0.31, 0.31, 0.31, 0.31, 0.31, 0.14, 0.14, 0.14, 0.14, 0.14, 0.18, 0.18, 0.18, 0.18, 0.18, 0.18, 0.18, 0.18, 0.18, 0.18, 0.18, 0.18, 0.18, 0.18, 0.18, 0.17, 0.17, 0.17, 0.17, 0.17, 0.29, 0.29, 0.29, 0.29, 0.29, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.31, 0.31, 0.31, 0.31, 0.31, 0.22, 0.22, 0.22, 0.22, 0.22, 0.16, 0.16, 0.16, 0.16, 0.16, 0.21, 0.21, 0.21, 0.21, 0.21, 0.14, 0.14, 0.14, 0.14, 0.14, 0.12, 0.12, 0.12, 0.12, 0.12, 0.21, 0.21, 0.21, 0.21, 0.21, 0.12, 0.12, 0.12, 0.12, 0.12, 0.11, 0.11, 0.11, 0.11, 0.11, 0.13, 0.13, 0.13, 0.13, 0.13, 0.1, 0.1, 0.1, 0.1, 0.1, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.17, 0.17, 0.17, 0.17, 0.17, 0.14, 0.14, 0.14, 0.14, 0.14, 0.25, 0.25, 0.25, 0.25, 0.25, 0.17, 0.17, 0.17, 0.17, 0.17, 0.14, 0.14, 0.14, 0.14, 0.14, 0.16, 0.16, 0.16, 0.16, 0.16, 0.08, 0.08, 0.08, 0.08, 0.08, 0.24, 0.24, 0.24, 0.24, 0.24, 0.35, 0.35, 0.35, 0.35, 0.35, 0.45, 0.45, 0.45, 0.45, 0.45, 0.48, 0.48, 0.48, 0.48, 0.48, 0.44, 0.44, 0.44, 0.44, 0.44, 0.46, 0.46, 0.46, 0.46, 0.46, 0.39, 0.39, 0.39, 0.39, 0.39, 0.12, 0.12, 0.12, 0.12, 0.12, 0.18, 0.18, 0.18, 0.18, 0.18, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.2, 0.2, 0.2, 0.2, 0.2, 0.24, 0.24, 0.24, 0.24, 0.24, 0.34, 0.34, 0.34, 0.34, 0.34, 0.28, 0.28, 0.28, 0.28, 0.28, 0.18, 0.18, 0.18, 0.18, 0.18, 0.21, 0.21, 0.21, 0.21, 0.21, 0.15, 0.15, 0.15, 0.15, 0.15]
                    
Loading
requests_processing
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 494 iterations"
    y-axis "llamacpp:requests_processing"
    x-axis "llamacpp:requests_processing" 1712113983 --> 1712114611
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0]
                    
Loading

@ggerganov ggerganov merged commit bb43cf7 into ggml-org:master Apr 3, 2024
def set_vocab(self):
try:
self._set_vocab_gpt2()
except:
Copy link
Collaborator

@HanClinto HanClinto Apr 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bryanSwk Why would _set_vocab_gpt2() ever fail, and why did you add this except? Trying to understand what the except clause is doing here, and if we should have it, or qualify it a bit more. This except being open-ended is breaking the CI linter. We can fix this by changing it to except Exception:, but I'd prefer to understand why this branch is here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @HanClinto, my intention for this try-except is to differentiate between the sealion variant of mpt, which utilises spm tokenizer.

There isn't a differentiating field in the config.json for sealion 7b as it also uses the MPTForCausalLM class. Hence, this is just a fallback for the sealion model.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very helpful, thank you!

Do you happen to know what kind of exception it will throw in that instance? If not, I can just change it to except Exception: and it should still pass lint

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, except Exception: will work, you can go ahead to make the change.

thank you!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good! Feel free to make any suggestions or wording changes on this PR:
#6470

I'm not very familiar with SEA-LION, so I welcome any adjustments you may have that would make things more clear. :)

tybalex pushed a commit to rubra-ai/tools.cpp that referenced this pull request Apr 17, 2024
* initial commit for sealion support

* add sealion support

* minor fix

* q/k ln and pos_embd only if required

* Apply suggestions from code review

Co-authored-by: Georgi Gerganov <[email protected]>

* minor : clear whitespaces

---------

Co-authored-by: bryan <[email protected]>
Co-authored-by: Georgi Gerganov <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants