diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -328,6 +328,20 @@ setInsertionPoint(op->getBlock(), ++Block::iterator(op)); } + /// Sets the insertion point to the node after the specified value. If value + /// has a defining operation, sets the insertion point to the node after such + /// defining operation. This will cause subsequent insertions to go right + /// after it. Otherwise, value is a BlockArgumen. Sets the insertion point to + /// the start of its block. + void setInsertionPointAfter(Value val) { + if (Operation *op = val.getDefiningOp()) { + setInsertionPointAfter(op); + } else { + auto blockArg = val.cast(); + setInsertionPointToStart(blockArg.getOwner()); + } + } + /// Sets the insertion point to the start of the specified block. void setInsertionPointToStart(Block *block) { setInsertionPoint(block, block->begin()); diff --git a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp --- a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp @@ -38,6 +38,7 @@ #include "llvm/Support/Debug.h" using namespace mlir; +using namespace vector; /// /// Implements a high-level vectorization strategy on a Function. @@ -918,6 +919,42 @@ return b.createOperation(state)->getResult(0); } +/// Returns the vector type resulting from applying the provided vectorization +/// strategy on the scalar type. +static VectorType getVectorType(Type scalarTy, + const VectorizationStrategy *strategy) { + assert(!scalarTy.isa() && "Expected scalar type"); + return VectorType::get(strategy->vectorSizes, scalarTy); +} + +/// Returns true if the provided value is vector uniform given the vectorization +/// strategy. +// TODO: For now, only values that are invariants to all the loops in the +// vectorization strategy are considered vector uniforms. +static bool isUniformDefinition(Value value, + const VectorizationStrategy *strategy) { + for (auto loopToDim : strategy->loopToVectorDim) { + auto loop = cast(loopToDim.first); + if (!loop.isDefinedOutsideOfLoop(value)) + return false; + } + return true; +} + +/// Generates a broadcast op for the provided uniform value using the +/// vectorization strategy in 'state'. +static Value vectorizeUniform(Value value, VectorizationState *state) { + OpBuilder builder(value.getContext()); + builder.setInsertionPointAfter(value); + + auto vectorTy = getVectorType(value.getType(), state->strategy); + auto bcast = builder.create(value.getLoc(), vectorTy, value); + + // Add broadcast to the replacement map to reuse it for other uses. + state->replacementMap[value] = bcast; + return bcast; +} + /// Tries to vectorize a given operand `op` of Operation `op` during /// def-chain propagation or during terminal vectorization, by applying the /// following logic: @@ -927,7 +964,8 @@ /// vectorize atm (i.e. broadcasting required), returns nullptr to indicate /// failure; /// 3. if the `op` is a constant, returns the vectorized form of the constant; -/// 4. non-constant scalars are currently non-vectorizable, in particular to +/// 4. if the `op` is uniform, returns a vector broadcast of the `op`; +/// 5. non-constant scalars are currently non-vectorizable, in particular to /// guard against vectorizing an index which may be loop-variant and needs /// special handling. /// @@ -963,12 +1001,15 @@ return nullptr; } // 3. vectorize constant. - if (auto constant = operand.getDefiningOp()) { - return vectorizeConstant( - op, constant, - VectorType::get(state->strategy->vectorSizes, operand.getType())); - } - // 4. currently non-vectorizable. + if (auto constant = operand.getDefiningOp()) + return vectorizeConstant(op, constant, + getVectorType(operand.getType(), state->strategy)); + + // 4. Uniform values. + if (isUniformDefinition(operand, state->strategy)) + return vectorizeUniform(operand, state); + + // 5. currently non-vectorizable. LLVM_DEBUG(dbgs() << "-> non-vectorizable: " << operand); return nullptr; } diff --git a/mlir/test/Dialect/Affine/SuperVectorize/uniform_divergent.mlir b/mlir/test/Dialect/Affine/SuperVectorize/uniform_divergent.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Affine/SuperVectorize/uniform_divergent.mlir @@ -0,0 +1,60 @@ +// RUN: mlir-opt %s -affine-super-vectorize="virtual-vector-size=128" -split-input-file | FileCheck %s + +// Specific tests to check vectorization of uniform/divergent values. + +// CHECK-LABEL: @uniform_arg +// CHECK-SAME: %[[in:.*]]: memref<512xf32>, +// CHECK-SAME: %[[uniform:.*]]: f32 +func @uniform_arg(%in : memref<512xf32>, %uniform : f32) { + affine.for %i = 0 to 512 { + %ld = affine.load %in[%i] : memref<512xf32> + %add = addf %ld, %uniform : f32 + } + return +} + +// CHECK-NEXT: %[[bcast:.*]] = vector.broadcast %[[uniform]] : f32 to vector<128xf32> +// CHECK-NEXT: affine.for +// CHECK: addf %{{.*}}, %[[bcast]] : vector<128xf32> + +// ----- + +// CHECK-LABEL: @multi_use_uniform_arg +// CHECK-SAME: %[[in:.*]]: memref<512xf32> +// CHECK-SAME: %[[uniform:.*]]: f32 +func @multi_use_uniform_arg(%in : memref<512xf32>, %uniform : f32) { + affine.for %i = 0 to 512 { + %ld = affine.load %in[%i] : memref<512xf32> + %user0 = addf %ld, %uniform : f32 + %user1 = addf %ld, %uniform : f32 + } + return +} + +// CHECK-NEXT: %[[bcast:.*]] = vector.broadcast %[[uniform]] : f32 to vector<128xf32> +// CHECK-NOT: vector.broadcast +// CHECK-NEXT: affine.for +// CHECK: addf %{{.*}}, %[[bcast]] : vector<128xf32> +// CHECK: addf %{{.*}}, %[[bcast]] : vector<128xf32> + +// ----- + +// CHECK-LABEL: @uniform_load +func @uniform_load(%A : memref, %C : memref) { + %c0 = constant 0 : index + %N = dim %A, %c0 : memref + affine.for %i = 0 to %N { + %uniform_ld = affine.load %A[%i, %i] : memref + affine.for %j = 0 to %N { + %b = affine.load %A[%i, %j] : memref + %c = addf %uniform_ld, %b : f32 + } + } + return +} + +// CHECK: affine.for +// CHECK-NEXT: %[[uniform_ld:.*]] = affine.load %{{.*}}[%{{.*}}, %{{.*}}] : memref +// CHECK-NEXT: %[[bcast:.*]] = vector.broadcast %[[uniform_ld]] : f32 to vector<128xf32> +// CHECK-NEXT: affine.for +// CHECK: addf %[[bcast]], %{{.*}} : vector<128xf32> diff --git a/mlir/test/Dialect/Affine/SuperVectorize/vectorize_1d.mlir b/mlir/test/Dialect/Affine/SuperVectorize/vectorize_1d.mlir --- a/mlir/test/Dialect/Affine/SuperVectorize/vectorize_1d.mlir +++ b/mlir/test/Dialect/Affine/SuperVectorize/vectorize_1d.mlir @@ -396,25 +396,6 @@ return } -// This should not vectorize and should not crash. -// CHECK-LABEL: @vec_rejected_11 -func @vec_rejected_11(%A : memref, %C : memref) { - %c0 = constant 0 : index - %N = dim %A, %c0 : memref - affine.for %i = 0 to %N { -// CHECK-NOT: vector - %a = affine.load %A[%i, %i] : memref // not vectorized - affine.for %j = 0 to %N { - %b = affine.load %A[%i, %j] : memref // may be vectorized -// CHECK-NOT: vector - %c = addf %a, %b : f32 // not vectorized because %a wasn't -// CHECK-NOT: vector - affine.store %c, %C[%i, %j] : memref // not vectorized because %c wasn't - } - } - return -} - // This should not vectorize due to the sequential dependence in the scf. // CHECK-LABEL: @vec_rejected_sequential func @vec_rejected_sequential(%A : memref) {