Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions cgen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,7 +748,7 @@ def make_multiple_ifs(conditions_and_blocks, base=None):
# {{{ simple statements

class Define(Generable):
def __init__(self, symbol, value):
def __init__(self, symbol, value=""):
self.symbol = symbol
self.value = value

Expand Down Expand Up @@ -1060,12 +1060,15 @@ class IfDef(Module):
:param iflines: the block of code inside the if [an array of type Generable]
:param elselines: the block of code inside the else [an array of type Generable]
"""
def __init__(self, condition, iflines, elselines):
def __init__(self, condition, iflines, elselines=None):
ifdef_line = Line('#ifdef %s' % condition)
if len(elselines):
elselines.insert(0, Line('#else'))
endif_line = Line('#endif')
lines = [ifdef_line]+iflines+elselines+[endif_line]
lines = [ifdef_line] + iflines

if elselines:
lines += [Line('#else')]
lines += elselines

lines += [Line('#endif')]
super(IfDef, self).__init__(lines)

mapper_method = "map_ifdef"
Expand All @@ -1079,11 +1082,15 @@ class IfNDef(Module):
[an array of type Generable]
:param elselines: the block of code inside the else [an array of type Generable]
"""
def __init__(self, condition, ifndeflines, elselines):
def __init__(self, condition, ifndeflines, elselines=None):
ifndefdef_line = Line('#ifndef %s' % condition)
if len(elselines):
elselines.insert(0, Line('#else'))
lines = [ifndefdef_line]+ifndeflines+elselines+[Line('#endif')]
lines = [ifndefdef_line] + ifndeflines

if elselines:
lines += [Line('#else')]
lines += elselines

lines += [Line('#endif')]
super(IfNDef, self).__init__(lines)

mapper_method = "map_ifndef"
Expand Down
68 changes: 67 additions & 1 deletion test/test_cgen.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import sys
from cgen import (
POD, Struct, FunctionBody, FunctionDeclaration,
For, If, Assign, Value, Block, ArrayOf, Comment,
Template)
Template, Pointer, IfNDef, IfDef, Define)
import numpy as np


Expand Down Expand Up @@ -46,3 +47,68 @@ def test_cgen():
print(s)
print(f_body)
print(t_decl)


def test_ptr_to_array():
t2 = Pointer(Pointer(ArrayOf(POD(np.float32, "yyy"), 2)))
assert str(t2) == "float **yyy[2];"


def test_ifndef_no_else():
expected = """#ifndef SOME_DEFINE
/* TRUE */
#endif"""

code = IfNDef("SOME_DEFINE", [Comment("TRUE")])
assert str(code) == expected


def test_ifndef():
expected = """#ifndef SOME_DEFINE
/* TRUE */
#else
/* FALSE */
#endif"""

code = IfNDef("SOME_DEFINE", [Comment("TRUE")], [Comment("FALSE")])
assert str(code) == expected


def test_ifdef_no_else():
expected = """#ifdef SOME_DEFINE
/* TRUE */
#endif"""

code = IfDef("SOME_DEFINE", [Comment("TRUE")])
assert str(code) == expected


def test_ifdef():
expected = """#ifdef SOME_DEFINE
/* TRUE */
#else
/* FALSE */
#endif"""

code = IfDef("SOME_DEFINE", [Comment("TRUE")], [Comment("FALSE")])
assert str(code) == expected


def test_define_no_val():
expected = "#define SOME_DEFINE"
code = Define("SOME_DEFINE")
assert str(code) == expected


def test_define_with_val():
expected = "#define SOME_DEFINE 42"
code = Define("SOME_DEFINE", 42)
assert str(code) == expected


if __name__ == "__main__":
if len(sys.argv) > 1:
exec(sys.argv[1])
else:
from pytest import main
main([__file__])