Task-Driven Prompt Evolution For Foundation Models
Task-Driven Prompt Evolution For Foundation Models
Task-Driven Prompt Evolution For Foundation Models
Models
GE HealthCare
arXiv:2310.17128v1 [cs.CV] 26 Oct 2023
{rahul.venkataramani}ge.com
1 Introduction
The recent release of a foundation model for image segmentation called Segment
Anything (SAM) [3] has generated unprecedented excitement about the possibil-
ity of realizing artificial general intelligence (AGI) in the field of medical image
analysis. SAM is a task-agnostic promptable segmentation model trained on 1
billion masks. This has triggered the possibility of improved zero-shot segmen-
tation performance and obviate the necessity for specialized techniques across
medical imaging tasks [4].
Consequently, a number of studies [2,1,6] have evaluated the performance of
SAM on a plethora of medical imaging segmentation tasks, and have concluded
that while SAM is a promising first step, there exists a significant gap compared
to supervised learning algorithms on many datasets. The hypothesized reasons
include lack of medical imaging samples in the training database and peculiari-
ties associated with medical images (e.g., scan-cone in Ultrasound, 3D nature of
CT/MR, large intensity variations in X-Ray and higher image resolution com-
pared to natural images).
2 R. Sathish et al.
like the Dice value. At inference, given a test image and prompt, we iteratively
maximize this task-based score to evolve the location of the prompt to produce
superior results compared to utilizing initial prompt location provided by user.
Although we develop this method on SAM, SAMPOT can be used in a plug-
and-play fashion with any foundation segmentation model.
1.2 Contributions
2 Methodology
Our method is aimed at evolving the location of the prompt and arriving at
an optimal prompt p∗ . Suppose we had access to the ground truth mask Ytest
for a given input image, we could simply compute the loss Ltask (Ybtest , Ytest ) and
choose a p that minimises the loss. However, as that is fallaciously self-fulfilling,
we propose to use an oracle O that acts as a surrogate to the true loss Ltask . The
scorer takes the input image Xtest and the predicted mask Ybtest and produces
a score s. The scorer can be a pre-learnt (and fixed) neural network model
that can be used in conjunction with the segmentation model, enabling us to
compute the gradients of the score with repect to p. If the scorer is designed
4 R. Sathish et al.
Fig. 2. Schematic of the SAMPOT. The spatial location of the user-provided prompt
is updated based on the gradients received from the segmentation score.
p∗ := arg max O(Xtest , Ybtest ), where Ybtest := fSAM (Xtest , p). (1)
p
Note that the gradient of s is computed with respect to p and therefore only p
gets updated, while the weights of SAM fSAM and the scorer O are held fixed.
In this study, we tapped into a database of X-ray images available within our
institution, sourced through data partnerships from US, Africa, and European
Task-driven Prompt Evolution for Foundation Models 5
(a) Sample Mask (b) Distance map (c) Over-seg. (d) Under-seg.
Fig. 3. Figure shows (a) sample mask from the dataset, (b) computed distance map,
synthetically generated (c) over-segmented mask and (d) sample under-segmented. The
dice coefficient for the over-segmented mask is 0.57 and that for under-segmented mask
is 0.61.
populations. The datasets were acquired after receiving approval from the rel-
evant Institutional Review Boards. The lung boundaries on the X-ray images
were delineated by a team of experienced radiologists. X-ray images from 122
subjects were split into train and test subjects in our experimental setup. This
split was used for training and evaluation of the segmentation regressor only.
Note that the SAM model is pretrained and is fixed throughout the study. We
have evaluated the effectiveness of the prompt optimization technique on the
test split of the dataset, thereby ensuring that the results are not biased by the
regressor which has been optimized on the train split. The train cohort is further
divided into training and validation sets with images from 41 and 28 subjects
each. The test set has images from 53 subjects.
Data preparation: We created several synthetic masks for every lung anno-
tation in the dataset and computed the Dice coefficient for these masks as the
ground truth segmentation score. We used the level-sets of ground truth an-
notation to generate under- and over-segmented instances of the lung field as
presented in fig. 3.
Additionally, we also included the lung mask predicted by the SAM when
given a single positive prompt and the corresponding Dice coefficient. In every
image, the lung field was divided into three horizontal bands and the centroid
of these regions were chosen as a prompt. We also chose random points outside
the three bands, with an offset of 5 pixels as prompts for SAM. Therefore, we
obtained predictions corresponding to 6 separate prompts for each image. Thus
we had a total of 901 images in the train set, 600 in the val set and 1205 in the
test set for learning the regressor.
Training parameters and network architecture: The regressor network
consisted of five 2D convolution layers interleaved with Batch normalization and
leaky ReLU activation, and sigmoid activation for the final layer. The network
was trained for 200 epochs with a batch size of 32 using Adam optimizer and
6 R. Sathish et al.
mean squared error (MSE) loss. A constant learning rate of 0.001 was used. We
set the stopping criterion as minimal loss on the validation set.
3.4 Results
Evaluation of Segmentation Regressor: Figure 4(a) is the scatterplot of
regressor outputs against true Dice coefficients for all the samples in the test
set, including the synthetically generated masks as well as SAM predictions.
The high correlation coefficient (0.88) shows that the regressor output can serve
as a proxy for Dice coefficient of segmented mask. We also present a similar plot
for SAM confidence scores for segmentations when prompted at the centroid
of the lung mask. We observe that the confidence scores of SAM have a lower
correlation coefficient of 0.67 with Dice compared to our Segmentation Regressor.
Fig. 4. Comparison of (a) Dice against regressor output for unseen synthetically gen-
erated masks (1205 samples); on the test set (53 samples) (b) Dice against SAM con-
fidence score and (c) Dice against regressor output when prompts are placed at the
centroid of the lung mask. The correlation coefficient for the regressor on unseen syn-
thetically generated masks is 0.88. On test samples, the correlation coefficient for the
regressor is 0.90 in comparison with 0.67 for SAM.
Task-driven Prompt Evolution for Foundation Models 7
Fig. 5. [Best viewed in color] Figure illustrates the trajectory of the prompt during the
optimization process. The initial prompt is set at the centroid of the ground truth lung
field annotation. Snapshots of the predicted masks at select locations on the prompt
trajectory along with the computed dice score are also shown.
4 Discussion
The direct application of foundation models like SAM has shown sub-par per-
formance on a number of different medical image segmentation tasks. Given the
relatively modest sizes of datasets available for downstream medical imaging
tasks, it may be prohibitive to fine-tune a very large model like SAM. The per-
formance of SAM on the previously unseen problem of lung segmentation on
8 R. Sathish et al.
Fig. 6. [Best viewed in color] Scatter plot of Dice coefficients resulting from initial
prompts and the final evolved prompts on the test set. 38 of 53 cases have shown
improvement in Dice after evolving. Four of them have significant Dice gains. The
scatter plot on the right is the blown-up area on the top-right of the scatter plot on the
left. The top row shows images that have significantly gained from prompt evolution.
On the bottom are some cases which under-performs upon prompt evolution.
5 Conclusions
On medical images, we observed that the spatial location of the prompt for
a general purpose foundation model (SAM) affects the accuracy. Taking a cue
from the NLP community, we have presented SAMPOT, a method to optimize
the prompt for a foundation model by altering the spatial location to obtain
superior results on downstream tasks. We have demonstrated this method on
lung segmentation of chest X-rays and obtained improvement on a significant
number of cases (∼ 75%). We hope that our work offers possibilities of prompt-
learning for extracting maximal value from general purpose foundation models
trained on natural images on domain-specific downstream tasks in medical image
analysis.
Task-driven Prompt Evolution for Foundation Models 9
References
1. Cheng, D., Qin, Z., Jiang, Z., Zhang, S., Lao, Q., Li, K.: SAM on medical images:
A comprehensive study on three prompt modes. arXiv preprint arXiv:2305.00035
(2023)
2. He, S., Bao, R., Li, J., Grant, P.E., Ou, Y.: Accuracy of segment-anything model
(SAM) in medical image segmentation tasks. arXiv preprint arXiv:2304.09324
(2023)
3. Kirillov, A., Mintun, E., Ravi, N., Mao, H., Rolland, C., Gustafson, L., Xiao, T.,
Whitehead, S., Berg, A.C., Lo, W.Y., et al.: Segment anything. arXiv preprint
arXiv:2304.02643 (2023)
4. Li, C., Lin, X., Mao, Y., Lin, W., Qi, Q., Ding, X., Huang, Y., Liang, D., Yu, Y.:
Domain generalization on medical imaging classification using episodic training
with task augmentation. Computers in biology and medicine 141, 105144 (2022)
5. Luc, P., Couprie, C., Chintala, S., Verbeek, J.: Semantic segmentation using ad-
versarial networks. arXiv preprint arXiv:1611.08408 (2016)
6. Ma, J., Wang, B.: Segment anything in medical images. arXiv preprint
arXiv:2304.12306 (2023)
7. Ravishankar, H., Venkataramani, R., Thiruvenkadam, S., Sudhakar, P., Vaidya, V.:
Learning and incorporating shape models for semantic segmentation. In: Interna-
tional conference on medical image computing and computer-assisted intervention.
pp. 203–211. Springer (2017)
8. Wu, J., Fu, R., Fang, H., Liu, Y., Wang, Z., Xu, Y., Jin, Y., Arbel, T.: Medical
SAM adapter: Adapting segment anything model for medical image segmentation.
arXiv preprint arXiv:2304.12620 (2023)
9. Zhang, K., Liu, D.: Customized segment anything model for medical image seg-
mentation. arXiv preprint arXiv:2304.13785 (2023)
10. Zhou, K., Yang, J., Loy, C.C., Liu, Z.: Conditional prompt learning for vision-
language models. In: Proceedings of the IEEE/CVF Conference on Computer Vi-
sion and Pattern Recognition. pp. 16816–16825 (2022)
11. Zhou, K., Yang, J., Loy, C.C., Liu, Z.: Learning to prompt for vision-language
models. International Journal of Computer Vision 130(9), 2337–2348 (2022)