Index: mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp =================================================================== --- mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp +++ mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp @@ -1428,21 +1428,33 @@ // being added to the accumulator by inserting `select` operations, for // example: // - // %res = arith.addf %acc, %val : vector<128xf32> - // %res_masked = select %mask, %res, %acc : vector<128xi1>, vector<128xf32> - // affine.yield %res_masked : vector<128xf32> + // %val_masked = select %mask, %val, %neutralCst : vector<128xi1>, + // vector<128xf32> + // %res = arith.addf %acc, %val_masked : vector<128xf32> + // affine.yield %res : vector<128xf32> // if (Value mask = state.vecLoopToMask.lookup(newParentOp)) { state.builder.setInsertionPoint(newYieldOp); for (unsigned i = 0; i < newYieldOp->getNumOperands(); ++i) { Value result = newYieldOp->getOperand(i); Value iterArg = cast(newParentOp).getRegionIterArgs()[i]; - Value maskedResult = state.builder.create( - result.getLoc(), mask, result, iterArg); + Operation *defOp = result.getDefiningOp(); + assert(defOp && defOp->getNumOperands() == 2 && "must be a binary op"); + assert((defOp->getOperand(0) == iterArg || + defOp->getOperand(1) == iterArg) && + "must use iterArg"); + Value input = defOp->getOperand(0) == iterArg ? defOp->getOperand(1) + : defOp->getOperand(0); + // IterOperands are neutral element vectors. + Value neutralVal = cast(newParentOp).getIterOperands()[i]; + state.builder.setInsertionPoint(defOp); + Value maskedInput = state.builder.create( + input.getLoc(), mask, input, neutralVal); LLVM_DEBUG( - dbgs() << "\n[early-vect]+++++ masking a yielded vector value: " - << maskedResult); - newYieldOp->setOperand(i, maskedResult); + dbgs() << "\n[early-vect]+++++ masking an input to a binary op that" + "produces value for a yield Op: " + << maskedInput); + defOp->replaceUsesOfWith(input, maskedInput); } } Index: mlir/test/Dialect/Affine/SuperVectorize/vectorize_reduction.mlir =================================================================== --- mlir/test/Dialect/Affine/SuperVectorize/vectorize_reduction.mlir +++ mlir/test/Dialect/Affine/SuperVectorize/vectorize_reduction.mlir @@ -475,9 +475,9 @@ // CHECK: %[[elems_left:.*]] = affine.apply #[[$map0]](%[[iv]]) // CHECK: %[[mask:.*]] = vector.create_mask %[[elems_left]] : vector<128xi1> // CHECK: %[[ld:.*]] = vector.transfer_read %{{.*}} : memref<256x512xf32>, vector<128xf32> -// CHECK: %[[add:.*]] = arith.addf %[[red_iter]], %[[ld]] : vector<128xf32> -// CHECK: %[[new_acc:.*]] = arith.select %[[mask]], %[[add]], %[[red_iter]] : vector<128xi1>, vector<128xf32> -// CHECK: affine.yield %[[new_acc]] : vector<128xf32> +// CHECK: %[[select:.*]] = arith.select %[[mask]], %[[ld]], %[[vzero]] : vector<128xi1>, vector<128xf32> +// CHECK: %[[add:.*]] = arith.addf %[[red_iter]], %[[select]] : vector<128xf32> +// CHECK: affine.yield %[[add]] : vector<128xf32> // CHECK: } // CHECK: %[[final_sum:.*]] = vector.reduction , %[[vred:.*]] : vector<128xf32> into f32 // CHECK: affine.store %[[final_sum]], %{{.*}} : memref<256xf32> @@ -508,9 +508,9 @@ // CHECK: %[[elems_left:.*]] = affine.apply #[[$map1]](%[[iv]])[%[[bnd]]] // CHECK: %[[mask:.*]] = vector.create_mask %[[elems_left]] : vector<128xi1> // CHECK: %[[ld:.*]] = vector.transfer_read %{{.*}} : memref<256x512xf32>, vector<128xf32> -// CHECK: %[[add:.*]] = arith.addf %[[red_iter]], %[[ld]] : vector<128xf32> -// CHECK: %[[new_acc:.*]] = arith.select %[[mask]], %[[add]], %[[red_iter]] : vector<128xi1>, vector<128xf32> -// CHECK: affine.yield %[[new_acc]] : vector<128xf32> +// CHECK: %[[select:.*]] = arith.select %[[mask]], %[[ld]], %[[vzero]] : vector<128xi1>, vector<128xf32> +// CHECK: %[[add:.*]] = arith.addf %[[red_iter]], %[[select]] : vector<128xf32> +// CHECK: affine.yield %[[add]] : vector<128xf32> // CHECK: } // CHECK: %[[final_sum:.*]] = vector.reduction , %[[vred:.*]] : vector<128xf32> into f32 // CHECK: affine.store %[[final_sum]], %{{.*}} : memref<256xf32> @@ -557,13 +557,14 @@ // CHECK: #[[$map2:.*]] = affine_map<([[d0:.*]]) -> (-[[d0]] + 512)> // CHECK-LABEL: @vecdim_reduction_masked_unknown_lb +// CHECK: %[[vzero:.*]] = arith.constant dense<0.000000e+00> : vector<128xf32> // CHECK: %{{.*}} = affine.for %[[iv:.*]] = %[[lb:.*]] to 512 step 128 iter_args(%[[red_iter:.*]] = {{.*}}) -> (vector<128xf32>) { // CHECK: %[[elems_left:.*]] = affine.apply #[[$map2]](%[[iv]]) // CHECK: %[[mask:.*]] = vector.create_mask %[[elems_left]] : vector<128xi1> // CHECK: %[[ld:.*]] = vector.transfer_read %{{.*}} : memref<256x512xf32>, vector<128xf32> -// CHECK: %[[add:.*]] = arith.addf %[[red_iter]], %[[ld]] : vector<128xf32> -// CHECK: %[[new_acc:.*]] = arith.select %[[mask]], %[[add]], %[[red_iter]] : vector<128xi1>, vector<128xf32> -// CHECK: affine.yield %[[new_acc]] : vector<128xf32> +// CHECK: %[[select:.*]] = arith.select %[[mask]], %[[ld]], %[[vzero]] : vector<128xi1>, vector<128xf32> +// CHECK: %[[add:.*]] = arith.addf %[[red_iter]], %[[select]] : vector<128xf32> +// CHECK: affine.yield %[[add]] : vector<128xf32> // ----- @@ -585,14 +586,15 @@ // CHECK: #[[$map3:.*]] = affine_map<([[d0:.*]], [[d1:.*]]) -> ([[d0]], [[d1]] * 2)> // CHECK: #[[$map3_sub:.*]] = affine_map<([[d0:.*]], [[d1:.*]]) -> ([[d0]] - [[d1]])> // CHECK-LABEL: @vecdim_reduction_complex_ub +// CHECK: %[[vzero:.*]] = arith.constant dense<0.000000e+00> : vector<128xf32> // CHECK: %{{.*}} = affine.for %[[iv:.*]] = 0 to min #[[$map3]](%[[M:.*]], %[[N:.*]]) step 128 iter_args(%[[red_iter:.*]] = {{.*}}) -> (vector<128xf32>) { // CHECK: %[[ub:.*]] = affine.min #[[$map3]](%[[M]], %[[N]]) // CHECK: %[[elems_left:.*]] = affine.apply #[[$map3_sub]](%[[ub]], %[[iv]]) // CHECK: %[[mask:.*]] = vector.create_mask %[[elems_left]] : vector<128xi1> // CHECK: %[[ld:.*]] = vector.transfer_read %{{.*}} : memref<256x512xf32>, vector<128xf32> -// CHECK: %[[add:.*]] = arith.addf %[[red_iter]], %[[ld]] : vector<128xf32> -// CHECK: %[[new_acc:.*]] = arith.select %[[mask]], %[[add]], %[[red_iter]] : vector<128xi1>, vector<128xf32> -// CHECK: affine.yield %[[new_acc]] : vector<128xf32> +// CHECK: %[[select:.*]] = arith.select %[[mask]], %[[ld]], %[[vzero]] : vector<128xi1>, vector<128xf32> +// CHECK: %[[add:.*]] = arith.addf %[[red_iter]], %[[select]] : vector<128xf32> +// CHECK: affine.yield %[[add]] : vector<128xf32> // ----- @@ -617,14 +619,16 @@ // CHECK: #[[$map4:.*]] = affine_map<([[d0:.*]]) -> (-[[d0]] + 500)> // CHECK-LABEL: @vecdim_two_reductions_masked // CHECK: affine.for %{{.*}} = 0 to 256 { +// CHECK: %[[vzero0:.*]] = arith.constant dense<0.000000e+00> : vector<128xf32> +// CHECK: %[[vzero1:.*]] = arith.constant dense<0.000000e+00> : vector<128xf32> // CHECK: %{{.*}} = affine.for %[[iv:.*]] = 0 to 500 step 128 iter_args(%[[sum_iter:.*]] = {{.*}}, %[[esum_iter:.*]] = {{.*}}) -> (vector<128xf32>, vector<128xf32>) { // CHECK: %[[elems_left:.*]] = affine.apply #[[$map4]](%[[iv]]) // CHECK: %[[mask:.*]] = vector.create_mask %[[elems_left]] : vector<128xi1> // CHECK: %[[ld:.*]] = vector.transfer_read %{{.*}} : memref<256x512xf32>, vector<128xf32> // CHECK: %[[exp:.*]] = math.exp %[[ld]] : vector<128xf32> -// CHECK: %[[add:.*]] = arith.addf %[[sum_iter]], %[[ld]] : vector<128xf32> -// CHECK: %[[eadd:.*]] = arith.addf %[[esum_iter]], %[[exp]] : vector<128xf32> -// CHECK: %[[new_acc:.*]] = arith.select %[[mask]], %[[add]], %[[sum_iter]] : vector<128xi1>, vector<128xf32> -// CHECK: %[[new_eacc:.*]] = arith.select %[[mask]], %[[eadd]], %[[esum_iter]] : vector<128xi1>, vector<128xf32> -// CHECK: affine.yield %[[new_acc]], %[[new_eacc]] : vector<128xf32> +// CHECK: %[[select0:.*]] = arith.select %[[mask]], %[[ld]], %[[vzero0]] : vector<128xi1>, vector<128xf32> +// CHECK: %[[add:.*]] = arith.addf %[[sum_iter]], %[[select0]] : vector<128xf32> +// CHECK: %[[select1:.*]] = arith.select %[[mask]], %[[exp]], %[[vzero1]] : vector<128xi1>, vector<128xf32> +// CHECK: %[[eadd:.*]] = arith.addf %[[esum_iter]], %[[select1]] : vector<128xf32> +// CHECK: affine.yield %[[add]], %[[eadd]] : vector<128xf32> // CHECK: }