diff --git a/argus/src/main/scala/argus/macros/FromSchema.scala b/argus/src/main/scala/argus/macros/FromSchema.scala index 5c03790..0c3b28c 100644 --- a/argus/src/main/scala/argus/macros/FromSchema.scala +++ b/argus/src/main/scala/argus/macros/FromSchema.scala @@ -23,7 +23,7 @@ object JsonEngs { */ @compileTimeOnly("You must enable the macro paradise plugin.") class fromSchemaJson(json: String, debug: Boolean = false, jsonEng: Option[JsonEng] = None, outPath: Option[String] = None, - name: String = "Root") extends StaticAnnotation { + name: String = "Root", parent: Option[String] = None) extends StaticAnnotation { def macroTransform(annottees: Any*): Any = macro SchemaMacros.fromSchemaMacroImpl } @@ -39,7 +39,7 @@ class fromSchemaJson(json: String, debug: Boolean = false, jsonEng: Option[JsonE */ @compileTimeOnly("You must enable the macro paradise plugin.") class fromSchemaResource(path: String, debug: Boolean = false, jsonEng: Option[JsonEng] = None, outPath: Option[String] = None, - name: String = "Root") extends StaticAnnotation { + name: String = "Root", parent: Option[String] = None) extends StaticAnnotation { def macroTransform(annottees: Any*): Any = macro SchemaMacros.fromSchemaMacroImpl } @@ -55,7 +55,7 @@ class fromSchemaResource(path: String, debug: Boolean = false, jsonEng: Option[J */ @compileTimeOnly("You must enable the macro paradise plugin.") class fromSchemaURL(url: String, debug: Boolean = false, jsonEng: Option[JsonEng] = None, outPath: Option[String], - name: String = "Root") extends StaticAnnotation { + name: String = "Root", parent: Option[String] = None) extends StaticAnnotation { def macroTransform(annottees: Any*): Any = macro SchemaMacros.fromSchemaMacroImpl } @@ -68,14 +68,14 @@ class SchemaMacros(val c: Context) { private val helpers = new ASTHelpers[c.universe.type](c.universe) import helpers._ - case class Params(schema: Schema.Root, debug: Boolean, jsonEnd: Option[JsonEng], outPath: Option[String], name: String) + case class Params(schema: Schema.Root, debug: Boolean, jsonEnd: Option[JsonEng], outPath: Option[String], name: String, parent: Option[String]) private def extractParams(prefix: Tree): Params = { val q"new $name (..$paramASTs)" = prefix val (Ident(TypeName(fn: String))) = name val commonParams = ("debug", false) :: ("jsonEng", q"Some(JsonEngs.Circe)") :: ("outPath", None) :: - ("name", "Root") :: Nil + ("name", "Root") :: ("parent", None) :: Nil val params = fn match { case "fromSchemaResource" => { @@ -100,7 +100,8 @@ class SchemaMacros(val c: Context) { params("debug").asInstanceOf[Boolean], params("jsonEng") match { case q"Some(JsonEngs.Circe)" => Some(JsonEngs.Circe); case q"None" => None }, params("outPath").asInstanceOf[Option[String]], - params("name").asInstanceOf[String] + params("name").asInstanceOf[String], + params("parent").asInstanceOf[Option[String]] ) } @@ -142,7 +143,7 @@ class SchemaMacros(val c: Context) { // Add definitions and codecs to annotated object case (objDef @ q"$mods object $tname extends { ..$earlydefns } with ..$parents { $self => ..$stats }") :: _ => { - val (_, defs) = modelBuilder.mkSchemaDef(params.name, schema) + val (_, defs) = modelBuilder.mkSchemaDef(params.name, params.parent, schema) q""" $mods object $tname extends { ..$earlydefns } with ..$parents { $self => diff --git a/argus/src/main/scala/argus/macros/ModelBuilder.scala b/argus/src/main/scala/argus/macros/ModelBuilder.scala index 623a0ec..c61d52c 100644 --- a/argus/src/main/scala/argus/macros/ModelBuilder.scala +++ b/argus/src/main/scala/argus/macros/ModelBuilder.scala @@ -36,7 +36,7 @@ class ModelBuilder[U <: Universe](val u: U) { /** * Main workhorse. Creates case-classes from given fields. */ - def mkCaseClassDef(path: List[String], name: String, fields: List[Field], + def mkCaseClassDef(path: List[String], name: String, parent: Option[String], fields: List[Field], requiredFields: Option[List[String]]): (Tree, List[Tree]) = { // Build val defs for each field in case class, keeping track of new class defs created along the way (for nested @@ -50,7 +50,10 @@ class ModelBuilder[U <: Universe](val u: U) { } val typ = mkTypeSelectPath(path :+ name) - val ccDef = q"""case class ${ TypeName(name) } (..$params)""" + val ccDef = parent match { + case Some(p) => q"""case class ${TypeName(name)}(..$params) extends ${TypeName(p)}""" + case None => q"""case class ${TypeName(name)} (..$params)""" + } val defs = if (fieldDefs.isEmpty) ccDef :: Nil @@ -121,16 +124,17 @@ class ModelBuilder[U <: Universe](val u: U) { * Creates a Class/Type definition (i.e. creates a case class or type alias). * * @param name The name of the class/type to that is created - * @param schema. The schema that defines the type. Rough set of rules are: + * @param parent The name of a trait that the created class extends + * @param schema The schema that defines the type. Rough set of rules are: * - schema.typ.$ref, creates a type alias * - schema.typ.object, creates a new Case Class * - schema.typ.intrinicType, creates a type alias to the intrinic type * - schmea.typ.array, creates an array based on the type defined within schema.items * - schema.typ.List[st], ??? */ - def mkDef(path: List[String], name: String, schema: Root): (Tree, List[Tree]) = { + def mkDef(path: List[String], name: String, parent: Option[String], schema: Root): (Tree, List[Tree]) = { - (schema.$ref, schema.enum, schema.typ, schema.oneOf, schema.multiOf) match { + (schema.$ref, schema.enum, schema.typ, schema.oneOf, schema.multiOf) match { // Refs case (Some(ref),_,_,_,_) => { @@ -145,7 +149,7 @@ class ModelBuilder[U <: Universe](val u: U) { // Object (which defines a case-class) case (_,_,Some(SimpleTypeTyp(SimpleTypes.Object)),_,_) => { - mkCaseClassDef(path, name, schema.properties.get, schema.required) + mkCaseClassDef(path, name, parent, schema.properties.get, schema.required) } // Array, create type alias to List of type defined by schema.items (which itself is a schema) @@ -197,7 +201,7 @@ class ModelBuilder[U <: Universe](val u: U) { // Types are a bit strange. They are type definitions and schemas. We extract any inner /definitions // and embed those - val (_, defDefs) = mkSchemaDef(defaultName, schema.justDefinitions, path) + val (_, defDefs) = mkSchemaDef(defaultName, None, schema.justDefinitions, path) // If references existing schema, use that instead (schema.typ, schema.$ref, schema.enum, schema.oneOf, schema.multiOf) match { @@ -239,7 +243,7 @@ class ModelBuilder[U <: Universe](val u: U) { | (_,_,_,_,Some(_)) => { // NB: We ignore defDefs here since we're re-calling mkSchema - mkSchemaDef(defaultName, schema, path) + mkSchemaDef(defaultName, None, schema, path) } // If not type info specified then we have no option but to make it a map of strings (field names) to anys (values) @@ -299,18 +303,18 @@ class ModelBuilder[U <: Universe](val u: U) { * @param path A package path for where this is defined. Defaults to Nil. * @return A tuple containing the type of the root element that is generated, and all definitions required to support it */ - def mkSchemaDef(name: String, schema: Root, path: List[String] = Nil): (Tree, List[Tree]) = { + def mkSchemaDef(name: String, parent: Option[String], schema: Root, path: List[String] = Nil): (Tree, List[Tree]) = { // Make definitions val fieldDefs = for { fields <- schema.definitions.toList field <- fields - (_, defDefs) = mkSchemaDef(field.name.capitalize, field.schema, path) + (_, defDefs) = mkSchemaDef(field.name.capitalize, None, field.schema, path) defDef <- defDefs } yield defDef // Make root - val (typ, rootDefs) = mkDef(path, name, schema) + val (typ, rootDefs) = mkDef(path, name, parent, schema) (typ, fieldDefs ++ rootDefs) } diff --git a/argus/src/test/scala/argus/macros/FromSchemaSpec.scala b/argus/src/test/scala/argus/macros/FromSchemaSpec.scala index c7ecea2..9ed6bf3 100644 --- a/argus/src/test/scala/argus/macros/FromSchemaSpec.scala +++ b/argus/src/test/scala/argus/macros/FromSchemaSpec.scala @@ -393,6 +393,14 @@ class FromSchemaSpec extends FlatSpec with Matchers with JsonMatchers { Schema.Person(age=Some(42)).age should === (Some(42)) } + it should "support parent, and make the root element extend it" in { + trait Person + @fromSchemaResource("/simple.json", parent=Some("Person")) + object Schema + + implicitly[Schema.Root <:< Person] + } + "Complex example" should "work end to end" in { @fromSchemaResource("/vega-lite-schema.json") object Vega diff --git a/argus/src/test/scala/argus/macros/ModelBuilderSpec.scala b/argus/src/test/scala/argus/macros/ModelBuilderSpec.scala index 687e616..8230311 100644 --- a/argus/src/test/scala/argus/macros/ModelBuilderSpec.scala +++ b/argus/src/test/scala/argus/macros/ModelBuilderSpec.scala @@ -22,7 +22,7 @@ class ModelBuilderSpec extends FlatSpec with Matchers with ASTMatchers { val fields = Field("a", schemaFromSimpleType(SimpleTypes.Integer)) :: Field("b", schemaFromSimpleType(SimpleTypes.String)) :: Nil - val (typ, res) = mb.mkCaseClassDef(List("Foo"), "Bar", fields, None) + val (typ, res) = mb.mkCaseClassDef(List("Foo"), "Bar", None, fields, None) typ should === (tq"Foo.Bar") res should === (q"case class Bar(a: Option[Int] = None, b: Option[String] = None)" :: Nil) @@ -33,14 +33,14 @@ class ModelBuilderSpec extends FlatSpec with Matchers with ASTMatchers { Field("a", schemaFromSimpleType(SimpleTypes.Integer)) :: Field("b", schemaFromSimpleType(SimpleTypes.String)) :: Nil - val (_, res) = mb.mkCaseClassDef(List("Foo"), "Bar", fields, Some("b" :: Nil)) + val (_, res) = mb.mkCaseClassDef(List("Foo"), "Bar", None, fields, Some("b" :: Nil)) res should === (q"case class Bar(a: Option[Int] = None, b: String)" :: Nil) } it should "reference other classes when type is $ref" in { val fields = Field("a", schemaFromRef("#/definitions/Address")) :: Nil - val (_, res) = mb.mkCaseClassDef(List("Foo"), "Bar", fields, None) + val (_, res) = mb.mkCaseClassDef(List("Foo"), "Bar", None, fields, None) res should === (q"case class Bar(a: Option[Address] = None)" :: Nil) } @@ -54,7 +54,7 @@ class ModelBuilderSpec extends FlatSpec with Matchers with ASTMatchers { Field("name", schemaFromSimpleType(SimpleTypes.String)) :: Field("address", innerSchema) :: Nil - val (_, res) = mb.mkCaseClassDef(List("Foo"), "Person", fields, None) + val (_, res) = mb.mkCaseClassDef(List("Foo"), "Person", None, fields, None) res should === ( q"case class Person(name: Option[String] = None, address: Option[Foo.Person.Address] = None)" :: q""" @@ -178,7 +178,7 @@ class ModelBuilderSpec extends FlatSpec with Matchers with ASTMatchers { "mkDef()" should "create a type alias for a $ref schema" in { val schema = schemaFromRef("#/definitions/ABC") - val (typ, res) = mb.mkDef("Foo" :: Nil, "Bar", schema) + val (typ, res) = mb.mkDef("Foo" :: Nil, "Bar", None, schema) typ should === (tq"Foo.Bar") res should === ( @@ -189,7 +189,7 @@ class ModelBuilderSpec extends FlatSpec with Matchers with ASTMatchers { it should "create an enum for a enum schema" in { val schema = schemaFromEnum("\"A\"" :: "\"B\"" :: Nil) - val (typ, res) = mb.mkDef("Foo" :: Nil, "Bar", schema) + val (typ, res) = mb.mkDef("Foo" :: Nil, "Bar", None, schema) typ should === (tq"Foo.Bar") showCode(res.head) should include ("@enum") @@ -200,7 +200,7 @@ class ModelBuilderSpec extends FlatSpec with Matchers with ASTMatchers { Field("a", schemaFromSimpleType(SimpleTypes.Integer)) :: Nil ) - val (typ, res) = mb.mkDef("Foo" :: Nil, "Bar", schema) + val (typ, res) = mb.mkDef("Foo" :: Nil, "Bar", None, schema) typ should === (tq"Foo.Bar") showCode(res.head) should include ("case class Bar") @@ -208,7 +208,7 @@ class ModelBuilderSpec extends FlatSpec with Matchers with ASTMatchers { it should "create an type alias for an array schema named using name" in { val schema = schemaFromArray(schemaFromSimpleType(SimpleTypes.String)) - val (typ, res) = mb.mkDef("Foo" :: Nil, "Bar", schema) + val (typ, res) = mb.mkDef("Foo" :: Nil, "Bar", None, schema) typ should === (tq"Foo.Bar") showCode(res.head) should include ("type Bar = List[String]") @@ -216,7 +216,7 @@ class ModelBuilderSpec extends FlatSpec with Matchers with ASTMatchers { it should "create an type alias for intrinsic types" in { val schema = schemaFromSimpleType(SimpleTypes.String) - val (typ, res) = mb.mkDef("Foo" :: Nil, "Bar", schema) + val (typ, res) = mb.mkDef("Foo" :: Nil, "Bar", None, schema) typ should === (tq"Foo.Bar") showCode(res.head) should include ("type Bar = String") @@ -228,7 +228,7 @@ class ModelBuilderSpec extends FlatSpec with Matchers with ASTMatchers { schemaFromArray(schemaFromSimpleType(SimpleTypes.String)) :: Nil ) - val (typ, res) = mb.mkDef("Foo" :: Nil, "Bar", schema) + val (typ, res) = mb.mkDef("Foo" :: Nil, "Bar", None, schema) val code = res.map(showCode(_)).mkString("\n") typ should === (tq"Foo.BarUnion") @@ -341,7 +341,7 @@ class ModelBuilderSpec extends FlatSpec with Matchers with ASTMatchers { )) :: Nil - val (typ: Tree, res) = mb.mkSchemaDef("Root", schema=base.copy(definitions=Some(defs)), "Foo" :: Nil) + val (typ: Tree, res) = mb.mkSchemaDef("Root", None, schema=base.copy(definitions=Some(defs)), "Foo" :: Nil) val code = res.map(showCode(_)).mkString("\n") typ should === (tq"Foo.Root") @@ -361,10 +361,21 @@ class ModelBuilderSpec extends FlatSpec with Matchers with ASTMatchers { Nil ) - val (typ: Tree, res) = mb.mkSchemaDef("Root", schema) + val (typ: Tree, res) = mb.mkSchemaDef("Root", None, schema) val code = res.map(showCode(_)).mkString("\n") code should include ("case class Root(a: Option[Int] = None, b: Option[Root.C] = None)") code should include ("type C = String") } + it should "support extending a provided parent" in { + val schema = schemaFromFields( + Field("a", schemaFromSimpleType(SimpleTypes.Integer)) :: + Nil + ) + + val (typ: Tree, res) = mb.mkSchemaDef("Root", Some("ParentTrait"), schema) + val code = res.map(showCode(_)).mkString("\n") + code should include ("case class Root(a: Option[Int] = None) extends ParentTrait") + } + }