#nullable enable using System; using System.Collections.Generic; using System.Diagnostics; using System.IO; using System.Linq; using System.Reflection; using System.Runtime.InteropServices; using System.Threading; using System.Threading.Tasks; using Amazon.Runtime.Internal.Util; using AngleSharp.Text; using BTCPayServer.Abstractions.Contracts; using BTCPayServer.Configuration; using BTCPayServer.Data; using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.Identity; using Microsoft.Data.Sqlite; using Microsoft.EntityFrameworkCore; using Microsoft.EntityFrameworkCore.Metadata; using Microsoft.EntityFrameworkCore.Metadata.Conventions; using Microsoft.Extensions.Configuration; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; using MySqlConnector; using NBXplorer; using Newtonsoft.Json.Linq; using Npgsql; namespace BTCPayServer.Hosting { static class TopologySort { public static IEnumerable OrderByTopology(this IEnumerable tables) { var comparer = Comparer.Create((a, b) => a.Name.CompareTo(b.Name)); return OrderByTopology( tables, t => { if (t.Name == "Invoices") return t.ForeignKeyConstraints.Select(f => f.PrincipalTable.Name).Where(f => f != "Refunds"); else return t.ForeignKeyConstraints.Select(f => f.PrincipalTable.Name); }, t => t.Name, t => t, comparer); } public static IEnumerable OrderByTopology( this IEnumerable values, Func> dependsOn, Func getKey, Func getValue, IComparer? solveTies = null) where T : notnull { var v = values.ToList(); return TopologicalSort(v, dependsOn, getKey, getValue, solveTies); } static List TopologicalSort(this IReadOnlyCollection nodes, Func> dependsOn, Func getKey, Func getValue, IComparer? solveTies = null) where T : notnull { if (nodes.Count == 0) return new List(); if (getKey == null) throw new ArgumentNullException(nameof(getKey)); if (getValue == null) throw new ArgumentNullException(nameof(getValue)); solveTies = solveTies ?? Comparer.Default; List result = new List(nodes.Count); HashSet allKeys = new HashSet(nodes.Count); var noDependencies = new SortedDictionary>(solveTies); foreach (var node in nodes) allKeys.Add(getKey(node)); var dependenciesByValues = nodes.ToDictionary(node => node, node => new HashSet(dependsOn(node).Where(n => allKeys.Contains(n)))); foreach (var e in dependenciesByValues.Where(x => x.Value.Count == 0)) { noDependencies.Add(e.Key, e.Value); } if (noDependencies.Count == 0) { throw new InvalidOperationException("Impossible to topologically sort a cyclic graph"); } while (noDependencies.Count > 0) { var nodep = noDependencies.First(); noDependencies.Remove(nodep.Key); dependenciesByValues.Remove(nodep.Key); var elemKey = getKey(nodep.Key); result.Add(getValue(nodep.Key)); foreach (var selem in dependenciesByValues) { if (selem.Value.Remove(elemKey) && selem.Value.Count == 0) noDependencies.Add(selem.Key, selem.Value); } } if (dependenciesByValues.Count != 0) { throw new InvalidOperationException("Impossible to topologically sort a cyclic graph"); } return result; } } public class ToPostgresMigrationStartupTask : IStartupTask { public ToPostgresMigrationStartupTask( IConfiguration configuration, IOptions datadirs, ILogger logger, IWebHostEnvironment environment, ApplicationDbContextFactory dbContextFactory) { Configuration = configuration; Datadirs = datadirs; Logger = logger; Environment = environment; DbContextFactory = dbContextFactory; } public IConfiguration Configuration { get; } public IOptions Datadirs { get; } public ILogger Logger { get; } public IWebHostEnvironment Environment { get; } public ApplicationDbContextFactory DbContextFactory { get; } public bool HasError { get; private set; } public async Task ExecuteAsync(CancellationToken cancellationToken = default) { var p = Configuration.GetOrDefault("POSTGRES", null); var sqlite = Configuration.GetOrDefault("SQLITEFILE", null); var mysql = Configuration.GetOrDefault("MYSQL", null); string migratingFrom; ApplicationDbContext otherContext; if (string.IsNullOrEmpty(p)) { return; } else if (!string.IsNullOrEmpty(sqlite)) { migratingFrom = "SQLite"; sqlite = Datadirs.Value.ToDatadirFullPath(sqlite); if (!File.Exists(sqlite)) return; otherContext = new ApplicationDbContext(new DbContextOptionsBuilder().UseSqlite("Data Source=" + sqlite, o => o.CommandTimeout(60 * 60 * 10)).Options); } else if (!string.IsNullOrEmpty(mysql)) { migratingFrom = "MySQL"; otherContext = new ApplicationDbContext(new DbContextOptionsBuilder().UseMySql(mysql, ServerVersion.AutoDetect(mysql), o => o.CommandTimeout(60 * 60 * 10)).Options); try { await otherContext.Settings.FirstOrDefaultAsync(); } catch (MySqlException ex) when (ex.SqlState == "42000") // DB doesn't exists { return; } } else { return; } if (await otherContext.Settings.FirstOrDefaultAsync() == null) return; { var postgres = new NpgsqlConnectionStringBuilder(p); using var postgresContext = new ApplicationDbContext(new DbContextOptionsBuilder().UseNpgsql(p, o => o.CommandTimeout(60 * 60 * 10)).Options); string? state; try { state = await GetMigrationState(postgresContext); if (state == "complete") return; if (state == null) throw new ConfigException("This postgres database isn't created during a migration. Please use an empty database for postgres when migrating. If it's not a migration, remove --sqlitefile or --mysql settings."); } catch (NpgsqlException ex) when (ex.SqlState == PostgresErrorCodes.InvalidCatalogName || ex.SqlState == PostgresErrorCodes.UndefinedTable) // DB doesn't exists { await postgresContext.Database.MigrateAsync(); state = "pending"; await SetMigrationState(postgresContext, migratingFrom, "pending"); } Logger.LogInformation($"Migrating from {migratingFrom} to Postgres..."); if (state == "pending") { Logger.LogInformation($"There is a unfinished migration in postgres... dropping all tables"); foreach (var t in postgresContext.Model.GetRelationalModel().Tables.OrderByTopology()) { await postgresContext.Database.ExecuteSqlRawAsync($"DROP TABLE IF EXISTS \"{t.Name}\" CASCADE"); } await postgresContext.Database.ExecuteSqlRawAsync($"DROP TABLE IF EXISTS \"__EFMigrationsHistory\" CASCADE"); await postgresContext.Database.MigrateAsync(); } else { throw new ConfigException("This database isn't created during a migration. Please use an empty database for postgres when migrating."); } await otherContext.Database.MigrateAsync(); await SetMigrationState(postgresContext, migratingFrom, "pending"); foreach (var t in postgresContext.Model.GetRelationalModel().Tables.OrderByTopology()) { var typeMapping = t.EntityTypeMappings.Single(); var query = (IQueryable)otherContext.GetType().GetMethod("Set", new Type[0])!.MakeGenericMethod(typeMapping.EntityType.ClrType).Invoke(otherContext, null)!; if (t.Name == "WebhookDeliveries" || t.Name == "InvoiceWebhookDeliveries" || t.Name == "StoreRoles") continue; Logger.LogInformation($"Migrating table: " + t.Name); List datetimeProperties = new List(); foreach (var col in t.Columns) if (col.PropertyMappings.Single().Property.ClrType == typeof(DateTime)) { datetimeProperties.Add(col.PropertyMappings.Single().Property.PropertyInfo!); } List datetimeoffsetProperties = new List(); foreach (var col in t.Columns) if (col.PropertyMappings.Single().Property.ClrType == typeof(DateTimeOffset)) { datetimeoffsetProperties.Add(col.PropertyMappings.Single().Property.PropertyInfo!); } var rows = await query.ToListAsync(); foreach (var row in rows) { // There is as circular deps between invoice and refund. if (row is InvoiceData id) id.CurrentRefundId = null; foreach (var prop in datetimeProperties) { var v = (DateTime)prop.GetValue(row)!; if (v.Kind == DateTimeKind.Unspecified) { v = new DateTime(v.Ticks, DateTimeKind.Utc); prop.SetValue(row, v); } else if (v.Kind == DateTimeKind.Local) { prop.SetValue(row, v.ToUniversalTime()); } } foreach (var prop in datetimeoffsetProperties) { var v = (DateTimeOffset)prop.GetValue(row)!; if (v.Offset != TimeSpan.Zero) { prop.SetValue(row, v.ToOffset(TimeSpan.Zero)); } } postgresContext.Entry(row).State = EntityState.Added; } await postgresContext.SaveChangesAsync(); postgresContext.ChangeTracker.Clear(); } foreach (var invoice in otherContext.Invoices.AsNoTracking().Where(i => i.CurrentRefundId != null)) { postgresContext.Entry(invoice).State = EntityState.Modified; } await postgresContext.SaveChangesAsync(); postgresContext.ChangeTracker.Clear(); await UpdateSequenceInvoiceSearch(postgresContext); await SetMigrationState(postgresContext, migratingFrom, "complete"); } otherContext.Dispose(); SqliteConnection.ClearAllPools(); MySqlConnection.ClearAllPools(); Logger.LogInformation($"Migration to postgres from {migratingFrom} successful"); } internal static async Task UpdateSequenceInvoiceSearch(ApplicationDbContext postgresContext) { await postgresContext.Database.ExecuteSqlRawAsync("SELECT SETVAL('\"InvoiceSearches_Id_seq\"', (SELECT max(\"Id\") FROM \"InvoiceSearches\"));"); } internal static async Task GetMigrationState(ApplicationDbContext postgresContext) { var o = (await postgresContext.Settings.FromSqlRaw("SELECT \"Id\", \"Value\" FROM \"Settings\" WHERE \"Id\"='MigrationData'").AsNoTracking().FirstOrDefaultAsync())?.Value; if (o is null) return null; return JObject.Parse(o)["state"]?.Value(); } private static async Task SetMigrationState(ApplicationDbContext postgresContext, string migratingFrom, string state) { await postgresContext.Database.ExecuteSqlRawAsync( "INSERT INTO \"Settings\" VALUES ('MigrationData', @p0::JSONB) ON CONFLICT (\"Id\") DO UPDATE SET \"Value\"=@p0::JSONB", new[] { $"{{ \"from\": \"{migratingFrom}\", \"state\": \"{state}\" }}" }); } } }