diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td @@ -64,6 +64,10 @@ let results = (outs SPV_Composite:$result ); + + let assemblyFormat = [{ + $constituents attr-dict `:` `(` type(operands) `)` `->` type($result) + }]; } // ----- diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -31,6 +31,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/bit.h" +#include using namespace mlir; @@ -1618,66 +1619,64 @@ // spv.CompositeConstruct //===----------------------------------------------------------------------===// -ParseResult spirv::CompositeConstructOp::parse(OpAsmParser &parser, - OperationState &state) { - SmallVector operands; - Type type; - auto loc = parser.getCurrentLocation(); - - if (parser.parseOperandList(operands) || parser.parseColonType(type)) { - return failure(); - } - auto cType = type.dyn_cast(); - if (!cType) { - return parser.emitError( - loc, "result type must be a composite type, but provided ") - << type; - } - - if (cType.hasCompileTimeKnownNumElements() && - operands.size() != cType.getNumElements()) { - return parser.emitError(loc, "has incorrect number of operands: expected ") - << cType.getNumElements() << ", but provided " << operands.size(); - } - // TODO: Add support for constructing a vector type from the vector operands. - // According to the spec: "for constructing a vector, the operands may - // also be vectors with the same component type as the Result Type component - // type". - SmallVector elementTypes; - elementTypes.reserve(operands.size()); - for (auto index : llvm::seq(0, operands.size())) { - elementTypes.push_back(cType.getElementType(index)); - } - state.addTypes(type); - return parser.resolveOperands(operands, elementTypes, loc, state.operands); -} - -void spirv::CompositeConstructOp::print(OpAsmPrinter &printer) { - printer << " " << constituents() << " : " << getResult().getType(); -} - LogicalResult spirv::CompositeConstructOp::verify() { auto cType = getType().cast(); operand_range constituents = this->constituents(); - if (cType.isa()) { + if (auto coopType = cType.dyn_cast()) { if (constituents.size() != 1) - return emitError("has incorrect number of operands: expected ") + return emitOpError("has incorrect number of operands: expected ") << "1, but provided " << constituents.size(); - } else if (constituents.size() != cType.getNumElements()) { - return emitError("has incorrect number of operands: expected ") - << cType.getNumElements() << ", but provided " - << constituents.size(); + if (coopType.getElementType() != constituents.front().getType()) + return emitOpError("operand type mismatch: expected operand type ") + << coopType.getElementType() << ", but provided " + << constituents.front().getType(); + return success(); } - for (auto index : llvm::seq(0, constituents.size())) { - if (constituents[index].getType() != cType.getElementType(index)) { - return emitError("operand type mismatch: expected operand type ") - << cType.getElementType(index) << ", but provided " - << constituents[index].getType(); + if (constituents.size() == cType.getNumElements()) { + for (auto index : llvm::seq(0, constituents.size())) { + if (constituents[index].getType() != cType.getElementType(index)) { + return emitOpError("operand type mismatch: expected operand type ") + << cType.getElementType(index) << ", but provided " + << constituents[index].getType(); + } } + return success(); } + // If not constructing a cooperative matrix type, then we must be constructing + // a vector type. + auto resultType = cType.dyn_cast(); + if (!resultType) + return emitOpError( + "expected to return a vector or cooperative matrix when the number of " + "constituents is less than what the result needs"); + + SmallVector sizes; + for (Value component : constituents) { + if (!component.getType().isa() && + !component.getType().isIntOrFloat()) + return emitOpError("operand type mismatch: expected operand to have " + "a scalar or vector type, but provided ") + << component.getType(); + + Type elementType = component.getType(); + if (auto vectorType = component.getType().dyn_cast()) { + sizes.push_back(vectorType.getNumElements()); + elementType = vectorType.getElementType(); + } else { + sizes.push_back(1); + } + + if (elementType != resultType.getElementType()) + return emitOpError("operand element type mismatch: expected to be ") + << resultType.getElementType() << ", but provided " << elementType; + } + unsigned totalCount = std::accumulate(sizes.begin(), sizes.end(), 0); + if (totalCount != cType.getNumElements()) + return emitOpError("has incorrect number of operands: expected ") + << cType.getNumElements() << ", but provided " << totalCount; return success(); } diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-core-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-core-spirv.mlir --- a/mlir/test/Conversion/MathToSPIRV/math-to-core-spirv.mlir +++ b/mlir/test/Conversion/MathToSPIRV/math-to-core-spirv.mlir @@ -32,8 +32,8 @@ // CHECK-SAME: (%[[VALUE:.+]]: vector<3xf16>, %[[SIGN:.+]]: vector<3xf16>) // CHECK: %[[SMASK:.+]] = spv.Constant -32768 : i16 // CHECK: %[[VMASK:.+]] = spv.Constant 32767 : i16 -// CHECK: %[[SVMASK:.+]] = spv.CompositeConstruct %[[SMASK]], %[[SMASK]], %[[SMASK]] : vector<3xi16> -// CHECK: %[[VVMASK:.+]] = spv.CompositeConstruct %[[VMASK]], %[[VMASK]], %[[VMASK]] : vector<3xi16> +// CHECK: %[[SVMASK:.+]] = spv.CompositeConstruct %[[SMASK]], %[[SMASK]], %[[SMASK]] +// CHECK: %[[VVMASK:.+]] = spv.CompositeConstruct %[[VMASK]], %[[VMASK]], %[[VMASK]] // CHECK: %[[VCAST:.+]] = spv.Bitcast %[[VALUE]] : vector<3xf16> to vector<3xi16> // CHECK: %[[SCAST:.+]] = spv.Bitcast %[[SIGN]] : vector<3xf16> to vector<3xi16> // CHECK: %[[VAND:.+]] = spv.BitwiseAnd %[[VCAST]], %[[VVMASK]] : vector<3xi16> diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir --- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir @@ -18,8 +18,8 @@ // CHECK-LABEL: @broadcast // CHECK-SAME: %[[A:.*]]: f32 -// CHECK: spv.CompositeConstruct %[[A]], %[[A]], %[[A]], %[[A]] : vector<4xf32> -// CHECK: spv.CompositeConstruct %[[A]], %[[A]] : vector<2xf32> +// CHECK: spv.CompositeConstruct %[[A]], %[[A]], %[[A]], %[[A]] +// CHECK: spv.CompositeConstruct %[[A]], %[[A]] func.func @broadcast(%arg0 : f32) -> (vector<4xf32>, vector<2xf32>) { %0 = vector.broadcast %arg0 : f32 to vector<4xf32> %1 = vector.broadcast %arg0 : f32 to vector<2xf32> @@ -182,7 +182,7 @@ // CHECK-LABEL: func @splat // CHECK-SAME: (%[[A:.+]]: f32) -// CHECK: %[[VAL:.+]] = spv.CompositeConstruct %[[A]], %[[A]], %[[A]], %[[A]] : vector<4xf32> +// CHECK: %[[VAL:.+]] = spv.CompositeConstruct %[[A]], %[[A]], %[[A]], %[[A]] // CHECK: return %[[VAL]] func.func @splat(%f : f32) -> vector<4xf32> { %splat = vector.splat %f : vector<4xf32> @@ -206,7 +206,7 @@ // CHECK-SAME: %[[ARG0:.+]]: vector<1xf32>, %[[ARG1:.+]]: vector<1xf32> // CHECK: %[[V0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] // CHECK: %[[V1:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] -// CHECK: spv.CompositeConstruct %[[V0]], %[[V1]], %[[V1]], %[[V0]] : vector<4xf32> +// CHECK: spv.CompositeConstruct %[[V0]], %[[V1]], %[[V1]], %[[V0]] : (f32, f32, f32, f32) -> vector<4xf32> func.func @shuffle(%v0 : vector<1xf32>, %v1: vector<1xf32>) -> vector<4xf32> { %shuffle = vector.shuffle %v0, %v1 [0, 1, 1, 0] : vector<1xf32>, vector<1xf32> return %shuffle : vector<4xf32> diff --git a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir --- a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir @@ -5,48 +5,41 @@ //===----------------------------------------------------------------------===// func.func @composite_construct_vector(%arg0: f32, %arg1: f32, %arg2 : f32) -> vector<3xf32> { - // CHECK: spv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : vector<3xf32> - %0 = spv.CompositeConstruct %arg0, %arg1, %arg2 : vector<3xf32> + // CHECK: spv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : (f32, f32, f32) -> vector<3xf32> + %0 = spv.CompositeConstruct %arg0, %arg1, %arg2 : (f32, f32, f32) -> vector<3xf32> return %0: vector<3xf32> } // ----- func.func @composite_construct_struct(%arg0: vector<3xf32>, %arg1: !spv.array<4xf32>, %arg2 : !spv.struct<(f32)>) -> !spv.struct<(vector<3xf32>, !spv.array<4xf32>, !spv.struct<(f32)>)> { - // CHECK: spv.CompositeConstruct %arg0, %arg1, %arg2 : !spv.struct<(vector<3xf32>, !spv.array<4 x f32>, !spv.struct<(f32)>)> - %0 = spv.CompositeConstruct %arg0, %arg1, %arg2 : !spv.struct<(vector<3xf32>, !spv.array<4xf32>, !spv.struct<(f32)>)> + // CHECK: spv.CompositeConstruct + %0 = spv.CompositeConstruct %arg0, %arg1, %arg2 : (vector<3xf32>, !spv.array<4xf32>, !spv.struct<(f32)>) -> !spv.struct<(vector<3xf32>, !spv.array<4xf32>, !spv.struct<(f32)>)> return %0: !spv.struct<(vector<3xf32>, !spv.array<4xf32>, !spv.struct<(f32)>)> } // ----- -func.func @composite_construct_coopmatrix(%arg0 : f32) -> !spv.coopmatrix<8x16xf32, Subgroup> { - // CHECK: spv.CompositeConstruct {{%.*}} : !spv.coopmatrix<8x16xf32, Subgroup> - %0 = spv.CompositeConstruct %arg0 : !spv.coopmatrix<8x16xf32, Subgroup> - return %0: !spv.coopmatrix<8x16xf32, Subgroup> +// CHECK-LABEL: func @composite_construct_mixed_scalar_vector +func.func @composite_construct_mixed_scalar_vector(%arg0: f32, %arg1: f32, %arg2 : vector<2xf32>) -> vector<4xf32> { + // CHECK: spv.CompositeConstruct %{{.+}}, %{{.+}}, %{{.+}} : (f32, vector<2xf32>, f32) -> vector<4xf32> + %0 = spv.CompositeConstruct %arg0, %arg2, %arg1 : (f32, vector<2xf32>, f32) -> vector<4xf32> + return %0: vector<4xf32> } // ----- -func.func @composite_construct_empty_struct() -> !spv.struct<()> { - // CHECK: spv.CompositeConstruct : !spv.struct<()> - %0 = spv.CompositeConstruct : !spv.struct<()> - return %0: !spv.struct<()> -} - -// ----- - -func.func @composite_construct_invalid_num_of_elements(%arg0: f32) -> f32 { - // expected-error @+1 {{result type must be a composite type, but provided 'f32'}} - %0 = spv.CompositeConstruct %arg0 : f32 - return %0: f32 +func.func @composite_construct_coopmatrix(%arg0 : f32) -> !spv.coopmatrix<8x16xf32, Subgroup> { + // CHECK: spv.CompositeConstruct {{%.*}} : (f32) -> !spv.coopmatrix<8x16xf32, Subgroup> + %0 = spv.CompositeConstruct %arg0 : (f32) -> !spv.coopmatrix<8x16xf32, Subgroup> + return %0: !spv.coopmatrix<8x16xf32, Subgroup> } // ----- func.func @composite_construct_invalid_result_type(%arg0: f32, %arg1: f32, %arg2 : f32) -> vector<3xf32> { // expected-error @+1 {{has incorrect number of operands: expected 3, but provided 2}} - %0 = spv.CompositeConstruct %arg0, %arg2 : vector<3xf32> + %0 = spv.CompositeConstruct %arg0, %arg2 : (f32, f32) -> vector<3xf32> return %0: vector<3xf32> } @@ -54,20 +47,52 @@ func.func @composite_construct_invalid_operand_type(%arg0: f32, %arg1: f32, %arg2 : f32) -> vector<3xi32> { // expected-error @+1 {{operand type mismatch: expected operand type 'i32', but provided 'f32'}} - %0 = "spv.CompositeConstruct" (%arg0, %arg1, %arg2) : (f32, f32, f32) -> vector<3xi32> + %0 = spv.CompositeConstruct %arg0, %arg1, %arg2 : (f32, f32, f32) -> vector<3xi32> return %0: vector<3xi32> } // ----- -func.func @composite_construct_coopmatrix(%arg0 : f32, %arg1 : f32) -> !spv.coopmatrix<8x16xf32, Subgroup> { +func.func @composite_construct_coopmatrix_incorrect_operand_count(%arg0 : f32, %arg1 : f32) -> !spv.coopmatrix<8x16xf32, Subgroup> { // expected-error @+1 {{has incorrect number of operands: expected 1, but provided 2}} - %0 = spv.CompositeConstruct %arg0, %arg1 : !spv.coopmatrix<8x16xf32, Subgroup> + %0 = spv.CompositeConstruct %arg0, %arg1 : (f32, f32) -> !spv.coopmatrix<8x16xf32, Subgroup> + return %0: !spv.coopmatrix<8x16xf32, Subgroup> +} + +// ----- + +func.func @composite_construct_coopmatrix_incorrect_element_type(%arg0 : i32) -> !spv.coopmatrix<8x16xf32, Subgroup> { + // expected-error @+1 {{operand type mismatch: expected operand type 'f32', but provided 'i32'}} + %0 = spv.CompositeConstruct %arg0 : (i32) -> !spv.coopmatrix<8x16xf32, Subgroup> return %0: !spv.coopmatrix<8x16xf32, Subgroup> } // ----- +func.func @composite_construct_array(%arg0: f32) -> !spv.array<4xf32> { + // expected-error @+1 {{expected to return a vector or cooperative matrix when the number of constituents is less than what the result needs}} + %0 = spv.CompositeConstruct %arg0 : (f32) -> !spv.array<4xf32> + return %0: !spv.array<4xf32> +} + +// ----- + +func.func @composite_construct_vector_wrong_element_type(%arg0: f32, %arg1: f32, %arg2 : vector<2xi32>) -> vector<4xf32> { + // expected-error @+1 {{operand element type mismatch: expected to be 'f32', but provided 'i32'}} + %0 = spv.CompositeConstruct %arg0, %arg2, %arg1 : (f32, vector<2xi32>, f32) -> vector<4xf32> + return %0: vector<4xf32> +} + +// ----- + +func.func @composite_construct_vector_wrong_count(%arg0: f32, %arg1: f32, %arg2 : vector<2xf32>) -> vector<4xf32> { + // expected-error @+1 {{op has incorrect number of operands: expected 4, but provided 3}} + %0 = spv.CompositeConstruct %arg0, %arg2 : (f32, vector<2xf32>) -> vector<4xf32> + return %0: vector<4xf32> +} + +// ----- + //===----------------------------------------------------------------------===// // spv.CompositeExtractOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/Transforms/rewrite-inserts.mlir b/mlir/test/Dialect/SPIRV/Transforms/rewrite-inserts.mlir --- a/mlir/test/Dialect/SPIRV/Transforms/rewrite-inserts.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/rewrite-inserts.mlir @@ -3,26 +3,26 @@ spv.module Logical GLSL450 { spv.func @rewrite(%value0 : f32, %value1 : f32, %value2 : f32, %value3 : i32, %value4: !spv.array<3xf32>) -> vector<3xf32> "None" { %0 = spv.Undef : vector<3xf32> - // CHECK: spv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : vector<3xf32> + // CHECK: spv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : (f32, f32, f32) -> vector<3xf32> %1 = spv.CompositeInsert %value0, %0[0 : i32] : f32 into vector<3xf32> %2 = spv.CompositeInsert %value1, %1[1 : i32] : f32 into vector<3xf32> %3 = spv.CompositeInsert %value2, %2[2 : i32] : f32 into vector<3xf32> %4 = spv.Undef : !spv.array<4xf32> - // CHECK: spv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}}, {{%.*}} : !spv.array<4 x f32> + // CHECK: spv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}}, {{%.*}} : (f32, f32, f32, f32) -> !spv.array<4 x f32> %5 = spv.CompositeInsert %value0, %4[0 : i32] : f32 into !spv.array<4xf32> %6 = spv.CompositeInsert %value1, %5[1 : i32] : f32 into !spv.array<4xf32> %7 = spv.CompositeInsert %value2, %6[2 : i32] : f32 into !spv.array<4xf32> %8 = spv.CompositeInsert %value0, %7[3 : i32] : f32 into !spv.array<4xf32> %9 = spv.Undef : !spv.struct<(f32, i32, f32)> - // CHECK: spv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : !spv.struct<(f32, i32, f32)> + // CHECK: spv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : (f32, i32, f32) -> !spv.struct<(f32, i32, f32)> %10 = spv.CompositeInsert %value0, %9[0 : i32] : f32 into !spv.struct<(f32, i32, f32)> %11 = spv.CompositeInsert %value3, %10[1 : i32] : i32 into !spv.struct<(f32, i32, f32)> %12 = spv.CompositeInsert %value1, %11[2 : i32] : f32 into !spv.struct<(f32, i32, f32)> %13 = spv.Undef : !spv.struct<(f32, !spv.array<3xf32>)> - // CHECK: spv.CompositeConstruct {{%.*}}, {{%.*}} : !spv.struct<(f32, !spv.array<3 x f32>)> + // CHECK: spv.CompositeConstruct {{%.*}}, {{%.*}} : (f32, !spv.array<3 x f32>) -> !spv.struct<(f32, !spv.array<3 x f32>)> %14 = spv.CompositeInsert %value0, %13[0 : i32] : f32 into !spv.struct<(f32, !spv.array<3xf32>)> %15 = spv.CompositeInsert %value4, %14[1 : i32] : !spv.array<3xf32> into !spv.struct<(f32, !spv.array<3xf32>)> diff --git a/mlir/test/Target/SPIRV/composite-op.mlir b/mlir/test/Target/SPIRV/composite-op.mlir --- a/mlir/test/Target/SPIRV/composite-op.mlir +++ b/mlir/test/Target/SPIRV/composite-op.mlir @@ -7,8 +7,8 @@ spv.ReturnValue %0: !spv.struct<(f32, !spv.struct<(!spv.array<4xf32>, f32)>)> } spv.func @composite_construct_vector(%arg0: f32, %arg1: f32, %arg2 : f32) -> vector<3xf32> "None" { - // CHECK: spv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : vector<3xf32> - %0 = spv.CompositeConstruct %arg0, %arg1, %arg2 : vector<3xf32> + // CHECK: spv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : (f32, f32, f32) -> vector<3xf32> + %0 = spv.CompositeConstruct %arg0, %arg1, %arg2 : (f32, f32, f32) -> vector<3xf32> spv.ReturnValue %0: vector<3xf32> } spv.func @vector_dynamic_extract(%vec: vector<4xf32>, %id : i32) -> f32 "None" { diff --git a/mlir/test/Target/SPIRV/debug.mlir b/mlir/test/Target/SPIRV/debug.mlir --- a/mlir/test/Target/SPIRV/debug.mlir +++ b/mlir/test/Target/SPIRV/debug.mlir @@ -33,7 +33,7 @@ // CHECK: loc({{".*debug.mlir"}}:34:10) %0 = spv.CompositeInsert %arg1, %arg0[1 : i32, 0 : i32] : !spv.array<4xf32> into !spv.struct<(f32, !spv.struct<(!spv.array<4xf32>, f32)>)> // CHECK: loc({{".*debug.mlir"}}:36:10) - %1 = spv.CompositeConstruct %arg2, %arg3 : vector<2xf32> + %1 = spv.CompositeConstruct %arg2, %arg3 : (f32, f32) -> vector<2xf32> spv.Return }