Overcoming OSS Contribution Anxiety

Author

Yusuf Mohammad

Published

April 14, 2026

OSS contribution is achievable for everyone, with the help and guidance of LLMs I believe anyone can grow into a great contributor. In my case I worked on two PRs involving the batch invariance feature and adding support for a new model type for this feature in vLLM.

Starting Small in vLLM - My Approach

I’ve always been interested in OSS contribution. Most of the work we do is underpinned by OSS (Open Source Software) libraries, naturally we ought to try and give something back. Often, a contribution is the best thing we can do (alongside things like supporting OSS projects with donations). However, I’ve found it pretty difficult to get into this world, imposter syndrome, hesitation around which issue to pick and finding issues which aren’t claimed are some of the problems I’ve faced (and I’m sure you have too)! So, when I got my first proper code contribution to vLLM1, I was over the moon. Not only did it open the door to another contribution, but it also built my confidence.

A bit more on this, many a time I have visited a repo I use to look at the “good first issues” tag and upon viewing the issues I see there are many people clamoring to claim and work on it. This makes it hard to get started, the place we are often recommended to start isn’t very accessible - outside of watching the repo for issues which get this tag added it’s hard to find them and claim them. Secondly, the issues page can be intimidating, knowing what to work on is a skill in and of itself.

So I want to give an insight into my method and how I got my first contribution, before diving into what exactly the contribution entails. I want everyone to walk away, not feeling anxious about getting involved in the OSS community and realise that we can all contribute and improve the ecosystem for all!

Finding my First PR - Having a Personal Issue

I recently led a journal club on the following post [https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/]. As a part of this I wanted to run the batch invariance on my local machine (I am lucky enough to have a 3090), however upon testing it turns out the 3090 was not a supported card in the vLLM batch invariance feature.

The Proper Way to Test Things

I learnt of this issue by first running a prompt in regular vLLM inference mode and then enabling batch invariance and running the same prompt for my journal club, my goal was to demonstrate batch invariance in practice. It worked to my untrained eye, the output matched 100/100 times. Armed with this info I went to the batch invariance issue and seeing that the 3090 was not listed as a supported card, which surprised me, I commented boldly claiming batch invariance works on the 3090. A maintainer quickly responded, he was surprised by my result and asked me to run the unit tests (which I probably should have started with). Upon running these they failed, quickly informing me my initial result was a fluke.

My First PR - A One Line Change

This was the start of my first PR to vLLM, in the end it turned out to be a one line fix. This small change opened pandoras box for me, firstly it allowed me to get familiar with the PR process for vLLM, small changes are good for this as they allow you to work quickly and get feedback from the maintainers quickly and thus get a holistic view of how PRs are managed before moving on to larger changes!

So, in this case it was a personal issue which led me to my first PR. I was lucky in that sense, to stumble upon a small enough change which I felt comfortable undertaking and in an area I am interested in. However, I think this approach does extend to cases where you are looking for PRs too. Looking for PRs is great, it’s an amazing way as a junior to work with very experienced engineers and might be your first time working with such a complex repository (which can be fun). To find such a PR, I could imagine a case where you look for features such as batch invariance in open source repos, this one was only implemented in around October 2025 meaning it hasn’t had time to fully mature like other features… therefore there is still work to be done and you can get involved in that work!

It could be things like testing features on whatever hardware you have available, finding where things break and fixing those issues. Often times in the ML world older and lower powered hardware (CPU only & older GPUs) is overlooked in the testing phase, but many people still operate on such hardware and your fixes could enable them to utilise these amazing features!

Lastly, the maintainers of open source projects are very willing to help. There are more issues than people to solve them in most repos so any well structured PRs are likely to be met with gratitude.

An Intro to vLLM and Batch Invariance

Before we go on to discuss the second PR and how I found it, let’s review what vLLM and batch invariance actually is.

“vLLM is a fast and easy-to-use library for LLM inference and serving.”2 and they have a feature which enables batch invariance when serving the model of your choice. It’s enabled with the following environment variable:

VLLM_BATCH_INVARIANT=1

This ensures that vLLM will replace the default kernels used when performing LLM inference with batch invariant kernels. Given this feature is a work in progress, one of the tasks involved is testing batch invariance on new models. As a new contributor to the vLLM project (and open source projects in general) I thought this task would be a great place to start. Armed with my 3090 and a dream, I set off the see what other models vLLM batch invariance can support!

Batch Invariance

In a nutshell, batch invariance is a feature which makes LLM output deterministic. It does this by changing the GPU kernels employed in the forward pass of the LLM such that computation completes in the same manner no matter how many items are in the batch.

