Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -241,14 +241,14 @@ struct StaticKVCache: CoreAIKVCache {
self.currentCapacity = min(capacity ?? maxCapacityFromModel, maxCapacityFromModel)
}

// Create modified requirements with adjusted sequence dimension.
var keyReqsMod = keyReqs
var valueReqsMod = valueReqs
keyReqsMod.shape[seqDim] = self.currentCapacity
valueReqsMod.shape[seqDim] = self.currentCapacity
// Build concrete shapes with adjusted sequence dimension.
var keyShape = keyReqs.shape
var valueShape = valueReqs.shape
keyShape[seqDim] = self.currentCapacity
valueShape[seqDim] = self.currentCapacity

let keyResolved = keyReqsMod.resolvingDynamicDimensions(keyReqsMod.shape)
let valueResolved = valueReqsMod.resolvingDynamicDimensions(valueReqsMod.shape)
let keyResolved = keyReqs.resolvingDynamicDimensions(keyShape)
let valueResolved = valueReqs.resolvingDynamicDimensions(valueShape)

let keyByteCount = keyResolved.minimumByteCount
let valueByteCount = valueResolved.minimumByteCount
Expand All @@ -260,16 +260,16 @@ struct StaticKVCache: CoreAIKVCache {
}

self.keyBinding = TensorBinding(
metalBuffer: keyBuf, shape: keyReqsMod.shape,
metalBuffer: keyBuf, shape: keyShape,
strides: keyResolved.preferredStrides, scalarType: keyReqs.scalarType)
self.valueBinding = TensorBinding(
metalBuffer: valueBuf, shape: valueReqsMod.shape,
metalBuffer: valueBuf, shape: valueShape,
strides: valueResolved.preferredStrides, scalarType: valueReqs.scalarType)

// Log final allocation summary
let fmt = ByteCountFormatter()
fmt.countStyle = .memory
let shapeDesc = KVCacheFactory.describeKVCacheStructure(shape: keyReqsMod.shape)
let shapeDesc = KVCacheFactory.describeKVCacheStructure(shape: keyShape)
CLILogger.log(
"StaticKVCache allocated: \(shapeDesc), Total: \(fmt.string(fromByteCount: Int64(keyByteCount + valueByteCount)))"
)
Expand Down Expand Up @@ -354,14 +354,14 @@ struct GrowingKVCache: CoreAIKVCache {
self.maxCapacity = maxCapacityFromModel > 0 ? maxCapacityFromModel : Int.max
self.currentCapacity = initialCapacity

// Create modified requirements with initial capacity.
var keyReqsMod = keyReqs
var valueReqsMod = valueReqs
keyReqsMod.shape[sequenceDim] = self.currentCapacity
valueReqsMod.shape[sequenceDim] = self.currentCapacity
// Build concrete shapes with initial capacity.
var keyShape = keyReqs.shape
var valueShape = valueReqs.shape
keyShape[sequenceDim] = self.currentCapacity
valueShape[sequenceDim] = self.currentCapacity

let keyResolved = keyReqsMod.resolvingDynamicDimensions(keyReqsMod.shape)
let valueResolved = valueReqsMod.resolvingDynamicDimensions(valueReqsMod.shape)
let keyResolved = keyReqs.resolvingDynamicDimensions(keyShape)
let valueResolved = valueReqs.resolvingDynamicDimensions(valueShape)

let keyByteCount = keyResolved.minimumByteCount
let valueByteCount = valueResolved.minimumByteCount
Expand All @@ -373,16 +373,16 @@ struct GrowingKVCache: CoreAIKVCache {
}

self.keyBinding = TensorBinding(
metalBuffer: keyBuf, shape: keyReqsMod.shape,
metalBuffer: keyBuf, shape: keyShape,
strides: keyResolved.preferredStrides, scalarType: keyReqs.scalarType)
self.valueBinding = TensorBinding(
metalBuffer: valueBuf, shape: valueReqsMod.shape,
metalBuffer: valueBuf, shape: valueShape,
strides: valueResolved.preferredStrides, scalarType: valueReqs.scalarType)

// Log final allocation summary
let fmt = ByteCountFormatter()
fmt.countStyle = .memory
let shapeDesc = KVCacheFactory.describeKVCacheStructure(shape: keyReqsMod.shape)
let shapeDesc = KVCacheFactory.describeKVCacheStructure(shape: keyShape)
CLILogger.log(
"GrowingKVCache allocated (initial): \(shapeDesc), Total: \(fmt.string(fromByteCount: Int64(keyByteCount + valueByteCount)))"
)
Expand Down Expand Up @@ -432,14 +432,14 @@ struct GrowingKVCache: CoreAIKVCache {
}
guard newCapacity > currentCapacity else { return nil }

// Create modified requirements with new capacity
var keyReqsMod = keyReqsTemplate
var valueReqsMod = valueReqsTemplate
keyReqsMod.shape[sequenceDim] = newCapacity
valueReqsMod.shape[sequenceDim] = newCapacity
// Build concrete shapes with new capacity.
var keyShape = keyReqsTemplate.shape
var valueShape = valueReqsTemplate.shape
keyShape[sequenceDim] = newCapacity
valueShape[sequenceDim] = newCapacity

let keyResolved = keyReqsMod.resolvingDynamicDimensions(keyReqsMod.shape)
let valueResolved = valueReqsMod.resolvingDynamicDimensions(valueReqsMod.shape)
let keyResolved = keyReqsTemplate.resolvingDynamicDimensions(keyShape)
let valueResolved = valueReqsTemplate.resolvingDynamicDimensions(valueShape)

let newKeyByteCount = keyResolved.minimumByteCount
let newValueByteCount = valueResolved.minimumByteCount
Expand All @@ -455,11 +455,10 @@ struct GrowingKVCache: CoreAIKVCache {
let oldValueBuf = valueBinding.metalBuffer

// Extract shape dimensions: [L, B, H, S, D]
let shape = keyReqsMod.shape
let l = shape[0]
let b = shape[1]
let h = shape[2]
let d = shape[4]
let l = keyShape[0]
let b = keyShape[1]
let h = keyShape[2]
let d = keyShape[4]
let oldS = currentCapacity
let newS = newCapacity
let mpsDataType = keyReqsTemplate.scalarType.mpsDataType
Expand All @@ -476,15 +475,15 @@ struct GrowingKVCache: CoreAIKVCache {

// Update bindings to new buffers (CPU metadata only — safe before GPU executes)
keyBinding = TensorBinding(
metalBuffer: newKeyBuf, shape: keyReqsMod.shape,
metalBuffer: newKeyBuf, shape: keyShape,
strides: keyResolved.preferredStrides, scalarType: keyReqsTemplate.scalarType)
valueBinding = TensorBinding(
metalBuffer: newValueBuf, shape: valueReqsMod.shape,
metalBuffer: newValueBuf, shape: valueShape,
strides: valueResolved.preferredStrides, scalarType: valueReqsTemplate.scalarType)

let fmt = ByteCountFormatter()
fmt.countStyle = .memory
let shapeDesc = KVCacheFactory.describeKVCacheStructure(shape: keyReqsMod.shape)
let shapeDesc = KVCacheFactory.describeKVCacheStructure(shape: keyShape)
CLILogger.log(
"GrowingKVCache pipelined grow: \(currentCapacity)\(newCapacity), \(shapeDesc), Total: \(fmt.string(fromByteCount: Int64(newKeyByteCount + newValueByteCount)))"
)
Expand Down