diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp --- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp +++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/Math/Transforms/Approximation.h" #include "mlir/Dialect/Math/Transforms/Passes.h" #include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/Dialect/Vector/VectorUtils.h" #include "mlir/Dialect/X86Vector/X86VectorDialect.h" #include "mlir/IR/Builders.h" #include "mlir/IR/ImplicitLocOpBuilder.h" @@ -92,6 +93,101 @@ : value; } +//----------------------------------------------------------------------------// +// Helper function to handle n-D vectors with 1-D operations. +//----------------------------------------------------------------------------// + +// Expands and unrolls n-D vector operands into multiple fixed size 1-D vectors +// and calls the compute function with 1-D vector operands. Stitches back all +// results into the original n-D vector result. +// +// Examples: vectorWidth = 8 +// - vector<4x8xf32> unrolled 4 times +// - vector<16xf32> expanded to vector<2x8xf32> and unrolled 2 times +// - vector<4x16xf32> expanded to vector<4x2x8xf32> and unrolled 4*2 times +// +// Some math approximations rely on ISA-specific operations that only accept +// fixed size 1-D vectors (e.g. AVX expects vectors of width 8). +// +// It is the caller's responsibility to verify that the inner dimension is +// divisible by the vectorWidth, and that all operands have the same vector +// shape. +static Value +handleMultidimensionalVectors(ImplicitLocOpBuilder &builder, + ValueRange operands, int64_t vectorWidth, + std::function compute) { + assert(!operands.empty() && "operands must be not empty"); + assert(vectorWidth > 0 && "vector width must be larger than 0"); + + VectorType inputType = operands[0].getType().cast(); + ArrayRef inputShape = inputType.getShape(); + + // If input shape matches target vector width, we can just call the + // user-provided compute function with the operands. + if (inputShape == llvm::makeArrayRef(vectorWidth)) + return compute(operands); + + // Check if the inner dimension has to be expanded, or we can directly iterate + // over the outer dimensions of the vector. + int64_t innerDim = inputShape.back(); + int64_t expansionDim = innerDim / vectorWidth; + assert((innerDim % vectorWidth == 0) && "invalid inner dimension size"); + + // Maybe expand operands to the higher rank vector shape that we'll use to + // iterate over and extract one dimensional vectors. + SmallVector expandedShape(inputShape.begin(), inputShape.end()); + SmallVector expandedOperands(operands); + + if (expansionDim > 1) { + // Expand shape from [..., innerDim] to [..., expansionDim, vectorWidth]. + expandedShape.insert(expandedShape.end() - 1, expansionDim); + expandedShape.back() = vectorWidth; + + for (unsigned i = 0; i < operands.size(); ++i) { + auto operand = operands[i]; + auto eltType = operand.getType().cast().getElementType(); + auto expandedType = VectorType::get(expandedShape, eltType); + expandedOperands[i] = + builder.create(expandedType, operand); + } + } + + // Iterate over all outer dimensions of the compute shape vector type. + auto iterationDims = ArrayRef(expandedShape).drop_back(); + int64_t maxLinearIndex = computeMaxLinearIndex(iterationDims); + + SmallVector ones(iterationDims.size(), 1); + auto strides = computeStrides(iterationDims, ones); + + // Compute results for each one dimensional vector. + SmallVector results(maxLinearIndex); + + for (int64_t i = 0; i < maxLinearIndex; ++i) { + auto offsets = delinearize(strides, i); + + SmallVector extracted(expandedOperands.size()); + for (auto tuple : llvm::enumerate(expandedOperands)) + extracted[tuple.index()] = + builder.create(tuple.value(), offsets); + + results[i] = compute(extracted); + } + + // Stitch results together into one large vector. + Type resultEltType = results[0].getType().cast().getElementType(); + Type resultExpandedType = VectorType::get(expandedShape, resultEltType); + Value result = builder.create( + resultExpandedType, builder.getZeroAttr(resultExpandedType)); + + for (int64_t i = 0; i < maxLinearIndex; ++i) + result = builder.create(results[i], result, + delinearize(strides, i)); + + // Reshape back to the original vector shape. + return builder.create( + VectorType::get(inputShape, resultEltType), result); +} + //----------------------------------------------------------------------------// // Helper functions to create constants. //----------------------------------------------------------------------------// @@ -943,7 +1039,7 @@ PatternRewriter &rewriter) const { auto shape = vectorShape(op.operand().getType(), isF32); // Only support already-vectorized rsqrt's. - if (!shape.hasValue() || (*shape)[0] != 8) + if (!shape.hasValue() || shape->back() % 8 != 0) return rewriter.notifyMatchFailure(op, "unsupported operand type"); ImplicitLocOpBuilder builder(op->getLoc(), rewriter); @@ -967,7 +1063,10 @@ Value notNormalFiniteMask = builder.create(ltMinMask, infMask); // Compute an approximate result. - Value yApprox = builder.create(op.operand()); + Value yApprox = handleMultidimensionalVectors( + builder, op->getOperands(), 8, [&builder](ValueRange operands) -> Value { + return builder.create(operands); + }); // Do a single step of Newton-Raphson iteration to improve the approximation. // This uses the formula y_{n+1} = y_n * (1.5 - y_n * (0.5 * x) * y_n). diff --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir --- a/mlir/test/Dialect/Math/polynomial-approximation.mlir +++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir @@ -402,9 +402,9 @@ return %0 : f32 } -// CHECK-LABEL: func @rsqrt_vector +// CHECK-LABEL: func @rsqrt_vector_8xf32 // CHECK: math.rsqrt -// AVX2-LABEL: func @rsqrt_vector( +// AVX2-LABEL: func @rsqrt_vector_8xf32( // AVX2-SAME: %[[VAL_0:.*]]: vector<8xf32>) -> vector<8xf32> { // AVX2: %[[VAL_1:.*]] = arith.constant dense<0x7F800000> : vector<8xf32> // AVX2: %[[VAL_2:.*]] = arith.constant dense<1.500000e+00> : vector<8xf32> @@ -421,7 +421,89 @@ // AVX2: %[[VAL_13:.*]] = select %[[VAL_8]], %[[VAL_9]], %[[VAL_12]] : vector<8xi1>, vector<8xf32> // AVX2: return %[[VAL_13]] : vector<8xf32> // AVX2: } -func @rsqrt_vector(%arg0: vector<8xf32>) -> vector<8xf32> { +func @rsqrt_vector_8xf32(%arg0: vector<8xf32>) -> vector<8xf32> { %0 = math.rsqrt %arg0 : vector<8xf32> return %0 : vector<8xf32> } + +// Virtual vector width is not a multiple of an AVX2 vector width. +// +// CHECK-LABEL: func @rsqrt_vector_5xf32 +// CHECK: math.rsqrt +// AVX2-LABEL: func @rsqrt_vector_5xf32 +// AVX2: math.rsqrt +func @rsqrt_vector_5xf32(%arg0: vector<5xf32>) -> vector<5xf32> { + %0 = math.rsqrt %arg0 : vector<5xf32> + return %0 : vector<5xf32> +} + +// One dimensional virtual vector expanded and unrolled into multiple AVX2-sized +// vectors. +// +// CHECK-LABEL: func @rsqrt_vector_16xf32 +// CHECK: math.rsqrt +// AVX2-LABEL: func @rsqrt_vector_16xf32( +// AVX2-SAME: %[[ARG:.*]]: vector<16xf32> +// AVX2-SAME: ) -> vector<16xf32> +// AVX2: %[[INIT:.*]] = arith.constant dense<0.000000e+00> : vector<2x8xf32> +// AVX2: %[[EXPAND:.*]] = vector.shape_cast %[[ARG]] : vector<16xf32> to vector<2x8xf32> +// AVX2: %[[VEC0:.*]] = vector.extract %[[EXPAND]][0] +// AVX2: %[[RSQRT0:.*]] = x86vector.avx.rsqrt %[[VEC0]] +// AVX2: %[[VEC1:.*]] = vector.extract %[[EXPAND]][1] +// AVX2: %[[RSQRT1:.*]] = x86vector.avx.rsqrt %[[VEC1]] +// AVX2: %[[RESULT0:.*]] = vector.insert %[[RSQRT0]], %[[INIT]] [0] +// AVX2: %[[RESULT1:.*]] = vector.insert %[[RSQRT1]], %[[RESULT0]] [1] +// AVX2: %[[RSQRT:.*]] = vector.shape_cast %[[RESULT1]] : vector<2x8xf32> to vector<16xf32> +func @rsqrt_vector_16xf32(%arg0: vector<16xf32>) -> vector<16xf32> { + %0 = math.rsqrt %arg0 : vector<16xf32> + return %0 : vector<16xf32> +} + +// Two dimensional virtual vector unrolled into multiple AVX2-sized vectors. +// +// CHECK-LABEL: func @rsqrt_vector_2x8xf32 +// CHECK: math.rsqrt +// AVX2-LABEL: func @rsqrt_vector_2x8xf32( +// AVX2-SAME: %[[ARG:.*]]: vector<2x8xf32> +// AVX2-SAME: ) -> vector<2x8xf32> +// AVX2: %[[INIT:.*]] = arith.constant dense<0.000000e+00> : vector<2x8xf32> +// AVX2-NOT: vector.shape_cast +// AVX2: %[[VEC0:.*]] = vector.extract %[[ARG]][0] +// AVX2: %[[RSQRT0:.*]] = x86vector.avx.rsqrt %[[VEC0]] +// AVX2: %[[VEC1:.*]] = vector.extract %[[ARG]][1] +// AVX2: %[[RSQRT1:.*]] = x86vector.avx.rsqrt %[[VEC1]] +// AVX2: %[[RESULT0:.*]] = vector.insert %[[RSQRT0]], %[[INIT]] [0] +// AVX2: %[[RESULT1:.*]] = vector.insert %[[RSQRT1]], %[[RESULT0]] [1] +// AVX2-NOT: vector.shape_cast +func @rsqrt_vector_2x8xf32(%arg0: vector<2x8xf32>) -> vector<2x8xf32> { + %0 = math.rsqrt %arg0 : vector<2x8xf32> + return %0 : vector<2x8xf32> +} + +// Two dimensional virtual vector expanded and unrolled into multiple AVX2-sized +// vectors. +// +// CHECK-LABEL: func @rsqrt_vector_2x16xf32 +// CHECK: math.rsqrt +// AVX2-LABEL: func @rsqrt_vector_2x16xf32( +// AVX2-SAME: %[[ARG:.*]]: vector<2x16xf32> +// AVX2-SAME: ) -> vector<2x16xf32> +// AVX2: %[[INIT:.*]] = arith.constant dense<0.000000e+00> : vector<2x2x8xf32> +// AVX2: %[[EXPAND:.*]] = vector.shape_cast %[[ARG]] : vector<2x16xf32> to vector<2x2x8xf32> +// AVX2: %[[VEC00:.*]] = vector.extract %[[EXPAND]][0, 0] +// AVX2: %[[RSQRT00:.*]] = x86vector.avx.rsqrt %[[VEC00]] +// AVX2: %[[VEC01:.*]] = vector.extract %[[EXPAND]][0, 1] +// AVX2: %[[RSQRT01:.*]] = x86vector.avx.rsqrt %[[VEC01]] +// AVX2: %[[VEC10:.*]] = vector.extract %[[EXPAND]][1, 0] +// AVX2: %[[RSQRT10:.*]] = x86vector.avx.rsqrt %[[VEC10]] +// AVX2: %[[VEC11:.*]] = vector.extract %[[EXPAND]][1, 1] +// AVX2: %[[RSQRT11:.*]] = x86vector.avx.rsqrt %[[VEC11]] +// AVX2: %[[RESULT0:.*]] = vector.insert %[[RSQRT00]], %[[INIT]] [0, 0] +// AVX2: %[[RESULT1:.*]] = vector.insert %[[RSQRT01]], %[[RESULT0]] [0, 1] +// AVX2: %[[RESULT2:.*]] = vector.insert %[[RSQRT10]], %[[RESULT1]] [1, 0] +// AVX2: %[[RESULT3:.*]] = vector.insert %[[RSQRT11]], %[[RESULT2]] [1, 1] +// AVX2: %[[RSQRT:.*]] = vector.shape_cast %[[RESULT3]] : vector<2x2x8xf32> to vector<2x16xf32> +func @rsqrt_vector_2x16xf32(%arg0: vector<2x16xf32>) -> vector<2x16xf32> { + %0 = math.rsqrt %arg0 : vector<2x16xf32> + return %0 : vector<2x16xf32> +}