The key point here is that LLMs are in fact run-to-run deterministic, despite the fact that we often experience otherwise (i.e. when you ask the same question twice how often will you receive a different answer?) - so naturally we question what is the cause? The answer is changes to batch size, if you were to run an LLM with the same batch size each time the answer you get would be exactly the same, however in an LLM inference server it’s not possible to guarantee the batch size when you send off your prompt to be processed.

The issue is that changing batch sizes changes the order in which GPU organises computation across cores, by introducing batch invariance we fix the computation order and make the LLM run-to-run deterministic with changing batch size having no effect on the determinism for more details check out this great post: https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/.

The Second PR - vLLM, Batch Invariance and Supporting New Models

So, 3090 in tow, a working batch invariance implementation and a goal to grow this feature I set out to look for more issues. A natural starting point was the batch invariance TODOs. This once again highlights how you might find things to work on, this list of TODOs has many asks. A nice one I felt to start with was testing more models, given my lack of (large) compute I thought it would be good to run some smaller LLMs and see how they mesh with batch invariance. Given my focus on small models, and the list of models already supported and tested on batch invariance3, I wanted to test and add support for a new type of model. I chose the Qwen3 series and specifically their AWQ quantized versions.

Qwen3-4B-AWQ, was my model of choice.

Activation aWare Quantization (AWQ) is a weight only quantization technique. The goal of any quantization is to reduce the size of the model, to allow inference to be more efficient or run on smaller GPUs. E.g. The weights alone of Qwen3-4B-AWQ takes up about 2.7gb, and for Qwen3-4B that jumps up to 8.0gb.

To get a full introduction to AWQ I recommend you read Lei Mao’s blog post on the topic https://leimao.github.io/blog/AWQ-Activation-Aware-Weight-Quantization/, it’s a great post and the rest of his work is great too.

Discovering the Issues - Running the Unit Tests

To begin testing and checking if Qwen3-4B-AWQ is supported I ran the unit tests again with the following command to specify this model:

VLLM_TEST_MODEL=Qwen/Qwen3-4B-AWQ .venv/bin/python -m pytest tests/v1/determinism/test_batch_invariance.py

This fired up the unit tests and returned the output (errors or otherwise), if the model was supported all tests would’ve gone green (they didn’t) and we would be able to say “this model works on a 3090” (the maintainers have newer hardware and they can test models on that to double check).

When I ran this the first time I saw the following output:

=== short test summary info ===
FAILED .../test_batch_invariance::deterministic_across_batch_sizes[FLASH_ATTN]  — 3/5 trials failed
FAILED .../test_batch_invariance::deterministic_across_batch_sizes[TRITON_ATTN] — 4/5 trials failed
FAILED .../test_batch_invariance::logprobs_bitwise_bs1_vs_bsN[FLASH_ATTN]      — 32/32 prompts violated
FAILED .../test_batch_invariance::logprobs_bitwise_bs1_vs_bsN[TRITON_ATTN]     — 32/32 prompts violated
FAILED .../test_batch_invariance::decode_logprobs_match_prefill[FLASH_ATTN]     — 48 mismatches
=== 5 failed, 4 passed, 27 warnings in 4m28s ===

Now to figure out why they failed, I began by looking at the log output before the tests began. It looked something like this:

================================================================================
BATCH INVARIANCE MODE: Disabling custom all-reduce (TP=1)
================================================================================

