Index: llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold.ll =================================================================== --- llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold.ll +++ llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold.ll @@ -172,3 +172,77 @@ %tobool.not = icmp eq i32 %dec, 0 br i1 %tobool.not, label %for.cond.cleanup, label %for.body } + +; Consider the case where %a points to a buffer exactly 17 bytes long. The +; loop below will access bytes: 0, 4, 8, and 16. The key bit is that we +; advance the pointer IV by *4* each time, and thus on the iteration we write +; byte 16, %uglygep2 (the pointer increment) is past the end of the underlying +; storage and thus violates the inbounds requirements. As a result, %uglygep2 +; is poison on the final iteration. If we insert a branch on that value, we +; have inserted undefined behavior where it did not previously exist. +; FIXME: miscompile +define void @inbounds_poison_use(ptr %a) { +; CHECK-LABEL: @inbounds_poison_use( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[UGLYGEP:%.*]] = getelementptr i8, ptr [[A:%.*]], i32 16 +; CHECK-NEXT: br label [[FOR_BODY:%.*]] +; CHECK: for.body: +; CHECK-NEXT: [[LSR_IV1:%.*]] = phi ptr [ [[UGLYGEP2:%.*]], [[FOR_BODY]] ], [ [[A]], [[ENTRY:%.*]] ] +; CHECK-NEXT: store i8 1, ptr [[LSR_IV1]], align 4 +; CHECK-NEXT: [[UGLYGEP2]] = getelementptr inbounds i8, ptr [[LSR_IV1]], i64 4 +; CHECK-NEXT: [[LSR_FOLD_TERM_COND_REPLACED_TERM_COND:%.*]] = icmp eq ptr [[UGLYGEP2]], [[UGLYGEP]] +; CHECK-NEXT: br i1 [[LSR_FOLD_TERM_COND_REPLACED_TERM_COND]], label [[FOR_END:%.*]], label [[FOR_BODY]] +; CHECK: for.end: +; CHECK-NEXT: ret void +; +entry: + br label %for.body + +for.body: ; preds = %for.body, %entry + %lsr.iv1 = phi ptr [ %uglygep2, %for.body ], [ %a, %entry ] + %lsr.iv = phi i32 [ %lsr.iv.next, %for.body ], [ 4, %entry ] + store i8 1, ptr %lsr.iv1, align 4 + %lsr.iv.next = add nsw i32 %lsr.iv, -1 + %uglygep2 = getelementptr inbounds i8, ptr %lsr.iv1, i64 4 + %exitcond.not = icmp eq i32 %lsr.iv.next, 0 + br i1 %exitcond.not, label %for.end, label %for.body + +for.end: ; preds = %for.body + ret void +} + +; In this case, the integer IV has a larger bitwidth than the pointer IV. +; This means that the smaller IV may wrap around multiple times before +; the original loop exit is taken. +; FIXME: miscompile +define void @iv_size(ptr %a, i128 %N) { +; CHECK-LABEL: @iv_size( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[TMP0:%.*]] = trunc i128 [[N:%.*]] to i32 +; CHECK-NEXT: [[TMP1:%.*]] = shl i32 [[TMP0]], 2 +; CHECK-NEXT: [[UGLYGEP:%.*]] = getelementptr i8, ptr [[A:%.*]], i32 [[TMP1]] +; CHECK-NEXT: br label [[FOR_BODY:%.*]] +; CHECK: for.body: +; CHECK-NEXT: [[LSR_IV1:%.*]] = phi ptr [ [[UGLYGEP2:%.*]], [[FOR_BODY]] ], [ [[A]], [[ENTRY:%.*]] ] +; CHECK-NEXT: store i32 1, ptr [[LSR_IV1]], align 4 +; CHECK-NEXT: [[UGLYGEP2]] = getelementptr i8, ptr [[LSR_IV1]], i64 4 +; CHECK-NEXT: [[LSR_FOLD_TERM_COND_REPLACED_TERM_COND:%.*]] = icmp eq ptr [[UGLYGEP2]], [[UGLYGEP]] +; CHECK-NEXT: br i1 [[LSR_FOLD_TERM_COND_REPLACED_TERM_COND]], label [[FOR_END:%.*]], label [[FOR_BODY]] +; CHECK: for.end: +; CHECK-NEXT: ret void +; +entry: + br label %for.body + +for.body: ; preds = %for.body, %entry + %lsr.iv1 = phi ptr [ %uglygep2, %for.body ], [ %a, %entry ] + %lsr.iv = phi i128 [ %lsr.iv.next, %for.body ], [ %N, %entry ] + store i32 1, ptr %lsr.iv1, align 4 + %lsr.iv.next = add nsw i128 %lsr.iv, -1 + %uglygep2 = getelementptr i8, ptr %lsr.iv1, i64 4 + %exitcond.not = icmp eq i128 %lsr.iv.next, 0 + br i1 %exitcond.not, label %for.end, label %for.body + +for.end: ; preds = %for.body + ret void +} Index: mlir/include/mlir/Dialect/Arith/IR/Arith.h =================================================================== --- mlir/include/mlir/Dialect/Arith/IR/Arith.h +++ mlir/include/mlir/Dialect/Arith/IR/Arith.h @@ -128,6 +128,9 @@ Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc); +/// Checks if `value` is an identity value assocated with an AtomicRMWKind op. +bool isIdentityValue(AtomicRMWKind op, Value value); + /// Returns the value obtained by applying the reduction operation kind /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`. Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Index: mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp =================================================================== --- mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp +++ mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp @@ -1205,6 +1205,38 @@ } }; +/// Checks every iterOperand of `forOp` from `reductions`. If an iterOperand is +/// not a neutral element for the corresponding reduction kind, replaces the +/// iterOperand with a neutral element and adds an op after `forOp` to combine +/// the original iterOperand and the related result of `forOp`. +static void +replaceIterOperandWithNeutralElement(AffineForOp forOp, + ArrayRef reductions) { + unsigned numIterArgs = forOp.getNumIterOperands(); + if (numIterArgs == 0) + return; + assert(reductions.size() == numIterArgs); + OpBuilder builder(forOp); + auto loc = forOp.getLoc(); + auto iterOperands = forOp.getIterOperands(); + for (const LoopReduction &reduction : reductions) { + unsigned pos = reduction.iterArgPosition; + arith::AtomicRMWKind kind = reduction.kind; + Value iterOperand = iterOperands[pos]; + if (isIdentityValue(kind, iterOperand)) + continue; + builder.setInsertionPoint(forOp); + Value neutralElem = + getIdentityValue(kind, iterOperand.getType(), builder, loc); + forOp.setOperand(pos + forOp.getNumControlOperands(), neutralElem); + builder.setInsertionPointAfter(forOp); + auto res = forOp.getResult(pos); + // Combine the original initial value with the result. + Value newRes = getReductionOp(kind, builder, loc, res, iterOperand); + res.replaceAllUsesExcept(newRes, newRes.getDefiningOp()); + } +} + /// Unrolls and jams this loop by the specified factor. LogicalResult mlir::loopUnrollJamByFactor(AffineForOp forOp, uint64_t unrollJamFactor) { @@ -1247,8 +1279,12 @@ // Get supported reductions to be used for creating reduction ops at the end. SmallVector reductions; - if (forOp.getNumIterOperands() > 0) + if (forOp.getNumIterOperands() > 0) { getSupportedReductions(forOp, reductions); + // Each iterOperand of `forOp` from `reductions` should be a neutral element + // for the corresponding reduction kind. + replaceIterOperandWithNeutralElement(forOp, reductions); + } // Generate the cleanup loop if trip count isn't a multiple of // unrollJamFactor. Index: mlir/lib/Dialect/Arith/IR/ArithOps.cpp =================================================================== --- mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -2338,10 +2338,29 @@ /// Returns the identity value associated with an AtomicRMWKind op. Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc) { - Attribute attr = getIdentityValueAttr(op, resultType, builder, loc); + Type scalarTy = getElementTypeOrSelf(resultType); + Attribute attr = getIdentityValueAttr(op, scalarTy, builder, loc); + if (scalarTy != resultType) { + attr = DenseElementsAttr::get(resultType, attr); + } return builder.create(loc, attr); } +/// Checks if `value` is an identity value assocated with an AtomicRMWKind op. +bool mlir::arith::isIdentityValue(AtomicRMWKind op, Value value) { + auto constOp = dyn_cast_or_null(value.getDefiningOp()); + if (!constOp) + return false; + OpBuilder builder(value.getContext()); + Type scalarTy = getElementTypeOrSelf(value); + Attribute valueAttr = + getIdentityValueAttr(op, scalarTy, builder, value.getLoc()); + if (constOp.getValue().isa()) + return constOp.getValue() == + DenseElementsAttr::get(value.getType(), valueAttr); + return constOp.getValue() == valueAttr; +} + /// Return the value obtained by applying the reduction operation kind /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`. Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder, Index: mlir/test/Dialect/Affine/unroll-jam.mlir =================================================================== --- mlir/test/Dialect/Affine/unroll-jam.mlir +++ mlir/test/Dialect/Affine/unroll-jam.mlir @@ -458,7 +458,8 @@ } // CHECK: %[[CONST0:[a-zA-Z0-9_]*]] = arith.constant 20 : index -// CHECK-NEXT: [[RES:%[0-9]+]]:2 = affine.for %[[IV0:arg[0-9]+]] = 0 to 20 step 2 iter_args([[ACC0:%arg[0-9]+]] = [[INIT0]], [[ACC1:%arg[0-9]+]] = [[INIT0]]) -> (f32, f32) { +// CHECK-NEXT: [[CST1:%[a-zA-Z0-9_]*]] = arith.constant 1.000000e+00 : f32 +// CHECK-NEXT: [[RES:%[0-9]+]]:2 = affine.for %[[IV0:arg[0-9]+]] = 0 to 20 step 2 iter_args([[ACC0:%arg[0-9]+]] = [[CST1]], [[ACC1:%arg[0-9]+]] = [[CST1]]) -> (f32, f32) { // CHECK-NEXT: [[RES1:%[0-9]+]]:2 = affine.for %[[IV1:arg[0-9]+]] = 0 to 30 iter_args([[ACC2:%arg[0-9]+]] = [[INIT1]], [[ACC3:%arg[0-9]+]] = [[INIT1]]) -> (f32, f32) { // CHECK-NEXT: [[LOAD1:%[0-9]+]] = affine.load {{.*}}[%[[IV0]], %[[IV1]]] // CHECK-NEXT: [[ADD1:%[0-9]+]] = arith.addf [[ACC2]], [[LOAD1]] : f32 @@ -481,6 +482,7 @@ // CHECK-NEXT: affine.yield [[ADD3]] : f32 // CHECK-NEXT: } // CHECK-NEXT: [[MUL4:%[0-9]+]] = arith.mulf [[MUL3]], [[RES2]] : f32 +// CHECK-NEXT: [[MUL5:%[0-9]+]] = arith.mulf [[MUL4]], [[INIT0]] : f32 // CHECK-NEXT: return // CHECK-LABEL: func @unroll_jam_iter_args_addi @@ -495,7 +497,8 @@ } // CHECK: %[[CONST0:[a-zA-Z0-9_]*]] = arith.constant 20 : index -// CHECK-NEXT: [[RES:%[0-9]+]]:2 = affine.for %[[IV0:arg[0-9]+]] = 0 to 20 step 2 iter_args([[ACC0:%arg[0-9]+]] = [[INIT0]], [[ACC1:%arg[0-9]+]] = [[INIT0]]) -> (i32, i32) { +// CHECK-NEXT: [[CST0:%[a-zA-Z0-9_]*]] = arith.constant 0 : i32 +// CHECK-NEXT: [[RES:%[0-9]+]]:2 = affine.for %[[IV0:arg[0-9]+]] = 0 to 20 step 2 iter_args([[ACC0:%arg[0-9]+]] = [[CST0]], [[ACC1:%arg[0-9]+]] = [[CST0]]) -> (i32, i32) { // CHECK-NEXT: [[LOAD1:%[0-9]+]] = affine.load {{.*}}[%[[IV0]]] // CHECK-NEXT: [[ADD1:%[0-9]+]] = arith.addi [[ACC0]], [[LOAD1]] : i32 // CHECK-NEXT: %[[INC1:[0-9]+]] = affine.apply [[$MAP_PLUS_1]](%[[IV0]]) @@ -508,4 +511,5 @@ // Cleanup loop (single iteration). // CHECK-NEXT: [[LOAD3:%[0-9]+]] = affine.load {{.*}}[%[[CONST0]]] // CHECK-NEXT: [[ADD4:%[0-9]+]] = arith.addi [[ADD3]], [[LOAD3]] : i32 +// CHECK-NEXT: [[ADD5:%[0-9]+]] = arith.addi [[ADD4]], [[INIT0]] : i32 // CHECK-NEXT: return