diff --git a/mlir/python/mlir/dialects/_scf_ops_ext.py b/mlir/python/mlir/dialects/_scf_ops_ext.py --- a/mlir/python/mlir/dialects/_scf_ops_ext.py +++ b/mlir/python/mlir/dialects/_scf_ops_ext.py @@ -70,16 +70,16 @@ """Specialization for the SCF if op class.""" def __init__(self, - results_, cond, - withElseRegion=False, + results_=[], *, + hasElse=False, loc=None, ip=None): """Creates an SCF `if` operation. - - `cond` is a boolean value to determine which regions of code will be executed. - - `withElseRegion` determines whether the if operation has the else branch. + - `cond` is a MLIR value of 'i1' type to determine which regions of code will be executed. + - `hasElse` determines whether the if operation has the else branch. """ operands = [] operands.append(cond) @@ -92,9 +92,9 @@ operands=operands, loc=loc, ip=ip)) - self.regions[0].blocks.append(*results) - if withElseRegion: - self.regions[1].blocks.append(*results) + self.regions[0].blocks.append(*[]) + if hasElse: + self.regions[1].blocks.append(*[]) @property def then_block(self): diff --git a/mlir/test/python/dialects/scf.py b/mlir/test/python/dialects/scf.py --- a/mlir/test/python/dialects/scf.py +++ b/mlir/test/python/dialects/scf.py @@ -82,3 +82,58 @@ # CHECK: iter_args(%{{.*}} = %[[ARGS]]#0, %{{.*}} = %[[ARGS]]#1) # CHECK: scf.yield %{{.*}}, %{{.*}} # CHECK: return + + +@constructAndPrintInModule +def testIfWithoutElse(): + bool = IntegerType.get_signless(1) + i32 = IntegerType.get_signless(32) + + @builtin.FuncOp.from_py_func(bool) + def simple_if(cond): + if_op = scf.IfOp(cond) + with InsertionPoint(if_op.then_block): + one = arith.ConstantOp(i32, 1) + add = arith.AddIOp(one, one) + scf.YieldOp([]) + return + + +# CHECK: func @simple_if(%[[ARG0:.*]]: i1) +# CHECK: scf.if %[[ARG0:.*]] +# CHECK: %[[ONE:.*]] = arith.constant 1 +# CHECK: %[[ADD:.*]] = arith.addi %[[ONE]], %[[ONE]] +# CHECK: return + + +@constructAndPrintInModule +def testIfWithElse(): + bool = IntegerType.get_signless(1) + i32 = IntegerType.get_signless(32) + + @builtin.FuncOp.from_py_func(bool) + def simple_if_else(cond): + if_op = scf.IfOp(cond, [i32, i32], hasElse=True) + with InsertionPoint(if_op.then_block): + x_true = arith.ConstantOp(i32, 0) + y_true = arith.ConstantOp(i32, 1) + scf.YieldOp([x_true, y_true]) + with InsertionPoint(if_op.else_block): + x_false = arith.ConstantOp(i32, 2) + y_false = arith.ConstantOp(i32, 3) + scf.YieldOp([x_false, y_false]) + add = arith.AddIOp(if_op.results[0], if_op.results[1]) + return + + +# CHECK: func @simple_if_else(%[[ARG0:.*]]: i1) +# CHECK: %[[RET:.*]]:2 = scf.if %[[ARG0:.*]] +# CHECK: %[[ZERO:.*]] = arith.constant 0 +# CHECK: %[[ONE:.*]] = arith.constant 1 +# CHECK: scf.yield %[[ZERO]], %[[ONE]] +# CHECK: } else { +# CHECK: %[[TWO:.*]] = arith.constant 2 +# CHECK: %[[THREE:.*]] = arith.constant 3 +# CHECK: scf.yield %[[TWO]], %[[THREE]] +# CHECK: arith.addi %[[RET]]#0, %[[RET]]#1 +# CHECK: return