diff --git a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp --- a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp @@ -254,10 +254,11 @@ /// * Uniform operands (only operands defined outside of the loop nest, /// for now) are broadcasted to a vector. /// TODO: Support more uniform cases. +/// * Affine for operations with 'iter_args' are vectorized by +/// vectorizing their 'iter_args' operands and results. +/// TODO: Support more complex loops with divergent lbs and/or ubs. /// * The remaining operations in the loop nest are vectorized by /// widening their scalar types to vector types. -/// * TODO: Add vectorization support for loops with 'iter_args' and -/// more complex loops with divergent lbs and/or ubs. /// b. if everything under the root AffineForOp in the current pattern /// is vectorized properly, we commit that loop to the IR and remove the /// scalar loop. Otherwise, we discard the vectorized loop and keep the @@ -620,6 +621,14 @@ /// * 'replacement': %0 = vector.broadcast %1 : f32 to vector<128xf32> void registerValueVectorReplacement(Value replaced, Operation *replacement); + /// Registers the vector replacement of a block argument (e.g., iter_args). + /// + /// Example: + /// * 'replaced': 'iter_arg' block argument. + /// * 'replacement': vectorized 'iter_arg' block argument. + void registerBlockArgVectorReplacement(BlockArgument replaced, + BlockArgument replacement); + /// Registers the scalar replacement of a scalar value. 'replacement' must be /// scalar. Both values must be block arguments. Operation results should be /// replaced using the 'registerOp*' utilitites. @@ -685,15 +694,15 @@ LLVM_DEBUG(dbgs() << "into\n"); LLVM_DEBUG(dbgs() << *replacement << "\n"); - assert(replaced->getNumResults() <= 1 && "Unsupported multi-result op"); assert(replaced->getNumResults() == replacement->getNumResults() && "Unexpected replaced and replacement results"); assert(opVectorReplacement.count(replaced) == 0 && "already registered"); opVectorReplacement[replaced] = replacement; - if (replaced->getNumResults() > 0) - registerValueVectorReplacementImpl(replaced->getResult(0), - replacement->getResult(0)); + for (auto resultTuple : + llvm::zip(replaced->getResults(), replacement->getResults())) + registerValueVectorReplacementImpl(std::get<0>(resultTuple), + std::get<1>(resultTuple)); } /// Registers the vector replacement of a scalar value. The replacement @@ -716,6 +725,16 @@ registerValueVectorReplacementImpl(replaced, replacement->getResult(0)); } +/// Registers the vector replacement of a block argument (e.g., iter_args). +/// +/// Example: +/// * 'replaced': 'iter_arg' block argument. +/// * 'replacement': vectorized 'iter_arg' block argument. +void VectorizationState::registerBlockArgVectorReplacement( + BlockArgument replaced, BlockArgument replacement) { + registerValueVectorReplacementImpl(replaced, replacement); +} + void VectorizationState::registerValueVectorReplacementImpl(Value replaced, Value replacement) { assert(!valueVectorReplacement.contains(replaced) && @@ -1013,16 +1032,20 @@ // vectorized at this point. static Operation *vectorizeAffineForOp(AffineForOp forOp, VectorizationState &state) { - // 'iter_args' not supported yet. - if (forOp.getNumIterOperands() > 0) + const VectorizationStrategy &strategy = *state.strategy; + auto loopToVecDimIt = strategy.loopToVectorDim.find(forOp); + bool isLoopVecDim = loopToVecDimIt != strategy.loopToVectorDim.end(); + + // We only support 'iter_args' when the loop is not one of the vector + // dimensions. + // TODO: Support vector dimension loops. They require special handling: + // generate horizontal reduction, last-value extraction, etc. + if (forOp.getNumIterOperands() > 0 && isLoopVecDim) return nullptr; // If we are vectorizing a vector dimension, compute a new step for the new // vectorized loop using the vectorization factor for the vector dimension. // Otherwise, propagate the step of the scalar loop. - const VectorizationStrategy &strategy = *state.strategy; - auto loopToVecDimIt = strategy.loopToVectorDim.find(forOp); - bool isLoopVecDim = loopToVecDimIt != strategy.loopToVectorDim.end(); unsigned newStep; if (isLoopVecDim) { unsigned vectorDim = loopToVecDimIt->second; @@ -1033,10 +1056,15 @@ newStep = forOp.getStep(); } + // Vectorize 'iter_args'. + SmallVector vecIterOperands; + for (auto operand : forOp.getIterOperands()) + vecIterOperands.push_back(vectorizeOperand(operand, state)); + auto vecForOp = state.builder.create( forOp.getLoc(), forOp.getLowerBoundOperands(), forOp.getLowerBoundMap(), forOp.getUpperBoundOperands(), forOp.getUpperBoundMap(), newStep, - forOp.getIterOperands(), + vecIterOperands, /*bodyBuilder=*/[](OpBuilder &, Location, Value, ValueRange) { // Make sure we don't create a default terminator in the loop body as // the proper terminator will be added during vectorization. @@ -1051,11 +1079,16 @@ // since a scalar copy of the iv will prevail in the vectorized loop. // TODO: A vector replacement will also be added in the future when // vectorization of linear ops is supported. - // 3) TODO: Support 'iter_args' along non-vector dimensions. + // 3) The new 'iter_args' region arguments are registered as vector + // replacements since they have been vectorized. state.registerOpVectorReplacement(forOp, vecForOp); state.registerValueScalarReplacement(forOp.getInductionVar(), vecForOp.getInductionVar()); - // Map the new vectorized loop to its vector dimension. + for (auto iterTuple : + llvm ::zip(forOp.getRegionIterArgs(), vecForOp.getRegionIterArgs())) + state.registerBlockArgVectorReplacement(std::get<0>(iterTuple), + std::get<1>(iterTuple)); + if (isLoopVecDim) state.vecLoopToVecDim[vecForOp] = loopToVecDimIt->second; @@ -1102,12 +1135,6 @@ /// operations after the parent op. static Operation *vectorizeAffineYieldOp(AffineYieldOp yieldOp, VectorizationState &state) { - // 'iter_args' not supported yet. - if (yieldOp.getNumOperands() > 0) - return nullptr; - - // Vectorize the yield op and change the insertion point right after the new - // parent op. Operation *newYieldOp = widenOp(yieldOp, state); Operation *newParentOp = state.builder.getInsertionBlock()->getParentOp(); state.builder.setInsertionPointAfter(newParentOp); diff --git a/mlir/test/Dialect/Affine/SuperVectorize/vectorize_1d.mlir b/mlir/test/Dialect/Affine/SuperVectorize/vectorize_1d.mlir --- a/mlir/test/Dialect/Affine/SuperVectorize/vectorize_1d.mlir +++ b/mlir/test/Dialect/Affine/SuperVectorize/vectorize_1d.mlir @@ -500,11 +500,11 @@ // ----- -// CHECK-LABEL: @vec_rejected_unsupported_reduction -func @vec_rejected_unsupported_reduction(%in: memref<128x256xf32>, %out: memref<256xf32>) { +// '%i' loop is vectorized, including the inner reduction over '%j'. + +func @vec_non_vecdim_reduction(%in: memref<128x256xf32>, %out: memref<256xf32>) { %cst = constant 0.000000e+00 : f32 affine.for %i = 0 to 256 { - // CHECK-NOT: vector %final_red = affine.for %j = 0 to 128 iter_args(%red_iter = %cst) -> (f32) { %ld = affine.load %in[%j, %i] : memref<128x256xf32> %add = addf %red_iter, %ld : f32 @@ -515,13 +515,63 @@ return } +// CHECK-LABEL: @vec_non_vecdim_reduction +// CHECK: affine.for %{{.*}} = 0 to 256 step 128 { +// CHECK: %[[vzero:.*]] = constant dense<0.000000e+00> : vector<128xf32> +// CHECK: %[[final_red:.*]] = affine.for %{{.*}} = 0 to 128 iter_args(%[[red_iter:.*]] = %[[vzero]]) -> (vector<128xf32>) { +// CHECK: %[[ld:.*]] = vector.transfer_read %{{.*}} : memref<128x256xf32>, vector<128xf32> +// CHECK: %[[add:.*]] = addf %[[red_iter]], %[[ld]] : vector<128xf32> +// CHECK: affine.yield %[[add]] : vector<128xf32> +// CHECK: } +// CHECK: vector.transfer_write %[[final_red]], %{{.*}} : vector<128xf32>, memref<256xf32> +// CHECK: } + +// ----- + +// '%i' loop is vectorized, including the inner reductions over '%j'. + +func @vec_non_vecdim_reductions(%in0: memref<128x256xf32>, %in1: memref<128x256xi32>, + %out0: memref<256xf32>, %out1: memref<256xi32>) { + %zero = constant 0.000000e+00 : f32 + %one = constant 1 : i32 + affine.for %i = 0 to 256 { + %red0, %red1 = affine.for %j = 0 to 128 + iter_args(%red_iter0 = %zero, %red_iter1 = %one) -> (f32, i32) { + %ld0 = affine.load %in0[%j, %i] : memref<128x256xf32> + %add = addf %red_iter0, %ld0 : f32 + %ld1 = affine.load %in1[%j, %i] : memref<128x256xi32> + %mul = muli %red_iter1, %ld1 : i32 + affine.yield %add, %mul : f32, i32 + } + affine.store %red0, %out0[%i] : memref<256xf32> + affine.store %red1, %out1[%i] : memref<256xi32> + } + return +} + +// CHECK-LABEL: @vec_non_vecdim_reductions +// CHECK: affine.for %{{.*}} = 0 to 256 step 128 { +// CHECK: %[[vzero:.*]] = constant dense<0.000000e+00> : vector<128xf32> +// CHECK: %[[vone:.*]] = constant dense<1> : vector<128xi32> +// CHECK: %[[reds:.*]]:2 = affine.for %{{.*}} = 0 to 128 +// CHECK-SAME: iter_args(%[[red_iter0:.*]] = %[[vzero]], %[[red_iter1:.*]] = %[[vone]]) -> (vector<128xf32>, vector<128xi32>) { +// CHECK: %[[ld0:.*]] = vector.transfer_read %{{.*}} : memref<128x256xf32>, vector<128xf32> +// CHECK: %[[add:.*]] = addf %[[red_iter0]], %[[ld0]] : vector<128xf32> +// CHECK: %[[ld1:.*]] = vector.transfer_read %{{.*}} : memref<128x256xi32>, vector<128xi32> +// CHECK: %[[mul:.*]] = muli %[[red_iter1]], %[[ld1]] : vector<128xi32> +// CHECK: affine.yield %[[add]], %[[mul]] : vector<128xf32>, vector<128xi32> +// CHECK: } +// CHECK: vector.transfer_write %[[reds]]#0, %{{.*}} : vector<128xf32>, memref<256xf32> +// CHECK: vector.transfer_write %[[reds]]#1, %{{.*}} : vector<128xi32>, memref<256xi32> +// CHECK: } + // ----- -// CHECK-LABEL: @vec_rejected_unsupported_last_value -func @vec_rejected_unsupported_last_value(%in: memref<128x256xf32>, %out: memref<256xf32>) { +// '%i' loop is vectorized, including the inner last value computation over '%j'. + +func @vec_no_vecdim_last_value(%in: memref<128x256xf32>, %out: memref<256xf32>) { %cst = constant 0.000000e+00 : f32 affine.for %i = 0 to 256 { - // CHECK-NOT: vector %last_val = affine.for %j = 0 to 128 iter_args(%last_iter = %cst) -> (f32) { %ld = affine.load %in[%j, %i] : memref<128x256xf32> affine.yield %ld : f32 @@ -530,3 +580,13 @@ } return } + +// CHECK-LABEL: @vec_no_vecdim_last_value +// CHECK: affine.for %{{.*}} = 0 to 256 step 128 { +// CHECK: %[[vzero:.*]] = constant dense<0.000000e+00> : vector<128xf32> +// CHECK: %[[last_val:.*]] = affine.for %{{.*}} = 0 to 128 iter_args(%[[last_iter:.*]] = %[[vzero]]) -> (vector<128xf32>) { +// CHECK: %[[ld:.*]] = vector.transfer_read %{{.*}} : memref<128x256xf32>, vector<128xf32> +// CHECK: affine.yield %[[ld]] : vector<128xf32> +// CHECK: } +// CHECK: vector.transfer_write %[[last_val]], %{{.*}} : vector<128xf32>, memref<256xf32> +// CHECK: }