diff --git a/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td --- a/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td +++ b/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td @@ -16,7 +16,6 @@ include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Dialect/LLVMIR/LLVMOpBase.td" include "mlir/Dialect/Arithmetic/IR/ArithmeticBase.td" -include "mlir/Dialect/ArmSVE/ArmSVEOpBase.td" //===----------------------------------------------------------------------===// // ArmSVE dialect definition @@ -27,69 +26,11 @@ let cppNamespace = "::mlir::arm_sve"; let summary = "Basic dialect to target Arm SVE architectures"; let description = [{ - This dialect contains the definitions necessary to target Arm SVE scalable - vector operations, including a scalable vector type and intrinsics for - some Arm SVE instructions. + This dialect contains the definitions necessary to target specific Arm SVE + scalable vector operations. }]; - let useDefaultTypePrinterParser = 1; } -//===----------------------------------------------------------------------===// -// ArmSVE type definitions -//===----------------------------------------------------------------------===// - -def ArmSVE_ScalableVectorType : DialectType()">, - "scalable vector type">, - BuildableType<"$_builder.getType()"> { - let description = [{ - `arm_sve.vector` represents vectors that will be processed by a scalable - vector architecture. - }]; -} - -class ArmSVE_Type : TypeDef { } - -def ScalableVectorType : ArmSVE_Type<"ScalableVector"> { - let mnemonic = "vector"; - - let summary = "Scalable vector type"; - - let description = [{ - A type representing scalable length SIMD vectors. Unlike fixed-length SIMD - vectors, whose size is constant and known at compile time, scalable - vectors' length is constant but determined by the specific hardware at - run time. - }]; - - let parameters = (ins - ArrayRefParameter<"int64_t", "Vector shape">:$shape, - "Type":$elementType - ); - - let extraClassDeclaration = [{ - bool hasStaticShape() const { - return llvm::none_of(getShape(), ShapedType::isDynamic); - } - int64_t getNumElements() const { - assert(hasStaticShape() && - "cannot get element count of dynamic shaped type"); - ArrayRef shape = getShape(); - int64_t num = 1; - for (auto dim : shape) - num *= dim; - return num; - } - }]; -} - -//===----------------------------------------------------------------------===// -// Additional LLVM type constraints -//===----------------------------------------------------------------------===// -def LLVMScalableVectorType : - Type()">, - "LLVM dialect scalable vector type">; - //===----------------------------------------------------------------------===// // ArmSVE op definitions //===----------------------------------------------------------------------===// @@ -97,16 +38,6 @@ class ArmSVE_Op traits = []> : Op {} -class ArmSVE_NonSVEIntrUnaryOverloadedOp traits =[]> : - LLVM_IntrOpBase overloadedResults=*/[0], - /*list overloadedOperands=*/[], // defined by result overload - /*list traits=*/traits, - /*int numResults=*/1>; - class ArmSVE_IntrBinaryOverloadedOp traits = []> : LLVM_IntrOpBase traits=*/traits, /*int numResults=*/1>; -class ScalableFOp traits = []> : - ArmSVE_Op])> { - let summary = op_description # " for scalable vectors of floats"; - let description = [{ - The `arm_sve.}] # mnemonic # [{` operations takes two scalable vectors and - returns one scalable vector with the result of the }] # op_description # [{. - }]; - let arguments = (ins - ScalableVectorOf<[AnyFloat]>:$src1, - ScalableVectorOf<[AnyFloat]>:$src2 - ); - let results = (outs ScalableVectorOf<[AnyFloat]>:$dst); - let assemblyFormat = - "$src1 `,` $src2 attr-dict `:` type($src1)"; -} - -class ScalableIOp traits = []> : - ArmSVE_Op])> { - let summary = op_description # " for scalable vectors of integers"; - let description = [{ - The `arm_sve.}] # mnemonic # [{` operation takes two scalable vectors and - returns one scalable vector with the result of the }] # op_description # [{. - }]; - let arguments = (ins - ScalableVectorOf<[I8, I16, I32, I64]>:$src1, - ScalableVectorOf<[I8, I16, I32, I64]>:$src2 - ); - let results = (outs ScalableVectorOf<[I8, I16, I32, I64]>:$dst); - let assemblyFormat = - "$src1 `,` $src2 attr-dict `:` type($src1)"; -} - class ScalableMaskedFOp traits = []> : ArmSVE_Op { - let summary = "Load vector scale size"; - let description = [{ - The vector_scale op returns the scale of the scalable vectors, a positive - integer value that is constant at runtime but unknown at compile time. - The scale of the vector indicates the multiplicity of the vectors and - vector operations. I.e.: an !arm_sve.vector<4xi32> is equivalent to - vector_scale consecutive vector<4xi32>; and an operation on an - !arm_sve.vector<4xi32> is equivalent to performing that operation vector_scale - times, once on each <4xi32> segment of the scalable vector. The vector_scale - op can be used to calculate the step in vector-length agnostic (VLA) loops. - }]; - let results = (outs Index:$res); - let assemblyFormat = - "attr-dict `:` type($res)"; -} - -def ScalableLoadOp : ArmSVE_Op<"load">, - Arguments<(ins Arg:$base, Index:$index)>, - Results<(outs ScalableVectorOf<[AnyType]>:$result)> { - let summary = "Load scalable vector from memory"; - let description = [{ - Load a slice of memory into a scalable vector. - }]; - let extraClassDeclaration = [{ - MemRefType getMemRefType() { - return base().getType().cast(); - } - }]; - let assemblyFormat = "$base `[` $index `]` attr-dict `:` " - "type($result) `from` type($base)"; -} - -def ScalableStoreOp : ArmSVE_Op<"store">, - Arguments<(ins Arg:$base, Index:$index, - ScalableVectorOf<[AnyType]>:$value)> { - let summary = "Store scalable vector into memory"; - let description = [{ - Store a scalable vector on a slice of memory. - }]; - let extraClassDeclaration = [{ - MemRefType getMemRefType() { - return base().getType().cast(); - } - }]; - let assemblyFormat = "$value `,` $base `[` $index `]` attr-dict `:` " - "type($value) `to` type($base)"; -} - -def ScalableAddIOp : ScalableIOp<"addi", "addition", [Commutative]>; - -def ScalableAddFOp : ScalableFOp<"addf", "addition", [Commutative]>; - -def ScalableSubIOp : ScalableIOp<"subi", "subtraction">; - -def ScalableSubFOp : ScalableFOp<"subf", "subtraction">; - -def ScalableMulIOp : ScalableIOp<"muli", "multiplication", [Commutative]>; - -def ScalableMulFOp : ScalableFOp<"mulf", "multiplication", [Commutative]>; - -def ScalableSDivIOp : ScalableIOp<"divi_signed", "signed division">; - -def ScalableUDivIOp : ScalableIOp<"divi_unsigned", "unsigned division">; - -def ScalableDivFOp : ScalableFOp<"divf", "division">; - def ScalableMaskedAddIOp : ScalableMaskedIOp<"masked.addi", "addition", [Commutative]>; @@ -417,189 +244,56 @@ def ScalableMaskedDivFOp : ScalableMaskedFOp<"masked.divf", "division">; -//===----------------------------------------------------------------------===// -// ScalableCmpFOp -//===----------------------------------------------------------------------===// - -def ScalableCmpFOp : ArmSVE_Op<"cmpf", [NoSideEffect, SameTypeOperands, - TypesMatchWith<"result type has i1 element type and same shape as operands", - "lhs", "result", "getI1SameShape($_self)">]> { - let summary = "floating-point comparison operation for scalable vectors"; - let description = [{ - The `arm_sve.cmpf` operation compares two scalable vectors of floating point - elements according to the float comparison rules and the predicate specified - by the respective attribute. The predicate defines the type of comparison: - (un)orderedness, (in)equality and signed less/greater than (or equal to) as - well as predicates that are always true or false. The result is a scalable - vector of i1 elements. Unlike `arm_sve.cmpi`, the operands are always - treated as signed. The u prefix indicates *unordered* comparison, not - unsigned comparison, so "une" means unordered not equal. For the sake of - readability by humans, custom assembly form for the operation uses a - string-typed attribute for the predicate. The value of this attribute - corresponds to lower-cased name of the predicate constant, e.g., "one" means - "ordered not equal". The string representation of the attribute is merely a - syntactic sugar and is converted to an integer attribute by the parser. - - Example: - - ```mlir - %r = arm_sve.cmpf oeq, %0, %1 : !arm_sve.vector<4xf32> - ``` - }]; - let arguments = (ins - Arith_CmpFPredicateAttr:$predicate, - ScalableVectorOf<[AnyFloat]>:$lhs, - ScalableVectorOf<[AnyFloat]>:$rhs // TODO: This should support a simple scalar - ); - let results = (outs ScalableVectorOf<[I1]>:$result); - - let builders = [ - OpBuilder<(ins "arith::CmpFPredicate":$predicate, "Value":$lhs, - "Value":$rhs), [{ - buildScalableCmpFOp($_builder, $_state, predicate, lhs, rhs); - }]>]; - - let extraClassDeclaration = [{ - static StringRef getPredicateAttrName() { return "predicate"; } - static arith::CmpFPredicate getPredicateByName(StringRef name); - - arith::CmpFPredicate getPredicate() { - return (arith::CmpFPredicate) (*this)->getAttrOfType( - getPredicateAttrName()).getInt(); - } - }]; - - let verifier = [{ return success(); }]; - - let assemblyFormat = "$predicate `,` $lhs `,` $rhs attr-dict `:` type($lhs)"; -} - -//===----------------------------------------------------------------------===// -// ScalableCmpIOp -//===----------------------------------------------------------------------===// - -def ScalableCmpIOp : ArmSVE_Op<"cmpi", [NoSideEffect, SameTypeOperands, - TypesMatchWith<"result type has i1 element type and same shape as operands", - "lhs", "result", "getI1SameShape($_self)">]> { - let summary = "integer comparison operation for scalable vectors"; - let description = [{ - The `arm_sve.cmpi` operation compares two scalable vectors of integer - elements according to the predicate specified by the respective attribute. - - The predicate defines the type of comparison: - - - equal (mnemonic: `"eq"`; integer value: `0`) - - not equal (mnemonic: `"ne"`; integer value: `1`) - - signed less than (mnemonic: `"slt"`; integer value: `2`) - - signed less than or equal (mnemonic: `"sle"`; integer value: `3`) - - signed greater than (mnemonic: `"sgt"`; integer value: `4`) - - signed greater than or equal (mnemonic: `"sge"`; integer value: `5`) - - unsigned less than (mnemonic: `"ult"`; integer value: `6`) - - unsigned less than or equal (mnemonic: `"ule"`; integer value: `7`) - - unsigned greater than (mnemonic: `"ugt"`; integer value: `8`) - - unsigned greater than or equal (mnemonic: `"uge"`; integer value: `9`) - - Example: - - ```mlir - %r = arm_sve.cmpi uge, %0, %1 : !arm_sve.vector<4xi32> - ``` - }]; - - let arguments = (ins - Arith_CmpIPredicateAttr:$predicate, - ScalableVectorOf<[I8, I16, I32, I64]>:$lhs, - ScalableVectorOf<[I8, I16, I32, I64]>:$rhs - ); - let results = (outs ScalableVectorOf<[I1]>:$result); - - let builders = [ - OpBuilder<(ins "arith::CmpIPredicate":$predicate, "Value":$lhs, - "Value":$rhs), [{ - buildScalableCmpIOp($_builder, $_state, predicate, lhs, rhs); - }]>]; - - let extraClassDeclaration = [{ - static StringRef getPredicateAttrName() { return "predicate"; } - static arith::CmpIPredicate getPredicateByName(StringRef name); - - arith::CmpIPredicate getPredicate() { - return (arith::CmpIPredicate) (*this)->getAttrOfType( - getPredicateAttrName()).getInt(); - } - }]; - - let verifier = [{ return success(); }]; - - let assemblyFormat = "$predicate `,` $lhs `,` $rhs attr-dict `:` type($lhs)"; -} - def UmmlaIntrOp : ArmSVE_IntrBinaryOverloadedOp<"ummla">, - Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType, - LLVMScalableVectorType)>; + Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>; def SmmlaIntrOp : ArmSVE_IntrBinaryOverloadedOp<"smmla">, - Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType, - LLVMScalableVectorType)>; + Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>; def SdotIntrOp : ArmSVE_IntrBinaryOverloadedOp<"sdot">, - Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType, - LLVMScalableVectorType)>; + Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>; def UdotIntrOp : ArmSVE_IntrBinaryOverloadedOp<"udot">, - Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType, - LLVMScalableVectorType)>; + Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>; def ScalableMaskedAddIIntrOp : ArmSVE_IntrBinaryOverloadedOp<"add">, - Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType, - LLVMScalableVectorType)>; + Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>; def ScalableMaskedAddFIntrOp : ArmSVE_IntrBinaryOverloadedOp<"fadd">, - Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType, - LLVMScalableVectorType)>; + Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>; def ScalableMaskedMulIIntrOp : ArmSVE_IntrBinaryOverloadedOp<"mul">, - Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType, - LLVMScalableVectorType)>; + Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>; def ScalableMaskedMulFIntrOp : ArmSVE_IntrBinaryOverloadedOp<"fmul">, - Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType, - LLVMScalableVectorType)>; + Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>; def ScalableMaskedSubIIntrOp : ArmSVE_IntrBinaryOverloadedOp<"sub">, - Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType, - LLVMScalableVectorType)>; + Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>; def ScalableMaskedSubFIntrOp : ArmSVE_IntrBinaryOverloadedOp<"fsub">, - Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType, - LLVMScalableVectorType)>; + Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>; def ScalableMaskedSDivIIntrOp : ArmSVE_IntrBinaryOverloadedOp<"sdiv">, - Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType, - LLVMScalableVectorType)>; + Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>; def ScalableMaskedUDivIIntrOp : ArmSVE_IntrBinaryOverloadedOp<"udiv">, - Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType, - LLVMScalableVectorType)>; + Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>; def ScalableMaskedDivFIntrOp : ArmSVE_IntrBinaryOverloadedOp<"fdiv">, - Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType, - LLVMScalableVectorType)>; - -def VectorScaleIntrOp: - ArmSVE_NonSVEIntrUnaryOverloadedOp<"vscale">; + Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>; #endif // ARMSVE_OPS diff --git a/mlir/include/mlir/Dialect/ArmSVE/ArmSVEDialect.h b/mlir/include/mlir/Dialect/ArmSVE/ArmSVEDialect.h --- a/mlir/include/mlir/Dialect/ArmSVE/ArmSVEDialect.h +++ b/mlir/include/mlir/Dialect/ArmSVE/ArmSVEDialect.h @@ -21,9 +21,6 @@ #include "mlir/Dialect/ArmSVE/ArmSVEDialect.h.inc" #include "mlir/Dialect/StandardOps/IR/Ops.h" -#define GET_TYPEDEF_CLASSES -#include "mlir/Dialect/ArmSVE/ArmSVETypes.h.inc" - #define GET_OP_CLASSES #include "mlir/Dialect/ArmSVE/ArmSVE.h.inc" diff --git a/mlir/include/mlir/Dialect/ArmSVE/ArmSVEOpBase.td b/mlir/include/mlir/Dialect/ArmSVE/ArmSVEOpBase.td deleted file mode 100644 --- a/mlir/include/mlir/Dialect/ArmSVE/ArmSVEOpBase.td +++ /dev/null @@ -1,53 +0,0 @@ -//===-- ArmSVEOpBase.td - Base op definitions for ArmSVE ---*- tablegen -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This is the base operation definition file for ArmSVE scalable vector types. -// -//===----------------------------------------------------------------------===// - -#ifndef ARMSVE_OP_BASE -#define ARMSVE_OP_BASE - -//===----------------------------------------------------------------------===// -// ArmSVE scalable vector type constraints -//===----------------------------------------------------------------------===// - -def IsScalableVectorTypePred : - CPred<"$_self.isa<::mlir::arm_sve::ScalableVectorType>()">; - -class ScalableVectorOf allowedTypes> : - ContainerType, IsScalableVectorTypePred, - "$_self.cast<::mlir::arm_sve::ScalableVectorType>().getElementType()", - "scalable vector">; - -// Whether the number of elements of a scalable vector is from the given -// `allowedLengths` list -class IsScalableVectorOfLengthPred allowedLengths> : - And<[IsScalableVectorTypePred, - Or().getNumElements() == }] - # allowedlength>)>]>; - -// Any scalable vector where the number of elements is from the given -// `allowedLengths` list -class ScalableVectorOfLength allowedLengths> : Type< - IsScalableVectorOfLengthPred, - " of length " # !interleave(allowedLengths, "/"), - "::mlir::arm_sve::ScalableVectorType">; - -// Any scalable vector where the number of elements is from the given -// `allowedLengths` list and the type is from the given `allowedTypes` list -class ScalableVectorOfLengthAndType allowedLengths, - list allowedTypes> : Type< - And<[ScalableVectorOf.predicate, - ScalableVectorOfLength.predicate]>, - ScalableVectorOf.summary # - ScalableVectorOfLength.summary, - "::mlir::arm_sve::ScalableVectorType">; - -#endif // ARMSVE_OP_BASE diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -1734,7 +1734,9 @@ let arguments = (ins LLVM_Type, LLVM_Type, LLVM_Type); } -// +/// Create a call to vscale intrinsic. +def LLVM_vscale : LLVM_IntrOp<"vscale", [0], [], [], 1>; + // Atomic operations. // diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h @@ -483,10 +483,22 @@ /// Returns the element count of any LLVM-compatible vector type. llvm::ElementCount getVectorNumElements(Type type); +/// Returns whether a vector type is scalable or not. +bool isScalableVectorType(Type vectorType); + +/// Creates an LLVM dialect-compatible vector type with the given element type +/// and length. +Type getVectorType(Type elementType, unsigned numElements, + bool isScalable = false); + /// Creates an LLVM dialect-compatible type with the given element type and /// length. Type getFixedVectorType(Type elementType, unsigned numElements); +/// Creates an LLVM dialect-compatible type with the given element type and +/// length. +Type getScalableVectorType(Type elementType, unsigned numElements); + /// Returns the size of the given primitive LLVM dialect-compatible type /// (including vectors) in bits, for example, the size of i16 is 16 and /// the size of vector<4xi16> is 64. Returns 0 for non-primitive diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -2383,4 +2383,36 @@ let assemblyFormat = "$matrix attr-dict `:` type($matrix) `->` type($res)"; } +//===----------------------------------------------------------------------===// +// VectorScaleOp +//===----------------------------------------------------------------------===// + +// TODO: In the future, we might want to have scalable vectors with different +// scales for different dimensions. E.g.: vector<[16]x[16]xf32>, in +// which case we might need to add an index to 'vscale' to select one +// of them. In order to support GPUs, we might also want to differentiate +// between a 'global' scale, a scale that's fixed throughout the +// execution, and a 'local' scale that is fixed but might vary with each +// call to the function. For that, it might be useful to have a +// 'vector.scale.global' and a 'vector.scale.local' operation. +def VectorScaleOp : Vector_Op<"vscale", + [NoSideEffect]> { + let summary = "Load vector scale size"; + let description = [{ + The `vscale` op returns the scale of the scalable vectors, a positive + integer value that is constant at runtime but unknown at compile-time. + The scale of the vector indicates the multiplicity of the vectors and + vector operations. For example, a `vector<[4]xi32>` is equivalent to + `vscale` consecutive `vector<4xi32>`; and an operation on a + `vector<[4]xi32>` is equivalent to performing that operation `vscale` + times, once on each `<4xi32>` segment of the scalable vector. The `vscale` + op can be used to calculate the step in vector-length agnostic (VLA) loops. + Right now we only support one contiguous set of scalable dimensions, all of + them grouped and scaled with the value returned by 'vscale'. + }]; + let results = (outs Index:$res); + let assemblyFormat = "attr-dict"; + let verifier = ?; +} + #endif // VECTOR_OPS diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -315,13 +315,18 @@ public: /// Build from another VectorType. explicit Builder(VectorType other) - : shape(other.getShape()), elementType(other.getElementType()) {} + : shape(other.getShape()), elementType(other.getElementType()), + numScalableDims(other.getNumScalableDims()) {} /// Build from scratch. - Builder(ArrayRef shape, Type elementType) - : shape(shape), elementType(elementType) {} - - Builder &setShape(ArrayRef newShape) { + Builder(ArrayRef shape, Type elementType, + unsigned numScalableDims = 0) + : shape(shape), elementType(elementType), + numScalableDims(numScalableDims) {} + + Builder &setShape(ArrayRef newShape, + unsigned newNumScalableDims = 0) { + numScalableDims = newNumScalableDims; shape = newShape; return *this; } @@ -334,6 +339,8 @@ /// Erase a dim from shape @pos. Builder &dropDim(unsigned pos) { assert(pos < shape.size() && "overflow"); + if (pos >= shape.size() - numScalableDims) + numScalableDims--; if (storage.empty()) storage.append(shape.begin(), shape.end()); storage.erase(storage.begin() + pos); @@ -347,7 +354,7 @@ operator Type() { if (shape.empty()) return elementType; - return VectorType::get(shape, elementType); + return VectorType::get(shape, elementType, numScalableDims); } private: @@ -355,6 +362,7 @@ // Owning shape data for copy-on-write operations. SmallVector storage; Type elementType; + unsigned numScalableDims; }; /// Given an `originalShape` and a `reducedShape` assumed to be a subset of diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -892,16 +892,21 @@ Syntax: ``` - vector-type ::= `vector` `<` static-dimension-list vector-element-type `>` + vector-type ::= `vector` `<` vector-dim-list vector-element-type `>` vector-element-type ::= float-type | integer-type | index-type - - static-dimension-list ::= (decimal-literal `x`)* + vector-dim-list := (static-dim-list `x`)? (`[` static-dim-list `]` `x`)? + static-dim-list ::= decimal-literal (`x` decimal-literal)* ``` - The vector type represents a SIMD style vector, used by target-specific - operation sets like AVX. While the most common use is for 1D vectors (e.g. - vector<16 x f32>) we also support multidimensional registers on targets that - support them (like TPUs). + The vector type represents a SIMD style vector used by target-specific + operation sets like AVX or SVE. While the most common use is for 1D + vectors (e.g. vector<16 x f32>) we also support multidimensional registers + on targets that support them (like TPUs). The dimensions of a vector type + can be fixed-length, scalable, or a combination of the two. The scalable + dimensions in a vector are indicated between square brackets ([ ]), and + all fixed-length dimensions, if present, must precede the set of scalable + dimensions. That is, a `vector<2x[4]xf32>` is valid, but `vector<[4]x2xf32>` + is not. Vector shapes must be positive decimal integers. 0D vectors are allowed by omitting the dimension: `vector`. @@ -913,19 +918,31 @@ Examples: ```mlir + // A 2D fixed-length vector of 3x42 i32 elements. vector<3x42xi32> + + // A 1D scalable-length vector that contains a multiple of 4 f32 elements. + vector<[4]xf32> + + // A 2D scalable-length vector that contains a multiple of 2x8 i8 elements. + vector<[2x8]xf32> + + // A 2D mixed fixed/scalable vector that contains 4 scalable vectors of 4 f32 elements. + vector<4x[4]xf32> ``` }]; let parameters = (ins ArrayRefParameter<"int64_t">:$shape, - "Type":$elementType + "Type":$elementType, + "unsigned":$numScalableDims ); - let builders = [ TypeBuilderWithInferredContext<(ins - "ArrayRef":$shape, "Type":$elementType + "ArrayRef":$shape, "Type":$elementType, + CArg<"unsigned", "0">:$numScalableDims ), [{ - return $_get(elementType.getContext(), shape, elementType); + return $_get(elementType.getContext(), shape, elementType, + numScalableDims); }]> ]; let extraClassDeclaration = [{ @@ -933,13 +950,18 @@ /// Arguments that are passed into the builder must outlive the builder. class Builder; - /// Returns true of the given type can be used as an element of a vector + /// Returns true if the given type can be used as an element of a vector /// type. In particular, vectors can consist of integer, index, or float /// primitives. static bool isValidElementType(Type t) { return t.isa(); } + /// Returns true if the vector contains scalable dimensions. + bool isScalable() const { + return getNumScalableDims() > 0; + } + /// Get or create a new VectorType with the same shape as `this` and an /// element type of bitwidth scaled by `scale`. /// Return null if the scaled element type cannot be represented. diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -216,6 +216,14 @@ // TODO: Remove this when all ops support 0-D vectors. def IsVectorOfAnyRankTypePred : CPred<"$_self.isa<::mlir::VectorType>()">; +// Whether a type is a fixed-length VectorType. +def IsFixedVectorTypePred : CPred<[{$_self.isa<::mlir::VectorType>() && + !$_self.cast().isScalable()}]>; + +// Whether a type is a scalable VectorType. +def IsScalableVectorTypePred : CPred<[{$_self.isa<::mlir::VectorType>() && + $_self.cast().isScalable()}]>; + // Whether a type is a TensorType. def IsTensorTypePred : CPred<"$_self.isa<::mlir::TensorType>()">; @@ -611,6 +619,14 @@ ShapedContainerType; +class FixedVectorOf allowedTypes> : + ShapedContainerType; + +class ScalableVectorOf allowedTypes> : + ShapedContainerType; + // Whether the number of elements of a vector is from the given // `allowedRanks` list class IsVectorOfRankPred allowedRanks> : @@ -643,6 +659,24 @@ == }] # allowedlength>)>]>; +// Whether the number of elements of a fixed-length vector is from the given +// `allowedLengths` list +class IsFixedVectorOfLengthPred allowedLengths> : + And<[IsFixedVectorTypePred, + Or().getNumElements() + == }] + # allowedlength>)>]>; + +// Whether the number of elements of a scalable vector is from the given +// `allowedLengths` list +class IsScalableVectorOfLengthPred allowedLengths> : + And<[IsScalableVectorTypePred, + Or().getNumElements() + == }] + # allowedlength>)>]>; + // Any vector where the number of elements is from the given // `allowedLengths` list class VectorOfLength allowedLengths> : Type< @@ -650,6 +684,20 @@ " of length " # !interleave(allowedLengths, "/"), "::mlir::VectorType">; +// Any fixed-length vector where the number of elements is from the given +// `allowedLengths` list +class FixedVectorOfLength allowedLengths> : Type< + IsFixedVectorOfLengthPred, + " of length " # !interleave(allowedLengths, "/"), + "::mlir::VectorType">; + +// Any scalable vector where the number of elements is from the given +// `allowedLengths` list +class ScalableVectorOfLength allowedLengths> : Type< + IsScalableVectorOfLengthPred, + " of length " # !interleave(allowedLengths, "/"), + "::mlir::VectorType">; + // Any vector where the number of elements is from the given // `allowedLengths` list and the type is from the given `allowedTypes` // list @@ -660,10 +708,34 @@ VectorOf.summary # VectorOfLength.summary, "::mlir::VectorType">; +// Any fixed-length vector where the number of elements is from the given +// `allowedLengths` list and the type is from the given `allowedTypes` list +class FixedVectorOfLengthAndType allowedLengths, + list allowedTypes> : Type< + And<[FixedVectorOf.predicate, + FixedVectorOfLength.predicate]>, + FixedVectorOf.summary # + FixedVectorOfLength.summary, + "::mlir::VectorType">; + +// Any scalable vector where the number of elements is from the given +// `allowedLengths` list and the type is from the given `allowedTypes` list +class ScalableVectorOfLengthAndType allowedLengths, + list allowedTypes> : Type< + And<[ScalableVectorOf.predicate, + ScalableVectorOfLength.predicate]>, + ScalableVectorOf.summary # + ScalableVectorOfLength.summary, + "::mlir::VectorType">; + def AnyVector : VectorOf<[AnyType]>; // Temporary vector type clone that allows gradual transition to 0-D vectors. def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>; +def AnyFixedVector : FixedVectorOf<[AnyType]>; + +def AnyScalableVector : ScalableVectorOf<[AnyType]>; + // Shaped types. def AnyShaped: ShapedContainerType<[AnyType], IsShapedTypePred, "shaped", diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp --- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp +++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp @@ -411,7 +411,8 @@ return {}; if (type.getShape().empty()) return VectorType::get({1}, elementType); - Type vectorType = VectorType::get(type.getShape().back(), elementType); + Type vectorType = VectorType::get(type.getShape().back(), elementType, + type.getNumScalableDims()); assert(LLVM::isCompatibleVectorType(vectorType) && "expected vector type compatible with the LLVM dialect"); auto shape = type.getShape(); diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -26,13 +26,21 @@ // Helper to reduce vector type by one rank at front. static VectorType reducedVectorTypeFront(VectorType tp) { assert((tp.getRank() > 1) && "unlowerable vector type"); - return VectorType::get(tp.getShape().drop_front(), tp.getElementType()); + unsigned numScalableDims = tp.getNumScalableDims(); + if (tp.getShape().size() == numScalableDims) + --numScalableDims; + return VectorType::get(tp.getShape().drop_front(), tp.getElementType(), + numScalableDims); } // Helper to reduce vector type by *all* but one rank at back. static VectorType reducedVectorTypeBack(VectorType tp) { assert((tp.getRank() > 1) && "unlowerable vector type"); - return VectorType::get(tp.getShape().take_back(), tp.getElementType()); + unsigned numScalableDims = tp.getNumScalableDims(); + if (numScalableDims > 0) + --numScalableDims; + return VectorType::get(tp.getShape().take_back(), tp.getElementType(), + numScalableDims); } // Helper that picks the proper sequence for inserting. @@ -112,6 +120,10 @@ namespace { +/// Trivial Vector to LLVM conversions +using VectorScaleOpConversion = + OneToOneConvertToLLVMPattern; + /// Conversion pattern for a vector.bitcast. class VectorBitCastOpConversion : public ConvertOpToLLVMPattern { @@ -1064,7 +1076,7 @@ VectorExtractElementOpConversion, VectorExtractOpConversion, VectorFMAOp1DConversion, VectorInsertElementOpConversion, VectorInsertOpConversion, VectorPrintOpConversion, - VectorTypeCastOpConversion, + VectorTypeCastOpConversion, VectorScaleOpConversion, VectorLoadStoreConversion, VectorLoadStoreConversion, diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -999,7 +999,8 @@ if (type.isa()) return UnrankedTensorType::get(i1Type); if (auto vectorType = type.dyn_cast()) - return VectorType::get(vectorType.getShape(), i1Type); + return VectorType::get(vectorType.getShape(), i1Type, + vectorType.getNumScalableDims()); return i1Type; } diff --git a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp --- a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp +++ b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp @@ -25,12 +25,6 @@ #include "mlir/Dialect/ArmSVE/ArmSVEDialect.cpp.inc" static Type getI1SameShape(Type type); -static void buildScalableCmpIOp(OpBuilder &build, OperationState &result, - arith::CmpIPredicate predicate, Value lhs, - Value rhs); -static void buildScalableCmpFOp(OpBuilder &build, OperationState &result, - arith::CmpFPredicate predicate, Value lhs, - Value rhs); #define GET_OP_CLASSES #include "mlir/Dialect/ArmSVE/ArmSVE.cpp.inc" @@ -43,31 +37,6 @@ #define GET_OP_LIST #include "mlir/Dialect/ArmSVE/ArmSVE.cpp.inc" >(); - addTypes< -#define GET_TYPEDEF_LIST -#include "mlir/Dialect/ArmSVE/ArmSVETypes.cpp.inc" - >(); -} - -//===----------------------------------------------------------------------===// -// ScalableVectorType -//===----------------------------------------------------------------------===// - -void ScalableVectorType::print(AsmPrinter &printer) const { - printer << "<"; - for (int64_t dim : getShape()) - printer << dim << 'x'; - printer << getElementType() << '>'; -} - -Type ScalableVectorType::parse(AsmParser &parser) { - SmallVector dims; - Type eltType; - if (parser.parseLess() || - parser.parseDimensionList(dims, /*allowDynamic=*/false) || - parser.parseType(eltType) || parser.parseGreater()) - return {}; - return ScalableVectorType::get(eltType.getContext(), dims, eltType); } //===----------------------------------------------------------------------===// @@ -77,30 +46,8 @@ // Return the scalable vector of the same shape and containing i1. static Type getI1SameShape(Type type) { auto i1Type = IntegerType::get(type.getContext(), 1); - if (auto sVectorType = type.dyn_cast()) - return ScalableVectorType::get(type.getContext(), sVectorType.getShape(), - i1Type); + if (auto sVectorType = type.dyn_cast()) + return VectorType::get(sVectorType.getShape(), i1Type, + sVectorType.getNumScalableDims()); return nullptr; } - -//===----------------------------------------------------------------------===// -// CmpFOp -//===----------------------------------------------------------------------===// - -static void buildScalableCmpFOp(OpBuilder &build, OperationState &result, - arith::CmpFPredicate predicate, Value lhs, - Value rhs) { - result.addOperands({lhs, rhs}); - result.types.push_back(getI1SameShape(lhs.getType())); - result.addAttribute(ScalableCmpFOp::getPredicateAttrName(), - build.getI64IntegerAttr(static_cast(predicate))); -} - -static void buildScalableCmpIOp(OpBuilder &build, OperationState &result, - arith::CmpIPredicate predicate, Value lhs, - Value rhs) { - result.addOperands({lhs, rhs}); - result.types.push_back(getI1SameShape(lhs.getType())); - result.addAttribute(ScalableCmpIOp::getPredicateAttrName(), - build.getI64IntegerAttr(static_cast(predicate))); -} diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp --- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp @@ -18,29 +18,6 @@ using namespace mlir; using namespace mlir::arm_sve; -// Extract an LLVM IR type from the LLVM IR dialect type. -static Type unwrap(Type type) { - if (!type) - return nullptr; - auto *mlirContext = type.getContext(); - if (!LLVM::isCompatibleType(type)) - emitError(UnknownLoc::get(mlirContext), - "conversion resulted in a non-LLVM type"); - return type; -} - -static Optional -convertScalableVectorTypeToLLVM(ScalableVectorType svType, - LLVMTypeConverter &converter) { - auto elementType = unwrap(converter.convertType(svType.getElementType())); - if (!elementType) - return {}; - - auto sVectorType = - LLVM::LLVMScalableVectorType::get(elementType, svType.getShape().back()); - return sVectorType; -} - template class ForwardOperands : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -70,22 +47,10 @@ } }; -static Optional addUnrealizedCast(OpBuilder &builder, - ScalableVectorType svType, - ValueRange inputs, Location loc) { - if (inputs.size() != 1 || - !inputs[0].getType().isa()) - return Value(); - return builder.create(loc, svType, inputs) - .getResult(0); -} - using SdotOpLowering = OneToOneConvertToLLVMPattern; using SmmlaOpLowering = OneToOneConvertToLLVMPattern; using UdotOpLowering = OneToOneConvertToLLVMPattern; using UmmlaOpLowering = OneToOneConvertToLLVMPattern; -using VectorScaleOpLowering = - OneToOneConvertToLLVMPattern; using ScalableMaskedAddIOpLowering = OneToOneConvertToLLVMPattern; @@ -114,136 +79,10 @@ OneToOneConvertToLLVMPattern; -// Load operation is lowered to code that obtains a pointer to the indexed -// element and loads from it. -struct ScalableLoadOpLowering : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(ScalableLoadOp loadOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto type = loadOp.getMemRefType(); - if (!isConvertibleAndHasIdentityMaps(type)) - return failure(); - - LLVMTypeConverter converter(loadOp.getContext()); - - auto resultType = loadOp.result().getType(); - LLVM::LLVMPointerType llvmDataTypePtr; - if (resultType.isa()) { - llvmDataTypePtr = - LLVM::LLVMPointerType::get(resultType.cast()); - } else if (resultType.isa()) { - llvmDataTypePtr = LLVM::LLVMPointerType::get( - convertScalableVectorTypeToLLVM(resultType.cast(), - converter) - .getValue()); - } - Value dataPtr = getStridedElementPtr(loadOp.getLoc(), type, adaptor.base(), - adaptor.index(), rewriter); - Value bitCastedPtr = rewriter.create( - loadOp.getLoc(), llvmDataTypePtr, dataPtr); - rewriter.replaceOpWithNewOp(loadOp, bitCastedPtr); - return success(); - } -}; - -// Store operation is lowered to code that obtains a pointer to the indexed -// element, and stores the given value to it. -struct ScalableStoreOpLowering - : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(ScalableStoreOp storeOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto type = storeOp.getMemRefType(); - if (!isConvertibleAndHasIdentityMaps(type)) - return failure(); - - LLVMTypeConverter converter(storeOp.getContext()); - - auto resultType = storeOp.value().getType(); - LLVM::LLVMPointerType llvmDataTypePtr; - if (resultType.isa()) { - llvmDataTypePtr = - LLVM::LLVMPointerType::get(resultType.cast()); - } else if (resultType.isa()) { - llvmDataTypePtr = LLVM::LLVMPointerType::get( - convertScalableVectorTypeToLLVM(resultType.cast(), - converter) - .getValue()); - } - Value dataPtr = getStridedElementPtr(storeOp.getLoc(), type, adaptor.base(), - adaptor.index(), rewriter); - Value bitCastedPtr = rewriter.create( - storeOp.getLoc(), llvmDataTypePtr, dataPtr); - rewriter.replaceOpWithNewOp(storeOp, adaptor.value(), - bitCastedPtr); - return success(); - } -}; - -static void -populateBasicSVEArithmeticExportPatterns(LLVMTypeConverter &converter, - OwningRewritePatternList &patterns) { - // clang-format off - patterns.add, - OneToOneConvertToLLVMPattern, - OneToOneConvertToLLVMPattern, - OneToOneConvertToLLVMPattern, - OneToOneConvertToLLVMPattern, - OneToOneConvertToLLVMPattern, - OneToOneConvertToLLVMPattern, - OneToOneConvertToLLVMPattern, - OneToOneConvertToLLVMPattern - >(converter); - // clang-format on -} - -static void -configureBasicSVEArithmeticLegalizations(LLVMConversionTarget &target) { - // clang-format off - target.addIllegalOp(); - // clang-format on -} - -static void -populateSVEMaskGenerationExportPatterns(LLVMTypeConverter &converter, - OwningRewritePatternList &patterns) { - // clang-format off - patterns.add, - OneToOneConvertToLLVMPattern - >(converter); - // clang-format on -} - -static void -configureSVEMaskGenerationLegalizations(LLVMConversionTarget &target) { - // clang-format off - target.addIllegalOp(); - // clang-format on -} - /// Populate the given list with patterns that convert from ArmSVE to LLVM. void mlir::populateArmSVELegalizeForLLVMExportPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { // Populate conversion patterns - // Remove any ArmSVE-specific types from function signatures and results. - populateFuncOpTypeConversionPattern(patterns, converter); - converter.addConversion([&converter](ScalableVectorType svType) { - return convertScalableVectorTypeToLLVM(svType, converter); - }); - converter.addSourceMaterialization(addUnrealizedCast); // clang-format off patterns.add, @@ -254,7 +93,6 @@ SmmlaOpLowering, UdotOpLowering, UmmlaOpLowering, - VectorScaleOpLowering, ScalableMaskedAddIOpLowering, ScalableMaskedAddFOpLowering, ScalableMaskedSubIOpLowering, @@ -264,11 +102,7 @@ ScalableMaskedSDivIOpLowering, ScalableMaskedUDivIOpLowering, ScalableMaskedDivFOpLowering>(converter); - patterns.add(converter); // clang-format on - populateBasicSVEArithmeticExportPatterns(converter, patterns); - populateSVEMaskGenerationExportPatterns(converter, patterns); } void mlir::configureArmSVELegalizeForExportTarget( @@ -278,7 +112,6 @@ SmmlaIntrOp, UdotIntrOp, UmmlaIntrOp, - VectorScaleIntrOp, ScalableMaskedAddIIntrOp, ScalableMaskedAddFIntrOp, ScalableMaskedSubIIntrOp, @@ -292,7 +125,6 @@ SmmlaOp, UdotOp, UmmlaOp, - VectorScaleOp, ScalableMaskedAddIOp, ScalableMaskedAddFOp, ScalableMaskedSubIOp, @@ -301,25 +133,6 @@ ScalableMaskedMulFOp, ScalableMaskedSDivIOp, ScalableMaskedUDivIOp, - ScalableMaskedDivFOp, - ScalableLoadOp, - ScalableStoreOp>(); + ScalableMaskedDivFOp>(); // clang-format on - auto hasScalableVectorType = [](TypeRange types) { - for (Type type : types) - if (type.isa()) - return true; - return false; - }; - target.addDynamicallyLegalOp([hasScalableVectorType](FuncOp op) { - return !hasScalableVectorType(op.getType().getInputs()) && - !hasScalableVectorType(op.getType().getResults()); - }); - target.addDynamicallyLegalOp( - [hasScalableVectorType](Operation *op) { - return !hasScalableVectorType(op->getOperandTypes()) && - !hasScalableVectorType(op->getResultTypes()); - }); - configureBasicSVEArithmeticLegalizations(target); - configureSVEMaskGenerationLegalizations(target); } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -155,12 +155,14 @@ return parser.emitError(trailingTypeLoc, "expected LLVM dialect-compatible type"); if (LLVM::isCompatibleVectorType(type)) { - if (type.isa()) { - resultType = LLVM::LLVMScalableVectorType::get( - resultType, LLVM::getVectorNumElements(type).getKnownMinValue()); + if (LLVM::isScalableVectorType(type)) { + resultType = LLVM::getVectorType( + resultType, LLVM::getVectorNumElements(type).getKnownMinValue(), + /*isScalable=*/true); } else { - resultType = LLVM::getFixedVectorType( - resultType, LLVM::getVectorNumElements(type).getFixedValue()); + resultType = LLVM::getVectorType( + resultType, LLVM::getVectorNumElements(type).getFixedValue(), + /*isScalable=*/false); } } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -775,7 +775,12 @@ llvm::ElementCount mlir::LLVM::getVectorNumElements(Type type) { return llvm::TypeSwitch(type) - .Case([](auto ty) { + .Case([](VectorType ty) { + if (ty.isScalable()) + return llvm::ElementCount::getScalable(ty.getNumElements()); + return llvm::ElementCount::getFixed(ty.getNumElements()); + }) + .Case([](LLVMFixedVectorType ty) { return llvm::ElementCount::getFixed(ty.getNumElements()); }) .Case([](LLVMScalableVectorType ty) { @@ -786,6 +791,31 @@ }); } +bool mlir::LLVM::isScalableVectorType(Type vectorType) { + assert( + (vectorType + .isa()) && + "expected LLVM-compatible vector type"); + return !vectorType.isa() && + (vectorType.isa() || + vectorType.cast().isScalable()); +} + +Type mlir::LLVM::getVectorType(Type elementType, unsigned numElements, + bool isScalable) { + bool useLLVM = LLVMFixedVectorType::isValidElementType(elementType); + bool useBuiltIn = VectorType::isValidElementType(elementType); + (void)useBuiltIn; + assert((useLLVM ^ useBuiltIn) && "expected LLVM-compatible fixed-vector type " + "to be either builtin or LLVM dialect type"); + if (useLLVM) { + if (isScalable) + return LLVMScalableVectorType::get(elementType, numElements); + return LLVMFixedVectorType::get(elementType, numElements); + } + return VectorType::get(numElements, elementType, (unsigned)isScalable); +} + Type mlir::LLVM::getFixedVectorType(Type elementType, unsigned numElements) { bool useLLVM = LLVMFixedVectorType::isValidElementType(elementType); bool useBuiltIn = VectorType::isValidElementType(elementType); @@ -797,6 +827,18 @@ return VectorType::get(numElements, elementType); } +Type mlir::LLVM::getScalableVectorType(Type elementType, unsigned numElements) { + bool useLLVM = LLVMScalableVectorType::isValidElementType(elementType); + bool useBuiltIn = VectorType::isValidElementType(elementType); + (void)useBuiltIn; + assert((useLLVM ^ useBuiltIn) && "expected LLVM-compatible scalable-vector " + "type to be either builtin or LLVM dialect " + "type"); + if (useLLVM) + return LLVMScalableVectorType::get(elementType, numElements); + return VectorType::get(numElements, elementType, /*numScalableDims=*/1); +} + llvm::TypeSize mlir::LLVM::getPrimitiveTypeSizeInBits(Type type) { assert(isCompatibleType(type) && "expected a type compatible with the LLVM dialect"); diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -515,7 +515,8 @@ if (type.isa()) return UnrankedTensorType::get(i1Type); if (auto vectorType = type.dyn_cast()) - return VectorType::get(vectorType.getShape(), i1Type); + return VectorType::get(vectorType.getShape(), i1Type, + vectorType.getNumScalableDims()); return i1Type; } diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1954,8 +1954,19 @@ }) .Case([&](VectorType vectorTy) { os << "vector<"; - for (int64_t dim : vectorTy.getShape()) - os << dim << 'x'; + auto vShape = vectorTy.getShape(); + unsigned lastDim = vShape.size(); + unsigned lastFixedDim = lastDim - vectorTy.getNumScalableDims(); + unsigned dimIdx = 0; + for (dimIdx = 0; dimIdx < lastFixedDim; dimIdx++) + os << vShape[dimIdx] << 'x'; + if (vectorTy.isScalable()) { + os << '['; + unsigned secondToLastDim = lastDim - 1; + for (; dimIdx < secondToLastDim; dimIdx++) + os << vShape[dimIdx] << 'x'; + os << vShape[dimIdx] << "]x"; + } printType(vectorTy.getElementType()); os << '>'; }) diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -1165,8 +1165,9 @@ newArrayType = RankedTensorType::get(inType.getShape(), newElementType); else if (inType.isa()) newArrayType = RankedTensorType::get(inType.getShape(), newElementType); - else if (inType.isa()) - newArrayType = VectorType::get(inType.getShape(), newElementType); + else if (auto vType = inType.dyn_cast()) + newArrayType = VectorType::get(vType.getShape(), newElementType, + vType.getNumScalableDims()); else assert(newArrayType && "Unhandled tensor type"); diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -294,8 +294,8 @@ if (isa()) return RankedTensorType::get(shape, elementType); - if (isa()) - return VectorType::get(shape, elementType); + if (auto vecTy = dyn_cast()) + return VectorType::get(shape, elementType, vecTy.getNumScalableDims()); llvm_unreachable("Unhandled ShapedType clone case"); } @@ -317,8 +317,8 @@ if (isa()) return RankedTensorType::get(shape, getElementType()); - if (isa()) - return VectorType::get(shape, getElementType()); + if (auto vecTy = dyn_cast()) + return VectorType::get(shape, getElementType(), vecTy.getNumScalableDims()); llvm_unreachable("Unhandled ShapedType clone case"); } @@ -340,8 +340,8 @@ return UnrankedTensorType::get(elementType); } - if (isa()) - return VectorType::get(getShape(), elementType); + if (auto vecTy = dyn_cast()) + return VectorType::get(getShape(), elementType, vecTy.getNumScalableDims()); llvm_unreachable("Unhandled ShapedType clone hit"); } @@ -441,7 +441,8 @@ //===----------------------------------------------------------------------===// LogicalResult VectorType::verify(function_ref emitError, - ArrayRef shape, Type elementType) { + ArrayRef shape, Type elementType, + unsigned numScalableDims) { if (!isValidElementType(elementType)) return emitError() << "vector elements must be int/index/float type but got " @@ -460,10 +461,10 @@ return VectorType(); if (auto et = getElementType().dyn_cast()) if (auto scaledEt = et.scaleElementBitwidth(scale)) - return VectorType::get(getShape(), scaledEt); + return VectorType::get(getShape(), scaledEt, getNumScalableDims()); if (auto et = getElementType().dyn_cast()) if (auto scaledEt = et.scaleElementBitwidth(scale)) - return VectorType::get(getShape(), scaledEt); + return VectorType::get(getShape(), scaledEt, getNumScalableDims()); return VectorType(); } diff --git a/mlir/lib/Parser/Parser.h b/mlir/lib/Parser/Parser.h --- a/mlir/lib/Parser/Parser.h +++ b/mlir/lib/Parser/Parser.h @@ -196,8 +196,11 @@ /// Parse a vector type. VectorType parseVectorType(); + ParseResult parseVectorDimensionList(SmallVectorImpl &dimensions, + unsigned &numScalableDims); ParseResult parseDimensionListRanked(SmallVectorImpl &dimensions, bool allowDynamic = true); + ParseResult parseIntegerInDimensionList(int64_t &value); ParseResult parseXInDimensionList(); /// Parse strided layout specification. diff --git a/mlir/lib/Parser/TypeParser.cpp b/mlir/lib/Parser/TypeParser.cpp --- a/mlir/lib/Parser/TypeParser.cpp +++ b/mlir/lib/Parser/TypeParser.cpp @@ -13,6 +13,7 @@ #include "Parser.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OpDefinition.h" #include "mlir/IR/TensorEncoding.h" using namespace mlir; @@ -442,8 +443,9 @@ /// Parse a vector type. /// -/// vector-type ::= `vector` `<` static-dimension-list type `>` -/// static-dimension-list ::= (decimal-literal `x`)* +/// vector-type ::= `vector` `<` vector-dim-list vector-element-type `>` +/// vector-dim-list := (static-dim-list `x`)? (`[` static-dim-list `]` `x`)? +/// static-dim-list ::= decimal-literal (`x` decimal-literal)* /// VectorType Parser::parseVectorType() { consumeToken(Token::kw_vector); @@ -452,7 +454,8 @@ return nullptr; SmallVector dimensions; - if (parseDimensionListRanked(dimensions, /*allowDynamic=*/false)) + unsigned numScalableDims; + if (parseVectorDimensionList(dimensions, numScalableDims)) return nullptr; if (any_of(dimensions, [](int64_t i) { return i <= 0; })) return emitError(getToken().getLoc(), @@ -464,11 +467,59 @@ auto elementType = parseType(); if (!elementType || parseToken(Token::greater, "expected '>' in vector type")) return nullptr; + if (!VectorType::isValidElementType(elementType)) return emitError(typeLoc, "vector elements must be int/index/float type"), nullptr; - return VectorType::get(dimensions, elementType); + return VectorType::get(dimensions, elementType, numScalableDims); +} + +/// Parse a dimension list in a vector type. This populates the dimension list, +/// and returns the number of scalable dimensions in `numScalableDims`. +/// +/// vector-dim-list := (static-dim-list `x`)? (`[` static-dim-list `]` `x`)? +/// static-dim-list ::= decimal-literal (`x` decimal-literal)* +/// +ParseResult +Parser::parseVectorDimensionList(SmallVectorImpl &dimensions, + unsigned &numScalableDims) { + numScalableDims = 0; + // If there is a set of fixed-length dimensions, consume it + while (getToken().is(Token::integer)) { + int64_t value; + if (parseIntegerInDimensionList(value)) + return failure(); + dimensions.push_back(value); + // Make sure we have an 'x' or something like 'xbf32'. + if (parseXInDimensionList()) + return failure(); + } + // If there is a set of scalable dimensions, consume it + if (consumeIf(Token::l_square)) { + while (getToken().is(Token::integer)) { + int64_t value; + if (parseIntegerInDimensionList(value)) + return failure(); + dimensions.push_back(value); + numScalableDims++; + // Check if we have reached the end of the scalable dimension list + if (consumeIf(Token::r_square)) { + // Make sure we have something like 'xbf32'. + if (parseXInDimensionList()) + return failure(); + return success(); + } + // Make sure we have an 'x' + if (parseXInDimensionList()) + return failure(); + } + // If we make it here, we've finished parsing the dimension list + // without finding ']' closing the set of scalable dimensions + return emitError("missing ']' closing set of scalable dimensions"); + } + + return success(); } /// Parse a dimension list of a tensor or memref type. This populates the @@ -490,28 +541,11 @@ return emitError("expected static shape"); dimensions.push_back(-1); } else { - // Hexadecimal integer literals (starting with `0x`) are not allowed in - // aggregate type declarations. Therefore, `0xf32` should be processed as - // a sequence of separate elements `0`, `x`, `f32`. - if (getTokenSpelling().size() > 1 && getTokenSpelling()[1] == 'x') { - // We can get here only if the token is an integer literal. Hexadecimal - // integer literals can only start with `0x` (`1x` wouldn't lex as a - // literal, just `1` would, at which point we don't get into this - // branch). - assert(getTokenSpelling()[0] == '0' && "invalid integer literal"); - dimensions.push_back(0); - state.lex.resetPointer(getTokenSpelling().data() + 1); - consumeToken(); - } else { - // Make sure this integer value is in bound and valid. - Optional dimension = getToken().getUInt64IntegerValue(); - if (!dimension || *dimension > std::numeric_limits::max()) - return emitError("invalid dimension"); - dimensions.push_back((int64_t)dimension.getValue()); - consumeToken(Token::integer); - } + int64_t value; + if (parseIntegerInDimensionList(value)) + return failure(); + dimensions.push_back(value); } - // Make sure we have an 'x' or something like 'xbf32'. if (parseXInDimensionList()) return failure(); @@ -520,6 +554,30 @@ return success(); } +ParseResult Parser::parseIntegerInDimensionList(int64_t &value) { + // Hexadecimal integer literals (starting with `0x`) are not allowed in + // aggregate type declarations. Therefore, `0xf32` should be processed as + // a sequence of separate elements `0`, `x`, `f32`. + if (getTokenSpelling().size() > 1 && getTokenSpelling()[1] == 'x') { + // We can get here only if the token is an integer literal. Hexadecimal + // integer literals can only start with `0x` (`1x` wouldn't lex as a + // literal, just `1` would, at which point we don't get into this + // branch). + assert(getTokenSpelling()[0] == '0' && "invalid integer literal"); + value = 0; + state.lex.resetPointer(getTokenSpelling().data() + 1); + consumeToken(); + } else { + // Make sure this integer value is in bound and valid. + Optional dimension = getToken().getUInt64IntegerValue(); + if (!dimension || *dimension > std::numeric_limits::max()) + return emitError("invalid dimension"); + value = (int64_t)dimension.getValue(); + consumeToken(Token::integer); + } + return success(); +} + /// Parse an 'x' token in a dimension list, handling the case where the x is /// juxtaposed with an element type, as in "xf32", leaving the "f32" as the next /// token. diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -242,10 +242,14 @@ if (auto *arrayTy = dyn_cast(llvmType)) { elementType = arrayTy->getElementType(); numElements = arrayTy->getNumElements(); + } else if (auto fVectorTy = dyn_cast(llvmType)) { + elementType = fVectorTy->getElementType(); + numElements = fVectorTy->getNumElements(); + } else if (auto sVectorTy = dyn_cast(llvmType)) { + elementType = sVectorTy->getElementType(); + numElements = sVectorTy->getMinNumElements(); } else { - auto *vectorTy = cast(llvmType); - elementType = vectorTy->getElementType(); - numElements = vectorTy->getNumElements(); + llvm_unreachable("unrecognized constant vector type"); } // Splat value is a scalar. Extract it only if the element type is not // another sequence type. The recursion terminates because each step removes diff --git a/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp b/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp --- a/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp +++ b/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp @@ -137,6 +137,9 @@ llvm::Type *translate(VectorType type) { assert(LLVM::isCompatibleVectorType(type) && "expected compatible with LLVM vector type"); + if (type.isScalable()) + return llvm::ScalableVectorType::get(translateType(type.getElementType()), + type.getNumElements()); return llvm::FixedVectorType::get(translateType(type.getElementType()), type.getNumElements()); } diff --git a/mlir/test/Dialect/Arithmetic/ops.mlir b/mlir/test/Dialect/Arithmetic/ops.mlir --- a/mlir/test/Dialect/Arithmetic/ops.mlir +++ b/mlir/test/Dialect/Arithmetic/ops.mlir @@ -19,6 +19,12 @@ return %0 : vector<8xi64> } +// CHECK-LABEL: test_addi_scalable_vector +func @test_addi_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi64> { + %0 = arith.addi %arg0, %arg1 : vector<[8]xi64> + return %0 : vector<[8]xi64> +} + // CHECK-LABEL: test_subi func @test_subi(%arg0 : i64, %arg1 : i64) -> i64 { %0 = arith.subi %arg0, %arg1 : i64 @@ -37,6 +43,12 @@ return %0 : vector<8xi64> } +// CHECK-LABEL: test_subi_scalable_vector +func @test_subi_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi64> { + %0 = arith.subi %arg0, %arg1 : vector<[8]xi64> + return %0 : vector<[8]xi64> +} + // CHECK-LABEL: test_muli func @test_muli(%arg0 : i64, %arg1 : i64) -> i64 { %0 = arith.muli %arg0, %arg1 : i64 @@ -55,6 +67,12 @@ return %0 : vector<8xi64> } +// CHECK-LABEL: test_muli_scalable_vector +func @test_muli_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi64> { + %0 = arith.muli %arg0, %arg1 : vector<[8]xi64> + return %0 : vector<[8]xi64> +} + // CHECK-LABEL: test_divui func @test_divui(%arg0 : i64, %arg1 : i64) -> i64 { %0 = arith.divui %arg0, %arg1 : i64 @@ -73,6 +91,12 @@ return %0 : vector<8xi64> } +// CHECK-LABEL: test_divui_scalable_vector +func @test_divui_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi64> { + %0 = arith.divui %arg0, %arg1 : vector<[8]xi64> + return %0 : vector<[8]xi64> +} + // CHECK-LABEL: test_divsi func @test_divsi(%arg0 : i64, %arg1 : i64) -> i64 { %0 = arith.divsi %arg0, %arg1 : i64 @@ -91,6 +115,12 @@ return %0 : vector<8xi64> } +// CHECK-LABEL: test_divsi_scalable_vector +func @test_divsi_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi64> { + %0 = arith.divsi %arg0, %arg1 : vector<[8]xi64> + return %0 : vector<[8]xi64> +} + // CHECK-LABEL: test_remui func @test_remui(%arg0 : i64, %arg1 : i64) -> i64 { %0 = arith.remui %arg0, %arg1 : i64 @@ -109,6 +139,12 @@ return %0 : vector<8xi64> } +// CHECK-LABEL: test_remui_scalable_vector +func @test_remui_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi64> { + %0 = arith.remui %arg0, %arg1 : vector<[8]xi64> + return %0 : vector<[8]xi64> +} + // CHECK-LABEL: test_remsi func @test_remsi(%arg0 : i64, %arg1 : i64) -> i64 { %0 = arith.remsi %arg0, %arg1 : i64 @@ -127,6 +163,12 @@ return %0 : vector<8xi64> } +// CHECK-LABEL: test_remsi_scalable_vector +func @test_remsi_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi64> { + %0 = arith.remsi %arg0, %arg1 : vector<[8]xi64> + return %0 : vector<[8]xi64> +} + // CHECK-LABEL: test_andi func @test_andi(%arg0 : i64, %arg1 : i64) -> i64 { %0 = arith.andi %arg0, %arg1 : i64 @@ -145,6 +187,12 @@ return %0 : vector<8xi64> } +// CHECK-LABEL: test_andi_scalable_vector +func @test_andi_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi64> { + %0 = arith.andi %arg0, %arg1 : vector<[8]xi64> + return %0 : vector<[8]xi64> +} + // CHECK-LABEL: test_ori func @test_ori(%arg0 : i64, %arg1 : i64) -> i64 { %0 = arith.ori %arg0, %arg1 : i64 @@ -163,6 +211,12 @@ return %0 : vector<8xi64> } +// CHECK-LABEL: test_ori_scalable_vector +func @test_ori_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi64> { + %0 = arith.ori %arg0, %arg1 : vector<[8]xi64> + return %0 : vector<[8]xi64> +} + // CHECK-LABEL: test_xori func @test_xori(%arg0 : i64, %arg1 : i64) -> i64 { %0 = arith.xori %arg0, %arg1 : i64 @@ -181,6 +235,12 @@ return %0 : vector<8xi64> } +// CHECK-LABEL: test_xori_scalable_vector +func @test_xori_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi64> { + %0 = arith.xori %arg0, %arg1 : vector<[8]xi64> + return %0 : vector<[8]xi64> +} + // CHECK-LABEL: test_ceildivsi func @test_ceildivsi(%arg0 : i64, %arg1 : i64) -> i64 { %0 = arith.ceildivsi %arg0, %arg1 : i64 @@ -199,6 +259,12 @@ return %0 : vector<8xi64> } +// CHECK-LABEL: test_ceildivsi_scalable_vector +func @test_ceildivsi_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi64> { + %0 = arith.ceildivsi %arg0, %arg1 : vector<[8]xi64> + return %0 : vector<[8]xi64> +} + // CHECK-LABEL: test_floordivsi func @test_floordivsi(%arg0 : i64, %arg1 : i64) -> i64 { %0 = arith.floordivsi %arg0, %arg1 : i64 @@ -217,6 +283,12 @@ return %0 : vector<8xi64> } +// CHECK-LABEL: test_floordivsi_scalable_vector +func @test_floordivsi_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi64> { + %0 = arith.floordivsi %arg0, %arg1 : vector<[8]xi64> + return %0 : vector<[8]xi64> +} + // CHECK-LABEL: test_shli func @test_shli(%arg0 : i64, %arg1 : i64) -> i64 { %0 = arith.shli %arg0, %arg1 : i64 @@ -235,6 +307,12 @@ return %0 : vector<8xi64> } +// CHECK-LABEL: test_shli_scalable_vector +func @test_shli_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi64> { + %0 = arith.shli %arg0, %arg1 : vector<[8]xi64> + return %0 : vector<[8]xi64> +} + // CHECK-LABEL: test_shrui func @test_shrui(%arg0 : i64, %arg1 : i64) -> i64 { %0 = arith.shrui %arg0, %arg1 : i64 @@ -253,6 +331,12 @@ return %0 : vector<8xi64> } +// CHECK-LABEL: test_shrui_scalable_vector +func @test_shrui_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi64> { + %0 = arith.shrui %arg0, %arg1 : vector<[8]xi64> + return %0 : vector<[8]xi64> +} + // CHECK-LABEL: test_shrsi func @test_shrsi(%arg0 : i64, %arg1 : i64) -> i64 { %0 = arith.shrsi %arg0, %arg1 : i64 @@ -271,6 +355,12 @@ return %0 : vector<8xi64> } +// CHECK-LABEL: test_shrsi_scalable_vector +func @test_shrsi_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi64> { + %0 = arith.shrsi %arg0, %arg1 : vector<[8]xi64> + return %0 : vector<[8]xi64> +} + // CHECK-LABEL: test_negf func @test_negf(%arg0 : f64) -> f64 { %0 = arith.negf %arg0 : f64 @@ -289,6 +379,12 @@ return %0 : vector<8xf64> } +// CHECK-LABEL: test_negf_scalable_vector +func @test_negf_scalable_vector(%arg0 : vector<[8]xf64>) -> vector<[8]xf64> { + %0 = arith.negf %arg0 : vector<[8]xf64> + return %0 : vector<[8]xf64> +} + // CHECK-LABEL: test_addf func @test_addf(%arg0 : f64, %arg1 : f64) -> f64 { %0 = arith.addf %arg0, %arg1 : f64 @@ -307,6 +403,12 @@ return %0 : vector<8xf64> } +// CHECK-LABEL: test_addf_scalable_vector +func @test_addf_scalable_vector(%arg0 : vector<[8]xf64>, %arg1 : vector<[8]xf64>) -> vector<[8]xf64> { + %0 = arith.addf %arg0, %arg1 : vector<[8]xf64> + return %0 : vector<[8]xf64> +} + // CHECK-LABEL: test_subf func @test_subf(%arg0 : f64, %arg1 : f64) -> f64 { %0 = arith.subf %arg0, %arg1 : f64 @@ -325,6 +427,12 @@ return %0 : vector<8xf64> } +// CHECK-LABEL: test_subf_scalable_vector +func @test_subf_scalable_vector(%arg0 : vector<[8]xf64>, %arg1 : vector<[8]xf64>) -> vector<[8]xf64> { + %0 = arith.subf %arg0, %arg1 : vector<[8]xf64> + return %0 : vector<[8]xf64> +} + // CHECK-LABEL: test_mulf func @test_mulf(%arg0 : f64, %arg1 : f64) -> f64 { %0 = arith.mulf %arg0, %arg1 : f64 @@ -343,6 +451,12 @@ return %0 : vector<8xf64> } +// CHECK-LABEL: test_mulf_scalable_vector +func @test_mulf_scalable_vector(%arg0 : vector<[8]xf64>, %arg1 : vector<[8]xf64>) -> vector<[8]xf64> { + %0 = arith.mulf %arg0, %arg1 : vector<[8]xf64> + return %0 : vector<[8]xf64> +} + // CHECK-LABEL: test_divf func @test_divf(%arg0 : f64, %arg1 : f64) -> f64 { %0 = arith.divf %arg0, %arg1 : f64 @@ -361,6 +475,12 @@ return %0 : vector<8xf64> } +// CHECK-LABEL: test_divf_scalable_vector +func @test_divf_scalable_vector(%arg0 : vector<[8]xf64>, %arg1 : vector<[8]xf64>) -> vector<[8]xf64> { + %0 = arith.divf %arg0, %arg1 : vector<[8]xf64> + return %0 : vector<[8]xf64> +} + // CHECK-LABEL: test_remf func @test_remf(%arg0 : f64, %arg1 : f64) -> f64 { %0 = arith.remf %arg0, %arg1 : f64 @@ -379,6 +499,12 @@ return %0 : vector<8xf64> } +// CHECK-LABEL: test_remf_scalable_vector +func @test_remf_scalable_vector(%arg0 : vector<[8]xf64>, %arg1 : vector<[8]xf64>) -> vector<[8]xf64> { + %0 = arith.remf %arg0, %arg1 : vector<[8]xf64> + return %0 : vector<[8]xf64> +} + // CHECK-LABEL: test_extui func @test_extui(%arg0 : i32) -> i64 { %0 = arith.extui %arg0 : i32 to i64 @@ -397,6 +523,12 @@ return %0 : vector<8xi64> } +// CHECK-LABEL: test_extui_scalable_vector +func @test_extui_scalable_vector(%arg0 : vector<[8]xi32>) -> vector<[8]xi64> { + %0 = arith.extui %arg0 : vector<[8]xi32> to vector<[8]xi64> + return %0 : vector<[8]xi64> +} + // CHECK-LABEL: test_extsi func @test_extsi(%arg0 : i32) -> i64 { %0 = arith.extsi %arg0 : i32 to i64 @@ -415,6 +547,12 @@ return %0 : vector<8xi64> } +// CHECK-LABEL: test_extsi_scalable_vector +func @test_extsi_scalable_vector(%arg0 : vector<[8]xi32>) -> vector<[8]xi64> { + %0 = arith.extsi %arg0 : vector<[8]xi32> to vector<[8]xi64> + return %0 : vector<[8]xi64> +} + // CHECK-LABEL: test_extf func @test_extf(%arg0 : f32) -> f64 { %0 = arith.extf %arg0 : f32 to f64 @@ -433,6 +571,12 @@ return %0 : vector<8xf64> } +// CHECK-LABEL: test_extf_scalable_vector +func @test_extf_scalable_vector(%arg0 : vector<[8]xf32>) -> vector<[8]xf64> { + %0 = arith.extf %arg0 : vector<[8]xf32> to vector<[8]xf64> + return %0 : vector<[8]xf64> +} + // CHECK-LABEL: test_trunci func @test_trunci(%arg0 : i32) -> i16 { %0 = arith.trunci %arg0 : i32 to i16 @@ -451,6 +595,12 @@ return %0 : vector<8xi16> } +// CHECK-LABEL: test_trunci_scalable_vector +func @test_trunci_scalable_vector(%arg0 : vector<[8]xi32>) -> vector<[8]xi16> { + %0 = arith.trunci %arg0 : vector<[8]xi32> to vector<[8]xi16> + return %0 : vector<[8]xi16> +} + // CHECK-LABEL: test_truncf func @test_truncf(%arg0 : f32) -> bf16 { %0 = arith.truncf %arg0 : f32 to bf16 @@ -469,6 +619,12 @@ return %0 : vector<8xbf16> } +// CHECK-LABEL: test_truncf_scalable_vector +func @test_truncf_scalable_vector(%arg0 : vector<[8]xf32>) -> vector<[8]xbf16> { + %0 = arith.truncf %arg0 : vector<[8]xf32> to vector<[8]xbf16> + return %0 : vector<[8]xbf16> +} + // CHECK-LABEL: test_uitofp func @test_uitofp(%arg0 : i32) -> f32 { %0 = arith.uitofp %arg0 : i32 to f32 @@ -487,6 +643,12 @@ return %0 : vector<8xf32> } +// CHECK-LABEL: test_uitofp_scalable_vector +func @test_uitofp_scalable_vector(%arg0 : vector<[8]xi32>) -> vector<[8]xf32> { + %0 = arith.uitofp %arg0 : vector<[8]xi32> to vector<[8]xf32> + return %0 : vector<[8]xf32> +} + // CHECK-LABEL: test_sitofp func @test_sitofp(%arg0 : i16) -> f64 { %0 = arith.sitofp %arg0 : i16 to f64 @@ -505,6 +667,12 @@ return %0 : vector<8xf64> } +// CHECK-LABEL: test_sitofp_scalable_vector +func @test_sitofp_scalable_vector(%arg0 : vector<[8]xi16>) -> vector<[8]xf64> { + %0 = arith.sitofp %arg0 : vector<[8]xi16> to vector<[8]xf64> + return %0 : vector<[8]xf64> +} + // CHECK-LABEL: test_fptoui func @test_fptoui(%arg0 : bf16) -> i8 { %0 = arith.fptoui %arg0 : bf16 to i8 @@ -523,6 +691,12 @@ return %0 : vector<8xi8> } +// CHECK-LABEL: test_fptoui_scalable_vector +func @test_fptoui_scalable_vector(%arg0 : vector<[8]xbf16>) -> vector<[8]xi8> { + %0 = arith.fptoui %arg0 : vector<[8]xbf16> to vector<[8]xi8> + return %0 : vector<[8]xi8> +} + // CHECK-LABEL: test_fptosi func @test_fptosi(%arg0 : f64) -> i64 { %0 = arith.fptosi %arg0 : f64 to i64 @@ -541,6 +715,12 @@ return %0 : vector<8xi64> } +// CHECK-LABEL: test_fptosi_scalable_vector +func @test_fptosi_scalable_vector(%arg0 : vector<[8]xf64>) -> vector<[8]xi64> { + %0 = arith.fptosi %arg0 : vector<[8]xf64> to vector<[8]xi64> + return %0 : vector<[8]xi64> +} + // CHECK-LABEL: test_index_cast0 func @test_index_cast0(%arg0 : i32) -> index { %0 = arith.index_cast %arg0 : i32 to index @@ -559,6 +739,12 @@ return %0 : vector<8xindex> } +// CHECK-LABEL: test_index_cast_scalable_vector0 +func @test_index_cast_scalable_vector0(%arg0 : vector<[8]xi32>) -> vector<[8]xindex> { + %0 = arith.index_cast %arg0 : vector<[8]xi32> to vector<[8]xindex> + return %0 : vector<[8]xindex> +} + // CHECK-LABEL: test_index_cast1 func @test_index_cast1(%arg0 : index) -> i64 { %0 = arith.index_cast %arg0 : index to i64 @@ -577,6 +763,12 @@ return %0 : vector<8xi64> } +// CHECK-LABEL: test_index_cast_scalable_vector1 +func @test_index_cast_scalable_vector1(%arg0 : vector<[8]xindex>) -> vector<[8]xi64> { + %0 = arith.index_cast %arg0 : vector<[8]xindex> to vector<[8]xi64> + return %0 : vector<[8]xi64> +} + // CHECK-LABEL: test_bitcast0 func @test_bitcast0(%arg0 : i64) -> f64 { %0 = arith.bitcast %arg0 : i64 to f64 @@ -595,6 +787,12 @@ return %0 : vector<8xf64> } +// CHECK-LABEL: test_bitcast_scalable_vector0 +func @test_bitcast_scalable_vector0(%arg0 : vector<[8]xi64>) -> vector<[8]xf64> { + %0 = arith.bitcast %arg0 : vector<[8]xi64> to vector<[8]xf64> + return %0 : vector<[8]xf64> +} + // CHECK-LABEL: test_bitcast1 func @test_bitcast1(%arg0 : f32) -> i32 { %0 = arith.bitcast %arg0 : f32 to i32 @@ -613,6 +811,12 @@ return %0 : vector<8xi32> } +// CHECK-LABEL: test_bitcast_scalable_vector1 +func @test_bitcast_scalable_vector1(%arg0 : vector<[8]xf32>) -> vector<[8]xi32> { + %0 = arith.bitcast %arg0 : vector<[8]xf32> to vector<[8]xi32> + return %0 : vector<[8]xi32> +} + // CHECK-LABEL: test_cmpi func @test_cmpi(%arg0 : i64, %arg1 : i64) -> i1 { %0 = arith.cmpi ne, %arg0, %arg1 : i64 @@ -631,6 +835,12 @@ return %0 : vector<8xi1> } +// CHECK-LABEL: test_cmpi_scalable_vector +func @test_cmpi_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi1> { + %0 = arith.cmpi ult, %arg0, %arg1 : vector<[8]xi64> + return %0 : vector<[8]xi1> +} + // CHECK-LABEL: test_cmpi_vector_0d func @test_cmpi_vector_0d(%arg0 : vector, %arg1 : vector) -> vector { %0 = arith.cmpi ult, %arg0, %arg1 : vector @@ -655,6 +865,12 @@ return %0 : vector<8xi1> } +// CHECK-LABEL: test_cmpf_scalable_vector +func @test_cmpf_scalable_vector(%arg0 : vector<[8]xf64>, %arg1 : vector<[8]xf64>) -> vector<[8]xi1> { + %0 = arith.cmpf ult, %arg0, %arg1 : vector<[8]xf64> + return %0 : vector<[8]xi1> +} + // CHECK-LABEL: test_index_cast func @test_index_cast(%arg0 : index) -> i64 { %0 = arith.index_cast %arg0 : index to i64 @@ -713,9 +929,11 @@ // CHECK-LABEL: func @maximum func @maximum(%v1: vector<4xf32>, %v2: vector<4xf32>, + %sv1: vector<[4]xf32>, %sv2: vector<[4]xf32>, %f1: f32, %f2: f32, %i1: i32, %i2: i32) { %max_vector = arith.maxf %v1, %v2 : vector<4xf32> + %max_scalable_vector = arith.maxf %sv1, %sv2 : vector<[4]xf32> %max_float = arith.maxf %f1, %f2 : f32 %max_signed = arith.maxsi %i1, %i2 : i32 %max_unsigned = arith.maxui %i1, %i2 : i32 @@ -724,9 +942,11 @@ // CHECK-LABEL: func @minimum func @minimum(%v1: vector<4xf32>, %v2: vector<4xf32>, + %sv1: vector<[4]xf32>, %sv2: vector<[4]xf32>, %f1: f32, %f2: f32, %i1: i32, %i2: i32) { %min_vector = arith.minf %v1, %v2 : vector<4xf32> + %min_scalable_vector = arith.minf %sv1, %sv2 : vector<[4]xf32> %min_float = arith.minf %f1, %f2 : f32 %min_signed = arith.minsi %i1, %i2 : i32 %min_unsigned = arith.minui %i1, %i2 : i32 diff --git a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir --- a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir +++ b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir @@ -1,168 +1,118 @@ -// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sve" -convert-std-to-llvm | mlir-opt | FileCheck %s +// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sve" -convert-std-to-llvm -reconcile-unrealized-casts | mlir-opt | FileCheck %s -func @arm_sve_sdot(%a: !arm_sve.vector<16xi8>, - %b: !arm_sve.vector<16xi8>, - %c: !arm_sve.vector<4xi32>) - -> !arm_sve.vector<4xi32> { +func @arm_sve_sdot(%a: vector<[16]xi8>, + %b: vector<[16]xi8>, + %c: vector<[4]xi32>) + -> vector<[4]xi32> { // CHECK: arm_sve.intr.sdot %0 = arm_sve.sdot %c, %a, %b : - !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32> - return %0 : !arm_sve.vector<4xi32> + vector<[16]xi8> to vector<[4]xi32> + return %0 : vector<[4]xi32> } -func @arm_sve_smmla(%a: !arm_sve.vector<16xi8>, - %b: !arm_sve.vector<16xi8>, - %c: !arm_sve.vector<4xi32>) - -> !arm_sve.vector<4xi32> { +func @arm_sve_smmla(%a: vector<[16]xi8>, + %b: vector<[16]xi8>, + %c: vector<[4]xi32>) + -> vector<[4]xi32> { // CHECK: arm_sve.intr.smmla %0 = arm_sve.smmla %c, %a, %b : - !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32> - return %0 : !arm_sve.vector<4xi32> + vector<[16]xi8> to vector<[4]xi32> + return %0 : vector<[4]xi32> } -func @arm_sve_udot(%a: !arm_sve.vector<16xi8>, - %b: !arm_sve.vector<16xi8>, - %c: !arm_sve.vector<4xi32>) - -> !arm_sve.vector<4xi32> { +func @arm_sve_udot(%a: vector<[16]xi8>, + %b: vector<[16]xi8>, + %c: vector<[4]xi32>) + -> vector<[4]xi32> { // CHECK: arm_sve.intr.udot %0 = arm_sve.udot %c, %a, %b : - !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32> - return %0 : !arm_sve.vector<4xi32> + vector<[16]xi8> to vector<[4]xi32> + return %0 : vector<[4]xi32> } -func @arm_sve_ummla(%a: !arm_sve.vector<16xi8>, - %b: !arm_sve.vector<16xi8>, - %c: !arm_sve.vector<4xi32>) - -> !arm_sve.vector<4xi32> { +func @arm_sve_ummla(%a: vector<[16]xi8>, + %b: vector<[16]xi8>, + %c: vector<[4]xi32>) + -> vector<[4]xi32> { // CHECK: arm_sve.intr.ummla %0 = arm_sve.ummla %c, %a, %b : - !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32> - return %0 : !arm_sve.vector<4xi32> + vector<[16]xi8> to vector<[4]xi32> + return %0 : vector<[4]xi32> } -func @arm_sve_arithi(%a: !arm_sve.vector<4xi32>, - %b: !arm_sve.vector<4xi32>, - %c: !arm_sve.vector<4xi32>, - %d: !arm_sve.vector<4xi32>, - %e: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> { - // CHECK: llvm.mul {{.*}}: !llvm.vec - %0 = arm_sve.muli %a, %b : !arm_sve.vector<4xi32> - // CHECK: llvm.add {{.*}}: !llvm.vec - %1 = arm_sve.addi %0, %c : !arm_sve.vector<4xi32> - // CHECK: llvm.sub {{.*}}: !llvm.vec - %2 = arm_sve.subi %1, %d : !arm_sve.vector<4xi32> - // CHECK: llvm.sdiv {{.*}}: !llvm.vec - %3 = arm_sve.divi_signed %2, %e : !arm_sve.vector<4xi32> - // CHECK: llvm.udiv {{.*}}: !llvm.vec - %4 = arm_sve.divi_unsigned %2, %e : !arm_sve.vector<4xi32> - return %4 : !arm_sve.vector<4xi32> +func @arm_sve_arithi_masked(%a: vector<[4]xi32>, + %b: vector<[4]xi32>, + %c: vector<[4]xi32>, + %d: vector<[4]xi32>, + %e: vector<[4]xi32>, + %mask: vector<[4]xi1> + ) -> vector<[4]xi32> { + // CHECK: arm_sve.intr.add{{.*}}: (vector<[4]xi1>, vector<[4]xi32>, vector<[4]xi32>) -> vector<[4]xi32> + %0 = arm_sve.masked.addi %mask, %a, %b : vector<[4]xi1>, + vector<[4]xi32> + // CHECK: arm_sve.intr.sub{{.*}}: (vector<[4]xi1>, vector<[4]xi32>, vector<[4]xi32>) -> vector<[4]xi32> + %1 = arm_sve.masked.subi %mask, %0, %c : vector<[4]xi1>, + vector<[4]xi32> + // CHECK: arm_sve.intr.mul{{.*}}: (vector<[4]xi1>, vector<[4]xi32>, vector<[4]xi32>) -> vector<[4]xi32> + %2 = arm_sve.masked.muli %mask, %1, %d : vector<[4]xi1>, + vector<[4]xi32> + // CHECK: arm_sve.intr.sdiv{{.*}}: (vector<[4]xi1>, vector<[4]xi32>, vector<[4]xi32>) -> vector<[4]xi32> + %3 = arm_sve.masked.divi_signed %mask, %2, %e : vector<[4]xi1>, + vector<[4]xi32> + // CHECK: arm_sve.intr.udiv{{.*}}: (vector<[4]xi1>, vector<[4]xi32>, vector<[4]xi32>) -> vector<[4]xi32> + %4 = arm_sve.masked.divi_unsigned %mask, %3, %e : vector<[4]xi1>, + vector<[4]xi32> + return %4 : vector<[4]xi32> } -func @arm_sve_arithf(%a: !arm_sve.vector<4xf32>, - %b: !arm_sve.vector<4xf32>, - %c: !arm_sve.vector<4xf32>, - %d: !arm_sve.vector<4xf32>, - %e: !arm_sve.vector<4xf32>) -> !arm_sve.vector<4xf32> { - // CHECK: llvm.fmul {{.*}}: !llvm.vec - %0 = arm_sve.mulf %a, %b : !arm_sve.vector<4xf32> - // CHECK: llvm.fadd {{.*}}: !llvm.vec - %1 = arm_sve.addf %0, %c : !arm_sve.vector<4xf32> - // CHECK: llvm.fsub {{.*}}: !llvm.vec - %2 = arm_sve.subf %1, %d : !arm_sve.vector<4xf32> - // CHECK: llvm.fdiv {{.*}}: !llvm.vec - %3 = arm_sve.divf %2, %e : !arm_sve.vector<4xf32> - return %3 : !arm_sve.vector<4xf32> +func @arm_sve_arithf_masked(%a: vector<[4]xf32>, + %b: vector<[4]xf32>, + %c: vector<[4]xf32>, + %d: vector<[4]xf32>, + %e: vector<[4]xf32>, + %mask: vector<[4]xi1> + ) -> vector<[4]xf32> { + // CHECK: arm_sve.intr.fadd{{.*}}: (vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>) -> vector<[4]xf32> + %0 = arm_sve.masked.addf %mask, %a, %b : vector<[4]xi1>, + vector<[4]xf32> + // CHECK: arm_sve.intr.fsub{{.*}}: (vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>) -> vector<[4]xf32> + %1 = arm_sve.masked.subf %mask, %0, %c : vector<[4]xi1>, + vector<[4]xf32> + // CHECK: arm_sve.intr.fmul{{.*}}: (vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>) -> vector<[4]xf32> + %2 = arm_sve.masked.mulf %mask, %1, %d : vector<[4]xi1>, + vector<[4]xf32> + // CHECK: arm_sve.intr.fdiv{{.*}}: (vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>) -> vector<[4]xf32> + %3 = arm_sve.masked.divf %mask, %2, %e : vector<[4]xi1>, + vector<[4]xf32> + return %3 : vector<[4]xf32> } -func @arm_sve_arithi_masked(%a: !arm_sve.vector<4xi32>, - %b: !arm_sve.vector<4xi32>, - %c: !arm_sve.vector<4xi32>, - %d: !arm_sve.vector<4xi32>, - %e: !arm_sve.vector<4xi32>, - %mask: !arm_sve.vector<4xi1> - ) -> !arm_sve.vector<4xi32> { - // CHECK: arm_sve.intr.add{{.*}}: (!llvm.vec, !llvm.vec, !llvm.vec) -> !llvm.vec - %0 = arm_sve.masked.addi %mask, %a, %b : !arm_sve.vector<4xi1>, - !arm_sve.vector<4xi32> - // CHECK: arm_sve.intr.sub{{.*}}: (!llvm.vec, !llvm.vec, !llvm.vec) -> !llvm.vec - %1 = arm_sve.masked.subi %mask, %0, %c : !arm_sve.vector<4xi1>, - !arm_sve.vector<4xi32> - // CHECK: arm_sve.intr.mul{{.*}}: (!llvm.vec, !llvm.vec, !llvm.vec) -> !llvm.vec - %2 = arm_sve.masked.muli %mask, %1, %d : !arm_sve.vector<4xi1>, - !arm_sve.vector<4xi32> - // CHECK: arm_sve.intr.sdiv{{.*}}: (!llvm.vec, !llvm.vec, !llvm.vec) -> !llvm.vec - %3 = arm_sve.masked.divi_signed %mask, %2, %e : !arm_sve.vector<4xi1>, - !arm_sve.vector<4xi32> - // CHECK: arm_sve.intr.udiv{{.*}}: (!llvm.vec, !llvm.vec, !llvm.vec) -> !llvm.vec - %4 = arm_sve.masked.divi_unsigned %mask, %3, %e : !arm_sve.vector<4xi1>, - !arm_sve.vector<4xi32> - return %4 : !arm_sve.vector<4xi32> -} - -func @arm_sve_arithf_masked(%a: !arm_sve.vector<4xf32>, - %b: !arm_sve.vector<4xf32>, - %c: !arm_sve.vector<4xf32>, - %d: !arm_sve.vector<4xf32>, - %e: !arm_sve.vector<4xf32>, - %mask: !arm_sve.vector<4xi1> - ) -> !arm_sve.vector<4xf32> { - // CHECK: arm_sve.intr.fadd{{.*}}: (!llvm.vec, !llvm.vec, !llvm.vec) -> !llvm.vec - %0 = arm_sve.masked.addf %mask, %a, %b : !arm_sve.vector<4xi1>, - !arm_sve.vector<4xf32> - // CHECK: arm_sve.intr.fsub{{.*}}: (!llvm.vec, !llvm.vec, !llvm.vec) -> !llvm.vec - %1 = arm_sve.masked.subf %mask, %0, %c : !arm_sve.vector<4xi1>, - !arm_sve.vector<4xf32> - // CHECK: arm_sve.intr.fmul{{.*}}: (!llvm.vec, !llvm.vec, !llvm.vec) -> !llvm.vec - %2 = arm_sve.masked.mulf %mask, %1, %d : !arm_sve.vector<4xi1>, - !arm_sve.vector<4xf32> - // CHECK: arm_sve.intr.fdiv{{.*}}: (!llvm.vec, !llvm.vec, !llvm.vec) -> !llvm.vec - %3 = arm_sve.masked.divf %mask, %2, %e : !arm_sve.vector<4xi1>, - !arm_sve.vector<4xf32> - return %3 : !arm_sve.vector<4xf32> -} - -func @arm_sve_mask_genf(%a: !arm_sve.vector<4xf32>, - %b: !arm_sve.vector<4xf32>) - -> !arm_sve.vector<4xi1> { - // CHECK: llvm.fcmp "oeq" {{.*}}: !llvm.vec - %0 = arm_sve.cmpf oeq, %a, %b : !arm_sve.vector<4xf32> - return %0 : !arm_sve.vector<4xi1> -} - -func @arm_sve_mask_geni(%a: !arm_sve.vector<4xi32>, - %b: !arm_sve.vector<4xi32>) - -> !arm_sve.vector<4xi1> { - // CHECK: llvm.icmp "uge" {{.*}}: !llvm.vec - %0 = arm_sve.cmpi uge, %a, %b : !arm_sve.vector<4xi32> - return %0 : !arm_sve.vector<4xi1> -} - -func @arm_sve_abs_diff(%a: !arm_sve.vector<4xi32>, - %b: !arm_sve.vector<4xi32>) - -> !arm_sve.vector<4xi32> { - // CHECK: llvm.sub {{.*}}: !llvm.vec - %z = arm_sve.subi %a, %a : !arm_sve.vector<4xi32> - // CHECK: llvm.icmp "sge" {{.*}}: !llvm.vec - %agb = arm_sve.cmpi sge, %a, %b : !arm_sve.vector<4xi32> - // CHECK: llvm.icmp "slt" {{.*}}: !llvm.vec - %bga = arm_sve.cmpi slt, %a, %b : !arm_sve.vector<4xi32> - // CHECK: "arm_sve.intr.sub"{{.*}}: (!llvm.vec, !llvm.vec, !llvm.vec) -> !llvm.vec - %0 = arm_sve.masked.subi %agb, %a, %b : !arm_sve.vector<4xi1>, - !arm_sve.vector<4xi32> - // CHECK: "arm_sve.intr.sub"{{.*}}: (!llvm.vec, !llvm.vec, !llvm.vec) -> !llvm.vec - %1 = arm_sve.masked.subi %bga, %b, %a : !arm_sve.vector<4xi1>, - !arm_sve.vector<4xi32> - // CHECK: "arm_sve.intr.add"{{.*}}: (!llvm.vec, !llvm.vec, !llvm.vec) -> !llvm.vec - %2 = arm_sve.masked.addi %agb, %z, %0 : !arm_sve.vector<4xi1>, - !arm_sve.vector<4xi32> - // CHECK: "arm_sve.intr.add"{{.*}}: (!llvm.vec, !llvm.vec, !llvm.vec) -> !llvm.vec - %3 = arm_sve.masked.addi %bga, %2, %1 : !arm_sve.vector<4xi1>, - !arm_sve.vector<4xi32> - return %3 : !arm_sve.vector<4xi32> +func @arm_sve_abs_diff(%a: vector<[4]xi32>, + %b: vector<[4]xi32>) + -> vector<[4]xi32> { + // CHECK: llvm.mlir.constant(dense<0> : vector<[4]xi32>) : vector<[4]xi32> + %z = arith.subi %a, %a : vector<[4]xi32> + // CHECK: llvm.icmp "sge" {{.*}}: vector<[4]xi32> + %agb = arith.cmpi sge, %a, %b : vector<[4]xi32> + // CHECK: llvm.icmp "slt" {{.*}}: vector<[4]xi32> + %bga = arith.cmpi slt, %a, %b : vector<[4]xi32> + // CHECK: "arm_sve.intr.sub"{{.*}}: (vector<[4]xi1>, vector<[4]xi32>, vector<[4]xi32>) -> vector<[4]xi32> + %0 = arm_sve.masked.subi %agb, %a, %b : vector<[4]xi1>, + vector<[4]xi32> + // CHECK: "arm_sve.intr.sub"{{.*}}: (vector<[4]xi1>, vector<[4]xi32>, vector<[4]xi32>) -> vector<[4]xi32> + %1 = arm_sve.masked.subi %bga, %b, %a : vector<[4]xi1>, + vector<[4]xi32> + // CHECK: "arm_sve.intr.add"{{.*}}: (vector<[4]xi1>, vector<[4]xi32>, vector<[4]xi32>) -> vector<[4]xi32> + %2 = arm_sve.masked.addi %agb, %z, %0 : vector<[4]xi1>, + vector<[4]xi32> + // CHECK: "arm_sve.intr.add"{{.*}}: (vector<[4]xi1>, vector<[4]xi32>, vector<[4]xi32>) -> vector<[4]xi32> + %3 = arm_sve.masked.addi %bga, %2, %1 : vector<[4]xi1>, + vector<[4]xi32> + return %3 : vector<[4]xi32> } func @get_vector_scale() -> index { - // CHECK: arm_sve.vscale - %0 = arm_sve.vector_scale : index + // CHECK: llvm.intr.vscale + %0 = vector.vscale return %0 : index } diff --git a/mlir/test/Dialect/ArmSVE/memcpy.mlir b/mlir/test/Dialect/ArmSVE/memcpy.mlir deleted file mode 100644 --- a/mlir/test/Dialect/ArmSVE/memcpy.mlir +++ /dev/null @@ -1,28 +0,0 @@ -// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sve" | mlir-opt | FileCheck %s - -// CHECK: memcopy([[SRC:%arg[0-9]+]]: memref, [[DST:%arg[0-9]+]] -func @memcopy(%src : memref, %dst : memref, %size : index) { - %c0 = arith.constant 0 : index - %c4 = arith.constant 4 : index - %vs = arm_sve.vector_scale : index - %step = arith.muli %c4, %vs : index - - // CHECK: [[SRCMRS:%[0-9]+]] = builtin.unrealized_conversion_cast [[SRC]] : memref to !llvm.struct<(ptr - // CHECK: [[DSTMRS:%[0-9]+]] = builtin.unrealized_conversion_cast [[DST]] : memref to !llvm.struct<(ptr - // CHECK: scf.for [[LOOPIDX:%arg[0-9]+]] = {{.*}} - scf.for %i0 = %c0 to %size step %step { - // CHECK: [[SRCIDX:%[0-9]+]] = builtin.unrealized_conversion_cast [[LOOPIDX]] : index to i64 - // CHECK: [[SRCMEM:%[0-9]+]] = llvm.extractvalue [[SRCMRS]][1] : !llvm.struct<(ptr - // CHECK-NEXT: [[SRCPTR:%[0-9]+]] = llvm.getelementptr [[SRCMEM]]{{.}}[[SRCIDX]]{{.}} : (!llvm.ptr, i64) -> !llvm.ptr - // CHECK-NEXT: [[SRCVPTR:%[0-9]+]] = llvm.bitcast [[SRCPTR]] : !llvm.ptr to !llvm.ptr> - // CHECK-NEXT: [[LDVAL:%[0-9]+]] = llvm.load [[SRCVPTR]] : !llvm.ptr> - %0 = arm_sve.load %src[%i0] : !arm_sve.vector<4xf32> from memref - // CHECK: [[DSTMEM:%[0-9]+]] = llvm.extractvalue [[DSTMRS]][1] : !llvm.struct<(ptr - // CHECK-NEXT: [[DSTPTR:%[0-9]+]] = llvm.getelementptr [[DSTMEM]]{{.}}[[SRCIDX]]{{.}} : (!llvm.ptr, i64) -> !llvm.ptr - // CHECK-NEXT: [[DSTVPTR:%[0-9]+]] = llvm.bitcast [[DSTPTR]] : !llvm.ptr to !llvm.ptr> - // CHECK-NEXT: llvm.store [[LDVAL]], [[DSTVPTR]] : !llvm.ptr> - arm_sve.store %0, %dst[%i0] : !arm_sve.vector<4xf32> to memref - } - - return -} diff --git a/mlir/test/Dialect/ArmSVE/roundtrip.mlir b/mlir/test/Dialect/ArmSVE/roundtrip.mlir --- a/mlir/test/Dialect/ArmSVE/roundtrip.mlir +++ b/mlir/test/Dialect/ArmSVE/roundtrip.mlir @@ -1,137 +1,84 @@ // RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s -func @arm_sve_sdot(%a: !arm_sve.vector<16xi8>, - %b: !arm_sve.vector<16xi8>, - %c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> { - // CHECK: arm_sve.sdot {{.*}}: <16xi8> to <4xi32 +func @arm_sve_sdot(%a: vector<[16]xi8>, + %b: vector<[16]xi8>, + %c: vector<[4]xi32>) -> vector<[4]xi32> { + // CHECK: arm_sve.sdot {{.*}}: vector<[16]xi8> to vector<[4]xi32 %0 = arm_sve.sdot %c, %a, %b : - !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32> - return %0 : !arm_sve.vector<4xi32> + vector<[16]xi8> to vector<[4]xi32> + return %0 : vector<[4]xi32> } -func @arm_sve_smmla(%a: !arm_sve.vector<16xi8>, - %b: !arm_sve.vector<16xi8>, - %c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> { - // CHECK: arm_sve.smmla {{.*}}: <16xi8> to <4xi3 +func @arm_sve_smmla(%a: vector<[16]xi8>, + %b: vector<[16]xi8>, + %c: vector<[4]xi32>) -> vector<[4]xi32> { + // CHECK: arm_sve.smmla {{.*}}: vector<[16]xi8> to vector<[4]xi3 %0 = arm_sve.smmla %c, %a, %b : - !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32> - return %0 : !arm_sve.vector<4xi32> + vector<[16]xi8> to vector<[4]xi32> + return %0 : vector<[4]xi32> } -func @arm_sve_udot(%a: !arm_sve.vector<16xi8>, - %b: !arm_sve.vector<16xi8>, - %c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> { - // CHECK: arm_sve.udot {{.*}}: <16xi8> to <4xi32 +func @arm_sve_udot(%a: vector<[16]xi8>, + %b: vector<[16]xi8>, + %c: vector<[4]xi32>) -> vector<[4]xi32> { + // CHECK: arm_sve.udot {{.*}}: vector<[16]xi8> to vector<[4]xi32 %0 = arm_sve.udot %c, %a, %b : - !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32> - return %0 : !arm_sve.vector<4xi32> + vector<[16]xi8> to vector<[4]xi32> + return %0 : vector<[4]xi32> } -func @arm_sve_ummla(%a: !arm_sve.vector<16xi8>, - %b: !arm_sve.vector<16xi8>, - %c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> { - // CHECK: arm_sve.ummla {{.*}}: <16xi8> to <4xi3 +func @arm_sve_ummla(%a: vector<[16]xi8>, + %b: vector<[16]xi8>, + %c: vector<[4]xi32>) -> vector<[4]xi32> { + // CHECK: arm_sve.ummla {{.*}}: vector<[16]xi8> to vector<[4]xi3 %0 = arm_sve.ummla %c, %a, %b : - !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32> - return %0 : !arm_sve.vector<4xi32> + vector<[16]xi8> to vector<[4]xi32> + return %0 : vector<[4]xi32> } -func @arm_sve_arithi(%a: !arm_sve.vector<4xi32>, - %b: !arm_sve.vector<4xi32>, - %c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> { - // CHECK: arm_sve.muli {{.*}}: !arm_sve.vector<4xi32> - %0 = arm_sve.muli %a, %b : !arm_sve.vector<4xi32> - // CHECK: arm_sve.addi {{.*}}: !arm_sve.vector<4xi32> - %1 = arm_sve.addi %0, %c : !arm_sve.vector<4xi32> - return %1 : !arm_sve.vector<4xi32> -} - -func @arm_sve_arithf(%a: !arm_sve.vector<4xf32>, - %b: !arm_sve.vector<4xf32>, - %c: !arm_sve.vector<4xf32>) -> !arm_sve.vector<4xf32> { - // CHECK: arm_sve.mulf {{.*}}: !arm_sve.vector<4xf32> - %0 = arm_sve.mulf %a, %b : !arm_sve.vector<4xf32> - // CHECK: arm_sve.addf {{.*}}: !arm_sve.vector<4xf32> - %1 = arm_sve.addf %0, %c : !arm_sve.vector<4xf32> - return %1 : !arm_sve.vector<4xf32> -} - -func @arm_sve_masked_arithi(%a: !arm_sve.vector<4xi32>, - %b: !arm_sve.vector<4xi32>, - %c: !arm_sve.vector<4xi32>, - %d: !arm_sve.vector<4xi32>, - %e: !arm_sve.vector<4xi32>, - %mask: !arm_sve.vector<4xi1>) - -> !arm_sve.vector<4xi32> { - // CHECK: arm_sve.masked.muli {{.*}}: !arm_sve.vector<4xi1>, !arm_sve.vector - %0 = arm_sve.masked.muli %mask, %a, %b : !arm_sve.vector<4xi1>, - !arm_sve.vector<4xi32> - // CHECK: arm_sve.masked.addi {{.*}}: !arm_sve.vector<4xi1>, !arm_sve.vector - %1 = arm_sve.masked.addi %mask, %0, %c : !arm_sve.vector<4xi1>, - !arm_sve.vector<4xi32> - // CHECK: arm_sve.masked.subi {{.*}}: !arm_sve.vector<4xi1>, !arm_sve.vector - %2 = arm_sve.masked.subi %mask, %1, %d : !arm_sve.vector<4xi1>, - !arm_sve.vector<4xi32> +func @arm_sve_masked_arithi(%a: vector<[4]xi32>, + %b: vector<[4]xi32>, + %c: vector<[4]xi32>, + %d: vector<[4]xi32>, + %e: vector<[4]xi32>, + %mask: vector<[4]xi1>) + -> vector<[4]xi32> { + // CHECK: arm_sve.masked.muli {{.*}}: vector<[4]xi1>, vector< + %0 = arm_sve.masked.muli %mask, %a, %b : vector<[4]xi1>, + vector<[4]xi32> + // CHECK: arm_sve.masked.addi {{.*}}: vector<[4]xi1>, vector< + %1 = arm_sve.masked.addi %mask, %0, %c : vector<[4]xi1>, + vector<[4]xi32> + // CHECK: arm_sve.masked.subi {{.*}}: vector<[4]xi1>, vector< + %2 = arm_sve.masked.subi %mask, %1, %d : vector<[4]xi1>, + vector<[4]xi32> // CHECK: arm_sve.masked.divi_signed - %3 = arm_sve.masked.divi_signed %mask, %2, %e : !arm_sve.vector<4xi1>, - !arm_sve.vector<4xi32> + %3 = arm_sve.masked.divi_signed %mask, %2, %e : vector<[4]xi1>, + vector<[4]xi32> // CHECK: arm_sve.masked.divi_unsigned - %4 = arm_sve.masked.divi_unsigned %mask, %3, %e : !arm_sve.vector<4xi1>, - !arm_sve.vector<4xi32> - return %2 : !arm_sve.vector<4xi32> -} - -func @arm_sve_masked_arithf(%a: !arm_sve.vector<4xf32>, - %b: !arm_sve.vector<4xf32>, - %c: !arm_sve.vector<4xf32>, - %d: !arm_sve.vector<4xf32>, - %e: !arm_sve.vector<4xf32>, - %mask: !arm_sve.vector<4xi1>) - -> !arm_sve.vector<4xf32> { - // CHECK: arm_sve.masked.mulf {{.*}}: !arm_sve.vector<4xi1>, !arm_sve.vector - %0 = arm_sve.masked.mulf %mask, %a, %b : !arm_sve.vector<4xi1>, - !arm_sve.vector<4xf32> - // CHECK: arm_sve.masked.addf {{.*}}: !arm_sve.vector<4xi1>, !arm_sve.vector - %1 = arm_sve.masked.addf %mask, %0, %c : !arm_sve.vector<4xi1>, - !arm_sve.vector<4xf32> - // CHECK: arm_sve.masked.subf {{.*}}: !arm_sve.vector<4xi1>, !arm_sve.vector - %2 = arm_sve.masked.subf %mask, %1, %d : !arm_sve.vector<4xi1>, - !arm_sve.vector<4xf32> - // CHECK: arm_sve.masked.divf {{.*}}: !arm_sve.vector<4xi1>, !arm_sve.vector - %3 = arm_sve.masked.divf %mask, %2, %e : !arm_sve.vector<4xi1>, - !arm_sve.vector<4xf32> - return %3 : !arm_sve.vector<4xf32> -} - -func @arm_sve_mask_genf(%a: !arm_sve.vector<4xf32>, - %b: !arm_sve.vector<4xf32>) - -> !arm_sve.vector<4xi1> { - // CHECK: arm_sve.cmpf oeq, {{.*}}: !arm_sve.vector<4xf32> - %0 = arm_sve.cmpf oeq, %a, %b : !arm_sve.vector<4xf32> - return %0 : !arm_sve.vector<4xi1> -} - -func @arm_sve_mask_geni(%a: !arm_sve.vector<4xi32>, - %b: !arm_sve.vector<4xi32>) - -> !arm_sve.vector<4xi1> { - // CHECK: arm_sve.cmpi uge, {{.*}}: !arm_sve.vector<4xi32> - %0 = arm_sve.cmpi uge, %a, %b : !arm_sve.vector<4xi32> - return %0 : !arm_sve.vector<4xi1> -} - -func @arm_sve_memory(%v: !arm_sve.vector<4xi32>, - %m: memref) - -> !arm_sve.vector<4xi32> { - %c0 = arith.constant 0 : index - // CHECK: arm_sve.load {{.*}}: !arm_sve.vector<4xi32> from memref - %0 = arm_sve.load %m[%c0] : !arm_sve.vector<4xi32> from memref - // CHECK: arm_sve.store {{.*}}: !arm_sve.vector<4xi32> to memref - arm_sve.store %v, %m[%c0] : !arm_sve.vector<4xi32> to memref - return %0 : !arm_sve.vector<4xi32> + %4 = arm_sve.masked.divi_unsigned %mask, %3, %e : vector<[4]xi1>, + vector<[4]xi32> + return %2 : vector<[4]xi32> } -func @get_vector_scale() -> index { - // CHECK: arm_sve.vector_scale : index - %0 = arm_sve.vector_scale : index - return %0 : index +func @arm_sve_masked_arithf(%a: vector<[4]xf32>, + %b: vector<[4]xf32>, + %c: vector<[4]xf32>, + %d: vector<[4]xf32>, + %e: vector<[4]xf32>, + %mask: vector<[4]xi1>) + -> vector<[4]xf32> { + // CHECK: arm_sve.masked.mulf {{.*}}: vector<[4]xi1>, vector< + %0 = arm_sve.masked.mulf %mask, %a, %b : vector<[4]xi1>, + vector<[4]xf32> + // CHECK: arm_sve.masked.addf {{.*}}: vector<[4]xi1>, vector< + %1 = arm_sve.masked.addf %mask, %0, %c : vector<[4]xi1>, + vector<[4]xf32> + // CHECK: arm_sve.masked.subf {{.*}}: vector<[4]xi1>, vector< + %2 = arm_sve.masked.subf %mask, %1, %d : vector<[4]xi1>, + vector<[4]xf32> + // CHECK: arm_sve.masked.divf {{.*}}: vector<[4]xi1>, vector< + %3 = arm_sve.masked.divf %mask, %2, %e : vector<[4]xi1>, + vector<[4]xf32> + return %3 : vector<[4]xf32> } diff --git a/mlir/test/Dialect/Builtin/invalid.mlir b/mlir/test/Dialect/Builtin/invalid.mlir --- a/mlir/test/Dialect/Builtin/invalid.mlir +++ b/mlir/test/Dialect/Builtin/invalid.mlir @@ -9,3 +9,11 @@ // ----- +//===----------------------------------------------------------------------===// +// VectorType +//===----------------------------------------------------------------------===// + +// expected-error@+1 {{missing ']' closing set of scalable dimensions}} +func @scalable_vector_arg(%arg0: vector<[4xf32>) { } + +// ----- diff --git a/mlir/test/Dialect/Builtin/ops.mlir b/mlir/test/Dialect/Builtin/ops.mlir --- a/mlir/test/Dialect/Builtin/ops.mlir +++ b/mlir/test/Dialect/Builtin/ops.mlir @@ -18,3 +18,19 @@ // An unrealized N-1 conversion. %result3 = unrealized_conversion_cast %operand, %operand : !foo.type, !foo.type to !bar.tuple_type + +//===----------------------------------------------------------------------===// +// VectorType +//===----------------------------------------------------------------------===// + +// A basic 1D scalable vector +%scalable_vector_1d = "foo.op"() : () -> vector<[4]xi32> + +// A 2D scalable vector +%scalable_vector_2d = "foo.op"() : () -> vector<[2x2]xf64> + +// A 2D scalable vector with fixed-length dimensions +%scalable_vector_2d_mixed = "foo.op"() : () -> vector<2x[4]xbf16> + +// A multi-dimensional vector with mixed scalable and fixed-length dimensions +%scalable_vector_multi_mixed = "foo.op"() : () -> vector<2x2x[4x4]xi8> diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -578,6 +578,25 @@ return } +// CHECK-LABEL: @vector_load_and_store_scalable_vector_memref +func @vector_load_and_store_scalable_vector_memref(%v: vector<[4]xi32>, %m: memref) -> vector<[4]xi32> { + %c0 = arith.constant 0 : index + // CHECK: vector.load {{.*}}: memref, vector<[4]xi32> + %0 = vector.load %m[%c0] : memref, vector<[4]xi32> + // CHECK: vector.store {{.*}}: memref, vector<[4]xi32> + vector.store %v, %m[%c0] : memref, vector<[4]xi32> + return %0 : vector<[4]xi32> +} + +func @vector_load_and_store_1d_scalable_vector_memref(%memref : memref<200x100xvector<8xf32>>, + %i : index, %j : index) { + // CHECK: %[[ld:.*]] = vector.load %{{.*}}[%{{.*}}] : memref<200x100xvector<8xf32>>, vector<8xf32> + %0 = vector.load %memref[%i, %j] : memref<200x100xvector<8xf32>>, vector<8xf32> + // CHECK: vector.store %[[ld]], %{{.*}}[%{{.*}}] : memref<200x100xvector<8xf32>>, vector<8xf32> + vector.store %0, %memref[%i, %j] : memref<200x100xvector<8xf32>>, vector<8xf32> + return +} + // CHECK-LABEL: @vector_load_and_store_out_of_bounds func @vector_load_and_store_out_of_bounds(%memref : memref<7xf32>) { %c0 = arith.constant 0 : index @@ -691,3 +710,10 @@ vector<4x16xf32> to f32 return %2 : f32 } + +// CHECK-LABEL: @get_vector_scale +func @get_vector_scale() -> index { + // CHECK: vector.vscale + %0 = vector.vscale + return %0 : index +} diff --git a/mlir/test/Dialect/Vector/vector-scalable-memcpy.mlir b/mlir/test/Dialect/Vector/vector-scalable-memcpy.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-scalable-memcpy.mlir @@ -0,0 +1,27 @@ +// RUN: mlir-opt %s -convert-vector-to-llvm | mlir-opt | FileCheck %s + +// CHECK: vector_scalable_memcopy([[SRC:%arg[0-9]+]]: memref, [[DST:%arg[0-9]+]] +func @vector_scalable_memcopy(%src : memref, %dst : memref, %size : index) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %vs = vector.vscale + %step = arith.muli %c4, %vs : index + // CHECK: [[SRCMRS:%[0-9]+]] = builtin.unrealized_conversion_cast [[SRC]] : memref to !llvm.struct<(ptr + // CHECK: [[DSTMRS:%[0-9]+]] = builtin.unrealized_conversion_cast [[DST]] : memref to !llvm.struct<(ptr + // CHECK: scf.for [[LOOPIDX:%arg[0-9]+]] = {{.*}} + scf.for %i0 = %c0 to %size step %step { + // CHECK: [[DATAIDX:%[0-9]+]] = builtin.unrealized_conversion_cast [[LOOPIDX]] : index to i64 + // CHECK: [[SRCMEM:%[0-9]+]] = llvm.extractvalue [[SRCMRS]][1] : !llvm.struct<(ptr + // CHECK-NEXT: [[SRCPTR:%[0-9]+]] = llvm.getelementptr [[SRCMEM]]{{.}}[[DATAIDX]]{{.}} : (!llvm.ptr, i64) -> !llvm.ptr + // CHECK-NEXT: [[SRCVPTR:%[0-9]+]] = llvm.bitcast [[SRCPTR]] : !llvm.ptr to !llvm.ptr> + // CHECK-NEXT: [[LDVAL:%[0-9]+]] = llvm.load [[SRCVPTR]]{{.*}}: !llvm.ptr> + %0 = vector.load %src[%i0] : memref, vector<[4]xf32> + // CHECK: [[DSTMEM:%[0-9]+]] = llvm.extractvalue [[DSTMRS]][1] : !llvm.struct<(ptr + // CHECK-NEXT: [[DSTPTR:%[0-9]+]] = llvm.getelementptr [[DSTMEM]]{{.}}[[DATAIDX]]{{.}} : (!llvm.ptr, i64) -> !llvm.ptr + // CHECK-NEXT: [[DSTVPTR:%[0-9]+]] = llvm.bitcast [[DSTPTR]] : !llvm.ptr to !llvm.ptr> + // CHECK-NEXT: llvm.store [[LDVAL]], [[DSTVPTR]]{{.*}}: !llvm.ptr> + vector.store %0, %dst[%i0] : memref, vector<[4]xf32> + } + + return +} diff --git a/mlir/test/Target/LLVMIR/arm-sve.mlir b/mlir/test/Target/LLVMIR/arm-sve.mlir --- a/mlir/test/Target/LLVMIR/arm-sve.mlir +++ b/mlir/test/Target/LLVMIR/arm-sve.mlir @@ -1,193 +1,193 @@ // RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s // CHECK-LABEL: define @arm_sve_sdot -llvm.func @arm_sve_sdot(%arg0: !llvm.vec, - %arg1: !llvm.vec, - %arg2: !llvm.vec) - -> !llvm.vec { +llvm.func @arm_sve_sdot(%arg0: vector<[16]xi8>, + %arg1: vector<[16]xi8>, + %arg2: vector<[4]xi32>) + -> vector<[4]xi32> { // CHECK: call @llvm.aarch64.sve.sdot.nxv4i32(, !llvm.vec, !llvm.vec) - -> !llvm.vec - llvm.return %0 : !llvm.vec + (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) + -> vector<[4]xi32> + llvm.return %0 : vector<[4]xi32> } // CHECK-LABEL: define @arm_sve_smmla -llvm.func @arm_sve_smmla(%arg0: !llvm.vec, - %arg1: !llvm.vec, - %arg2: !llvm.vec) - -> !llvm.vec { +llvm.func @arm_sve_smmla(%arg0: vector<[16]xi8>, + %arg1: vector<[16]xi8>, + %arg2: vector<[4]xi32>) + -> vector<[4]xi32> { // CHECK: call @llvm.aarch64.sve.smmla.nxv4i32(, !llvm.vec, !llvm.vec) - -> !llvm.vec - llvm.return %0 : !llvm.vec + (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) + -> vector<[4]xi32> + llvm.return %0 : vector<[4]xi32> } // CHECK-LABEL: define @arm_sve_udot -llvm.func @arm_sve_udot(%arg0: !llvm.vec, - %arg1: !llvm.vec, - %arg2: !llvm.vec) - -> !llvm.vec { +llvm.func @arm_sve_udot(%arg0: vector<[16]xi8>, + %arg1: vector<[16]xi8>, + %arg2: vector<[4]xi32>) + -> vector<[4]xi32> { // CHECK: call @llvm.aarch64.sve.udot.nxv4i32(, !llvm.vec, !llvm.vec) - -> !llvm.vec - llvm.return %0 : !llvm.vec + (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) + -> vector<[4]xi32> + llvm.return %0 : vector<[4]xi32> } // CHECK-LABEL: define @arm_sve_ummla -llvm.func @arm_sve_ummla(%arg0: !llvm.vec, - %arg1: !llvm.vec, - %arg2: !llvm.vec) - -> !llvm.vec { +llvm.func @arm_sve_ummla(%arg0: vector<[16]xi8>, + %arg1: vector<[16]xi8>, + %arg2: vector<[4]xi32>) + -> vector<[4]xi32> { // CHECK: call @llvm.aarch64.sve.ummla.nxv4i32(, !llvm.vec, !llvm.vec) - -> !llvm.vec - llvm.return %0 : !llvm.vec + (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>) + -> vector<[4]xi32> + llvm.return %0 : vector<[4]xi32> } // CHECK-LABEL: define @arm_sve_arithi -llvm.func @arm_sve_arithi(%arg0: !llvm.vec, - %arg1: !llvm.vec, - %arg2: !llvm.vec) - -> !llvm.vec { +llvm.func @arm_sve_arithi(%arg0: vector<[4]xi32>, + %arg1: vector<[4]xi32>, + %arg2: vector<[4]xi32>) + -> vector<[4]xi32> { // CHECK: mul - %0 = llvm.mul %arg0, %arg1 : !llvm.vec + %0 = llvm.mul %arg0, %arg1 : vector<[4]xi32> // CHECK: add - %1 = llvm.add %0, %arg2 : !llvm.vec - llvm.return %1 : !llvm.vec + %1 = llvm.add %0, %arg2 : vector<[4]xi32> + llvm.return %1 : vector<[4]xi32> } // CHECK-LABEL: define @arm_sve_arithf -llvm.func @arm_sve_arithf(%arg0: !llvm.vec, - %arg1: !llvm.vec, - %arg2: !llvm.vec) - -> !llvm.vec { +llvm.func @arm_sve_arithf(%arg0: vector<[4]xf32>, + %arg1: vector<[4]xf32>, + %arg2: vector<[4]xf32>) + -> vector<[4]xf32> { // CHECK: fmul - %0 = llvm.fmul %arg0, %arg1 : !llvm.vec + %0 = llvm.fmul %arg0, %arg1 : vector<[4]xf32> // CHECK: fadd - %1 = llvm.fadd %0, %arg2 : !llvm.vec - llvm.return %1 : !llvm.vec + %1 = llvm.fadd %0, %arg2 : vector<[4]xf32> + llvm.return %1 : vector<[4]xf32> } // CHECK-LABEL: define @arm_sve_arithi_masked -llvm.func @arm_sve_arithi_masked(%arg0: !llvm.vec, - %arg1: !llvm.vec, - %arg2: !llvm.vec, - %arg3: !llvm.vec, - %arg4: !llvm.vec, - %arg5: !llvm.vec) - -> !llvm.vec { +llvm.func @arm_sve_arithi_masked(%arg0: vector<[4]xi32>, + %arg1: vector<[4]xi32>, + %arg2: vector<[4]xi32>, + %arg3: vector<[4]xi32>, + %arg4: vector<[4]xi32>, + %arg5: vector<[4]xi1>) + -> vector<[4]xi32> { // CHECK: call @llvm.aarch64.sve.add.nxv4i32 - %0 = "arm_sve.intr.add"(%arg5, %arg0, %arg1) : (!llvm.vec, - !llvm.vec, - !llvm.vec) - -> !llvm.vec + %0 = "arm_sve.intr.add"(%arg5, %arg0, %arg1) : (vector<[4]xi1>, + vector<[4]xi32>, + vector<[4]xi32>) + -> vector<[4]xi32> // CHECK: call @llvm.aarch64.sve.sub.nxv4i32 - %1 = "arm_sve.intr.sub"(%arg5, %0, %arg1) : (!llvm.vec, - !llvm.vec, - !llvm.vec) - -> !llvm.vec + %1 = "arm_sve.intr.sub"(%arg5, %0, %arg1) : (vector<[4]xi1>, + vector<[4]xi32>, + vector<[4]xi32>) + -> vector<[4]xi32> // CHECK: call @llvm.aarch64.sve.mul.nxv4i32 - %2 = "arm_sve.intr.mul"(%arg5, %1, %arg3) : (!llvm.vec, - !llvm.vec, - !llvm.vec) - -> !llvm.vec + %2 = "arm_sve.intr.mul"(%arg5, %1, %arg3) : (vector<[4]xi1>, + vector<[4]xi32>, + vector<[4]xi32>) + -> vector<[4]xi32> // CHECK: call @llvm.aarch64.sve.sdiv.nxv4i32 - %3 = "arm_sve.intr.sdiv"(%arg5, %2, %arg4) : (!llvm.vec, - !llvm.vec, - !llvm.vec) - -> !llvm.vec + %3 = "arm_sve.intr.sdiv"(%arg5, %2, %arg4) : (vector<[4]xi1>, + vector<[4]xi32>, + vector<[4]xi32>) + -> vector<[4]xi32> // CHECK: call @llvm.aarch64.sve.udiv.nxv4i32 - %4 = "arm_sve.intr.udiv"(%arg5, %3, %arg4) : (!llvm.vec, - !llvm.vec, - !llvm.vec) - -> !llvm.vec - llvm.return %4 : !llvm.vec + %4 = "arm_sve.intr.udiv"(%arg5, %3, %arg4) : (vector<[4]xi1>, + vector<[4]xi32>, + vector<[4]xi32>) + -> vector<[4]xi32> + llvm.return %4 : vector<[4]xi32> } // CHECK-LABEL: define @arm_sve_arithf_masked -llvm.func @arm_sve_arithf_masked(%arg0: !llvm.vec, - %arg1: !llvm.vec, - %arg2: !llvm.vec, - %arg3: !llvm.vec, - %arg4: !llvm.vec, - %arg5: !llvm.vec) - -> !llvm.vec { +llvm.func @arm_sve_arithf_masked(%arg0: vector<[4]xf32>, + %arg1: vector<[4]xf32>, + %arg2: vector<[4]xf32>, + %arg3: vector<[4]xf32>, + %arg4: vector<[4]xf32>, + %arg5: vector<[4]xi1>) + -> vector<[4]xf32> { // CHECK: call @llvm.aarch64.sve.fadd.nxv4f32 - %0 = "arm_sve.intr.fadd"(%arg5, %arg0, %arg1) : (!llvm.vec, - !llvm.vec, - !llvm.vec) - -> !llvm.vec + %0 = "arm_sve.intr.fadd"(%arg5, %arg0, %arg1) : (vector<[4]xi1>, + vector<[4]xf32>, + vector<[4]xf32>) + -> vector<[4]xf32> // CHECK: call @llvm.aarch64.sve.fsub.nxv4f32 - %1 = "arm_sve.intr.fsub"(%arg5, %0, %arg2) : (!llvm.vec, - !llvm.vec, - !llvm.vec) - -> !llvm.vec + %1 = "arm_sve.intr.fsub"(%arg5, %0, %arg2) : (vector<[4]xi1>, + vector<[4]xf32>, + vector<[4]xf32>) + -> vector<[4]xf32> // CHECK: call @llvm.aarch64.sve.fmul.nxv4f32 - %2 = "arm_sve.intr.fmul"(%arg5, %1, %arg3) : (!llvm.vec, - !llvm.vec, - !llvm.vec) - -> !llvm.vec + %2 = "arm_sve.intr.fmul"(%arg5, %1, %arg3) : (vector<[4]xi1>, + vector<[4]xf32>, + vector<[4]xf32>) + -> vector<[4]xf32> // CHECK: call @llvm.aarch64.sve.fdiv.nxv4f32 - %3 = "arm_sve.intr.fdiv"(%arg5, %2, %arg4) : (!llvm.vec, - !llvm.vec, - !llvm.vec) - -> !llvm.vec - llvm.return %3 : !llvm.vec + %3 = "arm_sve.intr.fdiv"(%arg5, %2, %arg4) : (vector<[4]xi1>, + vector<[4]xf32>, + vector<[4]xf32>) + -> vector<[4]xf32> + llvm.return %3 : vector<[4]xf32> } // CHECK-LABEL: define @arm_sve_mask_genf -llvm.func @arm_sve_mask_genf(%arg0: !llvm.vec, - %arg1: !llvm.vec) - -> !llvm.vec { +llvm.func @arm_sve_mask_genf(%arg0: vector<[4]xf32>, + %arg1: vector<[4]xf32>) + -> vector<[4]xi1> { // CHECK: fcmp oeq - %0 = llvm.fcmp "oeq" %arg0, %arg1 : !llvm.vec - llvm.return %0 : !llvm.vec + %0 = llvm.fcmp "oeq" %arg0, %arg1 : vector<[4]xf32> + llvm.return %0 : vector<[4]xi1> } // CHECK-LABEL: define @arm_sve_mask_geni -llvm.func @arm_sve_mask_geni(%arg0: !llvm.vec, - %arg1: !llvm.vec) - -> !llvm.vec { +llvm.func @arm_sve_mask_geni(%arg0: vector<[4]xi32>, + %arg1: vector<[4]xi32>) + -> vector<[4]xi1> { // CHECK: icmp uge - %0 = llvm.icmp "uge" %arg0, %arg1 : !llvm.vec - llvm.return %0 : !llvm.vec + %0 = llvm.icmp "uge" %arg0, %arg1 : vector<[4]xi32> + llvm.return %0 : vector<[4]xi1> } // CHECK-LABEL: define @arm_sve_abs_diff -llvm.func @arm_sve_abs_diff(%arg0: !llvm.vec, - %arg1: !llvm.vec) - -> !llvm.vec { +llvm.func @arm_sve_abs_diff(%arg0: vector<[4]xi32>, + %arg1: vector<[4]xi32>) + -> vector<[4]xi32> { // CHECK: sub - %0 = llvm.sub %arg0, %arg0 : !llvm.vec + %0 = llvm.sub %arg0, %arg0 : vector<[4]xi32> // CHECK: icmp sge - %1 = llvm.icmp "sge" %arg0, %arg1 : !llvm.vec + %1 = llvm.icmp "sge" %arg0, %arg1 : vector<[4]xi32> // CHECK: icmp slt - %2 = llvm.icmp "slt" %arg0, %arg1 : !llvm.vec + %2 = llvm.icmp "slt" %arg0, %arg1 : vector<[4]xi32> // CHECK: call @llvm.aarch64.sve.sub.nxv4i32 - %3 = "arm_sve.intr.sub"(%1, %arg0, %arg1) : (!llvm.vec, - !llvm.vec, - !llvm.vec) - -> !llvm.vec + %3 = "arm_sve.intr.sub"(%1, %arg0, %arg1) : (vector<[4]xi1>, + vector<[4]xi32>, + vector<[4]xi32>) + -> vector<[4]xi32> // CHECK: call @llvm.aarch64.sve.sub.nxv4i32 - %4 = "arm_sve.intr.sub"(%2, %arg1, %arg0) : (!llvm.vec, - !llvm.vec, - !llvm.vec) - -> !llvm.vec + %4 = "arm_sve.intr.sub"(%2, %arg1, %arg0) : (vector<[4]xi1>, + vector<[4]xi32>, + vector<[4]xi32>) + -> vector<[4]xi32> // CHECK: call @llvm.aarch64.sve.add.nxv4i32 - %5 = "arm_sve.intr.add"(%1, %0, %3) : (!llvm.vec, - !llvm.vec, - !llvm.vec) - -> !llvm.vec + %5 = "arm_sve.intr.add"(%1, %0, %3) : (vector<[4]xi1>, + vector<[4]xi32>, + vector<[4]xi32>) + -> vector<[4]xi32> // CHECK: call @llvm.aarch64.sve.add.nxv4i32 - %6 = "arm_sve.intr.add"(%2, %5, %4) : (!llvm.vec, - !llvm.vec, - !llvm.vec) - -> !llvm.vec - llvm.return %6 : !llvm.vec + %6 = "arm_sve.intr.add"(%2, %5, %4) : (vector<[4]xi1>, + vector<[4]xi32>, + vector<[4]xi32>) + -> vector<[4]xi32> + llvm.return %6 : vector<[4]xi32> } // CHECK-LABEL: define void @memcopy @@ -234,7 +234,7 @@ %12 = llvm.mlir.constant(0 : index) : i64 %13 = llvm.mlir.constant(4 : index) : i64 // CHECK: [[VL:%[0-9]+]] = call i64 @llvm.vscale.i64() - %14 = "arm_sve.vscale"() : () -> i64 + %14 = "llvm.intr.vscale"() : () -> i64 // CHECK: mul i64 [[VL]], 4 %15 = llvm.mul %14, %13 : i64 llvm.br ^bb1(%12 : i64) @@ -249,9 +249,9 @@ // CHECK: etelementptr float, float* %19 = llvm.getelementptr %18[%16] : (!llvm.ptr, i64) -> !llvm.ptr // CHECK: bitcast float* %{{[0-9]+}} to * - %20 = llvm.bitcast %19 : !llvm.ptr to !llvm.ptr> + %20 = llvm.bitcast %19 : !llvm.ptr to !llvm.ptr> // CHECK: load , * - %21 = llvm.load %20 : !llvm.ptr> + %21 = llvm.load %20 : !llvm.ptr> // CHECK: extractvalue { float*, float*, i64, [1 x i64], [1 x i64] } %22 = llvm.extractvalue %11[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, @@ -259,9 +259,9 @@ // CHECK: getelementptr float, float* %32 %23 = llvm.getelementptr %22[%16] : (!llvm.ptr, i64) -> !llvm.ptr // CHECK: bitcast float* %33 to * - %24 = llvm.bitcast %23 : !llvm.ptr to !llvm.ptr> + %24 = llvm.bitcast %23 : !llvm.ptr to !llvm.ptr> // CHECK: store %{{[0-9]+}}, * %{{[0-9]+}} - llvm.store %21, %24 : !llvm.ptr> + llvm.store %21, %24 : !llvm.ptr> %25 = llvm.add %16, %15 : i64 llvm.br ^bb1(%25 : i64) ^bb3: @@ -271,6 +271,6 @@ // CHECK-LABEL: define i64 @get_vector_scale() llvm.func @get_vector_scale() -> i64 { // CHECK: call i64 @llvm.vscale.i64() - %0 = "arm_sve.vscale"() : () -> i64 + %0 = "llvm.intr.vscale"() : () -> i64 llvm.return %0 : i64 }