Skip to content

Commit 6f16ea4

Browse files
committed
make the StagingBuffer implementation more robust
1 parent 347d902 commit 6f16ea4

File tree

5 files changed

+89
-62
lines changed

5 files changed

+89
-62
lines changed

wgpu-core/src/device/global.rs

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2519,20 +2519,17 @@ impl Global {
25192519
}
25202520
let map_state = &*buffer.map_state.lock();
25212521
match *map_state {
2522-
resource::BufferMapState::Init { ref ptr, .. } => {
2522+
resource::BufferMapState::Init { ref staging_buffer } => {
25232523
// offset (u64) can not be < 0, so no need to validate the lower bound
25242524
if offset + range_size > buffer.size {
25252525
return Err(BufferAccessError::OutOfBoundsOverrun {
25262526
index: offset + range_size - 1,
25272527
max: buffer.size,
25282528
});
25292529
}
2530-
unsafe {
2531-
Ok((
2532-
NonNull::new_unchecked(ptr.as_ptr().offset(offset as isize)),
2533-
range_size,
2534-
))
2535-
}
2530+
let ptr = unsafe { staging_buffer.ptr() };
2531+
let ptr = unsafe { NonNull::new_unchecked(ptr.as_ptr().offset(offset as isize)) };
2532+
Ok((ptr, range_size))
25362533
}
25372534
resource::BufferMapState::Active {
25382535
ref ptr, ref range, ..

wgpu-core/src/device/queue.rs

Lines changed: 19 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ use smallvec::SmallVec;
3131
use std::{
3232
iter,
3333
mem::{self},
34-
ptr::{self, NonNull},
34+
ptr::NonNull,
3535
sync::{atomic::Ordering, Arc},
3636
};
3737
use thiserror::Error;
@@ -405,17 +405,13 @@ impl Global {
405405
// Platform validation requires that the staging buffer always be
406406
// freed, even if an error occurs. All paths from here must call
407407
// `device.pending_writes.consume`.
408-
let (staging_buffer, staging_buffer_ptr) = StagingBuffer::new(device, data_size)?;
408+
let mut staging_buffer = StagingBuffer::new(device, data_size)?;
409409
let mut pending_writes = device.pending_writes.lock();
410410
let pending_writes = pending_writes.as_mut().unwrap();
411411

412-
let staging_buffer = unsafe {
412+
let staging_buffer = {
413413
profiling::scope!("copy");
414-
ptr::copy_nonoverlapping(
415-
data.as_ptr(),
416-
staging_buffer_ptr.as_ptr(),
417-
data_size.get() as usize,
418-
);
414+
staging_buffer.write(data);
419415
staging_buffer.flush()
420416
};
421417

@@ -448,13 +444,14 @@ impl Global {
448444

449445
let device = &queue.device;
450446

451-
let (staging_buffer, staging_buffer_ptr) = StagingBuffer::new(device, buffer_size)?;
447+
let staging_buffer = StagingBuffer::new(device, buffer_size)?;
448+
let ptr = unsafe { staging_buffer.ptr() };
452449

453450
let fid = hub.staging_buffers.prepare(id_in);
454451
let id = fid.assign(Arc::new(staging_buffer));
455452
resource_log!("Queue::create_staging_buffer {id:?}");
456453

457-
Ok((id, staging_buffer_ptr))
454+
Ok((id, ptr))
458455
}
459456

460457
pub fn queue_write_staging_buffer<A: HalApi>(
@@ -487,7 +484,7 @@ impl Global {
487484
// user. Platform validation requires that the staging buffer always
488485
// be freed, even if an error occurs. All paths from here must call
489486
// `device.pending_writes.consume`.
490-
let staging_buffer = unsafe { staging_buffer.flush() };
487+
let staging_buffer = staging_buffer.flush();
491488

492489
let result = self.queue_write_staging_buffer_impl(
493490
&queue,
@@ -779,42 +776,34 @@ impl Global {
779776
// Platform validation requires that the staging buffer always be
780777
// freed, even if an error occurs. All paths from here must call
781778
// `device.pending_writes.consume`.
782-
let (staging_buffer, staging_buffer_ptr) = StagingBuffer::new(device, stage_size)?;
779+
let mut staging_buffer = StagingBuffer::new(device, stage_size)?;
783780

784781
if stage_bytes_per_row == bytes_per_row {
785782
profiling::scope!("copy aligned");
786783
// Fast path if the data is already being aligned optimally.
787-
unsafe {
788-
ptr::copy_nonoverlapping(
789-
data.as_ptr().offset(data_layout.offset as isize),
790-
staging_buffer_ptr.as_ptr(),
791-
stage_size.get() as usize,
792-
);
793-
}
784+
staging_buffer.write(&data[data_layout.offset as usize..]);
794785
} else {
795786
profiling::scope!("copy chunked");
796787
// Copy row by row into the optimal alignment.
797788
let copy_bytes_per_row = stage_bytes_per_row.min(bytes_per_row) as usize;
798789
for layer in 0..size.depth_or_array_layers {
799790
let rows_offset = layer * block_rows_per_image;
800-
for row in 0..height_blocks {
791+
for row in rows_offset..rows_offset + height_blocks {
792+
let src_offset = data_layout.offset as u32 + row * bytes_per_row;
793+
let dst_offset = row * stage_bytes_per_row;
801794
unsafe {
802-
ptr::copy_nonoverlapping(
803-
data.as_ptr().offset(
804-
data_layout.offset as isize
805-
+ (rows_offset + row) as isize * bytes_per_row as isize,
806-
),
807-
staging_buffer_ptr.as_ptr().offset(
808-
(rows_offset + row) as isize * stage_bytes_per_row as isize,
809-
),
795+
staging_buffer.write_with_offset(
796+
data,
797+
src_offset as isize,
798+
dst_offset as isize,
810799
copy_bytes_per_row,
811-
);
800+
)
812801
}
813802
}
814803
}
815804
}
816805

817-
let staging_buffer = unsafe { staging_buffer.flush() };
806+
let staging_buffer = staging_buffer.flush();
818807

819808
let regions = (0..array_layer_count).map(|rel_array_layer| {
820809
let mut texture_base = dst_base.clone();

wgpu-core/src/device/resource.rs

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -591,18 +591,15 @@ impl<A: HalApi> Device<A> {
591591
};
592592
hal::BufferUses::MAP_WRITE
593593
} else {
594-
let (staging_buffer, staging_buffer_ptr) =
594+
let mut staging_buffer =
595595
StagingBuffer::new(self, wgt::BufferSize::new(aligned_size).unwrap())?;
596596

597597
// Zero initialize memory and then mark the buffer as initialized
598598
// (it's guaranteed that this is the case by the time the buffer is usable)
599-
unsafe { std::ptr::write_bytes(staging_buffer_ptr.as_ptr(), 0, aligned_size as usize) };
599+
staging_buffer.write_zeros();
600600
buffer.initialization_status.write().drain(0..aligned_size);
601601

602-
*buffer.map_state.lock() = resource::BufferMapState::Init {
603-
staging_buffer,
604-
ptr: staging_buffer_ptr,
605-
};
602+
*buffer.map_state.lock() = resource::BufferMapState::Init { staging_buffer };
606603
hal::BufferUses::COPY_DST
607604
};
608605

wgpu-core/src/resource.rs

Lines changed: 58 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -256,10 +256,7 @@ pub enum BufferMapAsyncStatus {
256256
#[derive(Debug)]
257257
pub(crate) enum BufferMapState<A: HalApi> {
258258
/// Mapped at creation.
259-
Init {
260-
staging_buffer: StagingBuffer<A>,
261-
ptr: NonNull<u8>,
262-
},
259+
Init { staging_buffer: StagingBuffer<A> },
263260
/// Waiting for GPU to be done before mapping
264261
Waiting(BufferPendingMapping<A>),
265262
/// Mapped
@@ -651,28 +648,22 @@ impl<A: HalApi> Buffer<A> {
651648
let raw_buf = self.try_raw(&snatch_guard)?;
652649
log::debug!("{} map state -> Idle", self.error_ident());
653650
match mem::replace(&mut *self.map_state.lock(), BufferMapState::Idle) {
654-
BufferMapState::Init {
655-
staging_buffer,
656-
ptr,
657-
} => {
651+
BufferMapState::Init { staging_buffer } => {
658652
#[cfg(feature = "trace")]
659653
if let Some(ref mut trace) = *device.trace.lock() {
660-
let data = trace.make_binary("bin", unsafe {
661-
std::slice::from_raw_parts(ptr.as_ptr(), self.size as usize)
662-
});
654+
let data = trace.make_binary("bin", staging_buffer.get_data());
663655
trace.add(trace::Action::WriteBuffer {
664656
id: buffer_id,
665657
data,
666658
range: 0..self.size,
667659
queued: true,
668660
});
669661
}
670-
let _ = ptr;
671662

672663
let mut pending_writes = device.pending_writes.lock();
673664
let pending_writes = pending_writes.as_mut().unwrap();
674665

675-
let staging_buffer = unsafe { staging_buffer.flush() };
666+
let staging_buffer = staging_buffer.flush();
676667

677668
self.use_at(device.active_submission_index.load(Ordering::Relaxed) + 1);
678669
let region = wgt::BufferSize::new(self.size).map(|size| hal::BufferCopy {
@@ -832,6 +823,11 @@ impl<A: HalApi> Drop for DestroyedBuffer<A> {
832823
}
833824
}
834825

826+
#[cfg(send_sync)]
827+
unsafe impl<A: HalApi> Send for StagingBuffer<A> {}
828+
#[cfg(send_sync)]
829+
unsafe impl<A: HalApi> Sync for StagingBuffer<A> {}
830+
835831
/// A temporary buffer, consumed by the command that uses it.
836832
///
837833
/// A [`StagingBuffer`] is designed for one-shot uploads of data to the GPU. It
@@ -857,13 +853,11 @@ pub struct StagingBuffer<A: HalApi> {
857853
device: Arc<Device<A>>,
858854
pub(crate) size: wgt::BufferSize,
859855
is_coherent: bool,
856+
ptr: NonNull<u8>,
860857
}
861858

862859
impl<A: HalApi> StagingBuffer<A> {
863-
pub(crate) fn new(
864-
device: &Arc<Device<A>>,
865-
size: wgt::BufferSize,
866-
) -> Result<(Self, NonNull<u8>), DeviceError> {
860+
pub(crate) fn new(device: &Arc<Device<A>>, size: wgt::BufferSize) -> Result<Self, DeviceError> {
867861
use hal::Device;
868862
profiling::scope!("StagingBuffer::new");
869863
let stage_desc = hal::BufferDescriptor {
@@ -881,9 +875,55 @@ impl<A: HalApi> StagingBuffer<A> {
881875
device: device.clone(),
882876
size,
883877
is_coherent: mapping.is_coherent,
878+
ptr: mapping.ptr,
884879
};
885880

886-
Ok((staging_buffer, mapping.ptr))
881+
Ok(staging_buffer)
882+
}
883+
884+
/// SAFETY: You must not call any functions of `self`
885+
/// until you stopped using the returned pointer.
886+
pub(crate) unsafe fn ptr(&self) -> NonNull<u8> {
887+
self.ptr
888+
}
889+
890+
#[cfg(feature = "trace")]
891+
pub(crate) fn get_data(&self) -> &[u8] {
892+
unsafe { std::slice::from_raw_parts(self.ptr.as_ptr(), self.size.get() as usize) }
893+
}
894+
895+
pub(crate) fn write_zeros(&mut self) {
896+
unsafe { core::ptr::write_bytes(self.ptr.as_ptr(), 0, self.size.get() as usize) };
897+
}
898+
899+
pub(crate) fn write(&mut self, data: &[u8]) {
900+
assert!(data.len() >= self.size.get() as usize);
901+
// SAFETY: With the assert above, all of `copy_nonoverlapping`'s
902+
// requirements are satisfied.
903+
unsafe {
904+
core::ptr::copy_nonoverlapping(
905+
data.as_ptr(),
906+
self.ptr.as_ptr(),
907+
self.size.get() as usize,
908+
);
909+
}
910+
}
911+
912+
/// SAFETY: The offsets and size must be in-bounds.
913+
pub(crate) unsafe fn write_with_offset(
914+
&mut self,
915+
data: &[u8],
916+
src_offset: isize,
917+
dst_offset: isize,
918+
size: usize,
919+
) {
920+
unsafe {
921+
core::ptr::copy_nonoverlapping(
922+
data.as_ptr().offset(src_offset),
923+
self.ptr.as_ptr().offset(dst_offset),
924+
size,
925+
);
926+
}
887927
}
888928

889929
pub(crate) fn flush(self) -> FlushedStagingBuffer<A> {

wgpu-hal/src/lib.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -714,9 +714,13 @@ pub trait Device: WasmNotSendSync {
714714
/// be ordered, so it is meaningful to talk about what must occur
715715
/// "between" them.
716716
///
717+
/// - Zero-sized mappings are not allowed.
718+
///
719+
/// - The returned [`BufferMapping::ptr`] must not be used after a call to
720+
/// [`Device::unmap_buffer`].
721+
///
717722
/// [`MAP_READ`]: BufferUses::MAP_READ
718723
/// [`MAP_WRITE`]: BufferUses::MAP_WRITE
719-
//TODO: clarify if zero-sized mapping is allowed
720724
unsafe fn map_buffer(
721725
&self,
722726
buffer: &<Self::A as Api>::Buffer,

0 commit comments

Comments
 (0)