INFO 04-10 08:30:19 [utils.py:233] non-default args: {'dtype': 'bfloat16', 'max_model_len': 8192, 'max_num_seqs': 32, 'disable_log_stats': True, 'enforce_eager': True, 'attention_config': AttentionConfig(backend=<AttentionBackendEnum.FLASH_ATTN: 'vllm.v1.attention.backends.flash_attn.FlashAttentionBackend'>, flash_attn_version=None, use_prefill_decode_attention=False, flash_attn_max_num_splits_for_cuda_graph=32, use_cudnn_prefill=False, use_trtllm_ragged_deepseek_prefill=True, use_trtllm_attention=None, disable_flashinfer_prefill=True, disable_flashinfer_q_quantization=False, use_prefill_query_quantization=False), 'model': 'Qwen/Qwen3-4B-AWQ'}
WARNING 04-10 08:30:19 [envs.py:1783] Unknown vLLM environment variable detected: VLLM_TEST_MODEL
INFO 04-10 08:30:19 [model.py:549] Resolved architecture: Qwen3ForCausalLM
WARNING 04-10 08:30:19 [model.py:2018] Casting torch.float16 to torch.bfloat16.
INFO 04-10 08:30:19 [model.py:1680] Using max model len 8192
INFO 04-10 08:30:19 [awq_marlin.py:246] The model is convertible to awq_marlin during runtime. Using awq_marlin kernel.
WARNING 04-10 08:30:20 [vllm.py:857] Enforce eager set, disabling torch.compile and CUDAGraphs. This is equivalent to setting -cc.mode=none -cc.cudagraph_mode=none
WARNING 04-10 08:30:20 [vllm.py:868] Inductor compilation was disabled by user settings, optimizations settings that are only active during inductor compilation will be ignored.
INFO 04-10 08:30:20 [kernel.py:196] Final IR op priority after setting platform defaults: IrOpPriorityConfig(rms_norm=['vllm_c', 'native'])
INFO 04-10 08:30:20 [vllm.py:1046] Cudagraph is disabled under eager mode
(EngineCore pid=241117) INFO 04-10 08:30:20 [core.py:105] Initializing a V1 LLM engine (v0.18.1rc1.dev135+g70a215283) with config: model='Qwen/Qwen3-4B-AWQ', speculative_config=None, tokenizer='Qwen/Qwen3-4B-AWQ', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=8192, download_dir=None, load_format=auto, tensor_parallel_size=1, pipeline_parallel_size=1, data_parallel_size=1, decode_context_parallel_size=1, dcp_comm_backend=ag_rs, disable_custom_all_reduce=True, quantization=awq_marlin, enforce_eager=True, enable_return_routed_experts=False, kv_cache_dtype=auto, device_config=cuda, structured_outputs_config=StructuredOutputsConfig(backend='auto', disable_any_whitespace=False, disable_additional_properties=False, reasoning_parser='', reasoning_parser_plugin='', enable_in_reasoning=False), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None, kv_cache_metrics=False, kv_cache_metrics_sample=0.01, cudagraph_metrics=False, enable_layerwise_nvtx_tracing=False, enable_mfu_metrics=False, enable_mm_processor_stats=False, enable_logging_iteration_details=False), seed=0, served_model_name=Qwen/Qwen3-4B-AWQ, enable_prefix_caching=True, enable_chunked_prefill=True, pooler_config=None, compilation_config={'mode': <CompilationMode.NONE: 0>, 'debug_dump_path': None, 'cache_dir': '', 'compile_cache_save_format': 'binary', 'backend': 'inductor', 'custom_ops': ['all'], 'ir_enable_torch_wrap': False, 'splitting_ops': [], 'compile_mm_encoder': False, 'cudagraph_mm_encoder': False, 'encoder_cudagraph_token_budgets': [], 'encoder_cudagraph_max_images_per_batch': 0, 'compile_sizes': [], 'compile_ranges_endpoints': [8192], 'inductor_compile_config': {'enable_auto_functionalized_v2': False, 'size_asserts': False, 'alignment_asserts': False, 'scalar_asserts': False, 'combo_kernels': True, 'benchmark_combo_kernel': True}, 'inductor_passes': {}, 'cudagraph_mode': <CUDAGraphMode.NONE: 0>, 'cudagraph_num_of_warmups': 0, 'cudagraph_capture_sizes': [], 'cudagraph_copy_inputs': False, 'cudagraph_specialize_lora': True, 'use_inductor_graph_partition': False, 'pass_config': {'fuse_norm_quant': True, 'fuse_act_quant': True, 'fuse_attn_quant': False, 'enable_sp': False, 'fuse_gemm_comms': False, 'fuse_allreduce_rms': False}, 'max_cudagraph_capture_size': 0, 'dynamic_shapes_config': {'type': <DynamicShapesType.BACKED: 'backed'>, 'evaluate_guards': False, 'assume_32_bit_indexing': False}, 'local_cache_dir': None, 'fast_moe_cold_start': True, 'static_all_moe_layers': []}, kernel_config=KernelConfig(ir_op_priority=IrOpPriorityConfig(rms_norm=['vllm_c', 'native']), enable_flashinfer_autotune=True, moe_backend='auto')
(EngineCore pid=241117) INFO 04-10 08:30:20 [parallel_state.py:1400] world_size=1 rank=0 local_rank=0 distributed_init_method=tcp://192.168.1.107:58693 backend=nccl
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
(EngineCore pid=241117) INFO 04-10 08:30:20 [parallel_state.py:1712] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, PCP rank 0, TP rank 0, EP rank N/A, EPLB rank N/A
(EngineCore pid=241117) INFO 04-10 08:30:21 [gpu_model_runner.py:4733] Starting to load model Qwen/Qwen3-4B-AWQ...
(EngineCore pid=241117) INFO 04-10 08:30:21 [awq_marlin.py:414] Using MarlinLinearKernel for AWQMarlinLinearMethod
(EngineCore pid=241117) INFO 04-10 08:30:21 [cuda.py:302] Using AttentionBackendEnum.FLASH_ATTN backend.
(EngineCore pid=241117) INFO 04-10 08:30:21 [flash_attn.py:622] Using FlashAttention version 2
(EngineCore pid=241117) INFO 04-10 08:30:21 [weight_utils.py:627] No model.safetensors.index.json found in remote.
(EngineCore pid=241117) INFO 04-10 08:30:22 [default_loader.py:384] Loading weights took 0.50 seconds
(EngineCore pid=241117) INFO 04-10 08:30:23 [gpu_model_runner.py:4818] Model loading took 2.5 GiB memory and 1.429585 seconds
(EngineCore pid=241117) INFO 04-10 08:30:25 [gpu_worker.py:436] Available KV cache memory: 17.99 GiB
(EngineCore pid=241117) INFO 04-10 08:30:25 [kv_cache_utils.py:1319] GPU KV cache size: 131,024 tokens
(EngineCore pid=241117) INFO 04-10 08:30:25 [kv_cache_utils.py:1324] Maximum concurrency for 8,192 tokens per request: 15.99x
(EngineCore pid=241117) INFO 04-10 08:30:25 [core.py:283] init engine (profile, create kv cache, warmup model) took 2.50 seconds
(EngineCore pid=241117) WARNING 04-10 08:30:26 [vllm.py:857] Enforce eager set, disabling torch.compile and CUDAGraphs. This is equivalent to setting -cc.mode=none -cc.cudagraph_mode=none
(EngineCore pid=241117) WARNING 04-10 08:30:26 [vllm.py:868] Inductor compilation was disabled by user settings, optimizations settings that are only active during inductor compilation will be ignored.
(EngineCore pid=241117) INFO 04-10 08:30:26 [kernel.py:196] Final IR op priority after setting platform defaults: IrOpPriorityConfig(rms_norm=['vllm_c', 'native'])
(EngineCore pid=241117) INFO 04-10 08:30:26 [vllm.py:1046] Cudagraph is disabled under eager mode

