diff --git a/llvm/docs/GettingInvolved.rst b/llvm/docs/GettingInvolved.rst --- a/llvm/docs/GettingInvolved.rst +++ b/llvm/docs/GettingInvolved.rst @@ -182,6 +182,7 @@ Proposals/TestSuite Proposals/VariableNames Proposals/VectorizationPlan + Proposals/VectorPredication :doc:`CodeOfConduct` Proposal to adopt a code of conduct on the LLVM social spaces (lists, events, @@ -203,4 +204,7 @@ Proposal to change the variable names coding standard. :doc:`Proposals/VectorizationPlan` - Proposal to model the process and upgrade the infrastructure of LLVM's Loop Vectorizer. \ No newline at end of file + Proposal to model the process and upgrade the infrastructure of LLVM's Loop Vectorizer. + +:doc:`Proposals/VectorPredication` + Proposal to support predicated vector instructions in LLVM. diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst --- a/llvm/docs/LangRef.rst +++ b/llvm/docs/LangRef.rst @@ -7538,6 +7538,8 @@ = fadd float 4.0, %var ; yields float:result = 4.0 + %var +.. _i_sub: + '``sub``' Instruction ^^^^^^^^^^^^^^^^^^^^^ @@ -7633,6 +7635,8 @@ = fsub float 4.0, %var ; yields float:result = 4.0 - %var = fsub float -0.0, %val ; yields float:result = -%var +.. _i_mul: + '``mul``' Instruction ^^^^^^^^^^^^^^^^^^^^^ @@ -7727,6 +7731,8 @@ = fmul float 4.0, %var ; yields float:result = 4.0 * %var +.. _i_udiv: + '``udiv``' Instruction ^^^^^^^^^^^^^^^^^^^^^^ @@ -7773,6 +7779,8 @@ = udiv i32 4, %var ; yields i32:result = 4 / %var +.. _i_sdiv: + '``sdiv``' Instruction ^^^^^^^^^^^^^^^^^^^^^^ @@ -7861,6 +7869,8 @@ = fdiv float 4.0, %var ; yields float:result = 4.0 / %var +.. _i_urem: + '``urem``' Instruction ^^^^^^^^^^^^^^^^^^^^^^ @@ -7905,6 +7915,8 @@ = urem i32 4, %var ; yields i32:result = 4 % %var +.. _i_srem: + '``srem``' Instruction ^^^^^^^^^^^^^^^^^^^^^^ @@ -8018,6 +8030,8 @@ operands of the same type, execute an operation on them, and produce a single value. The resulting value is the same type as its operands. +.. _i_shl: + '``shl``' Instruction ^^^^^^^^^^^^^^^^^^^^^ @@ -8070,6 +8084,9 @@ = shl i32 1, 32 ; undefined = shl <2 x i32> < i32 1, i32 1>, < i32 1, i32 2> ; yields: result=<2 x i32> < i32 2, i32 4> +.. _i_lshr: + + '``lshr``' Instruction ^^^^^^^^^^^^^^^^^^^^^^ @@ -8119,6 +8136,8 @@ = lshr i32 1, 32 ; undefined = lshr <2 x i32> < i32 -2, i32 4>, < i32 1, i32 2> ; yields: result=<2 x i32> < i32 0x7FFFFFFF, i32 1> +.. _i_ashr: + '``ashr``' Instruction ^^^^^^^^^^^^^^^^^^^^^^ @@ -8169,6 +8188,8 @@ = ashr i32 1, 32 ; undefined = ashr <2 x i32> < i32 -2, i32 4>, < i32 1, i32 3> ; yields: result=<2 x i32> < i32 -1, i32 0> +.. _i_and: + '``and``' Instruction ^^^^^^^^^^^^^^^^^^^^^ @@ -8218,6 +8239,8 @@ = and i32 15, 40 ; yields i32:result = 8 = and i32 4, 8 ; yields i32:result = 0 +.. _i_or: + '``or``' Instruction ^^^^^^^^^^^^^^^^^^^^ @@ -8267,6 +8290,8 @@ = or i32 15, 40 ; yields i32:result = 47 = or i32 4, 8 ; yields i32:result = 12 +.. _i_xor: + '``xor``' Instruction ^^^^^^^^^^^^^^^^^^^^^ @@ -14626,6 +14651,678 @@ after performing the required machine specific adjustments. The pointer returned can then be :ref:`bitcast and executed `. + +.. _int_vp: + +Vector Predication Intrinsics +----------------------------- +VP intrinics are intended for predicated SIMD/vector code. +A typical VP operation takes a vector mask and an explicit vector length parameter as in: + +:: + + llvm.vp..*( %x, %y, %mask, i32 %evl) + +The vector mask parameter always has a vector of bit type, for example `<32 x i1>`. +The explicit vector length parameter always has the type `i32`. +The explicit vector length is only effective if the MSB of its value is zero. +Results are only computed for enabled lanes. +A lane is enabled if the mask at that position is true and, if effective, where the lane position is below the explicit vector length. + + +.. _int_vp_add: + +'``llvm.vp.add.*``' Intrinsics +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Syntax: +""""""" +This is an overloaded intrinsic. + +:: + + declare <16 x i32> @llvm.vp.add.v16i32 (<16 x i32> , <16 x i32> , <16 x i1> , i32 ) + declare <256 x i64> @llvm.vp.add.v256i64 (<256 x i64> , <256 x i64> , <256 x i1> , i32 ) + +Overview: +""""""""" + +Predicated integer addition of two vectors of integers. + + +Arguments: +"""""""""" + +The first two operands and the result have the same vector of integer type. The third operand is the vector mask and has the same number of elements as the result vector type. The fourth operand is the explicit vector length of the operation. + +Semantics: +"""""""""" + +The '``llvm.vp.add``' intrinsic performs integer addition (:ref:`add `) of the first and second vector operand on each enabled lane. +The result on disabled lanes is undefined. + +Examples: +""""""""" + +.. code-block:: llvm + + %r = call <4 x i32> @llvm.vp.add.v4i32(<4 x i32> %a, <4 x i32> %b, <4 x i1> %mask, i32 %avl) + ;; For all lanes below %avl, %r is lane-wise equivalent to %also.r + + %t = add <4 x i32> %a, %b + %also.r = select <4 x i1> %mask, <4 x i32> %t, <4 x i32> undef + +.. _int_vp_sub: + +'``llvm.vp.sub.*``' Intrinsics +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Syntax: +""""""" +This is an overloaded intrinsic. + +:: + + declare <16 x i32> @llvm.vp.sub.v16i32 (<16 x i32> , <16 x i32> , <16 x i1> , i32 ) + declare <256 x i64> @llvm.vp.sub.v256i64 (<256 x i64> , <256 x i64> , <256 x i1> , i32 ) + +Overview: +""""""""" + +Predicated integer subtraction of two vectors of integers. + + +Arguments: +"""""""""" + +The first two operands and the result have the same vector of integer type. The third operand is the vector mask and has the same number of elements as the result vector type. The fourth operand is the explicit vector length of the operation. + +Semantics: +"""""""""" + +The '``llvm.vp.sub``' intrinsic performs integer subtraction (:ref:`sub `) of the first and second vector operand on each enabled lane. +The result on disabled lanes is undefined. + +Examples: +""""""""" + +.. code-block:: llvm + + %r = call <4 x i32> @llvm.vp.sub.v4i32(<4 x i32> %a, <4 x i32> %b, <4 x i1> %mask, i32 %avl) + ;; For all lanes below %avl, %r is lane-wise equivalent to %also.r + + %t = sub <4 x i32> %a, %b + %also.r = select <4 x i1> %mask, <4 x i32> %t, <4 x i32> undef + + + +.. _int_vp_mul: + +'``llvm.vp.mul.*``' Intrinsics +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Syntax: +""""""" +This is an overloaded intrinsic. + +:: + + declare <16 x i32> @llvm.vp.mul.v16i32 (<16 x i32> , <16 x i32> , <16 x i1> , i32 ) + declare <256 x i64> @llvm.vp.mul.v256i64 (<256 x i64> , <256 x i64> , <256 x i1> , i32 ) + +Overview: +""""""""" + +Predicated integer multiplication of two vectors of integers. + + +Arguments: +"""""""""" + +The first two operands and the result have the same vector of integer type. The third operand is the vector mask and has the same number of elements as the result vector type. The fourth operand is the explicit vector length of the operation. + +Semantics: +"""""""""" +The '``llvm.vp.mul``' intrinsic performs integer multiplication (:ref:`mul `) of the first and second vector operand on each enabled lane. +The result on disabled lanes is undefined. + +Examples: +""""""""" + +.. code-block:: llvm + + %r = call <4 x i32> @llvm.vp.mul.v4i32(<4 x i32> %a, <4 x i32> %b, <4 x i1> %mask, i32 %avl) + ;; For all lanes below %avl, %r is lane-wise equivalent to %also.r + + %t = mul <4 x i32> %a, %b + %also.r = select <4 x i1> %mask, <4 x i32> %t, <4 x i32> undef + + +.. _int_vp_sdiv: + +'``llvm.vp.sdiv.*``' Intrinsics +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Syntax: +""""""" +This is an overloaded intrinsic. + +:: + + declare <16 x i32> @llvm.vp.sdiv.v16i32 (<16 x i32> , <16 x i32> , <16 x i1> , i32 ) + declare <256 x i64> @llvm.vp.sdiv.v256i64 (<256 x i64> , <256 x i64> , <256 x i1> , i32 ) + +Overview: +""""""""" + +Predicated, signed division of two vectors of integers. + + +Arguments: +"""""""""" + +The first two operands and the result have the same vector of integer type. The third operand is the vector mask and has the same number of elements as the result vector type. The fourth operand is the explicit vector length of the operation. + +Semantics: +"""""""""" + +The '``llvm.vp.sdiv``' intrinsic performs signed division (:ref:`sdiv `) of the first and second vector operand on each enabled lane. +The result on disabled lanes is undefined. + +Examples: +""""""""" + +.. code-block:: llvm + + %r = call <4 x i32> @llvm.vp.sdiv.v4i32(<4 x i32> %a, <4 x i32> %b, <4 x i1> %mask, i32 %avl) + ;; For all lanes below %avl, %r is lane-wise equivalent to %also.r + + %t = sdiv <4 x i32> %a, %b + %also.r = select <4 x ii> %mask, <4 x i32> %t, <4 x i32> undef + + +.. _int_vp_udiv: + +'``llvm.vp.udiv.*``' Intrinsics +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Syntax: +""""""" +This is an overloaded intrinsic. + +:: + + declare <16 x i32> @llvm.vp.udiv.v16i32 (<16 x i32> , <16 x i32> , <16 x i1> , i32 ) + declare <256 x i64> @llvm.vp.udiv.v256i64 (<256 x i64> , <256 x i64> , <256 x i1> , i32 ) + +Overview: +""""""""" + +Predicated, unsigned division of two vectors of integers. + + +Arguments: +"""""""""" + +The first two operands and the result have the same vector of integer type. The third operand is the vector mask and has the same number of elements as the result vector type. The fourth operand is the explicit vector length of the operation. + +Semantics: +"""""""""" + +The '``llvm.vp.udiv``' intrinsic performs unsigned division (:ref:`udiv `) of the first and second vector operand on each enabled lane. +The result on disabled lanes is undefined. + +Examples: +""""""""" + +.. code-block:: llvm + + %r = call <4 x i32> @llvm.vp.udiv.v4i32(<4 x i32> %a, <4 x i32> %b, <4 x i1> %mask, i32 %avl) + ;; For all lanes below %avl, %r is lane-wise equivalent to %also.r + + %t = udiv <4 x i32> %a, %b + %also.r = select <4 x ii> %mask, <4 x i32> %t, <4 x i32> undef + + + +.. _int_vp_srem: + +'``llvm.vp.srem.*``' Intrinsics +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Syntax: +""""""" +This is an overloaded intrinsic. + +:: + + declare <16 x i32> @llvm.vp.srem.v16i32 (<16 x i32> , <16 x i32> , <16 x i1> , i32 ) + declare <256 x i64> @llvm.vp.srem.v256i64 (<256 x i64> , <256 x i64> , <256 x i1> , i32 ) + +Overview: +""""""""" + +Predicated computations of the signed remainder of two integer vectors. + + +Arguments: +"""""""""" + +The first two operands and the result have the same vector of integer type. The third operand is the vector mask and has the same number of elements as the result vector type. The fourth operand is the explicit vector length of the operation. + +Semantics: +"""""""""" + +The '``llvm.vp.srem``' intrinsic computes the remainder of the signed division (:ref:`srem `) of the first and second vector operand on each enabled lane. +The result on disabled lanes is undefined. + +Examples: +""""""""" + +.. code-block:: llvm + + %r = call <4 x i32> @llvm.vp.srem.v4i32(<4 x i32> %a, <4 x i32> %b, <4 x i1> %mask, i32 %avl) + ;; For all lanes below %avl, %r is lane-wise equivalent to %also.r + + %t = srem <4 x i32> %a, %b + %also.r = select <4 x i1> %mask, <4 x i32> %t, <4 x i32> undef + + + +.. _int_vp_urem: + +'``llvm.vp.urem.*``' Intrinsics +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Syntax: +""""""" +This is an overloaded intrinsic. + +:: + + declare <16 x i32> @llvm.vp.urem.v16i32 (<16 x i32> , <16 x i32> , <16 x i1> , i32 ) + declare <256 x i64> @llvm.vp.urem.v256i64 (<256 x i64> , <256 x i64> , <256 x i1> , i32 ) + +Overview: +""""""""" + +Predicated computation of the unsigned remainder of two integer vectors. + + +Arguments: +"""""""""" + +The first two operands and the result have the same vector of integer type. The third operand is the vector mask and has the same number of elements as the result vector type. The fourth operand is the explicit vector length of the operation. + +Semantics: +"""""""""" + +The '``llvm.vp.urem``' intrinsic computes the remainder of the unsigned division (:ref:`urem `) of the first and second vector operand on each enabled lane. +The result on disabled lanes is undefined. + +Examples: +""""""""" + +.. code-block:: llvm + + %r = call <4 x i32> @llvm.vp.urem.v4i32(<4 x i32> %a, <4 x i32> %b, <4 x i1> %mask, i32 %avl) + ;; For all lanes below %avl, %r is lane-wise equivalent to %also.r + + %t = urem <4 x i32> %a, %b + %also.r = select <4 x i1> %mask, <4 x i32> %t, <4 x i32> undef + + +.. _int_vp_ashr: + +'``llvm.vp.ashr.*``' Intrinsics +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Syntax: +""""""" +This is an overloaded intrinsic. + +:: + + declare <16 x i32> @llvm.vp.ashr.v16i32 (<16 x i32> , <16 x i32> , <16 x i1> , i32 ) + declare <256 x i64> @llvm.vp.ashr.v256i64 (<256 x i64> , <256 x i64> , <256 x i1> , i32 ) + +Overview: +""""""""" + +Vector-predicated arithmetic right-shift. + + +Arguments: +"""""""""" + +The first two operands and the result have the same vector of integer type. The third operand is the vector mask and has the same number of elements as the result vector type. The fourth operand is the explicit vector length of the operation. + +Semantics: +"""""""""" + +The '``llvm.vp.ashr``' intrinsic computes the arithmetic right shift (:ref:`ashr `) of the first operand by the second operand on each enabled lane. +The result on disabled lanes is undefined. + +Examples: +""""""""" + +.. code-block:: llvm + + %r = call <4 x i32> @llvm.vp.ashr.v4i32(<4 x i32> %a, <4 x i32> %b, <4 x i1> %mask, i32 %avl) + ;; For all lanes below %avl, %r is lane-wise equivalent to %also.r + + %t = ashr <4 x i32> %a, %b + %also.r = select <4 x i1> %mask, <4 x i32> %t, <4 x i32> undef + + +.. _int_vp_lshr: + + +'``llvm.vp.lshr.*``' Intrinsics +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Syntax: +""""""" +This is an overloaded intrinsic. + +:: + + declare <16 x i32> @llvm.vp.lshr.v16i32 (<16 x i32> , <16 x i32> , <16 x i1> , i32 ) + declare <256 x i64> @llvm.vp.lshr.v256i64 (<256 x i64> , <256 x i64> , <256 x i1> , i32 ) + +Overview: +""""""""" + +Vector-predicated logical right-shift. + + +Arguments: +"""""""""" + +The first two operands and the result have the same vector of integer type. The third operand is the vector mask and has the same number of elements as the result vector type. The fourth operand is the explicit vector length of the operation. + +Semantics: +"""""""""" + +The '``llvm.vp.lshr``' intrinsic computes the logical right shift (:ref:`lshr `) of the first operand by the second operand on each enabled lane. +The result on disabled lanes is undefined. + +Examples: +""""""""" + +.. code-block:: llvm + + %r = call <4 x i32> @llvm.vp.lshr.v4i32(<4 x i32> %a, <4 x i32> %b, <4 x i1> %mask, i32 %avl) + ;; For all lanes below %avl, %r is lane-wise equivalent to %also.r + + %t = lshr <4 x i32> %a, %b + %also.r = select <4 x i1> %mask, <4 x i32> %t, <4 x i32> undef + + +.. _int_vp_shl: + +'``llvm.vp.shl.*``' Intrinsics +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Syntax: +""""""" +This is an overloaded intrinsic. + +:: + + declare <16 x i32> @llvm.vp.shl.v16i32 (<16 x i32> , <16 x i32> , <16 x i1> , i32 ) + declare <256 x i64> @llvm.vp.shl.v256i64 (<256 x i64> , <256 x i64> , <256 x i1> , i32 ) + +Overview: +""""""""" + +Vector-predicated left shift. + + +Arguments: +"""""""""" + +The first two operands and the result have the same vector of integer type. The third operand is the vector mask and has the same number of elements as the result vector type. The fourth operand is the explicit vector length of the operation. + +Semantics: +"""""""""" + +The '``llvm.vp.shl``' intrinsic computes the left shift (:ref:`shl `) of the first operand by the second operand on each enabled lane. +The result on disabled lanes is undefined. + +Examples: +""""""""" + +.. code-block:: llvm + + %r = call <4 x i32> @llvm.vp.shl.v4i32(<4 x i32> %a, <4 x i32> %b, <4 x i1> %mask, i32 %avl) + ;; For all lanes below %avl, %r is lane-wise equivalent to %also.r + + %t = shl <4 x i32> %a, %b + %also.r = select <4 x i1> %mask, <4 x i32> %t, <4 x i32> undef + + +.. _int_vp_or: + +'``llvm.vp.or.*``' Intrinsics +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Syntax: +""""""" +This is an overloaded intrinsic. + +:: + + declare <16 x i32> @llvm.vp.or.v16i32 (<16 x i32> , <16 x i32> , <16 x i1> , i32 ) + declare <256 x i64> @llvm.vp.or.v256i64 (<256 x i64> , <256 x i64> , <256 x i1> , i32 ) + +Overview: +""""""""" + +Vector-predicated or. + + +Arguments: +"""""""""" + +The first two operands and the result have the same vector of integer type. The third operand is the vector mask and has the same number of elements as the result vector type. The fourth operand is the explicit vector length of the operation. + +Semantics: +"""""""""" + +The '``llvm.vp.or``' intrinsic performs a bitwise or (:ref:`or `) of the first two operands on each enabled lane. +The result on disabled lanes is undefined. + +Examples: +""""""""" + +.. code-block:: llvm + + %r = call <4 x i32> @llvm.vp.or.v4i32(<4 x i32> %a, <4 x i32> %b, <4 x i1> %mask, i32 %avl) + ;; For all lanes below %avl, %r is lane-wise equivalent to %also.r + + %t = or <4 x i32> %a, %b + %also.r = select <4 x i1> %mask, <4 x i32> %t, <4 x i32> undef + + +.. _int_vp_and: + +'``llvm.vp.and.*``' Intrinsics +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Syntax: +""""""" +This is an overloaded intrinsic. + +:: + + declare <16 x i32> @llvm.vp.and.v16i32 (<16 x i32> , <16 x i32> , <16 x i1> , i32 ) + declare <256 x i64> @llvm.vp.and.v256i64 (<256 x i64> , <256 x i64> , <256 x i1> , i32 ) + +Overview: +""""""""" + +Vector-predicated and. + + +Arguments: +"""""""""" + +The first two operands and the result have the same vector of integer type. The third operand is the vector mask and has the same number of elements as the result vector type. The fourth operand is the explicit vector length of the operation. + +Semantics: +"""""""""" + +The '``llvm.vp.and``' intrinsic performs a bitwise and (:ref:`and `) of the first two operands on each enabled lane. +The result on disabled lanes is undefined. + +Examples: +""""""""" + +.. code-block:: llvm + + %r = call <4 x i32> @llvm.vp.and.v4i32(<4 x i32> %a, <4 x i32> %b, <4 x i1> %mask, i32 %avl) + ;; For all lanes below %avl, %r is lane-wise equivalent to %also.r + + %t = and <4 x i32> %a, %b + %also.r = select <4 x i1> %mask, <4 x i32> %t, <4 x i32> undef + + +.. _int_vp_xor: + +'``llvm.vp.xor.*``' Intrinsics +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Syntax: +""""""" +This is an overloaded intrinsic. + +:: + + declare <16 x i32> @llvm.vp.xor.v16i32 (<16 x i32> , <16 x i32> , <16 x i1> , i32 ) + declare <256 x i64> @llvm.vp.xor.v256i64 (<256 x i64> , <256 x i64> , <256 x i1> , i32 ) + +Overview: +""""""""" + +Vector-predicated, bitwise xor. + + +Arguments: +"""""""""" + +The first two operands and the result have the same vector of integer type. The third operand is the vector mask and has the same number of elements as the result vector type. The fourth operand is the explicit vector length of the operation. + +Semantics: +"""""""""" + +The '``llvm.vp.xor``' intrinsic performs a bitwise xor (:ref:`xor `) of the first two operands on each enabled lane. +The result on disabled lanes is undefined. + +Examples: +""""""""" + +.. code-block:: llvm + + %r = call <4 x i32> @llvm.vp.xor.v4i32(<4 x i32> %a, <4 x i32> %b, <4 x i1> %mask, i32 %avl) + ;; For all lanes below %avl, %r is lane-wise equivalent to %also.r + + %t = xor <4 x i32> %a, %b + %also.r = select <4 x i1> %mask, <4 x i32> %t, <4 x i32> undef + + + +.. _int_vp_compose: + +'``llvm.vp.compose.*``' Intrinsics +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Syntax: +""""""" +This is an overloaded intrinsic. + +:: + + declare <16 x float> @llvm.vp.compose.v16f32 (<16 x float> , <16 x float> , i32 , i32 ) + +Overview: +""""""""" + +The compose intrinsic blends two input vectors based on a pivot value. + + +Arguments: +"""""""""" + +The first operand is the vector whose elements are selected below the pivot. The second operand is the vector whose values are selected starting from the pivot position. The third operand is the pivot value. The fourth operand it the explicit vector length of the operation + + +Semantics: +"""""""""" + +The '``llvm.vp.compose``' intrinsic is designed for conditional blending of two vectors based on a pivot number. All lanes below the pivot are taken from the first operand, all elements at greated and equal positions are taken from the second operand. It is useful for targets that support an explicit vector length and guarantee that vector instructions preserve the contents of vector registers above the AVL of the operation. Other targets may support this intrinsic differently, for example by lowering it into a select with a bitmask that represents the pivot comparison. +The result of this operation is equivalent to a select with an equivalent predicate mask based on the pivot operand. However, as for all VP intrinsics all lanes above the explicit vector length are undefined. + + +Examples: +""""""""" + +.. code-block:: llvm + + %r = call <4 x i32> @llvm.vp.compose.v4i32(<4 x i32> %a, <4 x i32> %b, i32 %pivot, i32 %evl) + ;; lanes of %r at positions >= %evl are undef + + ;; except for %r is equivalent to %also.r + %tmp = insertelement <4 x i32> undef, %pivot, 0 + %pivot.splat = shufflevector <4 x i32> %tmp, <4 x i32> undef, <4 x i32> zeroinitializer + %pivot.mask = icmp ult i1 <4 x i32> , %pivot.splat + %also.r = select <4 x i1> %pivot.mask, <4 x i32> %a, <4 x i32> %b + + + +.. _int_vp_select: + +'``llvm.vp.select.*``' Intrinsics +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Syntax: +""""""" +This is an overloaded intrinsic. + +:: + + declare <16 x i32> @llvm.vp.select.v16i32 (<16 x i1> , <16 x i32> , <16 x i32> , i32 ) + declare <256 x double> @llvm.vp.select.v256f64 (<256 x i1> , <256 x double> , <256 x double> , i32 ) + +Overview: +""""""""" + +Conditional select with an explicit vector length. + + +Arguments: +"""""""""" + +The first three operand and the result are vector types of the same length. The second and third operand, and the result have the same vector type. The fourth operand is the explicit vector length. + +Semantics: +"""""""""" + +The '``llvm.vp.select``' intrinsic performs conditional select (:ref:`select `) of the second and thirs vector operand on each enabled lane. +If the explicit vector length (the fourth operand) is effective, the result is undefined on lanes at positions greater-equal-than the explicit vector length. + +Examples: +""""""""" + +.. code-block:: llvm + + %r = call <4 x i32> @llvm.vp.select.v4i32(<4 x i1> %mask, <4 x i32> %onTrue, <4 x i32> %onFalse, i32 %avl) + ;; For all lanes below %avl, %r is lane-wise equivalent to %also.r + + %also.r = select <4 x i1> %mask, <4 x i32> %onTrue, <4 x i32> %onFalse + + + .. _int_mload_mstore: Masked Vector Load and Store Intrinsics diff --git a/llvm/docs/Proposals/VectorPredication.rst b/llvm/docs/Proposals/VectorPredication.rst new file mode 100644 --- /dev/null +++ b/llvm/docs/Proposals/VectorPredication.rst @@ -0,0 +1,83 @@ +========================== +Vector Predication Roadmap +========================== + +.. contents:: Table of Contents + :depth: 3 + :local: + +Motivation +========== + +This proposal defines a roadmap towards native vector predication in LLVM, specifically for vector instructions with a mask and/or an explicit vector length. +LLVM currently has no target-independent means to model predicated vector instructions for modern SIMD ISAs such as AVX512, ARM SVE, the RISC-V V extension and NEC SX-Aurora. +Only some predicated vector operations, such as masked loads and stores are available through intrinsics [MaskedIR]_. + +The Vector Predication extension +================================ + +The Vector Predication (VP) extension [EvlRFC]_ can be a first step towards native vector predication. +The VP prototype in this patch demonstrates the following concepts: + +- Predicated vector intrinsics with an explicit mask and vector length parameter on IR level. +- First-class predicated SDNodes on ISel level. Mask and vector length are value operands. +- An incremental strategy to generalize PatternMatch/InstCombine/InstSimplify and DAGCombiner to work on both regular instructions and VP intrinsics. +- DAGCombiner example: FMA fusion. +- InstCombine/InstSimplify example: FSub pattern re-writes. +- Early experiments on the LNT test suite (Clang static release, O3 -ffast-math) indicate that compile time on non-VP IR is not affected by the API abstractions in PatternMatch, etc. + +Roadmap +======= + +Drawing from the VP prototype, we propose the following roadmap towards native vector predication in LLVM: + + +1. IR-level VP intrinsics +------------------------- + +- There is a consensus on the semantics/instruction set of VP. +- VP intrinsics and attributes are available on IR level. +- TTI has capability flags for VP (``supportsVP()``?, ``haveActiveVectorLength()``?). + +Result: VP usable for IR-level vectorizers (LV, VPlan, RegionVectorizer), potential integration in Clang with builtins. + +2. CodeGen support +------------------ + +- VP intrinsics translate to first-class SDNodes (``llvm.evl.fdiv.* -> evl_fdiv``). +- VP legalization (legalize explicit vector length to mask (AVX512), legalize VP SDNodes to pre-existing ones (SSE, NEON)). + +Result: Backend development based on VP SDNodes. + +3. Lift InstSimplify/InstCombine/DAGCombiner to VP +-------------------------------------------------- + +- Introduce PredicatedInstruction, PredicatedBinaryOperator, .. helper classes that match standard vector IR and VP intrinsics. +- Add a matcher context to PatternMatch and context-aware IR Builder APIs. +- Incrementally lift DAGCombiner to work on VP SDNodes as well as on regular vector instructions. +- Incrementally lift InstCombine/InstSimplify to operate on VP as well as regular IR instructions. + +Result: Optimization of VP intrinsics on par with standard vector instructions. + +4. Deprecate llvm.masked.* / llvm.experimental.reduce.* +------------------------------------------------------- + +- Modernize llvm.masked.* / llvm.experimental.reduce* by translating to VP. +- DCE transitional APIs. + +Result: VP has superseded earlier vector intrinsics. + +5. Predicated IR Instructions +----------------------------- + +- Vector instructions have an optional mask and vector length parameter. These lower to VP SDNodes (from Stage 2). +- Phase out VP intrinsics, only keeping those that are not equivalent to vectorized scalar instructions (reduce, shuffles, ..) +- InstCombine/InstSimplify expect predication in regular Instructions (Stage (3) has laid the groundwork). + +Result: Native vector predication in IR. + +References +========== + +.. [MaskedIR] `llvm.masked.*` intrinsics, https://llvm.org/docs/LangRef.html#masked-vector-load-and-store-intrinsics +.. [EvlRFC] Explicit Vector Length RFC, https://reviews.llvm.org/D53613 diff --git a/llvm/include/llvm/Analysis/InstructionSimplify.h b/llvm/include/llvm/Analysis/InstructionSimplify.h --- a/llvm/include/llvm/Analysis/InstructionSimplify.h +++ b/llvm/include/llvm/Analysis/InstructionSimplify.h @@ -53,6 +53,10 @@ class Value; class MDNode; class BinaryOperator; +class VPIntrinsic; +namespace PatternMatch { + struct PredicatedContext; +} /// InstrInfoQuery provides an interface to query additional information for /// instructions like metadata or keywords like nsw, which provides conservative @@ -138,6 +142,13 @@ Value *SimplifyFSubInst(Value *LHS, Value *RHS, FastMathFlags FMF, const SimplifyQuery &Q); +/// Given operands for an FSub, fold the result or return null. +Value *SimplifyFSubInst(Value *LHS, Value *RHS, FastMathFlags FMF, + const SimplifyQuery &Q); +Value *SimplifyPredicatedFSubInst(Value *LHS, Value *RHS, + FastMathFlags FMF, const SimplifyQuery &Q, + PatternMatch::PredicatedContext & PC); + /// Given operands for an FMul, fold the result or return null. Value *SimplifyFMulInst(Value *LHS, Value *RHS, FastMathFlags FMF, const SimplifyQuery &Q); @@ -259,6 +270,9 @@ /// Given a callsite, fold the result or return null. Value *SimplifyCall(CallBase *Call, const SimplifyQuery &Q); +/// Given a VP intrinsic function, fold the result or return null. +Value *SimplifyVPIntrinsic(VPIntrinsic & VPInst, const SimplifyQuery &Q); + /// See if we can compute a simplified version of this instruction. If not, /// return null. Value *SimplifyInstruction(Instruction *I, const SimplifyQuery &Q, diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h --- a/llvm/include/llvm/Analysis/TargetTransformInfo.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h @@ -45,6 +45,7 @@ class Function; class GlobalValue; class IntrinsicInst; +class PredicatedInstruction; class LoadInst; class Loop; class ProfileSummaryInfo; @@ -1130,6 +1131,12 @@ bool useReductionIntrinsic(unsigned Opcode, Type *Ty, ReductionFlags Flags) const; + /// \returns True if the vector length parameter should be folded into the vector mask. + bool shouldFoldVectorLengthIntoMask(const PredicatedInstruction &PredInst) const; + + /// \returns False if this VP op should be replaced by a non-VP op or an unpredicated op plus a select. + bool supportsVPOperation(const PredicatedInstruction &PredInst) const; + /// \returns True if the target wants to expand the given reduction intrinsic /// into a shuffle sequence. bool shouldExpandReduction(const IntrinsicInst *II) const; @@ -1375,6 +1382,8 @@ virtual unsigned getStoreVectorFactor(unsigned VF, unsigned StoreSize, unsigned ChainSizeInBytes, VectorType *VecTy) const = 0; + virtual bool shouldFoldVectorLengthIntoMask(const PredicatedInstruction &PredInst) const = 0; + virtual bool supportsVPOperation(const PredicatedInstruction &PredInst) const = 0; virtual bool useReductionIntrinsic(unsigned Opcode, Type *Ty, ReductionFlags) const = 0; virtual bool shouldExpandReduction(const IntrinsicInst *II) const = 0; @@ -1843,6 +1852,12 @@ VectorType *VecTy) const override { return Impl.getStoreVectorFactor(VF, StoreSize, ChainSizeInBytes, VecTy); } + bool shouldFoldVectorLengthIntoMask(const PredicatedInstruction &PredInst) const override { + return Impl.shouldFoldVectorLengthIntoMask(PredInst); + } + bool supportsVPOperation(const PredicatedInstruction &PredInst) const { + return Impl.supportsVPOperation(PredInst); + } bool useReductionIntrinsic(unsigned Opcode, Type *Ty, ReductionFlags Flags) const override { return Impl.useReductionIntrinsic(Opcode, Ty, Flags); diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h --- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h @@ -591,6 +591,14 @@ return VF; } + bool shouldFoldVectorLengthIntoMask(const PredicatedInstruction &PredInst) const { + return true; + } + + bool supportsVPOperation(const PredicatedInstruction &PredInst) const { + return false; + } + bool useReductionIntrinsic(unsigned Opcode, Type *Ty, TTI::ReductionFlags Flags) const { return false; diff --git a/llvm/include/llvm/Bitcode/LLVMBitCodes.h b/llvm/include/llvm/Bitcode/LLVMBitCodes.h --- a/llvm/include/llvm/Bitcode/LLVMBitCodes.h +++ b/llvm/include/llvm/Bitcode/LLVMBitCodes.h @@ -633,6 +633,9 @@ ATTR_KIND_NOFREE = 62, ATTR_KIND_NOSYNC = 63, ATTR_KIND_SANITIZE_MEMTAG = 64, + ATTR_KIND_MASK = 65, + ATTR_KIND_VECTORLENGTH = 66, + ATTR_KIND_PASSTHRU = 67, }; enum ComdatSelectionKindCodes { diff --git a/llvm/include/llvm/CodeGen/ExpandVectorPredication.h b/llvm/include/llvm/CodeGen/ExpandVectorPredication.h new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/CodeGen/ExpandVectorPredication.h @@ -0,0 +1,23 @@ +//===----- ExpandVectorPredication.h - Expand vector predication --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_CODEGEN_EXPANDVECTORPREDICATION_H +#define LLVM_CODEGEN_EXPANDVECTORPREDICATION_H + +#include "llvm/IR/PassManager.h" + +namespace llvm { + +class ExpandVectorPredicationPass + : public PassInfoMixin { +public: + PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM); +}; +} // end namespace llvm + +#endif // LLVM_CODEGEN_EXPANDVECTORPREDICATION_H diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h --- a/llvm/include/llvm/CodeGen/ISDOpcodes.h +++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h @@ -198,6 +198,7 @@ /// Simple integer binary arithmetic operators. ADD, SUB, MUL, SDIV, UDIV, SREM, UREM, + VP_ADD, VP_SUB, VP_MUL, VP_SDIV, VP_UDIV, VP_SREM, VP_UREM, /// SMUL_LOHI/UMUL_LOHI - Multiply two integers of type iN, producing /// a signed/unsigned value of type i[2*N], and return the full value as @@ -285,6 +286,7 @@ /// Simple binary floating point operators. FADD, FSUB, FMUL, FDIV, FREM, + VP_FADD, VP_FSUB, VP_FMUL, VP_FDIV, VP_FREM, /// Constrained versions of the binary floating point operators. /// These will be lowered to the simple operators before final selection. @@ -310,8 +312,8 @@ STRICT_FP_TO_SINT, STRICT_FP_TO_UINT, - /// X = STRICT_FP_ROUND(Y, TRUNC) - Rounding 'Y' from a larger floating - /// point type down to the precision of the destination VT. TRUNC is a + /// X = STRICT_FP_ROUND(Y, TRUNC) - Rounding 'Y' from a larger floating + /// point type down to the precision of the destination VT. TRUNC is a /// flag, which is always an integer that is zero or one. If TRUNC is 0, /// this is a normal rounding, if it is 1, this FP_ROUND is known to not /// change the value of Y. @@ -332,6 +334,7 @@ /// FMA - Perform a * b + c with no intermediate rounding step. FMA, + VP_FMA, /// FMAD - Perform a * b + c, while getting the same result as the /// separately rounded operations. @@ -398,6 +401,19 @@ /// in terms of the element size of VEC1/VEC2, not in terms of bytes. VECTOR_SHUFFLE, + /// VP_VSHIFT(VEC1, AMOUNT, MASK, VLEN) - Returns a vector, of the same type as + /// VEC1. AMOUNT is an integer value. The returned vector is equivalent + /// to VEC1 shifted by AMOUNT (RETURNED_VEC[idx] = VEC1[idx + AMOUNT]). + VP_VSHIFT, + + /// VP_COMPRESS(VEC1, MASK, VLEN) - Returns a vector, of the same type as + /// VEC1. + VP_COMPRESS, + + /// VP_EXPAND(VEC1, MASK, VLEN) - Returns a vector, of the same type as + /// VEC1. + VP_EXPAND, + /// SCALAR_TO_VECTOR(VAL) - This represents the operation of loading a /// scalar value into element 0 of the resultant vector type. The top /// elements 1 to N-1 of the N-element vector are undefined. The type @@ -424,6 +440,7 @@ /// Bitwise operators - logical and, logical or, logical xor. AND, OR, XOR, + VP_AND, VP_OR, VP_XOR, /// ABS - Determine the unsigned absolute value of a signed integer value of /// the same bitwidth. @@ -447,6 +464,7 @@ /// fshl(X,Y,Z): (X << (Z % BW)) | (Y >> (BW - (Z % BW))) /// fshr(X,Y,Z): (X << (BW - (Z % BW))) | (Y >> (Z % BW)) SHL, SRA, SRL, ROTL, ROTR, FSHL, FSHR, + VP_SHL, VP_SRA, VP_SRL, /// Byte Swap and Counting operators. BSWAP, CTTZ, CTLZ, CTPOP, BITREVERSE, @@ -466,6 +484,14 @@ /// change the condition type in order to match the VSELECT node using a /// pattern. The condition follows the BooleanContent format of the target. VSELECT, + VP_SELECT, + + /// Select with an integer pivot (op #0) and two vector operands (ops #1 + /// and #2), returning a vector result. Op #3 is the vector length, all + /// vectors have the same length. + /// Vector element below the pivot (op #0) are taken from op #1, elements + /// at positions greater-equal than the pivot are taken from op #2. + VP_COMPOSE, /// Select with condition operator - This selects between a true value and /// a false value (ops #2 and #3) based on the boolean result of comparing @@ -480,6 +506,7 @@ /// them with (op #2) as a CondCodeSDNode. If the operands are vector types /// then the result type must also be a vector type. SETCC, + VP_SETCC, /// Like SetCC, ops #0 and #1 are the LHS and RHS operands to compare, but /// op #2 is a boolean indicating if there is an incoming carry. This @@ -620,6 +647,7 @@ FCEIL, FTRUNC, FRINT, FNEARBYINT, FROUND, FFLOOR, LROUND, LLROUND, LRINT, LLRINT, + VP_FNEG, // TODO supplement VP opcodes /// FMINNUM/FMAXNUM - Perform floating-point minimum or maximum on two /// values. // @@ -868,6 +896,7 @@ // Val, OutChain = MLOAD(BasePtr, Mask, PassThru) // OutChain = MSTORE(Value, BasePtr, Mask) MLOAD, MSTORE, + VP_LOAD, VP_STORE, // Masked gather and scatter - load and store operations for a vector of // random addresses with additional mask operand that prevents memory @@ -879,6 +908,7 @@ // The Index operand can have more vector elements than the other operands // due to type legalization. The extra elements are ignored. MGATHER, MSCATTER, + VP_GATHER, VP_SCATTER, /// This corresponds to the llvm.lifetime.* intrinsics. The first operand /// is the chain and the second operand is the alloca pointer. @@ -916,6 +946,14 @@ VECREDUCE_AND, VECREDUCE_OR, VECREDUCE_XOR, VECREDUCE_SMAX, VECREDUCE_SMIN, VECREDUCE_UMAX, VECREDUCE_UMIN, + VP_REDUCE_FADD, VP_REDUCE_FMUL, + VP_REDUCE_ADD, VP_REDUCE_MUL, + VP_REDUCE_AND, VP_REDUCE_OR, VP_REDUCE_XOR, + VP_REDUCE_SMAX, VP_REDUCE_SMIN, VP_REDUCE_UMAX, VP_REDUCE_UMIN, + + /// FMIN/FMAX nodes can have flags, for NaN/NoNaN variants. + VP_REDUCE_FMAX, VP_REDUCE_FMIN, + /// BUILTIN_OP_END - This must be the last enum value in this list. /// The target-specific pre-isel opcode values start here. BUILTIN_OP_END @@ -1092,6 +1130,20 @@ /// SETCC_INVALID if it is not possible to represent the resultant comparison. CondCode getSetCCAndOperation(CondCode Op1, CondCode Op2, bool isInteger); + /// Return the mask operand of this VP SDNode. + /// Otherwise, return -1. + int GetMaskPosVP(unsigned OpCode); + + /// Return the vector length operand of this VP SDNode. + /// Otherwise, return -1. + int GetVectorLengthPosVP(unsigned OpCode); + + /// Translate this VP OpCode to an unpredicated instruction OpCode. + unsigned GetFunctionOpCodeForVP(unsigned VPOpCode, bool hasFPExcept); + + /// Translate this non-VP Opcode to its corresponding VP Opcode + unsigned GetVPForFunctionOpCode(unsigned OpCode); + } // end llvm::ISD namespace } // end llvm namespace diff --git a/llvm/include/llvm/CodeGen/Passes.h b/llvm/include/llvm/CodeGen/Passes.h --- a/llvm/include/llvm/CodeGen/Passes.h +++ b/llvm/include/llvm/CodeGen/Passes.h @@ -439,6 +439,10 @@ /// shuffles. FunctionPass *createExpandReductionsPass(); + /// This pass expands the vector predication intrinsics into unpredicated instructions + /// with selects or just the explicit vector length into the predicate mask. + FunctionPass *createExpandVectorPredicationPass(); + // This pass expands memcmp() to load/stores. FunctionPass *createExpandMemCmpPass(); diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h --- a/llvm/include/llvm/CodeGen/SelectionDAG.h +++ b/llvm/include/llvm/CodeGen/SelectionDAG.h @@ -1125,6 +1125,20 @@ SDValue getIndexedStore(SDValue OrigStore, const SDLoc &dl, SDValue Base, SDValue Offset, ISD::MemIndexedMode AM); + /// Returns sum of the base pointer and offset. + SDValue getLoadVP(EVT VT, const SDLoc &dl, SDValue Chain, SDValue Ptr, + SDValue Mask, SDValue VLen, EVT MemVT, + MachineMemOperand *MMO, ISD::LoadExtType); + SDValue getStoreVP(SDValue Chain, const SDLoc &dl, SDValue Val, + SDValue Ptr, SDValue Mask, SDValue VLen, EVT MemVT, + MachineMemOperand *MMO, bool IsTruncating = false); + SDValue getGatherVP(SDVTList VTs, EVT VT, const SDLoc &dl, + ArrayRef Ops, MachineMemOperand *MMO, + ISD::MemIndexType IndexType); + SDValue getScatterVP(SDVTList VTs, EVT VT, const SDLoc &dl, + ArrayRef Ops, MachineMemOperand *MMO, + ISD::MemIndexType IndexType); + /// Returns sum of the base pointer and offset. SDValue getMemBasePlusOffset(SDValue Base, unsigned Offset, const SDLoc &DL); diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h --- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h +++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h @@ -549,6 +549,7 @@ class LSBaseSDNodeBitfields { friend class LSBaseSDNode; friend class MaskedGatherScatterSDNode; + friend class VPGatherScatterSDNode; uint16_t : NumMemSDNodeBits; @@ -563,6 +564,7 @@ class LoadSDNodeBitfields { friend class LoadSDNode; friend class MaskedLoadSDNode; + friend class VPLoadSDNode; uint16_t : NumLSBaseSDNodeBits; @@ -573,6 +575,7 @@ class StoreSDNodeBitfields { friend class StoreSDNode; friend class MaskedStoreSDNode; + friend class VPStoreSDNode; uint16_t : NumLSBaseSDNodeBits; @@ -721,6 +724,66 @@ } } + /// Test whether this is a vector predicated node. + bool isVP() const { + switch (NodeType) { + default: + return false; + case ISD::VP_LOAD: + case ISD::VP_STORE: + case ISD::VP_GATHER: + case ISD::VP_SCATTER: + + case ISD::VP_FNEG: + + case ISD::VP_FADD: + case ISD::VP_FMUL: + case ISD::VP_FSUB: + case ISD::VP_FDIV: + case ISD::VP_FREM: + + case ISD::VP_FMA: + + case ISD::VP_ADD: + case ISD::VP_MUL: + case ISD::VP_SUB: + case ISD::VP_SRA: + case ISD::VP_SRL: + case ISD::VP_SHL: + case ISD::VP_UDIV: + case ISD::VP_SDIV: + case ISD::VP_UREM: + case ISD::VP_SREM: + + case ISD::VP_EXPAND: + case ISD::VP_COMPRESS: + case ISD::VP_VSHIFT: + case ISD::VP_SETCC: + case ISD::VP_COMPOSE: + + case ISD::VP_AND: + case ISD::VP_XOR: + case ISD::VP_OR: + + case ISD::VP_REDUCE_ADD: + case ISD::VP_REDUCE_SMIN: + case ISD::VP_REDUCE_SMAX: + case ISD::VP_REDUCE_UMIN: + case ISD::VP_REDUCE_UMAX: + + case ISD::VP_REDUCE_MUL: + case ISD::VP_REDUCE_AND: + case ISD::VP_REDUCE_OR: + case ISD::VP_REDUCE_FADD: + case ISD::VP_REDUCE_FMUL: + case ISD::VP_REDUCE_FMIN: + case ISD::VP_REDUCE_FMAX: + + return true; + } + } + + /// Test if this node has a post-isel opcode, directly /// corresponding to a MachineInstr opcode. bool isMachineOpcode() const { return NodeType < 0; } @@ -1426,6 +1489,10 @@ N->getOpcode() == ISD::MSTORE || N->getOpcode() == ISD::MGATHER || N->getOpcode() == ISD::MSCATTER || + N->getOpcode() == ISD::VP_LOAD || + N->getOpcode() == ISD::VP_STORE || + N->getOpcode() == ISD::VP_GATHER || + N->getOpcode() == ISD::VP_SCATTER || N->isMemIntrinsic() || N->isTargetMemoryOpcode(); } @@ -2287,6 +2354,96 @@ } }; +/// This base class is used to represent VP_LOAD and VP_STORE nodes +class VPLoadStoreSDNode : public MemSDNode { +public: + friend class SelectionDAG; + + VPLoadStoreSDNode(ISD::NodeType NodeTy, unsigned Order, + const DebugLoc &dl, SDVTList VTs, EVT MemVT, + MachineMemOperand *MMO) + : MemSDNode(NodeTy, Order, dl, VTs, MemVT, MMO) {} + + // VPLoadSDNode (Chain, ptr, mask, VLen) + // VPStoreSDNode (Chain, data, ptr, mask, VLen) + // Mask is a vector of i1 elements, Vlen is i32 + const SDValue &getBasePtr() const { + return getOperand(getOpcode() == ISD::VP_LOAD ? 1 : 2); + } + const SDValue &getMask() const { + return getOperand(getOpcode() == ISD::VP_LOAD ? 2 : 3); + } + const SDValue &getVectorLength() const { + return getOperand(getOpcode() == ISD::VP_LOAD ? 3 : 4); + } + + static bool classof(const SDNode *N) { + return N->getOpcode() == ISD::VP_LOAD || + N->getOpcode() == ISD::VP_STORE; + } +}; + +/// This class is used to represent a VP_LOAD node +class VPLoadSDNode : public VPLoadStoreSDNode { +public: + friend class SelectionDAG; + + VPLoadSDNode(unsigned Order, const DebugLoc &dl, SDVTList VTs, + ISD::LoadExtType ETy, EVT MemVT, + MachineMemOperand *MMO) + : VPLoadStoreSDNode(ISD::VP_LOAD, Order, dl, VTs, MemVT, MMO) { + LoadSDNodeBits.ExtTy = ETy; + LoadSDNodeBits.IsExpanding = false; + } + + ISD::LoadExtType getExtensionType() const { + return static_cast(LoadSDNodeBits.ExtTy); + } + + const SDValue &getBasePtr() const { return getOperand(1); } + const SDValue &getMask() const { return getOperand(2); } + const SDValue &getVectorLength() const { return getOperand(3); } + + static bool classof(const SDNode *N) { + return N->getOpcode() == ISD::VP_LOAD; + } + bool isExpandingLoad() const { return LoadSDNodeBits.IsExpanding; } +}; + +/// This class is used to represent a VP_STORE node +class VPStoreSDNode : public VPLoadStoreSDNode { +public: + friend class SelectionDAG; + + VPStoreSDNode(unsigned Order, const DebugLoc &dl, SDVTList VTs, + bool isTrunc, EVT MemVT, + MachineMemOperand *MMO) + : VPLoadStoreSDNode(ISD::VP_STORE, Order, dl, VTs, MemVT, MMO) { + StoreSDNodeBits.IsTruncating = isTrunc; + StoreSDNodeBits.IsCompressing = false; + } + + /// Return true if this is a truncating store. + /// For integers this is the same as doing a TRUNCATE and storing the result. + /// For floats, it is the same as doing an FP_ROUND and storing the result. + bool isTruncatingStore() const { return StoreSDNodeBits.IsTruncating; } + + /// Returns true if the op does a compression to the vector before storing. + /// The node contiguously stores the active elements (integers or floats) + /// in src (those with their respective bit set in writemask k) to unaligned + /// memory at base_addr. + bool isCompressingStore() const { return StoreSDNodeBits.IsCompressing; } + + const SDValue &getValue() const { return getOperand(1); } + const SDValue &getBasePtr() const { return getOperand(2); } + const SDValue &getMask() const { return getOperand(3); } + const SDValue &getVectorLength() const { return getOperand(4); } + + static bool classof(const SDNode *N) { + return N->getOpcode() == ISD::VP_STORE; + } +}; + /// This base class is used to represent MLOAD and MSTORE nodes class MaskedLoadStoreSDNode : public MemSDNode { public: @@ -2374,6 +2531,85 @@ } }; +/// This is a base class used to represent +/// VP_GATHER and VP_SCATTER nodes +/// +class VPGatherScatterSDNode : public MemSDNode { +public: + friend class SelectionDAG; + + VPGatherScatterSDNode(ISD::NodeType NodeTy, unsigned Order, + const DebugLoc &dl, SDVTList VTs, EVT MemVT, + MachineMemOperand *MMO, ISD::MemIndexType IndexType) + : MemSDNode(NodeTy, Order, dl, VTs, MemVT, MMO) { + LSBaseSDNodeBits.AddressingMode = IndexType; + assert(getIndexType() == IndexType && "Value truncated"); + } + + /// How is Index applied to BasePtr when computing addresses. + ISD::MemIndexType getIndexType() const { + return static_cast(LSBaseSDNodeBits.AddressingMode); + } + bool isIndexScaled() const { + return (getIndexType() == ISD::SIGNED_SCALED) || + (getIndexType() == ISD::UNSIGNED_SCALED); + } + bool isIndexSigned() const { + return (getIndexType() == ISD::SIGNED_SCALED) || + (getIndexType() == ISD::SIGNED_UNSCALED); + } + + // In the both nodes address is Op1, mask is Op2: + // VPGatherSDNode (Chain, base, index, scale, mask, vlen) + // VPScatterSDNode (Chain, value, base, index, scale, mask, vlen) + // Mask is a vector of i1 elements + const SDValue &getBasePtr() const { return getOperand((getOpcode() == ISD::VP_GATHER) ? 1 : 2); } + const SDValue &getIndex() const { return getOperand((getOpcode() == ISD::VP_GATHER) ? 2 : 3); } + const SDValue &getScale() const { return getOperand((getOpcode() == ISD::VP_GATHER) ? 3 : 4); } + const SDValue &getMask() const { return getOperand((getOpcode() == ISD::VP_GATHER) ? 4 : 5); } + const SDValue &getVectorLength() const { return getOperand((getOpcode() == ISD::VP_GATHER) ? 5 : 6); } + + static bool classof(const SDNode *N) { + return N->getOpcode() == ISD::VP_GATHER || + N->getOpcode() == ISD::VP_SCATTER; + } +}; + +/// This class is used to represent an VP_GATHER node +/// +class VPGatherSDNode : public VPGatherScatterSDNode { +public: + friend class SelectionDAG; + + VPGatherSDNode(unsigned Order, const DebugLoc &dl, SDVTList VTs, + EVT MemVT, MachineMemOperand *MMO, + ISD::MemIndexType IndexType) + : VPGatherScatterSDNode(ISD::VP_GATHER, Order, dl, VTs, MemVT, MMO, IndexType) {} + + static bool classof(const SDNode *N) { + return N->getOpcode() == ISD::VP_GATHER; + } +}; + +/// This class is used to represent an VP_SCATTER node +/// +class VPScatterSDNode : public VPGatherScatterSDNode { +public: + friend class SelectionDAG; + + VPScatterSDNode(unsigned Order, const DebugLoc &dl, SDVTList VTs, + EVT MemVT, MachineMemOperand *MMO, + ISD::MemIndexType IndexType) + : VPGatherScatterSDNode(ISD::VP_SCATTER, Order, dl, VTs, MemVT, MMO, IndexType) {} + + const SDValue &getValue() const { return getOperand(1); } + + static bool classof(const SDNode *N) { + return N->getOpcode() == ISD::VP_SCATTER; + } +}; + + /// This is a base class used to represent /// MGATHER and MSCATTER nodes /// diff --git a/llvm/include/llvm/IR/Attributes.td b/llvm/include/llvm/IR/Attributes.td --- a/llvm/include/llvm/IR/Attributes.td +++ b/llvm/include/llvm/IR/Attributes.td @@ -139,6 +139,15 @@ /// Parameter is required to be a trivial constant. def ImmArg : EnumAttr<"immarg">; +/// Return value that is equal to this argument on enabled lanes (mask). +def Passthru : EnumAttr<"passthru">; + +/// Mask argument that applies to this function. +def Mask : EnumAttr<"mask">; + +/// Dynamic Vector Length argument of this function. +def VectorLength : EnumAttr<"vlen">; + /// Function can return twice. def ReturnsTwice : EnumAttr<"returns_twice">; diff --git a/llvm/include/llvm/IR/IRBuilder.h b/llvm/include/llvm/IR/IRBuilder.h --- a/llvm/include/llvm/IR/IRBuilder.h +++ b/llvm/include/llvm/IR/IRBuilder.h @@ -29,6 +29,7 @@ #include "llvm/IR/Function.h" #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/InstrTypes.h" +#include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" @@ -97,8 +98,8 @@ FastMathFlags FMF; bool IsFPConstrained; - ConstrainedFPIntrinsic::ExceptionBehavior DefaultConstrainedExcept; - ConstrainedFPIntrinsic::RoundingMode DefaultConstrainedRounding; + ExceptionBehavior DefaultConstrainedExcept; + RoundingMode DefaultConstrainedRounding; ArrayRef DefaultOperandBundles; @@ -106,8 +107,8 @@ IRBuilderBase(LLVMContext &context, MDNode *FPMathTag = nullptr, ArrayRef OpBundles = None) : Context(context), DefaultFPMathTag(FPMathTag), IsFPConstrained(false), - DefaultConstrainedExcept(ConstrainedFPIntrinsic::ebStrict), - DefaultConstrainedRounding(ConstrainedFPIntrinsic::rmDynamic), + DefaultConstrainedExcept(ExceptionBehavior::ebStrict), + DefaultConstrainedRounding(RoundingMode::rmDynamic), DefaultOperandBundles(OpBundles) { ClearInsertionPoint(); } @@ -235,23 +236,23 @@ /// Set the exception handling to be used with constrained floating point void setDefaultConstrainedExcept( - ConstrainedFPIntrinsic::ExceptionBehavior NewExcept) { + ExceptionBehavior NewExcept) { DefaultConstrainedExcept = NewExcept; } /// Set the rounding mode handling to be used with constrained floating point void setDefaultConstrainedRounding( - ConstrainedFPIntrinsic::RoundingMode NewRounding) { + RoundingMode NewRounding) { DefaultConstrainedRounding = NewRounding; } /// Get the exception handling used with constrained floating point - ConstrainedFPIntrinsic::ExceptionBehavior getDefaultConstrainedExcept() { + ExceptionBehavior getDefaultConstrainedExcept() { return DefaultConstrainedExcept; } /// Get the rounding mode handling used with constrained floating point - ConstrainedFPIntrinsic::RoundingMode getDefaultConstrainedRounding() { + RoundingMode getDefaultConstrainedRounding() { return DefaultConstrainedRounding; } @@ -656,6 +657,23 @@ /// assume that the provided condition will be true. CallInst *CreateAssumption(Value *Cond); + /// Call an arithmetic VP intrinsic. + Instruction *CreateVectorPredicatedInst(unsigned OC, ArrayRef, + Instruction *FMFSource = nullptr, + const Twine &Name = ""); + + /// Call an comparison VP intrinsic. + Instruction *CreateVectorPredicatedCmp(CmpInst::Predicate Pred, + Value *FirstOp, Value *SndOp, Value *Mask, + Value *VectorLength, + const Twine &Name = ""); + + /// Call an comparison VP intrinsic. + Instruction *CreateVectorPredicatedReduce(Module &M, CmpInst::Predicate Pred, + Value *FirstOp, Value *SndOp, Value *Mask, + Value *VectorLength, + const Twine &Name = ""); + /// Create a call to the experimental.gc.statepoint intrinsic to /// start a new statepoint sequence. CallInst *CreateGCStatepointCall(uint64_t ID, uint32_t NumPatchBytes, @@ -1098,35 +1116,25 @@ } Value *getConstrainedFPRounding( - Optional Rounding) { - ConstrainedFPIntrinsic::RoundingMode UseRounding = + Optional Rounding) { + RoundingMode UseRounding = DefaultConstrainedRounding; if (Rounding.hasValue()) UseRounding = Rounding.getValue(); - Optional RoundingStr = - ConstrainedFPIntrinsic::RoundingModeToStr(UseRounding); - assert(RoundingStr.hasValue() && "Garbage strict rounding mode!"); - auto *RoundingMDS = MDString::get(Context, RoundingStr.getValue()); - - return MetadataAsValue::get(Context, RoundingMDS); + return GetConstrainedFPRounding(Context, UseRounding); } Value *getConstrainedFPExcept( - Optional Except) { - ConstrainedFPIntrinsic::ExceptionBehavior UseExcept = + Optional Except) { + ExceptionBehavior UseExcept = DefaultConstrainedExcept; if (Except.hasValue()) UseExcept = Except.getValue(); - Optional ExceptStr = - ConstrainedFPIntrinsic::ExceptionBehaviorToStr(UseExcept); - assert(ExceptStr.hasValue() && "Garbage strict exception behavior!"); - auto *ExceptMDS = MDString::get(Context, ExceptStr.getValue()); - - return MetadataAsValue::get(Context, ExceptMDS); + return GetConstrainedFPExcept(Context, UseExcept); } public: @@ -1483,8 +1491,8 @@ CallInst *CreateConstrainedFPBinOp( Intrinsic::ID ID, Value *L, Value *R, Instruction *FMFSource = nullptr, const Twine &Name = "", MDNode *FPMathTag = nullptr, - Optional Rounding = None, - Optional Except = None) { + Optional Rounding = None, + Optional Except = None) { Value *RoundingV = getConstrainedFPRounding(Rounding); Value *ExceptV = getConstrainedFPExcept(Except); @@ -2078,8 +2086,8 @@ Intrinsic::ID ID, Value *V, Type *DestTy, Instruction *FMFSource = nullptr, const Twine &Name = "", MDNode *FPMathTag = nullptr, - Optional Rounding = None, - Optional Except = None) { + Optional Rounding = None, + Optional Except = None) { Value *ExceptV = getConstrainedFPExcept(Except); FastMathFlags UseFMF = FMF; diff --git a/llvm/include/llvm/IR/IntrinsicInst.h b/llvm/include/llvm/IR/IntrinsicInst.h --- a/llvm/include/llvm/IR/IntrinsicInst.h +++ b/llvm/include/llvm/IR/IntrinsicInst.h @@ -205,50 +205,386 @@ /// @} }; - /// This is the common base class for constrained floating point intrinsics. - class ConstrainedFPIntrinsic : public IntrinsicInst { + enum class RoundingMode : uint8_t { + rmInvalid, + rmDynamic, + rmToNearest, + rmDownward, + rmUpward, + rmTowardZero + }; + + enum class ExceptionBehavior : uint8_t { + ebInvalid, + ebIgnore, + ebMayTrap, + ebStrict + }; + + /// Returns a valid RoundingMode enumerator when given a string + /// that is valid as input in constrained intrinsic rounding mode + /// metadata. + Optional StrToRoundingMode(StringRef); + + /// For any RoundingMode enumerator, returns a string valid as input in + /// constrained intrinsic rounding mode metadata. + Optional RoundingModeToStr(RoundingMode); + + /// Returns a valid ExceptionBehavior enumerator when given a string + /// valid as input in constrained intrinsic exception behavior metadata. + Optional StrToExceptionBehavior(StringRef); + + /// For any ExceptionBehavior enumerator, returns a string valid as + /// input in constrained intrinsic exception behavior metadata. + Optional ExceptionBehaviorToStr(ExceptionBehavior); + + /// Return the IR Value representation of any ExceptionBehavior. + Value* + GetConstrainedFPExcept(LLVMContext&, ExceptionBehavior); + + /// Return the IR Value representation of any RoundingMode. + Value* + GetConstrainedFPRounding(LLVMContext&, RoundingMode); + + /// This is the common base class for vector predication intrinsics. + class VPIntrinsic : public IntrinsicInst { public: - /// Specifies the rounding mode to be assumed. This is only used when - /// when constrained floating point is enabled. See the LLVM Language - /// Reference Manual for details. - enum RoundingMode : uint8_t { - rmDynamic, ///< This corresponds to "fpround.dynamic". - rmToNearest, ///< This corresponds to "fpround.tonearest". - rmDownward, ///< This corresponds to "fpround.downward". - rmUpward, ///< This corresponds to "fpround.upward". - rmTowardZero ///< This corresponds to "fpround.tozero". + enum class VPTypeToken : int8_t { + Returned = 0, // vectorized return type. + Vector = 1, // vector operand type + Pointer = 2, // vector pointer-operand type (memory op) + Mask = 3 // vector mask type }; - /// Specifies the required exception behavior. This is only used when - /// when constrained floating point is used. See the LLVM Language - /// Reference Manual for details. - enum ExceptionBehavior : uint8_t { - ebIgnore, ///< This corresponds to "fpexcept.ignore". - ebMayTrap, ///< This corresponds to "fpexcept.maytrap". - ebStrict ///< This corresponds to "fpexcept.strict". - }; + using TypeTokenVec = SmallVector; + using ShortTypeVec = SmallVector; + + /// \brief Declares a llvm.vp.* intrinsic in \p M that matches the parameters \p Params. + static Function* GetDeclarationForParams(Module *M, Intrinsic::ID, ArrayRef Params); + + // Type tokens required to instantiate this intrinsic. + static TypeTokenVec GetTypeTokens(Intrinsic::ID); + + // whether the intrinsic has a rounding mode parameter (regardless of + // setting). + static bool HasRoundingModeParam(Intrinsic::ID VPID) { return GetRoundingModeParamPos(VPID) != None; } + // whether the intrinsic has a exception behavior parameter (regardless of + // setting). + static bool HasExceptionBehaviorParam(Intrinsic::ID VPID) { return GetExceptionBehaviorParamPos(VPID) != None; } + static Optional GetMaskParamPos(Intrinsic::ID IntrinsicID); + static Optional GetVectorLengthParamPos(Intrinsic::ID IntrinsicID); + static Optional + GetExceptionBehaviorParamPos(Intrinsic::ID IntrinsicID); + static Optional GetRoundingModeParamPos(Intrinsic::ID IntrinsicID); + // the llvm.vp.* intrinsic for this llvm.experimental.constrained.* + // intrinsic + static Intrinsic::ID GetForConstrainedIntrinsic(Intrinsic::ID IntrinsicID); + static Intrinsic::ID GetForOpcode(unsigned OC); + + /// TODO make this private! + /// \brief Generate the disambiguating type vec for this VP Intrinsic. + /// \returns A disamguating type vector to instantiate this intrinsic. + /// \p TTVec + /// Vector of disambiguating tokens. + /// \p VecRetTy + /// The return type of the intrinsic (optional) + /// \p VecPtrTy + /// The pointer operand type (optional) + /// \p VectorTy + /// The vector data type of the operation. + static VPIntrinsic::ShortTypeVec + EncodeTypeTokens(VPIntrinsic::TypeTokenVec TTVec, Type *VecRetTy, + Type *VecPtrTy, Type &VectorTy); + + /// set the mask parameter. + /// this asserts if the underlying intrinsic has no mask parameter. + void setMaskParam(Value *); + + /// set the vector length parameter. + /// this asserts if the underlying intrinsic has no vector length + /// parameter. + void setVectorLengthParam(Value *); + + /// \return the mask parameter or nullptr. + Value *getMaskParam() const; + + /// \return the vector length parameter or nullptr. + Value *getVectorLengthParam() const; + + /// \return whether the vector length param can be ignored. + bool canIgnoreVectorLengthParam() const; + + /// \return The pointer operand of this load,store, gather or scatter. + Value *getMemoryPointerParam() const; + static Optional GetMemoryPointerParamPos(Intrinsic::ID); + + /// \return The data (payload) operand of this store or scatter. + Value *getMemoryDataParam() const; + static Optional GetMemoryDataParamPos(Intrinsic::ID); + + /// \return The vector to reduce if this is a reduction operation. + Value *getReductionVectorParam() const; + static Optional GetReductionVectorParamPos(Intrinsic::ID VPID); + + /// \return The initial value of this is a reduction operation. + Value *getReductionAccuParam() const; + static Optional GetReductionAccuParamPos(Intrinsic::ID VPID); + + /// \return the static element count (vector number of elements) the vector + /// length parameter applies to. This returns None if the operation is + /// scalable. + Optional getStaticVectorLength() const; bool isUnaryOp() const; + static bool IsUnaryVPOp(Intrinsic::ID); + bool isBinaryOp() const; + static bool IsBinaryVPOp(Intrinsic::ID); bool isTernaryOp() const; + static bool IsTernaryVPOp(Intrinsic::ID); + + // compare intrinsic + bool isCompareOp() const { + switch (getIntrinsicID()) { + default: + return false; + case Intrinsic::vp_icmp: + case Intrinsic::vp_fcmp: + return true; + } + } + CmpInst::Predicate getCmpPredicate() const; + + // Contrained fp-math + // whether this is an fp op with non-standard rounding or exception + // behavior. + bool isConstrainedOp() const; + + // the specified rounding mode. Optional getRoundingMode() const; + // the specified exception behavior. Optional getExceptionBehavior() const; - /// Returns a valid RoundingMode enumerator when given a string - /// that is valid as input in constrained intrinsic rounding mode - /// metadata. - static Optional StrToRoundingMode(StringRef); + // llvm.vp.reduction.* + bool isReductionOp() const; + static bool IsVPReduction(Intrinsic::ID VPIntrin); + + // Methods for support type inquiry through isa, cast, and dyn_cast: + static bool classof(const IntrinsicInst *I) { + switch (I->getIntrinsicID()) { + default: + return false; + + // general cmp + case Intrinsic::vp_icmp: + case Intrinsic::vp_fcmp: + + // int arith + case Intrinsic::vp_and: + case Intrinsic::vp_or: + case Intrinsic::vp_xor: + case Intrinsic::vp_ashr: + case Intrinsic::vp_lshr: + case Intrinsic::vp_shl: + case Intrinsic::vp_add: + case Intrinsic::vp_sub: + case Intrinsic::vp_mul: + case Intrinsic::vp_udiv: + case Intrinsic::vp_sdiv: + case Intrinsic::vp_urem: + case Intrinsic::vp_srem: + + // memory + case Intrinsic::vp_load: + case Intrinsic::vp_store: + case Intrinsic::vp_gather: + case Intrinsic::vp_scatter: + + // shuffle + case Intrinsic::vp_select: + case Intrinsic::vp_compose: + case Intrinsic::vp_compress: + case Intrinsic::vp_expand: + case Intrinsic::vp_vshift: + + // fp arith + case Intrinsic::vp_fneg: + + case Intrinsic::vp_fadd: + case Intrinsic::vp_fsub: + case Intrinsic::vp_fmul: + case Intrinsic::vp_fdiv: + case Intrinsic::vp_frem: + + case Intrinsic::vp_fma: + case Intrinsic::vp_ceil: + case Intrinsic::vp_cos: + case Intrinsic::vp_exp2: + case Intrinsic::vp_exp: + case Intrinsic::vp_floor: + case Intrinsic::vp_log10: + case Intrinsic::vp_log2: + case Intrinsic::vp_log: + case Intrinsic::vp_maxnum: + case Intrinsic::vp_minnum: + case Intrinsic::vp_nearbyint: + case Intrinsic::vp_pow: + case Intrinsic::vp_powi: + case Intrinsic::vp_rint: + case Intrinsic::vp_round: + case Intrinsic::vp_sin: + case Intrinsic::vp_sqrt: + case Intrinsic::vp_trunc: + case Intrinsic::vp_fptoui: + case Intrinsic::vp_fptosi: + case Intrinsic::vp_fpext: + case Intrinsic::vp_fptrunc: + + case Intrinsic::vp_lround: + case Intrinsic::vp_llround: + + // reductions + case Intrinsic::vp_reduce_add: + case Intrinsic::vp_reduce_mul: + case Intrinsic::vp_reduce_umin: + case Intrinsic::vp_reduce_umax: + case Intrinsic::vp_reduce_smin: + case Intrinsic::vp_reduce_smax: + + case Intrinsic::vp_reduce_and: + case Intrinsic::vp_reduce_or: + case Intrinsic::vp_reduce_xor: + + case Intrinsic::vp_reduce_fadd: + case Intrinsic::vp_reduce_fmul: + case Intrinsic::vp_reduce_fmin: + case Intrinsic::vp_reduce_fmax: + return true; + } + } + static bool classof(const Value *V) { + return isa(V) && classof(cast(V)); + } + + Intrinsic::ID getFunctionalIntrinsicID() const { + return GetFunctionalIntrinsicForVP(getIntrinsicID()); + } - /// For any RoundingMode enumerator, returns a string valid as input in - /// constrained intrinsic rounding mode metadata. - static Optional RoundingModeToStr(RoundingMode); + static Intrinsic::ID GetFunctionalIntrinsicForVP(Intrinsic::ID VPID) { + switch (VPID) { + default: + return VPID; + + case Intrinsic::vp_reduce_add: + return Intrinsic::experimental_vector_reduce_add; + case Intrinsic::vp_reduce_mul: + return Intrinsic::experimental_vector_reduce_mul; + case Intrinsic::vp_reduce_and: + return Intrinsic::experimental_vector_reduce_and; + case Intrinsic::vp_reduce_or: + return Intrinsic::experimental_vector_reduce_or; + case Intrinsic::vp_reduce_xor: + return Intrinsic::experimental_vector_reduce_xor; + case Intrinsic::vp_reduce_smin: + return Intrinsic::experimental_vector_reduce_smin; + case Intrinsic::vp_reduce_smax: + return Intrinsic::experimental_vector_reduce_smax; + case Intrinsic::vp_reduce_umin: + return Intrinsic::experimental_vector_reduce_umin; + case Intrinsic::vp_reduce_umax: + return Intrinsic::experimental_vector_reduce_umax; + + case Intrinsic::vp_reduce_fmin: + return Intrinsic::experimental_vector_reduce_fmin; + case Intrinsic::vp_reduce_fmax: + return Intrinsic::experimental_vector_reduce_fmax; + + case Intrinsic::vp_reduce_fadd: + return Intrinsic::experimental_vector_reduce_v2_fadd; + case Intrinsic::vp_reduce_fmul: + return Intrinsic::experimental_vector_reduce_v2_fmul; + } + } - /// Returns a valid ExceptionBehavior enumerator when given a string - /// valid as input in constrained intrinsic exception behavior metadata. - static Optional StrToExceptionBehavior(StringRef); + // Equivalent non-predicated opcode + unsigned getFunctionalOpcode() const { + if (isConstrainedOp()) { + return Instruction::Call; // TODO pass as constrained op + } + return GetFunctionalOpcodeForVP(getIntrinsicID()); + } - /// For any ExceptionBehavior enumerator, returns a string valid as - /// input in constrained intrinsic exception behavior metadata. - static Optional ExceptionBehaviorToStr(ExceptionBehavior); + // Equivalent non-predicated opcode + static unsigned GetFunctionalOpcodeForVP(Intrinsic::ID ID) { + switch (ID) { + default: + return Instruction::Call; + + case Intrinsic::vp_icmp: + return Instruction::ICmp; + case Intrinsic::vp_fcmp: + return Instruction::FCmp; + + case Intrinsic::vp_and: + return Instruction::And; + case Intrinsic::vp_or: + return Instruction::Or; + case Intrinsic::vp_xor: + return Instruction::Xor; + case Intrinsic::vp_ashr: + return Instruction::AShr; + case Intrinsic::vp_lshr: + return Instruction::LShr; + case Intrinsic::vp_shl: + return Instruction::Shl; + + case Intrinsic::vp_select: + return Instruction::Select; + + case Intrinsic::vp_load: + return Instruction::Load; + case Intrinsic::vp_store: + return Instruction::Store; + + case Intrinsic::vp_fneg: + return Instruction::FNeg; + + case Intrinsic::vp_fadd: + return Instruction::FAdd; + case Intrinsic::vp_fsub: + return Instruction::FSub; + case Intrinsic::vp_fmul: + return Instruction::FMul; + case Intrinsic::vp_fdiv: + return Instruction::FDiv; + case Intrinsic::vp_frem: + return Instruction::FRem; + + case Intrinsic::vp_add: + return Instruction::Add; + case Intrinsic::vp_sub: + return Instruction::Sub; + case Intrinsic::vp_mul: + return Instruction::Mul; + case Intrinsic::vp_udiv: + return Instruction::UDiv; + case Intrinsic::vp_sdiv: + return Instruction::SDiv; + case Intrinsic::vp_urem: + return Instruction::URem; + case Intrinsic::vp_srem: + return Instruction::SRem; + } + } + }; + + /// This is the common base class for constrained floating point intrinsics. + class ConstrainedFPIntrinsic : public IntrinsicInst { + public: + + bool isUnaryOp() const; + bool isTernaryOp() const; + Optional getRoundingMode() const; + Optional getExceptionBehavior() const; // Methods for support type inquiry through isa, cast, and dyn_cast: static bool classof(const IntrinsicInst *I) { diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td --- a/llvm/include/llvm/IR/Intrinsics.td +++ b/llvm/include/llvm/IR/Intrinsics.td @@ -98,6 +98,25 @@ int ArgNo = argNo; } +// VectorLength - The specified argument is the Dynamic Vector Length of the +// operation. +class VectorLength : IntrinsicProperty { + int ArgNo = argNo; +} + +// Mask - The specified argument contains the per-lane mask of this +// intrinsic. Inputs on masked-out lanes must not affect the result of this +// intrinsic (except for the Passthru argument). +class Mask : IntrinsicProperty { + int ArgNo = argNo; +} +// Passthru - The specified argument contains the per-lane return value +// for this vector intrinsic where the mask is false. +// (requires the Mask attribute in the same function) +class Passthru : IntrinsicProperty { + int ArgNo = argNo; +} + def IntrNoReturn : IntrinsicProperty; def IntrWillReturn : IntrinsicProperty; @@ -1099,8 +1118,479 @@ def int_ptrmask: Intrinsic<[llvm_anyptr_ty], [llvm_anyptr_ty, llvm_anyint_ty], [IntrNoMem, IntrSpeculatable, IntrWillReturn]>; +//===---------------- Vector Predication Intrinsics --------------===// + +// Memory Intrinsics +def int_vp_store : Intrinsic<[], + [ llvm_anyvector_ty, + LLVMAnyPointerType>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty], + [ NoCapture<1>, IntrArgMemOnly, IntrWillReturn, Mask<2>, VectorLength<3> ]>; + +def int_vp_load : Intrinsic<[ llvm_anyvector_ty], + [ LLVMAnyPointerType>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty], + [ NoCapture<0>, IntrReadMem, IntrWillReturn, IntrArgMemOnly, Mask<1>, VectorLength<2> ]>; + +def int_vp_gather: Intrinsic<[ llvm_anyvector_ty], + [ LLVMVectorOfAnyPointersToElt<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty], + [ IntrReadMem, IntrWillReturn, IntrArgMemOnly, Mask<1>, VectorLength<2> ]>; + +def int_vp_scatter: Intrinsic<[], + [ llvm_anyvector_ty, + LLVMVectorOfAnyPointersToElt<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty], + [ IntrArgMemOnly, IntrWillReturn, Mask<2>, VectorLength<3> ]>; +// TODO allow IntrNoCapture for vectors of pointers + +// Reductions +let IntrProperties = [IntrNoMem, IntrWillReturn, Mask<1>, VectorLength<2>] in { + def int_vp_reduce_add : Intrinsic<[LLVMVectorElementType<0>], + [llvm_anyvector_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_reduce_mul : Intrinsic<[LLVMVectorElementType<0>], + [llvm_anyvector_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_reduce_and : Intrinsic<[LLVMVectorElementType<0>], + [llvm_anyvector_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_reduce_or : Intrinsic<[LLVMVectorElementType<0>], + [llvm_anyvector_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_reduce_xor : Intrinsic<[LLVMVectorElementType<0>], + [llvm_anyvector_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_reduce_smax : Intrinsic<[LLVMVectorElementType<0>], + [llvm_anyvector_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_reduce_smin : Intrinsic<[LLVMVectorElementType<0>], + [llvm_anyvector_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_reduce_umax : Intrinsic<[LLVMVectorElementType<0>], + [llvm_anyvector_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_reduce_umin : Intrinsic<[LLVMVectorElementType<0>], + [llvm_anyvector_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_reduce_fmax : Intrinsic<[LLVMVectorElementType<0>], + [llvm_anyvector_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_reduce_fmin : Intrinsic<[LLVMVectorElementType<0>], + [llvm_anyvector_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; +} + +let IntrProperties = [IntrNoMem, IntrWillReturn, Mask<2>, VectorLength<3>] in { + def int_vp_reduce_fadd : Intrinsic<[LLVMVectorElementType<0>], + [LLVMVectorElementType<0>, + llvm_anyvector_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_reduce_fmul : Intrinsic<[LLVMVectorElementType<0>], + [LLVMVectorElementType<0>, + llvm_anyvector_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; +} + +// Binary operators +let IntrProperties = [IntrNoMem, IntrWillReturn, Mask<2>, VectorLength<3>] in { + def int_vp_add : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_sub : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_mul : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_sdiv : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_udiv : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_srem : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_urem : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + +// Logical operators + def int_vp_ashr : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_lshr : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_shl : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_or : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_and : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_xor : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + +} + +// Comparison +// TODO add signalling fcmp +// The last argument is the comparison predicate +def int_vp_icmp : Intrinsic<[ LLVMScalarOrSameVectorWidth<0, llvm_i1_ty> ], + [ llvm_anyvector_ty, + LLVMMatchType<0>, + llvm_i8_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty ], + [ IntrNoMem, Mask<3>, VectorLength<4>, ImmArg<2> ]>; + +def int_vp_fcmp : Intrinsic<[ LLVMScalarOrSameVectorWidth<0, llvm_i1_ty> ], + [ llvm_anyvector_ty, + LLVMMatchType<0>, + llvm_i8_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty ], + [ IntrNoMem, Mask<3>, VectorLength<4>, ImmArg<2> ]>; + + + +// Shuffle +def int_vp_vshift: Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + llvm_i32_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty], + [ IntrNoMem, IntrWillReturn, Mask<2>, VectorLength<3> ]>; + +def int_vp_expand: Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty], + [ IntrNoMem, IntrWillReturn, Mask<1>, VectorLength<2> ]>; + +def int_vp_compress: Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty], + [ IntrNoMem, IntrWillReturn, VectorLength<2> ]>; + +// Select +def int_vp_select : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + LLVMMatchType<0>, + LLVMMatchType<0>, + llvm_i32_ty], + [ IntrNoMem, IntrWillReturn, Passthru<2>, Mask<0>, VectorLength<3> ]>; + +// Compose +def int_vp_compose : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + llvm_i32_ty, + llvm_i32_ty], + [ IntrNoMem, IntrWillReturn, VectorLength<3> ]>; + + + +// VP fp rounding and truncation +let IntrProperties = [ IntrNoMem, IntrWillReturn, Mask<2>, VectorLength<3> ] in { + + def int_vp_fptosi : Intrinsic<[ llvm_anyint_ty ], + [ llvm_anyfloat_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + + def int_vp_fptoui : Intrinsic<[ llvm_anyint_ty ], + [ llvm_anyfloat_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + + def int_vp_fpext : Intrinsic<[ llvm_anyfloat_ty ], + [ llvm_anyfloat_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + + def int_vp_lround : Intrinsic<[ llvm_anyint_ty ], + [ llvm_anyfloat_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_llround : Intrinsic<[ llvm_anyint_ty ], + [ llvm_anyfloat_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; +} + +let IntrProperties = [ IntrNoMem, IntrWillReturn, Mask<3>, VectorLength<4> ] in { + def int_vp_fptrunc : Intrinsic<[ llvm_anyfloat_ty ], + [ llvm_anyfloat_ty, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + +} + +// VP single argument constrained intrinsics. +let IntrProperties = [ IntrNoMem, IntrWillReturn, Mask<3>, VectorLength<4> ] in { + // These intrinsics are sensitive to the rounding mode so we need constrained + // versions of each of them. When strict rounding and exception control are + // not required the non-constrained versions of these intrinsics should be + // used. + def int_vp_sqrt : Intrinsic<[ llvm_anyfloat_ty ], + [ LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_sin : Intrinsic<[ llvm_anyfloat_ty ], + [ LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_cos : Intrinsic<[ llvm_anyfloat_ty ], + [ LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_log : Intrinsic<[ llvm_anyfloat_ty ], + [ LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_log10: Intrinsic<[ llvm_anyfloat_ty ], + [ LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_log2 : Intrinsic<[ llvm_anyfloat_ty ], + [ LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_exp : Intrinsic<[ llvm_anyfloat_ty ], + [ LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_exp2 : Intrinsic<[ llvm_anyfloat_ty ], + [ LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_rint : Intrinsic<[ llvm_anyfloat_ty ], + [ LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_nearbyint : Intrinsic<[ llvm_anyfloat_ty ], + [ LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_lrint : Intrinsic<[ llvm_anyint_ty ], + [ llvm_anyfloat_ty, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_llrint : Intrinsic<[ llvm_anyint_ty ], + [ llvm_anyfloat_ty, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_ceil : Intrinsic<[ llvm_anyfloat_ty ], + [ LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_floor : Intrinsic<[ llvm_anyfloat_ty ], + [ LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_round : Intrinsic<[ llvm_anyfloat_ty ], + [ LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_trunc : Intrinsic<[ llvm_anyfloat_ty ], + [ LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; +} + + +// VP two argument constrained intrinsics. +let IntrProperties = [ IntrNoMem, IntrWillReturn, Mask<4>, VectorLength<5> ] in { + // These intrinsics are sensitive to the rounding mode so we need constrained + // versions of each of them. When strict rounding and exception control are + // not required the non-constrained versions of these intrinsics should be + // used. + def int_vp_powi : Intrinsic<[ llvm_anyfloat_ty ], + [ LLVMMatchType<0>, + llvm_i32_ty, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_pow : Intrinsic<[ llvm_anyfloat_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_maxnum : Intrinsic<[ llvm_anyfloat_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_minnum : Intrinsic<[ llvm_anyfloat_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + +} + + +// VP standard fp-math intrinsics. +def int_vp_fneg : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty], + [ IntrNoMem, IntrWillReturn, Mask<2>, VectorLength<3> ]>; + +let IntrProperties = [ IntrNoMem, IntrWillReturn, Mask<4>, VectorLength<5> ] in { + // These intrinsics are sensitive to the rounding mode so we need constrained + // versions of each of them. When strict rounding and exception control are + // not required the non-constrained versions of these intrinsics should be + // used. + def int_vp_fadd : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty ]>; + def int_vp_fsub : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty ]>; + def int_vp_fmul : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty ]>; + def int_vp_fdiv : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty ]>; + def int_vp_frem : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty ]>; +} + +def int_vp_fma : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty ], + [ IntrNoMem, IntrWillReturn, Mask<5>, VectorLength<6> ]>; + + + + //===-------------------------- Masked Intrinsics -------------------------===// -// +// TODO poised for deprecation (to be superseded by llvm.vp.* intrinsics) def int_masked_store : Intrinsic<[], [llvm_anyvector_ty, LLVMAnyPointerType>, llvm_i32_ty, @@ -1200,6 +1690,7 @@ [ IntrArgMemOnly, IntrWillReturn, NoCapture<0>, WriteOnly<0>, ImmArg<3> ]>; //===------------------------ Reduction Intrinsics ------------------------===// +// TODO poised for deprecation (to be superseded by llvm.vp.*. intrinsics) // let IntrProperties = [IntrNoMem, IntrWillReturn] in { def int_experimental_vector_reduce_v2_fadd : Intrinsic<[llvm_anyfloat_ty], diff --git a/llvm/include/llvm/IR/MatcherCast.h b/llvm/include/llvm/IR/MatcherCast.h new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/IR/MatcherCast.h @@ -0,0 +1,65 @@ +#ifndef LLVM_IR_MATCHERCAST_H +#define LLVM_IR_MATCHERCAST_H + +//===- MatcherCast.h - Match on the LLVM IR --------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Parameterized class hierachy for templatized pattern matching. +// +//===----------------------------------------------------------------------===// + + +namespace llvm { +namespace PatternMatch { + + +// type modification +template +struct MatcherCast { }; + +// whether the Value \p Obj behaves like a \p Class. +template +bool match_isa(const Value* Obj) { + using UnconstClass = typename std::remove_cv::type; + using DestClass = typename MatcherCast::ActualCastType; + return isa(Obj); +} + +template +auto match_cast(const Value* Obj) { + using UnconstClass = typename std::remove_cv::type; + using DestClass = typename MatcherCast::ActualCastType; + return cast(Obj); +} +template +auto match_dyn_cast(const Value* Obj) { + using UnconstClass = typename std::remove_cv::type; + using DestClass = typename MatcherCast::ActualCastType; + return dyn_cast(Obj); +} + +template +auto match_cast(Value* Obj) { + using UnconstClass = typename std::remove_cv::type; + using DestClass = typename MatcherCast::ActualCastType; + return cast(Obj); +} +template +auto match_dyn_cast(Value* Obj) { + using UnconstClass = typename std::remove_cv::type; + using DestClass = typename MatcherCast::ActualCastType; + return dyn_cast(Obj); +} + + +} // namespace PatternMatch + +} // namespace llvm + +#endif // LLVM_IR_MATCHERCAST_H + diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h --- a/llvm/include/llvm/IR/PatternMatch.h +++ b/llvm/include/llvm/IR/PatternMatch.h @@ -40,22 +40,81 @@ #include "llvm/IR/Operator.h" #include "llvm/IR/Value.h" #include "llvm/Support/Casting.h" +#include "llvm/IR/MatcherCast.h" + #include + namespace llvm { namespace PatternMatch { +// Use verbatim types in default (empty) context. +struct EmptyContext { + EmptyContext() {} + + EmptyContext(const Value *) {} + + EmptyContext(const EmptyContext & E) {} + + // reset this match context to be rooted at \p V + void reset(Value * V) {} + + // accept a match where \p Val is in a non-leaf position in a match pattern + bool acceptInnerNode(const Value * Val) const { return true; } + + // accept a match where \p Val is bound to a free variable. + bool acceptBoundNode(const Value * Val) const { return true; } + + // whether this context is compatiable with \p E. + bool acceptContext(EmptyContext E) const { return true; } + + // merge the context \p E into this context and return whether the resulting context is valid. + bool mergeContext(EmptyContext E) { return true; } + + // reset this context to \p Val. + template bool reset_match(Val *V, const Pattern &P) { + reset(V); + return const_cast(P).match_context(V, *this); + } + + // match in the current context + template bool try_match(Val *V, const Pattern &P) { + return const_cast(P).match_context(V, *this); + } +}; + +template +struct MatcherCast { using ActualCastType = DestClass; }; + + + + + + +// match without (== empty) context template bool match(Val *V, const Pattern &P) { - return const_cast(P).match(V); + EmptyContext ECtx; + return const_cast(P).match_context(V, ECtx); } +// match pattern in a given context +template bool match(Val *V, const Pattern &P, MatchContext & MContext) { + return const_cast(P).match_context(V, MContext); +} + + + template struct OneUse_match { SubPattern_t SubPattern; OneUse_match(const SubPattern_t &SP) : SubPattern(SP) {} template bool match(OpTy *V) { - return V->hasOneUse() && SubPattern.match(V); + EmptyContext EContext; return match_context(V, EContext); + } + + template bool match_context(OpTy *V, MatchContext & MContext) { + return V->hasOneUse() && SubPattern.match_context(V, MContext); } }; @@ -64,7 +123,11 @@ } template struct class_match { - template bool match(ITy *V) { return isa(V); } + template bool match(ITy *V) { + EmptyContext EContext; return match_context(V, EContext); + } + template + bool match_context(ITy *V, MatchContext & MContext) { return match_isa(V); } }; /// Match an arbitrary value and ignore it. @@ -115,11 +178,17 @@ match_combine_or(const LTy &Left, const RTy &Right) : L(Left), R(Right) {} - template bool match(ITy *V) { - if (L.match(V)) + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { + MatchContext SubContext; + + if (L.match_context(V, SubContext) && MContext.acceptContext(SubContext)) { + MContext.mergeContext(SubContext); return true; - if (R.match(V)) + } + if (R.match_context(V, MContext)) { return true; + } return false; } }; @@ -130,9 +199,10 @@ match_combine_and(const LTy &Left, const RTy &Right) : L(Left), R(Right) {} - template bool match(ITy *V) { - if (L.match(V)) - if (R.match(V)) + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { + if (L.match_context(V, MContext)) + if (R.match_context(V, MContext)) return true; return false; } @@ -155,7 +225,8 @@ apint_match(const APInt *&R) : Res(R) {} - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { if (auto *CI = dyn_cast(V)) { Res = &CI->getValue(); return true; @@ -175,7 +246,8 @@ struct apfloat_match { const APFloat *&Res; apfloat_match(const APFloat *&R) : Res(R) {} - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { if (auto *CI = dyn_cast(V)) { Res = &CI->getValueAPF(); return true; @@ -199,7 +271,8 @@ inline apfloat_match m_APFloat(const APFloat *&Res) { return Res; } template struct constantint_match { - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { if (const auto *CI = dyn_cast(V)) { const APInt &CIV = CI->getValue(); if (Val >= 0) @@ -222,7 +295,8 @@ /// satisfy a specified predicate. /// For vector constants, undefined elements are ignored. template struct cst_pred_ty : public Predicate { - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { if (const auto *CI = dyn_cast(V)) return this->isValue(CI->getValue()); if (V->getType()->isVectorTy()) { @@ -259,7 +333,8 @@ api_pred_ty(const APInt *&R) : Res(R) {} - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { if (const auto *CI = dyn_cast(V)) if (this->isValue(CI->getValue())) { Res = &CI->getValue(); @@ -281,7 +356,8 @@ /// constants that satisfy a specified predicate. /// For vector constants, undefined elements are ignored. template struct cstfp_pred_ty : public Predicate { - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { if (const auto *CF = dyn_cast(V)) return this->isValue(CF->getValueAPF()); if (V->getType()->isVectorTy()) { @@ -394,7 +470,8 @@ } struct is_zero { - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { auto *C = dyn_cast(V); return C && (C->isNullValue() || cst_pred_ty().match(C)); } @@ -542,8 +619,11 @@ bind_ty(Class *&V) : VR(V) {} - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { if (auto *CV = dyn_cast(V)) { + if (!MContext.acceptBoundNode(V)) return false; + VR = CV; return true; } @@ -583,7 +663,8 @@ specificval_ty(const Value *V) : Val(V) {} - template bool match(ITy *V) { return V == Val; } + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { return V == Val; } }; /// Match if we have a specific specified value. @@ -596,7 +677,8 @@ deferredval_ty(Class *const &V) : Val(V) {} - template bool match(ITy *const V) { return V == Val; } + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *const V, MatchContext & MContext) { return V == Val; } }; /// A commutative-friendly version of m_Specific(). @@ -612,7 +694,8 @@ specific_fpval(double V) : Val(V) {} - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { if (const auto *CFP = dyn_cast(V)) return CFP->isExactlyValue(Val); if (V->getType()->isVectorTy()) @@ -635,7 +718,8 @@ bind_const_intval_ty(uint64_t &V) : VR(V) {} - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { if (const auto *CV = dyn_cast(V)) if (CV->getValue().ule(UINT64_MAX)) { VR = CV->getZExtValue(); @@ -652,7 +736,8 @@ specific_intval(APInt V) : Val(std::move(V)) {} - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { const auto *CI = dyn_cast(V); if (!CI && V->getType()->isVectorTy()) if (const auto *C = dyn_cast(V)) @@ -682,7 +767,8 @@ specific_bbval(BasicBlock *Val) : Val(Val) {} - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EC; return match_context(V, EC); } + template bool match_context(ITy *V, MatchContext & MContext) { const auto *BB = dyn_cast(V); return BB && BB == Val; } @@ -714,11 +800,16 @@ // The LHS is always matched first. AnyBinaryOp_match(const LHS_t &LHS, const RHS_t &RHS) : L(LHS), R(RHS) {} - template bool match(OpTy *V) { - if (auto *I = dyn_cast(V)) - return (L.match(I->getOperand(0)) && R.match(I->getOperand(1))) || - (Commutable && L.match(I->getOperand(1)) && - R.match(I->getOperand(0))); + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { + auto * I = match_dyn_cast(V); + if (!I) return false; + + if (!MContext.acceptInnerNode(I)) return false; + + MatchContext LRContext(MContext); + if (L.match_context(I->getOperand(0), LRContext) && R.match_context(I->getOperand(1), LRContext) && MContext.mergeContext(LRContext)) return true; + if (Commutable && (L.match_context(I->getOperand(1), MContext) && R.match_context(I->getOperand(0), MContext))) return true; return false; } }; @@ -742,12 +833,15 @@ // The LHS is always matched first. BinaryOp_match(const LHS_t &LHS, const RHS_t &RHS) : L(LHS), R(RHS) {} - template bool match(OpTy *V) { - if (V->getValueID() == Value::InstructionVal + Opcode) { - auto *I = cast(V); - return (L.match(I->getOperand(0)) && R.match(I->getOperand(1))) || - (Commutable && L.match(I->getOperand(1)) && - R.match(I->getOperand(0))); + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { + auto * I = match_dyn_cast(V); + if (I && I->getOpcode() == Opcode) { + MatchContext LRContext(MContext); + if (!MContext.acceptInnerNode(I)) return false; + if (L.match_context(I->getOperand(0), LRContext) && R.match_context(I->getOperand(1), LRContext) && MContext.mergeContext(LRContext)) return true; + if (Commutable && (L.match_context(I->getOperand(1), MContext) && R.match_context(I->getOperand(0), MContext))) return true; + return false; } if (auto *CE = dyn_cast(V)) return CE->getOpcode() == Opcode && @@ -786,25 +880,26 @@ Op_t X; FNeg_match(const Op_t &Op) : X(Op) {} - template bool match(OpTy *V) { + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { auto *FPMO = dyn_cast(V); if (!FPMO) return false; - if (FPMO->getOpcode() == Instruction::FNeg) + if (match_cast(V)->getOpcode() == Instruction::FNeg) return X.match(FPMO->getOperand(0)); - if (FPMO->getOpcode() == Instruction::FSub) { + if (match_cast(V)->getOpcode() == Instruction::FSub) { if (FPMO->hasNoSignedZeros()) { // With 'nsz', any zero goes. - if (!cstfp_pred_ty().match(FPMO->getOperand(0))) + if (!cstfp_pred_ty().match_context(FPMO->getOperand(0), MContext)) return false; } else { // Without 'nsz', we need fsub -0.0, X exactly. - if (!cstfp_pred_ty().match(FPMO->getOperand(0))) + if (!cstfp_pred_ty().match_context(FPMO->getOperand(0), MContext)) return false; } - return X.match(FPMO->getOperand(1)); + return X.match_context(FPMO->getOperand(1), MContext); } return false; @@ -940,7 +1035,8 @@ OverflowingBinaryOp_match(const LHS_t &LHS, const RHS_t &RHS) : L(LHS), R(RHS) {} - template bool match(OpTy *V) { + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { if (auto *Op = dyn_cast(V)) { if (Op->getOpcode() != Opcode) return false; @@ -950,7 +1046,7 @@ if (WrapFlags & OverflowingBinaryOperator::NoSignedWrap && !Op->hasNoSignedWrap()) return false; - return L.match(Op->getOperand(0)) && R.match(Op->getOperand(1)); + return L.match_context(Op->getOperand(0), MContext) && R.match_context(Op->getOperand(1), MContext); } return false; } @@ -1032,10 +1128,11 @@ BinOpPred_match(const LHS_t &LHS, const RHS_t &RHS) : L(LHS), R(RHS) {} - template bool match(OpTy *V) { - if (auto *I = dyn_cast(V)) - return this->isOpType(I->getOpcode()) && L.match(I->getOperand(0)) && - R.match(I->getOperand(1)); + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { + if (auto *I = match_dyn_cast(V)) + return this->isOpType(I->getOpcode()) && L.match_context(I->getOperand(0), MContext) && + R.match_context(I->getOperand(1), MContext); if (auto *CE = dyn_cast(V)) return this->isOpType(CE->getOpcode()) && L.match(CE->getOperand(0)) && R.match(CE->getOperand(1)); @@ -1127,9 +1224,10 @@ Exact_match(const SubPattern_t &SP) : SubPattern(SP) {} - template bool match(OpTy *V) { + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { if (auto *PEO = dyn_cast(V)) - return PEO->isExact() && SubPattern.match(V); + return PEO->isExact() && SubPattern.match_context(V, MContext); return false; } }; @@ -1154,14 +1252,17 @@ CmpClass_match(PredicateTy &Pred, const LHS_t &LHS, const RHS_t &RHS) : Predicate(Pred), L(LHS), R(RHS) {} - template bool match(OpTy *V) { - if (auto *I = dyn_cast(V)) - if ((L.match(I->getOperand(0)) && R.match(I->getOperand(1))) || - (Commutable && L.match(I->getOperand(1)) && - R.match(I->getOperand(0)))) { + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { + if (auto *I = match_dyn_cast(V)) { + if (!MContext.acceptInnerNode(I)) return false; + MatchContext LRContext(MContext); + if ((L.match_context(I->getOperand(0), LRContext) && R.match_context(I->getOperand(1), LRContext) && MContext.mergeContext(LRContext)) || + (Commutable && (L.match_context(I->getOperand(1), MContext) && R.match_context(I->getOperand(0), MContext)))) { Predicate = I->getPredicate(); return true; } + } return false; } }; @@ -1194,10 +1295,11 @@ OneOps_match(const T0 &Op1) : Op1(Op1) {} - template bool match(OpTy *V) { - if (V->getValueID() == Value::InstructionVal + Opcode) { - auto *I = cast(V); - return Op1.match(I->getOperand(0)); + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { + auto *I = match_dyn_cast(V); + if (I && I->getOpcode() == Opcode && MContext.acceptInnerNode(I)) { + return Op1.match_context(I->getOperand(0), MContext); } return false; } @@ -1210,10 +1312,12 @@ TwoOps_match(const T0 &Op1, const T1 &Op2) : Op1(Op1), Op2(Op2) {} - template bool match(OpTy *V) { - if (V->getValueID() == Value::InstructionVal + Opcode) { - auto *I = cast(V); - return Op1.match(I->getOperand(0)) && Op2.match(I->getOperand(1)); + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { + auto *I = match_dyn_cast(V); + if (I && I->getOpcode() == Opcode && MContext.acceptInnerNode(I)) { + return Op1.match_context(I->getOperand(0), MContext) && + Op2.match_context(I->getOperand(1), MContext); } return false; } @@ -1229,11 +1333,13 @@ ThreeOps_match(const T0 &Op1, const T1 &Op2, const T2 &Op3) : Op1(Op1), Op2(Op2), Op3(Op3) {} - template bool match(OpTy *V) { - if (V->getValueID() == Value::InstructionVal + Opcode) { - auto *I = cast(V); - return Op1.match(I->getOperand(0)) && Op2.match(I->getOperand(1)) && - Op3.match(I->getOperand(2)); + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { + auto *I = match_dyn_cast(V); + if (I && I->getOpcode() == Opcode && MContext.acceptInnerNode(I)) { + return Op1.match_context(I->getOperand(0), MContext) && + Op2.match_context(I->getOperand(1), MContext) && + Op3.match_context(I->getOperand(2), MContext); } return false; } @@ -1301,9 +1407,10 @@ CastClass_match(const Op_t &OpMatch) : Op(OpMatch) {} - template bool match(OpTy *V) { - if (auto *O = dyn_cast(V)) - return O->getOpcode() == Opcode && Op.match(O->getOperand(0)); + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { + if (auto O = match_dyn_cast(V)) + return O->getOpcode() == Opcode && MContext.acceptInnerNode(O) && Op.match_context(O->getOperand(0), MContext); return false; } }; @@ -1405,8 +1512,9 @@ br_match(BasicBlock *&Succ) : Succ(Succ) {} - template bool match(OpTy *V) { - if (auto *BI = dyn_cast(V)) + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { + if (auto *BI = match_dyn_cast(V)) if (BI->isUnconditional()) { Succ = BI->getSuccessor(0); return true; @@ -1426,10 +1534,12 @@ brc_match(const Cond_t &C, const TrueBlock_t &t, const FalseBlock_t &f) : Cond(C), T(t), F(f) {} - template bool match(OpTy *V) { - if (auto *BI = dyn_cast(V)) - if (BI->isConditional() && Cond.match(BI->getCondition())) - return T.match(BI->getSuccessor(0)) && F.match(BI->getSuccessor(1)); + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { + if (auto *BI = match_dyn_cast(V)) + if (BI->isConditional() && Cond.match(BI->getCondition())) { + return T.match_context(BI->getSuccessor(0), MContext) && F.match_context(BI->getSuccessor(1), MContext); + } return false; } }; @@ -1461,13 +1571,14 @@ // The LHS is always matched first. MaxMin_match(const LHS_t &LHS, const RHS_t &RHS) : L(LHS), R(RHS) {} - template bool match(OpTy *V) { + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { // Look for "(x pred y) ? x : y" or "(x pred y) ? y : x". - auto *SI = dyn_cast(V); - if (!SI) + auto *SI = match_dyn_cast(V); + if (!SI || !MContext.acceptInnerNode(SI)) return false; - auto *Cmp = dyn_cast(SI->getCondition()); - if (!Cmp) + auto *Cmp = match_dyn_cast(SI->getCondition()); + if (!Cmp || !MContext.acceptInnerNode(Cmp)) return false; // At this point we have a select conditioned on a comparison. Check that // it is the values returned by the select that are being compared. @@ -1483,9 +1594,12 @@ // Does "(x pred y) ? x : y" represent the desired max/min operation? if (!Pred_t::match(Pred)) return false; + // It does! Bind the operands. - return (L.match(LHS) && R.match(RHS)) || - (Commutable && L.match(RHS) && R.match(LHS)); + MatchContext LRContext(MContext); + if (L.match_context(LHS, LRContext) && R.match_context(RHS, LRContext) && MContext.mergeContext(LRContext)) return true; + if (Commutable && (L.match_context(RHS, MContext) && R.match_context(LHS, MContext))) return true; + return false; } }; @@ -1642,7 +1756,8 @@ UAddWithOverflow_match(const LHS_t &L, const RHS_t &R, const Sum_t &S) : L(L), R(R), S(S) {} - template bool match(OpTy *V) { + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { Value *ICmpLHS, *ICmpRHS; ICmpInst::Predicate Pred; if (!m_ICmp(Pred, m_Value(ICmpLHS), m_Value(ICmpRHS)).match(V)) @@ -1695,9 +1810,10 @@ Argument_match(unsigned OpIdx, const Opnd_t &V) : OpI(OpIdx), Val(V) {} - template bool match(OpTy *V) { + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { // FIXME: Should likely be switched to use `CallBase`. - if (const auto *CI = dyn_cast(V)) + if (const auto *CI = match_dyn_cast(V)) return Val.match(CI->getArgOperand(OpI)); return false; } @@ -1715,8 +1831,9 @@ IntrinsicID_match(Intrinsic::ID IntrID) : ID(IntrID) {} - template bool match(OpTy *V) { - if (const auto *CI = dyn_cast(V)) + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { + if (const auto *CI = match_dyn_cast(V)) if (const auto *F = CI->getCalledFunction()) return F->getIntrinsicID() == ID; return false; @@ -1926,7 +2043,8 @@ Opnd_t Val; Signum_match(const Opnd_t &V) : Val(V) {} - template bool match(OpTy *V) { + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { unsigned TypeSize = V->getType()->getScalarSizeInBits(); if (TypeSize == 0) return false; @@ -1966,10 +2084,11 @@ Opnd_t Val; ExtractValue_match(const Opnd_t &V) : Val(V) {} - template bool match(OpTy *V) { + template bool match(OpTy *V) { EmptyContext EC; return match_context(V, EC); } + template bool match_context(OpTy *V, MatchContext & MContext) { if (auto *I = dyn_cast(V)) return I->getNumIndices() == 1 && I->getIndices()[0] == Ind && - Val.match(I->getAggregateOperand()); + Val.match_context(I->getAggregateOperand(), MContext); return false; } }; diff --git a/llvm/include/llvm/IR/PredicatedInst.h b/llvm/include/llvm/IR/PredicatedInst.h new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/IR/PredicatedInst.h @@ -0,0 +1,419 @@ +//===-- llvm/PredicatedInst.h - Predication utility subclass --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines various classes for working with predicated instructions. +// Predicated instructions are either regular instructions or calls to +// Vector Predication (VP) intrinsics that have a mask and an explicit +// vector length argument. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_IR_PREDICATEDINST_H +#define LLVM_IR_PREDICATEDINST_H + +#include "llvm/ADT/None.h" +#include "llvm/ADT/Optional.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/MatcherCast.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/Casting.h" + +#include + +namespace llvm { + +class BasicBlock; + +class PredicatedInstruction : public User { +public: + // The PredicatedInstruction class is intended to be used as a utility, and is + // never itself instantiated. + PredicatedInstruction() = delete; + ~PredicatedInstruction() = delete; + + void copyIRFlags(const Value *V, bool IncludeWrapFlags) { + cast(this)->copyIRFlags(V, IncludeWrapFlags); + } + + BasicBlock *getParent() { return cast(this)->getParent(); } + const BasicBlock *getParent() const { + return cast(this)->getParent(); + } + + void *operator new(size_t s) = delete; + + Value *getMaskParam() const { + auto thisVP = dyn_cast(this); + if (!thisVP) + return nullptr; + return thisVP->getMaskParam(); + } + + Value *getVectorLengthParam() const { + auto thisVP = dyn_cast(this); + if (!thisVP) + return nullptr; + return thisVP->getVectorLengthParam(); + } + + /// \returns True if the passed vector length value has no predicating effect + /// on the op. + bool canIgnoreVectorLengthParam() const; + + /// \return True if the static operator of this instruction has a mask or + /// vector length parameter. + bool isVectorPredicatedOp() const { return isa(this); } + + /// \returns the effective Opcode of this operation (ignoring the mask and + /// vector length param). + unsigned getOpcode() const { + auto *VPInst = dyn_cast(this); + + if (!VPInst) { + return cast(this)->getOpcode(); + } + + return VPInst->getFunctionalOpcode(); + } + + static bool classof(const Instruction *I) { return isa(I); } + static bool classof(const ConstantExpr *CE) { return false; } + static bool classof(const Value *V) { return isa(V); } + + /// Convenience function for getting all the fast-math flags, which must be an + /// operator which supports these flags. See LangRef.html for the meaning of + /// these flags. + FastMathFlags getFastMathFlags() const; +}; + +class PredicatedOperator : public User { +public: + // The PredicatedOperator class is intended to be used as a utility, and is + // never itself instantiated. + PredicatedOperator() = delete; + ~PredicatedOperator() = delete; + + void *operator new(size_t s) = delete; + + /// Return the opcode for this Instruction or ConstantExpr. + unsigned getOpcode() const { + auto *VPInst = dyn_cast(this); + + // Conceal the fp operation if it has non-default rounding mode or exception + // behavior + if (VPInst && !VPInst->isConstrainedOp()) { + return VPInst->getFunctionalOpcode(); + } + + if (const Instruction *I = dyn_cast(this)) + return I->getOpcode(); + + return cast(this)->getOpcode(); + } + + Value *getMask() const { + auto thisVP = dyn_cast(this); + if (!thisVP) + return nullptr; + return thisVP->getMaskParam(); + } + + Value *getVectorLength() const { + auto thisVP = dyn_cast(this); + if (!thisVP) + return nullptr; + return thisVP->getVectorLengthParam(); + } + + void copyIRFlags(const Value *V, bool IncludeWrapFlags = true); + FastMathFlags getFastMathFlags() const { + auto *I = dyn_cast(this); + if (I) + return I->getFastMathFlags(); + else + return FastMathFlags(); + } + + static bool classof(const Instruction *I) { + return isa(I) || isa(I); + } + static bool classof(const ConstantExpr *CE) { return isa(CE); } + static bool classof(const Value *V) { + return isa(V) || isa(V); + } +}; + +class PredicatedBinaryOperator : public PredicatedOperator { +public: + // The PredicatedBinaryOperator class is intended to be used as a utility, and + // is never itself instantiated. + PredicatedBinaryOperator() = delete; + ~PredicatedBinaryOperator() = delete; + + using BinaryOps = Instruction::BinaryOps; + + void *operator new(size_t s) = delete; + + static bool classof(const Instruction *I) { + if (isa(I)) + return true; + auto VPInst = dyn_cast(I); + return VPInst && VPInst->isBinaryOp(); + } + static bool classof(const ConstantExpr *CE) { + return isa(CE); + } + static bool classof(const Value *V) { + auto *I = dyn_cast(V); + if (I && classof(I)) + return true; + auto *CE = dyn_cast(V); + return CE && classof(CE); + } + + /// Construct a predicated binary instruction, given the opcode and the two + /// operands. + static Instruction *Create(Module *Mod, Value *Mask, Value *VectorLen, + Instruction::BinaryOps Opc, Value *V1, Value *V2, + const Twine &Name, BasicBlock *InsertAtEnd, + Instruction *InsertBefore); + + static Instruction *Create(Module *Mod, Value *Mask, Value *VectorLen, + BinaryOps Opc, Value *V1, Value *V2, + const Twine &Name = Twine(), + Instruction *InsertBefore = nullptr) { + return Create(Mod, Mask, VectorLen, Opc, V1, V2, Name, nullptr, + InsertBefore); + } + + static Instruction *Create(Module *Mod, Value *Mask, Value *VectorLen, + BinaryOps Opc, Value *V1, Value *V2, + const Twine &Name, BasicBlock *InsertAtEnd) { + return Create(Mod, Mask, VectorLen, Opc, V1, V2, Name, InsertAtEnd, + nullptr); + } + + static Instruction *CreateWithCopiedFlags(Module *Mod, Value *Mask, + Value *VectorLen, BinaryOps Opc, + Value *V1, Value *V2, + Instruction *CopyBO, + const Twine &Name = "") { + Instruction *BO = + Create(Mod, Mask, VectorLen, Opc, V1, V2, Name, nullptr, nullptr); + BO->copyIRFlags(CopyBO); + return BO; + } +}; + +class PredicatedICmpInst : public PredicatedBinaryOperator { +public: + // The Operator class is intended to be used as a utility, and is never itself + // instantiated. + PredicatedICmpInst() = delete; + ~PredicatedICmpInst() = delete; + + void *operator new(size_t s) = delete; + + static bool classof(const Instruction *I) { + if (isa(I)) + return true; + auto VPInst = dyn_cast(I); + return VPInst && VPInst->getFunctionalOpcode() == Instruction::ICmp; + } + static bool classof(const ConstantExpr *CE) { + return CE->getOpcode() == Instruction::ICmp; + } + static bool classof(const Value *V) { + auto *I = dyn_cast(V); + if (I && classof(I)) + return true; + auto *CE = dyn_cast(V); + return CE && classof(CE); + } + + ICmpInst::Predicate getPredicate() const { + auto *ICInst = dyn_cast(this); + if (ICInst) + return ICInst->getPredicate(); + auto *CE = dyn_cast(this); + if (CE) + return static_cast(CE->getPredicate()); + return static_cast( + cast(this)->getCmpPredicate()); + } +}; + +class PredicatedFCmpInst : public PredicatedBinaryOperator { +public: + // The Operator class is intended to be used as a utility, and is never itself + // instantiated. + PredicatedFCmpInst() = delete; + ~PredicatedFCmpInst() = delete; + + void *operator new(size_t s) = delete; + + static bool classof(const Instruction *I) { + if (isa(I)) + return true; + auto VPInst = dyn_cast(I); + return VPInst && VPInst->getFunctionalOpcode() == Instruction::FCmp; + } + static bool classof(const ConstantExpr *CE) { + return CE->getOpcode() == Instruction::FCmp; + } + static bool classof(const Value *V) { + auto *I = dyn_cast(V); + if (I && classof(I)) + return true; + return isa(V); + } + + FCmpInst::Predicate getPredicate() const { + auto *FCInst = dyn_cast(this); + if (FCInst) + return FCInst->getPredicate(); + auto *CE = dyn_cast(this); + if (CE) + return static_cast(CE->getPredicate()); + return static_cast( + cast(this)->getCmpPredicate()); + } +}; + +class PredicatedSelectInst : public PredicatedOperator { +public: + // The Operator class is intended to be used as a utility, and is never itself + // instantiated. + PredicatedSelectInst() = delete; + ~PredicatedSelectInst() = delete; + + void *operator new(size_t s) = delete; + + static bool classof(const Instruction *I) { + if (isa(I)) + return true; + auto VPInst = dyn_cast(I); + return VPInst && VPInst->getFunctionalOpcode() == Instruction::Select; + } + static bool classof(const ConstantExpr *CE) { + return CE->getOpcode() == Instruction::Select; + } + static bool classof(const Value *V) { + auto *I = dyn_cast(V); + if (I && classof(I)) + return true; + auto *CE = dyn_cast(V); + return CE && CE->getOpcode() == Instruction::Select; + } + + const Value *getCondition() const { return getOperand(0); } + const Value *getTrueValue() const { return getOperand(1); } + const Value *getFalseValue() const { return getOperand(2); } + Value *getCondition() { return getOperand(0); } + Value *getTrueValue() { return getOperand(1); } + Value *getFalseValue() { return getOperand(2); } + + void setCondition(Value *V) { setOperand(0, V); } + void setTrueValue(Value *V) { setOperand(1, V); } + void setFalseValue(Value *V) { setOperand(2, V); } +}; + +namespace PatternMatch { + +// PredicatedMatchContext for pattern matching +struct PredicatedContext { + Value *Mask; + Value *VectorLength; + Module *Mod; + + void reset(Value *V) { + auto *PI = dyn_cast(V); + if (!PI) { + VectorLength = nullptr; + Mask = nullptr; + Mod = nullptr; + } else { + VectorLength = PI->getVectorLengthParam(); + Mask = PI->getMaskParam(); + Mod = PI->getParent()->getParent()->getParent(); + } + } + + PredicatedContext(Value *Val) + : Mask(nullptr), VectorLength(nullptr), Mod(nullptr) { + reset(Val); + } + + PredicatedContext(const PredicatedContext &PC) + : Mask(PC.Mask), VectorLength(PC.VectorLength), Mod(PC.Mod) {} + + /// accept a match where \p Val is in a non-leaf position in a match pattern + bool acceptInnerNode(const Value *Val) const { + auto PredI = dyn_cast(Val); + if (!PredI) + return VectorLength == nullptr && Mask == nullptr; + return VectorLength == PredI->getVectorLengthParam() && + Mask == PredI->getMaskParam(); + } + + /// accept a match where \p Val is bound to a free variable. + bool acceptBoundNode(const Value *Val) const { return true; } + + /// whether this context is compatiable with \p E. + bool acceptContext(PredicatedContext PC) const { + return std::tie(PC.Mask, PC.VectorLength) == std::tie(Mask, VectorLength); + } + + /// merge the context \p E into this context and return whether the resulting + /// context is valid. + bool mergeContext(PredicatedContext PC) const { return acceptContext(PC); } + + /// match \p P in a new contest for \p Val. + template + bool reset_match(Val *V, const Pattern &P) { + reset(V); + return const_cast(P).match_context(V, *this); + } + + /// match \p P in the current context. + template + bool try_match(Val *V, const Pattern &P) { + PredicatedContext SubContext(*this); + return const_cast(P).match_context(V, SubContext); + } +}; + +struct PredicatedContext; +template <> struct MatcherCast { + using ActualCastType = PredicatedBinaryOperator; +}; +template <> struct MatcherCast { + using ActualCastType = PredicatedOperator; +}; +template <> struct MatcherCast { + using ActualCastType = PredicatedICmpInst; +}; +template <> struct MatcherCast { + using ActualCastType = PredicatedFCmpInst; +}; +template <> struct MatcherCast { + using ActualCastType = PredicatedSelectInst; +}; +template <> struct MatcherCast { + using ActualCastType = PredicatedInstruction; +}; + +} // namespace PatternMatch + +} // namespace llvm + +#endif // LLVM_IR_PREDICATEDINST_H diff --git a/llvm/include/llvm/IR/VPBuilder.h b/llvm/include/llvm/IR/VPBuilder.h new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/IR/VPBuilder.h @@ -0,0 +1,231 @@ +#ifndef LLVM_IR_VPBUILDER_H +#define LLVM_IR_VPBUILDER_H + +#include +#include +#include +#include +#include +#include + +namespace llvm { + +using ValArray = ArrayRef; + +class VPBuilder { + IRBuilder<> & Builder; + + // Explicit mask parameter + Value * Mask; + // Explicit vector length parameter + Value * ExplicitVectorLength; + // Compile-time vector length + int StaticVectorLength; + + // get a valid mask/evl argument for the current predication contet + Value& GetMaskForType(VectorType & VecTy); + Value& GetEVLForType(VectorType & VecTy); + +public: + VPBuilder(IRBuilder<> & _builder) + : Builder(_builder) + , Mask(nullptr) + , ExplicitVectorLength(nullptr) + , StaticVectorLength(-1) + {} + + Module & getModule() const; + LLVMContext & getContext() const { return Builder.getContext(); } + + // The cannonical vector type for this \p ElementTy + VectorType& getVectorType(Type &ElementTy); + + // Predication context tracker + VPBuilder& setMask(Value * _Mask) { Mask = _Mask; return *this; } + VPBuilder& setEVL(Value * _ExplicitVectorLength) { ExplicitVectorLength = _ExplicitVectorLength; return *this; } + VPBuilder& setStaticVL(int VLen) { StaticVectorLength = VLen; return *this; } + + // Create a map-vectorized copy of the instruction \p Inst with the underlying IRBuilder instance. + // This operation may return nullptr if the instruction could not be vectorized. + Value* CreateVectorCopy(Instruction & Inst, ValArray VecOpArray); + + // Memory + Value& CreateContiguousStore(Value & Val, Value & Pointer, Align Alignment); + Value& CreateContiguousLoad(Value & Pointer, Align Alignment); + Value& CreateScatter(Value & Val, Value & PointerVec, Align Alignment); + Value& CreateGather(Value & PointerVec, Align Alignment); +}; + + + + + +namespace PatternMatch { + // Factory class to generate instructions in a context + template + class MatchContextBuilder { + public: + // MatchContextBuilder(MatcherContext MC); + }; + + +// Context-free instruction builder +template<> +class MatchContextBuilder { +public: + MatchContextBuilder(EmptyContext & EC) {} + + #define HANDLE_BINARY_INST(N, OPC, CLASS) \ + Instruction *Create##OPC(Value *V1, Value *V2, \ + const Twine &Name = "") const {\ + return BinaryOperator::Create(Instruction::OPC, V1, V2, Name);\ + } \ + template \ + Value *Create##OPC(IRBuilderType & Builder, Value *V1, Value *V2, \ + const Twine &Name = "") const { \ + auto * Inst = BinaryOperator::Create(Instruction::OPC, V1, V2, Name); \ + Builder.Insert(Inst); return Inst; \ + } + #include "llvm/IR/Instruction.def" + #define HANDLE_BINARY_INST(N, OPC, CLASS) \ + Value *Create##OPC(Value *V1, Value *V2, \ + const Twine &Name, BasicBlock *BB) const {\ + return BinaryOperator::Create(Instruction::OPC, V1, V2, Name, BB);\ + } + #include "llvm/IR/Instruction.def" + #define HANDLE_BINARY_INST(N, OPC, CLASS) \ + Value *Create##OPC(Value *V1, Value *V2, \ + const Twine &Name, Instruction *I) const {\ + return BinaryOperator::Create(Instruction::OPC, V1, V2, Name, I);\ + } + #include "llvm/IR/Instruction.def" + #undef HANDLE_BINARY_INST + + BinaryOperator *CreateFAddFMF(Value *V1, Value *V2, + Instruction *FMFSource, + const Twine &Name = "") { + return BinaryOperator::CreateWithCopiedFlags(Instruction::FAdd, V1, V2, FMFSource, Name); + } + BinaryOperator *CreateFSubFMF(Value *V1, Value *V2, + Instruction *FMFSource, + const Twine &Name = "") { + return BinaryOperator::CreateWithCopiedFlags(Instruction::FSub, V1, V2, FMFSource, Name); + } + template + BinaryOperator *CreateFSubFMF(IRBuilderType & Builder, Value *V1, Value *V2, + Instruction *FMFSource, + const Twine &Name = "") { + auto * Inst = CreateFSubFMF(V1, V2, FMFSource, Name); + Builder.Insert(Inst); return Inst; + } + BinaryOperator *CreateFMulFMF(Value *V1, Value *V2, + Instruction *FMFSource, + const Twine &Name = "") { + return BinaryOperator::CreateWithCopiedFlags(Instruction::FMul, V1, V2, FMFSource, Name); + } + BinaryOperator *CreateFDivFMF(Value *V1, Value *V2, + Instruction *FMFSource, + const Twine &Name = "") { + return BinaryOperator::CreateWithCopiedFlags(Instruction::FDiv, V1, V2, FMFSource, Name); + } + BinaryOperator *CreateFRemFMF(Value *V1, Value *V2, + Instruction *FMFSource, + const Twine &Name = "") { + return BinaryOperator::CreateWithCopiedFlags(Instruction::FRem, V1, V2, FMFSource, Name); + } + BinaryOperator *CreateFNegFMF(Value *Op, Instruction *FMFSource, + const Twine &Name = "") { + Value *Zero = ConstantFP::getNegativeZero(Op->getType()); + return BinaryOperator::CreateWithCopiedFlags(Instruction::FSub, Zero, Op, FMFSource); + } + + template + Value *CreateFPTrunc(IRBuilderType & Builder, Value *V, Type *DestTy, const Twine & Name = Twine()) { return Builder.CreateFPTrunc(V, DestTy, Name); } + template + Value *CreateFPExt(IRBuilderType & Builder, Value *V, Type *DestTy, const Twine & Name = Twine()) { return Builder.CreateFPExt(V, DestTy, Name); } +}; + + + +// Context-free instruction builder +template<> +class MatchContextBuilder { + PredicatedContext & PC; +public: + MatchContextBuilder(PredicatedContext & PC) : PC(PC) {} + + #define HANDLE_BINARY_INST(N, OPC, CLASS) \ + Instruction *Create##OPC(Value *V1, Value *V2, \ + const Twine &Name = "") const {\ + return PredicatedBinaryOperator::Create(PC.Mod, PC.Mask, PC.VectorLength, Instruction::OPC, V1, V2, Name);\ + } \ + template \ + Instruction *Create##OPC(IRBuilderType & Builder, Value *V1, Value *V2, \ + const Twine &Name = "") const {\ + auto * PredInst = Create##OPC(V1, V2, Name); Builder.Insert(PredInst); return PredInst; \ + } + #include "llvm/IR/Instruction.def" + #define HANDLE_BINARY_INST(N, OPC, CLASS) \ + Instruction *Create##OPC(Value *V1, Value *V2, \ + const Twine &Name, BasicBlock *BB) const {\ + return PredicatedBinaryOperator::Create(PC.Mod, PC.Mask, PC.VectorLength, Instruction::OPC, V1, V2, Name, BB);\ + } + #include "llvm/IR/Instruction.def" + #define HANDLE_BINARY_INST(N, OPC, CLASS) \ + Instruction *Create##OPC(Value *V1, Value *V2, \ + const Twine &Name, Instruction *I) const {\ + return PredicatedBinaryOperator::Create(PC.Mod, PC.Mask, PC.VectorLength, Instruction::OPC, V1, V2, Name, I);\ + } + #include "llvm/IR/Instruction.def" + #undef HANDLE_BINARY_INST + + Instruction *CreateFAddFMF(Value *V1, Value *V2, + Instruction *FMFSource, + const Twine &Name = "") { + return PredicatedBinaryOperator::CreateWithCopiedFlags(PC.Mod, PC.Mask, PC.VectorLength, Instruction::FAdd, V1, V2, FMFSource, Name); + } + Instruction *CreateFSubFMF(Value *V1, Value *V2, + Instruction *FMFSource, + const Twine &Name = "") { + return PredicatedBinaryOperator::CreateWithCopiedFlags(PC.Mod, PC.Mask, PC.VectorLength, Instruction::FSub, V1, V2, FMFSource, Name); + } + template + Instruction *CreateFSubFMF(IRBuilderType & Builder, Value *V1, Value *V2, + Instruction *FMFSource, + const Twine &Name = "") { + auto * Inst = CreateFSubFMF(V1, V2, FMFSource, Name); + Builder.Insert(Inst); return Inst; + } + Instruction *CreateFMulFMF(Value *V1, Value *V2, + Instruction *FMFSource, + const Twine &Name = "") { + return PredicatedBinaryOperator::CreateWithCopiedFlags(PC.Mod, PC.Mask, PC.VectorLength, Instruction::FMul, V1, V2, FMFSource, Name); + } + Instruction *CreateFDivFMF(Value *V1, Value *V2, + Instruction *FMFSource, + const Twine &Name = "") { + return PredicatedBinaryOperator::CreateWithCopiedFlags(PC.Mod, PC.Mask, PC.VectorLength, Instruction::FDiv, V1, V2, FMFSource, Name); + } + Instruction *CreateFRemFMF(Value *V1, Value *V2, + Instruction *FMFSource, + const Twine &Name = "") { + return PredicatedBinaryOperator::CreateWithCopiedFlags(PC.Mod, PC.Mask, PC.VectorLength, Instruction::FRem, V1, V2, FMFSource, Name); + } + Instruction *CreateFNegFMF(Value *Op, Instruction *FMFSource, + const Twine &Name = "") { + Value *Zero = ConstantFP::getNegativeZero(Op->getType()); + return PredicatedBinaryOperator::CreateWithCopiedFlags(PC.Mod, PC.Mask, PC.VectorLength, Instruction::FSub, Zero, Op, FMFSource); + } + + // TODO predicated casts + template + Value *CreateFPTrunc(IRBuilderType & Builder, Value *V, Type *DestTy, const Twine & Name = Twine()) { return Builder.CreateFPTrunc(V, DestTy, Name); } + template + Value *CreateFPExt(IRBuilderType & Builder, Value *V, Type *DestTy, const Twine & Name = Twine()) { return Builder.CreateFPExt(V, DestTy, Name); } +}; + +} + +} // namespace llvm + +#endif // LLVM_IR_VPBUILDER_H diff --git a/llvm/include/llvm/InitializePasses.h b/llvm/include/llvm/InitializePasses.h --- a/llvm/include/llvm/InitializePasses.h +++ b/llvm/include/llvm/InitializePasses.h @@ -143,6 +143,7 @@ void initializeExpandMemCmpPassPass(PassRegistry&); void initializeExpandPostRAPass(PassRegistry&); void initializeExpandReductionsPass(PassRegistry&); +void initializeExpandVectorPredicationPass(PassRegistry&); void initializeMakeGuardsExplicitLegacyPassPass(PassRegistry&); void initializeExternalAAWrapperPassPass(PassRegistry&); void initializeFEntryInserterPass(PassRegistry&); diff --git a/llvm/include/llvm/Target/TargetSelectionDAG.td b/llvm/include/llvm/Target/TargetSelectionDAG.td --- a/llvm/include/llvm/Target/TargetSelectionDAG.td +++ b/llvm/include/llvm/Target/TargetSelectionDAG.td @@ -128,6 +128,13 @@ SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>, SDTCisInt<0>, SDTCisInt<3> ]>; +def SDTIntBinOpVP : SDTypeProfile<1, 4, [ // vp_add, vp_and, etc. + SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>, SDTCisInt<0>, SDTCisInt<4>, SDTCisSameNumEltsAs<0, 3> +]>; +def SDTIntShiftOpVP : SDTypeProfile<1, 4, [ // shl, sra, srl + SDTCisSameAs<0, 1>, SDTCisInt<0>, SDTCisInt<2>, SDTCisInt<4>, SDTCisSameNumEltsAs<0, 3> +]>; + def SDTFPBinOp : SDTypeProfile<1, 2, [ // fadd, fmul, etc. SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>, SDTCisFP<0> ]>; @@ -173,6 +180,16 @@ SDTCisOpSmallerThanOp<1, 0> ]>; +def SDTFPUnOpVP : SDTypeProfile<1, 3, [ // vp_fneg, etc. + SDTCisSameAs<0, 1>, SDTCisFP<0>, SDTCisInt<3>, SDTCisSameNumEltsAs<0, 2> +]>; +def SDTFPBinOpVP : SDTypeProfile<1, 4, [ // vp_fadd, etc. + SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>, SDTCisFP<0>, SDTCisInt<4>, SDTCisSameNumEltsAs<0, 3> +]>; +def SDTFPTernaryOpVP : SDTypeProfile<1, 5, [ // vp_fmadd, etc. + SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>, SDTCisSameAs<0, 3>, SDTCisFP<0>, SDTCisInt<5>, SDTCisSameNumEltsAs<0, 4> +]>; + def SDTSetCC : SDTypeProfile<1, 3, [ // setcc SDTCisInt<0>, SDTCisSameAs<1, 2>, SDTCisVT<3, OtherVT> ]>; @@ -185,6 +202,10 @@ SDTCisVec<0>, SDTCisInt<1>, SDTCisSameAs<0, 2>, SDTCisSameAs<2, 3>, SDTCisSameNumEltsAs<0, 1> ]>; +def SDTVSelectVP : SDTypeProfile<1, 5, [ // vp_vselect + SDTCisVec<0>, SDTCisInt<1>, SDTCisSameAs<0, 2>, SDTCisSameAs<2, 3>, SDTCisSameNumEltsAs<0, 1>, SDTCisInt<5>, SDTCisSameNumEltsAs<0, 4> +]>; + def SDTSelectCC : SDTypeProfile<1, 5, [ // select_cc SDTCisSameAs<1, 2>, SDTCisSameAs<3, 4>, SDTCisSameAs<0, 3>, SDTCisVT<5, OtherVT> @@ -228,11 +249,20 @@ SDTCisVec<0>, SDTCisPtrTy<1>, SDTCisVec<2>, SDTCisSameNumEltsAs<0, 2> ]>; +def SDTStoreVP: SDTypeProfile<0, 4, [ // evl store + SDTCisVec<0>, SDTCisPtrTy<1>, SDTCisVec<2>, SDTCisSameNumEltsAs<0, 2>, SDTCisInt<3> +]>; + def SDTMaskedLoad: SDTypeProfile<1, 3, [ // masked load SDTCisVec<0>, SDTCisPtrTy<1>, SDTCisVec<2>, SDTCisSameAs<0, 3>, SDTCisSameNumEltsAs<0, 2> ]>; +def SDTLoadVP : SDTypeProfile<1, 3, [ // evl load + SDTCisVec<0>, SDTCisPtrTy<1>, SDTCisSameNumEltsAs<0, 2>, SDTCisInt<3>, + SDTCisSameNumEltsAs<0, 2> +]>; + def SDTVecShuffle : SDTypeProfile<1, 2, [ SDTCisSameAs<0, 1>, SDTCisSameAs<1, 2> ]>; @@ -391,6 +421,26 @@ def umax : SDNode<"ISD::UMAX" , SDTIntBinOp, [SDNPCommutative, SDNPAssociative]>; +def vp_and : SDNode<"ISD::VP_AND" , SDTIntBinOpVP , + [SDNPCommutative, SDNPAssociative]>; +def vp_or : SDNode<"ISD::VP_OR" , SDTIntBinOpVP , + [SDNPCommutative, SDNPAssociative]>; +def vp_xor : SDNode<"ISD::VP_XOR" , SDTIntBinOpVP , + [SDNPCommutative, SDNPAssociative]>; +def vp_srl : SDNode<"ISD::VP_SRL" , SDTIntShiftOpVP>; +def vp_sra : SDNode<"ISD::VP_SRA" , SDTIntShiftOpVP>; +def vp_shl : SDNode<"ISD::VP_SHL" , SDTIntShiftOpVP>; + +def vp_add : SDNode<"ISD::VP_ADD" , SDTIntBinOpVP , + [SDNPCommutative, SDNPAssociative]>; +def vp_sub : SDNode<"ISD::VP_SUB" , SDTIntBinOpVP>; +def vp_mul : SDNode<"ISD::VP_MUL" , SDTIntBinOpVP, + [SDNPCommutative, SDNPAssociative]>; +def vp_sdiv : SDNode<"ISD::VP_SDIV" , SDTIntBinOpVP>; +def vp_udiv : SDNode<"ISD::VP_UDIV" , SDTIntBinOpVP>; +def vp_srem : SDNode<"ISD::VP_SREM" , SDTIntBinOpVP>; +def vp_urem : SDNode<"ISD::VP_UREM" , SDTIntBinOpVP>; + def saddsat : SDNode<"ISD::SADDSAT" , SDTIntBinOp, [SDNPCommutative]>; def uaddsat : SDNode<"ISD::UADDSAT" , SDTIntBinOp, [SDNPCommutative]>; def ssubsat : SDNode<"ISD::SSUBSAT" , SDTIntBinOp>; @@ -473,6 +523,14 @@ def fpextend : SDNode<"ISD::FP_EXTEND" , SDTFPExtendOp>; def fcopysign : SDNode<"ISD::FCOPYSIGN" , SDTFPSignOp>; +def vp_fneg : SDNode<"ISD::VP_FNEG" , SDTFPUnOpVP>; +def vp_fadd : SDNode<"ISD::VP_FADD" , SDTFPBinOpVP, [SDNPCommutative]>; +def vp_fsub : SDNode<"ISD::VP_FSUB" , SDTFPBinOpVP>; +def vp_fmul : SDNode<"ISD::VP_FMUL" , SDTFPBinOpVP, [SDNPCommutative]>; +def vp_fdiv : SDNode<"ISD::VP_FDIV" , SDTFPBinOpVP>; +def vp_frem : SDNode<"ISD::VP_FREM" , SDTFPBinOpVP>; +def vp_fma : SDNode<"ISD::VP_FMA" , SDTFPTernaryOpVP>; + def sint_to_fp : SDNode<"ISD::SINT_TO_FP" , SDTIntToFPOp>; def uint_to_fp : SDNode<"ISD::UINT_TO_FP" , SDTIntToFPOp>; def fp_to_sint : SDNode<"ISD::FP_TO_SINT" , SDTFPToIntOp>; @@ -610,6 +668,11 @@ def masked_ld : SDNode<"ISD::MLOAD", SDTMaskedLoad, [SDNPHasChain, SDNPMayLoad, SDNPMemOperand]>; +def vp_store : SDNode<"ISD::VP_STORE", SDTStoreVP, + [SDNPHasChain, SDNPMayStore, SDNPMemOperand]>; +def vp_load : SDNode<"ISD::VP_LOAD", SDTLoadVP, + [SDNPHasChain, SDNPMayLoad, SDNPMemOperand]>; + // Do not use ld, st directly. Use load, extload, sextload, zextload, store, // and truncst (see below). def ld : SDNode<"ISD::LOAD" , SDTLoad, diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp --- a/llvm/lib/Analysis/InstructionSimplify.cpp +++ b/llvm/lib/Analysis/InstructionSimplify.cpp @@ -37,6 +37,7 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/Operator.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/PredicatedInst.h" #include "llvm/IR/ValueHandle.h" #include "llvm/Support/KnownBits.h" #include @@ -4554,8 +4555,10 @@ /// Given operands for an FSub, see if we can fold the result. If not, this /// returns null. -static Value *SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, - const SimplifyQuery &Q, unsigned MaxRecurse) { +template +static Value *SimplifyFSubInstGeneric(Value *Op0, Value *Op1, FastMathFlags FMF, + const SimplifyQuery &Q, unsigned MaxRecurse, MatchContext & MC) { + if (Constant *C = foldOrCommuteConstant(Instruction::FSub, Op0, Op1, Q)) return C; @@ -4563,26 +4566,26 @@ return C; // fsub X, +0 ==> X - if (match(Op1, m_PosZeroFP())) + if (MC.try_match(Op1, m_PosZeroFP())) return Op0; // fsub X, -0 ==> X, when we know X is not -0 - if (match(Op1, m_NegZeroFP()) && + if (MC.try_match(Op1, m_NegZeroFP()) && (FMF.noSignedZeros() || CannotBeNegativeZero(Op0, Q.TLI))) return Op0; // fsub -0.0, (fsub -0.0, X) ==> X // fsub -0.0, (fneg X) ==> X Value *X; - if (match(Op0, m_NegZeroFP()) && - match(Op1, m_FNeg(m_Value(X)))) + if (MC.try_match(Op0, m_NegZeroFP()) && + MC.try_match(Op1, m_FNeg(m_Value(X)))) return X; // fsub 0.0, (fsub 0.0, X) ==> X if signed zeros are ignored. // fsub 0.0, (fneg X) ==> X if signed zeros are ignored. if (FMF.noSignedZeros() && match(Op0, m_AnyZeroFP()) && - (match(Op1, m_FSub(m_AnyZeroFP(), m_Value(X))) || - match(Op1, m_FNeg(m_Value(X))))) + (MC.try_match(Op1, m_FSub(m_AnyZeroFP(), m_Value(X))) || + MC.try_match(Op1, m_FNeg(m_Value(X))))) return X; // fsub nnan x, x ==> 0.0 @@ -4592,8 +4595,8 @@ // Y - (Y - X) --> X // (X + Y) - Y --> X if (FMF.noSignedZeros() && FMF.allowReassoc() && - (match(Op1, m_FSub(m_Specific(Op0), m_Value(X))) || - match(Op0, m_c_FAdd(m_Specific(Op1), m_Value(X))))) + (MC.try_match(Op1, m_FSub(m_Specific(Op0), m_Value(X))) || + MC.try_match(Op0, m_c_FAdd(m_Specific(Op1), m_Value(X))))) return X; return nullptr; @@ -4648,9 +4651,26 @@ } + +/// Given operands for an FSub, see if we can fold the result. +static Value *SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, + const SimplifyQuery &Q, unsigned MaxRecurse) { + if (Constant *C = foldOrCommuteConstant(Instruction::FSub, Op0, Op1, Q)) + return C; + + EmptyContext EC; + return SimplifyFSubInstGeneric(Op0, Op1, FMF, Q, RecursionLimit, EC); +} + Value *llvm::SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, const SimplifyQuery &Q) { - return ::SimplifyFSubInst(Op0, Op1, FMF, Q, RecursionLimit); + // Now apply simplifications that do not require rounding. + return SimplifyFSubInst(Op0, Op1, FMF, Q, RecursionLimit); +} + +Value *llvm::SimplifyPredicatedFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, + const SimplifyQuery &Q, PredicatedContext & PC) { + return ::SimplifyFSubInstGeneric(Op0, Op1, FMF, Q, RecursionLimit, PC); } Value *llvm::SimplifyFMulInst(Value *Op0, Value *Op1, FastMathFlags FMF, @@ -5241,9 +5261,20 @@ return ConstantFoldCall(Call, F, ConstantArgs, Q.TLI); } +Value *llvm::SimplifyVPIntrinsic(VPIntrinsic & VPInst, const SimplifyQuery &Q) { + PredicatedContext PC(&VPInst); + + auto & PI = cast(VPInst); + switch (PI.getOpcode()) { + default: + return nullptr; + + case Instruction::FSub: return SimplifyPredicatedFSubInst(VPInst.getOperand(0), VPInst.getOperand(1), VPInst.getFastMathFlags(), Q, PC); + } +} + /// See if we can compute a simplified version of this instruction. /// If not, this returns null. - Value *llvm::SimplifyInstruction(Instruction *I, const SimplifyQuery &SQ, OptimizationRemarkEmitter *ORE) { const SimplifyQuery Q = SQ.CxtI ? SQ : SQ.getWithInstruction(I); @@ -5380,6 +5411,13 @@ Result = SimplifyPHINode(cast(I), Q); break; case Instruction::Call: { + auto * VPInst = dyn_cast(I); + if (VPInst) { + Result = SimplifyVPIntrinsic(*VPInst, Q); + if (Result) break; + } + + CallSite CS((I)); Result = SimplifyCall(cast(I), Q); break; } diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp --- a/llvm/lib/Analysis/TargetTransformInfo.cpp +++ b/llvm/lib/Analysis/TargetTransformInfo.cpp @@ -830,6 +830,14 @@ return TTIImpl->getStoreVectorFactor(VF, StoreSize, ChainSizeInBytes, VecTy); } +bool TargetTransformInfo::shouldFoldVectorLengthIntoMask(const PredicatedInstruction &PI) const { + return TTIImpl->shouldFoldVectorLengthIntoMask(PI); +} + +bool TargetTransformInfo::supportsVPOperation(const PredicatedInstruction &PI) const { + return TTIImpl->supportsVPOperation(PI); +} + bool TargetTransformInfo::useReductionIntrinsic(unsigned Opcode, Type *Ty, ReductionFlags Flags) const { return TTIImpl->useReductionIntrinsic(Opcode, Ty, Flags); diff --git a/llvm/lib/AsmParser/LLLexer.cpp b/llvm/lib/AsmParser/LLLexer.cpp --- a/llvm/lib/AsmParser/LLLexer.cpp +++ b/llvm/lib/AsmParser/LLLexer.cpp @@ -645,6 +645,7 @@ KEYWORD(inlinehint); KEYWORD(inreg); KEYWORD(jumptable); + KEYWORD(mask); KEYWORD(minsize); KEYWORD(naked); KEYWORD(nest); @@ -666,6 +667,7 @@ KEYWORD(optforfuzzing); KEYWORD(optnone); KEYWORD(optsize); + KEYWORD(passthru); KEYWORD(readnone); KEYWORD(readonly); KEYWORD(returned); @@ -689,6 +691,7 @@ KEYWORD(swiftself); KEYWORD(uwtable); KEYWORD(willreturn); + KEYWORD(vlen); KEYWORD(writeonly); KEYWORD(zeroext); KEYWORD(immarg); diff --git a/llvm/lib/AsmParser/LLParser.cpp b/llvm/lib/AsmParser/LLParser.cpp --- a/llvm/lib/AsmParser/LLParser.cpp +++ b/llvm/lib/AsmParser/LLParser.cpp @@ -1338,15 +1338,18 @@ case lltok::kw_dereferenceable: case lltok::kw_dereferenceable_or_null: case lltok::kw_inalloca: + case lltok::kw_mask: case lltok::kw_nest: case lltok::kw_noalias: case lltok::kw_nocapture: case lltok::kw_nonnull: + case lltok::kw_passthru: case lltok::kw_returned: case lltok::kw_sret: case lltok::kw_swifterror: case lltok::kw_swiftself: case lltok::kw_immarg: + case lltok::kw_vlen: HaveError |= Error(Lex.getLoc(), "invalid use of parameter-only attribute on a function"); @@ -1633,11 +1636,13 @@ } case lltok::kw_inalloca: B.addAttribute(Attribute::InAlloca); break; case lltok::kw_inreg: B.addAttribute(Attribute::InReg); break; + case lltok::kw_mask: B.addAttribute(Attribute::Mask); break; case lltok::kw_nest: B.addAttribute(Attribute::Nest); break; case lltok::kw_noalias: B.addAttribute(Attribute::NoAlias); break; case lltok::kw_nocapture: B.addAttribute(Attribute::NoCapture); break; case lltok::kw_nofree: B.addAttribute(Attribute::NoFree); break; case lltok::kw_nonnull: B.addAttribute(Attribute::NonNull); break; + case lltok::kw_passthru: B.addAttribute(Attribute::Passthru); break; case lltok::kw_readnone: B.addAttribute(Attribute::ReadNone); break; case lltok::kw_readonly: B.addAttribute(Attribute::ReadOnly); break; case lltok::kw_returned: B.addAttribute(Attribute::Returned); break; @@ -1645,6 +1650,7 @@ case lltok::kw_sret: B.addAttribute(Attribute::StructRet); break; case lltok::kw_swifterror: B.addAttribute(Attribute::SwiftError); break; case lltok::kw_swiftself: B.addAttribute(Attribute::SwiftSelf); break; + case lltok::kw_vlen: B.addAttribute(Attribute::VectorLength); break; case lltok::kw_writeonly: B.addAttribute(Attribute::WriteOnly); break; case lltok::kw_zeroext: B.addAttribute(Attribute::ZExt); break; case lltok::kw_immarg: B.addAttribute(Attribute::ImmArg); break; @@ -1737,13 +1743,16 @@ // Error handling. case lltok::kw_byval: case lltok::kw_inalloca: + case lltok::kw_mask: case lltok::kw_nest: case lltok::kw_nocapture: + case lltok::kw_passthru: case lltok::kw_returned: case lltok::kw_sret: case lltok::kw_swifterror: case lltok::kw_swiftself: case lltok::kw_immarg: + case lltok::kw_vlen: HaveError |= Error(Lex.getLoc(), "invalid use of parameter-only attribute"); break; @@ -3412,7 +3421,7 @@ ID.Kind = ValID::t_Constant; return false; } - + // Unary Operators. case lltok::kw_fneg: case lltok::kw_freeze: { @@ -3423,7 +3432,7 @@ ParseGlobalTypeAndValue(Val) || ParseToken(lltok::rparen, "expected ')' in unary constantexpr")) return true; - + // Check that the type is valid for the operator. switch (Opc) { case Instruction::FNeg: @@ -4761,7 +4770,7 @@ OPTIONAL(declaration, MDField, ); \ OPTIONAL(name, MDStringField, ); \ OPTIONAL(file, MDField, ); \ - OPTIONAL(line, LineField, ); + OPTIONAL(line, LineField, ); PARSE_MD_FIELDS(); #undef VISIT_MD_FIELDS diff --git a/llvm/lib/AsmParser/LLToken.h b/llvm/lib/AsmParser/LLToken.h --- a/llvm/lib/AsmParser/LLToken.h +++ b/llvm/lib/AsmParser/LLToken.h @@ -191,6 +191,7 @@ kw_inlinehint, kw_inreg, kw_jumptable, + kw_mask, kw_minsize, kw_naked, kw_nest, @@ -212,6 +213,7 @@ kw_optforfuzzing, kw_optnone, kw_optsize, + kw_passthru, kw_readnone, kw_readonly, kw_returned, @@ -232,6 +234,7 @@ kw_swiftself, kw_uwtable, kw_willreturn, + kw_vlen, kw_writeonly, kw_zeroext, kw_immarg, diff --git a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp --- a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp +++ b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp @@ -1296,6 +1296,15 @@ case Attribute::SanitizeMemTag: llvm_unreachable("sanitize_memtag attribute not supported in raw format"); break; + case Attribute::Mask: + llvm_unreachable("mask attribute not supported in raw format"); + break; + case Attribute::VectorLength: + llvm_unreachable("vlen attribute not supported in raw format"); + break; + case Attribute::Passthru: + llvm_unreachable("passthru attribute not supported in raw format"); + break; } llvm_unreachable("Unsupported attribute type"); } @@ -1310,6 +1319,9 @@ I == Attribute::DereferenceableOrNull || I == Attribute::ArgMemOnly || I == Attribute::AllocSize || + I == Attribute::Mask || + I == Attribute::VectorLength || + I == Attribute::Passthru || I == Attribute::NoSync) continue; if (uint64_t A = (Val & getRawAttributeMask(I))) { @@ -1435,6 +1447,8 @@ return Attribute::InReg; case bitc::ATTR_KIND_JUMP_TABLE: return Attribute::JumpTable; + case bitc::ATTR_KIND_MASK: + return Attribute::Mask; case bitc::ATTR_KIND_MIN_SIZE: return Attribute::MinSize; case bitc::ATTR_KIND_NAKED: @@ -1483,6 +1497,8 @@ return Attribute::OptimizeForSize; case bitc::ATTR_KIND_OPTIMIZE_NONE: return Attribute::OptimizeNone; + case bitc::ATTR_KIND_PASSTHRU: + return Attribute::Passthru; case bitc::ATTR_KIND_READ_NONE: return Attribute::ReadNone; case bitc::ATTR_KIND_READ_ONLY: @@ -1529,6 +1545,8 @@ return Attribute::UWTable; case bitc::ATTR_KIND_WILLRETURN: return Attribute::WillReturn; + case bitc::ATTR_KIND_VECTORLENGTH: + return Attribute::VectorLength; case bitc::ATTR_KIND_WRITEONLY: return Attribute::WriteOnly; case bitc::ATTR_KIND_Z_EXT: diff --git a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp --- a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp +++ b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp @@ -678,6 +678,12 @@ return bitc::ATTR_KIND_READ_ONLY; case Attribute::Returned: return bitc::ATTR_KIND_RETURNED; + case Attribute::Mask: + return bitc::ATTR_KIND_MASK; + case Attribute::VectorLength: + return bitc::ATTR_KIND_VECTORLENGTH; + case Attribute::Passthru: + return bitc::ATTR_KIND_PASSTHRU; case Attribute::ReturnsTwice: return bitc::ATTR_KIND_RETURNS_TWICE; case Attribute::SExt: diff --git a/llvm/lib/CodeGen/CMakeLists.txt b/llvm/lib/CodeGen/CMakeLists.txt --- a/llvm/lib/CodeGen/CMakeLists.txt +++ b/llvm/lib/CodeGen/CMakeLists.txt @@ -25,6 +25,7 @@ ExpandMemCmp.cpp ExpandPostRAPseudos.cpp ExpandReductions.cpp + ExpandVectorPredication.cpp FaultMaps.cpp FEntryInserter.cpp FinalizeISel.cpp diff --git a/llvm/lib/CodeGen/ExpandVectorPredication.cpp b/llvm/lib/CodeGen/ExpandVectorPredication.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/CodeGen/ExpandVectorPredication.cpp @@ -0,0 +1,602 @@ +//===--- ExpandVectorPredication.cpp - Expand vector predication intrinsics +//--===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass implements IR expansion for vector predication intrinsics, allowing +// targets to enable vector predication until just before codegen. +// +//===----------------------------------------------------------------------===// + +#include "llvm/CodeGen/ExpandVectorPredication.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/CodeGen/Passes.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PredicatedInst.h" +#include "llvm/Pass.h" +#include "llvm/Support/Debug.h" +#include "llvm/Transforms/Utils/LoopUtils.h" + +using namespace llvm; + +#define DEBUG_TYPE "expand-vec-pred" + +STATISTIC(NumFoldedVL, "Number of folded vector length params"); +STATISTIC(numLoweredVPOps, "Number of folded vector predication operations"); + +namespace { + +/// \brief The logical vector element size of this operation. +int32_t GetFunctionalVectorElementSize() { + return 64; // TODO infer from operation (eg + // VPIntrinsic::getVectorElementSize()) +} + +/// \returns A vector with ascending integer indices (<0, 1, ..., NumElems-1>). +Value *CreateStepVector(IRBuilder<> &Builder, int32_t ElemBits, + int32_t NumElems) { + // TODO add caching + SmallVector ConstElems; + + Type *LaneTy = Builder.getIntNTy(ElemBits); + + for (int32_t Idx = 0; Idx < NumElems; ++Idx) { + ConstElems.push_back(ConstantInt::get(LaneTy, Idx, false)); + } + + return ConstantVector::get(ConstElems); +} + +/// \returns A bitmask that is true where the lane position is less-than +/// +/// \p Builder +/// Used for instruction creation. +/// \p VLParam +/// The explicit vector length parameter to test against the lane +/// positions. +// \p ElemBits +/// Integer bitsize used for the generated ICmp instruction. +/// \p NumElems +/// Static vector length of the operation. +Value *ConvertVLToMask(IRBuilder<> &Builder, Value *VLParam, int32_t ElemBits, + int32_t NumElems) { + // TODO increase elem bits to shrink wrap VLParam where necessary (eg if + // operating on i8) + Type *LaneTy = Builder.getIntNTy(ElemBits); + + auto ExtVLParam = Builder.CreateSExt(VLParam, LaneTy); + auto VLSplat = Builder.CreateVectorSplat(NumElems, ExtVLParam); + + auto IdxVec = CreateStepVector(Builder, ElemBits, NumElems); + + return Builder.CreateICmp(CmpInst::ICMP_ULT, IdxVec, VLSplat); +} + +/// \returns A non-excepting divisor constant for this type. +Constant *getSafeDivisor(Type *DivTy) { + if (DivTy->isIntOrIntVectorTy()) { + return Constant::getAllOnesValue(DivTy); + } + if (DivTy->isFPOrFPVectorTy()) { + return ConstantVector::getSplat( + DivTy->getVectorNumElements(), + ConstantFP::get(DivTy->getVectorElementType(), 1.0)); + } + llvm_unreachable("Not a valid type for division"); +} + +/// Transfer operation properties from \p OldVPI to \p NewVal. +void TransferDecorations(Value *NewVal, VPIntrinsic *OldVPI) { + auto NewInst = dyn_cast(NewVal); + if (!NewInst) + return; + + if (auto FPMathOp = dyn_cast(OldVPI)) { + NewInst->setFastMathFlags(FPMathOp->getFastMathFlags()); + } +} + +/// \brief Lower this vector-predicated operator into standard IR. +void LowerVPUnaryOperator(VPIntrinsic *VPI) { + assert(VPI->canIgnoreVectorLengthParam()); + auto OC = VPI->getFunctionalOpcode(); + auto FirstOp = VPI->getOperand(0); + assert(OC == Instruction::FNeg); + auto I = cast(VPI); + IRBuilder<> Builder(I); + auto NewFNeg = Builder.CreateFNegFMF(FirstOp, I, I->getName()); + I->replaceAllUsesWith(NewFNeg); + I->eraseFromParent(); +} + +/// \brief Lower this VP binary operator to a non-VP binary operator. +void LowerVPBinaryOperator(VPIntrinsic *VPI) { + assert(VPI->canIgnoreVectorLengthParam()); + assert(VPI->isBinaryOp()); + + auto OldBinOp = cast(VPI); + + auto FirstOp = VPI->getOperand(0); + auto SndOp = VPI->getOperand(1); + + IRBuilder<> Builder(OldBinOp); + auto Mask = VPI->getMaskParam(); + + switch (VPI->getFunctionalOpcode()) { + default: + // can safely ignore the predicate + break; + + // Division operators need a safe divisor on masked-off lanes (1.0) + case Instruction::FDiv: + case Instruction::FRem: + case Instruction::UDiv: + case Instruction::SDiv: + case Instruction::URem: + case Instruction::SRem: + // 2nd operand must not be zero + auto SafeDivisor = getSafeDivisor(VPI->getType()); + SndOp = Builder.CreateSelect(Mask, SndOp, SafeDivisor); + } + + auto NewBinOp = Builder.CreateBinOp( + static_cast(VPI->getFunctionalOpcode()), FirstOp, + SndOp, VPI->getName(), nullptr); + + if (auto *NewBinOpInst = cast(NewBinOp)) { + // transfer FMF flags, wrapping attributes, .. + TransferDecorations(NewBinOpInst, VPI); + } + + OldBinOp->replaceAllUsesWith(NewBinOp); + OldBinOp->eraseFromParent(); +} + +/// \brief Lower llvm.vp.compose.* into a select instruction +void LowerVPCompose(VPIntrinsic *VPI) { + auto &I = cast(*VPI); + auto ElemBits = GetFunctionalVectorElementSize(); + auto NumElems = VPI->getStaticVectorLength(); + assert(NumElems.hasValue() && "TODO scalable vector support"); + + IRBuilder<> Builder(cast(VPI)); + auto PivotMask = ConvertVLToMask(Builder, VPI->getOperand(2), ElemBits, + NumElems.getValue()); + auto NewCompose = Builder.CreateSelect(PivotMask, VPI->getOperand(0), + VPI->getOperand(1), VPI->getName()); + I.replaceAllUsesWith(NewCompose); + I.eraseFromParent(); +} + +/// \brief Lower this llvm.vp.fma intrinsic to a llvm.fma intrinsic. +void LowerVPFMA(VPIntrinsic *VPI) { + assert(VPI->canIgnoreVectorLengthParam()); + + auto I = cast(VPI); + auto M = I->getParent()->getModule(); + IRBuilder<> Builder(I); + auto FMAFunc = Intrinsic::getDeclaration(M, Intrinsic::fma, VPI->getType()); + auto NewFMA = Builder.CreateCall( + FMAFunc, {VPI->getOperand(0), VPI->getOperand(1), VPI->getOperand(2)}, + VPI->getName()); + TransferDecorations(NewFMA, VPI); + I->replaceAllUsesWith(NewFMA); + I->eraseFromParent(); +} + +/// \returns Whether the vector mask \p MaskVal has all lane bits set. +static bool IsAllTrueMask(Value *MaskVal) { + auto ConstVec = dyn_cast(MaskVal); + if (!ConstVec) + return false; + return ConstVec->isAllOnesValue(); +} + +/// \returns The constant \p ConstVal broadcasted to \p VecTy. +static Value *BroadcastConstant(Constant *ConstVal, VectorType *VecTy) { + return ConstantDataVector::getSplat(VecTy->getVectorNumElements(), ConstVal); +} + +/// \returns The neutral element of the reduction \p VPRedID. +static Value *GetNeutralElementVector(Intrinsic::ID VPRedID, + VectorType *VecTy) { + unsigned ElemBits = VecTy->getScalarSizeInBits(); + + switch (VPRedID) { + default: + abort(); // invalid vp reduction intrinsic + + case Intrinsic::vp_reduce_add: + case Intrinsic::vp_reduce_or: + case Intrinsic::vp_reduce_xor: + case Intrinsic::vp_reduce_umax: + return Constant::getNullValue(VecTy); + + case Intrinsic::vp_reduce_mul: + return BroadcastConstant( + ConstantInt::get(VecTy->getElementType(), 1, false), VecTy); + + case Intrinsic::vp_reduce_and: + case Intrinsic::vp_reduce_umin: + return Constant::getAllOnesValue(VecTy); + + case Intrinsic::vp_reduce_smin: + return BroadcastConstant( + ConstantInt::get(VecTy->getContext(), + APInt::getSignedMaxValue(ElemBits)), + VecTy); + case Intrinsic::vp_reduce_smax: + return BroadcastConstant( + ConstantInt::get(VecTy->getContext(), + APInt::getSignedMinValue(ElemBits)), + VecTy); + + case Intrinsic::vp_reduce_fmin: + case Intrinsic::vp_reduce_fmax: + return BroadcastConstant(ConstantFP::getQNaN(VecTy->getElementType()), + VecTy); + case Intrinsic::vp_reduce_fadd: + return BroadcastConstant(ConstantFP::get(VecTy->getElementType(), 0.0), + VecTy); + case Intrinsic::vp_reduce_fmul: + return BroadcastConstant(ConstantFP::get(VecTy->getElementType(), 1.0), + VecTy); + } +} + +/// \brief Lower this llvm.vp.reduce.* intrinsic to a llvm.experimental.reduce.* +/// intrinsic. +void LowerVPReduction(VPIntrinsic *VPI) { + assert(VPI->canIgnoreVectorLengthParam()); + assert(VPI->isReductionOp()); + + auto &I = *cast(VPI); + IRBuilder<> Builder(&I); + auto M = Builder.GetInsertBlock()->getModule(); + assert(M && "No module to declare reduction intrinsic in!"); + + SmallVector Args; + + Value *RedVectorParam = VPI->getReductionVectorParam(); + Value *RedAccuParam = VPI->getReductionAccuParam(); + Value *MaskParam = VPI->getMaskParam(); + auto FunctionalID = VPI->getFunctionalIntrinsicID(); + + // Insert neutral element in masked-out positions + bool IsUnmasked = IsAllTrueMask(VPI->getMaskParam()); + if (!IsUnmasked) { + auto *NeutralVector = GetNeutralElementVector( + VPI->getIntrinsicID(), cast(RedVectorParam->getType())); + RedVectorParam = + Builder.CreateSelect(MaskParam, RedVectorParam, NeutralVector); + } + + auto VecTypeArg = RedVectorParam->getType(); + + Value *NewReduct; + switch (FunctionalID) { + default: { + auto RedIntrinFunc = Intrinsic::getDeclaration(M, FunctionalID, VecTypeArg); + NewReduct = Builder.CreateCall(RedIntrinFunc, RedVectorParam, I.getName()); + assert(!RedAccuParam && "accu dropped"); + } break; + + case Intrinsic::experimental_vector_reduce_v2_fadd: + case Intrinsic::experimental_vector_reduce_v2_fmul: { + auto TypeArg = RedAccuParam->getType(); + auto RedIntrinFunc = + Intrinsic::getDeclaration(M, FunctionalID, {TypeArg, VecTypeArg}); + NewReduct = Builder.CreateCall(RedIntrinFunc, + {RedAccuParam, RedVectorParam}, I.getName()); + } break; + } + + TransferDecorations(NewReduct, VPI); + I.replaceAllUsesWith(NewReduct); + I.eraseFromParent(); +} + +/// \brief Lower this llvm.vp.(load|store|gather|scatter) to a non-vp +/// instruction. +void LowerVPMemoryIntrinsic(VPIntrinsic *VPI) { + assert(VPI->canIgnoreVectorLengthParam()); + auto &I = cast(*VPI); + + auto MaskParam = VPI->getMaskParam(); + auto PtrParam = VPI->getMemoryPointerParam(); + auto DataParam = VPI->getMemoryDataParam(); + bool IsUnmasked = IsAllTrueMask(MaskParam); + + IRBuilder<> Builder(&I); + auto &DL = Builder.GetInsertBlock()->getModule()->getDataLayout(); + + Value *NewMemoryInst = nullptr; + switch (VPI->getIntrinsicID()) { + default: + abort(); // not a VP memory intrinsic + + case Intrinsic::vp_store: { + if (IsUnmasked) { + NewMemoryInst = Builder.CreateStore(DataParam, PtrParam, false); + } else { + Align MayAlign = PtrParam->getPointerAlignment(DL).valueOrOne(); + NewMemoryInst = Builder.CreateMaskedStore(DataParam, PtrParam, + MayAlign.value(), MaskParam); + } + } break; + + case Intrinsic::vp_load: { + if (IsUnmasked) { + NewMemoryInst = Builder.CreateLoad(PtrParam, false); + } else { + Align MayAlign = PtrParam->getPointerAlignment(DL).valueOrOne(); + NewMemoryInst = + Builder.CreateMaskedLoad(PtrParam, MayAlign.value(), MaskParam); + } + } break; + + case Intrinsic::vp_scatter: { + if (IsUnmasked) { + NewMemoryInst = Builder.CreateStore(DataParam, PtrParam, false); + } else { + Align MayAlign; // FIXME = PtrParam->getPointerAlignment(DL).valueOrOne(); + NewMemoryInst = Builder.CreateMaskedScatter(DataParam, PtrParam, + MayAlign.value(), MaskParam); + } + } break; + + case Intrinsic::vp_gather: { + if (IsUnmasked) { + NewMemoryInst = Builder.CreateLoad(I.getType(), PtrParam, false); + } else { + Align MayAlign; // FIXME = PtrParam->getPointerAlignment(DL).valueOrOne(); + NewMemoryInst = Builder.CreateMaskedGather( + PtrParam, MayAlign.value(), MaskParam, nullptr, I.getName()); + } + } break; + } + + assert(NewMemoryInst); + I.replaceAllUsesWith(NewMemoryInst); + I.eraseFromParent(); +} + +/// \brief Lower llvm.vp.select.* to a select instruction. +void LowerVPSelectInst(VPIntrinsic *VPI) { + auto I = cast(VPI); + + auto NewVal = SelectInst::Create(VPI->getMaskParam(), VPI->getOperand(1), + VPI->getOperand(2), I->getName(), I, I); + TransferDecorations(NewVal, VPI); + I->replaceAllUsesWith(NewVal); + I->eraseFromParent(); +} + +/// \brief Lower llvm.vp.(icmp|fcmp) to an icmp or fcmp instruction. +void LowerVPCompare(VPIntrinsic *VPI) { + auto NewCmp = CmpInst::Create( + static_cast(VPI->getFunctionalOpcode()), + VPI->getCmpPredicate(), VPI->getOperand(0), VPI->getOperand(1), + VPI->getName(), cast(VPI)); + VPI->replaceAllUsesWith(NewCmp); + VPI->eraseFromParent(); +} + +/// \brief Lower a llvm.vp.* intrinsic that is not functionally equivalent to a +/// standard IR instruction. +void LowerUnmatchedVPIntrinsic(VPIntrinsic *VPI) { + if (VPI->isReductionOp()) + return LowerVPReduction(VPI); + + switch (VPI->getIntrinsicID()) { + default: + abort(); // unexpected intrinsic + + case Intrinsic::vp_compress: + case Intrinsic::vp_expand: + case Intrinsic::vp_vshift: + LLVM_DEBUG(dbgs() << "Silently keeping VP intrinsic: can not substitute: " + << *VPI << "\n"); + return; + + case Intrinsic::vp_compose: + LowerVPCompose(VPI); + break; + + case Intrinsic::vp_fma: + LowerVPFMA(VPI); + break; + + case Intrinsic::vp_gather: + case Intrinsic::vp_scatter: + LowerVPMemoryIntrinsic(VPI); + break; + } +} + +/// \brief Expand llvm.vp.* intrinsics as requested by \p TTI. +bool expandVectorPredication(Function &F, const TargetTransformInfo *TTI) { + bool Changed = false; + + // Holds all vector-predicated ops with an effective vector length param that + // needs to be folded into the mask param. + SmallVector ExpandVLWorklist; + + // Holds all vector-predicated ops that need to translated into non-VP ops. + SmallVector ExpandOpWorklist; + + for (auto &I : instructions(F)) { + auto *VPI = dyn_cast(&I); + if (!VPI) + continue; + + auto &PI = cast(*VPI); + + bool supportsVPOp = TTI->supportsVPOperation(PI); + bool hasEffectiveVLParam = !VPI->canIgnoreVectorLengthParam(); + bool shouldFoldVLParam = + !supportsVPOp || TTI->shouldFoldVectorLengthIntoMask(PI); + + LLVM_DEBUG(dbgs() << "Inspecting " << VPI + << "\n:: target-support=" << supportsVPOp + << ", effectiveVecLen=" << hasEffectiveVLParam + << ", shouldFoldVecLen=" << shouldFoldVLParam << "\n"); + + if (shouldFoldVLParam) { + if (hasEffectiveVLParam && VPI->getMaskParam()) { + ExpandVLWorklist.push_back(VPI); + } else { + ExpandOpWorklist.push_back(VPI); + } + } + } + + // Fold vector-length params into the mask param. + LLVM_DEBUG(dbgs() << "\n:::: Folding vlen into mask. ::::\n"); + for (VPIntrinsic *VPI : ExpandVLWorklist) { + ++NumFoldedVL; + Changed = true; + + LLVM_DEBUG(dbgs() << "Folding vlen for op: " << *VPI << '\n'); + + IRBuilder<> Builder(cast(VPI)); + + Value *OldMaskParam = VPI->getMaskParam(); + Value *OldVLParam = VPI->getVectorLengthParam(); + assert(OldMaskParam && "no mask param to fold the vl param into"); + assert(OldVLParam && "no vector length param to fold away"); + + LLVM_DEBUG(dbgs() << "OLD vlen: " << *OldVLParam << '\n'); + LLVM_DEBUG(dbgs() << "OLD mask: " << *OldMaskParam << '\n'); + + // Determine the lane bit size that should be used to lower this op + auto ElemBits = GetFunctionalVectorElementSize(); + auto NumElems = VPI->getStaticVectorLength(); + assert(NumElems.hasValue() && "TODO scalable vector support"); + + // Lower VL to M + auto *VLMask = + ConvertVLToMask(Builder, OldVLParam, ElemBits, NumElems.getValue()); + auto NewMaskParam = Builder.CreateAnd(VLMask, OldMaskParam); + VPI->setMaskParam( + NewMaskParam); // FIXME cannot trivially use the PI abstraction here. + + // Disable VL + auto FullVL = Builder.getInt32(NumElems.getValue()); + VPI->setVectorLengthParam(FullVL); + assert(VPI->canIgnoreVectorLengthParam() && + "transformation did not render the vl param ineffective!"); + + LLVM_DEBUG(dbgs() << "NEW vlen: " << *FullVL << '\n'); + LLVM_DEBUG(dbgs() << "NEW mask: " << *NewMaskParam << '\n'); + + auto &PI = cast(*VPI); + if (!TTI->supportsVPOperation(PI)) { + ExpandOpWorklist.push_back(VPI); + } + } + + // Translate into non-VP ops + LLVM_DEBUG(dbgs() << "\n:::: Lowering VP into non-VP ops ::::\n"); + for (VPIntrinsic *VPI : ExpandOpWorklist) { + ++numLoweredVPOps; + Changed = true; + + LLVM_DEBUG(dbgs() << "Lowering vp op: " << *VPI << '\n'); + + unsigned OC = VPI->getFunctionalOpcode(); +#define FIRST_UNARY_INST(X) unsigned FirstUnOp = X; +#define LAST_UNARY_INST(X) unsigned LastUnOp = X; +#define FIRST_BINARY_INST(X) unsigned FirstBinOp = X; +#define LAST_BINARY_INST(X) unsigned LastBinOp = X; +#include "llvm/IR/Instruction.def" + + if (FirstBinOp <= OC && OC <= LastBinOp) { + LowerVPBinaryOperator(VPI); + continue; + } + if (FirstUnOp <= OC && OC <= LastUnOp) { + LowerVPUnaryOperator(VPI); + continue; + } + + switch (OC) { + default: + abort(); // unexpected intrinsic + + case Instruction::Call: + LowerUnmatchedVPIntrinsic(VPI); + break; + + case Instruction::Select: + LowerVPSelectInst(VPI); + break; + + case Instruction::Store: + case Instruction::Load: + LowerVPMemoryIntrinsic(VPI); + break; + + case Instruction::ICmp: + case Instruction::FCmp: + LowerVPCompare(VPI); + break; + } + } + + return Changed; +} + +class ExpandVectorPredication : public FunctionPass { +public: + static char ID; + ExpandVectorPredication() : FunctionPass(ID) { + initializeExpandVectorPredicationPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override { + const auto *TTI = &getAnalysis().getTTI(F); + return expandVectorPredication(F, TTI); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired(); + AU.setPreservesCFG(); + } +}; +} // namespace + +char ExpandVectorPredication::ID; +INITIALIZE_PASS_BEGIN(ExpandVectorPredication, "expand-vec-pred", + "Expand vector predication intrinsics", false, false) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_END(ExpandVectorPredication, "expand-vec-pred", + "Expand vector predication intrinsics", false, false) + +FunctionPass *llvm::createExpandVectorPredicationPass() { + return new ExpandVectorPredication(); +} + +PreservedAnalyses +ExpandVectorPredicationPass::run(Function &F, FunctionAnalysisManager &AM) { + const auto &TTI = AM.getResult(F); + if (!expandVectorPredication(F, &TTI)) + return PreservedAnalyses::all(); + PreservedAnalyses PA; + PA.preserveSet(); + return PA; +} diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -427,6 +427,7 @@ SDValue visitBITCAST(SDNode *N); SDValue visitBUILD_PAIR(SDNode *N); SDValue visitFADD(SDNode *N); + SDValue visitFADD_VP(SDNode *N); SDValue visitFSUB(SDNode *N); SDValue visitFMUL(SDNode *N); SDValue visitFMA(SDNode *N); @@ -475,6 +476,7 @@ SDValue visitFP16_TO_FP(SDNode *N); SDValue visitVECREDUCE(SDNode *N); + template SDValue visitFADDForFMACombine(SDNode *N); SDValue visitFSUBForFMACombine(SDNode *N); SDValue visitFMULForFMADistributiveCombine(SDNode *N); @@ -736,6 +738,137 @@ void NodeInserted(SDNode *N) override { DC.ConsiderForPruning(N); } }; +struct EmptyMatchContext { + SelectionDAG & DAG; + + EmptyMatchContext(SelectionDAG & DAG, SDNode * Root) + : DAG(DAG) + {} + + bool match(SDValue OpN, unsigned OpCode) const { return OpCode == OpN->getOpcode(); } + + unsigned getFunctionOpCode(SDValue N) const { + return N->getOpcode(); + } + + bool isCompatible(SDValue OpVal) const { return true; } + + // Specialize based on number of operands. + SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT) { return DAG.getNode(Opcode, DL, VT); } + SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue Operand, + const SDNodeFlags Flags = SDNodeFlags()) { + return DAG.getNode(Opcode, DL, VT, Operand, Flags); + } + SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, + SDValue N2, const SDNodeFlags Flags = SDNodeFlags()) { + return DAG.getNode(Opcode, DL, VT, N1, N2, Flags); + } + SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, + SDValue N2, SDValue N3, + const SDNodeFlags Flags = SDNodeFlags()) { + return DAG.getNode(Opcode, DL, VT, N1, N2, N3, Flags); + } + + SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, + SDValue N2, SDValue N3, SDValue N4) { + return DAG.getNode(Opcode, DL, VT, N1, N2, N3, N4); + } + + SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, + SDValue N2, SDValue N3, SDValue N4, SDValue N5) { + return DAG.getNode(Opcode, DL, VT, N1, N2, N3, N4, N5); + } +}; + +struct +VPMatchContext { + SelectionDAG & DAG; + SDNode * Root; + SDValue RootMaskOp; + SDValue RootVectorLenOp; + + VPMatchContext(SelectionDAG & DAG, SDNode * Root) + : DAG(DAG) + , Root(Root) + , RootMaskOp() + , RootVectorLenOp() + { + if (Root->isVP()) { + int RootMaskPos = ISD::GetMaskPosVP(Root->getOpcode()); + if (RootMaskPos != -1) { + RootMaskOp = Root->getOperand(RootMaskPos); + } + + int RootVLenPos = ISD::GetVectorLengthPosVP(Root->getOpcode()); + if (RootVLenPos != -1) { + RootVectorLenOp = Root->getOperand(RootVLenPos); + } + } + } + + unsigned getFunctionOpCode(SDValue N) const { + unsigned VPOpCode = N->getOpcode(); + return ISD::GetFunctionOpCodeForVP(VPOpCode, N->getFlags().hasFPExcept()); + } + + bool isCompatible(SDValue OpVal) const { + if (!OpVal->isVP()) { + return !Root->isVP(); + + } else { + unsigned VPOpCode = OpVal->getOpcode(); + int MaskPos = ISD::GetMaskPosVP(VPOpCode); + if (MaskPos != -1 && RootMaskOp != OpVal.getOperand(MaskPos)) { + return false; + } + + int VLenPos = ISD::GetVectorLengthPosVP(VPOpCode); + if (VLenPos != -1 && RootVectorLenOp != OpVal.getOperand(VLenPos)) { + return false; + } + + return true; + } + } + + /// whether \p OpN is a node that is functionally compatible with the NodeType \p OpNodeTy + bool match(SDValue OpVal, unsigned OpNT) const { + return isCompatible(OpVal) && getFunctionOpCode(OpVal) == OpNT; + } + + // Specialize based on number of operands. + // TODO emit VP intrinsics where MaskOp/VectorLenOp != null + // SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT) { return DAG.getNode(Opcode, DL, VT); } + SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue Operand, + const SDNodeFlags Flags = SDNodeFlags()) { + unsigned VPOpcode = ISD::GetVPForFunctionOpCode(Opcode); + int MaskPos = ISD::GetMaskPosVP(VPOpcode); + int VLenPos = ISD::GetVectorLengthPosVP(VPOpcode); + assert(MaskPos == 1 && VLenPos == 2); + + return DAG.getNode(VPOpcode, DL, VT, {Operand, RootMaskOp, RootVectorLenOp}, Flags); + } + SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, + SDValue N2, const SDNodeFlags Flags = SDNodeFlags()) { + unsigned VPOpcode = ISD::GetVPForFunctionOpCode(Opcode); + int MaskPos = ISD::GetMaskPosVP(VPOpcode); + int VLenPos = ISD::GetVectorLengthPosVP(VPOpcode); + assert(MaskPos == 2 && VLenPos == 3); + + return DAG.getNode(VPOpcode, DL, VT, {N1, N2, RootMaskOp, RootVectorLenOp}, Flags); + } + SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, + SDValue N2, SDValue N3, + const SDNodeFlags Flags = SDNodeFlags()) { + unsigned VPOpcode = ISD::GetVPForFunctionOpCode(Opcode); + int MaskPos = ISD::GetMaskPosVP(VPOpcode); + int VLenPos = ISD::GetVectorLengthPosVP(VPOpcode); + assert(MaskPos == 3 && VLenPos == 4); + + return DAG.getNode(VPOpcode, DL, VT, {N1, N2, N3, RootMaskOp, RootVectorLenOp}, Flags); + } +}; + } // end anonymous namespace //===----------------------------------------------------------------------===// @@ -1560,6 +1693,7 @@ case ISD::BITCAST: return visitBITCAST(N); case ISD::BUILD_PAIR: return visitBUILD_PAIR(N); case ISD::FADD: return visitFADD(N); + case ISD::VP_FADD: return visitFADD_VP(N); case ISD::FSUB: return visitFSUB(N); case ISD::FMUL: return visitFMUL(N); case ISD::FMA: return visitFMA(N); @@ -11325,13 +11459,18 @@ return F.hasAllowContract() || F.hasAllowReassociation(); } + /// Try to perform FMA combining on a given FADD node. +template SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); EVT VT = N->getValueType(0); SDLoc SL(N); + MatchContextClass matcher(DAG, N); + if (!matcher.isCompatible(N0) || !matcher.isCompatible(N1)) return SDValue(); + const TargetOptions &Options = DAG.getTarget().Options; // Floating-point multiply-add with intermediate rounding. @@ -11364,8 +11503,8 @@ // Is the node an FMUL and contractable either due to global flags or // SDNodeFlags. - auto isContractableFMUL = [AllowFusionGlobally](SDValue N) { - if (N.getOpcode() != ISD::FMUL) + auto isContractableFMUL = [AllowFusionGlobally, &matcher](SDValue N) { + if (!matcher.match(N, ISD::FMUL)) return false; return AllowFusionGlobally || isContractable(N.getNode()); }; @@ -11378,44 +11517,44 @@ // fold (fadd (fmul x, y), z) -> (fma x, y, z) if (isContractableFMUL(N0) && (Aggressive || N0->hasOneUse())) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, + return matcher.getNode(PreferredFusedOpcode, SL, VT, N0.getOperand(0), N0.getOperand(1), N1, Flags); } // fold (fadd x, (fmul y, z)) -> (fma y, z, x) // Note: Commutes FADD operands. if (isContractableFMUL(N1) && (Aggressive || N1->hasOneUse())) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, + return matcher.getNode(PreferredFusedOpcode, SL, VT, N1.getOperand(0), N1.getOperand(1), N0, Flags); } // Look through FP_EXTEND nodes to do more combining. // fold (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z) - if (N0.getOpcode() == ISD::FP_EXTEND) { + if ((N0.getOpcode() == ISD::FP_EXTEND) && matcher.isCompatible(N0.getOperand(0))) { SDValue N00 = N0.getOperand(0); if (isContractableFMUL(N00) && TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT, N00.getValueType())) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, + return matcher.getNode(PreferredFusedOpcode, SL, VT, DAG.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(0)), - DAG.getNode(ISD::FP_EXTEND, SL, VT, + matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(1)), N1, Flags); } } // fold (fadd x, (fpext (fmul y, z))) -> (fma (fpext y), (fpext z), x) // Note: Commutes FADD operands. - if (N1.getOpcode() == ISD::FP_EXTEND) { + if (matcher.match(N1, ISD::FP_EXTEND)) { SDValue N10 = N1.getOperand(0); if (isContractableFMUL(N10) && TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT, N10.getValueType())) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FP_EXTEND, SL, VT, + return matcher.getNode(PreferredFusedOpcode, SL, VT, + matcher.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(0)), - DAG.getNode(ISD::FP_EXTEND, SL, VT, + matcher.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(1)), N0, Flags); } } @@ -11424,12 +11563,12 @@ if (Aggressive) { // fold (fadd (fma x, y, (fmul u, v)), z) -> (fma x, y (fma u, v, z)) if (CanFuse && - N0.getOpcode() == PreferredFusedOpcode && - N0.getOperand(2).getOpcode() == ISD::FMUL && + matcher.match(N0, PreferredFusedOpcode) && + matcher.match(N0.getOperand(2), ISD::FMUL) && N0->hasOneUse() && N0.getOperand(2)->hasOneUse()) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, + return matcher.getNode(PreferredFusedOpcode, SL, VT, N0.getOperand(0), N0.getOperand(1), - DAG.getNode(PreferredFusedOpcode, SL, VT, + matcher.getNode(PreferredFusedOpcode, SL, VT, N0.getOperand(2).getOperand(0), N0.getOperand(2).getOperand(1), N1, Flags), Flags); @@ -11437,12 +11576,12 @@ // fold (fadd x, (fma y, z, (fmul u, v)) -> (fma y, z (fma u, v, x)) if (CanFuse && - N1->getOpcode() == PreferredFusedOpcode && - N1.getOperand(2).getOpcode() == ISD::FMUL && + matcher.match(N1, PreferredFusedOpcode) && + matcher.match(N1.getOperand(2), ISD::FMUL) && N1->hasOneUse() && N1.getOperand(2)->hasOneUse()) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, + return matcher.getNode(PreferredFusedOpcode, SL, VT, N1.getOperand(0), N1.getOperand(1), - DAG.getNode(PreferredFusedOpcode, SL, VT, + matcher.getNode(PreferredFusedOpcode, SL, VT, N1.getOperand(2).getOperand(0), N1.getOperand(2).getOperand(1), N0, Flags), Flags); @@ -11454,15 +11593,15 @@ auto FoldFAddFMAFPExtFMul = [&] ( SDValue X, SDValue Y, SDValue U, SDValue V, SDValue Z, SDNodeFlags Flags) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, X, Y, - DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FP_EXTEND, SL, VT, U), - DAG.getNode(ISD::FP_EXTEND, SL, VT, V), + return matcher.getNode(PreferredFusedOpcode, SL, VT, X, Y, + matcher.getNode(PreferredFusedOpcode, SL, VT, + matcher.getNode(ISD::FP_EXTEND, SL, VT, U), + matcher.getNode(ISD::FP_EXTEND, SL, VT, V), Z, Flags), Flags); }; - if (N0.getOpcode() == PreferredFusedOpcode) { + if (matcher.match(N0, PreferredFusedOpcode)) { SDValue N02 = N0.getOperand(2); - if (N02.getOpcode() == ISD::FP_EXTEND) { + if (matcher.match(N02, ISD::FP_EXTEND)) { SDValue N020 = N02.getOperand(0); if (isContractableFMUL(N020) && TLI.isFPExtFoldable(DAG, PreferredFusedOpcode, VT, @@ -11482,12 +11621,12 @@ auto FoldFAddFPExtFMAFMul = [&] ( SDValue X, SDValue Y, SDValue U, SDValue V, SDValue Z, SDNodeFlags Flags) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FP_EXTEND, SL, VT, X), - DAG.getNode(ISD::FP_EXTEND, SL, VT, Y), - DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FP_EXTEND, SL, VT, U), - DAG.getNode(ISD::FP_EXTEND, SL, VT, V), + return matcher.getNode(PreferredFusedOpcode, SL, VT, + matcher.getNode(ISD::FP_EXTEND, SL, VT, X), + matcher.getNode(ISD::FP_EXTEND, SL, VT, Y), + matcher.getNode(PreferredFusedOpcode, SL, VT, + matcher.getNode(ISD::FP_EXTEND, SL, VT, U), + matcher.getNode(ISD::FP_EXTEND, SL, VT, V), Z, Flags), Flags); }; if (N0.getOpcode() == ISD::FP_EXTEND) { @@ -11935,6 +12074,15 @@ return SDValue(); } +SDValue DAGCombiner::visitFADD_VP(SDNode *N) { + // FADD -> FMA combines: + if (SDValue Fused = visitFADDForFMACombine(N)) { + AddToWorklist(Fused.getNode()); + return Fused; + } + return SDValue(); +} + SDValue DAGCombiner::visitFADD(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); @@ -12107,7 +12255,7 @@ } // enable-unsafe-fp-math // FADD -> FMA combines: - if (SDValue Fused = visitFADDForFMACombine(N)) { + if (SDValue Fused = visitFADDForFMACombine(N)) { AddToWorklist(Fused.getNode()); return Fused; } diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp @@ -992,7 +992,7 @@ } // Handle promotion for the ADDE/SUBE/ADDCARRY/SUBCARRY nodes. Notice that -// the third operand of ADDE/SUBE nodes is carry flag, which differs from +// the third operand of ADDE/SUBE nodes is carry flag, which differs from // the ADDCARRY/SUBCARRY nodes in that the third operand is carry Boolean. SDValue DAGTypeLegalizer::PromoteIntRes_ADDSUBCARRY(SDNode *N, unsigned ResNo) { if (ResNo == 1) @@ -1144,6 +1144,9 @@ return false; } + if (N->isVP()) { + Res = PromoteIntOp_VP(N, OpNo); + } else { switch (N->getOpcode()) { default: #ifndef NDEBUG @@ -1222,6 +1225,7 @@ case ISD::VECREDUCE_UMAX: case ISD::VECREDUCE_UMIN: Res = PromoteIntOp_VECREDUCE(N); break; } + } // If the result is null, the sub-method took care of registering results etc. if (!Res.getNode()) return false; @@ -1502,6 +1506,25 @@ TruncateStore, N->isCompressingStore()); } +SDValue DAGTypeLegalizer::PromoteIntOp_VP(SDNode *N, unsigned OpNo) { + EVT DataVT; + switch (N->getOpcode()) { + default: + DataVT = N->getValueType(0); + break; + + case ISD::VP_STORE: + case ISD::VP_SCATTER: + llvm_unreachable("TODO implement VP memory nodes"); + } + + // TODO assert that \p OpNo is the mask + SDValue Mask = PromoteTargetBoolean(N->getOperand(OpNo), DataVT); + SmallVector NewOps(N->op_begin(), N->op_end()); + NewOps[OpNo] = Mask; + return SDValue(DAG.UpdateNodeOperands(N, NewOps), 0); +} + SDValue DAGTypeLegalizer::PromoteIntOp_MLOAD(MaskedLoadSDNode *N, unsigned OpNo) { assert(OpNo == 2 && "Only know how to promote the mask!"); diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h @@ -341,6 +341,7 @@ SDValue PromoteIntRes_VECREDUCE(SDNode *N); SDValue PromoteIntRes_ABS(SDNode *N); + // Integer Operand Promotion. bool PromoteIntegerOperand(SDNode *N, unsigned OpNo); SDValue PromoteIntOp_ANY_EXTEND(SDNode *N); @@ -376,6 +377,7 @@ SDValue PromoteIntOp_MULFIX(SDNode *N); SDValue PromoteIntOp_FPOWI(SDNode *N); SDValue PromoteIntOp_VECREDUCE(SDNode *N); + SDValue PromoteIntOp_VP(SDNode *N, unsigned OpNo); void PromoteSetCCOperands(SDValue &LHS,SDValue &RHS, ISD::CondCode Code); diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -428,6 +428,213 @@ return Result; } +//===----------------------------------------------------------------------===// +// SDNode VP Support +//===----------------------------------------------------------------------===// + +int +ISD::GetMaskPosVP(unsigned OpCode) { + switch (OpCode) { + default: return -1; + + case ISD::VP_FNEG: + return 1; + + case ISD::VP_ADD: + case ISD::VP_SUB: + case ISD::VP_MUL: + case ISD::VP_SDIV: + case ISD::VP_SREM: + case ISD::VP_UDIV: + case ISD::VP_UREM: + + case ISD::VP_AND: + case ISD::VP_OR: + case ISD::VP_XOR: + case ISD::VP_SHL: + case ISD::VP_SRA: + case ISD::VP_SRL: + + case ISD::VP_FADD: + case ISD::VP_FMUL: + case ISD::VP_FSUB: + case ISD::VP_FDIV: + case ISD::VP_FREM: + return 2; + + case ISD::VP_FMA: + case ISD::VP_SELECT: + return 3; + + case VP_REDUCE_ADD: + case VP_REDUCE_MUL: + case VP_REDUCE_AND: + case VP_REDUCE_OR: + case VP_REDUCE_XOR: + case VP_REDUCE_SMAX: + case VP_REDUCE_SMIN: + case VP_REDUCE_UMAX: + case VP_REDUCE_UMIN: + case VP_REDUCE_FMAX: + case VP_REDUCE_FMIN: + return 1; + + case VP_REDUCE_FADD: + case VP_REDUCE_FMUL: + return 2; + + /// FMIN/FMAX nodes can have flags, for NaN/NoNaN variants. + // (implicit) case ISD::VP_COMPOSE: return -1 + } +} + +int +ISD::GetVectorLengthPosVP(unsigned OpCode) { + switch (OpCode) { + default: return -1; + + case VP_SELECT: + return 0; + + case VP_FNEG: + return 2; + + case VP_ADD: + case VP_SUB: + case VP_MUL: + case VP_SDIV: + case VP_SREM: + case VP_UDIV: + case VP_UREM: + + case VP_AND: + case VP_OR: + case VP_XOR: + case VP_SHL: + case VP_SRA: + case VP_SRL: + + case VP_FADD: + case VP_FMUL: + case VP_FDIV: + case VP_FREM: + return 3; + + case VP_FMA: + return 4; + + case VP_COMPOSE: + return 3; + + case VP_REDUCE_ADD: + case VP_REDUCE_MUL: + case VP_REDUCE_AND: + case VP_REDUCE_OR: + case VP_REDUCE_XOR: + case VP_REDUCE_SMAX: + case VP_REDUCE_SMIN: + case VP_REDUCE_UMAX: + case VP_REDUCE_UMIN: + case VP_REDUCE_FMAX: + case VP_REDUCE_FMIN: + return 2; + + case VP_REDUCE_FADD: + case VP_REDUCE_FMUL: + return 3; + + } +} + +unsigned +ISD::GetFunctionOpCodeForVP(unsigned OpCode, bool hasFPExcept) { + switch (OpCode) { + default: return OpCode; + + case VP_SELECT: return ISD::VSELECT; + case VP_ADD: return ISD::ADD; + case VP_SUB: return ISD::SUB; + case VP_MUL: return ISD::MUL; + case VP_SDIV: return ISD::SDIV; + case VP_SREM: return ISD::SREM; + case VP_UDIV: return ISD::UDIV; + case VP_UREM: return ISD::UREM; + + case VP_AND: return ISD::AND; + case VP_OR: return ISD::OR; + case VP_XOR: return ISD::XOR; + case VP_SHL: return ISD::SHL; + case VP_SRA: return ISD::SRA; + case VP_SRL: return ISD::SRL; + + case VP_FNEG: return ISD::FNEG; + case VP_FADD: return hasFPExcept ? ISD::STRICT_FADD : ISD::FADD; + case VP_FSUB: return hasFPExcept ? ISD::STRICT_FSUB : ISD::FSUB; + case VP_FMUL: return hasFPExcept ? ISD::STRICT_FMUL : ISD::FMUL; + case VP_FDIV: return hasFPExcept ? ISD::STRICT_FDIV : ISD::FDIV; + case VP_FREM: return hasFPExcept ? ISD::STRICT_FREM : ISD::FREM; + + case VP_REDUCE_AND: return VECREDUCE_AND; + case VP_REDUCE_OR: return VECREDUCE_OR; + case VP_REDUCE_XOR: return VECREDUCE_XOR; + case VP_REDUCE_ADD: return VECREDUCE_ADD; + case VP_REDUCE_FADD: return VECREDUCE_FADD; + case VP_REDUCE_FMUL: return VECREDUCE_FMUL; + case VP_REDUCE_FMAX: return VECREDUCE_FMAX; + case VP_REDUCE_FMIN: return VECREDUCE_FMIN; + case VP_REDUCE_UMAX: return VECREDUCE_UMAX; + case VP_REDUCE_UMIN: return VECREDUCE_UMIN; + case VP_REDUCE_SMAX: return VECREDUCE_SMAX; + case VP_REDUCE_SMIN: return VECREDUCE_SMIN; + + case VP_STORE: return ISD::MSTORE; + case VP_LOAD: return ISD::MLOAD; + case VP_GATHER: return ISD::MGATHER; + case VP_SCATTER: return ISD::MSCATTER; + + case VP_FMA: return hasFPExcept ? ISD::STRICT_FMA : ISD::FMA; + } +} + +unsigned +ISD::GetVPForFunctionOpCode(unsigned OpCode) { + switch (OpCode) { + default: llvm_unreachable("can not translate this Opcode to VP"); + + case VSELECT: return ISD::VP_SELECT; + case ADD: return ISD::VP_ADD; + case SUB: return ISD::VP_SUB; + case MUL: return ISD::VP_MUL; + case SDIV: return ISD::VP_SDIV; + case SREM: return ISD::VP_SREM; + case UDIV: return ISD::VP_UDIV; + case UREM: return ISD::VP_UREM; + + case AND: return ISD::VP_AND; + case OR: return ISD::VP_OR; + case XOR: return ISD::VP_XOR; + case SHL: return ISD::VP_SHL; + case SRA: return ISD::VP_SRA; + case SRL: return ISD::VP_SRL; + + case FNEG: return ISD::VP_FNEG; + case STRICT_FADD: + case FADD: return ISD::VP_FADD; + case STRICT_FSUB: + case FSUB: return ISD::VP_FSUB; + case STRICT_FMUL: + case FMUL: return ISD::VP_FMUL; + case STRICT_FDIV: + case FDIV: return ISD::VP_FDIV; + case STRICT_FREM: + case FREM: return ISD::VP_FREM; + + case STRICT_FMA: + case FMA: return ISD::VP_FMA; + } +} + + //===----------------------------------------------------------------------===// // SDNode Profile Support //===----------------------------------------------------------------------===// @@ -558,6 +765,34 @@ ID.AddInteger(ST->getPointerInfo().getAddrSpace()); break; } + case ISD::VP_LOAD: { + const VPLoadSDNode *ELD = cast(N); + ID.AddInteger(ELD->getMemoryVT().getRawBits()); + ID.AddInteger(ELD->getRawSubclassData()); + ID.AddInteger(ELD->getPointerInfo().getAddrSpace()); + break; + } + case ISD::VP_STORE: { + const VPStoreSDNode *EST = cast(N); + ID.AddInteger(EST->getMemoryVT().getRawBits()); + ID.AddInteger(EST->getRawSubclassData()); + ID.AddInteger(EST->getPointerInfo().getAddrSpace()); + break; + } + case ISD::VP_GATHER: { + const VPGatherSDNode *EG = cast(N); + ID.AddInteger(EG->getMemoryVT().getRawBits()); + ID.AddInteger(EG->getRawSubclassData()); + ID.AddInteger(EG->getPointerInfo().getAddrSpace()); + break; + } + case ISD::VP_SCATTER: { + const VPScatterSDNode *ES = cast(N); + ID.AddInteger(ES->getMemoryVT().getRawBits()); + ID.AddInteger(ES->getRawSubclassData()); + ID.AddInteger(ES->getPointerInfo().getAddrSpace()); + break; + } case ISD::MLOAD: { const MaskedLoadSDNode *MLD = cast(N); ID.AddInteger(MLD->getMemoryVT().getRawBits()); @@ -6989,6 +7224,142 @@ return V; } +SDValue SelectionDAG::getLoadVP(EVT VT, const SDLoc &dl, SDValue Chain, + SDValue Ptr, SDValue Mask, SDValue VLen, + EVT MemVT, MachineMemOperand *MMO, + ISD::LoadExtType ExtTy) { + SDVTList VTs = getVTList(VT, MVT::Other); + SDValue Ops[] = { Chain, Ptr, Mask, VLen }; + FoldingSetNodeID ID; + AddNodeIDNode(ID, ISD::VP_LOAD, VTs, Ops); + ID.AddInteger(VT.getRawBits()); + ID.AddInteger(getSyntheticNodeSubclassData( + dl.getIROrder(), VTs, ExtTy, MemVT, MMO)); + ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + void *IP = nullptr; + if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) { + cast(E)->refineAlignment(MMO); + return SDValue(E, 0); + } + auto *N = newSDNode(dl.getIROrder(), dl.getDebugLoc(), VTs, + ExtTy, MemVT, MMO); + createOperands(N, Ops); + + CSEMap.InsertNode(N, IP); + InsertNode(N); + SDValue V(N, 0); + NewSDValueDbgMsg(V, "Creating new node: ", this); + return V; +} + + +SDValue SelectionDAG::getStoreVP(SDValue Chain, const SDLoc &dl, + SDValue Val, SDValue Ptr, SDValue Mask, + SDValue VLen, EVT MemVT, MachineMemOperand *MMO, + bool IsTruncating) { + assert(Chain.getValueType() == MVT::Other && + "Invalid chain type"); + SDVTList VTs = getVTList(MVT::Other); + SDValue Ops[] = { Chain, Val, Ptr, Mask, VLen }; + FoldingSetNodeID ID; + AddNodeIDNode(ID, ISD::MSTORE, VTs, Ops); + ID.AddInteger(MemVT.getRawBits()); + ID.AddInteger(getSyntheticNodeSubclassData( + dl.getIROrder(), VTs, IsTruncating, MemVT, MMO)); + ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + void *IP = nullptr; + if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) { + cast(E)->refineAlignment(MMO); + return SDValue(E, 0); + } + auto *N = newSDNode(dl.getIROrder(), dl.getDebugLoc(), VTs, + IsTruncating, MemVT, MMO); + createOperands(N, Ops); + + CSEMap.InsertNode(N, IP); + InsertNode(N); + SDValue V(N, 0); + NewSDValueDbgMsg(V, "Creating new node: ", this); + return V; +} + +SDValue SelectionDAG::getGatherVP(SDVTList VTs, EVT VT, const SDLoc &dl, + ArrayRef Ops, + MachineMemOperand *MMO, + ISD::MemIndexType IndexType) { + assert(Ops.size() == 6 && "Incompatible number of operands"); + + FoldingSetNodeID ID; + AddNodeIDNode(ID, ISD::VP_GATHER, VTs, Ops); + ID.AddInteger(VT.getRawBits()); + ID.AddInteger(getSyntheticNodeSubclassData( + dl.getIROrder(), VTs, VT, MMO, IndexType)); + ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + void *IP = nullptr; + if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) { + cast(E)->refineAlignment(MMO); + return SDValue(E, 0); + } + + auto *N = newSDNode(dl.getIROrder(), dl.getDebugLoc(), + VTs, VT, MMO, IndexType); + createOperands(N, Ops); + + assert(N->getMask().getValueType().getVectorNumElements() == + N->getValueType(0).getVectorNumElements() && + "Vector width mismatch between mask and data"); + assert(N->getIndex().getValueType().getVectorNumElements() >= + N->getValueType(0).getVectorNumElements() && + "Vector width mismatch between index and data"); + assert(isa(N->getScale()) && + cast(N->getScale())->getAPIntValue().isPowerOf2() && + "Scale should be a constant power of 2"); + + CSEMap.InsertNode(N, IP); + InsertNode(N); + SDValue V(N, 0); + NewSDValueDbgMsg(V, "Creating new node: ", this); + return V; +} + +SDValue SelectionDAG::getScatterVP(SDVTList VTs, EVT VT, const SDLoc &dl, + ArrayRef Ops, + MachineMemOperand *MMO, + ISD::MemIndexType IndexType) { + assert(Ops.size() == 7 && "Incompatible number of operands"); + + FoldingSetNodeID ID; + AddNodeIDNode(ID, ISD::VP_SCATTER, VTs, Ops); + ID.AddInteger(VT.getRawBits()); + ID.AddInteger(getSyntheticNodeSubclassData( + dl.getIROrder(), VTs, VT, MMO, IndexType)); + ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + void *IP = nullptr; + if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) { + cast(E)->refineAlignment(MMO); + return SDValue(E, 0); + } + auto *N = newSDNode(dl.getIROrder(), dl.getDebugLoc(), + VTs, VT, MMO, IndexType); + createOperands(N, Ops); + + assert(N->getMask().getValueType().getVectorNumElements() == + N->getValue().getValueType().getVectorNumElements() && + "Vector width mismatch between mask and data"); + assert(N->getIndex().getValueType().getVectorNumElements() >= + N->getValue().getValueType().getVectorNumElements() && + "Vector width mismatch between index and data"); + assert(isa(N->getScale()) && + cast(N->getScale())->getAPIntValue().isPowerOf2() && + "Scale should be a constant power of 2"); + + CSEMap.InsertNode(N, IP); + InsertNode(N); + SDValue V(N, 0); + NewSDValueDbgMsg(V, "Creating new node: ", this); + return V; +} + SDValue SelectionDAG::getMaskedStore(SDValue Chain, const SDLoc &dl, SDValue Val, SDValue Ptr, SDValue Mask, EVT MemVT, MachineMemOperand *MMO, diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h @@ -748,6 +748,12 @@ void visitIntrinsicCall(const CallInst &I, unsigned Intrinsic); void visitTargetIntrinsic(const CallInst &I, unsigned Intrinsic); void visitConstrainedFPIntrinsic(const ConstrainedFPIntrinsic &FPI); + void visitVectorPredicationIntrinsic(const VPIntrinsic &VPI); + void visitCmpVP(const VPIntrinsic &I); + void visitLoadVP(const CallInst &I); + void visitStoreVP(const CallInst &I); + void visitGatherVP(const CallInst &I); + void visitScatterVP(const CallInst &I); void visitVAStart(const CallInst &I); void visitVAArg(const VAArgInst &I); diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -4313,6 +4313,46 @@ setValue(&I, StoreNode); } +void SelectionDAGBuilder::visitStoreVP(const CallInst &I) { + SDLoc sdl = getCurSDLoc(); + + auto getVPStoreOps = [&](Value* &Ptr, Value* &Mask, Value* &Src0, + Value * &VLen, unsigned & Alignment) { + // llvm.masked.store.*(Src0, Ptr, Mask, VLen) + Src0 = I.getArgOperand(0); + Ptr = I.getArgOperand(1); + Alignment = I.getParamAlignment(1); + Mask = I.getArgOperand(2); + VLen = I.getArgOperand(3); + }; + + Value *PtrOperand, *MaskOperand, *Src0Operand, *VLenOperand; + unsigned Alignment = 0; + getVPStoreOps(PtrOperand, MaskOperand, Src0Operand, VLenOperand, Alignment); + + SDValue Ptr = getValue(PtrOperand); + SDValue Src0 = getValue(Src0Operand); + SDValue Mask = getValue(MaskOperand); + SDValue VLen = getValue(VLenOperand); + + EVT VT = Src0.getValueType(); + if (!Alignment) + Alignment = DAG.getEVTAlignment(VT); + + AAMDNodes AAInfo; + I.getAAMetadata(AAInfo); + + MachineMemOperand *MMO = + DAG.getMachineFunction(). + getMachineMemOperand(MachinePointerInfo(PtrOperand), + MachineMemOperand::MOStore, VT.getStoreSize(), + Alignment, AAInfo); + SDValue StoreNode = DAG.getStoreVP(getRoot(), sdl, Src0, Ptr, Mask, VLen, VT, + MMO, false /* Truncating */); + DAG.setRoot(StoreNode); + setValue(&I, StoreNode); +} + // Get a uniform base for the Gather/Scatter intrinsic. // The first argument of the Gather/Scatter intrinsic is a vector of pointers. // We try to represent it as a base pointer + vector of indices. @@ -4547,6 +4587,160 @@ setValue(&I, Gather); } +void SelectionDAGBuilder::visitGatherVP(const CallInst &I) { + SDLoc sdl = getCurSDLoc(); + + // @llvm.evl.gather.*(Ptrs, Mask, VLen) + const Value *Ptr = I.getArgOperand(0); + SDValue Mask = getValue(I.getArgOperand(1)); + SDValue VLen = getValue(I.getArgOperand(2)); + + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + EVT VT = TLI.getValueType(DAG.getDataLayout(), I.getType()); + unsigned Alignment = I.getParamAlignment(0); + if (!Alignment) + Alignment = DAG.getEVTAlignment(VT); + + AAMDNodes AAInfo; + I.getAAMetadata(AAInfo); + const MDNode *Ranges = I.getMetadata(LLVMContext::MD_range); + + SDValue Root = DAG.getRoot(); + SDValue Base; + SDValue Index; + ISD::MemIndexType IndexType; + SDValue Scale; + const Value *BasePtr = Ptr; + bool UniformBase = getUniformBase(BasePtr, Base, Index, IndexType, Scale, this); + bool ConstantMemory = false; + if (UniformBase && AA && + AA->pointsToConstantMemory( + MemoryLocation(BasePtr, + LocationSize::precise( + DAG.getDataLayout().getTypeStoreSize(I.getType())), + AAInfo))) { + // Do not serialize (non-volatile) loads of constant memory with anything. + Root = DAG.getEntryNode(); + ConstantMemory = true; + } + + MachineMemOperand *MMO = + DAG.getMachineFunction(). + getMachineMemOperand(MachinePointerInfo(UniformBase ? BasePtr : nullptr), + MachineMemOperand::MOLoad, VT.getStoreSize(), + Alignment, AAInfo, Ranges); + + if (!UniformBase) { + Base = DAG.getConstant(0, sdl, TLI.getPointerTy(DAG.getDataLayout())); + Index = getValue(Ptr); + Scale = DAG.getTargetConstant(1, sdl, TLI.getPointerTy(DAG.getDataLayout())); + } + SDValue Ops[] = { Root, Base, Index, Scale, Mask, VLen }; + SDValue Gather = DAG.getGatherVP(DAG.getVTList(VT, MVT::Other), VT, sdl, Ops, MMO, IndexType); + + SDValue OutChain = Gather.getValue(1); + if (!ConstantMemory) + PendingLoads.push_back(OutChain); + setValue(&I, Gather); +} + +void SelectionDAGBuilder::visitScatterVP(const CallInst &I) { + SDLoc sdl = getCurSDLoc(); + + // llvm.evl.scatter.*(Src0, Ptrs, Mask, VLen) + const Value *Ptr = I.getArgOperand(1); + SDValue Src0 = getValue(I.getArgOperand(0)); + SDValue Mask = getValue(I.getArgOperand(2)); + SDValue VLen = getValue(I.getArgOperand(3)); + EVT VT = Src0.getValueType(); + unsigned Alignment = I.getParamAlignment(1); + if (!Alignment) + Alignment = DAG.getEVTAlignment(VT); + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + + AAMDNodes AAInfo; + I.getAAMetadata(AAInfo); + + SDValue Base; + SDValue Index; + ISD::MemIndexType IndexType; + SDValue Scale; + const Value *BasePtr = Ptr; + bool UniformBase = getUniformBase(BasePtr, Base, Index, IndexType, Scale, this); + + const Value *MemOpBasePtr = UniformBase ? BasePtr : nullptr; + MachineMemOperand *MMO = DAG.getMachineFunction(). + getMachineMemOperand(MachinePointerInfo(MemOpBasePtr), + MachineMemOperand::MOStore, VT.getStoreSize(), + Alignment, AAInfo); + if (!UniformBase) { + Base = DAG.getConstant(0, sdl, TLI.getPointerTy(DAG.getDataLayout())); + Index = getValue(Ptr); + Scale = DAG.getTargetConstant(1, sdl, TLI.getPointerTy(DAG.getDataLayout())); + } + SDValue Ops[] = { getRoot(), Src0, Base, Index, Scale, Mask, VLen }; + SDValue Scatter = DAG.getScatterVP(DAG.getVTList(MVT::Other), VT, sdl, + Ops, MMO, IndexType); + DAG.setRoot(Scatter); + setValue(&I, Scatter); +} + +void SelectionDAGBuilder::visitLoadVP(const CallInst &I) { + SDLoc sdl = getCurSDLoc(); + + auto getMaskedLoadOps = [&](Value* &Ptr, Value* &Mask, Value* &VLen, + unsigned& Alignment) { + // @llvm.evl.load.*(Ptr, Mask, Vlen) + Ptr = I.getArgOperand(0); + Alignment = I.getParamAlignment(0); + Mask = I.getArgOperand(1); + VLen = I.getArgOperand(2); + }; + + Value *PtrOperand, *MaskOperand, *VLenOperand; + unsigned Alignment; + getMaskedLoadOps(PtrOperand, MaskOperand, VLenOperand, Alignment); + + SDValue Ptr = getValue(PtrOperand); + SDValue VLen = getValue(VLenOperand); + SDValue Mask = getValue(MaskOperand); + + // infer the return type + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + SmallVector ValValueVTs; + ComputeValueVTs(TLI, DAG.getDataLayout(), I.getType(), ValValueVTs); + EVT VT = ValValueVTs[0]; + assert((ValValueVTs.size() == 1) && "splitting not implemented"); + + if (!Alignment) + Alignment = DAG.getEVTAlignment(VT); + + AAMDNodes AAInfo; + I.getAAMetadata(AAInfo); + const MDNode *Ranges = I.getMetadata(LLVMContext::MD_range); + + // Do not serialize masked loads of constant memory with anything. + bool AddToChain = + !AA || !AA->pointsToConstantMemory(MemoryLocation( + PtrOperand, + LocationSize::precise( + DAG.getDataLayout().getTypeStoreSize(I.getType())), + AAInfo)); + SDValue InChain = AddToChain ? DAG.getRoot() : DAG.getEntryNode(); + + MachineMemOperand *MMO = + DAG.getMachineFunction(). + getMachineMemOperand(MachinePointerInfo(PtrOperand), + MachineMemOperand::MOLoad, VT.getStoreSize(), + Alignment, AAInfo, Ranges); + + SDValue Load = DAG.getLoadVP(VT, sdl, InChain, Ptr, Mask, VLen, VT, MMO, + ISD::NON_EXTLOAD); + if (AddToChain) + PendingLoads.push_back(Load.getValue(1)); + setValue(&I, Load); +} + void SelectionDAGBuilder::visitAtomicCmpXchg(const AtomicCmpXchgInst &I) { SDLoc dl = getCurSDLoc(); AtomicOrdering SuccessOrdering = I.getSuccessOrdering(); @@ -6154,6 +6348,64 @@ case Intrinsic::experimental_constrained_trunc: visitConstrainedFPIntrinsic(cast(I)); return; + + case Intrinsic::vp_and: + case Intrinsic::vp_or: + case Intrinsic::vp_xor: + case Intrinsic::vp_ashr: + case Intrinsic::vp_lshr: + case Intrinsic::vp_shl: + + case Intrinsic::vp_select: + case Intrinsic::vp_compose: + case Intrinsic::vp_compress: + case Intrinsic::vp_expand: + case Intrinsic::vp_vshift: + + case Intrinsic::vp_load: + case Intrinsic::vp_store: + case Intrinsic::vp_gather: + case Intrinsic::vp_scatter: + + case Intrinsic::vp_fneg: + + case Intrinsic::vp_fadd: + case Intrinsic::vp_fsub: + case Intrinsic::vp_fmul: + case Intrinsic::vp_fdiv: + case Intrinsic::vp_frem: + + case Intrinsic::vp_fma: + + case Intrinsic::vp_add: + case Intrinsic::vp_sub: + case Intrinsic::vp_mul: + case Intrinsic::vp_udiv: + case Intrinsic::vp_sdiv: + case Intrinsic::vp_urem: + case Intrinsic::vp_srem: + + case Intrinsic::vp_fcmp: + case Intrinsic::vp_icmp: + + case Intrinsic::vp_reduce_and: + case Intrinsic::vp_reduce_or: + case Intrinsic::vp_reduce_xor: + + case Intrinsic::vp_reduce_fadd: + case Intrinsic::vp_reduce_fmax: + case Intrinsic::vp_reduce_fmin: + case Intrinsic::vp_reduce_fmul: + + case Intrinsic::vp_reduce_add: + case Intrinsic::vp_reduce_mul: + case Intrinsic::vp_reduce_umax: + case Intrinsic::vp_reduce_umin: + case Intrinsic::vp_reduce_smax: + case Intrinsic::vp_reduce_smin: + visitVectorPredicationIntrinsic(cast(I)); + return; + case Intrinsic::fmuladd: { EVT VT = TLI.getValueType(DAG.getDataLayout(), I.getType()); if (TM.Options.AllowFPOpFusion != FPOpFusion::Strict && @@ -7020,8 +7272,7 @@ SDVTList VTs = DAG.getVTList(ValueVTs); SDValue Result = DAG.getNode(Opcode, sdl, VTs, Opers); - if (FPI.getExceptionBehavior() != - ConstrainedFPIntrinsic::ExceptionBehavior::ebIgnore) { + if (FPI.getExceptionBehavior() != ExceptionBehavior::ebIgnore) { SDNodeFlags Flags; Flags.setFPExcept(true); Result->setFlags(Flags); @@ -7034,6 +7285,220 @@ setValue(&FPI, FPResult); } +void SelectionDAGBuilder::visitCmpVP(const VPIntrinsic &I) { + ISD::CondCode Condition; + CmpInst::Predicate predicate = I.getCmpPredicate(); + bool IsFP = I.getOperand(0)->getType()->isFPOrFPVectorTy(); + if (IsFP) { + Condition = getFCmpCondCode(predicate); + auto *FPMO = dyn_cast(&I); + if ((FPMO && FPMO->hasNoNaNs()) || TM.Options.NoNaNsFPMath) + Condition = getFCmpCodeWithoutNaN(Condition); + + } else { + Condition = getICmpCondCode(predicate); + } + + SDValue Op1 = getValue(I.getOperand(0)); + SDValue Op2 = getValue(I.getOperand(1)); + + EVT DestVT = DAG.getTargetLoweringInfo().getValueType(DAG.getDataLayout(), + I.getType()); + setValue(&I, DAG.getSetCC(getCurSDLoc(), DestVT, Op1, Op2, Condition)); +} + +void SelectionDAGBuilder::visitVectorPredicationIntrinsic( + const VPIntrinsic &VPIntrin) { + SDLoc sdl = getCurSDLoc(); + unsigned Opcode; + switch (VPIntrin.getIntrinsicID()) { + default: + llvm_unreachable("Unforeseen intrinsic"); // Can't reach here. + + case Intrinsic::vp_load: + visitLoadVP(VPIntrin); + return; + case Intrinsic::vp_store: + visitStoreVP(VPIntrin); + return; + case Intrinsic::vp_gather: + visitGatherVP(VPIntrin); + return; + case Intrinsic::vp_scatter: + visitScatterVP(VPIntrin); + return; + + case Intrinsic::vp_fcmp: + case Intrinsic::vp_icmp: + visitCmpVP(VPIntrin); + return; + + case Intrinsic::vp_add: + Opcode = ISD::VP_ADD; + break; + case Intrinsic::vp_sub: + Opcode = ISD::VP_SUB; + break; + case Intrinsic::vp_mul: + Opcode = ISD::VP_MUL; + break; + case Intrinsic::vp_udiv: + Opcode = ISD::VP_UDIV; + break; + case Intrinsic::vp_sdiv: + Opcode = ISD::VP_SDIV; + break; + case Intrinsic::vp_urem: + Opcode = ISD::VP_UREM; + break; + case Intrinsic::vp_srem: + Opcode = ISD::VP_SREM; + break; + + case Intrinsic::vp_and: + Opcode = ISD::VP_AND; + break; + case Intrinsic::vp_or: + Opcode = ISD::VP_OR; + break; + case Intrinsic::vp_xor: + Opcode = ISD::VP_XOR; + break; + case Intrinsic::vp_ashr: + Opcode = ISD::VP_SRA; + break; + case Intrinsic::vp_lshr: + Opcode = ISD::VP_SRL; + break; + case Intrinsic::vp_shl: + Opcode = ISD::VP_SHL; + break; + + case Intrinsic::vp_fneg: + Opcode = ISD::VP_FNEG; + break; + case Intrinsic::vp_fadd: + Opcode = ISD::VP_FADD; + break; + case Intrinsic::vp_fsub: + Opcode = ISD::VP_FSUB; + break; + case Intrinsic::vp_fmul: + Opcode = ISD::VP_FMUL; + break; + case Intrinsic::vp_fdiv: + Opcode = ISD::VP_FDIV; + break; + case Intrinsic::vp_frem: + Opcode = ISD::VP_FREM; + break; + + case Intrinsic::vp_fma: + Opcode = ISD::VP_FMA; + break; + + case Intrinsic::vp_select: + Opcode = ISD::VP_SELECT; + break; + case Intrinsic::vp_compose: + Opcode = ISD::VP_COMPOSE; + break; + case Intrinsic::vp_compress: + Opcode = ISD::VP_COMPRESS; + break; + case Intrinsic::vp_expand: + Opcode = ISD::VP_EXPAND; + break; + case Intrinsic::vp_vshift: + Opcode = ISD::VP_VSHIFT; + break; + + case Intrinsic::vp_reduce_and: + Opcode = ISD::VP_REDUCE_AND; + break; + case Intrinsic::vp_reduce_or: + Opcode = ISD::VP_REDUCE_OR; + break; + case Intrinsic::vp_reduce_xor: + Opcode = ISD::VP_REDUCE_XOR; + break; + case Intrinsic::vp_reduce_add: + Opcode = ISD::VP_REDUCE_ADD; + break; + case Intrinsic::vp_reduce_mul: + Opcode = ISD::VP_REDUCE_MUL; + break; + case Intrinsic::vp_reduce_fadd: + Opcode = ISD::VP_REDUCE_FADD; + break; + case Intrinsic::vp_reduce_fmul: + Opcode = ISD::VP_REDUCE_FMUL; + break; + case Intrinsic::vp_reduce_smax: + Opcode = ISD::VP_REDUCE_SMAX; + break; + case Intrinsic::vp_reduce_smin: + Opcode = ISD::VP_REDUCE_SMIN; + break; + case Intrinsic::vp_reduce_umax: + Opcode = ISD::VP_REDUCE_UMAX; + break; + case Intrinsic::vp_reduce_umin: + Opcode = ISD::VP_REDUCE_UMIN; + break; + } + + // TODO memory evl: SDValue Chain = getRoot(); + + SmallVector ValueVTs; + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + ComputeValueVTs(TLI, DAG.getDataLayout(), VPIntrin.getType(), ValueVTs); + SDVTList VTs = DAG.getVTList(ValueVTs); + + // ValueVTs.push_back(MVT::Other); // Out chain + + SDValue Result; + + switch (VPIntrin.getNumArgOperands()) { + default: + llvm_unreachable("unexpected number of arguments to evl intrinsic"); + case 3: + Result = DAG.getNode(Opcode, sdl, VTs, + {getValue(VPIntrin.getArgOperand(0)), + getValue(VPIntrin.getArgOperand(1)), + getValue(VPIntrin.getArgOperand(2))}); + break; + + case 4: + Result = DAG.getNode(Opcode, sdl, VTs, + {getValue(VPIntrin.getArgOperand(0)), + getValue(VPIntrin.getArgOperand(1)), + getValue(VPIntrin.getArgOperand(2)), + getValue(VPIntrin.getArgOperand(3))}); + break; + + case 5: + Result = DAG.getNode(Opcode, sdl, VTs, + {getValue(VPIntrin.getArgOperand(0)), + getValue(VPIntrin.getArgOperand(1)), + getValue(VPIntrin.getArgOperand(2)), + getValue(VPIntrin.getArgOperand(3)), + getValue(VPIntrin.getArgOperand(4))}); + break; + } + + if (Result.getNode()->getNumValues() == 2) { + // this VP node has a chain + SDValue OutChain = Result.getValue(1); + DAG.setRoot(OutChain); + SDValue VPResult = Result.getValue(0); + setValue(&VPIntrin, VPResult); + } else { + // this is a pure node + setValue(&VPIntrin, Result); + } +} + std::pair SelectionDAGBuilder::lowerInvokable(TargetLowering::CallLoweringInfo &CLI, const BasicBlock *EHPadBB) { diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp @@ -439,6 +439,65 @@ case ISD::VECREDUCE_UMIN: return "vecreduce_umin"; case ISD::VECREDUCE_FMAX: return "vecreduce_fmax"; case ISD::VECREDUCE_FMIN: return "vecreduce_fmin"; + + // Explicit Vector Length erxtension + // VP Memory + case ISD::VP_LOAD: return "vp_load"; + case ISD::VP_STORE: return "vp_store"; + case ISD::VP_GATHER: return "vp_gather"; + case ISD::VP_SCATTER: return "vp_scatter"; + + // VP Unary operators + case ISD::VP_FNEG: return "vp_fneg"; + + // VP Binary operators + case ISD::VP_ADD: return "vp_add"; + case ISD::VP_SUB: return "vp_sub"; + case ISD::VP_MUL: return "vp_mul"; + case ISD::VP_SDIV: return "vp_sdiv"; + case ISD::VP_UDIV: return "vp_udiv"; + case ISD::VP_SREM: return "vp_srem"; + case ISD::VP_UREM: return "vp_urem"; + case ISD::VP_AND: return "vp_and"; + case ISD::VP_OR: return "vp_or"; + case ISD::VP_XOR: return "vp_xor"; + case ISD::VP_SHL: return "vp_shl"; + case ISD::VP_SRA: return "vp_sra"; + case ISD::VP_SRL: return "vp_srl"; + case ISD::VP_FADD: return "vp_fadd"; + case ISD::VP_FSUB: return "vp_fsub"; + case ISD::VP_FMUL: return "vp_fmul"; + case ISD::VP_FDIV: return "vp_fdiv"; + case ISD::VP_FREM: return "vp_frem"; + + // VP comparison + case ISD::VP_SETCC: return "vp_setcc"; + + // VP ternary operators + case ISD::VP_FMA: return "vp_fma"; + + // VP shuffle + case ISD::VP_VSHIFT: return "vp_vshift"; + case ISD::VP_COMPRESS: return "vp_compress"; + case ISD::VP_EXPAND: return "vp_expand"; + + case ISD::VP_COMPOSE: return "vp_compose"; + case ISD::VP_SELECT: return "vp_select"; + + // VP reduction operators + case ISD::VP_REDUCE_FADD: return "vp_reduce_fadd"; + case ISD::VP_REDUCE_FMUL: return "vp_reduce_fmul"; + case ISD::VP_REDUCE_ADD: return "vp_reduce_add"; + case ISD::VP_REDUCE_MUL: return "vp_reduce_mul"; + case ISD::VP_REDUCE_AND: return "vp_reduce_and"; + case ISD::VP_REDUCE_OR: return "vp_reduce_or"; + case ISD::VP_REDUCE_XOR: return "vp_reduce_xor"; + case ISD::VP_REDUCE_SMAX: return "vp_reduce_smax"; + case ISD::VP_REDUCE_SMIN: return "vp_reduce_smin"; + case ISD::VP_REDUCE_UMAX: return "vp_reduce_umax"; + case ISD::VP_REDUCE_UMIN: return "vp_reduce_umin"; + case ISD::VP_REDUCE_FMAX: return "vp_reduce_fmax"; + case ISD::VP_REDUCE_FMIN: return "vp_reduce_fmin"; } } diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp @@ -816,6 +816,10 @@ CurDAG->Combine(BeforeLegalizeTypes, AA, OptLevel); } + if (getenv("SDEBUG")) { + CurDAG->dump(); + } + #ifndef NDEBUG if (TTI.hasBranchDivergence()) CurDAG->VerifyDAGDiverence(); diff --git a/llvm/lib/CodeGen/TargetPassConfig.cpp b/llvm/lib/CodeGen/TargetPassConfig.cpp --- a/llvm/lib/CodeGen/TargetPassConfig.cpp +++ b/llvm/lib/CodeGen/TargetPassConfig.cpp @@ -672,6 +672,11 @@ // Instrument function entry and exit, e.g. with calls to mcount(). addPass(createPostInlineEntryExitInstrumenterPass()); + // Expand vector predication intrinsics into standard IR instructions. + // This pass has to run before ScalarizeMaskedMemIntrin and ExpandReduction + // passes since it emits those kinds of intrinsics. + addPass(createExpandVectorPredicationPass()); + // Add scalarization of target's unsupported masked memory intrinsics pass. // the unsupported intrinsic will be replaced with a chain of basic blocks, // that stores/loads element one-by-one if the appropriate mask bit is set. diff --git a/llvm/lib/IR/Attributes.cpp b/llvm/lib/IR/Attributes.cpp --- a/llvm/lib/IR/Attributes.cpp +++ b/llvm/lib/IR/Attributes.cpp @@ -290,6 +290,8 @@ return "builtin"; if (hasAttribute(Attribute::Convergent)) return "convergent"; + if (hasAttribute(Attribute::VectorLength)) + return "vlen"; if (hasAttribute(Attribute::SwiftError)) return "swifterror"; if (hasAttribute(Attribute::SwiftSelf)) @@ -306,6 +308,10 @@ return "inreg"; if (hasAttribute(Attribute::JumpTable)) return "jumptable"; + if (hasAttribute(Attribute::Mask)) + return "mask"; + if (hasAttribute(Attribute::Passthru)) + return "passthru"; if (hasAttribute(Attribute::MinSize)) return "minsize"; if (hasAttribute(Attribute::Naked)) diff --git a/llvm/lib/IR/CMakeLists.txt b/llvm/lib/IR/CMakeLists.txt --- a/llvm/lib/IR/CMakeLists.txt +++ b/llvm/lib/IR/CMakeLists.txt @@ -46,18 +46,19 @@ PassManager.cpp PassRegistry.cpp PassTimingInfo.cpp + PredicatedInst.cpp + ProfileSummary.cpp RemarkStreamer.cpp SafepointIRVerifier.cpp - ProfileSummary.cpp Statepoint.cpp Type.cpp TypeFinder.cpp Use.cpp User.cpp + VPBuilder.cpp Value.cpp ValueSymbolTable.cpp Verifier.cpp - ADDITIONAL_HEADER_DIRS ${LLVM_MAIN_INCLUDE_DIR}/llvm/IR diff --git a/llvm/lib/IR/IRBuilder.cpp b/llvm/lib/IR/IRBuilder.cpp --- a/llvm/lib/IR/IRBuilder.cpp +++ b/llvm/lib/IR/IRBuilder.cpp @@ -511,6 +511,65 @@ return createCallHelper(TheFn, Ops, this, Name); } + +/// Create a call to a vector-predicated intrinsic (VP). +/// \p OC - The LLVM IR Opcode of the operation +/// \p VecOpArray - Intrinsic operand list +/// \p FMFSource - Copy source for Fast Math Flags +/// \p Name - name of the result variable +Instruction *IRBuilderBase::CreateVectorPredicatedInst( + unsigned OC, ArrayRef Params, Instruction *FMFSource, + const Twine &Name) { + + Module *M = BB->getParent()->getParent(); + + using ShortTypeVec = VPIntrinsic::ShortTypeVec; + using ShortValueVec = SmallVector; + + Intrinsic::ID VPID = VPIntrinsic::GetForOpcode(OC); + auto VPFunc = VPIntrinsic::GetDeclarationForParams(M, VPID, Params); + auto *VPCall = createCallHelper(VPFunc, Params, this, Name); + + // transfer fast math flags + if (FMFSource && isa(FMFSource)) { + VPCall->copyFastMathFlags(FMFSource); + } + + return VPCall; +} + +/// Create a call to a vector-predicated comparison intrinsic (VP). +/// \p Pred - comparison predicate +/// \p FirstOp - First vector operand +/// \p SndOp - Second vector operand +/// \p Mask - Mask operand +/// \p VectorLength - Vector length operand +/// \p Name - name of the result variable +Instruction *IRBuilderBase::CreateVectorPredicatedCmp(CmpInst::Predicate Pred, + Value *FirstParam, + Value *SndParam, Value *MaskParam, + Value *VectorLengthParam, + const Twine &Name) { + + Module *M = BB->getParent()->getParent(); + + // encode comparison predicate as MD + uint8_t RawPred = static_cast(Pred); + auto Int8Ty = Type::getInt8Ty(getContext()); + auto PredParam = ConstantInt::get(Int8Ty, RawPred, false); + + Intrinsic::ID VPID = FirstParam->getType()->isIntOrIntVectorTy() + ? Intrinsic::vp_icmp + : Intrinsic::vp_fcmp; + + auto VPFunc = VPIntrinsic::GetDeclarationForParams( + M, VPID, {FirstParam, SndParam, PredParam, MaskParam, VectorLengthParam}); + + return createCallHelper( + VPFunc, {FirstParam, SndParam, PredParam, MaskParam, VectorLengthParam}, + this, Name); +} + /// Create a call to a Masked Gather intrinsic. /// \p Ptrs - vector of pointers for loading /// \p Align - alignment for one element diff --git a/llvm/lib/IR/IntrinsicInst.cpp b/llvm/lib/IR/IntrinsicInst.cpp --- a/llvm/lib/IR/IntrinsicInst.cpp +++ b/llvm/lib/IR/IntrinsicInst.cpp @@ -21,13 +21,13 @@ //===----------------------------------------------------------------------===// #include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/Operator.h" #include "llvm/ADT/StringSwitch.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" +#include "llvm/IR/Operator.h" #include "llvm/Support/raw_ostream.h" using namespace llvm; @@ -102,55 +102,918 @@ return ConstantInt::get(Type::getInt64Ty(Context), 1); } -Optional -ConstrainedFPIntrinsic::getRoundingMode() const { - unsigned NumOperands = getNumArgOperands(); - Metadata *MD = - cast(getArgOperand(NumOperands - 2))->getMetadata(); - if (!MD || !isa(MD)) - return None; - return StrToRoundingMode(cast(MD)->getString()); +Optional +llvm::StrToExceptionBehavior(StringRef ExceptionArg) { + return StringSwitch>(ExceptionArg) + .Case("fpexcept.ignore", ExceptionBehavior::ebIgnore) + .Case("fpexcept.maytrap", ExceptionBehavior::ebMayTrap) + .Case("fpexcept.strict", ExceptionBehavior::ebStrict) + .Default(None); } -Optional -ConstrainedFPIntrinsic::StrToRoundingMode(StringRef RoundingArg) { +Optional llvm::ExceptionBehaviorToStr(ExceptionBehavior UseExcept) { + Optional ExceptStr = None; + switch (UseExcept) { + default: + break; + case ExceptionBehavior::ebStrict: + ExceptStr = "fpexcept.strict"; + break; + case ExceptionBehavior::ebIgnore: + ExceptStr = "fpexcept.ignore"; + break; + case ExceptionBehavior::ebMayTrap: + ExceptStr = "fpexcept.maytrap"; + break; + } + return ExceptStr; +} + +Optional llvm::StrToRoundingMode(StringRef RoundingArg) { // For dynamic rounding mode, we use round to nearest but we will set the // 'exact' SDNodeFlag so that the value will not be rounded. return StringSwitch>(RoundingArg) - .Case("round.dynamic", rmDynamic) - .Case("round.tonearest", rmToNearest) - .Case("round.downward", rmDownward) - .Case("round.upward", rmUpward) - .Case("round.towardzero", rmTowardZero) - .Default(None); + .Case("round.dynamic", RoundingMode::rmDynamic) + .Case("round.tonearest", RoundingMode::rmToNearest) + .Case("round.downward", RoundingMode::rmDownward) + .Case("round.upward", RoundingMode::rmUpward) + .Case("round.towardzero", RoundingMode::rmTowardZero) + .Default(None); } -Optional -ConstrainedFPIntrinsic::RoundingModeToStr(RoundingMode UseRounding) { +Optional llvm::RoundingModeToStr(RoundingMode UseRounding) { Optional RoundingStr = None; switch (UseRounding) { - case ConstrainedFPIntrinsic::rmDynamic: + default: + break; + case RoundingMode::rmDynamic: RoundingStr = "round.dynamic"; break; - case ConstrainedFPIntrinsic::rmToNearest: + case RoundingMode::rmToNearest: RoundingStr = "round.tonearest"; break; - case ConstrainedFPIntrinsic::rmDownward: + case RoundingMode::rmDownward: RoundingStr = "round.downward"; break; - case ConstrainedFPIntrinsic::rmUpward: + case RoundingMode::rmUpward: RoundingStr = "round.upward"; break; - case ConstrainedFPIntrinsic::rmTowardZero: + case RoundingMode::rmTowardZero: RoundingStr = "round.towardzero"; break; } return RoundingStr; } -Optional +/// Return the IR Value representation of any ExceptionBehavior. +Value *llvm::GetConstrainedFPExcept(LLVMContext &Context, + ExceptionBehavior UseExcept) { + Optional ExceptStr = ExceptionBehaviorToStr(UseExcept); + assert(ExceptStr.hasValue() && "Garbage strict exception behavior!"); + auto *ExceptMDS = MDString::get(Context, ExceptStr.getValue()); + + return MetadataAsValue::get(Context, ExceptMDS); +} + +/// Return the IR Value representation of any RoundingMode. +Value *llvm::GetConstrainedFPRounding(LLVMContext &Context, + RoundingMode UseRounding) { + Optional RoundingStr = RoundingModeToStr(UseRounding); + assert(RoundingStr.hasValue() && "Garbage strict rounding mode!"); + auto *RoundingMDS = MDString::get(Context, RoundingStr.getValue()); + + return MetadataAsValue::get(Context, RoundingMDS); +} + + +Optional VPIntrinsic::getStaticVectorLength() const { + auto GetStaticVectorLengthOfType = [](const Type *T) -> Optional { + auto VT = dyn_cast(T); + if (!VT || VT->isScalable()) + return None; + + // Corner case for excessive number of elements in the vector type + auto Num = VT->getNumElements(); + if (Num >= + static_cast(std::numeric_limits::max())) { + return std::numeric_limits::max(); + } + + return static_cast(Num); + }; + + auto VPMask = getMaskParam(); + if (VPMask) { + return GetStaticVectorLengthOfType(VPMask->getType()); + } + + // only compose does not have a mask param + assert(getIntrinsicID() == Intrinsic::vp_compose); + return GetStaticVectorLengthOfType(getType()); +} + +void VPIntrinsic::setMaskParam(Value *NewMask) { + auto MaskPos = GetMaskParamPos(getIntrinsicID()); + assert(MaskPos.hasValue()); + this->setOperand(MaskPos.getValue(), NewMask); +} + +void VPIntrinsic::setVectorLengthParam(Value *NewVL) { + auto VLPos = GetVectorLengthParamPos(getIntrinsicID()); + assert(VLPos.hasValue()); + this->setOperand(VLPos.getValue(), NewVL); +} + +Value *VPIntrinsic::getMaskParam() const { + auto maskPos = GetMaskParamPos(getIntrinsicID()); + if (maskPos) + return getArgOperand(maskPos.getValue()); + return nullptr; +} + +Value *VPIntrinsic::getVectorLengthParam() const { + auto vlenPos = GetVectorLengthParamPos(getIntrinsicID()); + if (vlenPos) + return getArgOperand(vlenPos.getValue()); + return nullptr; +} + +Optional VPIntrinsic::GetMaskParamPos(Intrinsic::ID IntrinsicID) { + switch (IntrinsicID) { + default: + return None; + + // general cmp + case Intrinsic::vp_icmp: + case Intrinsic::vp_fcmp: + return 3; + + // int arith + case Intrinsic::vp_and: + case Intrinsic::vp_or: + case Intrinsic::vp_xor: + case Intrinsic::vp_ashr: + case Intrinsic::vp_lshr: + case Intrinsic::vp_shl: + case Intrinsic::vp_add: + case Intrinsic::vp_sub: + case Intrinsic::vp_mul: + case Intrinsic::vp_udiv: + case Intrinsic::vp_sdiv: + case Intrinsic::vp_urem: + case Intrinsic::vp_srem: + return 2; + + // memory + case Intrinsic::vp_load: + case Intrinsic::vp_gather: + return 1; + case Intrinsic::vp_store: + case Intrinsic::vp_scatter: + return 2; + + // shuffle + case Intrinsic::vp_select: + return 0; + + case Intrinsic::vp_compose: + return None; + + case Intrinsic::vp_compress: + case Intrinsic::vp_expand: + return 1; + case Intrinsic::vp_vshift: + return 2; + + // fp arith + case Intrinsic::vp_fneg: + return 2; + + case Intrinsic::vp_fadd: + case Intrinsic::vp_fsub: + case Intrinsic::vp_fmul: + case Intrinsic::vp_fdiv: + case Intrinsic::vp_frem: + return 4; + + case Intrinsic::vp_fma: + return 5; + + case Intrinsic::vp_ceil: + case Intrinsic::vp_cos: + case Intrinsic::vp_exp2: + case Intrinsic::vp_exp: + case Intrinsic::vp_floor: + case Intrinsic::vp_log10: + case Intrinsic::vp_log2: + case Intrinsic::vp_log: + return 3; + + case Intrinsic::vp_maxnum: + case Intrinsic::vp_minnum: + return 4; + + case Intrinsic::vp_nearbyint: + case Intrinsic::vp_pow: + case Intrinsic::vp_powi: + case Intrinsic::vp_rint: + case Intrinsic::vp_round: + case Intrinsic::vp_sin: + case Intrinsic::vp_sqrt: + case Intrinsic::vp_trunc: + return 3; + + case Intrinsic::vp_fptoui: + case Intrinsic::vp_fptosi: + case Intrinsic::vp_lround: + case Intrinsic::vp_llround: + return 2; + + case Intrinsic::vp_fpext: + case Intrinsic::vp_fptrunc: + return 3; + + // reductions + case Intrinsic::vp_reduce_add: + case Intrinsic::vp_reduce_mul: + case Intrinsic::vp_reduce_umin: + case Intrinsic::vp_reduce_umax: + case Intrinsic::vp_reduce_smin: + case Intrinsic::vp_reduce_smax: + case Intrinsic::vp_reduce_and: + case Intrinsic::vp_reduce_or: + case Intrinsic::vp_reduce_xor: + case Intrinsic::vp_reduce_fmin: + case Intrinsic::vp_reduce_fmax: + return 1; + + case Intrinsic::vp_reduce_fadd: + case Intrinsic::vp_reduce_fmul: + return 2; + } +} + +Optional VPIntrinsic::GetVectorLengthParamPos(Intrinsic::ID IntrinsicID) { + switch (IntrinsicID) { + default: + break; + + case Intrinsic::vp_select: + case Intrinsic::vp_compose: + return 3; + } + + auto maskPos = GetMaskParamPos(IntrinsicID); + if (maskPos) { + return maskPos.getValue() + 1; + } + return None; +} + +Intrinsic::ID VPIntrinsic::GetForOpcode(unsigned OC) { + switch (OC) { + default: + return Intrinsic::not_intrinsic; + + // fp unary + case Instruction::FNeg: + return Intrinsic::vp_fneg; + + // fp binary + case Instruction::FAdd: + return Intrinsic::vp_fadd; + case Instruction::FSub: + return Intrinsic::vp_fsub; + case Instruction::FMul: + return Intrinsic::vp_fmul; + case Instruction::FDiv: + return Intrinsic::vp_fdiv; + case Instruction::FRem: + return Intrinsic::vp_frem; + + // sign-oblivious int + case Instruction::Add: + return Intrinsic::vp_add; + case Instruction::Sub: + return Intrinsic::vp_sub; + case Instruction::Mul: + return Intrinsic::vp_mul; + + // signed/unsigned int + case Instruction::SDiv: + return Intrinsic::vp_sdiv; + case Instruction::UDiv: + return Intrinsic::vp_udiv; + case Instruction::SRem: + return Intrinsic::vp_srem; + case Instruction::URem: + return Intrinsic::vp_urem; + + // logical + case Instruction::Or: + return Intrinsic::vp_or; + case Instruction::And: + return Intrinsic::vp_and; + case Instruction::Xor: + return Intrinsic::vp_xor; + + case Instruction::LShr: + return Intrinsic::vp_lshr; + case Instruction::AShr: + return Intrinsic::vp_ashr; + case Instruction::Shl: + return Intrinsic::vp_shl; + + // comparison + case Instruction::ICmp: + return Intrinsic::vp_icmp; + case Instruction::FCmp: + return Intrinsic::vp_fcmp; + } +} + +bool VPIntrinsic::canIgnoreVectorLengthParam() const { + auto StaticVL = getStaticVectorLength(); + if (!StaticVL.hasValue()) + return false; + + auto *VLParam = getVectorLengthParam(); + assert(VLParam); + + // Check whether the vector length param is an out-of-range constant. + auto VLConst = dyn_cast(VLParam); + if (!VLConst) + return false; + int64_t VLNum = VLConst->getSExtValue(); + if (VLNum < 0 || VLNum >= StaticVL.getValue()) + return true; + + return false; +} + +CmpInst::Predicate VPIntrinsic::getCmpPredicate() const { + return static_cast( + cast(getArgOperand(2))->getZExtValue()); +} + +Optional VPIntrinsic::getRoundingMode() const { + auto RmParamPos = GetRoundingModeParamPos(getIntrinsicID()); + if (!RmParamPos) + return None; + + Metadata *MD = dyn_cast(getArgOperand(RmParamPos.getValue())) + ->getMetadata(); + if (!MD || !isa(MD)) + return None; + StringRef RoundingArg = cast(MD)->getString(); + return StrToRoundingMode(RoundingArg); +} + +Optional VPIntrinsic::getExceptionBehavior() const { + auto EbParamPos = GetExceptionBehaviorParamPos(getIntrinsicID()); + if (!EbParamPos) + return None; + + Metadata *MD = dyn_cast(getArgOperand(EbParamPos.getValue())) + ->getMetadata(); + if (!MD || !isa(MD)) + return None; + StringRef ExceptionArg = cast(MD)->getString(); + return StrToExceptionBehavior(ExceptionArg); +} + +/// \return The vector to reduce if this is a reduction operation. +Value *VPIntrinsic::getReductionVectorParam() const { + auto PosOpt = GetReductionVectorParamPos(getIntrinsicID()); + if (!PosOpt.hasValue()) + return nullptr; + return getArgOperand(PosOpt.getValue()); +} + +Optional VPIntrinsic::GetReductionVectorParamPos(Intrinsic::ID VPID) { + switch (VPID) { + default: + return None; + + case Intrinsic::vp_reduce_add: + case Intrinsic::vp_reduce_mul: + case Intrinsic::vp_reduce_and: + case Intrinsic::vp_reduce_or: + case Intrinsic::vp_reduce_xor: + case Intrinsic::vp_reduce_smin: + case Intrinsic::vp_reduce_smax: + case Intrinsic::vp_reduce_umin: + case Intrinsic::vp_reduce_umax: + case Intrinsic::vp_reduce_fmin: + case Intrinsic::vp_reduce_fmax: + return 0; + + case Intrinsic::vp_reduce_fadd: + case Intrinsic::vp_reduce_fmul: + return 1; + } +} + +Optional VPIntrinsic::GetReductionAccuParamPos(Intrinsic::ID VPID) { + switch (VPID) { + default: + return None; + + case Intrinsic::vp_reduce_fadd: + case Intrinsic::vp_reduce_fmul: + return 0; + } +} + +/// \return The accumulator initial value if this is a reduction operation. +Value *VPIntrinsic::getReductionAccuParam() const { + auto PosOpt = GetReductionAccuParamPos(getIntrinsicID()); + if (!PosOpt.hasValue()) + return nullptr; + return getArgOperand(PosOpt.getValue()); +} + +/// \return The pointer operand of this load,store, gather or scatter. +Value *VPIntrinsic::getMemoryPointerParam() const { + auto PtrParamOpt = GetMemoryPointerParamPos(getIntrinsicID()); + if (!PtrParamOpt.hasValue()) + return nullptr; + return getArgOperand(PtrParamOpt.getValue()); +} + +Optional VPIntrinsic::GetMemoryPointerParamPos(Intrinsic::ID VPID) { + switch (VPID) { + default: + return None; + + case Intrinsic::vp_load: + return 0; + case Intrinsic::vp_gather: + return 0; + case Intrinsic::vp_store: + return 1; + case Intrinsic::vp_scatter: + return 1; + } +} + +/// \return The data (payload) operand of this store or scatter. +Value *VPIntrinsic::getMemoryDataParam() const { + auto DataParamOpt = GetMemoryDataParamPos(getIntrinsicID()); + if (!DataParamOpt.hasValue()) + return nullptr; + return getArgOperand(DataParamOpt.getValue()); +} + +Optional VPIntrinsic::GetMemoryDataParamPos(Intrinsic::ID VPID) { + switch (VPID) { + default: + return None; + + case Intrinsic::vp_store: + return 0; + case Intrinsic::vp_scatter: + return 0; + } +} + +Function *VPIntrinsic::GetDeclarationForParams(Module *M, Intrinsic::ID VPID, + ArrayRef Params) { + assert(VPID != Intrinsic::not_intrinsic && "todo dispatch to default insts"); + + bool IsArithOp = VPIntrinsic::IsBinaryVPOp(VPID) || + VPIntrinsic::IsUnaryVPOp(VPID) || + VPIntrinsic::IsTernaryVPOp(VPID); + bool IsCmpOp = (VPID == Intrinsic::vp_icmp) || (VPID == Intrinsic::vp_fcmp); + bool IsReduceOp = VPIntrinsic::IsVPReduction(VPID); + bool IsShuffleOp = + (VPID == Intrinsic::vp_compress) || (VPID == Intrinsic::vp_expand) || + (VPID == Intrinsic::vp_vshift) || (VPID == Intrinsic::vp_select) || + (VPID == Intrinsic::vp_compose); + bool IsMemoryOp = + (VPID == Intrinsic::vp_store) || (VPID == Intrinsic::vp_load) || + (VPID == Intrinsic::vp_store) || (VPID == Intrinsic::vp_load); + + Type *VecTy = nullptr; + Type *VecRetTy = nullptr; + Type *VecPtrTy = nullptr; + + if (IsArithOp || IsCmpOp) { + Value &FirstOp = *Params[0]; + + // Fetch the VP intrinsic + VecTy = cast(FirstOp.getType()); + VecRetTy = VecTy; + + } else if (IsReduceOp) { + auto VectorPosOpt = GetReductionVectorParamPos(VPID); + auto AccuPosOpt = GetReductionAccuParamPos(VPID); + Value *VectorParam = Params[VectorPosOpt.getValue()]; + + VecTy = VectorParam->getType(); + + if (AccuPosOpt.hasValue()) { + Value *AccuParam = Params[AccuPosOpt.getValue()]; + VecRetTy = AccuParam->getType(); + } else { + VecRetTy = VecTy; + } + + } else if (IsMemoryOp) { + auto DataPosOpt = VPIntrinsic::GetMemoryDataParamPos(VPID); + auto PtrPosOpt = VPIntrinsic::GetMemoryPointerParamPos(VPID); + VecPtrTy = Params[PtrPosOpt.getValue()]->getType(); + + if (DataPosOpt.hasValue()) { + // store-kind operation + VecTy = Params[DataPosOpt.getValue()]->getType(); + } else { + // load-kind operation + VecTy = VecPtrTy->getPointerElementType(); + } + + } else if (IsShuffleOp) { + VecTy = (VPID == Intrinsic::vp_select) ? Params[1]->getType() + : Params[0]->getType(); + VecRetTy = VecTy; + } + + auto TypeTokens = VPIntrinsic::GetTypeTokens(VPID); + auto *VPFunc = Intrinsic::getDeclaration( + M, VPID, + VPIntrinsic::EncodeTypeTokens(TypeTokens, VecTy, VecPtrTy, *VecTy)); + assert(VPFunc && "not a VP intrinsic"); + + return VPFunc; +} + +VPIntrinsic::TypeTokenVec VPIntrinsic::GetTypeTokens(Intrinsic::ID ID) { + switch (ID) { + default: + llvm_unreachable("not implemented!"); + + case Intrinsic::vp_cos: + case Intrinsic::vp_sin: + case Intrinsic::vp_exp: + case Intrinsic::vp_exp2: + + case Intrinsic::vp_log: + case Intrinsic::vp_log2: + case Intrinsic::vp_log10: + case Intrinsic::vp_sqrt: + case Intrinsic::vp_ceil: + case Intrinsic::vp_floor: + case Intrinsic::vp_round: + case Intrinsic::vp_trunc: + case Intrinsic::vp_rint: + case Intrinsic::vp_nearbyint: + + case Intrinsic::vp_fadd: + case Intrinsic::vp_fsub: + case Intrinsic::vp_fmul: + case Intrinsic::vp_fdiv: + case Intrinsic::vp_frem: + case Intrinsic::vp_pow: + case Intrinsic::vp_powi: + case Intrinsic::vp_maxnum: + case Intrinsic::vp_minnum: + return TypeTokenVec{VPTypeToken::Returned}; + + case Intrinsic::vp_reduce_and: + case Intrinsic::vp_reduce_or: + case Intrinsic::vp_reduce_xor: + + case Intrinsic::vp_reduce_add: + case Intrinsic::vp_reduce_mul: + case Intrinsic::vp_reduce_fadd: + case Intrinsic::vp_reduce_fmul: + + case Intrinsic::vp_reduce_fmin: + case Intrinsic::vp_reduce_fmax: + case Intrinsic::vp_reduce_smin: + case Intrinsic::vp_reduce_smax: + case Intrinsic::vp_reduce_umin: + case Intrinsic::vp_reduce_umax: + return TypeTokenVec{VPTypeToken::Vector}; + + case Intrinsic::vp_gather: + case Intrinsic::vp_load: + return TypeTokenVec{VPTypeToken::Returned, VPTypeToken::Pointer}; + + case Intrinsic::vp_scatter: + case Intrinsic::vp_store: + return TypeTokenVec{VPTypeToken::Pointer, VPTypeToken::Vector}; + + case Intrinsic::vp_fpext: + case Intrinsic::vp_fptrunc: + case Intrinsic::vp_fptoui: + case Intrinsic::vp_fptosi: + case Intrinsic::vp_lround: + case Intrinsic::vp_llround: + case Intrinsic::vp_lrint: + case Intrinsic::vp_llrint: + return TypeTokenVec{VPTypeToken::Returned, VPTypeToken::Vector}; + + case Intrinsic::vp_icmp: + case Intrinsic::vp_fcmp: + return TypeTokenVec{VPTypeToken::Mask, VPTypeToken::Vector}; + } +} + +bool VPIntrinsic::isReductionOp() const { + return IsVPReduction(getIntrinsicID()); +} + +bool VPIntrinsic::IsVPReduction(Intrinsic::ID ID) { + switch (ID) { + default: + return false; + + case Intrinsic::vp_reduce_and: + case Intrinsic::vp_reduce_or: + case Intrinsic::vp_reduce_xor: + + case Intrinsic::vp_reduce_add: + case Intrinsic::vp_reduce_mul: + case Intrinsic::vp_reduce_smin: + case Intrinsic::vp_reduce_smax: + case Intrinsic::vp_reduce_umin: + case Intrinsic::vp_reduce_umax: + + case Intrinsic::vp_reduce_fadd: + case Intrinsic::vp_reduce_fmul: + case Intrinsic::vp_reduce_fmin: + case Intrinsic::vp_reduce_fmax: + + return true; + } +} + +bool VPIntrinsic::isConstrainedOp() const { + return (getRoundingMode() != None && + getRoundingMode() != RoundingMode::rmToNearest) || + (getExceptionBehavior() != None && + getExceptionBehavior() != ExceptionBehavior::ebIgnore); +} + +bool VPIntrinsic::isUnaryOp() const { return IsUnaryVPOp(getIntrinsicID()); } + +bool VPIntrinsic::IsUnaryVPOp(Intrinsic::ID VPID) { + return VPID == Intrinsic::vp_fneg; +} + +bool VPIntrinsic::isBinaryOp() const { return IsBinaryVPOp(getIntrinsicID()); } + +bool VPIntrinsic::IsBinaryVPOp(Intrinsic::ID VPID) { + switch (VPID) { + default: + return false; + + case Intrinsic::vp_and: + case Intrinsic::vp_or: + case Intrinsic::vp_xor: + case Intrinsic::vp_ashr: + case Intrinsic::vp_lshr: + case Intrinsic::vp_shl: + + case Intrinsic::vp_fadd: + case Intrinsic::vp_fsub: + case Intrinsic::vp_fmul: + case Intrinsic::vp_fdiv: + case Intrinsic::vp_frem: + + case Intrinsic::vp_add: + case Intrinsic::vp_sub: + case Intrinsic::vp_mul: + case Intrinsic::vp_udiv: + case Intrinsic::vp_sdiv: + case Intrinsic::vp_urem: + case Intrinsic::vp_srem: + return true; + } +} + +bool VPIntrinsic::IsTernaryVPOp(Intrinsic::ID VPID) { + switch (VPID) { + default: + return false; + + case Intrinsic::vp_fma: + return true; + } +} + +bool VPIntrinsic::isTernaryOp() const { + return IsTernaryVPOp(getIntrinsicID()); +} + +Optional +VPIntrinsic::GetExceptionBehaviorParamPos(Intrinsic::ID IntrinsicID) { + switch (IntrinsicID) { + default: + return None; + + case Intrinsic::vp_fadd: + case Intrinsic::vp_fsub: + case Intrinsic::vp_fmul: + case Intrinsic::vp_fdiv: + case Intrinsic::vp_frem: + return 3; + + case Intrinsic::vp_fma: + return 4; + + case Intrinsic::vp_ceil: + case Intrinsic::vp_cos: + case Intrinsic::vp_exp2: + case Intrinsic::vp_exp: + case Intrinsic::vp_floor: + case Intrinsic::vp_log10: + case Intrinsic::vp_log2: + case Intrinsic::vp_log: + return 2; + + case Intrinsic::vp_maxnum: + case Intrinsic::vp_minnum: + return 3; + case Intrinsic::vp_nearbyint: + case Intrinsic::vp_pow: + case Intrinsic::vp_powi: + case Intrinsic::vp_rint: + case Intrinsic::vp_round: + case Intrinsic::vp_sin: + case Intrinsic::vp_sqrt: + case Intrinsic::vp_trunc: + return 2; + + case Intrinsic::vp_fpext: + case Intrinsic::vp_fptrunc: + return 2; + + case Intrinsic::vp_fptoui: + case Intrinsic::vp_fptosi: + case Intrinsic::vp_lround: + case Intrinsic::vp_llround: + return 1; + } +} + +Optional VPIntrinsic::GetRoundingModeParamPos(Intrinsic::ID IntrinsicID) { + switch (IntrinsicID) { + default: + return None; + + case Intrinsic::vp_fadd: + case Intrinsic::vp_fsub: + case Intrinsic::vp_fmul: + case Intrinsic::vp_fdiv: + case Intrinsic::vp_frem: + return 2; + + case Intrinsic::vp_fma: + return 3; + + case Intrinsic::vp_ceil: + case Intrinsic::vp_cos: + case Intrinsic::vp_exp2: + case Intrinsic::vp_exp: + case Intrinsic::vp_floor: + case Intrinsic::vp_log10: + case Intrinsic::vp_log2: + case Intrinsic::vp_log: + return 1; + + case Intrinsic::vp_maxnum: + case Intrinsic::vp_minnum: + return 2; + case Intrinsic::vp_nearbyint: + case Intrinsic::vp_pow: + case Intrinsic::vp_powi: + case Intrinsic::vp_rint: + case Intrinsic::vp_round: + case Intrinsic::vp_sin: + case Intrinsic::vp_sqrt: + case Intrinsic::vp_trunc: + return 1; + + case Intrinsic::vp_fptoui: + case Intrinsic::vp_fptosi: + case Intrinsic::vp_lround: + case Intrinsic::vp_llround: + return None; + + case Intrinsic::vp_fpext: + case Intrinsic::vp_fptrunc: + return 2; + } +} + +Intrinsic::ID +VPIntrinsic::GetForConstrainedIntrinsic(Intrinsic::ID IntrinsicID) { + switch (IntrinsicID) { + default: + return Intrinsic::not_intrinsic; + + // llvm.experimental.constrained.* + case Intrinsic::experimental_constrained_cos: + return Intrinsic::vp_cos; + case Intrinsic::experimental_constrained_sin: + return Intrinsic::vp_sin; + case Intrinsic::experimental_constrained_exp: + return Intrinsic::vp_exp; + case Intrinsic::experimental_constrained_exp2: + return Intrinsic::vp_exp2; + case Intrinsic::experimental_constrained_log: + return Intrinsic::vp_log; + case Intrinsic::experimental_constrained_log2: + return Intrinsic::vp_log2; + case Intrinsic::experimental_constrained_log10: + return Intrinsic::vp_log10; + case Intrinsic::experimental_constrained_sqrt: + return Intrinsic::vp_sqrt; + case Intrinsic::experimental_constrained_ceil: + return Intrinsic::vp_ceil; + case Intrinsic::experimental_constrained_floor: + return Intrinsic::vp_floor; + case Intrinsic::experimental_constrained_round: + return Intrinsic::vp_round; + case Intrinsic::experimental_constrained_trunc: + return Intrinsic::vp_trunc; + case Intrinsic::experimental_constrained_rint: + return Intrinsic::vp_rint; + case Intrinsic::experimental_constrained_nearbyint: + return Intrinsic::vp_nearbyint; + + case Intrinsic::experimental_constrained_fadd: + return Intrinsic::vp_fadd; + case Intrinsic::experimental_constrained_fsub: + return Intrinsic::vp_fsub; + case Intrinsic::experimental_constrained_fmul: + return Intrinsic::vp_fmul; + case Intrinsic::experimental_constrained_fdiv: + return Intrinsic::vp_fdiv; + case Intrinsic::experimental_constrained_frem: + return Intrinsic::vp_frem; + case Intrinsic::experimental_constrained_pow: + return Intrinsic::vp_pow; + case Intrinsic::experimental_constrained_powi: + return Intrinsic::vp_powi; + case Intrinsic::experimental_constrained_maxnum: + return Intrinsic::vp_maxnum; + case Intrinsic::experimental_constrained_minnum: + return Intrinsic::vp_minnum; + + case Intrinsic::experimental_constrained_fma: + return Intrinsic::fma; + } +} + +VPIntrinsic::ShortTypeVec +VPIntrinsic::EncodeTypeTokens(VPIntrinsic::TypeTokenVec TTVec, Type *VecRetTy, + Type *VecPtrTy, Type &VectorTy) { + ShortTypeVec STV; + + for (auto Token : TTVec) { + switch (Token) { + default: + llvm_unreachable("unsupported token"); // unsupported VPTypeToken + + case VPIntrinsic::VPTypeToken::Vector: + STV.push_back(&VectorTy); + break; + case VPIntrinsic::VPTypeToken::Pointer: + STV.push_back(VecPtrTy); + break; + case VPIntrinsic::VPTypeToken::Returned: + assert(VecRetTy); + STV.push_back(VecRetTy); + break; + case VPIntrinsic::VPTypeToken::Mask: + auto NumElems = VectorTy.getVectorNumElements(); + auto MaskTy = + VectorType::get(Type::getInt1Ty(VectorTy.getContext()), NumElems); + STV.push_back(MaskTy); + break; + } + } + + return STV; +} + +Optional ConstrainedFPIntrinsic::getRoundingMode() const { + unsigned NumOperands = getNumArgOperands(); + assert(NumOperands >= 2 && "underflow"); + Metadata *MD = + cast(getArgOperand(NumOperands - 2))->getMetadata(); + if (!MD || !isa(MD)) + return None; + return StrToRoundingMode(cast(MD)->getString()); +} + +Optional ConstrainedFPIntrinsic::getExceptionBehavior() const { unsigned NumOperands = getNumArgOperands(); + assert(NumOperands >= 1 && "underflow"); Metadata *MD = cast(getArgOperand(NumOperands - 1))->getMetadata(); if (!MD || !isa(MD)) @@ -158,31 +1021,6 @@ return StrToExceptionBehavior(cast(MD)->getString()); } -Optional -ConstrainedFPIntrinsic::StrToExceptionBehavior(StringRef ExceptionArg) { - return StringSwitch>(ExceptionArg) - .Case("fpexcept.ignore", ebIgnore) - .Case("fpexcept.maytrap", ebMayTrap) - .Case("fpexcept.strict", ebStrict) - .Default(None); -} - -Optional -ConstrainedFPIntrinsic::ExceptionBehaviorToStr(ExceptionBehavior UseExcept) { - Optional ExceptStr = None; - switch (UseExcept) { - case ConstrainedFPIntrinsic::ebStrict: - ExceptStr = "fpexcept.strict"; - break; - case ConstrainedFPIntrinsic::ebIgnore: - ExceptStr = "fpexcept.ignore"; - break; - case ConstrainedFPIntrinsic::ebMayTrap: - ExceptStr = "fpexcept.maytrap"; - break; - } - return ExceptStr; -} bool ConstrainedFPIntrinsic::isUnaryOp() const { switch (getIntrinsicID()) { diff --git a/llvm/lib/IR/PredicatedInst.cpp b/llvm/lib/IR/PredicatedInst.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/IR/PredicatedInst.cpp @@ -0,0 +1,85 @@ +#include +#include +#include +#include +#include + +namespace { +using namespace llvm; +using ShortValueVec = SmallVector; +} // namespace + +namespace llvm { + +bool PredicatedInstruction::canIgnoreVectorLengthParam() const { + auto VPI = dyn_cast(this); + if (!VPI) + return true; + + return VPI->canIgnoreVectorLengthParam(); +} + +FastMathFlags PredicatedInstruction::getFastMathFlags() const { + return cast(this)->getFastMathFlags(); +} + +void PredicatedOperator::copyIRFlags(const Value *V, bool IncludeWrapFlags) { + auto *I = dyn_cast(this); + if (I) + I->copyIRFlags(V, IncludeWrapFlags); +} + +Instruction *PredicatedBinaryOperator::Create( + Module *Mod, Value *Mask, Value *VectorLen, Instruction::BinaryOps Opc, + Value *V1, Value *V2, const Twine &Name, BasicBlock *InsertAtEnd, + Instruction *InsertBefore) { + assert(!(InsertAtEnd && InsertBefore)); + auto VPID = VPIntrinsic::GetForOpcode(Opc); + + // Default Code Path + if ((!Mod || (!Mask && !VectorLen)) || VPID == Intrinsic::not_intrinsic) { + if (InsertAtEnd) { + return BinaryOperator::Create(Opc, V1, V2, Name, InsertAtEnd); + } else { + return BinaryOperator::Create(Opc, V1, V2, Name, InsertBefore); + } + } + + assert(Mod && "Need a module to emit VP Intrinsics"); + + // Fetch the VP intrinsic + auto &VecTy = cast(*V1->getType()); + auto TypeTokens = VPIntrinsic::GetTypeTokens(VPID); + auto *VPFunc = Intrinsic::getDeclaration( + Mod, VPID, + VPIntrinsic::EncodeTypeTokens(TypeTokens, &VecTy, nullptr, VecTy)); + + // Encode default environment fp behavior + LLVMContext &Ctx = V1->getContext(); + SmallVector BinOpArgs({V1, V2}); + if (VPIntrinsic::HasRoundingModeParam(VPID)) { + BinOpArgs.push_back( + GetConstrainedFPRounding(Ctx, RoundingMode::rmToNearest)); + } + if (VPIntrinsic::HasExceptionBehaviorParam(VPID)) { + BinOpArgs.push_back( + GetConstrainedFPExcept(Ctx, ExceptionBehavior::ebIgnore)); + } + + BinOpArgs.push_back(Mask); + BinOpArgs.push_back(VectorLen); + + CallInst *CI; + if (InsertAtEnd) { + CI = CallInst::Create(VPFunc, BinOpArgs, Name, InsertAtEnd); + } else { + CI = CallInst::Create(VPFunc, BinOpArgs, Name, InsertBefore); + } + + // the VP inst does not touch memory if the exception behavior is + // "fpecept.ignore" + CI->setDoesNotAccessMemory(); + return CI; +} + +} // namespace llvm diff --git a/llvm/lib/IR/VPBuilder.cpp b/llvm/lib/IR/VPBuilder.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/IR/VPBuilder.cpp @@ -0,0 +1,101 @@ +#include +#include +#include +#include +#include + +namespace { + using namespace llvm; + using ShortTypeVec = VPIntrinsic::ShortTypeVec; + using ShortValueVec = SmallVector; +} + +namespace llvm { + +Module & +VPBuilder::getModule() const { + return *Builder.GetInsertBlock()->getParent()->getParent(); +} + +Value& +VPBuilder::GetMaskForType(VectorType & VecTy) { + if (Mask) return *Mask; + + auto * boolTy = Builder.getInt1Ty(); + auto * maskTy = VectorType::get(boolTy, StaticVectorLength); + return *ConstantInt::getAllOnesValue(maskTy); +} + +Value& +VPBuilder::GetEVLForType(VectorType & VecTy) { + if (ExplicitVectorLength) return *ExplicitVectorLength; + + auto * intTy = Builder.getInt32Ty(); + return *ConstantInt::get(intTy, StaticVectorLength); +} + +Value* +VPBuilder::CreateVectorCopy(Instruction & Inst, ValArray VecOpArray) { + auto OC = Inst.getOpcode(); + auto VPID = VPIntrinsic::GetForOpcode(OC); + if (VPID == Intrinsic::not_intrinsic) { + return nullptr; + } + + abort(); // TODO implement + + return nullptr; +} + + +VectorType& +VPBuilder::getVectorType(Type &ElementTy) { + return *VectorType::get(&ElementTy, StaticVectorLength); +} + +Value& +VPBuilder::CreateContiguousStore(Value & Val, Value & Pointer, Align Alignment) { + auto & VecTy = cast(*Val.getType()); + auto * StoreFunc = Intrinsic::getDeclaration(&getModule(), Intrinsic::vp_store, {Val.getType(), Pointer.getType()}); + ShortValueVec Args{&Val, &Pointer, &GetMaskForType(VecTy), &GetEVLForType(VecTy)}; + CallInst &StoreCall = *Builder.CreateCall(StoreFunc, Args); + if (Alignment != None) StoreCall.addParamAttr(1, Attribute::getWithAlignment(getContext(), Alignment)); + return StoreCall; +} + +Value& +VPBuilder::CreateContiguousLoad(Value & Pointer, Align Alignment) { + auto & PointerTy = cast(*Pointer.getType()); + auto & VecTy = getVectorType(*PointerTy.getPointerElementType()); + + auto * LoadFunc = Intrinsic::getDeclaration(&getModule(), Intrinsic::vp_load, {&VecTy, &PointerTy}); + ShortValueVec Args{&Pointer, &GetMaskForType(VecTy), &GetEVLForType(VecTy)}; + CallInst &LoadCall= *Builder.CreateCall(LoadFunc, Args); + if (Alignment != None) LoadCall.addParamAttr(1, Attribute::getWithAlignment(getContext(), Alignment)); + return LoadCall; +} + +Value& +VPBuilder::CreateScatter(Value & Val, Value & PointerVec, Align Alignment) { + auto & VecTy = cast(*Val.getType()); + auto * ScatterFunc = Intrinsic::getDeclaration(&getModule(), Intrinsic::vp_scatter, {Val.getType(), PointerVec.getType()}); + ShortValueVec Args{&Val, &PointerVec, &GetMaskForType(VecTy), &GetEVLForType(VecTy)}; + CallInst &ScatterCall = *Builder.CreateCall(ScatterFunc, Args); + if (Alignment != None) ScatterCall.addParamAttr(1, Attribute::getWithAlignment(getContext(), Alignment)); + return ScatterCall; +} + +Value& +VPBuilder::CreateGather(Value & PointerVec, Align Alignment) { + auto & PointerVecTy = cast(*PointerVec.getType()); + auto & ElemTy = *cast(*PointerVecTy.getVectorElementType()).getPointerElementType(); + auto & VecTy = *VectorType::get(&ElemTy, PointerVecTy.getNumElements()); + auto * GatherFunc = Intrinsic::getDeclaration(&getModule(), Intrinsic::vp_gather, {&VecTy, &PointerVecTy}); + + ShortValueVec Args{&PointerVec, &GetMaskForType(VecTy), &GetEVLForType(VecTy)}; + CallInst &GatherCall = *Builder.CreateCall(GatherFunc, Args); + if (Alignment != None) GatherCall.addParamAttr(1, Attribute::getWithAlignment(getContext(), Alignment)); + return GatherCall; +} + +} // namespace llvm diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp --- a/llvm/lib/IR/Verifier.cpp +++ b/llvm/lib/IR/Verifier.cpp @@ -475,6 +475,7 @@ void visitUserOp2(Instruction &I) { visitUserOp1(I); } void visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call); void visitConstrainedFPIntrinsic(ConstrainedFPIntrinsic &FPI); + void visitVPIntrinsic(VPIntrinsic &FPI); void visitDbgIntrinsic(StringRef Kind, DbgVariableIntrinsic &DII); void visitDbgLabelIntrinsic(StringRef Kind, DbgLabelInst &DLI); void visitAtomicCmpXchgInst(AtomicCmpXchgInst &CXI); @@ -1689,11 +1690,14 @@ if (Attrs.isEmpty()) return; + bool SawMask = false; bool SawNest = false; + bool SawPassthru = false; bool SawReturned = false; bool SawSRet = false; bool SawSwiftSelf = false; bool SawSwiftError = false; + bool SawVectorLength = false; // Verify return value attributes. AttributeSet RetAttrs = Attrs.getRetAttributes(); @@ -1762,12 +1766,33 @@ SawSwiftError = true; } + if (ArgAttrs.hasAttribute(Attribute::VectorLength)) { + Assert(!SawVectorLength, "Cannot have multiple 'vlen' parameters!", + V); + SawVectorLength = true; + } + + if (ArgAttrs.hasAttribute(Attribute::Passthru)) { + Assert(!SawPassthru, "Cannot have multiple 'passthru' parameters!", + V); + SawPassthru = true; + } + + if (ArgAttrs.hasAttribute(Attribute::Mask)) { + Assert(!SawMask, "Cannot have multiple 'mask' parameters!", + V); + SawMask = true; + } + if (ArgAttrs.hasAttribute(Attribute::InAlloca)) { Assert(i == FT->getNumParams() - 1, "inalloca isn't on the last parameter!", V); } } + Assert(!SawPassthru || SawMask, + "Cannot have 'passthru' parameter without 'mask' parameter!", V); + if (!Attrs.hasAttributes(AttributeList::FunctionIndex)) return; @@ -3133,7 +3158,7 @@ /// visitUnaryOperator - Check the argument to the unary operator. /// void Verifier::visitUnaryOperator(UnaryOperator &U) { - Assert(U.getType() == U.getOperand(0)->getType(), + Assert(U.getType() == U.getOperand(0)->getType(), "Unary operators must have same type for" "operands and result!", &U); @@ -4334,6 +4359,94 @@ case Intrinsic::experimental_constrained_trunc: visitConstrainedFPIntrinsic(cast(Call)); break; + + // general cmp + case Intrinsic::vp_icmp: + case Intrinsic::vp_fcmp: + + // int arith + case Intrinsic::vp_and: + case Intrinsic::vp_or: + case Intrinsic::vp_xor: + case Intrinsic::vp_ashr: + case Intrinsic::vp_lshr: + case Intrinsic::vp_shl: + case Intrinsic::vp_add: + case Intrinsic::vp_sub: + case Intrinsic::vp_mul: + case Intrinsic::vp_udiv: + case Intrinsic::vp_sdiv: + case Intrinsic::vp_urem: + case Intrinsic::vp_srem: + + // memory + case Intrinsic::vp_load: + case Intrinsic::vp_store: + case Intrinsic::vp_gather: + case Intrinsic::vp_scatter: + + // shuffle + case Intrinsic::vp_select: + case Intrinsic::vp_compose: + case Intrinsic::vp_compress: + case Intrinsic::vp_expand: + case Intrinsic::vp_vshift: + + // fp arith + case Intrinsic::vp_fneg: + + case Intrinsic::vp_fadd: + case Intrinsic::vp_fsub: + case Intrinsic::vp_fmul: + case Intrinsic::vp_fdiv: + case Intrinsic::vp_frem: + + case Intrinsic::vp_fma: + case Intrinsic::vp_ceil: + case Intrinsic::vp_cos: + case Intrinsic::vp_exp2: + case Intrinsic::vp_exp: + case Intrinsic::vp_floor: + case Intrinsic::vp_log10: + case Intrinsic::vp_log2: + case Intrinsic::vp_log: + case Intrinsic::vp_maxnum: + case Intrinsic::vp_minnum: + case Intrinsic::vp_nearbyint: + case Intrinsic::vp_pow: + case Intrinsic::vp_powi: + case Intrinsic::vp_rint: + case Intrinsic::vp_round: + case Intrinsic::vp_sin: + case Intrinsic::vp_sqrt: + case Intrinsic::vp_trunc: + case Intrinsic::vp_fptoui: + case Intrinsic::vp_fptosi: + case Intrinsic::vp_fpext: + case Intrinsic::vp_fptrunc: + + case Intrinsic::vp_lround: + case Intrinsic::vp_llround: + + // reductions + case Intrinsic::vp_reduce_add: + case Intrinsic::vp_reduce_mul: + case Intrinsic::vp_reduce_umin: + case Intrinsic::vp_reduce_umax: + case Intrinsic::vp_reduce_smin: + case Intrinsic::vp_reduce_smax: + + case Intrinsic::vp_reduce_and: + case Intrinsic::vp_reduce_or: + case Intrinsic::vp_reduce_xor: + + case Intrinsic::vp_reduce_fadd: + case Intrinsic::vp_reduce_fmul: + case Intrinsic::vp_reduce_fmin: + case Intrinsic::vp_reduce_fmax: + visitVPIntrinsic(cast(Call)); + break; + case Intrinsic::dbg_declare: // llvm.dbg.declare Assert(isa(Call.getArgOperand(0)), "invalid llvm.dbg.declare intrinsic call 1", Call); @@ -4757,6 +4870,16 @@ return nullptr; } +void Verifier::visitVPIntrinsic(VPIntrinsic &VPI) { + Assert(!VPI.isConstrainedOp(), "VP intrinsics only support the default fp environment for now (round.tonearest; fpexcept.ignore)."); + if (VPI.isConstrainedOp()) { + Assert(VPI.getExceptionBehavior() != ExceptionBehavior::ebInvalid, + "invalid exception behavior argument", &VPI); + Assert(VPI.getRoundingMode() != RoundingMode::rmInvalid, + "invalid rounding mode argument", &VPI); + } +} + void Verifier::visitConstrainedFPIntrinsic(ConstrainedFPIntrinsic &FPI) { unsigned NumOperands = FPI.getNumArgOperands(); bool HasExceptionMD = false; @@ -5162,7 +5285,7 @@ bool runOnFunction(Function &F) override { if (!V->verify(F) && FatalErrors) { - errs() << "in function " << F.getName() << '\n'; + errs() << "in function " << F.getName() << '\n'; report_fatal_error("Broken function found, compilation aborted!"); } return false; diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -24,6 +24,9 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/Operator.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/PredicatedInst.h" +#include "llvm/IR/VPBuilder.h" +#include "llvm/IR/MatcherCast.h" #include "llvm/IR/Type.h" #include "llvm/IR/Value.h" #include "llvm/Support/AlignOf.h" @@ -2087,6 +2090,17 @@ return nullptr; } +Instruction *InstCombiner::visitPredicatedFSub(PredicatedBinaryOperator& I) { + auto * Inst = cast(&I); + PredicatedContext PC(&I); + if (Value *V = SimplifyPredicatedFSubInst(I.getOperand(0), I.getOperand(1), + I.getFastMathFlags(), + SQ.getWithInstruction(Inst), PC)) + return replaceInstUsesWith(*Inst, V); + + return visitFSubGeneric(*Inst); +} + Instruction *InstCombiner::visitFSub(BinaryOperator &I) { if (Value *V = SimplifyFSubInst(I.getOperand(0), I.getOperand(1), I.getFastMathFlags(), @@ -2096,11 +2110,19 @@ if (Instruction *X = foldVectorBinop(I)) return X; + return visitFSubGeneric(I); +} + +template +Instruction *InstCombiner::visitFSubGeneric(BinaryOpTy &I) { + MatchContextType MC(cast(&I)); + MatchContextBuilder MCBuilder(MC); + // Subtraction from -0.0 is the canonical form of fneg. // fsub nsz 0, X ==> fsub nsz -0.0, X Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - if (I.hasNoSignedZeros() && match(Op0, m_PosZeroFP())) - return BinaryOperator::CreateFNegFMF(Op1, &I); + if (I.hasNoSignedZeros() && MC.try_match(Op0, m_PosZeroFP())) + return MCBuilder.CreateFNegFMF(Op1, &I); if (Instruction *X = foldFNegIntoConstant(I)) return X; @@ -2111,6 +2133,18 @@ Value *X, *Y; Constant *C; + // Fold negation into constant operand. This is limited with one-use because + // fneg is assumed better for analysis and cheaper in codegen than fmul/fdiv. + // -(X * C) --> X * (-C) + if (MC.try_match(&I, m_FNeg(m_OneUse(m_FMul(m_Value(X), m_Constant(C)))))) + return MCBuilder.CreateFMulFMF(X, ConstantExpr::getFNeg(C), &I); + // -(X / C) --> X / (-C) + if (MC.try_match(&I, m_FNeg(m_OneUse(m_FDiv(m_Value(X), m_Constant(C)))))) + return MCBuilder.CreateFDivFMF(X, ConstantExpr::getFNeg(C), &I); + // -(C / X) --> (-C) / X + if (MC.try_match(&I, m_FNeg(m_OneUse(m_FDiv(m_Constant(C), m_Value(X)))))) + return MCBuilder.CreateFDivFMF(ConstantExpr::getFNeg(C), X, &I); + // If Op0 is not -0.0 or we can ignore -0.0: Z - (X - Y) --> Z + (Y - X) // Canonicalize to fadd to make analysis easier. // This can also help codegen because fadd is commutative. @@ -2118,36 +2152,37 @@ // killed later. We still limit that particular transform with 'hasOneUse' // because an fneg is assumed better/cheaper than a generic fsub. if (I.hasNoSignedZeros() || CannotBeNegativeZero(Op0, SQ.TLI)) { - if (match(Op1, m_OneUse(m_FSub(m_Value(X), m_Value(Y))))) { - Value *NewSub = Builder.CreateFSubFMF(Y, X, &I); - return BinaryOperator::CreateFAddFMF(Op0, NewSub, &I); + if (MC.try_match(Op1, m_OneUse(m_FSub(m_Value(X), m_Value(Y))))) { + Value *NewSub = MCBuilder.CreateFSubFMF(Builder, Y, X, &I); + return MCBuilder.CreateFAddFMF(Op0, NewSub, &I); } } - if (isa(Op0)) - if (SelectInst *SI = dyn_cast(Op1)) - if (Instruction *NV = FoldOpIntoSelect(I, SI)) - return NV; + if (auto * PlainBinOp = dyn_cast(&I)) + if (isa(Op0)) + if (SelectInst *SI = dyn_cast(Op1)) + if (Instruction *NV = FoldOpIntoSelect(*PlainBinOp, SI)) + return NV; // X - C --> X + (-C) // But don't transform constant expressions because there's an inverse fold // for X + (-Y) --> X - Y. - if (match(Op1, m_Constant(C)) && !isa(Op1)) - return BinaryOperator::CreateFAddFMF(Op0, ConstantExpr::getFNeg(C), &I); + if (MC.try_match(Op1, m_Constant(C)) && !isa(Op1)) + return MCBuilder.CreateFAddFMF(Op0, ConstantExpr::getFNeg(C), &I); // X - (-Y) --> X + Y - if (match(Op1, m_FNeg(m_Value(Y)))) - return BinaryOperator::CreateFAddFMF(Op0, Y, &I); + if (MC.try_match(Op1, m_FNeg(m_Value(Y)))) + return MCBuilder.CreateFAddFMF(Op0, Y, &I); // Similar to above, but look through a cast of the negated value: // X - (fptrunc(-Y)) --> X + fptrunc(Y) Type *Ty = I.getType(); - if (match(Op1, m_OneUse(m_FPTrunc(m_FNeg(m_Value(Y)))))) - return BinaryOperator::CreateFAddFMF(Op0, Builder.CreateFPTrunc(Y, Ty), &I); + if (MC.try_match(Op1, m_OneUse(m_FPTrunc(m_FNeg(m_Value(Y)))))) + return MCBuilder.CreateFAddFMF(Op0, MCBuilder.CreateFPTrunc(Builder, Y, Ty), &I); // X - (fpext(-Y)) --> X + fpext(Y) - if (match(Op1, m_OneUse(m_FPExt(m_FNeg(m_Value(Y)))))) - return BinaryOperator::CreateFAddFMF(Op0, Builder.CreateFPExt(Y, Ty), &I); + if (MC.try_match(Op1, m_OneUse(m_FPExt(m_FNeg(m_Value(Y)))))) + return MCBuilder.CreateFAddFMF(Op0, MCBuilder.CreateFPExt(Builder, Y, Ty), &I); // Similar to above, but look through fmul/fdiv of the negated value: // Op0 - (-X * Y) --> Op0 + (X * Y) @@ -2165,39 +2200,42 @@ } // Handle special cases for FSub with selects feeding the operation - if (Value *V = SimplifySelectsFeedingBinaryOp(I, Op0, Op1)) - return replaceInstUsesWith(I, V); + if (auto * PlainBinOp = dyn_cast(&I)) + if (Value *V = SimplifySelectsFeedingBinaryOp(*PlainBinOp, Op0, Op1)) + return replaceInstUsesWith(I, V); if (I.hasAllowReassoc() && I.hasNoSignedZeros()) { // (Y - X) - Y --> -X - if (match(Op0, m_FSub(m_Specific(Op1), m_Value(X)))) - return BinaryOperator::CreateFNegFMF(X, &I); + if (MC.try_match(Op0, m_FSub(m_Specific(Op1), m_Value(X)))) + return MCBuilder.CreateFNegFMF(X, &I); // Y - (X + Y) --> -X // Y - (Y + X) --> -X - if (match(Op1, m_c_FAdd(m_Specific(Op0), m_Value(X)))) - return BinaryOperator::CreateFNegFMF(X, &I); + if (MC.try_match(Op1, m_c_FAdd(m_Specific(Op0), m_Value(X)))) + return MCBuilder.CreateFNegFMF(X, &I); // (X * C) - X --> X * (C - 1.0) - if (match(Op0, m_FMul(m_Specific(Op1), m_Constant(C)))) { + if (MC.try_match(Op0, m_FMul(m_Specific(Op1), m_Constant(C)))) { Constant *CSubOne = ConstantExpr::getFSub(C, ConstantFP::get(Ty, 1.0)); - return BinaryOperator::CreateFMulFMF(Op1, CSubOne, &I); + return MCBuilder.CreateFMulFMF(Op1, CSubOne, &I); } // X - (X * C) --> X * (1.0 - C) - if (match(Op1, m_FMul(m_Specific(Op0), m_Constant(C)))) { + if (MC.try_match(Op1, m_FMul(m_Specific(Op0), m_Constant(C)))) { Constant *OneSubC = ConstantExpr::getFSub(ConstantFP::get(Ty, 1.0), C); - return BinaryOperator::CreateFMulFMF(Op0, OneSubC, &I); + return MCBuilder.CreateFMulFMF(Op0, OneSubC, &I); } - if (Instruction *F = factorizeFAddFSub(I, Builder)) - return F; + if (auto * PlainBinOp = dyn_cast(&I)) { + if (Instruction *F = factorizeFAddFSub(*PlainBinOp, Builder)) + return F; - // TODO: This performs reassociative folds for FP ops. Some fraction of the - // functionality has been subsumed by simple pattern matching here and in - // InstSimplify. We should let a dedicated reassociation pass handle more - // complex pattern matching and remove this from InstCombine. - if (Value *V = FAddCombine(Builder).simplify(&I)) - return replaceInstUsesWith(I, V); + // TODO: This performs reassociative folds for FP ops. Some fraction of the + // functionality has been subsumed by simple pattern matching here and in + // InstSimplify. We should let a dedicated reassociation pass handle more + // complex pattern matching and remove this from InstCombine. + if (Value *V = FAddCombine(Builder).simplify(PlainBinOp)) + return replaceInstUsesWith(*PlainBinOp, V); + } } return nullptr; diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -38,6 +38,7 @@ #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/PredicatedInst.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/LLVMContext.h" @@ -1793,6 +1794,14 @@ return &CI; } + // Predicated instruction patterns + auto * VPInst = dyn_cast(&CI); + if (VPInst) { + auto * PredInst = cast(VPInst); + auto Result = visitPredicatedInstruction(PredInst); + if (Result) return Result; + } + IntrinsicInst *II = dyn_cast(&CI); if (!II) return visitCallBase(CI); @@ -1857,7 +1866,8 @@ if (Changed) return II; } - // For vector result intrinsics, use the generic demanded vector support. + // For vector result intrinsics, use the generic demanded vector support to + // simplify any operands before moving on to the per-intrinsic rules. if (II->getType()->isVectorTy()) { auto VWidth = II->getType()->getVectorNumElements(); APInt UndefElts(VWidth, 0); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -30,6 +30,7 @@ #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/PredicatedInst.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/PatternMatch.h" #include "llvm/IR/Use.h" @@ -371,6 +372,8 @@ Instruction *visitFAdd(BinaryOperator &I); Value *OptimizePointerDifference(Value *LHS, Value *RHS, Type *Ty); Instruction *visitSub(BinaryOperator &I); + template Instruction *visitFSubGeneric(BinaryOpTy &I); + Instruction *visitPredicatedFSub(PredicatedBinaryOperator &I); Instruction *visitFSub(BinaryOperator &I); Instruction *visitMul(BinaryOperator &I); Instruction *visitFMul(BinaryOperator &I); @@ -447,6 +450,16 @@ Instruction *visitVAStartInst(VAStartInst &I); Instruction *visitVACopyInst(VACopyInst &I); + // Entry point to VPIntrinsic + Instruction *visitPredicatedInstruction(PredicatedInstruction * PI) { + switch (PI->getOpcode()) { + default: + return nullptr; + case Instruction::FSub: + return visitPredicatedFSub(cast(*PI)); + } + } + /// Specify what to return for unhandled instructions. Instruction *visitInstruction(Instruction &I) { return nullptr; } diff --git a/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/llvm/lib/Transforms/Utils/CodeExtractor.cpp --- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp +++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp @@ -860,6 +860,7 @@ case Attribute::InaccessibleMemOnly: case Attribute::InaccessibleMemOrArgMemOnly: case Attribute::JumpTable: + case Attribute::Mask: case Attribute::Naked: case Attribute::Nest: case Attribute::NoAlias: @@ -869,6 +870,7 @@ case Attribute::NoSync: case Attribute::None: case Attribute::NonNull: + case Attribute::Passthru: case Attribute::ReadNone: case Attribute::ReadOnly: case Attribute::Returned: @@ -880,6 +882,7 @@ case Attribute::SwiftError: case Attribute::SwiftSelf: case Attribute::WillReturn: + case Attribute::VectorLength: case Attribute::WriteOnly: case Attribute::ZExt: case Attribute::ImmArg: diff --git a/llvm/test/Bitcode/attributes.ll b/llvm/test/Bitcode/attributes.ll --- a/llvm/test/Bitcode/attributes.ll +++ b/llvm/test/Bitcode/attributes.ll @@ -374,6 +374,11 @@ ret void; } +; CHECK: define <8 x double> @f64(<8 x double> passthru %0, <8 x i1> mask %1, i32 vlen %2) { +define <8 x double> @f64(<8 x double> passthru, <8 x i1> mask, i32 vlen) { + ret <8 x double> undef +} + ; CHECK: attributes #0 = { noreturn } ; CHECK: attributes #1 = { nounwind } ; CHECK: attributes #2 = { readnone } diff --git a/llvm/test/CodeGen/AArch64/O0-pipeline.ll b/llvm/test/CodeGen/AArch64/O0-pipeline.ll --- a/llvm/test/CodeGen/AArch64/O0-pipeline.ll +++ b/llvm/test/CodeGen/AArch64/O0-pipeline.ll @@ -24,6 +24,7 @@ ; CHECK-NEXT: Lower constant intrinsics ; CHECK-NEXT: Remove unreachable blocks from the CFG ; CHECK-NEXT: Instrument function entry/exit with calls to e.g. mcount() (post inlining) +; CHECK-NEXT: Expand vector predication intrinsics ; CHECK-NEXT: Scalarize Masked Memory Intrinsics ; CHECK-NEXT: Expand reduction intrinsics ; CHECK-NEXT: AArch64 Stack Tagging diff --git a/llvm/test/CodeGen/AArch64/O3-pipeline.ll b/llvm/test/CodeGen/AArch64/O3-pipeline.ll --- a/llvm/test/CodeGen/AArch64/O3-pipeline.ll +++ b/llvm/test/CodeGen/AArch64/O3-pipeline.ll @@ -47,6 +47,7 @@ ; CHECK-NEXT: Constant Hoisting ; CHECK-NEXT: Partially inline calls to library functions ; CHECK-NEXT: Instrument function entry/exit with calls to e.g. mcount() (post inlining) +; CHECK-NEXT: Expand vector predication intrinsics ; CHECK-NEXT: Scalarize Masked Memory Intrinsics ; CHECK-NEXT: Expand reduction intrinsics ; CHECK-NEXT: Dominator Tree Construction diff --git a/llvm/test/CodeGen/ARM/O3-pipeline.ll b/llvm/test/CodeGen/ARM/O3-pipeline.ll --- a/llvm/test/CodeGen/ARM/O3-pipeline.ll +++ b/llvm/test/CodeGen/ARM/O3-pipeline.ll @@ -31,6 +31,7 @@ ; CHECK-NEXT: Constant Hoisting ; CHECK-NEXT: Partially inline calls to library functions ; CHECK-NEXT: Instrument function entry/exit with calls to e.g. mcount() (post inlining) +; CHECK-NEXT: Expand vector predication intrinsics ; CHECK-NEXT: Scalarize Masked Memory Intrinsics ; CHECK-NEXT: Expand reduction intrinsics ; CHECK-NEXT: Dominator Tree Construction diff --git a/llvm/test/CodeGen/Generic/expand-vp.ll b/llvm/test/CodeGen/Generic/expand-vp.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/Generic/expand-vp.ll @@ -0,0 +1,162 @@ +; RUN: opt --expand-vec-pred -S < %s | FileCheck %s + +define void @test_vp_constrainedfp(<8 x double> %f0, <8 x double> %f1, <8 x double> %f2, <8 x double> %f3, <8 x i1> %m, i32 %n) { +; CHECK-NOT: {{call.* @llvm.vp.reduce}} +; CHECK-NOT: {{call.* @llvm.vp.fadd}} +; CHECK-NOT: {{call.* @llvm.vp.fsub}} +; CHECK-NOT: {{call.* @llvm.vp.fmul}} +; CHECK-NOT: {{call.* @llvm.vp.frem}} +; CHECK-NOT: {{call.* @llvm.vp.fma}} +; CHECK-NOT: {{call.* @llvm.vp.fneg}} + %r0 = call <8 x double> @llvm.vp.fadd.v8f64(<8 x double> %f0, <8 x double> %f1, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %r1 = call <8 x double> @llvm.vp.fsub.v8f64(<8 x double> %f0, <8 x double> %f1, metadata !"round.tozero", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %r2 = call <8 x double> @llvm.vp.fmul.v8f64(<8 x double> %f0, <8 x double> %f1, metadata !"round.tozero", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %r3 = call <8 x double> @llvm.vp.fdiv.v8f64(<8 x double> %f0, <8 x double> %f1, metadata !"round.tozero", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %r4 = call <8 x double> @llvm.vp.frem.v8f64(<8 x double> %f0, <8 x double> %f1, metadata !"round.tozero", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %r5 = call <8 x double> @llvm.vp.fma.v8f64(<8 x double> %f0, <8 x double> %f1, <8 x double> %f2, metadata !"round.tozero", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %r6 = call <8 x double> @llvm.vp.fneg.v8f64(<8 x double> %f2, metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + ret void +} + +define void @test_vp_int(<8 x i32> %i0, <8 x i32> %i1, <8 x i32> %i2, <8 x i32> %f3, <8 x i1> %m, i32 %n) { +; CHECK-NOT: {{call.* @llvm.vp.add}} +; CHECK-NOT: {{call.* @llvm.vp.sub}} +; CHECK-NOT: {{call.* @llvm.vp.mul}} +; CHECK-NOT: {{call.* @llvm.vp.sdiv}} +; CHECK-NOT: {{call.* @llvm.vp.udiv}} +; CHECK-NOT: {{call.* @llvm.vp.srem}} +; CHECK-NOT: {{call.* @llvm.vp.urem}} + %r0 = call <8 x i32> @llvm.vp.add.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r1 = call <8 x i32> @llvm.vp.sub.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r2 = call <8 x i32> @llvm.vp.mul.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r3 = call <8 x i32> @llvm.vp.sdiv.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r4 = call <8 x i32> @llvm.vp.srem.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r5 = call <8 x i32> @llvm.vp.udiv.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r6 = call <8 x i32> @llvm.vp.urem.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + ret void +} + +define void @test_mem(<16 x i32*> %p0, <16 x i32>* %p1, <16 x i32> %i0, <16 x i1> %m, i32 %n) { +; CHECK-NOT: {{call.* @llvm.vp.load}} +; CHECK-NOT: {{call.* @llvm.vp.store}} +; CHECK-NOT: {{call.* @llvm.vp.gather}} +; CHECK-NOT: {{call.* @llvm.vp.scatter}} + call void @llvm.vp.store.v16i32.p0v16i32(<16 x i32> %i0, <16 x i32>* %p1, <16 x i1> %m, i32 %n) + call void @llvm.vp.scatter.v16i32.v16p0i32(<16 x i32> %i0 , <16 x i32*> %p0, <16 x i1> %m, i32 %n) + %l0 = call <16 x i32> @llvm.vp.load.v16i32.p0v16i32(<16 x i32>* %p1, <16 x i1> %m, i32 %n) + %l1 = call <16 x i32> @llvm.vp.gather.v16i32.v16p0i32(<16 x i32*> %p0, <16 x i1> %m, i32 %n) + ret void +} + +define void @test_reduce_fp(<16 x float> %v, <16 x i1> %m, i32 %n) { +; CHECK-NOT: {{call.* @llvm.vp.reduce.fadd}} +; CHECK-NOT: {{call.* @llvm.vp.reduce.fmul}} +; CHECK-NOT: {{call.* @llvm.vp.reduce.fmin}} +; CHECK-NOT: {{call.* @llvm.vp.reduce.fmax}} + %r0 = call float @llvm.vp.reduce.fadd.v16f32(float 0.0, <16 x float> %v, <16 x i1> %m, i32 %n) + %r1 = call float @llvm.vp.reduce.fmul.v16f32(float 42.0, <16 x float> %v, <16 x i1> %m, i32 %n) + %r2 = call float @llvm.vp.reduce.fmin.v16f32(<16 x float> %v, <16 x i1> %m, i32 %n) + %r3 = call float @llvm.vp.reduce.fmax.v16f32(<16 x float> %v, <16 x i1> %m, i32 %n) + ret void +} + +define void @test_reduce_int(<16 x i32> %v, <16 x i1> %m, i32 %n) { +; CHECK-NOT: {{call.* @llvm.vp.reduce.add}} +; CHECK-NOT: {{call.* @llvm.vp.reduce.mul}} +; CHECK-NOT: {{call.* @llvm.vp.reduce.and}} +; CHECK-NOT: {{call.* @llvm.vp.reduce.xor}} +; CHECK-NOT: {{call.* @llvm.vp.reduce.or}} +; CHECK-NOT: {{call.* @llvm.vp.reduce.smin}} +; CHECK-NOT: {{call.* @llvm.vp.reduce.smax}} +; CHECK-NOT: {{call.* @llvm.vp.reduce.umin}} +; CHECK-NOT: {{call.* @llvm.vp.reduce.umax}} + %r0 = call i32 @llvm.vp.reduce.add.v16i32(<16 x i32> %v, <16 x i1> %m, i32 %n) + %r1 = call i32 @llvm.vp.reduce.mul.v16i32(<16 x i32> %v, <16 x i1> %m, i32 %n) + %r2 = call i32 @llvm.vp.reduce.and.v16i32(<16 x i32> %v, <16 x i1> %m, i32 %n) + %r3 = call i32 @llvm.vp.reduce.xor.v16i32(<16 x i32> %v, <16 x i1> %m, i32 %n) + %r4 = call i32 @llvm.vp.reduce.or.v16i32(<16 x i32> %v, <16 x i1> %m, i32 %n) + %r5 = call i32 @llvm.vp.reduce.smin.v16i32(<16 x i32> %v, <16 x i1> %m, i32 %n) + %r6 = call i32 @llvm.vp.reduce.smax.v16i32(<16 x i32> %v, <16 x i1> %m, i32 %n) + %r7 = call i32 @llvm.vp.reduce.umin.v16i32(<16 x i32> %v, <16 x i1> %m, i32 %n) + %r8 = call i32 @llvm.vp.reduce.umax.v16i32(<16 x i32> %v, <16 x i1> %m, i32 %n) + ret void +} + +define void @test_shuffle(<16 x float> %v0, <16 x float> %v1, <16 x i1> %m, i32 %k, i32 %n) { +; CHECK-NOT: {{call.* @llvm.vp.select}} +; CHECK-NOT: {{call.* @llvm.vp.compose}} +; no generic lowering available: {{call.* @llvm.vp.compress}} +; no generic lowering available: {{call.* @llvm.vp.expand}} +; no generic lowering available: {{call.* @llvm.vp.vshift}} + %r0 = call <16 x float> @llvm.vp.select.v16f32(<16 x i1> %m, <16 x float> %v0, <16 x float> %v1, i32 %n) + %r1 = call <16 x float> @llvm.vp.compose.v16f32(<16 x float> %v0, <16 x float> %v1, i32 %k, i32 %n) + %r2 = call <16 x float> @llvm.vp.vshift.v16f32(<16 x float> %v0, i32 %k, <16 x i1> %m, i32 %n) + %r3 = call <16 x float> @llvm.vp.compress.v16f32(<16 x float> %v0, <16 x i1> %m, i32 %n) + %r4 = call <16 x float> @llvm.vp.expand.v16f32(<16 x float> %v0, <16 x i1> %m, i32 %n) + ret void +} + +define void @test_xcmp(<16 x i32> %i0, <16 x i32> %i1, <16 x float> %f0, <16 x float> %f1, <16 x i1> %m, i32 %n) { +; CHECK-NOT: {{call.* @llvm.vp.icmp}} +; CHECK-NOT: {{call.* @llvm.vp.fcmp}} + %r0 = call <16 x i1> @llvm.vp.icmp.v16i32(<16 x i32> %i0, <16 x i32> %i1, i8 38, <16 x i1> %m, i32 %n) + %r1 = call <16 x i1> @llvm.vp.fcmp.v16f32(<16 x float> %f0, <16 x float> %f1, i8 10, <16 x i1> %m, i32 %n) + ret void +} + +; standard floating point arith +declare <8 x double> @llvm.vp.fadd.v8f64(<8 x double>, <8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.fsub.v8f64(<8 x double>, <8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.fmul.v8f64(<8 x double>, <8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.fdiv.v8f64(<8 x double>, <8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.frem.v8f64(<8 x double>, <8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.fma.v8f64(<8 x double>, <8 x double>, <8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.fneg.v8f64(<8 x double>, metadata, <8 x i1> mask, i32 vlen) + +; integer arith +declare <8 x i32> @llvm.vp.add.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) +declare <8 x i32> @llvm.vp.sub.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) +declare <8 x i32> @llvm.vp.mul.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) +declare <8 x i32> @llvm.vp.sdiv.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) +declare <8 x i32> @llvm.vp.srem.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) +declare <8 x i32> @llvm.vp.udiv.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) +declare <8 x i32> @llvm.vp.urem.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) +; bit arith +declare <8 x i32> @llvm.vp.and.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) +declare <8 x i32> @llvm.vp.xor.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) +declare <8 x i32> @llvm.vp.or.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) +declare <8 x i32> @llvm.vp.ashr.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) +declare <8 x i32> @llvm.vp.lshr.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) +declare <8 x i32> @llvm.vp.shl.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) + +; memory +declare void @llvm.vp.store.v16i32.p0v16i32(<16 x i32>, <16 x i32>*, <16 x i1> mask, i32 vlen) +declare void @llvm.vp.scatter.v16i32.v16p0i32(<16 x i32>, <16 x i32*>, <16 x i1> mask, i32 vlen) +declare <16 x i32> @llvm.vp.load.v16i32.p0v16i32(<16 x i32>*, <16 x i1> mask, i32 vlen) +declare <16 x i32> @llvm.vp.gather.v16i32.v16p0i32(<16 x i32*>, <16 x i1> mask, i32 vlen) + +; reductions +declare float @llvm.vp.reduce.fadd.v16f32(float, <16 x float>, <16 x i1> mask, i32 vlen) +declare float @llvm.vp.reduce.fmul.v16f32(float, <16 x float>, <16 x i1> mask, i32 vlen) +declare float @llvm.vp.reduce.fmin.v16f32(<16 x float>, <16 x i1> mask, i32 vlen) +declare float @llvm.vp.reduce.fmax.v16f32(<16 x float>, <16 x i1> mask, i32 vlen) +declare i32 @llvm.vp.reduce.add.v16i32(<16 x i32>, <16 x i1> mask, i32 vlen) +declare i32 @llvm.vp.reduce.mul.v16i32(<16 x i32>, <16 x i1> mask, i32 vlen) +declare i32 @llvm.vp.reduce.and.v16i32(<16 x i32>, <16 x i1> mask, i32 vlen) +declare i32 @llvm.vp.reduce.xor.v16i32(<16 x i32>, <16 x i1> mask, i32 vlen) +declare i32 @llvm.vp.reduce.or.v16i32(<16 x i32>, <16 x i1> mask, i32 vlen) +declare i32 @llvm.vp.reduce.smin.v16i32(<16 x i32>, <16 x i1> mask, i32 vlen) +declare i32 @llvm.vp.reduce.smax.v16i32(<16 x i32>, <16 x i1> mask, i32 vlen) +declare i32 @llvm.vp.reduce.umin.v16i32(<16 x i32>, <16 x i1> mask, i32 vlen) +declare i32 @llvm.vp.reduce.umax.v16i32(<16 x i32>, <16 x i1> mask, i32 vlen) + +; shuffles +declare <16 x float> @llvm.vp.select.v16f32(<16 x i1>, <16 x float>, <16 x float>, i32 vlen) +declare <16 x float> @llvm.vp.compose.v16f32(<16 x float>, <16 x float>, i32, i32 vlen) +declare <16 x float> @llvm.vp.vshift.v16f32(<16 x float>, i32, <16 x i1>, i32 vlen) +declare <16 x float> @llvm.vp.compress.v16f32(<16 x float>, <16 x i1>, i32 vlen) +declare <16 x float> @llvm.vp.expand.v16f32(<16 x float>, <16 x i1> mask, i32 vlen) + +; icmp , fcmp +declare <16 x i1> @llvm.vp.icmp.v16i32(<16 x i32>, <16 x i32>, i8, <16 x i1> mask, i32 vlen) +declare <16 x i1> @llvm.vp.fcmp.v16f32(<16 x float>, <16 x float>, i8, <16 x i1> mask, i32 vlen) diff --git a/llvm/test/CodeGen/X86/O0-pipeline.ll b/llvm/test/CodeGen/X86/O0-pipeline.ll --- a/llvm/test/CodeGen/X86/O0-pipeline.ll +++ b/llvm/test/CodeGen/X86/O0-pipeline.ll @@ -27,6 +27,7 @@ ; CHECK-NEXT: Lower constant intrinsics ; CHECK-NEXT: Remove unreachable blocks from the CFG ; CHECK-NEXT: Instrument function entry/exit with calls to e.g. mcount() (post inlining) +; CHECK-NEXT: Expand vector predication intrinsics ; CHECK-NEXT: Scalarize Masked Memory Intrinsics ; CHECK-NEXT: Expand reduction intrinsics ; CHECK-NEXT: Expand indirectbr instructions diff --git a/llvm/test/CodeGen/X86/O3-pipeline.ll b/llvm/test/CodeGen/X86/O3-pipeline.ll --- a/llvm/test/CodeGen/X86/O3-pipeline.ll +++ b/llvm/test/CodeGen/X86/O3-pipeline.ll @@ -44,6 +44,7 @@ ; CHECK-NEXT: Constant Hoisting ; CHECK-NEXT: Partially inline calls to library functions ; CHECK-NEXT: Instrument function entry/exit with calls to e.g. mcount() (post inlining) +; CHECK-NEXT: Expand vector predication intrinsics ; CHECK-NEXT: Scalarize Masked Memory Intrinsics ; CHECK-NEXT: Expand reduction intrinsics ; CHECK-NEXT: Dominator Tree Construction diff --git a/llvm/test/Transforms/InstCombine/vp-fsub.ll b/llvm/test/Transforms/InstCombine/vp-fsub.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/InstCombine/vp-fsub.ll @@ -0,0 +1,45 @@ +; RUN: opt < %s -instcombine -S | FileCheck %s + +; PR4374 + +define <4 x float> @test1_vp(<4 x float> %x, <4 x float> %y, <4 x i1> %M, i32 %L) { +; CHECK-LABEL: @test1_vp( +; + %t1 = call <4 x float> @llvm.vp.fsub.v4f32(<4 x float> %x, <4 x float> %y, metadata !"round.tonearest", metadata !"fpexcept.ignore", <4 x i1> %M, i32 %L) #0 + %t2 = call <4 x float> @llvm.vp.fsub.v4f32(<4 x float> , <4 x float> %t1, metadata !"round.tonearest", metadata !"fpexcept.ignore", <4 x i1> %M, i32 %L) #0 + ret <4 x float> %t2 +} + +; Can't do anything with the test above because -0.0 - 0.0 = -0.0, but if we have nsz: +; -(X - Y) --> Y - X + +; TODO predicated FAdd folding +define <4 x float> @neg_sub_nsz_vp(<4 x float> %x, <4 x float> %y, <4 x i1> %M, i32 %L) { +; CH***-LABEL: @neg_sub_nsz_vp( +; + %t1 = call <4 x float> @llvm.vp.fsub.v4f32(<4 x float> %x, <4 x float> %y, metadata !"round.tonearest", metadata !"fpexcept.ignore", <4 x i1> %M, i32 %L) #0 + %t2 = call nsz <4 x float> @llvm.vp.fsub.v4f32(<4 x float> , <4 x float> %t1, metadata !"round.tonearest", metadata !"fpexcept.ignore", <4 x i1> %M, i32 %L) #0 + ret <4 x float> %t2 +} + +; With nsz: Z - (X - Y) --> Z + (Y - X) + +define <4 x float> @sub_sub_nsz_vp(<4 x float> %x, <4 x float> %y, <4 x float> %z, <4 x i1> %M, i32 %L) { +; CHECK-LABEL: @sub_sub_nsz_vp( +; CHECK-NEXT: %1 = call nsz <4 x float> @llvm.vp.fsub.v4f32(<4 x float> %y, <4 x float> %x, metadata !"round.tonearest", metadata !"fpexcept.ignore", <4 x i1> %M, i32 %L) # +; CHECK-NEXT: %t2 = call nsz <4 x float> @llvm.vp.fadd.v4f32(<4 x float> %z, <4 x float> %1, metadata !"round.tonearest", metadata !"fpexcept.ignore", <4 x i1> %M, i32 %L) # +; CHECK-NEXT: ret <4 x float> %t2 + %t1 = call <4 x float> @llvm.vp.fsub.v4f32(<4 x float> %x, <4 x float> %y, metadata !"round.tonearest", metadata !"fpexcept.ignore", <4 x i1> %M, i32 %L) #0 + %t2 = call nsz <4 x float> @llvm.vp.fsub.v4f32(<4 x float> %z, <4 x float> %t1, metadata !"round.tonearest", metadata !"fpexcept.ignore", <4 x i1> %M, i32 %L) #0 + ret <4 x float> %t2 +} + + + +; Function Attrs: nounwind readnone +declare <4 x float> @llvm.vp.fadd.v4f32(<4 x float>, <4 x float>, metadata, metadata, <4 x i1> mask, i32 vlen) + +; Function Attrs: nounwind readnone +declare <4 x float> @llvm.vp.fsub.v4f32(<4 x float>, <4 x float>, metadata, metadata, <4 x i1> mask, i32 vlen) + +attributes #0 = { readnone } diff --git a/llvm/test/Transforms/InstSimplify/vp-fsub.ll b/llvm/test/Transforms/InstSimplify/vp-fsub.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/InstSimplify/vp-fsub.ll @@ -0,0 +1,55 @@ +; RUN: opt < %s -instsimplify -S | FileCheck %s + +define <8 x double> @fsub_fadd_fold_vp_xy(<8 x double> %x, <8 x double> %y, <8 x i1> %m, i32 %len) { +; CHECK-LABEL: fsub_fadd_fold_vp_xy +; CHECK: ret <8 x double> %x + %tmp = call reassoc nsz <8 x double> @llvm.vp.fadd.v8f64(<8 x double> %x, <8 x double> %y, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %len) + %res0 = call reassoc nsz <8 x double> @llvm.vp.fsub.v8f64(<8 x double> %tmp, <8 x double> %y, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %len) + ret <8 x double> %res0 +} + +define <8 x double> @fsub_fadd_fold_vp_zw(<8 x double> %z, <8 x double> %w, <8 x i1> %m, i32 %len) { +; CHECK-LABEL: fsub_fadd_fold_vp_zw +; CHECK: ret <8 x double> %z + %tmp = call reassoc nsz <8 x double> @llvm.vp.fadd.v8f64(<8 x double> %w, <8 x double> %z, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %len) + %res1 = call reassoc nsz <8 x double> @llvm.vp.fsub.v8f64(<8 x double> %tmp, <8 x double> %w, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %len) + ret <8 x double> %res1 +} + +; REQUIRES-CONSTRAINED-VP: define <8 x double> @fsub_fadd_fold_vp_yx_fpexcept(<8 x double> %x, <8 x double> %y, <8 x i1> %m, i32 %len) #0 { +; REQUIRES-CONSTRAINED-VP: ; *HECK-LABEL: fsub_fadd_fold_vp_yx +; REQUIRES-CONSTRAINED-VP: ; *HECK-NEXT: %tmp = +; REQUIRES-CONSTRAINED-VP: ; *HECK-NEXT: %res2 = +; REQUIRES-CONSTRAINED-VP: ; *HECK-NEXT: ret +; REQUIRES-CONSTRAINED-VP: %tmp = call reassoc nsz <8 x double> @llvm.vp.fadd.v8f64(<8 x double> %y, <8 x double> %x, metadata !"round.tonearest", metadata !"fpexcept.strict", <8 x i1> %m, i32 %len) +; REQUIRES-CONSTRAINED-VP: %res2 = call reassoc nsz <8 x double> @llvm.vp.fsub.v8f64(<8 x double> %tmp, <8 x double> %y, metadata !"round.tonearest", metadata !"fpexcept.strict", <8 x i1> %m, i32 %len) +; REQUIRES-CONSTRAINED-VP: ret <8 x double> %res2 +; REQUIRES-CONSTRAINED-VP: } + +define <8 x double> @fsub_fadd_fold_vp_yx_olen(<8 x double> %x, <8 x double> %y, <8 x i1> %m, i32 %len, i32 %otherLen) { +; CHECK-LABEL: fsub_fadd_fold_vp_yx_olen +; CHECK-NEXT: %tmp = call reassoc nsz <8 x double> @llvm.vp.fadd.v8f64(<8 x double> %y, <8 x double> %x, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %otherLen) +; CHECK-NEXT: %res3 = call reassoc nsz <8 x double> @llvm.vp.fsub.v8f64(<8 x double> %tmp, <8 x double> %y, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %len) +; CHECK-NEXT: ret <8 x double> %res3 + %tmp = call reassoc nsz <8 x double> @llvm.vp.fadd.v8f64(<8 x double> %y, <8 x double> %x, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %otherLen) + %res3 = call reassoc nsz <8 x double> @llvm.vp.fsub.v8f64(<8 x double> %tmp, <8 x double> %y, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %len) + ret <8 x double> %res3 +} + +define <8 x double> @fsub_fadd_fold_vp_yx_omask(<8 x double> %x, <8 x double> %y, <8 x i1> %m, i32 %len, <8 x i1> %othermask) { +; CHECK-LABEL: fsub_fadd_fold_vp_yx_omask +; CHECK-NEXT: %tmp = call reassoc nsz <8 x double> @llvm.vp.fadd.v8f64(<8 x double> %y, <8 x double> %x, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %len) +; CHECK-NEXT: %res4 = call reassoc nsz <8 x double> @llvm.vp.fsub.v8f64(<8 x double> %tmp, <8 x double> %y, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %othermask, i32 %len) +; CHECK-NEXT: ret <8 x double> %res4 + %tmp = call reassoc nsz <8 x double> @llvm.vp.fadd.v8f64(<8 x double> %y, <8 x double> %x, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %len) + %res4 = call reassoc nsz <8 x double> @llvm.vp.fsub.v8f64(<8 x double> %tmp, <8 x double> %y, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %othermask, i32 %len) + ret <8 x double> %res4 +} + +; Function Attrs: nounwind readnone +declare <8 x double> @llvm.vp.fadd.v8f64(<8 x double>, <8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) + +; Function Attrs: nounwind readnone +declare <8 x double> @llvm.vp.fsub.v8f64(<8 x double>, <8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) + +attributes #0 = { strictfp } diff --git a/llvm/test/Verifier/evl_attribs.ll b/llvm/test/Verifier/evl_attribs.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Verifier/evl_attribs.ll @@ -0,0 +1,13 @@ +; RUN: not llvm-as %s -o /dev/null 2>&1 | FileCheck %s + +declare void @a(<16 x i1> mask %a, <16 x i1> mask %b) +; CHECK: Cannot have multiple 'mask' parameters! + +declare void @b(<16 x i1> mask %a, i32 vlen %x, i32 vlen %y) +; CHECK: Cannot have multiple 'vlen' parameters! + +declare <16 x double> @c(<16 x double> passthru %a) +; CHECK: Cannot have 'passthru' parameter without 'mask' parameter! + +declare <16 x double> @d(<16 x double> passthru %a, <16 x i1> mask %M, <16 x double> passthru %b) +; CHECK: Cannot have multiple 'passthru' parameters! diff --git a/llvm/test/Verifier/vp-intrinsics-constrained.ll b/llvm/test/Verifier/vp-intrinsics-constrained.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Verifier/vp-intrinsics-constrained.ll @@ -0,0 +1,17 @@ +; RUN: not opt -S < %s |& FileCheck %s +; CHECK: VP intrinsics only support the default fp environment for now (round.tonearest; fpexcept.ignore). +; CHECK: error: input module is broken! + +define void @test_vp_strictfp(<8 x double> %f0, <8 x double> %f1, <8 x double> %f2, <8 x double> %f3, <8 x i1> %m, i32 %n) #0 { + %r0 = call <8 x double> @llvm.vp.fadd.v8f64(<8 x double> %f0, <8 x double> %f1, metadata !"round.tonearest", metadata !"fpexcept.strict", <8 x i1> %m, i32 %n) + ret void +} + +define void @test_vp_rounding(<8 x double> %f0, <8 x double> %f1, <8 x double> %f2, <8 x double> %f3, <8 x i1> %m, i32 %n) #0 { + %r0 = call <8 x double> @llvm.vp.fadd.v8f64(<8 x double> %f0, <8 x double> %f1, metadata !"round.tozero", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + ret void +} + +declare <8 x double> @llvm.vp.fadd.v8f64(<8 x double>, <8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) + +attributes #0 = { strictfp } diff --git a/llvm/test/Verifier/vp-intrinsics.ll b/llvm/test/Verifier/vp-intrinsics.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Verifier/vp-intrinsics.ll @@ -0,0 +1,122 @@ +; RUN: opt --verify %s + +define void @test_vp_constrainedfp(<8 x double> %f0, <8 x double> %f1, <8 x double> %f2, <8 x double> %f3, <8 x i1> %m, i32 %n) { + %r0 = call <8 x double> @llvm.vp.fadd.v8f64(<8 x double> %f0, <8 x double> %f1, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %r1 = call <8 x double> @llvm.vp.fsub.v8f64(<8 x double> %f0, <8 x double> %f1, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %r2 = call <8 x double> @llvm.vp.fmul.v8f64(<8 x double> %f0, <8 x double> %f1, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %r3 = call <8 x double> @llvm.vp.fdiv.v8f64(<8 x double> %f0, <8 x double> %f1, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %r4 = call <8 x double> @llvm.vp.frem.v8f64(<8 x double> %f0, <8 x double> %f1, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %r5 = call <8 x double> @llvm.vp.fma.v8f64(<8 x double> %f0, <8 x double> %f1, <8 x double> %f2, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %r6 = call <8 x double> @llvm.vp.fneg.v8f64(<8 x double> %f2, metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + ret void +} + +define void @test_vp_int(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) { + %r0 = call <8 x i32> @llvm.vp.add.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r1 = call <8 x i32> @llvm.vp.sub.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r2 = call <8 x i32> @llvm.vp.mul.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r3 = call <8 x i32> @llvm.vp.sdiv.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r4 = call <8 x i32> @llvm.vp.srem.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r5 = call <8 x i32> @llvm.vp.udiv.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r6 = call <8 x i32> @llvm.vp.urem.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r7 = call <8 x i32> @llvm.vp.and.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r8 = call <8 x i32> @llvm.vp.or.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r9 = call <8 x i32> @llvm.vp.xor.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %rA = call <8 x i32> @llvm.vp.ashr.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %rB = call <8 x i32> @llvm.vp.lshr.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %rC = call <8 x i32> @llvm.vp.shl.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + ret void +} + +define void @test_mem(<16 x i32*> %p0, <16 x i32>* %p1, <16 x i32> %i0, <16 x i1> %m, i32 %n) { + call void @llvm.vp.store.v16i32.p0v16i32(<16 x i32> %i0, <16 x i32>* %p1, <16 x i1> %m, i32 %n) + call void @llvm.vp.scatter.v16i32.v16p0i32(<16 x i32> %i0 , <16 x i32*> %p0, <16 x i1> %m, i32 %n) + %l0 = call <16 x i32> @llvm.vp.load.v16i32.p0v16i32(<16 x i32>* %p1, <16 x i1> %m, i32 %n) + %l1 = call <16 x i32> @llvm.vp.gather.v16i32.v16p0i32(<16 x i32*> %p0, <16 x i1> %m, i32 %n) + ret void +} + +define void @test_reduce_fp(<16 x float> %v, <16 x i1> %m, i32 %n) { + %r0 = call float @llvm.vp.reduce.fadd.v16f32(float 0.0, <16 x float> %v, <16 x i1> %m, i32 %n) + %r1 = call float @llvm.vp.reduce.fmul.v16f32(float 42.0, <16 x float> %v, <16 x i1> %m, i32 %n) + %r2 = call float @llvm.vp.reduce.fmin.v16f32(<16 x float> %v, <16 x i1> %m, i32 %n) + %r3 = call float @llvm.vp.reduce.fmax.v16f32(<16 x float> %v, <16 x i1> %m, i32 %n) + ret void +} + +define void @test_reduce_int(<16 x i32> %v, <16 x i1> %m, i32 %n) { + %r0 = call i32 @llvm.vp.reduce.add.v16i32(<16 x i32> %v, <16 x i1> %m, i32 %n) + %r1 = call i32 @llvm.vp.reduce.mul.v16i32(<16 x i32> %v, <16 x i1> %m, i32 %n) + %r2 = call i32 @llvm.vp.reduce.and.v16i32(<16 x i32> %v, <16 x i1> %m, i32 %n) + %r3 = call i32 @llvm.vp.reduce.xor.v16i32(<16 x i32> %v, <16 x i1> %m, i32 %n) + %r4 = call i32 @llvm.vp.reduce.or.v16i32(<16 x i32> %v, <16 x i1> %m, i32 %n) + ret void +} + +define void @test_shuffle(<16 x float> %v0, <16 x float> %v1, <16 x i1> %m, i32 %k, i32 %n) { + %r0 = call <16 x float> @llvm.vp.select.v16f32(<16 x i1> %m, <16 x float> %v0, <16 x float> %v1, i32 %n) + %r1 = call <16 x float> @llvm.vp.compose.v16f32(<16 x float> %v0, <16 x float> %v1, i32 %k, i32 %n) + %r2 = call <16 x float> @llvm.vp.shift.v16f32(<16 x float> %v0, i32 %k, <16 x i1> %m, i32 %n) + %r3 = call <16 x float> @llvm.vp.compress.v16f32(<16 x float> %v0, <16 x i1> %m, i32 %n) + %r4 = call <16 x float> @llvm.vp.expand.v16f32(<16 x float> %v0, <16 x i1> %m, i32 %n) + ret void +} + +define void @test_xcmp(<16 x i32> %v0, <16 x i32> %v1, <16 x i1> %m, i32 %n) { + %r0 = call <16 x i1> @llvm.vp.icmp.v16i32(<16 x i32> %v0, <16 x i32> %v1, i8 8, <16 x i1> %m, i32 %n) + %r1 = call <16 x i1> @llvm.vp.fcmp.v16i32(<16 x i32> %v0, <16 x i32> %v1, i8 12, <16 x i1> %m, i32 %n) + ret void +} + +; standard floating point arith +declare <8 x double> @llvm.vp.fadd.v8f64(<8 x double>, <8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.fsub.v8f64(<8 x double>, <8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.fmul.v8f64(<8 x double>, <8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.fdiv.v8f64(<8 x double>, <8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.frem.v8f64(<8 x double>, <8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.fma.v8f64(<8 x double>, <8 x double>, <8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.fneg.v8f64(<8 x double>, metadata, <8 x i1> mask, i32 vlen) + +; integer arith +declare <8 x i32> @llvm.vp.add.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) +declare <8 x i32> @llvm.vp.sub.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) +declare <8 x i32> @llvm.vp.mul.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) +declare <8 x i32> @llvm.vp.sdiv.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) +declare <8 x i32> @llvm.vp.srem.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) +declare <8 x i32> @llvm.vp.udiv.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) +declare <8 x i32> @llvm.vp.urem.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) +; bit arith +declare <8 x i32> @llvm.vp.and.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) +declare <8 x i32> @llvm.vp.xor.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) +declare <8 x i32> @llvm.vp.or.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) +declare <8 x i32> @llvm.vp.ashr.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) +declare <8 x i32> @llvm.vp.lshr.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) +declare <8 x i32> @llvm.vp.shl.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) + +; memory +declare void @llvm.vp.store.v16i32.p0v16i32(<16 x i32>, <16 x i32>*, <16 x i1> mask, i32 vlen) +declare void @llvm.vp.scatter.v16i32.v16p0i32(<16 x i32>, <16 x i32*>, <16 x i1> mask, i32 vlen) +declare <16 x i32> @llvm.vp.load.v16i32.p0v16i32(<16 x i32>*, <16 x i1> mask, i32 vlen) +declare <16 x i32> @llvm.vp.gather.v16i32.v16p0i32(<16 x i32*>, <16 x i1> mask, i32 vlen) + +; reductions +declare float @llvm.vp.reduce.fadd.v16f32(float, <16 x float>, <16 x i1> mask, i32 vlen) +declare float @llvm.vp.reduce.fmul.v16f32(float, <16 x float>, <16 x i1> mask, i32 vlen) +declare float @llvm.vp.reduce.fmin.v16f32(<16 x float>, <16 x i1> mask, i32 vlen) +declare float @llvm.vp.reduce.fmax.v16f32(<16 x float>, <16 x i1> mask, i32 vlen) +declare i32 @llvm.vp.reduce.add.v16i32(<16 x i32>, <16 x i1> mask, i32 vlen) +declare i32 @llvm.vp.reduce.mul.v16i32(<16 x i32>, <16 x i1> mask, i32 vlen) +declare i32 @llvm.vp.reduce.and.v16i32(<16 x i32>, <16 x i1> mask, i32 vlen) +declare i32 @llvm.vp.reduce.xor.v16i32(<16 x i32>, <16 x i1> mask, i32 vlen) +declare i32 @llvm.vp.reduce.or.v16i32(<16 x i32>, <16 x i1> mask, i32 vlen) + +; shuffles +declare <16 x float> @llvm.vp.select.v16f32(<16 x i1>, <16 x float>, <16 x float>, i32 vlen) +declare <16 x float> @llvm.vp.compose.v16f32(<16 x float>, <16 x float>, i32, i32 vlen) +declare <16 x float> @llvm.vp.shift.v16f32(<16 x float>, i32, <16 x i1>, i32 vlen) +declare <16 x float> @llvm.vp.compress.v16f32(<16 x float>, <16 x i1>, i32 vlen) +declare <16 x float> @llvm.vp.expand.v16f32(<16 x float>, <16 x i1> mask, i32 vlen) + +; icmp , fcmp +declare <16 x i1> @llvm.vp.icmp.v16i32(<16 x i32>, <16 x i32>, i8, <16 x i1> mask, i32 vlen) +declare <16 x i1> @llvm.vp.fcmp.v16i32(<16 x i32>, <16 x i32>, i8, <16 x i1> mask, i32 vlen) diff --git a/llvm/tools/llc/llc.cpp b/llvm/tools/llc/llc.cpp --- a/llvm/tools/llc/llc.cpp +++ b/llvm/tools/llc/llc.cpp @@ -314,6 +314,7 @@ initializeVectorization(*Registry); initializeScalarizeMaskedMemIntrinPass(*Registry); initializeExpandReductionsPass(*Registry); + initializeExpandVectorPredicationPass(*Registry); initializeHardwareLoopsPass(*Registry); // Initialize debugging passes. diff --git a/llvm/tools/opt/opt.cpp b/llvm/tools/opt/opt.cpp --- a/llvm/tools/opt/opt.cpp +++ b/llvm/tools/opt/opt.cpp @@ -532,6 +532,7 @@ initializePostInlineEntryExitInstrumenterPass(Registry); initializeUnreachableBlockElimLegacyPassPass(Registry); initializeExpandReductionsPass(Registry); + initializeExpandVectorPredicationPass(Registry); initializeWasmEHPreparePass(Registry); initializeWriteBitcodePassPass(Registry); initializeHardwareLoopsPass(Registry); diff --git a/llvm/unittests/IR/CMakeLists.txt b/llvm/unittests/IR/CMakeLists.txt --- a/llvm/unittests/IR/CMakeLists.txt +++ b/llvm/unittests/IR/CMakeLists.txt @@ -39,6 +39,7 @@ ValueTest.cpp VectorTypesTest.cpp VerifierTest.cpp + VPIntrinsicTest.cpp WaymarkTest.cpp ) diff --git a/llvm/unittests/IR/IRBuilderTest.cpp b/llvm/unittests/IR/IRBuilderTest.cpp --- a/llvm/unittests/IR/IRBuilderTest.cpp +++ b/llvm/unittests/IR/IRBuilderTest.cpp @@ -242,52 +242,52 @@ V = Builder.CreateFAdd(V, V); ASSERT_TRUE(isa(V)); auto *CII = cast(V); - ASSERT_TRUE(CII->getExceptionBehavior() == ConstrainedFPIntrinsic::ebStrict); - ASSERT_TRUE(CII->getRoundingMode() == ConstrainedFPIntrinsic::rmDynamic); + ASSERT_TRUE(CII->getExceptionBehavior() == ExceptionBehavior::ebStrict); + ASSERT_TRUE(CII->getRoundingMode() == RoundingMode::rmDynamic); - Builder.setDefaultConstrainedExcept(ConstrainedFPIntrinsic::ebIgnore); - Builder.setDefaultConstrainedRounding(ConstrainedFPIntrinsic::rmUpward); + Builder.setDefaultConstrainedExcept(ExceptionBehavior::ebIgnore); + Builder.setDefaultConstrainedRounding(RoundingMode::rmUpward); V = Builder.CreateFAdd(V, V); CII = cast(V); - ASSERT_TRUE(CII->getExceptionBehavior() == ConstrainedFPIntrinsic::ebIgnore); - ASSERT_TRUE(CII->getRoundingMode() == ConstrainedFPIntrinsic::rmUpward); + ASSERT_TRUE(CII->getExceptionBehavior() == ExceptionBehavior::ebIgnore); + ASSERT_TRUE(CII->getRoundingMode() == RoundingMode::rmUpward); - Builder.setDefaultConstrainedExcept(ConstrainedFPIntrinsic::ebIgnore); - Builder.setDefaultConstrainedRounding(ConstrainedFPIntrinsic::rmToNearest); + Builder.setDefaultConstrainedExcept(ExceptionBehavior::ebIgnore); + Builder.setDefaultConstrainedRounding(RoundingMode::rmToNearest); V = Builder.CreateFAdd(V, V); CII = cast(V); - ASSERT_TRUE(CII->getExceptionBehavior() == ConstrainedFPIntrinsic::ebIgnore); - ASSERT_TRUE(CII->getRoundingMode() == ConstrainedFPIntrinsic::rmToNearest); + ASSERT_TRUE(CII->getExceptionBehavior() == ExceptionBehavior::ebIgnore); + ASSERT_TRUE(CII->getRoundingMode() == RoundingMode::rmToNearest); - Builder.setDefaultConstrainedExcept(ConstrainedFPIntrinsic::ebMayTrap); - Builder.setDefaultConstrainedRounding(ConstrainedFPIntrinsic::rmDownward); + Builder.setDefaultConstrainedExcept(ExceptionBehavior::ebMayTrap); + Builder.setDefaultConstrainedRounding(RoundingMode::rmDownward); V = Builder.CreateFAdd(V, V); CII = cast(V); - ASSERT_TRUE(CII->getExceptionBehavior() == ConstrainedFPIntrinsic::ebMayTrap); - ASSERT_TRUE(CII->getRoundingMode() == ConstrainedFPIntrinsic::rmDownward); + ASSERT_TRUE(CII->getExceptionBehavior() == ExceptionBehavior::ebMayTrap); + ASSERT_TRUE(CII->getRoundingMode() == RoundingMode::rmDownward); - Builder.setDefaultConstrainedExcept(ConstrainedFPIntrinsic::ebStrict); - Builder.setDefaultConstrainedRounding(ConstrainedFPIntrinsic::rmTowardZero); + Builder.setDefaultConstrainedExcept(ExceptionBehavior::ebStrict); + Builder.setDefaultConstrainedRounding(RoundingMode::rmTowardZero); V = Builder.CreateFAdd(V, V); CII = cast(V); - ASSERT_TRUE(CII->getExceptionBehavior() == ConstrainedFPIntrinsic::ebStrict); - ASSERT_TRUE(CII->getRoundingMode() == ConstrainedFPIntrinsic::rmTowardZero); + ASSERT_TRUE(CII->getExceptionBehavior() == ExceptionBehavior::ebStrict); + ASSERT_TRUE(CII->getRoundingMode() == RoundingMode::rmTowardZero); - Builder.setDefaultConstrainedExcept(ConstrainedFPIntrinsic::ebIgnore); - Builder.setDefaultConstrainedRounding(ConstrainedFPIntrinsic::rmDynamic); + Builder.setDefaultConstrainedExcept(ExceptionBehavior::ebIgnore); + Builder.setDefaultConstrainedRounding(RoundingMode::rmDynamic); V = Builder.CreateFAdd(V, V); CII = cast(V); - ASSERT_TRUE(CII->getExceptionBehavior() == ConstrainedFPIntrinsic::ebIgnore); - ASSERT_TRUE(CII->getRoundingMode() == ConstrainedFPIntrinsic::rmDynamic); + ASSERT_TRUE(CII->getExceptionBehavior() == ExceptionBehavior::ebIgnore); + ASSERT_TRUE(CII->getRoundingMode() == RoundingMode::rmDynamic); // Now override the defaults. Call = Builder.CreateConstrainedFPBinOp( Intrinsic::experimental_constrained_fadd, V, V, nullptr, "", nullptr, - ConstrainedFPIntrinsic::rmDownward, ConstrainedFPIntrinsic::ebMayTrap); + RoundingMode::rmDownward, ExceptionBehavior::ebMayTrap); CII = cast(Call); EXPECT_EQ(CII->getIntrinsicID(), Intrinsic::experimental_constrained_fadd); - ASSERT_TRUE(CII->getExceptionBehavior() == ConstrainedFPIntrinsic::ebMayTrap); - ASSERT_TRUE(CII->getRoundingMode() == ConstrainedFPIntrinsic::rmDownward); + ASSERT_TRUE(CII->getExceptionBehavior() == ExceptionBehavior::ebMayTrap); + ASSERT_TRUE(CII->getRoundingMode() == RoundingMode::rmDownward); Builder.CreateRetVoid(); EXPECT_FALSE(verifyModule(*M)); diff --git a/llvm/unittests/IR/VPIntrinsicTest.cpp b/llvm/unittests/IR/VPIntrinsicTest.cpp new file mode 100644 --- /dev/null +++ b/llvm/unittests/IR/VPIntrinsicTest.cpp @@ -0,0 +1,190 @@ +//===- VPIntrinsicTest.cpp - VPIntrinsic unit tests ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/SmallVector.h" +#include "llvm/AsmParser/Parser.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Verifier.h" +#include "llvm/Support/SourceMgr.h" +#include "gtest/gtest.h" + +namespace llvm { +namespace { + +class VPIntrinsicTest : public testing::Test { +protected: + LLVMContext Context; + + VPIntrinsicTest() : Context() {} + + LLVMContext C; + SMDiagnostic Err; + + std::unique_ptr CreateVPDeclarationModule() { + return parseAssemblyString( +" declare <8 x double> @llvm.vp.fadd.v8f64(<8 x double>, <8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) " +" declare <8 x double> @llvm.vp.fsub.v8f64(<8 x double>, <8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) " +" declare <8 x double> @llvm.vp.fmul.v8f64(<8 x double>, <8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) " +" declare <8 x double> @llvm.vp.fdiv.v8f64(<8 x double>, <8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) " +" declare <8 x double> @llvm.vp.frem.v8f64(<8 x double>, <8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) " +" declare <8 x double> @llvm.vp.fma.v8f64(<8 x double>, <8 x double>, <8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) " +" declare <8 x double> @llvm.vp.fneg.v8f64(<8 x double>, metadata, <8 x i1> mask, i32 vlen) " +" declare <8 x i32> @llvm.vp.add.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) " +" declare <8 x i32> @llvm.vp.sub.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) " +" declare <8 x i32> @llvm.vp.mul.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) " +" declare <8 x i32> @llvm.vp.sdiv.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) " +" declare <8 x i32> @llvm.vp.srem.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) " +" declare <8 x i32> @llvm.vp.udiv.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) " +" declare <8 x i32> @llvm.vp.urem.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) " +" declare <8 x i32> @llvm.vp.and.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) " +" declare <8 x i32> @llvm.vp.xor.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) " +" declare <8 x i32> @llvm.vp.or.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) " +" declare <8 x i32> @llvm.vp.ashr.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) " +" declare <8 x i32> @llvm.vp.lshr.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) " +" declare <8 x i32> @llvm.vp.shl.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) " +" declare void @llvm.vp.store.v16i32.p0v16i32(<16 x i32>, <16 x i32>*, <16 x i1> mask, i32 vlen) " +" declare void @llvm.vp.scatter.v16i32.v16p0i32(<16 x i32>, <16 x i32*>, <16 x i1> mask, i32 vlen) " +" declare <16 x i32> @llvm.vp.load.v16i32.p0v16i32(<16 x i32>*, <16 x i1> mask, i32 vlen) " +" declare <16 x i32> @llvm.vp.gather.v16i32.v16p0i32(<16 x i32*>, <16 x i1> mask, i32 vlen) " +" declare float @llvm.vp.reduce.fadd.v16f32(float, <16 x float>, <16 x i1> mask, i32 vlen) " +" declare float @llvm.vp.reduce.fmul.v16f32(float, <16 x float>, <16 x i1> mask, i32 vlen) " +" declare float @llvm.vp.reduce.fmin.v16f32(<16 x float>, <16 x i1> mask, i32 vlen) " +" declare float @llvm.vp.reduce.fmax.v16f32(<16 x float>, <16 x i1> mask, i32 vlen) " +" declare i32 @llvm.vp.reduce.add.v16i32(<16 x i32>, <16 x i1> mask, i32 vlen) " +" declare i32 @llvm.vp.reduce.mul.v16i32(<16 x i32>, <16 x i1> mask, i32 vlen) " +" declare i32 @llvm.vp.reduce.and.v16i32(<16 x i32>, <16 x i1> mask, i32 vlen) " +" declare i32 @llvm.vp.reduce.xor.v16i32(<16 x i32>, <16 x i1> mask, i32 vlen) " +" declare i32 @llvm.vp.reduce.or.v16i32(<16 x i32>, <16 x i1> mask, i32 vlen) " +" declare i32 @llvm.vp.reduce.smin.v16i32(<16 x i32>, <16 x i1> mask, i32 vlen) " +" declare i32 @llvm.vp.reduce.smax.v16i32(<16 x i32>, <16 x i1> mask, i32 vlen) " +" declare i32 @llvm.vp.reduce.umin.v16i32(<16 x i32>, <16 x i1> mask, i32 vlen) " +" declare i32 @llvm.vp.reduce.umax.v16i32(<16 x i32>, <16 x i1> mask, i32 vlen) " +" declare <16 x float> @llvm.vp.select.v16f32(<16 x i1>, <16 x float>, <16 x float>, i32 vlen) " +" declare <16 x float> @llvm.vp.compose.v16f32(<16 x float>, <16 x float>, i32, i32 vlen) " +" declare <16 x float> @llvm.vp.vshift.v16f32(<16 x float>, i32, <16 x i1>, i32 vlen) " +" declare <16 x float> @llvm.vp.compress.v16f32(<16 x float>, <16 x i1>, i32 vlen) " +" declare <16 x float> @llvm.vp.expand.v16f32(<16 x float>, <16 x i1> mask, i32 vlen) " +" declare <16 x i1> @llvm.vp.icmp.v16i32(<16 x i32>, <16 x i32>, i8, <16 x i1> mask, i32 vlen) " +" declare <16 x i1> @llvm.vp.fcmp.v16f32(<16 x float>, <16 x float>, i8, <16 x i1> mask, i32 vlen) ", + Err, C); + } +}; + +/// Check that VPIntrinsic:canIgnoreVectorLengthParam() returns true +/// if the vector length parameter does not mask-off any lanes. +TEST_F(VPIntrinsicTest, CanIgnoreVectorLength) { + LLVMContext C; + SMDiagnostic Err; + + std::unique_ptr M = + parseAssemblyString( +"declare <256 x i64> @llvm.vp.mul.v256i64(<256 x i64>, <256 x i64>, <256 x i1>, i32)" +"define void @test_static_vlen(<256 x i64> %i0, <256 x i64> %i1, <256 x i1> %m, i32 %vl) { " +" %r0 = call <256 x i64> @llvm.vp.mul.v256i64(<256 x i64> %i0, <256 x i64> %i1, <256 x i1> %m, i32 %vl)" +" %r1 = call <256 x i64> @llvm.vp.mul.v256i64(<256 x i64> %i0, <256 x i64> %i1, <256 x i1> %m, i32 256)" +" %r2 = call <256 x i64> @llvm.vp.mul.v256i64(<256 x i64> %i0, <256 x i64> %i1, <256 x i1> %m, i32 0)" +" %r3 = call <256 x i64> @llvm.vp.mul.v256i64(<256 x i64> %i0, <256 x i64> %i1, <256 x i1> %m, i32 -1)" +" %r4 = call <256 x i64> @llvm.vp.mul.v256i64(<256 x i64> %i0, <256 x i64> %i1, <256 x i1> %m, i32 123)" +" ret void " +"}", + Err, C); + + auto *F = M->getFunction("test_static_vlen"); + assert(F); + + const int NumExpected = 5; + const bool Expected[] = {false, true, false, true, false}; + int i = 0; + for (auto &I : F->getEntryBlock()) { + VPIntrinsic *VPI = dyn_cast(&I); + if (!VPI) { + ASSERT_TRUE(I.isTerminator()); + continue; + } + + ASSERT_LT(i, NumExpected); + ASSERT_EQ(Expected[i], VPI->canIgnoreVectorLengthParam()); + ++i; + } +} + +/// Check that the argument returned by +/// VPIntrinsic::GetParamPos(Intrinsic::ID) has the expected type. +TEST_F(VPIntrinsicTest, GetParamPos) { + std::unique_ptr M = CreateVPDeclarationModule(); + assert(M); + + for (Function &F : *M) { + ASSERT_TRUE(F.isIntrinsic()); + Optional MaskParamPos = VPIntrinsic::GetMaskParamPos(F.getIntrinsicID()); + if (MaskParamPos.hasValue()) { + Type *MaskParamType = F.getArg(MaskParamPos.getValue())->getType(); + ASSERT_TRUE(MaskParamType->isVectorTy()); + ASSERT_TRUE(MaskParamType->getVectorElementType()->isIntegerTy(1)); + } + + Optional VecLenParamPos = VPIntrinsic::GetVectorLengthParamPos(F.getIntrinsicID()); + if (VecLenParamPos.hasValue()) { + Type *VecLenParamType = F.getArg(VecLenParamPos.getValue())->getType(); + ASSERT_TRUE(VecLenParamType->isIntegerTy(32)); + } + + Optional MemPtrParamPos = VPIntrinsic::GetMemoryPointerParamPos(F.getIntrinsicID()); + if (MemPtrParamPos.hasValue()) { + Type *MemPtrParamType = F.getArg(MemPtrParamPos.getValue())->getType(); + ASSERT_TRUE(MemPtrParamType->isPtrOrPtrVectorTy()); + } + + Optional RoundingParamPos = VPIntrinsic::GetRoundingModeParamPos(F.getIntrinsicID()); + if (RoundingParamPos.hasValue()) { + Type *RoundingParamType = F.getArg(RoundingParamPos.getValue())->getType(); + ASSERT_TRUE(RoundingParamType->isMetadataTy()); + } + + Optional ExceptParamPos = VPIntrinsic::GetExceptionBehaviorParamPos(F.getIntrinsicID()); + if (ExceptParamPos.hasValue()) { + Type *ExceptParamType = F.getArg(ExceptParamPos.getValue())->getType(); + ASSERT_TRUE(ExceptParamType->isMetadataTy()); + } + } +} + +/// Check that going from Opcode to VP intrinsic and back results in the same Opcode. +TEST_F(VPIntrinsicTest, OpcodeRoundTrip) { + std::vector Opcodes; + Opcodes.reserve(100); + + { +#define HANDLE_INST(OCNum, OCName, Class) Opcodes.push_back(OCNum); +#include "llvm/IR/Instruction.def" + } + + unsigned FullTripCounts = 0; + for (unsigned OC : Opcodes) { + Intrinsic::ID VPID = VPIntrinsic::GetForOpcode(OC); + // no equivalent VP intrinsic available + if (VPID == Intrinsic::not_intrinsic) + continue; + + unsigned RoundTripOC = VPIntrinsic::GetFunctionalOpcodeForVP(VPID); + // no equivalent Opcode available + if (RoundTripOC == Instruction::Call) + continue; + + ASSERT_EQ(RoundTripOC, OC); + ++FullTripCounts; + } + ASSERT_NE(FullTripCounts, 0u); +} + +} // end anonymous namespace +} // end namespace llvm diff --git a/llvm/utils/TableGen/CodeGenIntrinsics.h b/llvm/utils/TableGen/CodeGenIntrinsics.h --- a/llvm/utils/TableGen/CodeGenIntrinsics.h +++ b/llvm/utils/TableGen/CodeGenIntrinsics.h @@ -146,7 +146,10 @@ ReadOnly, WriteOnly, ReadNone, - ImmArg + ImmArg, + Mask, + VectorLength, + Passthru }; std::vector> ArgumentAttributes; diff --git a/llvm/utils/TableGen/CodeGenTarget.cpp b/llvm/utils/TableGen/CodeGenTarget.cpp --- a/llvm/utils/TableGen/CodeGenTarget.cpp +++ b/llvm/utils/TableGen/CodeGenTarget.cpp @@ -728,8 +728,7 @@ // variants with iAny types; otherwise, if the intrinsic is not // overloaded, all the types can be specified directly. assert(((!TyEl->isSubClassOf("LLVMExtendedType") && - !TyEl->isSubClassOf("LLVMTruncatedType") && - !TyEl->isSubClassOf("LLVMScalarOrSameVectorWidth")) || + !TyEl->isSubClassOf("LLVMTruncatedType")) || VT == MVT::iAny || VT == MVT::vAny) && "Expected iAny or vAny type"); } else @@ -791,6 +790,15 @@ } else if (Property->isSubClassOf("Returned")) { unsigned ArgNo = Property->getValueAsInt("ArgNo"); ArgumentAttributes.push_back(std::make_pair(ArgNo, Returned)); + } else if (Property->isSubClassOf("VectorLength")) { + unsigned ArgNo = Property->getValueAsInt("ArgNo"); + ArgumentAttributes.push_back(std::make_pair(ArgNo, VectorLength)); + } else if (Property->isSubClassOf("Mask")) { + unsigned ArgNo = Property->getValueAsInt("ArgNo"); + ArgumentAttributes.push_back(std::make_pair(ArgNo, Mask)); + } else if (Property->isSubClassOf("Passthru")) { + unsigned ArgNo = Property->getValueAsInt("ArgNo"); + ArgumentAttributes.push_back(std::make_pair(ArgNo, Passthru)); } else if (Property->isSubClassOf("ReadOnly")) { unsigned ArgNo = Property->getValueAsInt("ArgNo"); ArgumentAttributes.push_back(std::make_pair(ArgNo, ReadOnly)); diff --git a/llvm/utils/TableGen/IntrinsicEmitter.cpp b/llvm/utils/TableGen/IntrinsicEmitter.cpp --- a/llvm/utils/TableGen/IntrinsicEmitter.cpp +++ b/llvm/utils/TableGen/IntrinsicEmitter.cpp @@ -671,6 +671,24 @@ OS << "Attribute::Returned"; addComma = true; break; + case CodeGenIntrinsic::VectorLength: + if (addComma) + OS << ","; + OS << "Attribute::VectorLength"; + addComma = true; + break; + case CodeGenIntrinsic::Mask: + if (addComma) + OS << ","; + OS << "Attribute::Mask"; + addComma = true; + break; + case CodeGenIntrinsic::Passthru: + if (addComma) + OS << ","; + OS << "Attribute::Passthru"; + addComma = true; + break; case CodeGenIntrinsic::ReadOnly: if (addComma) OS << ",";