diff --git a/cgen/__init__.py b/cgen/__init__.py index 7d9a8b1..9a0572d 100644 --- a/cgen/__init__.py +++ b/cgen/__init__.py @@ -748,7 +748,7 @@ def make_multiple_ifs(conditions_and_blocks, base=None): # {{{ simple statements class Define(Generable): - def __init__(self, symbol, value): + def __init__(self, symbol, value=""): self.symbol = symbol self.value = value @@ -1060,12 +1060,15 @@ class IfDef(Module): :param iflines: the block of code inside the if [an array of type Generable] :param elselines: the block of code inside the else [an array of type Generable] """ - def __init__(self, condition, iflines, elselines): + def __init__(self, condition, iflines, elselines=None): ifdef_line = Line('#ifdef %s' % condition) - if len(elselines): - elselines.insert(0, Line('#else')) - endif_line = Line('#endif') - lines = [ifdef_line]+iflines+elselines+[endif_line] + lines = [ifdef_line] + iflines + + if elselines: + lines += [Line('#else')] + lines += elselines + + lines += [Line('#endif')] super(IfDef, self).__init__(lines) mapper_method = "map_ifdef" @@ -1079,11 +1082,15 @@ class IfNDef(Module): [an array of type Generable] :param elselines: the block of code inside the else [an array of type Generable] """ - def __init__(self, condition, ifndeflines, elselines): + def __init__(self, condition, ifndeflines, elselines=None): ifndefdef_line = Line('#ifndef %s' % condition) - if len(elselines): - elselines.insert(0, Line('#else')) - lines = [ifndefdef_line]+ifndeflines+elselines+[Line('#endif')] + lines = [ifndefdef_line] + ifndeflines + + if elselines: + lines += [Line('#else')] + lines += elselines + + lines += [Line('#endif')] super(IfNDef, self).__init__(lines) mapper_method = "map_ifndef" diff --git a/test/test_cgen.py b/test/test_cgen.py index b1de365..ce1e636 100644 --- a/test/test_cgen.py +++ b/test/test_cgen.py @@ -1,7 +1,8 @@ +import sys from cgen import ( POD, Struct, FunctionBody, FunctionDeclaration, For, If, Assign, Value, Block, ArrayOf, Comment, - Template) + Template, Pointer, IfNDef, IfDef, Define) import numpy as np @@ -46,3 +47,68 @@ def test_cgen(): print(s) print(f_body) print(t_decl) + + +def test_ptr_to_array(): + t2 = Pointer(Pointer(ArrayOf(POD(np.float32, "yyy"), 2))) + assert str(t2) == "float **yyy[2];" + + +def test_ifndef_no_else(): + expected = """#ifndef SOME_DEFINE +/* TRUE */ +#endif""" + + code = IfNDef("SOME_DEFINE", [Comment("TRUE")]) + assert str(code) == expected + + +def test_ifndef(): + expected = """#ifndef SOME_DEFINE +/* TRUE */ +#else +/* FALSE */ +#endif""" + + code = IfNDef("SOME_DEFINE", [Comment("TRUE")], [Comment("FALSE")]) + assert str(code) == expected + + +def test_ifdef_no_else(): + expected = """#ifdef SOME_DEFINE +/* TRUE */ +#endif""" + + code = IfDef("SOME_DEFINE", [Comment("TRUE")]) + assert str(code) == expected + + +def test_ifdef(): + expected = """#ifdef SOME_DEFINE +/* TRUE */ +#else +/* FALSE */ +#endif""" + + code = IfDef("SOME_DEFINE", [Comment("TRUE")], [Comment("FALSE")]) + assert str(code) == expected + + +def test_define_no_val(): + expected = "#define SOME_DEFINE" + code = Define("SOME_DEFINE") + assert str(code) == expected + + +def test_define_with_val(): + expected = "#define SOME_DEFINE 42" + code = Define("SOME_DEFINE", 42) + assert str(code) == expected + + +if __name__ == "__main__": + if len(sys.argv) > 1: + exec(sys.argv[1]) + else: + from pytest import main + main([__file__])