diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp --- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp +++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp @@ -531,17 +531,30 @@ cond = cond ? cond : rewriter.create(loc, /*value=*/1, /*width=*/1); - bool hasElseRegion = !op.elseRegion().empty(); - auto ifOp = rewriter.create(loc, cond, hasElseRegion); - rewriter.inlineRegionBefore(op.thenRegion(), &ifOp.thenRegion().back()); - rewriter.eraseBlock(&ifOp.thenRegion().back()); - if (hasElseRegion) { - rewriter.inlineRegionBefore(op.elseRegion(), &ifOp.elseRegion().back()); - rewriter.eraseBlock(&ifOp.elseRegion().back()); + if (op.getNumResults() == 0) { + bool hasElseRegion = !op.elseRegion().empty(); + auto ifOp = rewriter.create(loc, cond, hasElseRegion); + rewriter.inlineRegionBefore(op.thenRegion(), &ifOp.thenRegion().back()); + rewriter.eraseBlock(&ifOp.thenRegion().back()); + if (hasElseRegion) { + rewriter.inlineRegionBefore(op.elseRegion(), &ifOp.elseRegion().back()); + rewriter.eraseBlock(&ifOp.elseRegion().back()); + } + rewriter.eraseOp(op); + } else { + bool hasElseRegion = !op.elseRegion().empty(); + auto ifOp = rewriter.create(loc, op.getResultTypes(), cond, + hasElseRegion); + rewriter.inlineRegionBefore(op.thenRegion(), &ifOp.thenRegion().back()); + rewriter.eraseBlock(&ifOp.thenRegion().back()); + if (hasElseRegion) { + rewriter.inlineRegionBefore(op.elseRegion(), &ifOp.elseRegion().back()); + rewriter.eraseBlock(&ifOp.elseRegion().back()); + } + rewriter.replaceOp(op, ifOp.results()); } // Ok, we're done! - rewriter.eraseOp(op); return success(); } }; diff --git a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir --- a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir +++ b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir @@ -239,6 +239,67 @@ return } +func private @get_res() -> (f32) + +// CHECK-LABEL: func @if_with_yield +// CHECK-NEXT: %[[v0:.*]] = call @get_idx() : () -> index +// CHECK-NEXT: %[[v1:.*]] = call @get_res() : () -> f32 +// CHECK-NEXT: %[[v2:.*]] = call @get_res() : () -> f32 +// CHECK-NEXT: %[[c0:.*]] = constant 0 : index +// CHECK-NEXT: %[[cm1:.*]] = constant -1 : index +// CHECK-NEXT: %[[v1:.*]] = muli %[[v0]], %[[cm1]] : index +// CHECK-NEXT: %[[c20:.*]] = constant 20 : index +// CHECK-NEXT: %[[v2:.*]] = addi %[[v1]], %[[c20]] : index +// CHECK-NEXT: %[[v3:.*]] = cmpi sge, %[[v2]], %[[c0]] : index +// CHECK-NEXT: %[[ret:.*]] = scf.if %[[v3]] -> (f32) { +// CHECK-NEXT: yield +// CHECK-NEXT: } else { +// CHECK-NEXT: yield +// CHECK-NEXT: } +// CHECK-NEXT: return %[[ret]] : f32 +// CHECK-NEXT: } +func @if_with_yield() ->(f32) { + %i = call @get_idx() : () -> (index) + %res1 = call @get_res() : () -> (f32) + %res2 = call @get_res() : () -> (f32) + %0 = affine.if #set1(%i) -> (f32) { + affine.yield %res1 : f32 + } else { + affine.yield %res2 : f32 + } + return %0 : f32 +} + +// CHECK-LABEL: func @if_with_multi_yield +// CHECK-NEXT: %[[v0:.*]] = call @get_idx() : () -> index +// CHECK-NEXT: %[[v1:.*]] = call @get_res() : () -> f32 +// CHECK-NEXT: %[[v2:.*]] = call @get_res() : () -> f32 +// CHECK-NEXT: %[[c0:.*]] = constant 0 : index +// CHECK-NEXT: %[[cm1:.*]] = constant -1 : index +// CHECK-NEXT: %[[v1:.*]] = muli %[[v0]], %[[cm1]] : index +// CHECK-NEXT: %[[c20:.*]] = constant 20 : index +// CHECK-NEXT: %[[v2:.*]] = addi %[[v1]], %[[c20]] : index +// CHECK-NEXT: %[[v3:.*]] = cmpi sge, %[[v2]], %[[c0]] : index +// CHECK-NEXT: %[[ret:.*]]:2 = scf.if %[[v3]] -> (f32, f32) { +// CHECK-NEXT: yield +// CHECK-NEXT: } else { +// CHECK-NEXT: yield +// CHECK-NEXT: } +// CHECK-NEXT: return %[[ret]]#0, %[[ret]]#1 : f32, f32 +// CHECK-NEXT: } +func @if_with_multi_yield() ->(f32, f32) { + %i = call @get_idx() : () -> (index) + %res1 = call @get_res() : () -> (f32) + %res2 = call @get_res() : () -> (f32) + %0:2 = affine.if #set1(%i) -> (f32, f32) { + affine.yield %res1, %res2 : f32, f32 + } else { + affine.yield %res2, %res1 : f32, f32 + } + return %0#0, %0#1 : f32, f32 +} + + #setN = affine_set<(d0)[N,M,K,L] : (N - d0 + 1 >= 0, N - 1 >= 0, M - 1 >= 0, K - 1 >= 0, L - 42 == 0)> // CHECK-LABEL: func @multi_cond diff --git a/mlir/test/Dialect/Linalg/affine.mlir b/mlir/test/Dialect/Linalg/affine.mlir --- a/mlir/test/Dialect/Linalg/affine.mlir +++ b/mlir/test/Dialect/Linalg/affine.mlir @@ -9,6 +9,8 @@ // CHECK-DAG: #[[$clampMinMap:.*]] = affine_map<(d0) -> (d0, 0)> +// CHECK-DAG: #[[$tset:.*]] = affine_set<(d0, d1) : (d0 - d1 >= 0)> + func @matmul(%arg0: memref, %M: index, %N: index, %K: index) { %c0 = constant 0 : index %c1 = constant 1 : index @@ -149,3 +151,72 @@ // CHECK-NEXT: cmpf // CHECK-NEXT: select // CHECK-NEXT: affine.store + +//----------------------------------------------------------------------------// +// Generic ops to loops. +//----------------------------------------------------------------------------// +#transpose_accesses = [ + affine_map<(H,W)->(H,W)>, + affine_map<(H,W)->(W,H)> +] +#trans_trait = { + indexing_maps = #transpose_accesses, + iterator_types = ["parallel","parallel"] +} + +func @transpose(%in: memref, %out: memref){ + + // Transpose + linalg.generic #trans_trait + ins(%in : memref) + outs(%out : memref) { + ^bb0(%a: f32, %b:f32): + linalg.yield %a : f32 + } + return +} +// CHECK-LABEL: @transpose +// CHECK-NEXT: constant +// CHECK-NEXT: constant +// CHECK-NEXT: memref.dim +// CHECK-NEXT: memref.dim +// CHECK-NEXT: affine.for %[[idx0:.*]] = {{.*}} +// CHECK-NEXT: affine.for %[[idx1:.*]] = {{.*}} +// CHECK-NEXT: %[[v0:.*]] = affine.load %{{.*}}[%[[idx0]], %[[idx1]]] {{.*}} +// CHECK-NEXT: affine.store %[[v0]], %{{.*}}[%[[idx1]], %[[idx0]]] {{.*}} + +#set0 = affine_set<(d0,d1) : (d0-d1>=0)> +func @transpose_inplace(%out: memref){ + + linalg.indexed_generic #trans_trait + outs(%out, %out : memref, memref) { + ^bb0(%i: index, %j: index, %a: f32, %b:f32): + + // With the addtion of the AffineScope trait, linalg.*generic ops can have + // affine ops in its body + %r1, %r2 = affine.if #set0(%i,%j) -> (f32,f32) { + affine.yield %b,%a : f32, f32 + } else { + affine.yield %a,%b : f32, f32 + } + linalg.yield %r1, %r2 : f32 ,f32 + } + return +} + +// CHECK-LABEL: @transpose_inplace +// CHECK-NEXT: constant +// CHECK-NEXT: constant +// CHECK-NEXT: memref.dim +// CHECK-NEXT: memref.dim +// CHECK-NEXT: affine.for %[[idx0:.*]] = {{.*}} +// CHECK-NEXT: affine.for %[[idx1:.*]] = {{.*}} +// CHECK-NEXT: %[[v0:.*]] = affine.load %{{.*}}[%[[idx0]], %[[idx1]]] {{.*}} +// CHECK-NEXT: %[[v1:.*]] = affine.load %{{.*}}[%[[idx1]], %[[idx0]]] {{.*}} +// CHECK-NEXT: %[[ret:.*]]:2 = affine.if #[[$tset]](%[[idx0]], %[[idx1]]) {{.*}} +// CHECK-NEXT: affine.yield %[[v1]], %[[v0]] {{.*}} +// CHECK-NEXT: } else { +// CHECK-NEXT: affine.yield %[[v0]], %[[v1]] {{.*}} +// CHECK-NEXT: } +// CHECK-NEXT: affine.store %[[ret]]#0, %{{.*}}[%[[idx0]], %[[idx1]]] {{.*}} +// CHECK-NEXT: affine.store %[[ret]]#1, %{{.*}}[%[[idx1]], %[[idx0]]] {{.*}}