As you can see it contains a lot of info, but it is this line which tells us what happened:

INFO 04-10 08:30:19 [awq_marlin.py:246] The model is convertible to awq_marlin during runtime. Using awq_marlin kernel.

Remember, batch invariance works by switching out the regular kernels for batch invariant ones. This line tells us the kernels are being swapped to AWQ Marlin kernels and we cannot assume these are batch invariant. It is from this I began my investigation, now we need to look at that awq_marlin.py file.

Tracing the Issue - awq_marlin.py - The Marlin Autoconversion

Let’s start at line 246. Line 246 is within the following function (it is the line containing the logger.info(msg))

@classmethod
def override_quantization_method(
    cls, hf_quant_cfg, user_quant
) -> "QuantizationMethods | None":
    can_convert = cls.is_awq_marlin_compatible(hf_quant_cfg)
    is_valid_user_quant = (
        user_quant is None or user_quant == "marlin" or user_quant == "awq_marlin"
    )

    if can_convert and is_valid_user_quant:
        msg = (
            "The model is convertible to {} during runtime."
            " Using {} kernel.".format(cls.get_name(), cls.get_name())
        )
        logger.info(msg)
        return cls.get_name()

    if can_convert and user_quant == "awq":
        logger.info(
            "Detected that the model can run with awq_marlin"
            ", however you specified quantization=awq explicitly,"
            " so forcing awq. Use quantization=awq_marlin for"
            " faster inference"
        )
    return None

The purpose of this function is the check if the kernel can be replaced with a marlin kernel. The simplest approach then is to just skip them, return None here if batch invariant mode is turned on and never allow the marlin conversion to take place.

Marlin kernels are super fast, mixed precision kernels. They operate on FP16 and INT4 values. We need mixed precision as AWQ model weights are stored in INT4, but the activations run in FP16 (remember AWQ is a weight only quantization method).

Marlin kernels (are probably one of many) which allow for efficient computation of FP16 x INT4 operations.

For some further info on Marlin kernels check out: https://developers.redhat.com/articles/2024/04/17/how-marlin-pushes-boundaries-mixed-precision-llm-inference

Blocking the Marlin Conversion

It’s a simple change, the function is now:

