diff --git a/mlir/python/mlir/dialects/_scf_ops_ext.py b/mlir/python/mlir/dialects/_scf_ops_ext.py --- a/mlir/python/mlir/dialects/_scf_ops_ext.py +++ b/mlir/python/mlir/dialects/_scf_ops_ext.py @@ -64,3 +64,44 @@ To obtain the loop-carried operands, use `iter_args`. """ return self.body.arguments[1:] + + +class IfOp: + """Specialization for the SCF if op class.""" + + def __init__(self, + cond, + results_=[], + *, + hasElse=False, + loc=None, + ip=None): + """Creates an SCF `if` operation. + + - `cond` is a MLIR value of 'i1' type to determine which regions of code will be executed. + - `hasElse` determines whether the if operation has the else branch. + """ + operands = [] + operands.append(cond) + results = [] + results.extend(results_) + super().__init__( + self.build_generic( + regions=2, + results=results, + operands=operands, + loc=loc, + ip=ip)) + self.regions[0].blocks.append(*[]) + if hasElse: + self.regions[1].blocks.append(*[]) + + @property + def then_block(self): + """Returns the then block of the if operation.""" + return self.regions[0].blocks[0] + + @property + def else_block(self): + """Returns the else block of the if operation.""" + return self.regions[1].blocks[0] diff --git a/mlir/test/python/dialects/scf.py b/mlir/test/python/dialects/scf.py --- a/mlir/test/python/dialects/scf.py +++ b/mlir/test/python/dialects/scf.py @@ -82,3 +82,58 @@ # CHECK: iter_args(%{{.*}} = %[[ARGS]]#0, %{{.*}} = %[[ARGS]]#1) # CHECK: scf.yield %{{.*}}, %{{.*}} # CHECK: return + + +@constructAndPrintInModule +def testIfWithoutElse(): + bool = IntegerType.get_signless(1) + i32 = IntegerType.get_signless(32) + + @builtin.FuncOp.from_py_func(bool) + def simple_if(cond): + if_op = scf.IfOp(cond) + with InsertionPoint(if_op.then_block): + one = arith.ConstantOp(i32, 1) + add = arith.AddIOp(one, one) + scf.YieldOp([]) + return + + +# CHECK: func @simple_if(%[[ARG0:.*]]: i1) +# CHECK: scf.if %[[ARG0:.*]] +# CHECK: %[[ONE:.*]] = arith.constant 1 +# CHECK: %[[ADD:.*]] = arith.addi %[[ONE]], %[[ONE]] +# CHECK: return + + +@constructAndPrintInModule +def testIfWithElse(): + bool = IntegerType.get_signless(1) + i32 = IntegerType.get_signless(32) + + @builtin.FuncOp.from_py_func(bool) + def simple_if_else(cond): + if_op = scf.IfOp(cond, [i32, i32], hasElse=True) + with InsertionPoint(if_op.then_block): + x_true = arith.ConstantOp(i32, 0) + y_true = arith.ConstantOp(i32, 1) + scf.YieldOp([x_true, y_true]) + with InsertionPoint(if_op.else_block): + x_false = arith.ConstantOp(i32, 2) + y_false = arith.ConstantOp(i32, 3) + scf.YieldOp([x_false, y_false]) + add = arith.AddIOp(if_op.results[0], if_op.results[1]) + return + + +# CHECK: func @simple_if_else(%[[ARG0:.*]]: i1) +# CHECK: %[[RET:.*]]:2 = scf.if %[[ARG0:.*]] +# CHECK: %[[ZERO:.*]] = arith.constant 0 +# CHECK: %[[ONE:.*]] = arith.constant 1 +# CHECK: scf.yield %[[ZERO]], %[[ONE]] +# CHECK: } else { +# CHECK: %[[TWO:.*]] = arith.constant 2 +# CHECK: %[[THREE:.*]] = arith.constant 3 +# CHECK: scf.yield %[[TWO]], %[[THREE]] +# CHECK: arith.addi %[[RET]]#0, %[[RET]]#1 +# CHECK: return