diff --git a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp --- a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp +++ b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp @@ -328,6 +328,13 @@ } } + if (auto minOp = upperBound.getDefiningOp()) { + for (Value operand : {minOp.getLhs(), minOp.getRhs()}) { + if (auto staticBound = deriveStaticUpperBound(operand, rewriter)) + return staticBound; + } + } + if (auto multiplyOp = upperBound.getDefiningOp()) { if (auto lhs = dyn_cast_or_null( deriveStaticUpperBound(multiplyOp.getOperand(0), rewriter) @@ -336,8 +343,8 @@ deriveStaticUpperBound(multiplyOp.getOperand(1), rewriter) .getDefiningOp())) { // Assumptions about the upper bound of minimum computations no longer - // work if multiplied by a negative value, so abort in this case. - if (lhs.value() < 0 || rhs.value() < 0) + // work if multiplied by mixed signs, so abort in this case. + if (lhs.value() < 0 != rhs.value() < 0) return {}; return rewriter.create( diff --git a/mlir/test/Conversion/SCFToGPU/parallel_loop.mlir b/mlir/test/Conversion/SCFToGPU/parallel_loop.mlir --- a/mlir/test/Conversion/SCFToGPU/parallel_loop.mlir +++ b/mlir/test/Conversion/SCFToGPU/parallel_loop.mlir @@ -215,7 +215,8 @@ %3 = affine.min #map1(%arg3)[%2] %squared_min = arith.muli %3, %3 : index %4 = memref.dim %arg0, %c1 : memref - %5 = affine.min #map2(%arg4)[%4] + %d = arith.subi %4, %arg4 : index + %5 = arith.minsi %c3, %d : index %6 = memref.subview %arg0[%arg3, %arg4][%squared_min, %5][%c1, %c1] : memref to memref %7 = memref.dim %arg1, %c0 : memref %8 = affine.min #map1(%arg3)[%7] @@ -241,12 +242,12 @@ } } -// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> -// CHECK: #[[$MAP1:.*]] = affine_map<(d0)[s0, s1] -> ((d0 - s0) ceildiv s1)> -// CHECK: #[[$MAP2:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s0 + s1)> -// CHECK: #[[$MAP3:.*]] = affine_map<(d0)[s0] -> (2, -d0 + s0)> -// CHECK: #[[$MAP4:.*]] = affine_map<(d0)[s0] -> (3, -d0 + s0)> -// CHECK: #[[$MAP5:.*]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)> +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0)[s0, s1] -> ((d0 - s0) ceildiv s1)> +// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s0 + s1)> +// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0)[s0] -> (2, -d0 + s0)> +// CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0)[s0] -> (3, -d0 + s0)> +// CHECK-DAG: #[[$MAP5:.*]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)> // CHECK: module { // CHECK-LABEL: func @sum( @@ -262,8 +263,7 @@ // CHECK: [[VAL_11:%.*]] = affine.apply #[[$MAP1]]([[VAL_8]]){{\[}}%[[C0]], %[[C3]]] // CHECK: [[VAL_12:%.*]] = arith.constant 4 : index // CHECK: [[VAL_13:%.*]] = affine.apply #[[$MAP1]]([[VAL_12]]){{\[}}%[[C0]], %[[C1]]] -// CHECK: [[VAL_14:%.*]] = arith.constant 3 : index -// CHECK: [[VAL_15:%.*]] = affine.apply #[[$MAP1]]([[VAL_14]]){{\[}}%[[C0]], %[[C1]]] +// CHECK: [[VAL_15:%.*]] = affine.apply #[[$MAP1]](%[[C3]]){{\[}}%[[C0]], %[[C1]]] // CHECK: gpu.launch blocks([[VAL_16:%.*]], [[VAL_17:%.*]], [[VAL_18:%.*]]) in ([[VAL_19:%.*]] = [[VAL_10]], [[VAL_20:%.*]] = [[VAL_11]], [[VAL_21:%.*]] = [[VAL_9]]) threads([[VAL_22:%.*]], [[VAL_23:%.*]], [[VAL_24:%.*]]) in ([[VAL_25:%.*]] = [[VAL_13]], [[VAL_26:%.*]] = [[VAL_15]], [[VAL_27:%.*]] = [[VAL_9]]) { // CHECK: [[VAL_28:%.*]] = affine.apply #[[$MAP2]]([[VAL_16]]){{\[}}%[[C2]], %[[C0]]] // CHECK: [[VAL_29:%.*]] = affine.apply #[[$MAP2]]([[VAL_17]]){{\[}}%[[C3]], %[[C0]]] @@ -271,7 +271,8 @@ // CHECK: [[VAL_31:%.*]] = affine.min #[[$MAP3]]([[VAL_28]]){{\[}}[[VAL_30]]] // CHECK: [[VAL_31_SQUARED:%.*]] = arith.muli [[VAL_31]], [[VAL_31]] : index // CHECK: [[VAL_32:%.*]] = memref.dim [[VAL_0]], %[[C1]] : memref -// CHECK: [[VAL_33:%.*]] = affine.min #[[$MAP4]]([[VAL_29]]){{\[}}[[VAL_32]]] +// CHECK: [[VAL_D:%.*]] = arith.subi [[VAL_32]], [[VAL_29]] : index +// CHECK: [[VAL_33:%.*]] = arith.minsi %[[C3]], [[VAL_D]] : index // CHECK: [[VAL_34:%.*]] = memref.subview [[VAL_0]]{{\[}}[[VAL_28]], [[VAL_29]]] {{\[}}[[VAL_31_SQUARED]], [[VAL_33]]] {{\[}}%[[C1]], %[[C1]]] : memref to memref // CHECK: [[VAL_35:%.*]] = memref.dim [[VAL_1]], %[[C0]] : memref // CHECK: [[VAL_36:%.*]] = affine.min #[[$MAP3]]([[VAL_28]]){{\[}}[[VAL_35]]]