@classmethod
def override_quantization_method(
    cls, hf_quant_cfg, user_quant
) -> "QuantizationMethods | None":
    # Skip override to marlin kernels, as they are not
    # batch invariant
    if envs.VLLM_BATCH_INVARIANT:
        return None

    can_convert = cls.is_awq_marlin_compatible(hf_quant_cfg)
    is_valid_user_quant = (
        user_quant is None or user_quant == "marlin" or user_quant == "awq_marlin"
    )

    if can_convert and is_valid_user_quant:
        msg = (
            "The model is convertible to {} during runtime."
            " Using {} kernel.".format(cls.get_name(), cls.get_name())
        )
        logger.info(msg)
        return cls.get_name()

    if can_convert and user_quant == "awq":
        logger.info(
            "Detected that the model can run with awq_marlin"
            ", however you specified quantization=awq explicitly,"
            " so forcing awq. Use quantization=awq_marlin for"
            " faster inference"
        )
    return None

If batch invariant mode on, return None from this method. AWQ kernels are not swapped with marlin ones. The next step was to run the unit tests again and check the output:

=== short test summary info ===
FAILED ...deterministic_across_batch_sizes[FLASH_ATTN]  — ValidationError: VllmConfig
FAILED ...deterministic_across_batch_sizes[TRITON_ATTN] — ValidationError: VllmConfig
FAILED ...logprobs_bitwise_bs1_vs_bsN[FLASH_ATTN]      — ValidationError: VllmConfig
FAILED ...logprobs_bitwise_bs1_vs_bsN[TRITON_ATTN]      — ValidationError: VllmConfig
FAILED ...test_simple_generation[FLASH_ATTN]             — ValidationError: VllmConfig
FAILED ...test_simple_generation[TRITON_ATTN]            — ValidationError: VllmConfig
FAILED ...decode_logprobs_match_prefill[FLASH_ATTN]      — ValidationError: VllmConfig
=== 7 failed, 2 passed, 18 warnings in 36s ===

It didn’t work! This time for a different reason. Let’s investigate by taking a look at some of the actual test output:

