Index: mlir/lib/Dialect/Affine/IR/AffineOps.cpp =================================================================== --- mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -1816,7 +1816,10 @@ if (!ivArg || !ivArg.getOwner()) return AffineForOp(); auto *containingInst = ivArg.getOwner()->getParent()->getParentOp(); - return dyn_cast(containingInst); + if (auto forOp = dyn_cast(containingInst)) + // Check to make sure `val` is the induction variable, not an iter_arg. + return forOp.getInductionVar() == val ? forOp : AffineForOp(); + return AffineForOp(); } /// Extracts the induction variables from a list of AffineForOps and returns Index: mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp =================================================================== --- mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp +++ mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp @@ -256,6 +256,8 @@ /// * Uniform operands (only operands defined outside of the loop nest, /// for now) are broadcasted to a vector. /// TODO: Support more uniform cases. +/// * Induction variables of loops not mapped to a vector dimension are +/// broadcasted to a vector. /// * Affine for operations with 'iter_args' are vectorized by /// vectorizing their 'iter_args' operands and results. /// TODO: Support more complex loops with divergent lbs and/or ubs. @@ -1088,13 +1090,42 @@ return bcastOp; } +/// Returns true if the provided value is an induction variable of a loop +/// not in `loopToVectorDim` given the vectorization strategy. +static bool isNonVecDimLoopIV(Value value, + const VectorizationStrategy *strategy) { + AffineForOp forOp = getForInductionVarOwner(value); + if (!forOp || + strategy->loopToVectorDim.find(forOp) != strategy->loopToVectorDim.end()) + return false; + + return true; +} + +/// Generates a broadcast op for the provided induction variable or its scalar +/// replacement using the vectorization strategy in 'state'. +static Operation *vectorizeIV(Value iv, VectorizationState &state) { + OpBuilder::InsertionGuard guard(state.builder); + auto ivScalarRepl = state.valueScalarReplacement.lookupOrDefault(iv); + state.builder.setInsertionPointAfterValue(ivScalarRepl); + + auto vectorTy = getVectorType(iv.getType(), state.strategy); + auto bcastOp = + state.builder.create(iv.getLoc(), vectorTy, ivScalarRepl); + state.registerValueVectorReplacement(iv, bcastOp); + return bcastOp; +} + /// Tries to vectorize a given `operand` by applying the following logic: /// 1. if the defining operation has been already vectorized, `operand` is /// already in the proper vector form; /// 2. if the `operand` is a constant, returns the vectorized form of the /// constant; /// 3. if the `operand` is uniform, returns a vector broadcast of the `op`; -/// 4. otherwise, the vectorization of `operand` is not supported. +/// 4. if the `operand` is an induction variable of a loop not in +/// `loopToVectorDim`, returns a vector broadcast of the `operand` or its +/// scalar replacement; +/// 5. otherwise, the vectorization of `operand` is not supported. /// Newly created vector operations are registered in `state` as replacement /// for their scalar counterparts. /// In particular this logic captures some of the use cases where definitions @@ -1133,6 +1164,13 @@ return vecUniform->getResult(0); } + // Vectorize induction variables of loops not in `loopToVectorDim`. + if (isNonVecDimLoopIV(operand, state.strategy)) { + Operation *vecIV = vectorizeIV(operand, state); + LLVM_DEBUG(dbgs() << "-> iv: " << *vecIV); + return vecIV->getResult(0); + } + // Check for unsupported block argument scenarios. A supported block argument // should have been vectorized already. if (!operand.getDefiningOp()) Index: mlir/test/Dialect/Affine/SuperVectorize/vectorize_1d.mlir =================================================================== --- mlir/test/Dialect/Affine/SuperVectorize/vectorize_1d.mlir +++ mlir/test/Dialect/Affine/SuperVectorize/vectorize_1d.mlir @@ -165,6 +165,65 @@ // ----- +// CHECK-LABEL: func @vec_block_arg +func @vec_block_arg(%A : memref<32x512xi32>) { + // CHECK: affine.for %[[IV0:[arg0-9]+]] = 0 to 512 step 128 { + // CHECK-NEXT: affine.for %[[IV1:[arg0-9]+]] = 0 to 32 { + // CHECK-NEXT: %[[BROADCAST:.*]] = vector.broadcast %[[IV1]] : index to vector<128xindex> + // CHECK-NEXT: %[[CAST:.*]] = index_cast %[[BROADCAST]] : vector<128xindex> to vector<128xi32> + // CHECK-NEXT: vector.transfer_write %[[CAST]], {{.*}}[%[[IV1]], %[[IV0]]] : vector<128xi32>, memref<32x512xi32> + affine.for %i = 0 to 512 { // vectorized + affine.for %j = 0 to 32 { + %idx = std.index_cast %j : index to i32 + affine.store %idx, %A[%j, %i] : memref<32x512xi32> + } + } + return +} + +// ----- + +// CHECK-LABEL: func @vec_block_arg_2 +func @vec_block_arg_2(%A : memref, %B : memref) { + %c0 = constant 0 : index + %M = memref.dim %A, %c0 : memref + %N = memref.dim %B, %c0 : memref + %cst_0 = constant 0.0 : f32 + %c2 = constant 2 : index + %c1 = constant 1 : index + // CHECK: affine.for %[[IV0:[arg0-9]+]] = 0 to %{{.*}} { + // CHECK-NEXT: %[[BROADCAST1:.*]] = vector.broadcast %[[IV0]] : index to vector<128xindex> + // CHECK-NEXT: affine.for %[[IV1:[arg0-9]+]] = 0 to 512 step 128 { + // CHECK-NOT: vector.broadcast %[[IV1]] + // CHECK: %{{.*}} = affine.for %[[IV2:[arg0-9]+]] = 0 to 2 {{.*}} -> (vector<128xf32>) { + // CHECK-NEXT: %[[BROADCAST2:.*]] = vector.broadcast %[[IV2]] : index to vector<128xindex> + // CHECK: vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} : memref, vector<128xf32> + // CHECK-NEXT: muli %[[BROADCAST1]], %{{.*}} : vector<128xindex> + // CHECK-NEXT: addi %{{.*}}, %[[BROADCAST2]] : vector<128xindex> + // CHECK: } + // CHECK-NEXT: vector.transfer_write %{{.*}}, %{{.*}} : vector<128xf32>, memref + affine.for %i0 = 0 to %N { + affine.for %i1 = 0 to 512 { // vectorized + %0 = affine.for %i2 = 0 to 2 iter_args(%accum = %cst_0) -> f32 { + %2 = affine.load %A[%i0 * 2 + %i2 - 1, %i1] : memref + %mul = muli %i0, %c2 : index + %add = addi %mul, %i2 : index + %sub = subi %add, %c1 : index + %3 = cmpi uge, %sub, %c0 : index + %4 = cmpi ule, %sub, %M : index + %5 = and %3, %4 : i1 + %6 = select %5, %2, %cst_0 : f32 + %7 = addf %accum, %6 : f32 + affine.yield %7 : f32 + } + affine.store %0, %B[%i0, %i1] : memref + } + } + return +} + +// ----- + // CHECK-LABEL: func @vec_rejected_1 func @vec_rejected_1(%A : memref, %B : memref) { // CHECK-DAG: %[[C0:.*]] = constant 0 : index