diff --git a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp --- a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp +++ b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp @@ -506,16 +506,36 @@ /// `upperBound`. static Value deriveStaticUpperBound(Value upperBound, PatternRewriter &rewriter) { - if (AffineMinOp minOp = - dyn_cast_or_null(upperBound.getDefiningOp())) { + if (auto op = dyn_cast_or_null(upperBound.getDefiningOp())) { + return op; + } + + if (auto minOp = dyn_cast_or_null(upperBound.getDefiningOp())) { for (const AffineExpr &result : minOp.map().getResults()) { - if (AffineConstantExpr constExpr = - result.dyn_cast()) { + if (auto constExpr = result.dyn_cast()) { return rewriter.create(minOp.getLoc(), constExpr.getValue()); } } } + + if (auto multiplyOp = dyn_cast_or_null(upperBound.getDefiningOp())) { + if (auto lhs = dyn_cast_or_null( + deriveStaticUpperBound(multiplyOp.getOperand(0), rewriter) + .getDefiningOp())) + if (auto rhs = dyn_cast_or_null( + deriveStaticUpperBound(multiplyOp.getOperand(1), rewriter) + .getDefiningOp())) { + // Assumptions about the upper bound of minimum computations no longer + // work if multiplied by a negative value, so abort in this case. + if (lhs.getValue() < 0 || rhs.getValue() < 0) + return {}; + + return rewriter.create( + multiplyOp.getLoc(), lhs.getValue() * rhs.getValue()); + } + } + return {}; } diff --git a/mlir/test/Conversion/LoopsToGPU/parallel_loop.mlir b/mlir/test/Conversion/LoopsToGPU/parallel_loop.mlir --- a/mlir/test/Conversion/LoopsToGPU/parallel_loop.mlir +++ b/mlir/test/Conversion/LoopsToGPU/parallel_loop.mlir @@ -213,9 +213,10 @@ loop.parallel (%arg3, %arg4) = (%c0, %c0) to (%0, %1) step (%c2, %c3) { %2 = dim %arg0, 0 : memref %3 = affine.min #map1(%arg3)[%2] + %squared_min = muli %3, %3 : index %4 = dim %arg0, 1 : memref %5 = affine.min #map2(%arg4)[%4] - %6 = std.subview %arg0[%arg3, %arg4][%3, %5][%c1, %c1] : memref to memref + %6 = std.subview %arg0[%arg3, %arg4][%squared_min, %5][%c1, %c1] : memref to memref %7 = dim %arg1, 0 : memref %8 = affine.min #map1(%arg3)[%7] %9 = dim %arg1, 1 : memref @@ -226,7 +227,7 @@ %14 = dim %arg2, 1 : memref %15 = affine.min #map2(%arg4)[%14] %16 = std.subview %arg2[%arg3, %arg4][%13, %15][%c1, %c1] : memref to memref - loop.parallel (%arg5, %arg6) = (%c0, %c0) to (%3, %5) step (%c1, %c1) { + loop.parallel (%arg5, %arg6) = (%c0, %c0) to (%squared_min, %5) step (%c1, %c1) { %17 = load %6[%arg5, %arg6] : memref %18 = load %11[%arg5, %arg6] : memref %19 = load %16[%arg5, %arg6] : memref @@ -259,7 +260,7 @@ // CHECK: [[VAL_9:%.*]] = constant 1 : index // CHECK: [[VAL_10:%.*]] = affine.apply #[[MAP1]](){{\[}}[[VAL_7]], [[VAL_4]], [[VAL_6]]] // CHECK: [[VAL_11:%.*]] = affine.apply #[[MAP1]](){{\[}}[[VAL_8]], [[VAL_4]], [[VAL_5]]] -// CHECK: [[VAL_12:%.*]] = constant 2 : index +// CHECK: [[VAL_12:%.*]] = constant 4 : index // CHECK: [[VAL_13:%.*]] = affine.apply #[[MAP1]](){{\[}}[[VAL_12]], [[VAL_4]], [[VAL_3]]] // CHECK: [[VAL_14:%.*]] = constant 3 : index // CHECK: [[VAL_15:%.*]] = affine.apply #[[MAP1]](){{\[}}[[VAL_14]], [[VAL_4]], [[VAL_3]]] @@ -268,9 +269,10 @@ // CHECK: [[VAL_29:%.*]] = affine.apply #[[MAP2]]([[VAL_17]]){{\[}}[[VAL_5]], [[VAL_4]]] // CHECK: [[VAL_30:%.*]] = dim [[VAL_0]], 0 : memref // CHECK: [[VAL_31:%.*]] = affine.min #[[MAP3]]([[VAL_28]]){{\[}}[[VAL_30]]] +// CHECK: [[VAL_31_SQUARED:%.*]] = muli [[VAL_31]], [[VAL_31]] : index // CHECK: [[VAL_32:%.*]] = dim [[VAL_0]], 1 : memref // CHECK: [[VAL_33:%.*]] = affine.min #[[MAP4]]([[VAL_29]]){{\[}}[[VAL_32]]] -// CHECK: [[VAL_34:%.*]] = subview [[VAL_0]]{{\[}}[[VAL_28]], [[VAL_29]]] {{\[}}[[VAL_31]], [[VAL_33]]] {{\[}}[[VAL_3]], [[VAL_3]]] : memref to memref +// CHECK: [[VAL_34:%.*]] = subview [[VAL_0]]{{\[}}[[VAL_28]], [[VAL_29]]] {{\[}}[[VAL_31_SQUARED]], [[VAL_33]]] {{\[}}[[VAL_3]], [[VAL_3]]] : memref to memref // CHECK: [[VAL_35:%.*]] = dim [[VAL_1]], 0 : memref // CHECK: [[VAL_36:%.*]] = affine.min #[[MAP3]]([[VAL_28]]){{\[}}[[VAL_35]]] // CHECK: [[VAL_37:%.*]] = dim [[VAL_1]], 1 : memref @@ -282,7 +284,7 @@ // CHECK: [[VAL_43:%.*]] = affine.min #[[MAP4]]([[VAL_29]]){{\[}}[[VAL_42]]] // CHECK: [[VAL_44:%.*]] = subview [[VAL_2]]{{\[}}[[VAL_28]], [[VAL_29]]] {{\[}}[[VAL_41]], [[VAL_43]]] {{\[}}[[VAL_3]], [[VAL_3]]] : memref to memref // CHECK: [[VAL_45:%.*]] = affine.apply #[[MAP2]]([[VAL_22]]){{\[}}[[VAL_3]], [[VAL_4]]] -// CHECK: [[VAL_46:%.*]] = cmpi "slt", [[VAL_45]], [[VAL_31]] : index +// CHECK: [[VAL_46:%.*]] = cmpi "slt", [[VAL_45]], [[VAL_31_SQUARED]] : index // CHECK: loop.if [[VAL_46]] { // CHECK: [[VAL_47:%.*]] = affine.apply #[[MAP2]]([[VAL_23]]){{\[}}[[VAL_3]], [[VAL_4]]] // CHECK: [[VAL_48:%.*]] = cmpi "slt", [[VAL_47]], [[VAL_33]] : index