E       pydantic_core._pydantic_core.ValidationError: 1 validation error for VllmConfig
E         Value error, torch.bfloat16 is not supported for quantization method awq. Supported dtypes: [torch.float16] [type=value_error, input_value=ArgsKwargs((), {'model_co... 'shutdown_timeout': 0}), input_type=ArgsKwargs]
E           For further information visit https://errors.pydantic.dev/2.12/v/value_error

vllm/engine/arg_utils.py:2001: ValidationError
------------------------------------------------------------------------------------------------------- Captured stdout call -------------------------------------------------------------------------------------------------------

================================================================================
BATCH INVARIANCE MODE: Disabling custom all-reduce (TP=1)
================================================================================

INFO 04-10 16:02:24 [utils.py:233] non-default args: {'dtype': 'bfloat16', 'max_model_len': 8192, 'max_num_seqs': 32, 'disable_log_stats': True, 'enforce_eager': True, 'attention_config': AttentionConfig(backend=<AttentionBackendEnum.FLASH_ATTN: 'vllm.v1.attention.backends.flash_attn.FlashAttentionBackend'>, flash_attn_version=None, use_prefill_decode_attention=False, flash_attn_max_num_splits_for_cuda_graph=32, use_cudnn_prefill=False, use_trtllm_ragged_deepseek_prefill=True, use_trtllm_attention=None, disable_flashinfer_prefill=True, disable_flashinfer_q_quantization=False, use_prefill_query_quantization=False), 'model': 'Qwen/Qwen3-4B-AWQ'}
WARNING 04-10 16:02:24 [envs.py:1783] Unknown vLLM environment variable detected: VLLM_TEST_MODEL
INFO 04-10 16:02:25 [model.py:549] Resolved architecture: Qwen3ForCausalLM
WARNING 04-10 16:02:25 [model.py:2018] Casting torch.float16 to torch.bfloat16.
INFO 04-10 16:02:25 [model.py:1680] Using max model len 8192

Two things to hone in on, firstly the error itself:

torch.bfloat16 is not supported for quantization method awq. Supported dtypes: [torch.float16]

Secondly, batch invariant mode casts the model to bfloat16. It is explicitly stated by the logging:

WARNING 04-10 16:02:25 [model.py:2018] Casting torch.float16 to torch.bfloat16.

By blocking the marlin conversion, I ran into this new issue. It appears that the regular AWQ kernels do not support bfloat16 dtype (whereas marlin kernels support both FP16 and bfloat16). This is due to the fact that the tests themselves hardcode any model to be cast to bfloat16 and as such they were relying on the flexibility of marlin kernels to work properly.

The fix here is simple, remove the hard coded casting. It appears a few times in the test_batch_invariance.py file, the change is as follows. We change:

dtype="bfloat16"

to

dtype="auto"

This allows vLLM to choose the best dtype for the given model, making it a more general test in the process and leaving the current model in float16. Having changed this, the next step was once again to run the unit tests.

Result of Setting dtype = “auto”

triton.runtime.errors.OutOfResources: out of resource: shared memory, Required: 106496, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.

Another failure, this time it was due to running out of memory. Specifically GPU Shared Memory, which is the shared pools of memory that GPU processors have access to (think of it as L1 cache on a CPU). In this case the operation requires 106496 bytes, but the 3090 only has 101376 bytes.

The fix here, luckily, was given in the error message. The offending file and function is:

File "/home/yusuf/PycharmProjects/ym_vllm/vllm/vllm/model_executor/layers/batch_invariant.py", line 196, in matmul_persistent

In this function the following config exists:

configs = {
        torch.bfloat16: {
            "BLOCK_SIZE_M": 128,
            "BLOCK_SIZE_N": 128,
            "BLOCK_SIZE_K": 64,
            "GROUP_SIZE_M": 8,
            "num_stages": 3,
            "num_warps": 8,
        },
        torch.float16: {
            "BLOCK_SIZE_M": 128,
            "BLOCK_SIZE_N": 256,
            "BLOCK_SIZE_K": 64,
            "GROUP_SIZE_M": 8,
            "num_stages": 3,
            "num_warps": 8,
        },
        torch.float32: {
            "BLOCK_SIZE_M": 128,
            "BLOCK_SIZE_N": 128,
            "BLOCK_SIZE_K": 32,
            "GROUP_SIZE_M": 8,
            "num_stages": 3,
            "num_warps": 8,
        },
    }

Given the AWQ model is float16, we use those settings. Now, observe how the torch.float16 entry has a BLOCK_SIZE_N of 256, whereas both others have 128 - this is the cause of the out of memory issue (and why it wasn’t seen when marlin kernels were allowed, as they ran in bfloat16 and as such the BLOCK_SIZE_N was 128 and fit in the shared memory). I halved the BLOCK_SIZE_N, and this allowed the operation to run without running out of memory.

Enough Memory?

Again, with that error taken care of it’s once again time to run the unit tests!

Alas, another error reared its head! We must be getting close to the end of these now though right…?

File "/home/yusuf/PycharmProjects/ym_vllm/vllm/vllm/model_executor/layers/batch_invariant.py", line 755, in _log_softmax_batch_invariant
assert not _half_to_float, "not implemented" 

AssertionError: not implemented

Not implemented error, now getting batch invariance to work on the 3090 is leading to many untouched code paths. I think this highlights the need for people with this hardware to get involved in projects like vLLM. It can be hard to know a priori how different hardware will handle the code and only through testing can we reveal these effects.

The function itself is as follows:

def _log_softmax_batch_invariant(input, dim, _half_to_float):
    assert not _half_to_float, "not implemented"
    return log_softmax(input, dim=dim)

It calls the log_softmax and raises an error if the input is a “half” which means float16 and the output is a “float” or float32. The reason this passed when using bfloat16 is because _half_to_float is false, this binary value checks if there needs to be a cast from fp16 -> fp32. This is a quirk of PyTorch internals, when we pass bfloat16 the input is cast before this function is called therefore the input is fp32, however when we pass fp16 the value is not cast before the function is called.

We can prove this with the following snippet (and adding a print statement to _log_softmax_batch_invariant print(f”input dtype: {input.dtype}, _half_to_float: {_half_to_float}“))

VLLM_BATCH_INVARIANT=1 .venv/bin/python -c "
import os; os.environ['VLLM_BATCH_INVARIANT'] = '1'
import torch
from vllm.model_executor.layers.batch_invariant import enable_batch_invariant_mode, _log_softmax_batch_invariant
enable_batch_invariant_mode()
# Test bfloat16
bf16 = torch.randn(2, 10, dtype=torch.bfloat16, device='cuda')
print('=== bfloat16 ===')
print(f'Calling log_softmax with dtype=float32 on bfloat16 tensor')
result = bf16.log_softmax(dim=-1, dtype=torch.float32)
# Test float16
fp16 = torch.randn(2, 10, dtype=torch.float16, device='cuda')
print('=== float16 ===')
print(f'Calling log_softmax with dtype=float32 on float16 tensor')
result = fp16.log_softmax(dim=-1, dtype=torch.float32)
"

Output:

=== bfloat16 ===
Calling log_softmax with dtype=float32 on bfloat16 tensor
input dtype: torch.float32, _half_to_float: False
=== float16 ===
Calling log_softmax with dtype=float32 on float16 tensor
input dtype: torch.float16, _half_to_float: True

Observe how when input in bfloat16 it is fp32 in the function, whereas when calling with fp16 it is left as fp16 in the function.

Another small fix suffices here:

def _log_softmax_batch_invariant(input, dim, _half_to_float):
    if _half_to_float:
        return log_softmax(input.float(), dim=dim)
    return log_softmax(input, dim=dim)

Now, I addressed the not implemented error… by implementing it! Simply put, I do what PyTorch already does for the bfloat16 - convert the input from fp16 to fp32 before we call log_softmax. That’s what .float() does.

The Final Change

Now with all these changes completed, I ran the unit tests again. Unsurprisingly, they failed again (I promise this is the last time).

However this time failure is different - there is no code error, the code runs but the output is non-deterministic. This requires us to dig a little deeper.

  • test_v1_generation_is_deterministic_across_batch_sizes_with_needle
  • test_logprobs_bitwise_batch_invariance_bs1_vs_bsN
  • test_simple_generation
  • test_logprobs_without_batch_invariance_should_fail
  • test_decode_logprobs_match_prefill_logprobs

test_simple_generation - runs the model, passes a prompt and checks that some output is generated.

test_logprobs_without_batch_invariance_should_fail - runs test_logprobs_bitwise_batch_invariance_bs1_vs_bsN but with the batch invariance feature turned off. In doing so, it checks if the feature actually does something. Basically, if batch invariance is turned off is the model actually non-deterministic.

test_v1_generation_is_deterministic_across_batch_sizes_with_needle - starts with a “needle” prompt, which is “There once was a”, this is basically a starting point for generation. Then it ensures that a call with the same needle prompt, yields identical output when batch_size (bs) = 1 and bs = n. Basically, runs needle prompt alone and then hidden with a batch, then it verifies the response for the two needle prompts match (verifying the batch size has no impact on outputs).

test_logprobs_bitwise_batch_invariance_bs1_vs_bsN - run 32 prompts individually and then all together in one batch. The logprobs are recorded at each step and for each prompt they are compared. The logprobs must be identical at each step.

test_decode_logprobs_match_prefill_logprobs - run a single prompt through prefill (the first forward pass, processing the input prompt) and then run the decode phase (generate N tokens), store the logprobs at each step in the decode. Then, take the generated tokens say 3 of them append them to the prompt and run prefill again and check what the outputted logprob is. It should match the decode logprobs.

Reviewing the 2 Failing Tests

FAILED ...deterministic_across_batch_sizes[FLASH_ATTN]  — 3/5 trials failed
FAILED ...deterministic_across_batch_sizes[TRITON_ATTN] — 4/5 trials failed
FAILED ...logprobs_bitwise_bs1_vs_bsN[FLASH_ATTN]      — 32/32 prompts violated
FAILED ...logprobs_bitwise_bs1_vs_bsN[TRITON_ATTN]      — 32/32 prompts violated

4 failures, 2 tests failing but on 2 different backends. The tests now just tell us it isnt working, but there’s no explicit error present here. A typical machine learning error4 where it fails silently.

The funny thing is, the difference in the logprobs is so tiny that the tokens generated when bs=1 and bs=n are the same.

Prompt 31 (step 0):
  Reason: Bitwise mismatch (max_diff=6.332994e-07)
  Preview: I've been thinking about getting a new laptop because...
  BS=1 tokens: [847, 1482, 825, 374, 3709, 2238, 6301, 323, 358, 1184, 803, 5819, 13, 358, 2776, 537]
  BS=N tokens: [847, 1482, 825, 374, 3709, 2238, 6301, 323, 358, 1184, 803, 5819, 13, 358, 2776, 537]
  BS=1 logprobs for all 16 steps:
    Step 0: -0.12404439598321915
    Step 1: -0.273791640996933
    Step 2: -0.012951892800629139
    Step 3: -0.11462680250406265
    Step 4: -1.4374871253967285
    Step 5: -1.070314884185791
    Step 6: -0.11129239946603775
    Step 7: -1.8875861167907715
    Step 8: -1.3804924488067627
    Step 9: -0.5117585062980652
    Step 10: -1.5740704536437988
    Step 11: -0.810169517993927
    Step 12: -0.06133199855685234
    Step 13: -0.36362993717193604
    Step 14: -0.9980694055557251
    Step 15: -0.9822379350662231
  BS=N logprobs for all 16 steps:
    Step 0: -0.12404376268386841
    Step 1: -0.2737951874732971
    Step 2: -0.012942831963300705
    Step 3: -0.11547216027975082
    Step 4: -1.4370230436325073
    Step 5: -1.063231110572815
    Step 6: -0.10990802943706512
    Step 7: -1.8877787590026855
    Step 8: -1.368670105934143
    Step 9: -0.5117335319519043
    Step 10: -1.5733510255813599
    Step 11: -0.8182716369628906
    Step 12: -0.06144455820322037
    Step 13: -0.36342307925224304
    Step 14: -1.000751256942749
    Step 15: -0.9830570220947266

Look here, the tokens match but the raw logprobs are not equal, an easy one to spot is in Step 14. This is an interesting fact, but it doesn’t lead us anywhere to solving this issue.

Calling in Assistance From Claude Code

LLMs are great at reasoning about code bases, we can use this to our advantage to support in OSS contribution. My method is when I get stuck for Claude to ask me questions about the issue I’m stuck at and for it to help me understand the structure of the repo. I’m sure that Claude could have completely closed this issue itself, but as an early contributor I don’t think “closing the PR” and getting a stamp on my name “I contributed to some OSS project” is the goal. The goal ought to be, personal development, wanting to fix an issue or just plain curiosity and utilised correctly LLMs can support this massively.

By guiding me through the structure of the repo it led me to look into the awq.py file and get to the heart of the final issue.

To be honest, it’s easy to use LLMs to just do the PR - but as mentioned before by doing so you add nothing that the maintainer couldn’t themselves. Rather as I have done here by working on this PR in collaboration with Claude I built my curiosity and it enabled me to discover more issues, fixes and avenues to explore in the vLLM project. And without Claude to support maybe I wouldn’t have taken this so far, in the past a large part of my anxiety around OSS contribution has been caused by such thinking - if LLMs help to abate that for many people imagine how much we can collectively achieve!

If we Blocked awq_marlin conversion, Which Kernel is Used?

In awq_marlin.py we blocked the awq_marlin conversion. So the question is if we blocked the awq_marlin kernel which kernel IS being used now?

There is a another file awq.py, sitting in the same dir as awq_marlin.py. Within this file there is a function which handles which kernel is actually used for the Matmul:

def apply(
    self,
    layer: torch.nn.Module,
    x: torch.Tensor,
    bias: torch.Tensor | None = None,
) -> torch.Tensor:
    qweight = layer.qweight
    scales = layer.scales
    qzeros = layer.qzeros
    pack_factor = self.quant_config.pack_factor
    out_shape = x.shape[:-1] + (qweight.shape[-1] * pack_factor,)
    reshaped_x = x.reshape(-1, x.shape[-1])

    # num_tokens >= threshold
    FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 256

    if FP16_MATMUL_HEURISTIC_CONDITION:
        out = ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
        out = torch.matmul(reshaped_x, out)
    else:
        out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros, pack_factor)
    if bias is not None:
        out.add_(bias)
    return out.reshape(out_shape)

Take note of the if statement, there are two paths in which there are two different kernels used. In the unit tests (where decode token count is often 1) the second path is taken and this is the cause of the non-determinism in the logprobs. The awq_gemm kernel is not batch invariant, as such the simple fix was to block calling this kernel and forcing the first path:

if FP16_MATMUL_HEURISTIC_CONDITION or envs.VLLM_BATCH_INVARIANT:
    out = ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
    out = torch.matmul(reshaped_x, out)
else:
    out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros, pack_factor)

With this we can run the unit tests one more time and:

9 passed, 27 warnings in 469.21s (0:07:49) 

Woohoo! Qwen3-4B-AWQ works in batch invariant mode on the 3090! (And the B200 as tested by a vLLM maintainer).

A Closing Note

This whole process was exceedingly fun! Starting from a tiny initial PR, leading to another albeit tiny second PR was great to experience first hand. The anxiety slowly slipped away as I went through the PR process, building my confidence and enabling me to look for and work on more issues.

I know it can be daunting to begin, even in this small work we touched on some pretty heavy topics. But I hope you see now it doesn’t need to be and I hope other newcomers in the OSS community take similar approaches and grow into meaningful contributors. Run the unit tests, check the format, understand the behaviour and get out there and contribute :)


1 https://github.com/vllm-project/vllm

2 https://docs.vllm.ai/en/stable/

3 https://docs.vllm.ai/en/latest/features/batch_invariance/#tested-models

4 https://karpathy.github.io/2019/04/25/recipe/