diff --git a/mlir/test/python/dialects/scf.py b/mlir/test/python/dialects/scf.py --- a/mlir/test/python/dialects/scf.py +++ b/mlir/test/python/dialects/scf.py @@ -82,3 +82,58 @@ # CHECK: iter_args(%{{.*}} = %[[ARGS]]#0, %{{.*}} = %[[ARGS]]#1) # CHECK: scf.yield %{{.*}}, %{{.*}} # CHECK: return + + +@constructAndPrintInModule +def testIfWithoutElse(): + bool = IntegerType.get_signless(1) + i32 = IntegerType.get_signless(32) + + @builtin.FuncOp.from_py_func(bool) + def simple_if(cond): + if_op = scf.IfOp(cond) + with InsertionPoint(if_op.then_block): + one = arith.ConstantOp(i32, 1) + add = arith.AddIOp(one, one) + scf.YieldOp([]) + return + + +# CHECK: func @simple_if(%[[ARG0:.*]]: i1) +# CHECK: scf.if %[[ARG0:.*]] +# CHECK: %[[ONE:.*]] = arith.constant 1 +# CHECK: %[[ADD:.*]] = arith.addi %[[ONE]], %[[ONE]] +# CHECK: return + + +@constructAndPrintInModule +def testIfWithElse(): + bool = IntegerType.get_signless(1) + i32 = IntegerType.get_signless(32) + + @builtin.FuncOp.from_py_func(bool) + def simple_if_else(cond): + if_op = scf.IfOp(cond, [i32, i32], hasElse=True) + with InsertionPoint(if_op.then_block): + x_true = arith.ConstantOp(i32, 0) + y_true = arith.ConstantOp(i32, 1) + scf.YieldOp([x_true, y_true]) + with InsertionPoint(if_op.else_block): + x_false = arith.ConstantOp(i32, 2) + y_false = arith.ConstantOp(i32, 3) + scf.YieldOp([x_false, y_false]) + add = arith.AddIOp(if_op.results[0], if_op.results[1]) + return + + +# CHECK: func @simple_if_else(%[[ARG0:.*]]: i1) +# CHECK: %[[RET:.*]]:2 = scf.if %[[ARG0:.*]] +# CHECK: %[[ZERO:.*]] = arith.constant 0 +# CHECK: %[[ONE:.*]] = arith.constant 1 +# CHECK: scf.yield %[[ZERO]], %[[ONE]] +# CHECK: } else { +# CHECK: %[[TWO:.*]] = arith.constant 2 +# CHECK: %[[THREE:.*]] = arith.constant 3 +# CHECK: scf.yield %[[TWO]], %[[THREE]] +# CHECK: arith.addi %[[RET]]#0, %[[RET]]#1 +# CHECK: return