# Code generator to generate code for everything contained in COM type
# libraries.
import contextlib
import io
import keyword
import logging
import os
import textwrap
import warnings
from collections.abc import Iterator, Sequence
from typing import Any, Literal, Optional
from typing import Union as _UnionT

import comtypes
from comtypes import typeinfo
from comtypes.tools import tlbparser, typedesc
from comtypes.tools.codegenerator import heads, namespaces, packing
from comtypes.tools.codegenerator.comments import ComInterfaceBodyImplCommentWriter
from comtypes.tools.codegenerator.helpers import (
    ASSUME_STRINGS,
    ComMethodGenerator,
    DispMethodGenerator,
    DispPropertyGenerator,
    TypeNamer,
    get_real_type,
)
from comtypes.tools.codegenerator.modulenamer import name_wrapper_module

version = comtypes.__version__

logger = logging.getLogger(__name__)

__warn_on_munge__ = __debug__


_InterfaceTypeDesc = _UnionT[
    typedesc.ComInterface,
    typedesc.ComInterfaceHead,
    typedesc.ComInterfaceBody,
    typedesc.DispInterface,
    typedesc.DispInterfaceHead,
    typedesc.DispInterfaceBody,
]


class CodeGenerator:
    def __init__(self, known_symbols=None, known_interfaces=None) -> None:
        self.stream = io.StringIO()
        self.imports = namespaces.ImportedNamespaces()
        self.declarations = namespaces.DeclaredNamespaces()
        self.enums = namespaces.EnumerationNamespaces()
        self.unnamed_enum_members: list[tuple[str, int]] = []
        self._to_type_name = TypeNamer()
        self.known_symbols = known_symbols or {}
        self.known_interfaces = known_interfaces or {}

        self.done = set()  # type descriptions that have been generated
        self.names = set()  # names that have been generated
        self.externals = []  # typelibs imported to generated module
        self.enum_aliases: dict[str, str] = {}
        self.last_item = "attribute"

    @contextlib.contextmanager
    def adjust_blank(
        self, item: Literal["assert", "attribute", "class", "comment", "variable"]
    ) -> Iterator[io.StringIO]:
        if self.last_item == "class":
            print(file=self.stream)
            print(file=self.stream)
        elif self.last_item == "attribute":
            print(file=self.stream)
            if item == "class":
                print(file=self.stream)
        elif self.last_item in ("variable", "assert", "comment"):
            if item == "class":
                print(file=self.stream)
                print(file=self.stream)
            elif item == "attribute":
                print(file=self.stream)
        else:
            raise TypeError
        yield self.stream
        self.last_item = item

    def generate(self, item):
        if item in self.done:
            return
        if self._is_interface_typedesc(item):
            self._define_interface(item)
            return
        if isinstance(item, typedesc.StructureHead):
            name = getattr(item.struct, "name", None)
        else:
            name = getattr(item, "name", None)
        if name in self.known_symbols:
            self.imports.add(name, symbols=self.known_symbols)
            self.done.add(item)
            if isinstance(item, typedesc.Structure):
                self.done.add(item.get_head())
                self.done.add(item.get_body())
            return
        mth = getattr(self, type(item).__name__)
        # to avoid infinite recursion, we have to mark it as done
        # before actually generating the code.
        self.done.add(item)
        mth(item)

    def generate_all(self, items):
        for item in items:
            self.generate(item)

    def _make_relative_path(self, path1, path2):
        """path1 and path2 are pathnames.
        Return path1 as a relative path to path2, if possible.
        """
        path1 = os.path.abspath(path1)
        path2 = os.path.abspath(path2)
        common = os.path.commonprefix(
            [os.path.normcase(path1), os.path.normcase(path2)]
        )
        if not os.path.isdir(common):
            return path1
        if not common.endswith("\\"):
            return path1
        if not os.path.isdir(path2):
            path2 = os.path.dirname(path2)
        # strip the common prefix
        path1 = path1[len(common) :]
        path2 = path2[len(common) :]

        parts2 = path2.split("\\")
        return "..\\" * len(parts2) + path1

    def _generate_typelib_path(self, filename):
        # NOTE: the logic in this function appears completely different from that
        # of the handling of tlib (given as a string) in GetModule. There, relative
        # references are resolved wrt to the directory of the calling module. Here,
        # resolution is with respect to current working directory -- later to be
        # relativized to comtypes.gen.
        if filename is not None:
            if os.path.isabs(filename):
                # absolute path
                self.declarations.add("typelib_path", repr(filename))
            elif not os.path.dirname(filename) and not os.path.isfile(filename):
                # no directory given, and not in current directory.
                self.declarations.add("typelib_path", repr(filename))
            else:
                # relative path; make relative to comtypes.gen.
                path = self._make_relative_path(filename, comtypes.gen.__path__[0])
                self.imports.add("os")
                definition = (
                    "os.path.normpath(\n"
                    "    os.path.abspath(os.path.join(os.path.dirname(__file__),\n"
                    f"                                 {path!r})))"
                )
                self.declarations.add("typelib_path", definition)
                p = os.path.normpath(
                    os.path.abspath(os.path.join(comtypes.gen.__path__[0], path))
                )
                assert os.path.isfile(p)
            self.names.add("typelib_path")

    def generate_wrapper_code(
        self, tdescs: Sequence[Any], filename: Optional[str]
    ) -> str:
        """Returns the code for the COM type library wrapper module.

        The returned `Python` code string is containing definitions of interfaces,
        coclasses, constants, and structures.

        The module will have long name that is derived from the type library guid, lcid
        and version numbers.
        Such as `comtypes.gen._xxxxxxxx_xxxx_xxxx_xxxx_xxxxxxxxxxxx_l_M_m`.
        """
        tlib_mtime = None

        if filename is not None:
            # get full path to DLL first (os.stat can't work with relative DLL paths properly)
            loaded_typelib = typeinfo.LoadTypeLib(filename)
            full_filename = tlbparser.get_tlib_filename(loaded_typelib)

            while full_filename and not os.path.exists(full_filename):
                full_filename = os.path.split(full_filename)[0]

            if full_filename and os.path.isfile(full_filename):
                # get DLL timestamp at the moment of wrapper generation

                tlib_mtime = os.stat(full_filename).st_mtime

                if not full_filename.endswith(filename):
                    filename = full_filename

        self.filename = filename
        self.declarations.add("_lcid", "0", "change this if required")
        self._generate_typelib_path(filename)

        items = set(tdescs)
        loops = 0
        while items:
            loops += 1
            self.more = set()
            self.generate_all(items)

            items |= self.more
            items -= self.done

        self.imports.add("ctypes", "*")  # HACK: wildcard import is so ugly.
        if tlib_mtime is not None:
            logger.debug('filename: "%s": tlib_mtime: %s', filename, tlib_mtime)
            self.imports.add("comtypes", "_check_version")
        output = io.StringIO()
        if filename is not None:
            # Hm, what is the CORRECT encoding?
            print("# -*- coding: mbcs -*-", file=output)
            print(file=output)
        print(self.imports.getvalue(), file=output)
        print("from typing import TYPE_CHECKING", file=output)
        print(file=output)
        print("if TYPE_CHECKING:", file=output)
        print("    from comtypes import hints", file=output)
        print(file=output)
        print(file=output)
        print(self.declarations.getvalue(), file=output)
        print(file=output)
        if self.unnamed_enum_members:
            print("# values for unnamed enumeration", file=output)
            for n, v in self.unnamed_enum_members:
                print(f"{n} = {v}", file=output)
            print(file=output)
        if self.enums:
            print(self.enums.to_constants(), file=output)
            print(file=output)
        if self.enum_aliases:
            print("# aliases for enums", file=output)
            for k, v in self.enum_aliases.items():
                print(f"{k} = {v}", file=output)
            print(file=output)
        print(self.stream.getvalue(), file=output)
        print(self._make_dunder_all_part(), file=output)
        print(file=output)
        if tlib_mtime is not None:
            print(f"_check_version({version!r}, {tlib_mtime:f})", file=output)
        return output.getvalue()

    def generate_friendly_code(self, modname: str) -> str:
        """Returns the code for the COM type library friendly module.

        The returned `Python` code string is containing `from {modname} import
        DefinedInWrapper, ...` and `__all__ = ['DefinedInWrapper', ...]`
        The `modname` is the wrapper module name like `comtypes.gen._xxxx..._x_x_x`.

        The module will have shorter name that is derived from the type library name.
        Such as "comtypes.gen.stdole" and "comtypes.gen.Excel".
        """
        output = io.StringIO()
        print("from enum import IntFlag", file=output)
        print(file=output)
        print(f"import {modname} as __wrapper_module__", file=output)
        print(self._make_friendly_module_import_part(modname), file=output)
        print(file=output)
        print(file=output)
        if self.enums:
            print(self.enums.to_intflags(), file=output)
            print(file=output)
            print(file=output)
        if self.enum_aliases:
            for k, v in self.enum_aliases.items():
                print(f"{k} = {v}", file=output)
            print(file=output)
            print(file=output)
        print(self._make_dunder_all_part(), file=output)
        return output.getvalue()

    def _make_dunder_all_part(self) -> str:
        joined_names = ", ".join(repr(str(n)) for n in self.names)
        dunder_all = f"__all__ = [{joined_names}]"
        if len(dunder_all) > 80:
            txtwrapper = textwrap.TextWrapper(
                subsequent_indent="    ", initial_indent="    ", break_long_words=False
            )
            joined_names = "\n".join(txtwrapper.wrap(joined_names))
            dunder_all = f"__all__ = [\n{joined_names}\n]"
        return dunder_all

    def _make_friendly_module_import_part(self, modname: str) -> str:
        # The `modname` is the wrapper module name like `comtypes.gen._xxxx..._x_x_x`
        txtwrapper = textwrap.TextWrapper(
            subsequent_indent="    ", initial_indent="    ", break_long_words=False
        )
        symbols = set(self.names)
        symbols.update(self.imports.get_symbols())
        symbols.update(self.declarations.get_symbols())
        symbols -= set(self.enums.get_symbols())
        symbols -= set(self.enum_aliases)
        joined_names = ", ".join(str(n) for n in symbols)
        part = f"from {modname} import {joined_names}"
        if len(part) > 80:
            txtwrapper = textwrap.TextWrapper(
                subsequent_indent="    ", initial_indent="    ", break_long_words=False
            )
            joined_names = "\n".join(txtwrapper.wrap(joined_names))
            part = f"from {modname} import (\n{joined_names}\n)"
        return part

    def need_VARIANT_imports(self, value):
        text = repr(value)
        if "Decimal(" in text:
            self.imports.add("decimal", "Decimal")
        if "datetime.datetime(" in text:
            self.imports.add("datetime")

    def ArrayType(self, tp: typedesc.ArrayType) -> None:
        self.generate(get_real_type(tp.typ))
        self.generate(tp.typ)

    def EnumValue(self, tp: typedesc.EnumValue) -> None:
        value = int(tp.value)
        if keyword.iskeyword(tp.name):
            # XXX use logging!
            if __warn_on_munge__:
                print(f"# Fixing keyword as EnumValue for {tp.name}")
        tp_name = self._to_type_name(tp)
        if tp.enumeration.name:
            self.enums.add(tp.enumeration.name, tp_name, value)
        else:
            self.unnamed_enum_members.append((tp_name, value))
        self.names.add(tp_name)

    def Enumeration(self, tp: typedesc.Enumeration) -> None:
        for item in tp.values:
            self.generate(item)
        if tp.name:
            self.names.add(tp.name)

    def Typedef(self, tp: typedesc.Typedef) -> None:
        if isinstance(tp.typ, (typedesc.Structure, typedesc.Union)):
            self.generate(tp.typ.get_head())
            self.more.add(tp.typ)
        else:
            self.generate(tp.typ)
        definition = self._to_type_name(tp.typ)
        if tp.name != definition:
            if definition in self.known_symbols:
                self.declarations.add(tp.name, definition)
            else:
                if isinstance(tp.typ, typedesc.Enumeration):
                    self.enum_aliases[tp.name] = definition
                else:
                    with self.adjust_blank("variable") as ofi:
                        print(f"{tp.name} = {definition}", file=ofi)
        self.names.add(tp.name)

    def FundamentalType(self, item: typedesc.FundamentalType) -> None:
        pass  # we should check if this is known somewhere

    def StructureHead(self, head: typedesc.StructureHead) -> None:
        for struct in head.struct.bases:
            self.generate(struct.get_head())
            self.more.add(struct)
        basenames = [self._to_type_name(b) for b in head.struct.bases]
        if basenames:
            self.imports.add("comtypes", "GUID")

        with self.adjust_blank("class") as ofi:
            heads.StructureHeadWriter(ofi).write(head, basenames)
        self.names.add(head.struct.name)

    def Structure(self, struct: typedesc.Structure) -> None:
        self.generate(struct.get_head())
        self.generate(struct.get_body())

    def Union(self, union: typedesc.Union) -> None:
        self.generate(union.get_head())
        self.generate(union.get_body())

    def StructureBody(self, body: typedesc.StructureBody) -> None:
        fields: list[typedesc.Field] = []
        methods: list[typedesc.Method] = []
        for m in body.struct.members:
            if type(m) is typedesc.Field:
                fields.append(m)
                if type(m.typ) is typedesc.Typedef:
                    self.generate(get_real_type(m.typ))
                self.generate(m.typ)
            elif type(m) is typedesc.Method:
                methods.append(m)
                self.generate(m.returns)
                self.generate_all(m.iterArgTypes())
            elif type(m) is typedesc.Constructor:
                pass

        # we don't need _pack_ on Unions (I hope, at least), and not
        # on COM interfaces:
        if not methods:
            try:
                pack = packing.calc_packing(body.struct, fields)
                if pack is not None:
                    with self.adjust_blank("attribute") as ofi:
                        print(f"{body.struct.name}._pack_ = {pack}", file=ofi)
            except packing.PackingError as details:
                # if packing fails, write a warning comment to the output.
                message = f"Structure {body.struct.name}: {details}"
                warnings.warn(message, UserWarning)
                with self.adjust_blank("comment") as ofi:
                    print(f"# WARNING: {details}", file=ofi)
        if fields:
            if body.struct.bases:
                assert len(body.struct.bases) == 1
                self.generate(body.struct.bases[0].get_body())
            with self.adjust_blank("attribute") as ofi:
                self._write_structbody_fields(body, fields, ofi)
            if body.struct.size is None:
                with self.adjust_blank("comment") as ofi:
                    self._write_structbody_size_comments(body, ofi)
            elif body.struct.name not in packing.dont_assert_size:
                with self.adjust_blank("assert") as ofi:
                    self._write_structbody_size_assertion(body, ofi)
        if methods:
            self.imports.add("comtypes", "COMMETHOD")
            with self.adjust_blank("attribute") as ofi:
                self._write_structbody_commethods(body, methods, ofi)

    def _write_structbody_fields(
        self,
        body: typedesc.StructureBody,
        fields: list[typedesc.Field],
        ofi: io.StringIO,
    ) -> None:
        print(f"{body.struct.name}._fields_ = [", file=ofi)
        if body.struct.location:
            print(f"    # {body.struct.location}", file=ofi)
        # unnamed fields will get autogenerated names "_", "_1". "_2", "_3", ...
        unnamed_index = 0
        for f in fields:
            if not f.name:
                fieldname = "_" if not unnamed_index else f"_{unnamed_index:d}"
                unnamed_index += 1
                print(f"    # Unnamed field renamed to '{fieldname}'", file=ofi)
            else:
                fieldname = f.name
            typename = self._to_type_name(f.typ)
            if f.bits is None:
                print(f"    ('{fieldname}', {typename}),", file=ofi)
            else:
                print(f"    ('{fieldname}', {typename}, {f.bits}),", file=ofi)
        print("]", file=ofi)

    def _write_structbody_size_comments(
        self, body: typedesc.StructureBody, ofi: io.StringIO
    ) -> None:
        msg1 = "# The size provided by the typelib is incorrect."
        msg2 = f"# The size and alignment check for {body.struct.name} is skipped."
        print(msg1, file=ofi)
        print(msg2, file=ofi)

    def _write_structbody_size_assertion(
        self, body: typedesc.StructureBody, ofi: io.StringIO
    ) -> None:
        name = body.struct.name
        assert body.struct.size is not None
        size = body.struct.size // 8
        print(f"assert sizeof({name}) == {size}, sizeof({name})", file=ofi)
        align = body.struct.align // 8
        print(f"assert alignment({name}) == {align}, alignment({name})", file=ofi)

    def _write_structbody_commethods(
        self,
        body: typedesc.StructureBody,
        methods: list[typedesc.Method],
        ofi: io.StringIO,
    ) -> None:
        print(f"{body.struct.name}._methods_ = [", file=ofi)
        if body.struct.location:
            print(f"# {body.struct.location}", file=ofi)
        for m in methods:
            if m.location:
                print(f"    # {m.location}", file=ofi)
            print(
                (
                    "    COMMETHOD(\n"
                    "        [],\n"
                    f"        {self._to_type_name(m.returns)},\n"
                    f"        '{m.name}',\n"
                ),
                file=ofi,
            )
            for a in m.iterArgTypes():
                print(f"        ([], {self._to_type_name(a)}),\n", file=ofi)
                print("    ),", file=ofi)
        print("]", file=ofi)

    ################################################################
    # top-level typedesc generators
    #
    def TypeLib(self, lib: typedesc.TypeLib) -> None:
        # Hm, in user code we have to write:
        # class MyServer(COMObject, ...):
        #     _com_interfaces_ = [MyTypeLib.IInterface]
        #     _reg_typelib_ = MyTypeLib.Library._reg_typelib_
        #                               ^^^^^^^
        # Should the '_reg_typelib_' attribute be at top-level in the
        # generated code, instead as being an attribute of the
        # 'Library' symbol?

        with self.adjust_blank("class") as ofi:
            heads.LibraryHeadWriter(ofi).write(lib)
        self.names.add("Library")

    def External(self, ext: typedesc.External) -> None:
        modname = name_wrapper_module(ext.tlib)
        if modname not in self.imports:
            self.externals.append(ext.tlib)
            self.imports.add(modname)

    def Constant(self, tp: typedesc.Constant) -> None:
        with self.adjust_blank("variable") as ofi:
            type_name = self._to_type_name(tp.typ)
            print(f"{tp.name} = {tp.value!r}  # Constant {type_name}", file=ofi)
            self.names.add(tp.name)

    def SAFEARRAYType(self, sa: typedesc.SAFEARRAYType) -> None:
        self.generate(sa.typ)
        self.imports.add("comtypes.automation", "_midlSAFEARRAY")

    def PointerType(self, tp: typedesc.PointerType) -> None:
        if type(tp.typ) is typedesc.ComInterface:
            # this defines the class
            self.generate(tp.typ.get_head())
            # this defines the _methods_
            self.more.add(tp.typ)
        elif type(tp.typ) is typedesc.PointerType:
            self.generate(tp.typ)
        elif type(tp.typ) in (typedesc.Union, typedesc.Structure):
            self.generate(tp.typ.get_head())
            self.more.add(tp.typ)
        elif type(tp.typ) is typedesc.Typedef:
            self.generate(tp.typ)
        else:
            self.generate(tp.typ)
        if not ASSUME_STRINGS:
            return
        real_type = get_real_type(tp.typ)
        if isinstance(real_type, typedesc.FundamentalType):
            if real_type.name == "char":
                self.declarations.add("STRING", "c_char_p")
            elif real_type.name == "wchar_t":
                self.declarations.add("WSTRING", "c_wchar_p")

    def CoClass(self, coclass: typedesc.CoClass) -> None:
        self.imports.add("comtypes", "GUID")
        self.imports.add("comtypes", "CoClass")
        with self.adjust_blank("class") as ofi:
            heads.CoClassHeadWriter(ofi, self.filename).write(coclass)

        for itf, _ in coclass.interfaces:
            self.generate(itf.get_head())
        impl, src = typedesc.groupby_impltypeflags(coclass.interfaces)
        implemented = [self._to_type_name(itf) for itf in impl]
        sources = [self._to_type_name(itf) for itf in src]

        with self.adjust_blank("attribute") as ofi:
            if implemented:
                itfs = ", ".join(implemented)
                print(f"{coclass.name}._com_interfaces_ = [{itfs}]", file=ofi)
            if sources:
                itfs = ", ".join(sources)
                print(f"{coclass.name}._outgoing_interfaces_ = [{itfs}]", file=ofi)

        self.names.add(coclass.name)

    def _is_interface_typedesc(
        self, item: Any
    ) -> "comtypes.hints.TypeGuard[_InterfaceTypeDesc]":
        return isinstance(
            item,
            (
                typedesc.ComInterface,
                typedesc.ComInterfaceHead,
                typedesc.ComInterfaceBody,
                typedesc.DispInterface,
                typedesc.DispInterfaceHead,
                typedesc.DispInterfaceBody,
            ),
        )

    def _define_interface(self, item: _InterfaceTypeDesc) -> None:
        if isinstance(
            item,
            (
                typedesc.ComInterfaceHead,
                typedesc.ComInterfaceBody,
                typedesc.DispInterfaceHead,
                typedesc.DispInterfaceBody,
            ),
        ):
            if self._is_known_interface(item.itf):
                self.imports.add(item.itf.name, symbols=self.known_symbols)
                self.done.add(item)
                return
        elif isinstance(item, (typedesc.ComInterface, typedesc.DispInterface)):
            if self._is_known_interface(item):
                self.imports.add(item.name, symbols=self.known_symbols)
                self.done.add(item)
                self.done.add(item.get_head())
                self.done.add(item.get_body())
                return
        else:
            raise TypeError
        self.done.add(item)  # to avoid infinite recursion.
        mth = getattr(self, type(item).__name__)
        mth(item)

    def ComInterface(self, itf: typedesc.ComInterface) -> None:
        self.generate(itf.get_head())
        self.generate(itf.get_body())
        self.names.add(itf.name)

    def _is_known_interface(
        self, item: _UnionT[typedesc.ComInterface, typedesc.DispInterface]
    ) -> bool:
        """Returns whether an interface is statically defined in `comtypes`,
        based on its name and iid.
        """
        if item.name in self.known_interfaces:
            return self.known_interfaces[item.name] == item.iid
        return False

    def ComInterfaceHead(self, head: typedesc.ComInterfaceHead) -> None:
        if head.itf.base is None:
            # we don't beed to generate IUnknown
            return
        self.generate(head.itf.base.get_head())
        self.more.add(head.itf.base)
        basename = self._to_type_name(head.itf.base)

        self.imports.add("comtypes", "GUID")

        with self.adjust_blank("class") as ofi:
            heads.ComInterfaceHeadWriter(ofi).write(head, basename)

    def ComInterfaceBody(self, body: typedesc.ComInterfaceBody) -> None:
        # The base class must be fully generated, including the
        # _methods_ list.
        self.generate(body.itf.base)

        # make sure we can generate the body
        for m in body.itf.members:
            for a in m.arguments:
                self.generate(a[0])
            self.generate(m.returns)

        with self.adjust_blank("attribute") as ofi:
            print(f"{body.itf.name}._methods_ = [", file=ofi)
            for m in body.itf.members:
                if isinstance(m, typedesc.ComMethod):
                    isdual = "dual" in body.itf.idlflags
                    print(ComMethodGenerator(m, isdual).generate(), file=ofi)
                    self.add_ComMth_requirements(m, isdual)
                else:
                    raise TypeError("what's this?")

            print("]", file=ofi)
        print(file=self.stream)
        ComInterfaceBodyImplCommentWriter(self.stream).write(body)

    def DispInterface(self, itf: typedesc.DispInterface) -> None:
        self.generate(itf.get_head())
        self.generate(itf.get_body())
        self.names.add(itf.name)

    def DispInterfaceHead(self, head: typedesc.DispInterfaceHead) -> None:
        self.generate(head.itf.base)
        basename = self._to_type_name(head.itf.base)

        self.imports.add("comtypes", "GUID")

        with self.adjust_blank("class") as ofi:
            heads.DispInterfaceHeadWriter(ofi).write(head, basename)

    def DispInterfaceBody(self, body: typedesc.DispInterfaceBody) -> None:
        # make sure we can generate the body
        for m in body.itf.members:
            if isinstance(m, typedesc.DispMethod):
                for a in m.arguments:
                    self.generate(a[0])
                self.generate(m.returns)
            elif isinstance(m, typedesc.DispProperty):
                self.generate(m.typ)
            else:
                raise TypeError(m)
        with self.adjust_blank("attribute") as ofi:
            print(f"{body.itf.name}._disp_methods_ = [", file=ofi)
            for m in body.itf.members:
                if isinstance(m, typedesc.DispMethod):
                    print(DispMethodGenerator(m).generate(), file=ofi)
                    self.add_DispMth_requirements(m)
                elif isinstance(m, typedesc.DispProperty):
                    print(DispPropertyGenerator(m).generate(), file=ofi)
                    self.add_DispProp_requirements(m)
                else:
                    raise TypeError(m)
            print("]", file=ofi)

    ################################################################
    # non-toplevel method requirements
    #
    def add_ComMth_requirements(self, m: typedesc.ComMethod, isdual: bool) -> None:
        self.imports.add("comtypes", "COMMETHOD")
        if isdual:
            self.imports.add("comtypes", "dispid")
        if __debug__ and m.doc:
            self.imports.add("comtypes", "helpstring")
        for typ, _, _, default in m.arguments:
            if isinstance(typ, typedesc.ComInterface):
                self.declarations.add(
                    "OPENARRAY",
                    "POINTER(c_ubyte)",
                    "hack, see comtypes/tools/codegenerator.py",
                )
            if default is not None:
                self.need_VARIANT_imports(default)

    def add_DispMth_requirements(self, m: typedesc.DispMethod) -> None:
        self.imports.add("comtypes", "DISPMETHOD")
        self.imports.add("comtypes", "dispid")
        if __debug__ and m.doc:
            self.imports.add("comtypes", "helpstring")
        for _, _, _, default in m.arguments:
            if default is not None:
                self.need_VARIANT_imports(default)

    def add_DispProp_requirements(self, prop: typedesc.DispProperty) -> None:
        self.imports.add("comtypes", "DISPPROPERTY")
        self.imports.add("comtypes", "dispid")
        if __debug__ and prop.doc:
            self.imports.add("comtypes", "helpstring")
