diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp @@ -25,6 +25,17 @@ namespace { +/// Convert composite constant operation to SPIR-V dialect. +// TODO(denis0x0D) : move to DRR. +class ConstantCompositeOpConversion final : public SPIRVOpLowering { +public: + using SPIRVOpLowering::SPIRVOpLowering; + + PatternMatchResult + matchAndRewrite(ConstantOp constCompositeOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override; +}; + /// Convert constant operation with IndexType return to SPIR-V constant /// operation. Since IndexType is not used within SPIR-V dialect, this needs /// special handling to make sure the result type and the type of the value @@ -172,6 +183,27 @@ return builder.create(loc, basePtr, linearizedIndices); } +//===----------------------------------------------------------------------===// +// ConstantOp with composite type. +//===----------------------------------------------------------------------===// + +PatternMatchResult ConstantCompositeOpConversion::matchAndRewrite( + ConstantOp constCompositeOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + auto compositeType = + constCompositeOp.getResult().getType().dyn_cast(); + if (!compositeType) { + return matchFailure(); + } + auto spirvCompositeType = typeConverter.convertType(compositeType); + if (!spirvCompositeType) { + return matchFailure(); + } + rewriter.replaceOpWithNewOp( + constCompositeOp, spirvCompositeType, constCompositeOp.value()); + return matchSuccess(); +} + //===----------------------------------------------------------------------===// // ConstantOp with index type. //===----------------------------------------------------------------------===// @@ -354,7 +386,8 @@ OwningRewritePatternList &patterns) { // Add patterns that lower operations into SPIR-V dialect. populateWithGenerated(context, &patterns); - patterns.insert, IntegerOpConversion, IntegerOpConversion, diff --git a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp @@ -72,6 +72,19 @@ memrefSize = std::max(memrefSize, shape.value() * strides[shape.index()]); } return (offset + memrefSize) * elementSize.getValue(); + } else if (auto tensorType = t.dyn_cast()) { + if (!tensorType.hasStaticShape()) { + return llvm::None; + } + auto elementSize = getTypeNumBytes(tensorType.getElementType()); + if (!elementSize) { + return llvm::None; + } + int64_t size = elementSize.getValue(); + for (auto shape : tensorType.getShape()) { + size *= shape; + } + return size; } // TODO: Add size computation for other types. return llvm::None; @@ -123,6 +136,27 @@ } } + if (auto tensorType = type.dyn_cast()) { + // TODO(ravishankarm) : Handle dynamic shapes. + if (!tensorType.hasStaticShape()) { + return Type(); + } + auto elementType = convertStdType(tensorType.getElementType()); + if (!elementType) { + return Type(); + } + auto elementSize = getTypeNumBytes(elementType); + if (!elementSize) { + return Type(); + } + auto tensorSize = getTypeNumBytes(tensorType); + if (!tensorSize) { + return Type(); + } + return spirv::ArrayType::get(elementType, + tensorSize.getValue() / elementSize.getValue(), + elementSize.getValue()); + } return Type(); } diff --git a/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir @@ -220,6 +220,14 @@ %3 = constant dense<[2, 3]> : vector<2xi32> // CHECK: spv.constant 1 : i32 %4 = constant 1 : index + // CHECK: spv.constant dense<1> : tensor<6xi32> : !spv.array<6 x i32 [4]> + %5 = constant dense<1> : tensor<6xi32> + // CHECK: spv.constant dense<1.000000e+00> : tensor<6xf32> : !spv.array<6 x f32 [4]> + %6 = constant dense<1.0> : tensor<6xf32> + // CHECK: spv.constant dense<{{\[}}1, 2, 3, 4, 5, 6]> : tensor<6xi32> : !spv.array<6 x i32 [4]> + %7 = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32> + // CHECK: spv.constant dense<{{\[}}1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf32> : !spv.array<6 x f32 [4]> + %8 = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]> : tensor<6xf32> return }