diff --git a/app/src/main/java/app/closer/data/remote/FirestoreBackupDataSource.kt b/app/src/main/java/app/closer/data/remote/FirestoreBackupDataSource.kt new file mode 100644 index 00000000..cfebe003 --- /dev/null +++ b/app/src/main/java/app/closer/data/remote/FirestoreBackupDataSource.kt @@ -0,0 +1,294 @@ +package app.closer.data.remote + +import app.closer.crypto.CoupleEncryptionManager +import app.closer.crypto.FieldEncryptor +import app.closer.domain.model.BackupCursor +import app.closer.domain.model.BackupManifest +import app.closer.domain.model.BackupMessageRecord +import app.closer.domain.model.RestoreRequest +import app.closer.domain.model.RestoreStatus +import com.google.firebase.firestore.DocumentSnapshot +import com.google.firebase.firestore.FirebaseFirestore +import kotlinx.coroutines.channels.awaitClose +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.callbackFlow +import kotlinx.coroutines.tasks.await +import javax.inject.Inject +import javax.inject.Singleton + +/** + * Server side of the E2EE conversation backup + full partner-assisted restore. + * + * Layout (all couple-key ciphertext / ECIES keyboxes — the server holds nothing readable): + * couples/{id}/backup/manifest — pointers + `generation` (optimistic concurrency) + * couples/{id}/backup/manifest/chunks/{seq} — incremental encrypted chunks + * couples/{id}/restore_requests/{recipientUid} — partner-assist request + written keybox + * (snapshot blob lives in Storage: couples/{id}/backups/{snapshotId}, via FirebaseStorageDataSource) + * + * Both partners write the same couple backup; convergence is by **message-id dedupe** + **manifest + * `generation` CAS** in a transaction. Chunks are deleted only AFTER the manifest commit (crash-safe). + */ +@Singleton +class FirestoreBackupDataSource @Inject constructor( + private val db: FirebaseFirestore, + private val encryptionManager: CoupleEncryptionManager, + private val fieldEncryptor: FieldEncryptor, + private val storageDataSource: FirebaseStorageDataSource +) { + private fun manifestRef(coupleId: String) = + db.collection(FirestoreCollections.COUPLES).document(coupleId) + .collection(FirestoreCollections.Couples.BACKUP) + .document(FirestoreCollections.Backup.MANIFEST_ID) + + private fun chunksRef(coupleId: String) = + manifestRef(coupleId).collection(FirestoreCollections.Backup.CHUNKS) + + private fun restoreRequestRef(coupleId: String, recipientUid: String) = + db.collection(FirestoreCollections.COUPLES).document(coupleId) + .collection(FirestoreCollections.Couples.RESTORE_REQUESTS).document(recipientUid) + + // ─── Manifest ──────────────────────────────────────────────────────────── + + suspend fun getManifest(coupleId: String): BackupManifest? { + val snap = manifestRef(coupleId).get().await() + return if (snap.exists()) snap.toManifest() else null + } + + // ─── Incremental append ────────────────────────────────────────────────── + + /** + * Append [records] (strictly after [afterCursor]) as one encrypted chunk and advance the manifest + * cursor + seq atomically. Optimistic: fails (returns false) if another writer bumped `generation` + * meanwhile — the caller re-reads and retries. Idempotent at restore time via message-id dedupe. + */ + suspend fun appendChunk( + coupleId: String, + userId: String, + records: List, + newCursor: BackupCursor, + addedMessageCount: Int + ): Boolean { + if (records.isEmpty()) return true + val aead = encryptionManager.aeadFor(coupleId) ?: return false + val payload = fieldEncryptor.encrypt(BackupCodec.encode(records), aead, coupleId) + val now = System.currentTimeMillis() + return db.runTransaction { txn -> + val mSnap = txn.get(manifestRef(coupleId)) + val current = if (mSnap.exists()) mSnap.toManifest() else BackupManifest() + val seq = current.latestChunkSeq + 1 + txn.set( + chunksRef(coupleId).document(seq.toString()), + mapOf( + "seq" to seq, + "payload" to payload, + "count" to records.size, + "createdBy" to userId, + "createdAt" to now + ) + ) + txn.set( + manifestRef(coupleId), + current.copy( + generation = current.generation + 1, + latestChunkSeq = seq, + snapshotThroughCursor = maxCursor(current.snapshotThroughCursor, newCursor), + messageCount = current.messageCount + addedMessageCount, + updatedAt = now, + updatedBy = userId + ).toMap() + ) + true + }.await() + } + + /** All chunk payloads (ciphertext) ordered by seq, for restore/compaction. */ + suspend fun getChunks(coupleId: String): List { + val q = chunksRef(coupleId).orderBy("seq").get().await() + return q.documents.mapNotNull { d -> + val payload = d.getString("payload") ?: return@mapNotNull null + ChunkDoc(seq = d.getLong("seq") ?: 0L, payload = payload) + } + } + + /** Decrypt a chunk/snapshot ciphertext payload into records (null if the key is unavailable). */ + fun decodeCiphertext(coupleId: String, ciphertext: String?): List { + val aead = encryptionManager.aeadFor(coupleId) ?: return emptyList() + val plain = fieldEncryptor.decrypt(ciphertext, aead, coupleId) ?: return emptyList() + return runCatching { BackupCodec.decode(plain) }.getOrDefault(emptyList()) + } + + // ─── Compaction (fold chunks → snapshot blob) ──────────────────────────── + + /** + * Upload a full-state snapshot blob, then CAS the manifest to point at it, then delete the folded + * chunks (only after the manifest commit → crash-safe). Returns the previous snapshotId to delete + * from Storage, or null on a lost CAS race (caller retries). + */ + suspend fun writeSnapshot( + coupleId: String, + userId: String, + records: List, + throughCursor: BackupCursor, + expectedGeneration: Long, + foldedChunkSeqs: List + ): SnapshotResult? { + val aead = encryptionManager.aeadFor(coupleId) ?: return null + val plain = BackupCodec.encode(records) + val checksum = BackupCodec.checksum(plain) + val ciphertext = fieldEncryptor.encrypt(plain, aead, coupleId) + val snapshotId = java.util.UUID.randomUUID().toString() + val url = storageDataSource.uploadBackupSnapshot(userId, snapshotId, ciphertext.toByteArray(Charsets.UTF_8)) + val now = System.currentTimeMillis() + + // Returns "owner|snapshotId" of the PREVIOUS snapshot to clean up, or ABORT on a lost race. + val prevRef = db.runTransaction { txn -> + val mSnap = txn.get(manifestRef(coupleId)) + val current = if (mSnap.exists()) mSnap.toManifest() else BackupManifest() + if (current.generation != expectedGeneration) { + // Lost the race — abort so the freshly-uploaded blob is orphaned (cleaned up by caller). + return@runTransaction ABORT + } + txn.set( + manifestRef(coupleId), + current.copy( + generation = current.generation + 1, + snapshotUrl = url, + snapshotOwner = userId, + snapshotChecksum = checksum, + snapshotThroughCursor = throughCursor, + messageCount = records.count { !it.deleted }, + updatedAt = now, + updatedBy = userId + ).toMap() + ) + "${current.snapshotOwner}|${snapshotIdFromUrl(current.snapshotUrl) ?: ""}" + }.await() + + if (prevRef == ABORT) { + storageDataSource.deleteBackupSnapshot(userId, snapshotId) // clean the orphan + return null + } + // Manifest committed → now safe to delete folded chunks + the previous snapshot (best-effort; + // cross-owner deletes silently no-op and are cleaned by that owner / account deletion). + foldedChunkSeqs.forEach { seq -> + runCatching { chunksRef(coupleId).document(seq.toString()).delete().await() } + } + val prevOwner = prevRef.substringBefore("|") + val prevId = prevRef.substringAfter("|") + if (prevId.isNotBlank()) storageDataSource.deleteBackupSnapshot(prevOwner, prevId) + return SnapshotResult(snapshotId = snapshotId, previousSnapshotId = prevId.takeIf { it.isNotBlank() }) + } + + suspend fun downloadSnapshotCiphertext(snapshotUrl: String): String = + String(storageDataSource.downloadBytes(snapshotUrl), Charsets.UTF_8) + + // ─── Restore requests (partner-assist) ─────────────────────────────────── + + suspend fun createRestoreRequest( + coupleId: String, + recipientUid: String, + recipientPublicKey: String, + requestNonce: String, + expiresAt: Long + ) { + restoreRequestRef(coupleId, recipientUid).set( + mapOf( + "recipientUid" to recipientUid, + "recipientPublicKey" to recipientPublicKey, + "requestNonce" to requestNonce, + "status" to RestoreStatus.REQUESTED.name, + "createdAt" to System.currentTimeMillis(), + "expiresAt" to expiresAt + ) + ).await() + } + + /** Partner writes the couple key wrapped to the recipient's fresh pubkey, after OOB-code confirm. */ + suspend fun fulfillRestoreRequest(coupleId: String, recipientUid: String, keybox: String) { + restoreRequestRef(coupleId, recipientUid).update( + mapOf( + "keybox" to keybox, + "status" to RestoreStatus.READY.name, + "fulfilledAt" to System.currentTimeMillis() + ) + ).await() + } + + suspend fun updateRestoreStatus(coupleId: String, recipientUid: String, status: RestoreStatus) { + restoreRequestRef(coupleId, recipientUid).update("status", status.name).await() + } + + /** A consumes (deletes) its own request after unwrapping — no wrapped key lingers. */ + suspend fun deleteRestoreRequest(coupleId: String, recipientUid: String) { + runCatching { restoreRequestRef(coupleId, recipientUid).delete().await() } + } + + suspend fun getRestoreRequest(coupleId: String, recipientUid: String): RestoreRequest? { + val snap = restoreRequestRef(coupleId, recipientUid).get().await() + return if (snap.exists()) snap.toRestoreRequest() else null + } + + /** Live view of a restore request (A observes own; B observes the partner's uid doc). */ + fun observeRestoreRequest(coupleId: String, recipientUid: String): Flow = callbackFlow { + val reg = restoreRequestRef(coupleId, recipientUid).addSnapshotListener { snap, err -> + if (err != null) return@addSnapshotListener + trySend(snap?.takeIf { it.exists() }?.toRestoreRequest()) + } + awaitClose { reg.remove() } + } + + // ─── mapping ───────────────────────────────────────────────────────────── + + private fun DocumentSnapshot.toManifest() = BackupManifest( + schemaVersion = (getLong("schemaVersion") ?: 1L).toInt(), + generation = getLong("generation") ?: 0L, + snapshotUrl = getString("snapshotUrl"), + snapshotOwner = getString("snapshotOwner") ?: "", + snapshotChecksum = getString("snapshotChecksum"), + snapshotThroughCursor = BackupCursor( + createdAt = getLong("snapshotThroughCursorAt") ?: 0L, + messageId = getString("snapshotThroughCursorId") ?: "" + ), + latestChunkSeq = getLong("latestChunkSeq") ?: 0L, + messageCount = (getLong("messageCount") ?: 0L).toInt(), + updatedAt = getLong("updatedAt") ?: 0L, + updatedBy = getString("updatedBy") ?: "" + ) + + private fun BackupManifest.toMap(): Map = mapOf( + "schemaVersion" to schemaVersion, + "generation" to generation, + "snapshotUrl" to snapshotUrl, + "snapshotOwner" to snapshotOwner, + "snapshotChecksum" to snapshotChecksum, + "snapshotThroughCursorAt" to snapshotThroughCursor.createdAt, + "snapshotThroughCursorId" to snapshotThroughCursor.messageId, + "latestChunkSeq" to latestChunkSeq, + "messageCount" to messageCount, + "updatedAt" to updatedAt, + "updatedBy" to updatedBy + ) + + private fun DocumentSnapshot.toRestoreRequest() = RestoreRequest( + recipientUid = getString("recipientUid") ?: "", + recipientPublicKey = getString("recipientPublicKey") ?: "", + requestNonce = getString("requestNonce") ?: "", + keybox = getString("keybox"), + status = runCatching { RestoreStatus.valueOf(getString("status") ?: "") }.getOrDefault(RestoreStatus.REQUESTED), + createdAt = getLong("createdAt") ?: 0L, + expiresAt = getLong("expiresAt") ?: 0L, + fulfilledAt = getLong("fulfilledAt") + ) + + private fun maxCursor(a: BackupCursor, b: BackupCursor): BackupCursor = if (b.isAfter(a)) b else a + + private fun snapshotIdFromUrl(url: String?): String? = + url?.substringAfter("backups%2F", "")?.substringBefore("?", "")?.takeIf { it.isNotBlank() } + + data class ChunkDoc(val seq: Long, val payload: String) + data class SnapshotResult(val snapshotId: String, val previousSnapshotId: String?) + + private companion object { + const val ABORT = "__abort__" + } +}