diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp --- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp +++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp @@ -58,21 +58,28 @@ if (!vecType.hasRank()) return failure(); auto shape = vecType.getShape(); - // TODO: support multidimensional vectors - if (shape.size() != 1) - return failure(); + int64_t numElements = 1; + for (auto i = 0; i < shape.size(); ++i) + numElements *= shape[i]; Value result = rewriter.create( loc, DenseElementsAttr::get( vecType, FloatAttr::get(vecType.getElementType(), 0.0))); - for (auto i = 0; i < shape.front(); ++i) { + for (auto linearIndex = 0; linearIndex < numElements; ++linearIndex) { + SmallVector positions(shape.size()); + int64_t cur = linearIndex; + for (int i = shape.size() - 1; i >= 0; --i) { + positions[i] = cur % shape[i]; + cur /= shape[i]; + } SmallVector operands; for (auto input : op->getOperands()) operands.push_back( - rewriter.create(loc, input, i)); + rewriter.create(loc, input, positions)); Value scalarOp = rewriter.create(loc, vecType.getElementType(), operands); - result = rewriter.create(loc, scalarOp, result, i); + result = + rewriter.create(loc, scalarOp, result, positions); } rewriter.replaceOp(op, {result}); return success(); diff --git a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir --- a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir +++ b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir @@ -68,20 +68,39 @@ // CHECK-SAME: %[[VAL_1:.*]]: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) { // CHECK-DAG: %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32> // CHECK-DAG: %[[CVD:.*]] = arith.constant dense<0.000000e+00> : vector<2xf64> -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32 -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : i32 -// CHECK: %[[IN0_F32:.*]] = vector.extractelement %[[VAL_0]]{{\[}}%[[C0]] : i32] : vector<2xf32> +// CHECK: %[[IN0_F32:.*]] = vector.extract %[[VAL_0]][0] : vector<2xf32> // CHECK: %[[OUT0_F32:.*]] = call @expm1f(%[[IN0_F32]]) : (f32) -> f32 -// CHECK: %[[VAL_8:.*]] = vector.insertelement %[[OUT0_F32]], %[[CVF]]{{\[}}%[[C0]] : i32] : vector<2xf32> -// CHECK: %[[IN1_F32:.*]] = vector.extractelement %[[VAL_0]]{{\[}}%[[C1]] : i32] : vector<2xf32> +// CHECK: %[[VAL_8:.*]] = vector.insert %[[OUT0_F32]], %[[CVF]] [0] : f32 into vector<2xf32> +// CHECK: %[[IN1_F32:.*]] = vector.extract %[[VAL_0]][1] : vector<2xf32> // CHECK: %[[OUT1_F32:.*]] = call @expm1f(%[[IN1_F32]]) : (f32) -> f32 -// CHECK: %[[VAL_11:.*]] = vector.insertelement %[[OUT1_F32]], %[[VAL_8]]{{\[}}%[[C1]] : i32] : vector<2xf32> -// CHECK: %[[IN0_F64:.*]] = vector.extractelement %[[VAL_1]]{{\[}}%[[C0]] : i32] : vector<2xf64> +// CHECK: %[[VAL_11:.*]] = vector.insert %[[OUT1_F32]], %[[VAL_8]] [1] : f32 into vector<2xf32> +// CHECK: %[[IN0_F64:.*]] = vector.extract %[[VAL_1]][0] : vector<2xf64> // CHECK: %[[OUT0_F64:.*]] = call @expm1(%[[IN0_F64]]) : (f64) -> f64 -// CHECK: %[[VAL_14:.*]] = vector.insertelement %[[OUT0_F64]], %[[CVD]]{{\[}}%[[C0]] : i32] : vector<2xf64> -// CHECK: %[[IN1_F64:.*]] = vector.extractelement %[[VAL_1]]{{\[}}%[[C1]] : i32] : vector<2xf64> +// CHECK: %[[VAL_14:.*]] = vector.insert %[[OUT0_F64]], %[[CVD]] [0] : f64 into vector<2xf64> +// CHECK: %[[IN1_F64:.*]] = vector.extract %[[VAL_1]][1] : vector<2xf64> // CHECK: %[[OUT1_F64:.*]] = call @expm1(%[[IN1_F64]]) : (f64) -> f64 -// CHECK: %[[VAL_17:.*]] = vector.insertelement %[[OUT1_F64]], %[[VAL_14]]{{\[}}%[[C1]] : i32] : vector<2xf64> +// CHECK: %[[VAL_17:.*]] = vector.insert %[[OUT1_F64]], %[[VAL_14]] [1] : f64 into vector<2xf64> // CHECK: return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64> // CHECK: } +func @expm1_multidim_vec_caller(%float: vector<2x2xf32>) -> (vector<2x2xf32>) { + %result = math.expm1 %float : vector<2x2xf32> + return %result : vector<2x2xf32> +} +// CHECK-LABEL: func @expm1_multidim_vec_caller( +// CHECK-SAME: %[[VAL:.*]]: vector<2x2xf32> +// CHECK-DAG: %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32> +// CHECK: %[[IN0_0_F32:.*]] = vector.extract %[[VAL]][0, 0] : vector<2x2xf32> +// CHECK: %[[OUT0_0_F32:.*]] = call @expm1f(%[[IN0_0_F32]]) : (f32) -> f32 +// CHECK: %[[VAL_1:.*]] = vector.insert %[[OUT0_0_F32]], %[[CVF]] [0, 0] : f32 into vector<2x2xf32> +// CHECK: %[[IN0_1_F32:.*]] = vector.extract %[[VAL]][0, 1] : vector<2x2xf32> +// CHECK: %[[OUT0_1_F32:.*]] = call @expm1f(%[[IN0_1_F32]]) : (f32) -> f32 +// CHECK: %[[VAL_2:.*]] = vector.insert %[[OUT0_1_F32]], %[[VAL_1]] [0, 1] : f32 into vector<2x2xf32> +// CHECK: %[[IN1_0_F32:.*]] = vector.extract %[[VAL]][1, 0] : vector<2x2xf32> +// CHECK: %[[OUT1_0_F32:.*]] = call @expm1f(%[[IN1_0_F32]]) : (f32) -> f32 +// CHECK: %[[VAL_3:.*]] = vector.insert %[[OUT1_0_F32]], %[[VAL_2]] [1, 0] : f32 into vector<2x2xf32> +// CHECK: %[[IN1_1_F32:.*]] = vector.extract %[[VAL]][1, 1] : vector<2x2xf32> +// CHECK: %[[OUT1_1_F32:.*]] = call @expm1f(%[[IN1_1_F32]]) : (f32) -> f32 +// CHECK: %[[VAL_4:.*]] = vector.insert %[[OUT1_1_F32]], %[[VAL_3]] [1, 1] : f32 into vector<2x2xf32> +// CHECK: return %[[VAL_4]] : vector<2x2xf32> +// CHECK: }