Skip to content

Commit 83b7b0b

Browse files
authored
add generated columns support (#20)
* add generated columns support
1 parent c2f11e5 commit 83b7b0b

3 files changed

Lines changed: 56 additions & 43 deletions

File tree

modules/core/src/main/scala/com/anymindgroup/PgCodeGen.scala

Lines changed: 39 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,8 @@ class PgCodeGen(
9797
val filterFragment: Fragment[Void] =
9898
sql" AND table_name NOT IN (#${(schemaHistoryTableName :: excludeTables).mkString("'", "','", "'")})"
9999

100-
val q: Query[Void, String ~ String ~ String ~ Option[Int] ~ Option[Int] ~ Option[Int] ~ String ~ Option[String]] =
101-
sql"""SELECT table_name,column_name,udt_name,character_maximum_length,numeric_precision,numeric_scale,is_nullable,column_default
100+
val q =
101+
sql"""SELECT table_name,column_name,udt_name,character_maximum_length,numeric_precision,numeric_scale,is_nullable,column_default,is_generated
102102
FROM information_schema.COLUMNS WHERE table_schema = 'public'$filterFragment UNION
103103
(SELECT
104104
cls.relname AS table_name,
@@ -114,24 +114,27 @@ class PgCodeGen(
114114
WHEN attr.attnotnull OR tp.typtype = 'd'::"char" AND tp.typnotnull THEN 'NO'::text
115115
ELSE 'YES'::text
116116
END::information_schema.yes_or_no AS is_nullable,
117-
NULL AS column_default
117+
NULL AS column_default,
118+
'NEVER' AS is_generated
118119
FROM pg_catalog.pg_attribute as attr
119120
JOIN pg_catalog.pg_class as cls on cls.oid = attr.attrelid
120121
JOIN pg_catalog.pg_namespace as ns on ns.oid = cls.relnamespace
121122
JOIN pg_catalog.pg_type as tp on tp.oid = attr.atttypid
122123
WHERE cls.relkind = 'm' and attr.attnum >= 1 AND ns.nspname = 'public'
123124
ORDER by attr.attnum)
124-
""".query(name ~ name ~ name ~ int4.opt ~ int4.opt ~ int4.opt ~ varchar(3) ~ varchar.opt)
125+
""".query(name ~ name ~ name ~ int4.opt ~ int4.opt ~ int4.opt ~ varchar(3) ~ varchar.opt ~ varchar)
125126

126-
s.execute(q.map { case tName ~ colName ~ udt ~ maxCharLength ~ numPrecision ~ numScale ~ nullable ~ default =>
127-
(
128-
tName,
129-
colName,
130-
toType(udt, maxCharLength, numPrecision, numScale),
131-
nullable == "YES",
132-
default.flatMap(ColumnDefault.fromString),
133-
)
134-
}).map(_.map { case (tName, colName, udt, isNullable, default) =>
127+
s.execute(q.map {
128+
case tName ~ colName ~ udt ~ maxCharLength ~ numPrecision ~ numScale ~ nullable ~ default ~ is_generated =>
129+
(
130+
tName,
131+
colName,
132+
toType(udt, maxCharLength, numPrecision, numScale),
133+
nullable == "YES",
134+
default.flatMap(ColumnDefault.fromString),
135+
is_generated == "ALWAYS",
136+
)
137+
}).map(_.map { case (tName, colName, udt, isNullable, default, isAlwaysGenerated) =>
135138
toScalaType(udt, isNullable, enums).map { st =>
136139
(
137140
tName,
@@ -142,6 +145,7 @@ class PgCodeGen(
142145
scalaType = st,
143146
isNullable = isNullable,
144147
default = default,
148+
isAlwaysGenerated = isAlwaysGenerated,
145149
),
146150
)
147151
}.leftMap(new Exception(_))
@@ -316,7 +320,7 @@ class PgCodeGen(
316320

317321
columns.toList.map { case (tname, tableCols) =>
318322
val tableConstraints = constraints.getOrElse(tname, Nil)
319-
val autoIncColumns = findAutoIncColumns(tname)
323+
val generatedCols = findAutoIncColumns(tname) ::: tableCols.filter(_.isAlwaysGenerated)
320324
val autoIncFk = tableConstraints.collect { case c: Constraint.ForeignKey => c }.flatMap {
321325
_.refs.flatMap { ref =>
322326
tableCols.find(c => c.columnName == ref.fromColName).filter { _ =>
@@ -327,8 +331,8 @@ class PgCodeGen(
327331

328332
Table(
329333
name = tname,
330-
columns = tableCols.filterNot((autoIncColumns ::: autoIncFk).contains),
331-
autoIncColumns = autoIncColumns,
334+
columns = tableCols.filterNot((generatedCols ::: autoIncFk).contains),
335+
generatedColumns = generatedCols,
332336
constraints = tableConstraints,
333337
indexes = indexes.getOrElse(tname, Nil),
334338
autoIncFk = autoIncFk,
@@ -625,7 +629,7 @@ class PgCodeGen(
625629
if (autoIncFk.isEmpty) {
626630
(rowClassName, s"${rowClassName}.codec")
627631
} else {
628-
val autoIncFkCodecs = autoIncFk.map(col => s"skunk.codec.all.${col.pgType.name}").mkString(" *: ")
632+
val autoIncFkCodecs = autoIncFk.map(_.codecName).mkString(" *: ")
629633
val autoIncFkScalaTypes = autoIncFk.map(_.scalaType).mkString(" *: ")
630634
(s"($autoIncFkScalaTypes ~ $rowClassName)", s"$autoIncFkCodecs ~ ${rowClassName}.codec")
631635
}
@@ -640,18 +644,20 @@ class PgCodeGen(
640644
val allColNames = allCols.map(_.columnName).mkString(",")
641645
val (insertScalaType, insertCodec) = queryTypesStr(table)
642646

643-
val returningStatement = autoIncColumns match {
647+
val returningStatement = generatedColumns match {
644648
case Nil => ""
645-
case _ => autoIncColumns.map(_.columnName).mkString(" RETURNING ", ",", "")
649+
case _ => generatedColumns.map(_.columnName).mkString(" RETURNING ", ",", "")
646650
}
647-
val returningType = autoIncColumns.map(_.scalaType).mkString(" *: ")
648-
val fragmentType = autoIncColumns match {
651+
val returningType = generatedColumns
652+
.map(_.scalaType)
653+
.mkString("", " *: ", if (generatedColumns.length > 1) " *: EmptyTuple" else "")
654+
val fragmentType = generatedColumns match {
649655
case Nil => "command"
650-
case _ => s"query(${autoIncColumns.map(col => s"skunk.codec.all.${col.pgType.name}").mkString(" *: ")})"
656+
case _ => s"query(${generatedColumns.map(_.codecName).mkString(" *: ")})"
651657
}
652658

653659
val upsertQ = primaryUniqueConstraint.map { cstr =>
654-
val queryType = autoIncColumns match {
660+
val queryType = generatedColumns match {
655661
case Nil => s"Command[$insertScalaType *: updateFr.A *: EmptyTuple]"
656662
case _ => s"Query[$insertScalaType *: updateFr.A *: EmptyTuple, $returningType]"
657663
}
@@ -662,7 +668,7 @@ class PgCodeGen(
662668
| DO UPDATE SET $${updateFr.fragment}$returningStatement\"\"\".$fragmentType""".stripMargin
663669
}
664670

665-
val queryType = autoIncColumns match {
671+
val queryType = generatedColumns match {
666672
case Nil => s"Command[$insertScalaType]"
667673
case _ => s"Query[$insertScalaType, $returningType]"
668674
}
@@ -693,7 +699,7 @@ class PgCodeGen(
693699
}
694700

695701
private def tableColumns(table: Table): (Option[String], String) = {
696-
val allCols = table.autoIncColumns ::: table.autoIncFk ::: table.columns
702+
val allCols = table.generatedColumns ::: table.autoIncFk ::: table.columns
697703
val cols =
698704
allCols.map(column =>
699705
s""" val ${column.snakeCaseScalaName} = Cols(NonEmptyList.of("${column.columnName}"), ${column.codecName}, tableName)"""
@@ -714,13 +720,13 @@ class PgCodeGen(
714720
private def selectAllStatement(table: Table): String = {
715721
import table.*
716722

717-
val autoIncStm = if (autoIncColumns.nonEmpty) {
718-
val types = autoIncColumns.map(_.codecName).mkString(" *: ")
719-
val sTypes = autoIncColumns.map(_.scalaType).mkString(" *: ")
720-
val colNamesStr = (autoIncColumns ::: columns).map(_.columnName).mkString(", ")
723+
val generatedColStm = if (generatedColumns.nonEmpty) {
724+
val types = generatedColumns.map(_.codecName).mkString(" *: ")
725+
val sTypes = generatedColumns.map(_.scalaType).mkString(" *: ")
726+
val colNamesStr = (generatedColumns ::: columns).map(_.columnName).mkString(", ")
721727

722728
s"""
723-
| def selectAllWithId[A](addClause: Fragment[A] = Fragment.empty): Query[A, $sTypes *: $rowClassName *: EmptyTuple] =
729+
| def selectAllWithGenerated[A](addClause: Fragment[A] = Fragment.empty): Query[A, $sTypes *: $rowClassName *: EmptyTuple] =
724730
| sql"SELECT $colNamesStr FROM #$$tableName $$addClause".query($types *: ${rowClassName}.codec)
725731
|
726732
""".stripMargin
@@ -740,7 +746,7 @@ class PgCodeGen(
740746
val selectCol = s"""| def select[A, B](cols: Cols[A], rest: Fragment[B] = Fragment.empty): Query[B, A] =
741747
| sql"SELECT #$${cols.name} FROM #$$tableName $$rest".query(cols.codec)
742748
|""".stripMargin
743-
autoIncStm ++ defaultStm ++ selectCol
749+
generatedColStm ++ defaultStm ++ selectCol
744750
}
745751

746752
private def lastModified(modified: List[Long]): Option[Long] =
@@ -806,6 +812,7 @@ object PgCodeGen {
806812
isEnum: Boolean,
807813
isNullable: Boolean,
808814
default: Option[ColumnDefault],
815+
isAlwaysGenerated: Boolean,
809816
) {
810817
val scalaName: String = toScalaName(columnName)
811818
val snakeCaseScalaName: String = escapeScalaKeywords(columnName)
@@ -849,7 +856,7 @@ object PgCodeGen {
849856
final case class Table(
850857
name: String,
851858
columns: List[Column],
852-
autoIncColumns: List[Column],
859+
generatedColumns: List[Column],
853860
constraints: List[Constraint],
854861
indexes: List[Index],
855862
autoIncFk: List[Column],

modules/core/src/test/resources/db/migration/V1__test.sql

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@ CREATE TABLE test (
1414
tla_var varchar(3) NOT NULL,
1515
numeric_default numeric NOT NULL,
1616
numeric_24p numeric(24) NOT NULL,
17-
numeric_16p_2s numeric(16, 2) NOT NULL
17+
numeric_16p_2s numeric(16, 2) NOT NULL,
18+
gen INT NOT NULL GENERATED ALWAYS AS (1 + 1) STORED,
19+
gen_opt INT GENERATED ALWAYS AS (1 + 1) STORED
1820
);
1921

2022
CREATE TABLE test_ref_only (

modules/core/src/test/scala/com/anymindgroup/GeneratedCodeTest._scala

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ import better.files.File
1616
import skunk.*
1717
import skunk.codec.all.*
1818
import skunk.util.Origin
19-
import skunk.{Command as SqlCommand}
19+
import skunk.Command as SqlCommand
2020
import cats.implicits.*
2121
import java.time.temporal.ChronoUnit
2222
import java.time.ZoneOffset
@@ -44,7 +44,7 @@ object GeneratedCodeTest extends IOApp {
4444
tlaVar = "abc",
4545
numericDefault = BigDecimal(1),
4646
numeric24p = BigDecimal(2),
47-
numeric16p2s = BigDecimal(3)
47+
numeric16p2s = BigDecimal(3),
4848
).withUpdateAll
4949

5050
val testBRow = TestBRow(
@@ -83,20 +83,24 @@ object GeneratedCodeTest extends IOApp {
8383
// Test table
8484
p <- s.prepare(TestTable.upsertQuery(testUpdateFr))
8585
_ <- s.prepare(TestTable.insertQuery(ignoreConflict = true))
86-
id <- p.option((testRow, testUpdateFr.argument))
87-
_ <- IO.raiseWhen(id.isEmpty)(new Throwable("test A did not return a generated id"))
86+
res <- p.option((testRow, testUpdateFr.argument))
87+
_ <- IO.raiseWhen(res.isEmpty)(new Throwable("test A did not return generated columns"))
88+
id = res.get._1
89+
_ <- IO.raiseWhen(res.get._2 != 2 && res.get._3 != Some(2))(
90+
new Throwable("unexpected result for generated columns")
91+
)
8892
all <- s.execute(TestTable.selectAll())
89-
allWithId <- s.execute(TestTable.selectAllWithId())
93+
allWithGen <- s.execute(TestTable.selectAllWithGenerated())
9094
_ <- IO.raiseWhen(all != List(testRow))(new Throwable("test A result not equal"))
91-
_ <- IO.raiseWhen(allWithId.map(_._2) != List(testRow))(new Throwable("test A result with id not equal"))
95+
_ <- IO.raiseWhen(allWithGen.map(_._4) != List(testRow))(new Throwable("test A result with id not equal"))
9296
aliasedTestTable = TestTable.withAlias("t")
9397
idAndName2 = aliasedTestTable.column.id ~ aliasedTestTable.column.name_2
9498
xs <-
9599
s.execute(
96100
sql"""SELECT #${idAndName2.aliasedName},#${aliasedTestTable.column.name.fullName} FROM #${TestTable.tableName} #${aliasedTestTable.tableName}"""
97101
.query(idAndName2.codec ~ TestTable.column.name.codec)
98102
)
99-
_ <- IO.raiseWhen(xs != List((id.get, testRow.name2) -> testRow.name))(
103+
_ <- IO.raiseWhen(xs != List((id, testRow.name2) -> testRow.name))(
100104
new Throwable("test A select fields not equal")
101105
)
102106
all2 <- s.execute(TestTable.select(TestTable.all))
@@ -179,10 +183,10 @@ object GeneratedCodeTest extends IOApp {
179183
_ <- IO.raiseWhen(
180184
Some(testBRow.copy(val27 = None, val2 = "updated_val_2", val14 = "updated_val_14")) != loadedById
181185
)(new Throwable("test B result missing update"))
182-
_ <- s.execute(sql"REFRESH MATERIALIZED VIEW test_materialized_view".command)
186+
_ <- s.execute(sql"REFRESH MATERIALIZED VIEW test_materialized_view".command)
183187
result <- s.execute(TestMaterializedViewTable.selectAll())
184-
_ <- IO.raiseWhen(result.isEmpty)(new Throwable(s"materialized view doesn't have correct value: ${result}"))
185-
_ <- IO.println("Test successful!")
188+
_ <- IO.raiseWhen(result.isEmpty)(new Throwable(s"materialized view doesn't have correct value: ${result}"))
189+
_ <- IO.println("Test successful!")
186190
} yield ()
187191

188192
}

0 commit comments

Comments
 (0)