diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -1737,17 +1737,13 @@ printer << " : " << constOp.getType(); } -LogicalResult spirv::ConstantOp::verify() { - auto opType = getType(); - auto value = valueAttr(); +static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value, + Type opType) { auto valueType = value.getType(); - // ODS already generates checks to make sure the result type is valid. We just - // need to additionally check that the value's attribute type is consistent - // with the result type. if (value.isa()) { if (valueType != opType) - return emitOpError("result type (") + return op.emitOpError("result type (") << opType << ") does not match value type (" << valueType << ")"; return success(); } @@ -1757,7 +1753,9 @@ auto arrayType = opType.dyn_cast(); auto shapedType = valueType.dyn_cast(); if (!arrayType) - return emitOpError("must have spv.array result type for array value"); + return op.emitOpError("result or element type (") + << opType << ") does not match value type (" << valueType + << "), must be the same or spv.array"; int numElements = arrayType.getNumElements(); auto opElemType = arrayType.getElementType(); @@ -1766,37 +1764,42 @@ opElemType = t.getElementType(); } if (!opElemType.isIntOrFloat()) - return emitOpError("only support nested array result type"); + return op.emitOpError("only support nested array result type"); auto valueElemType = shapedType.getElementType(); if (valueElemType != opElemType) { - return emitOpError("result element type (") + return op.emitOpError("result element type (") << opElemType << ") does not match value element type (" << valueElemType << ")"; } if (numElements != shapedType.getNumElements()) { - return emitOpError("result number of elements (") + return op.emitOpError("result number of elements (") << numElements << ") does not match value number of elements (" << shapedType.getNumElements() << ")"; } return success(); } - if (auto attayAttr = value.dyn_cast()) { + if (auto arrayAttr = value.dyn_cast()) { auto arrayType = opType.dyn_cast(); if (!arrayType) - return emitOpError("must have spv.array result type for array value"); + return op.emitOpError("must have spv.array result type for array value"); Type elemType = arrayType.getElementType(); - for (Attribute element : attayAttr.getValue()) { - if (element.getType() != elemType) - return emitOpError("has array element whose type (") - << element.getType() - << ") does not match the result element type (" << elemType - << ')'; + for (Attribute element : arrayAttr.getValue()) { + // Verify array elements recursively. + if (failed(verifyConstantType(op, element, elemType))) + return failure(); } return success(); } - return emitOpError("cannot have value of type ") << valueType; + return op.emitOpError("cannot have value of type ") << valueType; +} + +LogicalResult spirv::ConstantOp::verify() { + // ODS already generates checks to make sure the result type is valid. We just + // need to additionally check that the value's attribute type is consistent + // with the result type. + return verifyConstantType(*this, valueAttr(), getType()); } bool spirv::ConstantOp::isBuildableWith(Type type) { diff --git a/mlir/lib/Target/SPIRV/TranslateRegistration.cpp b/mlir/lib/Target/SPIRV/TranslateRegistration.cpp --- a/mlir/lib/Target/SPIRV/TranslateRegistration.cpp +++ b/mlir/lib/Target/SPIRV/TranslateRegistration.cpp @@ -16,6 +16,7 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Dialect.h" +#include "mlir/IR/Verifier.h" #include "mlir/Parser.h" #include "mlir/Support/FileUtilities.h" #include "mlir/Target/SPIRV/Deserialization.h" @@ -151,6 +152,8 @@ FileLineColLoc::get(&deserializationContext, /*filename=*/"", /*line=*/0, /*column=*/0))); dstModule->getBody()->push_front(spirvModule.release()); + if (failed(verify(*dstModule))) + return failure(); dstModule->print(output); return mlir::success(); diff --git a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir --- a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir @@ -72,6 +72,7 @@ %6 = spv.Constant dense<1.0> : tensor<2x3xf32> : !spv.array<2 x !spv.array<3 x f32>> %7 = spv.Constant dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> : !spv.array<2 x !spv.array<3 x i32>> %8 = spv.Constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32> : !spv.array<2 x !spv.array<3 x f32>> + %9 = spv.Constant [[dense<3.0> : vector<2xf32>]] : !spv.array<1 x !spv.array<1xvector<2xf32>>> return } @@ -86,7 +87,7 @@ // ----- func @array_constant() -> () { - // expected-error @+1 {{has array element whose type ('vector<2xi32>') does not match the result element type ('vector<2xf32>')}} + // expected-error @+1 {{result or element type ('vector<2xf32>') does not match value type ('vector<2xi32>')}} %0 = spv.Constant [dense<3.0> : vector<2xf32>, dense<4> : vector<2xi32>] : !spv.array<2xvector<2xf32>> return } @@ -110,7 +111,7 @@ // ----- func @value_result_type_mismatch() -> () { - // expected-error @+1 {{must have spv.array result type for array value}} + // expected-error @+1 {{result or element type ('vector<4xi32>') does not match value type ('tensor<4xi32>')}} %0 = "spv.Constant"() {value = dense<0> : tensor<4xi32>} : () -> (vector<4xi32>) }