diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp @@ -25,6 +25,17 @@ namespace { +/// Convert composite constant operation to SPIR-V dialect. +// TODO(denis0x0D) : move to DRR. +class ConstantCompositeOpConversion final : public SPIRVOpLowering { +public: + using SPIRVOpLowering::SPIRVOpLowering; + + PatternMatchResult + matchAndRewrite(ConstantOp constCompositeOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override; +}; + /// Convert constant operation with IndexType return to SPIR-V constant /// operation. Since IndexType is not used within SPIR-V dialect, this needs /// special handling to make sure the result type and the type of the value @@ -172,6 +183,27 @@ return builder.create(loc, basePtr, linearizedIndices); } +//===----------------------------------------------------------------------===// +// ConstantOp with composite type. +//===----------------------------------------------------------------------===// + +PatternMatchResult ConstantCompositeOpConversion::matchAndRewrite( + ConstantOp constCompositeOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + auto compositeType = + constCompositeOp.getResult().getType().dyn_cast(); + if (!compositeType) { + return matchFailure(); + } + auto spirvCompositeType = typeConverter.convertType(compositeType); + if (!spirvCompositeType) { + return matchFailure(); + } + rewriter.replaceOpWithNewOp( + constCompositeOp, spirvCompositeType, constCompositeOp.value()); + return matchSuccess(); +} + //===----------------------------------------------------------------------===// // ConstantOp with index type. //===----------------------------------------------------------------------===// @@ -354,7 +386,8 @@ OwningRewritePatternList &patterns) { // Add patterns that lower operations into SPIR-V dialect. populateWithGenerated(context, &patterns); - patterns.insert, IntegerOpConversion, IntegerOpConversion, diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp @@ -51,6 +51,9 @@ { for (auto argType : enumerate(funcOp.getType().getInputs())) { auto convertedType = typeConverter.convertType(argType.value()); + if (!convertedType) { + return matchFailure(); + } signatureConverter.addInputs(argType.index(), convertedType); } } diff --git a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp @@ -71,6 +71,19 @@ memrefSize = std::max(memrefSize, shape.value() * strides[shape.index()]); } return (offset + memrefSize) * elementSize.getValue(); + } else if (auto tensorType = t.dyn_cast()) { + if (!tensorType.hasStaticShape()) { + return llvm::None; + } + auto elementSize = getTypeNumBytes(tensorType.getElementType()); + if (!elementSize) { + return llvm::None; + } + int64_t size = elementSize.getValue(); + for (auto shape : tensorType.getShape()) { + size *= shape; + } + return size; } // TODO: Add size computation for other types. return llvm::None; @@ -122,6 +135,27 @@ } } + if (auto tensorType = type.dyn_cast()) { + // TODO(ravishankarm) : Handle dynamic shapes. + if (!tensorType.hasStaticShape()) { + return Type(); + } + auto elementType = convertStdType(tensorType.getElementType()); + if (!elementType) { + return Type(); + } + auto elementSize = getTypeNumBytes(elementType); + if (!elementSize) { + return Type(); + } + auto tensorSize = getTypeNumBytes(tensorType); + if (!tensorSize) { + return Type(); + } + return spirv::ArrayType::get(elementType, + tensorSize.getValue() / elementSize.getValue(), + elementSize.getValue()); + } return Type(); } diff --git a/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir @@ -220,6 +220,14 @@ %3 = constant dense<[2, 3]> : vector<2xi32> // CHECK: spv.constant 1 : i32 %4 = constant 1 : index + // CHECK: spv.constant dense<1> : tensor<2x3xi32> : !spv.array<6 x i32 [4]> + %5 = constant dense<1> : tensor<2x3xi32> + // CHECK: spv.constant dense<1.000000e+00> : tensor<2x3xf32> : !spv.array<6 x f32 [4]> + %6 = constant dense<1.0> : tensor<2x3xf32> + // CHECK: spv.constant dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> : !spv.array<6 x i32 [4]> + %7 = constant dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> + // CHECK: spv.constant dense<{{\[}}[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf32> : !spv.array<6 x f32 [4]> + %8 = constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32> return }