btcpayserver/BTCPayServer/Payments/PayJoin/PayJoinRepository.cs

75 lines
2.3 KiB
C#
Raw Normal View History

using System.Collections.Generic;
using System.Linq;
2020-03-30 00:28:22 +09:00
using System.Threading.Tasks;
2020-04-13 15:43:25 +09:00
using BTCPayServer.Data;
using Dapper;
2020-04-13 15:43:25 +09:00
using Microsoft.EntityFrameworkCore;
2020-03-30 00:28:22 +09:00
using NBitcoin;
using Npgsql;
2020-03-30 00:28:22 +09:00
namespace BTCPayServer.Payments.PayJoin
{
public class UTXOLocker : IUTXOLocker
2020-03-30 00:28:22 +09:00
{
2020-04-13 15:43:25 +09:00
private readonly ApplicationDbContextFactory _dbContextFactory;
public UTXOLocker(ApplicationDbContextFactory dbContextFactory)
2020-04-13 15:43:25 +09:00
{
_dbContextFactory = dbContextFactory;
}
2020-04-13 15:43:25 +09:00
public async Task<bool> TryUnlock(params OutPoint[] outPoints)
2020-03-30 00:28:22 +09:00
{
2020-04-13 15:43:25 +09:00
using var ctx = _dbContextFactory.CreateContext();
foreach (OutPoint outPoint in outPoints)
2020-03-30 00:28:22 +09:00
{
ctx.PayjoinLocks.Remove(new PayjoinLock() { Id = outPoint.ToString() });
2020-04-13 15:43:25 +09:00
}
2020-04-13 15:43:25 +09:00
try
{
return await ctx.SaveChangesAsync() == outPoints.Length;
}
2020-04-16 14:25:52 +09:00
catch (DbUpdateException)
2020-04-13 15:43:25 +09:00
{
return false;
2020-03-30 00:28:22 +09:00
}
}
private async Task<bool> TryLockInputs(string[] ids)
2020-03-30 00:28:22 +09:00
{
2020-04-13 15:43:25 +09:00
using var ctx = _dbContextFactory.CreateContext();
var connection = ctx.Database.GetDbConnection();
2020-04-13 15:43:25 +09:00
try
{
await connection.ExecuteAsync("""
INSERT INTO "PayjoinLocks"("Id")
SELECT * FROM unnest(@ids)
""", new { ids });
return true;
2020-04-13 15:43:25 +09:00
}
catch (Npgsql.PostgresException ex) when (ex.SqlState == PostgresErrorCodes.UniqueViolation)
2020-03-30 00:28:22 +09:00
{
2020-04-13 15:43:25 +09:00
return false;
2020-03-30 00:28:22 +09:00
}
}
public Task<bool> TryLock(OutPoint outpoint)
=> TryLockInputs([outpoint.ToString()]);
public Task<bool> TryLockInputs(OutPoint[] outPoints)
=> TryLockInputs(outPoints.Select(o => "K-" + o.ToString()).ToArray());
public async Task<HashSet<OutPoint>> FindLocks(OutPoint[] outpoints)
{
var outPointsStr = outpoints.Select(o => o.ToString());
await using var ctx = _dbContextFactory.CreateContext();
return (await ctx.PayjoinLocks.Where(l => outPointsStr.Contains(l.Id)).ToArrayAsync())
.Select(l => OutPoint.Parse(l.Id)).ToHashSet();
}
2020-03-30 00:28:22 +09:00
}
}